1package openaicompat
2
3import (
4 "encoding/json"
5 "fmt"
6
7 "charm.land/fantasy"
8 "charm.land/fantasy/providers/openai"
9 openaisdk "github.com/openai/openai-go/v2"
10 "github.com/openai/openai-go/v2/packages/param"
11 "github.com/openai/openai-go/v2/shared"
12)
13
14const reasoningStartedCtx = "reasoning_started"
15
16// PrepareCallFunc prepares the call for the language model.
17func PrepareCallFunc(_ fantasy.LanguageModel, params *openaisdk.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) {
18 providerOptions := &ProviderOptions{}
19 if v, ok := call.ProviderOptions[Name]; ok {
20 providerOptions, ok = v.(*ProviderOptions)
21 if !ok {
22 return nil, fantasy.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
23 }
24 }
25
26 if providerOptions.ReasoningEffort != nil {
27 switch *providerOptions.ReasoningEffort {
28 case openai.ReasoningEffortMinimal:
29 params.ReasoningEffort = shared.ReasoningEffortMinimal
30 case openai.ReasoningEffortLow:
31 params.ReasoningEffort = shared.ReasoningEffortLow
32 case openai.ReasoningEffortMedium:
33 params.ReasoningEffort = shared.ReasoningEffortMedium
34 case openai.ReasoningEffortHigh:
35 params.ReasoningEffort = shared.ReasoningEffortHigh
36 default:
37 return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
38 }
39 }
40
41 if providerOptions.User != nil {
42 params.User = param.NewOpt(*providerOptions.User)
43 }
44 return nil, nil
45}
46
47// ExtraContentFunc adds extra content to the response.
48func ExtraContentFunc(choice openaisdk.ChatCompletionChoice) []fantasy.Content {
49 var content []fantasy.Content
50 reasoningData := ReasoningData{}
51 err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData)
52 if err != nil {
53 return content
54 }
55 if reasoningData.ReasoningContent != "" {
56 content = append(content, fantasy.ReasoningContent{
57 Text: reasoningData.ReasoningContent,
58 })
59 }
60 return content
61}
62
63func extractReasoningContext(ctx map[string]any) bool {
64 reasoningStarted, ok := ctx[reasoningStartedCtx]
65 if !ok {
66 return false
67 }
68 b, ok := reasoningStarted.(bool)
69 if !ok {
70 return false
71 }
72 return b
73}
74
75// StreamExtraFunc handles extra functionality for streaming responses.
76func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any) (map[string]any, bool) {
77 if len(chunk.Choices) == 0 {
78 return ctx, true
79 }
80
81 reasoningStarted := extractReasoningContext(ctx)
82
83 for inx, choice := range chunk.Choices {
84 reasoningData := ReasoningData{}
85 err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData)
86 if err != nil {
87 yield(fantasy.StreamPart{
88 Type: fantasy.StreamPartTypeError,
89 Error: fantasy.NewAIError("Unexpected", "error unmarshalling delta", err),
90 })
91 return ctx, false
92 }
93
94 emitEvent := func(reasoningContent string) bool {
95 if !reasoningStarted {
96 shouldContinue := yield(fantasy.StreamPart{
97 Type: fantasy.StreamPartTypeReasoningStart,
98 ID: fmt.Sprintf("%d", inx),
99 })
100 if !shouldContinue {
101 return false
102 }
103 }
104
105 return yield(fantasy.StreamPart{
106 Type: fantasy.StreamPartTypeReasoningDelta,
107 ID: fmt.Sprintf("%d", inx),
108 Delta: reasoningContent,
109 })
110 }
111 if reasoningData.ReasoningContent != "" {
112 if !reasoningStarted {
113 ctx[reasoningStartedCtx] = true
114 }
115 return ctx, emitEvent(reasoningData.ReasoningContent)
116 }
117 if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) {
118 ctx[reasoningStartedCtx] = false
119 return ctx, yield(fantasy.StreamPart{
120 Type: fantasy.StreamPartTypeReasoningEnd,
121 ID: fmt.Sprintf("%d", inx),
122 })
123 }
124 }
125 return ctx, true
126}