agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"strings"
  8	"sync"
  9
 10	"github.com/kujtimiihoxha/opencode/internal/config"
 11	"github.com/kujtimiihoxha/opencode/internal/llm/models"
 12	"github.com/kujtimiihoxha/opencode/internal/llm/prompt"
 13	"github.com/kujtimiihoxha/opencode/internal/llm/provider"
 14	"github.com/kujtimiihoxha/opencode/internal/llm/tools"
 15	"github.com/kujtimiihoxha/opencode/internal/logging"
 16	"github.com/kujtimiihoxha/opencode/internal/message"
 17	"github.com/kujtimiihoxha/opencode/internal/permission"
 18	"github.com/kujtimiihoxha/opencode/internal/session"
 19)
 20
 21// Common errors
 22var (
 23	ErrRequestCancelled = errors.New("request cancelled by user")
 24	ErrSessionBusy      = errors.New("session is currently processing another request")
 25)
 26
 27type AgentEvent struct {
 28	message message.Message
 29	err     error
 30}
 31
 32func (e *AgentEvent) Err() error {
 33	return e.err
 34}
 35
 36func (e *AgentEvent) Response() message.Message {
 37	return e.message
 38}
 39
 40type Service interface {
 41	Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error)
 42	Cancel(sessionID string)
 43	IsSessionBusy(sessionID string) bool
 44}
 45
 46type agent struct {
 47	sessions session.Service
 48	messages message.Service
 49
 50	tools    []tools.BaseTool
 51	provider provider.Provider
 52
 53	titleProvider provider.Provider
 54
 55	activeRequests sync.Map
 56}
 57
 58func NewAgent(
 59	agentName config.AgentName,
 60	sessions session.Service,
 61	messages message.Service,
 62	agentTools []tools.BaseTool,
 63) (Service, error) {
 64	agentProvider, err := createAgentProvider(agentName)
 65	if err != nil {
 66		return nil, err
 67	}
 68	var titleProvider provider.Provider
 69	// Only generate titles for the coder agent
 70	if agentName == config.AgentCoder {
 71		titleProvider, err = createAgentProvider(config.AgentTitle)
 72		if err != nil {
 73			return nil, err
 74		}
 75	}
 76
 77	agent := &agent{
 78		provider:       agentProvider,
 79		messages:       messages,
 80		sessions:       sessions,
 81		tools:          agentTools,
 82		titleProvider:  titleProvider,
 83		activeRequests: sync.Map{},
 84	}
 85
 86	return agent, nil
 87}
 88
 89func (a *agent) Cancel(sessionID string) {
 90	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
 91		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
 92			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
 93			cancel()
 94		}
 95	}
 96}
 97
 98func (a *agent) IsSessionBusy(sessionID string) bool {
 99	_, busy := a.activeRequests.Load(sessionID)
100	return busy
101}
102
103func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
104	if a.titleProvider == nil {
105		return nil
106	}
107	session, err := a.sessions.Get(ctx, sessionID)
108	if err != nil {
109		return err
110	}
111	response, err := a.titleProvider.SendMessages(
112		ctx,
113		[]message.Message{
114			{
115				Role: message.User,
116				Parts: []message.ContentPart{
117					message.TextContent{
118						Text: content,
119					},
120				},
121			},
122		},
123		make([]tools.BaseTool, 0),
124	)
125	if err != nil {
126		return err
127	}
128
129	title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
130	if title == "" {
131		return nil
132	}
133
134	session.Title = title
135	_, err = a.sessions.Save(ctx, session)
136	return err
137}
138
139func (a *agent) err(err error) AgentEvent {
140	return AgentEvent{
141		err: err,
142	}
143}
144
145func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) {
146	events := make(chan AgentEvent)
147	if a.IsSessionBusy(sessionID) {
148		return nil, ErrSessionBusy
149	}
150
151	genCtx, cancel := context.WithCancel(ctx)
152
153	a.activeRequests.Store(sessionID, cancel)
154	go func() {
155		logging.Debug("Request started", "sessionID", sessionID)
156		defer logging.RecoverPanic("agent.Run", func() {
157			events <- a.err(fmt.Errorf("panic while running the agent"))
158		})
159
160		result := a.processGeneration(genCtx, sessionID, content)
161		if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
162			logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result))
163		}
164		logging.Debug("Request completed", "sessionID", sessionID)
165		a.activeRequests.Delete(sessionID)
166		cancel()
167		events <- result
168		close(events)
169	}()
170	return events, nil
171}
172
173func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent {
174	// List existing messages; if none, start title generation asynchronously.
175	msgs, err := a.messages.List(ctx, sessionID)
176	if err != nil {
177		return a.err(fmt.Errorf("failed to list messages: %w", err))
178	}
179	if len(msgs) == 0 {
180		go func() {
181			defer logging.RecoverPanic("agent.Run", func() {
182				logging.ErrorPersist("panic while generating title")
183			})
184			titleErr := a.generateTitle(context.Background(), sessionID, content)
185			if titleErr != nil {
186				logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
187			}
188		}()
189	}
190
191	userMsg, err := a.createUserMessage(ctx, sessionID, content)
192	if err != nil {
193		return a.err(fmt.Errorf("failed to create user message: %w", err))
194	}
195
196	// Append the new user message to the conversation history.
197	msgHistory := append(msgs, userMsg)
198	for {
199		// Check for cancellation before each iteration
200		select {
201		case <-ctx.Done():
202			return a.err(ctx.Err())
203		default:
204			// Continue processing
205		}
206		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
207		if err != nil {
208			if errors.Is(err, context.Canceled) {
209				return a.err(ErrRequestCancelled)
210			}
211			return a.err(fmt.Errorf("failed to process events: %w", err))
212		}
213		logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
214		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
215			// We are not done, we need to respond with the tool response
216			msgHistory = append(msgHistory, agentMessage, *toolResults)
217			continue
218		}
219		return AgentEvent{
220			message: agentMessage,
221		}
222	}
223}
224
225func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) {
226	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
227		Role: message.User,
228		Parts: []message.ContentPart{
229			message.TextContent{Text: content},
230		},
231	})
232}
233
234func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
235	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
236
237	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
238		Role:  message.Assistant,
239		Parts: []message.ContentPart{},
240		Model: a.provider.Model().ID,
241	})
242	if err != nil {
243		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
244	}
245
246	// Add the session and message ID into the context if needed by tools.
247	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
248	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
249
250	// Process each event in the stream.
251	for event := range eventChan {
252		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
253			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
254			return assistantMsg, nil, processErr
255		}
256		if ctx.Err() != nil {
257			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
258			return assistantMsg, nil, ctx.Err()
259		}
260	}
261
262	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
263	toolCalls := assistantMsg.ToolCalls()
264	for i, toolCall := range toolCalls {
265		select {
266		case <-ctx.Done():
267			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
268			// Make all future tool calls cancelled
269			for j := i; j < len(toolCalls); j++ {
270				toolResults[j] = message.ToolResult{
271					ToolCallID: toolCalls[j].ID,
272					Content:    "Tool execution canceled by user",
273					IsError:    true,
274				}
275			}
276			goto out
277		default:
278			// Continue processing
279			var tool tools.BaseTool
280			for _, availableTools := range a.tools {
281				if availableTools.Info().Name == toolCall.Name {
282					tool = availableTools
283				}
284			}
285
286			// Tool not found
287			if tool == nil {
288				toolResults[i] = message.ToolResult{
289					ToolCallID: toolCall.ID,
290					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
291					IsError:    true,
292				}
293				continue
294			}
295
296			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
297				ID:    toolCall.ID,
298				Name:  toolCall.Name,
299				Input: toolCall.Input,
300			})
301			if toolErr != nil {
302				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
303					toolResults[i] = message.ToolResult{
304						ToolCallID: toolCall.ID,
305						Content:    "Permission denied",
306						IsError:    true,
307					}
308					for j := i + 1; j < len(toolCalls); j++ {
309						toolResults[j] = message.ToolResult{
310							ToolCallID: toolCalls[j].ID,
311							Content:    "Tool execution canceled by user",
312							IsError:    true,
313						}
314					}
315					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
316				} else {
317					toolResults[i] = message.ToolResult{
318						ToolCallID: toolCall.ID,
319						Content:    toolErr.Error(),
320						IsError:    true,
321					}
322					for j := i; j < len(toolCalls); j++ {
323						toolResults[j] = message.ToolResult{
324							ToolCallID: toolCalls[j].ID,
325							Content:    "Previous tool failed",
326							IsError:    true,
327						}
328					}
329					a.finishMessage(ctx, &assistantMsg, message.FinishReasonError)
330				}
331				// If permission is denied or an error happens we cancel all the following tools
332				break
333			}
334			toolResults[i] = message.ToolResult{
335				ToolCallID: toolCall.ID,
336				Content:    toolResult.Content,
337				Metadata:   toolResult.Metadata,
338				IsError:    toolResult.IsError,
339			}
340		}
341	}
342out:
343	if len(toolResults) == 0 {
344		return assistantMsg, nil, nil
345	}
346	parts := make([]message.ContentPart, 0)
347	for _, tr := range toolResults {
348		parts = append(parts, tr)
349	}
350	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
351		Role:  message.Tool,
352		Parts: parts,
353	})
354	if err != nil {
355		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
356	}
357
358	return assistantMsg, &msg, err
359}
360
361func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
362	msg.AddFinish(finishReson)
363	_ = a.messages.Update(ctx, *msg)
364}
365
366func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
367	select {
368	case <-ctx.Done():
369		return ctx.Err()
370	default:
371		// Continue processing.
372	}
373
374	switch event.Type {
375	case provider.EventThinkingDelta:
376		assistantMsg.AppendReasoningContent(event.Content)
377		return a.messages.Update(ctx, *assistantMsg)
378	case provider.EventContentDelta:
379		assistantMsg.AppendContent(event.Content)
380		return a.messages.Update(ctx, *assistantMsg)
381	case provider.EventError:
382		if errors.Is(event.Error, context.Canceled) {
383			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
384			return context.Canceled
385		}
386		logging.ErrorPersist(event.Error.Error())
387		return event.Error
388	case provider.EventComplete:
389		assistantMsg.SetToolCalls(event.Response.ToolCalls)
390		assistantMsg.AddFinish(event.Response.FinishReason)
391		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
392			return fmt.Errorf("failed to update message: %w", err)
393		}
394		return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
395	}
396
397	return nil
398}
399
400func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
401	sess, err := a.sessions.Get(ctx, sessionID)
402	if err != nil {
403		return fmt.Errorf("failed to get session: %w", err)
404	}
405
406	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
407		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
408		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
409		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
410
411	sess.Cost += cost
412	sess.CompletionTokens += usage.OutputTokens
413	sess.PromptTokens += usage.InputTokens
414
415	_, err = a.sessions.Save(ctx, sess)
416	if err != nil {
417		return fmt.Errorf("failed to save session: %w", err)
418	}
419	return nil
420}
421
422func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
423	cfg := config.Get()
424	agentConfig, ok := cfg.Agents[agentName]
425	if !ok {
426		return nil, fmt.Errorf("agent %s not found", agentName)
427	}
428	model, ok := models.SupportedModels[agentConfig.Model]
429	if !ok {
430		return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
431	}
432
433	providerCfg, ok := cfg.Providers[model.Provider]
434	if !ok {
435		return nil, fmt.Errorf("provider %s not supported", model.Provider)
436	}
437	if providerCfg.Disabled {
438		return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
439	}
440	agentProvider, err := provider.NewProvider(
441		model.Provider,
442		provider.WithAPIKey(providerCfg.APIKey),
443		provider.WithModel(model),
444		provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
445		provider.WithMaxTokens(agentConfig.MaxTokens),
446	)
447	if err != nil {
448		return nil, fmt.Errorf("could not create provider: %v", err)
449	}
450
451	return agentProvider, nil
452}