language_model.go

  1package openai
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"io"
  8	"strings"
  9
 10	"charm.land/fantasy"
 11	xjson "github.com/charmbracelet/x/json"
 12	"github.com/google/uuid"
 13	"github.com/openai/openai-go/v2"
 14	"github.com/openai/openai-go/v2/packages/param"
 15	"github.com/openai/openai-go/v2/shared"
 16)
 17
 18type languageModel struct {
 19	provider                   string
 20	modelID                    string
 21	client                     openai.Client
 22	prepareCallFunc            LanguageModelPrepareCallFunc
 23	mapFinishReasonFunc        LanguageModelMapFinishReasonFunc
 24	extraContentFunc           LanguageModelExtraContentFunc
 25	usageFunc                  LanguageModelUsageFunc
 26	streamUsageFunc            LanguageModelStreamUsageFunc
 27	streamExtraFunc            LanguageModelStreamExtraFunc
 28	streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
 29	toPromptFunc               LanguageModelToPromptFunc
 30}
 31
 32// LanguageModelOption is a function that configures a languageModel.
 33type LanguageModelOption = func(*languageModel)
 34
 35// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
 36func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
 37	return func(l *languageModel) {
 38		l.prepareCallFunc = fn
 39	}
 40}
 41
 42// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
 43func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
 44	return func(l *languageModel) {
 45		l.mapFinishReasonFunc = fn
 46	}
 47}
 48
 49// WithLanguageModelExtraContentFunc sets the extra content function for the language model.
 50func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
 51	return func(l *languageModel) {
 52		l.extraContentFunc = fn
 53	}
 54}
 55
 56// WithLanguageModelStreamExtraFunc sets the stream extra function for the language model.
 57func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
 58	return func(l *languageModel) {
 59		l.streamExtraFunc = fn
 60	}
 61}
 62
 63// WithLanguageModelUsageFunc sets the usage function for the language model.
 64func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
 65	return func(l *languageModel) {
 66		l.usageFunc = fn
 67	}
 68}
 69
 70// WithLanguageModelStreamUsageFunc sets the stream usage function for the language model.
 71func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
 72	return func(l *languageModel) {
 73		l.streamUsageFunc = fn
 74	}
 75}
 76
 77// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
 78func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
 79	return func(l *languageModel) {
 80		l.toPromptFunc = fn
 81	}
 82}
 83
 84func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
 85	model := languageModel{
 86		modelID:                    modelID,
 87		provider:                   provider,
 88		client:                     client,
 89		prepareCallFunc:            DefaultPrepareCallFunc,
 90		mapFinishReasonFunc:        DefaultMapFinishReasonFunc,
 91		usageFunc:                  DefaultUsageFunc,
 92		streamUsageFunc:            DefaultStreamUsageFunc,
 93		streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc,
 94		toPromptFunc:               DefaultToPrompt,
 95	}
 96
 97	for _, o := range opts {
 98		o(&model)
 99	}
100	return model
101}
102
103type streamToolCall struct {
104	id          string
105	name        string
106	arguments   string
107	hasFinished bool
108}
109
110// Model implements fantasy.LanguageModel.
111func (o languageModel) Model() string {
112	return o.modelID
113}
114
115// Provider implements fantasy.LanguageModel.
116func (o languageModel) Provider() string {
117	return o.provider
118}
119
120func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) {
121	params := &openai.ChatCompletionNewParams{}
122	messages, warnings := o.toPromptFunc(call.Prompt, o.provider, o.modelID)
123	if call.TopK != nil {
124		warnings = append(warnings, fantasy.CallWarning{
125			Type:    fantasy.CallWarningTypeUnsupportedSetting,
126			Setting: "top_k",
127		})
128	}
129
130	if call.MaxOutputTokens != nil {
131		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
132	}
133	if call.Temperature != nil {
134		params.Temperature = param.NewOpt(*call.Temperature)
135	}
136	if call.TopP != nil {
137		params.TopP = param.NewOpt(*call.TopP)
138	}
139	if call.FrequencyPenalty != nil {
140		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
141	}
142	if call.PresencePenalty != nil {
143		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
144	}
145
146	if isReasoningModel(o.modelID) {
147		// remove unsupported settings for reasoning models
148		// see https://platform.openai.com/docs/guides/reasoning#limitations
149		if call.Temperature != nil {
150			params.Temperature = param.Opt[float64]{}
151			warnings = append(warnings, fantasy.CallWarning{
152				Type:    fantasy.CallWarningTypeUnsupportedSetting,
153				Setting: "temperature",
154				Details: "temperature is not supported for reasoning models",
155			})
156		}
157		if call.TopP != nil {
158			params.TopP = param.Opt[float64]{}
159			warnings = append(warnings, fantasy.CallWarning{
160				Type:    fantasy.CallWarningTypeUnsupportedSetting,
161				Setting: "TopP",
162				Details: "TopP is not supported for reasoning models",
163			})
164		}
165		if call.FrequencyPenalty != nil {
166			params.FrequencyPenalty = param.Opt[float64]{}
167			warnings = append(warnings, fantasy.CallWarning{
168				Type:    fantasy.CallWarningTypeUnsupportedSetting,
169				Setting: "FrequencyPenalty",
170				Details: "FrequencyPenalty is not supported for reasoning models",
171			})
172		}
173		if call.PresencePenalty != nil {
174			params.PresencePenalty = param.Opt[float64]{}
175			warnings = append(warnings, fantasy.CallWarning{
176				Type:    fantasy.CallWarningTypeUnsupportedSetting,
177				Setting: "PresencePenalty",
178				Details: "PresencePenalty is not supported for reasoning models",
179			})
180		}
181
182		// reasoning models use max_completion_tokens instead of max_tokens
183		if call.MaxOutputTokens != nil {
184			if !params.MaxCompletionTokens.Valid() {
185				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
186			}
187			params.MaxTokens = param.Opt[int64]{}
188		}
189	}
190
191	// Handle search preview models
192	if isSearchPreviewModel(o.modelID) {
193		if call.Temperature != nil {
194			params.Temperature = param.Opt[float64]{}
195			warnings = append(warnings, fantasy.CallWarning{
196				Type:    fantasy.CallWarningTypeUnsupportedSetting,
197				Setting: "temperature",
198				Details: "temperature is not supported for the search preview models and has been removed.",
199			})
200		}
201	}
202
203	optionsWarnings, err := o.prepareCallFunc(o, params, call)
204	if err != nil {
205		return nil, nil, err
206	}
207
208	if len(optionsWarnings) > 0 {
209		warnings = append(warnings, optionsWarnings...)
210	}
211
212	params.Messages = messages
213	params.Model = o.modelID
214
215	if len(call.Tools) > 0 {
216		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
217		params.Tools = tools
218		if toolChoice != nil {
219			params.ToolChoice = *toolChoice
220		}
221		warnings = append(warnings, toolWarnings...)
222	}
223	return params, warnings, nil
224}
225
226// Generate implements fantasy.LanguageModel.
227func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
228	params, warnings, err := o.prepareParams(call)
229	if err != nil {
230		return nil, err
231	}
232	response, err := o.client.Chat.Completions.New(ctx, *params)
233	if err != nil {
234		return nil, toProviderErr(err)
235	}
236
237	if len(response.Choices) == 0 {
238		return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
239	}
240	choice := response.Choices[0]
241	content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
242	text := choice.Message.Content
243	if text != "" {
244		content = append(content, fantasy.TextContent{
245			Text: text,
246		})
247	}
248	if o.extraContentFunc != nil {
249		extraContent := o.extraContentFunc(choice)
250		content = append(content, extraContent...)
251	}
252	for _, tc := range choice.Message.ToolCalls {
253		toolCallID := tc.ID
254		content = append(content, fantasy.ToolCallContent{
255			ProviderExecuted: false, // TODO: update when handling other tools
256			ToolCallID:       toolCallID,
257			ToolName:         tc.Function.Name,
258			Input:            tc.Function.Arguments,
259		})
260	}
261	// Handle annotations/citations
262	for _, annotation := range choice.Message.Annotations {
263		if annotation.Type == "url_citation" {
264			content = append(content, fantasy.SourceContent{
265				SourceType: fantasy.SourceTypeURL,
266				ID:         uuid.NewString(),
267				URL:        annotation.URLCitation.URL,
268				Title:      annotation.URLCitation.Title,
269			})
270		}
271	}
272
273	usage, providerMetadata := o.usageFunc(*response)
274
275	mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason)
276	if len(choice.Message.ToolCalls) > 0 {
277		mappedFinishReason = fantasy.FinishReasonToolCalls
278	}
279	return &fantasy.Response{
280		Content:      content,
281		Usage:        usage,
282		FinishReason: mappedFinishReason,
283		ProviderMetadata: fantasy.ProviderMetadata{
284			Name: providerMetadata,
285		},
286		Warnings: warnings,
287	}, nil
288}
289
290// Stream implements fantasy.LanguageModel.
291func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
292	params, warnings, err := o.prepareParams(call)
293	if err != nil {
294		return nil, err
295	}
296
297	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
298		IncludeUsage: openai.Bool(true),
299	}
300
301	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
302	isActiveText := false
303	toolCalls := make(map[int64]streamToolCall)
304
305	// Build provider metadata for streaming
306	providerMetadata := fantasy.ProviderMetadata{
307		Name: &ProviderMetadata{},
308	}
309	acc := openai.ChatCompletionAccumulator{}
310	extraContext := make(map[string]any)
311	var usage fantasy.Usage
312	var finishReason string
313	return func(yield func(fantasy.StreamPart) bool) {
314		if len(warnings) > 0 {
315			if !yield(fantasy.StreamPart{
316				Type:     fantasy.StreamPartTypeWarnings,
317				Warnings: warnings,
318			}) {
319				return
320			}
321		}
322		for stream.Next() {
323			chunk := stream.Current()
324			acc.AddChunk(chunk)
325			usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
326			if len(chunk.Choices) == 0 {
327				continue
328			}
329			for _, choice := range chunk.Choices {
330				if choice.FinishReason != "" {
331					finishReason = choice.FinishReason
332				}
333				switch {
334				case choice.Delta.Content != "":
335					if !isActiveText {
336						isActiveText = true
337						if !yield(fantasy.StreamPart{
338							Type: fantasy.StreamPartTypeTextStart,
339							ID:   "0",
340						}) {
341							return
342						}
343					}
344					if !yield(fantasy.StreamPart{
345						Type:  fantasy.StreamPartTypeTextDelta,
346						ID:    "0",
347						Delta: choice.Delta.Content,
348					}) {
349						return
350					}
351				case len(choice.Delta.ToolCalls) > 0:
352					if isActiveText {
353						isActiveText = false
354						if !yield(fantasy.StreamPart{
355							Type: fantasy.StreamPartTypeTextEnd,
356							ID:   "0",
357						}) {
358							return
359						}
360					}
361
362					for _, toolCallDelta := range choice.Delta.ToolCalls {
363						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
364							if existingToolCall.hasFinished {
365								continue
366							}
367							if toolCallDelta.Function.Arguments != "" {
368								existingToolCall.arguments += toolCallDelta.Function.Arguments
369							}
370							if !yield(fantasy.StreamPart{
371								Type:  fantasy.StreamPartTypeToolInputDelta,
372								ID:    existingToolCall.id,
373								Delta: toolCallDelta.Function.Arguments,
374							}) {
375								return
376							}
377							toolCalls[toolCallDelta.Index] = existingToolCall
378							if xjson.IsValid(existingToolCall.arguments) {
379								if !yield(fantasy.StreamPart{
380									Type: fantasy.StreamPartTypeToolInputEnd,
381									ID:   existingToolCall.id,
382								}) {
383									return
384								}
385
386								if !yield(fantasy.StreamPart{
387									Type:          fantasy.StreamPartTypeToolCall,
388									ID:            existingToolCall.id,
389									ToolCallName:  existingToolCall.name,
390									ToolCallInput: existingToolCall.arguments,
391								}) {
392									return
393								}
394								existingToolCall.hasFinished = true
395								toolCalls[toolCallDelta.Index] = existingToolCall
396							}
397						} else {
398							// Does not exist
399							var err error
400							if toolCallDelta.Type != "function" {
401								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."}
402							}
403							if toolCallDelta.ID == "" {
404								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'id' to be a string."}
405							}
406							if toolCallDelta.Function.Name == "" {
407								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function.name' to be a string."}
408							}
409							if err != nil {
410								yield(fantasy.StreamPart{
411									Type:  fantasy.StreamPartTypeError,
412									Error: toProviderErr(stream.Err()),
413								})
414								return
415							}
416
417							if !yield(fantasy.StreamPart{
418								Type:         fantasy.StreamPartTypeToolInputStart,
419								ID:           toolCallDelta.ID,
420								ToolCallName: toolCallDelta.Function.Name,
421							}) {
422								return
423							}
424							toolCalls[toolCallDelta.Index] = streamToolCall{
425								id:        toolCallDelta.ID,
426								name:      toolCallDelta.Function.Name,
427								arguments: toolCallDelta.Function.Arguments,
428							}
429
430							exTc := toolCalls[toolCallDelta.Index]
431							if exTc.arguments != "" {
432								if !yield(fantasy.StreamPart{
433									Type:  fantasy.StreamPartTypeToolInputDelta,
434									ID:    exTc.id,
435									Delta: exTc.arguments,
436								}) {
437									return
438								}
439								if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
440									if !yield(fantasy.StreamPart{
441										Type: fantasy.StreamPartTypeToolInputEnd,
442										ID:   toolCallDelta.ID,
443									}) {
444										return
445									}
446
447									if !yield(fantasy.StreamPart{
448										Type:          fantasy.StreamPartTypeToolCall,
449										ID:            exTc.id,
450										ToolCallName:  exTc.name,
451										ToolCallInput: exTc.arguments,
452									}) {
453										return
454									}
455									exTc.hasFinished = true
456									toolCalls[toolCallDelta.Index] = exTc
457								}
458							}
459							continue
460						}
461					}
462				}
463
464				if o.streamExtraFunc != nil {
465					updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
466					if !shouldContinue {
467						return
468					}
469					extraContext = updatedContext
470				}
471			}
472
473			// Check for annotations in the delta's raw JSON
474			for _, choice := range chunk.Choices {
475				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
476					for _, annotation := range annotations {
477						if annotation.Type == "url_citation" {
478							if !yield(fantasy.StreamPart{
479								Type:       fantasy.StreamPartTypeSource,
480								ID:         uuid.NewString(),
481								SourceType: fantasy.SourceTypeURL,
482								URL:        annotation.URLCitation.URL,
483								Title:      annotation.URLCitation.Title,
484							}) {
485								return
486							}
487						}
488					}
489				}
490			}
491		}
492		err := stream.Err()
493		if err == nil || errors.Is(err, io.EOF) {
494			// finished
495			if isActiveText {
496				isActiveText = false
497				if !yield(fantasy.StreamPart{
498					Type: fantasy.StreamPartTypeTextEnd,
499					ID:   "0",
500				}) {
501					return
502				}
503			}
504
505			if len(acc.Choices) > 0 {
506				choice := acc.Choices[0]
507				// Add logprobs if available
508				providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
509
510				// Handle annotations/citations from accumulated response
511				for _, annotation := range choice.Message.Annotations {
512					if annotation.Type == "url_citation" {
513						if !yield(fantasy.StreamPart{
514							Type:       fantasy.StreamPartTypeSource,
515							ID:         acc.ID,
516							SourceType: fantasy.SourceTypeURL,
517							URL:        annotation.URLCitation.URL,
518							Title:      annotation.URLCitation.Title,
519						}) {
520							return
521						}
522					}
523				}
524			}
525			mappedFinishReason := o.mapFinishReasonFunc(finishReason)
526			if len(acc.Choices) > 0 {
527				choice := acc.Choices[0]
528				if len(choice.Message.ToolCalls) > 0 {
529					mappedFinishReason = fantasy.FinishReasonToolCalls
530				}
531			}
532			yield(fantasy.StreamPart{
533				Type:             fantasy.StreamPartTypeFinish,
534				Usage:            usage,
535				FinishReason:     mappedFinishReason,
536				ProviderMetadata: providerMetadata,
537			})
538			return
539		} else { //nolint: revive
540			yield(fantasy.StreamPart{
541				Type:  fantasy.StreamPartTypeError,
542				Error: toProviderErr(err),
543			})
544			return
545		}
546	}, nil
547}
548
549func isReasoningModel(modelID string) bool {
550	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
551}
552
553func isSearchPreviewModel(modelID string) bool {
554	return strings.Contains(modelID, "search-preview")
555}
556
557func supportsFlexProcessing(modelID string) bool {
558	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
559}
560
561func supportsPriorityProcessing(modelID string) bool {
562	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
563		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
564		strings.HasPrefix(modelID, "o4-mini")
565}
566
567func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) {
568	for _, tool := range tools {
569		if tool.GetType() == fantasy.ToolTypeFunction {
570			ft, ok := tool.(fantasy.FunctionTool)
571			if !ok {
572				continue
573			}
574			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
575				OfFunction: &openai.ChatCompletionFunctionToolParam{
576					Function: shared.FunctionDefinitionParam{
577						Name:        ft.Name,
578						Description: param.NewOpt(ft.Description),
579						Parameters:  openai.FunctionParameters(ft.InputSchema),
580						Strict:      param.NewOpt(false),
581					},
582					Type: "function",
583				},
584			})
585			continue
586		}
587
588		// TODO: handle provider tool calls
589		warnings = append(warnings, fantasy.CallWarning{
590			Type:    fantasy.CallWarningTypeUnsupportedTool,
591			Tool:    tool,
592			Message: "tool is not supported",
593		})
594	}
595	if toolChoice == nil {
596		return openAiTools, openAiToolChoice, warnings
597	}
598
599	switch *toolChoice {
600	case fantasy.ToolChoiceAuto:
601		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
602			OfAuto: param.NewOpt("auto"),
603		}
604	case fantasy.ToolChoiceNone:
605		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
606			OfAuto: param.NewOpt("none"),
607		}
608	default:
609		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
610			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
611				Type: "function",
612				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
613					Name: string(*toolChoice),
614				},
615			},
616		}
617	}
618	return openAiTools, openAiToolChoice, warnings
619}
620
621// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
622func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
623	var annotations []openai.ChatCompletionMessageAnnotation
624
625	// Parse the raw JSON to extract annotations
626	var deltaData map[string]any
627	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
628		return annotations
629	}
630
631	// Check if annotations exist in the delta
632	if annotationsData, ok := deltaData["annotations"].([]any); ok {
633		for _, annotationData := range annotationsData {
634			if annotationMap, ok := annotationData.(map[string]any); ok {
635				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
636					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
637						annotation := openai.ChatCompletionMessageAnnotation{
638							Type: "url_citation",
639							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
640								URL:   urlCitationData["url"].(string),
641								Title: urlCitationData["title"].(string),
642							},
643						}
644						annotations = append(annotations, annotation)
645					}
646				}
647			}
648		}
649	}
650
651	return annotations
652}