1package openai
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "net/http"
9 "net/http/httptest"
10 "strings"
11 "testing"
12
13 "github.com/charmbracelet/ai/ai"
14 "github.com/openai/openai-go/v2/packages/param"
15 "github.com/stretchr/testify/require"
16)
17
18func TestToOpenAiPrompt_SystemMessages(t *testing.T) {
19 t.Parallel()
20
21 t.Run("should forward system messages", func(t *testing.T) {
22 t.Parallel()
23
24 prompt := ai.Prompt{
25 {
26 Role: ai.MessageRoleSystem,
27 Content: []ai.MessagePart{
28 ai.TextPart{Text: "You are a helpful assistant."},
29 },
30 },
31 }
32
33 messages, warnings := toPrompt(prompt)
34
35 require.Empty(t, warnings)
36 require.Len(t, messages, 1)
37
38 systemMsg := messages[0].OfSystem
39 require.NotNil(t, systemMsg)
40 require.Equal(t, "You are a helpful assistant.", systemMsg.Content.OfString.Value)
41 })
42
43 t.Run("should handle empty system messages", func(t *testing.T) {
44 t.Parallel()
45
46 prompt := ai.Prompt{
47 {
48 Role: ai.MessageRoleSystem,
49 Content: []ai.MessagePart{},
50 },
51 }
52
53 messages, warnings := toPrompt(prompt)
54
55 require.Len(t, warnings, 1)
56 require.Contains(t, warnings[0].Message, "system prompt has no text parts")
57 require.Empty(t, messages)
58 })
59
60 t.Run("should join multiple system text parts", func(t *testing.T) {
61 t.Parallel()
62
63 prompt := ai.Prompt{
64 {
65 Role: ai.MessageRoleSystem,
66 Content: []ai.MessagePart{
67 ai.TextPart{Text: "You are a helpful assistant."},
68 ai.TextPart{Text: "Be concise."},
69 },
70 },
71 }
72
73 messages, warnings := toPrompt(prompt)
74
75 require.Empty(t, warnings)
76 require.Len(t, messages, 1)
77
78 systemMsg := messages[0].OfSystem
79 require.NotNil(t, systemMsg)
80 require.Equal(t, "You are a helpful assistant.\nBe concise.", systemMsg.Content.OfString.Value)
81 })
82}
83
84func TestToOpenAiPrompt_UserMessages(t *testing.T) {
85 t.Parallel()
86
87 t.Run("should convert messages with only a text part to a string content", func(t *testing.T) {
88 t.Parallel()
89
90 prompt := ai.Prompt{
91 {
92 Role: ai.MessageRoleUser,
93 Content: []ai.MessagePart{
94 ai.TextPart{Text: "Hello"},
95 },
96 },
97 }
98
99 messages, warnings := toPrompt(prompt)
100
101 require.Empty(t, warnings)
102 require.Len(t, messages, 1)
103
104 userMsg := messages[0].OfUser
105 require.NotNil(t, userMsg)
106 require.Equal(t, "Hello", userMsg.Content.OfString.Value)
107 })
108
109 t.Run("should convert messages with image parts", func(t *testing.T) {
110 t.Parallel()
111
112 imageData := []byte{0, 1, 2, 3}
113 prompt := ai.Prompt{
114 {
115 Role: ai.MessageRoleUser,
116 Content: []ai.MessagePart{
117 ai.TextPart{Text: "Hello"},
118 ai.FilePart{
119 MediaType: "image/png",
120 Data: imageData,
121 },
122 },
123 },
124 }
125
126 messages, warnings := toPrompt(prompt)
127
128 require.Empty(t, warnings)
129 require.Len(t, messages, 1)
130
131 userMsg := messages[0].OfUser
132 require.NotNil(t, userMsg)
133
134 content := userMsg.Content.OfArrayOfContentParts
135 require.Len(t, content, 2)
136
137 // Check text part
138 textPart := content[0].OfText
139 require.NotNil(t, textPart)
140 require.Equal(t, "Hello", textPart.Text)
141
142 // Check image part
143 imagePart := content[1].OfImageURL
144 require.NotNil(t, imagePart)
145 expectedURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
146 require.Equal(t, expectedURL, imagePart.ImageURL.URL)
147 })
148
149 t.Run("should add image detail when specified through provider options", func(t *testing.T) {
150 t.Parallel()
151
152 imageData := []byte{0, 1, 2, 3}
153 prompt := ai.Prompt{
154 {
155 Role: ai.MessageRoleUser,
156 Content: []ai.MessagePart{
157 ai.FilePart{
158 MediaType: "image/png",
159 Data: imageData,
160 ProviderOptions: NewProviderFileOptions(&ProviderFileOptions{
161 ImageDetail: "low",
162 }),
163 },
164 },
165 },
166 }
167
168 messages, warnings := toPrompt(prompt)
169
170 require.Empty(t, warnings)
171 require.Len(t, messages, 1)
172
173 userMsg := messages[0].OfUser
174 require.NotNil(t, userMsg)
175
176 content := userMsg.Content.OfArrayOfContentParts
177 require.Len(t, content, 1)
178
179 imagePart := content[0].OfImageURL
180 require.NotNil(t, imagePart)
181 require.Equal(t, "low", imagePart.ImageURL.Detail)
182 })
183}
184
185func TestToOpenAiPrompt_FileParts(t *testing.T) {
186 t.Parallel()
187
188 t.Run("should throw for unsupported mime types", func(t *testing.T) {
189 t.Parallel()
190
191 prompt := ai.Prompt{
192 {
193 Role: ai.MessageRoleUser,
194 Content: []ai.MessagePart{
195 ai.FilePart{
196 MediaType: "application/something",
197 Data: []byte("test"),
198 },
199 },
200 },
201 }
202
203 messages, warnings := toPrompt(prompt)
204
205 require.Len(t, warnings, 1)
206 require.Contains(t, warnings[0].Message, "file part media type application/something not supported")
207 require.Len(t, messages, 1) // Message is still created but with empty content array
208 })
209
210 t.Run("should add audio content for audio/wav file parts", func(t *testing.T) {
211 t.Parallel()
212
213 audioData := []byte{0, 1, 2, 3}
214 prompt := ai.Prompt{
215 {
216 Role: ai.MessageRoleUser,
217 Content: []ai.MessagePart{
218 ai.FilePart{
219 MediaType: "audio/wav",
220 Data: audioData,
221 },
222 },
223 },
224 }
225
226 messages, warnings := toPrompt(prompt)
227
228 require.Empty(t, warnings)
229 require.Len(t, messages, 1)
230
231 userMsg := messages[0].OfUser
232 require.NotNil(t, userMsg)
233
234 content := userMsg.Content.OfArrayOfContentParts
235 require.Len(t, content, 1)
236
237 audioPart := content[0].OfInputAudio
238 require.NotNil(t, audioPart)
239 require.Equal(t, base64.StdEncoding.EncodeToString(audioData), audioPart.InputAudio.Data)
240 require.Equal(t, "wav", audioPart.InputAudio.Format)
241 })
242
243 t.Run("should add audio content for audio/mpeg file parts", func(t *testing.T) {
244 t.Parallel()
245
246 audioData := []byte{0, 1, 2, 3}
247 prompt := ai.Prompt{
248 {
249 Role: ai.MessageRoleUser,
250 Content: []ai.MessagePart{
251 ai.FilePart{
252 MediaType: "audio/mpeg",
253 Data: audioData,
254 },
255 },
256 },
257 }
258
259 messages, warnings := toPrompt(prompt)
260
261 require.Empty(t, warnings)
262 require.Len(t, messages, 1)
263
264 userMsg := messages[0].OfUser
265 content := userMsg.Content.OfArrayOfContentParts
266 audioPart := content[0].OfInputAudio
267 require.NotNil(t, audioPart)
268 require.Equal(t, "mp3", audioPart.InputAudio.Format)
269 })
270
271 t.Run("should add audio content for audio/mp3 file parts", func(t *testing.T) {
272 t.Parallel()
273
274 audioData := []byte{0, 1, 2, 3}
275 prompt := ai.Prompt{
276 {
277 Role: ai.MessageRoleUser,
278 Content: []ai.MessagePart{
279 ai.FilePart{
280 MediaType: "audio/mp3",
281 Data: audioData,
282 },
283 },
284 },
285 }
286
287 messages, warnings := toPrompt(prompt)
288
289 require.Empty(t, warnings)
290 require.Len(t, messages, 1)
291
292 userMsg := messages[0].OfUser
293 content := userMsg.Content.OfArrayOfContentParts
294 audioPart := content[0].OfInputAudio
295 require.NotNil(t, audioPart)
296 require.Equal(t, "mp3", audioPart.InputAudio.Format)
297 })
298
299 t.Run("should convert messages with PDF file parts", func(t *testing.T) {
300 t.Parallel()
301
302 pdfData := []byte{1, 2, 3, 4, 5}
303 prompt := ai.Prompt{
304 {
305 Role: ai.MessageRoleUser,
306 Content: []ai.MessagePart{
307 ai.FilePart{
308 MediaType: "application/pdf",
309 Data: pdfData,
310 Filename: "document.pdf",
311 },
312 },
313 },
314 }
315
316 messages, warnings := toPrompt(prompt)
317
318 require.Empty(t, warnings)
319 require.Len(t, messages, 1)
320
321 userMsg := messages[0].OfUser
322 content := userMsg.Content.OfArrayOfContentParts
323 require.Len(t, content, 1)
324
325 filePart := content[0].OfFile
326 require.NotNil(t, filePart)
327 require.Equal(t, "document.pdf", filePart.File.Filename.Value)
328
329 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
330 require.Equal(t, expectedData, filePart.File.FileData.Value)
331 })
332
333 t.Run("should convert messages with binary PDF file parts", func(t *testing.T) {
334 t.Parallel()
335
336 pdfData := []byte{1, 2, 3, 4, 5}
337 prompt := ai.Prompt{
338 {
339 Role: ai.MessageRoleUser,
340 Content: []ai.MessagePart{
341 ai.FilePart{
342 MediaType: "application/pdf",
343 Data: pdfData,
344 Filename: "document.pdf",
345 },
346 },
347 },
348 }
349
350 messages, warnings := toPrompt(prompt)
351
352 require.Empty(t, warnings)
353 require.Len(t, messages, 1)
354
355 userMsg := messages[0].OfUser
356 content := userMsg.Content.OfArrayOfContentParts
357 filePart := content[0].OfFile
358 require.NotNil(t, filePart)
359
360 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
361 require.Equal(t, expectedData, filePart.File.FileData.Value)
362 })
363
364 t.Run("should convert messages with PDF file parts using file_id", func(t *testing.T) {
365 t.Parallel()
366
367 prompt := ai.Prompt{
368 {
369 Role: ai.MessageRoleUser,
370 Content: []ai.MessagePart{
371 ai.FilePart{
372 MediaType: "application/pdf",
373 Data: []byte("file-pdf-12345"),
374 },
375 },
376 },
377 }
378
379 messages, warnings := toPrompt(prompt)
380
381 require.Empty(t, warnings)
382 require.Len(t, messages, 1)
383
384 userMsg := messages[0].OfUser
385 content := userMsg.Content.OfArrayOfContentParts
386 filePart := content[0].OfFile
387 require.NotNil(t, filePart)
388 require.Equal(t, "file-pdf-12345", filePart.File.FileID.Value)
389 require.True(t, param.IsOmitted(filePart.File.FileData))
390 require.True(t, param.IsOmitted(filePart.File.Filename))
391 })
392
393 t.Run("should use default filename for PDF file parts when not provided", func(t *testing.T) {
394 t.Parallel()
395
396 pdfData := []byte{1, 2, 3, 4, 5}
397 prompt := ai.Prompt{
398 {
399 Role: ai.MessageRoleUser,
400 Content: []ai.MessagePart{
401 ai.FilePart{
402 MediaType: "application/pdf",
403 Data: pdfData,
404 },
405 },
406 },
407 }
408
409 messages, warnings := toPrompt(prompt)
410
411 require.Empty(t, warnings)
412 require.Len(t, messages, 1)
413
414 userMsg := messages[0].OfUser
415 content := userMsg.Content.OfArrayOfContentParts
416 filePart := content[0].OfFile
417 require.NotNil(t, filePart)
418 require.Equal(t, "part-0.pdf", filePart.File.Filename.Value)
419 })
420}
421
422func TestToOpenAiPrompt_ToolCalls(t *testing.T) {
423 t.Parallel()
424
425 t.Run("should stringify arguments to tool calls", func(t *testing.T) {
426 t.Parallel()
427
428 inputArgs := map[string]any{"foo": "bar123"}
429 inputJSON, _ := json.Marshal(inputArgs)
430
431 outputResult := map[string]any{"oof": "321rab"}
432 outputJSON, _ := json.Marshal(outputResult)
433
434 prompt := ai.Prompt{
435 {
436 Role: ai.MessageRoleAssistant,
437 Content: []ai.MessagePart{
438 ai.ToolCallPart{
439 ToolCallID: "quux",
440 ToolName: "thwomp",
441 Input: string(inputJSON),
442 },
443 },
444 },
445 {
446 Role: ai.MessageRoleTool,
447 Content: []ai.MessagePart{
448 ai.ToolResultPart{
449 ToolCallID: "quux",
450 Output: ai.ToolResultOutputContentText{
451 Text: string(outputJSON),
452 },
453 },
454 },
455 },
456 }
457
458 messages, warnings := toPrompt(prompt)
459
460 require.Empty(t, warnings)
461 require.Len(t, messages, 2)
462
463 // Check assistant message with tool call
464 assistantMsg := messages[0].OfAssistant
465 require.NotNil(t, assistantMsg)
466 require.Equal(t, "", assistantMsg.Content.OfString.Value)
467 require.Len(t, assistantMsg.ToolCalls, 1)
468
469 toolCall := assistantMsg.ToolCalls[0].OfFunction
470 require.NotNil(t, toolCall)
471 require.Equal(t, "quux", toolCall.ID)
472 require.Equal(t, "thwomp", toolCall.Function.Name)
473 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
474
475 // Check tool message
476 toolMsg := messages[1].OfTool
477 require.NotNil(t, toolMsg)
478 require.Equal(t, string(outputJSON), toolMsg.Content.OfString.Value)
479 require.Equal(t, "quux", toolMsg.ToolCallID)
480 })
481
482 t.Run("should handle different tool output types", func(t *testing.T) {
483 t.Parallel()
484
485 prompt := ai.Prompt{
486 {
487 Role: ai.MessageRoleTool,
488 Content: []ai.MessagePart{
489 ai.ToolResultPart{
490 ToolCallID: "text-tool",
491 Output: ai.ToolResultOutputContentText{
492 Text: "Hello world",
493 },
494 },
495 ai.ToolResultPart{
496 ToolCallID: "error-tool",
497 Output: ai.ToolResultOutputContentError{
498 Error: errors.New("Something went wrong"),
499 },
500 },
501 },
502 },
503 }
504
505 messages, warnings := toPrompt(prompt)
506
507 require.Empty(t, warnings)
508 require.Len(t, messages, 2)
509
510 // Check first tool message (text)
511 textToolMsg := messages[0].OfTool
512 require.NotNil(t, textToolMsg)
513 require.Equal(t, "Hello world", textToolMsg.Content.OfString.Value)
514 require.Equal(t, "text-tool", textToolMsg.ToolCallID)
515
516 // Check second tool message (error)
517 errorToolMsg := messages[1].OfTool
518 require.NotNil(t, errorToolMsg)
519 require.Equal(t, "Something went wrong", errorToolMsg.Content.OfString.Value)
520 require.Equal(t, "error-tool", errorToolMsg.ToolCallID)
521 })
522}
523
524func TestToOpenAiPrompt_AssistantMessages(t *testing.T) {
525 t.Parallel()
526
527 t.Run("should handle simple text assistant messages", func(t *testing.T) {
528 t.Parallel()
529
530 prompt := ai.Prompt{
531 {
532 Role: ai.MessageRoleAssistant,
533 Content: []ai.MessagePart{
534 ai.TextPart{Text: "Hello, how can I help you?"},
535 },
536 },
537 }
538
539 messages, warnings := toPrompt(prompt)
540
541 require.Empty(t, warnings)
542 require.Len(t, messages, 1)
543
544 assistantMsg := messages[0].OfAssistant
545 require.NotNil(t, assistantMsg)
546 require.Equal(t, "Hello, how can I help you?", assistantMsg.Content.OfString.Value)
547 })
548
549 t.Run("should handle assistant messages with mixed content", func(t *testing.T) {
550 t.Parallel()
551
552 inputArgs := map[string]any{"query": "test"}
553 inputJSON, _ := json.Marshal(inputArgs)
554
555 prompt := ai.Prompt{
556 {
557 Role: ai.MessageRoleAssistant,
558 Content: []ai.MessagePart{
559 ai.TextPart{Text: "Let me search for that."},
560 ai.ToolCallPart{
561 ToolCallID: "call-123",
562 ToolName: "search",
563 Input: string(inputJSON),
564 },
565 },
566 },
567 }
568
569 messages, warnings := toPrompt(prompt)
570
571 require.Empty(t, warnings)
572 require.Len(t, messages, 1)
573
574 assistantMsg := messages[0].OfAssistant
575 require.NotNil(t, assistantMsg)
576 require.Equal(t, "Let me search for that.", assistantMsg.Content.OfString.Value)
577 require.Len(t, assistantMsg.ToolCalls, 1)
578
579 toolCall := assistantMsg.ToolCalls[0].OfFunction
580 require.Equal(t, "call-123", toolCall.ID)
581 require.Equal(t, "search", toolCall.Function.Name)
582 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
583 })
584}
585
586var testPrompt = ai.Prompt{
587 {
588 Role: ai.MessageRoleUser,
589 Content: []ai.MessagePart{
590 ai.TextPart{Text: "Hello"},
591 },
592 },
593}
594
595var testLogprobs = map[string]any{
596 "content": []map[string]any{
597 {
598 "token": "Hello",
599 "logprob": -0.0009994634,
600 "top_logprobs": []map[string]any{
601 {
602 "token": "Hello",
603 "logprob": -0.0009994634,
604 },
605 },
606 },
607 {
608 "token": "!",
609 "logprob": -0.13410144,
610 "top_logprobs": []map[string]any{
611 {
612 "token": "!",
613 "logprob": -0.13410144,
614 },
615 },
616 },
617 {
618 "token": " How",
619 "logprob": -0.0009250381,
620 "top_logprobs": []map[string]any{
621 {
622 "token": " How",
623 "logprob": -0.0009250381,
624 },
625 },
626 },
627 {
628 "token": " can",
629 "logprob": -0.047709424,
630 "top_logprobs": []map[string]any{
631 {
632 "token": " can",
633 "logprob": -0.047709424,
634 },
635 },
636 },
637 {
638 "token": " I",
639 "logprob": -0.000009014684,
640 "top_logprobs": []map[string]any{
641 {
642 "token": " I",
643 "logprob": -0.000009014684,
644 },
645 },
646 },
647 {
648 "token": " assist",
649 "logprob": -0.009125131,
650 "top_logprobs": []map[string]any{
651 {
652 "token": " assist",
653 "logprob": -0.009125131,
654 },
655 },
656 },
657 {
658 "token": " you",
659 "logprob": -0.0000066306106,
660 "top_logprobs": []map[string]any{
661 {
662 "token": " you",
663 "logprob": -0.0000066306106,
664 },
665 },
666 },
667 {
668 "token": " today",
669 "logprob": -0.00011093382,
670 "top_logprobs": []map[string]any{
671 {
672 "token": " today",
673 "logprob": -0.00011093382,
674 },
675 },
676 },
677 {
678 "token": "?",
679 "logprob": -0.00004596782,
680 "top_logprobs": []map[string]any{
681 {
682 "token": "?",
683 "logprob": -0.00004596782,
684 },
685 },
686 },
687 },
688}
689
690type mockServer struct {
691 server *httptest.Server
692 response map[string]any
693 calls []mockCall
694}
695
696type mockCall struct {
697 method string
698 path string
699 headers map[string]string
700 body map[string]any
701}
702
703func newMockServer() *mockServer {
704 ms := &mockServer{
705 calls: make([]mockCall, 0),
706 }
707
708 ms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
709 // Record the call
710 call := mockCall{
711 method: r.Method,
712 path: r.URL.Path,
713 headers: make(map[string]string),
714 }
715
716 for k, v := range r.Header {
717 if len(v) > 0 {
718 call.headers[k] = v[0]
719 }
720 }
721
722 // Parse request body
723 if r.Body != nil {
724 var body map[string]any
725 json.NewDecoder(r.Body).Decode(&body)
726 call.body = body
727 }
728
729 ms.calls = append(ms.calls, call)
730
731 // Return mock response
732 w.Header().Set("Content-Type", "application/json")
733 json.NewEncoder(w).Encode(ms.response)
734 }))
735
736 return ms
737}
738
739func (ms *mockServer) close() {
740 ms.server.Close()
741}
742
743func (ms *mockServer) prepareJSONResponse(opts map[string]any) {
744 // Default values
745 response := map[string]any{
746 "id": "chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd",
747 "object": "chat.completion",
748 "created": 1711115037,
749 "model": "gpt-3.5-turbo-0125",
750 "choices": []map[string]any{
751 {
752 "index": 0,
753 "message": map[string]any{
754 "role": "assistant",
755 "content": "",
756 },
757 "finish_reason": "stop",
758 },
759 },
760 "usage": map[string]any{
761 "prompt_tokens": 4,
762 "total_tokens": 34,
763 "completion_tokens": 30,
764 },
765 "system_fingerprint": "fp_3bc1b5746c",
766 }
767
768 // Override with provided options
769 for k, v := range opts {
770 switch k {
771 case "content":
772 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["content"] = v
773 case "tool_calls":
774 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["tool_calls"] = v
775 case "function_call":
776 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["function_call"] = v
777 case "annotations":
778 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["annotations"] = v
779 case "usage":
780 response["usage"] = v
781 case "finish_reason":
782 response["choices"].([]map[string]any)[0]["finish_reason"] = v
783 case "id":
784 response["id"] = v
785 case "created":
786 response["created"] = v
787 case "model":
788 response["model"] = v
789 case "logprobs":
790 if v != nil {
791 response["choices"].([]map[string]any)[0]["logprobs"] = v
792 }
793 }
794 }
795
796 ms.response = response
797}
798
799func TestDoGenerate(t *testing.T) {
800 t.Parallel()
801
802 t.Run("should extract text response", func(t *testing.T) {
803 t.Parallel()
804
805 server := newMockServer()
806 defer server.close()
807
808 server.prepareJSONResponse(map[string]any{
809 "content": "Hello, World!",
810 })
811
812 provider := New(
813 WithAPIKey("test-api-key"),
814 WithBaseURL(server.server.URL),
815 )
816 model, _ := provider.LanguageModel("gpt-3.5-turbo")
817
818 result, err := model.Generate(context.Background(), ai.Call{
819 Prompt: testPrompt,
820 })
821
822 require.NoError(t, err)
823 require.Len(t, result.Content, 1)
824
825 textContent, ok := result.Content[0].(ai.TextContent)
826 require.True(t, ok)
827 require.Equal(t, "Hello, World!", textContent.Text)
828 })
829
830 t.Run("should extract usage", func(t *testing.T) {
831 t.Parallel()
832
833 server := newMockServer()
834 defer server.close()
835
836 server.prepareJSONResponse(map[string]any{
837 "usage": map[string]any{
838 "prompt_tokens": 20,
839 "total_tokens": 25,
840 "completion_tokens": 5,
841 },
842 })
843
844 provider := New(
845 WithAPIKey("test-api-key"),
846 WithBaseURL(server.server.URL),
847 )
848 model, _ := provider.LanguageModel("gpt-3.5-turbo")
849
850 result, err := model.Generate(context.Background(), ai.Call{
851 Prompt: testPrompt,
852 })
853
854 require.NoError(t, err)
855 require.Equal(t, int64(20), result.Usage.InputTokens)
856 require.Equal(t, int64(5), result.Usage.OutputTokens)
857 require.Equal(t, int64(25), result.Usage.TotalTokens)
858 })
859
860 t.Run("should send request body", func(t *testing.T) {
861 t.Parallel()
862
863 server := newMockServer()
864 defer server.close()
865
866 server.prepareJSONResponse(map[string]any{})
867
868 provider := New(
869 WithAPIKey("test-api-key"),
870 WithBaseURL(server.server.URL),
871 )
872 model, _ := provider.LanguageModel("gpt-3.5-turbo")
873
874 _, err := model.Generate(context.Background(), ai.Call{
875 Prompt: testPrompt,
876 })
877
878 require.NoError(t, err)
879 require.Len(t, server.calls, 1)
880
881 call := server.calls[0]
882 require.Equal(t, "POST", call.method)
883 require.Equal(t, "/chat/completions", call.path)
884 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
885
886 messages, ok := call.body["messages"].([]any)
887 require.True(t, ok)
888 require.Len(t, messages, 1)
889
890 message := messages[0].(map[string]any)
891 require.Equal(t, "user", message["role"])
892 require.Equal(t, "Hello", message["content"])
893 })
894
895 t.Run("should support partial usage", func(t *testing.T) {
896 t.Parallel()
897
898 server := newMockServer()
899 defer server.close()
900
901 server.prepareJSONResponse(map[string]any{
902 "usage": map[string]any{
903 "prompt_tokens": 20,
904 "total_tokens": 20,
905 },
906 })
907
908 provider := New(
909 WithAPIKey("test-api-key"),
910 WithBaseURL(server.server.URL),
911 )
912 model, _ := provider.LanguageModel("gpt-3.5-turbo")
913
914 result, err := model.Generate(context.Background(), ai.Call{
915 Prompt: testPrompt,
916 })
917
918 require.NoError(t, err)
919 require.Equal(t, int64(20), result.Usage.InputTokens)
920 require.Equal(t, int64(0), result.Usage.OutputTokens)
921 require.Equal(t, int64(20), result.Usage.TotalTokens)
922 })
923
924 t.Run("should extract logprobs", func(t *testing.T) {
925 t.Parallel()
926
927 server := newMockServer()
928 defer server.close()
929
930 server.prepareJSONResponse(map[string]any{
931 "logprobs": testLogprobs,
932 })
933
934 provider := New(
935 WithAPIKey("test-api-key"),
936 WithBaseURL(server.server.URL),
937 )
938 model, _ := provider.LanguageModel("gpt-3.5-turbo")
939
940 result, err := model.Generate(context.Background(), ai.Call{
941 Prompt: testPrompt,
942 ProviderOptions: NewProviderOptions(&ProviderOptions{
943 LogProbs: ai.BoolOption(true),
944 }),
945 })
946
947 require.NoError(t, err)
948 require.NotNil(t, result.ProviderMetadata)
949
950 openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
951 require.True(t, ok)
952
953 logprobs := openaiMeta.Logprobs
954 require.True(t, ok)
955 require.NotNil(t, logprobs)
956 })
957
958 t.Run("should extract finish reason", func(t *testing.T) {
959 t.Parallel()
960
961 server := newMockServer()
962 defer server.close()
963
964 server.prepareJSONResponse(map[string]any{
965 "finish_reason": "stop",
966 })
967
968 provider := New(
969 WithAPIKey("test-api-key"),
970 WithBaseURL(server.server.URL),
971 )
972 model, _ := provider.LanguageModel("gpt-3.5-turbo")
973
974 result, err := model.Generate(context.Background(), ai.Call{
975 Prompt: testPrompt,
976 })
977
978 require.NoError(t, err)
979 require.Equal(t, ai.FinishReasonStop, result.FinishReason)
980 })
981
982 t.Run("should support unknown finish reason", func(t *testing.T) {
983 t.Parallel()
984
985 server := newMockServer()
986 defer server.close()
987
988 server.prepareJSONResponse(map[string]any{
989 "finish_reason": "eos",
990 })
991
992 provider := New(
993 WithAPIKey("test-api-key"),
994 WithBaseURL(server.server.URL),
995 )
996 model, _ := provider.LanguageModel("gpt-3.5-turbo")
997
998 result, err := model.Generate(context.Background(), ai.Call{
999 Prompt: testPrompt,
1000 })
1001
1002 require.NoError(t, err)
1003 require.Equal(t, ai.FinishReasonUnknown, result.FinishReason)
1004 })
1005
1006 t.Run("should pass the model and the messages", func(t *testing.T) {
1007 t.Parallel()
1008
1009 server := newMockServer()
1010 defer server.close()
1011
1012 server.prepareJSONResponse(map[string]any{
1013 "content": "",
1014 })
1015
1016 provider := New(
1017 WithAPIKey("test-api-key"),
1018 WithBaseURL(server.server.URL),
1019 )
1020 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1021
1022 _, err := model.Generate(context.Background(), ai.Call{
1023 Prompt: testPrompt,
1024 })
1025
1026 require.NoError(t, err)
1027 require.Len(t, server.calls, 1)
1028
1029 call := server.calls[0]
1030 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1031
1032 messages := call.body["messages"].([]any)
1033 require.Len(t, messages, 1)
1034
1035 message := messages[0].(map[string]any)
1036 require.Equal(t, "user", message["role"])
1037 require.Equal(t, "Hello", message["content"])
1038 })
1039
1040 t.Run("should pass settings", func(t *testing.T) {
1041 t.Parallel()
1042
1043 server := newMockServer()
1044 defer server.close()
1045
1046 server.prepareJSONResponse(map[string]any{})
1047
1048 provider := New(
1049 WithAPIKey("test-api-key"),
1050 WithBaseURL(server.server.URL),
1051 )
1052 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1053
1054 _, err := model.Generate(context.Background(), ai.Call{
1055 Prompt: testPrompt,
1056 ProviderOptions: NewProviderOptions(&ProviderOptions{
1057 LogitBias: map[string]int64{
1058 "50256": -100,
1059 },
1060 ParallelToolCalls: ai.BoolOption(false),
1061 User: ai.StringOption("test-user-id"),
1062 }),
1063 })
1064
1065 require.NoError(t, err)
1066 require.Len(t, server.calls, 1)
1067
1068 call := server.calls[0]
1069 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1070
1071 messages := call.body["messages"].([]any)
1072 require.Len(t, messages, 1)
1073
1074 logitBias := call.body["logit_bias"].(map[string]any)
1075 require.Equal(t, float64(-100), logitBias["50256"])
1076 require.Equal(t, false, call.body["parallel_tool_calls"])
1077 require.Equal(t, "test-user-id", call.body["user"])
1078 })
1079
1080 t.Run("should pass reasoningEffort setting", func(t *testing.T) {
1081 t.Parallel()
1082
1083 server := newMockServer()
1084 defer server.close()
1085
1086 server.prepareJSONResponse(map[string]any{
1087 "content": "",
1088 })
1089
1090 provider := New(
1091 WithAPIKey("test-api-key"),
1092 WithBaseURL(server.server.URL),
1093 )
1094 model, _ := provider.LanguageModel("o1-mini")
1095
1096 _, err := model.Generate(context.Background(), ai.Call{
1097 Prompt: testPrompt,
1098 ProviderOptions: NewProviderOptions(
1099 &ProviderOptions{
1100 ReasoningEffort: ReasoningEffortOption(ReasoningEffortLow),
1101 },
1102 ),
1103 })
1104
1105 require.NoError(t, err)
1106 require.Len(t, server.calls, 1)
1107
1108 call := server.calls[0]
1109 require.Equal(t, "o1-mini", call.body["model"])
1110 require.Equal(t, "low", call.body["reasoning_effort"])
1111
1112 messages := call.body["messages"].([]any)
1113 require.Len(t, messages, 1)
1114
1115 message := messages[0].(map[string]any)
1116 require.Equal(t, "user", message["role"])
1117 require.Equal(t, "Hello", message["content"])
1118 })
1119
1120 t.Run("should pass textVerbosity setting", func(t *testing.T) {
1121 t.Parallel()
1122
1123 server := newMockServer()
1124 defer server.close()
1125
1126 server.prepareJSONResponse(map[string]any{
1127 "content": "",
1128 })
1129
1130 provider := New(
1131 WithAPIKey("test-api-key"),
1132 WithBaseURL(server.server.URL),
1133 )
1134 model, _ := provider.LanguageModel("gpt-4o")
1135
1136 _, err := model.Generate(context.Background(), ai.Call{
1137 Prompt: testPrompt,
1138 ProviderOptions: NewProviderOptions(&ProviderOptions{
1139 TextVerbosity: ai.StringOption("low"),
1140 }),
1141 })
1142
1143 require.NoError(t, err)
1144 require.Len(t, server.calls, 1)
1145
1146 call := server.calls[0]
1147 require.Equal(t, "gpt-4o", call.body["model"])
1148 require.Equal(t, "low", call.body["verbosity"])
1149
1150 messages := call.body["messages"].([]any)
1151 require.Len(t, messages, 1)
1152
1153 message := messages[0].(map[string]any)
1154 require.Equal(t, "user", message["role"])
1155 require.Equal(t, "Hello", message["content"])
1156 })
1157
1158 t.Run("should pass tools and toolChoice", func(t *testing.T) {
1159 t.Parallel()
1160
1161 server := newMockServer()
1162 defer server.close()
1163
1164 server.prepareJSONResponse(map[string]any{
1165 "content": "",
1166 })
1167
1168 provider := New(
1169 WithAPIKey("test-api-key"),
1170 WithBaseURL(server.server.URL),
1171 )
1172 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1173
1174 _, err := model.Generate(context.Background(), ai.Call{
1175 Prompt: testPrompt,
1176 Tools: []ai.Tool{
1177 ai.FunctionTool{
1178 Name: "test-tool",
1179 InputSchema: map[string]any{
1180 "type": "object",
1181 "properties": map[string]any{
1182 "value": map[string]any{
1183 "type": "string",
1184 },
1185 },
1186 "required": []string{"value"},
1187 "additionalProperties": false,
1188 "$schema": "http://json-schema.org/draft-07/schema#",
1189 },
1190 },
1191 },
1192 ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0],
1193 })
1194
1195 require.NoError(t, err)
1196 require.Len(t, server.calls, 1)
1197
1198 call := server.calls[0]
1199 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1200
1201 messages := call.body["messages"].([]any)
1202 require.Len(t, messages, 1)
1203
1204 tools := call.body["tools"].([]any)
1205 require.Len(t, tools, 1)
1206
1207 tool := tools[0].(map[string]any)
1208 require.Equal(t, "function", tool["type"])
1209
1210 function := tool["function"].(map[string]any)
1211 require.Equal(t, "test-tool", function["name"])
1212 require.Equal(t, false, function["strict"])
1213
1214 toolChoice := call.body["tool_choice"].(map[string]any)
1215 require.Equal(t, "function", toolChoice["type"])
1216
1217 toolChoiceFunction := toolChoice["function"].(map[string]any)
1218 require.Equal(t, "test-tool", toolChoiceFunction["name"])
1219 })
1220
1221 t.Run("should parse tool results", func(t *testing.T) {
1222 t.Parallel()
1223
1224 server := newMockServer()
1225 defer server.close()
1226
1227 server.prepareJSONResponse(map[string]any{
1228 "tool_calls": []map[string]any{
1229 {
1230 "id": "call_O17Uplv4lJvD6DVdIvFFeRMw",
1231 "type": "function",
1232 "function": map[string]any{
1233 "name": "test-tool",
1234 "arguments": `{"value":"Spark"}`,
1235 },
1236 },
1237 },
1238 })
1239
1240 provider := New(
1241 WithAPIKey("test-api-key"),
1242 WithBaseURL(server.server.URL),
1243 )
1244 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1245
1246 result, err := model.Generate(context.Background(), ai.Call{
1247 Prompt: testPrompt,
1248 Tools: []ai.Tool{
1249 ai.FunctionTool{
1250 Name: "test-tool",
1251 InputSchema: map[string]any{
1252 "type": "object",
1253 "properties": map[string]any{
1254 "value": map[string]any{
1255 "type": "string",
1256 },
1257 },
1258 "required": []string{"value"},
1259 "additionalProperties": false,
1260 "$schema": "http://json-schema.org/draft-07/schema#",
1261 },
1262 },
1263 },
1264 ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0],
1265 })
1266
1267 require.NoError(t, err)
1268 require.Len(t, result.Content, 1)
1269
1270 toolCall, ok := result.Content[0].(ai.ToolCallContent)
1271 require.True(t, ok)
1272 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", toolCall.ToolCallID)
1273 require.Equal(t, "test-tool", toolCall.ToolName)
1274 require.Equal(t, `{"value":"Spark"}`, toolCall.Input)
1275 })
1276
1277 t.Run("should parse annotations/citations", func(t *testing.T) {
1278 t.Parallel()
1279
1280 server := newMockServer()
1281 defer server.close()
1282
1283 server.prepareJSONResponse(map[string]any{
1284 "content": "Based on the search results [doc1], I found information.",
1285 "annotations": []map[string]any{
1286 {
1287 "type": "url_citation",
1288 "url_citation": map[string]any{
1289 "start_index": 24,
1290 "end_index": 29,
1291 "url": "https://example.com/doc1.pdf",
1292 "title": "Document 1",
1293 },
1294 },
1295 },
1296 })
1297
1298 provider := New(
1299 WithAPIKey("test-api-key"),
1300 WithBaseURL(server.server.URL),
1301 )
1302 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1303
1304 result, err := model.Generate(context.Background(), ai.Call{
1305 Prompt: testPrompt,
1306 })
1307
1308 require.NoError(t, err)
1309 require.Len(t, result.Content, 2)
1310
1311 textContent, ok := result.Content[0].(ai.TextContent)
1312 require.True(t, ok)
1313 require.Equal(t, "Based on the search results [doc1], I found information.", textContent.Text)
1314
1315 sourceContent, ok := result.Content[1].(ai.SourceContent)
1316 require.True(t, ok)
1317 require.Equal(t, ai.SourceTypeURL, sourceContent.SourceType)
1318 require.Equal(t, "https://example.com/doc1.pdf", sourceContent.URL)
1319 require.Equal(t, "Document 1", sourceContent.Title)
1320 require.NotEmpty(t, sourceContent.ID)
1321 })
1322
1323 t.Run("should return cached_tokens in prompt_details_tokens", func(t *testing.T) {
1324 t.Parallel()
1325
1326 server := newMockServer()
1327 defer server.close()
1328
1329 server.prepareJSONResponse(map[string]any{
1330 "usage": map[string]any{
1331 "prompt_tokens": 15,
1332 "completion_tokens": 20,
1333 "total_tokens": 35,
1334 "prompt_tokens_details": map[string]any{
1335 "cached_tokens": 1152,
1336 },
1337 },
1338 })
1339
1340 provider := New(
1341 WithAPIKey("test-api-key"),
1342 WithBaseURL(server.server.URL),
1343 )
1344 model, _ := provider.LanguageModel("gpt-4o-mini")
1345
1346 result, err := model.Generate(context.Background(), ai.Call{
1347 Prompt: testPrompt,
1348 })
1349
1350 require.NoError(t, err)
1351 require.Equal(t, int64(1152), result.Usage.CacheReadTokens)
1352 require.Equal(t, int64(15), result.Usage.InputTokens)
1353 require.Equal(t, int64(20), result.Usage.OutputTokens)
1354 require.Equal(t, int64(35), result.Usage.TotalTokens)
1355 })
1356
1357 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
1358 t.Parallel()
1359
1360 server := newMockServer()
1361 defer server.close()
1362
1363 server.prepareJSONResponse(map[string]any{
1364 "usage": map[string]any{
1365 "prompt_tokens": 15,
1366 "completion_tokens": 20,
1367 "total_tokens": 35,
1368 "completion_tokens_details": map[string]any{
1369 "accepted_prediction_tokens": 123,
1370 "rejected_prediction_tokens": 456,
1371 },
1372 },
1373 })
1374
1375 provider := New(
1376 WithAPIKey("test-api-key"),
1377 WithBaseURL(server.server.URL),
1378 )
1379 model, _ := provider.LanguageModel("gpt-4o-mini")
1380
1381 result, err := model.Generate(context.Background(), ai.Call{
1382 Prompt: testPrompt,
1383 })
1384
1385 require.NoError(t, err)
1386 require.NotNil(t, result.ProviderMetadata)
1387
1388 openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
1389
1390 require.True(t, ok)
1391 require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
1392 require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
1393 })
1394
1395 t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) {
1396 t.Parallel()
1397
1398 server := newMockServer()
1399 defer server.close()
1400
1401 server.prepareJSONResponse(map[string]any{})
1402
1403 provider := New(
1404 WithAPIKey("test-api-key"),
1405 WithBaseURL(server.server.URL),
1406 )
1407 model, _ := provider.LanguageModel("o1-preview")
1408
1409 result, err := model.Generate(context.Background(), ai.Call{
1410 Prompt: testPrompt,
1411 Temperature: &[]float64{0.5}[0],
1412 TopP: &[]float64{0.7}[0],
1413 FrequencyPenalty: &[]float64{0.2}[0],
1414 PresencePenalty: &[]float64{0.3}[0],
1415 })
1416
1417 require.NoError(t, err)
1418 require.Len(t, server.calls, 1)
1419
1420 call := server.calls[0]
1421 require.Equal(t, "o1-preview", call.body["model"])
1422
1423 messages := call.body["messages"].([]any)
1424 require.Len(t, messages, 1)
1425
1426 message := messages[0].(map[string]any)
1427 require.Equal(t, "user", message["role"])
1428 require.Equal(t, "Hello", message["content"])
1429
1430 // These should not be present
1431 require.Nil(t, call.body["temperature"])
1432 require.Nil(t, call.body["top_p"])
1433 require.Nil(t, call.body["frequency_penalty"])
1434 require.Nil(t, call.body["presence_penalty"])
1435
1436 // Should have warnings
1437 require.Len(t, result.Warnings, 4)
1438 require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1439 require.Equal(t, "temperature", result.Warnings[0].Setting)
1440 require.Contains(t, result.Warnings[0].Details, "temperature is not supported for reasoning models")
1441 })
1442
1443 t.Run("should convert maxOutputTokens to max_completion_tokens for reasoning models", func(t *testing.T) {
1444 t.Parallel()
1445
1446 server := newMockServer()
1447 defer server.close()
1448
1449 server.prepareJSONResponse(map[string]any{})
1450
1451 provider := New(
1452 WithAPIKey("test-api-key"),
1453 WithBaseURL(server.server.URL),
1454 )
1455 model, _ := provider.LanguageModel("o1-preview")
1456
1457 _, err := model.Generate(context.Background(), ai.Call{
1458 Prompt: testPrompt,
1459 MaxOutputTokens: &[]int64{1000}[0],
1460 })
1461
1462 require.NoError(t, err)
1463 require.Len(t, server.calls, 1)
1464
1465 call := server.calls[0]
1466 require.Equal(t, "o1-preview", call.body["model"])
1467 require.Equal(t, float64(1000), call.body["max_completion_tokens"])
1468 require.Nil(t, call.body["max_tokens"])
1469
1470 messages := call.body["messages"].([]any)
1471 require.Len(t, messages, 1)
1472
1473 message := messages[0].(map[string]any)
1474 require.Equal(t, "user", message["role"])
1475 require.Equal(t, "Hello", message["content"])
1476 })
1477
1478 t.Run("should return reasoning tokens", func(t *testing.T) {
1479 t.Parallel()
1480
1481 server := newMockServer()
1482 defer server.close()
1483
1484 server.prepareJSONResponse(map[string]any{
1485 "usage": map[string]any{
1486 "prompt_tokens": 15,
1487 "completion_tokens": 20,
1488 "total_tokens": 35,
1489 "completion_tokens_details": map[string]any{
1490 "reasoning_tokens": 10,
1491 },
1492 },
1493 })
1494
1495 provider := New(
1496 WithAPIKey("test-api-key"),
1497 WithBaseURL(server.server.URL),
1498 )
1499 model, _ := provider.LanguageModel("o1-preview")
1500
1501 result, err := model.Generate(context.Background(), ai.Call{
1502 Prompt: testPrompt,
1503 })
1504
1505 require.NoError(t, err)
1506 require.Equal(t, int64(15), result.Usage.InputTokens)
1507 require.Equal(t, int64(20), result.Usage.OutputTokens)
1508 require.Equal(t, int64(35), result.Usage.TotalTokens)
1509 require.Equal(t, int64(10), result.Usage.ReasoningTokens)
1510 })
1511
1512 t.Run("should send max_completion_tokens extension setting", func(t *testing.T) {
1513 t.Parallel()
1514
1515 server := newMockServer()
1516 defer server.close()
1517
1518 server.prepareJSONResponse(map[string]any{
1519 "model": "o1-preview",
1520 })
1521
1522 provider := New(
1523 WithAPIKey("test-api-key"),
1524 WithBaseURL(server.server.URL),
1525 )
1526 model, _ := provider.LanguageModel("o1-preview")
1527
1528 _, err := model.Generate(context.Background(), ai.Call{
1529 Prompt: testPrompt,
1530 ProviderOptions: NewProviderOptions(&ProviderOptions{
1531 MaxCompletionTokens: ai.IntOption(255),
1532 }),
1533 })
1534
1535 require.NoError(t, err)
1536 require.Len(t, server.calls, 1)
1537
1538 call := server.calls[0]
1539 require.Equal(t, "o1-preview", call.body["model"])
1540 require.Equal(t, float64(255), call.body["max_completion_tokens"])
1541
1542 messages := call.body["messages"].([]any)
1543 require.Len(t, messages, 1)
1544
1545 message := messages[0].(map[string]any)
1546 require.Equal(t, "user", message["role"])
1547 require.Equal(t, "Hello", message["content"])
1548 })
1549
1550 t.Run("should send prediction extension setting", func(t *testing.T) {
1551 t.Parallel()
1552
1553 server := newMockServer()
1554 defer server.close()
1555
1556 server.prepareJSONResponse(map[string]any{
1557 "content": "",
1558 })
1559
1560 provider := New(
1561 WithAPIKey("test-api-key"),
1562 WithBaseURL(server.server.URL),
1563 )
1564 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1565
1566 _, err := model.Generate(context.Background(), ai.Call{
1567 Prompt: testPrompt,
1568 ProviderOptions: NewProviderOptions(&ProviderOptions{
1569 Prediction: map[string]any{
1570 "type": "content",
1571 "content": "Hello, World!",
1572 },
1573 }),
1574 })
1575
1576 require.NoError(t, err)
1577 require.Len(t, server.calls, 1)
1578
1579 call := server.calls[0]
1580 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1581
1582 prediction := call.body["prediction"].(map[string]any)
1583 require.Equal(t, "content", prediction["type"])
1584 require.Equal(t, "Hello, World!", prediction["content"])
1585
1586 messages := call.body["messages"].([]any)
1587 require.Len(t, messages, 1)
1588
1589 message := messages[0].(map[string]any)
1590 require.Equal(t, "user", message["role"])
1591 require.Equal(t, "Hello", message["content"])
1592 })
1593
1594 t.Run("should send store extension setting", func(t *testing.T) {
1595 t.Parallel()
1596
1597 server := newMockServer()
1598 defer server.close()
1599
1600 server.prepareJSONResponse(map[string]any{
1601 "content": "",
1602 })
1603
1604 provider := New(
1605 WithAPIKey("test-api-key"),
1606 WithBaseURL(server.server.URL),
1607 )
1608 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1609
1610 _, err := model.Generate(context.Background(), ai.Call{
1611 Prompt: testPrompt,
1612 ProviderOptions: NewProviderOptions(&ProviderOptions{
1613 Store: ai.BoolOption(true),
1614 }),
1615 })
1616
1617 require.NoError(t, err)
1618 require.Len(t, server.calls, 1)
1619
1620 call := server.calls[0]
1621 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1622 require.Equal(t, true, call.body["store"])
1623
1624 messages := call.body["messages"].([]any)
1625 require.Len(t, messages, 1)
1626
1627 message := messages[0].(map[string]any)
1628 require.Equal(t, "user", message["role"])
1629 require.Equal(t, "Hello", message["content"])
1630 })
1631
1632 t.Run("should send metadata extension values", func(t *testing.T) {
1633 t.Parallel()
1634
1635 server := newMockServer()
1636 defer server.close()
1637
1638 server.prepareJSONResponse(map[string]any{
1639 "content": "",
1640 })
1641
1642 provider := New(
1643 WithAPIKey("test-api-key"),
1644 WithBaseURL(server.server.URL),
1645 )
1646 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1647
1648 _, err := model.Generate(context.Background(), ai.Call{
1649 Prompt: testPrompt,
1650 ProviderOptions: NewProviderOptions(&ProviderOptions{
1651 Metadata: map[string]any{
1652 "custom": "value",
1653 },
1654 }),
1655 })
1656
1657 require.NoError(t, err)
1658 require.Len(t, server.calls, 1)
1659
1660 call := server.calls[0]
1661 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1662
1663 metadata := call.body["metadata"].(map[string]any)
1664 require.Equal(t, "value", metadata["custom"])
1665
1666 messages := call.body["messages"].([]any)
1667 require.Len(t, messages, 1)
1668
1669 message := messages[0].(map[string]any)
1670 require.Equal(t, "user", message["role"])
1671 require.Equal(t, "Hello", message["content"])
1672 })
1673
1674 t.Run("should send promptCacheKey extension value", func(t *testing.T) {
1675 t.Parallel()
1676
1677 server := newMockServer()
1678 defer server.close()
1679
1680 server.prepareJSONResponse(map[string]any{
1681 "content": "",
1682 })
1683
1684 provider := New(
1685 WithAPIKey("test-api-key"),
1686 WithBaseURL(server.server.URL),
1687 )
1688 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1689
1690 _, err := model.Generate(context.Background(), ai.Call{
1691 Prompt: testPrompt,
1692 ProviderOptions: NewProviderOptions(&ProviderOptions{
1693 PromptCacheKey: ai.StringOption("test-cache-key-123"),
1694 }),
1695 })
1696
1697 require.NoError(t, err)
1698 require.Len(t, server.calls, 1)
1699
1700 call := server.calls[0]
1701 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1702 require.Equal(t, "test-cache-key-123", call.body["prompt_cache_key"])
1703
1704 messages := call.body["messages"].([]any)
1705 require.Len(t, messages, 1)
1706
1707 message := messages[0].(map[string]any)
1708 require.Equal(t, "user", message["role"])
1709 require.Equal(t, "Hello", message["content"])
1710 })
1711
1712 t.Run("should send safety_identifier extension value", func(t *testing.T) {
1713 t.Parallel()
1714
1715 server := newMockServer()
1716 defer server.close()
1717
1718 server.prepareJSONResponse(map[string]any{
1719 "content": "",
1720 })
1721
1722 provider := New(
1723 WithAPIKey("test-api-key"),
1724 WithBaseURL(server.server.URL),
1725 )
1726 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1727
1728 _, err := model.Generate(context.Background(), ai.Call{
1729 Prompt: testPrompt,
1730 ProviderOptions: NewProviderOptions(&ProviderOptions{
1731 SafetyIdentifier: ai.StringOption("test-safety-identifier-123"),
1732 }),
1733 })
1734
1735 require.NoError(t, err)
1736 require.Len(t, server.calls, 1)
1737
1738 call := server.calls[0]
1739 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1740 require.Equal(t, "test-safety-identifier-123", call.body["safety_identifier"])
1741
1742 messages := call.body["messages"].([]any)
1743 require.Len(t, messages, 1)
1744
1745 message := messages[0].(map[string]any)
1746 require.Equal(t, "user", message["role"])
1747 require.Equal(t, "Hello", message["content"])
1748 })
1749
1750 t.Run("should remove temperature setting for search preview models", func(t *testing.T) {
1751 t.Parallel()
1752
1753 server := newMockServer()
1754 defer server.close()
1755
1756 server.prepareJSONResponse(map[string]any{})
1757
1758 provider := New(
1759 WithAPIKey("test-api-key"),
1760 WithBaseURL(server.server.URL),
1761 )
1762 model, _ := provider.LanguageModel("gpt-4o-search-preview")
1763
1764 result, err := model.Generate(context.Background(), ai.Call{
1765 Prompt: testPrompt,
1766 Temperature: &[]float64{0.7}[0],
1767 })
1768
1769 require.NoError(t, err)
1770 require.Len(t, server.calls, 1)
1771
1772 call := server.calls[0]
1773 require.Equal(t, "gpt-4o-search-preview", call.body["model"])
1774 require.Nil(t, call.body["temperature"])
1775
1776 require.Len(t, result.Warnings, 1)
1777 require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1778 require.Equal(t, "temperature", result.Warnings[0].Setting)
1779 require.Contains(t, result.Warnings[0].Details, "search preview models")
1780 })
1781
1782 t.Run("should send ServiceTier flex processing setting", func(t *testing.T) {
1783 t.Parallel()
1784
1785 server := newMockServer()
1786 defer server.close()
1787
1788 server.prepareJSONResponse(map[string]any{
1789 "content": "",
1790 })
1791
1792 provider := New(
1793 WithAPIKey("test-api-key"),
1794 WithBaseURL(server.server.URL),
1795 )
1796 model, _ := provider.LanguageModel("o3-mini")
1797
1798 _, err := model.Generate(context.Background(), ai.Call{
1799 Prompt: testPrompt,
1800 ProviderOptions: NewProviderOptions(&ProviderOptions{
1801 ServiceTier: ai.StringOption("flex"),
1802 }),
1803 })
1804
1805 require.NoError(t, err)
1806 require.Len(t, server.calls, 1)
1807
1808 call := server.calls[0]
1809 require.Equal(t, "o3-mini", call.body["model"])
1810 require.Equal(t, "flex", call.body["service_tier"])
1811
1812 messages := call.body["messages"].([]any)
1813 require.Len(t, messages, 1)
1814
1815 message := messages[0].(map[string]any)
1816 require.Equal(t, "user", message["role"])
1817 require.Equal(t, "Hello", message["content"])
1818 })
1819
1820 t.Run("should show warning when using flex processing with unsupported model", func(t *testing.T) {
1821 t.Parallel()
1822
1823 server := newMockServer()
1824 defer server.close()
1825
1826 server.prepareJSONResponse(map[string]any{})
1827
1828 provider := New(
1829 WithAPIKey("test-api-key"),
1830 WithBaseURL(server.server.URL),
1831 )
1832 model, _ := provider.LanguageModel("gpt-4o-mini")
1833
1834 result, err := model.Generate(context.Background(), ai.Call{
1835 Prompt: testPrompt,
1836 ProviderOptions: NewProviderOptions(&ProviderOptions{
1837 ServiceTier: ai.StringOption("flex"),
1838 }),
1839 })
1840
1841 require.NoError(t, err)
1842 require.Len(t, server.calls, 1)
1843
1844 call := server.calls[0]
1845 require.Nil(t, call.body["service_tier"])
1846
1847 require.Len(t, result.Warnings, 1)
1848 require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1849 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
1850 require.Contains(t, result.Warnings[0].Details, "flex processing is only available")
1851 })
1852
1853 t.Run("should send serviceTier priority processing setting", func(t *testing.T) {
1854 t.Parallel()
1855
1856 server := newMockServer()
1857 defer server.close()
1858
1859 server.prepareJSONResponse(map[string]any{})
1860
1861 provider := New(
1862 WithAPIKey("test-api-key"),
1863 WithBaseURL(server.server.URL),
1864 )
1865 model, _ := provider.LanguageModel("gpt-4o-mini")
1866
1867 _, err := model.Generate(context.Background(), ai.Call{
1868 Prompt: testPrompt,
1869 ProviderOptions: NewProviderOptions(&ProviderOptions{
1870 ServiceTier: ai.StringOption("priority"),
1871 }),
1872 })
1873
1874 require.NoError(t, err)
1875 require.Len(t, server.calls, 1)
1876
1877 call := server.calls[0]
1878 require.Equal(t, "gpt-4o-mini", call.body["model"])
1879 require.Equal(t, "priority", call.body["service_tier"])
1880
1881 messages := call.body["messages"].([]any)
1882 require.Len(t, messages, 1)
1883
1884 message := messages[0].(map[string]any)
1885 require.Equal(t, "user", message["role"])
1886 require.Equal(t, "Hello", message["content"])
1887 })
1888
1889 t.Run("should show warning when using priority processing with unsupported model", func(t *testing.T) {
1890 t.Parallel()
1891
1892 server := newMockServer()
1893 defer server.close()
1894
1895 server.prepareJSONResponse(map[string]any{})
1896
1897 provider := New(
1898 WithAPIKey("test-api-key"),
1899 WithBaseURL(server.server.URL),
1900 )
1901 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1902
1903 result, err := model.Generate(context.Background(), ai.Call{
1904 Prompt: testPrompt,
1905 ProviderOptions: NewProviderOptions(&ProviderOptions{
1906 ServiceTier: ai.StringOption("priority"),
1907 }),
1908 })
1909
1910 require.NoError(t, err)
1911 require.Len(t, server.calls, 1)
1912
1913 call := server.calls[0]
1914 require.Nil(t, call.body["service_tier"])
1915
1916 require.Len(t, result.Warnings, 1)
1917 require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1918 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
1919 require.Contains(t, result.Warnings[0].Details, "priority processing is only available")
1920 })
1921}
1922
1923type streamingMockServer struct {
1924 server *httptest.Server
1925 chunks []string
1926 calls []mockCall
1927}
1928
1929func newStreamingMockServer() *streamingMockServer {
1930 sms := &streamingMockServer{
1931 calls: make([]mockCall, 0),
1932 }
1933
1934 sms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1935 // Record the call
1936 call := mockCall{
1937 method: r.Method,
1938 path: r.URL.Path,
1939 headers: make(map[string]string),
1940 }
1941
1942 for k, v := range r.Header {
1943 if len(v) > 0 {
1944 call.headers[k] = v[0]
1945 }
1946 }
1947
1948 // Parse request body
1949 if r.Body != nil {
1950 var body map[string]any
1951 json.NewDecoder(r.Body).Decode(&body)
1952 call.body = body
1953 }
1954
1955 sms.calls = append(sms.calls, call)
1956
1957 // Set streaming headers
1958 w.Header().Set("Content-Type", "text/event-stream")
1959 w.Header().Set("Cache-Control", "no-cache")
1960 w.Header().Set("Connection", "keep-alive")
1961
1962 // Add custom headers if any
1963 for _, chunk := range sms.chunks {
1964 if strings.HasPrefix(chunk, "HEADER:") {
1965 parts := strings.SplitN(chunk[7:], ":", 2)
1966 if len(parts) == 2 {
1967 w.Header().Set(parts[0], parts[1])
1968 }
1969 continue
1970 }
1971 }
1972
1973 w.WriteHeader(http.StatusOK)
1974
1975 // Write chunks
1976 for _, chunk := range sms.chunks {
1977 if strings.HasPrefix(chunk, "HEADER:") {
1978 continue
1979 }
1980 w.Write([]byte(chunk))
1981 if f, ok := w.(http.Flusher); ok {
1982 f.Flush()
1983 }
1984 }
1985 }))
1986
1987 return sms
1988}
1989
1990func (sms *streamingMockServer) close() {
1991 sms.server.Close()
1992}
1993
1994func (sms *streamingMockServer) prepareStreamResponse(opts map[string]any) {
1995 content := []string{}
1996 if c, ok := opts["content"].([]string); ok {
1997 content = c
1998 }
1999
2000 usage := map[string]any{
2001 "prompt_tokens": 17,
2002 "total_tokens": 244,
2003 "completion_tokens": 227,
2004 }
2005 if u, ok := opts["usage"].(map[string]any); ok {
2006 usage = u
2007 }
2008
2009 logprobs := map[string]any{}
2010 if l, ok := opts["logprobs"].(map[string]any); ok {
2011 logprobs = l
2012 }
2013
2014 finishReason := "stop"
2015 if fr, ok := opts["finish_reason"].(string); ok {
2016 finishReason = fr
2017 }
2018
2019 model := "gpt-3.5-turbo-0613"
2020 if m, ok := opts["model"].(string); ok {
2021 model = m
2022 }
2023
2024 headers := map[string]string{}
2025 if h, ok := opts["headers"].(map[string]string); ok {
2026 headers = h
2027 }
2028
2029 chunks := []string{}
2030
2031 // Add custom headers
2032 for k, v := range headers {
2033 chunks = append(chunks, "HEADER:"+k+":"+v)
2034 }
2035
2036 // Initial chunk with role
2037 initialChunk := map[string]any{
2038 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2039 "object": "chat.completion.chunk",
2040 "created": 1702657020,
2041 "model": model,
2042 "system_fingerprint": nil,
2043 "choices": []map[string]any{
2044 {
2045 "index": 0,
2046 "delta": map[string]any{
2047 "role": "assistant",
2048 "content": "",
2049 },
2050 "finish_reason": nil,
2051 },
2052 },
2053 }
2054 initialData, _ := json.Marshal(initialChunk)
2055 chunks = append(chunks, "data: "+string(initialData)+"\n\n")
2056
2057 // Content chunks
2058 for i, text := range content {
2059 contentChunk := map[string]any{
2060 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2061 "object": "chat.completion.chunk",
2062 "created": 1702657020,
2063 "model": model,
2064 "system_fingerprint": nil,
2065 "choices": []map[string]any{
2066 {
2067 "index": 1,
2068 "delta": map[string]any{
2069 "content": text,
2070 },
2071 "finish_reason": nil,
2072 },
2073 },
2074 }
2075 contentData, _ := json.Marshal(contentChunk)
2076 chunks = append(chunks, "data: "+string(contentData)+"\n\n")
2077
2078 // Add annotations if this is the last content chunk and we have annotations
2079 if i == len(content)-1 {
2080 if annotations, ok := opts["annotations"].([]map[string]any); ok {
2081 annotationChunk := map[string]any{
2082 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2083 "object": "chat.completion.chunk",
2084 "created": 1702657020,
2085 "model": model,
2086 "system_fingerprint": nil,
2087 "choices": []map[string]any{
2088 {
2089 "index": 1,
2090 "delta": map[string]any{
2091 "annotations": annotations,
2092 },
2093 "finish_reason": nil,
2094 },
2095 },
2096 }
2097 annotationData, _ := json.Marshal(annotationChunk)
2098 chunks = append(chunks, "data: "+string(annotationData)+"\n\n")
2099 }
2100 }
2101 }
2102
2103 // Finish chunk
2104 finishChunk := map[string]any{
2105 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2106 "object": "chat.completion.chunk",
2107 "created": 1702657020,
2108 "model": model,
2109 "system_fingerprint": nil,
2110 "choices": []map[string]any{
2111 {
2112 "index": 0,
2113 "delta": map[string]any{},
2114 "finish_reason": finishReason,
2115 },
2116 },
2117 }
2118
2119 if len(logprobs) > 0 {
2120 finishChunk["choices"].([]map[string]any)[0]["logprobs"] = logprobs
2121 }
2122
2123 finishData, _ := json.Marshal(finishChunk)
2124 chunks = append(chunks, "data: "+string(finishData)+"\n\n")
2125
2126 // Usage chunk
2127 usageChunk := map[string]any{
2128 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2129 "object": "chat.completion.chunk",
2130 "created": 1702657020,
2131 "model": model,
2132 "system_fingerprint": "fp_3bc1b5746c",
2133 "choices": []map[string]any{},
2134 "usage": usage,
2135 }
2136 usageData, _ := json.Marshal(usageChunk)
2137 chunks = append(chunks, "data: "+string(usageData)+"\n\n")
2138
2139 // Done
2140 chunks = append(chunks, "data: [DONE]\n\n")
2141
2142 sms.chunks = chunks
2143}
2144
2145func (sms *streamingMockServer) prepareToolStreamResponse() {
2146 chunks := []string{
2147 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2148 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2149 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"value"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2150 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2151 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2152 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2153 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2154 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2155 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
2156 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
2157 "data: [DONE]\n\n",
2158 }
2159 sms.chunks = chunks
2160}
2161
2162func (sms *streamingMockServer) prepareErrorStreamResponse() {
2163 chunks := []string{
2164 `data: {"error":{"message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error.","type":"server_error","param":null,"code":null}}` + "\n\n",
2165 "data: [DONE]\n\n",
2166 }
2167 sms.chunks = chunks
2168}
2169
2170func collectStreamParts(stream ai.StreamResponse) ([]ai.StreamPart, error) {
2171 var parts []ai.StreamPart
2172 for part := range stream {
2173 parts = append(parts, part)
2174 if part.Type == ai.StreamPartTypeError {
2175 break
2176 }
2177 if part.Type == ai.StreamPartTypeFinish {
2178 break
2179 }
2180 }
2181 return parts, nil
2182}
2183
2184func TestDoStream(t *testing.T) {
2185 t.Parallel()
2186
2187 t.Run("should stream text deltas", func(t *testing.T) {
2188 t.Parallel()
2189
2190 server := newStreamingMockServer()
2191 defer server.close()
2192
2193 server.prepareStreamResponse(map[string]any{
2194 "content": []string{"Hello", ", ", "World!"},
2195 "finish_reason": "stop",
2196 "usage": map[string]any{
2197 "prompt_tokens": 17,
2198 "total_tokens": 244,
2199 "completion_tokens": 227,
2200 },
2201 "logprobs": testLogprobs,
2202 })
2203
2204 provider := New(
2205 WithAPIKey("test-api-key"),
2206 WithBaseURL(server.server.URL),
2207 )
2208 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2209
2210 stream, err := model.Stream(context.Background(), ai.Call{
2211 Prompt: testPrompt,
2212 })
2213
2214 require.NoError(t, err)
2215
2216 parts, err := collectStreamParts(stream)
2217 require.NoError(t, err)
2218
2219 // Verify stream structure
2220 require.True(t, len(parts) >= 4) // text-start, deltas, text-end, finish
2221
2222 // Find text parts
2223 textStart, textEnd, finish := -1, -1, -1
2224 var deltas []string
2225
2226 for i, part := range parts {
2227 switch part.Type {
2228 case ai.StreamPartTypeTextStart:
2229 textStart = i
2230 case ai.StreamPartTypeTextDelta:
2231 deltas = append(deltas, part.Delta)
2232 case ai.StreamPartTypeTextEnd:
2233 textEnd = i
2234 case ai.StreamPartTypeFinish:
2235 finish = i
2236 }
2237 }
2238
2239 require.NotEqual(t, -1, textStart)
2240 require.NotEqual(t, -1, textEnd)
2241 require.NotEqual(t, -1, finish)
2242 require.Equal(t, []string{"Hello", ", ", "World!"}, deltas)
2243
2244 // Check finish part
2245 finishPart := parts[finish]
2246 require.Equal(t, ai.FinishReasonStop, finishPart.FinishReason)
2247 require.Equal(t, int64(17), finishPart.Usage.InputTokens)
2248 require.Equal(t, int64(227), finishPart.Usage.OutputTokens)
2249 require.Equal(t, int64(244), finishPart.Usage.TotalTokens)
2250 })
2251
2252 t.Run("should stream tool deltas", func(t *testing.T) {
2253 t.Parallel()
2254
2255 server := newStreamingMockServer()
2256 defer server.close()
2257
2258 server.prepareToolStreamResponse()
2259
2260 provider := New(
2261 WithAPIKey("test-api-key"),
2262 WithBaseURL(server.server.URL),
2263 )
2264 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2265
2266 stream, err := model.Stream(context.Background(), ai.Call{
2267 Prompt: testPrompt,
2268 Tools: []ai.Tool{
2269 ai.FunctionTool{
2270 Name: "test-tool",
2271 InputSchema: map[string]any{
2272 "type": "object",
2273 "properties": map[string]any{
2274 "value": map[string]any{
2275 "type": "string",
2276 },
2277 },
2278 "required": []string{"value"},
2279 "additionalProperties": false,
2280 "$schema": "http://json-schema.org/draft-07/schema#",
2281 },
2282 },
2283 },
2284 })
2285
2286 require.NoError(t, err)
2287
2288 parts, err := collectStreamParts(stream)
2289 require.NoError(t, err)
2290
2291 // Find tool-related parts
2292 toolInputStart, toolInputEnd, toolCall := -1, -1, -1
2293 var toolDeltas []string
2294
2295 for i, part := range parts {
2296 switch part.Type {
2297 case ai.StreamPartTypeToolInputStart:
2298 toolInputStart = i
2299 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2300 require.Equal(t, "test-tool", part.ToolCallName)
2301 case ai.StreamPartTypeToolInputDelta:
2302 toolDeltas = append(toolDeltas, part.Delta)
2303 case ai.StreamPartTypeToolInputEnd:
2304 toolInputEnd = i
2305 case ai.StreamPartTypeToolCall:
2306 toolCall = i
2307 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2308 require.Equal(t, "test-tool", part.ToolCallName)
2309 require.Equal(t, `{"value":"Sparkle Day"}`, part.ToolCallInput)
2310 }
2311 }
2312
2313 require.NotEqual(t, -1, toolInputStart)
2314 require.NotEqual(t, -1, toolInputEnd)
2315 require.NotEqual(t, -1, toolCall)
2316
2317 // Verify tool deltas combine to form the complete input
2318 fullInput := ""
2319 for _, delta := range toolDeltas {
2320 fullInput += delta
2321 }
2322 require.Equal(t, `{"value":"Sparkle Day"}`, fullInput)
2323 })
2324
2325 t.Run("should stream annotations/citations", func(t *testing.T) {
2326 t.Parallel()
2327
2328 server := newStreamingMockServer()
2329 defer server.close()
2330
2331 server.prepareStreamResponse(map[string]any{
2332 "content": []string{"Based on search results"},
2333 "annotations": []map[string]any{
2334 {
2335 "type": "url_citation",
2336 "url_citation": map[string]any{
2337 "start_index": 24,
2338 "end_index": 29,
2339 "url": "https://example.com/doc1.pdf",
2340 "title": "Document 1",
2341 },
2342 },
2343 },
2344 })
2345
2346 provider := New(
2347 WithAPIKey("test-api-key"),
2348 WithBaseURL(server.server.URL),
2349 )
2350 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2351
2352 stream, err := model.Stream(context.Background(), ai.Call{
2353 Prompt: testPrompt,
2354 })
2355
2356 require.NoError(t, err)
2357
2358 parts, err := collectStreamParts(stream)
2359 require.NoError(t, err)
2360
2361 // Find source part
2362 var sourcePart *ai.StreamPart
2363 for _, part := range parts {
2364 if part.Type == ai.StreamPartTypeSource {
2365 sourcePart = &part
2366 break
2367 }
2368 }
2369
2370 require.NotNil(t, sourcePart)
2371 require.Equal(t, ai.SourceTypeURL, sourcePart.SourceType)
2372 require.Equal(t, "https://example.com/doc1.pdf", sourcePart.URL)
2373 require.Equal(t, "Document 1", sourcePart.Title)
2374 require.NotEmpty(t, sourcePart.ID)
2375 })
2376
2377 t.Run("should handle error stream parts", func(t *testing.T) {
2378 t.Parallel()
2379
2380 server := newStreamingMockServer()
2381 defer server.close()
2382
2383 server.prepareErrorStreamResponse()
2384
2385 provider := New(
2386 WithAPIKey("test-api-key"),
2387 WithBaseURL(server.server.URL),
2388 )
2389 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2390
2391 stream, err := model.Stream(context.Background(), ai.Call{
2392 Prompt: testPrompt,
2393 })
2394
2395 require.NoError(t, err)
2396
2397 parts, err := collectStreamParts(stream)
2398 require.NoError(t, err)
2399
2400 // Should have error and finish parts
2401 require.True(t, len(parts) >= 1)
2402
2403 // Find error part
2404 var errorPart *ai.StreamPart
2405 for _, part := range parts {
2406 if part.Type == ai.StreamPartTypeError {
2407 errorPart = &part
2408 break
2409 }
2410 }
2411
2412 require.NotNil(t, errorPart)
2413 require.NotNil(t, errorPart.Error)
2414 })
2415
2416 t.Run("should send request body", func(t *testing.T) {
2417 t.Parallel()
2418
2419 server := newStreamingMockServer()
2420 defer server.close()
2421
2422 server.prepareStreamResponse(map[string]any{
2423 "content": []string{},
2424 })
2425
2426 provider := New(
2427 WithAPIKey("test-api-key"),
2428 WithBaseURL(server.server.URL),
2429 )
2430 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2431
2432 _, err := model.Stream(context.Background(), ai.Call{
2433 Prompt: testPrompt,
2434 })
2435
2436 require.NoError(t, err)
2437 require.Len(t, server.calls, 1)
2438
2439 call := server.calls[0]
2440 require.Equal(t, "POST", call.method)
2441 require.Equal(t, "/chat/completions", call.path)
2442 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2443 require.Equal(t, true, call.body["stream"])
2444
2445 streamOptions := call.body["stream_options"].(map[string]any)
2446 require.Equal(t, true, streamOptions["include_usage"])
2447
2448 messages := call.body["messages"].([]any)
2449 require.Len(t, messages, 1)
2450
2451 message := messages[0].(map[string]any)
2452 require.Equal(t, "user", message["role"])
2453 require.Equal(t, "Hello", message["content"])
2454 })
2455
2456 t.Run("should return cached tokens in providerMetadata", func(t *testing.T) {
2457 t.Parallel()
2458
2459 server := newStreamingMockServer()
2460 defer server.close()
2461
2462 server.prepareStreamResponse(map[string]any{
2463 "content": []string{},
2464 "usage": map[string]any{
2465 "prompt_tokens": 15,
2466 "completion_tokens": 20,
2467 "total_tokens": 35,
2468 "prompt_tokens_details": map[string]any{
2469 "cached_tokens": 1152,
2470 },
2471 },
2472 })
2473
2474 provider := New(
2475 WithAPIKey("test-api-key"),
2476 WithBaseURL(server.server.URL),
2477 )
2478 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2479
2480 stream, err := model.Stream(context.Background(), ai.Call{
2481 Prompt: testPrompt,
2482 })
2483
2484 require.NoError(t, err)
2485
2486 parts, err := collectStreamParts(stream)
2487 require.NoError(t, err)
2488
2489 // Find finish part
2490 var finishPart *ai.StreamPart
2491 for _, part := range parts {
2492 if part.Type == ai.StreamPartTypeFinish {
2493 finishPart = &part
2494 break
2495 }
2496 }
2497
2498 require.NotNil(t, finishPart)
2499 require.Equal(t, int64(1152), finishPart.Usage.CacheReadTokens)
2500 require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2501 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2502 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2503 })
2504
2505 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
2506 t.Parallel()
2507
2508 server := newStreamingMockServer()
2509 defer server.close()
2510
2511 server.prepareStreamResponse(map[string]any{
2512 "content": []string{},
2513 "usage": map[string]any{
2514 "prompt_tokens": 15,
2515 "completion_tokens": 20,
2516 "total_tokens": 35,
2517 "completion_tokens_details": map[string]any{
2518 "accepted_prediction_tokens": 123,
2519 "rejected_prediction_tokens": 456,
2520 },
2521 },
2522 })
2523
2524 provider := New(
2525 WithAPIKey("test-api-key"),
2526 WithBaseURL(server.server.URL),
2527 )
2528 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2529
2530 stream, err := model.Stream(context.Background(), ai.Call{
2531 Prompt: testPrompt,
2532 })
2533
2534 require.NoError(t, err)
2535
2536 parts, err := collectStreamParts(stream)
2537 require.NoError(t, err)
2538
2539 // Find finish part
2540 var finishPart *ai.StreamPart
2541 for _, part := range parts {
2542 if part.Type == ai.StreamPartTypeFinish {
2543 finishPart = &part
2544 break
2545 }
2546 }
2547
2548 require.NotNil(t, finishPart)
2549 require.NotNil(t, finishPart.ProviderMetadata)
2550
2551 openaiMeta, ok := finishPart.ProviderMetadata["openai"].(*ProviderMetadata)
2552 require.True(t, ok)
2553 require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
2554 require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
2555 })
2556
2557 t.Run("should send store extension setting", func(t *testing.T) {
2558 t.Parallel()
2559
2560 server := newStreamingMockServer()
2561 defer server.close()
2562
2563 server.prepareStreamResponse(map[string]any{
2564 "content": []string{},
2565 })
2566
2567 provider := New(
2568 WithAPIKey("test-api-key"),
2569 WithBaseURL(server.server.URL),
2570 )
2571 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2572
2573 _, err := model.Stream(context.Background(), ai.Call{
2574 Prompt: testPrompt,
2575 ProviderOptions: NewProviderOptions(&ProviderOptions{
2576 Store: ai.BoolOption(true),
2577 }),
2578 })
2579
2580 require.NoError(t, err)
2581 require.Len(t, server.calls, 1)
2582
2583 call := server.calls[0]
2584 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2585 require.Equal(t, true, call.body["stream"])
2586 require.Equal(t, true, call.body["store"])
2587
2588 streamOptions := call.body["stream_options"].(map[string]any)
2589 require.Equal(t, true, streamOptions["include_usage"])
2590
2591 messages := call.body["messages"].([]any)
2592 require.Len(t, messages, 1)
2593
2594 message := messages[0].(map[string]any)
2595 require.Equal(t, "user", message["role"])
2596 require.Equal(t, "Hello", message["content"])
2597 })
2598
2599 t.Run("should send metadata extension values", func(t *testing.T) {
2600 t.Parallel()
2601
2602 server := newStreamingMockServer()
2603 defer server.close()
2604
2605 server.prepareStreamResponse(map[string]any{
2606 "content": []string{},
2607 })
2608
2609 provider := New(
2610 WithAPIKey("test-api-key"),
2611 WithBaseURL(server.server.URL),
2612 )
2613 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2614
2615 _, err := model.Stream(context.Background(), ai.Call{
2616 Prompt: testPrompt,
2617 ProviderOptions: NewProviderOptions(&ProviderOptions{
2618 Metadata: map[string]any{
2619 "custom": "value",
2620 },
2621 }),
2622 })
2623
2624 require.NoError(t, err)
2625 require.Len(t, server.calls, 1)
2626
2627 call := server.calls[0]
2628 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2629 require.Equal(t, true, call.body["stream"])
2630
2631 metadata := call.body["metadata"].(map[string]any)
2632 require.Equal(t, "value", metadata["custom"])
2633
2634 streamOptions := call.body["stream_options"].(map[string]any)
2635 require.Equal(t, true, streamOptions["include_usage"])
2636
2637 messages := call.body["messages"].([]any)
2638 require.Len(t, messages, 1)
2639
2640 message := messages[0].(map[string]any)
2641 require.Equal(t, "user", message["role"])
2642 require.Equal(t, "Hello", message["content"])
2643 })
2644
2645 t.Run("should send serviceTier flex processing setting in streaming", func(t *testing.T) {
2646 t.Parallel()
2647
2648 server := newStreamingMockServer()
2649 defer server.close()
2650
2651 server.prepareStreamResponse(map[string]any{
2652 "content": []string{},
2653 })
2654
2655 provider := New(
2656 WithAPIKey("test-api-key"),
2657 WithBaseURL(server.server.URL),
2658 )
2659 model, _ := provider.LanguageModel("o3-mini")
2660
2661 _, err := model.Stream(context.Background(), ai.Call{
2662 Prompt: testPrompt,
2663 ProviderOptions: NewProviderOptions(&ProviderOptions{
2664 ServiceTier: ai.StringOption("flex"),
2665 }),
2666 })
2667
2668 require.NoError(t, err)
2669 require.Len(t, server.calls, 1)
2670
2671 call := server.calls[0]
2672 require.Equal(t, "o3-mini", call.body["model"])
2673 require.Equal(t, "flex", call.body["service_tier"])
2674 require.Equal(t, true, call.body["stream"])
2675
2676 streamOptions := call.body["stream_options"].(map[string]any)
2677 require.Equal(t, true, streamOptions["include_usage"])
2678
2679 messages := call.body["messages"].([]any)
2680 require.Len(t, messages, 1)
2681
2682 message := messages[0].(map[string]any)
2683 require.Equal(t, "user", message["role"])
2684 require.Equal(t, "Hello", message["content"])
2685 })
2686
2687 t.Run("should send serviceTier priority processing setting in streaming", func(t *testing.T) {
2688 t.Parallel()
2689
2690 server := newStreamingMockServer()
2691 defer server.close()
2692
2693 server.prepareStreamResponse(map[string]any{
2694 "content": []string{},
2695 })
2696
2697 provider := New(
2698 WithAPIKey("test-api-key"),
2699 WithBaseURL(server.server.URL),
2700 )
2701 model, _ := provider.LanguageModel("gpt-4o-mini")
2702
2703 _, err := model.Stream(context.Background(), ai.Call{
2704 Prompt: testPrompt,
2705 ProviderOptions: NewProviderOptions(&ProviderOptions{
2706 ServiceTier: ai.StringOption("priority"),
2707 }),
2708 })
2709
2710 require.NoError(t, err)
2711 require.Len(t, server.calls, 1)
2712
2713 call := server.calls[0]
2714 require.Equal(t, "gpt-4o-mini", call.body["model"])
2715 require.Equal(t, "priority", call.body["service_tier"])
2716 require.Equal(t, true, call.body["stream"])
2717
2718 streamOptions := call.body["stream_options"].(map[string]any)
2719 require.Equal(t, true, streamOptions["include_usage"])
2720
2721 messages := call.body["messages"].([]any)
2722 require.Len(t, messages, 1)
2723
2724 message := messages[0].(map[string]any)
2725 require.Equal(t, "user", message["role"])
2726 require.Equal(t, "Hello", message["content"])
2727 })
2728
2729 t.Run("should stream text delta for reasoning models", func(t *testing.T) {
2730 t.Parallel()
2731
2732 server := newStreamingMockServer()
2733 defer server.close()
2734
2735 server.prepareStreamResponse(map[string]any{
2736 "content": []string{"Hello, World!"},
2737 "model": "o1-preview",
2738 })
2739
2740 provider := New(
2741 WithAPIKey("test-api-key"),
2742 WithBaseURL(server.server.URL),
2743 )
2744 model, _ := provider.LanguageModel("o1-preview")
2745
2746 stream, err := model.Stream(context.Background(), ai.Call{
2747 Prompt: testPrompt,
2748 })
2749
2750 require.NoError(t, err)
2751
2752 parts, err := collectStreamParts(stream)
2753 require.NoError(t, err)
2754
2755 // Find text parts
2756 var textDeltas []string
2757 for _, part := range parts {
2758 if part.Type == ai.StreamPartTypeTextDelta {
2759 textDeltas = append(textDeltas, part.Delta)
2760 }
2761 }
2762
2763 // Should contain the text content (without empty delta)
2764 require.Equal(t, []string{"Hello, World!"}, textDeltas)
2765 })
2766
2767 t.Run("should send reasoning tokens", func(t *testing.T) {
2768 t.Parallel()
2769
2770 server := newStreamingMockServer()
2771 defer server.close()
2772
2773 server.prepareStreamResponse(map[string]any{
2774 "content": []string{"Hello, World!"},
2775 "model": "o1-preview",
2776 "usage": map[string]any{
2777 "prompt_tokens": 15,
2778 "completion_tokens": 20,
2779 "total_tokens": 35,
2780 "completion_tokens_details": map[string]any{
2781 "reasoning_tokens": 10,
2782 },
2783 },
2784 })
2785
2786 provider := New(
2787 WithAPIKey("test-api-key"),
2788 WithBaseURL(server.server.URL),
2789 )
2790 model, _ := provider.LanguageModel("o1-preview")
2791
2792 stream, err := model.Stream(context.Background(), ai.Call{
2793 Prompt: testPrompt,
2794 })
2795
2796 require.NoError(t, err)
2797
2798 parts, err := collectStreamParts(stream)
2799 require.NoError(t, err)
2800
2801 // Find finish part
2802 var finishPart *ai.StreamPart
2803 for _, part := range parts {
2804 if part.Type == ai.StreamPartTypeFinish {
2805 finishPart = &part
2806 break
2807 }
2808 }
2809
2810 require.NotNil(t, finishPart)
2811 require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2812 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2813 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2814 require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens)
2815 })
2816}