1package openai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "io"
8 "strings"
9
10 "charm.land/fantasy"
11 xjson "github.com/charmbracelet/x/json"
12 "github.com/google/uuid"
13 "github.com/openai/openai-go/v2"
14 "github.com/openai/openai-go/v2/packages/param"
15 "github.com/openai/openai-go/v2/shared"
16)
17
18type languageModel struct {
19 provider string
20 modelID string
21 client openai.Client
22 prepareCallFunc LanguageModelPrepareCallFunc
23 mapFinishReasonFunc LanguageModelMapFinishReasonFunc
24 extraContentFunc LanguageModelExtraContentFunc
25 usageFunc LanguageModelUsageFunc
26 streamUsageFunc LanguageModelStreamUsageFunc
27 streamExtraFunc LanguageModelStreamExtraFunc
28 streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
29 toPromptFunc LanguageModelToPromptFunc
30}
31
32// LanguageModelOption is a function that configures a languageModel.
33type LanguageModelOption = func(*languageModel)
34
35// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
36func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
37 return func(l *languageModel) {
38 l.prepareCallFunc = fn
39 }
40}
41
42// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
43func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
44 return func(l *languageModel) {
45 l.mapFinishReasonFunc = fn
46 }
47}
48
49// WithLanguageModelExtraContentFunc sets the extra content function for the language model.
50func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
51 return func(l *languageModel) {
52 l.extraContentFunc = fn
53 }
54}
55
56// WithLanguageModelStreamExtraFunc sets the stream extra function for the language model.
57func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
58 return func(l *languageModel) {
59 l.streamExtraFunc = fn
60 }
61}
62
63// WithLanguageModelUsageFunc sets the usage function for the language model.
64func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
65 return func(l *languageModel) {
66 l.usageFunc = fn
67 }
68}
69
70// WithLanguageModelStreamUsageFunc sets the stream usage function for the language model.
71func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
72 return func(l *languageModel) {
73 l.streamUsageFunc = fn
74 }
75}
76
77// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
78func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
79 return func(l *languageModel) {
80 l.toPromptFunc = fn
81 }
82}
83
84func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
85 model := languageModel{
86 modelID: modelID,
87 provider: provider,
88 client: client,
89 prepareCallFunc: DefaultPrepareCallFunc,
90 mapFinishReasonFunc: DefaultMapFinishReasonFunc,
91 usageFunc: DefaultUsageFunc,
92 streamUsageFunc: DefaultStreamUsageFunc,
93 streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc,
94 toPromptFunc: DefaultToPrompt,
95 }
96
97 for _, o := range opts {
98 o(&model)
99 }
100 return model
101}
102
103type streamToolCall struct {
104 id string
105 name string
106 arguments string
107 hasFinished bool
108}
109
110// Model implements fantasy.LanguageModel.
111func (o languageModel) Model() string {
112 return o.modelID
113}
114
115// Provider implements fantasy.LanguageModel.
116func (o languageModel) Provider() string {
117 return o.provider
118}
119
120func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) {
121 params := &openai.ChatCompletionNewParams{}
122 messages, warnings := o.toPromptFunc(call.Prompt, o.provider, o.modelID)
123 if call.TopK != nil {
124 warnings = append(warnings, fantasy.CallWarning{
125 Type: fantasy.CallWarningTypeUnsupportedSetting,
126 Setting: "top_k",
127 })
128 }
129
130 if call.MaxOutputTokens != nil {
131 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
132 }
133 if call.Temperature != nil {
134 params.Temperature = param.NewOpt(*call.Temperature)
135 }
136 if call.TopP != nil {
137 params.TopP = param.NewOpt(*call.TopP)
138 }
139 if call.FrequencyPenalty != nil {
140 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
141 }
142 if call.PresencePenalty != nil {
143 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
144 }
145
146 if isReasoningModel(o.modelID) {
147 // remove unsupported settings for reasoning models
148 // see https://platform.openai.com/docs/guides/reasoning#limitations
149 if call.Temperature != nil {
150 params.Temperature = param.Opt[float64]{}
151 warnings = append(warnings, fantasy.CallWarning{
152 Type: fantasy.CallWarningTypeUnsupportedSetting,
153 Setting: "temperature",
154 Details: "temperature is not supported for reasoning models",
155 })
156 }
157 if call.TopP != nil {
158 params.TopP = param.Opt[float64]{}
159 warnings = append(warnings, fantasy.CallWarning{
160 Type: fantasy.CallWarningTypeUnsupportedSetting,
161 Setting: "TopP",
162 Details: "TopP is not supported for reasoning models",
163 })
164 }
165 if call.FrequencyPenalty != nil {
166 params.FrequencyPenalty = param.Opt[float64]{}
167 warnings = append(warnings, fantasy.CallWarning{
168 Type: fantasy.CallWarningTypeUnsupportedSetting,
169 Setting: "FrequencyPenalty",
170 Details: "FrequencyPenalty is not supported for reasoning models",
171 })
172 }
173 if call.PresencePenalty != nil {
174 params.PresencePenalty = param.Opt[float64]{}
175 warnings = append(warnings, fantasy.CallWarning{
176 Type: fantasy.CallWarningTypeUnsupportedSetting,
177 Setting: "PresencePenalty",
178 Details: "PresencePenalty is not supported for reasoning models",
179 })
180 }
181
182 // reasoning models use max_completion_tokens instead of max_tokens
183 if call.MaxOutputTokens != nil {
184 if !params.MaxCompletionTokens.Valid() {
185 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
186 }
187 params.MaxTokens = param.Opt[int64]{}
188 }
189 }
190
191 // Handle search preview models
192 if isSearchPreviewModel(o.modelID) {
193 if call.Temperature != nil {
194 params.Temperature = param.Opt[float64]{}
195 warnings = append(warnings, fantasy.CallWarning{
196 Type: fantasy.CallWarningTypeUnsupportedSetting,
197 Setting: "temperature",
198 Details: "temperature is not supported for the search preview models and has been removed.",
199 })
200 }
201 }
202
203 optionsWarnings, err := o.prepareCallFunc(o, params, call)
204 if err != nil {
205 return nil, nil, err
206 }
207
208 if len(optionsWarnings) > 0 {
209 warnings = append(warnings, optionsWarnings...)
210 }
211
212 params.Messages = messages
213 params.Model = o.modelID
214
215 if len(call.Tools) > 0 {
216 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
217 params.Tools = tools
218 if toolChoice != nil {
219 params.ToolChoice = *toolChoice
220 }
221 warnings = append(warnings, toolWarnings...)
222 }
223 return params, warnings, nil
224}
225
226// Generate implements fantasy.LanguageModel.
227func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
228 params, warnings, err := o.prepareParams(call)
229 if err != nil {
230 return nil, err
231 }
232 response, err := o.client.Chat.Completions.New(ctx, *params)
233 if err != nil {
234 return nil, toProviderErr(err)
235 }
236
237 if len(response.Choices) == 0 {
238 return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
239 }
240 choice := response.Choices[0]
241 content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
242 text := choice.Message.Content
243 if text != "" {
244 content = append(content, fantasy.TextContent{
245 Text: text,
246 })
247 }
248 if o.extraContentFunc != nil {
249 extraContent := o.extraContentFunc(choice)
250 content = append(content, extraContent...)
251 }
252 for _, tc := range choice.Message.ToolCalls {
253 toolCallID := tc.ID
254 content = append(content, fantasy.ToolCallContent{
255 ProviderExecuted: false, // TODO: update when handling other tools
256 ToolCallID: toolCallID,
257 ToolName: tc.Function.Name,
258 Input: tc.Function.Arguments,
259 })
260 }
261 // Handle annotations/citations
262 for _, annotation := range choice.Message.Annotations {
263 if annotation.Type == "url_citation" {
264 content = append(content, fantasy.SourceContent{
265 SourceType: fantasy.SourceTypeURL,
266 ID: uuid.NewString(),
267 URL: annotation.URLCitation.URL,
268 Title: annotation.URLCitation.Title,
269 })
270 }
271 }
272
273 usage, providerMetadata := o.usageFunc(*response)
274
275 mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason)
276 if len(choice.Message.ToolCalls) > 0 {
277 mappedFinishReason = fantasy.FinishReasonToolCalls
278 }
279 return &fantasy.Response{
280 Content: content,
281 Usage: usage,
282 FinishReason: mappedFinishReason,
283 ProviderMetadata: fantasy.ProviderMetadata{
284 Name: providerMetadata,
285 },
286 Warnings: warnings,
287 }, nil
288}
289
290// Stream implements fantasy.LanguageModel.
291func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
292 params, warnings, err := o.prepareParams(call)
293 if err != nil {
294 return nil, err
295 }
296
297 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
298 IncludeUsage: openai.Bool(true),
299 }
300
301 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
302 isActiveText := false
303 toolCalls := make(map[int64]streamToolCall)
304
305 // Build provider metadata for streaming
306 providerMetadata := fantasy.ProviderMetadata{
307 Name: &ProviderMetadata{},
308 }
309 acc := openai.ChatCompletionAccumulator{}
310 extraContext := make(map[string]any)
311 var usage fantasy.Usage
312 var finishReason string
313 return func(yield func(fantasy.StreamPart) bool) {
314 if len(warnings) > 0 {
315 if !yield(fantasy.StreamPart{
316 Type: fantasy.StreamPartTypeWarnings,
317 Warnings: warnings,
318 }) {
319 return
320 }
321 }
322 for stream.Next() {
323 chunk := stream.Current()
324 acc.AddChunk(chunk)
325 usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
326 if len(chunk.Choices) == 0 {
327 continue
328 }
329 for _, choice := range chunk.Choices {
330 if choice.FinishReason != "" {
331 finishReason = choice.FinishReason
332 }
333 switch {
334 case choice.Delta.Content != "":
335 if !isActiveText {
336 isActiveText = true
337 if !yield(fantasy.StreamPart{
338 Type: fantasy.StreamPartTypeTextStart,
339 ID: "0",
340 }) {
341 return
342 }
343 }
344 if !yield(fantasy.StreamPart{
345 Type: fantasy.StreamPartTypeTextDelta,
346 ID: "0",
347 Delta: choice.Delta.Content,
348 }) {
349 return
350 }
351 case len(choice.Delta.ToolCalls) > 0:
352 if isActiveText {
353 isActiveText = false
354 if !yield(fantasy.StreamPart{
355 Type: fantasy.StreamPartTypeTextEnd,
356 ID: "0",
357 }) {
358 return
359 }
360 }
361
362 for _, toolCallDelta := range choice.Delta.ToolCalls {
363 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
364 if existingToolCall.hasFinished {
365 continue
366 }
367 if toolCallDelta.Function.Arguments != "" {
368 existingToolCall.arguments += toolCallDelta.Function.Arguments
369 }
370 if !yield(fantasy.StreamPart{
371 Type: fantasy.StreamPartTypeToolInputDelta,
372 ID: existingToolCall.id,
373 Delta: toolCallDelta.Function.Arguments,
374 }) {
375 return
376 }
377 toolCalls[toolCallDelta.Index] = existingToolCall
378 if xjson.IsValid(existingToolCall.arguments) {
379 if !yield(fantasy.StreamPart{
380 Type: fantasy.StreamPartTypeToolInputEnd,
381 ID: existingToolCall.id,
382 }) {
383 return
384 }
385
386 if !yield(fantasy.StreamPart{
387 Type: fantasy.StreamPartTypeToolCall,
388 ID: existingToolCall.id,
389 ToolCallName: existingToolCall.name,
390 ToolCallInput: existingToolCall.arguments,
391 }) {
392 return
393 }
394 existingToolCall.hasFinished = true
395 toolCalls[toolCallDelta.Index] = existingToolCall
396 }
397 } else {
398 // Does not exist
399 var err error
400 if toolCallDelta.Type != "function" {
401 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."}
402 }
403 if toolCallDelta.ID == "" {
404 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'id' to be a string."}
405 }
406 if toolCallDelta.Function.Name == "" {
407 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function.name' to be a string."}
408 }
409 if err != nil {
410 yield(fantasy.StreamPart{
411 Type: fantasy.StreamPartTypeError,
412 Error: toProviderErr(stream.Err()),
413 })
414 return
415 }
416
417 if !yield(fantasy.StreamPart{
418 Type: fantasy.StreamPartTypeToolInputStart,
419 ID: toolCallDelta.ID,
420 ToolCallName: toolCallDelta.Function.Name,
421 }) {
422 return
423 }
424 toolCalls[toolCallDelta.Index] = streamToolCall{
425 id: toolCallDelta.ID,
426 name: toolCallDelta.Function.Name,
427 arguments: toolCallDelta.Function.Arguments,
428 }
429
430 exTc := toolCalls[toolCallDelta.Index]
431 if exTc.arguments != "" {
432 if !yield(fantasy.StreamPart{
433 Type: fantasy.StreamPartTypeToolInputDelta,
434 ID: exTc.id,
435 Delta: exTc.arguments,
436 }) {
437 return
438 }
439 if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
440 if !yield(fantasy.StreamPart{
441 Type: fantasy.StreamPartTypeToolInputEnd,
442 ID: toolCallDelta.ID,
443 }) {
444 return
445 }
446
447 if !yield(fantasy.StreamPart{
448 Type: fantasy.StreamPartTypeToolCall,
449 ID: exTc.id,
450 ToolCallName: exTc.name,
451 ToolCallInput: exTc.arguments,
452 }) {
453 return
454 }
455 exTc.hasFinished = true
456 toolCalls[toolCallDelta.Index] = exTc
457 }
458 }
459 continue
460 }
461 }
462 }
463
464 if o.streamExtraFunc != nil {
465 updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
466 if !shouldContinue {
467 return
468 }
469 extraContext = updatedContext
470 }
471 }
472
473 // Check for annotations in the delta's raw JSON
474 for _, choice := range chunk.Choices {
475 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
476 for _, annotation := range annotations {
477 if annotation.Type == "url_citation" {
478 if !yield(fantasy.StreamPart{
479 Type: fantasy.StreamPartTypeSource,
480 ID: uuid.NewString(),
481 SourceType: fantasy.SourceTypeURL,
482 URL: annotation.URLCitation.URL,
483 Title: annotation.URLCitation.Title,
484 }) {
485 return
486 }
487 }
488 }
489 }
490 }
491 }
492 err := stream.Err()
493 if err == nil || errors.Is(err, io.EOF) {
494 // finished
495 if isActiveText {
496 isActiveText = false
497 if !yield(fantasy.StreamPart{
498 Type: fantasy.StreamPartTypeTextEnd,
499 ID: "0",
500 }) {
501 return
502 }
503 }
504
505 if len(acc.Choices) > 0 {
506 choice := acc.Choices[0]
507 // Add logprobs if available
508 providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
509
510 // Handle annotations/citations from accumulated response
511 for _, annotation := range choice.Message.Annotations {
512 if annotation.Type == "url_citation" {
513 if !yield(fantasy.StreamPart{
514 Type: fantasy.StreamPartTypeSource,
515 ID: acc.ID,
516 SourceType: fantasy.SourceTypeURL,
517 URL: annotation.URLCitation.URL,
518 Title: annotation.URLCitation.Title,
519 }) {
520 return
521 }
522 }
523 }
524 }
525 mappedFinishReason := o.mapFinishReasonFunc(finishReason)
526 if len(acc.Choices) > 0 {
527 choice := acc.Choices[0]
528 if len(choice.Message.ToolCalls) > 0 {
529 mappedFinishReason = fantasy.FinishReasonToolCalls
530 }
531 }
532 yield(fantasy.StreamPart{
533 Type: fantasy.StreamPartTypeFinish,
534 Usage: usage,
535 FinishReason: mappedFinishReason,
536 ProviderMetadata: providerMetadata,
537 })
538 return
539 } else { //nolint: revive
540 yield(fantasy.StreamPart{
541 Type: fantasy.StreamPartTypeError,
542 Error: toProviderErr(err),
543 })
544 return
545 }
546 }, nil
547}
548
549func isReasoningModel(modelID string) bool {
550 return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
551}
552
553func isSearchPreviewModel(modelID string) bool {
554 return strings.Contains(modelID, "search-preview")
555}
556
557func supportsFlexProcessing(modelID string) bool {
558 return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
559}
560
561func supportsPriorityProcessing(modelID string) bool {
562 return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
563 strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
564 strings.HasPrefix(modelID, "o4-mini")
565}
566
567func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) {
568 for _, tool := range tools {
569 if tool.GetType() == fantasy.ToolTypeFunction {
570 ft, ok := tool.(fantasy.FunctionTool)
571 if !ok {
572 continue
573 }
574 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
575 OfFunction: &openai.ChatCompletionFunctionToolParam{
576 Function: shared.FunctionDefinitionParam{
577 Name: ft.Name,
578 Description: param.NewOpt(ft.Description),
579 Parameters: openai.FunctionParameters(ft.InputSchema),
580 Strict: param.NewOpt(false),
581 },
582 Type: "function",
583 },
584 })
585 continue
586 }
587
588 // TODO: handle provider tool calls
589 warnings = append(warnings, fantasy.CallWarning{
590 Type: fantasy.CallWarningTypeUnsupportedTool,
591 Tool: tool,
592 Message: "tool is not supported",
593 })
594 }
595 if toolChoice == nil {
596 return openAiTools, openAiToolChoice, warnings
597 }
598
599 switch *toolChoice {
600 case fantasy.ToolChoiceAuto:
601 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
602 OfAuto: param.NewOpt("auto"),
603 }
604 case fantasy.ToolChoiceNone:
605 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
606 OfAuto: param.NewOpt("none"),
607 }
608 default:
609 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
610 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
611 Type: "function",
612 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
613 Name: string(*toolChoice),
614 },
615 },
616 }
617 }
618 return openAiTools, openAiToolChoice, warnings
619}
620
621// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
622func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
623 var annotations []openai.ChatCompletionMessageAnnotation
624
625 // Parse the raw JSON to extract annotations
626 var deltaData map[string]any
627 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
628 return annotations
629 }
630
631 // Check if annotations exist in the delta
632 if annotationsData, ok := deltaData["annotations"].([]any); ok {
633 for _, annotationData := range annotationsData {
634 if annotationMap, ok := annotationData.(map[string]any); ok {
635 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
636 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
637 annotation := openai.ChatCompletionMessageAnnotation{
638 Type: "url_citation",
639 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
640 URL: urlCitationData["url"].(string),
641 Title: urlCitationData["title"].(string),
642 },
643 }
644 annotations = append(annotations, annotation)
645 }
646 }
647 }
648 }
649 }
650
651 return annotations
652}