1package openai
  2
  3import (
  4	"fmt"
  5
  6	"github.com/charmbracelet/fantasy/ai"
  7	"github.com/openai/openai-go/v2"
  8	"github.com/openai/openai-go/v2/packages/param"
  9	"github.com/openai/openai-go/v2/shared"
 10)
 11
 12type PrepareLanguageModelCallFunc = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error)
 13
 14func defaultPrepareLanguageModelCall(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
 15	if call.ProviderOptions == nil {
 16		return nil, nil
 17	}
 18	var warnings []ai.CallWarning
 19	providerOptions := &ProviderOptions{}
 20	if v, ok := call.ProviderOptions[Name]; ok {
 21		providerOptions, ok = v.(*ProviderOptions)
 22		if !ok {
 23			return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
 24		}
 25	}
 26
 27	if providerOptions.LogitBias != nil {
 28		params.LogitBias = providerOptions.LogitBias
 29	}
 30	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 31		providerOptions.LogProbs = nil
 32	}
 33	if providerOptions.LogProbs != nil {
 34		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 35	}
 36	if providerOptions.TopLogProbs != nil {
 37		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 38	}
 39	if providerOptions.User != nil {
 40		params.User = param.NewOpt(*providerOptions.User)
 41	}
 42	if providerOptions.ParallelToolCalls != nil {
 43		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 44	}
 45	if providerOptions.MaxCompletionTokens != nil {
 46		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 47	}
 48
 49	if providerOptions.TextVerbosity != nil {
 50		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 51	}
 52	if providerOptions.Prediction != nil {
 53		// Convert map[string]any to ChatCompletionPredictionContentParam
 54		if content, ok := providerOptions.Prediction["content"]; ok {
 55			if contentStr, ok := content.(string); ok {
 56				params.Prediction = openai.ChatCompletionPredictionContentParam{
 57					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 58						OfString: param.NewOpt(contentStr),
 59					},
 60				}
 61			}
 62		}
 63	}
 64	if providerOptions.Store != nil {
 65		params.Store = param.NewOpt(*providerOptions.Store)
 66	}
 67	if providerOptions.Metadata != nil {
 68		// Convert map[string]any to map[string]string
 69		metadata := make(map[string]string)
 70		for k, v := range providerOptions.Metadata {
 71			if str, ok := v.(string); ok {
 72				metadata[k] = str
 73			}
 74		}
 75		params.Metadata = metadata
 76	}
 77	if providerOptions.PromptCacheKey != nil {
 78		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 79	}
 80	if providerOptions.SafetyIdentifier != nil {
 81		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 82	}
 83	if providerOptions.ServiceTier != nil {
 84		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 85	}
 86
 87	if providerOptions.ReasoningEffort != nil {
 88		switch *providerOptions.ReasoningEffort {
 89		case ReasoningEffortMinimal:
 90			params.ReasoningEffort = shared.ReasoningEffortMinimal
 91		case ReasoningEffortLow:
 92			params.ReasoningEffort = shared.ReasoningEffortLow
 93		case ReasoningEffortMedium:
 94			params.ReasoningEffort = shared.ReasoningEffortMedium
 95		case ReasoningEffortHigh:
 96			params.ReasoningEffort = shared.ReasoningEffortHigh
 97		default:
 98			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 99		}
100	}
101
102	if isReasoningModel(model.Model()) {
103		if providerOptions.LogitBias != nil {
104			params.LogitBias = nil
105			warnings = append(warnings, ai.CallWarning{
106				Type:    ai.CallWarningTypeUnsupportedSetting,
107				Setting: "LogitBias",
108				Message: "LogitBias is not supported for reasoning models",
109			})
110		}
111		if providerOptions.LogProbs != nil {
112			params.Logprobs = param.Opt[bool]{}
113			warnings = append(warnings, ai.CallWarning{
114				Type:    ai.CallWarningTypeUnsupportedSetting,
115				Setting: "Logprobs",
116				Message: "Logprobs is not supported for reasoning models",
117			})
118		}
119		if providerOptions.TopLogProbs != nil {
120			params.TopLogprobs = param.Opt[int64]{}
121			warnings = append(warnings, ai.CallWarning{
122				Type:    ai.CallWarningTypeUnsupportedSetting,
123				Setting: "TopLogprobs",
124				Message: "TopLogprobs is not supported for reasoning models",
125			})
126		}
127	}
128
129	// Handle service tier validation
130	if providerOptions.ServiceTier != nil {
131		serviceTier := *providerOptions.ServiceTier
132		if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
133			params.ServiceTier = ""
134			warnings = append(warnings, ai.CallWarning{
135				Type:    ai.CallWarningTypeUnsupportedSetting,
136				Setting: "ServiceTier",
137				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
138			})
139		} else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
140			params.ServiceTier = ""
141			warnings = append(warnings, ai.CallWarning{
142				Type:    ai.CallWarningTypeUnsupportedSetting,
143				Setting: "ServiceTier",
144				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
145			})
146		}
147	}
148	return warnings, nil
149}