diff --git a/internal/agent/agent.go b/internal/agent/agent.go index c64a4c7838b8545eed8bbaa3e32f33bab437f8d6..8d2fa40fd427143bf988587ef7faa3a89c3e23b1 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -445,7 +445,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy Content: content, IsError: true, } - _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{ + _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{ Role: message.Tool, Parts: []message.ContentPart{ toolResult, @@ -876,14 +876,17 @@ func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, } func (a *sessionAgent) Cancel(sessionID string) { - // Cancel regular requests. - if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil { + // Cancel regular requests. Don't use Take() here - we need the entry to + // remain in activeRequests so IsBusy() returns true until the goroutine + // fully completes (including error handling that may access the DB). + // The defer in processRequest will clean up the entry. + if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil { slog.Info("Request cancellation initiated", "session_id", sessionID) cancel() } // Also check for summarize requests. - if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil { + if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil { slog.Info("Summarize cancellation initiated", "session_id", sessionID) cancel() } diff --git a/internal/agent/agentic_fetch_tool.go b/internal/agent/agentic_fetch_tool.go index 333ec7926f80735c3798c524378964a8e41fe3e4..89d3535720f8452111f12f4df4eb691e39253bed 100644 --- a/internal/agent/agentic_fetch_tool.go +++ b/internal/agent/agentic_fetch_tool.go @@ -79,7 +79,7 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) ( description = "Search the web and analyze results" } - p := c.permissions.Request( + p, err := c.permissions.Request(ctx, permission.CreatePermissionRequest{ SessionID: validationResult.SessionID, Path: c.cfg.WorkingDir(), @@ -90,7 +90,9 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) ( Params: tools.AgenticFetchPermissionsParams(params), }, ) - + if err != nil { + return fantasy.ToolResponse{}, err + } if !p { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } diff --git a/internal/agent/tools/bash.go b/internal/agent/tools/bash.go index c3f0bc8cd24a6c4ff7c6f775e357c90b3dc99802..ca3612e091f23235688a2a40006469e39093d6a5 100644 --- a/internal/agent/tools/bash.go +++ b/internal/agent/tools/bash.go @@ -215,7 +215,7 @@ func NewBashTool(permissions permission.Service, workingDir string, attribution return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for executing shell command") } if !isSafeReadOnly { - p := permissions.Request( + p, err := permissions.Request(ctx, permission.CreatePermissionRequest{ SessionID: sessionID, Path: execWorkingDir, @@ -226,6 +226,9 @@ func NewBashTool(permissions permission.Service, workingDir string, attribution Params: BashPermissionsParams(params), }, ) + if err != nil { + return fantasy.ToolResponse{}, err + } if !p { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } diff --git a/internal/agent/tools/download.go b/internal/agent/tools/download.go index 353b312a29d410c6485f76fc8dd42a4b9dcdefb1..8f3f224b9e5647911d3c7e1cc5a668eea18b1785 100644 --- a/internal/agent/tools/download.go +++ b/internal/agent/tools/download.go @@ -70,7 +70,7 @@ func NewDownloadTool(permissions permission.Service, workingDir string, client * return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for downloading files") } - p := permissions.Request( + p, err := permissions.Request(ctx, permission.CreatePermissionRequest{ SessionID: sessionID, Path: filePath, @@ -80,7 +80,9 @@ func NewDownloadTool(permissions permission.Service, workingDir string, client * Params: DownloadPermissionsParams(params), }, ) - + if err != nil { + return fantasy.ToolResponse{}, err + } if !p { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } diff --git a/internal/agent/tools/edit.go b/internal/agent/tools/edit.go index e4503e8127a750647c659353a018d36ee42643a1..a3680d009c6d76f8bcb3e39f1c1ddd2041aa1e52 100644 --- a/internal/agent/tools/edit.go +++ b/internal/agent/tools/edit.go @@ -122,7 +122,7 @@ func createNewFile(edit editContext, filePath, content string, call fantasy.Tool content, strings.TrimPrefix(filePath, edit.workingDir), ) - p := edit.permissions.Request( + p, err := edit.permissions.Request(edit.ctx, permission.CreatePermissionRequest{ SessionID: sessionID, Path: fsext.PathOrPrefix(filePath, edit.workingDir), @@ -137,6 +137,9 @@ func createNewFile(edit editContext, filePath, content string, call fantasy.Tool }, }, ) + if err != nil { + return fantasy.ToolResponse{}, err + } if !p { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } @@ -243,7 +246,7 @@ func deleteContent(edit editContext, filePath, oldString string, replaceAll bool strings.TrimPrefix(filePath, edit.workingDir), ) - p := edit.permissions.Request( + p, err := edit.permissions.Request(edit.ctx, permission.CreatePermissionRequest{ SessionID: sessionID, Path: fsext.PathOrPrefix(filePath, edit.workingDir), @@ -258,6 +261,9 @@ func deleteContent(edit editContext, filePath, oldString string, replaceAll bool }, }, ) + if err != nil { + return fantasy.ToolResponse{}, err + } if !p { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } @@ -378,7 +384,7 @@ func replaceContent(edit editContext, filePath, oldString, newString string, rep strings.TrimPrefix(filePath, edit.workingDir), ) - p := edit.permissions.Request( + p, err := edit.permissions.Request(edit.ctx, permission.CreatePermissionRequest{ SessionID: sessionID, Path: fsext.PathOrPrefix(filePath, edit.workingDir), @@ -393,6 +399,9 @@ func replaceContent(edit editContext, filePath, oldString, newString string, rep }, }, ) + if err != nil { + return fantasy.ToolResponse{}, err + } if !p { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } diff --git a/internal/agent/tools/fetch.go b/internal/agent/tools/fetch.go index b23da7099be7ad0b5e3cc7076426c7494e8a3202..29fa6f15b5a90fe2dd8d34ef383990b892b742c3 100644 --- a/internal/agent/tools/fetch.go +++ b/internal/agent/tools/fetch.go @@ -55,7 +55,7 @@ func NewFetchTool(permissions permission.Service, workingDir string, client *htt return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file") } - p := permissions.Request( + p, err := permissions.Request(ctx, permission.CreatePermissionRequest{ SessionID: sessionID, Path: workingDir, @@ -66,7 +66,9 @@ func NewFetchTool(permissions permission.Service, workingDir string, client *htt Params: FetchPermissionsParams(params), }, ) - + if err != nil { + return fantasy.ToolResponse{}, err + } if !p { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } diff --git a/internal/agent/tools/ls.go b/internal/agent/tools/ls.go index 2a6627741256339a319ec734c4ff766b041e5670..eff7bac0757b5956f669a752c378cab548affb85 100644 --- a/internal/agent/tools/ls.go +++ b/internal/agent/tools/ls.go @@ -79,7 +79,7 @@ func NewLsTool(permissions permission.Service, workingDir string, lsConfig confi return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for accessing directories outside working directory") } - granted := permissions.Request( + granted, err := permissions.Request(ctx, permission.CreatePermissionRequest{ SessionID: sessionID, Path: absSearchPath, @@ -90,7 +90,9 @@ func NewLsTool(permissions permission.Service, workingDir string, lsConfig confi Params: LSPermissionsParams(params), }, ) - + if err != nil { + return fantasy.ToolResponse{}, err + } if !granted { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } diff --git a/internal/agent/tools/mcp-tools.go b/internal/agent/tools/mcp-tools.go index 5b4302cc5e16adedea18bdc767d2312f8d920f82..fa55f03728639a09e6bd2f150338238d30120883 100644 --- a/internal/agent/tools/mcp-tools.go +++ b/internal/agent/tools/mcp-tools.go @@ -89,7 +89,7 @@ func (m *Tool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolRe return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file") } permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name) - p := m.permissions.Request( + p, err := m.permissions.Request(ctx, permission.CreatePermissionRequest{ SessionID: sessionID, ToolCallID: params.ID, @@ -100,6 +100,9 @@ func (m *Tool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolRe Params: params.Input, }, ) + if err != nil { + return fantasy.ToolResponse{}, err + } if !p { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } diff --git a/internal/agent/tools/multiedit.go b/internal/agent/tools/multiedit.go index 9136c37fadb914cb1c560e3fa5f2b6208fc3ead5..0640228d23230e6a49d8e1405f371c099031fbf7 100644 --- a/internal/agent/tools/multiedit.go +++ b/internal/agent/tools/multiedit.go @@ -173,7 +173,7 @@ func processMultiEditWithCreation(edit editContext, params MultiEditParams, call } else { description = fmt.Sprintf("Create file %s with %d edits", params.FilePath, editsApplied) } - p := edit.permissions.Request(permission.CreatePermissionRequest{ + p, err := edit.permissions.Request(edit.ctx, permission.CreatePermissionRequest{ SessionID: sessionID, Path: fsext.PathOrPrefix(params.FilePath, edit.workingDir), ToolCallID: call.ID, @@ -186,12 +186,15 @@ func processMultiEditWithCreation(edit editContext, params MultiEditParams, call NewContent: currentContent, }, }) + if err != nil { + return fantasy.ToolResponse{}, err + } if !p { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } // Write the file - err := os.WriteFile(params.FilePath, []byte(currentContent), 0o644) + err = os.WriteFile(params.FilePath, []byte(currentContent), 0o644) if err != nil { return fantasy.ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } @@ -314,7 +317,7 @@ func processMultiEditExistingFile(edit editContext, params MultiEditParams, call } else { description = fmt.Sprintf("Apply %d edits to file %s", editsApplied, params.FilePath) } - p := edit.permissions.Request(permission.CreatePermissionRequest{ + p, err := edit.permissions.Request(edit.ctx, permission.CreatePermissionRequest{ SessionID: sessionID, Path: fsext.PathOrPrefix(params.FilePath, edit.workingDir), ToolCallID: call.ID, @@ -327,6 +330,9 @@ func processMultiEditExistingFile(edit editContext, params MultiEditParams, call NewContent: currentContent, }, }) + if err != nil { + return fantasy.ToolResponse{}, err + } if !p { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } diff --git a/internal/agent/tools/multiedit_test.go b/internal/agent/tools/multiedit_test.go index 36d0a0d469f67aa11cf36cd0bce3efffb4bab683..b6d575435e63dcd62a4dc9a7efb76cf13c14ad05 100644 --- a/internal/agent/tools/multiedit_test.go +++ b/internal/agent/tools/multiedit_test.go @@ -19,8 +19,8 @@ type mockPermissionService struct { *pubsub.Broker[permission.PermissionRequest] } -func (m *mockPermissionService) Request(req permission.CreatePermissionRequest) bool { - return true +func (m *mockPermissionService) Request(ctx context.Context, req permission.CreatePermissionRequest) (bool, error) { + return true, nil } func (m *mockPermissionService) Grant(req permission.PermissionRequest) {} diff --git a/internal/agent/tools/view.go b/internal/agent/tools/view.go index 7129a91b4b526bfdd27c97987b84aeae38d33068..96150669c292d457f7e8f1bb514f2a209bb73b6e 100644 --- a/internal/agent/tools/view.go +++ b/internal/agent/tools/view.go @@ -88,7 +88,7 @@ func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permiss return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for accessing files outside working directory") } - granted := permissions.Request( + granted, err := permissions.Request(ctx, permission.CreatePermissionRequest{ SessionID: sessionID, Path: absFilePath, @@ -99,7 +99,9 @@ func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permiss Params: ViewPermissionsParams(params), }, ) - + if err != nil { + return fantasy.ToolResponse{}, err + } if !granted { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } diff --git a/internal/agent/tools/write.go b/internal/agent/tools/write.go index 4ffd44a0553d1a1646d20dac557ab4e1bc47f45a..bbd5e50cf863d4d13503f6cee926b57df80f69bc 100644 --- a/internal/agent/tools/write.go +++ b/internal/agent/tools/write.go @@ -111,7 +111,7 @@ func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permis strings.TrimPrefix(filePath, workingDir), ) - p := permissions.Request( + p, err := permissions.Request(ctx, permission.CreatePermissionRequest{ SessionID: sessionID, Path: fsext.PathOrPrefix(filePath, workingDir), @@ -126,6 +126,9 @@ func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permis }, }, ) + if err != nil { + return fantasy.ToolResponse{}, err + } if !p { return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } diff --git a/internal/app/app.go b/internal/app/app.go index 24f1500cc707a16964f00a6eac1a62c0ac094850..08762f863a7d9cf77751d0c2c4095591f002eab0 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -392,13 +392,16 @@ func (app *App) Subscribe(program *tea.Program) { func (app *App) Shutdown() { start := time.Now() defer func() { slog.Info("Shutdown took " + time.Since(start).String()) }() - var wg sync.WaitGroup + + // First, cancel all agents and wait for them to finish. This must complete + // before closing the DB so agents can finish writing their state. if app.AgentCoordinator != nil { - wg.Go(func() { - app.AgentCoordinator.CancelAll() - }) + app.AgentCoordinator.CancelAll() } + // Now run remaining cleanup tasks in parallel. + var wg sync.WaitGroup + // Kill all background shells. wg.Go(func() { shell.GetBackgroundShellManager().KillAll() diff --git a/internal/permission/permission.go b/internal/permission/permission.go index 829dd2ed90abf4d45b63481eacebb492cadabdfd..9dc85e976238fdbe1ff2d3689b2a2c4160608760 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -47,7 +47,7 @@ type Service interface { GrantPersistent(permission PermissionRequest) Grant(permission PermissionRequest) Deny(permission PermissionRequest) - Request(opts CreatePermissionRequest) bool + Request(ctx context.Context, opts CreatePermissionRequest) (bool, error) AutoApproveSession(sessionID string) SetSkipRequests(skip bool) SkipRequests() bool @@ -122,9 +122,9 @@ func (s *permissionService) Deny(permission PermissionRequest) { } } -func (s *permissionService) Request(opts CreatePermissionRequest) bool { +func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRequest) (bool, error) { if s.skip { - return true + return true, nil } // tell the UI that a permission was requested @@ -137,7 +137,7 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { // Check if the tool/action combination is in the allowlist commandKey := opts.ToolName + ":" + opts.Action if slices.Contains(s.allowedTools, commandKey) || slices.Contains(s.allowedTools, opts.ToolName) { - return true + return true, nil } s.autoApproveSessionsMu.RLock() @@ -145,7 +145,7 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { s.autoApproveSessionsMu.RUnlock() if autoApprove { - return true + return true, nil } fileInfo, err := os.Stat(opts.Path) @@ -176,7 +176,7 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { for _, p := range s.sessionPermissions { if p.ToolName == permission.ToolName && p.Action == permission.Action && p.SessionID == permission.SessionID && p.Path == permission.Path { s.sessionPermissionsMu.RUnlock() - return true + return true, nil } } s.sessionPermissionsMu.RUnlock() @@ -185,7 +185,7 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { for _, p := range s.sessionPermissions { if p.ToolName == permission.ToolName && p.Action == permission.Action && p.SessionID == permission.SessionID && p.Path == permission.Path { s.sessionPermissionsMu.RUnlock() - return true + return true, nil } } s.sessionPermissionsMu.RUnlock() @@ -199,7 +199,12 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { // Publish the request s.Publish(pubsub.CreatedEvent, permission) - return <-respCh + select { + case <-ctx.Done(): + return false, ctx.Err() + case granted := <-respCh: + return granted, nil + } } func (s *permissionService) AutoApproveSession(sessionID string) { diff --git a/internal/permission/permission_test.go b/internal/permission/permission_test.go index d1ccd286836768f1bc1119966568941f7494affd..89e06916024cd1669f5e0d0a263d4a71548c8a97 100644 --- a/internal/permission/permission_test.go +++ b/internal/permission/permission_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestPermissionService_AllowedCommands(t *testing.T) { @@ -81,14 +82,16 @@ func TestPermissionService_AllowedCommands(t *testing.T) { func TestPermissionService_SkipMode(t *testing.T) { service := NewPermissionService("/tmp", true, []string{}) - result := service.Request(CreatePermissionRequest{ + result, err := service.Request(t.Context(), CreatePermissionRequest{ SessionID: "test-session", ToolName: "bash", Action: "execute", Description: "test command", Path: "/tmp", }) - + if err != nil { + t.Errorf("unexpected error: %v", err) + } if !result { t.Error("expected permission to be granted in skip mode") } @@ -115,7 +118,7 @@ func TestPermissionService_SequentialProperties(t *testing.T) { go func() { defer wg.Done() - result1 = service.Request(req1) + result1, _ = service.Request(t.Context(), req1) }() var permissionReq PermissionRequest @@ -136,7 +139,8 @@ func TestPermissionService_SequentialProperties(t *testing.T) { Params: map[string]string{"file": "test.txt"}, Path: "/tmp/test.txt", } - result2 := service.Request(req2) + result2, err := service.Request(t.Context(), req2) + require.NoError(t, err) assert.True(t, result2, "Second request should be auto-approved") }) t.Run("Sequential requests with temporary grants", func(t *testing.T) { @@ -156,7 +160,7 @@ func TestPermissionService_SequentialProperties(t *testing.T) { var wg sync.WaitGroup wg.Go(func() { - result1 = service.Request(req) + result1, _ = service.Request(t.Context(), req) }) var permissionReq PermissionRequest @@ -170,7 +174,7 @@ func TestPermissionService_SequentialProperties(t *testing.T) { var result2 bool wg.Go(func() { - result2 = service.Request(req) + result2, _ = service.Request(t.Context(), req) }) event = <-events @@ -215,7 +219,8 @@ func TestPermissionService_SequentialProperties(t *testing.T) { wg.Add(1) go func(index int, request CreatePermissionRequest) { defer wg.Done() - results = append(results, service.Request(request)) + result, _ := service.Request(t.Context(), request) + results = append(results, result) }(i, req) } @@ -241,7 +246,8 @@ func TestPermissionService_SequentialProperties(t *testing.T) { assert.Equal(t, 2, grantedCount, "Should have 2 granted and 1 denied") secondReq := requests[1] secondReq.Description = "Repeat of second request" - result := service.Request(secondReq) + result, err := service.Request(t.Context(), secondReq) + require.NoError(t, err) assert.True(t, result, "Repeated request should be auto-approved due to persistent permission") }) }