language_model_hooks.go

  1package openaicompat
  2
  3import (
  4	"encoding/base64"
  5	"encoding/json"
  6	"fmt"
  7	"strings"
  8
  9	"charm.land/fantasy"
 10	"charm.land/fantasy/providers/openai"
 11	openaisdk "github.com/charmbracelet/openai-go"
 12	"github.com/charmbracelet/openai-go/packages/param"
 13	"github.com/charmbracelet/openai-go/shared"
 14)
 15
 16const reasoningStartedCtx = "reasoning_started"
 17
 18// PrepareCallFunc prepares the call for the language model.
 19func PrepareCallFunc(_ fantasy.LanguageModel, params *openaisdk.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) {
 20	providerOptions := &ProviderOptions{}
 21	if v, ok := call.ProviderOptions[Name]; ok {
 22		providerOptions, ok = v.(*ProviderOptions)
 23		if !ok {
 24			return nil, &fantasy.Error{Title: "invalid argument", Message: "openai-compat provider options should be *openaicompat.ProviderOptions"}
 25		}
 26	}
 27
 28	if providerOptions.ReasoningEffort != nil {
 29		switch *providerOptions.ReasoningEffort {
 30		case openai.ReasoningEffortNone:
 31			params.ReasoningEffort = shared.ReasoningEffortNone
 32		case openai.ReasoningEffortMinimal:
 33			params.ReasoningEffort = shared.ReasoningEffortMinimal
 34		case openai.ReasoningEffortLow:
 35			params.ReasoningEffort = shared.ReasoningEffortLow
 36		case openai.ReasoningEffortMedium:
 37			params.ReasoningEffort = shared.ReasoningEffortMedium
 38		case openai.ReasoningEffortHigh:
 39			params.ReasoningEffort = shared.ReasoningEffortHigh
 40		case openai.ReasoningEffortXHigh:
 41			params.ReasoningEffort = shared.ReasoningEffortXhigh
 42		default:
 43			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 44		}
 45	}
 46
 47	if providerOptions.User != nil {
 48		params.User = param.NewOpt(*providerOptions.User)
 49	}
 50	return nil, nil
 51}
 52
 53// ExtraContentFunc adds extra content to the response.
 54func ExtraContentFunc(choice openaisdk.ChatCompletionChoice) []fantasy.Content {
 55	var content []fantasy.Content
 56	reasoningData := ReasoningData{}
 57	err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData)
 58	if err != nil {
 59		return content
 60	}
 61	if reasoningData.ReasoningContent != "" {
 62		content = append(content, fantasy.ReasoningContent{
 63			Text: reasoningData.ReasoningContent,
 64		})
 65	}
 66	return content
 67}
 68
 69func extractReasoningContext(ctx map[string]any) bool {
 70	reasoningStarted, ok := ctx[reasoningStartedCtx]
 71	if !ok {
 72		return false
 73	}
 74	b, ok := reasoningStarted.(bool)
 75	if !ok {
 76		return false
 77	}
 78	return b
 79}
 80
 81// StreamExtraFunc handles extra functionality for streaming responses.
 82func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any) (map[string]any, bool) {
 83	if len(chunk.Choices) == 0 {
 84		return ctx, true
 85	}
 86
 87	reasoningStarted := extractReasoningContext(ctx)
 88
 89	for inx, choice := range chunk.Choices {
 90		reasoningData := ReasoningData{}
 91		err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData)
 92		if err != nil {
 93			yield(fantasy.StreamPart{
 94				Type:  fantasy.StreamPartTypeError,
 95				Error: &fantasy.Error{Title: "stream error", Message: "error unmarshalling delta", Cause: err},
 96			})
 97			return ctx, false
 98		}
 99
100		emitEvent := func(reasoningContent string) bool {
101			if !reasoningStarted {
102				shouldContinue := yield(fantasy.StreamPart{
103					Type: fantasy.StreamPartTypeReasoningStart,
104					ID:   fmt.Sprintf("%d", inx),
105				})
106				if !shouldContinue {
107					return false
108				}
109			}
110
111			return yield(fantasy.StreamPart{
112				Type:  fantasy.StreamPartTypeReasoningDelta,
113				ID:    fmt.Sprintf("%d", inx),
114				Delta: reasoningContent,
115			})
116		}
117		if reasoningData.ReasoningContent != "" {
118			if !reasoningStarted {
119				ctx[reasoningStartedCtx] = true
120			}
121			return ctx, emitEvent(reasoningData.ReasoningContent)
122		}
123		if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) {
124			ctx[reasoningStartedCtx] = false
125			return ctx, yield(fantasy.StreamPart{
126				Type: fantasy.StreamPartTypeReasoningEnd,
127				ID:   fmt.Sprintf("%d", inx),
128			})
129		}
130	}
131	return ctx, true
132}
133
134// ToPromptFunc converts a fantasy prompt to OpenAI format with reasoning support.
135// It handles fantasy.ContentTypeReasoning in assistant messages by adding the
136// reasoning_content field to the message JSON.
137func ToPromptFunc(prompt fantasy.Prompt, _, _ string) ([]openaisdk.ChatCompletionMessageParamUnion, []fantasy.CallWarning) {
138	var messages []openaisdk.ChatCompletionMessageParamUnion
139	var warnings []fantasy.CallWarning
140	for _, msg := range prompt {
141		switch msg.Role {
142		case fantasy.MessageRoleSystem:
143			var systemPromptParts []string
144			for _, c := range msg.Content {
145				if c.GetType() != fantasy.ContentTypeText {
146					warnings = append(warnings, fantasy.CallWarning{
147						Type:    fantasy.CallWarningTypeOther,
148						Message: "system prompt can only have text content",
149					})
150					continue
151				}
152				textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
153				if !ok {
154					warnings = append(warnings, fantasy.CallWarning{
155						Type:    fantasy.CallWarningTypeOther,
156						Message: "system prompt text part does not have the right type",
157					})
158					continue
159				}
160				text := textPart.Text
161				if strings.TrimSpace(text) != "" {
162					systemPromptParts = append(systemPromptParts, textPart.Text)
163				}
164			}
165			if len(systemPromptParts) == 0 {
166				warnings = append(warnings, fantasy.CallWarning{
167					Type:    fantasy.CallWarningTypeOther,
168					Message: "system prompt has no text parts",
169				})
170				continue
171			}
172			messages = append(messages, openaisdk.SystemMessage(strings.Join(systemPromptParts, "\n")))
173		case fantasy.MessageRoleUser:
174			// simple user message just text content
175			if len(msg.Content) == 1 && msg.Content[0].GetType() == fantasy.ContentTypeText {
176				textPart, ok := fantasy.AsContentType[fantasy.TextPart](msg.Content[0])
177				if !ok {
178					warnings = append(warnings, fantasy.CallWarning{
179						Type:    fantasy.CallWarningTypeOther,
180						Message: "user message text part does not have the right type",
181					})
182					continue
183				}
184				messages = append(messages, openaisdk.UserMessage(textPart.Text))
185				continue
186			}
187			// text content and attachments
188			var content []openaisdk.ChatCompletionContentPartUnionParam
189			for _, c := range msg.Content {
190				switch c.GetType() {
191				case fantasy.ContentTypeText:
192					textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
193					if !ok {
194						warnings = append(warnings, fantasy.CallWarning{
195							Type:    fantasy.CallWarningTypeOther,
196							Message: "user message text part does not have the right type",
197						})
198						continue
199					}
200					content = append(content, openaisdk.ChatCompletionContentPartUnionParam{
201						OfText: &openaisdk.ChatCompletionContentPartTextParam{
202							Text: textPart.Text,
203						},
204					})
205				case fantasy.ContentTypeFile:
206					filePart, ok := fantasy.AsContentType[fantasy.FilePart](c)
207					if !ok {
208						warnings = append(warnings, fantasy.CallWarning{
209							Type:    fantasy.CallWarningTypeOther,
210							Message: "user message file part does not have the right type",
211						})
212						continue
213					}
214
215					switch {
216					case strings.HasPrefix(filePart.MediaType, "text/"):
217						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
218						documentBlock := openaisdk.ChatCompletionContentPartFileFileParam{
219							FileData: param.NewOpt(base64Encoded),
220						}
221						content = append(content, openaisdk.FileContentPart(documentBlock))
222
223					case strings.HasPrefix(filePart.MediaType, "image/"):
224						// Handle image files
225						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
226						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
227						imageURL := openaisdk.ChatCompletionContentPartImageImageURLParam{URL: data}
228
229						// Check for provider-specific options like image detail
230						if providerOptions, ok := filePart.ProviderOptions[openai.Name]; ok {
231							if detail, ok := providerOptions.(*openai.ProviderFileOptions); ok {
232								imageURL.Detail = detail.ImageDetail
233							}
234						}
235
236						imageBlock := openaisdk.ChatCompletionContentPartImageParam{ImageURL: imageURL}
237						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
238
239					case filePart.MediaType == "audio/wav":
240						// Handle WAV audio files
241						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
242						audioBlock := openaisdk.ChatCompletionContentPartInputAudioParam{
243							InputAudio: openaisdk.ChatCompletionContentPartInputAudioInputAudioParam{
244								Data:   base64Encoded,
245								Format: "wav",
246							},
247						}
248						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
249
250					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
251						// Handle MP3 audio files
252						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
253						audioBlock := openaisdk.ChatCompletionContentPartInputAudioParam{
254							InputAudio: openaisdk.ChatCompletionContentPartInputAudioInputAudioParam{
255								Data:   base64Encoded,
256								Format: "mp3",
257							},
258						}
259						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
260
261					case filePart.MediaType == "application/pdf":
262						// Handle PDF files
263						dataStr := string(filePart.Data)
264
265						// Check if data looks like a file ID (starts with "file-")
266						if strings.HasPrefix(dataStr, "file-") {
267							fileBlock := openaisdk.ChatCompletionContentPartFileParam{
268								File: openaisdk.ChatCompletionContentPartFileFileParam{
269									FileID: param.NewOpt(dataStr),
270								},
271							}
272							content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
273						} else {
274							// Handle as base64 data
275							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
276							data := "data:application/pdf;base64," + base64Encoded
277
278							filename := filePart.Filename
279							if filename == "" {
280								// Generate default filename based on content index
281								filename = fmt.Sprintf("part-%d.pdf", len(content))
282							}
283
284							fileBlock := openaisdk.ChatCompletionContentPartFileParam{
285								File: openaisdk.ChatCompletionContentPartFileFileParam{
286									Filename: param.NewOpt(filename),
287									FileData: param.NewOpt(data),
288								},
289							}
290							content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
291						}
292
293					default:
294						warnings = append(warnings, fantasy.CallWarning{
295							Type:    fantasy.CallWarningTypeOther,
296							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
297						})
298					}
299				}
300			}
301			if !hasVisibleCompatUserContent(content) {
302				warnings = append(warnings, fantasy.CallWarning{
303					Type:    fantasy.CallWarningTypeOther,
304					Message: "dropping empty user message (contains neither user-facing content nor tool results)",
305				})
306				continue
307			}
308			messages = append(messages, openaisdk.UserMessage(content))
309		case fantasy.MessageRoleAssistant:
310			// simple assistant message just text content
311			if len(msg.Content) == 1 && msg.Content[0].GetType() == fantasy.ContentTypeText {
312				textPart, ok := fantasy.AsContentType[fantasy.TextPart](msg.Content[0])
313				if !ok {
314					warnings = append(warnings, fantasy.CallWarning{
315						Type:    fantasy.CallWarningTypeOther,
316						Message: "assistant message text part does not have the right type",
317					})
318					continue
319				}
320				messages = append(messages, openaisdk.AssistantMessage(textPart.Text))
321				continue
322			}
323			assistantMsg := openaisdk.ChatCompletionAssistantMessageParam{
324				Role: "assistant",
325			}
326			var reasoningText string
327			for _, c := range msg.Content {
328				switch c.GetType() {
329				case fantasy.ContentTypeText:
330					textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
331					if !ok {
332						warnings = append(warnings, fantasy.CallWarning{
333							Type:    fantasy.CallWarningTypeOther,
334							Message: "assistant message text part does not have the right type",
335						})
336						continue
337					}
338					assistantMsg.Content = openaisdk.ChatCompletionAssistantMessageParamContentUnion{
339						OfString: param.NewOpt(textPart.Text),
340					}
341				case fantasy.ContentTypeReasoning:
342					reasoningPart, ok := fantasy.AsContentType[fantasy.ReasoningPart](c)
343					if !ok {
344						warnings = append(warnings, fantasy.CallWarning{
345							Type:    fantasy.CallWarningTypeOther,
346							Message: "assistant message reasoning part does not have the right type",
347						})
348						continue
349					}
350					reasoningText = reasoningPart.Text
351				case fantasy.ContentTypeToolCall:
352					toolCallPart, ok := fantasy.AsContentType[fantasy.ToolCallPart](c)
353					if !ok {
354						warnings = append(warnings, fantasy.CallWarning{
355							Type:    fantasy.CallWarningTypeOther,
356							Message: "assistant message tool part does not have the right type",
357						})
358						continue
359					}
360					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
361						openaisdk.ChatCompletionMessageToolCallUnionParam{
362							OfFunction: &openaisdk.ChatCompletionMessageFunctionToolCallParam{
363								ID:   toolCallPart.ToolCallID,
364								Type: "function",
365								Function: openaisdk.ChatCompletionMessageFunctionToolCallFunctionParam{
366									Name:      toolCallPart.ToolName,
367									Arguments: toolCallPart.Input,
368								},
369							},
370						})
371				}
372			}
373			// Add reasoning_content field if present
374			if reasoningText != "" {
375				assistantMsg.SetExtraFields(map[string]any{
376					"reasoning_content": reasoningText,
377				})
378			}
379			if !hasVisibleCompatAssistantContent(&assistantMsg) {
380				warnings = append(warnings, fantasy.CallWarning{
381					Type:    fantasy.CallWarningTypeOther,
382					Message: "dropping empty assistant message (contains neither user-facing content nor tool calls)",
383				})
384				continue
385			}
386			messages = append(messages, openaisdk.ChatCompletionMessageParamUnion{
387				OfAssistant: &assistantMsg,
388			})
389		case fantasy.MessageRoleTool:
390			for _, c := range msg.Content {
391				if c.GetType() != fantasy.ContentTypeToolResult {
392					warnings = append(warnings, fantasy.CallWarning{
393						Type:    fantasy.CallWarningTypeOther,
394						Message: "tool message can only have tool result content",
395					})
396					continue
397				}
398
399				toolResultPart, ok := fantasy.AsContentType[fantasy.ToolResultPart](c)
400				if !ok {
401					warnings = append(warnings, fantasy.CallWarning{
402						Type:    fantasy.CallWarningTypeOther,
403						Message: "tool message result part does not have the right type",
404					})
405					continue
406				}
407
408				switch toolResultPart.Output.GetType() {
409				case fantasy.ToolResultContentTypeText:
410					output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](toolResultPart.Output)
411					if !ok {
412						warnings = append(warnings, fantasy.CallWarning{
413							Type:    fantasy.CallWarningTypeOther,
414							Message: "tool result output does not have the right type",
415						})
416						continue
417					}
418					messages = append(messages, openaisdk.ToolMessage(output.Text, toolResultPart.ToolCallID))
419				case fantasy.ToolResultContentTypeError:
420					output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](toolResultPart.Output)
421					if !ok {
422						warnings = append(warnings, fantasy.CallWarning{
423							Type:    fantasy.CallWarningTypeOther,
424							Message: "tool result output does not have the right type",
425						})
426						continue
427					}
428					messages = append(messages, openaisdk.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
429				}
430			}
431		}
432	}
433	return messages, warnings
434}
435
436func hasVisibleCompatUserContent(content []openaisdk.ChatCompletionContentPartUnionParam) bool {
437	for _, part := range content {
438		if part.OfText != nil || part.OfImageURL != nil || part.OfInputAudio != nil || part.OfFile != nil {
439			return true
440		}
441	}
442	return false
443}
444
445func hasVisibleCompatAssistantContent(msg *openaisdk.ChatCompletionAssistantMessageParam) bool {
446	// Check if there's text content
447	if !param.IsOmitted(msg.Content.OfString) || len(msg.Content.OfArrayOfContentParts) > 0 {
448		return true
449	}
450	// Check if there are tool calls
451	if len(msg.ToolCalls) > 0 {
452		return true
453	}
454	return false
455}