1package openai
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"io"
  8	"strings"
  9
 10	"charm.land/fantasy"
 11	xjson "github.com/charmbracelet/x/json"
 12	"github.com/google/uuid"
 13	"github.com/openai/openai-go/v2"
 14	"github.com/openai/openai-go/v2/packages/param"
 15	"github.com/openai/openai-go/v2/shared"
 16)
 17
 18type languageModel struct {
 19	provider                   string
 20	modelID                    string
 21	client                     openai.Client
 22	prepareCallFunc            LanguageModelPrepareCallFunc
 23	mapFinishReasonFunc        LanguageModelMapFinishReasonFunc
 24	extraContentFunc           LanguageModelExtraContentFunc
 25	usageFunc                  LanguageModelUsageFunc
 26	streamUsageFunc            LanguageModelStreamUsageFunc
 27	streamExtraFunc            LanguageModelStreamExtraFunc
 28	streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
 29	toPromptFunc               LanguageModelToPromptFunc
 30}
 31
 32// LanguageModelOption is a function that configures a languageModel.
 33type LanguageModelOption = func(*languageModel)
 34
 35// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
 36func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
 37	return func(l *languageModel) {
 38		l.prepareCallFunc = fn
 39	}
 40}
 41
 42// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
 43func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
 44	return func(l *languageModel) {
 45		l.mapFinishReasonFunc = fn
 46	}
 47}
 48
 49// WithLanguageModelExtraContentFunc sets the extra content function for the language model.
 50func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
 51	return func(l *languageModel) {
 52		l.extraContentFunc = fn
 53	}
 54}
 55
 56// WithLanguageModelStreamExtraFunc sets the stream extra function for the language model.
 57func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
 58	return func(l *languageModel) {
 59		l.streamExtraFunc = fn
 60	}
 61}
 62
 63// WithLanguageModelUsageFunc sets the usage function for the language model.
 64func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
 65	return func(l *languageModel) {
 66		l.usageFunc = fn
 67	}
 68}
 69
 70// WithLanguageModelStreamUsageFunc sets the stream usage function for the language model.
 71func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
 72	return func(l *languageModel) {
 73		l.streamUsageFunc = fn
 74	}
 75}
 76
 77// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
 78func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
 79	return func(l *languageModel) {
 80		l.toPromptFunc = fn
 81	}
 82}
 83
 84func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
 85	model := languageModel{
 86		modelID:                    modelID,
 87		provider:                   provider,
 88		client:                     client,
 89		prepareCallFunc:            DefaultPrepareCallFunc,
 90		mapFinishReasonFunc:        DefaultMapFinishReasonFunc,
 91		usageFunc:                  DefaultUsageFunc,
 92		streamUsageFunc:            DefaultStreamUsageFunc,
 93		streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc,
 94		toPromptFunc:               DefaultToPrompt,
 95	}
 96
 97	for _, o := range opts {
 98		o(&model)
 99	}
100	return model
101}
102
103type streamToolCall struct {
104	id          string
105	name        string
106	arguments   string
107	hasFinished bool
108}
109
110// Model implements fantasy.LanguageModel.
111func (o languageModel) Model() string {
112	return o.modelID
113}
114
115// Provider implements fantasy.LanguageModel.
116func (o languageModel) Provider() string {
117	return o.provider
118}
119
120func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) {
121	params := &openai.ChatCompletionNewParams{}
122	messages, warnings := o.toPromptFunc(call.Prompt, o.provider, o.modelID)
123	if call.TopK != nil {
124		warnings = append(warnings, fantasy.CallWarning{
125			Type:    fantasy.CallWarningTypeUnsupportedSetting,
126			Setting: "top_k",
127		})
128	}
129
130	if call.MaxOutputTokens != nil {
131		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
132	}
133	if call.Temperature != nil {
134		params.Temperature = param.NewOpt(*call.Temperature)
135	}
136	if call.TopP != nil {
137		params.TopP = param.NewOpt(*call.TopP)
138	}
139	if call.FrequencyPenalty != nil {
140		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
141	}
142	if call.PresencePenalty != nil {
143		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
144	}
145
146	if isReasoningModel(o.modelID) {
147		// remove unsupported settings for reasoning models
148		// see https://platform.openai.com/docs/guides/reasoning#limitations
149		if call.Temperature != nil {
150			params.Temperature = param.Opt[float64]{}
151			warnings = append(warnings, fantasy.CallWarning{
152				Type:    fantasy.CallWarningTypeUnsupportedSetting,
153				Setting: "temperature",
154				Details: "temperature is not supported for reasoning models",
155			})
156		}
157		if call.TopP != nil {
158			params.TopP = param.Opt[float64]{}
159			warnings = append(warnings, fantasy.CallWarning{
160				Type:    fantasy.CallWarningTypeUnsupportedSetting,
161				Setting: "TopP",
162				Details: "TopP is not supported for reasoning models",
163			})
164		}
165		if call.FrequencyPenalty != nil {
166			params.FrequencyPenalty = param.Opt[float64]{}
167			warnings = append(warnings, fantasy.CallWarning{
168				Type:    fantasy.CallWarningTypeUnsupportedSetting,
169				Setting: "FrequencyPenalty",
170				Details: "FrequencyPenalty is not supported for reasoning models",
171			})
172		}
173		if call.PresencePenalty != nil {
174			params.PresencePenalty = param.Opt[float64]{}
175			warnings = append(warnings, fantasy.CallWarning{
176				Type:    fantasy.CallWarningTypeUnsupportedSetting,
177				Setting: "PresencePenalty",
178				Details: "PresencePenalty is not supported for reasoning models",
179			})
180		}
181
182		// reasoning models use max_completion_tokens instead of max_tokens
183		if call.MaxOutputTokens != nil {
184			if !params.MaxCompletionTokens.Valid() {
185				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
186			}
187			params.MaxTokens = param.Opt[int64]{}
188		}
189	}
190
191	// Handle search preview models
192	if isSearchPreviewModel(o.modelID) {
193		if call.Temperature != nil {
194			params.Temperature = param.Opt[float64]{}
195			warnings = append(warnings, fantasy.CallWarning{
196				Type:    fantasy.CallWarningTypeUnsupportedSetting,
197				Setting: "temperature",
198				Details: "temperature is not supported for the search preview models and has been removed.",
199			})
200		}
201	}
202
203	optionsWarnings, err := o.prepareCallFunc(o, params, call)
204	if err != nil {
205		return nil, nil, err
206	}
207
208	if len(optionsWarnings) > 0 {
209		warnings = append(warnings, optionsWarnings...)
210	}
211
212	params.Messages = messages
213	params.Model = o.modelID
214
215	if len(call.Tools) > 0 {
216		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
217		params.Tools = tools
218		if toolChoice != nil {
219			params.ToolChoice = *toolChoice
220		}
221		warnings = append(warnings, toolWarnings...)
222	}
223	return params, warnings, nil
224}
225
226func (o languageModel) handleError(err error) error {
227	var apiErr *openai.Error
228	if errors.As(err, &apiErr) {
229		requestDump := apiErr.DumpRequest(true)
230		responseDump := apiErr.DumpResponse(true)
231		headers := map[string]string{}
232		for k, h := range apiErr.Response.Header {
233			v := h[len(h)-1]
234			headers[strings.ToLower(k)] = v
235		}
236		return fantasy.NewAPICallError(
237			apiErr.Message,
238			apiErr.Request.URL.String(),
239			string(requestDump),
240			apiErr.StatusCode,
241			headers,
242			string(responseDump),
243			apiErr,
244			false,
245		)
246	}
247	return err
248}
249
250// Generate implements fantasy.LanguageModel.
251func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
252	params, warnings, err := o.prepareParams(call)
253	if err != nil {
254		return nil, err
255	}
256	response, err := o.client.Chat.Completions.New(ctx, *params)
257	if err != nil {
258		return nil, o.handleError(err)
259	}
260
261	if len(response.Choices) == 0 {
262		return nil, errors.New("no response generated")
263	}
264	choice := response.Choices[0]
265	content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
266	text := choice.Message.Content
267	if text != "" {
268		content = append(content, fantasy.TextContent{
269			Text: text,
270		})
271	}
272	if o.extraContentFunc != nil {
273		extraContent := o.extraContentFunc(choice)
274		content = append(content, extraContent...)
275	}
276	for _, tc := range choice.Message.ToolCalls {
277		toolCallID := tc.ID
278		content = append(content, fantasy.ToolCallContent{
279			ProviderExecuted: false, // TODO: update when handling other tools
280			ToolCallID:       toolCallID,
281			ToolName:         tc.Function.Name,
282			Input:            tc.Function.Arguments,
283		})
284	}
285	// Handle annotations/citations
286	for _, annotation := range choice.Message.Annotations {
287		if annotation.Type == "url_citation" {
288			content = append(content, fantasy.SourceContent{
289				SourceType: fantasy.SourceTypeURL,
290				ID:         uuid.NewString(),
291				URL:        annotation.URLCitation.URL,
292				Title:      annotation.URLCitation.Title,
293			})
294		}
295	}
296
297	usage, providerMetadata := o.usageFunc(*response)
298
299	mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason)
300	if len(choice.Message.ToolCalls) > 0 {
301		mappedFinishReason = fantasy.FinishReasonToolCalls
302	}
303	return &fantasy.Response{
304		Content:      content,
305		Usage:        usage,
306		FinishReason: mappedFinishReason,
307		ProviderMetadata: fantasy.ProviderMetadata{
308			Name: providerMetadata,
309		},
310		Warnings: warnings,
311	}, nil
312}
313
314// Stream implements fantasy.LanguageModel.
315func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
316	params, warnings, err := o.prepareParams(call)
317	if err != nil {
318		return nil, err
319	}
320
321	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
322		IncludeUsage: openai.Bool(true),
323	}
324
325	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
326	isActiveText := false
327	toolCalls := make(map[int64]streamToolCall)
328
329	// Build provider metadata for streaming
330	providerMetadata := fantasy.ProviderMetadata{
331		Name: &ProviderMetadata{},
332	}
333	acc := openai.ChatCompletionAccumulator{}
334	extraContext := make(map[string]any)
335	var usage fantasy.Usage
336	var finishReason string
337	return func(yield func(fantasy.StreamPart) bool) {
338		if len(warnings) > 0 {
339			if !yield(fantasy.StreamPart{
340				Type:     fantasy.StreamPartTypeWarnings,
341				Warnings: warnings,
342			}) {
343				return
344			}
345		}
346		for stream.Next() {
347			chunk := stream.Current()
348			acc.AddChunk(chunk)
349			usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
350			if len(chunk.Choices) == 0 {
351				continue
352			}
353			for _, choice := range chunk.Choices {
354				if choice.FinishReason != "" {
355					finishReason = choice.FinishReason
356				}
357				switch {
358				case choice.Delta.Content != "":
359					if !isActiveText {
360						isActiveText = true
361						if !yield(fantasy.StreamPart{
362							Type: fantasy.StreamPartTypeTextStart,
363							ID:   "0",
364						}) {
365							return
366						}
367					}
368					if !yield(fantasy.StreamPart{
369						Type:  fantasy.StreamPartTypeTextDelta,
370						ID:    "0",
371						Delta: choice.Delta.Content,
372					}) {
373						return
374					}
375				case len(choice.Delta.ToolCalls) > 0:
376					if isActiveText {
377						isActiveText = false
378						if !yield(fantasy.StreamPart{
379							Type: fantasy.StreamPartTypeTextEnd,
380							ID:   "0",
381						}) {
382							return
383						}
384					}
385
386					for _, toolCallDelta := range choice.Delta.ToolCalls {
387						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
388							if existingToolCall.hasFinished {
389								continue
390							}
391							if toolCallDelta.Function.Arguments != "" {
392								existingToolCall.arguments += toolCallDelta.Function.Arguments
393							}
394							if !yield(fantasy.StreamPart{
395								Type:  fantasy.StreamPartTypeToolInputDelta,
396								ID:    existingToolCall.id,
397								Delta: toolCallDelta.Function.Arguments,
398							}) {
399								return
400							}
401							toolCalls[toolCallDelta.Index] = existingToolCall
402							if xjson.IsValid(existingToolCall.arguments) {
403								if !yield(fantasy.StreamPart{
404									Type: fantasy.StreamPartTypeToolInputEnd,
405									ID:   existingToolCall.id,
406								}) {
407									return
408								}
409
410								if !yield(fantasy.StreamPart{
411									Type:          fantasy.StreamPartTypeToolCall,
412									ID:            existingToolCall.id,
413									ToolCallName:  existingToolCall.name,
414									ToolCallInput: existingToolCall.arguments,
415								}) {
416									return
417								}
418								existingToolCall.hasFinished = true
419								toolCalls[toolCallDelta.Index] = existingToolCall
420							}
421						} else {
422							// Does not exist
423							var err error
424							if toolCallDelta.Type != "function" {
425								err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
426							}
427							if toolCallDelta.ID == "" {
428								err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
429							}
430							if toolCallDelta.Function.Name == "" {
431								err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
432							}
433							if err != nil {
434								yield(fantasy.StreamPart{
435									Type:  fantasy.StreamPartTypeError,
436									Error: o.handleError(stream.Err()),
437								})
438								return
439							}
440
441							if !yield(fantasy.StreamPart{
442								Type:         fantasy.StreamPartTypeToolInputStart,
443								ID:           toolCallDelta.ID,
444								ToolCallName: toolCallDelta.Function.Name,
445							}) {
446								return
447							}
448							toolCalls[toolCallDelta.Index] = streamToolCall{
449								id:        toolCallDelta.ID,
450								name:      toolCallDelta.Function.Name,
451								arguments: toolCallDelta.Function.Arguments,
452							}
453
454							exTc := toolCalls[toolCallDelta.Index]
455							if exTc.arguments != "" {
456								if !yield(fantasy.StreamPart{
457									Type:  fantasy.StreamPartTypeToolInputDelta,
458									ID:    exTc.id,
459									Delta: exTc.arguments,
460								}) {
461									return
462								}
463								if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
464									if !yield(fantasy.StreamPart{
465										Type: fantasy.StreamPartTypeToolInputEnd,
466										ID:   toolCallDelta.ID,
467									}) {
468										return
469									}
470
471									if !yield(fantasy.StreamPart{
472										Type:          fantasy.StreamPartTypeToolCall,
473										ID:            exTc.id,
474										ToolCallName:  exTc.name,
475										ToolCallInput: exTc.arguments,
476									}) {
477										return
478									}
479									exTc.hasFinished = true
480									toolCalls[toolCallDelta.Index] = exTc
481								}
482							}
483							continue
484						}
485					}
486				}
487
488				if o.streamExtraFunc != nil {
489					updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
490					if !shouldContinue {
491						return
492					}
493					extraContext = updatedContext
494				}
495			}
496
497			// Check for annotations in the delta's raw JSON
498			for _, choice := range chunk.Choices {
499				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
500					for _, annotation := range annotations {
501						if annotation.Type == "url_citation" {
502							if !yield(fantasy.StreamPart{
503								Type:       fantasy.StreamPartTypeSource,
504								ID:         uuid.NewString(),
505								SourceType: fantasy.SourceTypeURL,
506								URL:        annotation.URLCitation.URL,
507								Title:      annotation.URLCitation.Title,
508							}) {
509								return
510							}
511						}
512					}
513				}
514			}
515		}
516		err := stream.Err()
517		if err == nil || errors.Is(err, io.EOF) {
518			// finished
519			if isActiveText {
520				isActiveText = false
521				if !yield(fantasy.StreamPart{
522					Type: fantasy.StreamPartTypeTextEnd,
523					ID:   "0",
524				}) {
525					return
526				}
527			}
528
529			if len(acc.Choices) > 0 {
530				choice := acc.Choices[0]
531				// Add logprobs if available
532				providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
533
534				// Handle annotations/citations from accumulated response
535				for _, annotation := range choice.Message.Annotations {
536					if annotation.Type == "url_citation" {
537						if !yield(fantasy.StreamPart{
538							Type:       fantasy.StreamPartTypeSource,
539							ID:         acc.ID,
540							SourceType: fantasy.SourceTypeURL,
541							URL:        annotation.URLCitation.URL,
542							Title:      annotation.URLCitation.Title,
543						}) {
544							return
545						}
546					}
547				}
548			}
549			mappedFinishReason := o.mapFinishReasonFunc(finishReason)
550			if len(acc.Choices) > 0 {
551				choice := acc.Choices[0]
552				if len(choice.Message.ToolCalls) > 0 {
553					mappedFinishReason = fantasy.FinishReasonToolCalls
554				}
555			}
556			yield(fantasy.StreamPart{
557				Type:             fantasy.StreamPartTypeFinish,
558				Usage:            usage,
559				FinishReason:     mappedFinishReason,
560				ProviderMetadata: providerMetadata,
561			})
562			return
563		} else { //nolint: revive
564			yield(fantasy.StreamPart{
565				Type:  fantasy.StreamPartTypeError,
566				Error: o.handleError(err),
567			})
568			return
569		}
570	}, nil
571}
572
573func isReasoningModel(modelID string) bool {
574	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
575}
576
577func isSearchPreviewModel(modelID string) bool {
578	return strings.Contains(modelID, "search-preview")
579}
580
581func supportsFlexProcessing(modelID string) bool {
582	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
583}
584
585func supportsPriorityProcessing(modelID string) bool {
586	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
587		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
588		strings.HasPrefix(modelID, "o4-mini")
589}
590
591func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) {
592	for _, tool := range tools {
593		if tool.GetType() == fantasy.ToolTypeFunction {
594			ft, ok := tool.(fantasy.FunctionTool)
595			if !ok {
596				continue
597			}
598			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
599				OfFunction: &openai.ChatCompletionFunctionToolParam{
600					Function: shared.FunctionDefinitionParam{
601						Name:        ft.Name,
602						Description: param.NewOpt(ft.Description),
603						Parameters:  openai.FunctionParameters(ft.InputSchema),
604						Strict:      param.NewOpt(false),
605					},
606					Type: "function",
607				},
608			})
609			continue
610		}
611
612		// TODO: handle provider tool calls
613		warnings = append(warnings, fantasy.CallWarning{
614			Type:    fantasy.CallWarningTypeUnsupportedTool,
615			Tool:    tool,
616			Message: "tool is not supported",
617		})
618	}
619	if toolChoice == nil {
620		return openAiTools, openAiToolChoice, warnings
621	}
622
623	switch *toolChoice {
624	case fantasy.ToolChoiceAuto:
625		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
626			OfAuto: param.NewOpt("auto"),
627		}
628	case fantasy.ToolChoiceNone:
629		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
630			OfAuto: param.NewOpt("none"),
631		}
632	default:
633		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
634			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
635				Type: "function",
636				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
637					Name: string(*toolChoice),
638				},
639			},
640		}
641	}
642	return openAiTools, openAiToolChoice, warnings
643}
644
645// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
646func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
647	var annotations []openai.ChatCompletionMessageAnnotation
648
649	// Parse the raw JSON to extract annotations
650	var deltaData map[string]any
651	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
652		return annotations
653	}
654
655	// Check if annotations exist in the delta
656	if annotationsData, ok := deltaData["annotations"].([]any); ok {
657		for _, annotationData := range annotationsData {
658			if annotationMap, ok := annotationData.(map[string]any); ok {
659				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
660					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
661						annotation := openai.ChatCompletionMessageAnnotation{
662							Type: "url_citation",
663							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
664								URL:   urlCitationData["url"].(string),
665								Title: urlCitationData["title"].(string),
666							},
667						}
668						annotations = append(annotations, annotation)
669					}
670				}
671			}
672		}
673	}
674
675	return annotations
676}