1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "sync"
9
10 "github.com/kujtimiihoxha/opencode/internal/config"
11 "github.com/kujtimiihoxha/opencode/internal/llm/models"
12 "github.com/kujtimiihoxha/opencode/internal/llm/prompt"
13 "github.com/kujtimiihoxha/opencode/internal/llm/provider"
14 "github.com/kujtimiihoxha/opencode/internal/llm/tools"
15 "github.com/kujtimiihoxha/opencode/internal/logging"
16 "github.com/kujtimiihoxha/opencode/internal/message"
17 "github.com/kujtimiihoxha/opencode/internal/permission"
18 "github.com/kujtimiihoxha/opencode/internal/session"
19)
20
21// Common errors
22var (
23 ErrRequestCancelled = errors.New("request cancelled by user")
24 ErrSessionBusy = errors.New("session is currently processing another request")
25)
26
27type AgentEvent struct {
28 message message.Message
29 err error
30}
31
32func (e *AgentEvent) Err() error {
33 return e.err
34}
35
36func (e *AgentEvent) Response() message.Message {
37 return e.message
38}
39
40type Service interface {
41 Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error)
42 Cancel(sessionID string)
43 IsSessionBusy(sessionID string) bool
44}
45
46type agent struct {
47 sessions session.Service
48 messages message.Service
49
50 tools []tools.BaseTool
51 provider provider.Provider
52
53 titleProvider provider.Provider
54
55 activeRequests sync.Map
56}
57
58func NewAgent(
59 agentName config.AgentName,
60 sessions session.Service,
61 messages message.Service,
62 agentTools []tools.BaseTool,
63) (Service, error) {
64 agentProvider, err := createAgentProvider(agentName)
65 if err != nil {
66 return nil, err
67 }
68 var titleProvider provider.Provider
69 // Only generate titles for the coder agent
70 if agentName == config.AgentCoder {
71 titleProvider, err = createAgentProvider(config.AgentTitle)
72 if err != nil {
73 return nil, err
74 }
75 }
76
77 agent := &agent{
78 provider: agentProvider,
79 messages: messages,
80 sessions: sessions,
81 tools: agentTools,
82 titleProvider: titleProvider,
83 activeRequests: sync.Map{},
84 }
85
86 return agent, nil
87}
88
89func (a *agent) Cancel(sessionID string) {
90 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
91 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
92 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
93 cancel()
94 }
95 }
96}
97
98func (a *agent) IsSessionBusy(sessionID string) bool {
99 _, busy := a.activeRequests.Load(sessionID)
100 return busy
101}
102
103func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
104 if a.titleProvider == nil {
105 return nil
106 }
107 session, err := a.sessions.Get(ctx, sessionID)
108 if err != nil {
109 return err
110 }
111 response, err := a.titleProvider.SendMessages(
112 ctx,
113 []message.Message{
114 {
115 Role: message.User,
116 Parts: []message.ContentPart{
117 message.TextContent{
118 Text: content,
119 },
120 },
121 },
122 },
123 make([]tools.BaseTool, 0),
124 )
125 if err != nil {
126 return err
127 }
128
129 title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
130 if title == "" {
131 return nil
132 }
133
134 session.Title = title
135 _, err = a.sessions.Save(ctx, session)
136 return err
137}
138
139func (a *agent) err(err error) AgentEvent {
140 return AgentEvent{
141 err: err,
142 }
143}
144
145func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) {
146 events := make(chan AgentEvent)
147 if a.IsSessionBusy(sessionID) {
148 return nil, ErrSessionBusy
149 }
150
151 genCtx, cancel := context.WithCancel(ctx)
152
153 a.activeRequests.Store(sessionID, cancel)
154 go func() {
155 logging.Debug("Request started", "sessionID", sessionID)
156 defer logging.RecoverPanic("agent.Run", func() {
157 events <- a.err(fmt.Errorf("panic while running the agent"))
158 })
159
160 result := a.processGeneration(genCtx, sessionID, content)
161 if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
162 logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result))
163 }
164 logging.Debug("Request completed", "sessionID", sessionID)
165 a.activeRequests.Delete(sessionID)
166 cancel()
167 events <- result
168 close(events)
169 }()
170 return events, nil
171}
172
173func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent {
174 // List existing messages; if none, start title generation asynchronously.
175 msgs, err := a.messages.List(ctx, sessionID)
176 if err != nil {
177 return a.err(fmt.Errorf("failed to list messages: %w", err))
178 }
179 if len(msgs) == 0 {
180 go func() {
181 defer logging.RecoverPanic("agent.Run", func() {
182 logging.ErrorPersist("panic while generating title")
183 })
184 titleErr := a.generateTitle(context.Background(), sessionID, content)
185 if titleErr != nil {
186 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
187 }
188 }()
189 }
190
191 userMsg, err := a.createUserMessage(ctx, sessionID, content)
192 if err != nil {
193 return a.err(fmt.Errorf("failed to create user message: %w", err))
194 }
195
196 // Append the new user message to the conversation history.
197 msgHistory := append(msgs, userMsg)
198 for {
199 // Check for cancellation before each iteration
200 select {
201 case <-ctx.Done():
202 return a.err(ctx.Err())
203 default:
204 // Continue processing
205 }
206 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
207 if err != nil {
208 if errors.Is(err, context.Canceled) {
209 return a.err(ErrRequestCancelled)
210 }
211 return a.err(fmt.Errorf("failed to process events: %w", err))
212 }
213 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
214 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
215 // We are not done, we need to respond with the tool response
216 msgHistory = append(msgHistory, agentMessage, *toolResults)
217 continue
218 }
219 return AgentEvent{
220 message: agentMessage,
221 }
222 }
223}
224
225func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) {
226 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
227 Role: message.User,
228 Parts: []message.ContentPart{
229 message.TextContent{Text: content},
230 },
231 })
232}
233
234func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
235 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
236
237 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
238 Role: message.Assistant,
239 Parts: []message.ContentPart{},
240 Model: a.provider.Model().ID,
241 })
242 if err != nil {
243 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
244 }
245
246 // Add the session and message ID into the context if needed by tools.
247 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
248 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
249
250 // Process each event in the stream.
251 for event := range eventChan {
252 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
253 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
254 return assistantMsg, nil, processErr
255 }
256 if ctx.Err() != nil {
257 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
258 return assistantMsg, nil, ctx.Err()
259 }
260 }
261
262 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
263 toolCalls := assistantMsg.ToolCalls()
264 for i, toolCall := range toolCalls {
265 select {
266 case <-ctx.Done():
267 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
268 // Make all future tool calls cancelled
269 for j := i; j < len(toolCalls); j++ {
270 toolResults[j] = message.ToolResult{
271 ToolCallID: toolCalls[j].ID,
272 Content: "Tool execution canceled by user",
273 IsError: true,
274 }
275 }
276 goto out
277 default:
278 // Continue processing
279 var tool tools.BaseTool
280 for _, availableTools := range a.tools {
281 if availableTools.Info().Name == toolCall.Name {
282 tool = availableTools
283 }
284 }
285
286 // Tool not found
287 if tool == nil {
288 toolResults[i] = message.ToolResult{
289 ToolCallID: toolCall.ID,
290 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
291 IsError: true,
292 }
293 continue
294 }
295
296 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
297 ID: toolCall.ID,
298 Name: toolCall.Name,
299 Input: toolCall.Input,
300 })
301 if toolErr != nil {
302 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
303 toolResults[i] = message.ToolResult{
304 ToolCallID: toolCall.ID,
305 Content: "Permission denied",
306 IsError: true,
307 }
308 for j := i + 1; j < len(toolCalls); j++ {
309 toolResults[j] = message.ToolResult{
310 ToolCallID: toolCalls[j].ID,
311 Content: "Tool execution canceled by user",
312 IsError: true,
313 }
314 }
315 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
316 } else {
317 toolResults[i] = message.ToolResult{
318 ToolCallID: toolCall.ID,
319 Content: toolErr.Error(),
320 IsError: true,
321 }
322 for j := i; j < len(toolCalls); j++ {
323 toolResults[j] = message.ToolResult{
324 ToolCallID: toolCalls[j].ID,
325 Content: "Previous tool failed",
326 IsError: true,
327 }
328 }
329 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError)
330 }
331 // If permission is denied or an error happens we cancel all the following tools
332 break
333 }
334 toolResults[i] = message.ToolResult{
335 ToolCallID: toolCall.ID,
336 Content: toolResult.Content,
337 Metadata: toolResult.Metadata,
338 IsError: toolResult.IsError,
339 }
340 }
341 }
342out:
343 if len(toolResults) == 0 {
344 return assistantMsg, nil, nil
345 }
346 parts := make([]message.ContentPart, 0)
347 for _, tr := range toolResults {
348 parts = append(parts, tr)
349 }
350 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
351 Role: message.Tool,
352 Parts: parts,
353 })
354 if err != nil {
355 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
356 }
357
358 return assistantMsg, &msg, err
359}
360
361func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
362 msg.AddFinish(finishReson)
363 _ = a.messages.Update(ctx, *msg)
364}
365
366func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
367 select {
368 case <-ctx.Done():
369 return ctx.Err()
370 default:
371 // Continue processing.
372 }
373
374 switch event.Type {
375 case provider.EventThinkingDelta:
376 assistantMsg.AppendReasoningContent(event.Content)
377 return a.messages.Update(ctx, *assistantMsg)
378 case provider.EventContentDelta:
379 assistantMsg.AppendContent(event.Content)
380 return a.messages.Update(ctx, *assistantMsg)
381 case provider.EventError:
382 if errors.Is(event.Error, context.Canceled) {
383 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
384 return context.Canceled
385 }
386 logging.ErrorPersist(event.Error.Error())
387 return event.Error
388 case provider.EventComplete:
389 assistantMsg.SetToolCalls(event.Response.ToolCalls)
390 assistantMsg.AddFinish(event.Response.FinishReason)
391 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
392 return fmt.Errorf("failed to update message: %w", err)
393 }
394 return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
395 }
396
397 return nil
398}
399
400func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
401 sess, err := a.sessions.Get(ctx, sessionID)
402 if err != nil {
403 return fmt.Errorf("failed to get session: %w", err)
404 }
405
406 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
407 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
408 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
409 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
410
411 sess.Cost += cost
412 sess.CompletionTokens += usage.OutputTokens
413 sess.PromptTokens += usage.InputTokens
414
415 _, err = a.sessions.Save(ctx, sess)
416 if err != nil {
417 return fmt.Errorf("failed to save session: %w", err)
418 }
419 return nil
420}
421
422func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
423 cfg := config.Get()
424 agentConfig, ok := cfg.Agents[agentName]
425 if !ok {
426 return nil, fmt.Errorf("agent %s not found", agentName)
427 }
428 model, ok := models.SupportedModels[agentConfig.Model]
429 if !ok {
430 return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
431 }
432
433 providerCfg, ok := cfg.Providers[model.Provider]
434 if !ok {
435 return nil, fmt.Errorf("provider %s not supported", model.Provider)
436 }
437 if providerCfg.Disabled {
438 return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
439 }
440 agentProvider, err := provider.NewProvider(
441 model.Provider,
442 provider.WithAPIKey(providerCfg.APIKey),
443 provider.WithModel(model),
444 provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
445 provider.WithMaxTokens(agentConfig.MaxTokens),
446 )
447 if err != nil {
448 return nil, fmt.Errorf("could not create provider: %v", err)
449 }
450
451 return agentProvider, nil
452}