1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "sync"
9
10 "github.com/opencode-ai/opencode/internal/config"
11 "github.com/opencode-ai/opencode/internal/llm/models"
12 "github.com/opencode-ai/opencode/internal/llm/prompt"
13 "github.com/opencode-ai/opencode/internal/llm/provider"
14 "github.com/opencode-ai/opencode/internal/llm/tools"
15 "github.com/opencode-ai/opencode/internal/logging"
16 "github.com/opencode-ai/opencode/internal/message"
17 "github.com/opencode-ai/opencode/internal/permission"
18 "github.com/opencode-ai/opencode/internal/session"
19)
20
21// Common errors
22var (
23 ErrRequestCancelled = errors.New("request cancelled by user")
24 ErrSessionBusy = errors.New("session is currently processing another request")
25)
26
27type AgentEvent struct {
28 message message.Message
29 err error
30}
31
32func (e *AgentEvent) Err() error {
33 return e.err
34}
35
36func (e *AgentEvent) Response() message.Message {
37 return e.message
38}
39
40type Service interface {
41 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
42 Cancel(sessionID string)
43 IsSessionBusy(sessionID string) bool
44 IsBusy() bool
45 Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
46}
47
48type agent struct {
49 sessions session.Service
50 messages message.Service
51
52 tools []tools.BaseTool
53 provider provider.Provider
54
55 titleProvider provider.Provider
56
57 activeRequests sync.Map
58}
59
60func NewAgent(
61 agentName config.AgentName,
62 sessions session.Service,
63 messages message.Service,
64 agentTools []tools.BaseTool,
65) (Service, error) {
66 agentProvider, err := createAgentProvider(agentName)
67 if err != nil {
68 return nil, err
69 }
70 var titleProvider provider.Provider
71 // Only generate titles for the coder agent
72 if agentName == config.AgentCoder {
73 titleProvider, err = createAgentProvider(config.AgentTitle)
74 if err != nil {
75 return nil, err
76 }
77 }
78
79 agent := &agent{
80 provider: agentProvider,
81 messages: messages,
82 sessions: sessions,
83 tools: agentTools,
84 titleProvider: titleProvider,
85 activeRequests: sync.Map{},
86 }
87
88 return agent, nil
89}
90
91func (a *agent) Cancel(sessionID string) {
92 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
93 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
94 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
95 cancel()
96 }
97 }
98}
99
100func (a *agent) IsBusy() bool {
101 busy := false
102 a.activeRequests.Range(func(key, value interface{}) bool {
103 if cancelFunc, ok := value.(context.CancelFunc); ok {
104 if cancelFunc != nil {
105 busy = true
106 return false // Stop iterating
107 }
108 }
109 return true // Continue iterating
110 })
111 return busy
112}
113
114func (a *agent) IsSessionBusy(sessionID string) bool {
115 _, busy := a.activeRequests.Load(sessionID)
116 return busy
117}
118
119func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
120 if content == "" {
121 return nil
122 }
123 if a.titleProvider == nil {
124 return nil
125 }
126 session, err := a.sessions.Get(ctx, sessionID)
127 if err != nil {
128 return err
129 }
130 parts := []message.ContentPart{message.TextContent{Text: content}}
131 response, err := a.titleProvider.SendMessages(
132 ctx,
133 []message.Message{
134 {
135 Role: message.User,
136 Parts: parts,
137 },
138 },
139 make([]tools.BaseTool, 0),
140 )
141 if err != nil {
142 return err
143 }
144
145 title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
146 if title == "" {
147 return nil
148 }
149
150 session.Title = title
151 _, err = a.sessions.Save(ctx, session)
152 return err
153}
154
155func (a *agent) err(err error) AgentEvent {
156 return AgentEvent{
157 err: err,
158 }
159}
160
161func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
162 if !a.provider.Model().SupportsAttachments && attachments != nil {
163 attachments = nil
164 }
165 events := make(chan AgentEvent)
166 if a.IsSessionBusy(sessionID) {
167 return nil, ErrSessionBusy
168 }
169
170 genCtx, cancel := context.WithCancel(ctx)
171
172 a.activeRequests.Store(sessionID, cancel)
173 go func() {
174 logging.Debug("Request started", "sessionID", sessionID)
175 defer logging.RecoverPanic("agent.Run", func() {
176 events <- a.err(fmt.Errorf("panic while running the agent"))
177 })
178 var attachmentParts []message.ContentPart
179 for _, attachment := range attachments {
180 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
181 }
182 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
183 if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
184 logging.ErrorPersist(result.Err().Error())
185 }
186 logging.Debug("Request completed", "sessionID", sessionID)
187 a.activeRequests.Delete(sessionID)
188 cancel()
189 events <- result
190 close(events)
191 }()
192 return events, nil
193}
194
195func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
196 // List existing messages; if none, start title generation asynchronously.
197 msgs, err := a.messages.List(ctx, sessionID)
198 if err != nil {
199 return a.err(fmt.Errorf("failed to list messages: %w", err))
200 }
201 if len(msgs) == 0 {
202 go func() {
203 defer logging.RecoverPanic("agent.Run", func() {
204 logging.ErrorPersist("panic while generating title")
205 })
206 titleErr := a.generateTitle(context.Background(), sessionID, content)
207 if titleErr != nil {
208 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
209 }
210 }()
211 }
212
213 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
214 if err != nil {
215 return a.err(fmt.Errorf("failed to create user message: %w", err))
216 }
217 // Append the new user message to the conversation history.
218 msgHistory := append(msgs, userMsg)
219
220 for {
221 // Check for cancellation before each iteration
222 select {
223 case <-ctx.Done():
224 return a.err(ctx.Err())
225 default:
226 // Continue processing
227 }
228 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
229 if err != nil {
230 if errors.Is(err, context.Canceled) {
231 agentMessage.AddFinish(message.FinishReasonCanceled)
232 a.messages.Update(context.Background(), agentMessage)
233 return a.err(ErrRequestCancelled)
234 }
235 return a.err(fmt.Errorf("failed to process events: %w", err))
236 }
237 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
238 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
239 // We are not done, we need to respond with the tool response
240 msgHistory = append(msgHistory, agentMessage, *toolResults)
241 continue
242 }
243 return AgentEvent{
244 message: agentMessage,
245 }
246 }
247}
248
249func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
250 parts := []message.ContentPart{message.TextContent{Text: content}}
251 parts = append(parts, attachmentParts...)
252 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
253 Role: message.User,
254 Parts: parts,
255 })
256}
257
258func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
259 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
260
261 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
262 Role: message.Assistant,
263 Parts: []message.ContentPart{},
264 Model: a.provider.Model().ID,
265 })
266 if err != nil {
267 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
268 }
269
270 // Add the session and message ID into the context if needed by tools.
271 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
272 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
273
274 // Process each event in the stream.
275 for event := range eventChan {
276 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
277 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
278 return assistantMsg, nil, processErr
279 }
280 if ctx.Err() != nil {
281 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
282 return assistantMsg, nil, ctx.Err()
283 }
284 }
285
286 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
287 toolCalls := assistantMsg.ToolCalls()
288 for i, toolCall := range toolCalls {
289 select {
290 case <-ctx.Done():
291 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
292 // Make all future tool calls cancelled
293 for j := i; j < len(toolCalls); j++ {
294 toolResults[j] = message.ToolResult{
295 ToolCallID: toolCalls[j].ID,
296 Content: "Tool execution canceled by user",
297 IsError: true,
298 }
299 }
300 goto out
301 default:
302 // Continue processing
303 var tool tools.BaseTool
304 for _, availableTools := range a.tools {
305 if availableTools.Info().Name == toolCall.Name {
306 tool = availableTools
307 }
308 }
309
310 // Tool not found
311 if tool == nil {
312 toolResults[i] = message.ToolResult{
313 ToolCallID: toolCall.ID,
314 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
315 IsError: true,
316 }
317 continue
318 }
319 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
320 ID: toolCall.ID,
321 Name: toolCall.Name,
322 Input: toolCall.Input,
323 })
324 if toolErr != nil {
325 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
326 toolResults[i] = message.ToolResult{
327 ToolCallID: toolCall.ID,
328 Content: "Permission denied",
329 IsError: true,
330 }
331 for j := i + 1; j < len(toolCalls); j++ {
332 toolResults[j] = message.ToolResult{
333 ToolCallID: toolCalls[j].ID,
334 Content: "Tool execution canceled by user",
335 IsError: true,
336 }
337 }
338 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
339 break
340 }
341 }
342 toolResults[i] = message.ToolResult{
343 ToolCallID: toolCall.ID,
344 Content: toolResult.Content,
345 Metadata: toolResult.Metadata,
346 IsError: toolResult.IsError,
347 }
348 }
349 }
350out:
351 if len(toolResults) == 0 {
352 return assistantMsg, nil, nil
353 }
354 parts := make([]message.ContentPart, 0)
355 for _, tr := range toolResults {
356 parts = append(parts, tr)
357 }
358 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
359 Role: message.Tool,
360 Parts: parts,
361 })
362 if err != nil {
363 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
364 }
365
366 return assistantMsg, &msg, err
367}
368
369func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
370 msg.AddFinish(finishReson)
371 _ = a.messages.Update(ctx, *msg)
372}
373
374func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
375 select {
376 case <-ctx.Done():
377 return ctx.Err()
378 default:
379 // Continue processing.
380 }
381
382 switch event.Type {
383 case provider.EventThinkingDelta:
384 assistantMsg.AppendReasoningContent(event.Content)
385 return a.messages.Update(ctx, *assistantMsg)
386 case provider.EventContentDelta:
387 assistantMsg.AppendContent(event.Content)
388 return a.messages.Update(ctx, *assistantMsg)
389 case provider.EventToolUseStart:
390 assistantMsg.AddToolCall(*event.ToolCall)
391 return a.messages.Update(ctx, *assistantMsg)
392 // TODO: see how to handle this
393 // case provider.EventToolUseDelta:
394 // tm := time.Unix(assistantMsg.UpdatedAt, 0)
395 // assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
396 // if time.Since(tm) > 1000*time.Millisecond {
397 // err := a.messages.Update(ctx, *assistantMsg)
398 // assistantMsg.UpdatedAt = time.Now().Unix()
399 // return err
400 // }
401 case provider.EventToolUseStop:
402 assistantMsg.FinishToolCall(event.ToolCall.ID)
403 return a.messages.Update(ctx, *assistantMsg)
404 case provider.EventError:
405 if errors.Is(event.Error, context.Canceled) {
406 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
407 return context.Canceled
408 }
409 logging.ErrorPersist(event.Error.Error())
410 return event.Error
411 case provider.EventComplete:
412 assistantMsg.SetToolCalls(event.Response.ToolCalls)
413 assistantMsg.AddFinish(event.Response.FinishReason)
414 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
415 return fmt.Errorf("failed to update message: %w", err)
416 }
417 return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
418 }
419
420 return nil
421}
422
423func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
424 sess, err := a.sessions.Get(ctx, sessionID)
425 if err != nil {
426 return fmt.Errorf("failed to get session: %w", err)
427 }
428
429 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
430 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
431 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
432 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
433
434 sess.Cost += cost
435 sess.CompletionTokens += usage.OutputTokens
436 sess.PromptTokens += usage.InputTokens
437
438 _, err = a.sessions.Save(ctx, sess)
439 if err != nil {
440 return fmt.Errorf("failed to save session: %w", err)
441 }
442 return nil
443}
444
445func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
446 if a.IsBusy() {
447 return models.Model{}, fmt.Errorf("cannot change model while processing requests")
448 }
449
450 if err := config.UpdateAgentModel(agentName, modelID); err != nil {
451 return models.Model{}, fmt.Errorf("failed to update config: %w", err)
452 }
453
454 provider, err := createAgentProvider(agentName)
455 if err != nil {
456 return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
457 }
458
459 a.provider = provider
460
461 return a.provider.Model(), nil
462}
463
464func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
465 cfg := config.Get()
466 agentConfig, ok := cfg.Agents[agentName]
467 if !ok {
468 return nil, fmt.Errorf("agent %s not found", agentName)
469 }
470 model, ok := models.SupportedModels[agentConfig.Model]
471 if !ok {
472 return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
473 }
474
475 providerCfg, ok := cfg.Providers[model.Provider]
476 if !ok {
477 return nil, fmt.Errorf("provider %s not supported", model.Provider)
478 }
479 if providerCfg.Disabled {
480 return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
481 }
482 maxTokens := model.DefaultMaxTokens
483 if agentConfig.MaxTokens > 0 {
484 maxTokens = agentConfig.MaxTokens
485 }
486 opts := []provider.ProviderClientOption{
487 provider.WithAPIKey(providerCfg.APIKey),
488 provider.WithModel(model),
489 provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
490 provider.WithMaxTokens(maxTokens),
491 }
492 if model.Provider == models.ProviderOpenAI && model.CanReason {
493 opts = append(
494 opts,
495 provider.WithOpenAIOptions(
496 provider.WithReasoningEffort(agentConfig.ReasoningEffort),
497 ),
498 )
499 } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
500 opts = append(
501 opts,
502 provider.WithAnthropicOptions(
503 provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
504 ),
505 )
506 }
507 agentProvider, err := provider.NewProvider(
508 model.Provider,
509 opts...,
510 )
511 if err != nil {
512 return nil, fmt.Errorf("could not create provider: %v", err)
513 }
514
515 return agentProvider, nil
516}