language_model_hooks.go

  1package openaicompat
  2
  3import (
  4	"fmt"
  5
  6	"github.com/charmbracelet/fantasy/ai"
  7	"github.com/charmbracelet/fantasy/openai"
  8	openaisdk "github.com/openai/openai-go/v2"
  9	"github.com/openai/openai-go/v2/packages/param"
 10	"github.com/openai/openai-go/v2/shared"
 11)
 12
 13const reasoningStartedCtx = "reasoning_started"
 14
 15func languagePrepareModelCall(model ai.LanguageModel, params *openaisdk.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
 16	providerOptions := &ProviderOptions{}
 17	if v, ok := call.ProviderOptions[Name]; ok {
 18		providerOptions, ok = v.(*ProviderOptions)
 19		if !ok {
 20			return nil, ai.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
 21		}
 22	}
 23
 24	if providerOptions.ReasoningEffort != nil {
 25		switch *providerOptions.ReasoningEffort {
 26		case openai.ReasoningEffortMinimal:
 27			params.ReasoningEffort = shared.ReasoningEffortMinimal
 28		case openai.ReasoningEffortLow:
 29			params.ReasoningEffort = shared.ReasoningEffortLow
 30		case openai.ReasoningEffortMedium:
 31			params.ReasoningEffort = shared.ReasoningEffortMedium
 32		case openai.ReasoningEffortHigh:
 33			params.ReasoningEffort = shared.ReasoningEffortHigh
 34		default:
 35			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 36		}
 37	}
 38
 39	if providerOptions.User != nil {
 40		params.User = param.NewOpt(*providerOptions.User)
 41	}
 42	return nil, nil
 43}
 44
 45func languageModelExtraContent(choice openaisdk.ChatCompletionChoice) []ai.Content {
 46	// TODO: check this
 47	return []ai.Content{}
 48}
 49
 50func extractReasoningContext(ctx map[string]any) bool {
 51	reasoningStarted, ok := ctx[reasoningStartedCtx]
 52	if !ok {
 53		return false
 54	}
 55	b, ok := reasoningStarted.(bool)
 56	if !ok {
 57		return false
 58	}
 59	return b
 60}
 61
 62func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool) {
 63	// TODO: check this
 64	// if len(chunk.Choices) == 0 {
 65	// 	return ctx, true
 66	// }
 67	//
 68	// reasoningStarted := extractReasoningContext(ctx)
 69	//
 70	// for inx, choice := range chunk.Choices {
 71	// 	reasoningData := ReasoningData{}
 72	// 	err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData)
 73	// 	if err != nil {
 74	// 		yield(ai.StreamPart{
 75	// 			Type:  ai.StreamPartTypeError,
 76	// 			Error: ai.NewAIError("Unexpected", "error unmarshalling delta", err),
 77	// 		})
 78	// 		return ctx, false
 79	// 	}
 80	//
 81	// 	emitEvent := func(reasoningContent string) bool {
 82	// 		if !reasoningStarted {
 83	// 			shouldContinue := yield(ai.StreamPart{
 84	// 				Type: ai.StreamPartTypeReasoningStart,
 85	// 				ID:   fmt.Sprintf("%d", inx),
 86	// 			})
 87	// 			if !shouldContinue {
 88	// 				return false
 89	// 			}
 90	// 		}
 91	//
 92	// 		return yield(ai.StreamPart{
 93	// 			Type:  ai.StreamPartTypeReasoningDelta,
 94	// 			ID:    fmt.Sprintf("%d", inx),
 95	// 			Delta: reasoningContent,
 96	// 		})
 97	// 	}
 98	// 	if len(reasoningData.ReasoningDetails) > 0 {
 99	// 		for _, detail := range reasoningData.ReasoningDetails {
100	// 			if !reasoningStarted {
101	// 				ctx[reasoningStartedCtx] = true
102	// 			}
103	// 			switch detail.Type {
104	// 			case "reasoning.text":
105	// 				return ctx, emitEvent(detail.Text)
106	// 			case "reasoning.summary":
107	// 				return ctx, emitEvent(detail.Summary)
108	// 			case "reasoning.encrypted":
109	// 				return ctx, emitEvent("[REDACTED]")
110	// 			}
111	// 		}
112	// 	} else if reasoningData.Reasoning != "" {
113	// 		return ctx, emitEvent(reasoningData.Reasoning)
114	// 	}
115	// 	if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) {
116	// 		ctx[reasoningStartedCtx] = false
117	// 		return ctx, yield(ai.StreamPart{
118	// 			Type: ai.StreamPartTypeReasoningEnd,
119	// 			ID:   fmt.Sprintf("%d", inx),
120	// 		})
121	// 	}
122	// }
123	// return ctx, true
124	return nil, true
125}