1package providers
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "net/http"
9 "net/http/httptest"
10 "strings"
11 "testing"
12
13 "github.com/charmbracelet/crush/internal/ai"
14 "github.com/openai/openai-go/v2/packages/param"
15 "github.com/stretchr/testify/require"
16)
17
18func TestToOpenAIPrompt_SystemMessages(t *testing.T) {
19 t.Parallel()
20
21 t.Run("should forward system messages", func(t *testing.T) {
22 t.Parallel()
23
24 prompt := ai.Prompt{
25 {
26 Role: ai.MessageRoleSystem,
27 Content: []ai.MessagePart{
28 ai.TextPart{Text: "You are a helpful assistant."},
29 },
30 },
31 }
32
33 messages, warnings := toOpenAIPrompt(prompt)
34
35 require.Empty(t, warnings)
36 require.Len(t, messages, 1)
37
38 systemMsg := messages[0].OfSystem
39 require.NotNil(t, systemMsg)
40 require.Equal(t, "You are a helpful assistant.", systemMsg.Content.OfString.Value)
41 })
42
43 t.Run("should handle empty system messages", func(t *testing.T) {
44 t.Parallel()
45
46 prompt := ai.Prompt{
47 {
48 Role: ai.MessageRoleSystem,
49 Content: []ai.MessagePart{},
50 },
51 }
52
53 messages, warnings := toOpenAIPrompt(prompt)
54
55 require.Len(t, warnings, 1)
56 require.Contains(t, warnings[0].Message, "system prompt has no text parts")
57 require.Empty(t, messages)
58 })
59
60 t.Run("should join multiple system text parts", func(t *testing.T) {
61 t.Parallel()
62
63 prompt := ai.Prompt{
64 {
65 Role: ai.MessageRoleSystem,
66 Content: []ai.MessagePart{
67 ai.TextPart{Text: "You are a helpful assistant."},
68 ai.TextPart{Text: "Be concise."},
69 },
70 },
71 }
72
73 messages, warnings := toOpenAIPrompt(prompt)
74
75 require.Empty(t, warnings)
76 require.Len(t, messages, 1)
77
78 systemMsg := messages[0].OfSystem
79 require.NotNil(t, systemMsg)
80 require.Equal(t, "You are a helpful assistant.\nBe concise.", systemMsg.Content.OfString.Value)
81 })
82}
83
84func TestToOpenAIPrompt_UserMessages(t *testing.T) {
85 t.Parallel()
86
87 t.Run("should convert messages with only a text part to a string content", func(t *testing.T) {
88 t.Parallel()
89
90 prompt := ai.Prompt{
91 {
92 Role: ai.MessageRoleUser,
93 Content: []ai.MessagePart{
94 ai.TextPart{Text: "Hello"},
95 },
96 },
97 }
98
99 messages, warnings := toOpenAIPrompt(prompt)
100
101 require.Empty(t, warnings)
102 require.Len(t, messages, 1)
103
104 userMsg := messages[0].OfUser
105 require.NotNil(t, userMsg)
106 require.Equal(t, "Hello", userMsg.Content.OfString.Value)
107 })
108
109 t.Run("should convert messages with image parts", func(t *testing.T) {
110 t.Parallel()
111
112 imageData := []byte{0, 1, 2, 3}
113 prompt := ai.Prompt{
114 {
115 Role: ai.MessageRoleUser,
116 Content: []ai.MessagePart{
117 ai.TextPart{Text: "Hello"},
118 ai.FilePart{
119 MediaType: "image/png",
120 Data: imageData,
121 },
122 },
123 },
124 }
125
126 messages, warnings := toOpenAIPrompt(prompt)
127
128 require.Empty(t, warnings)
129 require.Len(t, messages, 1)
130
131 userMsg := messages[0].OfUser
132 require.NotNil(t, userMsg)
133
134 content := userMsg.Content.OfArrayOfContentParts
135 require.Len(t, content, 2)
136
137 // Check text part
138 textPart := content[0].OfText
139 require.NotNil(t, textPart)
140 require.Equal(t, "Hello", textPart.Text)
141
142 // Check image part
143 imagePart := content[1].OfImageURL
144 require.NotNil(t, imagePart)
145 expectedURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
146 require.Equal(t, expectedURL, imagePart.ImageURL.URL)
147 })
148
149 t.Run("should add image detail when specified through provider options", func(t *testing.T) {
150 t.Parallel()
151
152 imageData := []byte{0, 1, 2, 3}
153 prompt := ai.Prompt{
154 {
155 Role: ai.MessageRoleUser,
156 Content: []ai.MessagePart{
157 ai.FilePart{
158 MediaType: "image/png",
159 Data: imageData,
160 ProviderOptions: ai.ProviderOptions{
161 "openai": map[string]any{
162 "imageDetail": "low",
163 },
164 },
165 },
166 },
167 },
168 }
169
170 messages, warnings := toOpenAIPrompt(prompt)
171
172 require.Empty(t, warnings)
173 require.Len(t, messages, 1)
174
175 userMsg := messages[0].OfUser
176 require.NotNil(t, userMsg)
177
178 content := userMsg.Content.OfArrayOfContentParts
179 require.Len(t, content, 1)
180
181 imagePart := content[0].OfImageURL
182 require.NotNil(t, imagePart)
183 require.Equal(t, "low", imagePart.ImageURL.Detail)
184 })
185}
186
187func TestToOpenAIPrompt_FileParts(t *testing.T) {
188 t.Parallel()
189
190 t.Run("should throw for unsupported mime types", func(t *testing.T) {
191 t.Parallel()
192
193 prompt := ai.Prompt{
194 {
195 Role: ai.MessageRoleUser,
196 Content: []ai.MessagePart{
197 ai.FilePart{
198 MediaType: "application/something",
199 Data: []byte("test"),
200 },
201 },
202 },
203 }
204
205 messages, warnings := toOpenAIPrompt(prompt)
206
207 require.Len(t, warnings, 1)
208 require.Contains(t, warnings[0].Message, "file part media type application/something not supported")
209 require.Len(t, messages, 1) // Message is still created but with empty content array
210 })
211
212 t.Run("should add audio content for audio/wav file parts", func(t *testing.T) {
213 t.Parallel()
214
215 audioData := []byte{0, 1, 2, 3}
216 prompt := ai.Prompt{
217 {
218 Role: ai.MessageRoleUser,
219 Content: []ai.MessagePart{
220 ai.FilePart{
221 MediaType: "audio/wav",
222 Data: audioData,
223 },
224 },
225 },
226 }
227
228 messages, warnings := toOpenAIPrompt(prompt)
229
230 require.Empty(t, warnings)
231 require.Len(t, messages, 1)
232
233 userMsg := messages[0].OfUser
234 require.NotNil(t, userMsg)
235
236 content := userMsg.Content.OfArrayOfContentParts
237 require.Len(t, content, 1)
238
239 audioPart := content[0].OfInputAudio
240 require.NotNil(t, audioPart)
241 require.Equal(t, base64.StdEncoding.EncodeToString(audioData), audioPart.InputAudio.Data)
242 require.Equal(t, "wav", audioPart.InputAudio.Format)
243 })
244
245 t.Run("should add audio content for audio/mpeg file parts", func(t *testing.T) {
246 t.Parallel()
247
248 audioData := []byte{0, 1, 2, 3}
249 prompt := ai.Prompt{
250 {
251 Role: ai.MessageRoleUser,
252 Content: []ai.MessagePart{
253 ai.FilePart{
254 MediaType: "audio/mpeg",
255 Data: audioData,
256 },
257 },
258 },
259 }
260
261 messages, warnings := toOpenAIPrompt(prompt)
262
263 require.Empty(t, warnings)
264 require.Len(t, messages, 1)
265
266 userMsg := messages[0].OfUser
267 content := userMsg.Content.OfArrayOfContentParts
268 audioPart := content[0].OfInputAudio
269 require.NotNil(t, audioPart)
270 require.Equal(t, "mp3", audioPart.InputAudio.Format)
271 })
272
273 t.Run("should add audio content for audio/mp3 file parts", func(t *testing.T) {
274 t.Parallel()
275
276 audioData := []byte{0, 1, 2, 3}
277 prompt := ai.Prompt{
278 {
279 Role: ai.MessageRoleUser,
280 Content: []ai.MessagePart{
281 ai.FilePart{
282 MediaType: "audio/mp3",
283 Data: audioData,
284 },
285 },
286 },
287 }
288
289 messages, warnings := toOpenAIPrompt(prompt)
290
291 require.Empty(t, warnings)
292 require.Len(t, messages, 1)
293
294 userMsg := messages[0].OfUser
295 content := userMsg.Content.OfArrayOfContentParts
296 audioPart := content[0].OfInputAudio
297 require.NotNil(t, audioPart)
298 require.Equal(t, "mp3", audioPart.InputAudio.Format)
299 })
300
301 t.Run("should convert messages with PDF file parts", func(t *testing.T) {
302 t.Parallel()
303
304 pdfData := []byte{1, 2, 3, 4, 5}
305 prompt := ai.Prompt{
306 {
307 Role: ai.MessageRoleUser,
308 Content: []ai.MessagePart{
309 ai.FilePart{
310 MediaType: "application/pdf",
311 Data: pdfData,
312 Filename: "document.pdf",
313 },
314 },
315 },
316 }
317
318 messages, warnings := toOpenAIPrompt(prompt)
319
320 require.Empty(t, warnings)
321 require.Len(t, messages, 1)
322
323 userMsg := messages[0].OfUser
324 content := userMsg.Content.OfArrayOfContentParts
325 require.Len(t, content, 1)
326
327 filePart := content[0].OfFile
328 require.NotNil(t, filePart)
329 require.Equal(t, "document.pdf", filePart.File.Filename.Value)
330
331 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
332 require.Equal(t, expectedData, filePart.File.FileData.Value)
333 })
334
335 t.Run("should convert messages with binary PDF file parts", func(t *testing.T) {
336 t.Parallel()
337
338 pdfData := []byte{1, 2, 3, 4, 5}
339 prompt := ai.Prompt{
340 {
341 Role: ai.MessageRoleUser,
342 Content: []ai.MessagePart{
343 ai.FilePart{
344 MediaType: "application/pdf",
345 Data: pdfData,
346 Filename: "document.pdf",
347 },
348 },
349 },
350 }
351
352 messages, warnings := toOpenAIPrompt(prompt)
353
354 require.Empty(t, warnings)
355 require.Len(t, messages, 1)
356
357 userMsg := messages[0].OfUser
358 content := userMsg.Content.OfArrayOfContentParts
359 filePart := content[0].OfFile
360 require.NotNil(t, filePart)
361
362 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
363 require.Equal(t, expectedData, filePart.File.FileData.Value)
364 })
365
366 t.Run("should convert messages with PDF file parts using file_id", func(t *testing.T) {
367 t.Parallel()
368
369 prompt := ai.Prompt{
370 {
371 Role: ai.MessageRoleUser,
372 Content: []ai.MessagePart{
373 ai.FilePart{
374 MediaType: "application/pdf",
375 Data: []byte("file-pdf-12345"),
376 },
377 },
378 },
379 }
380
381 messages, warnings := toOpenAIPrompt(prompt)
382
383 require.Empty(t, warnings)
384 require.Len(t, messages, 1)
385
386 userMsg := messages[0].OfUser
387 content := userMsg.Content.OfArrayOfContentParts
388 filePart := content[0].OfFile
389 require.NotNil(t, filePart)
390 require.Equal(t, "file-pdf-12345", filePart.File.FileID.Value)
391 require.True(t, param.IsOmitted(filePart.File.FileData))
392 require.True(t, param.IsOmitted(filePart.File.Filename))
393 })
394
395 t.Run("should use default filename for PDF file parts when not provided", func(t *testing.T) {
396 t.Parallel()
397
398 pdfData := []byte{1, 2, 3, 4, 5}
399 prompt := ai.Prompt{
400 {
401 Role: ai.MessageRoleUser,
402 Content: []ai.MessagePart{
403 ai.FilePart{
404 MediaType: "application/pdf",
405 Data: pdfData,
406 },
407 },
408 },
409 }
410
411 messages, warnings := toOpenAIPrompt(prompt)
412
413 require.Empty(t, warnings)
414 require.Len(t, messages, 1)
415
416 userMsg := messages[0].OfUser
417 content := userMsg.Content.OfArrayOfContentParts
418 filePart := content[0].OfFile
419 require.NotNil(t, filePart)
420 require.Equal(t, "part-0.pdf", filePart.File.Filename.Value)
421 })
422}
423
424func TestToOpenAIPrompt_ToolCalls(t *testing.T) {
425 t.Parallel()
426
427 t.Run("should stringify arguments to tool calls", func(t *testing.T) {
428 t.Parallel()
429
430 inputArgs := map[string]any{"foo": "bar123"}
431 inputJSON, _ := json.Marshal(inputArgs)
432
433 outputResult := map[string]any{"oof": "321rab"}
434 outputJSON, _ := json.Marshal(outputResult)
435
436 prompt := ai.Prompt{
437 {
438 Role: ai.MessageRoleAssistant,
439 Content: []ai.MessagePart{
440 ai.ToolCallPart{
441 ToolCallID: "quux",
442 ToolName: "thwomp",
443 Input: string(inputJSON),
444 },
445 },
446 },
447 {
448 Role: ai.MessageRoleTool,
449 Content: []ai.MessagePart{
450 ai.ToolResultPart{
451 ToolCallID: "quux",
452 Output: ai.ToolResultOutputContentText{
453 Text: string(outputJSON),
454 },
455 },
456 },
457 },
458 }
459
460 messages, warnings := toOpenAIPrompt(prompt)
461
462 require.Empty(t, warnings)
463 require.Len(t, messages, 2)
464
465 // Check assistant message with tool call
466 assistantMsg := messages[0].OfAssistant
467 require.NotNil(t, assistantMsg)
468 require.Equal(t, "", assistantMsg.Content.OfString.Value)
469 require.Len(t, assistantMsg.ToolCalls, 1)
470
471 toolCall := assistantMsg.ToolCalls[0].OfFunction
472 require.NotNil(t, toolCall)
473 require.Equal(t, "quux", toolCall.ID)
474 require.Equal(t, "thwomp", toolCall.Function.Name)
475 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
476
477 // Check tool message
478 toolMsg := messages[1].OfTool
479 require.NotNil(t, toolMsg)
480 require.Equal(t, string(outputJSON), toolMsg.Content.OfString.Value)
481 require.Equal(t, "quux", toolMsg.ToolCallID)
482 })
483
484 t.Run("should handle different tool output types", func(t *testing.T) {
485 t.Parallel()
486
487 prompt := ai.Prompt{
488 {
489 Role: ai.MessageRoleTool,
490 Content: []ai.MessagePart{
491 ai.ToolResultPart{
492 ToolCallID: "text-tool",
493 Output: ai.ToolResultOutputContentText{
494 Text: "Hello world",
495 },
496 },
497 ai.ToolResultPart{
498 ToolCallID: "error-tool",
499 Output: ai.ToolResultOutputContentError{
500 Error: errors.New("Something went wrong"),
501 },
502 },
503 },
504 },
505 }
506
507 messages, warnings := toOpenAIPrompt(prompt)
508
509 require.Empty(t, warnings)
510 require.Len(t, messages, 2)
511
512 // Check first tool message (text)
513 textToolMsg := messages[0].OfTool
514 require.NotNil(t, textToolMsg)
515 require.Equal(t, "Hello world", textToolMsg.Content.OfString.Value)
516 require.Equal(t, "text-tool", textToolMsg.ToolCallID)
517
518 // Check second tool message (error)
519 errorToolMsg := messages[1].OfTool
520 require.NotNil(t, errorToolMsg)
521 require.Equal(t, "Something went wrong", errorToolMsg.Content.OfString.Value)
522 require.Equal(t, "error-tool", errorToolMsg.ToolCallID)
523 })
524}
525
526func TestToOpenAIPrompt_AssistantMessages(t *testing.T) {
527 t.Parallel()
528
529 t.Run("should handle simple text assistant messages", func(t *testing.T) {
530 t.Parallel()
531
532 prompt := ai.Prompt{
533 {
534 Role: ai.MessageRoleAssistant,
535 Content: []ai.MessagePart{
536 ai.TextPart{Text: "Hello, how can I help you?"},
537 },
538 },
539 }
540
541 messages, warnings := toOpenAIPrompt(prompt)
542
543 require.Empty(t, warnings)
544 require.Len(t, messages, 1)
545
546 assistantMsg := messages[0].OfAssistant
547 require.NotNil(t, assistantMsg)
548 require.Equal(t, "Hello, how can I help you?", assistantMsg.Content.OfString.Value)
549 })
550
551 t.Run("should handle assistant messages with mixed content", func(t *testing.T) {
552 t.Parallel()
553
554 inputArgs := map[string]any{"query": "test"}
555 inputJSON, _ := json.Marshal(inputArgs)
556
557 prompt := ai.Prompt{
558 {
559 Role: ai.MessageRoleAssistant,
560 Content: []ai.MessagePart{
561 ai.TextPart{Text: "Let me search for that."},
562 ai.ToolCallPart{
563 ToolCallID: "call-123",
564 ToolName: "search",
565 Input: string(inputJSON),
566 },
567 },
568 },
569 }
570
571 messages, warnings := toOpenAIPrompt(prompt)
572
573 require.Empty(t, warnings)
574 require.Len(t, messages, 1)
575
576 assistantMsg := messages[0].OfAssistant
577 require.NotNil(t, assistantMsg)
578 require.Equal(t, "Let me search for that.", assistantMsg.Content.OfString.Value)
579 require.Len(t, assistantMsg.ToolCalls, 1)
580
581 toolCall := assistantMsg.ToolCalls[0].OfFunction
582 require.Equal(t, "call-123", toolCall.ID)
583 require.Equal(t, "search", toolCall.Function.Name)
584 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
585 })
586}
587
588var testPrompt = ai.Prompt{
589 {
590 Role: ai.MessageRoleUser,
591 Content: []ai.MessagePart{
592 ai.TextPart{Text: "Hello"},
593 },
594 },
595}
596
597var testLogprobs = map[string]any{
598 "content": []map[string]any{
599 {
600 "token": "Hello",
601 "logprob": -0.0009994634,
602 "top_logprobs": []map[string]any{
603 {
604 "token": "Hello",
605 "logprob": -0.0009994634,
606 },
607 },
608 },
609 {
610 "token": "!",
611 "logprob": -0.13410144,
612 "top_logprobs": []map[string]any{
613 {
614 "token": "!",
615 "logprob": -0.13410144,
616 },
617 },
618 },
619 {
620 "token": " How",
621 "logprob": -0.0009250381,
622 "top_logprobs": []map[string]any{
623 {
624 "token": " How",
625 "logprob": -0.0009250381,
626 },
627 },
628 },
629 {
630 "token": " can",
631 "logprob": -0.047709424,
632 "top_logprobs": []map[string]any{
633 {
634 "token": " can",
635 "logprob": -0.047709424,
636 },
637 },
638 },
639 {
640 "token": " I",
641 "logprob": -0.000009014684,
642 "top_logprobs": []map[string]any{
643 {
644 "token": " I",
645 "logprob": -0.000009014684,
646 },
647 },
648 },
649 {
650 "token": " assist",
651 "logprob": -0.009125131,
652 "top_logprobs": []map[string]any{
653 {
654 "token": " assist",
655 "logprob": -0.009125131,
656 },
657 },
658 },
659 {
660 "token": " you",
661 "logprob": -0.0000066306106,
662 "top_logprobs": []map[string]any{
663 {
664 "token": " you",
665 "logprob": -0.0000066306106,
666 },
667 },
668 },
669 {
670 "token": " today",
671 "logprob": -0.00011093382,
672 "top_logprobs": []map[string]any{
673 {
674 "token": " today",
675 "logprob": -0.00011093382,
676 },
677 },
678 },
679 {
680 "token": "?",
681 "logprob": -0.00004596782,
682 "top_logprobs": []map[string]any{
683 {
684 "token": "?",
685 "logprob": -0.00004596782,
686 },
687 },
688 },
689 },
690}
691
692type mockServer struct {
693 server *httptest.Server
694 response map[string]any
695 calls []mockCall
696}
697
698type mockCall struct {
699 method string
700 path string
701 headers map[string]string
702 body map[string]any
703}
704
705func newMockServer() *mockServer {
706 ms := &mockServer{
707 calls: make([]mockCall, 0),
708 }
709
710 ms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
711 // Record the call
712 call := mockCall{
713 method: r.Method,
714 path: r.URL.Path,
715 headers: make(map[string]string),
716 }
717
718 for k, v := range r.Header {
719 if len(v) > 0 {
720 call.headers[k] = v[0]
721 }
722 }
723
724 // Parse request body
725 if r.Body != nil {
726 var body map[string]any
727 json.NewDecoder(r.Body).Decode(&body)
728 call.body = body
729 }
730
731 ms.calls = append(ms.calls, call)
732
733 // Return mock response
734 w.Header().Set("Content-Type", "application/json")
735 json.NewEncoder(w).Encode(ms.response)
736 }))
737
738 return ms
739}
740
741func (ms *mockServer) close() {
742 ms.server.Close()
743}
744
745func (ms *mockServer) prepareJSONResponse(opts map[string]any) {
746 // Default values
747 response := map[string]any{
748 "id": "chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd",
749 "object": "chat.completion",
750 "created": 1711115037,
751 "model": "gpt-3.5-turbo-0125",
752 "choices": []map[string]any{
753 {
754 "index": 0,
755 "message": map[string]any{
756 "role": "assistant",
757 "content": "",
758 },
759 "finish_reason": "stop",
760 },
761 },
762 "usage": map[string]any{
763 "prompt_tokens": 4,
764 "total_tokens": 34,
765 "completion_tokens": 30,
766 },
767 "system_fingerprint": "fp_3bc1b5746c",
768 }
769
770 // Override with provided options
771 for k, v := range opts {
772 switch k {
773 case "content":
774 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["content"] = v
775 case "tool_calls":
776 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["tool_calls"] = v
777 case "function_call":
778 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["function_call"] = v
779 case "annotations":
780 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["annotations"] = v
781 case "usage":
782 response["usage"] = v
783 case "finish_reason":
784 response["choices"].([]map[string]any)[0]["finish_reason"] = v
785 case "id":
786 response["id"] = v
787 case "created":
788 response["created"] = v
789 case "model":
790 response["model"] = v
791 case "logprobs":
792 if v != nil {
793 response["choices"].([]map[string]any)[0]["logprobs"] = v
794 }
795 }
796 }
797
798 ms.response = response
799}
800
801func TestDoGenerate(t *testing.T) {
802 t.Parallel()
803
804 t.Run("should extract text response", func(t *testing.T) {
805 t.Parallel()
806
807 server := newMockServer()
808 defer server.close()
809
810 server.prepareJSONResponse(map[string]any{
811 "content": "Hello, World!",
812 })
813
814 provider := NewOpenAIProvider(
815 WithOpenAIApiKey("test-api-key"),
816 WithOpenAIBaseURL(server.server.URL),
817 )
818 model, _ := provider.LanguageModel("gpt-3.5-turbo")
819
820 result, err := model.Generate(context.Background(), ai.Call{
821 Prompt: testPrompt,
822 })
823
824 require.NoError(t, err)
825 require.Len(t, result.Content, 1)
826
827 textContent, ok := result.Content[0].(ai.TextContent)
828 require.True(t, ok)
829 require.Equal(t, "Hello, World!", textContent.Text)
830 })
831
832 t.Run("should extract usage", func(t *testing.T) {
833 t.Parallel()
834
835 server := newMockServer()
836 defer server.close()
837
838 server.prepareJSONResponse(map[string]any{
839 "usage": map[string]any{
840 "prompt_tokens": 20,
841 "total_tokens": 25,
842 "completion_tokens": 5,
843 },
844 })
845
846 provider := NewOpenAIProvider(
847 WithOpenAIApiKey("test-api-key"),
848 WithOpenAIBaseURL(server.server.URL),
849 )
850 model, _ := provider.LanguageModel("gpt-3.5-turbo")
851
852 result, err := model.Generate(context.Background(), ai.Call{
853 Prompt: testPrompt,
854 })
855
856 require.NoError(t, err)
857 require.Equal(t, int64(20), result.Usage.InputTokens)
858 require.Equal(t, int64(5), result.Usage.OutputTokens)
859 require.Equal(t, int64(25), result.Usage.TotalTokens)
860 })
861
862 t.Run("should send request body", func(t *testing.T) {
863 t.Parallel()
864
865 server := newMockServer()
866 defer server.close()
867
868 server.prepareJSONResponse(map[string]any{})
869
870 provider := NewOpenAIProvider(
871 WithOpenAIApiKey("test-api-key"),
872 WithOpenAIBaseURL(server.server.URL),
873 )
874 model, _ := provider.LanguageModel("gpt-3.5-turbo")
875
876 _, err := model.Generate(context.Background(), ai.Call{
877 Prompt: testPrompt,
878 })
879
880 require.NoError(t, err)
881 require.Len(t, server.calls, 1)
882
883 call := server.calls[0]
884 require.Equal(t, "POST", call.method)
885 require.Equal(t, "/chat/completions", call.path)
886 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
887
888 messages, ok := call.body["messages"].([]any)
889 require.True(t, ok)
890 require.Len(t, messages, 1)
891
892 message := messages[0].(map[string]any)
893 require.Equal(t, "user", message["role"])
894 require.Equal(t, "Hello", message["content"])
895 })
896
897 t.Run("should support partial usage", func(t *testing.T) {
898 t.Parallel()
899
900 server := newMockServer()
901 defer server.close()
902
903 server.prepareJSONResponse(map[string]any{
904 "usage": map[string]any{
905 "prompt_tokens": 20,
906 "total_tokens": 20,
907 },
908 })
909
910 provider := NewOpenAIProvider(
911 WithOpenAIApiKey("test-api-key"),
912 WithOpenAIBaseURL(server.server.URL),
913 )
914 model, _ := provider.LanguageModel("gpt-3.5-turbo")
915
916 result, err := model.Generate(context.Background(), ai.Call{
917 Prompt: testPrompt,
918 })
919
920 require.NoError(t, err)
921 require.Equal(t, int64(20), result.Usage.InputTokens)
922 require.Equal(t, int64(0), result.Usage.OutputTokens)
923 require.Equal(t, int64(20), result.Usage.TotalTokens)
924 })
925
926 t.Run("should extract logprobs", func(t *testing.T) {
927 t.Parallel()
928
929 server := newMockServer()
930 defer server.close()
931
932 server.prepareJSONResponse(map[string]any{
933 "logprobs": testLogprobs,
934 })
935
936 provider := NewOpenAIProvider(
937 WithOpenAIApiKey("test-api-key"),
938 WithOpenAIBaseURL(server.server.URL),
939 )
940 model, _ := provider.LanguageModel("gpt-3.5-turbo")
941
942 result, err := model.Generate(context.Background(), ai.Call{
943 Prompt: testPrompt,
944 ProviderOptions: ai.ProviderOptions{
945 "openai": map[string]any{
946 "logProbs": true,
947 },
948 },
949 })
950
951 require.NoError(t, err)
952 require.NotNil(t, result.ProviderMetadata)
953
954 openaiMeta, ok := result.ProviderMetadata["openai"]
955 require.True(t, ok)
956
957 logprobs, ok := openaiMeta["logprobs"]
958 require.True(t, ok)
959 require.NotNil(t, logprobs)
960 })
961
962 t.Run("should extract finish reason", func(t *testing.T) {
963 t.Parallel()
964
965 server := newMockServer()
966 defer server.close()
967
968 server.prepareJSONResponse(map[string]any{
969 "finish_reason": "stop",
970 })
971
972 provider := NewOpenAIProvider(
973 WithOpenAIApiKey("test-api-key"),
974 WithOpenAIBaseURL(server.server.URL),
975 )
976 model, _ := provider.LanguageModel("gpt-3.5-turbo")
977
978 result, err := model.Generate(context.Background(), ai.Call{
979 Prompt: testPrompt,
980 })
981
982 require.NoError(t, err)
983 require.Equal(t, ai.FinishReasonStop, result.FinishReason)
984 })
985
986 t.Run("should support unknown finish reason", func(t *testing.T) {
987 t.Parallel()
988
989 server := newMockServer()
990 defer server.close()
991
992 server.prepareJSONResponse(map[string]any{
993 "finish_reason": "eos",
994 })
995
996 provider := NewOpenAIProvider(
997 WithOpenAIApiKey("test-api-key"),
998 WithOpenAIBaseURL(server.server.URL),
999 )
1000 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1001
1002 result, err := model.Generate(context.Background(), ai.Call{
1003 Prompt: testPrompt,
1004 })
1005
1006 require.NoError(t, err)
1007 require.Equal(t, ai.FinishReasonUnknown, result.FinishReason)
1008 })
1009
1010 t.Run("should pass the model and the messages", func(t *testing.T) {
1011 t.Parallel()
1012
1013 server := newMockServer()
1014 defer server.close()
1015
1016 server.prepareJSONResponse(map[string]any{
1017 "content": "",
1018 })
1019
1020 provider := NewOpenAIProvider(
1021 WithOpenAIApiKey("test-api-key"),
1022 WithOpenAIBaseURL(server.server.URL),
1023 )
1024 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1025
1026 _, err := model.Generate(context.Background(), ai.Call{
1027 Prompt: testPrompt,
1028 })
1029
1030 require.NoError(t, err)
1031 require.Len(t, server.calls, 1)
1032
1033 call := server.calls[0]
1034 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1035
1036 messages := call.body["messages"].([]any)
1037 require.Len(t, messages, 1)
1038
1039 message := messages[0].(map[string]any)
1040 require.Equal(t, "user", message["role"])
1041 require.Equal(t, "Hello", message["content"])
1042 })
1043
1044 t.Run("should pass settings", func(t *testing.T) {
1045 t.Parallel()
1046
1047 server := newMockServer()
1048 defer server.close()
1049
1050 server.prepareJSONResponse(map[string]any{})
1051
1052 provider := NewOpenAIProvider(
1053 WithOpenAIApiKey("test-api-key"),
1054 WithOpenAIBaseURL(server.server.URL),
1055 )
1056 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1057
1058 _, err := model.Generate(context.Background(), ai.Call{
1059 Prompt: testPrompt,
1060 ProviderOptions: ai.ProviderOptions{
1061 "openai": map[string]any{
1062 "logitBias": map[string]int64{
1063 "50256": -100,
1064 },
1065 "parallelToolCalls": false,
1066 "user": "test-user-id",
1067 },
1068 },
1069 })
1070
1071 require.NoError(t, err)
1072 require.Len(t, server.calls, 1)
1073
1074 call := server.calls[0]
1075 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1076
1077 messages := call.body["messages"].([]any)
1078 require.Len(t, messages, 1)
1079
1080 logitBias := call.body["logit_bias"].(map[string]any)
1081 require.Equal(t, float64(-100), logitBias["50256"])
1082 require.Equal(t, false, call.body["parallel_tool_calls"])
1083 require.Equal(t, "test-user-id", call.body["user"])
1084 })
1085
1086 t.Run("should pass reasoningEffort setting", func(t *testing.T) {
1087 t.Parallel()
1088
1089 server := newMockServer()
1090 defer server.close()
1091
1092 server.prepareJSONResponse(map[string]any{
1093 "content": "",
1094 })
1095
1096 provider := NewOpenAIProvider(
1097 WithOpenAIApiKey("test-api-key"),
1098 WithOpenAIBaseURL(server.server.URL),
1099 )
1100 model, _ := provider.LanguageModel("o1-mini")
1101
1102 _, err := model.Generate(context.Background(), ai.Call{
1103 Prompt: testPrompt,
1104 ProviderOptions: ai.ProviderOptions{
1105 "openai": map[string]any{
1106 "reasoningEffort": "low",
1107 },
1108 },
1109 })
1110
1111 require.NoError(t, err)
1112 require.Len(t, server.calls, 1)
1113
1114 call := server.calls[0]
1115 require.Equal(t, "o1-mini", call.body["model"])
1116 require.Equal(t, "low", call.body["reasoning_effort"])
1117
1118 messages := call.body["messages"].([]any)
1119 require.Len(t, messages, 1)
1120
1121 message := messages[0].(map[string]any)
1122 require.Equal(t, "user", message["role"])
1123 require.Equal(t, "Hello", message["content"])
1124 })
1125
1126 t.Run("should pass textVerbosity setting", func(t *testing.T) {
1127 t.Parallel()
1128
1129 server := newMockServer()
1130 defer server.close()
1131
1132 server.prepareJSONResponse(map[string]any{
1133 "content": "",
1134 })
1135
1136 provider := NewOpenAIProvider(
1137 WithOpenAIApiKey("test-api-key"),
1138 WithOpenAIBaseURL(server.server.URL),
1139 )
1140 model, _ := provider.LanguageModel("gpt-4o")
1141
1142 _, err := model.Generate(context.Background(), ai.Call{
1143 Prompt: testPrompt,
1144 ProviderOptions: ai.ProviderOptions{
1145 "openai": map[string]any{
1146 "textVerbosity": "low",
1147 },
1148 },
1149 })
1150
1151 require.NoError(t, err)
1152 require.Len(t, server.calls, 1)
1153
1154 call := server.calls[0]
1155 require.Equal(t, "gpt-4o", call.body["model"])
1156 require.Equal(t, "low", call.body["verbosity"])
1157
1158 messages := call.body["messages"].([]any)
1159 require.Len(t, messages, 1)
1160
1161 message := messages[0].(map[string]any)
1162 require.Equal(t, "user", message["role"])
1163 require.Equal(t, "Hello", message["content"])
1164 })
1165
1166 t.Run("should pass tools and toolChoice", func(t *testing.T) {
1167 t.Parallel()
1168
1169 server := newMockServer()
1170 defer server.close()
1171
1172 server.prepareJSONResponse(map[string]any{
1173 "content": "",
1174 })
1175
1176 provider := NewOpenAIProvider(
1177 WithOpenAIApiKey("test-api-key"),
1178 WithOpenAIBaseURL(server.server.URL),
1179 )
1180 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1181
1182 _, err := model.Generate(context.Background(), ai.Call{
1183 Prompt: testPrompt,
1184 Tools: []ai.Tool{
1185 ai.FunctionTool{
1186 Name: "test-tool",
1187 InputSchema: map[string]any{
1188 "type": "object",
1189 "properties": map[string]any{
1190 "value": map[string]any{
1191 "type": "string",
1192 },
1193 },
1194 "required": []string{"value"},
1195 "additionalProperties": false,
1196 "$schema": "http://json-schema.org/draft-07/schema#",
1197 },
1198 },
1199 },
1200 ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0],
1201 })
1202
1203 require.NoError(t, err)
1204 require.Len(t, server.calls, 1)
1205
1206 call := server.calls[0]
1207 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1208
1209 messages := call.body["messages"].([]any)
1210 require.Len(t, messages, 1)
1211
1212 tools := call.body["tools"].([]any)
1213 require.Len(t, tools, 1)
1214
1215 tool := tools[0].(map[string]any)
1216 require.Equal(t, "function", tool["type"])
1217
1218 function := tool["function"].(map[string]any)
1219 require.Equal(t, "test-tool", function["name"])
1220 require.Equal(t, false, function["strict"])
1221
1222 toolChoice := call.body["tool_choice"].(map[string]any)
1223 require.Equal(t, "function", toolChoice["type"])
1224
1225 toolChoiceFunction := toolChoice["function"].(map[string]any)
1226 require.Equal(t, "test-tool", toolChoiceFunction["name"])
1227 })
1228
1229 t.Run("should parse tool results", func(t *testing.T) {
1230 t.Parallel()
1231
1232 server := newMockServer()
1233 defer server.close()
1234
1235 server.prepareJSONResponse(map[string]any{
1236 "tool_calls": []map[string]any{
1237 {
1238 "id": "call_O17Uplv4lJvD6DVdIvFFeRMw",
1239 "type": "function",
1240 "function": map[string]any{
1241 "name": "test-tool",
1242 "arguments": `{"value":"Spark"}`,
1243 },
1244 },
1245 },
1246 })
1247
1248 provider := NewOpenAIProvider(
1249 WithOpenAIApiKey("test-api-key"),
1250 WithOpenAIBaseURL(server.server.URL),
1251 )
1252 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1253
1254 result, err := model.Generate(context.Background(), ai.Call{
1255 Prompt: testPrompt,
1256 Tools: []ai.Tool{
1257 ai.FunctionTool{
1258 Name: "test-tool",
1259 InputSchema: map[string]any{
1260 "type": "object",
1261 "properties": map[string]any{
1262 "value": map[string]any{
1263 "type": "string",
1264 },
1265 },
1266 "required": []string{"value"},
1267 "additionalProperties": false,
1268 "$schema": "http://json-schema.org/draft-07/schema#",
1269 },
1270 },
1271 },
1272 ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0],
1273 })
1274
1275 require.NoError(t, err)
1276 require.Len(t, result.Content, 1)
1277
1278 toolCall, ok := result.Content[0].(ai.ToolCallContent)
1279 require.True(t, ok)
1280 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", toolCall.ToolCallID)
1281 require.Equal(t, "test-tool", toolCall.ToolName)
1282 require.Equal(t, `{"value":"Spark"}`, toolCall.Input)
1283 })
1284
1285 t.Run("should parse annotations/citations", func(t *testing.T) {
1286 t.Parallel()
1287
1288 server := newMockServer()
1289 defer server.close()
1290
1291 server.prepareJSONResponse(map[string]any{
1292 "content": "Based on the search results [doc1], I found information.",
1293 "annotations": []map[string]any{
1294 {
1295 "type": "url_citation",
1296 "url_citation": map[string]any{
1297 "start_index": 24,
1298 "end_index": 29,
1299 "url": "https://example.com/doc1.pdf",
1300 "title": "Document 1",
1301 },
1302 },
1303 },
1304 })
1305
1306 provider := NewOpenAIProvider(
1307 WithOpenAIApiKey("test-api-key"),
1308 WithOpenAIBaseURL(server.server.URL),
1309 )
1310 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1311
1312 result, err := model.Generate(context.Background(), ai.Call{
1313 Prompt: testPrompt,
1314 })
1315
1316 require.NoError(t, err)
1317 require.Len(t, result.Content, 2)
1318
1319 textContent, ok := result.Content[0].(ai.TextContent)
1320 require.True(t, ok)
1321 require.Equal(t, "Based on the search results [doc1], I found information.", textContent.Text)
1322
1323 sourceContent, ok := result.Content[1].(ai.SourceContent)
1324 require.True(t, ok)
1325 require.Equal(t, ai.SourceTypeURL, sourceContent.SourceType)
1326 require.Equal(t, "https://example.com/doc1.pdf", sourceContent.URL)
1327 require.Equal(t, "Document 1", sourceContent.Title)
1328 require.NotEmpty(t, sourceContent.ID)
1329 })
1330
1331 t.Run("should return cached_tokens in prompt_details_tokens", func(t *testing.T) {
1332 t.Parallel()
1333
1334 server := newMockServer()
1335 defer server.close()
1336
1337 server.prepareJSONResponse(map[string]any{
1338 "usage": map[string]any{
1339 "prompt_tokens": 15,
1340 "completion_tokens": 20,
1341 "total_tokens": 35,
1342 "prompt_tokens_details": map[string]any{
1343 "cached_tokens": 1152,
1344 },
1345 },
1346 })
1347
1348 provider := NewOpenAIProvider(
1349 WithOpenAIApiKey("test-api-key"),
1350 WithOpenAIBaseURL(server.server.URL),
1351 )
1352 model, _ := provider.LanguageModel("gpt-4o-mini")
1353
1354 result, err := model.Generate(context.Background(), ai.Call{
1355 Prompt: testPrompt,
1356 })
1357
1358 require.NoError(t, err)
1359 require.Equal(t, int64(1152), result.Usage.CacheReadTokens)
1360 require.Equal(t, int64(15), result.Usage.InputTokens)
1361 require.Equal(t, int64(20), result.Usage.OutputTokens)
1362 require.Equal(t, int64(35), result.Usage.TotalTokens)
1363 })
1364
1365 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
1366 t.Parallel()
1367
1368 server := newMockServer()
1369 defer server.close()
1370
1371 server.prepareJSONResponse(map[string]any{
1372 "usage": map[string]any{
1373 "prompt_tokens": 15,
1374 "completion_tokens": 20,
1375 "total_tokens": 35,
1376 "completion_tokens_details": map[string]any{
1377 "accepted_prediction_tokens": 123,
1378 "rejected_prediction_tokens": 456,
1379 },
1380 },
1381 })
1382
1383 provider := NewOpenAIProvider(
1384 WithOpenAIApiKey("test-api-key"),
1385 WithOpenAIBaseURL(server.server.URL),
1386 )
1387 model, _ := provider.LanguageModel("gpt-4o-mini")
1388
1389 result, err := model.Generate(context.Background(), ai.Call{
1390 Prompt: testPrompt,
1391 })
1392
1393 require.NoError(t, err)
1394 require.NotNil(t, result.ProviderMetadata)
1395
1396 openaiMeta, ok := result.ProviderMetadata["openai"]
1397 require.True(t, ok)
1398 require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"])
1399 require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"])
1400 })
1401
1402 t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) {
1403 t.Parallel()
1404
1405 server := newMockServer()
1406 defer server.close()
1407
1408 server.prepareJSONResponse(map[string]any{})
1409
1410 provider := NewOpenAIProvider(
1411 WithOpenAIApiKey("test-api-key"),
1412 WithOpenAIBaseURL(server.server.URL),
1413 )
1414 model, _ := provider.LanguageModel("o1-preview")
1415
1416 result, err := model.Generate(context.Background(), ai.Call{
1417 Prompt: testPrompt,
1418 Temperature: &[]float64{0.5}[0],
1419 TopP: &[]float64{0.7}[0],
1420 FrequencyPenalty: &[]float64{0.2}[0],
1421 PresencePenalty: &[]float64{0.3}[0],
1422 })
1423
1424 require.NoError(t, err)
1425 require.Len(t, server.calls, 1)
1426
1427 call := server.calls[0]
1428 require.Equal(t, "o1-preview", call.body["model"])
1429
1430 messages := call.body["messages"].([]any)
1431 require.Len(t, messages, 1)
1432
1433 message := messages[0].(map[string]any)
1434 require.Equal(t, "user", message["role"])
1435 require.Equal(t, "Hello", message["content"])
1436
1437 // These should not be present
1438 require.Nil(t, call.body["temperature"])
1439 require.Nil(t, call.body["top_p"])
1440 require.Nil(t, call.body["frequency_penalty"])
1441 require.Nil(t, call.body["presence_penalty"])
1442
1443 // Should have warnings
1444 require.Len(t, result.Warnings, 4)
1445 require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1446 require.Equal(t, "temperature", result.Warnings[0].Setting)
1447 require.Contains(t, result.Warnings[0].Details, "temperature is not supported for reasoning models")
1448 })
1449
1450 t.Run("should convert maxOutputTokens to max_completion_tokens for reasoning models", func(t *testing.T) {
1451 t.Parallel()
1452
1453 server := newMockServer()
1454 defer server.close()
1455
1456 server.prepareJSONResponse(map[string]any{})
1457
1458 provider := NewOpenAIProvider(
1459 WithOpenAIApiKey("test-api-key"),
1460 WithOpenAIBaseURL(server.server.URL),
1461 )
1462 model, _ := provider.LanguageModel("o1-preview")
1463
1464 _, err := model.Generate(context.Background(), ai.Call{
1465 Prompt: testPrompt,
1466 MaxOutputTokens: &[]int64{1000}[0],
1467 })
1468
1469 require.NoError(t, err)
1470 require.Len(t, server.calls, 1)
1471
1472 call := server.calls[0]
1473 require.Equal(t, "o1-preview", call.body["model"])
1474 require.Equal(t, float64(1000), call.body["max_completion_tokens"])
1475 require.Nil(t, call.body["max_tokens"])
1476
1477 messages := call.body["messages"].([]any)
1478 require.Len(t, messages, 1)
1479
1480 message := messages[0].(map[string]any)
1481 require.Equal(t, "user", message["role"])
1482 require.Equal(t, "Hello", message["content"])
1483 })
1484
1485 t.Run("should return reasoning tokens", func(t *testing.T) {
1486 t.Parallel()
1487
1488 server := newMockServer()
1489 defer server.close()
1490
1491 server.prepareJSONResponse(map[string]any{
1492 "usage": map[string]any{
1493 "prompt_tokens": 15,
1494 "completion_tokens": 20,
1495 "total_tokens": 35,
1496 "completion_tokens_details": map[string]any{
1497 "reasoning_tokens": 10,
1498 },
1499 },
1500 })
1501
1502 provider := NewOpenAIProvider(
1503 WithOpenAIApiKey("test-api-key"),
1504 WithOpenAIBaseURL(server.server.URL),
1505 )
1506 model, _ := provider.LanguageModel("o1-preview")
1507
1508 result, err := model.Generate(context.Background(), ai.Call{
1509 Prompt: testPrompt,
1510 })
1511
1512 require.NoError(t, err)
1513 require.Equal(t, int64(15), result.Usage.InputTokens)
1514 require.Equal(t, int64(20), result.Usage.OutputTokens)
1515 require.Equal(t, int64(35), result.Usage.TotalTokens)
1516 require.Equal(t, int64(10), result.Usage.ReasoningTokens)
1517 })
1518
1519 t.Run("should send max_completion_tokens extension setting", func(t *testing.T) {
1520 t.Parallel()
1521
1522 server := newMockServer()
1523 defer server.close()
1524
1525 server.prepareJSONResponse(map[string]any{
1526 "model": "o1-preview",
1527 })
1528
1529 provider := NewOpenAIProvider(
1530 WithOpenAIApiKey("test-api-key"),
1531 WithOpenAIBaseURL(server.server.URL),
1532 )
1533 model, _ := provider.LanguageModel("o1-preview")
1534
1535 _, err := model.Generate(context.Background(), ai.Call{
1536 Prompt: testPrompt,
1537 ProviderOptions: ai.ProviderOptions{
1538 "openai": map[string]any{
1539 "maxCompletionTokens": 255,
1540 },
1541 },
1542 })
1543
1544 require.NoError(t, err)
1545 require.Len(t, server.calls, 1)
1546
1547 call := server.calls[0]
1548 require.Equal(t, "o1-preview", call.body["model"])
1549 require.Equal(t, float64(255), call.body["max_completion_tokens"])
1550
1551 messages := call.body["messages"].([]any)
1552 require.Len(t, messages, 1)
1553
1554 message := messages[0].(map[string]any)
1555 require.Equal(t, "user", message["role"])
1556 require.Equal(t, "Hello", message["content"])
1557 })
1558
1559 t.Run("should send prediction extension setting", func(t *testing.T) {
1560 t.Parallel()
1561
1562 server := newMockServer()
1563 defer server.close()
1564
1565 server.prepareJSONResponse(map[string]any{
1566 "content": "",
1567 })
1568
1569 provider := NewOpenAIProvider(
1570 WithOpenAIApiKey("test-api-key"),
1571 WithOpenAIBaseURL(server.server.URL),
1572 )
1573 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1574
1575 _, err := model.Generate(context.Background(), ai.Call{
1576 Prompt: testPrompt,
1577 ProviderOptions: ai.ProviderOptions{
1578 "openai": map[string]any{
1579 "prediction": map[string]any{
1580 "type": "content",
1581 "content": "Hello, World!",
1582 },
1583 },
1584 },
1585 })
1586
1587 require.NoError(t, err)
1588 require.Len(t, server.calls, 1)
1589
1590 call := server.calls[0]
1591 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1592
1593 prediction := call.body["prediction"].(map[string]any)
1594 require.Equal(t, "content", prediction["type"])
1595 require.Equal(t, "Hello, World!", prediction["content"])
1596
1597 messages := call.body["messages"].([]any)
1598 require.Len(t, messages, 1)
1599
1600 message := messages[0].(map[string]any)
1601 require.Equal(t, "user", message["role"])
1602 require.Equal(t, "Hello", message["content"])
1603 })
1604
1605 t.Run("should send store extension setting", func(t *testing.T) {
1606 t.Parallel()
1607
1608 server := newMockServer()
1609 defer server.close()
1610
1611 server.prepareJSONResponse(map[string]any{
1612 "content": "",
1613 })
1614
1615 provider := NewOpenAIProvider(
1616 WithOpenAIApiKey("test-api-key"),
1617 WithOpenAIBaseURL(server.server.URL),
1618 )
1619 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1620
1621 _, err := model.Generate(context.Background(), ai.Call{
1622 Prompt: testPrompt,
1623 ProviderOptions: ai.ProviderOptions{
1624 "openai": map[string]any{
1625 "store": true,
1626 },
1627 },
1628 })
1629
1630 require.NoError(t, err)
1631 require.Len(t, server.calls, 1)
1632
1633 call := server.calls[0]
1634 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1635 require.Equal(t, true, call.body["store"])
1636
1637 messages := call.body["messages"].([]any)
1638 require.Len(t, messages, 1)
1639
1640 message := messages[0].(map[string]any)
1641 require.Equal(t, "user", message["role"])
1642 require.Equal(t, "Hello", message["content"])
1643 })
1644
1645 t.Run("should send metadata extension values", func(t *testing.T) {
1646 t.Parallel()
1647
1648 server := newMockServer()
1649 defer server.close()
1650
1651 server.prepareJSONResponse(map[string]any{
1652 "content": "",
1653 })
1654
1655 provider := NewOpenAIProvider(
1656 WithOpenAIApiKey("test-api-key"),
1657 WithOpenAIBaseURL(server.server.URL),
1658 )
1659 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1660
1661 _, err := model.Generate(context.Background(), ai.Call{
1662 Prompt: testPrompt,
1663 ProviderOptions: ai.ProviderOptions{
1664 "openai": map[string]any{
1665 "metadata": map[string]any{
1666 "custom": "value",
1667 },
1668 },
1669 },
1670 })
1671
1672 require.NoError(t, err)
1673 require.Len(t, server.calls, 1)
1674
1675 call := server.calls[0]
1676 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1677
1678 metadata := call.body["metadata"].(map[string]any)
1679 require.Equal(t, "value", metadata["custom"])
1680
1681 messages := call.body["messages"].([]any)
1682 require.Len(t, messages, 1)
1683
1684 message := messages[0].(map[string]any)
1685 require.Equal(t, "user", message["role"])
1686 require.Equal(t, "Hello", message["content"])
1687 })
1688
1689 t.Run("should send promptCacheKey extension value", func(t *testing.T) {
1690 t.Parallel()
1691
1692 server := newMockServer()
1693 defer server.close()
1694
1695 server.prepareJSONResponse(map[string]any{
1696 "content": "",
1697 })
1698
1699 provider := NewOpenAIProvider(
1700 WithOpenAIApiKey("test-api-key"),
1701 WithOpenAIBaseURL(server.server.URL),
1702 )
1703 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1704
1705 _, err := model.Generate(context.Background(), ai.Call{
1706 Prompt: testPrompt,
1707 ProviderOptions: ai.ProviderOptions{
1708 "openai": map[string]any{
1709 "promptCacheKey": "test-cache-key-123",
1710 },
1711 },
1712 })
1713
1714 require.NoError(t, err)
1715 require.Len(t, server.calls, 1)
1716
1717 call := server.calls[0]
1718 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1719 require.Equal(t, "test-cache-key-123", call.body["prompt_cache_key"])
1720
1721 messages := call.body["messages"].([]any)
1722 require.Len(t, messages, 1)
1723
1724 message := messages[0].(map[string]any)
1725 require.Equal(t, "user", message["role"])
1726 require.Equal(t, "Hello", message["content"])
1727 })
1728
1729 t.Run("should send safetyIdentifier extension value", func(t *testing.T) {
1730 t.Parallel()
1731
1732 server := newMockServer()
1733 defer server.close()
1734
1735 server.prepareJSONResponse(map[string]any{
1736 "content": "",
1737 })
1738
1739 provider := NewOpenAIProvider(
1740 WithOpenAIApiKey("test-api-key"),
1741 WithOpenAIBaseURL(server.server.URL),
1742 )
1743 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1744
1745 _, err := model.Generate(context.Background(), ai.Call{
1746 Prompt: testPrompt,
1747 ProviderOptions: ai.ProviderOptions{
1748 "openai": map[string]any{
1749 "safetyIdentifier": "test-safety-identifier-123",
1750 },
1751 },
1752 })
1753
1754 require.NoError(t, err)
1755 require.Len(t, server.calls, 1)
1756
1757 call := server.calls[0]
1758 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1759 require.Equal(t, "test-safety-identifier-123", call.body["safety_identifier"])
1760
1761 messages := call.body["messages"].([]any)
1762 require.Len(t, messages, 1)
1763
1764 message := messages[0].(map[string]any)
1765 require.Equal(t, "user", message["role"])
1766 require.Equal(t, "Hello", message["content"])
1767 })
1768
1769 t.Run("should remove temperature setting for search preview models", func(t *testing.T) {
1770 t.Parallel()
1771
1772 server := newMockServer()
1773 defer server.close()
1774
1775 server.prepareJSONResponse(map[string]any{})
1776
1777 provider := NewOpenAIProvider(
1778 WithOpenAIApiKey("test-api-key"),
1779 WithOpenAIBaseURL(server.server.URL),
1780 )
1781 model, _ := provider.LanguageModel("gpt-4o-search-preview")
1782
1783 result, err := model.Generate(context.Background(), ai.Call{
1784 Prompt: testPrompt,
1785 Temperature: &[]float64{0.7}[0],
1786 })
1787
1788 require.NoError(t, err)
1789 require.Len(t, server.calls, 1)
1790
1791 call := server.calls[0]
1792 require.Equal(t, "gpt-4o-search-preview", call.body["model"])
1793 require.Nil(t, call.body["temperature"])
1794
1795 require.Len(t, result.Warnings, 1)
1796 require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1797 require.Equal(t, "temperature", result.Warnings[0].Setting)
1798 require.Contains(t, result.Warnings[0].Details, "search preview models")
1799 })
1800
1801 t.Run("should send ServiceTier flex processing setting", func(t *testing.T) {
1802 t.Parallel()
1803
1804 server := newMockServer()
1805 defer server.close()
1806
1807 server.prepareJSONResponse(map[string]any{
1808 "content": "",
1809 })
1810
1811 provider := NewOpenAIProvider(
1812 WithOpenAIApiKey("test-api-key"),
1813 WithOpenAIBaseURL(server.server.URL),
1814 )
1815 model, _ := provider.LanguageModel("o3-mini")
1816
1817 _, err := model.Generate(context.Background(), ai.Call{
1818 Prompt: testPrompt,
1819 ProviderOptions: ai.ProviderOptions{
1820 "openai": map[string]any{
1821 "serviceTier": "flex",
1822 },
1823 },
1824 })
1825
1826 require.NoError(t, err)
1827 require.Len(t, server.calls, 1)
1828
1829 call := server.calls[0]
1830 require.Equal(t, "o3-mini", call.body["model"])
1831 require.Equal(t, "flex", call.body["service_tier"])
1832
1833 messages := call.body["messages"].([]any)
1834 require.Len(t, messages, 1)
1835
1836 message := messages[0].(map[string]any)
1837 require.Equal(t, "user", message["role"])
1838 require.Equal(t, "Hello", message["content"])
1839 })
1840
1841 t.Run("should show warning when using flex processing with unsupported model", func(t *testing.T) {
1842 t.Parallel()
1843
1844 server := newMockServer()
1845 defer server.close()
1846
1847 server.prepareJSONResponse(map[string]any{})
1848
1849 provider := NewOpenAIProvider(
1850 WithOpenAIApiKey("test-api-key"),
1851 WithOpenAIBaseURL(server.server.URL),
1852 )
1853 model, _ := provider.LanguageModel("gpt-4o-mini")
1854
1855 result, err := model.Generate(context.Background(), ai.Call{
1856 Prompt: testPrompt,
1857 ProviderOptions: ai.ProviderOptions{
1858 "openai": map[string]any{
1859 "serviceTier": "flex",
1860 },
1861 },
1862 })
1863
1864 require.NoError(t, err)
1865 require.Len(t, server.calls, 1)
1866
1867 call := server.calls[0]
1868 require.Nil(t, call.body["service_tier"])
1869
1870 require.Len(t, result.Warnings, 1)
1871 require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1872 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
1873 require.Contains(t, result.Warnings[0].Details, "flex processing is only available")
1874 })
1875
1876 t.Run("should send serviceTier priority processing setting", func(t *testing.T) {
1877 t.Parallel()
1878
1879 server := newMockServer()
1880 defer server.close()
1881
1882 server.prepareJSONResponse(map[string]any{})
1883
1884 provider := NewOpenAIProvider(
1885 WithOpenAIApiKey("test-api-key"),
1886 WithOpenAIBaseURL(server.server.URL),
1887 )
1888 model, _ := provider.LanguageModel("gpt-4o-mini")
1889
1890 _, err := model.Generate(context.Background(), ai.Call{
1891 Prompt: testPrompt,
1892 ProviderOptions: ai.ProviderOptions{
1893 "openai": map[string]any{
1894 "serviceTier": "priority",
1895 },
1896 },
1897 })
1898
1899 require.NoError(t, err)
1900 require.Len(t, server.calls, 1)
1901
1902 call := server.calls[0]
1903 require.Equal(t, "gpt-4o-mini", call.body["model"])
1904 require.Equal(t, "priority", call.body["service_tier"])
1905
1906 messages := call.body["messages"].([]any)
1907 require.Len(t, messages, 1)
1908
1909 message := messages[0].(map[string]any)
1910 require.Equal(t, "user", message["role"])
1911 require.Equal(t, "Hello", message["content"])
1912 })
1913
1914 t.Run("should show warning when using priority processing with unsupported model", func(t *testing.T) {
1915 t.Parallel()
1916
1917 server := newMockServer()
1918 defer server.close()
1919
1920 server.prepareJSONResponse(map[string]any{})
1921
1922 provider := NewOpenAIProvider(
1923 WithOpenAIApiKey("test-api-key"),
1924 WithOpenAIBaseURL(server.server.URL),
1925 )
1926 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1927
1928 result, err := model.Generate(context.Background(), ai.Call{
1929 Prompt: testPrompt,
1930 ProviderOptions: ai.ProviderOptions{
1931 "openai": map[string]any{
1932 "serviceTier": "priority",
1933 },
1934 },
1935 })
1936
1937 require.NoError(t, err)
1938 require.Len(t, server.calls, 1)
1939
1940 call := server.calls[0]
1941 require.Nil(t, call.body["service_tier"])
1942
1943 require.Len(t, result.Warnings, 1)
1944 require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1945 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
1946 require.Contains(t, result.Warnings[0].Details, "priority processing is only available")
1947 })
1948}
1949
1950type streamingMockServer struct {
1951 server *httptest.Server
1952 chunks []string
1953 calls []mockCall
1954}
1955
1956func newStreamingMockServer() *streamingMockServer {
1957 sms := &streamingMockServer{
1958 calls: make([]mockCall, 0),
1959 }
1960
1961 sms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1962 // Record the call
1963 call := mockCall{
1964 method: r.Method,
1965 path: r.URL.Path,
1966 headers: make(map[string]string),
1967 }
1968
1969 for k, v := range r.Header {
1970 if len(v) > 0 {
1971 call.headers[k] = v[0]
1972 }
1973 }
1974
1975 // Parse request body
1976 if r.Body != nil {
1977 var body map[string]any
1978 json.NewDecoder(r.Body).Decode(&body)
1979 call.body = body
1980 }
1981
1982 sms.calls = append(sms.calls, call)
1983
1984 // Set streaming headers
1985 w.Header().Set("Content-Type", "text/event-stream")
1986 w.Header().Set("Cache-Control", "no-cache")
1987 w.Header().Set("Connection", "keep-alive")
1988
1989 // Add custom headers if any
1990 for _, chunk := range sms.chunks {
1991 if strings.HasPrefix(chunk, "HEADER:") {
1992 parts := strings.SplitN(chunk[7:], ":", 2)
1993 if len(parts) == 2 {
1994 w.Header().Set(parts[0], parts[1])
1995 }
1996 continue
1997 }
1998 }
1999
2000 w.WriteHeader(http.StatusOK)
2001
2002 // Write chunks
2003 for _, chunk := range sms.chunks {
2004 if strings.HasPrefix(chunk, "HEADER:") {
2005 continue
2006 }
2007 w.Write([]byte(chunk))
2008 if f, ok := w.(http.Flusher); ok {
2009 f.Flush()
2010 }
2011 }
2012 }))
2013
2014 return sms
2015}
2016
2017func (sms *streamingMockServer) close() {
2018 sms.server.Close()
2019}
2020
2021func (sms *streamingMockServer) prepareStreamResponse(opts map[string]any) {
2022 content := []string{}
2023 if c, ok := opts["content"].([]string); ok {
2024 content = c
2025 }
2026
2027 usage := map[string]any{
2028 "prompt_tokens": 17,
2029 "total_tokens": 244,
2030 "completion_tokens": 227,
2031 }
2032 if u, ok := opts["usage"].(map[string]any); ok {
2033 usage = u
2034 }
2035
2036 logprobs := map[string]any{}
2037 if l, ok := opts["logprobs"].(map[string]any); ok {
2038 logprobs = l
2039 }
2040
2041 finishReason := "stop"
2042 if fr, ok := opts["finish_reason"].(string); ok {
2043 finishReason = fr
2044 }
2045
2046 model := "gpt-3.5-turbo-0613"
2047 if m, ok := opts["model"].(string); ok {
2048 model = m
2049 }
2050
2051 headers := map[string]string{}
2052 if h, ok := opts["headers"].(map[string]string); ok {
2053 headers = h
2054 }
2055
2056 chunks := []string{}
2057
2058 // Add custom headers
2059 for k, v := range headers {
2060 chunks = append(chunks, "HEADER:"+k+":"+v)
2061 }
2062
2063 // Initial chunk with role
2064 initialChunk := map[string]any{
2065 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2066 "object": "chat.completion.chunk",
2067 "created": 1702657020,
2068 "model": model,
2069 "system_fingerprint": nil,
2070 "choices": []map[string]any{
2071 {
2072 "index": 0,
2073 "delta": map[string]any{
2074 "role": "assistant",
2075 "content": "",
2076 },
2077 "finish_reason": nil,
2078 },
2079 },
2080 }
2081 initialData, _ := json.Marshal(initialChunk)
2082 chunks = append(chunks, "data: "+string(initialData)+"\n\n")
2083
2084 // Content chunks
2085 for i, text := range content {
2086 contentChunk := map[string]any{
2087 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2088 "object": "chat.completion.chunk",
2089 "created": 1702657020,
2090 "model": model,
2091 "system_fingerprint": nil,
2092 "choices": []map[string]any{
2093 {
2094 "index": 1,
2095 "delta": map[string]any{
2096 "content": text,
2097 },
2098 "finish_reason": nil,
2099 },
2100 },
2101 }
2102 contentData, _ := json.Marshal(contentChunk)
2103 chunks = append(chunks, "data: "+string(contentData)+"\n\n")
2104
2105 // Add annotations if this is the last content chunk and we have annotations
2106 if i == len(content)-1 {
2107 if annotations, ok := opts["annotations"].([]map[string]any); ok {
2108 annotationChunk := map[string]any{
2109 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2110 "object": "chat.completion.chunk",
2111 "created": 1702657020,
2112 "model": model,
2113 "system_fingerprint": nil,
2114 "choices": []map[string]any{
2115 {
2116 "index": 1,
2117 "delta": map[string]any{
2118 "annotations": annotations,
2119 },
2120 "finish_reason": nil,
2121 },
2122 },
2123 }
2124 annotationData, _ := json.Marshal(annotationChunk)
2125 chunks = append(chunks, "data: "+string(annotationData)+"\n\n")
2126 }
2127 }
2128 }
2129
2130 // Finish chunk
2131 finishChunk := map[string]any{
2132 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2133 "object": "chat.completion.chunk",
2134 "created": 1702657020,
2135 "model": model,
2136 "system_fingerprint": nil,
2137 "choices": []map[string]any{
2138 {
2139 "index": 0,
2140 "delta": map[string]any{},
2141 "finish_reason": finishReason,
2142 },
2143 },
2144 }
2145
2146 if len(logprobs) > 0 {
2147 finishChunk["choices"].([]map[string]any)[0]["logprobs"] = logprobs
2148 }
2149
2150 finishData, _ := json.Marshal(finishChunk)
2151 chunks = append(chunks, "data: "+string(finishData)+"\n\n")
2152
2153 // Usage chunk
2154 usageChunk := map[string]any{
2155 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2156 "object": "chat.completion.chunk",
2157 "created": 1702657020,
2158 "model": model,
2159 "system_fingerprint": "fp_3bc1b5746c",
2160 "choices": []map[string]any{},
2161 "usage": usage,
2162 }
2163 usageData, _ := json.Marshal(usageChunk)
2164 chunks = append(chunks, "data: "+string(usageData)+"\n\n")
2165
2166 // Done
2167 chunks = append(chunks, "data: [DONE]\n\n")
2168
2169 sms.chunks = chunks
2170}
2171
2172func (sms *streamingMockServer) prepareToolStreamResponse() {
2173 chunks := []string{
2174 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2175 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2176 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"value"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2177 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2178 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2179 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2180 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2181 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2182 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
2183 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
2184 "data: [DONE]\n\n",
2185 }
2186 sms.chunks = chunks
2187}
2188
2189func (sms *streamingMockServer) prepareErrorStreamResponse() {
2190 chunks := []string{
2191 `data: {"error":{"message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error.","type":"server_error","param":null,"code":null}}` + "\n\n",
2192 "data: [DONE]\n\n",
2193 }
2194 sms.chunks = chunks
2195}
2196
2197func collectStreamParts(stream ai.StreamResponse) ([]ai.StreamPart, error) {
2198 var parts []ai.StreamPart
2199 for part := range stream {
2200 parts = append(parts, part)
2201 if part.Type == ai.StreamPartTypeError {
2202 break
2203 }
2204 if part.Type == ai.StreamPartTypeFinish {
2205 break
2206 }
2207 }
2208 return parts, nil
2209}
2210
2211func TestDoStream(t *testing.T) {
2212 t.Parallel()
2213
2214 t.Run("should stream text deltas", func(t *testing.T) {
2215 t.Parallel()
2216
2217 server := newStreamingMockServer()
2218 defer server.close()
2219
2220 server.prepareStreamResponse(map[string]any{
2221 "content": []string{"Hello", ", ", "World!"},
2222 "finish_reason": "stop",
2223 "usage": map[string]any{
2224 "prompt_tokens": 17,
2225 "total_tokens": 244,
2226 "completion_tokens": 227,
2227 },
2228 "logprobs": testLogprobs,
2229 })
2230
2231 provider := NewOpenAIProvider(
2232 WithOpenAIApiKey("test-api-key"),
2233 WithOpenAIBaseURL(server.server.URL),
2234 )
2235 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2236
2237 stream, err := model.Stream(context.Background(), ai.Call{
2238 Prompt: testPrompt,
2239 })
2240
2241 require.NoError(t, err)
2242
2243 parts, err := collectStreamParts(stream)
2244 require.NoError(t, err)
2245
2246 // Verify stream structure
2247 require.True(t, len(parts) >= 4) // text-start, deltas, text-end, finish
2248
2249 // Find text parts
2250 textStart, textEnd, finish := -1, -1, -1
2251 var deltas []string
2252
2253 for i, part := range parts {
2254 switch part.Type {
2255 case ai.StreamPartTypeTextStart:
2256 textStart = i
2257 case ai.StreamPartTypeTextDelta:
2258 deltas = append(deltas, part.Delta)
2259 case ai.StreamPartTypeTextEnd:
2260 textEnd = i
2261 case ai.StreamPartTypeFinish:
2262 finish = i
2263 }
2264 }
2265
2266 require.NotEqual(t, -1, textStart)
2267 require.NotEqual(t, -1, textEnd)
2268 require.NotEqual(t, -1, finish)
2269 require.Equal(t, []string{"Hello", ", ", "World!"}, deltas)
2270
2271 // Check finish part
2272 finishPart := parts[finish]
2273 require.Equal(t, ai.FinishReasonStop, finishPart.FinishReason)
2274 require.Equal(t, int64(17), finishPart.Usage.InputTokens)
2275 require.Equal(t, int64(227), finishPart.Usage.OutputTokens)
2276 require.Equal(t, int64(244), finishPart.Usage.TotalTokens)
2277 })
2278
2279 t.Run("should stream tool deltas", func(t *testing.T) {
2280 t.Parallel()
2281
2282 server := newStreamingMockServer()
2283 defer server.close()
2284
2285 server.prepareToolStreamResponse()
2286
2287 provider := NewOpenAIProvider(
2288 WithOpenAIApiKey("test-api-key"),
2289 WithOpenAIBaseURL(server.server.URL),
2290 )
2291 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2292
2293 stream, err := model.Stream(context.Background(), ai.Call{
2294 Prompt: testPrompt,
2295 Tools: []ai.Tool{
2296 ai.FunctionTool{
2297 Name: "test-tool",
2298 InputSchema: map[string]any{
2299 "type": "object",
2300 "properties": map[string]any{
2301 "value": map[string]any{
2302 "type": "string",
2303 },
2304 },
2305 "required": []string{"value"},
2306 "additionalProperties": false,
2307 "$schema": "http://json-schema.org/draft-07/schema#",
2308 },
2309 },
2310 },
2311 })
2312
2313 require.NoError(t, err)
2314
2315 parts, err := collectStreamParts(stream)
2316 require.NoError(t, err)
2317
2318 // Find tool-related parts
2319 toolInputStart, toolInputEnd, toolCall := -1, -1, -1
2320 var toolDeltas []string
2321
2322 for i, part := range parts {
2323 switch part.Type {
2324 case ai.StreamPartTypeToolInputStart:
2325 toolInputStart = i
2326 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2327 require.Equal(t, "test-tool", part.ToolCallName)
2328 case ai.StreamPartTypeToolInputDelta:
2329 toolDeltas = append(toolDeltas, part.Delta)
2330 case ai.StreamPartTypeToolInputEnd:
2331 toolInputEnd = i
2332 case ai.StreamPartTypeToolCall:
2333 toolCall = i
2334 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2335 require.Equal(t, "test-tool", part.ToolCallName)
2336 require.Equal(t, `{"value":"Sparkle Day"}`, part.ToolCallInput)
2337 }
2338 }
2339
2340 require.NotEqual(t, -1, toolInputStart)
2341 require.NotEqual(t, -1, toolInputEnd)
2342 require.NotEqual(t, -1, toolCall)
2343
2344 // Verify tool deltas combine to form the complete input
2345 fullInput := ""
2346 for _, delta := range toolDeltas {
2347 fullInput += delta
2348 }
2349 require.Equal(t, `{"value":"Sparkle Day"}`, fullInput)
2350 })
2351
2352 t.Run("should stream annotations/citations", func(t *testing.T) {
2353 t.Parallel()
2354
2355 server := newStreamingMockServer()
2356 defer server.close()
2357
2358 server.prepareStreamResponse(map[string]any{
2359 "content": []string{"Based on search results"},
2360 "annotations": []map[string]any{
2361 {
2362 "type": "url_citation",
2363 "url_citation": map[string]any{
2364 "start_index": 24,
2365 "end_index": 29,
2366 "url": "https://example.com/doc1.pdf",
2367 "title": "Document 1",
2368 },
2369 },
2370 },
2371 })
2372
2373 provider := NewOpenAIProvider(
2374 WithOpenAIApiKey("test-api-key"),
2375 WithOpenAIBaseURL(server.server.URL),
2376 )
2377 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2378
2379 stream, err := model.Stream(context.Background(), ai.Call{
2380 Prompt: testPrompt,
2381 })
2382
2383 require.NoError(t, err)
2384
2385 parts, err := collectStreamParts(stream)
2386 require.NoError(t, err)
2387
2388 // Find source part
2389 var sourcePart *ai.StreamPart
2390 for _, part := range parts {
2391 if part.Type == ai.StreamPartTypeSource {
2392 sourcePart = &part
2393 break
2394 }
2395 }
2396
2397 require.NotNil(t, sourcePart)
2398 require.Equal(t, ai.SourceTypeURL, sourcePart.SourceType)
2399 require.Equal(t, "https://example.com/doc1.pdf", sourcePart.URL)
2400 require.Equal(t, "Document 1", sourcePart.Title)
2401 require.NotEmpty(t, sourcePart.ID)
2402 })
2403
2404 t.Run("should handle error stream parts", func(t *testing.T) {
2405 t.Parallel()
2406
2407 server := newStreamingMockServer()
2408 defer server.close()
2409
2410 server.prepareErrorStreamResponse()
2411
2412 provider := NewOpenAIProvider(
2413 WithOpenAIApiKey("test-api-key"),
2414 WithOpenAIBaseURL(server.server.URL),
2415 )
2416 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2417
2418 stream, err := model.Stream(context.Background(), ai.Call{
2419 Prompt: testPrompt,
2420 })
2421
2422 require.NoError(t, err)
2423
2424 parts, err := collectStreamParts(stream)
2425 require.NoError(t, err)
2426
2427 // Should have error and finish parts
2428 require.True(t, len(parts) >= 1)
2429
2430 // Find error part
2431 var errorPart *ai.StreamPart
2432 for _, part := range parts {
2433 if part.Type == ai.StreamPartTypeError {
2434 errorPart = &part
2435 break
2436 }
2437 }
2438
2439 require.NotNil(t, errorPart)
2440 require.NotNil(t, errorPart.Error)
2441 })
2442
2443 t.Run("should send request body", func(t *testing.T) {
2444 t.Parallel()
2445
2446 server := newStreamingMockServer()
2447 defer server.close()
2448
2449 server.prepareStreamResponse(map[string]any{
2450 "content": []string{},
2451 })
2452
2453 provider := NewOpenAIProvider(
2454 WithOpenAIApiKey("test-api-key"),
2455 WithOpenAIBaseURL(server.server.URL),
2456 )
2457 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2458
2459 _, err := model.Stream(context.Background(), ai.Call{
2460 Prompt: testPrompt,
2461 })
2462
2463 require.NoError(t, err)
2464 require.Len(t, server.calls, 1)
2465
2466 call := server.calls[0]
2467 require.Equal(t, "POST", call.method)
2468 require.Equal(t, "/chat/completions", call.path)
2469 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2470 require.Equal(t, true, call.body["stream"])
2471
2472 streamOptions := call.body["stream_options"].(map[string]any)
2473 require.Equal(t, true, streamOptions["include_usage"])
2474
2475 messages := call.body["messages"].([]any)
2476 require.Len(t, messages, 1)
2477
2478 message := messages[0].(map[string]any)
2479 require.Equal(t, "user", message["role"])
2480 require.Equal(t, "Hello", message["content"])
2481 })
2482
2483 t.Run("should return cached tokens in providerMetadata", func(t *testing.T) {
2484 t.Parallel()
2485
2486 server := newStreamingMockServer()
2487 defer server.close()
2488
2489 server.prepareStreamResponse(map[string]any{
2490 "content": []string{},
2491 "usage": map[string]any{
2492 "prompt_tokens": 15,
2493 "completion_tokens": 20,
2494 "total_tokens": 35,
2495 "prompt_tokens_details": map[string]any{
2496 "cached_tokens": 1152,
2497 },
2498 },
2499 })
2500
2501 provider := NewOpenAIProvider(
2502 WithOpenAIApiKey("test-api-key"),
2503 WithOpenAIBaseURL(server.server.URL),
2504 )
2505 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2506
2507 stream, err := model.Stream(context.Background(), ai.Call{
2508 Prompt: testPrompt,
2509 })
2510
2511 require.NoError(t, err)
2512
2513 parts, err := collectStreamParts(stream)
2514 require.NoError(t, err)
2515
2516 // Find finish part
2517 var finishPart *ai.StreamPart
2518 for _, part := range parts {
2519 if part.Type == ai.StreamPartTypeFinish {
2520 finishPart = &part
2521 break
2522 }
2523 }
2524
2525 require.NotNil(t, finishPart)
2526 require.Equal(t, int64(1152), finishPart.Usage.CacheReadTokens)
2527 require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2528 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2529 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2530 })
2531
2532 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
2533 t.Parallel()
2534
2535 server := newStreamingMockServer()
2536 defer server.close()
2537
2538 server.prepareStreamResponse(map[string]any{
2539 "content": []string{},
2540 "usage": map[string]any{
2541 "prompt_tokens": 15,
2542 "completion_tokens": 20,
2543 "total_tokens": 35,
2544 "completion_tokens_details": map[string]any{
2545 "accepted_prediction_tokens": 123,
2546 "rejected_prediction_tokens": 456,
2547 },
2548 },
2549 })
2550
2551 provider := NewOpenAIProvider(
2552 WithOpenAIApiKey("test-api-key"),
2553 WithOpenAIBaseURL(server.server.URL),
2554 )
2555 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2556
2557 stream, err := model.Stream(context.Background(), ai.Call{
2558 Prompt: testPrompt,
2559 })
2560
2561 require.NoError(t, err)
2562
2563 parts, err := collectStreamParts(stream)
2564 require.NoError(t, err)
2565
2566 // Find finish part
2567 var finishPart *ai.StreamPart
2568 for _, part := range parts {
2569 if part.Type == ai.StreamPartTypeFinish {
2570 finishPart = &part
2571 break
2572 }
2573 }
2574
2575 require.NotNil(t, finishPart)
2576 require.NotNil(t, finishPart.ProviderMetadata)
2577
2578 openaiMeta, ok := finishPart.ProviderMetadata["openai"]
2579 require.True(t, ok)
2580 require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"])
2581 require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"])
2582 })
2583
2584 t.Run("should send store extension setting", func(t *testing.T) {
2585 t.Parallel()
2586
2587 server := newStreamingMockServer()
2588 defer server.close()
2589
2590 server.prepareStreamResponse(map[string]any{
2591 "content": []string{},
2592 })
2593
2594 provider := NewOpenAIProvider(
2595 WithOpenAIApiKey("test-api-key"),
2596 WithOpenAIBaseURL(server.server.URL),
2597 )
2598 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2599
2600 _, err := model.Stream(context.Background(), ai.Call{
2601 Prompt: testPrompt,
2602 ProviderOptions: ai.ProviderOptions{
2603 "openai": map[string]any{
2604 "store": true,
2605 },
2606 },
2607 })
2608
2609 require.NoError(t, err)
2610 require.Len(t, server.calls, 1)
2611
2612 call := server.calls[0]
2613 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2614 require.Equal(t, true, call.body["stream"])
2615 require.Equal(t, true, call.body["store"])
2616
2617 streamOptions := call.body["stream_options"].(map[string]any)
2618 require.Equal(t, true, streamOptions["include_usage"])
2619
2620 messages := call.body["messages"].([]any)
2621 require.Len(t, messages, 1)
2622
2623 message := messages[0].(map[string]any)
2624 require.Equal(t, "user", message["role"])
2625 require.Equal(t, "Hello", message["content"])
2626 })
2627
2628 t.Run("should send metadata extension values", func(t *testing.T) {
2629 t.Parallel()
2630
2631 server := newStreamingMockServer()
2632 defer server.close()
2633
2634 server.prepareStreamResponse(map[string]any{
2635 "content": []string{},
2636 })
2637
2638 provider := NewOpenAIProvider(
2639 WithOpenAIApiKey("test-api-key"),
2640 WithOpenAIBaseURL(server.server.URL),
2641 )
2642 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2643
2644 _, err := model.Stream(context.Background(), ai.Call{
2645 Prompt: testPrompt,
2646 ProviderOptions: ai.ProviderOptions{
2647 "openai": map[string]any{
2648 "metadata": map[string]any{
2649 "custom": "value",
2650 },
2651 },
2652 },
2653 })
2654
2655 require.NoError(t, err)
2656 require.Len(t, server.calls, 1)
2657
2658 call := server.calls[0]
2659 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2660 require.Equal(t, true, call.body["stream"])
2661
2662 metadata := call.body["metadata"].(map[string]any)
2663 require.Equal(t, "value", metadata["custom"])
2664
2665 streamOptions := call.body["stream_options"].(map[string]any)
2666 require.Equal(t, true, streamOptions["include_usage"])
2667
2668 messages := call.body["messages"].([]any)
2669 require.Len(t, messages, 1)
2670
2671 message := messages[0].(map[string]any)
2672 require.Equal(t, "user", message["role"])
2673 require.Equal(t, "Hello", message["content"])
2674 })
2675
2676 t.Run("should send serviceTier flex processing setting in streaming", func(t *testing.T) {
2677 t.Parallel()
2678
2679 server := newStreamingMockServer()
2680 defer server.close()
2681
2682 server.prepareStreamResponse(map[string]any{
2683 "content": []string{},
2684 })
2685
2686 provider := NewOpenAIProvider(
2687 WithOpenAIApiKey("test-api-key"),
2688 WithOpenAIBaseURL(server.server.URL),
2689 )
2690 model, _ := provider.LanguageModel("o3-mini")
2691
2692 _, err := model.Stream(context.Background(), ai.Call{
2693 Prompt: testPrompt,
2694 ProviderOptions: ai.ProviderOptions{
2695 "openai": map[string]any{
2696 "serviceTier": "flex",
2697 },
2698 },
2699 })
2700
2701 require.NoError(t, err)
2702 require.Len(t, server.calls, 1)
2703
2704 call := server.calls[0]
2705 require.Equal(t, "o3-mini", call.body["model"])
2706 require.Equal(t, "flex", call.body["service_tier"])
2707 require.Equal(t, true, call.body["stream"])
2708
2709 streamOptions := call.body["stream_options"].(map[string]any)
2710 require.Equal(t, true, streamOptions["include_usage"])
2711
2712 messages := call.body["messages"].([]any)
2713 require.Len(t, messages, 1)
2714
2715 message := messages[0].(map[string]any)
2716 require.Equal(t, "user", message["role"])
2717 require.Equal(t, "Hello", message["content"])
2718 })
2719
2720 t.Run("should send serviceTier priority processing setting in streaming", func(t *testing.T) {
2721 t.Parallel()
2722
2723 server := newStreamingMockServer()
2724 defer server.close()
2725
2726 server.prepareStreamResponse(map[string]any{
2727 "content": []string{},
2728 })
2729
2730 provider := NewOpenAIProvider(
2731 WithOpenAIApiKey("test-api-key"),
2732 WithOpenAIBaseURL(server.server.URL),
2733 )
2734 model, _ := provider.LanguageModel("gpt-4o-mini")
2735
2736 _, err := model.Stream(context.Background(), ai.Call{
2737 Prompt: testPrompt,
2738 ProviderOptions: ai.ProviderOptions{
2739 "openai": map[string]any{
2740 "serviceTier": "priority",
2741 },
2742 },
2743 })
2744
2745 require.NoError(t, err)
2746 require.Len(t, server.calls, 1)
2747
2748 call := server.calls[0]
2749 require.Equal(t, "gpt-4o-mini", call.body["model"])
2750 require.Equal(t, "priority", call.body["service_tier"])
2751 require.Equal(t, true, call.body["stream"])
2752
2753 streamOptions := call.body["stream_options"].(map[string]any)
2754 require.Equal(t, true, streamOptions["include_usage"])
2755
2756 messages := call.body["messages"].([]any)
2757 require.Len(t, messages, 1)
2758
2759 message := messages[0].(map[string]any)
2760 require.Equal(t, "user", message["role"])
2761 require.Equal(t, "Hello", message["content"])
2762 })
2763
2764 t.Run("should stream text delta for reasoning models", func(t *testing.T) {
2765 t.Parallel()
2766
2767 server := newStreamingMockServer()
2768 defer server.close()
2769
2770 server.prepareStreamResponse(map[string]any{
2771 "content": []string{"Hello, World!"},
2772 "model": "o1-preview",
2773 })
2774
2775 provider := NewOpenAIProvider(
2776 WithOpenAIApiKey("test-api-key"),
2777 WithOpenAIBaseURL(server.server.URL),
2778 )
2779 model, _ := provider.LanguageModel("o1-preview")
2780
2781 stream, err := model.Stream(context.Background(), ai.Call{
2782 Prompt: testPrompt,
2783 })
2784
2785 require.NoError(t, err)
2786
2787 parts, err := collectStreamParts(stream)
2788 require.NoError(t, err)
2789
2790 // Find text parts
2791 var textDeltas []string
2792 for _, part := range parts {
2793 if part.Type == ai.StreamPartTypeTextDelta {
2794 textDeltas = append(textDeltas, part.Delta)
2795 }
2796 }
2797
2798 // Should contain the text content (without empty delta)
2799 require.Equal(t, []string{"Hello, World!"}, textDeltas)
2800 })
2801
2802 t.Run("should send reasoning tokens", func(t *testing.T) {
2803 t.Parallel()
2804
2805 server := newStreamingMockServer()
2806 defer server.close()
2807
2808 server.prepareStreamResponse(map[string]any{
2809 "content": []string{"Hello, World!"},
2810 "model": "o1-preview",
2811 "usage": map[string]any{
2812 "prompt_tokens": 15,
2813 "completion_tokens": 20,
2814 "total_tokens": 35,
2815 "completion_tokens_details": map[string]any{
2816 "reasoning_tokens": 10,
2817 },
2818 },
2819 })
2820
2821 provider := NewOpenAIProvider(
2822 WithOpenAIApiKey("test-api-key"),
2823 WithOpenAIBaseURL(server.server.URL),
2824 )
2825 model, _ := provider.LanguageModel("o1-preview")
2826
2827 stream, err := model.Stream(context.Background(), ai.Call{
2828 Prompt: testPrompt,
2829 })
2830
2831 require.NoError(t, err)
2832
2833 parts, err := collectStreamParts(stream)
2834 require.NoError(t, err)
2835
2836 // Find finish part
2837 var finishPart *ai.StreamPart
2838 for _, part := range parts {
2839 if part.Type == ai.StreamPartTypeFinish {
2840 finishPart = &part
2841 break
2842 }
2843 }
2844
2845 require.NotNil(t, finishPart)
2846 require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2847 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2848 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2849 require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens)
2850 })
2851}