fix: fix permission cancel logic

Kujtim Hoxha created

Change summary

internal/llm/agent/agent.go | 52 +++++++++++++++++++++++++++++++++------
1 file changed, 44 insertions(+), 8 deletions(-)

Detailed changes

internal/llm/agent/agent.go 🔗

@@ -501,12 +501,45 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
 				}
 				continue
 			}
-			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
-				ID:    toolCall.ID,
-				Name:  toolCall.Name,
-				Input: toolCall.Input,
-			})
+
+			// Run tool in goroutine to allow cancellation
+			type toolExecResult struct {
+				response tools.ToolResponse
+				err      error
+			}
+			resultChan := make(chan toolExecResult, 1)
+
+			go func() {
+				response, err := tool.Run(ctx, tools.ToolCall{
+					ID:    toolCall.ID,
+					Name:  toolCall.Name,
+					Input: toolCall.Input,
+				})
+				resultChan <- toolExecResult{response: response, err: err}
+			}()
+
+			var toolResponse tools.ToolResponse
+			var toolErr error
+
+			select {
+			case <-ctx.Done():
+				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
+				// Mark remaining tool calls as cancelled
+				for j := i; j < len(toolCalls); j++ {
+					toolResults[j] = message.ToolResult{
+						ToolCallID: toolCalls[j].ID,
+						Content:    "Tool execution canceled by user",
+						IsError:    true,
+					}
+				}
+				goto out
+			case result := <-resultChan:
+				toolResponse = result.response
+				toolErr = result.err
+			}
+
 			if toolErr != nil {
+				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 					toolResults[i] = message.ToolResult{
 						ToolCallID: toolCall.ID,
@@ -526,9 +559,9 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
 			}
 			toolResults[i] = message.ToolResult{
 				ToolCallID: toolCall.ID,
-				Content:    toolResult.Content,
-				Metadata:   toolResult.Metadata,
-				IsError:    toolResult.IsError,
+				Content:    toolResponse.Content,
+				Metadata:   toolResponse.Metadata,
+				IsError:    toolResponse.IsError,
 			}
 		}
 	}
@@ -796,6 +829,9 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 }
 
 func (a *agent) CancelAll() {
+	if !a.IsBusy() {
+		return
+	}
 	a.activeRequests.Range(func(key, value any) bool {
 		a.Cancel(key.(string)) // key is sessionID
 		return true