1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sync"
11
12 "github.com/charmbracelet/crush/internal/llm/tools"
13)
14
15type StepResult struct {
16 Response
17 Messages []Message
18}
19
20type StopCondition = func(steps []StepResult) bool
21
22// StepCountIs returns a stop condition that stops after the specified number of steps.
23func StepCountIs(stepCount int) StopCondition {
24 return func(steps []StepResult) bool {
25 return len(steps) >= stepCount
26 }
27}
28
29// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
30func HasToolCall(toolName string) StopCondition {
31 return func(steps []StepResult) bool {
32 if len(steps) == 0 {
33 return false
34 }
35 lastStep := steps[len(steps)-1]
36 toolCalls := lastStep.Content.ToolCalls()
37 for _, toolCall := range toolCalls {
38 if toolCall.ToolName == toolName {
39 return true
40 }
41 }
42 return false
43 }
44}
45
46// HasContent returns a stop condition that stops when the specified content type appears in the last step.
47func HasContent(contentType ContentType) StopCondition {
48 return func(steps []StepResult) bool {
49 if len(steps) == 0 {
50 return false
51 }
52 lastStep := steps[len(steps)-1]
53 for _, content := range lastStep.Content {
54 if content.GetType() == contentType {
55 return true
56 }
57 }
58 return false
59 }
60}
61
62// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
63func FinishReasonIs(reason FinishReason) StopCondition {
64 return func(steps []StepResult) bool {
65 if len(steps) == 0 {
66 return false
67 }
68 lastStep := steps[len(steps)-1]
69 return lastStep.FinishReason == reason
70 }
71}
72
73// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
74func MaxTokensUsed(maxTokens int64) StopCondition {
75 return func(steps []StepResult) bool {
76 var totalTokens int64
77 for _, step := range steps {
78 totalTokens += step.Usage.TotalTokens
79 }
80 return totalTokens >= maxTokens
81 }
82}
83
84type PrepareStepFunctionOptions struct {
85 Steps []StepResult
86 StepNumber int
87 Model LanguageModel
88 Messages []Message
89}
90
91type PrepareStepResult struct {
92 Model LanguageModel
93 Messages []Message
94 System *string
95 ToolChoice *ToolChoice
96 ActiveTools []string
97 DisableAllTools bool
98}
99
100type ToolCallRepairOptions struct {
101 OriginalToolCall ToolCallContent
102 ValidationError error
103 AvailableTools []tools.BaseTool
104 SystemPrompt string
105 Messages []Message
106}
107
108type (
109 PrepareStepFunction = func(options PrepareStepFunctionOptions) PrepareStepResult
110 OnStepFinishedFunction = func(step StepResult)
111 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
112)
113
114type AgentSettings struct {
115 systemPrompt string
116 maxOutputTokens *int64
117 temperature *float64
118 topP *float64
119 topK *int64
120 presencePenalty *float64
121 frequencyPenalty *float64
122 headers map[string]string
123 providerOptions ProviderOptions
124
125 // TODO: add support for provider tools
126 tools []tools.BaseTool
127 maxRetries *int
128
129 model LanguageModel
130
131 stopWhen []StopCondition
132 prepareStep PrepareStepFunction
133 repairToolCall RepairToolCallFunction
134 onStepFinished OnStepFinishedFunction
135 onRetry OnRetryCallback
136}
137
138type AgentCall struct {
139 Prompt string `json:"prompt"`
140 Files []FilePart `json:"files"`
141 Messages []Message `json:"messages"`
142 MaxOutputTokens *int64
143 Temperature *float64 `json:"temperature"`
144 TopP *float64 `json:"top_p"`
145 TopK *int64 `json:"top_k"`
146 PresencePenalty *float64 `json:"presence_penalty"`
147 FrequencyPenalty *float64 `json:"frequency_penalty"`
148 ActiveTools []string `json:"active_tools"`
149 Headers map[string]string
150 ProviderOptions ProviderOptions
151 OnRetry OnRetryCallback
152 MaxRetries *int
153
154 StopWhen []StopCondition
155 PrepareStep PrepareStepFunction
156 RepairToolCall RepairToolCallFunction
157 OnStepFinished OnStepFinishedFunction
158}
159
160type AgentResult struct {
161 Steps []StepResult
162 // Final response
163 Response Response
164 TotalUsage Usage
165}
166
167type Agent interface {
168 Generate(context.Context, AgentCall) (*AgentResult, error)
169 Stream(context.Context, AgentCall) (StreamResponse, error)
170}
171
172type agentOption = func(*AgentSettings)
173
174type agent struct {
175 settings AgentSettings
176}
177
178func NewAgent(model LanguageModel, opts ...agentOption) Agent {
179 settings := AgentSettings{
180 model: model,
181 }
182 for _, o := range opts {
183 o(&settings)
184 }
185 return &agent{
186 settings: settings,
187 }
188}
189
190func (a *agent) prepareCall(call AgentCall) AgentCall {
191 if call.MaxOutputTokens == nil && a.settings.maxOutputTokens != nil {
192 call.MaxOutputTokens = a.settings.maxOutputTokens
193 }
194 if call.Temperature == nil && a.settings.temperature != nil {
195 call.Temperature = a.settings.temperature
196 }
197 if call.TopP == nil && a.settings.topP != nil {
198 call.TopP = a.settings.topP
199 }
200 if call.TopK == nil && a.settings.topK != nil {
201 call.TopK = a.settings.topK
202 }
203 if call.PresencePenalty == nil && a.settings.presencePenalty != nil {
204 call.PresencePenalty = a.settings.presencePenalty
205 }
206 if call.FrequencyPenalty == nil && a.settings.frequencyPenalty != nil {
207 call.FrequencyPenalty = a.settings.frequencyPenalty
208 }
209 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
210 call.StopWhen = a.settings.stopWhen
211 }
212 if call.PrepareStep == nil && a.settings.prepareStep != nil {
213 call.PrepareStep = a.settings.prepareStep
214 }
215 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
216 call.RepairToolCall = a.settings.repairToolCall
217 }
218 if call.OnStepFinished == nil && a.settings.onStepFinished != nil {
219 call.OnStepFinished = a.settings.onStepFinished
220 }
221 if call.OnRetry == nil && a.settings.onRetry != nil {
222 call.OnRetry = a.settings.onRetry
223 }
224 if call.MaxRetries == nil && a.settings.maxRetries != nil {
225 call.MaxRetries = a.settings.maxRetries
226 }
227
228 providerOptions := ProviderOptions{}
229 if a.settings.providerOptions != nil {
230 maps.Copy(providerOptions, a.settings.providerOptions)
231 }
232 if call.ProviderOptions != nil {
233 maps.Copy(providerOptions, call.ProviderOptions)
234 }
235 call.ProviderOptions = providerOptions
236
237 headers := map[string]string{}
238
239 if a.settings.headers != nil {
240 maps.Copy(headers, a.settings.headers)
241 }
242
243 if call.Headers != nil {
244 maps.Copy(headers, call.Headers)
245 }
246 call.Headers = headers
247 return call
248}
249
250// Generate implements Agent.
251func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
252 opts = a.prepareCall(opts)
253 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
254 if err != nil {
255 return nil, err
256 }
257 var responseMessages []Message
258 var steps []StepResult
259
260 for {
261 stepInputMessages := append(initialPrompt, responseMessages...)
262 stepModel := a.settings.model
263 stepSystemPrompt := a.settings.systemPrompt
264 stepActiveTools := opts.ActiveTools
265 stepToolChoice := ToolChoiceAuto
266 disableAllTools := false
267
268 if opts.PrepareStep != nil {
269 prepared := opts.PrepareStep(PrepareStepFunctionOptions{
270 Model: stepModel,
271 Steps: steps,
272 StepNumber: len(steps),
273 Messages: stepInputMessages,
274 })
275
276 // Apply prepared step modifications
277 if prepared.Messages != nil {
278 stepInputMessages = prepared.Messages
279 }
280 if prepared.Model != nil {
281 stepModel = prepared.Model
282 }
283 if prepared.System != nil {
284 stepSystemPrompt = *prepared.System
285 }
286 if prepared.ToolChoice != nil {
287 stepToolChoice = *prepared.ToolChoice
288 }
289 if len(prepared.ActiveTools) > 0 {
290 stepActiveTools = prepared.ActiveTools
291 }
292 disableAllTools = prepared.DisableAllTools
293 }
294
295 // Recreate prompt with potentially modified system prompt
296 if stepSystemPrompt != a.settings.systemPrompt {
297 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
298 if err != nil {
299 return nil, err
300 }
301 // Replace system message part, keep the rest
302 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
303 stepInputMessages[0] = stepPrompt[0] // Replace system message
304 }
305 }
306
307 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
308
309 retryOptions := DefaultRetryOptions()
310 retryOptions.OnRetry = opts.OnRetry
311 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
312
313 result, err := retry(ctx, func() (*Response, error) {
314 return stepModel.Generate(ctx, Call{
315 Prompt: stepInputMessages,
316 MaxOutputTokens: opts.MaxOutputTokens,
317 Temperature: opts.Temperature,
318 TopP: opts.TopP,
319 TopK: opts.TopK,
320 PresencePenalty: opts.PresencePenalty,
321 FrequencyPenalty: opts.FrequencyPenalty,
322 Tools: preparedTools,
323 ToolChoice: &stepToolChoice,
324 Headers: opts.Headers,
325 ProviderOptions: opts.ProviderOptions,
326 })
327 })
328 if err != nil {
329 return nil, err
330 }
331
332 var stepToolCalls []ToolCallContent
333 for _, content := range result.Content {
334 if content.GetType() == ContentTypeToolCall {
335 toolCall, ok := AsContentType[ToolCallContent](content)
336 if !ok {
337 continue
338 }
339
340 // Validate and potentially repair the tool call
341 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
342 stepToolCalls = append(stepToolCalls, validatedToolCall)
343 }
344 }
345
346 toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls)
347
348 // Build step content with validated tool calls and tool results
349 stepContent := []Content{}
350 toolCallIndex := 0
351 for _, content := range result.Content {
352 if content.GetType() == ContentTypeToolCall {
353 // Replace with validated tool call
354 if toolCallIndex < len(stepToolCalls) {
355 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
356 toolCallIndex++
357 }
358 } else {
359 // Keep other content as-is
360 stepContent = append(stepContent, content)
361 }
362 }
363 // Add tool results
364 for _, result := range toolResults {
365 stepContent = append(stepContent, result)
366 }
367 currentStepMessages := toResponseMessages(stepContent)
368 responseMessages = append(responseMessages, currentStepMessages...)
369
370 stepResult := StepResult{
371 Response: Response{
372 Content: stepContent,
373 FinishReason: result.FinishReason,
374 Usage: result.Usage,
375 Warnings: result.Warnings,
376 ProviderMetadata: result.ProviderMetadata,
377 },
378 Messages: currentStepMessages,
379 }
380 steps = append(steps, stepResult)
381 if opts.OnStepFinished != nil {
382 opts.OnStepFinished(stepResult)
383 }
384
385 shouldStop := isStopConditionMet(opts.StopWhen, steps)
386
387 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
388 break
389 }
390 }
391
392 totalUsage := Usage{}
393
394 for _, step := range steps {
395 usage := step.Usage
396 totalUsage.InputTokens += usage.InputTokens
397 totalUsage.OutputTokens += usage.OutputTokens
398 totalUsage.ReasoningTokens += usage.ReasoningTokens
399 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
400 totalUsage.CacheReadTokens += usage.CacheReadTokens
401 totalUsage.TotalTokens += usage.TotalTokens
402 }
403
404 agentResult := &AgentResult{
405 Steps: steps,
406 Response: steps[len(steps)-1].Response,
407 TotalUsage: totalUsage,
408 }
409 return agentResult, nil
410}
411
412func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
413 if len(conditions) == 0 {
414 return false
415 }
416
417 for _, condition := range conditions {
418 if condition(steps) {
419 return true
420 }
421 }
422 return false
423}
424
425func toResponseMessages(content []Content) []Message {
426 var assistantParts []MessagePart
427 var toolParts []MessagePart
428
429 for _, c := range content {
430 switch c.GetType() {
431 case ContentTypeText:
432 text, ok := AsContentType[TextContent](c)
433 if !ok {
434 continue
435 }
436 assistantParts = append(assistantParts, TextPart{
437 Text: text.Text,
438 ProviderOptions: ProviderOptions(text.ProviderMetadata),
439 })
440 case ContentTypeReasoning:
441 reasoning, ok := AsContentType[ReasoningContent](c)
442 if !ok {
443 continue
444 }
445 assistantParts = append(assistantParts, ReasoningPart{
446 Text: reasoning.Text,
447 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
448 })
449 case ContentTypeToolCall:
450 toolCall, ok := AsContentType[ToolCallContent](c)
451 if !ok {
452 continue
453 }
454 assistantParts = append(assistantParts, ToolCallPart{
455 ToolCallID: toolCall.ToolCallID,
456 ToolName: toolCall.ToolName,
457 Input: toolCall.Input,
458 ProviderExecuted: toolCall.ProviderExecuted,
459 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
460 })
461 case ContentTypeFile:
462 file, ok := AsContentType[FileContent](c)
463 if !ok {
464 continue
465 }
466 assistantParts = append(assistantParts, FilePart{
467 Data: file.Data,
468 MediaType: file.MediaType,
469 ProviderOptions: ProviderOptions(file.ProviderMetadata),
470 })
471 case ContentTypeSource:
472 // Sources are metadata about references used to generate the response.
473 // They don't need to be included in the conversation messages.
474 continue
475 case ContentTypeToolResult:
476 result, ok := AsContentType[ToolResultContent](c)
477 if !ok {
478 continue
479 }
480 toolParts = append(toolParts, ToolResultPart{
481 ToolCallID: result.ToolCallID,
482 Output: result.Result,
483 ProviderOptions: ProviderOptions(result.ProviderMetadata),
484 })
485 }
486 }
487
488 var messages []Message
489 if len(assistantParts) > 0 {
490 messages = append(messages, Message{
491 Role: MessageRoleAssistant,
492 Content: assistantParts,
493 })
494 }
495 if len(toolParts) > 0 {
496 messages = append(messages, Message{
497 Role: MessageRoleTool,
498 Content: toolParts,
499 })
500 }
501 return messages
502}
503
504func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent) ([]ToolResultContent, error) {
505 if len(toolCalls) == 0 {
506 return nil, nil
507 }
508
509 // Create a map for quick tool lookup
510 toolMap := make(map[string]tools.BaseTool)
511 for _, tool := range allTools {
512 toolMap[tool.Info().Name] = tool
513 }
514
515 // Execute all tool calls in parallel
516 results := make([]ToolResultContent, len(toolCalls))
517 var toolExecutionError error
518 var wg sync.WaitGroup
519
520 for i, toolCall := range toolCalls {
521 wg.Add(1)
522 go func(index int, call ToolCallContent) {
523 defer wg.Done()
524
525 // Skip invalid tool calls - create error result
526 if call.Invalid {
527 results[index] = ToolResultContent{
528 ToolCallID: call.ToolCallID,
529 ToolName: call.ToolName,
530 Result: ToolResultOutputContentError{
531 Error: call.ValidationError,
532 },
533 ProviderExecuted: false,
534 }
535 return
536 }
537
538 tool, exists := toolMap[call.ToolName]
539 if !exists {
540 results[index] = ToolResultContent{
541 ToolCallID: call.ToolCallID,
542 ToolName: call.ToolName,
543 Result: ToolResultOutputContentError{
544 Error: errors.New("Error: Tool not found: " + call.ToolName),
545 },
546 ProviderExecuted: false,
547 }
548 return
549 }
550
551 // Execute the tool
552 result, err := tool.Run(ctx, tools.ToolCall{
553 ID: call.ToolCallID,
554 Name: call.ToolName,
555 Input: call.Input,
556 })
557 if err != nil {
558 results[index] = ToolResultContent{
559 ToolCallID: call.ToolCallID,
560 ToolName: call.ToolName,
561 Result: ToolResultOutputContentError{
562 Error: err,
563 },
564 ProviderExecuted: false,
565 }
566 toolExecutionError = err
567 return
568 }
569
570 if result.IsError {
571 results[index] = ToolResultContent{
572 ToolCallID: call.ToolCallID,
573 ToolName: call.ToolName,
574 Result: ToolResultOutputContentError{
575 Error: errors.New(result.Content),
576 },
577 ProviderExecuted: false,
578 }
579 } else {
580 results[index] = ToolResultContent{
581 ToolCallID: call.ToolCallID,
582 ToolName: toolCall.ToolName,
583 Result: ToolResultOutputContentText{
584 Text: result.Content,
585 },
586 ProviderExecuted: false,
587 }
588 }
589 }(i, toolCall)
590 }
591
592 // Wait for all tool executions to complete
593 wg.Wait()
594
595 return results, toolExecutionError
596}
597
598// Stream implements Agent.
599func (a *agent) Stream(ctx context.Context, opts AgentCall) (StreamResponse, error) {
600 // TODO: implement the agentic stuff
601 panic("not implemented")
602}
603
604func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string, disableAllTools bool) []Tool {
605 var preparedTools []Tool
606
607 // If explicitly disabling all tools, return no tools
608 if disableAllTools {
609 return preparedTools
610 }
611
612 for _, tool := range tools {
613 // If activeTools has items, only include tools in the list
614 // If activeTools is empty, include all tools
615 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
616 continue
617 }
618 info := tool.Info()
619 preparedTools = append(preparedTools, FunctionTool{
620 Name: info.Name,
621 Description: info.Description,
622 InputSchema: map[string]any{
623 "type": "object",
624 "properties": info.Parameters,
625 "required": info.Required,
626 },
627 })
628 }
629 return preparedTools
630}
631
632// validateAndRepairToolCall validates a tool call and attempts repair if validation fails
633func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []tools.BaseTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
634 if err := a.validateToolCall(toolCall, availableTools); err == nil {
635 return toolCall
636 } else {
637 if repairFunc != nil {
638 repairOptions := ToolCallRepairOptions{
639 OriginalToolCall: toolCall,
640 ValidationError: err,
641 AvailableTools: availableTools,
642 SystemPrompt: systemPrompt,
643 Messages: messages,
644 }
645
646 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
647 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
648 return *repairedToolCall
649 }
650 }
651 }
652
653 invalidToolCall := toolCall
654 invalidToolCall.Invalid = true
655 invalidToolCall.ValidationError = err
656 return invalidToolCall
657 }
658}
659
660// validateToolCall validates a tool call against available tools and their schemas
661func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []tools.BaseTool) error {
662 var tool tools.BaseTool
663 for _, t := range availableTools {
664 if t.Info().Name == toolCall.ToolName {
665 tool = t
666 break
667 }
668 }
669
670 if tool == nil {
671 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
672 }
673
674 // Validate JSON parsing
675 var input map[string]any
676 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
677 return fmt.Errorf("invalid JSON input: %w", err)
678 }
679
680 // Basic schema validation (check required fields)
681 // TODO: more robust schema validation using JSON Schema or similar
682 toolInfo := tool.Info()
683 for _, required := range toolInfo.Required {
684 if _, exists := input[required]; !exists {
685 return fmt.Errorf("missing required parameter: %s", required)
686 }
687 }
688 return nil
689}
690
691func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
692 if prompt == "" {
693 return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
694 }
695
696 var preparedPrompt Prompt
697
698 if system != "" {
699 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
700 }
701
702 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
703 preparedPrompt = append(preparedPrompt, messages...)
704 return preparedPrompt, nil
705}
706
707func WithSystemPrompt(prompt string) agentOption {
708 return func(s *AgentSettings) {
709 s.systemPrompt = prompt
710 }
711}
712
713func WithMaxOutputTokens(tokens int64) agentOption {
714 return func(s *AgentSettings) {
715 s.maxOutputTokens = &tokens
716 }
717}
718
719func WithTemperature(temp float64) agentOption {
720 return func(s *AgentSettings) {
721 s.temperature = &temp
722 }
723}
724
725func WithTopP(topP float64) agentOption {
726 return func(s *AgentSettings) {
727 s.topP = &topP
728 }
729}
730
731func WithTopK(topK int64) agentOption {
732 return func(s *AgentSettings) {
733 s.topK = &topK
734 }
735}
736
737func WithPresencePenalty(penalty float64) agentOption {
738 return func(s *AgentSettings) {
739 s.presencePenalty = &penalty
740 }
741}
742
743func WithFrequencyPenalty(penalty float64) agentOption {
744 return func(s *AgentSettings) {
745 s.frequencyPenalty = &penalty
746 }
747}
748
749func WithTools(tools ...tools.BaseTool) agentOption {
750 return func(s *AgentSettings) {
751 s.tools = append(s.tools, tools...)
752 }
753}
754
755func WithStopConditions(conditions ...StopCondition) agentOption {
756 return func(s *AgentSettings) {
757 s.stopWhen = append(s.stopWhen, conditions...)
758 }
759}
760
761func WithPrepareStep(fn PrepareStepFunction) agentOption {
762 return func(s *AgentSettings) {
763 s.prepareStep = fn
764 }
765}
766
767func WithRepairToolCall(fn RepairToolCallFunction) agentOption {
768 return func(s *AgentSettings) {
769 s.repairToolCall = fn
770 }
771}
772
773func WithOnStepFinished(fn OnStepFinishedFunction) agentOption {
774 return func(s *AgentSettings) {
775 s.onStepFinished = fn
776 }
777}
778
779func WithHeaders(headers map[string]string) agentOption {
780 return func(s *AgentSettings) {
781 s.headers = headers
782 }
783}
784
785func WithProviderOptions(providerOptions ProviderOptions) agentOption {
786 return func(s *AgentSettings) {
787 s.providerOptions = providerOptions
788 }
789}