openai_test.go

   1package providers
   2
   3import (
   4	"context"
   5	"encoding/base64"
   6	"encoding/json"
   7	"net/http"
   8	"net/http/httptest"
   9	"strings"
  10	"testing"
  11
  12	"github.com/charmbracelet/crush/internal/ai"
  13	"github.com/openai/openai-go/v2/packages/param"
  14	"github.com/stretchr/testify/require"
  15)
  16
  17func TestToOpenAIPrompt_SystemMessages(t *testing.T) {
  18	t.Parallel()
  19
  20	t.Run("should forward system messages", func(t *testing.T) {
  21		t.Parallel()
  22
  23		prompt := ai.Prompt{
  24			{
  25				Role: ai.MessageRoleSystem,
  26				Content: []ai.MessagePart{
  27					ai.TextPart{Text: "You are a helpful assistant."},
  28				},
  29			},
  30		}
  31
  32		messages, warnings := toOpenAIPrompt(prompt)
  33
  34		require.Empty(t, warnings)
  35		require.Len(t, messages, 1)
  36
  37		systemMsg := messages[0].OfSystem
  38		require.NotNil(t, systemMsg)
  39		require.Equal(t, "You are a helpful assistant.", systemMsg.Content.OfString.Value)
  40	})
  41
  42	t.Run("should handle empty system messages", func(t *testing.T) {
  43		t.Parallel()
  44
  45		prompt := ai.Prompt{
  46			{
  47				Role:    ai.MessageRoleSystem,
  48				Content: []ai.MessagePart{},
  49			},
  50		}
  51
  52		messages, warnings := toOpenAIPrompt(prompt)
  53
  54		require.Len(t, warnings, 1)
  55		require.Contains(t, warnings[0].Message, "system prompt has no text parts")
  56		require.Empty(t, messages)
  57	})
  58
  59	t.Run("should join multiple system text parts", func(t *testing.T) {
  60		t.Parallel()
  61
  62		prompt := ai.Prompt{
  63			{
  64				Role: ai.MessageRoleSystem,
  65				Content: []ai.MessagePart{
  66					ai.TextPart{Text: "You are a helpful assistant."},
  67					ai.TextPart{Text: "Be concise."},
  68				},
  69			},
  70		}
  71
  72		messages, warnings := toOpenAIPrompt(prompt)
  73
  74		require.Empty(t, warnings)
  75		require.Len(t, messages, 1)
  76
  77		systemMsg := messages[0].OfSystem
  78		require.NotNil(t, systemMsg)
  79		require.Equal(t, "You are a helpful assistant.\nBe concise.", systemMsg.Content.OfString.Value)
  80	})
  81}
  82
  83func TestToOpenAIPrompt_UserMessages(t *testing.T) {
  84	t.Parallel()
  85
  86	t.Run("should convert messages with only a text part to a string content", func(t *testing.T) {
  87		t.Parallel()
  88
  89		prompt := ai.Prompt{
  90			{
  91				Role: ai.MessageRoleUser,
  92				Content: []ai.MessagePart{
  93					ai.TextPart{Text: "Hello"},
  94				},
  95			},
  96		}
  97
  98		messages, warnings := toOpenAIPrompt(prompt)
  99
 100		require.Empty(t, warnings)
 101		require.Len(t, messages, 1)
 102
 103		userMsg := messages[0].OfUser
 104		require.NotNil(t, userMsg)
 105		require.Equal(t, "Hello", userMsg.Content.OfString.Value)
 106	})
 107
 108	t.Run("should convert messages with image parts", func(t *testing.T) {
 109		t.Parallel()
 110
 111		imageData := []byte{0, 1, 2, 3}
 112		prompt := ai.Prompt{
 113			{
 114				Role: ai.MessageRoleUser,
 115				Content: []ai.MessagePart{
 116					ai.TextPart{Text: "Hello"},
 117					ai.FilePart{
 118						MediaType: "image/png",
 119						Data:      imageData,
 120					},
 121				},
 122			},
 123		}
 124
 125		messages, warnings := toOpenAIPrompt(prompt)
 126
 127		require.Empty(t, warnings)
 128		require.Len(t, messages, 1)
 129
 130		userMsg := messages[0].OfUser
 131		require.NotNil(t, userMsg)
 132
 133		content := userMsg.Content.OfArrayOfContentParts
 134		require.Len(t, content, 2)
 135
 136		// Check text part
 137		textPart := content[0].OfText
 138		require.NotNil(t, textPart)
 139		require.Equal(t, "Hello", textPart.Text)
 140
 141		// Check image part
 142		imagePart := content[1].OfImageURL
 143		require.NotNil(t, imagePart)
 144		expectedURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
 145		require.Equal(t, expectedURL, imagePart.ImageURL.URL)
 146	})
 147
 148	t.Run("should add image detail when specified through provider options", func(t *testing.T) {
 149		t.Parallel()
 150
 151		imageData := []byte{0, 1, 2, 3}
 152		prompt := ai.Prompt{
 153			{
 154				Role: ai.MessageRoleUser,
 155				Content: []ai.MessagePart{
 156					ai.FilePart{
 157						MediaType: "image/png",
 158						Data:      imageData,
 159						ProviderOptions: ai.ProviderOptions{
 160							"openai": map[string]any{
 161								"imageDetail": "low",
 162							},
 163						},
 164					},
 165				},
 166			},
 167		}
 168
 169		messages, warnings := toOpenAIPrompt(prompt)
 170
 171		require.Empty(t, warnings)
 172		require.Len(t, messages, 1)
 173
 174		userMsg := messages[0].OfUser
 175		require.NotNil(t, userMsg)
 176
 177		content := userMsg.Content.OfArrayOfContentParts
 178		require.Len(t, content, 1)
 179
 180		imagePart := content[0].OfImageURL
 181		require.NotNil(t, imagePart)
 182		require.Equal(t, "low", imagePart.ImageURL.Detail)
 183	})
 184}
 185
 186func TestToOpenAIPrompt_FileParts(t *testing.T) {
 187	t.Parallel()
 188
 189	t.Run("should throw for unsupported mime types", func(t *testing.T) {
 190		t.Parallel()
 191
 192		prompt := ai.Prompt{
 193			{
 194				Role: ai.MessageRoleUser,
 195				Content: []ai.MessagePart{
 196					ai.FilePart{
 197						MediaType: "application/something",
 198						Data:      []byte("test"),
 199					},
 200				},
 201			},
 202		}
 203
 204		messages, warnings := toOpenAIPrompt(prompt)
 205
 206		require.Len(t, warnings, 1)
 207		require.Contains(t, warnings[0].Message, "file part media type application/something not supported")
 208		require.Len(t, messages, 1) // Message is still created but with empty content array
 209	})
 210
 211	t.Run("should add audio content for audio/wav file parts", func(t *testing.T) {
 212		t.Parallel()
 213
 214		audioData := []byte{0, 1, 2, 3}
 215		prompt := ai.Prompt{
 216			{
 217				Role: ai.MessageRoleUser,
 218				Content: []ai.MessagePart{
 219					ai.FilePart{
 220						MediaType: "audio/wav",
 221						Data:      audioData,
 222					},
 223				},
 224			},
 225		}
 226
 227		messages, warnings := toOpenAIPrompt(prompt)
 228
 229		require.Empty(t, warnings)
 230		require.Len(t, messages, 1)
 231
 232		userMsg := messages[0].OfUser
 233		require.NotNil(t, userMsg)
 234
 235		content := userMsg.Content.OfArrayOfContentParts
 236		require.Len(t, content, 1)
 237
 238		audioPart := content[0].OfInputAudio
 239		require.NotNil(t, audioPart)
 240		require.Equal(t, base64.StdEncoding.EncodeToString(audioData), audioPart.InputAudio.Data)
 241		require.Equal(t, "wav", audioPart.InputAudio.Format)
 242	})
 243
 244	t.Run("should add audio content for audio/mpeg file parts", func(t *testing.T) {
 245		t.Parallel()
 246
 247		audioData := []byte{0, 1, 2, 3}
 248		prompt := ai.Prompt{
 249			{
 250				Role: ai.MessageRoleUser,
 251				Content: []ai.MessagePart{
 252					ai.FilePart{
 253						MediaType: "audio/mpeg",
 254						Data:      audioData,
 255					},
 256				},
 257			},
 258		}
 259
 260		messages, warnings := toOpenAIPrompt(prompt)
 261
 262		require.Empty(t, warnings)
 263		require.Len(t, messages, 1)
 264
 265		userMsg := messages[0].OfUser
 266		content := userMsg.Content.OfArrayOfContentParts
 267		audioPart := content[0].OfInputAudio
 268		require.NotNil(t, audioPart)
 269		require.Equal(t, "mp3", audioPart.InputAudio.Format)
 270	})
 271
 272	t.Run("should add audio content for audio/mp3 file parts", func(t *testing.T) {
 273		t.Parallel()
 274
 275		audioData := []byte{0, 1, 2, 3}
 276		prompt := ai.Prompt{
 277			{
 278				Role: ai.MessageRoleUser,
 279				Content: []ai.MessagePart{
 280					ai.FilePart{
 281						MediaType: "audio/mp3",
 282						Data:      audioData,
 283					},
 284				},
 285			},
 286		}
 287
 288		messages, warnings := toOpenAIPrompt(prompt)
 289
 290		require.Empty(t, warnings)
 291		require.Len(t, messages, 1)
 292
 293		userMsg := messages[0].OfUser
 294		content := userMsg.Content.OfArrayOfContentParts
 295		audioPart := content[0].OfInputAudio
 296		require.NotNil(t, audioPart)
 297		require.Equal(t, "mp3", audioPart.InputAudio.Format)
 298	})
 299
 300	t.Run("should convert messages with PDF file parts", func(t *testing.T) {
 301		t.Parallel()
 302
 303		pdfData := []byte{1, 2, 3, 4, 5}
 304		prompt := ai.Prompt{
 305			{
 306				Role: ai.MessageRoleUser,
 307				Content: []ai.MessagePart{
 308					ai.FilePart{
 309						MediaType: "application/pdf",
 310						Data:      pdfData,
 311						Filename:  "document.pdf",
 312					},
 313				},
 314			},
 315		}
 316
 317		messages, warnings := toOpenAIPrompt(prompt)
 318
 319		require.Empty(t, warnings)
 320		require.Len(t, messages, 1)
 321
 322		userMsg := messages[0].OfUser
 323		content := userMsg.Content.OfArrayOfContentParts
 324		require.Len(t, content, 1)
 325
 326		filePart := content[0].OfFile
 327		require.NotNil(t, filePart)
 328		require.Equal(t, "document.pdf", filePart.File.Filename.Value)
 329
 330		expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
 331		require.Equal(t, expectedData, filePart.File.FileData.Value)
 332	})
 333
 334	t.Run("should convert messages with binary PDF file parts", func(t *testing.T) {
 335		t.Parallel()
 336
 337		pdfData := []byte{1, 2, 3, 4, 5}
 338		prompt := ai.Prompt{
 339			{
 340				Role: ai.MessageRoleUser,
 341				Content: []ai.MessagePart{
 342					ai.FilePart{
 343						MediaType: "application/pdf",
 344						Data:      pdfData,
 345						Filename:  "document.pdf",
 346					},
 347				},
 348			},
 349		}
 350
 351		messages, warnings := toOpenAIPrompt(prompt)
 352
 353		require.Empty(t, warnings)
 354		require.Len(t, messages, 1)
 355
 356		userMsg := messages[0].OfUser
 357		content := userMsg.Content.OfArrayOfContentParts
 358		filePart := content[0].OfFile
 359		require.NotNil(t, filePart)
 360
 361		expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
 362		require.Equal(t, expectedData, filePart.File.FileData.Value)
 363	})
 364
 365	t.Run("should convert messages with PDF file parts using file_id", func(t *testing.T) {
 366		t.Parallel()
 367
 368		prompt := ai.Prompt{
 369			{
 370				Role: ai.MessageRoleUser,
 371				Content: []ai.MessagePart{
 372					ai.FilePart{
 373						MediaType: "application/pdf",
 374						Data:      []byte("file-pdf-12345"),
 375					},
 376				},
 377			},
 378		}
 379
 380		messages, warnings := toOpenAIPrompt(prompt)
 381
 382		require.Empty(t, warnings)
 383		require.Len(t, messages, 1)
 384
 385		userMsg := messages[0].OfUser
 386		content := userMsg.Content.OfArrayOfContentParts
 387		filePart := content[0].OfFile
 388		require.NotNil(t, filePart)
 389		require.Equal(t, "file-pdf-12345", filePart.File.FileID.Value)
 390		require.True(t, param.IsOmitted(filePart.File.FileData))
 391		require.True(t, param.IsOmitted(filePart.File.Filename))
 392	})
 393
 394	t.Run("should use default filename for PDF file parts when not provided", func(t *testing.T) {
 395		t.Parallel()
 396
 397		pdfData := []byte{1, 2, 3, 4, 5}
 398		prompt := ai.Prompt{
 399			{
 400				Role: ai.MessageRoleUser,
 401				Content: []ai.MessagePart{
 402					ai.FilePart{
 403						MediaType: "application/pdf",
 404						Data:      pdfData,
 405					},
 406				},
 407			},
 408		}
 409
 410		messages, warnings := toOpenAIPrompt(prompt)
 411
 412		require.Empty(t, warnings)
 413		require.Len(t, messages, 1)
 414
 415		userMsg := messages[0].OfUser
 416		content := userMsg.Content.OfArrayOfContentParts
 417		filePart := content[0].OfFile
 418		require.NotNil(t, filePart)
 419		require.Equal(t, "part-0.pdf", filePart.File.Filename.Value)
 420	})
 421}
 422
 423func TestToOpenAIPrompt_ToolCalls(t *testing.T) {
 424	t.Parallel()
 425
 426	t.Run("should stringify arguments to tool calls", func(t *testing.T) {
 427		t.Parallel()
 428
 429		inputArgs := map[string]any{"foo": "bar123"}
 430		inputJSON, _ := json.Marshal(inputArgs)
 431
 432		outputResult := map[string]any{"oof": "321rab"}
 433		outputJSON, _ := json.Marshal(outputResult)
 434
 435		prompt := ai.Prompt{
 436			{
 437				Role: ai.MessageRoleAssistant,
 438				Content: []ai.MessagePart{
 439					ai.ToolCallPart{
 440						ToolCallID: "quux",
 441						ToolName:   "thwomp",
 442						Input:      string(inputJSON),
 443					},
 444				},
 445			},
 446			{
 447				Role: ai.MessageRoleTool,
 448				Content: []ai.MessagePart{
 449					ai.ToolResultPart{
 450						ToolCallID: "quux",
 451						Output: ai.ToolResultOutputContentText{
 452							Text: string(outputJSON),
 453						},
 454					},
 455				},
 456			},
 457		}
 458
 459		messages, warnings := toOpenAIPrompt(prompt)
 460
 461		require.Empty(t, warnings)
 462		require.Len(t, messages, 2)
 463
 464		// Check assistant message with tool call
 465		assistantMsg := messages[0].OfAssistant
 466		require.NotNil(t, assistantMsg)
 467		require.Equal(t, "", assistantMsg.Content.OfString.Value)
 468		require.Len(t, assistantMsg.ToolCalls, 1)
 469
 470		toolCall := assistantMsg.ToolCalls[0].OfFunction
 471		require.NotNil(t, toolCall)
 472		require.Equal(t, "quux", toolCall.ID)
 473		require.Equal(t, "thwomp", toolCall.Function.Name)
 474		require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
 475
 476		// Check tool message
 477		toolMsg := messages[1].OfTool
 478		require.NotNil(t, toolMsg)
 479		require.Equal(t, string(outputJSON), toolMsg.Content.OfString.Value)
 480		require.Equal(t, "quux", toolMsg.ToolCallID)
 481	})
 482
 483	t.Run("should handle different tool output types", func(t *testing.T) {
 484		t.Parallel()
 485
 486		prompt := ai.Prompt{
 487			{
 488				Role: ai.MessageRoleTool,
 489				Content: []ai.MessagePart{
 490					ai.ToolResultPart{
 491						ToolCallID: "text-tool",
 492						Output: ai.ToolResultOutputContentText{
 493							Text: "Hello world",
 494						},
 495					},
 496					ai.ToolResultPart{
 497						ToolCallID: "error-tool",
 498						Output: ai.ToolResultOutputContentError{
 499							Error: "Something went wrong",
 500						},
 501					},
 502				},
 503			},
 504		}
 505
 506		messages, warnings := toOpenAIPrompt(prompt)
 507
 508		require.Empty(t, warnings)
 509		require.Len(t, messages, 2)
 510
 511		// Check first tool message (text)
 512		textToolMsg := messages[0].OfTool
 513		require.NotNil(t, textToolMsg)
 514		require.Equal(t, "Hello world", textToolMsg.Content.OfString.Value)
 515		require.Equal(t, "text-tool", textToolMsg.ToolCallID)
 516
 517		// Check second tool message (error)
 518		errorToolMsg := messages[1].OfTool
 519		require.NotNil(t, errorToolMsg)
 520		require.Equal(t, "Something went wrong", errorToolMsg.Content.OfString.Value)
 521		require.Equal(t, "error-tool", errorToolMsg.ToolCallID)
 522	})
 523}
 524
 525func TestToOpenAIPrompt_AssistantMessages(t *testing.T) {
 526	t.Parallel()
 527
 528	t.Run("should handle simple text assistant messages", func(t *testing.T) {
 529		t.Parallel()
 530
 531		prompt := ai.Prompt{
 532			{
 533				Role: ai.MessageRoleAssistant,
 534				Content: []ai.MessagePart{
 535					ai.TextPart{Text: "Hello, how can I help you?"},
 536				},
 537			},
 538		}
 539
 540		messages, warnings := toOpenAIPrompt(prompt)
 541
 542		require.Empty(t, warnings)
 543		require.Len(t, messages, 1)
 544
 545		assistantMsg := messages[0].OfAssistant
 546		require.NotNil(t, assistantMsg)
 547		require.Equal(t, "Hello, how can I help you?", assistantMsg.Content.OfString.Value)
 548	})
 549
 550	t.Run("should handle assistant messages with mixed content", func(t *testing.T) {
 551		t.Parallel()
 552
 553		inputArgs := map[string]any{"query": "test"}
 554		inputJSON, _ := json.Marshal(inputArgs)
 555
 556		prompt := ai.Prompt{
 557			{
 558				Role: ai.MessageRoleAssistant,
 559				Content: []ai.MessagePart{
 560					ai.TextPart{Text: "Let me search for that."},
 561					ai.ToolCallPart{
 562						ToolCallID: "call-123",
 563						ToolName:   "search",
 564						Input:      string(inputJSON),
 565					},
 566				},
 567			},
 568		}
 569
 570		messages, warnings := toOpenAIPrompt(prompt)
 571
 572		require.Empty(t, warnings)
 573		require.Len(t, messages, 1)
 574
 575		assistantMsg := messages[0].OfAssistant
 576		require.NotNil(t, assistantMsg)
 577		require.Equal(t, "Let me search for that.", assistantMsg.Content.OfString.Value)
 578		require.Len(t, assistantMsg.ToolCalls, 1)
 579
 580		toolCall := assistantMsg.ToolCalls[0].OfFunction
 581		require.Equal(t, "call-123", toolCall.ID)
 582		require.Equal(t, "search", toolCall.Function.Name)
 583		require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
 584	})
 585}
 586
 587var testPrompt = ai.Prompt{
 588	{
 589		Role: ai.MessageRoleUser,
 590		Content: []ai.MessagePart{
 591			ai.TextPart{Text: "Hello"},
 592		},
 593	},
 594}
 595
 596var testLogprobs = map[string]any{
 597	"content": []map[string]any{
 598		{
 599			"token":   "Hello",
 600			"logprob": -0.0009994634,
 601			"top_logprobs": []map[string]any{
 602				{
 603					"token":   "Hello",
 604					"logprob": -0.0009994634,
 605				},
 606			},
 607		},
 608		{
 609			"token":   "!",
 610			"logprob": -0.13410144,
 611			"top_logprobs": []map[string]any{
 612				{
 613					"token":   "!",
 614					"logprob": -0.13410144,
 615				},
 616			},
 617		},
 618		{
 619			"token":   " How",
 620			"logprob": -0.0009250381,
 621			"top_logprobs": []map[string]any{
 622				{
 623					"token":   " How",
 624					"logprob": -0.0009250381,
 625				},
 626			},
 627		},
 628		{
 629			"token":   " can",
 630			"logprob": -0.047709424,
 631			"top_logprobs": []map[string]any{
 632				{
 633					"token":   " can",
 634					"logprob": -0.047709424,
 635				},
 636			},
 637		},
 638		{
 639			"token":   " I",
 640			"logprob": -0.000009014684,
 641			"top_logprobs": []map[string]any{
 642				{
 643					"token":   " I",
 644					"logprob": -0.000009014684,
 645				},
 646			},
 647		},
 648		{
 649			"token":   " assist",
 650			"logprob": -0.009125131,
 651			"top_logprobs": []map[string]any{
 652				{
 653					"token":   " assist",
 654					"logprob": -0.009125131,
 655				},
 656			},
 657		},
 658		{
 659			"token":   " you",
 660			"logprob": -0.0000066306106,
 661			"top_logprobs": []map[string]any{
 662				{
 663					"token":   " you",
 664					"logprob": -0.0000066306106,
 665				},
 666			},
 667		},
 668		{
 669			"token":   " today",
 670			"logprob": -0.00011093382,
 671			"top_logprobs": []map[string]any{
 672				{
 673					"token":   " today",
 674					"logprob": -0.00011093382,
 675				},
 676			},
 677		},
 678		{
 679			"token":   "?",
 680			"logprob": -0.00004596782,
 681			"top_logprobs": []map[string]any{
 682				{
 683					"token":   "?",
 684					"logprob": -0.00004596782,
 685				},
 686			},
 687		},
 688	},
 689}
 690
 691type mockServer struct {
 692	server   *httptest.Server
 693	response map[string]any
 694	calls    []mockCall
 695}
 696
 697type mockCall struct {
 698	method  string
 699	path    string
 700	headers map[string]string
 701	body    map[string]any
 702}
 703
 704func newMockServer() *mockServer {
 705	ms := &mockServer{
 706		calls: make([]mockCall, 0),
 707	}
 708
 709	ms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 710		// Record the call
 711		call := mockCall{
 712			method:  r.Method,
 713			path:    r.URL.Path,
 714			headers: make(map[string]string),
 715		}
 716
 717		for k, v := range r.Header {
 718			if len(v) > 0 {
 719				call.headers[k] = v[0]
 720			}
 721		}
 722
 723		// Parse request body
 724		if r.Body != nil {
 725			var body map[string]any
 726			json.NewDecoder(r.Body).Decode(&body)
 727			call.body = body
 728		}
 729
 730		ms.calls = append(ms.calls, call)
 731
 732		// Return mock response
 733		w.Header().Set("Content-Type", "application/json")
 734		json.NewEncoder(w).Encode(ms.response)
 735	}))
 736
 737	return ms
 738}
 739
 740func (ms *mockServer) close() {
 741	ms.server.Close()
 742}
 743
 744func (ms *mockServer) prepareJSONResponse(opts map[string]any) {
 745	// Default values
 746	response := map[string]any{
 747		"id":      "chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd",
 748		"object":  "chat.completion",
 749		"created": 1711115037,
 750		"model":   "gpt-3.5-turbo-0125",
 751		"choices": []map[string]any{
 752			{
 753				"index": 0,
 754				"message": map[string]any{
 755					"role":    "assistant",
 756					"content": "",
 757				},
 758				"finish_reason": "stop",
 759			},
 760		},
 761		"usage": map[string]any{
 762			"prompt_tokens":     4,
 763			"total_tokens":      34,
 764			"completion_tokens": 30,
 765		},
 766		"system_fingerprint": "fp_3bc1b5746c",
 767	}
 768
 769	// Override with provided options
 770	for k, v := range opts {
 771		switch k {
 772		case "content":
 773			response["choices"].([]map[string]any)[0]["message"].(map[string]any)["content"] = v
 774		case "tool_calls":
 775			response["choices"].([]map[string]any)[0]["message"].(map[string]any)["tool_calls"] = v
 776		case "function_call":
 777			response["choices"].([]map[string]any)[0]["message"].(map[string]any)["function_call"] = v
 778		case "annotations":
 779			response["choices"].([]map[string]any)[0]["message"].(map[string]any)["annotations"] = v
 780		case "usage":
 781			response["usage"] = v
 782		case "finish_reason":
 783			response["choices"].([]map[string]any)[0]["finish_reason"] = v
 784		case "id":
 785			response["id"] = v
 786		case "created":
 787			response["created"] = v
 788		case "model":
 789			response["model"] = v
 790		case "logprobs":
 791			if v != nil {
 792				response["choices"].([]map[string]any)[0]["logprobs"] = v
 793			}
 794		}
 795	}
 796
 797	ms.response = response
 798}
 799
 800func TestDoGenerate(t *testing.T) {
 801	t.Parallel()
 802
 803	t.Run("should extract text response", func(t *testing.T) {
 804		t.Parallel()
 805
 806		server := newMockServer()
 807		defer server.close()
 808
 809		server.prepareJSONResponse(map[string]any{
 810			"content": "Hello, World!",
 811		})
 812
 813		provider := NewOpenAIProvider(
 814			WithOpenAIApiKey("test-api-key"),
 815			WithOpenAIBaseURL(server.server.URL),
 816		)
 817		model := provider.LanguageModel("gpt-3.5-turbo")
 818
 819		result, err := model.Generate(context.Background(), ai.Call{
 820			Prompt: testPrompt,
 821		})
 822
 823		require.NoError(t, err)
 824		require.Len(t, result.Content, 1)
 825
 826		textContent, ok := result.Content[0].(ai.TextContent)
 827		require.True(t, ok)
 828		require.Equal(t, "Hello, World!", textContent.Text)
 829	})
 830
 831	t.Run("should extract usage", func(t *testing.T) {
 832		t.Parallel()
 833
 834		server := newMockServer()
 835		defer server.close()
 836
 837		server.prepareJSONResponse(map[string]any{
 838			"usage": map[string]any{
 839				"prompt_tokens":     20,
 840				"total_tokens":      25,
 841				"completion_tokens": 5,
 842			},
 843		})
 844
 845		provider := NewOpenAIProvider(
 846			WithOpenAIApiKey("test-api-key"),
 847			WithOpenAIBaseURL(server.server.URL),
 848		)
 849		model := provider.LanguageModel("gpt-3.5-turbo")
 850
 851		result, err := model.Generate(context.Background(), ai.Call{
 852			Prompt: testPrompt,
 853		})
 854
 855		require.NoError(t, err)
 856		require.Equal(t, int64(20), result.Usage.InputTokens)
 857		require.Equal(t, int64(5), result.Usage.OutputTokens)
 858		require.Equal(t, int64(25), result.Usage.TotalTokens)
 859	})
 860
 861	t.Run("should send request body", func(t *testing.T) {
 862		t.Parallel()
 863
 864		server := newMockServer()
 865		defer server.close()
 866
 867		server.prepareJSONResponse(map[string]any{})
 868
 869		provider := NewOpenAIProvider(
 870			WithOpenAIApiKey("test-api-key"),
 871			WithOpenAIBaseURL(server.server.URL),
 872		)
 873		model := provider.LanguageModel("gpt-3.5-turbo")
 874
 875		_, err := model.Generate(context.Background(), ai.Call{
 876			Prompt: testPrompt,
 877		})
 878
 879		require.NoError(t, err)
 880		require.Len(t, server.calls, 1)
 881
 882		call := server.calls[0]
 883		require.Equal(t, "POST", call.method)
 884		require.Equal(t, "/chat/completions", call.path)
 885		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
 886
 887		messages, ok := call.body["messages"].([]any)
 888		require.True(t, ok)
 889		require.Len(t, messages, 1)
 890
 891		message := messages[0].(map[string]any)
 892		require.Equal(t, "user", message["role"])
 893		require.Equal(t, "Hello", message["content"])
 894	})
 895
 896	t.Run("should support partial usage", func(t *testing.T) {
 897		t.Parallel()
 898
 899		server := newMockServer()
 900		defer server.close()
 901
 902		server.prepareJSONResponse(map[string]any{
 903			"usage": map[string]any{
 904				"prompt_tokens": 20,
 905				"total_tokens":  20,
 906			},
 907		})
 908
 909		provider := NewOpenAIProvider(
 910			WithOpenAIApiKey("test-api-key"),
 911			WithOpenAIBaseURL(server.server.URL),
 912		)
 913		model := provider.LanguageModel("gpt-3.5-turbo")
 914
 915		result, err := model.Generate(context.Background(), ai.Call{
 916			Prompt: testPrompt,
 917		})
 918
 919		require.NoError(t, err)
 920		require.Equal(t, int64(20), result.Usage.InputTokens)
 921		require.Equal(t, int64(0), result.Usage.OutputTokens)
 922		require.Equal(t, int64(20), result.Usage.TotalTokens)
 923	})
 924
 925	t.Run("should extract logprobs", func(t *testing.T) {
 926		t.Parallel()
 927
 928		server := newMockServer()
 929		defer server.close()
 930
 931		server.prepareJSONResponse(map[string]any{
 932			"logprobs": testLogprobs,
 933		})
 934
 935		provider := NewOpenAIProvider(
 936			WithOpenAIApiKey("test-api-key"),
 937			WithOpenAIBaseURL(server.server.URL),
 938		)
 939		model := provider.LanguageModel("gpt-3.5-turbo")
 940
 941		result, err := model.Generate(context.Background(), ai.Call{
 942			Prompt: testPrompt,
 943			ProviderOptions: ai.ProviderOptions{
 944				"openai": map[string]any{
 945					"logProbs": true,
 946				},
 947			},
 948		})
 949
 950		require.NoError(t, err)
 951		require.NotNil(t, result.ProviderMetadata)
 952
 953		openaiMeta, ok := result.ProviderMetadata["openai"]
 954		require.True(t, ok)
 955
 956		logprobs, ok := openaiMeta["logprobs"]
 957		require.True(t, ok)
 958		require.NotNil(t, logprobs)
 959	})
 960
 961	t.Run("should extract finish reason", func(t *testing.T) {
 962		t.Parallel()
 963
 964		server := newMockServer()
 965		defer server.close()
 966
 967		server.prepareJSONResponse(map[string]any{
 968			"finish_reason": "stop",
 969		})
 970
 971		provider := NewOpenAIProvider(
 972			WithOpenAIApiKey("test-api-key"),
 973			WithOpenAIBaseURL(server.server.URL),
 974		)
 975		model := provider.LanguageModel("gpt-3.5-turbo")
 976
 977		result, err := model.Generate(context.Background(), ai.Call{
 978			Prompt: testPrompt,
 979		})
 980
 981		require.NoError(t, err)
 982		require.Equal(t, ai.FinishReasonStop, result.FinishReason)
 983	})
 984
 985	t.Run("should support unknown finish reason", func(t *testing.T) {
 986		t.Parallel()
 987
 988		server := newMockServer()
 989		defer server.close()
 990
 991		server.prepareJSONResponse(map[string]any{
 992			"finish_reason": "eos",
 993		})
 994
 995		provider := NewOpenAIProvider(
 996			WithOpenAIApiKey("test-api-key"),
 997			WithOpenAIBaseURL(server.server.URL),
 998		)
 999		model := provider.LanguageModel("gpt-3.5-turbo")
1000
1001		result, err := model.Generate(context.Background(), ai.Call{
1002			Prompt: testPrompt,
1003		})
1004
1005		require.NoError(t, err)
1006		require.Equal(t, ai.FinishReasonUnknown, result.FinishReason)
1007	})
1008
1009	t.Run("should pass the model and the messages", func(t *testing.T) {
1010		t.Parallel()
1011
1012		server := newMockServer()
1013		defer server.close()
1014
1015		server.prepareJSONResponse(map[string]any{
1016			"content": "",
1017		})
1018
1019		provider := NewOpenAIProvider(
1020			WithOpenAIApiKey("test-api-key"),
1021			WithOpenAIBaseURL(server.server.URL),
1022		)
1023		model := provider.LanguageModel("gpt-3.5-turbo")
1024
1025		_, err := model.Generate(context.Background(), ai.Call{
1026			Prompt: testPrompt,
1027		})
1028
1029		require.NoError(t, err)
1030		require.Len(t, server.calls, 1)
1031
1032		call := server.calls[0]
1033		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1034
1035		messages := call.body["messages"].([]any)
1036		require.Len(t, messages, 1)
1037
1038		message := messages[0].(map[string]any)
1039		require.Equal(t, "user", message["role"])
1040		require.Equal(t, "Hello", message["content"])
1041	})
1042
1043	t.Run("should pass settings", func(t *testing.T) {
1044		t.Parallel()
1045
1046		server := newMockServer()
1047		defer server.close()
1048
1049		server.prepareJSONResponse(map[string]any{})
1050
1051		provider := NewOpenAIProvider(
1052			WithOpenAIApiKey("test-api-key"),
1053			WithOpenAIBaseURL(server.server.URL),
1054		)
1055		model := provider.LanguageModel("gpt-3.5-turbo")
1056
1057		_, err := model.Generate(context.Background(), ai.Call{
1058			Prompt: testPrompt,
1059			ProviderOptions: ai.ProviderOptions{
1060				"openai": map[string]any{
1061					"logitBias": map[string]int64{
1062						"50256": -100,
1063					},
1064					"parallelToolCalls": false,
1065					"user":              "test-user-id",
1066				},
1067			},
1068		})
1069
1070		require.NoError(t, err)
1071		require.Len(t, server.calls, 1)
1072
1073		call := server.calls[0]
1074		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1075
1076		messages := call.body["messages"].([]any)
1077		require.Len(t, messages, 1)
1078
1079		logitBias := call.body["logit_bias"].(map[string]any)
1080		require.Equal(t, float64(-100), logitBias["50256"])
1081		require.Equal(t, false, call.body["parallel_tool_calls"])
1082		require.Equal(t, "test-user-id", call.body["user"])
1083	})
1084
1085	t.Run("should pass reasoningEffort setting", func(t *testing.T) {
1086		t.Parallel()
1087
1088		server := newMockServer()
1089		defer server.close()
1090
1091		server.prepareJSONResponse(map[string]any{
1092			"content": "",
1093		})
1094
1095		provider := NewOpenAIProvider(
1096			WithOpenAIApiKey("test-api-key"),
1097			WithOpenAIBaseURL(server.server.URL),
1098		)
1099		model := provider.LanguageModel("o1-mini")
1100
1101		_, err := model.Generate(context.Background(), ai.Call{
1102			Prompt: testPrompt,
1103			ProviderOptions: ai.ProviderOptions{
1104				"openai": map[string]any{
1105					"reasoningEffort": "low",
1106				},
1107			},
1108		})
1109
1110		require.NoError(t, err)
1111		require.Len(t, server.calls, 1)
1112
1113		call := server.calls[0]
1114		require.Equal(t, "o1-mini", call.body["model"])
1115		require.Equal(t, "low", call.body["reasoning_effort"])
1116
1117		messages := call.body["messages"].([]any)
1118		require.Len(t, messages, 1)
1119
1120		message := messages[0].(map[string]any)
1121		require.Equal(t, "user", message["role"])
1122		require.Equal(t, "Hello", message["content"])
1123	})
1124
1125	t.Run("should pass textVerbosity setting", func(t *testing.T) {
1126		t.Parallel()
1127
1128		server := newMockServer()
1129		defer server.close()
1130
1131		server.prepareJSONResponse(map[string]any{
1132			"content": "",
1133		})
1134
1135		provider := NewOpenAIProvider(
1136			WithOpenAIApiKey("test-api-key"),
1137			WithOpenAIBaseURL(server.server.URL),
1138		)
1139		model := provider.LanguageModel("gpt-4o")
1140
1141		_, err := model.Generate(context.Background(), ai.Call{
1142			Prompt: testPrompt,
1143			ProviderOptions: ai.ProviderOptions{
1144				"openai": map[string]any{
1145					"textVerbosity": "low",
1146				},
1147			},
1148		})
1149
1150		require.NoError(t, err)
1151		require.Len(t, server.calls, 1)
1152
1153		call := server.calls[0]
1154		require.Equal(t, "gpt-4o", call.body["model"])
1155		require.Equal(t, "low", call.body["verbosity"])
1156
1157		messages := call.body["messages"].([]any)
1158		require.Len(t, messages, 1)
1159
1160		message := messages[0].(map[string]any)
1161		require.Equal(t, "user", message["role"])
1162		require.Equal(t, "Hello", message["content"])
1163	})
1164
1165	t.Run("should pass tools and toolChoice", func(t *testing.T) {
1166		t.Parallel()
1167
1168		server := newMockServer()
1169		defer server.close()
1170
1171		server.prepareJSONResponse(map[string]any{
1172			"content": "",
1173		})
1174
1175		provider := NewOpenAIProvider(
1176			WithOpenAIApiKey("test-api-key"),
1177			WithOpenAIBaseURL(server.server.URL),
1178		)
1179		model := provider.LanguageModel("gpt-3.5-turbo")
1180
1181		_, err := model.Generate(context.Background(), ai.Call{
1182			Prompt: testPrompt,
1183			Tools: []ai.Tool{
1184				ai.FunctionTool{
1185					Name: "test-tool",
1186					InputSchema: map[string]any{
1187						"type": "object",
1188						"properties": map[string]any{
1189							"value": map[string]any{
1190								"type": "string",
1191							},
1192						},
1193						"required":             []string{"value"},
1194						"additionalProperties": false,
1195						"$schema":              "http://json-schema.org/draft-07/schema#",
1196					},
1197				},
1198			},
1199			ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0],
1200		})
1201
1202		require.NoError(t, err)
1203		require.Len(t, server.calls, 1)
1204
1205		call := server.calls[0]
1206		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1207
1208		messages := call.body["messages"].([]any)
1209		require.Len(t, messages, 1)
1210
1211		tools := call.body["tools"].([]any)
1212		require.Len(t, tools, 1)
1213
1214		tool := tools[0].(map[string]any)
1215		require.Equal(t, "function", tool["type"])
1216
1217		function := tool["function"].(map[string]any)
1218		require.Equal(t, "test-tool", function["name"])
1219		require.Equal(t, false, function["strict"])
1220
1221		toolChoice := call.body["tool_choice"].(map[string]any)
1222		require.Equal(t, "function", toolChoice["type"])
1223
1224		toolChoiceFunction := toolChoice["function"].(map[string]any)
1225		require.Equal(t, "test-tool", toolChoiceFunction["name"])
1226	})
1227
1228	t.Run("should parse tool results", func(t *testing.T) {
1229		t.Parallel()
1230
1231		server := newMockServer()
1232		defer server.close()
1233
1234		server.prepareJSONResponse(map[string]any{
1235			"tool_calls": []map[string]any{
1236				{
1237					"id":   "call_O17Uplv4lJvD6DVdIvFFeRMw",
1238					"type": "function",
1239					"function": map[string]any{
1240						"name":      "test-tool",
1241						"arguments": `{"value":"Spark"}`,
1242					},
1243				},
1244			},
1245		})
1246
1247		provider := NewOpenAIProvider(
1248			WithOpenAIApiKey("test-api-key"),
1249			WithOpenAIBaseURL(server.server.URL),
1250		)
1251		model := provider.LanguageModel("gpt-3.5-turbo")
1252
1253		result, err := model.Generate(context.Background(), ai.Call{
1254			Prompt: testPrompt,
1255			Tools: []ai.Tool{
1256				ai.FunctionTool{
1257					Name: "test-tool",
1258					InputSchema: map[string]any{
1259						"type": "object",
1260						"properties": map[string]any{
1261							"value": map[string]any{
1262								"type": "string",
1263							},
1264						},
1265						"required":             []string{"value"},
1266						"additionalProperties": false,
1267						"$schema":              "http://json-schema.org/draft-07/schema#",
1268					},
1269				},
1270			},
1271			ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0],
1272		})
1273
1274		require.NoError(t, err)
1275		require.Len(t, result.Content, 1)
1276
1277		toolCall, ok := result.Content[0].(ai.ToolCallContent)
1278		require.True(t, ok)
1279		require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", toolCall.ToolCallID)
1280		require.Equal(t, "test-tool", toolCall.ToolName)
1281		require.Equal(t, `{"value":"Spark"}`, toolCall.Input)
1282	})
1283
1284	t.Run("should parse annotations/citations", func(t *testing.T) {
1285		t.Parallel()
1286
1287		server := newMockServer()
1288		defer server.close()
1289
1290		server.prepareJSONResponse(map[string]any{
1291			"content": "Based on the search results [doc1], I found information.",
1292			"annotations": []map[string]any{
1293				{
1294					"type": "url_citation",
1295					"url_citation": map[string]any{
1296						"start_index": 24,
1297						"end_index":   29,
1298						"url":         "https://example.com/doc1.pdf",
1299						"title":       "Document 1",
1300					},
1301				},
1302			},
1303		})
1304
1305		provider := NewOpenAIProvider(
1306			WithOpenAIApiKey("test-api-key"),
1307			WithOpenAIBaseURL(server.server.URL),
1308		)
1309		model := provider.LanguageModel("gpt-3.5-turbo")
1310
1311		result, err := model.Generate(context.Background(), ai.Call{
1312			Prompt: testPrompt,
1313		})
1314
1315		require.NoError(t, err)
1316		require.Len(t, result.Content, 2)
1317
1318		textContent, ok := result.Content[0].(ai.TextContent)
1319		require.True(t, ok)
1320		require.Equal(t, "Based on the search results [doc1], I found information.", textContent.Text)
1321
1322		sourceContent, ok := result.Content[1].(ai.SourceContent)
1323		require.True(t, ok)
1324		require.Equal(t, ai.SourceTypeURL, sourceContent.SourceType)
1325		require.Equal(t, "https://example.com/doc1.pdf", sourceContent.URL)
1326		require.Equal(t, "Document 1", sourceContent.Title)
1327		require.NotEmpty(t, sourceContent.ID)
1328	})
1329
1330	t.Run("should return cached_tokens in prompt_details_tokens", func(t *testing.T) {
1331		t.Parallel()
1332
1333		server := newMockServer()
1334		defer server.close()
1335
1336		server.prepareJSONResponse(map[string]any{
1337			"usage": map[string]any{
1338				"prompt_tokens":     15,
1339				"completion_tokens": 20,
1340				"total_tokens":      35,
1341				"prompt_tokens_details": map[string]any{
1342					"cached_tokens": 1152,
1343				},
1344			},
1345		})
1346
1347		provider := NewOpenAIProvider(
1348			WithOpenAIApiKey("test-api-key"),
1349			WithOpenAIBaseURL(server.server.URL),
1350		)
1351		model := provider.LanguageModel("gpt-4o-mini")
1352
1353		result, err := model.Generate(context.Background(), ai.Call{
1354			Prompt: testPrompt,
1355		})
1356
1357		require.NoError(t, err)
1358		require.Equal(t, int64(1152), result.Usage.CacheReadTokens)
1359		require.Equal(t, int64(15), result.Usage.InputTokens)
1360		require.Equal(t, int64(20), result.Usage.OutputTokens)
1361		require.Equal(t, int64(35), result.Usage.TotalTokens)
1362	})
1363
1364	t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
1365		t.Parallel()
1366
1367		server := newMockServer()
1368		defer server.close()
1369
1370		server.prepareJSONResponse(map[string]any{
1371			"usage": map[string]any{
1372				"prompt_tokens":     15,
1373				"completion_tokens": 20,
1374				"total_tokens":      35,
1375				"completion_tokens_details": map[string]any{
1376					"accepted_prediction_tokens": 123,
1377					"rejected_prediction_tokens": 456,
1378				},
1379			},
1380		})
1381
1382		provider := NewOpenAIProvider(
1383			WithOpenAIApiKey("test-api-key"),
1384			WithOpenAIBaseURL(server.server.URL),
1385		)
1386		model := provider.LanguageModel("gpt-4o-mini")
1387
1388		result, err := model.Generate(context.Background(), ai.Call{
1389			Prompt: testPrompt,
1390		})
1391
1392		require.NoError(t, err)
1393		require.NotNil(t, result.ProviderMetadata)
1394
1395		openaiMeta, ok := result.ProviderMetadata["openai"]
1396		require.True(t, ok)
1397		require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"])
1398		require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"])
1399	})
1400
1401	t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) {
1402		t.Parallel()
1403
1404		server := newMockServer()
1405		defer server.close()
1406
1407		server.prepareJSONResponse(map[string]any{})
1408
1409		provider := NewOpenAIProvider(
1410			WithOpenAIApiKey("test-api-key"),
1411			WithOpenAIBaseURL(server.server.URL),
1412		)
1413		model := provider.LanguageModel("o1-preview")
1414
1415		result, err := model.Generate(context.Background(), ai.Call{
1416			Prompt:           testPrompt,
1417			Temperature:      &[]float64{0.5}[0],
1418			TopP:             &[]float64{0.7}[0],
1419			FrequencyPenalty: &[]float64{0.2}[0],
1420			PresencePenalty:  &[]float64{0.3}[0],
1421		})
1422
1423		require.NoError(t, err)
1424		require.Len(t, server.calls, 1)
1425
1426		call := server.calls[0]
1427		require.Equal(t, "o1-preview", call.body["model"])
1428
1429		messages := call.body["messages"].([]any)
1430		require.Len(t, messages, 1)
1431
1432		message := messages[0].(map[string]any)
1433		require.Equal(t, "user", message["role"])
1434		require.Equal(t, "Hello", message["content"])
1435
1436		// These should not be present
1437		require.Nil(t, call.body["temperature"])
1438		require.Nil(t, call.body["top_p"])
1439		require.Nil(t, call.body["frequency_penalty"])
1440		require.Nil(t, call.body["presence_penalty"])
1441
1442		// Should have warnings
1443		require.Len(t, result.Warnings, 4)
1444		require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1445		require.Equal(t, "temperature", result.Warnings[0].Setting)
1446		require.Contains(t, result.Warnings[0].Details, "temperature is not supported for reasoning models")
1447	})
1448
1449	t.Run("should convert maxOutputTokens to max_completion_tokens for reasoning models", func(t *testing.T) {
1450		t.Parallel()
1451
1452		server := newMockServer()
1453		defer server.close()
1454
1455		server.prepareJSONResponse(map[string]any{})
1456
1457		provider := NewOpenAIProvider(
1458			WithOpenAIApiKey("test-api-key"),
1459			WithOpenAIBaseURL(server.server.URL),
1460		)
1461		model := provider.LanguageModel("o1-preview")
1462
1463		_, err := model.Generate(context.Background(), ai.Call{
1464			Prompt:          testPrompt,
1465			MaxOutputTokens: &[]int64{1000}[0],
1466		})
1467
1468		require.NoError(t, err)
1469		require.Len(t, server.calls, 1)
1470
1471		call := server.calls[0]
1472		require.Equal(t, "o1-preview", call.body["model"])
1473		require.Equal(t, float64(1000), call.body["max_completion_tokens"])
1474		require.Nil(t, call.body["max_tokens"])
1475
1476		messages := call.body["messages"].([]any)
1477		require.Len(t, messages, 1)
1478
1479		message := messages[0].(map[string]any)
1480		require.Equal(t, "user", message["role"])
1481		require.Equal(t, "Hello", message["content"])
1482	})
1483
1484	t.Run("should return reasoning tokens", func(t *testing.T) {
1485		t.Parallel()
1486
1487		server := newMockServer()
1488		defer server.close()
1489
1490		server.prepareJSONResponse(map[string]any{
1491			"usage": map[string]any{
1492				"prompt_tokens":     15,
1493				"completion_tokens": 20,
1494				"total_tokens":      35,
1495				"completion_tokens_details": map[string]any{
1496					"reasoning_tokens": 10,
1497				},
1498			},
1499		})
1500
1501		provider := NewOpenAIProvider(
1502			WithOpenAIApiKey("test-api-key"),
1503			WithOpenAIBaseURL(server.server.URL),
1504		)
1505		model := provider.LanguageModel("o1-preview")
1506
1507		result, err := model.Generate(context.Background(), ai.Call{
1508			Prompt: testPrompt,
1509		})
1510
1511		require.NoError(t, err)
1512		require.Equal(t, int64(15), result.Usage.InputTokens)
1513		require.Equal(t, int64(20), result.Usage.OutputTokens)
1514		require.Equal(t, int64(35), result.Usage.TotalTokens)
1515		require.Equal(t, int64(10), result.Usage.ReasoningTokens)
1516	})
1517
1518	t.Run("should send max_completion_tokens extension setting", func(t *testing.T) {
1519		t.Parallel()
1520
1521		server := newMockServer()
1522		defer server.close()
1523
1524		server.prepareJSONResponse(map[string]any{
1525			"model": "o1-preview",
1526		})
1527
1528		provider := NewOpenAIProvider(
1529			WithOpenAIApiKey("test-api-key"),
1530			WithOpenAIBaseURL(server.server.URL),
1531		)
1532		model := provider.LanguageModel("o1-preview")
1533
1534		_, err := model.Generate(context.Background(), ai.Call{
1535			Prompt: testPrompt,
1536			ProviderOptions: ai.ProviderOptions{
1537				"openai": map[string]any{
1538					"maxCompletionTokens": 255,
1539				},
1540			},
1541		})
1542
1543		require.NoError(t, err)
1544		require.Len(t, server.calls, 1)
1545
1546		call := server.calls[0]
1547		require.Equal(t, "o1-preview", call.body["model"])
1548		require.Equal(t, float64(255), call.body["max_completion_tokens"])
1549
1550		messages := call.body["messages"].([]any)
1551		require.Len(t, messages, 1)
1552
1553		message := messages[0].(map[string]any)
1554		require.Equal(t, "user", message["role"])
1555		require.Equal(t, "Hello", message["content"])
1556	})
1557
1558	t.Run("should send prediction extension setting", func(t *testing.T) {
1559		t.Parallel()
1560
1561		server := newMockServer()
1562		defer server.close()
1563
1564		server.prepareJSONResponse(map[string]any{
1565			"content": "",
1566		})
1567
1568		provider := NewOpenAIProvider(
1569			WithOpenAIApiKey("test-api-key"),
1570			WithOpenAIBaseURL(server.server.URL),
1571		)
1572		model := provider.LanguageModel("gpt-3.5-turbo")
1573
1574		_, err := model.Generate(context.Background(), ai.Call{
1575			Prompt: testPrompt,
1576			ProviderOptions: ai.ProviderOptions{
1577				"openai": map[string]any{
1578					"prediction": map[string]any{
1579						"type":    "content",
1580						"content": "Hello, World!",
1581					},
1582				},
1583			},
1584		})
1585
1586		require.NoError(t, err)
1587		require.Len(t, server.calls, 1)
1588
1589		call := server.calls[0]
1590		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1591
1592		prediction := call.body["prediction"].(map[string]any)
1593		require.Equal(t, "content", prediction["type"])
1594		require.Equal(t, "Hello, World!", prediction["content"])
1595
1596		messages := call.body["messages"].([]any)
1597		require.Len(t, messages, 1)
1598
1599		message := messages[0].(map[string]any)
1600		require.Equal(t, "user", message["role"])
1601		require.Equal(t, "Hello", message["content"])
1602	})
1603
1604	t.Run("should send store extension setting", func(t *testing.T) {
1605		t.Parallel()
1606
1607		server := newMockServer()
1608		defer server.close()
1609
1610		server.prepareJSONResponse(map[string]any{
1611			"content": "",
1612		})
1613
1614		provider := NewOpenAIProvider(
1615			WithOpenAIApiKey("test-api-key"),
1616			WithOpenAIBaseURL(server.server.URL),
1617		)
1618		model := provider.LanguageModel("gpt-3.5-turbo")
1619
1620		_, err := model.Generate(context.Background(), ai.Call{
1621			Prompt: testPrompt,
1622			ProviderOptions: ai.ProviderOptions{
1623				"openai": map[string]any{
1624					"store": true,
1625				},
1626			},
1627		})
1628
1629		require.NoError(t, err)
1630		require.Len(t, server.calls, 1)
1631
1632		call := server.calls[0]
1633		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1634		require.Equal(t, true, call.body["store"])
1635
1636		messages := call.body["messages"].([]any)
1637		require.Len(t, messages, 1)
1638
1639		message := messages[0].(map[string]any)
1640		require.Equal(t, "user", message["role"])
1641		require.Equal(t, "Hello", message["content"])
1642	})
1643
1644	t.Run("should send metadata extension values", func(t *testing.T) {
1645		t.Parallel()
1646
1647		server := newMockServer()
1648		defer server.close()
1649
1650		server.prepareJSONResponse(map[string]any{
1651			"content": "",
1652		})
1653
1654		provider := NewOpenAIProvider(
1655			WithOpenAIApiKey("test-api-key"),
1656			WithOpenAIBaseURL(server.server.URL),
1657		)
1658		model := provider.LanguageModel("gpt-3.5-turbo")
1659
1660		_, err := model.Generate(context.Background(), ai.Call{
1661			Prompt: testPrompt,
1662			ProviderOptions: ai.ProviderOptions{
1663				"openai": map[string]any{
1664					"metadata": map[string]any{
1665						"custom": "value",
1666					},
1667				},
1668			},
1669		})
1670
1671		require.NoError(t, err)
1672		require.Len(t, server.calls, 1)
1673
1674		call := server.calls[0]
1675		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1676
1677		metadata := call.body["metadata"].(map[string]any)
1678		require.Equal(t, "value", metadata["custom"])
1679
1680		messages := call.body["messages"].([]any)
1681		require.Len(t, messages, 1)
1682
1683		message := messages[0].(map[string]any)
1684		require.Equal(t, "user", message["role"])
1685		require.Equal(t, "Hello", message["content"])
1686	})
1687
1688	t.Run("should send promptCacheKey extension value", func(t *testing.T) {
1689		t.Parallel()
1690
1691		server := newMockServer()
1692		defer server.close()
1693
1694		server.prepareJSONResponse(map[string]any{
1695			"content": "",
1696		})
1697
1698		provider := NewOpenAIProvider(
1699			WithOpenAIApiKey("test-api-key"),
1700			WithOpenAIBaseURL(server.server.URL),
1701		)
1702		model := provider.LanguageModel("gpt-3.5-turbo")
1703
1704		_, err := model.Generate(context.Background(), ai.Call{
1705			Prompt: testPrompt,
1706			ProviderOptions: ai.ProviderOptions{
1707				"openai": map[string]any{
1708					"promptCacheKey": "test-cache-key-123",
1709				},
1710			},
1711		})
1712
1713		require.NoError(t, err)
1714		require.Len(t, server.calls, 1)
1715
1716		call := server.calls[0]
1717		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1718		require.Equal(t, "test-cache-key-123", call.body["prompt_cache_key"])
1719
1720		messages := call.body["messages"].([]any)
1721		require.Len(t, messages, 1)
1722
1723		message := messages[0].(map[string]any)
1724		require.Equal(t, "user", message["role"])
1725		require.Equal(t, "Hello", message["content"])
1726	})
1727
1728	t.Run("should send safetyIdentifier extension value", func(t *testing.T) {
1729		t.Parallel()
1730
1731		server := newMockServer()
1732		defer server.close()
1733
1734		server.prepareJSONResponse(map[string]any{
1735			"content": "",
1736		})
1737
1738		provider := NewOpenAIProvider(
1739			WithOpenAIApiKey("test-api-key"),
1740			WithOpenAIBaseURL(server.server.URL),
1741		)
1742		model := provider.LanguageModel("gpt-3.5-turbo")
1743
1744		_, err := model.Generate(context.Background(), ai.Call{
1745			Prompt: testPrompt,
1746			ProviderOptions: ai.ProviderOptions{
1747				"openai": map[string]any{
1748					"safetyIdentifier": "test-safety-identifier-123",
1749				},
1750			},
1751		})
1752
1753		require.NoError(t, err)
1754		require.Len(t, server.calls, 1)
1755
1756		call := server.calls[0]
1757		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1758		require.Equal(t, "test-safety-identifier-123", call.body["safety_identifier"])
1759
1760		messages := call.body["messages"].([]any)
1761		require.Len(t, messages, 1)
1762
1763		message := messages[0].(map[string]any)
1764		require.Equal(t, "user", message["role"])
1765		require.Equal(t, "Hello", message["content"])
1766	})
1767
1768	t.Run("should remove temperature setting for search preview models", func(t *testing.T) {
1769		t.Parallel()
1770
1771		server := newMockServer()
1772		defer server.close()
1773
1774		server.prepareJSONResponse(map[string]any{})
1775
1776		provider := NewOpenAIProvider(
1777			WithOpenAIApiKey("test-api-key"),
1778			WithOpenAIBaseURL(server.server.URL),
1779		)
1780		model := provider.LanguageModel("gpt-4o-search-preview")
1781
1782		result, err := model.Generate(context.Background(), ai.Call{
1783			Prompt:      testPrompt,
1784			Temperature: &[]float64{0.7}[0],
1785		})
1786
1787		require.NoError(t, err)
1788		require.Len(t, server.calls, 1)
1789
1790		call := server.calls[0]
1791		require.Equal(t, "gpt-4o-search-preview", call.body["model"])
1792		require.Nil(t, call.body["temperature"])
1793
1794		require.Len(t, result.Warnings, 1)
1795		require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1796		require.Equal(t, "temperature", result.Warnings[0].Setting)
1797		require.Contains(t, result.Warnings[0].Details, "search preview models")
1798	})
1799
1800	t.Run("should send serviceTier flex processing setting", func(t *testing.T) {
1801		t.Parallel()
1802
1803		server := newMockServer()
1804		defer server.close()
1805
1806		server.prepareJSONResponse(map[string]any{
1807			"content": "",
1808		})
1809
1810		provider := NewOpenAIProvider(
1811			WithOpenAIApiKey("test-api-key"),
1812			WithOpenAIBaseURL(server.server.URL),
1813		)
1814		model := provider.LanguageModel("o3-mini")
1815
1816		_, err := model.Generate(context.Background(), ai.Call{
1817			Prompt: testPrompt,
1818			ProviderOptions: ai.ProviderOptions{
1819				"openai": map[string]any{
1820					"serviceTier": "flex",
1821				},
1822			},
1823		})
1824
1825		require.NoError(t, err)
1826		require.Len(t, server.calls, 1)
1827
1828		call := server.calls[0]
1829		require.Equal(t, "o3-mini", call.body["model"])
1830		require.Equal(t, "flex", call.body["service_tier"])
1831
1832		messages := call.body["messages"].([]any)
1833		require.Len(t, messages, 1)
1834
1835		message := messages[0].(map[string]any)
1836		require.Equal(t, "user", message["role"])
1837		require.Equal(t, "Hello", message["content"])
1838	})
1839
1840	t.Run("should show warning when using flex processing with unsupported model", func(t *testing.T) {
1841		t.Parallel()
1842
1843		server := newMockServer()
1844		defer server.close()
1845
1846		server.prepareJSONResponse(map[string]any{})
1847
1848		provider := NewOpenAIProvider(
1849			WithOpenAIApiKey("test-api-key"),
1850			WithOpenAIBaseURL(server.server.URL),
1851		)
1852		model := provider.LanguageModel("gpt-4o-mini")
1853
1854		result, err := model.Generate(context.Background(), ai.Call{
1855			Prompt: testPrompt,
1856			ProviderOptions: ai.ProviderOptions{
1857				"openai": map[string]any{
1858					"serviceTier": "flex",
1859				},
1860			},
1861		})
1862
1863		require.NoError(t, err)
1864		require.Len(t, server.calls, 1)
1865
1866		call := server.calls[0]
1867		require.Nil(t, call.body["service_tier"])
1868
1869		require.Len(t, result.Warnings, 1)
1870		require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1871		require.Equal(t, "serviceTier", result.Warnings[0].Setting)
1872		require.Contains(t, result.Warnings[0].Details, "flex processing is only available")
1873	})
1874
1875	t.Run("should send serviceTier priority processing setting", func(t *testing.T) {
1876		t.Parallel()
1877
1878		server := newMockServer()
1879		defer server.close()
1880
1881		server.prepareJSONResponse(map[string]any{})
1882
1883		provider := NewOpenAIProvider(
1884			WithOpenAIApiKey("test-api-key"),
1885			WithOpenAIBaseURL(server.server.URL),
1886		)
1887		model := provider.LanguageModel("gpt-4o-mini")
1888
1889		_, err := model.Generate(context.Background(), ai.Call{
1890			Prompt: testPrompt,
1891			ProviderOptions: ai.ProviderOptions{
1892				"openai": map[string]any{
1893					"serviceTier": "priority",
1894				},
1895			},
1896		})
1897
1898		require.NoError(t, err)
1899		require.Len(t, server.calls, 1)
1900
1901		call := server.calls[0]
1902		require.Equal(t, "gpt-4o-mini", call.body["model"])
1903		require.Equal(t, "priority", call.body["service_tier"])
1904
1905		messages := call.body["messages"].([]any)
1906		require.Len(t, messages, 1)
1907
1908		message := messages[0].(map[string]any)
1909		require.Equal(t, "user", message["role"])
1910		require.Equal(t, "Hello", message["content"])
1911	})
1912
1913	t.Run("should show warning when using priority processing with unsupported model", func(t *testing.T) {
1914		t.Parallel()
1915
1916		server := newMockServer()
1917		defer server.close()
1918
1919		server.prepareJSONResponse(map[string]any{})
1920
1921		provider := NewOpenAIProvider(
1922			WithOpenAIApiKey("test-api-key"),
1923			WithOpenAIBaseURL(server.server.URL),
1924		)
1925		model := provider.LanguageModel("gpt-3.5-turbo")
1926
1927		result, err := model.Generate(context.Background(), ai.Call{
1928			Prompt: testPrompt,
1929			ProviderOptions: ai.ProviderOptions{
1930				"openai": map[string]any{
1931					"serviceTier": "priority",
1932				},
1933			},
1934		})
1935
1936		require.NoError(t, err)
1937		require.Len(t, server.calls, 1)
1938
1939		call := server.calls[0]
1940		require.Nil(t, call.body["service_tier"])
1941
1942		require.Len(t, result.Warnings, 1)
1943		require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1944		require.Equal(t, "serviceTier", result.Warnings[0].Setting)
1945		require.Contains(t, result.Warnings[0].Details, "priority processing is only available")
1946	})
1947}
1948
1949type streamingMockServer struct {
1950	server *httptest.Server
1951	chunks []string
1952	calls  []mockCall
1953}
1954
1955func newStreamingMockServer() *streamingMockServer {
1956	sms := &streamingMockServer{
1957		calls: make([]mockCall, 0),
1958	}
1959
1960	sms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1961		// Record the call
1962		call := mockCall{
1963			method:  r.Method,
1964			path:    r.URL.Path,
1965			headers: make(map[string]string),
1966		}
1967
1968		for k, v := range r.Header {
1969			if len(v) > 0 {
1970				call.headers[k] = v[0]
1971			}
1972		}
1973
1974		// Parse request body
1975		if r.Body != nil {
1976			var body map[string]any
1977			json.NewDecoder(r.Body).Decode(&body)
1978			call.body = body
1979		}
1980
1981		sms.calls = append(sms.calls, call)
1982
1983		// Set streaming headers
1984		w.Header().Set("Content-Type", "text/event-stream")
1985		w.Header().Set("Cache-Control", "no-cache")
1986		w.Header().Set("Connection", "keep-alive")
1987
1988		// Add custom headers if any
1989		for _, chunk := range sms.chunks {
1990			if strings.HasPrefix(chunk, "HEADER:") {
1991				parts := strings.SplitN(chunk[7:], ":", 2)
1992				if len(parts) == 2 {
1993					w.Header().Set(parts[0], parts[1])
1994				}
1995				continue
1996			}
1997		}
1998
1999		w.WriteHeader(http.StatusOK)
2000
2001		// Write chunks
2002		for _, chunk := range sms.chunks {
2003			if strings.HasPrefix(chunk, "HEADER:") {
2004				continue
2005			}
2006			w.Write([]byte(chunk))
2007			if f, ok := w.(http.Flusher); ok {
2008				f.Flush()
2009			}
2010		}
2011	}))
2012
2013	return sms
2014}
2015
2016func (sms *streamingMockServer) close() {
2017	sms.server.Close()
2018}
2019
2020func (sms *streamingMockServer) prepareStreamResponse(opts map[string]any) {
2021	content := []string{}
2022	if c, ok := opts["content"].([]string); ok {
2023		content = c
2024	}
2025
2026	usage := map[string]any{
2027		"prompt_tokens":     17,
2028		"total_tokens":      244,
2029		"completion_tokens": 227,
2030	}
2031	if u, ok := opts["usage"].(map[string]any); ok {
2032		usage = u
2033	}
2034
2035	logprobs := map[string]any{}
2036	if l, ok := opts["logprobs"].(map[string]any); ok {
2037		logprobs = l
2038	}
2039
2040	finishReason := "stop"
2041	if fr, ok := opts["finish_reason"].(string); ok {
2042		finishReason = fr
2043	}
2044
2045	model := "gpt-3.5-turbo-0613"
2046	if m, ok := opts["model"].(string); ok {
2047		model = m
2048	}
2049
2050	headers := map[string]string{}
2051	if h, ok := opts["headers"].(map[string]string); ok {
2052		headers = h
2053	}
2054
2055	chunks := []string{}
2056
2057	// Add custom headers
2058	for k, v := range headers {
2059		chunks = append(chunks, "HEADER:"+k+":"+v)
2060	}
2061
2062	// Initial chunk with role
2063	initialChunk := map[string]any{
2064		"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2065		"object":             "chat.completion.chunk",
2066		"created":            1702657020,
2067		"model":              model,
2068		"system_fingerprint": nil,
2069		"choices": []map[string]any{
2070			{
2071				"index": 0,
2072				"delta": map[string]any{
2073					"role":    "assistant",
2074					"content": "",
2075				},
2076				"finish_reason": nil,
2077			},
2078		},
2079	}
2080	initialData, _ := json.Marshal(initialChunk)
2081	chunks = append(chunks, "data: "+string(initialData)+"\n\n")
2082
2083	// Content chunks
2084	for i, text := range content {
2085		contentChunk := map[string]any{
2086			"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2087			"object":             "chat.completion.chunk",
2088			"created":            1702657020,
2089			"model":              model,
2090			"system_fingerprint": nil,
2091			"choices": []map[string]any{
2092				{
2093					"index": 1,
2094					"delta": map[string]any{
2095						"content": text,
2096					},
2097					"finish_reason": nil,
2098				},
2099			},
2100		}
2101		contentData, _ := json.Marshal(contentChunk)
2102		chunks = append(chunks, "data: "+string(contentData)+"\n\n")
2103
2104		// Add annotations if this is the last content chunk and we have annotations
2105		if i == len(content)-1 {
2106			if annotations, ok := opts["annotations"].([]map[string]any); ok {
2107				annotationChunk := map[string]any{
2108					"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2109					"object":             "chat.completion.chunk",
2110					"created":            1702657020,
2111					"model":              model,
2112					"system_fingerprint": nil,
2113					"choices": []map[string]any{
2114						{
2115							"index": 1,
2116							"delta": map[string]any{
2117								"annotations": annotations,
2118							},
2119							"finish_reason": nil,
2120						},
2121					},
2122				}
2123				annotationData, _ := json.Marshal(annotationChunk)
2124				chunks = append(chunks, "data: "+string(annotationData)+"\n\n")
2125			}
2126		}
2127	}
2128
2129	// Finish chunk
2130	finishChunk := map[string]any{
2131		"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2132		"object":             "chat.completion.chunk",
2133		"created":            1702657020,
2134		"model":              model,
2135		"system_fingerprint": nil,
2136		"choices": []map[string]any{
2137			{
2138				"index":         0,
2139				"delta":         map[string]any{},
2140				"finish_reason": finishReason,
2141			},
2142		},
2143	}
2144
2145	if len(logprobs) > 0 {
2146		finishChunk["choices"].([]map[string]any)[0]["logprobs"] = logprobs
2147	}
2148
2149	finishData, _ := json.Marshal(finishChunk)
2150	chunks = append(chunks, "data: "+string(finishData)+"\n\n")
2151
2152	// Usage chunk
2153	usageChunk := map[string]any{
2154		"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2155		"object":             "chat.completion.chunk",
2156		"created":            1702657020,
2157		"model":              model,
2158		"system_fingerprint": "fp_3bc1b5746c",
2159		"choices":            []map[string]any{},
2160		"usage":              usage,
2161	}
2162	usageData, _ := json.Marshal(usageChunk)
2163	chunks = append(chunks, "data: "+string(usageData)+"\n\n")
2164
2165	// Done
2166	chunks = append(chunks, "data: [DONE]\n\n")
2167
2168	sms.chunks = chunks
2169}
2170
2171func (sms *streamingMockServer) prepareToolStreamResponse() {
2172	chunks := []string{
2173		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2174		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2175		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"value"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2176		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2177		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2178		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2179		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2180		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2181		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
2182		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
2183		"data: [DONE]\n\n",
2184	}
2185	sms.chunks = chunks
2186}
2187
2188func (sms *streamingMockServer) prepareErrorStreamResponse() {
2189	chunks := []string{
2190		`data: {"error":{"message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error.","type":"server_error","param":null,"code":null}}` + "\n\n",
2191		"data: [DONE]\n\n",
2192	}
2193	sms.chunks = chunks
2194}
2195
2196func collectStreamParts(stream ai.StreamResponse) ([]ai.StreamPart, error) {
2197	var parts []ai.StreamPart
2198	for part := range stream {
2199		parts = append(parts, part)
2200		if part.Type == ai.StreamPartTypeError {
2201			break
2202		}
2203		if part.Type == ai.StreamPartTypeFinish {
2204			break
2205		}
2206	}
2207	return parts, nil
2208}
2209
2210func TestDoStream(t *testing.T) {
2211	t.Parallel()
2212
2213	t.Run("should stream text deltas", func(t *testing.T) {
2214		t.Parallel()
2215
2216		server := newStreamingMockServer()
2217		defer server.close()
2218
2219		server.prepareStreamResponse(map[string]any{
2220			"content":       []string{"Hello", ", ", "World!"},
2221			"finish_reason": "stop",
2222			"usage": map[string]any{
2223				"prompt_tokens":     17,
2224				"total_tokens":      244,
2225				"completion_tokens": 227,
2226			},
2227			"logprobs": testLogprobs,
2228		})
2229
2230		provider := NewOpenAIProvider(
2231			WithOpenAIApiKey("test-api-key"),
2232			WithOpenAIBaseURL(server.server.URL),
2233		)
2234		model := provider.LanguageModel("gpt-3.5-turbo")
2235
2236		stream, err := model.Stream(context.Background(), ai.Call{
2237			Prompt: testPrompt,
2238		})
2239
2240		require.NoError(t, err)
2241
2242		parts, err := collectStreamParts(stream)
2243		require.NoError(t, err)
2244
2245		// Verify stream structure
2246		require.True(t, len(parts) >= 4) // text-start, deltas, text-end, finish
2247
2248		// Find text parts
2249		var textStart, textEnd, finish int = -1, -1, -1
2250		var deltas []string
2251
2252		for i, part := range parts {
2253			switch part.Type {
2254			case ai.StreamPartTypeTextStart:
2255				textStart = i
2256			case ai.StreamPartTypeTextDelta:
2257				deltas = append(deltas, part.Delta)
2258			case ai.StreamPartTypeTextEnd:
2259				textEnd = i
2260			case ai.StreamPartTypeFinish:
2261				finish = i
2262			}
2263		}
2264
2265		require.NotEqual(t, -1, textStart)
2266		require.NotEqual(t, -1, textEnd)
2267		require.NotEqual(t, -1, finish)
2268		require.Equal(t, []string{"Hello", ", ", "World!"}, deltas)
2269
2270		// Check finish part
2271		finishPart := parts[finish]
2272		require.Equal(t, ai.FinishReasonStop, finishPart.FinishReason)
2273		require.Equal(t, int64(17), finishPart.Usage.InputTokens)
2274		require.Equal(t, int64(227), finishPart.Usage.OutputTokens)
2275		require.Equal(t, int64(244), finishPart.Usage.TotalTokens)
2276	})
2277
2278	t.Run("should stream tool deltas", func(t *testing.T) {
2279		t.Parallel()
2280
2281		server := newStreamingMockServer()
2282		defer server.close()
2283
2284		server.prepareToolStreamResponse()
2285
2286		provider := NewOpenAIProvider(
2287			WithOpenAIApiKey("test-api-key"),
2288			WithOpenAIBaseURL(server.server.URL),
2289		)
2290		model := provider.LanguageModel("gpt-3.5-turbo")
2291
2292		stream, err := model.Stream(context.Background(), ai.Call{
2293			Prompt: testPrompt,
2294			Tools: []ai.Tool{
2295				ai.FunctionTool{
2296					Name: "test-tool",
2297					InputSchema: map[string]any{
2298						"type": "object",
2299						"properties": map[string]any{
2300							"value": map[string]any{
2301								"type": "string",
2302							},
2303						},
2304						"required":             []string{"value"},
2305						"additionalProperties": false,
2306						"$schema":              "http://json-schema.org/draft-07/schema#",
2307					},
2308				},
2309			},
2310		})
2311
2312		require.NoError(t, err)
2313
2314		parts, err := collectStreamParts(stream)
2315		require.NoError(t, err)
2316
2317		// Find tool-related parts
2318		toolInputStart, toolInputEnd, toolCall := -1, -1, -1
2319		var toolDeltas []string
2320
2321		for i, part := range parts {
2322			switch part.Type {
2323			case ai.StreamPartTypeToolInputStart:
2324				toolInputStart = i
2325				require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2326				require.Equal(t, "test-tool", part.ToolCallName)
2327			case ai.StreamPartTypeToolInputDelta:
2328				toolDeltas = append(toolDeltas, part.Delta)
2329			case ai.StreamPartTypeToolInputEnd:
2330				toolInputEnd = i
2331			case ai.StreamPartTypeToolCall:
2332				toolCall = i
2333				require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2334				require.Equal(t, "test-tool", part.ToolCallName)
2335				require.Equal(t, `{"value":"Sparkle Day"}`, part.ToolCallInput)
2336			}
2337		}
2338
2339		require.NotEqual(t, -1, toolInputStart)
2340		require.NotEqual(t, -1, toolInputEnd)
2341		require.NotEqual(t, -1, toolCall)
2342
2343		// Verify tool deltas combine to form the complete input
2344		fullInput := ""
2345		for _, delta := range toolDeltas {
2346			fullInput += delta
2347		}
2348		require.Equal(t, `{"value":"Sparkle Day"}`, fullInput)
2349	})
2350
2351	t.Run("should stream annotations/citations", func(t *testing.T) {
2352		t.Parallel()
2353
2354		server := newStreamingMockServer()
2355		defer server.close()
2356
2357		server.prepareStreamResponse(map[string]any{
2358			"content": []string{"Based on search results"},
2359			"annotations": []map[string]any{
2360				{
2361					"type": "url_citation",
2362					"url_citation": map[string]any{
2363						"start_index": 24,
2364						"end_index":   29,
2365						"url":         "https://example.com/doc1.pdf",
2366						"title":       "Document 1",
2367					},
2368				},
2369			},
2370		})
2371
2372		provider := NewOpenAIProvider(
2373			WithOpenAIApiKey("test-api-key"),
2374			WithOpenAIBaseURL(server.server.URL),
2375		)
2376		model := provider.LanguageModel("gpt-3.5-turbo")
2377
2378		stream, err := model.Stream(context.Background(), ai.Call{
2379			Prompt: testPrompt,
2380		})
2381
2382		require.NoError(t, err)
2383
2384		parts, err := collectStreamParts(stream)
2385		require.NoError(t, err)
2386
2387		// Find source part
2388		var sourcePart *ai.StreamPart
2389		for _, part := range parts {
2390			if part.Type == ai.StreamPartTypeSource {
2391				sourcePart = &part
2392				break
2393			}
2394		}
2395
2396		require.NotNil(t, sourcePart)
2397		require.Equal(t, ai.SourceTypeURL, sourcePart.SourceType)
2398		require.Equal(t, "https://example.com/doc1.pdf", sourcePart.URL)
2399		require.Equal(t, "Document 1", sourcePart.Title)
2400		require.NotEmpty(t, sourcePart.ID)
2401	})
2402
2403	t.Run("should handle error stream parts", func(t *testing.T) {
2404		t.Parallel()
2405
2406		server := newStreamingMockServer()
2407		defer server.close()
2408
2409		server.prepareErrorStreamResponse()
2410
2411		provider := NewOpenAIProvider(
2412			WithOpenAIApiKey("test-api-key"),
2413			WithOpenAIBaseURL(server.server.URL),
2414		)
2415		model := provider.LanguageModel("gpt-3.5-turbo")
2416
2417		stream, err := model.Stream(context.Background(), ai.Call{
2418			Prompt: testPrompt,
2419		})
2420
2421		require.NoError(t, err)
2422
2423		parts, err := collectStreamParts(stream)
2424		require.NoError(t, err)
2425
2426		// Should have error and finish parts
2427		require.True(t, len(parts) >= 1)
2428
2429		// Find error part
2430		var errorPart *ai.StreamPart
2431		for _, part := range parts {
2432			if part.Type == ai.StreamPartTypeError {
2433				errorPart = &part
2434				break
2435			}
2436		}
2437
2438		require.NotNil(t, errorPart)
2439		require.NotNil(t, errorPart.Error)
2440	})
2441
2442	t.Run("should send request body", func(t *testing.T) {
2443		t.Parallel()
2444
2445		server := newStreamingMockServer()
2446		defer server.close()
2447
2448		server.prepareStreamResponse(map[string]any{
2449			"content": []string{},
2450		})
2451
2452		provider := NewOpenAIProvider(
2453			WithOpenAIApiKey("test-api-key"),
2454			WithOpenAIBaseURL(server.server.URL),
2455		)
2456		model := provider.LanguageModel("gpt-3.5-turbo")
2457
2458		_, err := model.Stream(context.Background(), ai.Call{
2459			Prompt: testPrompt,
2460		})
2461
2462		require.NoError(t, err)
2463		require.Len(t, server.calls, 1)
2464
2465		call := server.calls[0]
2466		require.Equal(t, "POST", call.method)
2467		require.Equal(t, "/chat/completions", call.path)
2468		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2469		require.Equal(t, true, call.body["stream"])
2470
2471		streamOptions := call.body["stream_options"].(map[string]any)
2472		require.Equal(t, true, streamOptions["include_usage"])
2473
2474		messages := call.body["messages"].([]any)
2475		require.Len(t, messages, 1)
2476
2477		message := messages[0].(map[string]any)
2478		require.Equal(t, "user", message["role"])
2479		require.Equal(t, "Hello", message["content"])
2480	})
2481
2482	t.Run("should return cached tokens in providerMetadata", func(t *testing.T) {
2483		t.Parallel()
2484
2485		server := newStreamingMockServer()
2486		defer server.close()
2487
2488		server.prepareStreamResponse(map[string]any{
2489			"content": []string{},
2490			"usage": map[string]any{
2491				"prompt_tokens":     15,
2492				"completion_tokens": 20,
2493				"total_tokens":      35,
2494				"prompt_tokens_details": map[string]any{
2495					"cached_tokens": 1152,
2496				},
2497			},
2498		})
2499
2500		provider := NewOpenAIProvider(
2501			WithOpenAIApiKey("test-api-key"),
2502			WithOpenAIBaseURL(server.server.URL),
2503		)
2504		model := provider.LanguageModel("gpt-3.5-turbo")
2505
2506		stream, err := model.Stream(context.Background(), ai.Call{
2507			Prompt: testPrompt,
2508		})
2509
2510		require.NoError(t, err)
2511
2512		parts, err := collectStreamParts(stream)
2513		require.NoError(t, err)
2514
2515		// Find finish part
2516		var finishPart *ai.StreamPart
2517		for _, part := range parts {
2518			if part.Type == ai.StreamPartTypeFinish {
2519				finishPart = &part
2520				break
2521			}
2522		}
2523
2524		require.NotNil(t, finishPart)
2525		require.Equal(t, int64(1152), finishPart.Usage.CacheReadTokens)
2526		require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2527		require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2528		require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2529	})
2530
2531	t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
2532		t.Parallel()
2533
2534		server := newStreamingMockServer()
2535		defer server.close()
2536
2537		server.prepareStreamResponse(map[string]any{
2538			"content": []string{},
2539			"usage": map[string]any{
2540				"prompt_tokens":     15,
2541				"completion_tokens": 20,
2542				"total_tokens":      35,
2543				"completion_tokens_details": map[string]any{
2544					"accepted_prediction_tokens": 123,
2545					"rejected_prediction_tokens": 456,
2546				},
2547			},
2548		})
2549
2550		provider := NewOpenAIProvider(
2551			WithOpenAIApiKey("test-api-key"),
2552			WithOpenAIBaseURL(server.server.URL),
2553		)
2554		model := provider.LanguageModel("gpt-3.5-turbo")
2555
2556		stream, err := model.Stream(context.Background(), ai.Call{
2557			Prompt: testPrompt,
2558		})
2559
2560		require.NoError(t, err)
2561
2562		parts, err := collectStreamParts(stream)
2563		require.NoError(t, err)
2564
2565		// Find finish part
2566		var finishPart *ai.StreamPart
2567		for _, part := range parts {
2568			if part.Type == ai.StreamPartTypeFinish {
2569				finishPart = &part
2570				break
2571			}
2572		}
2573
2574		require.NotNil(t, finishPart)
2575		require.NotNil(t, finishPart.ProviderMetadata)
2576
2577		openaiMeta, ok := finishPart.ProviderMetadata["openai"]
2578		require.True(t, ok)
2579		require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"])
2580		require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"])
2581	})
2582
2583	t.Run("should send store extension setting", func(t *testing.T) {
2584		t.Parallel()
2585
2586		server := newStreamingMockServer()
2587		defer server.close()
2588
2589		server.prepareStreamResponse(map[string]any{
2590			"content": []string{},
2591		})
2592
2593		provider := NewOpenAIProvider(
2594			WithOpenAIApiKey("test-api-key"),
2595			WithOpenAIBaseURL(server.server.URL),
2596		)
2597		model := provider.LanguageModel("gpt-3.5-turbo")
2598
2599		_, err := model.Stream(context.Background(), ai.Call{
2600			Prompt: testPrompt,
2601			ProviderOptions: ai.ProviderOptions{
2602				"openai": map[string]any{
2603					"store": true,
2604				},
2605			},
2606		})
2607
2608		require.NoError(t, err)
2609		require.Len(t, server.calls, 1)
2610
2611		call := server.calls[0]
2612		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2613		require.Equal(t, true, call.body["stream"])
2614		require.Equal(t, true, call.body["store"])
2615
2616		streamOptions := call.body["stream_options"].(map[string]any)
2617		require.Equal(t, true, streamOptions["include_usage"])
2618
2619		messages := call.body["messages"].([]any)
2620		require.Len(t, messages, 1)
2621
2622		message := messages[0].(map[string]any)
2623		require.Equal(t, "user", message["role"])
2624		require.Equal(t, "Hello", message["content"])
2625	})
2626
2627	t.Run("should send metadata extension values", func(t *testing.T) {
2628		t.Parallel()
2629
2630		server := newStreamingMockServer()
2631		defer server.close()
2632
2633		server.prepareStreamResponse(map[string]any{
2634			"content": []string{},
2635		})
2636
2637		provider := NewOpenAIProvider(
2638			WithOpenAIApiKey("test-api-key"),
2639			WithOpenAIBaseURL(server.server.URL),
2640		)
2641		model := provider.LanguageModel("gpt-3.5-turbo")
2642
2643		_, err := model.Stream(context.Background(), ai.Call{
2644			Prompt: testPrompt,
2645			ProviderOptions: ai.ProviderOptions{
2646				"openai": map[string]any{
2647					"metadata": map[string]any{
2648						"custom": "value",
2649					},
2650				},
2651			},
2652		})
2653
2654		require.NoError(t, err)
2655		require.Len(t, server.calls, 1)
2656
2657		call := server.calls[0]
2658		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2659		require.Equal(t, true, call.body["stream"])
2660
2661		metadata := call.body["metadata"].(map[string]any)
2662		require.Equal(t, "value", metadata["custom"])
2663
2664		streamOptions := call.body["stream_options"].(map[string]any)
2665		require.Equal(t, true, streamOptions["include_usage"])
2666
2667		messages := call.body["messages"].([]any)
2668		require.Len(t, messages, 1)
2669
2670		message := messages[0].(map[string]any)
2671		require.Equal(t, "user", message["role"])
2672		require.Equal(t, "Hello", message["content"])
2673	})
2674
2675	t.Run("should send serviceTier flex processing setting in streaming", func(t *testing.T) {
2676		t.Parallel()
2677
2678		server := newStreamingMockServer()
2679		defer server.close()
2680
2681		server.prepareStreamResponse(map[string]any{
2682			"content": []string{},
2683		})
2684
2685		provider := NewOpenAIProvider(
2686			WithOpenAIApiKey("test-api-key"),
2687			WithOpenAIBaseURL(server.server.URL),
2688		)
2689		model := provider.LanguageModel("o3-mini")
2690
2691		_, err := model.Stream(context.Background(), ai.Call{
2692			Prompt: testPrompt,
2693			ProviderOptions: ai.ProviderOptions{
2694				"openai": map[string]any{
2695					"serviceTier": "flex",
2696				},
2697			},
2698		})
2699
2700		require.NoError(t, err)
2701		require.Len(t, server.calls, 1)
2702
2703		call := server.calls[0]
2704		require.Equal(t, "o3-mini", call.body["model"])
2705		require.Equal(t, "flex", call.body["service_tier"])
2706		require.Equal(t, true, call.body["stream"])
2707
2708		streamOptions := call.body["stream_options"].(map[string]any)
2709		require.Equal(t, true, streamOptions["include_usage"])
2710
2711		messages := call.body["messages"].([]any)
2712		require.Len(t, messages, 1)
2713
2714		message := messages[0].(map[string]any)
2715		require.Equal(t, "user", message["role"])
2716		require.Equal(t, "Hello", message["content"])
2717	})
2718
2719	t.Run("should send serviceTier priority processing setting in streaming", func(t *testing.T) {
2720		t.Parallel()
2721
2722		server := newStreamingMockServer()
2723		defer server.close()
2724
2725		server.prepareStreamResponse(map[string]any{
2726			"content": []string{},
2727		})
2728
2729		provider := NewOpenAIProvider(
2730			WithOpenAIApiKey("test-api-key"),
2731			WithOpenAIBaseURL(server.server.URL),
2732		)
2733		model := provider.LanguageModel("gpt-4o-mini")
2734
2735		_, err := model.Stream(context.Background(), ai.Call{
2736			Prompt: testPrompt,
2737			ProviderOptions: ai.ProviderOptions{
2738				"openai": map[string]any{
2739					"serviceTier": "priority",
2740				},
2741			},
2742		})
2743
2744		require.NoError(t, err)
2745		require.Len(t, server.calls, 1)
2746
2747		call := server.calls[0]
2748		require.Equal(t, "gpt-4o-mini", call.body["model"])
2749		require.Equal(t, "priority", call.body["service_tier"])
2750		require.Equal(t, true, call.body["stream"])
2751
2752		streamOptions := call.body["stream_options"].(map[string]any)
2753		require.Equal(t, true, streamOptions["include_usage"])
2754
2755		messages := call.body["messages"].([]any)
2756		require.Len(t, messages, 1)
2757
2758		message := messages[0].(map[string]any)
2759		require.Equal(t, "user", message["role"])
2760		require.Equal(t, "Hello", message["content"])
2761	})
2762
2763	t.Run("should stream text delta for reasoning models", func(t *testing.T) {
2764		t.Parallel()
2765
2766		server := newStreamingMockServer()
2767		defer server.close()
2768
2769		server.prepareStreamResponse(map[string]any{
2770			"content": []string{"Hello, World!"},
2771			"model":   "o1-preview",
2772		})
2773
2774		provider := NewOpenAIProvider(
2775			WithOpenAIApiKey("test-api-key"),
2776			WithOpenAIBaseURL(server.server.URL),
2777		)
2778		model := provider.LanguageModel("o1-preview")
2779
2780		stream, err := model.Stream(context.Background(), ai.Call{
2781			Prompt: testPrompt,
2782		})
2783
2784		require.NoError(t, err)
2785
2786		parts, err := collectStreamParts(stream)
2787		require.NoError(t, err)
2788
2789		// Find text parts
2790		var textDeltas []string
2791		for _, part := range parts {
2792			if part.Type == ai.StreamPartTypeTextDelta {
2793				textDeltas = append(textDeltas, part.Delta)
2794			}
2795		}
2796
2797		// Should contain the text content (without empty delta)
2798		require.Equal(t, []string{"Hello, World!"}, textDeltas)
2799	})
2800
2801	t.Run("should send reasoning tokens", func(t *testing.T) {
2802		t.Parallel()
2803
2804		server := newStreamingMockServer()
2805		defer server.close()
2806
2807		server.prepareStreamResponse(map[string]any{
2808			"content": []string{"Hello, World!"},
2809			"model":   "o1-preview",
2810			"usage": map[string]any{
2811				"prompt_tokens":     15,
2812				"completion_tokens": 20,
2813				"total_tokens":      35,
2814				"completion_tokens_details": map[string]any{
2815					"reasoning_tokens": 10,
2816				},
2817			},
2818		})
2819
2820		provider := NewOpenAIProvider(
2821			WithOpenAIApiKey("test-api-key"),
2822			WithOpenAIBaseURL(server.server.URL),
2823		)
2824		model := provider.LanguageModel("o1-preview")
2825
2826		stream, err := model.Stream(context.Background(), ai.Call{
2827			Prompt: testPrompt,
2828		})
2829
2830		require.NoError(t, err)
2831
2832		parts, err := collectStreamParts(stream)
2833		require.NoError(t, err)
2834
2835		// Find finish part
2836		var finishPart *ai.StreamPart
2837		for _, part := range parts {
2838			if part.Type == ai.StreamPartTypeFinish {
2839				finishPart = &part
2840				break
2841			}
2842		}
2843
2844		require.NotNil(t, finishPart)
2845		require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2846		require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2847		require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2848		require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens)
2849	})
2850}