1package openai
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "net/http"
9 "net/http/httptest"
10 "strings"
11 "testing"
12
13 "charm.land/fantasy"
14 "github.com/openai/openai-go/v2/packages/param"
15 "github.com/stretchr/testify/assert"
16 "github.com/stretchr/testify/require"
17)
18
19func TestToOpenAiPrompt_SystemMessages(t *testing.T) {
20 t.Parallel()
21
22 t.Run("should forward system messages", func(t *testing.T) {
23 t.Parallel()
24
25 prompt := fantasy.Prompt{
26 {
27 Role: fantasy.MessageRoleSystem,
28 Content: []fantasy.MessagePart{
29 fantasy.TextPart{Text: "You are a helpful assistant."},
30 },
31 },
32 }
33
34 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
35
36 require.Empty(t, warnings)
37 require.Len(t, messages, 1)
38
39 systemMsg := messages[0].OfSystem
40 require.NotNil(t, systemMsg)
41 require.Equal(t, "You are a helpful assistant.", systemMsg.Content.OfString.Value)
42 })
43
44 t.Run("should handle empty system messages", func(t *testing.T) {
45 t.Parallel()
46
47 prompt := fantasy.Prompt{
48 {
49 Role: fantasy.MessageRoleSystem,
50 Content: []fantasy.MessagePart{},
51 },
52 }
53
54 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
55
56 require.Len(t, warnings, 1)
57 require.Contains(t, warnings[0].Message, "system prompt has no text parts")
58 require.Empty(t, messages)
59 })
60
61 t.Run("should join multiple system text parts", func(t *testing.T) {
62 t.Parallel()
63
64 prompt := fantasy.Prompt{
65 {
66 Role: fantasy.MessageRoleSystem,
67 Content: []fantasy.MessagePart{
68 fantasy.TextPart{Text: "You are a helpful assistant."},
69 fantasy.TextPart{Text: "Be concise."},
70 },
71 },
72 }
73
74 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
75
76 require.Empty(t, warnings)
77 require.Len(t, messages, 1)
78
79 systemMsg := messages[0].OfSystem
80 require.NotNil(t, systemMsg)
81 require.Equal(t, "You are a helpful assistant.\nBe concise.", systemMsg.Content.OfString.Value)
82 })
83}
84
85func TestToOpenAiPrompt_UserMessages(t *testing.T) {
86 t.Parallel()
87
88 t.Run("should convert messages with only a text part to a string content", func(t *testing.T) {
89 t.Parallel()
90
91 prompt := fantasy.Prompt{
92 {
93 Role: fantasy.MessageRoleUser,
94 Content: []fantasy.MessagePart{
95 fantasy.TextPart{Text: "Hello"},
96 },
97 },
98 }
99
100 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
101
102 require.Empty(t, warnings)
103 require.Len(t, messages, 1)
104
105 userMsg := messages[0].OfUser
106 require.NotNil(t, userMsg)
107 require.Equal(t, "Hello", userMsg.Content.OfString.Value)
108 })
109
110 t.Run("should convert messages with image parts", func(t *testing.T) {
111 t.Parallel()
112
113 imageData := []byte{0, 1, 2, 3}
114 prompt := fantasy.Prompt{
115 {
116 Role: fantasy.MessageRoleUser,
117 Content: []fantasy.MessagePart{
118 fantasy.TextPart{Text: "Hello"},
119 fantasy.FilePart{
120 MediaType: "image/png",
121 Data: imageData,
122 },
123 },
124 },
125 }
126
127 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
128
129 require.Empty(t, warnings)
130 require.Len(t, messages, 1)
131
132 userMsg := messages[0].OfUser
133 require.NotNil(t, userMsg)
134
135 content := userMsg.Content.OfArrayOfContentParts
136 require.Len(t, content, 2)
137
138 // Check text part
139 textPart := content[0].OfText
140 require.NotNil(t, textPart)
141 require.Equal(t, "Hello", textPart.Text)
142
143 // Check image part
144 imagePart := content[1].OfImageURL
145 require.NotNil(t, imagePart)
146 expectedURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
147 require.Equal(t, expectedURL, imagePart.ImageURL.URL)
148 })
149
150 t.Run("should add image detail when specified through provider options", func(t *testing.T) {
151 t.Parallel()
152
153 imageData := []byte{0, 1, 2, 3}
154 prompt := fantasy.Prompt{
155 {
156 Role: fantasy.MessageRoleUser,
157 Content: []fantasy.MessagePart{
158 fantasy.FilePart{
159 MediaType: "image/png",
160 Data: imageData,
161 ProviderOptions: NewProviderFileOptions(&ProviderFileOptions{
162 ImageDetail: "low",
163 }),
164 },
165 },
166 },
167 }
168
169 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
170
171 require.Empty(t, warnings)
172 require.Len(t, messages, 1)
173
174 userMsg := messages[0].OfUser
175 require.NotNil(t, userMsg)
176
177 content := userMsg.Content.OfArrayOfContentParts
178 require.Len(t, content, 1)
179
180 imagePart := content[0].OfImageURL
181 require.NotNil(t, imagePart)
182 require.Equal(t, "low", imagePart.ImageURL.Detail)
183 })
184}
185
186func TestToOpenAiPrompt_FileParts(t *testing.T) {
187 t.Parallel()
188
189 t.Run("should throw for unsupported mime types", func(t *testing.T) {
190 t.Parallel()
191
192 prompt := fantasy.Prompt{
193 {
194 Role: fantasy.MessageRoleUser,
195 Content: []fantasy.MessagePart{
196 fantasy.FilePart{
197 MediaType: "application/something",
198 Data: []byte("test"),
199 },
200 },
201 },
202 }
203
204 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
205
206 require.Len(t, warnings, 2) // unsupported type + empty message
207 require.Contains(t, warnings[0].Message, "file part media type application/something not supported")
208 require.Contains(t, warnings[1].Message, "dropping empty user message")
209 require.Empty(t, messages) // Message is now dropped because it's empty
210 })
211
212 t.Run("should add audio content for audio/wav file parts", func(t *testing.T) {
213 t.Parallel()
214
215 audioData := []byte{0, 1, 2, 3}
216 prompt := fantasy.Prompt{
217 {
218 Role: fantasy.MessageRoleUser,
219 Content: []fantasy.MessagePart{
220 fantasy.FilePart{
221 MediaType: "audio/wav",
222 Data: audioData,
223 },
224 },
225 },
226 }
227
228 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
229
230 require.Empty(t, warnings)
231 require.Len(t, messages, 1)
232
233 userMsg := messages[0].OfUser
234 require.NotNil(t, userMsg)
235
236 content := userMsg.Content.OfArrayOfContentParts
237 require.Len(t, content, 1)
238
239 audioPart := content[0].OfInputAudio
240 require.NotNil(t, audioPart)
241 require.Equal(t, base64.StdEncoding.EncodeToString(audioData), audioPart.InputAudio.Data)
242 require.Equal(t, "wav", audioPart.InputAudio.Format)
243 })
244
245 t.Run("should add audio content for audio/mpeg file parts", func(t *testing.T) {
246 t.Parallel()
247
248 audioData := []byte{0, 1, 2, 3}
249 prompt := fantasy.Prompt{
250 {
251 Role: fantasy.MessageRoleUser,
252 Content: []fantasy.MessagePart{
253 fantasy.FilePart{
254 MediaType: "audio/mpeg",
255 Data: audioData,
256 },
257 },
258 },
259 }
260
261 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
262
263 require.Empty(t, warnings)
264 require.Len(t, messages, 1)
265
266 userMsg := messages[0].OfUser
267 content := userMsg.Content.OfArrayOfContentParts
268 audioPart := content[0].OfInputAudio
269 require.NotNil(t, audioPart)
270 require.Equal(t, "mp3", audioPart.InputAudio.Format)
271 })
272
273 t.Run("should add audio content for audio/mp3 file parts", func(t *testing.T) {
274 t.Parallel()
275
276 audioData := []byte{0, 1, 2, 3}
277 prompt := fantasy.Prompt{
278 {
279 Role: fantasy.MessageRoleUser,
280 Content: []fantasy.MessagePart{
281 fantasy.FilePart{
282 MediaType: "audio/mp3",
283 Data: audioData,
284 },
285 },
286 },
287 }
288
289 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
290
291 require.Empty(t, warnings)
292 require.Len(t, messages, 1)
293
294 userMsg := messages[0].OfUser
295 content := userMsg.Content.OfArrayOfContentParts
296 audioPart := content[0].OfInputAudio
297 require.NotNil(t, audioPart)
298 require.Equal(t, "mp3", audioPart.InputAudio.Format)
299 })
300
301 t.Run("should convert messages with PDF file parts", func(t *testing.T) {
302 t.Parallel()
303
304 pdfData := []byte{1, 2, 3, 4, 5}
305 prompt := fantasy.Prompt{
306 {
307 Role: fantasy.MessageRoleUser,
308 Content: []fantasy.MessagePart{
309 fantasy.FilePart{
310 MediaType: "application/pdf",
311 Data: pdfData,
312 Filename: "document.pdf",
313 },
314 },
315 },
316 }
317
318 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
319
320 require.Empty(t, warnings)
321 require.Len(t, messages, 1)
322
323 userMsg := messages[0].OfUser
324 content := userMsg.Content.OfArrayOfContentParts
325 require.Len(t, content, 1)
326
327 filePart := content[0].OfFile
328 require.NotNil(t, filePart)
329 require.Equal(t, "document.pdf", filePart.File.Filename.Value)
330
331 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
332 require.Equal(t, expectedData, filePart.File.FileData.Value)
333 })
334
335 t.Run("should convert messages with binary PDF file parts", func(t *testing.T) {
336 t.Parallel()
337
338 pdfData := []byte{1, 2, 3, 4, 5}
339 prompt := fantasy.Prompt{
340 {
341 Role: fantasy.MessageRoleUser,
342 Content: []fantasy.MessagePart{
343 fantasy.FilePart{
344 MediaType: "application/pdf",
345 Data: pdfData,
346 Filename: "document.pdf",
347 },
348 },
349 },
350 }
351
352 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
353
354 require.Empty(t, warnings)
355 require.Len(t, messages, 1)
356
357 userMsg := messages[0].OfUser
358 content := userMsg.Content.OfArrayOfContentParts
359 filePart := content[0].OfFile
360 require.NotNil(t, filePart)
361
362 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
363 require.Equal(t, expectedData, filePart.File.FileData.Value)
364 })
365
366 t.Run("should convert messages with PDF file parts using file_id", func(t *testing.T) {
367 t.Parallel()
368
369 prompt := fantasy.Prompt{
370 {
371 Role: fantasy.MessageRoleUser,
372 Content: []fantasy.MessagePart{
373 fantasy.FilePart{
374 MediaType: "application/pdf",
375 Data: []byte("file-pdf-12345"),
376 },
377 },
378 },
379 }
380
381 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
382
383 require.Empty(t, warnings)
384 require.Len(t, messages, 1)
385
386 userMsg := messages[0].OfUser
387 content := userMsg.Content.OfArrayOfContentParts
388 filePart := content[0].OfFile
389 require.NotNil(t, filePart)
390 require.Equal(t, "file-pdf-12345", filePart.File.FileID.Value)
391 require.True(t, param.IsOmitted(filePart.File.FileData))
392 require.True(t, param.IsOmitted(filePart.File.Filename))
393 })
394
395 t.Run("should use default filename for PDF file parts when not provided", func(t *testing.T) {
396 t.Parallel()
397
398 pdfData := []byte{1, 2, 3, 4, 5}
399 prompt := fantasy.Prompt{
400 {
401 Role: fantasy.MessageRoleUser,
402 Content: []fantasy.MessagePart{
403 fantasy.FilePart{
404 MediaType: "application/pdf",
405 Data: pdfData,
406 },
407 },
408 },
409 }
410
411 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
412
413 require.Empty(t, warnings)
414 require.Len(t, messages, 1)
415
416 userMsg := messages[0].OfUser
417 content := userMsg.Content.OfArrayOfContentParts
418 filePart := content[0].OfFile
419 require.NotNil(t, filePart)
420 require.Equal(t, "part-0.pdf", filePart.File.Filename.Value)
421 })
422}
423
424func TestToOpenAiPrompt_ToolCalls(t *testing.T) {
425 t.Parallel()
426
427 t.Run("should stringify arguments to tool calls", func(t *testing.T) {
428 t.Parallel()
429
430 inputArgs := map[string]any{"foo": "bar123"}
431 inputJSON, _ := json.Marshal(inputArgs)
432
433 outputResult := map[string]any{"oof": "321rab"}
434 outputJSON, _ := json.Marshal(outputResult)
435
436 prompt := fantasy.Prompt{
437 {
438 Role: fantasy.MessageRoleAssistant,
439 Content: []fantasy.MessagePart{
440 fantasy.ToolCallPart{
441 ToolCallID: "quux",
442 ToolName: "thwomp",
443 Input: string(inputJSON),
444 },
445 },
446 },
447 {
448 Role: fantasy.MessageRoleTool,
449 Content: []fantasy.MessagePart{
450 fantasy.ToolResultPart{
451 ToolCallID: "quux",
452 Output: fantasy.ToolResultOutputContentText{
453 Text: string(outputJSON),
454 },
455 },
456 },
457 },
458 }
459
460 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
461
462 require.Empty(t, warnings)
463 require.Len(t, messages, 2)
464
465 // Check assistant message with tool call
466 assistantMsg := messages[0].OfAssistant
467 require.NotNil(t, assistantMsg)
468 require.Equal(t, "", assistantMsg.Content.OfString.Value)
469 require.Len(t, assistantMsg.ToolCalls, 1)
470
471 toolCall := assistantMsg.ToolCalls[0].OfFunction
472 require.NotNil(t, toolCall)
473 require.Equal(t, "quux", toolCall.ID)
474 require.Equal(t, "thwomp", toolCall.Function.Name)
475 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
476
477 // Check tool message
478 toolMsg := messages[1].OfTool
479 require.NotNil(t, toolMsg)
480 require.Equal(t, string(outputJSON), toolMsg.Content.OfString.Value)
481 require.Equal(t, "quux", toolMsg.ToolCallID)
482 })
483
484 t.Run("should handle different tool output types", func(t *testing.T) {
485 t.Parallel()
486
487 prompt := fantasy.Prompt{
488 {
489 Role: fantasy.MessageRoleTool,
490 Content: []fantasy.MessagePart{
491 fantasy.ToolResultPart{
492 ToolCallID: "text-tool",
493 Output: fantasy.ToolResultOutputContentText{
494 Text: "Hello world",
495 },
496 },
497 fantasy.ToolResultPart{
498 ToolCallID: "error-tool",
499 Output: fantasy.ToolResultOutputContentError{
500 Error: errors.New("Something went wrong"),
501 },
502 },
503 },
504 },
505 }
506
507 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
508
509 require.Empty(t, warnings)
510 require.Len(t, messages, 2)
511
512 // Check first tool message (text)
513 textToolMsg := messages[0].OfTool
514 require.NotNil(t, textToolMsg)
515 require.Equal(t, "Hello world", textToolMsg.Content.OfString.Value)
516 require.Equal(t, "text-tool", textToolMsg.ToolCallID)
517
518 // Check second tool message (error)
519 errorToolMsg := messages[1].OfTool
520 require.NotNil(t, errorToolMsg)
521 require.Equal(t, "Something went wrong", errorToolMsg.Content.OfString.Value)
522 require.Equal(t, "error-tool", errorToolMsg.ToolCallID)
523 })
524}
525
526func TestToOpenAiPrompt_AssistantMessages(t *testing.T) {
527 t.Parallel()
528
529 t.Run("should handle simple text assistant messages", func(t *testing.T) {
530 t.Parallel()
531
532 prompt := fantasy.Prompt{
533 {
534 Role: fantasy.MessageRoleAssistant,
535 Content: []fantasy.MessagePart{
536 fantasy.TextPart{Text: "Hello, how can I help you?"},
537 },
538 },
539 }
540
541 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
542
543 require.Empty(t, warnings)
544 require.Len(t, messages, 1)
545
546 assistantMsg := messages[0].OfAssistant
547 require.NotNil(t, assistantMsg)
548 require.Equal(t, "Hello, how can I help you?", assistantMsg.Content.OfString.Value)
549 })
550
551 t.Run("should handle assistant messages with mixed content", func(t *testing.T) {
552 t.Parallel()
553
554 inputArgs := map[string]any{"query": "test"}
555 inputJSON, _ := json.Marshal(inputArgs)
556
557 prompt := fantasy.Prompt{
558 {
559 Role: fantasy.MessageRoleAssistant,
560 Content: []fantasy.MessagePart{
561 fantasy.TextPart{Text: "Let me search for that."},
562 fantasy.ToolCallPart{
563 ToolCallID: "call-123",
564 ToolName: "search",
565 Input: string(inputJSON),
566 },
567 },
568 },
569 }
570
571 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-5")
572
573 require.Empty(t, warnings)
574 require.Len(t, messages, 1)
575
576 assistantMsg := messages[0].OfAssistant
577 require.NotNil(t, assistantMsg)
578 require.Equal(t, "Let me search for that.", assistantMsg.Content.OfString.Value)
579 require.Len(t, assistantMsg.ToolCalls, 1)
580
581 toolCall := assistantMsg.ToolCalls[0].OfFunction
582 require.Equal(t, "call-123", toolCall.ID)
583 require.Equal(t, "search", toolCall.Function.Name)
584 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
585 })
586}
587
588var testPrompt = fantasy.Prompt{
589 {
590 Role: fantasy.MessageRoleUser,
591 Content: []fantasy.MessagePart{
592 fantasy.TextPart{Text: "Hello"},
593 },
594 },
595}
596
597var testLogprobs = map[string]any{
598 "content": []map[string]any{
599 {
600 "token": "Hello",
601 "logprob": -0.0009994634,
602 "top_logprobs": []map[string]any{
603 {
604 "token": "Hello",
605 "logprob": -0.0009994634,
606 },
607 },
608 },
609 {
610 "token": "!",
611 "logprob": -0.13410144,
612 "top_logprobs": []map[string]any{
613 {
614 "token": "!",
615 "logprob": -0.13410144,
616 },
617 },
618 },
619 {
620 "token": " How",
621 "logprob": -0.0009250381,
622 "top_logprobs": []map[string]any{
623 {
624 "token": " How",
625 "logprob": -0.0009250381,
626 },
627 },
628 },
629 {
630 "token": " can",
631 "logprob": -0.047709424,
632 "top_logprobs": []map[string]any{
633 {
634 "token": " can",
635 "logprob": -0.047709424,
636 },
637 },
638 },
639 {
640 "token": " I",
641 "logprob": -0.000009014684,
642 "top_logprobs": []map[string]any{
643 {
644 "token": " I",
645 "logprob": -0.000009014684,
646 },
647 },
648 },
649 {
650 "token": " assist",
651 "logprob": -0.009125131,
652 "top_logprobs": []map[string]any{
653 {
654 "token": " assist",
655 "logprob": -0.009125131,
656 },
657 },
658 },
659 {
660 "token": " you",
661 "logprob": -0.0000066306106,
662 "top_logprobs": []map[string]any{
663 {
664 "token": " you",
665 "logprob": -0.0000066306106,
666 },
667 },
668 },
669 {
670 "token": " today",
671 "logprob": -0.00011093382,
672 "top_logprobs": []map[string]any{
673 {
674 "token": " today",
675 "logprob": -0.00011093382,
676 },
677 },
678 },
679 {
680 "token": "?",
681 "logprob": -0.00004596782,
682 "top_logprobs": []map[string]any{
683 {
684 "token": "?",
685 "logprob": -0.00004596782,
686 },
687 },
688 },
689 },
690}
691
692type mockServer struct {
693 server *httptest.Server
694 response map[string]any
695 calls []mockCall
696}
697
698type mockCall struct {
699 method string
700 path string
701 headers map[string]string
702 body map[string]any
703}
704
705func newMockServer() *mockServer {
706 ms := &mockServer{
707 calls: make([]mockCall, 0),
708 }
709
710 ms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
711 // Record the call
712 call := mockCall{
713 method: r.Method,
714 path: r.URL.Path,
715 headers: make(map[string]string),
716 }
717
718 for k, v := range r.Header {
719 if len(v) > 0 {
720 call.headers[k] = v[0]
721 }
722 }
723
724 // Parse request body
725 if r.Body != nil {
726 var body map[string]any
727 json.NewDecoder(r.Body).Decode(&body)
728 call.body = body
729 }
730
731 ms.calls = append(ms.calls, call)
732
733 // Return mock response
734 w.Header().Set("Content-Type", "application/json")
735 json.NewEncoder(w).Encode(ms.response)
736 }))
737
738 return ms
739}
740
741func (ms *mockServer) close() {
742 ms.server.Close()
743}
744
745func (ms *mockServer) prepareJSONResponse(opts map[string]any) {
746 // Default values
747 response := map[string]any{
748 "id": "chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd",
749 "object": "chat.completion",
750 "created": 1711115037,
751 "model": "gpt-3.5-turbo-0125",
752 "choices": []map[string]any{
753 {
754 "index": 0,
755 "message": map[string]any{
756 "role": "assistant",
757 "content": "",
758 },
759 "finish_reason": "stop",
760 },
761 },
762 "usage": map[string]any{
763 "prompt_tokens": 4,
764 "total_tokens": 34,
765 "completion_tokens": 30,
766 },
767 "system_fingerprint": "fp_3bc1b5746c",
768 }
769
770 // Override with provided options
771 for k, v := range opts {
772 switch k {
773 case "content":
774 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["content"] = v
775 case "tool_calls":
776 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["tool_calls"] = v
777 case "function_call":
778 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["function_call"] = v
779 case "annotations":
780 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["annotations"] = v
781 case "usage":
782 response["usage"] = v
783 case "finish_reason":
784 response["choices"].([]map[string]any)[0]["finish_reason"] = v
785 case "id":
786 response["id"] = v
787 case "created":
788 response["created"] = v
789 case "model":
790 response["model"] = v
791 case "logprobs":
792 if v != nil {
793 response["choices"].([]map[string]any)[0]["logprobs"] = v
794 }
795 }
796 }
797
798 ms.response = response
799}
800
801func TestDoGenerate(t *testing.T) {
802 t.Parallel()
803
804 t.Run("should extract text response", func(t *testing.T) {
805 t.Parallel()
806
807 server := newMockServer()
808 defer server.close()
809
810 server.prepareJSONResponse(map[string]any{
811 "content": "Hello, World!",
812 })
813
814 provider, err := New(
815 WithAPIKey("test-api-key"),
816 WithBaseURL(server.server.URL),
817 )
818 require.NoError(t, err)
819 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
820
821 result, err := model.Generate(context.Background(), fantasy.Call{
822 Prompt: testPrompt,
823 })
824
825 require.NoError(t, err)
826 require.Len(t, result.Content, 1)
827
828 textContent, ok := result.Content[0].(fantasy.TextContent)
829 require.True(t, ok)
830 require.Equal(t, "Hello, World!", textContent.Text)
831 })
832
833 t.Run("should extract usage", func(t *testing.T) {
834 t.Parallel()
835
836 server := newMockServer()
837 defer server.close()
838
839 server.prepareJSONResponse(map[string]any{
840 "usage": map[string]any{
841 "prompt_tokens": 20,
842 "total_tokens": 25,
843 "completion_tokens": 5,
844 },
845 })
846
847 provider, err := New(
848 WithAPIKey("test-api-key"),
849 WithBaseURL(server.server.URL),
850 )
851 require.NoError(t, err)
852 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
853
854 result, err := model.Generate(context.Background(), fantasy.Call{
855 Prompt: testPrompt,
856 })
857
858 require.NoError(t, err)
859 require.Equal(t, int64(20), result.Usage.InputTokens)
860 require.Equal(t, int64(5), result.Usage.OutputTokens)
861 require.Equal(t, int64(25), result.Usage.TotalTokens)
862 })
863
864 t.Run("should send request body", func(t *testing.T) {
865 t.Parallel()
866
867 server := newMockServer()
868 defer server.close()
869
870 server.prepareJSONResponse(map[string]any{})
871
872 provider, err := New(
873 WithAPIKey("test-api-key"),
874 WithBaseURL(server.server.URL),
875 )
876 require.NoError(t, err)
877 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
878
879 _, err = model.Generate(context.Background(), fantasy.Call{
880 Prompt: testPrompt,
881 })
882
883 require.NoError(t, err)
884 require.Len(t, server.calls, 1)
885
886 call := server.calls[0]
887 require.Equal(t, "POST", call.method)
888 require.Equal(t, "/chat/completions", call.path)
889 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
890
891 messages, ok := call.body["messages"].([]any)
892 require.True(t, ok)
893 require.Len(t, messages, 1)
894
895 message := messages[0].(map[string]any)
896 require.Equal(t, "user", message["role"])
897 require.Equal(t, "Hello", message["content"])
898 })
899
900 t.Run("should support partial usage", func(t *testing.T) {
901 t.Parallel()
902
903 server := newMockServer()
904 defer server.close()
905
906 server.prepareJSONResponse(map[string]any{
907 "usage": map[string]any{
908 "prompt_tokens": 20,
909 "total_tokens": 20,
910 },
911 })
912
913 provider, err := New(
914 WithAPIKey("test-api-key"),
915 WithBaseURL(server.server.URL),
916 )
917 require.NoError(t, err)
918 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
919
920 result, err := model.Generate(context.Background(), fantasy.Call{
921 Prompt: testPrompt,
922 })
923
924 require.NoError(t, err)
925 require.Equal(t, int64(20), result.Usage.InputTokens)
926 require.Equal(t, int64(0), result.Usage.OutputTokens)
927 require.Equal(t, int64(20), result.Usage.TotalTokens)
928 })
929
930 t.Run("should extract logprobs", func(t *testing.T) {
931 t.Parallel()
932
933 server := newMockServer()
934 defer server.close()
935
936 server.prepareJSONResponse(map[string]any{
937 "logprobs": testLogprobs,
938 })
939
940 provider, err := New(
941 WithAPIKey("test-api-key"),
942 WithBaseURL(server.server.URL),
943 )
944 require.NoError(t, err)
945 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
946
947 result, err := model.Generate(context.Background(), fantasy.Call{
948 Prompt: testPrompt,
949 ProviderOptions: NewProviderOptions(&ProviderOptions{
950 LogProbs: fantasy.Opt(true),
951 }),
952 })
953
954 require.NoError(t, err)
955 require.NotNil(t, result.ProviderMetadata)
956
957 openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
958 require.True(t, ok)
959
960 logprobs := openaiMeta.Logprobs
961 require.True(t, ok)
962 require.NotNil(t, logprobs)
963 })
964
965 t.Run("should extract finish reason", func(t *testing.T) {
966 t.Parallel()
967
968 server := newMockServer()
969 defer server.close()
970
971 server.prepareJSONResponse(map[string]any{
972 "finish_reason": "stop",
973 })
974
975 provider, err := New(
976 WithAPIKey("test-api-key"),
977 WithBaseURL(server.server.URL),
978 )
979 require.NoError(t, err)
980 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
981
982 result, err := model.Generate(context.Background(), fantasy.Call{
983 Prompt: testPrompt,
984 })
985
986 require.NoError(t, err)
987 require.Equal(t, fantasy.FinishReasonStop, result.FinishReason)
988 })
989
990 t.Run("should support unknown finish reason", func(t *testing.T) {
991 t.Parallel()
992
993 server := newMockServer()
994 defer server.close()
995
996 server.prepareJSONResponse(map[string]any{
997 "finish_reason": "eos",
998 })
999
1000 provider, err := New(
1001 WithAPIKey("test-api-key"),
1002 WithBaseURL(server.server.URL),
1003 )
1004 require.NoError(t, err)
1005 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1006
1007 result, err := model.Generate(context.Background(), fantasy.Call{
1008 Prompt: testPrompt,
1009 })
1010
1011 require.NoError(t, err)
1012 require.Equal(t, fantasy.FinishReasonUnknown, result.FinishReason)
1013 })
1014
1015 t.Run("should pass the model and the messages", func(t *testing.T) {
1016 t.Parallel()
1017
1018 server := newMockServer()
1019 defer server.close()
1020
1021 server.prepareJSONResponse(map[string]any{
1022 "content": "",
1023 })
1024
1025 provider, err := New(
1026 WithAPIKey("test-api-key"),
1027 WithBaseURL(server.server.URL),
1028 )
1029 require.NoError(t, err)
1030 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1031
1032 _, err = model.Generate(context.Background(), fantasy.Call{
1033 Prompt: testPrompt,
1034 })
1035
1036 require.NoError(t, err)
1037 require.Len(t, server.calls, 1)
1038
1039 call := server.calls[0]
1040 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1041
1042 messages := call.body["messages"].([]any)
1043 require.Len(t, messages, 1)
1044
1045 message := messages[0].(map[string]any)
1046 require.Equal(t, "user", message["role"])
1047 require.Equal(t, "Hello", message["content"])
1048 })
1049
1050 t.Run("should pass settings", func(t *testing.T) {
1051 t.Parallel()
1052
1053 server := newMockServer()
1054 defer server.close()
1055
1056 server.prepareJSONResponse(map[string]any{})
1057
1058 provider, err := New(
1059 WithAPIKey("test-api-key"),
1060 WithBaseURL(server.server.URL),
1061 )
1062 require.NoError(t, err)
1063 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1064
1065 _, err = model.Generate(context.Background(), fantasy.Call{
1066 Prompt: testPrompt,
1067 ProviderOptions: NewProviderOptions(&ProviderOptions{
1068 LogitBias: map[string]int64{
1069 "50256": -100,
1070 },
1071 ParallelToolCalls: fantasy.Opt(false),
1072 User: fantasy.Opt("test-user-id"),
1073 }),
1074 })
1075
1076 require.NoError(t, err)
1077 require.Len(t, server.calls, 1)
1078
1079 call := server.calls[0]
1080 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1081
1082 messages := call.body["messages"].([]any)
1083 require.Len(t, messages, 1)
1084
1085 logitBias := call.body["logit_bias"].(map[string]any)
1086 require.Equal(t, float64(-100), logitBias["50256"])
1087 require.Equal(t, false, call.body["parallel_tool_calls"])
1088 require.Equal(t, "test-user-id", call.body["user"])
1089 })
1090
1091 t.Run("should pass reasoningEffort setting", func(t *testing.T) {
1092 t.Parallel()
1093
1094 server := newMockServer()
1095 defer server.close()
1096
1097 server.prepareJSONResponse(map[string]any{
1098 "content": "",
1099 })
1100
1101 provider, err := New(
1102 WithAPIKey("test-api-key"),
1103 WithBaseURL(server.server.URL),
1104 )
1105 require.NoError(t, err)
1106 model, _ := provider.LanguageModel(t.Context(), "o1-mini")
1107
1108 _, err = model.Generate(context.Background(), fantasy.Call{
1109 Prompt: testPrompt,
1110 ProviderOptions: NewProviderOptions(
1111 &ProviderOptions{
1112 ReasoningEffort: ReasoningEffortOption(ReasoningEffortLow),
1113 },
1114 ),
1115 })
1116
1117 require.NoError(t, err)
1118 require.Len(t, server.calls, 1)
1119
1120 call := server.calls[0]
1121 require.Equal(t, "o1-mini", call.body["model"])
1122 require.Equal(t, "low", call.body["reasoning_effort"])
1123
1124 messages := call.body["messages"].([]any)
1125 require.Len(t, messages, 1)
1126
1127 message := messages[0].(map[string]any)
1128 require.Equal(t, "user", message["role"])
1129 require.Equal(t, "Hello", message["content"])
1130 })
1131
1132 t.Run("should pass textVerbosity setting", func(t *testing.T) {
1133 t.Parallel()
1134
1135 server := newMockServer()
1136 defer server.close()
1137
1138 server.prepareJSONResponse(map[string]any{
1139 "content": "",
1140 })
1141
1142 provider, err := New(
1143 WithAPIKey("test-api-key"),
1144 WithBaseURL(server.server.URL),
1145 )
1146 require.NoError(t, err)
1147 model, _ := provider.LanguageModel(t.Context(), "gpt-4o")
1148
1149 _, err = model.Generate(context.Background(), fantasy.Call{
1150 Prompt: testPrompt,
1151 ProviderOptions: NewProviderOptions(&ProviderOptions{
1152 TextVerbosity: fantasy.Opt("low"),
1153 }),
1154 })
1155
1156 require.NoError(t, err)
1157 require.Len(t, server.calls, 1)
1158
1159 call := server.calls[0]
1160 require.Equal(t, "gpt-4o", call.body["model"])
1161 require.Equal(t, "low", call.body["verbosity"])
1162
1163 messages := call.body["messages"].([]any)
1164 require.Len(t, messages, 1)
1165
1166 message := messages[0].(map[string]any)
1167 require.Equal(t, "user", message["role"])
1168 require.Equal(t, "Hello", message["content"])
1169 })
1170
1171 t.Run("should pass tools and toolChoice", func(t *testing.T) {
1172 t.Parallel()
1173
1174 server := newMockServer()
1175 defer server.close()
1176
1177 server.prepareJSONResponse(map[string]any{
1178 "content": "",
1179 })
1180
1181 provider, err := New(
1182 WithAPIKey("test-api-key"),
1183 WithBaseURL(server.server.URL),
1184 )
1185 require.NoError(t, err)
1186 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1187
1188 _, err = model.Generate(context.Background(), fantasy.Call{
1189 Prompt: testPrompt,
1190 Tools: []fantasy.Tool{
1191 fantasy.FunctionTool{
1192 Name: "test-tool",
1193 InputSchema: map[string]any{
1194 "type": "object",
1195 "properties": map[string]any{
1196 "value": map[string]any{
1197 "type": "string",
1198 },
1199 },
1200 "required": []string{"value"},
1201 "additionalProperties": false,
1202 "$schema": "http://json-schema.org/draft-07/schema#",
1203 },
1204 },
1205 },
1206 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoice("test-tool")}[0],
1207 })
1208
1209 require.NoError(t, err)
1210 require.Len(t, server.calls, 1)
1211
1212 call := server.calls[0]
1213 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1214
1215 messages := call.body["messages"].([]any)
1216 require.Len(t, messages, 1)
1217
1218 tools := call.body["tools"].([]any)
1219 require.Len(t, tools, 1)
1220
1221 tool := tools[0].(map[string]any)
1222 require.Equal(t, "function", tool["type"])
1223
1224 function := tool["function"].(map[string]any)
1225 require.Equal(t, "test-tool", function["name"])
1226 require.Equal(t, false, function["strict"])
1227
1228 toolChoice := call.body["tool_choice"].(map[string]any)
1229 require.Equal(t, "function", toolChoice["type"])
1230
1231 toolChoiceFunction := toolChoice["function"].(map[string]any)
1232 require.Equal(t, "test-tool", toolChoiceFunction["name"])
1233 })
1234
1235 t.Run("should parse tool results", func(t *testing.T) {
1236 t.Parallel()
1237
1238 server := newMockServer()
1239 defer server.close()
1240
1241 server.prepareJSONResponse(map[string]any{
1242 "tool_calls": []map[string]any{
1243 {
1244 "id": "call_O17Uplv4lJvD6DVdIvFFeRMw",
1245 "type": "function",
1246 "function": map[string]any{
1247 "name": "test-tool",
1248 "arguments": `{"value":"Spark"}`,
1249 },
1250 },
1251 },
1252 })
1253
1254 provider, err := New(
1255 WithAPIKey("test-api-key"),
1256 WithBaseURL(server.server.URL),
1257 )
1258 require.NoError(t, err)
1259 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1260
1261 result, err := model.Generate(context.Background(), fantasy.Call{
1262 Prompt: testPrompt,
1263 Tools: []fantasy.Tool{
1264 fantasy.FunctionTool{
1265 Name: "test-tool",
1266 InputSchema: map[string]any{
1267 "type": "object",
1268 "properties": map[string]any{
1269 "value": map[string]any{
1270 "type": "string",
1271 },
1272 },
1273 "required": []string{"value"},
1274 "additionalProperties": false,
1275 "$schema": "http://json-schema.org/draft-07/schema#",
1276 },
1277 },
1278 },
1279 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoice("test-tool")}[0],
1280 })
1281
1282 require.NoError(t, err)
1283 require.Len(t, result.Content, 1)
1284
1285 toolCall, ok := result.Content[0].(fantasy.ToolCallContent)
1286 require.True(t, ok)
1287 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", toolCall.ToolCallID)
1288 require.Equal(t, "test-tool", toolCall.ToolName)
1289 require.Equal(t, `{"value":"Spark"}`, toolCall.Input)
1290 })
1291
1292 t.Run("should handle ToolChoiceRequired", func(t *testing.T) {
1293 t.Parallel()
1294
1295 server := newMockServer()
1296 defer server.close()
1297
1298 server.prepareJSONResponse(map[string]any{
1299 "content": "",
1300 })
1301
1302 provider, err := New(
1303 WithAPIKey("test-api-key"),
1304 WithBaseURL(server.server.URL),
1305 )
1306 require.NoError(t, err)
1307 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1308
1309 _, err = model.Generate(context.Background(), fantasy.Call{
1310 Prompt: testPrompt,
1311 Tools: []fantasy.Tool{
1312 fantasy.FunctionTool{
1313 Name: "test-tool",
1314 InputSchema: map[string]any{
1315 "type": "object",
1316 "properties": map[string]any{
1317 "value": map[string]any{
1318 "type": "string",
1319 },
1320 },
1321 "required": []string{"value"},
1322 "additionalProperties": false,
1323 "$schema": "http://json-schema.org/draft-07/schema#",
1324 },
1325 },
1326 },
1327 ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoiceRequired}[0],
1328 })
1329
1330 require.NoError(t, err)
1331 require.Len(t, server.calls, 1)
1332
1333 call := server.calls[0]
1334 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1335
1336 // Verify tool is present
1337 tools := call.body["tools"].([]any)
1338 require.Len(t, tools, 1)
1339
1340 tool := tools[0].(map[string]any)
1341 require.Equal(t, "function", tool["type"])
1342
1343 function := tool["function"].(map[string]any)
1344 require.Equal(t, "test-tool", function["name"])
1345
1346 // Verify tool_choice is set to "required" (not a function name)
1347 toolChoice := call.body["tool_choice"]
1348 require.Equal(t, "required", toolChoice)
1349 })
1350
1351 t.Run("should parse annotations/citations", func(t *testing.T) {
1352 t.Parallel()
1353
1354 server := newMockServer()
1355 defer server.close()
1356
1357 server.prepareJSONResponse(map[string]any{
1358 "content": "Based on the search results [doc1], I found information.",
1359 "annotations": []map[string]any{
1360 {
1361 "type": "url_citation",
1362 "url_citation": map[string]any{
1363 "start_index": 24,
1364 "end_index": 29,
1365 "url": "https://example.com/doc1.pdf",
1366 "title": "Document 1",
1367 },
1368 },
1369 },
1370 })
1371
1372 provider, err := New(
1373 WithAPIKey("test-api-key"),
1374 WithBaseURL(server.server.URL),
1375 )
1376 require.NoError(t, err)
1377 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1378
1379 result, err := model.Generate(context.Background(), fantasy.Call{
1380 Prompt: testPrompt,
1381 })
1382
1383 require.NoError(t, err)
1384 require.Len(t, result.Content, 2)
1385
1386 textContent, ok := result.Content[0].(fantasy.TextContent)
1387 require.True(t, ok)
1388 require.Equal(t, "Based on the search results [doc1], I found information.", textContent.Text)
1389
1390 sourceContent, ok := result.Content[1].(fantasy.SourceContent)
1391 require.True(t, ok)
1392 require.Equal(t, fantasy.SourceTypeURL, sourceContent.SourceType)
1393 require.Equal(t, "https://example.com/doc1.pdf", sourceContent.URL)
1394 require.Equal(t, "Document 1", sourceContent.Title)
1395 require.NotEmpty(t, sourceContent.ID)
1396 })
1397
1398 t.Run("should return cached_tokens in prompt_details_tokens", func(t *testing.T) {
1399 t.Parallel()
1400
1401 server := newMockServer()
1402 defer server.close()
1403
1404 server.prepareJSONResponse(map[string]any{
1405 "usage": map[string]any{
1406 "prompt_tokens": 15,
1407 "completion_tokens": 20,
1408 "total_tokens": 35,
1409 "prompt_tokens_details": map[string]any{
1410 "cached_tokens": 1152,
1411 },
1412 },
1413 })
1414
1415 provider, err := New(
1416 WithAPIKey("test-api-key"),
1417 WithBaseURL(server.server.URL),
1418 )
1419 require.NoError(t, err)
1420 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1421
1422 result, err := model.Generate(context.Background(), fantasy.Call{
1423 Prompt: testPrompt,
1424 })
1425
1426 require.NoError(t, err)
1427 require.Equal(t, int64(1152), result.Usage.CacheReadTokens)
1428 require.Equal(t, int64(15), result.Usage.InputTokens)
1429 require.Equal(t, int64(20), result.Usage.OutputTokens)
1430 require.Equal(t, int64(35), result.Usage.TotalTokens)
1431 })
1432
1433 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
1434 t.Parallel()
1435
1436 server := newMockServer()
1437 defer server.close()
1438
1439 server.prepareJSONResponse(map[string]any{
1440 "usage": map[string]any{
1441 "prompt_tokens": 15,
1442 "completion_tokens": 20,
1443 "total_tokens": 35,
1444 "completion_tokens_details": map[string]any{
1445 "accepted_prediction_tokens": 123,
1446 "rejected_prediction_tokens": 456,
1447 },
1448 },
1449 })
1450
1451 provider, err := New(
1452 WithAPIKey("test-api-key"),
1453 WithBaseURL(server.server.URL),
1454 )
1455 require.NoError(t, err)
1456 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1457
1458 result, err := model.Generate(context.Background(), fantasy.Call{
1459 Prompt: testPrompt,
1460 })
1461
1462 require.NoError(t, err)
1463 require.NotNil(t, result.ProviderMetadata)
1464
1465 openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
1466
1467 require.True(t, ok)
1468 require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
1469 require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
1470 })
1471
1472 t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) {
1473 t.Parallel()
1474
1475 server := newMockServer()
1476 defer server.close()
1477
1478 server.prepareJSONResponse(map[string]any{})
1479
1480 provider, err := New(
1481 WithAPIKey("test-api-key"),
1482 WithBaseURL(server.server.URL),
1483 )
1484 require.NoError(t, err)
1485 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1486
1487 result, err := model.Generate(context.Background(), fantasy.Call{
1488 Prompt: testPrompt,
1489 Temperature: &[]float64{0.5}[0],
1490 TopP: &[]float64{0.7}[0],
1491 FrequencyPenalty: &[]float64{0.2}[0],
1492 PresencePenalty: &[]float64{0.3}[0],
1493 })
1494
1495 require.NoError(t, err)
1496 require.Len(t, server.calls, 1)
1497
1498 call := server.calls[0]
1499 require.Equal(t, "o1-preview", call.body["model"])
1500
1501 messages := call.body["messages"].([]any)
1502 require.Len(t, messages, 1)
1503
1504 message := messages[0].(map[string]any)
1505 require.Equal(t, "user", message["role"])
1506 require.Equal(t, "Hello", message["content"])
1507
1508 // These should not be present
1509 require.Nil(t, call.body["temperature"])
1510 require.Nil(t, call.body["top_p"])
1511 require.Nil(t, call.body["frequency_penalty"])
1512 require.Nil(t, call.body["presence_penalty"])
1513
1514 // Should have warnings
1515 require.Len(t, result.Warnings, 4)
1516 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1517 require.Equal(t, "temperature", result.Warnings[0].Setting)
1518 require.Contains(t, result.Warnings[0].Details, "temperature is not supported for reasoning models")
1519 })
1520
1521 t.Run("should convert maxOutputTokens to max_completion_tokens for reasoning models", func(t *testing.T) {
1522 t.Parallel()
1523
1524 server := newMockServer()
1525 defer server.close()
1526
1527 server.prepareJSONResponse(map[string]any{})
1528
1529 provider, err := New(
1530 WithAPIKey("test-api-key"),
1531 WithBaseURL(server.server.URL),
1532 )
1533 require.NoError(t, err)
1534 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1535
1536 _, err = model.Generate(context.Background(), fantasy.Call{
1537 Prompt: testPrompt,
1538 MaxOutputTokens: &[]int64{1000}[0],
1539 })
1540
1541 require.NoError(t, err)
1542 require.Len(t, server.calls, 1)
1543
1544 call := server.calls[0]
1545 require.Equal(t, "o1-preview", call.body["model"])
1546 require.Equal(t, float64(1000), call.body["max_completion_tokens"])
1547 require.Nil(t, call.body["max_tokens"])
1548
1549 messages := call.body["messages"].([]any)
1550 require.Len(t, messages, 1)
1551
1552 message := messages[0].(map[string]any)
1553 require.Equal(t, "user", message["role"])
1554 require.Equal(t, "Hello", message["content"])
1555 })
1556
1557 t.Run("should return reasoning tokens", func(t *testing.T) {
1558 t.Parallel()
1559
1560 server := newMockServer()
1561 defer server.close()
1562
1563 server.prepareJSONResponse(map[string]any{
1564 "usage": map[string]any{
1565 "prompt_tokens": 15,
1566 "completion_tokens": 20,
1567 "total_tokens": 35,
1568 "completion_tokens_details": map[string]any{
1569 "reasoning_tokens": 10,
1570 },
1571 },
1572 })
1573
1574 provider, err := New(
1575 WithAPIKey("test-api-key"),
1576 WithBaseURL(server.server.URL),
1577 )
1578 require.NoError(t, err)
1579 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1580
1581 result, err := model.Generate(context.Background(), fantasy.Call{
1582 Prompt: testPrompt,
1583 })
1584
1585 require.NoError(t, err)
1586 require.Equal(t, int64(15), result.Usage.InputTokens)
1587 require.Equal(t, int64(20), result.Usage.OutputTokens)
1588 require.Equal(t, int64(35), result.Usage.TotalTokens)
1589 require.Equal(t, int64(10), result.Usage.ReasoningTokens)
1590 })
1591
1592 t.Run("should send max_completion_tokens extension setting", func(t *testing.T) {
1593 t.Parallel()
1594
1595 server := newMockServer()
1596 defer server.close()
1597
1598 server.prepareJSONResponse(map[string]any{
1599 "model": "o1-preview",
1600 })
1601
1602 provider, err := New(
1603 WithAPIKey("test-api-key"),
1604 WithBaseURL(server.server.URL),
1605 )
1606 require.NoError(t, err)
1607 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
1608
1609 _, err = model.Generate(context.Background(), fantasy.Call{
1610 Prompt: testPrompt,
1611 ProviderOptions: NewProviderOptions(&ProviderOptions{
1612 MaxCompletionTokens: fantasy.Opt(int64(255)),
1613 }),
1614 })
1615
1616 require.NoError(t, err)
1617 require.Len(t, server.calls, 1)
1618
1619 call := server.calls[0]
1620 require.Equal(t, "o1-preview", call.body["model"])
1621 require.Equal(t, float64(255), call.body["max_completion_tokens"])
1622
1623 messages := call.body["messages"].([]any)
1624 require.Len(t, messages, 1)
1625
1626 message := messages[0].(map[string]any)
1627 require.Equal(t, "user", message["role"])
1628 require.Equal(t, "Hello", message["content"])
1629 })
1630
1631 t.Run("should send prediction extension setting", func(t *testing.T) {
1632 t.Parallel()
1633
1634 server := newMockServer()
1635 defer server.close()
1636
1637 server.prepareJSONResponse(map[string]any{
1638 "content": "",
1639 })
1640
1641 provider, err := New(
1642 WithAPIKey("test-api-key"),
1643 WithBaseURL(server.server.URL),
1644 )
1645 require.NoError(t, err)
1646 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1647
1648 _, err = model.Generate(context.Background(), fantasy.Call{
1649 Prompt: testPrompt,
1650 ProviderOptions: NewProviderOptions(&ProviderOptions{
1651 Prediction: map[string]any{
1652 "type": "content",
1653 "content": "Hello, World!",
1654 },
1655 }),
1656 })
1657
1658 require.NoError(t, err)
1659 require.Len(t, server.calls, 1)
1660
1661 call := server.calls[0]
1662 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1663
1664 prediction := call.body["prediction"].(map[string]any)
1665 require.Equal(t, "content", prediction["type"])
1666 require.Equal(t, "Hello, World!", prediction["content"])
1667
1668 messages := call.body["messages"].([]any)
1669 require.Len(t, messages, 1)
1670
1671 message := messages[0].(map[string]any)
1672 require.Equal(t, "user", message["role"])
1673 require.Equal(t, "Hello", message["content"])
1674 })
1675
1676 t.Run("should send store extension setting", func(t *testing.T) {
1677 t.Parallel()
1678
1679 server := newMockServer()
1680 defer server.close()
1681
1682 server.prepareJSONResponse(map[string]any{
1683 "content": "",
1684 })
1685
1686 provider, err := New(
1687 WithAPIKey("test-api-key"),
1688 WithBaseURL(server.server.URL),
1689 )
1690 require.NoError(t, err)
1691 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1692
1693 _, err = model.Generate(context.Background(), fantasy.Call{
1694 Prompt: testPrompt,
1695 ProviderOptions: NewProviderOptions(&ProviderOptions{
1696 Store: fantasy.Opt(true),
1697 }),
1698 })
1699
1700 require.NoError(t, err)
1701 require.Len(t, server.calls, 1)
1702
1703 call := server.calls[0]
1704 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1705 require.Equal(t, true, call.body["store"])
1706
1707 messages := call.body["messages"].([]any)
1708 require.Len(t, messages, 1)
1709
1710 message := messages[0].(map[string]any)
1711 require.Equal(t, "user", message["role"])
1712 require.Equal(t, "Hello", message["content"])
1713 })
1714
1715 t.Run("should send metadata extension values", func(t *testing.T) {
1716 t.Parallel()
1717
1718 server := newMockServer()
1719 defer server.close()
1720
1721 server.prepareJSONResponse(map[string]any{
1722 "content": "",
1723 })
1724
1725 provider, err := New(
1726 WithAPIKey("test-api-key"),
1727 WithBaseURL(server.server.URL),
1728 )
1729 require.NoError(t, err)
1730 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1731
1732 _, err = model.Generate(context.Background(), fantasy.Call{
1733 Prompt: testPrompt,
1734 ProviderOptions: NewProviderOptions(&ProviderOptions{
1735 Metadata: map[string]any{
1736 "custom": "value",
1737 },
1738 }),
1739 })
1740
1741 require.NoError(t, err)
1742 require.Len(t, server.calls, 1)
1743
1744 call := server.calls[0]
1745 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1746
1747 metadata := call.body["metadata"].(map[string]any)
1748 require.Equal(t, "value", metadata["custom"])
1749
1750 messages := call.body["messages"].([]any)
1751 require.Len(t, messages, 1)
1752
1753 message := messages[0].(map[string]any)
1754 require.Equal(t, "user", message["role"])
1755 require.Equal(t, "Hello", message["content"])
1756 })
1757
1758 t.Run("should send promptCacheKey extension value", func(t *testing.T) {
1759 t.Parallel()
1760
1761 server := newMockServer()
1762 defer server.close()
1763
1764 server.prepareJSONResponse(map[string]any{
1765 "content": "",
1766 })
1767
1768 provider, err := New(
1769 WithAPIKey("test-api-key"),
1770 WithBaseURL(server.server.URL),
1771 )
1772 require.NoError(t, err)
1773 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1774
1775 _, err = model.Generate(context.Background(), fantasy.Call{
1776 Prompt: testPrompt,
1777 ProviderOptions: NewProviderOptions(&ProviderOptions{
1778 PromptCacheKey: fantasy.Opt("test-cache-key-123"),
1779 }),
1780 })
1781
1782 require.NoError(t, err)
1783 require.Len(t, server.calls, 1)
1784
1785 call := server.calls[0]
1786 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1787 require.Equal(t, "test-cache-key-123", call.body["prompt_cache_key"])
1788
1789 messages := call.body["messages"].([]any)
1790 require.Len(t, messages, 1)
1791
1792 message := messages[0].(map[string]any)
1793 require.Equal(t, "user", message["role"])
1794 require.Equal(t, "Hello", message["content"])
1795 })
1796
1797 t.Run("should send safety_identifier extension value", func(t *testing.T) {
1798 t.Parallel()
1799
1800 server := newMockServer()
1801 defer server.close()
1802
1803 server.prepareJSONResponse(map[string]any{
1804 "content": "",
1805 })
1806
1807 provider, err := New(
1808 WithAPIKey("test-api-key"),
1809 WithBaseURL(server.server.URL),
1810 )
1811 require.NoError(t, err)
1812 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1813
1814 _, err = model.Generate(context.Background(), fantasy.Call{
1815 Prompt: testPrompt,
1816 ProviderOptions: NewProviderOptions(&ProviderOptions{
1817 SafetyIdentifier: fantasy.Opt("test-safety-identifier-123"),
1818 }),
1819 })
1820
1821 require.NoError(t, err)
1822 require.Len(t, server.calls, 1)
1823
1824 call := server.calls[0]
1825 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1826 require.Equal(t, "test-safety-identifier-123", call.body["safety_identifier"])
1827
1828 messages := call.body["messages"].([]any)
1829 require.Len(t, messages, 1)
1830
1831 message := messages[0].(map[string]any)
1832 require.Equal(t, "user", message["role"])
1833 require.Equal(t, "Hello", message["content"])
1834 })
1835
1836 t.Run("should remove temperature setting for search preview models", func(t *testing.T) {
1837 t.Parallel()
1838
1839 server := newMockServer()
1840 defer server.close()
1841
1842 server.prepareJSONResponse(map[string]any{})
1843
1844 provider, err := New(
1845 WithAPIKey("test-api-key"),
1846 WithBaseURL(server.server.URL),
1847 )
1848 require.NoError(t, err)
1849 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-search-preview")
1850
1851 result, err := model.Generate(context.Background(), fantasy.Call{
1852 Prompt: testPrompt,
1853 Temperature: &[]float64{0.7}[0],
1854 })
1855
1856 require.NoError(t, err)
1857 require.Len(t, server.calls, 1)
1858
1859 call := server.calls[0]
1860 require.Equal(t, "gpt-4o-search-preview", call.body["model"])
1861 require.Nil(t, call.body["temperature"])
1862
1863 require.Len(t, result.Warnings, 1)
1864 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1865 require.Equal(t, "temperature", result.Warnings[0].Setting)
1866 require.Contains(t, result.Warnings[0].Details, "search preview models")
1867 })
1868
1869 t.Run("should send ServiceTier flex processing setting", func(t *testing.T) {
1870 t.Parallel()
1871
1872 server := newMockServer()
1873 defer server.close()
1874
1875 server.prepareJSONResponse(map[string]any{
1876 "content": "",
1877 })
1878
1879 provider, err := New(
1880 WithAPIKey("test-api-key"),
1881 WithBaseURL(server.server.URL),
1882 )
1883 require.NoError(t, err)
1884 model, _ := provider.LanguageModel(t.Context(), "o3-mini")
1885
1886 _, err = model.Generate(context.Background(), fantasy.Call{
1887 Prompt: testPrompt,
1888 ProviderOptions: NewProviderOptions(&ProviderOptions{
1889 ServiceTier: fantasy.Opt("flex"),
1890 }),
1891 })
1892
1893 require.NoError(t, err)
1894 require.Len(t, server.calls, 1)
1895
1896 call := server.calls[0]
1897 require.Equal(t, "o3-mini", call.body["model"])
1898 require.Equal(t, "flex", call.body["service_tier"])
1899
1900 messages := call.body["messages"].([]any)
1901 require.Len(t, messages, 1)
1902
1903 message := messages[0].(map[string]any)
1904 require.Equal(t, "user", message["role"])
1905 require.Equal(t, "Hello", message["content"])
1906 })
1907
1908 t.Run("should show warning when using flex processing with unsupported model", func(t *testing.T) {
1909 t.Parallel()
1910
1911 server := newMockServer()
1912 defer server.close()
1913
1914 server.prepareJSONResponse(map[string]any{})
1915
1916 provider, err := New(
1917 WithAPIKey("test-api-key"),
1918 WithBaseURL(server.server.URL),
1919 )
1920 require.NoError(t, err)
1921 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1922
1923 result, err := model.Generate(context.Background(), fantasy.Call{
1924 Prompt: testPrompt,
1925 ProviderOptions: NewProviderOptions(&ProviderOptions{
1926 ServiceTier: fantasy.Opt("flex"),
1927 }),
1928 })
1929
1930 require.NoError(t, err)
1931 require.Len(t, server.calls, 1)
1932
1933 call := server.calls[0]
1934 require.Nil(t, call.body["service_tier"])
1935
1936 require.Len(t, result.Warnings, 1)
1937 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1938 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
1939 require.Contains(t, result.Warnings[0].Details, "flex processing is only available")
1940 })
1941
1942 t.Run("should send serviceTier priority processing setting", func(t *testing.T) {
1943 t.Parallel()
1944
1945 server := newMockServer()
1946 defer server.close()
1947
1948 server.prepareJSONResponse(map[string]any{})
1949
1950 provider, err := New(
1951 WithAPIKey("test-api-key"),
1952 WithBaseURL(server.server.URL),
1953 )
1954 require.NoError(t, err)
1955 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
1956
1957 _, err = model.Generate(context.Background(), fantasy.Call{
1958 Prompt: testPrompt,
1959 ProviderOptions: NewProviderOptions(&ProviderOptions{
1960 ServiceTier: fantasy.Opt("priority"),
1961 }),
1962 })
1963
1964 require.NoError(t, err)
1965 require.Len(t, server.calls, 1)
1966
1967 call := server.calls[0]
1968 require.Equal(t, "gpt-4o-mini", call.body["model"])
1969 require.Equal(t, "priority", call.body["service_tier"])
1970
1971 messages := call.body["messages"].([]any)
1972 require.Len(t, messages, 1)
1973
1974 message := messages[0].(map[string]any)
1975 require.Equal(t, "user", message["role"])
1976 require.Equal(t, "Hello", message["content"])
1977 })
1978
1979 t.Run("should show warning when using priority processing with unsupported model", func(t *testing.T) {
1980 t.Parallel()
1981
1982 server := newMockServer()
1983 defer server.close()
1984
1985 server.prepareJSONResponse(map[string]any{})
1986
1987 provider, err := New(
1988 WithAPIKey("test-api-key"),
1989 WithBaseURL(server.server.URL),
1990 )
1991 require.NoError(t, err)
1992 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
1993
1994 result, err := model.Generate(context.Background(), fantasy.Call{
1995 Prompt: testPrompt,
1996 ProviderOptions: NewProviderOptions(&ProviderOptions{
1997 ServiceTier: fantasy.Opt("priority"),
1998 }),
1999 })
2000
2001 require.NoError(t, err)
2002 require.Len(t, server.calls, 1)
2003
2004 call := server.calls[0]
2005 require.Nil(t, call.body["service_tier"])
2006
2007 require.Len(t, result.Warnings, 1)
2008 require.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
2009 require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
2010 require.Contains(t, result.Warnings[0].Details, "priority processing is only available")
2011 })
2012}
2013
2014type streamingMockServer struct {
2015 server *httptest.Server
2016 chunks []string
2017 calls []mockCall
2018}
2019
2020func newStreamingMockServer() *streamingMockServer {
2021 sms := &streamingMockServer{
2022 calls: make([]mockCall, 0),
2023 }
2024
2025 sms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2026 // Record the call
2027 call := mockCall{
2028 method: r.Method,
2029 path: r.URL.Path,
2030 headers: make(map[string]string),
2031 }
2032
2033 for k, v := range r.Header {
2034 if len(v) > 0 {
2035 call.headers[k] = v[0]
2036 }
2037 }
2038
2039 // Parse request body
2040 if r.Body != nil {
2041 var body map[string]any
2042 json.NewDecoder(r.Body).Decode(&body)
2043 call.body = body
2044 }
2045
2046 sms.calls = append(sms.calls, call)
2047
2048 // Set streaming headers
2049 w.Header().Set("Content-Type", "text/event-stream")
2050 w.Header().Set("Cache-Control", "no-cache")
2051 w.Header().Set("Connection", "keep-alive")
2052
2053 // Add custom headers if any
2054 for _, chunk := range sms.chunks {
2055 if strings.HasPrefix(chunk, "HEADER:") {
2056 parts := strings.SplitN(chunk[7:], ":", 2)
2057 if len(parts) == 2 {
2058 w.Header().Set(parts[0], parts[1])
2059 }
2060 continue
2061 }
2062 }
2063
2064 w.WriteHeader(http.StatusOK)
2065
2066 // Write chunks
2067 for _, chunk := range sms.chunks {
2068 if strings.HasPrefix(chunk, "HEADER:") {
2069 continue
2070 }
2071 w.Write([]byte(chunk))
2072 if f, ok := w.(http.Flusher); ok {
2073 f.Flush()
2074 }
2075 }
2076 }))
2077
2078 return sms
2079}
2080
2081func (sms *streamingMockServer) close() {
2082 sms.server.Close()
2083}
2084
2085func (sms *streamingMockServer) prepareStreamResponse(opts map[string]any) {
2086 content := []string{}
2087 if c, ok := opts["content"].([]string); ok {
2088 content = c
2089 }
2090
2091 usage := map[string]any{
2092 "prompt_tokens": 17,
2093 "total_tokens": 244,
2094 "completion_tokens": 227,
2095 }
2096 if u, ok := opts["usage"].(map[string]any); ok {
2097 usage = u
2098 }
2099
2100 logprobs := map[string]any{}
2101 if l, ok := opts["logprobs"].(map[string]any); ok {
2102 logprobs = l
2103 }
2104
2105 finishReason := "stop"
2106 if fr, ok := opts["finish_reason"].(string); ok {
2107 finishReason = fr
2108 }
2109
2110 model := "gpt-3.5-turbo-0613"
2111 if m, ok := opts["model"].(string); ok {
2112 model = m
2113 }
2114
2115 headers := map[string]string{}
2116 if h, ok := opts["headers"].(map[string]string); ok {
2117 headers = h
2118 }
2119
2120 chunks := []string{}
2121
2122 // Add custom headers
2123 for k, v := range headers {
2124 chunks = append(chunks, "HEADER:"+k+":"+v)
2125 }
2126
2127 // Initial chunk with role
2128 initialChunk := map[string]any{
2129 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2130 "object": "chat.completion.chunk",
2131 "created": 1702657020,
2132 "model": model,
2133 "system_fingerprint": nil,
2134 "choices": []map[string]any{
2135 {
2136 "index": 0,
2137 "delta": map[string]any{
2138 "role": "assistant",
2139 "content": "",
2140 },
2141 "finish_reason": nil,
2142 },
2143 },
2144 }
2145 initialData, _ := json.Marshal(initialChunk)
2146 chunks = append(chunks, "data: "+string(initialData)+"\n\n")
2147
2148 // Content chunks
2149 for i, text := range content {
2150 contentChunk := map[string]any{
2151 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2152 "object": "chat.completion.chunk",
2153 "created": 1702657020,
2154 "model": model,
2155 "system_fingerprint": nil,
2156 "choices": []map[string]any{
2157 {
2158 "index": 1,
2159 "delta": map[string]any{
2160 "content": text,
2161 },
2162 "finish_reason": nil,
2163 },
2164 },
2165 }
2166 contentData, _ := json.Marshal(contentChunk)
2167 chunks = append(chunks, "data: "+string(contentData)+"\n\n")
2168
2169 // Add annotations if this is the last content chunk and we have annotations
2170 if i == len(content)-1 {
2171 if annotations, ok := opts["annotations"].([]map[string]any); ok {
2172 annotationChunk := map[string]any{
2173 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2174 "object": "chat.completion.chunk",
2175 "created": 1702657020,
2176 "model": model,
2177 "system_fingerprint": nil,
2178 "choices": []map[string]any{
2179 {
2180 "index": 1,
2181 "delta": map[string]any{
2182 "annotations": annotations,
2183 },
2184 "finish_reason": nil,
2185 },
2186 },
2187 }
2188 annotationData, _ := json.Marshal(annotationChunk)
2189 chunks = append(chunks, "data: "+string(annotationData)+"\n\n")
2190 }
2191 }
2192 }
2193
2194 // Finish chunk
2195 finishChunk := map[string]any{
2196 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2197 "object": "chat.completion.chunk",
2198 "created": 1702657020,
2199 "model": model,
2200 "system_fingerprint": nil,
2201 "choices": []map[string]any{
2202 {
2203 "index": 0,
2204 "delta": map[string]any{},
2205 "finish_reason": finishReason,
2206 },
2207 },
2208 }
2209
2210 if len(logprobs) > 0 {
2211 finishChunk["choices"].([]map[string]any)[0]["logprobs"] = logprobs
2212 }
2213
2214 finishData, _ := json.Marshal(finishChunk)
2215 chunks = append(chunks, "data: "+string(finishData)+"\n\n")
2216
2217 // Usage chunk
2218 usageChunk := map[string]any{
2219 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2220 "object": "chat.completion.chunk",
2221 "created": 1702657020,
2222 "model": model,
2223 "system_fingerprint": "fp_3bc1b5746c",
2224 "choices": []map[string]any{},
2225 "usage": usage,
2226 }
2227 usageData, _ := json.Marshal(usageChunk)
2228 chunks = append(chunks, "data: "+string(usageData)+"\n\n")
2229
2230 // Done
2231 chunks = append(chunks, "data: [DONE]\n\n")
2232
2233 sms.chunks = chunks
2234}
2235
2236func (sms *streamingMockServer) prepareToolStreamResponse() {
2237 chunks := []string{
2238 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2239 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2240 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"value"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2241 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2242 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2243 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2244 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2245 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2246 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
2247 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
2248 "data: [DONE]\n\n",
2249 }
2250 sms.chunks = chunks
2251}
2252
2253func (sms *streamingMockServer) prepareErrorStreamResponse() {
2254 chunks := []string{
2255 `data: {"error":{"message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error.","type":"server_error","param":null,"code":null}}` + "\n\n",
2256 "data: [DONE]\n\n",
2257 }
2258 sms.chunks = chunks
2259}
2260
2261func collectStreamParts(stream fantasy.StreamResponse) ([]fantasy.StreamPart, error) {
2262 var parts []fantasy.StreamPart
2263 for part := range stream {
2264 parts = append(parts, part)
2265 if part.Type == fantasy.StreamPartTypeError {
2266 break
2267 }
2268 if part.Type == fantasy.StreamPartTypeFinish {
2269 break
2270 }
2271 }
2272 return parts, nil
2273}
2274
2275func TestDoStream(t *testing.T) {
2276 t.Parallel()
2277
2278 t.Run("should stream text deltas", func(t *testing.T) {
2279 t.Parallel()
2280
2281 server := newStreamingMockServer()
2282 defer server.close()
2283
2284 server.prepareStreamResponse(map[string]any{
2285 "content": []string{"Hello", ", ", "World!"},
2286 "finish_reason": "stop",
2287 "usage": map[string]any{
2288 "prompt_tokens": 17,
2289 "total_tokens": 244,
2290 "completion_tokens": 227,
2291 },
2292 "logprobs": testLogprobs,
2293 })
2294
2295 provider, err := New(
2296 WithAPIKey("test-api-key"),
2297 WithBaseURL(server.server.URL),
2298 )
2299 require.NoError(t, err)
2300 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2301
2302 stream, err := model.Stream(context.Background(), fantasy.Call{
2303 Prompt: testPrompt,
2304 })
2305
2306 require.NoError(t, err)
2307
2308 parts, err := collectStreamParts(stream)
2309 require.NoError(t, err)
2310
2311 // Verify stream structure
2312 require.True(t, len(parts) >= 4) // text-start, deltas, text-end, finish
2313
2314 // Find text parts
2315 textStart, textEnd, finish := -1, -1, -1
2316 var deltas []string
2317
2318 for i, part := range parts {
2319 switch part.Type {
2320 case fantasy.StreamPartTypeTextStart:
2321 textStart = i
2322 case fantasy.StreamPartTypeTextDelta:
2323 deltas = append(deltas, part.Delta)
2324 case fantasy.StreamPartTypeTextEnd:
2325 textEnd = i
2326 case fantasy.StreamPartTypeFinish:
2327 finish = i
2328 }
2329 }
2330
2331 require.NotEqual(t, -1, textStart)
2332 require.NotEqual(t, -1, textEnd)
2333 require.NotEqual(t, -1, finish)
2334 require.Equal(t, []string{"Hello", ", ", "World!"}, deltas)
2335
2336 // Check finish part
2337 finishPart := parts[finish]
2338 require.Equal(t, fantasy.FinishReasonStop, finishPart.FinishReason)
2339 require.Equal(t, int64(17), finishPart.Usage.InputTokens)
2340 require.Equal(t, int64(227), finishPart.Usage.OutputTokens)
2341 require.Equal(t, int64(244), finishPart.Usage.TotalTokens)
2342 })
2343
2344 t.Run("should stream tool deltas", func(t *testing.T) {
2345 t.Parallel()
2346
2347 server := newStreamingMockServer()
2348 defer server.close()
2349
2350 server.prepareToolStreamResponse()
2351
2352 provider, err := New(
2353 WithAPIKey("test-api-key"),
2354 WithBaseURL(server.server.URL),
2355 )
2356 require.NoError(t, err)
2357 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2358
2359 stream, err := model.Stream(context.Background(), fantasy.Call{
2360 Prompt: testPrompt,
2361 Tools: []fantasy.Tool{
2362 fantasy.FunctionTool{
2363 Name: "test-tool",
2364 InputSchema: map[string]any{
2365 "type": "object",
2366 "properties": map[string]any{
2367 "value": map[string]any{
2368 "type": "string",
2369 },
2370 },
2371 "required": []string{"value"},
2372 "additionalProperties": false,
2373 "$schema": "http://json-schema.org/draft-07/schema#",
2374 },
2375 },
2376 },
2377 })
2378
2379 require.NoError(t, err)
2380
2381 parts, err := collectStreamParts(stream)
2382 require.NoError(t, err)
2383
2384 // Find tool-related parts
2385 toolInputStart, toolInputEnd, toolCall := -1, -1, -1
2386 var toolDeltas []string
2387
2388 for i, part := range parts {
2389 switch part.Type {
2390 case fantasy.StreamPartTypeToolInputStart:
2391 toolInputStart = i
2392 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2393 require.Equal(t, "test-tool", part.ToolCallName)
2394 case fantasy.StreamPartTypeToolInputDelta:
2395 toolDeltas = append(toolDeltas, part.Delta)
2396 case fantasy.StreamPartTypeToolInputEnd:
2397 toolInputEnd = i
2398 case fantasy.StreamPartTypeToolCall:
2399 toolCall = i
2400 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2401 require.Equal(t, "test-tool", part.ToolCallName)
2402 require.Equal(t, `{"value":"Sparkle Day"}`, part.ToolCallInput)
2403 }
2404 }
2405
2406 require.NotEqual(t, -1, toolInputStart)
2407 require.NotEqual(t, -1, toolInputEnd)
2408 require.NotEqual(t, -1, toolCall)
2409
2410 // Verify tool deltas combine to form the complete input
2411 var fullInput strings.Builder
2412 for _, delta := range toolDeltas {
2413 fullInput.WriteString(delta)
2414 }
2415 require.Equal(t, `{"value":"Sparkle Day"}`, fullInput.String())
2416 })
2417
2418 t.Run("should stream annotations/citations", func(t *testing.T) {
2419 t.Parallel()
2420
2421 server := newStreamingMockServer()
2422 defer server.close()
2423
2424 server.prepareStreamResponse(map[string]any{
2425 "content": []string{"Based on search results"},
2426 "annotations": []map[string]any{
2427 {
2428 "type": "url_citation",
2429 "url_citation": map[string]any{
2430 "start_index": 24,
2431 "end_index": 29,
2432 "url": "https://example.com/doc1.pdf",
2433 "title": "Document 1",
2434 },
2435 },
2436 },
2437 })
2438
2439 provider, err := New(
2440 WithAPIKey("test-api-key"),
2441 WithBaseURL(server.server.URL),
2442 )
2443 require.NoError(t, err)
2444 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2445
2446 stream, err := model.Stream(context.Background(), fantasy.Call{
2447 Prompt: testPrompt,
2448 })
2449
2450 require.NoError(t, err)
2451
2452 parts, err := collectStreamParts(stream)
2453 require.NoError(t, err)
2454
2455 // Find source part
2456 var sourcePart *fantasy.StreamPart
2457 for _, part := range parts {
2458 if part.Type == fantasy.StreamPartTypeSource {
2459 sourcePart = &part
2460 break
2461 }
2462 }
2463
2464 require.NotNil(t, sourcePart)
2465 require.Equal(t, fantasy.SourceTypeURL, sourcePart.SourceType)
2466 require.Equal(t, "https://example.com/doc1.pdf", sourcePart.URL)
2467 require.Equal(t, "Document 1", sourcePart.Title)
2468 require.NotEmpty(t, sourcePart.ID)
2469 })
2470
2471 t.Run("should handle error stream parts", func(t *testing.T) {
2472 t.Parallel()
2473
2474 server := newStreamingMockServer()
2475 defer server.close()
2476
2477 server.prepareErrorStreamResponse()
2478
2479 provider, err := New(
2480 WithAPIKey("test-api-key"),
2481 WithBaseURL(server.server.URL),
2482 )
2483 require.NoError(t, err)
2484 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2485
2486 stream, err := model.Stream(context.Background(), fantasy.Call{
2487 Prompt: testPrompt,
2488 })
2489
2490 require.NoError(t, err)
2491
2492 parts, err := collectStreamParts(stream)
2493 require.NoError(t, err)
2494
2495 // Should have error and finish parts
2496 require.True(t, len(parts) >= 1)
2497
2498 // Find error part
2499 var errorPart *fantasy.StreamPart
2500 for _, part := range parts {
2501 if part.Type == fantasy.StreamPartTypeError {
2502 errorPart = &part
2503 break
2504 }
2505 }
2506
2507 require.NotNil(t, errorPart)
2508 require.NotNil(t, errorPart.Error)
2509 })
2510
2511 t.Run("should send request body", func(t *testing.T) {
2512 t.Parallel()
2513
2514 server := newStreamingMockServer()
2515 defer server.close()
2516
2517 server.prepareStreamResponse(map[string]any{
2518 "content": []string{},
2519 })
2520
2521 provider, err := New(
2522 WithAPIKey("test-api-key"),
2523 WithBaseURL(server.server.URL),
2524 )
2525 require.NoError(t, err)
2526 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2527
2528 _, err = model.Stream(context.Background(), fantasy.Call{
2529 Prompt: testPrompt,
2530 })
2531
2532 require.NoError(t, err)
2533 require.Len(t, server.calls, 1)
2534
2535 call := server.calls[0]
2536 require.Equal(t, "POST", call.method)
2537 require.Equal(t, "/chat/completions", call.path)
2538 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2539 require.Equal(t, true, call.body["stream"])
2540
2541 streamOptions := call.body["stream_options"].(map[string]any)
2542 require.Equal(t, true, streamOptions["include_usage"])
2543
2544 messages := call.body["messages"].([]any)
2545 require.Len(t, messages, 1)
2546
2547 message := messages[0].(map[string]any)
2548 require.Equal(t, "user", message["role"])
2549 require.Equal(t, "Hello", message["content"])
2550 })
2551
2552 t.Run("should return cached tokens in providerMetadata", func(t *testing.T) {
2553 t.Parallel()
2554
2555 server := newStreamingMockServer()
2556 defer server.close()
2557
2558 server.prepareStreamResponse(map[string]any{
2559 "content": []string{},
2560 "usage": map[string]any{
2561 "prompt_tokens": 15,
2562 "completion_tokens": 20,
2563 "total_tokens": 35,
2564 "prompt_tokens_details": map[string]any{
2565 "cached_tokens": 1152,
2566 },
2567 },
2568 })
2569
2570 provider, err := New(
2571 WithAPIKey("test-api-key"),
2572 WithBaseURL(server.server.URL),
2573 )
2574 require.NoError(t, err)
2575 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2576
2577 stream, err := model.Stream(context.Background(), fantasy.Call{
2578 Prompt: testPrompt,
2579 })
2580
2581 require.NoError(t, err)
2582
2583 parts, err := collectStreamParts(stream)
2584 require.NoError(t, err)
2585
2586 // Find finish part
2587 var finishPart *fantasy.StreamPart
2588 for _, part := range parts {
2589 if part.Type == fantasy.StreamPartTypeFinish {
2590 finishPart = &part
2591 break
2592 }
2593 }
2594
2595 require.NotNil(t, finishPart)
2596 require.Equal(t, int64(1152), finishPart.Usage.CacheReadTokens)
2597 require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2598 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2599 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2600 })
2601
2602 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
2603 t.Parallel()
2604
2605 server := newStreamingMockServer()
2606 defer server.close()
2607
2608 server.prepareStreamResponse(map[string]any{
2609 "content": []string{},
2610 "usage": map[string]any{
2611 "prompt_tokens": 15,
2612 "completion_tokens": 20,
2613 "total_tokens": 35,
2614 "completion_tokens_details": map[string]any{
2615 "accepted_prediction_tokens": 123,
2616 "rejected_prediction_tokens": 456,
2617 },
2618 },
2619 })
2620
2621 provider, err := New(
2622 WithAPIKey("test-api-key"),
2623 WithBaseURL(server.server.URL),
2624 )
2625 require.NoError(t, err)
2626 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2627
2628 stream, err := model.Stream(context.Background(), fantasy.Call{
2629 Prompt: testPrompt,
2630 })
2631
2632 require.NoError(t, err)
2633
2634 parts, err := collectStreamParts(stream)
2635 require.NoError(t, err)
2636
2637 // Find finish part
2638 var finishPart *fantasy.StreamPart
2639 for _, part := range parts {
2640 if part.Type == fantasy.StreamPartTypeFinish {
2641 finishPart = &part
2642 break
2643 }
2644 }
2645
2646 require.NotNil(t, finishPart)
2647 require.NotNil(t, finishPart.ProviderMetadata)
2648
2649 openaiMeta, ok := finishPart.ProviderMetadata["openai"].(*ProviderMetadata)
2650 require.True(t, ok)
2651 require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
2652 require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
2653 })
2654
2655 t.Run("should send store extension setting", func(t *testing.T) {
2656 t.Parallel()
2657
2658 server := newStreamingMockServer()
2659 defer server.close()
2660
2661 server.prepareStreamResponse(map[string]any{
2662 "content": []string{},
2663 })
2664
2665 provider, err := New(
2666 WithAPIKey("test-api-key"),
2667 WithBaseURL(server.server.URL),
2668 )
2669 require.NoError(t, err)
2670 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2671
2672 _, err = model.Stream(context.Background(), fantasy.Call{
2673 Prompt: testPrompt,
2674 ProviderOptions: NewProviderOptions(&ProviderOptions{
2675 Store: fantasy.Opt(true),
2676 }),
2677 })
2678
2679 require.NoError(t, err)
2680 require.Len(t, server.calls, 1)
2681
2682 call := server.calls[0]
2683 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2684 require.Equal(t, true, call.body["stream"])
2685 require.Equal(t, true, call.body["store"])
2686
2687 streamOptions := call.body["stream_options"].(map[string]any)
2688 require.Equal(t, true, streamOptions["include_usage"])
2689
2690 messages := call.body["messages"].([]any)
2691 require.Len(t, messages, 1)
2692
2693 message := messages[0].(map[string]any)
2694 require.Equal(t, "user", message["role"])
2695 require.Equal(t, "Hello", message["content"])
2696 })
2697
2698 t.Run("should send metadata extension values", func(t *testing.T) {
2699 t.Parallel()
2700
2701 server := newStreamingMockServer()
2702 defer server.close()
2703
2704 server.prepareStreamResponse(map[string]any{
2705 "content": []string{},
2706 })
2707
2708 provider, err := New(
2709 WithAPIKey("test-api-key"),
2710 WithBaseURL(server.server.URL),
2711 )
2712 require.NoError(t, err)
2713 model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
2714
2715 _, err = model.Stream(context.Background(), fantasy.Call{
2716 Prompt: testPrompt,
2717 ProviderOptions: NewProviderOptions(&ProviderOptions{
2718 Metadata: map[string]any{
2719 "custom": "value",
2720 },
2721 }),
2722 })
2723
2724 require.NoError(t, err)
2725 require.Len(t, server.calls, 1)
2726
2727 call := server.calls[0]
2728 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2729 require.Equal(t, true, call.body["stream"])
2730
2731 metadata := call.body["metadata"].(map[string]any)
2732 require.Equal(t, "value", metadata["custom"])
2733
2734 streamOptions := call.body["stream_options"].(map[string]any)
2735 require.Equal(t, true, streamOptions["include_usage"])
2736
2737 messages := call.body["messages"].([]any)
2738 require.Len(t, messages, 1)
2739
2740 message := messages[0].(map[string]any)
2741 require.Equal(t, "user", message["role"])
2742 require.Equal(t, "Hello", message["content"])
2743 })
2744
2745 t.Run("should send serviceTier flex processing setting in streaming", func(t *testing.T) {
2746 t.Parallel()
2747
2748 server := newStreamingMockServer()
2749 defer server.close()
2750
2751 server.prepareStreamResponse(map[string]any{
2752 "content": []string{},
2753 })
2754
2755 provider, err := New(
2756 WithAPIKey("test-api-key"),
2757 WithBaseURL(server.server.URL),
2758 )
2759 require.NoError(t, err)
2760 model, _ := provider.LanguageModel(t.Context(), "o3-mini")
2761
2762 _, err = model.Stream(context.Background(), fantasy.Call{
2763 Prompt: testPrompt,
2764 ProviderOptions: NewProviderOptions(&ProviderOptions{
2765 ServiceTier: fantasy.Opt("flex"),
2766 }),
2767 })
2768
2769 require.NoError(t, err)
2770 require.Len(t, server.calls, 1)
2771
2772 call := server.calls[0]
2773 require.Equal(t, "o3-mini", call.body["model"])
2774 require.Equal(t, "flex", call.body["service_tier"])
2775 require.Equal(t, true, call.body["stream"])
2776
2777 streamOptions := call.body["stream_options"].(map[string]any)
2778 require.Equal(t, true, streamOptions["include_usage"])
2779
2780 messages := call.body["messages"].([]any)
2781 require.Len(t, messages, 1)
2782
2783 message := messages[0].(map[string]any)
2784 require.Equal(t, "user", message["role"])
2785 require.Equal(t, "Hello", message["content"])
2786 })
2787
2788 t.Run("should send serviceTier priority processing setting in streaming", func(t *testing.T) {
2789 t.Parallel()
2790
2791 server := newStreamingMockServer()
2792 defer server.close()
2793
2794 server.prepareStreamResponse(map[string]any{
2795 "content": []string{},
2796 })
2797
2798 provider, err := New(
2799 WithAPIKey("test-api-key"),
2800 WithBaseURL(server.server.URL),
2801 )
2802 require.NoError(t, err)
2803 model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
2804
2805 _, err = model.Stream(context.Background(), fantasy.Call{
2806 Prompt: testPrompt,
2807 ProviderOptions: NewProviderOptions(&ProviderOptions{
2808 ServiceTier: fantasy.Opt("priority"),
2809 }),
2810 })
2811
2812 require.NoError(t, err)
2813 require.Len(t, server.calls, 1)
2814
2815 call := server.calls[0]
2816 require.Equal(t, "gpt-4o-mini", call.body["model"])
2817 require.Equal(t, "priority", call.body["service_tier"])
2818 require.Equal(t, true, call.body["stream"])
2819
2820 streamOptions := call.body["stream_options"].(map[string]any)
2821 require.Equal(t, true, streamOptions["include_usage"])
2822
2823 messages := call.body["messages"].([]any)
2824 require.Len(t, messages, 1)
2825
2826 message := messages[0].(map[string]any)
2827 require.Equal(t, "user", message["role"])
2828 require.Equal(t, "Hello", message["content"])
2829 })
2830
2831 t.Run("should stream text delta for reasoning models", func(t *testing.T) {
2832 t.Parallel()
2833
2834 server := newStreamingMockServer()
2835 defer server.close()
2836
2837 server.prepareStreamResponse(map[string]any{
2838 "content": []string{"Hello, World!"},
2839 "model": "o1-preview",
2840 })
2841
2842 provider, err := New(
2843 WithAPIKey("test-api-key"),
2844 WithBaseURL(server.server.URL),
2845 )
2846 require.NoError(t, err)
2847 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
2848
2849 stream, err := model.Stream(context.Background(), fantasy.Call{
2850 Prompt: testPrompt,
2851 })
2852
2853 require.NoError(t, err)
2854
2855 parts, err := collectStreamParts(stream)
2856 require.NoError(t, err)
2857
2858 // Find text parts
2859 var textDeltas []string
2860 for _, part := range parts {
2861 if part.Type == fantasy.StreamPartTypeTextDelta {
2862 textDeltas = append(textDeltas, part.Delta)
2863 }
2864 }
2865
2866 // Should contain the text content (without empty delta)
2867 require.Equal(t, []string{"Hello, World!"}, textDeltas)
2868 })
2869
2870 t.Run("should send reasoning tokens", func(t *testing.T) {
2871 t.Parallel()
2872
2873 server := newStreamingMockServer()
2874 defer server.close()
2875
2876 server.prepareStreamResponse(map[string]any{
2877 "content": []string{"Hello, World!"},
2878 "model": "o1-preview",
2879 "usage": map[string]any{
2880 "prompt_tokens": 15,
2881 "completion_tokens": 20,
2882 "total_tokens": 35,
2883 "completion_tokens_details": map[string]any{
2884 "reasoning_tokens": 10,
2885 },
2886 },
2887 })
2888
2889 provider, err := New(
2890 WithAPIKey("test-api-key"),
2891 WithBaseURL(server.server.URL),
2892 )
2893 require.NoError(t, err)
2894 model, _ := provider.LanguageModel(t.Context(), "o1-preview")
2895
2896 stream, err := model.Stream(context.Background(), fantasy.Call{
2897 Prompt: testPrompt,
2898 })
2899
2900 require.NoError(t, err)
2901
2902 parts, err := collectStreamParts(stream)
2903 require.NoError(t, err)
2904
2905 // Find finish part
2906 var finishPart *fantasy.StreamPart
2907 for _, part := range parts {
2908 if part.Type == fantasy.StreamPartTypeFinish {
2909 finishPart = &part
2910 break
2911 }
2912 }
2913
2914 require.NotNil(t, finishPart)
2915 require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2916 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2917 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2918 require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens)
2919 })
2920}
2921
2922func TestDefaultToPrompt_DropsEmptyMessages(t *testing.T) {
2923 t.Parallel()
2924
2925 t.Run("should drop truly empty assistant messages", func(t *testing.T) {
2926 t.Parallel()
2927
2928 prompt := fantasy.Prompt{
2929 {
2930 Role: fantasy.MessageRoleUser,
2931 Content: []fantasy.MessagePart{
2932 fantasy.TextPart{Text: "Hello"},
2933 },
2934 },
2935 {
2936 Role: fantasy.MessageRoleAssistant,
2937 Content: []fantasy.MessagePart{},
2938 },
2939 }
2940
2941 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
2942
2943 require.Len(t, messages, 1, "should only have user message")
2944 require.Len(t, warnings, 1)
2945 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
2946 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
2947 })
2948
2949 t.Run("should keep assistant messages with text content", func(t *testing.T) {
2950 t.Parallel()
2951
2952 prompt := fantasy.Prompt{
2953 {
2954 Role: fantasy.MessageRoleUser,
2955 Content: []fantasy.MessagePart{
2956 fantasy.TextPart{Text: "Hello"},
2957 },
2958 },
2959 {
2960 Role: fantasy.MessageRoleAssistant,
2961 Content: []fantasy.MessagePart{
2962 fantasy.TextPart{Text: "Hi there!"},
2963 },
2964 },
2965 }
2966
2967 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
2968
2969 require.Len(t, messages, 2, "should have both user and assistant messages")
2970 require.Empty(t, warnings)
2971 })
2972
2973 t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
2974 t.Parallel()
2975
2976 prompt := fantasy.Prompt{
2977 {
2978 Role: fantasy.MessageRoleUser,
2979 Content: []fantasy.MessagePart{
2980 fantasy.TextPart{Text: "What's the weather?"},
2981 },
2982 },
2983 {
2984 Role: fantasy.MessageRoleAssistant,
2985 Content: []fantasy.MessagePart{
2986 fantasy.ToolCallPart{
2987 ToolCallID: "call_123",
2988 ToolName: "get_weather",
2989 Input: `{"location":"NYC"}`,
2990 },
2991 },
2992 },
2993 }
2994
2995 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
2996
2997 require.Len(t, messages, 2, "should have both user and assistant messages")
2998 require.Empty(t, warnings)
2999 })
3000
3001 t.Run("should drop user messages without visible content", func(t *testing.T) {
3002 t.Parallel()
3003
3004 prompt := fantasy.Prompt{
3005 {
3006 Role: fantasy.MessageRoleUser,
3007 Content: []fantasy.MessagePart{
3008 fantasy.FilePart{
3009 Data: []byte("not supported"),
3010 MediaType: "application/unknown",
3011 },
3012 },
3013 },
3014 }
3015
3016 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3017
3018 require.Empty(t, messages)
3019 require.Len(t, warnings, 2) // One for unsupported type, one for empty message
3020 require.Contains(t, warnings[1].Message, "dropping empty user message")
3021 })
3022
3023 t.Run("should keep user messages with image content", func(t *testing.T) {
3024 t.Parallel()
3025
3026 prompt := fantasy.Prompt{
3027 {
3028 Role: fantasy.MessageRoleUser,
3029 Content: []fantasy.MessagePart{
3030 fantasy.FilePart{
3031 Data: []byte{0x01, 0x02, 0x03},
3032 MediaType: "image/png",
3033 },
3034 },
3035 },
3036 }
3037
3038 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3039
3040 require.Len(t, messages, 1)
3041 require.Empty(t, warnings)
3042 })
3043
3044 t.Run("should keep user messages with tool results", func(t *testing.T) {
3045 t.Parallel()
3046
3047 prompt := fantasy.Prompt{
3048 {
3049 Role: fantasy.MessageRoleTool,
3050 Content: []fantasy.MessagePart{
3051 fantasy.ToolResultPart{
3052 ToolCallID: "call_123",
3053 Output: fantasy.ToolResultOutputContentText{Text: "done"},
3054 },
3055 },
3056 },
3057 }
3058
3059 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3060
3061 require.Len(t, messages, 1)
3062 require.Empty(t, warnings)
3063 })
3064
3065 t.Run("should keep user messages with tool error results", func(t *testing.T) {
3066 t.Parallel()
3067
3068 prompt := fantasy.Prompt{
3069 {
3070 Role: fantasy.MessageRoleTool,
3071 Content: []fantasy.MessagePart{
3072 fantasy.ToolResultPart{
3073 ToolCallID: "call_456",
3074 Output: fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
3075 },
3076 },
3077 },
3078 }
3079
3080 messages, warnings := DefaultToPrompt(prompt, "openai", "gpt-4")
3081
3082 require.Len(t, messages, 1)
3083 require.Empty(t, warnings)
3084 })
3085}
3086
3087func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) {
3088 t.Parallel()
3089
3090 t.Run("should drop truly empty assistant messages", func(t *testing.T) {
3091 t.Parallel()
3092
3093 prompt := fantasy.Prompt{
3094 {
3095 Role: fantasy.MessageRoleUser,
3096 Content: []fantasy.MessagePart{
3097 fantasy.TextPart{Text: "Hello"},
3098 },
3099 },
3100 {
3101 Role: fantasy.MessageRoleAssistant,
3102 Content: []fantasy.MessagePart{},
3103 },
3104 }
3105
3106 input, warnings := toResponsesPrompt(prompt, "system")
3107
3108 require.Len(t, input, 1, "should only have user message")
3109 require.Len(t, warnings, 1)
3110 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
3111 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
3112 })
3113
3114 t.Run("should keep assistant messages with text content", func(t *testing.T) {
3115 t.Parallel()
3116
3117 prompt := fantasy.Prompt{
3118 {
3119 Role: fantasy.MessageRoleUser,
3120 Content: []fantasy.MessagePart{
3121 fantasy.TextPart{Text: "Hello"},
3122 },
3123 },
3124 {
3125 Role: fantasy.MessageRoleAssistant,
3126 Content: []fantasy.MessagePart{
3127 fantasy.TextPart{Text: "Hi there!"},
3128 },
3129 },
3130 }
3131
3132 input, warnings := toResponsesPrompt(prompt, "system")
3133
3134 require.Len(t, input, 2, "should have both user and assistant messages")
3135 require.Empty(t, warnings)
3136 })
3137
3138 t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
3139 t.Parallel()
3140
3141 prompt := fantasy.Prompt{
3142 {
3143 Role: fantasy.MessageRoleUser,
3144 Content: []fantasy.MessagePart{
3145 fantasy.TextPart{Text: "What's the weather?"},
3146 },
3147 },
3148 {
3149 Role: fantasy.MessageRoleAssistant,
3150 Content: []fantasy.MessagePart{
3151 fantasy.ToolCallPart{
3152 ToolCallID: "call_123",
3153 ToolName: "get_weather",
3154 Input: `{"location":"NYC"}`,
3155 },
3156 },
3157 },
3158 }
3159
3160 input, warnings := toResponsesPrompt(prompt, "system")
3161
3162 require.Len(t, input, 2, "should have both user and assistant messages")
3163 require.Empty(t, warnings)
3164 })
3165
3166 t.Run("should drop user messages without visible content", func(t *testing.T) {
3167 t.Parallel()
3168
3169 prompt := fantasy.Prompt{
3170 {
3171 Role: fantasy.MessageRoleUser,
3172 Content: []fantasy.MessagePart{
3173 fantasy.FilePart{
3174 Data: []byte("not supported"),
3175 MediaType: "application/unknown",
3176 },
3177 },
3178 },
3179 }
3180
3181 input, warnings := toResponsesPrompt(prompt, "system")
3182
3183 require.Empty(t, input)
3184 require.Len(t, warnings, 2) // One for unsupported type, one for empty message
3185 require.Contains(t, warnings[1].Message, "dropping empty user message")
3186 })
3187
3188 t.Run("should keep user messages with image content", func(t *testing.T) {
3189 t.Parallel()
3190
3191 prompt := fantasy.Prompt{
3192 {
3193 Role: fantasy.MessageRoleUser,
3194 Content: []fantasy.MessagePart{
3195 fantasy.FilePart{
3196 Data: []byte{0x01, 0x02, 0x03},
3197 MediaType: "image/png",
3198 },
3199 },
3200 },
3201 }
3202
3203 input, warnings := toResponsesPrompt(prompt, "system")
3204
3205 require.Len(t, input, 1)
3206 require.Empty(t, warnings)
3207 })
3208
3209 t.Run("should keep user messages with tool results", func(t *testing.T) {
3210 t.Parallel()
3211
3212 prompt := fantasy.Prompt{
3213 {
3214 Role: fantasy.MessageRoleTool,
3215 Content: []fantasy.MessagePart{
3216 fantasy.ToolResultPart{
3217 ToolCallID: "call_123",
3218 Output: fantasy.ToolResultOutputContentText{Text: "done"},
3219 },
3220 },
3221 },
3222 }
3223
3224 input, warnings := toResponsesPrompt(prompt, "system")
3225
3226 require.Len(t, input, 1)
3227 require.Empty(t, warnings)
3228 })
3229
3230 t.Run("should keep user messages with tool error results", func(t *testing.T) {
3231 t.Parallel()
3232
3233 prompt := fantasy.Prompt{
3234 {
3235 Role: fantasy.MessageRoleTool,
3236 Content: []fantasy.MessagePart{
3237 fantasy.ToolResultPart{
3238 ToolCallID: "call_456",
3239 Output: fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
3240 },
3241 },
3242 },
3243 }
3244
3245 input, warnings := toResponsesPrompt(prompt, "system")
3246
3247 require.Len(t, input, 1)
3248 require.Empty(t, warnings)
3249 })
3250}
3251
3252func TestParseContextTooLargeError(t *testing.T) {
3253 t.Parallel()
3254
3255 tests := []struct {
3256 name string
3257 message string
3258 wantErr bool
3259 wantUsed int
3260 wantMax int
3261 }{
3262 {
3263 name: "matches openai format with resulted in",
3264 message: "This model's maximum context length is 128000 tokens. However, your messages resulted in 150000 tokens.",
3265 wantErr: true,
3266 wantUsed: 150000,
3267 wantMax: 128000,
3268 },
3269 {
3270 name: "matches openai format with requested",
3271 message: "maximum context length is 8192 tokens, however you requested 10000 tokens",
3272 wantErr: true,
3273 wantUsed: 10000,
3274 wantMax: 8192,
3275 },
3276 {
3277 name: "does not match unrelated error",
3278 message: "invalid api key",
3279 wantErr: false,
3280 },
3281 {
3282 name: "does not match rate limit error",
3283 message: "rate limit exceeded",
3284 wantErr: false,
3285 },
3286 }
3287
3288 for _, tt := range tests {
3289 t.Run(tt.name, func(t *testing.T) {
3290 t.Parallel()
3291 providerErr := &fantasy.ProviderError{Message: tt.message}
3292 parseContextTooLargeError(tt.message, providerErr)
3293
3294 if tt.wantErr {
3295 require.True(t, providerErr.IsContextTooLarge())
3296 if tt.wantUsed > 0 {
3297 require.Equal(t, tt.wantUsed, providerErr.ContextUsedTokens)
3298 require.Equal(t, tt.wantMax, providerErr.ContextMaxTokens)
3299 }
3300 } else {
3301 require.False(t, providerErr.IsContextTooLarge())
3302 }
3303 })
3304 }
3305}
3306
3307func TestUserAgent(t *testing.T) {
3308 t.Parallel()
3309
3310 t.Run("default UA applied", func(t *testing.T) {
3311 t.Parallel()
3312
3313 server := newMockServer()
3314 defer server.close()
3315 server.prepareJSONResponse(map[string]any{})
3316
3317 p, err := New(WithAPIKey("k"), WithBaseURL(server.server.URL))
3318 require.NoError(t, err)
3319 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3320 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt})
3321
3322 require.Len(t, server.calls, 1)
3323 assert.Equal(t, "Charm Fantasy/"+fantasy.Version, server.calls[0].headers["User-Agent"])
3324 })
3325
3326 t.Run("WithHeaders User-Agent wins over default", func(t *testing.T) {
3327 t.Parallel()
3328
3329 server := newMockServer()
3330 defer server.close()
3331 server.prepareJSONResponse(map[string]any{})
3332
3333 p, err := New(WithAPIKey("k"), WithBaseURL(server.server.URL), WithHeaders(map[string]string{"User-Agent": "custom-from-headers"}))
3334 require.NoError(t, err)
3335 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3336 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt})
3337
3338 require.Len(t, server.calls, 1)
3339 assert.Equal(t, "custom-from-headers", server.calls[0].headers["User-Agent"])
3340 })
3341
3342 t.Run("WithUserAgent wins over both", func(t *testing.T) {
3343 t.Parallel()
3344
3345 server := newMockServer()
3346 defer server.close()
3347 server.prepareJSONResponse(map[string]any{})
3348
3349 p, err := New(
3350 WithAPIKey("k"),
3351 WithBaseURL(server.server.URL),
3352 WithHeaders(map[string]string{"User-Agent": "from-headers"}),
3353 WithUserAgent("explicit-ua"),
3354 )
3355 require.NoError(t, err)
3356 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3357 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt})
3358
3359 require.Len(t, server.calls, 1)
3360 assert.Equal(t, "explicit-ua", server.calls[0].headers["User-Agent"])
3361 })
3362
3363 t.Run("Call.UserAgent overrides provider WithHeaders UA", func(t *testing.T) {
3364 t.Parallel()
3365
3366 server := newMockServer()
3367 defer server.close()
3368 server.prepareJSONResponse(map[string]any{})
3369
3370 p, err := New(
3371 WithAPIKey("k"),
3372 WithBaseURL(server.server.URL),
3373 WithHeaders(map[string]string{"User-Agent": "header-ua"}),
3374 )
3375 require.NoError(t, err)
3376 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3377 _, _ = model.Generate(t.Context(), fantasy.Call{
3378 Prompt: testPrompt,
3379 UserAgent: "call-level-ua",
3380 })
3381
3382 require.Len(t, server.calls, 1)
3383 assert.Equal(t, "call-level-ua", server.calls[0].headers["User-Agent"])
3384 })
3385
3386 t.Run("no Call UA falls through to provider UA", func(t *testing.T) {
3387 t.Parallel()
3388
3389 server := newMockServer()
3390 defer server.close()
3391 server.prepareJSONResponse(map[string]any{})
3392
3393 p, err := New(
3394 WithAPIKey("k"),
3395 WithBaseURL(server.server.URL),
3396 WithUserAgent("provider-ua"),
3397 )
3398 require.NoError(t, err)
3399 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3400 _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt})
3401
3402 require.Len(t, server.calls, 1)
3403 assert.Equal(t, "provider-ua", server.calls[0].headers["User-Agent"])
3404 })
3405
3406 t.Run("agent WithUserAgent overrides provider UA end-to-end", func(t *testing.T) {
3407 t.Parallel()
3408
3409 server := newMockServer()
3410 defer server.close()
3411 server.prepareJSONResponse(map[string]any{})
3412
3413 p, err := New(
3414 WithAPIKey("k"),
3415 WithBaseURL(server.server.URL),
3416 WithUserAgent("provider-ua"),
3417 )
3418 require.NoError(t, err)
3419 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3420
3421 agent := fantasy.NewAgent(model, fantasy.WithUserAgent("agent-ua"))
3422 _, _ = agent.Generate(t.Context(), fantasy.AgentCall{Prompt: "hi"})
3423
3424 require.Len(t, server.calls, 1)
3425 assert.Equal(t, "agent-ua", server.calls[0].headers["User-Agent"])
3426 })
3427
3428 t.Run("agent without UA falls through to provider UA end-to-end", func(t *testing.T) {
3429 t.Parallel()
3430
3431 server := newMockServer()
3432 defer server.close()
3433 server.prepareJSONResponse(map[string]any{})
3434
3435 p, err := New(
3436 WithAPIKey("k"),
3437 WithBaseURL(server.server.URL),
3438 WithUserAgent("provider-ua"),
3439 )
3440 require.NoError(t, err)
3441 model, _ := p.LanguageModel(t.Context(), "gpt-4")
3442
3443 agent := fantasy.NewAgent(model)
3444 _, _ = agent.Generate(t.Context(), fantasy.AgentCall{Prompt: "hi"})
3445
3446 require.Len(t, server.calls, 1)
3447 assert.Equal(t, "provider-ua", server.calls[0].headers["User-Agent"])
3448 })
3449}