diff --git a/openai/language_model.go b/openai/language_model.go new file mode 100644 index 0000000000000000000000000000000000000000..40a00144a756f17ce401782782f7ef40f82983ec --- /dev/null +++ b/openai/language_model.go @@ -0,0 +1,925 @@ +package openai + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "strings" + + "github.com/charmbracelet/fantasy/ai" + xjson "github.com/charmbracelet/x/json" + "github.com/google/uuid" + "github.com/openai/openai-go/v2" + "github.com/openai/openai-go/v2/packages/param" + "github.com/openai/openai-go/v2/shared" +) + +type languageModel struct { + provider string + modelID string + client openai.Client + prepareCallFunc PrepareLanguageModelCallFunc +} + +type LanguageModelOption = func(*languageModel) + +func WithPrepareLanguageModelCall(fn PrepareLanguageModelCallFunc) LanguageModelOption { + return func(l *languageModel) { + l.prepareCallFunc = fn + } +} + +func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel { + model := languageModel{ + modelID: modelID, + provider: provider, + client: client, + prepareCallFunc: defaultPrepareLanguageModelCall, + } + + for _, o := range opts { + o(&model) + } + return model +} + +type streamToolCall struct { + id string + name string + arguments string + hasFinished bool +} + +// Model implements ai.LanguageModel. +func (o languageModel) Model() string { + return o.modelID +} + +// Provider implements ai.LanguageModel. +func (o languageModel) Provider() string { + return o.provider +} + +func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) { + params := &openai.ChatCompletionNewParams{} + messages, warnings := toPrompt(call.Prompt) + if call.TopK != nil { + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "top_k", + }) + } + params.Messages = messages + params.Model = o.modelID + + if call.MaxOutputTokens != nil { + params.MaxTokens = param.NewOpt(*call.MaxOutputTokens) + } + if call.Temperature != nil { + params.Temperature = param.NewOpt(*call.Temperature) + } + if call.TopP != nil { + params.TopP = param.NewOpt(*call.TopP) + } + if call.FrequencyPenalty != nil { + params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty) + } + if call.PresencePenalty != nil { + params.PresencePenalty = param.NewOpt(*call.PresencePenalty) + } + + if isReasoningModel(o.modelID) { + // remove unsupported settings for reasoning models + // see https://platform.openai.com/docs/guides/reasoning#limitations + if call.Temperature != nil { + params.Temperature = param.Opt[float64]{} + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "temperature", + Details: "temperature is not supported for reasoning models", + }) + } + if call.TopP != nil { + params.TopP = param.Opt[float64]{} + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "TopP", + Details: "TopP is not supported for reasoning models", + }) + } + if call.FrequencyPenalty != nil { + params.FrequencyPenalty = param.Opt[float64]{} + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "FrequencyPenalty", + Details: "FrequencyPenalty is not supported for reasoning models", + }) + } + if call.PresencePenalty != nil { + params.PresencePenalty = param.Opt[float64]{} + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "PresencePenalty", + Details: "PresencePenalty is not supported for reasoning models", + }) + } + + // reasoning models use max_completion_tokens instead of max_tokens + if call.MaxOutputTokens != nil { + if !params.MaxCompletionTokens.Valid() { + params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens) + } + params.MaxTokens = param.Opt[int64]{} + } + } + + // Handle search preview models + if isSearchPreviewModel(o.modelID) { + if call.Temperature != nil { + params.Temperature = param.Opt[float64]{} + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "temperature", + Details: "temperature is not supported for the search preview models and has been removed.", + }) + } + } + + optionsWarnings, err := o.prepareCallFunc(o, params, call) + if err != nil { + return nil, nil, err + } + + if len(optionsWarnings) > 0 { + warnings = append(warnings, optionsWarnings...) + } + + if len(call.Tools) > 0 { + tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice) + params.Tools = tools + if toolChoice != nil { + params.ToolChoice = *toolChoice + } + warnings = append(warnings, toolWarnings...) + } + return params, warnings, nil +} + +func (o languageModel) handleError(err error) error { + var apiErr *openai.Error + if errors.As(err, &apiErr) { + requestDump := apiErr.DumpRequest(true) + responseDump := apiErr.DumpResponse(true) + headers := map[string]string{} + for k, h := range apiErr.Response.Header { + v := h[len(h)-1] + headers[strings.ToLower(k)] = v + } + return ai.NewAPICallError( + apiErr.Message, + apiErr.Request.URL.String(), + string(requestDump), + apiErr.StatusCode, + headers, + string(responseDump), + apiErr, + false, + ) + } + return err +} + +// Generate implements ai.LanguageModel. +func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) { + params, warnings, err := o.prepareParams(call) + if err != nil { + return nil, err + } + response, err := o.client.Chat.Completions.New(ctx, *params) + if err != nil { + return nil, o.handleError(err) + } + + if len(response.Choices) == 0 { + return nil, errors.New("no response generated") + } + choice := response.Choices[0] + content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations)) + text := choice.Message.Content + if text != "" { + content = append(content, ai.TextContent{ + Text: text, + }) + } + + for _, tc := range choice.Message.ToolCalls { + toolCallID := tc.ID + if toolCallID == "" { + toolCallID = uuid.NewString() + } + content = append(content, ai.ToolCallContent{ + ProviderExecuted: false, // TODO: update when handling other tools + ToolCallID: toolCallID, + ToolName: tc.Function.Name, + Input: tc.Function.Arguments, + }) + } + // Handle annotations/citations + for _, annotation := range choice.Message.Annotations { + if annotation.Type == "url_citation" { + content = append(content, ai.SourceContent{ + SourceType: ai.SourceTypeURL, + ID: uuid.NewString(), + URL: annotation.URLCitation.URL, + Title: annotation.URLCitation.Title, + }) + } + } + + completionTokenDetails := response.Usage.CompletionTokensDetails + promptTokenDetails := response.Usage.PromptTokensDetails + + // Build provider metadata + providerMetadata := &ProviderMetadata{} + // Add logprobs if available + if len(choice.Logprobs.Content) > 0 { + providerMetadata.Logprobs = choice.Logprobs.Content + } + + // Add prediction tokens if available + if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 { + if completionTokenDetails.AcceptedPredictionTokens > 0 { + providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens + } + if completionTokenDetails.RejectedPredictionTokens > 0 { + providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens + } + } + + return &ai.Response{ + Content: content, + Usage: ai.Usage{ + InputTokens: response.Usage.PromptTokens, + OutputTokens: response.Usage.CompletionTokens, + TotalTokens: response.Usage.TotalTokens, + ReasoningTokens: completionTokenDetails.ReasoningTokens, + CacheReadTokens: promptTokenDetails.CachedTokens, + }, + FinishReason: mapOpenAiFinishReason(choice.FinishReason), + ProviderMetadata: ai.ProviderMetadata{ + Name: providerMetadata, + }, + Warnings: warnings, + }, nil +} + +// Stream implements ai.LanguageModel. +func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) { + params, warnings, err := o.prepareParams(call) + if err != nil { + return nil, err + } + + params.StreamOptions = openai.ChatCompletionStreamOptionsParam{ + IncludeUsage: openai.Bool(true), + } + + stream := o.client.Chat.Completions.NewStreaming(ctx, *params) + isActiveText := false + toolCalls := make(map[int64]streamToolCall) + + // Build provider metadata for streaming + streamProviderMetadata := &ProviderMetadata{} + acc := openai.ChatCompletionAccumulator{} + var usage ai.Usage + return func(yield func(ai.StreamPart) bool) { + if len(warnings) > 0 { + if !yield(ai.StreamPart{ + Type: ai.StreamPartTypeWarnings, + Warnings: warnings, + }) { + return + } + } + for stream.Next() { + chunk := stream.Current() + acc.AddChunk(chunk) + if chunk.Usage.TotalTokens > 0 { + // we do this here because the acc does not add prompt details + completionTokenDetails := chunk.Usage.CompletionTokensDetails + promptTokenDetails := chunk.Usage.PromptTokensDetails + usage = ai.Usage{ + InputTokens: chunk.Usage.PromptTokens, + OutputTokens: chunk.Usage.CompletionTokens, + TotalTokens: chunk.Usage.TotalTokens, + ReasoningTokens: completionTokenDetails.ReasoningTokens, + CacheReadTokens: promptTokenDetails.CachedTokens, + } + + // Add prediction tokens if available + if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 { + if completionTokenDetails.AcceptedPredictionTokens > 0 { + streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens + } + if completionTokenDetails.RejectedPredictionTokens > 0 { + streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens + } + } + } + if len(chunk.Choices) == 0 { + continue + } + for _, choice := range chunk.Choices { + switch { + case choice.Delta.Content != "": + if !isActiveText { + isActiveText = true + if !yield(ai.StreamPart{ + Type: ai.StreamPartTypeTextStart, + ID: "0", + }) { + return + } + } + if !yield(ai.StreamPart{ + Type: ai.StreamPartTypeTextDelta, + ID: "0", + Delta: choice.Delta.Content, + }) { + return + } + case len(choice.Delta.ToolCalls) > 0: + if isActiveText { + isActiveText = false + if !yield(ai.StreamPart{ + Type: ai.StreamPartTypeTextEnd, + ID: "0", + }) { + return + } + } + + for _, toolCallDelta := range choice.Delta.ToolCalls { + if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok { + if existingToolCall.hasFinished { + continue + } + if toolCallDelta.Function.Arguments != "" { + existingToolCall.arguments += toolCallDelta.Function.Arguments + } + if !yield(ai.StreamPart{ + Type: ai.StreamPartTypeToolInputDelta, + ID: existingToolCall.id, + Delta: toolCallDelta.Function.Arguments, + }) { + return + } + toolCalls[toolCallDelta.Index] = existingToolCall + if xjson.IsValid(existingToolCall.arguments) { + if !yield(ai.StreamPart{ + Type: ai.StreamPartTypeToolInputEnd, + ID: existingToolCall.id, + }) { + return + } + + if !yield(ai.StreamPart{ + Type: ai.StreamPartTypeToolCall, + ID: existingToolCall.id, + ToolCallName: existingToolCall.name, + ToolCallInput: existingToolCall.arguments, + }) { + return + } + existingToolCall.hasFinished = true + toolCalls[toolCallDelta.Index] = existingToolCall + } + } else { + // Does not exist + var err error + if toolCallDelta.Type != "function" { + err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.") + } + if toolCallDelta.ID == "" { + err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.") + } + if toolCallDelta.Function.Name == "" { + err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.") + } + if err != nil { + yield(ai.StreamPart{ + Type: ai.StreamPartTypeError, + Error: o.handleError(stream.Err()), + }) + return + } + + if !yield(ai.StreamPart{ + Type: ai.StreamPartTypeToolInputStart, + ID: toolCallDelta.ID, + ToolCallName: toolCallDelta.Function.Name, + }) { + return + } + toolCalls[toolCallDelta.Index] = streamToolCall{ + id: toolCallDelta.ID, + name: toolCallDelta.Function.Name, + arguments: toolCallDelta.Function.Arguments, + } + + exTc := toolCalls[toolCallDelta.Index] + if exTc.arguments != "" { + if !yield(ai.StreamPart{ + Type: ai.StreamPartTypeToolInputDelta, + ID: exTc.id, + Delta: exTc.arguments, + }) { + return + } + if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) { + if !yield(ai.StreamPart{ + Type: ai.StreamPartTypeToolInputEnd, + ID: toolCallDelta.ID, + }) { + return + } + + if !yield(ai.StreamPart{ + Type: ai.StreamPartTypeToolCall, + ID: exTc.id, + ToolCallName: exTc.name, + ToolCallInput: exTc.arguments, + }) { + return + } + exTc.hasFinished = true + toolCalls[toolCallDelta.Index] = exTc + } + } + continue + } + } + } + } + + // Check for annotations in the delta's raw JSON + for _, choice := range chunk.Choices { + if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 { + for _, annotation := range annotations { + if annotation.Type == "url_citation" { + if !yield(ai.StreamPart{ + Type: ai.StreamPartTypeSource, + ID: uuid.NewString(), + SourceType: ai.SourceTypeURL, + URL: annotation.URLCitation.URL, + Title: annotation.URLCitation.Title, + }) { + return + } + } + } + } + } + } + err := stream.Err() + if err == nil || errors.Is(err, io.EOF) { + // finished + if isActiveText { + isActiveText = false + if !yield(ai.StreamPart{ + Type: ai.StreamPartTypeTextEnd, + ID: "0", + }) { + return + } + } + + // Add logprobs if available + if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 { + streamProviderMetadata.Logprobs = acc.Choices[0].Logprobs.Content + } + + // Handle annotations/citations from accumulated response + if len(acc.Choices) > 0 { + for _, annotation := range acc.Choices[0].Message.Annotations { + if annotation.Type == "url_citation" { + if !yield(ai.StreamPart{ + Type: ai.StreamPartTypeSource, + ID: acc.ID, + SourceType: ai.SourceTypeURL, + URL: annotation.URLCitation.URL, + Title: annotation.URLCitation.Title, + }) { + return + } + } + } + } + + finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason) + yield(ai.StreamPart{ + Type: ai.StreamPartTypeFinish, + Usage: usage, + FinishReason: finishReason, + ProviderMetadata: ai.ProviderMetadata{ + Name: streamProviderMetadata, + }, + }) + return + } else { + yield(ai.StreamPart{ + Type: ai.StreamPartTypeError, + Error: o.handleError(err), + }) + return + } + }, nil +} + +func mapOpenAiFinishReason(finishReason string) ai.FinishReason { + switch finishReason { + case "stop": + return ai.FinishReasonStop + case "length": + return ai.FinishReasonLength + case "content_filter": + return ai.FinishReasonContentFilter + case "function_call", "tool_calls": + return ai.FinishReasonToolCalls + default: + return ai.FinishReasonUnknown + } +} + +func isReasoningModel(modelID string) bool { + return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat") +} + +func isSearchPreviewModel(modelID string) bool { + return strings.Contains(modelID, "search-preview") +} + +func supportsFlexProcessing(modelID string) bool { + return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5") +} + +func supportsPriorityProcessing(modelID string) bool { + return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") || + strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") || + strings.HasPrefix(modelID, "o4-mini") +} + +func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) { + for _, tool := range tools { + if tool.GetType() == ai.ToolTypeFunction { + ft, ok := tool.(ai.FunctionTool) + if !ok { + continue + } + openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{ + OfFunction: &openai.ChatCompletionFunctionToolParam{ + Function: shared.FunctionDefinitionParam{ + Name: ft.Name, + Description: param.NewOpt(ft.Description), + Parameters: openai.FunctionParameters(ft.InputSchema), + Strict: param.NewOpt(false), + }, + Type: "function", + }, + }) + continue + } + + // TODO: handle provider tool calls + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedTool, + Tool: tool, + Message: "tool is not supported", + }) + } + if toolChoice == nil { + return openAiTools, openAiToolChoice, warnings + } + + switch *toolChoice { + case ai.ToolChoiceAuto: + openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{ + OfAuto: param.NewOpt("auto"), + } + case ai.ToolChoiceNone: + openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{ + OfAuto: param.NewOpt("none"), + } + default: + openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{ + OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{ + Type: "function", + Function: openai.ChatCompletionNamedToolChoiceFunctionParam{ + Name: string(*toolChoice), + }, + }, + } + } + return openAiTools, openAiToolChoice, warnings +} + +func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) { + var messages []openai.ChatCompletionMessageParamUnion + var warnings []ai.CallWarning + for _, msg := range prompt { + switch msg.Role { + case ai.MessageRoleSystem: + var systemPromptParts []string + for _, c := range msg.Content { + if c.GetType() != ai.ContentTypeText { + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeOther, + Message: "system prompt can only have text content", + }) + continue + } + textPart, ok := ai.AsContentType[ai.TextPart](c) + if !ok { + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeOther, + Message: "system prompt text part does not have the right type", + }) + continue + } + text := textPart.Text + if strings.TrimSpace(text) != "" { + systemPromptParts = append(systemPromptParts, textPart.Text) + } + } + if len(systemPromptParts) == 0 { + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeOther, + Message: "system prompt has no text parts", + }) + continue + } + messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n"))) + case ai.MessageRoleUser: + // simple user message just text content + if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText { + textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0]) + if !ok { + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeOther, + Message: "user message text part does not have the right type", + }) + continue + } + messages = append(messages, openai.UserMessage(textPart.Text)) + continue + } + // text content and attachments + // for now we only support image content later we need to check + // TODO: add the supported media types to the language model so we + // can use that to validate the data here. + var content []openai.ChatCompletionContentPartUnionParam + for _, c := range msg.Content { + switch c.GetType() { + case ai.ContentTypeText: + textPart, ok := ai.AsContentType[ai.TextPart](c) + if !ok { + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeOther, + Message: "user message text part does not have the right type", + }) + continue + } + content = append(content, openai.ChatCompletionContentPartUnionParam{ + OfText: &openai.ChatCompletionContentPartTextParam{ + Text: textPart.Text, + }, + }) + case ai.ContentTypeFile: + filePart, ok := ai.AsContentType[ai.FilePart](c) + if !ok { + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeOther, + Message: "user message file part does not have the right type", + }) + continue + } + + switch { + case strings.HasPrefix(filePart.MediaType, "image/"): + // Handle image files + base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data) + data := "data:" + filePart.MediaType + ";base64," + base64Encoded + imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data} + + // Check for provider-specific options like image detail + if providerOptions, ok := filePart.ProviderOptions[Name]; ok { + if detail, ok := providerOptions.(*ProviderFileOptions); ok { + imageURL.Detail = detail.ImageDetail + } + } + + imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL} + content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock}) + + case filePart.MediaType == "audio/wav": + // Handle WAV audio files + base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data) + audioBlock := openai.ChatCompletionContentPartInputAudioParam{ + InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{ + Data: base64Encoded, + Format: "wav", + }, + } + content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock}) + + case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3": + // Handle MP3 audio files + base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data) + audioBlock := openai.ChatCompletionContentPartInputAudioParam{ + InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{ + Data: base64Encoded, + Format: "mp3", + }, + } + content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock}) + + case filePart.MediaType == "application/pdf": + // Handle PDF files + dataStr := string(filePart.Data) + + // Check if data looks like a file ID (starts with "file-") + if strings.HasPrefix(dataStr, "file-") { + fileBlock := openai.ChatCompletionContentPartFileParam{ + File: openai.ChatCompletionContentPartFileFileParam{ + FileID: param.NewOpt(dataStr), + }, + } + content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock}) + } else { + // Handle as base64 data + base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data) + data := "data:application/pdf;base64," + base64Encoded + + filename := filePart.Filename + if filename == "" { + // Generate default filename based on content index + filename = fmt.Sprintf("part-%d.pdf", len(content)) + } + + fileBlock := openai.ChatCompletionContentPartFileParam{ + File: openai.ChatCompletionContentPartFileFileParam{ + Filename: param.NewOpt(filename), + FileData: param.NewOpt(data), + }, + } + content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock}) + } + + default: + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeOther, + Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType), + }) + } + } + } + messages = append(messages, openai.UserMessage(content)) + case ai.MessageRoleAssistant: + // simple assistant message just text content + if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText { + textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0]) + if !ok { + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeOther, + Message: "assistant message text part does not have the right type", + }) + continue + } + messages = append(messages, openai.AssistantMessage(textPart.Text)) + continue + } + assistantMsg := openai.ChatCompletionAssistantMessageParam{ + Role: "assistant", + } + for _, c := range msg.Content { + switch c.GetType() { + case ai.ContentTypeText: + textPart, ok := ai.AsContentType[ai.TextPart](c) + if !ok { + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeOther, + Message: "assistant message text part does not have the right type", + }) + continue + } + assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{ + OfString: param.NewOpt(textPart.Text), + } + case ai.ContentTypeToolCall: + toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c) + if !ok { + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeOther, + Message: "assistant message tool part does not have the right type", + }) + continue + } + assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, + openai.ChatCompletionMessageToolCallUnionParam{ + OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{ + ID: toolCallPart.ToolCallID, + Type: "function", + Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{ + Name: toolCallPart.ToolName, + Arguments: toolCallPart.Input, + }, + }, + }) + } + } + messages = append(messages, openai.ChatCompletionMessageParamUnion{ + OfAssistant: &assistantMsg, + }) + case ai.MessageRoleTool: + for _, c := range msg.Content { + if c.GetType() != ai.ContentTypeToolResult { + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeOther, + Message: "tool message can only have tool result content", + }) + continue + } + + toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c) + if !ok { + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeOther, + Message: "tool message result part does not have the right type", + }) + continue + } + + switch toolResultPart.Output.GetType() { + case ai.ToolResultContentTypeText: + output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output) + if !ok { + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeOther, + Message: "tool result output does not have the right type", + }) + continue + } + messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID)) + case ai.ToolResultContentTypeError: + // TODO: check if better handling is needed + output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output) + if !ok { + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeOther, + Message: "tool result output does not have the right type", + }) + continue + } + messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID)) + } + } + } + } + return messages, warnings +} + +// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta. +func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation { + var annotations []openai.ChatCompletionMessageAnnotation + + // Parse the raw JSON to extract annotations + var deltaData map[string]any + if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil { + return annotations + } + + // Check if annotations exist in the delta + if annotationsData, ok := deltaData["annotations"].([]any); ok { + for _, annotationData := range annotationsData { + if annotationMap, ok := annotationData.(map[string]any); ok { + if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" { + if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok { + annotation := openai.ChatCompletionMessageAnnotation{ + Type: "url_citation", + URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{ + URL: urlCitationData["url"].(string), + Title: urlCitationData["title"].(string), + }, + } + annotations = append(annotations, annotation) + } + } + } + } + } + + return annotations +} diff --git a/openai/language_model_hooks.go b/openai/language_model_hooks.go new file mode 100644 index 0000000000000000000000000000000000000000..537c88f6e2af8ae9082e0aae0339866e60521b2c --- /dev/null +++ b/openai/language_model_hooks.go @@ -0,0 +1,149 @@ +package openai + +import ( + "fmt" + + "github.com/charmbracelet/fantasy/ai" + "github.com/openai/openai-go/v2" + "github.com/openai/openai-go/v2/packages/param" + "github.com/openai/openai-go/v2/shared" +) + +type PrepareLanguageModelCallFunc = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) + +func defaultPrepareLanguageModelCall(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) { + if call.ProviderOptions == nil { + return nil, nil + } + var warnings []ai.CallWarning + providerOptions := &ProviderOptions{} + if v, ok := call.ProviderOptions[Name]; ok { + providerOptions, ok = v.(*ProviderOptions) + if !ok { + return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil) + } + } + + if providerOptions.LogitBias != nil { + params.LogitBias = providerOptions.LogitBias + } + if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil { + providerOptions.LogProbs = nil + } + if providerOptions.LogProbs != nil { + params.Logprobs = param.NewOpt(*providerOptions.LogProbs) + } + if providerOptions.TopLogProbs != nil { + params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs) + } + if providerOptions.User != nil { + params.User = param.NewOpt(*providerOptions.User) + } + if providerOptions.ParallelToolCalls != nil { + params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls) + } + if providerOptions.MaxCompletionTokens != nil { + params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens) + } + + if providerOptions.TextVerbosity != nil { + params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity) + } + if providerOptions.Prediction != nil { + // Convert map[string]any to ChatCompletionPredictionContentParam + if content, ok := providerOptions.Prediction["content"]; ok { + if contentStr, ok := content.(string); ok { + params.Prediction = openai.ChatCompletionPredictionContentParam{ + Content: openai.ChatCompletionPredictionContentContentUnionParam{ + OfString: param.NewOpt(contentStr), + }, + } + } + } + } + if providerOptions.Store != nil { + params.Store = param.NewOpt(*providerOptions.Store) + } + if providerOptions.Metadata != nil { + // Convert map[string]any to map[string]string + metadata := make(map[string]string) + for k, v := range providerOptions.Metadata { + if str, ok := v.(string); ok { + metadata[k] = str + } + } + params.Metadata = metadata + } + if providerOptions.PromptCacheKey != nil { + params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey) + } + if providerOptions.SafetyIdentifier != nil { + params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier) + } + if providerOptions.ServiceTier != nil { + params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier) + } + + if providerOptions.ReasoningEffort != nil { + switch *providerOptions.ReasoningEffort { + case ReasoningEffortMinimal: + params.ReasoningEffort = shared.ReasoningEffortMinimal + case ReasoningEffortLow: + params.ReasoningEffort = shared.ReasoningEffortLow + case ReasoningEffortMedium: + params.ReasoningEffort = shared.ReasoningEffortMedium + case ReasoningEffortHigh: + params.ReasoningEffort = shared.ReasoningEffortHigh + default: + return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort) + } + } + + if isReasoningModel(model.Model()) { + if providerOptions.LogitBias != nil { + params.LogitBias = nil + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "LogitBias", + Message: "LogitBias is not supported for reasoning models", + }) + } + if providerOptions.LogProbs != nil { + params.Logprobs = param.Opt[bool]{} + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "Logprobs", + Message: "Logprobs is not supported for reasoning models", + }) + } + if providerOptions.TopLogProbs != nil { + params.TopLogprobs = param.Opt[int64]{} + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "TopLogprobs", + Message: "TopLogprobs is not supported for reasoning models", + }) + } + } + + // Handle service tier validation + if providerOptions.ServiceTier != nil { + serviceTier := *providerOptions.ServiceTier + if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) { + params.ServiceTier = "" + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "ServiceTier", + Details: "flex processing is only available for o3, o4-mini, and gpt-5 models", + }) + } else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) { + params.ServiceTier = "" + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "ServiceTier", + Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported", + }) + } + } + return warnings, nil +} diff --git a/openai/openai.go b/openai/openai.go index 19fe1db20700a629de64f4687afeab8922df7be9..bde785f168acf8855656a59b9418ffc2b5a7c48f 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -2,22 +2,11 @@ package openai import ( "cmp" - "context" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" "maps" - "strings" "github.com/charmbracelet/fantasy/ai" - xjson "github.com/charmbracelet/x/json" - "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" - "github.com/openai/openai-go/v2/packages/param" - "github.com/openai/openai-go/v2/shared" ) const ( @@ -29,28 +18,23 @@ type provider struct { options options } -type PrepareCallWithOptions = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) - -type Hooks struct { - PrepareCallWithOptions PrepareCallWithOptions -} - type options struct { - baseURL string - apiKey string - organization string - project string - name string - hooks Hooks - headers map[string]string - client option.HTTPClient + baseURL string + apiKey string + organization string + project string + name string + headers map[string]string + client option.HTTPClient + languageModelOptions []LanguageModelOption } type Option = func(*options) func New(opts ...Option) ai.Provider { providerOptions := options{ - headers: map[string]string{}, + headers: map[string]string{}, + languageModelOptions: make([]LanguageModelOption, 0), } for _, o := range opts { o(&providerOptions) @@ -111,9 +95,9 @@ func WithHTTPClient(client option.HTTPClient) Option { } } -func WithHooks(hooks Hooks) Option { +func WithLanguageModelOptions(opts ...LanguageModelOption) Option { return func(o *options) { - o.hooks = hooks + o.languageModelOptions = append(o.languageModelOptions, opts...) } } @@ -135,517 +119,12 @@ func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) { openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client)) } - return languageModel{ - modelID: modelID, - provider: o.options.name, - options: o.options, - client: openai.NewClient(openaiClientOptions...), - }, nil -} - -type languageModel struct { - provider string - modelID string - client openai.Client - options options -} - -// Model implements ai.LanguageModel. -func (o languageModel) Model() string { - return o.modelID -} - -// Provider implements ai.LanguageModel. -func (o languageModel) Provider() string { - return o.provider -} - -func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) { - params := &openai.ChatCompletionNewParams{} - messages, warnings := toPrompt(call.Prompt) - if call.TopK != nil { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "top_k", - }) - } - params.Messages = messages - params.Model = o.modelID - - if call.MaxOutputTokens != nil { - params.MaxTokens = param.NewOpt(*call.MaxOutputTokens) - } - if call.Temperature != nil { - params.Temperature = param.NewOpt(*call.Temperature) - } - if call.TopP != nil { - params.TopP = param.NewOpt(*call.TopP) - } - if call.FrequencyPenalty != nil { - params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty) - } - if call.PresencePenalty != nil { - params.PresencePenalty = param.NewOpt(*call.PresencePenalty) - } - - if isReasoningModel(o.modelID) { - // remove unsupported settings for reasoning models - // see https://platform.openai.com/docs/guides/reasoning#limitations - if call.Temperature != nil { - params.Temperature = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "temperature", - Details: "temperature is not supported for reasoning models", - }) - } - if call.TopP != nil { - params.TopP = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "TopP", - Details: "TopP is not supported for reasoning models", - }) - } - if call.FrequencyPenalty != nil { - params.FrequencyPenalty = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "FrequencyPenalty", - Details: "FrequencyPenalty is not supported for reasoning models", - }) - } - if call.PresencePenalty != nil { - params.PresencePenalty = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "PresencePenalty", - Details: "PresencePenalty is not supported for reasoning models", - }) - } - - // reasoning models use max_completion_tokens instead of max_tokens - if call.MaxOutputTokens != nil { - if !params.MaxCompletionTokens.Valid() { - params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens) - } - params.MaxTokens = param.Opt[int64]{} - } - } - - // Handle search preview models - if isSearchPreviewModel(o.modelID) { - if call.Temperature != nil { - params.Temperature = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "temperature", - Details: "temperature is not supported for the search preview models and has been removed.", - }) - } - } - - prepareOptions := prepareCallWithOptions - if o.options.hooks.PrepareCallWithOptions != nil { - prepareOptions = o.options.hooks.PrepareCallWithOptions - } - - optionsWarnings, err := prepareOptions(o, params, call) - if err != nil { - return nil, nil, err - } - - if len(optionsWarnings) > 0 { - warnings = append(warnings, optionsWarnings...) - } - - if len(call.Tools) > 0 { - tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice) - params.Tools = tools - if toolChoice != nil { - params.ToolChoice = *toolChoice - } - warnings = append(warnings, toolWarnings...) - } - return params, warnings, nil -} - -func (o languageModel) handleError(err error) error { - var apiErr *openai.Error - if errors.As(err, &apiErr) { - requestDump := apiErr.DumpRequest(true) - responseDump := apiErr.DumpResponse(true) - headers := map[string]string{} - for k, h := range apiErr.Response.Header { - v := h[len(h)-1] - headers[strings.ToLower(k)] = v - } - return ai.NewAPICallError( - apiErr.Message, - apiErr.Request.URL.String(), - string(requestDump), - apiErr.StatusCode, - headers, - string(responseDump), - apiErr, - false, - ) - } - return err -} - -// Generate implements ai.LanguageModel. -func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) { - params, warnings, err := o.prepareParams(call) - if err != nil { - return nil, err - } - response, err := o.client.Chat.Completions.New(ctx, *params) - if err != nil { - return nil, o.handleError(err) - } - - if len(response.Choices) == 0 { - return nil, errors.New("no response generated") - } - choice := response.Choices[0] - content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations)) - text := choice.Message.Content - if text != "" { - content = append(content, ai.TextContent{ - Text: text, - }) - } - - for _, tc := range choice.Message.ToolCalls { - toolCallID := tc.ID - if toolCallID == "" { - toolCallID = uuid.NewString() - } - content = append(content, ai.ToolCallContent{ - ProviderExecuted: false, // TODO: update when handling other tools - ToolCallID: toolCallID, - ToolName: tc.Function.Name, - Input: tc.Function.Arguments, - }) - } - // Handle annotations/citations - for _, annotation := range choice.Message.Annotations { - if annotation.Type == "url_citation" { - content = append(content, ai.SourceContent{ - SourceType: ai.SourceTypeURL, - ID: uuid.NewString(), - URL: annotation.URLCitation.URL, - Title: annotation.URLCitation.Title, - }) - } - } - - completionTokenDetails := response.Usage.CompletionTokensDetails - promptTokenDetails := response.Usage.PromptTokensDetails - - // Build provider metadata - providerMetadata := &ProviderMetadata{} - // Add logprobs if available - if len(choice.Logprobs.Content) > 0 { - providerMetadata.Logprobs = choice.Logprobs.Content - } - - // Add prediction tokens if available - if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 { - if completionTokenDetails.AcceptedPredictionTokens > 0 { - providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens - } - if completionTokenDetails.RejectedPredictionTokens > 0 { - providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens - } - } - - return &ai.Response{ - Content: content, - Usage: ai.Usage{ - InputTokens: response.Usage.PromptTokens, - OutputTokens: response.Usage.CompletionTokens, - TotalTokens: response.Usage.TotalTokens, - ReasoningTokens: completionTokenDetails.ReasoningTokens, - CacheReadTokens: promptTokenDetails.CachedTokens, - }, - FinishReason: mapOpenAiFinishReason(choice.FinishReason), - ProviderMetadata: ai.ProviderMetadata{ - Name: providerMetadata, - }, - Warnings: warnings, - }, nil -} - -type toolCall struct { - id string - name string - arguments string - hasFinished bool -} - -// Stream implements ai.LanguageModel. -func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) { - params, warnings, err := o.prepareParams(call) - if err != nil { - return nil, err - } - - params.StreamOptions = openai.ChatCompletionStreamOptionsParam{ - IncludeUsage: openai.Bool(true), - } - - stream := o.client.Chat.Completions.NewStreaming(ctx, *params) - isActiveText := false - toolCalls := make(map[int64]toolCall) - - // Build provider metadata for streaming - streamProviderMetadata := &ProviderMetadata{} - acc := openai.ChatCompletionAccumulator{} - var usage ai.Usage - return func(yield func(ai.StreamPart) bool) { - if len(warnings) > 0 { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeWarnings, - Warnings: warnings, - }) { - return - } - } - for stream.Next() { - chunk := stream.Current() - acc.AddChunk(chunk) - if chunk.Usage.TotalTokens > 0 { - // we do this here because the acc does not add prompt details - completionTokenDetails := chunk.Usage.CompletionTokensDetails - promptTokenDetails := chunk.Usage.PromptTokensDetails - usage = ai.Usage{ - InputTokens: chunk.Usage.PromptTokens, - OutputTokens: chunk.Usage.CompletionTokens, - TotalTokens: chunk.Usage.TotalTokens, - ReasoningTokens: completionTokenDetails.ReasoningTokens, - CacheReadTokens: promptTokenDetails.CachedTokens, - } - - // Add prediction tokens if available - if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 { - if completionTokenDetails.AcceptedPredictionTokens > 0 { - streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens - } - if completionTokenDetails.RejectedPredictionTokens > 0 { - streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens - } - } - } - if len(chunk.Choices) == 0 { - continue - } - for _, choice := range chunk.Choices { - switch { - case choice.Delta.Content != "": - if !isActiveText { - isActiveText = true - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextStart, - ID: "0", - }) { - return - } - } - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextDelta, - ID: "0", - Delta: choice.Delta.Content, - }) { - return - } - case len(choice.Delta.ToolCalls) > 0: - if isActiveText { - isActiveText = false - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextEnd, - ID: "0", - }) { - return - } - } - - for _, toolCallDelta := range choice.Delta.ToolCalls { - if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok { - if existingToolCall.hasFinished { - continue - } - if toolCallDelta.Function.Arguments != "" { - existingToolCall.arguments += toolCallDelta.Function.Arguments - } - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputDelta, - ID: existingToolCall.id, - Delta: toolCallDelta.Function.Arguments, - }) { - return - } - toolCalls[toolCallDelta.Index] = existingToolCall - if xjson.IsValid(existingToolCall.arguments) { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputEnd, - ID: existingToolCall.id, - }) { - return - } - - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolCall, - ID: existingToolCall.id, - ToolCallName: existingToolCall.name, - ToolCallInput: existingToolCall.arguments, - }) { - return - } - existingToolCall.hasFinished = true - toolCalls[toolCallDelta.Index] = existingToolCall - } - } else { - // Does not exist - var err error - if toolCallDelta.Type != "function" { - err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.") - } - if toolCallDelta.ID == "" { - err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.") - } - if toolCallDelta.Function.Name == "" { - err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.") - } - if err != nil { - yield(ai.StreamPart{ - Type: ai.StreamPartTypeError, - Error: o.handleError(stream.Err()), - }) - return - } - - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputStart, - ID: toolCallDelta.ID, - ToolCallName: toolCallDelta.Function.Name, - }) { - return - } - toolCalls[toolCallDelta.Index] = toolCall{ - id: toolCallDelta.ID, - name: toolCallDelta.Function.Name, - arguments: toolCallDelta.Function.Arguments, - } - - exTc := toolCalls[toolCallDelta.Index] - if exTc.arguments != "" { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputDelta, - ID: exTc.id, - Delta: exTc.arguments, - }) { - return - } - if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolInputEnd, - ID: toolCallDelta.ID, - }) { - return - } - - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeToolCall, - ID: exTc.id, - ToolCallName: exTc.name, - ToolCallInput: exTc.arguments, - }) { - return - } - exTc.hasFinished = true - toolCalls[toolCallDelta.Index] = exTc - } - } - continue - } - } - } - } - - // Check for annotations in the delta's raw JSON - for _, choice := range chunk.Choices { - if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 { - for _, annotation := range annotations { - if annotation.Type == "url_citation" { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeSource, - ID: uuid.NewString(), - SourceType: ai.SourceTypeURL, - URL: annotation.URLCitation.URL, - Title: annotation.URLCitation.Title, - }) { - return - } - } - } - } - } - } - err := stream.Err() - if err == nil || errors.Is(err, io.EOF) { - // finished - if isActiveText { - isActiveText = false - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeTextEnd, - ID: "0", - }) { - return - } - } - - // Add logprobs if available - if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 { - streamProviderMetadata.Logprobs = acc.Choices[0].Logprobs.Content - } - - // Handle annotations/citations from accumulated response - if len(acc.Choices) > 0 { - for _, annotation := range acc.Choices[0].Message.Annotations { - if annotation.Type == "url_citation" { - if !yield(ai.StreamPart{ - Type: ai.StreamPartTypeSource, - ID: acc.ID, - SourceType: ai.SourceTypeURL, - URL: annotation.URLCitation.URL, - Title: annotation.URLCitation.Title, - }) { - return - } - } - } - } - - finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason) - yield(ai.StreamPart{ - Type: ai.StreamPartTypeFinish, - Usage: usage, - FinishReason: finishReason, - ProviderMetadata: ai.ProviderMetadata{ - Name: streamProviderMetadata, - }, - }) - return - } else { - yield(ai.StreamPart{ - Type: ai.StreamPartTypeError, - Error: o.handleError(err), - }) - return - } - }, nil + return newLanguageModel( + modelID, + o.options.name, + openai.NewClient(openaiClientOptions...), + o.options.languageModelOptions..., + ), nil } func (o *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) { @@ -659,525 +138,3 @@ func (o *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, er func (o *provider) Name() string { return Name } - -func prepareCallWithOptions(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) { - if call.ProviderOptions == nil { - return nil, nil - } - var warnings []ai.CallWarning - providerOptions := &ProviderOptions{} - if v, ok := call.ProviderOptions[Name]; ok { - providerOptions, ok = v.(*ProviderOptions) - if !ok { - return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil) - } - } - - if providerOptions.LogitBias != nil { - params.LogitBias = providerOptions.LogitBias - } - if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil { - providerOptions.LogProbs = nil - } - if providerOptions.LogProbs != nil { - params.Logprobs = param.NewOpt(*providerOptions.LogProbs) - } - if providerOptions.TopLogProbs != nil { - params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs) - } - if providerOptions.User != nil { - params.User = param.NewOpt(*providerOptions.User) - } - if providerOptions.ParallelToolCalls != nil { - params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls) - } - if providerOptions.MaxCompletionTokens != nil { - params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens) - } - - if providerOptions.TextVerbosity != nil { - params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity) - } - if providerOptions.Prediction != nil { - // Convert map[string]any to ChatCompletionPredictionContentParam - if content, ok := providerOptions.Prediction["content"]; ok { - if contentStr, ok := content.(string); ok { - params.Prediction = openai.ChatCompletionPredictionContentParam{ - Content: openai.ChatCompletionPredictionContentContentUnionParam{ - OfString: param.NewOpt(contentStr), - }, - } - } - } - } - if providerOptions.Store != nil { - params.Store = param.NewOpt(*providerOptions.Store) - } - if providerOptions.Metadata != nil { - // Convert map[string]any to map[string]string - metadata := make(map[string]string) - for k, v := range providerOptions.Metadata { - if str, ok := v.(string); ok { - metadata[k] = str - } - } - params.Metadata = metadata - } - if providerOptions.PromptCacheKey != nil { - params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey) - } - if providerOptions.SafetyIdentifier != nil { - params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier) - } - if providerOptions.ServiceTier != nil { - params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier) - } - - if providerOptions.ReasoningEffort != nil { - switch *providerOptions.ReasoningEffort { - case ReasoningEffortMinimal: - params.ReasoningEffort = shared.ReasoningEffortMinimal - case ReasoningEffortLow: - params.ReasoningEffort = shared.ReasoningEffortLow - case ReasoningEffortMedium: - params.ReasoningEffort = shared.ReasoningEffortMedium - case ReasoningEffortHigh: - params.ReasoningEffort = shared.ReasoningEffortHigh - default: - return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort) - } - } - - if isReasoningModel(model.Model()) { - if providerOptions.LogitBias != nil { - params.LogitBias = nil - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "LogitBias", - Message: "LogitBias is not supported for reasoning models", - }) - } - if providerOptions.LogProbs != nil { - params.Logprobs = param.Opt[bool]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "Logprobs", - Message: "Logprobs is not supported for reasoning models", - }) - } - if providerOptions.TopLogProbs != nil { - params.TopLogprobs = param.Opt[int64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "TopLogprobs", - Message: "TopLogprobs is not supported for reasoning models", - }) - } - } - - // Handle service tier validation - if providerOptions.ServiceTier != nil { - serviceTier := *providerOptions.ServiceTier - if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) { - params.ServiceTier = "" - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "ServiceTier", - Details: "flex processing is only available for o3, o4-mini, and gpt-5 models", - }) - } else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) { - params.ServiceTier = "" - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "ServiceTier", - Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported", - }) - } - } - return warnings, nil -} - -func mapOpenAiFinishReason(finishReason string) ai.FinishReason { - switch finishReason { - case "stop": - return ai.FinishReasonStop - case "length": - return ai.FinishReasonLength - case "content_filter": - return ai.FinishReasonContentFilter - case "function_call", "tool_calls": - return ai.FinishReasonToolCalls - default: - return ai.FinishReasonUnknown - } -} - -func isReasoningModel(modelID string) bool { - return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat") -} - -func isSearchPreviewModel(modelID string) bool { - return strings.Contains(modelID, "search-preview") -} - -func supportsFlexProcessing(modelID string) bool { - return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5") -} - -func supportsPriorityProcessing(modelID string) bool { - return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") || - strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") || - strings.HasPrefix(modelID, "o4-mini") -} - -func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) { - for _, tool := range tools { - if tool.GetType() == ai.ToolTypeFunction { - ft, ok := tool.(ai.FunctionTool) - if !ok { - continue - } - openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{ - OfFunction: &openai.ChatCompletionFunctionToolParam{ - Function: shared.FunctionDefinitionParam{ - Name: ft.Name, - Description: param.NewOpt(ft.Description), - Parameters: openai.FunctionParameters(ft.InputSchema), - Strict: param.NewOpt(false), - }, - Type: "function", - }, - }) - continue - } - - // TODO: handle provider tool calls - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedTool, - Tool: tool, - Message: "tool is not supported", - }) - } - if toolChoice == nil { - return openAiTools, openAiToolChoice, warnings - } - - switch *toolChoice { - case ai.ToolChoiceAuto: - openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{ - OfAuto: param.NewOpt("auto"), - } - case ai.ToolChoiceNone: - openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{ - OfAuto: param.NewOpt("none"), - } - default: - openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{ - OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{ - Type: "function", - Function: openai.ChatCompletionNamedToolChoiceFunctionParam{ - Name: string(*toolChoice), - }, - }, - } - } - return openAiTools, openAiToolChoice, warnings -} - -func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) { - var messages []openai.ChatCompletionMessageParamUnion - var warnings []ai.CallWarning - for _, msg := range prompt { - switch msg.Role { - case ai.MessageRoleSystem: - var systemPromptParts []string - for _, c := range msg.Content { - if c.GetType() != ai.ContentTypeText { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, - Message: "system prompt can only have text content", - }) - continue - } - textPart, ok := ai.AsContentType[ai.TextPart](c) - if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, - Message: "system prompt text part does not have the right type", - }) - continue - } - text := textPart.Text - if strings.TrimSpace(text) != "" { - systemPromptParts = append(systemPromptParts, textPart.Text) - } - } - if len(systemPromptParts) == 0 { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, - Message: "system prompt has no text parts", - }) - continue - } - messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n"))) - case ai.MessageRoleUser: - // simple user message just text content - if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText { - textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0]) - if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, - Message: "user message text part does not have the right type", - }) - continue - } - messages = append(messages, openai.UserMessage(textPart.Text)) - continue - } - // text content and attachments - // for now we only support image content later we need to check - // TODO: add the supported media types to the language model so we - // can use that to validate the data here. - var content []openai.ChatCompletionContentPartUnionParam - for _, c := range msg.Content { - switch c.GetType() { - case ai.ContentTypeText: - textPart, ok := ai.AsContentType[ai.TextPart](c) - if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, - Message: "user message text part does not have the right type", - }) - continue - } - content = append(content, openai.ChatCompletionContentPartUnionParam{ - OfText: &openai.ChatCompletionContentPartTextParam{ - Text: textPart.Text, - }, - }) - case ai.ContentTypeFile: - filePart, ok := ai.AsContentType[ai.FilePart](c) - if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, - Message: "user message file part does not have the right type", - }) - continue - } - - switch { - case strings.HasPrefix(filePart.MediaType, "image/"): - // Handle image files - base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data) - data := "data:" + filePart.MediaType + ";base64," + base64Encoded - imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data} - - // Check for provider-specific options like image detail - if providerOptions, ok := filePart.ProviderOptions[Name]; ok { - if detail, ok := providerOptions.(*ProviderFileOptions); ok { - imageURL.Detail = detail.ImageDetail - } - } - - imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL} - content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock}) - - case filePart.MediaType == "audio/wav": - // Handle WAV audio files - base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data) - audioBlock := openai.ChatCompletionContentPartInputAudioParam{ - InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{ - Data: base64Encoded, - Format: "wav", - }, - } - content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock}) - - case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3": - // Handle MP3 audio files - base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data) - audioBlock := openai.ChatCompletionContentPartInputAudioParam{ - InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{ - Data: base64Encoded, - Format: "mp3", - }, - } - content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock}) - - case filePart.MediaType == "application/pdf": - // Handle PDF files - dataStr := string(filePart.Data) - - // Check if data looks like a file ID (starts with "file-") - if strings.HasPrefix(dataStr, "file-") { - fileBlock := openai.ChatCompletionContentPartFileParam{ - File: openai.ChatCompletionContentPartFileFileParam{ - FileID: param.NewOpt(dataStr), - }, - } - content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock}) - } else { - // Handle as base64 data - base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data) - data := "data:application/pdf;base64," + base64Encoded - - filename := filePart.Filename - if filename == "" { - // Generate default filename based on content index - filename = fmt.Sprintf("part-%d.pdf", len(content)) - } - - fileBlock := openai.ChatCompletionContentPartFileParam{ - File: openai.ChatCompletionContentPartFileFileParam{ - Filename: param.NewOpt(filename), - FileData: param.NewOpt(data), - }, - } - content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock}) - } - - default: - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, - Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType), - }) - } - } - } - messages = append(messages, openai.UserMessage(content)) - case ai.MessageRoleAssistant: - // simple assistant message just text content - if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText { - textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0]) - if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, - Message: "assistant message text part does not have the right type", - }) - continue - } - messages = append(messages, openai.AssistantMessage(textPart.Text)) - continue - } - assistantMsg := openai.ChatCompletionAssistantMessageParam{ - Role: "assistant", - } - for _, c := range msg.Content { - switch c.GetType() { - case ai.ContentTypeText: - textPart, ok := ai.AsContentType[ai.TextPart](c) - if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, - Message: "assistant message text part does not have the right type", - }) - continue - } - assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{ - OfString: param.NewOpt(textPart.Text), - } - case ai.ContentTypeToolCall: - toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c) - if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, - Message: "assistant message tool part does not have the right type", - }) - continue - } - assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, - openai.ChatCompletionMessageToolCallUnionParam{ - OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{ - ID: toolCallPart.ToolCallID, - Type: "function", - Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{ - Name: toolCallPart.ToolName, - Arguments: toolCallPart.Input, - }, - }, - }) - } - } - messages = append(messages, openai.ChatCompletionMessageParamUnion{ - OfAssistant: &assistantMsg, - }) - case ai.MessageRoleTool: - for _, c := range msg.Content { - if c.GetType() != ai.ContentTypeToolResult { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, - Message: "tool message can only have tool result content", - }) - continue - } - - toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c) - if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, - Message: "tool message result part does not have the right type", - }) - continue - } - - switch toolResultPart.Output.GetType() { - case ai.ToolResultContentTypeText: - output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output) - if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, - Message: "tool result output does not have the right type", - }) - continue - } - messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID)) - case ai.ToolResultContentTypeError: - // TODO: check if better handling is needed - output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output) - if !ok { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeOther, - Message: "tool result output does not have the right type", - }) - continue - } - messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID)) - } - } - } - } - return messages, warnings -} - -// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta. -func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation { - var annotations []openai.ChatCompletionMessageAnnotation - - // Parse the raw JSON to extract annotations - var deltaData map[string]any - if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil { - return annotations - } - - // Check if annotations exist in the delta - if annotationsData, ok := deltaData["annotations"].([]any); ok { - for _, annotationData := range annotationsData { - if annotationMap, ok := annotationData.(map[string]any); ok { - if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" { - if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok { - annotation := openai.ChatCompletionMessageAnnotation{ - Type: "url_citation", - URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{ - URL: urlCitationData["url"].(string), - Title: urlCitationData["title"].(string), - }, - } - annotations = append(annotations, annotation) - } - } - } - } - } - - return annotations -} diff --git a/openrouter/openrouter.go b/openrouter/openrouter.go index bf3f32113253717950116c990fc6431217893dce..3c647789d074bdda8866f460aead8b1176706b61 100644 --- a/openrouter/openrouter.go +++ b/openrouter/openrouter.go @@ -36,9 +36,9 @@ func New(opts ...Option) ai.Provider { providerOptions := options{ openaiOptions: []openai.Option{ openai.WithBaseURL(DefaultURL), - openai.WithHooks(openai.Hooks{ - PrepareCallWithOptions: prepareCallWithOptions, - }), + openai.WithLanguageModelOptions( + openai.WithPrepareLanguageModelCall(prepareCallWithOptions), + ), }, } for _, o := range opts {