language_model_hooks.go

  1package openaicompat
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6
  7	"charm.land/fantasy"
  8	"charm.land/fantasy/openai"
  9	openaisdk "github.com/openai/openai-go/v2"
 10	"github.com/openai/openai-go/v2/packages/param"
 11	"github.com/openai/openai-go/v2/shared"
 12)
 13
 14const reasoningStartedCtx = "reasoning_started"
 15
 16func PrepareCallFunc(model fantasy.LanguageModel, params *openaisdk.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) {
 17	providerOptions := &ProviderOptions{}
 18	if v, ok := call.ProviderOptions[Name]; ok {
 19		providerOptions, ok = v.(*ProviderOptions)
 20		if !ok {
 21			return nil, fantasy.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
 22		}
 23	}
 24
 25	if providerOptions.ReasoningEffort != nil {
 26		switch *providerOptions.ReasoningEffort {
 27		case openai.ReasoningEffortMinimal:
 28			params.ReasoningEffort = shared.ReasoningEffortMinimal
 29		case openai.ReasoningEffortLow:
 30			params.ReasoningEffort = shared.ReasoningEffortLow
 31		case openai.ReasoningEffortMedium:
 32			params.ReasoningEffort = shared.ReasoningEffortMedium
 33		case openai.ReasoningEffortHigh:
 34			params.ReasoningEffort = shared.ReasoningEffortHigh
 35		default:
 36			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 37		}
 38	}
 39
 40	if providerOptions.User != nil {
 41		params.User = param.NewOpt(*providerOptions.User)
 42	}
 43	return nil, nil
 44}
 45
 46func ExtraContentFunc(choice openaisdk.ChatCompletionChoice) []fantasy.Content {
 47	var content []fantasy.Content
 48	reasoningData := ReasoningData{}
 49	err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData)
 50	if err != nil {
 51		return content
 52	}
 53	if reasoningData.ReasoningContent != "" {
 54		content = append(content, fantasy.ReasoningContent{
 55			Text: reasoningData.ReasoningContent,
 56		})
 57	}
 58	return content
 59}
 60
 61func extractReasoningContext(ctx map[string]any) bool {
 62	reasoningStarted, ok := ctx[reasoningStartedCtx]
 63	if !ok {
 64		return false
 65	}
 66	b, ok := reasoningStarted.(bool)
 67	if !ok {
 68		return false
 69	}
 70	return b
 71}
 72
 73func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any) (map[string]any, bool) {
 74	if len(chunk.Choices) == 0 {
 75		return ctx, true
 76	}
 77
 78	reasoningStarted := extractReasoningContext(ctx)
 79
 80	for inx, choice := range chunk.Choices {
 81		reasoningData := ReasoningData{}
 82		err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData)
 83		if err != nil {
 84			yield(fantasy.StreamPart{
 85				Type:  fantasy.StreamPartTypeError,
 86				Error: fantasy.NewAIError("Unexpected", "error unmarshalling delta", err),
 87			})
 88			return ctx, false
 89		}
 90
 91		emitEvent := func(reasoningContent string) bool {
 92			if !reasoningStarted {
 93				shouldContinue := yield(fantasy.StreamPart{
 94					Type: fantasy.StreamPartTypeReasoningStart,
 95					ID:   fmt.Sprintf("%d", inx),
 96				})
 97				if !shouldContinue {
 98					return false
 99				}
100			}
101
102			return yield(fantasy.StreamPart{
103				Type:  fantasy.StreamPartTypeReasoningDelta,
104				ID:    fmt.Sprintf("%d", inx),
105				Delta: reasoningContent,
106			})
107		}
108		if reasoningData.ReasoningContent != "" {
109			if !reasoningStarted {
110				ctx[reasoningStartedCtx] = true
111			}
112			return ctx, emitEvent(reasoningData.ReasoningContent)
113		}
114		if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) {
115			ctx[reasoningStartedCtx] = false
116			return ctx, yield(fantasy.StreamPart{
117				Type: fantasy.StreamPartTypeReasoningEnd,
118				ID:   fmt.Sprintf("%d", inx),
119			})
120		}
121	}
122	return ctx, true
123}