1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sync"
11
12 "github.com/charmbracelet/crush/internal/llm/tools"
13)
14
15type StepResult struct {
16 Response
17 Messages []Message
18}
19
20type StopCondition = func(steps []StepResult) bool
21
22// StepCountIs returns a stop condition that stops after the specified number of steps.
23func StepCountIs(stepCount int) StopCondition {
24 return func(steps []StepResult) bool {
25 return len(steps) >= stepCount
26 }
27}
28
29// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
30func HasToolCall(toolName string) StopCondition {
31 return func(steps []StepResult) bool {
32 if len(steps) == 0 {
33 return false
34 }
35 lastStep := steps[len(steps)-1]
36 toolCalls := lastStep.Content.ToolCalls()
37 for _, toolCall := range toolCalls {
38 if toolCall.ToolName == toolName {
39 return true
40 }
41 }
42 return false
43 }
44}
45
46// HasContent returns a stop condition that stops when the specified content type appears in the last step.
47func HasContent(contentType ContentType) StopCondition {
48 return func(steps []StepResult) bool {
49 if len(steps) == 0 {
50 return false
51 }
52 lastStep := steps[len(steps)-1]
53 for _, content := range lastStep.Content {
54 if content.GetType() == contentType {
55 return true
56 }
57 }
58 return false
59 }
60}
61
62// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
63func FinishReasonIs(reason FinishReason) StopCondition {
64 return func(steps []StepResult) bool {
65 if len(steps) == 0 {
66 return false
67 }
68 lastStep := steps[len(steps)-1]
69 return lastStep.FinishReason == reason
70 }
71}
72
73// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
74func MaxTokensUsed(maxTokens int64) StopCondition {
75 return func(steps []StepResult) bool {
76 var totalTokens int64
77 for _, step := range steps {
78 totalTokens += step.Usage.TotalTokens
79 }
80 return totalTokens >= maxTokens
81 }
82}
83
84type PrepareStepFunctionOptions struct {
85 Steps []StepResult
86 StepNumber int
87 Model LanguageModel
88 Messages []Message
89}
90
91type PrepareStepResult struct {
92 Model LanguageModel
93 Messages []Message
94 System *string
95 ToolChoice *ToolChoice
96 ActiveTools []string
97 DisableAllTools bool
98}
99
100type ToolCallRepairOptions struct {
101 OriginalToolCall ToolCallContent
102 ValidationError error
103 AvailableTools []tools.BaseTool
104 SystemPrompt string
105 Messages []Message
106}
107
108type (
109 PrepareStepFunction = func(options PrepareStepFunctionOptions) PrepareStepResult
110 OnStepFinishedFunction = func(step StepResult)
111 RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
112)
113
114type AgentSettings struct {
115 systemPrompt string
116 maxOutputTokens *int64
117 temperature *float64
118 topP *float64
119 topK *int64
120 presencePenalty *float64
121 frequencyPenalty *float64
122 headers map[string]string
123 providerOptions ProviderOptions
124
125 // TODO: add support for provider tools
126 tools []tools.BaseTool
127 maxRetries *int
128
129 model LanguageModel
130
131 stopWhen []StopCondition
132 prepareStep PrepareStepFunction
133 repairToolCall RepairToolCallFunction
134 onStepFinished OnStepFinishedFunction
135 onRetry OnRetryCallback
136}
137
138type AgentCall struct {
139 Prompt string `json:"prompt"`
140 Files []FilePart `json:"files"`
141 Messages []Message `json:"messages"`
142 MaxOutputTokens *int64
143 Temperature *float64 `json:"temperature"`
144 TopP *float64 `json:"top_p"`
145 TopK *int64 `json:"top_k"`
146 PresencePenalty *float64 `json:"presence_penalty"`
147 FrequencyPenalty *float64 `json:"frequency_penalty"`
148 ActiveTools []string `json:"active_tools"`
149 Headers map[string]string
150 ProviderOptions ProviderOptions
151 OnRetry OnRetryCallback
152 MaxRetries *int
153
154 StopWhen []StopCondition
155 PrepareStep PrepareStepFunction
156 RepairToolCall RepairToolCallFunction
157 OnStepFinished OnStepFinishedFunction
158}
159
160type AgentStreamCall struct {
161 Prompt string `json:"prompt"`
162 Files []FilePart `json:"files"`
163 Messages []Message `json:"messages"`
164 MaxOutputTokens *int64
165 Temperature *float64 `json:"temperature"`
166 TopP *float64 `json:"top_p"`
167 TopK *int64 `json:"top_k"`
168 PresencePenalty *float64 `json:"presence_penalty"`
169 FrequencyPenalty *float64 `json:"frequency_penalty"`
170 ActiveTools []string `json:"active_tools"`
171 Headers map[string]string
172 ProviderOptions ProviderOptions
173 OnRetry OnRetryCallback
174 MaxRetries *int
175
176 StopWhen []StopCondition
177 PrepareStep PrepareStepFunction
178 RepairToolCall RepairToolCallFunction
179
180 // Agent-level callbacks
181 OnAgentStart func() // Called when agent starts
182 OnAgentFinish func(result *AgentResult) // Called when agent finishes
183 OnStepStart func(stepNumber int) // Called when a step starts
184 OnStepFinish func(stepResult StepResult) // Called when a step finishes
185 OnFinish func(result *AgentResult) // Called when entire agent completes
186 OnError func(error) // Called when an error occurs
187
188 // Stream part callbacks - called for each corresponding stream part type
189 OnChunk func(StreamPart) // Called for each stream part (catch-all)
190 OnWarnings func(warnings []CallWarning) // Called for warnings
191 OnTextStart func(id string) // Called when text starts
192 OnTextDelta func(id, text string) // Called for text deltas
193 OnTextEnd func(id string) // Called when text ends
194 OnReasoningStart func(id string) // Called when reasoning starts
195 OnReasoningDelta func(id, text string) // Called for reasoning deltas
196 OnReasoningEnd func(id string) // Called when reasoning ends
197 OnToolInputStart func(id, toolName string) // Called when tool input starts
198 OnToolInputDelta func(id, delta string) // Called for tool input deltas
199 OnToolInputEnd func(id string) // Called when tool input ends
200 OnToolCall func(toolCall ToolCallContent) // Called when tool call is complete
201 OnToolResult func(result ToolResultContent) // Called when tool execution completes
202 OnSource func(source SourceContent) // Called for source references
203 OnStreamFinish func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) // Called when stream finishes
204 OnStreamError func(error) // Called when stream error occurs
205}
206
207type AgentResult struct {
208 Steps []StepResult
209 // Final response
210 Response Response
211 TotalUsage Usage
212}
213
214type Agent interface {
215 Generate(context.Context, AgentCall) (*AgentResult, error)
216 Stream(context.Context, AgentStreamCall) (*AgentResult, error)
217}
218
219type agentOption = func(*AgentSettings)
220
221type agent struct {
222 settings AgentSettings
223}
224
225func NewAgent(model LanguageModel, opts ...agentOption) Agent {
226 settings := AgentSettings{
227 model: model,
228 }
229 for _, o := range opts {
230 o(&settings)
231 }
232 return &agent{
233 settings: settings,
234 }
235}
236
237func (a *agent) prepareCall(call AgentCall) AgentCall {
238 if call.MaxOutputTokens == nil && a.settings.maxOutputTokens != nil {
239 call.MaxOutputTokens = a.settings.maxOutputTokens
240 }
241 if call.Temperature == nil && a.settings.temperature != nil {
242 call.Temperature = a.settings.temperature
243 }
244 if call.TopP == nil && a.settings.topP != nil {
245 call.TopP = a.settings.topP
246 }
247 if call.TopK == nil && a.settings.topK != nil {
248 call.TopK = a.settings.topK
249 }
250 if call.PresencePenalty == nil && a.settings.presencePenalty != nil {
251 call.PresencePenalty = a.settings.presencePenalty
252 }
253 if call.FrequencyPenalty == nil && a.settings.frequencyPenalty != nil {
254 call.FrequencyPenalty = a.settings.frequencyPenalty
255 }
256 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
257 call.StopWhen = a.settings.stopWhen
258 }
259 if call.PrepareStep == nil && a.settings.prepareStep != nil {
260 call.PrepareStep = a.settings.prepareStep
261 }
262 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
263 call.RepairToolCall = a.settings.repairToolCall
264 }
265 if call.OnStepFinished == nil && a.settings.onStepFinished != nil {
266 call.OnStepFinished = a.settings.onStepFinished
267 }
268 if call.OnRetry == nil && a.settings.onRetry != nil {
269 call.OnRetry = a.settings.onRetry
270 }
271 if call.MaxRetries == nil && a.settings.maxRetries != nil {
272 call.MaxRetries = a.settings.maxRetries
273 }
274
275 providerOptions := ProviderOptions{}
276 if a.settings.providerOptions != nil {
277 maps.Copy(providerOptions, a.settings.providerOptions)
278 }
279 if call.ProviderOptions != nil {
280 maps.Copy(providerOptions, call.ProviderOptions)
281 }
282 call.ProviderOptions = providerOptions
283
284 headers := map[string]string{}
285
286 if a.settings.headers != nil {
287 maps.Copy(headers, a.settings.headers)
288 }
289
290 if call.Headers != nil {
291 maps.Copy(headers, call.Headers)
292 }
293 call.Headers = headers
294 return call
295}
296
297// Generate implements Agent.
298func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
299 opts = a.prepareCall(opts)
300 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
301 if err != nil {
302 return nil, err
303 }
304 var responseMessages []Message
305 var steps []StepResult
306
307 for {
308 stepInputMessages := append(initialPrompt, responseMessages...)
309 stepModel := a.settings.model
310 stepSystemPrompt := a.settings.systemPrompt
311 stepActiveTools := opts.ActiveTools
312 stepToolChoice := ToolChoiceAuto
313 disableAllTools := false
314
315 if opts.PrepareStep != nil {
316 prepared := opts.PrepareStep(PrepareStepFunctionOptions{
317 Model: stepModel,
318 Steps: steps,
319 StepNumber: len(steps),
320 Messages: stepInputMessages,
321 })
322
323 // Apply prepared step modifications
324 if prepared.Messages != nil {
325 stepInputMessages = prepared.Messages
326 }
327 if prepared.Model != nil {
328 stepModel = prepared.Model
329 }
330 if prepared.System != nil {
331 stepSystemPrompt = *prepared.System
332 }
333 if prepared.ToolChoice != nil {
334 stepToolChoice = *prepared.ToolChoice
335 }
336 if len(prepared.ActiveTools) > 0 {
337 stepActiveTools = prepared.ActiveTools
338 }
339 disableAllTools = prepared.DisableAllTools
340 }
341
342 // Recreate prompt with potentially modified system prompt
343 if stepSystemPrompt != a.settings.systemPrompt {
344 stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
345 if err != nil {
346 return nil, err
347 }
348 // Replace system message part, keep the rest
349 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
350 stepInputMessages[0] = stepPrompt[0] // Replace system message
351 }
352 }
353
354 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
355
356 retryOptions := DefaultRetryOptions()
357 retryOptions.OnRetry = opts.OnRetry
358 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
359
360 result, err := retry(ctx, func() (*Response, error) {
361 return stepModel.Generate(ctx, Call{
362 Prompt: stepInputMessages,
363 MaxOutputTokens: opts.MaxOutputTokens,
364 Temperature: opts.Temperature,
365 TopP: opts.TopP,
366 TopK: opts.TopK,
367 PresencePenalty: opts.PresencePenalty,
368 FrequencyPenalty: opts.FrequencyPenalty,
369 Tools: preparedTools,
370 ToolChoice: &stepToolChoice,
371 Headers: opts.Headers,
372 ProviderOptions: opts.ProviderOptions,
373 })
374 })
375 if err != nil {
376 return nil, err
377 }
378
379 var stepToolCalls []ToolCallContent
380 for _, content := range result.Content {
381 if content.GetType() == ContentTypeToolCall {
382 toolCall, ok := AsContentType[ToolCallContent](content)
383 if !ok {
384 continue
385 }
386
387 // Validate and potentially repair the tool call
388 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
389 stepToolCalls = append(stepToolCalls, validatedToolCall)
390 }
391 }
392
393 toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
394
395 // Build step content with validated tool calls and tool results
396 stepContent := []Content{}
397 toolCallIndex := 0
398 for _, content := range result.Content {
399 if content.GetType() == ContentTypeToolCall {
400 // Replace with validated tool call
401 if toolCallIndex < len(stepToolCalls) {
402 stepContent = append(stepContent, stepToolCalls[toolCallIndex])
403 toolCallIndex++
404 }
405 } else {
406 // Keep other content as-is
407 stepContent = append(stepContent, content)
408 }
409 }
410 // Add tool results
411 for _, result := range toolResults {
412 stepContent = append(stepContent, result)
413 }
414 currentStepMessages := toResponseMessages(stepContent)
415 responseMessages = append(responseMessages, currentStepMessages...)
416
417 stepResult := StepResult{
418 Response: Response{
419 Content: stepContent,
420 FinishReason: result.FinishReason,
421 Usage: result.Usage,
422 Warnings: result.Warnings,
423 ProviderMetadata: result.ProviderMetadata,
424 },
425 Messages: currentStepMessages,
426 }
427 steps = append(steps, stepResult)
428 if opts.OnStepFinished != nil {
429 opts.OnStepFinished(stepResult)
430 }
431
432 shouldStop := isStopConditionMet(opts.StopWhen, steps)
433
434 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
435 break
436 }
437 }
438
439 totalUsage := Usage{}
440
441 for _, step := range steps {
442 usage := step.Usage
443 totalUsage.InputTokens += usage.InputTokens
444 totalUsage.OutputTokens += usage.OutputTokens
445 totalUsage.ReasoningTokens += usage.ReasoningTokens
446 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
447 totalUsage.CacheReadTokens += usage.CacheReadTokens
448 totalUsage.TotalTokens += usage.TotalTokens
449 }
450
451 agentResult := &AgentResult{
452 Steps: steps,
453 Response: steps[len(steps)-1].Response,
454 TotalUsage: totalUsage,
455 }
456 return agentResult, nil
457}
458
459func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
460 if len(conditions) == 0 {
461 return false
462 }
463
464 for _, condition := range conditions {
465 if condition(steps) {
466 return true
467 }
468 }
469 return false
470}
471
472func toResponseMessages(content []Content) []Message {
473 var assistantParts []MessagePart
474 var toolParts []MessagePart
475
476 for _, c := range content {
477 switch c.GetType() {
478 case ContentTypeText:
479 text, ok := AsContentType[TextContent](c)
480 if !ok {
481 continue
482 }
483 assistantParts = append(assistantParts, TextPart{
484 Text: text.Text,
485 ProviderOptions: ProviderOptions(text.ProviderMetadata),
486 })
487 case ContentTypeReasoning:
488 reasoning, ok := AsContentType[ReasoningContent](c)
489 if !ok {
490 continue
491 }
492 assistantParts = append(assistantParts, ReasoningPart{
493 Text: reasoning.Text,
494 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
495 })
496 case ContentTypeToolCall:
497 toolCall, ok := AsContentType[ToolCallContent](c)
498 if !ok {
499 continue
500 }
501 assistantParts = append(assistantParts, ToolCallPart{
502 ToolCallID: toolCall.ToolCallID,
503 ToolName: toolCall.ToolName,
504 Input: toolCall.Input,
505 ProviderExecuted: toolCall.ProviderExecuted,
506 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
507 })
508 case ContentTypeFile:
509 file, ok := AsContentType[FileContent](c)
510 if !ok {
511 continue
512 }
513 assistantParts = append(assistantParts, FilePart{
514 Data: file.Data,
515 MediaType: file.MediaType,
516 ProviderOptions: ProviderOptions(file.ProviderMetadata),
517 })
518 case ContentTypeSource:
519 // Sources are metadata about references used to generate the response.
520 // They don't need to be included in the conversation messages.
521 continue
522 case ContentTypeToolResult:
523 result, ok := AsContentType[ToolResultContent](c)
524 if !ok {
525 continue
526 }
527 toolParts = append(toolParts, ToolResultPart{
528 ToolCallID: result.ToolCallID,
529 Output: result.Result,
530 ProviderOptions: ProviderOptions(result.ProviderMetadata),
531 })
532 }
533 }
534
535 var messages []Message
536 if len(assistantParts) > 0 {
537 messages = append(messages, Message{
538 Role: MessageRoleAssistant,
539 Content: assistantParts,
540 })
541 }
542 if len(toolParts) > 0 {
543 messages = append(messages, Message{
544 Role: MessageRoleTool,
545 Content: toolParts,
546 })
547 }
548 return messages
549}
550
551func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) {
552 if len(toolCalls) == 0 {
553 return nil, nil
554 }
555
556 // Create a map for quick tool lookup
557 toolMap := make(map[string]tools.BaseTool)
558 for _, tool := range allTools {
559 toolMap[tool.Info().Name] = tool
560 }
561
562 // Execute all tool calls in parallel
563 results := make([]ToolResultContent, len(toolCalls))
564 var toolExecutionError error
565 var wg sync.WaitGroup
566
567 for i, toolCall := range toolCalls {
568 wg.Add(1)
569 go func(index int, call ToolCallContent) {
570 defer wg.Done()
571
572 // Skip invalid tool calls - create error result
573 if call.Invalid {
574 results[index] = ToolResultContent{
575 ToolCallID: call.ToolCallID,
576 ToolName: call.ToolName,
577 Result: ToolResultOutputContentError{
578 Error: call.ValidationError,
579 },
580 ProviderExecuted: false,
581 }
582 if toolResultCallback != nil {
583 toolResultCallback(results[index])
584 }
585
586 return
587 }
588
589 tool, exists := toolMap[call.ToolName]
590 if !exists {
591 results[index] = ToolResultContent{
592 ToolCallID: call.ToolCallID,
593 ToolName: call.ToolName,
594 Result: ToolResultOutputContentError{
595 Error: errors.New("Error: Tool not found: " + call.ToolName),
596 },
597 ProviderExecuted: false,
598 }
599
600 if toolResultCallback != nil {
601 toolResultCallback(results[index])
602 }
603 return
604 }
605
606 // Execute the tool
607 result, err := tool.Run(ctx, tools.ToolCall{
608 ID: call.ToolCallID,
609 Name: call.ToolName,
610 Input: call.Input,
611 })
612 if err != nil {
613 results[index] = ToolResultContent{
614 ToolCallID: call.ToolCallID,
615 ToolName: call.ToolName,
616 Result: ToolResultOutputContentError{
617 Error: err,
618 },
619 ProviderExecuted: false,
620 }
621 if toolResultCallback != nil {
622 toolResultCallback(results[index])
623 }
624 toolExecutionError = err
625 return
626 }
627
628 if result.IsError {
629 results[index] = ToolResultContent{
630 ToolCallID: call.ToolCallID,
631 ToolName: call.ToolName,
632 Result: ToolResultOutputContentError{
633 Error: errors.New(result.Content),
634 },
635 ProviderExecuted: false,
636 }
637
638 if toolResultCallback != nil {
639 toolResultCallback(results[index])
640 }
641 } else {
642 results[index] = ToolResultContent{
643 ToolCallID: call.ToolCallID,
644 ToolName: toolCall.ToolName,
645 Result: ToolResultOutputContentText{
646 Text: result.Content,
647 },
648 ProviderExecuted: false,
649 }
650 if toolResultCallback != nil {
651 toolResultCallback(results[index])
652 }
653 }
654 }(i, toolCall)
655 }
656
657 // Wait for all tool executions to complete
658 wg.Wait()
659
660 return results, toolExecutionError
661}
662
663// Stream implements Agent.
664func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
665 // Convert AgentStreamCall to AgentCall for preparation
666 call := AgentCall{
667 Prompt: opts.Prompt,
668 Files: opts.Files,
669 Messages: opts.Messages,
670 MaxOutputTokens: opts.MaxOutputTokens,
671 Temperature: opts.Temperature,
672 TopP: opts.TopP,
673 TopK: opts.TopK,
674 PresencePenalty: opts.PresencePenalty,
675 FrequencyPenalty: opts.FrequencyPenalty,
676 ActiveTools: opts.ActiveTools,
677 Headers: opts.Headers,
678 ProviderOptions: opts.ProviderOptions,
679 MaxRetries: opts.MaxRetries,
680 StopWhen: opts.StopWhen,
681 PrepareStep: opts.PrepareStep,
682 RepairToolCall: opts.RepairToolCall,
683 }
684
685 call = a.prepareCall(call)
686
687 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
688 if err != nil {
689 return nil, err
690 }
691
692 var responseMessages []Message
693 var steps []StepResult
694 var totalUsage Usage
695
696 // Start agent stream
697 if opts.OnAgentStart != nil {
698 opts.OnAgentStart()
699 }
700
701 for stepNumber := 0; ; stepNumber++ {
702 stepInputMessages := append(initialPrompt, responseMessages...)
703 stepModel := a.settings.model
704 stepSystemPrompt := a.settings.systemPrompt
705 stepActiveTools := call.ActiveTools
706 stepToolChoice := ToolChoiceAuto
707 disableAllTools := false
708
709 // Apply step preparation if provided
710 if call.PrepareStep != nil {
711 prepared := call.PrepareStep(PrepareStepFunctionOptions{
712 Model: stepModel,
713 Steps: steps,
714 StepNumber: stepNumber,
715 Messages: stepInputMessages,
716 })
717
718 if prepared.Messages != nil {
719 stepInputMessages = prepared.Messages
720 }
721 if prepared.Model != nil {
722 stepModel = prepared.Model
723 }
724 if prepared.System != nil {
725 stepSystemPrompt = *prepared.System
726 }
727 if prepared.ToolChoice != nil {
728 stepToolChoice = *prepared.ToolChoice
729 }
730 if len(prepared.ActiveTools) > 0 {
731 stepActiveTools = prepared.ActiveTools
732 }
733 disableAllTools = prepared.DisableAllTools
734 }
735
736 // Recreate prompt with potentially modified system prompt
737 if stepSystemPrompt != a.settings.systemPrompt {
738 stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
739 if err != nil {
740 return nil, err
741 }
742 if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
743 stepInputMessages[0] = stepPrompt[0]
744 }
745 }
746
747 preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
748
749 // Start step stream
750 if opts.OnStepStart != nil {
751 opts.OnStepStart(stepNumber)
752 }
753
754 // Create streaming call
755 streamCall := Call{
756 Prompt: stepInputMessages,
757 MaxOutputTokens: call.MaxOutputTokens,
758 Temperature: call.Temperature,
759 TopP: call.TopP,
760 TopK: call.TopK,
761 PresencePenalty: call.PresencePenalty,
762 FrequencyPenalty: call.FrequencyPenalty,
763 Tools: preparedTools,
764 ToolChoice: &stepToolChoice,
765 Headers: call.Headers,
766 ProviderOptions: call.ProviderOptions,
767 }
768
769 // Get streaming response
770 stream, err := stepModel.Stream(ctx, streamCall)
771 if err != nil {
772 if opts.OnError != nil {
773 opts.OnError(err)
774 }
775 return nil, err
776 }
777
778 // Process stream with tool execution
779 stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
780 if err != nil {
781 if opts.OnError != nil {
782 opts.OnError(err)
783 }
784 return nil, err
785 }
786
787 steps = append(steps, stepResult)
788 totalUsage = addUsage(totalUsage, stepResult.Usage)
789
790 // Call step finished callback
791 if opts.OnStepFinish != nil {
792 opts.OnStepFinish(stepResult)
793 }
794
795 // Add step messages to response messages
796 stepMessages := toResponseMessages(stepResult.Content)
797 responseMessages = append(responseMessages, stepMessages...)
798
799 // Check stop conditions
800 shouldStop := isStopConditionMet(call.StopWhen, steps)
801 if shouldStop || !shouldContinue {
802 break
803 }
804 }
805
806 // Finish agent stream
807 agentResult := &AgentResult{
808 Steps: steps,
809 Response: steps[len(steps)-1].Response,
810 TotalUsage: totalUsage,
811 }
812
813 if opts.OnFinish != nil {
814 opts.OnFinish(agentResult)
815 }
816
817 if opts.OnAgentFinish != nil {
818 opts.OnAgentFinish(agentResult)
819 }
820
821 return agentResult, nil
822}
823
824func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string, disableAllTools bool) []Tool {
825 var preparedTools []Tool
826
827 // If explicitly disabling all tools, return no tools
828 if disableAllTools {
829 return preparedTools
830 }
831
832 for _, tool := range tools {
833 // If activeTools has items, only include tools in the list
834 // If activeTools is empty, include all tools
835 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
836 continue
837 }
838 info := tool.Info()
839 preparedTools = append(preparedTools, FunctionTool{
840 Name: info.Name,
841 Description: info.Description,
842 InputSchema: map[string]any{
843 "type": "object",
844 "properties": info.Parameters,
845 "required": info.Required,
846 },
847 })
848 }
849 return preparedTools
850}
851
852// validateAndRepairToolCall validates a tool call and attempts repair if validation fails
853func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []tools.BaseTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
854 if err := a.validateToolCall(toolCall, availableTools); err == nil {
855 return toolCall
856 } else {
857 if repairFunc != nil {
858 repairOptions := ToolCallRepairOptions{
859 OriginalToolCall: toolCall,
860 ValidationError: err,
861 AvailableTools: availableTools,
862 SystemPrompt: systemPrompt,
863 Messages: messages,
864 }
865
866 if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
867 if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
868 return *repairedToolCall
869 }
870 }
871 }
872
873 invalidToolCall := toolCall
874 invalidToolCall.Invalid = true
875 invalidToolCall.ValidationError = err
876 return invalidToolCall
877 }
878}
879
880// validateToolCall validates a tool call against available tools and their schemas
881func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []tools.BaseTool) error {
882 var tool tools.BaseTool
883 for _, t := range availableTools {
884 if t.Info().Name == toolCall.ToolName {
885 tool = t
886 break
887 }
888 }
889
890 if tool == nil {
891 return fmt.Errorf("tool not found: %s", toolCall.ToolName)
892 }
893
894 // Validate JSON parsing
895 var input map[string]any
896 if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
897 return fmt.Errorf("invalid JSON input: %w", err)
898 }
899
900 // Basic schema validation (check required fields)
901 // TODO: more robust schema validation using JSON Schema or similar
902 toolInfo := tool.Info()
903 for _, required := range toolInfo.Required {
904 if _, exists := input[required]; !exists {
905 return fmt.Errorf("missing required parameter: %s", required)
906 }
907 }
908 return nil
909}
910
911func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
912 if prompt == "" {
913 return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
914 }
915
916 var preparedPrompt Prompt
917
918 if system != "" {
919 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
920 }
921
922 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
923 preparedPrompt = append(preparedPrompt, messages...)
924 return preparedPrompt, nil
925}
926
927func WithSystemPrompt(prompt string) agentOption {
928 return func(s *AgentSettings) {
929 s.systemPrompt = prompt
930 }
931}
932
933func WithMaxOutputTokens(tokens int64) agentOption {
934 return func(s *AgentSettings) {
935 s.maxOutputTokens = &tokens
936 }
937}
938
939func WithTemperature(temp float64) agentOption {
940 return func(s *AgentSettings) {
941 s.temperature = &temp
942 }
943}
944
945func WithTopP(topP float64) agentOption {
946 return func(s *AgentSettings) {
947 s.topP = &topP
948 }
949}
950
951func WithTopK(topK int64) agentOption {
952 return func(s *AgentSettings) {
953 s.topK = &topK
954 }
955}
956
957func WithPresencePenalty(penalty float64) agentOption {
958 return func(s *AgentSettings) {
959 s.presencePenalty = &penalty
960 }
961}
962
963func WithFrequencyPenalty(penalty float64) agentOption {
964 return func(s *AgentSettings) {
965 s.frequencyPenalty = &penalty
966 }
967}
968
969func WithTools(tools ...tools.BaseTool) agentOption {
970 return func(s *AgentSettings) {
971 s.tools = append(s.tools, tools...)
972 }
973}
974
975func WithStopConditions(conditions ...StopCondition) agentOption {
976 return func(s *AgentSettings) {
977 s.stopWhen = append(s.stopWhen, conditions...)
978 }
979}
980
981func WithPrepareStep(fn PrepareStepFunction) agentOption {
982 return func(s *AgentSettings) {
983 s.prepareStep = fn
984 }
985}
986
987func WithRepairToolCall(fn RepairToolCallFunction) agentOption {
988 return func(s *AgentSettings) {
989 s.repairToolCall = fn
990 }
991}
992
993func WithOnStepFinished(fn OnStepFinishedFunction) agentOption {
994 return func(s *AgentSettings) {
995 s.onStepFinished = fn
996 }
997}
998
999// processStepStream processes a single step's stream and returns the step result
1000func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1001 var stepContent []Content
1002 var stepToolCalls []ToolCallContent
1003 var stepUsage Usage
1004 var stepFinishReason FinishReason = FinishReasonUnknown
1005 var stepWarnings []CallWarning
1006 var stepProviderMetadata ProviderMetadata
1007
1008 activeToolCalls := make(map[string]*ToolCallContent)
1009 activeTextContent := make(map[string]string)
1010
1011 // Process stream parts
1012 for part := range stream {
1013 // Forward all parts to chunk callback
1014 if opts.OnChunk != nil {
1015 opts.OnChunk(part)
1016 }
1017
1018 switch part.Type {
1019 case StreamPartTypeWarnings:
1020 stepWarnings = part.Warnings
1021 if opts.OnWarnings != nil {
1022 opts.OnWarnings(part.Warnings)
1023 }
1024
1025 case StreamPartTypeTextStart:
1026 activeTextContent[part.ID] = ""
1027 if opts.OnTextStart != nil {
1028 opts.OnTextStart(part.ID)
1029 }
1030
1031 case StreamPartTypeTextDelta:
1032 if _, exists := activeTextContent[part.ID]; exists {
1033 activeTextContent[part.ID] += part.Delta
1034 }
1035 if opts.OnTextDelta != nil {
1036 opts.OnTextDelta(part.ID, part.Delta)
1037 }
1038
1039 case StreamPartTypeTextEnd:
1040 if text, exists := activeTextContent[part.ID]; exists {
1041 stepContent = append(stepContent, TextContent{
1042 Text: text,
1043 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1044 })
1045 delete(activeTextContent, part.ID)
1046 }
1047 if opts.OnTextEnd != nil {
1048 opts.OnTextEnd(part.ID)
1049 }
1050
1051 case StreamPartTypeReasoningStart:
1052 activeTextContent[part.ID] = ""
1053 if opts.OnReasoningStart != nil {
1054 opts.OnReasoningStart(part.ID)
1055 }
1056
1057 case StreamPartTypeReasoningDelta:
1058 if _, exists := activeTextContent[part.ID]; exists {
1059 activeTextContent[part.ID] += part.Delta
1060 }
1061 if opts.OnReasoningDelta != nil {
1062 opts.OnReasoningDelta(part.ID, part.Delta)
1063 }
1064
1065 case StreamPartTypeReasoningEnd:
1066 if text, exists := activeTextContent[part.ID]; exists {
1067 stepContent = append(stepContent, ReasoningContent{
1068 Text: text,
1069 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1070 })
1071 delete(activeTextContent, part.ID)
1072 }
1073 if opts.OnReasoningEnd != nil {
1074 opts.OnReasoningEnd(part.ID)
1075 }
1076
1077 case StreamPartTypeToolInputStart:
1078 activeToolCalls[part.ID] = &ToolCallContent{
1079 ToolCallID: part.ID,
1080 ToolName: part.ToolCallName,
1081 Input: "",
1082 ProviderExecuted: part.ProviderExecuted,
1083 }
1084 if opts.OnToolInputStart != nil {
1085 opts.OnToolInputStart(part.ID, part.ToolCallName)
1086 }
1087
1088 case StreamPartTypeToolInputDelta:
1089 if toolCall, exists := activeToolCalls[part.ID]; exists {
1090 toolCall.Input += part.Delta
1091 }
1092 if opts.OnToolInputDelta != nil {
1093 opts.OnToolInputDelta(part.ID, part.Delta)
1094 }
1095
1096 case StreamPartTypeToolInputEnd:
1097 if opts.OnToolInputEnd != nil {
1098 opts.OnToolInputEnd(part.ID)
1099 }
1100
1101 case StreamPartTypeToolCall:
1102 toolCall := ToolCallContent{
1103 ToolCallID: part.ID,
1104 ToolName: part.ToolCallName,
1105 Input: part.ToolCallInput,
1106 ProviderExecuted: part.ProviderExecuted,
1107 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1108 }
1109
1110 // Validate and potentially repair the tool call
1111 validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1112 stepToolCalls = append(stepToolCalls, validatedToolCall)
1113 stepContent = append(stepContent, validatedToolCall)
1114
1115 if opts.OnToolCall != nil {
1116 opts.OnToolCall(validatedToolCall)
1117 }
1118
1119 // Clean up active tool call
1120 delete(activeToolCalls, part.ID)
1121
1122 case StreamPartTypeSource:
1123 sourceContent := SourceContent{
1124 SourceType: part.SourceType,
1125 ID: part.ID,
1126 URL: part.URL,
1127 Title: part.Title,
1128 ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
1129 }
1130 stepContent = append(stepContent, sourceContent)
1131 if opts.OnSource != nil {
1132 opts.OnSource(sourceContent)
1133 }
1134
1135 case StreamPartTypeFinish:
1136 stepUsage = part.Usage
1137 stepFinishReason = part.FinishReason
1138 stepProviderMetadata = ProviderMetadata(part.ProviderMetadata)
1139 if opts.OnStreamFinish != nil {
1140 opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
1141 }
1142
1143 case StreamPartTypeError:
1144 if opts.OnStreamError != nil {
1145 opts.OnStreamError(part.Error)
1146 }
1147 if opts.OnError != nil {
1148 opts.OnError(part.Error)
1149 }
1150 return StepResult{}, false, part.Error
1151 }
1152 }
1153
1154 // Execute tools if any
1155 var toolResults []ToolResultContent
1156 if len(stepToolCalls) > 0 {
1157 var err error
1158 toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1159 if err != nil {
1160 return StepResult{}, false, err
1161 }
1162 // Add tool results to content
1163 for _, result := range toolResults {
1164 stepContent = append(stepContent, result)
1165 }
1166 }
1167
1168 stepResult := StepResult{
1169 Response: Response{
1170 Content: stepContent,
1171 FinishReason: stepFinishReason,
1172 Usage: stepUsage,
1173 Warnings: stepWarnings,
1174 ProviderMetadata: stepProviderMetadata,
1175 },
1176 Messages: toResponseMessages(stepContent),
1177 }
1178
1179 // Determine if we should continue (has tool calls and not stopped)
1180 shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
1181
1182 return stepResult, shouldContinue, nil
1183}
1184
1185func addUsage(a, b Usage) Usage {
1186 return Usage{
1187 InputTokens: a.InputTokens + b.InputTokens,
1188 OutputTokens: a.OutputTokens + b.OutputTokens,
1189 TotalTokens: a.TotalTokens + b.TotalTokens,
1190 ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
1191 CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
1192 CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
1193 }
1194}
1195
1196func WithHeaders(headers map[string]string) agentOption {
1197 return func(s *AgentSettings) {
1198 s.headers = headers
1199 }
1200}
1201
1202func WithProviderOptions(providerOptions ProviderOptions) agentOption {
1203 return func(s *AgentSettings) {
1204 s.providerOptions = providerOptions
1205 }
1206}