feat: initial ai-sdk implementation

Kujtim Hoxha created

Change summary

agent.go                          |  108 +
content.go                        |  432 +++++
errors.go                         |  295 +++
model.go                          |  148 +
provider.go                       |   26 
providers/examples/simple/main.go |   28 
providers/examples/stream/main.go |   49 
providers/openai.go               | 1161 +++++++++++++
providers/openai_test.go          | 2850 +++++++++++++++++++++++++++++++++
9 files changed, 5,097 insertions(+)

Detailed changes

agent.go 🔗

@@ -0,0 +1,108 @@
+package ai
+
+import (
+	"context"
+)
+
+type StepResponse struct {
+	Response
+	// Messages generated during this step
+	Messages []Message
+}
+
+type StepCondition = func(steps []StepResponse) bool
+
+type PrepareStepFunctionOptions struct {
+	Steps      []StepResponse
+	StepNumber int
+	Model      LanguageModel
+	Messages   []Message
+}
+
+type PrepareStepResult struct {
+	SystemPrompt string
+	Model        LanguageModel
+	Messages     []Message
+}
+
+type PrepareStepFunction = func(options PrepareStepFunctionOptions) PrepareStepResult
+
+type OnStepFinishedFunction = func(step StepResponse)
+
+type AgentSettings struct {
+	Call
+	Model LanguageModel
+
+	StopWhen       []StepCondition
+	PrepareStep    PrepareStepFunction
+	OnStepFinished OnStepFinishedFunction
+}
+
+type Agent interface {
+	Generate(context.Context, Call) (*Response, error)
+	Stream(context.Context, Call) (StreamResponse, error)
+}
+
+type agentOption = func(*AgentSettings)
+
+type agent struct {
+	settings AgentSettings
+}
+
+func NewAgent(model LanguageModel, opts ...agentOption) Agent {
+	settings := AgentSettings{
+		Model: model,
+	}
+	for _, o := range opts {
+		o(&settings)
+	}
+	return &agent{
+		settings: settings,
+	}
+}
+
+func mergeCall(agentOpts, opts Call) Call {
+	if len(opts.Prompt) > 0 {
+		agentOpts.Prompt = opts.Prompt
+	}
+	if opts.MaxOutputTokens != nil {
+		agentOpts.MaxOutputTokens = opts.MaxOutputTokens
+	}
+	if opts.Temperature != nil {
+		agentOpts.Temperature = opts.Temperature
+	}
+	if opts.TopP != nil {
+		agentOpts.TopP = opts.TopP
+	}
+	if opts.TopK != nil {
+		agentOpts.TopK = opts.TopK
+	}
+	if opts.PresencePenalty != nil {
+		agentOpts.PresencePenalty = opts.PresencePenalty
+	}
+	if opts.FrequencyPenalty != nil {
+		agentOpts.FrequencyPenalty = opts.FrequencyPenalty
+	}
+	if opts.Tools != nil {
+		agentOpts.Tools = opts.Tools
+	}
+	if opts.Headers != nil {
+		agentOpts.Headers = opts.Headers
+	}
+	if opts.ProviderOptions != nil {
+		agentOpts.ProviderOptions = opts.ProviderOptions
+	}
+	return agentOpts
+}
+
+// Generate implements Agent.
+func (a *agent) Generate(ctx context.Context, opts Call) (*Response, error) {
+	// TODO: implement the agentic stuff
+	return a.settings.Model.Generate(ctx, mergeCall(a.settings.Call, opts))
+}
+
+// Stream implements Agent.
+func (a *agent) Stream(ctx context.Context, opts Call) (StreamResponse, error) {
+	// TODO: implement the agentic stuff
+	return a.settings.Model.Stream(ctx, mergeCall(a.settings.Call, opts))
+}

content.go 🔗

@@ -0,0 +1,432 @@
+package ai
+
+// ProviderMetadata represents additional provider-specific metadata.
+// They are passed through from the provider to the AI SDK and enable
+// provider-specific results that can be fully encapsulated in the provider.
+//
+// The outer map is keyed by the provider name, and the inner
+// map is keyed by the provider-specific metadata key.
+//
+// Example:
+//
+//	{
+//	  "anthropic": {
+//	    "cacheControl": { "type": "ephemeral" }
+//	  }
+//	}
+type ProviderMetadata map[string]map[string]any
+
+// ProviderOptions represents additional provider-specific options.
+// Options are additional input to the provider. They are passed through
+// to the provider from the AI SDK and enable provider-specific functionality
+// that can be fully encapsulated in the provider.
+//
+// This enables us to quickly ship provider-specific functionality
+// without affecting the core AI SDK.
+//
+// The outer map is keyed by the provider name, and the inner
+// map is keyed by the provider-specific option key.
+//
+// Example:
+//
+//	{
+//	  "anthropic": {
+//	    "cacheControl": { "type": "ephemeral" }
+//	  }
+//	}
+type ProviderOptions map[string]map[string]any
+
+// FinishReason represents why a language model finished generating a response.
+//
+// Can be one of the following:
+// - `stop`: model generated stop sequence
+// - `length`: model generated maximum number of tokens
+// - `content-filter`: content filter violation stopped the model
+// - `tool-calls`: model triggered tool calls
+// - `error`: model stopped because of an error
+// - `other`: model stopped for other reasons
+// - `unknown`: the model has not transmitted a finish reason
+type FinishReason string
+
+const (
+	FinishReasonStop          FinishReason = "stop"           // model generated stop sequence
+	FinishReasonLength        FinishReason = "length"         // model generated maximum number of tokens
+	FinishReasonContentFilter FinishReason = "content-filter" // content filter violation stopped the model
+	FinishReasonToolCalls     FinishReason = "tool-calls"     // model triggered tool calls
+	FinishReasonError         FinishReason = "error"          // model stopped because of an error
+	FinishReasonOther         FinishReason = "other"          // model stopped for other reasons
+	FinishReasonUnknown       FinishReason = "unknown"        // the model has not transmitted a finish reason
+)
+
+// Prompt represents a list of messages for the language model.
+type Prompt []Message
+
+// MessageRole represents the role of a message.
+type MessageRole string
+
+const (
+	MessageRoleSystem    MessageRole = "system"
+	MessageRoleUser      MessageRole = "user"
+	MessageRoleAssistant MessageRole = "assistant"
+	MessageRoleTool      MessageRole = "tool"
+)
+
+// Message represents a message in a prompt.
+type Message struct {
+	Role            MessageRole     `json:"role"`
+	Content         []MessagePart   `json:"content"`
+	ProviderOptions ProviderOptions `json:"provider_options"`
+}
+
+func AsContentType[T MessagePart](content MessagePart) (T, bool) {
+	var zero T
+	if content == nil {
+		return zero, false
+	}
+	switch v := any(content).(type) {
+	case T:
+		return v, true
+	case *T:
+		return *v, true
+	default:
+		return zero, false
+	}
+}
+
+// MessagePart represents a part of a message content.
+type MessagePart interface {
+	GetType() ContentType
+}
+
+// TextPart represents text content in a message.
+type TextPart struct {
+	Text            string          `json:"text"`
+	ProviderOptions ProviderOptions `json:"provider_options"`
+}
+
+// GetType returns the type of the text part.
+func (t TextPart) GetType() ContentType {
+	return ContentTypeText
+}
+
+// ReasoningPart represents reasoning content in a message.
+type ReasoningPart struct {
+	Text            string          `json:"text"`
+	ProviderOptions ProviderOptions `json:"provider_options"`
+}
+
+// GetType returns the type of the reasoning part.
+func (r ReasoningPart) GetType() ContentType {
+	return ContentTypeReasoning
+}
+
+// FilePart represents file content in a message.
+type FilePart struct {
+	Filename        string          `json:"filename"`
+	Data            []byte          `json:"data"`
+	MediaType       string          `json:"media_type"`
+	ProviderOptions ProviderOptions `json:"provider_options"`
+}
+
+// GetType returns the type of the file part.
+func (f FilePart) GetType() ContentType {
+	return ContentTypeFile
+}
+
+// ToolCallPart represents a tool call in a message.
+type ToolCallPart struct {
+	ToolCallID       string          `json:"tool_call_id"`
+	ToolName         string          `json:"tool_name"`
+	Input            string          `json:"input"` // the json string
+	ProviderExecuted bool            `json:"provider_executed"`
+	ProviderOptions  ProviderOptions `json:"provider_options"`
+}
+
+// GetType returns the type of the tool call part.
+func (t ToolCallPart) GetType() ContentType {
+	return ContentTypeToolCall
+}
+
+// ToolResultPart represents a tool result in a message.
+type ToolResultPart struct {
+	ToolCallID      string                  `json:"tool_call_id"`
+	Output          ToolResultOutputContent `json:"output"`
+	ProviderOptions ProviderOptions         `json:"provider_options"`
+}
+
+// GetType returns the type of the tool result part.
+func (t ToolResultPart) GetType() ContentType {
+	return ContentTypeToolResult
+}
+
+// ToolResultContentType represents the type of tool result output.
+type ToolResultContentType string
+
+const (
+	// ToolResultContentTypeText represents text output.
+	ToolResultContentTypeText ToolResultContentType = "text"
+	// ToolResultContentTypeError represents error text output.
+	ToolResultContentTypeError ToolResultContentType = "error"
+	// ToolResultContentTypeMedia represents content output.
+	ToolResultContentTypeMedia ToolResultContentType = "media"
+)
+
+type ToolResultOutputContent interface {
+	GetType() ToolResultContentType
+}
+
+type ToolResultOutputContentText struct {
+	Text string `json:"text"`
+}
+
+func (t ToolResultOutputContentText) GetType() ToolResultContentType {
+	return ToolResultContentTypeText
+}
+
+type ToolResultOutputContentError struct {
+	Error string `json:"error"`
+}
+
+func (t ToolResultOutputContentError) GetType() ToolResultContentType {
+	return ToolResultContentTypeError
+}
+
+type ToolResultOutputContentMedia struct {
+	Data      string `json:"data"`       // for media type (base64)
+	MediaType string `json:"media_type"` // for media type
+}
+
+func (t ToolResultOutputContentMedia) GetType() ToolResultContentType {
+	return ToolResultContentTypeMedia
+}
+
+func AsToolResultOutputType[T ToolResultOutputContent](content ToolResultOutputContent) (T, bool) {
+	var zero T
+	if content == nil {
+		return zero, false
+	}
+	switch v := any(content).(type) {
+	case T:
+		return v, true
+	case *T:
+		return *v, true
+	default:
+		return zero, false
+	}
+}
+
+// ContentType represents the type of content.
+type ContentType string
+
+const (
+	// ContentTypeText represents text content.
+	ContentTypeText ContentType = "text"
+	// ContentTypeReasoning represents reasoning content.
+	ContentTypeReasoning ContentType = "reasoning"
+	// ContentTypeFile represents file content.
+	ContentTypeFile ContentType = "file"
+	// ContentTypeSource represents source content.
+	ContentTypeSource ContentType = "source"
+	// ContentTypeToolCall represents a tool call.
+	ContentTypeToolCall ContentType = "tool-call"
+	// ContentTypeToolResult represents a tool result.
+	ContentTypeToolResult ContentType = "tool-result"
+)
+
+// Content represents generated content from the model.
+type Content interface {
+	GetType() ContentType
+}
+
+// TextContent represents text that the model has generated.
+type TextContent struct {
+	// The text content.
+	Text             string           `json:"text"`
+	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
+}
+
+// GetType returns the type of the text content.
+func (t TextContent) GetType() ContentType {
+	return ContentTypeText
+}
+
+// ReasoningContent represents reasoning that the model has generated.
+type ReasoningContent struct {
+	Text             string           `json:"text"`
+	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
+}
+
+// GetType returns the type of the reasoning content.
+func (r ReasoningContent) GetType() ContentType {
+	return ContentTypeReasoning
+}
+
+// FileContent represents a file that has been generated by the model.
+// Generated files as base64 encoded strings or binary data.
+// The files should be returned without any unnecessary conversion.
+type FileContent struct {
+	// The IANA media type of the file, e.g. `image/png` or `audio/mp3`.
+	// @see https://www.iana.org/assignments/media-types/media-types.xhtml
+	MediaType string `json:"media_type"`
+	// Generated file data as base64 encoded strings or binary data.
+	// If the API returns base64 encoded strings, the file data should be returned
+	// as base64 encoded strings. If the API returns binary data, the file data should
+	// be returned as binary data.
+	Data any `json:"data"` // string (base64) or []byte
+}
+
+// GetType returns the type of the file content.
+func (f FileContent) GetType() ContentType {
+	return ContentTypeFile
+}
+
+// SourceType represents the type of source.
+type SourceType string
+
+const (
+	// SourceTypeURL represents a URL source.
+	SourceTypeURL SourceType = "url"
+	// SourceTypeDocument represents a document source.
+	SourceTypeDocument SourceType = "document"
+)
+
+// SourceContent represents a source that has been used as input to generate the response.
+type SourceContent struct {
+	SourceType       SourceType       `json:"source_type"` // "url" or "document"
+	ID               string           `json:"id"`
+	URL              string           `json:"url"` // for URL sources
+	Title            string           `json:"title"`
+	MediaType        string           `json:"media_type"` // for document sources (IANA media type)
+	Filename         string           `json:"filename"`   // for document sources
+	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
+}
+
+// GetType returns the type of the source content.
+func (s SourceContent) GetType() ContentType {
+	return ContentTypeSource
+}
+
+// ToolCallContent represents tool calls that the model has generated.
+type ToolCallContent struct {
+	ToolCallID string `json:"tool_call_id"`
+	ToolName   string `json:"tool_name"`
+	// Stringified JSON object with the tool call arguments.
+	// Must match the parameters schema of the tool.
+	Input string `json:"input"`
+	// Whether the tool call will be executed by the provider.
+	// If this flag is not set or is false, the tool call will be executed by the client.
+	ProviderExecuted bool `json:"provider_executed"`
+	// Additional provider-specific metadata for the tool call.
+	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
+}
+
+// GetType returns the type of the tool call content.
+func (t ToolCallContent) GetType() ContentType {
+	return ContentTypeToolCall
+}
+
+// ToolResultContent represents result of a tool call that has been executed by the provider.
+type ToolResultContent struct {
+	// The ID of the tool call that this result is associated with.
+	ToolCallID string `json:"tool_call_id"`
+	// Name of the tool that generated this result.
+	ToolName string `json:"tool_name"`
+	// Result of the tool call. This is a JSON-serializable object.
+	Result any `json:"result"`
+	// Optional flag if the result is an error or an error message.
+	IsError bool `json:"is_error"`
+	// Whether the tool result was generated by the provider.
+	// If this flag is set to true, the tool result was generated by the provider.
+	// If this flag is not set or is false, the tool result was generated by the client.
+	ProviderExecuted bool `json:"provider_executed"`
+	// Additional provider-specific metadata for the tool result.
+	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
+}
+
+// GetType returns the type of the tool result content.
+func (t ToolResultContent) GetType() ContentType {
+	return ContentTypeToolResult
+}
+
+// ToolType represents the type of tool.
+type ToolType string
+
+const (
+	// ToolTypeFunction represents a function tool.
+	ToolTypeFunction ToolType = "function"
+	// ToolTypeProviderDefined represents a provider-defined tool.
+	ToolTypeProviderDefined ToolType = "provider-defined"
+)
+
+// Tool represents a tool that can be used by the model.
+//
+// Note: this is **not** the user-facing tool definition. The AI SDK methods will
+// map the user-facing tool definitions to this format.
+type Tool interface {
+	GetType() ToolType
+	GetName() string
+}
+
+// FunctionTool represents a function tool.
+//
+// A tool has a name, a description, and a set of parameters.
+type FunctionTool struct {
+	// Name of the tool. Unique within this model call.
+	Name string `json:"name"`
+	// Description of the tool. The language model uses this to understand the
+	// tool's purpose and to provide better completion suggestions.
+	Description string `json:"description"`
+	// InputSchema - the parameters that the tool expects. The language model uses this to
+	// understand the tool's input requirements and to provide matching suggestions.
+	InputSchema map[string]any `json:"input_schema"` // JSON Schema
+	// ProviderOptions are provider-specific options for the tool.
+	ProviderOptions ProviderOptions `json:"provider_options"`
+}
+
+// GetType returns the type of the function tool.
+func (f FunctionTool) GetType() ToolType {
+	return ToolTypeFunction
+}
+
+// GetName returns the name of the function tool.
+func (f FunctionTool) GetName() string {
+	return f.Name
+}
+
+// ProviderDefinedTool represents the configuration of a tool that is defined by the provider.
+type ProviderDefinedTool struct {
+	// ID of the tool. Should follow the format `<provider-name>.<unique-tool-name>`.
+	ID string `json:"id"`
+	// Name of the tool that the user must use in the tool set.
+	Name string `json:"name"`
+	// Args for configuring the tool. Must match the expected arguments defined by the provider for this tool.
+	Args map[string]any `json:"args"`
+}
+
+// GetType returns the type of the provider-defined tool.
+func (p ProviderDefinedTool) GetType() ToolType {
+	return ToolTypeProviderDefined
+}
+
+// GetName returns the name of the provider-defined tool.
+func (p ProviderDefinedTool) GetName() string {
+	return p.Name
+}
+
+// Helpers
+func NewUserMessage(prompt string, files ...FilePart) Message {
+	content := []MessagePart{
+		TextPart{
+			Text: prompt,
+		},
+	}
+
+	for _, f := range files {
+		content = append(content, f)
+	}
+
+	return Message{
+		Role:    MessageRoleUser,
+		Content: content,
+	}
+}

errors.go 🔗

@@ -0,0 +1,295 @@
+package ai
+
+import (
+	"encoding/json"
+	"errors"
+	"fmt"
+)
+
+// markerSymbol is used for identifying AI SDK Error instances.
+var markerSymbol = "ai.error"
+
+// AIError is a custom error type for AI SDK related errors.
+type AIError struct {
+	Name    string
+	Message string
+	Cause   error
+	marker  string
+}
+
+// Error implements the error interface.
+func (e *AIError) Error() string {
+	return e.Message
+}
+
+// Unwrap returns the underlying cause of the error.
+func (e *AIError) Unwrap() error {
+	return e.Cause
+}
+
+// NewAIError creates a new AI SDK Error.
+func NewAIError(name, message string, cause error) *AIError {
+	return &AIError{
+		Name:    name,
+		Message: message,
+		Cause:   cause,
+		marker:  markerSymbol,
+	}
+}
+
+// IsAIError checks if the given error is an AI SDK Error.
+func IsAIError(err error) bool {
+	var sdkErr *AIError
+	return errors.As(err, &sdkErr) && sdkErr.marker == markerSymbol
+}
+
+// APICallError represents an error from an API call.
+type APICallError struct {
+	*AIError
+	URL               string
+	RequestBodyValues any
+	StatusCode        int
+	ResponseHeaders   map[string]string
+	ResponseBody      string
+	IsRetryable       bool
+	Data              any
+}
+
+// NewAPICallError creates a new API call error.
+func NewAPICallError(message, url string, requestBodyValues any, statusCode int, responseHeaders map[string]string, responseBody string, cause error, isRetryable bool, data any) *APICallError {
+	if !isRetryable && statusCode != 0 {
+		isRetryable = statusCode == 408 || statusCode == 409 || statusCode == 429 || statusCode >= 500
+	}
+
+	return &APICallError{
+		AIError:           NewAIError("AI_APICallError", message, cause),
+		URL:               url,
+		RequestBodyValues: requestBodyValues,
+		StatusCode:        statusCode,
+		ResponseHeaders:   responseHeaders,
+		ResponseBody:      responseBody,
+		IsRetryable:       isRetryable,
+		Data:              data,
+	}
+}
+
+// EmptyResponseBodyError represents an empty response body error.
+type EmptyResponseBodyError struct {
+	*AIError
+}
+
+// NewEmptyResponseBodyError creates a new empty response body error.
+func NewEmptyResponseBodyError(message string) *EmptyResponseBodyError {
+	if message == "" {
+		message = "Empty response body"
+	}
+	return &EmptyResponseBodyError{
+		AIError: NewAIError("AI_EmptyResponseBodyError", message, nil),
+	}
+}
+
+// InvalidArgumentError represents an invalid function argument error.
+type InvalidArgumentError struct {
+	*AIError
+	Argument string
+}
+
+// NewInvalidArgumentError creates a new invalid argument error.
+func NewInvalidArgumentError(argument, message string, cause error) *InvalidArgumentError {
+	return &InvalidArgumentError{
+		AIError:  NewAIError("AI_InvalidArgumentError", message, cause),
+		Argument: argument,
+	}
+}
+
+// InvalidPromptError represents an invalid prompt error.
+type InvalidPromptError struct {
+	*AIError
+	Prompt any
+}
+
+// NewInvalidPromptError creates a new invalid prompt error.
+func NewInvalidPromptError(prompt any, message string, cause error) *InvalidPromptError {
+	return &InvalidPromptError{
+		AIError: NewAIError("AI_InvalidPromptError", fmt.Sprintf("Invalid prompt: %s", message), cause),
+		Prompt:  prompt,
+	}
+}
+
+// InvalidResponseDataError represents invalid response data from the server.
+type InvalidResponseDataError struct {
+	*AIError
+	Data any
+}
+
+// NewInvalidResponseDataError creates a new invalid response data error.
+func NewInvalidResponseDataError(data any, message string) *InvalidResponseDataError {
+	if message == "" {
+		dataJSON, _ := json.Marshal(data)
+		message = fmt.Sprintf("Invalid response data: %s.", string(dataJSON))
+	}
+	return &InvalidResponseDataError{
+		AIError: NewAIError("AI_InvalidResponseDataError", message, nil),
+		Data:    data,
+	}
+}
+
+// JSONParseError represents a JSON parsing error.
+type JSONParseError struct {
+	*AIError
+	Text string
+}
+
+// NewJSONParseError creates a new JSON parse error.
+func NewJSONParseError(text string, cause error) *JSONParseError {
+	message := fmt.Sprintf("JSON parsing failed: Text: %s.\nError message: %s", text, GetErrorMessage(cause))
+	return &JSONParseError{
+		AIError: NewAIError("AI_JSONParseError", message, cause),
+		Text:    text,
+	}
+}
+
+// LoadAPIKeyError represents an error loading an API key.
+type LoadAPIKeyError struct {
+	*AIError
+}
+
+// NewLoadAPIKeyError creates a new load API key error.
+func NewLoadAPIKeyError(message string) *LoadAPIKeyError {
+	return &LoadAPIKeyError{
+		AIError: NewAIError("AI_LoadAPIKeyError", message, nil),
+	}
+}
+
+// LoadSettingError represents an error loading a setting.
+type LoadSettingError struct {
+	*AIError
+}
+
+// NewLoadSettingError creates a new load setting error.
+func NewLoadSettingError(message string) *LoadSettingError {
+	return &LoadSettingError{
+		AIError: NewAIError("AI_LoadSettingError", message, nil),
+	}
+}
+
+// NoContentGeneratedError is thrown when the AI provider fails to generate any content.
+type NoContentGeneratedError struct {
+	*AIError
+}
+
+// NewNoContentGeneratedError creates a new no content generated error.
+func NewNoContentGeneratedError(message string) *NoContentGeneratedError {
+	if message == "" {
+		message = "No content generated."
+	}
+	return &NoContentGeneratedError{
+		AIError: NewAIError("AI_NoContentGeneratedError", message, nil),
+	}
+}
+
+// ModelType represents the type of model.
+type ModelType string
+
+const (
+	ModelTypeLanguage      ModelType = "languageModel"
+	ModelTypeTextEmbedding ModelType = "textEmbeddingModel"
+	ModelTypeImage         ModelType = "imageModel"
+	ModelTypeTranscription ModelType = "transcriptionModel"
+	ModelTypeSpeech        ModelType = "speechModel"
+)
+
+// NoSuchModelError represents an error when a model is not found.
+type NoSuchModelError struct {
+	*AIError
+	ModelID   string
+	ModelType ModelType
+}
+
+// NewNoSuchModelError creates a new no such model error.
+func NewNoSuchModelError(modelID string, modelType ModelType, message string) *NoSuchModelError {
+	if message == "" {
+		message = fmt.Sprintf("No such %s: %s", modelType, modelID)
+	}
+	return &NoSuchModelError{
+		AIError:   NewAIError("AI_NoSuchModelError", message, nil),
+		ModelID:   modelID,
+		ModelType: modelType,
+	}
+}
+
+// TooManyEmbeddingValuesForCallError represents an error when too many values are provided for embedding.
+type TooManyEmbeddingValuesForCallError struct {
+	*AIError
+	Provider             string
+	ModelID              string
+	MaxEmbeddingsPerCall int
+	Values               []any
+}
+
+// NewTooManyEmbeddingValuesForCallError creates a new too many embedding values error.
+func NewTooManyEmbeddingValuesForCallError(provider, modelID string, maxEmbeddingsPerCall int, values []any) *TooManyEmbeddingValuesForCallError {
+	message := fmt.Sprintf(
+		"Too many values for a single embedding call. The %s model \"%s\" can only embed up to %d values per call, but %d values were provided.",
+		provider, modelID, maxEmbeddingsPerCall, len(values),
+	)
+	return &TooManyEmbeddingValuesForCallError{
+		AIError:              NewAIError("AI_TooManyEmbeddingValuesForCallError", message, nil),
+		Provider:             provider,
+		ModelID:              modelID,
+		MaxEmbeddingsPerCall: maxEmbeddingsPerCall,
+		Values:               values,
+	}
+}
+
+// TypeValidationError represents a type validation error.
+type TypeValidationError struct {
+	*AIError
+	Value any
+}
+
+// NewTypeValidationError creates a new type validation error.
+func NewTypeValidationError(value any, cause error) *TypeValidationError {
+	valueJSON, _ := json.Marshal(value)
+	message := fmt.Sprintf(
+		"Type validation failed: Value: %s.\nError message: %s",
+		string(valueJSON), GetErrorMessage(cause),
+	)
+	return &TypeValidationError{
+		AIError: NewAIError("AI_TypeValidationError", message, cause),
+		Value:   value,
+	}
+}
+
+// WrapTypeValidationError wraps an error into a TypeValidationError.
+func WrapTypeValidationError(value any, cause error) *TypeValidationError {
+	if tvErr, ok := cause.(*TypeValidationError); ok && tvErr.Value == value {
+		return tvErr
+	}
+	return NewTypeValidationError(value, cause)
+}
+
+// UnsupportedFunctionalityError represents an unsupported functionality error.
+type UnsupportedFunctionalityError struct {
+	*AIError
+	Functionality string
+}
+
+// NewUnsupportedFunctionalityError creates a new unsupported functionality error.
+func NewUnsupportedFunctionalityError(functionality, message string) *UnsupportedFunctionalityError {
+	if message == "" {
+		message = fmt.Sprintf("'%s' functionality not supported.", functionality)
+	}
+	return &UnsupportedFunctionalityError{
+		AIError:       NewAIError("AI_UnsupportedFunctionalityError", message, nil),
+		Functionality: functionality,
+	}
+}
+
+// GetErrorMessage extracts a message from an error.
+func GetErrorMessage(err error) string {
+	if err == nil {
+		return "unknown error"
+	}
+	return err.Error()
+}

model.go 🔗

@@ -0,0 +1,148 @@
+package ai
+
+import (
+	"context"
+	"fmt"
+	"iter"
+)
+
+type Usage struct {
+	InputTokens         int64 `json:"input_tokens"`
+	OutputTokens        int64 `json:"output_tokens"`
+	TotalTokens         int64 `json:"total_tokens"`
+	ReasoningTokens     int64 `json:"reasoning_tokens"`
+	CacheCreationTokens int64 `json:"cache_creation_tokens"`
+	CacheReadTokens     int64 `json:"cache_read_tokens"`
+}
+
+func (u Usage) String() string {
+	return fmt.Sprintf("Usage{Input: %d, Output: %d, Total: %d, Reasoning: %d, CacheCreation: %d, CacheRead: %d}",
+		u.InputTokens,
+		u.OutputTokens,
+		u.TotalTokens,
+		u.ReasoningTokens,
+		u.CacheCreationTokens,
+		u.CacheReadTokens,
+	)
+}
+
+type ResponseContent []Content
+
+func (r ResponseContent) Text() string {
+	for _, c := range r {
+		if c.GetType() == ContentTypeText {
+			return c.(TextContent).Text
+		}
+	}
+	return ""
+}
+
+type Response struct {
+	Content      ResponseContent `json:"content"`
+	FinishReason FinishReason    `json:"finish_reason"`
+	Usage        Usage           `json:"usage"`
+	Warnings     []CallWarning   `json:"warnings"`
+
+	// for provider specific response metadata, the key is the provider id
+	ProviderMetadata map[string]map[string]any `json:"provider_metadata"`
+}
+
+type StreamPartType string
+
+const (
+	StreamPartTypeWarnings  StreamPartType = "warnings"
+	StreamPartTypeTextStart StreamPartType = "text_start"
+	StreamPartTypeTextDelta StreamPartType = "text_delta"
+	StreamPartTypeTextEnd   StreamPartType = "text_end"
+
+	StreamPartTypeReasoningStart StreamPartType = "reasoning_start"
+	StreamPartTypeReasoningDelta StreamPartType = "reasoning_delta"
+	StreamPartTypeReasoningEnd   StreamPartType = "reasoning_end"
+	StreamPartTypeToolInputStart StreamPartType = "tool_input_start"
+	StreamPartTypeToolInputDelta StreamPartType = "tool_input_delta"
+	StreamPartTypeToolInputEnd   StreamPartType = "tool_input_end"
+	StreamPartTypeToolCall       StreamPartType = "tool_call"
+	StreamPartTypeToolResult     StreamPartType = "tool_result"
+	StreamPartTypeSource         StreamPartType = "source"
+	StreamPartTypeFinish         StreamPartType = "finish"
+	StreamPartTypeError          StreamPartType = "error"
+)
+
+type StreamPart struct {
+	Type             StreamPartType `json:"type"`
+	ID               string         `json:"id"`
+	ToolCallName     string         `json:"tool_call_name"`
+	ToolCallInput    string         `json:"tool_call_input"`
+	Delta            string         `json:"delta"`
+	ProviderExecuted bool           `json:"provider_executed"`
+	Usage            Usage          `json:"usage"`
+	FinishReason     FinishReason   `json:"finish_reason"`
+	Error            error          `json:"error"`
+	Warnings         []CallWarning  `json:"warnings"`
+
+	// Source-related fields
+	SourceType SourceType `json:"source_type"`
+	URL        string     `json:"url"`
+	Title      string     `json:"title"`
+
+	ProviderMetadata ProviderOptions `json:"provider_metadata"`
+}
+type StreamResponse = iter.Seq[StreamPart]
+
+type ToolChoice string
+
+const (
+	ToolChoiceNone ToolChoice = "none"
+	ToolChoiceAuto ToolChoice = "auto"
+)
+
+func SpecificToolChoice(name string) ToolChoice {
+	return ToolChoice(name)
+}
+
+type Call struct {
+	Prompt           Prompt            `json:"prompt"`
+	MaxOutputTokens  *int64            `json:"max_output_tokens"`
+	Temperature      *float64          `json:"temperature"`
+	TopP             *float64          `json:"top_p"`
+	TopK             *int64            `json:"top_k"`
+	PresencePenalty  *float64          `json:"presence_penalty"`
+	FrequencyPenalty *float64          `json:"frequency_penalty"`
+	Tools            []Tool            `json:"tools"`
+	ToolChoice       *ToolChoice       `json:"tool_choice"`
+	Headers          map[string]string `json:"headers"`
+
+	// for provider specific options, the key is the provider id
+	ProviderOptions ProviderOptions `json:"provider_options"`
+}
+
+// CallWarningType represents the type of call warning.
+type CallWarningType string
+
+const (
+	// CallWarningTypeUnsupportedSetting indicates an unsupported setting.
+	CallWarningTypeUnsupportedSetting CallWarningType = "unsupported-setting"
+	// CallWarningTypeUnsupportedTool indicates an unsupported tool.
+	CallWarningTypeUnsupportedTool CallWarningType = "unsupported-tool"
+	// CallWarningTypeOther indicates other warnings.
+	CallWarningTypeOther CallWarningType = "other"
+)
+
+// CallWarning represents a warning from the model provider for this call.
+// The call will proceed, but e.g. some settings might not be supported,
+// which can lead to suboptimal results.
+type CallWarning struct {
+	Type    CallWarningType `json:"type"`
+	Setting string          `json:"setting"`
+	Tool    Tool            `json:"tool"`
+	Details string          `json:"details"`
+	Message string          `json:"message"`
+}
+
+type LanguageModel interface {
+	Generate(context.Context, Call) (*Response, error)
+	Stream(context.Context, Call) (StreamResponse, error)
+
+	Provider() string
+	Model() string
+}

provider.go 🔗

@@ -0,0 +1,26 @@
+package ai
+
+import (
+	"encoding/json"
+
+	"github.com/go-viper/mapstructure/v2"
+)
+
+type Provider interface {
+	LanguageModel(modelID string) LanguageModel
+	// TODO: add other model types when needed
+}
+
+func ParseOptions[T any](options map[string]any, m *T) error {
+	return mapstructure.Decode(options, m)
+}
+
+func FloatOption(f float64) *float64 {
+	return &f
+}
+
+func IsParsableJSON(data string) bool {
+	var m map[string]any
+	err := json.Unmarshal([]byte(data), &m)
+	return err == nil
+}

providers/examples/simple/main.go 🔗

@@ -0,0 +1,28 @@
+package main
+
+import (
+	"context"
+	"fmt"
+
+	"github.com/charmbracelet/crush/internal/ai"
+	"github.com/charmbracelet/crush/internal/ai/providers"
+)
+
+func main() {
+	provider := providers.NewOpenAIProvider(providers.WithOpenAIApiKey("$OPENAI_API_KEY"))
+	model := provider.LanguageModel("gpt-4o")
+
+	response, err := model.Generate(context.Background(), ai.Call{
+		Prompt: ai.Prompt{
+			ai.NewUserMessage("Hello"),
+		},
+		Temperature: ai.FloatOption(0.7),
+	})
+	if err != nil {
+		fmt.Println(err)
+		return
+	}
+
+	fmt.Println("Assistant: ", response.Content.Text())
+	fmt.Println("Usage:", response.Usage)
+}

providers/examples/stream/main.go 🔗

@@ -0,0 +1,49 @@
+package main
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+
+	"github.com/charmbracelet/crush/internal/ai"
+	"github.com/charmbracelet/crush/internal/ai/providers"
+)
+
+func main() {
+	provider := providers.NewOpenAIProvider(providers.WithOpenAIApiKey("$OPENAI_API_KEY"))
+	model := provider.LanguageModel("gpt-4o")
+
+	stream, err := model.Stream(context.Background(), ai.Call{
+		Prompt: ai.Prompt{
+			ai.NewUserMessage("Whats the weather in pristina."),
+		},
+		Temperature: ai.FloatOption(0.7),
+		Tools: []ai.Tool{
+			ai.FunctionTool{
+				Name:        "weather",
+				Description: "Gets the weather for a location",
+				InputSchema: map[string]any{
+					"type": "object",
+					"properties": map[string]any{
+						"location": map[string]string{
+							"type":        "string",
+							"description": "the city",
+						},
+					},
+					"required": []string{
+						"location",
+					},
+				},
+			},
+		},
+	})
+	if err != nil {
+		fmt.Println(err)
+		return
+	}
+
+	for chunk := range stream {
+		data, _ := json.Marshal(chunk)
+		fmt.Println(string(data))
+	}
+}

providers/openai.go 🔗

@@ -0,0 +1,1161 @@
+package providers
+
+import (
+	"context"
+	"encoding/base64"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"io"
+	"maps"
+	"strings"
+
+	"github.com/charmbracelet/crush/internal/ai"
+	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/env"
+	"github.com/google/uuid"
+	"github.com/openai/openai-go/v2"
+	"github.com/openai/openai-go/v2/option"
+	"github.com/openai/openai-go/v2/packages/param"
+	"github.com/openai/openai-go/v2/shared"
+)
+
+type ReasoningEffort string
+
+const (
+	ReasoningEffortMinimal ReasoningEffort = "minimal"
+	ReasoningEffortLow     ReasoningEffort = "low"
+	ReasoningEffortMedium  ReasoningEffort = "medium"
+	ReasoningEffortHigh    ReasoningEffort = "high"
+)
+
+type OpenAIProviderOptions struct {
+	LogitBias           map[string]int64 `json:"logit_bias"`
+	LogProbs            *bool            `json:"log_probes"`
+	TopLogProbs         *int64           `json:"top_log_probs"`
+	ParallelToolCalls   *bool            `json:"parallel_tool_calls"`
+	User                *string          `json:"user"`
+	ReasoningEffort     *ReasoningEffort `json:"reasoning_effort"`
+	MaxCompletionTokens *int64           `json:"max_completion_tokens"`
+	TextVerbosity       *string          `json:"text_verbosity"`
+	Prediction          map[string]any   `json:"prediction"`
+	Store               *bool            `json:"store"`
+	Metadata            map[string]any   `json:"metadata"`
+	PromptCacheKey      *string          `json:"prompt_cache_key"`
+	SafetyIdentifier    *string          `json:"safety_identifier"`
+	ServiceTier         *string          `json:"service_tier"`
+	StructuredOutputs   *bool            `json:"structured_outputs"`
+}
+
+type openAIProvider struct {
+	options openAIProviderOptions
+}
+
+type openAIProviderOptions struct {
+	baseURL      string
+	apiKey       string
+	organization string
+	project      string
+	name         string
+	headers      map[string]string
+	client       option.HTTPClient
+	resolver     config.VariableResolver
+}
+
+type OpenAIOption = func(*openAIProviderOptions)
+
+func NewOpenAIProvider(opts ...OpenAIOption) ai.Provider {
+	options := openAIProviderOptions{
+		headers: map[string]string{},
+	}
+	for _, o := range opts {
+		o(&options)
+	}
+
+	if options.resolver == nil {
+		// use the default resolver
+		options.resolver = config.NewShellVariableResolver(env.New())
+	}
+	options.apiKey, _ = options.resolver.ResolveValue(options.apiKey)
+	options.baseURL, _ = options.resolver.ResolveValue(options.baseURL)
+	if options.baseURL == "" {
+		options.baseURL = "https://api.openai.com/v1"
+	}
+
+	options.name, _ = options.resolver.ResolveValue(options.name)
+	if options.name == "" {
+		options.name = "openai"
+	}
+
+	for k, v := range options.headers {
+		options.headers[k], _ = options.resolver.ResolveValue(v)
+	}
+
+	options.organization, _ = options.resolver.ResolveValue(options.organization)
+	if options.organization != "" {
+		options.headers["OpenAI-Organization"] = options.organization
+	}
+
+	options.project, _ = options.resolver.ResolveValue(options.project)
+	if options.project != "" {
+		options.headers["OpenAI-Project"] = options.project
+	}
+
+	return &openAIProvider{
+		options: options,
+	}
+}
+
+func WithOpenAIBaseURL(baseURL string) OpenAIOption {
+	return func(o *openAIProviderOptions) {
+		o.baseURL = baseURL
+	}
+}
+
+func WithOpenAIApiKey(apiKey string) OpenAIOption {
+	return func(o *openAIProviderOptions) {
+		o.apiKey = apiKey
+	}
+}
+
+func WithOpenAIOrganization(organization string) OpenAIOption {
+	return func(o *openAIProviderOptions) {
+		o.organization = organization
+	}
+}
+
+func WithOpenAIProject(project string) OpenAIOption {
+	return func(o *openAIProviderOptions) {
+		o.project = project
+	}
+}
+
+func WithOpenAIName(name string) OpenAIOption {
+	return func(o *openAIProviderOptions) {
+		o.name = name
+	}
+}
+
+func WithOpenAIHeaders(headers map[string]string) OpenAIOption {
+	return func(o *openAIProviderOptions) {
+		maps.Copy(o.headers, headers)
+	}
+}
+
+func WithOpenAIHttpClient(client option.HTTPClient) OpenAIOption {
+	return func(o *openAIProviderOptions) {
+		o.client = client
+	}
+}
+
+func WithOpenAIVariableResolver(resolver config.VariableResolver) OpenAIOption {
+	return func(o *openAIProviderOptions) {
+		o.resolver = resolver
+	}
+}
+
+// LanguageModel implements ai.Provider.
+func (o *openAIProvider) LanguageModel(modelID string) ai.LanguageModel {
+	openaiClientOptions := []option.RequestOption{}
+	if o.options.apiKey != "" {
+		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
+	}
+	if o.options.baseURL != "" {
+		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
+	}
+
+	for key, value := range o.options.headers {
+		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
+	}
+
+	if o.options.client != nil {
+		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
+	}
+
+	return openAILanguageModel{
+		modelID:         modelID,
+		provider:        fmt.Sprintf("%s.chat", o.options.name),
+		providerOptions: o.options,
+		client:          openai.NewClient(openaiClientOptions...),
+	}
+}
+
+type openAILanguageModel struct {
+	provider        string
+	modelID         string
+	client          openai.Client
+	providerOptions openAIProviderOptions
+}
+
+// Model implements ai.LanguageModel.
+func (o openAILanguageModel) Model() string {
+	return o.modelID
+}
+
+// Provider implements ai.LanguageModel.
+func (o openAILanguageModel) Provider() string {
+	return o.provider
+}
+
+func (o openAILanguageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
+	params := &openai.ChatCompletionNewParams{}
+	messages, warnings := toOpenAIPrompt(call.Prompt)
+	providerOptions := &OpenAIProviderOptions{}
+	if v, ok := call.ProviderOptions["openai"]; ok {
+		err := ai.ParseOptions(v, providerOptions)
+		if err != nil {
+			return nil, nil, err
+		}
+	}
+	if call.TopK != nil {
+		warnings = append(warnings, ai.CallWarning{
+			Type:    ai.CallWarningTypeUnsupportedSetting,
+			Setting: "top_k",
+		})
+	}
+	params.Messages = messages
+	params.Model = o.modelID
+	if providerOptions.LogitBias != nil {
+		params.LogitBias = providerOptions.LogitBias
+	}
+	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
+		providerOptions.LogProbs = nil
+	}
+	if providerOptions.LogProbs != nil {
+		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
+	}
+	if providerOptions.TopLogProbs != nil {
+		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
+	}
+	if providerOptions.User != nil {
+		params.User = param.NewOpt(*providerOptions.User)
+	}
+	if providerOptions.ParallelToolCalls != nil {
+		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
+	}
+
+	if call.MaxOutputTokens != nil {
+		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
+	}
+	if call.Temperature != nil {
+		params.Temperature = param.NewOpt(*call.Temperature)
+	}
+	if call.TopP != nil {
+		params.TopP = param.NewOpt(*call.TopP)
+	}
+	if call.FrequencyPenalty != nil {
+		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
+	}
+	if call.PresencePenalty != nil {
+		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
+	}
+
+	if providerOptions.MaxCompletionTokens != nil {
+		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
+	}
+
+	if providerOptions.TextVerbosity != nil {
+		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
+	}
+	if providerOptions.Prediction != nil {
+		// Convert map[string]any to ChatCompletionPredictionContentParam
+		if content, ok := providerOptions.Prediction["content"]; ok {
+			if contentStr, ok := content.(string); ok {
+				params.Prediction = openai.ChatCompletionPredictionContentParam{
+					Content: openai.ChatCompletionPredictionContentContentUnionParam{
+						OfString: param.NewOpt(contentStr),
+					},
+				}
+			}
+		}
+	}
+	if providerOptions.Store != nil {
+		params.Store = param.NewOpt(*providerOptions.Store)
+	}
+	if providerOptions.Metadata != nil {
+		// Convert map[string]any to map[string]string
+		metadata := make(map[string]string)
+		for k, v := range providerOptions.Metadata {
+			if str, ok := v.(string); ok {
+				metadata[k] = str
+			}
+		}
+		params.Metadata = metadata
+	}
+	if providerOptions.PromptCacheKey != nil {
+		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
+	}
+	if providerOptions.SafetyIdentifier != nil {
+		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
+	}
+	if providerOptions.ServiceTier != nil {
+		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
+	}
+
+	if providerOptions.ReasoningEffort != nil {
+		switch *providerOptions.ReasoningEffort {
+		case ReasoningEffortMinimal:
+			params.ReasoningEffort = shared.ReasoningEffortMinimal
+		case ReasoningEffortLow:
+			params.ReasoningEffort = shared.ReasoningEffortLow
+		case ReasoningEffortMedium:
+			params.ReasoningEffort = shared.ReasoningEffortMedium
+		case ReasoningEffortHigh:
+			params.ReasoningEffort = shared.ReasoningEffortHigh
+		default:
+			return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
+		}
+	}
+
+	if isReasoningModel(o.modelID) {
+		// remove unsupported settings for reasoning models
+		// see https://platform.openai.com/docs/guides/reasoning#limitations
+		if call.Temperature != nil {
+			params.Temperature = param.Opt[float64]{}
+			warnings = append(warnings, ai.CallWarning{
+				Type:    ai.CallWarningTypeUnsupportedSetting,
+				Setting: "temperature",
+				Details: "temperature is not supported for reasoning models",
+			})
+		}
+		if call.TopP != nil {
+			params.TopP = param.Opt[float64]{}
+			warnings = append(warnings, ai.CallWarning{
+				Type:    ai.CallWarningTypeUnsupportedSetting,
+				Setting: "top_p",
+				Details: "topP is not supported for reasoning models",
+			})
+		}
+		if call.FrequencyPenalty != nil {
+			params.FrequencyPenalty = param.Opt[float64]{}
+			warnings = append(warnings, ai.CallWarning{
+				Type:    ai.CallWarningTypeUnsupportedSetting,
+				Setting: "frequency_penalty",
+				Details: "frequencyPenalty is not supported for reasoning models",
+			})
+		}
+		if call.PresencePenalty != nil {
+			params.PresencePenalty = param.Opt[float64]{}
+			warnings = append(warnings, ai.CallWarning{
+				Type:    ai.CallWarningTypeUnsupportedSetting,
+				Setting: "presence_penalty",
+				Details: "presencePenalty is not supported for reasoning models",
+			})
+		}
+		if providerOptions.LogitBias != nil {
+			params.LogitBias = nil
+			warnings = append(warnings, ai.CallWarning{
+				Type:    ai.CallWarningTypeOther,
+				Message: "logitBias is not supported for reasoning models",
+			})
+		}
+		if providerOptions.LogProbs != nil {
+			params.Logprobs = param.Opt[bool]{}
+			warnings = append(warnings, ai.CallWarning{
+				Type:    ai.CallWarningTypeOther,
+				Message: "logprobs is not supported for reasoning models",
+			})
+		}
+		if providerOptions.TopLogProbs != nil {
+			params.TopLogprobs = param.Opt[int64]{}
+			warnings = append(warnings, ai.CallWarning{
+				Type:    ai.CallWarningTypeOther,
+				Message: "topLogprobs is not supported for reasoning models",
+			})
+		}
+
+		// reasoning models use max_completion_tokens instead of max_tokens
+		if call.MaxOutputTokens != nil {
+			if providerOptions.MaxCompletionTokens == nil {
+				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
+			}
+			params.MaxTokens = param.Opt[int64]{}
+		}
+	}
+
+	// Handle search preview models
+	if isSearchPreviewModel(o.modelID) {
+		if call.Temperature != nil {
+			params.Temperature = param.Opt[float64]{}
+			warnings = append(warnings, ai.CallWarning{
+				Type:    ai.CallWarningTypeUnsupportedSetting,
+				Setting: "temperature",
+				Details: "temperature is not supported for the search preview models and has been removed.",
+			})
+		}
+	}
+
+	// Handle service tier validation
+	if providerOptions.ServiceTier != nil {
+		serviceTier := *providerOptions.ServiceTier
+		if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
+			params.ServiceTier = ""
+			warnings = append(warnings, ai.CallWarning{
+				Type:    ai.CallWarningTypeUnsupportedSetting,
+				Setting: "serviceTier",
+				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
+			})
+		} else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
+			params.ServiceTier = ""
+			warnings = append(warnings, ai.CallWarning{
+				Type:    ai.CallWarningTypeUnsupportedSetting,
+				Setting: "serviceTier",
+				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
+			})
+		}
+	}
+
+	if len(call.Tools) > 0 {
+		tools, toolChoice, toolWarnings := toOpenAITools(call.Tools, call.ToolChoice)
+		params.Tools = tools
+		if toolChoice != nil {
+			params.ToolChoice = *toolChoice
+		}
+		warnings = append(warnings, toolWarnings...)
+	}
+	return params, warnings, nil
+}
+
+// Generate implements ai.LanguageModel.
+func (o openAILanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
+	params, warnings, err := o.prepareParams(call)
+	if err != nil {
+		return nil, err
+	}
+	response, err := o.client.Chat.Completions.New(ctx, *params)
+	if err != nil {
+		return nil, err
+	}
+
+	if len(response.Choices) == 0 {
+		return nil, errors.New("no response generated")
+	}
+	choice := response.Choices[0]
+	var content []ai.Content
+	text := choice.Message.Content
+	if text != "" {
+		content = append(content, ai.TextContent{
+			Text: text,
+		})
+	}
+
+	for _, tc := range choice.Message.ToolCalls {
+		toolCallID := tc.ID
+		if toolCallID == "" {
+			toolCallID = uuid.NewString()
+		}
+		content = append(content, ai.ToolCallContent{
+			ProviderExecuted: false, // TODO: update when handling other tools
+			ToolCallID:       toolCallID,
+			ToolName:         tc.Function.Name,
+			Input:            tc.Function.Arguments,
+		})
+	}
+	// Handle annotations/citations
+	for _, annotation := range choice.Message.Annotations {
+		if annotation.Type == "url_citation" {
+			content = append(content, ai.SourceContent{
+				SourceType: ai.SourceTypeURL,
+				ID:         uuid.NewString(),
+				URL:        annotation.URLCitation.URL,
+				Title:      annotation.URLCitation.Title,
+			})
+		}
+	}
+
+	completionTokenDetails := response.Usage.CompletionTokensDetails
+	promptTokenDetails := response.Usage.PromptTokensDetails
+
+	// Build provider metadata
+	providerMetadata := ai.ProviderMetadata{
+		"openai": make(map[string]any),
+	}
+
+	// Add logprobs if available
+	if len(choice.Logprobs.Content) > 0 {
+		providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content
+	}
+
+	// Add prediction tokens if available
+	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
+		if completionTokenDetails.AcceptedPredictionTokens > 0 {
+			providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
+		}
+		if completionTokenDetails.RejectedPredictionTokens > 0 {
+			providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
+		}
+	}
+
+	return &ai.Response{
+		Content: content,
+		Usage: ai.Usage{
+			InputTokens:     response.Usage.PromptTokens,
+			OutputTokens:    response.Usage.CompletionTokens,
+			TotalTokens:     response.Usage.TotalTokens,
+			ReasoningTokens: completionTokenDetails.ReasoningTokens,
+			CacheReadTokens: promptTokenDetails.CachedTokens,
+		},
+		FinishReason:     mapOpenAIFinishReason(choice.FinishReason),
+		ProviderMetadata: providerMetadata,
+		Warnings:         warnings,
+	}, nil
+}
+
+type toolCall struct {
+	id          string
+	name        string
+	arguments   string
+	hasFinished bool
+}
+
+// Stream implements ai.LanguageModel.
+func (o openAILanguageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
+	params, warnings, err := o.prepareParams(call)
+	if err != nil {
+		return nil, err
+	}
+
+	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
+		IncludeUsage: openai.Bool(true),
+	}
+
+	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
+	isActiveText := false
+	toolCalls := make(map[int64]toolCall)
+
+	// Build provider metadata for streaming
+	streamProviderMetadata := ai.ProviderOptions{
+		"openai": make(map[string]any),
+	}
+
+	acc := openai.ChatCompletionAccumulator{}
+	var usage ai.Usage
+	return func(yield func(ai.StreamPart) bool) {
+		if len(warnings) > 0 {
+			if !yield(ai.StreamPart{
+				Type:     ai.StreamPartTypeWarnings,
+				Warnings: warnings,
+			}) {
+				return
+			}
+		}
+		for stream.Next() {
+			chunk := stream.Current()
+			acc.AddChunk(chunk)
+			if chunk.Usage.TotalTokens > 0 {
+				// we do this here because the acc does not add prompt details
+				completionTokenDetails := chunk.Usage.CompletionTokensDetails
+				promptTokenDetails := chunk.Usage.PromptTokensDetails
+				usage = ai.Usage{
+					InputTokens:     chunk.Usage.PromptTokens,
+					OutputTokens:    chunk.Usage.CompletionTokens,
+					TotalTokens:     chunk.Usage.TotalTokens,
+					ReasoningTokens: completionTokenDetails.ReasoningTokens,
+					CacheReadTokens: promptTokenDetails.CachedTokens,
+				}
+
+				// Add prediction tokens if available
+				if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
+					if completionTokenDetails.AcceptedPredictionTokens > 0 {
+						streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
+					}
+					if completionTokenDetails.RejectedPredictionTokens > 0 {
+						streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
+					}
+				}
+			}
+			if len(chunk.Choices) == 0 {
+				continue
+			}
+			for _, choice := range chunk.Choices {
+				switch {
+				case choice.Delta.Content != "":
+					if !isActiveText {
+						isActiveText = true
+						if !yield(ai.StreamPart{
+							Type: ai.StreamPartTypeTextStart,
+							ID:   "0",
+						}) {
+							return
+						}
+					}
+					if !yield(ai.StreamPart{
+						Type:  ai.StreamPartTypeTextDelta,
+						ID:    "0",
+						Delta: choice.Delta.Content,
+					}) {
+						return
+					}
+				case len(choice.Delta.ToolCalls) > 0:
+					if isActiveText {
+						isActiveText = false
+						if !yield(ai.StreamPart{
+							Type: ai.StreamPartTypeTextEnd,
+							ID:   "0",
+						}) {
+							return
+						}
+					}
+
+					for _, toolCallDelta := range choice.Delta.ToolCalls {
+						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
+							if existingToolCall.hasFinished {
+								continue
+							}
+							if toolCallDelta.Function.Arguments != "" {
+								existingToolCall.arguments += toolCallDelta.Function.Arguments
+							}
+							if !yield(ai.StreamPart{
+								Type:  ai.StreamPartTypeToolInputDelta,
+								ID:    existingToolCall.id,
+								Delta: toolCallDelta.Function.Arguments,
+							}) {
+								return
+							}
+							toolCalls[toolCallDelta.Index] = existingToolCall
+							if existingToolCall.arguments != "" && ai.IsParsableJSON(existingToolCall.arguments) {
+								if !yield(ai.StreamPart{
+									Type: ai.StreamPartTypeToolInputEnd,
+									ID:   existingToolCall.id,
+								}) {
+									return
+								}
+
+								if !yield(ai.StreamPart{
+									Type:          ai.StreamPartTypeToolCall,
+									ID:            existingToolCall.id,
+									ToolCallName:  existingToolCall.name,
+									ToolCallInput: existingToolCall.arguments,
+								}) {
+									return
+								}
+								existingToolCall.hasFinished = true
+								toolCalls[toolCallDelta.Index] = existingToolCall
+							}
+
+						} else {
+							// Does not exist
+							var err error
+							if toolCallDelta.Type != "function" {
+								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
+							}
+							if toolCallDelta.ID == "" {
+								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
+							}
+							if toolCallDelta.Function.Name == "" {
+								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
+							}
+							if err != nil {
+								yield(ai.StreamPart{
+									Type:  ai.StreamPartTypeError,
+									Error: stream.Err(),
+								})
+								return
+							}
+
+							if !yield(ai.StreamPart{
+								Type:         ai.StreamPartTypeToolInputStart,
+								ID:           toolCallDelta.ID,
+								ToolCallName: toolCallDelta.Function.Name,
+							}) {
+								return
+							}
+							toolCalls[toolCallDelta.Index] = toolCall{
+								id:        toolCallDelta.ID,
+								name:      toolCallDelta.Function.Name,
+								arguments: toolCallDelta.Function.Arguments,
+							}
+
+							exTc := toolCalls[toolCallDelta.Index]
+							if exTc.arguments != "" {
+								if !yield(ai.StreamPart{
+									Type:  ai.StreamPartTypeToolInputDelta,
+									ID:    exTc.id,
+									Delta: exTc.arguments,
+								}) {
+									return
+								}
+								if ai.IsParsableJSON(toolCalls[toolCallDelta.Index].arguments) {
+									if !yield(ai.StreamPart{
+										Type: ai.StreamPartTypeToolInputEnd,
+										ID:   toolCallDelta.ID,
+									}) {
+										return
+									}
+
+									if !yield(ai.StreamPart{
+										Type:          ai.StreamPartTypeToolCall,
+										ID:            exTc.id,
+										ToolCallName:  exTc.name,
+										ToolCallInput: exTc.arguments,
+									}) {
+										return
+									}
+									exTc.hasFinished = true
+									toolCalls[toolCallDelta.Index] = exTc
+								}
+							}
+							continue
+						}
+					}
+				}
+			}
+
+			// Check for annotations in the delta's raw JSON
+			for _, choice := range chunk.Choices {
+				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
+					for _, annotation := range annotations {
+						if annotation.Type == "url_citation" {
+							if !yield(ai.StreamPart{
+								Type:       ai.StreamPartTypeSource,
+								ID:         uuid.NewString(),
+								SourceType: ai.SourceTypeURL,
+								URL:        annotation.URLCitation.URL,
+								Title:      annotation.URLCitation.Title,
+							}) {
+								return
+							}
+						}
+					}
+				}
+			}
+
+		}
+		err := stream.Err()
+		if err == nil || errors.Is(err, io.EOF) {
+			// finished
+			if isActiveText {
+				isActiveText = false
+				if !yield(ai.StreamPart{
+					Type: ai.StreamPartTypeTextEnd,
+					ID:   "0",
+				}) {
+					return
+				}
+			}
+
+			// Add logprobs if available
+			if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
+				streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content
+			}
+
+			// Handle annotations/citations from accumulated response
+			if len(acc.Choices) > 0 {
+				for _, annotation := range acc.Choices[0].Message.Annotations {
+					if annotation.Type == "url_citation" {
+						if !yield(ai.StreamPart{
+							Type:       ai.StreamPartTypeSource,
+							ID:         uuid.NewString(),
+							SourceType: ai.SourceTypeURL,
+							URL:        annotation.URLCitation.URL,
+							Title:      annotation.URLCitation.Title,
+						}) {
+							return
+						}
+					}
+				}
+			}
+
+			finishReason := mapOpenAIFinishReason(acc.Choices[0].FinishReason)
+			yield(ai.StreamPart{
+				Type:             ai.StreamPartTypeFinish,
+				Usage:            usage,
+				FinishReason:     finishReason,
+				ProviderMetadata: streamProviderMetadata,
+			})
+			return
+
+		} else {
+			yield(ai.StreamPart{
+				Type:  ai.StreamPartTypeError,
+				Error: stream.Err(),
+			})
+			return
+		}
+	}, nil
+}
+
+func mapOpenAIFinishReason(finishReason string) ai.FinishReason {
+	switch finishReason {
+	case "stop":
+		return ai.FinishReasonStop
+	case "length":
+		return ai.FinishReasonLength
+	case "content_filter":
+		return ai.FinishReasonContentFilter
+	case "function_call", "tool_calls":
+		return ai.FinishReasonToolCalls
+	default:
+		return ai.FinishReasonUnknown
+	}
+}
+
+func isReasoningModel(modelID string) bool {
+	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
+}
+
+func isSearchPreviewModel(modelID string) bool {
+	return strings.Contains(modelID, "search-preview")
+}
+
+func supportsFlexProcessing(modelID string) bool {
+	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
+}
+
+func supportsPriorityProcessing(modelID string) bool {
+	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
+		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
+		strings.HasPrefix(modelID, "o4-mini")
+}
+
+func toOpenAITools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAITools []openai.ChatCompletionToolUnionParam, openAIToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
+	for _, tool := range tools {
+		if tool.GetType() == ai.ToolTypeFunction {
+			ft, ok := tool.(ai.FunctionTool)
+			if !ok {
+				continue
+			}
+			openAITools = append(openAITools, openai.ChatCompletionToolUnionParam{
+				OfFunction: &openai.ChatCompletionFunctionToolParam{
+					Function: shared.FunctionDefinitionParam{
+						Name:        ft.Name,
+						Description: param.NewOpt(ft.Description),
+						Parameters:  openai.FunctionParameters(ft.InputSchema),
+						Strict:      param.NewOpt(false),
+					},
+					Type: "function",
+				},
+			})
+			continue
+		}
+
+		// TODO: handle provider tool calls
+		warnings = append(warnings, ai.CallWarning{
+			Type:    ai.CallWarningTypeUnsupportedTool,
+			Tool:    tool,
+			Message: "tool is not supported",
+		})
+	}
+	if toolChoice == nil {
+		return
+	}
+
+	switch *toolChoice {
+	case ai.ToolChoiceAuto:
+		openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
+			OfAuto: param.NewOpt("auto"),
+		}
+	case ai.ToolChoiceNone:
+		openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
+			OfAuto: param.NewOpt("none"),
+		}
+	default:
+		openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
+			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
+				Type: "function",
+				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
+					Name: string(*toolChoice),
+				},
+			},
+		}
+	}
+	return
+}
+
+func toOpenAIPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
+	var messages []openai.ChatCompletionMessageParamUnion
+	var warnings []ai.CallWarning
+	for _, msg := range prompt {
+		switch msg.Role {
+		case ai.MessageRoleSystem:
+			var systemPromptParts []string
+			for _, c := range msg.Content {
+				if c.GetType() != ai.ContentTypeText {
+					warnings = append(warnings, ai.CallWarning{
+						Type:    ai.CallWarningTypeOther,
+						Message: "system prompt can only have text content",
+					})
+					continue
+				}
+				textPart, ok := ai.AsContentType[ai.TextPart](c)
+				if !ok {
+					warnings = append(warnings, ai.CallWarning{
+						Type:    ai.CallWarningTypeOther,
+						Message: "system prompt text part does not have the right type",
+					})
+					continue
+				}
+				text := textPart.Text
+				if strings.TrimSpace(text) != "" {
+					systemPromptParts = append(systemPromptParts, textPart.Text)
+				}
+			}
+			if len(systemPromptParts) == 0 {
+				warnings = append(warnings, ai.CallWarning{
+					Type:    ai.CallWarningTypeOther,
+					Message: "system prompt has no text parts",
+				})
+				continue
+			}
+			messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
+		case ai.MessageRoleUser:
+			// simple user message just text content
+			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
+				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
+				if !ok {
+					warnings = append(warnings, ai.CallWarning{
+						Type:    ai.CallWarningTypeOther,
+						Message: "user message text part does not have the right type",
+					})
+					continue
+				}
+				messages = append(messages, openai.UserMessage(textPart.Text))
+				continue
+			}
+			// text content and attachments
+			// for now we only support image content later we need to check
+			// TODO: add the supported media types to the language model so we
+			//  can use that to validate the data here.
+			var content []openai.ChatCompletionContentPartUnionParam
+			for _, c := range msg.Content {
+				switch c.GetType() {
+				case ai.ContentTypeText:
+					textPart, ok := ai.AsContentType[ai.TextPart](c)
+					if !ok {
+						warnings = append(warnings, ai.CallWarning{
+							Type:    ai.CallWarningTypeOther,
+							Message: "user message text part does not have the right type",
+						})
+						continue
+					}
+					content = append(content, openai.ChatCompletionContentPartUnionParam{
+						OfText: &openai.ChatCompletionContentPartTextParam{
+							Text: textPart.Text,
+						},
+					})
+				case ai.ContentTypeFile:
+					filePart, ok := ai.AsContentType[ai.FilePart](c)
+					if !ok {
+						warnings = append(warnings, ai.CallWarning{
+							Type:    ai.CallWarningTypeOther,
+							Message: "user message file part does not have the right type",
+						})
+						continue
+					}
+
+					switch {
+					case strings.HasPrefix(filePart.MediaType, "image/"):
+						// Handle image files
+						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
+						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
+						imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
+
+						// Check for provider-specific options like image detail
+						if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
+							if detail, ok := providerOptions["imageDetail"].(string); ok {
+								imageURL.Detail = detail
+							}
+						}
+
+						imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
+						content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
+
+					case filePart.MediaType == "audio/wav":
+						// Handle WAV audio files
+						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
+						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
+							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
+								Data:   base64Encoded,
+								Format: "wav",
+							},
+						}
+						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
+
+					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
+						// Handle MP3 audio files
+						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
+						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
+							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
+								Data:   base64Encoded,
+								Format: "mp3",
+							},
+						}
+						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
+
+					case filePart.MediaType == "application/pdf":
+						// Handle PDF files
+						dataStr := string(filePart.Data)
+
+						// Check if data looks like a file ID (starts with "file-")
+						if strings.HasPrefix(dataStr, "file-") {
+							fileBlock := openai.ChatCompletionContentPartFileParam{
+								File: openai.ChatCompletionContentPartFileFileParam{
+									FileID: param.NewOpt(dataStr),
+								},
+							}
+							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
+						} else {
+							// Handle as base64 data
+							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
+							data := "data:application/pdf;base64," + base64Encoded
+
+							filename := filePart.Filename
+							if filename == "" {
+								// Generate default filename based on content index
+								filename = fmt.Sprintf("part-%d.pdf", len(content))
+							}
+
+							fileBlock := openai.ChatCompletionContentPartFileParam{
+								File: openai.ChatCompletionContentPartFileFileParam{
+									Filename: param.NewOpt(filename),
+									FileData: param.NewOpt(data),
+								},
+							}
+							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
+						}
+
+					default:
+						warnings = append(warnings, ai.CallWarning{
+							Type:    ai.CallWarningTypeOther,
+							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
+						})
+					}
+				}
+			}
+			messages = append(messages, openai.UserMessage(content))
+		case ai.MessageRoleAssistant:
+			// simple assistant message just text content
+			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
+				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
+				if !ok {
+					warnings = append(warnings, ai.CallWarning{
+						Type:    ai.CallWarningTypeOther,
+						Message: "assistant message text part does not have the right type",
+					})
+					continue
+				}
+				messages = append(messages, openai.AssistantMessage(textPart.Text))
+				continue
+			}
+			assistantMsg := openai.ChatCompletionAssistantMessageParam{
+				Role: "assistant",
+			}
+			for _, c := range msg.Content {
+				switch c.GetType() {
+				case ai.ContentTypeText:
+					textPart, ok := ai.AsContentType[ai.TextPart](c)
+					if !ok {
+						warnings = append(warnings, ai.CallWarning{
+							Type:    ai.CallWarningTypeOther,
+							Message: "assistant message text part does not have the right type",
+						})
+						continue
+					}
+					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
+						OfString: param.NewOpt(textPart.Text),
+					}
+				case ai.ContentTypeToolCall:
+					toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
+					if !ok {
+						warnings = append(warnings, ai.CallWarning{
+							Type:    ai.CallWarningTypeOther,
+							Message: "assistant message tool part does not have the right type",
+						})
+						continue
+					}
+					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
+						openai.ChatCompletionMessageToolCallUnionParam{
+							OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
+								ID:   toolCallPart.ToolCallID,
+								Type: "function",
+								Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
+									Name:      toolCallPart.ToolName,
+									Arguments: toolCallPart.Input,
+								},
+							},
+						})
+				}
+			}
+			messages = append(messages, openai.ChatCompletionMessageParamUnion{
+				OfAssistant: &assistantMsg,
+			})
+		case ai.MessageRoleTool:
+			for _, c := range msg.Content {
+				if c.GetType() != ai.ContentTypeToolResult {
+					warnings = append(warnings, ai.CallWarning{
+						Type:    ai.CallWarningTypeOther,
+						Message: "tool message can only have tool result content",
+					})
+					continue
+				}
+
+				toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
+				if !ok {
+					warnings = append(warnings, ai.CallWarning{
+						Type:    ai.CallWarningTypeOther,
+						Message: "tool message result part does not have the right type",
+					})
+					continue
+				}
+
+				switch toolResultPart.Output.GetType() {
+				case ai.ToolResultContentTypeText:
+					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
+					if !ok {
+						warnings = append(warnings, ai.CallWarning{
+							Type:    ai.CallWarningTypeOther,
+							Message: "tool result output does not have the right type",
+						})
+						continue
+					}
+					messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
+				case ai.ToolResultContentTypeError:
+					// TODO: check if better handling is needed
+					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
+					if !ok {
+						warnings = append(warnings, ai.CallWarning{
+							Type:    ai.CallWarningTypeOther,
+							Message: "tool result output does not have the right type",
+						})
+						continue
+					}
+					messages = append(messages, openai.ToolMessage(output.Error, toolResultPart.ToolCallID))
+				}
+			}
+		}
+	}
+	return messages, warnings
+}
+
+// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta
+func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
+	var annotations []openai.ChatCompletionMessageAnnotation
+
+	// Parse the raw JSON to extract annotations
+	var deltaData map[string]interface{}
+	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
+		return annotations
+	}
+
+	// Check if annotations exist in the delta
+	if annotationsData, ok := deltaData["annotations"].([]interface{}); ok {
+		for _, annotationData := range annotationsData {
+			if annotationMap, ok := annotationData.(map[string]interface{}); ok {
+				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
+					if urlCitationData, ok := annotationMap["url_citation"].(map[string]interface{}); ok {
+						annotation := openai.ChatCompletionMessageAnnotation{
+							Type: "url_citation",
+							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
+								URL:   urlCitationData["url"].(string),
+								Title: urlCitationData["title"].(string),
+							},
+						}
+						annotations = append(annotations, annotation)
+					}
+				}
+			}
+		}
+	}
+
+	return annotations
+}

providers/openai_test.go 🔗

@@ -0,0 +1,2850 @@
+package providers
+
+import (
+	"context"
+	"encoding/base64"
+	"encoding/json"
+	"net/http"
+	"net/http/httptest"
+	"strings"
+	"testing"
+
+	"github.com/charmbracelet/crush/internal/ai"
+	"github.com/openai/openai-go/v2/packages/param"
+	"github.com/stretchr/testify/require"
+)
+
+func TestToOpenAIPrompt_SystemMessages(t *testing.T) {
+	t.Parallel()
+
+	t.Run("should forward system messages", func(t *testing.T) {
+		t.Parallel()
+
+		prompt := ai.Prompt{
+			{
+				Role: ai.MessageRoleSystem,
+				Content: []ai.MessagePart{
+					ai.TextPart{Text: "You are a helpful assistant."},
+				},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Empty(t, warnings)
+		require.Len(t, messages, 1)
+
+		systemMsg := messages[0].OfSystem
+		require.NotNil(t, systemMsg)
+		require.Equal(t, "You are a helpful assistant.", systemMsg.Content.OfString.Value)
+	})
+
+	t.Run("should handle empty system messages", func(t *testing.T) {
+		t.Parallel()
+
+		prompt := ai.Prompt{
+			{
+				Role:    ai.MessageRoleSystem,
+				Content: []ai.MessagePart{},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Len(t, warnings, 1)
+		require.Contains(t, warnings[0].Message, "system prompt has no text parts")
+		require.Empty(t, messages)
+	})
+
+	t.Run("should join multiple system text parts", func(t *testing.T) {
+		t.Parallel()
+
+		prompt := ai.Prompt{
+			{
+				Role: ai.MessageRoleSystem,
+				Content: []ai.MessagePart{
+					ai.TextPart{Text: "You are a helpful assistant."},
+					ai.TextPart{Text: "Be concise."},
+				},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Empty(t, warnings)
+		require.Len(t, messages, 1)
+
+		systemMsg := messages[0].OfSystem
+		require.NotNil(t, systemMsg)
+		require.Equal(t, "You are a helpful assistant.\nBe concise.", systemMsg.Content.OfString.Value)
+	})
+}
+
+func TestToOpenAIPrompt_UserMessages(t *testing.T) {
+	t.Parallel()
+
+	t.Run("should convert messages with only a text part to a string content", func(t *testing.T) {
+		t.Parallel()
+
+		prompt := ai.Prompt{
+			{
+				Role: ai.MessageRoleUser,
+				Content: []ai.MessagePart{
+					ai.TextPart{Text: "Hello"},
+				},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Empty(t, warnings)
+		require.Len(t, messages, 1)
+
+		userMsg := messages[0].OfUser
+		require.NotNil(t, userMsg)
+		require.Equal(t, "Hello", userMsg.Content.OfString.Value)
+	})
+
+	t.Run("should convert messages with image parts", func(t *testing.T) {
+		t.Parallel()
+
+		imageData := []byte{0, 1, 2, 3}
+		prompt := ai.Prompt{
+			{
+				Role: ai.MessageRoleUser,
+				Content: []ai.MessagePart{
+					ai.TextPart{Text: "Hello"},
+					ai.FilePart{
+						MediaType: "image/png",
+						Data:      imageData,
+					},
+				},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Empty(t, warnings)
+		require.Len(t, messages, 1)
+
+		userMsg := messages[0].OfUser
+		require.NotNil(t, userMsg)
+
+		content := userMsg.Content.OfArrayOfContentParts
+		require.Len(t, content, 2)
+
+		// Check text part
+		textPart := content[0].OfText
+		require.NotNil(t, textPart)
+		require.Equal(t, "Hello", textPart.Text)
+
+		// Check image part
+		imagePart := content[1].OfImageURL
+		require.NotNil(t, imagePart)
+		expectedURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
+		require.Equal(t, expectedURL, imagePart.ImageURL.URL)
+	})
+
+	t.Run("should add image detail when specified through provider options", func(t *testing.T) {
+		t.Parallel()
+
+		imageData := []byte{0, 1, 2, 3}
+		prompt := ai.Prompt{
+			{
+				Role: ai.MessageRoleUser,
+				Content: []ai.MessagePart{
+					ai.FilePart{
+						MediaType: "image/png",
+						Data:      imageData,
+						ProviderOptions: ai.ProviderOptions{
+							"openai": map[string]any{
+								"imageDetail": "low",
+							},
+						},
+					},
+				},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Empty(t, warnings)
+		require.Len(t, messages, 1)
+
+		userMsg := messages[0].OfUser
+		require.NotNil(t, userMsg)
+
+		content := userMsg.Content.OfArrayOfContentParts
+		require.Len(t, content, 1)
+
+		imagePart := content[0].OfImageURL
+		require.NotNil(t, imagePart)
+		require.Equal(t, "low", imagePart.ImageURL.Detail)
+	})
+}
+
+func TestToOpenAIPrompt_FileParts(t *testing.T) {
+	t.Parallel()
+
+	t.Run("should throw for unsupported mime types", func(t *testing.T) {
+		t.Parallel()
+
+		prompt := ai.Prompt{
+			{
+				Role: ai.MessageRoleUser,
+				Content: []ai.MessagePart{
+					ai.FilePart{
+						MediaType: "application/something",
+						Data:      []byte("test"),
+					},
+				},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Len(t, warnings, 1)
+		require.Contains(t, warnings[0].Message, "file part media type application/something not supported")
+		require.Len(t, messages, 1) // Message is still created but with empty content array
+	})
+
+	t.Run("should add audio content for audio/wav file parts", func(t *testing.T) {
+		t.Parallel()
+
+		audioData := []byte{0, 1, 2, 3}
+		prompt := ai.Prompt{
+			{
+				Role: ai.MessageRoleUser,
+				Content: []ai.MessagePart{
+					ai.FilePart{
+						MediaType: "audio/wav",
+						Data:      audioData,
+					},
+				},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Empty(t, warnings)
+		require.Len(t, messages, 1)
+
+		userMsg := messages[0].OfUser
+		require.NotNil(t, userMsg)
+
+		content := userMsg.Content.OfArrayOfContentParts
+		require.Len(t, content, 1)
+
+		audioPart := content[0].OfInputAudio
+		require.NotNil(t, audioPart)
+		require.Equal(t, base64.StdEncoding.EncodeToString(audioData), audioPart.InputAudio.Data)
+		require.Equal(t, "wav", audioPart.InputAudio.Format)
+	})
+
+	t.Run("should add audio content for audio/mpeg file parts", func(t *testing.T) {
+		t.Parallel()
+
+		audioData := []byte{0, 1, 2, 3}
+		prompt := ai.Prompt{
+			{
+				Role: ai.MessageRoleUser,
+				Content: []ai.MessagePart{
+					ai.FilePart{
+						MediaType: "audio/mpeg",
+						Data:      audioData,
+					},
+				},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Empty(t, warnings)
+		require.Len(t, messages, 1)
+
+		userMsg := messages[0].OfUser
+		content := userMsg.Content.OfArrayOfContentParts
+		audioPart := content[0].OfInputAudio
+		require.NotNil(t, audioPart)
+		require.Equal(t, "mp3", audioPart.InputAudio.Format)
+	})
+
+	t.Run("should add audio content for audio/mp3 file parts", func(t *testing.T) {
+		t.Parallel()
+
+		audioData := []byte{0, 1, 2, 3}
+		prompt := ai.Prompt{
+			{
+				Role: ai.MessageRoleUser,
+				Content: []ai.MessagePart{
+					ai.FilePart{
+						MediaType: "audio/mp3",
+						Data:      audioData,
+					},
+				},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Empty(t, warnings)
+		require.Len(t, messages, 1)
+
+		userMsg := messages[0].OfUser
+		content := userMsg.Content.OfArrayOfContentParts
+		audioPart := content[0].OfInputAudio
+		require.NotNil(t, audioPart)
+		require.Equal(t, "mp3", audioPart.InputAudio.Format)
+	})
+
+	t.Run("should convert messages with PDF file parts", func(t *testing.T) {
+		t.Parallel()
+
+		pdfData := []byte{1, 2, 3, 4, 5}
+		prompt := ai.Prompt{
+			{
+				Role: ai.MessageRoleUser,
+				Content: []ai.MessagePart{
+					ai.FilePart{
+						MediaType: "application/pdf",
+						Data:      pdfData,
+						Filename:  "document.pdf",
+					},
+				},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Empty(t, warnings)
+		require.Len(t, messages, 1)
+
+		userMsg := messages[0].OfUser
+		content := userMsg.Content.OfArrayOfContentParts
+		require.Len(t, content, 1)
+
+		filePart := content[0].OfFile
+		require.NotNil(t, filePart)
+		require.Equal(t, "document.pdf", filePart.File.Filename.Value)
+
+		expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
+		require.Equal(t, expectedData, filePart.File.FileData.Value)
+	})
+
+	t.Run("should convert messages with binary PDF file parts", func(t *testing.T) {
+		t.Parallel()
+
+		pdfData := []byte{1, 2, 3, 4, 5}
+		prompt := ai.Prompt{
+			{
+				Role: ai.MessageRoleUser,
+				Content: []ai.MessagePart{
+					ai.FilePart{
+						MediaType: "application/pdf",
+						Data:      pdfData,
+						Filename:  "document.pdf",
+					},
+				},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Empty(t, warnings)
+		require.Len(t, messages, 1)
+
+		userMsg := messages[0].OfUser
+		content := userMsg.Content.OfArrayOfContentParts
+		filePart := content[0].OfFile
+		require.NotNil(t, filePart)
+
+		expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
+		require.Equal(t, expectedData, filePart.File.FileData.Value)
+	})
+
+	t.Run("should convert messages with PDF file parts using file_id", func(t *testing.T) {
+		t.Parallel()
+
+		prompt := ai.Prompt{
+			{
+				Role: ai.MessageRoleUser,
+				Content: []ai.MessagePart{
+					ai.FilePart{
+						MediaType: "application/pdf",
+						Data:      []byte("file-pdf-12345"),
+					},
+				},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Empty(t, warnings)
+		require.Len(t, messages, 1)
+
+		userMsg := messages[0].OfUser
+		content := userMsg.Content.OfArrayOfContentParts
+		filePart := content[0].OfFile
+		require.NotNil(t, filePart)
+		require.Equal(t, "file-pdf-12345", filePart.File.FileID.Value)
+		require.True(t, param.IsOmitted(filePart.File.FileData))
+		require.True(t, param.IsOmitted(filePart.File.Filename))
+	})
+
+	t.Run("should use default filename for PDF file parts when not provided", func(t *testing.T) {
+		t.Parallel()
+
+		pdfData := []byte{1, 2, 3, 4, 5}
+		prompt := ai.Prompt{
+			{
+				Role: ai.MessageRoleUser,
+				Content: []ai.MessagePart{
+					ai.FilePart{
+						MediaType: "application/pdf",
+						Data:      pdfData,
+					},
+				},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Empty(t, warnings)
+		require.Len(t, messages, 1)
+
+		userMsg := messages[0].OfUser
+		content := userMsg.Content.OfArrayOfContentParts
+		filePart := content[0].OfFile
+		require.NotNil(t, filePart)
+		require.Equal(t, "part-0.pdf", filePart.File.Filename.Value)
+	})
+}
+
+func TestToOpenAIPrompt_ToolCalls(t *testing.T) {
+	t.Parallel()
+
+	t.Run("should stringify arguments to tool calls", func(t *testing.T) {
+		t.Parallel()
+
+		inputArgs := map[string]any{"foo": "bar123"}
+		inputJSON, _ := json.Marshal(inputArgs)
+
+		outputResult := map[string]any{"oof": "321rab"}
+		outputJSON, _ := json.Marshal(outputResult)
+
+		prompt := ai.Prompt{
+			{
+				Role: ai.MessageRoleAssistant,
+				Content: []ai.MessagePart{
+					ai.ToolCallPart{
+						ToolCallID: "quux",
+						ToolName:   "thwomp",
+						Input:      string(inputJSON),
+					},
+				},
+			},
+			{
+				Role: ai.MessageRoleTool,
+				Content: []ai.MessagePart{
+					ai.ToolResultPart{
+						ToolCallID: "quux",
+						Output: ai.ToolResultOutputContentText{
+							Text: string(outputJSON),
+						},
+					},
+				},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Empty(t, warnings)
+		require.Len(t, messages, 2)
+
+		// Check assistant message with tool call
+		assistantMsg := messages[0].OfAssistant
+		require.NotNil(t, assistantMsg)
+		require.Equal(t, "", assistantMsg.Content.OfString.Value)
+		require.Len(t, assistantMsg.ToolCalls, 1)
+
+		toolCall := assistantMsg.ToolCalls[0].OfFunction
+		require.NotNil(t, toolCall)
+		require.Equal(t, "quux", toolCall.ID)
+		require.Equal(t, "thwomp", toolCall.Function.Name)
+		require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
+
+		// Check tool message
+		toolMsg := messages[1].OfTool
+		require.NotNil(t, toolMsg)
+		require.Equal(t, string(outputJSON), toolMsg.Content.OfString.Value)
+		require.Equal(t, "quux", toolMsg.ToolCallID)
+	})
+
+	t.Run("should handle different tool output types", func(t *testing.T) {
+		t.Parallel()
+
+		prompt := ai.Prompt{
+			{
+				Role: ai.MessageRoleTool,
+				Content: []ai.MessagePart{
+					ai.ToolResultPart{
+						ToolCallID: "text-tool",
+						Output: ai.ToolResultOutputContentText{
+							Text: "Hello world",
+						},
+					},
+					ai.ToolResultPart{
+						ToolCallID: "error-tool",
+						Output: ai.ToolResultOutputContentError{
+							Error: "Something went wrong",
+						},
+					},
+				},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Empty(t, warnings)
+		require.Len(t, messages, 2)
+
+		// Check first tool message (text)
+		textToolMsg := messages[0].OfTool
+		require.NotNil(t, textToolMsg)
+		require.Equal(t, "Hello world", textToolMsg.Content.OfString.Value)
+		require.Equal(t, "text-tool", textToolMsg.ToolCallID)
+
+		// Check second tool message (error)
+		errorToolMsg := messages[1].OfTool
+		require.NotNil(t, errorToolMsg)
+		require.Equal(t, "Something went wrong", errorToolMsg.Content.OfString.Value)
+		require.Equal(t, "error-tool", errorToolMsg.ToolCallID)
+	})
+}
+
+func TestToOpenAIPrompt_AssistantMessages(t *testing.T) {
+	t.Parallel()
+
+	t.Run("should handle simple text assistant messages", func(t *testing.T) {
+		t.Parallel()
+
+		prompt := ai.Prompt{
+			{
+				Role: ai.MessageRoleAssistant,
+				Content: []ai.MessagePart{
+					ai.TextPart{Text: "Hello, how can I help you?"},
+				},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Empty(t, warnings)
+		require.Len(t, messages, 1)
+
+		assistantMsg := messages[0].OfAssistant
+		require.NotNil(t, assistantMsg)
+		require.Equal(t, "Hello, how can I help you?", assistantMsg.Content.OfString.Value)
+	})
+
+	t.Run("should handle assistant messages with mixed content", func(t *testing.T) {
+		t.Parallel()
+
+		inputArgs := map[string]any{"query": "test"}
+		inputJSON, _ := json.Marshal(inputArgs)
+
+		prompt := ai.Prompt{
+			{
+				Role: ai.MessageRoleAssistant,
+				Content: []ai.MessagePart{
+					ai.TextPart{Text: "Let me search for that."},
+					ai.ToolCallPart{
+						ToolCallID: "call-123",
+						ToolName:   "search",
+						Input:      string(inputJSON),
+					},
+				},
+			},
+		}
+
+		messages, warnings := toOpenAIPrompt(prompt)
+
+		require.Empty(t, warnings)
+		require.Len(t, messages, 1)
+
+		assistantMsg := messages[0].OfAssistant
+		require.NotNil(t, assistantMsg)
+		require.Equal(t, "Let me search for that.", assistantMsg.Content.OfString.Value)
+		require.Len(t, assistantMsg.ToolCalls, 1)
+
+		toolCall := assistantMsg.ToolCalls[0].OfFunction
+		require.Equal(t, "call-123", toolCall.ID)
+		require.Equal(t, "search", toolCall.Function.Name)
+		require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
+	})
+}
+
+var testPrompt = ai.Prompt{
+	{
+		Role: ai.MessageRoleUser,
+		Content: []ai.MessagePart{
+			ai.TextPart{Text: "Hello"},
+		},
+	},
+}
+
+var testLogprobs = map[string]any{
+	"content": []map[string]any{
+		{
+			"token":   "Hello",
+			"logprob": -0.0009994634,
+			"top_logprobs": []map[string]any{
+				{
+					"token":   "Hello",
+					"logprob": -0.0009994634,
+				},
+			},
+		},
+		{
+			"token":   "!",
+			"logprob": -0.13410144,
+			"top_logprobs": []map[string]any{
+				{
+					"token":   "!",
+					"logprob": -0.13410144,
+				},
+			},
+		},
+		{
+			"token":   " How",
+			"logprob": -0.0009250381,
+			"top_logprobs": []map[string]any{
+				{
+					"token":   " How",
+					"logprob": -0.0009250381,
+				},
+			},
+		},
+		{
+			"token":   " can",
+			"logprob": -0.047709424,
+			"top_logprobs": []map[string]any{
+				{
+					"token":   " can",
+					"logprob": -0.047709424,
+				},
+			},
+		},
+		{
+			"token":   " I",
+			"logprob": -0.000009014684,
+			"top_logprobs": []map[string]any{
+				{
+					"token":   " I",
+					"logprob": -0.000009014684,
+				},
+			},
+		},
+		{
+			"token":   " assist",
+			"logprob": -0.009125131,
+			"top_logprobs": []map[string]any{
+				{
+					"token":   " assist",
+					"logprob": -0.009125131,
+				},
+			},
+		},
+		{
+			"token":   " you",
+			"logprob": -0.0000066306106,
+			"top_logprobs": []map[string]any{
+				{
+					"token":   " you",
+					"logprob": -0.0000066306106,
+				},
+			},
+		},
+		{
+			"token":   " today",
+			"logprob": -0.00011093382,
+			"top_logprobs": []map[string]any{
+				{
+					"token":   " today",
+					"logprob": -0.00011093382,
+				},
+			},
+		},
+		{
+			"token":   "?",
+			"logprob": -0.00004596782,
+			"top_logprobs": []map[string]any{
+				{
+					"token":   "?",
+					"logprob": -0.00004596782,
+				},
+			},
+		},
+	},
+}
+
+type mockServer struct {
+	server   *httptest.Server
+	response map[string]any
+	calls    []mockCall
+}
+
+type mockCall struct {
+	method  string
+	path    string
+	headers map[string]string
+	body    map[string]any
+}
+
+func newMockServer() *mockServer {
+	ms := &mockServer{
+		calls: make([]mockCall, 0),
+	}
+
+	ms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		// Record the call
+		call := mockCall{
+			method:  r.Method,
+			path:    r.URL.Path,
+			headers: make(map[string]string),
+		}
+
+		for k, v := range r.Header {
+			if len(v) > 0 {
+				call.headers[k] = v[0]
+			}
+		}
+
+		// Parse request body
+		if r.Body != nil {
+			var body map[string]any
+			json.NewDecoder(r.Body).Decode(&body)
+			call.body = body
+		}
+
+		ms.calls = append(ms.calls, call)
+
+		// Return mock response
+		w.Header().Set("Content-Type", "application/json")
+		json.NewEncoder(w).Encode(ms.response)
+	}))
+
+	return ms
+}
+
+func (ms *mockServer) close() {
+	ms.server.Close()
+}
+
+func (ms *mockServer) prepareJSONResponse(opts map[string]any) {
+	// Default values
+	response := map[string]any{
+		"id":      "chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd",
+		"object":  "chat.completion",
+		"created": 1711115037,
+		"model":   "gpt-3.5-turbo-0125",
+		"choices": []map[string]any{
+			{
+				"index": 0,
+				"message": map[string]any{
+					"role":    "assistant",
+					"content": "",
+				},
+				"finish_reason": "stop",
+			},
+		},
+		"usage": map[string]any{
+			"prompt_tokens":     4,
+			"total_tokens":      34,
+			"completion_tokens": 30,
+		},
+		"system_fingerprint": "fp_3bc1b5746c",
+	}
+
+	// Override with provided options
+	for k, v := range opts {
+		switch k {
+		case "content":
+			response["choices"].([]map[string]any)[0]["message"].(map[string]any)["content"] = v
+		case "tool_calls":
+			response["choices"].([]map[string]any)[0]["message"].(map[string]any)["tool_calls"] = v
+		case "function_call":
+			response["choices"].([]map[string]any)[0]["message"].(map[string]any)["function_call"] = v
+		case "annotations":
+			response["choices"].([]map[string]any)[0]["message"].(map[string]any)["annotations"] = v
+		case "usage":
+			response["usage"] = v
+		case "finish_reason":
+			response["choices"].([]map[string]any)[0]["finish_reason"] = v
+		case "id":
+			response["id"] = v
+		case "created":
+			response["created"] = v
+		case "model":
+			response["model"] = v
+		case "logprobs":
+			if v != nil {
+				response["choices"].([]map[string]any)[0]["logprobs"] = v
+			}
+		}
+	}
+
+	ms.response = response
+}
+
+func TestDoGenerate(t *testing.T) {
+	t.Parallel()
+
+	t.Run("should extract text response", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"content": "Hello, World!",
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		result, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+		require.Len(t, result.Content, 1)
+
+		textContent, ok := result.Content[0].(ai.TextContent)
+		require.True(t, ok)
+		require.Equal(t, "Hello, World!", textContent.Text)
+	})
+
+	t.Run("should extract usage", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"usage": map[string]any{
+				"prompt_tokens":     20,
+				"total_tokens":      25,
+				"completion_tokens": 5,
+			},
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		result, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+		require.Equal(t, int64(20), result.Usage.InputTokens)
+		require.Equal(t, int64(5), result.Usage.OutputTokens)
+		require.Equal(t, int64(25), result.Usage.TotalTokens)
+	})
+
+	t.Run("should send request body", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		_, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "POST", call.method)
+		require.Equal(t, "/chat/completions", call.path)
+		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+
+		messages, ok := call.body["messages"].([]any)
+		require.True(t, ok)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should support partial usage", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"usage": map[string]any{
+				"prompt_tokens": 20,
+				"total_tokens":  20,
+			},
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		result, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+		require.Equal(t, int64(20), result.Usage.InputTokens)
+		require.Equal(t, int64(0), result.Usage.OutputTokens)
+		require.Equal(t, int64(20), result.Usage.TotalTokens)
+	})
+
+	t.Run("should extract logprobs", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"logprobs": testLogprobs,
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		result, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"logProbs": true,
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result.ProviderMetadata)
+
+		openaiMeta, ok := result.ProviderMetadata["openai"]
+		require.True(t, ok)
+
+		logprobs, ok := openaiMeta["logprobs"]
+		require.True(t, ok)
+		require.NotNil(t, logprobs)
+	})
+
+	t.Run("should extract finish reason", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"finish_reason": "stop",
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		result, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+		require.Equal(t, ai.FinishReasonStop, result.FinishReason)
+	})
+
+	t.Run("should support unknown finish reason", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"finish_reason": "eos",
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		result, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+		require.Equal(t, ai.FinishReasonUnknown, result.FinishReason)
+	})
+
+	t.Run("should pass the model and the messages", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"content": "",
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		_, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should pass settings", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		_, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"logitBias": map[string]int64{
+						"50256": -100,
+					},
+					"parallelToolCalls": false,
+					"user":              "test-user-id",
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		logitBias := call.body["logit_bias"].(map[string]any)
+		require.Equal(t, float64(-100), logitBias["50256"])
+		require.Equal(t, false, call.body["parallel_tool_calls"])
+		require.Equal(t, "test-user-id", call.body["user"])
+	})
+
+	t.Run("should pass reasoningEffort setting", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"content": "",
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("o1-mini")
+
+		_, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"reasoningEffort": "low",
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "o1-mini", call.body["model"])
+		require.Equal(t, "low", call.body["reasoning_effort"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should pass textVerbosity setting", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"content": "",
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-4o")
+
+		_, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"textVerbosity": "low",
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "gpt-4o", call.body["model"])
+		require.Equal(t, "low", call.body["verbosity"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should pass tools and toolChoice", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"content": "",
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		_, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			Tools: []ai.Tool{
+				ai.FunctionTool{
+					Name: "test-tool",
+					InputSchema: map[string]any{
+						"type": "object",
+						"properties": map[string]any{
+							"value": map[string]any{
+								"type": "string",
+							},
+						},
+						"required":             []string{"value"},
+						"additionalProperties": false,
+						"$schema":              "http://json-schema.org/draft-07/schema#",
+					},
+				},
+			},
+			ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0],
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		tools := call.body["tools"].([]any)
+		require.Len(t, tools, 1)
+
+		tool := tools[0].(map[string]any)
+		require.Equal(t, "function", tool["type"])
+
+		function := tool["function"].(map[string]any)
+		require.Equal(t, "test-tool", function["name"])
+		require.Equal(t, false, function["strict"])
+
+		toolChoice := call.body["tool_choice"].(map[string]any)
+		require.Equal(t, "function", toolChoice["type"])
+
+		toolChoiceFunction := toolChoice["function"].(map[string]any)
+		require.Equal(t, "test-tool", toolChoiceFunction["name"])
+	})
+
+	t.Run("should parse tool results", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"tool_calls": []map[string]any{
+				{
+					"id":   "call_O17Uplv4lJvD6DVdIvFFeRMw",
+					"type": "function",
+					"function": map[string]any{
+						"name":      "test-tool",
+						"arguments": `{"value":"Spark"}`,
+					},
+				},
+			},
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		result, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			Tools: []ai.Tool{
+				ai.FunctionTool{
+					Name: "test-tool",
+					InputSchema: map[string]any{
+						"type": "object",
+						"properties": map[string]any{
+							"value": map[string]any{
+								"type": "string",
+							},
+						},
+						"required":             []string{"value"},
+						"additionalProperties": false,
+						"$schema":              "http://json-schema.org/draft-07/schema#",
+					},
+				},
+			},
+			ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0],
+		})
+
+		require.NoError(t, err)
+		require.Len(t, result.Content, 1)
+
+		toolCall, ok := result.Content[0].(ai.ToolCallContent)
+		require.True(t, ok)
+		require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", toolCall.ToolCallID)
+		require.Equal(t, "test-tool", toolCall.ToolName)
+		require.Equal(t, `{"value":"Spark"}`, toolCall.Input)
+	})
+
+	t.Run("should parse annotations/citations", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"content": "Based on the search results [doc1], I found information.",
+			"annotations": []map[string]any{
+				{
+					"type": "url_citation",
+					"url_citation": map[string]any{
+						"start_index": 24,
+						"end_index":   29,
+						"url":         "https://example.com/doc1.pdf",
+						"title":       "Document 1",
+					},
+				},
+			},
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		result, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+		require.Len(t, result.Content, 2)
+
+		textContent, ok := result.Content[0].(ai.TextContent)
+		require.True(t, ok)
+		require.Equal(t, "Based on the search results [doc1], I found information.", textContent.Text)
+
+		sourceContent, ok := result.Content[1].(ai.SourceContent)
+		require.True(t, ok)
+		require.Equal(t, ai.SourceTypeURL, sourceContent.SourceType)
+		require.Equal(t, "https://example.com/doc1.pdf", sourceContent.URL)
+		require.Equal(t, "Document 1", sourceContent.Title)
+		require.NotEmpty(t, sourceContent.ID)
+	})
+
+	t.Run("should return cached_tokens in prompt_details_tokens", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"usage": map[string]any{
+				"prompt_tokens":     15,
+				"completion_tokens": 20,
+				"total_tokens":      35,
+				"prompt_tokens_details": map[string]any{
+					"cached_tokens": 1152,
+				},
+			},
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-4o-mini")
+
+		result, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+		require.Equal(t, int64(1152), result.Usage.CacheReadTokens)
+		require.Equal(t, int64(15), result.Usage.InputTokens)
+		require.Equal(t, int64(20), result.Usage.OutputTokens)
+		require.Equal(t, int64(35), result.Usage.TotalTokens)
+	})
+
+	t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"usage": map[string]any{
+				"prompt_tokens":     15,
+				"completion_tokens": 20,
+				"total_tokens":      35,
+				"completion_tokens_details": map[string]any{
+					"accepted_prediction_tokens": 123,
+					"rejected_prediction_tokens": 456,
+				},
+			},
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-4o-mini")
+
+		result, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result.ProviderMetadata)
+
+		openaiMeta, ok := result.ProviderMetadata["openai"]
+		require.True(t, ok)
+		require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"])
+		require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"])
+	})
+
+	t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("o1-preview")
+
+		result, err := model.Generate(context.Background(), ai.Call{
+			Prompt:           testPrompt,
+			Temperature:      &[]float64{0.5}[0],
+			TopP:             &[]float64{0.7}[0],
+			FrequencyPenalty: &[]float64{0.2}[0],
+			PresencePenalty:  &[]float64{0.3}[0],
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "o1-preview", call.body["model"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+
+		// These should not be present
+		require.Nil(t, call.body["temperature"])
+		require.Nil(t, call.body["top_p"])
+		require.Nil(t, call.body["frequency_penalty"])
+		require.Nil(t, call.body["presence_penalty"])
+
+		// Should have warnings
+		require.Len(t, result.Warnings, 4)
+		require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
+		require.Equal(t, "temperature", result.Warnings[0].Setting)
+		require.Contains(t, result.Warnings[0].Details, "temperature is not supported for reasoning models")
+	})
+
+	t.Run("should convert maxOutputTokens to max_completion_tokens for reasoning models", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("o1-preview")
+
+		_, err := model.Generate(context.Background(), ai.Call{
+			Prompt:          testPrompt,
+			MaxOutputTokens: &[]int64{1000}[0],
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "o1-preview", call.body["model"])
+		require.Equal(t, float64(1000), call.body["max_completion_tokens"])
+		require.Nil(t, call.body["max_tokens"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should return reasoning tokens", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"usage": map[string]any{
+				"prompt_tokens":     15,
+				"completion_tokens": 20,
+				"total_tokens":      35,
+				"completion_tokens_details": map[string]any{
+					"reasoning_tokens": 10,
+				},
+			},
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("o1-preview")
+
+		result, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+		require.Equal(t, int64(15), result.Usage.InputTokens)
+		require.Equal(t, int64(20), result.Usage.OutputTokens)
+		require.Equal(t, int64(35), result.Usage.TotalTokens)
+		require.Equal(t, int64(10), result.Usage.ReasoningTokens)
+	})
+
+	t.Run("should send max_completion_tokens extension setting", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"model": "o1-preview",
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("o1-preview")
+
+		_, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"maxCompletionTokens": 255,
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "o1-preview", call.body["model"])
+		require.Equal(t, float64(255), call.body["max_completion_tokens"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should send prediction extension setting", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"content": "",
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		_, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"prediction": map[string]any{
+						"type":    "content",
+						"content": "Hello, World!",
+					},
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+
+		prediction := call.body["prediction"].(map[string]any)
+		require.Equal(t, "content", prediction["type"])
+		require.Equal(t, "Hello, World!", prediction["content"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should send store extension setting", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"content": "",
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		_, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"store": true,
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+		require.Equal(t, true, call.body["store"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should send metadata extension values", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"content": "",
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		_, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"metadata": map[string]any{
+						"custom": "value",
+					},
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+
+		metadata := call.body["metadata"].(map[string]any)
+		require.Equal(t, "value", metadata["custom"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should send promptCacheKey extension value", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"content": "",
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		_, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"promptCacheKey": "test-cache-key-123",
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+		require.Equal(t, "test-cache-key-123", call.body["prompt_cache_key"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should send safetyIdentifier extension value", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"content": "",
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		_, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"safetyIdentifier": "test-safety-identifier-123",
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+		require.Equal(t, "test-safety-identifier-123", call.body["safety_identifier"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should remove temperature setting for search preview models", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-4o-search-preview")
+
+		result, err := model.Generate(context.Background(), ai.Call{
+			Prompt:      testPrompt,
+			Temperature: &[]float64{0.7}[0],
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "gpt-4o-search-preview", call.body["model"])
+		require.Nil(t, call.body["temperature"])
+
+		require.Len(t, result.Warnings, 1)
+		require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
+		require.Equal(t, "temperature", result.Warnings[0].Setting)
+		require.Contains(t, result.Warnings[0].Details, "search preview models")
+	})
+
+	t.Run("should send serviceTier flex processing setting", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"content": "",
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("o3-mini")
+
+		_, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"serviceTier": "flex",
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "o3-mini", call.body["model"])
+		require.Equal(t, "flex", call.body["service_tier"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should show warning when using flex processing with unsupported model", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-4o-mini")
+
+		result, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"serviceTier": "flex",
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Nil(t, call.body["service_tier"])
+
+		require.Len(t, result.Warnings, 1)
+		require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
+		require.Equal(t, "serviceTier", result.Warnings[0].Setting)
+		require.Contains(t, result.Warnings[0].Details, "flex processing is only available")
+	})
+
+	t.Run("should send serviceTier priority processing setting", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-4o-mini")
+
+		_, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"serviceTier": "priority",
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "gpt-4o-mini", call.body["model"])
+		require.Equal(t, "priority", call.body["service_tier"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should show warning when using priority processing with unsupported model", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		result, err := model.Generate(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"serviceTier": "priority",
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Nil(t, call.body["service_tier"])
+
+		require.Len(t, result.Warnings, 1)
+		require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
+		require.Equal(t, "serviceTier", result.Warnings[0].Setting)
+		require.Contains(t, result.Warnings[0].Details, "priority processing is only available")
+	})
+}
+
+type streamingMockServer struct {
+	server *httptest.Server
+	chunks []string
+	calls  []mockCall
+}
+
+func newStreamingMockServer() *streamingMockServer {
+	sms := &streamingMockServer{
+		calls: make([]mockCall, 0),
+	}
+
+	sms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		// Record the call
+		call := mockCall{
+			method:  r.Method,
+			path:    r.URL.Path,
+			headers: make(map[string]string),
+		}
+
+		for k, v := range r.Header {
+			if len(v) > 0 {
+				call.headers[k] = v[0]
+			}
+		}
+
+		// Parse request body
+		if r.Body != nil {
+			var body map[string]any
+			json.NewDecoder(r.Body).Decode(&body)
+			call.body = body
+		}
+
+		sms.calls = append(sms.calls, call)
+
+		// Set streaming headers
+		w.Header().Set("Content-Type", "text/event-stream")
+		w.Header().Set("Cache-Control", "no-cache")
+		w.Header().Set("Connection", "keep-alive")
+
+		// Add custom headers if any
+		for _, chunk := range sms.chunks {
+			if strings.HasPrefix(chunk, "HEADER:") {
+				parts := strings.SplitN(chunk[7:], ":", 2)
+				if len(parts) == 2 {
+					w.Header().Set(parts[0], parts[1])
+				}
+				continue
+			}
+		}
+
+		w.WriteHeader(http.StatusOK)
+
+		// Write chunks
+		for _, chunk := range sms.chunks {
+			if strings.HasPrefix(chunk, "HEADER:") {
+				continue
+			}
+			w.Write([]byte(chunk))
+			if f, ok := w.(http.Flusher); ok {
+				f.Flush()
+			}
+		}
+	}))
+
+	return sms
+}
+
+func (sms *streamingMockServer) close() {
+	sms.server.Close()
+}
+
+func (sms *streamingMockServer) prepareStreamResponse(opts map[string]any) {
+	content := []string{}
+	if c, ok := opts["content"].([]string); ok {
+		content = c
+	}
+
+	usage := map[string]any{
+		"prompt_tokens":     17,
+		"total_tokens":      244,
+		"completion_tokens": 227,
+	}
+	if u, ok := opts["usage"].(map[string]any); ok {
+		usage = u
+	}
+
+	logprobs := map[string]any{}
+	if l, ok := opts["logprobs"].(map[string]any); ok {
+		logprobs = l
+	}
+
+	finishReason := "stop"
+	if fr, ok := opts["finish_reason"].(string); ok {
+		finishReason = fr
+	}
+
+	model := "gpt-3.5-turbo-0613"
+	if m, ok := opts["model"].(string); ok {
+		model = m
+	}
+
+	headers := map[string]string{}
+	if h, ok := opts["headers"].(map[string]string); ok {
+		headers = h
+	}
+
+	chunks := []string{}
+
+	// Add custom headers
+	for k, v := range headers {
+		chunks = append(chunks, "HEADER:"+k+":"+v)
+	}
+
+	// Initial chunk with role
+	initialChunk := map[string]any{
+		"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
+		"object":             "chat.completion.chunk",
+		"created":            1702657020,
+		"model":              model,
+		"system_fingerprint": nil,
+		"choices": []map[string]any{
+			{
+				"index": 0,
+				"delta": map[string]any{
+					"role":    "assistant",
+					"content": "",
+				},
+				"finish_reason": nil,
+			},
+		},
+	}
+	initialData, _ := json.Marshal(initialChunk)
+	chunks = append(chunks, "data: "+string(initialData)+"\n\n")
+
+	// Content chunks
+	for i, text := range content {
+		contentChunk := map[string]any{
+			"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
+			"object":             "chat.completion.chunk",
+			"created":            1702657020,
+			"model":              model,
+			"system_fingerprint": nil,
+			"choices": []map[string]any{
+				{
+					"index": 1,
+					"delta": map[string]any{
+						"content": text,
+					},
+					"finish_reason": nil,
+				},
+			},
+		}
+		contentData, _ := json.Marshal(contentChunk)
+		chunks = append(chunks, "data: "+string(contentData)+"\n\n")
+
+		// Add annotations if this is the last content chunk and we have annotations
+		if i == len(content)-1 {
+			if annotations, ok := opts["annotations"].([]map[string]any); ok {
+				annotationChunk := map[string]any{
+					"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
+					"object":             "chat.completion.chunk",
+					"created":            1702657020,
+					"model":              model,
+					"system_fingerprint": nil,
+					"choices": []map[string]any{
+						{
+							"index": 1,
+							"delta": map[string]any{
+								"annotations": annotations,
+							},
+							"finish_reason": nil,
+						},
+					},
+				}
+				annotationData, _ := json.Marshal(annotationChunk)
+				chunks = append(chunks, "data: "+string(annotationData)+"\n\n")
+			}
+		}
+	}
+
+	// Finish chunk
+	finishChunk := map[string]any{
+		"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
+		"object":             "chat.completion.chunk",
+		"created":            1702657020,
+		"model":              model,
+		"system_fingerprint": nil,
+		"choices": []map[string]any{
+			{
+				"index":         0,
+				"delta":         map[string]any{},
+				"finish_reason": finishReason,
+			},
+		},
+	}
+
+	if len(logprobs) > 0 {
+		finishChunk["choices"].([]map[string]any)[0]["logprobs"] = logprobs
+	}
+
+	finishData, _ := json.Marshal(finishChunk)
+	chunks = append(chunks, "data: "+string(finishData)+"\n\n")
+
+	// Usage chunk
+	usageChunk := map[string]any{
+		"id":                 "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
+		"object":             "chat.completion.chunk",
+		"created":            1702657020,
+		"model":              model,
+		"system_fingerprint": "fp_3bc1b5746c",
+		"choices":            []map[string]any{},
+		"usage":              usage,
+	}
+	usageData, _ := json.Marshal(usageChunk)
+	chunks = append(chunks, "data: "+string(usageData)+"\n\n")
+
+	// Done
+	chunks = append(chunks, "data: [DONE]\n\n")
+
+	sms.chunks = chunks
+}
+
+func (sms *streamingMockServer) prepareToolStreamResponse() {
+	chunks := []string{
+		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"value"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
+		`data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
+		"data: [DONE]\n\n",
+	}
+	sms.chunks = chunks
+}
+
+func (sms *streamingMockServer) prepareErrorStreamResponse() {
+	chunks := []string{
+		`data: {"error":{"message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error.","type":"server_error","param":null,"code":null}}` + "\n\n",
+		"data: [DONE]\n\n",
+	}
+	sms.chunks = chunks
+}
+
+func collectStreamParts(stream ai.StreamResponse) ([]ai.StreamPart, error) {
+	var parts []ai.StreamPart
+	for part := range stream {
+		parts = append(parts, part)
+		if part.Type == ai.StreamPartTypeError {
+			break
+		}
+		if part.Type == ai.StreamPartTypeFinish {
+			break
+		}
+	}
+	return parts, nil
+}
+
+func TestDoStream(t *testing.T) {
+	t.Parallel()
+
+	t.Run("should stream text deltas", func(t *testing.T) {
+		t.Parallel()
+
+		server := newStreamingMockServer()
+		defer server.close()
+
+		server.prepareStreamResponse(map[string]any{
+			"content":       []string{"Hello", ", ", "World!"},
+			"finish_reason": "stop",
+			"usage": map[string]any{
+				"prompt_tokens":     17,
+				"total_tokens":      244,
+				"completion_tokens": 227,
+			},
+			"logprobs": testLogprobs,
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		stream, err := model.Stream(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+
+		parts, err := collectStreamParts(stream)
+		require.NoError(t, err)
+
+		// Verify stream structure
+		require.True(t, len(parts) >= 4) // text-start, deltas, text-end, finish
+
+		// Find text parts
+		var textStart, textEnd, finish int = -1, -1, -1
+		var deltas []string
+
+		for i, part := range parts {
+			switch part.Type {
+			case ai.StreamPartTypeTextStart:
+				textStart = i
+			case ai.StreamPartTypeTextDelta:
+				deltas = append(deltas, part.Delta)
+			case ai.StreamPartTypeTextEnd:
+				textEnd = i
+			case ai.StreamPartTypeFinish:
+				finish = i
+			}
+		}
+
+		require.NotEqual(t, -1, textStart)
+		require.NotEqual(t, -1, textEnd)
+		require.NotEqual(t, -1, finish)
+		require.Equal(t, []string{"Hello", ", ", "World!"}, deltas)
+
+		// Check finish part
+		finishPart := parts[finish]
+		require.Equal(t, ai.FinishReasonStop, finishPart.FinishReason)
+		require.Equal(t, int64(17), finishPart.Usage.InputTokens)
+		require.Equal(t, int64(227), finishPart.Usage.OutputTokens)
+		require.Equal(t, int64(244), finishPart.Usage.TotalTokens)
+	})
+
+	t.Run("should stream tool deltas", func(t *testing.T) {
+		t.Parallel()
+
+		server := newStreamingMockServer()
+		defer server.close()
+
+		server.prepareToolStreamResponse()
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		stream, err := model.Stream(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			Tools: []ai.Tool{
+				ai.FunctionTool{
+					Name: "test-tool",
+					InputSchema: map[string]any{
+						"type": "object",
+						"properties": map[string]any{
+							"value": map[string]any{
+								"type": "string",
+							},
+						},
+						"required":             []string{"value"},
+						"additionalProperties": false,
+						"$schema":              "http://json-schema.org/draft-07/schema#",
+					},
+				},
+			},
+		})
+
+		require.NoError(t, err)
+
+		parts, err := collectStreamParts(stream)
+		require.NoError(t, err)
+
+		// Find tool-related parts
+		toolInputStart, toolInputEnd, toolCall := -1, -1, -1
+		var toolDeltas []string
+
+		for i, part := range parts {
+			switch part.Type {
+			case ai.StreamPartTypeToolInputStart:
+				toolInputStart = i
+				require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
+				require.Equal(t, "test-tool", part.ToolCallName)
+			case ai.StreamPartTypeToolInputDelta:
+				toolDeltas = append(toolDeltas, part.Delta)
+			case ai.StreamPartTypeToolInputEnd:
+				toolInputEnd = i
+			case ai.StreamPartTypeToolCall:
+				toolCall = i
+				require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
+				require.Equal(t, "test-tool", part.ToolCallName)
+				require.Equal(t, `{"value":"Sparkle Day"}`, part.ToolCallInput)
+			}
+		}
+
+		require.NotEqual(t, -1, toolInputStart)
+		require.NotEqual(t, -1, toolInputEnd)
+		require.NotEqual(t, -1, toolCall)
+
+		// Verify tool deltas combine to form the complete input
+		fullInput := ""
+		for _, delta := range toolDeltas {
+			fullInput += delta
+		}
+		require.Equal(t, `{"value":"Sparkle Day"}`, fullInput)
+	})
+
+	t.Run("should stream annotations/citations", func(t *testing.T) {
+		t.Parallel()
+
+		server := newStreamingMockServer()
+		defer server.close()
+
+		server.prepareStreamResponse(map[string]any{
+			"content": []string{"Based on search results"},
+			"annotations": []map[string]any{
+				{
+					"type": "url_citation",
+					"url_citation": map[string]any{
+						"start_index": 24,
+						"end_index":   29,
+						"url":         "https://example.com/doc1.pdf",
+						"title":       "Document 1",
+					},
+				},
+			},
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		stream, err := model.Stream(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+
+		parts, err := collectStreamParts(stream)
+		require.NoError(t, err)
+
+		// Find source part
+		var sourcePart *ai.StreamPart
+		for _, part := range parts {
+			if part.Type == ai.StreamPartTypeSource {
+				sourcePart = &part
+				break
+			}
+		}
+
+		require.NotNil(t, sourcePart)
+		require.Equal(t, ai.SourceTypeURL, sourcePart.SourceType)
+		require.Equal(t, "https://example.com/doc1.pdf", sourcePart.URL)
+		require.Equal(t, "Document 1", sourcePart.Title)
+		require.NotEmpty(t, sourcePart.ID)
+	})
+
+	t.Run("should handle error stream parts", func(t *testing.T) {
+		t.Parallel()
+
+		server := newStreamingMockServer()
+		defer server.close()
+
+		server.prepareErrorStreamResponse()
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		stream, err := model.Stream(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+
+		parts, err := collectStreamParts(stream)
+		require.NoError(t, err)
+
+		// Should have error and finish parts
+		require.True(t, len(parts) >= 1)
+
+		// Find error part
+		var errorPart *ai.StreamPart
+		for _, part := range parts {
+			if part.Type == ai.StreamPartTypeError {
+				errorPart = &part
+				break
+			}
+		}
+
+		require.NotNil(t, errorPart)
+		require.NotNil(t, errorPart.Error)
+	})
+
+	t.Run("should send request body", func(t *testing.T) {
+		t.Parallel()
+
+		server := newStreamingMockServer()
+		defer server.close()
+
+		server.prepareStreamResponse(map[string]any{
+			"content": []string{},
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		_, err := model.Stream(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "POST", call.method)
+		require.Equal(t, "/chat/completions", call.path)
+		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+		require.Equal(t, true, call.body["stream"])
+
+		streamOptions := call.body["stream_options"].(map[string]any)
+		require.Equal(t, true, streamOptions["include_usage"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should return cached tokens in providerMetadata", func(t *testing.T) {
+		t.Parallel()
+
+		server := newStreamingMockServer()
+		defer server.close()
+
+		server.prepareStreamResponse(map[string]any{
+			"content": []string{},
+			"usage": map[string]any{
+				"prompt_tokens":     15,
+				"completion_tokens": 20,
+				"total_tokens":      35,
+				"prompt_tokens_details": map[string]any{
+					"cached_tokens": 1152,
+				},
+			},
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		stream, err := model.Stream(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+
+		parts, err := collectStreamParts(stream)
+		require.NoError(t, err)
+
+		// Find finish part
+		var finishPart *ai.StreamPart
+		for _, part := range parts {
+			if part.Type == ai.StreamPartTypeFinish {
+				finishPart = &part
+				break
+			}
+		}
+
+		require.NotNil(t, finishPart)
+		require.Equal(t, int64(1152), finishPart.Usage.CacheReadTokens)
+		require.Equal(t, int64(15), finishPart.Usage.InputTokens)
+		require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
+		require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
+	})
+
+	t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
+		t.Parallel()
+
+		server := newStreamingMockServer()
+		defer server.close()
+
+		server.prepareStreamResponse(map[string]any{
+			"content": []string{},
+			"usage": map[string]any{
+				"prompt_tokens":     15,
+				"completion_tokens": 20,
+				"total_tokens":      35,
+				"completion_tokens_details": map[string]any{
+					"accepted_prediction_tokens": 123,
+					"rejected_prediction_tokens": 456,
+				},
+			},
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		stream, err := model.Stream(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+
+		parts, err := collectStreamParts(stream)
+		require.NoError(t, err)
+
+		// Find finish part
+		var finishPart *ai.StreamPart
+		for _, part := range parts {
+			if part.Type == ai.StreamPartTypeFinish {
+				finishPart = &part
+				break
+			}
+		}
+
+		require.NotNil(t, finishPart)
+		require.NotNil(t, finishPart.ProviderMetadata)
+
+		openaiMeta, ok := finishPart.ProviderMetadata["openai"]
+		require.True(t, ok)
+		require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"])
+		require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"])
+	})
+
+	t.Run("should send store extension setting", func(t *testing.T) {
+		t.Parallel()
+
+		server := newStreamingMockServer()
+		defer server.close()
+
+		server.prepareStreamResponse(map[string]any{
+			"content": []string{},
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		_, err := model.Stream(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"store": true,
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+		require.Equal(t, true, call.body["stream"])
+		require.Equal(t, true, call.body["store"])
+
+		streamOptions := call.body["stream_options"].(map[string]any)
+		require.Equal(t, true, streamOptions["include_usage"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should send metadata extension values", func(t *testing.T) {
+		t.Parallel()
+
+		server := newStreamingMockServer()
+		defer server.close()
+
+		server.prepareStreamResponse(map[string]any{
+			"content": []string{},
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-3.5-turbo")
+
+		_, err := model.Stream(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"metadata": map[string]any{
+						"custom": "value",
+					},
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+		require.Equal(t, true, call.body["stream"])
+
+		metadata := call.body["metadata"].(map[string]any)
+		require.Equal(t, "value", metadata["custom"])
+
+		streamOptions := call.body["stream_options"].(map[string]any)
+		require.Equal(t, true, streamOptions["include_usage"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should send serviceTier flex processing setting in streaming", func(t *testing.T) {
+		t.Parallel()
+
+		server := newStreamingMockServer()
+		defer server.close()
+
+		server.prepareStreamResponse(map[string]any{
+			"content": []string{},
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("o3-mini")
+
+		_, err := model.Stream(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"serviceTier": "flex",
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "o3-mini", call.body["model"])
+		require.Equal(t, "flex", call.body["service_tier"])
+		require.Equal(t, true, call.body["stream"])
+
+		streamOptions := call.body["stream_options"].(map[string]any)
+		require.Equal(t, true, streamOptions["include_usage"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should send serviceTier priority processing setting in streaming", func(t *testing.T) {
+		t.Parallel()
+
+		server := newStreamingMockServer()
+		defer server.close()
+
+		server.prepareStreamResponse(map[string]any{
+			"content": []string{},
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("gpt-4o-mini")
+
+		_, err := model.Stream(context.Background(), ai.Call{
+			Prompt: testPrompt,
+			ProviderOptions: ai.ProviderOptions{
+				"openai": map[string]any{
+					"serviceTier": "priority",
+				},
+			},
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "gpt-4o-mini", call.body["model"])
+		require.Equal(t, "priority", call.body["service_tier"])
+		require.Equal(t, true, call.body["stream"])
+
+		streamOptions := call.body["stream_options"].(map[string]any)
+		require.Equal(t, true, streamOptions["include_usage"])
+
+		messages := call.body["messages"].([]any)
+		require.Len(t, messages, 1)
+
+		message := messages[0].(map[string]any)
+		require.Equal(t, "user", message["role"])
+		require.Equal(t, "Hello", message["content"])
+	})
+
+	t.Run("should stream text delta for reasoning models", func(t *testing.T) {
+		t.Parallel()
+
+		server := newStreamingMockServer()
+		defer server.close()
+
+		server.prepareStreamResponse(map[string]any{
+			"content": []string{"Hello, World!"},
+			"model":   "o1-preview",
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("o1-preview")
+
+		stream, err := model.Stream(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+
+		parts, err := collectStreamParts(stream)
+		require.NoError(t, err)
+
+		// Find text parts
+		var textDeltas []string
+		for _, part := range parts {
+			if part.Type == ai.StreamPartTypeTextDelta {
+				textDeltas = append(textDeltas, part.Delta)
+			}
+		}
+
+		// Should contain the text content (without empty delta)
+		require.Equal(t, []string{"Hello, World!"}, textDeltas)
+	})
+
+	t.Run("should send reasoning tokens", func(t *testing.T) {
+		t.Parallel()
+
+		server := newStreamingMockServer()
+		defer server.close()
+
+		server.prepareStreamResponse(map[string]any{
+			"content": []string{"Hello, World!"},
+			"model":   "o1-preview",
+			"usage": map[string]any{
+				"prompt_tokens":     15,
+				"completion_tokens": 20,
+				"total_tokens":      35,
+				"completion_tokens_details": map[string]any{
+					"reasoning_tokens": 10,
+				},
+			},
+		})
+
+		provider := NewOpenAIProvider(
+			WithOpenAIApiKey("test-api-key"),
+			WithOpenAIBaseURL(server.server.URL),
+		)
+		model := provider.LanguageModel("o1-preview")
+
+		stream, err := model.Stream(context.Background(), ai.Call{
+			Prompt: testPrompt,
+		})
+
+		require.NoError(t, err)
+
+		parts, err := collectStreamParts(stream)
+		require.NoError(t, err)
+
+		// Find finish part
+		var finishPart *ai.StreamPart
+		for _, part := range parts {
+			if part.Type == ai.StreamPartTypeFinish {
+				finishPart = &part
+				break
+			}
+		}
+
+		require.NotNil(t, finishPart)
+		require.Equal(t, int64(15), finishPart.Usage.InputTokens)
+		require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
+		require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
+		require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens)
+	})
+}