1package openai
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "net/http"
9 "net/http/httptest"
10 "strings"
11 "testing"
12
13 "charm.land/fantasy"
14 "github.com/openai/openai-go/v3/packages/param"
15 "github.com/stretchr/testify/assert"
16 "github.com/stretchr/testify/require"
17)
18
19func TestToOpenAiPrompt_SystemMessages(t *testing.T) {
20 t.Parallel()
21
22 t.Run("should forward system messages", func(t *testing.T) {
23 t.Parallel()
24
25 prompt := fantasy.Prompt{
26 {
27 Role: fantasy.MessageRoleSystem,
28 Content: []fantasy.MessagePart{
29 fantasy.TextPart{Text: "You are a helpful assistant."},
30 },
31 },
32 }
33
34 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
35
36 require.Empty(t, warnings)
37 require.Len(t, messages, 1)
38
39 systemMsg := messages[0].OfSystem
40 require.NotNil(t, systemMsg)
41 require.Equal(t, "You are a helpful assistant.", systemMsg.Content.OfString.Value)
42 })
43
44 t.Run("should handle empty system messages", func(t *testing.T) {
45 t.Parallel()
46
47 prompt := fantasy.Prompt{
48 {
49 Role: fantasy.MessageRoleSystem,
50 Content: []fantasy.MessagePart{},
51 },
52 }
53
54 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
55
56 require.Len(t, warnings, 1)
57 require.Contains(t, warnings[0].Message, "system prompt has no text parts")
58 require.Empty(t, messages)
59 })
60
61 t.Run("should join multiple system text parts", func(t *testing.T) {
62 t.Parallel()
63
64 prompt := fantasy.Prompt{
65 {
66 Role: fantasy.MessageRoleSystem,
67 Content: []fantasy.MessagePart{
68 fantasy.TextPart{Text: "You are a helpful assistant."},
69 fantasy.TextPart{Text: "Be concise."},
70 },
71 },
72 }
73
74 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
75
76 require.Empty(t, warnings)
77 require.Len(t, messages, 1)
78
79 systemMsg := messages[0].OfSystem
80 require.NotNil(t, systemMsg)
81 require.Equal(t, "You are a helpful assistant.\nBe concise.", systemMsg.Content.OfString.Value)
82 })
83}
84
85func TestToOpenAiPrompt_UserMessages(t *testing.T) {
86 t.Parallel()
87
88 t.Run("should convert messages with only a text part to a string content", func(t *testing.T) {
89 t.Parallel()
90
91 prompt := fantasy.Prompt{
92 {
93 Role: fantasy.MessageRoleUser,
94 Content: []fantasy.MessagePart{
95 fantasy.TextPart{Text: "Hello"},
96 },
97 },
98 }
99
100 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
101
102 require.Empty(t, warnings)
103 require.Len(t, messages, 1)
104
105 userMsg := messages[0].OfUser
106 require.NotNil(t, userMsg)
107 require.Equal(t, "Hello", userMsg.Content.OfString.Value)
108 })
109
110 t.Run("should convert messages with image parts", func(t *testing.T) {
111 t.Parallel()
112
113 imageData := []byte{0, 1, 2, 3}
114 prompt := fantasy.Prompt{
115 {
116 Role: fantasy.MessageRoleUser,
117 Content: []fantasy.MessagePart{
118 fantasy.TextPart{Text: "Hello"},
119 fantasy.FilePart{
120 MediaType: "image/png",
121 Data: imageData,
122 },
123 },
124 },
125 }
126
127 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
128
129 require.Empty(t, warnings)
130 require.Len(t, messages, 1)
131
132 userMsg := messages[0].OfUser
133 require.NotNil(t, userMsg)
134
135 content := userMsg.Content.OfArrayOfContentParts
136 require.Len(t, content, 2)
137
138 // Check text part
139 textPart := content[0].OfText
140 require.NotNil(t, textPart)
141 require.Equal(t, "Hello", textPart.Text)
142
143 // Check image part
144 imagePart := content[1].OfImageURL
145 require.NotNil(t, imagePart)
146 expectedURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
147 require.Equal(t, expectedURL, imagePart.ImageURL.URL)
148 })
149
150 t.Run("should add image detail when specified through provider options", func(t *testing.T) {
151 t.Parallel()
152
153 imageData := []byte{0, 1, 2, 3}
154 prompt := fantasy.Prompt{
155 {
156 Role: fantasy.MessageRoleUser,
157 Content: []fantasy.MessagePart{
158 fantasy.FilePart{
159 MediaType: "image/png",
160 Data: imageData,
161 ProviderOptions: NewProviderFileOptions(&ProviderFileOptions{
162 ImageDetail: "low",
163 }),
164 },
165 },
166 },
167 }
168
169 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
170
171 require.Empty(t, warnings)
172 require.Len(t, messages, 1)
173
174 userMsg := messages[0].OfUser
175 require.NotNil(t, userMsg)
176
177 content := userMsg.Content.OfArrayOfContentParts
178 require.Len(t, content, 1)
179
180 imagePart := content[0].OfImageURL
181 require.NotNil(t, imagePart)
182 require.Equal(t, "low", imagePart.ImageURL.Detail)
183 })
184}
185
186func TestToOpenAiPrompt_FileParts(t *testing.T) {
187 t.Parallel()
188
189 t.Run("should throw for unsupported mime types", func(t *testing.T) {
190 t.Parallel()
191
192 prompt := fantasy.Prompt{
193 {
194 Role: fantasy.MessageRoleUser,
195 Content: []fantasy.MessagePart{
196 fantasy.FilePart{
197 MediaType: "application/something",
198 Data: []byte("test"),
199 },
200 },
201 },
202 }
203
204 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
205
206 require.Len(t, warnings, 2) // unsupported type + empty message
207 require.Contains(t, warnings[0].Message, "file part media type application/something not supported")
208 require.Contains(t, warnings[1].Message, "dropping empty user message")
209 require.Empty(t, messages) // Message is now dropped because it's empty
210 })
211
212 t.Run("should add audio content for audio/wav file parts", func(t *testing.T) {
213 t.Parallel()
214
215 audioData := []byte{0, 1, 2, 3}
216 prompt := fantasy.Prompt{
217 {
218 Role: fantasy.MessageRoleUser,
219 Content: []fantasy.MessagePart{
220 fantasy.FilePart{
221 MediaType: "audio/wav",
222 Data: audioData,
223 },
224 },
225 },
226 }
227
228 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
229
230 require.Empty(t, warnings)
231 require.Len(t, messages, 1)
232
233 userMsg := messages[0].OfUser
234 require.NotNil(t, userMsg)
235
236 content := userMsg.Content.OfArrayOfContentParts
237 require.Len(t, content, 1)
238
239 audioPart := content[0].OfInputAudio
240 require.NotNil(t, audioPart)
241 require.Equal(t, base64.StdEncoding.EncodeToString(audioData), audioPart.InputAudio.Data)
242 require.Equal(t, "wav", audioPart.InputAudio.Format)
243 })
244
245 t.Run("should add audio content for audio/mpeg file parts", func(t *testing.T) {
246 t.Parallel()
247
248 audioData := []byte{0, 1, 2, 3}
249 prompt := fantasy.Prompt{
250 {
251 Role: fantasy.MessageRoleUser,
252 Content: []fantasy.MessagePart{
253 fantasy.FilePart{
254 MediaType: "audio/mpeg",
255 Data: audioData,
256 },
257 },
258 },
259 }
260
261 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
262
263 require.Empty(t, warnings)
264 require.Len(t, messages, 1)
265
266 userMsg := messages[0].OfUser
267 content := userMsg.Content.OfArrayOfContentParts
268 audioPart := content[0].OfInputAudio
269 require.NotNil(t, audioPart)
270 require.Equal(t, "mp3", audioPart.InputAudio.Format)
271 })
272
273 t.Run("should add audio content for audio/mp3 file parts", func(t *testing.T) {
274 t.Parallel()
275
276 audioData := []byte{0, 1, 2, 3}
277 prompt := fantasy.Prompt{
278 {
279 Role: fantasy.MessageRoleUser,
280 Content: []fantasy.MessagePart{
281 fantasy.FilePart{
282 MediaType: "audio/mp3",
283 Data: audioData,
284 },
285 },
286 },
287 }
288
289 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
290
291 require.Empty(t, warnings)
292 require.Len(t, messages, 1)
293
294 userMsg := messages[0].OfUser
295 content := userMsg.Content.OfArrayOfContentParts
296 audioPart := content[0].OfInputAudio
297 require.NotNil(t, audioPart)
298 require.Equal(t, "mp3", audioPart.InputAudio.Format)
299 })
300
301 t.Run("should convert messages with PDF file parts", func(t *testing.T) {
302 t.Parallel()
303
304 pdfData := []byte{1, 2, 3, 4, 5}
305 prompt := fantasy.Prompt{
306 {
307 Role: fantasy.MessageRoleUser,
308 Content: []fantasy.MessagePart{
309 fantasy.FilePart{
310 MediaType: "application/pdf",
311 Data: pdfData,
312 Filename: "document.pdf",
313 },
314 },
315 },
316 }
317
318 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
319
320 require.Empty(t, warnings)
321 require.Len(t, messages, 1)
322
323 userMsg := messages[0].OfUser
324 content := userMsg.Content.OfArrayOfContentParts
325 require.Len(t, content, 1)
326
327 filePart := content[0].OfFile
328 require.NotNil(t, filePart)
329 require.Equal(t, "document.pdf", filePart.File.Filename.Value)
330
331 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
332 require.Equal(t, expectedData, filePart.File.FileData.Value)
333 })
334
335 t.Run("should convert messages with binary PDF file parts", func(t *testing.T) {
336 t.Parallel()
337
338 pdfData := []byte{1, 2, 3, 4, 5}
339 prompt := fantasy.Prompt{
340 {
341 Role: fantasy.MessageRoleUser,
342 Content: []fantasy.MessagePart{
343 fantasy.FilePart{
344 MediaType: "application/pdf",
345 Data: pdfData,
346 Filename: "document.pdf",
347 },
348 },
349 },
350 }
351
352 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
353
354 require.Empty(t, warnings)
355 require.Len(t, messages, 1)
356
357 userMsg := messages[0].OfUser
358 content := userMsg.Content.OfArrayOfContentParts
359 filePart := content[0].OfFile
360 require.NotNil(t, filePart)
361
362 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
363 require.Equal(t, expectedData, filePart.File.FileData.Value)
364 })
365
366 t.Run("should convert messages with PDF file parts using file_id", func(t *testing.T) {
367 t.Parallel()
368
369 prompt := fantasy.Prompt{
370 {
371 Role: fantasy.MessageRoleUser,
372 Content: []fantasy.MessagePart{
373 fantasy.FilePart{
374 MediaType: "application/pdf",
375 Data: []byte("file-pdf-12345"),
376 },
377 },
378 },
379 }
380
381 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
382
383 require.Empty(t, warnings)
384 require.Len(t, messages, 1)
385
386 userMsg := messages[0].OfUser
387 content := userMsg.Content.OfArrayOfContentParts
388 filePart := content[0].OfFile
389 require.NotNil(t, filePart)
390 require.Equal(t, "file-pdf-12345", filePart.File.FileID.Value)
391 require.True(t, param.IsOmitted(filePart.File.FileData))
392 require.True(t, param.IsOmitted(filePart.File.Filename))
393 })
394
395 t.Run("should use default filename for PDF file parts when not provided", func(t *testing.T) {
396 t.Parallel()
397
398 pdfData := []byte{1, 2, 3, 4, 5}
399 prompt := fantasy.Prompt{
400 {
401 Role: fantasy.MessageRoleUser,
402 Content: []fantasy.MessagePart{
403 fantasy.FilePart{
404 MediaType: "application/pdf",
405 Data: pdfData,
406 },
407 },
408 },
409 }
410
411 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
412
413 require.Empty(t, warnings)
414 require.Len(t, messages, 1)
415
416 userMsg := messages[0].OfUser
417 content := userMsg.Content.OfArrayOfContentParts
418 filePart := content[0].OfFile
419 require.NotNil(t, filePart)
420 require.Equal(t, "part-0.pdf", filePart.File.Filename.Value)
421 })
422}
423
424func TestToOpenAiPrompt_ToolCalls(t *testing.T) {
425 t.Parallel()
426
427 t.Run("should stringify arguments to tool calls", func(t *testing.T) {
428 t.Parallel()
429
430 inputArgs := map[string]any{"foo": "bar123"}
431 inputJSON, _ := json.Marshal(inputArgs)
432
433 outputResult := map[string]any{"oof": "321rab"}
434 outputJSON, _ := json.Marshal(outputResult)
435
436 prompt := fantasy.Prompt{
437 {
438 Role: fantasy.MessageRoleAssistant,
439 Content: []fantasy.MessagePart{
440 fantasy.ToolCallPart{
441 ToolCallID: "quux",
442 ToolName: "thwomp",
443 Input: string(inputJSON),
444 },
445 },
446 },
447 {
448 Role: fantasy.MessageRoleTool,
449 Content: []fantasy.MessagePart{
450 fantasy.ToolResultPart{
451 ToolCallID: "quux",
452 Output: fantasy.ToolResultOutputContentText{
453 Text: string(outputJSON),
454 },
455 },
456 },
457 },
458 }
459
460 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
461
462 require.Empty(t, warnings)
463 require.Len(t, messages, 2)
464
465 // Check assistant message with tool call
466 assistantMsg := messages[0].OfAssistant
467 require.NotNil(t, assistantMsg)
468 require.Equal(t, "", assistantMsg.Content.OfString.Value)
469 require.Len(t, assistantMsg.ToolCalls, 1)
470
471 toolCall := assistantMsg.ToolCalls[0].OfFunction
472 require.NotNil(t, toolCall)
473 require.Equal(t, "quux", toolCall.ID)
474 require.Equal(t, "thwomp", toolCall.Function.Name)
475 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
476
477 // Check tool message
478 toolMsg := messages[1].OfTool
479 require.NotNil(t, toolMsg)
480 require.Equal(t, string(outputJSON), toolMsg.Content.OfString.Value)
481 require.Equal(t, "quux", toolMsg.ToolCallID)
482 })
483
484 t.Run("should handle different tool output types", func(t *testing.T) {
485 t.Parallel()
486
487 prompt := fantasy.Prompt{
488 {
489 Role: fantasy.MessageRoleTool,
490 Content: []fantasy.MessagePart{
491 fantasy.ToolResultPart{
492 ToolCallID: "text-tool",
493 Output: fantasy.ToolResultOutputContentText{
494 Text: "Hello world",
495 },
496 },
497 fantasy.ToolResultPart{
498 ToolCallID: "error-tool",
499 Output: fantasy.ToolResultOutputContentError{
500 Error: errors.New("Something went wrong"),
501 },
502 },
503 },
504 },
505 }
506
507 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
508
509 require.Empty(t, warnings)
510 require.Len(t, messages, 2)
511
512 // Check first tool message (text)
513 textToolMsg := messages[0].OfTool
514 require.NotNil(t, textToolMsg)
515 require.Equal(t, "Hello world", textToolMsg.Content.OfString.Value)
516 require.Equal(t, "text-tool", textToolMsg.ToolCallID)
517
518 // Check second tool message (error)
519 errorToolMsg := messages[1].OfTool
520 require.NotNil(t, errorToolMsg)
521 require.Equal(t, "Something went wrong", errorToolMsg.Content.OfString.Value)
522 require.Equal(t, "error-tool", errorToolMsg.ToolCallID)
523 })
524}
525
526func TestToOpenAiPrompt_AssistantMessages(t *testing.T) {
527 t.Parallel()
528
529 t.Run("should handle simple text assistant messages", func(t *testing.T) {
530 t.Parallel()
531
532 prompt := fantasy.Prompt{
533 {
534 Role: fantasy.MessageRoleAssistant,
535 Content: []fantasy.MessagePart{
536 fantasy.TextPart{Text: "Hello, how can I help you?"},
537 },
538 },
539 }
540
541 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
542
543 require.Empty(t, warnings)
544 require.Len(t, messages, 1)
545
546 assistantMsg := messages[0].OfAssistant
547 require.NotNil(t, assistantMsg)
548 require.Equal(t, "Hello, how can I help you?", assistantMsg.Content.OfString.Value)
549 })
550
551 t.Run("should handle assistant messages with mixed content", func(t *testing.T) {
552 t.Parallel()
553
554 inputArgs := map[string]any{"query": "test"}
555 inputJSON, _ := json.Marshal(inputArgs)
556
557 prompt := fantasy.Prompt{
558 {
559 Role: fantasy.MessageRoleAssistant,
560 Content: []fantasy.MessagePart{
561 fantasy.TextPart{Text: "Let me search for that."},
562 fantasy.ToolCallPart{
563 ToolCallID: "call-123",
564 ToolName: "search",
565 Input: string(inputJSON),
566 },
567 },
568 },
569 }
570
571 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
572
573 require.Empty(t, warnings)
574 require.Len(t, messages, 1)
575
576 assistantMsg := messages[0].OfAssistant
577 require.NotNil(t, assistantMsg)
578 require.Equal(t, "Let me search for that.", assistantMsg.Content.OfString.Value)
579 require.Len(t, assistantMsg.ToolCalls, 1)
580
581 toolCall := assistantMsg.ToolCalls[0].OfFunction
582 require.Equal(t, "call-123", toolCall.ID)
583 require.Equal(t, "search", toolCall.Function.Name)
584 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
585 })
586}
587
588var testPrompt = fantasy.Prompt{
589 {
590 Role: fantasy.MessageRoleUser,
591 Content: []fantasy.MessagePart{
592 fantasy.TextPart{Text: "Hello"},
593 },
594 },
595}
596
597var testLogprobs = map[string]any{
598 "content": []map[string]any{
599 {
600 "token": "Hello",
601 "logprob": -0.0009994634,
602 "top_logprobs": []map[string]any{
603 {
604 "token": "Hello",
605 "logprob": -0.0009994634,
606 },
607 },
608 },
609 {
610 "token": "!",
611 "logprob": -0.13410144,
612 "top_logprobs": []map[string]any{
613 {
614 "token": "!",
615 "logprob": -0.13410144,
616 },
617 },
618 },
619 {
620 "token": " How",
621 "logprob": -0.0009250381,
622 "top_logprobs": []map[string]any{
623 {
624 "token": " How",
625 "logprob": -0.0009250381,
626 },
627 },
628 },
629 {
630 "token": " can",
631 "logprob": -0.047709424,
632 "top_logprobs": []map[string]any{
633 {
634 "token": " can",
635 "logprob": -0.047709424,
636 },
637 },
638 },
639 {
640 "token": " I",
641 "logprob": -0.000009014684,
642 "top_logprobs": []map[string]any{
643 {
644 "token": " I",
645 "logprob": -0.000009014684,
646 },
647 },
648 },
649 {
650 "token": " assist",
651 "logprob": -0.009125131,
652 "top_logprobs": []map[string]any{
653 {
654 "token": " assist",
655 "logprob": -0.009125131,
656 },
657 },
658 },
659 {
660 "token": " you",
661 "logprob": -0.0000066306106,
662 "top_logprobs": []map[string]any{
663 {
664 "token": " you",
665 "logprob": -0.0000066306106,
666 },
667 },
668 },
669 {
670 "token": " today",
671 "logprob": -0.00011093382,
672 "top_logprobs": []map[string]any{
673 {
674 "token": " today",
675 "logprob": -0.00011093382,
676 },
677 },
678 },
679 {
680 "token": "?",
681 "logprob": -0.00004596782,
682 "top_logprobs": []map[string]any{
683 {
684 "token": "?",
685 "logprob": -0.00004596782,
686 },
687 },
688 },
689 },
690}
691
692type mockServer struct {
693 server *httptest.Server
694 response map[string]any
695 calls []mockCall
696}
697
698type mockCall struct {
699 method string
700 path string
701 headers map[string]string
702 body map[string]any
703}
704
705func newMockServer() *mockServer {
706 ms := &mockServer{
707 calls: make([]mockCall, 0),
708 }
709
710 ms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
711 // Record the call
712 call := mockCall{
713 method: r.Method,
714 path: r.URL.Path,
715 headers: make(map[string]string),
716 }
717
718 for k, v := range r.Header {
719 if len(v) > 0 {
720 call.headers[k] = v[0]
721 }
722 }
723
724 // Parse request body
725 if r.Body != nil {
726 var body map[string]any
727 json.NewDecoder(r.Body).Decode(&body)
728 call.body = body
729 }
730
731 ms.calls = append(ms.calls, call)
732
733 // Return mock response
734 w.Header().Set("Content-Type", "application/json")
735 json.NewEncoder(w).Encode(ms.response)
736 }))
737
738 return ms
739}
740
741func (ms *mockServer) close() {
742 ms.server.Close()
743}
744
745func (ms *mockServer) prepareJSONResponse(opts map[string]any) {
746 // Default values
747 response := map[string]any{
748 "id": "chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd",
749 "object": "chat.completion",
750 "created": 1711115037,
751 "model": "gpt-3.5-turbo-0125",
752 "choices": []map[string]any{
753 {
754 "index": 0,
755 "message": map[string]any{
756 "role": "assistant",
757 "content": "",
758 },
759 "finish_reason": "stop",
760 },
761 },
762 "usage": map[string]any{
763 "prompt_tokens": 4,
764 "total_tokens": 34,
765 "completion_tokens": 30,
766 },
767 "system_fingerprint": "fp_3bc1b5746c",
768 }
769
770 // Override with provided options
771 for k, v := range opts {
772 switch k {
773 case "content":
774 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["content"] = v
775 case "tool_calls":
776 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["tool_calls"] = v
777 case "function_call":
778 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["function_call"] = v
779 case "annotations":
780 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["annotations"] = v
781 case "usage":
782 response["usage"] = v
783 case "finish_reason":
784 response["choices"].([]map[string]any)[0]["finish_reason"] = v
785 case "id":
786 response["id"] = v
787 case "created":
788 response["created"] = v
789 case "model":
790 response["model"] = v
791 case "logprobs":
792 if v != nil {
793 response["choices"].([]map[string]any)[0]["logprobs"] = v
794 }
795 }
796 }
797
798 ms.response = response
799}
800
801func TestDoGenerate(t *testing.T) {
802 t.Parallel()
803
804 t.Run("should extract text response", func(t *testing.T) {
805 t.Parallel()
806
807 server := newMockServer()
808 defer server.close()
809
810 server.prepareJSONResponse(map[string]any{
811 "content": "Hello, World!",
812 })
813
814 provider, err := New(
815 WithAPIKey("test-api-key"),
816 WithBaseURL(server.server.URL),
817 )
818 require.NoError(t, err)
819 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
820
821 result, err := model.Generate(context.Background(), fantasy.Call{
822 Prompt: testPrompt,
823 })
824
825 require.NoError(t, err)
826 require.Len(t, result.Content, 1)
827
828 textContent, ok := result.Content[0].(fantasy.TextContent)
829 require.True(t, ok)
830 require.Equal(t, "Hello, World!", textContent.Text)
831 })
832
833 t.Run("should extract usage", func(t *testing.T) {
834 t.Parallel()
835
836 server := newMockServer()
837 defer server.close()
838
839 server.prepareJSONResponse(map[string]any{
840 "usage": map[string]any{
841 "prompt_tokens": 20,
842 "total_tokens": 25,
843 "completion_tokens": 5,
844 },
845 })
846
847 provider, err := New(
848 WithAPIKey("test-api-key"),
849 WithBaseURL(server.server.URL),
850 )
851 require.NoError(t, err)
852 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
853
854 result, err := model.Generate(context.Background(), fantasy.Call{
855 Prompt: testPrompt,
856 })
857
858 require.NoError(t, err)
859 require.Equal(t, int64(20), result.Usage.InputTokens)
860 require.Equal(t, int64(5), result.Usage.OutputTokens)
861 require.Equal(t, int64(25), result.Usage.TotalTokens)
862 })
863
864 t.Run("should send request body", func(t *testing.T) {
865 t.Parallel()
866
867 server := newMockServer()
868 defer server.close()
869
870 server.prepareJSONResponse(map[string]any{})
871
872 provider, err := New(
873 WithAPIKey("test-api-key"),
874 WithBaseURL(server.server.URL),
875 )
876 require.NoError(t, err)
877 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
878
879 _, err = model.Generate(context.Background(), fantasy.Call{
880 Prompt: testPrompt,
881 })
882
883 require.NoError(t, err)
884 require.Len(t, server.calls, 1)
885
886 call := server.calls[0]
887 require.Equal(t, "POST", call.method)
888 require.Equal(t, "/chat/completions", call.path)
889 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
890
891 messages, ok := call.body["messages"].([]any)
892 require.True(t, ok)
893 require.Len(t, messages, 1)
894
895 message := messages[0].(map[string]any)
896 require.Equal(t, "user", message["role"])
897 require.Equal(t, "Hello", message["content"])
898 })
899
900 t.Run("should support partial usage", func(t *testing.T) {
901 t.Parallel()
902
903 server := newMockServer()
904 defer server.close()
905
906 server.prepareJSONResponse(map[string]any{
907 "usage": map[string]any{
908 "prompt_tokens": 20,
909 "total_tokens": 20,
910 },
911 })
912
913 provider, err := New(
914 WithAPIKey("test-api-key"),
915 WithBaseURL(server.server.URL),
916 )
917 require.NoError(t, err)
918 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
919
920 result, err := model.Generate(context.Background(), fantasy.Call{
921 Prompt: testPrompt,
922 })
923
924 require.NoError(t, err)
925 require.Equal(t, int64(20), result.Usage.InputTokens)
926 require.Equal(t, int64(0), result.Usage.OutputTokens)
927 require.Equal(t, int64(20), result.Usage.TotalTokens)
928 })
929
930 t.Run("should extract logprobs", func(t *testing.T) {
931 t.Parallel()
932
933 server := newMockServer()
934 defer server.close()
935
936 server.prepareJSONResponse(map[string]any{
937 "logprobs": testLogprobs,
938 })
939
940 provider, err := New(
941 WithAPIKey("test-api-key"),
942 WithBaseURL(server.server.URL),
943 )
944 require.NoError(t, err)
945 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
946
947 result, err := model.Generate(context.Background(), fantasy.Call{
948 Prompt: testPrompt,
949 ProviderOptions: NewProviderOptions(&ProviderOptions{
950 LogProbs: fantasy.Opt(true),
951 }),
952 })
953
954 require.NoError(t, err)
955 require.NotNil(t, result.ProviderMetadata)
956
957 openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
958 require.True(t, ok)
959
960 logprobs := openaiMeta.Logprobs
961 require.True(t, ok)
962 require.NotNil(t, logprobs)
963 })
964
965 t.Run("should extract finish reason", func(t *testing.T) {
966 t.Parallel()
967
968 server := newMockServer()
969 defer server.close()
970
971 server.prepareJSONResponse(map[string]any{
972 "finish_reason": "stop",
973 })
974
975 provider, err := New(
976 WithAPIKey("test-api-key"),
977 WithBaseURL(server.server.URL),
978 )
979 require.NoError(t, err)
980 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
981
982 result, err := model.Generate(context.Background(), fantasy.Call{
983 Prompt: testPrompt,
984 })
985
986 require.NoError(t, err)
987 require.Equal(t, fantasy.FinishReasonStop, result.FinishReason)
988 })
989
990 t.Run("should support unknown finish reason", func(t *testing.T) {
991 t.Parallel()
992
993 server := newMockServer()
994 defer server.close()
995
996 server.prepareJSONResponse(map[string]any{
997 "finish_reason": "eos",
998 })
999
1000 provider, err := New(
1001 WithAPIKey("test-api-key"),
1002 WithBaseURL(server.server.URL),
1003 )
1004 require.NoError(t, err)
1005 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1006
1007 result, err := model.Generate(context.Background(), fantasy.Call{
1008 Prompt: testPrompt,
1009 })
1010
1011 require.NoError(t, err)
1012 require.Equal(t, fantasy.FinishReasonUnknown, result.FinishReason)
1013 })
1014
1015 t.Run("should pass the model and the messages", func(t *testing.T) {
1016 t.Parallel()
1017
1018 server := newMockServer()
1019 defer server.close()
1020
1021 server.prepareJSONResponse(map[string]any{
1022 "content": "",
1023 })
1024
1025 provider, err := New(
1026 WithAPIKey("test-api-key"),
1027 WithBaseURL(server.server.URL),
1028 )
1029 require.NoError(t, err)
1030 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1031
1032 _, err = model.Generate(context.Background(), fantasy.Call{
1033 Prompt: testPrompt,
1034 })
1035
1036 require.NoError(t, err)
1037 require.Len(t, server.calls, 1)
1038
1039 call := server.calls[0]
1040 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1041
1042 messages := call.body["messages"].([]any)
1043 require.Len(t, messages, 1)
1044
1045 message := messages[0].(map[string]any)
1046 require.Equal(t, "user", message["role"])
1047 require.Equal(t, "Hello", message["content"])
1048 })
1049
1050 t.Run("should pass settings", func(t *testing.T) {
1051 t.Parallel()
1052
1053 server := newMockServer()
1054 defer server.close()
1055
1056 server.prepareJSONResponse(map[string]any{})
1057
1058 provider, err := New(
1059 WithAPIKey("test-api-key"),
1060 WithBaseURL(server.server.URL),
1061 )
1062 require.NoError(t, err)
1063 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1064
1065 _, err = model.Generate(context.Background(), fantasy.Call{
1066 Prompt: testPrompt,
1067 ProviderOptions: NewProviderOptions(&ProviderOptions{
1068 LogitBias: map[string]int64{
1069 "50256": -100,
1070 },
1071 ParallelToolCalls: fantasy.Opt(false),
1072 User: fantasy.Opt("test-user-id"),
1073 }),
1074 })
1075
1076 require.NoError(t, err)
1077 require.Len(t, server.calls, 1)
1078
1079 call := server.calls[0]
1080 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1081
1082 messages := call.body["messages"].([]any)
1083 require.Len(t, messages, 1)
1084
1085 logitBias := call.body["logit_bias"].(map[string]any)
1086 require.Equal(t, float64(-100), logitBias["50256"])
1087 require.Equal(t, false, call.body["parallel_tool_calls"])
1088 require.Equal(t, "test-user-id", call.body["user"])
1089 })
1090
1091 t.Run("should pass reasoningEffort setting", func(t *testing.T) {
1092 t.Parallel()
1093
1094 server := newMockServer()
1095 defer server.close()
1096
1097 server.prepareJSONResponse(map[string]any{
1098 "content": "",
1099 })
1100
1101 provider, err := New(
1102 WithAPIKey("test-api-key"),
1103 WithBaseURL(server.server.URL),
1104 )
1105 require.NoError(t, err)
1106 model, _ := provider.LanguageModel(t.Context(), "o1-mini")
1107
1108 _, err = model.Generate(context.Background(), fantasy.Call{
1109 Prompt: testPrompt,
1110 ProviderOptions: NewProviderOptions(
1111 &ProviderOptions{
1112 ReasoningEffort: ReasoningEffortOption(ReasoningEffortLow),
1113 },
1114 ),
1115 })
1116
1117 require.NoError(t, err)
1118 require.Len(t, server.calls, 1)
1119
1120 call := server.calls[0]
1121 require.Equal(t, "o1-mini", call.body["model"])
1122 require.Equal(t, "low", call.body["reasoning_effort"])
1123
1124 messages := call.body["messages"].([]any)
1125 require.Len(t, messages, 1)
1126
1127 message := messages[0].(map[string]any)
1128 require.Equal(t, "user", message["role"])
1129 require.Equal(t, "Hello", message["content"])
1130 })
1131
1132 t.Run("should pass textVerbosity setting", func(t *testing.T) {
1133 t.Parallel()
1134
1135 server := newMockServer()
1136 defer server.close()
1137
1138 server.prepareJSONResponse(map[string]any{
1139 "content": "",
1140 })
1141
1142 provider, err := New(
1143 WithAPIKey("test-api-key"),
1144 WithBaseURL(server.server.URL),
1145 )
1146 require.NoError(t, err)
1147 model, _ := provider.LanguageModel(t.Context(), "gpt-4o")
1148
1149 _, err = model.Generate(context.Background(), fantasy.Call{
1150 Prompt: testPrompt,
1151 ProviderOptions: NewProviderOptions(&ProviderOptions{
1152 TextVerbosity: fantasy.Opt("low"),
1153 }),
1154 })
1155
1156 require.NoError(t, err)
1157 require.Len(t, server.calls, 1)
1158
1159 call := server.calls[0]
1160 require.Equal(t, "gpt-4o", call.body["model"])
1161 require.Equal(t, "low", call.body["verbosity"])
1162
1163 messages := call.body["messages"].([]any)
1164 require.Len(t, messages, 1)
1165
1166 message := messages[0].(map[string]any)
1167 require.Equal(t, "user", message["role"])
1168 require.Equal(t, "Hello", message["content"])
1169 })
1170
1171 t.Run("should pass tools and toolChoice", func(t *testing.T) {
1172 t.Parallel()
1173
1174 server := newMockServer()
1175 defer server.close()
1176
1177 server.prepareJSONResponse(map[string]any{
1178 "content": "",
1179 })
1180
1181 provider, err := New(
1182 WithAPIKey("test-api-key"),
1183 WithBaseURL(server.server.URL),
1184 )
1185 require.NoError(t, err)
1186 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1187
1188 _, err = model.Generate(context.Background(), fantasy.Call{
1189 Prompt: testPrompt,
1190 Tools: []fantasy.Tool{
1191 fantasy.FunctionTool{
1192 Name: "test-tool",
1193 InputSchema: map[string]any{
1194 "type": "object",
1195 "properties": map[string]any{
1196 "value": map[string]any{
1197 "type": "string",
1198 },
1199 },
1200 "required": []string{"value"},
1201 "additionalProperties": false,
1202 "$schema": "http://json-schema.org/draft-07/schema#",
1203 },
1204 },
1205 },
1206 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoice("test-tool")}[0],
1207 })
1208
1209 require.NoError(t, err)
1210 require.Len(t, server.calls, 1)
1211
1212 call := server.calls[0]
1213 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1214
1215 messages := call.body["messages"].([]any)
1216 require.Len(t, messages, 1)
1217
1218 tools := call.body["tools"].([]any)
1219 require.Len(t, tools, 1)
1220
1221 tool := tools[0].(map[string]any)
1222 require.Equal(t, "function", tool["type"])
1223
1224 function := tool["function"].(map[string]any)
1225 require.Equal(t, "test-tool", function["name"])
1226 require.Equal(t, false, function["strict"])
1227
1228 toolChoice := call.body["tool_choice"].(map[string]any)
1229 require.Equal(t, "function", toolChoice["type"])
1230
1231 toolChoiceFunction := toolChoice["function"].(map[string]any)
1232 require.Equal(t, "test-tool", toolChoiceFunction["name"])
1233 })
1234
1235 t.Run("should parse tool results", func(t *testing.T) {
1236 t.Parallel()
1237
1238 server := newMockServer()
1239 defer server.close()
1240
1241 server.prepareJSONResponse(map[string]any{
1242 "tool_calls": []map[string]any{
1243 {
1244 "id": "call_O17Uplv4lJvD6DVdIvFFeRMw",
1245 "type": "function",
1246 "function": map[string]any{
1247 "name": "test-tool",
1248 "arguments": `{"value":"Spark"}`,
1249 },
1250 },
1251 },
1252 })
1253
1254 provider, err := New(
1255 WithAPIKey("test-api-key"),
1256 WithBaseURL(server.server.URL),
1257 )
1258 require.NoError(t, err)
1259 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1260
1261 result, err := model.Generate(context.Background(), fantasy.Call{
1262 Prompt: testPrompt,
1263 Tools: []fantasy.Tool{
1264 fantasy.FunctionTool{
1265 Name: "test-tool",
1266 InputSchema: map[string]any{
1267 "type": "object",
1268 "properties": map[string]any{
1269 "value": map[string]any{
1270 "type": "string",
1271 },
1272 },
1273 "required": []string{"value"},
1274 "additionalProperties": false,
1275 "$schema": "http://json-schema.org/draft-07/schema#",
1276 },
1277 },
1278 },
1279 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoice("test-tool")}[0],
1280 })
1281
1282 require.NoError(t, err)
1283 require.Len(t, result.Content, 1)
1284
1285 toolCall, ok := result.Content[0].(fantasy.ToolCallContent)
1286 require.True(t, ok)
1287 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", toolCall.ToolCallID)
1288 require.Equal(t, "test-tool", toolCall.ToolName)
1289 require.Equal(t, `{"value":"Spark"}`, toolCall.Input)
1290 })
1291
1292 t.Run("should handle ToolChoiceRequired", func(t *testing.T) {
1293 t.Parallel()
1294
1295 server := newMockServer()
1296 defer server.close()
1297
1298 server.prepareJSONResponse(map[string]any{
1299 "content": "",
1300 })
1301
1302 provider, err := New(
1303 WithAPIKey("test-api-key"),
1304 WithBaseURL(server.server.URL),
1305 )
1306 require.NoError(t, err)
1307 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1308
1309 _, err = model.Generate(context.Background(), fantasy.Call{
1310 Prompt: testPrompt,
1311 Tools: []fantasy.Tool{
1312 fantasy.FunctionTool{
1313 Name: "test-tool",
1314 InputSchema: map[string]any{
1315 "type": "object",
1316 "properties": map[string]any{
1317 "value": map[string]any{
1318 "type": "string",
1319 },
1320 },
1321 "required": []string{"value"},
1322 "additionalProperties": false,
1323 "$schema": "http://json-schema.org/draft-07/schema#",
1324 },
1325 },
1326 },
1327 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoiceRequired}[0],
1328 })
1329
1330 require.NoError(t, err)
1331 require.Len(t, server.calls, 1)
1332
1333 call := server.calls[0]
1334 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1335
1336 // Verify tool is present
1337 tools := call.body["tools"].([]any)
1338 require.Len(t, tools, 1)
1339
1340 tool := tools[0].(map[string]any)
1341 require.Equal(t, "function", tool["type"])
1342
1343 function := tool["function"].(map[string]any)
1344 require.Equal(t, "test-tool", function["name"])
1345
1346 // Verify tool_choice is set to "required" (not a function name)
1347 toolChoice := call.body["tool_choice"]
1348 require.Equal(t, "required", toolChoice)
1349 })
1350
1351 t.Run("should parse annotations/citations", func(t *testing.T) {
1352 t.Parallel()
1353
1354 server := newMockServer()
1355 defer server.close()
1356
1357 server.prepareJSONResponse(map[string]any{
1358 "content": "Based on the search results [doc1], I found information.",
1359 "annotations": []map[string]any{
1360 {
1361 "type": "url_citation",
1362 "url_citation": map[string]any{
1363 "start_index": 24,
1364 "end_index": 29,
1365 "url": "https://example.com/doc1.pdf",
1366 "title": "Document 1",
1367 },
1368 },
1369 },
1370 })
1371
1372 provider, err := New(
1373 WithAPIKey("test-api-key"),
1374 WithBaseURL(server.server.URL),
1375 )
1376 require.NoError(t, err)
1377 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1378
1379 result, err := model.Generate(context.Background(), fantasy.Call{
1380 Prompt: testPrompt,
1381 })
1382
1383 require.NoError(t, err)
1384 require.Len(t, result.Content, 2)
1385
1386 textContent, ok := result.Content[0].(fantasy.TextContent)
1387 require.True(t, ok)
1388 require.Equal(t, "Based on the search results [doc1], I found information.", textContent.Text)
1389
1390 sourceContent, ok := result.Content[1].(fantasy.SourceContent)
1391 require.True(t, ok)
1392 require.Equal(t, fantasy.SourceTypeURL, sourceContent.SourceType)
1393 require.Equal(t, "https://example.com/doc1.pdf", sourceContent.URL)
1394 require.Equal(t, "Document 1", sourceContent.Title)
1395 require.NotEmpty(t, sourceContent.ID)
1396 })
1397
1398 t.Run("should return cached_tokens in prompt_details_tokens", func(t *testing.T) {
1399 t.Parallel()
1400
1401 server := newMockServer()
1402 defer server.close()
1403
1404 server.prepareJSONResponse(map[string]any{
1405 "usage": map[string]any{
1406 "prompt_tokens": 15,
1407 "completion_tokens": 20,
1408 "total_tokens": 35,
1409 "prompt_tokens_details": map[string]any{
1410 "cached_tokens": 1152,
1411 },
1412 },
1413 })
1414
1415 provider, err := New(
1416 WithAPIKey("test-api-key"),
1417 WithBaseURL(server.server.URL),
1418 )
1419 require.NoError(t, err)
1420 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1421
1422 result, err := model.Generate(context.Background(), fantasy.Call{
1423 Prompt: testPrompt,
1424 })
1425
1426 require.NoError(t, err)
1427 require.Equal(t, int64(1152), result.Usage.CacheReadTokens)
1428 // InputTokens = prompt_tokens - cached_tokens = 15 - 1152 = -1137 → clamped to 0
1429 require.Equal(t, int64(0), result.Usage.InputTokens)
1430 require.Equal(t, int64(20), result.Usage.OutputTokens)
1431 require.Equal(t, int64(35), result.Usage.TotalTokens)
1432 })
1433
1434 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
1435 t.Parallel()
1436
1437 server := newMockServer()
1438 defer server.close()
1439
1440 server.prepareJSONResponse(map[string]any{
1441 "usage": map[string]any{
1442 "prompt_tokens": 15,
1443 "completion_tokens": 20,
1444 "total_tokens": 35,
1445 "completion_tokens_details": map[string]any{
1446 "accepted_prediction_tokens": 123,
1447 "rejected_prediction_tokens": 456,
1448 },
1449 },
1450 })
1451
1452 provider, err := New(
1453 WithAPIKey("test-api-key"),
1454 WithBaseURL(server.server.URL),
1455 )
1456 require.NoError(t, err)
1457 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1458
1459 result, err := model.Generate(context.Background(), fantasy.Call{
1460 Prompt: testPrompt,
1461 })
1462
1463 require.NoError(t, err)
1464 require.NotNil(t, result.ProviderMetadata)
1465
1466 openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
1467
1468 require.True(t, ok)
1469 require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
1470 require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
1471 })
1472
1473 t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) {
1474 t.Parallel()
1475
1476 server := newMockServer()
1477 defer server.close()
1478
1479 server.prepareJSONResponse(map[string]any{})
1480
1481 provider, err := New(
1482 WithAPIKey("test-api-key"),
1483 WithBaseURL(server.server.URL),
1484 )
1485 require.NoError(t, err)
1486 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1487
1488 result, err := model.Generate(context.Background(), fantasy.Call{
1489 Prompt: testPrompt,
1490 Temperature: &[]float64{0.5}[0],
1491 TopP: &[]float64{0.7}[0],
1492 FrequencyPenalty: &[]float64{0.2}[0],
1493 PresencePenalty: &[]float64{0.3}[0],
1494 })
1495
1496 require.NoError(t, err)
1497 require.Len(t, server.calls, 1)
1498
1499 call := server.calls[0]
1500 require.Equal(t, "o1-preview", call.body["model"])
1501
1502 messages := call.body["messages"].([]any)
1503 require.Len(t, messages, 1)
1504
1505 message := messages[0].(map[string]any)
1506 require.Equal(t, "user", message["role"])
1507 require.Equal(t, "Hello", message["content"])
1508
1509 // These should not be present
1510 require.Nil(t, call.body["temperature"])
1511 require.Nil(t, call.body["top_p"])
1512 require.Nil(t, call.body["frequency_penalty"])
1513 require.Nil(t, call.body["presence_penalty"])
1514
1515 // Should have warnings
1516 require.Len(t, result.Warnings, 4)
1517 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1518 require.Equal(t, "temperature", result.Warnings[0].Setting)
1519 require.Contains(t, result.Warnings[0].Details, "temperature is not supported for reasoning models")
1520 })
1521
1522 t.Run("should convert maxOutputTokens to max_completion_tokens for reasoning models", func(t *testing.T) {
1523 t.Parallel()
1524
1525 server := newMockServer()
1526 defer server.close()
1527
1528 server.prepareJSONResponse(map[string]any{})
1529
1530 provider, err := New(
1531 WithAPIKey("test-api-key"),
1532 WithBaseURL(server.server.URL),
1533 )
1534 require.NoError(t, err)
1535 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1536
1537 _, err = model.Generate(context.Background(), fantasy.Call{
1538 Prompt: testPrompt,
1539 MaxOutputTokens: &[]int64{1000}[0],
1540 })
1541
1542 require.NoError(t, err)
1543 require.Len(t, server.calls, 1)
1544
1545 call := server.calls[0]
1546 require.Equal(t, "o1-preview", call.body["model"])
1547 require.Equal(t, float64(1000), call.body["max_completion_tokens"])
1548 require.Nil(t, call.body["max_tokens"])
1549
1550 messages := call.body["messages"].([]any)
1551 require.Len(t, messages, 1)
1552
1553 message := messages[0].(map[string]any)
1554 require.Equal(t, "user", message["role"])
1555 require.Equal(t, "Hello", message["content"])
1556 })
1557
1558 t.Run("should return reasoning tokens", func(t *testing.T) {
1559 t.Parallel()
1560
1561 server := newMockServer()
1562 defer server.close()
1563
1564 server.prepareJSONResponse(map[string]any{
1565 "usage": map[string]any{
1566 "prompt_tokens": 15,
1567 "completion_tokens": 20,
1568 "total_tokens": 35,
1569 "completion_tokens_details": map[string]any{
1570 "reasoning_tokens": 10,
1571 },
1572 },
1573 })
1574
1575 provider, err := New(
1576 WithAPIKey("test-api-key"),
1577 WithBaseURL(server.server.URL),
1578 )
1579 require.NoError(t, err)
1580 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1581
1582 result, err := model.Generate(context.Background(), fantasy.Call{
1583 Prompt: testPrompt,
1584 })
1585
1586 require.NoError(t, err)
1587 require.Equal(t, int64(15), result.Usage.InputTokens)
1588 require.Equal(t, int64(20), result.Usage.OutputTokens)
1589 require.Equal(t, int64(35), result.Usage.TotalTokens)
1590 require.Equal(t, int64(10), result.Usage.ReasoningTokens)
1591 })
1592
1593 t.Run("should send max_completion_tokens extension setting", func(t *testing.T) {
1594 t.Parallel()
1595
1596 server := newMockServer()
1597 defer server.close()
1598
1599 server.prepareJSONResponse(map[string]any{
1600 "model": "o1-preview",
1601 })
1602
1603 provider, err := New(
1604 WithAPIKey("test-api-key"),
1605 WithBaseURL(server.server.URL),
1606 )
1607 require.NoError(t, err)
1608 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1609
1610 _, err = model.Generate(context.Background(), fantasy.Call{
1611 Prompt: testPrompt,
1612 ProviderOptions: NewProviderOptions(&ProviderOptions{
1613 MaxCompletionTokens: fantasy.Opt(int64(255)),
1614 }),
1615 })
1616
1617 require.NoError(t, err)
1618 require.Len(t, server.calls, 1)
1619
1620 call := server.calls[0]
1621 require.Equal(t, "o1-preview", call.body["model"])
1622 require.Equal(t, float64(255), call.body["max_completion_tokens"])
1623
1624 messages := call.body["messages"].([]any)
1625 require.Len(t, messages, 1)
1626
1627 message := messages[0].(map[string]any)
1628 require.Equal(t, "user", message["role"])
1629 require.Equal(t, "Hello", message["content"])
1630 })
1631
1632 t.Run("should send prediction extension setting", func(t *testing.T) {
1633 t.Parallel()
1634
1635 server := newMockServer()
1636 defer server.close()
1637
1638 server.prepareJSONResponse(map[string]any{
1639 "content": "",
1640 })
1641
1642 provider, err := New(
1643 WithAPIKey("test-api-key"),
1644 WithBaseURL(server.server.URL),
1645 )
1646 require.NoError(t, err)
1647 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1648
1649 _, err = model.Generate(context.Background(), fantasy.Call{
1650 Prompt: testPrompt,
1651 ProviderOptions: NewProviderOptions(&ProviderOptions{
1652 Prediction: map[string]any{
1653 "type": "content",
1654 "content": "Hello, World!",
1655 },
1656 }),
1657 })
1658
1659 require.NoError(t, err)
1660 require.Len(t, server.calls, 1)
1661
1662 call := server.calls[0]
1663 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1664
1665 prediction := call.body["prediction"].(map[string]any)
1666 require.Equal(t, "content", prediction["type"])
1667 require.Equal(t, "Hello, World!", prediction["content"])
1668
1669 messages := call.body["messages"].([]any)
1670 require.Len(t, messages, 1)
1671
1672 message := messages[0].(map[string]any)
1673 require.Equal(t, "user", message["role"])
1674 require.Equal(t, "Hello", message["content"])
1675 })
1676
1677 t.Run("should send store extension setting", func(t *testing.T) {
1678 t.Parallel()
1679
1680 server := newMockServer()
1681 defer server.close()
1682
1683 server.prepareJSONResponse(map[string]any{
1684 "content": "",
1685 })
1686
1687 provider, err := New(
1688 WithAPIKey("test-api-key"),
1689 WithBaseURL(server.server.URL),
1690 )
1691 require.NoError(t, err)
1692 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1693
1694 _, err = model.Generate(context.Background(), fantasy.Call{
1695 Prompt: testPrompt,
1696 ProviderOptions: NewProviderOptions(&ProviderOptions{
1697 Store: fantasy.Opt(true),
1698 }),
1699 })
1700
1701 require.NoError(t, err)
1702 require.Len(t, server.calls, 1)
1703
1704 call := server.calls[0]
1705 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1706 require.Equal(t, true, call.body["store"])
1707
1708 messages := call.body["messages"].([]any)
1709 require.Len(t, messages, 1)
1710
1711 message := messages[0].(map[string]any)
1712 require.Equal(t, "user", message["role"])
1713 require.Equal(t, "Hello", message["content"])
1714 })
1715
1716 t.Run("should send metadata extension values", func(t *testing.T) {
1717 t.Parallel()
1718
1719 server := newMockServer()
1720 defer server.close()
1721
1722 server.prepareJSONResponse(map[string]any{
1723 "content": "",
1724 })
1725
1726 provider, err := New(
1727 WithAPIKey("test-api-key"),
1728 WithBaseURL(server.server.URL),
1729 )
1730 require.NoError(t, err)
1731 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1732
1733 _, err = model.Generate(context.Background(), fantasy.Call{
1734 Prompt: testPrompt,
1735 ProviderOptions: NewProviderOptions(&ProviderOptions{
1736 Metadata: map[string]any{
1737 "custom": "value",
1738 },
1739 }),
1740 })
1741
1742 require.NoError(t, err)
1743 require.Len(t, server.calls, 1)
1744
1745 call := server.calls[0]
1746 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1747
1748 metadata := call.body["metadata"].(map[string]any)
1749 require.Equal(t, "value", metadata["custom"])
1750
1751 messages := call.body["messages"].([]any)
1752 require.Len(t, messages, 1)
1753
1754 message := messages[0].(map[string]any)
1755 require.Equal(t, "user", message["role"])
1756 require.Equal(t, "Hello", message["content"])
1757 })
1758
1759 t.Run("should send promptCacheKey extension value", func(t *testing.T) {
1760 t.Parallel()
1761
1762 server := newMockServer()
1763 defer server.close()
1764
1765 server.prepareJSONResponse(map[string]any{
1766 "content": "",
1767 })
1768
1769 provider, err := New(
1770 WithAPIKey("test-api-key"),
1771 WithBaseURL(server.server.URL),
1772 )
1773 require.NoError(t, err)
1774 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1775
1776 _, err = model.Generate(context.Background(), fantasy.Call{
1777 Prompt: testPrompt,
1778 ProviderOptions: NewProviderOptions(&ProviderOptions{
1779 PromptCacheKey: fantasy.Opt("test-cache-key-123"),
1780 }),
1781 })
1782
1783 require.NoError(t, err)
1784 require.Len(t, server.calls, 1)
1785
1786 call := server.calls[0]
1787 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1788 require.Equal(t, "test-cache-key-123", call.body["prompt_cache_key"])
1789
1790 messages := call.body["messages"].([]any)
1791 require.Len(t, messages, 1)
1792
1793 message := messages[0].(map[string]any)
1794 require.Equal(t, "user", message["role"])
1795 require.Equal(t, "Hello", message["content"])
1796 })
1797
1798 t.Run("should send safety_identifier extension value", func(t *testing.T) {
1799 t.Parallel()
1800
1801 server := newMockServer()
1802 defer server.close()
1803
1804 server.prepareJSONResponse(map[string]any{
1805 "content": "",
1806 })
1807
1808 provider, err := New(
1809 WithAPIKey("test-api-key"),
1810 WithBaseURL(server.server.URL),
1811 )
1812 require.NoError(t, err)
1813 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1814
1815 _, err = model.Generate(context.Background(), fantasy.Call{
1816 Prompt: testPrompt,
1817 ProviderOptions: NewProviderOptions(&ProviderOptions{
1818 SafetyIdentifier: fantasy.Opt("test-safety-identifier-123"),
1819 }),
1820 })
1821
1822 require.NoError(t, err)
1823 require.Len(t, server.calls, 1)
1824
1825 call := server.calls[0]
1826 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1827 require.Equal(t, "test-safety-identifier-123", call.body["safety_identifier"])
1828
1829 messages := call.body["messages"].([]any)
1830 require.Len(t, messages, 1)
1831
1832 message := messages[0].(map[string]any)
1833 require.Equal(t, "user", message["role"])
1834 require.Equal(t, "Hello", message["content"])
1835 })
1836
1837 t.Run("should remove temperature setting for search preview models", func(t *testing.T) {
1838 t.Parallel()
1839
1840 server := newMockServer()
1841 defer server.close()
1842
1843 server.prepareJSONResponse(map[string]any{})
1844
1845 provider, err := New(
1846 WithAPIKey("test-api-key"),
1847 WithBaseURL(server.server.URL),
1848 )
1849 require.NoError(t, err)
1850 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-search-preview")
1851
1852 result, err := model.Generate(context.Background(), fantasy.Call{
1853 Prompt: testPrompt,
1854 Temperature: &[]float64{0.7}[0],
1855 })
1856
1857 require.NoError(t, err)
1858 require.Len(t, server.calls, 1)
1859
1860 call := server.calls[0]
1861 require.Equal(t, "gpt-4o-search-preview", call.body["model"])
1862 require.Nil(t, call.body["temperature"])
1863
1864 require.Len(t, result.Warnings, 1)
1865 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1866 require.Equal(t, "temperature", result.Warnings[0].Setting)
1867 require.Contains(t, result.Warnings[0].Details, "search preview models")
1868 })
1869
1870 t.Run("should send ServiceTier flex processing setting", func(t *testing.T) {
1871 t.Parallel()
1872
1873 server := newMockServer()
1874 defer server.close()
1875
1876 server.prepareJSONResponse(map[string]any{
1877 "content": "",
1878 })
1879
1880 provider, err := New(
1881 WithAPIKey("test-api-key"),
1882 WithBaseURL(server.server.URL),
1883 )
1884 require.NoError(t, err)
1885 model, _ := provider.LanguageModel(t.Context(), "o3-mini")
1886
1887 _, err = model.Generate(context.Background(), fantasy.Call{
1888 Prompt: testPrompt,
1889 ProviderOptions: NewProviderOptions(&ProviderOptions{
1890 ServiceTier: fantasy.Opt("flex"),
1891 }),
1892 })
1893
1894 require.NoError(t, err)
1895 require.Len(t, server.calls, 1)
1896
1897 call := server.calls[0]
1898 require.Equal(t, "o3-mini", call.body["model"])
1899 require.Equal(t, "flex", call.body["service_tier"])
1900
1901 messages := call.body["messages"].([]any)
1902 require.Len(t, messages, 1)
1903
1904 message := messages[0].(map[string]any)
1905 require.Equal(t, "user", message["role"])
1906 require.Equal(t, "Hello", message["content"])
1907 })
1908
1909 t.Run("should show warning when using flex processing with unsupported model", func(t *testing.T) {
1910 t.Parallel()
1911
1912 server := newMockServer()
1913 defer server.close()
1914
1915 server.prepareJSONResponse(map[string]any{})
1916
1917 provider, err := New(
1918 WithAPIKey("test-api-key"),
1919 WithBaseURL(server.server.URL),
1920 )
1921 require.NoError(t, err)
1922 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1923
1924 result, err := model.Generate(context.Background(), fantasy.Call{
1925 Prompt: testPrompt,
1926 ProviderOptions: NewProviderOptions(&ProviderOptions{
1927 ServiceTier: fantasy.Opt("flex"),
1928 }),
1929 })
1930
1931 require.NoError(t, err)
1932 require.Len(t, server.calls, 1)
1933
1934 call := server.calls[0]
1935 require.Nil(t, call.body["service_tier"])
1936
1937 require.Len(t, result.Warnings, 1)
1938 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1939 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
1940 require.Contains(t, result.Warnings[0].Details, "flex processing is only available")
1941 })
1942
1943 t.Run("should send serviceTier priority processing setting", func(t *testing.T) {
1944 t.Parallel()
1945
1946 server := newMockServer()
1947 defer server.close()
1948
1949 server.prepareJSONResponse(map[string]any{})
1950
1951 provider, err := New(
1952 WithAPIKey("test-api-key"),
1953 WithBaseURL(server.server.URL),
1954 )
1955 require.NoError(t, err)
1956 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1957
1958 _, err = model.Generate(context.Background(), fantasy.Call{
1959 Prompt: testPrompt,
1960 ProviderOptions: NewProviderOptions(&ProviderOptions{
1961 ServiceTier: fantasy.Opt("priority"),
1962 }),
1963 })
1964
1965 require.NoError(t, err)
1966 require.Len(t, server.calls, 1)
1967
1968 call := server.calls[0]
1969 require.Equal(t, "gpt-4o-mini", call.body["model"])
1970 require.Equal(t, "priority", call.body["service_tier"])
1971
1972 messages := call.body["messages"].([]any)
1973 require.Len(t, messages, 1)
1974
1975 message := messages[0].(map[string]any)
1976 require.Equal(t, "user", message["role"])
1977 require.Equal(t, "Hello", message["content"])
1978 })
1979
1980 t.Run("should show warning when using priority processing with unsupported model", func(t *testing.T) {
1981 t.Parallel()
1982
1983 server := newMockServer()
1984 defer server.close()
1985
1986 server.prepareJSONResponse(map[string]any{})
1987
1988 provider, err := New(
1989 WithAPIKey("test-api-key"),
1990 WithBaseURL(server.server.URL),
1991 )
1992 require.NoError(t, err)
1993 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1994
1995 result, err := model.Generate(context.Background(), fantasy.Call{
1996 Prompt: testPrompt,
1997 ProviderOptions: NewProviderOptions(&ProviderOptions{
1998 ServiceTier: fantasy.Opt("priority"),
1999 }),
2000 })
2001
2002 require.NoError(t, err)
2003 require.Len(t, server.calls, 1)
2004
2005 call := server.calls[0]
2006 require.Nil(t, call.body["service_tier"])
2007
2008 require.Len(t, result.Warnings, 1)
2009 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
2010 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
2011 require.Contains(t, result.Warnings[0].Details, "priority processing is only available")
2012 })
2013}
2014
2015type streamingMockServer struct {
2016 server *httptest.Server
2017 chunks []string
2018 calls []mockCall
2019}
2020
2021func newStreamingMockServer() *streamingMockServer {
2022 sms := &streamingMockServer{
2023 calls: make([]mockCall, 0),
2024 }
2025
2026 sms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2027 // Record the call
2028 call := mockCall{
2029 method: r.Method,
2030 path: r.URL.Path,
2031 headers: make(map[string]string),
2032 }
2033
2034 for k, v := range r.Header {
2035 if len(v) > 0 {
2036 call.headers[k] = v[0]
2037 }
2038 }
2039
2040 // Parse request body
2041 if r.Body != nil {
2042 var body map[string]any
2043 json.NewDecoder(r.Body).Decode(&body)
2044 call.body = body
2045 }
2046
2047 sms.calls = append(sms.calls, call)
2048
2049 // Set streaming headers
2050 w.Header().Set("Content-Type", "text/event-stream")
2051 w.Header().Set("Cache-Control", "no-cache")
2052 w.Header().Set("Connection", "keep-alive")
2053
2054 // Add custom headers if any
2055 for _, chunk := range sms.chunks {
2056 if strings.HasPrefix(chunk, "HEADER:") {
2057 parts := strings.SplitN(chunk[7:], ":", 2)
2058 if len(parts) == 2 {
2059 w.Header().Set(parts[0], parts[1])
2060 }
2061 continue
2062 }
2063 }
2064
2065 w.WriteHeader(http.StatusOK)
2066
2067 // Write chunks
2068 for _, chunk := range sms.chunks {
2069 if strings.HasPrefix(chunk, "HEADER:") {
2070 continue
2071 }
2072 w.Write([]byte(chunk))
2073 if f, ok := w.(http.Flusher); ok {
2074 f.Flush()
2075 }
2076 }
2077 }))
2078
2079 return sms
2080}
2081
2082func (sms *streamingMockServer) close() {
2083 sms.server.Close()
2084}
2085
2086func (sms *streamingMockServer) prepareStreamResponse(opts map[string]any) {
2087 content := []string{}
2088 if c, ok := opts["content"].([]string); ok {
2089 content = c
2090 }
2091
2092 usage := map[string]any{
2093 "prompt_tokens": 17,
2094 "total_tokens": 244,
2095 "completion_tokens": 227,
2096 }
2097 if u, ok := opts["usage"].(map[string]any); ok {
2098 usage = u
2099 }
2100
2101 logprobs := map[string]any{}
2102 if l, ok := opts["logprobs"].(map[string]any); ok {
2103 logprobs = l
2104 }
2105
2106 finishReason := "stop"
2107 if fr, ok := opts["finish_reason"].(string); ok {
2108 finishReason = fr
2109 }
2110
2111 model := "gpt-3.5-turbo-0613"
2112 if m, ok := opts["model"].(string); ok {
2113 model = m
2114 }
2115
2116 headers := map[string]string{}
2117 if h, ok := opts["headers"].(map[string]string); ok {
2118 headers = h
2119 }
2120
2121 chunks := []string{}
2122
2123 // Add custom headers
2124 for k, v := range headers {
2125 chunks = append(chunks, "HEADER:"+k+":"+v)
2126 }
2127
2128 // Initial chunk with role
2129 initialChunk := map[string]any{
2130 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2131 "object": "chat.completion.chunk",
2132 "created": 1702657020,
2133 "model": model,
2134 "system_fingerprint": nil,
2135 "choices": []map[string]any{
2136 {
2137 "index": 0,
2138 "delta": map[string]any{
2139 "role": "assistant",
2140 "content": "",
2141 },
2142 "finish_reason": nil,
2143 },
2144 },
2145 }
2146 initialData, _ := json.Marshal(initialChunk)
2147 chunks = append(chunks, "data: "+string(initialData)+"\n\n")
2148
2149 // Content chunks
2150 for i, text := range content {
2151 contentChunk := map[string]any{
2152 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2153 "object": "chat.completion.chunk",
2154 "created": 1702657020,
2155 "model": model,
2156 "system_fingerprint": nil,
2157 "choices": []map[string]any{
2158 {
2159 "index": 1,
2160 "delta": map[string]any{
2161 "content": text,
2162 },
2163 "finish_reason": nil,
2164 },
2165 },
2166 }
2167 contentData, _ := json.Marshal(contentChunk)
2168 chunks = append(chunks, "data: "+string(contentData)+"\n\n")
2169
2170 // Add annotations if this is the last content chunk and we have annotations
2171 if i == len(content)-1 {
2172 if annotations, ok := opts["annotations"].([]map[string]any); ok {
2173 annotationChunk := map[string]any{
2174 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2175 "object": "chat.completion.chunk",
2176 "created": 1702657020,
2177 "model": model,
2178 "system_fingerprint": nil,
2179 "choices": []map[string]any{
2180 {
2181 "index": 1,
2182 "delta": map[string]any{
2183 "annotations": annotations,
2184 },
2185 "finish_reason": nil,
2186 },
2187 },
2188 }
2189 annotationData, _ := json.Marshal(annotationChunk)
2190 chunks = append(chunks, "data: "+string(annotationData)+"\n\n")
2191 }
2192 }
2193 }
2194
2195 // Finish chunk
2196 finishChunk := map[string]any{
2197 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2198 "object": "chat.completion.chunk",
2199 "created": 1702657020,
2200 "model": model,
2201 "system_fingerprint": nil,
2202 "choices": []map[string]any{
2203 {
2204 "index": 0,
2205 "delta": map[string]any{},
2206 "finish_reason": finishReason,
2207 },
2208 },
2209 }
2210
2211 if len(logprobs) > 0 {
2212 finishChunk["choices"].([]map[string]any)[0]["logprobs"] = logprobs
2213 }
2214
2215 finishData, _ := json.Marshal(finishChunk)
2216 chunks = append(chunks, "data: "+string(finishData)+"\n\n")
2217
2218 // Usage chunk
2219 usageChunk := map[string]any{
2220 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2221 "object": "chat.completion.chunk",
2222 "created": 1702657020,
2223 "model": model,
2224 "system_fingerprint": "fp_3bc1b5746c",
2225 "choices": []map[string]any{},
2226 "usage": usage,
2227 }
2228 usageData, _ := json.Marshal(usageChunk)
2229 chunks = append(chunks, "data: "+string(usageData)+"\n\n")
2230
2231 // Done
2232 chunks = append(chunks, "data: [DONE]\n\n")
2233
2234 sms.chunks = chunks
2235}
2236
2237func (sms *streamingMockServer) prepareToolStreamResponse() {
2238 chunks := []string{
2239 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2240 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2241 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"value"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2242 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2243 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2244 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2245 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2246 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2247 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
2248 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
2249 "data: [DONE]\n\n",
2250 }
2251 sms.chunks = chunks
2252}
2253
2254func (sms *streamingMockServer) prepareErrorStreamResponse() {
2255 chunks := []string{
2256 `data: {"error":{"message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error.","type":"server_error","param":null,"code":null}}` + "\n\n",
2257 "data: [DONE]\n\n",
2258 }
2259 sms.chunks = chunks
2260}
2261
2262func collectStreamParts(stream fantasy.StreamResponse) ([]fantasy.StreamPart, error) {
2263 var parts []fantasy.StreamPart
2264 for part := range stream {
2265 parts = append(parts, part)
2266 if part.Type == fantasy.StreamPartTypeError {
2267 break
2268 }
2269 if part.Type == fantasy.StreamPartTypeFinish {
2270 break
2271 }
2272 }
2273 return parts, nil
2274}
2275
2276func TestDoStream(t *testing.T) {
2277 t.Parallel()
2278
2279 t.Run("should stream text deltas", func(t *testing.T) {
2280 t.Parallel()
2281
2282 server := newStreamingMockServer()
2283 defer server.close()
2284
2285 server.prepareStreamResponse(map[string]any{
2286 "content": []string{"Hello", ", ", "World!"},
2287 "finish_reason": "stop",
2288 "usage": map[string]any{
2289 "prompt_tokens": 17,
2290 "total_tokens": 244,
2291 "completion_tokens": 227,
2292 },
2293 "logprobs": testLogprobs,
2294 })
2295
2296 provider, err := New(
2297 WithAPIKey("test-api-key"),
2298 WithBaseURL(server.server.URL),
2299 )
2300 require.NoError(t, err)
2301 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2302
2303 stream, err := model.Stream(context.Background(), fantasy.Call{
2304 Prompt: testPrompt,
2305 })
2306
2307 require.NoError(t, err)
2308
2309 parts, err := collectStreamParts(stream)
2310 require.NoError(t, err)
2311
2312 // Verify stream structure
2313 require.True(t, len(parts) >= 4) // text-start, deltas, text-end, finish
2314
2315 // Find text parts
2316 textStart, textEnd, finish := -1, -1, -1
2317 var deltas []string
2318
2319 for i, part := range parts {
2320 switch part.Type {
2321 case fantasy.StreamPartTypeTextStart:
2322 textStart = i
2323 case fantasy.StreamPartTypeTextDelta:
2324 deltas = append(deltas, part.Delta)
2325 case fantasy.StreamPartTypeTextEnd:
2326 textEnd = i
2327 case fantasy.StreamPartTypeFinish:
2328 finish = i
2329 }
2330 }
2331
2332 require.NotEqual(t, -1, textStart)
2333 require.NotEqual(t, -1, textEnd)
2334 require.NotEqual(t, -1, finish)
2335 require.Equal(t, []string{"Hello", ", ", "World!"}, deltas)
2336
2337 // Check finish part
2338 finishPart := parts[finish]
2339 require.Equal(t, fantasy.FinishReasonStop, finishPart.FinishReason)
2340 require.Equal(t, int64(17), finishPart.Usage.InputTokens)
2341 require.Equal(t, int64(227), finishPart.Usage.OutputTokens)
2342 require.Equal(t, int64(244), finishPart.Usage.TotalTokens)
2343 })
2344
2345 t.Run("should stream tool deltas", func(t *testing.T) {
2346 t.Parallel()
2347
2348 server := newStreamingMockServer()
2349 defer server.close()
2350
2351 server.prepareToolStreamResponse()
2352
2353 provider, err := New(
2354 WithAPIKey("test-api-key"),
2355 WithBaseURL(server.server.URL),
2356 )
2357 require.NoError(t, err)
2358 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2359
2360 stream, err := model.Stream(context.Background(), fantasy.Call{
2361 Prompt: testPrompt,
2362 Tools: []fantasy.Tool{
2363 fantasy.FunctionTool{
2364 Name: "test-tool",
2365 InputSchema: map[string]any{
2366 "type": "object",
2367 "properties": map[string]any{
2368 "value": map[string]any{
2369 "type": "string",
2370 },
2371 },
2372 "required": []string{"value"},
2373 "additionalProperties": false,
2374 "$schema": "http://json-schema.org/draft-07/schema#",
2375 },
2376 },
2377 },
2378 })
2379
2380 require.NoError(t, err)
2381
2382 parts, err := collectStreamParts(stream)
2383 require.NoError(t, err)
2384
2385 // Find tool-related parts
2386 toolInputStart, toolInputEnd, toolCall := -1, -1, -1
2387 var toolDeltas []string
2388
2389 for i, part := range parts {
2390 switch part.Type {
2391 case fantasy.StreamPartTypeToolInputStart:
2392 toolInputStart = i
2393 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2394 require.Equal(t, "test-tool", part.ToolCallName)
2395 case fantasy.StreamPartTypeToolInputDelta:
2396 toolDeltas = append(toolDeltas, part.Delta)
2397 case fantasy.StreamPartTypeToolInputEnd:
2398 toolInputEnd = i
2399 case fantasy.StreamPartTypeToolCall:
2400 toolCall = i
2401 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2402 require.Equal(t, "test-tool", part.ToolCallName)
2403 require.Equal(t, `{"value":"Sparkle Day"}`, part.ToolCallInput)
2404 }
2405 }
2406
2407 require.NotEqual(t, -1, toolInputStart)
2408 require.NotEqual(t, -1, toolInputEnd)
2409 require.NotEqual(t, -1, toolCall)
2410
2411 // Verify tool deltas combine to form the complete input
2412 var fullInput strings.Builder
2413 for _, delta := range toolDeltas {
2414 fullInput.WriteString(delta)
2415 }
2416 require.Equal(t, `{"value":"Sparkle Day"}`, fullInput.String())
2417 })
2418
2419 t.Run("should stream annotations/citations", func(t *testing.T) {
2420 t.Parallel()
2421
2422 server := newStreamingMockServer()
2423 defer server.close()
2424
2425 server.prepareStreamResponse(map[string]any{
2426 "content": []string{"Based on search results"},
2427 "annotations": []map[string]any{
2428 {
2429 "type": "url_citation",
2430 "url_citation": map[string]any{
2431 "start_index": 24,
2432 "end_index": 29,
2433 "url": "https://example.com/doc1.pdf",
2434 "title": "Document 1",
2435 },
2436 },
2437 },
2438 })
2439
2440 provider, err := New(
2441 WithAPIKey("test-api-key"),
2442 WithBaseURL(server.server.URL),
2443 )
2444 require.NoError(t, err)
2445 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2446
2447 stream, err := model.Stream(context.Background(), fantasy.Call{
2448 Prompt: testPrompt,
2449 })
2450
2451 require.NoError(t, err)
2452
2453 parts, err := collectStreamParts(stream)
2454 require.NoError(t, err)
2455
2456 // Find source part
2457 var sourcePart *fantasy.StreamPart
2458 for _, part := range parts {
2459 if part.Type == fantasy.StreamPartTypeSource {
2460 sourcePart = &part
2461 break
2462 }
2463 }
2464
2465 require.NotNil(t, sourcePart)
2466 require.Equal(t, fantasy.SourceTypeURL, sourcePart.SourceType)
2467 require.Equal(t, "https://example.com/doc1.pdf", sourcePart.URL)
2468 require.Equal(t, "Document 1", sourcePart.Title)
2469 require.NotEmpty(t, sourcePart.ID)
2470 })
2471
2472 t.Run("should handle error stream parts", func(t *testing.T) {
2473 t.Parallel()
2474
2475 server := newStreamingMockServer()
2476 defer server.close()
2477
2478 server.prepareErrorStreamResponse()
2479
2480 provider, err := New(
2481 WithAPIKey("test-api-key"),
2482 WithBaseURL(server.server.URL),
2483 )
2484 require.NoError(t, err)
2485 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2486
2487 stream, err := model.Stream(context.Background(), fantasy.Call{
2488 Prompt: testPrompt,
2489 })
2490
2491 require.NoError(t, err)
2492
2493 parts, err := collectStreamParts(stream)
2494 require.NoError(t, err)
2495
2496 // Should have error and finish parts
2497 require.True(t, len(parts) >= 1)
2498
2499 // Find error part
2500 var errorPart *fantasy.StreamPart
2501 for _, part := range parts {
2502 if part.Type == fantasy.StreamPartTypeError {
2503 errorPart = &part
2504 break
2505 }
2506 }
2507
2508 require.NotNil(t, errorPart)
2509 require.NotNil(t, errorPart.Error)
2510 })
2511
2512 t.Run("should send request body", func(t *testing.T) {
2513 t.Parallel()
2514
2515 server := newStreamingMockServer()
2516 defer server.close()
2517
2518 server.prepareStreamResponse(map[string]any{
2519 "content": []string{},
2520 })
2521
2522 provider, err := New(
2523 WithAPIKey("test-api-key"),
2524 WithBaseURL(server.server.URL),
2525 )
2526 require.NoError(t, err)
2527 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2528
2529 _, err = model.Stream(context.Background(), fantasy.Call{
2530 Prompt: testPrompt,
2531 })
2532
2533 require.NoError(t, err)
2534 require.Len(t, server.calls, 1)
2535
2536 call := server.calls[0]
2537 require.Equal(t, "POST", call.method)
2538 require.Equal(t, "/chat/completions", call.path)
2539 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2540 require.Equal(t, true, call.body["stream"])
2541
2542 streamOptions := call.body["stream_options"].(map[string]any)
2543 require.Equal(t, true, streamOptions["include_usage"])
2544
2545 messages := call.body["messages"].([]any)
2546 require.Len(t, messages, 1)
2547
2548 message := messages[0].(map[string]any)
2549 require.Equal(t, "user", message["role"])
2550 require.Equal(t, "Hello", message["content"])
2551 })
2552
2553 t.Run("should return cached tokens in providerMetadata", func(t *testing.T) {
2554 t.Parallel()
2555
2556 server := newStreamingMockServer()
2557 defer server.close()
2558
2559 server.prepareStreamResponse(map[string]any{
2560 "content": []string{},
2561 "usage": map[string]any{
2562 "prompt_tokens": 15,
2563 "completion_tokens": 20,
2564 "total_tokens": 35,
2565 "prompt_tokens_details": map[string]any{
2566 "cached_tokens": 1152,
2567 },
2568 },
2569 })
2570
2571 provider, err := New(
2572 WithAPIKey("test-api-key"),
2573 WithBaseURL(server.server.URL),
2574 )
2575 require.NoError(t, err)
2576 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2577
2578 stream, err := model.Stream(context.Background(), fantasy.Call{
2579 Prompt: testPrompt,
2580 })
2581
2582 require.NoError(t, err)
2583
2584 parts, err := collectStreamParts(stream)
2585 require.NoError(t, err)
2586
2587 // Find finish part
2588 var finishPart *fantasy.StreamPart
2589 for _, part := range parts {
2590 if part.Type == fantasy.StreamPartTypeFinish {
2591 finishPart = &part
2592 break
2593 }
2594 }
2595
2596 require.NotNil(t, finishPart)
2597 require.Equal(t, int64(1152), finishPart.Usage.CacheReadTokens)
2598 // InputTokens = prompt_tokens - cached_tokens = 15 - 1152 = -1137 → clamped to 0
2599 require.Equal(t, int64(0), finishPart.Usage.InputTokens)
2600 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2601 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2602 })
2603
2604 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
2605 t.Parallel()
2606
2607 server := newStreamingMockServer()
2608 defer server.close()
2609
2610 server.prepareStreamResponse(map[string]any{
2611 "content": []string{},
2612 "usage": map[string]any{
2613 "prompt_tokens": 15,
2614 "completion_tokens": 20,
2615 "total_tokens": 35,
2616 "completion_tokens_details": map[string]any{
2617 "accepted_prediction_tokens": 123,
2618 "rejected_prediction_tokens": 456,
2619 },
2620 },
2621 })
2622
2623 provider, err := New(
2624 WithAPIKey("test-api-key"),
2625 WithBaseURL(server.server.URL),
2626 )
2627 require.NoError(t, err)
2628 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2629
2630 stream, err := model.Stream(context.Background(), fantasy.Call{
2631 Prompt: testPrompt,
2632 })
2633
2634 require.NoError(t, err)
2635
2636 parts, err := collectStreamParts(stream)
2637 require.NoError(t, err)
2638
2639 // Find finish part
2640 var finishPart *fantasy.StreamPart
2641 for _, part := range parts {
2642 if part.Type == fantasy.StreamPartTypeFinish {
2643 finishPart = &part
2644 break
2645 }
2646 }
2647
2648 require.NotNil(t, finishPart)
2649 require.NotNil(t, finishPart.ProviderMetadata)
2650
2651 openaiMeta, ok := finishPart.ProviderMetadata["openai"].(*ProviderMetadata)
2652 require.True(t, ok)
2653 require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
2654 require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
2655 })
2656
2657 t.Run("should send store extension setting", func(t *testing.T) {
2658 t.Parallel()
2659
2660 server := newStreamingMockServer()
2661 defer server.close()
2662
2663 server.prepareStreamResponse(map[string]any{
2664 "content": []string{},
2665 })
2666
2667 provider, err := New(
2668 WithAPIKey("test-api-key"),
2669 WithBaseURL(server.server.URL),
2670 )
2671 require.NoError(t, err)
2672 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2673
2674 _, err = model.Stream(context.Background(), fantasy.Call{
2675 Prompt: testPrompt,
2676 ProviderOptions: NewProviderOptions(&ProviderOptions{
2677 Store: fantasy.Opt(true),
2678 }),
2679 })
2680
2681 require.NoError(t, err)
2682 require.Len(t, server.calls, 1)
2683
2684 call := server.calls[0]
2685 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2686 require.Equal(t, true, call.body["stream"])
2687 require.Equal(t, true, call.body["store"])
2688
2689 streamOptions := call.body["stream_options"].(map[string]any)
2690 require.Equal(t, true, streamOptions["include_usage"])
2691
2692 messages := call.body["messages"].([]any)
2693 require.Len(t, messages, 1)
2694
2695 message := messages[0].(map[string]any)
2696 require.Equal(t, "user", message["role"])
2697 require.Equal(t, "Hello", message["content"])
2698 })
2699
2700 t.Run("should send metadata extension values", func(t *testing.T) {
2701 t.Parallel()
2702
2703 server := newStreamingMockServer()
2704 defer server.close()
2705
2706 server.prepareStreamResponse(map[string]any{
2707 "content": []string{},
2708 })
2709
2710 provider, err := New(
2711 WithAPIKey("test-api-key"),
2712 WithBaseURL(server.server.URL),
2713 )
2714 require.NoError(t, err)
2715 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2716
2717 _, err = model.Stream(context.Background(), fantasy.Call{
2718 Prompt: testPrompt,
2719 ProviderOptions: NewProviderOptions(&ProviderOptions{
2720 Metadata: map[string]any{
2721 "custom": "value",
2722 },
2723 }),
2724 })
2725
2726 require.NoError(t, err)
2727 require.Len(t, server.calls, 1)
2728
2729 call := server.calls[0]
2730 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2731 require.Equal(t, true, call.body["stream"])
2732
2733 metadata := call.body["metadata"].(map[string]any)
2734 require.Equal(t, "value", metadata["custom"])
2735
2736 streamOptions := call.body["stream_options"].(map[string]any)
2737 require.Equal(t, true, streamOptions["include_usage"])
2738
2739 messages := call.body["messages"].([]any)
2740 require.Len(t, messages, 1)
2741
2742 message := messages[0].(map[string]any)
2743 require.Equal(t, "user", message["role"])
2744 require.Equal(t, "Hello", message["content"])
2745 })
2746
2747 t.Run("should send serviceTier flex processing setting in streaming", func(t *testing.T) {
2748 t.Parallel()
2749
2750 server := newStreamingMockServer()
2751 defer server.close()
2752
2753 server.prepareStreamResponse(map[string]any{
2754 "content": []string{},
2755 })
2756
2757 provider, err := New(
2758 WithAPIKey("test-api-key"),
2759 WithBaseURL(server.server.URL),
2760 )
2761 require.NoError(t, err)
2762 model, _ := provider.LanguageModel(t.Context(), "o3-mini")
2763
2764 _, err = model.Stream(context.Background(), fantasy.Call{
2765 Prompt: testPrompt,
2766 ProviderOptions: NewProviderOptions(&ProviderOptions{
2767 ServiceTier: fantasy.Opt("flex"),
2768 }),
2769 })
2770
2771 require.NoError(t, err)
2772 require.Len(t, server.calls, 1)
2773
2774 call := server.calls[0]
2775 require.Equal(t, "o3-mini", call.body["model"])
2776 require.Equal(t, "flex", call.body["service_tier"])
2777 require.Equal(t, true, call.body["stream"])
2778
2779 streamOptions := call.body["stream_options"].(map[string]any)
2780 require.Equal(t, true, streamOptions["include_usage"])
2781
2782 messages := call.body["messages"].([]any)
2783 require.Len(t, messages, 1)
2784
2785 message := messages[0].(map[string]any)
2786 require.Equal(t, "user", message["role"])
2787 require.Equal(t, "Hello", message["content"])
2788 })
2789
2790 t.Run("should send serviceTier priority processing setting in streaming", func(t *testing.T) {
2791 t.Parallel()
2792
2793 server := newStreamingMockServer()
2794 defer server.close()
2795
2796 server.prepareStreamResponse(map[string]any{
2797 "content": []string{},
2798 })
2799
2800 provider, err := New(
2801 WithAPIKey("test-api-key"),
2802 WithBaseURL(server.server.URL),
2803 )
2804 require.NoError(t, err)
2805 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
2806
2807 _, err = model.Stream(context.Background(), fantasy.Call{
2808 Prompt: testPrompt,
2809 ProviderOptions: NewProviderOptions(&ProviderOptions{
2810 ServiceTier: fantasy.Opt("priority"),
2811 }),
2812 })
2813
2814 require.NoError(t, err)
2815 require.Len(t, server.calls, 1)
2816
2817 call := server.calls[0]
2818 require.Equal(t, "gpt-4o-mini", call.body["model"])
2819 require.Equal(t, "priority", call.body["service_tier"])
2820 require.Equal(t, true, call.body["stream"])
2821
2822 streamOptions := call.body["stream_options"].(map[string]any)
2823 require.Equal(t, true, streamOptions["include_usage"])
2824
2825 messages := call.body["messages"].([]any)
2826 require.Len(t, messages, 1)
2827
2828 message := messages[0].(map[string]any)
2829 require.Equal(t, "user", message["role"])
2830 require.Equal(t, "Hello", message["content"])
2831 })
2832
2833 t.Run("should stream text delta for reasoning models", func(t *testing.T) {
2834 t.Parallel()
2835
2836 server := newStreamingMockServer()
2837 defer server.close()
2838
2839 server.prepareStreamResponse(map[string]any{
2840 "content": []string{"Hello, World!"},
2841 "model": "o1-preview",
2842 })
2843
2844 provider, err := New(
2845 WithAPIKey("test-api-key"),
2846 WithBaseURL(server.server.URL),
2847 )
2848 require.NoError(t, err)
2849 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
2850
2851 stream, err := model.Stream(context.Background(), fantasy.Call{
2852 Prompt: testPrompt,
2853 })
2854
2855 require.NoError(t, err)
2856
2857 parts, err := collectStreamParts(stream)
2858 require.NoError(t, err)
2859
2860 // Find text parts
2861 var textDeltas []string
2862 for _, part := range parts {
2863 if part.Type == fantasy.StreamPartTypeTextDelta {
2864 textDeltas = append(textDeltas, part.Delta)
2865 }
2866 }
2867
2868 // Should contain the text content (without empty delta)
2869 require.Equal(t, []string{"Hello, World!"}, textDeltas)
2870 })
2871
2872 t.Run("should send reasoning tokens", func(t *testing.T) {
2873 t.Parallel()
2874
2875 server := newStreamingMockServer()
2876 defer server.close()
2877
2878 server.prepareStreamResponse(map[string]any{
2879 "content": []string{"Hello, World!"},
2880 "model": "o1-preview",
2881 "usage": map[string]any{
2882 "prompt_tokens": 15,
2883 "completion_tokens": 20,
2884 "total_tokens": 35,
2885 "completion_tokens_details": map[string]any{
2886 "reasoning_tokens": 10,
2887 },
2888 },
2889 })
2890
2891 provider, err := New(
2892 WithAPIKey("test-api-key"),
2893 WithBaseURL(server.server.URL),
2894 )
2895 require.NoError(t, err)
2896 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
2897
2898 stream, err := model.Stream(context.Background(), fantasy.Call{
2899 Prompt: testPrompt,
2900 })
2901
2902 require.NoError(t, err)
2903
2904 parts, err := collectStreamParts(stream)
2905 require.NoError(t, err)
2906
2907 // Find finish part
2908 var finishPart *fantasy.StreamPart
2909 for _, part := range parts {
2910 if part.Type == fantasy.StreamPartTypeFinish {
2911 finishPart = &part
2912 break
2913 }
2914 }
2915
2916 require.NotNil(t, finishPart)
2917 require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2918 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2919 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2920 require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens)
2921 })
2922}
2923
2924func TestDefaultToPrompt_DropsEmptyMessages(t *testing.T) {
2925 t.Parallel()
2926
2927 t.Run("should drop truly empty assistant messages", func(t *testing.T) {
2928 t.Parallel()
2929
2930 prompt := fantasy.Prompt{
2931 {
2932 Role: fantasy.MessageRoleUser,
2933 Content: []fantasy.MessagePart{
2934 fantasy.TextPart{Text: "Hello"},
2935 },
2936 },
2937 {
2938 Role: fantasy.MessageRoleAssistant,
2939 Content: []fantasy.MessagePart{},
2940 },
2941 }
2942
2943 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
2944
2945 require.Len(t, messages, 1, "should only have user message")
2946 require.Len(t, warnings, 1)
2947 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
2948 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
2949 })
2950
2951 t.Run("should keep assistant messages with text content", func(t *testing.T) {
2952 t.Parallel()
2953
2954 prompt := fantasy.Prompt{
2955 {
2956 Role: fantasy.MessageRoleUser,
2957 Content: []fantasy.MessagePart{
2958 fantasy.TextPart{Text: "Hello"},
2959 },
2960 },
2961 {
2962 Role: fantasy.MessageRoleAssistant,
2963 Content: []fantasy.MessagePart{
2964 fantasy.TextPart{Text: "Hi there!"},
2965 },
2966 },
2967 }
2968
2969 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
2970
2971 require.Len(t, messages, 2, "should have both user and assistant messages")
2972 require.Empty(t, warnings)
2973 })
2974
2975 t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
2976 t.Parallel()
2977
2978 prompt := fantasy.Prompt{
2979 {
2980 Role: fantasy.MessageRoleUser,
2981 Content: []fantasy.MessagePart{
2982 fantasy.TextPart{Text: "What's the weather?"},
2983 },
2984 },
2985 {
2986 Role: fantasy.MessageRoleAssistant,
2987 Content: []fantasy.MessagePart{
2988 fantasy.ToolCallPart{
2989 ToolCallID: "call_123",
2990 ToolName: "get_weather",
2991 Input: `{"location":"NYC"}`,
2992 },
2993 },
2994 },
2995 }
2996
2997 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
2998
2999 require.Len(t, messages, 2, "should have both user and assistant messages")
3000 require.Empty(t, warnings)
3001 })
3002
3003 t.Run("should drop user messages without visible content", func(t *testing.T) {
3004 t.Parallel()
3005
3006 prompt := fantasy.Prompt{
3007 {
3008 Role: fantasy.MessageRoleUser,
3009 Content: []fantasy.MessagePart{
3010 fantasy.FilePart{
3011 Data: []byte("not supported"),
3012 MediaType: "application/unknown",
3013 },
3014 },
3015 },
3016 }
3017
3018 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3019
3020 require.Empty(t, messages)
3021 require.Len(t, warnings, 2) // One for unsupported type, one for empty message
3022 require.Contains(t, warnings[1].Message, "dropping empty user message")
3023 })
3024
3025 t.Run("should keep user messages with image content", func(t *testing.T) {
3026 t.Parallel()
3027
3028 prompt := fantasy.Prompt{
3029 {
3030 Role: fantasy.MessageRoleUser,
3031 Content: []fantasy.MessagePart{
3032 fantasy.FilePart{
3033 Data: []byte{0x01, 0x02, 0x03},
3034 MediaType: "image/png",
3035 },
3036 },
3037 },
3038 }
3039
3040 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3041
3042 require.Len(t, messages, 1)
3043 require.Empty(t, warnings)
3044 })
3045
3046 t.Run("should keep user messages with tool results", func(t *testing.T) {
3047 t.Parallel()
3048
3049 prompt := fantasy.Prompt{
3050 {
3051 Role: fantasy.MessageRoleTool,
3052 Content: []fantasy.MessagePart{
3053 fantasy.ToolResultPart{
3054 ToolCallID: "call_123",
3055 Output: fantasy.ToolResultOutputContentText{Text: "done"},
3056 },
3057 },
3058 },
3059 }
3060
3061 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3062
3063 require.Len(t, messages, 1)
3064 require.Empty(t, warnings)
3065 })
3066
3067 t.Run("should keep user messages with tool error results", func(t *testing.T) {
3068 t.Parallel()
3069
3070 prompt := fantasy.Prompt{
3071 {
3072 Role: fantasy.MessageRoleTool,
3073 Content: []fantasy.MessagePart{
3074 fantasy.ToolResultPart{
3075 ToolCallID: "call_456",
3076 Output: fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
3077 },
3078 },
3079 },
3080 }
3081
3082 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3083
3084 require.Len(t, messages, 1)
3085 require.Empty(t, warnings)
3086 })
3087}
3088
3089func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) {
3090 t.Parallel()
3091
3092 t.Run("should drop truly empty assistant messages", func(t *testing.T) {
3093 t.Parallel()
3094
3095 prompt := fantasy.Prompt{
3096 {
3097 Role: fantasy.MessageRoleUser,
3098 Content: []fantasy.MessagePart{
3099 fantasy.TextPart{Text: "Hello"},
3100 },
3101 },
3102 {
3103 Role: fantasy.MessageRoleAssistant,
3104 Content: []fantasy.MessagePart{},
3105 },
3106 }
3107
3108 input, warnings := toResponsesPrompt(prompt, "system")
3109
3110 require.Len(t, input, 1, "should only have user message")
3111 require.Len(t, warnings, 1)
3112 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
3113 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
3114 })
3115
3116 t.Run("should keep assistant messages with text content", func(t *testing.T) {
3117 t.Parallel()
3118
3119 prompt := fantasy.Prompt{
3120 {
3121 Role: fantasy.MessageRoleUser,
3122 Content: []fantasy.MessagePart{
3123 fantasy.TextPart{Text: "Hello"},
3124 },
3125 },
3126 {
3127 Role: fantasy.MessageRoleAssistant,
3128 Content: []fantasy.MessagePart{
3129 fantasy.TextPart{Text: "Hi there!"},
3130 },
3131 },
3132 }
3133
3134 input, warnings := toResponsesPrompt(prompt, "system")
3135
3136 require.Len(t, input, 2, "should have both user and assistant messages")
3137 require.Empty(t, warnings)
3138 })
3139
3140 t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
3141 t.Parallel()
3142
3143 prompt := fantasy.Prompt{
3144 {
3145 Role: fantasy.MessageRoleUser,
3146 Content: []fantasy.MessagePart{
3147 fantasy.TextPart{Text: "What's the weather?"},
3148 },
3149 },
3150 {
3151 Role: fantasy.MessageRoleAssistant,
3152 Content: []fantasy.MessagePart{
3153 fantasy.ToolCallPart{
3154 ToolCallID: "call_123",
3155 ToolName: "get_weather",
3156 Input: `{"location":"NYC"}`,
3157 },
3158 },
3159 },
3160 }
3161
3162 input, warnings := toResponsesPrompt(prompt, "system")
3163
3164 require.Len(t, input, 2, "should have both user and assistant messages")
3165 require.Empty(t, warnings)
3166 })
3167
3168 t.Run("should drop user messages without visible content", func(t *testing.T) {
3169 t.Parallel()
3170
3171 prompt := fantasy.Prompt{
3172 {
3173 Role: fantasy.MessageRoleUser,
3174 Content: []fantasy.MessagePart{
3175 fantasy.FilePart{
3176 Data: []byte("not supported"),
3177 MediaType: "application/unknown",
3178 },
3179 },
3180 },
3181 }
3182
3183 input, warnings := toResponsesPrompt(prompt, "system")
3184
3185 require.Empty(t, input)
3186 require.Len(t, warnings, 2) // One for unsupported type, one for empty message
3187 require.Contains(t, warnings[1].Message, "dropping empty user message")
3188 })
3189
3190 t.Run("should keep user messages with image content", func(t *testing.T) {
3191 t.Parallel()
3192
3193 prompt := fantasy.Prompt{
3194 {
3195 Role: fantasy.MessageRoleUser,
3196 Content: []fantasy.MessagePart{
3197 fantasy.FilePart{
3198 Data: []byte{0x01, 0x02, 0x03},
3199 MediaType: "image/png",
3200 },
3201 },
3202 },
3203 }
3204
3205 input, warnings := toResponsesPrompt(prompt, "system")
3206
3207 require.Len(t, input, 1)
3208 require.Empty(t, warnings)
3209 })
3210
3211 t.Run("should keep user messages with tool results", func(t *testing.T) {
3212 t.Parallel()
3213
3214 prompt := fantasy.Prompt{
3215 {
3216 Role: fantasy.MessageRoleTool,
3217 Content: []fantasy.MessagePart{
3218 fantasy.ToolResultPart{
3219 ToolCallID: "call_123",
3220 Output: fantasy.ToolResultOutputContentText{Text: "done"},
3221 },
3222 },
3223 },
3224 }
3225
3226 input, warnings := toResponsesPrompt(prompt, "system")
3227
3228 require.Len(t, input, 1)
3229 require.Empty(t, warnings)
3230 })
3231
3232 t.Run("should keep user messages with tool error results", func(t *testing.T) {
3233 t.Parallel()
3234
3235 prompt := fantasy.Prompt{
3236 {
3237 Role: fantasy.MessageRoleTool,
3238 Content: []fantasy.MessagePart{
3239 fantasy.ToolResultPart{
3240 ToolCallID: "call_456",
3241 Output: fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
3242 },
3243 },
3244 },
3245 }
3246
3247 input, warnings := toResponsesPrompt(prompt, "system")
3248
3249 require.Len(t, input, 1)
3250 require.Empty(t, warnings)
3251 })
3252}
3253
3254func TestParseContextTooLargeError(t *testing.T) {
3255 t.Parallel()
3256
3257 tests := []struct {
3258 name string
3259 message string
3260 wantErr bool
3261 wantUsed int
3262 wantMax int
3263 }{
3264 {
3265 name: "matches openai format with resulted in",
3266 message: "This model's maximum context length is 128000 tokens. However, your messages resulted in 150000 tokens.",
3267 wantErr: true,
3268 wantUsed: 150000,
3269 wantMax: 128000,
3270 },
3271 {
3272 name: "matches openai format with requested",
3273 message: "maximum context length is 8192 tokens, however you requested 10000 tokens",
3274 wantErr: true,
3275 wantUsed: 10000,
3276 wantMax: 8192,
3277 },
3278 {
3279 name: "does not match unrelated error",
3280 message: "invalid api key",
3281 wantErr: false,
3282 },
3283 {
3284 name: "does not match rate limit error",
3285 message: "rate limit exceeded",
3286 wantErr: false,
3287 },
3288 }
3289
3290 for _, tt := range tests {
3291 t.Run(tt.name, func(t *testing.T) {
3292 t.Parallel()
3293 providerErr := &fantasy.ProviderError{Message: tt.message}
3294 parseContextTooLargeError(tt.message, providerErr)
3295
3296 if tt.wantErr {
3297 require.True(t, providerErr.IsContextTooLarge())
3298 if tt.wantUsed > 0 {
3299 require.Equal(t, tt.wantUsed, providerErr.ContextUsedTokens)
3300 require.Equal(t, tt.wantMax, providerErr.ContextMaxTokens)
3301 }
3302 } else {
3303 require.False(t, providerErr.IsContextTooLarge())
3304 }
3305 })
3306 }
3307}
3308
3309func TestUserAgent(t *testing.T) {
3310 t.Parallel()
3311
3312 t.Run("default UA applied", func(t *testing.T) {
3313 t.Parallel()
3314
3315 server := newMockServer()
3316 defer server.close()
3317 server.prepareJSONResponse(map[string]any{})
3318
3319 p, err := New(WithAPIKey("k"), WithBaseURL(server.server.URL))
3320 require.NoError(t, err)
3321 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3322 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt})
3323
3324 require.Len(t, server.calls, 1)
3325 assert.Equal(t, "Charm-Fantasy/"+fantasy.Version+" (https://charm.land/fantasy)", server.calls[0].headers["User-Agent"])
3326 })
3327
3328 t.Run("WithHeaders User-Agent wins over default", func(t *testing.T) {
3329 t.Parallel()
3330
3331 server := newMockServer()
3332 defer server.close()
3333 server.prepareJSONResponse(map[string]any{})
3334
3335 p, err := New(WithAPIKey("k"), WithBaseURL(server.server.URL), WithHeaders(map[string]string{"User-Agent": "custom-from-headers"}))
3336 require.NoError(t, err)
3337 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3338 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt})
3339
3340 require.Len(t, server.calls, 1)
3341 assert.Equal(t, "custom-from-headers", server.calls[0].headers["User-Agent"])
3342 })
3343
3344 t.Run("WithUserAgent wins over both", func(t *testing.T) {
3345 t.Parallel()
3346
3347 server := newMockServer()
3348 defer server.close()
3349 server.prepareJSONResponse(map[string]any{})
3350
3351 p, err := New(
3352 WithAPIKey("k"),
3353 WithBaseURL(server.server.URL),
3354 WithHeaders(map[string]string{"User-Agent": "from-headers"}),
3355 WithUserAgent("explicit-ua"),
3356 )
3357 require.NoError(t, err)
3358 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3359 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt})
3360
3361 require.Len(t, server.calls, 1)
3362 assert.Equal(t, "explicit-ua", server.calls[0].headers["User-Agent"])
3363 })
3364
3365 t.Run("Call.UserAgent overrides provider WithHeaders UA", func(t *testing.T) {
3366 t.Parallel()
3367
3368 server := newMockServer()
3369 defer server.close()
3370 server.prepareJSONResponse(map[string]any{})
3371
3372 p, err := New(
3373 WithAPIKey("k"),
3374 WithBaseURL(server.server.URL),
3375 WithHeaders(map[string]string{"User-Agent": "header-ua"}),
3376 )
3377 require.NoError(t, err)
3378 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3379 _, _ = model.Generate(t.Context(), fantasy.Call{
3380 Prompt: testPrompt,
3381 UserAgent: "call-level-ua",
3382 })
3383
3384 require.Len(t, server.calls, 1)
3385 assert.Equal(t, "call-level-ua", server.calls[0].headers["User-Agent"])
3386 })
3387
3388 t.Run("no Call UA falls through to provider UA", func(t *testing.T) {
3389 t.Parallel()
3390
3391 server := newMockServer()
3392 defer server.close()
3393 server.prepareJSONResponse(map[string]any{})
3394
3395 p, err := New(
3396 WithAPIKey("k"),
3397 WithBaseURL(server.server.URL),
3398 WithUserAgent("provider-ua"),
3399 )
3400 require.NoError(t, err)
3401 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3402 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt})
3403
3404 require.Len(t, server.calls, 1)
3405 assert.Equal(t, "provider-ua", server.calls[0].headers["User-Agent"])
3406 })
3407
3408 t.Run("agent WithUserAgent overrides provider UA end-to-end", func(t *testing.T) {
3409 t.Parallel()
3410
3411 server := newMockServer()
3412 defer server.close()
3413 server.prepareJSONResponse(map[string]any{})
3414
3415 p, err := New(
3416 WithAPIKey("k"),
3417 WithBaseURL(server.server.URL),
3418 WithUserAgent("provider-ua"),
3419 )
3420 require.NoError(t, err)
3421 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3422
3423 agent := fantasy.NewAgent(model, fantasy.WithUserAgent("agent-ua"))
3424 _, _ = agent.Generate(t.Context(), fantasy.AgentCall{Prompt: "hi"})
3425
3426 require.Len(t, server.calls, 1)
3427 assert.Equal(t, "agent-ua", server.calls[0].headers["User-Agent"])
3428 })
3429
3430 t.Run("agent without UA falls through to provider UA end-to-end", func(t *testing.T) {
3431 t.Parallel()
3432
3433 server := newMockServer()
3434 defer server.close()
3435 server.prepareJSONResponse(map[string]any{})
3436
3437 p, err := New(
3438 WithAPIKey("k"),
3439 WithBaseURL(server.server.URL),
3440 WithUserAgent("provider-ua"),
3441 )
3442 require.NoError(t, err)
3443 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3444
3445 agent := fantasy.NewAgent(model)
3446 _, _ = agent.Generate(t.Context(), fantasy.AgentCall{Prompt: "hi"})
3447
3448 require.Len(t, server.calls, 1)
3449 assert.Equal(t, "provider-ua", server.calls[0].headers["User-Agent"])
3450 })
3451}
3452
3453// --- OpenAI Responses API Web Search Tests ---
3454
3455// mockResponsesWebSearchResponse returns a Responses API response
3456// containing a web_search_call output item followed by a message
3457// with url_citation annotations.
3458func mockResponsesWebSearchResponse() map[string]any {
3459 return map[string]any{
3460 "id": "resp_01WebSearch",
3461 "object": "response",
3462 "model": "gpt-4.1",
3463 "output": []any{
3464 map[string]any{
3465 "type": "web_search_call",
3466 "id": "ws_01",
3467 "status": "completed",
3468 "action": map[string]any{
3469 "type": "search",
3470 "query": "latest AI news",
3471 },
3472 },
3473 map[string]any{
3474 "type": "message",
3475 "id": "msg_01",
3476 "role": "assistant",
3477 "status": "completed",
3478 "content": []any{
3479 map[string]any{
3480 "type": "output_text",
3481 "text": "Based on recent search results, here is the latest AI news.",
3482 "annotations": []any{
3483 map[string]any{
3484 "type": "url_citation",
3485 "url": "https://example.com/ai-news",
3486 "title": "Latest AI News",
3487 "start_index": 0,
3488 "end_index": 50,
3489 },
3490 map[string]any{
3491 "type": "url_citation",
3492 "url": "https://example.com/ml-update",
3493 "title": "ML Update",
3494 "start_index": 51,
3495 "end_index": 60,
3496 },
3497 },
3498 },
3499 },
3500 },
3501 },
3502 "status": "completed",
3503 "usage": map[string]any{
3504 "input_tokens": 100,
3505 "output_tokens": 50,
3506 "total_tokens": 150,
3507 },
3508 }
3509}
3510
3511func newResponsesProvider(t *testing.T, serverURL string) fantasy.LanguageModel {
3512 t.Helper()
3513 provider, err := New(
3514 WithAPIKey("test-api-key"),
3515 WithBaseURL(serverURL),
3516 WithUseResponsesAPI(),
3517 )
3518 require.NoError(t, err)
3519 model, err := provider.LanguageModel(context.Background(), "gpt-4.1")
3520 require.NoError(t, err)
3521 return model
3522}
3523
3524func TestResponsesGenerate_WebSearchResponse(t *testing.T) {
3525 t.Parallel()
3526
3527 server := newMockServer()
3528 defer server.close()
3529 server.response = mockResponsesWebSearchResponse()
3530
3531 model := newResponsesProvider(t, server.server.URL)
3532
3533 resp, err := model.Generate(context.Background(), fantasy.Call{
3534 Prompt: testPrompt,
3535 Tools: []fantasy.Tool{WebSearchTool(nil)},
3536 })
3537 require.NoError(t, err)
3538
3539 require.Equal(t, "POST", server.calls[0].method)
3540 require.Equal(t, "/responses", server.calls[0].path)
3541
3542 var (
3543 toolCalls []fantasy.ToolCallContent
3544 sources []fantasy.SourceContent
3545 toolResults []fantasy.ToolResultContent
3546 texts []fantasy.TextContent
3547 )
3548 for _, c := range resp.Content {
3549 switch v := c.(type) {
3550 case fantasy.ToolCallContent:
3551 toolCalls = append(toolCalls, v)
3552 case fantasy.SourceContent:
3553 sources = append(sources, v)
3554 case fantasy.ToolResultContent:
3555 toolResults = append(toolResults, v)
3556 case fantasy.TextContent:
3557 texts = append(texts, v)
3558 }
3559 }
3560
3561 // ToolCallContent for the provider-executed web_search.
3562 require.Len(t, toolCalls, 1)
3563 require.True(t, toolCalls[0].ProviderExecuted)
3564 require.Equal(t, "web_search", toolCalls[0].ToolName)
3565 require.Equal(t, "ws_01", toolCalls[0].ToolCallID)
3566
3567 // SourceContent entries from url_citation annotations.
3568 require.Len(t, sources, 2)
3569 require.Equal(t, "https://example.com/ai-news", sources[0].URL)
3570 require.Equal(t, "Latest AI News", sources[0].Title)
3571 require.Equal(t, fantasy.SourceTypeURL, sources[0].SourceType)
3572 require.Equal(t, "https://example.com/ml-update", sources[1].URL)
3573 require.Equal(t, "ML Update", sources[1].Title)
3574
3575 // ToolResultContent with provider metadata.
3576 require.Len(t, toolResults, 1)
3577 require.True(t, toolResults[0].ProviderExecuted)
3578 require.Equal(t, "web_search", toolResults[0].ToolName)
3579 require.Equal(t, "ws_01", toolResults[0].ToolCallID)
3580
3581 metaVal, ok := toolResults[0].ProviderMetadata[Name]
3582 require.True(t, ok, "providerMetadata should contain openai key")
3583 wsMeta, ok := metaVal.(*WebSearchCallMetadata)
3584 require.True(t, ok, "metadata should be *WebSearchCallMetadata")
3585 require.Equal(t, "ws_01", wsMeta.ItemID)
3586 require.NotNil(t, wsMeta.Action)
3587 require.Equal(t, "search", wsMeta.Action.Type)
3588 require.Equal(t, "latest AI news", wsMeta.Action.Query)
3589
3590 // TextContent with the final answer.
3591 require.Len(t, texts, 1)
3592 require.Equal(t,
3593 "Based on recent search results, here is the latest AI news.",
3594 texts[0].Text,
3595 )
3596}
3597
3598func TestResponsesGenerate_StoreOption(t *testing.T) {
3599 t.Parallel()
3600
3601 server := newMockServer()
3602 defer server.close()
3603 server.response = mockResponsesWebSearchResponse()
3604
3605 model := newResponsesProvider(t, server.server.URL)
3606
3607 _, err := model.Generate(context.Background(), fantasy.Call{
3608 Prompt: testPrompt,
3609 ProviderOptions: fantasy.ProviderOptions{
3610 Name: &ResponsesProviderOptions{
3611 Store: fantasy.Opt(true),
3612 },
3613 },
3614 })
3615 require.NoError(t, err)
3616
3617 require.Equal(t, "POST", server.calls[0].method)
3618 require.Equal(t, "/responses", server.calls[0].path)
3619 require.Equal(t, true, server.calls[0].body["store"])
3620}
3621
3622func TestResponsesGenerate_PreviousResponseIDOption(t *testing.T) {
3623 t.Parallel()
3624
3625 server := newMockServer()
3626 defer server.close()
3627 server.response = mockResponsesWebSearchResponse()
3628
3629 model := newResponsesProvider(t, server.server.URL)
3630
3631 _, err := model.Generate(context.Background(), fantasy.Call{
3632 Prompt: testPrompt,
3633 ProviderOptions: fantasy.ProviderOptions{
3634 Name: &ResponsesProviderOptions{
3635 PreviousResponseID: fantasy.Opt("resp_prev_123"),
3636 Store: fantasy.Opt(true),
3637 },
3638 },
3639 })
3640 require.NoError(t, err)
3641
3642 require.Equal(t, "POST", server.calls[0].method)
3643 require.Equal(t, "/responses", server.calls[0].path)
3644 require.Equal(t, "resp_prev_123", server.calls[0].body["previous_response_id"])
3645}
3646
3647func TestResponsesGenerate_StateChainingAcrossTurns(t *testing.T) {
3648 t.Parallel()
3649
3650 server := newMockServer()
3651 defer server.close()
3652 server.response = map[string]any{
3653 "id": "resp_turn_1",
3654 "object": "response",
3655 "model": "gpt-4.1",
3656 "output": []any{
3657 map[string]any{
3658 "type": "message",
3659 "id": "msg_1",
3660 "role": "assistant",
3661 "status": "completed",
3662 "content": []any{
3663 map[string]any{
3664 "type": "output_text",
3665 "text": "First turn",
3666 },
3667 },
3668 },
3669 },
3670 "status": "completed",
3671 "usage": map[string]any{
3672 "input_tokens": 10,
3673 "output_tokens": 5,
3674 "total_tokens": 15,
3675 },
3676 }
3677
3678 model := newResponsesProvider(t, server.server.URL)
3679
3680 first, err := model.Generate(context.Background(), fantasy.Call{
3681 Prompt: testPrompt,
3682 ProviderOptions: fantasy.ProviderOptions{
3683 Name: &ResponsesProviderOptions{Store: fantasy.Opt(true)},
3684 },
3685 })
3686 require.NoError(t, err)
3687
3688 meta, ok := first.ProviderMetadata[Name].(*ResponsesProviderMetadata)
3689 require.True(t, ok)
3690 require.Equal(t, "resp_turn_1", meta.ResponseID)
3691
3692 server.response = map[string]any{
3693 "id": "resp_turn_2",
3694 "object": "response",
3695 "model": "gpt-4.1",
3696 "output": []any{
3697 map[string]any{
3698 "type": "message",
3699 "id": "msg_2",
3700 "role": "assistant",
3701 "status": "completed",
3702 "content": []any{
3703 map[string]any{
3704 "type": "output_text",
3705 "text": "Second turn",
3706 },
3707 },
3708 },
3709 },
3710 "status": "completed",
3711 "usage": map[string]any{
3712 "input_tokens": 8,
3713 "output_tokens": 4,
3714 "total_tokens": 12,
3715 },
3716 }
3717
3718 _, err = model.Generate(context.Background(), fantasy.Call{
3719 Prompt: fantasy.Prompt{
3720 fantasy.NewUserMessage("follow-up only"),
3721 },
3722 ProviderOptions: fantasy.ProviderOptions{
3723 Name: &ResponsesProviderOptions{
3724 Store: fantasy.Opt(true),
3725 PreviousResponseID: &meta.ResponseID,
3726 },
3727 },
3728 })
3729 require.NoError(t, err)
3730 require.Len(t, server.calls, 2)
3731
3732 firstCall := server.calls[0]
3733 require.Equal(t, true, firstCall.body["store"])
3734
3735 secondCall := server.calls[1]
3736 require.Equal(t, "resp_turn_1", secondCall.body["previous_response_id"])
3737 require.Equal(t, true, secondCall.body["store"])
3738
3739 input, ok := secondCall.body["input"].([]any)
3740 require.True(t, ok)
3741 require.Len(t, input, 1)
3742
3743 inputMessage, ok := input[0].(map[string]any)
3744 require.True(t, ok)
3745 require.Equal(t, "user", inputMessage["role"])
3746}
3747
3748func TestResponsesGenerate_WebSearchToolInRequest(t *testing.T) {
3749 t.Parallel()
3750
3751 t.Run("basic web_search tool", func(t *testing.T) {
3752 t.Parallel()
3753
3754 server := newMockServer()
3755 defer server.close()
3756 server.response = mockResponsesWebSearchResponse()
3757
3758 model := newResponsesProvider(t, server.server.URL)
3759
3760 _, err := model.Generate(context.Background(), fantasy.Call{
3761 Prompt: testPrompt,
3762 Tools: []fantasy.Tool{WebSearchTool(nil)},
3763 })
3764 require.NoError(t, err)
3765
3766 tools, ok := server.calls[0].body["tools"].([]any)
3767 require.True(t, ok, "request body should have tools array")
3768 require.Len(t, tools, 1)
3769
3770 tool, ok := tools[0].(map[string]any)
3771 require.True(t, ok)
3772 require.Equal(t, "web_search", tool["type"])
3773 })
3774
3775 t.Run("with search_context_size and allowed_domains", func(t *testing.T) {
3776 t.Parallel()
3777
3778 server := newMockServer()
3779 defer server.close()
3780 server.response = mockResponsesWebSearchResponse()
3781
3782 model := newResponsesProvider(t, server.server.URL)
3783
3784 _, err := model.Generate(context.Background(), fantasy.Call{
3785 Prompt: testPrompt,
3786 Tools: []fantasy.Tool{
3787 WebSearchTool(&WebSearchToolOptions{
3788 SearchContextSize: SearchContextSizeHigh,
3789 AllowedDomains: []string{"example.com", "test.com"},
3790 }),
3791 },
3792 })
3793 require.NoError(t, err)
3794
3795 tools, ok := server.calls[0].body["tools"].([]any)
3796 require.True(t, ok)
3797 require.Len(t, tools, 1)
3798
3799 tool, ok := tools[0].(map[string]any)
3800 require.True(t, ok)
3801 require.Equal(t, "web_search", tool["type"])
3802 require.Equal(t, "high", tool["search_context_size"])
3803
3804 filters, ok := tool["filters"].(map[string]any)
3805 require.True(t, ok, "tool should have filters")
3806 domains, ok := filters["allowed_domains"].([]any)
3807 require.True(t, ok, "filters should have allowed_domains")
3808 require.Len(t, domains, 2)
3809 require.Equal(t, "example.com", domains[0])
3810 require.Equal(t, "test.com", domains[1])
3811 })
3812
3813 t.Run("with user_location", func(t *testing.T) {
3814 t.Parallel()
3815
3816 server := newMockServer()
3817 defer server.close()
3818 server.response = mockResponsesWebSearchResponse()
3819
3820 model := newResponsesProvider(t, server.server.URL)
3821
3822 _, err := model.Generate(context.Background(), fantasy.Call{
3823 Prompt: testPrompt,
3824 Tools: []fantasy.Tool{
3825 WebSearchTool(&WebSearchToolOptions{
3826 UserLocation: &WebSearchUserLocation{
3827 City: "San Francisco",
3828 Country: "US",
3829 },
3830 }),
3831 },
3832 })
3833 require.NoError(t, err)
3834
3835 tools, ok := server.calls[0].body["tools"].([]any)
3836 require.True(t, ok)
3837 require.Len(t, tools, 1)
3838
3839 tool, ok := tools[0].(map[string]any)
3840 require.True(t, ok)
3841 require.Equal(t, "web_search", tool["type"])
3842
3843 userLoc, ok := tool["user_location"].(map[string]any)
3844 require.True(t, ok, "tool should have user_location")
3845 require.Equal(t, "San Francisco", userLoc["city"])
3846 require.Equal(t, "US", userLoc["country"])
3847 })
3848}
3849
3850func TestResponsesToPrompt_WebSearchProviderExecutedToolResults(t *testing.T) {
3851 t.Parallel()
3852
3853 prompt := fantasy.Prompt{
3854 {
3855 Role: fantasy.MessageRoleUser,
3856 Content: []fantasy.MessagePart{
3857 fantasy.TextPart{Text: "Search for the latest AI news"},
3858 },
3859 },
3860 {
3861 Role: fantasy.MessageRoleAssistant,
3862 Content: []fantasy.MessagePart{
3863 fantasy.ToolCallPart{
3864 ToolCallID: "ws_01",
3865 ToolName: "web_search",
3866 ProviderExecuted: true,
3867 },
3868 fantasy.ToolResultPart{
3869 ToolCallID: "ws_01",
3870 ProviderExecuted: true,
3871 },
3872 fantasy.TextPart{Text: "Here is what I found."},
3873 },
3874 },
3875 }
3876
3877 input, warnings := toResponsesPrompt(prompt, "system instructions")
3878
3879 require.Empty(t, warnings)
3880
3881 // Expected input items: user message, item_reference (for
3882 // provider-executed tool call; the ToolResultPart is skipped),
3883 // and assistant text message. System instructions are passed
3884 // via params.Instructions, not as an input item.
3885 require.Len(t, input, 3,
3886 "expected user + item_reference + assistant text")
3887}
3888
3889func TestResponsesStream_WebSearchResponse(t *testing.T) {
3890 t.Parallel()
3891
3892 chunks := []string{
3893 "event: response.output_item.added\n" +
3894 `data: {"type":"response.output_item.added","output_index":0,"item":{"type":"web_search_call","id":"ws_01","status":"in_progress"}}` + "\n\n",
3895 "event: response.output_item.done\n" +
3896 `data: {"type":"response.output_item.done","output_index":0,"item":{"type":"web_search_call","id":"ws_01","status":"completed","action":{"type":"search","query":"latest AI news"}}}` + "\n\n",
3897 "event: response.output_item.added\n" +
3898 `data: {"type":"response.output_item.added","output_index":1,"item":{"type":"message","id":"msg_01","role":"assistant","status":"in_progress","content":[]}}` + "\n\n",
3899 "event: response.output_text.delta\n" +
3900 `data: {"type":"response.output_text.delta","output_index":1,"content_index":0,"delta":"Here are the results."}` + "\n\n",
3901 "event: response.output_item.done\n" +
3902 `data: {"type":"response.output_item.done","output_index":1,"item":{"type":"message","id":"msg_01","role":"assistant","status":"completed","content":[{"type":"output_text","text":"Here are the results.","annotations":[{"type":"url_citation","url":"https://example.com/ai-news","title":"Latest AI News","start_index":0,"end_index":21}]}]}}` + "\n\n",
3903 "event: response.completed\n" +
3904 `data: {"type":"response.completed","response":{"id":"resp_01","status":"completed","output":[],"usage":{"input_tokens":100,"output_tokens":50,"total_tokens":150}}}` + "\n\n",
3905 }
3906
3907 sms := newStreamingMockServer()
3908 defer sms.close()
3909 sms.chunks = chunks
3910
3911 model := newResponsesProvider(t, sms.server.URL)
3912
3913 stream, err := model.Stream(context.Background(), fantasy.Call{
3914 Prompt: testPrompt,
3915 Tools: []fantasy.Tool{WebSearchTool(nil)},
3916 })
3917 require.NoError(t, err)
3918
3919 var parts []fantasy.StreamPart
3920 stream(func(part fantasy.StreamPart) bool {
3921 parts = append(parts, part)
3922 return true
3923 })
3924
3925 var (
3926 toolInputStarts []fantasy.StreamPart
3927 toolCalls []fantasy.StreamPart
3928 toolResults []fantasy.StreamPart
3929 textDeltas []fantasy.StreamPart
3930 finishes []fantasy.StreamPart
3931 )
3932 for _, p := range parts {
3933 switch p.Type {
3934 case fantasy.StreamPartTypeToolInputStart:
3935 toolInputStarts = append(toolInputStarts, p)
3936 case fantasy.StreamPartTypeToolCall:
3937 toolCalls = append(toolCalls, p)
3938 case fantasy.StreamPartTypeToolResult:
3939 toolResults = append(toolResults, p)
3940 case fantasy.StreamPartTypeTextDelta:
3941 textDeltas = append(textDeltas, p)
3942 case fantasy.StreamPartTypeFinish:
3943 finishes = append(finishes, p)
3944 }
3945 }
3946
3947 require.NotEmpty(t, toolInputStarts, "should have a tool input start")
3948 require.True(t, toolInputStarts[0].ProviderExecuted)
3949 require.Equal(t, "web_search", toolInputStarts[0].ToolCallName)
3950
3951 require.NotEmpty(t, toolCalls, "should have a tool call")
3952 require.True(t, toolCalls[0].ProviderExecuted)
3953 require.Equal(t, "web_search", toolCalls[0].ToolCallName)
3954
3955 require.NotEmpty(t, toolResults, "should have a tool result")
3956 require.True(t, toolResults[0].ProviderExecuted)
3957 require.Equal(t, "web_search", toolResults[0].ToolCallName)
3958 require.Equal(t, "ws_01", toolResults[0].ID)
3959
3960 require.NotEmpty(t, textDeltas, "should have text deltas")
3961 require.Equal(t, "Here are the results.", textDeltas[0].Delta)
3962
3963 require.Len(t, finishes, 1)
3964 responsesMeta, ok := finishes[0].ProviderMetadata[Name].(*ResponsesProviderMetadata)
3965 require.True(t, ok)
3966 require.Equal(t, "resp_01", responsesMeta.ResponseID)
3967}
3968
3969func TestResponsesStream_StoreOption(t *testing.T) {
3970 t.Parallel()
3971
3972 chunks := []string{
3973 "event: response.completed\n" +
3974 `data: {"type":"response.completed","response":{"id":"resp_01","status":"completed","output":[],"usage":{"input_tokens":100,"output_tokens":50,"total_tokens":150}}}` + "\n\n",
3975 }
3976
3977 sms := newStreamingMockServer()
3978 defer sms.close()
3979 sms.chunks = chunks
3980
3981 model := newResponsesProvider(t, sms.server.URL)
3982
3983 stream, err := model.Stream(context.Background(), fantasy.Call{
3984 Prompt: testPrompt,
3985 ProviderOptions: fantasy.ProviderOptions{
3986 Name: &ResponsesProviderOptions{
3987 Store: fantasy.Opt(true),
3988 },
3989 },
3990 })
3991 require.NoError(t, err)
3992
3993 stream(func(part fantasy.StreamPart) bool {
3994 return part.Type != fantasy.StreamPartTypeFinish
3995 })
3996
3997 require.Equal(t, "POST", sms.calls[0].method)
3998 require.Equal(t, "/responses", sms.calls[0].path)
3999 require.Equal(t, true, sms.calls[0].body["store"])
4000}
4001
4002func TestResponsesStream_PreviousResponseIDOption(t *testing.T) {
4003 t.Parallel()
4004
4005 chunks := []string{
4006 "event: response.completed\n" +
4007 `data: {"type":"response.completed","response":{"id":"resp_01","status":"completed","output":[],"usage":{"input_tokens":100,"output_tokens":50,"total_tokens":150}}}` + "\n\n",
4008 }
4009
4010 sms := newStreamingMockServer()
4011 defer sms.close()
4012 sms.chunks = chunks
4013
4014 model := newResponsesProvider(t, sms.server.URL)
4015
4016 stream, err := model.Stream(context.Background(), fantasy.Call{
4017 Prompt: testPrompt,
4018 ProviderOptions: fantasy.ProviderOptions{
4019 Name: &ResponsesProviderOptions{
4020 PreviousResponseID: fantasy.Opt("resp_prev_456"),
4021 Store: fantasy.Opt(true),
4022 },
4023 },
4024 })
4025 require.NoError(t, err)
4026
4027 stream(func(part fantasy.StreamPart) bool {
4028 return part.Type != fantasy.StreamPartTypeFinish
4029 })
4030
4031 require.Equal(t, "POST", sms.calls[0].method)
4032 require.Equal(t, "/responses", sms.calls[0].path)
4033 require.Equal(t, "resp_prev_456", sms.calls[0].body["previous_response_id"])
4034}