language_model_hooks.go

  1package openaicompat
  2
  3import (
  4	"encoding/base64"
  5	"encoding/json"
  6	"fmt"
  7	"strings"
  8
  9	"charm.land/fantasy"
 10	"charm.land/fantasy/providers/openai"
 11	openaisdk "github.com/openai/openai-go/v2"
 12	"github.com/openai/openai-go/v2/packages/param"
 13	"github.com/openai/openai-go/v2/shared"
 14)
 15
 16const reasoningStartedCtx = "reasoning_started"
 17
 18// PrepareCallFunc prepares the call for the language model.
 19func PrepareCallFunc(_ fantasy.LanguageModel, params *openaisdk.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) {
 20	providerOptions := &ProviderOptions{}
 21	if v, ok := call.ProviderOptions[Name]; ok {
 22		providerOptions, ok = v.(*ProviderOptions)
 23		if !ok {
 24			return nil, &fantasy.Error{Title: "invalid argument", Message: "openai-compat provider options should be *openaicompat.ProviderOptions"}
 25		}
 26	}
 27
 28	if providerOptions.ReasoningEffort != nil {
 29		switch *providerOptions.ReasoningEffort {
 30		case openai.ReasoningEffortMinimal:
 31			params.ReasoningEffort = shared.ReasoningEffortMinimal
 32		case openai.ReasoningEffortLow:
 33			params.ReasoningEffort = shared.ReasoningEffortLow
 34		case openai.ReasoningEffortMedium:
 35			params.ReasoningEffort = shared.ReasoningEffortMedium
 36		case openai.ReasoningEffortHigh:
 37			params.ReasoningEffort = shared.ReasoningEffortHigh
 38		default:
 39			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 40		}
 41	}
 42
 43	if providerOptions.User != nil {
 44		params.User = param.NewOpt(*providerOptions.User)
 45	}
 46	return nil, nil
 47}
 48
 49// ExtraContentFunc adds extra content to the response.
 50func ExtraContentFunc(choice openaisdk.ChatCompletionChoice) []fantasy.Content {
 51	var content []fantasy.Content
 52	reasoningData := ReasoningData{}
 53	err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData)
 54	if err != nil {
 55		return content
 56	}
 57	if reasoningData.ReasoningContent != "" {
 58		content = append(content, fantasy.ReasoningContent{
 59			Text: reasoningData.ReasoningContent,
 60		})
 61	}
 62	return content
 63}
 64
 65func extractReasoningContext(ctx map[string]any) bool {
 66	reasoningStarted, ok := ctx[reasoningStartedCtx]
 67	if !ok {
 68		return false
 69	}
 70	b, ok := reasoningStarted.(bool)
 71	if !ok {
 72		return false
 73	}
 74	return b
 75}
 76
 77// StreamExtraFunc handles extra functionality for streaming responses.
 78func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any) (map[string]any, bool) {
 79	if len(chunk.Choices) == 0 {
 80		return ctx, true
 81	}
 82
 83	reasoningStarted := extractReasoningContext(ctx)
 84
 85	for inx, choice := range chunk.Choices {
 86		reasoningData := ReasoningData{}
 87		err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData)
 88		if err != nil {
 89			yield(fantasy.StreamPart{
 90				Type:  fantasy.StreamPartTypeError,
 91				Error: &fantasy.Error{Title: "stream error", Message: "error unmarshalling delta", Cause: err},
 92			})
 93			return ctx, false
 94		}
 95
 96		emitEvent := func(reasoningContent string) bool {
 97			if !reasoningStarted {
 98				shouldContinue := yield(fantasy.StreamPart{
 99					Type: fantasy.StreamPartTypeReasoningStart,
100					ID:   fmt.Sprintf("%d", inx),
101				})
102				if !shouldContinue {
103					return false
104				}
105			}
106
107			return yield(fantasy.StreamPart{
108				Type:  fantasy.StreamPartTypeReasoningDelta,
109				ID:    fmt.Sprintf("%d", inx),
110				Delta: reasoningContent,
111			})
112		}
113		if reasoningData.ReasoningContent != "" {
114			if !reasoningStarted {
115				ctx[reasoningStartedCtx] = true
116			}
117			return ctx, emitEvent(reasoningData.ReasoningContent)
118		}
119		if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) {
120			ctx[reasoningStartedCtx] = false
121			return ctx, yield(fantasy.StreamPart{
122				Type: fantasy.StreamPartTypeReasoningEnd,
123				ID:   fmt.Sprintf("%d", inx),
124			})
125		}
126	}
127	return ctx, true
128}
129
130// ToPromptFunc converts a fantasy prompt to OpenAI format with reasoning support.
131// It handles fantasy.ContentTypeReasoning in assistant messages by adding the
132// reasoning_content field to the message JSON.
133func ToPromptFunc(prompt fantasy.Prompt, _, _ string) ([]openaisdk.ChatCompletionMessageParamUnion, []fantasy.CallWarning) {
134	var messages []openaisdk.ChatCompletionMessageParamUnion
135	var warnings []fantasy.CallWarning
136	for _, msg := range prompt {
137		switch msg.Role {
138		case fantasy.MessageRoleSystem:
139			var systemPromptParts []string
140			for _, c := range msg.Content {
141				if c.GetType() != fantasy.ContentTypeText {
142					warnings = append(warnings, fantasy.CallWarning{
143						Type:    fantasy.CallWarningTypeOther,
144						Message: "system prompt can only have text content",
145					})
146					continue
147				}
148				textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
149				if !ok {
150					warnings = append(warnings, fantasy.CallWarning{
151						Type:    fantasy.CallWarningTypeOther,
152						Message: "system prompt text part does not have the right type",
153					})
154					continue
155				}
156				text := textPart.Text
157				if strings.TrimSpace(text) != "" {
158					systemPromptParts = append(systemPromptParts, textPart.Text)
159				}
160			}
161			if len(systemPromptParts) == 0 {
162				warnings = append(warnings, fantasy.CallWarning{
163					Type:    fantasy.CallWarningTypeOther,
164					Message: "system prompt has no text parts",
165				})
166				continue
167			}
168			messages = append(messages, openaisdk.SystemMessage(strings.Join(systemPromptParts, "\n")))
169		case fantasy.MessageRoleUser:
170			// simple user message just text content
171			if len(msg.Content) == 1 && msg.Content[0].GetType() == fantasy.ContentTypeText {
172				textPart, ok := fantasy.AsContentType[fantasy.TextPart](msg.Content[0])
173				if !ok {
174					warnings = append(warnings, fantasy.CallWarning{
175						Type:    fantasy.CallWarningTypeOther,
176						Message: "user message text part does not have the right type",
177					})
178					continue
179				}
180				messages = append(messages, openaisdk.UserMessage(textPart.Text))
181				continue
182			}
183			// text content and attachments
184			var content []openaisdk.ChatCompletionContentPartUnionParam
185			for _, c := range msg.Content {
186				switch c.GetType() {
187				case fantasy.ContentTypeText:
188					textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
189					if !ok {
190						warnings = append(warnings, fantasy.CallWarning{
191							Type:    fantasy.CallWarningTypeOther,
192							Message: "user message text part does not have the right type",
193						})
194						continue
195					}
196					content = append(content, openaisdk.ChatCompletionContentPartUnionParam{
197						OfText: &openaisdk.ChatCompletionContentPartTextParam{
198							Text: textPart.Text,
199						},
200					})
201				case fantasy.ContentTypeFile:
202					filePart, ok := fantasy.AsContentType[fantasy.FilePart](c)
203					if !ok {
204						warnings = append(warnings, fantasy.CallWarning{
205							Type:    fantasy.CallWarningTypeOther,
206							Message: "user message file part does not have the right type",
207						})
208						continue
209					}
210
211					switch {
212					case strings.HasPrefix(filePart.MediaType, "text/"):
213						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
214						content = append(content, openaisdk.FileContentPart(openaisdk.ChatCompletionContentPartFileFileParam{
215							FileData: param.NewOpt(base64Encoded),
216						}))
217					case strings.HasPrefix(filePart.MediaType, "image/"):
218						// Handle image files
219						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
220						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
221						imageURL := openaisdk.ChatCompletionContentPartImageImageURLParam{URL: data}
222
223						// Check for provider-specific options like image detail
224						if providerOptions, ok := filePart.ProviderOptions[openai.Name]; ok {
225							if detail, ok := providerOptions.(*openai.ProviderFileOptions); ok {
226								imageURL.Detail = detail.ImageDetail
227							}
228						}
229
230						imageBlock := openaisdk.ChatCompletionContentPartImageParam{ImageURL: imageURL}
231						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
232
233					case filePart.MediaType == "audio/wav":
234						// Handle WAV audio files
235						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
236						audioBlock := openaisdk.ChatCompletionContentPartInputAudioParam{
237							InputAudio: openaisdk.ChatCompletionContentPartInputAudioInputAudioParam{
238								Data:   base64Encoded,
239								Format: "wav",
240							},
241						}
242						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
243
244					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
245						// Handle MP3 audio files
246						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
247						audioBlock := openaisdk.ChatCompletionContentPartInputAudioParam{
248							InputAudio: openaisdk.ChatCompletionContentPartInputAudioInputAudioParam{
249								Data:   base64Encoded,
250								Format: "mp3",
251							},
252						}
253						content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
254
255					case filePart.MediaType == "application/pdf":
256						// Handle PDF files
257						dataStr := string(filePart.Data)
258
259						// Check if data looks like a file ID (starts with "file-")
260						if strings.HasPrefix(dataStr, "file-") {
261							fileBlock := openaisdk.ChatCompletionContentPartFileParam{
262								File: openaisdk.ChatCompletionContentPartFileFileParam{
263									FileID: param.NewOpt(dataStr),
264								},
265							}
266							content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
267						} else {
268							// Handle as base64 data
269							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
270							data := "data:application/pdf;base64," + base64Encoded
271
272							filename := filePart.Filename
273							if filename == "" {
274								// Generate default filename based on content index
275								filename = fmt.Sprintf("part-%d.pdf", len(content))
276							}
277
278							fileBlock := openaisdk.ChatCompletionContentPartFileParam{
279								File: openaisdk.ChatCompletionContentPartFileFileParam{
280									Filename: param.NewOpt(filename),
281									FileData: param.NewOpt(data),
282								},
283							}
284							content = append(content, openaisdk.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
285						}
286
287					default:
288						warnings = append(warnings, fantasy.CallWarning{
289							Type:    fantasy.CallWarningTypeOther,
290							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
291						})
292					}
293				}
294			}
295			messages = append(messages, openaisdk.UserMessage(content))
296		case fantasy.MessageRoleAssistant:
297			// simple assistant message just text content
298			if len(msg.Content) == 1 && msg.Content[0].GetType() == fantasy.ContentTypeText {
299				textPart, ok := fantasy.AsContentType[fantasy.TextPart](msg.Content[0])
300				if !ok {
301					warnings = append(warnings, fantasy.CallWarning{
302						Type:    fantasy.CallWarningTypeOther,
303						Message: "assistant message text part does not have the right type",
304					})
305					continue
306				}
307				messages = append(messages, openaisdk.AssistantMessage(textPart.Text))
308				continue
309			}
310			assistantMsg := openaisdk.ChatCompletionAssistantMessageParam{
311				Role: "assistant",
312			}
313			var reasoningText string
314			for _, c := range msg.Content {
315				switch c.GetType() {
316				case fantasy.ContentTypeText:
317					textPart, ok := fantasy.AsContentType[fantasy.TextPart](c)
318					if !ok {
319						warnings = append(warnings, fantasy.CallWarning{
320							Type:    fantasy.CallWarningTypeOther,
321							Message: "assistant message text part does not have the right type",
322						})
323						continue
324					}
325					assistantMsg.Content = openaisdk.ChatCompletionAssistantMessageParamContentUnion{
326						OfString: param.NewOpt(textPart.Text),
327					}
328				case fantasy.ContentTypeReasoning:
329					reasoningPart, ok := fantasy.AsContentType[fantasy.ReasoningPart](c)
330					if !ok {
331						warnings = append(warnings, fantasy.CallWarning{
332							Type:    fantasy.CallWarningTypeOther,
333							Message: "assistant message reasoning part does not have the right type",
334						})
335						continue
336					}
337					reasoningText = reasoningPart.Text
338				case fantasy.ContentTypeToolCall:
339					toolCallPart, ok := fantasy.AsContentType[fantasy.ToolCallPart](c)
340					if !ok {
341						warnings = append(warnings, fantasy.CallWarning{
342							Type:    fantasy.CallWarningTypeOther,
343							Message: "assistant message tool part does not have the right type",
344						})
345						continue
346					}
347					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
348						openaisdk.ChatCompletionMessageToolCallUnionParam{
349							OfFunction: &openaisdk.ChatCompletionMessageFunctionToolCallParam{
350								ID:   toolCallPart.ToolCallID,
351								Type: "function",
352								Function: openaisdk.ChatCompletionMessageFunctionToolCallFunctionParam{
353									Name:      toolCallPart.ToolName,
354									Arguments: toolCallPart.Input,
355								},
356							},
357						})
358				}
359			}
360			// Add reasoning_content field if present
361			if reasoningText != "" {
362				assistantMsg.SetExtraFields(map[string]any{
363					"reasoning_content": reasoningText,
364				})
365			}
366			messages = append(messages, openaisdk.ChatCompletionMessageParamUnion{
367				OfAssistant: &assistantMsg,
368			})
369		case fantasy.MessageRoleTool:
370			for _, c := range msg.Content {
371				if c.GetType() != fantasy.ContentTypeToolResult {
372					warnings = append(warnings, fantasy.CallWarning{
373						Type:    fantasy.CallWarningTypeOther,
374						Message: "tool message can only have tool result content",
375					})
376					continue
377				}
378
379				toolResultPart, ok := fantasy.AsContentType[fantasy.ToolResultPart](c)
380				if !ok {
381					warnings = append(warnings, fantasy.CallWarning{
382						Type:    fantasy.CallWarningTypeOther,
383						Message: "tool message result part does not have the right type",
384					})
385					continue
386				}
387
388				switch toolResultPart.Output.GetType() {
389				case fantasy.ToolResultContentTypeText:
390					output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](toolResultPart.Output)
391					if !ok {
392						warnings = append(warnings, fantasy.CallWarning{
393							Type:    fantasy.CallWarningTypeOther,
394							Message: "tool result output does not have the right type",
395						})
396						continue
397					}
398					messages = append(messages, openaisdk.ToolMessage(output.Text, toolResultPart.ToolCallID))
399				case fantasy.ToolResultContentTypeError:
400					output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](toolResultPart.Output)
401					if !ok {
402						warnings = append(warnings, fantasy.CallWarning{
403							Type:    fantasy.CallWarningTypeOther,
404							Message: "tool result output does not have the right type",
405						})
406						continue
407					}
408					messages = append(messages, openaisdk.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
409				}
410			}
411		}
412	}
413	return messages, warnings
414}