1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "sync"
9
10 "github.com/opencode-ai/opencode/internal/config"
11 "github.com/opencode-ai/opencode/internal/llm/models"
12 "github.com/opencode-ai/opencode/internal/llm/prompt"
13 "github.com/opencode-ai/opencode/internal/llm/provider"
14 "github.com/opencode-ai/opencode/internal/llm/tools"
15 "github.com/opencode-ai/opencode/internal/logging"
16 "github.com/opencode-ai/opencode/internal/message"
17 "github.com/opencode-ai/opencode/internal/permission"
18 "github.com/opencode-ai/opencode/internal/session"
19)
20
21// Common errors
22var (
23 ErrRequestCancelled = errors.New("request cancelled by user")
24 ErrSessionBusy = errors.New("session is currently processing another request")
25)
26
27type AgentEvent struct {
28 message message.Message
29 err error
30}
31
32func (e *AgentEvent) Err() error {
33 return e.err
34}
35
36func (e *AgentEvent) Response() message.Message {
37 return e.message
38}
39
40type Service interface {
41 Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error)
42 Cancel(sessionID string)
43 IsSessionBusy(sessionID string) bool
44 IsBusy() bool
45 Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
46}
47
48type agent struct {
49 sessions session.Service
50 messages message.Service
51
52 tools []tools.BaseTool
53 provider provider.Provider
54
55 titleProvider provider.Provider
56
57 activeRequests sync.Map
58}
59
60func NewAgent(
61 agentName config.AgentName,
62 sessions session.Service,
63 messages message.Service,
64 agentTools []tools.BaseTool,
65) (Service, error) {
66 agentProvider, err := createAgentProvider(agentName)
67 if err != nil {
68 return nil, err
69 }
70 var titleProvider provider.Provider
71 // Only generate titles for the coder agent
72 if agentName == config.AgentCoder {
73 titleProvider, err = createAgentProvider(config.AgentTitle)
74 if err != nil {
75 return nil, err
76 }
77 }
78
79 agent := &agent{
80 provider: agentProvider,
81 messages: messages,
82 sessions: sessions,
83 tools: agentTools,
84 titleProvider: titleProvider,
85 activeRequests: sync.Map{},
86 }
87
88 return agent, nil
89}
90
91func (a *agent) Cancel(sessionID string) {
92 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
93 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
94 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
95 cancel()
96 }
97 }
98}
99
100func (a *agent) IsBusy() bool {
101 busy := false
102 a.activeRequests.Range(func(key, value interface{}) bool {
103 if cancelFunc, ok := value.(context.CancelFunc); ok {
104 if cancelFunc != nil {
105 busy = true
106 return false // Stop iterating
107 }
108 }
109 return true // Continue iterating
110 })
111 return busy
112}
113
114func (a *agent) IsSessionBusy(sessionID string) bool {
115 _, busy := a.activeRequests.Load(sessionID)
116 return busy
117}
118
119func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
120 if a.titleProvider == nil {
121 return nil
122 }
123 session, err := a.sessions.Get(ctx, sessionID)
124 if err != nil {
125 return err
126 }
127 response, err := a.titleProvider.SendMessages(
128 ctx,
129 []message.Message{
130 {
131 Role: message.User,
132 Parts: []message.ContentPart{
133 message.TextContent{
134 Text: content,
135 },
136 },
137 },
138 },
139 make([]tools.BaseTool, 0),
140 )
141 if err != nil {
142 return err
143 }
144
145 title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
146 if title == "" {
147 return nil
148 }
149
150 session.Title = title
151 _, err = a.sessions.Save(ctx, session)
152 return err
153}
154
155func (a *agent) err(err error) AgentEvent {
156 return AgentEvent{
157 err: err,
158 }
159}
160
161func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) {
162 events := make(chan AgentEvent)
163 if a.IsSessionBusy(sessionID) {
164 return nil, ErrSessionBusy
165 }
166
167 genCtx, cancel := context.WithCancel(ctx)
168
169 a.activeRequests.Store(sessionID, cancel)
170 go func() {
171 logging.Debug("Request started", "sessionID", sessionID)
172 defer logging.RecoverPanic("agent.Run", func() {
173 events <- a.err(fmt.Errorf("panic while running the agent"))
174 })
175
176 result := a.processGeneration(genCtx, sessionID, content)
177 if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
178 logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result))
179 }
180 logging.Debug("Request completed", "sessionID", sessionID)
181 a.activeRequests.Delete(sessionID)
182 cancel()
183 events <- result
184 close(events)
185 }()
186 return events, nil
187}
188
189func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent {
190 // List existing messages; if none, start title generation asynchronously.
191 msgs, err := a.messages.List(ctx, sessionID)
192 if err != nil {
193 return a.err(fmt.Errorf("failed to list messages: %w", err))
194 }
195 if len(msgs) == 0 {
196 go func() {
197 defer logging.RecoverPanic("agent.Run", func() {
198 logging.ErrorPersist("panic while generating title")
199 })
200 titleErr := a.generateTitle(context.Background(), sessionID, content)
201 if titleErr != nil {
202 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
203 }
204 }()
205 }
206
207 userMsg, err := a.createUserMessage(ctx, sessionID, content)
208 if err != nil {
209 return a.err(fmt.Errorf("failed to create user message: %w", err))
210 }
211
212 // Append the new user message to the conversation history.
213 msgHistory := append(msgs, userMsg)
214 for {
215 // Check for cancellation before each iteration
216 select {
217 case <-ctx.Done():
218 return a.err(ctx.Err())
219 default:
220 // Continue processing
221 }
222 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
223 if err != nil {
224 if errors.Is(err, context.Canceled) {
225 agentMessage.AddFinish(message.FinishReasonCanceled)
226 a.messages.Update(context.Background(), agentMessage)
227 return a.err(ErrRequestCancelled)
228 }
229 return a.err(fmt.Errorf("failed to process events: %w", err))
230 }
231 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
232 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
233 // We are not done, we need to respond with the tool response
234 msgHistory = append(msgHistory, agentMessage, *toolResults)
235 continue
236 }
237 return AgentEvent{
238 message: agentMessage,
239 }
240 }
241}
242
243func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) {
244 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
245 Role: message.User,
246 Parts: []message.ContentPart{
247 message.TextContent{Text: content},
248 },
249 })
250}
251
252func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
253 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
254
255 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
256 Role: message.Assistant,
257 Parts: []message.ContentPart{},
258 Model: a.provider.Model().ID,
259 })
260 if err != nil {
261 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
262 }
263
264 // Add the session and message ID into the context if needed by tools.
265 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
266 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
267
268 // Process each event in the stream.
269 for event := range eventChan {
270 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
271 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
272 return assistantMsg, nil, processErr
273 }
274 if ctx.Err() != nil {
275 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
276 return assistantMsg, nil, ctx.Err()
277 }
278 }
279
280 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
281 toolCalls := assistantMsg.ToolCalls()
282 for i, toolCall := range toolCalls {
283 select {
284 case <-ctx.Done():
285 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
286 // Make all future tool calls cancelled
287 for j := i; j < len(toolCalls); j++ {
288 toolResults[j] = message.ToolResult{
289 ToolCallID: toolCalls[j].ID,
290 Content: "Tool execution canceled by user",
291 IsError: true,
292 }
293 }
294 goto out
295 default:
296 // Continue processing
297 var tool tools.BaseTool
298 for _, availableTools := range a.tools {
299 if availableTools.Info().Name == toolCall.Name {
300 tool = availableTools
301 }
302 }
303
304 // Tool not found
305 if tool == nil {
306 toolResults[i] = message.ToolResult{
307 ToolCallID: toolCall.ID,
308 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
309 IsError: true,
310 }
311 continue
312 }
313
314 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
315 ID: toolCall.ID,
316 Name: toolCall.Name,
317 Input: toolCall.Input,
318 })
319 if toolErr != nil {
320 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
321 toolResults[i] = message.ToolResult{
322 ToolCallID: toolCall.ID,
323 Content: "Permission denied",
324 IsError: true,
325 }
326 for j := i + 1; j < len(toolCalls); j++ {
327 toolResults[j] = message.ToolResult{
328 ToolCallID: toolCalls[j].ID,
329 Content: "Tool execution canceled by user",
330 IsError: true,
331 }
332 }
333 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
334 break
335 }
336 }
337 toolResults[i] = message.ToolResult{
338 ToolCallID: toolCall.ID,
339 Content: toolResult.Content,
340 Metadata: toolResult.Metadata,
341 IsError: toolResult.IsError,
342 }
343 }
344 }
345out:
346 if len(toolResults) == 0 {
347 return assistantMsg, nil, nil
348 }
349 parts := make([]message.ContentPart, 0)
350 for _, tr := range toolResults {
351 parts = append(parts, tr)
352 }
353 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
354 Role: message.Tool,
355 Parts: parts,
356 })
357 if err != nil {
358 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
359 }
360
361 return assistantMsg, &msg, err
362}
363
364func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
365 msg.AddFinish(finishReson)
366 _ = a.messages.Update(ctx, *msg)
367}
368
369func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
370 select {
371 case <-ctx.Done():
372 return ctx.Err()
373 default:
374 // Continue processing.
375 }
376
377 switch event.Type {
378 case provider.EventThinkingDelta:
379 assistantMsg.AppendReasoningContent(event.Content)
380 return a.messages.Update(ctx, *assistantMsg)
381 case provider.EventContentDelta:
382 assistantMsg.AppendContent(event.Content)
383 return a.messages.Update(ctx, *assistantMsg)
384 case provider.EventToolUseStart:
385 assistantMsg.AddToolCall(*event.ToolCall)
386 return a.messages.Update(ctx, *assistantMsg)
387 // TODO: see how to handle this
388 // case provider.EventToolUseDelta:
389 // tm := time.Unix(assistantMsg.UpdatedAt, 0)
390 // assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
391 // if time.Since(tm) > 1000*time.Millisecond {
392 // err := a.messages.Update(ctx, *assistantMsg)
393 // assistantMsg.UpdatedAt = time.Now().Unix()
394 // return err
395 // }
396 case provider.EventToolUseStop:
397 assistantMsg.FinishToolCall(event.ToolCall.ID)
398 return a.messages.Update(ctx, *assistantMsg)
399 case provider.EventError:
400 if errors.Is(event.Error, context.Canceled) {
401 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
402 return context.Canceled
403 }
404 logging.ErrorPersist(event.Error.Error())
405 return event.Error
406 case provider.EventComplete:
407 assistantMsg.SetToolCalls(event.Response.ToolCalls)
408 assistantMsg.AddFinish(event.Response.FinishReason)
409 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
410 return fmt.Errorf("failed to update message: %w", err)
411 }
412 return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
413 }
414
415 return nil
416}
417
418func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
419 sess, err := a.sessions.Get(ctx, sessionID)
420 if err != nil {
421 return fmt.Errorf("failed to get session: %w", err)
422 }
423
424 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
425 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
426 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
427 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
428
429 sess.Cost += cost
430 sess.CompletionTokens += usage.OutputTokens
431 sess.PromptTokens += usage.InputTokens
432
433 _, err = a.sessions.Save(ctx, sess)
434 if err != nil {
435 return fmt.Errorf("failed to save session: %w", err)
436 }
437 return nil
438}
439
440func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
441 if a.IsBusy() {
442 return models.Model{}, fmt.Errorf("cannot change model while processing requests")
443 }
444
445 if err := config.UpdateAgentModel(agentName, modelID); err != nil {
446 return models.Model{}, fmt.Errorf("failed to update config: %w", err)
447 }
448
449 provider, err := createAgentProvider(agentName)
450 if err != nil {
451 return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
452 }
453
454 a.provider = provider
455
456 return a.provider.Model(), nil
457}
458
459func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
460 cfg := config.Get()
461 agentConfig, ok := cfg.Agents[agentName]
462 if !ok {
463 return nil, fmt.Errorf("agent %s not found", agentName)
464 }
465 model, ok := models.SupportedModels[agentConfig.Model]
466 if !ok {
467 return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
468 }
469
470 providerCfg, ok := cfg.Providers[model.Provider]
471 if !ok {
472 return nil, fmt.Errorf("provider %s not supported", model.Provider)
473 }
474 if providerCfg.Disabled {
475 return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
476 }
477 maxTokens := model.DefaultMaxTokens
478 if agentConfig.MaxTokens > 0 {
479 maxTokens = agentConfig.MaxTokens
480 }
481 opts := []provider.ProviderClientOption{
482 provider.WithAPIKey(providerCfg.APIKey),
483 provider.WithModel(model),
484 provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
485 provider.WithMaxTokens(maxTokens),
486 }
487 if model.Provider == models.ProviderOpenAI && model.CanReason {
488 opts = append(
489 opts,
490 provider.WithOpenAIOptions(
491 provider.WithReasoningEffort(agentConfig.ReasoningEffort),
492 ),
493 )
494 } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
495 opts = append(
496 opts,
497 provider.WithAnthropicOptions(
498 provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
499 ),
500 )
501 }
502 agentProvider, err := provider.NewProvider(
503 model.Provider,
504 opts...,
505 )
506 if err != nil {
507 return nil, fmt.Errorf("could not create provider: %v", err)
508 }
509
510 return agentProvider, nil
511}