agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"strings"
  8	"sync"
  9
 10	"github.com/opencode-ai/opencode/internal/config"
 11	"github.com/opencode-ai/opencode/internal/llm/models"
 12	"github.com/opencode-ai/opencode/internal/llm/prompt"
 13	"github.com/opencode-ai/opencode/internal/llm/provider"
 14	"github.com/opencode-ai/opencode/internal/llm/tools"
 15	"github.com/opencode-ai/opencode/internal/logging"
 16	"github.com/opencode-ai/opencode/internal/message"
 17	"github.com/opencode-ai/opencode/internal/permission"
 18	"github.com/opencode-ai/opencode/internal/session"
 19)
 20
 21// Common errors
 22var (
 23	ErrRequestCancelled = errors.New("request cancelled by user")
 24	ErrSessionBusy      = errors.New("session is currently processing another request")
 25)
 26
 27type AgentEvent struct {
 28	message message.Message
 29	err     error
 30}
 31
 32func (e *AgentEvent) Err() error {
 33	return e.err
 34}
 35
 36func (e *AgentEvent) Response() message.Message {
 37	return e.message
 38}
 39
 40type Service interface {
 41	Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error)
 42	Cancel(sessionID string)
 43	IsSessionBusy(sessionID string) bool
 44	IsBusy() bool
 45}
 46
 47type agent struct {
 48	sessions session.Service
 49	messages message.Service
 50
 51	tools    []tools.BaseTool
 52	provider provider.Provider
 53
 54	titleProvider provider.Provider
 55
 56	activeRequests sync.Map
 57}
 58
 59func NewAgent(
 60	agentName config.AgentName,
 61	sessions session.Service,
 62	messages message.Service,
 63	agentTools []tools.BaseTool,
 64) (Service, error) {
 65	agentProvider, err := createAgentProvider(agentName)
 66	if err != nil {
 67		return nil, err
 68	}
 69	var titleProvider provider.Provider
 70	// Only generate titles for the coder agent
 71	if agentName == config.AgentCoder {
 72		titleProvider, err = createAgentProvider(config.AgentTitle)
 73		if err != nil {
 74			return nil, err
 75		}
 76	}
 77
 78	agent := &agent{
 79		provider:       agentProvider,
 80		messages:       messages,
 81		sessions:       sessions,
 82		tools:          agentTools,
 83		titleProvider:  titleProvider,
 84		activeRequests: sync.Map{},
 85	}
 86
 87	return agent, nil
 88}
 89
 90func (a *agent) Cancel(sessionID string) {
 91	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
 92		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
 93			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
 94			cancel()
 95		}
 96	}
 97}
 98
 99func (a *agent) IsBusy() bool {
100	busy := false
101	a.activeRequests.Range(func(key, value interface{}) bool {
102		if cancelFunc, ok := value.(context.CancelFunc); ok {
103			if cancelFunc != nil {
104				busy = true
105				return false // Stop iterating
106			}
107		}
108		return true // Continue iterating
109	})
110	return busy
111}
112
113func (a *agent) IsSessionBusy(sessionID string) bool {
114	_, busy := a.activeRequests.Load(sessionID)
115	return busy
116}
117
118func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
119	if a.titleProvider == nil {
120		return nil
121	}
122	session, err := a.sessions.Get(ctx, sessionID)
123	if err != nil {
124		return err
125	}
126	response, err := a.titleProvider.SendMessages(
127		ctx,
128		[]message.Message{
129			{
130				Role: message.User,
131				Parts: []message.ContentPart{
132					message.TextContent{
133						Text: content,
134					},
135				},
136			},
137		},
138		make([]tools.BaseTool, 0),
139	)
140	if err != nil {
141		return err
142	}
143
144	title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
145	if title == "" {
146		return nil
147	}
148
149	session.Title = title
150	_, err = a.sessions.Save(ctx, session)
151	return err
152}
153
154func (a *agent) err(err error) AgentEvent {
155	return AgentEvent{
156		err: err,
157	}
158}
159
160func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) {
161	events := make(chan AgentEvent)
162	if a.IsSessionBusy(sessionID) {
163		return nil, ErrSessionBusy
164	}
165
166	genCtx, cancel := context.WithCancel(ctx)
167
168	a.activeRequests.Store(sessionID, cancel)
169	go func() {
170		logging.Debug("Request started", "sessionID", sessionID)
171		defer logging.RecoverPanic("agent.Run", func() {
172			events <- a.err(fmt.Errorf("panic while running the agent"))
173		})
174
175		result := a.processGeneration(genCtx, sessionID, content)
176		if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
177			logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result))
178		}
179		logging.Debug("Request completed", "sessionID", sessionID)
180		a.activeRequests.Delete(sessionID)
181		cancel()
182		events <- result
183		close(events)
184	}()
185	return events, nil
186}
187
188func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent {
189	// List existing messages; if none, start title generation asynchronously.
190	msgs, err := a.messages.List(ctx, sessionID)
191	if err != nil {
192		return a.err(fmt.Errorf("failed to list messages: %w", err))
193	}
194	if len(msgs) == 0 {
195		go func() {
196			defer logging.RecoverPanic("agent.Run", func() {
197				logging.ErrorPersist("panic while generating title")
198			})
199			titleErr := a.generateTitle(context.Background(), sessionID, content)
200			if titleErr != nil {
201				logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
202			}
203		}()
204	}
205
206	userMsg, err := a.createUserMessage(ctx, sessionID, content)
207	if err != nil {
208		return a.err(fmt.Errorf("failed to create user message: %w", err))
209	}
210
211	// Append the new user message to the conversation history.
212	msgHistory := append(msgs, userMsg)
213	for {
214		// Check for cancellation before each iteration
215		select {
216		case <-ctx.Done():
217			return a.err(ctx.Err())
218		default:
219			// Continue processing
220		}
221		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
222		if err != nil {
223			if errors.Is(err, context.Canceled) {
224				agentMessage.AddFinish(message.FinishReasonCanceled)
225				a.messages.Update(context.Background(), agentMessage)
226				return a.err(ErrRequestCancelled)
227			}
228			return a.err(fmt.Errorf("failed to process events: %w", err))
229		}
230		logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
231		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
232			// We are not done, we need to respond with the tool response
233			msgHistory = append(msgHistory, agentMessage, *toolResults)
234			continue
235		}
236		return AgentEvent{
237			message: agentMessage,
238		}
239	}
240}
241
242func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) {
243	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
244		Role: message.User,
245		Parts: []message.ContentPart{
246			message.TextContent{Text: content},
247		},
248	})
249}
250
251func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
252	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
253
254	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
255		Role:  message.Assistant,
256		Parts: []message.ContentPart{},
257		Model: a.provider.Model().ID,
258	})
259	if err != nil {
260		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
261	}
262
263	// Add the session and message ID into the context if needed by tools.
264	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
265	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
266
267	// Process each event in the stream.
268	for event := range eventChan {
269		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
270			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
271			return assistantMsg, nil, processErr
272		}
273		if ctx.Err() != nil {
274			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
275			return assistantMsg, nil, ctx.Err()
276		}
277	}
278
279	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
280	toolCalls := assistantMsg.ToolCalls()
281	for i, toolCall := range toolCalls {
282		select {
283		case <-ctx.Done():
284			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
285			// Make all future tool calls cancelled
286			for j := i; j < len(toolCalls); j++ {
287				toolResults[j] = message.ToolResult{
288					ToolCallID: toolCalls[j].ID,
289					Content:    "Tool execution canceled by user",
290					IsError:    true,
291				}
292			}
293			goto out
294		default:
295			// Continue processing
296			var tool tools.BaseTool
297			for _, availableTools := range a.tools {
298				if availableTools.Info().Name == toolCall.Name {
299					tool = availableTools
300				}
301			}
302
303			// Tool not found
304			if tool == nil {
305				toolResults[i] = message.ToolResult{
306					ToolCallID: toolCall.ID,
307					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
308					IsError:    true,
309				}
310				continue
311			}
312
313			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
314				ID:    toolCall.ID,
315				Name:  toolCall.Name,
316				Input: toolCall.Input,
317			})
318			if toolErr != nil {
319				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
320					toolResults[i] = message.ToolResult{
321						ToolCallID: toolCall.ID,
322						Content:    "Permission denied",
323						IsError:    true,
324					}
325					for j := i + 1; j < len(toolCalls); j++ {
326						toolResults[j] = message.ToolResult{
327							ToolCallID: toolCalls[j].ID,
328							Content:    "Tool execution canceled by user",
329							IsError:    true,
330						}
331					}
332					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
333					break
334				}
335			}
336			toolResults[i] = message.ToolResult{
337				ToolCallID: toolCall.ID,
338				Content:    toolResult.Content,
339				Metadata:   toolResult.Metadata,
340				IsError:    toolResult.IsError,
341			}
342		}
343	}
344out:
345	if len(toolResults) == 0 {
346		return assistantMsg, nil, nil
347	}
348	parts := make([]message.ContentPart, 0)
349	for _, tr := range toolResults {
350		parts = append(parts, tr)
351	}
352	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
353		Role:  message.Tool,
354		Parts: parts,
355	})
356	if err != nil {
357		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
358	}
359
360	return assistantMsg, &msg, err
361}
362
363func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
364	msg.AddFinish(finishReson)
365	_ = a.messages.Update(ctx, *msg)
366}
367
368func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
369	select {
370	case <-ctx.Done():
371		return ctx.Err()
372	default:
373		// Continue processing.
374	}
375
376	switch event.Type {
377	case provider.EventThinkingDelta:
378		assistantMsg.AppendReasoningContent(event.Content)
379		return a.messages.Update(ctx, *assistantMsg)
380	case provider.EventContentDelta:
381		assistantMsg.AppendContent(event.Content)
382		return a.messages.Update(ctx, *assistantMsg)
383	case provider.EventToolUseStart:
384		assistantMsg.AddToolCall(*event.ToolCall)
385		return a.messages.Update(ctx, *assistantMsg)
386	// TODO: see how to handle this
387	// case provider.EventToolUseDelta:
388	// 	tm := time.Unix(assistantMsg.UpdatedAt, 0)
389	// 	assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
390	// 	if time.Since(tm) > 1000*time.Millisecond {
391	// 		err := a.messages.Update(ctx, *assistantMsg)
392	// 		assistantMsg.UpdatedAt = time.Now().Unix()
393	// 		return err
394	// 	}
395	case provider.EventToolUseStop:
396		assistantMsg.FinishToolCall(event.ToolCall.ID)
397		return a.messages.Update(ctx, *assistantMsg)
398	case provider.EventError:
399		if errors.Is(event.Error, context.Canceled) {
400			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
401			return context.Canceled
402		}
403		logging.ErrorPersist(event.Error.Error())
404		return event.Error
405	case provider.EventComplete:
406		assistantMsg.SetToolCalls(event.Response.ToolCalls)
407		assistantMsg.AddFinish(event.Response.FinishReason)
408		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
409			return fmt.Errorf("failed to update message: %w", err)
410		}
411		return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
412	}
413
414	return nil
415}
416
417func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
418	sess, err := a.sessions.Get(ctx, sessionID)
419	if err != nil {
420		return fmt.Errorf("failed to get session: %w", err)
421	}
422
423	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
424		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
425		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
426		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
427
428	sess.Cost += cost
429	sess.CompletionTokens += usage.OutputTokens
430	sess.PromptTokens += usage.InputTokens
431
432	_, err = a.sessions.Save(ctx, sess)
433	if err != nil {
434		return fmt.Errorf("failed to save session: %w", err)
435	}
436	return nil
437}
438
439func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
440	cfg := config.Get()
441	agentConfig, ok := cfg.Agents[agentName]
442	if !ok {
443		return nil, fmt.Errorf("agent %s not found", agentName)
444	}
445	model, ok := models.SupportedModels[agentConfig.Model]
446	if !ok {
447		return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
448	}
449
450	providerCfg, ok := cfg.Providers[model.Provider]
451	if !ok {
452		return nil, fmt.Errorf("provider %s not supported", model.Provider)
453	}
454	if providerCfg.Disabled {
455		return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
456	}
457	maxTokens := model.DefaultMaxTokens
458	if agentConfig.MaxTokens > 0 {
459		maxTokens = agentConfig.MaxTokens
460	}
461	opts := []provider.ProviderClientOption{
462		provider.WithAPIKey(providerCfg.APIKey),
463		provider.WithModel(model),
464		provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
465		provider.WithMaxTokens(maxTokens),
466	}
467	if model.Provider == models.ProviderOpenAI && model.CanReason {
468		opts = append(
469			opts,
470			provider.WithOpenAIOptions(
471				provider.WithReasoningEffort(agentConfig.ReasoningEffort),
472			),
473		)
474	} else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
475		opts = append(
476			opts,
477			provider.WithAnthropicOptions(
478				provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
479			),
480		)
481	}
482	agentProvider, err := provider.NewProvider(
483		model.Provider,
484		opts...,
485	)
486	if err != nil {
487		return nil, fmt.Errorf("could not create provider: %v", err)
488	}
489
490	return agentProvider, nil
491}