diff --git a/openai/openai.go b/openai/openai.go index 1acdfd1b0480bde24a553e185823ca9c8ac1f6f0..19fe1db20700a629de64f4687afeab8922df7be9 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -160,143 +160,6 @@ func (o languageModel) Provider() string { return o.provider } -func prepareCallWithOptions(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) { - if call.ProviderOptions == nil { - return nil, nil - } - var warnings []ai.CallWarning - providerOptions := &ProviderOptions{} - if v, ok := call.ProviderOptions[Name]; ok { - providerOptions, ok = v.(*ProviderOptions) - if !ok { - return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil) - } - } - - if providerOptions.LogitBias != nil { - params.LogitBias = providerOptions.LogitBias - } - if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil { - providerOptions.LogProbs = nil - } - if providerOptions.LogProbs != nil { - params.Logprobs = param.NewOpt(*providerOptions.LogProbs) - } - if providerOptions.TopLogProbs != nil { - params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs) - } - if providerOptions.User != nil { - params.User = param.NewOpt(*providerOptions.User) - } - if providerOptions.ParallelToolCalls != nil { - params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls) - } - if providerOptions.MaxCompletionTokens != nil { - params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens) - } - - if providerOptions.TextVerbosity != nil { - params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity) - } - if providerOptions.Prediction != nil { - // Convert map[string]any to ChatCompletionPredictionContentParam - if content, ok := providerOptions.Prediction["content"]; ok { - if contentStr, ok := content.(string); ok { - params.Prediction = openai.ChatCompletionPredictionContentParam{ - Content: openai.ChatCompletionPredictionContentContentUnionParam{ - OfString: param.NewOpt(contentStr), - }, - } - } - } - } - if providerOptions.Store != nil { - params.Store = param.NewOpt(*providerOptions.Store) - } - if providerOptions.Metadata != nil { - // Convert map[string]any to map[string]string - metadata := make(map[string]string) - for k, v := range providerOptions.Metadata { - if str, ok := v.(string); ok { - metadata[k] = str - } - } - params.Metadata = metadata - } - if providerOptions.PromptCacheKey != nil { - params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey) - } - if providerOptions.SafetyIdentifier != nil { - params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier) - } - if providerOptions.ServiceTier != nil { - params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier) - } - - if providerOptions.ReasoningEffort != nil { - switch *providerOptions.ReasoningEffort { - case ReasoningEffortMinimal: - params.ReasoningEffort = shared.ReasoningEffortMinimal - case ReasoningEffortLow: - params.ReasoningEffort = shared.ReasoningEffortLow - case ReasoningEffortMedium: - params.ReasoningEffort = shared.ReasoningEffortMedium - case ReasoningEffortHigh: - params.ReasoningEffort = shared.ReasoningEffortHigh - default: - return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort) - } - } - - if isReasoningModel(model.Model()) { - if providerOptions.LogitBias != nil { - params.LogitBias = nil - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "LogitBias", - Message: "LogitBias is not supported for reasoning models", - }) - } - if providerOptions.LogProbs != nil { - params.Logprobs = param.Opt[bool]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "Logprobs", - Message: "Logprobs is not supported for reasoning models", - }) - } - if providerOptions.TopLogProbs != nil { - params.TopLogprobs = param.Opt[int64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "TopLogprobs", - Message: "TopLogprobs is not supported for reasoning models", - }) - } - } - - // Handle service tier validation - if providerOptions.ServiceTier != nil { - serviceTier := *providerOptions.ServiceTier - if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) { - params.ServiceTier = "" - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "ServiceTier", - Details: "flex processing is only available for o3, o4-mini, and gpt-5 models", - }) - } else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) { - params.ServiceTier = "" - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "ServiceTier", - Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported", - }) - } - } - return warnings, nil -} - func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) { params := &openai.ChatCompletionNewParams{} messages, warnings := toPrompt(call.Prompt) @@ -797,6 +660,143 @@ func (o *provider) Name() string { return Name } +func prepareCallWithOptions(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) { + if call.ProviderOptions == nil { + return nil, nil + } + var warnings []ai.CallWarning + providerOptions := &ProviderOptions{} + if v, ok := call.ProviderOptions[Name]; ok { + providerOptions, ok = v.(*ProviderOptions) + if !ok { + return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil) + } + } + + if providerOptions.LogitBias != nil { + params.LogitBias = providerOptions.LogitBias + } + if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil { + providerOptions.LogProbs = nil + } + if providerOptions.LogProbs != nil { + params.Logprobs = param.NewOpt(*providerOptions.LogProbs) + } + if providerOptions.TopLogProbs != nil { + params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs) + } + if providerOptions.User != nil { + params.User = param.NewOpt(*providerOptions.User) + } + if providerOptions.ParallelToolCalls != nil { + params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls) + } + if providerOptions.MaxCompletionTokens != nil { + params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens) + } + + if providerOptions.TextVerbosity != nil { + params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity) + } + if providerOptions.Prediction != nil { + // Convert map[string]any to ChatCompletionPredictionContentParam + if content, ok := providerOptions.Prediction["content"]; ok { + if contentStr, ok := content.(string); ok { + params.Prediction = openai.ChatCompletionPredictionContentParam{ + Content: openai.ChatCompletionPredictionContentContentUnionParam{ + OfString: param.NewOpt(contentStr), + }, + } + } + } + } + if providerOptions.Store != nil { + params.Store = param.NewOpt(*providerOptions.Store) + } + if providerOptions.Metadata != nil { + // Convert map[string]any to map[string]string + metadata := make(map[string]string) + for k, v := range providerOptions.Metadata { + if str, ok := v.(string); ok { + metadata[k] = str + } + } + params.Metadata = metadata + } + if providerOptions.PromptCacheKey != nil { + params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey) + } + if providerOptions.SafetyIdentifier != nil { + params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier) + } + if providerOptions.ServiceTier != nil { + params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier) + } + + if providerOptions.ReasoningEffort != nil { + switch *providerOptions.ReasoningEffort { + case ReasoningEffortMinimal: + params.ReasoningEffort = shared.ReasoningEffortMinimal + case ReasoningEffortLow: + params.ReasoningEffort = shared.ReasoningEffortLow + case ReasoningEffortMedium: + params.ReasoningEffort = shared.ReasoningEffortMedium + case ReasoningEffortHigh: + params.ReasoningEffort = shared.ReasoningEffortHigh + default: + return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort) + } + } + + if isReasoningModel(model.Model()) { + if providerOptions.LogitBias != nil { + params.LogitBias = nil + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "LogitBias", + Message: "LogitBias is not supported for reasoning models", + }) + } + if providerOptions.LogProbs != nil { + params.Logprobs = param.Opt[bool]{} + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "Logprobs", + Message: "Logprobs is not supported for reasoning models", + }) + } + if providerOptions.TopLogProbs != nil { + params.TopLogprobs = param.Opt[int64]{} + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "TopLogprobs", + Message: "TopLogprobs is not supported for reasoning models", + }) + } + } + + // Handle service tier validation + if providerOptions.ServiceTier != nil { + serviceTier := *providerOptions.ServiceTier + if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) { + params.ServiceTier = "" + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "ServiceTier", + Details: "flex processing is only available for o3, o4-mini, and gpt-5 models", + }) + } else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) { + params.ServiceTier = "" + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "ServiceTier", + Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported", + }) + } + } + return warnings, nil +} + func mapOpenAiFinishReason(finishReason string) ai.FinishReason { switch finishReason { case "stop":