1package ai
  2
  3import (
  4	"context"
  5	"errors"
  6	"maps"
  7	"slices"
  8	"sync"
  9
 10	"github.com/charmbracelet/crush/internal/llm/tools"
 11)
 12
 13type StepResult struct {
 14	Response
 15	// Messages generated during this step
 16	Messages []Message
 17}
 18
 19type StopCondition = func(steps []StepResult) bool
 20
 21type PrepareStepFunctionOptions struct {
 22	Steps      []StepResult
 23	StepNumber int
 24	Model      LanguageModel
 25	Messages   []Message
 26}
 27
 28type PrepareStepResult struct {
 29	Model    LanguageModel
 30	Messages []Message
 31}
 32
 33type (
 34	PrepareStepFunction    = func(options PrepareStepFunctionOptions) PrepareStepResult
 35	OnStepFinishedFunction = func(step StepResult)
 36	RepairToolCall         = func(ToolCallContent) ToolCallContent
 37)
 38
 39type AgentSettings struct {
 40	systemPrompt     string
 41	maxOutputTokens  *int64
 42	temperature      *float64
 43	topP             *float64
 44	topK             *int64
 45	presencePenalty  *float64
 46	frequencyPenalty *float64
 47	headers          map[string]string
 48	providerOptions  ProviderOptions
 49
 50	// TODO: add support for provider tools
 51	tools      []tools.BaseTool
 52	maxRetries *int
 53
 54	model LanguageModel
 55
 56	stopWhen       []StopCondition
 57	prepareStep    PrepareStepFunction
 58	repairToolCall RepairToolCall
 59	onStepFinished OnStepFinishedFunction
 60	onRetry        OnRetryCallback
 61}
 62
 63type AgentCall struct {
 64	Prompt           string     `json:"prompt"`
 65	Files            []FilePart `json:"files"`
 66	Messages         []Message  `json:"messages"`
 67	MaxOutputTokens  *int64
 68	Temperature      *float64 `json:"temperature"`
 69	TopP             *float64 `json:"top_p"`
 70	TopK             *int64   `json:"top_k"`
 71	PresencePenalty  *float64 `json:"presence_penalty"`
 72	FrequencyPenalty *float64 `json:"frequency_penalty"`
 73	ActiveTools      []string `json:"active_tools"`
 74	Headers          map[string]string
 75	ProviderOptions  ProviderOptions
 76	OnRetry          OnRetryCallback
 77	MaxRetries       *int
 78
 79	StopWhen       []StopCondition
 80	PrepareStep    PrepareStepFunction
 81	RepairToolCall RepairToolCall
 82	OnStepFinished OnStepFinishedFunction
 83}
 84
 85type AgentResult struct {
 86	Steps []StepResult
 87	// Final response
 88	Response   Response
 89	TotalUsage Usage
 90}
 91
 92type Agent interface {
 93	Generate(context.Context, AgentCall) (*AgentResult, error)
 94	Stream(context.Context, AgentCall) (StreamResponse, error)
 95}
 96
 97type agentOption = func(*AgentSettings)
 98
 99type agent struct {
100	settings AgentSettings
101}
102
103func NewAgent(model LanguageModel, opts ...agentOption) Agent {
104	settings := AgentSettings{
105		model: model,
106	}
107	for _, o := range opts {
108		o(&settings)
109	}
110	return &agent{
111		settings: settings,
112	}
113}
114
115func (a *agent) prepareCall(call AgentCall) AgentCall {
116	if call.MaxOutputTokens == nil && a.settings.maxOutputTokens != nil {
117		call.MaxOutputTokens = a.settings.maxOutputTokens
118	}
119	if call.Temperature == nil && a.settings.temperature != nil {
120		call.Temperature = a.settings.temperature
121	}
122	if call.TopP == nil && a.settings.topP != nil {
123		call.TopP = a.settings.topP
124	}
125	if call.TopK == nil && a.settings.topK != nil {
126		call.TopK = a.settings.topK
127	}
128	if call.PresencePenalty == nil && a.settings.presencePenalty != nil {
129		call.PresencePenalty = a.settings.presencePenalty
130	}
131	if call.FrequencyPenalty == nil && a.settings.frequencyPenalty != nil {
132		call.FrequencyPenalty = a.settings.frequencyPenalty
133	}
134	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
135		call.StopWhen = a.settings.stopWhen
136	}
137	if call.PrepareStep == nil && a.settings.prepareStep != nil {
138		call.PrepareStep = a.settings.prepareStep
139	}
140	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
141		call.RepairToolCall = a.settings.repairToolCall
142	}
143	if call.OnStepFinished == nil && a.settings.onStepFinished != nil {
144		call.OnStepFinished = a.settings.onStepFinished
145	}
146	if call.OnRetry == nil && a.settings.onRetry != nil {
147		call.OnRetry = a.settings.onRetry
148	}
149	if call.MaxRetries == nil && a.settings.maxRetries != nil {
150		call.MaxRetries = a.settings.maxRetries
151	}
152
153	providerOptions := ProviderOptions{}
154	if a.settings.providerOptions != nil {
155		maps.Copy(providerOptions, a.settings.providerOptions)
156	}
157	if call.ProviderOptions != nil {
158		maps.Copy(providerOptions, call.ProviderOptions)
159	}
160	call.ProviderOptions = providerOptions
161
162	headers := map[string]string{}
163
164	if a.settings.headers != nil {
165		maps.Copy(headers, a.settings.headers)
166	}
167
168	if call.Headers != nil {
169		maps.Copy(headers, call.Headers)
170	}
171	call.Headers = headers
172	return call
173}
174
175// Generate implements Agent.
176func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
177	opts = a.prepareCall(opts)
178	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
179	if err != nil {
180		return nil, err
181	}
182	var responseMessages []Message
183	var steps []StepResult
184
185	for {
186		stepInputMessages := append(initialPrompt, responseMessages...)
187		stepModel := a.settings.model
188		if opts.PrepareStep != nil {
189			prepared := opts.PrepareStep(PrepareStepFunctionOptions{
190				Model:      stepModel,
191				Steps:      steps,
192				StepNumber: len(steps),
193				Messages:   stepInputMessages,
194			})
195			stepInputMessages = prepared.Messages
196			if prepared.Model != nil {
197				stepModel = prepared.Model
198			}
199		}
200
201		preparedTools := a.prepareTools(a.settings.tools, opts.ActiveTools)
202
203		toolChoice := ToolChoiceAuto
204		retryOptions := DefaultRetryOptions()
205		retryOptions.OnRetry = opts.OnRetry
206		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
207
208		result, err := retry(ctx, func() (*Response, error) {
209			return stepModel.Generate(ctx, Call{
210				Prompt:           stepInputMessages,
211				MaxOutputTokens:  opts.MaxOutputTokens,
212				Temperature:      opts.Temperature,
213				TopP:             opts.TopP,
214				TopK:             opts.TopK,
215				PresencePenalty:  opts.PresencePenalty,
216				FrequencyPenalty: opts.FrequencyPenalty,
217				Tools:            preparedTools,
218				ToolChoice:       &toolChoice,
219				Headers:          opts.Headers,
220				ProviderOptions:  opts.ProviderOptions,
221			})
222		})
223		if err != nil {
224			return nil, err
225		}
226
227		var stepToolCalls []ToolCallContent
228		for _, content := range result.Content {
229			if content.GetType() == ContentTypeToolCall {
230				toolCall, ok := AsContentType[ToolCallContent](content)
231				if !ok {
232					continue
233				}
234				stepToolCalls = append(stepToolCalls, toolCall)
235			}
236		}
237
238		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls)
239
240		stepContent := result.Content
241		for _, result := range toolResults {
242			stepContent = append(stepContent, result)
243		}
244		currentStepMessages := toResponseMessages(stepContent)
245		responseMessages = append(responseMessages, currentStepMessages...)
246
247		stepResult := StepResult{
248			Response: Response{
249				Content:          stepContent,
250				FinishReason:     result.FinishReason,
251				Usage:            result.Usage,
252				Warnings:         result.Warnings,
253				ProviderMetadata: result.ProviderMetadata,
254			},
255			Messages: currentStepMessages,
256		}
257		steps = append(steps, stepResult)
258		if opts.OnStepFinished != nil {
259			opts.OnStepFinished(stepResult)
260		}
261
262		shouldStop := isStopConditionMet(opts.StopWhen, steps)
263
264		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
265			break
266		}
267	}
268
269	totalUsage := Usage{}
270
271	for _, step := range steps {
272		usage := step.Usage
273		totalUsage.InputTokens += usage.InputTokens
274		totalUsage.OutputTokens += usage.OutputTokens
275		totalUsage.ReasoningTokens += usage.ReasoningTokens
276		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
277		totalUsage.CacheReadTokens += usage.CacheReadTokens
278		totalUsage.TotalTokens += usage.TotalTokens
279	}
280
281	agentResult := &AgentResult{
282		Steps:      steps,
283		Response:   steps[len(steps)-1].Response,
284		TotalUsage: totalUsage,
285	}
286	return agentResult, nil
287}
288
289func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
290	if len(conditions) == 0 {
291		return false
292	}
293
294	for _, condition := range conditions {
295		if condition(steps) {
296			return true
297		}
298	}
299	return false
300}
301
302func toResponseMessages(content []Content) []Message {
303	var assistantParts []MessagePart
304	var toolParts []MessagePart
305
306	for _, c := range content {
307		switch c.GetType() {
308		case ContentTypeText:
309			text, ok := AsContentType[TextContent](c)
310			if !ok {
311				continue
312			}
313			assistantParts = append(assistantParts, TextPart{
314				Text:            text.Text,
315				ProviderOptions: ProviderOptions(text.ProviderMetadata),
316			})
317		case ContentTypeReasoning:
318			reasoning, ok := AsContentType[ReasoningContent](c)
319			if !ok {
320				continue
321			}
322			assistantParts = append(assistantParts, ReasoningPart{
323				Text:            reasoning.Text,
324				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
325			})
326		case ContentTypeToolCall:
327			toolCall, ok := AsContentType[ToolCallContent](c)
328			if !ok {
329				continue
330			}
331			assistantParts = append(assistantParts, ToolCallPart{
332				ToolCallID:       toolCall.ToolCallID,
333				ToolName:         toolCall.ToolName,
334				Input:            toolCall.Input,
335				ProviderExecuted: toolCall.ProviderExecuted,
336				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
337			})
338		case ContentTypeFile:
339			file, ok := AsContentType[FileContent](c)
340			if !ok {
341				continue
342			}
343			assistantParts = append(assistantParts, FilePart{
344				Data:            file.Data,
345				MediaType:       file.MediaType,
346				ProviderOptions: ProviderOptions(file.ProviderMetadata),
347			})
348		case ContentTypeToolResult:
349			result, ok := AsContentType[ToolResultContent](c)
350			if !ok {
351				continue
352			}
353			toolParts = append(toolParts, ToolResultPart{
354				ToolCallID:      result.ToolCallID,
355				Output:          result.Result,
356				ProviderOptions: ProviderOptions(result.ProviderMetadata),
357			})
358		}
359	}
360
361	var messages []Message
362	if len(assistantParts) > 0 {
363		messages = append(messages, Message{
364			Role:    MessageRoleAssistant,
365			Content: assistantParts,
366		})
367	}
368	if len(toolParts) > 0 {
369		messages = append(messages, Message{
370			Role:    MessageRoleTool,
371			Content: toolParts,
372		})
373	}
374	return messages
375}
376
377func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent) ([]ToolResultContent, error) {
378	if len(toolCalls) == 0 {
379		return nil, nil
380	}
381
382	// Create a map for quick tool lookup
383	toolMap := make(map[string]tools.BaseTool)
384	for _, tool := range allTools {
385		toolMap[tool.Info().Name] = tool
386	}
387
388	// Execute all tool calls in parallel
389	results := make([]ToolResultContent, len(toolCalls))
390	var toolExecutionError error
391	var wg sync.WaitGroup
392
393	for i, toolCall := range toolCalls {
394		wg.Add(1)
395		go func(index int, call ToolCallContent) {
396			defer wg.Done()
397
398			tool, exists := toolMap[call.ToolName]
399			if !exists {
400				results[index] = ToolResultContent{
401					ToolCallID: call.ToolCallID,
402					ToolName:   call.ToolName,
403					Result: ToolResultOutputContentError{
404						Error: errors.New("Error: Tool not found: " + call.ToolName),
405					},
406					ProviderExecuted: false,
407				}
408				return
409			}
410
411			// Execute the tool
412			result, err := tool.Run(ctx, tools.ToolCall{
413				ID:    call.ToolCallID,
414				Name:  call.ToolName,
415				Input: call.Input,
416			})
417			if err != nil {
418				results[index] = ToolResultContent{
419					ToolCallID: call.ToolCallID,
420					ToolName:   call.ToolName,
421					Result: ToolResultOutputContentError{
422						Error: err,
423					},
424					ProviderExecuted: false,
425				}
426				toolExecutionError = err
427				return
428			}
429
430			if result.IsError {
431				results[index] = ToolResultContent{
432					ToolCallID: call.ToolCallID,
433					ToolName:   call.ToolName,
434					Result: ToolResultOutputContentError{
435						Error: errors.New(result.Content),
436					},
437					ProviderExecuted: false,
438				}
439			} else {
440				results[index] = ToolResultContent{
441					ToolCallID: call.ToolCallID,
442					ToolName:   toolCall.ToolName,
443					Result: ToolResultOutputContentText{
444						Text: result.Content,
445					},
446					ProviderExecuted: false,
447				}
448			}
449		}(i, toolCall)
450	}
451
452	// Wait for all tool executions to complete
453	wg.Wait()
454
455	return results, toolExecutionError
456}
457
458// Stream implements Agent.
459func (a *agent) Stream(ctx context.Context, opts AgentCall) (StreamResponse, error) {
460	// TODO: implement the agentic stuff
461	panic("not implemented")
462}
463
464func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string) []Tool {
465	var preparedTools []Tool
466	for _, tool := range tools {
467		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
468			continue
469		}
470		info := tool.Info()
471		preparedTools = append(preparedTools, FunctionTool{
472			Name:        info.Name,
473			Description: info.Description,
474			InputSchema: map[string]any{
475				"type":       "object",
476				"properties": info.Parameters,
477				"required":   info.Required,
478			},
479		})
480	}
481	return preparedTools
482}
483
484func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
485	if prompt == "" {
486		return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
487	}
488
489	var preparedPrompt Prompt
490
491	if system != "" {
492		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
493	}
494
495	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
496	preparedPrompt = append(preparedPrompt, messages...)
497	return preparedPrompt, nil
498}
499
500func WithSystemPrompt(prompt string) agentOption {
501	return func(s *AgentSettings) {
502		s.systemPrompt = prompt
503	}
504}
505
506func WithMaxOutputTokens(tokens int64) agentOption {
507	return func(s *AgentSettings) {
508		s.maxOutputTokens = &tokens
509	}
510}
511
512func WithTemperature(temp float64) agentOption {
513	return func(s *AgentSettings) {
514		s.temperature = &temp
515	}
516}
517
518func WithTopP(topP float64) agentOption {
519	return func(s *AgentSettings) {
520		s.topP = &topP
521	}
522}
523
524func WithTopK(topK int64) agentOption {
525	return func(s *AgentSettings) {
526		s.topK = &topK
527	}
528}
529
530func WithPresencePenalty(penalty float64) agentOption {
531	return func(s *AgentSettings) {
532		s.presencePenalty = &penalty
533	}
534}
535
536func WithFrequencyPenalty(penalty float64) agentOption {
537	return func(s *AgentSettings) {
538		s.frequencyPenalty = &penalty
539	}
540}
541
542func WithTools(tools ...tools.BaseTool) agentOption {
543	return func(s *AgentSettings) {
544		s.tools = append(s.tools, tools...)
545	}
546}
547
548func WithStopConditions(conditions ...StopCondition) agentOption {
549	return func(s *AgentSettings) {
550		s.stopWhen = append(s.stopWhen, conditions...)
551	}
552}
553
554func WithPrepareStep(fn PrepareStepFunction) agentOption {
555	return func(s *AgentSettings) {
556		s.prepareStep = fn
557	}
558}
559
560func WithRepairToolCall(fn RepairToolCall) agentOption {
561	return func(s *AgentSettings) {
562		s.repairToolCall = fn
563	}
564}
565
566func WithOnStepFinished(fn OnStepFinishedFunction) agentOption {
567	return func(s *AgentSettings) {
568		s.onStepFinished = fn
569	}
570}
571
572func WithHeaders(headers map[string]string) agentOption {
573	return func(s *AgentSettings) {
574		s.headers = headers
575	}
576}
577
578func WithProviderOptions(providerOptions ProviderOptions) agentOption {
579	return func(s *AgentSettings) {
580		s.providerOptions = providerOptions
581	}
582}