@@ -501,12 +501,45 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
}
continue
}
- toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
- ID: toolCall.ID,
- Name: toolCall.Name,
- Input: toolCall.Input,
- })
+
+ // Run tool in goroutine to allow cancellation
+ type toolExecResult struct {
+ response tools.ToolResponse
+ err error
+ }
+ resultChan := make(chan toolExecResult, 1)
+
+ go func() {
+ response, err := tool.Run(ctx, tools.ToolCall{
+ ID: toolCall.ID,
+ Name: toolCall.Name,
+ Input: toolCall.Input,
+ })
+ resultChan <- toolExecResult{response: response, err: err}
+ }()
+
+ var toolResponse tools.ToolResponse
+ var toolErr error
+
+ select {
+ case <-ctx.Done():
+ a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
+ // Mark remaining tool calls as cancelled
+ for j := i; j < len(toolCalls); j++ {
+ toolResults[j] = message.ToolResult{
+ ToolCallID: toolCalls[j].ID,
+ Content: "Tool execution canceled by user",
+ IsError: true,
+ }
+ }
+ goto out
+ case result := <-resultChan:
+ toolResponse = result.response
+ toolErr = result.err
+ }
+
if toolErr != nil {
+ slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
if errors.Is(toolErr, permission.ErrorPermissionDenied) {
toolResults[i] = message.ToolResult{
ToolCallID: toolCall.ID,
@@ -526,9 +559,9 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
}
toolResults[i] = message.ToolResult{
ToolCallID: toolCall.ID,
- Content: toolResult.Content,
- Metadata: toolResult.Metadata,
- IsError: toolResult.IsError,
+ Content: toolResponse.Content,
+ Metadata: toolResponse.Metadata,
+ IsError: toolResponse.IsError,
}
}
}
@@ -796,6 +829,9 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
}
func (a *agent) CancelAll() {
+ if !a.IsBusy() {
+ return
+ }
a.activeRequests.Range(func(key, value any) bool {
a.Cancel(key.(string)) // key is sessionID
return true