language_model_hooks.go

  1package openaicompat
  2
  3import (
  4	"encoding/base64"
  5	"encoding/json"
  6	"fmt"
  7	"strings"
  8
  9	"charm.land/fantasy"
 10	"charm.land/fantasy/providers/openai"
 11	openaisdk "github.com/openai/openai-go/v2"
 12	"github.com/openai/openai-go/v2/packages/param"
 13	"github.com/openai/openai-go/v2/shared"
 14)
 15
 16const reasoningStartedCtx = "reasoning_started"
 17
 18// PrepareCallFunc prepares the call for the language model.
 19func PrepareCallFunc(_ fantasy.LanguageModel, params *openaisdk.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) {
 20	providerOptions := &ProviderOptions{}
 21	if v, ok := call.ProviderOptions[Name]; ok {
 22		providerOptions, ok = v.(*ProviderOptions)
 23		if !ok {
 24			return nil, &fantasy.Error{Title: "invalid argument", Message: "openai-compat provider options should be *openaicompat.ProviderOptions"}
 25		}
 26	}
 27
 28	if providerOptions.ReasoningEffort != nil {
 29		switch *providerOptions.ReasoningEffort {
 30		case openai.ReasoningEffortMinimal:
 31			params.ReasoningEffort = shared.ReasoningEffortMinimal
 32		case openai.ReasoningEffortLow:
 33			params.ReasoningEffort = shared.ReasoningEffortLow
 34		case openai.ReasoningEffortMedium:
 35			params.ReasoningEffort = shared.ReasoningEffortMedium
 36		case openai.ReasoningEffortHigh:
 37			params.ReasoningEffort = shared.ReasoningEffortHigh
 38		default:
 39			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 40		}
 41	}
 42
 43	if providerOptions.User != nil {
 44		params.User = param.NewOpt(*providerOptions.User)
 45	}
 46	return nil, nil
 47}
 48
 49// ExtraContentFunc adds extra content to the response.
 50func ExtraContentFunc(choice openaisdk.ChatCompletionChoice) []fantasy.Content {
 51	var content []fantasy.Content
 52	reasoningData := ReasoningData{}
 53	err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData)
 54	if err != nil {
 55		return content
 56	}
 57	if reasoningData.ReasoningContent != "" {
 58		content = append(content, fantasy.ReasoningContent{
 59			Text: reasoningData.ReasoningContent,
 60		})
 61	}
 62	return content
 63}
 64
 65func extractReasoningContext(ctx map[string]any) bool {
 66	reasoningStarted, ok := ctx[reasoningStartedCtx]
 67	if !ok {
 68		return false
 69	}
 70	b, ok := reasoningStarted.(bool)
 71	if !ok {
 72		return false
 73	}
 74	return b
 75}
 76
 77// StreamExtraFunc handles extra functionality for streaming responses.
 78func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any) (map[string]any, bool) {
 79	if len(chunk.Choices) == 0 {
 80		return ctx, true
 81	}
 82
 83	reasoningStarted := extractReasoningContext(ctx)
 84
 85	for inx, choice := range chunk.Choices {
 86		reasoningData := ReasoningData{}
 87		err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData)
 88		if err != nil {
 89			yield(fantasy.StreamPart{
 90				Type:  fantasy.StreamPartTypeError,
 91				Error: &fantasy.Error{Title: "stream error", Message: "error unmarshalling delta", Cause: err},
 92			})
 93			return ctx, false
 94		}
 95
 96		emitEvent := func(reasoningContent string) bool {
 97			if !reasoningStarted {
 98				shouldContinue := yield(fantasy.StreamPart{
 99					Type: fantasy.StreamPartTypeReasoningStart,
100					ID:   fmt.Sprintf("%d", inx),
101				})
102				if !shouldContinue {
103					return false
104				}
105			}
106
107			return yield(fantasy.StreamPart{
108				Type:  fantasy.StreamPartTypeReasoningDelta,
109				ID:    fmt.Sprintf("%d", inx),
110				Delta: reasoningContent,
111			})
112		}
113		if reasoningData.ReasoningContent != "" {
114			if !reasoningStarted {
115				ctx[reasoningStartedCtx] = true
116			}
117			return ctx, emitEvent(reasoningData.ReasoningContent)
118		}
119		if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) {
120			ctx[reasoningStartedCtx] = false
121			return ctx, yield(fantasy.StreamPart{
122				Type: fantasy.StreamPartTypeReasoningEnd,
123				ID:   fmt.Sprintf("%d", inx),
124			})
125		}
126	}
127	return ctx, true
128}
129
130// ToPromptFunc converts a fantasy prompt to OpenAI format with reasoning support.
131// It handles fantasy.ContentTypeReasoning in assistant messages by adding the
132// reasoning_content field to the message JSON.
133func ToPromptFunc(prompt fantasy.Prompt, _, _ string) ([]openaisdk.ChatCompletionMessageParamUnion, []fantasy.CallWarning) {
134	var messages []openaisdk.ChatCompletionMessageParamUnion
135	var warnings []fantasy.CallWarning
136	for _, msg := range prompt {
137		switch msg.Role {
138		case fantasy.MessageRoleSystem:
139			var systemPromptParts []string
140			for _, c := range msg.Content {
141				if c.GetType() != fantasy.ContentTypeText {
142					warnings = append(warnings, fantasy.CallWarning{
143						Type:    fantasy.CallWarningTypeOther,
144						Message: "system prompt can only have text content",
145					})
146					continue
147				}
148				textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
149				if !ok {
150					warnings = append(warnings, fantasy.CallWarning{
151						Type:    fantasy.CallWarningTypeOther,
152						Message: "system prompt text part does not have the right type",
153					})
154					continue
155				}
156				text := textPart.Text
157				if strings.TrimSpace(text) != "" {
158					systemPromptParts = append(systemPromptParts, textPart.Text)
159				}
160			}
161			if len(systemPromptParts) == 0 {
162				warnings = append(warnings, fantasy.CallWarning{
163					Type:    fantasy.CallWarningTypeOther,
164					Message: "system prompt has no text parts",
165				})
166				continue
167			}
168			messages = append(messages, openaisdk.SystemMessage(strings.Join(systemPromptParts, "\n")))
169		case fantasy.MessageRoleUser:
170			// simple user message just text content
171			if len(msg.Content) == 1 && msg.Content[0].GetType() == fantasy.ContentTypeText {
172				textPart, ok := fantasy.AsContentType[fantasy.TextPart](msg.Content[0])
173				if !ok {
174					warnings = append(warnings, fantasy.CallWarning{
175						Type:    fantasy.CallWarningTypeOther,
176						Message: "user message text part does not have the right type",
177					})
178					continue
179				}
180				messages = append(messages, openaisdk.UserMessage(textPart.Text))
181				continue
182			}
183			// text content and attachments
184			var content []openaisdk.ChatCompletionContentPartUnionParam
185			for _, c := range msg.Content {
186				switch c.GetType() {
187				case fantasy.ContentTypeText:
188					textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
189					if !ok {
190						warnings = append(warnings, fantasy.CallWarning{
191							Type:    fantasy.CallWarningTypeOther,
192							Message: "user message text part does not have the right type",
193						})
194						continue
195					}
196					content = append(content, openaisdk.ChatCompletionContentPartUnionParam{
197						OfText: &openaisdk.ChatCompletionContentPartTextParam{
198							Text: textPart.Text,
199						},
200					})
201				case fantasy.ContentTypeFile:
202					filePart, ok := fantasy.AsContentType[fantasy.FilePart](c)
203					if !ok {
204						warnings = append(warnings, fantasy.CallWarning{
205							Type:    fantasy.CallWarningTypeOther,
206							Message: "user message file part does not have the right type",
207						})
208						continue
209					}
210
211					switch {
212					case strings.HasPrefix(filePart.MediaType, "image/"):
213						// Handle image files
214						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
215						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
216						imageURL := openaisdk.ChatCompletionContentPartImageImageURLParam{URL: data}
217
218						// Check for provider-specific options like image detail
219						if providerOptions, ok := filePart.ProviderOptions[openai.Name]; ok {
220							if detail, ok := providerOptions.(*openai.ProviderFileOptions); ok {
221								imageURL.Detail = detail.ImageDetail
222							}
223						}
224
225						imageBlock := openaisdk.ChatCompletionContentPartImageParam{ImageURL: imageURL}
226						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
227
228					case filePart.MediaType == "audio/wav":
229						// Handle WAV audio files
230						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
231						audioBlock := openaisdk.ChatCompletionContentPartInputAudioParam{
232							InputAudio: openaisdk.ChatCompletionContentPartInputAudioInputAudioParam{
233								Data:   base64Encoded,
234								Format: "wav",
235							},
236						}
237						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
238
239					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
240						// Handle MP3 audio files
241						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
242						audioBlock := openaisdk.ChatCompletionContentPartInputAudioParam{
243							InputAudio: openaisdk.ChatCompletionContentPartInputAudioInputAudioParam{
244								Data:   base64Encoded,
245								Format: "mp3",
246							},
247						}
248						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
249
250					case filePart.MediaType == "application/pdf":
251						// Handle PDF files
252						dataStr := string(filePart.Data)
253
254						// Check if data looks like a file ID (starts with "file-")
255						if strings.HasPrefix(dataStr, "file-") {
256							fileBlock := openaisdk.ChatCompletionContentPartFileParam{
257								File: openaisdk.ChatCompletionContentPartFileFileParam{
258									FileID: param.NewOpt(dataStr),
259								},
260							}
261							content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
262						} else {
263							// Handle as base64 data
264							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
265							data := "data:application/pdf;base64," + base64Encoded
266
267							filename := filePart.Filename
268							if filename == "" {
269								// Generate default filename based on content index
270								filename = fmt.Sprintf("part-%d.pdf", len(content))
271							}
272
273							fileBlock := openaisdk.ChatCompletionContentPartFileParam{
274								File: openaisdk.ChatCompletionContentPartFileFileParam{
275									Filename: param.NewOpt(filename),
276									FileData: param.NewOpt(data),
277								},
278							}
279							content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
280						}
281
282					default:
283						warnings = append(warnings, fantasy.CallWarning{
284							Type:    fantasy.CallWarningTypeOther,
285							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
286						})
287					}
288				}
289			}
290			if !hasVisibleCompatUserContent(content) {
291				warnings = append(warnings, fantasy.CallWarning{
292					Type:    fantasy.CallWarningTypeOther,
293					Message: "dropping empty user message (contains neither user-facing content nor tool results)",
294				})
295				continue
296			}
297			messages = append(messages, openaisdk.UserMessage(content))
298		case fantasy.MessageRoleAssistant:
299			// simple assistant message just text content
300			if len(msg.Content) == 1 && msg.Content[0].GetType() == fantasy.ContentTypeText {
301				textPart, ok := fantasy.AsContentType[fantasy.TextPart](msg.Content[0])
302				if !ok {
303					warnings = append(warnings, fantasy.CallWarning{
304						Type:    fantasy.CallWarningTypeOther,
305						Message: "assistant message text part does not have the right type",
306					})
307					continue
308				}
309				messages = append(messages, openaisdk.AssistantMessage(textPart.Text))
310				continue
311			}
312			assistantMsg := openaisdk.ChatCompletionAssistantMessageParam{
313				Role: "assistant",
314			}
315			var reasoningText string
316			for _, c := range msg.Content {
317				switch c.GetType() {
318				case fantasy.ContentTypeText:
319					textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
320					if !ok {
321						warnings = append(warnings, fantasy.CallWarning{
322							Type:    fantasy.CallWarningTypeOther,
323							Message: "assistant message text part does not have the right type",
324						})
325						continue
326					}
327					assistantMsg.Content = openaisdk.ChatCompletionAssistantMessageParamContentUnion{
328						OfString: param.NewOpt(textPart.Text),
329					}
330				case fantasy.ContentTypeReasoning:
331					reasoningPart, ok := fantasy.AsContentType[fantasy.ReasoningPart](c)
332					if !ok {
333						warnings = append(warnings, fantasy.CallWarning{
334							Type:    fantasy.CallWarningTypeOther,
335							Message: "assistant message reasoning part does not have the right type",
336						})
337						continue
338					}
339					reasoningText = reasoningPart.Text
340				case fantasy.ContentTypeToolCall:
341					toolCallPart, ok := fantasy.AsContentType[fantasy.ToolCallPart](c)
342					if !ok {
343						warnings = append(warnings, fantasy.CallWarning{
344							Type:    fantasy.CallWarningTypeOther,
345							Message: "assistant message tool part does not have the right type",
346						})
347						continue
348					}
349					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
350						openaisdk.ChatCompletionMessageToolCallUnionParam{
351							OfFunction: &openaisdk.ChatCompletionMessageFunctionToolCallParam{
352								ID:   toolCallPart.ToolCallID,
353								Type: "function",
354								Function: openaisdk.ChatCompletionMessageFunctionToolCallFunctionParam{
355									Name:      toolCallPart.ToolName,
356									Arguments: toolCallPart.Input,
357								},
358							},
359						})
360				}
361			}
362			// Add reasoning_content field if present
363			if reasoningText != "" {
364				assistantMsg.SetExtraFields(map[string]any{
365					"reasoning_content": reasoningText,
366				})
367			}
368			if !hasVisibleCompatAssistantContent(&assistantMsg) {
369				warnings = append(warnings, fantasy.CallWarning{
370					Type:    fantasy.CallWarningTypeOther,
371					Message: "dropping empty assistant message (contains neither user-facing content nor tool calls)",
372				})
373				continue
374			}
375			messages = append(messages, openaisdk.ChatCompletionMessageParamUnion{
376				OfAssistant: &assistantMsg,
377			})
378		case fantasy.MessageRoleTool:
379			for _, c := range msg.Content {
380				if c.GetType() != fantasy.ContentTypeToolResult {
381					warnings = append(warnings, fantasy.CallWarning{
382						Type:    fantasy.CallWarningTypeOther,
383						Message: "tool message can only have tool result content",
384					})
385					continue
386				}
387
388				toolResultPart, ok := fantasy.AsContentType[fantasy.ToolResultPart](c)
389				if !ok {
390					warnings = append(warnings, fantasy.CallWarning{
391						Type:    fantasy.CallWarningTypeOther,
392						Message: "tool message result part does not have the right type",
393					})
394					continue
395				}
396
397				switch toolResultPart.Output.GetType() {
398				case fantasy.ToolResultContentTypeText:
399					output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](toolResultPart.Output)
400					if !ok {
401						warnings = append(warnings, fantasy.CallWarning{
402							Type:    fantasy.CallWarningTypeOther,
403							Message: "tool result output does not have the right type",
404						})
405						continue
406					}
407					messages = append(messages, openaisdk.ToolMessage(output.Text, toolResultPart.ToolCallID))
408				case fantasy.ToolResultContentTypeError:
409					output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](toolResultPart.Output)
410					if !ok {
411						warnings = append(warnings, fantasy.CallWarning{
412							Type:    fantasy.CallWarningTypeOther,
413							Message: "tool result output does not have the right type",
414						})
415						continue
416					}
417					messages = append(messages, openaisdk.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
418				}
419			}
420		}
421	}
422	return messages, warnings
423}
424
425func hasVisibleCompatUserContent(content []openaisdk.ChatCompletionContentPartUnionParam) bool {
426	for _, part := range content {
427		if part.OfText != nil || part.OfImageURL != nil || part.OfInputAudio != nil || part.OfFile != nil {
428			return true
429		}
430	}
431	return false
432}
433
434func hasVisibleCompatAssistantContent(msg *openaisdk.ChatCompletionAssistantMessageParam) bool {
435	// Check if there's text content
436	if !param.IsOmitted(msg.Content.OfString) || len(msg.Content.OfArrayOfContentParts) > 0 {
437		return true
438	}
439	// Check if there are tool calls
440	if len(msg.ToolCalls) > 0 {
441		return true
442	}
443	return false
444}