1package ai
2
3import (
4 "context"
5 "errors"
6 "maps"
7 "slices"
8 "sync"
9
10 "github.com/charmbracelet/crush/internal/llm/tools"
11)
12
13type StepResult struct {
14 Response
15 // Messages generated during this step
16 Messages []Message
17}
18
19type StopCondition = func(steps []StepResult) bool
20
21type PrepareStepFunctionOptions struct {
22 Steps []StepResult
23 StepNumber int
24 Model LanguageModel
25 Messages []Message
26}
27
28type PrepareStepResult struct {
29 Model LanguageModel
30 Messages []Message
31}
32
33type (
34 PrepareStepFunction = func(options PrepareStepFunctionOptions) PrepareStepResult
35 OnStepFinishedFunction = func(step StepResult)
36 RepairToolCall = func(ToolCallContent) ToolCallContent
37)
38
39type AgentSettings struct {
40 systemPrompt string
41 maxOutputTokens *int64
42 temperature *float64
43 topP *float64
44 topK *int64
45 presencePenalty *float64
46 frequencyPenalty *float64
47 headers map[string]string
48 providerOptions ProviderOptions
49
50 // TODO: add support for provider tools
51 tools []tools.BaseTool
52 maxRetries *int
53
54 model LanguageModel
55
56 stopWhen []StopCondition
57 prepareStep PrepareStepFunction
58 repairToolCall RepairToolCall
59 onStepFinished OnStepFinishedFunction
60 onRetry OnRetryCallback
61}
62
63type AgentCall struct {
64 Prompt string `json:"prompt"`
65 Files []FilePart `json:"files"`
66 Messages []Message `json:"messages"`
67 MaxOutputTokens *int64
68 Temperature *float64 `json:"temperature"`
69 TopP *float64 `json:"top_p"`
70 TopK *int64 `json:"top_k"`
71 PresencePenalty *float64 `json:"presence_penalty"`
72 FrequencyPenalty *float64 `json:"frequency_penalty"`
73 ActiveTools []string `json:"active_tools"`
74 Headers map[string]string
75 ProviderOptions ProviderOptions
76 OnRetry OnRetryCallback
77 MaxRetries *int
78
79 StopWhen []StopCondition
80 PrepareStep PrepareStepFunction
81 RepairToolCall RepairToolCall
82 OnStepFinished OnStepFinishedFunction
83}
84
85type AgentResult struct {
86 Steps []StepResult
87 // Final response
88 Response Response
89 TotalUsage Usage
90}
91
92type Agent interface {
93 Generate(context.Context, AgentCall) (*AgentResult, error)
94 Stream(context.Context, AgentCall) (StreamResponse, error)
95}
96
97type agentOption = func(*AgentSettings)
98
99type agent struct {
100 settings AgentSettings
101}
102
103func NewAgent(model LanguageModel, opts ...agentOption) Agent {
104 settings := AgentSettings{
105 model: model,
106 }
107 for _, o := range opts {
108 o(&settings)
109 }
110 return &agent{
111 settings: settings,
112 }
113}
114
115func (a *agent) prepareCall(call AgentCall) AgentCall {
116 if call.MaxOutputTokens == nil && a.settings.maxOutputTokens != nil {
117 call.MaxOutputTokens = a.settings.maxOutputTokens
118 }
119 if call.Temperature == nil && a.settings.temperature != nil {
120 call.Temperature = a.settings.temperature
121 }
122 if call.TopP == nil && a.settings.topP != nil {
123 call.TopP = a.settings.topP
124 }
125 if call.TopK == nil && a.settings.topK != nil {
126 call.TopK = a.settings.topK
127 }
128 if call.PresencePenalty == nil && a.settings.presencePenalty != nil {
129 call.PresencePenalty = a.settings.presencePenalty
130 }
131 if call.FrequencyPenalty == nil && a.settings.frequencyPenalty != nil {
132 call.FrequencyPenalty = a.settings.frequencyPenalty
133 }
134 if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
135 call.StopWhen = a.settings.stopWhen
136 }
137 if call.PrepareStep == nil && a.settings.prepareStep != nil {
138 call.PrepareStep = a.settings.prepareStep
139 }
140 if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
141 call.RepairToolCall = a.settings.repairToolCall
142 }
143 if call.OnStepFinished == nil && a.settings.onStepFinished != nil {
144 call.OnStepFinished = a.settings.onStepFinished
145 }
146 if call.OnRetry == nil && a.settings.onRetry != nil {
147 call.OnRetry = a.settings.onRetry
148 }
149 if call.MaxRetries == nil && a.settings.maxRetries != nil {
150 call.MaxRetries = a.settings.maxRetries
151 }
152
153 providerOptions := ProviderOptions{}
154 if a.settings.providerOptions != nil {
155 maps.Copy(providerOptions, a.settings.providerOptions)
156 }
157 if call.ProviderOptions != nil {
158 maps.Copy(providerOptions, call.ProviderOptions)
159 }
160 call.ProviderOptions = providerOptions
161
162 headers := map[string]string{}
163
164 if a.settings.headers != nil {
165 maps.Copy(headers, a.settings.headers)
166 }
167
168 if call.Headers != nil {
169 maps.Copy(headers, call.Headers)
170 }
171 call.Headers = headers
172 return call
173}
174
175// Generate implements Agent.
176func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
177 opts = a.prepareCall(opts)
178 initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
179 if err != nil {
180 return nil, err
181 }
182 var responseMessages []Message
183 var steps []StepResult
184
185 for {
186 stepInputMessages := append(initialPrompt, responseMessages...)
187 stepModel := a.settings.model
188 if opts.PrepareStep != nil {
189 prepared := opts.PrepareStep(PrepareStepFunctionOptions{
190 Model: stepModel,
191 Steps: steps,
192 StepNumber: len(steps),
193 Messages: stepInputMessages,
194 })
195 stepInputMessages = prepared.Messages
196 if prepared.Model != nil {
197 stepModel = prepared.Model
198 }
199 }
200
201 preparedTools := a.prepareTools(a.settings.tools, opts.ActiveTools)
202
203 toolChoice := ToolChoiceAuto
204 retryOptions := DefaultRetryOptions()
205 retryOptions.OnRetry = opts.OnRetry
206 retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
207
208 result, err := retry(ctx, func() (*Response, error) {
209 return stepModel.Generate(ctx, Call{
210 Prompt: stepInputMessages,
211 MaxOutputTokens: opts.MaxOutputTokens,
212 Temperature: opts.Temperature,
213 TopP: opts.TopP,
214 TopK: opts.TopK,
215 PresencePenalty: opts.PresencePenalty,
216 FrequencyPenalty: opts.FrequencyPenalty,
217 Tools: preparedTools,
218 ToolChoice: &toolChoice,
219 Headers: opts.Headers,
220 ProviderOptions: opts.ProviderOptions,
221 })
222 })
223 if err != nil {
224 return nil, err
225 }
226
227 var stepToolCalls []ToolCallContent
228 for _, content := range result.Content {
229 if content.GetType() == ContentTypeToolCall {
230 toolCall, ok := AsContentType[ToolCallContent](content)
231 if !ok {
232 continue
233 }
234 stepToolCalls = append(stepToolCalls, toolCall)
235 }
236 }
237
238 toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls)
239
240 stepContent := result.Content
241 for _, result := range toolResults {
242 stepContent = append(stepContent, result)
243 }
244 currentStepMessages := toResponseMessages(stepContent)
245 responseMessages = append(responseMessages, currentStepMessages...)
246
247 stepResult := StepResult{
248 Response: *result,
249 Messages: currentStepMessages,
250 }
251 steps = append(steps, stepResult)
252 if opts.OnStepFinished != nil {
253 opts.OnStepFinished(stepResult)
254 }
255
256 shouldStop := isStopConditionMet(opts.StopWhen, steps)
257
258 if shouldStop || err != nil || len(stepToolCalls) == 0 {
259 break
260 }
261 }
262
263 totalUsage := Usage{}
264
265 for _, step := range steps {
266 usage := step.Usage
267 totalUsage.InputTokens += usage.InputTokens
268 totalUsage.OutputTokens += usage.OutputTokens
269 totalUsage.ReasoningTokens += usage.ReasoningTokens
270 totalUsage.CacheCreationTokens += usage.CacheCreationTokens
271 totalUsage.CacheReadTokens += usage.CacheReadTokens
272 totalUsage.TotalTokens += totalUsage.TotalTokens
273 }
274
275 agentResult := &AgentResult{
276 Steps: steps,
277 Response: steps[len(steps)-1].Response,
278 TotalUsage: totalUsage,
279 }
280 return agentResult, nil
281}
282
283func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
284 if len(conditions) == 0 {
285 return false
286 }
287
288 for _, condition := range conditions {
289 if condition(steps) {
290 return true
291 }
292 }
293 return false
294}
295
296func toResponseMessages(content []Content) []Message {
297 var assistantParts []MessagePart
298 var toolParts []MessagePart
299
300 for _, c := range content {
301 switch c.GetType() {
302 case ContentTypeText:
303 text, ok := AsContentType[TextContent](c)
304 if !ok {
305 continue
306 }
307 assistantParts = append(assistantParts, TextPart{
308 Text: text.Text,
309 ProviderOptions: ProviderOptions(text.ProviderMetadata),
310 })
311 case ContentTypeReasoning:
312 reasoning, ok := AsContentType[ReasoningContent](c)
313 if !ok {
314 continue
315 }
316 assistantParts = append(assistantParts, ReasoningPart{
317 Text: reasoning.Text,
318 ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
319 })
320 case ContentTypeToolCall:
321 toolCall, ok := AsContentType[ToolCallContent](c)
322 if !ok {
323 continue
324 }
325 assistantParts = append(assistantParts, ToolCallPart{
326 ToolCallID: toolCall.ToolCallID,
327 ToolName: toolCall.ToolName,
328 Input: toolCall.Input,
329 ProviderExecuted: toolCall.ProviderExecuted,
330 ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
331 })
332 case ContentTypeFile:
333 file, ok := AsContentType[FileContent](c)
334 if !ok {
335 continue
336 }
337 assistantParts = append(assistantParts, FilePart{
338 Data: file.Data,
339 MediaType: file.MediaType,
340 ProviderOptions: ProviderOptions(file.ProviderMetadata),
341 })
342 case ContentTypeToolResult:
343 result, ok := AsContentType[ToolResultContent](c)
344 if !ok {
345 continue
346 }
347 toolParts = append(toolParts, ToolResultPart{
348 ToolCallID: result.ToolCallID,
349 Output: result.Result,
350 ProviderOptions: ProviderOptions(result.ProviderMetadata),
351 })
352 }
353 }
354
355 var messages []Message
356 if len(assistantParts) > 0 {
357 messages = append(messages, Message{
358 Role: MessageRoleAssistant,
359 Content: assistantParts,
360 })
361 }
362 if len(toolParts) > 0 {
363 messages = append(messages, Message{
364 Role: MessageRoleTool,
365 Content: toolParts,
366 })
367 }
368 return messages
369}
370
371func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent) ([]ToolResultContent, error) {
372 if len(toolCalls) == 0 {
373 return nil, nil
374 }
375
376 // Create a map for quick tool lookup
377 toolMap := make(map[string]tools.BaseTool)
378 for _, tool := range allTools {
379 toolMap[tool.Info().Name] = tool
380 }
381
382 // Execute all tool calls in parallel
383 results := make([]ToolResultContent, len(toolCalls))
384 var toolExecutionError error
385 var wg sync.WaitGroup
386
387 for i, toolCall := range toolCalls {
388 wg.Add(1)
389 go func(index int, call ToolCallContent) {
390 defer wg.Done()
391
392 tool, exists := toolMap[call.ToolName]
393 if !exists {
394 results[index] = ToolResultContent{
395 ToolCallID: call.ToolCallID,
396 ToolName: call.ToolName,
397 Result: ToolResultOutputContentError{
398 Error: errors.New("Error: Tool not found: " + call.ToolName),
399 },
400 ProviderExecuted: false,
401 }
402 return
403 }
404
405 // Execute the tool
406 result, err := tool.Run(ctx, tools.ToolCall{
407 ID: call.ToolCallID,
408 Name: call.ToolName,
409 Input: call.Input,
410 })
411 if err != nil {
412 results[index] = ToolResultContent{
413 ToolCallID: call.ToolCallID,
414 ToolName: call.ToolName,
415 Result: ToolResultOutputContentError{
416 Error: err,
417 },
418 ProviderExecuted: false,
419 }
420 toolExecutionError = err
421 return
422 }
423
424 if result.IsError {
425 results[index] = ToolResultContent{
426 ToolCallID: call.ToolCallID,
427 ToolName: call.ToolName,
428 Result: ToolResultOutputContentError{
429 Error: errors.New(result.Content),
430 },
431 ProviderExecuted: false,
432 }
433 } else {
434 results[index] = ToolResultContent{
435 ToolCallID: call.ToolCallID,
436 ToolName: toolCall.ToolName,
437 Result: ToolResultOutputContentText{
438 Text: result.Content,
439 },
440 ProviderExecuted: false,
441 }
442 }
443 }(i, toolCall)
444 }
445
446 // Wait for all tool executions to complete
447 wg.Wait()
448
449 return results, toolExecutionError
450}
451
452// Stream implements Agent.
453func (a *agent) Stream(ctx context.Context, opts AgentCall) (StreamResponse, error) {
454 // TODO: implement the agentic stuff
455 panic("not implemented")
456}
457
458func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string) []Tool {
459 var preparedTools []Tool
460 for _, tool := range tools {
461 if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
462 continue
463 }
464 info := tool.Info()
465 preparedTools = append(preparedTools, FunctionTool{
466 Name: info.Name,
467 Description: info.Description,
468 InputSchema: map[string]any{
469 "type": "object",
470 "properties": info.Parameters,
471 "required": info.Required,
472 },
473 })
474 }
475 return preparedTools
476}
477
478func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
479 if prompt == "" {
480 return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
481 }
482
483 var preparedPrompt Prompt
484
485 if system != "" {
486 preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
487 }
488
489 preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
490 preparedPrompt = append(preparedPrompt, messages...)
491 return preparedPrompt, nil
492}
493
494func WithSystemPrompt(prompt string) agentOption {
495 return func(s *AgentSettings) {
496 s.systemPrompt = prompt
497 }
498}
499
500func WithMaxOutputTokens(tokens int64) agentOption {
501 return func(s *AgentSettings) {
502 s.maxOutputTokens = &tokens
503 }
504}
505
506func WithTemperature(temp float64) agentOption {
507 return func(s *AgentSettings) {
508 s.temperature = &temp
509 }
510}
511
512func WithTopP(topP float64) agentOption {
513 return func(s *AgentSettings) {
514 s.topP = &topP
515 }
516}
517
518func WithTopK(topK int64) agentOption {
519 return func(s *AgentSettings) {
520 s.topK = &topK
521 }
522}
523
524func WithPresencePenalty(penalty float64) agentOption {
525 return func(s *AgentSettings) {
526 s.presencePenalty = &penalty
527 }
528}
529
530func WithFrequencyPenalty(penalty float64) agentOption {
531 return func(s *AgentSettings) {
532 s.frequencyPenalty = &penalty
533 }
534}
535
536func WithTools(tools ...tools.BaseTool) agentOption {
537 return func(s *AgentSettings) {
538 s.tools = append(s.tools, tools...)
539 }
540}
541
542func WithStopConditions(conditions ...StopCondition) agentOption {
543 return func(s *AgentSettings) {
544 s.stopWhen = append(s.stopWhen, conditions...)
545 }
546}
547
548func WithPrepareStep(fn PrepareStepFunction) agentOption {
549 return func(s *AgentSettings) {
550 s.prepareStep = fn
551 }
552}
553
554func WithRepairToolCall(fn RepairToolCall) agentOption {
555 return func(s *AgentSettings) {
556 s.repairToolCall = fn
557 }
558}
559
560func WithOnStepFinished(fn OnStepFinishedFunction) agentOption {
561 return func(s *AgentSettings) {
562 s.onStepFinished = fn
563 }
564}
565
566func WithHeaders(headers map[string]string) agentOption {
567 return func(s *AgentSettings) {
568 s.headers = headers
569 }
570}
571
572func WithProviderOptions(providerOptions ProviderOptions) agentOption {
573 return func(s *AgentSettings) {
574 s.providerOptions = providerOptions
575 }
576}