agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"time"
 10
 11	"github.com/charmbracelet/catwalk/pkg/catwalk"
 12	"github.com/charmbracelet/crush/internal/agent/prompt"
 13	"github.com/charmbracelet/crush/internal/agent/tools"
 14	"github.com/charmbracelet/crush/internal/ai"
 15	"github.com/charmbracelet/crush/internal/ai/providers"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/csync"
 18	"github.com/charmbracelet/crush/internal/history"
 19	"github.com/charmbracelet/crush/internal/lsp"
 20	"github.com/charmbracelet/crush/internal/message"
 21	"github.com/charmbracelet/crush/internal/permission"
 22	"github.com/charmbracelet/crush/internal/pubsub"
 23	"github.com/charmbracelet/crush/internal/session"
 24)
 25
 26// Common errors
 27var (
 28	ErrRequestCancelled = errors.New("request canceled by user")
 29	ErrSessionBusy      = errors.New("session is currently processing another request")
 30)
 31
 32type AgentEventType string
 33
 34const (
 35	AgentEventTypeError     AgentEventType = "error"
 36	AgentEventTypeResponse  AgentEventType = "response"
 37	AgentEventTypeSummarize AgentEventType = "summarize"
 38)
 39
 40type AgentEvent struct {
 41	Type   AgentEventType
 42	Result ai.AgentResult
 43	Error  error
 44
 45	// When summarizing
 46	SessionID string
 47	Progress  string
 48	Done      bool
 49}
 50
 51type Service interface {
 52	pubsub.Suscriber[AgentEvent]
 53	Model() catwalk.Model
 54	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 55	Cancel(sessionID string)
 56	CancelAll()
 57	IsSessionBusy(sessionID string) bool
 58	IsBusy() bool
 59	Summarize(ctx context.Context, sessionID string) error
 60	UpdateModel() error
 61	QueuedPrompts(sessionID string) int
 62	ClearQueue(sessionID string)
 63}
 64
 65type agent struct {
 66	*pubsub.Broker[AgentEvent]
 67	cfg            *config.Config
 68	permissions    permission.Service
 69	sessions       session.Service
 70	messages       message.Service
 71	history        history.Service
 72	lspClients     map[string]*lsp.Client
 73	activeRequests *csync.Map[string, context.CancelFunc]
 74
 75	promptQueue *csync.Map[string, []string]
 76}
 77
 78type AgentOption = func(*agent)
 79
 80// WIP this is a work in progress
 81func NewAgent(
 82	cfg *config.Config,
 83	permissions permission.Service,
 84	sessions session.Service,
 85	messages message.Service,
 86	history history.Service,
 87	lspClients map[string]*lsp.Client,
 88) Service {
 89	return &agent{
 90		cfg:            cfg,
 91		Broker:         pubsub.NewBroker[AgentEvent](),
 92		permissions:    permissions,
 93		sessions:       sessions,
 94		messages:       messages,
 95		history:        history,
 96		lspClients:     lspClients,
 97		activeRequests: csync.NewMap[string, context.CancelFunc](),
 98		promptQueue:    csync.NewMap[string, []string](),
 99	}
100}
101
102func (a *agent) getLanguageModel(providerName, modelID string) (ai.LanguageModel, error) {
103	var provider ai.Provider
104	providerCfg, ok := a.cfg.Providers.Get(providerName)
105	if !ok {
106		return nil, errors.New("provider not found")
107	}
108
109	models := providerCfg.Models
110	foundModel := false
111	for _, providerModel := range models {
112		if providerModel.ID == modelID {
113			foundModel = true
114			break
115		}
116	}
117	if !foundModel {
118		return nil, fmt.Errorf("model `%s` not found in provider `%s`", modelID, providerName)
119	}
120	switch providerName {
121	case "openai":
122		apiKey, err := a.cfg.Resolve(providerCfg.APIKey)
123		if err != nil {
124			return nil, err
125		}
126		baseURL, err := a.cfg.Resolve(providerCfg.BaseURL)
127		if err != nil {
128			return nil, err
129		}
130		opts := []providers.OpenAiOption{
131			providers.WithOpenAiAPIKey(apiKey),
132		}
133		if baseURL != "" {
134			opts = append(opts, providers.WithOpenAiBaseURL(baseURL))
135		}
136		provider = providers.NewOpenAiProvider(opts...)
137	default:
138		return nil, errors.New("provider not found")
139	}
140	if provider == nil {
141		return nil, errors.New("provider not found")
142	}
143	return provider.LanguageModel(modelID)
144}
145
146func (a *agent) tools(ctx context.Context) []ai.AgentTool {
147	cwd := a.cfg.WorkingDir()
148	allTools := []ai.AgentTool{
149		tools.NewBashTool(a.permissions, cwd),
150		tools.NewDownloadTool(a.permissions, cwd),
151		tools.NewEditTool(a.lspClients, a.permissions, a.history, cwd),
152		tools.NewMultiEditTool(a.lspClients, a.permissions, a.history, cwd),
153		tools.NewFetchTool(a.permissions, cwd),
154		tools.NewGlobTool(cwd),
155		tools.NewGrepTool(cwd),
156		tools.NewLSTool(a.permissions, cwd),
157		tools.NewSourcegraphTool(),
158		tools.NewViewTool(a.lspClients, a.permissions, cwd),
159		tools.NewWriteTool(a.lspClients, a.permissions, a.history, cwd),
160	}
161	mcpTools := tools.GetMCPTools(ctx, a.permissions, a.cfg)
162
163	allTools = append(allTools, mcpTools...)
164
165	if len(a.lspClients) > 0 {
166		allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients))
167	}
168	// TODO: add agent tool
169	return allTools
170}
171
172// Run implements Service.
173func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
174	// INFO: for now we assume that the agent uses the large model
175	configModel := a.cfg.Models[config.SelectedModelTypeLarge]
176	model, err := a.getLanguageModel(configModel.Provider, configModel.Model)
177	if err != nil {
178		return nil, err
179	}
180
181	modelCfg := a.Model()
182	maxTokens := configModel.MaxTokens
183	if maxTokens == 0 {
184		maxTokens = modelCfg.DefaultMaxTokens
185	}
186
187	if !modelCfg.SupportsImages && attachments != nil {
188		attachments = nil
189	}
190
191	agent := ai.NewAgent(
192		model,
193		ai.WithSystemPrompt(
194			prompt.CoderPrompt(configModel.Provider, a.cfg.Options.ContextPaths...),
195		),
196		ai.WithTools(a.tools(ctx)...),
197		ai.WithMaxOutputTokens(maxTokens),
198	)
199
200	events := make(chan AgentEvent, 1)
201	if a.IsSessionBusy(sessionID) {
202		existing, ok := a.promptQueue.Get(sessionID)
203		if !ok {
204			existing = []string{}
205		}
206		existing = append(existing, content)
207		a.promptQueue.Set(sessionID, existing)
208		return nil, nil
209	}
210
211	genCtx, cancel := context.WithCancel(ctx)
212	a.activeRequests.Set(sessionID, cancel)
213
214	go func() {
215		slog.Debug("Request started", "sessionID", sessionID)
216
217		result, err := a.makeCall(genCtx, agent, sessionID, content, attachments)
218		a.activeRequests.Del(sessionID)
219		cancel()
220		if err != nil {
221			slog.Error(err.Error())
222			events <- AgentEvent{
223				Type:  AgentEventTypeError,
224				Error: err,
225			}
226		} else {
227			result := AgentEvent{
228				Type:   AgentEventTypeResponse,
229				Result: *result,
230			}
231			a.Publish(pubsub.CreatedEvent, result)
232			events <- result
233		}
234		slog.Debug("Request completed", "sessionID", sessionID)
235		// TODO: implement this
236		close(events)
237	}()
238	return events, nil
239}
240
241func (a *agent) makeCall(ctx context.Context, agent ai.Agent, sessionID, prompt string, attachments []message.Attachment) (*ai.AgentResult, error) {
242	msgs, err := a.messages.List(ctx, sessionID)
243	if err != nil {
244		return nil, fmt.Errorf("failed to list messages: %w", err)
245	}
246	if len(msgs) == 0 {
247		go func() {
248			// TODO: generate title
249			// titleErr := a.generateTitle(context.Background(), sessionID, content)
250			// if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
251			// 	slog.Error("failed to generate title", "error", titleErr)
252			// }
253		}()
254	}
255	session, err := a.sessions.Get(ctx, sessionID)
256	if err != nil {
257		return nil, fmt.Errorf("failed to get session: %w", err)
258	}
259	if session.SummaryMessageID != "" {
260		summaryMsgInex := -1
261		for i, msg := range msgs {
262			if msg.ID == session.SummaryMessageID {
263				summaryMsgInex = i
264				break
265			}
266		}
267		if summaryMsgInex != -1 {
268			msgs = msgs[summaryMsgInex:]
269			msgs[0].Role = message.User
270		}
271	}
272
273	// Create the user message
274	var attachmentParts []message.ContentPart
275	for _, attachment := range attachments {
276		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
277	}
278	parts := []message.ContentPart{message.TextContent{Text: prompt}}
279	parts = append(parts, attachmentParts...)
280	_, err = a.messages.Create(ctx, sessionID, message.CreateMessageParams{
281		Role:  message.User,
282		Parts: parts,
283	})
284	if err != nil {
285		return nil, fmt.Errorf("failed to create user message: %w", err)
286	}
287
288	var history []ai.Message
289	for _, m := range msgs {
290		if len(m.Parts) == 0 {
291			continue
292		}
293		// Assistant message without content or tool calls (cancelled before it returned anything)
294		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
295			continue
296		}
297		history = append(history, m.ToAIMessage()...)
298	}
299
300	var files []ai.FilePart
301	for _, attachment := range attachments {
302		files = append(files, ai.FilePart{
303			Filename:  attachment.FileName,
304			Data:      attachment.Content,
305			MediaType: attachment.MimeType,
306		})
307	}
308	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
309	// TODO: see if this is even needed
310	ctx = context.WithValue(ctx, tools.MessageIDContextKey, "mock")
311
312	var currentAssistant *message.Message
313	result, err := agent.Stream(ctx, ai.AgentStreamCall{
314		Prompt:   prompt,
315		Files:    files,
316		Messages: history,
317		// Get's called before each step
318		PrepareStep: func(options ai.PrepareStepFunctionOptions) (ai.PrepareStepResult, error) {
319			prepared := ai.PrepareStepResult{}
320			modelCfg := a.cfg.Models[config.SelectedModelTypeLarge]
321			// Before each step create the new assistant message
322			assistantMsg, createErr := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
323				Role:     message.Assistant,
324				Parts:    []message.ContentPart{},
325				Model:    modelCfg.Model,
326				Provider: modelCfg.Provider,
327			})
328			if createErr != nil {
329				return prepared, createErr
330			}
331			currentAssistant = &assistantMsg
332			return prepared, nil
333		},
334		OnChunk: func(chunk ai.StreamPart) error {
335			data, _ := json.Marshal(chunk)
336			slog.Info("\n" + string(data) + "\n")
337			return nil
338		},
339		// TODO: see how to not swallow the errors on these handlers
340		OnReasoningDelta: func(id string, text string) error {
341			currentAssistant.AppendReasoningContent(text)
342			return a.messages.Update(ctx, *currentAssistant)
343		},
344		OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
345			// handle anthropic signature
346			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
347				if signature, ok := anthropicData["signature"]; ok {
348					if str, ok := signature.(string); ok {
349						currentAssistant.AppendReasoningSignature(str)
350					}
351				}
352			}
353			currentAssistant.FinishThinking()
354			return a.messages.Update(ctx, *currentAssistant)
355		},
356		OnTextDelta: func(id string, text string) error {
357			currentAssistant.AppendContent(text)
358			return a.messages.Update(ctx, *currentAssistant)
359		},
360		OnToolInputStart: func(id string, toolName string) error {
361			toolCall := message.ToolCall{
362				ID:               id,
363				Name:             toolName,
364				ProviderExecuted: false,
365				Finished:         false,
366			}
367			slog.Info("Tool call started", "toolCall", toolName)
368			currentAssistant.AddToolCall(toolCall)
369			return a.messages.Update(ctx, *currentAssistant)
370		},
371		OnToolCall: func(tc ai.ToolCallContent) error {
372			toolCall := message.ToolCall{
373				ID:               tc.ToolCallID,
374				Name:             tc.ToolName,
375				Input:            tc.Input,
376				ProviderExecuted: false,
377				Finished:         true,
378			}
379			currentAssistant.AddToolCall(toolCall)
380			return a.messages.Update(ctx, *currentAssistant)
381		},
382		OnToolResult: func(result ai.ToolResultContent) error {
383			var resultContent string
384			isError := false
385			switch result.Result.GetType() {
386			case ai.ToolResultContentTypeText:
387				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Result)
388				if ok {
389					resultContent = r.Text
390				}
391			case ai.ToolResultContentTypeError:
392				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Result)
393				if ok {
394					isError = true
395					resultContent = r.Error.Error()
396				}
397			case ai.ToolResultContentTypeMedia:
398				// TODO: handle this message type
399			}
400			toolResult := message.ToolResult{
401				ToolCallID: result.ToolCallID,
402				Name:       result.ToolName,
403				Content:    resultContent,
404				IsError:    isError,
405				Metadata:   result.ClientMetadata,
406			}
407			currentAssistant.AddToolResult(toolResult)
408			return a.messages.Update(ctx, *currentAssistant)
409		},
410		OnStepFinish: func(stepResult ai.StepResult) error {
411			slog.Info("Step Finished", "result", stepResult)
412			finishReason := message.FinishReasonUnknown
413			switch stepResult.FinishReason {
414			case ai.FinishReasonLength:
415				finishReason = message.FinishReasonMaxTokens
416			case ai.FinishReasonStop:
417				finishReason = message.FinishReasonEndTurn
418			case ai.FinishReasonToolCalls:
419				finishReason = message.FinishReasonToolUse
420			}
421			currentAssistant.AddFinish(finishReason, "", "")
422			return a.messages.Update(ctx, *currentAssistant)
423		},
424	})
425	if err != nil {
426		isCancelErr := errors.Is(err, context.Canceled)
427		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
428		if currentAssistant == nil {
429			return result, err
430		}
431		toolCalls := currentAssistant.ToolCalls()
432		toolResults := currentAssistant.ToolResults()
433		for _, tc := range toolCalls {
434			if !tc.Finished {
435				tc.Finished = true
436				tc.Input = "{}"
437			}
438			currentAssistant.AddToolCall(tc)
439			found := false
440			for _, tr := range toolResults {
441				if tr.ToolCallID == tc.ID {
442					found = true
443					break
444				}
445			}
446			if !found {
447				content := "There was an error while executing the tool"
448				if isCancelErr {
449					content = "Tool execution canceled by user"
450				} else if isPermissionErr {
451					content = "Permission denied"
452				}
453				currentAssistant.AddToolResult(message.ToolResult{
454					ToolCallID: tc.ID,
455					Name:       tc.Name,
456					Content:    content,
457					IsError:    true,
458				})
459			}
460		}
461		if isCancelErr {
462			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
463		} else if isPermissionErr {
464			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
465		} else {
466			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
467		}
468		// TODO: handle error?
469		_ = a.messages.Update(context.Background(), *currentAssistant)
470	}
471	return result, err
472}
473
474// Summarize implements Service.
475func (a *agent) Summarize(ctx context.Context, sessionID string) error {
476	// TODO: implement
477	return nil
478}
479
480// UpdateModel implements Service.
481func (a *agent) UpdateModel() error {
482	return nil
483}
484
485func (a *agent) Cancel(sessionID string) {
486	// Cancel regular requests
487	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
488		slog.Info("Request cancellation initiated", "session_id", sessionID)
489		cancel()
490	}
491
492	// Also check for summarize requests
493	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
494		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
495		cancel()
496	}
497
498	if a.QueuedPrompts(sessionID) > 0 {
499		slog.Info("Clearing queued prompts", "session_id", sessionID)
500		a.promptQueue.Del(sessionID)
501	}
502}
503
504func (a *agent) ClearQueue(sessionID string) {
505	if a.QueuedPrompts(sessionID) > 0 {
506		slog.Info("Clearing queued prompts", "session_id", sessionID)
507		a.promptQueue.Del(sessionID)
508	}
509}
510
511func (a *agent) CancelAll() {
512	if !a.IsBusy() {
513		return
514	}
515	for key := range a.activeRequests.Seq2() {
516		a.Cancel(key) // key is sessionID
517	}
518
519	timeout := time.After(5 * time.Second)
520	for a.IsBusy() {
521		select {
522		case <-timeout:
523			return
524		default:
525			time.Sleep(200 * time.Millisecond)
526		}
527	}
528}
529
530func (a *agent) IsBusy() bool {
531	var busy bool
532	for cancelFunc := range a.activeRequests.Seq() {
533		if cancelFunc != nil {
534			busy = true
535			break
536		}
537	}
538	return busy
539}
540
541func (a *agent) IsSessionBusy(sessionID string) bool {
542	_, busy := a.activeRequests.Get(sessionID)
543	return busy
544}
545
546func (a *agent) Model() catwalk.Model {
547	return *a.cfg.GetModelByType(config.SelectedModelTypeLarge)
548}
549
550func (a *agent) QueuedPrompts(sessionID string) int {
551	l, ok := a.promptQueue.Get(sessionID)
552	if !ok {
553		return 0
554	}
555	return len(l)
556}