1package openai
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "net/http"
9 "net/http/httptest"
10 "strings"
11 "testing"
12
13 "charm.land/fantasy"
14 "github.com/openai/openai-go/v2/packages/param"
15 "github.com/stretchr/testify/require"
16)
17
18func TestToOpenAiPrompt_SystemMessages(t *testing.T) {
19 t.Parallel()
20
21 t.Run("should forward system messages", func(t *testing.T) {
22 t.Parallel()
23
24 prompt := fantasy.Prompt{
25 {
26 Role: fantasy.MessageRoleSystem,
27 Content: []fantasy.MessagePart{
28 fantasy.TextPart{Text: "You are a helpful assistant."},
29 },
30 },
31 }
32
33 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
34
35 require.Empty(t, warnings)
36 require.Len(t, messages, 1)
37
38 systemMsg := messages[0].OfSystem
39 require.NotNil(t, systemMsg)
40 require.Equal(t, "You are a helpful assistant.", systemMsg.Content.OfString.Value)
41 })
42
43 t.Run("should handle empty system messages", func(t *testing.T) {
44 t.Parallel()
45
46 prompt := fantasy.Prompt{
47 {
48 Role: fantasy.MessageRoleSystem,
49 Content: []fantasy.MessagePart{},
50 },
51 }
52
53 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
54
55 require.Len(t, warnings, 1)
56 require.Contains(t, warnings[0].Message, "system prompt has no text parts")
57 require.Empty(t, messages)
58 })
59
60 t.Run("should join multiple system text parts", func(t *testing.T) {
61 t.Parallel()
62
63 prompt := fantasy.Prompt{
64 {
65 Role: fantasy.MessageRoleSystem,
66 Content: []fantasy.MessagePart{
67 fantasy.TextPart{Text: "You are a helpful assistant."},
68 fantasy.TextPart{Text: "Be concise."},
69 },
70 },
71 }
72
73 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
74
75 require.Empty(t, warnings)
76 require.Len(t, messages, 1)
77
78 systemMsg := messages[0].OfSystem
79 require.NotNil(t, systemMsg)
80 require.Equal(t, "You are a helpful assistant.\nBe concise.", systemMsg.Content.OfString.Value)
81 })
82}
83
84func TestToOpenAiPrompt_UserMessages(t *testing.T) {
85 t.Parallel()
86
87 t.Run("should convert messages with only a text part to a string content", func(t *testing.T) {
88 t.Parallel()
89
90 prompt := fantasy.Prompt{
91 {
92 Role: fantasy.MessageRoleUser,
93 Content: []fantasy.MessagePart{
94 fantasy.TextPart{Text: "Hello"},
95 },
96 },
97 }
98
99 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
100
101 require.Empty(t, warnings)
102 require.Len(t, messages, 1)
103
104 userMsg := messages[0].OfUser
105 require.NotNil(t, userMsg)
106 require.Equal(t, "Hello", userMsg.Content.OfString.Value)
107 })
108
109 t.Run("should convert messages with image parts", func(t *testing.T) {
110 t.Parallel()
111
112 imageData := []byte{0, 1, 2, 3}
113 prompt := fantasy.Prompt{
114 {
115 Role: fantasy.MessageRoleUser,
116 Content: []fantasy.MessagePart{
117 fantasy.TextPart{Text: "Hello"},
118 fantasy.FilePart{
119 MediaType: "image/png",
120 Data: imageData,
121 },
122 },
123 },
124 }
125
126 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
127
128 require.Empty(t, warnings)
129 require.Len(t, messages, 1)
130
131 userMsg := messages[0].OfUser
132 require.NotNil(t, userMsg)
133
134 content := userMsg.Content.OfArrayOfContentParts
135 require.Len(t, content, 2)
136
137 // Check text part
138 textPart := content[0].OfText
139 require.NotNil(t, textPart)
140 require.Equal(t, "Hello", textPart.Text)
141
142 // Check image part
143 imagePart := content[1].OfImageURL
144 require.NotNil(t, imagePart)
145 expectedURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
146 require.Equal(t, expectedURL, imagePart.ImageURL.URL)
147 })
148
149 t.Run("should add image detail when specified through provider options", func(t *testing.T) {
150 t.Parallel()
151
152 imageData := []byte{0, 1, 2, 3}
153 prompt := fantasy.Prompt{
154 {
155 Role: fantasy.MessageRoleUser,
156 Content: []fantasy.MessagePart{
157 fantasy.FilePart{
158 MediaType: "image/png",
159 Data: imageData,
160 ProviderOptions: NewProviderFileOptions(&ProviderFileOptions{
161 ImageDetail: "low",
162 }),
163 },
164 },
165 },
166 }
167
168 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
169
170 require.Empty(t, warnings)
171 require.Len(t, messages, 1)
172
173 userMsg := messages[0].OfUser
174 require.NotNil(t, userMsg)
175
176 content := userMsg.Content.OfArrayOfContentParts
177 require.Len(t, content, 1)
178
179 imagePart := content[0].OfImageURL
180 require.NotNil(t, imagePart)
181 require.Equal(t, "low", imagePart.ImageURL.Detail)
182 })
183}
184
185func TestToOpenAiPrompt_FileParts(t *testing.T) {
186 t.Parallel()
187
188 t.Run("should throw for unsupported mime types", func(t *testing.T) {
189 t.Parallel()
190
191 prompt := fantasy.Prompt{
192 {
193 Role: fantasy.MessageRoleUser,
194 Content: []fantasy.MessagePart{
195 fantasy.FilePart{
196 MediaType: "application/something",
197 Data: []byte("test"),
198 },
199 },
200 },
201 }
202
203 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
204
205 require.Len(t, warnings, 2) // unsupported type + empty message
206 require.Contains(t, warnings[0].Message, "file part media type application/something not supported")
207 require.Contains(t, warnings[1].Message, "dropping empty user message")
208 require.Empty(t, messages) // Message is now dropped because it's empty
209 })
210
211 t.Run("should add audio content for audio/wav file parts", func(t *testing.T) {
212 t.Parallel()
213
214 audioData := []byte{0, 1, 2, 3}
215 prompt := fantasy.Prompt{
216 {
217 Role: fantasy.MessageRoleUser,
218 Content: []fantasy.MessagePart{
219 fantasy.FilePart{
220 MediaType: "audio/wav",
221 Data: audioData,
222 },
223 },
224 },
225 }
226
227 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
228
229 require.Empty(t, warnings)
230 require.Len(t, messages, 1)
231
232 userMsg := messages[0].OfUser
233 require.NotNil(t, userMsg)
234
235 content := userMsg.Content.OfArrayOfContentParts
236 require.Len(t, content, 1)
237
238 audioPart := content[0].OfInputAudio
239 require.NotNil(t, audioPart)
240 require.Equal(t, base64.StdEncoding.EncodeToString(audioData), audioPart.InputAudio.Data)
241 require.Equal(t, "wav", audioPart.InputAudio.Format)
242 })
243
244 t.Run("should add audio content for audio/mpeg file parts", func(t *testing.T) {
245 t.Parallel()
246
247 audioData := []byte{0, 1, 2, 3}
248 prompt := fantasy.Prompt{
249 {
250 Role: fantasy.MessageRoleUser,
251 Content: []fantasy.MessagePart{
252 fantasy.FilePart{
253 MediaType: "audio/mpeg",
254 Data: audioData,
255 },
256 },
257 },
258 }
259
260 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
261
262 require.Empty(t, warnings)
263 require.Len(t, messages, 1)
264
265 userMsg := messages[0].OfUser
266 content := userMsg.Content.OfArrayOfContentParts
267 audioPart := content[0].OfInputAudio
268 require.NotNil(t, audioPart)
269 require.Equal(t, "mp3", audioPart.InputAudio.Format)
270 })
271
272 t.Run("should add audio content for audio/mp3 file parts", func(t *testing.T) {
273 t.Parallel()
274
275 audioData := []byte{0, 1, 2, 3}
276 prompt := fantasy.Prompt{
277 {
278 Role: fantasy.MessageRoleUser,
279 Content: []fantasy.MessagePart{
280 fantasy.FilePart{
281 MediaType: "audio/mp3",
282 Data: audioData,
283 },
284 },
285 },
286 }
287
288 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
289
290 require.Empty(t, warnings)
291 require.Len(t, messages, 1)
292
293 userMsg := messages[0].OfUser
294 content := userMsg.Content.OfArrayOfContentParts
295 audioPart := content[0].OfInputAudio
296 require.NotNil(t, audioPart)
297 require.Equal(t, "mp3", audioPart.InputAudio.Format)
298 })
299
300 t.Run("should convert messages with PDF file parts", func(t *testing.T) {
301 t.Parallel()
302
303 pdfData := []byte{1, 2, 3, 4, 5}
304 prompt := fantasy.Prompt{
305 {
306 Role: fantasy.MessageRoleUser,
307 Content: []fantasy.MessagePart{
308 fantasy.FilePart{
309 MediaType: "application/pdf",
310 Data: pdfData,
311 Filename: "document.pdf",
312 },
313 },
314 },
315 }
316
317 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
318
319 require.Empty(t, warnings)
320 require.Len(t, messages, 1)
321
322 userMsg := messages[0].OfUser
323 content := userMsg.Content.OfArrayOfContentParts
324 require.Len(t, content, 1)
325
326 filePart := content[0].OfFile
327 require.NotNil(t, filePart)
328 require.Equal(t, "document.pdf", filePart.File.Filename.Value)
329
330 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
331 require.Equal(t, expectedData, filePart.File.FileData.Value)
332 })
333
334 t.Run("should convert messages with binary PDF file parts", func(t *testing.T) {
335 t.Parallel()
336
337 pdfData := []byte{1, 2, 3, 4, 5}
338 prompt := fantasy.Prompt{
339 {
340 Role: fantasy.MessageRoleUser,
341 Content: []fantasy.MessagePart{
342 fantasy.FilePart{
343 MediaType: "application/pdf",
344 Data: pdfData,
345 Filename: "document.pdf",
346 },
347 },
348 },
349 }
350
351 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
352
353 require.Empty(t, warnings)
354 require.Len(t, messages, 1)
355
356 userMsg := messages[0].OfUser
357 content := userMsg.Content.OfArrayOfContentParts
358 filePart := content[0].OfFile
359 require.NotNil(t, filePart)
360
361 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
362 require.Equal(t, expectedData, filePart.File.FileData.Value)
363 })
364
365 t.Run("should convert messages with PDF file parts using file_id", func(t *testing.T) {
366 t.Parallel()
367
368 prompt := fantasy.Prompt{
369 {
370 Role: fantasy.MessageRoleUser,
371 Content: []fantasy.MessagePart{
372 fantasy.FilePart{
373 MediaType: "application/pdf",
374 Data: []byte("file-pdf-12345"),
375 },
376 },
377 },
378 }
379
380 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
381
382 require.Empty(t, warnings)
383 require.Len(t, messages, 1)
384
385 userMsg := messages[0].OfUser
386 content := userMsg.Content.OfArrayOfContentParts
387 filePart := content[0].OfFile
388 require.NotNil(t, filePart)
389 require.Equal(t, "file-pdf-12345", filePart.File.FileID.Value)
390 require.True(t, param.IsOmitted(filePart.File.FileData))
391 require.True(t, param.IsOmitted(filePart.File.Filename))
392 })
393
394 t.Run("should use default filename for PDF file parts when not provided", func(t *testing.T) {
395 t.Parallel()
396
397 pdfData := []byte{1, 2, 3, 4, 5}
398 prompt := fantasy.Prompt{
399 {
400 Role: fantasy.MessageRoleUser,
401 Content: []fantasy.MessagePart{
402 fantasy.FilePart{
403 MediaType: "application/pdf",
404 Data: pdfData,
405 },
406 },
407 },
408 }
409
410 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
411
412 require.Empty(t, warnings)
413 require.Len(t, messages, 1)
414
415 userMsg := messages[0].OfUser
416 content := userMsg.Content.OfArrayOfContentParts
417 filePart := content[0].OfFile
418 require.NotNil(t, filePart)
419 require.Equal(t, "part-0.pdf", filePart.File.Filename.Value)
420 })
421}
422
423func TestToOpenAiPrompt_ToolCalls(t *testing.T) {
424 t.Parallel()
425
426 t.Run("should stringify arguments to tool calls", func(t *testing.T) {
427 t.Parallel()
428
429 inputArgs := map[string]any{"foo": "bar123"}
430 inputJSON, _ := json.Marshal(inputArgs)
431
432 outputResult := map[string]any{"oof": "321rab"}
433 outputJSON, _ := json.Marshal(outputResult)
434
435 prompt := fantasy.Prompt{
436 {
437 Role: fantasy.MessageRoleAssistant,
438 Content: []fantasy.MessagePart{
439 fantasy.ToolCallPart{
440 ToolCallID: "quux",
441 ToolName: "thwomp",
442 Input: string(inputJSON),
443 },
444 },
445 },
446 {
447 Role: fantasy.MessageRoleTool,
448 Content: []fantasy.MessagePart{
449 fantasy.ToolResultPart{
450 ToolCallID: "quux",
451 Output: fantasy.ToolResultOutputContentText{
452 Text: string(outputJSON),
453 },
454 },
455 },
456 },
457 }
458
459 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
460
461 require.Empty(t, warnings)
462 require.Len(t, messages, 2)
463
464 // Check assistant message with tool call
465 assistantMsg := messages[0].OfAssistant
466 require.NotNil(t, assistantMsg)
467 require.Equal(t, "", assistantMsg.Content.OfString.Value)
468 require.Len(t, assistantMsg.ToolCalls, 1)
469
470 toolCall := assistantMsg.ToolCalls[0].OfFunction
471 require.NotNil(t, toolCall)
472 require.Equal(t, "quux", toolCall.ID)
473 require.Equal(t, "thwomp", toolCall.Function.Name)
474 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
475
476 // Check tool message
477 toolMsg := messages[1].OfTool
478 require.NotNil(t, toolMsg)
479 require.Equal(t, string(outputJSON), toolMsg.Content.OfString.Value)
480 require.Equal(t, "quux", toolMsg.ToolCallID)
481 })
482
483 t.Run("should handle different tool output types", func(t *testing.T) {
484 t.Parallel()
485
486 prompt := fantasy.Prompt{
487 {
488 Role: fantasy.MessageRoleTool,
489 Content: []fantasy.MessagePart{
490 fantasy.ToolResultPart{
491 ToolCallID: "text-tool",
492 Output: fantasy.ToolResultOutputContentText{
493 Text: "Hello world",
494 },
495 },
496 fantasy.ToolResultPart{
497 ToolCallID: "error-tool",
498 Output: fantasy.ToolResultOutputContentError{
499 Error: errors.New("Something went wrong"),
500 },
501 },
502 },
503 },
504 }
505
506 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
507
508 require.Empty(t, warnings)
509 require.Len(t, messages, 2)
510
511 // Check first tool message (text)
512 textToolMsg := messages[0].OfTool
513 require.NotNil(t, textToolMsg)
514 require.Equal(t, "Hello world", textToolMsg.Content.OfString.Value)
515 require.Equal(t, "text-tool", textToolMsg.ToolCallID)
516
517 // Check second tool message (error)
518 errorToolMsg := messages[1].OfTool
519 require.NotNil(t, errorToolMsg)
520 require.Equal(t, "Something went wrong", errorToolMsg.Content.OfString.Value)
521 require.Equal(t, "error-tool", errorToolMsg.ToolCallID)
522 })
523}
524
525func TestToOpenAiPrompt_AssistantMessages(t *testing.T) {
526 t.Parallel()
527
528 t.Run("should handle simple text assistant messages", func(t *testing.T) {
529 t.Parallel()
530
531 prompt := fantasy.Prompt{
532 {
533 Role: fantasy.MessageRoleAssistant,
534 Content: []fantasy.MessagePart{
535 fantasy.TextPart{Text: "Hello, how can I help you?"},
536 },
537 },
538 }
539
540 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
541
542 require.Empty(t, warnings)
543 require.Len(t, messages, 1)
544
545 assistantMsg := messages[0].OfAssistant
546 require.NotNil(t, assistantMsg)
547 require.Equal(t, "Hello, how can I help you?", assistantMsg.Content.OfString.Value)
548 })
549
550 t.Run("should handle assistant messages with mixed content", func(t *testing.T) {
551 t.Parallel()
552
553 inputArgs := map[string]any{"query": "test"}
554 inputJSON, _ := json.Marshal(inputArgs)
555
556 prompt := fantasy.Prompt{
557 {
558 Role: fantasy.MessageRoleAssistant,
559 Content: []fantasy.MessagePart{
560 fantasy.TextPart{Text: "Let me search for that."},
561 fantasy.ToolCallPart{
562 ToolCallID: "call-123",
563 ToolName: "search",
564 Input: string(inputJSON),
565 },
566 },
567 },
568 }
569
570 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
571
572 require.Empty(t, warnings)
573 require.Len(t, messages, 1)
574
575 assistantMsg := messages[0].OfAssistant
576 require.NotNil(t, assistantMsg)
577 require.Equal(t, "Let me search for that.", assistantMsg.Content.OfString.Value)
578 require.Len(t, assistantMsg.ToolCalls, 1)
579
580 toolCall := assistantMsg.ToolCalls[0].OfFunction
581 require.Equal(t, "call-123", toolCall.ID)
582 require.Equal(t, "search", toolCall.Function.Name)
583 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
584 })
585}
586
587var testPrompt = fantasy.Prompt{
588 {
589 Role: fantasy.MessageRoleUser,
590 Content: []fantasy.MessagePart{
591 fantasy.TextPart{Text: "Hello"},
592 },
593 },
594}
595
596var testLogprobs = map[string]any{
597 "content": []map[string]any{
598 {
599 "token": "Hello",
600 "logprob": -0.0009994634,
601 "top_logprobs": []map[string]any{
602 {
603 "token": "Hello",
604 "logprob": -0.0009994634,
605 },
606 },
607 },
608 {
609 "token": "!",
610 "logprob": -0.13410144,
611 "top_logprobs": []map[string]any{
612 {
613 "token": "!",
614 "logprob": -0.13410144,
615 },
616 },
617 },
618 {
619 "token": " How",
620 "logprob": -0.0009250381,
621 "top_logprobs": []map[string]any{
622 {
623 "token": " How",
624 "logprob": -0.0009250381,
625 },
626 },
627 },
628 {
629 "token": " can",
630 "logprob": -0.047709424,
631 "top_logprobs": []map[string]any{
632 {
633 "token": " can",
634 "logprob": -0.047709424,
635 },
636 },
637 },
638 {
639 "token": " I",
640 "logprob": -0.000009014684,
641 "top_logprobs": []map[string]any{
642 {
643 "token": " I",
644 "logprob": -0.000009014684,
645 },
646 },
647 },
648 {
649 "token": " assist",
650 "logprob": -0.009125131,
651 "top_logprobs": []map[string]any{
652 {
653 "token": " assist",
654 "logprob": -0.009125131,
655 },
656 },
657 },
658 {
659 "token": " you",
660 "logprob": -0.0000066306106,
661 "top_logprobs": []map[string]any{
662 {
663 "token": " you",
664 "logprob": -0.0000066306106,
665 },
666 },
667 },
668 {
669 "token": " today",
670 "logprob": -0.00011093382,
671 "top_logprobs": []map[string]any{
672 {
673 "token": " today",
674 "logprob": -0.00011093382,
675 },
676 },
677 },
678 {
679 "token": "?",
680 "logprob": -0.00004596782,
681 "top_logprobs": []map[string]any{
682 {
683 "token": "?",
684 "logprob": -0.00004596782,
685 },
686 },
687 },
688 },
689}
690
691type mockServer struct {
692 server *httptest.Server
693 response map[string]any
694 calls []mockCall
695}
696
697type mockCall struct {
698 method string
699 path string
700 headers map[string]string
701 body map[string]any
702}
703
704func newMockServer() *mockServer {
705 ms := &mockServer{
706 calls: make([]mockCall, 0),
707 }
708
709 ms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
710 // Record the call
711 call := mockCall{
712 method: r.Method,
713 path: r.URL.Path,
714 headers: make(map[string]string),
715 }
716
717 for k, v := range r.Header {
718 if len(v) > 0 {
719 call.headers[k] = v[0]
720 }
721 }
722
723 // Parse request body
724 if r.Body != nil {
725 var body map[string]any
726 json.NewDecoder(r.Body).Decode(&body)
727 call.body = body
728 }
729
730 ms.calls = append(ms.calls, call)
731
732 // Return mock response
733 w.Header().Set("Content-Type", "application/json")
734 json.NewEncoder(w).Encode(ms.response)
735 }))
736
737 return ms
738}
739
740func (ms *mockServer) close() {
741 ms.server.Close()
742}
743
744func (ms *mockServer) prepareJSONResponse(opts map[string]any) {
745 // Default values
746 response := map[string]any{
747 "id": "chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd",
748 "object": "chat.completion",
749 "created": 1711115037,
750 "model": "gpt-3.5-turbo-0125",
751 "choices": []map[string]any{
752 {
753 "index": 0,
754 "message": map[string]any{
755 "role": "assistant",
756 "content": "",
757 },
758 "finish_reason": "stop",
759 },
760 },
761 "usage": map[string]any{
762 "prompt_tokens": 4,
763 "total_tokens": 34,
764 "completion_tokens": 30,
765 },
766 "system_fingerprint": "fp_3bc1b5746c",
767 }
768
769 // Override with provided options
770 for k, v := range opts {
771 switch k {
772 case "content":
773 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["content"] = v
774 case "tool_calls":
775 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["tool_calls"] = v
776 case "function_call":
777 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["function_call"] = v
778 case "annotations":
779 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["annotations"] = v
780 case "usage":
781 response["usage"] = v
782 case "finish_reason":
783 response["choices"].([]map[string]any)[0]["finish_reason"] = v
784 case "id":
785 response["id"] = v
786 case "created":
787 response["created"] = v
788 case "model":
789 response["model"] = v
790 case "logprobs":
791 if v != nil {
792 response["choices"].([]map[string]any)[0]["logprobs"] = v
793 }
794 }
795 }
796
797 ms.response = response
798}
799
800func TestDoGenerate(t *testing.T) {
801 t.Parallel()
802
803 t.Run("should extract text response", func(t *testing.T) {
804 t.Parallel()
805
806 server := newMockServer()
807 defer server.close()
808
809 server.prepareJSONResponse(map[string]any{
810 "content": "Hello, World!",
811 })
812
813 provider, err := New(
814 WithAPIKey("test-api-key"),
815 WithBaseURL(server.server.URL),
816 )
817 require.NoError(t, err)
818 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
819
820 result, err := model.Generate(context.Background(), fantasy.Call{
821 Prompt: testPrompt,
822 })
823
824 require.NoError(t, err)
825 require.Len(t, result.Content, 1)
826
827 textContent, ok := result.Content[0].(fantasy.TextContent)
828 require.True(t, ok)
829 require.Equal(t, "Hello, World!", textContent.Text)
830 })
831
832 t.Run("should extract usage", func(t *testing.T) {
833 t.Parallel()
834
835 server := newMockServer()
836 defer server.close()
837
838 server.prepareJSONResponse(map[string]any{
839 "usage": map[string]any{
840 "prompt_tokens": 20,
841 "total_tokens": 25,
842 "completion_tokens": 5,
843 },
844 })
845
846 provider, err := New(
847 WithAPIKey("test-api-key"),
848 WithBaseURL(server.server.URL),
849 )
850 require.NoError(t, err)
851 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
852
853 result, err := model.Generate(context.Background(), fantasy.Call{
854 Prompt: testPrompt,
855 })
856
857 require.NoError(t, err)
858 require.Equal(t, int64(20), result.Usage.InputTokens)
859 require.Equal(t, int64(5), result.Usage.OutputTokens)
860 require.Equal(t, int64(25), result.Usage.TotalTokens)
861 })
862
863 t.Run("should send request body", func(t *testing.T) {
864 t.Parallel()
865
866 server := newMockServer()
867 defer server.close()
868
869 server.prepareJSONResponse(map[string]any{})
870
871 provider, err := New(
872 WithAPIKey("test-api-key"),
873 WithBaseURL(server.server.URL),
874 )
875 require.NoError(t, err)
876 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
877
878 _, err = model.Generate(context.Background(), fantasy.Call{
879 Prompt: testPrompt,
880 })
881
882 require.NoError(t, err)
883 require.Len(t, server.calls, 1)
884
885 call := server.calls[0]
886 require.Equal(t, "POST", call.method)
887 require.Equal(t, "/chat/completions", call.path)
888 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
889
890 messages, ok := call.body["messages"].([]any)
891 require.True(t, ok)
892 require.Len(t, messages, 1)
893
894 message := messages[0].(map[string]any)
895 require.Equal(t, "user", message["role"])
896 require.Equal(t, "Hello", message["content"])
897 })
898
899 t.Run("should support partial usage", func(t *testing.T) {
900 t.Parallel()
901
902 server := newMockServer()
903 defer server.close()
904
905 server.prepareJSONResponse(map[string]any{
906 "usage": map[string]any{
907 "prompt_tokens": 20,
908 "total_tokens": 20,
909 },
910 })
911
912 provider, err := New(
913 WithAPIKey("test-api-key"),
914 WithBaseURL(server.server.URL),
915 )
916 require.NoError(t, err)
917 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
918
919 result, err := model.Generate(context.Background(), fantasy.Call{
920 Prompt: testPrompt,
921 })
922
923 require.NoError(t, err)
924 require.Equal(t, int64(20), result.Usage.InputTokens)
925 require.Equal(t, int64(0), result.Usage.OutputTokens)
926 require.Equal(t, int64(20), result.Usage.TotalTokens)
927 })
928
929 t.Run("should extract logprobs", func(t *testing.T) {
930 t.Parallel()
931
932 server := newMockServer()
933 defer server.close()
934
935 server.prepareJSONResponse(map[string]any{
936 "logprobs": testLogprobs,
937 })
938
939 provider, err := New(
940 WithAPIKey("test-api-key"),
941 WithBaseURL(server.server.URL),
942 )
943 require.NoError(t, err)
944 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
945
946 result, err := model.Generate(context.Background(), fantasy.Call{
947 Prompt: testPrompt,
948 ProviderOptions: NewProviderOptions(&ProviderOptions{
949 LogProbs: fantasy.Opt(true),
950 }),
951 })
952
953 require.NoError(t, err)
954 require.NotNil(t, result.ProviderMetadata)
955
956 openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
957 require.True(t, ok)
958
959 logprobs := openaiMeta.Logprobs
960 require.True(t, ok)
961 require.NotNil(t, logprobs)
962 })
963
964 t.Run("should extract finish reason", func(t *testing.T) {
965 t.Parallel()
966
967 server := newMockServer()
968 defer server.close()
969
970 server.prepareJSONResponse(map[string]any{
971 "finish_reason": "stop",
972 })
973
974 provider, err := New(
975 WithAPIKey("test-api-key"),
976 WithBaseURL(server.server.URL),
977 )
978 require.NoError(t, err)
979 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
980
981 result, err := model.Generate(context.Background(), fantasy.Call{
982 Prompt: testPrompt,
983 })
984
985 require.NoError(t, err)
986 require.Equal(t, fantasy.FinishReasonStop, result.FinishReason)
987 })
988
989 t.Run("should support unknown finish reason", func(t *testing.T) {
990 t.Parallel()
991
992 server := newMockServer()
993 defer server.close()
994
995 server.prepareJSONResponse(map[string]any{
996 "finish_reason": "eos",
997 })
998
999 provider, err := New(
1000 WithAPIKey("test-api-key"),
1001 WithBaseURL(server.server.URL),
1002 )
1003 require.NoError(t, err)
1004 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1005
1006 result, err := model.Generate(context.Background(), fantasy.Call{
1007 Prompt: testPrompt,
1008 })
1009
1010 require.NoError(t, err)
1011 require.Equal(t, fantasy.FinishReasonUnknown, result.FinishReason)
1012 })
1013
1014 t.Run("should pass the model and the messages", func(t *testing.T) {
1015 t.Parallel()
1016
1017 server := newMockServer()
1018 defer server.close()
1019
1020 server.prepareJSONResponse(map[string]any{
1021 "content": "",
1022 })
1023
1024 provider, err := New(
1025 WithAPIKey("test-api-key"),
1026 WithBaseURL(server.server.URL),
1027 )
1028 require.NoError(t, err)
1029 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1030
1031 _, err = model.Generate(context.Background(), fantasy.Call{
1032 Prompt: testPrompt,
1033 })
1034
1035 require.NoError(t, err)
1036 require.Len(t, server.calls, 1)
1037
1038 call := server.calls[0]
1039 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1040
1041 messages := call.body["messages"].([]any)
1042 require.Len(t, messages, 1)
1043
1044 message := messages[0].(map[string]any)
1045 require.Equal(t, "user", message["role"])
1046 require.Equal(t, "Hello", message["content"])
1047 })
1048
1049 t.Run("should pass settings", func(t *testing.T) {
1050 t.Parallel()
1051
1052 server := newMockServer()
1053 defer server.close()
1054
1055 server.prepareJSONResponse(map[string]any{})
1056
1057 provider, err := New(
1058 WithAPIKey("test-api-key"),
1059 WithBaseURL(server.server.URL),
1060 )
1061 require.NoError(t, err)
1062 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1063
1064 _, err = model.Generate(context.Background(), fantasy.Call{
1065 Prompt: testPrompt,
1066 ProviderOptions: NewProviderOptions(&ProviderOptions{
1067 LogitBias: map[string]int64{
1068 "50256": -100,
1069 },
1070 ParallelToolCalls: fantasy.Opt(false),
1071 User: fantasy.Opt("test-user-id"),
1072 }),
1073 })
1074
1075 require.NoError(t, err)
1076 require.Len(t, server.calls, 1)
1077
1078 call := server.calls[0]
1079 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1080
1081 messages := call.body["messages"].([]any)
1082 require.Len(t, messages, 1)
1083
1084 logitBias := call.body["logit_bias"].(map[string]any)
1085 require.Equal(t, float64(-100), logitBias["50256"])
1086 require.Equal(t, false, call.body["parallel_tool_calls"])
1087 require.Equal(t, "test-user-id", call.body["user"])
1088 })
1089
1090 t.Run("should pass reasoningEffort setting", func(t *testing.T) {
1091 t.Parallel()
1092
1093 server := newMockServer()
1094 defer server.close()
1095
1096 server.prepareJSONResponse(map[string]any{
1097 "content": "",
1098 })
1099
1100 provider, err := New(
1101 WithAPIKey("test-api-key"),
1102 WithBaseURL(server.server.URL),
1103 )
1104 require.NoError(t, err)
1105 model, _ := provider.LanguageModel(t.Context(), "o1-mini")
1106
1107 _, err = model.Generate(context.Background(), fantasy.Call{
1108 Prompt: testPrompt,
1109 ProviderOptions: NewProviderOptions(
1110 &ProviderOptions{
1111 ReasoningEffort: ReasoningEffortOption(ReasoningEffortLow),
1112 },
1113 ),
1114 })
1115
1116 require.NoError(t, err)
1117 require.Len(t, server.calls, 1)
1118
1119 call := server.calls[0]
1120 require.Equal(t, "o1-mini", call.body["model"])
1121 require.Equal(t, "low", call.body["reasoning_effort"])
1122
1123 messages := call.body["messages"].([]any)
1124 require.Len(t, messages, 1)
1125
1126 message := messages[0].(map[string]any)
1127 require.Equal(t, "user", message["role"])
1128 require.Equal(t, "Hello", message["content"])
1129 })
1130
1131 t.Run("should pass textVerbosity setting", func(t *testing.T) {
1132 t.Parallel()
1133
1134 server := newMockServer()
1135 defer server.close()
1136
1137 server.prepareJSONResponse(map[string]any{
1138 "content": "",
1139 })
1140
1141 provider, err := New(
1142 WithAPIKey("test-api-key"),
1143 WithBaseURL(server.server.URL),
1144 )
1145 require.NoError(t, err)
1146 model, _ := provider.LanguageModel(t.Context(), "gpt-4o")
1147
1148 _, err = model.Generate(context.Background(), fantasy.Call{
1149 Prompt: testPrompt,
1150 ProviderOptions: NewProviderOptions(&ProviderOptions{
1151 TextVerbosity: fantasy.Opt("low"),
1152 }),
1153 })
1154
1155 require.NoError(t, err)
1156 require.Len(t, server.calls, 1)
1157
1158 call := server.calls[0]
1159 require.Equal(t, "gpt-4o", call.body["model"])
1160 require.Equal(t, "low", call.body["verbosity"])
1161
1162 messages := call.body["messages"].([]any)
1163 require.Len(t, messages, 1)
1164
1165 message := messages[0].(map[string]any)
1166 require.Equal(t, "user", message["role"])
1167 require.Equal(t, "Hello", message["content"])
1168 })
1169
1170 t.Run("should pass tools and toolChoice", func(t *testing.T) {
1171 t.Parallel()
1172
1173 server := newMockServer()
1174 defer server.close()
1175
1176 server.prepareJSONResponse(map[string]any{
1177 "content": "",
1178 })
1179
1180 provider, err := New(
1181 WithAPIKey("test-api-key"),
1182 WithBaseURL(server.server.URL),
1183 )
1184 require.NoError(t, err)
1185 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1186
1187 _, err = model.Generate(context.Background(), fantasy.Call{
1188 Prompt: testPrompt,
1189 Tools: []fantasy.Tool{
1190 fantasy.FunctionTool{
1191 Name: "test-tool",
1192 InputSchema: map[string]any{
1193 "type": "object",
1194 "properties": map[string]any{
1195 "value": map[string]any{
1196 "type": "string",
1197 },
1198 },
1199 "required": []string{"value"},
1200 "additionalProperties": false,
1201 "$schema": "http://json-schema.org/draft-07/schema#",
1202 },
1203 },
1204 },
1205 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoice("test-tool")}[0],
1206 })
1207
1208 require.NoError(t, err)
1209 require.Len(t, server.calls, 1)
1210
1211 call := server.calls[0]
1212 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1213
1214 messages := call.body["messages"].([]any)
1215 require.Len(t, messages, 1)
1216
1217 tools := call.body["tools"].([]any)
1218 require.Len(t, tools, 1)
1219
1220 tool := tools[0].(map[string]any)
1221 require.Equal(t, "function", tool["type"])
1222
1223 function := tool["function"].(map[string]any)
1224 require.Equal(t, "test-tool", function["name"])
1225 require.Equal(t, false, function["strict"])
1226
1227 toolChoice := call.body["tool_choice"].(map[string]any)
1228 require.Equal(t, "function", toolChoice["type"])
1229
1230 toolChoiceFunction := toolChoice["function"].(map[string]any)
1231 require.Equal(t, "test-tool", toolChoiceFunction["name"])
1232 })
1233
1234 t.Run("should parse tool results", func(t *testing.T) {
1235 t.Parallel()
1236
1237 server := newMockServer()
1238 defer server.close()
1239
1240 server.prepareJSONResponse(map[string]any{
1241 "tool_calls": []map[string]any{
1242 {
1243 "id": "call_O17Uplv4lJvD6DVdIvFFeRMw",
1244 "type": "function",
1245 "function": map[string]any{
1246 "name": "test-tool",
1247 "arguments": `{"value":"Spark"}`,
1248 },
1249 },
1250 },
1251 })
1252
1253 provider, err := New(
1254 WithAPIKey("test-api-key"),
1255 WithBaseURL(server.server.URL),
1256 )
1257 require.NoError(t, err)
1258 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1259
1260 result, err := model.Generate(context.Background(), fantasy.Call{
1261 Prompt: testPrompt,
1262 Tools: []fantasy.Tool{
1263 fantasy.FunctionTool{
1264 Name: "test-tool",
1265 InputSchema: map[string]any{
1266 "type": "object",
1267 "properties": map[string]any{
1268 "value": map[string]any{
1269 "type": "string",
1270 },
1271 },
1272 "required": []string{"value"},
1273 "additionalProperties": false,
1274 "$schema": "http://json-schema.org/draft-07/schema#",
1275 },
1276 },
1277 },
1278 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoice("test-tool")}[0],
1279 })
1280
1281 require.NoError(t, err)
1282 require.Len(t, result.Content, 1)
1283
1284 toolCall, ok := result.Content[0].(fantasy.ToolCallContent)
1285 require.True(t, ok)
1286 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", toolCall.ToolCallID)
1287 require.Equal(t, "test-tool", toolCall.ToolName)
1288 require.Equal(t, `{"value":"Spark"}`, toolCall.Input)
1289 })
1290
1291 t.Run("should parse annotations/citations", func(t *testing.T) {
1292 t.Parallel()
1293
1294 server := newMockServer()
1295 defer server.close()
1296
1297 server.prepareJSONResponse(map[string]any{
1298 "content": "Based on the search results [doc1], I found information.",
1299 "annotations": []map[string]any{
1300 {
1301 "type": "url_citation",
1302 "url_citation": map[string]any{
1303 "start_index": 24,
1304 "end_index": 29,
1305 "url": "https://example.com/doc1.pdf",
1306 "title": "Document 1",
1307 },
1308 },
1309 },
1310 })
1311
1312 provider, err := New(
1313 WithAPIKey("test-api-key"),
1314 WithBaseURL(server.server.URL),
1315 )
1316 require.NoError(t, err)
1317 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1318
1319 result, err := model.Generate(context.Background(), fantasy.Call{
1320 Prompt: testPrompt,
1321 })
1322
1323 require.NoError(t, err)
1324 require.Len(t, result.Content, 2)
1325
1326 textContent, ok := result.Content[0].(fantasy.TextContent)
1327 require.True(t, ok)
1328 require.Equal(t, "Based on the search results [doc1], I found information.", textContent.Text)
1329
1330 sourceContent, ok := result.Content[1].(fantasy.SourceContent)
1331 require.True(t, ok)
1332 require.Equal(t, fantasy.SourceTypeURL, sourceContent.SourceType)
1333 require.Equal(t, "https://example.com/doc1.pdf", sourceContent.URL)
1334 require.Equal(t, "Document 1", sourceContent.Title)
1335 require.NotEmpty(t, sourceContent.ID)
1336 })
1337
1338 t.Run("should return cached_tokens in prompt_details_tokens", func(t *testing.T) {
1339 t.Parallel()
1340
1341 server := newMockServer()
1342 defer server.close()
1343
1344 server.prepareJSONResponse(map[string]any{
1345 "usage": map[string]any{
1346 "prompt_tokens": 15,
1347 "completion_tokens": 20,
1348 "total_tokens": 35,
1349 "prompt_tokens_details": map[string]any{
1350 "cached_tokens": 1152,
1351 },
1352 },
1353 })
1354
1355 provider, err := New(
1356 WithAPIKey("test-api-key"),
1357 WithBaseURL(server.server.URL),
1358 )
1359 require.NoError(t, err)
1360 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1361
1362 result, err := model.Generate(context.Background(), fantasy.Call{
1363 Prompt: testPrompt,
1364 })
1365
1366 require.NoError(t, err)
1367 require.Equal(t, int64(1152), result.Usage.CacheReadTokens)
1368 require.Equal(t, int64(15), result.Usage.InputTokens)
1369 require.Equal(t, int64(20), result.Usage.OutputTokens)
1370 require.Equal(t, int64(35), result.Usage.TotalTokens)
1371 })
1372
1373 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
1374 t.Parallel()
1375
1376 server := newMockServer()
1377 defer server.close()
1378
1379 server.prepareJSONResponse(map[string]any{
1380 "usage": map[string]any{
1381 "prompt_tokens": 15,
1382 "completion_tokens": 20,
1383 "total_tokens": 35,
1384 "completion_tokens_details": map[string]any{
1385 "accepted_prediction_tokens": 123,
1386 "rejected_prediction_tokens": 456,
1387 },
1388 },
1389 })
1390
1391 provider, err := New(
1392 WithAPIKey("test-api-key"),
1393 WithBaseURL(server.server.URL),
1394 )
1395 require.NoError(t, err)
1396 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1397
1398 result, err := model.Generate(context.Background(), fantasy.Call{
1399 Prompt: testPrompt,
1400 })
1401
1402 require.NoError(t, err)
1403 require.NotNil(t, result.ProviderMetadata)
1404
1405 openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
1406
1407 require.True(t, ok)
1408 require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
1409 require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
1410 })
1411
1412 t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) {
1413 t.Parallel()
1414
1415 server := newMockServer()
1416 defer server.close()
1417
1418 server.prepareJSONResponse(map[string]any{})
1419
1420 provider, err := New(
1421 WithAPIKey("test-api-key"),
1422 WithBaseURL(server.server.URL),
1423 )
1424 require.NoError(t, err)
1425 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1426
1427 result, err := model.Generate(context.Background(), fantasy.Call{
1428 Prompt: testPrompt,
1429 Temperature: &[]float64{0.5}[0],
1430 TopP: &[]float64{0.7}[0],
1431 FrequencyPenalty: &[]float64{0.2}[0],
1432 PresencePenalty: &[]float64{0.3}[0],
1433 })
1434
1435 require.NoError(t, err)
1436 require.Len(t, server.calls, 1)
1437
1438 call := server.calls[0]
1439 require.Equal(t, "o1-preview", call.body["model"])
1440
1441 messages := call.body["messages"].([]any)
1442 require.Len(t, messages, 1)
1443
1444 message := messages[0].(map[string]any)
1445 require.Equal(t, "user", message["role"])
1446 require.Equal(t, "Hello", message["content"])
1447
1448 // These should not be present
1449 require.Nil(t, call.body["temperature"])
1450 require.Nil(t, call.body["top_p"])
1451 require.Nil(t, call.body["frequency_penalty"])
1452 require.Nil(t, call.body["presence_penalty"])
1453
1454 // Should have warnings
1455 require.Len(t, result.Warnings, 4)
1456 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1457 require.Equal(t, "temperature", result.Warnings[0].Setting)
1458 require.Contains(t, result.Warnings[0].Details, "temperature is not supported for reasoning models")
1459 })
1460
1461 t.Run("should convert maxOutputTokens to max_completion_tokens for reasoning models", func(t *testing.T) {
1462 t.Parallel()
1463
1464 server := newMockServer()
1465 defer server.close()
1466
1467 server.prepareJSONResponse(map[string]any{})
1468
1469 provider, err := New(
1470 WithAPIKey("test-api-key"),
1471 WithBaseURL(server.server.URL),
1472 )
1473 require.NoError(t, err)
1474 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1475
1476 _, err = model.Generate(context.Background(), fantasy.Call{
1477 Prompt: testPrompt,
1478 MaxOutputTokens: &[]int64{1000}[0],
1479 })
1480
1481 require.NoError(t, err)
1482 require.Len(t, server.calls, 1)
1483
1484 call := server.calls[0]
1485 require.Equal(t, "o1-preview", call.body["model"])
1486 require.Equal(t, float64(1000), call.body["max_completion_tokens"])
1487 require.Nil(t, call.body["max_tokens"])
1488
1489 messages := call.body["messages"].([]any)
1490 require.Len(t, messages, 1)
1491
1492 message := messages[0].(map[string]any)
1493 require.Equal(t, "user", message["role"])
1494 require.Equal(t, "Hello", message["content"])
1495 })
1496
1497 t.Run("should return reasoning tokens", func(t *testing.T) {
1498 t.Parallel()
1499
1500 server := newMockServer()
1501 defer server.close()
1502
1503 server.prepareJSONResponse(map[string]any{
1504 "usage": map[string]any{
1505 "prompt_tokens": 15,
1506 "completion_tokens": 20,
1507 "total_tokens": 35,
1508 "completion_tokens_details": map[string]any{
1509 "reasoning_tokens": 10,
1510 },
1511 },
1512 })
1513
1514 provider, err := New(
1515 WithAPIKey("test-api-key"),
1516 WithBaseURL(server.server.URL),
1517 )
1518 require.NoError(t, err)
1519 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1520
1521 result, err := model.Generate(context.Background(), fantasy.Call{
1522 Prompt: testPrompt,
1523 })
1524
1525 require.NoError(t, err)
1526 require.Equal(t, int64(15), result.Usage.InputTokens)
1527 require.Equal(t, int64(20), result.Usage.OutputTokens)
1528 require.Equal(t, int64(35), result.Usage.TotalTokens)
1529 require.Equal(t, int64(10), result.Usage.ReasoningTokens)
1530 })
1531
1532 t.Run("should send max_completion_tokens extension setting", func(t *testing.T) {
1533 t.Parallel()
1534
1535 server := newMockServer()
1536 defer server.close()
1537
1538 server.prepareJSONResponse(map[string]any{
1539 "model": "o1-preview",
1540 })
1541
1542 provider, err := New(
1543 WithAPIKey("test-api-key"),
1544 WithBaseURL(server.server.URL),
1545 )
1546 require.NoError(t, err)
1547 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1548
1549 _, err = model.Generate(context.Background(), fantasy.Call{
1550 Prompt: testPrompt,
1551 ProviderOptions: NewProviderOptions(&ProviderOptions{
1552 MaxCompletionTokens: fantasy.Opt(int64(255)),
1553 }),
1554 })
1555
1556 require.NoError(t, err)
1557 require.Len(t, server.calls, 1)
1558
1559 call := server.calls[0]
1560 require.Equal(t, "o1-preview", call.body["model"])
1561 require.Equal(t, float64(255), call.body["max_completion_tokens"])
1562
1563 messages := call.body["messages"].([]any)
1564 require.Len(t, messages, 1)
1565
1566 message := messages[0].(map[string]any)
1567 require.Equal(t, "user", message["role"])
1568 require.Equal(t, "Hello", message["content"])
1569 })
1570
1571 t.Run("should send prediction extension setting", func(t *testing.T) {
1572 t.Parallel()
1573
1574 server := newMockServer()
1575 defer server.close()
1576
1577 server.prepareJSONResponse(map[string]any{
1578 "content": "",
1579 })
1580
1581 provider, err := New(
1582 WithAPIKey("test-api-key"),
1583 WithBaseURL(server.server.URL),
1584 )
1585 require.NoError(t, err)
1586 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1587
1588 _, err = model.Generate(context.Background(), fantasy.Call{
1589 Prompt: testPrompt,
1590 ProviderOptions: NewProviderOptions(&ProviderOptions{
1591 Prediction: map[string]any{
1592 "type": "content",
1593 "content": "Hello, World!",
1594 },
1595 }),
1596 })
1597
1598 require.NoError(t, err)
1599 require.Len(t, server.calls, 1)
1600
1601 call := server.calls[0]
1602 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1603
1604 prediction := call.body["prediction"].(map[string]any)
1605 require.Equal(t, "content", prediction["type"])
1606 require.Equal(t, "Hello, World!", prediction["content"])
1607
1608 messages := call.body["messages"].([]any)
1609 require.Len(t, messages, 1)
1610
1611 message := messages[0].(map[string]any)
1612 require.Equal(t, "user", message["role"])
1613 require.Equal(t, "Hello", message["content"])
1614 })
1615
1616 t.Run("should send store extension setting", func(t *testing.T) {
1617 t.Parallel()
1618
1619 server := newMockServer()
1620 defer server.close()
1621
1622 server.prepareJSONResponse(map[string]any{
1623 "content": "",
1624 })
1625
1626 provider, err := New(
1627 WithAPIKey("test-api-key"),
1628 WithBaseURL(server.server.URL),
1629 )
1630 require.NoError(t, err)
1631 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1632
1633 _, err = model.Generate(context.Background(), fantasy.Call{
1634 Prompt: testPrompt,
1635 ProviderOptions: NewProviderOptions(&ProviderOptions{
1636 Store: fantasy.Opt(true),
1637 }),
1638 })
1639
1640 require.NoError(t, err)
1641 require.Len(t, server.calls, 1)
1642
1643 call := server.calls[0]
1644 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1645 require.Equal(t, true, call.body["store"])
1646
1647 messages := call.body["messages"].([]any)
1648 require.Len(t, messages, 1)
1649
1650 message := messages[0].(map[string]any)
1651 require.Equal(t, "user", message["role"])
1652 require.Equal(t, "Hello", message["content"])
1653 })
1654
1655 t.Run("should send metadata extension values", func(t *testing.T) {
1656 t.Parallel()
1657
1658 server := newMockServer()
1659 defer server.close()
1660
1661 server.prepareJSONResponse(map[string]any{
1662 "content": "",
1663 })
1664
1665 provider, err := New(
1666 WithAPIKey("test-api-key"),
1667 WithBaseURL(server.server.URL),
1668 )
1669 require.NoError(t, err)
1670 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1671
1672 _, err = model.Generate(context.Background(), fantasy.Call{
1673 Prompt: testPrompt,
1674 ProviderOptions: NewProviderOptions(&ProviderOptions{
1675 Metadata: map[string]any{
1676 "custom": "value",
1677 },
1678 }),
1679 })
1680
1681 require.NoError(t, err)
1682 require.Len(t, server.calls, 1)
1683
1684 call := server.calls[0]
1685 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1686
1687 metadata := call.body["metadata"].(map[string]any)
1688 require.Equal(t, "value", metadata["custom"])
1689
1690 messages := call.body["messages"].([]any)
1691 require.Len(t, messages, 1)
1692
1693 message := messages[0].(map[string]any)
1694 require.Equal(t, "user", message["role"])
1695 require.Equal(t, "Hello", message["content"])
1696 })
1697
1698 t.Run("should send promptCacheKey extension value", func(t *testing.T) {
1699 t.Parallel()
1700
1701 server := newMockServer()
1702 defer server.close()
1703
1704 server.prepareJSONResponse(map[string]any{
1705 "content": "",
1706 })
1707
1708 provider, err := New(
1709 WithAPIKey("test-api-key"),
1710 WithBaseURL(server.server.URL),
1711 )
1712 require.NoError(t, err)
1713 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1714
1715 _, err = model.Generate(context.Background(), fantasy.Call{
1716 Prompt: testPrompt,
1717 ProviderOptions: NewProviderOptions(&ProviderOptions{
1718 PromptCacheKey: fantasy.Opt("test-cache-key-123"),
1719 }),
1720 })
1721
1722 require.NoError(t, err)
1723 require.Len(t, server.calls, 1)
1724
1725 call := server.calls[0]
1726 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1727 require.Equal(t, "test-cache-key-123", call.body["prompt_cache_key"])
1728
1729 messages := call.body["messages"].([]any)
1730 require.Len(t, messages, 1)
1731
1732 message := messages[0].(map[string]any)
1733 require.Equal(t, "user", message["role"])
1734 require.Equal(t, "Hello", message["content"])
1735 })
1736
1737 t.Run("should send safety_identifier extension value", func(t *testing.T) {
1738 t.Parallel()
1739
1740 server := newMockServer()
1741 defer server.close()
1742
1743 server.prepareJSONResponse(map[string]any{
1744 "content": "",
1745 })
1746
1747 provider, err := New(
1748 WithAPIKey("test-api-key"),
1749 WithBaseURL(server.server.URL),
1750 )
1751 require.NoError(t, err)
1752 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1753
1754 _, err = model.Generate(context.Background(), fantasy.Call{
1755 Prompt: testPrompt,
1756 ProviderOptions: NewProviderOptions(&ProviderOptions{
1757 SafetyIdentifier: fantasy.Opt("test-safety-identifier-123"),
1758 }),
1759 })
1760
1761 require.NoError(t, err)
1762 require.Len(t, server.calls, 1)
1763
1764 call := server.calls[0]
1765 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1766 require.Equal(t, "test-safety-identifier-123", call.body["safety_identifier"])
1767
1768 messages := call.body["messages"].([]any)
1769 require.Len(t, messages, 1)
1770
1771 message := messages[0].(map[string]any)
1772 require.Equal(t, "user", message["role"])
1773 require.Equal(t, "Hello", message["content"])
1774 })
1775
1776 t.Run("should remove temperature setting for search preview models", func(t *testing.T) {
1777 t.Parallel()
1778
1779 server := newMockServer()
1780 defer server.close()
1781
1782 server.prepareJSONResponse(map[string]any{})
1783
1784 provider, err := New(
1785 WithAPIKey("test-api-key"),
1786 WithBaseURL(server.server.URL),
1787 )
1788 require.NoError(t, err)
1789 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-search-preview")
1790
1791 result, err := model.Generate(context.Background(), fantasy.Call{
1792 Prompt: testPrompt,
1793 Temperature: &[]float64{0.7}[0],
1794 })
1795
1796 require.NoError(t, err)
1797 require.Len(t, server.calls, 1)
1798
1799 call := server.calls[0]
1800 require.Equal(t, "gpt-4o-search-preview", call.body["model"])
1801 require.Nil(t, call.body["temperature"])
1802
1803 require.Len(t, result.Warnings, 1)
1804 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1805 require.Equal(t, "temperature", result.Warnings[0].Setting)
1806 require.Contains(t, result.Warnings[0].Details, "search preview models")
1807 })
1808
1809 t.Run("should send ServiceTier flex processing setting", func(t *testing.T) {
1810 t.Parallel()
1811
1812 server := newMockServer()
1813 defer server.close()
1814
1815 server.prepareJSONResponse(map[string]any{
1816 "content": "",
1817 })
1818
1819 provider, err := New(
1820 WithAPIKey("test-api-key"),
1821 WithBaseURL(server.server.URL),
1822 )
1823 require.NoError(t, err)
1824 model, _ := provider.LanguageModel(t.Context(), "o3-mini")
1825
1826 _, err = model.Generate(context.Background(), fantasy.Call{
1827 Prompt: testPrompt,
1828 ProviderOptions: NewProviderOptions(&ProviderOptions{
1829 ServiceTier: fantasy.Opt("flex"),
1830 }),
1831 })
1832
1833 require.NoError(t, err)
1834 require.Len(t, server.calls, 1)
1835
1836 call := server.calls[0]
1837 require.Equal(t, "o3-mini", call.body["model"])
1838 require.Equal(t, "flex", call.body["service_tier"])
1839
1840 messages := call.body["messages"].([]any)
1841 require.Len(t, messages, 1)
1842
1843 message := messages[0].(map[string]any)
1844 require.Equal(t, "user", message["role"])
1845 require.Equal(t, "Hello", message["content"])
1846 })
1847
1848 t.Run("should show warning when using flex processing with unsupported model", func(t *testing.T) {
1849 t.Parallel()
1850
1851 server := newMockServer()
1852 defer server.close()
1853
1854 server.prepareJSONResponse(map[string]any{})
1855
1856 provider, err := New(
1857 WithAPIKey("test-api-key"),
1858 WithBaseURL(server.server.URL),
1859 )
1860 require.NoError(t, err)
1861 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1862
1863 result, err := model.Generate(context.Background(), fantasy.Call{
1864 Prompt: testPrompt,
1865 ProviderOptions: NewProviderOptions(&ProviderOptions{
1866 ServiceTier: fantasy.Opt("flex"),
1867 }),
1868 })
1869
1870 require.NoError(t, err)
1871 require.Len(t, server.calls, 1)
1872
1873 call := server.calls[0]
1874 require.Nil(t, call.body["service_tier"])
1875
1876 require.Len(t, result.Warnings, 1)
1877 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1878 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
1879 require.Contains(t, result.Warnings[0].Details, "flex processing is only available")
1880 })
1881
1882 t.Run("should send serviceTier priority processing setting", func(t *testing.T) {
1883 t.Parallel()
1884
1885 server := newMockServer()
1886 defer server.close()
1887
1888 server.prepareJSONResponse(map[string]any{})
1889
1890 provider, err := New(
1891 WithAPIKey("test-api-key"),
1892 WithBaseURL(server.server.URL),
1893 )
1894 require.NoError(t, err)
1895 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1896
1897 _, err = model.Generate(context.Background(), fantasy.Call{
1898 Prompt: testPrompt,
1899 ProviderOptions: NewProviderOptions(&ProviderOptions{
1900 ServiceTier: fantasy.Opt("priority"),
1901 }),
1902 })
1903
1904 require.NoError(t, err)
1905 require.Len(t, server.calls, 1)
1906
1907 call := server.calls[0]
1908 require.Equal(t, "gpt-4o-mini", call.body["model"])
1909 require.Equal(t, "priority", call.body["service_tier"])
1910
1911 messages := call.body["messages"].([]any)
1912 require.Len(t, messages, 1)
1913
1914 message := messages[0].(map[string]any)
1915 require.Equal(t, "user", message["role"])
1916 require.Equal(t, "Hello", message["content"])
1917 })
1918
1919 t.Run("should show warning when using priority processing with unsupported model", func(t *testing.T) {
1920 t.Parallel()
1921
1922 server := newMockServer()
1923 defer server.close()
1924
1925 server.prepareJSONResponse(map[string]any{})
1926
1927 provider, err := New(
1928 WithAPIKey("test-api-key"),
1929 WithBaseURL(server.server.URL),
1930 )
1931 require.NoError(t, err)
1932 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1933
1934 result, err := model.Generate(context.Background(), fantasy.Call{
1935 Prompt: testPrompt,
1936 ProviderOptions: NewProviderOptions(&ProviderOptions{
1937 ServiceTier: fantasy.Opt("priority"),
1938 }),
1939 })
1940
1941 require.NoError(t, err)
1942 require.Len(t, server.calls, 1)
1943
1944 call := server.calls[0]
1945 require.Nil(t, call.body["service_tier"])
1946
1947 require.Len(t, result.Warnings, 1)
1948 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1949 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
1950 require.Contains(t, result.Warnings[0].Details, "priority processing is only available")
1951 })
1952}
1953
1954type streamingMockServer struct {
1955 server *httptest.Server
1956 chunks []string
1957 calls []mockCall
1958}
1959
1960func newStreamingMockServer() *streamingMockServer {
1961 sms := &streamingMockServer{
1962 calls: make([]mockCall, 0),
1963 }
1964
1965 sms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1966 // Record the call
1967 call := mockCall{
1968 method: r.Method,
1969 path: r.URL.Path,
1970 headers: make(map[string]string),
1971 }
1972
1973 for k, v := range r.Header {
1974 if len(v) > 0 {
1975 call.headers[k] = v[0]
1976 }
1977 }
1978
1979 // Parse request body
1980 if r.Body != nil {
1981 var body map[string]any
1982 json.NewDecoder(r.Body).Decode(&body)
1983 call.body = body
1984 }
1985
1986 sms.calls = append(sms.calls, call)
1987
1988 // Set streaming headers
1989 w.Header().Set("Content-Type", "text/event-stream")
1990 w.Header().Set("Cache-Control", "no-cache")
1991 w.Header().Set("Connection", "keep-alive")
1992
1993 // Add custom headers if any
1994 for _, chunk := range sms.chunks {
1995 if strings.HasPrefix(chunk, "HEADER:") {
1996 parts := strings.SplitN(chunk[7:], ":", 2)
1997 if len(parts) == 2 {
1998 w.Header().Set(parts[0], parts[1])
1999 }
2000 continue
2001 }
2002 }
2003
2004 w.WriteHeader(http.StatusOK)
2005
2006 // Write chunks
2007 for _, chunk := range sms.chunks {
2008 if strings.HasPrefix(chunk, "HEADER:") {
2009 continue
2010 }
2011 w.Write([]byte(chunk))
2012 if f, ok := w.(http.Flusher); ok {
2013 f.Flush()
2014 }
2015 }
2016 }))
2017
2018 return sms
2019}
2020
2021func (sms *streamingMockServer) close() {
2022 sms.server.Close()
2023}
2024
2025func (sms *streamingMockServer) prepareStreamResponse(opts map[string]any) {
2026 content := []string{}
2027 if c, ok := opts["content"].([]string); ok {
2028 content = c
2029 }
2030
2031 usage := map[string]any{
2032 "prompt_tokens": 17,
2033 "total_tokens": 244,
2034 "completion_tokens": 227,
2035 }
2036 if u, ok := opts["usage"].(map[string]any); ok {
2037 usage = u
2038 }
2039
2040 logprobs := map[string]any{}
2041 if l, ok := opts["logprobs"].(map[string]any); ok {
2042 logprobs = l
2043 }
2044
2045 finishReason := "stop"
2046 if fr, ok := opts["finish_reason"].(string); ok {
2047 finishReason = fr
2048 }
2049
2050 model := "gpt-3.5-turbo-0613"
2051 if m, ok := opts["model"].(string); ok {
2052 model = m
2053 }
2054
2055 headers := map[string]string{}
2056 if h, ok := opts["headers"].(map[string]string); ok {
2057 headers = h
2058 }
2059
2060 chunks := []string{}
2061
2062 // Add custom headers
2063 for k, v := range headers {
2064 chunks = append(chunks, "HEADER:"+k+":"+v)
2065 }
2066
2067 // Initial chunk with role
2068 initialChunk := map[string]any{
2069 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2070 "object": "chat.completion.chunk",
2071 "created": 1702657020,
2072 "model": model,
2073 "system_fingerprint": nil,
2074 "choices": []map[string]any{
2075 {
2076 "index": 0,
2077 "delta": map[string]any{
2078 "role": "assistant",
2079 "content": "",
2080 },
2081 "finish_reason": nil,
2082 },
2083 },
2084 }
2085 initialData, _ := json.Marshal(initialChunk)
2086 chunks = append(chunks, "data: "+string(initialData)+"\n\n")
2087
2088 // Content chunks
2089 for i, text := range content {
2090 contentChunk := map[string]any{
2091 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2092 "object": "chat.completion.chunk",
2093 "created": 1702657020,
2094 "model": model,
2095 "system_fingerprint": nil,
2096 "choices": []map[string]any{
2097 {
2098 "index": 1,
2099 "delta": map[string]any{
2100 "content": text,
2101 },
2102 "finish_reason": nil,
2103 },
2104 },
2105 }
2106 contentData, _ := json.Marshal(contentChunk)
2107 chunks = append(chunks, "data: "+string(contentData)+"\n\n")
2108
2109 // Add annotations if this is the last content chunk and we have annotations
2110 if i == len(content)-1 {
2111 if annotations, ok := opts["annotations"].([]map[string]any); ok {
2112 annotationChunk := map[string]any{
2113 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2114 "object": "chat.completion.chunk",
2115 "created": 1702657020,
2116 "model": model,
2117 "system_fingerprint": nil,
2118 "choices": []map[string]any{
2119 {
2120 "index": 1,
2121 "delta": map[string]any{
2122 "annotations": annotations,
2123 },
2124 "finish_reason": nil,
2125 },
2126 },
2127 }
2128 annotationData, _ := json.Marshal(annotationChunk)
2129 chunks = append(chunks, "data: "+string(annotationData)+"\n\n")
2130 }
2131 }
2132 }
2133
2134 // Finish chunk
2135 finishChunk := map[string]any{
2136 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2137 "object": "chat.completion.chunk",
2138 "created": 1702657020,
2139 "model": model,
2140 "system_fingerprint": nil,
2141 "choices": []map[string]any{
2142 {
2143 "index": 0,
2144 "delta": map[string]any{},
2145 "finish_reason": finishReason,
2146 },
2147 },
2148 }
2149
2150 if len(logprobs) > 0 {
2151 finishChunk["choices"].([]map[string]any)[0]["logprobs"] = logprobs
2152 }
2153
2154 finishData, _ := json.Marshal(finishChunk)
2155 chunks = append(chunks, "data: "+string(finishData)+"\n\n")
2156
2157 // Usage chunk
2158 usageChunk := map[string]any{
2159 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2160 "object": "chat.completion.chunk",
2161 "created": 1702657020,
2162 "model": model,
2163 "system_fingerprint": "fp_3bc1b5746c",
2164 "choices": []map[string]any{},
2165 "usage": usage,
2166 }
2167 usageData, _ := json.Marshal(usageChunk)
2168 chunks = append(chunks, "data: "+string(usageData)+"\n\n")
2169
2170 // Done
2171 chunks = append(chunks, "data: [DONE]\n\n")
2172
2173 sms.chunks = chunks
2174}
2175
2176func (sms *streamingMockServer) prepareToolStreamResponse() {
2177 chunks := []string{
2178 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2179 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2180 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"value"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2181 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2182 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2183 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2184 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2185 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2186 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
2187 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
2188 "data: [DONE]\n\n",
2189 }
2190 sms.chunks = chunks
2191}
2192
2193func (sms *streamingMockServer) prepareErrorStreamResponse() {
2194 chunks := []string{
2195 `data: {"error":{"message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error.","type":"server_error","param":null,"code":null}}` + "\n\n",
2196 "data: [DONE]\n\n",
2197 }
2198 sms.chunks = chunks
2199}
2200
2201func collectStreamParts(stream fantasy.StreamResponse) ([]fantasy.StreamPart, error) {
2202 var parts []fantasy.StreamPart
2203 for part := range stream {
2204 parts = append(parts, part)
2205 if part.Type == fantasy.StreamPartTypeError {
2206 break
2207 }
2208 if part.Type == fantasy.StreamPartTypeFinish {
2209 break
2210 }
2211 }
2212 return parts, nil
2213}
2214
2215func TestDoStream(t *testing.T) {
2216 t.Parallel()
2217
2218 t.Run("should stream text deltas", func(t *testing.T) {
2219 t.Parallel()
2220
2221 server := newStreamingMockServer()
2222 defer server.close()
2223
2224 server.prepareStreamResponse(map[string]any{
2225 "content": []string{"Hello", ", ", "World!"},
2226 "finish_reason": "stop",
2227 "usage": map[string]any{
2228 "prompt_tokens": 17,
2229 "total_tokens": 244,
2230 "completion_tokens": 227,
2231 },
2232 "logprobs": testLogprobs,
2233 })
2234
2235 provider, err := New(
2236 WithAPIKey("test-api-key"),
2237 WithBaseURL(server.server.URL),
2238 )
2239 require.NoError(t, err)
2240 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2241
2242 stream, err := model.Stream(context.Background(), fantasy.Call{
2243 Prompt: testPrompt,
2244 })
2245
2246 require.NoError(t, err)
2247
2248 parts, err := collectStreamParts(stream)
2249 require.NoError(t, err)
2250
2251 // Verify stream structure
2252 require.True(t, len(parts) >= 4) // text-start, deltas, text-end, finish
2253
2254 // Find text parts
2255 textStart, textEnd, finish := -1, -1, -1
2256 var deltas []string
2257
2258 for i, part := range parts {
2259 switch part.Type {
2260 case fantasy.StreamPartTypeTextStart:
2261 textStart = i
2262 case fantasy.StreamPartTypeTextDelta:
2263 deltas = append(deltas, part.Delta)
2264 case fantasy.StreamPartTypeTextEnd:
2265 textEnd = i
2266 case fantasy.StreamPartTypeFinish:
2267 finish = i
2268 }
2269 }
2270
2271 require.NotEqual(t, -1, textStart)
2272 require.NotEqual(t, -1, textEnd)
2273 require.NotEqual(t, -1, finish)
2274 require.Equal(t, []string{"Hello", ", ", "World!"}, deltas)
2275
2276 // Check finish part
2277 finishPart := parts[finish]
2278 require.Equal(t, fantasy.FinishReasonStop, finishPart.FinishReason)
2279 require.Equal(t, int64(17), finishPart.Usage.InputTokens)
2280 require.Equal(t, int64(227), finishPart.Usage.OutputTokens)
2281 require.Equal(t, int64(244), finishPart.Usage.TotalTokens)
2282 })
2283
2284 t.Run("should stream tool deltas", func(t *testing.T) {
2285 t.Parallel()
2286
2287 server := newStreamingMockServer()
2288 defer server.close()
2289
2290 server.prepareToolStreamResponse()
2291
2292 provider, err := New(
2293 WithAPIKey("test-api-key"),
2294 WithBaseURL(server.server.URL),
2295 )
2296 require.NoError(t, err)
2297 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2298
2299 stream, err := model.Stream(context.Background(), fantasy.Call{
2300 Prompt: testPrompt,
2301 Tools: []fantasy.Tool{
2302 fantasy.FunctionTool{
2303 Name: "test-tool",
2304 InputSchema: map[string]any{
2305 "type": "object",
2306 "properties": map[string]any{
2307 "value": map[string]any{
2308 "type": "string",
2309 },
2310 },
2311 "required": []string{"value"},
2312 "additionalProperties": false,
2313 "$schema": "http://json-schema.org/draft-07/schema#",
2314 },
2315 },
2316 },
2317 })
2318
2319 require.NoError(t, err)
2320
2321 parts, err := collectStreamParts(stream)
2322 require.NoError(t, err)
2323
2324 // Find tool-related parts
2325 toolInputStart, toolInputEnd, toolCall := -1, -1, -1
2326 var toolDeltas []string
2327
2328 for i, part := range parts {
2329 switch part.Type {
2330 case fantasy.StreamPartTypeToolInputStart:
2331 toolInputStart = i
2332 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2333 require.Equal(t, "test-tool", part.ToolCallName)
2334 case fantasy.StreamPartTypeToolInputDelta:
2335 toolDeltas = append(toolDeltas, part.Delta)
2336 case fantasy.StreamPartTypeToolInputEnd:
2337 toolInputEnd = i
2338 case fantasy.StreamPartTypeToolCall:
2339 toolCall = i
2340 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2341 require.Equal(t, "test-tool", part.ToolCallName)
2342 require.Equal(t, `{"value":"Sparkle Day"}`, part.ToolCallInput)
2343 }
2344 }
2345
2346 require.NotEqual(t, -1, toolInputStart)
2347 require.NotEqual(t, -1, toolInputEnd)
2348 require.NotEqual(t, -1, toolCall)
2349
2350 // Verify tool deltas combine to form the complete input
2351 fullInput := ""
2352 for _, delta := range toolDeltas {
2353 fullInput += delta
2354 }
2355 require.Equal(t, `{"value":"Sparkle Day"}`, fullInput)
2356 })
2357
2358 t.Run("should stream annotations/citations", func(t *testing.T) {
2359 t.Parallel()
2360
2361 server := newStreamingMockServer()
2362 defer server.close()
2363
2364 server.prepareStreamResponse(map[string]any{
2365 "content": []string{"Based on search results"},
2366 "annotations": []map[string]any{
2367 {
2368 "type": "url_citation",
2369 "url_citation": map[string]any{
2370 "start_index": 24,
2371 "end_index": 29,
2372 "url": "https://example.com/doc1.pdf",
2373 "title": "Document 1",
2374 },
2375 },
2376 },
2377 })
2378
2379 provider, err := New(
2380 WithAPIKey("test-api-key"),
2381 WithBaseURL(server.server.URL),
2382 )
2383 require.NoError(t, err)
2384 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2385
2386 stream, err := model.Stream(context.Background(), fantasy.Call{
2387 Prompt: testPrompt,
2388 })
2389
2390 require.NoError(t, err)
2391
2392 parts, err := collectStreamParts(stream)
2393 require.NoError(t, err)
2394
2395 // Find source part
2396 var sourcePart *fantasy.StreamPart
2397 for _, part := range parts {
2398 if part.Type == fantasy.StreamPartTypeSource {
2399 sourcePart = &part
2400 break
2401 }
2402 }
2403
2404 require.NotNil(t, sourcePart)
2405 require.Equal(t, fantasy.SourceTypeURL, sourcePart.SourceType)
2406 require.Equal(t, "https://example.com/doc1.pdf", sourcePart.URL)
2407 require.Equal(t, "Document 1", sourcePart.Title)
2408 require.NotEmpty(t, sourcePart.ID)
2409 })
2410
2411 t.Run("should handle error stream parts", func(t *testing.T) {
2412 t.Parallel()
2413
2414 server := newStreamingMockServer()
2415 defer server.close()
2416
2417 server.prepareErrorStreamResponse()
2418
2419 provider, err := New(
2420 WithAPIKey("test-api-key"),
2421 WithBaseURL(server.server.URL),
2422 )
2423 require.NoError(t, err)
2424 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2425
2426 stream, err := model.Stream(context.Background(), fantasy.Call{
2427 Prompt: testPrompt,
2428 })
2429
2430 require.NoError(t, err)
2431
2432 parts, err := collectStreamParts(stream)
2433 require.NoError(t, err)
2434
2435 // Should have error and finish parts
2436 require.True(t, len(parts) >= 1)
2437
2438 // Find error part
2439 var errorPart *fantasy.StreamPart
2440 for _, part := range parts {
2441 if part.Type == fantasy.StreamPartTypeError {
2442 errorPart = &part
2443 break
2444 }
2445 }
2446
2447 require.NotNil(t, errorPart)
2448 require.NotNil(t, errorPart.Error)
2449 })
2450
2451 t.Run("should send request body", func(t *testing.T) {
2452 t.Parallel()
2453
2454 server := newStreamingMockServer()
2455 defer server.close()
2456
2457 server.prepareStreamResponse(map[string]any{
2458 "content": []string{},
2459 })
2460
2461 provider, err := New(
2462 WithAPIKey("test-api-key"),
2463 WithBaseURL(server.server.URL),
2464 )
2465 require.NoError(t, err)
2466 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2467
2468 _, err = model.Stream(context.Background(), fantasy.Call{
2469 Prompt: testPrompt,
2470 })
2471
2472 require.NoError(t, err)
2473 require.Len(t, server.calls, 1)
2474
2475 call := server.calls[0]
2476 require.Equal(t, "POST", call.method)
2477 require.Equal(t, "/chat/completions", call.path)
2478 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2479 require.Equal(t, true, call.body["stream"])
2480
2481 streamOptions := call.body["stream_options"].(map[string]any)
2482 require.Equal(t, true, streamOptions["include_usage"])
2483
2484 messages := call.body["messages"].([]any)
2485 require.Len(t, messages, 1)
2486
2487 message := messages[0].(map[string]any)
2488 require.Equal(t, "user", message["role"])
2489 require.Equal(t, "Hello", message["content"])
2490 })
2491
2492 t.Run("should return cached tokens in providerMetadata", func(t *testing.T) {
2493 t.Parallel()
2494
2495 server := newStreamingMockServer()
2496 defer server.close()
2497
2498 server.prepareStreamResponse(map[string]any{
2499 "content": []string{},
2500 "usage": map[string]any{
2501 "prompt_tokens": 15,
2502 "completion_tokens": 20,
2503 "total_tokens": 35,
2504 "prompt_tokens_details": map[string]any{
2505 "cached_tokens": 1152,
2506 },
2507 },
2508 })
2509
2510 provider, err := New(
2511 WithAPIKey("test-api-key"),
2512 WithBaseURL(server.server.URL),
2513 )
2514 require.NoError(t, err)
2515 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2516
2517 stream, err := model.Stream(context.Background(), fantasy.Call{
2518 Prompt: testPrompt,
2519 })
2520
2521 require.NoError(t, err)
2522
2523 parts, err := collectStreamParts(stream)
2524 require.NoError(t, err)
2525
2526 // Find finish part
2527 var finishPart *fantasy.StreamPart
2528 for _, part := range parts {
2529 if part.Type == fantasy.StreamPartTypeFinish {
2530 finishPart = &part
2531 break
2532 }
2533 }
2534
2535 require.NotNil(t, finishPart)
2536 require.Equal(t, int64(1152), finishPart.Usage.CacheReadTokens)
2537 require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2538 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2539 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2540 })
2541
2542 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
2543 t.Parallel()
2544
2545 server := newStreamingMockServer()
2546 defer server.close()
2547
2548 server.prepareStreamResponse(map[string]any{
2549 "content": []string{},
2550 "usage": map[string]any{
2551 "prompt_tokens": 15,
2552 "completion_tokens": 20,
2553 "total_tokens": 35,
2554 "completion_tokens_details": map[string]any{
2555 "accepted_prediction_tokens": 123,
2556 "rejected_prediction_tokens": 456,
2557 },
2558 },
2559 })
2560
2561 provider, err := New(
2562 WithAPIKey("test-api-key"),
2563 WithBaseURL(server.server.URL),
2564 )
2565 require.NoError(t, err)
2566 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2567
2568 stream, err := model.Stream(context.Background(), fantasy.Call{
2569 Prompt: testPrompt,
2570 })
2571
2572 require.NoError(t, err)
2573
2574 parts, err := collectStreamParts(stream)
2575 require.NoError(t, err)
2576
2577 // Find finish part
2578 var finishPart *fantasy.StreamPart
2579 for _, part := range parts {
2580 if part.Type == fantasy.StreamPartTypeFinish {
2581 finishPart = &part
2582 break
2583 }
2584 }
2585
2586 require.NotNil(t, finishPart)
2587 require.NotNil(t, finishPart.ProviderMetadata)
2588
2589 openaiMeta, ok := finishPart.ProviderMetadata["openai"].(*ProviderMetadata)
2590 require.True(t, ok)
2591 require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
2592 require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
2593 })
2594
2595 t.Run("should send store extension setting", func(t *testing.T) {
2596 t.Parallel()
2597
2598 server := newStreamingMockServer()
2599 defer server.close()
2600
2601 server.prepareStreamResponse(map[string]any{
2602 "content": []string{},
2603 })
2604
2605 provider, err := New(
2606 WithAPIKey("test-api-key"),
2607 WithBaseURL(server.server.URL),
2608 )
2609 require.NoError(t, err)
2610 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2611
2612 _, err = model.Stream(context.Background(), fantasy.Call{
2613 Prompt: testPrompt,
2614 ProviderOptions: NewProviderOptions(&ProviderOptions{
2615 Store: fantasy.Opt(true),
2616 }),
2617 })
2618
2619 require.NoError(t, err)
2620 require.Len(t, server.calls, 1)
2621
2622 call := server.calls[0]
2623 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2624 require.Equal(t, true, call.body["stream"])
2625 require.Equal(t, true, call.body["store"])
2626
2627 streamOptions := call.body["stream_options"].(map[string]any)
2628 require.Equal(t, true, streamOptions["include_usage"])
2629
2630 messages := call.body["messages"].([]any)
2631 require.Len(t, messages, 1)
2632
2633 message := messages[0].(map[string]any)
2634 require.Equal(t, "user", message["role"])
2635 require.Equal(t, "Hello", message["content"])
2636 })
2637
2638 t.Run("should send metadata extension values", func(t *testing.T) {
2639 t.Parallel()
2640
2641 server := newStreamingMockServer()
2642 defer server.close()
2643
2644 server.prepareStreamResponse(map[string]any{
2645 "content": []string{},
2646 })
2647
2648 provider, err := New(
2649 WithAPIKey("test-api-key"),
2650 WithBaseURL(server.server.URL),
2651 )
2652 require.NoError(t, err)
2653 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2654
2655 _, err = model.Stream(context.Background(), fantasy.Call{
2656 Prompt: testPrompt,
2657 ProviderOptions: NewProviderOptions(&ProviderOptions{
2658 Metadata: map[string]any{
2659 "custom": "value",
2660 },
2661 }),
2662 })
2663
2664 require.NoError(t, err)
2665 require.Len(t, server.calls, 1)
2666
2667 call := server.calls[0]
2668 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2669 require.Equal(t, true, call.body["stream"])
2670
2671 metadata := call.body["metadata"].(map[string]any)
2672 require.Equal(t, "value", metadata["custom"])
2673
2674 streamOptions := call.body["stream_options"].(map[string]any)
2675 require.Equal(t, true, streamOptions["include_usage"])
2676
2677 messages := call.body["messages"].([]any)
2678 require.Len(t, messages, 1)
2679
2680 message := messages[0].(map[string]any)
2681 require.Equal(t, "user", message["role"])
2682 require.Equal(t, "Hello", message["content"])
2683 })
2684
2685 t.Run("should send serviceTier flex processing setting in streaming", func(t *testing.T) {
2686 t.Parallel()
2687
2688 server := newStreamingMockServer()
2689 defer server.close()
2690
2691 server.prepareStreamResponse(map[string]any{
2692 "content": []string{},
2693 })
2694
2695 provider, err := New(
2696 WithAPIKey("test-api-key"),
2697 WithBaseURL(server.server.URL),
2698 )
2699 require.NoError(t, err)
2700 model, _ := provider.LanguageModel(t.Context(), "o3-mini")
2701
2702 _, err = model.Stream(context.Background(), fantasy.Call{
2703 Prompt: testPrompt,
2704 ProviderOptions: NewProviderOptions(&ProviderOptions{
2705 ServiceTier: fantasy.Opt("flex"),
2706 }),
2707 })
2708
2709 require.NoError(t, err)
2710 require.Len(t, server.calls, 1)
2711
2712 call := server.calls[0]
2713 require.Equal(t, "o3-mini", call.body["model"])
2714 require.Equal(t, "flex", call.body["service_tier"])
2715 require.Equal(t, true, call.body["stream"])
2716
2717 streamOptions := call.body["stream_options"].(map[string]any)
2718 require.Equal(t, true, streamOptions["include_usage"])
2719
2720 messages := call.body["messages"].([]any)
2721 require.Len(t, messages, 1)
2722
2723 message := messages[0].(map[string]any)
2724 require.Equal(t, "user", message["role"])
2725 require.Equal(t, "Hello", message["content"])
2726 })
2727
2728 t.Run("should send serviceTier priority processing setting in streaming", func(t *testing.T) {
2729 t.Parallel()
2730
2731 server := newStreamingMockServer()
2732 defer server.close()
2733
2734 server.prepareStreamResponse(map[string]any{
2735 "content": []string{},
2736 })
2737
2738 provider, err := New(
2739 WithAPIKey("test-api-key"),
2740 WithBaseURL(server.server.URL),
2741 )
2742 require.NoError(t, err)
2743 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
2744
2745 _, err = model.Stream(context.Background(), fantasy.Call{
2746 Prompt: testPrompt,
2747 ProviderOptions: NewProviderOptions(&ProviderOptions{
2748 ServiceTier: fantasy.Opt("priority"),
2749 }),
2750 })
2751
2752 require.NoError(t, err)
2753 require.Len(t, server.calls, 1)
2754
2755 call := server.calls[0]
2756 require.Equal(t, "gpt-4o-mini", call.body["model"])
2757 require.Equal(t, "priority", call.body["service_tier"])
2758 require.Equal(t, true, call.body["stream"])
2759
2760 streamOptions := call.body["stream_options"].(map[string]any)
2761 require.Equal(t, true, streamOptions["include_usage"])
2762
2763 messages := call.body["messages"].([]any)
2764 require.Len(t, messages, 1)
2765
2766 message := messages[0].(map[string]any)
2767 require.Equal(t, "user", message["role"])
2768 require.Equal(t, "Hello", message["content"])
2769 })
2770
2771 t.Run("should stream text delta for reasoning models", func(t *testing.T) {
2772 t.Parallel()
2773
2774 server := newStreamingMockServer()
2775 defer server.close()
2776
2777 server.prepareStreamResponse(map[string]any{
2778 "content": []string{"Hello, World!"},
2779 "model": "o1-preview",
2780 })
2781
2782 provider, err := New(
2783 WithAPIKey("test-api-key"),
2784 WithBaseURL(server.server.URL),
2785 )
2786 require.NoError(t, err)
2787 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
2788
2789 stream, err := model.Stream(context.Background(), fantasy.Call{
2790 Prompt: testPrompt,
2791 })
2792
2793 require.NoError(t, err)
2794
2795 parts, err := collectStreamParts(stream)
2796 require.NoError(t, err)
2797
2798 // Find text parts
2799 var textDeltas []string
2800 for _, part := range parts {
2801 if part.Type == fantasy.StreamPartTypeTextDelta {
2802 textDeltas = append(textDeltas, part.Delta)
2803 }
2804 }
2805
2806 // Should contain the text content (without empty delta)
2807 require.Equal(t, []string{"Hello, World!"}, textDeltas)
2808 })
2809
2810 t.Run("should send reasoning tokens", func(t *testing.T) {
2811 t.Parallel()
2812
2813 server := newStreamingMockServer()
2814 defer server.close()
2815
2816 server.prepareStreamResponse(map[string]any{
2817 "content": []string{"Hello, World!"},
2818 "model": "o1-preview",
2819 "usage": map[string]any{
2820 "prompt_tokens": 15,
2821 "completion_tokens": 20,
2822 "total_tokens": 35,
2823 "completion_tokens_details": map[string]any{
2824 "reasoning_tokens": 10,
2825 },
2826 },
2827 })
2828
2829 provider, err := New(
2830 WithAPIKey("test-api-key"),
2831 WithBaseURL(server.server.URL),
2832 )
2833 require.NoError(t, err)
2834 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
2835
2836 stream, err := model.Stream(context.Background(), fantasy.Call{
2837 Prompt: testPrompt,
2838 })
2839
2840 require.NoError(t, err)
2841
2842 parts, err := collectStreamParts(stream)
2843 require.NoError(t, err)
2844
2845 // Find finish part
2846 var finishPart *fantasy.StreamPart
2847 for _, part := range parts {
2848 if part.Type == fantasy.StreamPartTypeFinish {
2849 finishPart = &part
2850 break
2851 }
2852 }
2853
2854 require.NotNil(t, finishPart)
2855 require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2856 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2857 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2858 require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens)
2859 })
2860}
2861
2862func TestDefaultToPrompt_DropsEmptyMessages(t *testing.T) {
2863 t.Parallel()
2864
2865 t.Run("should drop truly empty assistant messages", func(t *testing.T) {
2866 t.Parallel()
2867
2868 prompt := fantasy.Prompt{
2869 {
2870 Role: fantasy.MessageRoleUser,
2871 Content: []fantasy.MessagePart{
2872 fantasy.TextPart{Text: "Hello"},
2873 },
2874 },
2875 {
2876 Role: fantasy.MessageRoleAssistant,
2877 Content: []fantasy.MessagePart{},
2878 },
2879 }
2880
2881 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
2882
2883 require.Len(t, messages, 1, "should only have user message")
2884 require.Len(t, warnings, 1)
2885 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
2886 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
2887 })
2888
2889 t.Run("should keep assistant messages with text content", func(t *testing.T) {
2890 t.Parallel()
2891
2892 prompt := fantasy.Prompt{
2893 {
2894 Role: fantasy.MessageRoleUser,
2895 Content: []fantasy.MessagePart{
2896 fantasy.TextPart{Text: "Hello"},
2897 },
2898 },
2899 {
2900 Role: fantasy.MessageRoleAssistant,
2901 Content: []fantasy.MessagePart{
2902 fantasy.TextPart{Text: "Hi there!"},
2903 },
2904 },
2905 }
2906
2907 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
2908
2909 require.Len(t, messages, 2, "should have both user and assistant messages")
2910 require.Empty(t, warnings)
2911 })
2912
2913 t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
2914 t.Parallel()
2915
2916 prompt := fantasy.Prompt{
2917 {
2918 Role: fantasy.MessageRoleUser,
2919 Content: []fantasy.MessagePart{
2920 fantasy.TextPart{Text: "What's the weather?"},
2921 },
2922 },
2923 {
2924 Role: fantasy.MessageRoleAssistant,
2925 Content: []fantasy.MessagePart{
2926 fantasy.ToolCallPart{
2927 ToolCallID: "call_123",
2928 ToolName: "get_weather",
2929 Input: `{"location":"NYC"}`,
2930 },
2931 },
2932 },
2933 }
2934
2935 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
2936
2937 require.Len(t, messages, 2, "should have both user and assistant messages")
2938 require.Empty(t, warnings)
2939 })
2940
2941 t.Run("should drop user messages without visible content", func(t *testing.T) {
2942 t.Parallel()
2943
2944 prompt := fantasy.Prompt{
2945 {
2946 Role: fantasy.MessageRoleUser,
2947 Content: []fantasy.MessagePart{
2948 fantasy.FilePart{
2949 Data: []byte("not supported"),
2950 MediaType: "application/unknown",
2951 },
2952 },
2953 },
2954 }
2955
2956 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
2957
2958 require.Empty(t, messages)
2959 require.Len(t, warnings, 2) // One for unsupported type, one for empty message
2960 require.Contains(t, warnings[1].Message, "dropping empty user message")
2961 })
2962
2963 t.Run("should keep user messages with image content", func(t *testing.T) {
2964 t.Parallel()
2965
2966 prompt := fantasy.Prompt{
2967 {
2968 Role: fantasy.MessageRoleUser,
2969 Content: []fantasy.MessagePart{
2970 fantasy.FilePart{
2971 Data: []byte{0x01, 0x02, 0x03},
2972 MediaType: "image/png",
2973 },
2974 },
2975 },
2976 }
2977
2978 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
2979
2980 require.Len(t, messages, 1)
2981 require.Empty(t, warnings)
2982 })
2983
2984 t.Run("should keep user messages with tool results", func(t *testing.T) {
2985 t.Parallel()
2986
2987 prompt := fantasy.Prompt{
2988 {
2989 Role: fantasy.MessageRoleTool,
2990 Content: []fantasy.MessagePart{
2991 fantasy.ToolResultPart{
2992 ToolCallID: "call_123",
2993 Output: fantasy.ToolResultOutputContentText{Text: "done"},
2994 },
2995 },
2996 },
2997 }
2998
2999 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3000
3001 require.Len(t, messages, 1)
3002 require.Empty(t, warnings)
3003 })
3004
3005 t.Run("should keep user messages with tool error results", func(t *testing.T) {
3006 t.Parallel()
3007
3008 prompt := fantasy.Prompt{
3009 {
3010 Role: fantasy.MessageRoleTool,
3011 Content: []fantasy.MessagePart{
3012 fantasy.ToolResultPart{
3013 ToolCallID: "call_456",
3014 Output: fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
3015 },
3016 },
3017 },
3018 }
3019
3020 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3021
3022 require.Len(t, messages, 1)
3023 require.Empty(t, warnings)
3024 })
3025}
3026
3027func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) {
3028 t.Parallel()
3029
3030 t.Run("should drop truly empty assistant messages", func(t *testing.T) {
3031 t.Parallel()
3032
3033 prompt := fantasy.Prompt{
3034 {
3035 Role: fantasy.MessageRoleUser,
3036 Content: []fantasy.MessagePart{
3037 fantasy.TextPart{Text: "Hello"},
3038 },
3039 },
3040 {
3041 Role: fantasy.MessageRoleAssistant,
3042 Content: []fantasy.MessagePart{},
3043 },
3044 }
3045
3046 input, warnings := toResponsesPrompt(prompt, "system")
3047
3048 require.Len(t, input, 1, "should only have user message")
3049 require.Len(t, warnings, 1)
3050 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
3051 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
3052 })
3053
3054 t.Run("should keep assistant messages with text content", func(t *testing.T) {
3055 t.Parallel()
3056
3057 prompt := fantasy.Prompt{
3058 {
3059 Role: fantasy.MessageRoleUser,
3060 Content: []fantasy.MessagePart{
3061 fantasy.TextPart{Text: "Hello"},
3062 },
3063 },
3064 {
3065 Role: fantasy.MessageRoleAssistant,
3066 Content: []fantasy.MessagePart{
3067 fantasy.TextPart{Text: "Hi there!"},
3068 },
3069 },
3070 }
3071
3072 input, warnings := toResponsesPrompt(prompt, "system")
3073
3074 require.Len(t, input, 2, "should have both user and assistant messages")
3075 require.Empty(t, warnings)
3076 })
3077
3078 t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
3079 t.Parallel()
3080
3081 prompt := fantasy.Prompt{
3082 {
3083 Role: fantasy.MessageRoleUser,
3084 Content: []fantasy.MessagePart{
3085 fantasy.TextPart{Text: "What's the weather?"},
3086 },
3087 },
3088 {
3089 Role: fantasy.MessageRoleAssistant,
3090 Content: []fantasy.MessagePart{
3091 fantasy.ToolCallPart{
3092 ToolCallID: "call_123",
3093 ToolName: "get_weather",
3094 Input: `{"location":"NYC"}`,
3095 },
3096 },
3097 },
3098 }
3099
3100 input, warnings := toResponsesPrompt(prompt, "system")
3101
3102 require.Len(t, input, 2, "should have both user and assistant messages")
3103 require.Empty(t, warnings)
3104 })
3105
3106 t.Run("should drop user messages without visible content", func(t *testing.T) {
3107 t.Parallel()
3108
3109 prompt := fantasy.Prompt{
3110 {
3111 Role: fantasy.MessageRoleUser,
3112 Content: []fantasy.MessagePart{
3113 fantasy.FilePart{
3114 Data: []byte("not supported"),
3115 MediaType: "application/unknown",
3116 },
3117 },
3118 },
3119 }
3120
3121 input, warnings := toResponsesPrompt(prompt, "system")
3122
3123 require.Empty(t, input)
3124 require.Len(t, warnings, 2) // One for unsupported type, one for empty message
3125 require.Contains(t, warnings[1].Message, "dropping empty user message")
3126 })
3127
3128 t.Run("should keep user messages with image content", func(t *testing.T) {
3129 t.Parallel()
3130
3131 prompt := fantasy.Prompt{
3132 {
3133 Role: fantasy.MessageRoleUser,
3134 Content: []fantasy.MessagePart{
3135 fantasy.FilePart{
3136 Data: []byte{0x01, 0x02, 0x03},
3137 MediaType: "image/png",
3138 },
3139 },
3140 },
3141 }
3142
3143 input, warnings := toResponsesPrompt(prompt, "system")
3144
3145 require.Len(t, input, 1)
3146 require.Empty(t, warnings)
3147 })
3148
3149 t.Run("should keep user messages with tool results", func(t *testing.T) {
3150 t.Parallel()
3151
3152 prompt := fantasy.Prompt{
3153 {
3154 Role: fantasy.MessageRoleTool,
3155 Content: []fantasy.MessagePart{
3156 fantasy.ToolResultPart{
3157 ToolCallID: "call_123",
3158 Output: fantasy.ToolResultOutputContentText{Text: "done"},
3159 },
3160 },
3161 },
3162 }
3163
3164 input, warnings := toResponsesPrompt(prompt, "system")
3165
3166 require.Len(t, input, 1)
3167 require.Empty(t, warnings)
3168 })
3169
3170 t.Run("should keep user messages with tool error results", func(t *testing.T) {
3171 t.Parallel()
3172
3173 prompt := fantasy.Prompt{
3174 {
3175 Role: fantasy.MessageRoleTool,
3176 Content: []fantasy.MessagePart{
3177 fantasy.ToolResultPart{
3178 ToolCallID: "call_456",
3179 Output: fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
3180 },
3181 },
3182 },
3183 }
3184
3185 input, warnings := toResponsesPrompt(prompt, "system")
3186
3187 require.Len(t, input, 1)
3188 require.Empty(t, warnings)
3189 })
3190}