language_model.go

  1package kronk
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"io"
  8
  9	"charm.land/fantasy"
 10	"charm.land/fantasy/object"
 11	"github.com/ardanlabs/kronk/sdk/kronk"
 12	"github.com/ardanlabs/kronk/sdk/kronk/model"
 13	xjson "github.com/charmbracelet/x/json"
 14	"github.com/google/uuid"
 15)
 16
 17type languageModel struct {
 18	provider            string
 19	modelID             string
 20	kronk               *kronk.Kronk
 21	objectMode          fantasy.ObjectMode
 22	prepareCallFunc     LanguageModelPrepareCallFunc
 23	mapFinishReasonFunc LanguageModelMapFinishReasonFunc
 24	toPromptFunc        LanguageModelToPromptFunc
 25}
 26
 27// LanguageModelOption is a function that configures a languageModel.
 28type LanguageModelOption func(*languageModel)
 29
 30// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
 31func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
 32	return func(l *languageModel) {
 33		l.prepareCallFunc = fn
 34	}
 35}
 36
 37// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
 38func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
 39	return func(l *languageModel) {
 40		l.mapFinishReasonFunc = fn
 41	}
 42}
 43
 44// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
 45func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
 46	return func(l *languageModel) {
 47		l.toPromptFunc = fn
 48	}
 49}
 50
 51// WithLanguageModelObjectMode sets the object generation mode.
 52func WithLanguageModelObjectMode(om fantasy.ObjectMode) LanguageModelOption {
 53	return func(l *languageModel) {
 54		l.objectMode = om
 55	}
 56}
 57
 58func newLanguageModel(modelID string, provider string, krn *kronk.Kronk, opts ...LanguageModelOption) *languageModel {
 59	lm := languageModel{
 60		modelID:             modelID,
 61		provider:            provider,
 62		kronk:               krn,
 63		objectMode:          fantasy.ObjectModeAuto,
 64		prepareCallFunc:     DefaultPrepareCallFunc,
 65		mapFinishReasonFunc: DefaultMapFinishReasonFunc,
 66		toPromptFunc:        DefaultToPrompt,
 67	}
 68
 69	for _, o := range opts {
 70		o(&lm)
 71	}
 72
 73	return &lm
 74}
 75
 76type streamToolCall struct {
 77	id          string
 78	name        string
 79	arguments   string
 80	hasFinished bool
 81}
 82
 83// Model implements fantasy.LanguageModel.
 84func (l *languageModel) Model() string {
 85	return l.modelID
 86}
 87
 88// Provider implements fantasy.LanguageModel.
 89func (l *languageModel) Provider() string {
 90	return l.provider
 91}
 92
 93func (l *languageModel) prepareDocument(call fantasy.Call) (model.D, []fantasy.CallWarning, error) {
 94	messages, warnings := l.toPromptFunc(call.Prompt, l.provider, l.modelID)
 95
 96	if call.TopK != nil {
 97		warnings = append(warnings, fantasy.CallWarning{
 98			Type:    fantasy.CallWarningTypeUnsupportedSetting,
 99			Setting: "top_k",
100		})
101	}
102
103	d := model.D{
104		"messages": messages,
105	}
106
107	if call.MaxOutputTokens != nil {
108		d["max_tokens"] = *call.MaxOutputTokens
109	}
110
111	if call.Temperature != nil {
112		d["temperature"] = *call.Temperature
113	}
114
115	if call.TopP != nil {
116		d["top_p"] = *call.TopP
117	}
118
119	if call.FrequencyPenalty != nil {
120		warnings = append(warnings, fantasy.CallWarning{
121			Type:    fantasy.CallWarningTypeUnsupportedSetting,
122			Setting: "frequency_penalty",
123			Details: "frequency_penalty is not supported by Kronk",
124		})
125	}
126
127	if call.PresencePenalty != nil {
128		warnings = append(warnings, fantasy.CallWarning{
129			Type:    fantasy.CallWarningTypeUnsupportedSetting,
130			Setting: "presence_penalty",
131			Details: "presence_penalty is not supported by Kronk",
132		})
133	}
134
135	optionsWarnings, err := l.prepareCallFunc(l, d, call)
136	if err != nil {
137		return nil, nil, err
138	}
139
140	if len(optionsWarnings) > 0 {
141		warnings = append(warnings, optionsWarnings...)
142	}
143
144	if len(call.Tools) > 0 {
145		tools, toolWarnings := toKronkTools(call.Tools)
146		d["tools"] = tools
147		warnings = append(warnings, toolWarnings...)
148	}
149
150	return d, warnings, nil
151}
152
153// Generate implements fantasy.LanguageModel.
154func (l *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
155	d, warnings, err := l.prepareDocument(call)
156	if err != nil {
157		return nil, err
158	}
159
160	ch, err := l.kronk.ChatStreaming(ctx, d)
161	if err != nil {
162		return nil, toProviderErr(err)
163	}
164
165	var lastResponse model.ChatResponse
166	var fullContent string
167
168	for resp := range ch {
169		lastResponse = resp
170
171		if len(resp.Choice) > 0 && resp.Choice[0].Delta != nil {
172			switch resp.Choice[0].FinishReason() {
173			case model.FinishReasonError:
174				return nil, &fantasy.Error{Title: "model error", Message: resp.Choice[0].Delta.Content}
175
176			case model.FinishReasonStop, model.FinishReasonTool:
177				// Final response already contains full accumulated content in Delta.Content,
178				// so we use it directly instead of continuing to accumulate.
179				fullContent = resp.Choice[0].Delta.Content
180
181			default:
182				fullContent += resp.Choice[0].Delta.Content
183			}
184		}
185	}
186
187	if len(lastResponse.Choice) == 0 {
188		return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
189	}
190
191	choice := lastResponse.Choice[0]
192	var content []fantasy.Content
193	if choice.Delta != nil {
194		content = make([]fantasy.Content, 0, 1+len(choice.Delta.ToolCalls))
195	}
196
197	if fullContent != "" {
198		content = append(content, fantasy.TextContent{
199			Text: fullContent,
200		})
201	}
202
203	if choice.Delta != nil {
204		for _, tc := range choice.Delta.ToolCalls {
205			// Marshal the underlying map directly, not the ToolCallArguments type
206			// which has a custom MarshalJSON that double-encodes to a JSON string.
207			argsJSON, _ := json.Marshal(map[string]any(tc.Function.Arguments))
208
209			content = append(content, fantasy.ToolCallContent{
210				ProviderExecuted: false,
211				ToolCallID:       tc.ID,
212				ToolName:         tc.Function.Name,
213				Input:            string(argsJSON),
214			})
215		}
216	}
217
218	usage := fantasy.Usage{}
219	if lastResponse.Usage != nil {
220		usage = fantasy.Usage{
221			InputTokens:     int64(lastResponse.Usage.PromptTokens),
222			OutputTokens:    int64(lastResponse.Usage.CompletionTokens),
223			TotalTokens:     int64(lastResponse.Usage.PromptTokens + lastResponse.Usage.CompletionTokens),
224			ReasoningTokens: int64(lastResponse.Usage.ReasoningTokens),
225		}
226	}
227
228	mappedFinishReason := l.mapFinishReasonFunc(choice.FinishReason())
229	if choice.Delta != nil && len(choice.Delta.ToolCalls) > 0 {
230		mappedFinishReason = fantasy.FinishReasonToolCalls
231	}
232
233	providerMetadata := fantasy.ProviderMetadata{}
234	if lastResponse.Usage != nil {
235		providerMetadata = fantasy.ProviderMetadata{
236			Name: &ProviderMetadata{
237				TokensPerSecond: lastResponse.Usage.TokensPerSecond,
238				OutputTokens:    int64(lastResponse.Usage.OutputTokens),
239			},
240		}
241	}
242
243	resp := fantasy.Response{
244		Content:          content,
245		Usage:            usage,
246		FinishReason:     mappedFinishReason,
247		ProviderMetadata: providerMetadata,
248		Warnings:         warnings,
249	}
250
251	return &resp, nil
252}
253
254// Stream implements fantasy.LanguageModel.
255func (l *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
256	d, warnings, err := l.prepareDocument(call)
257	if err != nil {
258		return nil, err
259	}
260
261	ch, err := l.kronk.ChatStreaming(ctx, d)
262	if err != nil {
263		return nil, toProviderErr(err)
264	}
265
266	isActiveText := false
267	isActiveReasoning := false
268	toolCalls := make(map[int]streamToolCall)
269
270	providerMetadata := fantasy.ProviderMetadata{
271		Name: &ProviderMetadata{},
272	}
273
274	var usage fantasy.Usage
275	var finishReason string
276
277	return func(yield func(fantasy.StreamPart) bool) {
278		if len(warnings) > 0 {
279			if !yield(fantasy.StreamPart{
280				Type:     fantasy.StreamPartTypeWarnings,
281				Warnings: warnings,
282			}) {
283				return
284			}
285		}
286
287		toolIndex := 0
288		for resp := range ch {
289			if len(resp.Choice) == 0 {
290				continue
291			}
292
293			choice := resp.Choice[0]
294			if choice.Delta == nil {
295				continue
296			}
297
298			if resp.Usage != nil {
299				usage = fantasy.Usage{
300					InputTokens:     int64(resp.Usage.PromptTokens),
301					OutputTokens:    int64(resp.Usage.CompletionTokens),
302					TotalTokens:     int64(resp.Usage.PromptTokens + resp.Usage.CompletionTokens),
303					ReasoningTokens: int64(resp.Usage.ReasoningTokens),
304				}
305
306				if pm, ok := providerMetadata[Name]; ok {
307					if metadata, ok := pm.(*ProviderMetadata); ok {
308						metadata.TokensPerSecond = resp.Usage.TokensPerSecond
309						metadata.OutputTokens = int64(resp.Usage.OutputTokens)
310					}
311				}
312			}
313
314			if choice.FinishReason() != "" {
315				finishReason = choice.FinishReason()
316			}
317
318			switch choice.FinishReason() {
319			case model.FinishReasonError:
320				yield(fantasy.StreamPart{
321					Type:  fantasy.StreamPartTypeError,
322					Error: &fantasy.Error{Title: "model error", Message: choice.Delta.Content},
323				})
324				return
325
326			case model.FinishReasonTool:
327				if isActiveReasoning {
328					isActiveReasoning = false
329					if !yield(fantasy.StreamPart{
330						Type: fantasy.StreamPartTypeReasoningEnd,
331						ID:   "reasoning-0",
332					}) {
333						return
334					}
335				}
336
337				if isActiveText {
338					isActiveText = false
339					if !yield(fantasy.StreamPart{
340						Type: fantasy.StreamPartTypeTextEnd,
341						ID:   "0",
342					}) {
343						return
344					}
345				}
346
347				for _, tc := range choice.Delta.ToolCalls {
348					argsJSON, _ := json.Marshal(map[string]any(tc.Function.Arguments))
349					argsStr := string(argsJSON)
350
351					toolID := tc.ID
352					if toolID == "" {
353						toolID = uuid.NewString()
354					}
355
356					if !yield(fantasy.StreamPart{
357						Type:         fantasy.StreamPartTypeToolInputStart,
358						ID:           toolID,
359						ToolCallName: tc.Function.Name,
360					}) {
361						return
362					}
363
364					if !yield(fantasy.StreamPart{
365						Type:  fantasy.StreamPartTypeToolInputDelta,
366						ID:    toolID,
367						Delta: argsStr,
368					}) {
369						return
370					}
371
372					if !yield(fantasy.StreamPart{
373						Type: fantasy.StreamPartTypeToolInputEnd,
374						ID:   toolID,
375					}) {
376						return
377					}
378
379					if !yield(fantasy.StreamPart{
380						Type:          fantasy.StreamPartTypeToolCall,
381						ID:            toolID,
382						ToolCallName:  tc.Function.Name,
383						ToolCallInput: argsStr,
384					}) {
385						return
386					}
387
388					toolCalls[toolIndex] = streamToolCall{
389						id:          toolID,
390						name:        tc.Function.Name,
391						arguments:   argsStr,
392						hasFinished: true,
393					}
394					toolIndex++
395				}
396
397			default:
398				if choice.Delta.Reasoning != "" {
399					if !isActiveReasoning {
400						isActiveReasoning = true
401						if !yield(fantasy.StreamPart{
402							Type: fantasy.StreamPartTypeReasoningStart,
403							ID:   "reasoning-0",
404						}) {
405							return
406						}
407					}
408
409					if !yield(fantasy.StreamPart{
410						Type:  fantasy.StreamPartTypeReasoningDelta,
411						ID:    "reasoning-0",
412						Delta: choice.Delta.Reasoning,
413					}) {
414						return
415					}
416				}
417
418				hasToolCalls := len(choice.Delta.ToolCalls) > 0
419				hasContent := choice.Delta.Content != ""
420
421				if isActiveReasoning && (hasContent || hasToolCalls) {
422					isActiveReasoning = false
423					if !yield(fantasy.StreamPart{
424						Type: fantasy.StreamPartTypeReasoningEnd,
425						ID:   "reasoning-0",
426					}) {
427						return
428					}
429				}
430
431				if hasContent {
432					if !isActiveText {
433						isActiveText = true
434						if !yield(fantasy.StreamPart{
435							Type: fantasy.StreamPartTypeTextStart,
436							ID:   "0",
437						}) {
438							return
439						}
440					}
441
442					if !yield(fantasy.StreamPart{
443						Type:  fantasy.StreamPartTypeTextDelta,
444						ID:    "0",
445						Delta: choice.Delta.Content,
446					}) {
447						return
448					}
449				}
450
451				if hasToolCalls && isActiveText {
452					isActiveText = false
453					if !yield(fantasy.StreamPart{
454						Type: fantasy.StreamPartTypeTextEnd,
455						ID:   "0",
456					}) {
457						return
458					}
459				}
460
461				for _, tc := range choice.Delta.ToolCalls {
462					argsJSON, _ := json.Marshal(map[string]any(tc.Function.Arguments))
463					argsStr := string(argsJSON)
464
465					switch existingTC, ok := toolCalls[toolIndex]; ok {
466					case true:
467						if existingTC.hasFinished {
468							continue
469						}
470
471						existingTC.arguments += argsStr
472
473						if !yield(fantasy.StreamPart{
474							Type:  fantasy.StreamPartTypeToolInputDelta,
475							ID:    existingTC.id,
476							Delta: argsStr,
477						}) {
478							return
479						}
480
481						toolCalls[toolIndex] = existingTC
482
483						if xjson.IsValid(existingTC.arguments) {
484							if !yield(fantasy.StreamPart{
485								Type: fantasy.StreamPartTypeToolInputEnd,
486								ID:   existingTC.id,
487							}) {
488								return
489							}
490
491							if !yield(fantasy.StreamPart{
492								Type:          fantasy.StreamPartTypeToolCall,
493								ID:            existingTC.id,
494								ToolCallName:  existingTC.name,
495								ToolCallInput: existingTC.arguments,
496							}) {
497								return
498							}
499
500							existingTC.hasFinished = true
501							toolCalls[toolIndex] = existingTC
502						}
503
504					case false:
505						toolID := tc.ID
506						if toolID == "" {
507							toolID = uuid.NewString()
508						}
509
510						if !yield(fantasy.StreamPart{
511							Type:         fantasy.StreamPartTypeToolInputStart,
512							ID:           toolID,
513							ToolCallName: tc.Function.Name,
514						}) {
515							return
516						}
517
518						toolCalls[toolIndex] = streamToolCall{
519							id:        toolID,
520							name:      tc.Function.Name,
521							arguments: argsStr,
522						}
523
524						if argsStr != "" && argsStr != "null" {
525							if !yield(fantasy.StreamPart{
526								Type:  fantasy.StreamPartTypeToolInputDelta,
527								ID:    toolID,
528								Delta: argsStr,
529							}) {
530								return
531							}
532
533							if xjson.IsValid(argsStr) {
534								if !yield(fantasy.StreamPart{
535									Type: fantasy.StreamPartTypeToolInputEnd,
536									ID:   toolID,
537								}) {
538									return
539								}
540
541								if !yield(fantasy.StreamPart{
542									Type:          fantasy.StreamPartTypeToolCall,
543									ID:            toolID,
544									ToolCallName:  tc.Function.Name,
545									ToolCallInput: argsStr,
546								}) {
547									return
548								}
549
550								stc := toolCalls[toolIndex]
551								stc.hasFinished = true
552								toolCalls[toolIndex] = stc
553							}
554						}
555
556						toolIndex++
557					}
558				}
559			}
560		}
561
562		if isActiveReasoning {
563			if !yield(fantasy.StreamPart{
564				Type: fantasy.StreamPartTypeReasoningEnd,
565				ID:   "reasoning-0",
566			}) {
567				return
568			}
569		}
570
571		if isActiveText {
572			if !yield(fantasy.StreamPart{
573				Type: fantasy.StreamPartTypeTextEnd,
574				ID:   "0",
575			}) {
576				return
577			}
578		}
579
580		mappedFinishReason := l.mapFinishReasonFunc(finishReason)
581		if len(toolCalls) > 0 {
582			mappedFinishReason = fantasy.FinishReasonToolCalls
583		}
584
585		yield(fantasy.StreamPart{
586			Type:             fantasy.StreamPartTypeFinish,
587			Usage:            usage,
588			FinishReason:     mappedFinishReason,
589			ProviderMetadata: providerMetadata,
590		})
591	}, nil
592}
593
594// GenerateObject implements fantasy.LanguageModel.
595func (l *languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
596	switch l.objectMode {
597	case fantasy.ObjectModeText:
598		return object.GenerateWithText(ctx, l, call)
599
600	case fantasy.ObjectModeTool:
601		return object.GenerateWithTool(ctx, l, call)
602
603	default:
604		return object.GenerateWithTool(ctx, l, call)
605	}
606}
607
608// StreamObject implements fantasy.LanguageModel.
609func (l *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
610	switch l.objectMode {
611	case fantasy.ObjectModeTool:
612		return object.StreamWithTool(ctx, l, call)
613
614	case fantasy.ObjectModeText:
615		return object.StreamWithText(ctx, l, call)
616
617	default:
618		return object.StreamWithTool(ctx, l, call)
619	}
620}
621
622func toKronkTools(tools []fantasy.Tool) ([]model.D, []fantasy.CallWarning) {
623	var kronkTools []model.D
624	var warnings []fantasy.CallWarning
625
626	for _, tool := range tools {
627		if tool.GetType() == fantasy.ToolTypeFunction {
628			ft, ok := tool.(fantasy.FunctionTool)
629			if !ok {
630				continue
631			}
632
633			kronkTools = append(kronkTools, model.D{
634				"type": "function",
635				"function": model.D{
636					"name":        ft.Name,
637					"description": ft.Description,
638					"parameters":  ft.InputSchema,
639				},
640			})
641
642			continue
643		}
644
645		warnings = append(warnings, fantasy.CallWarning{
646			Type:    fantasy.CallWarningTypeUnsupportedTool,
647			Tool:    tool,
648			Message: "tool is not supported",
649		})
650	}
651
652	return kronkTools, warnings
653}
654
655func toProviderErr(err error) error {
656	if err == nil {
657		return nil
658	}
659
660	if errors.Is(err, io.EOF) {
661		return nil
662	}
663
664	return &fantasy.ProviderError{
665		Title:   "kronk error",
666		Message: err.Error(),
667		Cause:   err,
668	}
669}