language_model_hooks.go

  1package openaicompat
  2
  3import (
  4	"encoding/base64"
  5	"encoding/json"
  6	"fmt"
  7	"strings"
  8
  9	"charm.land/fantasy"
 10	"charm.land/fantasy/providers/openai"
 11	openaisdk "github.com/charmbracelet/openai-go"
 12	"github.com/charmbracelet/openai-go/packages/param"
 13	"github.com/charmbracelet/openai-go/shared"
 14)
 15
 16const reasoningStartedCtx = "reasoning_started"
 17
 18// PrepareCallFunc prepares the call for the language model.
 19func PrepareCallFunc(_ fantasy.LanguageModel, params *openaisdk.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) {
 20	providerOptions := &ProviderOptions{}
 21	if v, ok := call.ProviderOptions[Name]; ok {
 22		providerOptions, ok = v.(*ProviderOptions)
 23		if !ok {
 24			return nil, &fantasy.Error{Title: "invalid argument", Message: "openai-compat provider options should be *openaicompat.ProviderOptions"}
 25		}
 26	}
 27
 28	if providerOptions.ReasoningEffort != nil {
 29		switch *providerOptions.ReasoningEffort {
 30		case openai.ReasoningEffortNone:
 31			params.ReasoningEffort = shared.ReasoningEffortNone
 32		case openai.ReasoningEffortMinimal:
 33			params.ReasoningEffort = shared.ReasoningEffortMinimal
 34		case openai.ReasoningEffortLow:
 35			params.ReasoningEffort = shared.ReasoningEffortLow
 36		case openai.ReasoningEffortMedium:
 37			params.ReasoningEffort = shared.ReasoningEffortMedium
 38		case openai.ReasoningEffortHigh:
 39			params.ReasoningEffort = shared.ReasoningEffortHigh
 40		case openai.ReasoningEffortXHigh:
 41			params.ReasoningEffort = shared.ReasoningEffortXhigh
 42		default:
 43			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 44		}
 45	}
 46
 47	if providerOptions.User != nil {
 48		params.User = param.NewOpt(*providerOptions.User)
 49	}
 50	return nil, nil
 51}
 52
 53// ExtraContentFunc adds extra content to the response.
 54func ExtraContentFunc(choice openaisdk.ChatCompletionChoice) []fantasy.Content {
 55	var content []fantasy.Content
 56	reasoningData := ReasoningData{}
 57	err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData)
 58	if err != nil {
 59		return content
 60	}
 61	if rc := reasoningData.GetReasoningContent(); rc != "" {
 62		content = append(content, fantasy.ReasoningContent{
 63			Text: rc,
 64		})
 65	}
 66	return content
 67}
 68
 69func extractReasoningContext(ctx map[string]any) bool {
 70	reasoningStarted, ok := ctx[reasoningStartedCtx]
 71	if !ok {
 72		return false
 73	}
 74	b, ok := reasoningStarted.(bool)
 75	if !ok {
 76		return false
 77	}
 78	return b
 79}
 80
 81// StreamExtraFunc handles extra functionality for streaming responses.
 82func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any) (map[string]any, bool) {
 83	if len(chunk.Choices) == 0 {
 84		return ctx, true
 85	}
 86
 87	reasoningStarted := extractReasoningContext(ctx)
 88
 89	for inx, choice := range chunk.Choices {
 90		reasoningData := ReasoningData{}
 91		err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData)
 92		if err != nil {
 93			yield(fantasy.StreamPart{
 94				Type:  fantasy.StreamPartTypeError,
 95				Error: &fantasy.Error{Title: "stream error", Message: "error unmarshalling delta", Cause: err},
 96			})
 97			return ctx, false
 98		}
 99
100		emitEvent := func(reasoningContent string) bool {
101			if !reasoningStarted {
102				shouldContinue := yield(fantasy.StreamPart{
103					Type: fantasy.StreamPartTypeReasoningStart,
104					ID:   fmt.Sprintf("%d", inx),
105				})
106				if !shouldContinue {
107					return false
108				}
109			}
110
111			return yield(fantasy.StreamPart{
112				Type:  fantasy.StreamPartTypeReasoningDelta,
113				ID:    fmt.Sprintf("%d", inx),
114				Delta: reasoningContent,
115			})
116		}
117		if rc := reasoningData.GetReasoningContent(); rc != "" {
118			if !reasoningStarted {
119				ctx[reasoningStartedCtx] = true
120			}
121			return ctx, emitEvent(rc)
122		}
123		if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) {
124			ctx[reasoningStartedCtx] = false
125			return ctx, yield(fantasy.StreamPart{
126				Type: fantasy.StreamPartTypeReasoningEnd,
127				ID:   fmt.Sprintf("%d", inx),
128			})
129		}
130	}
131	return ctx, true
132}
133
134// ToPromptFunc converts a fantasy prompt to OpenAI format with reasoning support.
135// It handles fantasy.ContentTypeReasoning in assistant messages by adding the
136// reasoning_content field to the message JSON.
137func ToPromptFunc(prompt fantasy.Prompt, _, _ string) ([]openaisdk.ChatCompletionMessageParamUnion, []fantasy.CallWarning) {
138	var messages []openaisdk.ChatCompletionMessageParamUnion
139	var warnings []fantasy.CallWarning
140	hasReasoning := false
141
142	for _, msg := range prompt {
143		switch msg.Role {
144		case fantasy.MessageRoleSystem:
145			var systemPromptParts []string
146			for _, c := range msg.Content {
147				if c.GetType() != fantasy.ContentTypeText {
148					warnings = append(warnings, fantasy.CallWarning{
149						Type:    fantasy.CallWarningTypeOther,
150						Message: "system prompt can only have text content",
151					})
152					continue
153				}
154				textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
155				if !ok {
156					warnings = append(warnings, fantasy.CallWarning{
157						Type:    fantasy.CallWarningTypeOther,
158						Message: "system prompt text part does not have the right type",
159					})
160					continue
161				}
162				text := textPart.Text
163				if strings.TrimSpace(text) != "" {
164					systemPromptParts = append(systemPromptParts, textPart.Text)
165				}
166			}
167			if len(systemPromptParts) == 0 {
168				warnings = append(warnings, fantasy.CallWarning{
169					Type:    fantasy.CallWarningTypeOther,
170					Message: "system prompt has no text parts",
171				})
172				continue
173			}
174			messages = append(messages, openaisdk.SystemMessage(strings.Join(systemPromptParts, "\n")))
175		case fantasy.MessageRoleUser:
176			// simple user message just text content
177			if len(msg.Content) == 1 && msg.Content[0].GetType() == fantasy.ContentTypeText {
178				textPart, ok := fantasy.AsContentType[fantasy.TextPart](msg.Content[0])
179				if !ok {
180					warnings = append(warnings, fantasy.CallWarning{
181						Type:    fantasy.CallWarningTypeOther,
182						Message: "user message text part does not have the right type",
183					})
184					continue
185				}
186				messages = append(messages, openaisdk.UserMessage(textPart.Text))
187				continue
188			}
189			// text content and attachments
190			var content []openaisdk.ChatCompletionContentPartUnionParam
191			for _, c := range msg.Content {
192				switch c.GetType() {
193				case fantasy.ContentTypeText:
194					textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
195					if !ok {
196						warnings = append(warnings, fantasy.CallWarning{
197							Type:    fantasy.CallWarningTypeOther,
198							Message: "user message text part does not have the right type",
199						})
200						continue
201					}
202					content = append(content, openaisdk.ChatCompletionContentPartUnionParam{
203						OfText: &openaisdk.ChatCompletionContentPartTextParam{
204							Text: textPart.Text,
205						},
206					})
207				case fantasy.ContentTypeFile:
208					filePart, ok := fantasy.AsContentType[fantasy.FilePart](c)
209					if !ok {
210						warnings = append(warnings, fantasy.CallWarning{
211							Type:    fantasy.CallWarningTypeOther,
212							Message: "user message file part does not have the right type",
213						})
214						continue
215					}
216
217					switch {
218					case strings.HasPrefix(filePart.MediaType, "text/"):
219						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
220						documentBlock := openaisdk.ChatCompletionContentPartFileFileParam{
221							FileData: param.NewOpt(base64Encoded),
222						}
223						content = append(content, openaisdk.FileContentPart(documentBlock))
224
225					case strings.HasPrefix(filePart.MediaType, "image/"):
226						// Handle image files
227						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
228						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
229						imageURL := openaisdk.ChatCompletionContentPartImageImageURLParam{URL: data}
230
231						// Check for provider-specific options like image detail
232						if providerOptions, ok := filePart.ProviderOptions[openai.Name]; ok {
233							if detail, ok := providerOptions.(*openai.ProviderFileOptions); ok {
234								imageURL.Detail = detail.ImageDetail
235							}
236						}
237
238						imageBlock := openaisdk.ChatCompletionContentPartImageParam{ImageURL: imageURL}
239						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
240
241					case filePart.MediaType == "audio/wav":
242						// Handle WAV audio files
243						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
244						audioBlock := openaisdk.ChatCompletionContentPartInputAudioParam{
245							InputAudio: openaisdk.ChatCompletionContentPartInputAudioInputAudioParam{
246								Data:   base64Encoded,
247								Format: "wav",
248							},
249						}
250						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
251
252					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
253						// Handle MP3 audio files
254						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
255						audioBlock := openaisdk.ChatCompletionContentPartInputAudioParam{
256							InputAudio: openaisdk.ChatCompletionContentPartInputAudioInputAudioParam{
257								Data:   base64Encoded,
258								Format: "mp3",
259							},
260						}
261						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
262
263					case filePart.MediaType == "application/pdf":
264						// Handle PDF files
265						dataStr := string(filePart.Data)
266
267						// Check if data looks like a file ID (starts with "file-")
268						if strings.HasPrefix(dataStr, "file-") {
269							fileBlock := openaisdk.ChatCompletionContentPartFileParam{
270								File: openaisdk.ChatCompletionContentPartFileFileParam{
271									FileID: param.NewOpt(dataStr),
272								},
273							}
274							content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
275						} else {
276							// Handle as base64 data
277							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
278							data := "data:application/pdf;base64," + base64Encoded
279
280							filename := filePart.Filename
281							if filename == "" {
282								// Generate default filename based on content index
283								filename = fmt.Sprintf("part-%d.pdf", len(content))
284							}
285
286							fileBlock := openaisdk.ChatCompletionContentPartFileParam{
287								File: openaisdk.ChatCompletionContentPartFileFileParam{
288									Filename: param.NewOpt(filename),
289									FileData: param.NewOpt(data),
290								},
291							}
292							content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
293						}
294
295					default:
296						warnings = append(warnings, fantasy.CallWarning{
297							Type:    fantasy.CallWarningTypeOther,
298							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
299						})
300					}
301				}
302			}
303			if !hasVisibleCompatUserContent(content) {
304				warnings = append(warnings, fantasy.CallWarning{
305					Type:    fantasy.CallWarningTypeOther,
306					Message: "dropping empty user message (contains neither user-facing content nor tool results)",
307				})
308				continue
309			}
310			messages = append(messages, openaisdk.UserMessage(content))
311		case fantasy.MessageRoleAssistant:
312			// simple assistant message just text content
313			if len(msg.Content) == 1 && msg.Content[0].GetType() == fantasy.ContentTypeText {
314				textPart, ok := fantasy.AsContentType[fantasy.TextPart](msg.Content[0])
315				if !ok {
316					warnings = append(warnings, fantasy.CallWarning{
317						Type:    fantasy.CallWarningTypeOther,
318						Message: "assistant message text part does not have the right type",
319					})
320					continue
321				}
322				messages = append(messages, openaisdk.AssistantMessage(textPart.Text))
323				continue
324			}
325			assistantMsg := openaisdk.ChatCompletionAssistantMessageParam{
326				Role: "assistant",
327			}
328			var reasoningText string
329			for _, c := range msg.Content {
330				switch c.GetType() {
331				case fantasy.ContentTypeText:
332					textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
333					if !ok {
334						warnings = append(warnings, fantasy.CallWarning{
335							Type:    fantasy.CallWarningTypeOther,
336							Message: "assistant message text part does not have the right type",
337						})
338						continue
339					}
340					assistantMsg.Content = openaisdk.ChatCompletionAssistantMessageParamContentUnion{
341						OfString: param.NewOpt(textPart.Text),
342					}
343				case fantasy.ContentTypeReasoning:
344					reasoningPart, ok := fantasy.AsContentType[fantasy.ReasoningPart](c)
345					if !ok {
346						warnings = append(warnings, fantasy.CallWarning{
347							Type:    fantasy.CallWarningTypeOther,
348							Message: "assistant message reasoning part does not have the right type",
349						})
350						continue
351					}
352					reasoningText = reasoningPart.Text
353					hasReasoning = true
354				case fantasy.ContentTypeToolCall:
355					toolCallPart, ok := fantasy.AsContentType[fantasy.ToolCallPart](c)
356					if !ok {
357						warnings = append(warnings, fantasy.CallWarning{
358							Type:    fantasy.CallWarningTypeOther,
359							Message: "assistant message tool part does not have the right type",
360						})
361						continue
362					}
363					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
364						openaisdk.ChatCompletionMessageToolCallUnionParam{
365							OfFunction: &openaisdk.ChatCompletionMessageFunctionToolCallParam{
366								ID:   toolCallPart.ToolCallID,
367								Type: "function",
368								Function: openaisdk.ChatCompletionMessageFunctionToolCallFunctionParam{
369									Name:      toolCallPart.ToolName,
370									Arguments: toolCallPart.Input,
371								},
372							},
373						})
374				}
375			}
376			// Add reasoning_content field if present, or if thinking is enabled
377			// and the message has tool calls (some providers like Kimi require
378			// reasoning_content on all assistant messages when thinking is enabled).
379			if reasoningText != "" || (hasReasoning && len(assistantMsg.ToolCalls) > 0) {
380				assistantMsg.SetExtraFields(map[string]any{
381					"reasoning_content": reasoningText,
382				})
383			}
384			if !hasVisibleCompatAssistantContent(&assistantMsg) {
385				warnings = append(warnings, fantasy.CallWarning{
386					Type:    fantasy.CallWarningTypeOther,
387					Message: "dropping empty assistant message (contains neither user-facing content nor tool calls)",
388				})
389				continue
390			}
391			messages = append(messages, openaisdk.ChatCompletionMessageParamUnion{
392				OfAssistant: &assistantMsg,
393			})
394		case fantasy.MessageRoleTool:
395			for _, c := range msg.Content {
396				if c.GetType() != fantasy.ContentTypeToolResult {
397					warnings = append(warnings, fantasy.CallWarning{
398						Type:    fantasy.CallWarningTypeOther,
399						Message: "tool message can only have tool result content",
400					})
401					continue
402				}
403
404				toolResultPart, ok := fantasy.AsContentType[fantasy.ToolResultPart](c)
405				if !ok {
406					warnings = append(warnings, fantasy.CallWarning{
407						Type:    fantasy.CallWarningTypeOther,
408						Message: "tool message result part does not have the right type",
409					})
410					continue
411				}
412
413				switch toolResultPart.Output.GetType() {
414				case fantasy.ToolResultContentTypeText:
415					output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](toolResultPart.Output)
416					if !ok {
417						warnings = append(warnings, fantasy.CallWarning{
418							Type:    fantasy.CallWarningTypeOther,
419							Message: "tool result output does not have the right type",
420						})
421						continue
422					}
423					messages = append(messages, openaisdk.ToolMessage(output.Text, toolResultPart.ToolCallID))
424				case fantasy.ToolResultContentTypeError:
425					output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](toolResultPart.Output)
426					if !ok {
427						warnings = append(warnings, fantasy.CallWarning{
428							Type:    fantasy.CallWarningTypeOther,
429							Message: "tool result output does not have the right type",
430						})
431						continue
432					}
433					messages = append(messages, openaisdk.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
434				}
435			}
436		}
437	}
438	return messages, warnings
439}
440
441func hasVisibleCompatUserContent(content []openaisdk.ChatCompletionContentPartUnionParam) bool {
442	for _, part := range content {
443		if part.OfText != nil || part.OfImageURL != nil || part.OfInputAudio != nil || part.OfFile != nil {
444			return true
445		}
446	}
447	return false
448}
449
450func hasVisibleCompatAssistantContent(msg *openaisdk.ChatCompletionAssistantMessageParam) bool {
451	// Check if there's text content
452	if !param.IsOmitted(msg.Content.OfString) || len(msg.Content.OfArrayOfContentParts) > 0 {
453		return true
454	}
455	// Check if there are tool calls
456	if len(msg.ToolCalls) > 0 {
457		return true
458	}
459	return false
460}