language_model_hooks.go

  1package openaicompat
  2
  3import (
  4	"encoding/base64"
  5	"encoding/json"
  6	"fmt"
  7	"strings"
  8
  9	"charm.land/fantasy"
 10	"charm.land/fantasy/providers/openai"
 11	openaisdk "github.com/openai/openai-go/v2"
 12	"github.com/openai/openai-go/v2/packages/param"
 13	"github.com/openai/openai-go/v2/shared"
 14)
 15
 16const reasoningStartedCtx = "reasoning_started"
 17
 18// PrepareCallFunc prepares the call for the language model.
 19func PrepareCallFunc(_ fantasy.LanguageModel, params *openaisdk.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) {
 20	providerOptions := &ProviderOptions{}
 21	if v, ok := call.ProviderOptions[Name]; ok {
 22		providerOptions, ok = v.(*ProviderOptions)
 23		if !ok {
 24			return nil, &fantasy.Error{Title: "invalid argument", Message: "openai-compat provider options should be *openaicompat.ProviderOptions"}
 25		}
 26	}
 27
 28	if providerOptions.ReasoningEffort != nil {
 29		switch *providerOptions.ReasoningEffort {
 30		case openai.ReasoningEffortMinimal:
 31			params.ReasoningEffort = shared.ReasoningEffortMinimal
 32		case openai.ReasoningEffortLow:
 33			params.ReasoningEffort = shared.ReasoningEffortLow
 34		case openai.ReasoningEffortMedium:
 35			params.ReasoningEffort = shared.ReasoningEffortMedium
 36		case openai.ReasoningEffortHigh:
 37			params.ReasoningEffort = shared.ReasoningEffortHigh
 38		default:
 39			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 40		}
 41	}
 42
 43	if providerOptions.User != nil {
 44		params.User = param.NewOpt(*providerOptions.User)
 45	}
 46	return nil, nil
 47}
 48
 49// ExtraContentFunc adds extra content to the response.
 50func ExtraContentFunc(choice openaisdk.ChatCompletionChoice) []fantasy.Content {
 51	var content []fantasy.Content
 52	reasoningData := ReasoningData{}
 53	err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData)
 54	if err != nil {
 55		return content
 56	}
 57	if reasoningData.ReasoningContent != "" {
 58		content = append(content, fantasy.ReasoningContent{
 59			Text: reasoningData.ReasoningContent,
 60		})
 61	}
 62	return content
 63}
 64
 65func extractReasoningContext(ctx map[string]any) bool {
 66	reasoningStarted, ok := ctx[reasoningStartedCtx]
 67	if !ok {
 68		return false
 69	}
 70	b, ok := reasoningStarted.(bool)
 71	if !ok {
 72		return false
 73	}
 74	return b
 75}
 76
 77// StreamExtraFunc handles extra functionality for streaming responses.
 78func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any) (map[string]any, bool) {
 79	if len(chunk.Choices) == 0 {
 80		return ctx, true
 81	}
 82
 83	reasoningStarted := extractReasoningContext(ctx)
 84
 85	for inx, choice := range chunk.Choices {
 86		reasoningData := ReasoningData{}
 87		err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData)
 88		if err != nil {
 89			yield(fantasy.StreamPart{
 90				Type:  fantasy.StreamPartTypeError,
 91				Error: &fantasy.Error{Title: "stream error", Message: "error unmarshalling delta", Cause: err},
 92			})
 93			return ctx, false
 94		}
 95
 96		emitEvent := func(reasoningContent string) bool {
 97			if !reasoningStarted {
 98				shouldContinue := yield(fantasy.StreamPart{
 99					Type: fantasy.StreamPartTypeReasoningStart,
100					ID:   fmt.Sprintf("%d", inx),
101				})
102				if !shouldContinue {
103					return false
104				}
105			}
106
107			return yield(fantasy.StreamPart{
108				Type:  fantasy.StreamPartTypeReasoningDelta,
109				ID:    fmt.Sprintf("%d", inx),
110				Delta: reasoningContent,
111			})
112		}
113		if reasoningData.ReasoningContent != "" {
114			if !reasoningStarted {
115				ctx[reasoningStartedCtx] = true
116			}
117			return ctx, emitEvent(reasoningData.ReasoningContent)
118		}
119		if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) {
120			ctx[reasoningStartedCtx] = false
121			return ctx, yield(fantasy.StreamPart{
122				Type: fantasy.StreamPartTypeReasoningEnd,
123				ID:   fmt.Sprintf("%d", inx),
124			})
125		}
126	}
127	return ctx, true
128}
129
130// ToPromptFunc converts a fantasy prompt to OpenAI format with reasoning support.
131// It handles fantasy.ContentTypeReasoning in assistant messages by adding the
132// reasoning_content field to the message JSON.
133func ToPromptFunc(prompt fantasy.Prompt, _, _ string) ([]openaisdk.ChatCompletionMessageParamUnion, []fantasy.CallWarning) {
134	var messages []openaisdk.ChatCompletionMessageParamUnion
135	var warnings []fantasy.CallWarning
136	for _, msg := range prompt {
137		switch msg.Role {
138		case fantasy.MessageRoleSystem:
139			var systemPromptParts []string
140			for _, c := range msg.Content {
141				if c.GetType() != fantasy.ContentTypeText {
142					warnings = append(warnings, fantasy.CallWarning{
143						Type:    fantasy.CallWarningTypeOther,
144						Message: "system prompt can only have text content",
145					})
146					continue
147				}
148				textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
149				if !ok {
150					warnings = append(warnings, fantasy.CallWarning{
151						Type:    fantasy.CallWarningTypeOther,
152						Message: "system prompt text part does not have the right type",
153					})
154					continue
155				}
156				text := textPart.Text
157				if strings.TrimSpace(text) != "" {
158					systemPromptParts = append(systemPromptParts, textPart.Text)
159				}
160			}
161			if len(systemPromptParts) == 0 {
162				warnings = append(warnings, fantasy.CallWarning{
163					Type:    fantasy.CallWarningTypeOther,
164					Message: "system prompt has no text parts",
165				})
166				continue
167			}
168			messages = append(messages, openaisdk.SystemMessage(strings.Join(systemPromptParts, "\n")))
169		case fantasy.MessageRoleUser:
170			// simple user message just text content
171			if len(msg.Content) == 1 && msg.Content[0].GetType() == fantasy.ContentTypeText {
172				textPart, ok := fantasy.AsContentType[fantasy.TextPart](msg.Content[0])
173				if !ok {
174					warnings = append(warnings, fantasy.CallWarning{
175						Type:    fantasy.CallWarningTypeOther,
176						Message: "user message text part does not have the right type",
177					})
178					continue
179				}
180				messages = append(messages, openaisdk.UserMessage(textPart.Text))
181				continue
182			}
183			// text content and attachments
184			var content []openaisdk.ChatCompletionContentPartUnionParam
185			for _, c := range msg.Content {
186				switch c.GetType() {
187				case fantasy.ContentTypeText:
188					textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
189					if !ok {
190						warnings = append(warnings, fantasy.CallWarning{
191							Type:    fantasy.CallWarningTypeOther,
192							Message: "user message text part does not have the right type",
193						})
194						continue
195					}
196					content = append(content, openaisdk.ChatCompletionContentPartUnionParam{
197						OfText: &openaisdk.ChatCompletionContentPartTextParam{
198							Text: textPart.Text,
199						},
200					})
201				case fantasy.ContentTypeFile:
202					filePart, ok := fantasy.AsContentType[fantasy.FilePart](c)
203					if !ok {
204						warnings = append(warnings, fantasy.CallWarning{
205							Type:    fantasy.CallWarningTypeOther,
206							Message: "user message file part does not have the right type",
207						})
208						continue
209					}
210
211					switch {
212					case strings.HasPrefix(filePart.MediaType, "image/"):
213						// Handle image files
214						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
215						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
216						imageURL := openaisdk.ChatCompletionContentPartImageImageURLParam{URL: data}
217
218						// Check for provider-specific options like image detail
219						if providerOptions, ok := filePart.ProviderOptions[openai.Name]; ok {
220							if detail, ok := providerOptions.(*openai.ProviderFileOptions); ok {
221								imageURL.Detail = detail.ImageDetail
222							}
223						}
224
225						imageBlock := openaisdk.ChatCompletionContentPartImageParam{ImageURL: imageURL}
226						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
227
228					case filePart.MediaType == "audio/wav":
229						// Handle WAV audio files
230						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
231						audioBlock := openaisdk.ChatCompletionContentPartInputAudioParam{
232							InputAudio: openaisdk.ChatCompletionContentPartInputAudioInputAudioParam{
233								Data:   base64Encoded,
234								Format: "wav",
235							},
236						}
237						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
238
239					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
240						// Handle MP3 audio files
241						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
242						audioBlock := openaisdk.ChatCompletionContentPartInputAudioParam{
243							InputAudio: openaisdk.ChatCompletionContentPartInputAudioInputAudioParam{
244								Data:   base64Encoded,
245								Format: "mp3",
246							},
247						}
248						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
249
250					case filePart.MediaType == "application/pdf":
251						// Handle PDF files
252						dataStr := string(filePart.Data)
253
254						// Check if data looks like a file ID (starts with "file-")
255						if strings.HasPrefix(dataStr, "file-") {
256							fileBlock := openaisdk.ChatCompletionContentPartFileParam{
257								File: openaisdk.ChatCompletionContentPartFileFileParam{
258									FileID: param.NewOpt(dataStr),
259								},
260							}
261							content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
262						} else {
263							// Handle as base64 data
264							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
265							data := "data:application/pdf;base64," + base64Encoded
266
267							filename := filePart.Filename
268							if filename == "" {
269								// Generate default filename based on content index
270								filename = fmt.Sprintf("part-%d.pdf", len(content))
271							}
272
273							fileBlock := openaisdk.ChatCompletionContentPartFileParam{
274								File: openaisdk.ChatCompletionContentPartFileFileParam{
275									Filename: param.NewOpt(filename),
276									FileData: param.NewOpt(data),
277								},
278							}
279							content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
280						}
281
282					default:
283						warnings = append(warnings, fantasy.CallWarning{
284							Type:    fantasy.CallWarningTypeOther,
285							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
286						})
287					}
288				}
289			}
290			messages = append(messages, openaisdk.UserMessage(content))
291		case fantasy.MessageRoleAssistant:
292			// simple assistant message just text content
293			if len(msg.Content) == 1 && msg.Content[0].GetType() == fantasy.ContentTypeText {
294				textPart, ok := fantasy.AsContentType[fantasy.TextPart](msg.Content[0])
295				if !ok {
296					warnings = append(warnings, fantasy.CallWarning{
297						Type:    fantasy.CallWarningTypeOther,
298						Message: "assistant message text part does not have the right type",
299					})
300					continue
301				}
302				messages = append(messages, openaisdk.AssistantMessage(textPart.Text))
303				continue
304			}
305			assistantMsg := openaisdk.ChatCompletionAssistantMessageParam{
306				Role: "assistant",
307			}
308			var reasoningText string
309			for _, c := range msg.Content {
310				switch c.GetType() {
311				case fantasy.ContentTypeText:
312					textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
313					if !ok {
314						warnings = append(warnings, fantasy.CallWarning{
315							Type:    fantasy.CallWarningTypeOther,
316							Message: "assistant message text part does not have the right type",
317						})
318						continue
319					}
320					assistantMsg.Content = openaisdk.ChatCompletionAssistantMessageParamContentUnion{
321						OfString: param.NewOpt(textPart.Text),
322					}
323				case fantasy.ContentTypeReasoning:
324					reasoningPart, ok := fantasy.AsContentType[fantasy.ReasoningPart](c)
325					if !ok {
326						warnings = append(warnings, fantasy.CallWarning{
327							Type:    fantasy.CallWarningTypeOther,
328							Message: "assistant message reasoning part does not have the right type",
329						})
330						continue
331					}
332					reasoningText = reasoningPart.Text
333				case fantasy.ContentTypeToolCall:
334					toolCallPart, ok := fantasy.AsContentType[fantasy.ToolCallPart](c)
335					if !ok {
336						warnings = append(warnings, fantasy.CallWarning{
337							Type:    fantasy.CallWarningTypeOther,
338							Message: "assistant message tool part does not have the right type",
339						})
340						continue
341					}
342					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
343						openaisdk.ChatCompletionMessageToolCallUnionParam{
344							OfFunction: &openaisdk.ChatCompletionMessageFunctionToolCallParam{
345								ID:   toolCallPart.ToolCallID,
346								Type: "function",
347								Function: openaisdk.ChatCompletionMessageFunctionToolCallFunctionParam{
348									Name:      toolCallPart.ToolName,
349									Arguments: toolCallPart.Input,
350								},
351							},
352						})
353				}
354			}
355			// Add reasoning_content field if present
356			if reasoningText != "" {
357				assistantMsg.SetExtraFields(map[string]any{
358					"reasoning_content": reasoningText,
359				})
360			}
361			messages = append(messages, openaisdk.ChatCompletionMessageParamUnion{
362				OfAssistant: &assistantMsg,
363			})
364		case fantasy.MessageRoleTool:
365			for _, c := range msg.Content {
366				if c.GetType() != fantasy.ContentTypeToolResult {
367					warnings = append(warnings, fantasy.CallWarning{
368						Type:    fantasy.CallWarningTypeOther,
369						Message: "tool message can only have tool result content",
370					})
371					continue
372				}
373
374				toolResultPart, ok := fantasy.AsContentType[fantasy.ToolResultPart](c)
375				if !ok {
376					warnings = append(warnings, fantasy.CallWarning{
377						Type:    fantasy.CallWarningTypeOther,
378						Message: "tool message result part does not have the right type",
379					})
380					continue
381				}
382
383				switch toolResultPart.Output.GetType() {
384				case fantasy.ToolResultContentTypeText:
385					output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](toolResultPart.Output)
386					if !ok {
387						warnings = append(warnings, fantasy.CallWarning{
388							Type:    fantasy.CallWarningTypeOther,
389							Message: "tool result output does not have the right type",
390						})
391						continue
392					}
393					messages = append(messages, openaisdk.ToolMessage(output.Text, toolResultPart.ToolCallID))
394				case fantasy.ToolResultContentTypeError:
395					output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](toolResultPart.Output)
396					if !ok {
397						warnings = append(warnings, fantasy.CallWarning{
398							Type:    fantasy.CallWarningTypeOther,
399							Message: "tool result output does not have the right type",
400						})
401						continue
402					}
403					messages = append(messages, openaisdk.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
404				}
405			}
406		}
407	}
408	return messages, warnings
409}