1package ai
2
3import (
4 "context"
5 "errors"
6 "maps"
7 "slices"
8 "sync"
9
10 "github.com/charmbracelet/crush/internal/llm/tools"
11)
12
13type StepResult struct {
14 Response
15 // Messages generated during this step
16 Messages []Message
17}
18
19type StopCondition = func(steps []StepResult) bool
20
21type PrepareStepFunctionOptions struct {
22 Steps []StepResult
23 StepNumber int
24 Model LanguageModel
25 Messages []Message
26}
27
28type PrepareStepResult struct {
29 Model LanguageModel
30 Messages []Message
31}
32
33type (
34 PrepareStepFunction = func(options PrepareStepFunctionOptions) PrepareStepResult
35 OnStepFinishedFunction = func(step StepResult)
36 RepairToolCall = func(ToolCallContent) ToolCallContent
37)
38
39type AgentSettings struct {
40 systemPrompt string
41 maxOutputTokens *int64
42 temperature *float64
43 topP *float64
44 topK *int64
45 presencePenalty *float64
46 frequencyPenalty *float64
47 headers map[string]string
48 providerOptions ProviderOptions
49
50 // TODO: add support for provider tools
51 tools []tools.BaseTool
52 maxRetries *int
53
54 model LanguageModel
55
56 stopWhen []StopCondition
57 prepareStep PrepareStepFunction
58 repairToolCall RepairToolCall
59 onStepFinished OnStepFinishedFunction
60 onRetry OnRetryCallback
61}
62
63type AgentCall struct {
64 Prompt string `json:"prompt"`
65 Files []FilePart `json:"files"`
66 Messages []Message `json:"messages"`
67 MaxOutputTokens *int64
68 Temperature *float64 `json:"temperature"`
69 TopP *float64 `json:"top_p"`
70 TopK *int64 `json:"top_k"`
71 PresencePenalty *float64 `json:"presence_penalty"`
72 FrequencyPenalty *float64 `json:"frequency_penalty"`
73 ActiveTools []string `json:"active_tools"`
74 Headers map[string]string
75 ProviderOptions ProviderOptions
76 OnRetry OnRetryCallback
77 MaxRetries *int
78
79 StopWhen []StopCondition
80 PrepareStep PrepareStepFunction
81 RepairToolCall RepairToolCall
82 OnStepFinished OnStepFinishedFunction
83}
84
85type AgentResult struct {
86 Steps []StepResult
87 // Final response
88 Response Response
89 TotalUsage Usage
90}
91
92type Agent interface {
93 Generate(context.Context, AgentCall) (*AgentResult, error)
94 Stream(context.Context, AgentCall) (StreamResponse, error)
95}
96
97type agentOption = func(*AgentSettings)
98
99type agent struct {
100 settings AgentSettings
101}
102
103func NewAgent(model LanguageModel, opts ...agentOption) Agent {
104 settings := AgentSettings{
105 model: model,
106 }
107 for _, o := range opts {
108 o(&settings)
109 }
110 return &agent{
111 settings: settings,
112 }
113}
114
115func (a *agent) prepareCall(call AgentCall) AgentCall {
116 if call.MaxOutputTokens == nil && a.settings.maxOutputTokens != nil {
117 call.MaxOutputTokens = a.settings.maxOutputTokens
118 }
119 if call.Temperature == nil && a.settings.temperature != nil {
120 call.Temperature = a.settings.temperature
121 }
122 if call.TopP == nil && a.settings.topP != nil {
123 call.TopP = a.settings.topP
124 }
125 if call.TopK == nil && a.settings.topK != nil {
126 call.TopK = a.settings.topK
127 }
128 if call.PresencePenalty == nil && a.settings.presencePenalty != nil {
129 call.PresencePenalty = a.settings.presencePenalty
130 }
131 if call.FrequencyPenalty == nil && a.settings.frequencyPenalty != nil {
132 call.FrequencyPenalty = a.settings.frequencyPenalty
133 }
134 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
135 call.StopWhen = a.settings.stopWhen
136 }
137 if call.PrepareStep == nil && a.settings.prepareStep != nil {
138 call.PrepareStep = a.settings.prepareStep
139 }
140 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
141 call.RepairToolCall = a.settings.repairToolCall
142 }
143 if call.OnStepFinished == nil && a.settings.onStepFinished != nil {
144 call.OnStepFinished = a.settings.onStepFinished
145 }
146 if call.OnRetry == nil && a.settings.onRetry != nil {
147 call.OnRetry = a.settings.onRetry
148 }
149 if call.MaxRetries == nil && a.settings.maxRetries != nil {
150 call.MaxRetries = a.settings.maxRetries
151 }
152
153 providerOptions := ProviderOptions{}
154 if a.settings.providerOptions != nil {
155 maps.Copy(providerOptions, a.settings.providerOptions)
156 }
157 if call.ProviderOptions != nil {
158 maps.Copy(providerOptions, call.ProviderOptions)
159 }
160 call.ProviderOptions = providerOptions
161
162 headers := map[string]string{}
163
164 if a.settings.headers != nil {
165 maps.Copy(headers, a.settings.headers)
166 }
167
168 if call.Headers != nil {
169 maps.Copy(headers, call.Headers)
170 }
171 call.Headers = headers
172 return call
173}
174
175// Generate implements Agent.
176func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
177 opts = a.prepareCall(opts)
178 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
179 if err != nil {
180 return nil, err
181 }
182 var responseMessages []Message
183 var steps []StepResult
184
185 for {
186 stepInputMessages := append(initialPrompt, responseMessages...)
187 stepModel := a.settings.model
188 if opts.PrepareStep != nil {
189 prepared := opts.PrepareStep(PrepareStepFunctionOptions{
190 Model: stepModel,
191 Steps: steps,
192 StepNumber: len(steps),
193 Messages: stepInputMessages,
194 })
195 stepInputMessages = prepared.Messages
196 if prepared.Model != nil {
197 stepModel = prepared.Model
198 }
199 }
200
201 preparedTools := a.prepareTools(a.settings.tools, opts.ActiveTools)
202
203 toolChoice := ToolChoiceAuto
204 retryOptions := DefaultRetryOptions()
205 retryOptions.OnRetry = opts.OnRetry
206 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
207
208 result, err := retry(ctx, func() (*Response, error) {
209 return stepModel.Generate(ctx, Call{
210 Prompt: stepInputMessages,
211 MaxOutputTokens: opts.MaxOutputTokens,
212 Temperature: opts.Temperature,
213 TopP: opts.TopP,
214 TopK: opts.TopK,
215 PresencePenalty: opts.PresencePenalty,
216 FrequencyPenalty: opts.FrequencyPenalty,
217 Tools: preparedTools,
218 ToolChoice: &toolChoice,
219 Headers: opts.Headers,
220 ProviderOptions: opts.ProviderOptions,
221 })
222 })
223 if err != nil {
224 return nil, err
225 }
226
227 var stepToolCalls []ToolCallContent
228 for _, content := range result.Content {
229 if content.GetType() == ContentTypeToolCall {
230 toolCall, ok := AsContentType[ToolCallContent](content)
231 if !ok {
232 continue
233 }
234 stepToolCalls = append(stepToolCalls, toolCall)
235 }
236 }
237
238 toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls)
239
240 stepContent := result.Content
241 for _, result := range toolResults {
242 stepContent = append(stepContent, result)
243 }
244 currentStepMessages := toResponseMessages(stepContent)
245 responseMessages = append(responseMessages, currentStepMessages...)
246
247 stepResult := StepResult{
248 Response: Response{
249 Content: stepContent,
250 FinishReason: result.FinishReason,
251 Usage: result.Usage,
252 Warnings: result.Warnings,
253 ProviderMetadata: result.ProviderMetadata,
254 },
255 Messages: currentStepMessages,
256 }
257 steps = append(steps, stepResult)
258 if opts.OnStepFinished != nil {
259 opts.OnStepFinished(stepResult)
260 }
261
262 shouldStop := isStopConditionMet(opts.StopWhen, steps)
263
264 if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
265 break
266 }
267 }
268
269 totalUsage := Usage{}
270
271 for _, step := range steps {
272 usage := step.Usage
273 totalUsage.InputTokens += usage.InputTokens
274 totalUsage.OutputTokens += usage.OutputTokens
275 totalUsage.ReasoningTokens += usage.ReasoningTokens
276 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
277 totalUsage.CacheReadTokens += usage.CacheReadTokens
278 totalUsage.TotalTokens += usage.TotalTokens
279 }
280
281 agentResult := &AgentResult{
282 Steps: steps,
283 Response: steps[len(steps)-1].Response,
284 TotalUsage: totalUsage,
285 }
286 return agentResult, nil
287}
288
289func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
290 if len(conditions) == 0 {
291 return false
292 }
293
294 for _, condition := range conditions {
295 if condition(steps) {
296 return true
297 }
298 }
299 return false
300}
301
302func toResponseMessages(content []Content) []Message {
303 var assistantParts []MessagePart
304 var toolParts []MessagePart
305
306 for _, c := range content {
307 switch c.GetType() {
308 case ContentTypeText:
309 text, ok := AsContentType[TextContent](c)
310 if !ok {
311 continue
312 }
313 assistantParts = append(assistantParts, TextPart{
314 Text: text.Text,
315 ProviderOptions: ProviderOptions(text.ProviderMetadata),
316 })
317 case ContentTypeReasoning:
318 reasoning, ok := AsContentType[ReasoningContent](c)
319 if !ok {
320 continue
321 }
322 assistantParts = append(assistantParts, ReasoningPart{
323 Text: reasoning.Text,
324 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
325 })
326 case ContentTypeToolCall:
327 toolCall, ok := AsContentType[ToolCallContent](c)
328 if !ok {
329 continue
330 }
331 assistantParts = append(assistantParts, ToolCallPart{
332 ToolCallID: toolCall.ToolCallID,
333 ToolName: toolCall.ToolName,
334 Input: toolCall.Input,
335 ProviderExecuted: toolCall.ProviderExecuted,
336 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
337 })
338 case ContentTypeFile:
339 file, ok := AsContentType[FileContent](c)
340 if !ok {
341 continue
342 }
343 assistantParts = append(assistantParts, FilePart{
344 Data: file.Data,
345 MediaType: file.MediaType,
346 ProviderOptions: ProviderOptions(file.ProviderMetadata),
347 })
348 case ContentTypeToolResult:
349 result, ok := AsContentType[ToolResultContent](c)
350 if !ok {
351 continue
352 }
353 toolParts = append(toolParts, ToolResultPart{
354 ToolCallID: result.ToolCallID,
355 Output: result.Result,
356 ProviderOptions: ProviderOptions(result.ProviderMetadata),
357 })
358 }
359 }
360
361 var messages []Message
362 if len(assistantParts) > 0 {
363 messages = append(messages, Message{
364 Role: MessageRoleAssistant,
365 Content: assistantParts,
366 })
367 }
368 if len(toolParts) > 0 {
369 messages = append(messages, Message{
370 Role: MessageRoleTool,
371 Content: toolParts,
372 })
373 }
374 return messages
375}
376
377func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent) ([]ToolResultContent, error) {
378 if len(toolCalls) == 0 {
379 return nil, nil
380 }
381
382 // Create a map for quick tool lookup
383 toolMap := make(map[string]tools.BaseTool)
384 for _, tool := range allTools {
385 toolMap[tool.Info().Name] = tool
386 }
387
388 // Execute all tool calls in parallel
389 results := make([]ToolResultContent, len(toolCalls))
390 var toolExecutionError error
391 var wg sync.WaitGroup
392
393 for i, toolCall := range toolCalls {
394 wg.Add(1)
395 go func(index int, call ToolCallContent) {
396 defer wg.Done()
397
398 tool, exists := toolMap[call.ToolName]
399 if !exists {
400 results[index] = ToolResultContent{
401 ToolCallID: call.ToolCallID,
402 ToolName: call.ToolName,
403 Result: ToolResultOutputContentError{
404 Error: errors.New("Error: Tool not found: " + call.ToolName),
405 },
406 ProviderExecuted: false,
407 }
408 return
409 }
410
411 // Execute the tool
412 result, err := tool.Run(ctx, tools.ToolCall{
413 ID: call.ToolCallID,
414 Name: call.ToolName,
415 Input: call.Input,
416 })
417 if err != nil {
418 results[index] = ToolResultContent{
419 ToolCallID: call.ToolCallID,
420 ToolName: call.ToolName,
421 Result: ToolResultOutputContentError{
422 Error: err,
423 },
424 ProviderExecuted: false,
425 }
426 toolExecutionError = err
427 return
428 }
429
430 if result.IsError {
431 results[index] = ToolResultContent{
432 ToolCallID: call.ToolCallID,
433 ToolName: call.ToolName,
434 Result: ToolResultOutputContentError{
435 Error: errors.New(result.Content),
436 },
437 ProviderExecuted: false,
438 }
439 } else {
440 results[index] = ToolResultContent{
441 ToolCallID: call.ToolCallID,
442 ToolName: toolCall.ToolName,
443 Result: ToolResultOutputContentText{
444 Text: result.Content,
445 },
446 ProviderExecuted: false,
447 }
448 }
449 }(i, toolCall)
450 }
451
452 // Wait for all tool executions to complete
453 wg.Wait()
454
455 return results, toolExecutionError
456}
457
458// Stream implements Agent.
459func (a *agent) Stream(ctx context.Context, opts AgentCall) (StreamResponse, error) {
460 // TODO: implement the agentic stuff
461 panic("not implemented")
462}
463
464func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string) []Tool {
465 var preparedTools []Tool
466 for _, tool := range tools {
467 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
468 continue
469 }
470 info := tool.Info()
471 preparedTools = append(preparedTools, FunctionTool{
472 Name: info.Name,
473 Description: info.Description,
474 InputSchema: map[string]any{
475 "type": "object",
476 "properties": info.Parameters,
477 "required": info.Required,
478 },
479 })
480 }
481 return preparedTools
482}
483
484func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
485 if prompt == "" {
486 return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
487 }
488
489 var preparedPrompt Prompt
490
491 if system != "" {
492 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
493 }
494
495 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
496 preparedPrompt = append(preparedPrompt, messages...)
497 return preparedPrompt, nil
498}
499
500func WithSystemPrompt(prompt string) agentOption {
501 return func(s *AgentSettings) {
502 s.systemPrompt = prompt
503 }
504}
505
506func WithMaxOutputTokens(tokens int64) agentOption {
507 return func(s *AgentSettings) {
508 s.maxOutputTokens = &tokens
509 }
510}
511
512func WithTemperature(temp float64) agentOption {
513 return func(s *AgentSettings) {
514 s.temperature = &temp
515 }
516}
517
518func WithTopP(topP float64) agentOption {
519 return func(s *AgentSettings) {
520 s.topP = &topP
521 }
522}
523
524func WithTopK(topK int64) agentOption {
525 return func(s *AgentSettings) {
526 s.topK = &topK
527 }
528}
529
530func WithPresencePenalty(penalty float64) agentOption {
531 return func(s *AgentSettings) {
532 s.presencePenalty = &penalty
533 }
534}
535
536func WithFrequencyPenalty(penalty float64) agentOption {
537 return func(s *AgentSettings) {
538 s.frequencyPenalty = &penalty
539 }
540}
541
542func WithTools(tools ...tools.BaseTool) agentOption {
543 return func(s *AgentSettings) {
544 s.tools = append(s.tools, tools...)
545 }
546}
547
548func WithStopConditions(conditions ...StopCondition) agentOption {
549 return func(s *AgentSettings) {
550 s.stopWhen = append(s.stopWhen, conditions...)
551 }
552}
553
554func WithPrepareStep(fn PrepareStepFunction) agentOption {
555 return func(s *AgentSettings) {
556 s.prepareStep = fn
557 }
558}
559
560func WithRepairToolCall(fn RepairToolCall) agentOption {
561 return func(s *AgentSettings) {
562 s.repairToolCall = fn
563 }
564}
565
566func WithOnStepFinished(fn OnStepFinishedFunction) agentOption {
567 return func(s *AgentSettings) {
568 s.onStepFinished = fn
569 }
570}
571
572func WithHeaders(headers map[string]string) agentOption {
573 return func(s *AgentSettings) {
574 s.headers = headers
575 }
576}
577
578func WithProviderOptions(providerOptions ProviderOptions) agentOption {
579 return func(s *AgentSettings) {
580 s.providerOptions = providerOptions
581 }
582}