agent.go

  1package ai
  2
  3import (
  4	"context"
  5	"errors"
  6	"maps"
  7	"slices"
  8	"sync"
  9
 10	"github.com/charmbracelet/crush/internal/llm/tools"
 11)
 12
 13type StepResult struct {
 14	Response
 15	// Messages generated during this step
 16	Messages []Message
 17}
 18
 19type StopCondition = func(steps []StepResult) bool
 20
 21type PrepareStepFunctionOptions struct {
 22	Steps      []StepResult
 23	StepNumber int
 24	Model      LanguageModel
 25	Messages   []Message
 26}
 27
 28type PrepareStepResult struct {
 29	Model    LanguageModel
 30	Messages []Message
 31}
 32
 33type (
 34	PrepareStepFunction    = func(options PrepareStepFunctionOptions) PrepareStepResult
 35	OnStepFinishedFunction = func(step StepResult)
 36	RepairToolCall         = func(ToolCallContent) ToolCallContent
 37)
 38
 39type AgentSettings struct {
 40	systemPrompt     string
 41	maxOutputTokens  *int64
 42	temperature      *float64
 43	topP             *float64
 44	topK             *int64
 45	presencePenalty  *float64
 46	frequencyPenalty *float64
 47	headers          map[string]string
 48	providerOptions  ProviderOptions
 49
 50	// TODO: add support for provider tools
 51	tools      []tools.BaseTool
 52	maxRetries *int
 53
 54	model LanguageModel
 55
 56	stopWhen       []StopCondition
 57	prepareStep    PrepareStepFunction
 58	repairToolCall RepairToolCall
 59	onStepFinished OnStepFinishedFunction
 60	onRetry        OnRetryCallback
 61}
 62
 63type AgentCall struct {
 64	Prompt           string     `json:"prompt"`
 65	Files            []FilePart `json:"files"`
 66	Messages         []Message  `json:"messages"`
 67	MaxOutputTokens  *int64
 68	Temperature      *float64 `json:"temperature"`
 69	TopP             *float64 `json:"top_p"`
 70	TopK             *int64   `json:"top_k"`
 71	PresencePenalty  *float64 `json:"presence_penalty"`
 72	FrequencyPenalty *float64 `json:"frequency_penalty"`
 73	ActiveTools      []string `json:"active_tools"`
 74	Headers          map[string]string
 75	ProviderOptions  ProviderOptions
 76	OnRetry          OnRetryCallback
 77	MaxRetries       *int
 78
 79	StopWhen       []StopCondition
 80	PrepareStep    PrepareStepFunction
 81	RepairToolCall RepairToolCall
 82	OnStepFinished OnStepFinishedFunction
 83}
 84
 85type AgentResult struct {
 86	Steps []StepResult
 87	// Final response
 88	Response   Response
 89	TotalUsage Usage
 90}
 91
 92type Agent interface {
 93	Generate(context.Context, AgentCall) (*AgentResult, error)
 94	Stream(context.Context, AgentCall) (StreamResponse, error)
 95}
 96
 97type agentOption = func(*AgentSettings)
 98
 99type agent struct {
100	settings AgentSettings
101}
102
103func NewAgent(model LanguageModel, opts ...agentOption) Agent {
104	settings := AgentSettings{
105		model: model,
106	}
107	for _, o := range opts {
108		o(&settings)
109	}
110	return &agent{
111		settings: settings,
112	}
113}
114
115func (a *agent) prepareCall(call AgentCall) AgentCall {
116	if call.MaxOutputTokens == nil && a.settings.maxOutputTokens != nil {
117		call.MaxOutputTokens = a.settings.maxOutputTokens
118	}
119	if call.Temperature == nil && a.settings.temperature != nil {
120		call.Temperature = a.settings.temperature
121	}
122	if call.TopP == nil && a.settings.topP != nil {
123		call.TopP = a.settings.topP
124	}
125	if call.TopK == nil && a.settings.topK != nil {
126		call.TopK = a.settings.topK
127	}
128	if call.PresencePenalty == nil && a.settings.presencePenalty != nil {
129		call.PresencePenalty = a.settings.presencePenalty
130	}
131	if call.FrequencyPenalty == nil && a.settings.frequencyPenalty != nil {
132		call.FrequencyPenalty = a.settings.frequencyPenalty
133	}
134	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
135		call.StopWhen = a.settings.stopWhen
136	}
137	if call.PrepareStep == nil && a.settings.prepareStep != nil {
138		call.PrepareStep = a.settings.prepareStep
139	}
140	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
141		call.RepairToolCall = a.settings.repairToolCall
142	}
143	if call.OnStepFinished == nil && a.settings.onStepFinished != nil {
144		call.OnStepFinished = a.settings.onStepFinished
145	}
146	if call.OnRetry == nil && a.settings.onRetry != nil {
147		call.OnRetry = a.settings.onRetry
148	}
149	if call.MaxRetries == nil && a.settings.maxRetries != nil {
150		call.MaxRetries = a.settings.maxRetries
151	}
152
153	providerOptions := ProviderOptions{}
154	if a.settings.providerOptions != nil {
155		maps.Copy(providerOptions, a.settings.providerOptions)
156	}
157	if call.ProviderOptions != nil {
158		maps.Copy(providerOptions, call.ProviderOptions)
159	}
160	call.ProviderOptions = providerOptions
161
162	headers := map[string]string{}
163
164	if a.settings.headers != nil {
165		maps.Copy(headers, a.settings.headers)
166	}
167
168	if call.Headers != nil {
169		maps.Copy(headers, call.Headers)
170	}
171	call.Headers = headers
172	return call
173}
174
175// Generate implements Agent.
176func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
177	opts = a.prepareCall(opts)
178	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
179	if err != nil {
180		return nil, err
181	}
182	var responseMessages []Message
183	var steps []StepResult
184
185	for {
186		stepInputMessages := append(initialPrompt, responseMessages...)
187		stepModel := a.settings.model
188		if opts.PrepareStep != nil {
189			prepared := opts.PrepareStep(PrepareStepFunctionOptions{
190				Model:      stepModel,
191				Steps:      steps,
192				StepNumber: len(steps),
193				Messages:   stepInputMessages,
194			})
195			stepInputMessages = prepared.Messages
196			if prepared.Model != nil {
197				stepModel = prepared.Model
198			}
199		}
200
201		preparedTools := a.prepareTools(a.settings.tools, opts.ActiveTools)
202
203		toolChoice := ToolChoiceAuto
204		retryOptions := DefaultRetryOptions()
205		retryOptions.OnRetry = opts.OnRetry
206		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
207
208		result, err := retry(ctx, func() (*Response, error) {
209			return stepModel.Generate(ctx, Call{
210				Prompt:           stepInputMessages,
211				MaxOutputTokens:  opts.MaxOutputTokens,
212				Temperature:      opts.Temperature,
213				TopP:             opts.TopP,
214				TopK:             opts.TopK,
215				PresencePenalty:  opts.PresencePenalty,
216				FrequencyPenalty: opts.FrequencyPenalty,
217				Tools:            preparedTools,
218				ToolChoice:       &toolChoice,
219				Headers:          opts.Headers,
220				ProviderOptions:  opts.ProviderOptions,
221			})
222		})
223		if err != nil {
224			return nil, err
225		}
226
227		var stepToolCalls []ToolCallContent
228		for _, content := range result.Content {
229			if content.GetType() == ContentTypeToolCall {
230				toolCall, ok := AsContentType[ToolCallContent](content)
231				if !ok {
232					continue
233				}
234				stepToolCalls = append(stepToolCalls, toolCall)
235			}
236		}
237
238		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls)
239
240		stepContent := result.Content
241		for _, result := range toolResults {
242			stepContent = append(stepContent, result)
243		}
244		currentStepMessages := toResponseMessages(stepContent)
245		responseMessages = append(responseMessages, currentStepMessages...)
246
247		stepResult := StepResult{
248			Response: *result,
249			Messages: currentStepMessages,
250		}
251		steps = append(steps, stepResult)
252		if opts.OnStepFinished != nil {
253			opts.OnStepFinished(stepResult)
254		}
255
256		shouldStop := isStopConditionMet(opts.StopWhen, steps)
257
258		if shouldStop || err != nil || len(stepToolCalls) == 0 {
259			break
260		}
261	}
262
263	totalUsage := Usage{}
264
265	for _, step := range steps {
266		usage := step.Usage
267		totalUsage.InputTokens += usage.InputTokens
268		totalUsage.OutputTokens += usage.OutputTokens
269		totalUsage.ReasoningTokens += usage.ReasoningTokens
270		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
271		totalUsage.CacheReadTokens += usage.CacheReadTokens
272		totalUsage.TotalTokens += totalUsage.TotalTokens
273	}
274
275	agentResult := &AgentResult{
276		Steps:      steps,
277		Response:   steps[len(steps)-1].Response,
278		TotalUsage: totalUsage,
279	}
280	return agentResult, nil
281}
282
283func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
284	if len(conditions) == 0 {
285		return false
286	}
287
288	for _, condition := range conditions {
289		if condition(steps) {
290			return true
291		}
292	}
293	return false
294}
295
296func toResponseMessages(content []Content) []Message {
297	var assistantParts []MessagePart
298	var toolParts []MessagePart
299
300	for _, c := range content {
301		switch c.GetType() {
302		case ContentTypeText:
303			text, ok := AsContentType[TextContent](c)
304			if !ok {
305				continue
306			}
307			assistantParts = append(assistantParts, TextPart{
308				Text:            text.Text,
309				ProviderOptions: ProviderOptions(text.ProviderMetadata),
310			})
311		case ContentTypeReasoning:
312			reasoning, ok := AsContentType[ReasoningContent](c)
313			if !ok {
314				continue
315			}
316			assistantParts = append(assistantParts, ReasoningPart{
317				Text:            reasoning.Text,
318				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
319			})
320		case ContentTypeToolCall:
321			toolCall, ok := AsContentType[ToolCallContent](c)
322			if !ok {
323				continue
324			}
325			assistantParts = append(assistantParts, ToolCallPart{
326				ToolCallID:       toolCall.ToolCallID,
327				ToolName:         toolCall.ToolName,
328				Input:            toolCall.Input,
329				ProviderExecuted: toolCall.ProviderExecuted,
330				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
331			})
332		case ContentTypeFile:
333			file, ok := AsContentType[FileContent](c)
334			if !ok {
335				continue
336			}
337			assistantParts = append(assistantParts, FilePart{
338				Data:            file.Data,
339				MediaType:       file.MediaType,
340				ProviderOptions: ProviderOptions(file.ProviderMetadata),
341			})
342		case ContentTypeToolResult:
343			result, ok := AsContentType[ToolResultContent](c)
344			if !ok {
345				continue
346			}
347			toolParts = append(toolParts, ToolResultPart{
348				ToolCallID:      result.ToolCallID,
349				Output:          result.Result,
350				ProviderOptions: ProviderOptions(result.ProviderMetadata),
351			})
352		}
353	}
354
355	var messages []Message
356	if len(assistantParts) > 0 {
357		messages = append(messages, Message{
358			Role:    MessageRoleAssistant,
359			Content: assistantParts,
360		})
361	}
362	if len(toolParts) > 0 {
363		messages = append(messages, Message{
364			Role:    MessageRoleTool,
365			Content: toolParts,
366		})
367	}
368	return messages
369}
370
371func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent) ([]ToolResultContent, error) {
372	if len(toolCalls) == 0 {
373		return nil, nil
374	}
375
376	// Create a map for quick tool lookup
377	toolMap := make(map[string]tools.BaseTool)
378	for _, tool := range allTools {
379		toolMap[tool.Info().Name] = tool
380	}
381
382	// Execute all tool calls in parallel
383	results := make([]ToolResultContent, len(toolCalls))
384	var toolExecutionError error
385	var wg sync.WaitGroup
386
387	for i, toolCall := range toolCalls {
388		wg.Add(1)
389		go func(index int, call ToolCallContent) {
390			defer wg.Done()
391
392			tool, exists := toolMap[call.ToolName]
393			if !exists {
394				results[index] = ToolResultContent{
395					ToolCallID: call.ToolCallID,
396					ToolName:   call.ToolName,
397					Result: ToolResultOutputContentError{
398						Error: errors.New("Error: Tool not found: " + call.ToolName),
399					},
400					ProviderExecuted: false,
401				}
402				return
403			}
404
405			// Execute the tool
406			result, err := tool.Run(ctx, tools.ToolCall{
407				ID:    call.ToolCallID,
408				Name:  call.ToolName,
409				Input: call.Input,
410			})
411			if err != nil {
412				results[index] = ToolResultContent{
413					ToolCallID: call.ToolCallID,
414					ToolName:   call.ToolName,
415					Result: ToolResultOutputContentError{
416						Error: err,
417					},
418					ProviderExecuted: false,
419				}
420				toolExecutionError = err
421				return
422			}
423
424			if result.IsError {
425				results[index] = ToolResultContent{
426					ToolCallID: call.ToolCallID,
427					ToolName:   call.ToolName,
428					Result: ToolResultOutputContentError{
429						Error: errors.New(result.Content),
430					},
431					ProviderExecuted: false,
432				}
433			} else {
434				results[index] = ToolResultContent{
435					ToolCallID: call.ToolCallID,
436					ToolName:   toolCall.ToolName,
437					Result: ToolResultOutputContentText{
438						Text: result.Content,
439					},
440					ProviderExecuted: false,
441				}
442			}
443		}(i, toolCall)
444	}
445
446	// Wait for all tool executions to complete
447	wg.Wait()
448
449	return results, toolExecutionError
450}
451
452// Stream implements Agent.
453func (a *agent) Stream(ctx context.Context, opts AgentCall) (StreamResponse, error) {
454	// TODO: implement the agentic stuff
455	panic("not implemented")
456}
457
458func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string) []Tool {
459	var preparedTools []Tool
460	for _, tool := range tools {
461		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
462			continue
463		}
464		info := tool.Info()
465		preparedTools = append(preparedTools, FunctionTool{
466			Name:        info.Name,
467			Description: info.Description,
468			InputSchema: map[string]any{
469				"type":       "object",
470				"properties": info.Parameters,
471				"required":   info.Required,
472			},
473		})
474	}
475	return preparedTools
476}
477
478func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
479	if prompt == "" {
480		return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
481	}
482
483	var preparedPrompt Prompt
484
485	if system != "" {
486		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
487	}
488
489	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
490	preparedPrompt = append(preparedPrompt, messages...)
491	return preparedPrompt, nil
492}
493
494func WithSystemPrompt(prompt string) agentOption {
495	return func(s *AgentSettings) {
496		s.systemPrompt = prompt
497	}
498}
499
500func WithMaxOutputTokens(tokens int64) agentOption {
501	return func(s *AgentSettings) {
502		s.maxOutputTokens = &tokens
503	}
504}
505
506func WithTemperature(temp float64) agentOption {
507	return func(s *AgentSettings) {
508		s.temperature = &temp
509	}
510}
511
512func WithTopP(topP float64) agentOption {
513	return func(s *AgentSettings) {
514		s.topP = &topP
515	}
516}
517
518func WithTopK(topK int64) agentOption {
519	return func(s *AgentSettings) {
520		s.topK = &topK
521	}
522}
523
524func WithPresencePenalty(penalty float64) agentOption {
525	return func(s *AgentSettings) {
526		s.presencePenalty = &penalty
527	}
528}
529
530func WithFrequencyPenalty(penalty float64) agentOption {
531	return func(s *AgentSettings) {
532		s.frequencyPenalty = &penalty
533	}
534}
535
536func WithTools(tools ...tools.BaseTool) agentOption {
537	return func(s *AgentSettings) {
538		s.tools = append(s.tools, tools...)
539	}
540}
541
542func WithStopConditions(conditions ...StopCondition) agentOption {
543	return func(s *AgentSettings) {
544		s.stopWhen = append(s.stopWhen, conditions...)
545	}
546}
547
548func WithPrepareStep(fn PrepareStepFunction) agentOption {
549	return func(s *AgentSettings) {
550		s.prepareStep = fn
551	}
552}
553
554func WithRepairToolCall(fn RepairToolCall) agentOption {
555	return func(s *AgentSettings) {
556		s.repairToolCall = fn
557	}
558}
559
560func WithOnStepFinished(fn OnStepFinishedFunction) agentOption {
561	return func(s *AgentSettings) {
562		s.onStepFinished = fn
563	}
564}
565
566func WithHeaders(headers map[string]string) agentOption {
567	return func(s *AgentSettings) {
568		s.headers = headers
569	}
570}
571
572func WithProviderOptions(providerOptions ProviderOptions) agentOption {
573	return func(s *AgentSettings) {
574		s.providerOptions = providerOptions
575	}
576}