1package kronk
2
3import (
4 "encoding/base64"
5 "fmt"
6 "strings"
7
8 "charm.land/fantasy"
9 "github.com/ardanlabs/kronk/sdk/kronk/model"
10)
11
12// LanguageModelPrepareCallFunc is a function that prepares the call for the language model.
13type LanguageModelPrepareCallFunc func(lm fantasy.LanguageModel, d model.D, call fantasy.Call) ([]fantasy.CallWarning, error)
14
15// LanguageModelMapFinishReasonFunc is a function that maps the finish reason for the language model.
16type LanguageModelMapFinishReasonFunc func(finishReason string) fantasy.FinishReason
17
18// LanguageModelToPromptFunc is a function that handles converting fantasy prompts to Kronk SDK messages.
19type LanguageModelToPromptFunc func(prompt fantasy.Prompt, provider, modelID string) ([]model.D, []fantasy.CallWarning)
20
21// DefaultPrepareCallFunc is the default implementation for preparing a call to the language model.
22func DefaultPrepareCallFunc(_ fantasy.LanguageModel, d model.D, call fantasy.Call) ([]fantasy.CallWarning, error) {
23 if call.ProviderOptions == nil {
24 return nil, nil
25 }
26
27 var warnings []fantasy.CallWarning
28 providerOptions := &ProviderOptions{}
29 if v, ok := call.ProviderOptions[Name]; ok {
30 providerOptions, ok = v.(*ProviderOptions)
31 if !ok {
32 return nil, &fantasy.Error{Title: "invalid argument", Message: "kronk provider options should be *kronk.ProviderOptions"}
33 }
34 }
35
36 if providerOptions.TopK != nil {
37 d["top_k"] = *providerOptions.TopK
38 }
39
40 if providerOptions.RepeatPenalty != nil {
41 d["repeat_penalty"] = *providerOptions.RepeatPenalty
42 }
43
44 if providerOptions.Seed != nil {
45 d["seed"] = *providerOptions.Seed
46 }
47
48 if providerOptions.MinP != nil {
49 d["min_p"] = *providerOptions.MinP
50 }
51
52 if providerOptions.NumPredict != nil {
53 d["num_predict"] = *providerOptions.NumPredict
54 }
55
56 if providerOptions.Stop != nil {
57 d["stop"] = providerOptions.Stop
58 }
59
60 return warnings, nil
61}
62
63// DefaultMapFinishReasonFunc is the default implementation for mapping finish reasons.
64func DefaultMapFinishReasonFunc(finishReason string) fantasy.FinishReason {
65 switch finishReason {
66 case string(model.FinishReasonStop):
67 return fantasy.FinishReasonStop
68
69 case string(model.FinishReasonTool):
70 return fantasy.FinishReasonToolCalls
71
72 case string(model.FinishReasonError):
73 return fantasy.FinishReasonError
74
75 default:
76 return fantasy.FinishReasonUnknown
77 }
78}
79
80// DefaultToPrompt is the default implementation for converting fantasy prompts to Kronk SDK messages.
81func DefaultToPrompt(prompt fantasy.Prompt, _ string, _ string) ([]model.D, []fantasy.CallWarning) {
82 var messages []model.D
83 var warnings []fantasy.CallWarning
84
85 for _, msg := range prompt {
86 switch msg.Role {
87 case fantasy.MessageRoleSystem:
88 for _, c := range msg.Content {
89 if c.GetType() == fantasy.ContentTypeText {
90 textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](c)
91 if !ok {
92 warnings = append(warnings, fantasy.CallWarning{
93 Type: fantasy.CallWarningTypeOther,
94 Message: "system message text part does not have the right type",
95 })
96
97 continue
98 }
99
100 messages = append(messages, model.TextMessage(model.RoleSystem, textPart.Text))
101 }
102 }
103
104 case fantasy.MessageRoleUser:
105 var content []model.D
106 for _, c := range msg.Content {
107 switch c.GetType() {
108 case fantasy.ContentTypeText:
109 textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](c)
110 if !ok {
111 warnings = append(warnings, fantasy.CallWarning{
112 Type: fantasy.CallWarningTypeOther,
113 Message: "user message text part does not have the right type",
114 })
115
116 continue
117 }
118
119 content = append(content, model.D{
120 "type": "text",
121 "text": textPart.Text,
122 })
123
124 case fantasy.ContentTypeFile:
125 filePart, ok := fantasy.AsMessagePart[fantasy.FilePart](c)
126 if !ok {
127 warnings = append(warnings, fantasy.CallWarning{
128 Type: fantasy.CallWarningTypeOther,
129 Message: "user message file part does not have the right type",
130 })
131
132 continue
133 }
134
135 switch {
136 case strings.HasPrefix(filePart.MediaType, "image/"):
137 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
138 data := "data:" + filePart.MediaType + ";base64," + base64Encoded
139 content = append(content, model.D{
140 "type": "image_url",
141 "image_url": model.D{
142 "url": data,
143 },
144 })
145
146 default:
147 warnings = append(warnings, fantasy.CallWarning{
148 Type: fantasy.CallWarningTypeOther,
149 Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
150 })
151 }
152 }
153 }
154
155 switch {
156 case len(content) == 1 && content[0]["type"] == "text":
157 messages = append(messages, model.TextMessage(model.RoleUser, content[0]["text"].(string)))
158
159 case len(content) > 0:
160 messages = append(messages, model.D{
161 "role": model.RoleUser,
162 "content": content,
163 })
164 }
165
166 case fantasy.MessageRoleAssistant:
167 var textContent string
168 var toolCalls []model.D
169
170 for _, c := range msg.Content {
171 switch c.GetType() {
172 case fantasy.ContentTypeText:
173 textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](c)
174 if !ok {
175 warnings = append(warnings, fantasy.CallWarning{
176 Type: fantasy.CallWarningTypeOther,
177 Message: "assistant message text part does not have the right type",
178 })
179
180 continue
181 }
182
183 textContent += textPart.Text
184
185 case fantasy.ContentTypeToolCall:
186 toolCallPart, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](c)
187 if !ok {
188 warnings = append(warnings, fantasy.CallWarning{
189 Type: fantasy.CallWarningTypeOther,
190 Message: "assistant message tool part does not have the right type",
191 })
192
193 continue
194 }
195
196 toolCalls = append(toolCalls, model.D{
197 "id": toolCallPart.ToolCallID,
198 "type": "function",
199 "function": model.D{
200 "name": toolCallPart.ToolName,
201 "arguments": toolCallPart.Input,
202 },
203 })
204 }
205 }
206
207 assistantMsg := model.D{
208 "role": model.RoleAssistant,
209 }
210
211 if textContent != "" {
212 assistantMsg["content"] = textContent
213 }
214
215 if len(toolCalls) > 0 {
216 assistantMsg["tool_calls"] = toolCalls
217 }
218
219 if textContent != "" || len(toolCalls) > 0 {
220 messages = append(messages, assistantMsg)
221 }
222
223 case fantasy.MessageRoleTool:
224 for _, c := range msg.Content {
225 if c.GetType() != fantasy.ContentTypeToolResult {
226 warnings = append(warnings, fantasy.CallWarning{
227 Type: fantasy.CallWarningTypeOther,
228 Message: "tool message can only have tool result content",
229 })
230
231 continue
232 }
233
234 toolResultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](c)
235 if !ok {
236 warnings = append(warnings, fantasy.CallWarning{
237 Type: fantasy.CallWarningTypeOther,
238 Message: "tool message result part does not have the right type",
239 })
240
241 continue
242 }
243
244 var resultContent string
245 switch toolResultPart.Output.GetType() {
246 case fantasy.ToolResultContentTypeText:
247 output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](toolResultPart.Output)
248 if ok {
249 resultContent = output.Text
250 }
251
252 case fantasy.ToolResultContentTypeError:
253 output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](toolResultPart.Output)
254 if ok {
255 resultContent = output.Error.Error()
256 }
257 }
258
259 messages = append(messages, model.D{
260 "role": "tool",
261 "content": resultContent,
262 "tool_call_id": toolResultPart.ToolCallID,
263 })
264 }
265 }
266 }
267
268 return messages, warnings
269}