1package openai
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "net/http"
9 "net/http/httptest"
10 "strings"
11 "testing"
12
13 "charm.land/fantasy"
14 "github.com/openai/openai-go/v2/packages/param"
15 "github.com/stretchr/testify/require"
16)
17
18func TestToOpenAiPrompt_SystemMessages(t *testing.T) {
19 t.Parallel()
20
21 t.Run("should forward system messages", func(t *testing.T) {
22 t.Parallel()
23
24 prompt := fantasy.Prompt{
25 {
26 Role: fantasy.MessageRoleSystem,
27 Content: []fantasy.MessagePart{
28 fantasy.TextPart{Text: "You are a helpful assistant."},
29 },
30 },
31 }
32
33 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
34
35 require.Empty(t, warnings)
36 require.Len(t, messages, 1)
37
38 systemMsg := messages[0].OfSystem
39 require.NotNil(t, systemMsg)
40 require.Equal(t, "You are a helpful assistant.", systemMsg.Content.OfString.Value)
41 })
42
43 t.Run("should handle empty system messages", func(t *testing.T) {
44 t.Parallel()
45
46 prompt := fantasy.Prompt{
47 {
48 Role: fantasy.MessageRoleSystem,
49 Content: []fantasy.MessagePart{},
50 },
51 }
52
53 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
54
55 require.Len(t, warnings, 1)
56 require.Contains(t, warnings[0].Message, "system prompt has no text parts")
57 require.Empty(t, messages)
58 })
59
60 t.Run("should join multiple system text parts", func(t *testing.T) {
61 t.Parallel()
62
63 prompt := fantasy.Prompt{
64 {
65 Role: fantasy.MessageRoleSystem,
66 Content: []fantasy.MessagePart{
67 fantasy.TextPart{Text: "You are a helpful assistant."},
68 fantasy.TextPart{Text: "Be concise."},
69 },
70 },
71 }
72
73 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
74
75 require.Empty(t, warnings)
76 require.Len(t, messages, 1)
77
78 systemMsg := messages[0].OfSystem
79 require.NotNil(t, systemMsg)
80 require.Equal(t, "You are a helpful assistant.\nBe concise.", systemMsg.Content.OfString.Value)
81 })
82}
83
84func TestToOpenAiPrompt_UserMessages(t *testing.T) {
85 t.Parallel()
86
87 t.Run("should convert messages with only a text part to a string content", func(t *testing.T) {
88 t.Parallel()
89
90 prompt := fantasy.Prompt{
91 {
92 Role: fantasy.MessageRoleUser,
93 Content: []fantasy.MessagePart{
94 fantasy.TextPart{Text: "Hello"},
95 },
96 },
97 }
98
99 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
100
101 require.Empty(t, warnings)
102 require.Len(t, messages, 1)
103
104 userMsg := messages[0].OfUser
105 require.NotNil(t, userMsg)
106 require.Equal(t, "Hello", userMsg.Content.OfString.Value)
107 })
108
109 t.Run("should convert messages with image parts", func(t *testing.T) {
110 t.Parallel()
111
112 imageData := []byte{0, 1, 2, 3}
113 prompt := fantasy.Prompt{
114 {
115 Role: fantasy.MessageRoleUser,
116 Content: []fantasy.MessagePart{
117 fantasy.TextPart{Text: "Hello"},
118 fantasy.FilePart{
119 MediaType: "image/png",
120 Data: imageData,
121 },
122 },
123 },
124 }
125
126 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
127
128 require.Empty(t, warnings)
129 require.Len(t, messages, 1)
130
131 userMsg := messages[0].OfUser
132 require.NotNil(t, userMsg)
133
134 content := userMsg.Content.OfArrayOfContentParts
135 require.Len(t, content, 2)
136
137 // Check text part
138 textPart := content[0].OfText
139 require.NotNil(t, textPart)
140 require.Equal(t, "Hello", textPart.Text)
141
142 // Check image part
143 imagePart := content[1].OfImageURL
144 require.NotNil(t, imagePart)
145 expectedURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
146 require.Equal(t, expectedURL, imagePart.ImageURL.URL)
147 })
148
149 t.Run("should add image detail when specified through provider options", func(t *testing.T) {
150 t.Parallel()
151
152 imageData := []byte{0, 1, 2, 3}
153 prompt := fantasy.Prompt{
154 {
155 Role: fantasy.MessageRoleUser,
156 Content: []fantasy.MessagePart{
157 fantasy.FilePart{
158 MediaType: "image/png",
159 Data: imageData,
160 ProviderOptions: NewProviderFileOptions(&ProviderFileOptions{
161 ImageDetail: "low",
162 }),
163 },
164 },
165 },
166 }
167
168 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
169
170 require.Empty(t, warnings)
171 require.Len(t, messages, 1)
172
173 userMsg := messages[0].OfUser
174 require.NotNil(t, userMsg)
175
176 content := userMsg.Content.OfArrayOfContentParts
177 require.Len(t, content, 1)
178
179 imagePart := content[0].OfImageURL
180 require.NotNil(t, imagePart)
181 require.Equal(t, "low", imagePart.ImageURL.Detail)
182 })
183}
184
185func TestToOpenAiPrompt_FileParts(t *testing.T) {
186 t.Parallel()
187
188 t.Run("should throw for unsupported mime types", func(t *testing.T) {
189 t.Parallel()
190
191 prompt := fantasy.Prompt{
192 {
193 Role: fantasy.MessageRoleUser,
194 Content: []fantasy.MessagePart{
195 fantasy.FilePart{
196 MediaType: "application/something",
197 Data: []byte("test"),
198 },
199 },
200 },
201 }
202
203 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
204
205 require.Len(t, warnings, 2) // unsupported type + empty message
206 require.Contains(t, warnings[0].Message, "file part media type application/something not supported")
207 require.Contains(t, warnings[1].Message, "dropping empty user message")
208 require.Empty(t, messages) // Message is now dropped because it's empty
209 })
210
211 t.Run("should add audio content for audio/wav file parts", func(t *testing.T) {
212 t.Parallel()
213
214 audioData := []byte{0, 1, 2, 3}
215 prompt := fantasy.Prompt{
216 {
217 Role: fantasy.MessageRoleUser,
218 Content: []fantasy.MessagePart{
219 fantasy.FilePart{
220 MediaType: "audio/wav",
221 Data: audioData,
222 },
223 },
224 },
225 }
226
227 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
228
229 require.Empty(t, warnings)
230 require.Len(t, messages, 1)
231
232 userMsg := messages[0].OfUser
233 require.NotNil(t, userMsg)
234
235 content := userMsg.Content.OfArrayOfContentParts
236 require.Len(t, content, 1)
237
238 audioPart := content[0].OfInputAudio
239 require.NotNil(t, audioPart)
240 require.Equal(t, base64.StdEncoding.EncodeToString(audioData), audioPart.InputAudio.Data)
241 require.Equal(t, "wav", audioPart.InputAudio.Format)
242 })
243
244 t.Run("should add audio content for audio/mpeg file parts", func(t *testing.T) {
245 t.Parallel()
246
247 audioData := []byte{0, 1, 2, 3}
248 prompt := fantasy.Prompt{
249 {
250 Role: fantasy.MessageRoleUser,
251 Content: []fantasy.MessagePart{
252 fantasy.FilePart{
253 MediaType: "audio/mpeg",
254 Data: audioData,
255 },
256 },
257 },
258 }
259
260 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
261
262 require.Empty(t, warnings)
263 require.Len(t, messages, 1)
264
265 userMsg := messages[0].OfUser
266 content := userMsg.Content.OfArrayOfContentParts
267 audioPart := content[0].OfInputAudio
268 require.NotNil(t, audioPart)
269 require.Equal(t, "mp3", audioPart.InputAudio.Format)
270 })
271
272 t.Run("should add audio content for audio/mp3 file parts", func(t *testing.T) {
273 t.Parallel()
274
275 audioData := []byte{0, 1, 2, 3}
276 prompt := fantasy.Prompt{
277 {
278 Role: fantasy.MessageRoleUser,
279 Content: []fantasy.MessagePart{
280 fantasy.FilePart{
281 MediaType: "audio/mp3",
282 Data: audioData,
283 },
284 },
285 },
286 }
287
288 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
289
290 require.Empty(t, warnings)
291 require.Len(t, messages, 1)
292
293 userMsg := messages[0].OfUser
294 content := userMsg.Content.OfArrayOfContentParts
295 audioPart := content[0].OfInputAudio
296 require.NotNil(t, audioPart)
297 require.Equal(t, "mp3", audioPart.InputAudio.Format)
298 })
299
300 t.Run("should convert messages with PDF file parts", func(t *testing.T) {
301 t.Parallel()
302
303 pdfData := []byte{1, 2, 3, 4, 5}
304 prompt := fantasy.Prompt{
305 {
306 Role: fantasy.MessageRoleUser,
307 Content: []fantasy.MessagePart{
308 fantasy.FilePart{
309 MediaType: "application/pdf",
310 Data: pdfData,
311 Filename: "document.pdf",
312 },
313 },
314 },
315 }
316
317 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
318
319 require.Empty(t, warnings)
320 require.Len(t, messages, 1)
321
322 userMsg := messages[0].OfUser
323 content := userMsg.Content.OfArrayOfContentParts
324 require.Len(t, content, 1)
325
326 filePart := content[0].OfFile
327 require.NotNil(t, filePart)
328 require.Equal(t, "document.pdf", filePart.File.Filename.Value)
329
330 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
331 require.Equal(t, expectedData, filePart.File.FileData.Value)
332 })
333
334 t.Run("should convert messages with binary PDF file parts", func(t *testing.T) {
335 t.Parallel()
336
337 pdfData := []byte{1, 2, 3, 4, 5}
338 prompt := fantasy.Prompt{
339 {
340 Role: fantasy.MessageRoleUser,
341 Content: []fantasy.MessagePart{
342 fantasy.FilePart{
343 MediaType: "application/pdf",
344 Data: pdfData,
345 Filename: "document.pdf",
346 },
347 },
348 },
349 }
350
351 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
352
353 require.Empty(t, warnings)
354 require.Len(t, messages, 1)
355
356 userMsg := messages[0].OfUser
357 content := userMsg.Content.OfArrayOfContentParts
358 filePart := content[0].OfFile
359 require.NotNil(t, filePart)
360
361 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
362 require.Equal(t, expectedData, filePart.File.FileData.Value)
363 })
364
365 t.Run("should convert messages with PDF file parts using file_id", func(t *testing.T) {
366 t.Parallel()
367
368 prompt := fantasy.Prompt{
369 {
370 Role: fantasy.MessageRoleUser,
371 Content: []fantasy.MessagePart{
372 fantasy.FilePart{
373 MediaType: "application/pdf",
374 Data: []byte("file-pdf-12345"),
375 },
376 },
377 },
378 }
379
380 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
381
382 require.Empty(t, warnings)
383 require.Len(t, messages, 1)
384
385 userMsg := messages[0].OfUser
386 content := userMsg.Content.OfArrayOfContentParts
387 filePart := content[0].OfFile
388 require.NotNil(t, filePart)
389 require.Equal(t, "file-pdf-12345", filePart.File.FileID.Value)
390 require.True(t, param.IsOmitted(filePart.File.FileData))
391 require.True(t, param.IsOmitted(filePart.File.Filename))
392 })
393
394 t.Run("should use default filename for PDF file parts when not provided", func(t *testing.T) {
395 t.Parallel()
396
397 pdfData := []byte{1, 2, 3, 4, 5}
398 prompt := fantasy.Prompt{
399 {
400 Role: fantasy.MessageRoleUser,
401 Content: []fantasy.MessagePart{
402 fantasy.FilePart{
403 MediaType: "application/pdf",
404 Data: pdfData,
405 },
406 },
407 },
408 }
409
410 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
411
412 require.Empty(t, warnings)
413 require.Len(t, messages, 1)
414
415 userMsg := messages[0].OfUser
416 content := userMsg.Content.OfArrayOfContentParts
417 filePart := content[0].OfFile
418 require.NotNil(t, filePart)
419 require.Equal(t, "part-0.pdf", filePart.File.Filename.Value)
420 })
421}
422
423func TestToOpenAiPrompt_ToolCalls(t *testing.T) {
424 t.Parallel()
425
426 t.Run("should stringify arguments to tool calls", func(t *testing.T) {
427 t.Parallel()
428
429 inputArgs := map[string]any{"foo": "bar123"}
430 inputJSON, _ := json.Marshal(inputArgs)
431
432 outputResult := map[string]any{"oof": "321rab"}
433 outputJSON, _ := json.Marshal(outputResult)
434
435 prompt := fantasy.Prompt{
436 {
437 Role: fantasy.MessageRoleAssistant,
438 Content: []fantasy.MessagePart{
439 fantasy.ToolCallPart{
440 ToolCallID: "quux",
441 ToolName: "thwomp",
442 Input: string(inputJSON),
443 },
444 },
445 },
446 {
447 Role: fantasy.MessageRoleTool,
448 Content: []fantasy.MessagePart{
449 fantasy.ToolResultPart{
450 ToolCallID: "quux",
451 Output: fantasy.ToolResultOutputContentText{
452 Text: string(outputJSON),
453 },
454 },
455 },
456 },
457 }
458
459 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
460
461 require.Empty(t, warnings)
462 require.Len(t, messages, 2)
463
464 // Check assistant message with tool call
465 assistantMsg := messages[0].OfAssistant
466 require.NotNil(t, assistantMsg)
467 require.Equal(t, "", assistantMsg.Content.OfString.Value)
468 require.Len(t, assistantMsg.ToolCalls, 1)
469
470 toolCall := assistantMsg.ToolCalls[0].OfFunction
471 require.NotNil(t, toolCall)
472 require.Equal(t, "quux", toolCall.ID)
473 require.Equal(t, "thwomp", toolCall.Function.Name)
474 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
475
476 // Check tool message
477 toolMsg := messages[1].OfTool
478 require.NotNil(t, toolMsg)
479 require.Equal(t, string(outputJSON), toolMsg.Content.OfString.Value)
480 require.Equal(t, "quux", toolMsg.ToolCallID)
481 })
482
483 t.Run("should handle different tool output types", func(t *testing.T) {
484 t.Parallel()
485
486 prompt := fantasy.Prompt{
487 {
488 Role: fantasy.MessageRoleTool,
489 Content: []fantasy.MessagePart{
490 fantasy.ToolResultPart{
491 ToolCallID: "text-tool",
492 Output: fantasy.ToolResultOutputContentText{
493 Text: "Hello world",
494 },
495 },
496 fantasy.ToolResultPart{
497 ToolCallID: "error-tool",
498 Output: fantasy.ToolResultOutputContentError{
499 Error: errors.New("Something went wrong"),
500 },
501 },
502 },
503 },
504 }
505
506 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
507
508 require.Empty(t, warnings)
509 require.Len(t, messages, 2)
510
511 // Check first tool message (text)
512 textToolMsg := messages[0].OfTool
513 require.NotNil(t, textToolMsg)
514 require.Equal(t, "Hello world", textToolMsg.Content.OfString.Value)
515 require.Equal(t, "text-tool", textToolMsg.ToolCallID)
516
517 // Check second tool message (error)
518 errorToolMsg := messages[1].OfTool
519 require.NotNil(t, errorToolMsg)
520 require.Equal(t, "Something went wrong", errorToolMsg.Content.OfString.Value)
521 require.Equal(t, "error-tool", errorToolMsg.ToolCallID)
522 })
523}
524
525func TestToOpenAiPrompt_AssistantMessages(t *testing.T) {
526 t.Parallel()
527
528 t.Run("should handle simple text assistant messages", func(t *testing.T) {
529 t.Parallel()
530
531 prompt := fantasy.Prompt{
532 {
533 Role: fantasy.MessageRoleAssistant,
534 Content: []fantasy.MessagePart{
535 fantasy.TextPart{Text: "Hello, how can I help you?"},
536 },
537 },
538 }
539
540 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
541
542 require.Empty(t, warnings)
543 require.Len(t, messages, 1)
544
545 assistantMsg := messages[0].OfAssistant
546 require.NotNil(t, assistantMsg)
547 require.Equal(t, "Hello, how can I help you?", assistantMsg.Content.OfString.Value)
548 })
549
550 t.Run("should handle assistant messages with mixed content", func(t *testing.T) {
551 t.Parallel()
552
553 inputArgs := map[string]any{"query": "test"}
554 inputJSON, _ := json.Marshal(inputArgs)
555
556 prompt := fantasy.Prompt{
557 {
558 Role: fantasy.MessageRoleAssistant,
559 Content: []fantasy.MessagePart{
560 fantasy.TextPart{Text: "Let me search for that."},
561 fantasy.ToolCallPart{
562 ToolCallID: "call-123",
563 ToolName: "search",
564 Input: string(inputJSON),
565 },
566 },
567 },
568 }
569
570 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
571
572 require.Empty(t, warnings)
573 require.Len(t, messages, 1)
574
575 assistantMsg := messages[0].OfAssistant
576 require.NotNil(t, assistantMsg)
577 require.Equal(t, "Let me search for that.", assistantMsg.Content.OfString.Value)
578 require.Len(t, assistantMsg.ToolCalls, 1)
579
580 toolCall := assistantMsg.ToolCalls[0].OfFunction
581 require.Equal(t, "call-123", toolCall.ID)
582 require.Equal(t, "search", toolCall.Function.Name)
583 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
584 })
585}
586
587var testPrompt = fantasy.Prompt{
588 {
589 Role: fantasy.MessageRoleUser,
590 Content: []fantasy.MessagePart{
591 fantasy.TextPart{Text: "Hello"},
592 },
593 },
594}
595
596var testLogprobs = map[string]any{
597 "content": []map[string]any{
598 {
599 "token": "Hello",
600 "logprob": -0.0009994634,
601 "top_logprobs": []map[string]any{
602 {
603 "token": "Hello",
604 "logprob": -0.0009994634,
605 },
606 },
607 },
608 {
609 "token": "!",
610 "logprob": -0.13410144,
611 "top_logprobs": []map[string]any{
612 {
613 "token": "!",
614 "logprob": -0.13410144,
615 },
616 },
617 },
618 {
619 "token": " How",
620 "logprob": -0.0009250381,
621 "top_logprobs": []map[string]any{
622 {
623 "token": " How",
624 "logprob": -0.0009250381,
625 },
626 },
627 },
628 {
629 "token": " can",
630 "logprob": -0.047709424,
631 "top_logprobs": []map[string]any{
632 {
633 "token": " can",
634 "logprob": -0.047709424,
635 },
636 },
637 },
638 {
639 "token": " I",
640 "logprob": -0.000009014684,
641 "top_logprobs": []map[string]any{
642 {
643 "token": " I",
644 "logprob": -0.000009014684,
645 },
646 },
647 },
648 {
649 "token": " assist",
650 "logprob": -0.009125131,
651 "top_logprobs": []map[string]any{
652 {
653 "token": " assist",
654 "logprob": -0.009125131,
655 },
656 },
657 },
658 {
659 "token": " you",
660 "logprob": -0.0000066306106,
661 "top_logprobs": []map[string]any{
662 {
663 "token": " you",
664 "logprob": -0.0000066306106,
665 },
666 },
667 },
668 {
669 "token": " today",
670 "logprob": -0.00011093382,
671 "top_logprobs": []map[string]any{
672 {
673 "token": " today",
674 "logprob": -0.00011093382,
675 },
676 },
677 },
678 {
679 "token": "?",
680 "logprob": -0.00004596782,
681 "top_logprobs": []map[string]any{
682 {
683 "token": "?",
684 "logprob": -0.00004596782,
685 },
686 },
687 },
688 },
689}
690
691type mockServer struct {
692 server *httptest.Server
693 response map[string]any
694 calls []mockCall
695}
696
697type mockCall struct {
698 method string
699 path string
700 headers map[string]string
701 body map[string]any
702}
703
704func newMockServer() *mockServer {
705 ms := &mockServer{
706 calls: make([]mockCall, 0),
707 }
708
709 ms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
710 // Record the call
711 call := mockCall{
712 method: r.Method,
713 path: r.URL.Path,
714 headers: make(map[string]string),
715 }
716
717 for k, v := range r.Header {
718 if len(v) > 0 {
719 call.headers[k] = v[0]
720 }
721 }
722
723 // Parse request body
724 if r.Body != nil {
725 var body map[string]any
726 json.NewDecoder(r.Body).Decode(&body)
727 call.body = body
728 }
729
730 ms.calls = append(ms.calls, call)
731
732 // Return mock response
733 w.Header().Set("Content-Type", "application/json")
734 json.NewEncoder(w).Encode(ms.response)
735 }))
736
737 return ms
738}
739
740func (ms *mockServer) close() {
741 ms.server.Close()
742}
743
744func (ms *mockServer) prepareJSONResponse(opts map[string]any) {
745 // Default values
746 response := map[string]any{
747 "id": "chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd",
748 "object": "chat.completion",
749 "created": 1711115037,
750 "model": "gpt-3.5-turbo-0125",
751 "choices": []map[string]any{
752 {
753 "index": 0,
754 "message": map[string]any{
755 "role": "assistant",
756 "content": "",
757 },
758 "finish_reason": "stop",
759 },
760 },
761 "usage": map[string]any{
762 "prompt_tokens": 4,
763 "total_tokens": 34,
764 "completion_tokens": 30,
765 },
766 "system_fingerprint": "fp_3bc1b5746c",
767 }
768
769 // Override with provided options
770 for k, v := range opts {
771 switch k {
772 case "content":
773 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["content"] = v
774 case "tool_calls":
775 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["tool_calls"] = v
776 case "function_call":
777 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["function_call"] = v
778 case "annotations":
779 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["annotations"] = v
780 case "usage":
781 response["usage"] = v
782 case "finish_reason":
783 response["choices"].([]map[string]any)[0]["finish_reason"] = v
784 case "id":
785 response["id"] = v
786 case "created":
787 response["created"] = v
788 case "model":
789 response["model"] = v
790 case "logprobs":
791 if v != nil {
792 response["choices"].([]map[string]any)[0]["logprobs"] = v
793 }
794 }
795 }
796
797 ms.response = response
798}
799
800func TestDoGenerate(t *testing.T) {
801 t.Parallel()
802
803 t.Run("should extract text response", func(t *testing.T) {
804 t.Parallel()
805
806 server := newMockServer()
807 defer server.close()
808
809 server.prepareJSONResponse(map[string]any{
810 "content": "Hello, World!",
811 })
812
813 provider, err := New(
814 WithAPIKey("test-api-key"),
815 WithBaseURL(server.server.URL),
816 )
817 require.NoError(t, err)
818 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
819
820 result, err := model.Generate(context.Background(), fantasy.Call{
821 Prompt: testPrompt,
822 })
823
824 require.NoError(t, err)
825 require.Len(t, result.Content, 1)
826
827 textContent, ok := result.Content[0].(fantasy.TextContent)
828 require.True(t, ok)
829 require.Equal(t, "Hello, World!", textContent.Text)
830 })
831
832 t.Run("should extract usage", func(t *testing.T) {
833 t.Parallel()
834
835 server := newMockServer()
836 defer server.close()
837
838 server.prepareJSONResponse(map[string]any{
839 "usage": map[string]any{
840 "prompt_tokens": 20,
841 "total_tokens": 25,
842 "completion_tokens": 5,
843 },
844 })
845
846 provider, err := New(
847 WithAPIKey("test-api-key"),
848 WithBaseURL(server.server.URL),
849 )
850 require.NoError(t, err)
851 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
852
853 result, err := model.Generate(context.Background(), fantasy.Call{
854 Prompt: testPrompt,
855 })
856
857 require.NoError(t, err)
858 require.Equal(t, int64(20), result.Usage.InputTokens)
859 require.Equal(t, int64(5), result.Usage.OutputTokens)
860 require.Equal(t, int64(25), result.Usage.TotalTokens)
861 })
862
863 t.Run("should send request body", func(t *testing.T) {
864 t.Parallel()
865
866 server := newMockServer()
867 defer server.close()
868
869 server.prepareJSONResponse(map[string]any{})
870
871 provider, err := New(
872 WithAPIKey("test-api-key"),
873 WithBaseURL(server.server.URL),
874 )
875 require.NoError(t, err)
876 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
877
878 _, err = model.Generate(context.Background(), fantasy.Call{
879 Prompt: testPrompt,
880 })
881
882 require.NoError(t, err)
883 require.Len(t, server.calls, 1)
884
885 call := server.calls[0]
886 require.Equal(t, "POST", call.method)
887 require.Equal(t, "/chat/completions", call.path)
888 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
889
890 messages, ok := call.body["messages"].([]any)
891 require.True(t, ok)
892 require.Len(t, messages, 1)
893
894 message := messages[0].(map[string]any)
895 require.Equal(t, "user", message["role"])
896 require.Equal(t, "Hello", message["content"])
897 })
898
899 t.Run("should support partial usage", func(t *testing.T) {
900 t.Parallel()
901
902 server := newMockServer()
903 defer server.close()
904
905 server.prepareJSONResponse(map[string]any{
906 "usage": map[string]any{
907 "prompt_tokens": 20,
908 "total_tokens": 20,
909 },
910 })
911
912 provider, err := New(
913 WithAPIKey("test-api-key"),
914 WithBaseURL(server.server.URL),
915 )
916 require.NoError(t, err)
917 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
918
919 result, err := model.Generate(context.Background(), fantasy.Call{
920 Prompt: testPrompt,
921 })
922
923 require.NoError(t, err)
924 require.Equal(t, int64(20), result.Usage.InputTokens)
925 require.Equal(t, int64(0), result.Usage.OutputTokens)
926 require.Equal(t, int64(20), result.Usage.TotalTokens)
927 })
928
929 t.Run("should extract logprobs", func(t *testing.T) {
930 t.Parallel()
931
932 server := newMockServer()
933 defer server.close()
934
935 server.prepareJSONResponse(map[string]any{
936 "logprobs": testLogprobs,
937 })
938
939 provider, err := New(
940 WithAPIKey("test-api-key"),
941 WithBaseURL(server.server.URL),
942 )
943 require.NoError(t, err)
944 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
945
946 result, err := model.Generate(context.Background(), fantasy.Call{
947 Prompt: testPrompt,
948 ProviderOptions: NewProviderOptions(&ProviderOptions{
949 LogProbs: fantasy.Opt(true),
950 }),
951 })
952
953 require.NoError(t, err)
954 require.NotNil(t, result.ProviderMetadata)
955
956 openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
957 require.True(t, ok)
958
959 logprobs := openaiMeta.Logprobs
960 require.True(t, ok)
961 require.NotNil(t, logprobs)
962 })
963
964 t.Run("should extract finish reason", func(t *testing.T) {
965 t.Parallel()
966
967 server := newMockServer()
968 defer server.close()
969
970 server.prepareJSONResponse(map[string]any{
971 "finish_reason": "stop",
972 })
973
974 provider, err := New(
975 WithAPIKey("test-api-key"),
976 WithBaseURL(server.server.URL),
977 )
978 require.NoError(t, err)
979 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
980
981 result, err := model.Generate(context.Background(), fantasy.Call{
982 Prompt: testPrompt,
983 })
984
985 require.NoError(t, err)
986 require.Equal(t, fantasy.FinishReasonStop, result.FinishReason)
987 })
988
989 t.Run("should support unknown finish reason", func(t *testing.T) {
990 t.Parallel()
991
992 server := newMockServer()
993 defer server.close()
994
995 server.prepareJSONResponse(map[string]any{
996 "finish_reason": "eos",
997 })
998
999 provider, err := New(
1000 WithAPIKey("test-api-key"),
1001 WithBaseURL(server.server.URL),
1002 )
1003 require.NoError(t, err)
1004 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1005
1006 result, err := model.Generate(context.Background(), fantasy.Call{
1007 Prompt: testPrompt,
1008 })
1009
1010 require.NoError(t, err)
1011 require.Equal(t, fantasy.FinishReasonUnknown, result.FinishReason)
1012 })
1013
1014 t.Run("should pass the model and the messages", func(t *testing.T) {
1015 t.Parallel()
1016
1017 server := newMockServer()
1018 defer server.close()
1019
1020 server.prepareJSONResponse(map[string]any{
1021 "content": "",
1022 })
1023
1024 provider, err := New(
1025 WithAPIKey("test-api-key"),
1026 WithBaseURL(server.server.URL),
1027 )
1028 require.NoError(t, err)
1029 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1030
1031 _, err = model.Generate(context.Background(), fantasy.Call{
1032 Prompt: testPrompt,
1033 })
1034
1035 require.NoError(t, err)
1036 require.Len(t, server.calls, 1)
1037
1038 call := server.calls[0]
1039 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1040
1041 messages := call.body["messages"].([]any)
1042 require.Len(t, messages, 1)
1043
1044 message := messages[0].(map[string]any)
1045 require.Equal(t, "user", message["role"])
1046 require.Equal(t, "Hello", message["content"])
1047 })
1048
1049 t.Run("should pass settings", func(t *testing.T) {
1050 t.Parallel()
1051
1052 server := newMockServer()
1053 defer server.close()
1054
1055 server.prepareJSONResponse(map[string]any{})
1056
1057 provider, err := New(
1058 WithAPIKey("test-api-key"),
1059 WithBaseURL(server.server.URL),
1060 )
1061 require.NoError(t, err)
1062 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1063
1064 _, err = model.Generate(context.Background(), fantasy.Call{
1065 Prompt: testPrompt,
1066 ProviderOptions: NewProviderOptions(&ProviderOptions{
1067 LogitBias: map[string]int64{
1068 "50256": -100,
1069 },
1070 ParallelToolCalls: fantasy.Opt(false),
1071 User: fantasy.Opt("test-user-id"),
1072 }),
1073 })
1074
1075 require.NoError(t, err)
1076 require.Len(t, server.calls, 1)
1077
1078 call := server.calls[0]
1079 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1080
1081 messages := call.body["messages"].([]any)
1082 require.Len(t, messages, 1)
1083
1084 logitBias := call.body["logit_bias"].(map[string]any)
1085 require.Equal(t, float64(-100), logitBias["50256"])
1086 require.Equal(t, false, call.body["parallel_tool_calls"])
1087 require.Equal(t, "test-user-id", call.body["user"])
1088 })
1089
1090 t.Run("should pass reasoningEffort setting", func(t *testing.T) {
1091 t.Parallel()
1092
1093 server := newMockServer()
1094 defer server.close()
1095
1096 server.prepareJSONResponse(map[string]any{
1097 "content": "",
1098 })
1099
1100 provider, err := New(
1101 WithAPIKey("test-api-key"),
1102 WithBaseURL(server.server.URL),
1103 )
1104 require.NoError(t, err)
1105 model, _ := provider.LanguageModel(t.Context(), "o1-mini")
1106
1107 _, err = model.Generate(context.Background(), fantasy.Call{
1108 Prompt: testPrompt,
1109 ProviderOptions: NewProviderOptions(
1110 &ProviderOptions{
1111 ReasoningEffort: ReasoningEffortOption(ReasoningEffortLow),
1112 },
1113 ),
1114 })
1115
1116 require.NoError(t, err)
1117 require.Len(t, server.calls, 1)
1118
1119 call := server.calls[0]
1120 require.Equal(t, "o1-mini", call.body["model"])
1121 require.Equal(t, "low", call.body["reasoning_effort"])
1122
1123 messages := call.body["messages"].([]any)
1124 require.Len(t, messages, 1)
1125
1126 message := messages[0].(map[string]any)
1127 require.Equal(t, "user", message["role"])
1128 require.Equal(t, "Hello", message["content"])
1129 })
1130
1131 t.Run("should pass textVerbosity setting", func(t *testing.T) {
1132 t.Parallel()
1133
1134 server := newMockServer()
1135 defer server.close()
1136
1137 server.prepareJSONResponse(map[string]any{
1138 "content": "",
1139 })
1140
1141 provider, err := New(
1142 WithAPIKey("test-api-key"),
1143 WithBaseURL(server.server.URL),
1144 )
1145 require.NoError(t, err)
1146 model, _ := provider.LanguageModel(t.Context(), "gpt-4o")
1147
1148 _, err = model.Generate(context.Background(), fantasy.Call{
1149 Prompt: testPrompt,
1150 ProviderOptions: NewProviderOptions(&ProviderOptions{
1151 TextVerbosity: fantasy.Opt("low"),
1152 }),
1153 })
1154
1155 require.NoError(t, err)
1156 require.Len(t, server.calls, 1)
1157
1158 call := server.calls[0]
1159 require.Equal(t, "gpt-4o", call.body["model"])
1160 require.Equal(t, "low", call.body["verbosity"])
1161
1162 messages := call.body["messages"].([]any)
1163 require.Len(t, messages, 1)
1164
1165 message := messages[0].(map[string]any)
1166 require.Equal(t, "user", message["role"])
1167 require.Equal(t, "Hello", message["content"])
1168 })
1169
1170 t.Run("should pass tools and toolChoice", func(t *testing.T) {
1171 t.Parallel()
1172
1173 server := newMockServer()
1174 defer server.close()
1175
1176 server.prepareJSONResponse(map[string]any{
1177 "content": "",
1178 })
1179
1180 provider, err := New(
1181 WithAPIKey("test-api-key"),
1182 WithBaseURL(server.server.URL),
1183 )
1184 require.NoError(t, err)
1185 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1186
1187 _, err = model.Generate(context.Background(), fantasy.Call{
1188 Prompt: testPrompt,
1189 Tools: []fantasy.Tool{
1190 fantasy.FunctionTool{
1191 Name: "test-tool",
1192 InputSchema: map[string]any{
1193 "type": "object",
1194 "properties": map[string]any{
1195 "value": map[string]any{
1196 "type": "string",
1197 },
1198 },
1199 "required": []string{"value"},
1200 "additionalProperties": false,
1201 "$schema": "http://json-schema.org/draft-07/schema#",
1202 },
1203 },
1204 },
1205 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoice("test-tool")}[0],
1206 })
1207
1208 require.NoError(t, err)
1209 require.Len(t, server.calls, 1)
1210
1211 call := server.calls[0]
1212 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1213
1214 messages := call.body["messages"].([]any)
1215 require.Len(t, messages, 1)
1216
1217 tools := call.body["tools"].([]any)
1218 require.Len(t, tools, 1)
1219
1220 tool := tools[0].(map[string]any)
1221 require.Equal(t, "function", tool["type"])
1222
1223 function := tool["function"].(map[string]any)
1224 require.Equal(t, "test-tool", function["name"])
1225 require.Equal(t, false, function["strict"])
1226
1227 toolChoice := call.body["tool_choice"].(map[string]any)
1228 require.Equal(t, "function", toolChoice["type"])
1229
1230 toolChoiceFunction := toolChoice["function"].(map[string]any)
1231 require.Equal(t, "test-tool", toolChoiceFunction["name"])
1232 })
1233
1234 t.Run("should parse tool results", func(t *testing.T) {
1235 t.Parallel()
1236
1237 server := newMockServer()
1238 defer server.close()
1239
1240 server.prepareJSONResponse(map[string]any{
1241 "tool_calls": []map[string]any{
1242 {
1243 "id": "call_O17Uplv4lJvD6DVdIvFFeRMw",
1244 "type": "function",
1245 "function": map[string]any{
1246 "name": "test-tool",
1247 "arguments": `{"value":"Spark"}`,
1248 },
1249 },
1250 },
1251 })
1252
1253 provider, err := New(
1254 WithAPIKey("test-api-key"),
1255 WithBaseURL(server.server.URL),
1256 )
1257 require.NoError(t, err)
1258 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1259
1260 result, err := model.Generate(context.Background(), fantasy.Call{
1261 Prompt: testPrompt,
1262 Tools: []fantasy.Tool{
1263 fantasy.FunctionTool{
1264 Name: "test-tool",
1265 InputSchema: map[string]any{
1266 "type": "object",
1267 "properties": map[string]any{
1268 "value": map[string]any{
1269 "type": "string",
1270 },
1271 },
1272 "required": []string{"value"},
1273 "additionalProperties": false,
1274 "$schema": "http://json-schema.org/draft-07/schema#",
1275 },
1276 },
1277 },
1278 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoice("test-tool")}[0],
1279 })
1280
1281 require.NoError(t, err)
1282 require.Len(t, result.Content, 1)
1283
1284 toolCall, ok := result.Content[0].(fantasy.ToolCallContent)
1285 require.True(t, ok)
1286 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", toolCall.ToolCallID)
1287 require.Equal(t, "test-tool", toolCall.ToolName)
1288 require.Equal(t, `{"value":"Spark"}`, toolCall.Input)
1289 })
1290
1291 t.Run("should handle ToolChoiceRequired", func(t *testing.T) {
1292 t.Parallel()
1293
1294 server := newMockServer()
1295 defer server.close()
1296
1297 server.prepareJSONResponse(map[string]any{
1298 "content": "",
1299 })
1300
1301 provider, err := New(
1302 WithAPIKey("test-api-key"),
1303 WithBaseURL(server.server.URL),
1304 )
1305 require.NoError(t, err)
1306 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1307
1308 _, err = model.Generate(context.Background(), fantasy.Call{
1309 Prompt: testPrompt,
1310 Tools: []fantasy.Tool{
1311 fantasy.FunctionTool{
1312 Name: "test-tool",
1313 InputSchema: map[string]any{
1314 "type": "object",
1315 "properties": map[string]any{
1316 "value": map[string]any{
1317 "type": "string",
1318 },
1319 },
1320 "required": []string{"value"},
1321 "additionalProperties": false,
1322 "$schema": "http://json-schema.org/draft-07/schema#",
1323 },
1324 },
1325 },
1326 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoiceRequired}[0],
1327 })
1328
1329 require.NoError(t, err)
1330 require.Len(t, server.calls, 1)
1331
1332 call := server.calls[0]
1333 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1334
1335 // Verify tool is present
1336 tools := call.body["tools"].([]any)
1337 require.Len(t, tools, 1)
1338
1339 tool := tools[0].(map[string]any)
1340 require.Equal(t, "function", tool["type"])
1341
1342 function := tool["function"].(map[string]any)
1343 require.Equal(t, "test-tool", function["name"])
1344
1345 // Verify tool_choice is set to "required" (not a function name)
1346 toolChoice := call.body["tool_choice"]
1347 require.Equal(t, "required", toolChoice)
1348 })
1349
1350 t.Run("should parse annotations/citations", func(t *testing.T) {
1351 t.Parallel()
1352
1353 server := newMockServer()
1354 defer server.close()
1355
1356 server.prepareJSONResponse(map[string]any{
1357 "content": "Based on the search results [doc1], I found information.",
1358 "annotations": []map[string]any{
1359 {
1360 "type": "url_citation",
1361 "url_citation": map[string]any{
1362 "start_index": 24,
1363 "end_index": 29,
1364 "url": "https://example.com/doc1.pdf",
1365 "title": "Document 1",
1366 },
1367 },
1368 },
1369 })
1370
1371 provider, err := New(
1372 WithAPIKey("test-api-key"),
1373 WithBaseURL(server.server.URL),
1374 )
1375 require.NoError(t, err)
1376 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1377
1378 result, err := model.Generate(context.Background(), fantasy.Call{
1379 Prompt: testPrompt,
1380 })
1381
1382 require.NoError(t, err)
1383 require.Len(t, result.Content, 2)
1384
1385 textContent, ok := result.Content[0].(fantasy.TextContent)
1386 require.True(t, ok)
1387 require.Equal(t, "Based on the search results [doc1], I found information.", textContent.Text)
1388
1389 sourceContent, ok := result.Content[1].(fantasy.SourceContent)
1390 require.True(t, ok)
1391 require.Equal(t, fantasy.SourceTypeURL, sourceContent.SourceType)
1392 require.Equal(t, "https://example.com/doc1.pdf", sourceContent.URL)
1393 require.Equal(t, "Document 1", sourceContent.Title)
1394 require.NotEmpty(t, sourceContent.ID)
1395 })
1396
1397 t.Run("should return cached_tokens in prompt_details_tokens", func(t *testing.T) {
1398 t.Parallel()
1399
1400 server := newMockServer()
1401 defer server.close()
1402
1403 server.prepareJSONResponse(map[string]any{
1404 "usage": map[string]any{
1405 "prompt_tokens": 15,
1406 "completion_tokens": 20,
1407 "total_tokens": 35,
1408 "prompt_tokens_details": map[string]any{
1409 "cached_tokens": 1152,
1410 },
1411 },
1412 })
1413
1414 provider, err := New(
1415 WithAPIKey("test-api-key"),
1416 WithBaseURL(server.server.URL),
1417 )
1418 require.NoError(t, err)
1419 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1420
1421 result, err := model.Generate(context.Background(), fantasy.Call{
1422 Prompt: testPrompt,
1423 })
1424
1425 require.NoError(t, err)
1426 require.Equal(t, int64(1152), result.Usage.CacheReadTokens)
1427 require.Equal(t, int64(15), result.Usage.InputTokens)
1428 require.Equal(t, int64(20), result.Usage.OutputTokens)
1429 require.Equal(t, int64(35), result.Usage.TotalTokens)
1430 })
1431
1432 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
1433 t.Parallel()
1434
1435 server := newMockServer()
1436 defer server.close()
1437
1438 server.prepareJSONResponse(map[string]any{
1439 "usage": map[string]any{
1440 "prompt_tokens": 15,
1441 "completion_tokens": 20,
1442 "total_tokens": 35,
1443 "completion_tokens_details": map[string]any{
1444 "accepted_prediction_tokens": 123,
1445 "rejected_prediction_tokens": 456,
1446 },
1447 },
1448 })
1449
1450 provider, err := New(
1451 WithAPIKey("test-api-key"),
1452 WithBaseURL(server.server.URL),
1453 )
1454 require.NoError(t, err)
1455 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1456
1457 result, err := model.Generate(context.Background(), fantasy.Call{
1458 Prompt: testPrompt,
1459 })
1460
1461 require.NoError(t, err)
1462 require.NotNil(t, result.ProviderMetadata)
1463
1464 openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
1465
1466 require.True(t, ok)
1467 require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
1468 require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
1469 })
1470
1471 t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) {
1472 t.Parallel()
1473
1474 server := newMockServer()
1475 defer server.close()
1476
1477 server.prepareJSONResponse(map[string]any{})
1478
1479 provider, err := New(
1480 WithAPIKey("test-api-key"),
1481 WithBaseURL(server.server.URL),
1482 )
1483 require.NoError(t, err)
1484 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1485
1486 result, err := model.Generate(context.Background(), fantasy.Call{
1487 Prompt: testPrompt,
1488 Temperature: &[]float64{0.5}[0],
1489 TopP: &[]float64{0.7}[0],
1490 FrequencyPenalty: &[]float64{0.2}[0],
1491 PresencePenalty: &[]float64{0.3}[0],
1492 })
1493
1494 require.NoError(t, err)
1495 require.Len(t, server.calls, 1)
1496
1497 call := server.calls[0]
1498 require.Equal(t, "o1-preview", call.body["model"])
1499
1500 messages := call.body["messages"].([]any)
1501 require.Len(t, messages, 1)
1502
1503 message := messages[0].(map[string]any)
1504 require.Equal(t, "user", message["role"])
1505 require.Equal(t, "Hello", message["content"])
1506
1507 // These should not be present
1508 require.Nil(t, call.body["temperature"])
1509 require.Nil(t, call.body["top_p"])
1510 require.Nil(t, call.body["frequency_penalty"])
1511 require.Nil(t, call.body["presence_penalty"])
1512
1513 // Should have warnings
1514 require.Len(t, result.Warnings, 4)
1515 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1516 require.Equal(t, "temperature", result.Warnings[0].Setting)
1517 require.Contains(t, result.Warnings[0].Details, "temperature is not supported for reasoning models")
1518 })
1519
1520 t.Run("should convert maxOutputTokens to max_completion_tokens for reasoning models", func(t *testing.T) {
1521 t.Parallel()
1522
1523 server := newMockServer()
1524 defer server.close()
1525
1526 server.prepareJSONResponse(map[string]any{})
1527
1528 provider, err := New(
1529 WithAPIKey("test-api-key"),
1530 WithBaseURL(server.server.URL),
1531 )
1532 require.NoError(t, err)
1533 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1534
1535 _, err = model.Generate(context.Background(), fantasy.Call{
1536 Prompt: testPrompt,
1537 MaxOutputTokens: &[]int64{1000}[0],
1538 })
1539
1540 require.NoError(t, err)
1541 require.Len(t, server.calls, 1)
1542
1543 call := server.calls[0]
1544 require.Equal(t, "o1-preview", call.body["model"])
1545 require.Equal(t, float64(1000), call.body["max_completion_tokens"])
1546 require.Nil(t, call.body["max_tokens"])
1547
1548 messages := call.body["messages"].([]any)
1549 require.Len(t, messages, 1)
1550
1551 message := messages[0].(map[string]any)
1552 require.Equal(t, "user", message["role"])
1553 require.Equal(t, "Hello", message["content"])
1554 })
1555
1556 t.Run("should return reasoning tokens", func(t *testing.T) {
1557 t.Parallel()
1558
1559 server := newMockServer()
1560 defer server.close()
1561
1562 server.prepareJSONResponse(map[string]any{
1563 "usage": map[string]any{
1564 "prompt_tokens": 15,
1565 "completion_tokens": 20,
1566 "total_tokens": 35,
1567 "completion_tokens_details": map[string]any{
1568 "reasoning_tokens": 10,
1569 },
1570 },
1571 })
1572
1573 provider, err := New(
1574 WithAPIKey("test-api-key"),
1575 WithBaseURL(server.server.URL),
1576 )
1577 require.NoError(t, err)
1578 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1579
1580 result, err := model.Generate(context.Background(), fantasy.Call{
1581 Prompt: testPrompt,
1582 })
1583
1584 require.NoError(t, err)
1585 require.Equal(t, int64(15), result.Usage.InputTokens)
1586 require.Equal(t, int64(20), result.Usage.OutputTokens)
1587 require.Equal(t, int64(35), result.Usage.TotalTokens)
1588 require.Equal(t, int64(10), result.Usage.ReasoningTokens)
1589 })
1590
1591 t.Run("should send max_completion_tokens extension setting", func(t *testing.T) {
1592 t.Parallel()
1593
1594 server := newMockServer()
1595 defer server.close()
1596
1597 server.prepareJSONResponse(map[string]any{
1598 "model": "o1-preview",
1599 })
1600
1601 provider, err := New(
1602 WithAPIKey("test-api-key"),
1603 WithBaseURL(server.server.URL),
1604 )
1605 require.NoError(t, err)
1606 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1607
1608 _, err = model.Generate(context.Background(), fantasy.Call{
1609 Prompt: testPrompt,
1610 ProviderOptions: NewProviderOptions(&ProviderOptions{
1611 MaxCompletionTokens: fantasy.Opt(int64(255)),
1612 }),
1613 })
1614
1615 require.NoError(t, err)
1616 require.Len(t, server.calls, 1)
1617
1618 call := server.calls[0]
1619 require.Equal(t, "o1-preview", call.body["model"])
1620 require.Equal(t, float64(255), call.body["max_completion_tokens"])
1621
1622 messages := call.body["messages"].([]any)
1623 require.Len(t, messages, 1)
1624
1625 message := messages[0].(map[string]any)
1626 require.Equal(t, "user", message["role"])
1627 require.Equal(t, "Hello", message["content"])
1628 })
1629
1630 t.Run("should send prediction extension setting", func(t *testing.T) {
1631 t.Parallel()
1632
1633 server := newMockServer()
1634 defer server.close()
1635
1636 server.prepareJSONResponse(map[string]any{
1637 "content": "",
1638 })
1639
1640 provider, err := New(
1641 WithAPIKey("test-api-key"),
1642 WithBaseURL(server.server.URL),
1643 )
1644 require.NoError(t, err)
1645 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1646
1647 _, err = model.Generate(context.Background(), fantasy.Call{
1648 Prompt: testPrompt,
1649 ProviderOptions: NewProviderOptions(&ProviderOptions{
1650 Prediction: map[string]any{
1651 "type": "content",
1652 "content": "Hello, World!",
1653 },
1654 }),
1655 })
1656
1657 require.NoError(t, err)
1658 require.Len(t, server.calls, 1)
1659
1660 call := server.calls[0]
1661 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1662
1663 prediction := call.body["prediction"].(map[string]any)
1664 require.Equal(t, "content", prediction["type"])
1665 require.Equal(t, "Hello, World!", prediction["content"])
1666
1667 messages := call.body["messages"].([]any)
1668 require.Len(t, messages, 1)
1669
1670 message := messages[0].(map[string]any)
1671 require.Equal(t, "user", message["role"])
1672 require.Equal(t, "Hello", message["content"])
1673 })
1674
1675 t.Run("should send store extension setting", func(t *testing.T) {
1676 t.Parallel()
1677
1678 server := newMockServer()
1679 defer server.close()
1680
1681 server.prepareJSONResponse(map[string]any{
1682 "content": "",
1683 })
1684
1685 provider, err := New(
1686 WithAPIKey("test-api-key"),
1687 WithBaseURL(server.server.URL),
1688 )
1689 require.NoError(t, err)
1690 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1691
1692 _, err = model.Generate(context.Background(), fantasy.Call{
1693 Prompt: testPrompt,
1694 ProviderOptions: NewProviderOptions(&ProviderOptions{
1695 Store: fantasy.Opt(true),
1696 }),
1697 })
1698
1699 require.NoError(t, err)
1700 require.Len(t, server.calls, 1)
1701
1702 call := server.calls[0]
1703 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1704 require.Equal(t, true, call.body["store"])
1705
1706 messages := call.body["messages"].([]any)
1707 require.Len(t, messages, 1)
1708
1709 message := messages[0].(map[string]any)
1710 require.Equal(t, "user", message["role"])
1711 require.Equal(t, "Hello", message["content"])
1712 })
1713
1714 t.Run("should send metadata extension values", func(t *testing.T) {
1715 t.Parallel()
1716
1717 server := newMockServer()
1718 defer server.close()
1719
1720 server.prepareJSONResponse(map[string]any{
1721 "content": "",
1722 })
1723
1724 provider, err := New(
1725 WithAPIKey("test-api-key"),
1726 WithBaseURL(server.server.URL),
1727 )
1728 require.NoError(t, err)
1729 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1730
1731 _, err = model.Generate(context.Background(), fantasy.Call{
1732 Prompt: testPrompt,
1733 ProviderOptions: NewProviderOptions(&ProviderOptions{
1734 Metadata: map[string]any{
1735 "custom": "value",
1736 },
1737 }),
1738 })
1739
1740 require.NoError(t, err)
1741 require.Len(t, server.calls, 1)
1742
1743 call := server.calls[0]
1744 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1745
1746 metadata := call.body["metadata"].(map[string]any)
1747 require.Equal(t, "value", metadata["custom"])
1748
1749 messages := call.body["messages"].([]any)
1750 require.Len(t, messages, 1)
1751
1752 message := messages[0].(map[string]any)
1753 require.Equal(t, "user", message["role"])
1754 require.Equal(t, "Hello", message["content"])
1755 })
1756
1757 t.Run("should send promptCacheKey extension value", func(t *testing.T) {
1758 t.Parallel()
1759
1760 server := newMockServer()
1761 defer server.close()
1762
1763 server.prepareJSONResponse(map[string]any{
1764 "content": "",
1765 })
1766
1767 provider, err := New(
1768 WithAPIKey("test-api-key"),
1769 WithBaseURL(server.server.URL),
1770 )
1771 require.NoError(t, err)
1772 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1773
1774 _, err = model.Generate(context.Background(), fantasy.Call{
1775 Prompt: testPrompt,
1776 ProviderOptions: NewProviderOptions(&ProviderOptions{
1777 PromptCacheKey: fantasy.Opt("test-cache-key-123"),
1778 }),
1779 })
1780
1781 require.NoError(t, err)
1782 require.Len(t, server.calls, 1)
1783
1784 call := server.calls[0]
1785 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1786 require.Equal(t, "test-cache-key-123", call.body["prompt_cache_key"])
1787
1788 messages := call.body["messages"].([]any)
1789 require.Len(t, messages, 1)
1790
1791 message := messages[0].(map[string]any)
1792 require.Equal(t, "user", message["role"])
1793 require.Equal(t, "Hello", message["content"])
1794 })
1795
1796 t.Run("should send safety_identifier extension value", func(t *testing.T) {
1797 t.Parallel()
1798
1799 server := newMockServer()
1800 defer server.close()
1801
1802 server.prepareJSONResponse(map[string]any{
1803 "content": "",
1804 })
1805
1806 provider, err := New(
1807 WithAPIKey("test-api-key"),
1808 WithBaseURL(server.server.URL),
1809 )
1810 require.NoError(t, err)
1811 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1812
1813 _, err = model.Generate(context.Background(), fantasy.Call{
1814 Prompt: testPrompt,
1815 ProviderOptions: NewProviderOptions(&ProviderOptions{
1816 SafetyIdentifier: fantasy.Opt("test-safety-identifier-123"),
1817 }),
1818 })
1819
1820 require.NoError(t, err)
1821 require.Len(t, server.calls, 1)
1822
1823 call := server.calls[0]
1824 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1825 require.Equal(t, "test-safety-identifier-123", call.body["safety_identifier"])
1826
1827 messages := call.body["messages"].([]any)
1828 require.Len(t, messages, 1)
1829
1830 message := messages[0].(map[string]any)
1831 require.Equal(t, "user", message["role"])
1832 require.Equal(t, "Hello", message["content"])
1833 })
1834
1835 t.Run("should remove temperature setting for search preview models", func(t *testing.T) {
1836 t.Parallel()
1837
1838 server := newMockServer()
1839 defer server.close()
1840
1841 server.prepareJSONResponse(map[string]any{})
1842
1843 provider, err := New(
1844 WithAPIKey("test-api-key"),
1845 WithBaseURL(server.server.URL),
1846 )
1847 require.NoError(t, err)
1848 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-search-preview")
1849
1850 result, err := model.Generate(context.Background(), fantasy.Call{
1851 Prompt: testPrompt,
1852 Temperature: &[]float64{0.7}[0],
1853 })
1854
1855 require.NoError(t, err)
1856 require.Len(t, server.calls, 1)
1857
1858 call := server.calls[0]
1859 require.Equal(t, "gpt-4o-search-preview", call.body["model"])
1860 require.Nil(t, call.body["temperature"])
1861
1862 require.Len(t, result.Warnings, 1)
1863 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1864 require.Equal(t, "temperature", result.Warnings[0].Setting)
1865 require.Contains(t, result.Warnings[0].Details, "search preview models")
1866 })
1867
1868 t.Run("should send ServiceTier flex processing setting", func(t *testing.T) {
1869 t.Parallel()
1870
1871 server := newMockServer()
1872 defer server.close()
1873
1874 server.prepareJSONResponse(map[string]any{
1875 "content": "",
1876 })
1877
1878 provider, err := New(
1879 WithAPIKey("test-api-key"),
1880 WithBaseURL(server.server.URL),
1881 )
1882 require.NoError(t, err)
1883 model, _ := provider.LanguageModel(t.Context(), "o3-mini")
1884
1885 _, err = model.Generate(context.Background(), fantasy.Call{
1886 Prompt: testPrompt,
1887 ProviderOptions: NewProviderOptions(&ProviderOptions{
1888 ServiceTier: fantasy.Opt("flex"),
1889 }),
1890 })
1891
1892 require.NoError(t, err)
1893 require.Len(t, server.calls, 1)
1894
1895 call := server.calls[0]
1896 require.Equal(t, "o3-mini", call.body["model"])
1897 require.Equal(t, "flex", call.body["service_tier"])
1898
1899 messages := call.body["messages"].([]any)
1900 require.Len(t, messages, 1)
1901
1902 message := messages[0].(map[string]any)
1903 require.Equal(t, "user", message["role"])
1904 require.Equal(t, "Hello", message["content"])
1905 })
1906
1907 t.Run("should show warning when using flex processing with unsupported model", func(t *testing.T) {
1908 t.Parallel()
1909
1910 server := newMockServer()
1911 defer server.close()
1912
1913 server.prepareJSONResponse(map[string]any{})
1914
1915 provider, err := New(
1916 WithAPIKey("test-api-key"),
1917 WithBaseURL(server.server.URL),
1918 )
1919 require.NoError(t, err)
1920 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1921
1922 result, err := model.Generate(context.Background(), fantasy.Call{
1923 Prompt: testPrompt,
1924 ProviderOptions: NewProviderOptions(&ProviderOptions{
1925 ServiceTier: fantasy.Opt("flex"),
1926 }),
1927 })
1928
1929 require.NoError(t, err)
1930 require.Len(t, server.calls, 1)
1931
1932 call := server.calls[0]
1933 require.Nil(t, call.body["service_tier"])
1934
1935 require.Len(t, result.Warnings, 1)
1936 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1937 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
1938 require.Contains(t, result.Warnings[0].Details, "flex processing is only available")
1939 })
1940
1941 t.Run("should send serviceTier priority processing setting", func(t *testing.T) {
1942 t.Parallel()
1943
1944 server := newMockServer()
1945 defer server.close()
1946
1947 server.prepareJSONResponse(map[string]any{})
1948
1949 provider, err := New(
1950 WithAPIKey("test-api-key"),
1951 WithBaseURL(server.server.URL),
1952 )
1953 require.NoError(t, err)
1954 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1955
1956 _, err = model.Generate(context.Background(), fantasy.Call{
1957 Prompt: testPrompt,
1958 ProviderOptions: NewProviderOptions(&ProviderOptions{
1959 ServiceTier: fantasy.Opt("priority"),
1960 }),
1961 })
1962
1963 require.NoError(t, err)
1964 require.Len(t, server.calls, 1)
1965
1966 call := server.calls[0]
1967 require.Equal(t, "gpt-4o-mini", call.body["model"])
1968 require.Equal(t, "priority", call.body["service_tier"])
1969
1970 messages := call.body["messages"].([]any)
1971 require.Len(t, messages, 1)
1972
1973 message := messages[0].(map[string]any)
1974 require.Equal(t, "user", message["role"])
1975 require.Equal(t, "Hello", message["content"])
1976 })
1977
1978 t.Run("should show warning when using priority processing with unsupported model", func(t *testing.T) {
1979 t.Parallel()
1980
1981 server := newMockServer()
1982 defer server.close()
1983
1984 server.prepareJSONResponse(map[string]any{})
1985
1986 provider, err := New(
1987 WithAPIKey("test-api-key"),
1988 WithBaseURL(server.server.URL),
1989 )
1990 require.NoError(t, err)
1991 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1992
1993 result, err := model.Generate(context.Background(), fantasy.Call{
1994 Prompt: testPrompt,
1995 ProviderOptions: NewProviderOptions(&ProviderOptions{
1996 ServiceTier: fantasy.Opt("priority"),
1997 }),
1998 })
1999
2000 require.NoError(t, err)
2001 require.Len(t, server.calls, 1)
2002
2003 call := server.calls[0]
2004 require.Nil(t, call.body["service_tier"])
2005
2006 require.Len(t, result.Warnings, 1)
2007 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
2008 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
2009 require.Contains(t, result.Warnings[0].Details, "priority processing is only available")
2010 })
2011}
2012
2013type streamingMockServer struct {
2014 server *httptest.Server
2015 chunks []string
2016 calls []mockCall
2017}
2018
2019func newStreamingMockServer() *streamingMockServer {
2020 sms := &streamingMockServer{
2021 calls: make([]mockCall, 0),
2022 }
2023
2024 sms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2025 // Record the call
2026 call := mockCall{
2027 method: r.Method,
2028 path: r.URL.Path,
2029 headers: make(map[string]string),
2030 }
2031
2032 for k, v := range r.Header {
2033 if len(v) > 0 {
2034 call.headers[k] = v[0]
2035 }
2036 }
2037
2038 // Parse request body
2039 if r.Body != nil {
2040 var body map[string]any
2041 json.NewDecoder(r.Body).Decode(&body)
2042 call.body = body
2043 }
2044
2045 sms.calls = append(sms.calls, call)
2046
2047 // Set streaming headers
2048 w.Header().Set("Content-Type", "text/event-stream")
2049 w.Header().Set("Cache-Control", "no-cache")
2050 w.Header().Set("Connection", "keep-alive")
2051
2052 // Add custom headers if any
2053 for _, chunk := range sms.chunks {
2054 if strings.HasPrefix(chunk, "HEADER:") {
2055 parts := strings.SplitN(chunk[7:], ":", 2)
2056 if len(parts) == 2 {
2057 w.Header().Set(parts[0], parts[1])
2058 }
2059 continue
2060 }
2061 }
2062
2063 w.WriteHeader(http.StatusOK)
2064
2065 // Write chunks
2066 for _, chunk := range sms.chunks {
2067 if strings.HasPrefix(chunk, "HEADER:") {
2068 continue
2069 }
2070 w.Write([]byte(chunk))
2071 if f, ok := w.(http.Flusher); ok {
2072 f.Flush()
2073 }
2074 }
2075 }))
2076
2077 return sms
2078}
2079
2080func (sms *streamingMockServer) close() {
2081 sms.server.Close()
2082}
2083
2084func (sms *streamingMockServer) prepareStreamResponse(opts map[string]any) {
2085 content := []string{}
2086 if c, ok := opts["content"].([]string); ok {
2087 content = c
2088 }
2089
2090 usage := map[string]any{
2091 "prompt_tokens": 17,
2092 "total_tokens": 244,
2093 "completion_tokens": 227,
2094 }
2095 if u, ok := opts["usage"].(map[string]any); ok {
2096 usage = u
2097 }
2098
2099 logprobs := map[string]any{}
2100 if l, ok := opts["logprobs"].(map[string]any); ok {
2101 logprobs = l
2102 }
2103
2104 finishReason := "stop"
2105 if fr, ok := opts["finish_reason"].(string); ok {
2106 finishReason = fr
2107 }
2108
2109 model := "gpt-3.5-turbo-0613"
2110 if m, ok := opts["model"].(string); ok {
2111 model = m
2112 }
2113
2114 headers := map[string]string{}
2115 if h, ok := opts["headers"].(map[string]string); ok {
2116 headers = h
2117 }
2118
2119 chunks := []string{}
2120
2121 // Add custom headers
2122 for k, v := range headers {
2123 chunks = append(chunks, "HEADER:"+k+":"+v)
2124 }
2125
2126 // Initial chunk with role
2127 initialChunk := map[string]any{
2128 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2129 "object": "chat.completion.chunk",
2130 "created": 1702657020,
2131 "model": model,
2132 "system_fingerprint": nil,
2133 "choices": []map[string]any{
2134 {
2135 "index": 0,
2136 "delta": map[string]any{
2137 "role": "assistant",
2138 "content": "",
2139 },
2140 "finish_reason": nil,
2141 },
2142 },
2143 }
2144 initialData, _ := json.Marshal(initialChunk)
2145 chunks = append(chunks, "data: "+string(initialData)+"\n\n")
2146
2147 // Content chunks
2148 for i, text := range content {
2149 contentChunk := map[string]any{
2150 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2151 "object": "chat.completion.chunk",
2152 "created": 1702657020,
2153 "model": model,
2154 "system_fingerprint": nil,
2155 "choices": []map[string]any{
2156 {
2157 "index": 1,
2158 "delta": map[string]any{
2159 "content": text,
2160 },
2161 "finish_reason": nil,
2162 },
2163 },
2164 }
2165 contentData, _ := json.Marshal(contentChunk)
2166 chunks = append(chunks, "data: "+string(contentData)+"\n\n")
2167
2168 // Add annotations if this is the last content chunk and we have annotations
2169 if i == len(content)-1 {
2170 if annotations, ok := opts["annotations"].([]map[string]any); ok {
2171 annotationChunk := map[string]any{
2172 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2173 "object": "chat.completion.chunk",
2174 "created": 1702657020,
2175 "model": model,
2176 "system_fingerprint": nil,
2177 "choices": []map[string]any{
2178 {
2179 "index": 1,
2180 "delta": map[string]any{
2181 "annotations": annotations,
2182 },
2183 "finish_reason": nil,
2184 },
2185 },
2186 }
2187 annotationData, _ := json.Marshal(annotationChunk)
2188 chunks = append(chunks, "data: "+string(annotationData)+"\n\n")
2189 }
2190 }
2191 }
2192
2193 // Finish chunk
2194 finishChunk := map[string]any{
2195 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2196 "object": "chat.completion.chunk",
2197 "created": 1702657020,
2198 "model": model,
2199 "system_fingerprint": nil,
2200 "choices": []map[string]any{
2201 {
2202 "index": 0,
2203 "delta": map[string]any{},
2204 "finish_reason": finishReason,
2205 },
2206 },
2207 }
2208
2209 if len(logprobs) > 0 {
2210 finishChunk["choices"].([]map[string]any)[0]["logprobs"] = logprobs
2211 }
2212
2213 finishData, _ := json.Marshal(finishChunk)
2214 chunks = append(chunks, "data: "+string(finishData)+"\n\n")
2215
2216 // Usage chunk
2217 usageChunk := map[string]any{
2218 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2219 "object": "chat.completion.chunk",
2220 "created": 1702657020,
2221 "model": model,
2222 "system_fingerprint": "fp_3bc1b5746c",
2223 "choices": []map[string]any{},
2224 "usage": usage,
2225 }
2226 usageData, _ := json.Marshal(usageChunk)
2227 chunks = append(chunks, "data: "+string(usageData)+"\n\n")
2228
2229 // Done
2230 chunks = append(chunks, "data: [DONE]\n\n")
2231
2232 sms.chunks = chunks
2233}
2234
2235func (sms *streamingMockServer) prepareToolStreamResponse() {
2236 chunks := []string{
2237 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2238 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2239 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"value"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2240 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2241 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2242 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2243 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2244 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2245 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
2246 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
2247 "data: [DONE]\n\n",
2248 }
2249 sms.chunks = chunks
2250}
2251
2252func (sms *streamingMockServer) prepareErrorStreamResponse() {
2253 chunks := []string{
2254 `data: {"error":{"message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error.","type":"server_error","param":null,"code":null}}` + "\n\n",
2255 "data: [DONE]\n\n",
2256 }
2257 sms.chunks = chunks
2258}
2259
2260func collectStreamParts(stream fantasy.StreamResponse) ([]fantasy.StreamPart, error) {
2261 var parts []fantasy.StreamPart
2262 for part := range stream {
2263 parts = append(parts, part)
2264 if part.Type == fantasy.StreamPartTypeError {
2265 break
2266 }
2267 if part.Type == fantasy.StreamPartTypeFinish {
2268 break
2269 }
2270 }
2271 return parts, nil
2272}
2273
2274func TestDoStream(t *testing.T) {
2275 t.Parallel()
2276
2277 t.Run("should stream text deltas", func(t *testing.T) {
2278 t.Parallel()
2279
2280 server := newStreamingMockServer()
2281 defer server.close()
2282
2283 server.prepareStreamResponse(map[string]any{
2284 "content": []string{"Hello", ", ", "World!"},
2285 "finish_reason": "stop",
2286 "usage": map[string]any{
2287 "prompt_tokens": 17,
2288 "total_tokens": 244,
2289 "completion_tokens": 227,
2290 },
2291 "logprobs": testLogprobs,
2292 })
2293
2294 provider, err := New(
2295 WithAPIKey("test-api-key"),
2296 WithBaseURL(server.server.URL),
2297 )
2298 require.NoError(t, err)
2299 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2300
2301 stream, err := model.Stream(context.Background(), fantasy.Call{
2302 Prompt: testPrompt,
2303 })
2304
2305 require.NoError(t, err)
2306
2307 parts, err := collectStreamParts(stream)
2308 require.NoError(t, err)
2309
2310 // Verify stream structure
2311 require.True(t, len(parts) >= 4) // text-start, deltas, text-end, finish
2312
2313 // Find text parts
2314 textStart, textEnd, finish := -1, -1, -1
2315 var deltas []string
2316
2317 for i, part := range parts {
2318 switch part.Type {
2319 case fantasy.StreamPartTypeTextStart:
2320 textStart = i
2321 case fantasy.StreamPartTypeTextDelta:
2322 deltas = append(deltas, part.Delta)
2323 case fantasy.StreamPartTypeTextEnd:
2324 textEnd = i
2325 case fantasy.StreamPartTypeFinish:
2326 finish = i
2327 }
2328 }
2329
2330 require.NotEqual(t, -1, textStart)
2331 require.NotEqual(t, -1, textEnd)
2332 require.NotEqual(t, -1, finish)
2333 require.Equal(t, []string{"Hello", ", ", "World!"}, deltas)
2334
2335 // Check finish part
2336 finishPart := parts[finish]
2337 require.Equal(t, fantasy.FinishReasonStop, finishPart.FinishReason)
2338 require.Equal(t, int64(17), finishPart.Usage.InputTokens)
2339 require.Equal(t, int64(227), finishPart.Usage.OutputTokens)
2340 require.Equal(t, int64(244), finishPart.Usage.TotalTokens)
2341 })
2342
2343 t.Run("should stream tool deltas", func(t *testing.T) {
2344 t.Parallel()
2345
2346 server := newStreamingMockServer()
2347 defer server.close()
2348
2349 server.prepareToolStreamResponse()
2350
2351 provider, err := New(
2352 WithAPIKey("test-api-key"),
2353 WithBaseURL(server.server.URL),
2354 )
2355 require.NoError(t, err)
2356 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2357
2358 stream, err := model.Stream(context.Background(), fantasy.Call{
2359 Prompt: testPrompt,
2360 Tools: []fantasy.Tool{
2361 fantasy.FunctionTool{
2362 Name: "test-tool",
2363 InputSchema: map[string]any{
2364 "type": "object",
2365 "properties": map[string]any{
2366 "value": map[string]any{
2367 "type": "string",
2368 },
2369 },
2370 "required": []string{"value"},
2371 "additionalProperties": false,
2372 "$schema": "http://json-schema.org/draft-07/schema#",
2373 },
2374 },
2375 },
2376 })
2377
2378 require.NoError(t, err)
2379
2380 parts, err := collectStreamParts(stream)
2381 require.NoError(t, err)
2382
2383 // Find tool-related parts
2384 toolInputStart, toolInputEnd, toolCall := -1, -1, -1
2385 var toolDeltas []string
2386
2387 for i, part := range parts {
2388 switch part.Type {
2389 case fantasy.StreamPartTypeToolInputStart:
2390 toolInputStart = i
2391 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2392 require.Equal(t, "test-tool", part.ToolCallName)
2393 case fantasy.StreamPartTypeToolInputDelta:
2394 toolDeltas = append(toolDeltas, part.Delta)
2395 case fantasy.StreamPartTypeToolInputEnd:
2396 toolInputEnd = i
2397 case fantasy.StreamPartTypeToolCall:
2398 toolCall = i
2399 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2400 require.Equal(t, "test-tool", part.ToolCallName)
2401 require.Equal(t, `{"value":"Sparkle Day"}`, part.ToolCallInput)
2402 }
2403 }
2404
2405 require.NotEqual(t, -1, toolInputStart)
2406 require.NotEqual(t, -1, toolInputEnd)
2407 require.NotEqual(t, -1, toolCall)
2408
2409 // Verify tool deltas combine to form the complete input
2410 fullInput := ""
2411 for _, delta := range toolDeltas {
2412 fullInput += delta
2413 }
2414 require.Equal(t, `{"value":"Sparkle Day"}`, fullInput)
2415 })
2416
2417 t.Run("should stream annotations/citations", func(t *testing.T) {
2418 t.Parallel()
2419
2420 server := newStreamingMockServer()
2421 defer server.close()
2422
2423 server.prepareStreamResponse(map[string]any{
2424 "content": []string{"Based on search results"},
2425 "annotations": []map[string]any{
2426 {
2427 "type": "url_citation",
2428 "url_citation": map[string]any{
2429 "start_index": 24,
2430 "end_index": 29,
2431 "url": "https://example.com/doc1.pdf",
2432 "title": "Document 1",
2433 },
2434 },
2435 },
2436 })
2437
2438 provider, err := New(
2439 WithAPIKey("test-api-key"),
2440 WithBaseURL(server.server.URL),
2441 )
2442 require.NoError(t, err)
2443 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2444
2445 stream, err := model.Stream(context.Background(), fantasy.Call{
2446 Prompt: testPrompt,
2447 })
2448
2449 require.NoError(t, err)
2450
2451 parts, err := collectStreamParts(stream)
2452 require.NoError(t, err)
2453
2454 // Find source part
2455 var sourcePart *fantasy.StreamPart
2456 for _, part := range parts {
2457 if part.Type == fantasy.StreamPartTypeSource {
2458 sourcePart = &part
2459 break
2460 }
2461 }
2462
2463 require.NotNil(t, sourcePart)
2464 require.Equal(t, fantasy.SourceTypeURL, sourcePart.SourceType)
2465 require.Equal(t, "https://example.com/doc1.pdf", sourcePart.URL)
2466 require.Equal(t, "Document 1", sourcePart.Title)
2467 require.NotEmpty(t, sourcePart.ID)
2468 })
2469
2470 t.Run("should handle error stream parts", func(t *testing.T) {
2471 t.Parallel()
2472
2473 server := newStreamingMockServer()
2474 defer server.close()
2475
2476 server.prepareErrorStreamResponse()
2477
2478 provider, err := New(
2479 WithAPIKey("test-api-key"),
2480 WithBaseURL(server.server.URL),
2481 )
2482 require.NoError(t, err)
2483 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2484
2485 stream, err := model.Stream(context.Background(), fantasy.Call{
2486 Prompt: testPrompt,
2487 })
2488
2489 require.NoError(t, err)
2490
2491 parts, err := collectStreamParts(stream)
2492 require.NoError(t, err)
2493
2494 // Should have error and finish parts
2495 require.True(t, len(parts) >= 1)
2496
2497 // Find error part
2498 var errorPart *fantasy.StreamPart
2499 for _, part := range parts {
2500 if part.Type == fantasy.StreamPartTypeError {
2501 errorPart = &part
2502 break
2503 }
2504 }
2505
2506 require.NotNil(t, errorPart)
2507 require.NotNil(t, errorPart.Error)
2508 })
2509
2510 t.Run("should send request body", func(t *testing.T) {
2511 t.Parallel()
2512
2513 server := newStreamingMockServer()
2514 defer server.close()
2515
2516 server.prepareStreamResponse(map[string]any{
2517 "content": []string{},
2518 })
2519
2520 provider, err := New(
2521 WithAPIKey("test-api-key"),
2522 WithBaseURL(server.server.URL),
2523 )
2524 require.NoError(t, err)
2525 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2526
2527 _, err = model.Stream(context.Background(), fantasy.Call{
2528 Prompt: testPrompt,
2529 })
2530
2531 require.NoError(t, err)
2532 require.Len(t, server.calls, 1)
2533
2534 call := server.calls[0]
2535 require.Equal(t, "POST", call.method)
2536 require.Equal(t, "/chat/completions", call.path)
2537 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2538 require.Equal(t, true, call.body["stream"])
2539
2540 streamOptions := call.body["stream_options"].(map[string]any)
2541 require.Equal(t, true, streamOptions["include_usage"])
2542
2543 messages := call.body["messages"].([]any)
2544 require.Len(t, messages, 1)
2545
2546 message := messages[0].(map[string]any)
2547 require.Equal(t, "user", message["role"])
2548 require.Equal(t, "Hello", message["content"])
2549 })
2550
2551 t.Run("should return cached tokens in providerMetadata", func(t *testing.T) {
2552 t.Parallel()
2553
2554 server := newStreamingMockServer()
2555 defer server.close()
2556
2557 server.prepareStreamResponse(map[string]any{
2558 "content": []string{},
2559 "usage": map[string]any{
2560 "prompt_tokens": 15,
2561 "completion_tokens": 20,
2562 "total_tokens": 35,
2563 "prompt_tokens_details": map[string]any{
2564 "cached_tokens": 1152,
2565 },
2566 },
2567 })
2568
2569 provider, err := New(
2570 WithAPIKey("test-api-key"),
2571 WithBaseURL(server.server.URL),
2572 )
2573 require.NoError(t, err)
2574 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2575
2576 stream, err := model.Stream(context.Background(), fantasy.Call{
2577 Prompt: testPrompt,
2578 })
2579
2580 require.NoError(t, err)
2581
2582 parts, err := collectStreamParts(stream)
2583 require.NoError(t, err)
2584
2585 // Find finish part
2586 var finishPart *fantasy.StreamPart
2587 for _, part := range parts {
2588 if part.Type == fantasy.StreamPartTypeFinish {
2589 finishPart = &part
2590 break
2591 }
2592 }
2593
2594 require.NotNil(t, finishPart)
2595 require.Equal(t, int64(1152), finishPart.Usage.CacheReadTokens)
2596 require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2597 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2598 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2599 })
2600
2601 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
2602 t.Parallel()
2603
2604 server := newStreamingMockServer()
2605 defer server.close()
2606
2607 server.prepareStreamResponse(map[string]any{
2608 "content": []string{},
2609 "usage": map[string]any{
2610 "prompt_tokens": 15,
2611 "completion_tokens": 20,
2612 "total_tokens": 35,
2613 "completion_tokens_details": map[string]any{
2614 "accepted_prediction_tokens": 123,
2615 "rejected_prediction_tokens": 456,
2616 },
2617 },
2618 })
2619
2620 provider, err := New(
2621 WithAPIKey("test-api-key"),
2622 WithBaseURL(server.server.URL),
2623 )
2624 require.NoError(t, err)
2625 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2626
2627 stream, err := model.Stream(context.Background(), fantasy.Call{
2628 Prompt: testPrompt,
2629 })
2630
2631 require.NoError(t, err)
2632
2633 parts, err := collectStreamParts(stream)
2634 require.NoError(t, err)
2635
2636 // Find finish part
2637 var finishPart *fantasy.StreamPart
2638 for _, part := range parts {
2639 if part.Type == fantasy.StreamPartTypeFinish {
2640 finishPart = &part
2641 break
2642 }
2643 }
2644
2645 require.NotNil(t, finishPart)
2646 require.NotNil(t, finishPart.ProviderMetadata)
2647
2648 openaiMeta, ok := finishPart.ProviderMetadata["openai"].(*ProviderMetadata)
2649 require.True(t, ok)
2650 require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
2651 require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
2652 })
2653
2654 t.Run("should send store extension setting", func(t *testing.T) {
2655 t.Parallel()
2656
2657 server := newStreamingMockServer()
2658 defer server.close()
2659
2660 server.prepareStreamResponse(map[string]any{
2661 "content": []string{},
2662 })
2663
2664 provider, err := New(
2665 WithAPIKey("test-api-key"),
2666 WithBaseURL(server.server.URL),
2667 )
2668 require.NoError(t, err)
2669 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2670
2671 _, err = model.Stream(context.Background(), fantasy.Call{
2672 Prompt: testPrompt,
2673 ProviderOptions: NewProviderOptions(&ProviderOptions{
2674 Store: fantasy.Opt(true),
2675 }),
2676 })
2677
2678 require.NoError(t, err)
2679 require.Len(t, server.calls, 1)
2680
2681 call := server.calls[0]
2682 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2683 require.Equal(t, true, call.body["stream"])
2684 require.Equal(t, true, call.body["store"])
2685
2686 streamOptions := call.body["stream_options"].(map[string]any)
2687 require.Equal(t, true, streamOptions["include_usage"])
2688
2689 messages := call.body["messages"].([]any)
2690 require.Len(t, messages, 1)
2691
2692 message := messages[0].(map[string]any)
2693 require.Equal(t, "user", message["role"])
2694 require.Equal(t, "Hello", message["content"])
2695 })
2696
2697 t.Run("should send metadata extension values", func(t *testing.T) {
2698 t.Parallel()
2699
2700 server := newStreamingMockServer()
2701 defer server.close()
2702
2703 server.prepareStreamResponse(map[string]any{
2704 "content": []string{},
2705 })
2706
2707 provider, err := New(
2708 WithAPIKey("test-api-key"),
2709 WithBaseURL(server.server.URL),
2710 )
2711 require.NoError(t, err)
2712 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2713
2714 _, err = model.Stream(context.Background(), fantasy.Call{
2715 Prompt: testPrompt,
2716 ProviderOptions: NewProviderOptions(&ProviderOptions{
2717 Metadata: map[string]any{
2718 "custom": "value",
2719 },
2720 }),
2721 })
2722
2723 require.NoError(t, err)
2724 require.Len(t, server.calls, 1)
2725
2726 call := server.calls[0]
2727 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2728 require.Equal(t, true, call.body["stream"])
2729
2730 metadata := call.body["metadata"].(map[string]any)
2731 require.Equal(t, "value", metadata["custom"])
2732
2733 streamOptions := call.body["stream_options"].(map[string]any)
2734 require.Equal(t, true, streamOptions["include_usage"])
2735
2736 messages := call.body["messages"].([]any)
2737 require.Len(t, messages, 1)
2738
2739 message := messages[0].(map[string]any)
2740 require.Equal(t, "user", message["role"])
2741 require.Equal(t, "Hello", message["content"])
2742 })
2743
2744 t.Run("should send serviceTier flex processing setting in streaming", func(t *testing.T) {
2745 t.Parallel()
2746
2747 server := newStreamingMockServer()
2748 defer server.close()
2749
2750 server.prepareStreamResponse(map[string]any{
2751 "content": []string{},
2752 })
2753
2754 provider, err := New(
2755 WithAPIKey("test-api-key"),
2756 WithBaseURL(server.server.URL),
2757 )
2758 require.NoError(t, err)
2759 model, _ := provider.LanguageModel(t.Context(), "o3-mini")
2760
2761 _, err = model.Stream(context.Background(), fantasy.Call{
2762 Prompt: testPrompt,
2763 ProviderOptions: NewProviderOptions(&ProviderOptions{
2764 ServiceTier: fantasy.Opt("flex"),
2765 }),
2766 })
2767
2768 require.NoError(t, err)
2769 require.Len(t, server.calls, 1)
2770
2771 call := server.calls[0]
2772 require.Equal(t, "o3-mini", call.body["model"])
2773 require.Equal(t, "flex", call.body["service_tier"])
2774 require.Equal(t, true, call.body["stream"])
2775
2776 streamOptions := call.body["stream_options"].(map[string]any)
2777 require.Equal(t, true, streamOptions["include_usage"])
2778
2779 messages := call.body["messages"].([]any)
2780 require.Len(t, messages, 1)
2781
2782 message := messages[0].(map[string]any)
2783 require.Equal(t, "user", message["role"])
2784 require.Equal(t, "Hello", message["content"])
2785 })
2786
2787 t.Run("should send serviceTier priority processing setting in streaming", func(t *testing.T) {
2788 t.Parallel()
2789
2790 server := newStreamingMockServer()
2791 defer server.close()
2792
2793 server.prepareStreamResponse(map[string]any{
2794 "content": []string{},
2795 })
2796
2797 provider, err := New(
2798 WithAPIKey("test-api-key"),
2799 WithBaseURL(server.server.URL),
2800 )
2801 require.NoError(t, err)
2802 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
2803
2804 _, err = model.Stream(context.Background(), fantasy.Call{
2805 Prompt: testPrompt,
2806 ProviderOptions: NewProviderOptions(&ProviderOptions{
2807 ServiceTier: fantasy.Opt("priority"),
2808 }),
2809 })
2810
2811 require.NoError(t, err)
2812 require.Len(t, server.calls, 1)
2813
2814 call := server.calls[0]
2815 require.Equal(t, "gpt-4o-mini", call.body["model"])
2816 require.Equal(t, "priority", call.body["service_tier"])
2817 require.Equal(t, true, call.body["stream"])
2818
2819 streamOptions := call.body["stream_options"].(map[string]any)
2820 require.Equal(t, true, streamOptions["include_usage"])
2821
2822 messages := call.body["messages"].([]any)
2823 require.Len(t, messages, 1)
2824
2825 message := messages[0].(map[string]any)
2826 require.Equal(t, "user", message["role"])
2827 require.Equal(t, "Hello", message["content"])
2828 })
2829
2830 t.Run("should stream text delta for reasoning models", func(t *testing.T) {
2831 t.Parallel()
2832
2833 server := newStreamingMockServer()
2834 defer server.close()
2835
2836 server.prepareStreamResponse(map[string]any{
2837 "content": []string{"Hello, World!"},
2838 "model": "o1-preview",
2839 })
2840
2841 provider, err := New(
2842 WithAPIKey("test-api-key"),
2843 WithBaseURL(server.server.URL),
2844 )
2845 require.NoError(t, err)
2846 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
2847
2848 stream, err := model.Stream(context.Background(), fantasy.Call{
2849 Prompt: testPrompt,
2850 })
2851
2852 require.NoError(t, err)
2853
2854 parts, err := collectStreamParts(stream)
2855 require.NoError(t, err)
2856
2857 // Find text parts
2858 var textDeltas []string
2859 for _, part := range parts {
2860 if part.Type == fantasy.StreamPartTypeTextDelta {
2861 textDeltas = append(textDeltas, part.Delta)
2862 }
2863 }
2864
2865 // Should contain the text content (without empty delta)
2866 require.Equal(t, []string{"Hello, World!"}, textDeltas)
2867 })
2868
2869 t.Run("should send reasoning tokens", func(t *testing.T) {
2870 t.Parallel()
2871
2872 server := newStreamingMockServer()
2873 defer server.close()
2874
2875 server.prepareStreamResponse(map[string]any{
2876 "content": []string{"Hello, World!"},
2877 "model": "o1-preview",
2878 "usage": map[string]any{
2879 "prompt_tokens": 15,
2880 "completion_tokens": 20,
2881 "total_tokens": 35,
2882 "completion_tokens_details": map[string]any{
2883 "reasoning_tokens": 10,
2884 },
2885 },
2886 })
2887
2888 provider, err := New(
2889 WithAPIKey("test-api-key"),
2890 WithBaseURL(server.server.URL),
2891 )
2892 require.NoError(t, err)
2893 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
2894
2895 stream, err := model.Stream(context.Background(), fantasy.Call{
2896 Prompt: testPrompt,
2897 })
2898
2899 require.NoError(t, err)
2900
2901 parts, err := collectStreamParts(stream)
2902 require.NoError(t, err)
2903
2904 // Find finish part
2905 var finishPart *fantasy.StreamPart
2906 for _, part := range parts {
2907 if part.Type == fantasy.StreamPartTypeFinish {
2908 finishPart = &part
2909 break
2910 }
2911 }
2912
2913 require.NotNil(t, finishPart)
2914 require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2915 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2916 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2917 require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens)
2918 })
2919}
2920
2921func TestDefaultToPrompt_DropsEmptyMessages(t *testing.T) {
2922 t.Parallel()
2923
2924 t.Run("should drop truly empty assistant messages", func(t *testing.T) {
2925 t.Parallel()
2926
2927 prompt := fantasy.Prompt{
2928 {
2929 Role: fantasy.MessageRoleUser,
2930 Content: []fantasy.MessagePart{
2931 fantasy.TextPart{Text: "Hello"},
2932 },
2933 },
2934 {
2935 Role: fantasy.MessageRoleAssistant,
2936 Content: []fantasy.MessagePart{},
2937 },
2938 }
2939
2940 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
2941
2942 require.Len(t, messages, 1, "should only have user message")
2943 require.Len(t, warnings, 1)
2944 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
2945 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
2946 })
2947
2948 t.Run("should keep assistant messages with text content", func(t *testing.T) {
2949 t.Parallel()
2950
2951 prompt := fantasy.Prompt{
2952 {
2953 Role: fantasy.MessageRoleUser,
2954 Content: []fantasy.MessagePart{
2955 fantasy.TextPart{Text: "Hello"},
2956 },
2957 },
2958 {
2959 Role: fantasy.MessageRoleAssistant,
2960 Content: []fantasy.MessagePart{
2961 fantasy.TextPart{Text: "Hi there!"},
2962 },
2963 },
2964 }
2965
2966 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
2967
2968 require.Len(t, messages, 2, "should have both user and assistant messages")
2969 require.Empty(t, warnings)
2970 })
2971
2972 t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
2973 t.Parallel()
2974
2975 prompt := fantasy.Prompt{
2976 {
2977 Role: fantasy.MessageRoleUser,
2978 Content: []fantasy.MessagePart{
2979 fantasy.TextPart{Text: "What's the weather?"},
2980 },
2981 },
2982 {
2983 Role: fantasy.MessageRoleAssistant,
2984 Content: []fantasy.MessagePart{
2985 fantasy.ToolCallPart{
2986 ToolCallID: "call_123",
2987 ToolName: "get_weather",
2988 Input: `{"location":"NYC"}`,
2989 },
2990 },
2991 },
2992 }
2993
2994 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
2995
2996 require.Len(t, messages, 2, "should have both user and assistant messages")
2997 require.Empty(t, warnings)
2998 })
2999
3000 t.Run("should drop user messages without visible content", func(t *testing.T) {
3001 t.Parallel()
3002
3003 prompt := fantasy.Prompt{
3004 {
3005 Role: fantasy.MessageRoleUser,
3006 Content: []fantasy.MessagePart{
3007 fantasy.FilePart{
3008 Data: []byte("not supported"),
3009 MediaType: "application/unknown",
3010 },
3011 },
3012 },
3013 }
3014
3015 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3016
3017 require.Empty(t, messages)
3018 require.Len(t, warnings, 2) // One for unsupported type, one for empty message
3019 require.Contains(t, warnings[1].Message, "dropping empty user message")
3020 })
3021
3022 t.Run("should keep user messages with image content", func(t *testing.T) {
3023 t.Parallel()
3024
3025 prompt := fantasy.Prompt{
3026 {
3027 Role: fantasy.MessageRoleUser,
3028 Content: []fantasy.MessagePart{
3029 fantasy.FilePart{
3030 Data: []byte{0x01, 0x02, 0x03},
3031 MediaType: "image/png",
3032 },
3033 },
3034 },
3035 }
3036
3037 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3038
3039 require.Len(t, messages, 1)
3040 require.Empty(t, warnings)
3041 })
3042
3043 t.Run("should keep user messages with tool results", func(t *testing.T) {
3044 t.Parallel()
3045
3046 prompt := fantasy.Prompt{
3047 {
3048 Role: fantasy.MessageRoleTool,
3049 Content: []fantasy.MessagePart{
3050 fantasy.ToolResultPart{
3051 ToolCallID: "call_123",
3052 Output: fantasy.ToolResultOutputContentText{Text: "done"},
3053 },
3054 },
3055 },
3056 }
3057
3058 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3059
3060 require.Len(t, messages, 1)
3061 require.Empty(t, warnings)
3062 })
3063
3064 t.Run("should keep user messages with tool error results", func(t *testing.T) {
3065 t.Parallel()
3066
3067 prompt := fantasy.Prompt{
3068 {
3069 Role: fantasy.MessageRoleTool,
3070 Content: []fantasy.MessagePart{
3071 fantasy.ToolResultPart{
3072 ToolCallID: "call_456",
3073 Output: fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
3074 },
3075 },
3076 },
3077 }
3078
3079 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3080
3081 require.Len(t, messages, 1)
3082 require.Empty(t, warnings)
3083 })
3084}
3085
3086func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) {
3087 t.Parallel()
3088
3089 t.Run("should drop truly empty assistant messages", func(t *testing.T) {
3090 t.Parallel()
3091
3092 prompt := fantasy.Prompt{
3093 {
3094 Role: fantasy.MessageRoleUser,
3095 Content: []fantasy.MessagePart{
3096 fantasy.TextPart{Text: "Hello"},
3097 },
3098 },
3099 {
3100 Role: fantasy.MessageRoleAssistant,
3101 Content: []fantasy.MessagePart{},
3102 },
3103 }
3104
3105 input, warnings := toResponsesPrompt(prompt, "system")
3106
3107 require.Len(t, input, 1, "should only have user message")
3108 require.Len(t, warnings, 1)
3109 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
3110 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
3111 })
3112
3113 t.Run("should keep assistant messages with text content", func(t *testing.T) {
3114 t.Parallel()
3115
3116 prompt := fantasy.Prompt{
3117 {
3118 Role: fantasy.MessageRoleUser,
3119 Content: []fantasy.MessagePart{
3120 fantasy.TextPart{Text: "Hello"},
3121 },
3122 },
3123 {
3124 Role: fantasy.MessageRoleAssistant,
3125 Content: []fantasy.MessagePart{
3126 fantasy.TextPart{Text: "Hi there!"},
3127 },
3128 },
3129 }
3130
3131 input, warnings := toResponsesPrompt(prompt, "system")
3132
3133 require.Len(t, input, 2, "should have both user and assistant messages")
3134 require.Empty(t, warnings)
3135 })
3136
3137 t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
3138 t.Parallel()
3139
3140 prompt := fantasy.Prompt{
3141 {
3142 Role: fantasy.MessageRoleUser,
3143 Content: []fantasy.MessagePart{
3144 fantasy.TextPart{Text: "What's the weather?"},
3145 },
3146 },
3147 {
3148 Role: fantasy.MessageRoleAssistant,
3149 Content: []fantasy.MessagePart{
3150 fantasy.ToolCallPart{
3151 ToolCallID: "call_123",
3152 ToolName: "get_weather",
3153 Input: `{"location":"NYC"}`,
3154 },
3155 },
3156 },
3157 }
3158
3159 input, warnings := toResponsesPrompt(prompt, "system")
3160
3161 require.Len(t, input, 2, "should have both user and assistant messages")
3162 require.Empty(t, warnings)
3163 })
3164
3165 t.Run("should drop user messages without visible content", func(t *testing.T) {
3166 t.Parallel()
3167
3168 prompt := fantasy.Prompt{
3169 {
3170 Role: fantasy.MessageRoleUser,
3171 Content: []fantasy.MessagePart{
3172 fantasy.FilePart{
3173 Data: []byte("not supported"),
3174 MediaType: "application/unknown",
3175 },
3176 },
3177 },
3178 }
3179
3180 input, warnings := toResponsesPrompt(prompt, "system")
3181
3182 require.Empty(t, input)
3183 require.Len(t, warnings, 2) // One for unsupported type, one for empty message
3184 require.Contains(t, warnings[1].Message, "dropping empty user message")
3185 })
3186
3187 t.Run("should keep user messages with image content", func(t *testing.T) {
3188 t.Parallel()
3189
3190 prompt := fantasy.Prompt{
3191 {
3192 Role: fantasy.MessageRoleUser,
3193 Content: []fantasy.MessagePart{
3194 fantasy.FilePart{
3195 Data: []byte{0x01, 0x02, 0x03},
3196 MediaType: "image/png",
3197 },
3198 },
3199 },
3200 }
3201
3202 input, warnings := toResponsesPrompt(prompt, "system")
3203
3204 require.Len(t, input, 1)
3205 require.Empty(t, warnings)
3206 })
3207
3208 t.Run("should keep user messages with tool results", func(t *testing.T) {
3209 t.Parallel()
3210
3211 prompt := fantasy.Prompt{
3212 {
3213 Role: fantasy.MessageRoleTool,
3214 Content: []fantasy.MessagePart{
3215 fantasy.ToolResultPart{
3216 ToolCallID: "call_123",
3217 Output: fantasy.ToolResultOutputContentText{Text: "done"},
3218 },
3219 },
3220 },
3221 }
3222
3223 input, warnings := toResponsesPrompt(prompt, "system")
3224
3225 require.Len(t, input, 1)
3226 require.Empty(t, warnings)
3227 })
3228
3229 t.Run("should keep user messages with tool error results", func(t *testing.T) {
3230 t.Parallel()
3231
3232 prompt := fantasy.Prompt{
3233 {
3234 Role: fantasy.MessageRoleTool,
3235 Content: []fantasy.MessagePart{
3236 fantasy.ToolResultPart{
3237 ToolCallID: "call_456",
3238 Output: fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
3239 },
3240 },
3241 },
3242 }
3243
3244 input, warnings := toResponsesPrompt(prompt, "system")
3245
3246 require.Len(t, input, 1)
3247 require.Empty(t, warnings)
3248 })
3249}