1package openai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "io"
8 "strings"
9
10 "charm.land/fantasy"
11 xjson "github.com/charmbracelet/x/json"
12 "github.com/google/uuid"
13 "github.com/openai/openai-go/v2"
14 "github.com/openai/openai-go/v2/packages/param"
15 "github.com/openai/openai-go/v2/shared"
16)
17
18type languageModel struct {
19 provider string
20 modelID string
21 client openai.Client
22 prepareCallFunc LanguageModelPrepareCallFunc
23 mapFinishReasonFunc LanguageModelMapFinishReasonFunc
24 extraContentFunc LanguageModelExtraContentFunc
25 usageFunc LanguageModelUsageFunc
26 streamUsageFunc LanguageModelStreamUsageFunc
27 streamExtraFunc LanguageModelStreamExtraFunc
28 streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
29 toPromptFunc LanguageModelToPromptFunc
30}
31
32// LanguageModelOption is a function that configures a languageModel.
33type LanguageModelOption = func(*languageModel)
34
35// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
36func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
37 return func(l *languageModel) {
38 l.prepareCallFunc = fn
39 }
40}
41
42// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
43func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
44 return func(l *languageModel) {
45 l.mapFinishReasonFunc = fn
46 }
47}
48
49// WithLanguageModelExtraContentFunc sets the extra content function for the language model.
50func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
51 return func(l *languageModel) {
52 l.extraContentFunc = fn
53 }
54}
55
56// WithLanguageModelStreamExtraFunc sets the stream extra function for the language model.
57func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
58 return func(l *languageModel) {
59 l.streamExtraFunc = fn
60 }
61}
62
63// WithLanguageModelUsageFunc sets the usage function for the language model.
64func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
65 return func(l *languageModel) {
66 l.usageFunc = fn
67 }
68}
69
70// WithLanguageModelStreamUsageFunc sets the stream usage function for the language model.
71func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
72 return func(l *languageModel) {
73 l.streamUsageFunc = fn
74 }
75}
76
77// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
78func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
79 return func(l *languageModel) {
80 l.toPromptFunc = fn
81 }
82}
83
84func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
85 model := languageModel{
86 modelID: modelID,
87 provider: provider,
88 client: client,
89 prepareCallFunc: DefaultPrepareCallFunc,
90 mapFinishReasonFunc: DefaultMapFinishReasonFunc,
91 usageFunc: DefaultUsageFunc,
92 streamUsageFunc: DefaultStreamUsageFunc,
93 streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc,
94 toPromptFunc: DefaultToPrompt,
95 }
96
97 for _, o := range opts {
98 o(&model)
99 }
100 return model
101}
102
103type streamToolCall struct {
104 id string
105 name string
106 arguments string
107 hasFinished bool
108}
109
110// Model implements fantasy.LanguageModel.
111func (o languageModel) Model() string {
112 return o.modelID
113}
114
115// Provider implements fantasy.LanguageModel.
116func (o languageModel) Provider() string {
117 return o.provider
118}
119
120func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) {
121 params := &openai.ChatCompletionNewParams{}
122 messages, warnings := o.toPromptFunc(call.Prompt, o.provider, o.modelID)
123 if call.TopK != nil {
124 warnings = append(warnings, fantasy.CallWarning{
125 Type: fantasy.CallWarningTypeUnsupportedSetting,
126 Setting: "top_k",
127 })
128 }
129
130 if call.MaxOutputTokens != nil {
131 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
132 }
133 if call.Temperature != nil {
134 params.Temperature = param.NewOpt(*call.Temperature)
135 }
136 if call.TopP != nil {
137 params.TopP = param.NewOpt(*call.TopP)
138 }
139 if call.FrequencyPenalty != nil {
140 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
141 }
142 if call.PresencePenalty != nil {
143 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
144 }
145
146 if isReasoningModel(o.modelID) {
147 // remove unsupported settings for reasoning models
148 // see https://platform.openai.com/docs/guides/reasoning#limitations
149 if call.Temperature != nil {
150 params.Temperature = param.Opt[float64]{}
151 warnings = append(warnings, fantasy.CallWarning{
152 Type: fantasy.CallWarningTypeUnsupportedSetting,
153 Setting: "temperature",
154 Details: "temperature is not supported for reasoning models",
155 })
156 }
157 if call.TopP != nil {
158 params.TopP = param.Opt[float64]{}
159 warnings = append(warnings, fantasy.CallWarning{
160 Type: fantasy.CallWarningTypeUnsupportedSetting,
161 Setting: "TopP",
162 Details: "TopP is not supported for reasoning models",
163 })
164 }
165 if call.FrequencyPenalty != nil {
166 params.FrequencyPenalty = param.Opt[float64]{}
167 warnings = append(warnings, fantasy.CallWarning{
168 Type: fantasy.CallWarningTypeUnsupportedSetting,
169 Setting: "FrequencyPenalty",
170 Details: "FrequencyPenalty is not supported for reasoning models",
171 })
172 }
173 if call.PresencePenalty != nil {
174 params.PresencePenalty = param.Opt[float64]{}
175 warnings = append(warnings, fantasy.CallWarning{
176 Type: fantasy.CallWarningTypeUnsupportedSetting,
177 Setting: "PresencePenalty",
178 Details: "PresencePenalty is not supported for reasoning models",
179 })
180 }
181
182 // reasoning models use max_completion_tokens instead of max_tokens
183 if call.MaxOutputTokens != nil {
184 if !params.MaxCompletionTokens.Valid() {
185 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
186 }
187 params.MaxTokens = param.Opt[int64]{}
188 }
189 }
190
191 // Handle search preview models
192 if isSearchPreviewModel(o.modelID) {
193 if call.Temperature != nil {
194 params.Temperature = param.Opt[float64]{}
195 warnings = append(warnings, fantasy.CallWarning{
196 Type: fantasy.CallWarningTypeUnsupportedSetting,
197 Setting: "temperature",
198 Details: "temperature is not supported for the search preview models and has been removed.",
199 })
200 }
201 }
202
203 optionsWarnings, err := o.prepareCallFunc(o, params, call)
204 if err != nil {
205 return nil, nil, err
206 }
207
208 if len(optionsWarnings) > 0 {
209 warnings = append(warnings, optionsWarnings...)
210 }
211
212 params.Messages = messages
213 params.Model = o.modelID
214
215 if len(call.Tools) > 0 {
216 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
217 params.Tools = tools
218 if toolChoice != nil {
219 params.ToolChoice = *toolChoice
220 }
221 warnings = append(warnings, toolWarnings...)
222 }
223 return params, warnings, nil
224}
225
226func (o languageModel) handleError(err error) error {
227 var apiErr *openai.Error
228 if errors.As(err, &apiErr) {
229 requestDump := apiErr.DumpRequest(true)
230 responseDump := apiErr.DumpResponse(true)
231 headers := map[string]string{}
232 for k, h := range apiErr.Response.Header {
233 v := h[len(h)-1]
234 headers[strings.ToLower(k)] = v
235 }
236 return fantasy.NewAPICallError(
237 apiErr.Message,
238 apiErr.Request.URL.String(),
239 string(requestDump),
240 apiErr.StatusCode,
241 headers,
242 string(responseDump),
243 apiErr,
244 false,
245 )
246 }
247 return err
248}
249
250// Generate implements fantasy.LanguageModel.
251func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
252 params, warnings, err := o.prepareParams(call)
253 if err != nil {
254 return nil, err
255 }
256 response, err := o.client.Chat.Completions.New(ctx, *params)
257 if err != nil {
258 return nil, o.handleError(err)
259 }
260
261 if len(response.Choices) == 0 {
262 return nil, errors.New("no response generated")
263 }
264 choice := response.Choices[0]
265 content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
266 text := choice.Message.Content
267 if text != "" {
268 content = append(content, fantasy.TextContent{
269 Text: text,
270 })
271 }
272 if o.extraContentFunc != nil {
273 extraContent := o.extraContentFunc(choice)
274 content = append(content, extraContent...)
275 }
276 for _, tc := range choice.Message.ToolCalls {
277 toolCallID := tc.ID
278 content = append(content, fantasy.ToolCallContent{
279 ProviderExecuted: false, // TODO: update when handling other tools
280 ToolCallID: toolCallID,
281 ToolName: tc.Function.Name,
282 Input: tc.Function.Arguments,
283 })
284 }
285 // Handle annotations/citations
286 for _, annotation := range choice.Message.Annotations {
287 if annotation.Type == "url_citation" {
288 content = append(content, fantasy.SourceContent{
289 SourceType: fantasy.SourceTypeURL,
290 ID: uuid.NewString(),
291 URL: annotation.URLCitation.URL,
292 Title: annotation.URLCitation.Title,
293 })
294 }
295 }
296
297 usage, providerMetadata := o.usageFunc(*response)
298
299 mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason)
300 if len(choice.Message.ToolCalls) > 0 {
301 mappedFinishReason = fantasy.FinishReasonToolCalls
302 }
303 return &fantasy.Response{
304 Content: content,
305 Usage: usage,
306 FinishReason: mappedFinishReason,
307 ProviderMetadata: fantasy.ProviderMetadata{
308 Name: providerMetadata,
309 },
310 Warnings: warnings,
311 }, nil
312}
313
314// Stream implements fantasy.LanguageModel.
315func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
316 params, warnings, err := o.prepareParams(call)
317 if err != nil {
318 return nil, err
319 }
320
321 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
322 IncludeUsage: openai.Bool(true),
323 }
324
325 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
326 isActiveText := false
327 toolCalls := make(map[int64]streamToolCall)
328
329 // Build provider metadata for streaming
330 providerMetadata := fantasy.ProviderMetadata{
331 Name: &ProviderMetadata{},
332 }
333 acc := openai.ChatCompletionAccumulator{}
334 extraContext := make(map[string]any)
335 var usage fantasy.Usage
336 var finishReason string
337 return func(yield func(fantasy.StreamPart) bool) {
338 if len(warnings) > 0 {
339 if !yield(fantasy.StreamPart{
340 Type: fantasy.StreamPartTypeWarnings,
341 Warnings: warnings,
342 }) {
343 return
344 }
345 }
346 for stream.Next() {
347 chunk := stream.Current()
348 acc.AddChunk(chunk)
349 usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
350 if len(chunk.Choices) == 0 {
351 continue
352 }
353 for _, choice := range chunk.Choices {
354 if choice.FinishReason != "" {
355 finishReason = choice.FinishReason
356 }
357 switch {
358 case choice.Delta.Content != "":
359 if !isActiveText {
360 isActiveText = true
361 if !yield(fantasy.StreamPart{
362 Type: fantasy.StreamPartTypeTextStart,
363 ID: "0",
364 }) {
365 return
366 }
367 }
368 if !yield(fantasy.StreamPart{
369 Type: fantasy.StreamPartTypeTextDelta,
370 ID: "0",
371 Delta: choice.Delta.Content,
372 }) {
373 return
374 }
375 case len(choice.Delta.ToolCalls) > 0:
376 if isActiveText {
377 isActiveText = false
378 if !yield(fantasy.StreamPart{
379 Type: fantasy.StreamPartTypeTextEnd,
380 ID: "0",
381 }) {
382 return
383 }
384 }
385
386 for _, toolCallDelta := range choice.Delta.ToolCalls {
387 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
388 if existingToolCall.hasFinished {
389 continue
390 }
391 if toolCallDelta.Function.Arguments != "" {
392 existingToolCall.arguments += toolCallDelta.Function.Arguments
393 }
394 if !yield(fantasy.StreamPart{
395 Type: fantasy.StreamPartTypeToolInputDelta,
396 ID: existingToolCall.id,
397 Delta: toolCallDelta.Function.Arguments,
398 }) {
399 return
400 }
401 toolCalls[toolCallDelta.Index] = existingToolCall
402 if xjson.IsValid(existingToolCall.arguments) {
403 if !yield(fantasy.StreamPart{
404 Type: fantasy.StreamPartTypeToolInputEnd,
405 ID: existingToolCall.id,
406 }) {
407 return
408 }
409
410 if !yield(fantasy.StreamPart{
411 Type: fantasy.StreamPartTypeToolCall,
412 ID: existingToolCall.id,
413 ToolCallName: existingToolCall.name,
414 ToolCallInput: existingToolCall.arguments,
415 }) {
416 return
417 }
418 existingToolCall.hasFinished = true
419 toolCalls[toolCallDelta.Index] = existingToolCall
420 }
421 } else {
422 // Does not exist
423 var err error
424 if toolCallDelta.Type != "function" {
425 err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
426 }
427 if toolCallDelta.ID == "" {
428 err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
429 }
430 if toolCallDelta.Function.Name == "" {
431 err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
432 }
433 if err != nil {
434 yield(fantasy.StreamPart{
435 Type: fantasy.StreamPartTypeError,
436 Error: o.handleError(stream.Err()),
437 })
438 return
439 }
440
441 if !yield(fantasy.StreamPart{
442 Type: fantasy.StreamPartTypeToolInputStart,
443 ID: toolCallDelta.ID,
444 ToolCallName: toolCallDelta.Function.Name,
445 }) {
446 return
447 }
448 toolCalls[toolCallDelta.Index] = streamToolCall{
449 id: toolCallDelta.ID,
450 name: toolCallDelta.Function.Name,
451 arguments: toolCallDelta.Function.Arguments,
452 }
453
454 exTc := toolCalls[toolCallDelta.Index]
455 if exTc.arguments != "" {
456 if !yield(fantasy.StreamPart{
457 Type: fantasy.StreamPartTypeToolInputDelta,
458 ID: exTc.id,
459 Delta: exTc.arguments,
460 }) {
461 return
462 }
463 if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
464 if !yield(fantasy.StreamPart{
465 Type: fantasy.StreamPartTypeToolInputEnd,
466 ID: toolCallDelta.ID,
467 }) {
468 return
469 }
470
471 if !yield(fantasy.StreamPart{
472 Type: fantasy.StreamPartTypeToolCall,
473 ID: exTc.id,
474 ToolCallName: exTc.name,
475 ToolCallInput: exTc.arguments,
476 }) {
477 return
478 }
479 exTc.hasFinished = true
480 toolCalls[toolCallDelta.Index] = exTc
481 }
482 }
483 continue
484 }
485 }
486 }
487
488 if o.streamExtraFunc != nil {
489 updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
490 if !shouldContinue {
491 return
492 }
493 extraContext = updatedContext
494 }
495 }
496
497 // Check for annotations in the delta's raw JSON
498 for _, choice := range chunk.Choices {
499 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
500 for _, annotation := range annotations {
501 if annotation.Type == "url_citation" {
502 if !yield(fantasy.StreamPart{
503 Type: fantasy.StreamPartTypeSource,
504 ID: uuid.NewString(),
505 SourceType: fantasy.SourceTypeURL,
506 URL: annotation.URLCitation.URL,
507 Title: annotation.URLCitation.Title,
508 }) {
509 return
510 }
511 }
512 }
513 }
514 }
515 }
516 err := stream.Err()
517 if err == nil || errors.Is(err, io.EOF) {
518 // finished
519 if isActiveText {
520 isActiveText = false
521 if !yield(fantasy.StreamPart{
522 Type: fantasy.StreamPartTypeTextEnd,
523 ID: "0",
524 }) {
525 return
526 }
527 }
528
529 if len(acc.Choices) > 0 {
530 choice := acc.Choices[0]
531 // Add logprobs if available
532 providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
533
534 // Handle annotations/citations from accumulated response
535 for _, annotation := range choice.Message.Annotations {
536 if annotation.Type == "url_citation" {
537 if !yield(fantasy.StreamPart{
538 Type: fantasy.StreamPartTypeSource,
539 ID: acc.ID,
540 SourceType: fantasy.SourceTypeURL,
541 URL: annotation.URLCitation.URL,
542 Title: annotation.URLCitation.Title,
543 }) {
544 return
545 }
546 }
547 }
548 }
549 mappedFinishReason := o.mapFinishReasonFunc(finishReason)
550 if len(acc.Choices) > 0 {
551 choice := acc.Choices[0]
552 if len(choice.Message.ToolCalls) > 0 {
553 mappedFinishReason = fantasy.FinishReasonToolCalls
554 }
555 }
556 yield(fantasy.StreamPart{
557 Type: fantasy.StreamPartTypeFinish,
558 Usage: usage,
559 FinishReason: mappedFinishReason,
560 ProviderMetadata: providerMetadata,
561 })
562 return
563 } else { //nolint: revive
564 yield(fantasy.StreamPart{
565 Type: fantasy.StreamPartTypeError,
566 Error: o.handleError(err),
567 })
568 return
569 }
570 }, nil
571}
572
573func isReasoningModel(modelID string) bool {
574 return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
575}
576
577func isSearchPreviewModel(modelID string) bool {
578 return strings.Contains(modelID, "search-preview")
579}
580
581func supportsFlexProcessing(modelID string) bool {
582 return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
583}
584
585func supportsPriorityProcessing(modelID string) bool {
586 return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
587 strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
588 strings.HasPrefix(modelID, "o4-mini")
589}
590
591func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) {
592 for _, tool := range tools {
593 if tool.GetType() == fantasy.ToolTypeFunction {
594 ft, ok := tool.(fantasy.FunctionTool)
595 if !ok {
596 continue
597 }
598 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
599 OfFunction: &openai.ChatCompletionFunctionToolParam{
600 Function: shared.FunctionDefinitionParam{
601 Name: ft.Name,
602 Description: param.NewOpt(ft.Description),
603 Parameters: openai.FunctionParameters(ft.InputSchema),
604 Strict: param.NewOpt(false),
605 },
606 Type: "function",
607 },
608 })
609 continue
610 }
611
612 // TODO: handle provider tool calls
613 warnings = append(warnings, fantasy.CallWarning{
614 Type: fantasy.CallWarningTypeUnsupportedTool,
615 Tool: tool,
616 Message: "tool is not supported",
617 })
618 }
619 if toolChoice == nil {
620 return openAiTools, openAiToolChoice, warnings
621 }
622
623 switch *toolChoice {
624 case fantasy.ToolChoiceAuto:
625 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
626 OfAuto: param.NewOpt("auto"),
627 }
628 case fantasy.ToolChoiceNone:
629 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
630 OfAuto: param.NewOpt("none"),
631 }
632 default:
633 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
634 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
635 Type: "function",
636 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
637 Name: string(*toolChoice),
638 },
639 },
640 }
641 }
642 return openAiTools, openAiToolChoice, warnings
643}
644
645// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
646func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
647 var annotations []openai.ChatCompletionMessageAnnotation
648
649 // Parse the raw JSON to extract annotations
650 var deltaData map[string]any
651 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
652 return annotations
653 }
654
655 // Check if annotations exist in the delta
656 if annotationsData, ok := deltaData["annotations"].([]any); ok {
657 for _, annotationData := range annotationsData {
658 if annotationMap, ok := annotationData.(map[string]any); ok {
659 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
660 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
661 annotation := openai.ChatCompletionMessageAnnotation{
662 Type: "url_citation",
663 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
664 URL: urlCitationData["url"].(string),
665 Title: urlCitationData["title"].(string),
666 },
667 }
668 annotations = append(annotations, annotation)
669 }
670 }
671 }
672 }
673 }
674
675 return annotations
676}