openai_test.go

   1package providers
   2
   3import (
   4	"context"
   5	"encoding/base64"
   6	"encoding/json"
   7	"errors"
   8	"net/http"
   9	"net/http/httptest"
  10	"strings"
  11	"testing"
  12
  13	"github.com/charmbracelet/crush/internal/ai"
  14	"github.com/openai/openai-go/v2/packages/param"
  15	"github.com/stretchr/testify/require"
  16)
  17
  18func TestToOpenAIPrompt_SystemMessages(t *testing.T) {
  19	t.Parallel()
  20
  21	t.Run("should forward system messages", func(t *testing.T) {
  22		t.Parallel()
  23
  24		prompt := ai.Prompt{
  25			{
  26				Role: ai.MessageRoleSystem,
  27				Content: []ai.MessagePart{
  28					ai.TextPart{Text: "You are a helpful assistant."},
  29				},
  30			},
  31		}
  32
  33		messages, warnings := toOpenAIPrompt(prompt)
  34
  35		require.Empty(t, warnings)
  36		require.Len(t, messages, 1)
  37
  38		systemMsg := messages[0].OfSystem
  39		require.NotNil(t, systemMsg)
  40		require.Equal(t, "You are a helpful assistant.", systemMsg.Content.OfString.Value)
  41	})
  42
  43	t.Run("should handle empty system messages", func(t *testing.T) {
  44		t.Parallel()
  45
  46		prompt := ai.Prompt{
  47			{
  48				Role:    ai.MessageRoleSystem,
  49				Content: []ai.MessagePart{},
  50			},
  51		}
  52
  53		messages, warnings := toOpenAIPrompt(prompt)
  54
  55		require.Len(t, warnings, 1)
  56		require.Contains(t, warnings[0].Message, "system prompt has no text parts")
  57		require.Empty(t, messages)
  58	})
  59
  60	t.Run("should join multiple system text parts", func(t *testing.T) {
  61		t.Parallel()
  62
  63		prompt := ai.Prompt{
  64			{
  65				Role: ai.MessageRoleSystem,
  66				Content: []ai.MessagePart{
  67					ai.TextPart{Text: "You are a helpful assistant."},
  68					ai.TextPart{Text: "Be concise."},
  69				},
  70			},
  71		}
  72
  73		messages, warnings := toOpenAIPrompt(prompt)
  74
  75		require.Empty(t, warnings)
  76		require.Len(t, messages, 1)
  77
  78		systemMsg := messages[0].OfSystem
  79		require.NotNil(t, systemMsg)
  80		require.Equal(t, "You are a helpful assistant.\nBe concise.", systemMsg.Content.OfString.Value)
  81	})
  82}
  83
  84func TestToOpenAIPrompt_UserMessages(t *testing.T) {
  85	t.Parallel()
  86
  87	t.Run("should convert messages with only a text part to a string content", func(t *testing.T) {
  88		t.Parallel()
  89
  90		prompt := ai.Prompt{
  91			{
  92				Role: ai.MessageRoleUser,
  93				Content: []ai.MessagePart{
  94					ai.TextPart{Text: "Hello"},
  95				},
  96			},
  97		}
  98
  99		messages, warnings := toOpenAIPrompt(prompt)
 100
 101		require.Empty(t, warnings)
 102		require.Len(t, messages, 1)
 103
 104		userMsg := messages[0].OfUser
 105		require.NotNil(t, userMsg)
 106		require.Equal(t, "Hello", userMsg.Content.OfString.Value)
 107	})
 108
 109	t.Run("should convert messages with image parts", func(t *testing.T) {
 110		t.Parallel()
 111
 112		imageData := []byte{0, 1, 2, 3}
 113		prompt := ai.Prompt{
 114			{
 115				Role: ai.MessageRoleUser,
 116				Content: []ai.MessagePart{
 117					ai.TextPart{Text: "Hello"},
 118					ai.FilePart{
 119						MediaType: "image/png",
 120						Data:      imageData,
 121					},
 122				},
 123			},
 124		}
 125
 126		messages, warnings := toOpenAIPrompt(prompt)
 127
 128		require.Empty(t, warnings)
 129		require.Len(t, messages, 1)
 130
 131		userMsg := messages[0].OfUser
 132		require.NotNil(t, userMsg)
 133
 134		content := userMsg.Content.OfArrayOfContentParts
 135		require.Len(t, content, 2)
 136
 137		// Check text part
 138		textPart := content[0].OfText
 139		require.NotNil(t, textPart)
 140		require.Equal(t, "Hello", textPart.Text)
 141
 142		// Check image part
 143		imagePart := content[1].OfImageURL
 144		require.NotNil(t, imagePart)
 145		expectedURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
 146		require.Equal(t, expectedURL, imagePart.ImageURL.URL)
 147	})
 148
 149	t.Run("should add image detail when specified through provider options", func(t *testing.T) {
 150		t.Parallel()
 151
 152		imageData := []byte{0, 1, 2, 3}
 153		prompt := ai.Prompt{
 154			{
 155				Role: ai.MessageRoleUser,
 156				Content: []ai.MessagePart{
 157					ai.FilePart{
 158						MediaType: "image/png",
 159						Data:      imageData,
 160						ProviderOptions: ai.ProviderOptions{
 161							"openai": map[string]any{
 162								"imageDetail": "low",
 163							},
 164						},
 165					},
 166				},
 167			},
 168		}
 169
 170		messages, warnings := toOpenAIPrompt(prompt)
 171
 172		require.Empty(t, warnings)
 173		require.Len(t, messages, 1)
 174
 175		userMsg := messages[0].OfUser
 176		require.NotNil(t, userMsg)
 177
 178		content := userMsg.Content.OfArrayOfContentParts
 179		require.Len(t, content, 1)
 180
 181		imagePart := content[0].OfImageURL
 182		require.NotNil(t, imagePart)
 183		require.Equal(t, "low", imagePart.ImageURL.Detail)
 184	})
 185}
 186
 187func TestToOpenAIPrompt_FileParts(t *testing.T) {
 188	t.Parallel()
 189
 190	t.Run("should throw for unsupported mime types", func(t *testing.T) {
 191		t.Parallel()
 192
 193		prompt := ai.Prompt{
 194			{
 195				Role: ai.MessageRoleUser,
 196				Content: []ai.MessagePart{
 197					ai.FilePart{
 198						MediaType: "application/something",
 199						Data:      []byte("test"),
 200					},
 201				},
 202			},
 203		}
 204
 205		messages, warnings := toOpenAIPrompt(prompt)
 206
 207		require.Len(t, warnings, 1)
 208		require.Contains(t, warnings[0].Message, "file part media type application/something not supported")
 209		require.Len(t, messages, 1) // Message is still created but with empty content array
 210	})
 211
 212	t.Run("should add audio content for audio/wav file parts", func(t *testing.T) {
 213		t.Parallel()
 214
 215		audioData := []byte{0, 1, 2, 3}
 216		prompt := ai.Prompt{
 217			{
 218				Role: ai.MessageRoleUser,
 219				Content: []ai.MessagePart{
 220					ai.FilePart{
 221						MediaType: "audio/wav",
 222						Data:      audioData,
 223					},
 224				},
 225			},
 226		}
 227
 228		messages, warnings := toOpenAIPrompt(prompt)
 229
 230		require.Empty(t, warnings)
 231		require.Len(t, messages, 1)
 232
 233		userMsg := messages[0].OfUser
 234		require.NotNil(t, userMsg)
 235
 236		content := userMsg.Content.OfArrayOfContentParts
 237		require.Len(t, content, 1)
 238
 239		audioPart := content[0].OfInputAudio
 240		require.NotNil(t, audioPart)
 241		require.Equal(t, base64.StdEncoding.EncodeToString(audioData), audioPart.InputAudio.Data)
 242		require.Equal(t, "wav", audioPart.InputAudio.Format)
 243	})
 244
 245	t.Run("should add audio content for audio/mpeg file parts", func(t *testing.T) {
 246		t.Parallel()
 247
 248		audioData := []byte{0, 1, 2, 3}
 249		prompt := ai.Prompt{
 250			{
 251				Role: ai.MessageRoleUser,
 252				Content: []ai.MessagePart{
 253					ai.FilePart{
 254						MediaType: "audio/mpeg",
 255						Data:      audioData,
 256					},
 257				},
 258			},
 259		}
 260
 261		messages, warnings := toOpenAIPrompt(prompt)
 262
 263		require.Empty(t, warnings)
 264		require.Len(t, messages, 1)
 265
 266		userMsg := messages[0].OfUser
 267		content := userMsg.Content.OfArrayOfContentParts
 268		audioPart := content[0].OfInputAudio
 269		require.NotNil(t, audioPart)
 270		require.Equal(t, "mp3", audioPart.InputAudio.Format)
 271	})
 272
 273	t.Run("should add audio content for audio/mp3 file parts", func(t *testing.T) {
 274		t.Parallel()
 275
 276		audioData := []byte{0, 1, 2, 3}
 277		prompt := ai.Prompt{
 278			{
 279				Role: ai.MessageRoleUser,
 280				Content: []ai.MessagePart{
 281					ai.FilePart{
 282						MediaType: "audio/mp3",
 283						Data:      audioData,
 284					},
 285				},
 286			},
 287		}
 288
 289		messages, warnings := toOpenAIPrompt(prompt)
 290
 291		require.Empty(t, warnings)
 292		require.Len(t, messages, 1)
 293
 294		userMsg := messages[0].OfUser
 295		content := userMsg.Content.OfArrayOfContentParts
 296		audioPart := content[0].OfInputAudio
 297		require.NotNil(t, audioPart)
 298		require.Equal(t, "mp3", audioPart.InputAudio.Format)
 299	})
 300
 301	t.Run("should convert messages with PDF file parts", func(t *testing.T) {
 302		t.Parallel()
 303
 304		pdfData := []byte{1, 2, 3, 4, 5}
 305		prompt := ai.Prompt{
 306			{
 307				Role: ai.MessageRoleUser,
 308				Content: []ai.MessagePart{
 309					ai.FilePart{
 310						MediaType: "application/pdf",
 311						Data:      pdfData,
 312						Filename:  "document.pdf",
 313					},
 314				},
 315			},
 316		}
 317
 318		messages, warnings := toOpenAIPrompt(prompt)
 319
 320		require.Empty(t, warnings)
 321		require.Len(t, messages, 1)
 322
 323		userMsg := messages[0].OfUser
 324		content := userMsg.Content.OfArrayOfContentParts
 325		require.Len(t, content, 1)
 326
 327		filePart := content[0].OfFile
 328		require.NotNil(t, filePart)
 329		require.Equal(t, "document.pdf", filePart.File.Filename.Value)
 330
 331		expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
 332		require.Equal(t, expectedData, filePart.File.FileData.Value)
 333	})
 334
 335	t.Run("should convert messages with binary PDF file parts", func(t *testing.T) {
 336		t.Parallel()
 337
 338		pdfData := []byte{1, 2, 3, 4, 5}
 339		prompt := ai.Prompt{
 340			{
 341				Role: ai.MessageRoleUser,
 342				Content: []ai.MessagePart{
 343					ai.FilePart{
 344						MediaType: "application/pdf",
 345						Data:      pdfData,
 346						Filename:  "document.pdf",
 347					},
 348				},
 349			},
 350		}
 351
 352		messages, warnings := toOpenAIPrompt(prompt)
 353
 354		require.Empty(t, warnings)
 355		require.Len(t, messages, 1)
 356
 357		userMsg := messages[0].OfUser
 358		content := userMsg.Content.OfArrayOfContentParts
 359		filePart := content[0].OfFile
 360		require.NotNil(t, filePart)
 361
 362		expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
 363		require.Equal(t, expectedData, filePart.File.FileData.Value)
 364	})
 365
 366	t.Run("should convert messages with PDF file parts using file_id", func(t *testing.T) {
 367		t.Parallel()
 368
 369		prompt := ai.Prompt{
 370			{
 371				Role: ai.MessageRoleUser,
 372				Content: []ai.MessagePart{
 373					ai.FilePart{
 374						MediaType: "application/pdf",
 375						Data:      []byte("file-pdf-12345"),
 376					},
 377				},
 378			},
 379		}
 380
 381		messages, warnings := toOpenAIPrompt(prompt)
 382
 383		require.Empty(t, warnings)
 384		require.Len(t, messages, 1)
 385
 386		userMsg := messages[0].OfUser
 387		content := userMsg.Content.OfArrayOfContentParts
 388		filePart := content[0].OfFile
 389		require.NotNil(t, filePart)
 390		require.Equal(t, "file-pdf-12345", filePart.File.FileID.Value)
 391		require.True(t, param.IsOmitted(filePart.File.FileData))
 392		require.True(t, param.IsOmitted(filePart.File.Filename))
 393	})
 394
 395	t.Run("should use default filename for PDF file parts when not provided", func(t *testing.T) {
 396		t.Parallel()
 397
 398		pdfData := []byte{1, 2, 3, 4, 5}
 399		prompt := ai.Prompt{
 400			{
 401				Role: ai.MessageRoleUser,
 402				Content: []ai.MessagePart{
 403					ai.FilePart{
 404						MediaType: "application/pdf",
 405						Data:      pdfData,
 406					},
 407				},
 408			},
 409		}
 410
 411		messages, warnings := toOpenAIPrompt(prompt)
 412
 413		require.Empty(t, warnings)
 414		require.Len(t, messages, 1)
 415
 416		userMsg := messages[0].OfUser
 417		content := userMsg.Content.OfArrayOfContentParts
 418		filePart := content[0].OfFile
 419		require.NotNil(t, filePart)
 420		require.Equal(t, "part-0.pdf", filePart.File.Filename.Value)
 421	})
 422}
 423
 424func TestToOpenAIPrompt_ToolCalls(t *testing.T) {
 425	t.Parallel()
 426
 427	t.Run("should stringify arguments to tool calls", func(t *testing.T) {
 428		t.Parallel()
 429
 430		inputArgs := map[string]any{"foo": "bar123"}
 431		inputJSON, _ := json.Marshal(inputArgs)
 432
 433		outputResult := map[string]any{"oof": "321rab"}
 434		outputJSON, _ := json.Marshal(outputResult)
 435
 436		prompt := ai.Prompt{
 437			{
 438				Role: ai.MessageRoleAssistant,
 439				Content: []ai.MessagePart{
 440					ai.ToolCallPart{
 441						ToolCallID: "quux",
 442						ToolName:   "thwomp",
 443						Input:      string(inputJSON),
 444					},
 445				},
 446			},
 447			{
 448				Role: ai.MessageRoleTool,
 449				Content: []ai.MessagePart{
 450					ai.ToolResultPart{
 451						ToolCallID: "quux",
 452						Output: ai.ToolResultOutputContentText{
 453							Text: string(outputJSON),
 454						},
 455					},
 456				},
 457			},
 458		}
 459
 460		messages, warnings := toOpenAIPrompt(prompt)
 461
 462		require.Empty(t, warnings)
 463		require.Len(t, messages, 2)
 464
 465		// Check assistant message with tool call
 466		assistantMsg := messages[0].OfAssistant
 467		require.NotNil(t, assistantMsg)
 468		require.Equal(t, "", assistantMsg.Content.OfString.Value)
 469		require.Len(t, assistantMsg.ToolCalls, 1)
 470
 471		toolCall := assistantMsg.ToolCalls[0].OfFunction
 472		require.NotNil(t, toolCall)
 473		require.Equal(t, "quux", toolCall.ID)
 474		require.Equal(t, "thwomp", toolCall.Function.Name)
 475		require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
 476
 477		// Check tool message
 478		toolMsg := messages[1].OfTool
 479		require.NotNil(t, toolMsg)
 480		require.Equal(t, string(outputJSON), toolMsg.Content.OfString.Value)
 481		require.Equal(t, "quux", toolMsg.ToolCallID)
 482	})
 483
 484	t.Run("should handle different tool output types", func(t *testing.T) {
 485		t.Parallel()
 486
 487		prompt := ai.Prompt{
 488			{
 489				Role: ai.MessageRoleTool,
 490				Content: []ai.MessagePart{
 491					ai.ToolResultPart{
 492						ToolCallID: "text-tool",
 493						Output: ai.ToolResultOutputContentText{
 494							Text: "Hello world",
 495						},
 496					},
 497					ai.ToolResultPart{
 498						ToolCallID: "error-tool",
 499						Output: ai.ToolResultOutputContentError{
 500							Error: errors.New("Something went wrong"),
 501						},
 502					},
 503				},
 504			},
 505		}
 506
 507		messages, warnings := toOpenAIPrompt(prompt)
 508
 509		require.Empty(t, warnings)
 510		require.Len(t, messages, 2)
 511
 512		// Check first tool message (text)
 513		textToolMsg := messages[0].OfTool
 514		require.NotNil(t, textToolMsg)
 515		require.Equal(t, "Hello world", textToolMsg.Content.OfString.Value)
 516		require.Equal(t, "text-tool", textToolMsg.ToolCallID)
 517
 518		// Check second tool message (error)
 519		errorToolMsg := messages[1].OfTool
 520		require.NotNil(t, errorToolMsg)
 521		require.Equal(t, "Something went wrong", errorToolMsg.Content.OfString.Value)
 522		require.Equal(t, "error-tool", errorToolMsg.ToolCallID)
 523	})
 524}
 525
 526func TestToOpenAIPrompt_AssistantMessages(t *testing.T) {
 527	t.Parallel()
 528
 529	t.Run("should handle simple text assistant messages", func(t *testing.T) {
 530		t.Parallel()
 531
 532		prompt := ai.Prompt{
 533			{
 534				Role: ai.MessageRoleAssistant,
 535				Content: []ai.MessagePart{
 536					ai.TextPart{Text: "Hello, how can I help you?"},
 537				},
 538			},
 539		}
 540
 541		messages, warnings := toOpenAIPrompt(prompt)
 542
 543		require.Empty(t, warnings)
 544		require.Len(t, messages, 1)
 545
 546		assistantMsg := messages[0].OfAssistant
 547		require.NotNil(t, assistantMsg)
 548		require.Equal(t, "Hello, how can I help you?", assistantMsg.Content.OfString.Value)
 549	})
 550
 551	t.Run("should handle assistant messages with mixed content", func(t *testing.T) {
 552		t.Parallel()
 553
 554		inputArgs := map[string]any{"query": "test"}
 555		inputJSON, _ := json.Marshal(inputArgs)
 556
 557		prompt := ai.Prompt{
 558			{
 559				Role: ai.MessageRoleAssistant,
 560				Content: []ai.MessagePart{
 561					ai.TextPart{Text: "Let me search for that."},
 562					ai.ToolCallPart{
 563						ToolCallID: "call-123",
 564						ToolName:   "search",
 565						Input:      string(inputJSON),
 566					},
 567				},
 568			},
 569		}
 570
 571		messages, warnings := toOpenAIPrompt(prompt)
 572
 573		require.Empty(t, warnings)
 574		require.Len(t, messages, 1)
 575
 576		assistantMsg := messages[0].OfAssistant
 577		require.NotNil(t, assistantMsg)
 578		require.Equal(t, "Let me search for that.", assistantMsg.Content.OfString.Value)
 579		require.Len(t, assistantMsg.ToolCalls, 1)
 580
 581		toolCall := assistantMsg.ToolCalls[0].OfFunction
 582		require.Equal(t, "call-123", toolCall.ID)
 583		require.Equal(t, "search", toolCall.Function.Name)
 584		require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
 585	})
 586}
 587
 588var testPrompt = ai.Prompt{
 589	{
 590		Role: ai.MessageRoleUser,
 591		Content: []ai.MessagePart{
 592			ai.TextPart{Text: "Hello"},
 593		},
 594	},
 595}
 596
 597var testLogprobs = map[string]any{
 598	"content": []map[string]any{
 599		{
 600			"token":   "Hello",
 601			"logprob": -0.0009994634,
 602			"top_logprobs": []map[string]any{
 603				{
 604					"token":   "Hello",
 605					"logprob": -0.0009994634,
 606				},
 607			},
 608		},
 609		{
 610			"token":   "!",
 611			"logprob": -0.13410144,
 612			"top_logprobs": []map[string]any{
 613				{
 614					"token":   "!",
 615					"logprob": -0.13410144,
 616				},
 617			},
 618		},
 619		{
 620			"token":   " How",
 621			"logprob": -0.0009250381,
 622			"top_logprobs": []map[string]any{
 623				{
 624					"token":   " How",
 625					"logprob": -0.0009250381,
 626				},
 627			},
 628		},
 629		{
 630			"token":   " can",
 631			"logprob": -0.047709424,
 632			"top_logprobs": []map[string]any{
 633				{
 634					"token":   " can",
 635					"logprob": -0.047709424,
 636				},
 637			},
 638		},
 639		{
 640			"token":   " I",
 641			"logprob": -0.000009014684,
 642			"top_logprobs": []map[string]any{
 643				{
 644					"token":   " I",
 645					"logprob": -0.000009014684,
 646				},
 647			},
 648		},
 649		{
 650			"token":   " assist",
 651			"logprob": -0.009125131,
 652			"top_logprobs": []map[string]any{
 653				{
 654					"token":   " assist",
 655					"logprob": -0.009125131,
 656				},
 657			},
 658		},
 659		{
 660			"token":   " you",
 661			"logprob": -0.0000066306106,
 662			"top_logprobs": []map[string]any{
 663				{
 664					"token":   " you",
 665					"logprob": -0.0000066306106,
 666				},
 667			},
 668		},
 669		{
 670			"token":   " today",
 671			"logprob": -0.00011093382,
 672			"top_logprobs": []map[string]any{
 673				{
 674					"token":   " today",
 675					"logprob": -0.00011093382,
 676				},
 677			},
 678		},
 679		{
 680			"token":   "?",
 681			"logprob": -0.00004596782,
 682			"top_logprobs": []map[string]any{
 683				{
 684					"token":   "?",
 685					"logprob": -0.00004596782,
 686				},
 687			},
 688		},
 689	},
 690}
 691
 692type mockServer struct {
 693	server   *httptest.Server
 694	response map[string]any
 695	calls    []mockCall
 696}
 697
 698type mockCall struct {
 699	method  string
 700	path    string
 701	headers map[string]string
 702	body    map[string]any
 703}
 704
 705func newMockServer() *mockServer {
 706	ms := &mockServer{
 707		calls: make([]mockCall, 0),
 708	}
 709
 710	ms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 711		// Record the call
 712		call := mockCall{
 713			method:  r.Method,
 714			path:    r.URL.Path,
 715			headers: make(map[string]string),
 716		}
 717
 718		for k, v := range r.Header {
 719			if len(v) > 0 {
 720				call.headers[k] = v[0]
 721			}
 722		}
 723
 724		// Parse request body
 725		if r.Body != nil {
 726			var body map[string]any
 727			json.NewDecoder(r.Body).Decode(&body)
 728			call.body = body
 729		}
 730
 731		ms.calls = append(ms.calls, call)
 732
 733		// Return mock response
 734		w.Header().Set("Content-Type", "application/json")
 735		json.NewEncoder(w).Encode(ms.response)
 736	}))
 737
 738	return ms
 739}
 740
 741func (ms *mockServer) close() {
 742	ms.server.Close()
 743}
 744
 745func (ms *mockServer) prepareJSONResponse(opts map[string]any) {
 746	// Default values
 747	response := map[string]any{
 748		"id":      "chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd",
 749		"object":  "chat.completion",
 750		"created": 1711115037,
 751		"model":   "gpt-3.5-turbo-0125",
 752		"choices": []map[string]any{
 753			{
 754				"index": 0,
 755				"message": map[string]any{
 756					"role":    "assistant",
 757					"content": "",
 758				},
 759				"finish_reason": "stop",
 760			},
 761		},
 762		"usage": map[string]any{
 763			"prompt_tokens":     4,
 764			"total_tokens":      34,
 765			"completion_tokens": 30,
 766		},
 767		"system_fingerprint": "fp_3bc1b5746c",
 768	}
 769
 770	// Override with provided options
 771	for k, v := range opts {
 772		switch k {
 773		case "content":
 774			response["choices"].([]map[string]any)[0]["message"].(map[string]any)["content"] = v
 775		case "tool_calls":
 776			response["choices"].([]map[string]any)[0]["message"].(map[string]any)["tool_calls"] = v
 777		case "function_call":
 778			response["choices"].([]map[string]any)[0]["message"].(map[string]any)["function_call"] = v
 779		case "annotations":
 780			response["choices"].([]map[string]any)[0]["message"].(map[string]any)["annotations"] = v
 781		case "usage":
 782			response["usage"] = v
 783		case "finish_reason":
 784			response["choices"].([]map[string]any)[0]["finish_reason"] = v
 785		case "id":
 786			response["id"] = v
 787		case "created":
 788			response["created"] = v
 789		case "model":
 790			response["model"] = v
 791		case "logprobs":
 792			if v != nil {
 793				response["choices"].([]map[string]any)[0]["logprobs"] = v
 794			}
 795		}
 796	}
 797
 798	ms.response = response
 799}
 800
 801func TestDoGenerate(t *testing.T) {
 802	t.Parallel()
 803
 804	t.Run("should extract text response", func(t *testing.T) {
 805		t.Parallel()
 806
 807		server := newMockServer()
 808		defer server.close()
 809
 810		server.prepareJSONResponse(map[string]any{
 811			"content": "Hello, World!",
 812		})
 813
 814		provider := NewOpenAIProvider(
 815			WithOpenAIApiKey("test-api-key"),
 816			WithOpenAIBaseURL(server.server.URL),
 817		)
 818		model := provider.LanguageModel("gpt-3.5-turbo")
 819
 820		result, err := model.Generate(context.Background(), ai.Call{
 821			Prompt: testPrompt,
 822		})
 823
 824		require.NoError(t, err)
 825		require.Len(t, result.Content, 1)
 826
 827		textContent, ok := result.Content[0].(ai.TextContent)
 828		require.True(t, ok)
 829		require.Equal(t, "Hello, World!", textContent.Text)
 830	})
 831
 832	t.Run("should extract usage", func(t *testing.T) {
 833		t.Parallel()
 834
 835		server := newMockServer()
 836		defer server.close()
 837
 838		server.prepareJSONResponse(map[string]any{
 839			"usage": map[string]any{
 840				"prompt_tokens":     20,
 841				"total_tokens":      25,
 842				"completion_tokens": 5,
 843			},
 844		})
 845
 846		provider := NewOpenAIProvider(
 847			WithOpenAIApiKey("test-api-key"),
 848			WithOpenAIBaseURL(server.server.URL),
 849		)
 850		model := provider.LanguageModel("gpt-3.5-turbo")
 851
 852		result, err := model.Generate(context.Background(), ai.Call{
 853			Prompt: testPrompt,
 854		})
 855
 856		require.NoError(t, err)
 857		require.Equal(t, int64(20), result.Usage.InputTokens)
 858		require.Equal(t, int64(5), result.Usage.OutputTokens)
 859		require.Equal(t, int64(25), result.Usage.TotalTokens)
 860	})
 861
 862	t.Run("should send request body", func(t *testing.T) {
 863		t.Parallel()
 864
 865		server := newMockServer()
 866		defer server.close()
 867
 868		server.prepareJSONResponse(map[string]any{})
 869
 870		provider := NewOpenAIProvider(
 871			WithOpenAIApiKey("test-api-key"),
 872			WithOpenAIBaseURL(server.server.URL),
 873		)
 874		model := provider.LanguageModel("gpt-3.5-turbo")
 875
 876		_, err := model.Generate(context.Background(), ai.Call{
 877			Prompt: testPrompt,
 878		})
 879
 880		require.NoError(t, err)
 881		require.Len(t, server.calls, 1)
 882
 883		call := server.calls[0]
 884		require.Equal(t, "POST", call.method)
 885		require.Equal(t, "/chat/completions", call.path)
 886		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
 887
 888		messages, ok := call.body["messages"].([]any)
 889		require.True(t, ok)
 890		require.Len(t, messages, 1)
 891
 892		message := messages[0].(map[string]any)
 893		require.Equal(t, "user", message["role"])
 894		require.Equal(t, "Hello", message["content"])
 895	})
 896
 897	t.Run("should support partial usage", func(t *testing.T) {
 898		t.Parallel()
 899
 900		server := newMockServer()
 901		defer server.close()
 902
 903		server.prepareJSONResponse(map[string]any{
 904			"usage": map[string]any{
 905				"prompt_tokens": 20,
 906				"total_tokens":  20,
 907			},
 908		})
 909
 910		provider := NewOpenAIProvider(
 911			WithOpenAIApiKey("test-api-key"),
 912			WithOpenAIBaseURL(server.server.URL),
 913		)
 914		model := provider.LanguageModel("gpt-3.5-turbo")
 915
 916		result, err := model.Generate(context.Background(), ai.Call{
 917			Prompt: testPrompt,
 918		})
 919
 920		require.NoError(t, err)
 921		require.Equal(t, int64(20), result.Usage.InputTokens)
 922		require.Equal(t, int64(0), result.Usage.OutputTokens)
 923		require.Equal(t, int64(20), result.Usage.TotalTokens)
 924	})
 925
 926	t.Run("should extract logprobs", func(t *testing.T) {
 927		t.Parallel()
 928
 929		server := newMockServer()
 930		defer server.close()
 931
 932		server.prepareJSONResponse(map[string]any{
 933			"logprobs": testLogprobs,
 934		})
 935
 936		provider := NewOpenAIProvider(
 937			WithOpenAIApiKey("test-api-key"),
 938			WithOpenAIBaseURL(server.server.URL),
 939		)
 940		model := provider.LanguageModel("gpt-3.5-turbo")
 941
 942		result, err := model.Generate(context.Background(), ai.Call{
 943			Prompt: testPrompt,
 944			ProviderOptions: ai.ProviderOptions{
 945				"openai": map[string]any{
 946					"logProbs": true,
 947				},
 948			},
 949		})
 950
 951		require.NoError(t, err)
 952		require.NotNil(t, result.ProviderMetadata)
 953
 954		openaiMeta, ok := result.ProviderMetadata["openai"]
 955		require.True(t, ok)
 956
 957		logprobs, ok := openaiMeta["logprobs"]
 958		require.True(t, ok)
 959		require.NotNil(t, logprobs)
 960	})
 961
 962	t.Run("should extract finish reason", func(t *testing.T) {
 963		t.Parallel()
 964
 965		server := newMockServer()
 966		defer server.close()
 967
 968		server.prepareJSONResponse(map[string]any{
 969			"finish_reason": "stop",
 970		})
 971
 972		provider := NewOpenAIProvider(
 973			WithOpenAIApiKey("test-api-key"),
 974			WithOpenAIBaseURL(server.server.URL),
 975		)
 976		model := provider.LanguageModel("gpt-3.5-turbo")
 977
 978		result, err := model.Generate(context.Background(), ai.Call{
 979			Prompt: testPrompt,
 980		})
 981
 982		require.NoError(t, err)
 983		require.Equal(t, ai.FinishReasonStop, result.FinishReason)
 984	})
 985
 986	t.Run("should support unknown finish reason", func(t *testing.T) {
 987		t.Parallel()
 988
 989		server := newMockServer()
 990		defer server.close()
 991
 992		server.prepareJSONResponse(map[string]any{
 993			"finish_reason": "eos",
 994		})
 995
 996		provider := NewOpenAIProvider(
 997			WithOpenAIApiKey("test-api-key"),
 998			WithOpenAIBaseURL(server.server.URL),
 999		)
1000		model := provider.LanguageModel("gpt-3.5-turbo")
1001
1002		result, err := model.Generate(context.Background(), ai.Call{
1003			Prompt: testPrompt,
1004		})
1005
1006		require.NoError(t, err)
1007		require.Equal(t, ai.FinishReasonUnknown, result.FinishReason)
1008	})
1009
1010	t.Run("should pass the model and the messages", func(t *testing.T) {
1011		t.Parallel()
1012
1013		server := newMockServer()
1014		defer server.close()
1015
1016		server.prepareJSONResponse(map[string]any{
1017			"content": "",
1018		})
1019
1020		provider := NewOpenAIProvider(
1021			WithOpenAIApiKey("test-api-key"),
1022			WithOpenAIBaseURL(server.server.URL),
1023		)
1024		model := provider.LanguageModel("gpt-3.5-turbo")
1025
1026		_, err := model.Generate(context.Background(), ai.Call{
1027			Prompt: testPrompt,
1028		})
1029
1030		require.NoError(t, err)
1031		require.Len(t, server.calls, 1)
1032
1033		call := server.calls[0]
1034		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1035
1036		messages := call.body["messages"].([]any)
1037		require.Len(t, messages, 1)
1038
1039		message := messages[0].(map[string]any)
1040		require.Equal(t, "user", message["role"])
1041		require.Equal(t, "Hello", message["content"])
1042	})
1043
1044	t.Run("should pass settings", func(t *testing.T) {
1045		t.Parallel()
1046
1047		server := newMockServer()
1048		defer server.close()
1049
1050		server.prepareJSONResponse(map[string]any{})
1051
1052		provider := NewOpenAIProvider(
1053			WithOpenAIApiKey("test-api-key"),
1054			WithOpenAIBaseURL(server.server.URL),
1055		)
1056		model := provider.LanguageModel("gpt-3.5-turbo")
1057
1058		_, err := model.Generate(context.Background(), ai.Call{
1059			Prompt: testPrompt,
1060			ProviderOptions: ai.ProviderOptions{
1061				"openai": map[string]any{
1062					"logitBias": map[string]int64{
1063						"50256": -100,
1064					},
1065					"parallelToolCalls": false,
1066					"user":              "test-user-id",
1067				},
1068			},
1069		})
1070
1071		require.NoError(t, err)
1072		require.Len(t, server.calls, 1)
1073
1074		call := server.calls[0]
1075		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1076
1077		messages := call.body["messages"].([]any)
1078		require.Len(t, messages, 1)
1079
1080		logitBias := call.body["logit_bias"].(map[string]any)
1081		require.Equal(t, float64(-100), logitBias["50256"])
1082		require.Equal(t, false, call.body["parallel_tool_calls"])
1083		require.Equal(t, "test-user-id", call.body["user"])
1084	})
1085
1086	t.Run("should pass reasoningEffort setting", func(t *testing.T) {
1087		t.Parallel()
1088
1089		server := newMockServer()
1090		defer server.close()
1091
1092		server.prepareJSONResponse(map[string]any{
1093			"content": "",
1094		})
1095
1096		provider := NewOpenAIProvider(
1097			WithOpenAIApiKey("test-api-key"),
1098			WithOpenAIBaseURL(server.server.URL),
1099		)
1100		model := provider.LanguageModel("o1-mini")
1101
1102		_, err := model.Generate(context.Background(), ai.Call{
1103			Prompt: testPrompt,
1104			ProviderOptions: ai.ProviderOptions{
1105				"openai": map[string]any{
1106					"reasoningEffort": "low",
1107				},
1108			},
1109		})
1110
1111		require.NoError(t, err)
1112		require.Len(t, server.calls, 1)
1113
1114		call := server.calls[0]
1115		require.Equal(t, "o1-mini", call.body["model"])
1116		require.Equal(t, "low", call.body["reasoning_effort"])
1117
1118		messages := call.body["messages"].([]any)
1119		require.Len(t, messages, 1)
1120
1121		message := messages[0].(map[string]any)
1122		require.Equal(t, "user", message["role"])
1123		require.Equal(t, "Hello", message["content"])
1124	})
1125
1126	t.Run("should pass textVerbosity setting", func(t *testing.T) {
1127		t.Parallel()
1128
1129		server := newMockServer()
1130		defer server.close()
1131
1132		server.prepareJSONResponse(map[string]any{
1133			"content": "",
1134		})
1135
1136		provider := NewOpenAIProvider(
1137			WithOpenAIApiKey("test-api-key"),
1138			WithOpenAIBaseURL(server.server.URL),
1139		)
1140		model := provider.LanguageModel("gpt-4o")
1141
1142		_, err := model.Generate(context.Background(), ai.Call{
1143			Prompt: testPrompt,
1144			ProviderOptions: ai.ProviderOptions{
1145				"openai": map[string]any{
1146					"textVerbosity": "low",
1147				},
1148			},
1149		})
1150
1151		require.NoError(t, err)
1152		require.Len(t, server.calls, 1)
1153
1154		call := server.calls[0]
1155		require.Equal(t, "gpt-4o", call.body["model"])
1156		require.Equal(t, "low", call.body["verbosity"])
1157
1158		messages := call.body["messages"].([]any)
1159		require.Len(t, messages, 1)
1160
1161		message := messages[0].(map[string]any)
1162		require.Equal(t, "user", message["role"])
1163		require.Equal(t, "Hello", message["content"])
1164	})
1165
1166	t.Run("should pass tools and toolChoice", func(t *testing.T) {
1167		t.Parallel()
1168
1169		server := newMockServer()
1170		defer server.close()
1171
1172		server.prepareJSONResponse(map[string]any{
1173			"content": "",
1174		})
1175
1176		provider := NewOpenAIProvider(
1177			WithOpenAIApiKey("test-api-key"),
1178			WithOpenAIBaseURL(server.server.URL),
1179		)
1180		model := provider.LanguageModel("gpt-3.5-turbo")
1181
1182		_, err := model.Generate(context.Background(), ai.Call{
1183			Prompt: testPrompt,
1184			Tools: []ai.Tool{
1185				ai.FunctionTool{
1186					Name: "test-tool",
1187					InputSchema: map[string]any{
1188						"type": "object",
1189						"properties": map[string]any{
1190							"value": map[string]any{
1191								"type": "string",
1192							},
1193						},
1194						"required":             []string{"value"},
1195						"additionalProperties": false,
1196						"$schema":              "http://json-schema.org/draft-07/schema#",
1197					},
1198				},
1199			},
1200			ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0],
1201		})
1202
1203		require.NoError(t, err)
1204		require.Len(t, server.calls, 1)
1205
1206		call := server.calls[0]
1207		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1208
1209		messages := call.body["messages"].([]any)
1210		require.Len(t, messages, 1)
1211
1212		tools := call.body["tools"].([]any)
1213		require.Len(t, tools, 1)
1214
1215		tool := tools[0].(map[string]any)
1216		require.Equal(t, "function", tool["type"])
1217
1218		function := tool["function"].(map[string]any)
1219		require.Equal(t, "test-tool", function["name"])
1220		require.Equal(t, false, function["strict"])
1221
1222		toolChoice := call.body["tool_choice"].(map[string]any)
1223		require.Equal(t, "function", toolChoice["type"])
1224
1225		toolChoiceFunction := toolChoice["function"].(map[string]any)
1226		require.Equal(t, "test-tool", toolChoiceFunction["name"])
1227	})
1228
1229	t.Run("should parse tool results", func(t *testing.T) {
1230		t.Parallel()
1231
1232		server := newMockServer()
1233		defer server.close()
1234
1235		server.prepareJSONResponse(map[string]any{
1236			"tool_calls": []map[string]any{
1237				{
1238					"id":   "call_O17Uplv4lJvD6DVdIvFFeRMw",
1239					"type": "function",
1240					"function": map[string]any{
1241						"name":      "test-tool",
1242						"arguments": `{"value":"Spark"}`,
1243					},
1244				},
1245			},
1246		})
1247
1248		provider := NewOpenAIProvider(
1249			WithOpenAIApiKey("test-api-key"),
1250			WithOpenAIBaseURL(server.server.URL),
1251		)
1252		model := provider.LanguageModel("gpt-3.5-turbo")
1253
1254		result, err := model.Generate(context.Background(), ai.Call{
1255			Prompt: testPrompt,
1256			Tools: []ai.Tool{
1257				ai.FunctionTool{
1258					Name: "test-tool",
1259					InputSchema: map[string]any{
1260						"type": "object",
1261						"properties": map[string]any{
1262							"value": map[string]any{
1263								"type": "string",
1264							},
1265						},
1266						"required":             []string{"value"},
1267						"additionalProperties": false,
1268						"$schema":              "http://json-schema.org/draft-07/schema#",
1269					},
1270				},
1271			},
1272			ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0],
1273		})
1274
1275		require.NoError(t, err)
1276		require.Len(t, result.Content, 1)
1277
1278		toolCall, ok := result.Content[0].(ai.ToolCallContent)
1279		require.True(t, ok)
1280		require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", toolCall.ToolCallID)
1281		require.Equal(t, "test-tool", toolCall.ToolName)
1282		require.Equal(t, `{"value":"Spark"}`, toolCall.Input)
1283	})
1284
1285	t.Run("should parse annotations/citations", func(t *testing.T) {
1286		t.Parallel()
1287
1288		server := newMockServer()
1289		defer server.close()
1290
1291		server.prepareJSONResponse(map[string]any{
1292			"content": "Based on the search results [doc1], I found information.",
1293			"annotations": []map[string]any{
1294				{
1295					"type": "url_citation",
1296					"url_citation": map[string]any{
1297						"start_index": 24,
1298						"end_index":   29,
1299						"url":         "https://example.com/doc1.pdf",
1300						"title":       "Document 1",
1301					},
1302				},
1303			},
1304		})
1305
1306		provider := NewOpenAIProvider(
1307			WithOpenAIApiKey("test-api-key"),
1308			WithOpenAIBaseURL(server.server.URL),
1309		)
1310		model := provider.LanguageModel("gpt-3.5-turbo")
1311
1312		result, err := model.Generate(context.Background(), ai.Call{
1313			Prompt: testPrompt,
1314		})
1315
1316		require.NoError(t, err)
1317		require.Len(t, result.Content, 2)
1318
1319		textContent, ok := result.Content[0].(ai.TextContent)
1320		require.True(t, ok)
1321		require.Equal(t, "Based on the search results [doc1], I found information.", textContent.Text)
1322
1323		sourceContent, ok := result.Content[1].(ai.SourceContent)
1324		require.True(t, ok)
1325		require.Equal(t, ai.SourceTypeURL, sourceContent.SourceType)
1326		require.Equal(t, "https://example.com/doc1.pdf", sourceContent.URL)
1327		require.Equal(t, "Document 1", sourceContent.Title)
1328		require.NotEmpty(t, sourceContent.ID)
1329	})
1330
1331	t.Run("should return cached_tokens in prompt_details_tokens", func(t *testing.T) {
1332		t.Parallel()
1333
1334		server := newMockServer()
1335		defer server.close()
1336
1337		server.prepareJSONResponse(map[string]any{
1338			"usage": map[string]any{
1339				"prompt_tokens":     15,
1340				"completion_tokens": 20,
1341				"total_tokens":      35,
1342				"prompt_tokens_details": map[string]any{
1343					"cached_tokens": 1152,
1344				},
1345			},
1346		})
1347
1348		provider := NewOpenAIProvider(
1349			WithOpenAIApiKey("test-api-key"),
1350			WithOpenAIBaseURL(server.server.URL),
1351		)
1352		model := provider.LanguageModel("gpt-4o-mini")
1353
1354		result, err := model.Generate(context.Background(), ai.Call{
1355			Prompt: testPrompt,
1356		})
1357
1358		require.NoError(t, err)
1359		require.Equal(t, int64(1152), result.Usage.CacheReadTokens)
1360		require.Equal(t, int64(15), result.Usage.InputTokens)
1361		require.Equal(t, int64(20), result.Usage.OutputTokens)
1362		require.Equal(t, int64(35), result.Usage.TotalTokens)
1363	})
1364
1365	t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
1366		t.Parallel()
1367
1368		server := newMockServer()
1369		defer server.close()
1370
1371		server.prepareJSONResponse(map[string]any{
1372			"usage": map[string]any{
1373				"prompt_tokens":     15,
1374				"completion_tokens": 20,
1375				"total_tokens":      35,
1376				"completion_tokens_details": map[string]any{
1377					"accepted_prediction_tokens": 123,
1378					"rejected_prediction_tokens": 456,
1379				},
1380			},
1381		})
1382
1383		provider := NewOpenAIProvider(
1384			WithOpenAIApiKey("test-api-key"),
1385			WithOpenAIBaseURL(server.server.URL),
1386		)
1387		model := provider.LanguageModel("gpt-4o-mini")
1388
1389		result, err := model.Generate(context.Background(), ai.Call{
1390			Prompt: testPrompt,
1391		})
1392
1393		require.NoError(t, err)
1394		require.NotNil(t, result.ProviderMetadata)
1395
1396		openaiMeta, ok := result.ProviderMetadata["openai"]
1397		require.True(t, ok)
1398		require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"])
1399		require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"])
1400	})
1401
1402	t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) {
1403		t.Parallel()
1404
1405		server := newMockServer()
1406		defer server.close()
1407
1408		server.prepareJSONResponse(map[string]any{})
1409
1410		provider := NewOpenAIProvider(
1411			WithOpenAIApiKey("test-api-key"),
1412			WithOpenAIBaseURL(server.server.URL),
1413		)
1414		model := provider.LanguageModel("o1-preview")
1415
1416		result, err := model.Generate(context.Background(), ai.Call{
1417			Prompt:           testPrompt,
1418			Temperature:      &[]float64{0.5}[0],
1419			TopP:             &[]float64{0.7}[0],
1420			FrequencyPenalty: &[]float64{0.2}[0],
1421			PresencePenalty:  &[]float64{0.3}[0],
1422		})
1423
1424		require.NoError(t, err)
1425		require.Len(t, server.calls, 1)
1426
1427		call := server.calls[0]
1428		require.Equal(t, "o1-preview", call.body["model"])
1429
1430		messages := call.body["messages"].([]any)
1431		require.Len(t, messages, 1)
1432
1433		message := messages[0].(map[string]any)
1434		require.Equal(t, "user", message["role"])
1435		require.Equal(t, "Hello", message["content"])
1436
1437		// These should not be present
1438		require.Nil(t, call.body["temperature"])
1439		require.Nil(t, call.body["top_p"])
1440		require.Nil(t, call.body["frequency_penalty"])
1441		require.Nil(t, call.body["presence_penalty"])
1442
1443		// Should have warnings
1444		require.Len(t, result.Warnings, 4)
1445		require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1446		require.Equal(t, "temperature", result.Warnings[0].Setting)
1447		require.Contains(t, result.Warnings[0].Details, "temperature is not supported for reasoning models")
1448	})
1449
1450	t.Run("should convert maxOutputTokens to max_completion_tokens for reasoning models", func(t *testing.T) {
1451		t.Parallel()
1452
1453		server := newMockServer()
1454		defer server.close()
1455
1456		server.prepareJSONResponse(map[string]any{})
1457
1458		provider := NewOpenAIProvider(
1459			WithOpenAIApiKey("test-api-key"),
1460			WithOpenAIBaseURL(server.server.URL),
1461		)
1462		model := provider.LanguageModel("o1-preview")
1463
1464		_, err := model.Generate(context.Background(), ai.Call{
1465			Prompt:          testPrompt,
1466			MaxOutputTokens: &[]int64{1000}[0],
1467		})
1468
1469		require.NoError(t, err)
1470		require.Len(t, server.calls, 1)
1471
1472		call := server.calls[0]
1473		require.Equal(t, "o1-preview", call.body["model"])
1474		require.Equal(t, float64(1000), call.body["max_completion_tokens"])
1475		require.Nil(t, call.body["max_tokens"])
1476
1477		messages := call.body["messages"].([]any)
1478		require.Len(t, messages, 1)
1479
1480		message := messages[0].(map[string]any)
1481		require.Equal(t, "user", message["role"])
1482		require.Equal(t, "Hello", message["content"])
1483	})
1484
1485	t.Run("should return reasoning tokens", func(t *testing.T) {
1486		t.Parallel()
1487
1488		server := newMockServer()
1489		defer server.close()
1490
1491		server.prepareJSONResponse(map[string]any{
1492			"usage": map[string]any{
1493				"prompt_tokens":     15,
1494				"completion_tokens": 20,
1495				"total_tokens":      35,
1496				"completion_tokens_details": map[string]any{
1497					"reasoning_tokens": 10,
1498				},
1499			},
1500		})
1501
1502		provider := NewOpenAIProvider(
1503			WithOpenAIApiKey("test-api-key"),
1504			WithOpenAIBaseURL(server.server.URL),
1505		)
1506		model := provider.LanguageModel("o1-preview")
1507
1508		result, err := model.Generate(context.Background(), ai.Call{
1509			Prompt: testPrompt,
1510		})
1511
1512		require.NoError(t, err)
1513		require.Equal(t, int64(15), result.Usage.InputTokens)
1514		require.Equal(t, int64(20), result.Usage.OutputTokens)
1515		require.Equal(t, int64(35), result.Usage.TotalTokens)
1516		require.Equal(t, int64(10), result.Usage.ReasoningTokens)
1517	})
1518
1519	t.Run("should send max_completion_tokens extension setting", func(t *testing.T) {
1520		t.Parallel()
1521
1522		server := newMockServer()
1523		defer server.close()
1524
1525		server.prepareJSONResponse(map[string]any{
1526			"model": "o1-preview",
1527		})
1528
1529		provider := NewOpenAIProvider(
1530			WithOpenAIApiKey("test-api-key"),
1531			WithOpenAIBaseURL(server.server.URL),
1532		)
1533		model := provider.LanguageModel("o1-preview")
1534
1535		_, err := model.Generate(context.Background(), ai.Call{
1536			Prompt: testPrompt,
1537			ProviderOptions: ai.ProviderOptions{
1538				"openai": map[string]any{
1539					"maxCompletionTokens": 255,
1540				},
1541			},
1542		})
1543
1544		require.NoError(t, err)
1545		require.Len(t, server.calls, 1)
1546
1547		call := server.calls[0]
1548		require.Equal(t, "o1-preview", call.body["model"])
1549		require.Equal(t, float64(255), call.body["max_completion_tokens"])
1550
1551		messages := call.body["messages"].([]any)
1552		require.Len(t, messages, 1)
1553
1554		message := messages[0].(map[string]any)
1555		require.Equal(t, "user", message["role"])
1556		require.Equal(t, "Hello", message["content"])
1557	})
1558
1559	t.Run("should send prediction extension setting", func(t *testing.T) {
1560		t.Parallel()
1561
1562		server := newMockServer()
1563		defer server.close()
1564
1565		server.prepareJSONResponse(map[string]any{
1566			"content": "",
1567		})
1568
1569		provider := NewOpenAIProvider(
1570			WithOpenAIApiKey("test-api-key"),
1571			WithOpenAIBaseURL(server.server.URL),
1572		)
1573		model := provider.LanguageModel("gpt-3.5-turbo")
1574
1575		_, err := model.Generate(context.Background(), ai.Call{
1576			Prompt: testPrompt,
1577			ProviderOptions: ai.ProviderOptions{
1578				"openai": map[string]any{
1579					"prediction": map[string]any{
1580						"type":    "content",
1581						"content": "Hello, World!",
1582					},
1583				},
1584			},
1585		})
1586
1587		require.NoError(t, err)
1588		require.Len(t, server.calls, 1)
1589
1590		call := server.calls[0]
1591		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1592
1593		prediction := call.body["prediction"].(map[string]any)
1594		require.Equal(t, "content", prediction["type"])
1595		require.Equal(t, "Hello, World!", prediction["content"])
1596
1597		messages := call.body["messages"].([]any)
1598		require.Len(t, messages, 1)
1599
1600		message := messages[0].(map[string]any)
1601		require.Equal(t, "user", message["role"])
1602		require.Equal(t, "Hello", message["content"])
1603	})
1604
1605	t.Run("should send store extension setting", func(t *testing.T) {
1606		t.Parallel()
1607
1608		server := newMockServer()
1609		defer server.close()
1610
1611		server.prepareJSONResponse(map[string]any{
1612			"content": "",
1613		})
1614
1615		provider := NewOpenAIProvider(
1616			WithOpenAIApiKey("test-api-key"),
1617			WithOpenAIBaseURL(server.server.URL),
1618		)
1619		model := provider.LanguageModel("gpt-3.5-turbo")
1620
1621		_, err := model.Generate(context.Background(), ai.Call{
1622			Prompt: testPrompt,
1623			ProviderOptions: ai.ProviderOptions{
1624				"openai": map[string]any{
1625					"store": true,
1626				},
1627			},
1628		})
1629
1630		require.NoError(t, err)
1631		require.Len(t, server.calls, 1)
1632
1633		call := server.calls[0]
1634		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1635		require.Equal(t, true, call.body["store"])
1636
1637		messages := call.body["messages"].([]any)
1638		require.Len(t, messages, 1)
1639
1640		message := messages[0].(map[string]any)
1641		require.Equal(t, "user", message["role"])
1642		require.Equal(t, "Hello", message["content"])
1643	})
1644
1645	t.Run("should send metadata extension values", func(t *testing.T) {
1646		t.Parallel()
1647
1648		server := newMockServer()
1649		defer server.close()
1650
1651		server.prepareJSONResponse(map[string]any{
1652			"content": "",
1653		})
1654
1655		provider := NewOpenAIProvider(
1656			WithOpenAIApiKey("test-api-key"),
1657			WithOpenAIBaseURL(server.server.URL),
1658		)
1659		model := provider.LanguageModel("gpt-3.5-turbo")
1660
1661		_, err := model.Generate(context.Background(), ai.Call{
1662			Prompt: testPrompt,
1663			ProviderOptions: ai.ProviderOptions{
1664				"openai": map[string]any{
1665					"metadata": map[string]any{
1666						"custom": "value",
1667					},
1668				},
1669			},
1670		})
1671
1672		require.NoError(t, err)
1673		require.Len(t, server.calls, 1)
1674
1675		call := server.calls[0]
1676		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1677
1678		metadata := call.body["metadata"].(map[string]any)
1679		require.Equal(t, "value", metadata["custom"])
1680
1681		messages := call.body["messages"].([]any)
1682		require.Len(t, messages, 1)
1683
1684		message := messages[0].(map[string]any)
1685		require.Equal(t, "user", message["role"])
1686		require.Equal(t, "Hello", message["content"])
1687	})
1688
1689	t.Run("should send promptCacheKey extension value", func(t *testing.T) {
1690		t.Parallel()
1691
1692		server := newMockServer()
1693		defer server.close()
1694
1695		server.prepareJSONResponse(map[string]any{
1696			"content": "",
1697		})
1698
1699		provider := NewOpenAIProvider(
1700			WithOpenAIApiKey("test-api-key"),
1701			WithOpenAIBaseURL(server.server.URL),
1702		)
1703		model := provider.LanguageModel("gpt-3.5-turbo")
1704
1705		_, err := model.Generate(context.Background(), ai.Call{
1706			Prompt: testPrompt,
1707			ProviderOptions: ai.ProviderOptions{
1708				"openai": map[string]any{
1709					"promptCacheKey": "test-cache-key-123",
1710				},
1711			},
1712		})
1713
1714		require.NoError(t, err)
1715		require.Len(t, server.calls, 1)
1716
1717		call := server.calls[0]
1718		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1719		require.Equal(t, "test-cache-key-123", call.body["prompt_cache_key"])
1720
1721		messages := call.body["messages"].([]any)
1722		require.Len(t, messages, 1)
1723
1724		message := messages[0].(map[string]any)
1725		require.Equal(t, "user", message["role"])
1726		require.Equal(t, "Hello", message["content"])
1727	})
1728
1729	t.Run("should send safetyIdentifier extension value", func(t *testing.T) {
1730		t.Parallel()
1731
1732		server := newMockServer()
1733		defer server.close()
1734
1735		server.prepareJSONResponse(map[string]any{
1736			"content": "",
1737		})
1738
1739		provider := NewOpenAIProvider(
1740			WithOpenAIApiKey("test-api-key"),
1741			WithOpenAIBaseURL(server.server.URL),
1742		)
1743		model := provider.LanguageModel("gpt-3.5-turbo")
1744
1745		_, err := model.Generate(context.Background(), ai.Call{
1746			Prompt: testPrompt,
1747			ProviderOptions: ai.ProviderOptions{
1748				"openai": map[string]any{
1749					"safetyIdentifier": "test-safety-identifier-123",
1750				},
1751			},
1752		})
1753
1754		require.NoError(t, err)
1755		require.Len(t, server.calls, 1)
1756
1757		call := server.calls[0]
1758		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1759		require.Equal(t, "test-safety-identifier-123", call.body["safety_identifier"])
1760
1761		messages := call.body["messages"].([]any)
1762		require.Len(t, messages, 1)
1763
1764		message := messages[0].(map[string]any)
1765		require.Equal(t, "user", message["role"])
1766		require.Equal(t, "Hello", message["content"])
1767	})
1768
1769	t.Run("should remove temperature setting for search preview models", func(t *testing.T) {
1770		t.Parallel()
1771
1772		server := newMockServer()
1773		defer server.close()
1774
1775		server.prepareJSONResponse(map[string]any{})
1776
1777		provider := NewOpenAIProvider(
1778			WithOpenAIApiKey("test-api-key"),
1779			WithOpenAIBaseURL(server.server.URL),
1780		)
1781		model := provider.LanguageModel("gpt-4o-search-preview")
1782
1783		result, err := model.Generate(context.Background(), ai.Call{
1784			Prompt:      testPrompt,
1785			Temperature: &[]float64{0.7}[0],
1786		})
1787
1788		require.NoError(t, err)
1789		require.Len(t, server.calls, 1)
1790
1791		call := server.calls[0]
1792		require.Equal(t, "gpt-4o-search-preview", call.body["model"])
1793		require.Nil(t, call.body["temperature"])
1794
1795		require.Len(t, result.Warnings, 1)
1796		require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1797		require.Equal(t, "temperature", result.Warnings[0].Setting)
1798		require.Contains(t, result.Warnings[0].Details, "search preview models")
1799	})
1800
1801	t.Run("should send serviceTier flex processing setting", func(t *testing.T) {
1802		t.Parallel()
1803
1804		server := newMockServer()
1805		defer server.close()
1806
1807		server.prepareJSONResponse(map[string]any{
1808			"content": "",
1809		})
1810
1811		provider := NewOpenAIProvider(
1812			WithOpenAIApiKey("test-api-key"),
1813			WithOpenAIBaseURL(server.server.URL),
1814		)
1815		model := provider.LanguageModel("o3-mini")
1816
1817		_, err := model.Generate(context.Background(), ai.Call{
1818			Prompt: testPrompt,
1819			ProviderOptions: ai.ProviderOptions{
1820				"openai": map[string]any{
1821					"serviceTier": "flex",
1822				},
1823			},
1824		})
1825
1826		require.NoError(t, err)
1827		require.Len(t, server.calls, 1)
1828
1829		call := server.calls[0]
1830		require.Equal(t, "o3-mini", call.body["model"])
1831		require.Equal(t, "flex", call.body["service_tier"])
1832
1833		messages := call.body["messages"].([]any)
1834		require.Len(t, messages, 1)
1835
1836		message := messages[0].(map[string]any)
1837		require.Equal(t, "user", message["role"])
1838		require.Equal(t, "Hello", message["content"])
1839	})
1840
1841	t.Run("should show warning when using flex processing with unsupported model", func(t *testing.T) {
1842		t.Parallel()
1843
1844		server := newMockServer()
1845		defer server.close()
1846
1847		server.prepareJSONResponse(map[string]any{})
1848
1849		provider := NewOpenAIProvider(
1850			WithOpenAIApiKey("test-api-key"),
1851			WithOpenAIBaseURL(server.server.URL),
1852		)
1853		model := provider.LanguageModel("gpt-4o-mini")
1854
1855		result, err := model.Generate(context.Background(), ai.Call{
1856			Prompt: testPrompt,
1857			ProviderOptions: ai.ProviderOptions{
1858				"openai": map[string]any{
1859					"serviceTier": "flex",
1860				},
1861			},
1862		})
1863
1864		require.NoError(t, err)
1865		require.Len(t, server.calls, 1)
1866
1867		call := server.calls[0]
1868		require.Nil(t, call.body["service_tier"])
1869
1870		require.Len(t, result.Warnings, 1)
1871		require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1872		require.Equal(t, "serviceTier", result.Warnings[0].Setting)
1873		require.Contains(t, result.Warnings[0].Details, "flex processing is only available")
1874	})
1875
1876	t.Run("should send serviceTier priority processing setting", func(t *testing.T) {
1877		t.Parallel()
1878
1879		server := newMockServer()
1880		defer server.close()
1881
1882		server.prepareJSONResponse(map[string]any{})
1883
1884		provider := NewOpenAIProvider(
1885			WithOpenAIApiKey("test-api-key"),
1886			WithOpenAIBaseURL(server.server.URL),
1887		)
1888		model := provider.LanguageModel("gpt-4o-mini")
1889
1890		_, err := model.Generate(context.Background(), ai.Call{
1891			Prompt: testPrompt,
1892			ProviderOptions: ai.ProviderOptions{
1893				"openai": map[string]any{
1894					"serviceTier": "priority",
1895				},
1896			},
1897		})
1898
1899		require.NoError(t, err)
1900		require.Len(t, server.calls, 1)
1901
1902		call := server.calls[0]
1903		require.Equal(t, "gpt-4o-mini", call.body["model"])
1904		require.Equal(t, "priority", call.body["service_tier"])
1905
1906		messages := call.body["messages"].([]any)
1907		require.Len(t, messages, 1)
1908
1909		message := messages[0].(map[string]any)
1910		require.Equal(t, "user", message["role"])
1911		require.Equal(t, "Hello", message["content"])
1912	})
1913
1914	t.Run("should show warning when using priority processing with unsupported model", func(t *testing.T) {
1915		t.Parallel()
1916
1917		server := newMockServer()
1918		defer server.close()
1919
1920		server.prepareJSONResponse(map[string]any{})
1921
1922		provider := NewOpenAIProvider(
1923			WithOpenAIApiKey("test-api-key"),
1924			WithOpenAIBaseURL(server.server.URL),
1925		)
1926		model := provider.LanguageModel("gpt-3.5-turbo")
1927
1928		result, err := model.Generate(context.Background(), ai.Call{
1929			Prompt: testPrompt,
1930			ProviderOptions: ai.ProviderOptions{
1931				"openai": map[string]any{
1932					"serviceTier": "priority",
1933				},
1934			},
1935		})
1936
1937		require.NoError(t, err)
1938		require.Len(t, server.calls, 1)
1939
1940		call := server.calls[0]
1941		require.Nil(t, call.body["service_tier"])
1942
1943		require.Len(t, result.Warnings, 1)
1944		require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1945		require.Equal(t, "serviceTier", result.Warnings[0].Setting)
1946		require.Contains(t, result.Warnings[0].Details, "priority processing is only available")
1947	})
1948}
1949
1950type streamingMockServer struct {
1951	server *httptest.Server
1952	chunks []string
1953	calls  []mockCall
1954}
1955
1956func newStreamingMockServer() *streamingMockServer {
1957	sms := &streamingMockServer{
1958		calls: make([]mockCall, 0),
1959	}
1960
1961	sms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1962		// Record the call
1963		call := mockCall{
1964			method:  r.Method,
1965			path:    r.URL.Path,
1966			headers: make(map[string]string),
1967		}
1968
1969		for k, v := range r.Header {
1970			if len(v) > 0 {
1971				call.headers[k] = v[0]
1972			}
1973		}
1974
1975		// Parse request body
1976		if r.Body != nil {
1977			var body map[string]any
1978			json.NewDecoder(r.Body).Decode(&body)
1979			call.body = body
1980		}
1981
1982		sms.calls = append(sms.calls, call)
1983
1984		// Set streaming headers
1985		w.Header().Set("Content-Type", "text/event-stream")
1986		w.Header().Set("Cache-Control", "no-cache")
1987		w.Header().Set("Connection", "keep-alive")
1988
1989		// Add custom headers if any
1990		for _, chunk := range sms.chunks {
1991			if strings.HasPrefix(chunk, "HEADER:") {
1992				parts := strings.SplitN(chunk[7:], ":", 2)
1993				if len(parts) == 2 {
1994					w.Header().Set(parts[0], parts[1])
1995				}
1996				continue
1997			}
1998		}
1999
2000		w.WriteHeader(http.StatusOK)
2001
2002		// Write chunks
2003		for _, chunk := range sms.chunks {
2004			if strings.HasPrefix(chunk, "HEADER:") {
2005				continue
2006			}
2007			w.Write([]byte(chunk))
2008			if f, ok := w.(http.Flusher); ok {
2009				f.Flush()
2010			}
2011		}
2012	}))
2013
2014	return sms
2015}
2016
2017func (sms *streamingMockServer) close() {
2018	sms.server.Close()
2019}
2020
2021func (sms *streamingMockServer) prepareStreamResponse(opts map[string]any) {
2022	content := []string{}
2023	if c, ok := opts["content"].([]string); ok {
2024		content = c
2025	}
2026
2027	usage := map[string]any{
2028		"prompt_tokens":     17,
2029		"total_tokens":      244,
2030		"completion_tokens": 227,
2031	}
2032	if u, ok := opts["usage"].(map[string]any); ok {
2033		usage = u
2034	}
2035
2036	logprobs := map[string]any{}
2037	if l, ok := opts["logprobs"].(map[string]any); ok {
2038		logprobs = l
2039	}
2040
2041	finishReason := "stop"
2042	if fr, ok := opts["finish_reason"].(string); ok {
2043		finishReason = fr
2044	}
2045
2046	model := "gpt-3.5-turbo-0613"
2047	if m, ok := opts["model"].(string); ok {
2048		model = m
2049	}
2050
2051	headers := map[string]string{}
2052	if h, ok := opts["headers"].(map[string]string); ok {
2053		headers = h
2054	}
2055
2056	chunks := []string{}
2057
2058	// Add custom headers
2059	for k, v := range headers {
2060		chunks = append(chunks, "HEADER:"+k+":"+v)
2061	}
2062
2063	// Initial chunk with role
2064	initialChunk := map[string]any{
2065		"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2066		"object":             "chat.completion.chunk",
2067		"created":            1702657020,
2068		"model":              model,
2069		"system_fingerprint": nil,
2070		"choices": []map[string]any{
2071			{
2072				"index": 0,
2073				"delta": map[string]any{
2074					"role":    "assistant",
2075					"content": "",
2076				},
2077				"finish_reason": nil,
2078			},
2079		},
2080	}
2081	initialData, _ := json.Marshal(initialChunk)
2082	chunks = append(chunks, "data: "+string(initialData)+"\n\n")
2083
2084	// Content chunks
2085	for i, text := range content {
2086		contentChunk := map[string]any{
2087			"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2088			"object":             "chat.completion.chunk",
2089			"created":            1702657020,
2090			"model":              model,
2091			"system_fingerprint": nil,
2092			"choices": []map[string]any{
2093				{
2094					"index": 1,
2095					"delta": map[string]any{
2096						"content": text,
2097					},
2098					"finish_reason": nil,
2099				},
2100			},
2101		}
2102		contentData, _ := json.Marshal(contentChunk)
2103		chunks = append(chunks, "data: "+string(contentData)+"\n\n")
2104
2105		// Add annotations if this is the last content chunk and we have annotations
2106		if i == len(content)-1 {
2107			if annotations, ok := opts["annotations"].([]map[string]any); ok {
2108				annotationChunk := map[string]any{
2109					"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2110					"object":             "chat.completion.chunk",
2111					"created":            1702657020,
2112					"model":              model,
2113					"system_fingerprint": nil,
2114					"choices": []map[string]any{
2115						{
2116							"index": 1,
2117							"delta": map[string]any{
2118								"annotations": annotations,
2119							},
2120							"finish_reason": nil,
2121						},
2122					},
2123				}
2124				annotationData, _ := json.Marshal(annotationChunk)
2125				chunks = append(chunks, "data: "+string(annotationData)+"\n\n")
2126			}
2127		}
2128	}
2129
2130	// Finish chunk
2131	finishChunk := map[string]any{
2132		"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2133		"object":             "chat.completion.chunk",
2134		"created":            1702657020,
2135		"model":              model,
2136		"system_fingerprint": nil,
2137		"choices": []map[string]any{
2138			{
2139				"index":         0,
2140				"delta":         map[string]any{},
2141				"finish_reason": finishReason,
2142			},
2143		},
2144	}
2145
2146	if len(logprobs) > 0 {
2147		finishChunk["choices"].([]map[string]any)[0]["logprobs"] = logprobs
2148	}
2149
2150	finishData, _ := json.Marshal(finishChunk)
2151	chunks = append(chunks, "data: "+string(finishData)+"\n\n")
2152
2153	// Usage chunk
2154	usageChunk := map[string]any{
2155		"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2156		"object":             "chat.completion.chunk",
2157		"created":            1702657020,
2158		"model":              model,
2159		"system_fingerprint": "fp_3bc1b5746c",
2160		"choices":            []map[string]any{},
2161		"usage":              usage,
2162	}
2163	usageData, _ := json.Marshal(usageChunk)
2164	chunks = append(chunks, "data: "+string(usageData)+"\n\n")
2165
2166	// Done
2167	chunks = append(chunks, "data: [DONE]\n\n")
2168
2169	sms.chunks = chunks
2170}
2171
2172func (sms *streamingMockServer) prepareToolStreamResponse() {
2173	chunks := []string{
2174		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2175		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2176		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"value"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2177		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2178		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2179		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2180		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2181		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2182		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
2183		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
2184		"data: [DONE]\n\n",
2185	}
2186	sms.chunks = chunks
2187}
2188
2189func (sms *streamingMockServer) prepareErrorStreamResponse() {
2190	chunks := []string{
2191		`data: {"error":{"message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error.","type":"server_error","param":null,"code":null}}` + "\n\n",
2192		"data: [DONE]\n\n",
2193	}
2194	sms.chunks = chunks
2195}
2196
2197func collectStreamParts(stream ai.StreamResponse) ([]ai.StreamPart, error) {
2198	var parts []ai.StreamPart
2199	for part := range stream {
2200		parts = append(parts, part)
2201		if part.Type == ai.StreamPartTypeError {
2202			break
2203		}
2204		if part.Type == ai.StreamPartTypeFinish {
2205			break
2206		}
2207	}
2208	return parts, nil
2209}
2210
2211func TestDoStream(t *testing.T) {
2212	t.Parallel()
2213
2214	t.Run("should stream text deltas", func(t *testing.T) {
2215		t.Parallel()
2216
2217		server := newStreamingMockServer()
2218		defer server.close()
2219
2220		server.prepareStreamResponse(map[string]any{
2221			"content":       []string{"Hello", ", ", "World!"},
2222			"finish_reason": "stop",
2223			"usage": map[string]any{
2224				"prompt_tokens":     17,
2225				"total_tokens":      244,
2226				"completion_tokens": 227,
2227			},
2228			"logprobs": testLogprobs,
2229		})
2230
2231		provider := NewOpenAIProvider(
2232			WithOpenAIApiKey("test-api-key"),
2233			WithOpenAIBaseURL(server.server.URL),
2234		)
2235		model := provider.LanguageModel("gpt-3.5-turbo")
2236
2237		stream, err := model.Stream(context.Background(), ai.Call{
2238			Prompt: testPrompt,
2239		})
2240
2241		require.NoError(t, err)
2242
2243		parts, err := collectStreamParts(stream)
2244		require.NoError(t, err)
2245
2246		// Verify stream structure
2247		require.True(t, len(parts) >= 4) // text-start, deltas, text-end, finish
2248
2249		// Find text parts
2250		var textStart, textEnd, finish int = -1, -1, -1
2251		var deltas []string
2252
2253		for i, part := range parts {
2254			switch part.Type {
2255			case ai.StreamPartTypeTextStart:
2256				textStart = i
2257			case ai.StreamPartTypeTextDelta:
2258				deltas = append(deltas, part.Delta)
2259			case ai.StreamPartTypeTextEnd:
2260				textEnd = i
2261			case ai.StreamPartTypeFinish:
2262				finish = i
2263			}
2264		}
2265
2266		require.NotEqual(t, -1, textStart)
2267		require.NotEqual(t, -1, textEnd)
2268		require.NotEqual(t, -1, finish)
2269		require.Equal(t, []string{"Hello", ", ", "World!"}, deltas)
2270
2271		// Check finish part
2272		finishPart := parts[finish]
2273		require.Equal(t, ai.FinishReasonStop, finishPart.FinishReason)
2274		require.Equal(t, int64(17), finishPart.Usage.InputTokens)
2275		require.Equal(t, int64(227), finishPart.Usage.OutputTokens)
2276		require.Equal(t, int64(244), finishPart.Usage.TotalTokens)
2277	})
2278
2279	t.Run("should stream tool deltas", func(t *testing.T) {
2280		t.Parallel()
2281
2282		server := newStreamingMockServer()
2283		defer server.close()
2284
2285		server.prepareToolStreamResponse()
2286
2287		provider := NewOpenAIProvider(
2288			WithOpenAIApiKey("test-api-key"),
2289			WithOpenAIBaseURL(server.server.URL),
2290		)
2291		model := provider.LanguageModel("gpt-3.5-turbo")
2292
2293		stream, err := model.Stream(context.Background(), ai.Call{
2294			Prompt: testPrompt,
2295			Tools: []ai.Tool{
2296				ai.FunctionTool{
2297					Name: "test-tool",
2298					InputSchema: map[string]any{
2299						"type": "object",
2300						"properties": map[string]any{
2301							"value": map[string]any{
2302								"type": "string",
2303							},
2304						},
2305						"required":             []string{"value"},
2306						"additionalProperties": false,
2307						"$schema":              "http://json-schema.org/draft-07/schema#",
2308					},
2309				},
2310			},
2311		})
2312
2313		require.NoError(t, err)
2314
2315		parts, err := collectStreamParts(stream)
2316		require.NoError(t, err)
2317
2318		// Find tool-related parts
2319		toolInputStart, toolInputEnd, toolCall := -1, -1, -1
2320		var toolDeltas []string
2321
2322		for i, part := range parts {
2323			switch part.Type {
2324			case ai.StreamPartTypeToolInputStart:
2325				toolInputStart = i
2326				require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2327				require.Equal(t, "test-tool", part.ToolCallName)
2328			case ai.StreamPartTypeToolInputDelta:
2329				toolDeltas = append(toolDeltas, part.Delta)
2330			case ai.StreamPartTypeToolInputEnd:
2331				toolInputEnd = i
2332			case ai.StreamPartTypeToolCall:
2333				toolCall = i
2334				require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2335				require.Equal(t, "test-tool", part.ToolCallName)
2336				require.Equal(t, `{"value":"Sparkle Day"}`, part.ToolCallInput)
2337			}
2338		}
2339
2340		require.NotEqual(t, -1, toolInputStart)
2341		require.NotEqual(t, -1, toolInputEnd)
2342		require.NotEqual(t, -1, toolCall)
2343
2344		// Verify tool deltas combine to form the complete input
2345		fullInput := ""
2346		for _, delta := range toolDeltas {
2347			fullInput += delta
2348		}
2349		require.Equal(t, `{"value":"Sparkle Day"}`, fullInput)
2350	})
2351
2352	t.Run("should stream annotations/citations", func(t *testing.T) {
2353		t.Parallel()
2354
2355		server := newStreamingMockServer()
2356		defer server.close()
2357
2358		server.prepareStreamResponse(map[string]any{
2359			"content": []string{"Based on search results"},
2360			"annotations": []map[string]any{
2361				{
2362					"type": "url_citation",
2363					"url_citation": map[string]any{
2364						"start_index": 24,
2365						"end_index":   29,
2366						"url":         "https://example.com/doc1.pdf",
2367						"title":       "Document 1",
2368					},
2369				},
2370			},
2371		})
2372
2373		provider := NewOpenAIProvider(
2374			WithOpenAIApiKey("test-api-key"),
2375			WithOpenAIBaseURL(server.server.URL),
2376		)
2377		model := provider.LanguageModel("gpt-3.5-turbo")
2378
2379		stream, err := model.Stream(context.Background(), ai.Call{
2380			Prompt: testPrompt,
2381		})
2382
2383		require.NoError(t, err)
2384
2385		parts, err := collectStreamParts(stream)
2386		require.NoError(t, err)
2387
2388		// Find source part
2389		var sourcePart *ai.StreamPart
2390		for _, part := range parts {
2391			if part.Type == ai.StreamPartTypeSource {
2392				sourcePart = &part
2393				break
2394			}
2395		}
2396
2397		require.NotNil(t, sourcePart)
2398		require.Equal(t, ai.SourceTypeURL, sourcePart.SourceType)
2399		require.Equal(t, "https://example.com/doc1.pdf", sourcePart.URL)
2400		require.Equal(t, "Document 1", sourcePart.Title)
2401		require.NotEmpty(t, sourcePart.ID)
2402	})
2403
2404	t.Run("should handle error stream parts", func(t *testing.T) {
2405		t.Parallel()
2406
2407		server := newStreamingMockServer()
2408		defer server.close()
2409
2410		server.prepareErrorStreamResponse()
2411
2412		provider := NewOpenAIProvider(
2413			WithOpenAIApiKey("test-api-key"),
2414			WithOpenAIBaseURL(server.server.URL),
2415		)
2416		model := provider.LanguageModel("gpt-3.5-turbo")
2417
2418		stream, err := model.Stream(context.Background(), ai.Call{
2419			Prompt: testPrompt,
2420		})
2421
2422		require.NoError(t, err)
2423
2424		parts, err := collectStreamParts(stream)
2425		require.NoError(t, err)
2426
2427		// Should have error and finish parts
2428		require.True(t, len(parts) >= 1)
2429
2430		// Find error part
2431		var errorPart *ai.StreamPart
2432		for _, part := range parts {
2433			if part.Type == ai.StreamPartTypeError {
2434				errorPart = &part
2435				break
2436			}
2437		}
2438
2439		require.NotNil(t, errorPart)
2440		require.NotNil(t, errorPart.Error)
2441	})
2442
2443	t.Run("should send request body", func(t *testing.T) {
2444		t.Parallel()
2445
2446		server := newStreamingMockServer()
2447		defer server.close()
2448
2449		server.prepareStreamResponse(map[string]any{
2450			"content": []string{},
2451		})
2452
2453		provider := NewOpenAIProvider(
2454			WithOpenAIApiKey("test-api-key"),
2455			WithOpenAIBaseURL(server.server.URL),
2456		)
2457		model := provider.LanguageModel("gpt-3.5-turbo")
2458
2459		_, err := model.Stream(context.Background(), ai.Call{
2460			Prompt: testPrompt,
2461		})
2462
2463		require.NoError(t, err)
2464		require.Len(t, server.calls, 1)
2465
2466		call := server.calls[0]
2467		require.Equal(t, "POST", call.method)
2468		require.Equal(t, "/chat/completions", call.path)
2469		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2470		require.Equal(t, true, call.body["stream"])
2471
2472		streamOptions := call.body["stream_options"].(map[string]any)
2473		require.Equal(t, true, streamOptions["include_usage"])
2474
2475		messages := call.body["messages"].([]any)
2476		require.Len(t, messages, 1)
2477
2478		message := messages[0].(map[string]any)
2479		require.Equal(t, "user", message["role"])
2480		require.Equal(t, "Hello", message["content"])
2481	})
2482
2483	t.Run("should return cached tokens in providerMetadata", func(t *testing.T) {
2484		t.Parallel()
2485
2486		server := newStreamingMockServer()
2487		defer server.close()
2488
2489		server.prepareStreamResponse(map[string]any{
2490			"content": []string{},
2491			"usage": map[string]any{
2492				"prompt_tokens":     15,
2493				"completion_tokens": 20,
2494				"total_tokens":      35,
2495				"prompt_tokens_details": map[string]any{
2496					"cached_tokens": 1152,
2497				},
2498			},
2499		})
2500
2501		provider := NewOpenAIProvider(
2502			WithOpenAIApiKey("test-api-key"),
2503			WithOpenAIBaseURL(server.server.URL),
2504		)
2505		model := provider.LanguageModel("gpt-3.5-turbo")
2506
2507		stream, err := model.Stream(context.Background(), ai.Call{
2508			Prompt: testPrompt,
2509		})
2510
2511		require.NoError(t, err)
2512
2513		parts, err := collectStreamParts(stream)
2514		require.NoError(t, err)
2515
2516		// Find finish part
2517		var finishPart *ai.StreamPart
2518		for _, part := range parts {
2519			if part.Type == ai.StreamPartTypeFinish {
2520				finishPart = &part
2521				break
2522			}
2523		}
2524
2525		require.NotNil(t, finishPart)
2526		require.Equal(t, int64(1152), finishPart.Usage.CacheReadTokens)
2527		require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2528		require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2529		require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2530	})
2531
2532	t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
2533		t.Parallel()
2534
2535		server := newStreamingMockServer()
2536		defer server.close()
2537
2538		server.prepareStreamResponse(map[string]any{
2539			"content": []string{},
2540			"usage": map[string]any{
2541				"prompt_tokens":     15,
2542				"completion_tokens": 20,
2543				"total_tokens":      35,
2544				"completion_tokens_details": map[string]any{
2545					"accepted_prediction_tokens": 123,
2546					"rejected_prediction_tokens": 456,
2547				},
2548			},
2549		})
2550
2551		provider := NewOpenAIProvider(
2552			WithOpenAIApiKey("test-api-key"),
2553			WithOpenAIBaseURL(server.server.URL),
2554		)
2555		model := provider.LanguageModel("gpt-3.5-turbo")
2556
2557		stream, err := model.Stream(context.Background(), ai.Call{
2558			Prompt: testPrompt,
2559		})
2560
2561		require.NoError(t, err)
2562
2563		parts, err := collectStreamParts(stream)
2564		require.NoError(t, err)
2565
2566		// Find finish part
2567		var finishPart *ai.StreamPart
2568		for _, part := range parts {
2569			if part.Type == ai.StreamPartTypeFinish {
2570				finishPart = &part
2571				break
2572			}
2573		}
2574
2575		require.NotNil(t, finishPart)
2576		require.NotNil(t, finishPart.ProviderMetadata)
2577
2578		openaiMeta, ok := finishPart.ProviderMetadata["openai"]
2579		require.True(t, ok)
2580		require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"])
2581		require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"])
2582	})
2583
2584	t.Run("should send store extension setting", func(t *testing.T) {
2585		t.Parallel()
2586
2587		server := newStreamingMockServer()
2588		defer server.close()
2589
2590		server.prepareStreamResponse(map[string]any{
2591			"content": []string{},
2592		})
2593
2594		provider := NewOpenAIProvider(
2595			WithOpenAIApiKey("test-api-key"),
2596			WithOpenAIBaseURL(server.server.URL),
2597		)
2598		model := provider.LanguageModel("gpt-3.5-turbo")
2599
2600		_, err := model.Stream(context.Background(), ai.Call{
2601			Prompt: testPrompt,
2602			ProviderOptions: ai.ProviderOptions{
2603				"openai": map[string]any{
2604					"store": true,
2605				},
2606			},
2607		})
2608
2609		require.NoError(t, err)
2610		require.Len(t, server.calls, 1)
2611
2612		call := server.calls[0]
2613		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2614		require.Equal(t, true, call.body["stream"])
2615		require.Equal(t, true, call.body["store"])
2616
2617		streamOptions := call.body["stream_options"].(map[string]any)
2618		require.Equal(t, true, streamOptions["include_usage"])
2619
2620		messages := call.body["messages"].([]any)
2621		require.Len(t, messages, 1)
2622
2623		message := messages[0].(map[string]any)
2624		require.Equal(t, "user", message["role"])
2625		require.Equal(t, "Hello", message["content"])
2626	})
2627
2628	t.Run("should send metadata extension values", func(t *testing.T) {
2629		t.Parallel()
2630
2631		server := newStreamingMockServer()
2632		defer server.close()
2633
2634		server.prepareStreamResponse(map[string]any{
2635			"content": []string{},
2636		})
2637
2638		provider := NewOpenAIProvider(
2639			WithOpenAIApiKey("test-api-key"),
2640			WithOpenAIBaseURL(server.server.URL),
2641		)
2642		model := provider.LanguageModel("gpt-3.5-turbo")
2643
2644		_, err := model.Stream(context.Background(), ai.Call{
2645			Prompt: testPrompt,
2646			ProviderOptions: ai.ProviderOptions{
2647				"openai": map[string]any{
2648					"metadata": map[string]any{
2649						"custom": "value",
2650					},
2651				},
2652			},
2653		})
2654
2655		require.NoError(t, err)
2656		require.Len(t, server.calls, 1)
2657
2658		call := server.calls[0]
2659		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2660		require.Equal(t, true, call.body["stream"])
2661
2662		metadata := call.body["metadata"].(map[string]any)
2663		require.Equal(t, "value", metadata["custom"])
2664
2665		streamOptions := call.body["stream_options"].(map[string]any)
2666		require.Equal(t, true, streamOptions["include_usage"])
2667
2668		messages := call.body["messages"].([]any)
2669		require.Len(t, messages, 1)
2670
2671		message := messages[0].(map[string]any)
2672		require.Equal(t, "user", message["role"])
2673		require.Equal(t, "Hello", message["content"])
2674	})
2675
2676	t.Run("should send serviceTier flex processing setting in streaming", func(t *testing.T) {
2677		t.Parallel()
2678
2679		server := newStreamingMockServer()
2680		defer server.close()
2681
2682		server.prepareStreamResponse(map[string]any{
2683			"content": []string{},
2684		})
2685
2686		provider := NewOpenAIProvider(
2687			WithOpenAIApiKey("test-api-key"),
2688			WithOpenAIBaseURL(server.server.URL),
2689		)
2690		model := provider.LanguageModel("o3-mini")
2691
2692		_, err := model.Stream(context.Background(), ai.Call{
2693			Prompt: testPrompt,
2694			ProviderOptions: ai.ProviderOptions{
2695				"openai": map[string]any{
2696					"serviceTier": "flex",
2697				},
2698			},
2699		})
2700
2701		require.NoError(t, err)
2702		require.Len(t, server.calls, 1)
2703
2704		call := server.calls[0]
2705		require.Equal(t, "o3-mini", call.body["model"])
2706		require.Equal(t, "flex", call.body["service_tier"])
2707		require.Equal(t, true, call.body["stream"])
2708
2709		streamOptions := call.body["stream_options"].(map[string]any)
2710		require.Equal(t, true, streamOptions["include_usage"])
2711
2712		messages := call.body["messages"].([]any)
2713		require.Len(t, messages, 1)
2714
2715		message := messages[0].(map[string]any)
2716		require.Equal(t, "user", message["role"])
2717		require.Equal(t, "Hello", message["content"])
2718	})
2719
2720	t.Run("should send serviceTier priority processing setting in streaming", func(t *testing.T) {
2721		t.Parallel()
2722
2723		server := newStreamingMockServer()
2724		defer server.close()
2725
2726		server.prepareStreamResponse(map[string]any{
2727			"content": []string{},
2728		})
2729
2730		provider := NewOpenAIProvider(
2731			WithOpenAIApiKey("test-api-key"),
2732			WithOpenAIBaseURL(server.server.URL),
2733		)
2734		model := provider.LanguageModel("gpt-4o-mini")
2735
2736		_, err := model.Stream(context.Background(), ai.Call{
2737			Prompt: testPrompt,
2738			ProviderOptions: ai.ProviderOptions{
2739				"openai": map[string]any{
2740					"serviceTier": "priority",
2741				},
2742			},
2743		})
2744
2745		require.NoError(t, err)
2746		require.Len(t, server.calls, 1)
2747
2748		call := server.calls[0]
2749		require.Equal(t, "gpt-4o-mini", call.body["model"])
2750		require.Equal(t, "priority", call.body["service_tier"])
2751		require.Equal(t, true, call.body["stream"])
2752
2753		streamOptions := call.body["stream_options"].(map[string]any)
2754		require.Equal(t, true, streamOptions["include_usage"])
2755
2756		messages := call.body["messages"].([]any)
2757		require.Len(t, messages, 1)
2758
2759		message := messages[0].(map[string]any)
2760		require.Equal(t, "user", message["role"])
2761		require.Equal(t, "Hello", message["content"])
2762	})
2763
2764	t.Run("should stream text delta for reasoning models", func(t *testing.T) {
2765		t.Parallel()
2766
2767		server := newStreamingMockServer()
2768		defer server.close()
2769
2770		server.prepareStreamResponse(map[string]any{
2771			"content": []string{"Hello, World!"},
2772			"model":   "o1-preview",
2773		})
2774
2775		provider := NewOpenAIProvider(
2776			WithOpenAIApiKey("test-api-key"),
2777			WithOpenAIBaseURL(server.server.URL),
2778		)
2779		model := provider.LanguageModel("o1-preview")
2780
2781		stream, err := model.Stream(context.Background(), ai.Call{
2782			Prompt: testPrompt,
2783		})
2784
2785		require.NoError(t, err)
2786
2787		parts, err := collectStreamParts(stream)
2788		require.NoError(t, err)
2789
2790		// Find text parts
2791		var textDeltas []string
2792		for _, part := range parts {
2793			if part.Type == ai.StreamPartTypeTextDelta {
2794				textDeltas = append(textDeltas, part.Delta)
2795			}
2796		}
2797
2798		// Should contain the text content (without empty delta)
2799		require.Equal(t, []string{"Hello, World!"}, textDeltas)
2800	})
2801
2802	t.Run("should send reasoning tokens", func(t *testing.T) {
2803		t.Parallel()
2804
2805		server := newStreamingMockServer()
2806		defer server.close()
2807
2808		server.prepareStreamResponse(map[string]any{
2809			"content": []string{"Hello, World!"},
2810			"model":   "o1-preview",
2811			"usage": map[string]any{
2812				"prompt_tokens":     15,
2813				"completion_tokens": 20,
2814				"total_tokens":      35,
2815				"completion_tokens_details": map[string]any{
2816					"reasoning_tokens": 10,
2817				},
2818			},
2819		})
2820
2821		provider := NewOpenAIProvider(
2822			WithOpenAIApiKey("test-api-key"),
2823			WithOpenAIBaseURL(server.server.URL),
2824		)
2825		model := provider.LanguageModel("o1-preview")
2826
2827		stream, err := model.Stream(context.Background(), ai.Call{
2828			Prompt: testPrompt,
2829		})
2830
2831		require.NoError(t, err)
2832
2833		parts, err := collectStreamParts(stream)
2834		require.NoError(t, err)
2835
2836		// Find finish part
2837		var finishPart *ai.StreamPart
2838		for _, part := range parts {
2839			if part.Type == ai.StreamPartTypeFinish {
2840				finishPart = &part
2841				break
2842			}
2843		}
2844
2845		require.NotNil(t, finishPart)
2846		require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2847		require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2848		require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2849		require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens)
2850	})
2851}