openai_test.go

   1package openai
   2
   3import (
   4	"context"
   5	"encoding/base64"
   6	"encoding/json"
   7	"errors"
   8	"net/http"
   9	"net/http/httptest"
  10	"strings"
  11	"testing"
  12
  13	"github.com/charmbracelet/ai/ai"
  14	"github.com/openai/openai-go/v2/packages/param"
  15	"github.com/stretchr/testify/require"
  16)
  17
  18func TestToOpenAiPrompt_SystemMessages(t *testing.T) {
  19	t.Parallel()
  20
  21	t.Run("should forward system messages", func(t *testing.T) {
  22		t.Parallel()
  23
  24		prompt := ai.Prompt{
  25			{
  26				Role: ai.MessageRoleSystem,
  27				Content: []ai.MessagePart{
  28					ai.TextPart{Text: "You are a helpful assistant."},
  29				},
  30			},
  31		}
  32
  33		messages, warnings := toPrompt(prompt)
  34
  35		require.Empty(t, warnings)
  36		require.Len(t, messages, 1)
  37
  38		systemMsg := messages[0].OfSystem
  39		require.NotNil(t, systemMsg)
  40		require.Equal(t, "You are a helpful assistant.", systemMsg.Content.OfString.Value)
  41	})
  42
  43	t.Run("should handle empty system messages", func(t *testing.T) {
  44		t.Parallel()
  45
  46		prompt := ai.Prompt{
  47			{
  48				Role:    ai.MessageRoleSystem,
  49				Content: []ai.MessagePart{},
  50			},
  51		}
  52
  53		messages, warnings := toPrompt(prompt)
  54
  55		require.Len(t, warnings, 1)
  56		require.Contains(t, warnings[0].Message, "system prompt has no text parts")
  57		require.Empty(t, messages)
  58	})
  59
  60	t.Run("should join multiple system text parts", func(t *testing.T) {
  61		t.Parallel()
  62
  63		prompt := ai.Prompt{
  64			{
  65				Role: ai.MessageRoleSystem,
  66				Content: []ai.MessagePart{
  67					ai.TextPart{Text: "You are a helpful assistant."},
  68					ai.TextPart{Text: "Be concise."},
  69				},
  70			},
  71		}
  72
  73		messages, warnings := toPrompt(prompt)
  74
  75		require.Empty(t, warnings)
  76		require.Len(t, messages, 1)
  77
  78		systemMsg := messages[0].OfSystem
  79		require.NotNil(t, systemMsg)
  80		require.Equal(t, "You are a helpful assistant.\nBe concise.", systemMsg.Content.OfString.Value)
  81	})
  82}
  83
  84func TestToOpenAiPrompt_UserMessages(t *testing.T) {
  85	t.Parallel()
  86
  87	t.Run("should convert messages with only a text part to a string content", func(t *testing.T) {
  88		t.Parallel()
  89
  90		prompt := ai.Prompt{
  91			{
  92				Role: ai.MessageRoleUser,
  93				Content: []ai.MessagePart{
  94					ai.TextPart{Text: "Hello"},
  95				},
  96			},
  97		}
  98
  99		messages, warnings := toPrompt(prompt)
 100
 101		require.Empty(t, warnings)
 102		require.Len(t, messages, 1)
 103
 104		userMsg := messages[0].OfUser
 105		require.NotNil(t, userMsg)
 106		require.Equal(t, "Hello", userMsg.Content.OfString.Value)
 107	})
 108
 109	t.Run("should convert messages with image parts", func(t *testing.T) {
 110		t.Parallel()
 111
 112		imageData := []byte{0, 1, 2, 3}
 113		prompt := ai.Prompt{
 114			{
 115				Role: ai.MessageRoleUser,
 116				Content: []ai.MessagePart{
 117					ai.TextPart{Text: "Hello"},
 118					ai.FilePart{
 119						MediaType: "image/png",
 120						Data:      imageData,
 121					},
 122				},
 123			},
 124		}
 125
 126		messages, warnings := toPrompt(prompt)
 127
 128		require.Empty(t, warnings)
 129		require.Len(t, messages, 1)
 130
 131		userMsg := messages[0].OfUser
 132		require.NotNil(t, userMsg)
 133
 134		content := userMsg.Content.OfArrayOfContentParts
 135		require.Len(t, content, 2)
 136
 137		// Check text part
 138		textPart := content[0].OfText
 139		require.NotNil(t, textPart)
 140		require.Equal(t, "Hello", textPart.Text)
 141
 142		// Check image part
 143		imagePart := content[1].OfImageURL
 144		require.NotNil(t, imagePart)
 145		expectedURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
 146		require.Equal(t, expectedURL, imagePart.ImageURL.URL)
 147	})
 148
 149	t.Run("should add image detail when specified through provider options", func(t *testing.T) {
 150		t.Parallel()
 151
 152		imageData := []byte{0, 1, 2, 3}
 153		prompt := ai.Prompt{
 154			{
 155				Role: ai.MessageRoleUser,
 156				Content: []ai.MessagePart{
 157					ai.FilePart{
 158						MediaType: "image/png",
 159						Data:      imageData,
 160						ProviderOptions: NewProviderFileOptions(&ProviderFileOptions{
 161							ImageDetail: "low",
 162						}),
 163					},
 164				},
 165			},
 166		}
 167
 168		messages, warnings := toPrompt(prompt)
 169
 170		require.Empty(t, warnings)
 171		require.Len(t, messages, 1)
 172
 173		userMsg := messages[0].OfUser
 174		require.NotNil(t, userMsg)
 175
 176		content := userMsg.Content.OfArrayOfContentParts
 177		require.Len(t, content, 1)
 178
 179		imagePart := content[0].OfImageURL
 180		require.NotNil(t, imagePart)
 181		require.Equal(t, "low", imagePart.ImageURL.Detail)
 182	})
 183}
 184
 185func TestToOpenAiPrompt_FileParts(t *testing.T) {
 186	t.Parallel()
 187
 188	t.Run("should throw for unsupported mime types", func(t *testing.T) {
 189		t.Parallel()
 190
 191		prompt := ai.Prompt{
 192			{
 193				Role: ai.MessageRoleUser,
 194				Content: []ai.MessagePart{
 195					ai.FilePart{
 196						MediaType: "application/something",
 197						Data:      []byte("test"),
 198					},
 199				},
 200			},
 201		}
 202
 203		messages, warnings := toPrompt(prompt)
 204
 205		require.Len(t, warnings, 1)
 206		require.Contains(t, warnings[0].Message, "file part media type application/something not supported")
 207		require.Len(t, messages, 1) // Message is still created but with empty content array
 208	})
 209
 210	t.Run("should add audio content for audio/wav file parts", func(t *testing.T) {
 211		t.Parallel()
 212
 213		audioData := []byte{0, 1, 2, 3}
 214		prompt := ai.Prompt{
 215			{
 216				Role: ai.MessageRoleUser,
 217				Content: []ai.MessagePart{
 218					ai.FilePart{
 219						MediaType: "audio/wav",
 220						Data:      audioData,
 221					},
 222				},
 223			},
 224		}
 225
 226		messages, warnings := toPrompt(prompt)
 227
 228		require.Empty(t, warnings)
 229		require.Len(t, messages, 1)
 230
 231		userMsg := messages[0].OfUser
 232		require.NotNil(t, userMsg)
 233
 234		content := userMsg.Content.OfArrayOfContentParts
 235		require.Len(t, content, 1)
 236
 237		audioPart := content[0].OfInputAudio
 238		require.NotNil(t, audioPart)
 239		require.Equal(t, base64.StdEncoding.EncodeToString(audioData), audioPart.InputAudio.Data)
 240		require.Equal(t, "wav", audioPart.InputAudio.Format)
 241	})
 242
 243	t.Run("should add audio content for audio/mpeg file parts", func(t *testing.T) {
 244		t.Parallel()
 245
 246		audioData := []byte{0, 1, 2, 3}
 247		prompt := ai.Prompt{
 248			{
 249				Role: ai.MessageRoleUser,
 250				Content: []ai.MessagePart{
 251					ai.FilePart{
 252						MediaType: "audio/mpeg",
 253						Data:      audioData,
 254					},
 255				},
 256			},
 257		}
 258
 259		messages, warnings := toPrompt(prompt)
 260
 261		require.Empty(t, warnings)
 262		require.Len(t, messages, 1)
 263
 264		userMsg := messages[0].OfUser
 265		content := userMsg.Content.OfArrayOfContentParts
 266		audioPart := content[0].OfInputAudio
 267		require.NotNil(t, audioPart)
 268		require.Equal(t, "mp3", audioPart.InputAudio.Format)
 269	})
 270
 271	t.Run("should add audio content for audio/mp3 file parts", func(t *testing.T) {
 272		t.Parallel()
 273
 274		audioData := []byte{0, 1, 2, 3}
 275		prompt := ai.Prompt{
 276			{
 277				Role: ai.MessageRoleUser,
 278				Content: []ai.MessagePart{
 279					ai.FilePart{
 280						MediaType: "audio/mp3",
 281						Data:      audioData,
 282					},
 283				},
 284			},
 285		}
 286
 287		messages, warnings := toPrompt(prompt)
 288
 289		require.Empty(t, warnings)
 290		require.Len(t, messages, 1)
 291
 292		userMsg := messages[0].OfUser
 293		content := userMsg.Content.OfArrayOfContentParts
 294		audioPart := content[0].OfInputAudio
 295		require.NotNil(t, audioPart)
 296		require.Equal(t, "mp3", audioPart.InputAudio.Format)
 297	})
 298
 299	t.Run("should convert messages with PDF file parts", func(t *testing.T) {
 300		t.Parallel()
 301
 302		pdfData := []byte{1, 2, 3, 4, 5}
 303		prompt := ai.Prompt{
 304			{
 305				Role: ai.MessageRoleUser,
 306				Content: []ai.MessagePart{
 307					ai.FilePart{
 308						MediaType: "application/pdf",
 309						Data:      pdfData,
 310						Filename:  "document.pdf",
 311					},
 312				},
 313			},
 314		}
 315
 316		messages, warnings := toPrompt(prompt)
 317
 318		require.Empty(t, warnings)
 319		require.Len(t, messages, 1)
 320
 321		userMsg := messages[0].OfUser
 322		content := userMsg.Content.OfArrayOfContentParts
 323		require.Len(t, content, 1)
 324
 325		filePart := content[0].OfFile
 326		require.NotNil(t, filePart)
 327		require.Equal(t, "document.pdf", filePart.File.Filename.Value)
 328
 329		expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
 330		require.Equal(t, expectedData, filePart.File.FileData.Value)
 331	})
 332
 333	t.Run("should convert messages with binary PDF file parts", func(t *testing.T) {
 334		t.Parallel()
 335
 336		pdfData := []byte{1, 2, 3, 4, 5}
 337		prompt := ai.Prompt{
 338			{
 339				Role: ai.MessageRoleUser,
 340				Content: []ai.MessagePart{
 341					ai.FilePart{
 342						MediaType: "application/pdf",
 343						Data:      pdfData,
 344						Filename:  "document.pdf",
 345					},
 346				},
 347			},
 348		}
 349
 350		messages, warnings := toPrompt(prompt)
 351
 352		require.Empty(t, warnings)
 353		require.Len(t, messages, 1)
 354
 355		userMsg := messages[0].OfUser
 356		content := userMsg.Content.OfArrayOfContentParts
 357		filePart := content[0].OfFile
 358		require.NotNil(t, filePart)
 359
 360		expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
 361		require.Equal(t, expectedData, filePart.File.FileData.Value)
 362	})
 363
 364	t.Run("should convert messages with PDF file parts using file_id", func(t *testing.T) {
 365		t.Parallel()
 366
 367		prompt := ai.Prompt{
 368			{
 369				Role: ai.MessageRoleUser,
 370				Content: []ai.MessagePart{
 371					ai.FilePart{
 372						MediaType: "application/pdf",
 373						Data:      []byte("file-pdf-12345"),
 374					},
 375				},
 376			},
 377		}
 378
 379		messages, warnings := toPrompt(prompt)
 380
 381		require.Empty(t, warnings)
 382		require.Len(t, messages, 1)
 383
 384		userMsg := messages[0].OfUser
 385		content := userMsg.Content.OfArrayOfContentParts
 386		filePart := content[0].OfFile
 387		require.NotNil(t, filePart)
 388		require.Equal(t, "file-pdf-12345", filePart.File.FileID.Value)
 389		require.True(t, param.IsOmitted(filePart.File.FileData))
 390		require.True(t, param.IsOmitted(filePart.File.Filename))
 391	})
 392
 393	t.Run("should use default filename for PDF file parts when not provided", func(t *testing.T) {
 394		t.Parallel()
 395
 396		pdfData := []byte{1, 2, 3, 4, 5}
 397		prompt := ai.Prompt{
 398			{
 399				Role: ai.MessageRoleUser,
 400				Content: []ai.MessagePart{
 401					ai.FilePart{
 402						MediaType: "application/pdf",
 403						Data:      pdfData,
 404					},
 405				},
 406			},
 407		}
 408
 409		messages, warnings := toPrompt(prompt)
 410
 411		require.Empty(t, warnings)
 412		require.Len(t, messages, 1)
 413
 414		userMsg := messages[0].OfUser
 415		content := userMsg.Content.OfArrayOfContentParts
 416		filePart := content[0].OfFile
 417		require.NotNil(t, filePart)
 418		require.Equal(t, "part-0.pdf", filePart.File.Filename.Value)
 419	})
 420}
 421
 422func TestToOpenAiPrompt_ToolCalls(t *testing.T) {
 423	t.Parallel()
 424
 425	t.Run("should stringify arguments to tool calls", func(t *testing.T) {
 426		t.Parallel()
 427
 428		inputArgs := map[string]any{"foo": "bar123"}
 429		inputJSON, _ := json.Marshal(inputArgs)
 430
 431		outputResult := map[string]any{"oof": "321rab"}
 432		outputJSON, _ := json.Marshal(outputResult)
 433
 434		prompt := ai.Prompt{
 435			{
 436				Role: ai.MessageRoleAssistant,
 437				Content: []ai.MessagePart{
 438					ai.ToolCallPart{
 439						ToolCallID: "quux",
 440						ToolName:   "thwomp",
 441						Input:      string(inputJSON),
 442					},
 443				},
 444			},
 445			{
 446				Role: ai.MessageRoleTool,
 447				Content: []ai.MessagePart{
 448					ai.ToolResultPart{
 449						ToolCallID: "quux",
 450						Output: ai.ToolResultOutputContentText{
 451							Text: string(outputJSON),
 452						},
 453					},
 454				},
 455			},
 456		}
 457
 458		messages, warnings := toPrompt(prompt)
 459
 460		require.Empty(t, warnings)
 461		require.Len(t, messages, 2)
 462
 463		// Check assistant message with tool call
 464		assistantMsg := messages[0].OfAssistant
 465		require.NotNil(t, assistantMsg)
 466		require.Equal(t, "", assistantMsg.Content.OfString.Value)
 467		require.Len(t, assistantMsg.ToolCalls, 1)
 468
 469		toolCall := assistantMsg.ToolCalls[0].OfFunction
 470		require.NotNil(t, toolCall)
 471		require.Equal(t, "quux", toolCall.ID)
 472		require.Equal(t, "thwomp", toolCall.Function.Name)
 473		require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
 474
 475		// Check tool message
 476		toolMsg := messages[1].OfTool
 477		require.NotNil(t, toolMsg)
 478		require.Equal(t, string(outputJSON), toolMsg.Content.OfString.Value)
 479		require.Equal(t, "quux", toolMsg.ToolCallID)
 480	})
 481
 482	t.Run("should handle different tool output types", func(t *testing.T) {
 483		t.Parallel()
 484
 485		prompt := ai.Prompt{
 486			{
 487				Role: ai.MessageRoleTool,
 488				Content: []ai.MessagePart{
 489					ai.ToolResultPart{
 490						ToolCallID: "text-tool",
 491						Output: ai.ToolResultOutputContentText{
 492							Text: "Hello world",
 493						},
 494					},
 495					ai.ToolResultPart{
 496						ToolCallID: "error-tool",
 497						Output: ai.ToolResultOutputContentError{
 498							Error: errors.New("Something went wrong"),
 499						},
 500					},
 501				},
 502			},
 503		}
 504
 505		messages, warnings := toPrompt(prompt)
 506
 507		require.Empty(t, warnings)
 508		require.Len(t, messages, 2)
 509
 510		// Check first tool message (text)
 511		textToolMsg := messages[0].OfTool
 512		require.NotNil(t, textToolMsg)
 513		require.Equal(t, "Hello world", textToolMsg.Content.OfString.Value)
 514		require.Equal(t, "text-tool", textToolMsg.ToolCallID)
 515
 516		// Check second tool message (error)
 517		errorToolMsg := messages[1].OfTool
 518		require.NotNil(t, errorToolMsg)
 519		require.Equal(t, "Something went wrong", errorToolMsg.Content.OfString.Value)
 520		require.Equal(t, "error-tool", errorToolMsg.ToolCallID)
 521	})
 522}
 523
 524func TestToOpenAiPrompt_AssistantMessages(t *testing.T) {
 525	t.Parallel()
 526
 527	t.Run("should handle simple text assistant messages", func(t *testing.T) {
 528		t.Parallel()
 529
 530		prompt := ai.Prompt{
 531			{
 532				Role: ai.MessageRoleAssistant,
 533				Content: []ai.MessagePart{
 534					ai.TextPart{Text: "Hello, how can I help you?"},
 535				},
 536			},
 537		}
 538
 539		messages, warnings := toPrompt(prompt)
 540
 541		require.Empty(t, warnings)
 542		require.Len(t, messages, 1)
 543
 544		assistantMsg := messages[0].OfAssistant
 545		require.NotNil(t, assistantMsg)
 546		require.Equal(t, "Hello, how can I help you?", assistantMsg.Content.OfString.Value)
 547	})
 548
 549	t.Run("should handle assistant messages with mixed content", func(t *testing.T) {
 550		t.Parallel()
 551
 552		inputArgs := map[string]any{"query": "test"}
 553		inputJSON, _ := json.Marshal(inputArgs)
 554
 555		prompt := ai.Prompt{
 556			{
 557				Role: ai.MessageRoleAssistant,
 558				Content: []ai.MessagePart{
 559					ai.TextPart{Text: "Let me search for that."},
 560					ai.ToolCallPart{
 561						ToolCallID: "call-123",
 562						ToolName:   "search",
 563						Input:      string(inputJSON),
 564					},
 565				},
 566			},
 567		}
 568
 569		messages, warnings := toPrompt(prompt)
 570
 571		require.Empty(t, warnings)
 572		require.Len(t, messages, 1)
 573
 574		assistantMsg := messages[0].OfAssistant
 575		require.NotNil(t, assistantMsg)
 576		require.Equal(t, "Let me search for that.", assistantMsg.Content.OfString.Value)
 577		require.Len(t, assistantMsg.ToolCalls, 1)
 578
 579		toolCall := assistantMsg.ToolCalls[0].OfFunction
 580		require.Equal(t, "call-123", toolCall.ID)
 581		require.Equal(t, "search", toolCall.Function.Name)
 582		require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
 583	})
 584}
 585
 586var testPrompt = ai.Prompt{
 587	{
 588		Role: ai.MessageRoleUser,
 589		Content: []ai.MessagePart{
 590			ai.TextPart{Text: "Hello"},
 591		},
 592	},
 593}
 594
 595var testLogprobs = map[string]any{
 596	"content": []map[string]any{
 597		{
 598			"token":   "Hello",
 599			"logprob": -0.0009994634,
 600			"top_logprobs": []map[string]any{
 601				{
 602					"token":   "Hello",
 603					"logprob": -0.0009994634,
 604				},
 605			},
 606		},
 607		{
 608			"token":   "!",
 609			"logprob": -0.13410144,
 610			"top_logprobs": []map[string]any{
 611				{
 612					"token":   "!",
 613					"logprob": -0.13410144,
 614				},
 615			},
 616		},
 617		{
 618			"token":   " How",
 619			"logprob": -0.0009250381,
 620			"top_logprobs": []map[string]any{
 621				{
 622					"token":   " How",
 623					"logprob": -0.0009250381,
 624				},
 625			},
 626		},
 627		{
 628			"token":   " can",
 629			"logprob": -0.047709424,
 630			"top_logprobs": []map[string]any{
 631				{
 632					"token":   " can",
 633					"logprob": -0.047709424,
 634				},
 635			},
 636		},
 637		{
 638			"token":   " I",
 639			"logprob": -0.000009014684,
 640			"top_logprobs": []map[string]any{
 641				{
 642					"token":   " I",
 643					"logprob": -0.000009014684,
 644				},
 645			},
 646		},
 647		{
 648			"token":   " assist",
 649			"logprob": -0.009125131,
 650			"top_logprobs": []map[string]any{
 651				{
 652					"token":   " assist",
 653					"logprob": -0.009125131,
 654				},
 655			},
 656		},
 657		{
 658			"token":   " you",
 659			"logprob": -0.0000066306106,
 660			"top_logprobs": []map[string]any{
 661				{
 662					"token":   " you",
 663					"logprob": -0.0000066306106,
 664				},
 665			},
 666		},
 667		{
 668			"token":   " today",
 669			"logprob": -0.00011093382,
 670			"top_logprobs": []map[string]any{
 671				{
 672					"token":   " today",
 673					"logprob": -0.00011093382,
 674				},
 675			},
 676		},
 677		{
 678			"token":   "?",
 679			"logprob": -0.00004596782,
 680			"top_logprobs": []map[string]any{
 681				{
 682					"token":   "?",
 683					"logprob": -0.00004596782,
 684				},
 685			},
 686		},
 687	},
 688}
 689
 690type mockServer struct {
 691	server   *httptest.Server
 692	response map[string]any
 693	calls    []mockCall
 694}
 695
 696type mockCall struct {
 697	method  string
 698	path    string
 699	headers map[string]string
 700	body    map[string]any
 701}
 702
 703func newMockServer() *mockServer {
 704	ms := &mockServer{
 705		calls: make([]mockCall, 0),
 706	}
 707
 708	ms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 709		// Record the call
 710		call := mockCall{
 711			method:  r.Method,
 712			path:    r.URL.Path,
 713			headers: make(map[string]string),
 714		}
 715
 716		for k, v := range r.Header {
 717			if len(v) > 0 {
 718				call.headers[k] = v[0]
 719			}
 720		}
 721
 722		// Parse request body
 723		if r.Body != nil {
 724			var body map[string]any
 725			json.NewDecoder(r.Body).Decode(&body)
 726			call.body = body
 727		}
 728
 729		ms.calls = append(ms.calls, call)
 730
 731		// Return mock response
 732		w.Header().Set("Content-Type", "application/json")
 733		json.NewEncoder(w).Encode(ms.response)
 734	}))
 735
 736	return ms
 737}
 738
 739func (ms *mockServer) close() {
 740	ms.server.Close()
 741}
 742
 743func (ms *mockServer) prepareJSONResponse(opts map[string]any) {
 744	// Default values
 745	response := map[string]any{
 746		"id":      "chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd",
 747		"object":  "chat.completion",
 748		"created": 1711115037,
 749		"model":   "gpt-3.5-turbo-0125",
 750		"choices": []map[string]any{
 751			{
 752				"index": 0,
 753				"message": map[string]any{
 754					"role":    "assistant",
 755					"content": "",
 756				},
 757				"finish_reason": "stop",
 758			},
 759		},
 760		"usage": map[string]any{
 761			"prompt_tokens":     4,
 762			"total_tokens":      34,
 763			"completion_tokens": 30,
 764		},
 765		"system_fingerprint": "fp_3bc1b5746c",
 766	}
 767
 768	// Override with provided options
 769	for k, v := range opts {
 770		switch k {
 771		case "content":
 772			response["choices"].([]map[string]any)[0]["message"].(map[string]any)["content"] = v
 773		case "tool_calls":
 774			response["choices"].([]map[string]any)[0]["message"].(map[string]any)["tool_calls"] = v
 775		case "function_call":
 776			response["choices"].([]map[string]any)[0]["message"].(map[string]any)["function_call"] = v
 777		case "annotations":
 778			response["choices"].([]map[string]any)[0]["message"].(map[string]any)["annotations"] = v
 779		case "usage":
 780			response["usage"] = v
 781		case "finish_reason":
 782			response["choices"].([]map[string]any)[0]["finish_reason"] = v
 783		case "id":
 784			response["id"] = v
 785		case "created":
 786			response["created"] = v
 787		case "model":
 788			response["model"] = v
 789		case "logprobs":
 790			if v != nil {
 791				response["choices"].([]map[string]any)[0]["logprobs"] = v
 792			}
 793		}
 794	}
 795
 796	ms.response = response
 797}
 798
 799func TestDoGenerate(t *testing.T) {
 800	t.Parallel()
 801
 802	t.Run("should extract text response", func(t *testing.T) {
 803		t.Parallel()
 804
 805		server := newMockServer()
 806		defer server.close()
 807
 808		server.prepareJSONResponse(map[string]any{
 809			"content": "Hello, World!",
 810		})
 811
 812		provider := New(
 813			WithAPIKey("test-api-key"),
 814			WithBaseURL(server.server.URL),
 815		)
 816		model, _ := provider.LanguageModel("gpt-3.5-turbo")
 817
 818		result, err := model.Generate(context.Background(), ai.Call{
 819			Prompt: testPrompt,
 820		})
 821
 822		require.NoError(t, err)
 823		require.Len(t, result.Content, 1)
 824
 825		textContent, ok := result.Content[0].(ai.TextContent)
 826		require.True(t, ok)
 827		require.Equal(t, "Hello, World!", textContent.Text)
 828	})
 829
 830	t.Run("should extract usage", func(t *testing.T) {
 831		t.Parallel()
 832
 833		server := newMockServer()
 834		defer server.close()
 835
 836		server.prepareJSONResponse(map[string]any{
 837			"usage": map[string]any{
 838				"prompt_tokens":     20,
 839				"total_tokens":      25,
 840				"completion_tokens": 5,
 841			},
 842		})
 843
 844		provider := New(
 845			WithAPIKey("test-api-key"),
 846			WithBaseURL(server.server.URL),
 847		)
 848		model, _ := provider.LanguageModel("gpt-3.5-turbo")
 849
 850		result, err := model.Generate(context.Background(), ai.Call{
 851			Prompt: testPrompt,
 852		})
 853
 854		require.NoError(t, err)
 855		require.Equal(t, int64(20), result.Usage.InputTokens)
 856		require.Equal(t, int64(5), result.Usage.OutputTokens)
 857		require.Equal(t, int64(25), result.Usage.TotalTokens)
 858	})
 859
 860	t.Run("should send request body", func(t *testing.T) {
 861		t.Parallel()
 862
 863		server := newMockServer()
 864		defer server.close()
 865
 866		server.prepareJSONResponse(map[string]any{})
 867
 868		provider := New(
 869			WithAPIKey("test-api-key"),
 870			WithBaseURL(server.server.URL),
 871		)
 872		model, _ := provider.LanguageModel("gpt-3.5-turbo")
 873
 874		_, err := model.Generate(context.Background(), ai.Call{
 875			Prompt: testPrompt,
 876		})
 877
 878		require.NoError(t, err)
 879		require.Len(t, server.calls, 1)
 880
 881		call := server.calls[0]
 882		require.Equal(t, "POST", call.method)
 883		require.Equal(t, "/chat/completions", call.path)
 884		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
 885
 886		messages, ok := call.body["messages"].([]any)
 887		require.True(t, ok)
 888		require.Len(t, messages, 1)
 889
 890		message := messages[0].(map[string]any)
 891		require.Equal(t, "user", message["role"])
 892		require.Equal(t, "Hello", message["content"])
 893	})
 894
 895	t.Run("should support partial usage", func(t *testing.T) {
 896		t.Parallel()
 897
 898		server := newMockServer()
 899		defer server.close()
 900
 901		server.prepareJSONResponse(map[string]any{
 902			"usage": map[string]any{
 903				"prompt_tokens": 20,
 904				"total_tokens":  20,
 905			},
 906		})
 907
 908		provider := New(
 909			WithAPIKey("test-api-key"),
 910			WithBaseURL(server.server.URL),
 911		)
 912		model, _ := provider.LanguageModel("gpt-3.5-turbo")
 913
 914		result, err := model.Generate(context.Background(), ai.Call{
 915			Prompt: testPrompt,
 916		})
 917
 918		require.NoError(t, err)
 919		require.Equal(t, int64(20), result.Usage.InputTokens)
 920		require.Equal(t, int64(0), result.Usage.OutputTokens)
 921		require.Equal(t, int64(20), result.Usage.TotalTokens)
 922	})
 923
 924	t.Run("should extract logprobs", func(t *testing.T) {
 925		t.Parallel()
 926
 927		server := newMockServer()
 928		defer server.close()
 929
 930		server.prepareJSONResponse(map[string]any{
 931			"logprobs": testLogprobs,
 932		})
 933
 934		provider := New(
 935			WithAPIKey("test-api-key"),
 936			WithBaseURL(server.server.URL),
 937		)
 938		model, _ := provider.LanguageModel("gpt-3.5-turbo")
 939
 940		result, err := model.Generate(context.Background(), ai.Call{
 941			Prompt: testPrompt,
 942			ProviderOptions: NewProviderOptions(&ProviderOptions{
 943				LogProbs: ai.BoolOption(true),
 944			}),
 945		})
 946
 947		require.NoError(t, err)
 948		require.NotNil(t, result.ProviderMetadata)
 949
 950		openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
 951		require.True(t, ok)
 952
 953		logprobs := openaiMeta.Logprobs
 954		require.True(t, ok)
 955		require.NotNil(t, logprobs)
 956	})
 957
 958	t.Run("should extract finish reason", func(t *testing.T) {
 959		t.Parallel()
 960
 961		server := newMockServer()
 962		defer server.close()
 963
 964		server.prepareJSONResponse(map[string]any{
 965			"finish_reason": "stop",
 966		})
 967
 968		provider := New(
 969			WithAPIKey("test-api-key"),
 970			WithBaseURL(server.server.URL),
 971		)
 972		model, _ := provider.LanguageModel("gpt-3.5-turbo")
 973
 974		result, err := model.Generate(context.Background(), ai.Call{
 975			Prompt: testPrompt,
 976		})
 977
 978		require.NoError(t, err)
 979		require.Equal(t, ai.FinishReasonStop, result.FinishReason)
 980	})
 981
 982	t.Run("should support unknown finish reason", func(t *testing.T) {
 983		t.Parallel()
 984
 985		server := newMockServer()
 986		defer server.close()
 987
 988		server.prepareJSONResponse(map[string]any{
 989			"finish_reason": "eos",
 990		})
 991
 992		provider := New(
 993			WithAPIKey("test-api-key"),
 994			WithBaseURL(server.server.URL),
 995		)
 996		model, _ := provider.LanguageModel("gpt-3.5-turbo")
 997
 998		result, err := model.Generate(context.Background(), ai.Call{
 999			Prompt: testPrompt,
1000		})
1001
1002		require.NoError(t, err)
1003		require.Equal(t, ai.FinishReasonUnknown, result.FinishReason)
1004	})
1005
1006	t.Run("should pass the model and the messages", func(t *testing.T) {
1007		t.Parallel()
1008
1009		server := newMockServer()
1010		defer server.close()
1011
1012		server.prepareJSONResponse(map[string]any{
1013			"content": "",
1014		})
1015
1016		provider := New(
1017			WithAPIKey("test-api-key"),
1018			WithBaseURL(server.server.URL),
1019		)
1020		model, _ := provider.LanguageModel("gpt-3.5-turbo")
1021
1022		_, err := model.Generate(context.Background(), ai.Call{
1023			Prompt: testPrompt,
1024		})
1025
1026		require.NoError(t, err)
1027		require.Len(t, server.calls, 1)
1028
1029		call := server.calls[0]
1030		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1031
1032		messages := call.body["messages"].([]any)
1033		require.Len(t, messages, 1)
1034
1035		message := messages[0].(map[string]any)
1036		require.Equal(t, "user", message["role"])
1037		require.Equal(t, "Hello", message["content"])
1038	})
1039
1040	t.Run("should pass settings", func(t *testing.T) {
1041		t.Parallel()
1042
1043		server := newMockServer()
1044		defer server.close()
1045
1046		server.prepareJSONResponse(map[string]any{})
1047
1048		provider := New(
1049			WithAPIKey("test-api-key"),
1050			WithBaseURL(server.server.URL),
1051		)
1052		model, _ := provider.LanguageModel("gpt-3.5-turbo")
1053
1054		_, err := model.Generate(context.Background(), ai.Call{
1055			Prompt: testPrompt,
1056			ProviderOptions: NewProviderOptions(&ProviderOptions{
1057				LogitBias: map[string]int64{
1058					"50256": -100,
1059				},
1060				ParallelToolCalls: ai.BoolOption(false),
1061				User:              ai.StringOption("test-user-id"),
1062			}),
1063		})
1064
1065		require.NoError(t, err)
1066		require.Len(t, server.calls, 1)
1067
1068		call := server.calls[0]
1069		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1070
1071		messages := call.body["messages"].([]any)
1072		require.Len(t, messages, 1)
1073
1074		logitBias := call.body["logit_bias"].(map[string]any)
1075		require.Equal(t, float64(-100), logitBias["50256"])
1076		require.Equal(t, false, call.body["parallel_tool_calls"])
1077		require.Equal(t, "test-user-id", call.body["user"])
1078	})
1079
1080	t.Run("should pass reasoningEffort setting", func(t *testing.T) {
1081		t.Parallel()
1082
1083		server := newMockServer()
1084		defer server.close()
1085
1086		server.prepareJSONResponse(map[string]any{
1087			"content": "",
1088		})
1089
1090		provider := New(
1091			WithAPIKey("test-api-key"),
1092			WithBaseURL(server.server.URL),
1093		)
1094		model, _ := provider.LanguageModel("o1-mini")
1095
1096		_, err := model.Generate(context.Background(), ai.Call{
1097			Prompt: testPrompt,
1098			ProviderOptions: NewProviderOptions(
1099				&ProviderOptions{
1100					ReasoningEffort: ReasoningEffortOption(ReasoningEffortLow),
1101				},
1102			),
1103		})
1104
1105		require.NoError(t, err)
1106		require.Len(t, server.calls, 1)
1107
1108		call := server.calls[0]
1109		require.Equal(t, "o1-mini", call.body["model"])
1110		require.Equal(t, "low", call.body["reasoning_effort"])
1111
1112		messages := call.body["messages"].([]any)
1113		require.Len(t, messages, 1)
1114
1115		message := messages[0].(map[string]any)
1116		require.Equal(t, "user", message["role"])
1117		require.Equal(t, "Hello", message["content"])
1118	})
1119
1120	t.Run("should pass textVerbosity setting", func(t *testing.T) {
1121		t.Parallel()
1122
1123		server := newMockServer()
1124		defer server.close()
1125
1126		server.prepareJSONResponse(map[string]any{
1127			"content": "",
1128		})
1129
1130		provider := New(
1131			WithAPIKey("test-api-key"),
1132			WithBaseURL(server.server.URL),
1133		)
1134		model, _ := provider.LanguageModel("gpt-4o")
1135
1136		_, err := model.Generate(context.Background(), ai.Call{
1137			Prompt: testPrompt,
1138			ProviderOptions: NewProviderOptions(&ProviderOptions{
1139				TextVerbosity: ai.StringOption("low"),
1140			}),
1141		})
1142
1143		require.NoError(t, err)
1144		require.Len(t, server.calls, 1)
1145
1146		call := server.calls[0]
1147		require.Equal(t, "gpt-4o", call.body["model"])
1148		require.Equal(t, "low", call.body["verbosity"])
1149
1150		messages := call.body["messages"].([]any)
1151		require.Len(t, messages, 1)
1152
1153		message := messages[0].(map[string]any)
1154		require.Equal(t, "user", message["role"])
1155		require.Equal(t, "Hello", message["content"])
1156	})
1157
1158	t.Run("should pass tools and toolChoice", func(t *testing.T) {
1159		t.Parallel()
1160
1161		server := newMockServer()
1162		defer server.close()
1163
1164		server.prepareJSONResponse(map[string]any{
1165			"content": "",
1166		})
1167
1168		provider := New(
1169			WithAPIKey("test-api-key"),
1170			WithBaseURL(server.server.URL),
1171		)
1172		model, _ := provider.LanguageModel("gpt-3.5-turbo")
1173
1174		_, err := model.Generate(context.Background(), ai.Call{
1175			Prompt: testPrompt,
1176			Tools: []ai.Tool{
1177				ai.FunctionTool{
1178					Name: "test-tool",
1179					InputSchema: map[string]any{
1180						"type": "object",
1181						"properties": map[string]any{
1182							"value": map[string]any{
1183								"type": "string",
1184							},
1185						},
1186						"required":             []string{"value"},
1187						"additionalProperties": false,
1188						"$schema":              "http://json-schema.org/draft-07/schema#",
1189					},
1190				},
1191			},
1192			ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0],
1193		})
1194
1195		require.NoError(t, err)
1196		require.Len(t, server.calls, 1)
1197
1198		call := server.calls[0]
1199		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1200
1201		messages := call.body["messages"].([]any)
1202		require.Len(t, messages, 1)
1203
1204		tools := call.body["tools"].([]any)
1205		require.Len(t, tools, 1)
1206
1207		tool := tools[0].(map[string]any)
1208		require.Equal(t, "function", tool["type"])
1209
1210		function := tool["function"].(map[string]any)
1211		require.Equal(t, "test-tool", function["name"])
1212		require.Equal(t, false, function["strict"])
1213
1214		toolChoice := call.body["tool_choice"].(map[string]any)
1215		require.Equal(t, "function", toolChoice["type"])
1216
1217		toolChoiceFunction := toolChoice["function"].(map[string]any)
1218		require.Equal(t, "test-tool", toolChoiceFunction["name"])
1219	})
1220
1221	t.Run("should parse tool results", func(t *testing.T) {
1222		t.Parallel()
1223
1224		server := newMockServer()
1225		defer server.close()
1226
1227		server.prepareJSONResponse(map[string]any{
1228			"tool_calls": []map[string]any{
1229				{
1230					"id":   "call_O17Uplv4lJvD6DVdIvFFeRMw",
1231					"type": "function",
1232					"function": map[string]any{
1233						"name":      "test-tool",
1234						"arguments": `{"value":"Spark"}`,
1235					},
1236				},
1237			},
1238		})
1239
1240		provider := New(
1241			WithAPIKey("test-api-key"),
1242			WithBaseURL(server.server.URL),
1243		)
1244		model, _ := provider.LanguageModel("gpt-3.5-turbo")
1245
1246		result, err := model.Generate(context.Background(), ai.Call{
1247			Prompt: testPrompt,
1248			Tools: []ai.Tool{
1249				ai.FunctionTool{
1250					Name: "test-tool",
1251					InputSchema: map[string]any{
1252						"type": "object",
1253						"properties": map[string]any{
1254							"value": map[string]any{
1255								"type": "string",
1256							},
1257						},
1258						"required":             []string{"value"},
1259						"additionalProperties": false,
1260						"$schema":              "http://json-schema.org/draft-07/schema#",
1261					},
1262				},
1263			},
1264			ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0],
1265		})
1266
1267		require.NoError(t, err)
1268		require.Len(t, result.Content, 1)
1269
1270		toolCall, ok := result.Content[0].(ai.ToolCallContent)
1271		require.True(t, ok)
1272		require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", toolCall.ToolCallID)
1273		require.Equal(t, "test-tool", toolCall.ToolName)
1274		require.Equal(t, `{"value":"Spark"}`, toolCall.Input)
1275	})
1276
1277	t.Run("should parse annotations/citations", func(t *testing.T) {
1278		t.Parallel()
1279
1280		server := newMockServer()
1281		defer server.close()
1282
1283		server.prepareJSONResponse(map[string]any{
1284			"content": "Based on the search results [doc1], I found information.",
1285			"annotations": []map[string]any{
1286				{
1287					"type": "url_citation",
1288					"url_citation": map[string]any{
1289						"start_index": 24,
1290						"end_index":   29,
1291						"url":         "https://example.com/doc1.pdf",
1292						"title":       "Document 1",
1293					},
1294				},
1295			},
1296		})
1297
1298		provider := New(
1299			WithAPIKey("test-api-key"),
1300			WithBaseURL(server.server.URL),
1301		)
1302		model, _ := provider.LanguageModel("gpt-3.5-turbo")
1303
1304		result, err := model.Generate(context.Background(), ai.Call{
1305			Prompt: testPrompt,
1306		})
1307
1308		require.NoError(t, err)
1309		require.Len(t, result.Content, 2)
1310
1311		textContent, ok := result.Content[0].(ai.TextContent)
1312		require.True(t, ok)
1313		require.Equal(t, "Based on the search results [doc1], I found information.", textContent.Text)
1314
1315		sourceContent, ok := result.Content[1].(ai.SourceContent)
1316		require.True(t, ok)
1317		require.Equal(t, ai.SourceTypeURL, sourceContent.SourceType)
1318		require.Equal(t, "https://example.com/doc1.pdf", sourceContent.URL)
1319		require.Equal(t, "Document 1", sourceContent.Title)
1320		require.NotEmpty(t, sourceContent.ID)
1321	})
1322
1323	t.Run("should return cached_tokens in prompt_details_tokens", func(t *testing.T) {
1324		t.Parallel()
1325
1326		server := newMockServer()
1327		defer server.close()
1328
1329		server.prepareJSONResponse(map[string]any{
1330			"usage": map[string]any{
1331				"prompt_tokens":     15,
1332				"completion_tokens": 20,
1333				"total_tokens":      35,
1334				"prompt_tokens_details": map[string]any{
1335					"cached_tokens": 1152,
1336				},
1337			},
1338		})
1339
1340		provider := New(
1341			WithAPIKey("test-api-key"),
1342			WithBaseURL(server.server.URL),
1343		)
1344		model, _ := provider.LanguageModel("gpt-4o-mini")
1345
1346		result, err := model.Generate(context.Background(), ai.Call{
1347			Prompt: testPrompt,
1348		})
1349
1350		require.NoError(t, err)
1351		require.Equal(t, int64(1152), result.Usage.CacheReadTokens)
1352		require.Equal(t, int64(15), result.Usage.InputTokens)
1353		require.Equal(t, int64(20), result.Usage.OutputTokens)
1354		require.Equal(t, int64(35), result.Usage.TotalTokens)
1355	})
1356
1357	t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
1358		t.Parallel()
1359
1360		server := newMockServer()
1361		defer server.close()
1362
1363		server.prepareJSONResponse(map[string]any{
1364			"usage": map[string]any{
1365				"prompt_tokens":     15,
1366				"completion_tokens": 20,
1367				"total_tokens":      35,
1368				"completion_tokens_details": map[string]any{
1369					"accepted_prediction_tokens": 123,
1370					"rejected_prediction_tokens": 456,
1371				},
1372			},
1373		})
1374
1375		provider := New(
1376			WithAPIKey("test-api-key"),
1377			WithBaseURL(server.server.URL),
1378		)
1379		model, _ := provider.LanguageModel("gpt-4o-mini")
1380
1381		result, err := model.Generate(context.Background(), ai.Call{
1382			Prompt: testPrompt,
1383		})
1384
1385		require.NoError(t, err)
1386		require.NotNil(t, result.ProviderMetadata)
1387
1388		openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
1389
1390		require.True(t, ok)
1391		require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
1392		require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
1393	})
1394
1395	t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) {
1396		t.Parallel()
1397
1398		server := newMockServer()
1399		defer server.close()
1400
1401		server.prepareJSONResponse(map[string]any{})
1402
1403		provider := New(
1404			WithAPIKey("test-api-key"),
1405			WithBaseURL(server.server.URL),
1406		)
1407		model, _ := provider.LanguageModel("o1-preview")
1408
1409		result, err := model.Generate(context.Background(), ai.Call{
1410			Prompt:           testPrompt,
1411			Temperature:      &[]float64{0.5}[0],
1412			TopP:             &[]float64{0.7}[0],
1413			FrequencyPenalty: &[]float64{0.2}[0],
1414			PresencePenalty:  &[]float64{0.3}[0],
1415		})
1416
1417		require.NoError(t, err)
1418		require.Len(t, server.calls, 1)
1419
1420		call := server.calls[0]
1421		require.Equal(t, "o1-preview", call.body["model"])
1422
1423		messages := call.body["messages"].([]any)
1424		require.Len(t, messages, 1)
1425
1426		message := messages[0].(map[string]any)
1427		require.Equal(t, "user", message["role"])
1428		require.Equal(t, "Hello", message["content"])
1429
1430		// These should not be present
1431		require.Nil(t, call.body["temperature"])
1432		require.Nil(t, call.body["top_p"])
1433		require.Nil(t, call.body["frequency_penalty"])
1434		require.Nil(t, call.body["presence_penalty"])
1435
1436		// Should have warnings
1437		require.Len(t, result.Warnings, 4)
1438		require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1439		require.Equal(t, "temperature", result.Warnings[0].Setting)
1440		require.Contains(t, result.Warnings[0].Details, "temperature is not supported for reasoning models")
1441	})
1442
1443	t.Run("should convert maxOutputTokens to max_completion_tokens for reasoning models", func(t *testing.T) {
1444		t.Parallel()
1445
1446		server := newMockServer()
1447		defer server.close()
1448
1449		server.prepareJSONResponse(map[string]any{})
1450
1451		provider := New(
1452			WithAPIKey("test-api-key"),
1453			WithBaseURL(server.server.URL),
1454		)
1455		model, _ := provider.LanguageModel("o1-preview")
1456
1457		_, err := model.Generate(context.Background(), ai.Call{
1458			Prompt:          testPrompt,
1459			MaxOutputTokens: &[]int64{1000}[0],
1460		})
1461
1462		require.NoError(t, err)
1463		require.Len(t, server.calls, 1)
1464
1465		call := server.calls[0]
1466		require.Equal(t, "o1-preview", call.body["model"])
1467		require.Equal(t, float64(1000), call.body["max_completion_tokens"])
1468		require.Nil(t, call.body["max_tokens"])
1469
1470		messages := call.body["messages"].([]any)
1471		require.Len(t, messages, 1)
1472
1473		message := messages[0].(map[string]any)
1474		require.Equal(t, "user", message["role"])
1475		require.Equal(t, "Hello", message["content"])
1476	})
1477
1478	t.Run("should return reasoning tokens", func(t *testing.T) {
1479		t.Parallel()
1480
1481		server := newMockServer()
1482		defer server.close()
1483
1484		server.prepareJSONResponse(map[string]any{
1485			"usage": map[string]any{
1486				"prompt_tokens":     15,
1487				"completion_tokens": 20,
1488				"total_tokens":      35,
1489				"completion_tokens_details": map[string]any{
1490					"reasoning_tokens": 10,
1491				},
1492			},
1493		})
1494
1495		provider := New(
1496			WithAPIKey("test-api-key"),
1497			WithBaseURL(server.server.URL),
1498		)
1499		model, _ := provider.LanguageModel("o1-preview")
1500
1501		result, err := model.Generate(context.Background(), ai.Call{
1502			Prompt: testPrompt,
1503		})
1504
1505		require.NoError(t, err)
1506		require.Equal(t, int64(15), result.Usage.InputTokens)
1507		require.Equal(t, int64(20), result.Usage.OutputTokens)
1508		require.Equal(t, int64(35), result.Usage.TotalTokens)
1509		require.Equal(t, int64(10), result.Usage.ReasoningTokens)
1510	})
1511
1512	t.Run("should send max_completion_tokens extension setting", func(t *testing.T) {
1513		t.Parallel()
1514
1515		server := newMockServer()
1516		defer server.close()
1517
1518		server.prepareJSONResponse(map[string]any{
1519			"model": "o1-preview",
1520		})
1521
1522		provider := New(
1523			WithAPIKey("test-api-key"),
1524			WithBaseURL(server.server.URL),
1525		)
1526		model, _ := provider.LanguageModel("o1-preview")
1527
1528		_, err := model.Generate(context.Background(), ai.Call{
1529			Prompt: testPrompt,
1530			ProviderOptions: NewProviderOptions(&ProviderOptions{
1531				MaxCompletionTokens: ai.IntOption(255),
1532			}),
1533		})
1534
1535		require.NoError(t, err)
1536		require.Len(t, server.calls, 1)
1537
1538		call := server.calls[0]
1539		require.Equal(t, "o1-preview", call.body["model"])
1540		require.Equal(t, float64(255), call.body["max_completion_tokens"])
1541
1542		messages := call.body["messages"].([]any)
1543		require.Len(t, messages, 1)
1544
1545		message := messages[0].(map[string]any)
1546		require.Equal(t, "user", message["role"])
1547		require.Equal(t, "Hello", message["content"])
1548	})
1549
1550	t.Run("should send prediction extension setting", func(t *testing.T) {
1551		t.Parallel()
1552
1553		server := newMockServer()
1554		defer server.close()
1555
1556		server.prepareJSONResponse(map[string]any{
1557			"content": "",
1558		})
1559
1560		provider := New(
1561			WithAPIKey("test-api-key"),
1562			WithBaseURL(server.server.URL),
1563		)
1564		model, _ := provider.LanguageModel("gpt-3.5-turbo")
1565
1566		_, err := model.Generate(context.Background(), ai.Call{
1567			Prompt: testPrompt,
1568			ProviderOptions: NewProviderOptions(&ProviderOptions{
1569				Prediction: map[string]any{
1570					"type":    "content",
1571					"content": "Hello, World!",
1572				},
1573			}),
1574		})
1575
1576		require.NoError(t, err)
1577		require.Len(t, server.calls, 1)
1578
1579		call := server.calls[0]
1580		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1581
1582		prediction := call.body["prediction"].(map[string]any)
1583		require.Equal(t, "content", prediction["type"])
1584		require.Equal(t, "Hello, World!", prediction["content"])
1585
1586		messages := call.body["messages"].([]any)
1587		require.Len(t, messages, 1)
1588
1589		message := messages[0].(map[string]any)
1590		require.Equal(t, "user", message["role"])
1591		require.Equal(t, "Hello", message["content"])
1592	})
1593
1594	t.Run("should send store extension setting", func(t *testing.T) {
1595		t.Parallel()
1596
1597		server := newMockServer()
1598		defer server.close()
1599
1600		server.prepareJSONResponse(map[string]any{
1601			"content": "",
1602		})
1603
1604		provider := New(
1605			WithAPIKey("test-api-key"),
1606			WithBaseURL(server.server.URL),
1607		)
1608		model, _ := provider.LanguageModel("gpt-3.5-turbo")
1609
1610		_, err := model.Generate(context.Background(), ai.Call{
1611			Prompt: testPrompt,
1612			ProviderOptions: NewProviderOptions(&ProviderOptions{
1613				Store: ai.BoolOption(true),
1614			}),
1615		})
1616
1617		require.NoError(t, err)
1618		require.Len(t, server.calls, 1)
1619
1620		call := server.calls[0]
1621		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1622		require.Equal(t, true, call.body["store"])
1623
1624		messages := call.body["messages"].([]any)
1625		require.Len(t, messages, 1)
1626
1627		message := messages[0].(map[string]any)
1628		require.Equal(t, "user", message["role"])
1629		require.Equal(t, "Hello", message["content"])
1630	})
1631
1632	t.Run("should send metadata extension values", func(t *testing.T) {
1633		t.Parallel()
1634
1635		server := newMockServer()
1636		defer server.close()
1637
1638		server.prepareJSONResponse(map[string]any{
1639			"content": "",
1640		})
1641
1642		provider := New(
1643			WithAPIKey("test-api-key"),
1644			WithBaseURL(server.server.URL),
1645		)
1646		model, _ := provider.LanguageModel("gpt-3.5-turbo")
1647
1648		_, err := model.Generate(context.Background(), ai.Call{
1649			Prompt: testPrompt,
1650			ProviderOptions: NewProviderOptions(&ProviderOptions{
1651				Metadata: map[string]any{
1652					"custom": "value",
1653				},
1654			}),
1655		})
1656
1657		require.NoError(t, err)
1658		require.Len(t, server.calls, 1)
1659
1660		call := server.calls[0]
1661		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1662
1663		metadata := call.body["metadata"].(map[string]any)
1664		require.Equal(t, "value", metadata["custom"])
1665
1666		messages := call.body["messages"].([]any)
1667		require.Len(t, messages, 1)
1668
1669		message := messages[0].(map[string]any)
1670		require.Equal(t, "user", message["role"])
1671		require.Equal(t, "Hello", message["content"])
1672	})
1673
1674	t.Run("should send promptCacheKey extension value", func(t *testing.T) {
1675		t.Parallel()
1676
1677		server := newMockServer()
1678		defer server.close()
1679
1680		server.prepareJSONResponse(map[string]any{
1681			"content": "",
1682		})
1683
1684		provider := New(
1685			WithAPIKey("test-api-key"),
1686			WithBaseURL(server.server.URL),
1687		)
1688		model, _ := provider.LanguageModel("gpt-3.5-turbo")
1689
1690		_, err := model.Generate(context.Background(), ai.Call{
1691			Prompt: testPrompt,
1692			ProviderOptions: NewProviderOptions(&ProviderOptions{
1693				PromptCacheKey: ai.StringOption("test-cache-key-123"),
1694			}),
1695		})
1696
1697		require.NoError(t, err)
1698		require.Len(t, server.calls, 1)
1699
1700		call := server.calls[0]
1701		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1702		require.Equal(t, "test-cache-key-123", call.body["prompt_cache_key"])
1703
1704		messages := call.body["messages"].([]any)
1705		require.Len(t, messages, 1)
1706
1707		message := messages[0].(map[string]any)
1708		require.Equal(t, "user", message["role"])
1709		require.Equal(t, "Hello", message["content"])
1710	})
1711
1712	t.Run("should send safety_identifier extension value", func(t *testing.T) {
1713		t.Parallel()
1714
1715		server := newMockServer()
1716		defer server.close()
1717
1718		server.prepareJSONResponse(map[string]any{
1719			"content": "",
1720		})
1721
1722		provider := New(
1723			WithAPIKey("test-api-key"),
1724			WithBaseURL(server.server.URL),
1725		)
1726		model, _ := provider.LanguageModel("gpt-3.5-turbo")
1727
1728		_, err := model.Generate(context.Background(), ai.Call{
1729			Prompt: testPrompt,
1730			ProviderOptions: NewProviderOptions(&ProviderOptions{
1731				SafetyIdentifier: ai.StringOption("test-safety-identifier-123"),
1732			}),
1733		})
1734
1735		require.NoError(t, err)
1736		require.Len(t, server.calls, 1)
1737
1738		call := server.calls[0]
1739		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1740		require.Equal(t, "test-safety-identifier-123", call.body["safety_identifier"])
1741
1742		messages := call.body["messages"].([]any)
1743		require.Len(t, messages, 1)
1744
1745		message := messages[0].(map[string]any)
1746		require.Equal(t, "user", message["role"])
1747		require.Equal(t, "Hello", message["content"])
1748	})
1749
1750	t.Run("should remove temperature setting for search preview models", func(t *testing.T) {
1751		t.Parallel()
1752
1753		server := newMockServer()
1754		defer server.close()
1755
1756		server.prepareJSONResponse(map[string]any{})
1757
1758		provider := New(
1759			WithAPIKey("test-api-key"),
1760			WithBaseURL(server.server.URL),
1761		)
1762		model, _ := provider.LanguageModel("gpt-4o-search-preview")
1763
1764		result, err := model.Generate(context.Background(), ai.Call{
1765			Prompt:      testPrompt,
1766			Temperature: &[]float64{0.7}[0],
1767		})
1768
1769		require.NoError(t, err)
1770		require.Len(t, server.calls, 1)
1771
1772		call := server.calls[0]
1773		require.Equal(t, "gpt-4o-search-preview", call.body["model"])
1774		require.Nil(t, call.body["temperature"])
1775
1776		require.Len(t, result.Warnings, 1)
1777		require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1778		require.Equal(t, "temperature", result.Warnings[0].Setting)
1779		require.Contains(t, result.Warnings[0].Details, "search preview models")
1780	})
1781
1782	t.Run("should send ServiceTier flex processing setting", func(t *testing.T) {
1783		t.Parallel()
1784
1785		server := newMockServer()
1786		defer server.close()
1787
1788		server.prepareJSONResponse(map[string]any{
1789			"content": "",
1790		})
1791
1792		provider := New(
1793			WithAPIKey("test-api-key"),
1794			WithBaseURL(server.server.URL),
1795		)
1796		model, _ := provider.LanguageModel("o3-mini")
1797
1798		_, err := model.Generate(context.Background(), ai.Call{
1799			Prompt: testPrompt,
1800			ProviderOptions: NewProviderOptions(&ProviderOptions{
1801				ServiceTier: ai.StringOption("flex"),
1802			}),
1803		})
1804
1805		require.NoError(t, err)
1806		require.Len(t, server.calls, 1)
1807
1808		call := server.calls[0]
1809		require.Equal(t, "o3-mini", call.body["model"])
1810		require.Equal(t, "flex", call.body["service_tier"])
1811
1812		messages := call.body["messages"].([]any)
1813		require.Len(t, messages, 1)
1814
1815		message := messages[0].(map[string]any)
1816		require.Equal(t, "user", message["role"])
1817		require.Equal(t, "Hello", message["content"])
1818	})
1819
1820	t.Run("should show warning when using flex processing with unsupported model", func(t *testing.T) {
1821		t.Parallel()
1822
1823		server := newMockServer()
1824		defer server.close()
1825
1826		server.prepareJSONResponse(map[string]any{})
1827
1828		provider := New(
1829			WithAPIKey("test-api-key"),
1830			WithBaseURL(server.server.URL),
1831		)
1832		model, _ := provider.LanguageModel("gpt-4o-mini")
1833
1834		result, err := model.Generate(context.Background(), ai.Call{
1835			Prompt: testPrompt,
1836			ProviderOptions: NewProviderOptions(&ProviderOptions{
1837				ServiceTier: ai.StringOption("flex"),
1838			}),
1839		})
1840
1841		require.NoError(t, err)
1842		require.Len(t, server.calls, 1)
1843
1844		call := server.calls[0]
1845		require.Nil(t, call.body["service_tier"])
1846
1847		require.Len(t, result.Warnings, 1)
1848		require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1849		require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
1850		require.Contains(t, result.Warnings[0].Details, "flex processing is only available")
1851	})
1852
1853	t.Run("should send serviceTier priority processing setting", func(t *testing.T) {
1854		t.Parallel()
1855
1856		server := newMockServer()
1857		defer server.close()
1858
1859		server.prepareJSONResponse(map[string]any{})
1860
1861		provider := New(
1862			WithAPIKey("test-api-key"),
1863			WithBaseURL(server.server.URL),
1864		)
1865		model, _ := provider.LanguageModel("gpt-4o-mini")
1866
1867		_, err := model.Generate(context.Background(), ai.Call{
1868			Prompt: testPrompt,
1869			ProviderOptions: NewProviderOptions(&ProviderOptions{
1870				ServiceTier: ai.StringOption("priority"),
1871			}),
1872		})
1873
1874		require.NoError(t, err)
1875		require.Len(t, server.calls, 1)
1876
1877		call := server.calls[0]
1878		require.Equal(t, "gpt-4o-mini", call.body["model"])
1879		require.Equal(t, "priority", call.body["service_tier"])
1880
1881		messages := call.body["messages"].([]any)
1882		require.Len(t, messages, 1)
1883
1884		message := messages[0].(map[string]any)
1885		require.Equal(t, "user", message["role"])
1886		require.Equal(t, "Hello", message["content"])
1887	})
1888
1889	t.Run("should show warning when using priority processing with unsupported model", func(t *testing.T) {
1890		t.Parallel()
1891
1892		server := newMockServer()
1893		defer server.close()
1894
1895		server.prepareJSONResponse(map[string]any{})
1896
1897		provider := New(
1898			WithAPIKey("test-api-key"),
1899			WithBaseURL(server.server.URL),
1900		)
1901		model, _ := provider.LanguageModel("gpt-3.5-turbo")
1902
1903		result, err := model.Generate(context.Background(), ai.Call{
1904			Prompt: testPrompt,
1905			ProviderOptions: NewProviderOptions(&ProviderOptions{
1906				ServiceTier: ai.StringOption("priority"),
1907			}),
1908		})
1909
1910		require.NoError(t, err)
1911		require.Len(t, server.calls, 1)
1912
1913		call := server.calls[0]
1914		require.Nil(t, call.body["service_tier"])
1915
1916		require.Len(t, result.Warnings, 1)
1917		require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1918		require.Equal(t, "ServiceTier", result.Warnings[0].Setting)
1919		require.Contains(t, result.Warnings[0].Details, "priority processing is only available")
1920	})
1921}
1922
1923type streamingMockServer struct {
1924	server *httptest.Server
1925	chunks []string
1926	calls  []mockCall
1927}
1928
1929func newStreamingMockServer() *streamingMockServer {
1930	sms := &streamingMockServer{
1931		calls: make([]mockCall, 0),
1932	}
1933
1934	sms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1935		// Record the call
1936		call := mockCall{
1937			method:  r.Method,
1938			path:    r.URL.Path,
1939			headers: make(map[string]string),
1940		}
1941
1942		for k, v := range r.Header {
1943			if len(v) > 0 {
1944				call.headers[k] = v[0]
1945			}
1946		}
1947
1948		// Parse request body
1949		if r.Body != nil {
1950			var body map[string]any
1951			json.NewDecoder(r.Body).Decode(&body)
1952			call.body = body
1953		}
1954
1955		sms.calls = append(sms.calls, call)
1956
1957		// Set streaming headers
1958		w.Header().Set("Content-Type", "text/event-stream")
1959		w.Header().Set("Cache-Control", "no-cache")
1960		w.Header().Set("Connection", "keep-alive")
1961
1962		// Add custom headers if any
1963		for _, chunk := range sms.chunks {
1964			if strings.HasPrefix(chunk, "HEADER:") {
1965				parts := strings.SplitN(chunk[7:], ":", 2)
1966				if len(parts) == 2 {
1967					w.Header().Set(parts[0], parts[1])
1968				}
1969				continue
1970			}
1971		}
1972
1973		w.WriteHeader(http.StatusOK)
1974
1975		// Write chunks
1976		for _, chunk := range sms.chunks {
1977			if strings.HasPrefix(chunk, "HEADER:") {
1978				continue
1979			}
1980			w.Write([]byte(chunk))
1981			if f, ok := w.(http.Flusher); ok {
1982				f.Flush()
1983			}
1984		}
1985	}))
1986
1987	return sms
1988}
1989
1990func (sms *streamingMockServer) close() {
1991	sms.server.Close()
1992}
1993
1994func (sms *streamingMockServer) prepareStreamResponse(opts map[string]any) {
1995	content := []string{}
1996	if c, ok := opts["content"].([]string); ok {
1997		content = c
1998	}
1999
2000	usage := map[string]any{
2001		"prompt_tokens":     17,
2002		"total_tokens":      244,
2003		"completion_tokens": 227,
2004	}
2005	if u, ok := opts["usage"].(map[string]any); ok {
2006		usage = u
2007	}
2008
2009	logprobs := map[string]any{}
2010	if l, ok := opts["logprobs"].(map[string]any); ok {
2011		logprobs = l
2012	}
2013
2014	finishReason := "stop"
2015	if fr, ok := opts["finish_reason"].(string); ok {
2016		finishReason = fr
2017	}
2018
2019	model := "gpt-3.5-turbo-0613"
2020	if m, ok := opts["model"].(string); ok {
2021		model = m
2022	}
2023
2024	headers := map[string]string{}
2025	if h, ok := opts["headers"].(map[string]string); ok {
2026		headers = h
2027	}
2028
2029	chunks := []string{}
2030
2031	// Add custom headers
2032	for k, v := range headers {
2033		chunks = append(chunks, "HEADER:"+k+":"+v)
2034	}
2035
2036	// Initial chunk with role
2037	initialChunk := map[string]any{
2038		"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2039		"object":             "chat.completion.chunk",
2040		"created":            1702657020,
2041		"model":              model,
2042		"system_fingerprint": nil,
2043		"choices": []map[string]any{
2044			{
2045				"index": 0,
2046				"delta": map[string]any{
2047					"role":    "assistant",
2048					"content": "",
2049				},
2050				"finish_reason": nil,
2051			},
2052		},
2053	}
2054	initialData, _ := json.Marshal(initialChunk)
2055	chunks = append(chunks, "data: "+string(initialData)+"\n\n")
2056
2057	// Content chunks
2058	for i, text := range content {
2059		contentChunk := map[string]any{
2060			"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2061			"object":             "chat.completion.chunk",
2062			"created":            1702657020,
2063			"model":              model,
2064			"system_fingerprint": nil,
2065			"choices": []map[string]any{
2066				{
2067					"index": 1,
2068					"delta": map[string]any{
2069						"content": text,
2070					},
2071					"finish_reason": nil,
2072				},
2073			},
2074		}
2075		contentData, _ := json.Marshal(contentChunk)
2076		chunks = append(chunks, "data: "+string(contentData)+"\n\n")
2077
2078		// Add annotations if this is the last content chunk and we have annotations
2079		if i == len(content)-1 {
2080			if annotations, ok := opts["annotations"].([]map[string]any); ok {
2081				annotationChunk := map[string]any{
2082					"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2083					"object":             "chat.completion.chunk",
2084					"created":            1702657020,
2085					"model":              model,
2086					"system_fingerprint": nil,
2087					"choices": []map[string]any{
2088						{
2089							"index": 1,
2090							"delta": map[string]any{
2091								"annotations": annotations,
2092							},
2093							"finish_reason": nil,
2094						},
2095					},
2096				}
2097				annotationData, _ := json.Marshal(annotationChunk)
2098				chunks = append(chunks, "data: "+string(annotationData)+"\n\n")
2099			}
2100		}
2101	}
2102
2103	// Finish chunk
2104	finishChunk := map[string]any{
2105		"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2106		"object":             "chat.completion.chunk",
2107		"created":            1702657020,
2108		"model":              model,
2109		"system_fingerprint": nil,
2110		"choices": []map[string]any{
2111			{
2112				"index":         0,
2113				"delta":         map[string]any{},
2114				"finish_reason": finishReason,
2115			},
2116		},
2117	}
2118
2119	if len(logprobs) > 0 {
2120		finishChunk["choices"].([]map[string]any)[0]["logprobs"] = logprobs
2121	}
2122
2123	finishData, _ := json.Marshal(finishChunk)
2124	chunks = append(chunks, "data: "+string(finishData)+"\n\n")
2125
2126	// Usage chunk
2127	usageChunk := map[string]any{
2128		"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2129		"object":             "chat.completion.chunk",
2130		"created":            1702657020,
2131		"model":              model,
2132		"system_fingerprint": "fp_3bc1b5746c",
2133		"choices":            []map[string]any{},
2134		"usage":              usage,
2135	}
2136	usageData, _ := json.Marshal(usageChunk)
2137	chunks = append(chunks, "data: "+string(usageData)+"\n\n")
2138
2139	// Done
2140	chunks = append(chunks, "data: [DONE]\n\n")
2141
2142	sms.chunks = chunks
2143}
2144
2145func (sms *streamingMockServer) prepareToolStreamResponse() {
2146	chunks := []string{
2147		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2148		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2149		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"value"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2150		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2151		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2152		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2153		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2154		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2155		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
2156		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
2157		"data: [DONE]\n\n",
2158	}
2159	sms.chunks = chunks
2160}
2161
2162func (sms *streamingMockServer) prepareErrorStreamResponse() {
2163	chunks := []string{
2164		`data: {"error":{"message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error.","type":"server_error","param":null,"code":null}}` + "\n\n",
2165		"data: [DONE]\n\n",
2166	}
2167	sms.chunks = chunks
2168}
2169
2170func collectStreamParts(stream ai.StreamResponse) ([]ai.StreamPart, error) {
2171	var parts []ai.StreamPart
2172	for part := range stream {
2173		parts = append(parts, part)
2174		if part.Type == ai.StreamPartTypeError {
2175			break
2176		}
2177		if part.Type == ai.StreamPartTypeFinish {
2178			break
2179		}
2180	}
2181	return parts, nil
2182}
2183
2184func TestDoStream(t *testing.T) {
2185	t.Parallel()
2186
2187	t.Run("should stream text deltas", func(t *testing.T) {
2188		t.Parallel()
2189
2190		server := newStreamingMockServer()
2191		defer server.close()
2192
2193		server.prepareStreamResponse(map[string]any{
2194			"content":       []string{"Hello", ", ", "World!"},
2195			"finish_reason": "stop",
2196			"usage": map[string]any{
2197				"prompt_tokens":     17,
2198				"total_tokens":      244,
2199				"completion_tokens": 227,
2200			},
2201			"logprobs": testLogprobs,
2202		})
2203
2204		provider := New(
2205			WithAPIKey("test-api-key"),
2206			WithBaseURL(server.server.URL),
2207		)
2208		model, _ := provider.LanguageModel("gpt-3.5-turbo")
2209
2210		stream, err := model.Stream(context.Background(), ai.Call{
2211			Prompt: testPrompt,
2212		})
2213
2214		require.NoError(t, err)
2215
2216		parts, err := collectStreamParts(stream)
2217		require.NoError(t, err)
2218
2219		// Verify stream structure
2220		require.True(t, len(parts) >= 4) // text-start, deltas, text-end, finish
2221
2222		// Find text parts
2223		textStart, textEnd, finish := -1, -1, -1
2224		var deltas []string
2225
2226		for i, part := range parts {
2227			switch part.Type {
2228			case ai.StreamPartTypeTextStart:
2229				textStart = i
2230			case ai.StreamPartTypeTextDelta:
2231				deltas = append(deltas, part.Delta)
2232			case ai.StreamPartTypeTextEnd:
2233				textEnd = i
2234			case ai.StreamPartTypeFinish:
2235				finish = i
2236			}
2237		}
2238
2239		require.NotEqual(t, -1, textStart)
2240		require.NotEqual(t, -1, textEnd)
2241		require.NotEqual(t, -1, finish)
2242		require.Equal(t, []string{"Hello", ", ", "World!"}, deltas)
2243
2244		// Check finish part
2245		finishPart := parts[finish]
2246		require.Equal(t, ai.FinishReasonStop, finishPart.FinishReason)
2247		require.Equal(t, int64(17), finishPart.Usage.InputTokens)
2248		require.Equal(t, int64(227), finishPart.Usage.OutputTokens)
2249		require.Equal(t, int64(244), finishPart.Usage.TotalTokens)
2250	})
2251
2252	t.Run("should stream tool deltas", func(t *testing.T) {
2253		t.Parallel()
2254
2255		server := newStreamingMockServer()
2256		defer server.close()
2257
2258		server.prepareToolStreamResponse()
2259
2260		provider := New(
2261			WithAPIKey("test-api-key"),
2262			WithBaseURL(server.server.URL),
2263		)
2264		model, _ := provider.LanguageModel("gpt-3.5-turbo")
2265
2266		stream, err := model.Stream(context.Background(), ai.Call{
2267			Prompt: testPrompt,
2268			Tools: []ai.Tool{
2269				ai.FunctionTool{
2270					Name: "test-tool",
2271					InputSchema: map[string]any{
2272						"type": "object",
2273						"properties": map[string]any{
2274							"value": map[string]any{
2275								"type": "string",
2276							},
2277						},
2278						"required":             []string{"value"},
2279						"additionalProperties": false,
2280						"$schema":              "http://json-schema.org/draft-07/schema#",
2281					},
2282				},
2283			},
2284		})
2285
2286		require.NoError(t, err)
2287
2288		parts, err := collectStreamParts(stream)
2289		require.NoError(t, err)
2290
2291		// Find tool-related parts
2292		toolInputStart, toolInputEnd, toolCall := -1, -1, -1
2293		var toolDeltas []string
2294
2295		for i, part := range parts {
2296			switch part.Type {
2297			case ai.StreamPartTypeToolInputStart:
2298				toolInputStart = i
2299				require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2300				require.Equal(t, "test-tool", part.ToolCallName)
2301			case ai.StreamPartTypeToolInputDelta:
2302				toolDeltas = append(toolDeltas, part.Delta)
2303			case ai.StreamPartTypeToolInputEnd:
2304				toolInputEnd = i
2305			case ai.StreamPartTypeToolCall:
2306				toolCall = i
2307				require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2308				require.Equal(t, "test-tool", part.ToolCallName)
2309				require.Equal(t, `{"value":"Sparkle Day"}`, part.ToolCallInput)
2310			}
2311		}
2312
2313		require.NotEqual(t, -1, toolInputStart)
2314		require.NotEqual(t, -1, toolInputEnd)
2315		require.NotEqual(t, -1, toolCall)
2316
2317		// Verify tool deltas combine to form the complete input
2318		fullInput := ""
2319		for _, delta := range toolDeltas {
2320			fullInput += delta
2321		}
2322		require.Equal(t, `{"value":"Sparkle Day"}`, fullInput)
2323	})
2324
2325	t.Run("should stream annotations/citations", func(t *testing.T) {
2326		t.Parallel()
2327
2328		server := newStreamingMockServer()
2329		defer server.close()
2330
2331		server.prepareStreamResponse(map[string]any{
2332			"content": []string{"Based on search results"},
2333			"annotations": []map[string]any{
2334				{
2335					"type": "url_citation",
2336					"url_citation": map[string]any{
2337						"start_index": 24,
2338						"end_index":   29,
2339						"url":         "https://example.com/doc1.pdf",
2340						"title":       "Document 1",
2341					},
2342				},
2343			},
2344		})
2345
2346		provider := New(
2347			WithAPIKey("test-api-key"),
2348			WithBaseURL(server.server.URL),
2349		)
2350		model, _ := provider.LanguageModel("gpt-3.5-turbo")
2351
2352		stream, err := model.Stream(context.Background(), ai.Call{
2353			Prompt: testPrompt,
2354		})
2355
2356		require.NoError(t, err)
2357
2358		parts, err := collectStreamParts(stream)
2359		require.NoError(t, err)
2360
2361		// Find source part
2362		var sourcePart *ai.StreamPart
2363		for _, part := range parts {
2364			if part.Type == ai.StreamPartTypeSource {
2365				sourcePart = &part
2366				break
2367			}
2368		}
2369
2370		require.NotNil(t, sourcePart)
2371		require.Equal(t, ai.SourceTypeURL, sourcePart.SourceType)
2372		require.Equal(t, "https://example.com/doc1.pdf", sourcePart.URL)
2373		require.Equal(t, "Document 1", sourcePart.Title)
2374		require.NotEmpty(t, sourcePart.ID)
2375	})
2376
2377	t.Run("should handle error stream parts", func(t *testing.T) {
2378		t.Parallel()
2379
2380		server := newStreamingMockServer()
2381		defer server.close()
2382
2383		server.prepareErrorStreamResponse()
2384
2385		provider := New(
2386			WithAPIKey("test-api-key"),
2387			WithBaseURL(server.server.URL),
2388		)
2389		model, _ := provider.LanguageModel("gpt-3.5-turbo")
2390
2391		stream, err := model.Stream(context.Background(), ai.Call{
2392			Prompt: testPrompt,
2393		})
2394
2395		require.NoError(t, err)
2396
2397		parts, err := collectStreamParts(stream)
2398		require.NoError(t, err)
2399
2400		// Should have error and finish parts
2401		require.True(t, len(parts) >= 1)
2402
2403		// Find error part
2404		var errorPart *ai.StreamPart
2405		for _, part := range parts {
2406			if part.Type == ai.StreamPartTypeError {
2407				errorPart = &part
2408				break
2409			}
2410		}
2411
2412		require.NotNil(t, errorPart)
2413		require.NotNil(t, errorPart.Error)
2414	})
2415
2416	t.Run("should send request body", func(t *testing.T) {
2417		t.Parallel()
2418
2419		server := newStreamingMockServer()
2420		defer server.close()
2421
2422		server.prepareStreamResponse(map[string]any{
2423			"content": []string{},
2424		})
2425
2426		provider := New(
2427			WithAPIKey("test-api-key"),
2428			WithBaseURL(server.server.URL),
2429		)
2430		model, _ := provider.LanguageModel("gpt-3.5-turbo")
2431
2432		_, err := model.Stream(context.Background(), ai.Call{
2433			Prompt: testPrompt,
2434		})
2435
2436		require.NoError(t, err)
2437		require.Len(t, server.calls, 1)
2438
2439		call := server.calls[0]
2440		require.Equal(t, "POST", call.method)
2441		require.Equal(t, "/chat/completions", call.path)
2442		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2443		require.Equal(t, true, call.body["stream"])
2444
2445		streamOptions := call.body["stream_options"].(map[string]any)
2446		require.Equal(t, true, streamOptions["include_usage"])
2447
2448		messages := call.body["messages"].([]any)
2449		require.Len(t, messages, 1)
2450
2451		message := messages[0].(map[string]any)
2452		require.Equal(t, "user", message["role"])
2453		require.Equal(t, "Hello", message["content"])
2454	})
2455
2456	t.Run("should return cached tokens in providerMetadata", func(t *testing.T) {
2457		t.Parallel()
2458
2459		server := newStreamingMockServer()
2460		defer server.close()
2461
2462		server.prepareStreamResponse(map[string]any{
2463			"content": []string{},
2464			"usage": map[string]any{
2465				"prompt_tokens":     15,
2466				"completion_tokens": 20,
2467				"total_tokens":      35,
2468				"prompt_tokens_details": map[string]any{
2469					"cached_tokens": 1152,
2470				},
2471			},
2472		})
2473
2474		provider := New(
2475			WithAPIKey("test-api-key"),
2476			WithBaseURL(server.server.URL),
2477		)
2478		model, _ := provider.LanguageModel("gpt-3.5-turbo")
2479
2480		stream, err := model.Stream(context.Background(), ai.Call{
2481			Prompt: testPrompt,
2482		})
2483
2484		require.NoError(t, err)
2485
2486		parts, err := collectStreamParts(stream)
2487		require.NoError(t, err)
2488
2489		// Find finish part
2490		var finishPart *ai.StreamPart
2491		for _, part := range parts {
2492			if part.Type == ai.StreamPartTypeFinish {
2493				finishPart = &part
2494				break
2495			}
2496		}
2497
2498		require.NotNil(t, finishPart)
2499		require.Equal(t, int64(1152), finishPart.Usage.CacheReadTokens)
2500		require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2501		require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2502		require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2503	})
2504
2505	t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
2506		t.Parallel()
2507
2508		server := newStreamingMockServer()
2509		defer server.close()
2510
2511		server.prepareStreamResponse(map[string]any{
2512			"content": []string{},
2513			"usage": map[string]any{
2514				"prompt_tokens":     15,
2515				"completion_tokens": 20,
2516				"total_tokens":      35,
2517				"completion_tokens_details": map[string]any{
2518					"accepted_prediction_tokens": 123,
2519					"rejected_prediction_tokens": 456,
2520				},
2521			},
2522		})
2523
2524		provider := New(
2525			WithAPIKey("test-api-key"),
2526			WithBaseURL(server.server.URL),
2527		)
2528		model, _ := provider.LanguageModel("gpt-3.5-turbo")
2529
2530		stream, err := model.Stream(context.Background(), ai.Call{
2531			Prompt: testPrompt,
2532		})
2533
2534		require.NoError(t, err)
2535
2536		parts, err := collectStreamParts(stream)
2537		require.NoError(t, err)
2538
2539		// Find finish part
2540		var finishPart *ai.StreamPart
2541		for _, part := range parts {
2542			if part.Type == ai.StreamPartTypeFinish {
2543				finishPart = &part
2544				break
2545			}
2546		}
2547
2548		require.NotNil(t, finishPart)
2549		require.NotNil(t, finishPart.ProviderMetadata)
2550
2551		openaiMeta, ok := finishPart.ProviderMetadata["openai"].(*ProviderMetadata)
2552		require.True(t, ok)
2553		require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
2554		require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
2555	})
2556
2557	t.Run("should send store extension setting", func(t *testing.T) {
2558		t.Parallel()
2559
2560		server := newStreamingMockServer()
2561		defer server.close()
2562
2563		server.prepareStreamResponse(map[string]any{
2564			"content": []string{},
2565		})
2566
2567		provider := New(
2568			WithAPIKey("test-api-key"),
2569			WithBaseURL(server.server.URL),
2570		)
2571		model, _ := provider.LanguageModel("gpt-3.5-turbo")
2572
2573		_, err := model.Stream(context.Background(), ai.Call{
2574			Prompt: testPrompt,
2575			ProviderOptions: NewProviderOptions(&ProviderOptions{
2576				Store: ai.BoolOption(true),
2577			}),
2578		})
2579
2580		require.NoError(t, err)
2581		require.Len(t, server.calls, 1)
2582
2583		call := server.calls[0]
2584		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2585		require.Equal(t, true, call.body["stream"])
2586		require.Equal(t, true, call.body["store"])
2587
2588		streamOptions := call.body["stream_options"].(map[string]any)
2589		require.Equal(t, true, streamOptions["include_usage"])
2590
2591		messages := call.body["messages"].([]any)
2592		require.Len(t, messages, 1)
2593
2594		message := messages[0].(map[string]any)
2595		require.Equal(t, "user", message["role"])
2596		require.Equal(t, "Hello", message["content"])
2597	})
2598
2599	t.Run("should send metadata extension values", func(t *testing.T) {
2600		t.Parallel()
2601
2602		server := newStreamingMockServer()
2603		defer server.close()
2604
2605		server.prepareStreamResponse(map[string]any{
2606			"content": []string{},
2607		})
2608
2609		provider := New(
2610			WithAPIKey("test-api-key"),
2611			WithBaseURL(server.server.URL),
2612		)
2613		model, _ := provider.LanguageModel("gpt-3.5-turbo")
2614
2615		_, err := model.Stream(context.Background(), ai.Call{
2616			Prompt: testPrompt,
2617			ProviderOptions: NewProviderOptions(&ProviderOptions{
2618				Metadata: map[string]any{
2619					"custom": "value",
2620				},
2621			}),
2622		})
2623
2624		require.NoError(t, err)
2625		require.Len(t, server.calls, 1)
2626
2627		call := server.calls[0]
2628		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2629		require.Equal(t, true, call.body["stream"])
2630
2631		metadata := call.body["metadata"].(map[string]any)
2632		require.Equal(t, "value", metadata["custom"])
2633
2634		streamOptions := call.body["stream_options"].(map[string]any)
2635		require.Equal(t, true, streamOptions["include_usage"])
2636
2637		messages := call.body["messages"].([]any)
2638		require.Len(t, messages, 1)
2639
2640		message := messages[0].(map[string]any)
2641		require.Equal(t, "user", message["role"])
2642		require.Equal(t, "Hello", message["content"])
2643	})
2644
2645	t.Run("should send serviceTier flex processing setting in streaming", func(t *testing.T) {
2646		t.Parallel()
2647
2648		server := newStreamingMockServer()
2649		defer server.close()
2650
2651		server.prepareStreamResponse(map[string]any{
2652			"content": []string{},
2653		})
2654
2655		provider := New(
2656			WithAPIKey("test-api-key"),
2657			WithBaseURL(server.server.URL),
2658		)
2659		model, _ := provider.LanguageModel("o3-mini")
2660
2661		_, err := model.Stream(context.Background(), ai.Call{
2662			Prompt: testPrompt,
2663			ProviderOptions: NewProviderOptions(&ProviderOptions{
2664				ServiceTier: ai.StringOption("flex"),
2665			}),
2666		})
2667
2668		require.NoError(t, err)
2669		require.Len(t, server.calls, 1)
2670
2671		call := server.calls[0]
2672		require.Equal(t, "o3-mini", call.body["model"])
2673		require.Equal(t, "flex", call.body["service_tier"])
2674		require.Equal(t, true, call.body["stream"])
2675
2676		streamOptions := call.body["stream_options"].(map[string]any)
2677		require.Equal(t, true, streamOptions["include_usage"])
2678
2679		messages := call.body["messages"].([]any)
2680		require.Len(t, messages, 1)
2681
2682		message := messages[0].(map[string]any)
2683		require.Equal(t, "user", message["role"])
2684		require.Equal(t, "Hello", message["content"])
2685	})
2686
2687	t.Run("should send serviceTier priority processing setting in streaming", func(t *testing.T) {
2688		t.Parallel()
2689
2690		server := newStreamingMockServer()
2691		defer server.close()
2692
2693		server.prepareStreamResponse(map[string]any{
2694			"content": []string{},
2695		})
2696
2697		provider := New(
2698			WithAPIKey("test-api-key"),
2699			WithBaseURL(server.server.URL),
2700		)
2701		model, _ := provider.LanguageModel("gpt-4o-mini")
2702
2703		_, err := model.Stream(context.Background(), ai.Call{
2704			Prompt: testPrompt,
2705			ProviderOptions: NewProviderOptions(&ProviderOptions{
2706				ServiceTier: ai.StringOption("priority"),
2707			}),
2708		})
2709
2710		require.NoError(t, err)
2711		require.Len(t, server.calls, 1)
2712
2713		call := server.calls[0]
2714		require.Equal(t, "gpt-4o-mini", call.body["model"])
2715		require.Equal(t, "priority", call.body["service_tier"])
2716		require.Equal(t, true, call.body["stream"])
2717
2718		streamOptions := call.body["stream_options"].(map[string]any)
2719		require.Equal(t, true, streamOptions["include_usage"])
2720
2721		messages := call.body["messages"].([]any)
2722		require.Len(t, messages, 1)
2723
2724		message := messages[0].(map[string]any)
2725		require.Equal(t, "user", message["role"])
2726		require.Equal(t, "Hello", message["content"])
2727	})
2728
2729	t.Run("should stream text delta for reasoning models", func(t *testing.T) {
2730		t.Parallel()
2731
2732		server := newStreamingMockServer()
2733		defer server.close()
2734
2735		server.prepareStreamResponse(map[string]any{
2736			"content": []string{"Hello, World!"},
2737			"model":   "o1-preview",
2738		})
2739
2740		provider := New(
2741			WithAPIKey("test-api-key"),
2742			WithBaseURL(server.server.URL),
2743		)
2744		model, _ := provider.LanguageModel("o1-preview")
2745
2746		stream, err := model.Stream(context.Background(), ai.Call{
2747			Prompt: testPrompt,
2748		})
2749
2750		require.NoError(t, err)
2751
2752		parts, err := collectStreamParts(stream)
2753		require.NoError(t, err)
2754
2755		// Find text parts
2756		var textDeltas []string
2757		for _, part := range parts {
2758			if part.Type == ai.StreamPartTypeTextDelta {
2759				textDeltas = append(textDeltas, part.Delta)
2760			}
2761		}
2762
2763		// Should contain the text content (without empty delta)
2764		require.Equal(t, []string{"Hello, World!"}, textDeltas)
2765	})
2766
2767	t.Run("should send reasoning tokens", func(t *testing.T) {
2768		t.Parallel()
2769
2770		server := newStreamingMockServer()
2771		defer server.close()
2772
2773		server.prepareStreamResponse(map[string]any{
2774			"content": []string{"Hello, World!"},
2775			"model":   "o1-preview",
2776			"usage": map[string]any{
2777				"prompt_tokens":     15,
2778				"completion_tokens": 20,
2779				"total_tokens":      35,
2780				"completion_tokens_details": map[string]any{
2781					"reasoning_tokens": 10,
2782				},
2783			},
2784		})
2785
2786		provider := New(
2787			WithAPIKey("test-api-key"),
2788			WithBaseURL(server.server.URL),
2789		)
2790		model, _ := provider.LanguageModel("o1-preview")
2791
2792		stream, err := model.Stream(context.Background(), ai.Call{
2793			Prompt: testPrompt,
2794		})
2795
2796		require.NoError(t, err)
2797
2798		parts, err := collectStreamParts(stream)
2799		require.NoError(t, err)
2800
2801		// Find finish part
2802		var finishPart *ai.StreamPart
2803		for _, part := range parts {
2804			if part.Type == ai.StreamPartTypeFinish {
2805				finishPart = &part
2806				break
2807			}
2808		}
2809
2810		require.NotNil(t, finishPart)
2811		require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2812		require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2813		require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2814		require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens)
2815	})
2816}