1package openai
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "net/http"
9 "net/http/httptest"
10 "strings"
11 "testing"
12
13 "charm.land/fantasy"
14 "github.com/openai/openai-go/v2/packages/param"
15 "github.com/stretchr/testify/require"
16)
17
18func TestToOpenAiPrompt_SystemMessages(t *testing.T) {
19 t.Parallel()
20
21 t.Run("should forward system messages", func(t *testing.T) {
22 t.Parallel()
23
24 prompt := fantasy.Prompt{
25 {
26 Role: fantasy.MessageRoleSystem,
27 Content: []fantasy.MessagePart{
28 fantasy.TextPart{Text: "You are a helpful assistant."},
29 },
30 },
31 }
32
33 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
34
35 require.Empty(t, warnings)
36 require.Len(t, messages, 1)
37
38 systemMsg := messages[0].OfSystem
39 require.NotNil(t, systemMsg)
40 require.Equal(t, "You are a helpful assistant.", systemMsg.Content.OfString.Value)
41 })
42
43 t.Run("should handle empty system messages", func(t *testing.T) {
44 t.Parallel()
45
46 prompt := fantasy.Prompt{
47 {
48 Role: fantasy.MessageRoleSystem,
49 Content: []fantasy.MessagePart{},
50 },
51 }
52
53 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
54
55 require.Len(t, warnings, 1)
56 require.Contains(t, warnings[0].Message, "system prompt has no text parts")
57 require.Empty(t, messages)
58 })
59
60 t.Run("should join multiple system text parts", func(t *testing.T) {
61 t.Parallel()
62
63 prompt := fantasy.Prompt{
64 {
65 Role: fantasy.MessageRoleSystem,
66 Content: []fantasy.MessagePart{
67 fantasy.TextPart{Text: "You are a helpful assistant."},
68 fantasy.TextPart{Text: "Be concise."},
69 },
70 },
71 }
72
73 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
74
75 require.Empty(t, warnings)
76 require.Len(t, messages, 1)
77
78 systemMsg := messages[0].OfSystem
79 require.NotNil(t, systemMsg)
80 require.Equal(t, "You are a helpful assistant.\nBe concise.", systemMsg.Content.OfString.Value)
81 })
82}
83
84func TestToOpenAiPrompt_UserMessages(t *testing.T) {
85 t.Parallel()
86
87 t.Run("should convert messages with only a text part to a string content", func(t *testing.T) {
88 t.Parallel()
89
90 prompt := fantasy.Prompt{
91 {
92 Role: fantasy.MessageRoleUser,
93 Content: []fantasy.MessagePart{
94 fantasy.TextPart{Text: "Hello"},
95 },
96 },
97 }
98
99 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
100
101 require.Empty(t, warnings)
102 require.Len(t, messages, 1)
103
104 userMsg := messages[0].OfUser
105 require.NotNil(t, userMsg)
106 require.Equal(t, "Hello", userMsg.Content.OfString.Value)
107 })
108
109 t.Run("should convert messages with image parts", func(t *testing.T) {
110 t.Parallel()
111
112 imageData := []byte{0, 1, 2, 3}
113 prompt := fantasy.Prompt{
114 {
115 Role: fantasy.MessageRoleUser,
116 Content: []fantasy.MessagePart{
117 fantasy.TextPart{Text: "Hello"},
118 fantasy.FilePart{
119 MediaType: "image/png",
120 Data: imageData,
121 },
122 },
123 },
124 }
125
126 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
127
128 require.Empty(t, warnings)
129 require.Len(t, messages, 1)
130
131 userMsg := messages[0].OfUser
132 require.NotNil(t, userMsg)
133
134 content := userMsg.Content.OfArrayOfContentParts
135 require.Len(t, content, 2)
136
137 // Check text part
138 textPart := content[0].OfText
139 require.NotNil(t, textPart)
140 require.Equal(t, "Hello", textPart.Text)
141
142 // Check image part
143 imagePart := content[1].OfImageURL
144 require.NotNil(t, imagePart)
145 expectedURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
146 require.Equal(t, expectedURL, imagePart.ImageURL.URL)
147 })
148
149 t.Run("should add image detail when specified through provider options", func(t *testing.T) {
150 t.Parallel()
151
152 imageData := []byte{0, 1, 2, 3}
153 prompt := fantasy.Prompt{
154 {
155 Role: fantasy.MessageRoleUser,
156 Content: []fantasy.MessagePart{
157 fantasy.FilePart{
158 MediaType: "image/png",
159 Data: imageData,
160 ProviderOptions: NewProviderFileOptions(&ProviderFileOptions{
161 ImageDetail: "low",
162 }),
163 },
164 },
165 },
166 }
167
168 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
169
170 require.Empty(t, warnings)
171 require.Len(t, messages, 1)
172
173 userMsg := messages[0].OfUser
174 require.NotNil(t, userMsg)
175
176 content := userMsg.Content.OfArrayOfContentParts
177 require.Len(t, content, 1)
178
179 imagePart := content[0].OfImageURL
180 require.NotNil(t, imagePart)
181 require.Equal(t, "low", imagePart.ImageURL.Detail)
182 })
183}
184
185func TestToOpenAiPrompt_FileParts(t *testing.T) {
186 t.Parallel()
187
188 t.Run("should throw for unsupported mime types", func(t *testing.T) {
189 t.Parallel()
190
191 prompt := fantasy.Prompt{
192 {
193 Role: fantasy.MessageRoleUser,
194 Content: []fantasy.MessagePart{
195 fantasy.FilePart{
196 MediaType: "application/something",
197 Data: []byte("test"),
198 },
199 },
200 },
201 }
202
203 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
204
205 require.Len(t, warnings, 1)
206 require.Contains(t, warnings[0].Message, "file part media type application/something not supported")
207 require.Len(t, messages, 1) // Message is still created but with empty content array
208 })
209
210 t.Run("should add audio content for audio/wav file parts", func(t *testing.T) {
211 t.Parallel()
212
213 audioData := []byte{0, 1, 2, 3}
214 prompt := fantasy.Prompt{
215 {
216 Role: fantasy.MessageRoleUser,
217 Content: []fantasy.MessagePart{
218 fantasy.FilePart{
219 MediaType: "audio/wav",
220 Data: audioData,
221 },
222 },
223 },
224 }
225
226 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
227
228 require.Empty(t, warnings)
229 require.Len(t, messages, 1)
230
231 userMsg := messages[0].OfUser
232 require.NotNil(t, userMsg)
233
234 content := userMsg.Content.OfArrayOfContentParts
235 require.Len(t, content, 1)
236
237 audioPart := content[0].OfInputAudio
238 require.NotNil(t, audioPart)
239 require.Equal(t, base64.StdEncoding.EncodeToString(audioData), audioPart.InputAudio.Data)
240 require.Equal(t, "wav", audioPart.InputAudio.Format)
241 })
242
243 t.Run("should add audio content for audio/mpeg file parts", func(t *testing.T) {
244 t.Parallel()
245
246 audioData := []byte{0, 1, 2, 3}
247 prompt := fantasy.Prompt{
248 {
249 Role: fantasy.MessageRoleUser,
250 Content: []fantasy.MessagePart{
251 fantasy.FilePart{
252 MediaType: "audio/mpeg",
253 Data: audioData,
254 },
255 },
256 },
257 }
258
259 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
260
261 require.Empty(t, warnings)
262 require.Len(t, messages, 1)
263
264 userMsg := messages[0].OfUser
265 content := userMsg.Content.OfArrayOfContentParts
266 audioPart := content[0].OfInputAudio
267 require.NotNil(t, audioPart)
268 require.Equal(t, "mp3", audioPart.InputAudio.Format)
269 })
270
271 t.Run("should add audio content for audio/mp3 file parts", func(t *testing.T) {
272 t.Parallel()
273
274 audioData := []byte{0, 1, 2, 3}
275 prompt := fantasy.Prompt{
276 {
277 Role: fantasy.MessageRoleUser,
278 Content: []fantasy.MessagePart{
279 fantasy.FilePart{
280 MediaType: "audio/mp3",
281 Data: audioData,
282 },
283 },
284 },
285 }
286
287 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
288
289 require.Empty(t, warnings)
290 require.Len(t, messages, 1)
291
292 userMsg := messages[0].OfUser
293 content := userMsg.Content.OfArrayOfContentParts
294 audioPart := content[0].OfInputAudio
295 require.NotNil(t, audioPart)
296 require.Equal(t, "mp3", audioPart.InputAudio.Format)
297 })
298
299 t.Run("should convert messages with PDF file parts", func(t *testing.T) {
300 t.Parallel()
301
302 pdfData := []byte{1, 2, 3, 4, 5}
303 prompt := fantasy.Prompt{
304 {
305 Role: fantasy.MessageRoleUser,
306 Content: []fantasy.MessagePart{
307 fantasy.FilePart{
308 MediaType: "application/pdf",
309 Data: pdfData,
310 Filename: "document.pdf",
311 },
312 },
313 },
314 }
315
316 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
317
318 require.Empty(t, warnings)
319 require.Len(t, messages, 1)
320
321 userMsg := messages[0].OfUser
322 content := userMsg.Content.OfArrayOfContentParts
323 require.Len(t, content, 1)
324
325 filePart := content[0].OfFile
326 require.NotNil(t, filePart)
327 require.Equal(t, "document.pdf", filePart.File.Filename.Value)
328
329 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
330 require.Equal(t, expectedData, filePart.File.FileData.Value)
331 })
332
333 t.Run("should convert messages with binary PDF file parts", func(t *testing.T) {
334 t.Parallel()
335
336 pdfData := []byte{1, 2, 3, 4, 5}
337 prompt := fantasy.Prompt{
338 {
339 Role: fantasy.MessageRoleUser,
340 Content: []fantasy.MessagePart{
341 fantasy.FilePart{
342 MediaType: "application/pdf",
343 Data: pdfData,
344 Filename: "document.pdf",
345 },
346 },
347 },
348 }
349
350 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
351
352 require.Empty(t, warnings)
353 require.Len(t, messages, 1)
354
355 userMsg := messages[0].OfUser
356 content := userMsg.Content.OfArrayOfContentParts
357 filePart := content[0].OfFile
358 require.NotNil(t, filePart)
359
360 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
361 require.Equal(t, expectedData, filePart.File.FileData.Value)
362 })
363
364 t.Run("should convert messages with PDF file parts using file_id", func(t *testing.T) {
365 t.Parallel()
366
367 prompt := fantasy.Prompt{
368 {
369 Role: fantasy.MessageRoleUser,
370 Content: []fantasy.MessagePart{
371 fantasy.FilePart{
372 MediaType: "application/pdf",
373 Data: []byte("file-pdf-12345"),
374 },
375 },
376 },
377 }
378
379 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
380
381 require.Empty(t, warnings)
382 require.Len(t, messages, 1)
383
384 userMsg := messages[0].OfUser
385 content := userMsg.Content.OfArrayOfContentParts
386 filePart := content[0].OfFile
387 require.NotNil(t, filePart)
388 require.Equal(t, "file-pdf-12345", filePart.File.FileID.Value)
389 require.True(t, param.IsOmitted(filePart.File.FileData))
390 require.True(t, param.IsOmitted(filePart.File.Filename))
391 })
392
393 t.Run("should use default filename for PDF file parts when not provided", func(t *testing.T) {
394 t.Parallel()
395
396 pdfData := []byte{1, 2, 3, 4, 5}
397 prompt := fantasy.Prompt{
398 {
399 Role: fantasy.MessageRoleUser,
400 Content: []fantasy.MessagePart{
401 fantasy.FilePart{
402 MediaType: "application/pdf",
403 Data: pdfData,
404 },
405 },
406 },
407 }
408
409 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
410
411 require.Empty(t, warnings)
412 require.Len(t, messages, 1)
413
414 userMsg := messages[0].OfUser
415 content := userMsg.Content.OfArrayOfContentParts
416 filePart := content[0].OfFile
417 require.NotNil(t, filePart)
418 require.Equal(t, "part-0.pdf", filePart.File.Filename.Value)
419 })
420}
421
422func TestToOpenAiPrompt_ToolCalls(t *testing.T) {
423 t.Parallel()
424
425 t.Run("should stringify arguments to tool calls", func(t *testing.T) {
426 t.Parallel()
427
428 inputArgs := map[string]any{"foo": "bar123"}
429 inputJSON, _ := json.Marshal(inputArgs)
430
431 outputResult := map[string]any{"oof": "321rab"}
432 outputJSON, _ := json.Marshal(outputResult)
433
434 prompt := fantasy.Prompt{
435 {
436 Role: fantasy.MessageRoleAssistant,
437 Content: []fantasy.MessagePart{
438 fantasy.ToolCallPart{
439 ToolCallID: "quux",
440 ToolName: "thwomp",
441 Input: string(inputJSON),
442 },
443 },
444 },
445 {
446 Role: fantasy.MessageRoleTool,
447 Content: []fantasy.MessagePart{
448 fantasy.ToolResultPart{
449 ToolCallID: "quux",
450 Output: fantasy.ToolResultOutputContentText{
451 Text: string(outputJSON),
452 },
453 },
454 },
455 },
456 }
457
458 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
459
460 require.Empty(t, warnings)
461 require.Len(t, messages, 2)
462
463 // Check assistant message with tool call
464 assistantMsg := messages[0].OfAssistant
465 require.NotNil(t, assistantMsg)
466 require.Equal(t, "", assistantMsg.Content.OfString.Value)
467 require.Len(t, assistantMsg.ToolCalls, 1)
468
469 toolCall := assistantMsg.ToolCalls[0].OfFunction
470 require.NotNil(t, toolCall)
471 require.Equal(t, "quux", toolCall.ID)
472 require.Equal(t, "thwomp", toolCall.Function.Name)
473 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
474
475 // Check tool message
476 toolMsg := messages[1].OfTool
477 require.NotNil(t, toolMsg)
478 require.Equal(t, string(outputJSON), toolMsg.Content.OfString.Value)
479 require.Equal(t, "quux", toolMsg.ToolCallID)
480 })
481
482 t.Run("should handle different tool output types", func(t *testing.T) {
483 t.Parallel()
484
485 prompt := fantasy.Prompt{
486 {
487 Role: fantasy.MessageRoleTool,
488 Content: []fantasy.MessagePart{
489 fantasy.ToolResultPart{
490 ToolCallID: "text-tool",
491 Output: fantasy.ToolResultOutputContentText{
492 Text: "Hello world",
493 },
494 },
495 fantasy.ToolResultPart{
496 ToolCallID: "error-tool",
497 Output: fantasy.ToolResultOutputContentError{
498 Error: errors.New("Something went wrong"),
499 },
500 },
501 },
502 },
503 }
504
505 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
506
507 require.Empty(t, warnings)
508 require.Len(t, messages, 2)
509
510 // Check first tool message (text)
511 textToolMsg := messages[0].OfTool
512 require.NotNil(t, textToolMsg)
513 require.Equal(t, "Hello world", textToolMsg.Content.OfString.Value)
514 require.Equal(t, "text-tool", textToolMsg.ToolCallID)
515
516 // Check second tool message (error)
517 errorToolMsg := messages[1].OfTool
518 require.NotNil(t, errorToolMsg)
519 require.Equal(t, "Something went wrong", errorToolMsg.Content.OfString.Value)
520 require.Equal(t, "error-tool", errorToolMsg.ToolCallID)
521 })
522}
523
524func TestToOpenAiPrompt_AssistantMessages(t *testing.T) {
525 t.Parallel()
526
527 t.Run("should handle simple text assistant messages", func(t *testing.T) {
528 t.Parallel()
529
530 prompt := fantasy.Prompt{
531 {
532 Role: fantasy.MessageRoleAssistant,
533 Content: []fantasy.MessagePart{
534 fantasy.TextPart{Text: "Hello, how can I help you?"},
535 },
536 },
537 }
538
539 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
540
541 require.Empty(t, warnings)
542 require.Len(t, messages, 1)
543
544 assistantMsg := messages[0].OfAssistant
545 require.NotNil(t, assistantMsg)
546 require.Equal(t, "Hello, how can I help you?", assistantMsg.Content.OfString.Value)
547 })
548
549 t.Run("should handle assistant messages with mixed content", func(t *testing.T) {
550 t.Parallel()
551
552 inputArgs := map[string]any{"query": "test"}
553 inputJSON, _ := json.Marshal(inputArgs)
554
555 prompt := fantasy.Prompt{
556 {
557 Role: fantasy.MessageRoleAssistant,
558 Content: []fantasy.MessagePart{
559 fantasy.TextPart{Text: "Let me search for that."},
560 fantasy.ToolCallPart{
561 ToolCallID: "call-123",
562 ToolName: "search",
563 Input: string(inputJSON),
564 },
565 },
566 },
567 }
568
569 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
570
571 require.Empty(t, warnings)
572 require.Len(t, messages, 1)
573
574 assistantMsg := messages[0].OfAssistant
575 require.NotNil(t, assistantMsg)
576 require.Equal(t, "Let me search for that.", assistantMsg.Content.OfString.Value)
577 require.Len(t, assistantMsg.ToolCalls, 1)
578
579 toolCall := assistantMsg.ToolCalls[0].OfFunction
580 require.Equal(t, "call-123", toolCall.ID)
581 require.Equal(t, "search", toolCall.Function.Name)
582 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
583 })
584}
585
586var testPrompt = fantasy.Prompt{
587 {
588 Role: fantasy.MessageRoleUser,
589 Content: []fantasy.MessagePart{
590 fantasy.TextPart{Text: "Hello"},
591 },
592 },
593}
594
595var testLogprobs = map[string]any{
596 "content": []map[string]any{
597 {
598 "token": "Hello",
599 "logprob": -0.0009994634,
600 "top_logprobs": []map[string]any{
601 {
602 "token": "Hello",
603 "logprob": -0.0009994634,
604 },
605 },
606 },
607 {
608 "token": "!",
609 "logprob": -0.13410144,
610 "top_logprobs": []map[string]any{
611 {
612 "token": "!",
613 "logprob": -0.13410144,
614 },
615 },
616 },
617 {
618 "token": " How",
619 "logprob": -0.0009250381,
620 "top_logprobs": []map[string]any{
621 {
622 "token": " How",
623 "logprob": -0.0009250381,
624 },
625 },
626 },
627 {
628 "token": " can",
629 "logprob": -0.047709424,
630 "top_logprobs": []map[string]any{
631 {
632 "token": " can",
633 "logprob": -0.047709424,
634 },
635 },
636 },
637 {
638 "token": " I",
639 "logprob": -0.000009014684,
640 "top_logprobs": []map[string]any{
641 {
642 "token": " I",
643 "logprob": -0.000009014684,
644 },
645 },
646 },
647 {
648 "token": " assist",
649 "logprob": -0.009125131,
650 "top_logprobs": []map[string]any{
651 {
652 "token": " assist",
653 "logprob": -0.009125131,
654 },
655 },
656 },
657 {
658 "token": " you",
659 "logprob": -0.0000066306106,
660 "top_logprobs": []map[string]any{
661 {
662 "token": " you",
663 "logprob": -0.0000066306106,
664 },
665 },
666 },
667 {
668 "token": " today",
669 "logprob": -0.00011093382,
670 "top_logprobs": []map[string]any{
671 {
672 "token": " today",
673 "logprob": -0.00011093382,
674 },
675 },
676 },
677 {
678 "token": "?",
679 "logprob": -0.00004596782,
680 "top_logprobs": []map[string]any{
681 {
682 "token": "?",
683 "logprob": -0.00004596782,
684 },
685 },
686 },
687 },
688}
689
690type mockServer struct {
691 server *httptest.Server
692 response map[string]any
693 calls []mockCall
694}
695
696type mockCall struct {
697 method string
698 path string
699 headers map[string]string
700 body map[string]any
701}
702
703func newMockServer() *mockServer {
704 ms := &mockServer{
705 calls: make([]mockCall, 0),
706 }
707
708 ms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
709 // Record the call
710 call := mockCall{
711 method: r.Method,
712 path: r.URL.Path,
713 headers: make(map[string]string),
714 }
715
716 for k, v := range r.Header {
717 if len(v) > 0 {
718 call.headers[k] = v[0]
719 }
720 }
721
722 // Parse request body
723 if r.Body != nil {
724 var body map[string]any
725 json.NewDecoder(r.Body).Decode(&body)
726 call.body = body
727 }
728
729 ms.calls = append(ms.calls, call)
730
731 // Return mock response
732 w.Header().Set("Content-Type", "application/json")
733 json.NewEncoder(w).Encode(ms.response)
734 }))
735
736 return ms
737}
738
739func (ms *mockServer) close() {
740 ms.server.Close()
741}
742
743func (ms *mockServer) prepareJSONResponse(opts map[string]any) {
744 // Default values
745 response := map[string]any{
746 "id": "chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd",
747 "object": "chat.completion",
748 "created": 1711115037,
749 "model": "gpt-3.5-turbo-0125",
750 "choices": []map[string]any{
751 {
752 "index": 0,
753 "message": map[string]any{
754 "role": "assistant",
755 "content": "",
756 },
757 "finish_reason": "stop",
758 },
759 },
760 "usage": map[string]any{
761 "prompt_tokens": 4,
762 "total_tokens": 34,
763 "completion_tokens": 30,
764 },
765 "system_fingerprint": "fp_3bc1b5746c",
766 }
767
768 // Override with provided options
769 for k, v := range opts {
770 switch k {
771 case "content":
772 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["content"] = v
773 case "tool_calls":
774 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["tool_calls"] = v
775 case "function_call":
776 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["function_call"] = v
777 case "annotations":
778 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["annotations"] = v
779 case "usage":
780 response["usage"] = v
781 case "finish_reason":
782 response["choices"].([]map[string]any)[0]["finish_reason"] = v
783 case "id":
784 response["id"] = v
785 case "created":
786 response["created"] = v
787 case "model":
788 response["model"] = v
789 case "logprobs":
790 if v != nil {
791 response["choices"].([]map[string]any)[0]["logprobs"] = v
792 }
793 }
794 }
795
796 ms.response = response
797}
798
799func TestDoGenerate(t *testing.T) {
800 t.Parallel()
801
802 t.Run("should extract text response", func(t *testing.T) {
803 t.Parallel()
804
805 server := newMockServer()
806 defer server.close()
807
808 server.prepareJSONResponse(map[string]any{
809 "content": "Hello, World!",
810 })
811
812 provider, err := New(
813 WithAPIKey("test-api-key"),
814 WithBaseURL(server.server.URL),
815 )
816 require.NoError(t, err)
817 model, _ := provider.LanguageModel("gpt-3.5-turbo")
818
819 result, err := model.Generate(context.Background(), fantasy.Call{
820 Prompt: testPrompt,
821 })
822
823 require.NoError(t, err)
824 require.Len(t, result.Content, 1)
825
826 textContent, ok := result.Content[0].(fantasy.TextContent)
827 require.True(t, ok)
828 require.Equal(t, "Hello, World!", textContent.Text)
829 })
830
831 t.Run("should extract usage", func(t *testing.T) {
832 t.Parallel()
833
834 server := newMockServer()
835 defer server.close()
836
837 server.prepareJSONResponse(map[string]any{
838 "usage": map[string]any{
839 "prompt_tokens": 20,
840 "total_tokens": 25,
841 "completion_tokens": 5,
842 },
843 })
844
845 provider, err := New(
846 WithAPIKey("test-api-key"),
847 WithBaseURL(server.server.URL),
848 )
849 require.NoError(t, err)
850 model, _ := provider.LanguageModel("gpt-3.5-turbo")
851
852 result, err := model.Generate(context.Background(), fantasy.Call{
853 Prompt: testPrompt,
854 })
855
856 require.NoError(t, err)
857 require.Equal(t, int64(20), result.Usage.InputTokens)
858 require.Equal(t, int64(5), result.Usage.OutputTokens)
859 require.Equal(t, int64(25), result.Usage.TotalTokens)
860 })
861
862 t.Run("should send request body", func(t *testing.T) {
863 t.Parallel()
864
865 server := newMockServer()
866 defer server.close()
867
868 server.prepareJSONResponse(map[string]any{})
869
870 provider, err := New(
871 WithAPIKey("test-api-key"),
872 WithBaseURL(server.server.URL),
873 )
874 require.NoError(t, err)
875 model, _ := provider.LanguageModel("gpt-3.5-turbo")
876
877 _, err = model.Generate(context.Background(), fantasy.Call{
878 Prompt: testPrompt,
879 })
880
881 require.NoError(t, err)
882 require.Len(t, server.calls, 1)
883
884 call := server.calls[0]
885 require.Equal(t, "POST", call.method)
886 require.Equal(t, "/chat/completions", call.path)
887 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
888
889 messages, ok := call.body["messages"].([]any)
890 require.True(t, ok)
891 require.Len(t, messages, 1)
892
893 message := messages[0].(map[string]any)
894 require.Equal(t, "user", message["role"])
895 require.Equal(t, "Hello", message["content"])
896 })
897
898 t.Run("should support partial usage", func(t *testing.T) {
899 t.Parallel()
900
901 server := newMockServer()
902 defer server.close()
903
904 server.prepareJSONResponse(map[string]any{
905 "usage": map[string]any{
906 "prompt_tokens": 20,
907 "total_tokens": 20,
908 },
909 })
910
911 provider, err := New(
912 WithAPIKey("test-api-key"),
913 WithBaseURL(server.server.URL),
914 )
915 require.NoError(t, err)
916 model, _ := provider.LanguageModel("gpt-3.5-turbo")
917
918 result, err := model.Generate(context.Background(), fantasy.Call{
919 Prompt: testPrompt,
920 })
921
922 require.NoError(t, err)
923 require.Equal(t, int64(20), result.Usage.InputTokens)
924 require.Equal(t, int64(0), result.Usage.OutputTokens)
925 require.Equal(t, int64(20), result.Usage.TotalTokens)
926 })
927
928 t.Run("should extract logprobs", func(t *testing.T) {
929 t.Parallel()
930
931 server := newMockServer()
932 defer server.close()
933
934 server.prepareJSONResponse(map[string]any{
935 "logprobs": testLogprobs,
936 })
937
938 provider, err := New(
939 WithAPIKey("test-api-key"),
940 WithBaseURL(server.server.URL),
941 )
942 require.NoError(t, err)
943 model, _ := provider.LanguageModel("gpt-3.5-turbo")
944
945 result, err := model.Generate(context.Background(), fantasy.Call{
946 Prompt: testPrompt,
947 ProviderOptions: NewProviderOptions(&ProviderOptions{
948 LogProbs: fantasy.Opt(true),
949 }),
950 })
951
952 require.NoError(t, err)
953 require.NotNil(t, result.ProviderMetadata)
954
955 openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
956 require.True(t, ok)
957
958 logprobs := openaiMeta.Logprobs
959 require.True(t, ok)
960 require.NotNil(t, logprobs)
961 })
962
963 t.Run("should extract finish reason", func(t *testing.T) {
964 t.Parallel()
965
966 server := newMockServer()
967 defer server.close()
968
969 server.prepareJSONResponse(map[string]any{
970 "finish_reason": "stop",
971 })
972
973 provider, err := New(
974 WithAPIKey("test-api-key"),
975 WithBaseURL(server.server.URL),
976 )
977 require.NoError(t, err)
978 model, _ := provider.LanguageModel("gpt-3.5-turbo")
979
980 result, err := model.Generate(context.Background(), fantasy.Call{
981 Prompt: testPrompt,
982 })
983
984 require.NoError(t, err)
985 require.Equal(t, fantasy.FinishReasonStop, result.FinishReason)
986 })
987
988 t.Run("should support unknown finish reason", func(t *testing.T) {
989 t.Parallel()
990
991 server := newMockServer()
992 defer server.close()
993
994 server.prepareJSONResponse(map[string]any{
995 "finish_reason": "eos",
996 })
997
998 provider, err := New(
999 WithAPIKey("test-api-key"),
1000 WithBaseURL(server.server.URL),
1001 )
1002 require.NoError(t, err)
1003 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1004
1005 result, err := model.Generate(context.Background(), fantasy.Call{
1006 Prompt: testPrompt,
1007 })
1008
1009 require.NoError(t, err)
1010 require.Equal(t, fantasy.FinishReasonUnknown, result.FinishReason)
1011 })
1012
1013 t.Run("should pass the model and the messages", func(t *testing.T) {
1014 t.Parallel()
1015
1016 server := newMockServer()
1017 defer server.close()
1018
1019 server.prepareJSONResponse(map[string]any{
1020 "content": "",
1021 })
1022
1023 provider, err := New(
1024 WithAPIKey("test-api-key"),
1025 WithBaseURL(server.server.URL),
1026 )
1027 require.NoError(t, err)
1028 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1029
1030 _, err = model.Generate(context.Background(), fantasy.Call{
1031 Prompt: testPrompt,
1032 })
1033
1034 require.NoError(t, err)
1035 require.Len(t, server.calls, 1)
1036
1037 call := server.calls[0]
1038 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1039
1040 messages := call.body["messages"].([]any)
1041 require.Len(t, messages, 1)
1042
1043 message := messages[0].(map[string]any)
1044 require.Equal(t, "user", message["role"])
1045 require.Equal(t, "Hello", message["content"])
1046 })
1047
1048 t.Run("should pass settings", func(t *testing.T) {
1049 t.Parallel()
1050
1051 server := newMockServer()
1052 defer server.close()
1053
1054 server.prepareJSONResponse(map[string]any{})
1055
1056 provider, err := New(
1057 WithAPIKey("test-api-key"),
1058 WithBaseURL(server.server.URL),
1059 )
1060 require.NoError(t, err)
1061 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1062
1063 _, err = model.Generate(context.Background(), fantasy.Call{
1064 Prompt: testPrompt,
1065 ProviderOptions: NewProviderOptions(&ProviderOptions{
1066 LogitBias: map[string]int64{
1067 "50256": -100,
1068 },
1069 ParallelToolCalls: fantasy.Opt(false),
1070 User: fantasy.Opt("test-user-id"),
1071 }),
1072 })
1073
1074 require.NoError(t, err)
1075 require.Len(t, server.calls, 1)
1076
1077 call := server.calls[0]
1078 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1079
1080 messages := call.body["messages"].([]any)
1081 require.Len(t, messages, 1)
1082
1083 logitBias := call.body["logit_bias"].(map[string]any)
1084 require.Equal(t, float64(-100), logitBias["50256"])
1085 require.Equal(t, false, call.body["parallel_tool_calls"])
1086 require.Equal(t, "test-user-id", call.body["user"])
1087 })
1088
1089 t.Run("should pass reasoningEffort setting", func(t *testing.T) {
1090 t.Parallel()
1091
1092 server := newMockServer()
1093 defer server.close()
1094
1095 server.prepareJSONResponse(map[string]any{
1096 "content": "",
1097 })
1098
1099 provider, err := New(
1100 WithAPIKey("test-api-key"),
1101 WithBaseURL(server.server.URL),
1102 )
1103 require.NoError(t, err)
1104 model, _ := provider.LanguageModel("o1-mini")
1105
1106 _, err = model.Generate(context.Background(), fantasy.Call{
1107 Prompt: testPrompt,
1108 ProviderOptions: NewProviderOptions(
1109 &ProviderOptions{
1110 ReasoningEffort: ReasoningEffortOption(ReasoningEffortLow),
1111 },
1112 ),
1113 })
1114
1115 require.NoError(t, err)
1116 require.Len(t, server.calls, 1)
1117
1118 call := server.calls[0]
1119 require.Equal(t, "o1-mini", call.body["model"])
1120 require.Equal(t, "low", call.body["reasoning_effort"])
1121
1122 messages := call.body["messages"].([]any)
1123 require.Len(t, messages, 1)
1124
1125 message := messages[0].(map[string]any)
1126 require.Equal(t, "user", message["role"])
1127 require.Equal(t, "Hello", message["content"])
1128 })
1129
1130 t.Run("should pass textVerbosity setting", func(t *testing.T) {
1131 t.Parallel()
1132
1133 server := newMockServer()
1134 defer server.close()
1135
1136 server.prepareJSONResponse(map[string]any{
1137 "content": "",
1138 })
1139
1140 provider, err := New(
1141 WithAPIKey("test-api-key"),
1142 WithBaseURL(server.server.URL),
1143 )
1144 require.NoError(t, err)
1145 model, _ := provider.LanguageModel("gpt-4o")
1146
1147 _, err = model.Generate(context.Background(), fantasy.Call{
1148 Prompt: testPrompt,
1149 ProviderOptions: NewProviderOptions(&ProviderOptions{
1150 TextVerbosity: fantasy.Opt("low"),
1151 }),
1152 })
1153
1154 require.NoError(t, err)
1155 require.Len(t, server.calls, 1)
1156
1157 call := server.calls[0]
1158 require.Equal(t, "gpt-4o", call.body["model"])
1159 require.Equal(t, "low", call.body["verbosity"])
1160
1161 messages := call.body["messages"].([]any)
1162 require.Len(t, messages, 1)
1163
1164 message := messages[0].(map[string]any)
1165 require.Equal(t, "user", message["role"])
1166 require.Equal(t, "Hello", message["content"])
1167 })
1168
1169 t.Run("should pass tools and toolChoice", func(t *testing.T) {
1170 t.Parallel()
1171
1172 server := newMockServer()
1173 defer server.close()
1174
1175 server.prepareJSONResponse(map[string]any{
1176 "content": "",
1177 })
1178
1179 provider, err := New(
1180 WithAPIKey("test-api-key"),
1181 WithBaseURL(server.server.URL),
1182 )
1183 require.NoError(t, err)
1184 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1185
1186 _, err = model.Generate(context.Background(), fantasy.Call{
1187 Prompt: testPrompt,
1188 Tools: []fantasy.Tool{
1189 fantasy.FunctionTool{
1190 Name: "test-tool",
1191 InputSchema: map[string]any{
1192 "type": "object",
1193 "properties": map[string]any{
1194 "value": map[string]any{
1195 "type": "string",
1196 },
1197 },
1198 "required": []string{"value"},
1199 "additionalProperties": false,
1200 "$schema": "http://json-schema.org/draft-07/schema#",
1201 },
1202 },
1203 },
1204 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoice("test-tool")}[0],
1205 })
1206
1207 require.NoError(t, err)
1208 require.Len(t, server.calls, 1)
1209
1210 call := server.calls[0]
1211 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1212
1213 messages := call.body["messages"].([]any)
1214 require.Len(t, messages, 1)
1215
1216 tools := call.body["tools"].([]any)
1217 require.Len(t, tools, 1)
1218
1219 tool := tools[0].(map[string]any)
1220 require.Equal(t, "function", tool["type"])
1221
1222 function := tool["function"].(map[string]any)
1223 require.Equal(t, "test-tool", function["name"])
1224 require.Equal(t, false, function["strict"])
1225
1226 toolChoice := call.body["tool_choice"].(map[string]any)
1227 require.Equal(t, "function", toolChoice["type"])
1228
1229 toolChoiceFunction := toolChoice["function"].(map[string]any)
1230 require.Equal(t, "test-tool", toolChoiceFunction["name"])
1231 })
1232
1233 t.Run("should parse tool results", func(t *testing.T) {
1234 t.Parallel()
1235
1236 server := newMockServer()
1237 defer server.close()
1238
1239 server.prepareJSONResponse(map[string]any{
1240 "tool_calls": []map[string]any{
1241 {
1242 "id": "call_O17Uplv4lJvD6DVdIvFFeRMw",
1243 "type": "function",
1244 "function": map[string]any{
1245 "name": "test-tool",
1246 "arguments": `{"value":"Spark"}`,
1247 },
1248 },
1249 },
1250 })
1251
1252 provider, err := New(
1253 WithAPIKey("test-api-key"),
1254 WithBaseURL(server.server.URL),
1255 )
1256 require.NoError(t, err)
1257 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1258
1259 result, err := model.Generate(context.Background(), fantasy.Call{
1260 Prompt: testPrompt,
1261 Tools: []fantasy.Tool{
1262 fantasy.FunctionTool{
1263 Name: "test-tool",
1264 InputSchema: map[string]any{
1265 "type": "object",
1266 "properties": map[string]any{
1267 "value": map[string]any{
1268 "type": "string",
1269 },
1270 },
1271 "required": []string{"value"},
1272 "additionalProperties": false,
1273 "$schema": "http://json-schema.org/draft-07/schema#",
1274 },
1275 },
1276 },
1277 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoice("test-tool")}[0],
1278 })
1279
1280 require.NoError(t, err)
1281 require.Len(t, result.Content, 1)
1282
1283 toolCall, ok := result.Content[0].(fantasy.ToolCallContent)
1284 require.True(t, ok)
1285 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", toolCall.ToolCallID)
1286 require.Equal(t, "test-tool", toolCall.ToolName)
1287 require.Equal(t, `{"value":"Spark"}`, toolCall.Input)
1288 })
1289
1290 t.Run("should parse annotations/citations", func(t *testing.T) {
1291 t.Parallel()
1292
1293 server := newMockServer()
1294 defer server.close()
1295
1296 server.prepareJSONResponse(map[string]any{
1297 "content": "Based on the search results [doc1], I found information.",
1298 "annotations": []map[string]any{
1299 {
1300 "type": "url_citation",
1301 "url_citation": map[string]any{
1302 "start_index": 24,
1303 "end_index": 29,
1304 "url": "https://example.com/doc1.pdf",
1305 "title": "Document 1",
1306 },
1307 },
1308 },
1309 })
1310
1311 provider, err := New(
1312 WithAPIKey("test-api-key"),
1313 WithBaseURL(server.server.URL),
1314 )
1315 require.NoError(t, err)
1316 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1317
1318 result, err := model.Generate(context.Background(), fantasy.Call{
1319 Prompt: testPrompt,
1320 })
1321
1322 require.NoError(t, err)
1323 require.Len(t, result.Content, 2)
1324
1325 textContent, ok := result.Content[0].(fantasy.TextContent)
1326 require.True(t, ok)
1327 require.Equal(t, "Based on the search results [doc1], I found information.", textContent.Text)
1328
1329 sourceContent, ok := result.Content[1].(fantasy.SourceContent)
1330 require.True(t, ok)
1331 require.Equal(t, fantasy.SourceTypeURL, sourceContent.SourceType)
1332 require.Equal(t, "https://example.com/doc1.pdf", sourceContent.URL)
1333 require.Equal(t, "Document 1", sourceContent.Title)
1334 require.NotEmpty(t, sourceContent.ID)
1335 })
1336
1337 t.Run("should return cached_tokens in prompt_details_tokens", func(t *testing.T) {
1338 t.Parallel()
1339
1340 server := newMockServer()
1341 defer server.close()
1342
1343 server.prepareJSONResponse(map[string]any{
1344 "usage": map[string]any{
1345 "prompt_tokens": 15,
1346 "completion_tokens": 20,
1347 "total_tokens": 35,
1348 "prompt_tokens_details": map[string]any{
1349 "cached_tokens": 1152,
1350 },
1351 },
1352 })
1353
1354 provider, err := New(
1355 WithAPIKey("test-api-key"),
1356 WithBaseURL(server.server.URL),
1357 )
1358 require.NoError(t, err)
1359 model, _ := provider.LanguageModel("gpt-4o-mini")
1360
1361 result, err := model.Generate(context.Background(), fantasy.Call{
1362 Prompt: testPrompt,
1363 })
1364
1365 require.NoError(t, err)
1366 require.Equal(t, int64(1152), result.Usage.CacheReadTokens)
1367 require.Equal(t, int64(15), result.Usage.InputTokens)
1368 require.Equal(t, int64(20), result.Usage.OutputTokens)
1369 require.Equal(t, int64(35), result.Usage.TotalTokens)
1370 })
1371
1372 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
1373 t.Parallel()
1374
1375 server := newMockServer()
1376 defer server.close()
1377
1378 server.prepareJSONResponse(map[string]any{
1379 "usage": map[string]any{
1380 "prompt_tokens": 15,
1381 "completion_tokens": 20,
1382 "total_tokens": 35,
1383 "completion_tokens_details": map[string]any{
1384 "accepted_prediction_tokens": 123,
1385 "rejected_prediction_tokens": 456,
1386 },
1387 },
1388 })
1389
1390 provider, err := New(
1391 WithAPIKey("test-api-key"),
1392 WithBaseURL(server.server.URL),
1393 )
1394 require.NoError(t, err)
1395 model, _ := provider.LanguageModel("gpt-4o-mini")
1396
1397 result, err := model.Generate(context.Background(), fantasy.Call{
1398 Prompt: testPrompt,
1399 })
1400
1401 require.NoError(t, err)
1402 require.NotNil(t, result.ProviderMetadata)
1403
1404 openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
1405
1406 require.True(t, ok)
1407 require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
1408 require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
1409 })
1410
1411 t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) {
1412 t.Parallel()
1413
1414 server := newMockServer()
1415 defer server.close()
1416
1417 server.prepareJSONResponse(map[string]any{})
1418
1419 provider, err := New(
1420 WithAPIKey("test-api-key"),
1421 WithBaseURL(server.server.URL),
1422 )
1423 require.NoError(t, err)
1424 model, _ := provider.LanguageModel("o1-preview")
1425
1426 result, err := model.Generate(context.Background(), fantasy.Call{
1427 Prompt: testPrompt,
1428 Temperature: &[]float64{0.5}[0],
1429 TopP: &[]float64{0.7}[0],
1430 FrequencyPenalty: &[]float64{0.2}[0],
1431 PresencePenalty: &[]float64{0.3}[0],
1432 })
1433
1434 require.NoError(t, err)
1435 require.Len(t, server.calls, 1)
1436
1437 call := server.calls[0]
1438 require.Equal(t, "o1-preview", call.body["model"])
1439
1440 messages := call.body["messages"].([]any)
1441 require.Len(t, messages, 1)
1442
1443 message := messages[0].(map[string]any)
1444 require.Equal(t, "user", message["role"])
1445 require.Equal(t, "Hello", message["content"])
1446
1447 // These should not be present
1448 require.Nil(t, call.body["temperature"])
1449 require.Nil(t, call.body["top_p"])
1450 require.Nil(t, call.body["frequency_penalty"])
1451 require.Nil(t, call.body["presence_penalty"])
1452
1453 // Should have warnings
1454 require.Len(t, result.Warnings, 4)
1455 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1456 require.Equal(t, "temperature", result.Warnings[0].Setting)
1457 require.Contains(t, result.Warnings[0].Details, "temperature is not supported for reasoning models")
1458 })
1459
1460 t.Run("should convert maxOutputTokens to max_completion_tokens for reasoning models", func(t *testing.T) {
1461 t.Parallel()
1462
1463 server := newMockServer()
1464 defer server.close()
1465
1466 server.prepareJSONResponse(map[string]any{})
1467
1468 provider, err := New(
1469 WithAPIKey("test-api-key"),
1470 WithBaseURL(server.server.URL),
1471 )
1472 require.NoError(t, err)
1473 model, _ := provider.LanguageModel("o1-preview")
1474
1475 _, err = model.Generate(context.Background(), fantasy.Call{
1476 Prompt: testPrompt,
1477 MaxOutputTokens: &[]int64{1000}[0],
1478 })
1479
1480 require.NoError(t, err)
1481 require.Len(t, server.calls, 1)
1482
1483 call := server.calls[0]
1484 require.Equal(t, "o1-preview", call.body["model"])
1485 require.Equal(t, float64(1000), call.body["max_completion_tokens"])
1486 require.Nil(t, call.body["max_tokens"])
1487
1488 messages := call.body["messages"].([]any)
1489 require.Len(t, messages, 1)
1490
1491 message := messages[0].(map[string]any)
1492 require.Equal(t, "user", message["role"])
1493 require.Equal(t, "Hello", message["content"])
1494 })
1495
1496 t.Run("should return reasoning tokens", func(t *testing.T) {
1497 t.Parallel()
1498
1499 server := newMockServer()
1500 defer server.close()
1501
1502 server.prepareJSONResponse(map[string]any{
1503 "usage": map[string]any{
1504 "prompt_tokens": 15,
1505 "completion_tokens": 20,
1506 "total_tokens": 35,
1507 "completion_tokens_details": map[string]any{
1508 "reasoning_tokens": 10,
1509 },
1510 },
1511 })
1512
1513 provider, err := New(
1514 WithAPIKey("test-api-key"),
1515 WithBaseURL(server.server.URL),
1516 )
1517 require.NoError(t, err)
1518 model, _ := provider.LanguageModel("o1-preview")
1519
1520 result, err := model.Generate(context.Background(), fantasy.Call{
1521 Prompt: testPrompt,
1522 })
1523
1524 require.NoError(t, err)
1525 require.Equal(t, int64(15), result.Usage.InputTokens)
1526 require.Equal(t, int64(20), result.Usage.OutputTokens)
1527 require.Equal(t, int64(35), result.Usage.TotalTokens)
1528 require.Equal(t, int64(10), result.Usage.ReasoningTokens)
1529 })
1530
1531 t.Run("should send max_completion_tokens extension setting", func(t *testing.T) {
1532 t.Parallel()
1533
1534 server := newMockServer()
1535 defer server.close()
1536
1537 server.prepareJSONResponse(map[string]any{
1538 "model": "o1-preview",
1539 })
1540
1541 provider, err := New(
1542 WithAPIKey("test-api-key"),
1543 WithBaseURL(server.server.URL),
1544 )
1545 require.NoError(t, err)
1546 model, _ := provider.LanguageModel("o1-preview")
1547
1548 _, err = model.Generate(context.Background(), fantasy.Call{
1549 Prompt: testPrompt,
1550 ProviderOptions: NewProviderOptions(&ProviderOptions{
1551 MaxCompletionTokens: fantasy.Opt(int64(255)),
1552 }),
1553 })
1554
1555 require.NoError(t, err)
1556 require.Len(t, server.calls, 1)
1557
1558 call := server.calls[0]
1559 require.Equal(t, "o1-preview", call.body["model"])
1560 require.Equal(t, float64(255), call.body["max_completion_tokens"])
1561
1562 messages := call.body["messages"].([]any)
1563 require.Len(t, messages, 1)
1564
1565 message := messages[0].(map[string]any)
1566 require.Equal(t, "user", message["role"])
1567 require.Equal(t, "Hello", message["content"])
1568 })
1569
1570 t.Run("should send prediction extension setting", func(t *testing.T) {
1571 t.Parallel()
1572
1573 server := newMockServer()
1574 defer server.close()
1575
1576 server.prepareJSONResponse(map[string]any{
1577 "content": "",
1578 })
1579
1580 provider, err := New(
1581 WithAPIKey("test-api-key"),
1582 WithBaseURL(server.server.URL),
1583 )
1584 require.NoError(t, err)
1585 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1586
1587 _, err = model.Generate(context.Background(), fantasy.Call{
1588 Prompt: testPrompt,
1589 ProviderOptions: NewProviderOptions(&ProviderOptions{
1590 Prediction: map[string]any{
1591 "type": "content",
1592 "content": "Hello, World!",
1593 },
1594 }),
1595 })
1596
1597 require.NoError(t, err)
1598 require.Len(t, server.calls, 1)
1599
1600 call := server.calls[0]
1601 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1602
1603 prediction := call.body["prediction"].(map[string]any)
1604 require.Equal(t, "content", prediction["type"])
1605 require.Equal(t, "Hello, World!", prediction["content"])
1606
1607 messages := call.body["messages"].([]any)
1608 require.Len(t, messages, 1)
1609
1610 message := messages[0].(map[string]any)
1611 require.Equal(t, "user", message["role"])
1612 require.Equal(t, "Hello", message["content"])
1613 })
1614
1615 t.Run("should send store extension setting", func(t *testing.T) {
1616 t.Parallel()
1617
1618 server := newMockServer()
1619 defer server.close()
1620
1621 server.prepareJSONResponse(map[string]any{
1622 "content": "",
1623 })
1624
1625 provider, err := New(
1626 WithAPIKey("test-api-key"),
1627 WithBaseURL(server.server.URL),
1628 )
1629 require.NoError(t, err)
1630 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1631
1632 _, err = model.Generate(context.Background(), fantasy.Call{
1633 Prompt: testPrompt,
1634 ProviderOptions: NewProviderOptions(&ProviderOptions{
1635 Store: fantasy.Opt(true),
1636 }),
1637 })
1638
1639 require.NoError(t, err)
1640 require.Len(t, server.calls, 1)
1641
1642 call := server.calls[0]
1643 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1644 require.Equal(t, true, call.body["store"])
1645
1646 messages := call.body["messages"].([]any)
1647 require.Len(t, messages, 1)
1648
1649 message := messages[0].(map[string]any)
1650 require.Equal(t, "user", message["role"])
1651 require.Equal(t, "Hello", message["content"])
1652 })
1653
1654 t.Run("should send metadata extension values", func(t *testing.T) {
1655 t.Parallel()
1656
1657 server := newMockServer()
1658 defer server.close()
1659
1660 server.prepareJSONResponse(map[string]any{
1661 "content": "",
1662 })
1663
1664 provider, err := New(
1665 WithAPIKey("test-api-key"),
1666 WithBaseURL(server.server.URL),
1667 )
1668 require.NoError(t, err)
1669 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1670
1671 _, err = model.Generate(context.Background(), fantasy.Call{
1672 Prompt: testPrompt,
1673 ProviderOptions: NewProviderOptions(&ProviderOptions{
1674 Metadata: map[string]any{
1675 "custom": "value",
1676 },
1677 }),
1678 })
1679
1680 require.NoError(t, err)
1681 require.Len(t, server.calls, 1)
1682
1683 call := server.calls[0]
1684 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1685
1686 metadata := call.body["metadata"].(map[string]any)
1687 require.Equal(t, "value", metadata["custom"])
1688
1689 messages := call.body["messages"].([]any)
1690 require.Len(t, messages, 1)
1691
1692 message := messages[0].(map[string]any)
1693 require.Equal(t, "user", message["role"])
1694 require.Equal(t, "Hello", message["content"])
1695 })
1696
1697 t.Run("should send promptCacheKey extension value", func(t *testing.T) {
1698 t.Parallel()
1699
1700 server := newMockServer()
1701 defer server.close()
1702
1703 server.prepareJSONResponse(map[string]any{
1704 "content": "",
1705 })
1706
1707 provider, err := New(
1708 WithAPIKey("test-api-key"),
1709 WithBaseURL(server.server.URL),
1710 )
1711 require.NoError(t, err)
1712 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1713
1714 _, err = model.Generate(context.Background(), fantasy.Call{
1715 Prompt: testPrompt,
1716 ProviderOptions: NewProviderOptions(&ProviderOptions{
1717 PromptCacheKey: fantasy.Opt("test-cache-key-123"),
1718 }),
1719 })
1720
1721 require.NoError(t, err)
1722 require.Len(t, server.calls, 1)
1723
1724 call := server.calls[0]
1725 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1726 require.Equal(t, "test-cache-key-123", call.body["prompt_cache_key"])
1727
1728 messages := call.body["messages"].([]any)
1729 require.Len(t, messages, 1)
1730
1731 message := messages[0].(map[string]any)
1732 require.Equal(t, "user", message["role"])
1733 require.Equal(t, "Hello", message["content"])
1734 })
1735
1736 t.Run("should send safety_identifier extension value", func(t *testing.T) {
1737 t.Parallel()
1738
1739 server := newMockServer()
1740 defer server.close()
1741
1742 server.prepareJSONResponse(map[string]any{
1743 "content": "",
1744 })
1745
1746 provider, err := New(
1747 WithAPIKey("test-api-key"),
1748 WithBaseURL(server.server.URL),
1749 )
1750 require.NoError(t, err)
1751 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1752
1753 _, err = model.Generate(context.Background(), fantasy.Call{
1754 Prompt: testPrompt,
1755 ProviderOptions: NewProviderOptions(&ProviderOptions{
1756 SafetyIdentifier: fantasy.Opt("test-safety-identifier-123"),
1757 }),
1758 })
1759
1760 require.NoError(t, err)
1761 require.Len(t, server.calls, 1)
1762
1763 call := server.calls[0]
1764 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1765 require.Equal(t, "test-safety-identifier-123", call.body["safety_identifier"])
1766
1767 messages := call.body["messages"].([]any)
1768 require.Len(t, messages, 1)
1769
1770 message := messages[0].(map[string]any)
1771 require.Equal(t, "user", message["role"])
1772 require.Equal(t, "Hello", message["content"])
1773 })
1774
1775 t.Run("should remove temperature setting for search preview models", func(t *testing.T) {
1776 t.Parallel()
1777
1778 server := newMockServer()
1779 defer server.close()
1780
1781 server.prepareJSONResponse(map[string]any{})
1782
1783 provider, err := New(
1784 WithAPIKey("test-api-key"),
1785 WithBaseURL(server.server.URL),
1786 )
1787 require.NoError(t, err)
1788 model, _ := provider.LanguageModel("gpt-4o-search-preview")
1789
1790 result, err := model.Generate(context.Background(), fantasy.Call{
1791 Prompt: testPrompt,
1792 Temperature: &[]float64{0.7}[0],
1793 })
1794
1795 require.NoError(t, err)
1796 require.Len(t, server.calls, 1)
1797
1798 call := server.calls[0]
1799 require.Equal(t, "gpt-4o-search-preview", call.body["model"])
1800 require.Nil(t, call.body["temperature"])
1801
1802 require.Len(t, result.Warnings, 1)
1803 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1804 require.Equal(t, "temperature", result.Warnings[0].Setting)
1805 require.Contains(t, result.Warnings[0].Details, "search preview models")
1806 })
1807
1808 t.Run("should send ServiceTier flex processing setting", func(t *testing.T) {
1809 t.Parallel()
1810
1811 server := newMockServer()
1812 defer server.close()
1813
1814 server.prepareJSONResponse(map[string]any{
1815 "content": "",
1816 })
1817
1818 provider, err := New(
1819 WithAPIKey("test-api-key"),
1820 WithBaseURL(server.server.URL),
1821 )
1822 require.NoError(t, err)
1823 model, _ := provider.LanguageModel("o3-mini")
1824
1825 _, err = model.Generate(context.Background(), fantasy.Call{
1826 Prompt: testPrompt,
1827 ProviderOptions: NewProviderOptions(&ProviderOptions{
1828 ServiceTier: fantasy.Opt("flex"),
1829 }),
1830 })
1831
1832 require.NoError(t, err)
1833 require.Len(t, server.calls, 1)
1834
1835 call := server.calls[0]
1836 require.Equal(t, "o3-mini", call.body["model"])
1837 require.Equal(t, "flex", call.body["service_tier"])
1838
1839 messages := call.body["messages"].([]any)
1840 require.Len(t, messages, 1)
1841
1842 message := messages[0].(map[string]any)
1843 require.Equal(t, "user", message["role"])
1844 require.Equal(t, "Hello", message["content"])
1845 })
1846
1847 t.Run("should show warning when using flex processing with unsupported model", func(t *testing.T) {
1848 t.Parallel()
1849
1850 server := newMockServer()
1851 defer server.close()
1852
1853 server.prepareJSONResponse(map[string]any{})
1854
1855 provider, err := New(
1856 WithAPIKey("test-api-key"),
1857 WithBaseURL(server.server.URL),
1858 )
1859 require.NoError(t, err)
1860 model, _ := provider.LanguageModel("gpt-4o-mini")
1861
1862 result, err := model.Generate(context.Background(), fantasy.Call{
1863 Prompt: testPrompt,
1864 ProviderOptions: NewProviderOptions(&ProviderOptions{
1865 ServiceTier: fantasy.Opt("flex"),
1866 }),
1867 })
1868
1869 require.NoError(t, err)
1870 require.Len(t, server.calls, 1)
1871
1872 call := server.calls[0]
1873 require.Nil(t, call.body["service_tier"])
1874
1875 require.Len(t, result.Warnings, 1)
1876 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1877 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
1878 require.Contains(t, result.Warnings[0].Details, "flex processing is only available")
1879 })
1880
1881 t.Run("should send serviceTier priority processing setting", func(t *testing.T) {
1882 t.Parallel()
1883
1884 server := newMockServer()
1885 defer server.close()
1886
1887 server.prepareJSONResponse(map[string]any{})
1888
1889 provider, err := New(
1890 WithAPIKey("test-api-key"),
1891 WithBaseURL(server.server.URL),
1892 )
1893 require.NoError(t, err)
1894 model, _ := provider.LanguageModel("gpt-4o-mini")
1895
1896 _, err = model.Generate(context.Background(), fantasy.Call{
1897 Prompt: testPrompt,
1898 ProviderOptions: NewProviderOptions(&ProviderOptions{
1899 ServiceTier: fantasy.Opt("priority"),
1900 }),
1901 })
1902
1903 require.NoError(t, err)
1904 require.Len(t, server.calls, 1)
1905
1906 call := server.calls[0]
1907 require.Equal(t, "gpt-4o-mini", call.body["model"])
1908 require.Equal(t, "priority", call.body["service_tier"])
1909
1910 messages := call.body["messages"].([]any)
1911 require.Len(t, messages, 1)
1912
1913 message := messages[0].(map[string]any)
1914 require.Equal(t, "user", message["role"])
1915 require.Equal(t, "Hello", message["content"])
1916 })
1917
1918 t.Run("should show warning when using priority processing with unsupported model", func(t *testing.T) {
1919 t.Parallel()
1920
1921 server := newMockServer()
1922 defer server.close()
1923
1924 server.prepareJSONResponse(map[string]any{})
1925
1926 provider, err := New(
1927 WithAPIKey("test-api-key"),
1928 WithBaseURL(server.server.URL),
1929 )
1930 require.NoError(t, err)
1931 model, _ := provider.LanguageModel("gpt-3.5-turbo")
1932
1933 result, err := model.Generate(context.Background(), fantasy.Call{
1934 Prompt: testPrompt,
1935 ProviderOptions: NewProviderOptions(&ProviderOptions{
1936 ServiceTier: fantasy.Opt("priority"),
1937 }),
1938 })
1939
1940 require.NoError(t, err)
1941 require.Len(t, server.calls, 1)
1942
1943 call := server.calls[0]
1944 require.Nil(t, call.body["service_tier"])
1945
1946 require.Len(t, result.Warnings, 1)
1947 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1948 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
1949 require.Contains(t, result.Warnings[0].Details, "priority processing is only available")
1950 })
1951}
1952
1953type streamingMockServer struct {
1954 server *httptest.Server
1955 chunks []string
1956 calls []mockCall
1957}
1958
1959func newStreamingMockServer() *streamingMockServer {
1960 sms := &streamingMockServer{
1961 calls: make([]mockCall, 0),
1962 }
1963
1964 sms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1965 // Record the call
1966 call := mockCall{
1967 method: r.Method,
1968 path: r.URL.Path,
1969 headers: make(map[string]string),
1970 }
1971
1972 for k, v := range r.Header {
1973 if len(v) > 0 {
1974 call.headers[k] = v[0]
1975 }
1976 }
1977
1978 // Parse request body
1979 if r.Body != nil {
1980 var body map[string]any
1981 json.NewDecoder(r.Body).Decode(&body)
1982 call.body = body
1983 }
1984
1985 sms.calls = append(sms.calls, call)
1986
1987 // Set streaming headers
1988 w.Header().Set("Content-Type", "text/event-stream")
1989 w.Header().Set("Cache-Control", "no-cache")
1990 w.Header().Set("Connection", "keep-alive")
1991
1992 // Add custom headers if any
1993 for _, chunk := range sms.chunks {
1994 if strings.HasPrefix(chunk, "HEADER:") {
1995 parts := strings.SplitN(chunk[7:], ":", 2)
1996 if len(parts) == 2 {
1997 w.Header().Set(parts[0], parts[1])
1998 }
1999 continue
2000 }
2001 }
2002
2003 w.WriteHeader(http.StatusOK)
2004
2005 // Write chunks
2006 for _, chunk := range sms.chunks {
2007 if strings.HasPrefix(chunk, "HEADER:") {
2008 continue
2009 }
2010 w.Write([]byte(chunk))
2011 if f, ok := w.(http.Flusher); ok {
2012 f.Flush()
2013 }
2014 }
2015 }))
2016
2017 return sms
2018}
2019
2020func (sms *streamingMockServer) close() {
2021 sms.server.Close()
2022}
2023
2024func (sms *streamingMockServer) prepareStreamResponse(opts map[string]any) {
2025 content := []string{}
2026 if c, ok := opts["content"].([]string); ok {
2027 content = c
2028 }
2029
2030 usage := map[string]any{
2031 "prompt_tokens": 17,
2032 "total_tokens": 244,
2033 "completion_tokens": 227,
2034 }
2035 if u, ok := opts["usage"].(map[string]any); ok {
2036 usage = u
2037 }
2038
2039 logprobs := map[string]any{}
2040 if l, ok := opts["logprobs"].(map[string]any); ok {
2041 logprobs = l
2042 }
2043
2044 finishReason := "stop"
2045 if fr, ok := opts["finish_reason"].(string); ok {
2046 finishReason = fr
2047 }
2048
2049 model := "gpt-3.5-turbo-0613"
2050 if m, ok := opts["model"].(string); ok {
2051 model = m
2052 }
2053
2054 headers := map[string]string{}
2055 if h, ok := opts["headers"].(map[string]string); ok {
2056 headers = h
2057 }
2058
2059 chunks := []string{}
2060
2061 // Add custom headers
2062 for k, v := range headers {
2063 chunks = append(chunks, "HEADER:"+k+":"+v)
2064 }
2065
2066 // Initial chunk with role
2067 initialChunk := map[string]any{
2068 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2069 "object": "chat.completion.chunk",
2070 "created": 1702657020,
2071 "model": model,
2072 "system_fingerprint": nil,
2073 "choices": []map[string]any{
2074 {
2075 "index": 0,
2076 "delta": map[string]any{
2077 "role": "assistant",
2078 "content": "",
2079 },
2080 "finish_reason": nil,
2081 },
2082 },
2083 }
2084 initialData, _ := json.Marshal(initialChunk)
2085 chunks = append(chunks, "data: "+string(initialData)+"\n\n")
2086
2087 // Content chunks
2088 for i, text := range content {
2089 contentChunk := map[string]any{
2090 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2091 "object": "chat.completion.chunk",
2092 "created": 1702657020,
2093 "model": model,
2094 "system_fingerprint": nil,
2095 "choices": []map[string]any{
2096 {
2097 "index": 1,
2098 "delta": map[string]any{
2099 "content": text,
2100 },
2101 "finish_reason": nil,
2102 },
2103 },
2104 }
2105 contentData, _ := json.Marshal(contentChunk)
2106 chunks = append(chunks, "data: "+string(contentData)+"\n\n")
2107
2108 // Add annotations if this is the last content chunk and we have annotations
2109 if i == len(content)-1 {
2110 if annotations, ok := opts["annotations"].([]map[string]any); ok {
2111 annotationChunk := map[string]any{
2112 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2113 "object": "chat.completion.chunk",
2114 "created": 1702657020,
2115 "model": model,
2116 "system_fingerprint": nil,
2117 "choices": []map[string]any{
2118 {
2119 "index": 1,
2120 "delta": map[string]any{
2121 "annotations": annotations,
2122 },
2123 "finish_reason": nil,
2124 },
2125 },
2126 }
2127 annotationData, _ := json.Marshal(annotationChunk)
2128 chunks = append(chunks, "data: "+string(annotationData)+"\n\n")
2129 }
2130 }
2131 }
2132
2133 // Finish chunk
2134 finishChunk := map[string]any{
2135 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2136 "object": "chat.completion.chunk",
2137 "created": 1702657020,
2138 "model": model,
2139 "system_fingerprint": nil,
2140 "choices": []map[string]any{
2141 {
2142 "index": 0,
2143 "delta": map[string]any{},
2144 "finish_reason": finishReason,
2145 },
2146 },
2147 }
2148
2149 if len(logprobs) > 0 {
2150 finishChunk["choices"].([]map[string]any)[0]["logprobs"] = logprobs
2151 }
2152
2153 finishData, _ := json.Marshal(finishChunk)
2154 chunks = append(chunks, "data: "+string(finishData)+"\n\n")
2155
2156 // Usage chunk
2157 usageChunk := map[string]any{
2158 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2159 "object": "chat.completion.chunk",
2160 "created": 1702657020,
2161 "model": model,
2162 "system_fingerprint": "fp_3bc1b5746c",
2163 "choices": []map[string]any{},
2164 "usage": usage,
2165 }
2166 usageData, _ := json.Marshal(usageChunk)
2167 chunks = append(chunks, "data: "+string(usageData)+"\n\n")
2168
2169 // Done
2170 chunks = append(chunks, "data: [DONE]\n\n")
2171
2172 sms.chunks = chunks
2173}
2174
2175func (sms *streamingMockServer) prepareToolStreamResponse() {
2176 chunks := []string{
2177 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2178 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2179 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"value"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2180 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2181 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2182 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2183 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2184 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2185 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
2186 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
2187 "data: [DONE]\n\n",
2188 }
2189 sms.chunks = chunks
2190}
2191
2192func (sms *streamingMockServer) prepareErrorStreamResponse() {
2193 chunks := []string{
2194 `data: {"error":{"message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error.","type":"server_error","param":null,"code":null}}` + "\n\n",
2195 "data: [DONE]\n\n",
2196 }
2197 sms.chunks = chunks
2198}
2199
2200func collectStreamParts(stream fantasy.StreamResponse) ([]fantasy.StreamPart, error) {
2201 var parts []fantasy.StreamPart
2202 for part := range stream {
2203 parts = append(parts, part)
2204 if part.Type == fantasy.StreamPartTypeError {
2205 break
2206 }
2207 if part.Type == fantasy.StreamPartTypeFinish {
2208 break
2209 }
2210 }
2211 return parts, nil
2212}
2213
2214func TestDoStream(t *testing.T) {
2215 t.Parallel()
2216
2217 t.Run("should stream text deltas", func(t *testing.T) {
2218 t.Parallel()
2219
2220 server := newStreamingMockServer()
2221 defer server.close()
2222
2223 server.prepareStreamResponse(map[string]any{
2224 "content": []string{"Hello", ", ", "World!"},
2225 "finish_reason": "stop",
2226 "usage": map[string]any{
2227 "prompt_tokens": 17,
2228 "total_tokens": 244,
2229 "completion_tokens": 227,
2230 },
2231 "logprobs": testLogprobs,
2232 })
2233
2234 provider, err := New(
2235 WithAPIKey("test-api-key"),
2236 WithBaseURL(server.server.URL),
2237 )
2238 require.NoError(t, err)
2239 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2240
2241 stream, err := model.Stream(context.Background(), fantasy.Call{
2242 Prompt: testPrompt,
2243 })
2244
2245 require.NoError(t, err)
2246
2247 parts, err := collectStreamParts(stream)
2248 require.NoError(t, err)
2249
2250 // Verify stream structure
2251 require.True(t, len(parts) >= 4) // text-start, deltas, text-end, finish
2252
2253 // Find text parts
2254 textStart, textEnd, finish := -1, -1, -1
2255 var deltas []string
2256
2257 for i, part := range parts {
2258 switch part.Type {
2259 case fantasy.StreamPartTypeTextStart:
2260 textStart = i
2261 case fantasy.StreamPartTypeTextDelta:
2262 deltas = append(deltas, part.Delta)
2263 case fantasy.StreamPartTypeTextEnd:
2264 textEnd = i
2265 case fantasy.StreamPartTypeFinish:
2266 finish = i
2267 }
2268 }
2269
2270 require.NotEqual(t, -1, textStart)
2271 require.NotEqual(t, -1, textEnd)
2272 require.NotEqual(t, -1, finish)
2273 require.Equal(t, []string{"Hello", ", ", "World!"}, deltas)
2274
2275 // Check finish part
2276 finishPart := parts[finish]
2277 require.Equal(t, fantasy.FinishReasonStop, finishPart.FinishReason)
2278 require.Equal(t, int64(17), finishPart.Usage.InputTokens)
2279 require.Equal(t, int64(227), finishPart.Usage.OutputTokens)
2280 require.Equal(t, int64(244), finishPart.Usage.TotalTokens)
2281 })
2282
2283 t.Run("should stream tool deltas", func(t *testing.T) {
2284 t.Parallel()
2285
2286 server := newStreamingMockServer()
2287 defer server.close()
2288
2289 server.prepareToolStreamResponse()
2290
2291 provider, err := New(
2292 WithAPIKey("test-api-key"),
2293 WithBaseURL(server.server.URL),
2294 )
2295 require.NoError(t, err)
2296 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2297
2298 stream, err := model.Stream(context.Background(), fantasy.Call{
2299 Prompt: testPrompt,
2300 Tools: []fantasy.Tool{
2301 fantasy.FunctionTool{
2302 Name: "test-tool",
2303 InputSchema: map[string]any{
2304 "type": "object",
2305 "properties": map[string]any{
2306 "value": map[string]any{
2307 "type": "string",
2308 },
2309 },
2310 "required": []string{"value"},
2311 "additionalProperties": false,
2312 "$schema": "http://json-schema.org/draft-07/schema#",
2313 },
2314 },
2315 },
2316 })
2317
2318 require.NoError(t, err)
2319
2320 parts, err := collectStreamParts(stream)
2321 require.NoError(t, err)
2322
2323 // Find tool-related parts
2324 toolInputStart, toolInputEnd, toolCall := -1, -1, -1
2325 var toolDeltas []string
2326
2327 for i, part := range parts {
2328 switch part.Type {
2329 case fantasy.StreamPartTypeToolInputStart:
2330 toolInputStart = i
2331 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2332 require.Equal(t, "test-tool", part.ToolCallName)
2333 case fantasy.StreamPartTypeToolInputDelta:
2334 toolDeltas = append(toolDeltas, part.Delta)
2335 case fantasy.StreamPartTypeToolInputEnd:
2336 toolInputEnd = i
2337 case fantasy.StreamPartTypeToolCall:
2338 toolCall = i
2339 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2340 require.Equal(t, "test-tool", part.ToolCallName)
2341 require.Equal(t, `{"value":"Sparkle Day"}`, part.ToolCallInput)
2342 }
2343 }
2344
2345 require.NotEqual(t, -1, toolInputStart)
2346 require.NotEqual(t, -1, toolInputEnd)
2347 require.NotEqual(t, -1, toolCall)
2348
2349 // Verify tool deltas combine to form the complete input
2350 fullInput := ""
2351 for _, delta := range toolDeltas {
2352 fullInput += delta
2353 }
2354 require.Equal(t, `{"value":"Sparkle Day"}`, fullInput)
2355 })
2356
2357 t.Run("should stream annotations/citations", func(t *testing.T) {
2358 t.Parallel()
2359
2360 server := newStreamingMockServer()
2361 defer server.close()
2362
2363 server.prepareStreamResponse(map[string]any{
2364 "content": []string{"Based on search results"},
2365 "annotations": []map[string]any{
2366 {
2367 "type": "url_citation",
2368 "url_citation": map[string]any{
2369 "start_index": 24,
2370 "end_index": 29,
2371 "url": "https://example.com/doc1.pdf",
2372 "title": "Document 1",
2373 },
2374 },
2375 },
2376 })
2377
2378 provider, err := New(
2379 WithAPIKey("test-api-key"),
2380 WithBaseURL(server.server.URL),
2381 )
2382 require.NoError(t, err)
2383 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2384
2385 stream, err := model.Stream(context.Background(), fantasy.Call{
2386 Prompt: testPrompt,
2387 })
2388
2389 require.NoError(t, err)
2390
2391 parts, err := collectStreamParts(stream)
2392 require.NoError(t, err)
2393
2394 // Find source part
2395 var sourcePart *fantasy.StreamPart
2396 for _, part := range parts {
2397 if part.Type == fantasy.StreamPartTypeSource {
2398 sourcePart = &part
2399 break
2400 }
2401 }
2402
2403 require.NotNil(t, sourcePart)
2404 require.Equal(t, fantasy.SourceTypeURL, sourcePart.SourceType)
2405 require.Equal(t, "https://example.com/doc1.pdf", sourcePart.URL)
2406 require.Equal(t, "Document 1", sourcePart.Title)
2407 require.NotEmpty(t, sourcePart.ID)
2408 })
2409
2410 t.Run("should handle error stream parts", func(t *testing.T) {
2411 t.Parallel()
2412
2413 server := newStreamingMockServer()
2414 defer server.close()
2415
2416 server.prepareErrorStreamResponse()
2417
2418 provider, err := New(
2419 WithAPIKey("test-api-key"),
2420 WithBaseURL(server.server.URL),
2421 )
2422 require.NoError(t, err)
2423 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2424
2425 stream, err := model.Stream(context.Background(), fantasy.Call{
2426 Prompt: testPrompt,
2427 })
2428
2429 require.NoError(t, err)
2430
2431 parts, err := collectStreamParts(stream)
2432 require.NoError(t, err)
2433
2434 // Should have error and finish parts
2435 require.True(t, len(parts) >= 1)
2436
2437 // Find error part
2438 var errorPart *fantasy.StreamPart
2439 for _, part := range parts {
2440 if part.Type == fantasy.StreamPartTypeError {
2441 errorPart = &part
2442 break
2443 }
2444 }
2445
2446 require.NotNil(t, errorPart)
2447 require.NotNil(t, errorPart.Error)
2448 })
2449
2450 t.Run("should send request body", func(t *testing.T) {
2451 t.Parallel()
2452
2453 server := newStreamingMockServer()
2454 defer server.close()
2455
2456 server.prepareStreamResponse(map[string]any{
2457 "content": []string{},
2458 })
2459
2460 provider, err := New(
2461 WithAPIKey("test-api-key"),
2462 WithBaseURL(server.server.URL),
2463 )
2464 require.NoError(t, err)
2465 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2466
2467 _, err = model.Stream(context.Background(), fantasy.Call{
2468 Prompt: testPrompt,
2469 })
2470
2471 require.NoError(t, err)
2472 require.Len(t, server.calls, 1)
2473
2474 call := server.calls[0]
2475 require.Equal(t, "POST", call.method)
2476 require.Equal(t, "/chat/completions", call.path)
2477 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2478 require.Equal(t, true, call.body["stream"])
2479
2480 streamOptions := call.body["stream_options"].(map[string]any)
2481 require.Equal(t, true, streamOptions["include_usage"])
2482
2483 messages := call.body["messages"].([]any)
2484 require.Len(t, messages, 1)
2485
2486 message := messages[0].(map[string]any)
2487 require.Equal(t, "user", message["role"])
2488 require.Equal(t, "Hello", message["content"])
2489 })
2490
2491 t.Run("should return cached tokens in providerMetadata", func(t *testing.T) {
2492 t.Parallel()
2493
2494 server := newStreamingMockServer()
2495 defer server.close()
2496
2497 server.prepareStreamResponse(map[string]any{
2498 "content": []string{},
2499 "usage": map[string]any{
2500 "prompt_tokens": 15,
2501 "completion_tokens": 20,
2502 "total_tokens": 35,
2503 "prompt_tokens_details": map[string]any{
2504 "cached_tokens": 1152,
2505 },
2506 },
2507 })
2508
2509 provider, err := New(
2510 WithAPIKey("test-api-key"),
2511 WithBaseURL(server.server.URL),
2512 )
2513 require.NoError(t, err)
2514 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2515
2516 stream, err := model.Stream(context.Background(), fantasy.Call{
2517 Prompt: testPrompt,
2518 })
2519
2520 require.NoError(t, err)
2521
2522 parts, err := collectStreamParts(stream)
2523 require.NoError(t, err)
2524
2525 // Find finish part
2526 var finishPart *fantasy.StreamPart
2527 for _, part := range parts {
2528 if part.Type == fantasy.StreamPartTypeFinish {
2529 finishPart = &part
2530 break
2531 }
2532 }
2533
2534 require.NotNil(t, finishPart)
2535 require.Equal(t, int64(1152), finishPart.Usage.CacheReadTokens)
2536 require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2537 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2538 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2539 })
2540
2541 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
2542 t.Parallel()
2543
2544 server := newStreamingMockServer()
2545 defer server.close()
2546
2547 server.prepareStreamResponse(map[string]any{
2548 "content": []string{},
2549 "usage": map[string]any{
2550 "prompt_tokens": 15,
2551 "completion_tokens": 20,
2552 "total_tokens": 35,
2553 "completion_tokens_details": map[string]any{
2554 "accepted_prediction_tokens": 123,
2555 "rejected_prediction_tokens": 456,
2556 },
2557 },
2558 })
2559
2560 provider, err := New(
2561 WithAPIKey("test-api-key"),
2562 WithBaseURL(server.server.URL),
2563 )
2564 require.NoError(t, err)
2565 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2566
2567 stream, err := model.Stream(context.Background(), fantasy.Call{
2568 Prompt: testPrompt,
2569 })
2570
2571 require.NoError(t, err)
2572
2573 parts, err := collectStreamParts(stream)
2574 require.NoError(t, err)
2575
2576 // Find finish part
2577 var finishPart *fantasy.StreamPart
2578 for _, part := range parts {
2579 if part.Type == fantasy.StreamPartTypeFinish {
2580 finishPart = &part
2581 break
2582 }
2583 }
2584
2585 require.NotNil(t, finishPart)
2586 require.NotNil(t, finishPart.ProviderMetadata)
2587
2588 openaiMeta, ok := finishPart.ProviderMetadata["openai"].(*ProviderMetadata)
2589 require.True(t, ok)
2590 require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
2591 require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
2592 })
2593
2594 t.Run("should send store extension setting", func(t *testing.T) {
2595 t.Parallel()
2596
2597 server := newStreamingMockServer()
2598 defer server.close()
2599
2600 server.prepareStreamResponse(map[string]any{
2601 "content": []string{},
2602 })
2603
2604 provider, err := New(
2605 WithAPIKey("test-api-key"),
2606 WithBaseURL(server.server.URL),
2607 )
2608 require.NoError(t, err)
2609 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2610
2611 _, err = model.Stream(context.Background(), fantasy.Call{
2612 Prompt: testPrompt,
2613 ProviderOptions: NewProviderOptions(&ProviderOptions{
2614 Store: fantasy.Opt(true),
2615 }),
2616 })
2617
2618 require.NoError(t, err)
2619 require.Len(t, server.calls, 1)
2620
2621 call := server.calls[0]
2622 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2623 require.Equal(t, true, call.body["stream"])
2624 require.Equal(t, true, call.body["store"])
2625
2626 streamOptions := call.body["stream_options"].(map[string]any)
2627 require.Equal(t, true, streamOptions["include_usage"])
2628
2629 messages := call.body["messages"].([]any)
2630 require.Len(t, messages, 1)
2631
2632 message := messages[0].(map[string]any)
2633 require.Equal(t, "user", message["role"])
2634 require.Equal(t, "Hello", message["content"])
2635 })
2636
2637 t.Run("should send metadata extension values", func(t *testing.T) {
2638 t.Parallel()
2639
2640 server := newStreamingMockServer()
2641 defer server.close()
2642
2643 server.prepareStreamResponse(map[string]any{
2644 "content": []string{},
2645 })
2646
2647 provider, err := New(
2648 WithAPIKey("test-api-key"),
2649 WithBaseURL(server.server.URL),
2650 )
2651 require.NoError(t, err)
2652 model, _ := provider.LanguageModel("gpt-3.5-turbo")
2653
2654 _, err = model.Stream(context.Background(), fantasy.Call{
2655 Prompt: testPrompt,
2656 ProviderOptions: NewProviderOptions(&ProviderOptions{
2657 Metadata: map[string]any{
2658 "custom": "value",
2659 },
2660 }),
2661 })
2662
2663 require.NoError(t, err)
2664 require.Len(t, server.calls, 1)
2665
2666 call := server.calls[0]
2667 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2668 require.Equal(t, true, call.body["stream"])
2669
2670 metadata := call.body["metadata"].(map[string]any)
2671 require.Equal(t, "value", metadata["custom"])
2672
2673 streamOptions := call.body["stream_options"].(map[string]any)
2674 require.Equal(t, true, streamOptions["include_usage"])
2675
2676 messages := call.body["messages"].([]any)
2677 require.Len(t, messages, 1)
2678
2679 message := messages[0].(map[string]any)
2680 require.Equal(t, "user", message["role"])
2681 require.Equal(t, "Hello", message["content"])
2682 })
2683
2684 t.Run("should send serviceTier flex processing setting in streaming", func(t *testing.T) {
2685 t.Parallel()
2686
2687 server := newStreamingMockServer()
2688 defer server.close()
2689
2690 server.prepareStreamResponse(map[string]any{
2691 "content": []string{},
2692 })
2693
2694 provider, err := New(
2695 WithAPIKey("test-api-key"),
2696 WithBaseURL(server.server.URL),
2697 )
2698 require.NoError(t, err)
2699 model, _ := provider.LanguageModel("o3-mini")
2700
2701 _, err = model.Stream(context.Background(), fantasy.Call{
2702 Prompt: testPrompt,
2703 ProviderOptions: NewProviderOptions(&ProviderOptions{
2704 ServiceTier: fantasy.Opt("flex"),
2705 }),
2706 })
2707
2708 require.NoError(t, err)
2709 require.Len(t, server.calls, 1)
2710
2711 call := server.calls[0]
2712 require.Equal(t, "o3-mini", call.body["model"])
2713 require.Equal(t, "flex", call.body["service_tier"])
2714 require.Equal(t, true, call.body["stream"])
2715
2716 streamOptions := call.body["stream_options"].(map[string]any)
2717 require.Equal(t, true, streamOptions["include_usage"])
2718
2719 messages := call.body["messages"].([]any)
2720 require.Len(t, messages, 1)
2721
2722 message := messages[0].(map[string]any)
2723 require.Equal(t, "user", message["role"])
2724 require.Equal(t, "Hello", message["content"])
2725 })
2726
2727 t.Run("should send serviceTier priority processing setting in streaming", func(t *testing.T) {
2728 t.Parallel()
2729
2730 server := newStreamingMockServer()
2731 defer server.close()
2732
2733 server.prepareStreamResponse(map[string]any{
2734 "content": []string{},
2735 })
2736
2737 provider, err := New(
2738 WithAPIKey("test-api-key"),
2739 WithBaseURL(server.server.URL),
2740 )
2741 require.NoError(t, err)
2742 model, _ := provider.LanguageModel("gpt-4o-mini")
2743
2744 _, err = model.Stream(context.Background(), fantasy.Call{
2745 Prompt: testPrompt,
2746 ProviderOptions: NewProviderOptions(&ProviderOptions{
2747 ServiceTier: fantasy.Opt("priority"),
2748 }),
2749 })
2750
2751 require.NoError(t, err)
2752 require.Len(t, server.calls, 1)
2753
2754 call := server.calls[0]
2755 require.Equal(t, "gpt-4o-mini", call.body["model"])
2756 require.Equal(t, "priority", call.body["service_tier"])
2757 require.Equal(t, true, call.body["stream"])
2758
2759 streamOptions := call.body["stream_options"].(map[string]any)
2760 require.Equal(t, true, streamOptions["include_usage"])
2761
2762 messages := call.body["messages"].([]any)
2763 require.Len(t, messages, 1)
2764
2765 message := messages[0].(map[string]any)
2766 require.Equal(t, "user", message["role"])
2767 require.Equal(t, "Hello", message["content"])
2768 })
2769
2770 t.Run("should stream text delta for reasoning models", func(t *testing.T) {
2771 t.Parallel()
2772
2773 server := newStreamingMockServer()
2774 defer server.close()
2775
2776 server.prepareStreamResponse(map[string]any{
2777 "content": []string{"Hello, World!"},
2778 "model": "o1-preview",
2779 })
2780
2781 provider, err := New(
2782 WithAPIKey("test-api-key"),
2783 WithBaseURL(server.server.URL),
2784 )
2785 require.NoError(t, err)
2786 model, _ := provider.LanguageModel("o1-preview")
2787
2788 stream, err := model.Stream(context.Background(), fantasy.Call{
2789 Prompt: testPrompt,
2790 })
2791
2792 require.NoError(t, err)
2793
2794 parts, err := collectStreamParts(stream)
2795 require.NoError(t, err)
2796
2797 // Find text parts
2798 var textDeltas []string
2799 for _, part := range parts {
2800 if part.Type == fantasy.StreamPartTypeTextDelta {
2801 textDeltas = append(textDeltas, part.Delta)
2802 }
2803 }
2804
2805 // Should contain the text content (without empty delta)
2806 require.Equal(t, []string{"Hello, World!"}, textDeltas)
2807 })
2808
2809 t.Run("should send reasoning tokens", func(t *testing.T) {
2810 t.Parallel()
2811
2812 server := newStreamingMockServer()
2813 defer server.close()
2814
2815 server.prepareStreamResponse(map[string]any{
2816 "content": []string{"Hello, World!"},
2817 "model": "o1-preview",
2818 "usage": map[string]any{
2819 "prompt_tokens": 15,
2820 "completion_tokens": 20,
2821 "total_tokens": 35,
2822 "completion_tokens_details": map[string]any{
2823 "reasoning_tokens": 10,
2824 },
2825 },
2826 })
2827
2828 provider, err := New(
2829 WithAPIKey("test-api-key"),
2830 WithBaseURL(server.server.URL),
2831 )
2832 require.NoError(t, err)
2833 model, _ := provider.LanguageModel("o1-preview")
2834
2835 stream, err := model.Stream(context.Background(), fantasy.Call{
2836 Prompt: testPrompt,
2837 })
2838
2839 require.NoError(t, err)
2840
2841 parts, err := collectStreamParts(stream)
2842 require.NoError(t, err)
2843
2844 // Find finish part
2845 var finishPart *fantasy.StreamPart
2846 for _, part := range parts {
2847 if part.Type == fantasy.StreamPartTypeFinish {
2848 finishPart = &part
2849 break
2850 }
2851 }
2852
2853 require.NotNil(t, finishPart)
2854 require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2855 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2856 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2857 require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens)
2858 })
2859}