language_model_hooks.go

  1package kronk
  2
  3import (
  4	"encoding/base64"
  5	"fmt"
  6	"strings"
  7
  8	"charm.land/fantasy"
  9	"github.com/ardanlabs/kronk/sdk/kronk/model"
 10)
 11
 12// LanguageModelPrepareCallFunc is a function that prepares the call for the language model.
 13type LanguageModelPrepareCallFunc func(lm fantasy.LanguageModel, d model.D, call fantasy.Call) ([]fantasy.CallWarning, error)
 14
 15// LanguageModelMapFinishReasonFunc is a function that maps the finish reason for the language model.
 16type LanguageModelMapFinishReasonFunc func(finishReason string) fantasy.FinishReason
 17
 18// LanguageModelToPromptFunc is a function that handles converting fantasy prompts to Kronk SDK messages.
 19type LanguageModelToPromptFunc func(prompt fantasy.Prompt, provider, modelID string) ([]model.D, []fantasy.CallWarning)
 20
 21// DefaultPrepareCallFunc is the default implementation for preparing a call to the language model.
 22func DefaultPrepareCallFunc(_ fantasy.LanguageModel, d model.D, call fantasy.Call) ([]fantasy.CallWarning, error) {
 23	if call.ProviderOptions == nil {
 24		return nil, nil
 25	}
 26
 27	var warnings []fantasy.CallWarning
 28	providerOptions := &ProviderOptions{}
 29	if v, ok := call.ProviderOptions[Name]; ok {
 30		providerOptions, ok = v.(*ProviderOptions)
 31		if !ok {
 32			return nil, &fantasy.Error{Title: "invalid argument", Message: "kronk provider options should be *kronk.ProviderOptions"}
 33		}
 34	}
 35
 36	if providerOptions.TopK != nil {
 37		d["top_k"] = *providerOptions.TopK
 38	}
 39
 40	if providerOptions.RepeatPenalty != nil {
 41		d["repeat_penalty"] = *providerOptions.RepeatPenalty
 42	}
 43
 44	if providerOptions.Seed != nil {
 45		d["seed"] = *providerOptions.Seed
 46	}
 47
 48	if providerOptions.MinP != nil {
 49		d["min_p"] = *providerOptions.MinP
 50	}
 51
 52	if providerOptions.NumPredict != nil {
 53		d["num_predict"] = *providerOptions.NumPredict
 54	}
 55
 56	if providerOptions.Stop != nil {
 57		d["stop"] = providerOptions.Stop
 58	}
 59
 60	return warnings, nil
 61}
 62
 63// DefaultMapFinishReasonFunc is the default implementation for mapping finish reasons.
 64func DefaultMapFinishReasonFunc(finishReason string) fantasy.FinishReason {
 65	switch finishReason {
 66	case string(model.FinishReasonStop):
 67		return fantasy.FinishReasonStop
 68
 69	case string(model.FinishReasonTool):
 70		return fantasy.FinishReasonToolCalls
 71
 72	case string(model.FinishReasonError):
 73		return fantasy.FinishReasonError
 74
 75	default:
 76		return fantasy.FinishReasonUnknown
 77	}
 78}
 79
 80// DefaultToPrompt is the default implementation for converting fantasy prompts to Kronk SDK messages.
 81func DefaultToPrompt(prompt fantasy.Prompt, _ string, _ string) ([]model.D, []fantasy.CallWarning) {
 82	var messages []model.D
 83	var warnings []fantasy.CallWarning
 84
 85	for _, msg := range prompt {
 86		switch msg.Role {
 87		case fantasy.MessageRoleSystem:
 88			for _, c := range msg.Content {
 89				if c.GetType() == fantasy.ContentTypeText {
 90					textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](c)
 91					if !ok {
 92						warnings = append(warnings, fantasy.CallWarning{
 93							Type:    fantasy.CallWarningTypeOther,
 94							Message: "system message text part does not have the right type",
 95						})
 96
 97						continue
 98					}
 99
100					messages = append(messages, model.TextMessage(model.RoleSystem, textPart.Text))
101				}
102			}
103
104		case fantasy.MessageRoleUser:
105			var content []model.D
106			for _, c := range msg.Content {
107				switch c.GetType() {
108				case fantasy.ContentTypeText:
109					textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](c)
110					if !ok {
111						warnings = append(warnings, fantasy.CallWarning{
112							Type:    fantasy.CallWarningTypeOther,
113							Message: "user message text part does not have the right type",
114						})
115
116						continue
117					}
118
119					content = append(content, model.D{
120						"type": "text",
121						"text": textPart.Text,
122					})
123
124				case fantasy.ContentTypeFile:
125					filePart, ok := fantasy.AsMessagePart[fantasy.FilePart](c)
126					if !ok {
127						warnings = append(warnings, fantasy.CallWarning{
128							Type:    fantasy.CallWarningTypeOther,
129							Message: "user message file part does not have the right type",
130						})
131
132						continue
133					}
134
135					switch {
136					case strings.HasPrefix(filePart.MediaType, "image/"):
137						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
138						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
139						content = append(content, model.D{
140							"type": "image_url",
141							"image_url": model.D{
142								"url": data,
143							},
144						})
145
146					default:
147						warnings = append(warnings, fantasy.CallWarning{
148							Type:    fantasy.CallWarningTypeOther,
149							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
150						})
151					}
152				}
153			}
154
155			switch {
156			case len(content) == 1 && content[0]["type"] == "text":
157				messages = append(messages, model.TextMessage(model.RoleUser, content[0]["text"].(string)))
158
159			case len(content) > 0:
160				messages = append(messages, model.D{
161					"role":    model.RoleUser,
162					"content": content,
163				})
164			}
165
166		case fantasy.MessageRoleAssistant:
167			var textContent string
168			var toolCalls []model.D
169
170			for _, c := range msg.Content {
171				switch c.GetType() {
172				case fantasy.ContentTypeText:
173					textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](c)
174					if !ok {
175						warnings = append(warnings, fantasy.CallWarning{
176							Type:    fantasy.CallWarningTypeOther,
177							Message: "assistant message text part does not have the right type",
178						})
179
180						continue
181					}
182
183					textContent += textPart.Text
184
185				case fantasy.ContentTypeToolCall:
186					toolCallPart, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](c)
187					if !ok {
188						warnings = append(warnings, fantasy.CallWarning{
189							Type:    fantasy.CallWarningTypeOther,
190							Message: "assistant message tool part does not have the right type",
191						})
192
193						continue
194					}
195
196					toolCalls = append(toolCalls, model.D{
197						"id":   toolCallPart.ToolCallID,
198						"type": "function",
199						"function": model.D{
200							"name":      toolCallPart.ToolName,
201							"arguments": toolCallPart.Input,
202						},
203					})
204				}
205			}
206
207			assistantMsg := model.D{
208				"role": model.RoleAssistant,
209			}
210
211			if textContent != "" {
212				assistantMsg["content"] = textContent
213			}
214
215			if len(toolCalls) > 0 {
216				assistantMsg["tool_calls"] = toolCalls
217			}
218
219			if textContent != "" || len(toolCalls) > 0 {
220				messages = append(messages, assistantMsg)
221			}
222
223		case fantasy.MessageRoleTool:
224			for _, c := range msg.Content {
225				if c.GetType() != fantasy.ContentTypeToolResult {
226					warnings = append(warnings, fantasy.CallWarning{
227						Type:    fantasy.CallWarningTypeOther,
228						Message: "tool message can only have tool result content",
229					})
230
231					continue
232				}
233
234				toolResultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](c)
235				if !ok {
236					warnings = append(warnings, fantasy.CallWarning{
237						Type:    fantasy.CallWarningTypeOther,
238						Message: "tool message result part does not have the right type",
239					})
240
241					continue
242				}
243
244				var resultContent string
245				switch toolResultPart.Output.GetType() {
246				case fantasy.ToolResultContentTypeText:
247					output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](toolResultPart.Output)
248					if ok {
249						resultContent = output.Text
250					}
251
252				case fantasy.ToolResultContentTypeError:
253					output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](toolResultPart.Output)
254					if ok {
255						resultContent = output.Error.Error()
256					}
257				}
258
259				messages = append(messages, model.D{
260					"role":         "tool",
261					"content":      resultContent,
262					"tool_call_id": toolResultPart.ToolCallID,
263				})
264			}
265		}
266	}
267
268	return messages, warnings
269}