From 95d5c900e97427ad902613724ebe549912dfccd4 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 14 Jul 2025 17:41:15 +0200 Subject: [PATCH] fix: fix permission cancel logic --- internal/llm/agent/agent.go | 52 +++++++++++++++++++++++++++++++------ 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 6c7844eaa2811570824fba489a8c7a7581fa201f..990b388c3a6ed9e2dd020994d4c18c97d04ebab4 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -501,12 +501,45 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg } continue } - toolResult, toolErr := tool.Run(ctx, tools.ToolCall{ - ID: toolCall.ID, - Name: toolCall.Name, - Input: toolCall.Input, - }) + + // Run tool in goroutine to allow cancellation + type toolExecResult struct { + response tools.ToolResponse + err error + } + resultChan := make(chan toolExecResult, 1) + + go func() { + response, err := tool.Run(ctx, tools.ToolCall{ + ID: toolCall.ID, + Name: toolCall.Name, + Input: toolCall.Input, + }) + resultChan <- toolExecResult{response: response, err: err} + }() + + var toolResponse tools.ToolResponse + var toolErr error + + select { + case <-ctx.Done(): + a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "") + // Mark remaining tool calls as cancelled + for j := i; j < len(toolCalls); j++ { + toolResults[j] = message.ToolResult{ + ToolCallID: toolCalls[j].ID, + Content: "Tool execution canceled by user", + IsError: true, + } + } + goto out + case result := <-resultChan: + toolResponse = result.response + toolErr = result.err + } + if toolErr != nil { + slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr) if errors.Is(toolErr, permission.ErrorPermissionDenied) { toolResults[i] = message.ToolResult{ ToolCallID: toolCall.ID, @@ -526,9 +559,9 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg } toolResults[i] = message.ToolResult{ ToolCallID: toolCall.ID, - Content: toolResult.Content, - Metadata: toolResult.Metadata, - IsError: toolResult.IsError, + Content: toolResponse.Content, + Metadata: toolResponse.Metadata, + IsError: toolResponse.IsError, } } } @@ -796,6 +829,9 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error { } func (a *agent) CancelAll() { + if !a.IsBusy() { + return + } a.activeRequests.Range(func(key, value any) bool { a.Cancel(key.(string)) // key is sessionID return true