1package openaicompat
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6
  7	"charm.land/fantasy"
  8	"charm.land/fantasy/providers/openai"
  9	openaisdk "github.com/openai/openai-go/v2"
 10	"github.com/openai/openai-go/v2/packages/param"
 11	"github.com/openai/openai-go/v2/shared"
 12)
 13
 14const reasoningStartedCtx = "reasoning_started"
 15
 16// PrepareCallFunc prepares the call for the language model.
 17func PrepareCallFunc(_ fantasy.LanguageModel, params *openaisdk.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) {
 18	providerOptions := &ProviderOptions{}
 19	if v, ok := call.ProviderOptions[Name]; ok {
 20		providerOptions, ok = v.(*ProviderOptions)
 21		if !ok {
 22			return nil, fantasy.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
 23		}
 24	}
 25
 26	if providerOptions.ReasoningEffort != nil {
 27		switch *providerOptions.ReasoningEffort {
 28		case openai.ReasoningEffortMinimal:
 29			params.ReasoningEffort = shared.ReasoningEffortMinimal
 30		case openai.ReasoningEffortLow:
 31			params.ReasoningEffort = shared.ReasoningEffortLow
 32		case openai.ReasoningEffortMedium:
 33			params.ReasoningEffort = shared.ReasoningEffortMedium
 34		case openai.ReasoningEffortHigh:
 35			params.ReasoningEffort = shared.ReasoningEffortHigh
 36		default:
 37			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 38		}
 39	}
 40
 41	if providerOptions.User != nil {
 42		params.User = param.NewOpt(*providerOptions.User)
 43	}
 44	return nil, nil
 45}
 46
 47// ExtraContentFunc adds extra content to the response.
 48func ExtraContentFunc(choice openaisdk.ChatCompletionChoice) []fantasy.Content {
 49	var content []fantasy.Content
 50	reasoningData := ReasoningData{}
 51	err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData)
 52	if err != nil {
 53		return content
 54	}
 55	if reasoningData.ReasoningContent != "" {
 56		content = append(content, fantasy.ReasoningContent{
 57			Text: reasoningData.ReasoningContent,
 58		})
 59	}
 60	return content
 61}
 62
 63func extractReasoningContext(ctx map[string]any) bool {
 64	reasoningStarted, ok := ctx[reasoningStartedCtx]
 65	if !ok {
 66		return false
 67	}
 68	b, ok := reasoningStarted.(bool)
 69	if !ok {
 70		return false
 71	}
 72	return b
 73}
 74
 75// StreamExtraFunc handles extra functionality for streaming responses.
 76func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any) (map[string]any, bool) {
 77	if len(chunk.Choices) == 0 {
 78		return ctx, true
 79	}
 80
 81	reasoningStarted := extractReasoningContext(ctx)
 82
 83	for inx, choice := range chunk.Choices {
 84		reasoningData := ReasoningData{}
 85		err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData)
 86		if err != nil {
 87			yield(fantasy.StreamPart{
 88				Type:  fantasy.StreamPartTypeError,
 89				Error: fantasy.NewAIError("Unexpected", "error unmarshalling delta", err),
 90			})
 91			return ctx, false
 92		}
 93
 94		emitEvent := func(reasoningContent string) bool {
 95			if !reasoningStarted {
 96				shouldContinue := yield(fantasy.StreamPart{
 97					Type: fantasy.StreamPartTypeReasoningStart,
 98					ID:   fmt.Sprintf("%d", inx),
 99				})
100				if !shouldContinue {
101					return false
102				}
103			}
104
105			return yield(fantasy.StreamPart{
106				Type:  fantasy.StreamPartTypeReasoningDelta,
107				ID:    fmt.Sprintf("%d", inx),
108				Delta: reasoningContent,
109			})
110		}
111		if reasoningData.ReasoningContent != "" {
112			if !reasoningStarted {
113				ctx[reasoningStartedCtx] = true
114			}
115			return ctx, emitEvent(reasoningData.ReasoningContent)
116		}
117		if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) {
118			ctx[reasoningStartedCtx] = false
119			return ctx, yield(fantasy.StreamPart{
120				Type: fantasy.StreamPartTypeReasoningEnd,
121				ID:   fmt.Sprintf("%d", inx),
122			})
123		}
124	}
125	return ctx, true
126}