1package openai
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "net/http"
9 "net/http/httptest"
10 "strings"
11 "testing"
12
13 "charm.land/fantasy"
14 "github.com/charmbracelet/openai-go/packages/param"
15 "github.com/stretchr/testify/assert"
16 "github.com/stretchr/testify/require"
17)
18
19func TestToOpenAiPrompt_SystemMessages(t *testing.T) {
20 t.Parallel()
21
22 t.Run("should forward system messages", func(t *testing.T) {
23 t.Parallel()
24
25 prompt := fantasy.Prompt{
26 {
27 Role: fantasy.MessageRoleSystem,
28 Content: []fantasy.MessagePart{
29 fantasy.TextPart{Text: "You are a helpful assistant."},
30 },
31 },
32 }
33
34 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
35
36 require.Empty(t, warnings)
37 require.Len(t, messages, 1)
38
39 systemMsg := messages[0].OfSystem
40 require.NotNil(t, systemMsg)
41 require.Equal(t, "You are a helpful assistant.", systemMsg.Content.OfString.Value)
42 })
43
44 t.Run("should handle empty system messages", func(t *testing.T) {
45 t.Parallel()
46
47 prompt := fantasy.Prompt{
48 {
49 Role: fantasy.MessageRoleSystem,
50 Content: []fantasy.MessagePart{},
51 },
52 }
53
54 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
55
56 require.Len(t, warnings, 1)
57 require.Contains(t, warnings[0].Message, "system prompt has no text parts")
58 require.Empty(t, messages)
59 })
60
61 t.Run("should join multiple system text parts", func(t *testing.T) {
62 t.Parallel()
63
64 prompt := fantasy.Prompt{
65 {
66 Role: fantasy.MessageRoleSystem,
67 Content: []fantasy.MessagePart{
68 fantasy.TextPart{Text: "You are a helpful assistant."},
69 fantasy.TextPart{Text: "Be concise."},
70 },
71 },
72 }
73
74 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
75
76 require.Empty(t, warnings)
77 require.Len(t, messages, 1)
78
79 systemMsg := messages[0].OfSystem
80 require.NotNil(t, systemMsg)
81 require.Equal(t, "You are a helpful assistant.\nBe concise.", systemMsg.Content.OfString.Value)
82 })
83}
84
85func TestToOpenAiPrompt_UserMessages(t *testing.T) {
86 t.Parallel()
87
88 t.Run("should convert messages with only a text part to a string content", func(t *testing.T) {
89 t.Parallel()
90
91 prompt := fantasy.Prompt{
92 {
93 Role: fantasy.MessageRoleUser,
94 Content: []fantasy.MessagePart{
95 fantasy.TextPart{Text: "Hello"},
96 },
97 },
98 }
99
100 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
101
102 require.Empty(t, warnings)
103 require.Len(t, messages, 1)
104
105 userMsg := messages[0].OfUser
106 require.NotNil(t, userMsg)
107 require.Equal(t, "Hello", userMsg.Content.OfString.Value)
108 })
109
110 t.Run("should convert messages with image parts", func(t *testing.T) {
111 t.Parallel()
112
113 imageData := []byte{0, 1, 2, 3}
114 prompt := fantasy.Prompt{
115 {
116 Role: fantasy.MessageRoleUser,
117 Content: []fantasy.MessagePart{
118 fantasy.TextPart{Text: "Hello"},
119 fantasy.FilePart{
120 MediaType: "image/png",
121 Data: imageData,
122 },
123 },
124 },
125 }
126
127 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
128
129 require.Empty(t, warnings)
130 require.Len(t, messages, 1)
131
132 userMsg := messages[0].OfUser
133 require.NotNil(t, userMsg)
134
135 content := userMsg.Content.OfArrayOfContentParts
136 require.Len(t, content, 2)
137
138 // Check text part
139 textPart := content[0].OfText
140 require.NotNil(t, textPart)
141 require.Equal(t, "Hello", textPart.Text)
142
143 // Check image part
144 imagePart := content[1].OfImageURL
145 require.NotNil(t, imagePart)
146 expectedURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
147 require.Equal(t, expectedURL, imagePart.ImageURL.URL)
148 })
149
150 t.Run("should add image detail when specified through provider options", func(t *testing.T) {
151 t.Parallel()
152
153 imageData := []byte{0, 1, 2, 3}
154 prompt := fantasy.Prompt{
155 {
156 Role: fantasy.MessageRoleUser,
157 Content: []fantasy.MessagePart{
158 fantasy.FilePart{
159 MediaType: "image/png",
160 Data: imageData,
161 ProviderOptions: NewProviderFileOptions(&ProviderFileOptions{
162 ImageDetail: "low",
163 }),
164 },
165 },
166 },
167 }
168
169 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
170
171 require.Empty(t, warnings)
172 require.Len(t, messages, 1)
173
174 userMsg := messages[0].OfUser
175 require.NotNil(t, userMsg)
176
177 content := userMsg.Content.OfArrayOfContentParts
178 require.Len(t, content, 1)
179
180 imagePart := content[0].OfImageURL
181 require.NotNil(t, imagePart)
182 require.Equal(t, "low", imagePart.ImageURL.Detail)
183 })
184}
185
186func TestToOpenAiPrompt_FileParts(t *testing.T) {
187 t.Parallel()
188
189 t.Run("should throw for unsupported mime types", func(t *testing.T) {
190 t.Parallel()
191
192 prompt := fantasy.Prompt{
193 {
194 Role: fantasy.MessageRoleUser,
195 Content: []fantasy.MessagePart{
196 fantasy.FilePart{
197 MediaType: "application/something",
198 Data: []byte("test"),
199 },
200 },
201 },
202 }
203
204 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
205
206 require.Len(t, warnings, 2) // unsupported type + empty message
207 require.Contains(t, warnings[0].Message, "file part media type application/something not supported")
208 require.Contains(t, warnings[1].Message, "dropping empty user message")
209 require.Empty(t, messages) // Message is now dropped because it's empty
210 })
211
212 t.Run("should add audio content for audio/wav file parts", func(t *testing.T) {
213 t.Parallel()
214
215 audioData := []byte{0, 1, 2, 3}
216 prompt := fantasy.Prompt{
217 {
218 Role: fantasy.MessageRoleUser,
219 Content: []fantasy.MessagePart{
220 fantasy.FilePart{
221 MediaType: "audio/wav",
222 Data: audioData,
223 },
224 },
225 },
226 }
227
228 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
229
230 require.Empty(t, warnings)
231 require.Len(t, messages, 1)
232
233 userMsg := messages[0].OfUser
234 require.NotNil(t, userMsg)
235
236 content := userMsg.Content.OfArrayOfContentParts
237 require.Len(t, content, 1)
238
239 audioPart := content[0].OfInputAudio
240 require.NotNil(t, audioPart)
241 require.Equal(t, base64.StdEncoding.EncodeToString(audioData), audioPart.InputAudio.Data)
242 require.Equal(t, "wav", audioPart.InputAudio.Format)
243 })
244
245 t.Run("should add audio content for audio/mpeg file parts", func(t *testing.T) {
246 t.Parallel()
247
248 audioData := []byte{0, 1, 2, 3}
249 prompt := fantasy.Prompt{
250 {
251 Role: fantasy.MessageRoleUser,
252 Content: []fantasy.MessagePart{
253 fantasy.FilePart{
254 MediaType: "audio/mpeg",
255 Data: audioData,
256 },
257 },
258 },
259 }
260
261 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
262
263 require.Empty(t, warnings)
264 require.Len(t, messages, 1)
265
266 userMsg := messages[0].OfUser
267 content := userMsg.Content.OfArrayOfContentParts
268 audioPart := content[0].OfInputAudio
269 require.NotNil(t, audioPart)
270 require.Equal(t, "mp3", audioPart.InputAudio.Format)
271 })
272
273 t.Run("should add audio content for audio/mp3 file parts", func(t *testing.T) {
274 t.Parallel()
275
276 audioData := []byte{0, 1, 2, 3}
277 prompt := fantasy.Prompt{
278 {
279 Role: fantasy.MessageRoleUser,
280 Content: []fantasy.MessagePart{
281 fantasy.FilePart{
282 MediaType: "audio/mp3",
283 Data: audioData,
284 },
285 },
286 },
287 }
288
289 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
290
291 require.Empty(t, warnings)
292 require.Len(t, messages, 1)
293
294 userMsg := messages[0].OfUser
295 content := userMsg.Content.OfArrayOfContentParts
296 audioPart := content[0].OfInputAudio
297 require.NotNil(t, audioPart)
298 require.Equal(t, "mp3", audioPart.InputAudio.Format)
299 })
300
301 t.Run("should convert messages with PDF file parts", func(t *testing.T) {
302 t.Parallel()
303
304 pdfData := []byte{1, 2, 3, 4, 5}
305 prompt := fantasy.Prompt{
306 {
307 Role: fantasy.MessageRoleUser,
308 Content: []fantasy.MessagePart{
309 fantasy.FilePart{
310 MediaType: "application/pdf",
311 Data: pdfData,
312 Filename: "document.pdf",
313 },
314 },
315 },
316 }
317
318 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
319
320 require.Empty(t, warnings)
321 require.Len(t, messages, 1)
322
323 userMsg := messages[0].OfUser
324 content := userMsg.Content.OfArrayOfContentParts
325 require.Len(t, content, 1)
326
327 filePart := content[0].OfFile
328 require.NotNil(t, filePart)
329 require.Equal(t, "document.pdf", filePart.File.Filename.Value)
330
331 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
332 require.Equal(t, expectedData, filePart.File.FileData.Value)
333 })
334
335 t.Run("should convert messages with binary PDF file parts", func(t *testing.T) {
336 t.Parallel()
337
338 pdfData := []byte{1, 2, 3, 4, 5}
339 prompt := fantasy.Prompt{
340 {
341 Role: fantasy.MessageRoleUser,
342 Content: []fantasy.MessagePart{
343 fantasy.FilePart{
344 MediaType: "application/pdf",
345 Data: pdfData,
346 Filename: "document.pdf",
347 },
348 },
349 },
350 }
351
352 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
353
354 require.Empty(t, warnings)
355 require.Len(t, messages, 1)
356
357 userMsg := messages[0].OfUser
358 content := userMsg.Content.OfArrayOfContentParts
359 filePart := content[0].OfFile
360 require.NotNil(t, filePart)
361
362 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
363 require.Equal(t, expectedData, filePart.File.FileData.Value)
364 })
365
366 t.Run("should convert messages with PDF file parts using file_id", func(t *testing.T) {
367 t.Parallel()
368
369 prompt := fantasy.Prompt{
370 {
371 Role: fantasy.MessageRoleUser,
372 Content: []fantasy.MessagePart{
373 fantasy.FilePart{
374 MediaType: "application/pdf",
375 Data: []byte("file-pdf-12345"),
376 },
377 },
378 },
379 }
380
381 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
382
383 require.Empty(t, warnings)
384 require.Len(t, messages, 1)
385
386 userMsg := messages[0].OfUser
387 content := userMsg.Content.OfArrayOfContentParts
388 filePart := content[0].OfFile
389 require.NotNil(t, filePart)
390 require.Equal(t, "file-pdf-12345", filePart.File.FileID.Value)
391 require.True(t, param.IsOmitted(filePart.File.FileData))
392 require.True(t, param.IsOmitted(filePart.File.Filename))
393 })
394
395 t.Run("should use default filename for PDF file parts when not provided", func(t *testing.T) {
396 t.Parallel()
397
398 pdfData := []byte{1, 2, 3, 4, 5}
399 prompt := fantasy.Prompt{
400 {
401 Role: fantasy.MessageRoleUser,
402 Content: []fantasy.MessagePart{
403 fantasy.FilePart{
404 MediaType: "application/pdf",
405 Data: pdfData,
406 },
407 },
408 },
409 }
410
411 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
412
413 require.Empty(t, warnings)
414 require.Len(t, messages, 1)
415
416 userMsg := messages[0].OfUser
417 content := userMsg.Content.OfArrayOfContentParts
418 filePart := content[0].OfFile
419 require.NotNil(t, filePart)
420 require.Equal(t, "part-0.pdf", filePart.File.Filename.Value)
421 })
422}
423
424func TestToOpenAiPrompt_ToolCalls(t *testing.T) {
425 t.Parallel()
426
427 t.Run("should stringify arguments to tool calls", func(t *testing.T) {
428 t.Parallel()
429
430 inputArgs := map[string]any{"foo": "bar123"}
431 inputJSON, _ := json.Marshal(inputArgs)
432
433 outputResult := map[string]any{"oof": "321rab"}
434 outputJSON, _ := json.Marshal(outputResult)
435
436 prompt := fantasy.Prompt{
437 {
438 Role: fantasy.MessageRoleAssistant,
439 Content: []fantasy.MessagePart{
440 fantasy.ToolCallPart{
441 ToolCallID: "quux",
442 ToolName: "thwomp",
443 Input: string(inputJSON),
444 },
445 },
446 },
447 {
448 Role: fantasy.MessageRoleTool,
449 Content: []fantasy.MessagePart{
450 fantasy.ToolResultPart{
451 ToolCallID: "quux",
452 Output: fantasy.ToolResultOutputContentText{
453 Text: string(outputJSON),
454 },
455 },
456 },
457 },
458 }
459
460 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
461
462 require.Empty(t, warnings)
463 require.Len(t, messages, 2)
464
465 // Check assistant message with tool call
466 assistantMsg := messages[0].OfAssistant
467 require.NotNil(t, assistantMsg)
468 require.Equal(t, "", assistantMsg.Content.OfString.Value)
469 require.Len(t, assistantMsg.ToolCalls, 1)
470
471 toolCall := assistantMsg.ToolCalls[0].OfFunction
472 require.NotNil(t, toolCall)
473 require.Equal(t, "quux", toolCall.ID)
474 require.Equal(t, "thwomp", toolCall.Function.Name)
475 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
476
477 // Check tool message
478 toolMsg := messages[1].OfTool
479 require.NotNil(t, toolMsg)
480 require.Equal(t, string(outputJSON), toolMsg.Content.OfString.Value)
481 require.Equal(t, "quux", toolMsg.ToolCallID)
482 })
483
484 t.Run("should handle different tool output types", func(t *testing.T) {
485 t.Parallel()
486
487 prompt := fantasy.Prompt{
488 {
489 Role: fantasy.MessageRoleTool,
490 Content: []fantasy.MessagePart{
491 fantasy.ToolResultPart{
492 ToolCallID: "text-tool",
493 Output: fantasy.ToolResultOutputContentText{
494 Text: "Hello world",
495 },
496 },
497 fantasy.ToolResultPart{
498 ToolCallID: "error-tool",
499 Output: fantasy.ToolResultOutputContentError{
500 Error: errors.New("Something went wrong"),
501 },
502 },
503 },
504 },
505 }
506
507 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
508
509 require.Empty(t, warnings)
510 require.Len(t, messages, 2)
511
512 // Check first tool message (text)
513 textToolMsg := messages[0].OfTool
514 require.NotNil(t, textToolMsg)
515 require.Equal(t, "Hello world", textToolMsg.Content.OfString.Value)
516 require.Equal(t, "text-tool", textToolMsg.ToolCallID)
517
518 // Check second tool message (error)
519 errorToolMsg := messages[1].OfTool
520 require.NotNil(t, errorToolMsg)
521 require.Equal(t, "Something went wrong", errorToolMsg.Content.OfString.Value)
522 require.Equal(t, "error-tool", errorToolMsg.ToolCallID)
523 })
524}
525
526func TestToOpenAiPrompt_AssistantMessages(t *testing.T) {
527 t.Parallel()
528
529 t.Run("should handle simple text assistant messages", func(t *testing.T) {
530 t.Parallel()
531
532 prompt := fantasy.Prompt{
533 {
534 Role: fantasy.MessageRoleAssistant,
535 Content: []fantasy.MessagePart{
536 fantasy.TextPart{Text: "Hello, how can I help you?"},
537 },
538 },
539 }
540
541 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
542
543 require.Empty(t, warnings)
544 require.Len(t, messages, 1)
545
546 assistantMsg := messages[0].OfAssistant
547 require.NotNil(t, assistantMsg)
548 require.Equal(t, "Hello, how can I help you?", assistantMsg.Content.OfString.Value)
549 })
550
551 t.Run("should handle assistant messages with mixed content", func(t *testing.T) {
552 t.Parallel()
553
554 inputArgs := map[string]any{"query": "test"}
555 inputJSON, _ := json.Marshal(inputArgs)
556
557 prompt := fantasy.Prompt{
558 {
559 Role: fantasy.MessageRoleAssistant,
560 Content: []fantasy.MessagePart{
561 fantasy.TextPart{Text: "Let me search for that."},
562 fantasy.ToolCallPart{
563 ToolCallID: "call-123",
564 ToolName: "search",
565 Input: string(inputJSON),
566 },
567 },
568 },
569 }
570
571 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
572
573 require.Empty(t, warnings)
574 require.Len(t, messages, 1)
575
576 assistantMsg := messages[0].OfAssistant
577 require.NotNil(t, assistantMsg)
578 require.Equal(t, "Let me search for that.", assistantMsg.Content.OfString.Value)
579 require.Len(t, assistantMsg.ToolCalls, 1)
580
581 toolCall := assistantMsg.ToolCalls[0].OfFunction
582 require.Equal(t, "call-123", toolCall.ID)
583 require.Equal(t, "search", toolCall.Function.Name)
584 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
585 })
586}
587
588var testPrompt = fantasy.Prompt{
589 {
590 Role: fantasy.MessageRoleUser,
591 Content: []fantasy.MessagePart{
592 fantasy.TextPart{Text: "Hello"},
593 },
594 },
595}
596
597var testLogprobs = map[string]any{
598 "content": []map[string]any{
599 {
600 "token": "Hello",
601 "logprob": -0.0009994634,
602 "top_logprobs": []map[string]any{
603 {
604 "token": "Hello",
605 "logprob": -0.0009994634,
606 },
607 },
608 },
609 {
610 "token": "!",
611 "logprob": -0.13410144,
612 "top_logprobs": []map[string]any{
613 {
614 "token": "!",
615 "logprob": -0.13410144,
616 },
617 },
618 },
619 {
620 "token": " How",
621 "logprob": -0.0009250381,
622 "top_logprobs": []map[string]any{
623 {
624 "token": " How",
625 "logprob": -0.0009250381,
626 },
627 },
628 },
629 {
630 "token": " can",
631 "logprob": -0.047709424,
632 "top_logprobs": []map[string]any{
633 {
634 "token": " can",
635 "logprob": -0.047709424,
636 },
637 },
638 },
639 {
640 "token": " I",
641 "logprob": -0.000009014684,
642 "top_logprobs": []map[string]any{
643 {
644 "token": " I",
645 "logprob": -0.000009014684,
646 },
647 },
648 },
649 {
650 "token": " assist",
651 "logprob": -0.009125131,
652 "top_logprobs": []map[string]any{
653 {
654 "token": " assist",
655 "logprob": -0.009125131,
656 },
657 },
658 },
659 {
660 "token": " you",
661 "logprob": -0.0000066306106,
662 "top_logprobs": []map[string]any{
663 {
664 "token": " you",
665 "logprob": -0.0000066306106,
666 },
667 },
668 },
669 {
670 "token": " today",
671 "logprob": -0.00011093382,
672 "top_logprobs": []map[string]any{
673 {
674 "token": " today",
675 "logprob": -0.00011093382,
676 },
677 },
678 },
679 {
680 "token": "?",
681 "logprob": -0.00004596782,
682 "top_logprobs": []map[string]any{
683 {
684 "token": "?",
685 "logprob": -0.00004596782,
686 },
687 },
688 },
689 },
690}
691
692type mockServer struct {
693 server *httptest.Server
694 response map[string]any
695 calls []mockCall
696}
697
698type mockCall struct {
699 method string
700 path string
701 headers map[string]string
702 body map[string]any
703}
704
705func newMockServer() *mockServer {
706 ms := &mockServer{
707 calls: make([]mockCall, 0),
708 }
709
710 ms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
711 // Record the call
712 call := mockCall{
713 method: r.Method,
714 path: r.URL.Path,
715 headers: make(map[string]string),
716 }
717
718 for k, v := range r.Header {
719 if len(v) > 0 {
720 call.headers[k] = v[0]
721 }
722 }
723
724 // Parse request body
725 if r.Body != nil {
726 var body map[string]any
727 json.NewDecoder(r.Body).Decode(&body)
728 call.body = body
729 }
730
731 ms.calls = append(ms.calls, call)
732
733 // Return mock response
734 w.Header().Set("Content-Type", "application/json")
735 json.NewEncoder(w).Encode(ms.response)
736 }))
737
738 return ms
739}
740
741func (ms *mockServer) close() {
742 ms.server.Close()
743}
744
745func (ms *mockServer) prepareJSONResponse(opts map[string]any) {
746 // Default values
747 response := map[string]any{
748 "id": "chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd",
749 "object": "chat.completion",
750 "created": 1711115037,
751 "model": "gpt-3.5-turbo-0125",
752 "choices": []map[string]any{
753 {
754 "index": 0,
755 "message": map[string]any{
756 "role": "assistant",
757 "content": "",
758 },
759 "finish_reason": "stop",
760 },
761 },
762 "usage": map[string]any{
763 "prompt_tokens": 4,
764 "total_tokens": 34,
765 "completion_tokens": 30,
766 },
767 "system_fingerprint": "fp_3bc1b5746c",
768 }
769
770 // Override with provided options
771 for k, v := range opts {
772 switch k {
773 case "content":
774 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["content"] = v
775 case "tool_calls":
776 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["tool_calls"] = v
777 case "function_call":
778 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["function_call"] = v
779 case "annotations":
780 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["annotations"] = v
781 case "usage":
782 response["usage"] = v
783 case "finish_reason":
784 response["choices"].([]map[string]any)[0]["finish_reason"] = v
785 case "id":
786 response["id"] = v
787 case "created":
788 response["created"] = v
789 case "model":
790 response["model"] = v
791 case "logprobs":
792 if v != nil {
793 response["choices"].([]map[string]any)[0]["logprobs"] = v
794 }
795 }
796 }
797
798 ms.response = response
799}
800
801func TestDoGenerate(t *testing.T) {
802 t.Parallel()
803
804 t.Run("should extract text response", func(t *testing.T) {
805 t.Parallel()
806
807 server := newMockServer()
808 defer server.close()
809
810 server.prepareJSONResponse(map[string]any{
811 "content": "Hello, World!",
812 })
813
814 provider, err := New(
815 WithAPIKey("test-api-key"),
816 WithBaseURL(server.server.URL),
817 )
818 require.NoError(t, err)
819 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
820
821 result, err := model.Generate(context.Background(), fantasy.Call{
822 Prompt: testPrompt,
823 })
824
825 require.NoError(t, err)
826 require.Len(t, result.Content, 1)
827
828 textContent, ok := result.Content[0].(fantasy.TextContent)
829 require.True(t, ok)
830 require.Equal(t, "Hello, World!", textContent.Text)
831 })
832
833 t.Run("should extract usage", func(t *testing.T) {
834 t.Parallel()
835
836 server := newMockServer()
837 defer server.close()
838
839 server.prepareJSONResponse(map[string]any{
840 "usage": map[string]any{
841 "prompt_tokens": 20,
842 "total_tokens": 25,
843 "completion_tokens": 5,
844 },
845 })
846
847 provider, err := New(
848 WithAPIKey("test-api-key"),
849 WithBaseURL(server.server.URL),
850 )
851 require.NoError(t, err)
852 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
853
854 result, err := model.Generate(context.Background(), fantasy.Call{
855 Prompt: testPrompt,
856 })
857
858 require.NoError(t, err)
859 require.Equal(t, int64(20), result.Usage.InputTokens)
860 require.Equal(t, int64(5), result.Usage.OutputTokens)
861 require.Equal(t, int64(25), result.Usage.TotalTokens)
862 })
863
864 t.Run("should send request body", func(t *testing.T) {
865 t.Parallel()
866
867 server := newMockServer()
868 defer server.close()
869
870 server.prepareJSONResponse(map[string]any{})
871
872 provider, err := New(
873 WithAPIKey("test-api-key"),
874 WithBaseURL(server.server.URL),
875 )
876 require.NoError(t, err)
877 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
878
879 _, err = model.Generate(context.Background(), fantasy.Call{
880 Prompt: testPrompt,
881 })
882
883 require.NoError(t, err)
884 require.Len(t, server.calls, 1)
885
886 call := server.calls[0]
887 require.Equal(t, "POST", call.method)
888 require.Equal(t, "/chat/completions", call.path)
889 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
890
891 messages, ok := call.body["messages"].([]any)
892 require.True(t, ok)
893 require.Len(t, messages, 1)
894
895 message := messages[0].(map[string]any)
896 require.Equal(t, "user", message["role"])
897 require.Equal(t, "Hello", message["content"])
898 })
899
900 t.Run("should support partial usage", func(t *testing.T) {
901 t.Parallel()
902
903 server := newMockServer()
904 defer server.close()
905
906 server.prepareJSONResponse(map[string]any{
907 "usage": map[string]any{
908 "prompt_tokens": 20,
909 "total_tokens": 20,
910 },
911 })
912
913 provider, err := New(
914 WithAPIKey("test-api-key"),
915 WithBaseURL(server.server.URL),
916 )
917 require.NoError(t, err)
918 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
919
920 result, err := model.Generate(context.Background(), fantasy.Call{
921 Prompt: testPrompt,
922 })
923
924 require.NoError(t, err)
925 require.Equal(t, int64(20), result.Usage.InputTokens)
926 require.Equal(t, int64(0), result.Usage.OutputTokens)
927 require.Equal(t, int64(20), result.Usage.TotalTokens)
928 })
929
930 t.Run("should extract logprobs", func(t *testing.T) {
931 t.Parallel()
932
933 server := newMockServer()
934 defer server.close()
935
936 server.prepareJSONResponse(map[string]any{
937 "logprobs": testLogprobs,
938 })
939
940 provider, err := New(
941 WithAPIKey("test-api-key"),
942 WithBaseURL(server.server.URL),
943 )
944 require.NoError(t, err)
945 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
946
947 result, err := model.Generate(context.Background(), fantasy.Call{
948 Prompt: testPrompt,
949 ProviderOptions: NewProviderOptions(&ProviderOptions{
950 LogProbs: new(true),
951 }),
952 })
953
954 require.NoError(t, err)
955 require.NotNil(t, result.ProviderMetadata)
956
957 openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
958 require.True(t, ok)
959
960 logprobs := openaiMeta.Logprobs
961 require.True(t, ok)
962 require.NotNil(t, logprobs)
963 })
964
965 t.Run("should extract finish reason", func(t *testing.T) {
966 t.Parallel()
967
968 server := newMockServer()
969 defer server.close()
970
971 server.prepareJSONResponse(map[string]any{
972 "finish_reason": "stop",
973 })
974
975 provider, err := New(
976 WithAPIKey("test-api-key"),
977 WithBaseURL(server.server.URL),
978 )
979 require.NoError(t, err)
980 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
981
982 result, err := model.Generate(context.Background(), fantasy.Call{
983 Prompt: testPrompt,
984 })
985
986 require.NoError(t, err)
987 require.Equal(t, fantasy.FinishReasonStop, result.FinishReason)
988 })
989
990 t.Run("should support unknown finish reason", func(t *testing.T) {
991 t.Parallel()
992
993 server := newMockServer()
994 defer server.close()
995
996 server.prepareJSONResponse(map[string]any{
997 "finish_reason": "eos",
998 })
999
1000 provider, err := New(
1001 WithAPIKey("test-api-key"),
1002 WithBaseURL(server.server.URL),
1003 )
1004 require.NoError(t, err)
1005 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1006
1007 result, err := model.Generate(context.Background(), fantasy.Call{
1008 Prompt: testPrompt,
1009 })
1010
1011 require.NoError(t, err)
1012 require.Equal(t, fantasy.FinishReasonUnknown, result.FinishReason)
1013 })
1014
1015 t.Run("should pass the model and the messages", func(t *testing.T) {
1016 t.Parallel()
1017
1018 server := newMockServer()
1019 defer server.close()
1020
1021 server.prepareJSONResponse(map[string]any{
1022 "content": "",
1023 })
1024
1025 provider, err := New(
1026 WithAPIKey("test-api-key"),
1027 WithBaseURL(server.server.URL),
1028 )
1029 require.NoError(t, err)
1030 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1031
1032 _, err = model.Generate(context.Background(), fantasy.Call{
1033 Prompt: testPrompt,
1034 })
1035
1036 require.NoError(t, err)
1037 require.Len(t, server.calls, 1)
1038
1039 call := server.calls[0]
1040 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1041
1042 messages := call.body["messages"].([]any)
1043 require.Len(t, messages, 1)
1044
1045 message := messages[0].(map[string]any)
1046 require.Equal(t, "user", message["role"])
1047 require.Equal(t, "Hello", message["content"])
1048 })
1049
1050 t.Run("should pass settings", func(t *testing.T) {
1051 t.Parallel()
1052
1053 server := newMockServer()
1054 defer server.close()
1055
1056 server.prepareJSONResponse(map[string]any{})
1057
1058 provider, err := New(
1059 WithAPIKey("test-api-key"),
1060 WithBaseURL(server.server.URL),
1061 )
1062 require.NoError(t, err)
1063 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1064
1065 _, err = model.Generate(context.Background(), fantasy.Call{
1066 Prompt: testPrompt,
1067 ProviderOptions: NewProviderOptions(&ProviderOptions{
1068 LogitBias: map[string]int64{
1069 "50256": -100,
1070 },
1071 ParallelToolCalls: new(false),
1072 User: new("test-user-id"),
1073 }),
1074 })
1075
1076 require.NoError(t, err)
1077 require.Len(t, server.calls, 1)
1078
1079 call := server.calls[0]
1080 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1081
1082 messages := call.body["messages"].([]any)
1083 require.Len(t, messages, 1)
1084
1085 logitBias := call.body["logit_bias"].(map[string]any)
1086 require.Equal(t, float64(-100), logitBias["50256"])
1087 require.Equal(t, false, call.body["parallel_tool_calls"])
1088 require.Equal(t, "test-user-id", call.body["user"])
1089 })
1090
1091 t.Run("should pass reasoningEffort setting", func(t *testing.T) {
1092 t.Parallel()
1093
1094 server := newMockServer()
1095 defer server.close()
1096
1097 server.prepareJSONResponse(map[string]any{
1098 "content": "",
1099 })
1100
1101 provider, err := New(
1102 WithAPIKey("test-api-key"),
1103 WithBaseURL(server.server.URL),
1104 )
1105 require.NoError(t, err)
1106 model, _ := provider.LanguageModel(t.Context(), "o1-mini")
1107
1108 _, err = model.Generate(context.Background(), fantasy.Call{
1109 Prompt: testPrompt,
1110 ProviderOptions: NewProviderOptions(
1111 &ProviderOptions{
1112 ReasoningEffort: new(ReasoningEffortLow),
1113 },
1114 ),
1115 })
1116
1117 require.NoError(t, err)
1118 require.Len(t, server.calls, 1)
1119
1120 call := server.calls[0]
1121 require.Equal(t, "o1-mini", call.body["model"])
1122 require.Equal(t, "low", call.body["reasoning_effort"])
1123
1124 messages := call.body["messages"].([]any)
1125 require.Len(t, messages, 1)
1126
1127 message := messages[0].(map[string]any)
1128 require.Equal(t, "user", message["role"])
1129 require.Equal(t, "Hello", message["content"])
1130 })
1131
1132 t.Run("should pass textVerbosity setting", func(t *testing.T) {
1133 t.Parallel()
1134
1135 server := newMockServer()
1136 defer server.close()
1137
1138 server.prepareJSONResponse(map[string]any{
1139 "content": "",
1140 })
1141
1142 provider, err := New(
1143 WithAPIKey("test-api-key"),
1144 WithBaseURL(server.server.URL),
1145 )
1146 require.NoError(t, err)
1147 model, _ := provider.LanguageModel(t.Context(), "gpt-4o")
1148
1149 _, err = model.Generate(context.Background(), fantasy.Call{
1150 Prompt: testPrompt,
1151 ProviderOptions: NewProviderOptions(&ProviderOptions{
1152 TextVerbosity: new("low"),
1153 }),
1154 })
1155
1156 require.NoError(t, err)
1157 require.Len(t, server.calls, 1)
1158
1159 call := server.calls[0]
1160 require.Equal(t, "gpt-4o", call.body["model"])
1161 require.Equal(t, "low", call.body["verbosity"])
1162
1163 messages := call.body["messages"].([]any)
1164 require.Len(t, messages, 1)
1165
1166 message := messages[0].(map[string]any)
1167 require.Equal(t, "user", message["role"])
1168 require.Equal(t, "Hello", message["content"])
1169 })
1170
1171 t.Run("should pass tools and toolChoice", func(t *testing.T) {
1172 t.Parallel()
1173
1174 server := newMockServer()
1175 defer server.close()
1176
1177 server.prepareJSONResponse(map[string]any{
1178 "content": "",
1179 })
1180
1181 provider, err := New(
1182 WithAPIKey("test-api-key"),
1183 WithBaseURL(server.server.URL),
1184 )
1185 require.NoError(t, err)
1186 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1187
1188 _, err = model.Generate(context.Background(), fantasy.Call{
1189 Prompt: testPrompt,
1190 Tools: []fantasy.Tool{
1191 fantasy.FunctionTool{
1192 Name: "test-tool",
1193 InputSchema: map[string]any{
1194 "type": "object",
1195 "properties": map[string]any{
1196 "value": map[string]any{
1197 "type": "string",
1198 },
1199 },
1200 "required": []string{"value"},
1201 "additionalProperties": false,
1202 "$schema": "http://json-schema.org/draft-07/schema#",
1203 },
1204 },
1205 },
1206 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoice("test-tool")}[0],
1207 })
1208
1209 require.NoError(t, err)
1210 require.Len(t, server.calls, 1)
1211
1212 call := server.calls[0]
1213 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1214
1215 messages := call.body["messages"].([]any)
1216 require.Len(t, messages, 1)
1217
1218 tools := call.body["tools"].([]any)
1219 require.Len(t, tools, 1)
1220
1221 tool := tools[0].(map[string]any)
1222 require.Equal(t, "function", tool["type"])
1223
1224 function := tool["function"].(map[string]any)
1225 require.Equal(t, "test-tool", function["name"])
1226 require.Equal(t, false, function["strict"])
1227
1228 toolChoice := call.body["tool_choice"].(map[string]any)
1229 require.Equal(t, "function", toolChoice["type"])
1230
1231 toolChoiceFunction := toolChoice["function"].(map[string]any)
1232 require.Equal(t, "test-tool", toolChoiceFunction["name"])
1233 })
1234
1235 t.Run("should parse tool results", func(t *testing.T) {
1236 t.Parallel()
1237
1238 server := newMockServer()
1239 defer server.close()
1240
1241 server.prepareJSONResponse(map[string]any{
1242 "tool_calls": []map[string]any{
1243 {
1244 "id": "call_O17Uplv4lJvD6DVdIvFFeRMw",
1245 "type": "function",
1246 "function": map[string]any{
1247 "name": "test-tool",
1248 "arguments": `{"value":"Spark"}`,
1249 },
1250 },
1251 },
1252 })
1253
1254 provider, err := New(
1255 WithAPIKey("test-api-key"),
1256 WithBaseURL(server.server.URL),
1257 )
1258 require.NoError(t, err)
1259 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1260
1261 result, err := model.Generate(context.Background(), fantasy.Call{
1262 Prompt: testPrompt,
1263 Tools: []fantasy.Tool{
1264 fantasy.FunctionTool{
1265 Name: "test-tool",
1266 InputSchema: map[string]any{
1267 "type": "object",
1268 "properties": map[string]any{
1269 "value": map[string]any{
1270 "type": "string",
1271 },
1272 },
1273 "required": []string{"value"},
1274 "additionalProperties": false,
1275 "$schema": "http://json-schema.org/draft-07/schema#",
1276 },
1277 },
1278 },
1279 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoice("test-tool")}[0],
1280 })
1281
1282 require.NoError(t, err)
1283 require.Len(t, result.Content, 1)
1284
1285 toolCall, ok := result.Content[0].(fantasy.ToolCallContent)
1286 require.True(t, ok)
1287 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", toolCall.ToolCallID)
1288 require.Equal(t, "test-tool", toolCall.ToolName)
1289 require.Equal(t, `{"value":"Spark"}`, toolCall.Input)
1290 })
1291
1292 t.Run("should handle ToolChoiceRequired", func(t *testing.T) {
1293 t.Parallel()
1294
1295 server := newMockServer()
1296 defer server.close()
1297
1298 server.prepareJSONResponse(map[string]any{
1299 "content": "",
1300 })
1301
1302 provider, err := New(
1303 WithAPIKey("test-api-key"),
1304 WithBaseURL(server.server.URL),
1305 )
1306 require.NoError(t, err)
1307 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1308
1309 _, err = model.Generate(context.Background(), fantasy.Call{
1310 Prompt: testPrompt,
1311 Tools: []fantasy.Tool{
1312 fantasy.FunctionTool{
1313 Name: "test-tool",
1314 InputSchema: map[string]any{
1315 "type": "object",
1316 "properties": map[string]any{
1317 "value": map[string]any{
1318 "type": "string",
1319 },
1320 },
1321 "required": []string{"value"},
1322 "additionalProperties": false,
1323 "$schema": "http://json-schema.org/draft-07/schema#",
1324 },
1325 },
1326 },
1327 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoiceRequired}[0],
1328 })
1329
1330 require.NoError(t, err)
1331 require.Len(t, server.calls, 1)
1332
1333 call := server.calls[0]
1334 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1335
1336 // Verify tool is present
1337 tools := call.body["tools"].([]any)
1338 require.Len(t, tools, 1)
1339
1340 tool := tools[0].(map[string]any)
1341 require.Equal(t, "function", tool["type"])
1342
1343 function := tool["function"].(map[string]any)
1344 require.Equal(t, "test-tool", function["name"])
1345
1346 // Verify tool_choice is set to "required" (not a function name)
1347 toolChoice := call.body["tool_choice"]
1348 require.Equal(t, "required", toolChoice)
1349 })
1350
1351 t.Run("should parse annotations/citations", func(t *testing.T) {
1352 t.Parallel()
1353
1354 server := newMockServer()
1355 defer server.close()
1356
1357 server.prepareJSONResponse(map[string]any{
1358 "content": "Based on the search results [doc1], I found information.",
1359 "annotations": []map[string]any{
1360 {
1361 "type": "url_citation",
1362 "url_citation": map[string]any{
1363 "start_index": 24,
1364 "end_index": 29,
1365 "url": "https://example.com/doc1.pdf",
1366 "title": "Document 1",
1367 },
1368 },
1369 },
1370 })
1371
1372 provider, err := New(
1373 WithAPIKey("test-api-key"),
1374 WithBaseURL(server.server.URL),
1375 )
1376 require.NoError(t, err)
1377 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1378
1379 result, err := model.Generate(context.Background(), fantasy.Call{
1380 Prompt: testPrompt,
1381 })
1382
1383 require.NoError(t, err)
1384 require.Len(t, result.Content, 2)
1385
1386 textContent, ok := result.Content[0].(fantasy.TextContent)
1387 require.True(t, ok)
1388 require.Equal(t, "Based on the search results [doc1], I found information.", textContent.Text)
1389
1390 sourceContent, ok := result.Content[1].(fantasy.SourceContent)
1391 require.True(t, ok)
1392 require.Equal(t, fantasy.SourceTypeURL, sourceContent.SourceType)
1393 require.Equal(t, "https://example.com/doc1.pdf", sourceContent.URL)
1394 require.Equal(t, "Document 1", sourceContent.Title)
1395 require.NotEmpty(t, sourceContent.ID)
1396 })
1397
1398 t.Run("should return cached_tokens in prompt_details_tokens", func(t *testing.T) {
1399 t.Parallel()
1400
1401 server := newMockServer()
1402 defer server.close()
1403
1404 server.prepareJSONResponse(map[string]any{
1405 "usage": map[string]any{
1406 "prompt_tokens": 15,
1407 "completion_tokens": 20,
1408 "total_tokens": 35,
1409 "prompt_tokens_details": map[string]any{
1410 "cached_tokens": 1152,
1411 },
1412 },
1413 })
1414
1415 provider, err := New(
1416 WithAPIKey("test-api-key"),
1417 WithBaseURL(server.server.URL),
1418 )
1419 require.NoError(t, err)
1420 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1421
1422 result, err := model.Generate(context.Background(), fantasy.Call{
1423 Prompt: testPrompt,
1424 })
1425
1426 require.NoError(t, err)
1427 require.Equal(t, int64(1152), result.Usage.CacheReadTokens)
1428 // InputTokens = prompt_tokens - cached_tokens = 15 - 1152 = -1137 → clamped to 0
1429 require.Equal(t, int64(0), result.Usage.InputTokens)
1430 require.Equal(t, int64(20), result.Usage.OutputTokens)
1431 require.Equal(t, int64(35), result.Usage.TotalTokens)
1432 })
1433
1434 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
1435 t.Parallel()
1436
1437 server := newMockServer()
1438 defer server.close()
1439
1440 server.prepareJSONResponse(map[string]any{
1441 "usage": map[string]any{
1442 "prompt_tokens": 15,
1443 "completion_tokens": 20,
1444 "total_tokens": 35,
1445 "completion_tokens_details": map[string]any{
1446 "accepted_prediction_tokens": 123,
1447 "rejected_prediction_tokens": 456,
1448 },
1449 },
1450 })
1451
1452 provider, err := New(
1453 WithAPIKey("test-api-key"),
1454 WithBaseURL(server.server.URL),
1455 )
1456 require.NoError(t, err)
1457 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1458
1459 result, err := model.Generate(context.Background(), fantasy.Call{
1460 Prompt: testPrompt,
1461 })
1462
1463 require.NoError(t, err)
1464 require.NotNil(t, result.ProviderMetadata)
1465
1466 openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
1467
1468 require.True(t, ok)
1469 require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
1470 require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
1471 })
1472
1473 t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) {
1474 t.Parallel()
1475
1476 server := newMockServer()
1477 defer server.close()
1478
1479 server.prepareJSONResponse(map[string]any{})
1480
1481 provider, err := New(
1482 WithAPIKey("test-api-key"),
1483 WithBaseURL(server.server.URL),
1484 )
1485 require.NoError(t, err)
1486 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1487
1488 result, err := model.Generate(context.Background(), fantasy.Call{
1489 Prompt: testPrompt,
1490 Temperature: &[]float64{0.5}[0],
1491 TopP: &[]float64{0.7}[0],
1492 FrequencyPenalty: &[]float64{0.2}[0],
1493 PresencePenalty: &[]float64{0.3}[0],
1494 })
1495
1496 require.NoError(t, err)
1497 require.Len(t, server.calls, 1)
1498
1499 call := server.calls[0]
1500 require.Equal(t, "o1-preview", call.body["model"])
1501
1502 messages := call.body["messages"].([]any)
1503 require.Len(t, messages, 1)
1504
1505 message := messages[0].(map[string]any)
1506 require.Equal(t, "user", message["role"])
1507 require.Equal(t, "Hello", message["content"])
1508
1509 // These should not be present
1510 require.Nil(t, call.body["temperature"])
1511 require.Nil(t, call.body["top_p"])
1512 require.Nil(t, call.body["frequency_penalty"])
1513 require.Nil(t, call.body["presence_penalty"])
1514
1515 // Should have warnings
1516 require.Len(t, result.Warnings, 4)
1517 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1518 require.Equal(t, "temperature", result.Warnings[0].Setting)
1519 require.Contains(t, result.Warnings[0].Details, "temperature is not supported for reasoning models")
1520 })
1521
1522 t.Run("should convert maxOutputTokens to max_completion_tokens for reasoning models", func(t *testing.T) {
1523 t.Parallel()
1524
1525 server := newMockServer()
1526 defer server.close()
1527
1528 server.prepareJSONResponse(map[string]any{})
1529
1530 provider, err := New(
1531 WithAPIKey("test-api-key"),
1532 WithBaseURL(server.server.URL),
1533 )
1534 require.NoError(t, err)
1535 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1536
1537 _, err = model.Generate(context.Background(), fantasy.Call{
1538 Prompt: testPrompt,
1539 MaxOutputTokens: &[]int64{1000}[0],
1540 })
1541
1542 require.NoError(t, err)
1543 require.Len(t, server.calls, 1)
1544
1545 call := server.calls[0]
1546 require.Equal(t, "o1-preview", call.body["model"])
1547 require.Equal(t, float64(1000), call.body["max_completion_tokens"])
1548 require.Nil(t, call.body["max_tokens"])
1549
1550 messages := call.body["messages"].([]any)
1551 require.Len(t, messages, 1)
1552
1553 message := messages[0].(map[string]any)
1554 require.Equal(t, "user", message["role"])
1555 require.Equal(t, "Hello", message["content"])
1556 })
1557
1558 t.Run("should return reasoning tokens", func(t *testing.T) {
1559 t.Parallel()
1560
1561 server := newMockServer()
1562 defer server.close()
1563
1564 server.prepareJSONResponse(map[string]any{
1565 "usage": map[string]any{
1566 "prompt_tokens": 15,
1567 "completion_tokens": 20,
1568 "total_tokens": 35,
1569 "completion_tokens_details": map[string]any{
1570 "reasoning_tokens": 10,
1571 },
1572 },
1573 })
1574
1575 provider, err := New(
1576 WithAPIKey("test-api-key"),
1577 WithBaseURL(server.server.URL),
1578 )
1579 require.NoError(t, err)
1580 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1581
1582 result, err := model.Generate(context.Background(), fantasy.Call{
1583 Prompt: testPrompt,
1584 })
1585
1586 require.NoError(t, err)
1587 require.Equal(t, int64(15), result.Usage.InputTokens)
1588 require.Equal(t, int64(20), result.Usage.OutputTokens)
1589 require.Equal(t, int64(35), result.Usage.TotalTokens)
1590 require.Equal(t, int64(10), result.Usage.ReasoningTokens)
1591 })
1592
1593 t.Run("should send max_completion_tokens extension setting", func(t *testing.T) {
1594 t.Parallel()
1595
1596 server := newMockServer()
1597 defer server.close()
1598
1599 server.prepareJSONResponse(map[string]any{
1600 "model": "o1-preview",
1601 })
1602
1603 provider, err := New(
1604 WithAPIKey("test-api-key"),
1605 WithBaseURL(server.server.URL),
1606 )
1607 require.NoError(t, err)
1608 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1609
1610 _, err = model.Generate(context.Background(), fantasy.Call{
1611 Prompt: testPrompt,
1612 ProviderOptions: NewProviderOptions(&ProviderOptions{
1613 MaxCompletionTokens: new(int64(255)),
1614 }),
1615 })
1616
1617 require.NoError(t, err)
1618 require.Len(t, server.calls, 1)
1619
1620 call := server.calls[0]
1621 require.Equal(t, "o1-preview", call.body["model"])
1622 require.Equal(t, float64(255), call.body["max_completion_tokens"])
1623
1624 messages := call.body["messages"].([]any)
1625 require.Len(t, messages, 1)
1626
1627 message := messages[0].(map[string]any)
1628 require.Equal(t, "user", message["role"])
1629 require.Equal(t, "Hello", message["content"])
1630 })
1631
1632 t.Run("should send prediction extension setting", func(t *testing.T) {
1633 t.Parallel()
1634
1635 server := newMockServer()
1636 defer server.close()
1637
1638 server.prepareJSONResponse(map[string]any{
1639 "content": "",
1640 })
1641
1642 provider, err := New(
1643 WithAPIKey("test-api-key"),
1644 WithBaseURL(server.server.URL),
1645 )
1646 require.NoError(t, err)
1647 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1648
1649 _, err = model.Generate(context.Background(), fantasy.Call{
1650 Prompt: testPrompt,
1651 ProviderOptions: NewProviderOptions(&ProviderOptions{
1652 Prediction: map[string]any{
1653 "type": "content",
1654 "content": "Hello, World!",
1655 },
1656 }),
1657 })
1658
1659 require.NoError(t, err)
1660 require.Len(t, server.calls, 1)
1661
1662 call := server.calls[0]
1663 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1664
1665 prediction := call.body["prediction"].(map[string]any)
1666 require.Equal(t, "content", prediction["type"])
1667 require.Equal(t, "Hello, World!", prediction["content"])
1668
1669 messages := call.body["messages"].([]any)
1670 require.Len(t, messages, 1)
1671
1672 message := messages[0].(map[string]any)
1673 require.Equal(t, "user", message["role"])
1674 require.Equal(t, "Hello", message["content"])
1675 })
1676
1677 t.Run("should send store extension setting", func(t *testing.T) {
1678 t.Parallel()
1679
1680 server := newMockServer()
1681 defer server.close()
1682
1683 server.prepareJSONResponse(map[string]any{
1684 "content": "",
1685 })
1686
1687 provider, err := New(
1688 WithAPIKey("test-api-key"),
1689 WithBaseURL(server.server.URL),
1690 )
1691 require.NoError(t, err)
1692 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1693
1694 _, err = model.Generate(context.Background(), fantasy.Call{
1695 Prompt: testPrompt,
1696 ProviderOptions: NewProviderOptions(&ProviderOptions{
1697 Store: new(true),
1698 }),
1699 })
1700
1701 require.NoError(t, err)
1702 require.Len(t, server.calls, 1)
1703
1704 call := server.calls[0]
1705 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1706 require.Equal(t, true, call.body["store"])
1707
1708 messages := call.body["messages"].([]any)
1709 require.Len(t, messages, 1)
1710
1711 message := messages[0].(map[string]any)
1712 require.Equal(t, "user", message["role"])
1713 require.Equal(t, "Hello", message["content"])
1714 })
1715
1716 t.Run("should send metadata extension values", func(t *testing.T) {
1717 t.Parallel()
1718
1719 server := newMockServer()
1720 defer server.close()
1721
1722 server.prepareJSONResponse(map[string]any{
1723 "content": "",
1724 })
1725
1726 provider, err := New(
1727 WithAPIKey("test-api-key"),
1728 WithBaseURL(server.server.URL),
1729 )
1730 require.NoError(t, err)
1731 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1732
1733 _, err = model.Generate(context.Background(), fantasy.Call{
1734 Prompt: testPrompt,
1735 ProviderOptions: NewProviderOptions(&ProviderOptions{
1736 Metadata: map[string]any{
1737 "custom": "value",
1738 },
1739 }),
1740 })
1741
1742 require.NoError(t, err)
1743 require.Len(t, server.calls, 1)
1744
1745 call := server.calls[0]
1746 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1747
1748 metadata := call.body["metadata"].(map[string]any)
1749 require.Equal(t, "value", metadata["custom"])
1750
1751 messages := call.body["messages"].([]any)
1752 require.Len(t, messages, 1)
1753
1754 message := messages[0].(map[string]any)
1755 require.Equal(t, "user", message["role"])
1756 require.Equal(t, "Hello", message["content"])
1757 })
1758
1759 t.Run("should send promptCacheKey extension value", func(t *testing.T) {
1760 t.Parallel()
1761
1762 server := newMockServer()
1763 defer server.close()
1764
1765 server.prepareJSONResponse(map[string]any{
1766 "content": "",
1767 })
1768
1769 provider, err := New(
1770 WithAPIKey("test-api-key"),
1771 WithBaseURL(server.server.URL),
1772 )
1773 require.NoError(t, err)
1774 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1775
1776 _, err = model.Generate(context.Background(), fantasy.Call{
1777 Prompt: testPrompt,
1778 ProviderOptions: NewProviderOptions(&ProviderOptions{
1779 PromptCacheKey: new("test-cache-key-123"),
1780 }),
1781 })
1782
1783 require.NoError(t, err)
1784 require.Len(t, server.calls, 1)
1785
1786 call := server.calls[0]
1787 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1788 require.Equal(t, "test-cache-key-123", call.body["prompt_cache_key"])
1789
1790 messages := call.body["messages"].([]any)
1791 require.Len(t, messages, 1)
1792
1793 message := messages[0].(map[string]any)
1794 require.Equal(t, "user", message["role"])
1795 require.Equal(t, "Hello", message["content"])
1796 })
1797
1798 t.Run("should send safety_identifier extension value", func(t *testing.T) {
1799 t.Parallel()
1800
1801 server := newMockServer()
1802 defer server.close()
1803
1804 server.prepareJSONResponse(map[string]any{
1805 "content": "",
1806 })
1807
1808 provider, err := New(
1809 WithAPIKey("test-api-key"),
1810 WithBaseURL(server.server.URL),
1811 )
1812 require.NoError(t, err)
1813 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1814
1815 _, err = model.Generate(context.Background(), fantasy.Call{
1816 Prompt: testPrompt,
1817 ProviderOptions: NewProviderOptions(&ProviderOptions{
1818 SafetyIdentifier: new("test-safety-identifier-123"),
1819 }),
1820 })
1821
1822 require.NoError(t, err)
1823 require.Len(t, server.calls, 1)
1824
1825 call := server.calls[0]
1826 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1827 require.Equal(t, "test-safety-identifier-123", call.body["safety_identifier"])
1828
1829 messages := call.body["messages"].([]any)
1830 require.Len(t, messages, 1)
1831
1832 message := messages[0].(map[string]any)
1833 require.Equal(t, "user", message["role"])
1834 require.Equal(t, "Hello", message["content"])
1835 })
1836
1837 t.Run("should remove temperature setting for search preview models", func(t *testing.T) {
1838 t.Parallel()
1839
1840 server := newMockServer()
1841 defer server.close()
1842
1843 server.prepareJSONResponse(map[string]any{})
1844
1845 provider, err := New(
1846 WithAPIKey("test-api-key"),
1847 WithBaseURL(server.server.URL),
1848 )
1849 require.NoError(t, err)
1850 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-search-preview")
1851
1852 result, err := model.Generate(context.Background(), fantasy.Call{
1853 Prompt: testPrompt,
1854 Temperature: &[]float64{0.7}[0],
1855 })
1856
1857 require.NoError(t, err)
1858 require.Len(t, server.calls, 1)
1859
1860 call := server.calls[0]
1861 require.Equal(t, "gpt-4o-search-preview", call.body["model"])
1862 require.Nil(t, call.body["temperature"])
1863
1864 require.Len(t, result.Warnings, 1)
1865 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1866 require.Equal(t, "temperature", result.Warnings[0].Setting)
1867 require.Contains(t, result.Warnings[0].Details, "search preview models")
1868 })
1869
1870 t.Run("should send ServiceTier flex processing setting", func(t *testing.T) {
1871 t.Parallel()
1872
1873 server := newMockServer()
1874 defer server.close()
1875
1876 server.prepareJSONResponse(map[string]any{
1877 "content": "",
1878 })
1879
1880 provider, err := New(
1881 WithAPIKey("test-api-key"),
1882 WithBaseURL(server.server.URL),
1883 )
1884 require.NoError(t, err)
1885 model, _ := provider.LanguageModel(t.Context(), "o3-mini")
1886
1887 _, err = model.Generate(context.Background(), fantasy.Call{
1888 Prompt: testPrompt,
1889 ProviderOptions: NewProviderOptions(&ProviderOptions{
1890 ServiceTier: new("flex"),
1891 }),
1892 })
1893
1894 require.NoError(t, err)
1895 require.Len(t, server.calls, 1)
1896
1897 call := server.calls[0]
1898 require.Equal(t, "o3-mini", call.body["model"])
1899 require.Equal(t, "flex", call.body["service_tier"])
1900
1901 messages := call.body["messages"].([]any)
1902 require.Len(t, messages, 1)
1903
1904 message := messages[0].(map[string]any)
1905 require.Equal(t, "user", message["role"])
1906 require.Equal(t, "Hello", message["content"])
1907 })
1908
1909 t.Run("should show warning when using flex processing with unsupported model", func(t *testing.T) {
1910 t.Parallel()
1911
1912 server := newMockServer()
1913 defer server.close()
1914
1915 server.prepareJSONResponse(map[string]any{})
1916
1917 provider, err := New(
1918 WithAPIKey("test-api-key"),
1919 WithBaseURL(server.server.URL),
1920 )
1921 require.NoError(t, err)
1922 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1923
1924 result, err := model.Generate(context.Background(), fantasy.Call{
1925 Prompt: testPrompt,
1926 ProviderOptions: NewProviderOptions(&ProviderOptions{
1927 ServiceTier: new("flex"),
1928 }),
1929 })
1930
1931 require.NoError(t, err)
1932 require.Len(t, server.calls, 1)
1933
1934 call := server.calls[0]
1935 require.Nil(t, call.body["service_tier"])
1936
1937 require.Len(t, result.Warnings, 1)
1938 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1939 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
1940 require.Contains(t, result.Warnings[0].Details, "flex processing is only available")
1941 })
1942
1943 t.Run("should send serviceTier priority processing setting", func(t *testing.T) {
1944 t.Parallel()
1945
1946 server := newMockServer()
1947 defer server.close()
1948
1949 server.prepareJSONResponse(map[string]any{})
1950
1951 provider, err := New(
1952 WithAPIKey("test-api-key"),
1953 WithBaseURL(server.server.URL),
1954 )
1955 require.NoError(t, err)
1956 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1957
1958 _, err = model.Generate(context.Background(), fantasy.Call{
1959 Prompt: testPrompt,
1960 ProviderOptions: NewProviderOptions(&ProviderOptions{
1961 ServiceTier: new("priority"),
1962 }),
1963 })
1964
1965 require.NoError(t, err)
1966 require.Len(t, server.calls, 1)
1967
1968 call := server.calls[0]
1969 require.Equal(t, "gpt-4o-mini", call.body["model"])
1970 require.Equal(t, "priority", call.body["service_tier"])
1971
1972 messages := call.body["messages"].([]any)
1973 require.Len(t, messages, 1)
1974
1975 message := messages[0].(map[string]any)
1976 require.Equal(t, "user", message["role"])
1977 require.Equal(t, "Hello", message["content"])
1978 })
1979
1980 t.Run("should show warning when using priority processing with unsupported model", func(t *testing.T) {
1981 t.Parallel()
1982
1983 server := newMockServer()
1984 defer server.close()
1985
1986 server.prepareJSONResponse(map[string]any{})
1987
1988 provider, err := New(
1989 WithAPIKey("test-api-key"),
1990 WithBaseURL(server.server.URL),
1991 )
1992 require.NoError(t, err)
1993 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1994
1995 result, err := model.Generate(context.Background(), fantasy.Call{
1996 Prompt: testPrompt,
1997 ProviderOptions: NewProviderOptions(&ProviderOptions{
1998 ServiceTier: new("priority"),
1999 }),
2000 })
2001
2002 require.NoError(t, err)
2003 require.Len(t, server.calls, 1)
2004
2005 call := server.calls[0]
2006 require.Nil(t, call.body["service_tier"])
2007
2008 require.Len(t, result.Warnings, 1)
2009 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
2010 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
2011 require.Contains(t, result.Warnings[0].Details, "priority processing is only available")
2012 })
2013
2014 t.Run("should return error instead of panic on empty response body", func(t *testing.T) {
2015 t.Parallel()
2016
2017 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2018 w.Header().Set("Content-Type", "application/json")
2019 w.WriteHeader(http.StatusOK)
2020 // Write empty body — some OpenAI-compatible endpoints may do this
2021 // under edge conditions, causing the SDK to return (nil, nil).
2022 }))
2023 defer server.Close()
2024
2025 provider, err := New(
2026 WithAPIKey("test-api-key"),
2027 WithBaseURL(server.URL),
2028 )
2029 require.NoError(t, err)
2030 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2031
2032 require.NotPanics(t, func() {
2033 _, _ = model.Generate(context.Background(), fantasy.Call{
2034 Prompt: testPrompt,
2035 })
2036 })
2037 })
2038}
2039
2040type streamingMockServer struct {
2041 server *httptest.Server
2042 chunks []string
2043 calls []mockCall
2044}
2045
2046func newStreamingMockServer() *streamingMockServer {
2047 sms := &streamingMockServer{
2048 calls: make([]mockCall, 0),
2049 }
2050
2051 sms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2052 // Record the call
2053 call := mockCall{
2054 method: r.Method,
2055 path: r.URL.Path,
2056 headers: make(map[string]string),
2057 }
2058
2059 for k, v := range r.Header {
2060 if len(v) > 0 {
2061 call.headers[k] = v[0]
2062 }
2063 }
2064
2065 // Parse request body
2066 if r.Body != nil {
2067 var body map[string]any
2068 json.NewDecoder(r.Body).Decode(&body)
2069 call.body = body
2070 }
2071
2072 sms.calls = append(sms.calls, call)
2073
2074 // Set streaming headers
2075 w.Header().Set("Content-Type", "text/event-stream")
2076 w.Header().Set("Cache-Control", "no-cache")
2077 w.Header().Set("Connection", "keep-alive")
2078
2079 // Add custom headers if any
2080 for _, chunk := range sms.chunks {
2081 if strings.HasPrefix(chunk, "HEADER:") {
2082 parts := strings.SplitN(chunk[7:], ":", 2)
2083 if len(parts) == 2 {
2084 w.Header().Set(parts[0], parts[1])
2085 }
2086 continue
2087 }
2088 }
2089
2090 w.WriteHeader(http.StatusOK)
2091
2092 // Write chunks
2093 for _, chunk := range sms.chunks {
2094 if strings.HasPrefix(chunk, "HEADER:") {
2095 continue
2096 }
2097 w.Write([]byte(chunk))
2098 if f, ok := w.(http.Flusher); ok {
2099 f.Flush()
2100 }
2101 }
2102 }))
2103
2104 return sms
2105}
2106
2107func (sms *streamingMockServer) close() {
2108 sms.server.Close()
2109}
2110
2111func (sms *streamingMockServer) prepareStreamResponse(opts map[string]any) {
2112 content := []string{}
2113 if c, ok := opts["content"].([]string); ok {
2114 content = c
2115 }
2116
2117 usage := map[string]any{
2118 "prompt_tokens": 17,
2119 "total_tokens": 244,
2120 "completion_tokens": 227,
2121 }
2122 if u, ok := opts["usage"].(map[string]any); ok {
2123 usage = u
2124 }
2125
2126 logprobs := map[string]any{}
2127 if l, ok := opts["logprobs"].(map[string]any); ok {
2128 logprobs = l
2129 }
2130
2131 finishReason := "stop"
2132 if fr, ok := opts["finish_reason"].(string); ok {
2133 finishReason = fr
2134 }
2135
2136 model := "gpt-3.5-turbo-0613"
2137 if m, ok := opts["model"].(string); ok {
2138 model = m
2139 }
2140
2141 headers := map[string]string{}
2142 if h, ok := opts["headers"].(map[string]string); ok {
2143 headers = h
2144 }
2145
2146 chunks := []string{}
2147
2148 // Add custom headers
2149 for k, v := range headers {
2150 chunks = append(chunks, "HEADER:"+k+":"+v)
2151 }
2152
2153 // Initial chunk with role
2154 initialChunk := map[string]any{
2155 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2156 "object": "chat.completion.chunk",
2157 "created": 1702657020,
2158 "model": model,
2159 "system_fingerprint": nil,
2160 "choices": []map[string]any{
2161 {
2162 "index": 0,
2163 "delta": map[string]any{
2164 "role": "assistant",
2165 "content": "",
2166 },
2167 "finish_reason": nil,
2168 },
2169 },
2170 }
2171 initialData, _ := json.Marshal(initialChunk)
2172 chunks = append(chunks, "data: "+string(initialData)+"\n\n")
2173
2174 // Content chunks
2175 for i, text := range content {
2176 contentChunk := map[string]any{
2177 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2178 "object": "chat.completion.chunk",
2179 "created": 1702657020,
2180 "model": model,
2181 "system_fingerprint": nil,
2182 "choices": []map[string]any{
2183 {
2184 "index": 1,
2185 "delta": map[string]any{
2186 "content": text,
2187 },
2188 "finish_reason": nil,
2189 },
2190 },
2191 }
2192 contentData, _ := json.Marshal(contentChunk)
2193 chunks = append(chunks, "data: "+string(contentData)+"\n\n")
2194
2195 // Add annotations if this is the last content chunk and we have annotations
2196 if i == len(content)-1 {
2197 if annotations, ok := opts["annotations"].([]map[string]any); ok {
2198 annotationChunk := map[string]any{
2199 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2200 "object": "chat.completion.chunk",
2201 "created": 1702657020,
2202 "model": model,
2203 "system_fingerprint": nil,
2204 "choices": []map[string]any{
2205 {
2206 "index": 1,
2207 "delta": map[string]any{
2208 "annotations": annotations,
2209 },
2210 "finish_reason": nil,
2211 },
2212 },
2213 }
2214 annotationData, _ := json.Marshal(annotationChunk)
2215 chunks = append(chunks, "data: "+string(annotationData)+"\n\n")
2216 }
2217 }
2218 }
2219
2220 // Finish chunk
2221 finishChunk := map[string]any{
2222 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2223 "object": "chat.completion.chunk",
2224 "created": 1702657020,
2225 "model": model,
2226 "system_fingerprint": nil,
2227 "choices": []map[string]any{
2228 {
2229 "index": 0,
2230 "delta": map[string]any{},
2231 "finish_reason": finishReason,
2232 },
2233 },
2234 }
2235
2236 if len(logprobs) > 0 {
2237 finishChunk["choices"].([]map[string]any)[0]["logprobs"] = logprobs
2238 }
2239
2240 finishData, _ := json.Marshal(finishChunk)
2241 chunks = append(chunks, "data: "+string(finishData)+"\n\n")
2242
2243 // Usage chunk
2244 usageChunk := map[string]any{
2245 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2246 "object": "chat.completion.chunk",
2247 "created": 1702657020,
2248 "model": model,
2249 "system_fingerprint": "fp_3bc1b5746c",
2250 "choices": []map[string]any{},
2251 "usage": usage,
2252 }
2253 usageData, _ := json.Marshal(usageChunk)
2254 chunks = append(chunks, "data: "+string(usageData)+"\n\n")
2255
2256 // Done
2257 chunks = append(chunks, "data: [DONE]\n\n")
2258
2259 sms.chunks = chunks
2260}
2261
2262func (sms *streamingMockServer) prepareToolStreamResponse() {
2263 chunks := []string{
2264 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2265 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2266 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"value"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2267 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2268 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2269 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2270 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2271 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2272 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
2273 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
2274 "data: [DONE]\n\n",
2275 }
2276 sms.chunks = chunks
2277}
2278
2279func (sms *streamingMockServer) prepareErrorStreamResponse() {
2280 chunks := []string{
2281 `data: {"error":{"message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error.","type":"server_error","param":null,"code":null}}` + "\n\n",
2282 "data: [DONE]\n\n",
2283 }
2284 sms.chunks = chunks
2285}
2286
2287func (sms *streamingMockServer) prepareToolStreamResponseWithEmptyArgs() {
2288 chunks := []string{
2289 // Tool call start with empty arguments (like Copilot sometimes does)
2290 `data: {"id":"chatcmpl-emptyargs","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_empty_args","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2291 // Finish without any argument deltas
2292 `data: {"id":"chatcmpl-emptyargs","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
2293 `data: {"id":"chatcmpl-emptyargs","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
2294 "data: [DONE]\n\n",
2295 }
2296 sms.chunks = chunks
2297}
2298
2299func (sms *streamingMockServer) prepareToolStreamResponseWithInvalidJSON() {
2300 chunks := []string{
2301 // Tool call start
2302 `data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_invalid_json","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2303 // Arguments delta containing \x00 which is not a valid JSON escape
2304 `data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"old_string\":\"hello\\x00"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2305 // Remaining arguments — combined is {"old_string":"hello\x00world"} which is invalid JSON
2306 `data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"world\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2307 // Finish with tool_calls
2308 `data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
2309 // Usage
2310 `data: {"id":"chatcmpl-invalid","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
2311 "data: [DONE]\n\n",
2312 }
2313 sms.chunks = chunks
2314}
2315
2316func collectStreamParts(stream fantasy.StreamResponse) ([]fantasy.StreamPart, error) {
2317 var parts []fantasy.StreamPart
2318 for part := range stream {
2319 parts = append(parts, part)
2320 if part.Type == fantasy.StreamPartTypeError {
2321 break
2322 }
2323 if part.Type == fantasy.StreamPartTypeFinish {
2324 break
2325 }
2326 }
2327 return parts, nil
2328}
2329
2330func TestDoStream(t *testing.T) {
2331 t.Parallel()
2332
2333 t.Run("should stream text deltas", func(t *testing.T) {
2334 t.Parallel()
2335
2336 server := newStreamingMockServer()
2337 defer server.close()
2338
2339 server.prepareStreamResponse(map[string]any{
2340 "content": []string{"Hello", ", ", "World!"},
2341 "finish_reason": "stop",
2342 "usage": map[string]any{
2343 "prompt_tokens": 17,
2344 "total_tokens": 244,
2345 "completion_tokens": 227,
2346 },
2347 "logprobs": testLogprobs,
2348 })
2349
2350 provider, err := New(
2351 WithAPIKey("test-api-key"),
2352 WithBaseURL(server.server.URL),
2353 )
2354 require.NoError(t, err)
2355 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2356
2357 stream, err := model.Stream(context.Background(), fantasy.Call{
2358 Prompt: testPrompt,
2359 })
2360
2361 require.NoError(t, err)
2362
2363 parts, err := collectStreamParts(stream)
2364 require.NoError(t, err)
2365
2366 // Verify stream structure
2367 require.True(t, len(parts) >= 4) // text-start, deltas, text-end, finish
2368
2369 // Find text parts
2370 textStart, textEnd, finish := -1, -1, -1
2371 var deltas []string
2372
2373 for i, part := range parts {
2374 switch part.Type {
2375 case fantasy.StreamPartTypeTextStart:
2376 textStart = i
2377 case fantasy.StreamPartTypeTextDelta:
2378 deltas = append(deltas, part.Delta)
2379 case fantasy.StreamPartTypeTextEnd:
2380 textEnd = i
2381 case fantasy.StreamPartTypeFinish:
2382 finish = i
2383 }
2384 }
2385
2386 require.NotEqual(t, -1, textStart)
2387 require.NotEqual(t, -1, textEnd)
2388 require.NotEqual(t, -1, finish)
2389 require.Equal(t, []string{"Hello", ", ", "World!"}, deltas)
2390
2391 // Check finish part
2392 finishPart := parts[finish]
2393 require.Equal(t, fantasy.FinishReasonStop, finishPart.FinishReason)
2394 require.Equal(t, int64(17), finishPart.Usage.InputTokens)
2395 require.Equal(t, int64(227), finishPart.Usage.OutputTokens)
2396 require.Equal(t, int64(244), finishPart.Usage.TotalTokens)
2397 })
2398
2399 t.Run("should stream tool deltas", func(t *testing.T) {
2400 t.Parallel()
2401
2402 server := newStreamingMockServer()
2403 defer server.close()
2404
2405 server.prepareToolStreamResponse()
2406
2407 provider, err := New(
2408 WithAPIKey("test-api-key"),
2409 WithBaseURL(server.server.URL),
2410 )
2411 require.NoError(t, err)
2412 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2413
2414 stream, err := model.Stream(context.Background(), fantasy.Call{
2415 Prompt: testPrompt,
2416 Tools: []fantasy.Tool{
2417 fantasy.FunctionTool{
2418 Name: "test-tool",
2419 InputSchema: map[string]any{
2420 "type": "object",
2421 "properties": map[string]any{
2422 "value": map[string]any{
2423 "type": "string",
2424 },
2425 },
2426 "required": []string{"value"},
2427 "additionalProperties": false,
2428 "$schema": "http://json-schema.org/draft-07/schema#",
2429 },
2430 },
2431 },
2432 })
2433
2434 require.NoError(t, err)
2435
2436 parts, err := collectStreamParts(stream)
2437 require.NoError(t, err)
2438
2439 // Find tool-related parts
2440 toolInputStart, toolInputEnd, toolCall := -1, -1, -1
2441 var toolDeltas []string
2442
2443 for i, part := range parts {
2444 switch part.Type {
2445 case fantasy.StreamPartTypeToolInputStart:
2446 toolInputStart = i
2447 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2448 require.Equal(t, "test-tool", part.ToolCallName)
2449 case fantasy.StreamPartTypeToolInputDelta:
2450 toolDeltas = append(toolDeltas, part.Delta)
2451 case fantasy.StreamPartTypeToolInputEnd:
2452 toolInputEnd = i
2453 case fantasy.StreamPartTypeToolCall:
2454 toolCall = i
2455 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2456 require.Equal(t, "test-tool", part.ToolCallName)
2457 require.Equal(t, `{"value":"Sparkle Day"}`, part.ToolCallInput)
2458 }
2459 }
2460
2461 require.NotEqual(t, -1, toolInputStart)
2462 require.NotEqual(t, -1, toolInputEnd)
2463 require.NotEqual(t, -1, toolCall)
2464
2465 // Verify tool deltas combine to form the complete input
2466 var fullInput strings.Builder
2467 for _, delta := range toolDeltas {
2468 fullInput.WriteString(delta)
2469 }
2470 require.Equal(t, `{"value":"Sparkle Day"}`, fullInput.String())
2471 })
2472
2473 t.Run("should handle tool calls with empty arguments", func(t *testing.T) {
2474 t.Parallel()
2475
2476 server := newStreamingMockServer()
2477 defer server.close()
2478
2479 server.prepareToolStreamResponseWithEmptyArgs()
2480
2481 provider, err := New(
2482 WithAPIKey("test-api-key"),
2483 WithBaseURL(server.server.URL),
2484 )
2485 require.NoError(t, err)
2486 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2487
2488 stream, err := model.Stream(context.Background(), fantasy.Call{
2489 Prompt: testPrompt,
2490 Tools: []fantasy.Tool{
2491 fantasy.FunctionTool{
2492 Name: "test-tool",
2493 InputSchema: map[string]any{
2494 "type": "object",
2495 "properties": map[string]any{
2496 "value": map[string]any{
2497 "type": "string",
2498 },
2499 },
2500 "required": []string{"value"},
2501 "additionalProperties": false,
2502 "$schema": "http://json-schema.org/draft-07/schema#",
2503 },
2504 },
2505 },
2506 })
2507
2508 require.NoError(t, err)
2509
2510 parts, err := collectStreamParts(stream)
2511 require.NoError(t, err)
2512
2513 // Find tool-related parts
2514 toolInputStart, toolInputEnd, toolCall := -1, -1, -1
2515
2516 for i, part := range parts {
2517 switch part.Type {
2518 case fantasy.StreamPartTypeToolInputStart:
2519 toolInputStart = i
2520 require.Equal(t, "call_empty_args", part.ID)
2521 require.Equal(t, "test-tool", part.ToolCallName)
2522 case fantasy.StreamPartTypeToolInputEnd:
2523 toolInputEnd = i
2524 require.Equal(t, "call_empty_args", part.ID)
2525 case fantasy.StreamPartTypeToolCall:
2526 toolCall = i
2527 require.Equal(t, "call_empty_args", part.ID)
2528 require.Equal(t, "test-tool", part.ToolCallName)
2529 // Empty arguments should be normalized to "{}"
2530 require.Equal(t, "{}", part.ToolCallInput)
2531 }
2532 }
2533
2534 require.NotEqual(t, -1, toolInputStart, "expected ToolInputStart part")
2535 require.NotEqual(t, -1, toolInputEnd, "expected ToolInputEnd part")
2536 require.NotEqual(t, -1, toolCall, "expected ToolCall part")
2537 })
2538
2539 t.Run("should stream annotations/citations", func(t *testing.T) {
2540 t.Parallel()
2541
2542 server := newStreamingMockServer()
2543 defer server.close()
2544
2545 server.prepareStreamResponse(map[string]any{
2546 "content": []string{"Based on search results"},
2547 "annotations": []map[string]any{
2548 {
2549 "type": "url_citation",
2550 "url_citation": map[string]any{
2551 "start_index": 24,
2552 "end_index": 29,
2553 "url": "https://example.com/doc1.pdf",
2554 "title": "Document 1",
2555 },
2556 },
2557 },
2558 })
2559
2560 provider, err := New(
2561 WithAPIKey("test-api-key"),
2562 WithBaseURL(server.server.URL),
2563 )
2564 require.NoError(t, err)
2565 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2566
2567 stream, err := model.Stream(context.Background(), fantasy.Call{
2568 Prompt: testPrompt,
2569 })
2570
2571 require.NoError(t, err)
2572
2573 parts, err := collectStreamParts(stream)
2574 require.NoError(t, err)
2575
2576 // Find source part
2577 var sourcePart *fantasy.StreamPart
2578 for _, part := range parts {
2579 if part.Type == fantasy.StreamPartTypeSource {
2580 sourcePart = &part
2581 break
2582 }
2583 }
2584
2585 require.NotNil(t, sourcePart)
2586 require.Equal(t, fantasy.SourceTypeURL, sourcePart.SourceType)
2587 require.Equal(t, "https://example.com/doc1.pdf", sourcePart.URL)
2588 require.Equal(t, "Document 1", sourcePart.Title)
2589 require.NotEmpty(t, sourcePart.ID)
2590 })
2591
2592 t.Run("should handle error stream parts", func(t *testing.T) {
2593 t.Parallel()
2594
2595 server := newStreamingMockServer()
2596 defer server.close()
2597
2598 server.prepareErrorStreamResponse()
2599
2600 provider, err := New(
2601 WithAPIKey("test-api-key"),
2602 WithBaseURL(server.server.URL),
2603 )
2604 require.NoError(t, err)
2605 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2606
2607 stream, err := model.Stream(context.Background(), fantasy.Call{
2608 Prompt: testPrompt,
2609 })
2610
2611 require.NoError(t, err)
2612
2613 parts, err := collectStreamParts(stream)
2614 require.NoError(t, err)
2615
2616 // Should have error and finish parts
2617 require.True(t, len(parts) >= 1)
2618
2619 // Find error part
2620 var errorPart *fantasy.StreamPart
2621 for _, part := range parts {
2622 if part.Type == fantasy.StreamPartTypeError {
2623 errorPart = &part
2624 break
2625 }
2626 }
2627
2628 require.NotNil(t, errorPart)
2629 require.NotNil(t, errorPart.Error)
2630 })
2631
2632 t.Run("should send request body", func(t *testing.T) {
2633 t.Parallel()
2634
2635 server := newStreamingMockServer()
2636 defer server.close()
2637
2638 server.prepareStreamResponse(map[string]any{
2639 "content": []string{},
2640 })
2641
2642 provider, err := New(
2643 WithAPIKey("test-api-key"),
2644 WithBaseURL(server.server.URL),
2645 )
2646 require.NoError(t, err)
2647 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2648
2649 _, err = model.Stream(context.Background(), fantasy.Call{
2650 Prompt: testPrompt,
2651 })
2652
2653 require.NoError(t, err)
2654 require.Len(t, server.calls, 1)
2655
2656 call := server.calls[0]
2657 require.Equal(t, "POST", call.method)
2658 require.Equal(t, "/chat/completions", call.path)
2659 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2660 require.Equal(t, true, call.body["stream"])
2661
2662 streamOptions := call.body["stream_options"].(map[string]any)
2663 require.Equal(t, true, streamOptions["include_usage"])
2664
2665 messages := call.body["messages"].([]any)
2666 require.Len(t, messages, 1)
2667
2668 message := messages[0].(map[string]any)
2669 require.Equal(t, "user", message["role"])
2670 require.Equal(t, "Hello", message["content"])
2671 })
2672
2673 t.Run("should return cached tokens in providerMetadata", func(t *testing.T) {
2674 t.Parallel()
2675
2676 server := newStreamingMockServer()
2677 defer server.close()
2678
2679 server.prepareStreamResponse(map[string]any{
2680 "content": []string{},
2681 "usage": map[string]any{
2682 "prompt_tokens": 15,
2683 "completion_tokens": 20,
2684 "total_tokens": 35,
2685 "prompt_tokens_details": map[string]any{
2686 "cached_tokens": 1152,
2687 },
2688 },
2689 })
2690
2691 provider, err := New(
2692 WithAPIKey("test-api-key"),
2693 WithBaseURL(server.server.URL),
2694 )
2695 require.NoError(t, err)
2696 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2697
2698 stream, err := model.Stream(context.Background(), fantasy.Call{
2699 Prompt: testPrompt,
2700 })
2701
2702 require.NoError(t, err)
2703
2704 parts, err := collectStreamParts(stream)
2705 require.NoError(t, err)
2706
2707 // Find finish part
2708 var finishPart *fantasy.StreamPart
2709 for _, part := range parts {
2710 if part.Type == fantasy.StreamPartTypeFinish {
2711 finishPart = &part
2712 break
2713 }
2714 }
2715
2716 require.NotNil(t, finishPart)
2717 require.Equal(t, int64(1152), finishPart.Usage.CacheReadTokens)
2718 // InputTokens = prompt_tokens - cached_tokens = 15 - 1152 = -1137 → clamped to 0
2719 require.Equal(t, int64(0), finishPart.Usage.InputTokens)
2720 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2721 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2722 })
2723
2724 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
2725 t.Parallel()
2726
2727 server := newStreamingMockServer()
2728 defer server.close()
2729
2730 server.prepareStreamResponse(map[string]any{
2731 "content": []string{},
2732 "usage": map[string]any{
2733 "prompt_tokens": 15,
2734 "completion_tokens": 20,
2735 "total_tokens": 35,
2736 "completion_tokens_details": map[string]any{
2737 "accepted_prediction_tokens": 123,
2738 "rejected_prediction_tokens": 456,
2739 },
2740 },
2741 })
2742
2743 provider, err := New(
2744 WithAPIKey("test-api-key"),
2745 WithBaseURL(server.server.URL),
2746 )
2747 require.NoError(t, err)
2748 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2749
2750 stream, err := model.Stream(context.Background(), fantasy.Call{
2751 Prompt: testPrompt,
2752 })
2753
2754 require.NoError(t, err)
2755
2756 parts, err := collectStreamParts(stream)
2757 require.NoError(t, err)
2758
2759 // Find finish part
2760 var finishPart *fantasy.StreamPart
2761 for _, part := range parts {
2762 if part.Type == fantasy.StreamPartTypeFinish {
2763 finishPart = &part
2764 break
2765 }
2766 }
2767
2768 require.NotNil(t, finishPart)
2769 require.NotNil(t, finishPart.ProviderMetadata)
2770
2771 openaiMeta, ok := finishPart.ProviderMetadata["openai"].(*ProviderMetadata)
2772 require.True(t, ok)
2773 require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
2774 require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
2775 })
2776
2777 t.Run("should send store extension setting", func(t *testing.T) {
2778 t.Parallel()
2779
2780 server := newStreamingMockServer()
2781 defer server.close()
2782
2783 server.prepareStreamResponse(map[string]any{
2784 "content": []string{},
2785 })
2786
2787 provider, err := New(
2788 WithAPIKey("test-api-key"),
2789 WithBaseURL(server.server.URL),
2790 )
2791 require.NoError(t, err)
2792 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2793
2794 _, err = model.Stream(context.Background(), fantasy.Call{
2795 Prompt: testPrompt,
2796 ProviderOptions: NewProviderOptions(&ProviderOptions{
2797 Store: new(true),
2798 }),
2799 })
2800
2801 require.NoError(t, err)
2802 require.Len(t, server.calls, 1)
2803
2804 call := server.calls[0]
2805 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2806 require.Equal(t, true, call.body["stream"])
2807 require.Equal(t, true, call.body["store"])
2808
2809 streamOptions := call.body["stream_options"].(map[string]any)
2810 require.Equal(t, true, streamOptions["include_usage"])
2811
2812 messages := call.body["messages"].([]any)
2813 require.Len(t, messages, 1)
2814
2815 message := messages[0].(map[string]any)
2816 require.Equal(t, "user", message["role"])
2817 require.Equal(t, "Hello", message["content"])
2818 })
2819
2820 t.Run("should send metadata extension values", func(t *testing.T) {
2821 t.Parallel()
2822
2823 server := newStreamingMockServer()
2824 defer server.close()
2825
2826 server.prepareStreamResponse(map[string]any{
2827 "content": []string{},
2828 })
2829
2830 provider, err := New(
2831 WithAPIKey("test-api-key"),
2832 WithBaseURL(server.server.URL),
2833 )
2834 require.NoError(t, err)
2835 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2836
2837 _, err = model.Stream(context.Background(), fantasy.Call{
2838 Prompt: testPrompt,
2839 ProviderOptions: NewProviderOptions(&ProviderOptions{
2840 Metadata: map[string]any{
2841 "custom": "value",
2842 },
2843 }),
2844 })
2845
2846 require.NoError(t, err)
2847 require.Len(t, server.calls, 1)
2848
2849 call := server.calls[0]
2850 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2851 require.Equal(t, true, call.body["stream"])
2852
2853 metadata := call.body["metadata"].(map[string]any)
2854 require.Equal(t, "value", metadata["custom"])
2855
2856 streamOptions := call.body["stream_options"].(map[string]any)
2857 require.Equal(t, true, streamOptions["include_usage"])
2858
2859 messages := call.body["messages"].([]any)
2860 require.Len(t, messages, 1)
2861
2862 message := messages[0].(map[string]any)
2863 require.Equal(t, "user", message["role"])
2864 require.Equal(t, "Hello", message["content"])
2865 })
2866
2867 t.Run("should send serviceTier flex processing setting in streaming", func(t *testing.T) {
2868 t.Parallel()
2869
2870 server := newStreamingMockServer()
2871 defer server.close()
2872
2873 server.prepareStreamResponse(map[string]any{
2874 "content": []string{},
2875 })
2876
2877 provider, err := New(
2878 WithAPIKey("test-api-key"),
2879 WithBaseURL(server.server.URL),
2880 )
2881 require.NoError(t, err)
2882 model, _ := provider.LanguageModel(t.Context(), "o3-mini")
2883
2884 _, err = model.Stream(context.Background(), fantasy.Call{
2885 Prompt: testPrompt,
2886 ProviderOptions: NewProviderOptions(&ProviderOptions{
2887 ServiceTier: new("flex"),
2888 }),
2889 })
2890
2891 require.NoError(t, err)
2892 require.Len(t, server.calls, 1)
2893
2894 call := server.calls[0]
2895 require.Equal(t, "o3-mini", call.body["model"])
2896 require.Equal(t, "flex", call.body["service_tier"])
2897 require.Equal(t, true, call.body["stream"])
2898
2899 streamOptions := call.body["stream_options"].(map[string]any)
2900 require.Equal(t, true, streamOptions["include_usage"])
2901
2902 messages := call.body["messages"].([]any)
2903 require.Len(t, messages, 1)
2904
2905 message := messages[0].(map[string]any)
2906 require.Equal(t, "user", message["role"])
2907 require.Equal(t, "Hello", message["content"])
2908 })
2909
2910 t.Run("should send serviceTier priority processing setting in streaming", func(t *testing.T) {
2911 t.Parallel()
2912
2913 server := newStreamingMockServer()
2914 defer server.close()
2915
2916 server.prepareStreamResponse(map[string]any{
2917 "content": []string{},
2918 })
2919
2920 provider, err := New(
2921 WithAPIKey("test-api-key"),
2922 WithBaseURL(server.server.URL),
2923 )
2924 require.NoError(t, err)
2925 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
2926
2927 _, err = model.Stream(context.Background(), fantasy.Call{
2928 Prompt: testPrompt,
2929 ProviderOptions: NewProviderOptions(&ProviderOptions{
2930 ServiceTier: new("priority"),
2931 }),
2932 })
2933
2934 require.NoError(t, err)
2935 require.Len(t, server.calls, 1)
2936
2937 call := server.calls[0]
2938 require.Equal(t, "gpt-4o-mini", call.body["model"])
2939 require.Equal(t, "priority", call.body["service_tier"])
2940 require.Equal(t, true, call.body["stream"])
2941
2942 streamOptions := call.body["stream_options"].(map[string]any)
2943 require.Equal(t, true, streamOptions["include_usage"])
2944
2945 messages := call.body["messages"].([]any)
2946 require.Len(t, messages, 1)
2947
2948 message := messages[0].(map[string]any)
2949 require.Equal(t, "user", message["role"])
2950 require.Equal(t, "Hello", message["content"])
2951 })
2952
2953 t.Run("should stream text delta for reasoning models", func(t *testing.T) {
2954 t.Parallel()
2955
2956 server := newStreamingMockServer()
2957 defer server.close()
2958
2959 server.prepareStreamResponse(map[string]any{
2960 "content": []string{"Hello, World!"},
2961 "model": "o1-preview",
2962 })
2963
2964 provider, err := New(
2965 WithAPIKey("test-api-key"),
2966 WithBaseURL(server.server.URL),
2967 )
2968 require.NoError(t, err)
2969 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
2970
2971 stream, err := model.Stream(context.Background(), fantasy.Call{
2972 Prompt: testPrompt,
2973 })
2974
2975 require.NoError(t, err)
2976
2977 parts, err := collectStreamParts(stream)
2978 require.NoError(t, err)
2979
2980 // Find text parts
2981 var textDeltas []string
2982 for _, part := range parts {
2983 if part.Type == fantasy.StreamPartTypeTextDelta {
2984 textDeltas = append(textDeltas, part.Delta)
2985 }
2986 }
2987
2988 // Should contain the text content (without empty delta)
2989 require.Equal(t, []string{"Hello, World!"}, textDeltas)
2990 })
2991
2992 t.Run("should send reasoning tokens", func(t *testing.T) {
2993 t.Parallel()
2994
2995 server := newStreamingMockServer()
2996 defer server.close()
2997
2998 server.prepareStreamResponse(map[string]any{
2999 "content": []string{"Hello, World!"},
3000 "model": "o1-preview",
3001 "usage": map[string]any{
3002 "prompt_tokens": 15,
3003 "completion_tokens": 20,
3004 "total_tokens": 35,
3005 "completion_tokens_details": map[string]any{
3006 "reasoning_tokens": 10,
3007 },
3008 },
3009 })
3010
3011 provider, err := New(
3012 WithAPIKey("test-api-key"),
3013 WithBaseURL(server.server.URL),
3014 )
3015 require.NoError(t, err)
3016 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
3017
3018 stream, err := model.Stream(context.Background(), fantasy.Call{
3019 Prompt: testPrompt,
3020 })
3021
3022 require.NoError(t, err)
3023
3024 parts, err := collectStreamParts(stream)
3025 require.NoError(t, err)
3026
3027 // Find finish part
3028 var finishPart *fantasy.StreamPart
3029 for _, part := range parts {
3030 if part.Type == fantasy.StreamPartTypeFinish {
3031 finishPart = &part
3032 break
3033 }
3034 }
3035
3036 require.NotNil(t, finishPart)
3037 require.Equal(t, int64(15), finishPart.Usage.InputTokens)
3038 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
3039 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
3040 require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens)
3041 })
3042
3043 t.Run("should drop tool calls with invalid JSON arguments", func(t *testing.T) {
3044 t.Parallel()
3045
3046 server := newStreamingMockServer()
3047 defer server.close()
3048
3049 server.prepareToolStreamResponseWithInvalidJSON()
3050
3051 provider, err := New(
3052 WithAPIKey("test-api-key"),
3053 WithBaseURL(server.server.URL),
3054 )
3055 require.NoError(t, err)
3056 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
3057
3058 stream, err := model.Stream(context.Background(), fantasy.Call{
3059 Prompt: testPrompt,
3060 Tools: []fantasy.Tool{
3061 fantasy.FunctionTool{
3062 Name: "test-tool",
3063 InputSchema: map[string]any{
3064 "type": "object",
3065 "properties": map[string]any{
3066 "old_string": map[string]any{
3067 "type": "string",
3068 },
3069 "new_string": map[string]any{
3070 "type": "string",
3071 },
3072 },
3073 "required": []string{"old_string", "new_string"},
3074 "additionalProperties": false,
3075 "$schema": "http://json-schema.org/draft-07/schema#",
3076 },
3077 },
3078 },
3079 })
3080
3081 require.NoError(t, err)
3082
3083 parts, err := collectStreamParts(stream)
3084 require.NoError(t, err)
3085
3086 // Find tool-related parts
3087 toolInputStart, toolInputEnd, toolCall := -1, -1, -1
3088 var toolDeltas []string
3089 var finishPart *fantasy.StreamPart
3090
3091 for i, part := range parts {
3092 switch part.Type {
3093 case fantasy.StreamPartTypeToolInputStart:
3094 toolInputStart = i
3095 require.Equal(t, "call_invalid_json", part.ID)
3096 require.Equal(t, "test-tool", part.ToolCallName)
3097 case fantasy.StreamPartTypeToolInputDelta:
3098 toolDeltas = append(toolDeltas, part.Delta)
3099 case fantasy.StreamPartTypeToolInputEnd:
3100 toolInputEnd = i
3101 require.Equal(t, "call_invalid_json", part.ID)
3102 case fantasy.StreamPartTypeToolCall:
3103 toolCall = i
3104 require.Equal(t, "call_invalid_json", part.ID)
3105 require.Equal(t, "test-tool", part.ToolCallName)
3106 case fantasy.StreamPartTypeFinish:
3107 finishPart = &part
3108 }
3109 }
3110
3111 require.NotEqual(t, -1, toolInputStart, "expected ToolInputStart part")
3112 require.NotEqual(t, -1, toolInputEnd, "expected ToolInputEnd part")
3113 require.NotEqual(t, -1, toolCall, "expected ToolCall part")
3114
3115 // Verify tool deltas combine to the complete input with \x00
3116 var fullInput strings.Builder
3117 for _, delta := range toolDeltas {
3118 fullInput.WriteString(delta)
3119 }
3120 require.Equal(t, `{"old_string":"hello\x00world"}`, fullInput.String())
3121
3122 // Finish reason is ToolCalls since the tool call was yielded
3123 require.NotNil(t, finishPart)
3124 require.Equal(t, fantasy.FinishReasonToolCalls, finishPart.FinishReason)
3125 })
3126}
3127
3128func TestDefaultToPrompt_DropsEmptyMessages(t *testing.T) {
3129 t.Parallel()
3130
3131 t.Run("should drop truly empty assistant messages", func(t *testing.T) {
3132 t.Parallel()
3133
3134 prompt := fantasy.Prompt{
3135 {
3136 Role: fantasy.MessageRoleUser,
3137 Content: []fantasy.MessagePart{
3138 fantasy.TextPart{Text: "Hello"},
3139 },
3140 },
3141 {
3142 Role: fantasy.MessageRoleAssistant,
3143 Content: []fantasy.MessagePart{},
3144 },
3145 }
3146
3147 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3148
3149 require.Len(t, messages, 1, "should only have user message")
3150 require.Len(t, warnings, 1)
3151 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
3152 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
3153 })
3154
3155 t.Run("should keep assistant messages with text content", func(t *testing.T) {
3156 t.Parallel()
3157
3158 prompt := fantasy.Prompt{
3159 {
3160 Role: fantasy.MessageRoleUser,
3161 Content: []fantasy.MessagePart{
3162 fantasy.TextPart{Text: "Hello"},
3163 },
3164 },
3165 {
3166 Role: fantasy.MessageRoleAssistant,
3167 Content: []fantasy.MessagePart{
3168 fantasy.TextPart{Text: "Hi there!"},
3169 },
3170 },
3171 }
3172
3173 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3174
3175 require.Len(t, messages, 2, "should have both user and assistant messages")
3176 require.Empty(t, warnings)
3177 })
3178
3179 t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
3180 t.Parallel()
3181
3182 prompt := fantasy.Prompt{
3183 {
3184 Role: fantasy.MessageRoleUser,
3185 Content: []fantasy.MessagePart{
3186 fantasy.TextPart{Text: "What's the weather?"},
3187 },
3188 },
3189 {
3190 Role: fantasy.MessageRoleAssistant,
3191 Content: []fantasy.MessagePart{
3192 fantasy.ToolCallPart{
3193 ToolCallID: "call_123",
3194 ToolName: "get_weather",
3195 Input: `{"location":"NYC"}`,
3196 },
3197 },
3198 },
3199 }
3200
3201 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3202
3203 require.Len(t, messages, 2, "should have both user and assistant messages")
3204 require.Empty(t, warnings)
3205 })
3206
3207 t.Run("should drop user messages without visible content", func(t *testing.T) {
3208 t.Parallel()
3209
3210 prompt := fantasy.Prompt{
3211 {
3212 Role: fantasy.MessageRoleUser,
3213 Content: []fantasy.MessagePart{
3214 fantasy.FilePart{
3215 Data: []byte("not supported"),
3216 MediaType: "application/unknown",
3217 },
3218 },
3219 },
3220 }
3221
3222 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3223
3224 require.Empty(t, messages)
3225 require.Len(t, warnings, 2) // One for unsupported type, one for empty message
3226 require.Contains(t, warnings[1].Message, "dropping empty user message")
3227 })
3228
3229 t.Run("should keep user messages with image content", func(t *testing.T) {
3230 t.Parallel()
3231
3232 prompt := fantasy.Prompt{
3233 {
3234 Role: fantasy.MessageRoleUser,
3235 Content: []fantasy.MessagePart{
3236 fantasy.FilePart{
3237 Data: []byte{0x01, 0x02, 0x03},
3238 MediaType: "image/png",
3239 },
3240 },
3241 },
3242 }
3243
3244 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3245
3246 require.Len(t, messages, 1)
3247 require.Empty(t, warnings)
3248 })
3249
3250 t.Run("should keep user messages with tool results", func(t *testing.T) {
3251 t.Parallel()
3252
3253 prompt := fantasy.Prompt{
3254 {
3255 Role: fantasy.MessageRoleTool,
3256 Content: []fantasy.MessagePart{
3257 fantasy.ToolResultPart{
3258 ToolCallID: "call_123",
3259 Output: fantasy.ToolResultOutputContentText{Text: "done"},
3260 },
3261 },
3262 },
3263 }
3264
3265 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3266
3267 require.Len(t, messages, 1)
3268 require.Empty(t, warnings)
3269 })
3270
3271 t.Run("should keep user messages with tool error results", func(t *testing.T) {
3272 t.Parallel()
3273
3274 prompt := fantasy.Prompt{
3275 {
3276 Role: fantasy.MessageRoleTool,
3277 Content: []fantasy.MessagePart{
3278 fantasy.ToolResultPart{
3279 ToolCallID: "call_456",
3280 Output: fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
3281 },
3282 },
3283 },
3284 }
3285
3286 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3287
3288 require.Len(t, messages, 1)
3289 require.Empty(t, warnings)
3290 })
3291}
3292
3293func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) {
3294 t.Parallel()
3295
3296 t.Run("should drop truly empty assistant messages", func(t *testing.T) {
3297 t.Parallel()
3298
3299 prompt := fantasy.Prompt{
3300 {
3301 Role: fantasy.MessageRoleUser,
3302 Content: []fantasy.MessagePart{
3303 fantasy.TextPart{Text: "Hello"},
3304 },
3305 },
3306 {
3307 Role: fantasy.MessageRoleAssistant,
3308 Content: []fantasy.MessagePart{},
3309 },
3310 }
3311
3312 input, warnings := toResponsesPrompt(prompt, "system", false)
3313
3314 require.Len(t, input, 1, "should only have user message")
3315 require.Len(t, warnings, 1)
3316 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
3317 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
3318 })
3319
3320 t.Run("should keep assistant messages with text content", func(t *testing.T) {
3321 t.Parallel()
3322
3323 prompt := fantasy.Prompt{
3324 {
3325 Role: fantasy.MessageRoleUser,
3326 Content: []fantasy.MessagePart{
3327 fantasy.TextPart{Text: "Hello"},
3328 },
3329 },
3330 {
3331 Role: fantasy.MessageRoleAssistant,
3332 Content: []fantasy.MessagePart{
3333 fantasy.TextPart{Text: "Hi there!"},
3334 },
3335 },
3336 }
3337
3338 input, warnings := toResponsesPrompt(prompt, "system", false)
3339
3340 require.Len(t, input, 2, "should have both user and assistant messages")
3341 require.Empty(t, warnings)
3342 })
3343
3344 t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
3345 t.Parallel()
3346
3347 prompt := fantasy.Prompt{
3348 {
3349 Role: fantasy.MessageRoleUser,
3350 Content: []fantasy.MessagePart{
3351 fantasy.TextPart{Text: "What's the weather?"},
3352 },
3353 },
3354 {
3355 Role: fantasy.MessageRoleAssistant,
3356 Content: []fantasy.MessagePart{
3357 fantasy.ToolCallPart{
3358 ToolCallID: "call_123",
3359 ToolName: "get_weather",
3360 Input: `{"location":"NYC"}`,
3361 },
3362 },
3363 },
3364 }
3365
3366 input, warnings := toResponsesPrompt(prompt, "system", false)
3367
3368 require.Len(t, input, 2, "should have both user and assistant messages")
3369 require.Empty(t, warnings)
3370 })
3371
3372 t.Run("should drop user messages without visible content", func(t *testing.T) {
3373 t.Parallel()
3374
3375 prompt := fantasy.Prompt{
3376 {
3377 Role: fantasy.MessageRoleUser,
3378 Content: []fantasy.MessagePart{
3379 fantasy.FilePart{
3380 Data: []byte("not supported"),
3381 MediaType: "application/unknown",
3382 },
3383 },
3384 },
3385 }
3386
3387 input, warnings := toResponsesPrompt(prompt, "system", false)
3388
3389 require.Empty(t, input)
3390 require.Len(t, warnings, 2) // One for unsupported type, one for empty message
3391 require.Contains(t, warnings[1].Message, "dropping empty user message")
3392 })
3393
3394 t.Run("should keep user messages with image content", func(t *testing.T) {
3395 t.Parallel()
3396
3397 prompt := fantasy.Prompt{
3398 {
3399 Role: fantasy.MessageRoleUser,
3400 Content: []fantasy.MessagePart{
3401 fantasy.FilePart{
3402 Data: []byte{0x01, 0x02, 0x03},
3403 MediaType: "image/png",
3404 },
3405 },
3406 },
3407 }
3408
3409 input, warnings := toResponsesPrompt(prompt, "system", false)
3410
3411 require.Len(t, input, 1)
3412 require.Empty(t, warnings)
3413 })
3414
3415 t.Run("should keep user messages with tool results", func(t *testing.T) {
3416 t.Parallel()
3417
3418 prompt := fantasy.Prompt{
3419 {
3420 Role: fantasy.MessageRoleTool,
3421 Content: []fantasy.MessagePart{
3422 fantasy.ToolResultPart{
3423 ToolCallID: "call_123",
3424 Output: fantasy.ToolResultOutputContentText{Text: "done"},
3425 },
3426 },
3427 },
3428 }
3429
3430 input, warnings := toResponsesPrompt(prompt, "system", false)
3431
3432 require.Len(t, input, 1)
3433 require.Empty(t, warnings)
3434 })
3435
3436 t.Run("should keep user messages with tool error results", func(t *testing.T) {
3437 t.Parallel()
3438
3439 prompt := fantasy.Prompt{
3440 {
3441 Role: fantasy.MessageRoleTool,
3442 Content: []fantasy.MessagePart{
3443 fantasy.ToolResultPart{
3444 ToolCallID: "call_456",
3445 Output: fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
3446 },
3447 },
3448 },
3449 }
3450
3451 input, warnings := toResponsesPrompt(prompt, "system", false)
3452
3453 require.Len(t, input, 1)
3454 require.Empty(t, warnings)
3455 })
3456}
3457
3458func TestParseContextTooLargeError(t *testing.T) {
3459 t.Parallel()
3460
3461 tests := []struct {
3462 name string
3463 message string
3464 wantErr bool
3465 wantUsed int
3466 wantMax int
3467 }{
3468 {
3469 name: "matches openai format with resulted in",
3470 message: "This model's maximum context length is 128000 tokens. However, your messages resulted in 150000 tokens.",
3471 wantErr: true,
3472 wantUsed: 150000,
3473 wantMax: 128000,
3474 },
3475 {
3476 name: "matches openai format with requested",
3477 message: "maximum context length is 8192 tokens, however you requested 10000 tokens",
3478 wantErr: true,
3479 wantUsed: 10000,
3480 wantMax: 8192,
3481 },
3482 {
3483 name: "does not match unrelated error",
3484 message: "invalid api key",
3485 wantErr: false,
3486 },
3487 {
3488 name: "does not match rate limit error",
3489 message: "rate limit exceeded",
3490 wantErr: false,
3491 },
3492 }
3493
3494 for _, tt := range tests {
3495 t.Run(tt.name, func(t *testing.T) {
3496 t.Parallel()
3497 providerErr := &fantasy.ProviderError{Message: tt.message}
3498 parseContextTooLargeError(tt.message, providerErr)
3499
3500 if tt.wantErr {
3501 require.True(t, providerErr.IsContextTooLarge())
3502 if tt.wantUsed > 0 {
3503 require.Equal(t, tt.wantUsed, providerErr.ContextUsedTokens)
3504 require.Equal(t, tt.wantMax, providerErr.ContextMaxTokens)
3505 }
3506 } else {
3507 require.False(t, providerErr.IsContextTooLarge())
3508 }
3509 })
3510 }
3511}
3512
3513func TestUserAgent(t *testing.T) {
3514 t.Parallel()
3515
3516 t.Run("default UA applied", func(t *testing.T) {
3517 t.Parallel()
3518
3519 server := newMockServer()
3520 defer server.close()
3521 server.prepareJSONResponse(map[string]any{})
3522
3523 p, err := New(WithAPIKey("k"), WithBaseURL(server.server.URL))
3524 require.NoError(t, err)
3525 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3526 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt})
3527
3528 require.Len(t, server.calls, 1)
3529 assert.Equal(t, "Charm-Fantasy/"+fantasy.Version+" (https://charm.land/fantasy)", server.calls[0].headers["User-Agent"])
3530 })
3531
3532 t.Run("WithHeaders User-Agent wins over default", func(t *testing.T) {
3533 t.Parallel()
3534
3535 server := newMockServer()
3536 defer server.close()
3537 server.prepareJSONResponse(map[string]any{})
3538
3539 p, err := New(WithAPIKey("k"), WithBaseURL(server.server.URL), WithHeaders(map[string]string{"User-Agent": "custom-from-headers"}))
3540 require.NoError(t, err)
3541 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3542 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt})
3543
3544 require.Len(t, server.calls, 1)
3545 assert.Equal(t, "custom-from-headers", server.calls[0].headers["User-Agent"])
3546 })
3547
3548 t.Run("WithUserAgent wins over both", func(t *testing.T) {
3549 t.Parallel()
3550
3551 server := newMockServer()
3552 defer server.close()
3553 server.prepareJSONResponse(map[string]any{})
3554
3555 p, err := New(
3556 WithAPIKey("k"),
3557 WithBaseURL(server.server.URL),
3558 WithHeaders(map[string]string{"User-Agent": "from-headers"}),
3559 WithUserAgent("explicit-ua"),
3560 )
3561 require.NoError(t, err)
3562 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3563 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt})
3564
3565 require.Len(t, server.calls, 1)
3566 assert.Equal(t, "explicit-ua", server.calls[0].headers["User-Agent"])
3567 })
3568
3569 t.Run("Call.UserAgent overrides provider WithHeaders UA", func(t *testing.T) {
3570 t.Parallel()
3571
3572 server := newMockServer()
3573 defer server.close()
3574 server.prepareJSONResponse(map[string]any{})
3575
3576 p, err := New(
3577 WithAPIKey("k"),
3578 WithBaseURL(server.server.URL),
3579 WithHeaders(map[string]string{"User-Agent": "header-ua"}),
3580 )
3581 require.NoError(t, err)
3582 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3583 _, _ = model.Generate(t.Context(), fantasy.Call{
3584 Prompt: testPrompt,
3585 UserAgent: "call-level-ua",
3586 })
3587
3588 require.Len(t, server.calls, 1)
3589 assert.Equal(t, "call-level-ua", server.calls[0].headers["User-Agent"])
3590 })
3591
3592 t.Run("no Call UA falls through to provider UA", func(t *testing.T) {
3593 t.Parallel()
3594
3595 server := newMockServer()
3596 defer server.close()
3597 server.prepareJSONResponse(map[string]any{})
3598
3599 p, err := New(
3600 WithAPIKey("k"),
3601 WithBaseURL(server.server.URL),
3602 WithUserAgent("provider-ua"),
3603 )
3604 require.NoError(t, err)
3605 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3606 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt})
3607
3608 require.Len(t, server.calls, 1)
3609 assert.Equal(t, "provider-ua", server.calls[0].headers["User-Agent"])
3610 })
3611
3612 t.Run("agent WithUserAgent overrides provider UA end-to-end", func(t *testing.T) {
3613 t.Parallel()
3614
3615 server := newMockServer()
3616 defer server.close()
3617 server.prepareJSONResponse(map[string]any{})
3618
3619 p, err := New(
3620 WithAPIKey("k"),
3621 WithBaseURL(server.server.URL),
3622 WithUserAgent("provider-ua"),
3623 )
3624 require.NoError(t, err)
3625 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3626
3627 agent := fantasy.NewAgent(model, fantasy.WithUserAgent("agent-ua"))
3628 _, _ = agent.Generate(t.Context(), fantasy.AgentCall{Prompt: "hi"})
3629
3630 require.Len(t, server.calls, 1)
3631 assert.Equal(t, "agent-ua", server.calls[0].headers["User-Agent"])
3632 })
3633
3634 t.Run("agent without UA falls through to provider UA end-to-end", func(t *testing.T) {
3635 t.Parallel()
3636
3637 server := newMockServer()
3638 defer server.close()
3639 server.prepareJSONResponse(map[string]any{})
3640
3641 p, err := New(
3642 WithAPIKey("k"),
3643 WithBaseURL(server.server.URL),
3644 WithUserAgent("provider-ua"),
3645 )
3646 require.NoError(t, err)
3647 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3648
3649 agent := fantasy.NewAgent(model)
3650 _, _ = agent.Generate(t.Context(), fantasy.AgentCall{Prompt: "hi"})
3651
3652 require.Len(t, server.calls, 1)
3653 assert.Equal(t, "provider-ua", server.calls[0].headers["User-Agent"])
3654 })
3655}
3656
3657// --- OpenAI Responses API Web Search Tests ---
3658
3659// mockResponsesWebSearchResponse returns a Responses API response
3660// containing a web_search_call output item followed by a message
3661// with url_citation annotations.
3662func mockResponsesWebSearchResponse() map[string]any {
3663 return map[string]any{
3664 "id": "resp_01WebSearch",
3665 "object": "response",
3666 "model": "gpt-4.1",
3667 "output": []any{
3668 map[string]any{
3669 "type": "web_search_call",
3670 "id": "ws_01",
3671 "status": "completed",
3672 "action": map[string]any{
3673 "type": "search",
3674 "query": "latest AI news",
3675 },
3676 },
3677 map[string]any{
3678 "type": "message",
3679 "id": "msg_01",
3680 "role": "assistant",
3681 "status": "completed",
3682 "content": []any{
3683 map[string]any{
3684 "type": "output_text",
3685 "text": "Based on recent search results, here is the latest AI news.",
3686 "annotations": []any{
3687 map[string]any{
3688 "type": "url_citation",
3689 "url": "https://example.com/ai-news",
3690 "title": "Latest AI News",
3691 "start_index": 0,
3692 "end_index": 50,
3693 },
3694 map[string]any{
3695 "type": "url_citation",
3696 "url": "https://example.com/ml-update",
3697 "title": "ML Update",
3698 "start_index": 51,
3699 "end_index": 60,
3700 },
3701 },
3702 },
3703 },
3704 },
3705 },
3706 "status": "completed",
3707 "usage": map[string]any{
3708 "input_tokens": 100,
3709 "output_tokens": 50,
3710 "total_tokens": 150,
3711 },
3712 }
3713}
3714
3715func newResponsesProvider(t *testing.T, serverURL string) fantasy.LanguageModel {
3716 t.Helper()
3717 provider, err := New(
3718 WithAPIKey("test-api-key"),
3719 WithBaseURL(serverURL),
3720 WithUseResponsesAPI(),
3721 )
3722 require.NoError(t, err)
3723 model, err := provider.LanguageModel(context.Background(), "gpt-4.1")
3724 require.NoError(t, err)
3725 return model
3726}
3727
3728func TestResponsesGenerate_WebSearchResponse(t *testing.T) {
3729 t.Parallel()
3730
3731 server := newMockServer()
3732 defer server.close()
3733 server.response = mockResponsesWebSearchResponse()
3734
3735 model := newResponsesProvider(t, server.server.URL)
3736
3737 resp, err := model.Generate(context.Background(), fantasy.Call{
3738 Prompt: testPrompt,
3739 Tools: []fantasy.Tool{WebSearchTool(nil)},
3740 })
3741 require.NoError(t, err)
3742
3743 require.Equal(t, "POST", server.calls[0].method)
3744 require.Equal(t, "/responses", server.calls[0].path)
3745
3746 var (
3747 toolCalls []fantasy.ToolCallContent
3748 sources []fantasy.SourceContent
3749 toolResults []fantasy.ToolResultContent
3750 texts []fantasy.TextContent
3751 )
3752 for _, c := range resp.Content {
3753 switch v := c.(type) {
3754 case fantasy.ToolCallContent:
3755 toolCalls = append(toolCalls, v)
3756 case fantasy.SourceContent:
3757 sources = append(sources, v)
3758 case fantasy.ToolResultContent:
3759 toolResults = append(toolResults, v)
3760 case fantasy.TextContent:
3761 texts = append(texts, v)
3762 }
3763 }
3764
3765 // ToolCallContent for the provider-executed web_search.
3766 require.Len(t, toolCalls, 1)
3767 require.True(t, toolCalls[0].ProviderExecuted)
3768 require.Equal(t, "web_search", toolCalls[0].ToolName)
3769 require.Equal(t, "ws_01", toolCalls[0].ToolCallID)
3770
3771 // SourceContent entries from url_citation annotations.
3772 require.Len(t, sources, 2)
3773 require.Equal(t, "https://example.com/ai-news", sources[0].URL)
3774 require.Equal(t, "Latest AI News", sources[0].Title)
3775 require.Equal(t, fantasy.SourceTypeURL, sources[0].SourceType)
3776 require.Equal(t, "https://example.com/ml-update", sources[1].URL)
3777 require.Equal(t, "ML Update", sources[1].Title)
3778
3779 // ToolResultContent with provider metadata.
3780 require.Len(t, toolResults, 1)
3781 require.True(t, toolResults[0].ProviderExecuted)
3782 require.Equal(t, "web_search", toolResults[0].ToolName)
3783 require.Equal(t, "ws_01", toolResults[0].ToolCallID)
3784
3785 metaVal, ok := toolResults[0].ProviderMetadata[Name]
3786 require.True(t, ok, "providerMetadata should contain openai key")
3787 wsMeta, ok := metaVal.(*WebSearchCallMetadata)
3788 require.True(t, ok, "metadata should be *WebSearchCallMetadata")
3789 require.Equal(t, "ws_01", wsMeta.ItemID)
3790 require.NotNil(t, wsMeta.Action)
3791 require.Equal(t, "search", wsMeta.Action.Type)
3792 require.Equal(t, "latest AI news", wsMeta.Action.Query)
3793
3794 // TextContent with the final answer.
3795 require.Len(t, texts, 1)
3796 require.Equal(t,
3797 "Based on recent search results, here is the latest AI news.",
3798 texts[0].Text,
3799 )
3800}
3801
3802func TestResponsesGenerate_StoreOption(t *testing.T) {
3803 t.Parallel()
3804
3805 server := newMockServer()
3806 defer server.close()
3807 server.response = mockResponsesWebSearchResponse()
3808
3809 model := newResponsesProvider(t, server.server.URL)
3810
3811 _, err := model.Generate(context.Background(), fantasy.Call{
3812 Prompt: testPrompt,
3813 ProviderOptions: fantasy.ProviderOptions{
3814 Name: &ResponsesProviderOptions{
3815 Store: new(true),
3816 },
3817 },
3818 })
3819 require.NoError(t, err)
3820
3821 require.Equal(t, "POST", server.calls[0].method)
3822 require.Equal(t, "/responses", server.calls[0].path)
3823 require.Equal(t, true, server.calls[0].body["store"])
3824}
3825
3826func TestResponsesGenerate_PreviousResponseIDOption(t *testing.T) {
3827 t.Parallel()
3828
3829 server := newMockServer()
3830 defer server.close()
3831 server.response = mockResponsesWebSearchResponse()
3832
3833 model := newResponsesProvider(t, server.server.URL)
3834
3835 _, err := model.Generate(context.Background(), fantasy.Call{
3836 Prompt: testPrompt,
3837 ProviderOptions: fantasy.ProviderOptions{
3838 Name: &ResponsesProviderOptions{
3839 PreviousResponseID: new("resp_prev_123"),
3840 Store: new(true),
3841 },
3842 },
3843 })
3844 require.NoError(t, err)
3845
3846 require.Equal(t, "POST", server.calls[0].method)
3847 require.Equal(t, "/responses", server.calls[0].path)
3848 require.Equal(t, "resp_prev_123", server.calls[0].body["previous_response_id"])
3849}
3850
3851func TestResponsesGenerate_StateChainingAcrossTurns(t *testing.T) {
3852 t.Parallel()
3853
3854 server := newMockServer()
3855 defer server.close()
3856 server.response = map[string]any{
3857 "id": "resp_turn_1",
3858 "object": "response",
3859 "model": "gpt-4.1",
3860 "output": []any{
3861 map[string]any{
3862 "type": "message",
3863 "id": "msg_1",
3864 "role": "assistant",
3865 "status": "completed",
3866 "content": []any{
3867 map[string]any{
3868 "type": "output_text",
3869 "text": "First turn",
3870 },
3871 },
3872 },
3873 },
3874 "status": "completed",
3875 "usage": map[string]any{
3876 "input_tokens": 10,
3877 "output_tokens": 5,
3878 "total_tokens": 15,
3879 },
3880 }
3881
3882 model := newResponsesProvider(t, server.server.URL)
3883
3884 first, err := model.Generate(context.Background(), fantasy.Call{
3885 Prompt: testPrompt,
3886 ProviderOptions: fantasy.ProviderOptions{
3887 Name: &ResponsesProviderOptions{Store: new(true)},
3888 },
3889 })
3890 require.NoError(t, err)
3891
3892 meta, ok := first.ProviderMetadata[Name].(*ResponsesProviderMetadata)
3893 require.True(t, ok)
3894 require.Equal(t, "resp_turn_1", meta.ResponseID)
3895
3896 server.response = map[string]any{
3897 "id": "resp_turn_2",
3898 "object": "response",
3899 "model": "gpt-4.1",
3900 "output": []any{
3901 map[string]any{
3902 "type": "message",
3903 "id": "msg_2",
3904 "role": "assistant",
3905 "status": "completed",
3906 "content": []any{
3907 map[string]any{
3908 "type": "output_text",
3909 "text": "Second turn",
3910 },
3911 },
3912 },
3913 },
3914 "status": "completed",
3915 "usage": map[string]any{
3916 "input_tokens": 8,
3917 "output_tokens": 4,
3918 "total_tokens": 12,
3919 },
3920 }
3921
3922 _, err = model.Generate(context.Background(), fantasy.Call{
3923 Prompt: fantasy.Prompt{
3924 fantasy.NewUserMessage("follow-up only"),
3925 },
3926 ProviderOptions: fantasy.ProviderOptions{
3927 Name: &ResponsesProviderOptions{
3928 Store: new(true),
3929 PreviousResponseID: &meta.ResponseID,
3930 },
3931 },
3932 })
3933 require.NoError(t, err)
3934 require.Len(t, server.calls, 2)
3935
3936 firstCall := server.calls[0]
3937 require.Equal(t, true, firstCall.body["store"])
3938
3939 secondCall := server.calls[1]
3940 require.Equal(t, "resp_turn_1", secondCall.body["previous_response_id"])
3941 require.Equal(t, true, secondCall.body["store"])
3942
3943 input, ok := secondCall.body["input"].([]any)
3944 require.True(t, ok)
3945 require.Len(t, input, 1)
3946
3947 inputMessage, ok := input[0].(map[string]any)
3948 require.True(t, ok)
3949 require.Equal(t, "user", inputMessage["role"])
3950}
3951
3952func TestResponsesGenerate_WebSearchToolInRequest(t *testing.T) {
3953 t.Parallel()
3954
3955 t.Run("basic web_search tool", func(t *testing.T) {
3956 t.Parallel()
3957
3958 server := newMockServer()
3959 defer server.close()
3960 server.response = mockResponsesWebSearchResponse()
3961
3962 model := newResponsesProvider(t, server.server.URL)
3963
3964 _, err := model.Generate(context.Background(), fantasy.Call{
3965 Prompt: testPrompt,
3966 Tools: []fantasy.Tool{WebSearchTool(nil)},
3967 })
3968 require.NoError(t, err)
3969
3970 tools, ok := server.calls[0].body["tools"].([]any)
3971 require.True(t, ok, "request body should have tools array")
3972 require.Len(t, tools, 1)
3973
3974 tool, ok := tools[0].(map[string]any)
3975 require.True(t, ok)
3976 require.Equal(t, "web_search", tool["type"])
3977 })
3978
3979 t.Run("with search_context_size and allowed_domains", func(t *testing.T) {
3980 t.Parallel()
3981
3982 server := newMockServer()
3983 defer server.close()
3984 server.response = mockResponsesWebSearchResponse()
3985
3986 model := newResponsesProvider(t, server.server.URL)
3987
3988 _, err := model.Generate(context.Background(), fantasy.Call{
3989 Prompt: testPrompt,
3990 Tools: []fantasy.Tool{
3991 WebSearchTool(&WebSearchToolOptions{
3992 SearchContextSize: SearchContextSizeHigh,
3993 AllowedDomains: []string{"example.com", "test.com"},
3994 }),
3995 },
3996 })
3997 require.NoError(t, err)
3998
3999 tools, ok := server.calls[0].body["tools"].([]any)
4000 require.True(t, ok)
4001 require.Len(t, tools, 1)
4002
4003 tool, ok := tools[0].(map[string]any)
4004 require.True(t, ok)
4005 require.Equal(t, "web_search", tool["type"])
4006 require.Equal(t, "high", tool["search_context_size"])
4007
4008 filters, ok := tool["filters"].(map[string]any)
4009 require.True(t, ok, "tool should have filters")
4010 domains, ok := filters["allowed_domains"].([]any)
4011 require.True(t, ok, "filters should have allowed_domains")
4012 require.Len(t, domains, 2)
4013 require.Equal(t, "example.com", domains[0])
4014 require.Equal(t, "test.com", domains[1])
4015 })
4016
4017 t.Run("with user_location", func(t *testing.T) {
4018 t.Parallel()
4019
4020 server := newMockServer()
4021 defer server.close()
4022 server.response = mockResponsesWebSearchResponse()
4023
4024 model := newResponsesProvider(t, server.server.URL)
4025
4026 _, err := model.Generate(context.Background(), fantasy.Call{
4027 Prompt: testPrompt,
4028 Tools: []fantasy.Tool{
4029 WebSearchTool(&WebSearchToolOptions{
4030 UserLocation: &WebSearchUserLocation{
4031 City: "San Francisco",
4032 Country: "US",
4033 },
4034 }),
4035 },
4036 })
4037 require.NoError(t, err)
4038
4039 tools, ok := server.calls[0].body["tools"].([]any)
4040 require.True(t, ok)
4041 require.Len(t, tools, 1)
4042
4043 tool, ok := tools[0].(map[string]any)
4044 require.True(t, ok)
4045 require.Equal(t, "web_search", tool["type"])
4046
4047 userLoc, ok := tool["user_location"].(map[string]any)
4048 require.True(t, ok, "tool should have user_location")
4049 require.Equal(t, "San Francisco", userLoc["city"])
4050 require.Equal(t, "US", userLoc["country"])
4051 })
4052}
4053
4054func TestResponsesToPrompt_WebSearchProviderExecutedToolResults(t *testing.T) {
4055 t.Parallel()
4056
4057 prompt := fantasy.Prompt{
4058 {
4059 Role: fantasy.MessageRoleUser,
4060 Content: []fantasy.MessagePart{
4061 fantasy.TextPart{Text: "Search for the latest AI news"},
4062 },
4063 },
4064 {
4065 Role: fantasy.MessageRoleAssistant,
4066 Content: []fantasy.MessagePart{
4067 fantasy.ToolCallPart{
4068 ToolCallID: "ws_01",
4069 ToolName: "web_search",
4070 ProviderExecuted: true,
4071 },
4072 fantasy.ToolResultPart{
4073 ToolCallID: "ws_01",
4074 ProviderExecuted: true,
4075 },
4076 fantasy.TextPart{Text: "Here is what I found."},
4077 },
4078 },
4079 }
4080
4081 t.Run("store false skips item reference", func(t *testing.T) {
4082 t.Parallel()
4083
4084 input, warnings := toResponsesPrompt(prompt, "system instructions", false)
4085
4086 require.Empty(t, warnings)
4087 require.Len(t, input, 2,
4088 "expected user + assistant text when store=false")
4089 require.Nil(t, input[0].OfItemReference)
4090 require.Nil(t, input[1].OfItemReference)
4091 })
4092
4093 t.Run("store true uses item reference", func(t *testing.T) {
4094 t.Parallel()
4095
4096 input, warnings := toResponsesPrompt(prompt, "system instructions", true)
4097
4098 require.Empty(t, warnings)
4099 require.Len(t, input, 3,
4100 "expected user + item_reference + assistant text when store=true")
4101 require.NotNil(t, input[1].OfItemReference)
4102 require.Equal(t, "ws_01", input[1].OfItemReference.ID)
4103 })
4104}
4105
4106func TestResponsesToPrompt_ReasoningWithStore(t *testing.T) {
4107 t.Parallel()
4108
4109 encryptedContent := "gAAAAABpvAwtDPh5dSXW86hwbwoTo4DJHANQ"
4110 reasoningItemID := "rs_08d030b87966238b0069bc095b7e5c81"
4111
4112 reasoningPart := fantasy.ReasoningPart{
4113 Text: "Let me think about this...",
4114 ProviderOptions: fantasy.ProviderOptions{
4115 Name: &ResponsesReasoningMetadata{
4116 ItemID: reasoningItemID,
4117 EncryptedContent: &encryptedContent,
4118 Summary: []string{},
4119 },
4120 },
4121 }
4122
4123 prompt := fantasy.Prompt{
4124 {
4125 Role: fantasy.MessageRoleUser,
4126 Content: []fantasy.MessagePart{
4127 fantasy.TextPart{Text: "What is 2+2?"},
4128 },
4129 },
4130 {
4131 Role: fantasy.MessageRoleAssistant,
4132 Content: []fantasy.MessagePart{
4133 reasoningPart,
4134 fantasy.TextPart{Text: "4"},
4135 },
4136 },
4137 {
4138 Role: fantasy.MessageRoleUser,
4139 Content: []fantasy.MessagePart{
4140 fantasy.TextPart{Text: "And 3+3?"},
4141 },
4142 },
4143 }
4144
4145 t.Run("store true skips reasoning", func(t *testing.T) {
4146 t.Parallel()
4147
4148 input, warnings := toResponsesPrompt(prompt, "system", true)
4149 require.Empty(t, warnings)
4150
4151 // With store=true: user, assistant text (reasoning
4152 // skipped), follow-up user.
4153 require.Len(t, input, 3)
4154
4155 // Verify no reasoning item leaked through.
4156 for _, item := range input {
4157 require.Nil(t, item.OfReasoning,
4158 "reasoning items must not appear when store=true")
4159 }
4160 })
4161
4162 t.Run("store false skips reasoning", func(t *testing.T) {
4163 t.Parallel()
4164
4165 input, warnings := toResponsesPrompt(prompt, "system", false)
4166 require.Empty(t, warnings)
4167
4168 // With store=false: user, assistant text, follow-up user.
4169 require.Len(t, input, 3)
4170
4171 for _, item := range input {
4172 require.Nil(t, item.OfReasoning,
4173 "reasoning items must not appear when store=false")
4174 }
4175 })
4176}
4177
4178func TestResponsesStream_WebSearchResponse(t *testing.T) {
4179 t.Parallel()
4180
4181 chunks := []string{
4182 "event: response.output_item.added\n" +
4183 `data: {"type":"response.output_item.added","output_index":0,"item":{"type":"web_search_call","id":"ws_01","status":"in_progress"}}` + "\n\n",
4184 "event: response.output_item.done\n" +
4185 `data: {"type":"response.output_item.done","output_index":0,"item":{"type":"web_search_call","id":"ws_01","status":"completed","action":{"type":"search","query":"latest AI news"}}}` + "\n\n",
4186 "event: response.output_item.added\n" +
4187 `data: {"type":"response.output_item.added","output_index":1,"item":{"type":"message","id":"msg_01","role":"assistant","status":"in_progress","content":[]}}` + "\n\n",
4188 "event: response.output_text.delta\n" +
4189 `data: {"type":"response.output_text.delta","output_index":1,"content_index":0,"delta":"Here are the results."}` + "\n\n",
4190 "event: response.output_text.annotation.added\n" +
4191 `data: {"type":"response.output_text.annotation.added","annotation":{"type":"url_citation","url":"https://example.com/ai-news","title":"Latest AI News","start_index":0,"end_index":21},"annotation_index":0,"content_index":0,"item_id":"msg_01","output_index":1,"sequence_number":10}` + "\n\n",
4192 "event: response.output_text.annotation.added\n" +
4193 `data: {"type":"response.output_text.annotation.added","annotation":{"type":"url_citation","url":"https://example.com/more-news","title":"More AI News","start_index":22,"end_index":40},"annotation_index":1,"content_index":0,"item_id":"msg_01","output_index":1,"sequence_number":11}` + "\n\n",
4194 "event: response.output_item.done\n" +
4195 `data: {"type":"response.output_item.done","output_index":1,"item":{"type":"message","id":"msg_01","role":"assistant","status":"completed","content":[{"type":"output_text","text":"Here are the results.","annotations":[{"type":"url_citation","url":"https://example.com/ai-news","title":"Latest AI News","start_index":0,"end_index":21},{"type":"url_citation","url":"https://example.com/more-news","title":"More AI News","start_index":22,"end_index":40}]}]}}` + "\n\n",
4196 "event: response.completed\n" +
4197 `data: {"type":"response.completed","response":{"id":"resp_01","status":"completed","output":[],"usage":{"input_tokens":100,"output_tokens":50,"total_tokens":150}}}` + "\n\n",
4198 }
4199
4200 sms := newStreamingMockServer()
4201 defer sms.close()
4202 sms.chunks = chunks
4203
4204 model := newResponsesProvider(t, sms.server.URL)
4205
4206 stream, err := model.Stream(context.Background(), fantasy.Call{
4207 Prompt: testPrompt,
4208 Tools: []fantasy.Tool{WebSearchTool(nil)},
4209 })
4210 require.NoError(t, err)
4211
4212 var parts []fantasy.StreamPart
4213 stream(func(part fantasy.StreamPart) bool {
4214 parts = append(parts, part)
4215 return true
4216 })
4217
4218 var (
4219 toolInputStarts []fantasy.StreamPart
4220 toolCalls []fantasy.StreamPart
4221 toolResults []fantasy.StreamPart
4222 textDeltas []fantasy.StreamPart
4223 sources []fantasy.StreamPart
4224 finishes []fantasy.StreamPart
4225 )
4226 for _, p := range parts {
4227 switch p.Type {
4228 case fantasy.StreamPartTypeToolInputStart:
4229 toolInputStarts = append(toolInputStarts, p)
4230 case fantasy.StreamPartTypeToolCall:
4231 toolCalls = append(toolCalls, p)
4232 case fantasy.StreamPartTypeToolResult:
4233 toolResults = append(toolResults, p)
4234 case fantasy.StreamPartTypeTextDelta:
4235 textDeltas = append(textDeltas, p)
4236 case fantasy.StreamPartTypeSource:
4237 sources = append(sources, p)
4238 case fantasy.StreamPartTypeFinish:
4239 finishes = append(finishes, p)
4240 }
4241 }
4242
4243 require.NotEmpty(t, toolInputStarts, "should have a tool input start")
4244 require.True(t, toolInputStarts[0].ProviderExecuted)
4245 require.Equal(t, "web_search", toolInputStarts[0].ToolCallName)
4246
4247 require.NotEmpty(t, toolCalls, "should have a tool call")
4248 require.True(t, toolCalls[0].ProviderExecuted)
4249 require.Equal(t, "web_search", toolCalls[0].ToolCallName)
4250
4251 require.NotEmpty(t, toolResults, "should have a tool result")
4252 require.True(t, toolResults[0].ProviderExecuted)
4253 require.Equal(t, "web_search", toolResults[0].ToolCallName)
4254 require.Equal(t, "ws_01", toolResults[0].ID)
4255
4256 require.NotEmpty(t, textDeltas, "should have text deltas")
4257 require.Equal(t, "Here are the results.", textDeltas[0].Delta)
4258
4259 require.Len(t, sources, 2, "should have two source citations from annotation events")
4260 require.Equal(t, fantasy.SourceTypeURL, sources[0].SourceType)
4261 require.Equal(t, "https://example.com/ai-news", sources[0].URL)
4262 require.Equal(t, "Latest AI News", sources[0].Title)
4263 require.NotEmpty(t, sources[0].ID, "source should have an ID")
4264 require.Equal(t, fantasy.SourceTypeURL, sources[1].SourceType)
4265 require.Equal(t, "https://example.com/more-news", sources[1].URL)
4266 require.Equal(t, "More AI News", sources[1].Title)
4267 require.NotEmpty(t, sources[1].ID, "source should have an ID")
4268
4269 require.Len(t, finishes, 1)
4270 responsesMeta, ok := finishes[0].ProviderMetadata[Name].(*ResponsesProviderMetadata)
4271 require.True(t, ok)
4272 require.Equal(t, "resp_01", responsesMeta.ResponseID)
4273}
4274
4275func TestResponsesStream_StoreOption(t *testing.T) {
4276 t.Parallel()
4277
4278 chunks := []string{
4279 "event: response.completed\n" +
4280 `data: {"type":"response.completed","response":{"id":"resp_01","status":"completed","output":[],"usage":{"input_tokens":100,"output_tokens":50,"total_tokens":150}}}` + "\n\n",
4281 }
4282
4283 sms := newStreamingMockServer()
4284 defer sms.close()
4285 sms.chunks = chunks
4286
4287 model := newResponsesProvider(t, sms.server.URL)
4288
4289 stream, err := model.Stream(context.Background(), fantasy.Call{
4290 Prompt: testPrompt,
4291 ProviderOptions: fantasy.ProviderOptions{
4292 Name: &ResponsesProviderOptions{
4293 Store: new(true),
4294 },
4295 },
4296 })
4297 require.NoError(t, err)
4298
4299 stream(func(part fantasy.StreamPart) bool {
4300 return part.Type != fantasy.StreamPartTypeFinish
4301 })
4302
4303 require.Equal(t, "POST", sms.calls[0].method)
4304 require.Equal(t, "/responses", sms.calls[0].path)
4305 require.Equal(t, true, sms.calls[0].body["store"])
4306}
4307
4308func TestResponsesStream_PreviousResponseIDOption(t *testing.T) {
4309 t.Parallel()
4310
4311 chunks := []string{
4312 "event: response.completed\n" +
4313 `data: {"type":"response.completed","response":{"id":"resp_01","status":"completed","output":[],"usage":{"input_tokens":100,"output_tokens":50,"total_tokens":150}}}` + "\n\n",
4314 }
4315
4316 sms := newStreamingMockServer()
4317 defer sms.close()
4318 sms.chunks = chunks
4319
4320 model := newResponsesProvider(t, sms.server.URL)
4321
4322 stream, err := model.Stream(context.Background(), fantasy.Call{
4323 Prompt: testPrompt,
4324 ProviderOptions: fantasy.ProviderOptions{
4325 Name: &ResponsesProviderOptions{
4326 PreviousResponseID: new("resp_prev_456"),
4327 Store: new(true),
4328 },
4329 },
4330 })
4331 require.NoError(t, err)
4332
4333 stream(func(part fantasy.StreamPart) bool {
4334 return part.Type != fantasy.StreamPartTypeFinish
4335 })
4336
4337 require.Equal(t, "POST", sms.calls[0].method)
4338 require.Equal(t, "/responses", sms.calls[0].path)
4339 require.Equal(t, "resp_prev_456", sms.calls[0].body["previous_response_id"])
4340}