1package openai
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "net/http"
9 "net/http/httptest"
10 "strings"
11 "testing"
12
13 "charm.land/fantasy"
14 "github.com/charmbracelet/openai-go/packages/param"
15 "github.com/stretchr/testify/assert"
16 "github.com/stretchr/testify/require"
17)
18
19func TestToOpenAiPrompt_SystemMessages(t *testing.T) {
20 t.Parallel()
21
22 t.Run("should forward system messages", func(t *testing.T) {
23 t.Parallel()
24
25 prompt := fantasy.Prompt{
26 {
27 Role: fantasy.MessageRoleSystem,
28 Content: []fantasy.MessagePart{
29 fantasy.TextPart{Text: "You are a helpful assistant."},
30 },
31 },
32 }
33
34 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
35
36 require.Empty(t, warnings)
37 require.Len(t, messages, 1)
38
39 systemMsg := messages[0].OfSystem
40 require.NotNil(t, systemMsg)
41 require.Equal(t, "You are a helpful assistant.", systemMsg.Content.OfString.Value)
42 })
43
44 t.Run("should handle empty system messages", func(t *testing.T) {
45 t.Parallel()
46
47 prompt := fantasy.Prompt{
48 {
49 Role: fantasy.MessageRoleSystem,
50 Content: []fantasy.MessagePart{},
51 },
52 }
53
54 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
55
56 require.Len(t, warnings, 1)
57 require.Contains(t, warnings[0].Message, "system prompt has no text parts")
58 require.Empty(t, messages)
59 })
60
61 t.Run("should join multiple system text parts", func(t *testing.T) {
62 t.Parallel()
63
64 prompt := fantasy.Prompt{
65 {
66 Role: fantasy.MessageRoleSystem,
67 Content: []fantasy.MessagePart{
68 fantasy.TextPart{Text: "You are a helpful assistant."},
69 fantasy.TextPart{Text: "Be concise."},
70 },
71 },
72 }
73
74 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
75
76 require.Empty(t, warnings)
77 require.Len(t, messages, 1)
78
79 systemMsg := messages[0].OfSystem
80 require.NotNil(t, systemMsg)
81 require.Equal(t, "You are a helpful assistant.\nBe concise.", systemMsg.Content.OfString.Value)
82 })
83}
84
85func TestToOpenAiPrompt_UserMessages(t *testing.T) {
86 t.Parallel()
87
88 t.Run("should convert messages with only a text part to a string content", func(t *testing.T) {
89 t.Parallel()
90
91 prompt := fantasy.Prompt{
92 {
93 Role: fantasy.MessageRoleUser,
94 Content: []fantasy.MessagePart{
95 fantasy.TextPart{Text: "Hello"},
96 },
97 },
98 }
99
100 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
101
102 require.Empty(t, warnings)
103 require.Len(t, messages, 1)
104
105 userMsg := messages[0].OfUser
106 require.NotNil(t, userMsg)
107 require.Equal(t, "Hello", userMsg.Content.OfString.Value)
108 })
109
110 t.Run("should convert messages with image parts", func(t *testing.T) {
111 t.Parallel()
112
113 imageData := []byte{0, 1, 2, 3}
114 prompt := fantasy.Prompt{
115 {
116 Role: fantasy.MessageRoleUser,
117 Content: []fantasy.MessagePart{
118 fantasy.TextPart{Text: "Hello"},
119 fantasy.FilePart{
120 MediaType: "image/png",
121 Data: imageData,
122 },
123 },
124 },
125 }
126
127 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
128
129 require.Empty(t, warnings)
130 require.Len(t, messages, 1)
131
132 userMsg := messages[0].OfUser
133 require.NotNil(t, userMsg)
134
135 content := userMsg.Content.OfArrayOfContentParts
136 require.Len(t, content, 2)
137
138 // Check text part
139 textPart := content[0].OfText
140 require.NotNil(t, textPart)
141 require.Equal(t, "Hello", textPart.Text)
142
143 // Check image part
144 imagePart := content[1].OfImageURL
145 require.NotNil(t, imagePart)
146 expectedURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
147 require.Equal(t, expectedURL, imagePart.ImageURL.URL)
148 })
149
150 t.Run("should add image detail when specified through provider options", func(t *testing.T) {
151 t.Parallel()
152
153 imageData := []byte{0, 1, 2, 3}
154 prompt := fantasy.Prompt{
155 {
156 Role: fantasy.MessageRoleUser,
157 Content: []fantasy.MessagePart{
158 fantasy.FilePart{
159 MediaType: "image/png",
160 Data: imageData,
161 ProviderOptions: NewProviderFileOptions(&ProviderFileOptions{
162 ImageDetail: "low",
163 }),
164 },
165 },
166 },
167 }
168
169 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
170
171 require.Empty(t, warnings)
172 require.Len(t, messages, 1)
173
174 userMsg := messages[0].OfUser
175 require.NotNil(t, userMsg)
176
177 content := userMsg.Content.OfArrayOfContentParts
178 require.Len(t, content, 1)
179
180 imagePart := content[0].OfImageURL
181 require.NotNil(t, imagePart)
182 require.Equal(t, "low", imagePart.ImageURL.Detail)
183 })
184}
185
186func TestToOpenAiPrompt_FileParts(t *testing.T) {
187 t.Parallel()
188
189 t.Run("should throw for unsupported mime types", func(t *testing.T) {
190 t.Parallel()
191
192 prompt := fantasy.Prompt{
193 {
194 Role: fantasy.MessageRoleUser,
195 Content: []fantasy.MessagePart{
196 fantasy.FilePart{
197 MediaType: "application/something",
198 Data: []byte("test"),
199 },
200 },
201 },
202 }
203
204 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
205
206 require.Len(t, warnings, 2) // unsupported type + empty message
207 require.Contains(t, warnings[0].Message, "file part media type application/something not supported")
208 require.Contains(t, warnings[1].Message, "dropping empty user message")
209 require.Empty(t, messages) // Message is now dropped because it's empty
210 })
211
212 t.Run("should add audio content for audio/wav file parts", func(t *testing.T) {
213 t.Parallel()
214
215 audioData := []byte{0, 1, 2, 3}
216 prompt := fantasy.Prompt{
217 {
218 Role: fantasy.MessageRoleUser,
219 Content: []fantasy.MessagePart{
220 fantasy.FilePart{
221 MediaType: "audio/wav",
222 Data: audioData,
223 },
224 },
225 },
226 }
227
228 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
229
230 require.Empty(t, warnings)
231 require.Len(t, messages, 1)
232
233 userMsg := messages[0].OfUser
234 require.NotNil(t, userMsg)
235
236 content := userMsg.Content.OfArrayOfContentParts
237 require.Len(t, content, 1)
238
239 audioPart := content[0].OfInputAudio
240 require.NotNil(t, audioPart)
241 require.Equal(t, base64.StdEncoding.EncodeToString(audioData), audioPart.InputAudio.Data)
242 require.Equal(t, "wav", audioPart.InputAudio.Format)
243 })
244
245 t.Run("should add audio content for audio/mpeg file parts", func(t *testing.T) {
246 t.Parallel()
247
248 audioData := []byte{0, 1, 2, 3}
249 prompt := fantasy.Prompt{
250 {
251 Role: fantasy.MessageRoleUser,
252 Content: []fantasy.MessagePart{
253 fantasy.FilePart{
254 MediaType: "audio/mpeg",
255 Data: audioData,
256 },
257 },
258 },
259 }
260
261 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
262
263 require.Empty(t, warnings)
264 require.Len(t, messages, 1)
265
266 userMsg := messages[0].OfUser
267 content := userMsg.Content.OfArrayOfContentParts
268 audioPart := content[0].OfInputAudio
269 require.NotNil(t, audioPart)
270 require.Equal(t, "mp3", audioPart.InputAudio.Format)
271 })
272
273 t.Run("should add audio content for audio/mp3 file parts", func(t *testing.T) {
274 t.Parallel()
275
276 audioData := []byte{0, 1, 2, 3}
277 prompt := fantasy.Prompt{
278 {
279 Role: fantasy.MessageRoleUser,
280 Content: []fantasy.MessagePart{
281 fantasy.FilePart{
282 MediaType: "audio/mp3",
283 Data: audioData,
284 },
285 },
286 },
287 }
288
289 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
290
291 require.Empty(t, warnings)
292 require.Len(t, messages, 1)
293
294 userMsg := messages[0].OfUser
295 content := userMsg.Content.OfArrayOfContentParts
296 audioPart := content[0].OfInputAudio
297 require.NotNil(t, audioPart)
298 require.Equal(t, "mp3", audioPart.InputAudio.Format)
299 })
300
301 t.Run("should convert messages with PDF file parts", func(t *testing.T) {
302 t.Parallel()
303
304 pdfData := []byte{1, 2, 3, 4, 5}
305 prompt := fantasy.Prompt{
306 {
307 Role: fantasy.MessageRoleUser,
308 Content: []fantasy.MessagePart{
309 fantasy.FilePart{
310 MediaType: "application/pdf",
311 Data: pdfData,
312 Filename: "document.pdf",
313 },
314 },
315 },
316 }
317
318 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
319
320 require.Empty(t, warnings)
321 require.Len(t, messages, 1)
322
323 userMsg := messages[0].OfUser
324 content := userMsg.Content.OfArrayOfContentParts
325 require.Len(t, content, 1)
326
327 filePart := content[0].OfFile
328 require.NotNil(t, filePart)
329 require.Equal(t, "document.pdf", filePart.File.Filename.Value)
330
331 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
332 require.Equal(t, expectedData, filePart.File.FileData.Value)
333 })
334
335 t.Run("should convert messages with binary PDF file parts", func(t *testing.T) {
336 t.Parallel()
337
338 pdfData := []byte{1, 2, 3, 4, 5}
339 prompt := fantasy.Prompt{
340 {
341 Role: fantasy.MessageRoleUser,
342 Content: []fantasy.MessagePart{
343 fantasy.FilePart{
344 MediaType: "application/pdf",
345 Data: pdfData,
346 Filename: "document.pdf",
347 },
348 },
349 },
350 }
351
352 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
353
354 require.Empty(t, warnings)
355 require.Len(t, messages, 1)
356
357 userMsg := messages[0].OfUser
358 content := userMsg.Content.OfArrayOfContentParts
359 filePart := content[0].OfFile
360 require.NotNil(t, filePart)
361
362 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
363 require.Equal(t, expectedData, filePart.File.FileData.Value)
364 })
365
366 t.Run("should convert messages with PDF file parts using file_id", func(t *testing.T) {
367 t.Parallel()
368
369 prompt := fantasy.Prompt{
370 {
371 Role: fantasy.MessageRoleUser,
372 Content: []fantasy.MessagePart{
373 fantasy.FilePart{
374 MediaType: "application/pdf",
375 Data: []byte("file-pdf-12345"),
376 },
377 },
378 },
379 }
380
381 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
382
383 require.Empty(t, warnings)
384 require.Len(t, messages, 1)
385
386 userMsg := messages[0].OfUser
387 content := userMsg.Content.OfArrayOfContentParts
388 filePart := content[0].OfFile
389 require.NotNil(t, filePart)
390 require.Equal(t, "file-pdf-12345", filePart.File.FileID.Value)
391 require.True(t, param.IsOmitted(filePart.File.FileData))
392 require.True(t, param.IsOmitted(filePart.File.Filename))
393 })
394
395 t.Run("should use default filename for PDF file parts when not provided", func(t *testing.T) {
396 t.Parallel()
397
398 pdfData := []byte{1, 2, 3, 4, 5}
399 prompt := fantasy.Prompt{
400 {
401 Role: fantasy.MessageRoleUser,
402 Content: []fantasy.MessagePart{
403 fantasy.FilePart{
404 MediaType: "application/pdf",
405 Data: pdfData,
406 },
407 },
408 },
409 }
410
411 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
412
413 require.Empty(t, warnings)
414 require.Len(t, messages, 1)
415
416 userMsg := messages[0].OfUser
417 content := userMsg.Content.OfArrayOfContentParts
418 filePart := content[0].OfFile
419 require.NotNil(t, filePart)
420 require.Equal(t, "part-0.pdf", filePart.File.Filename.Value)
421 })
422}
423
424func TestToOpenAiPrompt_ToolCalls(t *testing.T) {
425 t.Parallel()
426
427 t.Run("should stringify arguments to tool calls", func(t *testing.T) {
428 t.Parallel()
429
430 inputArgs := map[string]any{"foo": "bar123"}
431 inputJSON, _ := json.Marshal(inputArgs)
432
433 outputResult := map[string]any{"oof": "321rab"}
434 outputJSON, _ := json.Marshal(outputResult)
435
436 prompt := fantasy.Prompt{
437 {
438 Role: fantasy.MessageRoleAssistant,
439 Content: []fantasy.MessagePart{
440 fantasy.ToolCallPart{
441 ToolCallID: "quux",
442 ToolName: "thwomp",
443 Input: string(inputJSON),
444 },
445 },
446 },
447 {
448 Role: fantasy.MessageRoleTool,
449 Content: []fantasy.MessagePart{
450 fantasy.ToolResultPart{
451 ToolCallID: "quux",
452 Output: fantasy.ToolResultOutputContentText{
453 Text: string(outputJSON),
454 },
455 },
456 },
457 },
458 }
459
460 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
461
462 require.Empty(t, warnings)
463 require.Len(t, messages, 2)
464
465 // Check assistant message with tool call
466 assistantMsg := messages[0].OfAssistant
467 require.NotNil(t, assistantMsg)
468 require.Equal(t, "", assistantMsg.Content.OfString.Value)
469 require.Len(t, assistantMsg.ToolCalls, 1)
470
471 toolCall := assistantMsg.ToolCalls[0].OfFunction
472 require.NotNil(t, toolCall)
473 require.Equal(t, "quux", toolCall.ID)
474 require.Equal(t, "thwomp", toolCall.Function.Name)
475 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
476
477 // Check tool message
478 toolMsg := messages[1].OfTool
479 require.NotNil(t, toolMsg)
480 require.Equal(t, string(outputJSON), toolMsg.Content.OfString.Value)
481 require.Equal(t, "quux", toolMsg.ToolCallID)
482 })
483
484 t.Run("should handle different tool output types", func(t *testing.T) {
485 t.Parallel()
486
487 prompt := fantasy.Prompt{
488 {
489 Role: fantasy.MessageRoleTool,
490 Content: []fantasy.MessagePart{
491 fantasy.ToolResultPart{
492 ToolCallID: "text-tool",
493 Output: fantasy.ToolResultOutputContentText{
494 Text: "Hello world",
495 },
496 },
497 fantasy.ToolResultPart{
498 ToolCallID: "error-tool",
499 Output: fantasy.ToolResultOutputContentError{
500 Error: errors.New("Something went wrong"),
501 },
502 },
503 },
504 },
505 }
506
507 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
508
509 require.Empty(t, warnings)
510 require.Len(t, messages, 2)
511
512 // Check first tool message (text)
513 textToolMsg := messages[0].OfTool
514 require.NotNil(t, textToolMsg)
515 require.Equal(t, "Hello world", textToolMsg.Content.OfString.Value)
516 require.Equal(t, "text-tool", textToolMsg.ToolCallID)
517
518 // Check second tool message (error)
519 errorToolMsg := messages[1].OfTool
520 require.NotNil(t, errorToolMsg)
521 require.Equal(t, "Something went wrong", errorToolMsg.Content.OfString.Value)
522 require.Equal(t, "error-tool", errorToolMsg.ToolCallID)
523 })
524}
525
526func TestToOpenAiPrompt_AssistantMessages(t *testing.T) {
527 t.Parallel()
528
529 t.Run("should handle simple text assistant messages", func(t *testing.T) {
530 t.Parallel()
531
532 prompt := fantasy.Prompt{
533 {
534 Role: fantasy.MessageRoleAssistant,
535 Content: []fantasy.MessagePart{
536 fantasy.TextPart{Text: "Hello, how can I help you?"},
537 },
538 },
539 }
540
541 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
542
543 require.Empty(t, warnings)
544 require.Len(t, messages, 1)
545
546 assistantMsg := messages[0].OfAssistant
547 require.NotNil(t, assistantMsg)
548 require.Equal(t, "Hello, how can I help you?", assistantMsg.Content.OfString.Value)
549 })
550
551 t.Run("should handle assistant messages with mixed content", func(t *testing.T) {
552 t.Parallel()
553
554 inputArgs := map[string]any{"query": "test"}
555 inputJSON, _ := json.Marshal(inputArgs)
556
557 prompt := fantasy.Prompt{
558 {
559 Role: fantasy.MessageRoleAssistant,
560 Content: []fantasy.MessagePart{
561 fantasy.TextPart{Text: "Let me search for that."},
562 fantasy.ToolCallPart{
563 ToolCallID: "call-123",
564 ToolName: "search",
565 Input: string(inputJSON),
566 },
567 },
568 },
569 }
570
571 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
572
573 require.Empty(t, warnings)
574 require.Len(t, messages, 1)
575
576 assistantMsg := messages[0].OfAssistant
577 require.NotNil(t, assistantMsg)
578 require.Equal(t, "Let me search for that.", assistantMsg.Content.OfString.Value)
579 require.Len(t, assistantMsg.ToolCalls, 1)
580
581 toolCall := assistantMsg.ToolCalls[0].OfFunction
582 require.Equal(t, "call-123", toolCall.ID)
583 require.Equal(t, "search", toolCall.Function.Name)
584 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
585 })
586}
587
588var testPrompt = fantasy.Prompt{
589 {
590 Role: fantasy.MessageRoleUser,
591 Content: []fantasy.MessagePart{
592 fantasy.TextPart{Text: "Hello"},
593 },
594 },
595}
596
597var testLogprobs = map[string]any{
598 "content": []map[string]any{
599 {
600 "token": "Hello",
601 "logprob": -0.0009994634,
602 "top_logprobs": []map[string]any{
603 {
604 "token": "Hello",
605 "logprob": -0.0009994634,
606 },
607 },
608 },
609 {
610 "token": "!",
611 "logprob": -0.13410144,
612 "top_logprobs": []map[string]any{
613 {
614 "token": "!",
615 "logprob": -0.13410144,
616 },
617 },
618 },
619 {
620 "token": " How",
621 "logprob": -0.0009250381,
622 "top_logprobs": []map[string]any{
623 {
624 "token": " How",
625 "logprob": -0.0009250381,
626 },
627 },
628 },
629 {
630 "token": " can",
631 "logprob": -0.047709424,
632 "top_logprobs": []map[string]any{
633 {
634 "token": " can",
635 "logprob": -0.047709424,
636 },
637 },
638 },
639 {
640 "token": " I",
641 "logprob": -0.000009014684,
642 "top_logprobs": []map[string]any{
643 {
644 "token": " I",
645 "logprob": -0.000009014684,
646 },
647 },
648 },
649 {
650 "token": " assist",
651 "logprob": -0.009125131,
652 "top_logprobs": []map[string]any{
653 {
654 "token": " assist",
655 "logprob": -0.009125131,
656 },
657 },
658 },
659 {
660 "token": " you",
661 "logprob": -0.0000066306106,
662 "top_logprobs": []map[string]any{
663 {
664 "token": " you",
665 "logprob": -0.0000066306106,
666 },
667 },
668 },
669 {
670 "token": " today",
671 "logprob": -0.00011093382,
672 "top_logprobs": []map[string]any{
673 {
674 "token": " today",
675 "logprob": -0.00011093382,
676 },
677 },
678 },
679 {
680 "token": "?",
681 "logprob": -0.00004596782,
682 "top_logprobs": []map[string]any{
683 {
684 "token": "?",
685 "logprob": -0.00004596782,
686 },
687 },
688 },
689 },
690}
691
692type mockServer struct {
693 server *httptest.Server
694 response map[string]any
695 calls []mockCall
696}
697
698type mockCall struct {
699 method string
700 path string
701 headers map[string]string
702 body map[string]any
703}
704
705func newMockServer() *mockServer {
706 ms := &mockServer{
707 calls: make([]mockCall, 0),
708 }
709
710 ms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
711 // Record the call
712 call := mockCall{
713 method: r.Method,
714 path: r.URL.Path,
715 headers: make(map[string]string),
716 }
717
718 for k, v := range r.Header {
719 if len(v) > 0 {
720 call.headers[k] = v[0]
721 }
722 }
723
724 // Parse request body
725 if r.Body != nil {
726 var body map[string]any
727 json.NewDecoder(r.Body).Decode(&body)
728 call.body = body
729 }
730
731 ms.calls = append(ms.calls, call)
732
733 // Return mock response
734 w.Header().Set("Content-Type", "application/json")
735 json.NewEncoder(w).Encode(ms.response)
736 }))
737
738 return ms
739}
740
741func (ms *mockServer) close() {
742 ms.server.Close()
743}
744
745func (ms *mockServer) prepareJSONResponse(opts map[string]any) {
746 // Default values
747 response := map[string]any{
748 "id": "chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd",
749 "object": "chat.completion",
750 "created": 1711115037,
751 "model": "gpt-3.5-turbo-0125",
752 "choices": []map[string]any{
753 {
754 "index": 0,
755 "message": map[string]any{
756 "role": "assistant",
757 "content": "",
758 },
759 "finish_reason": "stop",
760 },
761 },
762 "usage": map[string]any{
763 "prompt_tokens": 4,
764 "total_tokens": 34,
765 "completion_tokens": 30,
766 },
767 "system_fingerprint": "fp_3bc1b5746c",
768 }
769
770 // Override with provided options
771 for k, v := range opts {
772 switch k {
773 case "content":
774 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["content"] = v
775 case "tool_calls":
776 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["tool_calls"] = v
777 case "function_call":
778 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["function_call"] = v
779 case "annotations":
780 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["annotations"] = v
781 case "usage":
782 response["usage"] = v
783 case "finish_reason":
784 response["choices"].([]map[string]any)[0]["finish_reason"] = v
785 case "id":
786 response["id"] = v
787 case "created":
788 response["created"] = v
789 case "model":
790 response["model"] = v
791 case "logprobs":
792 if v != nil {
793 response["choices"].([]map[string]any)[0]["logprobs"] = v
794 }
795 }
796 }
797
798 ms.response = response
799}
800
801func TestDoGenerate(t *testing.T) {
802 t.Parallel()
803
804 t.Run("should extract text response", func(t *testing.T) {
805 t.Parallel()
806
807 server := newMockServer()
808 defer server.close()
809
810 server.prepareJSONResponse(map[string]any{
811 "content": "Hello, World!",
812 })
813
814 provider, err := New(
815 WithAPIKey("test-api-key"),
816 WithBaseURL(server.server.URL),
817 )
818 require.NoError(t, err)
819 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
820
821 result, err := model.Generate(context.Background(), fantasy.Call{
822 Prompt: testPrompt,
823 })
824
825 require.NoError(t, err)
826 require.Len(t, result.Content, 1)
827
828 textContent, ok := result.Content[0].(fantasy.TextContent)
829 require.True(t, ok)
830 require.Equal(t, "Hello, World!", textContent.Text)
831 })
832
833 t.Run("should extract usage", func(t *testing.T) {
834 t.Parallel()
835
836 server := newMockServer()
837 defer server.close()
838
839 server.prepareJSONResponse(map[string]any{
840 "usage": map[string]any{
841 "prompt_tokens": 20,
842 "total_tokens": 25,
843 "completion_tokens": 5,
844 },
845 })
846
847 provider, err := New(
848 WithAPIKey("test-api-key"),
849 WithBaseURL(server.server.URL),
850 )
851 require.NoError(t, err)
852 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
853
854 result, err := model.Generate(context.Background(), fantasy.Call{
855 Prompt: testPrompt,
856 })
857
858 require.NoError(t, err)
859 require.Equal(t, int64(20), result.Usage.InputTokens)
860 require.Equal(t, int64(5), result.Usage.OutputTokens)
861 require.Equal(t, int64(25), result.Usage.TotalTokens)
862 })
863
864 t.Run("should send request body", func(t *testing.T) {
865 t.Parallel()
866
867 server := newMockServer()
868 defer server.close()
869
870 server.prepareJSONResponse(map[string]any{})
871
872 provider, err := New(
873 WithAPIKey("test-api-key"),
874 WithBaseURL(server.server.URL),
875 )
876 require.NoError(t, err)
877 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
878
879 _, err = model.Generate(context.Background(), fantasy.Call{
880 Prompt: testPrompt,
881 })
882
883 require.NoError(t, err)
884 require.Len(t, server.calls, 1)
885
886 call := server.calls[0]
887 require.Equal(t, "POST", call.method)
888 require.Equal(t, "/chat/completions", call.path)
889 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
890
891 messages, ok := call.body["messages"].([]any)
892 require.True(t, ok)
893 require.Len(t, messages, 1)
894
895 message := messages[0].(map[string]any)
896 require.Equal(t, "user", message["role"])
897 require.Equal(t, "Hello", message["content"])
898 })
899
900 t.Run("should support partial usage", func(t *testing.T) {
901 t.Parallel()
902
903 server := newMockServer()
904 defer server.close()
905
906 server.prepareJSONResponse(map[string]any{
907 "usage": map[string]any{
908 "prompt_tokens": 20,
909 "total_tokens": 20,
910 },
911 })
912
913 provider, err := New(
914 WithAPIKey("test-api-key"),
915 WithBaseURL(server.server.URL),
916 )
917 require.NoError(t, err)
918 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
919
920 result, err := model.Generate(context.Background(), fantasy.Call{
921 Prompt: testPrompt,
922 })
923
924 require.NoError(t, err)
925 require.Equal(t, int64(20), result.Usage.InputTokens)
926 require.Equal(t, int64(0), result.Usage.OutputTokens)
927 require.Equal(t, int64(20), result.Usage.TotalTokens)
928 })
929
930 t.Run("should extract logprobs", func(t *testing.T) {
931 t.Parallel()
932
933 server := newMockServer()
934 defer server.close()
935
936 server.prepareJSONResponse(map[string]any{
937 "logprobs": testLogprobs,
938 })
939
940 provider, err := New(
941 WithAPIKey("test-api-key"),
942 WithBaseURL(server.server.URL),
943 )
944 require.NoError(t, err)
945 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
946
947 result, err := model.Generate(context.Background(), fantasy.Call{
948 Prompt: testPrompt,
949 ProviderOptions: NewProviderOptions(&ProviderOptions{
950 LogProbs: fantasy.Opt(true),
951 }),
952 })
953
954 require.NoError(t, err)
955 require.NotNil(t, result.ProviderMetadata)
956
957 openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
958 require.True(t, ok)
959
960 logprobs := openaiMeta.Logprobs
961 require.True(t, ok)
962 require.NotNil(t, logprobs)
963 })
964
965 t.Run("should extract finish reason", func(t *testing.T) {
966 t.Parallel()
967
968 server := newMockServer()
969 defer server.close()
970
971 server.prepareJSONResponse(map[string]any{
972 "finish_reason": "stop",
973 })
974
975 provider, err := New(
976 WithAPIKey("test-api-key"),
977 WithBaseURL(server.server.URL),
978 )
979 require.NoError(t, err)
980 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
981
982 result, err := model.Generate(context.Background(), fantasy.Call{
983 Prompt: testPrompt,
984 })
985
986 require.NoError(t, err)
987 require.Equal(t, fantasy.FinishReasonStop, result.FinishReason)
988 })
989
990 t.Run("should support unknown finish reason", func(t *testing.T) {
991 t.Parallel()
992
993 server := newMockServer()
994 defer server.close()
995
996 server.prepareJSONResponse(map[string]any{
997 "finish_reason": "eos",
998 })
999
1000 provider, err := New(
1001 WithAPIKey("test-api-key"),
1002 WithBaseURL(server.server.URL),
1003 )
1004 require.NoError(t, err)
1005 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1006
1007 result, err := model.Generate(context.Background(), fantasy.Call{
1008 Prompt: testPrompt,
1009 })
1010
1011 require.NoError(t, err)
1012 require.Equal(t, fantasy.FinishReasonUnknown, result.FinishReason)
1013 })
1014
1015 t.Run("should pass the model and the messages", func(t *testing.T) {
1016 t.Parallel()
1017
1018 server := newMockServer()
1019 defer server.close()
1020
1021 server.prepareJSONResponse(map[string]any{
1022 "content": "",
1023 })
1024
1025 provider, err := New(
1026 WithAPIKey("test-api-key"),
1027 WithBaseURL(server.server.URL),
1028 )
1029 require.NoError(t, err)
1030 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1031
1032 _, err = model.Generate(context.Background(), fantasy.Call{
1033 Prompt: testPrompt,
1034 })
1035
1036 require.NoError(t, err)
1037 require.Len(t, server.calls, 1)
1038
1039 call := server.calls[0]
1040 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1041
1042 messages := call.body["messages"].([]any)
1043 require.Len(t, messages, 1)
1044
1045 message := messages[0].(map[string]any)
1046 require.Equal(t, "user", message["role"])
1047 require.Equal(t, "Hello", message["content"])
1048 })
1049
1050 t.Run("should pass settings", func(t *testing.T) {
1051 t.Parallel()
1052
1053 server := newMockServer()
1054 defer server.close()
1055
1056 server.prepareJSONResponse(map[string]any{})
1057
1058 provider, err := New(
1059 WithAPIKey("test-api-key"),
1060 WithBaseURL(server.server.URL),
1061 )
1062 require.NoError(t, err)
1063 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1064
1065 _, err = model.Generate(context.Background(), fantasy.Call{
1066 Prompt: testPrompt,
1067 ProviderOptions: NewProviderOptions(&ProviderOptions{
1068 LogitBias: map[string]int64{
1069 "50256": -100,
1070 },
1071 ParallelToolCalls: fantasy.Opt(false),
1072 User: fantasy.Opt("test-user-id"),
1073 }),
1074 })
1075
1076 require.NoError(t, err)
1077 require.Len(t, server.calls, 1)
1078
1079 call := server.calls[0]
1080 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1081
1082 messages := call.body["messages"].([]any)
1083 require.Len(t, messages, 1)
1084
1085 logitBias := call.body["logit_bias"].(map[string]any)
1086 require.Equal(t, float64(-100), logitBias["50256"])
1087 require.Equal(t, false, call.body["parallel_tool_calls"])
1088 require.Equal(t, "test-user-id", call.body["user"])
1089 })
1090
1091 t.Run("should pass reasoningEffort setting", func(t *testing.T) {
1092 t.Parallel()
1093
1094 server := newMockServer()
1095 defer server.close()
1096
1097 server.prepareJSONResponse(map[string]any{
1098 "content": "",
1099 })
1100
1101 provider, err := New(
1102 WithAPIKey("test-api-key"),
1103 WithBaseURL(server.server.URL),
1104 )
1105 require.NoError(t, err)
1106 model, _ := provider.LanguageModel(t.Context(), "o1-mini")
1107
1108 _, err = model.Generate(context.Background(), fantasy.Call{
1109 Prompt: testPrompt,
1110 ProviderOptions: NewProviderOptions(
1111 &ProviderOptions{
1112 ReasoningEffort: ReasoningEffortOption(ReasoningEffortLow),
1113 },
1114 ),
1115 })
1116
1117 require.NoError(t, err)
1118 require.Len(t, server.calls, 1)
1119
1120 call := server.calls[0]
1121 require.Equal(t, "o1-mini", call.body["model"])
1122 require.Equal(t, "low", call.body["reasoning_effort"])
1123
1124 messages := call.body["messages"].([]any)
1125 require.Len(t, messages, 1)
1126
1127 message := messages[0].(map[string]any)
1128 require.Equal(t, "user", message["role"])
1129 require.Equal(t, "Hello", message["content"])
1130 })
1131
1132 t.Run("should pass textVerbosity setting", func(t *testing.T) {
1133 t.Parallel()
1134
1135 server := newMockServer()
1136 defer server.close()
1137
1138 server.prepareJSONResponse(map[string]any{
1139 "content": "",
1140 })
1141
1142 provider, err := New(
1143 WithAPIKey("test-api-key"),
1144 WithBaseURL(server.server.URL),
1145 )
1146 require.NoError(t, err)
1147 model, _ := provider.LanguageModel(t.Context(), "gpt-4o")
1148
1149 _, err = model.Generate(context.Background(), fantasy.Call{
1150 Prompt: testPrompt,
1151 ProviderOptions: NewProviderOptions(&ProviderOptions{
1152 TextVerbosity: fantasy.Opt("low"),
1153 }),
1154 })
1155
1156 require.NoError(t, err)
1157 require.Len(t, server.calls, 1)
1158
1159 call := server.calls[0]
1160 require.Equal(t, "gpt-4o", call.body["model"])
1161 require.Equal(t, "low", call.body["verbosity"])
1162
1163 messages := call.body["messages"].([]any)
1164 require.Len(t, messages, 1)
1165
1166 message := messages[0].(map[string]any)
1167 require.Equal(t, "user", message["role"])
1168 require.Equal(t, "Hello", message["content"])
1169 })
1170
1171 t.Run("should pass tools and toolChoice", func(t *testing.T) {
1172 t.Parallel()
1173
1174 server := newMockServer()
1175 defer server.close()
1176
1177 server.prepareJSONResponse(map[string]any{
1178 "content": "",
1179 })
1180
1181 provider, err := New(
1182 WithAPIKey("test-api-key"),
1183 WithBaseURL(server.server.URL),
1184 )
1185 require.NoError(t, err)
1186 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1187
1188 _, err = model.Generate(context.Background(), fantasy.Call{
1189 Prompt: testPrompt,
1190 Tools: []fantasy.Tool{
1191 fantasy.FunctionTool{
1192 Name: "test-tool",
1193 InputSchema: map[string]any{
1194 "type": "object",
1195 "properties": map[string]any{
1196 "value": map[string]any{
1197 "type": "string",
1198 },
1199 },
1200 "required": []string{"value"},
1201 "additionalProperties": false,
1202 "$schema": "http://json-schema.org/draft-07/schema#",
1203 },
1204 },
1205 },
1206 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoice("test-tool")}[0],
1207 })
1208
1209 require.NoError(t, err)
1210 require.Len(t, server.calls, 1)
1211
1212 call := server.calls[0]
1213 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1214
1215 messages := call.body["messages"].([]any)
1216 require.Len(t, messages, 1)
1217
1218 tools := call.body["tools"].([]any)
1219 require.Len(t, tools, 1)
1220
1221 tool := tools[0].(map[string]any)
1222 require.Equal(t, "function", tool["type"])
1223
1224 function := tool["function"].(map[string]any)
1225 require.Equal(t, "test-tool", function["name"])
1226 require.Equal(t, false, function["strict"])
1227
1228 toolChoice := call.body["tool_choice"].(map[string]any)
1229 require.Equal(t, "function", toolChoice["type"])
1230
1231 toolChoiceFunction := toolChoice["function"].(map[string]any)
1232 require.Equal(t, "test-tool", toolChoiceFunction["name"])
1233 })
1234
1235 t.Run("should parse tool results", func(t *testing.T) {
1236 t.Parallel()
1237
1238 server := newMockServer()
1239 defer server.close()
1240
1241 server.prepareJSONResponse(map[string]any{
1242 "tool_calls": []map[string]any{
1243 {
1244 "id": "call_O17Uplv4lJvD6DVdIvFFeRMw",
1245 "type": "function",
1246 "function": map[string]any{
1247 "name": "test-tool",
1248 "arguments": `{"value":"Spark"}`,
1249 },
1250 },
1251 },
1252 })
1253
1254 provider, err := New(
1255 WithAPIKey("test-api-key"),
1256 WithBaseURL(server.server.URL),
1257 )
1258 require.NoError(t, err)
1259 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1260
1261 result, err := model.Generate(context.Background(), fantasy.Call{
1262 Prompt: testPrompt,
1263 Tools: []fantasy.Tool{
1264 fantasy.FunctionTool{
1265 Name: "test-tool",
1266 InputSchema: map[string]any{
1267 "type": "object",
1268 "properties": map[string]any{
1269 "value": map[string]any{
1270 "type": "string",
1271 },
1272 },
1273 "required": []string{"value"},
1274 "additionalProperties": false,
1275 "$schema": "http://json-schema.org/draft-07/schema#",
1276 },
1277 },
1278 },
1279 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoice("test-tool")}[0],
1280 })
1281
1282 require.NoError(t, err)
1283 require.Len(t, result.Content, 1)
1284
1285 toolCall, ok := result.Content[0].(fantasy.ToolCallContent)
1286 require.True(t, ok)
1287 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", toolCall.ToolCallID)
1288 require.Equal(t, "test-tool", toolCall.ToolName)
1289 require.Equal(t, `{"value":"Spark"}`, toolCall.Input)
1290 })
1291
1292 t.Run("should handle ToolChoiceRequired", func(t *testing.T) {
1293 t.Parallel()
1294
1295 server := newMockServer()
1296 defer server.close()
1297
1298 server.prepareJSONResponse(map[string]any{
1299 "content": "",
1300 })
1301
1302 provider, err := New(
1303 WithAPIKey("test-api-key"),
1304 WithBaseURL(server.server.URL),
1305 )
1306 require.NoError(t, err)
1307 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1308
1309 _, err = model.Generate(context.Background(), fantasy.Call{
1310 Prompt: testPrompt,
1311 Tools: []fantasy.Tool{
1312 fantasy.FunctionTool{
1313 Name: "test-tool",
1314 InputSchema: map[string]any{
1315 "type": "object",
1316 "properties": map[string]any{
1317 "value": map[string]any{
1318 "type": "string",
1319 },
1320 },
1321 "required": []string{"value"},
1322 "additionalProperties": false,
1323 "$schema": "http://json-schema.org/draft-07/schema#",
1324 },
1325 },
1326 },
1327 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoiceRequired}[0],
1328 })
1329
1330 require.NoError(t, err)
1331 require.Len(t, server.calls, 1)
1332
1333 call := server.calls[0]
1334 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1335
1336 // Verify tool is present
1337 tools := call.body["tools"].([]any)
1338 require.Len(t, tools, 1)
1339
1340 tool := tools[0].(map[string]any)
1341 require.Equal(t, "function", tool["type"])
1342
1343 function := tool["function"].(map[string]any)
1344 require.Equal(t, "test-tool", function["name"])
1345
1346 // Verify tool_choice is set to "required" (not a function name)
1347 toolChoice := call.body["tool_choice"]
1348 require.Equal(t, "required", toolChoice)
1349 })
1350
1351 t.Run("should parse annotations/citations", func(t *testing.T) {
1352 t.Parallel()
1353
1354 server := newMockServer()
1355 defer server.close()
1356
1357 server.prepareJSONResponse(map[string]any{
1358 "content": "Based on the search results [doc1], I found information.",
1359 "annotations": []map[string]any{
1360 {
1361 "type": "url_citation",
1362 "url_citation": map[string]any{
1363 "start_index": 24,
1364 "end_index": 29,
1365 "url": "https://example.com/doc1.pdf",
1366 "title": "Document 1",
1367 },
1368 },
1369 },
1370 })
1371
1372 provider, err := New(
1373 WithAPIKey("test-api-key"),
1374 WithBaseURL(server.server.URL),
1375 )
1376 require.NoError(t, err)
1377 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1378
1379 result, err := model.Generate(context.Background(), fantasy.Call{
1380 Prompt: testPrompt,
1381 })
1382
1383 require.NoError(t, err)
1384 require.Len(t, result.Content, 2)
1385
1386 textContent, ok := result.Content[0].(fantasy.TextContent)
1387 require.True(t, ok)
1388 require.Equal(t, "Based on the search results [doc1], I found information.", textContent.Text)
1389
1390 sourceContent, ok := result.Content[1].(fantasy.SourceContent)
1391 require.True(t, ok)
1392 require.Equal(t, fantasy.SourceTypeURL, sourceContent.SourceType)
1393 require.Equal(t, "https://example.com/doc1.pdf", sourceContent.URL)
1394 require.Equal(t, "Document 1", sourceContent.Title)
1395 require.NotEmpty(t, sourceContent.ID)
1396 })
1397
1398 t.Run("should return cached_tokens in prompt_details_tokens", func(t *testing.T) {
1399 t.Parallel()
1400
1401 server := newMockServer()
1402 defer server.close()
1403
1404 server.prepareJSONResponse(map[string]any{
1405 "usage": map[string]any{
1406 "prompt_tokens": 15,
1407 "completion_tokens": 20,
1408 "total_tokens": 35,
1409 "prompt_tokens_details": map[string]any{
1410 "cached_tokens": 1152,
1411 },
1412 },
1413 })
1414
1415 provider, err := New(
1416 WithAPIKey("test-api-key"),
1417 WithBaseURL(server.server.URL),
1418 )
1419 require.NoError(t, err)
1420 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1421
1422 result, err := model.Generate(context.Background(), fantasy.Call{
1423 Prompt: testPrompt,
1424 })
1425
1426 require.NoError(t, err)
1427 require.Equal(t, int64(1152), result.Usage.CacheReadTokens)
1428 // InputTokens = prompt_tokens - cached_tokens = 15 - 1152 = -1137 → clamped to 0
1429 require.Equal(t, int64(0), result.Usage.InputTokens)
1430 require.Equal(t, int64(20), result.Usage.OutputTokens)
1431 require.Equal(t, int64(35), result.Usage.TotalTokens)
1432 })
1433
1434 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
1435 t.Parallel()
1436
1437 server := newMockServer()
1438 defer server.close()
1439
1440 server.prepareJSONResponse(map[string]any{
1441 "usage": map[string]any{
1442 "prompt_tokens": 15,
1443 "completion_tokens": 20,
1444 "total_tokens": 35,
1445 "completion_tokens_details": map[string]any{
1446 "accepted_prediction_tokens": 123,
1447 "rejected_prediction_tokens": 456,
1448 },
1449 },
1450 })
1451
1452 provider, err := New(
1453 WithAPIKey("test-api-key"),
1454 WithBaseURL(server.server.URL),
1455 )
1456 require.NoError(t, err)
1457 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1458
1459 result, err := model.Generate(context.Background(), fantasy.Call{
1460 Prompt: testPrompt,
1461 })
1462
1463 require.NoError(t, err)
1464 require.NotNil(t, result.ProviderMetadata)
1465
1466 openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
1467
1468 require.True(t, ok)
1469 require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
1470 require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
1471 })
1472
1473 t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) {
1474 t.Parallel()
1475
1476 server := newMockServer()
1477 defer server.close()
1478
1479 server.prepareJSONResponse(map[string]any{})
1480
1481 provider, err := New(
1482 WithAPIKey("test-api-key"),
1483 WithBaseURL(server.server.URL),
1484 )
1485 require.NoError(t, err)
1486 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1487
1488 result, err := model.Generate(context.Background(), fantasy.Call{
1489 Prompt: testPrompt,
1490 Temperature: &[]float64{0.5}[0],
1491 TopP: &[]float64{0.7}[0],
1492 FrequencyPenalty: &[]float64{0.2}[0],
1493 PresencePenalty: &[]float64{0.3}[0],
1494 })
1495
1496 require.NoError(t, err)
1497 require.Len(t, server.calls, 1)
1498
1499 call := server.calls[0]
1500 require.Equal(t, "o1-preview", call.body["model"])
1501
1502 messages := call.body["messages"].([]any)
1503 require.Len(t, messages, 1)
1504
1505 message := messages[0].(map[string]any)
1506 require.Equal(t, "user", message["role"])
1507 require.Equal(t, "Hello", message["content"])
1508
1509 // These should not be present
1510 require.Nil(t, call.body["temperature"])
1511 require.Nil(t, call.body["top_p"])
1512 require.Nil(t, call.body["frequency_penalty"])
1513 require.Nil(t, call.body["presence_penalty"])
1514
1515 // Should have warnings
1516 require.Len(t, result.Warnings, 4)
1517 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1518 require.Equal(t, "temperature", result.Warnings[0].Setting)
1519 require.Contains(t, result.Warnings[0].Details, "temperature is not supported for reasoning models")
1520 })
1521
1522 t.Run("should convert maxOutputTokens to max_completion_tokens for reasoning models", func(t *testing.T) {
1523 t.Parallel()
1524
1525 server := newMockServer()
1526 defer server.close()
1527
1528 server.prepareJSONResponse(map[string]any{})
1529
1530 provider, err := New(
1531 WithAPIKey("test-api-key"),
1532 WithBaseURL(server.server.URL),
1533 )
1534 require.NoError(t, err)
1535 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1536
1537 _, err = model.Generate(context.Background(), fantasy.Call{
1538 Prompt: testPrompt,
1539 MaxOutputTokens: &[]int64{1000}[0],
1540 })
1541
1542 require.NoError(t, err)
1543 require.Len(t, server.calls, 1)
1544
1545 call := server.calls[0]
1546 require.Equal(t, "o1-preview", call.body["model"])
1547 require.Equal(t, float64(1000), call.body["max_completion_tokens"])
1548 require.Nil(t, call.body["max_tokens"])
1549
1550 messages := call.body["messages"].([]any)
1551 require.Len(t, messages, 1)
1552
1553 message := messages[0].(map[string]any)
1554 require.Equal(t, "user", message["role"])
1555 require.Equal(t, "Hello", message["content"])
1556 })
1557
1558 t.Run("should return reasoning tokens", func(t *testing.T) {
1559 t.Parallel()
1560
1561 server := newMockServer()
1562 defer server.close()
1563
1564 server.prepareJSONResponse(map[string]any{
1565 "usage": map[string]any{
1566 "prompt_tokens": 15,
1567 "completion_tokens": 20,
1568 "total_tokens": 35,
1569 "completion_tokens_details": map[string]any{
1570 "reasoning_tokens": 10,
1571 },
1572 },
1573 })
1574
1575 provider, err := New(
1576 WithAPIKey("test-api-key"),
1577 WithBaseURL(server.server.URL),
1578 )
1579 require.NoError(t, err)
1580 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1581
1582 result, err := model.Generate(context.Background(), fantasy.Call{
1583 Prompt: testPrompt,
1584 })
1585
1586 require.NoError(t, err)
1587 require.Equal(t, int64(15), result.Usage.InputTokens)
1588 require.Equal(t, int64(20), result.Usage.OutputTokens)
1589 require.Equal(t, int64(35), result.Usage.TotalTokens)
1590 require.Equal(t, int64(10), result.Usage.ReasoningTokens)
1591 })
1592
1593 t.Run("should send max_completion_tokens extension setting", func(t *testing.T) {
1594 t.Parallel()
1595
1596 server := newMockServer()
1597 defer server.close()
1598
1599 server.prepareJSONResponse(map[string]any{
1600 "model": "o1-preview",
1601 })
1602
1603 provider, err := New(
1604 WithAPIKey("test-api-key"),
1605 WithBaseURL(server.server.URL),
1606 )
1607 require.NoError(t, err)
1608 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1609
1610 _, err = model.Generate(context.Background(), fantasy.Call{
1611 Prompt: testPrompt,
1612 ProviderOptions: NewProviderOptions(&ProviderOptions{
1613 MaxCompletionTokens: fantasy.Opt(int64(255)),
1614 }),
1615 })
1616
1617 require.NoError(t, err)
1618 require.Len(t, server.calls, 1)
1619
1620 call := server.calls[0]
1621 require.Equal(t, "o1-preview", call.body["model"])
1622 require.Equal(t, float64(255), call.body["max_completion_tokens"])
1623
1624 messages := call.body["messages"].([]any)
1625 require.Len(t, messages, 1)
1626
1627 message := messages[0].(map[string]any)
1628 require.Equal(t, "user", message["role"])
1629 require.Equal(t, "Hello", message["content"])
1630 })
1631
1632 t.Run("should send prediction extension setting", func(t *testing.T) {
1633 t.Parallel()
1634
1635 server := newMockServer()
1636 defer server.close()
1637
1638 server.prepareJSONResponse(map[string]any{
1639 "content": "",
1640 })
1641
1642 provider, err := New(
1643 WithAPIKey("test-api-key"),
1644 WithBaseURL(server.server.URL),
1645 )
1646 require.NoError(t, err)
1647 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1648
1649 _, err = model.Generate(context.Background(), fantasy.Call{
1650 Prompt: testPrompt,
1651 ProviderOptions: NewProviderOptions(&ProviderOptions{
1652 Prediction: map[string]any{
1653 "type": "content",
1654 "content": "Hello, World!",
1655 },
1656 }),
1657 })
1658
1659 require.NoError(t, err)
1660 require.Len(t, server.calls, 1)
1661
1662 call := server.calls[0]
1663 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1664
1665 prediction := call.body["prediction"].(map[string]any)
1666 require.Equal(t, "content", prediction["type"])
1667 require.Equal(t, "Hello, World!", prediction["content"])
1668
1669 messages := call.body["messages"].([]any)
1670 require.Len(t, messages, 1)
1671
1672 message := messages[0].(map[string]any)
1673 require.Equal(t, "user", message["role"])
1674 require.Equal(t, "Hello", message["content"])
1675 })
1676
1677 t.Run("should send store extension setting", func(t *testing.T) {
1678 t.Parallel()
1679
1680 server := newMockServer()
1681 defer server.close()
1682
1683 server.prepareJSONResponse(map[string]any{
1684 "content": "",
1685 })
1686
1687 provider, err := New(
1688 WithAPIKey("test-api-key"),
1689 WithBaseURL(server.server.URL),
1690 )
1691 require.NoError(t, err)
1692 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1693
1694 _, err = model.Generate(context.Background(), fantasy.Call{
1695 Prompt: testPrompt,
1696 ProviderOptions: NewProviderOptions(&ProviderOptions{
1697 Store: fantasy.Opt(true),
1698 }),
1699 })
1700
1701 require.NoError(t, err)
1702 require.Len(t, server.calls, 1)
1703
1704 call := server.calls[0]
1705 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1706 require.Equal(t, true, call.body["store"])
1707
1708 messages := call.body["messages"].([]any)
1709 require.Len(t, messages, 1)
1710
1711 message := messages[0].(map[string]any)
1712 require.Equal(t, "user", message["role"])
1713 require.Equal(t, "Hello", message["content"])
1714 })
1715
1716 t.Run("should send metadata extension values", func(t *testing.T) {
1717 t.Parallel()
1718
1719 server := newMockServer()
1720 defer server.close()
1721
1722 server.prepareJSONResponse(map[string]any{
1723 "content": "",
1724 })
1725
1726 provider, err := New(
1727 WithAPIKey("test-api-key"),
1728 WithBaseURL(server.server.URL),
1729 )
1730 require.NoError(t, err)
1731 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1732
1733 _, err = model.Generate(context.Background(), fantasy.Call{
1734 Prompt: testPrompt,
1735 ProviderOptions: NewProviderOptions(&ProviderOptions{
1736 Metadata: map[string]any{
1737 "custom": "value",
1738 },
1739 }),
1740 })
1741
1742 require.NoError(t, err)
1743 require.Len(t, server.calls, 1)
1744
1745 call := server.calls[0]
1746 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1747
1748 metadata := call.body["metadata"].(map[string]any)
1749 require.Equal(t, "value", metadata["custom"])
1750
1751 messages := call.body["messages"].([]any)
1752 require.Len(t, messages, 1)
1753
1754 message := messages[0].(map[string]any)
1755 require.Equal(t, "user", message["role"])
1756 require.Equal(t, "Hello", message["content"])
1757 })
1758
1759 t.Run("should send promptCacheKey extension value", func(t *testing.T) {
1760 t.Parallel()
1761
1762 server := newMockServer()
1763 defer server.close()
1764
1765 server.prepareJSONResponse(map[string]any{
1766 "content": "",
1767 })
1768
1769 provider, err := New(
1770 WithAPIKey("test-api-key"),
1771 WithBaseURL(server.server.URL),
1772 )
1773 require.NoError(t, err)
1774 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1775
1776 _, err = model.Generate(context.Background(), fantasy.Call{
1777 Prompt: testPrompt,
1778 ProviderOptions: NewProviderOptions(&ProviderOptions{
1779 PromptCacheKey: fantasy.Opt("test-cache-key-123"),
1780 }),
1781 })
1782
1783 require.NoError(t, err)
1784 require.Len(t, server.calls, 1)
1785
1786 call := server.calls[0]
1787 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1788 require.Equal(t, "test-cache-key-123", call.body["prompt_cache_key"])
1789
1790 messages := call.body["messages"].([]any)
1791 require.Len(t, messages, 1)
1792
1793 message := messages[0].(map[string]any)
1794 require.Equal(t, "user", message["role"])
1795 require.Equal(t, "Hello", message["content"])
1796 })
1797
1798 t.Run("should send safety_identifier extension value", func(t *testing.T) {
1799 t.Parallel()
1800
1801 server := newMockServer()
1802 defer server.close()
1803
1804 server.prepareJSONResponse(map[string]any{
1805 "content": "",
1806 })
1807
1808 provider, err := New(
1809 WithAPIKey("test-api-key"),
1810 WithBaseURL(server.server.URL),
1811 )
1812 require.NoError(t, err)
1813 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1814
1815 _, err = model.Generate(context.Background(), fantasy.Call{
1816 Prompt: testPrompt,
1817 ProviderOptions: NewProviderOptions(&ProviderOptions{
1818 SafetyIdentifier: fantasy.Opt("test-safety-identifier-123"),
1819 }),
1820 })
1821
1822 require.NoError(t, err)
1823 require.Len(t, server.calls, 1)
1824
1825 call := server.calls[0]
1826 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1827 require.Equal(t, "test-safety-identifier-123", call.body["safety_identifier"])
1828
1829 messages := call.body["messages"].([]any)
1830 require.Len(t, messages, 1)
1831
1832 message := messages[0].(map[string]any)
1833 require.Equal(t, "user", message["role"])
1834 require.Equal(t, "Hello", message["content"])
1835 })
1836
1837 t.Run("should remove temperature setting for search preview models", func(t *testing.T) {
1838 t.Parallel()
1839
1840 server := newMockServer()
1841 defer server.close()
1842
1843 server.prepareJSONResponse(map[string]any{})
1844
1845 provider, err := New(
1846 WithAPIKey("test-api-key"),
1847 WithBaseURL(server.server.URL),
1848 )
1849 require.NoError(t, err)
1850 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-search-preview")
1851
1852 result, err := model.Generate(context.Background(), fantasy.Call{
1853 Prompt: testPrompt,
1854 Temperature: &[]float64{0.7}[0],
1855 })
1856
1857 require.NoError(t, err)
1858 require.Len(t, server.calls, 1)
1859
1860 call := server.calls[0]
1861 require.Equal(t, "gpt-4o-search-preview", call.body["model"])
1862 require.Nil(t, call.body["temperature"])
1863
1864 require.Len(t, result.Warnings, 1)
1865 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1866 require.Equal(t, "temperature", result.Warnings[0].Setting)
1867 require.Contains(t, result.Warnings[0].Details, "search preview models")
1868 })
1869
1870 t.Run("should send ServiceTier flex processing setting", func(t *testing.T) {
1871 t.Parallel()
1872
1873 server := newMockServer()
1874 defer server.close()
1875
1876 server.prepareJSONResponse(map[string]any{
1877 "content": "",
1878 })
1879
1880 provider, err := New(
1881 WithAPIKey("test-api-key"),
1882 WithBaseURL(server.server.URL),
1883 )
1884 require.NoError(t, err)
1885 model, _ := provider.LanguageModel(t.Context(), "o3-mini")
1886
1887 _, err = model.Generate(context.Background(), fantasy.Call{
1888 Prompt: testPrompt,
1889 ProviderOptions: NewProviderOptions(&ProviderOptions{
1890 ServiceTier: fantasy.Opt("flex"),
1891 }),
1892 })
1893
1894 require.NoError(t, err)
1895 require.Len(t, server.calls, 1)
1896
1897 call := server.calls[0]
1898 require.Equal(t, "o3-mini", call.body["model"])
1899 require.Equal(t, "flex", call.body["service_tier"])
1900
1901 messages := call.body["messages"].([]any)
1902 require.Len(t, messages, 1)
1903
1904 message := messages[0].(map[string]any)
1905 require.Equal(t, "user", message["role"])
1906 require.Equal(t, "Hello", message["content"])
1907 })
1908
1909 t.Run("should show warning when using flex processing with unsupported model", func(t *testing.T) {
1910 t.Parallel()
1911
1912 server := newMockServer()
1913 defer server.close()
1914
1915 server.prepareJSONResponse(map[string]any{})
1916
1917 provider, err := New(
1918 WithAPIKey("test-api-key"),
1919 WithBaseURL(server.server.URL),
1920 )
1921 require.NoError(t, err)
1922 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1923
1924 result, err := model.Generate(context.Background(), fantasy.Call{
1925 Prompt: testPrompt,
1926 ProviderOptions: NewProviderOptions(&ProviderOptions{
1927 ServiceTier: fantasy.Opt("flex"),
1928 }),
1929 })
1930
1931 require.NoError(t, err)
1932 require.Len(t, server.calls, 1)
1933
1934 call := server.calls[0]
1935 require.Nil(t, call.body["service_tier"])
1936
1937 require.Len(t, result.Warnings, 1)
1938 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1939 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
1940 require.Contains(t, result.Warnings[0].Details, "flex processing is only available")
1941 })
1942
1943 t.Run("should send serviceTier priority processing setting", func(t *testing.T) {
1944 t.Parallel()
1945
1946 server := newMockServer()
1947 defer server.close()
1948
1949 server.prepareJSONResponse(map[string]any{})
1950
1951 provider, err := New(
1952 WithAPIKey("test-api-key"),
1953 WithBaseURL(server.server.URL),
1954 )
1955 require.NoError(t, err)
1956 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1957
1958 _, err = model.Generate(context.Background(), fantasy.Call{
1959 Prompt: testPrompt,
1960 ProviderOptions: NewProviderOptions(&ProviderOptions{
1961 ServiceTier: fantasy.Opt("priority"),
1962 }),
1963 })
1964
1965 require.NoError(t, err)
1966 require.Len(t, server.calls, 1)
1967
1968 call := server.calls[0]
1969 require.Equal(t, "gpt-4o-mini", call.body["model"])
1970 require.Equal(t, "priority", call.body["service_tier"])
1971
1972 messages := call.body["messages"].([]any)
1973 require.Len(t, messages, 1)
1974
1975 message := messages[0].(map[string]any)
1976 require.Equal(t, "user", message["role"])
1977 require.Equal(t, "Hello", message["content"])
1978 })
1979
1980 t.Run("should show warning when using priority processing with unsupported model", func(t *testing.T) {
1981 t.Parallel()
1982
1983 server := newMockServer()
1984 defer server.close()
1985
1986 server.prepareJSONResponse(map[string]any{})
1987
1988 provider, err := New(
1989 WithAPIKey("test-api-key"),
1990 WithBaseURL(server.server.URL),
1991 )
1992 require.NoError(t, err)
1993 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1994
1995 result, err := model.Generate(context.Background(), fantasy.Call{
1996 Prompt: testPrompt,
1997 ProviderOptions: NewProviderOptions(&ProviderOptions{
1998 ServiceTier: fantasy.Opt("priority"),
1999 }),
2000 })
2001
2002 require.NoError(t, err)
2003 require.Len(t, server.calls, 1)
2004
2005 call := server.calls[0]
2006 require.Nil(t, call.body["service_tier"])
2007
2008 require.Len(t, result.Warnings, 1)
2009 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
2010 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
2011 require.Contains(t, result.Warnings[0].Details, "priority processing is only available")
2012 })
2013}
2014
2015type streamingMockServer struct {
2016 server *httptest.Server
2017 chunks []string
2018 calls []mockCall
2019}
2020
2021func newStreamingMockServer() *streamingMockServer {
2022 sms := &streamingMockServer{
2023 calls: make([]mockCall, 0),
2024 }
2025
2026 sms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2027 // Record the call
2028 call := mockCall{
2029 method: r.Method,
2030 path: r.URL.Path,
2031 headers: make(map[string]string),
2032 }
2033
2034 for k, v := range r.Header {
2035 if len(v) > 0 {
2036 call.headers[k] = v[0]
2037 }
2038 }
2039
2040 // Parse request body
2041 if r.Body != nil {
2042 var body map[string]any
2043 json.NewDecoder(r.Body).Decode(&body)
2044 call.body = body
2045 }
2046
2047 sms.calls = append(sms.calls, call)
2048
2049 // Set streaming headers
2050 w.Header().Set("Content-Type", "text/event-stream")
2051 w.Header().Set("Cache-Control", "no-cache")
2052 w.Header().Set("Connection", "keep-alive")
2053
2054 // Add custom headers if any
2055 for _, chunk := range sms.chunks {
2056 if strings.HasPrefix(chunk, "HEADER:") {
2057 parts := strings.SplitN(chunk[7:], ":", 2)
2058 if len(parts) == 2 {
2059 w.Header().Set(parts[0], parts[1])
2060 }
2061 continue
2062 }
2063 }
2064
2065 w.WriteHeader(http.StatusOK)
2066
2067 // Write chunks
2068 for _, chunk := range sms.chunks {
2069 if strings.HasPrefix(chunk, "HEADER:") {
2070 continue
2071 }
2072 w.Write([]byte(chunk))
2073 if f, ok := w.(http.Flusher); ok {
2074 f.Flush()
2075 }
2076 }
2077 }))
2078
2079 return sms
2080}
2081
2082func (sms *streamingMockServer) close() {
2083 sms.server.Close()
2084}
2085
2086func (sms *streamingMockServer) prepareStreamResponse(opts map[string]any) {
2087 content := []string{}
2088 if c, ok := opts["content"].([]string); ok {
2089 content = c
2090 }
2091
2092 usage := map[string]any{
2093 "prompt_tokens": 17,
2094 "total_tokens": 244,
2095 "completion_tokens": 227,
2096 }
2097 if u, ok := opts["usage"].(map[string]any); ok {
2098 usage = u
2099 }
2100
2101 logprobs := map[string]any{}
2102 if l, ok := opts["logprobs"].(map[string]any); ok {
2103 logprobs = l
2104 }
2105
2106 finishReason := "stop"
2107 if fr, ok := opts["finish_reason"].(string); ok {
2108 finishReason = fr
2109 }
2110
2111 model := "gpt-3.5-turbo-0613"
2112 if m, ok := opts["model"].(string); ok {
2113 model = m
2114 }
2115
2116 headers := map[string]string{}
2117 if h, ok := opts["headers"].(map[string]string); ok {
2118 headers = h
2119 }
2120
2121 chunks := []string{}
2122
2123 // Add custom headers
2124 for k, v := range headers {
2125 chunks = append(chunks, "HEADER:"+k+":"+v)
2126 }
2127
2128 // Initial chunk with role
2129 initialChunk := map[string]any{
2130 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2131 "object": "chat.completion.chunk",
2132 "created": 1702657020,
2133 "model": model,
2134 "system_fingerprint": nil,
2135 "choices": []map[string]any{
2136 {
2137 "index": 0,
2138 "delta": map[string]any{
2139 "role": "assistant",
2140 "content": "",
2141 },
2142 "finish_reason": nil,
2143 },
2144 },
2145 }
2146 initialData, _ := json.Marshal(initialChunk)
2147 chunks = append(chunks, "data: "+string(initialData)+"\n\n")
2148
2149 // Content chunks
2150 for i, text := range content {
2151 contentChunk := map[string]any{
2152 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2153 "object": "chat.completion.chunk",
2154 "created": 1702657020,
2155 "model": model,
2156 "system_fingerprint": nil,
2157 "choices": []map[string]any{
2158 {
2159 "index": 1,
2160 "delta": map[string]any{
2161 "content": text,
2162 },
2163 "finish_reason": nil,
2164 },
2165 },
2166 }
2167 contentData, _ := json.Marshal(contentChunk)
2168 chunks = append(chunks, "data: "+string(contentData)+"\n\n")
2169
2170 // Add annotations if this is the last content chunk and we have annotations
2171 if i == len(content)-1 {
2172 if annotations, ok := opts["annotations"].([]map[string]any); ok {
2173 annotationChunk := map[string]any{
2174 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2175 "object": "chat.completion.chunk",
2176 "created": 1702657020,
2177 "model": model,
2178 "system_fingerprint": nil,
2179 "choices": []map[string]any{
2180 {
2181 "index": 1,
2182 "delta": map[string]any{
2183 "annotations": annotations,
2184 },
2185 "finish_reason": nil,
2186 },
2187 },
2188 }
2189 annotationData, _ := json.Marshal(annotationChunk)
2190 chunks = append(chunks, "data: "+string(annotationData)+"\n\n")
2191 }
2192 }
2193 }
2194
2195 // Finish chunk
2196 finishChunk := map[string]any{
2197 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2198 "object": "chat.completion.chunk",
2199 "created": 1702657020,
2200 "model": model,
2201 "system_fingerprint": nil,
2202 "choices": []map[string]any{
2203 {
2204 "index": 0,
2205 "delta": map[string]any{},
2206 "finish_reason": finishReason,
2207 },
2208 },
2209 }
2210
2211 if len(logprobs) > 0 {
2212 finishChunk["choices"].([]map[string]any)[0]["logprobs"] = logprobs
2213 }
2214
2215 finishData, _ := json.Marshal(finishChunk)
2216 chunks = append(chunks, "data: "+string(finishData)+"\n\n")
2217
2218 // Usage chunk
2219 usageChunk := map[string]any{
2220 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2221 "object": "chat.completion.chunk",
2222 "created": 1702657020,
2223 "model": model,
2224 "system_fingerprint": "fp_3bc1b5746c",
2225 "choices": []map[string]any{},
2226 "usage": usage,
2227 }
2228 usageData, _ := json.Marshal(usageChunk)
2229 chunks = append(chunks, "data: "+string(usageData)+"\n\n")
2230
2231 // Done
2232 chunks = append(chunks, "data: [DONE]\n\n")
2233
2234 sms.chunks = chunks
2235}
2236
2237func (sms *streamingMockServer) prepareToolStreamResponse() {
2238 chunks := []string{
2239 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2240 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2241 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"value"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2242 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2243 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2244 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2245 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2246 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2247 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
2248 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
2249 "data: [DONE]\n\n",
2250 }
2251 sms.chunks = chunks
2252}
2253
2254func (sms *streamingMockServer) prepareErrorStreamResponse() {
2255 chunks := []string{
2256 `data: {"error":{"message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error.","type":"server_error","param":null,"code":null}}` + "\n\n",
2257 "data: [DONE]\n\n",
2258 }
2259 sms.chunks = chunks
2260}
2261
2262func (sms *streamingMockServer) prepareToolStreamResponseWithEmptyArgs() {
2263 chunks := []string{
2264 // Tool call start with empty arguments (like Copilot sometimes does)
2265 `data: {"id":"chatcmpl-emptyargs","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_empty_args","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2266 // Finish without any argument deltas
2267 `data: {"id":"chatcmpl-emptyargs","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
2268 `data: {"id":"chatcmpl-emptyargs","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
2269 "data: [DONE]\n\n",
2270 }
2271 sms.chunks = chunks
2272}
2273
2274func collectStreamParts(stream fantasy.StreamResponse) ([]fantasy.StreamPart, error) {
2275 var parts []fantasy.StreamPart
2276 for part := range stream {
2277 parts = append(parts, part)
2278 if part.Type == fantasy.StreamPartTypeError {
2279 break
2280 }
2281 if part.Type == fantasy.StreamPartTypeFinish {
2282 break
2283 }
2284 }
2285 return parts, nil
2286}
2287
2288func TestDoStream(t *testing.T) {
2289 t.Parallel()
2290
2291 t.Run("should stream text deltas", func(t *testing.T) {
2292 t.Parallel()
2293
2294 server := newStreamingMockServer()
2295 defer server.close()
2296
2297 server.prepareStreamResponse(map[string]any{
2298 "content": []string{"Hello", ", ", "World!"},
2299 "finish_reason": "stop",
2300 "usage": map[string]any{
2301 "prompt_tokens": 17,
2302 "total_tokens": 244,
2303 "completion_tokens": 227,
2304 },
2305 "logprobs": testLogprobs,
2306 })
2307
2308 provider, err := New(
2309 WithAPIKey("test-api-key"),
2310 WithBaseURL(server.server.URL),
2311 )
2312 require.NoError(t, err)
2313 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2314
2315 stream, err := model.Stream(context.Background(), fantasy.Call{
2316 Prompt: testPrompt,
2317 })
2318
2319 require.NoError(t, err)
2320
2321 parts, err := collectStreamParts(stream)
2322 require.NoError(t, err)
2323
2324 // Verify stream structure
2325 require.True(t, len(parts) >= 4) // text-start, deltas, text-end, finish
2326
2327 // Find text parts
2328 textStart, textEnd, finish := -1, -1, -1
2329 var deltas []string
2330
2331 for i, part := range parts {
2332 switch part.Type {
2333 case fantasy.StreamPartTypeTextStart:
2334 textStart = i
2335 case fantasy.StreamPartTypeTextDelta:
2336 deltas = append(deltas, part.Delta)
2337 case fantasy.StreamPartTypeTextEnd:
2338 textEnd = i
2339 case fantasy.StreamPartTypeFinish:
2340 finish = i
2341 }
2342 }
2343
2344 require.NotEqual(t, -1, textStart)
2345 require.NotEqual(t, -1, textEnd)
2346 require.NotEqual(t, -1, finish)
2347 require.Equal(t, []string{"Hello", ", ", "World!"}, deltas)
2348
2349 // Check finish part
2350 finishPart := parts[finish]
2351 require.Equal(t, fantasy.FinishReasonStop, finishPart.FinishReason)
2352 require.Equal(t, int64(17), finishPart.Usage.InputTokens)
2353 require.Equal(t, int64(227), finishPart.Usage.OutputTokens)
2354 require.Equal(t, int64(244), finishPart.Usage.TotalTokens)
2355 })
2356
2357 t.Run("should stream tool deltas", func(t *testing.T) {
2358 t.Parallel()
2359
2360 server := newStreamingMockServer()
2361 defer server.close()
2362
2363 server.prepareToolStreamResponse()
2364
2365 provider, err := New(
2366 WithAPIKey("test-api-key"),
2367 WithBaseURL(server.server.URL),
2368 )
2369 require.NoError(t, err)
2370 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2371
2372 stream, err := model.Stream(context.Background(), fantasy.Call{
2373 Prompt: testPrompt,
2374 Tools: []fantasy.Tool{
2375 fantasy.FunctionTool{
2376 Name: "test-tool",
2377 InputSchema: map[string]any{
2378 "type": "object",
2379 "properties": map[string]any{
2380 "value": map[string]any{
2381 "type": "string",
2382 },
2383 },
2384 "required": []string{"value"},
2385 "additionalProperties": false,
2386 "$schema": "http://json-schema.org/draft-07/schema#",
2387 },
2388 },
2389 },
2390 })
2391
2392 require.NoError(t, err)
2393
2394 parts, err := collectStreamParts(stream)
2395 require.NoError(t, err)
2396
2397 // Find tool-related parts
2398 toolInputStart, toolInputEnd, toolCall := -1, -1, -1
2399 var toolDeltas []string
2400
2401 for i, part := range parts {
2402 switch part.Type {
2403 case fantasy.StreamPartTypeToolInputStart:
2404 toolInputStart = i
2405 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2406 require.Equal(t, "test-tool", part.ToolCallName)
2407 case fantasy.StreamPartTypeToolInputDelta:
2408 toolDeltas = append(toolDeltas, part.Delta)
2409 case fantasy.StreamPartTypeToolInputEnd:
2410 toolInputEnd = i
2411 case fantasy.StreamPartTypeToolCall:
2412 toolCall = i
2413 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2414 require.Equal(t, "test-tool", part.ToolCallName)
2415 require.Equal(t, `{"value":"Sparkle Day"}`, part.ToolCallInput)
2416 }
2417 }
2418
2419 require.NotEqual(t, -1, toolInputStart)
2420 require.NotEqual(t, -1, toolInputEnd)
2421 require.NotEqual(t, -1, toolCall)
2422
2423 // Verify tool deltas combine to form the complete input
2424 var fullInput strings.Builder
2425 for _, delta := range toolDeltas {
2426 fullInput.WriteString(delta)
2427 }
2428 require.Equal(t, `{"value":"Sparkle Day"}`, fullInput.String())
2429 })
2430
2431 t.Run("should handle tool calls with empty arguments", func(t *testing.T) {
2432 t.Parallel()
2433
2434 server := newStreamingMockServer()
2435 defer server.close()
2436
2437 server.prepareToolStreamResponseWithEmptyArgs()
2438
2439 provider, err := New(
2440 WithAPIKey("test-api-key"),
2441 WithBaseURL(server.server.URL),
2442 )
2443 require.NoError(t, err)
2444 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2445
2446 stream, err := model.Stream(context.Background(), fantasy.Call{
2447 Prompt: testPrompt,
2448 Tools: []fantasy.Tool{
2449 fantasy.FunctionTool{
2450 Name: "test-tool",
2451 InputSchema: map[string]any{
2452 "type": "object",
2453 "properties": map[string]any{
2454 "value": map[string]any{
2455 "type": "string",
2456 },
2457 },
2458 "required": []string{"value"},
2459 "additionalProperties": false,
2460 "$schema": "http://json-schema.org/draft-07/schema#",
2461 },
2462 },
2463 },
2464 })
2465
2466 require.NoError(t, err)
2467
2468 parts, err := collectStreamParts(stream)
2469 require.NoError(t, err)
2470
2471 // Find tool-related parts
2472 toolInputStart, toolInputEnd, toolCall := -1, -1, -1
2473
2474 for i, part := range parts {
2475 switch part.Type {
2476 case fantasy.StreamPartTypeToolInputStart:
2477 toolInputStart = i
2478 require.Equal(t, "call_empty_args", part.ID)
2479 require.Equal(t, "test-tool", part.ToolCallName)
2480 case fantasy.StreamPartTypeToolInputEnd:
2481 toolInputEnd = i
2482 require.Equal(t, "call_empty_args", part.ID)
2483 case fantasy.StreamPartTypeToolCall:
2484 toolCall = i
2485 require.Equal(t, "call_empty_args", part.ID)
2486 require.Equal(t, "test-tool", part.ToolCallName)
2487 // Empty arguments should be normalized to "{}"
2488 require.Equal(t, "{}", part.ToolCallInput)
2489 }
2490 }
2491
2492 require.NotEqual(t, -1, toolInputStart, "expected ToolInputStart part")
2493 require.NotEqual(t, -1, toolInputEnd, "expected ToolInputEnd part")
2494 require.NotEqual(t, -1, toolCall, "expected ToolCall part")
2495 })
2496
2497 t.Run("should stream annotations/citations", func(t *testing.T) {
2498 t.Parallel()
2499
2500 server := newStreamingMockServer()
2501 defer server.close()
2502
2503 server.prepareStreamResponse(map[string]any{
2504 "content": []string{"Based on search results"},
2505 "annotations": []map[string]any{
2506 {
2507 "type": "url_citation",
2508 "url_citation": map[string]any{
2509 "start_index": 24,
2510 "end_index": 29,
2511 "url": "https://example.com/doc1.pdf",
2512 "title": "Document 1",
2513 },
2514 },
2515 },
2516 })
2517
2518 provider, err := New(
2519 WithAPIKey("test-api-key"),
2520 WithBaseURL(server.server.URL),
2521 )
2522 require.NoError(t, err)
2523 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2524
2525 stream, err := model.Stream(context.Background(), fantasy.Call{
2526 Prompt: testPrompt,
2527 })
2528
2529 require.NoError(t, err)
2530
2531 parts, err := collectStreamParts(stream)
2532 require.NoError(t, err)
2533
2534 // Find source part
2535 var sourcePart *fantasy.StreamPart
2536 for _, part := range parts {
2537 if part.Type == fantasy.StreamPartTypeSource {
2538 sourcePart = &part
2539 break
2540 }
2541 }
2542
2543 require.NotNil(t, sourcePart)
2544 require.Equal(t, fantasy.SourceTypeURL, sourcePart.SourceType)
2545 require.Equal(t, "https://example.com/doc1.pdf", sourcePart.URL)
2546 require.Equal(t, "Document 1", sourcePart.Title)
2547 require.NotEmpty(t, sourcePart.ID)
2548 })
2549
2550 t.Run("should handle error stream parts", func(t *testing.T) {
2551 t.Parallel()
2552
2553 server := newStreamingMockServer()
2554 defer server.close()
2555
2556 server.prepareErrorStreamResponse()
2557
2558 provider, err := New(
2559 WithAPIKey("test-api-key"),
2560 WithBaseURL(server.server.URL),
2561 )
2562 require.NoError(t, err)
2563 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2564
2565 stream, err := model.Stream(context.Background(), fantasy.Call{
2566 Prompt: testPrompt,
2567 })
2568
2569 require.NoError(t, err)
2570
2571 parts, err := collectStreamParts(stream)
2572 require.NoError(t, err)
2573
2574 // Should have error and finish parts
2575 require.True(t, len(parts) >= 1)
2576
2577 // Find error part
2578 var errorPart *fantasy.StreamPart
2579 for _, part := range parts {
2580 if part.Type == fantasy.StreamPartTypeError {
2581 errorPart = &part
2582 break
2583 }
2584 }
2585
2586 require.NotNil(t, errorPart)
2587 require.NotNil(t, errorPart.Error)
2588 })
2589
2590 t.Run("should send request body", func(t *testing.T) {
2591 t.Parallel()
2592
2593 server := newStreamingMockServer()
2594 defer server.close()
2595
2596 server.prepareStreamResponse(map[string]any{
2597 "content": []string{},
2598 })
2599
2600 provider, err := New(
2601 WithAPIKey("test-api-key"),
2602 WithBaseURL(server.server.URL),
2603 )
2604 require.NoError(t, err)
2605 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2606
2607 _, err = model.Stream(context.Background(), fantasy.Call{
2608 Prompt: testPrompt,
2609 })
2610
2611 require.NoError(t, err)
2612 require.Len(t, server.calls, 1)
2613
2614 call := server.calls[0]
2615 require.Equal(t, "POST", call.method)
2616 require.Equal(t, "/chat/completions", call.path)
2617 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2618 require.Equal(t, true, call.body["stream"])
2619
2620 streamOptions := call.body["stream_options"].(map[string]any)
2621 require.Equal(t, true, streamOptions["include_usage"])
2622
2623 messages := call.body["messages"].([]any)
2624 require.Len(t, messages, 1)
2625
2626 message := messages[0].(map[string]any)
2627 require.Equal(t, "user", message["role"])
2628 require.Equal(t, "Hello", message["content"])
2629 })
2630
2631 t.Run("should return cached tokens in providerMetadata", func(t *testing.T) {
2632 t.Parallel()
2633
2634 server := newStreamingMockServer()
2635 defer server.close()
2636
2637 server.prepareStreamResponse(map[string]any{
2638 "content": []string{},
2639 "usage": map[string]any{
2640 "prompt_tokens": 15,
2641 "completion_tokens": 20,
2642 "total_tokens": 35,
2643 "prompt_tokens_details": map[string]any{
2644 "cached_tokens": 1152,
2645 },
2646 },
2647 })
2648
2649 provider, err := New(
2650 WithAPIKey("test-api-key"),
2651 WithBaseURL(server.server.URL),
2652 )
2653 require.NoError(t, err)
2654 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2655
2656 stream, err := model.Stream(context.Background(), fantasy.Call{
2657 Prompt: testPrompt,
2658 })
2659
2660 require.NoError(t, err)
2661
2662 parts, err := collectStreamParts(stream)
2663 require.NoError(t, err)
2664
2665 // Find finish part
2666 var finishPart *fantasy.StreamPart
2667 for _, part := range parts {
2668 if part.Type == fantasy.StreamPartTypeFinish {
2669 finishPart = &part
2670 break
2671 }
2672 }
2673
2674 require.NotNil(t, finishPart)
2675 require.Equal(t, int64(1152), finishPart.Usage.CacheReadTokens)
2676 // InputTokens = prompt_tokens - cached_tokens = 15 - 1152 = -1137 → clamped to 0
2677 require.Equal(t, int64(0), finishPart.Usage.InputTokens)
2678 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2679 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2680 })
2681
2682 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
2683 t.Parallel()
2684
2685 server := newStreamingMockServer()
2686 defer server.close()
2687
2688 server.prepareStreamResponse(map[string]any{
2689 "content": []string{},
2690 "usage": map[string]any{
2691 "prompt_tokens": 15,
2692 "completion_tokens": 20,
2693 "total_tokens": 35,
2694 "completion_tokens_details": map[string]any{
2695 "accepted_prediction_tokens": 123,
2696 "rejected_prediction_tokens": 456,
2697 },
2698 },
2699 })
2700
2701 provider, err := New(
2702 WithAPIKey("test-api-key"),
2703 WithBaseURL(server.server.URL),
2704 )
2705 require.NoError(t, err)
2706 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2707
2708 stream, err := model.Stream(context.Background(), fantasy.Call{
2709 Prompt: testPrompt,
2710 })
2711
2712 require.NoError(t, err)
2713
2714 parts, err := collectStreamParts(stream)
2715 require.NoError(t, err)
2716
2717 // Find finish part
2718 var finishPart *fantasy.StreamPart
2719 for _, part := range parts {
2720 if part.Type == fantasy.StreamPartTypeFinish {
2721 finishPart = &part
2722 break
2723 }
2724 }
2725
2726 require.NotNil(t, finishPart)
2727 require.NotNil(t, finishPart.ProviderMetadata)
2728
2729 openaiMeta, ok := finishPart.ProviderMetadata["openai"].(*ProviderMetadata)
2730 require.True(t, ok)
2731 require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
2732 require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
2733 })
2734
2735 t.Run("should send store extension setting", func(t *testing.T) {
2736 t.Parallel()
2737
2738 server := newStreamingMockServer()
2739 defer server.close()
2740
2741 server.prepareStreamResponse(map[string]any{
2742 "content": []string{},
2743 })
2744
2745 provider, err := New(
2746 WithAPIKey("test-api-key"),
2747 WithBaseURL(server.server.URL),
2748 )
2749 require.NoError(t, err)
2750 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2751
2752 _, err = model.Stream(context.Background(), fantasy.Call{
2753 Prompt: testPrompt,
2754 ProviderOptions: NewProviderOptions(&ProviderOptions{
2755 Store: fantasy.Opt(true),
2756 }),
2757 })
2758
2759 require.NoError(t, err)
2760 require.Len(t, server.calls, 1)
2761
2762 call := server.calls[0]
2763 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2764 require.Equal(t, true, call.body["stream"])
2765 require.Equal(t, true, call.body["store"])
2766
2767 streamOptions := call.body["stream_options"].(map[string]any)
2768 require.Equal(t, true, streamOptions["include_usage"])
2769
2770 messages := call.body["messages"].([]any)
2771 require.Len(t, messages, 1)
2772
2773 message := messages[0].(map[string]any)
2774 require.Equal(t, "user", message["role"])
2775 require.Equal(t, "Hello", message["content"])
2776 })
2777
2778 t.Run("should send metadata extension values", func(t *testing.T) {
2779 t.Parallel()
2780
2781 server := newStreamingMockServer()
2782 defer server.close()
2783
2784 server.prepareStreamResponse(map[string]any{
2785 "content": []string{},
2786 })
2787
2788 provider, err := New(
2789 WithAPIKey("test-api-key"),
2790 WithBaseURL(server.server.URL),
2791 )
2792 require.NoError(t, err)
2793 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2794
2795 _, err = model.Stream(context.Background(), fantasy.Call{
2796 Prompt: testPrompt,
2797 ProviderOptions: NewProviderOptions(&ProviderOptions{
2798 Metadata: map[string]any{
2799 "custom": "value",
2800 },
2801 }),
2802 })
2803
2804 require.NoError(t, err)
2805 require.Len(t, server.calls, 1)
2806
2807 call := server.calls[0]
2808 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2809 require.Equal(t, true, call.body["stream"])
2810
2811 metadata := call.body["metadata"].(map[string]any)
2812 require.Equal(t, "value", metadata["custom"])
2813
2814 streamOptions := call.body["stream_options"].(map[string]any)
2815 require.Equal(t, true, streamOptions["include_usage"])
2816
2817 messages := call.body["messages"].([]any)
2818 require.Len(t, messages, 1)
2819
2820 message := messages[0].(map[string]any)
2821 require.Equal(t, "user", message["role"])
2822 require.Equal(t, "Hello", message["content"])
2823 })
2824
2825 t.Run("should send serviceTier flex processing setting in streaming", func(t *testing.T) {
2826 t.Parallel()
2827
2828 server := newStreamingMockServer()
2829 defer server.close()
2830
2831 server.prepareStreamResponse(map[string]any{
2832 "content": []string{},
2833 })
2834
2835 provider, err := New(
2836 WithAPIKey("test-api-key"),
2837 WithBaseURL(server.server.URL),
2838 )
2839 require.NoError(t, err)
2840 model, _ := provider.LanguageModel(t.Context(), "o3-mini")
2841
2842 _, err = model.Stream(context.Background(), fantasy.Call{
2843 Prompt: testPrompt,
2844 ProviderOptions: NewProviderOptions(&ProviderOptions{
2845 ServiceTier: fantasy.Opt("flex"),
2846 }),
2847 })
2848
2849 require.NoError(t, err)
2850 require.Len(t, server.calls, 1)
2851
2852 call := server.calls[0]
2853 require.Equal(t, "o3-mini", call.body["model"])
2854 require.Equal(t, "flex", call.body["service_tier"])
2855 require.Equal(t, true, call.body["stream"])
2856
2857 streamOptions := call.body["stream_options"].(map[string]any)
2858 require.Equal(t, true, streamOptions["include_usage"])
2859
2860 messages := call.body["messages"].([]any)
2861 require.Len(t, messages, 1)
2862
2863 message := messages[0].(map[string]any)
2864 require.Equal(t, "user", message["role"])
2865 require.Equal(t, "Hello", message["content"])
2866 })
2867
2868 t.Run("should send serviceTier priority processing setting in streaming", func(t *testing.T) {
2869 t.Parallel()
2870
2871 server := newStreamingMockServer()
2872 defer server.close()
2873
2874 server.prepareStreamResponse(map[string]any{
2875 "content": []string{},
2876 })
2877
2878 provider, err := New(
2879 WithAPIKey("test-api-key"),
2880 WithBaseURL(server.server.URL),
2881 )
2882 require.NoError(t, err)
2883 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
2884
2885 _, err = model.Stream(context.Background(), fantasy.Call{
2886 Prompt: testPrompt,
2887 ProviderOptions: NewProviderOptions(&ProviderOptions{
2888 ServiceTier: fantasy.Opt("priority"),
2889 }),
2890 })
2891
2892 require.NoError(t, err)
2893 require.Len(t, server.calls, 1)
2894
2895 call := server.calls[0]
2896 require.Equal(t, "gpt-4o-mini", call.body["model"])
2897 require.Equal(t, "priority", call.body["service_tier"])
2898 require.Equal(t, true, call.body["stream"])
2899
2900 streamOptions := call.body["stream_options"].(map[string]any)
2901 require.Equal(t, true, streamOptions["include_usage"])
2902
2903 messages := call.body["messages"].([]any)
2904 require.Len(t, messages, 1)
2905
2906 message := messages[0].(map[string]any)
2907 require.Equal(t, "user", message["role"])
2908 require.Equal(t, "Hello", message["content"])
2909 })
2910
2911 t.Run("should stream text delta for reasoning models", func(t *testing.T) {
2912 t.Parallel()
2913
2914 server := newStreamingMockServer()
2915 defer server.close()
2916
2917 server.prepareStreamResponse(map[string]any{
2918 "content": []string{"Hello, World!"},
2919 "model": "o1-preview",
2920 })
2921
2922 provider, err := New(
2923 WithAPIKey("test-api-key"),
2924 WithBaseURL(server.server.URL),
2925 )
2926 require.NoError(t, err)
2927 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
2928
2929 stream, err := model.Stream(context.Background(), fantasy.Call{
2930 Prompt: testPrompt,
2931 })
2932
2933 require.NoError(t, err)
2934
2935 parts, err := collectStreamParts(stream)
2936 require.NoError(t, err)
2937
2938 // Find text parts
2939 var textDeltas []string
2940 for _, part := range parts {
2941 if part.Type == fantasy.StreamPartTypeTextDelta {
2942 textDeltas = append(textDeltas, part.Delta)
2943 }
2944 }
2945
2946 // Should contain the text content (without empty delta)
2947 require.Equal(t, []string{"Hello, World!"}, textDeltas)
2948 })
2949
2950 t.Run("should send reasoning tokens", func(t *testing.T) {
2951 t.Parallel()
2952
2953 server := newStreamingMockServer()
2954 defer server.close()
2955
2956 server.prepareStreamResponse(map[string]any{
2957 "content": []string{"Hello, World!"},
2958 "model": "o1-preview",
2959 "usage": map[string]any{
2960 "prompt_tokens": 15,
2961 "completion_tokens": 20,
2962 "total_tokens": 35,
2963 "completion_tokens_details": map[string]any{
2964 "reasoning_tokens": 10,
2965 },
2966 },
2967 })
2968
2969 provider, err := New(
2970 WithAPIKey("test-api-key"),
2971 WithBaseURL(server.server.URL),
2972 )
2973 require.NoError(t, err)
2974 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
2975
2976 stream, err := model.Stream(context.Background(), fantasy.Call{
2977 Prompt: testPrompt,
2978 })
2979
2980 require.NoError(t, err)
2981
2982 parts, err := collectStreamParts(stream)
2983 require.NoError(t, err)
2984
2985 // Find finish part
2986 var finishPart *fantasy.StreamPart
2987 for _, part := range parts {
2988 if part.Type == fantasy.StreamPartTypeFinish {
2989 finishPart = &part
2990 break
2991 }
2992 }
2993
2994 require.NotNil(t, finishPart)
2995 require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2996 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2997 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2998 require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens)
2999 })
3000}
3001
3002func TestDefaultToPrompt_DropsEmptyMessages(t *testing.T) {
3003 t.Parallel()
3004
3005 t.Run("should drop truly empty assistant messages", func(t *testing.T) {
3006 t.Parallel()
3007
3008 prompt := fantasy.Prompt{
3009 {
3010 Role: fantasy.MessageRoleUser,
3011 Content: []fantasy.MessagePart{
3012 fantasy.TextPart{Text: "Hello"},
3013 },
3014 },
3015 {
3016 Role: fantasy.MessageRoleAssistant,
3017 Content: []fantasy.MessagePart{},
3018 },
3019 }
3020
3021 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3022
3023 require.Len(t, messages, 1, "should only have user message")
3024 require.Len(t, warnings, 1)
3025 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
3026 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
3027 })
3028
3029 t.Run("should keep assistant messages with text content", func(t *testing.T) {
3030 t.Parallel()
3031
3032 prompt := fantasy.Prompt{
3033 {
3034 Role: fantasy.MessageRoleUser,
3035 Content: []fantasy.MessagePart{
3036 fantasy.TextPart{Text: "Hello"},
3037 },
3038 },
3039 {
3040 Role: fantasy.MessageRoleAssistant,
3041 Content: []fantasy.MessagePart{
3042 fantasy.TextPart{Text: "Hi there!"},
3043 },
3044 },
3045 }
3046
3047 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3048
3049 require.Len(t, messages, 2, "should have both user and assistant messages")
3050 require.Empty(t, warnings)
3051 })
3052
3053 t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
3054 t.Parallel()
3055
3056 prompt := fantasy.Prompt{
3057 {
3058 Role: fantasy.MessageRoleUser,
3059 Content: []fantasy.MessagePart{
3060 fantasy.TextPart{Text: "What's the weather?"},
3061 },
3062 },
3063 {
3064 Role: fantasy.MessageRoleAssistant,
3065 Content: []fantasy.MessagePart{
3066 fantasy.ToolCallPart{
3067 ToolCallID: "call_123",
3068 ToolName: "get_weather",
3069 Input: `{"location":"NYC"}`,
3070 },
3071 },
3072 },
3073 }
3074
3075 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3076
3077 require.Len(t, messages, 2, "should have both user and assistant messages")
3078 require.Empty(t, warnings)
3079 })
3080
3081 t.Run("should drop user messages without visible content", func(t *testing.T) {
3082 t.Parallel()
3083
3084 prompt := fantasy.Prompt{
3085 {
3086 Role: fantasy.MessageRoleUser,
3087 Content: []fantasy.MessagePart{
3088 fantasy.FilePart{
3089 Data: []byte("not supported"),
3090 MediaType: "application/unknown",
3091 },
3092 },
3093 },
3094 }
3095
3096 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3097
3098 require.Empty(t, messages)
3099 require.Len(t, warnings, 2) // One for unsupported type, one for empty message
3100 require.Contains(t, warnings[1].Message, "dropping empty user message")
3101 })
3102
3103 t.Run("should keep user messages with image content", func(t *testing.T) {
3104 t.Parallel()
3105
3106 prompt := fantasy.Prompt{
3107 {
3108 Role: fantasy.MessageRoleUser,
3109 Content: []fantasy.MessagePart{
3110 fantasy.FilePart{
3111 Data: []byte{0x01, 0x02, 0x03},
3112 MediaType: "image/png",
3113 },
3114 },
3115 },
3116 }
3117
3118 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3119
3120 require.Len(t, messages, 1)
3121 require.Empty(t, warnings)
3122 })
3123
3124 t.Run("should keep user messages with tool results", func(t *testing.T) {
3125 t.Parallel()
3126
3127 prompt := fantasy.Prompt{
3128 {
3129 Role: fantasy.MessageRoleTool,
3130 Content: []fantasy.MessagePart{
3131 fantasy.ToolResultPart{
3132 ToolCallID: "call_123",
3133 Output: fantasy.ToolResultOutputContentText{Text: "done"},
3134 },
3135 },
3136 },
3137 }
3138
3139 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3140
3141 require.Len(t, messages, 1)
3142 require.Empty(t, warnings)
3143 })
3144
3145 t.Run("should keep user messages with tool error results", func(t *testing.T) {
3146 t.Parallel()
3147
3148 prompt := fantasy.Prompt{
3149 {
3150 Role: fantasy.MessageRoleTool,
3151 Content: []fantasy.MessagePart{
3152 fantasy.ToolResultPart{
3153 ToolCallID: "call_456",
3154 Output: fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
3155 },
3156 },
3157 },
3158 }
3159
3160 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3161
3162 require.Len(t, messages, 1)
3163 require.Empty(t, warnings)
3164 })
3165}
3166
3167func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) {
3168 t.Parallel()
3169
3170 t.Run("should drop truly empty assistant messages", func(t *testing.T) {
3171 t.Parallel()
3172
3173 prompt := fantasy.Prompt{
3174 {
3175 Role: fantasy.MessageRoleUser,
3176 Content: []fantasy.MessagePart{
3177 fantasy.TextPart{Text: "Hello"},
3178 },
3179 },
3180 {
3181 Role: fantasy.MessageRoleAssistant,
3182 Content: []fantasy.MessagePart{},
3183 },
3184 }
3185
3186 input, warnings := toResponsesPrompt(prompt, "system", false)
3187
3188 require.Len(t, input, 1, "should only have user message")
3189 require.Len(t, warnings, 1)
3190 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
3191 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
3192 })
3193
3194 t.Run("should keep assistant messages with text content", func(t *testing.T) {
3195 t.Parallel()
3196
3197 prompt := fantasy.Prompt{
3198 {
3199 Role: fantasy.MessageRoleUser,
3200 Content: []fantasy.MessagePart{
3201 fantasy.TextPart{Text: "Hello"},
3202 },
3203 },
3204 {
3205 Role: fantasy.MessageRoleAssistant,
3206 Content: []fantasy.MessagePart{
3207 fantasy.TextPart{Text: "Hi there!"},
3208 },
3209 },
3210 }
3211
3212 input, warnings := toResponsesPrompt(prompt, "system", false)
3213
3214 require.Len(t, input, 2, "should have both user and assistant messages")
3215 require.Empty(t, warnings)
3216 })
3217
3218 t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
3219 t.Parallel()
3220
3221 prompt := fantasy.Prompt{
3222 {
3223 Role: fantasy.MessageRoleUser,
3224 Content: []fantasy.MessagePart{
3225 fantasy.TextPart{Text: "What's the weather?"},
3226 },
3227 },
3228 {
3229 Role: fantasy.MessageRoleAssistant,
3230 Content: []fantasy.MessagePart{
3231 fantasy.ToolCallPart{
3232 ToolCallID: "call_123",
3233 ToolName: "get_weather",
3234 Input: `{"location":"NYC"}`,
3235 },
3236 },
3237 },
3238 }
3239
3240 input, warnings := toResponsesPrompt(prompt, "system", false)
3241
3242 require.Len(t, input, 2, "should have both user and assistant messages")
3243 require.Empty(t, warnings)
3244 })
3245
3246 t.Run("should drop user messages without visible content", func(t *testing.T) {
3247 t.Parallel()
3248
3249 prompt := fantasy.Prompt{
3250 {
3251 Role: fantasy.MessageRoleUser,
3252 Content: []fantasy.MessagePart{
3253 fantasy.FilePart{
3254 Data: []byte("not supported"),
3255 MediaType: "application/unknown",
3256 },
3257 },
3258 },
3259 }
3260
3261 input, warnings := toResponsesPrompt(prompt, "system", false)
3262
3263 require.Empty(t, input)
3264 require.Len(t, warnings, 2) // One for unsupported type, one for empty message
3265 require.Contains(t, warnings[1].Message, "dropping empty user message")
3266 })
3267
3268 t.Run("should keep user messages with image content", func(t *testing.T) {
3269 t.Parallel()
3270
3271 prompt := fantasy.Prompt{
3272 {
3273 Role: fantasy.MessageRoleUser,
3274 Content: []fantasy.MessagePart{
3275 fantasy.FilePart{
3276 Data: []byte{0x01, 0x02, 0x03},
3277 MediaType: "image/png",
3278 },
3279 },
3280 },
3281 }
3282
3283 input, warnings := toResponsesPrompt(prompt, "system", false)
3284
3285 require.Len(t, input, 1)
3286 require.Empty(t, warnings)
3287 })
3288
3289 t.Run("should keep user messages with tool results", func(t *testing.T) {
3290 t.Parallel()
3291
3292 prompt := fantasy.Prompt{
3293 {
3294 Role: fantasy.MessageRoleTool,
3295 Content: []fantasy.MessagePart{
3296 fantasy.ToolResultPart{
3297 ToolCallID: "call_123",
3298 Output: fantasy.ToolResultOutputContentText{Text: "done"},
3299 },
3300 },
3301 },
3302 }
3303
3304 input, warnings := toResponsesPrompt(prompt, "system", false)
3305
3306 require.Len(t, input, 1)
3307 require.Empty(t, warnings)
3308 })
3309
3310 t.Run("should keep user messages with tool error results", func(t *testing.T) {
3311 t.Parallel()
3312
3313 prompt := fantasy.Prompt{
3314 {
3315 Role: fantasy.MessageRoleTool,
3316 Content: []fantasy.MessagePart{
3317 fantasy.ToolResultPart{
3318 ToolCallID: "call_456",
3319 Output: fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
3320 },
3321 },
3322 },
3323 }
3324
3325 input, warnings := toResponsesPrompt(prompt, "system", false)
3326
3327 require.Len(t, input, 1)
3328 require.Empty(t, warnings)
3329 })
3330}
3331
3332func TestParseContextTooLargeError(t *testing.T) {
3333 t.Parallel()
3334
3335 tests := []struct {
3336 name string
3337 message string
3338 wantErr bool
3339 wantUsed int
3340 wantMax int
3341 }{
3342 {
3343 name: "matches openai format with resulted in",
3344 message: "This model's maximum context length is 128000 tokens. However, your messages resulted in 150000 tokens.",
3345 wantErr: true,
3346 wantUsed: 150000,
3347 wantMax: 128000,
3348 },
3349 {
3350 name: "matches openai format with requested",
3351 message: "maximum context length is 8192 tokens, however you requested 10000 tokens",
3352 wantErr: true,
3353 wantUsed: 10000,
3354 wantMax: 8192,
3355 },
3356 {
3357 name: "does not match unrelated error",
3358 message: "invalid api key",
3359 wantErr: false,
3360 },
3361 {
3362 name: "does not match rate limit error",
3363 message: "rate limit exceeded",
3364 wantErr: false,
3365 },
3366 }
3367
3368 for _, tt := range tests {
3369 t.Run(tt.name, func(t *testing.T) {
3370 t.Parallel()
3371 providerErr := &fantasy.ProviderError{Message: tt.message}
3372 parseContextTooLargeError(tt.message, providerErr)
3373
3374 if tt.wantErr {
3375 require.True(t, providerErr.IsContextTooLarge())
3376 if tt.wantUsed > 0 {
3377 require.Equal(t, tt.wantUsed, providerErr.ContextUsedTokens)
3378 require.Equal(t, tt.wantMax, providerErr.ContextMaxTokens)
3379 }
3380 } else {
3381 require.False(t, providerErr.IsContextTooLarge())
3382 }
3383 })
3384 }
3385}
3386
3387func TestUserAgent(t *testing.T) {
3388 t.Parallel()
3389
3390 t.Run("default UA applied", func(t *testing.T) {
3391 t.Parallel()
3392
3393 server := newMockServer()
3394 defer server.close()
3395 server.prepareJSONResponse(map[string]any{})
3396
3397 p, err := New(WithAPIKey("k"), WithBaseURL(server.server.URL))
3398 require.NoError(t, err)
3399 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3400 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt})
3401
3402 require.Len(t, server.calls, 1)
3403 assert.Equal(t, "Charm-Fantasy/"+fantasy.Version+" (https://charm.land/fantasy)", server.calls[0].headers["User-Agent"])
3404 })
3405
3406 t.Run("WithHeaders User-Agent wins over default", func(t *testing.T) {
3407 t.Parallel()
3408
3409 server := newMockServer()
3410 defer server.close()
3411 server.prepareJSONResponse(map[string]any{})
3412
3413 p, err := New(WithAPIKey("k"), WithBaseURL(server.server.URL), WithHeaders(map[string]string{"User-Agent": "custom-from-headers"}))
3414 require.NoError(t, err)
3415 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3416 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt})
3417
3418 require.Len(t, server.calls, 1)
3419 assert.Equal(t, "custom-from-headers", server.calls[0].headers["User-Agent"])
3420 })
3421
3422 t.Run("WithUserAgent wins over both", func(t *testing.T) {
3423 t.Parallel()
3424
3425 server := newMockServer()
3426 defer server.close()
3427 server.prepareJSONResponse(map[string]any{})
3428
3429 p, err := New(
3430 WithAPIKey("k"),
3431 WithBaseURL(server.server.URL),
3432 WithHeaders(map[string]string{"User-Agent": "from-headers"}),
3433 WithUserAgent("explicit-ua"),
3434 )
3435 require.NoError(t, err)
3436 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3437 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt})
3438
3439 require.Len(t, server.calls, 1)
3440 assert.Equal(t, "explicit-ua", server.calls[0].headers["User-Agent"])
3441 })
3442
3443 t.Run("Call.UserAgent overrides provider WithHeaders UA", func(t *testing.T) {
3444 t.Parallel()
3445
3446 server := newMockServer()
3447 defer server.close()
3448 server.prepareJSONResponse(map[string]any{})
3449
3450 p, err := New(
3451 WithAPIKey("k"),
3452 WithBaseURL(server.server.URL),
3453 WithHeaders(map[string]string{"User-Agent": "header-ua"}),
3454 )
3455 require.NoError(t, err)
3456 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3457 _, _ = model.Generate(t.Context(), fantasy.Call{
3458 Prompt: testPrompt,
3459 UserAgent: "call-level-ua",
3460 })
3461
3462 require.Len(t, server.calls, 1)
3463 assert.Equal(t, "call-level-ua", server.calls[0].headers["User-Agent"])
3464 })
3465
3466 t.Run("no Call UA falls through to provider UA", func(t *testing.T) {
3467 t.Parallel()
3468
3469 server := newMockServer()
3470 defer server.close()
3471 server.prepareJSONResponse(map[string]any{})
3472
3473 p, err := New(
3474 WithAPIKey("k"),
3475 WithBaseURL(server.server.URL),
3476 WithUserAgent("provider-ua"),
3477 )
3478 require.NoError(t, err)
3479 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3480 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt})
3481
3482 require.Len(t, server.calls, 1)
3483 assert.Equal(t, "provider-ua", server.calls[0].headers["User-Agent"])
3484 })
3485
3486 t.Run("agent WithUserAgent overrides provider UA end-to-end", func(t *testing.T) {
3487 t.Parallel()
3488
3489 server := newMockServer()
3490 defer server.close()
3491 server.prepareJSONResponse(map[string]any{})
3492
3493 p, err := New(
3494 WithAPIKey("k"),
3495 WithBaseURL(server.server.URL),
3496 WithUserAgent("provider-ua"),
3497 )
3498 require.NoError(t, err)
3499 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3500
3501 agent := fantasy.NewAgent(model, fantasy.WithUserAgent("agent-ua"))
3502 _, _ = agent.Generate(t.Context(), fantasy.AgentCall{Prompt: "hi"})
3503
3504 require.Len(t, server.calls, 1)
3505 assert.Equal(t, "agent-ua", server.calls[0].headers["User-Agent"])
3506 })
3507
3508 t.Run("agent without UA falls through to provider UA end-to-end", func(t *testing.T) {
3509 t.Parallel()
3510
3511 server := newMockServer()
3512 defer server.close()
3513 server.prepareJSONResponse(map[string]any{})
3514
3515 p, err := New(
3516 WithAPIKey("k"),
3517 WithBaseURL(server.server.URL),
3518 WithUserAgent("provider-ua"),
3519 )
3520 require.NoError(t, err)
3521 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3522
3523 agent := fantasy.NewAgent(model)
3524 _, _ = agent.Generate(t.Context(), fantasy.AgentCall{Prompt: "hi"})
3525
3526 require.Len(t, server.calls, 1)
3527 assert.Equal(t, "provider-ua", server.calls[0].headers["User-Agent"])
3528 })
3529}
3530
3531// --- OpenAI Responses API Web Search Tests ---
3532
3533// mockResponsesWebSearchResponse returns a Responses API response
3534// containing a web_search_call output item followed by a message
3535// with url_citation annotations.
3536func mockResponsesWebSearchResponse() map[string]any {
3537 return map[string]any{
3538 "id": "resp_01WebSearch",
3539 "object": "response",
3540 "model": "gpt-4.1",
3541 "output": []any{
3542 map[string]any{
3543 "type": "web_search_call",
3544 "id": "ws_01",
3545 "status": "completed",
3546 "action": map[string]any{
3547 "type": "search",
3548 "query": "latest AI news",
3549 },
3550 },
3551 map[string]any{
3552 "type": "message",
3553 "id": "msg_01",
3554 "role": "assistant",
3555 "status": "completed",
3556 "content": []any{
3557 map[string]any{
3558 "type": "output_text",
3559 "text": "Based on recent search results, here is the latest AI news.",
3560 "annotations": []any{
3561 map[string]any{
3562 "type": "url_citation",
3563 "url": "https://example.com/ai-news",
3564 "title": "Latest AI News",
3565 "start_index": 0,
3566 "end_index": 50,
3567 },
3568 map[string]any{
3569 "type": "url_citation",
3570 "url": "https://example.com/ml-update",
3571 "title": "ML Update",
3572 "start_index": 51,
3573 "end_index": 60,
3574 },
3575 },
3576 },
3577 },
3578 },
3579 },
3580 "status": "completed",
3581 "usage": map[string]any{
3582 "input_tokens": 100,
3583 "output_tokens": 50,
3584 "total_tokens": 150,
3585 },
3586 }
3587}
3588
3589func newResponsesProvider(t *testing.T, serverURL string) fantasy.LanguageModel {
3590 t.Helper()
3591 provider, err := New(
3592 WithAPIKey("test-api-key"),
3593 WithBaseURL(serverURL),
3594 WithUseResponsesAPI(),
3595 )
3596 require.NoError(t, err)
3597 model, err := provider.LanguageModel(context.Background(), "gpt-4.1")
3598 require.NoError(t, err)
3599 return model
3600}
3601
3602func TestResponsesGenerate_WebSearchResponse(t *testing.T) {
3603 t.Parallel()
3604
3605 server := newMockServer()
3606 defer server.close()
3607 server.response = mockResponsesWebSearchResponse()
3608
3609 model := newResponsesProvider(t, server.server.URL)
3610
3611 resp, err := model.Generate(context.Background(), fantasy.Call{
3612 Prompt: testPrompt,
3613 Tools: []fantasy.Tool{WebSearchTool(nil)},
3614 })
3615 require.NoError(t, err)
3616
3617 require.Equal(t, "POST", server.calls[0].method)
3618 require.Equal(t, "/responses", server.calls[0].path)
3619
3620 var (
3621 toolCalls []fantasy.ToolCallContent
3622 sources []fantasy.SourceContent
3623 toolResults []fantasy.ToolResultContent
3624 texts []fantasy.TextContent
3625 )
3626 for _, c := range resp.Content {
3627 switch v := c.(type) {
3628 case fantasy.ToolCallContent:
3629 toolCalls = append(toolCalls, v)
3630 case fantasy.SourceContent:
3631 sources = append(sources, v)
3632 case fantasy.ToolResultContent:
3633 toolResults = append(toolResults, v)
3634 case fantasy.TextContent:
3635 texts = append(texts, v)
3636 }
3637 }
3638
3639 // ToolCallContent for the provider-executed web_search.
3640 require.Len(t, toolCalls, 1)
3641 require.True(t, toolCalls[0].ProviderExecuted)
3642 require.Equal(t, "web_search", toolCalls[0].ToolName)
3643 require.Equal(t, "ws_01", toolCalls[0].ToolCallID)
3644
3645 // SourceContent entries from url_citation annotations.
3646 require.Len(t, sources, 2)
3647 require.Equal(t, "https://example.com/ai-news", sources[0].URL)
3648 require.Equal(t, "Latest AI News", sources[0].Title)
3649 require.Equal(t, fantasy.SourceTypeURL, sources[0].SourceType)
3650 require.Equal(t, "https://example.com/ml-update", sources[1].URL)
3651 require.Equal(t, "ML Update", sources[1].Title)
3652
3653 // ToolResultContent with provider metadata.
3654 require.Len(t, toolResults, 1)
3655 require.True(t, toolResults[0].ProviderExecuted)
3656 require.Equal(t, "web_search", toolResults[0].ToolName)
3657 require.Equal(t, "ws_01", toolResults[0].ToolCallID)
3658
3659 metaVal, ok := toolResults[0].ProviderMetadata[Name]
3660 require.True(t, ok, "providerMetadata should contain openai key")
3661 wsMeta, ok := metaVal.(*WebSearchCallMetadata)
3662 require.True(t, ok, "metadata should be *WebSearchCallMetadata")
3663 require.Equal(t, "ws_01", wsMeta.ItemID)
3664 require.NotNil(t, wsMeta.Action)
3665 require.Equal(t, "search", wsMeta.Action.Type)
3666 require.Equal(t, "latest AI news", wsMeta.Action.Query)
3667
3668 // TextContent with the final answer.
3669 require.Len(t, texts, 1)
3670 require.Equal(t,
3671 "Based on recent search results, here is the latest AI news.",
3672 texts[0].Text,
3673 )
3674}
3675
3676func TestResponsesGenerate_StoreOption(t *testing.T) {
3677 t.Parallel()
3678
3679 server := newMockServer()
3680 defer server.close()
3681 server.response = mockResponsesWebSearchResponse()
3682
3683 model := newResponsesProvider(t, server.server.URL)
3684
3685 _, err := model.Generate(context.Background(), fantasy.Call{
3686 Prompt: testPrompt,
3687 ProviderOptions: fantasy.ProviderOptions{
3688 Name: &ResponsesProviderOptions{
3689 Store: fantasy.Opt(true),
3690 },
3691 },
3692 })
3693 require.NoError(t, err)
3694
3695 require.Equal(t, "POST", server.calls[0].method)
3696 require.Equal(t, "/responses", server.calls[0].path)
3697 require.Equal(t, true, server.calls[0].body["store"])
3698}
3699
3700func TestResponsesGenerate_PreviousResponseIDOption(t *testing.T) {
3701 t.Parallel()
3702
3703 server := newMockServer()
3704 defer server.close()
3705 server.response = mockResponsesWebSearchResponse()
3706
3707 model := newResponsesProvider(t, server.server.URL)
3708
3709 _, err := model.Generate(context.Background(), fantasy.Call{
3710 Prompt: testPrompt,
3711 ProviderOptions: fantasy.ProviderOptions{
3712 Name: &ResponsesProviderOptions{
3713 PreviousResponseID: fantasy.Opt("resp_prev_123"),
3714 Store: fantasy.Opt(true),
3715 },
3716 },
3717 })
3718 require.NoError(t, err)
3719
3720 require.Equal(t, "POST", server.calls[0].method)
3721 require.Equal(t, "/responses", server.calls[0].path)
3722 require.Equal(t, "resp_prev_123", server.calls[0].body["previous_response_id"])
3723}
3724
3725func TestResponsesGenerate_StateChainingAcrossTurns(t *testing.T) {
3726 t.Parallel()
3727
3728 server := newMockServer()
3729 defer server.close()
3730 server.response = map[string]any{
3731 "id": "resp_turn_1",
3732 "object": "response",
3733 "model": "gpt-4.1",
3734 "output": []any{
3735 map[string]any{
3736 "type": "message",
3737 "id": "msg_1",
3738 "role": "assistant",
3739 "status": "completed",
3740 "content": []any{
3741 map[string]any{
3742 "type": "output_text",
3743 "text": "First turn",
3744 },
3745 },
3746 },
3747 },
3748 "status": "completed",
3749 "usage": map[string]any{
3750 "input_tokens": 10,
3751 "output_tokens": 5,
3752 "total_tokens": 15,
3753 },
3754 }
3755
3756 model := newResponsesProvider(t, server.server.URL)
3757
3758 first, err := model.Generate(context.Background(), fantasy.Call{
3759 Prompt: testPrompt,
3760 ProviderOptions: fantasy.ProviderOptions{
3761 Name: &ResponsesProviderOptions{Store: fantasy.Opt(true)},
3762 },
3763 })
3764 require.NoError(t, err)
3765
3766 meta, ok := first.ProviderMetadata[Name].(*ResponsesProviderMetadata)
3767 require.True(t, ok)
3768 require.Equal(t, "resp_turn_1", meta.ResponseID)
3769
3770 server.response = map[string]any{
3771 "id": "resp_turn_2",
3772 "object": "response",
3773 "model": "gpt-4.1",
3774 "output": []any{
3775 map[string]any{
3776 "type": "message",
3777 "id": "msg_2",
3778 "role": "assistant",
3779 "status": "completed",
3780 "content": []any{
3781 map[string]any{
3782 "type": "output_text",
3783 "text": "Second turn",
3784 },
3785 },
3786 },
3787 },
3788 "status": "completed",
3789 "usage": map[string]any{
3790 "input_tokens": 8,
3791 "output_tokens": 4,
3792 "total_tokens": 12,
3793 },
3794 }
3795
3796 _, err = model.Generate(context.Background(), fantasy.Call{
3797 Prompt: fantasy.Prompt{
3798 fantasy.NewUserMessage("follow-up only"),
3799 },
3800 ProviderOptions: fantasy.ProviderOptions{
3801 Name: &ResponsesProviderOptions{
3802 Store: fantasy.Opt(true),
3803 PreviousResponseID: &meta.ResponseID,
3804 },
3805 },
3806 })
3807 require.NoError(t, err)
3808 require.Len(t, server.calls, 2)
3809
3810 firstCall := server.calls[0]
3811 require.Equal(t, true, firstCall.body["store"])
3812
3813 secondCall := server.calls[1]
3814 require.Equal(t, "resp_turn_1", secondCall.body["previous_response_id"])
3815 require.Equal(t, true, secondCall.body["store"])
3816
3817 input, ok := secondCall.body["input"].([]any)
3818 require.True(t, ok)
3819 require.Len(t, input, 1)
3820
3821 inputMessage, ok := input[0].(map[string]any)
3822 require.True(t, ok)
3823 require.Equal(t, "user", inputMessage["role"])
3824}
3825
3826func TestResponsesGenerate_WebSearchToolInRequest(t *testing.T) {
3827 t.Parallel()
3828
3829 t.Run("basic web_search tool", func(t *testing.T) {
3830 t.Parallel()
3831
3832 server := newMockServer()
3833 defer server.close()
3834 server.response = mockResponsesWebSearchResponse()
3835
3836 model := newResponsesProvider(t, server.server.URL)
3837
3838 _, err := model.Generate(context.Background(), fantasy.Call{
3839 Prompt: testPrompt,
3840 Tools: []fantasy.Tool{WebSearchTool(nil)},
3841 })
3842 require.NoError(t, err)
3843
3844 tools, ok := server.calls[0].body["tools"].([]any)
3845 require.True(t, ok, "request body should have tools array")
3846 require.Len(t, tools, 1)
3847
3848 tool, ok := tools[0].(map[string]any)
3849 require.True(t, ok)
3850 require.Equal(t, "web_search", tool["type"])
3851 })
3852
3853 t.Run("with search_context_size and allowed_domains", func(t *testing.T) {
3854 t.Parallel()
3855
3856 server := newMockServer()
3857 defer server.close()
3858 server.response = mockResponsesWebSearchResponse()
3859
3860 model := newResponsesProvider(t, server.server.URL)
3861
3862 _, err := model.Generate(context.Background(), fantasy.Call{
3863 Prompt: testPrompt,
3864 Tools: []fantasy.Tool{
3865 WebSearchTool(&WebSearchToolOptions{
3866 SearchContextSize: SearchContextSizeHigh,
3867 AllowedDomains: []string{"example.com", "test.com"},
3868 }),
3869 },
3870 })
3871 require.NoError(t, err)
3872
3873 tools, ok := server.calls[0].body["tools"].([]any)
3874 require.True(t, ok)
3875 require.Len(t, tools, 1)
3876
3877 tool, ok := tools[0].(map[string]any)
3878 require.True(t, ok)
3879 require.Equal(t, "web_search", tool["type"])
3880 require.Equal(t, "high", tool["search_context_size"])
3881
3882 filters, ok := tool["filters"].(map[string]any)
3883 require.True(t, ok, "tool should have filters")
3884 domains, ok := filters["allowed_domains"].([]any)
3885 require.True(t, ok, "filters should have allowed_domains")
3886 require.Len(t, domains, 2)
3887 require.Equal(t, "example.com", domains[0])
3888 require.Equal(t, "test.com", domains[1])
3889 })
3890
3891 t.Run("with user_location", func(t *testing.T) {
3892 t.Parallel()
3893
3894 server := newMockServer()
3895 defer server.close()
3896 server.response = mockResponsesWebSearchResponse()
3897
3898 model := newResponsesProvider(t, server.server.URL)
3899
3900 _, err := model.Generate(context.Background(), fantasy.Call{
3901 Prompt: testPrompt,
3902 Tools: []fantasy.Tool{
3903 WebSearchTool(&WebSearchToolOptions{
3904 UserLocation: &WebSearchUserLocation{
3905 City: "San Francisco",
3906 Country: "US",
3907 },
3908 }),
3909 },
3910 })
3911 require.NoError(t, err)
3912
3913 tools, ok := server.calls[0].body["tools"].([]any)
3914 require.True(t, ok)
3915 require.Len(t, tools, 1)
3916
3917 tool, ok := tools[0].(map[string]any)
3918 require.True(t, ok)
3919 require.Equal(t, "web_search", tool["type"])
3920
3921 userLoc, ok := tool["user_location"].(map[string]any)
3922 require.True(t, ok, "tool should have user_location")
3923 require.Equal(t, "San Francisco", userLoc["city"])
3924 require.Equal(t, "US", userLoc["country"])
3925 })
3926}
3927
3928func TestResponsesToPrompt_WebSearchProviderExecutedToolResults(t *testing.T) {
3929 t.Parallel()
3930
3931 prompt := fantasy.Prompt{
3932 {
3933 Role: fantasy.MessageRoleUser,
3934 Content: []fantasy.MessagePart{
3935 fantasy.TextPart{Text: "Search for the latest AI news"},
3936 },
3937 },
3938 {
3939 Role: fantasy.MessageRoleAssistant,
3940 Content: []fantasy.MessagePart{
3941 fantasy.ToolCallPart{
3942 ToolCallID: "ws_01",
3943 ToolName: "web_search",
3944 ProviderExecuted: true,
3945 },
3946 fantasy.ToolResultPart{
3947 ToolCallID: "ws_01",
3948 ProviderExecuted: true,
3949 },
3950 fantasy.TextPart{Text: "Here is what I found."},
3951 },
3952 },
3953 }
3954
3955 t.Run("store false skips item reference", func(t *testing.T) {
3956 t.Parallel()
3957
3958 input, warnings := toResponsesPrompt(prompt, "system instructions", false)
3959
3960 require.Empty(t, warnings)
3961 require.Len(t, input, 2,
3962 "expected user + assistant text when store=false")
3963 require.Nil(t, input[0].OfItemReference)
3964 require.Nil(t, input[1].OfItemReference)
3965 })
3966
3967 t.Run("store true uses item reference", func(t *testing.T) {
3968 t.Parallel()
3969
3970 input, warnings := toResponsesPrompt(prompt, "system instructions", true)
3971
3972 require.Empty(t, warnings)
3973 require.Len(t, input, 3,
3974 "expected user + item_reference + assistant text when store=true")
3975 require.NotNil(t, input[1].OfItemReference)
3976 require.Equal(t, "ws_01", input[1].OfItemReference.ID)
3977 })
3978}
3979
3980func TestResponsesToPrompt_ReasoningWithStore(t *testing.T) {
3981 t.Parallel()
3982
3983 encryptedContent := "gAAAAABpvAwtDPh5dSXW86hwbwoTo4DJHANQ"
3984 reasoningItemID := "rs_08d030b87966238b0069bc095b7e5c81"
3985
3986 reasoningPart := fantasy.ReasoningPart{
3987 Text: "Let me think about this...",
3988 ProviderOptions: fantasy.ProviderOptions{
3989 Name: &ResponsesReasoningMetadata{
3990 ItemID: reasoningItemID,
3991 EncryptedContent: &encryptedContent,
3992 Summary: []string{},
3993 },
3994 },
3995 }
3996
3997 prompt := fantasy.Prompt{
3998 {
3999 Role: fantasy.MessageRoleUser,
4000 Content: []fantasy.MessagePart{
4001 fantasy.TextPart{Text: "What is 2+2?"},
4002 },
4003 },
4004 {
4005 Role: fantasy.MessageRoleAssistant,
4006 Content: []fantasy.MessagePart{
4007 reasoningPart,
4008 fantasy.TextPart{Text: "4"},
4009 },
4010 },
4011 {
4012 Role: fantasy.MessageRoleUser,
4013 Content: []fantasy.MessagePart{
4014 fantasy.TextPart{Text: "And 3+3?"},
4015 },
4016 },
4017 }
4018
4019 t.Run("store true skips reasoning", func(t *testing.T) {
4020 t.Parallel()
4021
4022 input, warnings := toResponsesPrompt(prompt, "system", true)
4023 require.Empty(t, warnings)
4024
4025 // With store=true: user, assistant text (reasoning
4026 // skipped), follow-up user.
4027 require.Len(t, input, 3)
4028
4029 // Verify no reasoning item leaked through.
4030 for _, item := range input {
4031 require.Nil(t, item.OfReasoning,
4032 "reasoning items must not appear when store=true")
4033 }
4034 })
4035
4036 t.Run("store false skips reasoning", func(t *testing.T) {
4037 t.Parallel()
4038
4039 input, warnings := toResponsesPrompt(prompt, "system", false)
4040 require.Empty(t, warnings)
4041
4042 // With store=false: user, assistant text, follow-up user.
4043 require.Len(t, input, 3)
4044
4045 for _, item := range input {
4046 require.Nil(t, item.OfReasoning,
4047 "reasoning items must not appear when store=false")
4048 }
4049 })
4050}
4051
4052func TestResponsesStream_WebSearchResponse(t *testing.T) {
4053 t.Parallel()
4054
4055 chunks := []string{
4056 "event: response.output_item.added\n" +
4057 `data: {"type":"response.output_item.added","output_index":0,"item":{"type":"web_search_call","id":"ws_01","status":"in_progress"}}` + "\n\n",
4058 "event: response.output_item.done\n" +
4059 `data: {"type":"response.output_item.done","output_index":0,"item":{"type":"web_search_call","id":"ws_01","status":"completed","action":{"type":"search","query":"latest AI news"}}}` + "\n\n",
4060 "event: response.output_item.added\n" +
4061 `data: {"type":"response.output_item.added","output_index":1,"item":{"type":"message","id":"msg_01","role":"assistant","status":"in_progress","content":[]}}` + "\n\n",
4062 "event: response.output_text.delta\n" +
4063 `data: {"type":"response.output_text.delta","output_index":1,"content_index":0,"delta":"Here are the results."}` + "\n\n",
4064 "event: response.output_item.done\n" +
4065 `data: {"type":"response.output_item.done","output_index":1,"item":{"type":"message","id":"msg_01","role":"assistant","status":"completed","content":[{"type":"output_text","text":"Here are the results.","annotations":[{"type":"url_citation","url":"https://example.com/ai-news","title":"Latest AI News","start_index":0,"end_index":21}]}]}}` + "\n\n",
4066 "event: response.completed\n" +
4067 `data: {"type":"response.completed","response":{"id":"resp_01","status":"completed","output":[],"usage":{"input_tokens":100,"output_tokens":50,"total_tokens":150}}}` + "\n\n",
4068 }
4069
4070 sms := newStreamingMockServer()
4071 defer sms.close()
4072 sms.chunks = chunks
4073
4074 model := newResponsesProvider(t, sms.server.URL)
4075
4076 stream, err := model.Stream(context.Background(), fantasy.Call{
4077 Prompt: testPrompt,
4078 Tools: []fantasy.Tool{WebSearchTool(nil)},
4079 })
4080 require.NoError(t, err)
4081
4082 var parts []fantasy.StreamPart
4083 stream(func(part fantasy.StreamPart) bool {
4084 parts = append(parts, part)
4085 return true
4086 })
4087
4088 var (
4089 toolInputStarts []fantasy.StreamPart
4090 toolCalls []fantasy.StreamPart
4091 toolResults []fantasy.StreamPart
4092 textDeltas []fantasy.StreamPart
4093 finishes []fantasy.StreamPart
4094 )
4095 for _, p := range parts {
4096 switch p.Type {
4097 case fantasy.StreamPartTypeToolInputStart:
4098 toolInputStarts = append(toolInputStarts, p)
4099 case fantasy.StreamPartTypeToolCall:
4100 toolCalls = append(toolCalls, p)
4101 case fantasy.StreamPartTypeToolResult:
4102 toolResults = append(toolResults, p)
4103 case fantasy.StreamPartTypeTextDelta:
4104 textDeltas = append(textDeltas, p)
4105 case fantasy.StreamPartTypeFinish:
4106 finishes = append(finishes, p)
4107 }
4108 }
4109
4110 require.NotEmpty(t, toolInputStarts, "should have a tool input start")
4111 require.True(t, toolInputStarts[0].ProviderExecuted)
4112 require.Equal(t, "web_search", toolInputStarts[0].ToolCallName)
4113
4114 require.NotEmpty(t, toolCalls, "should have a tool call")
4115 require.True(t, toolCalls[0].ProviderExecuted)
4116 require.Equal(t, "web_search", toolCalls[0].ToolCallName)
4117
4118 require.NotEmpty(t, toolResults, "should have a tool result")
4119 require.True(t, toolResults[0].ProviderExecuted)
4120 require.Equal(t, "web_search", toolResults[0].ToolCallName)
4121 require.Equal(t, "ws_01", toolResults[0].ID)
4122
4123 require.NotEmpty(t, textDeltas, "should have text deltas")
4124 require.Equal(t, "Here are the results.", textDeltas[0].Delta)
4125
4126 require.Len(t, finishes, 1)
4127 responsesMeta, ok := finishes[0].ProviderMetadata[Name].(*ResponsesProviderMetadata)
4128 require.True(t, ok)
4129 require.Equal(t, "resp_01", responsesMeta.ResponseID)
4130}
4131
4132func TestResponsesStream_StoreOption(t *testing.T) {
4133 t.Parallel()
4134
4135 chunks := []string{
4136 "event: response.completed\n" +
4137 `data: {"type":"response.completed","response":{"id":"resp_01","status":"completed","output":[],"usage":{"input_tokens":100,"output_tokens":50,"total_tokens":150}}}` + "\n\n",
4138 }
4139
4140 sms := newStreamingMockServer()
4141 defer sms.close()
4142 sms.chunks = chunks
4143
4144 model := newResponsesProvider(t, sms.server.URL)
4145
4146 stream, err := model.Stream(context.Background(), fantasy.Call{
4147 Prompt: testPrompt,
4148 ProviderOptions: fantasy.ProviderOptions{
4149 Name: &ResponsesProviderOptions{
4150 Store: fantasy.Opt(true),
4151 },
4152 },
4153 })
4154 require.NoError(t, err)
4155
4156 stream(func(part fantasy.StreamPart) bool {
4157 return part.Type != fantasy.StreamPartTypeFinish
4158 })
4159
4160 require.Equal(t, "POST", sms.calls[0].method)
4161 require.Equal(t, "/responses", sms.calls[0].path)
4162 require.Equal(t, true, sms.calls[0].body["store"])
4163}
4164
4165func TestResponsesStream_PreviousResponseIDOption(t *testing.T) {
4166 t.Parallel()
4167
4168 chunks := []string{
4169 "event: response.completed\n" +
4170 `data: {"type":"response.completed","response":{"id":"resp_01","status":"completed","output":[],"usage":{"input_tokens":100,"output_tokens":50,"total_tokens":150}}}` + "\n\n",
4171 }
4172
4173 sms := newStreamingMockServer()
4174 defer sms.close()
4175 sms.chunks = chunks
4176
4177 model := newResponsesProvider(t, sms.server.URL)
4178
4179 stream, err := model.Stream(context.Background(), fantasy.Call{
4180 Prompt: testPrompt,
4181 ProviderOptions: fantasy.ProviderOptions{
4182 Name: &ResponsesProviderOptions{
4183 PreviousResponseID: fantasy.Opt("resp_prev_456"),
4184 Store: fantasy.Opt(true),
4185 },
4186 },
4187 })
4188 require.NoError(t, err)
4189
4190 stream(func(part fantasy.StreamPart) bool {
4191 return part.Type != fantasy.StreamPartTypeFinish
4192 })
4193
4194 require.Equal(t, "POST", sms.calls[0].method)
4195 require.Equal(t, "/responses", sms.calls[0].path)
4196 require.Equal(t, "resp_prev_456", sms.calls[0].body["previous_response_id"])
4197}