1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"strings"
  8	"sync"
  9	"time"
 10
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/llm/models"
 13	"github.com/charmbracelet/crush/internal/llm/prompt"
 14	"github.com/charmbracelet/crush/internal/llm/provider"
 15	"github.com/charmbracelet/crush/internal/llm/tools"
 16	"github.com/charmbracelet/crush/internal/logging"
 17	"github.com/charmbracelet/crush/internal/message"
 18	"github.com/charmbracelet/crush/internal/permission"
 19	"github.com/charmbracelet/crush/internal/pubsub"
 20	"github.com/charmbracelet/crush/internal/session"
 21)
 22
 23// Common errors
 24var (
 25	ErrRequestCancelled = errors.New("request cancelled by user")
 26	ErrSessionBusy      = errors.New("session is currently processing another request")
 27)
 28
 29type AgentEventType string
 30
 31const (
 32	AgentEventTypeError     AgentEventType = "error"
 33	AgentEventTypeResponse  AgentEventType = "response"
 34	AgentEventTypeSummarize AgentEventType = "summarize"
 35)
 36
 37type AgentEvent struct {
 38	Type    AgentEventType
 39	Message message.Message
 40	Error   error
 41
 42	// When summarizing
 43	SessionID string
 44	Progress  string
 45	Done      bool
 46}
 47
 48type Service interface {
 49	pubsub.Suscriber[AgentEvent]
 50	Model() models.Model
 51	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 52	Cancel(sessionID string)
 53	CancelAll()
 54	IsSessionBusy(sessionID string) bool
 55	IsBusy() bool
 56	Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
 57	Summarize(ctx context.Context, sessionID string) error
 58}
 59
 60type agent struct {
 61	*pubsub.Broker[AgentEvent]
 62	sessions session.Service
 63	messages message.Service
 64
 65	tools    []tools.BaseTool
 66	provider provider.Provider
 67
 68	titleProvider     provider.Provider
 69	summarizeProvider provider.Provider
 70
 71	activeRequests sync.Map
 72}
 73
 74func NewAgent(
 75	agentName config.AgentName,
 76	sessions session.Service,
 77	messages message.Service,
 78	agentTools []tools.BaseTool,
 79) (Service, error) {
 80	agentProvider, err := createAgentProvider(agentName)
 81	if err != nil {
 82		return nil, err
 83	}
 84	var titleProvider provider.Provider
 85	// Only generate titles for the coder agent
 86	if agentName == config.AgentCoder {
 87		titleProvider, err = createAgentProvider(config.AgentTitle)
 88		if err != nil {
 89			return nil, err
 90		}
 91	}
 92	var summarizeProvider provider.Provider
 93	if agentName == config.AgentCoder {
 94		summarizeProvider, err = createAgentProvider(config.AgentSummarizer)
 95		if err != nil {
 96			return nil, err
 97		}
 98	}
 99
100	agent := &agent{
101		Broker:            pubsub.NewBroker[AgentEvent](),
102		provider:          agentProvider,
103		messages:          messages,
104		sessions:          sessions,
105		tools:             agentTools,
106		titleProvider:     titleProvider,
107		summarizeProvider: summarizeProvider,
108		activeRequests:    sync.Map{},
109	}
110
111	return agent, nil
112}
113
114func (a *agent) Model() models.Model {
115	return a.provider.Model()
116}
117
118func (a *agent) Cancel(sessionID string) {
119	// Cancel regular requests
120	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
121		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
122			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
123			cancel()
124		}
125	}
126
127	// Also check for summarize requests
128	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
129		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
130			logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
131			cancel()
132		}
133	}
134}
135
136func (a *agent) IsBusy() bool {
137	busy := false
138	a.activeRequests.Range(func(key, value interface{}) bool {
139		if cancelFunc, ok := value.(context.CancelFunc); ok {
140			if cancelFunc != nil {
141				busy = true
142				return false // Stop iterating
143			}
144		}
145		return true // Continue iterating
146	})
147	return busy
148}
149
150func (a *agent) IsSessionBusy(sessionID string) bool {
151	_, busy := a.activeRequests.Load(sessionID)
152	return busy
153}
154
155func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
156	if content == "" {
157		return nil
158	}
159	if a.titleProvider == nil {
160		return nil
161	}
162	session, err := a.sessions.Get(ctx, sessionID)
163	if err != nil {
164		return err
165	}
166	parts := []message.ContentPart{message.TextContent{Text: content}}
167
168	// Use streaming approach like summarization
169	response := a.titleProvider.StreamResponse(
170		ctx,
171		[]message.Message{
172			{
173				Role:  message.User,
174				Parts: parts,
175			},
176		},
177		make([]tools.BaseTool, 0),
178	)
179
180	var finalResponse *provider.ProviderResponse
181	for r := range response {
182		if r.Error != nil {
183			return r.Error
184		}
185		finalResponse = r.Response
186	}
187
188	if finalResponse == nil {
189		return fmt.Errorf("no response received from title provider")
190	}
191
192	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
193	if title == "" {
194		return nil
195	}
196
197	session.Title = title
198	_, err = a.sessions.Save(ctx, session)
199	return err
200}
201
202func (a *agent) err(err error) AgentEvent {
203	return AgentEvent{
204		Type:  AgentEventTypeError,
205		Error: err,
206	}
207}
208
209func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
210	if !a.provider.Model().SupportsAttachments && attachments != nil {
211		attachments = nil
212	}
213	events := make(chan AgentEvent)
214	if a.IsSessionBusy(sessionID) {
215		return nil, ErrSessionBusy
216	}
217
218	genCtx, cancel := context.WithCancel(ctx)
219
220	a.activeRequests.Store(sessionID, cancel)
221	go func() {
222		logging.Debug("Request started", "sessionID", sessionID)
223		defer logging.RecoverPanic("agent.Run", func() {
224			events <- a.err(fmt.Errorf("panic while running the agent"))
225		})
226		var attachmentParts []message.ContentPart
227		for _, attachment := range attachments {
228			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
229		}
230		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
231		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
232			logging.ErrorPersist(result.Error.Error())
233		}
234		logging.Debug("Request completed", "sessionID", sessionID)
235		a.activeRequests.Delete(sessionID)
236		cancel()
237		a.Publish(pubsub.CreatedEvent, result)
238		events <- result
239		close(events)
240	}()
241	return events, nil
242}
243
244func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
245	// List existing messages; if none, start title generation asynchronously.
246	msgs, err := a.messages.List(ctx, sessionID)
247	if err != nil {
248		return a.err(fmt.Errorf("failed to list messages: %w", err))
249	}
250	if len(msgs) == 0 {
251		go func() {
252			defer logging.RecoverPanic("agent.Run", func() {
253				logging.ErrorPersist("panic while generating title")
254			})
255			titleErr := a.generateTitle(context.Background(), sessionID, content)
256			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
257				logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
258			}
259		}()
260	}
261	session, err := a.sessions.Get(ctx, sessionID)
262	if err != nil {
263		return a.err(fmt.Errorf("failed to get session: %w", err))
264	}
265	if session.SummaryMessageID != "" {
266		summaryMsgInex := -1
267		for i, msg := range msgs {
268			if msg.ID == session.SummaryMessageID {
269				summaryMsgInex = i
270				break
271			}
272		}
273		if summaryMsgInex != -1 {
274			msgs = msgs[summaryMsgInex:]
275			msgs[0].Role = message.User
276		}
277	}
278
279	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
280	if err != nil {
281		return a.err(fmt.Errorf("failed to create user message: %w", err))
282	}
283	// Append the new user message to the conversation history.
284	msgHistory := append(msgs, userMsg)
285
286	for {
287		// Check for cancellation before each iteration
288		select {
289		case <-ctx.Done():
290			return a.err(ctx.Err())
291		default:
292			// Continue processing
293		}
294		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
295		if err != nil {
296			if errors.Is(err, context.Canceled) {
297				agentMessage.AddFinish(message.FinishReasonCanceled)
298				a.messages.Update(context.Background(), agentMessage)
299				return a.err(ErrRequestCancelled)
300			}
301			return a.err(fmt.Errorf("failed to process events: %w", err))
302		}
303		logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
304		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
305			// We are not done, we need to respond with the tool response
306			msgHistory = append(msgHistory, agentMessage, *toolResults)
307			continue
308		}
309		return AgentEvent{
310			Type:    AgentEventTypeResponse,
311			Message: agentMessage,
312			Done:    true,
313		}
314	}
315}
316
317func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
318	parts := []message.ContentPart{message.TextContent{Text: content}}
319	parts = append(parts, attachmentParts...)
320	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
321		Role:  message.User,
322		Parts: parts,
323	})
324}
325
326func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
327	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
328
329	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
330		Role:  message.Assistant,
331		Parts: []message.ContentPart{},
332		Model: a.provider.Model().ID,
333	})
334	if err != nil {
335		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
336	}
337
338	// Add the session and message ID into the context if needed by tools.
339	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
340	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
341
342	// Process each event in the stream.
343	for event := range eventChan {
344		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
345			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
346			return assistantMsg, nil, processErr
347		}
348		if ctx.Err() != nil {
349			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
350			return assistantMsg, nil, ctx.Err()
351		}
352	}
353
354	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
355	toolCalls := assistantMsg.ToolCalls()
356	for i, toolCall := range toolCalls {
357		select {
358		case <-ctx.Done():
359			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
360			// Make all future tool calls cancelled
361			for j := i; j < len(toolCalls); j++ {
362				toolResults[j] = message.ToolResult{
363					ToolCallID: toolCalls[j].ID,
364					Content:    "Tool execution canceled by user",
365					IsError:    true,
366				}
367			}
368			goto out
369		default:
370			// Continue processing
371			var tool tools.BaseTool
372			for _, availableTools := range a.tools {
373				if availableTools.Info().Name == toolCall.Name {
374					tool = availableTools
375				}
376			}
377
378			// Tool not found
379			if tool == nil {
380				toolResults[i] = message.ToolResult{
381					ToolCallID: toolCall.ID,
382					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
383					IsError:    true,
384				}
385				continue
386			}
387			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
388				ID:    toolCall.ID,
389				Name:  toolCall.Name,
390				Input: toolCall.Input,
391			})
392			if toolErr != nil {
393				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
394					toolResults[i] = message.ToolResult{
395						ToolCallID: toolCall.ID,
396						Content:    "Permission denied",
397						IsError:    true,
398					}
399					for j := i + 1; j < len(toolCalls); j++ {
400						toolResults[j] = message.ToolResult{
401							ToolCallID: toolCalls[j].ID,
402							Content:    "Tool execution canceled by user",
403							IsError:    true,
404						}
405					}
406					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
407					break
408				}
409			}
410			toolResults[i] = message.ToolResult{
411				ToolCallID: toolCall.ID,
412				Content:    toolResult.Content,
413				Metadata:   toolResult.Metadata,
414				IsError:    toolResult.IsError,
415			}
416		}
417	}
418out:
419	if len(toolResults) == 0 {
420		return assistantMsg, nil, nil
421	}
422	parts := make([]message.ContentPart, 0)
423	for _, tr := range toolResults {
424		parts = append(parts, tr)
425	}
426	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
427		Role:  message.Tool,
428		Parts: parts,
429	})
430	if err != nil {
431		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
432	}
433
434	return assistantMsg, &msg, err
435}
436
437func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
438	msg.AddFinish(finishReson)
439	_ = a.messages.Update(ctx, *msg)
440}
441
442func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
443	select {
444	case <-ctx.Done():
445		return ctx.Err()
446	default:
447		// Continue processing.
448	}
449
450	switch event.Type {
451	case provider.EventThinkingDelta:
452		assistantMsg.AppendReasoningContent(event.Content)
453		return a.messages.Update(ctx, *assistantMsg)
454	case provider.EventContentDelta:
455		assistantMsg.AppendContent(event.Content)
456		return a.messages.Update(ctx, *assistantMsg)
457	case provider.EventToolUseStart:
458		logging.Info("Tool call started", "toolCall", event.ToolCall)
459		assistantMsg.AddToolCall(*event.ToolCall)
460		return a.messages.Update(ctx, *assistantMsg)
461	case provider.EventToolUseDelta:
462		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
463		return a.messages.Update(ctx, *assistantMsg)
464	case provider.EventToolUseStop:
465		logging.Info("Finished tool call", "toolCall", event.ToolCall)
466		assistantMsg.FinishToolCall(event.ToolCall.ID)
467		return a.messages.Update(ctx, *assistantMsg)
468	case provider.EventError:
469		if errors.Is(event.Error, context.Canceled) {
470			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
471			return context.Canceled
472		}
473		logging.ErrorPersist(event.Error.Error())
474		return event.Error
475	case provider.EventComplete:
476		assistantMsg.SetToolCalls(event.Response.ToolCalls)
477		assistantMsg.AddFinish(event.Response.FinishReason)
478		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
479			return fmt.Errorf("failed to update message: %w", err)
480		}
481		return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
482	}
483
484	return nil
485}
486
487func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
488	sess, err := a.sessions.Get(ctx, sessionID)
489	if err != nil {
490		return fmt.Errorf("failed to get session: %w", err)
491	}
492
493	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
494		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
495		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
496		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
497
498	sess.Cost += cost
499	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
500	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
501
502	_, err = a.sessions.Save(ctx, sess)
503	if err != nil {
504		return fmt.Errorf("failed to save session: %w", err)
505	}
506	return nil
507}
508
509func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
510	if a.IsBusy() {
511		return models.Model{}, fmt.Errorf("cannot change model while processing requests")
512	}
513
514	if err := config.UpdateAgentModel(agentName, modelID); err != nil {
515		return models.Model{}, fmt.Errorf("failed to update config: %w", err)
516	}
517
518	provider, err := createAgentProvider(agentName)
519	if err != nil {
520		return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
521	}
522
523	a.provider = provider
524
525	return a.provider.Model(), nil
526}
527
528func (a *agent) Summarize(ctx context.Context, sessionID string) error {
529	if a.summarizeProvider == nil {
530		return fmt.Errorf("summarize provider not available")
531	}
532
533	// Check if session is busy
534	if a.IsSessionBusy(sessionID) {
535		return ErrSessionBusy
536	}
537
538	// Create a new context with cancellation
539	summarizeCtx, cancel := context.WithCancel(ctx)
540
541	// Store the cancel function in activeRequests to allow cancellation
542	a.activeRequests.Store(sessionID+"-summarize", cancel)
543
544	go func() {
545		defer a.activeRequests.Delete(sessionID + "-summarize")
546		defer cancel()
547		event := AgentEvent{
548			Type:     AgentEventTypeSummarize,
549			Progress: "Starting summarization...",
550		}
551
552		a.Publish(pubsub.CreatedEvent, event)
553		// Get all messages from the session
554		msgs, err := a.messages.List(summarizeCtx, sessionID)
555		if err != nil {
556			event = AgentEvent{
557				Type:  AgentEventTypeError,
558				Error: fmt.Errorf("failed to list messages: %w", err),
559				Done:  true,
560			}
561			a.Publish(pubsub.CreatedEvent, event)
562			return
563		}
564
565		if len(msgs) == 0 {
566			event = AgentEvent{
567				Type:  AgentEventTypeError,
568				Error: fmt.Errorf("no messages to summarize"),
569				Done:  true,
570			}
571			a.Publish(pubsub.CreatedEvent, event)
572			return
573		}
574
575		event = AgentEvent{
576			Type:     AgentEventTypeSummarize,
577			Progress: "Analyzing conversation...",
578		}
579		a.Publish(pubsub.CreatedEvent, event)
580
581		// Add a system message to guide the summarization
582		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
583
584		// Create a new message with the summarize prompt
585		promptMsg := message.Message{
586			Role:  message.User,
587			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
588		}
589
590		// Append the prompt to the messages
591		msgsWithPrompt := append(msgs, promptMsg)
592
593		event = AgentEvent{
594			Type:     AgentEventTypeSummarize,
595			Progress: "Generating summary...",
596		}
597
598		a.Publish(pubsub.CreatedEvent, event)
599
600		// Send the messages to the summarize provider
601		response := a.summarizeProvider.StreamResponse(
602			summarizeCtx,
603			msgsWithPrompt,
604			make([]tools.BaseTool, 0),
605		)
606		var finalResponse *provider.ProviderResponse
607		for r := range response {
608			if r.Error != nil {
609				event = AgentEvent{
610					Type:  AgentEventTypeError,
611					Error: fmt.Errorf("failed to summarize: %w", err),
612					Done:  true,
613				}
614				a.Publish(pubsub.CreatedEvent, event)
615				return
616			}
617			finalResponse = r.Response
618		}
619
620		summary := strings.TrimSpace(finalResponse.Content)
621		if summary == "" {
622			event = AgentEvent{
623				Type:  AgentEventTypeError,
624				Error: fmt.Errorf("empty summary returned"),
625				Done:  true,
626			}
627			a.Publish(pubsub.CreatedEvent, event)
628			return
629		}
630		event = AgentEvent{
631			Type:     AgentEventTypeSummarize,
632			Progress: "Creating new session...",
633		}
634
635		a.Publish(pubsub.CreatedEvent, event)
636		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
637		if err != nil {
638			event = AgentEvent{
639				Type:  AgentEventTypeError,
640				Error: fmt.Errorf("failed to get session: %w", err),
641				Done:  true,
642			}
643
644			a.Publish(pubsub.CreatedEvent, event)
645			return
646		}
647		// Create a message in the new session with the summary
648		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
649			Role: message.Assistant,
650			Parts: []message.ContentPart{
651				message.TextContent{Text: summary},
652				message.Finish{
653					Reason: message.FinishReasonEndTurn,
654					Time:   time.Now().Unix(),
655				},
656			},
657			Model: a.summarizeProvider.Model().ID,
658		})
659		if err != nil {
660			event = AgentEvent{
661				Type:  AgentEventTypeError,
662				Error: fmt.Errorf("failed to create summary message: %w", err),
663				Done:  true,
664			}
665
666			a.Publish(pubsub.CreatedEvent, event)
667			return
668		}
669		oldSession.SummaryMessageID = msg.ID
670		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
671		oldSession.PromptTokens = 0
672		model := a.summarizeProvider.Model()
673		usage := finalResponse.Usage
674		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
675			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
676			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
677			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
678		oldSession.Cost += cost
679		_, err = a.sessions.Save(summarizeCtx, oldSession)
680		if err != nil {
681			event = AgentEvent{
682				Type:  AgentEventTypeError,
683				Error: fmt.Errorf("failed to save session: %w", err),
684				Done:  true,
685			}
686			a.Publish(pubsub.CreatedEvent, event)
687		}
688
689		event = AgentEvent{
690			Type:      AgentEventTypeSummarize,
691			SessionID: oldSession.ID,
692			Progress:  "Summary complete",
693			Done:      true,
694		}
695		a.Publish(pubsub.CreatedEvent, event)
696		// Send final success event with the new session ID
697	}()
698
699	return nil
700}
701
702func (a *agent) CancelAll() {
703	a.activeRequests.Range(func(key, value any) bool {
704		a.Cancel(key.(string)) // key is sessionID
705		return true
706	})
707}
708
709func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
710	cfg := config.Get()
711	agentConfig, ok := cfg.Agents[agentName]
712	if !ok {
713		return nil, fmt.Errorf("agent %s not found", agentName)
714	}
715	model, ok := models.SupportedModels[agentConfig.Model]
716	if !ok {
717		return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
718	}
719
720	providerCfg, ok := cfg.Providers[model.Provider]
721	if !ok {
722		return nil, fmt.Errorf("provider %s not supported", model.Provider)
723	}
724	if providerCfg.Disabled {
725		return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
726	}
727	maxTokens := model.DefaultMaxTokens
728	if agentConfig.MaxTokens > 0 {
729		maxTokens = agentConfig.MaxTokens
730	}
731	opts := []provider.ProviderClientOption{
732		provider.WithAPIKey(providerCfg.APIKey),
733		provider.WithModel(model),
734		provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
735		provider.WithMaxTokens(maxTokens),
736	}
737	if (model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal) && model.CanReason {
738		opts = append(
739			opts,
740			provider.WithOpenAIOptions(
741				provider.WithReasoningEffort(agentConfig.ReasoningEffort),
742			),
743		)
744	} else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
745		opts = append(
746			opts,
747			provider.WithAnthropicOptions(
748				provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
749			),
750		)
751	}
752	agentProvider, err := provider.NewProvider(
753		model.Provider,
754		opts...,
755	)
756	if err != nil {
757		return nil, fmt.Errorf("could not create provider: %v", err)
758	}
759
760	return agentProvider, nil
761}