language_model_hooks.go

  1package openaicompat
  2
  3import (
  4	"encoding/base64"
  5	"encoding/json"
  6	"fmt"
  7	"strings"
  8
  9	"charm.land/fantasy"
 10	"charm.land/fantasy/providers/openai"
 11	openaisdk "github.com/charmbracelet/openai-go"
 12	"github.com/charmbracelet/openai-go/packages/param"
 13	"github.com/charmbracelet/openai-go/shared"
 14)
 15
 16const reasoningStartedCtx = "reasoning_started"
 17
 18// PrepareCallFunc prepares the call for the language model.
 19func PrepareCallFunc(_ fantasy.LanguageModel, params *openaisdk.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) {
 20	providerOptions := &ProviderOptions{}
 21	if v, ok := call.ProviderOptions[Name]; ok {
 22		providerOptions, ok = v.(*ProviderOptions)
 23		if !ok {
 24			return nil, &fantasy.Error{Title: "invalid argument", Message: "openai-compat provider options should be *openaicompat.ProviderOptions"}
 25		}
 26	}
 27
 28	if providerOptions.ReasoningEffort != nil {
 29		switch *providerOptions.ReasoningEffort {
 30		case openai.ReasoningEffortMinimal:
 31			params.ReasoningEffort = shared.ReasoningEffortMinimal
 32		case openai.ReasoningEffortLow:
 33			params.ReasoningEffort = shared.ReasoningEffortLow
 34		case openai.ReasoningEffortMedium:
 35			params.ReasoningEffort = shared.ReasoningEffortMedium
 36		case openai.ReasoningEffortHigh:
 37			params.ReasoningEffort = shared.ReasoningEffortHigh
 38		default:
 39			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 40		}
 41	}
 42
 43	if providerOptions.User != nil {
 44		params.User = param.NewOpt(*providerOptions.User)
 45	}
 46	return nil, nil
 47}
 48
 49// ExtraContentFunc adds extra content to the response.
 50func ExtraContentFunc(choice openaisdk.ChatCompletionChoice) []fantasy.Content {
 51	var content []fantasy.Content
 52	reasoningData := ReasoningData{}
 53	err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData)
 54	if err != nil {
 55		return content
 56	}
 57	if reasoningData.ReasoningContent != "" {
 58		content = append(content, fantasy.ReasoningContent{
 59			Text: reasoningData.ReasoningContent,
 60		})
 61	}
 62	return content
 63}
 64
 65func extractReasoningContext(ctx map[string]any) bool {
 66	reasoningStarted, ok := ctx[reasoningStartedCtx]
 67	if !ok {
 68		return false
 69	}
 70	b, ok := reasoningStarted.(bool)
 71	if !ok {
 72		return false
 73	}
 74	return b
 75}
 76
 77// StreamExtraFunc handles extra functionality for streaming responses.
 78func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any) (map[string]any, bool) {
 79	if len(chunk.Choices) == 0 {
 80		return ctx, true
 81	}
 82
 83	reasoningStarted := extractReasoningContext(ctx)
 84
 85	for inx, choice := range chunk.Choices {
 86		reasoningData := ReasoningData{}
 87		err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData)
 88		if err != nil {
 89			yield(fantasy.StreamPart{
 90				Type:  fantasy.StreamPartTypeError,
 91				Error: &fantasy.Error{Title: "stream error", Message: "error unmarshalling delta", Cause: err},
 92			})
 93			return ctx, false
 94		}
 95
 96		emitEvent := func(reasoningContent string) bool {
 97			if !reasoningStarted {
 98				shouldContinue := yield(fantasy.StreamPart{
 99					Type: fantasy.StreamPartTypeReasoningStart,
100					ID:   fmt.Sprintf("%d", inx),
101				})
102				if !shouldContinue {
103					return false
104				}
105			}
106
107			return yield(fantasy.StreamPart{
108				Type:  fantasy.StreamPartTypeReasoningDelta,
109				ID:    fmt.Sprintf("%d", inx),
110				Delta: reasoningContent,
111			})
112		}
113		if reasoningData.ReasoningContent != "" {
114			if !reasoningStarted {
115				ctx[reasoningStartedCtx] = true
116			}
117			return ctx, emitEvent(reasoningData.ReasoningContent)
118		}
119		if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) {
120			ctx[reasoningStartedCtx] = false
121			return ctx, yield(fantasy.StreamPart{
122				Type: fantasy.StreamPartTypeReasoningEnd,
123				ID:   fmt.Sprintf("%d", inx),
124			})
125		}
126	}
127	return ctx, true
128}
129
130// ToPromptFunc converts a fantasy prompt to OpenAI format with reasoning support.
131// It handles fantasy.ContentTypeReasoning in assistant messages by adding the
132// reasoning_content field to the message JSON.
133func ToPromptFunc(prompt fantasy.Prompt, _, _ string) ([]openaisdk.ChatCompletionMessageParamUnion, []fantasy.CallWarning) {
134	var messages []openaisdk.ChatCompletionMessageParamUnion
135	var warnings []fantasy.CallWarning
136	for _, msg := range prompt {
137		switch msg.Role {
138		case fantasy.MessageRoleSystem:
139			var systemPromptParts []string
140			for _, c := range msg.Content {
141				if c.GetType() != fantasy.ContentTypeText {
142					warnings = append(warnings, fantasy.CallWarning{
143						Type:    fantasy.CallWarningTypeOther,
144						Message: "system prompt can only have text content",
145					})
146					continue
147				}
148				textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
149				if !ok {
150					warnings = append(warnings, fantasy.CallWarning{
151						Type:    fantasy.CallWarningTypeOther,
152						Message: "system prompt text part does not have the right type",
153					})
154					continue
155				}
156				text := textPart.Text
157				if strings.TrimSpace(text) != "" {
158					systemPromptParts = append(systemPromptParts, textPart.Text)
159				}
160			}
161			if len(systemPromptParts) == 0 {
162				warnings = append(warnings, fantasy.CallWarning{
163					Type:    fantasy.CallWarningTypeOther,
164					Message: "system prompt has no text parts",
165				})
166				continue
167			}
168			messages = append(messages, openaisdk.SystemMessage(strings.Join(systemPromptParts, "\n")))
169		case fantasy.MessageRoleUser:
170			// simple user message just text content
171			if len(msg.Content) == 1 && msg.Content[0].GetType() == fantasy.ContentTypeText {
172				textPart, ok := fantasy.AsContentType[fantasy.TextPart](msg.Content[0])
173				if !ok {
174					warnings = append(warnings, fantasy.CallWarning{
175						Type:    fantasy.CallWarningTypeOther,
176						Message: "user message text part does not have the right type",
177					})
178					continue
179				}
180				messages = append(messages, openaisdk.UserMessage(textPart.Text))
181				continue
182			}
183			// text content and attachments
184			var content []openaisdk.ChatCompletionContentPartUnionParam
185			for _, c := range msg.Content {
186				switch c.GetType() {
187				case fantasy.ContentTypeText:
188					textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
189					if !ok {
190						warnings = append(warnings, fantasy.CallWarning{
191							Type:    fantasy.CallWarningTypeOther,
192							Message: "user message text part does not have the right type",
193						})
194						continue
195					}
196					content = append(content, openaisdk.ChatCompletionContentPartUnionParam{
197						OfText: &openaisdk.ChatCompletionContentPartTextParam{
198							Text: textPart.Text,
199						},
200					})
201				case fantasy.ContentTypeFile:
202					filePart, ok := fantasy.AsContentType[fantasy.FilePart](c)
203					if !ok {
204						warnings = append(warnings, fantasy.CallWarning{
205							Type:    fantasy.CallWarningTypeOther,
206							Message: "user message file part does not have the right type",
207						})
208						continue
209					}
210
211					switch {
212					case strings.HasPrefix(filePart.MediaType, "text/"):
213						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
214						documentBlock := openaisdk.ChatCompletionContentPartFileFileParam{
215							FileData: param.NewOpt(base64Encoded),
216						}
217						content = append(content, openaisdk.FileContentPart(documentBlock))
218
219					case strings.HasPrefix(filePart.MediaType, "image/"):
220						// Handle image files
221						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
222						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
223						imageURL := openaisdk.ChatCompletionContentPartImageImageURLParam{URL: data}
224
225						// Check for provider-specific options like image detail
226						if providerOptions, ok := filePart.ProviderOptions[openai.Name]; ok {
227							if detail, ok := providerOptions.(*openai.ProviderFileOptions); ok {
228								imageURL.Detail = detail.ImageDetail
229							}
230						}
231
232						imageBlock := openaisdk.ChatCompletionContentPartImageParam{ImageURL: imageURL}
233						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
234
235					case filePart.MediaType == "audio/wav":
236						// Handle WAV audio files
237						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
238						audioBlock := openaisdk.ChatCompletionContentPartInputAudioParam{
239							InputAudio: openaisdk.ChatCompletionContentPartInputAudioInputAudioParam{
240								Data:   base64Encoded,
241								Format: "wav",
242							},
243						}
244						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
245
246					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
247						// Handle MP3 audio files
248						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
249						audioBlock := openaisdk.ChatCompletionContentPartInputAudioParam{
250							InputAudio: openaisdk.ChatCompletionContentPartInputAudioInputAudioParam{
251								Data:   base64Encoded,
252								Format: "mp3",
253							},
254						}
255						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
256
257					case filePart.MediaType == "application/pdf":
258						// Handle PDF files
259						dataStr := string(filePart.Data)
260
261						// Check if data looks like a file ID (starts with "file-")
262						if strings.HasPrefix(dataStr, "file-") {
263							fileBlock := openaisdk.ChatCompletionContentPartFileParam{
264								File: openaisdk.ChatCompletionContentPartFileFileParam{
265									FileID: param.NewOpt(dataStr),
266								},
267							}
268							content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
269						} else {
270							// Handle as base64 data
271							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
272							data := "data:application/pdf;base64," + base64Encoded
273
274							filename := filePart.Filename
275							if filename == "" {
276								// Generate default filename based on content index
277								filename = fmt.Sprintf("part-%d.pdf", len(content))
278							}
279
280							fileBlock := openaisdk.ChatCompletionContentPartFileParam{
281								File: openaisdk.ChatCompletionContentPartFileFileParam{
282									Filename: param.NewOpt(filename),
283									FileData: param.NewOpt(data),
284								},
285							}
286							content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
287						}
288
289					default:
290						warnings = append(warnings, fantasy.CallWarning{
291							Type:    fantasy.CallWarningTypeOther,
292							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
293						})
294					}
295				}
296			}
297			if !hasVisibleCompatUserContent(content) {
298				warnings = append(warnings, fantasy.CallWarning{
299					Type:    fantasy.CallWarningTypeOther,
300					Message: "dropping empty user message (contains neither user-facing content nor tool results)",
301				})
302				continue
303			}
304			messages = append(messages, openaisdk.UserMessage(content))
305		case fantasy.MessageRoleAssistant:
306			// simple assistant message just text content
307			if len(msg.Content) == 1 && msg.Content[0].GetType() == fantasy.ContentTypeText {
308				textPart, ok := fantasy.AsContentType[fantasy.TextPart](msg.Content[0])
309				if !ok {
310					warnings = append(warnings, fantasy.CallWarning{
311						Type:    fantasy.CallWarningTypeOther,
312						Message: "assistant message text part does not have the right type",
313					})
314					continue
315				}
316				messages = append(messages, openaisdk.AssistantMessage(textPart.Text))
317				continue
318			}
319			assistantMsg := openaisdk.ChatCompletionAssistantMessageParam{
320				Role: "assistant",
321			}
322			var reasoningText string
323			for _, c := range msg.Content {
324				switch c.GetType() {
325				case fantasy.ContentTypeText:
326					textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
327					if !ok {
328						warnings = append(warnings, fantasy.CallWarning{
329							Type:    fantasy.CallWarningTypeOther,
330							Message: "assistant message text part does not have the right type",
331						})
332						continue
333					}
334					assistantMsg.Content = openaisdk.ChatCompletionAssistantMessageParamContentUnion{
335						OfString: param.NewOpt(textPart.Text),
336					}
337				case fantasy.ContentTypeReasoning:
338					reasoningPart, ok := fantasy.AsContentType[fantasy.ReasoningPart](c)
339					if !ok {
340						warnings = append(warnings, fantasy.CallWarning{
341							Type:    fantasy.CallWarningTypeOther,
342							Message: "assistant message reasoning part does not have the right type",
343						})
344						continue
345					}
346					reasoningText = reasoningPart.Text
347				case fantasy.ContentTypeToolCall:
348					toolCallPart, ok := fantasy.AsContentType[fantasy.ToolCallPart](c)
349					if !ok {
350						warnings = append(warnings, fantasy.CallWarning{
351							Type:    fantasy.CallWarningTypeOther,
352							Message: "assistant message tool part does not have the right type",
353						})
354						continue
355					}
356					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
357						openaisdk.ChatCompletionMessageToolCallUnionParam{
358							OfFunction: &openaisdk.ChatCompletionMessageFunctionToolCallParam{
359								ID:   toolCallPart.ToolCallID,
360								Type: "function",
361								Function: openaisdk.ChatCompletionMessageFunctionToolCallFunctionParam{
362									Name:      toolCallPart.ToolName,
363									Arguments: toolCallPart.Input,
364								},
365							},
366						})
367				}
368			}
369			// Add reasoning_content field if present
370			if reasoningText != "" {
371				assistantMsg.SetExtraFields(map[string]any{
372					"reasoning_content": reasoningText,
373				})
374			}
375			if !hasVisibleCompatAssistantContent(&assistantMsg) {
376				warnings = append(warnings, fantasy.CallWarning{
377					Type:    fantasy.CallWarningTypeOther,
378					Message: "dropping empty assistant message (contains neither user-facing content nor tool calls)",
379				})
380				continue
381			}
382			messages = append(messages, openaisdk.ChatCompletionMessageParamUnion{
383				OfAssistant: &assistantMsg,
384			})
385		case fantasy.MessageRoleTool:
386			for _, c := range msg.Content {
387				if c.GetType() != fantasy.ContentTypeToolResult {
388					warnings = append(warnings, fantasy.CallWarning{
389						Type:    fantasy.CallWarningTypeOther,
390						Message: "tool message can only have tool result content",
391					})
392					continue
393				}
394
395				toolResultPart, ok := fantasy.AsContentType[fantasy.ToolResultPart](c)
396				if !ok {
397					warnings = append(warnings, fantasy.CallWarning{
398						Type:    fantasy.CallWarningTypeOther,
399						Message: "tool message result part does not have the right type",
400					})
401					continue
402				}
403
404				switch toolResultPart.Output.GetType() {
405				case fantasy.ToolResultContentTypeText:
406					output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](toolResultPart.Output)
407					if !ok {
408						warnings = append(warnings, fantasy.CallWarning{
409							Type:    fantasy.CallWarningTypeOther,
410							Message: "tool result output does not have the right type",
411						})
412						continue
413					}
414					messages = append(messages, openaisdk.ToolMessage(output.Text, toolResultPart.ToolCallID))
415				case fantasy.ToolResultContentTypeError:
416					output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](toolResultPart.Output)
417					if !ok {
418						warnings = append(warnings, fantasy.CallWarning{
419							Type:    fantasy.CallWarningTypeOther,
420							Message: "tool result output does not have the right type",
421						})
422						continue
423					}
424					messages = append(messages, openaisdk.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
425				}
426			}
427		}
428	}
429	return messages, warnings
430}
431
432func hasVisibleCompatUserContent(content []openaisdk.ChatCompletionContentPartUnionParam) bool {
433	for _, part := range content {
434		if part.OfText != nil || part.OfImageURL != nil || part.OfInputAudio != nil || part.OfFile != nil {
435			return true
436		}
437	}
438	return false
439}
440
441func hasVisibleCompatAssistantContent(msg *openaisdk.ChatCompletionAssistantMessageParam) bool {
442	// Check if there's text content
443	if !param.IsOmitted(msg.Content.OfString) || len(msg.Content.OfArrayOfContentParts) > 0 {
444		return true
445	}
446	// Check if there are tool calls
447	if len(msg.ToolCalls) > 0 {
448		return true
449	}
450	return false
451}