1package kronk
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "io"
8
9 "charm.land/fantasy"
10 "charm.land/fantasy/object"
11 "github.com/ardanlabs/kronk/sdk/kronk"
12 "github.com/ardanlabs/kronk/sdk/kronk/model"
13 xjson "github.com/charmbracelet/x/json"
14 "github.com/google/uuid"
15)
16
17type languageModel struct {
18 provider string
19 modelID string
20 kronk *kronk.Kronk
21 objectMode fantasy.ObjectMode
22 prepareCallFunc LanguageModelPrepareCallFunc
23 mapFinishReasonFunc LanguageModelMapFinishReasonFunc
24 toPromptFunc LanguageModelToPromptFunc
25}
26
27// LanguageModelOption is a function that configures a languageModel.
28type LanguageModelOption func(*languageModel)
29
30// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
31func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
32 return func(l *languageModel) {
33 l.prepareCallFunc = fn
34 }
35}
36
37// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
38func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
39 return func(l *languageModel) {
40 l.mapFinishReasonFunc = fn
41 }
42}
43
44// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
45func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
46 return func(l *languageModel) {
47 l.toPromptFunc = fn
48 }
49}
50
51// WithLanguageModelObjectMode sets the object generation mode.
52func WithLanguageModelObjectMode(om fantasy.ObjectMode) LanguageModelOption {
53 return func(l *languageModel) {
54 l.objectMode = om
55 }
56}
57
58func newLanguageModel(modelID string, provider string, krn *kronk.Kronk, opts ...LanguageModelOption) *languageModel {
59 lm := languageModel{
60 modelID: modelID,
61 provider: provider,
62 kronk: krn,
63 objectMode: fantasy.ObjectModeAuto,
64 prepareCallFunc: DefaultPrepareCallFunc,
65 mapFinishReasonFunc: DefaultMapFinishReasonFunc,
66 toPromptFunc: DefaultToPrompt,
67 }
68
69 for _, o := range opts {
70 o(&lm)
71 }
72
73 return &lm
74}
75
76type streamToolCall struct {
77 id string
78 name string
79 arguments string
80 hasFinished bool
81}
82
83// Model implements fantasy.LanguageModel.
84func (l *languageModel) Model() string {
85 return l.modelID
86}
87
88// Provider implements fantasy.LanguageModel.
89func (l *languageModel) Provider() string {
90 return l.provider
91}
92
93func (l *languageModel) prepareDocument(call fantasy.Call) (model.D, []fantasy.CallWarning, error) {
94 messages, warnings := l.toPromptFunc(call.Prompt, l.provider, l.modelID)
95
96 if call.TopK != nil {
97 warnings = append(warnings, fantasy.CallWarning{
98 Type: fantasy.CallWarningTypeUnsupportedSetting,
99 Setting: "top_k",
100 })
101 }
102
103 d := model.D{
104 "messages": messages,
105 }
106
107 if call.MaxOutputTokens != nil {
108 d["max_tokens"] = *call.MaxOutputTokens
109 }
110
111 if call.Temperature != nil {
112 d["temperature"] = *call.Temperature
113 }
114
115 if call.TopP != nil {
116 d["top_p"] = *call.TopP
117 }
118
119 if call.FrequencyPenalty != nil {
120 warnings = append(warnings, fantasy.CallWarning{
121 Type: fantasy.CallWarningTypeUnsupportedSetting,
122 Setting: "frequency_penalty",
123 Details: "frequency_penalty is not supported by Kronk",
124 })
125 }
126
127 if call.PresencePenalty != nil {
128 warnings = append(warnings, fantasy.CallWarning{
129 Type: fantasy.CallWarningTypeUnsupportedSetting,
130 Setting: "presence_penalty",
131 Details: "presence_penalty is not supported by Kronk",
132 })
133 }
134
135 optionsWarnings, err := l.prepareCallFunc(l, d, call)
136 if err != nil {
137 return nil, nil, err
138 }
139
140 if len(optionsWarnings) > 0 {
141 warnings = append(warnings, optionsWarnings...)
142 }
143
144 if len(call.Tools) > 0 {
145 tools, toolWarnings := toKronkTools(call.Tools)
146 d["tools"] = tools
147 warnings = append(warnings, toolWarnings...)
148 }
149
150 return d, warnings, nil
151}
152
153// Generate implements fantasy.LanguageModel.
154func (l *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
155 d, warnings, err := l.prepareDocument(call)
156 if err != nil {
157 return nil, err
158 }
159
160 ch, err := l.kronk.ChatStreaming(ctx, d)
161 if err != nil {
162 return nil, toProviderErr(err)
163 }
164
165 var lastResponse model.ChatResponse
166 var fullContent string
167
168 for resp := range ch {
169 lastResponse = resp
170
171 if len(resp.Choice) > 0 && resp.Choice[0].Delta != nil {
172 switch resp.Choice[0].FinishReason() {
173 case model.FinishReasonError:
174 return nil, &fantasy.Error{Title: "model error", Message: resp.Choice[0].Delta.Content}
175
176 case model.FinishReasonStop, model.FinishReasonTool:
177 // Final response already contains full accumulated content in Delta.Content,
178 // so we use it directly instead of continuing to accumulate.
179 fullContent = resp.Choice[0].Delta.Content
180
181 default:
182 fullContent += resp.Choice[0].Delta.Content
183 }
184 }
185 }
186
187 if len(lastResponse.Choice) == 0 {
188 return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
189 }
190
191 choice := lastResponse.Choice[0]
192 var content []fantasy.Content
193 if choice.Delta != nil {
194 content = make([]fantasy.Content, 0, 1+len(choice.Delta.ToolCalls))
195 }
196
197 if fullContent != "" {
198 content = append(content, fantasy.TextContent{
199 Text: fullContent,
200 })
201 }
202
203 if choice.Delta != nil {
204 for _, tc := range choice.Delta.ToolCalls {
205 // Marshal the underlying map directly, not the ToolCallArguments type
206 // which has a custom MarshalJSON that double-encodes to a JSON string.
207 argsJSON, _ := json.Marshal(map[string]any(tc.Function.Arguments))
208
209 content = append(content, fantasy.ToolCallContent{
210 ProviderExecuted: false,
211 ToolCallID: tc.ID,
212 ToolName: tc.Function.Name,
213 Input: string(argsJSON),
214 })
215 }
216 }
217
218 usage := fantasy.Usage{}
219 if lastResponse.Usage != nil {
220 usage = fantasy.Usage{
221 InputTokens: int64(lastResponse.Usage.PromptTokens),
222 OutputTokens: int64(lastResponse.Usage.CompletionTokens),
223 TotalTokens: int64(lastResponse.Usage.PromptTokens + lastResponse.Usage.CompletionTokens),
224 ReasoningTokens: int64(lastResponse.Usage.ReasoningTokens),
225 }
226 }
227
228 mappedFinishReason := l.mapFinishReasonFunc(choice.FinishReason())
229 if choice.Delta != nil && len(choice.Delta.ToolCalls) > 0 {
230 mappedFinishReason = fantasy.FinishReasonToolCalls
231 }
232
233 providerMetadata := fantasy.ProviderMetadata{}
234 if lastResponse.Usage != nil {
235 providerMetadata = fantasy.ProviderMetadata{
236 Name: &ProviderMetadata{
237 TokensPerSecond: lastResponse.Usage.TokensPerSecond,
238 OutputTokens: int64(lastResponse.Usage.OutputTokens),
239 },
240 }
241 }
242
243 resp := fantasy.Response{
244 Content: content,
245 Usage: usage,
246 FinishReason: mappedFinishReason,
247 ProviderMetadata: providerMetadata,
248 Warnings: warnings,
249 }
250
251 return &resp, nil
252}
253
254// Stream implements fantasy.LanguageModel.
255func (l *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
256 d, warnings, err := l.prepareDocument(call)
257 if err != nil {
258 return nil, err
259 }
260
261 ch, err := l.kronk.ChatStreaming(ctx, d)
262 if err != nil {
263 return nil, toProviderErr(err)
264 }
265
266 isActiveText := false
267 isActiveReasoning := false
268 toolCalls := make(map[int]streamToolCall)
269
270 providerMetadata := fantasy.ProviderMetadata{
271 Name: &ProviderMetadata{},
272 }
273
274 var usage fantasy.Usage
275 var finishReason string
276
277 return func(yield func(fantasy.StreamPart) bool) {
278 if len(warnings) > 0 {
279 if !yield(fantasy.StreamPart{
280 Type: fantasy.StreamPartTypeWarnings,
281 Warnings: warnings,
282 }) {
283 return
284 }
285 }
286
287 toolIndex := 0
288 for resp := range ch {
289 if len(resp.Choice) == 0 {
290 continue
291 }
292
293 choice := resp.Choice[0]
294 if choice.Delta == nil {
295 continue
296 }
297
298 if resp.Usage != nil {
299 usage = fantasy.Usage{
300 InputTokens: int64(resp.Usage.PromptTokens),
301 OutputTokens: int64(resp.Usage.CompletionTokens),
302 TotalTokens: int64(resp.Usage.PromptTokens + resp.Usage.CompletionTokens),
303 ReasoningTokens: int64(resp.Usage.ReasoningTokens),
304 }
305
306 if pm, ok := providerMetadata[Name]; ok {
307 if metadata, ok := pm.(*ProviderMetadata); ok {
308 metadata.TokensPerSecond = resp.Usage.TokensPerSecond
309 metadata.OutputTokens = int64(resp.Usage.OutputTokens)
310 }
311 }
312 }
313
314 if choice.FinishReason() != "" {
315 finishReason = choice.FinishReason()
316 }
317
318 switch choice.FinishReason() {
319 case model.FinishReasonError:
320 yield(fantasy.StreamPart{
321 Type: fantasy.StreamPartTypeError,
322 Error: &fantasy.Error{Title: "model error", Message: choice.Delta.Content},
323 })
324 return
325
326 case model.FinishReasonTool:
327 if isActiveReasoning {
328 isActiveReasoning = false
329 if !yield(fantasy.StreamPart{
330 Type: fantasy.StreamPartTypeReasoningEnd,
331 ID: "reasoning-0",
332 }) {
333 return
334 }
335 }
336
337 if isActiveText {
338 isActiveText = false
339 if !yield(fantasy.StreamPart{
340 Type: fantasy.StreamPartTypeTextEnd,
341 ID: "0",
342 }) {
343 return
344 }
345 }
346
347 for _, tc := range choice.Delta.ToolCalls {
348 argsJSON, _ := json.Marshal(map[string]any(tc.Function.Arguments))
349 argsStr := string(argsJSON)
350
351 toolID := tc.ID
352 if toolID == "" {
353 toolID = uuid.NewString()
354 }
355
356 if !yield(fantasy.StreamPart{
357 Type: fantasy.StreamPartTypeToolInputStart,
358 ID: toolID,
359 ToolCallName: tc.Function.Name,
360 }) {
361 return
362 }
363
364 if !yield(fantasy.StreamPart{
365 Type: fantasy.StreamPartTypeToolInputDelta,
366 ID: toolID,
367 Delta: argsStr,
368 }) {
369 return
370 }
371
372 if !yield(fantasy.StreamPart{
373 Type: fantasy.StreamPartTypeToolInputEnd,
374 ID: toolID,
375 }) {
376 return
377 }
378
379 if !yield(fantasy.StreamPart{
380 Type: fantasy.StreamPartTypeToolCall,
381 ID: toolID,
382 ToolCallName: tc.Function.Name,
383 ToolCallInput: argsStr,
384 }) {
385 return
386 }
387
388 toolCalls[toolIndex] = streamToolCall{
389 id: toolID,
390 name: tc.Function.Name,
391 arguments: argsStr,
392 hasFinished: true,
393 }
394 toolIndex++
395 }
396
397 default:
398 if choice.Delta.Reasoning != "" {
399 if !isActiveReasoning {
400 isActiveReasoning = true
401 if !yield(fantasy.StreamPart{
402 Type: fantasy.StreamPartTypeReasoningStart,
403 ID: "reasoning-0",
404 }) {
405 return
406 }
407 }
408
409 if !yield(fantasy.StreamPart{
410 Type: fantasy.StreamPartTypeReasoningDelta,
411 ID: "reasoning-0",
412 Delta: choice.Delta.Reasoning,
413 }) {
414 return
415 }
416 }
417
418 hasToolCalls := len(choice.Delta.ToolCalls) > 0
419 hasContent := choice.Delta.Content != ""
420
421 if isActiveReasoning && (hasContent || hasToolCalls) {
422 isActiveReasoning = false
423 if !yield(fantasy.StreamPart{
424 Type: fantasy.StreamPartTypeReasoningEnd,
425 ID: "reasoning-0",
426 }) {
427 return
428 }
429 }
430
431 if hasContent {
432 if !isActiveText {
433 isActiveText = true
434 if !yield(fantasy.StreamPart{
435 Type: fantasy.StreamPartTypeTextStart,
436 ID: "0",
437 }) {
438 return
439 }
440 }
441
442 if !yield(fantasy.StreamPart{
443 Type: fantasy.StreamPartTypeTextDelta,
444 ID: "0",
445 Delta: choice.Delta.Content,
446 }) {
447 return
448 }
449 }
450
451 if hasToolCalls && isActiveText {
452 isActiveText = false
453 if !yield(fantasy.StreamPart{
454 Type: fantasy.StreamPartTypeTextEnd,
455 ID: "0",
456 }) {
457 return
458 }
459 }
460
461 for _, tc := range choice.Delta.ToolCalls {
462 argsJSON, _ := json.Marshal(map[string]any(tc.Function.Arguments))
463 argsStr := string(argsJSON)
464
465 switch existingTC, ok := toolCalls[toolIndex]; ok {
466 case true:
467 if existingTC.hasFinished {
468 continue
469 }
470
471 existingTC.arguments += argsStr
472
473 if !yield(fantasy.StreamPart{
474 Type: fantasy.StreamPartTypeToolInputDelta,
475 ID: existingTC.id,
476 Delta: argsStr,
477 }) {
478 return
479 }
480
481 toolCalls[toolIndex] = existingTC
482
483 if xjson.IsValid(existingTC.arguments) {
484 if !yield(fantasy.StreamPart{
485 Type: fantasy.StreamPartTypeToolInputEnd,
486 ID: existingTC.id,
487 }) {
488 return
489 }
490
491 if !yield(fantasy.StreamPart{
492 Type: fantasy.StreamPartTypeToolCall,
493 ID: existingTC.id,
494 ToolCallName: existingTC.name,
495 ToolCallInput: existingTC.arguments,
496 }) {
497 return
498 }
499
500 existingTC.hasFinished = true
501 toolCalls[toolIndex] = existingTC
502 }
503
504 case false:
505 toolID := tc.ID
506 if toolID == "" {
507 toolID = uuid.NewString()
508 }
509
510 if !yield(fantasy.StreamPart{
511 Type: fantasy.StreamPartTypeToolInputStart,
512 ID: toolID,
513 ToolCallName: tc.Function.Name,
514 }) {
515 return
516 }
517
518 toolCalls[toolIndex] = streamToolCall{
519 id: toolID,
520 name: tc.Function.Name,
521 arguments: argsStr,
522 }
523
524 if argsStr != "" && argsStr != "null" {
525 if !yield(fantasy.StreamPart{
526 Type: fantasy.StreamPartTypeToolInputDelta,
527 ID: toolID,
528 Delta: argsStr,
529 }) {
530 return
531 }
532
533 if xjson.IsValid(argsStr) {
534 if !yield(fantasy.StreamPart{
535 Type: fantasy.StreamPartTypeToolInputEnd,
536 ID: toolID,
537 }) {
538 return
539 }
540
541 if !yield(fantasy.StreamPart{
542 Type: fantasy.StreamPartTypeToolCall,
543 ID: toolID,
544 ToolCallName: tc.Function.Name,
545 ToolCallInput: argsStr,
546 }) {
547 return
548 }
549
550 stc := toolCalls[toolIndex]
551 stc.hasFinished = true
552 toolCalls[toolIndex] = stc
553 }
554 }
555
556 toolIndex++
557 }
558 }
559 }
560 }
561
562 if isActiveReasoning {
563 if !yield(fantasy.StreamPart{
564 Type: fantasy.StreamPartTypeReasoningEnd,
565 ID: "reasoning-0",
566 }) {
567 return
568 }
569 }
570
571 if isActiveText {
572 if !yield(fantasy.StreamPart{
573 Type: fantasy.StreamPartTypeTextEnd,
574 ID: "0",
575 }) {
576 return
577 }
578 }
579
580 mappedFinishReason := l.mapFinishReasonFunc(finishReason)
581 if len(toolCalls) > 0 {
582 mappedFinishReason = fantasy.FinishReasonToolCalls
583 }
584
585 yield(fantasy.StreamPart{
586 Type: fantasy.StreamPartTypeFinish,
587 Usage: usage,
588 FinishReason: mappedFinishReason,
589 ProviderMetadata: providerMetadata,
590 })
591 }, nil
592}
593
594// GenerateObject implements fantasy.LanguageModel.
595func (l *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
596 switch l.objectMode {
597 case fantasy.ObjectModeText:
598 return object.GenerateWithText(ctx, l, call)
599
600 case fantasy.ObjectModeTool:
601 return object.GenerateWithTool(ctx, l, call)
602
603 default:
604 return object.GenerateWithTool(ctx, l, call)
605 }
606}
607
608// StreamObject implements fantasy.LanguageModel.
609func (l *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
610 switch l.objectMode {
611 case fantasy.ObjectModeTool:
612 return object.StreamWithTool(ctx, l, call)
613
614 case fantasy.ObjectModeText:
615 return object.StreamWithText(ctx, l, call)
616
617 default:
618 return object.StreamWithTool(ctx, l, call)
619 }
620}
621
622func toKronkTools(tools []fantasy.Tool) ([]model.D, []fantasy.CallWarning) {
623 var kronkTools []model.D
624 var warnings []fantasy.CallWarning
625
626 for _, tool := range tools {
627 if tool.GetType() == fantasy.ToolTypeFunction {
628 ft, ok := tool.(fantasy.FunctionTool)
629 if !ok {
630 continue
631 }
632
633 kronkTools = append(kronkTools, model.D{
634 "type": "function",
635 "function": model.D{
636 "name": ft.Name,
637 "description": ft.Description,
638 "parameters": ft.InputSchema,
639 },
640 })
641
642 continue
643 }
644
645 warnings = append(warnings, fantasy.CallWarning{
646 Type: fantasy.CallWarningTypeUnsupportedTool,
647 Tool: tool,
648 Message: "tool is not supported",
649 })
650 }
651
652 return kronkTools, warnings
653}
654
655func toProviderErr(err error) error {
656 if err == nil {
657 return nil
658 }
659
660 if errors.Is(err, io.EOF) {
661 return nil
662 }
663
664 return &fantasy.ProviderError{
665 Title: "kronk error",
666 Message: err.Error(),
667 Cause: err,
668 }
669}