diff --git a/.cargo/config.toml b/.cargo/config.toml index 8db58d238003c29df6dbc9fa733c6d5521340103..717c5e18c8d294bacf65207bc6b8ecb7dba1b152 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -19,8 +19,6 @@ rustflags = [ "windows_slim_errors", # This cfg will reduce the size of `windows::core::Error` from 16 bytes to 4 bytes "-C", "target-feature=+crt-static", # This fixes the linking issue when compiling livekit on Windows - "-C", - "link-arg=-fuse-ld=lld", ] [env] diff --git a/.config/hakari.toml b/.config/hakari.toml index 2050065cc2d6be2a27ec012dcd125af992793eeb..b1e2954743b404f088c71c28aad1d6a699a22aeb 100644 --- a/.config/hakari.toml +++ b/.config/hakari.toml @@ -25,6 +25,8 @@ third-party = [ { name = "reqwest", version = "0.11.27" }, # build of remote_server should not include scap / its x11 dependency { name = "scap", git = "https://github.com/zed-industries/scap", rev = "808aa5c45b41e8f44729d02e38fd00a2fe2722e7" }, + # build of remote_server should not need to include on libalsa through rodio + { name = "rodio", git = "https://github.com/RustAudio/rodio", branch = "better_wav_output"}, ] [final-excludes] @@ -32,7 +34,6 @@ workspace-members = [ "zed_extension_api", # exclude all extensions - "zed_emmet", "zed_glsl", "zed_html", "zed_proto", @@ -40,5 +41,4 @@ workspace-members = [ "slash_commands_example", "zed_snippets", "zed_test_extension", - "zed_toml", ] diff --git a/.gitattributes b/.gitattributes index 9973cfb4db9ce8e9c79e84b9861a946f2f1c2f15..57afd4ea6942bd3985fb7395101800706d7b4ae6 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,5 @@ # Prevent GitHub from displaying comments within JSON files as errors. *.json linguist-language=JSON-with-Comments + +# Ensure the WSL script always has LF line endings, even on Windows +crates/zed/resources/windows/zed.sh text eol=lf diff --git a/.github/ISSUE_TEMPLATE/10_bug_report.yml b/.github/ISSUE_TEMPLATE/10_bug_report.yml index e132eca1e52bc617f35fc2ec6e4e34fe3c796b11..1bf6c80e4073dafa90e736f995053c570f0ba2da 100644 --- a/.github/ISSUE_TEMPLATE/10_bug_report.yml +++ b/.github/ISSUE_TEMPLATE/10_bug_report.yml @@ -14,7 +14,7 @@ body: ### Description diff --git a/.github/actionlint.yml b/.github/actionlint.yml index 0ee6af8a1d38e005f66b79f6c548d9f79396ea35..6d8e0107e9b42e71bb7266c0629393b9057e05bc 100644 --- a/.github/actionlint.yml +++ b/.github/actionlint.yml @@ -19,14 +19,27 @@ self-hosted-runner: - namespace-profile-16x32-ubuntu-2004-arm - namespace-profile-32x64-ubuntu-2004-arm # Namespace Ubuntu 22.04 (Everything else) - - namespace-profile-2x4-ubuntu-2204 - namespace-profile-4x8-ubuntu-2204 - namespace-profile-8x16-ubuntu-2204 - namespace-profile-16x32-ubuntu-2204 - namespace-profile-32x64-ubuntu-2204 + # Namespace Ubuntu 24.04 (like ubuntu-latest) + - namespace-profile-2x4-ubuntu-2404 # Namespace Limited Preview - namespace-profile-8x16-ubuntu-2004-arm-m4 - namespace-profile-8x32-ubuntu-2004-arm-m4 # Self Hosted Runners - self-mini-macos - self-32vcpu-windows-2022 + +# Disable shellcheck because it doesn't like powershell +# This should have been triggered with initial rollout of actionlint +# but https://github.com/zed-industries/zed/pull/36693 +# somehow caused actionlint to actually check those windows jobs +# where previously they were being skipped. Likely caused by an +# unknown bug in actionlint where parsing of `runs-on: [ ]` +# breaks something else. (yuck) +paths: + .github/workflows/{ci,release_nightly}.yml: + ignore: + - "shellcheck" diff --git a/.github/actions/run_tests_windows/action.yml b/.github/actions/run_tests_windows/action.yml index e3e3b7142e2223e2b5a7524205dbe21fb963ed86..8392ca1d375856c7f649e73d2445ce4f873924b1 100644 --- a/.github/actions/run_tests_windows/action.yml +++ b/.github/actions/run_tests_windows/action.yml @@ -20,168 +20,8 @@ runs: with: node-version: "18" - - name: Configure crash dumps - shell: powershell - run: | - # Record the start time for this CI run - $runStartTime = Get-Date - $runStartTimeStr = $runStartTime.ToString("yyyy-MM-dd HH:mm:ss") - Write-Host "CI run started at: $runStartTimeStr" - - # Save the timestamp for later use - echo "CI_RUN_START_TIME=$($runStartTime.Ticks)" >> $env:GITHUB_ENV - - # Create crash dump directory in workspace (non-persistent) - $dumpPath = "$env:GITHUB_WORKSPACE\crash_dumps" - New-Item -ItemType Directory -Force -Path $dumpPath | Out-Null - - Write-Host "Setting up crash dump detection..." - Write-Host "Workspace dump path: $dumpPath" - - # Note: We're NOT modifying registry on stateful runners - # Instead, we'll check default Windows crash locations after tests - - name: Run tests shell: powershell working-directory: ${{ inputs.working-directory }} run: | - $env:RUST_BACKTRACE = "full" - - # Enable Windows debugging features - $env:_NT_SYMBOL_PATH = "srv*https://msdl.microsoft.com/download/symbols" - - # .NET crash dump environment variables (ephemeral) - $env:COMPlus_DbgEnableMiniDump = "1" - $env:COMPlus_DbgMiniDumpType = "4" - $env:COMPlus_CreateDumpDiagnostics = "1" - cargo nextest run --workspace --no-fail-fast - continue-on-error: true - - - name: Analyze crash dumps - if: always() - shell: powershell - run: | - Write-Host "Checking for crash dumps..." - - # Get the CI run start time from the environment - $runStartTime = [DateTime]::new([long]$env:CI_RUN_START_TIME) - Write-Host "Only analyzing dumps created after: $($runStartTime.ToString('yyyy-MM-dd HH:mm:ss'))" - - # Check all possible crash dump locations - $searchPaths = @( - "$env:GITHUB_WORKSPACE\crash_dumps", - "$env:LOCALAPPDATA\CrashDumps", - "$env:TEMP", - "$env:GITHUB_WORKSPACE", - "$env:USERPROFILE\AppData\Local\CrashDumps", - "C:\Windows\System32\config\systemprofile\AppData\Local\CrashDumps" - ) - - $dumps = @() - foreach ($path in $searchPaths) { - if (Test-Path $path) { - Write-Host "Searching in: $path" - $found = Get-ChildItem "$path\*.dmp" -ErrorAction SilentlyContinue | Where-Object { - $_.CreationTime -gt $runStartTime - } - if ($found) { - $dumps += $found - Write-Host " Found $($found.Count) dump(s) from this CI run" - } - } - } - - if ($dumps) { - Write-Host "Found $($dumps.Count) crash dump(s)" - - # Install debugging tools if not present - $cdbPath = "C:\Program Files (x86)\Windows Kits\10\Debuggers\x64\cdb.exe" - if (-not (Test-Path $cdbPath)) { - Write-Host "Installing Windows Debugging Tools..." - $url = "https://go.microsoft.com/fwlink/?linkid=2237387" - Invoke-WebRequest -Uri $url -OutFile winsdksetup.exe - Start-Process -Wait winsdksetup.exe -ArgumentList "/features OptionId.WindowsDesktopDebuggers /quiet" - } - - foreach ($dump in $dumps) { - Write-Host "`n==================================" - Write-Host "Analyzing crash dump: $($dump.Name)" - Write-Host "Size: $([math]::Round($dump.Length / 1MB, 2)) MB" - Write-Host "Time: $($dump.CreationTime)" - Write-Host "==================================" - - # Set symbol path - $env:_NT_SYMBOL_PATH = "srv*C:\symbols*https://msdl.microsoft.com/download/symbols" - - # Run analysis - $analysisOutput = & $cdbPath -z $dump.FullName -c "!analyze -v; ~*k; lm; q" 2>&1 | Out-String - - # Extract key information - if ($analysisOutput -match "ExceptionCode:\s*([\w]+)") { - Write-Host "Exception Code: $($Matches[1])" - if ($Matches[1] -eq "c0000005") { - Write-Host "Exception Type: ACCESS VIOLATION" - } - } - - if ($analysisOutput -match "EXCEPTION_RECORD:\s*(.+)") { - Write-Host "Exception Record: $($Matches[1])" - } - - if ($analysisOutput -match "FAULTING_IP:\s*\n(.+)") { - Write-Host "Faulting Instruction: $($Matches[1])" - } - - # Save full analysis - $analysisFile = "$($dump.FullName).analysis.txt" - $analysisOutput | Out-File -FilePath $analysisFile - Write-Host "`nFull analysis saved to: $analysisFile" - - # Print stack trace section - Write-Host "`n--- Stack Trace Preview ---" - $stackSection = $analysisOutput -split "STACK_TEXT:" | Select-Object -Last 1 - $stackLines = $stackSection -split "`n" | Select-Object -First 20 - $stackLines | ForEach-Object { Write-Host $_ } - Write-Host "--- End Stack Trace Preview ---" - } - - Write-Host "`n⚠️ Crash dumps detected! Download the 'crash-dumps' artifact for detailed analysis." - - # Copy dumps to workspace for artifact upload - $artifactPath = "$env:GITHUB_WORKSPACE\crash_dumps_collected" - New-Item -ItemType Directory -Force -Path $artifactPath | Out-Null - - foreach ($dump in $dumps) { - $destName = "$($dump.Directory.Name)_$($dump.Name)" - Copy-Item $dump.FullName -Destination "$artifactPath\$destName" - if (Test-Path "$($dump.FullName).analysis.txt") { - Copy-Item "$($dump.FullName).analysis.txt" -Destination "$artifactPath\$destName.analysis.txt" - } - } - - Write-Host "Copied $($dumps.Count) dump(s) to artifact directory" - } else { - Write-Host "No crash dumps from this CI run found" - } - - - name: Upload crash dumps - if: always() - uses: actions/upload-artifact@v4 - with: - name: crash-dumps-${{ github.run_id }}-${{ github.run_attempt }} - path: | - crash_dumps_collected/*.dmp - crash_dumps_collected/*.txt - if-no-files-found: ignore - retention-days: 7 - - - name: Check test results - shell: powershell - working-directory: ${{ inputs.working-directory }} - run: | - # Re-check test results to fail the job if tests failed - if ($LASTEXITCODE -ne 0) { - Write-Host "Tests failed with exit code: $LASTEXITCODE" - exit $LASTEXITCODE - } diff --git a/.github/workflows/bump_collab_staging.yml b/.github/workflows/bump_collab_staging.yml index d8eaa6019ec29b5dd908564d05f430d3e7f01909..d400905b4da3304a8b916d3a38ae9d8a2855dbf5 100644 --- a/.github/workflows/bump_collab_staging.yml +++ b/.github/workflows/bump_collab_staging.yml @@ -8,7 +8,7 @@ on: jobs: update-collab-staging-tag: if: github.repository_owner == 'zed-industries' - runs-on: ubuntu-latest + runs-on: namespace-profile-2x4-ubuntu-2404 steps: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f4ba227168fb9cec10e1b5e23223b48e7a4ca222..d416b4af0eedf38da249e39181bd8017b57f752c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,7 +37,7 @@ jobs: run_nix: ${{ steps.filter.outputs.run_nix }} run_actionlint: ${{ steps.filter.outputs.run_actionlint }} runs-on: - - ubuntu-latest + - namespace-profile-2x4-ubuntu-2404 steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -81,6 +81,7 @@ jobs: echo "run_license=false" >> "$GITHUB_OUTPUT" echo "$CHANGED_FILES" | grep -qP '^(nix/|flake\.|Cargo\.|rust-toolchain.toml|\.cargo/config.toml)' && \ + echo "$GITHUB_REF_NAME" | grep -qvP '^v[0-9]+\.[0-9]+\.[0-9x](-pre)?$' && \ echo "run_nix=true" >> "$GITHUB_OUTPUT" || \ echo "run_nix=false" >> "$GITHUB_OUTPUT" @@ -237,7 +238,7 @@ jobs: uses: ./.github/actions/build_docs actionlint: - runs-on: ubuntu-latest + runs-on: namespace-profile-2x4-ubuntu-2404 if: github.repository_owner == 'zed-industries' && needs.job_spec.outputs.run_actionlint == 'true' needs: [job_spec] steps: @@ -418,7 +419,7 @@ jobs: if: | github.repository_owner == 'zed-industries' && needs.job_spec.outputs.run_tests == 'true' - runs-on: [self-hosted, Windows, X64] + runs-on: [self-32vcpu-windows-2022] steps: - name: Environment Setup run: | @@ -458,7 +459,7 @@ jobs: tests_pass: name: Tests Pass - runs-on: ubuntu-latest + runs-on: namespace-profile-2x4-ubuntu-2404 needs: - job_spec - style @@ -784,7 +785,7 @@ jobs: bundle-windows-x64: timeout-minutes: 120 name: Create a Windows installer - runs-on: [self-hosted, Windows, X64] + runs-on: [self-32vcpu-windows-2022] if: contains(github.event.pull_request.labels.*.name, 'run-bundling') # if: (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling')) needs: [windows_tests] diff --git a/.github/workflows/danger.yml b/.github/workflows/danger.yml index 15c82643aef1e14c85daaaf2c8c3c61f62f1b3aa..3f84179278d1baaa7a299e2292b3041830d9ca60 100644 --- a/.github/workflows/danger.yml +++ b/.github/workflows/danger.yml @@ -12,7 +12,7 @@ on: jobs: danger: if: github.repository_owner == 'zed-industries' - runs-on: ubuntu-latest + runs-on: namespace-profile-2x4-ubuntu-2404 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 diff --git a/.github/workflows/release_nightly.yml b/.github/workflows/release_nightly.yml index 0cc6737a45106713021c769b75dbbb180008dffe..2026ee7b730698cd7e40eebcd141f5b8a6ee9d04 100644 --- a/.github/workflows/release_nightly.yml +++ b/.github/workflows/release_nightly.yml @@ -59,7 +59,7 @@ jobs: timeout-minutes: 60 name: Run tests on Windows if: github.repository_owner == 'zed-industries' - runs-on: [self-hosted, Windows, X64] + runs-on: [self-32vcpu-windows-2022] steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -206,9 +206,6 @@ jobs: runs-on: github-8vcpu-ubuntu-2404 needs: tests name: Build Zed on FreeBSD - # env: - # MYTOKEN : ${{ secrets.MYTOKEN }} - # MYTOKEN2: "value2" steps: - uses: actions/checkout@v4 - name: Build FreeBSD remote-server @@ -243,7 +240,6 @@ jobs: bundle-nix: name: Build and cache Nix package - if: false needs: tests secrets: inherit uses: ./.github/workflows/nix.yml @@ -252,7 +248,7 @@ jobs: timeout-minutes: 60 name: Create a Windows installer if: github.repository_owner == 'zed-industries' - runs-on: [self-hosted, Windows, X64] + runs-on: [self-32vcpu-windows-2022] needs: windows-tests env: AZURE_TENANT_ID: ${{ secrets.AZURE_SIGNING_TENANT_ID }} @@ -294,7 +290,7 @@ jobs: update-nightly-tag: name: Update nightly tag if: github.repository_owner == 'zed-industries' - runs-on: ubuntu-latest + runs-on: namespace-profile-2x4-ubuntu-2404 needs: - bundle-mac - bundle-linux-x86 diff --git a/.github/workflows/script_checks.yml b/.github/workflows/script_checks.yml index c32a433e46a6fc5381fa1abbe19b2814fe423c1d..5dbfc9cb7fa9a51b9e0aca972d125c2a27677584 100644 --- a/.github/workflows/script_checks.yml +++ b/.github/workflows/script_checks.yml @@ -12,7 +12,7 @@ jobs: shellcheck: name: "ShellCheck Scripts" if: github.repository_owner == 'zed-industries' - runs-on: ubuntu-latest + runs-on: namespace-profile-2x4-ubuntu-2404 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 diff --git a/.rules b/.rules index da009f1877b4c6ef2f0613995391852d4bf1dc8a..2f2b9cd705d95775bedf092bc4e6254136da6117 100644 --- a/.rules +++ b/.rules @@ -12,6 +12,19 @@ - Example: avoid `let _ = client.request(...).await?;` - use `client.request(...).await?;` instead * When implementing async operations that may fail, ensure errors propagate to the UI layer so users get meaningful feedback. * Never create files with `mod.rs` paths - prefer `src/some_module.rs` instead of `src/some_module/mod.rs`. +* When creating new crates, prefer specifying the library root path in `Cargo.toml` using `[lib] path = "...rs"` instead of the default `lib.rs`, to maintain consistent and descriptive naming (e.g., `gpui.rs` or `main.rs`). +* Avoid creative additions unless explicitly requested +* Use full words for variable names (no abbreviations like "q" for "queue") +* Use variable shadowing to scope clones in async contexts for clarity, minimizing the lifetime of borrowed references. + Example: + ```rust + executor.spawn({ + let task_ran = task_ran.clone(); + async move { + *task_ran.borrow_mut() = true; + } + }); + ``` # GPUI diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 91b1b75f8292f37b122c152d71fe1e38eeccf817..1c0b1e363ed0f04ff33c070a4a84815cece78545 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -27,6 +27,22 @@ By effectively engaging with the Zed team and community early in your process, w We plan to set aside time each week to pair program with contributors on promising pull requests in Zed. This will be an experiment. We tend to prefer pairing over async code review on our team, and we'd like to see how well it works in an open source setting. If we're finding it difficult to get on the same page with async review, we may ask you to pair with us if you're open to it. The closer a contribution is to the goals outlined in our roadmap, the more likely we'll be to spend time pairing on it. +## Mandatory PR contents + +Please ensure the PR contains + +- Before & after screenshots, if there are visual adjustments introduced. + +Examples of visual adjustments: tree-sitter query updates, UI changes, etc. + +- A disclosure of the AI assistance usage, if any was used. + +Any kind of AI assistance must be disclosed in the PR, along with the extent to which AI assistance was used (e.g. docs only vs. code generation). + +If the PR responses are being generated by an AI, disclose that as well. + +As a small exception, trivial tab-completion doesn't need to be disclosed, as long as it's limited to single keywords or short phrases. + ## Tips to improve the chances of your PR getting reviewed and merged - Discuss your plans ahead of time with the team @@ -49,6 +65,8 @@ If you would like to add a new icon to the Zed icon theme, [open a Discussion](h ## Bird's-eye view of Zed +We suggest you keep the [zed glossary](docs/src/development/glossary.md) at your side when starting out. It lists and explains some of the structures and terms you will see throughout the codebase. + Zed is made up of several smaller crates - let's go over those you're most likely to interact with: - [`gpui`](/crates/gpui) is a GPU-accelerated UI framework which provides all of the building blocks for Zed. **We recommend familiarizing yourself with the root level GPUI documentation.** diff --git a/Cargo.lock b/Cargo.lock index f0fd3049c0a4d6dc8197086066b0a236afe987bb..c1c7e0b2ecf765b5243efd3229aa8d25a5c67b5c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7,27 +7,31 @@ name = "acp_thread" version = "0.1.0" dependencies = [ "action_log", - "agent", "agent-client-protocol", + "agent_settings", "anyhow", "buffer_diff", "collections", "editor", "env_logger 0.11.8", + "file_icons", "futures 0.3.31", "gpui", "indoc", "itertools 0.14.0", "language", + "language_model", "markdown", "parking_lot", + "portable-pty", "project", "prompt_store", - "rand 0.8.5", + "rand 0.9.1", "serde", "serde_json", "settings", "smol", + "task", "tempfile", "terminal", "ui", @@ -35,6 +39,27 @@ dependencies = [ "util", "uuid", "watch", + "which 6.0.3", + "workspace-hack", +] + +[[package]] +name = "acp_tools" +version = "0.1.0" +dependencies = [ + "agent-client-protocol", + "collections", + "gpui", + "language", + "markdown", + "project", + "serde", + "serde_json", + "settings", + "theme", + "ui", + "util", + "workspace", "workspace-hack", ] @@ -54,7 +79,7 @@ dependencies = [ "log", "pretty_assertions", "project", - "rand 0.8.5", + "rand 0.9.1", "serde_json", "settings", "text", @@ -129,7 +154,6 @@ dependencies = [ "component", "context_server", "convert_case 0.8.0", - "feature_flags", "fs", "futures 0.3.31", "git", @@ -148,7 +172,7 @@ dependencies = [ "pretty_assertions", "project", "prompt_store", - "rand 0.8.5", + "rand 0.9.1", "ref-cast", "rope", "schemars", @@ -166,16 +190,18 @@ dependencies = [ "uuid", "workspace", "workspace-hack", + "zed_env_vars", "zstd", ] [[package]] name = "agent-client-protocol" -version = "0.0.23" +version = "0.2.0-alpha.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fad72b7b8ee4331b3a4c8d43c107e982a4725564b4ee658ae5c4e79d2b486e8" +checksum = "08539e8d6b2ccca6cd00afdd42211698f7677adef09108a09414c11f1f45fdaf" dependencies = [ "anyhow", + "async-broadcast", "futures 0.3.31", "log", "parking_lot", @@ -190,10 +216,12 @@ version = "0.1.0" dependencies = [ "acp_thread", "action_log", + "agent", "agent-client-protocol", "agent_servers", "agent_settings", "anyhow", + "assistant_context", "assistant_tool", "assistant_tools", "chrono", @@ -203,10 +231,12 @@ dependencies = [ "collections", "context_server", "ctor", + "db", "editor", "env_logger 0.11.8", "fs", "futures 0.3.31", + "git", "gpui", "gpui_tokio", "handlebars 4.5.0", @@ -220,8 +250,8 @@ dependencies = [ "log", "lsp", "open", + "parking_lot", "paths", - "portable-pty", "pretty_assertions", "project", "prompt_store", @@ -232,11 +262,14 @@ dependencies = [ "serde_json", "settings", "smol", + "sqlez", "task", + "telemetry", "tempfile", "terminal", "text", "theme", + "thiserror 2.0.12", "tree-sitter-rust", "ui", "unindent", @@ -244,10 +277,11 @@ dependencies = [ "uuid", "watch", "web_search", - "which 6.0.3", "workspace-hack", "worktree", + "zed_env_vars", "zlog", + "zstd", ] [[package]] @@ -255,36 +289,37 @@ name = "agent_servers" version = "0.1.0" dependencies = [ "acp_thread", + "acp_tools", + "action_log", "agent-client-protocol", - "agentic-coding-protocol", + "agent_settings", "anyhow", + "client", "collections", - "context_server", "env_logger 0.11.8", + "fs", "futures 0.3.31", "gpui", + "gpui_tokio", "indoc", - "itertools 0.14.0", "language", + "language_model", + "language_models", "libc", "log", "nix 0.29.0", - "paths", "project", - "rand 0.8.5", - "schemars", + "reqwest_client", "serde", "serde_json", "settings", "smol", - "strum 0.27.1", + "task", "tempfile", "thiserror 2.0.12", "ui", "util", - "uuid", "watch", - "which 6.0.3", "workspace-hack", ] @@ -320,6 +355,7 @@ dependencies = [ "agent_settings", "ai_onboarding", "anyhow", + "arrayvec", "assistant_context", "assistant_slash_command", "assistant_slash_commands", @@ -346,7 +382,6 @@ dependencies = [ "gpui", "html_to_markdown", "http_client", - "indexed_docs", "indoc", "inventory", "itertools 0.14.0", @@ -365,11 +400,12 @@ dependencies = [ "parking_lot", "paths", "picker", + "postage", "pretty_assertions", "project", "prompt_store", "proto", - "rand 0.8.5", + "rand 0.9.1", "release_channel", "rope", "rules_library", @@ -379,6 +415,7 @@ dependencies = [ "serde_json", "serde_json_lenient", "settings", + "shlex", "smol", "streaming_diff", "task", @@ -404,24 +441,6 @@ dependencies = [ "zed_actions", ] -[[package]] -name = "agentic-coding-protocol" -version = "0.0.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e6ae951b36fa2f8d9dd6e1af6da2fcaba13d7c866cf6a9e65deda9dc6c5fe4" -dependencies = [ - "anyhow", - "chrono", - "derive_more 2.0.1", - "futures 0.3.31", - "log", - "parking_lot", - "schemars", - "semver", - "serde", - "serde_json", -] - [[package]] name = "ahash" version = "0.7.8" @@ -464,6 +483,7 @@ dependencies = [ "client", "cloud_llm_client", "component", + "feature_flags", "gpui", "language_model", "serde", @@ -488,7 +508,7 @@ dependencies = [ "parking_lot", "piper", "polling", - "regex-automata 0.4.9", + "regex-automata", "rustix-openpty", "serde", "signal-hook", @@ -812,7 +832,7 @@ dependencies = [ "project", "prompt_store", "proto", - "rand 0.8.5", + "rand 0.9.1", "regex", "rpc", "serde", @@ -828,6 +848,7 @@ dependencies = [ "uuid", "workspace", "workspace-hack", + "zed_env_vars", ] [[package]] @@ -837,7 +858,7 @@ dependencies = [ "anyhow", "async-trait", "collections", - "derive_more 0.99.19", + "derive_more", "extension", "futures 0.3.31", "gpui", @@ -871,7 +892,6 @@ dependencies = [ "gpui", "html_to_markdown", "http_client", - "indexed_docs", "language", "pretty_assertions", "project", @@ -901,7 +921,7 @@ dependencies = [ "clock", "collections", "ctor", - "derive_more 0.99.19", + "derive_more", "gpui", "icons", "indoc", @@ -911,7 +931,7 @@ dependencies = [ "parking_lot", "pretty_assertions", "project", - "rand 0.8.5", + "rand 0.9.1", "regex", "serde", "serde_json", @@ -938,7 +958,7 @@ dependencies = [ "cloud_llm_client", "collections", "component", - "derive_more 0.99.19", + "derive_more", "diffy", "editor", "feature_flags", @@ -963,7 +983,7 @@ dependencies = [ "pretty_assertions", "project", "prompt_store", - "rand 0.8.5", + "rand 0.9.1", "regex", "reqwest_client", "rust-embed", @@ -1261,26 +1281,6 @@ dependencies = [ "syn 2.0.101", ] -[[package]] -name = "async-stripe" -version = "0.40.0" -source = "git+https://github.com/zed-industries/async-stripe?rev=3672dd4efb7181aa597bf580bf5a2f5d23db6735#3672dd4efb7181aa597bf580bf5a2f5d23db6735" -dependencies = [ - "chrono", - "futures-util", - "http-types", - "hyper 0.14.32", - "hyper-rustls 0.24.2", - "serde", - "serde_json", - "serde_path_to_error", - "serde_qs 0.10.1", - "smart-default 0.6.0", - "smol_str 0.1.24", - "thiserror 1.0.69", - "tokio", -] - [[package]] name = "async-tar" version = "0.5.0" @@ -1303,9 +1303,9 @@ checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de" [[package]] name = "async-trait" -version = "0.1.88" +version = "0.1.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", @@ -1383,11 +1383,18 @@ name = "audio" version = "0.1.0" dependencies = [ "anyhow", + "async-tar", "collections", - "derive_more 0.99.19", + "crossbeam", "gpui", + "libwebrtc", + "log", "parking_lot", "rodio", + "schemars", + "serde", + "settings", + "smol", "util", "workspace-hack", ] @@ -2082,12 +2089,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce" -[[package]] -name = "base64" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" - [[package]] name = "base64" version = "0.21.7" @@ -2294,7 +2295,7 @@ dependencies = [ [[package]] name = "blade-graphics" version = "0.6.0" -source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5" +source = "git+https://github.com/kvark/blade?rev=bfa594ea697d4b6326ea29f747525c85ecf933b9#bfa594ea697d4b6326ea29f747525c85ecf933b9" dependencies = [ "ash", "ash-window", @@ -2327,7 +2328,7 @@ dependencies = [ [[package]] name = "blade-macros" version = "0.3.0" -source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5" +source = "git+https://github.com/kvark/blade?rev=bfa594ea697d4b6326ea29f747525c85ecf933b9#bfa594ea697d4b6326ea29f747525c85ecf933b9" dependencies = [ "proc-macro2", "quote", @@ -2337,7 +2338,7 @@ dependencies = [ [[package]] name = "blade-util" version = "0.2.0" -source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5" +source = "git+https://github.com/kvark/blade?rev=bfa594ea697d4b6326ea29f747525c85ecf933b9#bfa594ea697d4b6326ea29f747525c85ecf933b9" dependencies = [ "blade-graphics", "bytemuck", @@ -2354,19 +2355,6 @@ dependencies = [ "digest", ] -[[package]] -name = "blake3" -version = "1.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" -dependencies = [ - "arrayref", - "arrayvec", - "cc", - "cfg-if", - "constant_time_eq 0.3.1", -] - [[package]] name = "block" version = "0.1.6" @@ -2464,7 +2452,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" dependencies = [ "memchr", - "regex-automata 0.4.9", + "regex-automata", "serde", ] @@ -2481,7 +2469,7 @@ dependencies = [ "language", "log", "pretty_assertions", - "rand 0.8.5", + "rand 0.9.1", "rope", "serde_json", "sum_tree", @@ -2627,6 +2615,7 @@ dependencies = [ "audio", "client", "collections", + "feature_flags", "fs", "futures 0.3.31", "gpui", @@ -2902,11 +2891,9 @@ dependencies = [ "language", "log", "postage", - "rand 0.8.5", "release_channel", "rpc", "settings", - "sum_tree", "text", "time", "util", @@ -3073,10 +3060,9 @@ dependencies = [ "clock", "cloud_api_client", "cloud_llm_client", - "cocoa 0.26.0", "collections", "credentials_provider", - "derive_more 0.99.19", + "derive_more", "feature_flags", "fs", "futures 0.3.31", @@ -3086,10 +3072,11 @@ dependencies = [ "http_client_tls", "httparse", "log", + "objc2-foundation", "parking_lot", "paths", "postage", - "rand 0.8.5", + "rand 0.9.1", "regex", "release_channel", "rpc", @@ -3097,6 +3084,7 @@ dependencies = [ "schemars", "serde", "serde_json", + "serde_urlencoded", "settings", "sha2", "smol", @@ -3280,7 +3268,6 @@ dependencies = [ "anyhow", "assistant_context", "assistant_slash_command", - "async-stripe", "async-trait", "async-tungstenite", "audio", @@ -3296,7 +3283,6 @@ dependencies = [ "chrono", "client", "clock", - "cloud_llm_client", "collab_ui", "collections", "command_palette_hooks", @@ -3307,7 +3293,6 @@ dependencies = [ "dap_adapters", "dashmap 6.1.0", "debugger_ui", - "derive_more 0.99.19", "editor", "envy", "extension", @@ -3323,7 +3308,6 @@ dependencies = [ "http_client", "hyper 0.14.32", "indoc", - "jsonwebtoken", "language", "language_model", "livekit_api", @@ -3341,7 +3325,7 @@ dependencies = [ "prometheus", "prompt_store", "prost 0.9.0", - "rand 0.8.5", + "rand 0.9.1", "recent_projects", "release_channel", "remote", @@ -3369,7 +3353,6 @@ dependencies = [ "telemetry_events", "text", "theme", - "thiserror 2.0.12", "time", "tokio", "toml 0.8.20", @@ -3398,12 +3381,10 @@ dependencies = [ "collections", "db", "editor", - "emojis", "futures 0.3.31", "fuzzy", "gpui", "http_client", - "language", "log", "menu", "notifications", @@ -3411,7 +3392,6 @@ dependencies = [ "pretty_assertions", "project", "release_channel", - "rich_text", "rpc", "schemars", "serde", @@ -3512,7 +3492,7 @@ name = "command_palette_hooks" version = "0.1.0" dependencies = [ "collections", - "derive_more 0.99.19", + "derive_more", "gpui", "workspace-hack", ] @@ -3522,6 +3502,7 @@ name = "component" version = "0.1.0" dependencies = [ "collections", + "documented", "gpui", "inventory", "parking_lot", @@ -3584,12 +3565,6 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" -[[package]] -name = "constant_time_eq" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" - [[package]] name = "context_server" version = "0.1.0" @@ -3871,7 +3846,7 @@ dependencies = [ "rustc-hash 1.1.0", "rustybuzz 0.14.1", "self_cell", - "smol_str 0.2.2", + "smol_str", "swash", "sys-locale", "ttf-parser 0.21.1", @@ -3893,7 +3868,7 @@ dependencies = [ "jni", "js-sys", "libc", - "mach2", + "mach2 0.4.2", "ndk", "ndk-context", "num-derive", @@ -4043,7 +4018,7 @@ checksum = "031ed29858d90cfdf27fe49fae28028a1f20466db97962fa2f4ea34809aeebf3" dependencies = [ "cfg-if", "libc", - "mach2", + "mach2 0.4.2", ] [[package]] @@ -4055,7 +4030,7 @@ dependencies = [ "cfg-if", "crash-context", "libc", - "mach2", + "mach2 0.4.2", "parking_lot", ] @@ -4063,13 +4038,19 @@ dependencies = [ name = "crashes" version = "0.1.0" dependencies = [ + "bincode", "crash-handler", "log", + "mach2 0.5.0", "minidumper", "paths", "release_channel", + "serde", + "serde_json", "smol", + "system_specs", "workspace-hack", + "zstd", ] [[package]] @@ -4164,6 +4145,19 @@ dependencies = [ "itertools 0.10.5", ] +[[package]] +name = "crossbeam" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + [[package]] name = "crossbeam-channel" version = "0.5.15" @@ -4494,6 +4488,7 @@ dependencies = [ "tempfile", "util", "workspace-hack", + "zed_env_vars", ] [[package]] @@ -4670,27 +4665,6 @@ dependencies = [ "syn 2.0.101", ] -[[package]] -name = "derive_more" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678" -dependencies = [ - "derive_more-impl", -] - -[[package]] -name = "derive_more-impl" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.101", - "unicode-xid", -] - [[package]] name = "derive_refineable" version = "0.1.0" @@ -4711,7 +4685,6 @@ dependencies = [ "component", "ctor", "editor", - "futures 0.3.31", "gpui", "indoc", "language", @@ -4720,7 +4693,7 @@ dependencies = [ "markdown", "pretty_assertions", "project", - "rand 0.8.5", + "rand 0.9.1", "serde", "serde_json", "settings", @@ -4760,7 +4733,7 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b545b8c50194bdd008283985ab0b31dba153cfd5b3066a92770634fbc0d7d291" dependencies = [ - "nu-ansi-term 0.50.1", + "nu-ansi-term", ] [[package]] @@ -5065,6 +5038,7 @@ dependencies = [ "clock", "collections", "convert_case 0.8.0", + "criterion", "ctor", "dap", "db", @@ -5091,7 +5065,7 @@ dependencies = [ "parking_lot", "pretty_assertions", "project", - "rand 0.8.5", + "rand 0.9.1", "regex", "release_channel", "rpc", @@ -5586,7 +5560,7 @@ dependencies = [ "parking_lot", "paths", "project", - "rand 0.8.5", + "rand 0.9.1", "release_channel", "remote", "reqwest_client", @@ -5659,8 +5633,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" dependencies = [ "bit-set 0.5.3", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", + "regex-automata", + "regex-syntax", ] [[package]] @@ -5670,8 +5644,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e24cb5a94bcae1e5408b0effca5cd7172ea3c5755049c5f3af4cd283a165298" dependencies = [ "bit-set 0.8.0", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", + "regex-automata", + "regex-syntax", ] [[package]] @@ -5748,14 +5722,10 @@ dependencies = [ name = "feedback" version = "0.1.0" dependencies = [ - "client", "editor", "gpui", - "human_bytes", "menu", - "release_channel", - "serde", - "sysinfo", + "system_specs", "ui", "urlencoding", "util", @@ -6190,17 +6160,6 @@ dependencies = [ "futures-util", ] -[[package]] -name = "futures-batch" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f444c45a1cb86f2a7e301469fd50a82084a60dadc25d94529a8312276ecb71a" -dependencies = [ - "futures 0.3.31", - "futures-timer", - "pin-utils", -] - [[package]] name = "futures-channel" version = "0.3.31" @@ -6296,12 +6255,6 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" -[[package]] -name = "futures-timer" -version = "3.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" - [[package]] name = "futures-util" version = "0.3.31" @@ -6375,17 +6328,6 @@ dependencies = [ "windows-targets 0.48.5", ] -[[package]] -name = "getrandom" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" -dependencies = [ - "cfg-if", - "libc", - "wasi 0.9.0+wasi-snapshot-preview1", -] - [[package]] name = "getrandom" version = "0.2.15" @@ -6442,7 +6384,7 @@ dependencies = [ "askpass", "async-trait", "collections", - "derive_more 0.99.19", + "derive_more", "futures 0.3.31", "git2", "gpui", @@ -6450,7 +6392,7 @@ dependencies = [ "log", "parking_lot", "pretty_assertions", - "rand 0.8.5", + "rand 0.9.1", "regex", "rope", "schemars", @@ -7336,8 +7278,8 @@ dependencies = [ "aho-corasick", "bstr", "log", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", + "regex-automata", + "regex-syntax", ] [[package]] @@ -7472,7 +7414,7 @@ dependencies = [ "core-video", "cosmic-text", "ctor", - "derive_more 0.99.19", + "derive_more", "embed-resource", "env_logger 0.11.8", "etagere", @@ -7503,7 +7445,7 @@ dependencies = [ "pathfinder_geometry", "postage", "profiling", - "rand 0.8.5", + "rand 0.9.1", "raw-window-handle", "refineable", "reqwest_client", @@ -7518,6 +7460,7 @@ dependencies = [ "slotmap", "smallvec", "smol", + "stacksafe", "strum 0.27.1", "sum_tree", "taffy", @@ -7559,6 +7502,7 @@ dependencies = [ name = "gpui_tokio" version = "0.1.0" dependencies = [ + "anyhow", "gpui", "tokio", "util", @@ -7882,6 +7826,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "hound" +version = "3.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f" + [[package]] name = "html5ever" version = "0.27.0" @@ -7983,34 +7933,13 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f" -[[package]] -name = "http-types" -version = "2.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e9b187a72d63adbfba487f48095306ac823049cb504ee195541e91c7775f5ad" -dependencies = [ - "anyhow", - "async-channel 1.9.0", - "base64 0.13.1", - "futures-lite 1.13.0", - "http 0.2.12", - "infer", - "pin-project-lite", - "rand 0.7.3", - "serde", - "serde_json", - "serde_qs 0.8.5", - "serde_urlencoded", - "url", -] - [[package]] name = "http_client" version = "0.1.0" dependencies = [ "anyhow", "bytes 1.10.1", - "derive_more 0.99.19", + "derive_more", "futures 0.3.31", "http 1.3.1", "http-body 1.0.1", @@ -8355,7 +8284,7 @@ dependencies = [ "globset", "log", "memchr", - "regex-automata 0.4.9", + "regex-automata", "same-file", "walkdir", "winapi-util", @@ -8438,38 +8367,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408" [[package]] -name = "indexed_docs" -version = "0.1.0" -dependencies = [ - "anyhow", - "async-trait", - "cargo_metadata", - "collections", - "derive_more 0.99.19", - "extension", - "fs", - "futures 0.3.31", - "fuzzy", - "gpui", - "heed", - "html_to_markdown", - "http_client", - "indexmap", - "indoc", - "parking_lot", - "paths", - "pretty_assertions", - "serde", - "strum 0.27.1", - "util", - "workspace-hack", -] - -[[package]] -name = "indexmap" -version = "2.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" +name = "indexmap" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" dependencies = [ "equivalent", "hashbrown 0.15.3", @@ -8482,12 +8383,6 @@ version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" -[[package]] -name = "infer" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64e9829a50b42bb782c1df523f78d332fe371b10c661e78b7a3c34b0198e9fac" - [[package]] name = "inherent" version = "1.0.12" @@ -8556,6 +8451,7 @@ dependencies = [ "theme", "ui", "util", + "util_macros", "workspace", "workspace-hack", "zed_actions", @@ -8987,7 +8883,7 @@ dependencies = [ "percent-encoding", "referencing", "regex", - "regex-syntax 0.8.5", + "regex-syntax", "reqwest 0.12.15 (registry+https://github.com/rust-lang/crates.io-index)", "serde", "serde_json", @@ -9040,6 +8936,44 @@ dependencies = [ "uuid", ] +[[package]] +name = "keymap_editor" +version = "0.1.0" +dependencies = [ + "anyhow", + "collections", + "command_palette", + "component", + "db", + "editor", + "fs", + "fuzzy", + "gpui", + "itertools 0.14.0", + "language", + "log", + "menu", + "notifications", + "paths", + "project", + "search", + "serde", + "serde_json", + "settings", + "telemetry", + "tempfile", + "theme", + "tree-sitter-json", + "tree-sitter-rust", + "ui", + "ui_input", + "util", + "vim", + "workspace", + "workspace-hack", + "zed_actions", +] + [[package]] name = "khronos-egl" version = "6.0.0" @@ -9124,7 +9058,7 @@ dependencies = [ "parking_lot", "postage", "pretty_assertions", - "rand 0.8.5", + "rand 0.9.1", "regex", "rpc", "schemars", @@ -9197,6 +9131,7 @@ dependencies = [ "icons", "image", "log", + "open_router", "parking_lot", "proto", "schemars", @@ -9265,6 +9200,19 @@ dependencies = [ "x_ai", ] +[[package]] +name = "language_onboarding" +version = "0.1.0" +dependencies = [ + "db", + "editor", + "gpui", + "project", + "ui", + "workspace", + "workspace-hack", +] + [[package]] name = "language_selector" version = "0.1.0" @@ -9292,6 +9240,7 @@ dependencies = [ "anyhow", "client", "collections", + "command_palette_hooks", "copilot", "editor", "futures 0.3.31", @@ -9300,6 +9249,7 @@ dependencies = [ "language", "lsp", "project", + "proto", "release_channel", "serde_json", "settings", @@ -9325,7 +9275,6 @@ dependencies = [ "chrono", "collections", "dap", - "feature_flags", "futures 0.3.31", "gpui", "http_client", @@ -9561,6 +9510,21 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "line_ending_selector" +version = "0.1.0" +dependencies = [ + "editor", + "gpui", + "language", + "picker", + "project", + "ui", + "util", + "workspace", + "workspace-hack", +] + [[package]] name = "link-cplusplus" version = "1.0.10" @@ -9692,6 +9656,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "audio", "collections", "core-foundation 0.10.0", "core-video", @@ -9710,9 +9675,12 @@ dependencies = [ "objc", "parking_lot", "postage", + "rodio", "scap", "serde", "serde_json", + "serde_urlencoded", + "settings", "sha2", "simplelog", "smallvec", @@ -9785,7 +9753,7 @@ dependencies = [ "lazy_static", "proc-macro2", "quote", - "regex-syntax 0.8.5", + "regex-syntax", "rustc_version", "syn 2.0.101", ] @@ -9857,7 +9825,7 @@ dependencies = [ [[package]] name = "lsp-types" version = "0.95.1" -source = "git+https://github.com/zed-industries/lsp-types?rev=39f629bdd03d59abd786ed9fc27e8bca02c0c0ec#39f629bdd03d59abd786ed9fc27e8bca02c0c0ec" +source = "git+https://github.com/zed-industries/lsp-types?rev=0874f8742fe55b4dc94308c1e3c0069710d8eeaf#0874f8742fe55b4dc94308c1e3c0069710d8eeaf" dependencies = [ "bitflags 1.3.2", "serde", @@ -9943,6 +9911,15 @@ dependencies = [ "libc", ] +[[package]] +name = "mach2" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a1b95cd5421ec55b445b5ae102f5ea0e768de1f82bd3001e11f426c269c3aea" +dependencies = [ + "libc", +] + [[package]] name = "malloc_buf" version = "0.0.6" @@ -9991,9 +9968,11 @@ dependencies = [ "editor", "fs", "gpui", + "html5ever 0.27.0", "language", "linkify", "log", + "markup5ever_rcdom", "pretty_assertions", "pulldown-cmark 0.12.2", "settings", @@ -10054,11 +10033,11 @@ dependencies = [ [[package]] name = "matchers" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" dependencies = [ - "regex-automata 0.1.10", + "regex-automata", ] [[package]] @@ -10263,7 +10242,7 @@ dependencies = [ "num-traits", "range-map", "scroll", - "smart-default 0.7.1", + "smart-default", ] [[package]] @@ -10279,7 +10258,7 @@ dependencies = [ "goblin", "libc", "log", - "mach2", + "mach2 0.4.2", "memmap2", "memoffset", "minidump-common", @@ -10422,7 +10401,7 @@ dependencies = [ "parking_lot", "pretty_assertions", "project", - "rand 0.8.5", + "rand 0.9.1", "rope", "serde", "settings", @@ -10759,16 +10738,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "nu-ansi-term" -version = "0.46.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" -dependencies = [ - "overload", - "winapi", -] - [[package]] name = "nu-ansi-term" version = "0.50.1" @@ -11267,6 +11236,8 @@ dependencies = [ "schemars", "serde", "serde_json", + "strum 0.27.1", + "thiserror 2.0.12", "workspace-hack", ] @@ -11462,12 +11433,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "p256" version = "0.11.1" @@ -11690,6 +11655,12 @@ dependencies = [ "hmac", ] +[[package]] +name = "pciid-parser" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0008e816fcdaf229cdd540e9b6ca2dc4a10d65c31624abb546c6420a02846e61" + [[package]] name = "pem" version = "3.0.5" @@ -12656,12 +12627,13 @@ dependencies = [ "postage", "prettier", "pretty_assertions", - "rand 0.8.5", + "rand 0.9.1", "regex", "release_channel", "remote", "rpc", "schemars", + "semver", "serde", "serde_json", "settings", @@ -12681,6 +12653,7 @@ dependencies = [ "unindent", "url", "util", + "watch", "which 6.0.3", "workspace-hack", "worktree", @@ -13137,19 +13110,6 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" -[[package]] -name = "rand" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" -dependencies = [ - "getrandom 0.1.16", - "libc", - "rand_chacha 0.2.2", - "rand_core 0.5.1", - "rand_hc", -] - [[package]] name = "rand" version = "0.8.5" @@ -13171,16 +13131,6 @@ dependencies = [ "rand_core 0.9.3", ] -[[package]] -name = "rand_chacha" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" -dependencies = [ - "ppv-lite86", - "rand_core 0.5.1", -] - [[package]] name = "rand_chacha" version = "0.3.1" @@ -13201,15 +13151,6 @@ dependencies = [ "rand_core 0.9.3", ] -[[package]] -name = "rand_core" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" -dependencies = [ - "getrandom 0.1.16", -] - [[package]] name = "rand_core" version = "0.6.4" @@ -13228,15 +13169,6 @@ dependencies = [ "getrandom 0.3.2", ] -[[package]] -name = "rand_hc" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" -dependencies = [ - "rand_core 0.5.1", -] - [[package]] name = "range-map" version = "0.2.0" @@ -13493,17 +13425,8 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", -] - -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" -dependencies = [ - "regex-syntax 0.6.29", + "regex-automata", + "regex-syntax", ] [[package]] @@ -13514,7 +13437,7 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.5", + "regex-syntax", ] [[package]] @@ -13523,12 +13446,6 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" -[[package]] -name = "regex-syntax" -version = "0.6.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" - [[package]] name = "regex-syntax" version = "0.8.5" @@ -13631,6 +13548,7 @@ dependencies = [ "smol", "sysinfo", "telemetry_events", + "thiserror 2.0.12", "toml 0.8.20", "unindent", "util", @@ -13845,7 +13763,6 @@ dependencies = [ "regex", "reqwest 0.12.15 (git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415)", "serde", - "smol", "tokio", "workspace-hack", ] @@ -13966,14 +13883,15 @@ dependencies = [ [[package]] name = "rodio" version = "0.21.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e40ecf59e742e03336be6a3d53755e789fd05a059fa22dfa0ed624722319e183" +source = "git+https://github.com/RustAudio/rodio?branch=better_wav_output#82514bd1f2c6cfd9a1a885019b26a8ffea75bc5c" dependencies = [ "cpal", "dasp_sample", + "hound", "num-rational", + "rtrb", "symphonia", - "tracing", + "thiserror 2.0.12", ] [[package]] @@ -13985,7 +13903,7 @@ dependencies = [ "ctor", "gpui", "log", - "rand 0.8.5", + "rand 0.9.1", "rayon", "smallvec", "sum_tree", @@ -14014,7 +13932,7 @@ dependencies = [ "gpui", "parking_lot", "proto", - "rand 0.8.5", + "rand 0.9.1", "rsa", "serde", "serde_json", @@ -14047,6 +13965,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rtrb" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad8388ea1a9e0ea807e442e8263a699e7edcb320ecbcd21b4fa8ff859acce3ba" + [[package]] name = "rules_library" version = "0.1.0" @@ -14449,6 +14373,19 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "scheduler" +version = "0.1.0" +dependencies = [ + "async-task", + "backtrace", + "chrono", + "futures 0.3.31", + "parking_lot", + "rand 0.9.1", + "workspace-hack", +] + [[package]] name = "schema_generator" version = "0.1.0" @@ -14469,12 +14406,10 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe8c9d1c68d67dd9f97ecbc6f932b60eb289c5dbddd8aa1405484a8fd2fcd984" dependencies = [ - "chrono", "dyn-clone", "indexmap", "ref-cast", "schemars_derive", - "semver", "serde", "serde_json", ] @@ -14753,49 +14688,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0f7d95a54511e0c7be3f51e8867aa8cf35148d7b9445d44de2f943e2b206e749" -[[package]] -name = "semantic_index" -version = "0.1.0" -dependencies = [ - "anyhow", - "arrayvec", - "blake3", - "client", - "clock", - "collections", - "feature_flags", - "fs", - "futures 0.3.31", - "futures-batch", - "gpui", - "heed", - "http_client", - "language", - "language_model", - "languages", - "log", - "open_ai", - "parking_lot", - "project", - "reqwest_client", - "serde", - "serde_json", - "settings", - "sha2", - "smol", - "streaming-iterator", - "tempfile", - "theme", - "tree-sitter", - "ui", - "unindent", - "util", - "workspace", - "workspace-hack", - "worktree", - "zlog", -] - [[package]] name = "semantic_version" version = "0.1.0" @@ -14890,28 +14782,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_qs" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7715380eec75f029a4ef7de39a9200e0a63823176b759d055b613f5a87df6a6" -dependencies = [ - "percent-encoding", - "serde", - "thiserror 1.0.69", -] - -[[package]] -name = "serde_qs" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cac3f1e2ca2fe333923a1ae72caca910b98ed0630bb35ef6f8c8517d6e81afa" -dependencies = [ - "percent-encoding", - "serde", - "thiserror 1.0.69", -] - [[package]] name = "serde_repr" version = "0.1.20" @@ -14989,6 +14859,8 @@ dependencies = [ "serde_derive", "serde_json", "serde_json_lenient", + "serde_path_to_error", + "settings_ui_macros", "smallvec", "tree-sitter", "tree-sitter-json", @@ -15024,39 +14896,33 @@ name = "settings_ui" version = "0.1.0" dependencies = [ "anyhow", - "collections", - "command_palette", "command_palette_hooks", - "component", - "db", + "debugger_ui", "editor", "feature_flags", - "fs", - "fuzzy", "gpui", - "itertools 0.14.0", - "language", - "log", "menu", - "notifications", - "paths", - "project", - "search", "serde", "serde_json", "settings", - "telemetry", - "tempfile", + "smallvec", "theme", - "tree-sitter-json", - "tree-sitter-rust", "ui", - "ui_input", - "util", "workspace", "workspace-hack", ] +[[package]] +name = "settings_ui_macros" +version = "0.1.0" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "syn 2.0.101", + "workspace-hack", +] + [[package]] name = "sha1" version = "0.10.6" @@ -15286,17 +15152,6 @@ dependencies = [ "serde", ] -[[package]] -name = "smart-default" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "133659a15339456eeeb07572eb02a91c91e9815e9cbc89566944d2c8d3efdbf6" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "smart-default" version = "0.7.1" @@ -15325,15 +15180,6 @@ dependencies = [ "futures-lite 2.6.0", ] -[[package]] -name = "smol_str" -version = "0.1.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fad6c857cbab2627dcf01ec85a623ca4e7dcb5691cbaa3d7fb7653671f0d09c9" -dependencies = [ - "serde", -] - [[package]] name = "smol_str" version = "0.2.2" @@ -15460,6 +15306,7 @@ dependencies = [ "futures 0.3.31", "indoc", "libsqlite3-sys", + "log", "parking_lot", "smol", "sqlformat", @@ -15705,6 +15552,40 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stacker" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cddb07e32ddb770749da91081d8d0ac3a16f1a569a18b20348cd371f5dead06b" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "windows-sys 0.59.0", +] + +[[package]] +name = "stacksafe" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d9c1172965d317e87ddb6d364a040d958b40a1db82b6ef97da26253a8b3d090" +dependencies = [ + "stacker", + "stacksafe-macro", +] + +[[package]] +name = "stacksafe-macro" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "172175341049678163e979d9107ca3508046d4d2a7c6682bee46ac541b17db69" +dependencies = [ + "proc-macro-error2", + "quote", + "syn 2.0.101", +] + [[package]] name = "static_assertions" version = "1.1.0" @@ -15763,7 +15644,7 @@ name = "streaming_diff" version = "0.1.0" dependencies = [ "ordered-float 2.10.1", - "rand 0.8.5", + "rand 0.9.1", "rope", "util", "workspace-hack", @@ -15877,7 +15758,7 @@ dependencies = [ "arrayvec", "ctor", "log", - "rand 0.8.5", + "rand 0.9.1", "rayon", "workspace-hack", "zlog", @@ -16019,9 +15900,11 @@ dependencies = [ "editor", "file_icons", "gpui", + "project", "ui", "workspace", "workspace-hack", + "worktree", ] [[package]] @@ -16052,12 +15935,53 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "815c942ae7ee74737bb00f965fa5b5a2ac2ce7b6c01c0cc169bbeaf7abd5f5a9" dependencies = [ "lazy_static", + "symphonia-bundle-flac", + "symphonia-bundle-mp3", + "symphonia-codec-aac", "symphonia-codec-pcm", + "symphonia-codec-vorbis", "symphonia-core", + "symphonia-format-isomp4", + "symphonia-format-ogg", "symphonia-format-riff", "symphonia-metadata", ] +[[package]] +name = "symphonia-bundle-flac" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72e34f34298a7308d4397a6c7fbf5b84c5d491231ce3dd379707ba673ab3bd97" +dependencies = [ + "log", + "symphonia-core", + "symphonia-metadata", + "symphonia-utils-xiph", +] + +[[package]] +name = "symphonia-bundle-mp3" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c01c2aae70f0f1fb096b6f0ff112a930b1fb3626178fba3ae68b09dce71706d4" +dependencies = [ + "lazy_static", + "log", + "symphonia-core", + "symphonia-metadata", +] + +[[package]] +name = "symphonia-codec-aac" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdbf25b545ad0d3ee3e891ea643ad115aff4ca92f6aec472086b957a58522f70" +dependencies = [ + "lazy_static", + "log", + "symphonia-core", +] + [[package]] name = "symphonia-codec-pcm" version = "0.5.4" @@ -16068,6 +15992,17 @@ dependencies = [ "symphonia-core", ] +[[package]] +name = "symphonia-codec-vorbis" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a98765fb46a0a6732b007f7e2870c2129b6f78d87db7987e6533c8f164a9f30" +dependencies = [ + "log", + "symphonia-core", + "symphonia-utils-xiph", +] + [[package]] name = "symphonia-core" version = "0.5.4" @@ -16081,6 +16016,31 @@ dependencies = [ "log", ] +[[package]] +name = "symphonia-format-isomp4" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abfdf178d697e50ce1e5d9b982ba1b94c47218e03ec35022d9f0e071a16dc844" +dependencies = [ + "encoding_rs", + "log", + "symphonia-core", + "symphonia-metadata", + "symphonia-utils-xiph", +] + +[[package]] +name = "symphonia-format-ogg" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ada3505789516bcf00fc1157c67729eded428b455c27ca370e41f4d785bfa931" +dependencies = [ + "log", + "symphonia-core", + "symphonia-metadata", + "symphonia-utils-xiph", +] + [[package]] name = "symphonia-format-riff" version = "0.5.4" @@ -16105,6 +16065,16 @@ dependencies = [ "symphonia-core", ] +[[package]] +name = "symphonia-utils-xiph" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "484472580fa49991afda5f6550ece662237b00c6f562c7d9638d1b086ed010fe" +dependencies = [ + "symphonia-core", + "symphonia-metadata", +] + [[package]] name = "syn" version = "1.0.109" @@ -16256,6 +16226,21 @@ dependencies = [ "winx", ] +[[package]] +name = "system_specs" +version = "0.1.0" +dependencies = [ + "anyhow", + "client", + "gpui", + "human_bytes", + "pciid-parser", + "release_channel", + "serde", + "sysinfo", + "workspace-hack", +] + [[package]] name = "tab_switcher" version = "0.1.0" @@ -16453,7 +16438,7 @@ dependencies = [ "futures 0.3.31", "gpui", "libc", - "rand 0.8.5", + "rand 0.9.1", "regex", "release_channel", "schemars", @@ -16501,7 +16486,7 @@ dependencies = [ "language", "log", "project", - "rand 0.8.5", + "rand 0.9.1", "regex", "schemars", "search", @@ -16533,7 +16518,7 @@ dependencies = [ "log", "parking_lot", "postage", - "rand 0.8.5", + "rand 0.9.1", "regex", "rope", "smallvec", @@ -16549,7 +16534,7 @@ version = "0.1.0" dependencies = [ "anyhow", "collections", - "derive_more 0.99.19", + "derive_more", "fs", "futures 0.3.31", "gpui", @@ -16849,7 +16834,6 @@ dependencies = [ "schemars", "serde", "settings", - "settings_ui", "smallvec", "story", "telemetry", @@ -17064,10 +17048,15 @@ checksum = "bfb942dfe1d8e29a7ee7fcbde5bd2b9a25fb89aa70caea2eba3bee836ff41076" name = "toolchain_selector" version = "0.1.0" dependencies = [ + "anyhow", + "convert_case 0.8.0", "editor", + "file_finder", + "futures 0.3.31", "fuzzy", "gpui", "language", + "menu", "picker", "project", "ui", @@ -17218,14 +17207,14 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "matchers", - "nu-ansi-term 0.46.0", + "nu-ansi-term", "once_cell", - "regex", + "regex-automata", "serde", "serde_json", "sharded-slab", @@ -17256,7 +17245,7 @@ checksum = "a7cf18d43cbf0bfca51f657132cc616a5097edc4424d538bae6fa60142eaf9f0" dependencies = [ "cc", "regex", - "regex-syntax 0.8.5", + "regex-syntax", "serde_json", "streaming-iterator", "tree-sitter-language", @@ -17286,8 +17275,7 @@ dependencies = [ [[package]] name = "tree-sitter-cpp" version = "0.23.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df2196ea9d47b4ab4a31b9297eaa5a5d19a0b121dceb9f118f6790ad0ab94743" +source = "git+https://github.com/tree-sitter/tree-sitter-cpp?rev=5cb9b693cfd7bfacab1d9ff4acac1a4150700609#5cb9b693cfd7bfacab1d9ff4acac1a4150700609" dependencies = [ "cc", "tree-sitter-language", @@ -17891,7 +17879,7 @@ dependencies = [ "libc", "log", "nix 0.29.0", - "rand 0.8.5", + "rand 0.9.1", "regex", "rust-embed", "schemars", @@ -18076,6 +18064,8 @@ version = "0.1.0" dependencies = [ "anyhow", "gpui", + "schemars", + "serde", "settings", "workspace-hack", ] @@ -18182,12 +18172,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "wasi" -version = "0.9.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -18421,7 +18405,7 @@ dependencies = [ "indexmap", "libc", "log", - "mach2", + "mach2 0.4.2", "memfd", "object", "once_cell", @@ -18688,7 +18672,7 @@ dependencies = [ "futures 0.3.31", "gpui", "parking_lot", - "rand 0.8.5", + "rand 0.9.1", "workspace-hack", "zlog", ] @@ -19896,7 +19880,6 @@ dependencies = [ "any_vec", "anyhow", "async-recursion", - "bincode", "call", "client", "clock", @@ -19915,6 +19898,7 @@ dependencies = [ "node_runtime", "parking_lot", "postage", + "pretty_assertions", "project", "remote", "schemars", @@ -19980,6 +19964,7 @@ dependencies = [ "core-foundation-sys", "cranelift-codegen", "crc32fast", + "crossbeam-channel", "crossbeam-epoch", "crossbeam-utils", "crypto-common", @@ -20023,6 +20008,7 @@ dependencies = [ "libsqlite3-sys", "linux-raw-sys 0.4.15", "linux-raw-sys 0.9.4", + "livekit-runtime", "log", "lyon", "lyon_path", @@ -20061,8 +20047,8 @@ dependencies = [ "rand_core 0.6.4", "regalloc2", "regex", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", + "regex-automata", + "regex-syntax", "ring", "rust_decimal", "rustc-hash 1.1.0", @@ -20070,7 +20056,6 @@ dependencies = [ "rustix 1.0.7", "rustls 0.23.26", "rustls-webpki 0.103.1", - "schemars", "scopeguard", "sea-orm", "sea-query-binder", @@ -20148,7 +20133,7 @@ dependencies = [ "paths", "postage", "pretty_assertions", - "rand 0.8.5", + "rand 0.9.1", "rpc", "schemars", "serde", @@ -20245,9 +20230,9 @@ dependencies = [ [[package]] name = "xcb" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1e2f212bb1a92cd8caac8051b829a6582ede155ccb60b5d5908b81b100952be" +checksum = "f07c123b796139bfe0603e654eaf08e132e52387ba95b252c78bad3640ba37ea" dependencies = [ "bitflags 1.3.2", "libc", @@ -20274,7 +20259,7 @@ dependencies = [ [[package]] name = "xim" version = "0.4.0" -source = "git+https://github.com/XDeme1/xim-rs?rev=d50d461764c2213655cd9cf65a0ea94c70d3c4fd#d50d461764c2213655cd9cf65a0ea94c70d3c4fd" +source = "git+https://github.com/zed-industries/xim-rs?rev=c0a70c1bd2ce197364216e5e818a2cb3adb99a8d#c0a70c1bd2ce197364216e5e818a2cb3adb99a8d" dependencies = [ "ahash 0.8.11", "hashbrown 0.14.5", @@ -20287,7 +20272,7 @@ dependencies = [ [[package]] name = "xim-ctext" version = "0.3.0" -source = "git+https://github.com/XDeme1/xim-rs?rev=d50d461764c2213655cd9cf65a0ea94c70d3c4fd#d50d461764c2213655cd9cf65a0ea94c70d3c4fd" +source = "git+https://github.com/zed-industries/xim-rs?rev=c0a70c1bd2ce197364216e5e818a2cb3adb99a8d#c0a70c1bd2ce197364216e5e818a2cb3adb99a8d" dependencies = [ "encoding_rs", ] @@ -20295,7 +20280,7 @@ dependencies = [ [[package]] name = "xim-parser" version = "0.2.1" -source = "git+https://github.com/XDeme1/xim-rs?rev=d50d461764c2213655cd9cf65a0ea94c70d3c4fd#d50d461764c2213655cd9cf65a0ea94c70d3c4fd" +source = "git+https://github.com/zed-industries/xim-rs?rev=c0a70c1bd2ce197364216e5e818a2cb3adb99a8d#c0a70c1bd2ce197364216e5e818a2cb3adb99a8d" dependencies = [ "bitflags 2.9.0", ] @@ -20371,8 +20356,9 @@ checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" [[package]] name = "yawc" -version = "0.2.4" -source = "git+https://github.com/deviant-forks/yawc?rev=1899688f3e69ace4545aceb97b2a13881cf26142#1899688f3e69ace4545aceb97b2a13881cf26142" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19a5d82922135b4ae73a079a4ffb5501e9aadb4d785b8c660eaa0a8b899028c5" dependencies = [ "base64 0.22.1", "bytes 1.10.1", @@ -20503,11 +20489,11 @@ dependencies = [ [[package]] name = "zed" -version = "0.201.0" +version = "0.205.0" dependencies = [ + "acp_tools", "activity_indicator", "agent", - "agent_servers", "agent_settings", "agent_ui", "anyhow", @@ -20520,6 +20506,7 @@ dependencies = [ "auto_update", "auto_update_ui", "backtrace", + "bincode", "breadcrumbs", "call", "channel", @@ -20565,14 +20552,17 @@ dependencies = [ "itertools 0.14.0", "jj_ui", "journal", + "keymap_editor", "language", "language_extension", "language_model", "language_models", + "language_onboarding", "language_selector", "language_tools", "languages", "libc", + "line_ending_selector", "log", "markdown", "markdown_preview", @@ -20617,6 +20607,7 @@ dependencies = [ "supermaven", "svg_preview", "sysinfo", + "system_specs", "tab_switcher", "task", "tasks_ui", @@ -20648,6 +20639,7 @@ dependencies = [ "workspace", "workspace-hack", "zed_actions", + "zed_env_vars", "zeta", "zlog", "zlog_settings", @@ -20665,10 +20657,10 @@ dependencies = [ ] [[package]] -name = "zed_emmet" -version = "0.0.6" +name = "zed_env_vars" +version = "0.1.0" dependencies = [ - "zed_extension_api 0.1.0", + "workspace-hack", ] [[package]] @@ -20700,7 +20692,7 @@ dependencies = [ [[package]] name = "zed_html" -version = "0.2.1" +version = "0.2.2" dependencies = [ "zed_extension_api 0.1.0", ] @@ -20721,7 +20713,7 @@ dependencies = [ [[package]] name = "zed_snippets" -version = "0.0.5" +version = "0.0.6" dependencies = [ "serde_json", "zed_extension_api 0.1.0", @@ -20734,13 +20726,6 @@ dependencies = [ "zed_extension_api 0.6.0", ] -[[package]] -name = "zed_toml" -version = "0.1.4" -dependencies = [ - "zed_extension_api 0.1.0", -] - [[package]] name = "zeno" version = "0.3.2" @@ -20899,13 +20884,15 @@ dependencies = [ "gpui", "http_client", "indoc", + "itertools 0.14.0", "language", "language_model", "log", "menu", + "parking_lot", "postage", "project", - "rand 0.8.5", + "rand 0.9.1", "regex", "release_channel", "reqwest_client", @@ -20913,6 +20900,7 @@ dependencies = [ "serde", "serde_json", "settings", + "strum 0.27.1", "telemetry", "telemetry_events", "theme", @@ -20920,7 +20908,6 @@ dependencies = [ "tree-sitter-go", "tree-sitter-rust", "ui", - "unindent", "util", "uuid", "workspace", @@ -20975,7 +20962,7 @@ dependencies = [ "aes", "byteorder", "bzip2", - "constant_time_eq 0.1.5", + "constant_time_eq", "crc32fast", "crossbeam-utils", "flate2", diff --git a/Cargo.toml b/Cargo.toml index 1baa6d3d7497934b13a368eec5bad9c3c09445d4..e1eca763746e59e4d4ef206dafc3e6b6f3c67190 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] resolver = "2" members = [ + "crates/acp_tools", "crates/acp_thread", "crates/action_log", "crates/activity_indicator", @@ -53,6 +54,8 @@ members = [ "crates/deepseek", "crates/diagnostics", "crates/docs_preprocessor", + "crates/edit_prediction", + "crates/edit_prediction_button", "crates/editor", "crates/eval", "crates/explorer_command_injector", @@ -81,21 +84,21 @@ members = [ "crates/http_client_tls", "crates/icons", "crates/image_viewer", - "crates/indexed_docs", - "crates/edit_prediction", - "crates/edit_prediction_button", "crates/inspector_ui", "crates/install_cli", "crates/jj", "crates/jj_ui", "crates/journal", + "crates/keymap_editor", "crates/language", "crates/language_extension", "crates/language_model", "crates/language_models", + "crates/language_onboarding", "crates/language_selector", "crates/language_tools", "crates/languages", + "crates/line_ending_selector", "crates/livekit_api", "crates/livekit_client", "crates/lmstudio", @@ -130,6 +133,7 @@ members = [ "crates/refineable", "crates/refineable/derive_refineable", "crates/release_channel", + "crates/scheduler", "crates/remote", "crates/remote_server", "crates/repl", @@ -140,12 +144,12 @@ members = [ "crates/rules_library", "crates/schema_generator", "crates/search", - "crates/semantic_index", "crates/semantic_version", "crates/session", "crates/settings", "crates/settings_profile_selector", "crates/settings_ui", + "crates/settings_ui_macros", "crates/snippet", "crates/snippet_provider", "crates/snippets_ui", @@ -158,6 +162,7 @@ members = [ "crates/supermaven", "crates/supermaven_api", "crates/svg_preview", + "crates/system_specs", "crates/tab_switcher", "crates/task", "crates/tasks_ui", @@ -190,6 +195,7 @@ members = [ "crates/x_ai", "crates/zed", "crates/zed_actions", + "crates/zed_env_vars", "crates/zeta", "crates/zeta_cli", "crates/zlog", @@ -199,7 +205,6 @@ members = [ # Extensions # - "extensions/emmet", "extensions/glsl", "extensions/html", "extensions/proto", @@ -207,7 +212,6 @@ members = [ "extensions/slash-commands-example", "extensions/snippets", "extensions/test-extension", - "extensions/toml", # # Tooling @@ -228,6 +232,7 @@ edition = "2024" # Workspace member crates # +acp_tools = { path = "crates/acp_tools" } acp_thread = { path = "crates/acp_thread" } action_log = { path = "crates/action_log" } agent = { path = "crates/agent" } @@ -272,6 +277,7 @@ context_server = { path = "crates/context_server" } copilot = { path = "crates/copilot" } crashes = { path = "crates/crashes" } credentials_provider = { path = "crates/credentials_provider" } +crossbeam = "0.8.4" dap = { path = "crates/dap" } dap_adapters = { path = "crates/dap_adapters" } db = { path = "crates/db" } @@ -296,9 +302,7 @@ git_hosting_providers = { path = "crates/git_hosting_providers" } git_ui = { path = "crates/git_ui" } go_to_line = { path = "crates/go_to_line" } google_ai = { path = "crates/google_ai" } -gpui = { path = "crates/gpui", default-features = false, features = [ - "http_client", -] } +gpui = { path = "crates/gpui", default-features = false } gpui_macros = { path = "crates/gpui_macros" } gpui_tokio = { path = "crates/gpui_tokio" } html_to_markdown = { path = "crates/html_to_markdown" } @@ -306,7 +310,6 @@ http_client = { path = "crates/http_client" } http_client_tls = { path = "crates/http_client_tls" } icons = { path = "crates/icons" } image_viewer = { path = "crates/image_viewer" } -indexed_docs = { path = "crates/indexed_docs" } edit_prediction = { path = "crates/edit_prediction" } edit_prediction_button = { path = "crates/edit_prediction_button" } inspector_ui = { path = "crates/inspector_ui" } @@ -314,13 +317,16 @@ install_cli = { path = "crates/install_cli" } jj = { path = "crates/jj" } jj_ui = { path = "crates/jj_ui" } journal = { path = "crates/journal" } +keymap_editor = { path = "crates/keymap_editor" } language = { path = "crates/language" } language_extension = { path = "crates/language_extension" } language_model = { path = "crates/language_model" } language_models = { path = "crates/language_models" } +language_onboarding = { path = "crates/language_onboarding" } language_selector = { path = "crates/language_selector" } language_tools = { path = "crates/language_tools" } languages = { path = "crates/languages" } +line_ending_selector = { path = "crates/line_ending_selector" } livekit_api = { path = "crates/livekit_api" } livekit_client = { path = "crates/livekit_client" } lmstudio = { path = "crates/lmstudio" } @@ -358,20 +364,22 @@ proto = { path = "crates/proto" } recent_projects = { path = "crates/recent_projects" } refineable = { path = "crates/refineable" } release_channel = { path = "crates/release_channel" } +scheduler = { path = "crates/scheduler" } remote = { path = "crates/remote" } remote_server = { path = "crates/remote_server" } repl = { path = "crates/repl" } reqwest_client = { path = "crates/reqwest_client" } rich_text = { path = "crates/rich_text" } +rodio = { git = "https://github.com/RustAudio/rodio", branch = "better_wav_output"} rope = { path = "crates/rope" } rpc = { path = "crates/rpc" } rules_library = { path = "crates/rules_library" } search = { path = "crates/search" } -semantic_index = { path = "crates/semantic_index" } semantic_version = { path = "crates/semantic_version" } session = { path = "crates/session" } settings = { path = "crates/settings" } settings_ui = { path = "crates/settings_ui" } +settings_ui_macros = { path = "crates/settings_ui_macros" } snippet = { path = "crates/snippet" } snippet_provider = { path = "crates/snippet_provider" } snippets_ui = { path = "crates/snippets_ui" } @@ -383,6 +391,7 @@ streaming_diff = { path = "crates/streaming_diff" } sum_tree = { path = "crates/sum_tree" } supermaven = { path = "crates/supermaven" } supermaven_api = { path = "crates/supermaven_api" } +system_specs = { path = "crates/system_specs" } tab_switcher = { path = "crates/tab_switcher" } task = { path = "crates/task" } tasks_ui = { path = "crates/tasks_ui" } @@ -416,6 +425,7 @@ worktree = { path = "crates/worktree" } x_ai = { path = "crates/x_ai" } zed = { path = "crates/zed" } zed_actions = { path = "crates/zed_actions" } +zed_env_vars = { path = "crates/zed_env_vars" } zeta = { path = "crates/zeta" } zlog = { path = "crates/zlog" } zlog_settings = { path = "crates/zlog_settings" } @@ -424,8 +434,7 @@ zlog_settings = { path = "crates/zlog_settings" } # External crates # -agentic-coding-protocol = "0.0.10" -agent-client-protocol = "0.0.23" +agent-client-protocol = { version = "0.2.0-alpha.8", features = ["unstable"] } aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" @@ -439,6 +448,7 @@ async-fs = "2.1" async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "82d00a04211cf4e1236029aa03e6b6ce2a74c553" } async-recursion = "1.0.0" async-tar = "0.5.0" +async-task = "4.7" async-trait = "0.1" async-tungstenite = "0.29.1" async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] } @@ -451,11 +461,13 @@ aws-sdk-bedrockruntime = { version = "1.80.0", features = [ ] } aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] } aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] } +backtrace = "0.3" base64 = "0.22" +bincode = "1.2.1" bitflags = "2.6.0" -blade-graphics = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" } -blade-macros = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" } -blade-util = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" } +blade-graphics = { git = "https://github.com/kvark/blade", rev = "bfa594ea697d4b6326ea29f747525c85ecf933b9" } +blade-macros = { git = "https://github.com/kvark/blade", rev = "bfa594ea697d4b6326ea29f747525c85ecf933b9" } +blade-util = { git = "https://github.com/kvark/blade", rev = "bfa594ea697d4b6326ea29f747525c85ecf933b9" } blake3 = "1.5.3" bytes = "1.0" cargo_metadata = "0.19" @@ -495,6 +507,7 @@ handlebars = "4.3" heck = "0.5" heed = { version = "0.21.0", features = ["read-txn-no-tls"] } hex = "0.4.3" +human_bytes = "0.4.1" html5ever = "0.27.0" http = "1.1" http-body = "1.0" @@ -516,7 +529,8 @@ libc = "0.2" libsqlite3-sys = { version = "0.30.1", features = ["bundled"] } linkify = "0.10.0" log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] } -lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "39f629bdd03d59abd786ed9fc27e8bca02c0c0ec" } +lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "0874f8742fe55b4dc94308c1e3c0069710d8eeaf" } +mach2 = "0.5" markup5ever_rcdom = "0.3.0" metal = "0.29" minidumper = "0.8" @@ -527,12 +541,38 @@ nbformat = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c80421 nix = "0.29" num-format = "0.4.4" objc = "0.2" +objc2-foundation = { version = "0.3", default-features = false, features = [ + "NSArray", + "NSAttributedString", + "NSBundle", + "NSCoder", + "NSData", + "NSDate", + "NSDictionary", + "NSEnumerator", + "NSError", + "NSGeometry", + "NSNotification", + "NSNull", + "NSObjCRuntime", + "NSObject", + "NSProcessInfo", + "NSRange", + "NSRunLoop", + "NSString", + "NSURL", + "NSUndoManager", + "NSValue", + "objc2-core-foundation", + "std" +] } open = "5.0.0" ordered-float = "2.1.1" palette = { version = "0.7.5", default-features = false, features = ["std"] } parking_lot = "0.12.1" partial-json-fixer = "0.5.3" parse_int = "0.9" +pciid-parser = "0.8.0" pathdiff = "0.2" pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" } pet-conda = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" } @@ -551,7 +591,7 @@ prost-build = "0.9" prost-types = "0.9" pulldown-cmark = { version = "0.12.0", default-features = false } quote = "1.0.9" -rand = "0.8.5" +rand = "0.9" rayon = "1.8" ref-cast = "1.0.24" regex = "1.5" @@ -564,7 +604,6 @@ reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "951c77 "socks", "stream", ] } -rodio = { version = "0.21.1", default-features = false } rsa = "0.9.6" runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [ "async-dispatcher-runtime", @@ -584,7 +623,9 @@ serde_json_lenient = { version = "0.2", features = [ "preserve_order", "raw_value", ] } +serde_path_to_error = "0.1.17" serde_repr = "0.1" +serde_urlencoded = "0.7" sha2 = "0.10" shellexpand = "2.1.0" shlex = "1.3.0" @@ -592,6 +633,7 @@ simplelog = "0.12.2" smallvec = { version = "1.6", features = ["union"] } smol = "2.0" sqlformat = "0.2" +stacksafe = "0.1" streaming-iterator = "0.1" strsim = "0.11" strum = { version = "0.27.0", features = ["derive"] } @@ -618,7 +660,7 @@ tower-http = "0.4.4" tree-sitter = { version = "0.25.6", features = ["wasm"] } tree-sitter-bash = "0.25.0" tree-sitter-c = "0.23" -tree-sitter-cpp = "0.23" +tree-sitter-cpp = { git = "https://github.com/tree-sitter/tree-sitter-cpp", rev = "5cb9b693cfd7bfacab1d9ff4acac1a4150700609" } tree-sitter-css = "0.23" tree-sitter-diff = "0.1.0" tree-sitter-elixir = "0.3" @@ -662,25 +704,9 @@ which = "6.0.0" windows-core = "0.61" wit-component = "0.221" workspace-hack = "0.1.0" -# We can switch back to the published version once https://github.com/infinitefield/yawc/pull/16 is merged and a new -# version is released. -yawc = { git = "https://github.com/deviant-forks/yawc", rev = "1899688f3e69ace4545aceb97b2a13881cf26142" } +yawc = "0.2.5" zstd = "0.11" -[workspace.dependencies.async-stripe] -git = "https://github.com/zed-industries/async-stripe" -rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735" -default-features = false -features = [ - "runtime-tokio-hyper-rustls", - "billing", - "checkout", - "events", - # The features below are only enabled to get the `events` feature to build. - "chrono", - "connect", -] - [workspace.dependencies.windows] version = "0.61" features = [ @@ -701,6 +727,7 @@ features = [ "Win32_Graphics_Dxgi_Common", "Win32_Graphics_Gdi", "Win32_Graphics_Imaging", + "Win32_Graphics_Hlsl", "Win32_Networking_WinSock", "Win32_Security", "Win32_Security_Credentials", @@ -818,39 +845,33 @@ unexpected_cfgs = { level = "allow" } dbg_macro = "deny" todo = "deny" -# Motivation: We use `vec![a..b]` a lot when dealing with ranges in text, so -# warning on this rule produces a lot of noise. -single_range_in_vec_init = "allow" +# This is not a style lint, see https://github.com/rust-lang/rust-clippy/pull/15454 +# Remove when the lint gets promoted to `suspicious`. +declare_interior_mutable_const = "deny" + +redundant_clone = "deny" -# These are all of the rules that currently have violations in the Zed -# codebase. +# We currently do not restrict any style rules +# as it slows down shipping code to Zed. # -# We'll want to drive this list down by either: -# 1. fixing violations of the rule and begin enforcing it -# 2. deciding we want to allow the rule permanently, at which point -# we should codify that separately above. +# Running ./script/clippy can take several minutes, and so it's +# common to skip that step and let CI do it. Any unexpected failures +# (which also take minutes to discover) thus require switching back +# to an old branch, manual fixing, and re-pushing. # -# This list shouldn't be added to; it should only get shorter. -# ============================================================================= - -# There are a bunch of rules currently failing in the `style` group, so -# allow all of those, for now. +# In the future we could improve this by either making sure +# Zed can surface clippy errors in diagnostics (in addition to the +# rust-analyzer errors), or by having CI fix style nits automatically. style = { level = "allow", priority = -1 } -# Temporary list of style lints that we've fixed so far. -module_inception = { level = "deny" } -question_mark = { level = "deny" } -redundant_closure = { level = "deny" } -declare_interior_mutable_const = { level = "deny" } # Individual rules that have violations in the codebase: type_complexity = "allow" -# We often return trait objects from `new` functions. -new_ret_no_self = { level = "allow" } -# We have a few `next` functions that differ in lifetimes -# compared to Iterator::next. Yet, clippy complains about those. -should_implement_trait = { level = "allow" } let_underscore_future = "allow" +# Motivation: We use `vec![a..b]` a lot when dealing with ranges in text, so +# warning on this rule produces a lot of noise. +single_range_in_vec_init = "allow" + # in Rust it can be very tedious to reduce argument count without # running afoul of the borrow checker. too_many_arguments = "allow" @@ -858,6 +879,9 @@ too_many_arguments = "allow" # We often have large enum variants yet we rarely actually bother with splitting them up. large_enum_variant = "allow" +# Boolean expressions can be hard to read, requiring only the minimal form gets in the way +nonminimal_bool = "allow" + [workspace.metadata.cargo-machete] ignored = [ "bindgen", diff --git a/Procfile.web b/Procfile.web new file mode 100644 index 0000000000000000000000000000000000000000..814055514498124d1f20b1fed51f23a5809819a9 --- /dev/null +++ b/Procfile.web @@ -0,0 +1,2 @@ +postgrest_llm: postgrest crates/collab/postgrest_llm.conf +website: cd ../zed.dev; npm run dev -- --port=3000 diff --git a/assets/icons/ai.svg b/assets/icons/ai.svg index d60396ad47db1a2068207c52f783c08cd2da4e69..4236d50337bef92cb550cdbf71d83843ab35e2f3 100644 --- a/assets/icons/ai.svg +++ b/assets/icons/ai.svg @@ -1,5 +1,5 @@ - + diff --git a/assets/icons/arrow_circle.svg b/assets/icons/arrow_circle.svg index 76363c6270890d51e5946664fa4943e5b16aca0c..cdfa93979505e45a9e876059eddf5a61ac489e1a 100644 --- a/assets/icons/arrow_circle.svg +++ b/assets/icons/arrow_circle.svg @@ -1,6 +1,6 @@ - - - - + + + + diff --git a/assets/icons/arrow_down.svg b/assets/icons/arrow_down.svg index c71e5437f8cd9424be47da102802a47c30575769..60e6584c4568a5e113e225800024e835ea9743e7 100644 --- a/assets/icons/arrow_down.svg +++ b/assets/icons/arrow_down.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/arrow_down10.svg b/assets/icons/arrow_down10.svg index 8eed82276cc4bf613d5bc026ad8bd59694760787..5933b758d939bef502495cbffcbddc60a3d42691 100644 --- a/assets/icons/arrow_down10.svg +++ b/assets/icons/arrow_down10.svg @@ -1 +1 @@ - + diff --git a/assets/icons/arrow_down_right.svg b/assets/icons/arrow_down_right.svg index 73f72a2c38c6f6833a3c96f74fddafd8d1fb8730..ebdb06d77b24d5aa0d28615e156135495e8e80c4 100644 --- a/assets/icons/arrow_down_right.svg +++ b/assets/icons/arrow_down_right.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/arrow_left.svg b/assets/icons/arrow_left.svg index ca441497a054f2c6a1769f804d4c8aaf7cbc8ccc..f7eacb2a779c94e3f743fdbc594773de73017e41 100644 --- a/assets/icons/arrow_left.svg +++ b/assets/icons/arrow_left.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/arrow_right.svg b/assets/icons/arrow_right.svg index ae148885637563795bec94d85b51f979be4613a4..b9324af5a289ac2b4ae6d7b6374d603587763de0 100644 --- a/assets/icons/arrow_right.svg +++ b/assets/icons/arrow_right.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/arrow_right_left.svg b/assets/icons/arrow_right_left.svg index cfeee0cc24b5c988f15d83a29e7fb32b7427ccb0..2c1211056a17eee8644b07b1fb651818f63db3dc 100644 --- a/assets/icons/arrow_right_left.svg +++ b/assets/icons/arrow_right_left.svg @@ -1,6 +1,6 @@ - - - - + + + + diff --git a/assets/icons/arrow_up.svg b/assets/icons/arrow_up.svg index b98c710374fc5c9c5ef8ddc05fcbdbe2eaa30017..ff3ad441234b8d2ae1aeb17c531a9ecb288dc8d2 100644 --- a/assets/icons/arrow_up.svg +++ b/assets/icons/arrow_up.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/arrow_up_right.svg b/assets/icons/arrow_up_right.svg index fb065bc9ce7d90d20db4ee45b7bc2d909dede09f..a948ef8f8130b99339130e4400320309ae3afaec 100644 --- a/assets/icons/arrow_up_right.svg +++ b/assets/icons/arrow_up_right.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/attach.svg b/assets/icons/attach.svg new file mode 100644 index 0000000000000000000000000000000000000000..f923a3c7c8841fd358cf940d99e7371f010a6f4d --- /dev/null +++ b/assets/icons/attach.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/audio_off.svg b/assets/icons/audio_off.svg index dfb5a1c45829119ea0dc89bbca3a3f33228ee88f..43d2a04344748feab9496cd528aacba075d9f7e8 100644 --- a/assets/icons/audio_off.svg +++ b/assets/icons/audio_off.svg @@ -1,7 +1,7 @@ - - - - - + + + + + diff --git a/assets/icons/audio_on.svg b/assets/icons/audio_on.svg index d1bef0d337d6c8a0e79cb0dab8b7d63d5cb2a4d1..6e183bd585461e49418f58af95026590549c950b 100644 --- a/assets/icons/audio_on.svg +++ b/assets/icons/audio_on.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/backspace.svg b/assets/icons/backspace.svg index 679ef1ade19eef8317e0a35547c3b6b212a72499..9ef4432b6f019b1eb71978e214e6ea9a3e680839 100644 --- a/assets/icons/backspace.svg +++ b/assets/icons/backspace.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/bell.svg b/assets/icons/bell.svg index f9b2a97fb34faceb155b5eb6a263ff0752b9e402..70225bb105f24ad42616fb10b4742a2d3176502b 100644 --- a/assets/icons/bell.svg +++ b/assets/icons/bell.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/bell_dot.svg b/assets/icons/bell_dot.svg index 09a17401dabe97c75eb8d8a977aa0ed00d12f1ec..959a7773cf2af4a6520741a40cc6866ffab4bdab 100644 --- a/assets/icons/bell_dot.svg +++ b/assets/icons/bell_dot.svg @@ -1,5 +1,5 @@ - - + + diff --git a/assets/icons/bell_off.svg b/assets/icons/bell_off.svg index 98cbd1eb603c48de6f157b5d4cbcfbf246e05702..5c3c1a0d68680d8d9a7fa42163c40e899259646c 100644 --- a/assets/icons/bell_off.svg +++ b/assets/icons/bell_off.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/bell_ring.svg b/assets/icons/bell_ring.svg index e411e7511b0b10be7efd5d85d1257b325f9d64de..838056cc032aa4c47c75ffa1a1f2e189835ff2da 100644 --- a/assets/icons/bell_ring.svg +++ b/assets/icons/bell_ring.svg @@ -1,6 +1,6 @@ - - - - + + + + diff --git a/assets/icons/binary.svg b/assets/icons/binary.svg index bbc375617f061c9feae2de08bcda683f5192c3b8..3c15e9b5470575c6251fef6fe1ae2f035ef677a6 100644 --- a/assets/icons/binary.svg +++ b/assets/icons/binary.svg @@ -1 +1 @@ - + diff --git a/assets/icons/blocks.svg b/assets/icons/blocks.svg index e1690e2642b60d93893e52008c6d10d96e810d48..84725d789233079e9f3f4138392af6b15a104d9c 100644 --- a/assets/icons/blocks.svg +++ b/assets/icons/blocks.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/bolt_outlined.svg b/assets/icons/bolt_outlined.svg index 58fccf778813d3653f1066f45e5573adbf2d9ec2..ca9c75fbfd64beaac0ed544d2718a5ecb59a8243 100644 --- a/assets/icons/bolt_outlined.svg +++ b/assets/icons/bolt_outlined.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/book.svg b/assets/icons/book.svg index 8b0f89e82d073d857582f8364f1f501b8567cccd..a2ab394be4a74b9fb618dd3a7613b70f277181df 100644 --- a/assets/icons/book.svg +++ b/assets/icons/book.svg @@ -1 +1 @@ - + diff --git a/assets/icons/book_copy.svg b/assets/icons/book_copy.svg index f509beffe6da4af84b72268832af94bd6d3568b1..b7afd1df5c1fbaf51936a7170b3cfa0d0622511d 100644 --- a/assets/icons/book_copy.svg +++ b/assets/icons/book_copy.svg @@ -1 +1 @@ - + diff --git a/assets/icons/chat.svg b/assets/icons/chat.svg index a0548c3d3e6917fbea2bfba825761e01cd215a33..c64f6b5e0efb65a7c2e056d55bee1917960b4d29 100644 --- a/assets/icons/chat.svg +++ b/assets/icons/chat.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/check.svg b/assets/icons/check.svg index 4563505aaaecfdfa300d609709a34ecc4a5f3fb5..21e2137965e01f4d384f1f2aad70629e2d75f313 100644 --- a/assets/icons/check.svg +++ b/assets/icons/check.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/check_circle.svg b/assets/icons/check_circle.svg index e6ec5d11efffcc64721e444b8fef9a5a94481436..f9b88c4ce1451ef24a4084d6b3bb9469be85b571 100644 --- a/assets/icons/check_circle.svg +++ b/assets/icons/check_circle.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/check_double.svg b/assets/icons/check_double.svg index b52bef81a404d96489121985fa8bafdcfe30753c..fabc7005209070087e8d56e14808e68bd1f4c771 100644 --- a/assets/icons/check_double.svg +++ b/assets/icons/check_double.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/chevron_down.svg b/assets/icons/chevron_down.svg index 7894aae76497858d0db923063eed53dc41db991b..e4ca142a91fa18a252dbba72faffd9d403d29c2a 100644 --- a/assets/icons/chevron_down.svg +++ b/assets/icons/chevron_down.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/chevron_left.svg b/assets/icons/chevron_left.svg index 4be4c95dcac3e79116df2836634a8720bb36a2ef..fbe438fd4bfbcfc0bf08c2bbcf1c416c073ddc6d 100644 --- a/assets/icons/chevron_left.svg +++ b/assets/icons/chevron_left.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/chevron_right.svg b/assets/icons/chevron_right.svg index c8ff84717750a076b39cd4a045667ec1942e4167..4f170717c9b185a3204f0078160d34b0dae4aff1 100644 --- a/assets/icons/chevron_right.svg +++ b/assets/icons/chevron_right.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/chevron_up.svg b/assets/icons/chevron_up.svg index 8e575e2e8d2242c29602655ca08a334edc04690b..bbe6b9762d244af38f4fe258f9f334638829e6f6 100644 --- a/assets/icons/chevron_up.svg +++ b/assets/icons/chevron_up.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/chevron_up_down.svg b/assets/icons/chevron_up_down.svg index c7af01d4a36869c9c7b9e44bd8e0bdc42c5baf44..299f6bce5ad1e5e6d89b0d12d4ce9deb2f3ee193 100644 --- a/assets/icons/chevron_up_down.svg +++ b/assets/icons/chevron_up_down.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/circle_help.svg b/assets/icons/circle_help.svg index 4e2890d3e10e7976c648aec48c524771cce80ba8..0e623bd1da3241616b9b6bd8fb7d6243b12b07b7 100644 --- a/assets/icons/circle_help.svg +++ b/assets/icons/circle_help.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/close.svg b/assets/icons/close.svg index ad487e0a4f9fd2d95b26bf5cd84933e7bb817b9e..846b3a703dc6f53f36736b515443830d51205c99 100644 --- a/assets/icons/close.svg +++ b/assets/icons/close.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/cloud_download.svg b/assets/icons/cloud_download.svg index 0efcbe10f13b7b5d79c367824b1dfd88d0de1105..70cda55856cc459674e7641366631a035fcc0251 100644 --- a/assets/icons/cloud_download.svg +++ b/assets/icons/cloud_download.svg @@ -1 +1 @@ - + diff --git a/assets/icons/code.svg b/assets/icons/code.svg index 6a1795b59c9c8fefb9b0df2061e29ac3be2e3e1f..72d145224a16f28184ac0a6bedbb02f608248adf 100644 --- a/assets/icons/code.svg +++ b/assets/icons/code.svg @@ -1 +1 @@ - + diff --git a/assets/icons/cog.svg b/assets/icons/cog.svg index 4f3ada11a632c5d69bb68dda95bfdfa0d2ae6975..7dd3a8befff59b5aaa0506df9b2cd7140725ab81 100644 --- a/assets/icons/cog.svg +++ b/assets/icons/cog.svg @@ -1 +1 @@ - + diff --git a/assets/icons/command.svg b/assets/icons/command.svg index 6602af8e1f1e085e26d4548b93e3fb26964825a6..f361ca2d05f71ec0af5d1e99ac0f0a633a932ebe 100644 --- a/assets/icons/command.svg +++ b/assets/icons/command.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/control.svg b/assets/icons/control.svg index e831968df6d0a6e85d376517add5978c10b56315..f9341b6256143ce250aed35a38cebd2b6ed74207 100644 --- a/assets/icons/control.svg +++ b/assets/icons/control.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/copilot.svg b/assets/icons/copilot.svg index 57c0a5f91ae3641bd35cb41da9dff530d4ae7d51..2584cd631006c10ea9535408657fd881f0748249 100644 --- a/assets/icons/copilot.svg +++ b/assets/icons/copilot.svg @@ -1,9 +1,9 @@ - - - + + + diff --git a/assets/icons/copy.svg b/assets/icons/copy.svg index dfd8d9dbb9d62d09a3d0c7de9da4a9f0a9af3c5f..aba193930bd1e93062b1e7eef3e4a0de2e7f4ab6 100644 --- a/assets/icons/copy.svg +++ b/assets/icons/copy.svg @@ -1 +1,4 @@ - + + + + diff --git a/assets/icons/countdown_timer.svg b/assets/icons/countdown_timer.svg index 5e69f1bfb4b47a144d0675c2338f210aae953781..5d1e775e68c8bc3871ad8070faeaea36c7395eec 100644 --- a/assets/icons/countdown_timer.svg +++ b/assets/icons/countdown_timer.svg @@ -1 +1 @@ - + diff --git a/assets/icons/crosshair.svg b/assets/icons/crosshair.svg index 1492bf924543c9c64b06021a93dc44967740ea4c..3af6aa9fa35f29a6635d58027d7774ddd43a510c 100644 --- a/assets/icons/crosshair.svg +++ b/assets/icons/crosshair.svg @@ -1,7 +1,7 @@ - - - - - + + + + + diff --git a/assets/icons/cursor_i_beam.svg b/assets/icons/cursor_i_beam.svg index 3790de6f49d454bc5bb317e64e80a4daffceaa45..2d513181f94d2d2d29ebf1f779925bc09b85b699 100644 --- a/assets/icons/cursor_i_beam.svg +++ b/assets/icons/cursor_i_beam.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/dash.svg b/assets/icons/dash.svg index 9270f80781da095feb8abf5561fa579189574078..3928ee7cfa2b9ce408d4560b6f1aa0c0b298d1fb 100644 --- a/assets/icons/dash.svg +++ b/assets/icons/dash.svg @@ -1 +1 @@ - + diff --git a/assets/icons/database_zap.svg b/assets/icons/database_zap.svg index 160ffa5041957318e0b3f47864c81a50028c62f6..76af0f9251d096d91ee0d054ed867181a288e313 100644 --- a/assets/icons/database_zap.svg +++ b/assets/icons/database_zap.svg @@ -1 +1 @@ - + diff --git a/assets/icons/debug.svg b/assets/icons/debug.svg index 900caf4b983f60a2de7424c62965023f88283a18..6423a2b090c1b838b0a4e84c089f5db694777790 100644 --- a/assets/icons/debug.svg +++ b/assets/icons/debug.svg @@ -1,12 +1,12 @@ - - - - - - - - - - + + + + + + + + + + diff --git a/assets/icons/debug_breakpoint.svg b/assets/icons/debug_breakpoint.svg index 9cab42eecd37ba8dff1892ced7dd8de92dae0998..c09a3c159fed6bb2423fc7e50ce7fcae8d425065 100644 --- a/assets/icons/debug_breakpoint.svg +++ b/assets/icons/debug_breakpoint.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/debug_continue.svg b/assets/icons/debug_continue.svg index f663a5a041abc8ddc29938f3f980187c1d3e9f03..f03a8b2364b5fe8555cddb6d0e940480f2d15e75 100644 --- a/assets/icons/debug_continue.svg +++ b/assets/icons/debug_continue.svg @@ -1 +1 @@ - + diff --git a/assets/icons/debug_detach.svg b/assets/icons/debug_detach.svg index a34a0e817146097fce6e4b95919e2415c665fa43..8b3484557148a3e3e54638be5d6579f9741b7e3e 100644 --- a/assets/icons/debug_detach.svg +++ b/assets/icons/debug_detach.svg @@ -1 +1 @@ - + diff --git a/assets/icons/debug_disabled_breakpoint.svg b/assets/icons/debug_disabled_breakpoint.svg index 8b80623b025af88df7a5ff4fcf4871a6b6b86cd2..9a7c896f4709c97591196bc0c3d3813bb2a2a62c 100644 --- a/assets/icons/debug_disabled_breakpoint.svg +++ b/assets/icons/debug_disabled_breakpoint.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/debug_disabled_log_breakpoint.svg b/assets/icons/debug_disabled_log_breakpoint.svg index 2ccc37623d9daa4b5f6c79748a2e4bf3f7a03067..f477f4f32d83ee2ea7a364ac2d22bf7ce72a9f5a 100644 --- a/assets/icons/debug_disabled_log_breakpoint.svg +++ b/assets/icons/debug_disabled_log_breakpoint.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/debug_ignore_breakpoints.svg b/assets/icons/debug_ignore_breakpoints.svg index b2a345d314ec599fb53f09988183dff976e38977..bc95329c7ad1b44e075481ef52d98d69f650b176 100644 --- a/assets/icons/debug_ignore_breakpoints.svg +++ b/assets/icons/debug_ignore_breakpoints.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/debug_step_back.svg b/assets/icons/debug_step_back.svg index d1112d6b8e1725f90a5440b862c680bb6272a792..61d45866f61cbabbd9a7ae9975809d342cb76ed5 100644 --- a/assets/icons/debug_step_back.svg +++ b/assets/icons/debug_step_back.svg @@ -1 +1 @@ - + diff --git a/assets/icons/debug_step_into.svg b/assets/icons/debug_step_into.svg index 02bdd63cb4d0ca5ae7f0e6de167bf8d5c534d560..9a517fc7ca0762b17446a75cd90f39a91e1b51cf 100644 --- a/assets/icons/debug_step_into.svg +++ b/assets/icons/debug_step_into.svg @@ -1 +1 @@ - + diff --git a/assets/icons/debug_step_out.svg b/assets/icons/debug_step_out.svg index 48190b704b25ba4631b076606eaefd5090e15d24..147a44f930f34f6c3ddce94693a178a932129cb5 100644 --- a/assets/icons/debug_step_out.svg +++ b/assets/icons/debug_step_out.svg @@ -1 +1 @@ - + diff --git a/assets/icons/debug_step_over.svg b/assets/icons/debug_step_over.svg index 54afac001f3d249af236265da078a92551fb4422..336abc11deb866a128e8418dab47af01b6e4d3f6 100644 --- a/assets/icons/debug_step_over.svg +++ b/assets/icons/debug_step_over.svg @@ -1 +1 @@ - + diff --git a/assets/icons/diff.svg b/assets/icons/diff.svg index 61aa617f5b8ea66c94262f64ca8efb6e6607ae59..9d93b2d5b47f56dd77338c1cf59c912a2cdad294 100644 --- a/assets/icons/diff.svg +++ b/assets/icons/diff.svg @@ -1 +1 @@ - + diff --git a/assets/icons/disconnected.svg b/assets/icons/disconnected.svg index f3069798d0c904fd04313f115a985e70d452a4a2..47bd1db4788825f9475d11a288f5ad715e6de5a5 100644 --- a/assets/icons/disconnected.svg +++ b/assets/icons/disconnected.svg @@ -1 +1 @@ - + diff --git a/assets/icons/download.svg b/assets/icons/download.svg index 6ddcb1e100ec6392ff62c1209f79938cc31f7d8f..6c105d3fd74ac685176f0532e488322c00fb5fef 100644 --- a/assets/icons/download.svg +++ b/assets/icons/download.svg @@ -1 +1 @@ - + diff --git a/assets/icons/envelope.svg b/assets/icons/envelope.svg index 0f5e95f96817aefe0819c4767c7e77b9d1a17638..273cc6de267eeea7b9d893c81df57e806dafe089 100644 --- a/assets/icons/envelope.svg +++ b/assets/icons/envelope.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/eraser.svg b/assets/icons/eraser.svg index 601f2b9b909a8b1c43638ceefdb1a926c842d376..ca6209785fd1a2dd087792c645304dc05d4f4edb 100644 --- a/assets/icons/eraser.svg +++ b/assets/icons/eraser.svg @@ -1 +1 @@ - + diff --git a/assets/icons/escape.svg b/assets/icons/escape.svg index a87f03d2fa07eab2ebaee07cf6800592a0db20c4..1898588a67172f1586c4e40f622b95b8bc971511 100644 --- a/assets/icons/escape.svg +++ b/assets/icons/escape.svg @@ -1 +1 @@ - + diff --git a/assets/icons/exit.svg b/assets/icons/exit.svg index 1ff9d7882441548e9c3534ae5ffe6b6331391b45..3619a55c87083c7e53bc2545d7e243dad7d58eca 100644 --- a/assets/icons/exit.svg +++ b/assets/icons/exit.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/expand_down.svg b/assets/icons/expand_down.svg index 07390aad18525f69f7fa6b39a62bf7d64bfc1503..9f85ee67209ff89d31739219f76b89f24c71afec 100644 --- a/assets/icons/expand_down.svg +++ b/assets/icons/expand_down.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/expand_up.svg b/assets/icons/expand_up.svg index 73c1358b995d4e7dac1d058d66c463ac0616f646..49b084fa8f41df5bc16443081ce1617e7d7e1ef9 100644 --- a/assets/icons/expand_up.svg +++ b/assets/icons/expand_up.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/expand_vertical.svg b/assets/icons/expand_vertical.svg index e2a6dd227e0bdbddf8185b053adaed62ae03115a..5a5fa8ccb52019ccfc1afebd3972ef067634589e 100644 --- a/assets/icons/expand_vertical.svg +++ b/assets/icons/expand_vertical.svg @@ -1 +1 @@ - + diff --git a/assets/icons/eye.svg b/assets/icons/eye.svg index 7f10f738015ab077507266826e32fd88a0024460..327fa751e992167fba8901848fb2a639fda39726 100644 --- a/assets/icons/eye.svg +++ b/assets/icons/eye.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/file.svg b/assets/icons/file.svg index 85f3f543a51c8b8b24dcafa013468b646339740b..60cf2537d9e67321caf6b63f2775c1e6b29e6c32 100644 --- a/assets/icons/file.svg +++ b/assets/icons/file.svg @@ -1 +1 @@ - + diff --git a/assets/icons/file_code.svg b/assets/icons/file_code.svg index b0e632b67f86717cf6588fc293b42ce118345101..548d5a153ba243f7ae6890372ec9aaae765a9124 100644 --- a/assets/icons/file_code.svg +++ b/assets/icons/file_code.svg @@ -1 +1 @@ - + diff --git a/assets/icons/file_diff.svg b/assets/icons/file_diff.svg index d6cb4440eacddda1cc0be91a22f705c758263add..193dd7392ff1ff5cf4281921ffc2eb0b2b4697c9 100644 --- a/assets/icons/file_diff.svg +++ b/assets/icons/file_diff.svg @@ -1 +1 @@ - + diff --git a/assets/icons/file_doc.svg b/assets/icons/file_doc.svg index 3b11995f36759e6928abbc1cfbaa118345f0a21b..ccd5eeea01b01adc8598b0325bbaec935d272ba5 100644 --- a/assets/icons/file_doc.svg +++ b/assets/icons/file_doc.svg @@ -1,6 +1,6 @@ - + - - + + diff --git a/assets/icons/file_generic.svg b/assets/icons/file_generic.svg index 3c72bd3320d9e851641a4eecbc6d7c6bd3e989e3..790a5f18d723939131d2d7100c50020429eb4ff4 100644 --- a/assets/icons/file_generic.svg +++ b/assets/icons/file_generic.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/file_git.svg b/assets/icons/file_git.svg index 197db2e9e60f260c7a56a6e44c6250c531e0353d..2b36b0ffd3ba1c4389952a35a072027c6dc6de0f 100644 --- a/assets/icons/file_git.svg +++ b/assets/icons/file_git.svg @@ -1,6 +1,6 @@ - - - - + + + + diff --git a/assets/icons/file_icons/ai.svg b/assets/icons/file_icons/ai.svg index d60396ad47db1a2068207c52f783c08cd2da4e69..4236d50337bef92cb550cdbf71d83843ab35e2f3 100644 --- a/assets/icons/file_icons/ai.svg +++ b/assets/icons/file_icons/ai.svg @@ -1,5 +1,5 @@ - + diff --git a/assets/icons/file_icons/audio.svg b/assets/icons/file_icons/audio.svg index 672f736c958662ca157b165de417076a391ac398..7948b046160e92eb9e0d3cce3e28e3c007ce0a83 100644 --- a/assets/icons/file_icons/audio.svg +++ b/assets/icons/file_icons/audio.svg @@ -1,8 +1,8 @@ - - - - - - + + + + + + diff --git a/assets/icons/file_icons/book.svg b/assets/icons/file_icons/book.svg index 3b11995f36759e6928abbc1cfbaa118345f0a21b..ccd5eeea01b01adc8598b0325bbaec935d272ba5 100644 --- a/assets/icons/file_icons/book.svg +++ b/assets/icons/file_icons/book.svg @@ -1,6 +1,6 @@ - + - - + + diff --git a/assets/icons/file_icons/bun.svg b/assets/icons/file_icons/bun.svg index 48af8b3088dd040f6fd5f39d05a9d9e9e8f413ce..ca1ec900bc0a18eb44c9d5e2a810ae2c3730ed8c 100644 --- a/assets/icons/file_icons/bun.svg +++ b/assets/icons/file_icons/bun.svg @@ -1,5 +1,5 @@ - + diff --git a/assets/icons/file_icons/chevron_down.svg b/assets/icons/file_icons/chevron_down.svg index 9e60e40cf4c6a86f10ed3bc399b068a52208b572..9918f6c9f7188ca0e0de76071649be9a14f36d27 100644 --- a/assets/icons/file_icons/chevron_down.svg +++ b/assets/icons/file_icons/chevron_down.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/file_icons/chevron_left.svg b/assets/icons/file_icons/chevron_left.svg index a2aa9ad996a432362d2c8382cd7c651ff13151b7..3299ee71684be25aa2a5f8d520c12b5509bbcbca 100644 --- a/assets/icons/file_icons/chevron_left.svg +++ b/assets/icons/file_icons/chevron_left.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/file_icons/chevron_right.svg b/assets/icons/file_icons/chevron_right.svg index 06608c95ee11ec5fba5b9f5c235d4604001ab440..140f644127da6b3551b03a921c4a29f8daf0077b 100644 --- a/assets/icons/file_icons/chevron_right.svg +++ b/assets/icons/file_icons/chevron_right.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/file_icons/chevron_up.svg b/assets/icons/file_icons/chevron_up.svg index fd3d5e4470b438119fd5a33245f5583f76dba32a..ae8c12a9899dcc57985c391fc8f382bbba210905 100644 --- a/assets/icons/file_icons/chevron_up.svg +++ b/assets/icons/file_icons/chevron_up.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/file_icons/code.svg b/assets/icons/file_icons/code.svg index 5f012f883837f689da5c38e905b2eb0b9723945a..af2f6c5dc0e4916dd673c2a8a40f6e9b1cb9aa99 100644 --- a/assets/icons/file_icons/code.svg +++ b/assets/icons/file_icons/code.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/file_icons/coffeescript.svg b/assets/icons/file_icons/coffeescript.svg index fc49df62c0b74c73106cde11fa5766154d31db86..e91d187615b78a8f4cb6c5d2c209b2d772e7344a 100644 --- a/assets/icons/file_icons/coffeescript.svg +++ b/assets/icons/file_icons/coffeescript.svg @@ -1,5 +1,5 @@ - + diff --git a/assets/icons/file_icons/conversations.svg b/assets/icons/file_icons/conversations.svg index cef764661fed601146b5a659369999d39c3a44d8..e25ed973ef4e47c4bb1a3c434a5398228f98f488 100644 --- a/assets/icons/file_icons/conversations.svg +++ b/assets/icons/file_icons/conversations.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/file_icons/dart.svg b/assets/icons/file_icons/dart.svg index fd3ab01c93a42d7737a2d5af6aca4e8083372954..c9ec3de51a469fbc68712ce15849c144e48c6616 100644 --- a/assets/icons/file_icons/dart.svg +++ b/assets/icons/file_icons/dart.svg @@ -1,5 +1,5 @@ - + diff --git a/assets/icons/file_icons/database.svg b/assets/icons/file_icons/database.svg index 10fbdcbff4ccd6b437e0815750a1fa91fe6bf187..a8226110d3775ce2bd784daccfa52633ab0ab597 100644 --- a/assets/icons/file_icons/database.svg +++ b/assets/icons/file_icons/database.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/file_icons/diff.svg b/assets/icons/file_icons/diff.svg index 07c46f1799604f0ac9581e51760184c142984f3d..ec59a0aabee71abe6fd954e63056425054a4cc60 100644 --- a/assets/icons/file_icons/diff.svg +++ b/assets/icons/file_icons/diff.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/file_icons/eslint.svg b/assets/icons/file_icons/eslint.svg index 0f42abe691b4ea275b8b74f1bf2e9b9ab2bcc2ca..ba72d9166b29bc5feca600c5111a68ae2357db6b 100644 --- a/assets/icons/file_icons/eslint.svg +++ b/assets/icons/file_icons/eslint.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/file_icons/file.svg b/assets/icons/file_icons/file.svg index 3c72bd3320d9e851641a4eecbc6d7c6bd3e989e3..790a5f18d723939131d2d7100c50020429eb4ff4 100644 --- a/assets/icons/file_icons/file.svg +++ b/assets/icons/file_icons/file.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/file_icons/folder.svg b/assets/icons/file_icons/folder.svg index a76dc63d1a663993f02e4b6a88b200e4aea22f0c..e40613000da5ac10282bce1ed74fd6ef07ab566b 100644 --- a/assets/icons/file_icons/folder.svg +++ b/assets/icons/file_icons/folder.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/file_icons/folder_open.svg b/assets/icons/file_icons/folder_open.svg index ef37f55f83a38f2eb5713ae615407276f76028b7..55231fb6abdb876aa86984fffaf9d3993552ebab 100644 --- a/assets/icons/file_icons/folder_open.svg +++ b/assets/icons/file_icons/folder_open.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/file_icons/font.svg b/assets/icons/file_icons/font.svg index 4cb01a28f27c1a715bad570aa13ef55e6d9a6412..6f2b734b26307eb2ba8584bf6d673bab701a7de2 100644 --- a/assets/icons/file_icons/font.svg +++ b/assets/icons/file_icons/font.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/file_icons/git.svg b/assets/icons/file_icons/git.svg index 197db2e9e60f260c7a56a6e44c6250c531e0353d..2b36b0ffd3ba1c4389952a35a072027c6dc6de0f 100644 --- a/assets/icons/file_icons/git.svg +++ b/assets/icons/file_icons/git.svg @@ -1,6 +1,6 @@ - - - - + + + + diff --git a/assets/icons/file_icons/gleam.svg b/assets/icons/file_icons/gleam.svg index 6a3dc2c96fe76bee376d6fbd4a72c8d1cc56715d..0399bb4dd2a747845a79bbc78fc1e845768241c0 100644 --- a/assets/icons/file_icons/gleam.svg +++ b/assets/icons/file_icons/gleam.svg @@ -1,7 +1,7 @@ - - + + diff --git a/assets/icons/file_icons/graphql.svg b/assets/icons/file_icons/graphql.svg index 96884725998e29d223b5c7be76d6d1d27cb773c6..e6c0368182e6ed23fb5c36bb1b7d8ad251bc7e53 100644 --- a/assets/icons/file_icons/graphql.svg +++ b/assets/icons/file_icons/graphql.svg @@ -1,6 +1,6 @@ - - + + diff --git a/assets/icons/file_icons/hash.svg b/assets/icons/file_icons/hash.svg index 2241904266fa2f46df1eaeb6956229cf47e553c7..77e6c600725af5387ec1bebb2eb82d4eff6fa756 100644 --- a/assets/icons/file_icons/hash.svg +++ b/assets/icons/file_icons/hash.svg @@ -1,6 +1,6 @@ - - - - + + + + diff --git a/assets/icons/file_icons/heroku.svg b/assets/icons/file_icons/heroku.svg index 826a88646bf3753bd106308b3e211142c1b65280..732adf72cb6097543946d738e0110ccd75115faf 100644 --- a/assets/icons/file_icons/heroku.svg +++ b/assets/icons/file_icons/heroku.svg @@ -1,5 +1,5 @@ - + diff --git a/assets/icons/file_icons/html.svg b/assets/icons/file_icons/html.svg index 41f254dd681530e29f443c0aa78c36ea440b4a69..8832bcba3a71bf8369b7f39913e9a24857e032e4 100644 --- a/assets/icons/file_icons/html.svg +++ b/assets/icons/file_icons/html.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/file_icons/image.svg b/assets/icons/file_icons/image.svg index 75e64c0a43a06570f4806944c390f406deeef5a6..c89de1b1285ed22959e39b1b6c6ce21b444b90e9 100644 --- a/assets/icons/file_icons/image.svg +++ b/assets/icons/file_icons/image.svg @@ -1,7 +1,7 @@ - - - + + + diff --git a/assets/icons/file_icons/java.svg b/assets/icons/file_icons/java.svg index 63ce6e768c835007013875c338d52279b8cfe515..70d2d10ed7b8e09d1e6195858d5ec50cb4e7de03 100644 --- a/assets/icons/file_icons/java.svg +++ b/assets/icons/file_icons/java.svg @@ -1,7 +1,7 @@ - - - - - + + + + + diff --git a/assets/icons/file_icons/lock.svg b/assets/icons/file_icons/lock.svg index 6bfef249b4516f3fbbf7f1a4c220b0fe893367d5..10ae33869a610714a66763683b38ff91ea9fa074 100644 --- a/assets/icons/file_icons/lock.svg +++ b/assets/icons/file_icons/lock.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/file_icons/magnifying_glass.svg b/assets/icons/file_icons/magnifying_glass.svg index 75c3e76c80b5c1c577881d9fb7a942f162e395d3..d0440d905c35bce2960f4f9691a585c1d91e91fd 100644 --- a/assets/icons/file_icons/magnifying_glass.svg +++ b/assets/icons/file_icons/magnifying_glass.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/file_icons/nix.svg b/assets/icons/file_icons/nix.svg index 879a4d76aac461739afb97ba6b1d00240c3b4490..215d58a035c2306111665ad7ee2b169950ee59ec 100644 --- a/assets/icons/file_icons/nix.svg +++ b/assets/icons/file_icons/nix.svg @@ -1,8 +1,8 @@ - - - - - - + + + + + + diff --git a/assets/icons/file_icons/notebook.svg b/assets/icons/file_icons/notebook.svg index b72ebc3967c8944163a62be92442315b706a8093..968d5c598297c794c7bf5cc86535ac3a3fa67daf 100644 --- a/assets/icons/file_icons/notebook.svg +++ b/assets/icons/file_icons/notebook.svg @@ -1,8 +1,8 @@ - - - - - + + + + + diff --git a/assets/icons/file_icons/package.svg b/assets/icons/file_icons/package.svg index 12889e80845869a6ea2453fb60619fa01a578a0b..16bbccb2e63788c0bf442ccd54f34f63a58c28be 100644 --- a/assets/icons/file_icons/package.svg +++ b/assets/icons/file_icons/package.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/file_icons/phoenix.svg b/assets/icons/file_icons/phoenix.svg index b61b8beda7ba55e19f47b88e6e5a8bed9ddc02a3..5db68b4e44b0d13aaad3818f6b7635a8b4cd937f 100644 --- a/assets/icons/file_icons/phoenix.svg +++ b/assets/icons/file_icons/phoenix.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/file_icons/plus.svg b/assets/icons/file_icons/plus.svg index f343d5dd87bf8fe4841de35fa09207906bae9a07..3449da3ecd70868f387937f2f4015c4a63ef2798 100644 --- a/assets/icons/file_icons/plus.svg +++ b/assets/icons/file_icons/plus.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/file_icons/prettier.svg b/assets/icons/file_icons/prettier.svg index 835bd3a1267886671a2e7a28a2efc596d74bb4bf..f01230c33c4eb0de68731d074c8045b4ca877781 100644 --- a/assets/icons/file_icons/prettier.svg +++ b/assets/icons/file_icons/prettier.svg @@ -1,12 +1,12 @@ - - - - - - - - - - + + + + + + + + + + diff --git a/assets/icons/file_icons/project.svg b/assets/icons/file_icons/project.svg index 86a15d41bc41f3652a82ee5d67fd275ecd8c02fc..509cc5f4d0a4d88f392864ada4c02928c4a9c431 100644 --- a/assets/icons/file_icons/project.svg +++ b/assets/icons/file_icons/project.svg @@ -1,5 +1,5 @@ - + diff --git a/assets/icons/file_icons/python.svg b/assets/icons/file_icons/python.svg index de904d8e046a143b16283bbbbdfaf7010e22aed1..b44fdc539d4f08bb49e3810eadfff4c3d3abaf08 100644 --- a/assets/icons/file_icons/python.svg +++ b/assets/icons/file_icons/python.svg @@ -1,6 +1,6 @@ - - + + diff --git a/assets/icons/file_icons/replace.svg b/assets/icons/file_icons/replace.svg index 837cb23b669e2aceca4e27b69eb7ebceb08a9ca4..287328e82e7fc91f697ca19b6b006ee78f08ebd4 100644 --- a/assets/icons/file_icons/replace.svg +++ b/assets/icons/file_icons/replace.svg @@ -1,5 +1,5 @@ - + diff --git a/assets/icons/file_icons/replace_next.svg b/assets/icons/file_icons/replace_next.svg index 72511be70a2567c627e11b398bc95a591569675e..a9a9fc91f5816649aa4312285fb90e856558ee07 100644 --- a/assets/icons/file_icons/replace_next.svg +++ b/assets/icons/file_icons/replace_next.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/file_icons/rust.svg b/assets/icons/file_icons/rust.svg index 5db753628af10c679f347c863cf9819b3f9afa14..9e4dc57adb4458f7d860182287fa9b766cd6d1d8 100644 --- a/assets/icons/file_icons/rust.svg +++ b/assets/icons/file_icons/rust.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/file_icons/scala.svg b/assets/icons/file_icons/scala.svg index 9e89d1fa82338b0647f60283ab2bd8cfc8cc9850..0884cc96f4702cdf1ab5d4c7f1bf496752f82834 100644 --- a/assets/icons/file_icons/scala.svg +++ b/assets/icons/file_icons/scala.svg @@ -1,7 +1,7 @@ - + diff --git a/assets/icons/file_icons/settings.svg b/assets/icons/file_icons/settings.svg index 081d25bf482472bc6b1315b012644e8bcf16279f..d308135ff1fdc05166c83e595350facab31f5d30 100644 --- a/assets/icons/file_icons/settings.svg +++ b/assets/icons/file_icons/settings.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/file_icons/tcl.svg b/assets/icons/file_icons/tcl.svg index bb15b0f8e743c0555a6937cfc1a526ced1e5bb95..1bd7c4a5513dea6e375018d0323e8d4d2f013b8f 100644 --- a/assets/icons/file_icons/tcl.svg +++ b/assets/icons/file_icons/tcl.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/file_icons/toml.svg b/assets/icons/file_icons/toml.svg index 9ab78af50f9302615ec56535debe3794c2b73503..ae31911d6a659daec785a717926b0e4281683a69 100644 --- a/assets/icons/file_icons/toml.svg +++ b/assets/icons/file_icons/toml.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/file_icons/video.svg b/assets/icons/file_icons/video.svg index b96e359edbd859a33b4a3d83f93cb1395de60454..c249d4c82b0bbf6ccd17ed6c576ceba34f837a28 100644 --- a/assets/icons/file_icons/video.svg +++ b/assets/icons/file_icons/video.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/file_icons/vue.svg b/assets/icons/file_icons/vue.svg index 1cbe08dff52068688c16e773b283a715c8904ce6..1f993e90ef7f1103ecac0be40bff27566c75e338 100644 --- a/assets/icons/file_icons/vue.svg +++ b/assets/icons/file_icons/vue.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/file_lock.svg b/assets/icons/file_lock.svg index 6bfef249b4516f3fbbf7f1a4c220b0fe893367d5..10ae33869a610714a66763683b38ff91ea9fa074 100644 --- a/assets/icons/file_lock.svg +++ b/assets/icons/file_lock.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/file_markdown.svg b/assets/icons/file_markdown.svg index e26d7a532d5bbd2b3a7ff43325ff705d8ab222bd..26688a3db0aa000889436fdf59e63cac2af7b743 100644 --- a/assets/icons/file_markdown.svg +++ b/assets/icons/file_markdown.svg @@ -1 +1 @@ - + diff --git a/assets/icons/file_rust.svg b/assets/icons/file_rust.svg index 5db753628af10c679f347c863cf9819b3f9afa14..9e4dc57adb4458f7d860182287fa9b766cd6d1d8 100644 --- a/assets/icons/file_rust.svg +++ b/assets/icons/file_rust.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/file_text_outlined.svg b/assets/icons/file_text_outlined.svg index bb9b85d62f42c63b7042231d1eecc6baee8c83cc..d2e8897251e31b5ef10d009bbce7aba3a16521fe 100644 --- a/assets/icons/file_text_outlined.svg +++ b/assets/icons/file_text_outlined.svg @@ -1,6 +1,6 @@ - - - - + + + + diff --git a/assets/icons/file_toml.svg b/assets/icons/file_toml.svg index 9ab78af50f9302615ec56535debe3794c2b73503..ae31911d6a659daec785a717926b0e4281683a69 100644 --- a/assets/icons/file_toml.svg +++ b/assets/icons/file_toml.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/file_tree.svg b/assets/icons/file_tree.svg index 74acb1fc257a559a5aad1a3718e851159b735953..baf0e26ce6d8e88a18fd0496e708842e9d9394c3 100644 --- a/assets/icons/file_tree.svg +++ b/assets/icons/file_tree.svg @@ -1,5 +1,5 @@ - + diff --git a/assets/icons/filter.svg b/assets/icons/filter.svg index 7391fea132eac0e394cce97f0b1e630e2255f87d..4aa14e93c003d0770e973656c8af81af16d84b89 100644 --- a/assets/icons/filter.svg +++ b/assets/icons/filter.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/flame.svg b/assets/icons/flame.svg index 3215f0d5aee240cc1ae04a896e7e498d513b4aa9..89fc6cab1ef07336d7ae5886c9cbcbec5d419e49 100644 --- a/assets/icons/flame.svg +++ b/assets/icons/flame.svg @@ -1 +1 @@ - + diff --git a/assets/icons/folder.svg b/assets/icons/folder.svg index 0d76b7e3f8bd75aee66f7683a30b01052bc69dfe..35f4c1f8acf6796b14c2fb575e2181b771253706 100644 --- a/assets/icons/folder.svg +++ b/assets/icons/folder.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/folder_open.svg b/assets/icons/folder_open.svg index ef37f55f83a38f2eb5713ae615407276f76028b7..55231fb6abdb876aa86984fffaf9d3993552ebab 100644 --- a/assets/icons/folder_open.svg +++ b/assets/icons/folder_open.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/folder_search.svg b/assets/icons/folder_search.svg index d1bc537c98bb8d4029f2c368297fad51557e966e..207ea5c10e823929cc957e038487cb0f9d2f89ac 100644 --- a/assets/icons/folder_search.svg +++ b/assets/icons/folder_search.svg @@ -1,5 +1,5 @@ - + diff --git a/assets/icons/font.svg b/assets/icons/font.svg index 1cc569ecb7b61a67d4f001b4e41fed663c0173d0..47633a58c93feb5f80da5bf7bd15a382a74ee975 100644 --- a/assets/icons/font.svg +++ b/assets/icons/font.svg @@ -1 +1 @@ - + diff --git a/assets/icons/font_size.svg b/assets/icons/font_size.svg index fd983cb5d3cdf6c69cd81a6845af352937a7b44f..4286277bd900596d861a857a3b1e28b66d53d678 100644 --- a/assets/icons/font_size.svg +++ b/assets/icons/font_size.svg @@ -1 +1 @@ - + diff --git a/assets/icons/font_weight.svg b/assets/icons/font_weight.svg index 73b9852e2fbb674e1bdb1772fd3ea94870c448c5..410f43ec6e983f20e1f697c4a8f47253999ee786 100644 --- a/assets/icons/font_weight.svg +++ b/assets/icons/font_weight.svg @@ -1 +1 @@ - + diff --git a/assets/icons/forward_arrow.svg b/assets/icons/forward_arrow.svg index 503b0b309bfca1c3841de97fc2648bf8a9d26eab..e51796e5546a55b20a31f8bcaa9fbf8da8cb1b25 100644 --- a/assets/icons/forward_arrow.svg +++ b/assets/icons/forward_arrow.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/git_branch.svg b/assets/icons/git_branch.svg index 811bc7476211e61ba3d23bfde1e3fed1d6a419f5..fc6dcfe1b275974e64c292e56e7f962aa67cde06 100644 --- a/assets/icons/git_branch.svg +++ b/assets/icons/git_branch.svg @@ -1 +1 @@ - + diff --git a/assets/icons/git_branch_alt.svg b/assets/icons/git_branch_alt.svg index d18b072512c305c88d9daa5861b2413fbd163481..cf40195d8b2faaea629b04ec2430bd9e8afeff5f 100644 --- a/assets/icons/git_branch_alt.svg +++ b/assets/icons/git_branch_alt.svg @@ -1,7 +1,7 @@ - - - - - + + + + + diff --git a/assets/icons/github.svg b/assets/icons/github.svg index fe9186872b27c1c748170ed375f1b327967c95c9..0a12c9b656f659b010d2aaa4f1f89290368d1941 100644 --- a/assets/icons/github.svg +++ b/assets/icons/github.svg @@ -1 +1 @@ - + diff --git a/assets/icons/hash.svg b/assets/icons/hash.svg index 9e4dd7c0689f4945f8059884ebf82a9b8eaa3c64..afc1f9c0b50ceaca0bae9e5c6772d5d53665e31b 100644 --- a/assets/icons/hash.svg +++ b/assets/icons/hash.svg @@ -1 +1 @@ - + diff --git a/assets/icons/history_rerun.svg b/assets/icons/history_rerun.svg index 9ade606b31ed0bb646925210ee7c522e77208ca2..e11e754318192b25362ccc5a3a5ef61b25b1a474 100644 --- a/assets/icons/history_rerun.svg +++ b/assets/icons/history_rerun.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/image.svg b/assets/icons/image.svg index 0a26c35182b2aed46fcad8845f3f6d56bc7cfe52..e0d73d76212f38aef9a2a61eb5fb48447db717b6 100644 --- a/assets/icons/image.svg +++ b/assets/icons/image.svg @@ -1 +1 @@ - + diff --git a/assets/icons/info.svg b/assets/icons/info.svg index f3d2e6644ff2d7119965a08a1a1cfd45e2bf6f0b..c000f25867c4092fe08ed01a48a18e6ce07b2784 100644 --- a/assets/icons/info.svg +++ b/assets/icons/info.svg @@ -1,5 +1,5 @@ - - + + diff --git a/assets/icons/json.svg b/assets/icons/json.svg new file mode 100644 index 0000000000000000000000000000000000000000..af2f6c5dc0e4916dd673c2a8a40f6e9b1cb9aa99 --- /dev/null +++ b/assets/icons/json.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/keyboard.svg b/assets/icons/keyboard.svg index de9afd9561a3d332a63f16d8e245b860e3d04130..82791cda3fe31192f3be66fd6f27d2ca3c068bdd 100644 --- a/assets/icons/keyboard.svg +++ b/assets/icons/keyboard.svg @@ -1 +1 @@ - + diff --git a/assets/icons/knockouts/x_fg.svg b/assets/icons/knockouts/x_fg.svg index a3d47f13735e734ca2f801618a4edbee02ea457b..f459954f729f3b80c50b64ec7ad0547d534235e9 100644 --- a/assets/icons/knockouts/x_fg.svg +++ b/assets/icons/knockouts/x_fg.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/library.svg b/assets/icons/library.svg index ed59e1818b4b33f284df570d4faf3c72cf9acf63..fc7f5afcd2fa45626033d8cd7f9963e28b5f8d31 100644 --- a/assets/icons/library.svg +++ b/assets/icons/library.svg @@ -1,6 +1,6 @@ - - - - + + + + diff --git a/assets/icons/line_height.svg b/assets/icons/line_height.svg index 7afa70f767a3e1053129235d06c1423d629fa4e4..3929fc408001022def1e54c60d193058bfa4aabd 100644 --- a/assets/icons/line_height.svg +++ b/assets/icons/line_height.svg @@ -1 +1 @@ - + diff --git a/assets/icons/list_collapse.svg b/assets/icons/list_collapse.svg index 938799b1513fd5625f940cad99839501b4fee837..f18bc550b90228c2f689848b86cfc5bea3d6ff50 100644 --- a/assets/icons/list_collapse.svg +++ b/assets/icons/list_collapse.svg @@ -1 +1 @@ - + diff --git a/assets/icons/list_filter.svg b/assets/icons/list_filter.svg new file mode 100644 index 0000000000000000000000000000000000000000..82f41f5f6832a8cb35e2703e0f8ce36d148454dd --- /dev/null +++ b/assets/icons/list_filter.svg @@ -0,0 +1 @@ + diff --git a/assets/icons/list_todo.svg b/assets/icons/list_todo.svg index 019af957347a4a5f2f70c69be0dab084887a6f07..709f26d89dbb5b5e95869bf4daa41fa5ad230ba9 100644 --- a/assets/icons/list_todo.svg +++ b/assets/icons/list_todo.svg @@ -1 +1 @@ - + diff --git a/assets/icons/list_tree.svg b/assets/icons/list_tree.svg index 09872a60f7ed9c85e89f06b7384b083a7f4b5779..de3e0f3a57b0e0edfc38cfa2b8b364529a741cea 100644 --- a/assets/icons/list_tree.svg +++ b/assets/icons/list_tree.svg @@ -1,7 +1,7 @@ - - - - - + + + + + diff --git a/assets/icons/list_x.svg b/assets/icons/list_x.svg index 206faf2ce45dee9b94333ee747f18262a2c6baa5..0fa3bd68fbf362ebb769f840aecf320f13c219da 100644 --- a/assets/icons/list_x.svg +++ b/assets/icons/list_x.svg @@ -1,7 +1,7 @@ - - - - - + + + + + diff --git a/assets/icons/load_circle.svg b/assets/icons/load_circle.svg index 825aa335b00961e77d4f615897a5d7914cccced8..eecf099310e17daeb9e953a816dbca389447292a 100644 --- a/assets/icons/load_circle.svg +++ b/assets/icons/load_circle.svg @@ -1 +1 @@ - + diff --git a/assets/icons/location_edit.svg b/assets/icons/location_edit.svg index 02cd6f3389a499a84e37070aa110f6150fce94f0..e342652eb153caaea18cb40b1154e8444e8554ec 100644 --- a/assets/icons/location_edit.svg +++ b/assets/icons/location_edit.svg @@ -1 +1 @@ - + diff --git a/assets/icons/lock_outlined.svg b/assets/icons/lock_outlined.svg index 0bfd2fdc82ad6cfd21e9fd2c901a7604fb6c0ba9..d69a2456031113e6451b046acc71deaf559f3dc7 100644 --- a/assets/icons/lock_outlined.svg +++ b/assets/icons/lock_outlined.svg @@ -1,6 +1,6 @@ - - + + - + diff --git a/assets/icons/magnifying_glass.svg b/assets/icons/magnifying_glass.svg index b7c22e64bd219c63a9cdfbd9b187bbf49d6a2e8f..24f00bb51bccc34a61a55bea0aa2fcdabfc99b60 100644 --- a/assets/icons/magnifying_glass.svg +++ b/assets/icons/magnifying_glass.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/maximize.svg b/assets/icons/maximize.svg index ee03a2c0210586a0cf0744df051414451e92f2f6..7b6d26fed8fd0a5074def7afe880bb26274afe0a 100644 --- a/assets/icons/maximize.svg +++ b/assets/icons/maximize.svg @@ -1,6 +1,6 @@ - - - - + + + + diff --git a/assets/icons/menu.svg b/assets/icons/menu.svg index 0724fb2816f6dd891779bcb82da35755fb7f521d..f12ce47f7e9d6208accfd97d66ce8b977b36080e 100644 --- a/assets/icons/menu.svg +++ b/assets/icons/menu.svg @@ -1 +1 @@ - + diff --git a/assets/icons/menu_alt.svg b/assets/icons/menu_alt.svg index b605e094e37749f64b9d3df55adce3830a2d7eb3..b9cc19e22febe045ca9ccf4a7e86d69b258f875c 100644 --- a/assets/icons/menu_alt.svg +++ b/assets/icons/menu_alt.svg @@ -1 +1,3 @@ - + + + diff --git a/assets/icons/menu_alt_temp.svg b/assets/icons/menu_alt_temp.svg new file mode 100644 index 0000000000000000000000000000000000000000..87add13216d9eb8c4c3d8f345ff1695e98be2d5d --- /dev/null +++ b/assets/icons/menu_alt_temp.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/mic.svg b/assets/icons/mic.svg index 1d9c5bc9edf2a48b3311965fb57758b3ee2e015e..000d135ea54a539be0d381ac22fff334e1bc24df 100644 --- a/assets/icons/mic.svg +++ b/assets/icons/mic.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/mic_mute.svg b/assets/icons/mic_mute.svg index 8c61ae2f1ccedc1b27244ed80e1a3fdd75cd4120..8bc63be610baf13539957a218c0fd9af3425a12f 100644 --- a/assets/icons/mic_mute.svg +++ b/assets/icons/mic_mute.svg @@ -1,8 +1,8 @@ - - - - - - + + + + + + diff --git a/assets/icons/minimize.svg b/assets/icons/minimize.svg index ea825f054ed1813aff11d2838b2c0e8e2211d717..082ade47dbdab16ea69bc1e00f6e0c7bd5a8d936 100644 --- a/assets/icons/minimize.svg +++ b/assets/icons/minimize.svg @@ -1,6 +1,6 @@ - - - - + + + + diff --git a/assets/icons/notepad.svg b/assets/icons/notepad.svg index 48875eedee635472b19e17afaa49e38813c4d58d..27fd35566eee2141946658d315a155016e5ac345 100644 --- a/assets/icons/notepad.svg +++ b/assets/icons/notepad.svg @@ -1 +1 @@ - + diff --git a/assets/icons/option.svg b/assets/icons/option.svg index 676c10c93b78222ad42656cdf8f35e9443ab482a..47201f7c671e4503cf597252973ebe4c4e3b5c7d 100644 --- a/assets/icons/option.svg +++ b/assets/icons/option.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/pencil.svg b/assets/icons/pencil.svg index b913015c08ae5e7fc2ceb6011bd89925cedc27fe..c4d289e9c06c7fdc6d4e1875f4170c87ef6ea425 100644 --- a/assets/icons/pencil.svg +++ b/assets/icons/pencil.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/pencil_unavailable.svg b/assets/icons/pencil_unavailable.svg new file mode 100644 index 0000000000000000000000000000000000000000..4241d766ace9ec5873553e0c1d77b8c19f6caa79 --- /dev/null +++ b/assets/icons/pencil_unavailable.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/assets/icons/person.svg b/assets/icons/person.svg index c64167830378093dce799498d528edfcfd9bc6c4..a1c29e4acb29a52c5fb6f08e75875306026c63f0 100644 --- a/assets/icons/person.svg +++ b/assets/icons/person.svg @@ -1 +1 @@ - + diff --git a/assets/icons/pin.svg b/assets/icons/pin.svg index f3f50cc65953d2dc3b3ce038da708de213c6058a..d23daff8b988a5882f95ef62c3de5d99612d9317 100644 --- a/assets/icons/pin.svg +++ b/assets/icons/pin.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/play_filled.svg b/assets/icons/play_filled.svg index c632434305c6bd25da205ca8cee8203b9d3611b1..8075197ad2ae94fe3ffec1a2685e90f8b57ea513 100644 --- a/assets/icons/play_filled.svg +++ b/assets/icons/play_filled.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/play_outlined.svg b/assets/icons/play_outlined.svg index 7e1cacd5af8795501cc30f4e33927f752a1eba7f..ba1ea2693d61646623a998b7d34ae0ca2d716cef 100644 --- a/assets/icons/play_outlined.svg +++ b/assets/icons/play_outlined.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/plus.svg b/assets/icons/plus.svg index e26d430320eae364d012a2482c39de19fce4ed2a..8ac57d8cdde017ef51d622cf8b63af644b3de332 100644 --- a/assets/icons/plus.svg +++ b/assets/icons/plus.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/power.svg b/assets/icons/power.svg index 23f6f48f30befb236346b06d99fe877862d151e1..29bd2127c58e2e940ad044e18c2767d8826de6d5 100644 --- a/assets/icons/power.svg +++ b/assets/icons/power.svg @@ -1 +1 @@ - + diff --git a/assets/icons/public.svg b/assets/icons/public.svg index 574ee1010db1f27d59843c27632091afefa4cb0a..5659b5419f7d4df12ab45d8a7dfbc954ccd4c131 100644 --- a/assets/icons/public.svg +++ b/assets/icons/public.svg @@ -1 +1 @@ - + diff --git a/assets/icons/pull_request.svg b/assets/icons/pull_request.svg index ccfaaacfdcb28a25a73d278b7d61195eb28fd299..515462ab64406cf22d83717304cc40293ab21230 100644 --- a/assets/icons/pull_request.svg +++ b/assets/icons/pull_request.svg @@ -1 +1 @@ - + diff --git a/assets/icons/quote.svg b/assets/icons/quote.svg index 5564a60f95e34c16c6d4e820f5ea2901c7884707..a958bc67f2a7c04611499b53d4c25bec5ad1f2ff 100644 --- a/assets/icons/quote.svg +++ b/assets/icons/quote.svg @@ -1 +1 @@ - + diff --git a/assets/icons/reader.svg b/assets/icons/reader.svg index 2ccc37623d9daa4b5f6c79748a2e4bf3f7a03067..f477f4f32d83ee2ea7a364ac2d22bf7ce72a9f5a 100644 --- a/assets/icons/reader.svg +++ b/assets/icons/reader.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/refresh_title.svg b/assets/icons/refresh_title.svg index 8a8fdb04f395749c2cdbf509c6bd38f1bcaae623..c9e670bfabe7940d6936d9e767fa23d73d4703b5 100644 --- a/assets/icons/refresh_title.svg +++ b/assets/icons/refresh_title.svg @@ -1 +1 @@ - + diff --git a/assets/icons/regex.svg b/assets/icons/regex.svg index 0432cd570fe2341829c40f5d4e629e3b27e24379..818c2ba360bc5aca3d4a7bf8ab65a03a2efe235e 100644 --- a/assets/icons/regex.svg +++ b/assets/icons/regex.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/repl_neutral.svg b/assets/icons/repl_neutral.svg index d9c8b001df15bc3812084be29460d43733deffb1..2842e2c4210085cb930efee43aa3340a4d628d6a 100644 --- a/assets/icons/repl_neutral.svg +++ b/assets/icons/repl_neutral.svg @@ -1,6 +1,6 @@ - - - - + + + + diff --git a/assets/icons/repl_off.svg b/assets/icons/repl_off.svg index ac249ad5ffa687c2cbf369ff5bb5434f234960fa..3018ceaf8588cd9b545397d6167d5d68d8e33bfa 100644 --- a/assets/icons/repl_off.svg +++ b/assets/icons/repl_off.svg @@ -1,11 +1,11 @@ - - - - - - - - - + + + + + + + + + diff --git a/assets/icons/repl_pause.svg b/assets/icons/repl_pause.svg index 5273ed60bb5126cb8f1331b9b82651a06e4c9157..5a69a576c1152d71a714252d0cc66699eb39b1b3 100644 --- a/assets/icons/repl_pause.svg +++ b/assets/icons/repl_pause.svg @@ -1,8 +1,8 @@ - - - - - - + + + + + + diff --git a/assets/icons/repl_play.svg b/assets/icons/repl_play.svg index 76c292a38236fbe63df045c05ec6737a7a207ebe..0c8f4b0832ba2d74ae793751328e9927e45c950f 100644 --- a/assets/icons/repl_play.svg +++ b/assets/icons/repl_play.svg @@ -1,7 +1,7 @@ - - - - - + + + + + diff --git a/assets/icons/replace.svg b/assets/icons/replace.svg index 837cb23b669e2aceca4e27b69eb7ebceb08a9ca4..287328e82e7fc91f697ca19b6b006ee78f08ebd4 100644 --- a/assets/icons/replace.svg +++ b/assets/icons/replace.svg @@ -1,5 +1,5 @@ - + diff --git a/assets/icons/replace_next.svg b/assets/icons/replace_next.svg index 72511be70a2567c627e11b398bc95a591569675e..a9a9fc91f5816649aa4312285fb90e856558ee07 100644 --- a/assets/icons/replace_next.svg +++ b/assets/icons/replace_next.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/rerun.svg b/assets/icons/rerun.svg index a5daa5de1d748069422a47618076115050c56182..1a03a01ae6403e92565bc85a92eaddcb9edaa601 100644 --- a/assets/icons/rerun.svg +++ b/assets/icons/rerun.svg @@ -1 +1 @@ - + diff --git a/assets/icons/return.svg b/assets/icons/return.svg index aed9242a95bd1f830a5c3722af0dca7412491b07..c605eb6512b3f4314cc1d42936213ad4eef5b041 100644 --- a/assets/icons/return.svg +++ b/assets/icons/return.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/rotate_ccw.svg b/assets/icons/rotate_ccw.svg index 8f6bd6346a067dd1b8dafc46dd71c579ebd61729..cdfa8d0ab4d649e6255061faa54d337f038bc611 100644 --- a/assets/icons/rotate_ccw.svg +++ b/assets/icons/rotate_ccw.svg @@ -1 +1 @@ - + diff --git a/assets/icons/rotate_cw.svg b/assets/icons/rotate_cw.svg index b082096ee4be635dac44bc3308fc64d5ad5ebd0b..2adfa7f972b71b4d7b9194d4dc9488745dc18ce9 100644 --- a/assets/icons/rotate_cw.svg +++ b/assets/icons/rotate_cw.svg @@ -1 +1 @@ - + diff --git a/assets/icons/scissors.svg b/assets/icons/scissors.svg index 430293f9138947b35a96ff04db27db657320f64d..a19580bd89d3b3da56df2089f2de7dbd0cd981fe 100644 --- a/assets/icons/scissors.svg +++ b/assets/icons/scissors.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/screen.svg b/assets/icons/screen.svg index 4b686b58f9de2e4993546ddad1a20af395d50330..4bcdf19528a799b4725f617d56d1f84baa33c904 100644 --- a/assets/icons/screen.svg +++ b/assets/icons/screen.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/select_all.svg b/assets/icons/select_all.svg index c15973c419df4426bc47c8d07736a768ffab9d99..4fa17dcf6371838974f393021a39953afbbdcfdf 100644 --- a/assets/icons/select_all.svg +++ b/assets/icons/select_all.svg @@ -1 +1 @@ - + diff --git a/assets/icons/send.svg b/assets/icons/send.svg index 1403a43ff54b25d4424d5a66b3a00199fa8e1b6d..5ceeef2af4721301bdb3e92e06c019aa4440a39f 100644 --- a/assets/icons/send.svg +++ b/assets/icons/send.svg @@ -1 +1 @@ - + diff --git a/assets/icons/server.svg b/assets/icons/server.svg index bde19efd75bb11015ba48687a471d633257c8440..8d851d1328d60b0ba45f89a70fb97ff49bf5b732 100644 --- a/assets/icons/server.svg +++ b/assets/icons/server.svg @@ -1,6 +1,6 @@ - - - - + + + + diff --git a/assets/icons/settings.svg b/assets/icons/settings.svg index 617b14b3cde91918315801220294224c13b47e2a..33ac74f2300ed4ebad1e9ea4f523a10bb4865eea 100644 --- a/assets/icons/settings.svg +++ b/assets/icons/settings.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/shield_check.svg b/assets/icons/shield_check.svg index 6e58c314682a5e87de9b2ca582262a3110f7006d..43b52f43a8d70beb6e69c2271235090db4dc2c00 100644 --- a/assets/icons/shield_check.svg +++ b/assets/icons/shield_check.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/shift.svg b/assets/icons/shift.svg index 35dc2f144cff68641c37eae6d64bb016e55c498c..c38807d8b0b434ff4dd868decfd1d79ef772f720 100644 --- a/assets/icons/shift.svg +++ b/assets/icons/shift.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/slash.svg b/assets/icons/slash.svg index e2313f0099f9158c18c65c2f928a7301a277bcd6..1ebf01eb9f13af5b449dfdfcd075f759ac4a1d1f 100644 --- a/assets/icons/slash.svg +++ b/assets/icons/slash.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/sliders.svg b/assets/icons/sliders.svg index 8ab83055eef53a07c84cca255aeca505b07f47c2..20a6a367dc4963f3ecdfe6611076fd6462faa764 100644 --- a/assets/icons/sliders.svg +++ b/assets/icons/sliders.svg @@ -1,8 +1,8 @@ - - - - - - + + + + + + diff --git a/assets/icons/space.svg b/assets/icons/space.svg index 86bd55cd537bf49c955130fedb65ce148bec092d..0294c9bf1e64d9b0e0521b88981c7e704bba28e6 100644 --- a/assets/icons/space.svg +++ b/assets/icons/space.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/sparkle.svg b/assets/icons/sparkle.svg index e5cce9fafdb699e3e7144cd468f7f4747343c9de..535c447723cee7913c8fc40494c49d24760f7569 100644 --- a/assets/icons/sparkle.svg +++ b/assets/icons/sparkle.svg @@ -1 +1 @@ - + diff --git a/assets/icons/split.svg b/assets/icons/split.svg index eb031ab790cbea7fb12d43b9e28a1b5884ab84f6..b2be46a875b461928715129b560062aacd1fc5f4 100644 --- a/assets/icons/split.svg +++ b/assets/icons/split.svg @@ -1,5 +1,5 @@ - - + + diff --git a/assets/icons/split_alt.svg b/assets/icons/split_alt.svg index 5b99b7a26a44f4a97b05fcbcec02eae41554098e..2f99e1436fb71220f6853d7e56f3e748b53fff12 100644 --- a/assets/icons/split_alt.svg +++ b/assets/icons/split_alt.svg @@ -1 +1 @@ - + diff --git a/assets/icons/square_dot.svg b/assets/icons/square_dot.svg index 4bb684afb296218d9fe61a6f803c0aeaa3cb75af..72b32734399a2d9e47ea520a93b6891f0d4664f6 100644 --- a/assets/icons/square_dot.svg +++ b/assets/icons/square_dot.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/square_minus.svg b/assets/icons/square_minus.svg index 4b8fc4d982500fea1c548b6e01c6b80ba90050ee..5ba458e8b53bf6df71a95bd2326e7b1323bae161 100644 --- a/assets/icons/square_minus.svg +++ b/assets/icons/square_minus.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/square_plus.svg b/assets/icons/square_plus.svg index e0ee106b525196d267640a6173e3faddf1858e0b..063c7dbf8261d98957e3b835f4d8262b155dc396 100644 --- a/assets/icons/square_plus.svg +++ b/assets/icons/square_plus.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/star.svg b/assets/icons/star.svg index fd1502ede8a19ea3781c64430f4b10bc2475a630..b39638e386e9913ab12983ebd0805cc9128a955b 100644 --- a/assets/icons/star.svg +++ b/assets/icons/star.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/star_filled.svg b/assets/icons/star_filled.svg index d7de9939db2a57f19497e91ac7f1420c6c698fef..16f64e5cb33c17f879022341094be59316eb5135 100644 --- a/assets/icons/star_filled.svg +++ b/assets/icons/star_filled.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/stop.svg b/assets/icons/stop.svg index 41e4fd35e913f774a4b3674e8505b90ab787cbc6..cc2bbe9207acf5acd44ff13e93140099d222250b 100644 --- a/assets/icons/stop.svg +++ b/assets/icons/stop.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/swatch_book.svg b/assets/icons/swatch_book.svg index 99a1c88bd5fcede4bf0a7638fcf7c6896eecd025..b37d5df8c1a5f0f6b9fa9cb46b3004a2ba55da4f 100644 --- a/assets/icons/swatch_book.svg +++ b/assets/icons/swatch_book.svg @@ -1 +1 @@ - + diff --git a/assets/icons/tab.svg b/assets/icons/tab.svg index f16d51ccf5ae25bc2670d5ad959f2fb0cdca4e9c..db93be4df53cb01e07f1a66773b9118e68ed6609 100644 --- a/assets/icons/tab.svg +++ b/assets/icons/tab.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/terminal_alt.svg b/assets/icons/terminal_alt.svg index 82d88167b2a50fbadc36151354fed6fd65432a69..d03c05423e24fa8d7c8604050545a8ddd26cc9be 100644 --- a/assets/icons/terminal_alt.svg +++ b/assets/icons/terminal_alt.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/terminal_ghost.svg b/assets/icons/terminal_ghost.svg new file mode 100644 index 0000000000000000000000000000000000000000..7d0d0e068e8a6f01837e860e8223690a95541769 --- /dev/null +++ b/assets/icons/terminal_ghost.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/text_snippet.svg b/assets/icons/text_snippet.svg index 12f131fdd526f5d2cc86ac40a80d4fb5fea7983f..b8987546d323de794bf0b7a974162c81a48e5135 100644 --- a/assets/icons/text_snippet.svg +++ b/assets/icons/text_snippet.svg @@ -1 +1 @@ - + diff --git a/assets/icons/text_thread.svg b/assets/icons/text_thread.svg index 75afa934a028f1bddd104effe536db70ad4f241c..aa078c72a2f35d2b82e90f2be64d23fcda3418a5 100644 --- a/assets/icons/text_thread.svg +++ b/assets/icons/text_thread.svg @@ -1,7 +1,7 @@ - - - - - + + + + + diff --git a/assets/icons/thread.svg b/assets/icons/thread.svg index 8c2596a4c9fca9f75a122dc85225f33696320030..496cf42e3a3ee1439f36b8e2479d05564362e628 100644 --- a/assets/icons/thread.svg +++ b/assets/icons/thread.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/thread_from_summary.svg b/assets/icons/thread_from_summary.svg index 7519935affc03bf50e9a39bcb5792237fba1e44f..94ce9562da15e2abc53912a8069f1d4e3f3dd3d8 100644 --- a/assets/icons/thread_from_summary.svg +++ b/assets/icons/thread_from_summary.svg @@ -1,6 +1,6 @@ - - - - + + + + diff --git a/assets/icons/thumbs_down.svg b/assets/icons/thumbs_down.svg index 334115a014d386b18f800f3f3ad59b9f8cad49de..a396ff14f614158382172ff3a8ed682612db541c 100644 --- a/assets/icons/thumbs_down.svg +++ b/assets/icons/thumbs_down.svg @@ -1 +1 @@ - + diff --git a/assets/icons/thumbs_up.svg b/assets/icons/thumbs_up.svg index b1e435936b3cd46433e58b98fbf586468b288ff4..73c859c3557c17dd6fe962dec147ae3275a9aae9 100644 --- a/assets/icons/thumbs_up.svg +++ b/assets/icons/thumbs_up.svg @@ -1 +1 @@ - + diff --git a/assets/icons/todo_complete.svg b/assets/icons/todo_complete.svg index d50044e4351126305321ab7ce3afdb2814b78244..5bf70841a8f2876ac2955c6731d66cbfa2f8a7dd 100644 --- a/assets/icons/todo_complete.svg +++ b/assets/icons/todo_complete.svg @@ -1 +1 @@ - + diff --git a/assets/icons/todo_pending.svg b/assets/icons/todo_pending.svg index dfb013b52b987a3f99e1b8304418b847ff1ccf2b..e5e9776f11b2ebdaed8ab42039d1a8de80f29ccb 100644 --- a/assets/icons/todo_pending.svg +++ b/assets/icons/todo_pending.svg @@ -1,10 +1,10 @@ - - - - - - - - + + + + + + + + diff --git a/assets/icons/todo_progress.svg b/assets/icons/todo_progress.svg index 9b2ed7375d9807139261a2d81f7f1f168470d0f4..b4a3e8c50e75343435849323d49d3c45cfe3069c 100644 --- a/assets/icons/todo_progress.svg +++ b/assets/icons/todo_progress.svg @@ -1,11 +1,11 @@ - - - - - - - - - + + + + + + + + + diff --git a/assets/icons/tool_copy.svg b/assets/icons/tool_copy.svg index e722d8a022fca603b87fc1859436fcc060355095..a497a5c9cba29861710e3810029f2936aa2f02b1 100644 --- a/assets/icons/tool_copy.svg +++ b/assets/icons/tool_copy.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/tool_delete_file.svg b/assets/icons/tool_delete_file.svg index 3276f3d78e8ca1bb6d79a58845577cb150f545aa..e15c0cb568eae4274a9621ac403aa3393b1d5287 100644 --- a/assets/icons/tool_delete_file.svg +++ b/assets/icons/tool_delete_file.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/tool_diagnostics.svg b/assets/icons/tool_diagnostics.svg index c659d967812727450bc3efb825b6492e6d2eda50..414810628d96cbb6fa662e359d0dec3581afa022 100644 --- a/assets/icons/tool_diagnostics.svg +++ b/assets/icons/tool_diagnostics.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/tool_folder.svg b/assets/icons/tool_folder.svg index 0d76b7e3f8bd75aee66f7683a30b01052bc69dfe..35f4c1f8acf6796b14c2fb575e2181b771253706 100644 --- a/assets/icons/tool_folder.svg +++ b/assets/icons/tool_folder.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/tool_hammer.svg b/assets/icons/tool_hammer.svg index e66173ce70f39416bbfdbfdb97dfa6f99e1ef3b7..f725012cdf211fcdd0f739e2693caac4da300b49 100644 --- a/assets/icons/tool_hammer.svg +++ b/assets/icons/tool_hammer.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/tool_notification.svg b/assets/icons/tool_notification.svg index 7510b3204000d714e8fb120179cbfc521e1abdd8..7903a3369a5a620fdbefc3a7f1cfe72ee0b98c6f 100644 --- a/assets/icons/tool_notification.svg +++ b/assets/icons/tool_notification.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/tool_pencil.svg b/assets/icons/tool_pencil.svg index b913015c08ae5e7fc2ceb6011bd89925cedc27fe..c4d289e9c06c7fdc6d4e1875f4170c87ef6ea425 100644 --- a/assets/icons/tool_pencil.svg +++ b/assets/icons/tool_pencil.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/tool_read.svg b/assets/icons/tool_read.svg index 458cbb36607a308ae4ce5e6a98006f9ff87461e8..d22e9d8c7da9ba04fe194339d787e40637cf5257 100644 --- a/assets/icons/tool_read.svg +++ b/assets/icons/tool_read.svg @@ -1,7 +1,7 @@ - - - - - + + + + + diff --git a/assets/icons/tool_regex.svg b/assets/icons/tool_regex.svg index 0432cd570fe2341829c40f5d4e629e3b27e24379..818c2ba360bc5aca3d4a7bf8ab65a03a2efe235e 100644 --- a/assets/icons/tool_regex.svg +++ b/assets/icons/tool_regex.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/tool_search.svg b/assets/icons/tool_search.svg index 4f2750cfa2624ff4419c159fda5b62a515b43113..b225a1298eeb627c420cf19c801dded83fffb097 100644 --- a/assets/icons/tool_search.svg +++ b/assets/icons/tool_search.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/tool_terminal.svg b/assets/icons/tool_terminal.svg index 3c4ab42a4dc06f7b2a9aaefff8145764a876d117..24da5e3a10bc47c8a7c73173c21a1ec2c6cf21b6 100644 --- a/assets/icons/tool_terminal.svg +++ b/assets/icons/tool_terminal.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/tool_think.svg b/assets/icons/tool_think.svg index 595f8070d8b6d30ade68b1bff41c141b43050394..773f5e7fa7795d7bc56bba061d808418897f9287 100644 --- a/assets/icons/tool_think.svg +++ b/assets/icons/tool_think.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/tool_web.svg b/assets/icons/tool_web.svg index 6250a9f05ab53d2bc364dc7520d10ee319f29f1f..288b54c432dcb4336161d779771c04b045eb6b4f 100644 --- a/assets/icons/tool_web.svg +++ b/assets/icons/tool_web.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/trash.svg b/assets/icons/trash.svg index 1322e90f9fdc1fad9901febff0f71a938621f900..4a9e9add021be23727ec7c8c69a98a593a6bece7 100644 --- a/assets/icons/trash.svg +++ b/assets/icons/trash.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/undo.svg b/assets/icons/undo.svg index b2407456dcf49102c8cbd4d7ad0580d6cff44535..c714b58747e950ab75d3a02be7eebfe7cd83eda1 100644 --- a/assets/icons/undo.svg +++ b/assets/icons/undo.svg @@ -1 +1 @@ - + diff --git a/assets/icons/user_check.svg b/assets/icons/user_check.svg index cd682b5eda44247245efc278babe4657078c50ab..ee32a525909a738ffa21b75a2a9690fa5ee8dcaf 100644 --- a/assets/icons/user_check.svg +++ b/assets/icons/user_check.svg @@ -1 +1 @@ - + diff --git a/assets/icons/user_group.svg b/assets/icons/user_group.svg index ac1f7bdc633190f88b202d9e5ae7430af225aecd..30d2e5a7eac519246fe4ee176d107bb2b8cd6598 100644 --- a/assets/icons/user_group.svg +++ b/assets/icons/user_group.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/user_round_pen.svg b/assets/icons/user_round_pen.svg index eb755173231ee3ed41147e2ef612e9d716496c66..e684fd1a2006e5435e93e2b6db27d5584ce41090 100644 --- a/assets/icons/user_round_pen.svg +++ b/assets/icons/user_round_pen.svg @@ -1 +1 @@ - + diff --git a/assets/icons/warning.svg b/assets/icons/warning.svg index 456799fa5ae761e04fc4c20d4d31bd7afe5479aa..5af37dab9db2c07c4d6ff505d03e62348d078f53 100644 --- a/assets/icons/warning.svg +++ b/assets/icons/warning.svg @@ -1 +1 @@ - + diff --git a/assets/icons/whole_word.svg b/assets/icons/whole_word.svg index 77cecce38c5700a4bc983f005f05de390e45a521..ce0d1606c8552f002d7bf58ad7a778a3be9561af 100644 --- a/assets/icons/whole_word.svg +++ b/assets/icons/whole_word.svg @@ -1 +1 @@ - + diff --git a/assets/icons/x_circle.svg b/assets/icons/x_circle.svg index 69aaa3f6a166be2834f0db5813041ab86b25e758..8807e5fa1fe6912982ab271744f17d13026c13ad 100644 --- a/assets/icons/x_circle.svg +++ b/assets/icons/x_circle.svg @@ -1 +1 @@ - + diff --git a/assets/icons/x_circle_filled.svg b/assets/icons/x_circle_filled.svg new file mode 100644 index 0000000000000000000000000000000000000000..52215acda8a6b7fc57820fa90f6ed405e6af637c --- /dev/null +++ b/assets/icons/x_circle_filled.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/zed_agent.svg b/assets/icons/zed_agent.svg new file mode 100644 index 0000000000000000000000000000000000000000..0c80e22c51233fff40b7605d0835b463786b4e84 --- /dev/null +++ b/assets/icons/zed_agent.svg @@ -0,0 +1,27 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/icons/zed_assistant.svg b/assets/icons/zed_assistant.svg index d21252de8c234611ddd41caff287e3fc0d540ed3..812277a100b7e6e4ad44de357fc3556b686a90a0 100644 --- a/assets/icons/zed_assistant.svg +++ b/assets/icons/zed_assistant.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/zed_burn_mode.svg b/assets/icons/zed_burn_mode.svg index f6192d16e7d3cd0a081fa745524c08b31875e80c..cad6ed666be5edbb1b2d6dced7d0e8990ac90d68 100644 --- a/assets/icons/zed_burn_mode.svg +++ b/assets/icons/zed_burn_mode.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/zed_burn_mode_on.svg b/assets/icons/zed_burn_mode_on.svg index 29a74a3e636a22a568f45bc40c1397ed033ce10d..10e0e42b1302f5323adee6724c3d43cc59a82d31 100644 --- a/assets/icons/zed_burn_mode_on.svg +++ b/assets/icons/zed_burn_mode_on.svg @@ -1 +1 @@ - + diff --git a/assets/icons/zed_mcp_custom.svg b/assets/icons/zed_mcp_custom.svg index 6410a26fcade9d5be5dd494eb24efa6b5985724a..feff2d7d34fb71d4d9064ae0cf5075216f969f75 100644 --- a/assets/icons/zed_mcp_custom.svg +++ b/assets/icons/zed_mcp_custom.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/zed_mcp_extension.svg b/assets/icons/zed_mcp_extension.svg index 996e0c1920c206f1d4ace11179069f77bc103300..00117efcf4e20cb368824ee248b671caedad0b3b 100644 --- a/assets/icons/zed_mcp_extension.svg +++ b/assets/icons/zed_mcp_extension.svg @@ -1,4 +1,4 @@ - + diff --git a/assets/icons/zed_predict.svg b/assets/icons/zed_predict.svg index 79fd8c8fc132d7a2d7bb966086a3ff62c819c5b0..605a0584d52b3163158610ae9a96fbe96fc60806 100644 --- a/assets/icons/zed_predict.svg +++ b/assets/icons/zed_predict.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/zed_predict_down.svg b/assets/icons/zed_predict_down.svg index 4532ad7e26cab76bc4e52a68ea5c766a1ffdca81..79eef9b0b4ad3bc2678371120bfbf5154cdd32fd 100644 --- a/assets/icons/zed_predict_down.svg +++ b/assets/icons/zed_predict_down.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/zed_predict_error.svg b/assets/icons/zed_predict_error.svg index b2dc339fe954f4bbb146d792bc653e2c402a3818..6f75326179bf3a6663ff28bfaecb80bf29878d42 100644 --- a/assets/icons/zed_predict_error.svg +++ b/assets/icons/zed_predict_error.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/zed_predict_up.svg b/assets/icons/zed_predict_up.svg index 61ec143022b4f785affa5183d549a750bd741ab2..f77001e4bddf3094532708ef0313e8b938434781 100644 --- a/assets/icons/zed_predict_up.svg +++ b/assets/icons/zed_predict_up.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/images/acp_grid.svg b/assets/images/acp_grid.svg new file mode 100644 index 0000000000000000000000000000000000000000..8ebff8e1bc87b17e536c7f97dfa2118130233258 --- /dev/null +++ b/assets/images/acp_grid.svg @@ -0,0 +1,1257 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/images/acp_logo.svg b/assets/images/acp_logo.svg new file mode 100644 index 0000000000000000000000000000000000000000..efaa46707be0a893917c3fc072a14b9c7b6b0c9b --- /dev/null +++ b/assets/images/acp_logo.svg @@ -0,0 +1 @@ + diff --git a/assets/images/acp_logo_serif.svg b/assets/images/acp_logo_serif.svg new file mode 100644 index 0000000000000000000000000000000000000000..a04d32e51c43acf358baa733f03284dbb6de1369 --- /dev/null +++ b/assets/images/acp_logo_serif.svg @@ -0,0 +1,46 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index dda26f406bc50b4c0451bcdf89f7bd7f15e6427a..700c7e8d8c441aabc11f61ca4de62d2a3f83245e 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -41,7 +41,7 @@ "shift-f11": "debugger::StepOut", "f11": "zed::ToggleFullScreen", "ctrl-alt-z": "edit_prediction::RateCompletions", - "ctrl-shift-i": "edit_prediction::ToggleMenu", + "ctrl-alt-shift-i": "edit_prediction::ToggleMenu", "ctrl-alt-l": "lsp_tool::ToggleMenu" } }, @@ -64,8 +64,8 @@ "ctrl-k": "editor::CutToEndOfLine", "ctrl-k ctrl-q": "editor::Rewrap", "ctrl-k q": "editor::Rewrap", - "ctrl-backspace": "editor::DeleteToPreviousWordStart", - "ctrl-delete": "editor::DeleteToNextWordEnd", + "ctrl-backspace": ["editor::DeleteToPreviousWordStart", { "ignore_newlines": false, "ignore_brackets": false }], + "ctrl-delete": ["editor::DeleteToNextWordEnd", { "ignore_newlines": false, "ignore_brackets": false }], "cut": "editor::Cut", "shift-delete": "editor::Cut", "ctrl-x": "editor::Cut", @@ -121,7 +121,7 @@ "alt-g m": "git::OpenModifiedFiles", "menu": "editor::OpenContextMenu", "shift-f10": "editor::OpenContextMenu", - "ctrl-shift-e": "editor::ToggleEditPrediction", + "ctrl-alt-shift-e": "editor::ToggleEditPrediction", "f9": "editor::ToggleBreakpoint", "shift-f9": "editor::EditLogBreakpoint" } @@ -131,14 +131,14 @@ "bindings": { "shift-enter": "editor::Newline", "enter": "editor::Newline", - "ctrl-enter": "editor::NewlineAbove", - "ctrl-shift-enter": "editor::NewlineBelow", + "ctrl-enter": "editor::NewlineBelow", + "ctrl-shift-enter": "editor::NewlineAbove", "ctrl-k ctrl-z": "editor::ToggleSoftWrap", "ctrl-k z": "editor::ToggleSoftWrap", "find": "buffer_search::Deploy", "ctrl-f": "buffer_search::Deploy", "ctrl-h": "buffer_search::DeployReplace", - "ctrl->": "assistant::QuoteSelection", + "ctrl->": "agent::QuoteSelection", "ctrl-<": "assistant::InsertIntoEditor", "ctrl-alt-e": "editor::SelectEnclosingSymbol", "ctrl-shift-backspace": "editor::GoToPreviousChange", @@ -171,6 +171,7 @@ "context": "Markdown", "bindings": { "copy": "markdown::Copy", + "ctrl-insert": "markdown::Copy", "ctrl-c": "markdown::Copy" } }, @@ -241,12 +242,15 @@ "ctrl-shift-i": "agent::ToggleOptionsMenu", "ctrl-alt-shift-n": "agent::ToggleNewThreadMenu", "shift-alt-escape": "agent::ExpandMessageEditor", - "ctrl->": "assistant::QuoteSelection", + "ctrl->": "agent::QuoteSelection", "ctrl-alt-e": "agent::RemoveAllContext", "ctrl-shift-e": "project_panel::ToggleFocus", "ctrl-shift-enter": "agent::ContinueThread", "super-ctrl-b": "agent::ToggleBurnMode", - "alt-enter": "agent::ContinueWithBurnMode" + "alt-enter": "agent::ContinueWithBurnMode", + "ctrl-y": "agent::AllowOnce", + "ctrl-alt-y": "agent::AllowAlways", + "ctrl-d": "agent::RejectOnce" } }, { @@ -259,6 +263,7 @@ "context": "AgentPanel > Markdown", "bindings": { "copy": "markdown::CopyAsMarkdown", + "ctrl-insert": "markdown::CopyAsMarkdown", "ctrl-c": "markdown::CopyAsMarkdown" } }, @@ -327,17 +332,32 @@ } }, { - "context": "AcpThread > Editor", + "context": "AcpThread > ModeSelector", + "bindings": { + "ctrl-enter": "menu::Confirm" + } + }, + { + "context": "AcpThread > Editor && !use_modifier_to_send", "use_key_equivalents": true, "bindings": { "enter": "agent::Chat", - "up": "agent::PreviousHistoryMessage", - "down": "agent::NextHistoryMessage", "shift-ctrl-r": "agent::OpenAgentDiff", "ctrl-shift-y": "agent::KeepAll", "ctrl-shift-n": "agent::RejectAll" } }, + { + "context": "AcpThread > Editor && use_modifier_to_send", + "use_key_equivalents": true, + "bindings": { + "ctrl-enter": "agent::Chat", + "shift-ctrl-r": "agent::OpenAgentDiff", + "ctrl-shift-y": "agent::KeepAll", + "ctrl-shift-n": "agent::RejectAll", + "shift-tab": "agent::CycleModeSelector" + } + }, { "context": "ThreadHistory", "bindings": { @@ -476,8 +496,8 @@ "alt-down": "editor::MoveLineDown", "ctrl-alt-shift-up": "editor::DuplicateLineUp", "ctrl-alt-shift-down": "editor::DuplicateLineDown", - "alt-shift-right": "editor::SelectLargerSyntaxNode", // Expand Selection - "alt-shift-left": "editor::SelectSmallerSyntaxNode", // Shrink Selection + "alt-shift-right": "editor::SelectLargerSyntaxNode", // Expand selection + "alt-shift-left": "editor::SelectSmallerSyntaxNode", // Shrink selection "ctrl-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection "ctrl-f2": "editor::SelectAllMatches", // Select all occurrences of current word "ctrl-d": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand @@ -573,7 +593,7 @@ "ctrl-n": "workspace::NewFile", "shift-new": "workspace::NewWindow", "ctrl-shift-n": "workspace::NewWindow", - "ctrl-`": "terminal_panel::ToggleFocus", + "ctrl-`": "terminal_panel::Toggle", "f10": ["app_menu::OpenApplicationMenu", "Zed"], "alt-1": ["workspace::ActivatePane", 0], "alt-2": ["workspace::ActivatePane", 1], @@ -618,6 +638,7 @@ "alt-save": "workspace::SaveAll", "ctrl-alt-s": "workspace::SaveAll", "ctrl-k m": "language_selector::Toggle", + "ctrl-k ctrl-m": "toolchain::AddToolchain", "escape": "workspace::Unfollow", "ctrl-k ctrl-left": "workspace::ActivatePaneLeft", "ctrl-k ctrl-right": "workspace::ActivatePaneRight", @@ -848,7 +869,7 @@ "ctrl-backspace": ["project_panel::Delete", { "skip_prompt": false }], "ctrl-delete": ["project_panel::Delete", { "skip_prompt": false }], "alt-ctrl-r": "project_panel::RevealInFileManager", - "ctrl-shift-enter": "project_panel::OpenWithSystem", + "ctrl-shift-enter": "workspace::OpenWithSystem", "alt-d": "project_panel::CompareMarkedFiles", "shift-find": "project_panel::NewSearchInDirectory", "ctrl-alt-shift-f": "project_panel::NewSearchInDirectory", @@ -1018,6 +1039,13 @@ "tab": "channel_modal::ToggleMode" } }, + { + "context": "ToolchainSelector", + "use_key_equivalents": true, + "bindings": { + "ctrl-shift-a": "toolchain::AddToolchain" + } + }, { "context": "FileFinder || (FileFinder > Picker > Editor)", "bindings": { @@ -1187,9 +1215,16 @@ "ctrl-1": "onboarding::ActivateBasicsPage", "ctrl-2": "onboarding::ActivateEditingPage", "ctrl-3": "onboarding::ActivateAISetupPage", - "ctrl-escape": "onboarding::Finish", - "alt-tab": "onboarding::SignIn", + "ctrl-enter": "onboarding::Finish", + "alt-shift-l": "onboarding::SignIn", "alt-shift-a": "onboarding::OpenAccount" } + }, + { + "context": "InvalidBuffer", + "use_key_equivalents": true, + "bindings": { + "ctrl-shift-enter": "workspace::OpenWithSystem" + } } ] diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index 3966efd8dfce9f1800ad0c9ac1c38b172709ce50..7c85e6e582a8c0b89586d7ae3ee573271a457b38 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -70,9 +70,9 @@ "cmd-k q": "editor::Rewrap", "cmd-backspace": "editor::DeleteToBeginningOfLine", "cmd-delete": "editor::DeleteToEndOfLine", - "alt-backspace": "editor::DeleteToPreviousWordStart", - "ctrl-w": "editor::DeleteToPreviousWordStart", - "alt-delete": "editor::DeleteToNextWordEnd", + "alt-backspace": ["editor::DeleteToPreviousWordStart", { "ignore_newlines": false, "ignore_brackets": false }], + "ctrl-w": ["editor::DeleteToPreviousWordStart", { "ignore_newlines": false, "ignore_brackets": false }], + "alt-delete": ["editor::DeleteToNextWordEnd", { "ignore_newlines": false, "ignore_brackets": false }], "cmd-x": "editor::Cut", "cmd-c": "editor::Copy", "cmd-v": "editor::Paste", @@ -162,7 +162,7 @@ "cmd-alt-f": "buffer_search::DeployReplace", "cmd-alt-l": ["buffer_search::Deploy", { "selection_search_enabled": true }], "cmd-e": ["buffer_search::Deploy", { "focus": false }], - "cmd->": "assistant::QuoteSelection", + "cmd->": "agent::QuoteSelection", "cmd-<": "assistant::InsertIntoEditor", "cmd-alt-e": "editor::SelectEnclosingSymbol", "alt-enter": "editor::OpenSelectionsInMultibuffer" @@ -218,7 +218,7 @@ } }, { - "context": "Editor && !agent_diff", + "context": "Editor && !agent_diff && !AgentPanel", "use_key_equivalents": true, "bindings": { "cmd-alt-z": "git::Restore", @@ -281,12 +281,15 @@ "cmd-shift-i": "agent::ToggleOptionsMenu", "cmd-alt-shift-n": "agent::ToggleNewThreadMenu", "shift-alt-escape": "agent::ExpandMessageEditor", - "cmd->": "assistant::QuoteSelection", + "cmd->": "agent::QuoteSelection", "cmd-alt-e": "agent::RemoveAllContext", "cmd-shift-e": "project_panel::ToggleFocus", "cmd-ctrl-b": "agent::ToggleBurnMode", "cmd-shift-enter": "agent::ContinueThread", - "alt-enter": "agent::ContinueWithBurnMode" + "alt-enter": "agent::ContinueWithBurnMode", + "cmd-y": "agent::AllowOnce", + "cmd-alt-y": "agent::AllowAlways", + "cmd-d": "agent::RejectOnce" } }, { @@ -379,15 +382,31 @@ } }, { - "context": "AcpThread > Editor", + "context": "AcpThread > ModeSelector", + "bindings": { + "cmd-enter": "menu::Confirm" + } + }, + { + "context": "AcpThread > Editor && !use_modifier_to_send", "use_key_equivalents": true, "bindings": { "enter": "agent::Chat", - "up": "agent::PreviousHistoryMessage", - "down": "agent::NextHistoryMessage", "shift-ctrl-r": "agent::OpenAgentDiff", "cmd-shift-y": "agent::KeepAll", - "cmd-shift-n": "agent::RejectAll" + "cmd-shift-n": "agent::RejectAll", + "shift-tab": "agent::CycleModeSelector" + } + }, + { + "context": "AcpThread > Editor && use_modifier_to_send", + "use_key_equivalents": true, + "bindings": { + "cmd-enter": "agent::Chat", + "shift-ctrl-r": "agent::OpenAgentDiff", + "cmd-shift-y": "agent::KeepAll", + "cmd-shift-n": "agent::RejectAll", + "shift-tab": "agent::CycleModeSelector" } }, { @@ -528,8 +547,10 @@ "alt-down": "editor::MoveLineDown", "alt-shift-up": "editor::DuplicateLineUp", "alt-shift-down": "editor::DuplicateLineDown", - "ctrl-shift-right": "editor::SelectLargerSyntaxNode", // Expand Selection - "ctrl-shift-left": "editor::SelectSmallerSyntaxNode", // Shrink Selection + "cmd-ctrl-left": "editor::SelectSmallerSyntaxNode", // Shrink selection + "cmd-ctrl-right": "editor::SelectLargerSyntaxNode", // Expand selection + "cmd-ctrl-up": "editor::SelectPreviousSyntaxNode", // Move selection up + "cmd-ctrl-down": "editor::SelectNextSyntaxNode", // Move selection down "cmd-d": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand "cmd-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection "cmd-f2": "editor::SelectAllMatches", // Select all occurrences of current word @@ -641,7 +662,7 @@ "alt-shift-enter": "toast::RunAction", "cmd-shift-s": "workspace::SaveAs", "cmd-shift-n": "workspace::NewWindow", - "ctrl-`": "terminal_panel::ToggleFocus", + "ctrl-`": "terminal_panel::Toggle", "cmd-1": ["workspace::ActivatePane", 0], "cmd-2": ["workspace::ActivatePane", 1], "cmd-3": ["workspace::ActivatePane", 2], @@ -682,6 +703,7 @@ "cmd-?": "agent::ToggleFocus", "cmd-alt-s": "workspace::SaveAll", "cmd-k m": "language_selector::Toggle", + "cmd-k cmd-m": "toolchain::AddToolchain", "escape": "workspace::Unfollow", "cmd-k cmd-left": "workspace::ActivatePaneLeft", "cmd-k cmd-right": "workspace::ActivatePaneRight", @@ -907,7 +929,7 @@ "cmd-backspace": ["project_panel::Trash", { "skip_prompt": true }], "cmd-delete": ["project_panel::Delete", { "skip_prompt": false }], "alt-cmd-r": "project_panel::RevealInFileManager", - "ctrl-shift-enter": "project_panel::OpenWithSystem", + "ctrl-shift-enter": "workspace::OpenWithSystem", "alt-d": "project_panel::CompareMarkedFiles", "cmd-alt-backspace": ["project_panel::Delete", { "skip_prompt": false }], "cmd-alt-shift-f": "project_panel::NewSearchInDirectory", @@ -1086,6 +1108,13 @@ "tab": "channel_modal::ToggleMode" } }, + { + "context": "ToolchainSelector", + "use_key_equivalents": true, + "bindings": { + "cmd-shift-a": "toolchain::AddToolchain" + } + }, { "context": "FileFinder || (FileFinder > Picker > Editor)", "use_key_equivalents": true, @@ -1293,5 +1322,12 @@ "alt-tab": "onboarding::SignIn", "alt-shift-a": "onboarding::OpenAccount" } + }, + { + "context": "InvalidBuffer", + "use_key_equivalents": true, + "bindings": { + "ctrl-shift-enter": "workspace::OpenWithSystem" + } } ] diff --git a/assets/keymaps/default-windows.json b/assets/keymaps/default-windows.json new file mode 100644 index 0000000000000000000000000000000000000000..0e9f193bd1a11cd4804878648e4690545fd7ce27 --- /dev/null +++ b/assets/keymaps/default-windows.json @@ -0,0 +1,1241 @@ +[ + // Standard Windows bindings + { + "use_key_equivalents": true, + "bindings": { + "home": "menu::SelectFirst", + "shift-pageup": "menu::SelectFirst", + "pageup": "menu::SelectFirst", + "end": "menu::SelectLast", + "shift-pagedown": "menu::SelectLast", + "pagedown": "menu::SelectLast", + "ctrl-n": "menu::SelectNext", + "tab": "menu::SelectNext", + "down": "menu::SelectNext", + "ctrl-p": "menu::SelectPrevious", + "shift-tab": "menu::SelectPrevious", + "up": "menu::SelectPrevious", + "enter": "menu::Confirm", + "ctrl-enter": "menu::SecondaryConfirm", + "ctrl-escape": "menu::Cancel", + "ctrl-c": "menu::Cancel", + "escape": "menu::Cancel", + "shift-alt-enter": "menu::Restart", + "alt-enter": ["picker::ConfirmInput", { "secondary": false }], + "ctrl-alt-enter": ["picker::ConfirmInput", { "secondary": true }], + "ctrl-shift-w": "workspace::CloseWindow", + "shift-escape": "workspace::ToggleZoom", + "ctrl-o": "workspace::Open", + "ctrl-=": ["zed::IncreaseBufferFontSize", { "persist": false }], + "ctrl-shift-=": ["zed::IncreaseBufferFontSize", { "persist": false }], + "ctrl--": ["zed::DecreaseBufferFontSize", { "persist": false }], + "ctrl-0": ["zed::ResetBufferFontSize", { "persist": false }], + "ctrl-,": "zed::OpenSettings", + "ctrl-q": "zed::Quit", + "f4": "debugger::Start", + "shift-f5": "debugger::Stop", + "ctrl-shift-f5": "debugger::RerunSession", + "f6": "debugger::Pause", + "f7": "debugger::StepOver", + "ctrl-f11": "debugger::StepInto", + "shift-f11": "debugger::StepOut", + "f11": "zed::ToggleFullScreen", + "ctrl-shift-i": "edit_prediction::ToggleMenu", + "shift-alt-l": "lsp_tool::ToggleMenu" + } + }, + { + "context": "Picker || menu", + "use_key_equivalents": true, + "bindings": { + "up": "menu::SelectPrevious", + "down": "menu::SelectNext" + } + }, + { + "context": "Editor", + "use_key_equivalents": true, + "bindings": { + "escape": "editor::Cancel", + "shift-backspace": "editor::Backspace", + "backspace": "editor::Backspace", + "delete": "editor::Delete", + "tab": "editor::Tab", + "shift-tab": "editor::Backtab", + "ctrl-k": "editor::CutToEndOfLine", + "ctrl-k ctrl-q": "editor::Rewrap", + "ctrl-k q": "editor::Rewrap", + "ctrl-backspace": ["editor::DeleteToPreviousWordStart", { "ignore_newlines": false, "ignore_brackets": false }], + "ctrl-delete": ["editor::DeleteToNextWordEnd", { "ignore_newlines": false, "ignore_brackets": false }], + "shift-delete": "editor::Cut", + "ctrl-x": "editor::Cut", + "ctrl-insert": "editor::Copy", + "ctrl-c": "editor::Copy", + "shift-insert": "editor::Paste", + "ctrl-v": "editor::Paste", + "ctrl-z": "editor::Undo", + "ctrl-y": "editor::Redo", + "ctrl-shift-z": "editor::Redo", + "up": "editor::MoveUp", + "ctrl-up": "editor::LineUp", + "ctrl-down": "editor::LineDown", + "pageup": "editor::MovePageUp", + "alt-pageup": "editor::PageUp", + "shift-pageup": "editor::SelectPageUp", + "home": ["editor::MoveToBeginningOfLine", { "stop_at_soft_wraps": true, "stop_at_indent": true }], + "down": "editor::MoveDown", + "pagedown": "editor::MovePageDown", + "alt-pagedown": "editor::PageDown", + "shift-pagedown": "editor::SelectPageDown", + "end": ["editor::MoveToEndOfLine", { "stop_at_soft_wraps": true }], + "left": "editor::MoveLeft", + "right": "editor::MoveRight", + "ctrl-left": "editor::MoveToPreviousWordStart", + "ctrl-right": "editor::MoveToNextWordEnd", + "ctrl-home": "editor::MoveToBeginning", + "ctrl-end": "editor::MoveToEnd", + "shift-up": "editor::SelectUp", + "shift-down": "editor::SelectDown", + "shift-left": "editor::SelectLeft", + "shift-right": "editor::SelectRight", + "ctrl-shift-left": "editor::SelectToPreviousWordStart", + "ctrl-shift-right": "editor::SelectToNextWordEnd", + "ctrl-shift-home": "editor::SelectToBeginning", + "ctrl-shift-end": "editor::SelectToEnd", + "ctrl-a": "editor::SelectAll", + "ctrl-l": "editor::SelectLine", + "shift-alt-f": "editor::Format", + "shift-alt-o": "editor::OrganizeImports", + "shift-home": ["editor::SelectToBeginningOfLine", { "stop_at_soft_wraps": true, "stop_at_indent": true }], + "shift-end": ["editor::SelectToEndOfLine", { "stop_at_soft_wraps": true }], + "ctrl-alt-space": "editor::ShowCharacterPalette", + "ctrl-;": "editor::ToggleLineNumbers", + "ctrl-'": "editor::ToggleSelectedDiffHunks", + "ctrl-\"": "editor::ExpandAllDiffHunks", + "ctrl-i": "editor::ShowSignatureHelp", + "alt-g b": "git::Blame", + "alt-g m": "git::OpenModifiedFiles", + "menu": "editor::OpenContextMenu", + "shift-f10": "editor::OpenContextMenu", + "ctrl-shift-e": "editor::ToggleEditPrediction", + "f9": "editor::ToggleBreakpoint", + "shift-f9": "editor::EditLogBreakpoint" + } + }, + { + "context": "Editor && mode == full", + "use_key_equivalents": true, + "bindings": { + "shift-enter": "editor::Newline", + "enter": "editor::Newline", + "ctrl-enter": "editor::NewlineBelow", + "ctrl-shift-enter": "editor::NewlineAbove", + "ctrl-k ctrl-z": "editor::ToggleSoftWrap", + "ctrl-k z": "editor::ToggleSoftWrap", + "ctrl-f": "buffer_search::Deploy", + "ctrl-h": "buffer_search::DeployReplace", + "ctrl-shift-.": "assistant::QuoteSelection", + "ctrl-shift-,": "assistant::InsertIntoEditor", + "shift-alt-e": "editor::SelectEnclosingSymbol", + "ctrl-shift-backspace": "editor::GoToPreviousChange", + "ctrl-shift-alt-backspace": "editor::GoToNextChange", + "alt-enter": "editor::OpenSelectionsInMultibuffer" + } + }, + { + "context": "Editor && mode == full && edit_prediction", + "use_key_equivalents": true, + "bindings": { + "alt-]": "editor::NextEditPrediction", + "alt-[": "editor::PreviousEditPrediction" + } + }, + { + "context": "Editor && !edit_prediction", + "use_key_equivalents": true, + "bindings": { + "alt-\\": "editor::ShowEditPrediction" + } + }, + { + "context": "Editor && mode == auto_height", + "use_key_equivalents": true, + "bindings": { + "ctrl-enter": "editor::Newline", + "shift-enter": "editor::Newline", + "ctrl-shift-enter": "editor::NewlineBelow" + } + }, + { + "context": "Markdown", + "use_key_equivalents": true, + "bindings": { + "ctrl-c": "markdown::Copy" + } + }, + { + "context": "Editor && jupyter && !ContextEditor", + "use_key_equivalents": true, + "bindings": { + "ctrl-shift-enter": "repl::Run", + "ctrl-alt-enter": "repl::RunInPlace" + } + }, + { + "context": "Editor && !agent_diff", + "use_key_equivalents": true, + "bindings": { + "ctrl-k ctrl-r": "git::Restore", + "alt-y": "git::StageAndNext", + "shift-alt-y": "git::UnstageAndNext" + } + }, + { + "context": "Editor && editor_agent_diff", + "use_key_equivalents": true, + "bindings": { + "ctrl-y": "agent::Keep", + "ctrl-n": "agent::Reject", + "ctrl-shift-y": "agent::KeepAll", + "ctrl-shift-n": "agent::RejectAll", + "ctrl-shift-r": "agent::OpenAgentDiff" + } + }, + { + "context": "AgentDiff", + "use_key_equivalents": true, + "bindings": { + "ctrl-y": "agent::Keep", + "ctrl-n": "agent::Reject", + "ctrl-shift-y": "agent::KeepAll", + "ctrl-shift-n": "agent::RejectAll" + } + }, + { + "context": "ContextEditor > Editor", + "use_key_equivalents": true, + "bindings": { + "ctrl-enter": "assistant::Assist", + "ctrl-s": "workspace::Save", + "ctrl-shift-,": "assistant::InsertIntoEditor", + "shift-enter": "assistant::Split", + "ctrl-r": "assistant::CycleMessageRole", + "enter": "assistant::ConfirmCommand", + "alt-enter": "editor::Newline", + "ctrl-k c": "assistant::CopyCode", + "ctrl-g": "search::SelectNextMatch", + "ctrl-shift-g": "search::SelectPreviousMatch", + "ctrl-k l": "agent::OpenRulesLibrary" + } + }, + { + "context": "AgentPanel", + "use_key_equivalents": true, + "bindings": { + "ctrl-n": "agent::NewThread", + "shift-alt-n": "agent::NewTextThread", + "ctrl-shift-h": "agent::OpenHistory", + "shift-alt-c": "agent::OpenSettings", + "shift-alt-p": "agent::OpenRulesLibrary", + "ctrl-i": "agent::ToggleProfileSelector", + "shift-alt-/": "agent::ToggleModelSelector", + "ctrl-shift-a": "agent::ToggleContextPicker", + "ctrl-shift-j": "agent::ToggleNavigationMenu", + "ctrl-shift-i": "agent::ToggleOptionsMenu", + // "ctrl-shift-alt-n": "agent::ToggleNewThreadMenu", + "shift-alt-escape": "agent::ExpandMessageEditor", + "ctrl-shift-.": "assistant::QuoteSelection", + "shift-alt-e": "agent::RemoveAllContext", + "ctrl-shift-e": "project_panel::ToggleFocus", + "ctrl-shift-enter": "agent::ContinueThread", + "super-ctrl-b": "agent::ToggleBurnMode", + "alt-enter": "agent::ContinueWithBurnMode", + "ctrl-y": "agent::AllowOnce", + "ctrl-alt-y": "agent::AllowAlways", + "ctrl-d": "agent::RejectOnce" + } + }, + { + "context": "AgentPanel > NavigationMenu", + "use_key_equivalents": true, + "bindings": { + "shift-backspace": "agent::DeleteRecentlyOpenThread" + } + }, + { + "context": "AgentPanel > Markdown", + "use_key_equivalents": true, + "bindings": { + "ctrl-c": "markdown::CopyAsMarkdown" + } + }, + { + "context": "AgentPanel && prompt_editor", + "use_key_equivalents": true, + "bindings": { + "ctrl-n": "agent::NewTextThread", + "ctrl-alt-t": "agent::NewThread" + } + }, + { + "context": "AgentPanel && external_agent_thread", + "use_key_equivalents": true, + "bindings": { + "ctrl-n": "agent::NewExternalAgentThread", + "ctrl-alt-t": "agent::NewThread" + } + }, + { + "context": "MessageEditor && !Picker > Editor && !use_modifier_to_send", + "use_key_equivalents": true, + "bindings": { + "enter": "agent::Chat", + "ctrl-enter": "agent::ChatWithFollow", + "ctrl-i": "agent::ToggleProfileSelector", + "ctrl-shift-r": "agent::OpenAgentDiff", + "ctrl-shift-y": "agent::KeepAll", + "ctrl-shift-n": "agent::RejectAll" + } + }, + { + "context": "MessageEditor && !Picker > Editor && use_modifier_to_send", + "use_key_equivalents": true, + "bindings": { + "ctrl-enter": "agent::Chat", + "enter": "editor::Newline", + "ctrl-i": "agent::ToggleProfileSelector", + "ctrl-shift-r": "agent::OpenAgentDiff", + "ctrl-shift-y": "agent::KeepAll", + "ctrl-shift-n": "agent::RejectAll" + } + }, + { + "context": "EditMessageEditor > Editor", + "use_key_equivalents": true, + "bindings": { + "escape": "menu::Cancel", + "enter": "menu::Confirm", + "alt-enter": "editor::Newline" + } + }, + { + "context": "AgentFeedbackMessageEditor > Editor", + "use_key_equivalents": true, + "bindings": { + "escape": "menu::Cancel", + "enter": "menu::Confirm", + "alt-enter": "editor::Newline" + } + }, + { + "context": "ContextStrip", + "use_key_equivalents": true, + "bindings": { + "up": "agent::FocusUp", + "right": "agent::FocusRight", + "left": "agent::FocusLeft", + "down": "agent::FocusDown", + "backspace": "agent::RemoveFocusedContext", + "enter": "agent::AcceptSuggestedContext" + } + }, + { + "context": "AcpThread > ModeSelector", + "bindings": { + "ctrl-enter": "menu::Confirm" + } + }, + { + "context": "AcpThread > Editor", + "use_key_equivalents": true, + "bindings": { + "enter": "agent::Chat", + "ctrl-shift-r": "agent::OpenAgentDiff", + "ctrl-shift-y": "agent::KeepAll", + "ctrl-shift-n": "agent::RejectAll", + "shift-tab": "agent::CycleModeSelector" + } + }, + { + "context": "ThreadHistory", + "use_key_equivalents": true, + "bindings": { + "backspace": "agent::RemoveSelectedThread" + } + }, + { + "context": "PromptLibrary", + "use_key_equivalents": true, + "bindings": { + "ctrl-n": "rules_library::NewRule", + "ctrl-shift-s": "rules_library::ToggleDefaultRule" + } + }, + { + "context": "BufferSearchBar", + "use_key_equivalents": true, + "bindings": { + "escape": "buffer_search::Dismiss", + "tab": "buffer_search::FocusEditor", + "enter": "search::SelectNextMatch", + "shift-enter": "search::SelectPreviousMatch", + "alt-enter": "search::SelectAllMatches", + "ctrl-f": "search::FocusSearch", + "ctrl-h": "search::ToggleReplace", + "ctrl-l": "search::ToggleSelection" + } + }, + { + "context": "BufferSearchBar && in_replace > Editor", + "use_key_equivalents": true, + "bindings": { + "enter": "search::ReplaceNext", + "ctrl-enter": "search::ReplaceAll" + } + }, + { + "context": "BufferSearchBar && !in_replace > Editor", + "use_key_equivalents": true, + "bindings": { + "up": "search::PreviousHistoryQuery", + "down": "search::NextHistoryQuery" + } + }, + { + "context": "ProjectSearchBar", + "use_key_equivalents": true, + "bindings": { + "escape": "project_search::ToggleFocus", + "ctrl-shift-f": "search::FocusSearch", + "ctrl-shift-h": "search::ToggleReplace", + "alt-r": "search::ToggleRegex" // vscode + } + }, + { + "context": "ProjectSearchBar > Editor", + "use_key_equivalents": true, + "bindings": { + "up": "search::PreviousHistoryQuery", + "down": "search::NextHistoryQuery" + } + }, + { + "context": "ProjectSearchBar && in_replace > Editor", + "use_key_equivalents": true, + "bindings": { + "enter": "search::ReplaceNext", + "ctrl-alt-enter": "search::ReplaceAll" + } + }, + { + "context": "ProjectSearchView", + "use_key_equivalents": true, + "bindings": { + "escape": "project_search::ToggleFocus", + "ctrl-shift-h": "search::ToggleReplace", + "alt-r": "search::ToggleRegex" // vscode + } + }, + { + "context": "Pane", + "use_key_equivalents": true, + "bindings": { + "alt-1": ["pane::ActivateItem", 0], + "alt-2": ["pane::ActivateItem", 1], + "alt-3": ["pane::ActivateItem", 2], + "alt-4": ["pane::ActivateItem", 3], + "alt-5": ["pane::ActivateItem", 4], + "alt-6": ["pane::ActivateItem", 5], + "alt-7": ["pane::ActivateItem", 6], + "alt-8": ["pane::ActivateItem", 7], + "alt-9": ["pane::ActivateItem", 8], + "alt-0": "pane::ActivateLastItem", + "ctrl-pageup": "pane::ActivatePreviousItem", + "ctrl-pagedown": "pane::ActivateNextItem", + "ctrl-shift-pageup": "pane::SwapItemLeft", + "ctrl-shift-pagedown": "pane::SwapItemRight", + "ctrl-f4": ["pane::CloseActiveItem", { "close_pinned": false }], + "ctrl-w": ["pane::CloseActiveItem", { "close_pinned": false }], + "ctrl-shift-alt-t": ["pane::CloseOtherItems", { "close_pinned": false }], + "ctrl-shift-alt-w": "workspace::CloseInactiveTabsAndPanes", + "ctrl-k e": ["pane::CloseItemsToTheLeft", { "close_pinned": false }], + "ctrl-k t": ["pane::CloseItemsToTheRight", { "close_pinned": false }], + "ctrl-k u": ["pane::CloseCleanItems", { "close_pinned": false }], + "ctrl-k w": ["pane::CloseAllItems", { "close_pinned": false }], + "ctrl-k ctrl-w": "workspace::CloseAllItemsAndPanes", + "back": "pane::GoBack", + "alt--": "pane::GoBack", + "alt-=": "pane::GoForward", + "forward": "pane::GoForward", + "f3": "search::SelectNextMatch", + "shift-f3": "search::SelectPreviousMatch", + "ctrl-shift-f": "project_search::ToggleFocus", + "shift-alt-h": "search::ToggleReplace", + "alt-l": "search::ToggleSelection", + "alt-enter": "search::SelectAllMatches", + "alt-c": "search::ToggleCaseSensitive", + "alt-w": "search::ToggleWholeWord", + "alt-f": "project_search::ToggleFilters", + "alt-r": "search::ToggleRegex", + // "ctrl-shift-alt-x": "search::ToggleRegex", + "ctrl-k shift-enter": "pane::TogglePinTab" + } + }, + // Bindings from VS Code + { + "context": "Editor", + "use_key_equivalents": true, + "bindings": { + "ctrl-[": "editor::Outdent", + "ctrl-]": "editor::Indent", + "ctrl-shift-alt-up": "editor::AddSelectionAbove", // Insert Cursor Above + "ctrl-shift-alt-down": "editor::AddSelectionBelow", // Insert Cursor Below + "ctrl-shift-k": "editor::DeleteLine", + "alt-up": "editor::MoveLineUp", + "alt-down": "editor::MoveLineDown", + "shift-alt-up": "editor::DuplicateLineUp", + "shift-alt-down": "editor::DuplicateLineDown", + "shift-alt-right": "editor::SelectLargerSyntaxNode", // Expand selection + "shift-alt-left": "editor::SelectSmallerSyntaxNode", // Shrink selection + "ctrl-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection + "ctrl-f2": "editor::SelectAllMatches", // Select all occurrences of current word + "ctrl-d": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand + "ctrl-shift-down": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch + "ctrl-shift-up": ["editor::SelectPrevious", { "replace_newest": false }], // editor.action.addSelectionToPreviousFindMatch + "ctrl-k ctrl-d": ["editor::SelectNext", { "replace_newest": true }], // editor.action.moveSelectionToNextFindMatch / find_under_expand_skip + "ctrl-k ctrl-shift-d": ["editor::SelectPrevious", { "replace_newest": true }], // editor.action.moveSelectionToPreviousFindMatch + "ctrl-k ctrl-i": "editor::Hover", + "ctrl-k ctrl-b": "editor::BlameHover", + "ctrl-/": ["editor::ToggleComments", { "advance_downwards": false }], + "f8": ["editor::GoToDiagnostic", { "severity": { "min": "hint", "max": "error" } }], + "shift-f8": ["editor::GoToPreviousDiagnostic", { "severity": { "min": "hint", "max": "error" } }], + "f2": "editor::Rename", + "f12": "editor::GoToDefinition", + "alt-f12": "editor::GoToDefinitionSplit", + "ctrl-shift-f10": "editor::GoToDefinitionSplit", + "ctrl-f12": "editor::GoToImplementation", + "shift-f12": "editor::GoToTypeDefinition", + "ctrl-alt-f12": "editor::GoToTypeDefinitionSplit", + "shift-alt-f12": "editor::FindAllReferences", + "ctrl-m": "editor::MoveToEnclosingBracket", // from jetbrains + "ctrl-shift-\\": "editor::MoveToEnclosingBracket", + "ctrl-shift-[": "editor::Fold", + "ctrl-shift-]": "editor::UnfoldLines", + "ctrl-k ctrl-l": "editor::ToggleFold", + "ctrl-k ctrl-[": "editor::FoldRecursive", + "ctrl-k ctrl-]": "editor::UnfoldRecursive", + "ctrl-k ctrl-1": ["editor::FoldAtLevel", 1], + "ctrl-k ctrl-2": ["editor::FoldAtLevel", 2], + "ctrl-k ctrl-3": ["editor::FoldAtLevel", 3], + "ctrl-k ctrl-4": ["editor::FoldAtLevel", 4], + "ctrl-k ctrl-5": ["editor::FoldAtLevel", 5], + "ctrl-k ctrl-6": ["editor::FoldAtLevel", 6], + "ctrl-k ctrl-7": ["editor::FoldAtLevel", 7], + "ctrl-k ctrl-8": ["editor::FoldAtLevel", 8], + "ctrl-k ctrl-9": ["editor::FoldAtLevel", 9], + "ctrl-k ctrl-0": "editor::FoldAll", + "ctrl-k ctrl-j": "editor::UnfoldAll", + "ctrl-space": "editor::ShowCompletions", + "ctrl-shift-space": "editor::ShowWordCompletions", + "ctrl-.": "editor::ToggleCodeActions", + "ctrl-k r": "editor::RevealInFileManager", + "ctrl-k p": "editor::CopyPath", + "ctrl-\\": "pane::SplitRight", + "ctrl-shift-alt-c": "editor::DisplayCursorNames", + "alt-.": "editor::GoToHunk", + "alt-,": "editor::GoToPreviousHunk" + } + }, + { + "context": "Editor && extension == md", + "use_key_equivalents": true, + "bindings": { + "ctrl-k v": "markdown::OpenPreviewToTheSide", + "ctrl-shift-v": "markdown::OpenPreview" + } + }, + { + "context": "Editor && extension == svg", + "use_key_equivalents": true, + "bindings": { + "ctrl-k v": "svg::OpenPreviewToTheSide", + "ctrl-shift-v": "svg::OpenPreview" + } + }, + { + "context": "Editor && mode == full", + "use_key_equivalents": true, + "bindings": { + "ctrl-shift-o": "outline::Toggle", + "ctrl-g": "go_to_line::Toggle" + } + }, + { + "context": "Workspace", + "use_key_equivalents": true, + "bindings": { + // Change the default action on `menu::Confirm` by setting the parameter + // "ctrl-alt-o": ["projects::OpenRecent", { "create_new_window": true }], + "ctrl-r": ["projects::OpenRecent", { "create_new_window": false }], + // Change to open path modal for existing remote connection by setting the parameter + // "ctrl-shift-alt-o": "["projects::OpenRemote", { "from_existing_connection": true }]", + "ctrl-shift-alt-o": ["projects::OpenRemote", { "from_existing_connection": false, "create_new_window": false }], + "shift-alt-b": "branches::OpenRecent", + "shift-alt-enter": "toast::RunAction", + "ctrl-shift-`": "workspace::NewTerminal", + "ctrl-s": "workspace::Save", + "ctrl-k ctrl-shift-s": "workspace::SaveWithoutFormat", + "ctrl-shift-s": "workspace::SaveAs", + "ctrl-n": "workspace::NewFile", + "ctrl-shift-n": "workspace::NewWindow", + "ctrl-`": "terminal_panel::Toggle", + "f10": ["app_menu::OpenApplicationMenu", "Zed"], + "alt-1": ["workspace::ActivatePane", 0], + "alt-2": ["workspace::ActivatePane", 1], + "alt-3": ["workspace::ActivatePane", 2], + "alt-4": ["workspace::ActivatePane", 3], + "alt-5": ["workspace::ActivatePane", 4], + "alt-6": ["workspace::ActivatePane", 5], + "alt-7": ["workspace::ActivatePane", 6], + "alt-8": ["workspace::ActivatePane", 7], + "alt-9": ["workspace::ActivatePane", 8], + "ctrl-alt-b": "workspace::ToggleRightDock", + "ctrl-b": "workspace::ToggleLeftDock", + "ctrl-j": "workspace::ToggleBottomDock", + "ctrl-shift-y": "workspace::CloseAllDocks", + "alt-r": "workspace::ResetActiveDockSize", + // For 0px parameter, uses UI font size value. + "shift-alt--": ["workspace::DecreaseActiveDockSize", { "px": 0 }], + "shift-alt-=": ["workspace::IncreaseActiveDockSize", { "px": 0 }], + "shift-alt-0": "workspace::ResetOpenDocksSize", + "ctrl-shift-alt--": ["workspace::DecreaseOpenDocksSize", { "px": 0 }], + "ctrl-shift-alt-=": ["workspace::IncreaseOpenDocksSize", { "px": 0 }], + "ctrl-shift-f": "pane::DeploySearch", + "ctrl-shift-h": ["pane::DeploySearch", { "replace_enabled": true }], + "ctrl-shift-t": "pane::ReopenClosedItem", + "ctrl-k ctrl-s": "zed::OpenKeymapEditor", + "ctrl-k ctrl-t": "theme_selector::Toggle", + "ctrl-alt-super-p": "settings_profile_selector::Toggle", + "ctrl-t": "project_symbols::Toggle", + "ctrl-p": "file_finder::Toggle", + "ctrl-tab": "tab_switcher::Toggle", + "ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }], + "ctrl-e": "file_finder::Toggle", + "f1": "command_palette::Toggle", + "ctrl-shift-p": "command_palette::Toggle", + "ctrl-shift-m": "diagnostics::Deploy", + "ctrl-shift-e": "project_panel::ToggleFocus", + "ctrl-shift-b": "outline_panel::ToggleFocus", + "ctrl-shift-g": "git_panel::ToggleFocus", + "ctrl-shift-d": "debug_panel::ToggleFocus", + "ctrl-shift-/": "agent::ToggleFocus", + "ctrl-k s": "workspace::SaveAll", + "ctrl-k m": "language_selector::Toggle", + "ctrl-m ctrl-m": "toolchain::AddToolchain", + "escape": "workspace::Unfollow", + "ctrl-k ctrl-left": "workspace::ActivatePaneLeft", + "ctrl-k ctrl-right": "workspace::ActivatePaneRight", + "ctrl-k ctrl-up": "workspace::ActivatePaneUp", + "ctrl-k ctrl-down": "workspace::ActivatePaneDown", + "ctrl-k shift-left": "workspace::SwapPaneLeft", + "ctrl-k shift-right": "workspace::SwapPaneRight", + "ctrl-k shift-up": "workspace::SwapPaneUp", + "ctrl-k shift-down": "workspace::SwapPaneDown", + "ctrl-shift-x": "zed::Extensions", + "ctrl-shift-r": "task::Rerun", + "alt-t": "task::Rerun", + "shift-alt-t": "task::Spawn", + "shift-alt-r": ["task::Spawn", { "reveal_target": "center" }], + // also possible to spawn tasks by name: + // "foo-bar": ["task::Spawn", { "task_name": "MyTask", "reveal_target": "dock" }] + // or by tag: + // "foo-bar": ["task::Spawn", { "task_tag": "MyTag" }], + "f5": "debugger::Rerun", + "ctrl-f4": "workspace::CloseActiveDock", + "ctrl-w": "workspace::CloseActiveDock" + } + }, + { + "context": "Workspace && debugger_running", + "use_key_equivalents": true, + "bindings": { + "f5": "zed::NoAction" + } + }, + { + "context": "Workspace && debugger_stopped", + "use_key_equivalents": true, + "bindings": { + "f5": "debugger::Continue" + } + }, + { + "context": "ApplicationMenu", + "use_key_equivalents": true, + "bindings": { + "f10": "menu::Cancel", + "left": "app_menu::ActivateMenuLeft", + "right": "app_menu::ActivateMenuRight" + } + }, + // Bindings from Sublime Text + { + "context": "Editor", + "use_key_equivalents": true, + "bindings": { + "ctrl-u": "editor::UndoSelection", + "ctrl-shift-u": "editor::RedoSelection", + "ctrl-shift-j": "editor::JoinLines", + "ctrl-alt-backspace": "editor::DeleteToPreviousSubwordStart", + "shift-alt-h": "editor::DeleteToPreviousSubwordStart", + "ctrl-alt-delete": "editor::DeleteToNextSubwordEnd", + "shift-alt-d": "editor::DeleteToNextSubwordEnd", + "ctrl-alt-left": "editor::MoveToPreviousSubwordStart", + "ctrl-alt-right": "editor::MoveToNextSubwordEnd", + "ctrl-shift-alt-left": "editor::SelectToPreviousSubwordStart", + "ctrl-shift-alt-right": "editor::SelectToNextSubwordEnd" + } + }, + // Bindings from Atom + { + "context": "Pane", + "use_key_equivalents": true, + "bindings": { + "ctrl-k up": "pane::SplitUp", + "ctrl-k down": "pane::SplitDown", + "ctrl-k left": "pane::SplitLeft", + "ctrl-k right": "pane::SplitRight" + } + }, + // Bindings that should be unified with bindings for more general actions + { + "context": "Editor && renaming", + "use_key_equivalents": true, + "bindings": { + "enter": "editor::ConfirmRename" + } + }, + { + "context": "Editor && showing_completions", + "use_key_equivalents": true, + "bindings": { + "enter": "editor::ConfirmCompletion", + "shift-enter": "editor::ConfirmCompletionReplace", + "tab": "editor::ComposeCompletion" + } + }, + // Bindings for accepting edit predictions + // + // alt-l is provided as an alternative to tab/alt-tab. and will be displayed in the UI. This is + // because alt-tab may not be available, as it is often used for window switching. + { + "context": "Editor && edit_prediction", + "use_key_equivalents": true, + "bindings": { + "alt-tab": "editor::AcceptEditPrediction", + "alt-l": "editor::AcceptEditPrediction", + "tab": "editor::AcceptEditPrediction", + "alt-right": "editor::AcceptPartialEditPrediction" + } + }, + { + "context": "Editor && edit_prediction_conflict", + "use_key_equivalents": true, + "bindings": { + "alt-tab": "editor::AcceptEditPrediction", + "alt-l": "editor::AcceptEditPrediction", + "alt-right": "editor::AcceptPartialEditPrediction" + } + }, + { + "context": "Editor && showing_code_actions", + "use_key_equivalents": true, + "bindings": { + "enter": "editor::ConfirmCodeAction" + } + }, + { + "context": "Editor && (showing_code_actions || showing_completions)", + "use_key_equivalents": true, + "bindings": { + "ctrl-p": "editor::ContextMenuPrevious", + "up": "editor::ContextMenuPrevious", + "ctrl-n": "editor::ContextMenuNext", + "down": "editor::ContextMenuNext", + "pageup": "editor::ContextMenuFirst", + "pagedown": "editor::ContextMenuLast" + } + }, + { + "context": "Editor && showing_signature_help && !showing_completions", + "use_key_equivalents": true, + "bindings": { + "up": "editor::SignatureHelpPrevious", + "down": "editor::SignatureHelpNext" + } + }, + // Custom bindings + { + "use_key_equivalents": true, + "bindings": { + "ctrl-shift-alt-f": "workspace::FollowNextCollaborator", + // Only available in debug builds: opens an element inspector for development. + "shift-alt-i": "dev::ToggleInspector" + } + }, + { + "context": "!Terminal", + "use_key_equivalents": true, + "bindings": { + "ctrl-shift-c": "collab_panel::ToggleFocus" + } + }, + { + "context": "!ContextEditor > Editor && mode == full", + "use_key_equivalents": true, + "bindings": { + "alt-enter": "editor::OpenExcerpts", + "shift-enter": "editor::ExpandExcerpts", + "ctrl-alt-enter": "editor::OpenExcerptsSplit", + "ctrl-shift-e": "pane::RevealInProjectPanel", + "ctrl-f8": "editor::GoToHunk", + "ctrl-shift-f8": "editor::GoToPreviousHunk", + "ctrl-enter": "assistant::InlineAssist", + "ctrl-shift-;": "editor::ToggleInlayHints" + } + }, + { + "context": "PromptEditor", + "use_key_equivalents": true, + "bindings": { + "ctrl-[": "agent::CyclePreviousInlineAssist", + "ctrl-]": "agent::CycleNextInlineAssist", + "shift-alt-e": "agent::RemoveAllContext" + } + }, + { + "context": "Prompt", + "use_key_equivalents": true, + "bindings": { + "left": "menu::SelectPrevious", + "right": "menu::SelectNext", + "h": "menu::SelectPrevious", + "l": "menu::SelectNext" + } + }, + { + "context": "ProjectSearchBar && !in_replace", + "use_key_equivalents": true, + "bindings": { + "ctrl-enter": "project_search::SearchInNew" + } + }, + { + "context": "OutlinePanel && not_editing", + "use_key_equivalents": true, + "bindings": { + "left": "outline_panel::CollapseSelectedEntry", + "right": "outline_panel::ExpandSelectedEntry", + "shift-alt-c": "outline_panel::CopyPath", + "ctrl-shift-alt-c": "workspace::CopyRelativePath", + "ctrl-alt-r": "outline_panel::RevealInFileManager", + "space": "outline_panel::OpenSelectedEntry", + "shift-down": "menu::SelectNext", + "shift-up": "menu::SelectPrevious", + "alt-enter": "editor::OpenExcerpts", + "ctrl-alt-enter": "editor::OpenExcerptsSplit" + } + }, + { + "context": "ProjectPanel", + "use_key_equivalents": true, + "bindings": { + "left": "project_panel::CollapseSelectedEntry", + "right": "project_panel::ExpandSelectedEntry", + "ctrl-n": "project_panel::NewFile", + "alt-n": "project_panel::NewDirectory", + "ctrl-x": "project_panel::Cut", + "ctrl-insert": "project_panel::Copy", + "ctrl-c": "project_panel::Copy", + "shift-insert": "project_panel::Paste", + "ctrl-v": "project_panel::Paste", + "shift-alt-c": "project_panel::CopyPath", + "ctrl-k ctrl-shift-c": "workspace::CopyRelativePath", + "enter": "project_panel::Rename", + "f2": "project_panel::Rename", + "backspace": ["project_panel::Trash", { "skip_prompt": false }], + "delete": ["project_panel::Trash", { "skip_prompt": false }], + "shift-delete": ["project_panel::Delete", { "skip_prompt": false }], + "ctrl-backspace": ["project_panel::Delete", { "skip_prompt": false }], + "ctrl-delete": ["project_panel::Delete", { "skip_prompt": false }], + "ctrl-alt-r": "project_panel::RevealInFileManager", + "ctrl-shift-enter": "project_panel::OpenWithSystem", + "alt-d": "project_panel::CompareMarkedFiles", + "ctrl-k ctrl-shift-f": "project_panel::NewSearchInDirectory", + "shift-down": "menu::SelectNext", + "shift-up": "menu::SelectPrevious", + "escape": "menu::Cancel" + } + }, + { + "context": "ProjectPanel && not_editing", + "use_key_equivalents": true, + "bindings": { + "space": "project_panel::Open" + } + }, + { + "context": "GitPanel && ChangesList", + "use_key_equivalents": true, + "bindings": { + "up": "menu::SelectPrevious", + "down": "menu::SelectNext", + "enter": "menu::Confirm", + "alt-y": "git::StageFile", + "shift-alt-y": "git::UnstageFile", + "space": "git::ToggleStaged", + "shift-space": "git::StageRange", + "tab": "git_panel::FocusEditor", + "shift-tab": "git_panel::FocusEditor", + "escape": "git_panel::ToggleFocus", + "alt-enter": "menu::SecondaryConfirm", + "delete": ["git::RestoreFile", { "skip_prompt": false }], + "backspace": ["git::RestoreFile", { "skip_prompt": false }], + "shift-delete": ["git::RestoreFile", { "skip_prompt": false }], + "ctrl-backspace": ["git::RestoreFile", { "skip_prompt": false }], + "ctrl-delete": ["git::RestoreFile", { "skip_prompt": false }] + } + }, + { + "context": "GitPanel && CommitEditor", + "use_key_equivalents": true, + "bindings": { + "escape": "git::Cancel" + } + }, + { + "context": "GitCommit > Editor", + "use_key_equivalents": true, + "bindings": { + "escape": "menu::Cancel", + "enter": "editor::Newline", + "ctrl-enter": "git::Commit", + "ctrl-shift-enter": "git::Amend", + "alt-l": "git::GenerateCommitMessage" + } + }, + { + "context": "GitPanel", + "use_key_equivalents": true, + "bindings": { + "ctrl-g ctrl-g": "git::Fetch", + "ctrl-g up": "git::Push", + "ctrl-g down": "git::Pull", + "ctrl-g shift-up": "git::ForcePush", + "ctrl-g d": "git::Diff", + "ctrl-g backspace": "git::RestoreTrackedFiles", + "ctrl-g shift-backspace": "git::TrashUntrackedFiles", + "ctrl-space": "git::StageAll", + "ctrl-shift-space": "git::UnstageAll", + "ctrl-enter": "git::Commit", + "ctrl-shift-enter": "git::Amend" + } + }, + { + "context": "GitDiff > Editor", + "use_key_equivalents": true, + "bindings": { + "ctrl-enter": "git::Commit", + "ctrl-shift-enter": "git::Amend", + "ctrl-space": "git::StageAll", + "ctrl-shift-space": "git::UnstageAll" + } + }, + { + "context": "AskPass > Editor", + "use_key_equivalents": true, + "bindings": { + "enter": "menu::Confirm" + } + }, + { + "context": "CommitEditor > Editor", + "use_key_equivalents": true, + "bindings": { + "escape": "git_panel::FocusChanges", + "tab": "git_panel::FocusChanges", + "shift-tab": "git_panel::FocusChanges", + "enter": "editor::Newline", + "ctrl-enter": "git::Commit", + "ctrl-shift-enter": "git::Amend", + "alt-up": "git_panel::FocusChanges", + "alt-l": "git::GenerateCommitMessage" + } + }, + { + "context": "DebugPanel", + "use_key_equivalents": true, + "bindings": { + "ctrl-t": "debugger::ToggleThreadPicker", + "ctrl-i": "debugger::ToggleSessionPicker", + "shift-alt-escape": "debugger::ToggleExpandItem" + } + }, + { + "context": "VariableList", + "use_key_equivalents": true, + "bindings": { + "left": "variable_list::CollapseSelectedEntry", + "right": "variable_list::ExpandSelectedEntry", + "enter": "variable_list::EditVariable", + "ctrl-c": "variable_list::CopyVariableValue", + "ctrl-alt-c": "variable_list::CopyVariableName", + "delete": "variable_list::RemoveWatch", + "backspace": "variable_list::RemoveWatch", + "alt-enter": "variable_list::AddWatch" + } + }, + { + "context": "BreakpointList", + "use_key_equivalents": true, + "bindings": { + "space": "debugger::ToggleEnableBreakpoint", + "backspace": "debugger::UnsetBreakpoint", + "left": "debugger::PreviousBreakpointProperty", + "right": "debugger::NextBreakpointProperty" + } + }, + { + "context": "CollabPanel && not_editing", + "use_key_equivalents": true, + "bindings": { + "ctrl-backspace": "collab_panel::Remove", + "space": "menu::Confirm" + } + }, + { + "context": "CollabPanel", + "use_key_equivalents": true, + "bindings": { + "alt-up": "collab_panel::MoveChannelUp", + "alt-down": "collab_panel::MoveChannelDown" + } + }, + { + "context": "(CollabPanel && editing) > Editor", + "use_key_equivalents": true, + "bindings": { + "space": "collab_panel::InsertSpace" + } + }, + { + "context": "ChannelModal", + "use_key_equivalents": true, + "bindings": { + "tab": "channel_modal::ToggleMode" + } + }, + { + "context": "Picker > Editor", + "use_key_equivalents": true, + "bindings": { + "escape": "menu::Cancel", + "up": "menu::SelectPrevious", + "down": "menu::SelectNext", + "tab": "picker::ConfirmCompletion", + "alt-enter": ["picker::ConfirmInput", { "secondary": false }] + } + }, + { + "context": "ChannelModal > Picker > Editor", + "use_key_equivalents": true, + "bindings": { + "tab": "channel_modal::ToggleMode" + } + }, + { + "context": "ToolchainSelector", + "use_key_equivalents": true, + "bindings": { + "ctrl-shift-a": "toolchain::AddToolchain" + } + }, + { + "context": "FileFinder || (FileFinder > Picker > Editor)", + "use_key_equivalents": true, + "bindings": { + "ctrl-p": "file_finder::Toggle", + "ctrl-shift-a": "file_finder::ToggleSplitMenu", + "ctrl-shift-i": "file_finder::ToggleFilterMenu" + } + }, + { + "context": "FileFinder || (FileFinder > Picker > Editor) || (FileFinder > Picker > menu)", + "use_key_equivalents": true, + "bindings": { + "ctrl-shift-p": "file_finder::SelectPrevious", + "ctrl-j": "pane::SplitDown", + "ctrl-k": "pane::SplitUp", + "ctrl-h": "pane::SplitLeft", + "ctrl-l": "pane::SplitRight" + } + }, + { + "context": "TabSwitcher", + "use_key_equivalents": true, + "bindings": { + "ctrl-shift-tab": "menu::SelectPrevious", + "ctrl-up": "menu::SelectPrevious", + "ctrl-down": "menu::SelectNext", + "ctrl-backspace": "tab_switcher::CloseSelectedItem" + } + }, + { + "context": "Terminal", + "use_key_equivalents": true, + "bindings": { + "ctrl-alt-space": "terminal::ShowCharacterPalette", + "ctrl-insert": "terminal::Copy", + "ctrl-shift-c": "terminal::Copy", + "shift-insert": "terminal::Paste", + "ctrl-shift-v": "terminal::Paste", + "ctrl-enter": "assistant::InlineAssist", + "alt-b": ["terminal::SendText", "\u001bb"], + "alt-f": ["terminal::SendText", "\u001bf"], + "alt-.": ["terminal::SendText", "\u001b."], + "ctrl-delete": ["terminal::SendText", "\u001bd"], + // Overrides for conflicting keybindings + "ctrl-b": ["terminal::SendKeystroke", "ctrl-b"], + "ctrl-c": ["terminal::SendKeystroke", "ctrl-c"], + "ctrl-e": ["terminal::SendKeystroke", "ctrl-e"], + "ctrl-o": ["terminal::SendKeystroke", "ctrl-o"], + "ctrl-w": ["terminal::SendKeystroke", "ctrl-w"], + "ctrl-backspace": ["terminal::SendKeystroke", "ctrl-w"], + "ctrl-shift-a": "editor::SelectAll", + "ctrl-shift-f": "buffer_search::Deploy", + "ctrl-shift-l": "terminal::Clear", + "ctrl-shift-w": "pane::CloseActiveItem", + "up": ["terminal::SendKeystroke", "up"], + "pageup": ["terminal::SendKeystroke", "pageup"], + "down": ["terminal::SendKeystroke", "down"], + "pagedown": ["terminal::SendKeystroke", "pagedown"], + "escape": ["terminal::SendKeystroke", "escape"], + "enter": ["terminal::SendKeystroke", "enter"], + "shift-pageup": "terminal::ScrollPageUp", + "shift-pagedown": "terminal::ScrollPageDown", + "shift-up": "terminal::ScrollLineUp", + "shift-down": "terminal::ScrollLineDown", + "shift-home": "terminal::ScrollToTop", + "shift-end": "terminal::ScrollToBottom", + "ctrl-shift-space": "terminal::ToggleViMode", + "ctrl-shift-r": "terminal::RerunTask", + "ctrl-alt-r": "terminal::RerunTask", + "alt-t": "terminal::RerunTask" + } + }, + { + "context": "ZedPredictModal", + "use_key_equivalents": true, + "bindings": { + "escape": "menu::Cancel" + } + }, + { + "context": "ConfigureContextServerModal > Editor", + "use_key_equivalents": true, + "bindings": { + "escape": "menu::Cancel", + "enter": "editor::Newline", + "ctrl-enter": "menu::Confirm" + } + }, + { + "context": "OnboardingAiConfigurationModal", + "use_key_equivalents": true, + "bindings": { + "escape": "menu::Cancel" + } + }, + { + "context": "Diagnostics", + "use_key_equivalents": true, + "bindings": { + "ctrl-r": "diagnostics::ToggleDiagnosticsRefresh" + } + }, + { + "context": "DebugConsole > Editor", + "use_key_equivalents": true, + "bindings": { + "enter": "menu::Confirm", + "alt-enter": "console::WatchExpression" + } + }, + { + "context": "RunModal", + "use_key_equivalents": true, + "bindings": { + "ctrl-tab": "pane::ActivateNextItem", + "ctrl-shift-tab": "pane::ActivatePreviousItem" + } + }, + { + "context": "MarkdownPreview", + "use_key_equivalents": true, + "bindings": { + "pageup": "markdown::MovePageUp", + "pagedown": "markdown::MovePageDown" + } + }, + { + "context": "KeymapEditor", + "use_key_equivalents": true, + "bindings": { + "ctrl-f": "search::FocusSearch", + "alt-f": "keymap_editor::ToggleKeystrokeSearch", + "alt-c": "keymap_editor::ToggleConflictFilter", + "enter": "keymap_editor::EditBinding", + "alt-enter": "keymap_editor::CreateBinding", + "ctrl-c": "keymap_editor::CopyAction", + "ctrl-shift-c": "keymap_editor::CopyContext", + "ctrl-t": "keymap_editor::ShowMatchingKeybinds" + } + }, + { + "context": "KeystrokeInput", + "use_key_equivalents": true, + "bindings": { + "enter": "keystroke_input::StartRecording", + "escape escape escape": "keystroke_input::StopRecording", + "delete": "keystroke_input::ClearKeystrokes" + } + }, + { + "context": "KeybindEditorModal", + "use_key_equivalents": true, + "bindings": { + "ctrl-enter": "menu::Confirm", + "escape": "menu::Cancel" + } + }, + { + "context": "KeybindEditorModal > Editor", + "use_key_equivalents": true, + "bindings": { + "up": "menu::SelectPrevious", + "down": "menu::SelectNext" + } + }, + { + "context": "Onboarding", + "use_key_equivalents": true, + "bindings": { + "ctrl-1": "onboarding::ActivateBasicsPage", + "ctrl-2": "onboarding::ActivateEditingPage", + "ctrl-3": "onboarding::ActivateAISetupPage", + "ctrl-escape": "onboarding::Finish", + "alt-tab": "onboarding::SignIn", + "shift-alt-a": "onboarding::OpenAccount" + } + } +] diff --git a/assets/keymaps/linux/cursor.json b/assets/keymaps/linux/cursor.json index 1c381b0cf05531e7fd5743d71be1b4d662bb4c0d..2e27158e1167f0840cadfb0d86dc06614f6076c6 100644 --- a/assets/keymaps/linux/cursor.json +++ b/assets/keymaps/linux/cursor.json @@ -17,8 +17,8 @@ "bindings": { "ctrl-i": "agent::ToggleFocus", "ctrl-shift-i": "agent::ToggleFocus", - "ctrl-shift-l": "assistant::QuoteSelection", // In cursor uses "Ask" mode - "ctrl-l": "assistant::QuoteSelection", // In cursor uses "Agent" mode + "ctrl-shift-l": "agent::QuoteSelection", // In cursor uses "Ask" mode + "ctrl-l": "agent::QuoteSelection", // In cursor uses "Agent" mode "ctrl-k": "assistant::InlineAssist", "ctrl-shift-k": "assistant::InsertIntoEditor" } diff --git a/assets/keymaps/linux/emacs.json b/assets/keymaps/linux/emacs.json index 0ff3796f03d85affdae88d009e88e73516ba385a..0f936ba2f968abe0759e4bb294271a5e5f501848 100755 --- a/assets/keymaps/linux/emacs.json +++ b/assets/keymaps/linux/emacs.json @@ -38,10 +38,11 @@ "alt-;": ["editor::ToggleComments", { "advance_downwards": false }], "ctrl-x ctrl-;": "editor::ToggleComments", "alt-.": "editor::GoToDefinition", // xref-find-definitions + "alt-?": "editor::FindAllReferences", // xref-find-references "alt-,": "pane::GoBack", // xref-pop-marker-stack "ctrl-x h": "editor::SelectAll", // mark-whole-buffer "ctrl-d": "editor::Delete", // delete-char - "alt-d": "editor::DeleteToNextWordEnd", // kill-word + "alt-d": ["editor::DeleteToNextWordEnd", { "ignore_newlines": false, "ignore_brackets": false }], // kill-word "ctrl-k": "editor::KillRingCut", // kill-line "ctrl-w": "editor::Cut", // kill-region "alt-w": "editor::Copy", // kill-ring-save diff --git a/assets/keymaps/linux/jetbrains.json b/assets/keymaps/linux/jetbrains.json index 3df1243feda88680a4ce03cd0b25ab9ea9a36edd..59a182a968a849edb3359927e7647f611bcd44da 100644 --- a/assets/keymaps/linux/jetbrains.json +++ b/assets/keymaps/linux/jetbrains.json @@ -125,7 +125,7 @@ { "context": "Workspace || Editor", "bindings": { - "alt-f12": "terminal_panel::ToggleFocus", + "alt-f12": "terminal_panel::Toggle", "ctrl-shift-k": "git::Push" } }, diff --git a/assets/keymaps/linux/sublime_text.json b/assets/keymaps/linux/sublime_text.json index ece9d69dd102c019072678373e9328f302d4cb07..f526db45ff29e0828ce58df6ca9816bd71a4cbe5 100644 --- a/assets/keymaps/linux/sublime_text.json +++ b/assets/keymaps/linux/sublime_text.json @@ -50,8 +50,8 @@ "ctrl-k ctrl-u": "editor::ConvertToUpperCase", "ctrl-k ctrl-l": "editor::ConvertToLowerCase", "shift-alt-m": "markdown::OpenPreviewToTheSide", - "ctrl-backspace": "editor::DeleteToPreviousWordStart", - "ctrl-delete": "editor::DeleteToNextWordEnd", + "ctrl-backspace": ["editor::DeleteToPreviousWordStart", { "ignore_newlines": false, "ignore_brackets": false }], + "ctrl-delete": ["editor::DeleteToNextWordEnd", { "ignore_newlines": false, "ignore_brackets": false }], "alt-right": "editor::MoveToNextSubwordEnd", "alt-left": "editor::MoveToPreviousSubwordStart", "alt-shift-right": "editor::SelectToNextSubwordEnd", diff --git a/assets/keymaps/macos/cursor.json b/assets/keymaps/macos/cursor.json index fdf9c437cf395c074e42ae9c9dc53c1aa6ff66c2..1d723bd75bb788aa1ea63335f9fa555cb50d2df0 100644 --- a/assets/keymaps/macos/cursor.json +++ b/assets/keymaps/macos/cursor.json @@ -17,8 +17,8 @@ "bindings": { "cmd-i": "agent::ToggleFocus", "cmd-shift-i": "agent::ToggleFocus", - "cmd-shift-l": "assistant::QuoteSelection", // In cursor uses "Ask" mode - "cmd-l": "assistant::QuoteSelection", // In cursor uses "Agent" mode + "cmd-shift-l": "agent::QuoteSelection", // In cursor uses "Ask" mode + "cmd-l": "agent::QuoteSelection", // In cursor uses "Agent" mode "cmd-k": "assistant::InlineAssist", "cmd-shift-k": "assistant::InsertIntoEditor" } diff --git a/assets/keymaps/macos/emacs.json b/assets/keymaps/macos/emacs.json index 0ff3796f03d85affdae88d009e88e73516ba385a..0f936ba2f968abe0759e4bb294271a5e5f501848 100755 --- a/assets/keymaps/macos/emacs.json +++ b/assets/keymaps/macos/emacs.json @@ -38,10 +38,11 @@ "alt-;": ["editor::ToggleComments", { "advance_downwards": false }], "ctrl-x ctrl-;": "editor::ToggleComments", "alt-.": "editor::GoToDefinition", // xref-find-definitions + "alt-?": "editor::FindAllReferences", // xref-find-references "alt-,": "pane::GoBack", // xref-pop-marker-stack "ctrl-x h": "editor::SelectAll", // mark-whole-buffer "ctrl-d": "editor::Delete", // delete-char - "alt-d": "editor::DeleteToNextWordEnd", // kill-word + "alt-d": ["editor::DeleteToNextWordEnd", { "ignore_newlines": false, "ignore_brackets": false }], // kill-word "ctrl-k": "editor::KillRingCut", // kill-line "ctrl-w": "editor::Cut", // kill-region "alt-w": "editor::Copy", // kill-ring-save diff --git a/assets/keymaps/macos/jetbrains.json b/assets/keymaps/macos/jetbrains.json index 66962811f48a429f2f5d036241c64d6549f60334..2c757c3a30a08eb55e8344945ab66baf91ce0c6b 100644 --- a/assets/keymaps/macos/jetbrains.json +++ b/assets/keymaps/macos/jetbrains.json @@ -127,7 +127,7 @@ { "context": "Workspace || Editor", "bindings": { - "alt-f12": "terminal_panel::ToggleFocus", + "alt-f12": "terminal_panel::Toggle", "cmd-shift-k": "git::Push" } }, diff --git a/assets/keymaps/macos/sublime_text.json b/assets/keymaps/macos/sublime_text.json index 9fa528c75fa75061c34d767c3e9f9082c9eb2a81..a1e61bf8859e2e4ea227ed3dbe22ec29eb35a149 100644 --- a/assets/keymaps/macos/sublime_text.json +++ b/assets/keymaps/macos/sublime_text.json @@ -52,8 +52,8 @@ "cmd-k cmd-l": "editor::ConvertToLowerCase", "cmd-shift-j": "editor::JoinLines", "shift-alt-m": "markdown::OpenPreviewToTheSide", - "ctrl-backspace": "editor::DeleteToPreviousWordStart", - "ctrl-delete": "editor::DeleteToNextWordEnd", + "ctrl-backspace": ["editor::DeleteToPreviousWordStart", { "ignore_newlines": false, "ignore_brackets": false }], + "ctrl-delete": ["editor::DeleteToNextWordEnd", { "ignore_newlines": false, "ignore_brackets": false }], "ctrl-right": "editor::MoveToNextSubwordEnd", "ctrl-left": "editor::MoveToPreviousSubwordStart", "ctrl-shift-right": "editor::SelectToNextSubwordEnd", diff --git a/assets/keymaps/macos/textmate.json b/assets/keymaps/macos/textmate.json index 0bd8873b1749d2423d97df480b1aadeb28fe9bab..f91f39b7f5c079f81b5fcf8e28e2092a33ff1aa4 100644 --- a/assets/keymaps/macos/textmate.json +++ b/assets/keymaps/macos/textmate.json @@ -21,10 +21,10 @@ { "context": "Editor", "bindings": { - "alt-backspace": "editor::DeleteToPreviousWordStart", - "alt-shift-backspace": "editor::DeleteToNextWordEnd", - "alt-delete": "editor::DeleteToNextWordEnd", - "alt-shift-delete": "editor::DeleteToNextWordEnd", + "alt-backspace": ["editor::DeleteToPreviousWordStart", { "ignore_newlines": false, "ignore_brackets": false }], + "alt-shift-backspace": ["editor::DeleteToNextWordEnd", { "ignore_newlines": false, "ignore_brackets": false }], + "alt-delete": ["editor::DeleteToNextWordEnd", { "ignore_newlines": false, "ignore_brackets": false }], + "alt-shift-delete": ["editor::DeleteToNextWordEnd", { "ignore_newlines": false, "ignore_brackets": false }], "ctrl-backspace": "editor::DeleteToPreviousSubwordStart", "ctrl-delete": "editor::DeleteToNextSubwordEnd", "alt-left": ["editor::MoveToPreviousWordStart", { "stop_at_soft_wraps": true }], diff --git a/assets/keymaps/vim.json b/assets/keymaps/vim.json index 98f9cafc40e69f9eb7bcc248e02176f85e5d8838..78d4b3e7072c1ee8db57fdd31ef8afa1f3375b15 100644 --- a/assets/keymaps/vim.json +++ b/assets/keymaps/vim.json @@ -32,32 +32,6 @@ "(": "vim::SentenceBackward", ")": "vim::SentenceForward", "|": "vim::GoToColumn", - "] ]": "vim::NextSectionStart", - "] [": "vim::NextSectionEnd", - "[ [": "vim::PreviousSectionStart", - "[ ]": "vim::PreviousSectionEnd", - "] m": "vim::NextMethodStart", - "] shift-m": "vim::NextMethodEnd", - "[ m": "vim::PreviousMethodStart", - "[ shift-m": "vim::PreviousMethodEnd", - "[ *": "vim::PreviousComment", - "[ /": "vim::PreviousComment", - "] *": "vim::NextComment", - "] /": "vim::NextComment", - "[ -": "vim::PreviousLesserIndent", - "[ +": "vim::PreviousGreaterIndent", - "[ =": "vim::PreviousSameIndent", - "] -": "vim::NextLesserIndent", - "] +": "vim::NextGreaterIndent", - "] =": "vim::NextSameIndent", - "] b": "pane::ActivateNextItem", - "[ b": "pane::ActivatePreviousItem", - "] shift-b": "pane::ActivateLastItem", - "[ shift-b": ["pane::ActivateItem", 0], - "] space": "vim::InsertEmptyLineBelow", - "[ space": "vim::InsertEmptyLineAbove", - "[ e": "editor::MoveLineUp", - "] e": "editor::MoveLineDown", // Word motions "w": "vim::NextWordStart", @@ -81,10 +55,6 @@ "n": "vim::MoveToNextMatch", "shift-n": "vim::MoveToPreviousMatch", "%": "vim::Matching", - "] }": ["vim::UnmatchedForward", { "char": "}" }], - "[ {": ["vim::UnmatchedBackward", { "char": "{" }], - "] )": ["vim::UnmatchedForward", { "char": ")" }], - "[ (": ["vim::UnmatchedBackward", { "char": "(" }], "f": ["vim::PushFindForward", { "before": false, "multiline": false }], "t": ["vim::PushFindForward", { "before": true, "multiline": false }], "shift-f": ["vim::PushFindBackward", { "after": false, "multiline": false }], @@ -217,6 +187,46 @@ ".": "vim::Repeat" } }, + { + "context": "vim_mode == normal || vim_mode == visual || vim_mode == operator", + "bindings": { + "] ]": "vim::NextSectionStart", + "] [": "vim::NextSectionEnd", + "[ [": "vim::PreviousSectionStart", + "[ ]": "vim::PreviousSectionEnd", + "] m": "vim::NextMethodStart", + "] shift-m": "vim::NextMethodEnd", + "[ m": "vim::PreviousMethodStart", + "[ shift-m": "vim::PreviousMethodEnd", + "[ *": "vim::PreviousComment", + "[ /": "vim::PreviousComment", + "] *": "vim::NextComment", + "] /": "vim::NextComment", + "[ -": "vim::PreviousLesserIndent", + "[ +": "vim::PreviousGreaterIndent", + "[ =": "vim::PreviousSameIndent", + "] -": "vim::NextLesserIndent", + "] +": "vim::NextGreaterIndent", + "] =": "vim::NextSameIndent", + "] b": "pane::ActivateNextItem", + "[ b": "pane::ActivatePreviousItem", + "] shift-b": "pane::ActivateLastItem", + "[ shift-b": ["pane::ActivateItem", 0], + "] space": "vim::InsertEmptyLineBelow", + "[ space": "vim::InsertEmptyLineAbove", + "[ e": "editor::MoveLineUp", + "] e": "editor::MoveLineDown", + "[ f": "workspace::FollowNextCollaborator", + "] f": "workspace::FollowNextCollaborator", + "] }": ["vim::UnmatchedForward", { "char": "}" }], + "[ {": ["vim::UnmatchedBackward", { "char": "{" }], + "] )": ["vim::UnmatchedForward", { "char": ")" }], + "[ (": ["vim::UnmatchedBackward", { "char": "(" }], + // tree-sitter related commands + "[ x": "vim::SelectLargerSyntaxNode", + "] x": "vim::SelectSmallerSyntaxNode" + } + }, { "context": "vim_mode == normal", "bindings": { @@ -247,9 +257,6 @@ "g w": "vim::PushRewrap", "g q": "vim::PushRewrap", "insert": "vim::InsertBefore", - // tree-sitter related commands - "[ x": "vim::SelectLargerSyntaxNode", - "] x": "vim::SelectSmallerSyntaxNode", "] d": "editor::GoToDiagnostic", "[ d": "editor::GoToPreviousDiagnostic", "] c": "editor::GoToHunk", @@ -315,10 +322,7 @@ "g w": "vim::Rewrap", "g ?": "vim::ConvertToRot13", // "g ?": "vim::ConvertToRot47", - "\"": "vim::PushRegister", - // tree-sitter related commands - "[ x": "editor::SelectLargerSyntaxNode", - "] x": "editor::SelectSmallerSyntaxNode" + "\"": "vim::PushRegister" } }, { @@ -335,7 +339,7 @@ "ctrl-x ctrl-z": "editor::Cancel", "ctrl-x ctrl-e": "vim::LineDown", "ctrl-x ctrl-y": "vim::LineUp", - "ctrl-w": "editor::DeleteToPreviousWordStart", + "ctrl-w": ["editor::DeleteToPreviousWordStart", { "ignore_newlines": false, "ignore_brackets": false }], "ctrl-u": "editor::DeleteToBeginningOfLine", "ctrl-t": "vim::Indent", "ctrl-d": "vim::Outdent", @@ -352,6 +356,15 @@ "ctrl-s": "editor::ShowSignatureHelp" } }, + { + "context": "showing_completions", + "bindings": { + "ctrl-d": "vim::ScrollDown", + "ctrl-u": "vim::ScrollUp", + "ctrl-e": "vim::LineDown", + "ctrl-y": "vim::LineUp" + } + }, { "context": "(vim_mode == normal || vim_mode == helix_normal) && !menu", "bindings": { @@ -386,11 +399,14 @@ "ctrl-[": "editor::Cancel", ";": "vim::HelixCollapseSelection", ":": "command_palette::Toggle", + "m": "vim::PushHelixMatch", + "]": ["vim::PushHelixNext", { "around": true }], + "[": ["vim::PushHelixPrevious", { "around": true }], "left": "vim::WrappingLeft", "right": "vim::WrappingRight", "h": "vim::WrappingLeft", "l": "vim::WrappingRight", - "y": "editor::Copy", + "y": "vim::HelixYank", "alt-;": "vim::OtherEnd", "ctrl-r": "vim::Redo", "f": ["vim::PushFindForward", { "before": false, "multiline": true }], @@ -407,13 +423,7 @@ "g w": "vim::PushRewrap", "insert": "vim::InsertBefore", "alt-.": "vim::RepeatFind", - // tree-sitter related commands - "[ x": "editor::SelectLargerSyntaxNode", - "] x": "editor::SelectSmallerSyntaxNode", - "] d": "editor::GoToDiagnostic", - "[ d": "editor::GoToPreviousDiagnostic", - "] c": "editor::GoToHunk", - "[ c": "editor::GoToPreviousHunk", + "alt-s": ["editor::SplitSelectionIntoLines", { "keep_selections": true }], // Goto mode "g n": "pane::ActivateNextItem", "g p": "pane::ActivatePreviousItem", @@ -425,12 +435,14 @@ "g h": "vim::StartOfLine", "g s": "vim::FirstNonWhitespace", // "g s" default behavior is "space s" "g e": "vim::EndOfDocument", + "g .": "vim::HelixGotoLastModification", // go to last modification "g r": "editor::FindAllReferences", // zed specific "g t": "vim::WindowTop", "g c": "vim::WindowMiddle", "g b": "vim::WindowBottom", - "x": "editor::SelectLine", + "shift-r": "editor::Paste", + "x": "vim::HelixSelectLine", "shift-x": "editor::SelectLine", "%": "editor::SelectAll", // Window mode @@ -455,9 +467,6 @@ "space c": "editor::ToggleComments", "space y": "editor::Copy", "space p": "editor::Paste", - // Match mode - "m m": "vim::Matching", - "m i w": ["workspace::SendKeystrokes", "v i w"], "shift-u": "editor::Redo", "ctrl-c": "editor::ToggleComments", "d": "vim::HelixDelete", @@ -526,7 +535,7 @@ } }, { - "context": "vim_operator == a || vim_operator == i || vim_operator == cs", + "context": "vim_operator == a || vim_operator == i || vim_operator == cs || vim_operator == helix_next || vim_operator == helix_previous", "bindings": { "w": "vim::Word", "shift-w": ["vim::Word", { "ignore_punctuation": true }], @@ -563,6 +572,48 @@ "e": "vim::EntireFile" } }, + { + "context": "vim_operator == helix_m", + "bindings": { + "m": "vim::Matching" + } + }, + { + "context": "vim_operator == helix_next", + "bindings": { + "z": "vim::NextSectionStart", + "shift-z": "vim::NextSectionEnd", + "*": "vim::NextComment", + "/": "vim::NextComment", + "-": "vim::NextLesserIndent", + "+": "vim::NextGreaterIndent", + "=": "vim::NextSameIndent", + "b": "pane::ActivateNextItem", + "shift-b": "pane::ActivateLastItem", + "x": "editor::SelectSmallerSyntaxNode", + "d": "editor::GoToDiagnostic", + "c": "editor::GoToHunk", + "space": "vim::InsertEmptyLineBelow" + } + }, + { + "context": "vim_operator == helix_previous", + "bindings": { + "z": "vim::PreviousSectionStart", + "shift-z": "vim::PreviousSectionEnd", + "*": "vim::PreviousComment", + "/": "vim::PreviousComment", + "-": "vim::PreviousLesserIndent", + "+": "vim::PreviousGreaterIndent", + "=": "vim::PreviousSameIndent", + "b": "pane::ActivatePreviousItem", + "shift-b": ["pane::ActivateItem", 0], + "x": "editor::SelectLargerSyntaxNode", + "d": "editor::GoToPreviousDiagnostic", + "c": "editor::GoToPreviousHunk", + "space": "vim::InsertEmptyLineAbove" + } + }, { "context": "vim_operator == c", "bindings": { @@ -809,14 +860,14 @@ "j": "menu::SelectNext", "k": "menu::SelectPrevious", "l": "project_panel::ExpandSelectedEntry", - "o": "project_panel::OpenPermanent", "shift-d": "project_panel::Delete", "shift-r": "project_panel::Rename", "t": "project_panel::OpenPermanent", - "v": "project_panel::OpenPermanent", + "v": "project_panel::OpenSplitVertical", + "o": "project_panel::OpenSplitHorizontal", "p": "project_panel::Open", "x": "project_panel::RevealInFileManager", - "s": "project_panel::OpenWithSystem", + "s": "workspace::OpenWithSystem", "z d": "project_panel::CompareMarkedFiles", "] c": "project_panel::SelectNextGitEntry", "[ c": "project_panel::SelectPrevGitEntry", diff --git a/assets/prompts/assistant_system_prompt.hbs b/assets/prompts/assistant_system_prompt.hbs index b4545f5a7449bf8c562ea15d722ae8199c42e97a..f47c1ffa908b861eb81d37642a7634616c92a0d9 100644 --- a/assets/prompts/assistant_system_prompt.hbs +++ b/assets/prompts/assistant_system_prompt.hbs @@ -172,7 +172,7 @@ The user has specified the following rules that should be applied: Rules title: {{title}} {{/if}} `````` -{{contents}}} +{{contents}} `````` {{/each}} {{/if}} diff --git a/assets/settings/default.json b/assets/settings/default.json index 2c3bf6930d04a2668d65c20a191de17033de4aac..94df7a62ccf503081f62848c57e421f939cc64ec 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -1,4 +1,5 @@ { + "project_name": null, // The name of the Zed theme to use for the UI. // // `mode` is one of: @@ -71,8 +72,8 @@ "ui_font_weight": 400, // The default font size for text in the UI "ui_font_size": 16, - // The default font size for text in the agent panel - "agent_font_size": 16, + // The default font size for text in the agent panel. Falls back to the UI font size if unset. + "agent_font_size": null, // How much to fade out unused code. "unnecessary_code_fade": 0.3, // Active pane styling settings. @@ -162,6 +163,12 @@ // 2. Always quit the application // "on_last_window_closed": "quit_app", "on_last_window_closed": "platform_default", + // Whether to show padding for zoomed panels. + // When enabled, zoomed center panels (e.g. code editor) will have padding all around, + // while zoomed bottom/left/right panels will have padding to the top/right/left (respectively). + // + // Default: true + "zoomed_padding": true, // Whether to use the system provided dialogs for Open and Save As. // When set to false, Zed will use the built-in keyboard-first pickers. "use_system_path_prompts": true, @@ -182,8 +189,8 @@ // 4. A box drawn around the following character // "hollow" // - // Default: not set, defaults to "bar" - "cursor_shape": null, + // Default: "bar" + "cursor_shape": "bar", // Determines when the mouse cursor should be hidden in an editor or input box. // // 1. Never hide the mouse cursor: @@ -217,9 +224,25 @@ "current_line_highlight": "all", // Whether to highlight all occurrences of the selected text in an editor. "selection_highlight": true, + // Whether the text selection should have rounded corners. + "rounded_selection": true, // The debounce delay before querying highlights from the language // server based on the current cursor location. "lsp_highlight_debounce": 75, + // The minimum APCA perceptual contrast between foreground and background colors. + // APCA (Accessible Perceptual Contrast Algorithm) is more accurate than WCAG 2.x, + // especially for dark mode. Values range from 0 to 106. + // + // Based on APCA Readability Criterion (ARC) Bronze Simple Mode: + // https://readtech.org/ARC/tests/bronze-simple-mode/ + // - 0: No contrast adjustment + // - 45: Minimum for large fluent text (36px+) + // - 60: Minimum for other content text + // - 75: Minimum for body text + // - 90: Preferred for body text + // + // This only affects text drawn over highlight backgrounds in the editor. + "minimum_contrast_for_highlights": 45, // Whether to pop the completions menu while typing in an editor without // explicitly requesting it. "show_completions_on_input": true, @@ -260,8 +283,8 @@ // - "warning" // - "info" // - "hint" - // - null — allow all diagnostics (default) - "diagnostics_max_severity": null, + // - "all" — allow all diagnostics (default) + "diagnostics_max_severity": "all", // Whether to show wrap guides (vertical rulers) in the editor. // Setting this to true will show a guide at the 'preferred_line_length' value // if 'soft_wrap' is set to 'preferred_line_length', and will show any @@ -273,6 +296,8 @@ "redact_private_values": false, // The default number of lines to expand excerpts in the multibuffer by. "expand_excerpt_lines": 5, + // The default number of context lines shown in multibuffer excerpts. + "excerpt_context_lines": 2, // Globs to match against file paths to determine if a file is private. "private_files": ["**/.env*", "**/*.pem", "**/*.key", "**/*.cert", "**/*.crt", "**/secrets.yml"], // Whether to use additional LSP queries to format (and amend) the code after @@ -286,6 +311,8 @@ // bracket, brace, single or double quote characters. // For example, when you select text and type (, Zed will surround the text with (). "use_auto_surround": true, + /// Whether indentation should be adjusted based on the context whilst typing. + "auto_indent": true, // Whether indentation of pasted content should be adjusted based on the context. "auto_indent_on_paste": true, // Controls how the editor handles the autoclosed characters. @@ -335,6 +362,11 @@ // - It is adjacent to an edge (start or end) // - It is adjacent to a whitespace (left or right) "show_whitespaces": "selection", + // Visible characters used to render whitespace when show_whitespaces is enabled. + "whitespace_map": { + "space": "•", + "tab": "→" + }, // Settings related to calls in Zed "calls": { // Join calls with the microphone live by default @@ -355,6 +387,8 @@ // Whether to show code action buttons in the editor toolbar. "code_actions": false }, + // Whether to allow windows to tab together based on the user’s tabbing preference (macOS only). + "use_system_window_tabs": false, // Titlebar related settings "title_bar": { // Whether to show the branch icon beside branch switcher in the titlebar. @@ -645,6 +679,8 @@ // "never" "show": "always" }, + // Whether to enable drag-and-drop operations in the project panel. + "drag_and_drop": true, // Whether to hide the root entry when only one folder is open in the window. "hide_root": false }, @@ -710,20 +746,10 @@ // Default width of the collaboration panel. "default_width": 240 }, - "chat_panel": { - // When to show the chat panel button in the status bar. - // Can be 'never', 'always', or 'when_in_call', - // or a boolean (interpreted as 'never'/'always'). - "button": "when_in_call", - // Where to the chat panel. Can be 'left' or 'right'. - "dock": "right", - // Default width of the chat panel. - "default_width": 240 - }, "git_panel": { // Whether to show the git panel button in the status bar. "button": true, - // Where to show the git panel. Can be 'left' or 'right'. + // Where to dock the git panel. Can be 'left' or 'right'. "dock": "left", // Default width of the git panel. "default_width": 360, @@ -808,6 +834,9 @@ // } ], // When enabled, the agent can run potentially destructive actions without asking for your confirmation. + // + // Note: This setting has no effect on external agents that support permission modes, such as Claude Code. + // You can set `agent_servers.claude.default_mode` to `bypassPermissions` to skip all permission requests. "always_allow_tool_actions": false, // When enabled, the agent will stream edits. "stream_edits": false, @@ -887,11 +916,6 @@ }, // The settings for slash commands. "slash_commands": { - // Settings for the `/docs` slash command. - "docs": { - // Whether `/docs` is enabled. - "enabled": false - }, // Settings for the `/project` slash command. "project": { // Whether `/project` is enabled. @@ -937,7 +961,7 @@ // Show git status colors in the editor tabs. "git_status": false, // Position of the close button on the editor tabs. - // One of: ["right", "left", "hidden"] + // One of: ["right", "left"] "close_position": "right", // Whether to show the file icon for a tab. "file_icons": false, @@ -1136,11 +1160,6 @@ // The minimum severity of the diagnostics to show inline. // Inherits editor's diagnostics' max severity settings when `null`. "max_severity": null - }, - "cargo": { - // When enabled, Zed disables rust-analyzer's check on save and starts to query - // Cargo diagnostics separately. - "fetch_cargo_diagnostics": false } }, // Files or globs of files that will be excluded by Zed entirely. They will be skipped during file @@ -1185,6 +1204,10 @@ // The minimum column number to show the inline blame information at "min_column": 0 }, + // Control which information is shown in the branch picker. + "branch_picker": { + "show_author_name": true + }, // How git hunks are displayed visually in the editor. // This setting can take two values: // @@ -1256,7 +1279,9 @@ // Status bar-related settings. "status_bar": { // Whether to show the active language button in the status bar. - "active_language_button": true + "active_language_button": true, + // Whether to show the cursor position button in the status bar. + "cursor_position_button": true }, // Settings specific to the terminal "terminal": { @@ -1504,6 +1529,11 @@ // // Default: fallback "words": "fallback", + // Minimum number of characters required to automatically trigger word-based completions. + // Before that value, it's still possible to trigger the words-based completion manually with the corresponding editor command. + // + // Default: 3 + "words_min_length": 3, // Whether to fetch LSP completions or not. // // Default: true @@ -1576,7 +1606,7 @@ "ensure_final_newline_on_save": false }, "Elixir": { - "language_servers": ["elixir-ls", "!next-ls", "!lexical", "..."] + "language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."] }, "Elm": { "tab_size": 4 @@ -1601,7 +1631,7 @@ } }, "HEEX": { - "language_servers": ["elixir-ls", "!next-ls", "!lexical", "..."] + "language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."] }, "HTML": { "prettier": { @@ -1630,6 +1660,9 @@ "allowed": true } }, + "Kotlin": { + "language_servers": ["kotlin-language-server", "!kotlin-lsp", "..."] + }, "LaTeX": { "formatter": "language_server", "language_servers": ["texlab", "..."], @@ -1643,9 +1676,6 @@ "use_on_type_format": false, "allow_rewrap": "anywhere", "soft_wrap": "editor_width", - "completions": { - "words": "disabled" - }, "prettier": { "allowed": true } @@ -1659,9 +1689,6 @@ } }, "Plain Text": { - "completions": { - "words": "disabled" - }, "allow_rewrap": "anywhere" }, "Python": { @@ -1752,7 +1779,7 @@ "api_url": "http://localhost:1234/api/v0" }, "deepseek": { - "api_url": "https://api.deepseek.com" + "api_url": "https://api.deepseek.com/v1" }, "mistral": { "api_url": "https://api.mistral.ai/v1" @@ -1900,7 +1927,10 @@ "debugger": { "stepping_granularity": "line", "save_breakpoints": true, + "timeout": 2000, "dock": "bottom", + "log_dap_communications": true, + "format_dap_log_messages": true, "button": true }, // Configures any number of settings profiles that are temporarily applied on diff --git a/assets/settings/initial_tasks.json b/assets/settings/initial_tasks.json index a79c550671f85d7b107db5e85883caa28fe41411..5cead67b6d5bb89e878e3bfb8d250dcbbd2ce447 100644 --- a/assets/settings/initial_tasks.json +++ b/assets/settings/initial_tasks.json @@ -43,8 +43,8 @@ // "args": ["--login"] // } // } - "shell": "system", + "shell": "system" // Represents the tags for inline runnable indicators, or spawning multiple tasks at once. - "tags": [] + // "tags": [] } ] diff --git a/assets/themes/ayu/ayu.json b/assets/themes/ayu/ayu.json index f9f8720729008efb9a17cf45bd23ce51df7d3657..f71048caafba7156b53fd8637ca35c715ad300f2 100644 --- a/assets/themes/ayu/ayu.json +++ b/assets/themes/ayu/ayu.json @@ -93,7 +93,7 @@ "terminal.ansi.bright_cyan": "#4c806fff", "terminal.ansi.dim_cyan": "#cbf2e4ff", "terminal.ansi.white": "#bfbdb6ff", - "terminal.ansi.bright_white": "#bfbdb6ff", + "terminal.ansi.bright_white": "#fafafaff", "terminal.ansi.dim_white": "#787876ff", "link_text.hover": "#5ac1feff", "conflict": "#feb454ff", @@ -316,6 +316,11 @@ "font_style": null, "font_weight": null }, + "punctuation.markup": { + "color": "#a6a5a0ff", + "font_style": null, + "font_weight": null + }, "punctuation.special": { "color": "#d2a6ffff", "font_style": null, @@ -479,7 +484,7 @@ "terminal.ansi.bright_cyan": "#ace0cbff", "terminal.ansi.dim_cyan": "#2a5f4aff", "terminal.ansi.white": "#fcfcfcff", - "terminal.ansi.bright_white": "#fcfcfcff", + "terminal.ansi.bright_white": "#ffffffff", "terminal.ansi.dim_white": "#bcbec0ff", "link_text.hover": "#3b9ee5ff", "conflict": "#f1ad49ff", @@ -702,6 +707,11 @@ "font_style": null, "font_weight": null }, + "punctuation.markup": { + "color": "#73777bff", + "font_style": null, + "font_weight": null + }, "punctuation.special": { "color": "#a37accff", "font_style": null, @@ -865,7 +875,7 @@ "terminal.ansi.bright_cyan": "#4c806fff", "terminal.ansi.dim_cyan": "#cbf2e4ff", "terminal.ansi.white": "#cccac2ff", - "terminal.ansi.bright_white": "#cccac2ff", + "terminal.ansi.bright_white": "#fafafaff", "terminal.ansi.dim_white": "#898a8aff", "link_text.hover": "#72cffeff", "conflict": "#fecf72ff", @@ -1088,6 +1098,11 @@ "font_style": null, "font_weight": null }, + "punctuation.markup": { + "color": "#b4b3aeff", + "font_style": null, + "font_weight": null + }, "punctuation.special": { "color": "#dfbfffff", "font_style": null, diff --git a/assets/themes/gruvbox/gruvbox.json b/assets/themes/gruvbox/gruvbox.json index 459825c733dbf2eae1e5269885b1b2c135bd72c4..fc11cac55f638349778c88869dcb217c89111022 100644 --- a/assets/themes/gruvbox/gruvbox.json +++ b/assets/themes/gruvbox/gruvbox.json @@ -94,7 +94,7 @@ "terminal.ansi.bright_cyan": "#45603eff", "terminal.ansi.dim_cyan": "#c7dfbdff", "terminal.ansi.white": "#fbf1c7ff", - "terminal.ansi.bright_white": "#fbf1c7ff", + "terminal.ansi.bright_white": "#ffffffff", "terminal.ansi.dim_white": "#b0a189ff", "link_text.hover": "#83a598ff", "version_control.added": "#b7bb26ff", @@ -325,6 +325,11 @@ "font_style": null, "font_weight": null }, + "punctuation.markup": { + "color": "#83a598ff", + "font_style": null, + "font_weight": null + }, "punctuation.special": { "color": "#e5d5adff", "font_style": null, @@ -494,7 +499,7 @@ "terminal.ansi.bright_cyan": "#45603eff", "terminal.ansi.dim_cyan": "#c7dfbdff", "terminal.ansi.white": "#fbf1c7ff", - "terminal.ansi.bright_white": "#fbf1c7ff", + "terminal.ansi.bright_white": "#ffffffff", "terminal.ansi.dim_white": "#b0a189ff", "link_text.hover": "#83a598ff", "version_control.added": "#b7bb26ff", @@ -725,6 +730,11 @@ "font_style": null, "font_weight": null }, + "punctuation.markup": { + "color": "#83a598ff", + "font_style": null, + "font_weight": null + }, "punctuation.special": { "color": "#e5d5adff", "font_style": null, @@ -894,7 +904,7 @@ "terminal.ansi.bright_cyan": "#45603eff", "terminal.ansi.dim_cyan": "#c7dfbdff", "terminal.ansi.white": "#fbf1c7ff", - "terminal.ansi.bright_white": "#fbf1c7ff", + "terminal.ansi.bright_white": "#ffffffff", "terminal.ansi.dim_white": "#b0a189ff", "link_text.hover": "#83a598ff", "version_control.added": "#b7bb26ff", @@ -1125,6 +1135,11 @@ "font_style": null, "font_weight": null }, + "punctuation.markup": { + "color": "#83a598ff", + "font_style": null, + "font_weight": null + }, "punctuation.special": { "color": "#e5d5adff", "font_style": null, @@ -1294,7 +1309,7 @@ "terminal.ansi.bright_cyan": "#9fbca8ff", "terminal.ansi.dim_cyan": "#253e2eff", "terminal.ansi.white": "#fbf1c7ff", - "terminal.ansi.bright_white": "#fbf1c7ff", + "terminal.ansi.bright_white": "#ffffffff", "terminal.ansi.dim_white": "#b0a189ff", "link_text.hover": "#0b6678ff", "version_control.added": "#797410ff", @@ -1525,6 +1540,11 @@ "font_style": null, "font_weight": null }, + "punctuation.markup": { + "color": "#066578ff", + "font_style": null, + "font_weight": null + }, "punctuation.special": { "color": "#413d3aff", "font_style": null, @@ -1694,7 +1714,7 @@ "terminal.ansi.bright_cyan": "#9fbca8ff", "terminal.ansi.dim_cyan": "#253e2eff", "terminal.ansi.white": "#f9f5d7ff", - "terminal.ansi.bright_white": "#f9f5d7ff", + "terminal.ansi.bright_white": "#ffffffff", "terminal.ansi.dim_white": "#b0a189ff", "link_text.hover": "#0b6678ff", "version_control.added": "#797410ff", @@ -1925,6 +1945,11 @@ "font_style": null, "font_weight": null }, + "punctuation.markup": { + "color": "#066578ff", + "font_style": null, + "font_weight": null + }, "punctuation.special": { "color": "#413d3aff", "font_style": null, @@ -2094,7 +2119,7 @@ "terminal.ansi.bright_cyan": "#9fbca8ff", "terminal.ansi.dim_cyan": "#253e2eff", "terminal.ansi.white": "#f2e5bcff", - "terminal.ansi.bright_white": "#f2e5bcff", + "terminal.ansi.bright_white": "#ffffffff", "terminal.ansi.dim_white": "#b0a189ff", "link_text.hover": "#0b6678ff", "version_control.added": "#797410ff", @@ -2325,6 +2350,11 @@ "font_style": null, "font_weight": null }, + "punctuation.markup": { + "color": "#066578ff", + "font_style": null, + "font_weight": null + }, "punctuation.special": { "color": "#413d3aff", "font_style": null, diff --git a/assets/themes/one/one.json b/assets/themes/one/one.json index 23ebbcc67efaa9ca45748a5726ac1fd72488c451..7cc8c96a23f32aab69596722188e3c5ec87aba08 100644 --- a/assets/themes/one/one.json +++ b/assets/themes/one/one.json @@ -93,7 +93,7 @@ "terminal.ansi.bright_cyan": "#3a565bff", "terminal.ansi.dim_cyan": "#b9d9dfff", "terminal.ansi.white": "#dce0e5ff", - "terminal.ansi.bright_white": "#dce0e5ff", + "terminal.ansi.bright_white": "#fafafaff", "terminal.ansi.dim_white": "#575d65ff", "link_text.hover": "#74ade8ff", "version_control.added": "#27a657ff", @@ -321,6 +321,11 @@ "font_style": null, "font_weight": null }, + "punctuation.markup": { + "color": "#d07277ff", + "font_style": null, + "font_weight": null + }, "punctuation.special": { "color": "#b1574bff", "font_style": null, @@ -468,7 +473,7 @@ "terminal.bright_foreground": "#242529ff", "terminal.dim_foreground": "#fafafaff", "terminal.ansi.black": "#242529ff", - "terminal.ansi.bright_black": "#242529ff", + "terminal.ansi.bright_black": "#747579ff", "terminal.ansi.dim_black": "#97979aff", "terminal.ansi.red": "#d36151ff", "terminal.ansi.bright_red": "#f0b0a4ff", @@ -489,7 +494,7 @@ "terminal.ansi.bright_cyan": "#a3bedaff", "terminal.ansi.dim_cyan": "#254058ff", "terminal.ansi.white": "#fafafaff", - "terminal.ansi.bright_white": "#fafafaff", + "terminal.ansi.bright_white": "#ffffffff", "terminal.ansi.dim_white": "#aaaaaaff", "link_text.hover": "#5c78e2ff", "version_control.added": "#27a657ff", @@ -715,6 +720,11 @@ "font_style": null, "font_weight": null }, + "punctuation.markup": { + "color": "#d3604fff", + "font_style": null, + "font_weight": null + }, "punctuation.special": { "color": "#b92b46ff", "font_style": null, diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 2ac15de08f331e555e80883cd66c9e5beefe0a32..a0bbda848f9ec761aebdf66b644a8b2926685122 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -13,33 +13,39 @@ path = "src/acp_thread.rs" doctest = false [features] -test-support = ["gpui/test-support", "project/test-support"] +test-support = ["gpui/test-support", "project/test-support", "dep:parking_lot"] [dependencies] action_log.workspace = true agent-client-protocol.workspace = true -agent.workspace = true +agent_settings.workspace = true anyhow.workspace = true buffer_diff.workspace = true collections.workspace = true editor.workspace = true +file_icons.workspace = true futures.workspace = true gpui.workspace = true itertools.workspace = true language.workspace = true +language_model.workspace = true markdown.workspace = true +parking_lot = { workspace = true, optional = true } +portable-pty.workspace = true project.workspace = true prompt_store.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true smol.workspace = true +task.workspace = true terminal.workspace = true ui.workspace = true url.workspace = true util.workspace = true uuid.workspace = true watch.workspace = true +which.workspace = true workspace-hack.workspace = true [dev-dependencies] diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index a5b512f31abec80cd4c79ef843471f95b0f4b22a..c7279abdc6d63ff77644549bb64db160abc446bf 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -3,13 +3,20 @@ mod diff; mod mention; mod terminal; +use agent_settings::AgentSettings; +use collections::HashSet; pub use connection::*; pub use diff::*; +use futures::future::Shared; +use language::language_settings::FormatOnSave; pub use mention::*; +use project::lsp_store::{FormatTrigger, LspFormatTarget}; +use serde::{Deserialize, Serialize}; +use settings::Settings as _; pub use terminal::*; use action_log::ActionLog; -use agent_client_protocol as acp; +use agent_client_protocol::{self as acp}; use anyhow::{Context as _, Result, anyhow}; use editor::Bias; use futures::{FutureExt, channel::oneshot, future::BoxFuture}; @@ -24,21 +31,34 @@ use std::fmt::{Formatter, Write}; use std::ops::Range; use std::process::ExitStatus; use std::rc::Rc; +use std::time::{Duration, Instant}; use std::{fmt::Display, mem, path::PathBuf, sync::Arc}; use ui::App; -use util::ResultExt; +use util::{ResultExt, get_system_shell}; +use uuid::Uuid; #[derive(Debug)] pub struct UserMessage { pub id: Option, pub content: ContentBlock, - pub checkpoint: Option, + pub chunks: Vec, + pub checkpoint: Option, +} + +#[derive(Debug)] +pub struct Checkpoint { + git_checkpoint: GitStoreCheckpoint, + pub show: bool, } impl UserMessage { fn to_markdown(&self, cx: &App) -> String { let mut markdown = String::new(); - if let Some(_) = self.checkpoint { + if self + .checkpoint + .as_ref() + .is_some_and(|checkpoint| checkpoint.show) + { writeln!(markdown, "## User (checkpoint)").unwrap(); } else { writeln!(markdown, "## User").unwrap(); @@ -98,7 +118,7 @@ pub enum AgentThreadEntry { } impl AgentThreadEntry { - fn to_markdown(&self, cx: &App) -> String { + pub fn to_markdown(&self, cx: &App) -> String { match self { Self::UserMessage(message) => message.to_markdown(cx), Self::AssistantMessage(message) => message.to_markdown(cx), @@ -106,6 +126,14 @@ impl AgentThreadEntry { } } + pub fn user_message(&self) -> Option<&UserMessage> { + if let AgentThreadEntry::UserMessage(message) = self { + Some(message) + } else { + None + } + } + pub fn diffs(&self) -> impl Iterator> { if let AgentThreadEntry::ToolCall(call) = self { itertools::Either::Left(call.diffs()) @@ -157,38 +185,46 @@ impl ToolCall { tool_call: acp::ToolCall, status: ToolCallStatus, language_registry: Arc, + terminals: &HashMap>, cx: &mut App, - ) -> Self { - Self { + ) -> Result { + let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") { + first_line.to_owned() + "…" + } else { + tool_call.title + }; + let mut content = Vec::with_capacity(tool_call.content.len()); + for item in tool_call.content { + content.push(ToolCallContent::from_acp( + item, + language_registry.clone(), + terminals, + cx, + )?); + } + + let result = Self { id: tool_call.id, - label: cx.new(|cx| { - Markdown::new( - tool_call.title.into(), - Some(language_registry.clone()), - None, - cx, - ) - }), + label: cx + .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)), kind: tool_call.kind, - content: tool_call - .content - .into_iter() - .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx)) - .collect(), + content, locations: tool_call.locations, resolved_locations: Vec::default(), status, raw_input: tool_call.raw_input, raw_output: tool_call.raw_output, - } + }; + Ok(result) } fn update_fields( &mut self, fields: acp::ToolCallUpdateFields, language_registry: Arc, + terminals: &HashMap>, cx: &mut App, - ) { + ) -> Result<()> { let acp::ToolCallUpdateFields { kind, status, @@ -204,20 +240,36 @@ impl ToolCall { } if let Some(status) = status { - self.status = ToolCallStatus::Allowed { status }; + self.status = status.into(); } if let Some(title) = title { self.label.update(cx, |label, cx| { - label.replace(title, cx); + if let Some((first_line, _)) = title.split_once("\n") { + label.replace(first_line.to_owned() + "…", cx) + } else { + label.replace(title, cx); + } }); } if let Some(content) = content { - self.content = content - .into_iter() - .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx)) - .collect(); + let new_content_len = content.len(); + let mut content = content.into_iter(); + + // Reuse existing content if we can + for (old, new) in self.content.iter_mut().zip(content.by_ref()) { + old.update_from_acp(new, language_registry.clone(), terminals, cx)?; + } + for new in content { + self.content.push(ToolCallContent::from_acp( + new, + language_registry.clone(), + terminals, + cx, + )?) + } + self.content.truncate(new_content_len); } if let Some(locations) = locations { @@ -229,17 +281,17 @@ impl ToolCall { } if let Some(raw_output) = raw_output { - if self.content.is_empty() { - if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx) - { - self.content - .push(ToolCallContent::ContentBlock(ContentBlock::Markdown { - markdown, - })); - } + if self.content.is_empty() + && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx) + { + self.content + .push(ToolCallContent::ContentBlock(ContentBlock::Markdown { + markdown, + })); } self.raw_output = Some(raw_output); } + Ok(()) } pub fn diffs(&self) -> impl Iterator> { @@ -278,11 +330,9 @@ impl ToolCall { ) -> Option { let buffer = project .update(cx, |project, cx| { - if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) { - Some(project.open_buffer(path, cx)) - } else { - None - } + project + .project_path_for_absolute_path(&location.path, cx) + .map(|path| project.open_buffer(path, cx)) }) .ok()??; let buffer = buffer.await.log_err()?; @@ -325,30 +375,48 @@ impl ToolCall { #[derive(Debug)] pub enum ToolCallStatus { + /// The tool call hasn't started running yet, but we start showing it to + /// the user. + Pending, + /// The tool call is waiting for confirmation from the user. WaitingForConfirmation { options: Vec, respond_tx: oneshot::Sender, }, - Allowed { - status: acp::ToolCallStatus, - }, + /// The tool call is currently running. + InProgress, + /// The tool call completed successfully. + Completed, + /// The tool call failed. + Failed, + /// The user rejected the tool call. Rejected, + /// The user canceled generation so the tool call was canceled. Canceled, } +impl From for ToolCallStatus { + fn from(status: acp::ToolCallStatus) -> Self { + match status { + acp::ToolCallStatus::Pending => Self::Pending, + acp::ToolCallStatus::InProgress => Self::InProgress, + acp::ToolCallStatus::Completed => Self::Completed, + acp::ToolCallStatus::Failed => Self::Failed, + } + } +} + impl Display for ToolCallStatus { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( f, "{}", match self { + ToolCallStatus::Pending => "Pending", ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation", - ToolCallStatus::Allowed { status } => match status { - acp::ToolCallStatus::Pending => "Pending", - acp::ToolCallStatus::InProgress => "In Progress", - acp::ToolCallStatus::Completed => "Completed", - acp::ToolCallStatus::Failed => "Failed", - }, + ToolCallStatus::InProgress => "In Progress", + ToolCallStatus::Completed => "Completed", + ToolCallStatus::Failed => "Failed", ToolCallStatus::Rejected => "Rejected", ToolCallStatus::Canceled => "Canceled", } @@ -392,11 +460,11 @@ impl ContentBlock { language_registry: &Arc, cx: &mut App, ) { - if matches!(self, ContentBlock::Empty) { - if let acp::ContentBlock::ResourceLink(resource_link) = block { - *self = ContentBlock::ResourceLink { resource_link }; - return; - } + if matches!(self, ContentBlock::Empty) + && let acp::ContentBlock::ResourceLink(resource_link) = block + { + *self = ContentBlock::ResourceLink { resource_link }; + return; } let new_content = self.block_string_contents(block); @@ -430,7 +498,7 @@ impl ContentBlock { fn block_string_contents(&self, block: acp::ContentBlock) -> String { match block { - acp::ContentBlock::Text(text_content) => text_content.text.clone(), + acp::ContentBlock::Text(text_content) => text_content.text, acp::ContentBlock::ResourceLink(resource_link) => { Self::resource_link_md(&resource_link.uri) } @@ -442,21 +510,24 @@ impl ContentBlock { }), .. }) => Self::resource_link_md(&uri), - acp::ContentBlock::Image(_) - | acp::ContentBlock::Audio(_) - | acp::ContentBlock::Resource(_) => String::new(), + acp::ContentBlock::Image(image) => Self::image_md(&image), + acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(), } } fn resource_link_md(uri: &str) -> String { - if let Some(uri) = MentionUri::parse(&uri).log_err() { + if let Some(uri) = MentionUri::parse(uri).log_err() { uri.as_link().to_string() } else { uri.to_string() } } - fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str { + fn image_md(_image: &acp::ImageContent) -> String { + "`Image`".into() + } + + pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str { match self { ContentBlock::Empty => "", ContentBlock::Markdown { markdown } => markdown.read(cx).source(), @@ -491,16 +562,54 @@ impl ToolCallContent { pub fn from_acp( content: acp::ToolCallContent, language_registry: Arc, + terminals: &HashMap>, cx: &mut App, - ) -> Self { + ) -> Result { match content { - acp::ToolCallContent::Content { content } => { - Self::ContentBlock(ContentBlock::new(content, &language_registry, cx)) - } - acp::ToolCallContent::Diff { diff } => { - Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx))) + acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new( + content, + &language_registry, + cx, + ))), + acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| { + Diff::finalized( + diff.path, + diff.old_text, + diff.new_text, + language_registry, + cx, + ) + }))), + acp::ToolCallContent::Terminal { terminal_id } => terminals + .get(&terminal_id) + .cloned() + .map(Self::Terminal) + .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)), + } + } + + pub fn update_from_acp( + &mut self, + new: acp::ToolCallContent, + language_registry: Arc, + terminals: &HashMap>, + cx: &mut App, + ) -> Result<()> { + let needs_update = match (&self, &new) { + (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => { + old_diff.read(cx).needs_update( + new_diff.old_text.as_deref().unwrap_or(""), + &new_diff.new_text, + cx, + ) } + _ => true, + }; + + if needs_update { + *self = Self::from_acp(new, language_registry, terminals, cx)?; } + Ok(()) } pub fn to_markdown(&self, cx: &App) -> String { @@ -618,6 +727,52 @@ impl PlanEntry { } } +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct TokenUsage { + pub max_tokens: u64, + pub used_tokens: u64, +} + +impl TokenUsage { + pub fn ratio(&self) -> TokenUsageRatio { + #[cfg(debug_assertions)] + let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD") + .unwrap_or("0.8".to_string()) + .parse() + .unwrap(); + #[cfg(not(debug_assertions))] + let warning_threshold: f32 = 0.8; + + // When the maximum is unknown because there is no selected model, + // avoid showing the token limit warning. + if self.max_tokens == 0 { + TokenUsageRatio::Normal + } else if self.used_tokens >= self.max_tokens { + TokenUsageRatio::Exceeded + } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold { + TokenUsageRatio::Warning + } else { + TokenUsageRatio::Normal + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TokenUsageRatio { + Normal, + Warning, + Exceeded, +} + +#[derive(Debug, Clone)] +pub struct RetryStatus { + pub last_error: SharedString, + pub attempt: usize, + pub max_attempts: usize, + pub started_at: Instant, + pub duration: Duration, +} + pub struct AcpThread { title: SharedString, entries: Vec, @@ -628,44 +783,69 @@ pub struct AcpThread { send_task: Option>, connection: Rc, session_id: acp::SessionId, + token_usage: Option, + prompt_capabilities: acp::PromptCapabilities, + _observe_prompt_capabilities: Task>, + determine_shell: Shared>, + terminals: HashMap>, } +#[derive(Debug)] pub enum AcpThreadEvent { NewEntry, + TitleUpdated, + TokenUsageUpdated, EntryUpdated(usize), EntriesRemoved(Range), ToolAuthorizationRequired, + Retry(RetryStatus), Stopped, Error, - ServerExited(ExitStatus), + LoadError(LoadError), + PromptCapabilitiesUpdated, + Refusal, + AvailableCommandsUpdated(Vec), + ModeUpdated(acp::SessionModeId), } impl EventEmitter for AcpThread {} -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Debug)] pub enum ThreadStatus { Idle, - WaitingForToolConfirmation, Generating, } #[derive(Debug, Clone)] pub enum LoadError { Unsupported { - error_message: SharedString, - upgrade_message: SharedString, - upgrade_command: String, + command: SharedString, + current_version: SharedString, + minimum_version: SharedString, + }, + FailedToInstall(SharedString), + Exited { + status: ExitStatus, }, - Exited(i32), Other(SharedString), } impl Display for LoadError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message), - LoadError::Exited(status) => write!(f, "Server exited with status {}", status), - LoadError::Other(msg) => write!(f, "{}", msg), + LoadError::Unsupported { + command: path, + current_version, + minimum_version, + } => { + write!( + f, + "version {current_version} from {path} is not supported (need at least {minimum_version})" + ) + } + LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"), + LoadError::Exited { status } => write!(f, "Server exited with status {status}"), + LoadError::Other(msg) => write!(f, "{msg}"), } } } @@ -677,10 +857,35 @@ impl AcpThread { title: impl Into, connection: Rc, project: Entity, + action_log: Entity, session_id: acp::SessionId, + mut prompt_capabilities_rx: watch::Receiver, cx: &mut Context, ) -> Self { - let action_log = cx.new(|_| ActionLog::new(project.clone())); + let prompt_capabilities = *prompt_capabilities_rx.borrow(); + let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| { + loop { + let caps = prompt_capabilities_rx.recv().await?; + this.update(cx, |this, cx| { + this.prompt_capabilities = caps; + cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated); + })?; + } + }); + + let determine_shell = cx + .background_spawn(async move { + if cfg!(windows) { + return get_system_shell(); + } + + if which::which("bash").is_ok() { + "bash".into() + } else { + get_system_shell() + } + }) + .shared(); Self { action_log, @@ -692,9 +897,18 @@ impl AcpThread { send_task: None, connection, session_id, + token_usage: None, + prompt_capabilities, + _observe_prompt_capabilities: task, + terminals: HashMap::default(), + determine_shell, } } + pub fn prompt_capabilities(&self) -> acp::PromptCapabilities { + self.prompt_capabilities + } + pub fn connection(&self) -> &Rc { &self.connection } @@ -721,27 +935,23 @@ impl AcpThread { pub fn status(&self) -> ThreadStatus { if self.send_task.is_some() { - if self.waiting_for_tool_confirmation() { - ThreadStatus::WaitingForToolConfirmation - } else { - ThreadStatus::Generating - } + ThreadStatus::Generating } else { ThreadStatus::Idle } } + pub fn token_usage(&self) -> Option<&TokenUsage> { + self.token_usage.as_ref() + } + pub fn has_pending_edit_tool_calls(&self) -> bool { for entry in self.entries.iter().rev() { match entry { AgentThreadEntry::UserMessage(_) => return false, AgentThreadEntry::ToolCall( call @ ToolCall { - status: - ToolCallStatus::Allowed { - status: - acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending, - }, + status: ToolCallStatus::InProgress | ToolCallStatus::Pending, .. }, ) if call.diffs().next().is_some() => { @@ -770,7 +980,7 @@ impl AcpThread { &mut self, update: acp::SessionUpdate, cx: &mut Context, - ) -> Result<()> { + ) -> Result<(), acp::Error> { match update { acp::SessionUpdate::UserMessageChunk { content } => { self.push_user_content_block(None, content, cx); @@ -782,7 +992,7 @@ impl AcpThread { self.push_assistant_content_block(content, true, cx); } acp::SessionUpdate::ToolCall(tool_call) => { - self.upsert_tool_call(tool_call, cx); + self.upsert_tool_call(tool_call, cx)?; } acp::SessionUpdate::ToolCallUpdate(tool_call_update) => { self.update_tool_call(tool_call_update, cx)?; @@ -790,6 +1000,12 @@ impl AcpThread { acp::SessionUpdate::Plan(plan) => { self.update_plan(plan, cx); } + acp::SessionUpdate::AvailableCommandsUpdate { available_commands } => { + cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)) + } + acp::SessionUpdate::CurrentModeUpdate { current_mode_id } => { + cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)) + } } Ok(()) } @@ -804,18 +1020,25 @@ impl AcpThread { let entries_len = self.entries.len(); if let Some(last_entry) = self.entries.last_mut() - && let AgentThreadEntry::UserMessage(UserMessage { id, content, .. }) = last_entry + && let AgentThreadEntry::UserMessage(UserMessage { + id, + content, + chunks, + .. + }) = last_entry { *id = message_id.or(id.take()); - content.append(chunk, &language_registry, cx); + content.append(chunk.clone(), &language_registry, cx); + chunks.push(chunk); let idx = entries_len - 1; cx.emit(AcpThreadEvent::EntryUpdated(idx)); } else { - let content = ContentBlock::new(chunk, &language_registry, cx); + let content = ContentBlock::new(chunk.clone(), &language_registry, cx); self.push_entry( AgentThreadEntry::UserMessage(UserMessage { id: message_id, content, + chunks: vec![chunk], checkpoint: None, }), cx, @@ -872,6 +1095,30 @@ impl AcpThread { cx.emit(AcpThreadEvent::NewEntry); } + pub fn can_set_title(&mut self, cx: &mut Context) -> bool { + self.connection.set_title(&self.session_id, cx).is_some() + } + + pub fn set_title(&mut self, title: SharedString, cx: &mut Context) -> Task> { + if title != self.title { + self.title = title.clone(); + cx.emit(AcpThreadEvent::TitleUpdated); + if let Some(set_title) = self.connection.set_title(&self.session_id, cx) { + return set_title.run(title, cx); + } + } + Task::ready(Ok(())) + } + + pub fn update_token_usage(&mut self, usage: Option, cx: &mut Context) { + self.token_usage = usage; + cx.emit(AcpThreadEvent::TokenUsageUpdated); + } + + pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context) { + cx.emit(AcpThreadEvent::Retry(status)); + } + pub fn update_tool_call( &mut self, update: impl Into, @@ -880,27 +1127,28 @@ impl AcpThread { let update = update.into(); let languages = self.project.read(cx).languages().clone(); - let (ix, current_call) = self - .tool_call_mut(update.id()) + let ix = self + .index_for_tool_call(update.id()) .context("Tool call not found")?; + let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else { + unreachable!() + }; + match update { ToolCallUpdate::UpdateFields(update) => { let location_updated = update.fields.locations.is_some(); - current_call.update_fields(update.fields, languages, cx); + call.update_fields(update.fields, languages, &self.terminals, cx)?; if location_updated { - self.resolve_locations(update.id.clone(), cx); + self.resolve_locations(update.id, cx); } } ToolCallUpdate::UpdateDiff(update) => { - current_call.content.clear(); - current_call - .content - .push(ToolCallContent::Diff(update.diff)); + call.content.clear(); + call.content.push(ToolCallContent::Diff(update.diff)); } ToolCallUpdate::UpdateTerminal(update) => { - current_call.content.clear(); - current_call - .content + call.content.clear(); + call.content .push(ToolCallContent::Terminal(update.terminal)); } } @@ -911,32 +1159,63 @@ impl AcpThread { } /// Updates a tool call if id matches an existing entry, otherwise inserts a new one. - pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context) { - let status = ToolCallStatus::Allowed { - status: tool_call.status, - }; - self.upsert_tool_call_inner(tool_call, status, cx) + pub fn upsert_tool_call( + &mut self, + tool_call: acp::ToolCall, + cx: &mut Context, + ) -> Result<(), acp::Error> { + let status = tool_call.status.into(); + self.upsert_tool_call_inner(tool_call.into(), status, cx) } + /// Fails if id does not match an existing entry. pub fn upsert_tool_call_inner( &mut self, - tool_call: acp::ToolCall, + update: acp::ToolCallUpdate, status: ToolCallStatus, cx: &mut Context, - ) { + ) -> Result<(), acp::Error> { let language_registry = self.project.read(cx).languages().clone(); - let call = ToolCall::from_acp(tool_call, status, language_registry, cx); - let id = call.id.clone(); + let id = update.id.clone(); - if let Some((ix, current_call)) = self.tool_call_mut(&call.id) { - *current_call = call; + if let Some(ix) = self.index_for_tool_call(&id) { + let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else { + unreachable!() + }; + + call.update_fields(update.fields, language_registry, &self.terminals, cx)?; + call.status = status; cx.emit(AcpThreadEvent::EntryUpdated(ix)); } else { + let call = ToolCall::from_acp( + update.try_into()?, + status, + language_registry, + &self.terminals, + cx, + )?; self.push_entry(AgentThreadEntry::ToolCall(call), cx); }; self.resolve_locations(id, cx); + Ok(()) + } + + fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option { + self.entries + .iter() + .enumerate() + .rev() + .find_map(|(index, entry)| { + if let AgentThreadEntry::ToolCall(tool_call) = entry + && &tool_call.id == id + { + Some(index) + } else { + None + } + }) } fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> { @@ -957,6 +1236,22 @@ impl AcpThread { }) } + pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> { + self.entries + .iter() + .enumerate() + .rev() + .find_map(|(index, tool_call)| { + if let AgentThreadEntry::ToolCall(tool_call) = tool_call + && &tool_call.id == id + { + Some((index, tool_call)) + } else { + None + } + }) + } + pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context) { let project = self.project.clone(); let Some((_, tool_call)) = self.tool_call_mut(&id) else { @@ -1005,20 +1300,50 @@ impl AcpThread { pub fn request_tool_call_authorization( &mut self, - tool_call: acp::ToolCall, + tool_call: acp::ToolCallUpdate, options: Vec, + respect_always_allow_setting: bool, cx: &mut Context, - ) -> oneshot::Receiver { + ) -> Result> { let (tx, rx) = oneshot::channel(); + if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions { + // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions, + // some tools would (incorrectly) continue to auto-accept. + if let Some(allow_once_option) = options.iter().find_map(|option| { + if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) { + Some(option.id.clone()) + } else { + None + } + }) { + self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?; + return Ok(async { + acp::RequestPermissionOutcome::Selected { + option_id: allow_once_option, + } + } + .boxed()); + } + } + let status = ToolCallStatus::WaitingForConfirmation { options, respond_tx: tx, }; - self.upsert_tool_call_inner(tool_call, status, cx); + self.upsert_tool_call_inner(tool_call, status, cx)?; cx.emit(AcpThreadEvent::ToolAuthorizationRequired); - rx + + let fut = async { + match rx.await { + Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option }, + Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled, + } + } + .boxed(); + + Ok(fut) } pub fn authorize_tool_call( @@ -1037,9 +1362,7 @@ impl AcpThread { ToolCallStatus::Rejected } acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => { - ToolCallStatus::Allowed { - status: acp::ToolCallStatus::InProgress, - } + ToolCallStatus::InProgress } }; @@ -1054,23 +1377,27 @@ impl AcpThread { cx.emit(AcpThreadEvent::EntryUpdated(ix)); } - /// Returns true if the last turn is awaiting tool authorization - pub fn waiting_for_tool_confirmation(&self) -> bool { + pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> { + let mut first_tool_call = None; + for entry in self.entries.iter().rev() { match &entry { - AgentThreadEntry::ToolCall(call) => match call.status { - ToolCallStatus::WaitingForConfirmation { .. } => return true, - ToolCallStatus::Allowed { .. } - | ToolCallStatus::Rejected - | ToolCallStatus::Canceled => continue, - }, + AgentThreadEntry::ToolCall(call) => { + if let ToolCallStatus::WaitingForConfirmation { .. } = call.status { + first_tool_call = Some(call); + } else { + continue; + } + } AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => { - // Reached the beginning of the turn - return false; + // Reached the beginning of the turn. + // If we had pending permission requests in the previous turn, they have been cancelled. + break; } } } - false + + first_tool_call } pub fn plan(&self) -> &Plan { @@ -1134,85 +1461,90 @@ impl AcpThread { self.project.read(cx).languages().clone(), cx, ); + let request = acp::PromptRequest { + prompt: message.clone(), + session_id: self.session_id.clone(), + }; let git_store = self.project.read(cx).git_store().clone(); - let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx)); - let message_id = if self - .connection - .session_editor(&self.session_id, cx) - .is_some() - { + let message_id = if self.connection.truncate(&self.session_id, cx).is_some() { Some(UserMessageId::new()) } else { None }; - self.push_entry( - AgentThreadEntry::UserMessage(UserMessage { - id: message_id.clone(), - content: block, - checkpoint: None, - }), - cx, - ); + + self.run_turn(cx, async move |this, cx| { + this.update(cx, |this, cx| { + this.push_entry( + AgentThreadEntry::UserMessage(UserMessage { + id: message_id.clone(), + content: block, + chunks: message, + checkpoint: None, + }), + cx, + ); + }) + .ok(); + + let old_checkpoint = git_store + .update(cx, |git, cx| git.checkpoint(cx))? + .await + .context("failed to get old checkpoint") + .log_err(); + this.update(cx, |this, cx| { + if let Some((_ix, message)) = this.last_user_message() { + message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint { + git_checkpoint, + show: false, + }); + } + this.connection.prompt(message_id, request, cx) + })? + .await + }) + } + + pub fn can_resume(&self, cx: &App) -> bool { + self.connection.resume(&self.session_id, cx).is_some() + } + + pub fn resume(&mut self, cx: &mut Context) -> BoxFuture<'static, Result<()>> { + self.run_turn(cx, async move |this, cx| { + this.update(cx, |this, cx| { + this.connection + .resume(&this.session_id, cx) + .map(|resume| resume.run(cx)) + })? + .context("resuming a session is not supported")? + .await + }) + } + + fn run_turn( + &mut self, + cx: &mut Context, + f: impl 'static + AsyncFnOnce(WeakEntity, &mut AsyncApp) -> Result, + ) -> BoxFuture<'static, Result<()>> { self.clear_completed_plan_entries(cx); - let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel(); let (tx, rx) = oneshot::channel(); let cancel_task = self.cancel(cx); - let request = acp::PromptRequest { - prompt: message, - session_id: self.session_id.clone(), - }; - self.send_task = Some(cx.spawn({ - let message_id = message_id.clone(); - async move |this, cx| { - cancel_task.await; - - old_checkpoint_tx.send(old_checkpoint.await).ok(); - if let Ok(result) = this.update(cx, |this, cx| { - this.connection.prompt(message_id, request, cx) - }) { - tx.send(result.await).log_err(); - } - } + self.send_task = Some(cx.spawn(async move |this, cx| { + cancel_task.await; + tx.send(f(this, cx).await).ok(); })); cx.spawn(async move |this, cx| { - let old_checkpoint = old_checkpoint_rx - .await - .map_err(|_| anyhow!("send canceled")) - .flatten() - .context("failed to get old checkpoint") - .log_err(); - let response = rx.await; - if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) { - let new_checkpoint = git_store - .update(cx, |git, cx| git.checkpoint(cx))? - .await - .context("failed to get new checkpoint") - .log_err(); - if let Some(new_checkpoint) = new_checkpoint { - let equal = git_store - .update(cx, |git, cx| { - git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx) - })? - .await - .unwrap_or(true); - if !equal { - this.update(cx, |this, cx| { - if let Some((ix, message)) = this.user_message_mut(&message_id) { - message.checkpoint = Some(old_checkpoint); - cx.emit(AcpThreadEvent::EntryUpdated(ix)); - } - })?; - } - } - } + this.update(cx, |this, cx| this.update_last_checkpoint(cx))? + .await?; this.update(cx, |this, cx| { + this.project + .update(cx, |project, cx| project.set_agent_location(None, cx)); match response { Ok(Err(e)) => { this.send_task.take(); @@ -1220,22 +1552,60 @@ impl AcpThread { Err(e) } result => { - let cancelled = matches!( + let canceled = matches!( result, Ok(Ok(acp::PromptResponse { stop_reason: acp::StopReason::Cancelled })) ); - // We only take the task if the current prompt wasn't cancelled. + // We only take the task if the current prompt wasn't canceled. // - // This prompt may have been cancelled because another one was sent + // This prompt may have been canceled because another one was sent // while it was still generating. In these cases, dropping `send_task` - // would cause the next generation to be cancelled. - if !cancelled { + // would cause the next generation to be canceled. + if !canceled { this.send_task.take(); } + // Handle refusal - distinguish between user prompt and tool call refusals + if let Ok(Ok(acp::PromptResponse { + stop_reason: acp::StopReason::Refusal, + })) = result + { + if let Some((user_msg_ix, _)) = this.last_user_message() { + // Check if there's a completed tool call with results after the last user message + // This indicates the refusal is in response to tool output, not the user's prompt + let has_completed_tool_call_after_user_msg = + this.entries.iter().skip(user_msg_ix + 1).any(|entry| { + if let AgentThreadEntry::ToolCall(tool_call) = entry { + // Check if the tool call has completed and has output + matches!(tool_call.status, ToolCallStatus::Completed) + && tool_call.raw_output.is_some() + } else { + false + } + }); + + if has_completed_tool_call_after_user_msg { + // Refusal is due to tool output - don't truncate, just notify + // The model refused based on what the tool returned + cx.emit(AcpThreadEvent::Refusal); + } else { + // User prompt was refused - truncate back to before the user message + let range = user_msg_ix..this.entries.len(); + if range.start < range.end { + this.entries.truncate(user_msg_ix); + cx.emit(AcpThreadEvent::EntriesRemoved(range)); + } + cx.emit(AcpThreadEvent::Refusal); + } + } else { + // No user message found, treat as general refusal + cx.emit(AcpThreadEvent::Refusal); + } + } + cx.emit(AcpThreadEvent::Stopped); Ok(()) } @@ -1254,10 +1624,9 @@ impl AcpThread { if let AgentThreadEntry::ToolCall(call) = entry { let cancel = matches!( call.status, - ToolCallStatus::WaitingForConfirmation { .. } - | ToolCallStatus::Allowed { - status: acp::ToolCallStatus::InProgress - } + ToolCallStatus::Pending + | ToolCallStatus::WaitingForConfirmation { .. } + | ToolCallStatus::InProgress ); if cancel { @@ -1272,56 +1641,116 @@ impl AcpThread { cx.foreground_executor().spawn(send_task) } - /// Rewinds this thread to before the entry at `index`, removing it and all - /// subsequent entries while reverting any changes made from that point. - pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context) -> Task> { - let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else { - return Task::ready(Err(anyhow!("not supported"))); - }; - let Some(message) = self.user_message(&id) else { + /// Restores the git working tree to the state at the given checkpoint (if one exists) + pub fn restore_checkpoint( + &mut self, + id: UserMessageId, + cx: &mut Context, + ) -> Task> { + let Some((_, message)) = self.user_message_mut(&id) else { return Task::ready(Err(anyhow!("message not found"))); }; - let checkpoint = message.checkpoint.clone(); - + let checkpoint = message + .checkpoint + .as_ref() + .map(|c| c.git_checkpoint.clone()); + let rewind = self.rewind(id.clone(), cx); let git_store = self.project.read(cx).git_store().clone(); - cx.spawn(async move |this, cx| { + + cx.spawn(async move |_, cx| { + rewind.await?; if let Some(checkpoint) = checkpoint { git_store .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))? .await?; } - cx.update(|cx| session_editor.truncate(id.clone(), cx))? - .await?; + Ok(()) + }) + } + + /// Rewinds this thread to before the entry at `index`, removing it and all + /// subsequent entries while rejecting any action_log changes made from that point. + /// Unlike `restore_checkpoint`, this method does not restore from git. + pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context) -> Task> { + let Some(truncate) = self.connection.truncate(&self.session_id, cx) else { + return Task::ready(Err(anyhow!("not supported"))); + }; + + cx.spawn(async move |this, cx| { + cx.update(|cx| truncate.run(id.clone(), cx))?.await?; this.update(cx, |this, cx| { if let Some((ix, _)) = this.user_message_mut(&id) { let range = ix..this.entries.len(); this.entries.truncate(ix); cx.emit(AcpThreadEvent::EntriesRemoved(range)); } - }) - }) + this.action_log() + .update(cx, |action_log, cx| action_log.reject_all_edits(cx)) + })? + .await; + Ok(()) + }) } - fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> { - self.entries.iter().find_map(|entry| { - if let AgentThreadEntry::UserMessage(message) = entry { - if message.id.as_ref() == Some(&id) { - Some(message) - } else { - None - } + fn update_last_checkpoint(&mut self, cx: &mut Context) -> Task> { + let git_store = self.project.read(cx).git_store().clone(); + + let old_checkpoint = if let Some((_, message)) = self.last_user_message() { + if let Some(checkpoint) = message.checkpoint.as_ref() { + checkpoint.git_checkpoint.clone() } else { - None + return Task::ready(Ok(())); + } + } else { + return Task::ready(Ok(())); + }; + + let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx)); + cx.spawn(async move |this, cx| { + let new_checkpoint = new_checkpoint + .await + .context("failed to get new checkpoint") + .log_err(); + if let Some(new_checkpoint) = new_checkpoint { + let equal = git_store + .update(cx, |git, cx| { + git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx) + })? + .await + .unwrap_or(true); + this.update(cx, |this, cx| { + let (ix, message) = this.last_user_message().context("no user message")?; + let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?; + checkpoint.show = !equal; + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + anyhow::Ok(()) + })??; } + + Ok(()) }) } + fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> { + self.entries + .iter_mut() + .enumerate() + .rev() + .find_map(|(ix, entry)| { + if let AgentThreadEntry::UserMessage(message) = entry { + Some((ix, message)) + } else { + None + } + }) + } + fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> { self.entries.iter_mut().enumerate().find_map(|(ix, entry)| { if let AgentThreadEntry::UserMessage(message) = entry { - if message.id.as_ref() == Some(&id) { + if message.id.as_ref() == Some(id) { Some((ix, message)) } else { None @@ -1446,42 +1875,198 @@ impl AcpThread { .collect::>() }) .await; - cx.update(|cx| { - project.update(cx, |project, cx| { - project.set_agent_location( - Some(AgentLocation { - buffer: buffer.downgrade(), - position: edits - .last() - .map(|(range, _)| range.end) - .unwrap_or(Anchor::MIN), - }), - cx, - ); - }); + project.update(cx, |project, cx| { + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position: edits + .last() + .map(|(range, _)| range.end) + .unwrap_or(Anchor::MIN), + }), + cx, + ); + })?; + + let format_on_save = cx.update(|cx| { action_log.update(cx, |action_log, cx| { action_log.buffer_read(buffer.clone(), cx); }); - buffer.update(cx, |buffer, cx| { + + let format_on_save = buffer.update(cx, |buffer, cx| { buffer.edit(edits, None, cx); + + let settings = language::language_settings::language_settings( + buffer.language().map(|l| l.name()), + buffer.file(), + cx, + ); + + settings.format_on_save != FormatOnSave::Off }); action_log.update(cx, |action_log, cx| { action_log.buffer_edited(buffer.clone(), cx); }); + format_on_save })?; + + if format_on_save { + let format_task = project.update(cx, |project, cx| { + project.format( + HashSet::from_iter([buffer.clone()]), + LspFormatTarget::Buffers, + false, + FormatTrigger::Save, + cx, + ) + })?; + format_task.await.log_err(); + + action_log.update(cx, |action_log, cx| { + action_log.buffer_edited(buffer.clone(), cx); + })?; + } + project .update(cx, |project, cx| project.save_buffer(buffer, cx))? .await }) } + pub fn create_terminal( + &self, + mut command: String, + args: Vec, + extra_env: Vec, + cwd: Option, + output_byte_limit: Option, + cx: &mut Context, + ) -> Task>> { + for arg in args { + command.push(' '); + command.push_str(&arg); + } + + let shell_command = if cfg!(windows) { + format!("$null | & {{{}}}", command.replace("\"", "'")) + } else if let Some(cwd) = cwd.as_ref().and_then(|cwd| cwd.as_os_str().to_str()) { + // Make sure once we're *inside* the shell, we cd into `cwd` + format!("(cd {cwd}; {}) self.project.update(cx, |project, cx| { + project.directory_environment(dir.as_path().into(), cx) + }), + None => Task::ready(None).shared(), + }; + + let env = cx.spawn(async move |_, _| { + let mut env = env.await.unwrap_or_default(); + if cfg!(unix) { + env.insert("PAGER".into(), "cat".into()); + } + for var in extra_env { + env.insert(var.name, var.value); + } + env + }); + + let project = self.project.clone(); + let language_registry = project.read(cx).languages().clone(); + let determine_shell = self.determine_shell.clone(); + + let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into()); + let terminal_task = cx.spawn({ + let terminal_id = terminal_id.clone(); + async move |_this, cx| { + let program = determine_shell.await; + let env = env.await; + let terminal = project + .update(cx, |project, cx| { + project.create_terminal_task( + task::SpawnInTerminal { + command: Some(program), + args, + cwd: cwd.clone(), + env, + ..Default::default() + }, + cx, + ) + })? + .await?; + + cx.new(|cx| { + Terminal::new( + terminal_id, + command, + cwd, + output_byte_limit.map(|l| l as usize), + terminal, + language_registry, + cx, + ) + }) + } + }); + + cx.spawn(async move |this, cx| { + let terminal = terminal_task.await?; + this.update(cx, |this, _cx| { + this.terminals.insert(terminal_id, terminal.clone()); + terminal + }) + }) + } + + pub fn kill_terminal( + &mut self, + terminal_id: acp::TerminalId, + cx: &mut Context, + ) -> Result<()> { + self.terminals + .get(&terminal_id) + .context("Terminal not found")? + .update(cx, |terminal, cx| { + terminal.kill(cx); + }); + + Ok(()) + } + + pub fn release_terminal( + &mut self, + terminal_id: acp::TerminalId, + cx: &mut Context, + ) -> Result<()> { + self.terminals + .remove(&terminal_id) + .context("Terminal not found")? + .update(cx, |terminal, cx| { + terminal.kill(cx); + }); + + Ok(()) + } + + pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result> { + self.terminals + .get(&terminal_id) + .context("Terminal not found") + .cloned() + } + pub fn to_markdown(&self, cx: &App) -> String { self.entries.iter().map(|e| e.to_markdown(cx)).collect() } - pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context) { - cx.emit(AcpThreadEvent::ServerExited(status)); + pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context) { + cx.emit(AcpThreadEvent::LoadError(error)); } } @@ -1532,14 +2117,15 @@ mod tests { use super::*; use anyhow::anyhow; use futures::{channel::mpsc, future::LocalBoxFuture, select}; - use gpui::{AsyncApp, TestAppContext, WeakEntity}; + use gpui::{App, AsyncApp, TestAppContext, WeakEntity}; use indoc::indoc; use project::{FakeFs, Fs}; - use rand::Rng as _; + use rand::{distr, prelude::*}; use serde_json::json; use settings::SettingsStore; use smol::stream::StreamExt as _; use std::{ + any::Any, cell::RefCell, path::Path, rc::Rc, @@ -1566,11 +2152,7 @@ mod tests { let project = Project::test(fs, [], cx).await; let connection = Rc::new(FakeAgentConnection::new()); let thread = cx - .spawn(async move |mut cx| { - connection - .new_thread(project, Path::new(path!("/test")), &mut cx) - .await - }) + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -1690,11 +2272,7 @@ mod tests { )); let thread = cx - .spawn(async move |mut cx| { - connection - .new_thread(project, Path::new(path!("/test")), &mut cx) - .await - }) + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -1777,7 +2355,7 @@ mod tests { .unwrap(); let thread = cx - .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx)) + .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx)) .await .unwrap(); @@ -1840,11 +2418,7 @@ mod tests { })); let thread = cx - .spawn(async move |mut cx| { - connection - .new_thread(project, Path::new(path!("/test")), &mut cx) - .await - }) + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -1858,10 +2432,7 @@ mod tests { assert!(matches!( thread.entries[1], AgentThreadEntry::ToolCall(ToolCall { - status: ToolCallStatus::Allowed { - status: acp::ToolCallStatus::InProgress, - .. - }, + status: ToolCallStatus::InProgress, .. }) )); @@ -1900,10 +2471,7 @@ mod tests { assert!(matches!( thread.entries[1], AgentThreadEntry::ToolCall(ToolCall { - status: ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Completed, - .. - }, + status: ToolCallStatus::Completed, .. }) )); @@ -1952,10 +2520,11 @@ mod tests { } })); - let thread = connection - .new_thread(project, Path::new(path!("/test")), &mut cx.to_async()) + let thread = cx + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) .await .unwrap(); + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx))) .await .unwrap(); @@ -2012,8 +2581,8 @@ mod tests { .boxed_local() } })); - let thread = connection - .new_thread(project, Path::new(path!("/test")), &mut cx.to_async()) + let thread = cx + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -2035,7 +2604,7 @@ mod tests { "} ); }); - assert_eq!(fs.files(), vec![Path::new("/test/file-0")]); + assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]); cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx))) .await @@ -2065,7 +2634,10 @@ mod tests { }); assert_eq!( fs.files(), - vec![Path::new("/test/file-0"), Path::new("/test/file-1")] + vec![ + Path::new(path!("/test/file-0")), + Path::new(path!("/test/file-1")) + ] ); // Checkpoint isn't stored when there are no changes. @@ -2106,7 +2678,10 @@ mod tests { }); assert_eq!( fs.files(), - vec![Path::new("/test/file-0"), Path::new("/test/file-1")] + vec![ + Path::new(path!("/test/file-0")), + Path::new(path!("/test/file-1")) + ] ); // Rewinding the conversation truncates the history and restores the checkpoint. @@ -2115,7 +2690,7 @@ mod tests { let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else { panic!("unexpected entries {:?}", thread.entries) }; - thread.rewind(message.id.clone().unwrap(), cx) + thread.restore_checkpoint(message.id.clone().unwrap(), cx) }) .await .unwrap(); @@ -2134,7 +2709,274 @@ mod tests { "} ); }); - assert_eq!(fs.files(), vec![Path::new("/test/file-0")]); + assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]); + } + + #[gpui::test] + async fn test_tool_result_refusal(cx: &mut TestAppContext) { + use std::sync::atomic::AtomicUsize; + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, None, cx).await; + + // Create a connection that simulates refusal after tool result + let prompt_count = Arc::new(AtomicUsize::new(0)); + let connection = Rc::new(FakeAgentConnection::new().on_user_message({ + let prompt_count = prompt_count.clone(); + move |_request, thread, mut cx| { + let count = prompt_count.fetch_add(1, SeqCst); + async move { + if count == 0 { + // First prompt: Generate a tool call with result + thread.update(&mut cx, |thread, cx| { + thread + .handle_session_update( + acp::SessionUpdate::ToolCall(acp::ToolCall { + id: acp::ToolCallId("tool1".into()), + title: "Test Tool".into(), + kind: acp::ToolKind::Fetch, + status: acp::ToolCallStatus::Completed, + content: vec![], + locations: vec![], + raw_input: Some(serde_json::json!({"query": "test"})), + raw_output: Some( + serde_json::json!({"result": "inappropriate content"}), + ), + }), + cx, + ) + .unwrap(); + })?; + + // Now return refusal because of the tool result + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::Refusal, + }) + } else { + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + } + } + .boxed_local() + } + })); + + let thread = cx + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .await + .unwrap(); + + // Track if we see a Refusal event + let saw_refusal_event = Arc::new(std::sync::Mutex::new(false)); + let saw_refusal_event_captured = saw_refusal_event.clone(); + thread.update(cx, |_thread, cx| { + cx.subscribe( + &thread, + move |_thread, _event_thread, event: &AcpThreadEvent, _cx| { + if matches!(event, AcpThreadEvent::Refusal) { + *saw_refusal_event_captured.lock().unwrap() = true; + } + }, + ) + .detach(); + }); + + // Send a user message - this will trigger tool call and then refusal + let send_task = thread.update(cx, |thread, cx| { + thread.send( + vec![acp::ContentBlock::Text(acp::TextContent { + text: "Hello".into(), + annotations: None, + })], + cx, + ) + }); + cx.background_executor.spawn(send_task).detach(); + cx.run_until_parked(); + + // Verify that: + // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt) + // 2. The user message was NOT truncated + assert!( + *saw_refusal_event.lock().unwrap(), + "Refusal event should be emitted for tool result refusals" + ); + + thread.read_with(cx, |thread, _| { + let entries = thread.entries(); + assert!(entries.len() >= 2, "Should have user message and tool call"); + + // Verify user message is still there + assert!( + matches!(entries[0], AgentThreadEntry::UserMessage(_)), + "User message should not be truncated" + ); + + // Verify tool call is there with result + if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] { + assert!( + tool_call.raw_output.is_some(), + "Tool call should have output" + ); + } else { + panic!("Expected tool call at index 1"); + } + }); + } + + #[gpui::test] + async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, None, cx).await; + + let refuse_next = Arc::new(AtomicBool::new(false)); + let connection = Rc::new(FakeAgentConnection::new().on_user_message({ + let refuse_next = refuse_next.clone(); + move |_request, _thread, _cx| { + if refuse_next.load(SeqCst) { + async move { + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::Refusal, + }) + } + .boxed_local() + } else { + async move { + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + } + .boxed_local() + } + } + })); + + let thread = cx + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .await + .unwrap(); + + // Track if we see a Refusal event + let saw_refusal_event = Arc::new(std::sync::Mutex::new(false)); + let saw_refusal_event_captured = saw_refusal_event.clone(); + thread.update(cx, |_thread, cx| { + cx.subscribe( + &thread, + move |_thread, _event_thread, event: &AcpThreadEvent, _cx| { + if matches!(event, AcpThreadEvent::Refusal) { + *saw_refusal_event_captured.lock().unwrap() = true; + } + }, + ) + .detach(); + }); + + // Send a message that will be refused + refuse_next.store(true, SeqCst); + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx))) + .await + .unwrap(); + + // Verify that a Refusal event WAS emitted for user prompt refusal + assert!( + *saw_refusal_event.lock().unwrap(), + "Refusal event should be emitted for user prompt refusals" + ); + + // Verify the message was truncated (user prompt refusal) + thread.read_with(cx, |thread, cx| { + assert_eq!(thread.to_markdown(cx), ""); + }); + } + + #[gpui::test] + async fn test_refusal(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree(path!("/"), json!({})).await; + let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await; + + let refuse_next = Arc::new(AtomicBool::new(false)); + let connection = Rc::new(FakeAgentConnection::new().on_user_message({ + let refuse_next = refuse_next.clone(); + move |request, thread, mut cx| { + let refuse_next = refuse_next.clone(); + async move { + if refuse_next.load(SeqCst) { + return Ok(acp::PromptResponse { + stop_reason: acp::StopReason::Refusal, + }); + } + + let acp::ContentBlock::Text(content) = &request.prompt[0] else { + panic!("expected text content block"); + }; + thread.update(&mut cx, |thread, cx| { + thread + .handle_session_update( + acp::SessionUpdate::AgentMessageChunk { + content: content.text.to_uppercase().into(), + }, + cx, + ) + .unwrap(); + })?; + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + } + .boxed_local() + } + })); + let thread = cx + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .await + .unwrap(); + + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx))) + .await + .unwrap(); + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User + + hello + + ## Assistant + + HELLO + + "} + ); + }); + + // Simulate refusing the second message. The message should be truncated + // when a user prompt is refused. + refuse_next.store(true, SeqCst); + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx))) + .await + .unwrap(); + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User + + hello + + ## Assistant + + HELLO + + "} + ); + }); } async fn run_until_first_tool_call( @@ -2218,19 +3060,32 @@ mod tests { self: Rc, project: Entity, _cwd: &Path, - cx: &mut gpui::AsyncApp, + cx: &mut App, ) -> Task>> { let session_id = acp::SessionId( - rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) + rand::rng() + .sample_iter(&distr::Alphanumeric) .take(7) .map(char::from) .collect::() .into(), ); - let thread = cx - .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)) - .unwrap(); + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let thread = cx.new(|cx| { + AcpThread::new( + "Test", + self.clone(), + project, + action_log, + session_id.clone(), + watch::Receiver::constant(acp::PromptCapabilities { + image: true, + audio: true, + embedded_context: true, + }), + cx, + ) + }); self.sessions.lock().insert(session_id, thread.downgrade()); Task::ready(Ok(thread)) } @@ -2264,7 +3119,7 @@ mod tests { fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { let sessions = self.sessions.lock(); - let thread = sessions.get(&session_id).unwrap().clone(); + let thread = sessions.get(session_id).unwrap().clone(); cx.spawn(async move |cx| { thread @@ -2275,23 +3130,27 @@ mod tests { .detach(); } - fn session_editor( + fn truncate( &self, session_id: &acp::SessionId, - _cx: &mut App, - ) -> Option> { + _cx: &App, + ) -> Option> { Some(Rc::new(FakeAgentSessionEditor { _session_id: session_id.clone(), })) } + + fn into_any(self: Rc) -> Rc { + self + } } struct FakeAgentSessionEditor { _session_id: acp::SessionId, } - impl AgentSessionEditor for FakeAgentSessionEditor { - fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task> { + impl AgentSessionTruncate for FakeAgentSessionEditor { + fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task> { Task::ready(Ok(())) } } diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index c3167eb2d4fbdf66e7c45f574a227d215d18dca0..dfb1e3763d504e65bfbef636fb8c592643ce92c9 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -2,13 +2,15 @@ use crate::AcpThread; use agent_client_protocol::{self as acp}; use anyhow::Result; use collections::IndexMap; -use gpui::{AsyncApp, Entity, SharedString, Task}; +use gpui::{Entity, SharedString, Task}; +use language_model::LanguageModelProviderId; use project::Project; -use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc}; +use serde::{Deserialize, Serialize}; +use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc}; use ui::{App, IconName}; use uuid::Uuid; -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct UserMessageId(Arc); impl UserMessageId { @@ -22,7 +24,7 @@ pub trait AgentConnection { self: Rc, project: Entity, cwd: &Path, - cx: &mut AsyncApp, + cx: &mut App, ) -> Task>>; fn auth_methods(&self) -> &[acp::AuthMethod]; @@ -36,13 +38,29 @@ pub trait AgentConnection { cx: &mut App, ) -> Task>; + fn resume( + &self, + _session_id: &acp::SessionId, + _cx: &App, + ) -> Option> { + None + } + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); - fn session_editor( + fn truncate( + &self, + _session_id: &acp::SessionId, + _cx: &App, + ) -> Option> { + None + } + + fn set_title( &self, _session_id: &acp::SessionId, - _cx: &mut App, - ) -> Option> { + _cx: &App, + ) -> Option> { None } @@ -53,19 +71,90 @@ pub trait AgentConnection { fn model_selector(&self) -> Option> { None } + + fn telemetry(&self) -> Option> { + None + } + + fn session_modes( + &self, + _session_id: &acp::SessionId, + _cx: &App, + ) -> Option> { + None + } + + fn into_any(self: Rc) -> Rc; +} + +impl dyn AgentConnection { + pub fn downcast(self: Rc) -> Option> { + self.into_any().downcast().ok() + } +} + +pub trait AgentSessionTruncate { + fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task>; +} + +pub trait AgentSessionResume { + fn run(&self, cx: &mut App) -> Task>; +} + +pub trait AgentSessionSetTitle { + fn run(&self, title: SharedString, cx: &mut App) -> Task>; +} + +pub trait AgentTelemetry { + /// The name of the agent used for telemetry. + fn agent_name(&self) -> String; + + /// A representation of the current thread state that can be serialized for + /// storage with telemetry events. + fn thread_data( + &self, + session_id: &acp::SessionId, + cx: &mut App, + ) -> Task>; } -pub trait AgentSessionEditor { - fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task>; +pub trait AgentSessionModes { + fn current_mode(&self) -> acp::SessionModeId; + + fn all_modes(&self) -> Vec; + + fn set_mode(&self, mode: acp::SessionModeId, cx: &mut App) -> Task>; } #[derive(Debug)] -pub struct AuthRequired; +pub struct AuthRequired { + pub description: Option, + pub provider_id: Option, +} + +impl AuthRequired { + pub fn new() -> Self { + Self { + description: None, + provider_id: None, + } + } + + pub fn with_description(mut self, description: String) -> Self { + self.description = Some(description); + self + } + + pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self { + self.provider_id = Some(provider_id); + self + } +} impl Error for AuthRequired {} impl fmt::Display for AuthRequired { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "AuthRequired") + write!(f, "Authentication required") } } @@ -160,3 +249,228 @@ impl AgentModelList { } } } + +#[cfg(feature = "test-support")] +mod test_support { + use std::sync::Arc; + + use action_log::ActionLog; + use collections::HashMap; + use futures::{channel::oneshot, future::try_join_all}; + use gpui::{AppContext as _, WeakEntity}; + use parking_lot::Mutex; + + use super::*; + + #[derive(Clone, Default)] + pub struct StubAgentConnection { + sessions: Arc>>, + permission_requests: HashMap>, + next_prompt_updates: Arc>>, + } + + struct Session { + thread: WeakEntity, + response_tx: Option>, + } + + impl StubAgentConnection { + pub fn new() -> Self { + Self { + next_prompt_updates: Default::default(), + permission_requests: HashMap::default(), + sessions: Arc::default(), + } + } + + pub fn set_next_prompt_updates(&self, updates: Vec) { + *self.next_prompt_updates.lock() = updates; + } + + pub fn with_permission_requests( + mut self, + permission_requests: HashMap>, + ) -> Self { + self.permission_requests = permission_requests; + self + } + + pub fn send_update( + &self, + session_id: acp::SessionId, + update: acp::SessionUpdate, + cx: &mut App, + ) { + assert!( + self.next_prompt_updates.lock().is_empty(), + "Use either send_update or set_next_prompt_updates" + ); + + self.sessions + .lock() + .get(&session_id) + .unwrap() + .thread + .update(cx, |thread, cx| { + thread.handle_session_update(update, cx).unwrap(); + }) + .unwrap(); + } + + pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) { + self.sessions + .lock() + .get_mut(&session_id) + .unwrap() + .response_tx + .take() + .expect("No pending turn") + .send(stop_reason) + .unwrap(); + } + } + + impl AgentConnection for StubAgentConnection { + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::App, + ) -> Task>> { + let session_id = acp::SessionId(self.sessions.lock().len().to_string().into()); + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let thread = cx.new(|cx| { + AcpThread::new( + "Test", + self.clone(), + project, + action_log, + session_id.clone(), + watch::Receiver::constant(acp::PromptCapabilities { + image: true, + audio: true, + embedded_context: true, + }), + cx, + ) + }); + self.sessions.lock().insert( + session_id, + Session { + thread: thread.downgrade(), + response_tx: None, + }, + ); + Task::ready(Ok(thread)) + } + + fn authenticate( + &self, + _method_id: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { + unimplemented!() + } + + fn prompt( + &self, + _id: Option, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let mut sessions = self.sessions.lock(); + let Session { + thread, + response_tx, + } = sessions.get_mut(¶ms.session_id).unwrap(); + let mut tasks = vec![]; + if self.next_prompt_updates.lock().is_empty() { + let (tx, rx) = oneshot::channel(); + response_tx.replace(tx); + cx.spawn(async move |_| { + let stop_reason = rx.await?; + Ok(acp::PromptResponse { stop_reason }) + }) + } else { + for update in self.next_prompt_updates.lock().drain(..) { + let thread = thread.clone(); + let update = update.clone(); + let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = + &update + && let Some(options) = self.permission_requests.get(&tool_call.id) + { + Some((tool_call.clone(), options.clone())) + } else { + None + }; + let task = cx.spawn(async move |cx| { + if let Some((tool_call, options)) = permission_request { + thread + .update(cx, |thread, cx| { + thread.request_tool_call_authorization( + tool_call.clone().into(), + options.clone(), + false, + cx, + ) + })?? + .await; + } + thread.update(cx, |thread, cx| { + thread.handle_session_update(update.clone(), cx).unwrap(); + })?; + anyhow::Ok(()) + }); + tasks.push(task); + } + + cx.spawn(async move |_| { + try_join_all(tasks).await?; + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + }) + } + } + + fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) { + if let Some(end_turn_tx) = self + .sessions + .lock() + .get_mut(session_id) + .unwrap() + .response_tx + .take() + { + end_turn_tx.send(acp::StopReason::Cancelled).unwrap(); + } + } + + fn truncate( + &self, + _session_id: &agent_client_protocol::SessionId, + _cx: &App, + ) -> Option> { + Some(Rc::new(StubAgentSessionEditor)) + } + + fn into_any(self: Rc) -> Rc { + self + } + } + + struct StubAgentSessionEditor; + + impl AgentSessionTruncate for StubAgentSessionEditor { + fn run(&self, _: UserMessageId, _: &mut App) -> Task> { + Task::ready(Ok(())) + } + } +} + +#[cfg(feature = "test-support")] +pub use test_support::*; diff --git a/crates/acp_thread/src/diff.rs b/crates/acp_thread/src/diff.rs index a2c2d6c3229ae96bf45dfc870e8600a5f778a6f0..f75af0543e373b47b0c6de36760ba18b5d9da318 100644 --- a/crates/acp_thread/src/diff.rs +++ b/crates/acp_thread/src/diff.rs @@ -1,7 +1,6 @@ -use agent_client_protocol as acp; use anyhow::Result; use buffer_diff::{BufferDiff, BufferDiffSnapshot}; -use editor::{MultiBuffer, PathKey}; +use editor::{MultiBuffer, PathKey, multibuffer_context_lines}; use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task}; use itertools::Itertools; use language::{ @@ -21,69 +20,54 @@ pub enum Diff { } impl Diff { - pub fn from_acp( - diff: acp::Diff, + pub fn finalized( + path: PathBuf, + old_text: Option, + new_text: String, language_registry: Arc, cx: &mut Context, ) -> Self { - let acp::Diff { - path, - old_text, - new_text, - } = diff; - let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly)); - let new_buffer = cx.new(|cx| Buffer::local(new_text, cx)); - let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx)); - let new_buffer_snapshot = new_buffer.read(cx).text_snapshot(); - let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx)); - + let base_text = old_text.clone().unwrap_or(String::new()).into(); let task = cx.spawn({ let multibuffer = multibuffer.clone(); let path = path.clone(); + let buffer = new_buffer.clone(); async move |_, cx| { let language = language_registry .language_for_file_path(&path) .await .log_err(); - new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?; - - let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| { - buffer.set_language(language, cx); - buffer.snapshot() - })?; + buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?; - buffer_diff - .update(cx, |diff, cx| { - diff.set_base_text( - old_buffer_snapshot, - Some(language_registry), - new_buffer_snapshot, - cx, - ) - })? - .await?; + let diff = build_buffer_diff( + old_text.unwrap_or("".into()).into(), + &buffer, + Some(language_registry.clone()), + cx, + ) + .await?; multibuffer .update(cx, |multibuffer, cx| { let hunk_ranges = { - let buffer = new_buffer.read(cx); - let diff = buffer_diff.read(cx); - diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx) - .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer)) + let buffer = buffer.read(cx); + let diff = diff.read(cx); + diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, buffer, cx) + .map(|diff_hunk| diff_hunk.buffer_range.to_point(buffer)) .collect::>() }; multibuffer.set_excerpts_for_path( - PathKey::for_buffer(&new_buffer, cx), - new_buffer.clone(), + PathKey::for_buffer(&buffer, cx), + buffer.clone(), hunk_ranges, - editor::DEFAULT_MULTIBUFFER_CONTEXT, + multibuffer_context_lines(cx), cx, ); - multibuffer.add_diff(buffer_diff, cx); + multibuffer.add_diff(diff, cx); }) .log_err(); @@ -94,23 +78,26 @@ impl Diff { Self::Finalized(FinalizedDiff { multibuffer, path, + base_text, + new_buffer, _update_diff: task, }) } pub fn new(buffer: Entity, cx: &mut Context) -> Self { - let buffer_snapshot = buffer.read(cx).snapshot(); - let base_text = buffer_snapshot.text(); - let language_registry = buffer.read(cx).language_registry(); - let text_snapshot = buffer.read(cx).text_snapshot(); + let buffer_text_snapshot = buffer.read(cx).text_snapshot(); + let base_text_snapshot = buffer.read(cx).snapshot(); + let base_text = base_text_snapshot.text(); + debug_assert_eq!(buffer_text_snapshot.text(), base_text); let buffer_diff = cx.new(|cx| { - let mut diff = BufferDiff::new(&text_snapshot, cx); - let _ = diff.set_base_text( - buffer_snapshot.clone(), - language_registry, - text_snapshot, - cx, - ); + let mut diff = BufferDiff::new_unchanged(&buffer_text_snapshot, base_text_snapshot); + let snapshot = diff.snapshot(cx); + let secondary_diff = cx.new(|cx| { + let mut diff = BufferDiff::new(&buffer_text_snapshot, cx); + diff.set_snapshot(snapshot, &buffer_text_snapshot, cx); + diff + }); + diff.set_secondary_diff(secondary_diff); diff }); @@ -128,7 +115,7 @@ impl Diff { diff.update(cx); } }), - buffer, + new_buffer: buffer, diff: buffer_diff, revealed_ranges: Vec::new(), update_diff: Task::ready(Ok(())), @@ -163,9 +150,9 @@ impl Diff { .map(|buffer| buffer.read(cx).text()) .join("\n"); let path = match self { - Diff::Pending(PendingDiff { buffer, .. }) => { - buffer.read(cx).file().map(|file| file.path().as_ref()) - } + Diff::Pending(PendingDiff { + new_buffer: buffer, .. + }) => buffer.read(cx).file().map(|file| file.path().as_ref()), Diff::Finalized(FinalizedDiff { path, .. }) => Some(path.as_path()), }; format!( @@ -178,12 +165,33 @@ impl Diff { pub fn has_revealed_range(&self, cx: &App) -> bool { self.multibuffer().read(cx).excerpt_paths().next().is_some() } + + pub fn needs_update(&self, old_text: &str, new_text: &str, cx: &App) -> bool { + match self { + Diff::Pending(PendingDiff { + base_text, + new_buffer, + .. + }) => { + base_text.as_str() != old_text + || !new_buffer.read(cx).as_rope().chunks().equals_str(new_text) + } + Diff::Finalized(FinalizedDiff { + base_text, + new_buffer, + .. + }) => { + base_text.as_str() != old_text + || !new_buffer.read(cx).as_rope().chunks().equals_str(new_text) + } + } + } } pub struct PendingDiff { multibuffer: Entity, base_text: Arc, - buffer: Entity, + new_buffer: Entity, diff: Entity, revealed_ranges: Vec>, _subscription: Subscription, @@ -192,7 +200,7 @@ pub struct PendingDiff { impl PendingDiff { pub fn update(&mut self, cx: &mut Context) { - let buffer = self.buffer.clone(); + let buffer = self.new_buffer.clone(); let buffer_diff = self.diff.clone(); let base_text = self.base_text.clone(); self.update_diff = cx.spawn(async move |diff, cx| { @@ -209,7 +217,10 @@ impl PendingDiff { ) .await?; buffer_diff.update(cx, |diff, cx| { - diff.set_snapshot(diff_snapshot, &text_snapshot, cx) + diff.set_snapshot(diff_snapshot.clone(), &text_snapshot, cx); + diff.secondary_diff().unwrap().update(cx, |diff, cx| { + diff.set_snapshot(diff_snapshot.clone(), &text_snapshot, cx); + }); })?; diff.update(cx, |diff, cx| { if let Diff::Pending(diff) = diff { @@ -227,10 +238,10 @@ impl PendingDiff { fn finalize(&self, cx: &mut Context) -> FinalizedDiff { let ranges = self.excerpt_ranges(cx); let base_text = self.base_text.clone(); - let language_registry = self.buffer.read(cx).language_registry().clone(); + let language_registry = self.new_buffer.read(cx).language_registry(); let path = self - .buffer + .new_buffer .read(cx) .file() .map(|file| file.path().as_ref()) @@ -239,12 +250,12 @@ impl PendingDiff { // Replace the buffer in the multibuffer with the snapshot let buffer = cx.new(|cx| { - let language = self.buffer.read(cx).language().cloned(); + let language = self.new_buffer.read(cx).language().cloned(); let buffer = TextBuffer::new_normalized( 0, cx.entity_id().as_non_zero_u64().into(), - self.buffer.read(cx).line_ending(), - self.buffer.read(cx).as_rope().clone(), + self.new_buffer.read(cx).line_ending(), + self.new_buffer.read(cx).as_rope().clone(), ); let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite); buffer.set_language(language, cx); @@ -253,7 +264,6 @@ impl PendingDiff { let buffer_diff = cx.spawn({ let buffer = buffer.clone(); - let language_registry = language_registry.clone(); async move |_this, cx| { build_buffer_diff(base_text, &buffer, language_registry, cx).await } @@ -269,7 +279,7 @@ impl PendingDiff { path_key, buffer, ranges, - editor::DEFAULT_MULTIBUFFER_CONTEXT, + multibuffer_context_lines(cx), cx, ); multibuffer.add_diff(buffer_diff.clone(), cx); @@ -281,7 +291,9 @@ impl PendingDiff { FinalizedDiff { path, + base_text: self.base_text.clone(), multibuffer: self.multibuffer.clone(), + new_buffer: self.new_buffer.clone(), _update_diff: update_diff, } } @@ -290,10 +302,10 @@ impl PendingDiff { let ranges = self.excerpt_ranges(cx); self.multibuffer.update(cx, |multibuffer, cx| { multibuffer.set_excerpts_for_path( - PathKey::for_buffer(&self.buffer, cx), - self.buffer.clone(), + PathKey::for_buffer(&self.new_buffer, cx), + self.new_buffer.clone(), ranges, - editor::DEFAULT_MULTIBUFFER_CONTEXT, + multibuffer_context_lines(cx), cx, ); let end = multibuffer.len(cx); @@ -303,16 +315,16 @@ impl PendingDiff { } fn excerpt_ranges(&self, cx: &App) -> Vec> { - let buffer = self.buffer.read(cx); + let buffer = self.new_buffer.read(cx); let diff = self.diff.read(cx); let mut ranges = diff - .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx) - .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer)) + .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, buffer, cx) + .map(|diff_hunk| diff_hunk.buffer_range.to_point(buffer)) .collect::>(); ranges.extend( self.revealed_ranges .iter() - .map(|range| range.to_point(&buffer)), + .map(|range| range.to_point(buffer)), ); ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end))); @@ -337,6 +349,8 @@ impl PendingDiff { pub struct FinalizedDiff { path: PathBuf, + base_text: Arc, + new_buffer: Entity, multibuffer: Entity, _update_diff: Task>, } @@ -390,3 +404,21 @@ async fn build_buffer_diff( diff }) } + +#[cfg(test)] +mod tests { + use gpui::{AppContext as _, TestAppContext}; + use language::Buffer; + + use crate::Diff; + + #[gpui::test] + async fn test_pending_diff(cx: &mut TestAppContext) { + let buffer = cx.new(|cx| Buffer::local("hello!", cx)); + let _diff = cx.new(|cx| Diff::new(buffer.clone(), cx)); + buffer.update(cx, |buffer, cx| { + buffer.set_text("HELLO!", cx); + }); + cx.run_until_parked(); + } +} diff --git a/crates/acp_thread/src/mention.rs b/crates/acp_thread/src/mention.rs index 03174608fb0187687fb987ea640277baa25e01a2..6fa0887e2278467dae9887516d882da90a78d0df 100644 --- a/crates/acp_thread/src/mention.rs +++ b/crates/acp_thread/src/mention.rs @@ -1,23 +1,33 @@ -use agent::ThreadId; +use agent_client_protocol as acp; use anyhow::{Context as _, Result, bail}; +use file_icons::FileIcons; use prompt_store::{PromptId, UserPromptId}; +use serde::{Deserialize, Serialize}; use std::{ fmt, - ops::Range, + ops::RangeInclusive, path::{Path, PathBuf}, + str::FromStr, }; +use ui::{App, IconName, SharedString}; use url::Url; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] pub enum MentionUri { - File(PathBuf), + File { + abs_path: PathBuf, + }, + PastedImage, + Directory { + abs_path: PathBuf, + }, Symbol { - path: PathBuf, + abs_path: PathBuf, name: String, - line_range: Range, + line_range: RangeInclusive, }, Thread { - id: ThreadId, + id: acp::SessionId, name: String, }, TextThread { @@ -29,8 +39,9 @@ pub enum MentionUri { name: String, }, Selection { - path: PathBuf, - line_range: Range, + #[serde(default, skip_serializing_if = "Option::is_none")] + abs_path: Option, + line_range: RangeInclusive, }, Fetch { url: Url, @@ -39,51 +50,56 @@ pub enum MentionUri { impl MentionUri { pub fn parse(input: &str) -> Result { + fn parse_line_range(fragment: &str) -> Result> { + let range = fragment + .strip_prefix("L") + .context("Line range must start with \"L\"")?; + let (start, end) = range + .split_once(":") + .context("Line range must use colon as separator")?; + let range = start + .parse::() + .context("Parsing line range start")? + .checked_sub(1) + .context("Line numbers should be 1-based")? + ..=end + .parse::() + .context("Parsing line range end")? + .checked_sub(1) + .context("Line numbers should be 1-based")?; + Ok(range) + } + let url = url::Url::parse(input)?; let path = url.path(); match url.scheme() { "file" => { + let path = url.to_file_path().ok().context("Extracting file path")?; if let Some(fragment) = url.fragment() { - let range = fragment - .strip_prefix("L") - .context("Line range must start with \"L\"")?; - let (start, end) = range - .split_once(":") - .context("Line range must use colon as separator")?; - let line_range = start - .parse::() - .context("Parsing line range start")? - .checked_sub(1) - .context("Line numbers should be 1-based")? - ..end - .parse::() - .context("Parsing line range end")? - .checked_sub(1) - .context("Line numbers should be 1-based")?; + let line_range = parse_line_range(fragment)?; if let Some(name) = single_query_param(&url, "symbol")? { Ok(Self::Symbol { name, - path: path.into(), + abs_path: path, line_range, }) } else { Ok(Self::Selection { - path: path.into(), + abs_path: Some(path), line_range, }) } + } else if input.ends_with("/") { + Ok(Self::Directory { abs_path: path }) } else { - let file_path = - PathBuf::from(format!("{}{}", url.host_str().unwrap_or(""), path)); - - Ok(Self::File(file_path)) + Ok(Self::File { abs_path: path }) } } "zed" => { if let Some(thread_id) = path.strip_prefix("/agent/thread/") { let name = single_query_param(&url, "name")?.context("Missing thread name")?; Ok(Self::Thread { - id: thread_id.into(), + id: acp::SessionId(thread_id.into()), name, }) } else if let Some(path) = path.strip_prefix("/agent/text-thread/") { @@ -99,6 +115,17 @@ impl MentionUri { id: rule_id.into(), name, }) + } else if path.starts_with("/agent/pasted-image") { + Ok(Self::PastedImage) + } else if path.starts_with("/agent/untitled-buffer") { + let fragment = url + .fragment() + .context("Missing fragment for untitled buffer selection")?; + let line_range = parse_line_range(fragment)?; + Ok(Self::Selection { + abs_path: None, + line_range, + }) } else { bail!("invalid zed url: {:?}", input); } @@ -108,57 +135,87 @@ impl MentionUri { } } - fn name(&self) -> String { + pub fn name(&self) -> String { match self { - MentionUri::File(path) => path + MentionUri::File { abs_path, .. } | MentionUri::Directory { abs_path, .. } => abs_path .file_name() .unwrap_or_default() .to_string_lossy() .into_owned(), + MentionUri::PastedImage => "Image".to_string(), MentionUri::Symbol { name, .. } => name.clone(), MentionUri::Thread { name, .. } => name.clone(), MentionUri::TextThread { name, .. } => name.clone(), MentionUri::Rule { name, .. } => name.clone(), MentionUri::Selection { - path, line_range, .. - } => selection_name(path, line_range), + abs_path: path, + line_range, + .. + } => selection_name(path.as_deref(), line_range), MentionUri::Fetch { url } => url.to_string(), } } + pub fn icon_path(&self, cx: &mut App) -> SharedString { + match self { + MentionUri::File { abs_path } => { + FileIcons::get_icon(abs_path, cx).unwrap_or_else(|| IconName::File.path().into()) + } + MentionUri::PastedImage => IconName::Image.path().into(), + MentionUri::Directory { .. } => FileIcons::get_folder_icon(false, cx) + .unwrap_or_else(|| IconName::Folder.path().into()), + MentionUri::Symbol { .. } => IconName::Code.path().into(), + MentionUri::Thread { .. } => IconName::Thread.path().into(), + MentionUri::TextThread { .. } => IconName::Thread.path().into(), + MentionUri::Rule { .. } => IconName::Reader.path().into(), + MentionUri::Selection { .. } => IconName::Reader.path().into(), + MentionUri::Fetch { .. } => IconName::ToolWeb.path().into(), + } + } + pub fn as_link<'a>(&'a self) -> MentionLink<'a> { MentionLink(self) } pub fn to_uri(&self) -> Url { match self { - MentionUri::File(path) => { - let mut url = Url::parse("file:///").unwrap(); - url.set_path(&path.to_string_lossy()); - url + MentionUri::File { abs_path } => { + Url::from_file_path(abs_path).expect("mention path should be absolute") + } + MentionUri::PastedImage => Url::parse("zed:///agent/pasted-image").unwrap(), + MentionUri::Directory { abs_path } => { + Url::from_directory_path(abs_path).expect("mention path should be absolute") } MentionUri::Symbol { - path, + abs_path, name, line_range, } => { - let mut url = Url::parse("file:///").unwrap(); - url.set_path(&path.to_string_lossy()); + let mut url = + Url::from_file_path(abs_path).expect("mention path should be absolute"); url.query_pairs_mut().append_pair("symbol", name); url.set_fragment(Some(&format!( "L{}:{}", - line_range.start + 1, - line_range.end + 1 + line_range.start() + 1, + line_range.end() + 1 ))); url } - MentionUri::Selection { path, line_range } => { - let mut url = Url::parse("file:///").unwrap(); - url.set_path(&path.to_string_lossy()); + MentionUri::Selection { + abs_path: path, + line_range, + } => { + let mut url = if let Some(path) = path { + Url::from_file_path(path).expect("mention path should be absolute") + } else { + let mut url = Url::parse("zed:///").unwrap(); + url.set_path("/agent/untitled-buffer"); + url + }; url.set_fragment(Some(&format!( "L{}:{}", - line_range.start + 1, - line_range.end + 1 + line_range.start() + 1, + line_range.end() + 1 ))); url } @@ -170,7 +227,10 @@ impl MentionUri { } MentionUri::TextThread { path, name } => { let mut url = Url::parse("zed:///").unwrap(); - url.set_path(&format!("/agent/text-thread/{}", path.to_string_lossy())); + url.set_path(&format!( + "/agent/text-thread/{}", + path.to_string_lossy().trim_start_matches('/') + )); url.query_pairs_mut().append_pair("name", name); url } @@ -185,6 +245,14 @@ impl MentionUri { } } +impl FromStr for MentionUri { + type Err = anyhow::Error; + + fn from_str(s: &str) -> anyhow::Result { + Self::parse(s) + } +} + pub struct MentionLink<'a>(&'a MentionUri); impl fmt::Display for MentionLink<'_> { @@ -208,44 +276,81 @@ fn single_query_param(url: &Url, name: &'static str) -> Result> { } } -pub fn selection_name(path: &Path, line_range: &Range) -> String { +pub fn selection_name(path: Option<&Path>, line_range: &RangeInclusive) -> String { format!( "{} ({}:{})", - path.file_name().unwrap_or_default().display(), - line_range.start + 1, - line_range.end + 1 + path.and_then(|path| path.file_name()) + .unwrap_or("Untitled".as_ref()) + .display(), + *line_range.start() + 1, + *line_range.end() + 1 ) } #[cfg(test)] mod tests { + use util::{path, uri}; + use super::*; #[test] fn test_parse_file_uri() { - let file_uri = "file:///path/to/file.rs"; + let file_uri = uri!("file:///path/to/file.rs"); let parsed = MentionUri::parse(file_uri).unwrap(); match &parsed { - MentionUri::File(path) => assert_eq!(path.to_str().unwrap(), "/path/to/file.rs"), + MentionUri::File { abs_path } => { + assert_eq!(abs_path.to_str().unwrap(), path!("/path/to/file.rs")); + } _ => panic!("Expected File variant"), } assert_eq!(parsed.to_uri().to_string(), file_uri); } + #[test] + fn test_parse_directory_uri() { + let file_uri = uri!("file:///path/to/dir/"); + let parsed = MentionUri::parse(file_uri).unwrap(); + match &parsed { + MentionUri::Directory { abs_path } => { + assert_eq!(abs_path.to_str().unwrap(), path!("/path/to/dir/")); + } + _ => panic!("Expected Directory variant"), + } + assert_eq!(parsed.to_uri().to_string(), file_uri); + } + + #[test] + fn test_to_directory_uri_with_slash() { + let uri = MentionUri::Directory { + abs_path: PathBuf::from(path!("/path/to/dir/")), + }; + let expected = uri!("file:///path/to/dir/"); + assert_eq!(uri.to_uri().to_string(), expected); + } + + #[test] + fn test_to_directory_uri_without_slash() { + let uri = MentionUri::Directory { + abs_path: PathBuf::from(path!("/path/to/dir")), + }; + let expected = uri!("file:///path/to/dir/"); + assert_eq!(uri.to_uri().to_string(), expected); + } + #[test] fn test_parse_symbol_uri() { - let symbol_uri = "file:///path/to/file.rs?symbol=MySymbol#L10:20"; + let symbol_uri = uri!("file:///path/to/file.rs?symbol=MySymbol#L10:20"); let parsed = MentionUri::parse(symbol_uri).unwrap(); match &parsed { MentionUri::Symbol { - path, + abs_path: path, name, line_range, } => { - assert_eq!(path.to_str().unwrap(), "/path/to/file.rs"); + assert_eq!(path.to_str().unwrap(), path!("/path/to/file.rs")); assert_eq!(name, "MySymbol"); - assert_eq!(line_range.start, 9); - assert_eq!(line_range.end, 19); + assert_eq!(line_range.start(), &9); + assert_eq!(line_range.end(), &19); } _ => panic!("Expected Symbol variant"), } @@ -254,19 +359,42 @@ mod tests { #[test] fn test_parse_selection_uri() { - let selection_uri = "file:///path/to/file.rs#L5:15"; + let selection_uri = uri!("file:///path/to/file.rs#L5:15"); let parsed = MentionUri::parse(selection_uri).unwrap(); match &parsed { - MentionUri::Selection { path, line_range } => { - assert_eq!(path.to_str().unwrap(), "/path/to/file.rs"); - assert_eq!(line_range.start, 4); - assert_eq!(line_range.end, 14); + MentionUri::Selection { + abs_path: path, + line_range, + } => { + assert_eq!( + path.as_ref().unwrap().to_str().unwrap(), + path!("/path/to/file.rs") + ); + assert_eq!(line_range.start(), &4); + assert_eq!(line_range.end(), &14); } _ => panic!("Expected Selection variant"), } assert_eq!(parsed.to_uri().to_string(), selection_uri); } + #[test] + fn test_parse_untitled_selection_uri() { + let selection_uri = uri!("zed:///agent/untitled-buffer#L1:10"); + let parsed = MentionUri::parse(selection_uri).unwrap(); + match &parsed { + MentionUri::Selection { + abs_path: None, + line_range, + } => { + assert_eq!(line_range.start(), &0); + assert_eq!(line_range.end(), &9); + } + _ => panic!("Expected Selection variant without path"), + } + assert_eq!(parsed.to_uri().to_string(), selection_uri); + } + #[test] fn test_parse_thread_uri() { let thread_uri = "zed:///agent/thread/session123?name=Thread+name"; @@ -340,32 +468,35 @@ mod tests { #[test] fn test_invalid_line_range_format() { // Missing L prefix - assert!(MentionUri::parse("file:///path/to/file.rs#10:20").is_err()); + assert!(MentionUri::parse(uri!("file:///path/to/file.rs#10:20")).is_err()); // Missing colon separator - assert!(MentionUri::parse("file:///path/to/file.rs#L1020").is_err()); + assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L1020")).is_err()); // Invalid numbers - assert!(MentionUri::parse("file:///path/to/file.rs#L10:abc").is_err()); - assert!(MentionUri::parse("file:///path/to/file.rs#Labc:20").is_err()); + assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L10:abc")).is_err()); + assert!(MentionUri::parse(uri!("file:///path/to/file.rs#Labc:20")).is_err()); } #[test] fn test_invalid_query_parameters() { // Invalid query parameter name - assert!(MentionUri::parse("file:///path/to/file.rs#L10:20?invalid=test").is_err()); + assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L10:20?invalid=test")).is_err()); // Too many query parameters assert!( - MentionUri::parse("file:///path/to/file.rs#L10:20?symbol=test&another=param").is_err() + MentionUri::parse(uri!( + "file:///path/to/file.rs#L10:20?symbol=test&another=param" + )) + .is_err() ); } #[test] fn test_zero_based_line_numbers() { // Test that 0-based line numbers are rejected (should be 1-based) - assert!(MentionUri::parse("file:///path/to/file.rs#L0:10").is_err()); - assert!(MentionUri::parse("file:///path/to/file.rs#L1:0").is_err()); - assert!(MentionUri::parse("file:///path/to/file.rs#L0:0").is_err()); + assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L0:10")).is_err()); + assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L1:0")).is_err()); + assert!(MentionUri::parse(uri!("file:///path/to/file.rs#L0:0")).is_err()); } } diff --git a/crates/acp_thread/src/terminal.rs b/crates/acp_thread/src/terminal.rs index 41d7fb89bb2eb59207bf0a6557129a088b435f3a..6b4cdb73469d9dd7d1a1759bf3aa28d005d1f13e 100644 --- a/crates/acp_thread/src/terminal.rs +++ b/crates/acp_thread/src/terminal.rs @@ -1,34 +1,43 @@ -use gpui::{App, AppContext, Context, Entity}; +use agent_client_protocol as acp; + +use futures::{FutureExt as _, future::Shared}; +use gpui::{App, AppContext, Context, Entity, Task}; use language::LanguageRegistry; use markdown::Markdown; use std::{path::PathBuf, process::ExitStatus, sync::Arc, time::Instant}; pub struct Terminal { + id: acp::TerminalId, command: Entity, working_dir: Option, terminal: Entity, started_at: Instant, output: Option, + output_byte_limit: Option, + _output_task: Shared>, } pub struct TerminalOutput { pub ended_at: Instant, pub exit_status: Option, - pub was_content_truncated: bool, + pub content: String, pub original_content_len: usize, pub content_line_count: usize, - pub finished_with_empty_output: bool, } impl Terminal { pub fn new( + id: acp::TerminalId, command: String, working_dir: Option, + output_byte_limit: Option, terminal: Entity, language_registry: Arc, cx: &mut Context, ) -> Self { + let command_task = terminal.read(cx).wait_for_completed_task(cx); Self { + id, command: cx.new(|cx| { Markdown::new( format!("```\n{}\n```", command).into(), @@ -41,27 +50,93 @@ impl Terminal { terminal, started_at: Instant::now(), output: None, + output_byte_limit, + _output_task: cx + .spawn(async move |this, cx| { + let exit_status = command_task.await; + + this.update(cx, |this, cx| { + let (content, original_content_len) = this.truncated_output(cx); + let content_line_count = this.terminal.read(cx).total_lines(); + + this.output = Some(TerminalOutput { + ended_at: Instant::now(), + exit_status, + content, + original_content_len, + content_line_count, + }); + cx.notify(); + }) + .ok(); + + let exit_status = exit_status.map(portable_pty::ExitStatus::from); + + acp::TerminalExitStatus { + exit_code: exit_status.as_ref().map(|e| e.exit_code()), + signal: exit_status.and_then(|e| e.signal().map(Into::into)), + } + }) + .shared(), } } - pub fn finish( - &mut self, - exit_status: Option, - original_content_len: usize, - truncated_content_len: usize, - content_line_count: usize, - finished_with_empty_output: bool, - cx: &mut Context, - ) { - self.output = Some(TerminalOutput { - ended_at: Instant::now(), - exit_status, - was_content_truncated: truncated_content_len < original_content_len, - original_content_len, - content_line_count, - finished_with_empty_output, + pub fn id(&self) -> &acp::TerminalId { + &self.id + } + + pub fn wait_for_exit(&self) -> Shared> { + self._output_task.clone() + } + + pub fn kill(&mut self, cx: &mut App) { + self.terminal.update(cx, |terminal, _cx| { + terminal.kill_active_task(); }); - cx.notify(); + } + + pub fn current_output(&self, cx: &App) -> acp::TerminalOutputResponse { + if let Some(output) = self.output.as_ref() { + let exit_status = output.exit_status.map(portable_pty::ExitStatus::from); + + acp::TerminalOutputResponse { + output: output.content.clone(), + truncated: output.original_content_len > output.content.len(), + exit_status: Some(acp::TerminalExitStatus { + exit_code: exit_status.as_ref().map(|e| e.exit_code()), + signal: exit_status.and_then(|e| e.signal().map(Into::into)), + }), + } + } else { + let (current_content, original_len) = self.truncated_output(cx); + + acp::TerminalOutputResponse { + truncated: current_content.len() < original_len, + output: current_content, + exit_status: None, + } + } + } + + fn truncated_output(&self, cx: &App) -> (String, usize) { + let terminal = self.terminal.read(cx); + let mut content = terminal.get_content(); + + let original_content_len = content.len(); + + if let Some(limit) = self.output_byte_limit + && content.len() > limit + { + let mut end_ix = limit.min(content.len()); + while !content.is_char_boundary(end_ix) { + end_ix -= 1; + } + // Don't truncate mid-line, clear the remainder of the last line + end_ix = content[..end_ix].rfind('\n').unwrap_or(end_ix); + content.truncate(end_ix); + } + + (content, original_content_len) } pub fn command(&self) -> &Entity { diff --git a/crates/acp_tools/Cargo.toml b/crates/acp_tools/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..7a6d8c21a096364a8468671f4186048559ec8a61 --- /dev/null +++ b/crates/acp_tools/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "acp_tools" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + + +[lints] +workspace = true + +[lib] +path = "src/acp_tools.rs" +doctest = false + +[dependencies] +agent-client-protocol.workspace = true +collections.workspace = true +gpui.workspace = true +language.workspace= true +markdown.workspace = true +project.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +theme.workspace = true +ui.workspace = true +util.workspace = true +workspace-hack.workspace = true +workspace.workspace = true diff --git a/crates/indexed_docs/LICENSE-GPL b/crates/acp_tools/LICENSE-GPL similarity index 100% rename from crates/indexed_docs/LICENSE-GPL rename to crates/acp_tools/LICENSE-GPL diff --git a/crates/acp_tools/src/acp_tools.rs b/crates/acp_tools/src/acp_tools.rs new file mode 100644 index 0000000000000000000000000000000000000000..e20a040e9da70a40066f3e5534171818de34a936 --- /dev/null +++ b/crates/acp_tools/src/acp_tools.rs @@ -0,0 +1,494 @@ +use std::{ + cell::RefCell, + collections::HashSet, + fmt::Display, + rc::{Rc, Weak}, + sync::Arc, +}; + +use agent_client_protocol as acp; +use collections::HashMap; +use gpui::{ + App, Empty, Entity, EventEmitter, FocusHandle, Focusable, Global, ListAlignment, ListState, + StyleRefinement, Subscription, Task, TextStyleRefinement, Window, actions, list, prelude::*, +}; +use language::LanguageRegistry; +use markdown::{CodeBlockRenderer, Markdown, MarkdownElement, MarkdownStyle}; +use project::Project; +use settings::Settings; +use theme::ThemeSettings; +use ui::prelude::*; +use util::ResultExt as _; +use workspace::{Item, Workspace}; + +actions!(dev, [OpenAcpLogs]); + +pub fn init(cx: &mut App) { + cx.observe_new( + |workspace: &mut Workspace, _window, _cx: &mut Context| { + workspace.register_action(|workspace, _: &OpenAcpLogs, window, cx| { + let acp_tools = + Box::new(cx.new(|cx| AcpTools::new(workspace.project().clone(), cx))); + workspace.add_item_to_active_pane(acp_tools, None, true, window, cx); + }); + }, + ) + .detach(); +} + +struct GlobalAcpConnectionRegistry(Entity); + +impl Global for GlobalAcpConnectionRegistry {} + +#[derive(Default)] +pub struct AcpConnectionRegistry { + active_connection: RefCell>, +} + +struct ActiveConnection { + server_name: SharedString, + connection: Weak, +} + +impl AcpConnectionRegistry { + pub fn default_global(cx: &mut App) -> Entity { + if cx.has_global::() { + cx.global::().0.clone() + } else { + let registry = cx.new(|_cx| AcpConnectionRegistry::default()); + cx.set_global(GlobalAcpConnectionRegistry(registry.clone())); + registry + } + } + + pub fn set_active_connection( + &self, + server_name: impl Into, + connection: &Rc, + cx: &mut Context, + ) { + self.active_connection.replace(Some(ActiveConnection { + server_name: server_name.into(), + connection: Rc::downgrade(connection), + })); + cx.notify(); + } +} + +struct AcpTools { + project: Entity, + focus_handle: FocusHandle, + expanded: HashSet, + watched_connection: Option, + connection_registry: Entity, + _subscription: Subscription, +} + +struct WatchedConnection { + server_name: SharedString, + messages: Vec, + list_state: ListState, + connection: Weak, + incoming_request_methods: HashMap>, + outgoing_request_methods: HashMap>, + _task: Task<()>, +} + +impl AcpTools { + fn new(project: Entity, cx: &mut Context) -> Self { + let connection_registry = AcpConnectionRegistry::default_global(cx); + + let subscription = cx.observe(&connection_registry, |this, _, cx| { + this.update_connection(cx); + cx.notify(); + }); + + let mut this = Self { + project, + focus_handle: cx.focus_handle(), + expanded: HashSet::default(), + watched_connection: None, + connection_registry, + _subscription: subscription, + }; + this.update_connection(cx); + this + } + + fn update_connection(&mut self, cx: &mut Context) { + let active_connection = self.connection_registry.read(cx).active_connection.borrow(); + let Some(active_connection) = active_connection.as_ref() else { + return; + }; + + if let Some(watched_connection) = self.watched_connection.as_ref() { + if Weak::ptr_eq( + &watched_connection.connection, + &active_connection.connection, + ) { + return; + } + } + + if let Some(connection) = active_connection.connection.upgrade() { + let mut receiver = connection.subscribe(); + let task = cx.spawn(async move |this, cx| { + while let Ok(message) = receiver.recv().await { + this.update(cx, |this, cx| { + this.push_stream_message(message, cx); + }) + .ok(); + } + }); + + self.watched_connection = Some(WatchedConnection { + server_name: active_connection.server_name.clone(), + messages: vec![], + list_state: ListState::new(0, ListAlignment::Bottom, px(2048.)), + connection: active_connection.connection.clone(), + incoming_request_methods: HashMap::default(), + outgoing_request_methods: HashMap::default(), + _task: task, + }); + } + } + + fn push_stream_message(&mut self, stream_message: acp::StreamMessage, cx: &mut Context) { + let Some(connection) = self.watched_connection.as_mut() else { + return; + }; + let language_registry = self.project.read(cx).languages().clone(); + let index = connection.messages.len(); + + let (request_id, method, message_type, params) = match stream_message.message { + acp::StreamMessageContent::Request { id, method, params } => { + let method_map = match stream_message.direction { + acp::StreamMessageDirection::Incoming => { + &mut connection.incoming_request_methods + } + acp::StreamMessageDirection::Outgoing => { + &mut connection.outgoing_request_methods + } + }; + + method_map.insert(id, method.clone()); + (Some(id), method.into(), MessageType::Request, Ok(params)) + } + acp::StreamMessageContent::Response { id, result } => { + let method_map = match stream_message.direction { + acp::StreamMessageDirection::Incoming => { + &mut connection.outgoing_request_methods + } + acp::StreamMessageDirection::Outgoing => { + &mut connection.incoming_request_methods + } + }; + + if let Some(method) = method_map.remove(&id) { + (Some(id), method.into(), MessageType::Response, result) + } else { + ( + Some(id), + "[unrecognized response]".into(), + MessageType::Response, + result, + ) + } + } + acp::StreamMessageContent::Notification { method, params } => { + (None, method.into(), MessageType::Notification, Ok(params)) + } + }; + + let message = WatchedConnectionMessage { + name: method, + message_type, + request_id, + direction: stream_message.direction, + collapsed_params_md: match params.as_ref() { + Ok(params) => params + .as_ref() + .map(|params| collapsed_params_md(params, &language_registry, cx)), + Err(err) => { + if let Ok(err) = &serde_json::to_value(err) { + Some(collapsed_params_md(&err, &language_registry, cx)) + } else { + None + } + } + }, + + expanded_params_md: None, + params, + }; + + connection.messages.push(message); + connection.list_state.splice(index..index, 1); + cx.notify(); + } + + fn render_message( + &mut self, + index: usize, + window: &mut Window, + cx: &mut Context, + ) -> AnyElement { + let Some(connection) = self.watched_connection.as_ref() else { + return Empty.into_any(); + }; + + let Some(message) = connection.messages.get(index) else { + return Empty.into_any(); + }; + + let base_size = TextSize::Editor.rems(cx); + + let theme_settings = ThemeSettings::get_global(cx); + let text_style = window.text_style(); + + let colors = cx.theme().colors(); + let expanded = self.expanded.contains(&index); + + v_flex() + .w_full() + .px_4() + .py_3() + .border_color(colors.border) + .border_b_1() + .gap_2() + .items_start() + .font_buffer(cx) + .text_size(base_size) + .id(index) + .group("message") + .hover(|this| this.bg(colors.element_background.opacity(0.5))) + .on_click(cx.listener(move |this, _, _, cx| { + if this.expanded.contains(&index) { + this.expanded.remove(&index); + } else { + this.expanded.insert(index); + let Some(connection) = &mut this.watched_connection else { + return; + }; + let Some(message) = connection.messages.get_mut(index) else { + return; + }; + message.expanded(this.project.read(cx).languages().clone(), cx); + connection.list_state.scroll_to_reveal_item(index); + } + cx.notify() + })) + .child( + h_flex() + .w_full() + .gap_2() + .items_center() + .flex_shrink_0() + .child(match message.direction { + acp::StreamMessageDirection::Incoming => { + ui::Icon::new(ui::IconName::ArrowDown).color(Color::Error) + } + acp::StreamMessageDirection::Outgoing => { + ui::Icon::new(ui::IconName::ArrowUp).color(Color::Success) + } + }) + .child( + Label::new(message.name.clone()) + .buffer_font(cx) + .color(Color::Muted), + ) + .child(div().flex_1()) + .child( + div() + .child(ui::Chip::new(message.message_type.to_string())) + .visible_on_hover("message"), + ) + .children( + message + .request_id + .map(|req_id| div().child(ui::Chip::new(req_id.to_string()))), + ), + ) + // I'm aware using markdown is a hack. Trying to get something working for the demo. + // Will clean up soon! + .when_some( + if expanded { + message.expanded_params_md.clone() + } else { + message.collapsed_params_md.clone() + }, + |this, params| { + this.child( + div().pl_6().w_full().child( + MarkdownElement::new( + params, + MarkdownStyle { + base_text_style: text_style, + selection_background_color: colors.element_selection_background, + syntax: cx.theme().syntax().clone(), + code_block_overflow_x_scroll: true, + code_block: StyleRefinement { + text: Some(TextStyleRefinement { + font_family: Some( + theme_settings.buffer_font.family.clone(), + ), + font_size: Some((base_size * 0.8).into()), + ..Default::default() + }), + ..Default::default() + }, + ..Default::default() + }, + ) + .code_block_renderer( + CodeBlockRenderer::Default { + copy_button: false, + copy_button_on_hover: expanded, + border: false, + }, + ), + ), + ) + }, + ) + .into_any() + } +} + +struct WatchedConnectionMessage { + name: SharedString, + request_id: Option, + direction: acp::StreamMessageDirection, + message_type: MessageType, + params: Result, acp::Error>, + collapsed_params_md: Option>, + expanded_params_md: Option>, +} + +impl WatchedConnectionMessage { + fn expanded(&mut self, language_registry: Arc, cx: &mut App) { + let params_md = match &self.params { + Ok(Some(params)) => Some(expanded_params_md(params, &language_registry, cx)), + Err(err) => { + if let Some(err) = &serde_json::to_value(err).log_err() { + Some(expanded_params_md(&err, &language_registry, cx)) + } else { + None + } + } + _ => None, + }; + self.expanded_params_md = params_md; + } +} + +fn collapsed_params_md( + params: &serde_json::Value, + language_registry: &Arc, + cx: &mut App, +) -> Entity { + let params_json = serde_json::to_string(params).unwrap_or_default(); + let mut spaced_out_json = String::with_capacity(params_json.len() + params_json.len() / 4); + + for ch in params_json.chars() { + match ch { + '{' => spaced_out_json.push_str("{ "), + '}' => spaced_out_json.push_str(" }"), + ':' => spaced_out_json.push_str(": "), + ',' => spaced_out_json.push_str(", "), + c => spaced_out_json.push(c), + } + } + + let params_md = format!("```json\n{}\n```", spaced_out_json); + cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx)) +} + +fn expanded_params_md( + params: &serde_json::Value, + language_registry: &Arc, + cx: &mut App, +) -> Entity { + let params_json = serde_json::to_string_pretty(params).unwrap_or_default(); + let params_md = format!("```json\n{}\n```", params_json); + cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx)) +} + +enum MessageType { + Request, + Response, + Notification, +} + +impl Display for MessageType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MessageType::Request => write!(f, "Request"), + MessageType::Response => write!(f, "Response"), + MessageType::Notification => write!(f, "Notification"), + } + } +} + +enum AcpToolsEvent {} + +impl EventEmitter for AcpTools {} + +impl Item for AcpTools { + type Event = AcpToolsEvent; + + fn tab_content_text(&self, _detail: usize, _cx: &App) -> ui::SharedString { + format!( + "ACP: {}", + self.watched_connection + .as_ref() + .map_or("Disconnected", |connection| &connection.server_name) + ) + .into() + } + + fn tab_icon(&self, _window: &Window, _cx: &App) -> Option { + Some(ui::Icon::new(IconName::Thread)) + } +} + +impl Focusable for AcpTools { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl Render for AcpTools { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + v_flex() + .track_focus(&self.focus_handle) + .size_full() + .bg(cx.theme().colors().editor_background) + .child(match self.watched_connection.as_ref() { + Some(connection) => { + if connection.messages.is_empty() { + h_flex() + .size_full() + .justify_center() + .items_center() + .child("No messages recorded yet") + .into_any() + } else { + list( + connection.list_state.clone(), + cx.processor(Self::render_message), + ) + .with_sizing_behavior(gpui::ListSizingBehavior::Auto) + .flex_grow() + .into_any() + } + } + None => h_flex() + .size_full() + .justify_center() + .items_center() + .child("No active connection") + .into_any(), + }) + } +} diff --git a/crates/action_log/src/action_log.rs b/crates/action_log/src/action_log.rs index c4eaffc2281de30cf0274539897d5fd70cda1351..11ba596ac5a0ecd4ed49744d0eafa9defcde20c1 100644 --- a/crates/action_log/src/action_log.rs +++ b/crates/action_log/src/action_log.rs @@ -116,7 +116,7 @@ impl ActionLog { } else if buffer .read(cx) .file() - .map_or(false, |file| file.disk_state().exists()) + .is_some_and(|file| file.disk_state().exists()) { TrackedBufferStatus::Created { existing_file_content: Some(buffer.read(cx).as_rope().clone()), @@ -161,7 +161,7 @@ impl ActionLog { diff_base, last_seen_base, unreviewed_edits, - snapshot: text_snapshot.clone(), + snapshot: text_snapshot, status, version: buffer.read(cx).version(), diff, @@ -190,7 +190,7 @@ impl ActionLog { cx: &mut Context, ) { match event { - BufferEvent::Edited { .. } => self.handle_buffer_edited(buffer, cx), + BufferEvent::Edited => self.handle_buffer_edited(buffer, cx), BufferEvent::FileHandleChanged => { self.handle_buffer_file_changed(buffer, cx); } @@ -215,7 +215,7 @@ impl ActionLog { if buffer .read(cx) .file() - .map_or(false, |file| file.disk_state() == DiskState::Deleted) + .is_some_and(|file| file.disk_state() == DiskState::Deleted) { // If the buffer had been edited by a tool, but it got // deleted externally, we want to stop tracking it. @@ -227,7 +227,7 @@ impl ActionLog { if buffer .read(cx) .file() - .map_or(false, |file| file.disk_state() != DiskState::Deleted) + .is_some_and(|file| file.disk_state() != DiskState::Deleted) { // If the buffer had been deleted by a tool, but it got // resurrected externally, we want to clear the edits we @@ -264,15 +264,14 @@ impl ActionLog { if let Some((git_diff, (buffer_repo, _))) = git_diff.as_ref().zip(buffer_repo) { cx.update(|cx| { let mut old_head = buffer_repo.read(cx).head_commit.clone(); - Some(cx.subscribe(git_diff, move |_, event, cx| match event { - buffer_diff::BufferDiffEvent::DiffChanged { .. } => { + Some(cx.subscribe(git_diff, move |_, event, cx| { + if let buffer_diff::BufferDiffEvent::DiffChanged { .. } = event { let new_head = buffer_repo.read(cx).head_commit.clone(); if new_head != old_head { old_head = new_head; git_diff_updates_tx.send(()).ok(); } } - _ => {} })) })? } else { @@ -290,7 +289,7 @@ impl ActionLog { } _ = git_diff_updates_rx.changed().fuse() => { if let Some(git_diff) = git_diff.as_ref() { - Self::keep_committed_edits(&this, &buffer, &git_diff, cx).await?; + Self::keep_committed_edits(&this, &buffer, git_diff, cx).await?; } } } @@ -462,7 +461,7 @@ impl ActionLog { anyhow::Ok(( tracked_buffer.diff.clone(), buffer.read(cx).language().cloned(), - buffer.read(cx).language_registry().clone(), + buffer.read(cx).language_registry(), )) })??; let diff_snapshot = BufferDiff::update_diff( @@ -498,7 +497,7 @@ impl ActionLog { new: new_range, }, &new_diff_base, - &buffer_snapshot.as_rope(), + buffer_snapshot.as_rope(), )); } unreviewed_edits @@ -530,12 +529,12 @@ impl ActionLog { /// Mark a buffer as created by agent, so we can refresh it in the context pub fn buffer_created(&mut self, buffer: Entity, cx: &mut Context) { - self.track_buffer_internal(buffer.clone(), true, cx); + self.track_buffer_internal(buffer, true, cx); } /// Mark a buffer as edited by agent, so we can refresh it in the context pub fn buffer_edited(&mut self, buffer: Entity, cx: &mut Context) { - let tracked_buffer = self.track_buffer_internal(buffer.clone(), false, cx); + let tracked_buffer = self.track_buffer_internal(buffer, false, cx); if let TrackedBufferStatus::Deleted = tracked_buffer.status { tracked_buffer.status = TrackedBufferStatus::Modified; } @@ -614,10 +613,10 @@ impl ActionLog { false } }); - if tracked_buffer.unreviewed_edits.is_empty() { - if let TrackedBufferStatus::Created { .. } = &mut tracked_buffer.status { - tracked_buffer.status = TrackedBufferStatus::Modified; - } + if tracked_buffer.unreviewed_edits.is_empty() + && let TrackedBufferStatus::Created { .. } = &mut tracked_buffer.status + { + tracked_buffer.status = TrackedBufferStatus::Modified; } tracked_buffer.schedule_diff_update(ChangeAuthor::User, cx); } @@ -811,7 +810,7 @@ impl ActionLog { tracked.version != buffer.version && buffer .file() - .map_or(false, |file| file.disk_state() != DiskState::Deleted) + .is_some_and(|file| file.disk_state() != DiskState::Deleted) }) .map(|(buffer, _)| buffer) } @@ -847,7 +846,7 @@ fn apply_non_conflicting_edits( conflict = true; if new_edits .peek() - .map_or(false, |next_edit| next_edit.old.overlaps(&old_edit.new)) + .is_some_and(|next_edit| next_edit.old.overlaps(&old_edit.new)) { new_edit = new_edits.next().unwrap(); } else { @@ -964,7 +963,7 @@ impl TrackedBuffer { fn has_edits(&self, cx: &App) -> bool { self.diff .read(cx) - .hunks(&self.buffer.read(cx), cx) + .hunks(self.buffer.read(cx), cx) .next() .is_some() } @@ -2219,7 +2218,7 @@ mod tests { action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); for _ in 0..operations { - match rng.gen_range(0..100) { + match rng.random_range(0..100) { 0..25 => { action_log.update(cx, |log, cx| { let range = buffer.read(cx).random_byte_range(0, &mut rng); @@ -2238,7 +2237,7 @@ mod tests { .unwrap(); } _ => { - let is_agent_edit = rng.gen_bool(0.5); + let is_agent_edit = rng.random_bool(0.5); if is_agent_edit { log::info!("agent edit"); } else { @@ -2253,7 +2252,7 @@ mod tests { } } - if rng.gen_bool(0.2) { + if rng.random_bool(0.2) { quiesce(&action_log, &buffer, cx); } } @@ -2268,7 +2267,7 @@ mod tests { log::info!("quiescing..."); cx.run_until_parked(); action_log.update(cx, |log, cx| { - let tracked_buffer = log.tracked_buffers.get(&buffer).unwrap(); + let tracked_buffer = log.tracked_buffers.get(buffer).unwrap(); let mut old_text = tracked_buffer.diff_base.clone(); let new_text = buffer.read(cx).as_rope(); for edit in tracked_buffer.unreviewed_edits.edits() { @@ -2426,7 +2425,7 @@ mod tests { assert_eq!( unreviewed_hunks(&action_log, cx), vec![( - buffer.clone(), + buffer, vec![ HunkStatus { range: Point::new(6, 0)..Point::new(7, 0), diff --git a/crates/activity_indicator/src/activity_indicator.rs b/crates/activity_indicator/src/activity_indicator.rs index 7c562aaba4f494d044b3efd4c53344365011257f..1870ab74db214b518bb0b543166067e636f14965 100644 --- a/crates/activity_indicator/src/activity_indicator.rs +++ b/crates/activity_indicator/src/activity_indicator.rs @@ -1,11 +1,10 @@ use auto_update::{AutoUpdateStatus, AutoUpdater, DismissErrorMessage, VersionCheckType}; use editor::Editor; -use extension_host::ExtensionStore; +use extension_host::{ExtensionOperation, ExtensionStore}; use futures::StreamExt; use gpui::{ - Animation, AnimationExt as _, App, Context, CursorStyle, Entity, EventEmitter, - InteractiveElement as _, ParentElement as _, Render, SharedString, StatefulInteractiveElement, - Styled, Transformation, Window, actions, percentage, + App, Context, CursorStyle, Entity, EventEmitter, InteractiveElement as _, ParentElement as _, + Render, SharedString, StatefulInteractiveElement, Styled, Window, actions, }; use language::{ BinaryStatus, LanguageRegistry, LanguageServerId, LanguageServerName, @@ -25,7 +24,10 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; -use ui::{ButtonLike, ContextMenu, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*}; +use ui::{ + ButtonLike, CommonAnimationExt, ContextMenu, PopoverMenu, PopoverMenuHandle, Tooltip, + prelude::*, +}; use util::truncate_and_trailoff; use workspace::{StatusItemView, Workspace, item::ItemHandle}; @@ -82,7 +84,6 @@ impl ActivityIndicator { ) -> Entity { let project = workspace.project().clone(); let auto_updater = AutoUpdater::get(cx); - let workspace_handle = cx.entity(); let this = cx.new(|cx| { let mut status_events = languages.language_server_binary_statuses(); cx.spawn(async move |this, cx| { @@ -100,29 +101,10 @@ impl ActivityIndicator { }) .detach(); - cx.subscribe_in( - &workspace_handle, - window, - |activity_indicator, _, event, window, cx| match event { - workspace::Event::ClearActivityIndicator { .. } => { - if activity_indicator.statuses.pop().is_some() { - activity_indicator.dismiss_error_message( - &DismissErrorMessage, - window, - cx, - ); - cx.notify(); - } - } - _ => {} - }, - ) - .detach(); - cx.subscribe( &project.read(cx).lsp_store(), - |activity_indicator, _, event, cx| match event { - LspStoreEvent::LanguageServerUpdate { name, message, .. } => { + |activity_indicator, _, event, cx| { + if let LspStoreEvent::LanguageServerUpdate { name, message, .. } = event { if let proto::update_language_server::Variant::StatusUpdate(status_update) = message { @@ -191,7 +173,6 @@ impl ActivityIndicator { } cx.notify() } - _ => {} }, ) .detach(); @@ -206,9 +187,10 @@ impl ActivityIndicator { cx.subscribe( &project.read(cx).git_store().clone(), - |_, _, event: &GitStoreEvent, cx| match event { - project::git_store::GitStoreEvent::JobsUpdated => cx.notify(), - _ => {} + |_, _, event: &GitStoreEvent, cx| { + if let project::git_store::GitStoreEvent::JobsUpdated = event { + cx.notify() + } }, ) .detach(); @@ -230,7 +212,8 @@ impl ActivityIndicator { server_name, status, } => { - let create_buffer = project.update(cx, |project, cx| project.create_buffer(cx)); + let create_buffer = + project.update(cx, |project, cx| project.create_buffer(false, cx)); let status = status.clone(); let server_name = server_name.clone(); cx.spawn_in(window, async move |workspace, cx| { @@ -410,13 +393,7 @@ impl ActivityIndicator { icon: Some( Icon::new(IconName::ArrowCircle) .size(IconSize::Small) - .with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| { - icon.transform(Transformation::rotate(percentage(delta))) - }, - ) + .with_rotate_animation(2) .into_any_element(), ), message, @@ -438,11 +415,7 @@ impl ActivityIndicator { icon: Some( Icon::new(IconName::ArrowCircle) .size(IconSize::Small) - .with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), - ) + .with_rotate_animation(2) .into_any_element(), ), message: format!("Debug: {}", session.read(cx).adapter()), @@ -458,26 +431,20 @@ impl ActivityIndicator { .map(|r| r.read(cx)) .and_then(Repository::current_job); // Show any long-running git command - if let Some(job_info) = current_job { - if Instant::now() - job_info.start >= GIT_OPERATION_DELAY { - return Some(Content { - icon: Some( - Icon::new(IconName::ArrowCircle) - .size(IconSize::Small) - .with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| { - icon.transform(Transformation::rotate(percentage(delta))) - }, - ) - .into_any_element(), - ), - message: job_info.message.into(), - on_click: None, - tooltip_message: None, - }); - } + if let Some(job_info) = current_job + && Instant::now() - job_info.start >= GIT_OPERATION_DELAY + { + return Some(Content { + icon: Some( + Icon::new(IconName::ArrowCircle) + .size(IconSize::Small) + .with_rotate_animation(2) + .into_any_element(), + ), + message: job_info.message.into(), + on_click: None, + tooltip_message: None, + }); } // Show any language server installation info. @@ -678,8 +645,9 @@ impl ActivityIndicator { } // Show any application auto-update info. - if let Some(updater) = &self.auto_updater { - return match &updater.read(cx).status() { + self.auto_updater + .as_ref() + .and_then(|updater| match &updater.read(cx).status() { AutoUpdateStatus::Checking => Some(Content { icon: Some( Icon::new(IconName::Download) @@ -702,7 +670,7 @@ impl ActivityIndicator { on_click: Some(Arc::new(|this, window, cx| { this.dismiss_error_message(&DismissErrorMessage, window, cx) })), - tooltip_message: Some(Self::version_tooltip_message(&version)), + tooltip_message: Some(Self::version_tooltip_message(version)), }), AutoUpdateStatus::Installing { version } => Some(Content { icon: Some( @@ -714,13 +682,13 @@ impl ActivityIndicator { on_click: Some(Arc::new(|this, window, cx| { this.dismiss_error_message(&DismissErrorMessage, window, cx) })), - tooltip_message: Some(Self::version_tooltip_message(&version)), + tooltip_message: Some(Self::version_tooltip_message(version)), }), AutoUpdateStatus::Updated { version } => Some(Content { icon: None, message: "Click to restart and update Zed".to_string(), on_click: Some(Arc::new(move |_, _, cx| workspace::reload(cx))), - tooltip_message: Some(Self::version_tooltip_message(&version)), + tooltip_message: Some(Self::version_tooltip_message(version)), }), AutoUpdateStatus::Errored => Some(Content { icon: Some( @@ -735,29 +703,49 @@ impl ActivityIndicator { tooltip_message: None, }), AutoUpdateStatus::Idle => None, - }; - } - - if let Some(extension_store) = - ExtensionStore::try_global(cx).map(|extension_store| extension_store.read(cx)) - { - if let Some(extension_id) = extension_store.outstanding_operations().keys().next() { - return Some(Content { - icon: Some( - Icon::new(IconName::Download) - .size(IconSize::Small) - .into_any_element(), - ), - message: format!("Updating {extension_id} extension…"), - on_click: Some(Arc::new(|this, window, cx| { - this.dismiss_error_message(&DismissErrorMessage, window, cx) - })), - tooltip_message: None, - }); - } - } + }) + .or_else(|| { + if let Some(extension_store) = + ExtensionStore::try_global(cx).map(|extension_store| extension_store.read(cx)) + && let Some((extension_id, operation)) = + extension_store.outstanding_operations().iter().next() + { + let (message, icon, rotate) = match operation { + ExtensionOperation::Install => ( + format!("Installing {extension_id} extension…"), + IconName::LoadCircle, + true, + ), + ExtensionOperation::Upgrade => ( + format!("Updating {extension_id} extension…"), + IconName::Download, + false, + ), + ExtensionOperation::Remove => ( + format!("Removing {extension_id} extension…"), + IconName::LoadCircle, + true, + ), + }; - None + Some(Content { + icon: Some(Icon::new(icon).size(IconSize::Small).map(|this| { + if rotate { + this.with_rotate_animation(3).into_any_element() + } else { + this.into_any_element() + } + })), + message, + on_click: Some(Arc::new(|this, window, cx| { + this.dismiss_error_message(&Default::default(), window, cx) + })), + tooltip_message: None, + }) + } else { + None + } + }) } fn version_tooltip_message(version: &VersionCheckType) -> String { diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index 53ad2f496758bfc288a5c9dc25f8e2e99851d5b2..76f96647c7af5692ca9b4b146e27f9f7c19c7995 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -31,7 +31,6 @@ collections.workspace = true component.workspace = true context_server.workspace = true convert_case.workspace = true -feature_flags.workspace = true fs.workspace = true futures.workspace = true git.workspace = true @@ -64,6 +63,7 @@ time.workspace = true util.workspace = true uuid.workspace = true workspace-hack.workspace = true +zed_env_vars.workspace = true zstd.workspace = true [dev-dependencies] diff --git a/crates/agent/src/agent_profile.rs b/crates/agent/src/agent_profile.rs index 38e697dd9bbd5ede89ad23575bb1e123dfb2c350..c9e73372f60686cf330531926f4129e9c9b25db8 100644 --- a/crates/agent/src/agent_profile.rs +++ b/crates/agent/src/agent_profile.rs @@ -90,7 +90,7 @@ impl AgentProfile { return false; }; - return Self::is_enabled(settings, source, tool_name); + Self::is_enabled(settings, source, tool_name) } fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool { @@ -132,7 +132,7 @@ mod tests { }); let tool_set = default_tool_set(cx); - let profile = AgentProfile::new(id.clone(), tool_set); + let profile = AgentProfile::new(id, tool_set); let mut enabled_tools = cx .read(|cx| profile.enabled_tools(cx)) @@ -169,7 +169,7 @@ mod tests { }); let tool_set = default_tool_set(cx); - let profile = AgentProfile::new(id.clone(), tool_set); + let profile = AgentProfile::new(id, tool_set); let mut enabled_tools = cx .read(|cx| profile.enabled_tools(cx)) @@ -202,7 +202,7 @@ mod tests { }); let tool_set = default_tool_set(cx); - let profile = AgentProfile::new(id.clone(), tool_set); + let profile = AgentProfile::new(id, tool_set); let mut enabled_tools = cx .read(|cx| profile.enabled_tools(cx)) diff --git a/crates/agent/src/context.rs b/crates/agent/src/context.rs index 8cdb87ef8d9f3363e68c14053c01f34ece64b3b9..71fa8176a012569373df927eb37145208d6a105d 100644 --- a/crates/agent/src/context.rs +++ b/crates/agent/src/context.rs @@ -202,23 +202,22 @@ impl FileContextHandle { } if let Ok(snapshot) = buffer.read_with(cx, |buffer, _| buffer.snapshot()) { - if let Some(outline) = snapshot.outline(None) { - let items = outline - .items - .into_iter() - .map(|item| item.to_point(&snapshot)); - - if let Ok(outline_text) = - outline::render_outline(items, None, 0, usize::MAX).await - { - let context = AgentContext::File(FileContext { - handle: self, - full_path, - text: outline_text.into(), - is_outline: true, - }); - return Some((context, vec![buffer])); - } + let items = snapshot + .outline(None) + .items + .into_iter() + .map(|item| item.to_point(&snapshot)); + + if let Ok(outline_text) = + outline::render_outline(items, None, 0, usize::MAX).await + { + let context = AgentContext::File(FileContext { + handle: self, + full_path, + text: outline_text.into(), + is_outline: true, + }); + return Some((context, vec![buffer])); } } } @@ -362,7 +361,7 @@ impl Display for DirectoryContext { let mut is_first = true; for descendant in &self.descendants { if !is_first { - write!(f, "\n")?; + writeln!(f)?; } else { is_first = false; } @@ -650,7 +649,7 @@ impl TextThreadContextHandle { impl Display for TextThreadContext { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { // TODO: escape title? - write!(f, "\n", self.title)?; + writeln!(f, "", self.title)?; write!(f, "{}", self.text.trim())?; write!(f, "\n") } @@ -716,7 +715,7 @@ impl RulesContextHandle { impl Display for RulesContext { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if let Some(title) = &self.title { - write!(f, "Rules title: {}\n", title)?; + writeln!(f, "Rules title: {}", title)?; } let code_block = MarkdownCodeBlock { tag: "", diff --git a/crates/agent/src/context_server_tool.rs b/crates/agent/src/context_server_tool.rs index 22d1a72bf5f833a6594f34fd8f5d7b9102740740..696c569356bca36adf54bc84ec52fa7295048b75 100644 --- a/crates/agent/src/context_server_tool.rs +++ b/crates/agent/src/context_server_tool.rs @@ -86,15 +86,13 @@ impl Tool for ContextServerTool { ) -> ToolResult { if let Some(server) = self.store.read(cx).get_running_server(&self.server_id) { let tool_name = self.tool.name.clone(); - let server_clone = server.clone(); - let input_clone = input.clone(); cx.spawn(async move |_cx| { - let Some(protocol) = server_clone.client() else { + let Some(protocol) = server.client() else { bail!("Context server not initialized"); }; - let arguments = if let serde_json::Value::Object(map) = input_clone { + let arguments = if let serde_json::Value::Object(map) = input { Some(map.into_iter().collect()) } else { None diff --git a/crates/agent/src/context_store.rs b/crates/agent/src/context_store.rs index 60ba5527dcca22d81b7da62657c6abc00aa51607..b531852a184ffeaf86862990f03210ceb6033395 100644 --- a/crates/agent/src/context_store.rs +++ b/crates/agent/src/context_store.rs @@ -338,11 +338,9 @@ impl ContextStore { image_task, context_id: self.next_context_id.post_inc(), }); - if self.has_context(&context) { - if remove_if_exists { - self.remove_context(&context, cx); - return None; - } + if self.has_context(&context) && remove_if_exists { + self.remove_context(&context, cx); + return None; } self.insert_context(context.clone(), cx); diff --git a/crates/agent/src/history_store.rs b/crates/agent/src/history_store.rs index eb39c3e454c25fdc87baeffd550ea5cb29155aab..8f4c1a1e2e6533b4760956c60bfc1a26123df92a 100644 --- a/crates/agent/src/history_store.rs +++ b/crates/agent/src/history_store.rs @@ -254,10 +254,9 @@ impl HistoryStore { } pub fn remove_recently_opened_thread(&mut self, id: ThreadId, cx: &mut Context) { - self.recently_opened_entries.retain(|entry| match entry { - HistoryEntryId::Thread(thread_id) if thread_id == &id => false, - _ => true, - }); + self.recently_opened_entries.retain( + |entry| !matches!(entry, HistoryEntryId::Thread(thread_id) if thread_id == &id), + ); self.save_recently_opened_entries(cx); } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 1d417efbbae009ef1b08c240fd195534192c28f6..7b70fde56ab1e7acb6705aeace82f142dc28a9f3 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -9,14 +9,16 @@ use crate::{ tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState}, }; use action_log::ActionLog; -use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT}; +use agent_settings::{ + AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT, + SUMMARIZE_THREAD_PROMPT, +}; use anyhow::{Result, anyhow}; use assistant_tool::{AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; use client::{ModelRequestUsage, RequestUsage}; use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit}; use collections::HashMap; -use feature_flags::{self, FeatureFlagAppExt}; use futures::{FutureExt, StreamExt as _, future::Shared}; use git::repository::DiffType; use gpui::{ @@ -108,7 +110,7 @@ impl std::fmt::Display for PromptId { } #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] -pub struct MessageId(pub(crate) usize); +pub struct MessageId(pub usize); impl MessageId { fn post_inc(&mut self) -> Self { @@ -179,7 +181,7 @@ impl Message { } } - pub fn to_string(&self) -> String { + pub fn to_message_content(&self) -> String { let mut result = String::new(); if !self.loaded_context.text.is_empty() { @@ -385,10 +387,8 @@ pub struct Thread { cumulative_token_usage: TokenUsage, exceeded_window_error: Option, tool_use_limit_reached: bool, - feedback: Option, retry_state: Option, message_feedback: HashMap, - last_auto_capture_at: Option, last_received_chunk_at: Option, request_callback: Option< Box])>, @@ -486,15 +486,13 @@ impl Thread { cumulative_token_usage: TokenUsage::default(), exceeded_window_error: None, tool_use_limit_reached: false, - feedback: None, retry_state: None, message_feedback: HashMap::default(), - last_auto_capture_at: None, last_error_context: None, last_received_chunk_at: None, request_callback: None, remaining_turns: u32::MAX, - configured_model: configured_model.clone(), + configured_model, profile: AgentProfile::new(profile_id, tools), } } @@ -532,7 +530,7 @@ impl Thread { .and_then(|model| { let model = SelectedModel { provider: model.provider.clone().into(), - model: model.model.clone().into(), + model: model.model.into(), }; registry.select_model(&model, cx) }) @@ -612,9 +610,7 @@ impl Thread { cumulative_token_usage: serialized.cumulative_token_usage, exceeded_window_error: None, tool_use_limit_reached: serialized.tool_use_limit_reached, - feedback: None, message_feedback: HashMap::default(), - last_auto_capture_at: None, last_error_context: None, last_received_chunk_at: None, request_callback: None, @@ -844,11 +840,17 @@ impl Thread { .await .unwrap_or(false); - if !equal { - this.update(cx, |this, cx| { - this.insert_checkpoint(pending_checkpoint, cx) - })?; - } + this.update(cx, |this, cx| { + this.pending_checkpoint = if equal { + Some(pending_checkpoint) + } else { + this.insert_checkpoint(pending_checkpoint, cx); + Some(ThreadCheckpoint { + message_id: this.next_message_id, + git_checkpoint: final_checkpoint, + }) + } + })?; Ok(()) } @@ -1027,8 +1029,6 @@ impl Thread { }); } - self.auto_capture_telemetry(cx); - message_id } @@ -1643,17 +1643,15 @@ impl Thread { }; self.tool_use - .request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx); + .request_tool_use(tool_message_id, tool_use, tool_use_metadata, cx); - let pending_tool_use = self.tool_use.insert_tool_output( - tool_use_id.clone(), + self.tool_use.insert_tool_output( + tool_use_id, tool_name, tool_output, self.configured_model.as_ref(), self.completion_mode, - ); - - pending_tool_use + ) } pub fn stream_completion( @@ -1686,7 +1684,7 @@ impl Thread { self.last_received_chunk_at = Some(Instant::now()); let task = cx.spawn(async move |thread, cx| { - let stream_completion_future = model.stream_completion(request, &cx); + let stream_completion_future = model.stream_completion(request, cx); let initial_token_usage = thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage); let stream_completion = async { @@ -1818,7 +1816,7 @@ impl Thread { let streamed_input = if tool_use.is_input_complete { None } else { - Some((&tool_use.input).clone()) + Some(tool_use.input.clone()) }; let ui_text = thread.tool_use.request_tool_use( @@ -1900,7 +1898,6 @@ impl Thread { cx.emit(ThreadEvent::StreamedCompletion); cx.notify(); - thread.auto_capture_telemetry(cx); Ok(()) })??; @@ -1968,11 +1965,9 @@ impl Thread { if let Some(prev_message) = thread.messages.get(ix - 1) - { - if prev_message.role == Role::Assistant { + && prev_message.role == Role::Assistant { break; } - } } } @@ -2045,7 +2040,7 @@ impl Thread { retry_scheduled = thread .handle_retryable_error_with_delay( - &completion_error, + completion_error, Some(retry_strategy), model.clone(), intent, @@ -2075,8 +2070,6 @@ impl Thread { request_callback(request, response_events); } - thread.auto_capture_telemetry(cx); - if let Ok(initial_usage) = initial_token_usage { let usage = thread.cumulative_token_usage - initial_usage; @@ -2124,7 +2117,7 @@ impl Thread { self.pending_summary = cx.spawn(async move |this, cx| { let result = async { - let mut messages = model.model.stream_completion(request, &cx).await?; + let mut messages = model.model.stream_completion(request, cx).await?; let mut new_summary = String::new(); while let Some(event) = messages.next().await { @@ -2432,12 +2425,10 @@ impl Thread { return; } - let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt"); - let request = self.to_summarize_request( &model, CompletionIntent::ThreadContextSummarization, - added_user_message.into(), + SUMMARIZE_THREAD_DETAILED_PROMPT.into(), cx, ); @@ -2450,7 +2441,7 @@ impl Thread { // which result to prefer (the old task could complete after the new one, resulting in a // stale summary). self.detailed_summary_task = cx.spawn(async move |thread, cx| { - let stream = model.stream_completion_text(request, &cx); + let stream = model.stream_completion_text(request, cx); let Some(mut messages) = stream.await.log_err() else { thread .update(cx, |thread, _cx| { @@ -2479,13 +2470,13 @@ impl Thread { .ok()?; // Save thread so its summary can be reused later - if let Some(thread) = thread.upgrade() { - if let Ok(Ok(save_task)) = cx.update(|cx| { + if let Some(thread) = thread.upgrade() + && let Ok(Ok(save_task)) = cx.update(|cx| { thread_store .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx)) - }) { - save_task.await.log_err(); - } + }) + { + save_task.await.log_err(); } Some(()) @@ -2530,7 +2521,6 @@ impl Thread { model: Arc, cx: &mut Context, ) -> Vec { - self.auto_capture_telemetry(cx); let request = Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx)); let pending_tool_uses = self @@ -2734,13 +2724,11 @@ impl Thread { window: Option, cx: &mut Context, ) { - if self.all_tools_finished() { - if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() { - if !canceled { - self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx); - } - self.auto_capture_telemetry(cx); - } + if self.all_tools_finished() + && let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() + && !canceled + { + self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx); } cx.emit(ThreadEvent::ToolFinished { @@ -2796,10 +2784,6 @@ impl Thread { cx.emit(ThreadEvent::CancelEditing); } - pub fn feedback(&self) -> Option { - self.feedback - } - pub fn message_feedback(&self, message_id: MessageId) -> Option { self.message_feedback.get(&message_id).copied() } @@ -2832,7 +2816,7 @@ impl Thread { let message_content = self .message(message_id) - .map(|msg| msg.to_string()) + .map(|msg| msg.to_message_content()) .unwrap_or_default(); cx.background_spawn(async move { @@ -2861,52 +2845,6 @@ impl Thread { }) } - pub fn report_feedback( - &mut self, - feedback: ThreadFeedback, - cx: &mut Context, - ) -> Task> { - let last_assistant_message_id = self - .messages - .iter() - .rev() - .find(|msg| msg.role == Role::Assistant) - .map(|msg| msg.id); - - if let Some(message_id) = last_assistant_message_id { - self.report_message_feedback(message_id, feedback, cx) - } else { - let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx); - let serialized_thread = self.serialize(cx); - let thread_id = self.id().clone(); - let client = self.project.read(cx).client(); - self.feedback = Some(feedback); - cx.notify(); - - cx.background_spawn(async move { - let final_project_snapshot = final_project_snapshot.await; - let serialized_thread = serialized_thread.await?; - let thread_data = serde_json::to_value(serialized_thread) - .unwrap_or_else(|_| serde_json::Value::Null); - - let rating = match feedback { - ThreadFeedback::Positive => "positive", - ThreadFeedback::Negative => "negative", - }; - telemetry::event!( - "Assistant Thread Rated", - rating, - thread_id, - thread_data, - final_project_snapshot - ); - client.telemetry().flush_events().await; - - Ok(()) - }) - } - } - /// Create a snapshot of the current project state including git information and unsaved buffers. fn project_snapshot( project: Entity, @@ -2927,11 +2865,11 @@ impl Thread { let buffer_store = project.read(app_cx).buffer_store(); for buffer_handle in buffer_store.read(app_cx).buffers() { let buffer = buffer_handle.read(app_cx); - if buffer.is_dirty() { - if let Some(file) = buffer.file() { - let path = file.path().to_string_lossy().to_string(); - unsaved_buffers.push(path); - } + if buffer.is_dirty() + && let Some(file) = buffer.file() + { + let path = file.path().to_string_lossy().to_string(); + unsaved_buffers.push(path); } } }) @@ -3141,50 +3079,6 @@ impl Thread { &self.project } - pub fn auto_capture_telemetry(&mut self, cx: &mut Context) { - if !cx.has_flag::() { - return; - } - - let now = Instant::now(); - if let Some(last) = self.last_auto_capture_at { - if now.duration_since(last).as_secs() < 10 { - return; - } - } - - self.last_auto_capture_at = Some(now); - - let thread_id = self.id().clone(); - let github_login = self - .project - .read(cx) - .user_store() - .read(cx) - .current_user() - .map(|user| user.github_login.clone()); - let client = self.project.read(cx).client(); - let serialize_task = self.serialize(cx); - - cx.background_executor() - .spawn(async move { - if let Ok(serialized_thread) = serialize_task.await { - if let Ok(thread_data) = serde_json::to_value(serialized_thread) { - telemetry::event!( - "Agent Thread Auto-Captured", - thread_id = thread_id.to_string(), - thread_data = thread_data, - auto_capture_reason = "tracked_user", - github_login = github_login - ); - - client.telemetry().flush_events().await; - } - } - }) - .detach(); - } - pub fn cumulative_token_usage(&self) -> TokenUsage { self.cumulative_token_usage } @@ -3227,13 +3121,13 @@ impl Thread { .model .max_token_count_for_mode(self.completion_mode().into()); - if let Some(exceeded_error) = &self.exceeded_window_error { - if model.model.id() == exceeded_error.model_id { - return Some(TotalTokenUsage { - total: exceeded_error.token_count, - max, - }); - } + if let Some(exceeded_error) = &self.exceeded_window_error + && model.model.id() == exceeded_error.model_id + { + return Some(TotalTokenUsage { + total: exceeded_error.token_count, + max, + }); } let total = self @@ -3294,7 +3188,7 @@ impl Thread { self.configured_model.as_ref(), self.completion_mode, ); - self.tool_finished(tool_use_id.clone(), None, true, window, cx); + self.tool_finished(tool_use_id, None, true, window, cx); } } @@ -3926,7 +3820,7 @@ fn main() {{ AgentSettings { model_parameters: vec![LanguageModelParameters { provider: Some(model.provider_id().0.to_string().into()), - model: Some(model.id().0.clone()), + model: Some(model.id().0), temperature: Some(0.66), }], ..AgentSettings::get_global(cx).clone() @@ -3946,7 +3840,7 @@ fn main() {{ AgentSettings { model_parameters: vec![LanguageModelParameters { provider: None, - model: Some(model.id().0.clone()), + model: Some(model.id().0), temperature: Some(0.66), }], ..AgentSettings::get_global(cx).clone() @@ -3986,7 +3880,7 @@ fn main() {{ AgentSettings { model_parameters: vec![LanguageModelParameters { provider: Some("anthropic".into()), - model: Some(model.id().0.clone()), + model: Some(model.id().0), temperature: Some(0.66), }], ..AgentSettings::get_global(cx).clone() @@ -4037,7 +3931,7 @@ fn main() {{ }); let fake_model = model.as_fake(); - simulate_successful_response(&fake_model, cx); + simulate_successful_response(fake_model, cx); // Should start generating summary when there are >= 2 messages thread.read_with(cx, |thread, _| { @@ -4132,7 +4026,7 @@ fn main() {{ }); let fake_model = model.as_fake(); - simulate_successful_response(&fake_model, cx); + simulate_successful_response(fake_model, cx); thread.read_with(cx, |thread, _| { // State is still Error, not Generating @@ -5331,7 +5225,7 @@ fn main() {{ } #[gpui::test] - async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) { + async fn test_retry_canceled_on_stop(cx: &mut TestAppContext) { init_test_settings(cx); let project = create_test_project(cx, json!({})).await; @@ -5387,7 +5281,7 @@ fn main() {{ "Should have no pending completions after cancellation" ); - // Verify the retry was cancelled by checking retry state + // Verify the retry was canceled by checking retry state thread.read_with(cx, |thread, _| { if let Some(retry_state) = &thread.retry_state { panic!( @@ -5414,7 +5308,7 @@ fn main() {{ }); let fake_model = model.as_fake(); - simulate_successful_response(&fake_model, cx); + simulate_successful_response(fake_model, cx); thread.read_with(cx, |thread, _| { assert!(matches!(thread.summary(), ThreadSummary::Generating)); diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 12c94a522d52de78e52dab4764a7f187054eca47..2eae758b835d5d79ccf86f18be032f2d9bb87c2b 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -41,8 +41,7 @@ use std::{ }; use util::ResultExt as _; -pub static ZED_STATELESS: std::sync::LazyLock = - std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty())); +use zed_env_vars::ZED_STATELESS; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum DataType { @@ -74,7 +73,7 @@ impl Column for DataType { } } -const RULES_FILE_NAMES: [&'static str; 9] = [ +const RULES_FILE_NAMES: [&str; 9] = [ ".rules", ".cursorrules", ".windsurfrules", @@ -581,33 +580,32 @@ impl ThreadStore { return; }; - if protocol.capable(context_server::protocol::ServerCapability::Tools) { - if let Some(response) = protocol + if protocol.capable(context_server::protocol::ServerCapability::Tools) + && let Some(response) = protocol .request::(()) .await .log_err() - { - let tool_ids = tool_working_set - .update(cx, |tool_working_set, cx| { - tool_working_set.extend( - response.tools.into_iter().map(|tool| { - Arc::new(ContextServerTool::new( - context_server_store.clone(), - server.id(), - tool, - )) as Arc - }), - cx, - ) - }) - .log_err(); - - if let Some(tool_ids) = tool_ids { - this.update(cx, |this, _| { - this.context_server_tool_ids.insert(server_id, tool_ids); - }) - .log_err(); - } + { + let tool_ids = tool_working_set + .update(cx, |tool_working_set, cx| { + tool_working_set.extend( + response.tools.into_iter().map(|tool| { + Arc::new(ContextServerTool::new( + context_server_store.clone(), + server.id(), + tool, + )) as Arc + }), + cx, + ) + }) + .log_err(); + + if let Some(tool_ids) = tool_ids { + this.update(cx, |this, _| { + this.context_server_tool_ids.insert(server_id, tool_ids); + }) + .log_err(); } } }) @@ -697,13 +695,14 @@ impl SerializedThreadV0_1_0 { let mut messages: Vec = Vec::with_capacity(self.0.messages.len()); for message in self.0.messages { - if message.role == Role::User && !message.tool_results.is_empty() { - if let Some(last_message) = messages.last_mut() { - debug_assert!(last_message.role == Role::Assistant); - - last_message.tool_results = message.tool_results; - continue; - } + if message.role == Role::User + && !message.tool_results.is_empty() + && let Some(last_message) = messages.last_mut() + { + debug_assert!(last_message.role == Role::Assistant); + + last_message.tool_results = message.tool_results; + continue; } messages.push(message); @@ -895,6 +894,17 @@ impl ThreadsDatabase { let connection = if *ZED_STATELESS { Connection::open_memory(Some("THREAD_FALLBACK_DB")) + } else if cfg!(any(feature = "test-support", test)) { + // rust stores the name of the test on the current thread. + // We use this to automatically create a database that will + // be shared within the test (for the test_retrieve_old_thread) + // but not with concurrent tests. + let thread = std::thread::current(); + let test_name = thread.name(); + Connection::open_memory(Some(&format!( + "THREAD_FALLBACK_{}", + test_name.unwrap_or_default() + ))) } else { Connection::open_file(&sqlite_path.to_string_lossy()) }; diff --git a/crates/agent/src/tool_use.rs b/crates/agent/src/tool_use.rs index 7392c0878d17adf8038292b10a7a8c349d3ec4e8..962dca591fb66f4679d44b8e8a4733c879bc2e0c 100644 --- a/crates/agent/src/tool_use.rs +++ b/crates/agent/src/tool_use.rs @@ -112,19 +112,13 @@ impl ToolUseState { }, ); - if let Some(window) = &mut window { - if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) { - if let Some(output) = tool_result.output.clone() { - if let Some(card) = tool.deserialize_card( - output, - project.clone(), - window, - cx, - ) { - this.tool_result_cards.insert(tool_use_id, card); - } - } - } + if let Some(window) = &mut window + && let Some(tool) = this.tools.read(cx).tool(tool_use, cx) + && let Some(output) = tool_result.output.clone() + && let Some(card) = + tool.deserialize_card(output, project.clone(), window, cx) + { + this.tool_result_cards.insert(tool_use_id, card); } } } @@ -137,7 +131,7 @@ impl ToolUseState { } pub fn cancel_pending(&mut self) -> Vec { - let mut cancelled_tool_uses = Vec::new(); + let mut canceled_tool_uses = Vec::new(); self.pending_tool_uses_by_id .retain(|tool_use_id, tool_use| { if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) { @@ -155,10 +149,10 @@ impl ToolUseState { is_error: true, }, ); - cancelled_tool_uses.push(tool_use.clone()); + canceled_tool_uses.push(tool_use.clone()); false }); - cancelled_tool_uses + canceled_tool_uses } pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> { @@ -281,7 +275,7 @@ impl ToolUseState { pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool { self.tool_uses_by_assistant_message .get(&assistant_message_id) - .map_or(false, |results| !results.is_empty()) + .is_some_and(|results| !results.is_empty()) } pub fn tool_result( diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index ac1840e5e53f7812a792fa207de3ba24a64e355b..b712bed258dfb69ddf81a1ba431ec7a3566b9baf 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -8,24 +8,33 @@ license = "GPL-3.0-or-later" [lib] path = "src/agent2.rs" +[features] +test-support = ["db/test-support"] +e2e = [] + [lints] workspace = true [dependencies] acp_thread.workspace = true action_log.workspace = true +agent.workspace = true agent-client-protocol.workspace = true agent_servers.workspace = true agent_settings.workspace = true anyhow.workspace = true +assistant_context.workspace = true assistant_tool.workspace = true assistant_tools.workspace = true chrono.workspace = true +client.workspace = true cloud_llm_client.workspace = true collections.workspace = true context_server.workspace = true +db.workspace = true fs.workspace = true futures.workspace = true +git.workspace = true gpui.workspace = true handlebars = { workspace = true, features = ["rust-embed"] } html_to_markdown.workspace = true @@ -37,8 +46,8 @@ language_model.workspace = true language_models.workspace = true log.workspace = true open.workspace = true +parking_lot.workspace = true paths.workspace = true -portable-pty.workspace = true project.workspace = true prompt_store.workspace = true rust-embed.workspace = true @@ -47,25 +56,34 @@ serde.workspace = true serde_json.workspace = true settings.workspace = true smol.workspace = true +sqlez.workspace = true task.workspace = true +telemetry.workspace = true terminal.workspace = true +thiserror.workspace = true text.workspace = true ui.workspace = true util.workspace = true uuid.workspace = true watch.workspace = true web_search.workspace = true -which.workspace = true workspace-hack.workspace = true +zed_env_vars.workspace = true +zstd.workspace = true [dev-dependencies] +agent = { workspace = true, "features" = ["test-support"] } +agent_servers = { workspace = true, "features" = ["test-support"] } +assistant_context = { workspace = true, "features" = ["test-support"] } ctor.workspace = true client = { workspace = true, "features" = ["test-support"] } clock = { workspace = true, "features" = ["test-support"] } context_server = { workspace = true, "features" = ["test-support"] } +db = { workspace = true, "features" = ["test-support"] } editor = { workspace = true, "features" = ["test-support"] } env_logger.workspace = true fs = { workspace = true, "features" = ["test-support"] } +git = { workspace = true, "features" = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] } gpui_tokio.workspace = true language = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 6ebcece2b566886dc6503d57af1663a51f7e4c01..6e0df0cffd8e83c446c4acd1fde74c0f8e4b5b8c 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,16 +1,17 @@ -use crate::{AgentResponseEvent, Thread, templates::Templates}; use crate::{ - ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool, - EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, - OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent, - WebSearchTool, + ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization, + UserMessageContent, templates::Templates, }; -use acp_thread::AgentModelSelector; +use crate::{HistoryStore, TerminalHandle, ThreadEnvironment, TitleUpdated, TokenUsageUpdated}; +use acp_thread::{AcpThread, AgentModelSelector}; +use action_log::ActionLog; use agent_client_protocol as acp; use agent_settings::AgentSettings; use anyhow::{Context as _, Result, anyhow}; use collections::{HashSet, IndexMap}; use fs::Fs; +use futures::channel::{mpsc, oneshot}; +use futures::future::Shared; use futures::{StreamExt, future}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, @@ -21,14 +22,14 @@ use prompt_store::{ ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext, }; use settings::update_settings_file; -use std::cell::RefCell; +use std::any::Any; use std::collections::HashMap; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::rc::Rc; use std::sync::Arc; use util::ResultExt; -const RULES_FILE_NAMES: [&'static str; 9] = [ +const RULES_FILE_NAMES: [&str; 9] = [ ".rules", ".cursorrules", ".windsurfrules", @@ -50,7 +51,8 @@ struct Session { thread: Entity, /// The ACP thread that handles protocol communication acp_thread: WeakEntity, - _subscription: Subscription, + pending_save: Task<()>, + _subscriptions: Vec, } pub struct LanguageModels { @@ -60,16 +62,19 @@ pub struct LanguageModels { model_list: acp_thread::AgentModelList, refresh_models_rx: watch::Receiver<()>, refresh_models_tx: watch::Sender<()>, + _authenticate_all_providers_task: Task<()>, } impl LanguageModels { - fn new(cx: &App) -> Self { + fn new(cx: &mut App) -> Self { let (refresh_models_tx, refresh_models_rx) = watch::channel(()); + let mut this = Self { models: HashMap::default(), model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()), refresh_models_rx, refresh_models_tx, + _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx), }; this.refresh_list(cx); this @@ -89,8 +94,8 @@ impl LanguageModels { let mut recommended = Vec::new(); for provider in &providers { for model in provider.recommended_models(cx) { - recommended_models.insert(model.id()); - recommended.push(Self::map_language_model_to_info(&model, &provider)); + recommended_models.insert((model.provider_id(), model.id())); + recommended.push(Self::map_language_model_to_info(&model, provider)); } } if !recommended.is_empty() { @@ -106,7 +111,7 @@ impl LanguageModels { for model in provider.provided_models(cx) { let model_info = Self::map_language_model_to_info(&model, &provider); let model_id = model_info.id.clone(); - if !recommended_models.contains(&model.id()) { + if !recommended_models.contains(&(model.provider_id(), model.id())) { provider_models.push(model_info); } models.insert(model_id, model); @@ -149,13 +154,60 @@ impl LanguageModels { fn model_id(model: &Arc) -> acp_thread::AgentModelId { acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into()) } + + fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> { + let authenticate_all_providers = LanguageModelRegistry::global(cx) + .read(cx) + .providers() + .iter() + .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx))) + .collect::>(); + + cx.background_spawn(async move { + for (provider_id, provider_name, authenticate_task) in authenticate_all_providers { + if let Err(err) = authenticate_task.await { + if matches!(err, language_model::AuthenticateError::CredentialsNotFound) { + // Since we're authenticating these providers in the + // background for the purposes of populating the + // language selector, we don't care about providers + // where the credentials are not found. + } else { + // Some providers have noisy failure states that we + // don't want to spam the logs with every time the + // language model selector is initialized. + // + // Ideally these should have more clear failure modes + // that we know are safe to ignore here, like what we do + // with `CredentialsNotFound` above. + match provider_id.0.as_ref() { + "lmstudio" | "ollama" => { + // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated". + // + // These fail noisily, so we don't log them. + } + "copilot_chat" => { + // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors. + } + _ => { + log::error!( + "Failed to authenticate provider: {}: {err}", + provider_name.0 + ); + } + } + } + } + } + }) + } } pub struct NativeAgent { /// Session ID -> Session mapping sessions: HashMap, + history: Entity, /// Shared project context for all threads - project_context: Rc>, + project_context: Entity, project_context_needs_refresh: watch::Sender<()>, _maintain_project_context: Task>, context_server_registry: Entity, @@ -172,12 +224,13 @@ pub struct NativeAgent { impl NativeAgent { pub async fn new( project: Entity, + history: Entity, templates: Arc, prompt_store: Option>, fs: Arc, cx: &mut AsyncApp, ) -> Result> { - log::info!("Creating new NativeAgent"); + log::debug!("Creating new NativeAgent"); let project_context = cx .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))? @@ -199,7 +252,8 @@ impl NativeAgent { watch::channel(()); Self { sessions: HashMap::new(), - project_context: Rc::new(RefCell::new(project_context)), + history, + project_context: cx.new(|_| project_context), project_context_needs_refresh: project_context_needs_refresh_tx, _maintain_project_context: cx.spawn(async move |this, cx| { Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await @@ -217,6 +271,67 @@ impl NativeAgent { }) } + fn register_session( + &mut self, + thread_handle: Entity, + cx: &mut Context, + ) -> Entity { + let connection = Rc::new(NativeAgentConnection(cx.entity())); + + let thread = thread_handle.read(cx); + let session_id = thread.id().clone(); + let title = thread.title(); + let project = thread.project.clone(); + let action_log = thread.action_log.clone(); + let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone(); + let acp_thread = cx.new(|cx| { + acp_thread::AcpThread::new( + title, + connection, + project.clone(), + action_log.clone(), + session_id.clone(), + prompt_capabilities_rx, + cx, + ) + }); + + let registry = LanguageModelRegistry::read_global(cx); + let summarization_model = registry.thread_summary_model().map(|c| c.model); + + thread_handle.update(cx, |thread, cx| { + thread.set_summarization_model(summarization_model, cx); + thread.add_default_tools( + Rc::new(AcpThreadEnvironment { + acp_thread: acp_thread.downgrade(), + }) as _, + cx, + ) + }); + + let subscriptions = vec![ + cx.observe_release(&acp_thread, |this, acp_thread, _cx| { + this.sessions.remove(acp_thread.session_id()); + }), + cx.subscribe(&thread_handle, Self::handle_thread_title_updated), + cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated), + cx.observe(&thread_handle, move |this, thread, cx| { + this.save_thread(thread, cx) + }), + ]; + + self.sessions.insert( + session_id, + Session { + thread: thread_handle, + acp_thread: acp_thread.downgrade(), + _subscriptions: subscriptions, + pending_save: Task::ready(()), + }, + ); + acp_thread + } + pub fn models(&self) -> &LanguageModels { &self.models } @@ -232,7 +347,9 @@ impl NativeAgent { Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx) })? .await; - this.update(cx, |this, _| this.project_context.replace(project_context))?; + this.update(cx, |this, cx| { + this.project_context = cx.new(|_| project_context); + })?; } Ok(()) @@ -385,6 +502,43 @@ impl NativeAgent { }) } + fn handle_thread_title_updated( + &mut self, + thread: Entity, + _: &TitleUpdated, + cx: &mut Context, + ) { + let session_id = thread.read(cx).id(); + let Some(session) = self.sessions.get(session_id) else { + return; + }; + let thread = thread.downgrade(); + let acp_thread = session.acp_thread.clone(); + cx.spawn(async move |_, cx| { + let title = thread.read_with(cx, |thread, _| thread.title())?; + let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?; + task.await + }) + .detach_and_log_err(cx); + } + + fn handle_thread_token_usage_updated( + &mut self, + thread: Entity, + usage: &TokenUsageUpdated, + cx: &mut Context, + ) { + let Some(session) = self.sessions.get(thread.read(cx).id()) else { + return; + }; + session + .acp_thread + .update(cx, |acp_thread, cx| { + acp_thread.update_token_usage(usage.0.clone(), cx); + }) + .ok(); + } + fn handle_project_event( &mut self, _project: Entity, @@ -424,21 +578,251 @@ impl NativeAgent { cx: &mut Context, ) { self.models.refresh_list(cx); + + let registry = LanguageModelRegistry::read_global(cx); + let default_model = registry.default_model().map(|m| m.model); + let summarization_model = registry.thread_summary_model().map(|m| m.model); + for session in self.sessions.values_mut() { - session.thread.update(cx, |thread, _| { - let model_id = LanguageModels::model_id(&thread.selected_model); - if let Some(model) = self.models.model_from_id(&model_id) { - thread.selected_model = model.clone(); + session.thread.update(cx, |thread, cx| { + if thread.model().is_none() + && let Some(model) = default_model.clone() + { + thread.set_model(model, cx); + cx.notify(); } + thread.set_summarization_model(summarization_model.clone(), cx); }); } } + + pub fn open_thread( + &mut self, + id: acp::SessionId, + cx: &mut Context, + ) -> Task>> { + let database_future = ThreadsDatabase::connect(cx); + cx.spawn(async move |this, cx| { + let database = database_future.await.map_err(|err| anyhow!(err))?; + let db_thread = database + .load_thread(id.clone()) + .await? + .with_context(|| format!("no thread found with ID: {id:?}"))?; + + let thread = this.update(cx, |this, cx| { + let action_log = cx.new(|_cx| ActionLog::new(this.project.clone())); + cx.new(|cx| { + Thread::from_db( + id.clone(), + db_thread, + this.project.clone(), + this.project_context.clone(), + this.context_server_registry.clone(), + action_log.clone(), + this.templates.clone(), + cx, + ) + }) + })?; + let acp_thread = + this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?; + let events = thread.update(cx, |thread, cx| thread.replay(cx))?; + cx.update(|cx| { + NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx) + })? + .await?; + Ok(acp_thread) + }) + } + + pub fn thread_summary( + &mut self, + id: acp::SessionId, + cx: &mut Context, + ) -> Task> { + let thread = self.open_thread(id.clone(), cx); + cx.spawn(async move |this, cx| { + let acp_thread = thread.await?; + let result = this + .update(cx, |this, cx| { + this.sessions + .get(&id) + .unwrap() + .thread + .update(cx, |thread, cx| thread.summary(cx)) + })? + .await?; + drop(acp_thread); + Ok(result) + }) + } + + fn save_thread(&mut self, thread: Entity, cx: &mut Context) { + if thread.read(cx).is_empty() { + return; + } + + let database_future = ThreadsDatabase::connect(cx); + let (id, db_thread) = + thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx))); + let Some(session) = self.sessions.get_mut(&id) else { + return; + }; + let history = self.history.clone(); + session.pending_save = cx.spawn(async move |_, cx| { + let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else { + return; + }; + let db_thread = db_thread.await; + database.save_thread(id, db_thread).await.log_err(); + history.update(cx, |history, cx| history.reload(cx)).ok(); + }); + } } /// Wrapper struct that implements the AgentConnection trait #[derive(Clone)] pub struct NativeAgentConnection(pub Entity); +impl NativeAgentConnection { + pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option> { + self.0 + .read(cx) + .sessions + .get(session_id) + .map(|session| session.thread.clone()) + } + + fn run_turn( + &self, + session_id: acp::SessionId, + cx: &mut App, + f: impl 'static + + FnOnce(Entity, &mut App) -> Result>>, + ) -> Task> { + let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| { + agent + .sessions + .get_mut(&session_id) + .map(|s| (s.thread.clone(), s.acp_thread.clone())) + }) else { + return Task::ready(Err(anyhow!("Session not found"))); + }; + log::debug!("Found session for: {}", session_id); + + let response_stream = match f(thread, cx) { + Ok(stream) => stream, + Err(err) => return Task::ready(Err(err)), + }; + Self::handle_thread_events(response_stream, acp_thread, cx) + } + + fn handle_thread_events( + mut events: mpsc::UnboundedReceiver>, + acp_thread: WeakEntity, + cx: &App, + ) -> Task> { + cx.spawn(async move |cx| { + // Handle response stream and forward to session.acp_thread + while let Some(result) = events.next().await { + match result { + Ok(event) => { + log::trace!("Received completion event: {:?}", event); + + match event { + ThreadEvent::UserMessage(message) => { + acp_thread.update(cx, |thread, cx| { + for content in message.content { + thread.push_user_content_block( + Some(message.id.clone()), + content.into(), + cx, + ); + } + })?; + } + ThreadEvent::AgentText(text) => { + acp_thread.update(cx, |thread, cx| { + thread.push_assistant_content_block( + acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + }), + false, + cx, + ) + })?; + } + ThreadEvent::AgentThinking(text) => { + acp_thread.update(cx, |thread, cx| { + thread.push_assistant_content_block( + acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + }), + true, + cx, + ) + })?; + } + ThreadEvent::ToolCallAuthorization(ToolCallAuthorization { + tool_call, + options, + response, + }) => { + let outcome_task = acp_thread.update(cx, |thread, cx| { + thread.request_tool_call_authorization( + tool_call, options, true, cx, + ) + })??; + cx.background_spawn(async move { + if let acp::RequestPermissionOutcome::Selected { option_id } = + outcome_task.await + { + response + .send(option_id) + .map(|_| anyhow!("authorization receiver was dropped")) + .log_err(); + } + }) + .detach(); + } + ThreadEvent::ToolCall(tool_call) => { + acp_thread.update(cx, |thread, cx| { + thread.upsert_tool_call(tool_call, cx) + })??; + } + ThreadEvent::ToolCallUpdate(update) => { + acp_thread.update(cx, |thread, cx| { + thread.update_tool_call(update, cx) + })??; + } + ThreadEvent::Retry(status) => { + acp_thread.update(cx, |thread, cx| { + thread.update_retry_status(status, cx) + })?; + } + ThreadEvent::Stop(stop_reason) => { + log::debug!("Assistant message complete: {:?}", stop_reason); + return Ok(acp::PromptResponse { stop_reason }); + } + } + } + Err(e) => { + log::error!("Error in model response stream: {:?}", e); + return Err(e); + } + } + } + + log::debug!("Response stream completed"); + anyhow::Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + }) + } +} + impl AgentModelSelector for NativeAgentConnection { fn list_models(&self, cx: &mut App) -> Task> { log::debug!("NativeAgentConnection::list_models called"); @@ -456,7 +840,7 @@ impl AgentModelSelector for NativeAgentConnection { model_id: acp_thread::AgentModelId, cx: &mut App, ) -> Task> { - log::info!("Setting model for session {}: {}", session_id, model_id); + log::debug!("Setting model for session {}: {}", session_id, model_id); let Some(thread) = self .0 .read(cx) @@ -471,8 +855,8 @@ impl AgentModelSelector for NativeAgentConnection { return Task::ready(Err(anyhow!("Invalid model ID {}", model_id))); }; - thread.update(cx, |thread, _cx| { - thread.selected_model = model.clone(); + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); }); update_settings_file::( @@ -502,13 +886,15 @@ impl AgentModelSelector for NativeAgentConnection { else { return Task::ready(Err(anyhow!("Session not found"))); }; - let model = thread.read(cx).selected_model.clone(); + let Some(model) = thread.read(cx).model() else { + return Task::ready(Err(anyhow!("Model not found"))); + }; let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id()) else { return Task::ready(Err(anyhow!("Provider not found"))); }; Task::ready(Ok(LanguageModels::map_language_model_to_info( - &model, &provider, + model, &provider, ))) } @@ -522,105 +908,42 @@ impl acp_thread::AgentConnection for NativeAgentConnection { self: Rc, project: Entity, cwd: &Path, - cx: &mut AsyncApp, + cx: &mut App, ) -> Task>> { let agent = self.0.clone(); - log::info!("Creating new thread for project at: {:?}", cwd); + log::debug!("Creating new thread for project at: {:?}", cwd); cx.spawn(async move |cx| { log::debug!("Starting thread creation in async context"); - // Generate session ID - let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); - log::info!("Created session with ID: {}", session_id); - - // Create AcpThread - let acp_thread = cx.update(|cx| { - cx.new(|cx| { - acp_thread::AcpThread::new( - "agent2", - self.clone(), - project.clone(), - session_id.clone(), - cx, - ) - }) - })?; - let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?; - // Create Thread let thread = agent.update( cx, |agent, cx: &mut gpui::Context| -> Result<_> { // Fetch default model from registry settings let registry = LanguageModelRegistry::read_global(cx); - // Log available models for debugging let available_count = registry.available_models(cx).count(); log::debug!("Total available models: {}", available_count); - let default_model = registry - .default_model() - .and_then(|default_model| { - agent - .models - .model_from_id(&LanguageModels::model_id(&default_model.model)) - }) - .ok_or_else(|| { - log::warn!("No default model configured in settings"); - anyhow!( - "No default model. Please configure a default model in settings." - ) - })?; - - let thread = cx.new(|cx| { - let mut thread = Thread::new( + let default_model = registry.default_model().and_then(|default_model| { + agent + .models + .model_from_id(&LanguageModels::model_id(&default_model.model)) + }); + Ok(cx.new(|cx| { + Thread::new( project.clone(), agent.project_context.clone(), agent.context_server_registry.clone(), - action_log.clone(), agent.templates.clone(), default_model, cx, - ); - thread.add_tool(CopyPathTool::new(project.clone())); - thread.add_tool(CreateDirectoryTool::new(project.clone())); - thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone())); - thread.add_tool(DiagnosticsTool::new(project.clone())); - thread.add_tool(EditFileTool::new(cx.entity())); - thread.add_tool(FetchTool::new(project.read(cx).client().http_client())); - thread.add_tool(FindPathTool::new(project.clone())); - thread.add_tool(GrepTool::new(project.clone())); - thread.add_tool(ListDirectoryTool::new(project.clone())); - thread.add_tool(MovePathTool::new(project.clone())); - thread.add_tool(NowTool); - thread.add_tool(OpenTool::new(project.clone())); - thread.add_tool(ReadFileTool::new(project.clone(), action_log)); - thread.add_tool(TerminalTool::new(project.clone(), cx)); - thread.add_tool(ThinkingTool); - thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model. - thread - }); - - Ok(thread) + ) + })) }, )??; - - // Store the session - agent.update(cx, |agent, cx| { - agent.sessions.insert( - session_id, - Session { - thread, - acp_thread: acp_thread.downgrade(), - _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { - this.sessions.remove(acp_thread.session_id()); - }), - }, - ); - })?; - - Ok(acp_thread) + agent.update(cx, |agent, cx| agent.register_session(thread, cx)) }) } @@ -644,166 +967,228 @@ impl acp_thread::AgentConnection for NativeAgentConnection { ) -> Task> { let id = id.expect("UserMessageId is required"); let session_id = params.session_id.clone(); - let agent = self.0.clone(); log::info!("Received prompt request for session: {}", session_id); log::debug!("Prompt blocks count: {}", params.prompt.len()); - cx.spawn(async move |cx| { - // Get session - let (thread, acp_thread) = agent - .update(cx, |agent, _| { - agent - .sessions - .get_mut(&session_id) - .map(|s| (s.thread.clone(), s.acp_thread.clone())) - })? - .ok_or_else(|| { - log::error!("Session not found: {}", session_id); - anyhow::anyhow!("Session not found") - })?; - log::debug!("Found session for: {}", session_id); - + self.run_turn(session_id, cx, |thread, cx| { let content: Vec = params .prompt .into_iter() .map(Into::into) .collect::>(); - log::info!("Converted prompt to message: {} chars", content.len()); + log::debug!("Converted prompt to message: {} chars", content.len()); log::debug!("Message id: {:?}", id); log::debug!("Message content: {:?}", content); - // Get model using the ModelSelector capability (always available for agent2) - // Get the selected model from the thread directly - let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; - - // Send to thread - log::info!("Sending message to thread with model: {:?}", model.name()); - let mut response_stream = - thread.update(cx, |thread, cx| thread.send(id, content, cx))?; - - // Handle response stream and forward to session.acp_thread - while let Some(result) = response_stream.next().await { - match result { - Ok(event) => { - log::trace!("Received completion event: {:?}", event); - - match event { - AgentResponseEvent::Text(text) => { - acp_thread.update(cx, |thread, cx| { - thread.push_assistant_content_block( - acp::ContentBlock::Text(acp::TextContent { - text, - annotations: None, - }), - false, - cx, - ) - })?; - } - AgentResponseEvent::Thinking(text) => { - acp_thread.update(cx, |thread, cx| { - thread.push_assistant_content_block( - acp::ContentBlock::Text(acp::TextContent { - text, - annotations: None, - }), - true, - cx, - ) - })?; - } - AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization { - tool_call, - options, - response, - }) => { - let recv = acp_thread.update(cx, |thread, cx| { - thread.request_tool_call_authorization(tool_call, options, cx) - })?; - cx.background_spawn(async move { - if let Some(option) = recv - .await - .context("authorization sender was dropped") - .log_err() - { - response - .send(option) - .map(|_| anyhow!("authorization receiver was dropped")) - .log_err(); - } - }) - .detach(); - } - AgentResponseEvent::ToolCall(tool_call) => { - acp_thread.update(cx, |thread, cx| { - thread.upsert_tool_call(tool_call, cx) - })?; - } - AgentResponseEvent::ToolCallUpdate(update) => { - acp_thread.update(cx, |thread, cx| { - thread.update_tool_call(update, cx) - })??; - } - AgentResponseEvent::Stop(stop_reason) => { - log::debug!("Assistant message complete: {:?}", stop_reason); - return Ok(acp::PromptResponse { stop_reason }); - } - } - } - Err(e) => { - log::error!("Error in model response stream: {:?}", e); - // TODO: Consider sending an error message to the UI - break; - } - } - } - - log::info!("Response stream completed"); - anyhow::Ok(acp::PromptResponse { - stop_reason: acp::StopReason::EndTurn, - }) + thread.update(cx, |thread, cx| thread.send(id, content, cx)) }) } + fn resume( + &self, + session_id: &acp::SessionId, + _cx: &App, + ) -> Option> { + Some(Rc::new(NativeAgentSessionResume { + connection: self.clone(), + session_id: session_id.clone(), + }) as _) + } + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { log::info!("Cancelling on session: {}", session_id); self.0.update(cx, |agent, cx| { if let Some(agent) = agent.sessions.get(session_id) { - agent.thread.update(cx, |thread, _cx| thread.cancel()); + agent.thread.update(cx, |thread, cx| thread.cancel(cx)); } }); } - fn session_editor( + fn truncate( &self, session_id: &agent_client_protocol::SessionId, + cx: &App, + ) -> Option> { + self.0.read_with(cx, |agent, _cx| { + agent.sessions.get(session_id).map(|session| { + Rc::new(NativeAgentSessionTruncate { + thread: session.thread.clone(), + acp_thread: session.acp_thread.clone(), + }) as _ + }) + }) + } + + fn set_title( + &self, + session_id: &acp::SessionId, + _cx: &App, + ) -> Option> { + Some(Rc::new(NativeAgentSessionSetTitle { + connection: self.clone(), + session_id: session_id.clone(), + }) as _) + } + + fn telemetry(&self) -> Option> { + Some(Rc::new(self.clone()) as Rc) + } + + fn into_any(self: Rc) -> Rc { + self + } +} + +impl acp_thread::AgentTelemetry for NativeAgentConnection { + fn agent_name(&self) -> String { + "Zed".into() + } + + fn thread_data( + &self, + session_id: &acp::SessionId, cx: &mut App, - ) -> Option> { - self.0.update(cx, |agent, _cx| { - agent - .sessions - .get(session_id) - .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _) + ) -> Task> { + let Some(session) = self.0.read(cx).sessions.get(session_id) else { + return Task::ready(Err(anyhow!("Session not found"))); + }; + + let task = session.thread.read(cx).to_db(cx); + cx.background_spawn(async move { + serde_json::to_value(task.await).context("Failed to serialize thread") }) } } -struct NativeAgentSessionEditor(Entity); +struct NativeAgentSessionTruncate { + thread: Entity, + acp_thread: WeakEntity, +} -impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor { - fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { - Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id))) +impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate { + fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { + match self.thread.update(cx, |thread, cx| { + thread.truncate(message_id.clone(), cx)?; + Ok(thread.latest_token_usage()) + }) { + Ok(usage) => { + self.acp_thread + .update(cx, |thread, cx| { + thread.update_token_usage(usage, cx); + }) + .ok(); + Task::ready(Ok(())) + } + Err(error) => Task::ready(Err(error)), + } + } +} + +struct NativeAgentSessionResume { + connection: NativeAgentConnection, + session_id: acp::SessionId, +} + +impl acp_thread::AgentSessionResume for NativeAgentSessionResume { + fn run(&self, cx: &mut App) -> Task> { + self.connection + .run_turn(self.session_id.clone(), cx, |thread, cx| { + thread.update(cx, |thread, cx| thread.resume(cx)) + }) + } +} + +struct NativeAgentSessionSetTitle { + connection: NativeAgentConnection, + session_id: acp::SessionId, +} + +impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle { + fn run(&self, title: SharedString, cx: &mut App) -> Task> { + let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else { + return Task::ready(Err(anyhow!("session not found"))); + }; + let thread = session.thread.clone(); + thread.update(cx, |thread, cx| thread.set_title(title, cx)); + Task::ready(Ok(())) + } +} + +pub struct AcpThreadEnvironment { + acp_thread: WeakEntity, +} + +impl ThreadEnvironment for AcpThreadEnvironment { + fn create_terminal( + &self, + command: String, + cwd: Option, + output_byte_limit: Option, + cx: &mut AsyncApp, + ) -> Task>> { + let task = self.acp_thread.update(cx, |thread, cx| { + thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx) + }); + + let acp_thread = self.acp_thread.clone(); + cx.spawn(async move |cx| { + let terminal = task?.await?; + + let (drop_tx, drop_rx) = oneshot::channel(); + let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?; + + cx.spawn(async move |cx| { + drop_rx.await.ok(); + acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx)) + }) + .detach(); + + let handle = AcpTerminalHandle { + terminal, + _drop_tx: Some(drop_tx), + }; + + Ok(Rc::new(handle) as _) + }) + } +} + +pub struct AcpTerminalHandle { + terminal: Entity, + _drop_tx: Option>, +} + +impl TerminalHandle for AcpTerminalHandle { + fn id(&self, cx: &AsyncApp) -> Result { + self.terminal.read_with(cx, |term, _cx| term.id().clone()) + } + + fn wait_for_exit(&self, cx: &AsyncApp) -> Result>> { + self.terminal + .read_with(cx, |term, _cx| term.wait_for_exit()) + } + + fn current_output(&self, cx: &AsyncApp) -> Result { + self.terminal + .read_with(cx, |term, cx| term.current_output(cx)) } } #[cfg(test)] mod tests { + use crate::HistoryEntryId; + use super::*; - use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo}; + use acp_thread::{ + AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo, MentionUri, + }; use fs::FakeFs; use gpui::TestAppContext; + use indoc::indoc; + use language_model::fake_provider::FakeLanguageModel; use serde_json::json; use settings::SettingsStore; + use util::path; #[gpui::test] async fn test_maintaining_project_context(cx: &mut TestAppContext) { @@ -817,8 +1202,11 @@ mod tests { ) .await; let project = Project::test(fs.clone(), [], cx).await; + let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); let agent = NativeAgent::new( project.clone(), + history_store, Templates::new(), None, fs.clone(), @@ -826,8 +1214,8 @@ mod tests { ) .await .unwrap(); - agent.read_with(cx, |agent, _| { - assert_eq!(agent.project_context.borrow().worktrees, vec![]) + agent.read_with(cx, |agent, cx| { + assert_eq!(agent.project_context.read(cx).worktrees, vec![]) }); let worktree = project @@ -835,9 +1223,9 @@ mod tests { .await .unwrap(); cx.run_until_parked(); - agent.read_with(cx, |agent, _| { + agent.read_with(cx, |agent, cx| { assert_eq!( - agent.project_context.borrow().worktrees, + agent.project_context.read(cx).worktrees, vec![WorktreeContext { root_name: "a".into(), abs_path: Path::new("/a").into(), @@ -852,7 +1240,7 @@ mod tests { agent.read_with(cx, |agent, cx| { let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap(); assert_eq!( - agent.project_context.borrow().worktrees, + agent.project_context.read(cx).worktrees, vec![WorktreeContext { root_name: "a".into(), abs_path: Path::new("/a").into(), @@ -872,9 +1260,12 @@ mod tests { let fs = FakeFs::new(cx.executor()); fs.insert_tree("/", json!({ "a": {} })).await; let project = Project::test(fs.clone(), [], cx).await; + let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); let connection = NativeAgentConnection( NativeAgent::new( project.clone(), + history_store, Templates::new(), None, fs.clone(), @@ -925,9 +1316,13 @@ mod tests { .await; let project = Project::test(fs.clone(), [], cx).await; + let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + // Create the agent and connection let agent = NativeAgent::new( project.clone(), + history_store, Templates::new(), None, fs.clone(), @@ -940,11 +1335,7 @@ mod tests { // Create a thread/session let acp_thread = cx .update(|cx| { - Rc::new(connection.clone()).new_thread( - project.clone(), - Path::new("/a"), - &mut cx.to_async(), - ) + Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx) }) .await .unwrap(); @@ -961,7 +1352,7 @@ mod tests { agent.read_with(cx, |agent, _| { let session = agent.sessions.get(&session_id).unwrap(); session.thread.read_with(cx, |thread, _| { - assert_eq!(thread.selected_model.id().0, "fake"); + assert_eq!(thread.model().unwrap().id().0, "fake"); }); }); @@ -982,6 +1373,158 @@ mod tests { ); } + #[gpui::test] + #[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows + async fn test_save_load_thread(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": { + "b.md": "Lorem" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; + let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let agent = NativeAgent::new( + project.clone(), + history_store.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); + let connection = Rc::new(NativeAgentConnection(agent.clone())); + + let acp_thread = cx + .update(|cx| { + connection + .clone() + .new_thread(project.clone(), Path::new(""), cx) + }) + .await + .unwrap(); + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + let thread = agent.read_with(cx, |agent, _| { + agent.sessions.get(&session_id).unwrap().thread.clone() + }); + + // Ensure empty threads are not saved, even if they get mutated. + let model = Arc::new(FakeLanguageModel::default()); + let summary_model = Arc::new(FakeLanguageModel::default()); + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); + thread.set_summarization_model(Some(summary_model.clone()), cx); + }); + cx.run_until_parked(); + assert_eq!(history_entries(&history_store, cx), vec![]); + + let send = acp_thread.update(cx, |thread, cx| { + thread.send( + vec![ + "What does ".into(), + acp::ContentBlock::ResourceLink(acp::ResourceLink { + name: "b.md".into(), + uri: MentionUri::File { + abs_path: path!("/a/b.md").into(), + } + .to_uri() + .to_string(), + annotations: None, + description: None, + mime_type: None, + size: None, + title: None, + }), + " mean?".into(), + ], + cx, + ) + }); + let send = cx.foreground_executor().spawn(send); + cx.run_until_parked(); + + model.send_last_completion_stream_text_chunk("Lorem."); + model.end_last_completion_stream(); + cx.run_until_parked(); + summary_model.send_last_completion_stream_text_chunk("Explaining /a/b.md"); + summary_model.end_last_completion_stream(); + + send.await.unwrap(); + acp_thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User + + What does [@b.md](file:///a/b.md) mean? + + ## Assistant + + Lorem. + + "} + ) + }); + + cx.run_until_parked(); + + // Drop the ACP thread, which should cause the session to be dropped as well. + cx.update(|_| { + drop(thread); + drop(acp_thread); + }); + agent.read_with(cx, |agent, _| { + assert_eq!(agent.sessions.keys().cloned().collect::>(), []); + }); + + // Ensure the thread can be reloaded from disk. + assert_eq!( + history_entries(&history_store, cx), + vec![( + HistoryEntryId::AcpThread(session_id.clone()), + "Explaining /a/b.md".into() + )] + ); + let acp_thread = agent + .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx)) + .await + .unwrap(); + acp_thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User + + What does [@b.md](file:///a/b.md) mean? + + ## Assistant + + Lorem. + + "} + ) + }); + } + + fn history_entries( + history: &Entity, + cx: &mut TestAppContext, + ) -> Vec<(HistoryEntryId, String)> { + history.read_with(cx, |history, _| { + history + .entries() + .map(|e| (e.id(), e.title().to_string())) + .collect::>() + }) + } + fn init_test(cx: &mut TestAppContext) { env_logger::try_init().ok(); cx.update(|cx| { diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index f13cd1bd673b5e122333264ac3cbcbe83edd7627..1fc9c1cb956d1676c42713b5d9bb2a0b51e8ac90 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -1,13 +1,18 @@ mod agent; +mod db; +mod history_store; mod native_agent_server; mod templates; mod thread; +mod tool_schema; mod tools; #[cfg(test)] mod tests; pub use agent::*; +pub use db::*; +pub use history_store::*; pub use native_agent_server::NativeAgentServer; pub use templates::*; pub use thread::*; diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs new file mode 100644 index 0000000000000000000000000000000000000000..c78725138ffa081cc5b75c883d883b7a155d482c --- /dev/null +++ b/crates/agent2/src/db.rs @@ -0,0 +1,497 @@ +use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent}; +use acp_thread::UserMessageId; +use agent::{thread::DetailedSummaryState, thread_store}; +use agent_client_protocol as acp; +use agent_settings::{AgentProfileId, CompletionMode}; +use anyhow::{Result, anyhow}; +use chrono::{DateTime, Utc}; +use collections::{HashMap, IndexMap}; +use futures::{FutureExt, future::Shared}; +use gpui::{BackgroundExecutor, Global, Task}; +use indoc::indoc; +use parking_lot::Mutex; +use serde::{Deserialize, Serialize}; +use sqlez::{ + bindable::{Bind, Column}, + connection::Connection, + statement::Statement, +}; +use std::sync::Arc; +use ui::{App, SharedString}; +use zed_env_vars::ZED_STATELESS; + +pub type DbMessage = crate::Message; +pub type DbSummary = DetailedSummaryState; +pub type DbLanguageModel = thread_store::SerializedLanguageModel; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DbThreadMetadata { + pub id: acp::SessionId, + #[serde(alias = "summary")] + pub title: SharedString, + pub updated_at: DateTime, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DbThread { + pub title: SharedString, + pub messages: Vec, + pub updated_at: DateTime, + #[serde(default)] + pub detailed_summary: Option, + #[serde(default)] + pub initial_project_snapshot: Option>, + #[serde(default)] + pub cumulative_token_usage: language_model::TokenUsage, + #[serde(default)] + pub request_token_usage: HashMap, + #[serde(default)] + pub model: Option, + #[serde(default)] + pub completion_mode: Option, + #[serde(default)] + pub profile: Option, +} + +impl DbThread { + pub const VERSION: &'static str = "0.3.0"; + + pub fn from_json(json: &[u8]) -> Result { + let saved_thread_json = serde_json::from_slice::(json)?; + match saved_thread_json.get("version") { + Some(serde_json::Value::String(version)) => match version.as_str() { + Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?), + _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?), + }, + _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?), + } + } + + fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result { + let mut messages = Vec::new(); + let mut request_token_usage = HashMap::default(); + + let mut last_user_message_id = None; + for (ix, msg) in thread.messages.into_iter().enumerate() { + let message = match msg.role { + language_model::Role::User => { + let mut content = Vec::new(); + + // Convert segments to content + for segment in msg.segments { + match segment { + thread_store::SerializedMessageSegment::Text { text } => { + content.push(UserMessageContent::Text(text)); + } + thread_store::SerializedMessageSegment::Thinking { text, .. } => { + // User messages don't have thinking segments, but handle gracefully + content.push(UserMessageContent::Text(text)); + } + thread_store::SerializedMessageSegment::RedactedThinking { .. } => { + // User messages don't have redacted thinking, skip. + } + } + } + + // If no content was added, add context as text if available + if content.is_empty() && !msg.context.is_empty() { + content.push(UserMessageContent::Text(msg.context)); + } + + let id = UserMessageId::new(); + last_user_message_id = Some(id.clone()); + + crate::Message::User(UserMessage { + // MessageId from old format can't be meaningfully converted, so generate a new one + id, + content, + }) + } + language_model::Role::Assistant => { + let mut content = Vec::new(); + + // Convert segments to content + for segment in msg.segments { + match segment { + thread_store::SerializedMessageSegment::Text { text } => { + content.push(AgentMessageContent::Text(text)); + } + thread_store::SerializedMessageSegment::Thinking { + text, + signature, + } => { + content.push(AgentMessageContent::Thinking { text, signature }); + } + thread_store::SerializedMessageSegment::RedactedThinking { data } => { + content.push(AgentMessageContent::RedactedThinking(data)); + } + } + } + + // Convert tool uses + let mut tool_names_by_id = HashMap::default(); + for tool_use in msg.tool_uses { + tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone()); + content.push(AgentMessageContent::ToolUse( + language_model::LanguageModelToolUse { + id: tool_use.id, + name: tool_use.name.into(), + raw_input: serde_json::to_string(&tool_use.input) + .unwrap_or_default(), + input: tool_use.input, + is_input_complete: true, + }, + )); + } + + // Convert tool results + let mut tool_results = IndexMap::default(); + for tool_result in msg.tool_results { + let name = tool_names_by_id + .remove(&tool_result.tool_use_id) + .unwrap_or_else(|| SharedString::from("unknown")); + tool_results.insert( + tool_result.tool_use_id.clone(), + language_model::LanguageModelToolResult { + tool_use_id: tool_result.tool_use_id, + tool_name: name.into(), + is_error: tool_result.is_error, + content: tool_result.content, + output: tool_result.output, + }, + ); + } + + if let Some(last_user_message_id) = &last_user_message_id + && let Some(token_usage) = thread.request_token_usage.get(ix).copied() + { + request_token_usage.insert(last_user_message_id.clone(), token_usage); + } + + crate::Message::Agent(AgentMessage { + content, + tool_results, + }) + } + language_model::Role::System => { + // Skip system messages as they're not supported in the new format + continue; + } + }; + + messages.push(message); + } + + Ok(Self { + title: thread.summary, + messages, + updated_at: thread.updated_at, + detailed_summary: match thread.detailed_summary_state { + DetailedSummaryState::NotGenerated | DetailedSummaryState::Generating { .. } => { + None + } + DetailedSummaryState::Generated { text, .. } => Some(text), + }, + initial_project_snapshot: thread.initial_project_snapshot, + cumulative_token_usage: thread.cumulative_token_usage, + request_token_usage, + model: thread.model, + completion_mode: thread.completion_mode, + profile: thread.profile, + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum DataType { + #[serde(rename = "json")] + Json, + #[serde(rename = "zstd")] + Zstd, +} + +impl Bind for DataType { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + let value = match self { + DataType::Json => "json", + DataType::Zstd => "zstd", + }; + value.bind(statement, start_index) + } +} + +impl Column for DataType { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let (value, next_index) = String::column(statement, start_index)?; + let data_type = match value.as_str() { + "json" => DataType::Json, + "zstd" => DataType::Zstd, + _ => anyhow::bail!("Unknown data type: {}", value), + }; + Ok((data_type, next_index)) + } +} + +pub(crate) struct ThreadsDatabase { + executor: BackgroundExecutor, + connection: Arc>, +} + +struct GlobalThreadsDatabase(Shared, Arc>>>); + +impl Global for GlobalThreadsDatabase {} + +impl ThreadsDatabase { + pub fn connect(cx: &mut App) -> Shared, Arc>>> { + if cx.has_global::() { + return cx.global::().0.clone(); + } + let executor = cx.background_executor().clone(); + let task = executor + .spawn({ + let executor = executor.clone(); + async move { + match ThreadsDatabase::new(executor) { + Ok(db) => Ok(Arc::new(db)), + Err(err) => Err(Arc::new(err)), + } + } + }) + .shared(); + + cx.set_global(GlobalThreadsDatabase(task.clone())); + task + } + + pub fn new(executor: BackgroundExecutor) -> Result { + let connection = if *ZED_STATELESS { + Connection::open_memory(Some("THREAD_FALLBACK_DB")) + } else if cfg!(any(feature = "test-support", test)) { + // rust stores the name of the test on the current thread. + // We use this to automatically create a database that will + // be shared within the test (for the test_retrieve_old_thread) + // but not with concurrent tests. + let thread = std::thread::current(); + let test_name = thread.name(); + Connection::open_memory(Some(&format!( + "THREAD_FALLBACK_{}", + test_name.unwrap_or_default() + ))) + } else { + let threads_dir = paths::data_dir().join("threads"); + std::fs::create_dir_all(&threads_dir)?; + let sqlite_path = threads_dir.join("threads.db"); + Connection::open_file(&sqlite_path.to_string_lossy()) + }; + + connection.exec(indoc! {" + CREATE TABLE IF NOT EXISTS threads ( + id TEXT PRIMARY KEY, + summary TEXT NOT NULL, + updated_at TEXT NOT NULL, + data_type TEXT NOT NULL, + data BLOB NOT NULL + ) + "})?() + .map_err(|e| anyhow!("Failed to create threads table: {}", e))?; + + let db = Self { + executor, + connection: Arc::new(Mutex::new(connection)), + }; + + Ok(db) + } + + fn save_thread_sync( + connection: &Arc>, + id: acp::SessionId, + thread: DbThread, + ) -> Result<()> { + const COMPRESSION_LEVEL: i32 = 3; + + #[derive(Serialize)] + struct SerializedThread { + #[serde(flatten)] + thread: DbThread, + version: &'static str, + } + + let title = thread.title.to_string(); + let updated_at = thread.updated_at.to_rfc3339(); + let json_data = serde_json::to_string(&SerializedThread { + thread, + version: DbThread::VERSION, + })?; + + let connection = connection.lock(); + + let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?; + let data_type = DataType::Zstd; + let data = compressed; + + let mut insert = connection.exec_bound::<(Arc, String, String, DataType, Vec)>(indoc! {" + INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?) + "})?; + + insert((id.0, title, updated_at, data_type, data))?; + + Ok(()) + } + + pub fn list_threads(&self) -> Task>> { + let connection = self.connection.clone(); + + self.executor.spawn(async move { + let connection = connection.lock(); + + let mut select = + connection.select_bound::<(), (Arc, String, String)>(indoc! {" + SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC + "})?; + + let rows = select(())?; + let mut threads = Vec::new(); + + for (id, summary, updated_at) in rows { + threads.push(DbThreadMetadata { + id: acp::SessionId(id), + title: summary.into(), + updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc), + }); + } + + Ok(threads) + }) + } + + pub fn load_thread(&self, id: acp::SessionId) -> Task>> { + let connection = self.connection.clone(); + + self.executor.spawn(async move { + let connection = connection.lock(); + let mut select = connection.select_bound::, (DataType, Vec)>(indoc! {" + SELECT data_type, data FROM threads WHERE id = ? LIMIT 1 + "})?; + + let rows = select(id.0)?; + if let Some((data_type, data)) = rows.into_iter().next() { + let json_data = match data_type { + DataType::Zstd => { + let decompressed = zstd::decode_all(&data[..])?; + String::from_utf8(decompressed)? + } + DataType::Json => String::from_utf8(data)?, + }; + let thread = DbThread::from_json(json_data.as_bytes())?; + Ok(Some(thread)) + } else { + Ok(None) + } + }) + } + + pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task> { + let connection = self.connection.clone(); + + self.executor + .spawn(async move { Self::save_thread_sync(&connection, id, thread) }) + } + + pub fn delete_thread(&self, id: acp::SessionId) -> Task> { + let connection = self.connection.clone(); + + self.executor.spawn(async move { + let connection = connection.lock(); + + let mut delete = connection.exec_bound::>(indoc! {" + DELETE FROM threads WHERE id = ? + "})?; + + delete(id.0)?; + + Ok(()) + }) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use agent::MessageSegment; + use agent::context::LoadedContext; + use client::Client; + use fs::FakeFs; + use gpui::AppContext; + use gpui::TestAppContext; + use http_client::FakeHttpClient; + use language_model::Role; + use project::Project; + use settings::SettingsStore; + + fn init_test(cx: &mut TestAppContext) { + env_logger::try_init().ok(); + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + language::init(cx); + + let http_client = FakeHttpClient::with_404_response(); + let clock = Arc::new(clock::FakeSystemClock::new()); + let client = Client::new(clock, http_client, cx); + agent::init(cx); + agent_settings::init(cx); + language_model::init(client, cx); + }); + } + + #[gpui::test] + async fn test_retrieving_old_thread(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + + // Save a thread using the old agent. + let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx)); + let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx)); + thread.update(cx, |thread, cx| { + thread.insert_message( + Role::User, + vec![MessageSegment::Text("Hey!".into())], + LoadedContext::default(), + vec![], + false, + cx, + ); + thread.insert_message( + Role::Assistant, + vec![MessageSegment::Text("How're you doing?".into())], + LoadedContext::default(), + vec![], + false, + cx, + ) + }); + thread_store + .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx)) + .await + .unwrap(); + + // Open that same thread using the new agent. + let db = cx.update(ThreadsDatabase::connect).await.unwrap(); + let threads = db.list_threads().await.unwrap(); + assert_eq!(threads.len(), 1); + let thread = db + .load_thread(threads[0].id.clone()) + .await + .unwrap() + .unwrap(); + assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n"); + assert_eq!( + thread.messages[1].to_markdown(), + "## Assistant\n\nHow're you doing?\n" + ); + } +} diff --git a/crates/agent2/src/history_store.rs b/crates/agent2/src/history_store.rs new file mode 100644 index 0000000000000000000000000000000000000000..c656456e01780505c355c878c26d2405286e56b2 --- /dev/null +++ b/crates/agent2/src/history_store.rs @@ -0,0 +1,357 @@ +use crate::{DbThreadMetadata, ThreadsDatabase}; +use acp_thread::MentionUri; +use agent_client_protocol as acp; +use anyhow::{Context as _, Result, anyhow}; +use assistant_context::{AssistantContext, SavedContextMetadata}; +use chrono::{DateTime, Utc}; +use db::kvp::KEY_VALUE_STORE; +use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*}; +use itertools::Itertools; +use paths::contexts_dir; +use serde::{Deserialize, Serialize}; +use std::{collections::VecDeque, path::Path, sync::Arc, time::Duration}; +use ui::ElementId; +use util::ResultExt as _; + +const MAX_RECENTLY_OPENED_ENTRIES: usize = 6; +const RECENTLY_OPENED_THREADS_KEY: &str = "recent-agent-threads"; +const SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE: Duration = Duration::from_millis(50); + +const DEFAULT_TITLE: &SharedString = &SharedString::new_static("New Thread"); + +#[derive(Clone, Debug)] +pub enum HistoryEntry { + AcpThread(DbThreadMetadata), + TextThread(SavedContextMetadata), +} + +impl HistoryEntry { + pub fn updated_at(&self) -> DateTime { + match self { + HistoryEntry::AcpThread(thread) => thread.updated_at, + HistoryEntry::TextThread(context) => context.mtime.to_utc(), + } + } + + pub fn id(&self) -> HistoryEntryId { + match self { + HistoryEntry::AcpThread(thread) => HistoryEntryId::AcpThread(thread.id.clone()), + HistoryEntry::TextThread(context) => HistoryEntryId::TextThread(context.path.clone()), + } + } + + pub fn mention_uri(&self) -> MentionUri { + match self { + HistoryEntry::AcpThread(thread) => MentionUri::Thread { + id: thread.id.clone(), + name: thread.title.to_string(), + }, + HistoryEntry::TextThread(context) => MentionUri::TextThread { + path: context.path.as_ref().to_owned(), + name: context.title.to_string(), + }, + } + } + + pub fn title(&self) -> &SharedString { + match self { + HistoryEntry::AcpThread(thread) if thread.title.is_empty() => DEFAULT_TITLE, + HistoryEntry::AcpThread(thread) => &thread.title, + HistoryEntry::TextThread(context) => &context.title, + } + } +} + +/// Generic identifier for a history entry. +#[derive(Clone, PartialEq, Eq, Debug, Hash)] +pub enum HistoryEntryId { + AcpThread(acp::SessionId), + TextThread(Arc), +} + +impl Into for HistoryEntryId { + fn into(self) -> ElementId { + match self { + HistoryEntryId::AcpThread(session_id) => ElementId::Name(session_id.0.into()), + HistoryEntryId::TextThread(path) => ElementId::Path(path), + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +enum SerializedRecentOpen { + AcpThread(String), + TextThread(String), +} + +pub struct HistoryStore { + threads: Vec, + entries: Vec, + context_store: Entity, + recently_opened_entries: VecDeque, + _subscriptions: Vec, + _save_recently_opened_entries_task: Task<()>, +} + +impl HistoryStore { + pub fn new( + context_store: Entity, + cx: &mut Context, + ) -> Self { + let subscriptions = vec![cx.observe(&context_store, |this, _, cx| this.update_entries(cx))]; + + cx.spawn(async move |this, cx| { + let entries = Self::load_recently_opened_entries(cx).await; + this.update(cx, |this, cx| { + if let Some(entries) = entries.log_err() { + this.recently_opened_entries = entries; + } + + this.reload(cx); + }) + .ok(); + }) + .detach(); + + Self { + context_store, + recently_opened_entries: VecDeque::default(), + threads: Vec::default(), + entries: Vec::default(), + _subscriptions: subscriptions, + _save_recently_opened_entries_task: Task::ready(()), + } + } + + pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> { + self.threads.iter().find(|thread| &thread.id == session_id) + } + + pub fn delete_thread( + &mut self, + id: acp::SessionId, + cx: &mut Context, + ) -> Task> { + let database_future = ThreadsDatabase::connect(cx); + cx.spawn(async move |this, cx| { + let database = database_future.await.map_err(|err| anyhow!(err))?; + database.delete_thread(id.clone()).await?; + this.update(cx, |this, cx| this.reload(cx)) + }) + } + + pub fn delete_text_thread( + &mut self, + path: Arc, + cx: &mut Context, + ) -> Task> { + self.context_store.update(cx, |context_store, cx| { + context_store.delete_local_context(path, cx) + }) + } + + pub fn load_text_thread( + &self, + path: Arc, + cx: &mut Context, + ) -> Task>> { + self.context_store.update(cx, |context_store, cx| { + context_store.open_local_context(path, cx) + }) + } + + pub fn reload(&self, cx: &mut Context) { + let database_future = ThreadsDatabase::connect(cx); + cx.spawn(async move |this, cx| { + let threads = database_future + .await + .map_err(|err| anyhow!(err))? + .list_threads() + .await?; + + this.update(cx, |this, cx| { + if this.recently_opened_entries.len() < MAX_RECENTLY_OPENED_ENTRIES { + for thread in threads + .iter() + .take(MAX_RECENTLY_OPENED_ENTRIES - this.recently_opened_entries.len()) + .rev() + { + this.push_recently_opened_entry( + HistoryEntryId::AcpThread(thread.id.clone()), + cx, + ) + } + } + this.threads = threads; + this.update_entries(cx); + }) + }) + .detach_and_log_err(cx); + } + + fn update_entries(&mut self, cx: &mut Context) { + #[cfg(debug_assertions)] + if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() { + return; + } + let mut history_entries = Vec::new(); + history_entries.extend(self.threads.iter().cloned().map(HistoryEntry::AcpThread)); + history_entries.extend( + self.context_store + .read(cx) + .unordered_contexts() + .cloned() + .map(HistoryEntry::TextThread), + ); + + history_entries.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.updated_at())); + self.entries = history_entries; + cx.notify() + } + + pub fn is_empty(&self, _cx: &App) -> bool { + self.entries.is_empty() + } + + pub fn recently_opened_entries(&self, cx: &App) -> Vec { + #[cfg(debug_assertions)] + if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() { + return Vec::new(); + } + + let thread_entries = self.threads.iter().flat_map(|thread| { + self.recently_opened_entries + .iter() + .enumerate() + .flat_map(|(index, entry)| match entry { + HistoryEntryId::AcpThread(id) if &thread.id == id => { + Some((index, HistoryEntry::AcpThread(thread.clone()))) + } + _ => None, + }) + }); + + let context_entries = + self.context_store + .read(cx) + .unordered_contexts() + .flat_map(|context| { + self.recently_opened_entries + .iter() + .enumerate() + .flat_map(|(index, entry)| match entry { + HistoryEntryId::TextThread(path) if &context.path == path => { + Some((index, HistoryEntry::TextThread(context.clone()))) + } + _ => None, + }) + }); + + thread_entries + .chain(context_entries) + // optimization to halt iteration early + .take(self.recently_opened_entries.len()) + .sorted_unstable_by_key(|(index, _)| *index) + .map(|(_, entry)| entry) + .collect() + } + + fn save_recently_opened_entries(&mut self, cx: &mut Context) { + let serialized_entries = self + .recently_opened_entries + .iter() + .filter_map(|entry| match entry { + HistoryEntryId::TextThread(path) => path.file_name().map(|file| { + SerializedRecentOpen::TextThread(file.to_string_lossy().to_string()) + }), + HistoryEntryId::AcpThread(id) => { + Some(SerializedRecentOpen::AcpThread(id.to_string())) + } + }) + .collect::>(); + + self._save_recently_opened_entries_task = cx.spawn(async move |_, cx| { + let content = serde_json::to_string(&serialized_entries).unwrap(); + cx.background_executor() + .timer(SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE) + .await; + + if cfg!(any(feature = "test-support", test)) { + return; + } + KEY_VALUE_STORE + .write_kvp(RECENTLY_OPENED_THREADS_KEY.to_owned(), content) + .await + .log_err(); + }); + } + + fn load_recently_opened_entries(cx: &AsyncApp) -> Task>> { + cx.background_spawn(async move { + if cfg!(any(feature = "test-support", test)) { + anyhow::bail!("history store does not persist in tests"); + } + let json = KEY_VALUE_STORE + .read_kvp(RECENTLY_OPENED_THREADS_KEY)? + .unwrap_or("[]".to_string()); + let entries = serde_json::from_str::>(&json) + .context("deserializing persisted agent panel navigation history")? + .into_iter() + .take(MAX_RECENTLY_OPENED_ENTRIES) + .flat_map(|entry| match entry { + SerializedRecentOpen::AcpThread(id) => Some(HistoryEntryId::AcpThread( + acp::SessionId(id.as_str().into()), + )), + SerializedRecentOpen::TextThread(file_name) => Some( + HistoryEntryId::TextThread(contexts_dir().join(file_name).into()), + ), + }) + .collect(); + Ok(entries) + }) + } + + pub fn push_recently_opened_entry(&mut self, entry: HistoryEntryId, cx: &mut Context) { + self.recently_opened_entries + .retain(|old_entry| old_entry != &entry); + self.recently_opened_entries.push_front(entry); + self.recently_opened_entries + .truncate(MAX_RECENTLY_OPENED_ENTRIES); + self.save_recently_opened_entries(cx); + } + + pub fn remove_recently_opened_thread(&mut self, id: acp::SessionId, cx: &mut Context) { + self.recently_opened_entries.retain( + |entry| !matches!(entry, HistoryEntryId::AcpThread(thread_id) if thread_id == &id), + ); + self.save_recently_opened_entries(cx); + } + + pub fn replace_recently_opened_text_thread( + &mut self, + old_path: &Path, + new_path: &Arc, + cx: &mut Context, + ) { + for entry in &mut self.recently_opened_entries { + match entry { + HistoryEntryId::TextThread(path) if path.as_ref() == old_path => { + *entry = HistoryEntryId::TextThread(new_path.clone()); + break; + } + _ => {} + } + } + self.save_recently_opened_entries(cx); + } + + pub fn remove_recently_opened_entry(&mut self, entry: &HistoryEntryId, cx: &mut Context) { + self.recently_opened_entries + .retain(|old_entry| old_entry != entry); + self.save_recently_opened_entries(cx); + } + + pub fn entries(&self) -> impl Iterator { + self.entries.iter().cloned() + } +} diff --git a/crates/agent2/src/native_agent_server.rs b/crates/agent2/src/native_agent_server.rs index cadd88a8462ca0c297ef0b7b8cd516f87104c4eb..0dde0ff98552d4292a4391d2aec4f36419228a25 100644 --- a/crates/agent2/src/native_agent_server.rs +++ b/crates/agent2/src/native_agent_server.rs @@ -1,55 +1,56 @@ -use std::{path::Path, rc::Rc, sync::Arc}; +use std::{any::Any, path::Path, rc::Rc, sync::Arc}; -use agent_servers::AgentServer; +use agent_servers::{AgentServer, AgentServerDelegate}; use anyhow::Result; use fs::Fs; -use gpui::{App, Entity, Task}; -use project::Project; +use gpui::{App, Entity, SharedString, Task}; use prompt_store::PromptStore; -use crate::{NativeAgent, NativeAgentConnection, templates::Templates}; +use crate::{HistoryStore, NativeAgent, NativeAgentConnection, templates::Templates}; #[derive(Clone)] pub struct NativeAgentServer { fs: Arc, + history: Entity, } impl NativeAgentServer { - pub fn new(fs: Arc) -> Self { - Self { fs } + pub fn new(fs: Arc, history: Entity) -> Self { + Self { fs, history } } } impl AgentServer for NativeAgentServer { - fn name(&self) -> &'static str { - "Native Agent" + fn telemetry_id(&self) -> &'static str { + "zed" } - fn empty_state_headline(&self) -> &'static str { - "Native Agent" - } - - fn empty_state_message(&self) -> &'static str { - "How can I help you today?" + fn name(&self) -> SharedString { + "Zed Agent".into() } fn logo(&self) -> ui::IconName { - // Using the ZedAssistant icon as it's the native built-in agent - ui::IconName::ZedAssistant + ui::IconName::ZedAgent } fn connect( &self, - _root_dir: &Path, - project: &Entity, + _root_dir: Option<&Path>, + delegate: AgentServerDelegate, cx: &mut App, - ) -> Task>> { - log::info!( + ) -> Task< + Result<( + Rc, + Option, + )>, + > { + log::debug!( "NativeAgentServer::connect called for path: {:?}", _root_dir ); - let project = project.clone(); + let project = delegate.project().clone(); let fs = self.fs.clone(); + let history = self.history.clone(); let prompt_store = PromptStore::global(cx); cx.spawn(async move |cx| { log::debug!("Creating templates for native agent"); @@ -57,13 +58,70 @@ impl AgentServer for NativeAgentServer { let prompt_store = prompt_store.await?; log::debug!("Creating native agent entity"); - let agent = NativeAgent::new(project, templates, Some(prompt_store), fs, cx).await?; + let agent = + NativeAgent::new(project, history, templates, Some(prompt_store), fs, cx).await?; // Create the connection wrapper let connection = NativeAgentConnection(agent); - log::info!("NativeAgentServer connection established successfully"); + log::debug!("NativeAgentServer connection established successfully"); - Ok(Rc::new(connection) as Rc) + Ok(( + Rc::new(connection) as Rc, + None, + )) }) } + + fn into_any(self: Rc) -> Rc { + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use assistant_context::ContextStore; + use gpui::AppContext; + + agent_servers::e2e_tests::common_e2e_tests!( + async |fs, project, cx| { + let auth = cx.update(|cx| { + prompt_store::init(cx); + terminal::init(cx); + + let registry = language_model::LanguageModelRegistry::read_global(cx); + let auth = registry + .provider(&language_model::ANTHROPIC_PROVIDER_ID) + .unwrap() + .authenticate(cx); + + cx.spawn(async move |_| auth.await) + }); + + auth.await.unwrap(); + + cx.update(|cx| { + let registry = language_model::LanguageModelRegistry::global(cx); + + registry.update(cx, |registry, cx| { + registry.select_default_model( + Some(&language_model::SelectedModel { + provider: language_model::ANTHROPIC_PROVIDER_ID, + model: language_model::LanguageModelId("claude-sonnet-4-latest".into()), + }), + cx, + ); + }); + }); + + let history = cx.update(|cx| { + let context_store = cx.new(move |cx| ContextStore::fake(project.clone(), cx)); + cx.new(move |cx| HistoryStore::new(context_store, cx)) + }); + + NativeAgentServer::new(fs.clone(), history) + }, + allow_option_id = "allow" + ); } diff --git a/crates/agent2/src/templates.rs b/crates/agent2/src/templates.rs index a63f0ad206308130712b9481cfd7231eb0fd2696..72a8f6633cb7bb926580dbb4f9e65ec032162d93 100644 --- a/crates/agent2/src/templates.rs +++ b/crates/agent2/src/templates.rs @@ -62,7 +62,7 @@ fn contains( handlebars::RenderError::new("contains: missing or invalid query parameter") })?; - if list.contains(&query) { + if list.contains(query) { out.write("true")?; } diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 637af73d1a3ab9cfe31225a4273d2d675b15e403..884580ed69009d168b3266870acf4f698a2f5450 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1,46 +1,63 @@ use super::*; use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId}; -use action_log::ActionLog; use agent_client_protocol::{self as acp}; use agent_settings::AgentProfileId; use anyhow::Result; use client::{Client, UserStore}; +use cloud_llm_client::CompletionIntent; +use collections::IndexMap; +use context_server::{ContextServer, ContextServerCommand, ContextServerId}; use fs::{FakeFs, Fs}; -use futures::channel::mpsc::UnboundedReceiver; +use futures::{ + StreamExt, + channel::{ + mpsc::{self, UnboundedReceiver}, + oneshot, + }, +}; use gpui::{ App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient, }; use indoc::indoc; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, - LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason, - fake_provider::FakeLanguageModel, + LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest, + LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat, + LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel, +}; +use pretty_assertions::assert_eq; +use project::{ + Project, context_server_store::ContextServerStore, project_settings::ProjectSettings, }; -use project::Project; use prompt_store::ProjectContext; use reqwest_client::ReqwestClient; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::json; -use settings::SettingsStore; -use smol::stream::StreamExt; -use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration}; +use settings::{Settings, SettingsStore}; +use std::{path::Path, rc::Rc, sync::Arc, time::Duration}; use util::path; mod test_tools; use test_tools::*; #[gpui::test] -#[ignore = "can't run on CI yet"] async fn test_echo(cx: &mut TestAppContext) { - let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await; + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); let events = thread .update(cx, |thread, cx| { thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx) }) - .collect() - .await; + .unwrap(); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Hello"); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + + let events = events.collect().await; thread.update(cx, |thread, _cx| { assert_eq!( thread.last_message().unwrap().to_markdown(), @@ -55,9 +72,9 @@ async fn test_echo(cx: &mut TestAppContext) { } #[gpui::test] -#[ignore = "can't run on CI yet"] async fn test_thinking(cx: &mut TestAppContext) { - let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await; + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); let events = thread .update(cx, |thread, cx| { @@ -72,8 +89,18 @@ async fn test_thinking(cx: &mut TestAppContext) { cx, ) }) - .collect() - .await; + .unwrap(); + cx.run_until_parked(); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking { + text: "Think".to_string(), + signature: None, + }); + fake_model.send_last_completion_stream_text_chunk("Hello"); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + + let events = events.collect().await; thread.update(cx, |thread, _cx| { assert_eq!( thread.last_message().unwrap().to_markdown(), @@ -98,11 +125,15 @@ async fn test_system_prompt(cx: &mut TestAppContext) { } = setup(cx, TestModel::Fake).await; let fake_model = model.as_fake(); - project_context.borrow_mut().shell = "test-shell".into(); - thread.update(cx, |thread, _| thread.add_tool(EchoTool)); - thread.update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["abc"], cx) + project_context.update(cx, |project_context, _cx| { + project_context.shell = "test-shell".into() }); + thread.update(cx, |thread, _| thread.add_tool(EchoTool)); + thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["abc"], cx) + }) + .unwrap(); cx.run_until_parked(); let mut pending_completions = fake_model.pending_completions(); assert_eq!( @@ -130,7 +161,141 @@ async fn test_system_prompt(cx: &mut TestAppContext) { } #[gpui::test] -#[ignore = "can't run on CI yet"] +async fn test_prompt_caching(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + // Send initial user message and verify it's cached + thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Message 1"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + completion.messages[1..], + vec![LanguageModelRequestMessage { + role: Role::User, + content: vec!["Message 1".into()], + cache: true + }] + ); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text( + "Response to Message 1".into(), + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Send another user message and verify only the latest is cached + thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Message 2"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Message 1".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec!["Response to Message 1".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Message 2".into()], + cache: true + } + ] + ); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text( + "Response to Message 2".into(), + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Simulate a tool call and verify that the latest tool result is cached + thread.update(cx, |thread, _| thread.add_tool(EchoTool)); + thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Use the echo tool"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + let tool_use = LanguageModelToolUse { + id: "tool_1".into(), + name: EchoTool::name().into(), + raw_input: json!({"text": "test"}).to_string(), + input: json!({"text": "test"}), + is_input_complete: true, + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone())); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let completion = fake_model.pending_completions().pop().unwrap(); + let tool_result = LanguageModelToolResult { + tool_use_id: "tool_1".into(), + tool_name: EchoTool::name().into(), + is_error: false, + content: "test".into(), + output: Some("test".into()), + }; + assert_eq!( + completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Message 1".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec!["Response to Message 1".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Message 2".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec!["Response to Message 2".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Use the echo tool".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolUse(tool_use)], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::ToolResult(tool_result)], + cache: true + } + ] + ); +} + +#[gpui::test] +#[cfg_attr(not(feature = "e2e"), ignore)] async fn test_basic_tool_calls(cx: &mut TestAppContext) { let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await; @@ -144,6 +309,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { cx, ) }) + .unwrap() .collect() .await; assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); @@ -151,7 +317,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { // Test a tool calls that's likely to complete *after* streaming stops. let events = thread .update(cx, |thread, cx| { - thread.remove_tool(&AgentTool::name(&EchoTool)); + thread.remove_tool(&EchoTool::name()); thread.add_tool(DelayTool); thread.send( UserMessageId::new(), @@ -162,6 +328,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { cx, ) }) + .unwrap() .collect() .await; assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); @@ -188,19 +355,21 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { } #[gpui::test] -#[ignore = "can't run on CI yet"] +#[cfg_attr(not(feature = "e2e"), ignore)] async fn test_streaming_tool_calls(cx: &mut TestAppContext) { let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await; // Test a tool call that's likely to complete *before* streaming stops. - let mut events = thread.update(cx, |thread, cx| { - thread.add_tool(WordListTool); - thread.send(UserMessageId::new(), ["Test the word_list tool."], cx) - }); + let mut events = thread + .update(cx, |thread, cx| { + thread.add_tool(WordListTool); + thread.send(UserMessageId::new(), ["Test the word_list tool."], cx) + }) + .unwrap(); let mut saw_partial_tool_use = false; while let Some(event) = events.next().await { - if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event { + if let Ok(ThreadEvent::ToolCall(tool_call)) = event { thread.update(cx, |thread, _cx| { // Look for a tool use in the thread's last message let message = thread.last_message().unwrap(); @@ -242,15 +411,17 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; let fake_model = model.as_fake(); - let mut events = thread.update(cx, |thread, cx| { - thread.add_tool(ToolRequiringPermission); - thread.send(UserMessageId::new(), ["abc"], cx) - }); + let mut events = thread + .update(cx, |thread, cx| { + thread.add_tool(ToolRequiringPermission); + thread.send(UserMessageId::new(), ["abc"], cx) + }) + .unwrap(); cx.run_until_parked(); fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { id: "tool_id_1".into(), - name: ToolRequiringPermission.name().into(), + name: ToolRequiringPermission::name().into(), raw_input: "{}".into(), input: json!({}), is_input_complete: true, @@ -259,7 +430,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { id: "tool_id_2".into(), - name: ToolRequiringPermission.name().into(), + name: ToolRequiringPermission::name().into(), raw_input: "{}".into(), input: json!({}), is_input_complete: true, @@ -290,17 +461,17 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { vec![ language_model::MessageContent::ToolResult(LanguageModelToolResult { tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(), - tool_name: ToolRequiringPermission.name().into(), + tool_name: ToolRequiringPermission::name().into(), is_error: false, content: "Allowed".into(), output: Some("Allowed".into()) }), language_model::MessageContent::ToolResult(LanguageModelToolResult { tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(), - tool_name: ToolRequiringPermission.name().into(), + tool_name: ToolRequiringPermission::name().into(), is_error: true, content: "Permission to run tool denied by user".into(), - output: None + output: Some("Permission to run tool denied by user".into()) }) ] ); @@ -309,7 +480,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { id: "tool_id_3".into(), - name: ToolRequiringPermission.name().into(), + name: ToolRequiringPermission::name().into(), raw_input: "{}".into(), input: json!({}), is_input_complete: true, @@ -331,7 +502,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { vec![language_model::MessageContent::ToolResult( LanguageModelToolResult { tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(), - tool_name: ToolRequiringPermission.name().into(), + tool_name: ToolRequiringPermission::name().into(), is_error: false, content: "Allowed".into(), output: Some("Allowed".into()) @@ -343,7 +514,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { id: "tool_id_4".into(), - name: ToolRequiringPermission.name().into(), + name: ToolRequiringPermission::name().into(), raw_input: "{}".into(), input: json!({}), is_input_complete: true, @@ -358,7 +529,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { vec![language_model::MessageContent::ToolResult( LanguageModelToolResult { tool_use_id: "tool_id_4".into(), - tool_name: ToolRequiringPermission.name().into(), + tool_name: ToolRequiringPermission::name().into(), is_error: false, content: "Allowed".into(), output: Some("Allowed".into()) @@ -372,9 +543,11 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; let fake_model = model.as_fake(); - let mut events = thread.update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["abc"], cx) - }); + let mut events = thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["abc"], cx) + }) + .unwrap(); cx.run_until_parked(); fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { @@ -394,16 +567,197 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) { assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed)); } -async fn expect_tool_call( - events: &mut UnboundedReceiver>, -) -> acp::ToolCall { +#[gpui::test] +async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let events = thread + .update(cx, |thread, cx| { + thread.add_tool(EchoTool); + thread.send(UserMessageId::new(), ["abc"], cx) + }) + .unwrap(); + cx.run_until_parked(); + let tool_use = LanguageModelToolUse { + id: "tool_id_1".into(), + name: EchoTool::name().into(), + raw_input: "{}".into(), + input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(), + is_input_complete: true, + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone())); + fake_model.end_last_completion_stream(); + + cx.run_until_parked(); + let completion = fake_model.pending_completions().pop().unwrap(); + let tool_result = LanguageModelToolResult { + tool_use_id: "tool_id_1".into(), + tool_name: EchoTool::name().into(), + is_error: false, + content: "def".into(), + output: Some("def".into()), + }; + assert_eq!( + completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["abc".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolUse(tool_use.clone())], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::ToolResult(tool_result.clone())], + cache: true + }, + ] + ); + + // Simulate reaching tool use limit. + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate( + cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached, + )); + fake_model.end_last_completion_stream(); + let last_event = events.collect::>().await.pop().unwrap(); + assert!( + last_event + .unwrap_err() + .is::() + ); + + let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap(); + cx.run_until_parked(); + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["abc".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolUse(tool_use)], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::ToolResult(tool_result)], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Continue where you left off".into()], + cache: true + } + ] + ); + + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into())); + fake_model.end_last_completion_stream(); + events.collect::>().await; + thread.read_with(cx, |thread, _cx| { + assert_eq!( + thread.last_message().unwrap().to_markdown(), + indoc! {" + ## Assistant + + Done + "} + ) + }); +} + +#[gpui::test] +async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let events = thread + .update(cx, |thread, cx| { + thread.add_tool(EchoTool); + thread.send(UserMessageId::new(), ["abc"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + let tool_use = LanguageModelToolUse { + id: "tool_id_1".into(), + name: EchoTool::name().into(), + raw_input: "{}".into(), + input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(), + is_input_complete: true, + }; + let tool_result = LanguageModelToolResult { + tool_use_id: "tool_id_1".into(), + tool_name: EchoTool::name().into(), + is_error: false, + content: "def".into(), + output: Some("def".into()), + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone())); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate( + cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached, + )); + fake_model.end_last_completion_stream(); + let last_event = events.collect::>().await.pop().unwrap(); + assert!( + last_event + .unwrap_err() + .is::() + ); + + thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), vec!["ghi"], cx) + }) + .unwrap(); + cx.run_until_parked(); + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["abc".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolUse(tool_use)], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::ToolResult(tool_result)], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec!["ghi".into()], + cache: true + } + ] + ); +} + +async fn expect_tool_call(events: &mut UnboundedReceiver>) -> acp::ToolCall { let event = events .next() .await .expect("no tool call authorization event received") .unwrap(); match event { - AgentResponseEvent::ToolCall(tool_call) => return tool_call, + ThreadEvent::ToolCall(tool_call) => tool_call, event => { panic!("Unexpected event {event:?}"); } @@ -411,7 +765,7 @@ async fn expect_tool_call( } async fn expect_tool_call_update_fields( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> acp::ToolCallUpdate { let event = events .next() @@ -419,9 +773,7 @@ async fn expect_tool_call_update_fields( .expect("no tool call authorization event received") .unwrap(); match event { - AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => { - return update; - } + ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update, event => { panic!("Unexpected event {event:?}"); } @@ -429,7 +781,7 @@ async fn expect_tool_call_update_fields( } async fn next_tool_call_authorization( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> ToolCallAuthorization { loop { let event = events @@ -437,7 +789,7 @@ async fn next_tool_call_authorization( .await .expect("no tool call authorization event received") .unwrap(); - if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event { + if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event { let permission_kinds = tool_call_authorization .options .iter() @@ -457,7 +809,7 @@ async fn next_tool_call_authorization( } #[gpui::test] -#[ignore = "can't run on CI yet"] +#[cfg_attr(not(feature = "e2e"), ignore)] async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await; @@ -475,6 +827,7 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { cx, ) }) + .unwrap() .collect() .await; @@ -522,14 +875,14 @@ async fn test_profiles(cx: &mut TestAppContext) { "test-1": { "name": "Test Profile 1", "tools": { - EchoTool.name(): true, - DelayTool.name(): true, + EchoTool::name(): true, + DelayTool::name(): true, } }, "test-2": { "name": "Test Profile 2", "tools": { - InfiniteTool.name(): true, + InfiniteTool::name(): true, } } } @@ -542,10 +895,12 @@ async fn test_profiles(cx: &mut TestAppContext) { cx.run_until_parked(); // Test that test-1 profile (default) has echo and delay tools - thread.update(cx, |thread, cx| { - thread.set_profile(AgentProfileId("test-1".into())); - thread.send(UserMessageId::new(), ["test"], cx); - }); + thread + .update(cx, |thread, cx| { + thread.set_profile(AgentProfileId("test-1".into())); + thread.send(UserMessageId::new(), ["test"], cx) + }) + .unwrap(); cx.run_until_parked(); let mut pending_completions = fake_model.pending_completions(); @@ -556,14 +911,16 @@ async fn test_profiles(cx: &mut TestAppContext) { .iter() .map(|tool| tool.name.clone()) .collect(); - assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]); + assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]); fake_model.end_last_completion_stream(); // Switch to test-2 profile, and verify that it has only the infinite tool. - thread.update(cx, |thread, cx| { - thread.set_profile(AgentProfileId("test-2".into())); - thread.send(UserMessageId::new(), ["test2"], cx) - }); + thread + .update(cx, |thread, cx| { + thread.set_profile(AgentProfileId("test-2".into())); + thread.send(UserMessageId::new(), ["test2"], cx) + }) + .unwrap(); cx.run_until_parked(); let mut pending_completions = fake_model.pending_completions(); assert_eq!(pending_completions.len(), 1); @@ -573,60 +930,399 @@ async fn test_profiles(cx: &mut TestAppContext) { .iter() .map(|tool| tool.name.clone()) .collect(); - assert_eq!(tool_names, vec![InfiniteTool.name()]); + assert_eq!(tool_names, vec![InfiniteTool::name()]); } #[gpui::test] -#[ignore = "can't run on CI yet"] -async fn test_cancellation(cx: &mut TestAppContext) { - let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await; - - let mut events = thread.update(cx, |thread, cx| { - thread.add_tool(InfiniteTool); - thread.add_tool(EchoTool); - thread.send( - UserMessageId::new(), - ["Call the echo tool, then call the infinite tool, then explain their output"], - cx, - ) - }); +async fn test_mcp_tools(cx: &mut TestAppContext) { + let ThreadTest { + model, + thread, + context_server_store, + fs, + .. + } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); - // Wait until both tools are called. - let mut expected_tools = vec!["Echo", "Infinite Tool"]; - let mut echo_id = None; - let mut echo_completed = false; - while let Some(event) = events.next().await { - match event.unwrap() { - AgentResponseEvent::ToolCall(tool_call) => { - assert_eq!(tool_call.title, expected_tools.remove(0)); - if tool_call.title == "Echo" { - echo_id = Some(tool_call.id); + // Override profiles and wait for settings to be loaded. + fs.insert_file( + paths::settings_file(), + json!({ + "agent": { + "always_allow_tool_actions": true, + "profiles": { + "test": { + "name": "Test Profile", + "enable_all_context_servers": true, + "tools": { + EchoTool::name(): true, + } + }, } } - AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( - acp::ToolCallUpdate { - id, - fields: - acp::ToolCallUpdateFields { - status: Some(acp::ToolCallStatus::Completed), - .. - }, - }, - )) if Some(&id) == echo_id.as_ref() => { - echo_completed = true; - } - _ => {} - } - - if expected_tools.is_empty() && echo_completed { - break; + }) + .to_string() + .into_bytes(), + ) + .await; + cx.run_until_parked(); + thread.update(cx, |thread, _| { + thread.set_profile(AgentProfileId("test".into())) + }); + + let mut mcp_tool_calls = setup_context_server( + "test_server", + vec![context_server::types::Tool { + name: "echo".into(), + description: None, + input_schema: serde_json::to_value( + EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema), + ) + .unwrap(), + output_schema: None, + annotations: None, + }], + &context_server_store, + cx, + ); + + let events = thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hey"], cx).unwrap() + }); + cx.run_until_parked(); + + // Simulate the model calling the MCP tool. + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!(tool_names_for_completion(&completion), vec!["echo"]); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_1".into(), + name: "echo".into(), + raw_input: json!({"text": "test"}).to_string(), + input: json!({"text": "test"}), + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap(); + assert_eq!(tool_call_params.name, "echo"); + assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"}))); + tool_call_response + .send(context_server::types::CallToolResponse { + content: vec![context_server::types::ToolResponseContent::Text { + text: "test".into(), + }], + is_error: None, + meta: None, + structured_content: None, + }) + .unwrap(); + cx.run_until_parked(); + + assert_eq!(tool_names_for_completion(&completion), vec!["echo"]); + fake_model.send_last_completion_stream_text_chunk("Done!"); + fake_model.end_last_completion_stream(); + events.collect::>().await; + + // Send again after adding the echo tool, ensuring the name collision is resolved. + let events = thread.update(cx, |thread, cx| { + thread.add_tool(EchoTool); + thread.send(UserMessageId::new(), ["Go"], cx).unwrap() + }); + cx.run_until_parked(); + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + tool_names_for_completion(&completion), + vec!["echo", "test_server_echo"] + ); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_2".into(), + name: "test_server_echo".into(), + raw_input: json!({"text": "mcp"}).to_string(), + input: json!({"text": "mcp"}), + is_input_complete: true, + }, + )); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_3".into(), + name: "echo".into(), + raw_input: json!({"text": "native"}).to_string(), + input: json!({"text": "native"}), + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap(); + assert_eq!(tool_call_params.name, "echo"); + assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"}))); + tool_call_response + .send(context_server::types::CallToolResponse { + content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }], + is_error: None, + meta: None, + structured_content: None, + }) + .unwrap(); + cx.run_until_parked(); + + // Ensure the tool results were inserted with the correct names. + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + completion.messages.last().unwrap().content, + vec![ + MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: "tool_3".into(), + tool_name: "echo".into(), + is_error: false, + content: "native".into(), + output: Some("native".into()), + },), + MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: "tool_2".into(), + tool_name: "test_server_echo".into(), + is_error: false, + content: "mcp".into(), + output: Some("mcp".into()), + },), + ] + ); + fake_model.end_last_completion_stream(); + events.collect::>().await; +} + +#[gpui::test] +async fn test_mcp_tool_truncation(cx: &mut TestAppContext) { + let ThreadTest { + model, + thread, + context_server_store, + fs, + .. + } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + // Set up a profile with all tools enabled + fs.insert_file( + paths::settings_file(), + json!({ + "agent": { + "profiles": { + "test": { + "name": "Test Profile", + "enable_all_context_servers": true, + "tools": { + EchoTool::name(): true, + DelayTool::name(): true, + WordListTool::name(): true, + ToolRequiringPermission::name(): true, + InfiniteTool::name(): true, + } + }, + } + } + }) + .to_string() + .into_bytes(), + ) + .await; + cx.run_until_parked(); + + thread.update(cx, |thread, _| { + thread.set_profile(AgentProfileId("test".into())); + thread.add_tool(EchoTool); + thread.add_tool(DelayTool); + thread.add_tool(WordListTool); + thread.add_tool(ToolRequiringPermission); + thread.add_tool(InfiniteTool); + }); + + // Set up multiple context servers with some overlapping tool names + let _server1_calls = setup_context_server( + "xxx", + vec![ + context_server::types::Tool { + name: "echo".into(), // Conflicts with native EchoTool + description: None, + input_schema: serde_json::to_value( + EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema), + ) + .unwrap(), + output_schema: None, + annotations: None, + }, + context_server::types::Tool { + name: "unique_tool_1".into(), + description: None, + input_schema: json!({"type": "object", "properties": {}}), + output_schema: None, + annotations: None, + }, + ], + &context_server_store, + cx, + ); + + let _server2_calls = setup_context_server( + "yyy", + vec![ + context_server::types::Tool { + name: "echo".into(), // Also conflicts with native EchoTool + description: None, + input_schema: serde_json::to_value( + EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema), + ) + .unwrap(), + output_schema: None, + annotations: None, + }, + context_server::types::Tool { + name: "unique_tool_2".into(), + description: None, + input_schema: json!({"type": "object", "properties": {}}), + output_schema: None, + annotations: None, + }, + context_server::types::Tool { + name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2), + description: None, + input_schema: json!({"type": "object", "properties": {}}), + output_schema: None, + annotations: None, + }, + context_server::types::Tool { + name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1), + description: None, + input_schema: json!({"type": "object", "properties": {}}), + output_schema: None, + annotations: None, + }, + ], + &context_server_store, + cx, + ); + let _server3_calls = setup_context_server( + "zzz", + vec![ + context_server::types::Tool { + name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2), + description: None, + input_schema: json!({"type": "object", "properties": {}}), + output_schema: None, + annotations: None, + }, + context_server::types::Tool { + name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1), + description: None, + input_schema: json!({"type": "object", "properties": {}}), + output_schema: None, + annotations: None, + }, + context_server::types::Tool { + name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1), + description: None, + input_schema: json!({"type": "object", "properties": {}}), + output_schema: None, + annotations: None, + }, + ], + &context_server_store, + cx, + ); + + thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Go"], cx) + }) + .unwrap(); + cx.run_until_parked(); + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + tool_names_for_completion(&completion), + vec![ + "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", + "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc", + "delay", + "echo", + "infinite", + "tool_requiring_permission", + "unique_tool_1", + "unique_tool_2", + "word_list", + "xxx_echo", + "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + "yyy_echo", + "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + ] + ); +} + +#[gpui::test] +#[cfg_attr(not(feature = "e2e"), ignore)] +async fn test_cancellation(cx: &mut TestAppContext) { + let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await; + + let mut events = thread + .update(cx, |thread, cx| { + thread.add_tool(InfiniteTool); + thread.add_tool(EchoTool); + thread.send( + UserMessageId::new(), + ["Call the echo tool, then call the infinite tool, then explain their output"], + cx, + ) + }) + .unwrap(); + + // Wait until both tools are called. + let mut expected_tools = vec!["Echo", "Infinite Tool"]; + let mut echo_id = None; + let mut echo_completed = false; + while let Some(event) = events.next().await { + match event.unwrap() { + ThreadEvent::ToolCall(tool_call) => { + assert_eq!(tool_call.title, expected_tools.remove(0)); + if tool_call.title == "Echo" { + echo_id = Some(tool_call.id); + } + } + ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( + acp::ToolCallUpdate { + id, + fields: + acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + .. + }, + }, + )) if Some(&id) == echo_id.as_ref() => { + echo_completed = true; + } + _ => {} + } + + if expected_tools.is_empty() && echo_completed { + break; } } // Cancel the current send and ensure that the event stream is closed, even // if one of the tools is still running. - thread.update(cx, |thread, _cx| thread.cancel()); - events.collect::>().await; + thread.update(cx, |thread, cx| thread.cancel(cx)); + let events = events.collect::>().await; + let last_event = events.last(); + assert!( + matches!( + last_event, + Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled))) + ), + "unexpected event {last_event:?}" + ); // Ensure we can still send a new message after cancellation. let events = thread @@ -637,6 +1333,7 @@ async fn test_cancellation(cx: &mut TestAppContext) { cx, ) }) + .unwrap() .collect::>() .await; thread.update(cx, |thread, _cx| { @@ -650,14 +1347,80 @@ async fn test_cancellation(cx: &mut TestAppContext) { assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); } +#[gpui::test] +async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let events_1 = thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello 1"], cx) + }) + .unwrap(); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Hey 1!"); + cx.run_until_parked(); + + let events_2 = thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello 2"], cx) + }) + .unwrap(); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Hey 2!"); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + + let events_1 = events_1.collect::>().await; + assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]); + let events_2 = events_2.collect::>().await; + assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]); +} + +#[gpui::test] +async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let events_1 = thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello 1"], cx) + }) + .unwrap(); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Hey 1!"); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + let events_1 = events_1.collect::>().await; + + let events_2 = thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello 2"], cx) + }) + .unwrap(); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Hey 2!"); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + let events_2 = events_2.collect::>().await; + + assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]); + assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]); +} + #[gpui::test] async fn test_refusal(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; let fake_model = model.as_fake(); - let events = thread.update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Hello"], cx) - }); + let events = thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello"], cx) + }) + .unwrap(); cx.run_until_parked(); thread.read_with(cx, |thread, _| { assert_eq!( @@ -698,14 +1461,16 @@ async fn test_refusal(cx: &mut TestAppContext) { } #[gpui::test] -async fn test_truncate(cx: &mut TestAppContext) { +async fn test_truncate_first_message(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; let fake_model = model.as_fake(); let message_id = UserMessageId::new(); - thread.update(cx, |thread, cx| { - thread.send(message_id.clone(), ["Hello"], cx) - }); + thread + .update(cx, |thread, cx| { + thread.send(message_id.clone(), ["Hello"], cx) + }) + .unwrap(); cx.run_until_parked(); thread.read_with(cx, |thread, _| { assert_eq!( @@ -716,64 +1481,329 @@ async fn test_truncate(cx: &mut TestAppContext) { Hello "} ); + assert_eq!(thread.latest_token_usage(), None); }); fake_model.send_last_completion_stream_text_chunk("Hey!"); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( + language_model::TokenUsage { + input_tokens: 32_000, + output_tokens: 16_000, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + )); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hello + + ## Assistant + + Hey! + "} + ); + assert_eq!( + thread.latest_token_usage(), + Some(acp_thread::TokenUsage { + used_tokens: 32_000 + 16_000, + max_tokens: 1_000_000, + }) + ); + }); + + thread + .update(cx, |thread, cx| thread.truncate(message_id, cx)) + .unwrap(); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!(thread.to_markdown(), ""); + assert_eq!(thread.latest_token_usage(), None); + }); + + // Ensure we can still send a new message after truncation. + thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hi"], cx) + }) + .unwrap(); + thread.update(cx, |thread, _cx| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hi + "} + ); + }); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Ahoy!"); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( + language_model::TokenUsage { + input_tokens: 40_000, + output_tokens: 20_000, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + )); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hi + + ## Assistant + + Ahoy! + "} + ); + + assert_eq!( + thread.latest_token_usage(), + Some(acp_thread::TokenUsage { + used_tokens: 40_000 + 20_000, + max_tokens: 1_000_000, + }) + ); + }); +} + +#[gpui::test] +async fn test_truncate_second_message(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Message 1"], cx) + }) + .unwrap(); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Message 1 response"); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( + language_model::TokenUsage { + input_tokens: 32_000, + output_tokens: 16_000, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let assert_first_message_state = |cx: &mut TestAppContext| { + thread.clone().read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Message 1 + + ## Assistant + + Message 1 response + "} + ); + + assert_eq!( + thread.latest_token_usage(), + Some(acp_thread::TokenUsage { + used_tokens: 32_000 + 16_000, + max_tokens: 1_000_000, + }) + ); + }); + }; + + assert_first_message_state(cx); + + let second_message_id = UserMessageId::new(); + thread + .update(cx, |thread, cx| { + thread.send(second_message_id.clone(), ["Message 2"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_text_chunk("Message 2 response"); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( + language_model::TokenUsage { + input_tokens: 40_000, + output_tokens: 20_000, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + )); + fake_model.end_last_completion_stream(); cx.run_until_parked(); + thread.read_with(cx, |thread, _| { assert_eq!( thread.to_markdown(), indoc! {" ## User - Hello + Message 1 ## Assistant - Hey! + Message 1 response + + ## User + + Message 2 + + ## Assistant + + Message 2 response "} ); + + assert_eq!( + thread.latest_token_usage(), + Some(acp_thread::TokenUsage { + used_tokens: 40_000 + 20_000, + max_tokens: 1_000_000, + }) + ); }); thread - .update(cx, |thread, _cx| thread.truncate(message_id)) + .update(cx, |thread, cx| thread.truncate(second_message_id, cx)) .unwrap(); cx.run_until_parked(); - thread.read_with(cx, |thread, _| { - assert_eq!(thread.to_markdown(), ""); - }); - // Ensure we can still send a new message after truncation. + assert_first_message_state(cx); +} + +#[gpui::test] +async fn test_title_generation(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let summary_model = Arc::new(FakeLanguageModel::default()); thread.update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Hi"], cx) + thread.set_summarization_model(Some(summary_model.clone()), cx) }); - thread.update(cx, |thread, _cx| { - assert_eq!( - thread.to_markdown(), - indoc! {" - ## User - Hi - "} - ); - }); + let send = thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello"], cx) + }) + .unwrap(); cx.run_until_parked(); - fake_model.send_last_completion_stream_text_chunk("Ahoy!"); + + fake_model.send_last_completion_stream_text_chunk("Hey!"); + fake_model.end_last_completion_stream(); cx.run_until_parked(); - thread.read_with(cx, |thread, _| { - assert_eq!( - thread.to_markdown(), - indoc! {" - ## User + thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread")); + + // Ensure the summary model has been invoked to generate a title. + summary_model.send_last_completion_stream_text_chunk("Hello "); + summary_model.send_last_completion_stream_text_chunk("world\nG"); + summary_model.send_last_completion_stream_text_chunk("oodnight Moon"); + summary_model.end_last_completion_stream(); + send.collect::>().await; + cx.run_until_parked(); + thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world")); - Hi + // Send another message, ensuring no title is generated this time. + let send = thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello again"], cx) + }) + .unwrap(); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Hey again!"); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + assert_eq!(summary_model.pending_completions(), Vec::new()); + send.collect::>().await; + thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world")); +} - ## Assistant +#[gpui::test] +async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); - Ahoy! - "} - ); - }); + let _events = thread + .update(cx, |thread, cx| { + thread.add_tool(ToolRequiringPermission); + thread.add_tool(EchoTool); + thread.send(UserMessageId::new(), ["Hey!"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + let permission_tool_use = LanguageModelToolUse { + id: "tool_id_1".into(), + name: ToolRequiringPermission::name().into(), + raw_input: "{}".into(), + input: json!({}), + is_input_complete: true, + }; + let echo_tool_use = LanguageModelToolUse { + id: "tool_id_2".into(), + name: EchoTool::name().into(), + raw_input: json!({"text": "test"}).to_string(), + input: json!({"text": "test"}), + is_input_complete: true, + }; + fake_model.send_last_completion_stream_text_chunk("Hi!"); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + permission_tool_use, + )); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + echo_tool_use.clone(), + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Ensure pending tools are skipped when building a request. + let request = thread + .read_with(cx, |thread, cx| { + thread.build_completion_request(CompletionIntent::EditFile, cx) + }) + .unwrap(); + assert_eq!( + request.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Hey!".into()], + cache: true + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![ + MessageContent::Text("Hi!".into()), + MessageContent::ToolUse(echo_tool_use.clone()) + ], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: echo_tool_use.id.clone(), + tool_name: echo_tool_use.name, + is_error: false, + content: "test".into(), + output: Some("test".into()) + })], + cache: false + }, + ], + ); } #[gpui::test] @@ -791,7 +1821,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) { let client = Client::new(clock, http_client, cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); language_model::init(client.clone(), cx); - language_models::init(user_store.clone(), client.clone(), cx); + language_models::init(user_store, client.clone(), cx); Project::init_settings(cx); LanguageModelRegistry::test(cx); agent_settings::init(cx); @@ -803,10 +1833,13 @@ async fn test_agent_connection(cx: &mut TestAppContext) { fake_fs.insert_tree(path!("/test"), json!({})).await; let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await; let cwd = Path::new("/test"); + let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); // Create agent and connection let agent = NativeAgent::new( project.clone(), + history_store, templates.clone(), None, fake_fs.clone(), @@ -841,7 +1874,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) { // Create a thread using new_thread let connection_rc = Rc::new(connection.clone()); let acp_thread = cx - .update(|cx| connection_rc.new_thread(project, cwd, &mut cx.to_async())) + .update(|cx| connection_rc.new_thread(project, cwd, cx)) .await .expect("new_thread should succeed"); @@ -912,9 +1945,11 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool)); let fake_model = model.as_fake(); - let mut events = thread.update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Think"], cx) - }); + let mut events = thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Think"], cx) + }) + .unwrap(); cx.run_until_parked(); // Simulate streaming partial input. @@ -922,7 +1957,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { id: "1".into(), - name: ThinkingTool.name().into(), + name: ThinkingTool::name().into(), raw_input: input.to_string(), input, is_input_complete: false, @@ -1006,14 +2041,252 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { ); } +#[gpui::test] +async fn test_send_no_retry_on_success(cx: &mut TestAppContext) { + let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let mut events = thread + .update(cx, |thread, cx| { + thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx); + thread.send(UserMessageId::new(), ["Hello!"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_text_chunk("Hey!"); + fake_model.end_last_completion_stream(); + + let mut retry_events = Vec::new(); + while let Some(Ok(event)) = events.next().await { + match event { + ThreadEvent::Retry(retry_status) => { + retry_events.push(retry_status); + } + ThreadEvent::Stop(..) => break, + _ => {} + } + } + + assert_eq!(retry_events.len(), 0); + thread.read_with(cx, |thread, _cx| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hello! + + ## Assistant + + Hey! + "} + ) + }); +} + +#[gpui::test] +async fn test_send_retry_on_error(cx: &mut TestAppContext) { + let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let mut events = thread + .update(cx, |thread, cx| { + thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx); + thread.send(UserMessageId::new(), ["Hello!"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_text_chunk("Hey,"); + fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded { + provider: LanguageModelProviderName::new("Anthropic"), + retry_after: Some(Duration::from_secs(3)), + }); + fake_model.end_last_completion_stream(); + + cx.executor().advance_clock(Duration::from_secs(3)); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_text_chunk("there!"); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let mut retry_events = Vec::new(); + while let Some(Ok(event)) = events.next().await { + match event { + ThreadEvent::Retry(retry_status) => { + retry_events.push(retry_status); + } + ThreadEvent::Stop(..) => break, + _ => {} + } + } + + assert_eq!(retry_events.len(), 1); + assert!(matches!( + retry_events[0], + acp_thread::RetryStatus { attempt: 1, .. } + )); + thread.read_with(cx, |thread, _cx| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hello! + + ## Assistant + + Hey, + + [resume] + + ## Assistant + + there! + "} + ) + }); +} + +#[gpui::test] +async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) { + let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let events = thread + .update(cx, |thread, cx| { + thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx); + thread.add_tool(EchoTool); + thread.send(UserMessageId::new(), ["Call the echo tool!"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + let tool_use_1 = LanguageModelToolUse { + id: "tool_1".into(), + name: EchoTool::name().into(), + raw_input: json!({"text": "test"}).to_string(), + input: json!({"text": "test"}), + is_input_complete: true, + }; + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + tool_use_1.clone(), + )); + fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded { + provider: LanguageModelProviderName::new("Anthropic"), + retry_after: Some(Duration::from_secs(3)), + }); + fake_model.end_last_completion_stream(); + + cx.executor().advance_clock(Duration::from_secs(3)); + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Call the echo tool!".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![language_model::MessageContent::ToolResult( + LanguageModelToolResult { + tool_use_id: tool_use_1.id.clone(), + tool_name: tool_use_1.name.clone(), + is_error: false, + content: "test".into(), + output: Some("test".into()) + } + )], + cache: true + }, + ] + ); + + fake_model.send_last_completion_stream_text_chunk("Done"); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + events.collect::>().await; + thread.read_with(cx, |thread, _cx| { + assert_eq!( + thread.last_message(), + Some(Message::Agent(AgentMessage { + content: vec![AgentMessageContent::Text("Done".into())], + tool_results: IndexMap::default() + })) + ); + }) +} + +#[gpui::test] +async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) { + let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let mut events = thread + .update(cx, |thread, cx| { + thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx); + thread.send(UserMessageId::new(), ["Hello!"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 { + fake_model.send_last_completion_stream_error( + LanguageModelCompletionError::ServerOverloaded { + provider: LanguageModelProviderName::new("Anthropic"), + retry_after: Some(Duration::from_secs(3)), + }, + ); + fake_model.end_last_completion_stream(); + cx.executor().advance_clock(Duration::from_secs(3)); + cx.run_until_parked(); + } + + let mut errors = Vec::new(); + let mut retry_events = Vec::new(); + while let Some(event) = events.next().await { + match event { + Ok(ThreadEvent::Retry(retry_status)) => { + retry_events.push(retry_status); + } + Ok(ThreadEvent::Stop(..)) => break, + Err(error) => errors.push(error), + _ => {} + } + } + + assert_eq!( + retry_events.len(), + crate::thread::MAX_RETRY_ATTEMPTS as usize + ); + for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize { + assert_eq!(retry_events[i].attempt, i + 1); + } + assert_eq!(errors.len(), 1); + let error = errors[0] + .downcast_ref::() + .unwrap(); + assert!(matches!( + error, + LanguageModelCompletionError::ServerOverloaded { .. } + )); +} + /// Filters out the stop events for asserting against in tests -fn stop_events( - result_events: Vec>, -) -> Vec { +fn stop_events(result_events: Vec>) -> Vec { result_events .into_iter() .filter_map(|event| match event.unwrap() { - AgentResponseEvent::Stop(stop_reason) => Some(stop_reason), + ThreadEvent::Stop(stop_reason) => Some(stop_reason), _ => None, }) .collect() @@ -1022,13 +2295,13 @@ fn stop_events( struct ThreadTest { model: Arc, thread: Entity, - project_context: Rc>, + project_context: Entity, + context_server_store: Entity, fs: Arc, } enum TestModel { Sonnet4, - Sonnet4Thinking, Fake, } @@ -1036,7 +2309,6 @@ impl TestModel { fn id(&self) -> LanguageModelId { match self { TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()), - TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()), TestModel::Fake => unreachable!(), } } @@ -1058,11 +2330,12 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { "test-profile": { "name": "Test Profile", "tools": { - EchoTool.name(): true, - DelayTool.name(): true, - WordListTool.name(): true, - ToolRequiringPermission.name(): true, - InfiniteTool.name(): true, + EchoTool::name(): true, + DelayTool::name(): true, + WordListTool::name(): true, + ToolRequiringPermission::name(): true, + InfiniteTool::name(): true, + ThinkingTool::name(): true, } } } @@ -1077,15 +2350,20 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { settings::init(cx); Project::init_settings(cx); agent_settings::init(cx); - gpui_tokio::init(cx); - let http_client = ReqwestClient::user_agent("agent tests").unwrap(); - cx.set_http_client(Arc::new(http_client)); - client::init_settings(cx); - let client = Client::production(cx); - let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(client.clone(), cx); - language_models::init(user_store.clone(), client.clone(), cx); + match model { + TestModel::Fake => {} + TestModel::Sonnet4 => { + gpui_tokio::init(cx); + let http_client = ReqwestClient::user_agent("agent tests").unwrap(); + cx.set_http_client(Arc::new(http_client)); + client::init_settings(cx); + let client = Client::production(cx); + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(client.clone(), cx); + language_models::init(user_store, client.clone(), cx); + } + }; watch_settings(fs.clone(), cx); }); @@ -1118,18 +2396,17 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { }) .await; - let project_context = Rc::new(RefCell::new(ProjectContext::default())); + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let action_log = cx.new(|_| ActionLog::new(project.clone())); + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); let thread = cx.new(|cx| { Thread::new( project, project_context.clone(), context_server_registry, - action_log, templates, - model.clone(), + Some(model.clone()), cx, ) }); @@ -1137,6 +2414,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { model, thread, project_context, + context_server_store, fs, } } @@ -1171,3 +2449,90 @@ fn watch_settings(fs: Arc, cx: &mut App) { }) .detach(); } + +fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec { + completion + .tools + .iter() + .map(|tool| tool.name.clone()) + .collect() +} + +fn setup_context_server( + name: &'static str, + tools: Vec, + context_server_store: &Entity, + cx: &mut TestAppContext, +) -> mpsc::UnboundedReceiver<( + context_server::types::CallToolParams, + oneshot::Sender, +)> { + cx.update(|cx| { + let mut settings = ProjectSettings::get_global(cx).clone(); + settings.context_servers.insert( + name.into(), + project::project_settings::ContextServerSettings::Custom { + enabled: true, + command: ContextServerCommand { + path: "somebinary".into(), + args: Vec::new(), + env: None, + timeout: None, + }, + }, + ); + ProjectSettings::override_global(settings, cx); + }); + + let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded(); + let fake_transport = context_server::test::create_fake_transport(name, cx.executor()) + .on_request::(move |_params| async move { + context_server::types::InitializeResponse { + protocol_version: context_server::types::ProtocolVersion( + context_server::types::LATEST_PROTOCOL_VERSION.to_string(), + ), + server_info: context_server::types::Implementation { + name: name.into(), + version: "1.0.0".to_string(), + }, + capabilities: context_server::types::ServerCapabilities { + tools: Some(context_server::types::ToolsCapabilities { + list_changed: Some(true), + }), + ..Default::default() + }, + meta: None, + } + }) + .on_request::(move |_params| { + let tools = tools.clone(); + async move { + context_server::types::ListToolsResponse { + tools, + next_cursor: None, + meta: None, + } + } + }) + .on_request::(move |params| { + let mcp_tool_calls_tx = mcp_tool_calls_tx.clone(); + async move { + let (response_tx, response_rx) = oneshot::channel(); + mcp_tool_calls_tx + .unbounded_send((params, response_tx)) + .unwrap(); + response_rx.await.unwrap() + } + }); + context_server_store.update(cx, |store, cx| { + store.start_server( + Arc::new(ContextServer::new( + ContextServerId(name.into()), + Arc::new(fake_transport), + )), + cx, + ); + }); + cx.run_until_parked(); + mcp_tool_calls_rx +} diff --git a/crates/agent2/src/tests/test_tools.rs b/crates/agent2/src/tests/test_tools.rs index 7c7b81f52fce95c9af181cd7fa03579160021518..2275d23c2f8a924efce2d2d4d8bcf6a6f3a59def 100644 --- a/crates/agent2/src/tests/test_tools.rs +++ b/crates/agent2/src/tests/test_tools.rs @@ -7,7 +7,7 @@ use std::future; #[derive(JsonSchema, Serialize, Deserialize)] pub struct EchoToolInput { /// The text to echo. - text: String, + pub text: String, } pub struct EchoTool; @@ -16,15 +16,19 @@ impl AgentTool for EchoTool { type Input = EchoToolInput; type Output = String; - fn name(&self) -> SharedString { - "echo".into() + fn name() -> &'static str { + "echo" } - fn kind(&self) -> acp::ToolKind { + fn kind() -> acp::ToolKind { acp::ToolKind::Other } - fn initial_title(&self, _input: Result) -> SharedString { + fn initial_title( + &self, + _input: Result, + _cx: &mut App, + ) -> SharedString { "Echo".into() } @@ -51,11 +55,15 @@ impl AgentTool for DelayTool { type Input = DelayToolInput; type Output = String; - fn name(&self) -> SharedString { - "delay".into() + fn name() -> &'static str { + "delay" } - fn initial_title(&self, input: Result) -> SharedString { + fn initial_title( + &self, + input: Result, + _cx: &mut App, + ) -> SharedString { if let Ok(input) = input { format!("Delay {}ms", input.ms).into() } else { @@ -63,7 +71,7 @@ impl AgentTool for DelayTool { } } - fn kind(&self) -> acp::ToolKind { + fn kind() -> acp::ToolKind { acp::ToolKind::Other } @@ -92,15 +100,19 @@ impl AgentTool for ToolRequiringPermission { type Input = ToolRequiringPermissionInput; type Output = String; - fn name(&self) -> SharedString { - "tool_requiring_permission".into() + fn name() -> &'static str { + "tool_requiring_permission" } - fn kind(&self) -> acp::ToolKind { + fn kind() -> acp::ToolKind { acp::ToolKind::Other } - fn initial_title(&self, _input: Result) -> SharedString { + fn initial_title( + &self, + _input: Result, + _cx: &mut App, + ) -> SharedString { "This tool requires permission".into() } @@ -127,15 +139,19 @@ impl AgentTool for InfiniteTool { type Input = InfiniteToolInput; type Output = String; - fn name(&self) -> SharedString { - "infinite".into() + fn name() -> &'static str { + "infinite" } - fn kind(&self) -> acp::ToolKind { + fn kind() -> acp::ToolKind { acp::ToolKind::Other } - fn initial_title(&self, _input: Result) -> SharedString { + fn initial_title( + &self, + _input: Result, + _cx: &mut App, + ) -> SharedString { "Infinite Tool".into() } @@ -178,15 +194,19 @@ impl AgentTool for WordListTool { type Input = WordListInput; type Output = String; - fn name(&self) -> SharedString { - "word_list".into() + fn name() -> &'static str { + "word_list" } - fn kind(&self) -> acp::ToolKind { + fn kind() -> acp::ToolKind { acp::ToolKind::Other } - fn initial_title(&self, _input: Result) -> SharedString { + fn initial_title( + &self, + _input: Result, + _cx: &mut App, + ) -> SharedString { "List of random words".into() } diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 4156ec44d2b24ebab3e20f2bab9330e7ca13f53a..20c4cd07533b7cf9bd1dd00e666bbb66552db9d7 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,38 +1,103 @@ -use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates}; +use crate::{ + ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread, + DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, + ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate, + Template, Templates, TerminalTool, ThinkingTool, WebSearchTool, +}; use acp_thread::{MentionUri, UserMessageId}; use action_log::ActionLog; +use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot}; use agent_client_protocol as acp; -use agent_settings::{AgentProfileId, AgentSettings}; +use agent_settings::{ + AgentProfileId, AgentProfileSettings, AgentSettings, CompletionMode, + SUMMARIZE_THREAD_DETAILED_PROMPT, SUMMARIZE_THREAD_PROMPT, +}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::adapt_schema_to_format; -use cloud_llm_client::{CompletionIntent, CompletionMode}; -use collections::IndexMap; +use chrono::{DateTime, Utc}; +use client::{ModelRequestUsage, RequestUsage}; +use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; +use collections::{HashMap, HashSet, IndexMap}; use fs::Fs; use futures::{ + FutureExt, channel::{mpsc, oneshot}, + future::Shared, stream::FuturesUnordered, }; -use gpui::{App, Context, Entity, SharedString, Task}; +use git::repository::DiffType; +use gpui::{ + App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity, +}; use language_model::{ - LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage, - LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, - LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, + LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt, + LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest, + LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, + LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, + LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage, +}; +use project::{ + Project, + git_store::{GitStore, RepositoryState}, }; -use project::Project; use prompt_store::ProjectContext; use schemars::{JsonSchema, Schema}; use serde::{Deserialize, Serialize}; use settings::{Settings, update_settings_file}; use smol::stream::StreamExt; -use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc}; -use std::{fmt::Write, ops::Range}; -use util::{ResultExt, markdown::MarkdownCodeBlock}; +use std::{ + collections::BTreeMap, + ops::RangeInclusive, + path::Path, + rc::Rc, + sync::Arc, + time::{Duration, Instant}, +}; +use std::{fmt::Write, path::PathBuf}; +use util::{ResultExt, debug_panic, markdown::MarkdownCodeBlock}; +use uuid::Uuid; + +const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user"; +pub const MAX_TOOL_NAME_LENGTH: usize = 64; + +/// The ID of the user prompt that initiated a request. +/// +/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key). +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)] +pub struct PromptId(Arc); + +impl PromptId { + pub fn new() -> Self { + Self(Uuid::new_v4().to_string().into()) + } +} + +impl std::fmt::Display for PromptId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4; +pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5); + +#[derive(Debug, Clone)] +enum RetryStrategy { + ExponentialBackoff { + initial_delay: Duration, + max_attempts: u8, + }, + Fixed { + delay: Duration, + max_attempts: u8, + }, +} -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum Message { User(UserMessage), Agent(AgentMessage), + Resume, } impl Message { @@ -43,21 +108,41 @@ impl Message { } } + pub fn to_request(&self) -> Vec { + match self { + Message::User(message) => vec![message.to_request()], + Message::Agent(message) => message.to_request(), + Message::Resume => vec![LanguageModelRequestMessage { + role: Role::User, + content: vec!["Continue where you left off".into()], + cache: false, + }], + } + } + pub fn to_markdown(&self) -> String { match self { Message::User(message) => message.to_markdown(), Message::Agent(message) => message.to_markdown(), + Message::Resume => "[resume]\n".into(), + } + } + + pub fn role(&self) -> Role { + match self { + Message::User(_) | Message::Resume => Role::User, + Message::Agent(_) => Role::Assistant, } } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct UserMessage { pub id: UserMessageId, pub content: Vec, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum UserMessageContent { Text(String), Mention { uri: MentionUri, content: String }, @@ -79,9 +164,9 @@ impl UserMessage { } UserMessageContent::Mention { uri, content } => { if !content.is_empty() { - let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content); + let _ = writeln!(&mut markdown, "{}\n\n{}", uri.as_link(), content); } else { - let _ = write!(&mut markdown, "{}\n", uri.as_link()); + let _ = writeln!(&mut markdown, "{}", uri.as_link()); } } } @@ -102,14 +187,18 @@ impl UserMessage { They are up-to-date and don't need to be re-read.\n\n"; const OPEN_FILES_TAG: &str = ""; + const OPEN_DIRECTORIES_TAG: &str = ""; const OPEN_SYMBOLS_TAG: &str = ""; + const OPEN_SELECTIONS_TAG: &str = ""; const OPEN_THREADS_TAG: &str = ""; const OPEN_FETCH_TAG: &str = ""; const OPEN_RULES_TAG: &str = "\nThe user has specified the following rules that should be applied:\n"; let mut file_context = OPEN_FILES_TAG.to_string(); + let mut directory_context = OPEN_DIRECTORIES_TAG.to_string(); let mut symbol_context = OPEN_SYMBOLS_TAG.to_string(); + let mut selection_context = OPEN_SELECTIONS_TAG.to_string(); let mut thread_context = OPEN_THREADS_TAG.to_string(); let mut fetch_context = OPEN_FETCH_TAG.to_string(); let mut rules_context = OPEN_RULES_TAG.to_string(); @@ -124,29 +213,52 @@ impl UserMessage { } UserMessageContent::Mention { uri, content } => { match uri { - MentionUri::File(path) => { + MentionUri::File { abs_path } => { write!( - &mut symbol_context, + &mut file_context, "\n{}", MarkdownCodeBlock { - tag: &codeblock_tag(&path, None), + tag: &codeblock_tag(abs_path, None), text: &content.to_string(), } ) .ok(); } + MentionUri::PastedImage => { + debug_panic!("pasted image URI should not be used in mention content") + } + MentionUri::Directory { .. } => { + write!(&mut directory_context, "\n{}\n", content).ok(); + } MentionUri::Symbol { - path, line_range, .. + abs_path: path, + line_range, + .. + } => { + write!( + &mut symbol_context, + "\n{}", + MarkdownCodeBlock { + tag: &codeblock_tag(path, Some(line_range)), + text: content + } + ) + .ok(); } - | MentionUri::Selection { - path, line_range, .. + MentionUri::Selection { + abs_path: path, + line_range, + .. } => { write!( - &mut rules_context, + &mut selection_context, "\n{}", MarkdownCodeBlock { - tag: &codeblock_tag(&path, Some(line_range)), - text: &content + tag: &codeblock_tag( + path.as_deref().unwrap_or("Untitled".as_ref()), + Some(line_range) + ), + text: content } ) .ok(); @@ -163,7 +275,7 @@ impl UserMessage { "\n{}", MarkdownCodeBlock { tag: "", - text: &content + text: content } ) .ok(); @@ -189,6 +301,13 @@ impl UserMessage { .push(language_model::MessageContent::Text(file_context)); } + if directory_context.len() > OPEN_DIRECTORIES_TAG.len() { + directory_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(directory_context)); + } + if symbol_context.len() > OPEN_SYMBOLS_TAG.len() { symbol_context.push_str("\n"); message @@ -196,6 +315,13 @@ impl UserMessage { .push(language_model::MessageContent::Text(symbol_context)); } + if selection_context.len() > OPEN_SELECTIONS_TAG.len() { + selection_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(selection_context)); + } + if thread_context.len() > OPEN_THREADS_TAG.len() { thread_context.push_str("\n"); message @@ -231,7 +357,7 @@ impl UserMessage { } } -fn codeblock_tag(full_path: &Path, line_range: Option<&Range>) -> String { +fn codeblock_tag(full_path: &Path, line_range: Option<&RangeInclusive>) -> String { let mut result = String::new(); if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) { @@ -241,10 +367,10 @@ fn codeblock_tag(full_path: &Path, line_range: Option<&Range>) -> String { let _ = write!(result, "{}", full_path.display()); if let Some(range) = line_range { - if range.start == range.end { - let _ = write!(result, ":{}", range.start + 1); + if range.start() == range.end() { + let _ = write!(result, ":{}", range.start() + 1); } else { - let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1); + let _ = write!(result, ":{}-{}", range.start() + 1, range.end() + 1); } } @@ -269,9 +395,6 @@ impl AgentMessage { AgentMessageContent::RedactedThinking(_) => { markdown.push_str("\n") } - AgentMessageContent::Image(_) => { - markdown.push_str("\n"); - } AgentMessageContent::ToolUse(tool_use) => { markdown.push_str(&format!( "**Tool Use**: {} (ID: {})\n", @@ -320,62 +443,77 @@ impl AgentMessage { } pub fn to_request(&self) -> Vec { - let mut content = Vec::with_capacity(self.content.len()); + let mut assistant_message = LanguageModelRequestMessage { + role: Role::Assistant, + content: Vec::with_capacity(self.content.len()), + cache: false, + }; for chunk in &self.content { - let chunk = match chunk { + match chunk { AgentMessageContent::Text(text) => { - language_model::MessageContent::Text(text.clone()) + assistant_message + .content + .push(language_model::MessageContent::Text(text.clone())); } AgentMessageContent::Thinking { text, signature } => { - language_model::MessageContent::Thinking { - text: text.clone(), - signature: signature.clone(), - } + assistant_message + .content + .push(language_model::MessageContent::Thinking { + text: text.clone(), + signature: signature.clone(), + }); } AgentMessageContent::RedactedThinking(value) => { - language_model::MessageContent::RedactedThinking(value.clone()) - } - AgentMessageContent::ToolUse(value) => { - language_model::MessageContent::ToolUse(value.clone()) + assistant_message.content.push( + language_model::MessageContent::RedactedThinking(value.clone()), + ); } - AgentMessageContent::Image(value) => { - language_model::MessageContent::Image(value.clone()) + AgentMessageContent::ToolUse(tool_use) => { + if self.tool_results.contains_key(&tool_use.id) { + assistant_message + .content + .push(language_model::MessageContent::ToolUse(tool_use.clone())); + } } }; - content.push(chunk); } - let mut messages = vec![LanguageModelRequestMessage { - role: Role::Assistant, - content, + let mut user_message = LanguageModelRequestMessage { + role: Role::User, + content: Vec::new(), cache: false, - }]; + }; - if !self.tool_results.is_empty() { - let mut tool_results = Vec::with_capacity(self.tool_results.len()); - for tool_result in self.tool_results.values() { - tool_results.push(language_model::MessageContent::ToolResult( - tool_result.clone(), - )); + for tool_result in self.tool_results.values() { + let mut tool_result = tool_result.clone(); + // Surprisingly, the API fails if we return an empty string here. + // It thinks we are sending a tool use without a tool result. + if tool_result.content.is_empty() { + tool_result.content = "".into(); } - messages.push(LanguageModelRequestMessage { - role: Role::User, - content: tool_results, - cache: false, - }); + user_message + .content + .push(language_model::MessageContent::ToolResult(tool_result)); } + let mut messages = Vec::new(); + if !assistant_message.content.is_empty() { + messages.push(assistant_message); + } + if !user_message.content.is_empty() { + messages.push(user_message); + } messages } } -#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct AgentMessage { pub content: Vec, pub tool_results: IndexMap, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum AgentMessageContent { Text(String), Thinking { @@ -383,72 +521,470 @@ pub enum AgentMessageContent { signature: Option, }, RedactedThinking(String), - Image(LanguageModelImage), ToolUse(LanguageModelToolUse), } +pub trait TerminalHandle { + fn id(&self, cx: &AsyncApp) -> Result; + fn current_output(&self, cx: &AsyncApp) -> Result; + fn wait_for_exit(&self, cx: &AsyncApp) -> Result>>; +} + +pub trait ThreadEnvironment { + fn create_terminal( + &self, + command: String, + cwd: Option, + output_byte_limit: Option, + cx: &mut AsyncApp, + ) -> Task>>; +} + #[derive(Debug)] -pub enum AgentResponseEvent { - Text(String), - Thinking(String), +pub enum ThreadEvent { + UserMessage(UserMessage), + AgentText(String), + AgentThinking(String), ToolCall(acp::ToolCall), ToolCallUpdate(acp_thread::ToolCallUpdate), ToolCallAuthorization(ToolCallAuthorization), + Retry(acp_thread::RetryStatus), Stop(acp::StopReason), } +#[derive(Debug)] +pub struct NewTerminal { + pub command: String, + pub output_byte_limit: Option, + pub cwd: Option, + pub response: oneshot::Sender>>, +} + #[derive(Debug)] pub struct ToolCallAuthorization { - pub tool_call: acp::ToolCall, + pub tool_call: acp::ToolCallUpdate, pub options: Vec, pub response: oneshot::Sender, } +#[derive(Debug, thiserror::Error)] +enum CompletionError { + #[error("max tokens")] + MaxTokens, + #[error("refusal")] + Refusal, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + pub struct Thread { + id: acp::SessionId, + prompt_id: PromptId, + updated_at: DateTime, + title: Option, + pending_title_generation: Option>, + summary: Option, messages: Vec, completion_mode: CompletionMode, /// Holds the task that handles agent interaction until the end of the turn. /// Survives across multiple requests as the model performs tool calls and /// we run tools, report their results. - running_turn: Option>, + running_turn: Option, pending_message: Option, tools: BTreeMap>, + tool_use_limit_reached: bool, + request_token_usage: HashMap, + #[allow(unused)] + cumulative_token_usage: TokenUsage, + #[allow(unused)] + initial_project_snapshot: Shared>>>, context_server_registry: Entity, profile_id: AgentProfileId, - project_context: Rc>, + project_context: Entity, templates: Arc, - pub selected_model: Arc, - project: Entity, - action_log: Entity, + model: Option>, + summarization_model: Option>, + prompt_capabilities_tx: watch::Sender, + pub(crate) prompt_capabilities_rx: watch::Receiver, + pub(crate) project: Entity, + pub(crate) action_log: Entity, } impl Thread { + fn prompt_capabilities(model: Option<&dyn LanguageModel>) -> acp::PromptCapabilities { + let image = model.map_or(true, |model| model.supports_images()); + acp::PromptCapabilities { + image, + audio: false, + embedded_context: true, + } + } + pub fn new( project: Entity, - project_context: Rc>, + project_context: Entity, context_server_registry: Entity, - action_log: Entity, templates: Arc, - default_model: Arc, + model: Option>, cx: &mut Context, ) -> Self { let profile_id = AgentSettings::get_global(cx).default_profile.clone(); + let action_log = cx.new(|_cx| ActionLog::new(project.clone())); + let (prompt_capabilities_tx, prompt_capabilities_rx) = + watch::channel(Self::prompt_capabilities(model.as_deref())); Self { + id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()), + prompt_id: PromptId::new(), + updated_at: Utc::now(), + title: None, + pending_title_generation: None, + summary: None, messages: Vec::new(), - completion_mode: CompletionMode::Normal, + completion_mode: AgentSettings::get_global(cx).preferred_completion_mode, + running_turn: None, + pending_message: None, + tools: BTreeMap::default(), + tool_use_limit_reached: false, + request_token_usage: HashMap::default(), + cumulative_token_usage: TokenUsage::default(), + initial_project_snapshot: { + let project_snapshot = Self::project_snapshot(project.clone(), cx); + cx.foreground_executor() + .spawn(async move { Some(project_snapshot.await) }) + .shared() + }, + context_server_registry, + profile_id, + project_context, + templates, + model, + summarization_model: None, + prompt_capabilities_tx, + prompt_capabilities_rx, + project, + action_log, + } + } + + pub fn id(&self) -> &acp::SessionId { + &self.id + } + + pub fn replay( + &mut self, + cx: &mut Context, + ) -> mpsc::UnboundedReceiver> { + let (tx, rx) = mpsc::unbounded(); + let stream = ThreadEventStream(tx); + for message in &self.messages { + match message { + Message::User(user_message) => stream.send_user_message(user_message), + Message::Agent(assistant_message) => { + for content in &assistant_message.content { + match content { + AgentMessageContent::Text(text) => stream.send_text(text), + AgentMessageContent::Thinking { text, .. } => { + stream.send_thinking(text) + } + AgentMessageContent::RedactedThinking(_) => {} + AgentMessageContent::ToolUse(tool_use) => { + self.replay_tool_call( + tool_use, + assistant_message.tool_results.get(&tool_use.id), + &stream, + cx, + ); + } + } + } + } + Message::Resume => {} + } + } + rx + } + + fn replay_tool_call( + &self, + tool_use: &LanguageModelToolUse, + tool_result: Option<&LanguageModelToolResult>, + stream: &ThreadEventStream, + cx: &mut Context, + ) { + let tool = self.tools.get(tool_use.name.as_ref()).cloned().or_else(|| { + self.context_server_registry + .read(cx) + .servers() + .find_map(|(_, tools)| { + if let Some(tool) = tools.get(tool_use.name.as_ref()) { + Some(tool.clone()) + } else { + None + } + }) + }); + + let Some(tool) = tool else { + stream + .0 + .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall { + id: acp::ToolCallId(tool_use.id.to_string().into()), + title: tool_use.name.to_string(), + kind: acp::ToolKind::Other, + status: acp::ToolCallStatus::Failed, + content: Vec::new(), + locations: Vec::new(), + raw_input: Some(tool_use.input.clone()), + raw_output: None, + }))) + .ok(); + return; + }; + + let title = tool.initial_title(tool_use.input.clone(), cx); + let kind = tool.kind(); + stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone()); + + let output = tool_result + .as_ref() + .and_then(|result| result.output.clone()); + if let Some(output) = output.clone() { + let tool_event_stream = ToolCallEventStream::new( + tool_use.id.clone(), + stream.clone(), + Some(self.project.read(cx).fs().clone()), + ); + tool.replay(tool_use.input.clone(), output, tool_event_stream, cx) + .log_err(); + } + + stream.update_tool_call_fields( + &tool_use.id, + acp::ToolCallUpdateFields { + status: Some( + tool_result + .as_ref() + .map_or(acp::ToolCallStatus::Failed, |result| { + if result.is_error { + acp::ToolCallStatus::Failed + } else { + acp::ToolCallStatus::Completed + } + }), + ), + raw_output: output, + ..Default::default() + }, + ); + } + + pub fn from_db( + id: acp::SessionId, + db_thread: DbThread, + project: Entity, + project_context: Entity, + context_server_registry: Entity, + action_log: Entity, + templates: Arc, + cx: &mut Context, + ) -> Self { + let profile_id = db_thread + .profile + .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone()); + let model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + db_thread + .model + .and_then(|model| { + let model = SelectedModel { + provider: model.provider.clone().into(), + model: model.model.into(), + }; + registry.select_model(&model, cx) + }) + .or_else(|| registry.default_model()) + .map(|model| model.model) + }); + let (prompt_capabilities_tx, prompt_capabilities_rx) = + watch::channel(Self::prompt_capabilities(model.as_deref())); + + Self { + id, + prompt_id: PromptId::new(), + title: if db_thread.title.is_empty() { + None + } else { + Some(db_thread.title.clone()) + }, + pending_title_generation: None, + summary: db_thread.detailed_summary, + messages: db_thread.messages, + completion_mode: db_thread.completion_mode.unwrap_or_default(), running_turn: None, pending_message: None, tools: BTreeMap::default(), + tool_use_limit_reached: false, + request_token_usage: db_thread.request_token_usage.clone(), + cumulative_token_usage: db_thread.cumulative_token_usage, + initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(), context_server_registry, profile_id, project_context, templates, - selected_model: default_model, + model, + summarization_model: None, project, action_log, + updated_at: db_thread.updated_at, + prompt_capabilities_tx, + prompt_capabilities_rx, } } + pub fn to_db(&self, cx: &App) -> Task { + let initial_project_snapshot = self.initial_project_snapshot.clone(); + let mut thread = DbThread { + title: self.title(), + messages: self.messages.clone(), + updated_at: self.updated_at, + detailed_summary: self.summary.clone(), + initial_project_snapshot: None, + cumulative_token_usage: self.cumulative_token_usage, + request_token_usage: self.request_token_usage.clone(), + model: self.model.as_ref().map(|model| DbLanguageModel { + provider: model.provider_id().to_string(), + model: model.name().0.to_string(), + }), + completion_mode: Some(self.completion_mode), + profile: Some(self.profile_id.clone()), + }; + + cx.background_spawn(async move { + let initial_project_snapshot = initial_project_snapshot.await; + thread.initial_project_snapshot = initial_project_snapshot; + thread + }) + } + + /// Create a snapshot of the current project state including git information and unsaved buffers. + fn project_snapshot( + project: Entity, + cx: &mut Context, + ) -> Task> { + let git_store = project.read(cx).git_store().clone(); + let worktree_snapshots: Vec<_> = project + .read(cx) + .visible_worktrees(cx) + .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx)) + .collect(); + + cx.spawn(async move |_, cx| { + let worktree_snapshots = futures::future::join_all(worktree_snapshots).await; + + let mut unsaved_buffers = Vec::new(); + cx.update(|app_cx| { + let buffer_store = project.read(app_cx).buffer_store(); + for buffer_handle in buffer_store.read(app_cx).buffers() { + let buffer = buffer_handle.read(app_cx); + if buffer.is_dirty() + && let Some(file) = buffer.file() + { + let path = file.path().to_string_lossy().to_string(); + unsaved_buffers.push(path); + } + } + }) + .ok(); + + Arc::new(ProjectSnapshot { + worktree_snapshots, + unsaved_buffer_paths: unsaved_buffers, + timestamp: Utc::now(), + }) + }) + } + + fn worktree_snapshot( + worktree: Entity, + git_store: Entity, + cx: &App, + ) -> Task { + cx.spawn(async move |cx| { + // Get worktree path and snapshot + let worktree_info = cx.update(|app_cx| { + let worktree = worktree.read(app_cx); + let path = worktree.abs_path().to_string_lossy().to_string(); + let snapshot = worktree.snapshot(); + (path, snapshot) + }); + + let Ok((worktree_path, _snapshot)) = worktree_info else { + return WorktreeSnapshot { + worktree_path: String::new(), + git_state: None, + }; + }; + + let git_state = git_store + .update(cx, |git_store, cx| { + git_store + .repositories() + .values() + .find(|repo| { + repo.read(cx) + .abs_path_to_repo_path(&worktree.read(cx).abs_path()) + .is_some() + }) + .cloned() + }) + .ok() + .flatten() + .map(|repo| { + repo.update(cx, |repo, _| { + let current_branch = + repo.branch.as_ref().map(|branch| branch.name().to_owned()); + repo.send_job(None, |state, _| async move { + let RepositoryState::Local { backend, .. } = state else { + return GitState { + remote_url: None, + head_sha: None, + current_branch, + diff: None, + }; + }; + + let remote_url = backend.remote_url("origin"); + let head_sha = backend.head_sha().await; + let diff = backend.diff(DiffType::HeadToWorktree).await.ok(); + + GitState { + remote_url, + head_sha, + current_branch, + diff, + } + }) + }) + }); + + let git_state = match git_state { + Some(git_state) => match git_state.ok() { + Some(git_state) => git_state.await.ok(), + None => None, + }, + None => None, + }; + + WorktreeSnapshot { + worktree_path, + git_state, + } + }) + } + + pub fn project_context(&self) -> &Entity { + &self.project_context + } + pub fn project(&self) -> &Entity { &self.project } @@ -457,8 +993,51 @@ impl Thread { &self.action_log } - pub fn set_mode(&mut self, mode: CompletionMode) { + pub fn is_empty(&self) -> bool { + self.messages.is_empty() && self.title.is_none() + } + + pub fn model(&self) -> Option<&Arc> { + self.model.as_ref() + } + + pub fn set_model(&mut self, model: Arc, cx: &mut Context) { + let old_usage = self.latest_token_usage(); + self.model = Some(model); + let new_caps = Self::prompt_capabilities(self.model.as_deref()); + let new_usage = self.latest_token_usage(); + if old_usage != new_usage { + cx.emit(TokenUsageUpdated(new_usage)); + } + self.prompt_capabilities_tx.send(new_caps).log_err(); + cx.notify() + } + + pub fn summarization_model(&self) -> Option<&Arc> { + self.summarization_model.as_ref() + } + + pub fn set_summarization_model( + &mut self, + model: Option>, + cx: &mut Context, + ) { + self.summarization_model = model; + cx.notify() + } + + pub fn completion_mode(&self) -> CompletionMode { + self.completion_mode + } + + pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context) { + let old_usage = self.latest_token_usage(); self.completion_mode = mode; + let new_usage = self.latest_token_usage(); + if old_usage != new_usage { + cx.emit(TokenUsageUpdated(new_usage)); + } + cx.notify() } #[cfg(any(test, feature = "test-support"))] @@ -470,181 +1049,352 @@ impl Thread { } } - pub fn add_tool(&mut self, tool: impl AgentTool) { - self.tools.insert(tool.name(), tool.erase()); + pub fn add_default_tools( + &mut self, + environment: Rc, + cx: &mut Context, + ) { + let language_registry = self.project.read(cx).languages().clone(); + self.add_tool(CopyPathTool::new(self.project.clone())); + self.add_tool(CreateDirectoryTool::new(self.project.clone())); + self.add_tool(DeletePathTool::new( + self.project.clone(), + self.action_log.clone(), + )); + self.add_tool(DiagnosticsTool::new(self.project.clone())); + self.add_tool(EditFileTool::new( + self.project.clone(), + cx.weak_entity(), + language_registry, + )); + self.add_tool(FetchTool::new(self.project.read(cx).client().http_client())); + self.add_tool(FindPathTool::new(self.project.clone())); + self.add_tool(GrepTool::new(self.project.clone())); + self.add_tool(ListDirectoryTool::new(self.project.clone())); + self.add_tool(MovePathTool::new(self.project.clone())); + self.add_tool(NowTool); + self.add_tool(OpenTool::new(self.project.clone())); + self.add_tool(ReadFileTool::new( + self.project.clone(), + self.action_log.clone(), + )); + self.add_tool(TerminalTool::new(self.project.clone(), environment)); + self.add_tool(ThinkingTool); + self.add_tool(WebSearchTool); + } + + pub fn add_tool(&mut self, tool: T) { + self.tools.insert(T::name().into(), tool.erase()); } pub fn remove_tool(&mut self, name: &str) -> bool { self.tools.remove(name).is_some() } + pub fn profile(&self) -> &AgentProfileId { + &self.profile_id + } + pub fn set_profile(&mut self, profile_id: AgentProfileId) { self.profile_id = profile_id; } - pub fn cancel(&mut self) { - // TODO: do we need to emit a stop::cancel for ACP? - self.running_turn.take(); - self.flush_pending_message(); + pub fn cancel(&mut self, cx: &mut Context) { + if let Some(running_turn) = self.running_turn.take() { + running_turn.cancel(); + } + self.flush_pending_message(cx); + } + + fn update_token_usage(&mut self, update: language_model::TokenUsage, cx: &mut Context) { + let Some(last_user_message) = self.last_user_message() else { + return; + }; + + self.request_token_usage + .insert(last_user_message.id.clone(), update); + cx.emit(TokenUsageUpdated(self.latest_token_usage())); + cx.notify(); } - pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> { - self.cancel(); + pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context) -> Result<()> { + self.cancel(cx); let Some(position) = self.messages.iter().position( |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id), ) else { return Err(anyhow!("Message not found")); }; - self.messages.truncate(position); + + for message in self.messages.drain(position..) { + match message { + Message::User(message) => { + self.request_token_usage.remove(&message.id); + } + Message::Agent(_) | Message::Resume => {} + } + } + self.summary = None; + cx.notify(); Ok(()) } + pub fn latest_token_usage(&self) -> Option { + let last_user_message = self.last_user_message()?; + let tokens = self.request_token_usage.get(&last_user_message.id)?; + let model = self.model.clone()?; + + Some(acp_thread::TokenUsage { + max_tokens: model.max_token_count_for_mode(self.completion_mode.into()), + used_tokens: tokens.total_tokens(), + }) + } + + pub fn resume( + &mut self, + cx: &mut Context, + ) -> Result>> { + self.messages.push(Message::Resume); + cx.notify(); + + log::debug!("Total messages in thread: {}", self.messages.len()); + self.run_turn(cx) + } + /// Sending a message results in the model streaming a response, which could include tool calls. /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent. /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. pub fn send( &mut self, - message_id: UserMessageId, + id: UserMessageId, content: impl IntoIterator, cx: &mut Context, - ) -> mpsc::UnboundedReceiver> + ) -> Result>> where T: Into, { - let model = self.selected_model.clone(); + let model = self.model().context("No language model configured")?; + + log::info!("Thread::send called with model: {}", model.name().0); + self.advance_prompt_id(); + let content = content.into_iter().map(Into::into).collect::>(); - log::info!("Thread::send called with model: {:?}", model.name()); log::debug!("Thread::send content: {:?}", content); + self.messages + .push(Message::User(UserMessage { id, content })); cx.notify(); - let (events_tx, events_rx) = - mpsc::unbounded::>(); - let event_stream = AgentResponseEventStream(events_tx); - self.messages.push(Message::User(UserMessage { - id: message_id.clone(), - content, - })); - log::info!("Total messages in thread: {}", self.messages.len()); - self.running_turn = Some(cx.spawn(async move |this, cx| { - log::info!("Starting agent turn execution"); - let turn_result = async { - let mut completion_intent = CompletionIntent::UserPrompt; - loop { - log::debug!( - "Building completion request with intent: {:?}", - completion_intent - ); - let request = this.update(cx, |this, cx| { - this.build_completion_request(completion_intent, cx) - })?; + log::debug!("Total messages in thread: {}", self.messages.len()); + self.run_turn(cx) + } - log::info!("Calling model.stream_completion"); - let mut events = model.stream_completion(request, cx).await?; - log::debug!("Stream completion started successfully"); - - let mut tool_uses = FuturesUnordered::new(); - while let Some(event) = events.next().await { - match event? { - LanguageModelCompletionEvent::Stop(reason) => { - event_stream.send_stop(reason); - if reason == StopReason::Refusal { - this.update(cx, |this, _cx| this.truncate(message_id))??; - return Ok(()); - } + fn run_turn( + &mut self, + cx: &mut Context, + ) -> Result>> { + self.cancel(cx); + + let model = self.model.clone().context("No language model configured")?; + let profile = AgentSettings::get_global(cx) + .profiles + .get(&self.profile_id) + .context("Profile not found")?; + let (events_tx, events_rx) = mpsc::unbounded::>(); + let event_stream = ThreadEventStream(events_tx); + let message_ix = self.messages.len().saturating_sub(1); + self.tool_use_limit_reached = false; + self.summary = None; + self.running_turn = Some(RunningTurn { + event_stream: event_stream.clone(), + tools: self.enabled_tools(profile, &model, cx), + _task: cx.spawn(async move |this, cx| { + log::debug!("Starting agent turn execution"); + + let turn_result = Self::run_turn_internal(&this, model, &event_stream, cx).await; + _ = this.update(cx, |this, cx| this.flush_pending_message(cx)); + + match turn_result { + Ok(()) => { + log::debug!("Turn execution completed"); + event_stream.send_stop(acp::StopReason::EndTurn); + } + Err(error) => { + log::error!("Turn execution failed: {:?}", error); + match error.downcast::() { + Ok(CompletionError::Refusal) => { + event_stream.send_stop(acp::StopReason::Refusal); + _ = this.update(cx, |this, _| this.messages.truncate(message_ix)); } - event => { - log::trace!("Received completion event: {:?}", event); - this.update(cx, |this, cx| { - tool_uses.extend(this.handle_streamed_completion_event( - event, - &event_stream, - cx, - )); - }) - .ok(); + Ok(CompletionError::MaxTokens) => { + event_stream.send_stop(acp::StopReason::MaxTokens); + } + Ok(CompletionError::Other(error)) | Err(error) => { + event_stream.send_error(error); } } } + } - if tool_uses.is_empty() { - log::info!("No tool uses found, completing turn"); - return Ok(()); + _ = this.update(cx, |this, _| this.running_turn.take()); + }), + }); + Ok(events_rx) + } + + async fn run_turn_internal( + this: &WeakEntity, + model: Arc, + event_stream: &ThreadEventStream, + cx: &mut AsyncApp, + ) -> Result<()> { + let mut attempt = 0; + let mut intent = CompletionIntent::UserPrompt; + loop { + let request = + this.update(cx, |this, cx| this.build_completion_request(intent, cx))??; + + telemetry::event!( + "Agent Thread Completion", + thread_id = this.read_with(cx, |this, _| this.id.to_string())?, + prompt_id = this.read_with(cx, |this, _| this.prompt_id.to_string())?, + model = model.telemetry_id(), + model_provider = model.provider_id().to_string(), + attempt + ); + + log::debug!("Calling model.stream_completion, attempt {}", attempt); + let mut events = model + .stream_completion(request, cx) + .await + .map_err(|error| anyhow!(error))?; + let mut tool_results = FuturesUnordered::new(); + let mut error = None; + while let Some(event) = events.next().await { + log::trace!("Received completion event: {:?}", event); + match event { + Ok(event) => { + tool_results.extend(this.update(cx, |this, cx| { + this.handle_completion_event(event, event_stream, cx) + })??); } - log::info!("Found {} tool uses to execute", tool_uses.len()); - - while let Some(tool_result) = tool_uses.next().await { - log::info!("Tool finished {:?}", tool_result); - - event_stream.update_tool_call_fields( - &tool_result.tool_use_id, - acp::ToolCallUpdateFields { - status: Some(if tool_result.is_error { - acp::ToolCallStatus::Failed - } else { - acp::ToolCallStatus::Completed - }), - raw_output: tool_result.output.clone(), - ..Default::default() - }, - ); - this.update(cx, |this, _cx| { - this.pending_message() - .tool_results - .insert(tool_result.tool_use_id.clone(), tool_result); - }) - .ok(); + Err(err) => { + error = Some(err); + break; } - - this.update(cx, |this, _| this.flush_pending_message())?; - completion_intent = CompletionIntent::ToolResults; } } - .await; - this.update(cx, |this, _| this.flush_pending_message()).ok(); - if let Err(error) = turn_result { - log::error!("Turn execution failed: {:?}", error); - event_stream.send_error(error); + let end_turn = tool_results.is_empty(); + while let Some(tool_result) = tool_results.next().await { + log::debug!("Tool finished {:?}", tool_result); + + event_stream.update_tool_call_fields( + &tool_result.tool_use_id, + acp::ToolCallUpdateFields { + status: Some(if tool_result.is_error { + acp::ToolCallStatus::Failed + } else { + acp::ToolCallStatus::Completed + }), + raw_output: tool_result.output.clone(), + ..Default::default() + }, + ); + this.update(cx, |this, _cx| { + this.pending_message() + .tool_results + .insert(tool_result.tool_use_id.clone(), tool_result); + })?; + } + + this.update(cx, |this, cx| { + this.flush_pending_message(cx); + if this.title.is_none() && this.pending_title_generation.is_none() { + this.generate_title(cx); + } + })?; + + if let Some(error) = error { + attempt += 1; + let retry = + this.update(cx, |this, _| this.handle_completion_error(error, attempt))??; + let timer = cx.background_executor().timer(retry.duration); + event_stream.send_retry(retry); + timer.await; + this.update(cx, |this, _cx| { + if let Some(Message::Agent(message)) = this.messages.last() { + if message.tool_results.is_empty() { + intent = CompletionIntent::UserPrompt; + this.messages.push(Message::Resume); + } + } + })?; + } else if this.read_with(cx, |this, _| this.tool_use_limit_reached)? { + return Err(language_model::ToolUseLimitReachedError.into()); + } else if end_turn { + return Ok(()); } else { - log::info!("Turn execution completed successfully"); + intent = CompletionIntent::ToolResults; + attempt = 0; } - })); - events_rx + } } - pub fn build_system_message(&self) -> LanguageModelRequestMessage { - log::debug!("Building system message"); - let prompt = SystemPromptTemplate { - project: &self.project_context.borrow(), - available_tools: self.tools.keys().cloned().collect(), + fn handle_completion_error( + &mut self, + error: LanguageModelCompletionError, + attempt: u8, + ) -> Result { + if self.completion_mode == CompletionMode::Normal { + return Err(anyhow!(error)); } - .render(&self.templates) - .context("failed to build system prompt") - .expect("Invalid template"); - log::debug!("System message built"); - LanguageModelRequestMessage { - role: Role::System, - content: vec![prompt.into()], - cache: true, + + let Some(strategy) = Self::retry_strategy_for(&error) else { + return Err(anyhow!(error)); + }; + + let max_attempts = match &strategy { + RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts, + RetryStrategy::Fixed { max_attempts, .. } => *max_attempts, + }; + + if attempt > max_attempts { + return Err(anyhow!(error)); } + + let delay = match &strategy { + RetryStrategy::ExponentialBackoff { initial_delay, .. } => { + let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32); + Duration::from_secs(delay_secs) + } + RetryStrategy::Fixed { delay, .. } => *delay, + }; + log::debug!("Retry attempt {attempt} with delay {delay:?}"); + + Ok(acp_thread::RetryStatus { + last_error: error.to_string().into(), + attempt: attempt as usize, + max_attempts: max_attempts as usize, + started_at: Instant::now(), + duration: delay, + }) } /// A helper method that's called on every streamed completion event. - /// Returns an optional tool result task, which the main agentic loop in - /// send will send back to the model when it resolves. - fn handle_streamed_completion_event( + /// Returns an optional tool result task, which the main agentic loop will + /// send back to the model when it resolves. + fn handle_completion_event( &mut self, event: LanguageModelCompletionEvent, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, - ) -> Option> { + ) -> Result>> { log::trace!("Handling streamed completion event: {:?}", event); use LanguageModelCompletionEvent::*; match event { StartMessage { .. } => { - self.flush_pending_message(); + self.flush_pending_message(cx); self.pending_message = Some(AgentMessage::default()); } Text(new_text) => self.handle_text_event(new_text, event_stream, cx), @@ -653,7 +1403,7 @@ impl Thread { } RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx), ToolUse(tool_use) => { - return self.handle_tool_use_event(tool_use, event_stream, cx); + return Ok(self.handle_tool_use_event(tool_use, event_stream, cx)); } ToolUseJsonParseError { id, @@ -661,27 +1411,55 @@ impl Thread { raw_input, json_parse_error, } => { - return Some(Task::ready(self.handle_tool_use_json_parse_error_event( - id, - tool_name, - raw_input, - json_parse_error, + return Ok(Some(Task::ready( + self.handle_tool_use_json_parse_error_event( + id, + tool_name, + raw_input, + json_parse_error, + ), ))); } - UsageUpdate(_) | StatusUpdate(_) => {} - Stop(_) => unreachable!(), + UsageUpdate(usage) => { + telemetry::event!( + "Agent Thread Completion Usage Updated", + thread_id = self.id.to_string(), + prompt_id = self.prompt_id.to_string(), + model = self.model.as_ref().map(|m| m.telemetry_id()), + model_provider = self.model.as_ref().map(|m| m.provider_id().to_string()), + input_tokens = usage.input_tokens, + output_tokens = usage.output_tokens, + cache_creation_input_tokens = usage.cache_creation_input_tokens, + cache_read_input_tokens = usage.cache_read_input_tokens, + ); + self.update_token_usage(usage, cx); + } + StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => { + self.update_model_request_usage(amount, limit, cx); + } + StatusUpdate( + CompletionRequestStatus::Started + | CompletionRequestStatus::Queued { .. } + | CompletionRequestStatus::Failed { .. }, + ) => {} + StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => { + self.tool_use_limit_reached = true; + } + Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()), + Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()), + Stop(StopReason::ToolUse | StopReason::EndTurn) => {} } - None + Ok(None) } fn handle_text_event( &mut self, new_text: String, - events_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) { - events_stream.send_text(&new_text); + event_stream.send_text(&new_text); let last_message = self.pending_message(); if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() { @@ -699,7 +1477,7 @@ impl Thread { &mut self, new_text: String, new_signature: Option, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) { event_stream.send_thinking(&new_text); @@ -731,22 +1509,22 @@ impl Thread { fn handle_tool_use_event( &mut self, tool_use: LanguageModelToolUse, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) -> Option> { cx.notify(); - let tool = self.tools.get(tool_use.name.as_ref()).cloned(); + let tool = self.tool(tool_use.name.as_ref()); let mut title = SharedString::from(&tool_use.name); let mut kind = acp::ToolKind::Other; if let Some(tool) = tool.as_ref() { - title = tool.initial_title(tool_use.input.clone()); + title = tool.initial_title(tool_use.input.clone(), cx); kind = tool.kind(); } // Ensure the last message ends in the current tool use let last_message = self.pending_message(); - let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| { + let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| { if let AgentMessageContent::ToolUse(last_tool_use) = content { if last_tool_use.id == tool_use.id { *last_tool_use = tool_use.clone(); @@ -793,21 +1571,22 @@ impl Thread { let fs = self.project.read(cx).fs().clone(); let tool_event_stream = - ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs)); + ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs)); tool_event_stream.update_fields(acp::ToolCallUpdateFields { status: Some(acp::ToolCallStatus::InProgress), ..Default::default() }); - let supports_images = self.selected_model.supports_images(); + let supports_images = self.model().is_some_and(|model| model.supports_images()); let tool_result = tool.run(tool_use.input, tool_event_stream, cx); + log::debug!("Running tool {}", tool_use.name); Some(cx.foreground_executor().spawn(async move { let tool_result = tool_result.await.and_then(|output| { - if let LanguageModelToolResultContent::Image(_) = &output.llm_output { - if !supports_images { - return Err(anyhow!( - "Attempted to read an image, but this model doesn't support it.", - )); - } + if let LanguageModelToolResultContent::Image(_) = &output.llm_output + && !supports_images + { + return Err(anyhow!( + "Attempted to read an image, but this model doesn't support it.", + )); } Ok(output) }); @@ -825,7 +1604,7 @@ impl Thread { tool_name: tool_use.name, is_error: true, content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())), - output: None, + output: Some(error.to_string().into()), }, } })) @@ -848,15 +1627,176 @@ impl Thread { } } + fn update_model_request_usage(&self, amount: usize, limit: UsageLimit, cx: &mut Context) { + self.project + .read(cx) + .user_store() + .update(cx, |user_store, cx| { + user_store.update_model_request_usage( + ModelRequestUsage(RequestUsage { + amount: amount as i32, + limit, + }), + cx, + ) + }); + } + + pub fn title(&self) -> SharedString { + self.title.clone().unwrap_or("New Thread".into()) + } + + pub fn summary(&mut self, cx: &mut Context) -> Task> { + if let Some(summary) = self.summary.as_ref() { + return Task::ready(Ok(summary.clone())); + } + let Some(model) = self.summarization_model.clone() else { + return Task::ready(Err(anyhow!("No summarization model available"))); + }; + let mut request = LanguageModelRequest { + intent: Some(CompletionIntent::ThreadContextSummarization), + temperature: AgentSettings::temperature_for_model(&model, cx), + ..Default::default() + }; + + for message in &self.messages { + request.messages.extend(message.to_request()); + } + + request.messages.push(LanguageModelRequestMessage { + role: Role::User, + content: vec![SUMMARIZE_THREAD_DETAILED_PROMPT.into()], + cache: false, + }); + cx.spawn(async move |this, cx| { + let mut summary = String::new(); + let mut messages = model.stream_completion(request, cx).await?; + while let Some(event) = messages.next().await { + let event = event?; + let text = match event { + LanguageModelCompletionEvent::Text(text) => text, + LanguageModelCompletionEvent::StatusUpdate( + CompletionRequestStatus::UsageUpdated { amount, limit }, + ) => { + this.update(cx, |thread, cx| { + thread.update_model_request_usage(amount, limit, cx); + })?; + continue; + } + _ => continue, + }; + + let mut lines = text.lines(); + summary.extend(lines.next()); + } + + log::debug!("Setting summary: {}", summary); + let summary = SharedString::from(summary); + + this.update(cx, |this, cx| { + this.summary = Some(summary.clone()); + cx.notify() + })?; + + Ok(summary) + }) + } + + fn generate_title(&mut self, cx: &mut Context) { + let Some(model) = self.summarization_model.clone() else { + return; + }; + + log::debug!( + "Generating title with model: {:?}", + self.summarization_model.as_ref().map(|model| model.name()) + ); + let mut request = LanguageModelRequest { + intent: Some(CompletionIntent::ThreadSummarization), + temperature: AgentSettings::temperature_for_model(&model, cx), + ..Default::default() + }; + + for message in &self.messages { + request.messages.extend(message.to_request()); + } + + request.messages.push(LanguageModelRequestMessage { + role: Role::User, + content: vec![SUMMARIZE_THREAD_PROMPT.into()], + cache: false, + }); + self.pending_title_generation = Some(cx.spawn(async move |this, cx| { + let mut title = String::new(); + + let generate = async { + let mut messages = model.stream_completion(request, cx).await?; + while let Some(event) = messages.next().await { + let event = event?; + let text = match event { + LanguageModelCompletionEvent::Text(text) => text, + LanguageModelCompletionEvent::StatusUpdate( + CompletionRequestStatus::UsageUpdated { amount, limit }, + ) => { + this.update(cx, |thread, cx| { + thread.update_model_request_usage(amount, limit, cx); + })?; + continue; + } + _ => continue, + }; + + let mut lines = text.lines(); + title.extend(lines.next()); + + // Stop if the LLM generated multiple lines. + if lines.next().is_some() { + break; + } + } + anyhow::Ok(()) + }; + + if generate.await.context("failed to generate title").is_ok() { + _ = this.update(cx, |this, cx| this.set_title(title.into(), cx)); + } + _ = this.update(cx, |this, _| this.pending_title_generation = None); + })); + } + + pub fn set_title(&mut self, title: SharedString, cx: &mut Context) { + self.pending_title_generation = None; + if Some(&title) != self.title.as_ref() { + self.title = Some(title); + cx.emit(TitleUpdated); + cx.notify(); + } + } + + fn last_user_message(&self) -> Option<&UserMessage> { + self.messages + .iter() + .rev() + .find_map(|message| match message { + Message::User(user_message) => Some(user_message), + Message::Agent(_) => None, + Message::Resume => None, + }) + } + fn pending_message(&mut self) -> &mut AgentMessage { self.pending_message.get_or_insert_default() } - fn flush_pending_message(&mut self) { + fn flush_pending_message(&mut self, cx: &mut Context) { let Some(mut message) = self.pending_message.take() else { return; }; + if message.content.is_empty() { + return; + } + for content in &message.content { let AgentMessageContent::ToolUse(tool_use) = content else { continue; @@ -869,9 +1809,7 @@ impl Thread { tool_use_id: tool_use.id.clone(), tool_name: tool_use.name.clone(), is_error: true, - content: LanguageModelToolResultContent::Text( - "Tool canceled by user".into(), - ), + content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()), output: None, }, ); @@ -879,99 +1817,154 @@ impl Thread { } self.messages.push(Message::Agent(message)); + self.updated_at = Utc::now(); + self.summary = None; + cx.notify() } pub(crate) fn build_completion_request( &self, completion_intent: CompletionIntent, - cx: &mut App, - ) -> LanguageModelRequest { - log::debug!("Building completion request"); - log::debug!("Completion intent: {:?}", completion_intent); - log::debug!("Completion mode: {:?}", self.completion_mode); - - let messages = self.build_request_messages(); - log::info!("Request will include {} messages", messages.len()); - - let tools = if let Some(tools) = self.tools(cx).log_err() { - tools - .filter_map(|tool| { - let tool_name = tool.name().to_string(); + cx: &App, + ) -> Result { + let model = self.model().context("No language model configured")?; + let tools = if let Some(turn) = self.running_turn.as_ref() { + turn.tools + .iter() + .filter_map(|(tool_name, tool)| { log::trace!("Including tool: {}", tool_name); Some(LanguageModelRequestTool { - name: tool_name, + name: tool_name.to_string(), description: tool.description().to_string(), - input_schema: tool - .input_schema(self.selected_model.tool_input_format()) - .log_err()?, + input_schema: tool.input_schema(model.tool_input_format()).log_err()?, }) }) - .collect() + .collect::>() } else { Vec::new() }; - log::info!("Request includes {} tools", tools.len()); + log::debug!("Building completion request"); + log::debug!("Completion intent: {:?}", completion_intent); + log::debug!("Completion mode: {:?}", self.completion_mode); + + let messages = self.build_request_messages(cx); + log::debug!("Request will include {} messages", messages.len()); + log::debug!("Request includes {} tools", tools.len()); let request = LanguageModelRequest { - thread_id: None, - prompt_id: None, + thread_id: Some(self.id.to_string()), + prompt_id: Some(self.prompt_id.to_string()), intent: Some(completion_intent), - mode: Some(self.completion_mode), + mode: Some(self.completion_mode.into()), messages, tools, tool_choice: None, stop: Vec::new(), - temperature: None, + temperature: AgentSettings::temperature_for_model(model, cx), thinking_allowed: true, }; log::debug!("Completion request built successfully"); - request + Ok(request) } - fn tools<'a>(&'a self, cx: &'a App) -> Result>> { - let profile = AgentSettings::get_global(cx) - .profiles - .get(&self.profile_id) - .context("profile not found")?; - let provider_id = self.selected_model.provider_id(); + fn enabled_tools( + &self, + profile: &AgentProfileSettings, + model: &Arc, + cx: &App, + ) -> BTreeMap> { + fn truncate(tool_name: &SharedString) -> SharedString { + if tool_name.len() > MAX_TOOL_NAME_LENGTH { + let mut truncated = tool_name.to_string(); + truncated.truncate(MAX_TOOL_NAME_LENGTH); + truncated.into() + } else { + tool_name.clone() + } + } - Ok(self + let mut tools = self .tools .iter() - .filter(move |(_, tool)| tool.supported_provider(&provider_id)) .filter_map(|(tool_name, tool)| { - if profile.is_tool_enabled(tool_name) { - Some(tool) + if tool.supported_provider(&model.provider_id()) + && profile.is_tool_enabled(tool_name) + { + Some((truncate(tool_name), tool.clone())) } else { None } }) - .chain(self.context_server_registry.read(cx).servers().flat_map( - |(server_id, tools)| { - tools.iter().filter_map(|(tool_name, tool)| { - if profile.is_context_server_tool_enabled(&server_id.0, tool_name) { - Some(tool) - } else { - None - } - }) - }, - ))) + .collect::>(); + + let mut context_server_tools = Vec::new(); + let mut seen_tools = tools.keys().cloned().collect::>(); + let mut duplicate_tool_names = HashSet::default(); + for (server_id, server_tools) in self.context_server_registry.read(cx).servers() { + for (tool_name, tool) in server_tools { + if profile.is_context_server_tool_enabled(&server_id.0, &tool_name) { + let tool_name = truncate(tool_name); + if !seen_tools.insert(tool_name.clone()) { + duplicate_tool_names.insert(tool_name.clone()); + } + context_server_tools.push((server_id.clone(), tool_name, tool.clone())); + } + } + } + + // When there are duplicate tool names, disambiguate by prefixing them + // with the server ID. In the rare case there isn't enough space for the + // disambiguated tool name, keep only the last tool with this name. + for (server_id, tool_name, tool) in context_server_tools { + if duplicate_tool_names.contains(&tool_name) { + let available = MAX_TOOL_NAME_LENGTH.saturating_sub(tool_name.len()); + if available >= 2 { + let mut disambiguated = server_id.0.to_string(); + disambiguated.truncate(available - 1); + disambiguated.push('_'); + disambiguated.push_str(&tool_name); + tools.insert(disambiguated.into(), tool.clone()); + } else { + tools.insert(tool_name, tool.clone()); + } + } else { + tools.insert(tool_name, tool.clone()); + } + } + + tools + } + + fn tool(&self, name: &str) -> Option> { + self.running_turn.as_ref()?.tools.get(name).cloned() } - fn build_request_messages(&self) -> Vec { + fn build_request_messages(&self, cx: &App) -> Vec { log::trace!( "Building request messages from {} thread messages", self.messages.len() ); - let mut messages = vec![self.build_system_message()]; + + let system_prompt = SystemPromptTemplate { + project: self.project_context.read(cx), + available_tools: self.tools.keys().cloned().collect(), + } + .render(&self.templates) + .context("failed to build system prompt") + .expect("Invalid template"); + let mut messages = vec![LanguageModelRequestMessage { + role: Role::System, + content: vec![system_prompt.into()], + cache: false, + }]; for message in &self.messages { - match message { - Message::User(message) => messages.push(message.to_request()), - Message::Agent(message) => messages.extend(message.to_request()), - } + messages.extend(message.to_request()); + } + + if let Some(last_message) = messages.last_mut() { + last_message.cache = true; } if let Some(message) = self.pending_message.as_ref() { @@ -997,8 +1990,146 @@ impl Thread { markdown } + + fn advance_prompt_id(&mut self) { + self.prompt_id = PromptId::new(); + } + + fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option { + use LanguageModelCompletionError::*; + use http_client::StatusCode; + + // General strategy here: + // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all. + // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff. + // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times. + match error { + HttpResponseError { + status_code: StatusCode::TOO_MANY_REQUESTS, + .. + } => Some(RetryStrategy::ExponentialBackoff { + initial_delay: BASE_RETRY_DELAY, + max_attempts: MAX_RETRY_ATTEMPTS, + }), + ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: MAX_RETRY_ATTEMPTS, + }) + } + UpstreamProviderError { + status, + retry_after, + .. + } => match *status { + StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: MAX_RETRY_ATTEMPTS, + }) + } + StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + // Internal Server Error could be anything, retry up to 3 times. + max_attempts: 3, + }), + status => { + // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"), + // but we frequently get them in practice. See https://http.dev/529 + if status.as_u16() == 529 { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: MAX_RETRY_ATTEMPTS, + }) + } else { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: 2, + }) + } + } + }, + ApiInternalServerError { .. } => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 3, + }), + ApiReadResponseError { .. } + | HttpSend { .. } + | DeserializeResponse { .. } + | BadRequestFormat { .. } => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 3, + }), + // Retrying these errors definitely shouldn't help. + HttpResponseError { + status_code: + StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED, + .. + } + | AuthenticationError { .. } + | PermissionError { .. } + | NoApiKey { .. } + | ApiEndpointNotFound { .. } + | PromptTooLarge { .. } => None, + // These errors might be transient, so retry them + SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 1, + }), + // Retry all other 4xx and 5xx errors once. + HttpResponseError { status_code, .. } + if status_code.is_client_error() || status_code.is_server_error() => + { + Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 3, + }) + } + Other(err) + if err.is::() + || err.is::() => + { + // Retrying won't help for Payment Required or Model Request Limit errors (where + // the user must upgrade to usage-based billing to get more requests, or else wait + // for a significant amount of time for the request limit to reset). + None + } + // Conservatively assume that any other errors are non-retryable + HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 2, + }), + } + } } +struct RunningTurn { + /// Holds the task that handles agent interaction until the end of the turn. + /// Survives across multiple requests as the model performs tool calls and + /// we run tools, report their results. + _task: Task<()>, + /// The current event stream for the running turn. Used to report a final + /// cancellation event if we cancel the turn. + event_stream: ThreadEventStream, + /// The tools that were enabled for this turn. + tools: BTreeMap>, +} + +impl RunningTurn { + fn cancel(self) { + log::debug!("Cancelling in progress turn"); + self.event_stream.send_canceled(); + } +} + +pub struct TokenUsageUpdated(pub Option); + +impl EventEmitter for Thread {} + +pub struct TitleUpdated; + +impl EventEmitter for Thread {} + pub trait AgentTool where Self: 'static + Sized, @@ -1006,7 +2137,7 @@ where type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema; type Output: for<'de> Deserialize<'de> + Serialize + Into; - fn name(&self) -> SharedString; + fn name() -> &'static str; fn description(&self) -> SharedString { let schema = schemars::schema_for!(Self::Input); @@ -1018,14 +2149,18 @@ where ) } - fn kind(&self) -> acp::ToolKind; + fn kind() -> acp::ToolKind; /// The initial tool title to display. Can be updated during the tool run. - fn initial_title(&self, input: Result) -> SharedString; + fn initial_title( + &self, + input: Result, + cx: &mut App, + ) -> SharedString; /// Returns the JSON schema that describes the tool's input. - fn input_schema(&self) -> Schema { - schemars::schema_for!(Self::Input) + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema { + crate::tool_schema::root_schema_for::(format) } /// Some tools rely on a provider for the underlying billing or other reasons. @@ -1042,6 +2177,17 @@ where cx: &mut App, ) -> Task>; + /// Emits events for a previous execution of the tool. + fn replay( + &self, + _input: Self::Input, + _output: Self::Output, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Result<()> { + Ok(()) + } + fn erase(self) -> Arc { Arc::new(Erased(Arc::new(self))) } @@ -1058,7 +2204,7 @@ pub trait AnyAgentTool { fn name(&self) -> SharedString; fn description(&self) -> SharedString; fn kind(&self) -> acp::ToolKind; - fn initial_title(&self, input: serde_json::Value) -> SharedString; + fn initial_title(&self, input: serde_json::Value, _cx: &mut App) -> SharedString; fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool { true @@ -1069,6 +2215,13 @@ pub trait AnyAgentTool { event_stream: ToolCallEventStream, cx: &mut App, ) -> Task>; + fn replay( + &self, + input: serde_json::Value, + output: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()>; } impl AnyAgentTool for Erased> @@ -1076,7 +2229,7 @@ where T: AgentTool, { fn name(&self) -> SharedString { - self.0.name() + T::name().into() } fn description(&self) -> SharedString { @@ -1084,16 +2237,16 @@ where } fn kind(&self) -> agent_client_protocol::ToolKind { - self.0.kind() + T::kind() } - fn initial_title(&self, input: serde_json::Value) -> SharedString { + fn initial_title(&self, input: serde_json::Value, _cx: &mut App) -> SharedString { let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input); - self.0.initial_title(parsed_input) + self.0.initial_title(parsed_input, _cx) } fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - let mut json = serde_json::to_value(self.0.input_schema())?; + let mut json = serde_json::to_value(self.0.input_schema(format))?; adapt_schema_to_format(&mut json, format)?; Ok(json) } @@ -1120,23 +2273,39 @@ where }) }) } + + fn replay( + &self, + input: serde_json::Value, + output: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()> { + let input = serde_json::from_value(input)?; + let output = serde_json::from_value(output)?; + self.0.replay(input, output, event_stream, cx) + } } #[derive(Clone)] -struct AgentResponseEventStream( - mpsc::UnboundedSender>, -); +struct ThreadEventStream(mpsc::UnboundedSender>); + +impl ThreadEventStream { + fn send_user_message(&self, message: &UserMessage) { + self.0 + .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone()))) + .ok(); + } -impl AgentResponseEventStream { fn send_text(&self, text: &str) { self.0 - .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string()))) + .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string()))) .ok(); } fn send_thinking(&self, text: &str) { self.0 - .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string()))) + .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string()))) .ok(); } @@ -1148,7 +2317,7 @@ impl AgentResponseEventStream { input: serde_json::Value, ) { self.0 - .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call( + .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call( id, title.to_string(), kind, @@ -1181,7 +2350,7 @@ impl AgentResponseEventStream { fields: acp::ToolCallUpdateFields, ) { self.0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( acp::ToolCallUpdate { id: acp::ToolCallId(tool_use_id.to_string().into()), fields, @@ -1191,73 +2360,49 @@ impl AgentResponseEventStream { .ok(); } - fn send_stop(&self, reason: StopReason) { - match reason { - StopReason::EndTurn => { - self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn))) - .ok(); - } - StopReason::MaxTokens => { - self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens))) - .ok(); - } - StopReason::Refusal => { - self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal))) - .ok(); - } - StopReason::ToolUse => {} - } + fn send_retry(&self, status: acp_thread::RetryStatus) { + self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok(); } - fn send_error(&self, error: LanguageModelCompletionError) { - self.0.unbounded_send(Err(error)).ok(); + fn send_stop(&self, reason: acp::StopReason) { + self.0.unbounded_send(Ok(ThreadEvent::Stop(reason))).ok(); + } + + fn send_canceled(&self) { + self.0 + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled))) + .ok(); + } + + fn send_error(&self, error: impl Into) { + self.0.unbounded_send(Err(error.into())).ok(); } } #[derive(Clone)] pub struct ToolCallEventStream { tool_use_id: LanguageModelToolUseId, - kind: acp::ToolKind, - input: serde_json::Value, - stream: AgentResponseEventStream, + stream: ThreadEventStream, fs: Option>, } impl ToolCallEventStream { #[cfg(test)] pub fn test() -> (Self, ToolCallEventStreamReceiver) { - let (events_tx, events_rx) = - mpsc::unbounded::>(); - - let stream = ToolCallEventStream::new( - &LanguageModelToolUse { - id: "test_id".into(), - name: "test_tool".into(), - raw_input: String::new(), - input: serde_json::Value::Null, - is_input_complete: true, - }, - acp::ToolKind::Other, - AgentResponseEventStream(events_tx), - None, - ); + let (events_tx, events_rx) = mpsc::unbounded::>(); + + let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None); (stream, ToolCallEventStreamReceiver(events_rx)) } fn new( - tool_use: &LanguageModelToolUse, - kind: acp::ToolKind, - stream: AgentResponseEventStream, + tool_use_id: LanguageModelToolUseId, + stream: ThreadEventStream, fs: Option>, ) -> Self { Self { - tool_use_id: tool_use.id.clone(), - kind, - input: tool_use.input.clone(), + tool_use_id, stream, fs, } @@ -1271,7 +2416,7 @@ impl ToolCallEventStream { pub fn update_diff(&self, diff: Entity) { self.stream .0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( acp_thread::ToolCallUpdateDiff { id: acp::ToolCallId(self.tool_use_id.to_string().into()), diff, @@ -1281,19 +2426,6 @@ impl ToolCallEventStream { .ok(); } - pub fn update_terminal(&self, terminal: Entity) { - self.stream - .0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( - acp_thread::ToolCallUpdateTerminal { - id: acp::ToolCallId(self.tool_use_id.to_string().into()), - terminal, - } - .into(), - ))) - .ok(); - } - pub fn authorize(&self, title: impl Into, cx: &mut App) -> Task> { if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions { return Task::ready(Ok(())); @@ -1302,14 +2434,15 @@ impl ToolCallEventStream { let (response_tx, response_rx) = oneshot::channel(); self.stream .0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization( + .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization( ToolCallAuthorization { - tool_call: AgentResponseEventStream::initial_tool_call( - &self.tool_use_id, - title.into(), - self.kind.clone(), - self.input.clone(), - ), + tool_call: acp::ToolCallUpdate { + id: acp::ToolCallId(self.tool_use_id.to_string().into()), + fields: acp::ToolCallUpdateFields { + title: Some(title.into()), + ..Default::default() + }, + }, options: vec![ acp::PermissionOption { id: acp::PermissionOptionId("always_allow".into()), @@ -1351,26 +2484,48 @@ impl ToolCallEventStream { } #[cfg(test)] -pub struct ToolCallEventStreamReceiver( - mpsc::UnboundedReceiver>, -); +pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver>); #[cfg(test)] impl ToolCallEventStreamReceiver { pub async fn expect_authorization(&mut self) -> ToolCallAuthorization { let event = self.0.next().await; - if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event { + if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event { auth } else { panic!("Expected ToolCallAuthorization but got: {:?}", event); } } + pub async fn expect_update_fields(&mut self) -> acp::ToolCallUpdateFields { + let event = self.0.next().await; + if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( + update, + )))) = event + { + update.fields + } else { + panic!("Expected update fields but got: {:?}", event); + } + } + + pub async fn expect_diff(&mut self) -> Entity { + let event = self.0.next().await; + if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateDiff( + update, + )))) = event + { + update.diff + } else { + panic!("Expected diff but got: {:?}", event); + } + } + pub async fn expect_terminal(&mut self) -> Entity { let event = self.0.next().await; - if let Some(Ok(AgentResponseEvent::ToolCallUpdate( - acp_thread::ToolCallUpdate::UpdateTerminal(update), - ))) = event + if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal( + update, + )))) = event { update.terminal } else { @@ -1381,7 +2536,7 @@ impl ToolCallEventStreamReceiver { #[cfg(test)] impl std::ops::Deref for ToolCallEventStreamReceiver { - type Target = mpsc::UnboundedReceiver>; + type Target = mpsc::UnboundedReceiver>; fn deref(&self) -> &Self::Target { &self.0 @@ -1450,6 +2605,35 @@ impl From for UserMessageContent { } } +impl From for acp::ContentBlock { + fn from(content: UserMessageContent) -> Self { + match content { + UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + }), + UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent { + data: image.source.to_string(), + mime_type: "image/png".to_string(), + annotations: None, + uri: None, + }), + UserMessageContent::Mention { uri, content } => { + acp::ContentBlock::Resource(acp::EmbeddedResource { + resource: acp::EmbeddedResourceResource::TextResourceContents( + acp::TextResourceContents { + mime_type: None, + text: content, + uri: uri.to_uri().to_string(), + }, + ), + annotations: None, + }) + } + } + } +} + fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage { LanguageModelImage { source: image_content.data.into(), diff --git a/crates/agent2/src/tool_schema.rs b/crates/agent2/src/tool_schema.rs new file mode 100644 index 0000000000000000000000000000000000000000..f608336b416a72885e52abba58ef472029421e4f --- /dev/null +++ b/crates/agent2/src/tool_schema.rs @@ -0,0 +1,43 @@ +use language_model::LanguageModelToolSchemaFormat; +use schemars::{ + JsonSchema, Schema, + generate::SchemaSettings, + transform::{Transform, transform_subschemas}, +}; + +pub(crate) fn root_schema_for(format: LanguageModelToolSchemaFormat) -> Schema { + let mut generator = match format { + LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(), + LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3() + .with(|settings| { + settings.meta_schema = None; + settings.inline_subschemas = true; + }) + .with_transform(ToJsonSchemaSubsetTransform) + .into_generator(), + }; + generator.root_schema_for::() +} + +#[derive(Debug, Clone)] +struct ToJsonSchemaSubsetTransform; + +impl Transform for ToJsonSchemaSubsetTransform { + fn transform(&mut self, schema: &mut Schema) { + // Ensure that the type field is not an array, this happens when we use + // Option, the type will be [T, "null"]. + if let Some(type_field) = schema.get_mut("type") + && let Some(types) = type_field.as_array() + && let Some(first_type) = types.first() + { + *type_field = first_type.clone(); + } + + // oneOf is not supported, use anyOf instead + if let Some(one_of) = schema.remove("oneOf") { + schema.insert("anyOf".to_string(), one_of); + } + + transform_subschemas(self, schema); + } +} diff --git a/crates/agent2/src/tools.rs b/crates/agent2/src/tools.rs index d1f2b3b1c7ad3ed7ade2324c61c1e72d7e7e4006..bcca7eecd185b9381afded26fb573d14f50bc5be 100644 --- a/crates/agent2/src/tools.rs +++ b/crates/agent2/src/tools.rs @@ -16,6 +16,29 @@ mod terminal_tool; mod thinking_tool; mod web_search_tool; +/// A list of all built in tool names, for use in deduplicating MCP tool names +pub fn default_tool_names() -> impl Iterator { + [ + CopyPathTool::name(), + CreateDirectoryTool::name(), + DeletePathTool::name(), + DiagnosticsTool::name(), + EditFileTool::name(), + FetchTool::name(), + FindPathTool::name(), + GrepTool::name(), + ListDirectoryTool::name(), + MovePathTool::name(), + NowTool::name(), + OpenTool::name(), + ReadFileTool::name(), + TerminalTool::name(), + ThinkingTool::name(), + WebSearchTool::name(), + ] + .into_iter() +} + pub use context_server_registry::*; pub use copy_path_tool::*; pub use create_directory_tool::*; @@ -33,3 +56,5 @@ pub use read_file_tool::*; pub use terminal_tool::*; pub use thinking_tool::*; pub use web_search_tool::*; + +use crate::AgentTool; diff --git a/crates/agent2/src/tools/context_server_registry.rs b/crates/agent2/src/tools/context_server_registry.rs index db39e9278c250865d63922cd802e1ce9fb1d003f..46fa0298044de017464dc1a2e5bd21bf57c1bfcf 100644 --- a/crates/agent2/src/tools/context_server_registry.rs +++ b/crates/agent2/src/tools/context_server_registry.rs @@ -103,7 +103,7 @@ impl ContextServerRegistry { self.reload_tools_for_server(server_id.clone(), cx); } ContextServerStatus::Stopped | ContextServerStatus::Error(_) => { - self.registered_servers.remove(&server_id); + self.registered_servers.remove(server_id); cx.notify(); } } @@ -145,7 +145,7 @@ impl AnyAgentTool for ContextServerTool { ToolKind::Other } - fn initial_title(&self, _input: serde_json::Value) -> SharedString { + fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString { format!("Run MCP tool `{}`", self.tool.name).into() } @@ -169,22 +169,23 @@ impl AnyAgentTool for ContextServerTool { fn run( self: Arc, input: serde_json::Value, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else { return Task::ready(Err(anyhow!("Context server not found"))); }; let tool_name = self.tool.name.clone(); - let server_clone = server.clone(); - let input_clone = input.clone(); + let authorize = event_stream.authorize(self.initial_title(input.clone(), cx), cx); cx.spawn(async move |_cx| { - let Some(protocol) = server_clone.client() else { + authorize.await?; + + let Some(protocol) = server.client() else { bail!("Context server not initialized"); }; - let arguments = if let serde_json::Value::Object(map) = input_clone { + let arguments = if let serde_json::Value::Object(map) = input { Some(map.into_iter().collect()) } else { None @@ -228,4 +229,14 @@ impl AnyAgentTool for ContextServerTool { }) }) } + + fn replay( + &self, + _input: serde_json::Value, + _output: serde_json::Value, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Result<()> { + Ok(()) + } } diff --git a/crates/agent2/src/tools/copy_path_tool.rs b/crates/agent2/src/tools/copy_path_tool.rs index f973b86990af76ea923d548f95a4f05b4cd32c18..8fcd80391f828c7503701a86e9e1b400115763d6 100644 --- a/crates/agent2/src/tools/copy_path_tool.rs +++ b/crates/agent2/src/tools/copy_path_tool.rs @@ -1,23 +1,18 @@ use crate::{AgentTool, ToolCallEventStream}; use agent_client_protocol::ToolKind; use anyhow::{Context as _, Result, anyhow}; -use gpui::{App, AppContext, Entity, SharedString, Task}; +use gpui::{App, AppContext, Entity, Task}; use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::sync::Arc; use util::markdown::MarkdownInlineCode; -/// Copies a file or directory in the project, and returns confirmation that the -/// copy succeeded. -/// +/// Copies a file or directory in the project, and returns confirmation that the copy succeeded. /// Directory contents will be copied recursively (like `cp -r`). /// -/// This tool should be used when it's desirable to create a copy of a file or -/// directory without modifying the original. It's much more efficient than -/// doing this by separately reading and then writing the file or directory's -/// contents, so this tool should be preferred over that approach whenever -/// copying is the goal. +/// This tool should be used when it's desirable to create a copy of a file or directory without modifying the original. +/// It's much more efficient than doing this by separately reading and then writing the file or directory's contents, so this tool should be preferred over that approach whenever copying is the goal. #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct CopyPathToolInput { /// The source path of the file or directory to copy. @@ -33,12 +28,10 @@ pub struct CopyPathToolInput { /// You can copy the first file by providing a source_path of "directory1/a/something.txt" /// pub source_path: String, - /// The destination path where the file or directory should be copied to. /// /// - /// To copy "directory1/a/something.txt" to "directory2/b/copy.txt", - /// provide a destination_path of "directory2/b/copy.txt" + /// To copy "directory1/a/something.txt" to "directory2/b/copy.txt", provide a destination_path of "directory2/b/copy.txt" /// pub destination_path: String, } @@ -57,15 +50,19 @@ impl AgentTool for CopyPathTool { type Input = CopyPathToolInput; type Output = String; - fn name(&self) -> SharedString { - "copy_path".into() + fn name() -> &'static str { + "copy_path" } - fn kind(&self) -> ToolKind { + fn kind() -> ToolKind { ToolKind::Move } - fn initial_title(&self, input: Result) -> ui::SharedString { + fn initial_title( + &self, + input: Result, + _cx: &mut App, + ) -> ui::SharedString { if let Ok(input) = input { let src = MarkdownInlineCode(&input.source_path); let dest = MarkdownInlineCode(&input.destination_path); diff --git a/crates/agent2/src/tools/create_directory_tool.rs b/crates/agent2/src/tools/create_directory_tool.rs index c173c5ae67512813b610552c2001dc16ceb38212..30bd6418db35182358ed6139a9078e40a29dfac5 100644 --- a/crates/agent2/src/tools/create_directory_tool.rs +++ b/crates/agent2/src/tools/create_directory_tool.rs @@ -9,12 +9,9 @@ use util::markdown::MarkdownInlineCode; use crate::{AgentTool, ToolCallEventStream}; -/// Creates a new directory at the specified path within the project. Returns -/// confirmation that the directory was created. +/// Creates a new directory at the specified path within the project. Returns confirmation that the directory was created. /// -/// This tool creates a directory and all necessary parent directories (similar -/// to `mkdir -p`). It should be used whenever you need to create new -/// directories within the project. +/// This tool creates a directory and all necessary parent directories (similar to `mkdir -p`). It should be used whenever you need to create new directories within the project. #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct CreateDirectoryToolInput { /// The path of the new directory. @@ -44,15 +41,19 @@ impl AgentTool for CreateDirectoryTool { type Input = CreateDirectoryToolInput; type Output = String; - fn name(&self) -> SharedString { - "create_directory".into() + fn name() -> &'static str { + "create_directory" } - fn kind(&self) -> ToolKind { + fn kind() -> ToolKind { ToolKind::Read } - fn initial_title(&self, input: Result) -> SharedString { + fn initial_title( + &self, + input: Result, + _cx: &mut App, + ) -> SharedString { if let Ok(input) = input { format!("Create directory {}", MarkdownInlineCode(&input.path)).into() } else { diff --git a/crates/agent2/src/tools/delete_path_tool.rs b/crates/agent2/src/tools/delete_path_tool.rs index e013b3a3e755cf6662718d620264cb1e38fa5417..01a77f5d811127b3df470ec73fbc91ff7c26fd52 100644 --- a/crates/agent2/src/tools/delete_path_tool.rs +++ b/crates/agent2/src/tools/delete_path_tool.rs @@ -9,8 +9,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::sync::Arc; -/// Deletes the file or directory (and the directory's contents, recursively) at -/// the specified path in the project, and returns confirmation of the deletion. +/// Deletes the file or directory (and the directory's contents, recursively) at the specified path in the project, and returns confirmation of the deletion. #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct DeletePathToolInput { /// The path of the file or directory to delete. @@ -45,15 +44,19 @@ impl AgentTool for DeletePathTool { type Input = DeletePathToolInput; type Output = String; - fn name(&self) -> SharedString { - "delete_path".into() + fn name() -> &'static str { + "delete_path" } - fn kind(&self) -> ToolKind { + fn kind() -> ToolKind { ToolKind::Delete } - fn initial_title(&self, input: Result) -> SharedString { + fn initial_title( + &self, + input: Result, + _cx: &mut App, + ) -> SharedString { if let Ok(input) = input { format!("Delete “`{}`”", input.path).into() } else { diff --git a/crates/agent2/src/tools/diagnostics_tool.rs b/crates/agent2/src/tools/diagnostics_tool.rs index 6ba8b7b377a770fa3af35b725b4427e7102d70c1..a38e317d43cb16d8ee652f1a5f7aabd8b1ce4c8f 100644 --- a/crates/agent2/src/tools/diagnostics_tool.rs +++ b/crates/agent2/src/tools/diagnostics_tool.rs @@ -63,15 +63,19 @@ impl AgentTool for DiagnosticsTool { type Input = DiagnosticsToolInput; type Output = String; - fn name(&self) -> SharedString { - "diagnostics".into() + fn name() -> &'static str { + "diagnostics" } - fn kind(&self) -> acp::ToolKind { + fn kind() -> acp::ToolKind { acp::ToolKind::Read } - fn initial_title(&self, input: Result) -> SharedString { + fn initial_title( + &self, + input: Result, + _cx: &mut App, + ) -> SharedString { if let Some(path) = input.ok().and_then(|input| match input.path { Some(path) if !path.is_empty() => Some(path), _ => None, diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index 405afb585f5e9acfabfd632f0374729aa7e5c5af..9237961bce513d740989c7e3076395ed68473859 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -5,10 +5,10 @@ use anyhow::{Context as _, Result, anyhow}; use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat}; use cloud_llm_client::CompletionIntent; use collections::HashSet; -use gpui::{App, AppContext, AsyncApp, Entity, Task}; +use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; use indoc::formatdoc; -use language::ToPoint; use language::language_settings::{self, FormatOnSave}; +use language::{LanguageRegistry, ToPoint}; use language_model::LanguageModelToolResultContent; use paths; use project::lsp_store::{FormatTrigger, LspFormatTarget}; @@ -34,25 +34,21 @@ const DEFAULT_UI_TEXT: &str = "Editing file"; /// - Use the `list_directory` tool to verify the parent directory exists and is the correct location #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct EditFileToolInput { - /// A one-line, user-friendly markdown description of the edit. This will be - /// shown in the UI and also passed to another model to perform the edit. + /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit. /// - /// Be terse, but also descriptive in what you want to achieve with this - /// edit. Avoid generic instructions. + /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions. /// /// NEVER mention the file path in this description. /// /// Fix API endpoint URLs /// Update copyright year in `page_footer` /// - /// Make sure to include this field before all the others in the input object - /// so that we can display it immediately. + /// Make sure to include this field before all the others in the input object so that we can display it immediately. pub display_description: String, /// The full path of the file to create or modify in the project. /// - /// WARNING: When specifying which file path need changing, you MUST - /// start each path with one of the project's root directories. + /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories. /// /// The following examples assume we have two root directories in the project: /// - /a/b/backend @@ -61,22 +57,19 @@ pub struct EditFileToolInput { /// /// `backend/src/main.rs` /// - /// Notice how the file path starts with `backend`. Without that, the path - /// would be ambiguous and the call would fail! + /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail! /// /// /// /// `frontend/db.js` /// pub path: PathBuf, - /// The mode of operation on the file. Possible values: /// - 'edit': Make granular edits to an existing file. /// - 'create': Create a new file if it doesn't exist. /// - 'overwrite': Replace the entire contents of an existing file. /// - /// When a file already exists or you just created it, prefer editing - /// it as opposed to recreating it from scratch. + /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch. pub mode: EditFileMode, } @@ -90,6 +83,7 @@ struct EditFileToolPartialInput { #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "lowercase")] +#[schemars(inline)] pub enum EditFileMode { Edit, Create, @@ -98,11 +92,13 @@ pub enum EditFileMode { #[derive(Debug, Serialize, Deserialize)] pub struct EditFileToolOutput { + #[serde(alias = "original_path")] input_path: PathBuf, - project_path: PathBuf, new_text: String, old_text: Arc, + #[serde(default)] diff: String, + #[serde(alias = "raw_output")] edit_agent_output: EditAgentOutput, } @@ -122,12 +118,22 @@ impl From for LanguageModelToolResultContent { } pub struct EditFileTool { - thread: Entity, + thread: WeakEntity, + language_registry: Arc, + project: Entity, } impl EditFileTool { - pub fn new(thread: Entity) -> Self { - Self { thread } + pub fn new( + project: Entity, + thread: WeakEntity, + language_registry: Arc, + ) -> Self { + Self { + project, + thread, + language_registry, + } } fn authorize( @@ -156,19 +162,22 @@ impl EditFileTool { // It's also possible that the global config dir is configured to be inside the project, // so check for that edge case too. - if let Ok(canonical_path) = std::fs::canonicalize(&input.path) { - if canonical_path.starts_with(paths::config_dir()) { - return event_stream.authorize( - format!("{} (global settings)", input.display_description), - cx, - ); - } + if let Ok(canonical_path) = std::fs::canonicalize(&input.path) + && canonical_path.starts_with(paths::config_dir()) + { + return event_stream.authorize( + format!("{} (global settings)", input.display_description), + cx, + ); } // Check if path is inside the global config directory // First check if it's already inside project - if not, try to canonicalize - let thread = self.thread.read(cx); - let project_path = thread.project().read(cx).find_project_path(&input.path, cx); + let Ok(project_path) = self.thread.read_with(cx, |thread, cx| { + thread.project().read(cx).find_project_path(&input.path, cx) + }) else { + return Task::ready(Err(anyhow!("thread was dropped"))); + }; // If the path is inside the project, and it's not one of the above edge cases, // then no confirmation is necessary. Otherwise, confirmation is necessary. @@ -184,30 +193,58 @@ impl AgentTool for EditFileTool { type Input = EditFileToolInput; type Output = EditFileToolOutput; - fn name(&self) -> SharedString { - "edit_file".into() + fn name() -> &'static str { + "edit_file" } - fn kind(&self) -> acp::ToolKind { + fn kind() -> acp::ToolKind { acp::ToolKind::Edit } - fn initial_title(&self, input: Result) -> SharedString { + fn initial_title( + &self, + input: Result, + cx: &mut App, + ) -> SharedString { match input { - Ok(input) => input.display_description.into(), + Ok(input) => self + .project + .read(cx) + .find_project_path(&input.path, cx) + .and_then(|project_path| { + self.project + .read(cx) + .short_full_path_for_project_path(&project_path, cx) + }) + .unwrap_or(Path::new(&input.path).into()) + .to_string_lossy() + .to_string() + .into(), Err(raw_input) => { if let Some(input) = serde_json::from_value::(raw_input).ok() { + let path = input.path.trim(); + if !path.is_empty() { + return self + .project + .read(cx) + .find_project_path(&input.path, cx) + .and_then(|project_path| { + self.project + .read(cx) + .short_full_path_for_project_path(&project_path, cx) + }) + .unwrap_or(Path::new(&input.path).into()) + .to_string_lossy() + .to_string() + .into(); + } + let description = input.display_description.trim(); if !description.is_empty() { return description.to_string().into(); } - - let path = input.path.trim().to_string(); - if !path.is_empty() { - return path.into(); - } } DEFAULT_UI_TEXT.into() @@ -221,7 +258,12 @@ impl AgentTool for EditFileTool { event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let project = self.thread.read(cx).project().clone(); + let Ok(project) = self + .thread + .read_with(cx, |thread, _cx| thread.project().clone()) + else { + return Task::ready(Err(anyhow!("thread was dropped"))); + }; let project_path = match resolve_path(&input, project.clone(), cx) { Ok(path) => path, Err(err) => return Task::ready(Err(anyhow!(err))), @@ -237,17 +279,17 @@ impl AgentTool for EditFileTool { }); } - let request = self.thread.update(cx, |thread, cx| { - thread.build_completion_request(CompletionIntent::ToolResults, cx) - }); - let thread = self.thread.read(cx); - let model = thread.selected_model.clone(); - let action_log = thread.action_log().clone(); - let authorize = self.authorize(&input, &event_stream, cx); cx.spawn(async move |cx: &mut AsyncApp| { authorize.await?; + let (request, model, action_log) = self.thread.update(cx, |thread, cx| { + let request = thread.build_completion_request(CompletionIntent::ToolResults, cx); + (request, thread.model().cloned(), thread.action_log().clone()) + })?; + let request = request?; + let model = model.context("No language model configured")?; + let edit_format = EditFormat::from_model(model.clone())?; let edit_agent = EditAgent::new( model, @@ -266,6 +308,13 @@ impl AgentTool for EditFileTool { let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?; event_stream.update_diff(diff.clone()); + let _finalize_diff = util::defer({ + let diff = diff.downgrade(); + let mut cx = cx.clone(); + move || { + diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok(); + } + }); let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; let old_text = cx @@ -382,8 +431,6 @@ impl AgentTool for EditFileTool { }) .await; - diff.update(cx, |diff, cx| diff.finalize(cx)).ok(); - let input_path = input.path.display(); if unified_diff.is_empty() { anyhow::ensure!( @@ -413,14 +460,32 @@ impl AgentTool for EditFileTool { Ok(EditFileToolOutput { input_path: input.path, - project_path: project_path.path.to_path_buf(), - new_text: new_text.clone(), + new_text, old_text, diff: unified_diff, edit_agent_output, }) }) } + + fn replay( + &self, + _input: Self::Input, + output: Self::Output, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()> { + event_stream.update_diff(cx.new(|cx| { + Diff::finalized( + output.input_path, + Some(output.old_text.to_string()), + output.new_text, + self.language_registry.clone(), + cx, + ) + })); + Ok(()) + } } /// Validate that the file path is valid, meaning: @@ -465,7 +530,7 @@ fn resolve_path( let parent_entry = parent_project_path .as_ref() - .and_then(|path| project.entry_for_path(&path, cx)) + .and_then(|path| project.entry_for_path(path, cx)) .context("Can't create file: parent directory doesn't exist")?; anyhow::ensure!( @@ -492,14 +557,13 @@ fn resolve_path( mod tests { use super::*; use crate::{ContextServerRegistry, Templates}; - use action_log::ActionLog; use client::TelemetrySettings; use fs::Fs; use gpui::{TestAppContext, UpdateGlobal}; use language_model::fake_provider::FakeLanguageModel; + use prompt_store::ProjectContext; use serde_json::json; use settings::SettingsStore; - use std::rc::Rc; use util::path; #[gpui::test] @@ -509,18 +573,17 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let action_log = cx.new(|_| ActionLog::new(project.clone())); + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( - project, - Rc::default(), + project.clone(), + cx.new(|_cx| ProjectContext::default()), context_server_registry, - action_log, Templates::new(), - model, + Some(model), cx, ) }); @@ -531,7 +594,12 @@ mod tests { path: "root/nonexistent_file.txt".into(), mode: EditFileMode::Edit, }; - Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new( + project, + thread.downgrade(), + language_registry, + )) + .run(input, ToolCallEventStream::test().0, cx) }) .await; assert_eq!( @@ -618,8 +686,7 @@ mod tests { mode: mode.clone(), }; - let result = cx.update(|cx| resolve_path(&input, project, cx)); - result + cx.update(|cx| resolve_path(&input, project, cx)) } fn assert_resolved_path_eq(path: anyhow::Result, expected: &str) { @@ -706,18 +773,16 @@ mod tests { } }); - let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( - project, - Rc::default(), + project.clone(), + cx.new(|_cx| ProjectContext::default()), context_server_registry, - action_log.clone(), Templates::new(), - model.clone(), + Some(model.clone()), cx, ) }); @@ -744,9 +809,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { - thread: thread.clone(), - }) + Arc::new(EditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry.clone(), + )) .run(input, ToolCallEventStream::test().0, cx) }); @@ -771,7 +838,9 @@ mod tests { "Code should be formatted when format_on_save is enabled" ); - let stale_buffer_count = action_log.read_with(cx, |log, cx| log.stale_buffers(cx).count()); + let stale_buffer_count = thread + .read_with(cx, |thread, _cx| thread.action_log.clone()) + .read_with(cx, |log, cx| log.stale_buffers(cx).count()); assert_eq!( stale_buffer_count, 0, @@ -800,7 +869,12 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )) + .run(input, ToolCallEventStream::test().0, cx) }); // Stream the unformatted content @@ -844,16 +918,15 @@ mod tests { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let action_log = cx.new(|_| ActionLog::new(project.clone())); + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( - project, - Rc::default(), + project.clone(), + cx.new(|_cx| ProjectContext::default()), context_server_registry, - action_log.clone(), Templates::new(), - model.clone(), + Some(model.clone()), cx, ) }); @@ -881,9 +954,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { - thread: thread.clone(), - }) + Arc::new(EditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry.clone(), + )) .run(input, ToolCallEventStream::test().0, cx) }); @@ -932,9 +1007,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { - thread: thread.clone(), - }) + Arc::new(EditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )) .run(input, ToolCallEventStream::test().0, cx) }); @@ -970,20 +1047,23 @@ mod tests { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let action_log = cx.new(|_| ActionLog::new(project.clone())); + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( - project, - Rc::default(), + project.clone(), + cx.new(|_cx| ProjectContext::default()), context_server_registry, - action_log.clone(), Templates::new(), - model.clone(), + Some(model.clone()), cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); fs.insert_tree("/root", json!({})).await; // Test 1: Path with .zed component should require confirmation @@ -1001,7 +1081,10 @@ mod tests { }); let event = stream_rx.expect_authorization().await; - assert_eq!(event.tool_call.title, "test 1 (local settings)"); + assert_eq!( + event.tool_call.fields.title, + Some("test 1 (local settings)".into()) + ); // Test 2: Path outside project should require confirmation let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); @@ -1018,7 +1101,7 @@ mod tests { }); let event = stream_rx.expect_authorization().await; - assert_eq!(event.tool_call.title, "test 2"); + assert_eq!(event.tool_call.fields.title, Some("test 2".into())); // Test 3: Relative path without .zed should not require confirmation let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); @@ -1051,7 +1134,10 @@ mod tests { ) }); let event = stream_rx.expect_authorization().await; - assert_eq!(event.tool_call.title, "test 4 (local settings)"); + assert_eq!( + event.tool_call.fields.title, + Some("test 4 (local settings)".into()) + ); // Test 5: When always_allow_tool_actions is enabled, no confirmation needed cx.update(|cx| { @@ -1099,22 +1185,25 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/project", json!({})).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( - project, - Rc::default(), + project.clone(), + cx.new(|_cx| ProjectContext::default()), context_server_registry, - action_log.clone(), Templates::new(), - model.clone(), + Some(model.clone()), cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); // Test global config paths - these should require confirmation if they exist and are outside the project let test_cases = vec![ @@ -1208,23 +1297,25 @@ mod tests { cx, ) .await; - - let action_log = cx.new(|_| ActionLog::new(project.clone())); + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( project.clone(), - Rc::default(), + cx.new(|_cx| ProjectContext::default()), context_server_registry.clone(), - action_log.clone(), Templates::new(), - model.clone(), + Some(model.clone()), cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); // Test files in different worktrees let test_cases = vec![ @@ -1290,22 +1381,25 @@ mod tests { ) .await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let action_log = cx.new(|_| ActionLog::new(project.clone())); + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( project.clone(), - Rc::default(), + cx.new(|_cx| ProjectContext::default()), context_server_registry.clone(), - action_log.clone(), Templates::new(), - model.clone(), + Some(model.clone()), cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); // Test edge cases let test_cases = vec![ @@ -1374,22 +1468,25 @@ mod tests { ) .await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let action_log = cx.new(|_| ActionLog::new(project.clone())); + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( project.clone(), - Rc::default(), + cx.new(|_cx| ProjectContext::default()), context_server_registry.clone(), - action_log.clone(), Templates::new(), - model.clone(), + Some(model.clone()), cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); // Test different EditFileMode values let modes = vec![ @@ -1455,63 +1552,187 @@ mod tests { init_test(cx); let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let action_log = cx.new(|_| ActionLog::new(project.clone())); + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( project.clone(), - Rc::default(), + cx.new(|_cx| ProjectContext::default()), context_server_registry, - action_log.clone(), Templates::new(), - model.clone(), + Some(model.clone()), cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new( + project, + thread.downgrade(), + language_registry, + )); - assert_eq!( - tool.initial_title(Err(json!({ - "path": "src/main.rs", - "display_description": "", - "old_string": "old code", - "new_string": "new code" - }))), - "src/main.rs" - ); - assert_eq!( - tool.initial_title(Err(json!({ - "path": "", - "display_description": "Fix error handling", - "old_string": "old code", - "new_string": "new code" - }))), - "Fix error handling" - ); - assert_eq!( - tool.initial_title(Err(json!({ - "path": "src/main.rs", - "display_description": "Fix error handling", - "old_string": "old code", - "new_string": "new code" - }))), - "Fix error handling" - ); - assert_eq!( - tool.initial_title(Err(json!({ - "path": "", - "display_description": "", - "old_string": "old code", - "new_string": "new code" - }))), - DEFAULT_UI_TEXT - ); - assert_eq!( - tool.initial_title(Err(serde_json::Value::Null)), - DEFAULT_UI_TEXT - ); + cx.update(|cx| { + // ... + assert_eq!( + tool.initial_title( + Err(json!({ + "path": "src/main.rs", + "display_description": "", + "old_string": "old code", + "new_string": "new code" + })), + cx + ), + "src/main.rs" + ); + assert_eq!( + tool.initial_title( + Err(json!({ + "path": "", + "display_description": "Fix error handling", + "old_string": "old code", + "new_string": "new code" + })), + cx + ), + "Fix error handling" + ); + assert_eq!( + tool.initial_title( + Err(json!({ + "path": "src/main.rs", + "display_description": "Fix error handling", + "old_string": "old code", + "new_string": "new code" + })), + cx + ), + "src/main.rs" + ); + assert_eq!( + tool.initial_title( + Err(json!({ + "path": "", + "display_description": "", + "old_string": "old code", + "new_string": "new code" + })), + cx + ), + DEFAULT_UI_TEXT + ); + assert_eq!( + tool.initial_title(Err(serde_json::Value::Null), cx), + DEFAULT_UI_TEXT + ); + }); + } + + #[gpui::test] + async fn test_diff_finalization(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/", json!({"main.rs": ""})).await; + + let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await; + let languages = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry.clone(), + Templates::new(), + Some(model.clone()), + cx, + ) + }); + + // Ensure the diff is finalized after the edit completes. + { + let tool = Arc::new(EditFileTool::new( + project.clone(), + thread.downgrade(), + languages.clone(), + )); + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let edit = cx.update(|cx| { + tool.run( + EditFileToolInput { + display_description: "Edit file".into(), + path: path!("/main.rs").into(), + mode: EditFileMode::Edit, + }, + stream_tx, + cx, + ) + }); + stream_rx.expect_update_fields().await; + let diff = stream_rx.expect_diff().await; + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_)))); + cx.run_until_parked(); + model.end_last_completion_stream(); + edit.await.unwrap(); + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_)))); + } + + // Ensure the diff is finalized if an error occurs while editing. + { + model.forbid_requests(); + let tool = Arc::new(EditFileTool::new( + project.clone(), + thread.downgrade(), + languages.clone(), + )); + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let edit = cx.update(|cx| { + tool.run( + EditFileToolInput { + display_description: "Edit file".into(), + path: path!("/main.rs").into(), + mode: EditFileMode::Edit, + }, + stream_tx, + cx, + ) + }); + stream_rx.expect_update_fields().await; + let diff = stream_rx.expect_diff().await; + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_)))); + edit.await.unwrap_err(); + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_)))); + model.allow_requests(); + } + + // Ensure the diff is finalized if the tool call gets dropped. + { + let tool = Arc::new(EditFileTool::new( + project.clone(), + thread.downgrade(), + languages.clone(), + )); + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let edit = cx.update(|cx| { + tool.run( + EditFileToolInput { + display_description: "Edit file".into(), + path: path!("/main.rs").into(), + mode: EditFileMode::Edit, + }, + stream_tx, + cx, + ) + }); + stream_rx.expect_update_fields().await; + let diff = stream_rx.expect_diff().await; + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_)))); + drop(edit); + cx.run_until_parked(); + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_)))); + } } fn init_test(cx: &mut TestAppContext) { diff --git a/crates/agent2/src/tools/fetch_tool.rs b/crates/agent2/src/tools/fetch_tool.rs index ae26c5fe195da3d73a8ae1da47d072a3bfc3706f..60654ac863acdc559aeaad90f1c73727f33d1b59 100644 --- a/crates/agent2/src/tools/fetch_tool.rs +++ b/crates/agent2/src/tools/fetch_tool.rs @@ -118,15 +118,19 @@ impl AgentTool for FetchTool { type Input = FetchToolInput; type Output = String; - fn name(&self) -> SharedString { - "fetch".into() + fn name() -> &'static str { + "fetch" } - fn kind(&self) -> acp::ToolKind { + fn kind() -> acp::ToolKind { acp::ToolKind::Fetch } - fn initial_title(&self, input: Result) -> SharedString { + fn initial_title( + &self, + input: Result, + _cx: &mut App, + ) -> SharedString { match input { Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)).into(), Err(_) => "Fetch URL".into(), @@ -136,12 +140,17 @@ impl AgentTool for FetchTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { + let authorize = event_stream.authorize(input.url.clone(), cx); + let text = cx.background_spawn({ let http_client = self.http_client.clone(); - async move { Self::build_message(http_client, &input.url).await } + async move { + authorize.await?; + Self::build_message(http_client, &input.url).await + } }); cx.foreground_executor().spawn(async move { diff --git a/crates/agent2/src/tools/find_path_tool.rs b/crates/agent2/src/tools/find_path_tool.rs index 552de144a73365d10d4b9a565d852c1a13672be8..735ec67cffa31969e4eef741d6a23de05f3e15dc 100644 --- a/crates/agent2/src/tools/find_path_tool.rs +++ b/crates/agent2/src/tools/find_path_tool.rs @@ -31,7 +31,6 @@ pub struct FindPathToolInput { /// You can get back the first two paths by providing a glob of "*thing*.txt" /// pub glob: String, - /// Optional starting position for paginated results (0-based). /// When not provided, starts from the beginning. #[serde(default)] @@ -86,15 +85,19 @@ impl AgentTool for FindPathTool { type Input = FindPathToolInput; type Output = FindPathToolOutput; - fn name(&self) -> SharedString { - "find_path".into() + fn name() -> &'static str { + "find_path" } - fn kind(&self) -> acp::ToolKind { + fn kind() -> acp::ToolKind { acp::ToolKind::Search } - fn initial_title(&self, input: Result) -> SharedString { + fn initial_title( + &self, + input: Result, + _cx: &mut App, + ) -> SharedString { let mut title = "Find paths".to_string(); if let Ok(input) = input { title.push_str(&format!(" matching “`{}`”", input.glob)); @@ -116,7 +119,7 @@ impl AgentTool for FindPathTool { ..cmp::min(input.offset + RESULTS_PER_PAGE, matches.len())]; event_stream.update_fields(acp::ToolCallUpdateFields { - title: Some(if paginated_matches.len() == 0 { + title: Some(if paginated_matches.is_empty() { "No matches".into() } else if paginated_matches.len() == 1 { "1 match".into() @@ -166,16 +169,17 @@ fn search_paths(glob: &str, project: Entity, cx: &mut App) -> Task SharedString { - "grep".into() + fn name() -> &'static str { + "grep" } - fn kind(&self) -> acp::ToolKind { + fn kind() -> acp::ToolKind { acp::ToolKind::Search } - fn initial_title(&self, input: Result) -> SharedString { + fn initial_title( + &self, + input: Result, + _cx: &mut App, + ) -> SharedString { match input { Ok(input) => { let page = input.page(); @@ -179,15 +182,14 @@ impl AgentTool for GrepTool { // Check if this file should be excluded based on its worktree settings if let Ok(Some(project_path)) = project.read_with(cx, |project, cx| { project.find_project_path(&path, cx) - }) { - if cx.update(|cx| { + }) + && cx.update(|cx| { let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx); worktree_settings.is_path_excluded(&project_path.path) || worktree_settings.is_path_private(&project_path.path) }).unwrap_or(false) { continue; } - } while *parse_status.borrow() != ParseStatus::Idle { parse_status.changed().await?; @@ -259,10 +261,8 @@ impl AgentTool for GrepTool { let end_row = range.end.row; output.push_str("\n### "); - if let Some(parent_symbols) = &parent_symbols { - for symbol in parent_symbols { - write!(output, "{} › ", symbol.text)?; - } + for symbol in parent_symbols { + write!(output, "{} › ", symbol.text)?; } if range.start.row == end_row { @@ -275,12 +275,11 @@ impl AgentTool for GrepTool { output.extend(snapshot.text_for_range(range)); output.push_str("\n```\n"); - if let Some(ancestor_range) = ancestor_range { - if end_row < ancestor_range.end.row { + if let Some(ancestor_range) = ancestor_range + && end_row < ancestor_range.end.row { let remaining_lines = ancestor_range.end.row - end_row; writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?; } - } matches_found += 1; } @@ -320,7 +319,7 @@ mod tests { init_test(cx); cx.executor().allow_parking(); - let fs = FakeFs::new(cx.executor().clone()); + let fs = FakeFs::new(cx.executor()); fs.insert_tree( path!("/root"), serde_json::json!({ @@ -405,7 +404,7 @@ mod tests { init_test(cx); cx.executor().allow_parking(); - let fs = FakeFs::new(cx.executor().clone()); + let fs = FakeFs::new(cx.executor()); fs.insert_tree( path!("/root"), serde_json::json!({ @@ -480,7 +479,7 @@ mod tests { init_test(cx); cx.executor().allow_parking(); - let fs = FakeFs::new(cx.executor().clone()); + let fs = FakeFs::new(cx.executor()); // Create test file with syntax structures fs.insert_tree( @@ -765,7 +764,7 @@ mod tests { if cfg!(windows) { result.replace("root\\", "root/") } else { - result.to_string() + result } } Err(e) => panic!("Failed to run grep tool: {}", e), diff --git a/crates/agent2/src/tools/list_directory_tool.rs b/crates/agent2/src/tools/list_directory_tool.rs index 61f21d8f95117f0b0a8efccf7481874037af365c..0fbe23fe205e6a9bd5a77e737460c17b997f9175 100644 --- a/crates/agent2/src/tools/list_directory_tool.rs +++ b/crates/agent2/src/tools/list_directory_tool.rs @@ -10,14 +10,12 @@ use std::fmt::Write; use std::{path::Path, sync::Arc}; use util::markdown::MarkdownInlineCode; -/// Lists files and directories in a given path. Prefer the `grep` or -/// `find_path` tools when searching the codebase. +/// Lists files and directories in a given path. Prefer the `grep` or `find_path` tools when searching the codebase. #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct ListDirectoryToolInput { /// The fully-qualified path of the directory to list in the project. /// - /// This path should never be absolute, and the first component - /// of the path should always be a root directory in a project. + /// This path should never be absolute, and the first component of the path should always be a root directory in a project. /// /// /// If the project has the following root directories: @@ -53,15 +51,19 @@ impl AgentTool for ListDirectoryTool { type Input = ListDirectoryToolInput; type Output = String; - fn name(&self) -> SharedString { - "list_directory".into() + fn name() -> &'static str { + "list_directory" } - fn kind(&self) -> ToolKind { + fn kind() -> ToolKind { ToolKind::Read } - fn initial_title(&self, input: Result) -> SharedString { + fn initial_title( + &self, + input: Result, + _cx: &mut App, + ) -> SharedString { if let Ok(input) = input { let path = MarkdownInlineCode(&input.path); format!("List the {path} directory's contents").into() diff --git a/crates/agent2/src/tools/move_path_tool.rs b/crates/agent2/src/tools/move_path_tool.rs index f8d5d0d176e5cd53d1f563385797596048c9a87e..91880c1243e0aa48569ab8e6981ddd45b41ab411 100644 --- a/crates/agent2/src/tools/move_path_tool.rs +++ b/crates/agent2/src/tools/move_path_tool.rs @@ -8,14 +8,11 @@ use serde::{Deserialize, Serialize}; use std::{path::Path, sync::Arc}; use util::markdown::MarkdownInlineCode; -/// Moves or rename a file or directory in the project, and returns confirmation -/// that the move succeeded. +/// Moves or rename a file or directory in the project, and returns confirmation that the move succeeded. /// -/// If the source and destination directories are the same, but the filename is -/// different, this performs a rename. Otherwise, it performs a move. +/// If the source and destination directories are the same, but the filename is different, this performs a rename. Otherwise, it performs a move. /// -/// This tool should be used when it's desirable to move or rename a file or -/// directory without changing its contents at all. +/// This tool should be used when it's desirable to move or rename a file or directory without changing its contents at all. #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct MovePathToolInput { /// The source path of the file or directory to move/rename. @@ -55,15 +52,19 @@ impl AgentTool for MovePathTool { type Input = MovePathToolInput; type Output = String; - fn name(&self) -> SharedString { - "move_path".into() + fn name() -> &'static str { + "move_path" } - fn kind(&self) -> ToolKind { + fn kind() -> ToolKind { ToolKind::Move } - fn initial_title(&self, input: Result) -> SharedString { + fn initial_title( + &self, + input: Result, + _cx: &mut App, + ) -> SharedString { if let Ok(input) = input { let src = MarkdownInlineCode(&input.source_path); let dest = MarkdownInlineCode(&input.destination_path); diff --git a/crates/agent2/src/tools/now_tool.rs b/crates/agent2/src/tools/now_tool.rs index a72ede26fea1ee42eddb08e2b22b2b2b89c77075..3387c0a617017991f8b2590868864287f399ec28 100644 --- a/crates/agent2/src/tools/now_tool.rs +++ b/crates/agent2/src/tools/now_tool.rs @@ -11,6 +11,7 @@ use crate::{AgentTool, ToolCallEventStream}; #[derive(Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] +#[schemars(inline)] pub enum Timezone { /// Use UTC for the datetime. Utc, @@ -32,15 +33,19 @@ impl AgentTool for NowTool { type Input = NowToolInput; type Output = String; - fn name(&self) -> SharedString { - "now".into() + fn name() -> &'static str { + "now" } - fn kind(&self) -> acp::ToolKind { + fn kind() -> acp::ToolKind { acp::ToolKind::Other } - fn initial_title(&self, _input: Result) -> SharedString { + fn initial_title( + &self, + _input: Result, + _cx: &mut App, + ) -> SharedString { "Get current time".into() } diff --git a/crates/agent2/src/tools/open_tool.rs b/crates/agent2/src/tools/open_tool.rs index 36420560c1832d40496a95c69505ab8eb9cbb2c6..595a9f380b752635f97ef5d1819a1140c1db8be0 100644 --- a/crates/agent2/src/tools/open_tool.rs +++ b/crates/agent2/src/tools/open_tool.rs @@ -8,19 +8,15 @@ use serde::{Deserialize, Serialize}; use std::{path::PathBuf, sync::Arc}; use util::markdown::MarkdownEscaped; -/// This tool opens a file or URL with the default application associated with -/// it on the user's operating system: +/// This tool opens a file or URL with the default application associated with it on the user's operating system: /// /// - On macOS, it's equivalent to the `open` command /// - On Windows, it's equivalent to `start` /// - On Linux, it uses something like `xdg-open`, `gio open`, `gnome-open`, `kde-open`, `wslview` as appropriate /// -/// For example, it can open a web browser with a URL, open a PDF file with the -/// default PDF viewer, etc. +/// For example, it can open a web browser with a URL, open a PDF file with the default PDF viewer, etc. /// -/// You MUST ONLY use this tool when the user has explicitly requested opening -/// something. You MUST NEVER assume that the user would like for you to use -/// this tool. +/// You MUST ONLY use this tool when the user has explicitly requested opening something. You MUST NEVER assume that the user would like for you to use this tool. #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct OpenToolInput { /// The path or URL to open with the default application. @@ -41,15 +37,19 @@ impl AgentTool for OpenTool { type Input = OpenToolInput; type Output = String; - fn name(&self) -> SharedString { - "open".into() + fn name() -> &'static str { + "open" } - fn kind(&self) -> ToolKind { + fn kind() -> ToolKind { ToolKind::Execute } - fn initial_title(&self, input: Result) -> SharedString { + fn initial_title( + &self, + input: Result, + _cx: &mut App, + ) -> SharedString { if let Ok(input) = input { format!("Open `{}`", MarkdownEscaped(&input.path_or_url)).into() } else { @@ -65,7 +65,7 @@ impl AgentTool for OpenTool { ) -> Task> { // If path_or_url turns out to be a path in the project, make it absolute. let abs_path = to_absolute_path(&input.path_or_url, self.project.clone(), cx); - let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())), cx); + let authorize = event_stream.authorize(self.initial_title(Ok(input.clone()), cx), cx); cx.background_spawn(async move { authorize.await?; diff --git a/crates/agent2/src/tools/read_file_tool.rs b/crates/agent2/src/tools/read_file_tool.rs index f21643cbbbffca7a489918c3466ccc369a17156c..99f145901c664624d66d7487cce579f55cff908a 100644 --- a/crates/agent2/src/tools/read_file_tool.rs +++ b/crates/agent2/src/tools/read_file_tool.rs @@ -11,6 +11,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; use std::sync::Arc; +use util::markdown::MarkdownCodeBlock; use crate::{AgentTool, ToolCallEventStream}; @@ -21,8 +22,7 @@ use crate::{AgentTool, ToolCallEventStream}; pub struct ReadFileToolInput { /// The relative path of the file to read. /// - /// This path should never be absolute, and the first component - /// of the path should always be a root directory in a project. + /// This path should never be absolute, and the first component of the path should always be a root directory in a project. /// /// /// If the project has the following root directories: @@ -34,11 +34,9 @@ pub struct ReadFileToolInput { /// If you want to access `file.txt` in `directory2`, you should use the path `directory2/file.txt`. /// pub path: String, - /// Optional line number to start reading on (1-based index) #[serde(default)] pub start_line: Option, - /// Optional line number to end reading on (1-based index, inclusive) #[serde(default)] pub end_line: Option, @@ -62,31 +60,34 @@ impl AgentTool for ReadFileTool { type Input = ReadFileToolInput; type Output = LanguageModelToolResultContent; - fn name(&self) -> SharedString { - "read_file".into() + fn name() -> &'static str { + "read_file" } - fn kind(&self) -> acp::ToolKind { + fn kind() -> acp::ToolKind { acp::ToolKind::Read } - fn initial_title(&self, input: Result) -> SharedString { - if let Ok(input) = input { - let path = &input.path; + fn initial_title( + &self, + input: Result, + cx: &mut App, + ) -> SharedString { + if let Ok(input) = input + && let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) + && let Some(path) = self + .project + .read(cx) + .short_full_path_for_project_path(&project_path, cx) + { match (input.start_line, input.end_line) { (Some(start), Some(end)) => { - format!( - "[Read file `{}` (lines {}-{})](@selection:{}:({}-{}))", - path, start, end, path, start, end - ) + format!("Read file `{}` (lines {}-{})", path.display(), start, end,) } (Some(start), None) => { - format!( - "[Read file `{}` (from line {})](@selection:{}:({}-{}))", - path, start, path, start, start - ) + format!("Read file `{}` (from line {})", path.display(), start) } - _ => format!("[Read file `{}`](@file:{})", path, path), + _ => format!("Read file `{}`", path.display()), } .into() } else { @@ -103,6 +104,12 @@ impl AgentTool for ReadFileTool { let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else { return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))); }; + let Some(abs_path) = self.project.read(cx).absolute_path(&project_path, cx) else { + return Task::ready(Err(anyhow!( + "Failed to convert {} to absolute path", + &input.path + ))); + }; // Error out if this path is either excluded or private in global settings let global_settings = WorktreeSettings::get_global(cx); @@ -138,6 +145,14 @@ impl AgentTool for ReadFileTool { let file_path = input.path.clone(); + event_stream.update_fields(ToolCallUpdateFields { + locations: Some(vec![acp::ToolCallLocation { + path: abs_path, + line: input.start_line.map(|line| line.saturating_sub(1)), + }]), + ..Default::default() + }); + if image_store::is_image_file(&self.project, &project_path, cx) { return cx.spawn(async move |cx| { let image_entity: Entity = cx @@ -175,7 +190,7 @@ impl AgentTool for ReadFileTool { buffer .file() .as_ref() - .map_or(true, |file| !file.disk_state().exists()) + .is_none_or(|file| !file.disk_state().exists()) })? { anyhow::bail!("{file_path} not found"); } @@ -246,21 +261,25 @@ impl AgentTool for ReadFileTool { }; project.update(cx, |project, cx| { - if let Some(abs_path) = project.absolute_path(&project_path, cx) { - project.set_agent_location( - Some(AgentLocation { - buffer: buffer.downgrade(), - position: anchor.unwrap_or(text::Anchor::MIN), - }), - cx, - ); + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position: anchor.unwrap_or(text::Anchor::MIN), + }), + cx, + ); + if let Ok(LanguageModelToolResultContent::Text(text)) = &result { + let markdown = MarkdownCodeBlock { + tag: &input.path, + text, + } + .to_string(); event_stream.update_fields(ToolCallUpdateFields { - locations: Some(vec![acp::ToolCallLocation { - path: abs_path, - line: input.start_line.map(|line| line.saturating_sub(1)), + content: Some(vec![acp::ToolCallContent::Content { + content: markdown.into(), }]), ..Default::default() - }); + }) } })?; diff --git a/crates/agent2/src/tools/terminal_tool.rs b/crates/agent2/src/tools/terminal_tool.rs index ecb855ac34d655caefab5ed0bd4f33d60be547f8..7acfc2455093eac0f3d15e840abce47f38a6c8b0 100644 --- a/crates/agent2/src/tools/terminal_tool.rs +++ b/crates/agent2/src/tools/terminal_tool.rs @@ -1,19 +1,19 @@ use agent_client_protocol as acp; use anyhow::Result; -use futures::{FutureExt as _, future::Shared}; -use gpui::{App, AppContext, Entity, SharedString, Task}; -use project::{Project, terminals::TerminalKind}; +use gpui::{App, Entity, SharedString, Task}; +use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::{ path::{Path, PathBuf}, + rc::Rc, sync::Arc, }; -use util::{ResultExt, get_system_shell, markdown::MarkdownInlineCode}; +use util::markdown::MarkdownInlineCode; -use crate::{AgentTool, ToolCallEventStream}; +use crate::{AgentTool, ThreadEnvironment, ToolCallEventStream}; -const COMMAND_OUTPUT_LIMIT: usize = 16 * 1024; +const COMMAND_OUTPUT_LIMIT: u64 = 16 * 1024; /// Executes a shell one-liner and returns the combined output. /// @@ -36,28 +36,14 @@ pub struct TerminalToolInput { pub struct TerminalTool { project: Entity, - determine_shell: Shared>, + environment: Rc, } impl TerminalTool { - pub fn new(project: Entity, cx: &mut App) -> Self { - let determine_shell = cx.background_spawn(async move { - if cfg!(windows) { - return get_system_shell(); - } - - if which::which("bash").is_ok() { - log::info!("agent selected bash for terminal tool"); - "bash".into() - } else { - let shell = get_system_shell(); - log::info!("agent selected {shell} for terminal tool"); - shell - } - }); + pub fn new(project: Entity, environment: Rc) -> Self { Self { project, - determine_shell: determine_shell.shared(), + environment, } } } @@ -66,21 +52,25 @@ impl AgentTool for TerminalTool { type Input = TerminalToolInput; type Output = String; - fn name(&self) -> SharedString { - "terminal".into() + fn name() -> &'static str { + "terminal" } - fn kind(&self) -> acp::ToolKind { + fn kind() -> acp::ToolKind { acp::ToolKind::Execute } - fn initial_title(&self, input: Result) -> SharedString { + fn initial_title( + &self, + input: Result, + _cx: &mut App, + ) -> SharedString { if let Ok(input) = input { let mut lines = input.command.lines(); let first_line = lines.next().unwrap_or_default(); let remaining_line_count = lines.count(); match remaining_line_count { - 0 => MarkdownInlineCode(&first_line).to_string().into(), + 0 => MarkdownInlineCode(first_line).to_string().into(), 1 => MarkdownInlineCode(&format!( "{} - {} more line", first_line, remaining_line_count @@ -102,128 +92,49 @@ impl AgentTool for TerminalTool { event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let language_registry = self.project.read(cx).languages().clone(); let working_dir = match working_dir(&input, &self.project, cx) { Ok(dir) => dir, Err(err) => return Task::ready(Err(err)), }; - let program = self.determine_shell.clone(); - let command = if cfg!(windows) { - format!("$null | & {{{}}}", input.command.replace("\"", "'")) - } else if let Some(cwd) = working_dir - .as_ref() - .and_then(|cwd| cwd.as_os_str().to_str()) - { - // Make sure once we're *inside* the shell, we cd into `cwd` - format!("(cd {cwd}; {}) self.project.update(cx, |project, cx| { - project.directory_environment(dir.as_path().into(), cx) - }), - None => Task::ready(None).shared(), - }; - let env = cx.spawn(async move |_| { - let mut env = env.await.unwrap_or_default(); - if cfg!(unix) { - env.insert("PAGER".into(), "cat".into()); - } - env - }); - - let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())), cx); - - cx.spawn({ - async move |cx| { - authorize.await?; - - let program = program.await; - let env = env.await; - let terminal = self - .project - .update(cx, |project, cx| { - project.create_terminal( - TerminalKind::Task(task::SpawnInTerminal { - command: Some(program), - args, - cwd: working_dir.clone(), - env, - ..Default::default() - }), - cx, - ) - })? - .await?; - let acp_terminal = cx.new(|cx| { - acp_thread::Terminal::new( - input.command.clone(), - working_dir.clone(), - terminal.clone(), - language_registry, - cx, - ) - })?; - event_stream.update_terminal(acp_terminal.clone()); - - let exit_status = terminal - .update(cx, |terminal, cx| terminal.wait_for_completed_task(cx))? - .await; - let (content, content_line_count) = terminal.read_with(cx, |terminal, _| { - (terminal.get_content(), terminal.total_lines()) - })?; + let authorize = event_stream.authorize(self.initial_title(Ok(input.clone()), cx), cx); + cx.spawn(async move |cx| { + authorize.await?; + + let terminal = self + .environment + .create_terminal( + input.command.clone(), + working_dir, + Some(COMMAND_OUTPUT_LIMIT), + cx, + ) + .await?; - let (processed_content, finished_with_empty_output) = process_content( - &content, - &input.command, - exit_status.map(portable_pty::ExitStatus::from), - ); + let terminal_id = terminal.id(cx)?; + event_stream.update_fields(acp::ToolCallUpdateFields { + content: Some(vec![acp::ToolCallContent::Terminal { terminal_id }]), + ..Default::default() + }); - acp_terminal - .update(cx, |terminal, cx| { - terminal.finish( - exit_status, - content.len(), - processed_content.len(), - content_line_count, - finished_with_empty_output, - cx, - ); - }) - .log_err(); + let exit_status = terminal.wait_for_exit(cx)?.await; + let output = terminal.current_output(cx)?; - Ok(processed_content) - } + Ok(process_content(output, &input.command, exit_status)) }) } } fn process_content( - content: &str, + output: acp::TerminalOutputResponse, command: &str, - exit_status: Option, -) -> (String, bool) { - let should_truncate = content.len() > COMMAND_OUTPUT_LIMIT; - - let content = if should_truncate { - let mut end_ix = COMMAND_OUTPUT_LIMIT.min(content.len()); - while !content.is_char_boundary(end_ix) { - end_ix -= 1; - } - // Don't truncate mid-line, clear the remainder of the last line - end_ix = content[..end_ix].rfind('\n').unwrap_or(end_ix); - &content[..end_ix] - } else { - content - }; - let content = content.trim(); + exit_status: acp::TerminalExitStatus, +) -> String { + let content = output.output.trim(); let is_empty = content.is_empty(); + let content = format!("```\n{content}\n```"); - let content = if should_truncate { + let content = if output.truncated { format!( "Command output too long. The first {} bytes:\n\n{content}", content.len(), @@ -232,24 +143,21 @@ fn process_content( content }; - let content = match exit_status { - Some(exit_status) if exit_status.success() => { + let content = match exit_status.exit_code { + Some(0) => { if is_empty { "Command executed successfully.".to_string() } else { - content.to_string() + content } } - Some(exit_status) => { + Some(exit_code) => { if is_empty { - format!( - "Command \"{command}\" failed with exit code {}.", - exit_status.exit_code() - ) + format!("Command \"{command}\" failed with exit code {}.", exit_code) } else { format!( "Command \"{command}\" failed with exit code {}.\n\n{content}", - exit_status.exit_code() + exit_code ) } } @@ -260,7 +168,7 @@ fn process_content( ) } }; - (content, is_empty) + content } fn working_dir( @@ -271,7 +179,7 @@ fn working_dir( let project = project.read(cx); let cd = &input.cd; - if cd == "." || cd == "" { + if cd == "." || cd.is_empty() { // Accept "." or "" as meaning "the one worktree" if we only have one worktree. let mut worktrees = project.worktrees(cx); @@ -296,178 +204,10 @@ fn working_dir( { return Ok(Some(input_path.into())); } - } else { - if let Some(worktree) = project.worktree_for_root_name(cd, cx) { - return Ok(Some(worktree.read(cx).abs_path().to_path_buf())); - } + } else if let Some(worktree) = project.worktree_for_root_name(cd, cx) { + return Ok(Some(worktree.read(cx).abs_path().to_path_buf())); } anyhow::bail!("`cd` directory {cd:?} was not in any of the project's worktrees."); } } - -#[cfg(test)] -mod tests { - use agent_settings::AgentSettings; - use editor::EditorSettings; - use fs::RealFs; - use gpui::{BackgroundExecutor, TestAppContext}; - use pretty_assertions::assert_eq; - use serde_json::json; - use settings::{Settings, SettingsStore}; - use terminal::terminal_settings::TerminalSettings; - use theme::ThemeSettings; - use util::test::TempTree; - - use crate::AgentResponseEvent; - - use super::*; - - fn init_test(executor: &BackgroundExecutor, cx: &mut TestAppContext) { - zlog::init_test(); - - executor.allow_parking(); - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - Project::init_settings(cx); - ThemeSettings::register(cx); - TerminalSettings::register(cx); - EditorSettings::register(cx); - AgentSettings::register(cx); - }); - } - - #[gpui::test] - async fn test_interactive_command(executor: BackgroundExecutor, cx: &mut TestAppContext) { - if cfg!(windows) { - return; - } - - init_test(&executor, cx); - - let fs = Arc::new(RealFs::new(None, executor)); - let tree = TempTree::new(json!({ - "project": {}, - })); - let project: Entity = - Project::test(fs, [tree.path().join("project").as_path()], cx).await; - - let input = TerminalToolInput { - command: "cat".to_owned(), - cd: tree - .path() - .join("project") - .as_path() - .to_string_lossy() - .to_string(), - }; - let (event_stream_tx, mut event_stream_rx) = ToolCallEventStream::test(); - let result = cx - .update(|cx| Arc::new(TerminalTool::new(project, cx)).run(input, event_stream_tx, cx)); - - let auth = event_stream_rx.expect_authorization().await; - auth.response.send(auth.options[0].id.clone()).unwrap(); - event_stream_rx.expect_terminal().await; - assert_eq!(result.await.unwrap(), "Command executed successfully."); - } - - #[gpui::test] - async fn test_working_directory(executor: BackgroundExecutor, cx: &mut TestAppContext) { - if cfg!(windows) { - return; - } - - init_test(&executor, cx); - - let fs = Arc::new(RealFs::new(None, executor)); - let tree = TempTree::new(json!({ - "project": {}, - "other-project": {}, - })); - let project: Entity = - Project::test(fs, [tree.path().join("project").as_path()], cx).await; - - let check = |input, expected, cx: &mut TestAppContext| { - let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); - let result = cx.update(|cx| { - Arc::new(TerminalTool::new(project.clone(), cx)).run(input, stream_tx, cx) - }); - cx.run_until_parked(); - let event = stream_rx.try_next(); - if let Ok(Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth)))) = event { - auth.response.send(auth.options[0].id.clone()).unwrap(); - } - - cx.spawn(async move |_| { - let output = result.await; - assert_eq!(output.ok(), expected); - }) - }; - - check( - TerminalToolInput { - command: "pwd".into(), - cd: ".".into(), - }, - Some(format!( - "```\n{}\n```", - tree.path().join("project").display() - )), - cx, - ) - .await; - - check( - TerminalToolInput { - command: "pwd".into(), - cd: "other-project".into(), - }, - None, // other-project is a dir, but *not* a worktree (yet) - cx, - ) - .await; - - // Absolute path above the worktree root - check( - TerminalToolInput { - command: "pwd".into(), - cd: tree.path().to_string_lossy().into(), - }, - None, - cx, - ) - .await; - - project - .update(cx, |project, cx| { - project.create_worktree(tree.path().join("other-project"), true, cx) - }) - .await - .unwrap(); - - check( - TerminalToolInput { - command: "pwd".into(), - cd: "other-project".into(), - }, - Some(format!( - "```\n{}\n```", - tree.path().join("other-project").display() - )), - cx, - ) - .await; - - check( - TerminalToolInput { - command: "pwd".into(), - cd: ".".into(), - }, - None, - cx, - ) - .await; - } -} diff --git a/crates/agent2/src/tools/thinking_tool.rs b/crates/agent2/src/tools/thinking_tool.rs index 43647bb468d808b978a1b5176539a3167c5065f6..0a68f7545f81ce3202c110b1435d33b57adf409c 100644 --- a/crates/agent2/src/tools/thinking_tool.rs +++ b/crates/agent2/src/tools/thinking_tool.rs @@ -11,8 +11,7 @@ use crate::{AgentTool, ToolCallEventStream}; /// Use this tool when you need to work through complex problems, develop strategies, or outline approaches before taking action. #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct ThinkingToolInput { - /// Content to think about. This should be a description of what to think about or - /// a problem to solve. + /// Content to think about. This should be a description of what to think about or a problem to solve. content: String, } @@ -22,15 +21,19 @@ impl AgentTool for ThinkingTool { type Input = ThinkingToolInput; type Output = String; - fn name(&self) -> SharedString { - "thinking".into() + fn name() -> &'static str { + "thinking" } - fn kind(&self) -> acp::ToolKind { + fn kind() -> acp::ToolKind { acp::ToolKind::Think } - fn initial_title(&self, _input: Result) -> SharedString { + fn initial_title( + &self, + _input: Result, + _cx: &mut App, + ) -> SharedString { "Thinking".into() } diff --git a/crates/agent2/src/tools/web_search_tool.rs b/crates/agent2/src/tools/web_search_tool.rs index c1c09707426431bf8a3ad4c59a012a567366d392..ce26bccddeeb998abf6d39cbe2acfe91cecc6d1b 100644 --- a/crates/agent2/src/tools/web_search_tool.rs +++ b/crates/agent2/src/tools/web_search_tool.rs @@ -14,7 +14,7 @@ use ui::prelude::*; use web_search::WebSearchRegistry; /// Search the web for information using your query. -/// Use this when you need real-time information, facts, or data that might not be in your training. \ +/// Use this when you need real-time information, facts, or data that might not be in your training. /// Results will include snippets and links from relevant web pages. #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct WebSearchToolInput { @@ -40,15 +40,19 @@ impl AgentTool for WebSearchTool { type Input = WebSearchToolInput; type Output = WebSearchToolOutput; - fn name(&self) -> SharedString { - "web_search".into() + fn name() -> &'static str { + "web_search" } - fn kind(&self) -> acp::ToolKind { + fn kind() -> acp::ToolKind { acp::ToolKind::Fetch } - fn initial_title(&self, _input: Result) -> SharedString { + fn initial_title( + &self, + _input: Result, + _cx: &mut App, + ) -> SharedString { "Searching the Web".into() } @@ -80,33 +84,48 @@ impl AgentTool for WebSearchTool { } }; - let result_text = if response.results.len() == 1 { - "1 result".to_string() - } else { - format!("{} results", response.results.len()) - }; - event_stream.update_fields(acp::ToolCallUpdateFields { - title: Some(format!("Searched the web: {result_text}")), - content: Some( - response - .results - .iter() - .map(|result| acp::ToolCallContent::Content { - content: acp::ContentBlock::ResourceLink(acp::ResourceLink { - name: result.title.clone(), - uri: result.url.clone(), - title: Some(result.title.clone()), - description: Some(result.text.clone()), - mime_type: None, - annotations: None, - size: None, - }), - }) - .collect(), - ), - ..Default::default() - }); + emit_update(&response, &event_stream); Ok(WebSearchToolOutput(response)) }) } + + fn replay( + &self, + _input: Self::Input, + output: Self::Output, + event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Result<()> { + emit_update(&output.0, &event_stream); + Ok(()) + } +} + +fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) { + let result_text = if response.results.len() == 1 { + "1 result".to_string() + } else { + format!("{} results", response.results.len()) + }; + event_stream.update_fields(acp::ToolCallUpdateFields { + title: Some(format!("Searched the web: {result_text}")), + content: Some( + response + .results + .iter() + .map(|result| acp::ToolCallContent::Content { + content: acp::ContentBlock::ResourceLink(acp::ResourceLink { + name: result.title.clone(), + uri: result.url.clone(), + title: Some(result.title.clone()), + description: Some(result.text.clone()), + mime_type: None, + annotations: None, + size: None, + }), + }) + .collect(), + ), + ..Default::default() + }); } diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 81c97c8aa6cc4fa64d017b97ade5ddd535487b81..bb3fe6ff9078535b500e28f4beeab957929546a5 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -6,7 +6,7 @@ publish.workspace = true license = "GPL-3.0-or-later" [features] -test-support = ["acp_thread/test-support", "gpui/test-support", "project/test-support"] +test-support = ["acp_thread/test-support", "gpui/test-support", "project/test-support", "dep:env_logger", "client/test-support", "dep:gpui_tokio", "reqwest_client/test-support"] e2e = [] [lints] @@ -17,33 +17,36 @@ path = "src/agent_servers.rs" doctest = false [dependencies] +acp_tools.workspace = true acp_thread.workspace = true +action_log.workspace = true agent-client-protocol.workspace = true -agentic-coding-protocol.workspace = true +agent_settings.workspace = true anyhow.workspace = true +client.workspace = true collections.workspace = true -context_server.workspace = true +env_logger = { workspace = true, optional = true } +fs.workspace = true futures.workspace = true gpui.workspace = true +gpui_tokio = { workspace = true, optional = true } indoc.workspace = true -itertools.workspace = true +language.workspace = true +language_model.workspace = true +language_models.workspace = true log.workspace = true -paths.workspace = true project.workspace = true -rand.workspace = true -schemars.workspace = true +reqwest_client = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true settings.workspace = true smol.workspace = true -strum.workspace = true +task.workspace = true tempfile.workspace = true thiserror.workspace = true ui.workspace = true util.workspace = true -uuid.workspace = true watch.workspace = true -which.workspace = true workspace-hack.workspace = true [target.'cfg(unix)'.dependencies] @@ -51,8 +54,12 @@ libc.workspace = true nix.workspace = true [dev-dependencies] +client = { workspace = true, features = ["test-support"] } env_logger.workspace = true +fs.workspace = true language.workspace = true indoc.workspace = true acp_thread = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } +gpui_tokio.workspace = true +reqwest_client = { workspace = true, features = ["test-support"] } diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index 00e3e3df5093c6f1acef32665ab0d3d8846fc39f..97bc172a41157a6e9434b34bb20b7225e9eb0821 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -1,34 +1,670 @@ -use std::{path::Path, rc::Rc}; - -use crate::AgentServerCommand; use acp_thread::AgentConnection; -use anyhow::Result; -use gpui::AsyncApp; +use acp_tools::AcpConnectionRegistry; +use action_log::ActionLog; +use agent_client_protocol::{self as acp, Agent as _, ErrorCode}; +use anyhow::anyhow; +use collections::HashMap; +use futures::AsyncBufReadExt as _; +use futures::io::BufReader; +use project::Project; +use project::agent_server_store::AgentServerCommand; +use serde::Deserialize; +use util::ResultExt as _; + +use std::path::PathBuf; +use std::{any::Any, cell::RefCell}; +use std::{path::Path, rc::Rc}; use thiserror::Error; -mod v0; -mod v1; +use anyhow::{Context as _, Result}; +use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity}; + +use acp_thread::{AcpThread, AuthRequired, LoadError}; #[derive(Debug, Error)] #[error("Unsupported version")] pub struct UnsupportedVersion; +pub struct AcpConnection { + server_name: SharedString, + connection: Rc, + sessions: Rc>>, + auth_methods: Vec, + agent_capabilities: acp::AgentCapabilities, + default_mode: Option, + root_dir: PathBuf, + // NB: Don't move this into the wait_task, since we need to ensure the process is + // killed on drop (setting kill_on_drop on the command seems to not always work). + child: smol::process::Child, + _io_task: Task>, + _wait_task: Task>, + _stderr_task: Task>, +} + +pub struct AcpSession { + thread: WeakEntity, + suppress_abort_err: bool, + session_modes: Option>>, +} + pub async fn connect( - server_name: &'static str, + server_name: SharedString, command: AgentServerCommand, root_dir: &Path, + default_mode: Option, + is_remote: bool, cx: &mut AsyncApp, ) -> Result> { - let conn = v1::AcpConnection::stdio(server_name, command.clone(), &root_dir, cx).await; - - match conn { - Ok(conn) => Ok(Rc::new(conn) as _), - Err(err) if err.is::() => { - // Consider re-using initialize response and subprocess when adding another version here - let conn: Rc = - Rc::new(v0::AcpConnection::stdio(server_name, command, &root_dir, cx).await?); - Ok(conn) + let conn = AcpConnection::stdio( + server_name, + command.clone(), + root_dir, + default_mode, + is_remote, + cx, + ) + .await?; + Ok(Rc::new(conn) as _) +} + +const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1; + +impl AcpConnection { + pub async fn stdio( + server_name: SharedString, + command: AgentServerCommand, + root_dir: &Path, + default_mode: Option, + is_remote: bool, + cx: &mut AsyncApp, + ) -> Result { + let mut child = util::command::new_smol_command(command.path); + child + .args(command.args.iter().map(|arg| arg.as_str())) + .envs(command.env.iter().flatten()) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()); + if !is_remote { + child.current_dir(root_dir); + } + let mut child = child.spawn()?; + + let stdout = child.stdout.take().context("Failed to take stdout")?; + let stdin = child.stdin.take().context("Failed to take stdin")?; + let stderr = child.stderr.take().context("Failed to take stderr")?; + log::trace!("Spawned (pid: {})", child.id()); + + let sessions = Rc::new(RefCell::new(HashMap::default())); + + let client = ClientDelegate { + sessions: sessions.clone(), + cx: cx.clone(), + }; + let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, { + let foreground_executor = cx.foreground_executor().clone(); + move |fut| { + foreground_executor.spawn(fut).detach(); + } + }); + + let io_task = cx.background_spawn(io_task); + + let stderr_task = cx.background_spawn(async move { + let mut stderr = BufReader::new(stderr); + let mut line = String::new(); + while let Ok(n) = stderr.read_line(&mut line).await + && n > 0 + { + log::warn!("agent stderr: {}", &line); + line.clear(); + } + Ok(()) + }); + + let wait_task = cx.spawn({ + let sessions = sessions.clone(); + let status_fut = child.status(); + async move |cx| { + let status = status_fut.await?; + + for session in sessions.borrow().values() { + session + .thread + .update(cx, |thread, cx| { + thread.emit_load_error(LoadError::Exited { status }, cx) + }) + .ok(); + } + + anyhow::Ok(()) + } + }); + + let connection = Rc::new(connection); + + cx.update(|cx| { + AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| { + registry.set_active_connection(server_name.clone(), &connection, cx) + }); + })?; + + let response = connection + .initialize(acp::InitializeRequest { + protocol_version: acp::VERSION, + client_capabilities: acp::ClientCapabilities { + fs: acp::FileSystemCapability { + read_text_file: true, + write_text_file: true, + }, + terminal: true, + }, + }) + .await?; + + if response.protocol_version < MINIMUM_SUPPORTED_VERSION { + return Err(UnsupportedVersion.into()); } - Err(err) => Err(err), + + Ok(Self { + auth_methods: response.auth_methods, + root_dir: root_dir.to_owned(), + connection, + server_name, + sessions, + agent_capabilities: response.agent_capabilities, + default_mode, + _io_task: io_task, + _wait_task: wait_task, + _stderr_task: stderr_task, + child, + }) + } + + pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities { + &self.agent_capabilities.prompt_capabilities + } + + pub fn root_dir(&self) -> &Path { + &self.root_dir + } +} + +impl Drop for AcpConnection { + fn drop(&mut self) { + // See the comment on the child field. + self.child.kill().log_err(); + } +} + +impl AgentConnection for AcpConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut App, + ) -> Task>> { + let name = self.server_name.clone(); + let conn = self.connection.clone(); + let sessions = self.sessions.clone(); + let default_mode = self.default_mode.clone(); + let cwd = cwd.to_path_buf(); + let context_server_store = project.read(cx).context_server_store().read(cx); + let mcp_servers = if project.read(cx).is_local() { + context_server_store + .configured_server_ids() + .iter() + .filter_map(|id| { + let configuration = context_server_store.configuration_for_server(id)?; + let command = configuration.command(); + Some(acp::McpServer::Stdio { + name: id.0.to_string(), + command: command.path.clone(), + args: command.args.clone(), + env: if let Some(env) = command.env.as_ref() { + env.iter() + .map(|(name, value)| acp::EnvVariable { + name: name.clone(), + value: value.clone(), + }) + .collect() + } else { + vec![] + }, + }) + }) + .collect() + } else { + // In SSH projects, the external agent is running on the remote + // machine, and currently we only run MCP servers on the local + // machine. So don't pass any MCP servers to the agent in that case. + Vec::new() + }; + + cx.spawn(async move |cx| { + let response = conn + .new_session(acp::NewSessionRequest { mcp_servers, cwd }) + .await + .map_err(|err| { + if err.code == acp::ErrorCode::AUTH_REQUIRED.code { + let mut error = AuthRequired::new(); + + if err.message != acp::ErrorCode::AUTH_REQUIRED.message { + error = error.with_description(err.message); + } + + anyhow!(error) + } else { + anyhow!(err) + } + })?; + + let modes = response.modes.map(|modes| Rc::new(RefCell::new(modes))); + + if let Some(default_mode) = default_mode { + if let Some(modes) = modes.as_ref() { + let mut modes_ref = modes.borrow_mut(); + let has_mode = modes_ref.available_modes.iter().any(|mode| mode.id == default_mode); + + if has_mode { + let initial_mode_id = modes_ref.current_mode_id.clone(); + + cx.spawn({ + let default_mode = default_mode.clone(); + let session_id = response.session_id.clone(); + let modes = modes.clone(); + async move |_| { + let result = conn.set_session_mode(acp::SetSessionModeRequest { + session_id, + mode_id: default_mode, + }) + .await.log_err(); + + if result.is_none() { + modes.borrow_mut().current_mode_id = initial_mode_id; + } + } + }).detach(); + + modes_ref.current_mode_id = default_mode; + } else { + let available_modes = modes_ref + .available_modes + .iter() + .map(|mode| format!("- `{}`: {}", mode.id, mode.name)) + .collect::>() + .join("\n"); + + log::warn!( + "`{default_mode}` is not valid {name} mode. Available options:\n{available_modes}", + ); + } + } else { + log::warn!( + "`{name}` does not support modes, but `default_mode` was set in settings.", + ); + } + } + + let session_id = response.session_id; + let action_log = cx.new(|_| ActionLog::new(project.clone()))?; + let thread = cx.new(|cx| { + AcpThread::new( + self.server_name.clone(), + self.clone(), + project, + action_log, + session_id.clone(), + // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically. + watch::Receiver::constant(self.agent_capabilities.prompt_capabilities), + cx, + ) + })?; + + let session = AcpSession { + thread: thread.downgrade(), + suppress_abort_err: false, + session_modes: modes + }; + sessions.borrow_mut().insert(session_id, session); + + Ok(thread) + }) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &self.auth_methods + } + + fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task> { + let conn = self.connection.clone(); + cx.foreground_executor().spawn(async move { + let result = conn + .authenticate(acp::AuthenticateRequest { + method_id: method_id.clone(), + }) + .await?; + + Ok(result) + }) + } + + fn prompt( + &self, + _id: Option, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let conn = self.connection.clone(); + let sessions = self.sessions.clone(); + let session_id = params.session_id.clone(); + cx.foreground_executor().spawn(async move { + let result = conn.prompt(params).await; + + let mut suppress_abort_err = false; + + if let Some(session) = sessions.borrow_mut().get_mut(&session_id) { + suppress_abort_err = session.suppress_abort_err; + session.suppress_abort_err = false; + } + + match result { + Ok(response) => Ok(response), + Err(err) => { + if err.code != ErrorCode::INTERNAL_ERROR.code { + anyhow::bail!(err) + } + + let Some(data) = &err.data else { + anyhow::bail!(err) + }; + + // Temporary workaround until the following PR is generally available: + // https://github.com/google-gemini/gemini-cli/pull/6656 + + #[derive(Deserialize)] + #[serde(deny_unknown_fields)] + struct ErrorDetails { + details: Box, + } + + match serde_json::from_value(data.clone()) { + Ok(ErrorDetails { details }) => { + if suppress_abort_err + && (details.contains("This operation was aborted") + || details.contains("The user aborted a request")) + { + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::Cancelled, + }) + } else { + Err(anyhow!(details)) + } + } + Err(_) => Err(anyhow!(err)), + } + } + } + }) + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) { + session.suppress_abort_err = true; + } + let conn = self.connection.clone(); + let params = acp::CancelNotification { + session_id: session_id.clone(), + }; + cx.foreground_executor() + .spawn(async move { conn.cancel(params).await }) + .detach(); + } + + fn session_modes( + &self, + session_id: &acp::SessionId, + _cx: &App, + ) -> Option> { + let sessions = self.sessions.clone(); + let sessions_ref = sessions.borrow(); + let Some(session) = sessions_ref.get(session_id) else { + return None; + }; + + if let Some(modes) = session.session_modes.as_ref() { + Some(Rc::new(AcpSessionModes { + connection: self.connection.clone(), + session_id: session_id.clone(), + state: modes.clone(), + }) as _) + } else { + None + } + } + + fn into_any(self: Rc) -> Rc { + self + } +} + +struct AcpSessionModes { + session_id: acp::SessionId, + connection: Rc, + state: Rc>, +} + +impl acp_thread::AgentSessionModes for AcpSessionModes { + fn current_mode(&self) -> acp::SessionModeId { + self.state.borrow().current_mode_id.clone() + } + + fn all_modes(&self) -> Vec { + self.state.borrow().available_modes.clone() + } + + fn set_mode(&self, mode_id: acp::SessionModeId, cx: &mut App) -> Task> { + let connection = self.connection.clone(); + let session_id = self.session_id.clone(); + let old_mode_id; + { + let mut state = self.state.borrow_mut(); + old_mode_id = state.current_mode_id.clone(); + state.current_mode_id = mode_id.clone(); + }; + let state = self.state.clone(); + cx.foreground_executor().spawn(async move { + let result = connection + .set_session_mode(acp::SetSessionModeRequest { + session_id, + mode_id, + }) + .await; + + if result.is_err() { + state.borrow_mut().current_mode_id = old_mode_id; + } + + result?; + + Ok(()) + }) + } +} + +struct ClientDelegate { + sessions: Rc>>, + cx: AsyncApp, +} + +impl acp::Client for ClientDelegate { + async fn request_permission( + &self, + arguments: acp::RequestPermissionRequest, + ) -> Result { + let respect_always_allow_setting; + let thread; + { + let sessions_ref = self.sessions.borrow(); + let session = sessions_ref + .get(&arguments.session_id) + .context("Failed to get session")?; + respect_always_allow_setting = session.session_modes.is_none(); + thread = session.thread.clone(); + } + + let cx = &mut self.cx.clone(); + + let task = thread.update(cx, |thread, cx| { + thread.request_tool_call_authorization( + arguments.tool_call, + arguments.options, + respect_always_allow_setting, + cx, + ) + })??; + + let outcome = task.await; + + Ok(acp::RequestPermissionResponse { outcome }) + } + + async fn write_text_file( + &self, + arguments: acp::WriteTextFileRequest, + ) -> Result<(), acp::Error> { + let cx = &mut self.cx.clone(); + let task = self + .session_thread(&arguments.session_id)? + .update(cx, |thread, cx| { + thread.write_text_file(arguments.path, arguments.content, cx) + })?; + + task.await?; + + Ok(()) + } + + async fn read_text_file( + &self, + arguments: acp::ReadTextFileRequest, + ) -> Result { + let task = self.session_thread(&arguments.session_id)?.update( + &mut self.cx.clone(), + |thread, cx| { + thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx) + }, + )?; + + let content = task.await?; + + Ok(acp::ReadTextFileResponse { content }) + } + + async fn session_notification( + &self, + notification: acp::SessionNotification, + ) -> Result<(), acp::Error> { + let sessions = self.sessions.borrow(); + let session = sessions + .get(¬ification.session_id) + .context("Failed to get session")?; + + if let acp::SessionUpdate::CurrentModeUpdate { current_mode_id } = ¬ification.update { + if let Some(session_modes) = &session.session_modes { + session_modes.borrow_mut().current_mode_id = current_mode_id.clone(); + } else { + log::error!( + "Got a `CurrentModeUpdate` notification, but they agent didn't specify `modes` during setting setup." + ); + } + } + + session.thread.update(&mut self.cx.clone(), |thread, cx| { + thread.handle_session_update(notification.update, cx) + })??; + + Ok(()) + } + + async fn create_terminal( + &self, + args: acp::CreateTerminalRequest, + ) -> Result { + let terminal = self + .session_thread(&args.session_id)? + .update(&mut self.cx.clone(), |thread, cx| { + thread.create_terminal( + args.command, + args.args, + args.env, + args.cwd, + args.output_byte_limit, + cx, + ) + })? + .await?; + Ok( + terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse { + terminal_id: terminal.id().clone(), + })?, + ) + } + + async fn kill_terminal(&self, args: acp::KillTerminalRequest) -> Result<(), acp::Error> { + self.session_thread(&args.session_id)? + .update(&mut self.cx.clone(), |thread, cx| { + thread.kill_terminal(args.terminal_id, cx) + })??; + + Ok(()) + } + + async fn release_terminal(&self, args: acp::ReleaseTerminalRequest) -> Result<(), acp::Error> { + self.session_thread(&args.session_id)? + .update(&mut self.cx.clone(), |thread, cx| { + thread.release_terminal(args.terminal_id, cx) + })??; + + Ok(()) + } + + async fn terminal_output( + &self, + args: acp::TerminalOutputRequest, + ) -> Result { + self.session_thread(&args.session_id)? + .read_with(&mut self.cx.clone(), |thread, cx| { + let out = thread + .terminal(args.terminal_id)? + .read(cx) + .current_output(cx); + + Ok(out) + })? + } + + async fn wait_for_terminal_exit( + &self, + args: acp::WaitForTerminalExitRequest, + ) -> Result { + let exit_status = self + .session_thread(&args.session_id)? + .update(&mut self.cx.clone(), |thread, cx| { + anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit()) + })?? + .await; + + Ok(acp::WaitForTerminalExitResponse { exit_status }) + } +} + +impl ClientDelegate { + fn session_thread(&self, session_id: &acp::SessionId) -> Result> { + let sessions = self.sessions.borrow(); + sessions + .get(session_id) + .context("Failed to get session") + .map(|session| session.thread.clone()) } } diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs deleted file mode 100644 index 327613de673baeea4964c90cdd88c9075ed56f11..0000000000000000000000000000000000000000 --- a/crates/agent_servers/src/acp/v0.rs +++ /dev/null @@ -1,510 +0,0 @@ -// Translates old acp agents into the new schema -use agent_client_protocol as acp; -use agentic_coding_protocol::{self as acp_old, AgentRequest as _}; -use anyhow::{Context as _, Result, anyhow}; -use futures::channel::oneshot; -use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; -use project::Project; -use std::{cell::RefCell, path::Path, rc::Rc}; -use ui::App; -use util::ResultExt as _; - -use crate::AgentServerCommand; -use acp_thread::{AcpThread, AgentConnection, AuthRequired}; - -#[derive(Clone)] -struct OldAcpClientDelegate { - thread: Rc>>, - cx: AsyncApp, - next_tool_call_id: Rc>, - // sent_buffer_versions: HashMap, HashMap>, -} - -impl OldAcpClientDelegate { - fn new(thread: Rc>>, cx: AsyncApp) -> Self { - Self { - thread, - cx, - next_tool_call_id: Rc::new(RefCell::new(0)), - } - } -} - -impl acp_old::Client for OldAcpClientDelegate { - async fn stream_assistant_message_chunk( - &self, - params: acp_old::StreamAssistantMessageChunkParams, - ) -> Result<(), acp_old::Error> { - let cx = &mut self.cx.clone(); - - cx.update(|cx| { - self.thread - .borrow() - .update(cx, |thread, cx| match params.chunk { - acp_old::AssistantMessageChunk::Text { text } => { - thread.push_assistant_content_block(text.into(), false, cx) - } - acp_old::AssistantMessageChunk::Thought { thought } => { - thread.push_assistant_content_block(thought.into(), true, cx) - } - }) - .log_err(); - })?; - - Ok(()) - } - - async fn request_tool_call_confirmation( - &self, - request: acp_old::RequestToolCallConfirmationParams, - ) -> Result { - let cx = &mut self.cx.clone(); - - let old_acp_id = *self.next_tool_call_id.borrow() + 1; - self.next_tool_call_id.replace(old_acp_id); - - let tool_call = into_new_tool_call( - acp::ToolCallId(old_acp_id.to_string().into()), - request.tool_call, - ); - - let mut options = match request.confirmation { - acp_old::ToolCallConfirmation::Edit { .. } => vec![( - acp_old::ToolCallConfirmationOutcome::AlwaysAllow, - acp::PermissionOptionKind::AllowAlways, - "Always Allow Edits".to_string(), - )], - acp_old::ToolCallConfirmation::Execute { root_command, .. } => vec![( - acp_old::ToolCallConfirmationOutcome::AlwaysAllow, - acp::PermissionOptionKind::AllowAlways, - format!("Always Allow {}", root_command), - )], - acp_old::ToolCallConfirmation::Mcp { - server_name, - tool_name, - .. - } => vec![ - ( - acp_old::ToolCallConfirmationOutcome::AlwaysAllowMcpServer, - acp::PermissionOptionKind::AllowAlways, - format!("Always Allow {}", server_name), - ), - ( - acp_old::ToolCallConfirmationOutcome::AlwaysAllowTool, - acp::PermissionOptionKind::AllowAlways, - format!("Always Allow {}", tool_name), - ), - ], - acp_old::ToolCallConfirmation::Fetch { .. } => vec![( - acp_old::ToolCallConfirmationOutcome::AlwaysAllow, - acp::PermissionOptionKind::AllowAlways, - "Always Allow".to_string(), - )], - acp_old::ToolCallConfirmation::Other { .. } => vec![( - acp_old::ToolCallConfirmationOutcome::AlwaysAllow, - acp::PermissionOptionKind::AllowAlways, - "Always Allow".to_string(), - )], - }; - - options.extend([ - ( - acp_old::ToolCallConfirmationOutcome::Allow, - acp::PermissionOptionKind::AllowOnce, - "Allow".to_string(), - ), - ( - acp_old::ToolCallConfirmationOutcome::Reject, - acp::PermissionOptionKind::RejectOnce, - "Reject".to_string(), - ), - ]); - - let mut outcomes = Vec::with_capacity(options.len()); - let mut acp_options = Vec::with_capacity(options.len()); - - for (index, (outcome, kind, label)) in options.into_iter().enumerate() { - outcomes.push(outcome); - acp_options.push(acp::PermissionOption { - id: acp::PermissionOptionId(index.to_string().into()), - name: label, - kind, - }) - } - - let response = cx - .update(|cx| { - self.thread.borrow().update(cx, |thread, cx| { - thread.request_tool_call_authorization(tool_call, acp_options, cx) - }) - })? - .context("Failed to update thread")? - .await; - - let outcome = match response { - Ok(option_id) => outcomes[option_id.0.parse::().unwrap_or(0)], - Err(oneshot::Canceled) => acp_old::ToolCallConfirmationOutcome::Cancel, - }; - - Ok(acp_old::RequestToolCallConfirmationResponse { - id: acp_old::ToolCallId(old_acp_id), - outcome: outcome, - }) - } - - async fn push_tool_call( - &self, - request: acp_old::PushToolCallParams, - ) -> Result { - let cx = &mut self.cx.clone(); - - let old_acp_id = *self.next_tool_call_id.borrow() + 1; - self.next_tool_call_id.replace(old_acp_id); - - cx.update(|cx| { - self.thread.borrow().update(cx, |thread, cx| { - thread.upsert_tool_call( - into_new_tool_call(acp::ToolCallId(old_acp_id.to_string().into()), request), - cx, - ) - }) - })? - .context("Failed to update thread")?; - - Ok(acp_old::PushToolCallResponse { - id: acp_old::ToolCallId(old_acp_id), - }) - } - - async fn update_tool_call( - &self, - request: acp_old::UpdateToolCallParams, - ) -> Result<(), acp_old::Error> { - let cx = &mut self.cx.clone(); - - cx.update(|cx| { - self.thread.borrow().update(cx, |thread, cx| { - thread.update_tool_call( - acp::ToolCallUpdate { - id: acp::ToolCallId(request.tool_call_id.0.to_string().into()), - fields: acp::ToolCallUpdateFields { - status: Some(into_new_tool_call_status(request.status)), - content: Some( - request - .content - .into_iter() - .map(into_new_tool_call_content) - .collect::>(), - ), - ..Default::default() - }, - }, - cx, - ) - }) - })? - .context("Failed to update thread")??; - - Ok(()) - } - - async fn update_plan(&self, request: acp_old::UpdatePlanParams) -> Result<(), acp_old::Error> { - let cx = &mut self.cx.clone(); - - cx.update(|cx| { - self.thread.borrow().update(cx, |thread, cx| { - thread.update_plan( - acp::Plan { - entries: request - .entries - .into_iter() - .map(into_new_plan_entry) - .collect(), - }, - cx, - ) - }) - })? - .context("Failed to update thread")?; - - Ok(()) - } - - async fn read_text_file( - &self, - acp_old::ReadTextFileParams { path, line, limit }: acp_old::ReadTextFileParams, - ) -> Result { - let content = self - .cx - .update(|cx| { - self.thread.borrow().update(cx, |thread, cx| { - thread.read_text_file(path, line, limit, false, cx) - }) - })? - .context("Failed to update thread")? - .await?; - Ok(acp_old::ReadTextFileResponse { content }) - } - - async fn write_text_file( - &self, - acp_old::WriteTextFileParams { path, content }: acp_old::WriteTextFileParams, - ) -> Result<(), acp_old::Error> { - self.cx - .update(|cx| { - self.thread - .borrow() - .update(cx, |thread, cx| thread.write_text_file(path, content, cx)) - })? - .context("Failed to update thread")? - .await?; - - Ok(()) - } -} - -fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) -> acp::ToolCall { - acp::ToolCall { - id: id, - title: request.label, - kind: acp_kind_from_old_icon(request.icon), - status: acp::ToolCallStatus::InProgress, - content: request - .content - .into_iter() - .map(into_new_tool_call_content) - .collect(), - locations: request - .locations - .into_iter() - .map(into_new_tool_call_location) - .collect(), - raw_input: None, - raw_output: None, - } -} - -fn acp_kind_from_old_icon(icon: acp_old::Icon) -> acp::ToolKind { - match icon { - acp_old::Icon::FileSearch => acp::ToolKind::Search, - acp_old::Icon::Folder => acp::ToolKind::Search, - acp_old::Icon::Globe => acp::ToolKind::Search, - acp_old::Icon::Hammer => acp::ToolKind::Other, - acp_old::Icon::LightBulb => acp::ToolKind::Think, - acp_old::Icon::Pencil => acp::ToolKind::Edit, - acp_old::Icon::Regex => acp::ToolKind::Search, - acp_old::Icon::Terminal => acp::ToolKind::Execute, - } -} - -fn into_new_tool_call_status(status: acp_old::ToolCallStatus) -> acp::ToolCallStatus { - match status { - acp_old::ToolCallStatus::Running => acp::ToolCallStatus::InProgress, - acp_old::ToolCallStatus::Finished => acp::ToolCallStatus::Completed, - acp_old::ToolCallStatus::Error => acp::ToolCallStatus::Failed, - } -} - -fn into_new_tool_call_content(content: acp_old::ToolCallContent) -> acp::ToolCallContent { - match content { - acp_old::ToolCallContent::Markdown { markdown } => markdown.into(), - acp_old::ToolCallContent::Diff { diff } => acp::ToolCallContent::Diff { - diff: into_new_diff(diff), - }, - } -} - -fn into_new_diff(diff: acp_old::Diff) -> acp::Diff { - acp::Diff { - path: diff.path, - old_text: diff.old_text, - new_text: diff.new_text, - } -} - -fn into_new_tool_call_location(location: acp_old::ToolCallLocation) -> acp::ToolCallLocation { - acp::ToolCallLocation { - path: location.path, - line: location.line, - } -} - -fn into_new_plan_entry(entry: acp_old::PlanEntry) -> acp::PlanEntry { - acp::PlanEntry { - content: entry.content, - priority: into_new_plan_priority(entry.priority), - status: into_new_plan_status(entry.status), - } -} - -fn into_new_plan_priority(priority: acp_old::PlanEntryPriority) -> acp::PlanEntryPriority { - match priority { - acp_old::PlanEntryPriority::Low => acp::PlanEntryPriority::Low, - acp_old::PlanEntryPriority::Medium => acp::PlanEntryPriority::Medium, - acp_old::PlanEntryPriority::High => acp::PlanEntryPriority::High, - } -} - -fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatus { - match status { - acp_old::PlanEntryStatus::Pending => acp::PlanEntryStatus::Pending, - acp_old::PlanEntryStatus::InProgress => acp::PlanEntryStatus::InProgress, - acp_old::PlanEntryStatus::Completed => acp::PlanEntryStatus::Completed, - } -} - -pub struct AcpConnection { - pub name: &'static str, - pub connection: acp_old::AgentConnection, - pub _child_status: Task>, - pub current_thread: Rc>>, -} - -impl AcpConnection { - pub fn stdio( - name: &'static str, - command: AgentServerCommand, - root_dir: &Path, - cx: &mut AsyncApp, - ) -> Task> { - let root_dir = root_dir.to_path_buf(); - - cx.spawn(async move |cx| { - let mut child = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .current_dir(root_dir) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()) - .kill_on_drop(true) - .spawn()?; - - let stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - log::trace!("Spawned (pid: {})", child.id()); - - let foreground_executor = cx.foreground_executor().clone(); - - let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid())); - - let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( - OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()), - stdin, - stdout, - move |fut| foreground_executor.spawn(fut).detach(), - ); - - let io_task = cx.background_spawn(async move { - io_fut.await.log_err(); - }); - - let child_status = cx.background_spawn(async move { - let result = match child.status().await { - Err(e) => Err(anyhow!(e)), - Ok(result) if result.success() => Ok(()), - Ok(result) => Err(anyhow!(result)), - }; - drop(io_task); - result - }); - - Ok(Self { - name, - connection, - _child_status: child_status, - current_thread: thread_rc, - }) - }) - } -} - -impl AgentConnection for AcpConnection { - fn new_thread( - self: Rc, - project: Entity, - _cwd: &Path, - cx: &mut AsyncApp, - ) -> Task>> { - let task = self.connection.request_any( - acp_old::InitializeParams { - protocol_version: acp_old::ProtocolVersion::latest(), - } - .into_any(), - ); - let current_thread = self.current_thread.clone(); - cx.spawn(async move |cx| { - let result = task.await?; - let result = acp_old::InitializeParams::response_from_any(result)?; - - if !result.is_authenticated { - anyhow::bail!(AuthRequired) - } - - cx.update(|cx| { - let thread = cx.new(|cx| { - let session_id = acp::SessionId("acp-old-no-id".into()); - AcpThread::new(self.name, self.clone(), project, session_id, cx) - }); - current_thread.replace(thread.downgrade()); - thread - }) - }) - } - - fn auth_methods(&self) -> &[acp::AuthMethod] { - &[] - } - - fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task> { - let task = self - .connection - .request_any(acp_old::AuthenticateParams.into_any()); - cx.foreground_executor().spawn(async move { - task.await?; - Ok(()) - }) - } - - fn prompt( - &self, - _id: Option, - params: acp::PromptRequest, - cx: &mut App, - ) -> Task> { - let chunks = params - .prompt - .into_iter() - .filter_map(|block| match block { - acp::ContentBlock::Text(text) => { - Some(acp_old::UserMessageChunk::Text { text: text.text }) - } - acp::ContentBlock::ResourceLink(link) => Some(acp_old::UserMessageChunk::Path { - path: link.uri.into(), - }), - _ => None, - }) - .collect(); - - let task = self - .connection - .request_any(acp_old::SendUserMessageParams { chunks }.into_any()); - cx.foreground_executor().spawn(async move { - task.await?; - anyhow::Ok(acp::PromptResponse { - stop_reason: acp::StopReason::EndTurn, - }) - }) - } - - fn cancel(&self, _session_id: &acp::SessionId, cx: &mut App) { - let task = self - .connection - .request_any(acp_old::CancelSendMessageParams.into_any()); - cx.foreground_executor() - .spawn(async move { - task.await?; - anyhow::Ok(()) - }) - .detach_and_log_err(cx) - } -} diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs deleted file mode 100644 index de397fddf0f303d64ff772e4a5fc9de27ae5f577..0000000000000000000000000000000000000000 --- a/crates/agent_servers/src/acp/v1.rs +++ /dev/null @@ -1,283 +0,0 @@ -use agent_client_protocol::{self as acp, Agent as _}; -use anyhow::anyhow; -use collections::HashMap; -use futures::channel::oneshot; -use project::Project; -use std::cell::RefCell; -use std::path::Path; -use std::rc::Rc; - -use anyhow::{Context as _, Result}; -use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; - -use crate::{AgentServerCommand, acp::UnsupportedVersion}; -use acp_thread::{AcpThread, AgentConnection, AuthRequired}; - -pub struct AcpConnection { - server_name: &'static str, - connection: Rc, - sessions: Rc>>, - auth_methods: Vec, - _io_task: Task>, -} - -pub struct AcpSession { - thread: WeakEntity, -} - -const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1; - -impl AcpConnection { - pub async fn stdio( - server_name: &'static str, - command: AgentServerCommand, - root_dir: &Path, - cx: &mut AsyncApp, - ) -> Result { - let mut child = util::command::new_smol_command(&command.path) - .args(command.args.iter().map(|arg| arg.as_str())) - .envs(command.env.iter().flatten()) - .current_dir(root_dir) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()) - .kill_on_drop(true) - .spawn()?; - - let stdout = child.stdout.take().expect("Failed to take stdout"); - let stdin = child.stdin.take().expect("Failed to take stdin"); - log::trace!("Spawned (pid: {})", child.id()); - - let sessions = Rc::new(RefCell::new(HashMap::default())); - - let client = ClientDelegate { - sessions: sessions.clone(), - cx: cx.clone(), - }; - let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, { - let foreground_executor = cx.foreground_executor().clone(); - move |fut| { - foreground_executor.spawn(fut).detach(); - } - }); - - let io_task = cx.background_spawn(io_task); - - cx.spawn({ - let sessions = sessions.clone(); - async move |cx| { - let status = child.status().await?; - - for session in sessions.borrow().values() { - session - .thread - .update(cx, |thread, cx| thread.emit_server_exited(status, cx)) - .ok(); - } - - anyhow::Ok(()) - } - }) - .detach(); - - let response = connection - .initialize(acp::InitializeRequest { - protocol_version: acp::VERSION, - client_capabilities: acp::ClientCapabilities { - fs: acp::FileSystemCapability { - read_text_file: true, - write_text_file: true, - }, - }, - }) - .await?; - - if response.protocol_version < MINIMUM_SUPPORTED_VERSION { - return Err(UnsupportedVersion.into()); - } - - Ok(Self { - auth_methods: response.auth_methods, - connection: connection.into(), - server_name, - sessions, - _io_task: io_task, - }) - } -} - -impl AgentConnection for AcpConnection { - fn new_thread( - self: Rc, - project: Entity, - cwd: &Path, - cx: &mut AsyncApp, - ) -> Task>> { - let conn = self.connection.clone(); - let sessions = self.sessions.clone(); - let cwd = cwd.to_path_buf(); - cx.spawn(async move |cx| { - let response = conn - .new_session(acp::NewSessionRequest { - mcp_servers: vec![], - cwd, - }) - .await - .map_err(|err| { - if err.code == acp::ErrorCode::AUTH_REQUIRED.code { - anyhow!(AuthRequired) - } else { - anyhow!(err) - } - })?; - - let session_id = response.session_id; - - let thread = cx.new(|cx| { - AcpThread::new( - self.server_name, - self.clone(), - project, - session_id.clone(), - cx, - ) - })?; - - let session = AcpSession { - thread: thread.downgrade(), - }; - sessions.borrow_mut().insert(session_id, session); - - Ok(thread) - }) - } - - fn auth_methods(&self) -> &[acp::AuthMethod] { - &self.auth_methods - } - - fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task> { - let conn = self.connection.clone(); - cx.foreground_executor().spawn(async move { - let result = conn - .authenticate(acp::AuthenticateRequest { - method_id: method_id.clone(), - }) - .await?; - - Ok(result) - }) - } - - fn prompt( - &self, - _id: Option, - params: acp::PromptRequest, - cx: &mut App, - ) -> Task> { - let conn = self.connection.clone(); - cx.foreground_executor().spawn(async move { - let response = conn.prompt(params).await?; - Ok(response) - }) - } - - fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { - let conn = self.connection.clone(); - let params = acp::CancelNotification { - session_id: session_id.clone(), - }; - cx.foreground_executor() - .spawn(async move { conn.cancel(params).await }) - .detach(); - } -} - -struct ClientDelegate { - sessions: Rc>>, - cx: AsyncApp, -} - -impl acp::Client for ClientDelegate { - async fn request_permission( - &self, - arguments: acp::RequestPermissionRequest, - ) -> Result { - let cx = &mut self.cx.clone(); - let rx = self - .sessions - .borrow() - .get(&arguments.session_id) - .context("Failed to get session")? - .thread - .update(cx, |thread, cx| { - thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx) - })?; - - let result = rx.await; - - let outcome = match result { - Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option }, - Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled, - }; - - Ok(acp::RequestPermissionResponse { outcome }) - } - - async fn write_text_file( - &self, - arguments: acp::WriteTextFileRequest, - ) -> Result<(), acp::Error> { - let cx = &mut self.cx.clone(); - let task = self - .sessions - .borrow() - .get(&arguments.session_id) - .context("Failed to get session")? - .thread - .update(cx, |thread, cx| { - thread.write_text_file(arguments.path, arguments.content, cx) - })?; - - task.await?; - - Ok(()) - } - - async fn read_text_file( - &self, - arguments: acp::ReadTextFileRequest, - ) -> Result { - let cx = &mut self.cx.clone(); - let task = self - .sessions - .borrow() - .get(&arguments.session_id) - .context("Failed to get session")? - .thread - .update(cx, |thread, cx| { - thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx) - })?; - - let content = task.await?; - - Ok(acp::ReadTextFileResponse { content }) - } - - async fn session_notification( - &self, - notification: acp::SessionNotification, - ) -> Result<(), acp::Error> { - let cx = &mut self.cx.clone(); - let sessions = self.sessions.borrow(); - let session = sessions - .get(¬ification.session_id) - .context("Failed to get session")?; - - session.thread.update(cx, |thread, cx| { - thread.handle_session_update(notification.update, cx) - })??; - - Ok(()) - } -} diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index b3b8a3317049927986a6a578bc50c4e5506b7650..2c2900cb79328249355704606652c54d08f072e5 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -1,176 +1,79 @@ mod acp; mod claude; +mod custom; mod gemini; -mod settings; -#[cfg(test)] -mod e2e_tests; +#[cfg(any(test, feature = "test-support"))] +pub mod e2e_tests; pub use claude::*; +pub use custom::*; +use fs::Fs; pub use gemini::*; -pub use settings::*; +use project::agent_server_store::AgentServerStore; use acp_thread::AgentConnection; use anyhow::Result; -use collections::HashMap; -use gpui::{App, AsyncApp, Entity, SharedString, Task}; +use gpui::{App, Entity, SharedString, Task}; use project::Project; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use std::{ - path::{Path, PathBuf}, - rc::Rc, - sync::Arc, -}; -use util::ResultExt as _; +use std::{any::Any, path::Path, rc::Rc, sync::Arc}; -pub fn init(cx: &mut App) { - settings::init(cx); +pub use acp::AcpConnection; + +pub struct AgentServerDelegate { + store: Entity, + project: Entity, + status_tx: Option>, + new_version_available: Option>>, +} + +impl AgentServerDelegate { + pub fn new( + store: Entity, + project: Entity, + status_tx: Option>, + new_version_tx: Option>>, + ) -> Self { + Self { + store, + project, + status_tx, + new_version_available: new_version_tx, + } + } + + pub fn project(&self) -> &Entity { + &self.project + } } pub trait AgentServer: Send { fn logo(&self) -> ui::IconName; - fn name(&self) -> &'static str; - fn empty_state_headline(&self) -> &'static str; - fn empty_state_message(&self) -> &'static str; + fn name(&self) -> SharedString; + fn telemetry_id(&self) -> &'static str; + fn default_mode(&self, _cx: &mut App) -> Option { + None + } + fn set_default_mode( + &self, + _mode_id: Option, + _fs: Arc, + _cx: &mut App, + ) { + } fn connect( &self, - root_dir: &Path, - project: &Entity, + root_dir: Option<&Path>, + delegate: AgentServerDelegate, cx: &mut App, - ) -> Task>>; -} - -impl std::fmt::Debug for AgentServerCommand { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let filtered_env = self.env.as_ref().map(|env| { - env.iter() - .map(|(k, v)| { - ( - k, - if util::redact::should_redact(k) { - "[REDACTED]" - } else { - v - }, - ) - }) - .collect::>() - }); + ) -> Task, Option)>>; - f.debug_struct("AgentServerCommand") - .field("path", &self.path) - .field("args", &self.args) - .field("env", &filtered_env) - .finish() - } -} - -pub enum AgentServerVersion { - Supported, - Unsupported { - error_message: SharedString, - upgrade_message: SharedString, - upgrade_command: String, - }, + fn into_any(self: Rc) -> Rc; } -#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)] -pub struct AgentServerCommand { - #[serde(rename = "command")] - pub path: PathBuf, - #[serde(default)] - pub args: Vec, - pub env: Option>, -} - -impl AgentServerCommand { - pub(crate) async fn resolve( - path_bin_name: &'static str, - extra_args: &[&'static str], - fallback_path: Option<&Path>, - settings: Option, - project: &Entity, - cx: &mut AsyncApp, - ) -> Option { - if let Some(agent_settings) = settings { - return Some(Self { - path: agent_settings.command.path, - args: agent_settings - .command - .args - .into_iter() - .chain(extra_args.iter().map(|arg| arg.to_string())) - .collect(), - env: agent_settings.command.env, - }); - } else { - match find_bin_in_path(path_bin_name, project, cx).await { - Some(path) => Some(Self { - path, - args: extra_args.iter().map(|arg| arg.to_string()).collect(), - env: None, - }), - None => fallback_path.and_then(|path| { - if path.exists() { - Some(Self { - path: path.to_path_buf(), - args: extra_args.iter().map(|arg| arg.to_string()).collect(), - env: None, - }) - } else { - None - } - }), - } - } +impl dyn AgentServer { + pub fn downcast(self: Rc) -> Option> { + self.into_any().downcast().ok() } } - -async fn find_bin_in_path( - bin_name: &'static str, - project: &Entity, - cx: &mut AsyncApp, -) -> Option { - let (env_task, root_dir) = project - .update(cx, |project, cx| { - let worktree = project.visible_worktrees(cx).next(); - match worktree { - Some(worktree) => { - let env_task = project.environment().update(cx, |env, cx| { - env.get_worktree_environment(worktree.clone(), cx) - }); - - let path = worktree.read(cx).abs_path(); - (env_task, path) - } - None => { - let path: Arc = paths::home_dir().as_path().into(); - let env_task = project.environment().update(cx, |env, cx| { - env.get_directory_environment(path.clone(), cx) - }); - (env_task, path) - } - } - }) - .log_err()?; - - cx.background_executor() - .spawn(async move { - let which_result = if cfg!(windows) { - which::which(bin_name) - } else { - let env = env_task.await.unwrap_or_default(); - let shell_path = env.get("PATH").cloned(); - which::which_in(bin_name, shell_path.as_ref(), root_dir.as_ref()) - }; - - if let Err(which::Error::CannotFindBinaryPath) = which_result { - return None; - } - - which_result.log_err() - }) - .await -} diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index c394ec4a9c23b64400e12a94adffcc65499a0ebd..c75c9539abe5fdd03293d98719d4a905b368c4a4 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -1,1066 +1,96 @@ -mod mcp_server; -pub mod tools; - -use collections::HashMap; -use context_server::listener::McpServerTool; -use project::Project; -use settings::SettingsStore; -use smol::process::Child; -use std::cell::RefCell; -use std::fmt::Display; +use agent_client_protocol as acp; +use fs::Fs; +use settings::{SettingsStore, update_settings_file}; use std::path::Path; use std::rc::Rc; -use uuid::Uuid; +use std::sync::Arc; +use std::{any::Any, path::PathBuf}; -use agent_client_protocol as acp; -use anyhow::{Result, anyhow}; -use futures::channel::oneshot; -use futures::{AsyncBufReadExt, AsyncWriteExt}; -use futures::{ - AsyncRead, AsyncWrite, FutureExt, StreamExt, - channel::mpsc::{self, UnboundedReceiver, UnboundedSender}, - io::BufReader, - select_biased, -}; -use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; -use serde::{Deserialize, Serialize}; -use util::{ResultExt, debug_panic}; +use anyhow::{Context as _, Result}; +use gpui::{App, AppContext as _, SharedString, Task}; +use project::agent_server_store::{AllAgentServersSettings, CLAUDE_CODE_NAME}; -use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig}; -use crate::claude::tools::ClaudeTool; -use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; -use acp_thread::{AcpThread, AgentConnection}; +use crate::{AgentServer, AgentServerDelegate}; +use acp_thread::AgentConnection; #[derive(Clone)] pub struct ClaudeCode; -impl AgentServer for ClaudeCode { - fn name(&self) -> &'static str { - "Claude Code" - } +pub struct AgentServerLoginCommand { + pub path: PathBuf, + pub arguments: Vec, +} - fn empty_state_headline(&self) -> &'static str { - self.name() +impl AgentServer for ClaudeCode { + fn telemetry_id(&self) -> &'static str { + "claude-code" } - fn empty_state_message(&self) -> &'static str { - "How can I help you today?" + fn name(&self) -> SharedString { + "Claude Code".into() } fn logo(&self) -> ui::IconName { ui::IconName::AiClaude } - fn connect( - &self, - _root_dir: &Path, - _project: &Entity, - _cx: &mut App, - ) -> Task>> { - let connection = ClaudeAgentConnection { - sessions: Default::default(), - }; - - Task::ready(Ok(Rc::new(connection) as _)) - } -} - -struct ClaudeAgentConnection { - sessions: Rc>>, -} - -impl AgentConnection for ClaudeAgentConnection { - fn new_thread( - self: Rc, - project: Entity, - cwd: &Path, - cx: &mut AsyncApp, - ) -> Task>> { - let cwd = cwd.to_owned(); - cx.spawn(async move |cx| { - let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); - let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?; - - let mut mcp_servers = HashMap::default(); - mcp_servers.insert( - mcp_server::SERVER_NAME.to_string(), - permission_mcp_server.server_config()?, - ); - let mcp_config = McpConfig { mcp_servers }; - - let mcp_config_file = tempfile::NamedTempFile::new()?; - let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts(); - - let mut mcp_config_file = smol::fs::File::from(mcp_config_file); - mcp_config_file - .write_all(serde_json::to_string(&mcp_config)?.as_bytes()) - .await?; - mcp_config_file.flush().await?; - - let settings = cx.read_global(|settings: &SettingsStore, _| { - settings.get::(None).claude.clone() - })?; - - let Some(command) = AgentServerCommand::resolve( - "claude", - &[], - Some(&util::paths::home_dir().join(".claude/local/claude")), - settings, - &project, - cx, - ) - .await - else { - anyhow::bail!("Failed to find claude binary"); - }; - - let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded(); - let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); - - let session_id = acp::SessionId(Uuid::new_v4().to_string().into()); - - log::trace!("Starting session with id: {}", session_id); - - let mut child = spawn_claude( - &command, - ClaudeSessionMode::Start, - session_id.clone(), - &mcp_config_path, - &cwd, - )?; - - let stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - - let pid = child.id(); - log::trace!("Spawned (pid: {})", pid); - - cx.background_spawn(async move { - let mut outgoing_rx = Some(outgoing_rx); - - ClaudeAgentSession::handle_io( - outgoing_rx.take().unwrap(), - incoming_message_tx.clone(), - stdin, - stdout, - ) - .await?; - - log::trace!("Stopped (pid: {})", pid); - - drop(mcp_config_path); - anyhow::Ok(()) - }) - .detach(); - - let turn_state = Rc::new(RefCell::new(TurnState::None)); - - let handler_task = cx.spawn({ - let turn_state = turn_state.clone(); - let mut thread_rx = thread_rx.clone(); - async move |cx| { - while let Some(message) = incoming_message_rx.next().await { - ClaudeAgentSession::handle_message( - thread_rx.clone(), - message, - turn_state.clone(), - cx, - ) - .await - } - - if let Some(status) = child.status().await.log_err() { - if let Some(thread) = thread_rx.recv().await.ok() { - thread - .update(cx, |thread, cx| { - thread.emit_server_exited(status, cx); - }) - .ok(); - } - } - } - }); - - let thread = cx.new(|cx| { - AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx) - })?; - - thread_tx.send(thread.downgrade())?; - - let session = ClaudeAgentSession { - outgoing_tx, - turn_state, - _handler_task: handler_task, - _mcp_server: Some(permission_mcp_server), - }; - - self.sessions.borrow_mut().insert(session_id, session); - - Ok(thread) - }) - } + fn default_mode(&self, cx: &mut App) -> Option { + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).claude.clone() + }); - fn auth_methods(&self) -> &[acp::AuthMethod] { - &[] + settings + .as_ref() + .and_then(|s| s.default_mode.clone().map(|m| acp::SessionModeId(m.into()))) } - fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task> { - Task::ready(Err(anyhow!("Authentication not supported"))) + fn set_default_mode(&self, mode_id: Option, fs: Arc, cx: &mut App) { + update_settings_file::(fs, cx, |settings, _| { + settings.claude.get_or_insert_default().default_mode = mode_id.map(|m| m.to_string()) + }); } - fn prompt( + fn connect( &self, - _id: Option, - params: acp::PromptRequest, + root_dir: Option<&Path>, + delegate: AgentServerDelegate, cx: &mut App, - ) -> Task> { - let sessions = self.sessions.borrow(); - let Some(session) = sessions.get(¶ms.session_id) else { - return Task::ready(Err(anyhow!( - "Attempted to send message to nonexistent session {}", - params.session_id - ))); - }; - - let (end_tx, end_rx) = oneshot::channel(); - session.turn_state.replace(TurnState::InProgress { end_tx }); - - let mut content = String::new(); - for chunk in params.prompt { - match chunk { - acp::ContentBlock::Text(text_content) => { - content.push_str(&text_content.text); - } - acp::ContentBlock::ResourceLink(resource_link) => { - content.push_str(&format!("@{}", resource_link.uri)); - } - acp::ContentBlock::Audio(_) - | acp::ContentBlock::Image(_) - | acp::ContentBlock::Resource(_) => { - // TODO - } - } - } - - if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User { - message: Message { - role: Role::User, - content: Content::UntaggedText(content), - id: None, - model: None, - stop_reason: None, - stop_sequence: None, - usage: None, - }, - session_id: Some(params.session_id.to_string()), - }) { - return Task::ready(Err(anyhow!(err))); - } - - cx.foreground_executor().spawn(async move { end_rx.await? }) - } - - fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) { - let sessions = self.sessions.borrow(); - let Some(session) = sessions.get(&session_id) else { - log::warn!("Attempted to cancel nonexistent session {}", session_id); - return; - }; - - let request_id = new_request_id(); - - let turn_state = session.turn_state.take(); - let TurnState::InProgress { end_tx } = turn_state else { - // Already cancelled or idle, put it back - session.turn_state.replace(turn_state); - return; - }; - - session.turn_state.replace(TurnState::CancelRequested { - end_tx, - request_id: request_id.clone(), - }); - - session - .outgoing_tx - .unbounded_send(SdkMessage::ControlRequest { - request_id, - request: ControlRequest::Interrupt, - }) - .log_err(); - } -} - -#[derive(Clone, Copy)] -enum ClaudeSessionMode { - Start, - #[expect(dead_code)] - Resume, -} + ) -> Task, Option)>> { + let name = self.name(); + let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().to_string()); + let is_remote = delegate.project.read(cx).is_via_remote_server(); + let store = delegate.store.downgrade(); + let default_mode = self.default_mode(cx); -fn spawn_claude( - command: &AgentServerCommand, - mode: ClaudeSessionMode, - session_id: acp::SessionId, - mcp_config_path: &Path, - root_dir: &Path, -) -> Result { - let child = util::command::new_smol_command(&command.path) - .args([ - "--input-format", - "stream-json", - "--output-format", - "stream-json", - "--print", - "--verbose", - "--mcp-config", - mcp_config_path.to_string_lossy().as_ref(), - "--permission-prompt-tool", - &format!( - "mcp__{}__{}", - mcp_server::SERVER_NAME, - mcp_server::PermissionTool::NAME, - ), - "--allowedTools", - &format!( - "mcp__{}__{},mcp__{}__{}", - mcp_server::SERVER_NAME, - mcp_server::EditTool::NAME, - mcp_server::SERVER_NAME, - mcp_server::ReadTool::NAME - ), - "--disallowedTools", - "Read,Edit", - ]) - .args(match mode { - ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()], - ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()], + cx.spawn(async move |cx| { + let (command, root_dir, login) = store + .update(cx, |store, cx| { + let agent = store + .get_external_agent(&CLAUDE_CODE_NAME.into()) + .context("Claude Code is not registered")?; + anyhow::Ok(agent.get_command( + root_dir.as_deref(), + Default::default(), + delegate.status_tx, + delegate.new_version_available, + &mut cx.to_async(), + )) + })?? + .await?; + let connection = crate::acp::connect( + name, + command, + root_dir.as_ref(), + default_mode, + is_remote, + cx, + ) + .await?; + Ok((connection, login)) }) - .args(command.args.iter().map(|arg| arg.as_str())) - .current_dir(root_dir) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()) - .kill_on_drop(true) - .spawn()?; - - Ok(child) -} - -struct ClaudeAgentSession { - outgoing_tx: UnboundedSender, - turn_state: Rc>, - _mcp_server: Option, - _handler_task: Task<()>, -} - -#[derive(Debug, Default)] -enum TurnState { - #[default] - None, - InProgress { - end_tx: oneshot::Sender>, - }, - CancelRequested { - end_tx: oneshot::Sender>, - request_id: String, - }, - CancelConfirmed { - end_tx: oneshot::Sender>, - }, -} - -impl TurnState { - fn is_cancelled(&self) -> bool { - matches!(self, TurnState::CancelConfirmed { .. }) - } - - fn end_tx(self) -> Option>> { - match self { - TurnState::None => None, - TurnState::InProgress { end_tx, .. } => Some(end_tx), - TurnState::CancelRequested { end_tx, .. } => Some(end_tx), - TurnState::CancelConfirmed { end_tx } => Some(end_tx), - } } - fn confirm_cancellation(self, id: &str) -> Self { - match self { - TurnState::CancelRequested { request_id, end_tx } if request_id == id => { - TurnState::CancelConfirmed { end_tx } - } - _ => self, - } - } -} - -impl ClaudeAgentSession { - async fn handle_message( - mut thread_rx: watch::Receiver>, - message: SdkMessage, - turn_state: Rc>, - cx: &mut AsyncApp, - ) { - match message { - // we should only be sending these out, they don't need to be in the thread - SdkMessage::ControlRequest { .. } => {} - SdkMessage::User { - message, - session_id: _, - } => { - let Some(thread) = thread_rx - .recv() - .await - .log_err() - .and_then(|entity| entity.upgrade()) - else { - log::error!("Received an SDK message but thread is gone"); - return; - }; - - for chunk in message.content.chunks() { - match chunk { - ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { - if !turn_state.borrow().is_cancelled() { - thread - .update(cx, |thread, cx| { - thread.push_user_content_block(None, text.into(), cx) - }) - .log_err(); - } - } - ContentChunk::ToolResult { - content, - tool_use_id, - } => { - let content = content.to_string(); - thread - .update(cx, |thread, cx| { - thread.update_tool_call( - acp::ToolCallUpdate { - id: acp::ToolCallId(tool_use_id.into()), - fields: acp::ToolCallUpdateFields { - status: if turn_state.borrow().is_cancelled() { - // Do not set to completed if turn was cancelled - None - } else { - Some(acp::ToolCallStatus::Completed) - }, - content: (!content.is_empty()) - .then(|| vec![content.into()]), - ..Default::default() - }, - }, - cx, - ) - }) - .log_err(); - } - ContentChunk::Thinking { .. } - | ContentChunk::RedactedThinking - | ContentChunk::ToolUse { .. } => { - debug_panic!( - "Should not get {:?} with role: assistant. should we handle this?", - chunk - ); - } - - ContentChunk::Image - | ContentChunk::Document - | ContentChunk::WebSearchToolResult => { - thread - .update(cx, |thread, cx| { - thread.push_assistant_content_block( - format!("Unsupported content: {:?}", chunk).into(), - false, - cx, - ) - }) - .log_err(); - } - } - } - } - SdkMessage::Assistant { - message, - session_id: _, - } => { - let Some(thread) = thread_rx - .recv() - .await - .log_err() - .and_then(|entity| entity.upgrade()) - else { - log::error!("Received an SDK message but thread is gone"); - return; - }; - - for chunk in message.content.chunks() { - match chunk { - ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { - thread - .update(cx, |thread, cx| { - thread.push_assistant_content_block(text.into(), false, cx) - }) - .log_err(); - } - ContentChunk::Thinking { thinking } => { - thread - .update(cx, |thread, cx| { - thread.push_assistant_content_block(thinking.into(), true, cx) - }) - .log_err(); - } - ContentChunk::RedactedThinking => { - thread - .update(cx, |thread, cx| { - thread.push_assistant_content_block( - "[REDACTED]".into(), - true, - cx, - ) - }) - .log_err(); - } - ContentChunk::ToolUse { id, name, input } => { - let claude_tool = ClaudeTool::infer(&name, input); - - thread - .update(cx, |thread, cx| { - if let ClaudeTool::TodoWrite(Some(params)) = claude_tool { - thread.update_plan( - acp::Plan { - entries: params - .todos - .into_iter() - .map(Into::into) - .collect(), - }, - cx, - ) - } else { - thread.upsert_tool_call( - claude_tool.as_acp(acp::ToolCallId(id.into())), - cx, - ); - } - }) - .log_err(); - } - ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => { - debug_panic!( - "Should not get tool results with role: assistant. should we handle this?" - ); - } - ContentChunk::Image | ContentChunk::Document => { - thread - .update(cx, |thread, cx| { - thread.push_assistant_content_block( - format!("Unsupported content: {:?}", chunk).into(), - false, - cx, - ) - }) - .log_err(); - } - } - } - } - SdkMessage::Result { - is_error, - subtype, - result, - .. - } => { - let turn_state = turn_state.take(); - let was_cancelled = turn_state.is_cancelled(); - let Some(end_turn_tx) = turn_state.end_tx() else { - debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn"); - return; - }; - - if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution) - { - end_turn_tx - .send(Err(anyhow!( - "Error: {}", - result.unwrap_or_else(|| subtype.to_string()) - ))) - .ok(); - } else { - let stop_reason = match subtype { - ResultErrorType::Success => acp::StopReason::EndTurn, - ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests, - ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled, - }; - end_turn_tx - .send(Ok(acp::PromptResponse { stop_reason })) - .ok(); - } - } - SdkMessage::ControlResponse { response } => { - if matches!(response.subtype, ResultErrorType::Success) { - let new_state = turn_state.take().confirm_cancellation(&response.request_id); - turn_state.replace(new_state); - } else { - log::error!("Control response error: {:?}", response); - } - } - SdkMessage::System { .. } => {} - } - } - - async fn handle_io( - mut outgoing_rx: UnboundedReceiver, - incoming_tx: UnboundedSender, - mut outgoing_bytes: impl Unpin + AsyncWrite, - incoming_bytes: impl Unpin + AsyncRead, - ) -> Result> { - let mut output_reader = BufReader::new(incoming_bytes); - let mut outgoing_line = Vec::new(); - let mut incoming_line = String::new(); - loop { - select_biased! { - message = outgoing_rx.next() => { - if let Some(message) = message { - outgoing_line.clear(); - serde_json::to_writer(&mut outgoing_line, &message)?; - log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line)); - outgoing_line.push(b'\n'); - outgoing_bytes.write_all(&outgoing_line).await.ok(); - } else { - break; - } - } - bytes_read = output_reader.read_line(&mut incoming_line).fuse() => { - if bytes_read? == 0 { - break - } - log::trace!("recv: {}", &incoming_line); - match serde_json::from_str::(&incoming_line) { - Ok(message) => { - incoming_tx.unbounded_send(message).log_err(); - } - Err(error) => { - log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}"); - } - } - incoming_line.clear(); - } - } - } - - Ok(outgoing_rx) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct Message { - role: Role, - content: Content, - #[serde(skip_serializing_if = "Option::is_none")] - id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - model: Option, - #[serde(skip_serializing_if = "Option::is_none")] - stop_reason: Option, - #[serde(skip_serializing_if = "Option::is_none")] - stop_sequence: Option, - #[serde(skip_serializing_if = "Option::is_none")] - usage: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] -enum Content { - UntaggedText(String), - Chunks(Vec), -} - -impl Content { - pub fn chunks(self) -> impl Iterator { - match self { - Self::Chunks(chunks) => chunks.into_iter(), - Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(), - } - } -} - -impl Display for Content { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Content::UntaggedText(txt) => write!(f, "{}", txt), - Content::Chunks(chunks) => { - for chunk in chunks { - write!(f, "{}", chunk)?; - } - Ok(()) - } - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -enum ContentChunk { - Text { - text: String, - }, - ToolUse { - id: String, - name: String, - input: serde_json::Value, - }, - ToolResult { - content: Content, - tool_use_id: String, - }, - Thinking { - thinking: String, - }, - RedactedThinking, - // TODO - Image, - Document, - WebSearchToolResult, - #[serde(untagged)] - UntaggedText(String), -} - -impl Display for ContentChunk { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ContentChunk::Text { text } => write!(f, "{}", text), - ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking), - ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"), - ContentChunk::UntaggedText(text) => write!(f, "{}", text), - ContentChunk::ToolResult { content, .. } => write!(f, "{}", content), - ContentChunk::Image - | ContentChunk::Document - | ContentChunk::ToolUse { .. } - | ContentChunk::WebSearchToolResult => { - write!(f, "\n{:?}\n", &self) - } - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct Usage { - input_tokens: u32, - cache_creation_input_tokens: u32, - cache_read_input_tokens: u32, - output_tokens: u32, - service_tier: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -enum Role { - System, - Assistant, - User, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct MessageParam { - role: Role, - content: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -enum SdkMessage { - // An assistant message - Assistant { - message: Message, // from Anthropic SDK - #[serde(skip_serializing_if = "Option::is_none")] - session_id: Option, - }, - // A user message - User { - message: Message, // from Anthropic SDK - #[serde(skip_serializing_if = "Option::is_none")] - session_id: Option, - }, - // Emitted as the last message in a conversation - Result { - subtype: ResultErrorType, - duration_ms: f64, - duration_api_ms: f64, - is_error: bool, - num_turns: i32, - #[serde(skip_serializing_if = "Option::is_none")] - result: Option, - session_id: String, - total_cost_usd: f64, - }, - // Emitted as the first message at the start of a conversation - System { - cwd: String, - session_id: String, - tools: Vec, - model: String, - mcp_servers: Vec, - #[serde(rename = "apiKeySource")] - api_key_source: String, - #[serde(rename = "permissionMode")] - permission_mode: PermissionMode, - }, - /// Messages used to control the conversation, outside of chat messages to the model - ControlRequest { - request_id: String, - request: ControlRequest, - }, - /// Response to a control request - ControlResponse { response: ControlResponse }, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "subtype", rename_all = "snake_case")] -enum ControlRequest { - /// Cancel the current conversation - Interrupt, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct ControlResponse { - request_id: String, - subtype: ResultErrorType, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] -#[serde(rename_all = "snake_case")] -enum ResultErrorType { - Success, - ErrorMaxTurns, - ErrorDuringExecution, -} - -impl Display for ResultErrorType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ResultErrorType::Success => write!(f, "success"), - ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"), - ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"), - } - } -} - -fn new_request_id() -> String { - use rand::Rng; - // In the Claude Code TS SDK they just generate a random 12 character string, - // `Math.random().toString(36).substring(2, 15)` - rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(12) - .map(char::from) - .collect() -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct McpServer { - name: String, - status: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -enum PermissionMode { - Default, - AcceptEdits, - BypassPermissions, - Plan, -} - -#[cfg(test)] -pub(crate) mod tests { - use super::*; - use crate::e2e_tests; - use gpui::TestAppContext; - use serde_json::json; - - crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow"); - - pub fn local_command() -> AgentServerCommand { - AgentServerCommand { - path: "claude".into(), - args: vec![], - env: None, - } - } - - #[gpui::test] - #[cfg_attr(not(feature = "e2e"), ignore)] - async fn test_todo_plan(cx: &mut TestAppContext) { - let fs = e2e_tests::init_test(cx).await; - let project = Project::test(fs, [], cx).await; - let thread = - e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await; - - thread - .update(cx, |thread, cx| { - thread.send_raw( - "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.", - cx, - ) - }) - .await - .unwrap(); - - let mut entries_len = 0; - - thread.read_with(cx, |thread, _| { - entries_len = thread.plan().entries.len(); - assert!(thread.plan().entries.len() > 0, "Empty plan"); - }); - - thread - .update(cx, |thread, cx| { - thread.send_raw( - "Mark the first entry status as in progress without acting on it.", - cx, - ) - }) - .await - .unwrap(); - - thread.read_with(cx, |thread, _| { - assert!(matches!( - thread.plan().entries[0].status, - acp::PlanEntryStatus::InProgress - )); - assert_eq!(thread.plan().entries.len(), entries_len); - }); - - thread - .update(cx, |thread, cx| { - thread.send_raw( - "Now mark the first entry as completed without acting on it.", - cx, - ) - }) - .await - .unwrap(); - - thread.read_with(cx, |thread, _| { - assert!(matches!( - thread.plan().entries[0].status, - acp::PlanEntryStatus::Completed - )); - assert_eq!(thread.plan().entries.len(), entries_len); - }); - } - - #[test] - fn test_deserialize_content_untagged_text() { - let json = json!("Hello, world!"); - let content: Content = serde_json::from_value(json).unwrap(); - match content { - Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"), - _ => panic!("Expected UntaggedText variant"), - } - } - - #[test] - fn test_deserialize_content_chunks() { - let json = json!([ - { - "type": "text", - "text": "Hello" - }, - { - "type": "tool_use", - "id": "tool_123", - "name": "calculator", - "input": {"operation": "add", "a": 1, "b": 2} - } - ]); - let content: Content = serde_json::from_value(json).unwrap(); - match content { - Content::Chunks(chunks) => { - assert_eq!(chunks.len(), 2); - match &chunks[0] { - ContentChunk::Text { text } => assert_eq!(text, "Hello"), - _ => panic!("Expected Text chunk"), - } - match &chunks[1] { - ContentChunk::ToolUse { id, name, input } => { - assert_eq!(id, "tool_123"); - assert_eq!(name, "calculator"); - assert_eq!(input["operation"], "add"); - assert_eq!(input["a"], 1); - assert_eq!(input["b"], 2); - } - _ => panic!("Expected ToolUse chunk"), - } - } - _ => panic!("Expected Chunks variant"), - } - } - - #[test] - fn test_deserialize_tool_result_untagged_text() { - let json = json!({ - "type": "tool_result", - "content": "Result content", - "tool_use_id": "tool_456" - }); - let chunk: ContentChunk = serde_json::from_value(json).unwrap(); - match chunk { - ContentChunk::ToolResult { - content, - tool_use_id, - } => { - match content { - Content::UntaggedText(text) => assert_eq!(text, "Result content"), - _ => panic!("Expected UntaggedText content"), - } - assert_eq!(tool_use_id, "tool_456"); - } - _ => panic!("Expected ToolResult variant"), - } - } - - #[test] - fn test_deserialize_tool_result_chunks() { - let json = json!({ - "type": "tool_result", - "content": [ - { - "type": "text", - "text": "Processing complete" - }, - { - "type": "text", - "text": "Result: 42" - } - ], - "tool_use_id": "tool_789" - }); - let chunk: ContentChunk = serde_json::from_value(json).unwrap(); - match chunk { - ContentChunk::ToolResult { - content, - tool_use_id, - } => { - match content { - Content::Chunks(chunks) => { - assert_eq!(chunks.len(), 2); - match &chunks[0] { - ContentChunk::Text { text } => assert_eq!(text, "Processing complete"), - _ => panic!("Expected Text chunk"), - } - match &chunks[1] { - ContentChunk::Text { text } => assert_eq!(text, "Result: 42"), - _ => panic!("Expected Text chunk"), - } - } - _ => panic!("Expected Chunks content"), - } - assert_eq!(tool_use_id, "tool_789"); - } - _ => panic!("Expected ToolResult variant"), - } + fn into_any(self: Rc) -> Rc { + self } } diff --git a/crates/agent_servers/src/claude/mcp_server.rs b/crates/agent_servers/src/claude/mcp_server.rs deleted file mode 100644 index 53a8556e74545bc339936d0b2f9f78444190af0c..0000000000000000000000000000000000000000 --- a/crates/agent_servers/src/claude/mcp_server.rs +++ /dev/null @@ -1,302 +0,0 @@ -use std::path::PathBuf; - -use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams}; -use acp_thread::AcpThread; -use agent_client_protocol as acp; -use anyhow::{Context, Result}; -use collections::HashMap; -use context_server::listener::{McpServerTool, ToolResponse}; -use context_server::types::{ - Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities, - ToolAnnotations, ToolResponseContent, ToolsCapabilities, requests, -}; -use gpui::{App, AsyncApp, Task, WeakEntity}; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; - -pub struct ClaudeZedMcpServer { - server: context_server::listener::McpServer, -} - -pub const SERVER_NAME: &str = "zed"; - -impl ClaudeZedMcpServer { - pub async fn new( - thread_rx: watch::Receiver>, - cx: &AsyncApp, - ) -> Result { - let mut mcp_server = context_server::listener::McpServer::new(cx).await?; - mcp_server.handle_request::(Self::handle_initialize); - - mcp_server.add_tool(PermissionTool { - thread_rx: thread_rx.clone(), - }); - mcp_server.add_tool(ReadTool { - thread_rx: thread_rx.clone(), - }); - mcp_server.add_tool(EditTool { - thread_rx: thread_rx.clone(), - }); - - Ok(Self { server: mcp_server }) - } - - pub fn server_config(&self) -> Result { - #[cfg(not(test))] - let zed_path = std::env::current_exe() - .context("finding current executable path for use in mcp_server")?; - - #[cfg(test)] - let zed_path = crate::e2e_tests::get_zed_path(); - - Ok(McpServerConfig { - command: zed_path, - args: vec![ - "--nc".into(), - self.server.socket_path().display().to_string(), - ], - env: None, - }) - } - - fn handle_initialize(_: InitializeParams, cx: &App) -> Task> { - cx.foreground_executor().spawn(async move { - Ok(InitializeResponse { - protocol_version: ProtocolVersion("2025-06-18".into()), - capabilities: ServerCapabilities { - experimental: None, - logging: None, - completions: None, - prompts: None, - resources: None, - tools: Some(ToolsCapabilities { - list_changed: Some(false), - }), - }, - server_info: Implementation { - name: SERVER_NAME.into(), - version: "0.1.0".into(), - }, - meta: None, - }) - }) - } -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -pub struct McpConfig { - pub mcp_servers: HashMap, -} - -#[derive(Serialize, Clone)] -#[serde(rename_all = "camelCase")] -pub struct McpServerConfig { - pub command: PathBuf, - pub args: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub env: Option>, -} - -// Tools - -#[derive(Clone)] -pub struct PermissionTool { - thread_rx: watch::Receiver>, -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct PermissionToolParams { - tool_name: String, - input: serde_json::Value, - tool_use_id: Option, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -pub struct PermissionToolResponse { - behavior: PermissionToolBehavior, - updated_input: serde_json::Value, -} - -#[derive(Serialize)] -#[serde(rename_all = "snake_case")] -enum PermissionToolBehavior { - Allow, - Deny, -} - -impl McpServerTool for PermissionTool { - type Input = PermissionToolParams; - type Output = (); - - const NAME: &'static str = "Confirmation"; - - fn description(&self) -> &'static str { - "Request permission for tool calls" - } - - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone()); - let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into()); - let allow_option_id = acp::PermissionOptionId("allow".into()); - let reject_option_id = acp::PermissionOptionId("reject".into()); - - let chosen_option = thread - .update(cx, |thread, cx| { - thread.request_tool_call_authorization( - claude_tool.as_acp(tool_call_id), - vec![ - acp::PermissionOption { - id: allow_option_id.clone(), - name: "Allow".into(), - kind: acp::PermissionOptionKind::AllowOnce, - }, - acp::PermissionOption { - id: reject_option_id.clone(), - name: "Reject".into(), - kind: acp::PermissionOptionKind::RejectOnce, - }, - ], - cx, - ) - })? - .await?; - - let response = if chosen_option == allow_option_id { - PermissionToolResponse { - behavior: PermissionToolBehavior::Allow, - updated_input: input.input, - } - } else { - debug_assert_eq!(chosen_option, reject_option_id); - PermissionToolResponse { - behavior: PermissionToolBehavior::Deny, - updated_input: input.input, - } - }; - - Ok(ToolResponse { - content: vec![ToolResponseContent::Text { - text: serde_json::to_string(&response)?, - }], - structured_content: (), - }) - } -} - -#[derive(Clone)] -pub struct ReadTool { - thread_rx: watch::Receiver>, -} - -impl McpServerTool for ReadTool { - type Input = ReadToolParams; - type Output = (); - - const NAME: &'static str = "Read"; - - fn description(&self) -> &'static str { - "Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents." - } - - fn annotations(&self) -> ToolAnnotations { - ToolAnnotations { - title: Some("Read file".to_string()), - read_only_hint: Some(true), - destructive_hint: Some(false), - open_world_hint: Some(false), - idempotent_hint: None, - } - } - - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - let content = thread - .update(cx, |thread, cx| { - thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx) - })? - .await?; - - Ok(ToolResponse { - content: vec![ToolResponseContent::Text { text: content }], - structured_content: (), - }) - } -} - -#[derive(Clone)] -pub struct EditTool { - thread_rx: watch::Receiver>, -} - -impl McpServerTool for EditTool { - type Input = EditToolParams; - type Output = (); - - const NAME: &'static str = "Edit"; - - fn description(&self) -> &'static str { - "Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better." - } - - fn annotations(&self) -> ToolAnnotations { - ToolAnnotations { - title: Some("Edit file".to_string()), - read_only_hint: Some(false), - destructive_hint: Some(false), - open_world_hint: Some(false), - idempotent_hint: Some(false), - } - } - - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - let content = thread - .update(cx, |thread, cx| { - thread.read_text_file(input.abs_path.clone(), None, None, true, cx) - })? - .await?; - - let new_content = content.replace(&input.old_text, &input.new_text); - if new_content == content { - return Err(anyhow::anyhow!("The old_text was not found in the content")); - } - - thread - .update(cx, |thread, cx| { - thread.write_text_file(input.abs_path, new_content, cx) - })? - .await?; - - Ok(ToolResponse { - content: vec![], - structured_content: (), - }) - } -} diff --git a/crates/agent_servers/src/claude/tools.rs b/crates/agent_servers/src/claude/tools.rs deleted file mode 100644 index 7ca150c0bd0b30b958a4791db9d01684d16460d6..0000000000000000000000000000000000000000 --- a/crates/agent_servers/src/claude/tools.rs +++ /dev/null @@ -1,661 +0,0 @@ -use std::path::PathBuf; - -use agent_client_protocol as acp; -use itertools::Itertools; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use util::ResultExt; - -pub enum ClaudeTool { - Task(Option), - NotebookRead(Option), - NotebookEdit(Option), - Edit(Option), - MultiEdit(Option), - ReadFile(Option), - Write(Option), - Ls(Option), - Glob(Option), - Grep(Option), - Terminal(Option), - WebFetch(Option), - WebSearch(Option), - TodoWrite(Option), - ExitPlanMode(Option), - Other { - name: String, - input: serde_json::Value, - }, -} - -impl ClaudeTool { - pub fn infer(tool_name: &str, input: serde_json::Value) -> Self { - match tool_name { - // Known tools - "mcp__zed__Read" => Self::ReadFile(serde_json::from_value(input).log_err()), - "mcp__zed__Edit" => Self::Edit(serde_json::from_value(input).log_err()), - "MultiEdit" => Self::MultiEdit(serde_json::from_value(input).log_err()), - "Write" => Self::Write(serde_json::from_value(input).log_err()), - "LS" => Self::Ls(serde_json::from_value(input).log_err()), - "Glob" => Self::Glob(serde_json::from_value(input).log_err()), - "Grep" => Self::Grep(serde_json::from_value(input).log_err()), - "Bash" => Self::Terminal(serde_json::from_value(input).log_err()), - "WebFetch" => Self::WebFetch(serde_json::from_value(input).log_err()), - "WebSearch" => Self::WebSearch(serde_json::from_value(input).log_err()), - "TodoWrite" => Self::TodoWrite(serde_json::from_value(input).log_err()), - "exit_plan_mode" => Self::ExitPlanMode(serde_json::from_value(input).log_err()), - "Task" => Self::Task(serde_json::from_value(input).log_err()), - "NotebookRead" => Self::NotebookRead(serde_json::from_value(input).log_err()), - "NotebookEdit" => Self::NotebookEdit(serde_json::from_value(input).log_err()), - // Inferred from name - _ => { - let tool_name = tool_name.to_lowercase(); - - if tool_name.contains("edit") || tool_name.contains("write") { - Self::Edit(None) - } else if tool_name.contains("terminal") { - Self::Terminal(None) - } else { - Self::Other { - name: tool_name.to_string(), - input, - } - } - } - } - } - - pub fn label(&self) -> String { - match &self { - Self::Task(Some(params)) => params.description.clone(), - Self::Task(None) => "Task".into(), - Self::NotebookRead(Some(params)) => { - format!("Read Notebook {}", params.notebook_path.display()) - } - Self::NotebookRead(None) => "Read Notebook".into(), - Self::NotebookEdit(Some(params)) => { - format!("Edit Notebook {}", params.notebook_path.display()) - } - Self::NotebookEdit(None) => "Edit Notebook".into(), - Self::Terminal(Some(params)) => format!("`{}`", params.command), - Self::Terminal(None) => "Terminal".into(), - Self::ReadFile(_) => "Read File".into(), - Self::Ls(Some(params)) => { - format!("List Directory {}", params.path.display()) - } - Self::Ls(None) => "List Directory".into(), - Self::Edit(Some(params)) => { - format!("Edit {}", params.abs_path.display()) - } - Self::Edit(None) => "Edit".into(), - Self::MultiEdit(Some(params)) => { - format!("Multi Edit {}", params.file_path.display()) - } - Self::MultiEdit(None) => "Multi Edit".into(), - Self::Write(Some(params)) => { - format!("Write {}", params.file_path.display()) - } - Self::Write(None) => "Write".into(), - Self::Glob(Some(params)) => { - format!("Glob `{params}`") - } - Self::Glob(None) => "Glob".into(), - Self::Grep(Some(params)) => format!("`{params}`"), - Self::Grep(None) => "Grep".into(), - Self::WebFetch(Some(params)) => format!("Fetch {}", params.url), - Self::WebFetch(None) => "Fetch".into(), - Self::WebSearch(Some(params)) => format!("Web Search: {}", params), - Self::WebSearch(None) => "Web Search".into(), - Self::TodoWrite(Some(params)) => format!( - "Update TODOs: {}", - params.todos.iter().map(|todo| &todo.content).join(", ") - ), - Self::TodoWrite(None) => "Update TODOs".into(), - Self::ExitPlanMode(_) => "Exit Plan Mode".into(), - Self::Other { name, .. } => name.clone(), - } - } - pub fn content(&self) -> Vec { - match &self { - Self::Other { input, .. } => vec![ - format!( - "```json\n{}```", - serde_json::to_string_pretty(&input).unwrap_or("{}".to_string()) - ) - .into(), - ], - Self::Task(Some(params)) => vec![params.prompt.clone().into()], - Self::NotebookRead(Some(params)) => { - vec![params.notebook_path.display().to_string().into()] - } - Self::NotebookEdit(Some(params)) => vec![params.new_source.clone().into()], - Self::Terminal(Some(params)) => vec![ - format!( - "`{}`\n\n{}", - params.command, - params.description.as_deref().unwrap_or_default() - ) - .into(), - ], - Self::ReadFile(Some(params)) => vec![params.abs_path.display().to_string().into()], - Self::Ls(Some(params)) => vec![params.path.display().to_string().into()], - Self::Glob(Some(params)) => vec![params.to_string().into()], - Self::Grep(Some(params)) => vec![format!("`{params}`").into()], - Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()], - Self::WebSearch(Some(params)) => vec![params.to_string().into()], - Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()], - Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff { - diff: acp::Diff { - path: params.abs_path.clone(), - old_text: Some(params.old_text.clone()), - new_text: params.new_text.clone(), - }, - }], - Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff { - diff: acp::Diff { - path: params.file_path.clone(), - old_text: None, - new_text: params.content.clone(), - }, - }], - Self::MultiEdit(Some(params)) => { - // todo: show multiple edits in a multibuffer? - params - .edits - .first() - .map(|edit| { - vec![acp::ToolCallContent::Diff { - diff: acp::Diff { - path: params.file_path.clone(), - old_text: Some(edit.old_string.clone()), - new_text: edit.new_string.clone(), - }, - }] - }) - .unwrap_or_default() - } - Self::TodoWrite(Some(_)) => { - // These are mapped to plan updates later - vec![] - } - Self::Task(None) - | Self::NotebookRead(None) - | Self::NotebookEdit(None) - | Self::Terminal(None) - | Self::ReadFile(None) - | Self::Ls(None) - | Self::Glob(None) - | Self::Grep(None) - | Self::WebFetch(None) - | Self::WebSearch(None) - | Self::TodoWrite(None) - | Self::ExitPlanMode(None) - | Self::Edit(None) - | Self::Write(None) - | Self::MultiEdit(None) => vec![], - } - } - - pub fn kind(&self) -> acp::ToolKind { - match self { - Self::Task(_) => acp::ToolKind::Think, - Self::NotebookRead(_) => acp::ToolKind::Read, - Self::NotebookEdit(_) => acp::ToolKind::Edit, - Self::Edit(_) => acp::ToolKind::Edit, - Self::MultiEdit(_) => acp::ToolKind::Edit, - Self::Write(_) => acp::ToolKind::Edit, - Self::ReadFile(_) => acp::ToolKind::Read, - Self::Ls(_) => acp::ToolKind::Search, - Self::Glob(_) => acp::ToolKind::Search, - Self::Grep(_) => acp::ToolKind::Search, - Self::Terminal(_) => acp::ToolKind::Execute, - Self::WebSearch(_) => acp::ToolKind::Search, - Self::WebFetch(_) => acp::ToolKind::Fetch, - Self::TodoWrite(_) => acp::ToolKind::Think, - Self::ExitPlanMode(_) => acp::ToolKind::Think, - Self::Other { .. } => acp::ToolKind::Other, - } - } - - pub fn locations(&self) -> Vec { - match &self { - Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![acp::ToolCallLocation { - path: abs_path.clone(), - line: None, - }], - Self::MultiEdit(Some(MultiEditToolParams { file_path, .. })) => { - vec![acp::ToolCallLocation { - path: file_path.clone(), - line: None, - }] - } - Self::Write(Some(WriteToolParams { file_path, .. })) => { - vec![acp::ToolCallLocation { - path: file_path.clone(), - line: None, - }] - } - Self::ReadFile(Some(ReadToolParams { - abs_path, offset, .. - })) => vec![acp::ToolCallLocation { - path: abs_path.clone(), - line: *offset, - }], - Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => { - vec![acp::ToolCallLocation { - path: notebook_path.clone(), - line: None, - }] - } - Self::NotebookEdit(Some(NotebookEditToolParams { notebook_path, .. })) => { - vec![acp::ToolCallLocation { - path: notebook_path.clone(), - line: None, - }] - } - Self::Glob(Some(GlobToolParams { - path: Some(path), .. - })) => vec![acp::ToolCallLocation { - path: path.clone(), - line: None, - }], - Self::Ls(Some(LsToolParams { path, .. })) => vec![acp::ToolCallLocation { - path: path.clone(), - line: None, - }], - Self::Grep(Some(GrepToolParams { - path: Some(path), .. - })) => vec![acp::ToolCallLocation { - path: PathBuf::from(path), - line: None, - }], - Self::Task(_) - | Self::NotebookRead(None) - | Self::NotebookEdit(None) - | Self::Edit(None) - | Self::MultiEdit(None) - | Self::Write(None) - | Self::ReadFile(None) - | Self::Ls(None) - | Self::Glob(_) - | Self::Grep(_) - | Self::Terminal(_) - | Self::WebFetch(_) - | Self::WebSearch(_) - | Self::TodoWrite(_) - | Self::ExitPlanMode(_) - | Self::Other { .. } => vec![], - } - } - - pub fn as_acp(&self, id: acp::ToolCallId) -> acp::ToolCall { - acp::ToolCall { - id, - kind: self.kind(), - status: acp::ToolCallStatus::InProgress, - title: self.label(), - content: self.content(), - locations: self.locations(), - raw_input: None, - raw_output: None, - } - } -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct EditToolParams { - /// The absolute path to the file to read. - pub abs_path: PathBuf, - /// The old text to replace (must be unique in the file) - pub old_text: String, - /// The new text. - pub new_text: String, -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct ReadToolParams { - /// The absolute path to the file to read. - pub abs_path: PathBuf, - /// Which line to start reading from. Omit to start from the beginning. - #[serde(skip_serializing_if = "Option::is_none")] - pub offset: Option, - /// How many lines to read. Omit for the whole file. - #[serde(skip_serializing_if = "Option::is_none")] - pub limit: Option, -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct WriteToolParams { - /// Absolute path for new file - pub file_path: PathBuf, - /// File content - pub content: String, -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct BashToolParams { - /// Shell command to execute - pub command: String, - /// 5-10 word description of what command does - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// Timeout in ms (max 600000ms/10min, default 120000ms) - #[serde(skip_serializing_if = "Option::is_none")] - pub timeout: Option, -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct GlobToolParams { - /// Glob pattern like **/*.js or src/**/*.ts - pub pattern: String, - /// Directory to search in (omit for current directory) - #[serde(skip_serializing_if = "Option::is_none")] - pub path: Option, -} - -impl std::fmt::Display for GlobToolParams { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if let Some(path) = &self.path { - write!(f, "{}", path.display())?; - } - write!(f, "{}", self.pattern) - } -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct LsToolParams { - /// Absolute path to directory - pub path: PathBuf, - /// Array of glob patterns to ignore - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub ignore: Vec, -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct GrepToolParams { - /// Regex pattern to search for - pub pattern: String, - /// File/directory to search (defaults to current directory) - #[serde(skip_serializing_if = "Option::is_none")] - pub path: Option, - /// "content" (shows lines), "files_with_matches" (default), "count" - #[serde(skip_serializing_if = "Option::is_none")] - pub output_mode: Option, - /// Filter files with glob pattern like "*.js" - #[serde(skip_serializing_if = "Option::is_none")] - pub glob: Option, - /// File type filter like "js", "py", "rust" - #[serde(rename = "type", skip_serializing_if = "Option::is_none")] - pub file_type: Option, - /// Case insensitive search - #[serde(rename = "-i", default, skip_serializing_if = "is_false")] - pub case_insensitive: bool, - /// Show line numbers (content mode only) - #[serde(rename = "-n", default, skip_serializing_if = "is_false")] - pub line_numbers: bool, - /// Lines after match (content mode only) - #[serde(rename = "-A", skip_serializing_if = "Option::is_none")] - pub after_context: Option, - /// Lines before match (content mode only) - #[serde(rename = "-B", skip_serializing_if = "Option::is_none")] - pub before_context: Option, - /// Lines before and after match (content mode only) - #[serde(rename = "-C", skip_serializing_if = "Option::is_none")] - pub context: Option, - /// Enable multiline/cross-line matching - #[serde(default, skip_serializing_if = "is_false")] - pub multiline: bool, - /// Limit output to first N results - #[serde(skip_serializing_if = "Option::is_none")] - pub head_limit: Option, -} - -impl std::fmt::Display for GrepToolParams { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "grep")?; - - // Boolean flags - if self.case_insensitive { - write!(f, " -i")?; - } - if self.line_numbers { - write!(f, " -n")?; - } - - // Context options - if let Some(after) = self.after_context { - write!(f, " -A {}", after)?; - } - if let Some(before) = self.before_context { - write!(f, " -B {}", before)?; - } - if let Some(context) = self.context { - write!(f, " -C {}", context)?; - } - - // Output mode - if let Some(mode) = &self.output_mode { - match mode { - GrepOutputMode::FilesWithMatches => write!(f, " -l")?, - GrepOutputMode::Count => write!(f, " -c")?, - GrepOutputMode::Content => {} // Default mode - } - } - - // Head limit - if let Some(limit) = self.head_limit { - write!(f, " | head -{}", limit)?; - } - - // Glob pattern - if let Some(glob) = &self.glob { - write!(f, " --include=\"{}\"", glob)?; - } - - // File type - if let Some(file_type) = &self.file_type { - write!(f, " --type={}", file_type)?; - } - - // Multiline - if self.multiline { - write!(f, " -P")?; // Perl-compatible regex for multiline - } - - // Pattern (escaped if contains special characters) - write!(f, " \"{}\"", self.pattern)?; - - // Path - if let Some(path) = &self.path { - write!(f, " {}", path)?; - } - - Ok(()) - } -} - -#[derive(Default, Deserialize, Serialize, JsonSchema, strum::Display, Debug)] -#[serde(rename_all = "snake_case")] -pub enum TodoPriority { - High, - #[default] - Medium, - Low, -} - -impl Into for TodoPriority { - fn into(self) -> acp::PlanEntryPriority { - match self { - TodoPriority::High => acp::PlanEntryPriority::High, - TodoPriority::Medium => acp::PlanEntryPriority::Medium, - TodoPriority::Low => acp::PlanEntryPriority::Low, - } - } -} - -#[derive(Deserialize, Serialize, JsonSchema, Debug)] -#[serde(rename_all = "snake_case")] -pub enum TodoStatus { - Pending, - InProgress, - Completed, -} - -impl Into for TodoStatus { - fn into(self) -> acp::PlanEntryStatus { - match self { - TodoStatus::Pending => acp::PlanEntryStatus::Pending, - TodoStatus::InProgress => acp::PlanEntryStatus::InProgress, - TodoStatus::Completed => acp::PlanEntryStatus::Completed, - } - } -} - -#[derive(Deserialize, Serialize, JsonSchema, Debug)] -pub struct Todo { - /// Task description - pub content: String, - /// Current status of the todo - pub status: TodoStatus, - /// Priority level of the todo - #[serde(default)] - pub priority: TodoPriority, -} - -impl Into for Todo { - fn into(self) -> acp::PlanEntry { - acp::PlanEntry { - content: self.content, - priority: self.priority.into(), - status: self.status.into(), - } - } -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct TodoWriteToolParams { - pub todos: Vec, -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct ExitPlanModeToolParams { - /// Implementation plan in markdown format - pub plan: String, -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct TaskToolParams { - /// Short 3-5 word description of task - pub description: String, - /// Detailed task for agent to perform - pub prompt: String, -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct NotebookReadToolParams { - /// Absolute path to .ipynb file - pub notebook_path: PathBuf, - /// Specific cell ID to read - #[serde(skip_serializing_if = "Option::is_none")] - pub cell_id: Option, -} - -#[derive(Deserialize, Serialize, JsonSchema, Debug)] -#[serde(rename_all = "snake_case")] -pub enum CellType { - Code, - Markdown, -} - -#[derive(Deserialize, Serialize, JsonSchema, Debug)] -#[serde(rename_all = "snake_case")] -pub enum EditMode { - Replace, - Insert, - Delete, -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct NotebookEditToolParams { - /// Absolute path to .ipynb file - pub notebook_path: PathBuf, - /// New cell content - pub new_source: String, - /// Cell ID to edit - #[serde(skip_serializing_if = "Option::is_none")] - pub cell_id: Option, - /// Type of cell (code or markdown) - #[serde(skip_serializing_if = "Option::is_none")] - pub cell_type: Option, - /// Edit operation mode - #[serde(skip_serializing_if = "Option::is_none")] - pub edit_mode: Option, -} - -#[derive(Deserialize, Serialize, JsonSchema, Debug)] -pub struct MultiEditItem { - /// The text to search for and replace - pub old_string: String, - /// The replacement text - pub new_string: String, - /// Whether to replace all occurrences or just the first - #[serde(default, skip_serializing_if = "is_false")] - pub replace_all: bool, -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct MultiEditToolParams { - /// Absolute path to file - pub file_path: PathBuf, - /// List of edits to apply - pub edits: Vec, -} - -fn is_false(v: &bool) -> bool { - !*v -} - -#[derive(Deserialize, JsonSchema, Debug)] -#[serde(rename_all = "snake_case")] -pub enum GrepOutputMode { - Content, - FilesWithMatches, - Count, -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct WebFetchToolParams { - /// Valid URL to fetch - #[serde(rename = "url")] - pub url: String, - /// What to extract from content - pub prompt: String, -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct WebSearchToolParams { - /// Search query (min 2 chars) - pub query: String, - /// Only include these domains - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub allowed_domains: Vec, - /// Exclude these domains - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub blocked_domains: Vec, -} - -impl std::fmt::Display for WebSearchToolParams { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "\"{}\"", self.query)?; - - if !self.allowed_domains.is_empty() { - write!(f, " (allowed: {})", self.allowed_domains.join(", "))?; - } - - if !self.blocked_domains.is_empty() { - write!(f, " (blocked: {})", self.blocked_domains.join(", "))?; - } - - Ok(()) - } -} diff --git a/crates/agent_servers/src/custom.rs b/crates/agent_servers/src/custom.rs new file mode 100644 index 0000000000000000000000000000000000000000..f035952a7939201e4b7d990b97e1fc695105d505 --- /dev/null +++ b/crates/agent_servers/src/custom.rs @@ -0,0 +1,102 @@ +use crate::AgentServerDelegate; +use acp_thread::AgentConnection; +use agent_client_protocol as acp; +use anyhow::{Context as _, Result}; +use fs::Fs; +use gpui::{App, AppContext as _, SharedString, Task}; +use project::agent_server_store::{AllAgentServersSettings, ExternalAgentServerName}; +use settings::{SettingsStore, update_settings_file}; +use std::{path::Path, rc::Rc, sync::Arc}; +use ui::IconName; + +/// A generic agent server implementation for custom user-defined agents +pub struct CustomAgentServer { + name: SharedString, +} + +impl CustomAgentServer { + pub fn new(name: SharedString) -> Self { + Self { name } + } +} + +impl crate::AgentServer for CustomAgentServer { + fn telemetry_id(&self) -> &'static str { + "custom" + } + + fn name(&self) -> SharedString { + self.name.clone() + } + + fn logo(&self) -> IconName { + IconName::Terminal + } + + fn default_mode(&self, cx: &mut App) -> Option { + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings + .get::(None) + .custom + .get(&self.name()) + .cloned() + }); + + settings + .as_ref() + .and_then(|s| s.default_mode.clone().map(|m| acp::SessionModeId(m.into()))) + } + + fn set_default_mode(&self, mode_id: Option, fs: Arc, cx: &mut App) { + let name = self.name(); + update_settings_file::(fs, cx, move |settings, _| { + settings.custom.get_mut(&name).unwrap().default_mode = mode_id.map(|m| m.to_string()) + }); + } + + fn connect( + &self, + root_dir: Option<&Path>, + delegate: AgentServerDelegate, + cx: &mut App, + ) -> Task, Option)>> { + let name = self.name(); + let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().to_string()); + let is_remote = delegate.project.read(cx).is_via_remote_server(); + let default_mode = self.default_mode(cx); + let store = delegate.store.downgrade(); + + cx.spawn(async move |cx| { + let (command, root_dir, login) = store + .update(cx, |store, cx| { + let agent = store + .get_external_agent(&ExternalAgentServerName(name.clone())) + .with_context(|| { + format!("Custom agent server `{}` is not registered", name) + })?; + anyhow::Ok(agent.get_command( + root_dir.as_deref(), + Default::default(), + delegate.status_tx, + delegate.new_version_available, + &mut cx.to_async(), + )) + })?? + .await?; + let connection = crate::acp::connect( + name, + command, + root_dir.as_ref(), + default_mode, + is_remote, + cx, + ) + .await?; + Ok((connection, login)) + }) + } + + fn into_any(self: Rc) -> Rc { + self + } +} diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index ec6ca29b9dd1a902708a8786ddc6853955da5532..a4af1b6ad5c6048d27764653322c116f655f85fb 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -1,24 +1,33 @@ +use crate::{AgentServer, AgentServerDelegate}; +use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus}; +use agent_client_protocol as acp; +use futures::{FutureExt, StreamExt, channel::mpsc, select}; +use gpui::{AppContext, Entity, TestAppContext}; +use indoc::indoc; +#[cfg(test)] +use project::agent_server_store::BuiltinAgentServerSettings; +use project::{FakeFs, Project, agent_server_store::AllAgentServersSettings}; use std::{ path::{Path, PathBuf}, sync::Arc, time::Duration, }; - -use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings}; -use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus}; -use agent_client_protocol as acp; - -use futures::{FutureExt, StreamExt, channel::mpsc, select}; -use gpui::{Entity, TestAppContext}; -use indoc::indoc; -use project::{FakeFs, Project}; -use settings::{Settings, SettingsStore}; use util::path; -pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) { - let fs = init_test(cx).await; - let project = Project::test(fs, [], cx).await; - let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; +pub async fn test_basic(server: F, cx: &mut TestAppContext) +where + T: AgentServer + 'static, + F: AsyncFn(&Arc, &Entity, &mut TestAppContext) -> T, +{ + let fs = init_test(cx).await as Arc; + let project = Project::test(fs.clone(), [], cx).await; + let thread = new_test_thread( + server(&fs, &project, cx).await, + project.clone(), + "/private/tmp", + cx, + ) + .await; thread .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx)) @@ -42,8 +51,12 @@ pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppCont }); } -pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) { - let _fs = init_test(cx).await; +pub async fn test_path_mentions(server: F, cx: &mut TestAppContext) +where + T: AgentServer + 'static, + F: AsyncFn(&Arc, &Entity, &mut TestAppContext) -> T, +{ + let fs = init_test(cx).await as _; let tempdir = tempfile::tempdir().unwrap(); std::fs::write( @@ -56,7 +69,13 @@ pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut Tes ) .expect("failed to write file"); let project = Project::example([tempdir.path()], &mut cx.to_async()).await; - let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await; + let thread = new_test_thread( + server(&fs, &project, cx).await, + project.clone(), + tempdir.path(), + cx, + ) + .await; thread .update(cx, |thread, cx| { thread.send( @@ -110,15 +129,25 @@ pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut Tes drop(tempdir); } -pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) { - let _fs = init_test(cx).await; +pub async fn test_tool_call(server: F, cx: &mut TestAppContext) +where + T: AgentServer + 'static, + F: AsyncFn(&Arc, &Entity, &mut TestAppContext) -> T, +{ + let fs = init_test(cx).await as _; let tempdir = tempfile::tempdir().unwrap(); let foo_path = tempdir.path().join("foo"); std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file"); let project = Project::example([tempdir.path()], &mut cx.to_async()).await; - let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; + let thread = new_test_thread( + server(&fs, &project, cx).await, + project.clone(), + "/private/tmp", + cx, + ) + .await; thread .update(cx, |thread, cx| { @@ -134,7 +163,9 @@ pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestApp matches!( entry, AgentThreadEntry::ToolCall(ToolCall { - status: ToolCallStatus::Allowed { .. }, + status: ToolCallStatus::Pending + | ToolCallStatus::InProgress + | ToolCallStatus::Completed, .. }) ) @@ -150,14 +181,23 @@ pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestApp drop(tempdir); } -pub async fn test_tool_call_with_permission( - server: impl AgentServer + 'static, +pub async fn test_tool_call_with_permission( + server: F, allow_option_id: acp::PermissionOptionId, cx: &mut TestAppContext, -) { - let fs = init_test(cx).await; - let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; - let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; +) where + T: AgentServer + 'static, + F: AsyncFn(&Arc, &Entity, &mut TestAppContext) -> T, +{ + let fs = init_test(cx).await as Arc; + let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await; + let thread = new_test_thread( + server(&fs, &project, cx).await, + project.clone(), + "/private/tmp", + cx, + ) + .await; let full_turn = thread.update(cx, |thread, cx| { thread.send_raw( r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#, @@ -212,7 +252,9 @@ pub async fn test_tool_call_with_permission( assert!(thread.entries().iter().any(|entry| matches!( entry, AgentThreadEntry::ToolCall(ToolCall { - status: ToolCallStatus::Allowed { .. }, + status: ToolCallStatus::Pending + | ToolCallStatus::InProgress + | ToolCallStatus::Completed, .. }) ))); @@ -223,7 +265,9 @@ pub async fn test_tool_call_with_permission( thread.read_with(cx, |thread, cx| { let AgentThreadEntry::ToolCall(ToolCall { content, - status: ToolCallStatus::Allowed { .. }, + status: ToolCallStatus::Pending + | ToolCallStatus::InProgress + | ToolCallStatus::Completed, .. }) = thread .entries() @@ -241,11 +285,21 @@ pub async fn test_tool_call_with_permission( }); } -pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) { - let fs = init_test(cx).await; - - let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; - let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; +pub async fn test_cancel(server: F, cx: &mut TestAppContext) +where + T: AgentServer + 'static, + F: AsyncFn(&Arc, &Entity, &mut TestAppContext) -> T, +{ + let fs = init_test(cx).await as Arc; + + let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await; + let thread = new_test_thread( + server(&fs, &project, cx).await, + project.clone(), + "/private/tmp", + cx, + ) + .await; let _ = thread.update(cx, |thread, cx| { thread.send_raw( r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#, @@ -310,10 +364,20 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon }); } -pub async fn test_thread_drop(server: impl AgentServer + 'static, cx: &mut TestAppContext) { - let fs = init_test(cx).await; - let project = Project::test(fs, [], cx).await; - let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; +pub async fn test_thread_drop(server: F, cx: &mut TestAppContext) +where + T: AgentServer + 'static, + F: AsyncFn(&Arc, &Entity, &mut TestAppContext) -> T, +{ + let fs = init_test(cx).await as Arc; + let project = Project::test(fs.clone(), [], cx).await; + let thread = new_test_thread( + server(&fs, &project, cx).await, + project.clone(), + "/private/tmp", + cx, + ) + .await; thread .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx)) @@ -380,27 +444,43 @@ macro_rules! common_e2e_tests { } }; } +pub use common_e2e_tests; // Helpers pub async fn init_test(cx: &mut TestAppContext) -> Arc { + use settings::Settings; + env_logger::try_init().ok(); cx.update(|cx| { - let settings_store = SettingsStore::test(cx); + let settings_store = settings::SettingsStore::test(cx); cx.set_global(settings_store); Project::init_settings(cx); language::init(cx); - crate::settings::init(cx); - - crate::AllAgentServersSettings::override_global( + gpui_tokio::init(cx); + let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap(); + cx.set_http_client(Arc::new(http_client)); + client::init_settings(cx); + let client = client::Client::production(cx); + let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx)); + language_model::init(client.clone(), cx); + language_models::init(user_store, client, cx); + agent_settings::init(cx); + AllAgentServersSettings::register(cx); + + #[cfg(test)] + AllAgentServersSettings::override_global( AllAgentServersSettings { - claude: Some(AgentServerSettings { - command: crate::claude::tests::local_command(), - }), - gemini: Some(AgentServerSettings { - command: crate::gemini::tests::local_command(), + claude: Some(BuiltinAgentServerSettings { + path: Some("claude-code-acp".into()), + args: None, + env: None, + ignore_system_version: None, + default_mode: None, }), + gemini: Some(crate::gemini::tests::local_command().into()), + custom: collections::HashMap::default(), }, cx, ); @@ -417,17 +497,17 @@ pub async fn new_test_thread( current_dir: impl AsRef, cx: &mut TestAppContext, ) -> Entity { - let connection = cx - .update(|cx| server.connect(current_dir.as_ref(), &project, cx)) - .await - .unwrap(); + let store = project.read_with(cx, |project, _| project.agent_server_store().clone()); + let delegate = AgentServerDelegate::new(store, project.clone(), None, None); - let thread = connection - .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async()) + let (connection, _) = cx + .update(|cx| server.connect(Some(current_dir.as_ref()), delegate, cx)) .await .unwrap(); - thread + cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx)) + .await + .unwrap() } pub async fn run_until_first_tool_call( @@ -465,7 +545,7 @@ pub fn get_zed_path() -> PathBuf { while zed_path .file_name() - .map_or(true, |name| name.to_string_lossy() != "debug") + .is_none_or(|name| name.to_string_lossy() != "debug") { if !zed_path.pop() { panic!("Could not find target directory"); diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index ad883f6da8bd344044e1db0051ca6f24120d5057..01f15557899e1c7826e91d1555320996eccd0f45 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -1,32 +1,26 @@ -use std::path::Path; use std::rc::Rc; - -use crate::{AgentServer, AgentServerCommand}; -use acp_thread::{AgentConnection, LoadError}; -use anyhow::Result; -use gpui::{Entity, Task}; -use project::Project; +use std::{any::Any, path::Path}; + +use crate::{AgentServer, AgentServerDelegate}; +use acp_thread::AgentConnection; +use anyhow::{Context as _, Result}; +use client::ProxySettings; +use collections::HashMap; +use gpui::{App, AppContext, SharedString, Task}; +use language_models::provider::google::GoogleLanguageModelProvider; +use project::agent_server_store::GEMINI_NAME; use settings::SettingsStore; -use ui::App; - -use crate::AllAgentServersSettings; #[derive(Clone)] pub struct Gemini; -const ACP_ARG: &str = "--experimental-acp"; - impl AgentServer for Gemini { - fn name(&self) -> &'static str { - "Gemini" - } - - fn empty_state_headline(&self) -> &'static str { - "Welcome to Gemini" + fn telemetry_id(&self) -> &'static str { + "gemini-cli" } - fn empty_state_message(&self) -> &'static str { - "Ask questions, edit files, run commands.\nBe specific for the best results." + fn name(&self) -> SharedString { + "Gemini CLI".into() } fn logo(&self) -> ui::IconName { @@ -35,66 +29,73 @@ impl AgentServer for Gemini { fn connect( &self, - root_dir: &Path, - project: &Entity, + root_dir: Option<&Path>, + delegate: AgentServerDelegate, cx: &mut App, - ) -> Task>> { - let project = project.clone(); - let root_dir = root_dir.to_path_buf(); - let server_name = self.name(); - cx.spawn(async move |cx| { - let settings = cx.read_global(|settings: &SettingsStore, _| { - settings.get::(None).gemini.clone() - })?; - - let Some(command) = - AgentServerCommand::resolve("gemini", &[ACP_ARG], None, settings, &project, cx).await - else { - anyhow::bail!("Failed to find gemini binary"); - }; - - let result = crate::acp::connect(server_name, command.clone(), &root_dir, cx).await; - if result.is_err() { - let version_fut = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .arg("--version") - .kill_on_drop(true) - .output(); - - let help_fut = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .arg("--help") - .kill_on_drop(true) - .output(); + ) -> Task, Option)>> { + let name = self.name(); + let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().to_string()); + let is_remote = delegate.project.read(cx).is_via_remote_server(); + let store = delegate.store.downgrade(); + let proxy_url = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).proxy.clone() + }); + let default_mode = self.default_mode(cx); - let (version_output, help_output) = futures::future::join(version_fut, help_fut).await; - - let current_version = String::from_utf8(version_output?.stdout)?; - let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG); - - if !supported { - return Err(LoadError::Unsupported { - error_message: format!( - "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).", - current_version - ).into(), - upgrade_message: "Upgrade Gemini to Latest".into(), - upgrade_command: "npm install -g @google/gemini-cli@latest".into(), - }.into()) - } + cx.spawn(async move |cx| { + let mut extra_env = HashMap::default(); + if let Some(api_key) = cx.update(GoogleLanguageModelProvider::api_key)?.await.ok() { + extra_env.insert("GEMINI_API_KEY".into(), api_key.key); + } + let (mut command, root_dir, login) = store + .update(cx, |store, cx| { + let agent = store + .get_external_agent(&GEMINI_NAME.into()) + .context("Gemini CLI is not registered")?; + anyhow::Ok(agent.get_command( + root_dir.as_deref(), + extra_env, + delegate.status_tx, + delegate.new_version_available, + &mut cx.to_async(), + )) + })?? + .await?; + + // Add proxy flag if proxy settings are configured in Zed and not in the args + if let Some(proxy_url_value) = &proxy_url + && !command.args.iter().any(|arg| arg.contains("--proxy")) + { + command.args.push("--proxy".into()); + command.args.push(proxy_url_value.clone()); } - result + + let connection = crate::acp::connect( + name, + command, + root_dir.as_ref(), + default_mode, + is_remote, + cx, + ) + .await?; + Ok((connection, login)) }) } + + fn into_any(self: Rc) -> Rc { + self + } } #[cfg(test)] pub(crate) mod tests { + use project::agent_server_store::AgentServerCommand; + use super::*; - use crate::AgentServerCommand; use std::path::Path; - crate::common_e2e_tests!(Gemini, allow_option_id = "proceed_once"); + crate::common_e2e_tests!(async |_, _, _| Gemini, allow_option_id = "proceed_once"); pub fn local_command() -> AgentServerCommand { let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) diff --git a/crates/agent_servers/src/settings.rs b/crates/agent_servers/src/settings.rs index 645674b5f15087250c2364fb9a8a846e163ad54c..9a610465be5516664dafd9cd4cb46be96ad89c8b 100644 --- a/crates/agent_servers/src/settings.rs +++ b/crates/agent_servers/src/settings.rs @@ -1,41 +1,121 @@ +use agent_client_protocol as acp; +use std::path::PathBuf; + use crate::AgentServerCommand; use anyhow::Result; -use gpui::App; +use collections::HashMap; +use gpui::{App, SharedString}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsUi}; pub fn init(cx: &mut App) { AllAgentServersSettings::register(cx); } -#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug)] +#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug, SettingsUi, SettingsKey)] +#[settings_key(key = "agent_servers")] pub struct AllAgentServersSettings { - pub gemini: Option, - pub claude: Option, + pub gemini: Option, + pub claude: Option, + + /// Custom agent servers configured by the user + #[serde(flatten)] + pub custom: HashMap, +} + +#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug, PartialEq)] +pub struct BuiltinAgentServerSettings { + /// Absolute path to a binary to be used when launching this agent. + /// + /// This can be used to run a specific binary without automatic downloads or searching `$PATH`. + #[serde(rename = "command")] + pub path: Option, + /// If a binary is specified in `command`, it will be passed these arguments. + pub args: Option>, + /// If a binary is specified in `command`, it will be passed these environment variables. + pub env: Option>, + /// Whether to skip searching `$PATH` for an agent server binary when + /// launching this agent. + /// + /// This has no effect if a `command` is specified. Otherwise, when this is + /// `false`, Zed will search `$PATH` for an agent server binary and, if one + /// is found, use it for threads with this agent. If no agent binary is + /// found on `$PATH`, Zed will automatically install and use its own binary. + /// When this is `true`, Zed will not search `$PATH`, and will always use + /// its own binary. + /// + /// Default: true + pub ignore_system_version: Option, + /// The default mode for new threads. + /// + /// Note: Not all agents support modes. + /// + /// Default: None + #[serde(skip_serializing_if = "Option::is_none")] + pub default_mode: Option, +} + +impl BuiltinAgentServerSettings { + pub(crate) fn custom_command(self) -> Option { + self.path.map(|path| AgentServerCommand { + path, + args: self.args.unwrap_or_default(), + env: self.env, + }) + } } -#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)] -pub struct AgentServerSettings { +impl From for BuiltinAgentServerSettings { + fn from(value: AgentServerCommand) -> Self { + BuiltinAgentServerSettings { + path: Some(value.path), + args: Some(value.args), + env: value.env, + ..Default::default() + } + } +} + +#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug, PartialEq)] +pub struct CustomAgentServerSettings { #[serde(flatten)] pub command: AgentServerCommand, + /// The default mode for new threads. + /// + /// Note: Not all agents support modes. + /// + /// Default: None + #[serde(skip_serializing_if = "Option::is_none")] + pub default_mode: Option, } impl settings::Settings for AllAgentServersSettings { - const KEY: Option<&'static str> = Some("agent_servers"); - type FileContent = Self; fn load(sources: SettingsSources, _: &mut App) -> Result { let mut settings = AllAgentServersSettings::default(); - for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() { + for AllAgentServersSettings { + gemini, + claude, + custom, + } in sources.defaults_and_customizations() + { if gemini.is_some() { settings.gemini = gemini.clone(); } if claude.is_some() { settings.claude = claude.clone(); } + + // Merge custom agents + for (name, config) in custom { + // Skip built-in agent names to avoid conflicts + if name != "gemini" && name != "claude" { + settings.custom.insert(name.clone(), config.clone()); + } + } } Ok(settings) diff --git a/crates/agent_settings/src/agent_profile.rs b/crates/agent_settings/src/agent_profile.rs index 402cf81678e02a13c99bf4cdf225406085e3551d..04fdd4a753a3fd015f2710fb9d70770ad960c560 100644 --- a/crates/agent_settings/src/agent_profile.rs +++ b/crates/agent_settings/src/agent_profile.rs @@ -58,7 +58,7 @@ impl AgentProfileSettings { || self .context_servers .get(server_id) - .map_or(false, |preset| preset.tools.get(tool_name) == Some(&true)) + .is_some_and(|preset| preset.tools.get(tool_name) == Some(&true)) } } diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index d9557c5d008bc902acae7e512c1b8532092f4c34..e850945a40f46f31543fad2631216139706b405a 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -8,13 +8,15 @@ use gpui::{App, Pixels, SharedString}; use language_model::LanguageModel; use schemars::{JsonSchema, json_schema}; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsUi}; use std::borrow::Cow; pub use crate::agent_profile::*; pub const SUMMARIZE_THREAD_PROMPT: &str = include_str!("../../agent/src/prompts/summarize_thread_prompt.txt"); +pub const SUMMARIZE_THREAD_DETAILED_PROMPT: &str = + include_str!("../../agent/src/prompts/summarize_thread_detailed_prompt.txt"); pub fn init(cx: &mut App) { AgentSettings::register(cx); @@ -116,15 +118,15 @@ pub struct LanguageModelParameters { impl LanguageModelParameters { pub fn matches(&self, model: &Arc) -> bool { - if let Some(provider) = &self.provider { - if provider.0 != model.provider_id().0 { - return false; - } + if let Some(provider) = &self.provider + && provider.0 != model.provider_id().0 + { + return false; } - if let Some(setting_model) = &self.model { - if *setting_model != model.id().0 { - return false; - } + if let Some(setting_model) = &self.model + && *setting_model != model.id().0 + { + return false; } true } @@ -221,7 +223,8 @@ impl AgentSettingsContent { } } -#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug, Default)] +#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug, Default, SettingsUi, SettingsKey)] +#[settings_key(key = "agent", fallback_key = "assistant")] pub struct AgentSettingsContent { /// Whether the Agent is enabled. /// @@ -266,6 +269,10 @@ pub struct AgentSettingsContent { /// Whenever a tool action would normally wait for your confirmation /// that you allow it, always choose to allow it. /// + /// This setting has no effect on external agents that support permission modes, such as Claude Code. + /// + /// Set `agent_servers.claude.default_mode` to `bypassPermissions`, to disable all permission requests when using Claude Code. + /// /// Default: false always_allow_tool_actions: Option, /// Where to show a popup notification when the agent is waiting for user input. @@ -309,7 +316,7 @@ pub struct AgentSettingsContent { /// /// Default: true expand_terminal_card: Option, - /// Whether to always use cmd-enter (or ctrl-enter on Linux) to send messages in the agent panel. + /// Whether to always use cmd-enter (or ctrl-enter on Linux or Windows) to send messages in the agent panel. /// /// Default: false use_modifier_to_send: Option, @@ -350,18 +357,19 @@ impl JsonSchema for LanguageModelProviderSetting { fn json_schema(_: &mut schemars::SchemaGenerator) -> schemars::Schema { json_schema!({ "enum": [ - "anthropic", "amazon-bedrock", + "anthropic", + "copilot_chat", + "deepseek", "google", "lmstudio", + "mistral", "ollama", "openai", - "zed.dev", - "copilot_chat", - "deepseek", "openrouter", - "mistral", - "vercel" + "vercel", + "x_ai", + "zed.dev" ] }) } @@ -396,10 +404,6 @@ pub struct ContextServerPresetContent { } impl Settings for AgentSettings { - const KEY: Option<&'static str> = Some("agent"); - - const FALLBACK_KEY: Option<&'static str> = Some("assistant"); - const PRESERVED_KEYS: Option<&'static [&'static str]> = Some(&["version"]); type FileContent = AgentSettingsContent; @@ -503,9 +507,8 @@ impl Settings for AgentSettings { } } - debug_assert_eq!( - sources.default.always_allow_tool_actions.unwrap_or(false), - false, + debug_assert!( + !sources.default.always_allow_tool_actions.unwrap_or(false), "For security, agent.always_allow_tool_actions should always be false in default.json. If it's true, that is a bug that should be fixed!" ); diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index b6a5710aa42ce64722985d934967703faf92bbdc..eaa058467f44638db4f0a446444424d706f76608 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -25,6 +25,7 @@ agent_servers.workspace = true agent_settings.workspace = true ai_onboarding.workspace = true anyhow.workspace = true +arrayvec.workspace = true assistant_context.workspace = true assistant_slash_command.workspace = true assistant_slash_commands.workspace = true @@ -50,7 +51,6 @@ fuzzy.workspace = true gpui.workspace = true html_to_markdown.workspace = true http_client.workspace = true -indexed_docs.workspace = true indoc.workspace = true inventory.workspace = true itertools.workspace = true @@ -68,6 +68,7 @@ ordered-float.workspace = true parking_lot.workspace = true paths.workspace = true picker.workspace = true +postage.workspace = true project.workspace = true prompt_store.workspace = true proto.workspace = true @@ -80,6 +81,7 @@ serde.workspace = true serde_json.workspace = true serde_json_lenient.workspace = true settings.workspace = true +shlex.workspace = true smol.workspace = true streaming_diff.workspace = true task.workspace = true @@ -103,10 +105,13 @@ workspace.workspace = true zed_actions.workspace = true [dev-dependencies] +acp_thread = { workspace = true, features = ["test-support"] } agent = { workspace = true, features = ["test-support"] } +agent2 = { workspace = true, features = ["test-support"] } assistant_context = { workspace = true, features = ["test-support"] } assistant_tools.workspace = true buffer_diff = { workspace = true, features = ["test-support"] } +db = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] } indoc.workspace = true diff --git a/crates/agent_ui/src/acp.rs b/crates/agent_ui/src/acp.rs index b9814adb2dc5fec075bf1128cbbba19a8889b3e6..2e15cd424d6313d981ff8c000f5eeb958aec9370 100644 --- a/crates/agent_ui/src/acp.rs +++ b/crates/agent_ui/src/acp.rs @@ -1,10 +1,14 @@ mod completion_provider; -mod message_history; +mod entry_view_state; +mod message_editor; +mod mode_selector; mod model_selector; mod model_selector_popover; +mod thread_history; mod thread_view; -pub use message_history::MessageHistory; +pub use mode_selector::ModeSelector; pub use model_selector::AcpModelSelector; pub use model_selector_popover::AcpModelSelectorPopover; +pub use thread_history::*; pub use thread_view::AcpThreadView; diff --git a/crates/agent_ui/src/acp/completion_provider.rs b/crates/agent_ui/src/acp/completion_provider.rs index 46c8aa92f1a689e4a58f33843bbc1f7feb2201bc..5ef2e222d05e11c9848d8835800b86579f31f4de 100644 --- a/crates/agent_ui/src/acp/completion_provider.rs +++ b/crates/agent_ui/src/acp/completion_provider.rs @@ -1,207 +1,44 @@ +use std::cell::{Cell, RefCell}; use std::ops::Range; -use std::path::{Path, PathBuf}; +use std::rc::Rc; use std::sync::Arc; use std::sync::atomic::AtomicBool; -use acp_thread::{MentionUri, selection_name}; -use anyhow::{Context as _, Result, anyhow}; -use collections::{HashMap, HashSet}; -use editor::display_map::CreaseId; -use editor::{CompletionProvider, Editor, ExcerptId, ToOffset as _}; -use file_icons::FileIcons; -use futures::future::try_join_all; +use acp_thread::MentionUri; +use agent_client_protocol as acp; +use agent2::{HistoryEntry, HistoryStore}; +use anyhow::Result; +use editor::{CompletionProvider, Editor, ExcerptId}; use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{App, Entity, Task, WeakEntity}; -use http_client::HttpClientWithUrl; -use itertools::Itertools as _; use language::{Buffer, CodeLabel, HighlightId}; use lsp::CompletionContext; -use parking_lot::Mutex; +use project::lsp_store::CompletionDocumentation; use project::{ - Completion, CompletionIntent, CompletionResponse, Project, ProjectPath, Symbol, WorktreeId, + Completion, CompletionDisplayOptions, CompletionIntent, CompletionResponse, Project, + ProjectPath, Symbol, WorktreeId, }; use prompt_store::PromptStore; use rope::Point; -use text::{Anchor, OffsetRangeExt as _, ToPoint as _}; +use text::{Anchor, ToPoint as _}; use ui::prelude::*; -use url::Url; use workspace::Workspace; -use workspace::notifications::NotifyResultExt; -use agent::{ - context::RULES_ICON, - thread_store::{TextThreadStore, ThreadStore}, -}; - -use crate::context_picker::fetch_context_picker::fetch_url_content; +use crate::AgentPanel; +use crate::acp::message_editor::{MessageEditor, MessageEditorEvent}; use crate::context_picker::file_context_picker::{FileMatch, search_files}; use crate::context_picker::rules_context_picker::{RulesContextEntry, search_rules}; use crate::context_picker::symbol_context_picker::SymbolMatch; use crate::context_picker::symbol_context_picker::search_symbols; -use crate::context_picker::thread_context_picker::{ - ThreadContextEntry, ThreadMatch, search_threads, -}; use crate::context_picker::{ - ContextPickerAction, ContextPickerEntry, ContextPickerMode, RecentEntry, - available_context_picker_entries, recent_context_picker_entries, selection_ranges, + ContextPickerAction, ContextPickerEntry, ContextPickerMode, selection_ranges, }; -#[derive(Default)] -pub struct MentionSet { - uri_by_crease_id: HashMap, - fetch_results: HashMap, -} - -impl MentionSet { - pub fn insert(&mut self, crease_id: CreaseId, uri: MentionUri) { - self.uri_by_crease_id.insert(crease_id, uri); - } - - pub fn add_fetch_result(&mut self, url: Url, content: String) { - self.fetch_results.insert(url, content); - } - - pub fn drain(&mut self) -> impl Iterator { - self.fetch_results.clear(); - self.uri_by_crease_id.drain().map(|(id, _)| id) - } - - pub fn contents( - &self, - project: Entity, - thread_store: Entity, - text_thread_store: Entity, - window: &mut Window, - cx: &mut App, - ) -> Task>> { - let contents = self - .uri_by_crease_id - .iter() - .map(|(&crease_id, uri)| { - match uri { - MentionUri::File(path) => { - let uri = uri.clone(); - let path = path.to_path_buf(); - let buffer_task = project.update(cx, |project, cx| { - let path = project - .find_project_path(path, cx) - .context("Failed to find project path")?; - anyhow::Ok(project.open_buffer(path, cx)) - }); - - cx.spawn(async move |cx| { - let buffer = buffer_task?.await?; - let content = buffer.read_with(cx, |buffer, _cx| buffer.text())?; - - anyhow::Ok((crease_id, Mention { uri, content })) - }) - } - MentionUri::Symbol { - path, line_range, .. - } - | MentionUri::Selection { - path, line_range, .. - } => { - let uri = uri.clone(); - let path_buf = path.clone(); - let line_range = line_range.clone(); - - let buffer_task = project.update(cx, |project, cx| { - let path = project - .find_project_path(&path_buf, cx) - .context("Failed to find project path")?; - anyhow::Ok(project.open_buffer(path, cx)) - }); - - cx.spawn(async move |cx| { - let buffer = buffer_task?.await?; - let content = buffer.read_with(cx, |buffer, _cx| { - buffer - .text_for_range( - Point::new(line_range.start, 0) - ..Point::new( - line_range.end, - buffer.line_len(line_range.end), - ), - ) - .collect() - })?; - - anyhow::Ok((crease_id, Mention { uri, content })) - }) - } - MentionUri::Thread { id: thread_id, .. } => { - let open_task = thread_store.update(cx, |thread_store, cx| { - thread_store.open_thread(&thread_id, window, cx) - }); - - let uri = uri.clone(); - cx.spawn(async move |cx| { - let thread = open_task.await?; - let content = thread.read_with(cx, |thread, _cx| { - thread.latest_detailed_summary_or_text().to_string() - })?; - - anyhow::Ok((crease_id, Mention { uri, content })) - }) - } - MentionUri::TextThread { path, .. } => { - let context = text_thread_store.update(cx, |text_thread_store, cx| { - text_thread_store.open_local_context(path.as_path().into(), cx) - }); - let uri = uri.clone(); - cx.spawn(async move |cx| { - let context = context.await?; - let xml = context.update(cx, |context, cx| context.to_xml(cx))?; - anyhow::Ok((crease_id, Mention { uri, content: xml })) - }) - } - MentionUri::Rule { id: prompt_id, .. } => { - let Some(prompt_store) = thread_store.read(cx).prompt_store().clone() - else { - return Task::ready(Err(anyhow!("missing prompt store"))); - }; - let text_task = prompt_store.read(cx).load(*prompt_id, cx); - let uri = uri.clone(); - cx.spawn(async move |_| { - // TODO: report load errors instead of just logging - let text = text_task.await?; - anyhow::Ok((crease_id, Mention { uri, content: text })) - }) - } - MentionUri::Fetch { url } => { - let Some(content) = self.fetch_results.get(&url) else { - return Task::ready(Err(anyhow!("missing fetch result"))); - }; - Task::ready(Ok(( - crease_id, - Mention { - uri: uri.clone(), - content: content.clone(), - }, - ))) - } - } - }) - .collect::>(); - - cx.spawn(async move |_cx| { - let contents = try_join_all(contents).await?.into_iter().collect(); - anyhow::Ok(contents) - }) - } -} - -#[derive(Debug)] -pub struct Mention { - pub uri: MentionUri, - pub content: String, -} - pub(crate) enum Match { File(FileMatch), Symbol(SymbolMatch), - Thread(ThreadMatch), + Thread(HistoryEntry), + RecentThread(HistoryEntry), Fetch(SharedString), Rules(RulesContextEntry), Entry(EntryMatch), @@ -218,6 +55,7 @@ impl Match { Match::File(file) => file.mat.score, Match::Entry(mode) => mode.mat.as_ref().map(|mat| mat.score).unwrap_or(1.), Match::Thread(_) => 1., + Match::RecentThread(_) => 1., Match::Symbol(_) => 1., Match::Rules(_) => 1., Match::Fetch(_) => 1., @@ -225,227 +63,44 @@ impl Match { } } -fn search( - mode: Option, - query: String, - cancellation_flag: Arc, - recent_entries: Vec, - prompt_store: Option>, - thread_store: WeakEntity, - text_thread_context_store: WeakEntity, - workspace: Entity, - cx: &mut App, -) -> Task> { - match mode { - Some(ContextPickerMode::File) => { - let search_files_task = - search_files(query.clone(), cancellation_flag.clone(), &workspace, cx); - cx.background_spawn(async move { - search_files_task - .await - .into_iter() - .map(Match::File) - .collect() - }) - } - - Some(ContextPickerMode::Symbol) => { - let search_symbols_task = - search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx); - cx.background_spawn(async move { - search_symbols_task - .await - .into_iter() - .map(Match::Symbol) - .collect() - }) - } - - Some(ContextPickerMode::Thread) => { - if let Some((thread_store, context_store)) = thread_store - .upgrade() - .zip(text_thread_context_store.upgrade()) - { - let search_threads_task = search_threads( - query.clone(), - cancellation_flag.clone(), - thread_store, - context_store, - cx, - ); - cx.background_spawn(async move { - search_threads_task - .await - .into_iter() - .map(Match::Thread) - .collect() - }) - } else { - Task::ready(Vec::new()) - } - } - - Some(ContextPickerMode::Fetch) => { - if !query.is_empty() { - Task::ready(vec![Match::Fetch(query.into())]) - } else { - Task::ready(Vec::new()) - } - } - - Some(ContextPickerMode::Rules) => { - if let Some(prompt_store) = prompt_store.as_ref() { - let search_rules_task = - search_rules(query.clone(), cancellation_flag.clone(), prompt_store, cx); - cx.background_spawn(async move { - search_rules_task - .await - .into_iter() - .map(Match::Rules) - .collect::>() - }) - } else { - Task::ready(Vec::new()) - } - } - - None => { - if query.is_empty() { - let mut matches = recent_entries - .into_iter() - .map(|entry| match entry { - RecentEntry::File { - project_path, - path_prefix, - } => Match::File(FileMatch { - mat: fuzzy::PathMatch { - score: 1., - positions: Vec::new(), - worktree_id: project_path.worktree_id.to_usize(), - path: project_path.path, - path_prefix, - is_dir: false, - distance_to_relative_ancestor: 0, - }, - is_recent: true, - }), - RecentEntry::Thread(thread_context_entry) => Match::Thread(ThreadMatch { - thread: thread_context_entry, - is_recent: true, - }), - }) - .collect::>(); - - matches.extend( - available_context_picker_entries( - &prompt_store, - &Some(thread_store.clone()), - &workspace, - cx, - ) - .into_iter() - .map(|mode| { - Match::Entry(EntryMatch { - entry: mode, - mat: None, - }) - }), - ); - - Task::ready(matches) - } else { - let executor = cx.background_executor().clone(); - - let search_files_task = - search_files(query.clone(), cancellation_flag.clone(), &workspace, cx); - - let entries = available_context_picker_entries( - &prompt_store, - &Some(thread_store.clone()), - &workspace, - cx, - ); - let entry_candidates = entries - .iter() - .enumerate() - .map(|(ix, entry)| StringMatchCandidate::new(ix, entry.keyword())) - .collect::>(); - - cx.background_spawn(async move { - let mut matches = search_files_task - .await - .into_iter() - .map(Match::File) - .collect::>(); - - let entry_matches = fuzzy::match_strings( - &entry_candidates, - &query, - false, - true, - 100, - &Arc::new(AtomicBool::default()), - executor, - ) - .await; - - matches.extend(entry_matches.into_iter().map(|mat| { - Match::Entry(EntryMatch { - entry: entries[mat.candidate_id], - mat: Some(mat), - }) - })); - - matches.sort_by(|a, b| { - b.score() - .partial_cmp(&a.score()) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - matches - }) - } - } - } -} - pub struct ContextPickerCompletionProvider { - mention_set: Arc>, + message_editor: WeakEntity, workspace: WeakEntity, - thread_store: WeakEntity, - text_thread_store: WeakEntity, - editor: WeakEntity, + history_store: Entity, + prompt_store: Option>, + prompt_capabilities: Rc>, + available_commands: Rc>>, } impl ContextPickerCompletionProvider { pub fn new( - mention_set: Arc>, + message_editor: WeakEntity, workspace: WeakEntity, - thread_store: WeakEntity, - text_thread_store: WeakEntity, - editor: WeakEntity, + history_store: Entity, + prompt_store: Option>, + prompt_capabilities: Rc>, + available_commands: Rc>>, ) -> Self { Self { - mention_set, + message_editor, workspace, - thread_store, - text_thread_store, - editor, + history_store, + prompt_store, + prompt_capabilities, + available_commands, } } fn completion_for_entry( entry: ContextPickerEntry, - excerpt_id: ExcerptId, source_range: Range, - editor: Entity, - mention_set: Arc>, + message_editor: WeakEntity, workspace: &Entity, cx: &mut App, ) -> Option { match entry { ContextPickerEntry::Mode(mode) => Some(Completion { - replace_range: source_range.clone(), + replace_range: source_range, new_text: format!("@{} ", mode.keyword()), label: CodeLabel::plain(mode.label().to_string(), None), icon_path: Some(mode.icon().path().into()), @@ -458,134 +113,26 @@ impl ContextPickerCompletionProvider { confirm: Some(Arc::new(|_, _, _| true)), }), ContextPickerEntry::Action(action) => { - let (new_text, on_action) = match action { - ContextPickerAction::AddSelections => { - let selections = selection_ranges(workspace, cx); - - const PLACEHOLDER: &str = "selection "; - - let new_text = std::iter::repeat(PLACEHOLDER) - .take(selections.len()) - .chain(std::iter::once("")) - .join(" "); - - let callback = Arc::new({ - let mention_set = mention_set.clone(); - let selections = selections.clone(); - move |_, window: &mut Window, cx: &mut App| { - let editor = editor.clone(); - let mention_set = mention_set.clone(); - let selections = selections.clone(); - window.defer(cx, move |window, cx| { - let mut current_offset = 0; - - for (buffer, selection_range) in selections { - let snapshot = - editor.read(cx).buffer().read(cx).snapshot(cx); - let Some(start) = snapshot - .anchor_in_excerpt(excerpt_id, source_range.start) - else { - return; - }; - - let offset = start.to_offset(&snapshot) + current_offset; - let text_len = PLACEHOLDER.len() - 1; - - let range = snapshot.anchor_after(offset) - ..snapshot.anchor_after(offset + text_len); - - let path = buffer - .read(cx) - .file() - .map_or(PathBuf::from("untitled"), |file| { - file.path().to_path_buf() - }); - - let point_range = snapshot - .as_singleton() - .map(|(_, _, snapshot)| { - selection_range.to_point(&snapshot) - }) - .unwrap_or_default(); - let line_range = point_range.start.row..point_range.end.row; - let crease = crate::context_picker::crease_for_mention( - selection_name(&path, &line_range).into(), - IconName::Reader.path().into(), - range, - editor.downgrade(), - ); - - let [crease_id]: [_; 1] = - editor.update(cx, |editor, cx| { - let crease_ids = - editor.insert_creases(vec![crease.clone()], cx); - editor.fold_creases( - vec![crease], - false, - window, - cx, - ); - crease_ids.try_into().unwrap() - }); - - mention_set.lock().insert( - crease_id, - MentionUri::Selection { path, line_range }, - ); - - current_offset += text_len + 1; - } - }); - - false - } - }); - - (new_text, callback) - } - }; - - Some(Completion { - replace_range: source_range.clone(), - new_text, - label: CodeLabel::plain(action.label().to_string(), None), - icon_path: Some(action.icon().path().into()), - documentation: None, - source: project::CompletionSource::Custom, - insert_text_mode: None, - // This ensures that when a user accepts this completion, the - // completion menu will still be shown after "@category " is - // inserted - confirm: Some(on_action), - }) + Self::completion_for_action(action, source_range, message_editor, workspace, cx) } } } fn completion_for_thread( - thread_entry: ThreadContextEntry, - excerpt_id: ExcerptId, + thread_entry: HistoryEntry, source_range: Range, recent: bool, - editor: Entity, - mention_set: Arc>, + editor: WeakEntity, + cx: &mut App, ) -> Completion { + let uri = thread_entry.mention_uri(); + let icon_for_completion = if recent { - IconName::HistoryRerun + IconName::HistoryRerun.path().into() } else { - IconName::Thread + uri.icon_path(cx) }; - let uri = match &thread_entry { - ThreadContextEntry::Thread { id, title } => MentionUri::Thread { - id: id.clone(), - name: title.to_string(), - }, - ThreadContextEntry::Context { path, title } => MentionUri::TextThread { - path: path.to_path_buf(), - name: title.to_string(), - }, - }; let new_text = format!("{} ", uri.as_link()); let new_text_len = new_text.len(); @@ -596,15 +143,12 @@ impl ContextPickerCompletionProvider { documentation: None, insert_text_mode: None, source: project::CompletionSource::Custom, - icon_path: Some(icon_for_completion.path().into()), + icon_path: Some(icon_for_completion), confirm: Some(confirm_completion_callback( - IconName::Thread.path().into(), thread_entry.title().clone(), - excerpt_id, source_range.start, new_text_len - 1, - editor.clone(), - mention_set, + editor, uri, )), } @@ -612,10 +156,9 @@ impl ContextPickerCompletionProvider { fn completion_for_rules( rule: RulesContextEntry, - excerpt_id: ExcerptId, source_range: Range, - editor: Entity, - mention_set: Arc>, + editor: WeakEntity, + cx: &mut App, ) -> Completion { let uri = MentionUri::Rule { id: rule.prompt_id.into(), @@ -623,6 +166,7 @@ impl ContextPickerCompletionProvider { }; let new_text = format!("{} ", uri.as_link()); let new_text_len = new_text.len(); + let icon_path = uri.icon_path(cx); Completion { replace_range: source_range.clone(), new_text, @@ -630,15 +174,12 @@ impl ContextPickerCompletionProvider { documentation: None, insert_text_mode: None, source: project::CompletionSource::Custom, - icon_path: Some(RULES_ICON.path().into()), + icon_path: Some(icon_path), confirm: Some(confirm_completion_callback( - RULES_ICON.path().into(), - rule.title.clone(), - excerpt_id, + rule.title, source_range.start, new_text_len - 1, - editor.clone(), - mention_set, + editor, uri, )), } @@ -649,12 +190,10 @@ impl ContextPickerCompletionProvider { path_prefix: &str, is_recent: bool, is_directory: bool, - excerpt_id: ExcerptId, source_range: Range, - editor: Entity, - mention_set: Arc>, + message_editor: WeakEntity, project: Entity, - cx: &App, + cx: &mut App, ) -> Option { let (file_name, directory) = crate::context_picker::file_context_picker::extract_file_name_and_directory( @@ -664,28 +203,23 @@ impl ContextPickerCompletionProvider { let label = build_code_label_for_full_path(&file_name, directory.as_ref().map(|s| s.as_ref()), cx); - let full_path = if let Some(directory) = directory { - format!("{}{}", directory, file_name) - } else { - file_name.to_string() - }; - let crease_icon_path = if is_directory { - FileIcons::get_folder_icon(false, cx).unwrap_or_else(|| IconName::Folder.path().into()) + let abs_path = project.read(cx).absolute_path(&project_path, cx)?; + + let uri = if is_directory { + MentionUri::Directory { abs_path } } else { - FileIcons::get_icon(Path::new(&full_path), cx) - .unwrap_or_else(|| IconName::File.path().into()) + MentionUri::File { abs_path } }; + + let crease_icon_path = uri.icon_path(cx); let completion_icon_path = if is_recent { IconName::HistoryRerun.path().into() } else { - crease_icon_path.clone() + crease_icon_path }; - let abs_path = project.read(cx).absolute_path(&project_path, cx)?; - - let file_uri = MentionUri::File(abs_path); - let new_text = format!("{} ", file_uri.as_link()); + let new_text = format!("{} ", uri.as_link()); let new_text_len = new_text.len(); Some(Completion { replace_range: source_range.clone(), @@ -696,24 +230,19 @@ impl ContextPickerCompletionProvider { icon_path: Some(completion_icon_path), insert_text_mode: None, confirm: Some(confirm_completion_callback( - crease_icon_path, file_name, - excerpt_id, source_range.start, new_text_len - 1, - editor, - mention_set.clone(), - file_uri, + message_editor, + uri, )), }) } fn completion_for_symbol( symbol: Symbol, - excerpt_id: ExcerptId, source_range: Range, - editor: Entity, - mention_set: Arc>, + message_editor: WeakEntity, workspace: Entity, cx: &mut App, ) -> Option { @@ -723,28 +252,26 @@ impl ContextPickerCompletionProvider { let abs_path = project.read(cx).absolute_path(&symbol.path, cx)?; let uri = MentionUri::Symbol { - path: abs_path, + abs_path, name: symbol.name.clone(), - line_range: symbol.range.start.0.row..symbol.range.end.0.row, + line_range: symbol.range.start.0.row..=symbol.range.end.0.row, }; let new_text = format!("{} ", uri.as_link()); let new_text_len = new_text.len(); + let icon_path = uri.icon_path(cx); Some(Completion { replace_range: source_range.clone(), new_text, label, documentation: None, source: project::CompletionSource::Custom, - icon_path: Some(IconName::Code.path().into()), + icon_path: Some(icon_path), insert_text_mode: None, confirm: Some(confirm_completion_callback( - IconName::Code.path().into(), - symbol.name.clone().into(), - excerpt_id, + symbol.name.into(), source_range.start, new_text_len - 1, - editor.clone(), - mention_set.clone(), + message_editor, uri, )), }) @@ -753,293 +280,602 @@ impl ContextPickerCompletionProvider { fn completion_for_fetch( source_range: Range, url_to_fetch: SharedString, - excerpt_id: ExcerptId, - editor: Entity, - mention_set: Arc>, - http_client: Arc, + message_editor: WeakEntity, + cx: &mut App, ) -> Option { - let new_text = format!("@fetch {} ", url_to_fetch.clone()); - let new_text_len = new_text.len(); + let new_text = format!("@fetch {} ", url_to_fetch); + let url_to_fetch = url::Url::parse(url_to_fetch.as_ref()) + .or_else(|_| url::Url::parse(&format!("https://{url_to_fetch}"))) + .ok()?; + let mention_uri = MentionUri::Fetch { + url: url_to_fetch.clone(), + }; + let icon_path = mention_uri.icon_path(cx); Some(Completion { replace_range: source_range.clone(), - new_text, + new_text: new_text.clone(), label: CodeLabel::plain(url_to_fetch.to_string(), None), documentation: None, source: project::CompletionSource::Custom, - icon_path: Some(IconName::ToolWeb.path().into()), + icon_path: Some(icon_path), insert_text_mode: None, - confirm: Some({ - let start = source_range.start; - let content_len = new_text_len - 1; - let editor = editor.clone(); - let url_to_fetch = url_to_fetch.clone(); - let source_range = source_range.clone(); - Arc::new(move |_, window, cx| { - let Some(url) = url::Url::parse(url_to_fetch.as_ref()) - .or_else(|_| url::Url::parse(&format!("https://{url_to_fetch}"))) - .notify_app_err(cx) - else { - return false; - }; - let mention_uri = MentionUri::Fetch { url: url.clone() }; - - let editor = editor.clone(); - let mention_set = mention_set.clone(); - let http_client = http_client.clone(); - let source_range = source_range.clone(); - window.defer(cx, move |window, cx| { - let url = url.clone(); + confirm: Some(confirm_completion_callback( + url_to_fetch.to_string().into(), + source_range.start, + new_text.len() - 1, + message_editor, + mention_uri, + )), + }) + } - let Some(crease_id) = crate::context_picker::insert_crease_for_mention( - excerpt_id, - start, - content_len, - url.to_string().into(), - IconName::ToolWeb.path().into(), - editor.clone(), - window, - cx, - ) else { - return; - }; + pub(crate) fn completion_for_action( + action: ContextPickerAction, + source_range: Range, + message_editor: WeakEntity, + workspace: &Entity, + cx: &mut App, + ) -> Option { + let (new_text, on_action) = match action { + ContextPickerAction::AddSelections => { + const PLACEHOLDER: &str = "selection "; + let selections = selection_ranges(workspace, cx) + .into_iter() + .enumerate() + .map(|(ix, (buffer, range))| { + ( + buffer, + range, + (PLACEHOLDER.len() * ix)..(PLACEHOLDER.len() * (ix + 1) - 1), + ) + }) + .collect::>(); - let editor = editor.clone(); - let mention_set = mention_set.clone(); - let http_client = http_client.clone(); + let new_text: String = PLACEHOLDER.repeat(selections.len()); + + let callback = Arc::new({ + let source_range = source_range.clone(); + move |_, window: &mut Window, cx: &mut App| { + let selections = selections.clone(); + let message_editor = message_editor.clone(); let source_range = source_range.clone(); - window - .spawn(cx, async move |cx| { - if let Some(content) = - fetch_url_content(http_client, url.to_string()) - .await - .notify_async_err(cx) - { - mention_set.lock().add_fetch_result(url, content); - mention_set.lock().insert(crease_id, mention_uri.clone()); - } else { - // Remove crease if we failed to fetch - editor - .update(cx, |editor, cx| { - let snapshot = editor.buffer().read(cx).snapshot(cx); - let Some(anchor) = snapshot - .anchor_in_excerpt(excerpt_id, source_range.start) - else { - return; - }; - editor.display_map.update(cx, |display_map, cx| { - display_map.unfold_intersecting( - vec![anchor..anchor], - true, - cx, - ); - }); - editor.remove_creases([crease_id], cx); - }) - .ok(); - } - Some(()) - }) - .detach(); - }); - false - }) - }), + window.defer(cx, move |window, cx| { + message_editor + .update(cx, |message_editor, cx| { + message_editor.confirm_mention_for_selection( + source_range, + selections, + window, + cx, + ) + }) + .ok(); + }); + false + } + }); + + (new_text, callback) + } + }; + + Some(Completion { + replace_range: source_range, + new_text, + label: CodeLabel::plain(action.label().to_string(), None), + icon_path: Some(action.icon().path().into()), + documentation: None, + source: project::CompletionSource::Custom, + insert_text_mode: None, + // This ensures that when a user accepts this completion, the + // completion menu will still be shown after "@category " is + // inserted + confirm: Some(on_action), }) } -} -fn build_code_label_for_full_path(file_name: &str, directory: Option<&str>, cx: &App) -> CodeLabel { - let comment_id = cx.theme().syntax().highlight_id("comment").map(HighlightId); - let mut label = CodeLabel::default(); + fn search_slash_commands( + &self, + query: String, + cx: &mut App, + ) -> Task> { + let commands = self.available_commands.borrow().clone(); + if commands.is_empty() { + return Task::ready(Vec::new()); + } - label.push_str(&file_name, None); - label.push_str(" ", None); + cx.spawn(async move |cx| { + let candidates = commands + .iter() + .enumerate() + .map(|(id, command)| StringMatchCandidate::new(id, &command.name)) + .collect::>(); + + let matches = fuzzy::match_strings( + &candidates, + &query, + false, + true, + 100, + &Arc::new(AtomicBool::default()), + cx.background_executor().clone(), + ) + .await; - if let Some(directory) = directory { - label.push_str(&directory, comment_id); + matches + .into_iter() + .map(|mat| commands[mat.candidate_id].clone()) + .collect() + }) } - label.filter_range = 0..label.text().len(); - - label -} - -impl CompletionProvider for ContextPickerCompletionProvider { - fn completions( + fn search_mentions( &self, - excerpt_id: ExcerptId, - buffer: &Entity, - buffer_position: Anchor, - _trigger: CompletionContext, - _window: &mut Window, - cx: &mut Context, - ) -> Task>> { - let state = buffer.update(cx, |buffer, _cx| { - let position = buffer_position.to_point(buffer); - let line_start = Point::new(position.row, 0); - let offset_to_line = buffer.point_to_offset(line_start); - let mut lines = buffer.text_for_range(line_start..position).lines(); - let line = lines.next()?; - MentionCompletion::try_parse(line, offset_to_line) - }); - let Some(state) = state else { - return Task::ready(Ok(Vec::new())); - }; - + mode: Option, + query: String, + cancellation_flag: Arc, + cx: &mut App, + ) -> Task> { let Some(workspace) = self.workspace.upgrade() else { - return Task::ready(Ok(Vec::new())); + return Task::ready(Vec::default()); }; + match mode { + Some(ContextPickerMode::File) => { + let search_files_task = search_files(query, cancellation_flag, &workspace, cx); + cx.background_spawn(async move { + search_files_task + .await + .into_iter() + .map(Match::File) + .collect() + }) + } - let project = workspace.read(cx).project().clone(); - let http_client = workspace.read(cx).client().http_client(); - let snapshot = buffer.read(cx).snapshot(); - let source_range = snapshot.anchor_before(state.source_range.start) - ..snapshot.anchor_after(state.source_range.end); + Some(ContextPickerMode::Symbol) => { + let search_symbols_task = search_symbols(query, cancellation_flag, &workspace, cx); + cx.background_spawn(async move { + search_symbols_task + .await + .into_iter() + .map(Match::Symbol) + .collect() + }) + } - let thread_store = self.thread_store.clone(); - let text_thread_store = self.text_thread_store.clone(); - let editor = self.editor.clone(); + Some(ContextPickerMode::Thread) => { + let search_threads_task = + search_threads(query, cancellation_flag, &self.history_store, cx); + cx.background_spawn(async move { + search_threads_task + .await + .into_iter() + .map(Match::Thread) + .collect() + }) + } - let MentionCompletion { mode, argument, .. } = state; - let query = argument.unwrap_or_else(|| "".to_string()); + Some(ContextPickerMode::Fetch) => { + if !query.is_empty() { + Task::ready(vec![Match::Fetch(query.into())]) + } else { + Task::ready(Vec::new()) + } + } - let (exclude_paths, exclude_threads) = { - let mention_set = self.mention_set.lock(); + Some(ContextPickerMode::Rules) => { + if let Some(prompt_store) = self.prompt_store.as_ref() { + let search_rules_task = + search_rules(query, cancellation_flag, prompt_store, cx); + cx.background_spawn(async move { + search_rules_task + .await + .into_iter() + .map(Match::Rules) + .collect::>() + }) + } else { + Task::ready(Vec::new()) + } + } - let mut excluded_paths = HashSet::default(); - let mut excluded_threads = HashSet::default(); + None if query.is_empty() => { + let mut matches = self.recent_context_picker_entries(&workspace, cx); - for uri in mention_set.uri_by_crease_id.values() { - match uri { - MentionUri::File(path) => { - excluded_paths.insert(path.clone()); - } - MentionUri::Thread { id, .. } => { - excluded_threads.insert(id.clone()); - } - _ => {} - } + matches.extend( + self.available_context_picker_entries(&workspace, cx) + .into_iter() + .map(|mode| { + Match::Entry(EntryMatch { + entry: mode, + mat: None, + }) + }), + ); + + Task::ready(matches) } + None => { + let executor = cx.background_executor().clone(); - (excluded_paths, excluded_threads) - }; + let search_files_task = + search_files(query.clone(), cancellation_flag, &workspace, cx); - let recent_entries = recent_context_picker_entries( - Some(thread_store.clone()), - Some(text_thread_store.clone()), - workspace.clone(), - &exclude_paths, - &exclude_threads, - cx, - ); + let entries = self.available_context_picker_entries(&workspace, cx); + let entry_candidates = entries + .iter() + .enumerate() + .map(|(ix, entry)| StringMatchCandidate::new(ix, entry.keyword())) + .collect::>(); - let prompt_store = thread_store - .read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone()) - .ok() - .flatten(); + cx.background_spawn(async move { + let mut matches = search_files_task + .await + .into_iter() + .map(Match::File) + .collect::>(); - let search_task = search( - mode, - query, - Arc::::default(), - recent_entries, - prompt_store, - thread_store.clone(), - text_thread_store.clone(), - workspace.clone(), - cx, + let entry_matches = fuzzy::match_strings( + &entry_candidates, + &query, + false, + true, + 100, + &Arc::new(AtomicBool::default()), + executor, + ) + .await; + + matches.extend(entry_matches.into_iter().map(|mat| { + Match::Entry(EntryMatch { + entry: entries[mat.candidate_id], + mat: Some(mat), + }) + })); + + matches.sort_by(|a, b| { + b.score() + .partial_cmp(&a.score()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + matches + }) + } + } + } + + fn recent_context_picker_entries( + &self, + workspace: &Entity, + cx: &mut App, + ) -> Vec { + let mut recent = Vec::with_capacity(6); + + let mut mentions = self + .message_editor + .read_with(cx, |message_editor, _cx| message_editor.mentions()) + .unwrap_or_default(); + let workspace = workspace.read(cx); + let project = workspace.project().read(cx); + + if let Some(agent_panel) = workspace.panel::(cx) + && let Some(thread) = agent_panel.read(cx).active_agent_thread(cx) + { + let thread = thread.read(cx); + mentions.insert(MentionUri::Thread { + id: thread.session_id().clone(), + name: thread.title().into(), + }); + } + + recent.extend( + workspace + .recent_navigation_history_iter(cx) + .filter(|(_, abs_path)| { + abs_path.as_ref().is_none_or(|path| { + !mentions.contains(&MentionUri::File { + abs_path: path.clone(), + }) + }) + }) + .take(4) + .filter_map(|(project_path, _)| { + project + .worktree_for_id(project_path.worktree_id, cx) + .map(|worktree| { + let path_prefix = worktree.read(cx).root_name().into(); + Match::File(FileMatch { + mat: fuzzy::PathMatch { + score: 1., + positions: Vec::new(), + worktree_id: project_path.worktree_id.to_usize(), + path: project_path.path, + path_prefix, + is_dir: false, + distance_to_relative_ancestor: 0, + }, + is_recent: true, + }) + }) + }), ); - let mention_set = self.mention_set.clone(); + if self.prompt_capabilities.get().embedded_context { + const RECENT_COUNT: usize = 2; + let threads = self + .history_store + .read(cx) + .recently_opened_entries(cx) + .into_iter() + .filter(|thread| !mentions.contains(&thread.mention_uri())) + .take(RECENT_COUNT) + .collect::>(); + + recent.extend(threads.into_iter().map(Match::RecentThread)); + } - cx.spawn(async move |_, cx| { - let matches = search_task.await; - let Some(editor) = editor.upgrade() else { - return Ok(Vec::new()); - }; + recent + } - let completions = cx.update(|cx| { - matches - .into_iter() - .filter_map(|mat| match mat { - Match::File(FileMatch { mat, is_recent }) => { - let project_path = ProjectPath { - worktree_id: WorktreeId::from_usize(mat.worktree_id), - path: mat.path.clone(), + fn available_context_picker_entries( + &self, + workspace: &Entity, + cx: &mut App, + ) -> Vec { + let embedded_context = self.prompt_capabilities.get().embedded_context; + let mut entries = if embedded_context { + vec![ + ContextPickerEntry::Mode(ContextPickerMode::File), + ContextPickerEntry::Mode(ContextPickerMode::Symbol), + ContextPickerEntry::Mode(ContextPickerMode::Thread), + ] + } else { + // File is always available, but we don't need a mode entry + vec![] + }; + + let has_selection = workspace + .read(cx) + .active_item(cx) + .and_then(|item| item.downcast::()) + .is_some_and(|editor| { + editor.update(cx, |editor, cx| editor.has_non_empty_selection(cx)) + }); + if has_selection { + entries.push(ContextPickerEntry::Action( + ContextPickerAction::AddSelections, + )); + } + + if embedded_context { + if self.prompt_store.is_some() { + entries.push(ContextPickerEntry::Mode(ContextPickerMode::Rules)); + } + + entries.push(ContextPickerEntry::Mode(ContextPickerMode::Fetch)); + } + + entries + } +} + +fn build_code_label_for_full_path(file_name: &str, directory: Option<&str>, cx: &App) -> CodeLabel { + let comment_id = cx.theme().syntax().highlight_id("comment").map(HighlightId); + let mut label = CodeLabel::default(); + + label.push_str(file_name, None); + label.push_str(" ", None); + + if let Some(directory) = directory { + label.push_str(directory, comment_id); + } + + label.filter_range = 0..label.text().len(); + + label +} + +impl CompletionProvider for ContextPickerCompletionProvider { + fn completions( + &self, + _excerpt_id: ExcerptId, + buffer: &Entity, + buffer_position: Anchor, + _trigger: CompletionContext, + _window: &mut Window, + cx: &mut Context, + ) -> Task>> { + let state = buffer.update(cx, |buffer, _cx| { + let position = buffer_position.to_point(buffer); + let line_start = Point::new(position.row, 0); + let offset_to_line = buffer.point_to_offset(line_start); + let mut lines = buffer.text_for_range(line_start..position).lines(); + let line = lines.next()?; + ContextCompletion::try_parse( + line, + offset_to_line, + self.prompt_capabilities.get().embedded_context, + ) + }); + let Some(state) = state else { + return Task::ready(Ok(Vec::new())); + }; + + let Some(workspace) = self.workspace.upgrade() else { + return Task::ready(Ok(Vec::new())); + }; + + let project = workspace.read(cx).project().clone(); + let snapshot = buffer.read(cx).snapshot(); + let source_range = snapshot.anchor_before(state.source_range().start) + ..snapshot.anchor_after(state.source_range().end); + + let editor = self.message_editor.clone(); + + match state { + ContextCompletion::SlashCommand(SlashCommandCompletion { + command, argument, .. + }) => { + let search_task = self.search_slash_commands(command.unwrap_or_default(), cx); + cx.background_spawn(async move { + let completions = search_task + .await + .into_iter() + .map(|command| { + let new_text = if let Some(argument) = argument.as_ref() { + format!("/{} {}", command.name, argument) + } else { + format!("/{} ", command.name) }; - Self::completion_for_path( - project_path, - &mat.path_prefix, - is_recent, - mat.is_dir, - excerpt_id, - source_range.clone(), - editor.clone(), - mention_set.clone(), - project.clone(), - cx, - ) - } + let is_missing_argument = argument.is_none() && command.input.is_some(); + Completion { + replace_range: source_range.clone(), + new_text, + label: CodeLabel::plain(command.name.to_string(), None), + documentation: Some(CompletionDocumentation::MultiLinePlainText( + command.description.into(), + )), + source: project::CompletionSource::Custom, + icon_path: None, + insert_text_mode: None, + confirm: Some(Arc::new({ + let editor = editor.clone(); + move |intent, _window, cx| { + if !is_missing_argument { + cx.defer({ + let editor = editor.clone(); + move |cx| { + editor + .update(cx, |_editor, cx| { + match intent { + CompletionIntent::Complete + | CompletionIntent::CompleteWithInsert + | CompletionIntent::CompleteWithReplace => { + if !is_missing_argument { + cx.emit(MessageEditorEvent::Send); + } + } + CompletionIntent::Compose => {} + } + }) + .ok(); + } + }); + } + is_missing_argument + } + })), + } + }) + .collect(); - Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol( - symbol, - excerpt_id, - source_range.clone(), - editor.clone(), - mention_set.clone(), - workspace.clone(), - cx, - ), - - Match::Thread(ThreadMatch { - thread, is_recent, .. - }) => Some(Self::completion_for_thread( - thread, - excerpt_id, - source_range.clone(), - is_recent, - editor.clone(), - mention_set.clone(), - )), - - Match::Rules(user_rules) => Some(Self::completion_for_rules( - user_rules, - excerpt_id, - source_range.clone(), - editor.clone(), - mention_set.clone(), - )), - - Match::Fetch(url) => Self::completion_for_fetch( - source_range.clone(), - url, - excerpt_id, - editor.clone(), - mention_set.clone(), - http_client.clone(), - ), - - Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry( - entry, - excerpt_id, - source_range.clone(), - editor.clone(), - mention_set.clone(), - &workspace, - cx, - ), - }) - .collect() - })?; - - Ok(vec![CompletionResponse { - completions, - // Since this does its own filtering (see `filter_completions()` returns false), - // there is no benefit to computing whether this set of completions is incomplete. - is_incomplete: true, - }]) - }) + Ok(vec![CompletionResponse { + completions, + display_options: CompletionDisplayOptions { + dynamic_width: true, + }, + // Since this does its own filtering (see `filter_completions()` returns false), + // there is no benefit to computing whether this set of completions is incomplete. + is_incomplete: true, + }]) + }) + } + ContextCompletion::Mention(MentionCompletion { mode, argument, .. }) => { + let query = argument.unwrap_or_default(); + let search_task = + self.search_mentions(mode, query, Arc::::default(), cx); + + cx.spawn(async move |_, cx| { + let matches = search_task.await; + + let completions = cx.update(|cx| { + matches + .into_iter() + .filter_map(|mat| match mat { + Match::File(FileMatch { mat, is_recent }) => { + let project_path = ProjectPath { + worktree_id: WorktreeId::from_usize(mat.worktree_id), + path: mat.path.clone(), + }; + + Self::completion_for_path( + project_path, + &mat.path_prefix, + is_recent, + mat.is_dir, + source_range.clone(), + editor.clone(), + project.clone(), + cx, + ) + } + + Match::Symbol(SymbolMatch { symbol, .. }) => { + Self::completion_for_symbol( + symbol, + source_range.clone(), + editor.clone(), + workspace.clone(), + cx, + ) + } + + Match::Thread(thread) => Some(Self::completion_for_thread( + thread, + source_range.clone(), + false, + editor.clone(), + cx, + )), + + Match::RecentThread(thread) => Some(Self::completion_for_thread( + thread, + source_range.clone(), + true, + editor.clone(), + cx, + )), + + Match::Rules(user_rules) => Some(Self::completion_for_rules( + user_rules, + source_range.clone(), + editor.clone(), + cx, + )), + + Match::Fetch(url) => Self::completion_for_fetch( + source_range.clone(), + url, + editor.clone(), + cx, + ), + + Match::Entry(EntryMatch { entry, .. }) => { + Self::completion_for_entry( + entry, + source_range.clone(), + editor.clone(), + &workspace, + cx, + ) + } + }) + .collect() + })?; + + Ok(vec![CompletionResponse { + completions, + display_options: CompletionDisplayOptions { + dynamic_width: true, + }, + // Since this does its own filtering (see `filter_completions()` returns false), + // there is no benefit to computing whether this set of completions is incomplete. + is_incomplete: true, + }]) + }) + } + } } fn is_completion_trigger( @@ -1057,12 +893,16 @@ impl CompletionProvider for ContextPickerCompletionProvider { let offset_to_line = buffer.point_to_offset(line_start); let mut lines = buffer.text_for_range(line_start..position).lines(); if let Some(line) = lines.next() { - MentionCompletion::try_parse(line, offset_to_line) - .map(|completion| { - completion.source_range.start <= offset_to_line + position.column as usize - && completion.source_range.end >= offset_to_line + position.column as usize - }) - .unwrap_or(false) + ContextCompletion::try_parse( + line, + offset_to_line, + self.prompt_capabilities.get().embedded_context, + ) + .map(|completion| { + completion.source_range().start <= offset_to_line + position.column as usize + && completion.source_range().end >= offset_to_line + position.column as usize + }) + .unwrap_or(false) } else { false } @@ -1077,40 +917,145 @@ impl CompletionProvider for ContextPickerCompletionProvider { } } +pub(crate) fn search_threads( + query: String, + cancellation_flag: Arc, + history_store: &Entity, + cx: &mut App, +) -> Task> { + let threads = history_store.read(cx).entries().collect(); + if query.is_empty() { + return Task::ready(threads); + } + + let executor = cx.background_executor().clone(); + cx.background_spawn(async move { + let candidates = threads + .iter() + .enumerate() + .map(|(id, thread)| StringMatchCandidate::new(id, thread.title())) + .collect::>(); + let matches = fuzzy::match_strings( + &candidates, + &query, + false, + true, + 100, + &cancellation_flag, + executor, + ) + .await; + + matches + .into_iter() + .map(|mat| threads[mat.candidate_id].clone()) + .collect() + }) +} + fn confirm_completion_callback( - crease_icon_path: SharedString, crease_text: SharedString, - excerpt_id: ExcerptId, start: Anchor, content_len: usize, - editor: Entity, - mention_set: Arc>, + message_editor: WeakEntity, mention_uri: MentionUri, ) -> Arc bool + Send + Sync> { Arc::new(move |_, window, cx| { + let message_editor = message_editor.clone(); let crease_text = crease_text.clone(); - let crease_icon_path = crease_icon_path.clone(); - let editor = editor.clone(); - let mention_set = mention_set.clone(); let mention_uri = mention_uri.clone(); window.defer(cx, move |window, cx| { - if let Some(crease_id) = crate::context_picker::insert_crease_for_mention( - excerpt_id, - start, - content_len, - crease_text.clone(), - crease_icon_path, - editor.clone(), - window, - cx, - ) { - mention_set.lock().insert(crease_id, mention_uri.clone()); - } + message_editor + .clone() + .update(cx, |message_editor, cx| { + message_editor + .confirm_mention_completion( + crease_text, + start, + content_len, + mention_uri, + window, + cx, + ) + .detach(); + }) + .ok(); }); false }) } +enum ContextCompletion { + SlashCommand(SlashCommandCompletion), + Mention(MentionCompletion), +} + +impl ContextCompletion { + fn source_range(&self) -> Range { + match self { + Self::SlashCommand(completion) => completion.source_range.clone(), + Self::Mention(completion) => completion.source_range.clone(), + } + } + + fn try_parse(line: &str, offset_to_line: usize, allow_non_file_mentions: bool) -> Option { + if let Some(command) = SlashCommandCompletion::try_parse(line, offset_to_line) { + Some(Self::SlashCommand(command)) + } else if let Some(mention) = + MentionCompletion::try_parse(allow_non_file_mentions, line, offset_to_line) + { + Some(Self::Mention(mention)) + } else { + None + } + } +} + +#[derive(Debug, Default, PartialEq)] +pub struct SlashCommandCompletion { + pub source_range: Range, + pub command: Option, + pub argument: Option, +} + +impl SlashCommandCompletion { + pub fn try_parse(line: &str, offset_to_line: usize) -> Option { + // If we decide to support commands that are not at the beginning of the prompt, we can remove this check + if !line.starts_with('/') || offset_to_line != 0 { + return None; + } + + let (prefix, last_command) = line.rsplit_once('/')?; + if prefix.chars().last().is_some_and(|c| !c.is_whitespace()) + || last_command.starts_with(char::is_whitespace) + { + return None; + } + + let mut argument = None; + let mut command = None; + if let Some((command_text, args)) = last_command.split_once(char::is_whitespace) { + if !args.is_empty() { + argument = Some(args.trim_end().to_string()); + } + command = Some(command_text.to_string()); + } else if !last_command.is_empty() { + command = Some(last_command.to_string()); + }; + + Some(Self { + source_range: prefix.len() + offset_to_line + ..line + .rfind(|c: char| !c.is_whitespace()) + .unwrap_or_else(|| line.len()) + + 1 + + offset_to_line, + command, + argument, + }) + } +} + #[derive(Debug, Default, PartialEq)] struct MentionCompletion { source_range: Range, @@ -1119,16 +1064,24 @@ struct MentionCompletion { } impl MentionCompletion { - fn try_parse(line: &str, offset_to_line: usize) -> Option { + fn try_parse(allow_non_file_mentions: bool, line: &str, offset_to_line: usize) -> Option { let last_mention_start = line.rfind('@')?; - if last_mention_start >= line.len() { - return Some(Self::default()); + + // No whitespace immediately after '@' + if line[last_mention_start + 1..] + .chars() + .next() + .is_some_and(|c| c.is_whitespace()) + { + return None; } + + // Must be a word boundary before '@' if last_mention_start > 0 - && line + && line[..last_mention_start] .chars() - .nth(last_mention_start - 1) - .map_or(false, |c| !c.is_whitespace()) + .last() + .is_some_and(|c| !c.is_whitespace()) { return None; } @@ -1140,10 +1093,14 @@ impl MentionCompletion { let mut parts = rest_of_line.split_whitespace(); let mut end = last_mention_start + 1; + if let Some(mode_text) = parts.next() { + // Safe since we check no leading whitespace above end += mode_text.len(); - if let Some(parsed_mode) = ContextPickerMode::try_from(mode_text).ok() { + if let Some(parsed_mode) = ContextPickerMode::try_from(mode_text).ok() + && (allow_non_file_mentions || matches!(parsed_mode, ContextPickerMode::File)) + { mode = Some(parsed_mode); } else { argument = Some(mode_text.to_string()); @@ -1151,6 +1108,12 @@ impl MentionCompletion { match rest_of_line[mode_text.len()..].find(|c: char| !c.is_whitespace()) { Some(whitespace_count) => { if let Some(argument_text) = parts.next() { + // If mode wasn't recognized but we have an argument, don't suggest completions + // (e.g. '@something word') + if mode.is_none() && !argument_text.is_empty() { + return None; + } + argument = Some(argument_text.to_string()); end += whitespace_count + argument_text.len(); } @@ -1173,22 +1136,80 @@ impl MentionCompletion { #[cfg(test)] mod tests { use super::*; - use editor::AnchorRangeExt; - use gpui::{EventEmitter, FocusHandle, Focusable, TestAppContext, VisualTestContext}; - use project::{Project, ProjectPath}; - use serde_json::json; - use settings::SettingsStore; - use smol::stream::StreamExt as _; - use std::{ops::Deref, rc::Rc}; - use util::path; - use workspace::{AppState, Item}; + + #[test] + fn test_slash_command_completion_parse() { + assert_eq!( + SlashCommandCompletion::try_parse("/", 0), + Some(SlashCommandCompletion { + source_range: 0..1, + command: None, + argument: None, + }) + ); + + assert_eq!( + SlashCommandCompletion::try_parse("/help", 0), + Some(SlashCommandCompletion { + source_range: 0..5, + command: Some("help".to_string()), + argument: None, + }) + ); + + assert_eq!( + SlashCommandCompletion::try_parse("/help ", 0), + Some(SlashCommandCompletion { + source_range: 0..5, + command: Some("help".to_string()), + argument: None, + }) + ); + + assert_eq!( + SlashCommandCompletion::try_parse("/help arg1", 0), + Some(SlashCommandCompletion { + source_range: 0..10, + command: Some("help".to_string()), + argument: Some("arg1".to_string()), + }) + ); + + assert_eq!( + SlashCommandCompletion::try_parse("/help arg1 arg2", 0), + Some(SlashCommandCompletion { + source_range: 0..15, + command: Some("help".to_string()), + argument: Some("arg1 arg2".to_string()), + }) + ); + + assert_eq!( + SlashCommandCompletion::try_parse("/拿不到命令 拿不到命令 ", 0), + Some(SlashCommandCompletion { + source_range: 0..30, + command: Some("拿不到命令".to_string()), + argument: Some("拿不到命令".to_string()), + }) + ); + + assert_eq!(SlashCommandCompletion::try_parse("Lorem Ipsum", 0), None); + + assert_eq!(SlashCommandCompletion::try_parse("Lorem /", 0), None); + + assert_eq!(SlashCommandCompletion::try_parse("Lorem /help", 0), None); + + assert_eq!(SlashCommandCompletion::try_parse("Lorem/", 0), None); + + assert_eq!(SlashCommandCompletion::try_parse("/ ", 0), None); + } #[test] fn test_mention_completion_parse() { - assert_eq!(MentionCompletion::try_parse("Lorem Ipsum", 0), None); + assert_eq!(MentionCompletion::try_parse(true, "Lorem Ipsum", 0), None); assert_eq!( - MentionCompletion::try_parse("Lorem @", 0), + MentionCompletion::try_parse(true, "Lorem @", 0), Some(MentionCompletion { source_range: 6..7, mode: None, @@ -1197,7 +1218,7 @@ mod tests { ); assert_eq!( - MentionCompletion::try_parse("Lorem @file", 0), + MentionCompletion::try_parse(true, "Lorem @file", 0), Some(MentionCompletion { source_range: 6..11, mode: Some(ContextPickerMode::File), @@ -1206,7 +1227,7 @@ mod tests { ); assert_eq!( - MentionCompletion::try_parse("Lorem @file ", 0), + MentionCompletion::try_parse(true, "Lorem @file ", 0), Some(MentionCompletion { source_range: 6..12, mode: Some(ContextPickerMode::File), @@ -1215,7 +1236,7 @@ mod tests { ); assert_eq!( - MentionCompletion::try_parse("Lorem @file main.rs", 0), + MentionCompletion::try_parse(true, "Lorem @file main.rs", 0), Some(MentionCompletion { source_range: 6..19, mode: Some(ContextPickerMode::File), @@ -1224,7 +1245,7 @@ mod tests { ); assert_eq!( - MentionCompletion::try_parse("Lorem @file main.rs ", 0), + MentionCompletion::try_parse(true, "Lorem @file main.rs ", 0), Some(MentionCompletion { source_range: 6..19, mode: Some(ContextPickerMode::File), @@ -1233,7 +1254,7 @@ mod tests { ); assert_eq!( - MentionCompletion::try_parse("Lorem @file main.rs Ipsum", 0), + MentionCompletion::try_parse(true, "Lorem @file main.rs Ipsum", 0), Some(MentionCompletion { source_range: 6..19, mode: Some(ContextPickerMode::File), @@ -1242,7 +1263,7 @@ mod tests { ); assert_eq!( - MentionCompletion::try_parse("Lorem @main", 0), + MentionCompletion::try_parse(true, "Lorem @main", 0), Some(MentionCompletion { source_range: 6..11, mode: None, @@ -1250,475 +1271,52 @@ mod tests { }) ); - assert_eq!(MentionCompletion::try_parse("test@", 0), None); - } - - struct AtMentionEditor(Entity); - - impl Item for AtMentionEditor { - type Event = (); - - fn include_in_nav_history() -> bool { - false - } - - fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { - "Test".into() - } - } - - impl EventEmitter<()> for AtMentionEditor {} - - impl Focusable for AtMentionEditor { - fn focus_handle(&self, cx: &App) -> FocusHandle { - self.0.read(cx).focus_handle(cx).clone() - } - } - - impl Render for AtMentionEditor { - fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { - self.0.clone().into_any_element() - } - } - - #[gpui::test] - async fn test_context_completion_provider(cx: &mut TestAppContext) { - init_test(cx); - - let app_state = cx.update(AppState::test); - - cx.update(|cx| { - language::init(cx); - editor::init(cx); - workspace::init(app_state.clone(), cx); - Project::init_settings(cx); - }); - - app_state - .fs - .as_fake() - .insert_tree( - path!("/dir"), - json!({ - "editor": "", - "a": { - "one.txt": "1", - "two.txt": "2", - "three.txt": "3", - "four.txt": "4" - }, - "b": { - "five.txt": "5", - "six.txt": "6", - "seven.txt": "7", - "eight.txt": "8", - } - }), - ) - .await; - - let project = Project::test(app_state.fs.clone(), [path!("/dir").as_ref()], cx).await; - let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); - let workspace = window.root(cx).unwrap(); - - let worktree = project.update(cx, |project, cx| { - let mut worktrees = project.worktrees(cx).collect::>(); - assert_eq!(worktrees.len(), 1); - worktrees.pop().unwrap() - }); - let worktree_id = worktree.read_with(cx, |worktree, _| worktree.id()); - - let mut cx = VisualTestContext::from_window(*window.deref(), cx); - - let paths = vec![ - path!("a/one.txt"), - path!("a/two.txt"), - path!("a/three.txt"), - path!("a/four.txt"), - path!("b/five.txt"), - path!("b/six.txt"), - path!("b/seven.txt"), - path!("b/eight.txt"), - ]; - - let mut opened_editors = Vec::new(); - for path in paths { - let buffer = workspace - .update_in(&mut cx, |workspace, window, cx| { - workspace.open_path( - ProjectPath { - worktree_id, - path: Path::new(path).into(), - }, - None, - false, - window, - cx, - ) - }) - .await - .unwrap(); - opened_editors.push(buffer); - } - - let editor = workspace.update_in(&mut cx, |workspace, window, cx| { - let editor = cx.new(|cx| { - Editor::new( - editor::EditorMode::full(), - multi_buffer::MultiBuffer::build_simple("", cx), - None, - window, - cx, - ) - }); - workspace.active_pane().update(cx, |pane, cx| { - pane.add_item( - Box::new(cx.new(|_| AtMentionEditor(editor.clone()))), - true, - true, - None, - window, - cx, - ); - }); - editor - }); - - let mention_set = Arc::new(Mutex::new(MentionSet::default())); - - let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx)); - let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); - - let editor_entity = editor.downgrade(); - editor.update_in(&mut cx, |editor, window, cx| { - window.focus(&editor.focus_handle(cx)); - editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new( - mention_set.clone(), - workspace.downgrade(), - thread_store.downgrade(), - text_thread_store.downgrade(), - editor_entity, - )))); - }); - - cx.simulate_input("Lorem "); - - editor.update(&mut cx, |editor, cx| { - assert_eq!(editor.text(cx), "Lorem "); - assert!(!editor.has_visible_completions_menu()); - }); - - cx.simulate_input("@"); - - editor.update(&mut cx, |editor, cx| { - assert_eq!(editor.text(cx), "Lorem @"); - assert!(editor.has_visible_completions_menu()); - assert_eq!( - current_completion_labels(editor), - &[ - "eight.txt dir/b/", - "seven.txt dir/b/", - "six.txt dir/b/", - "five.txt dir/b/", - "Files & Directories", - "Symbols", - "Threads", - "Fetch" - ] - ); - }); - - // Select and confirm "File" - editor.update_in(&mut cx, |editor, window, cx| { - assert!(editor.has_visible_completions_menu()); - editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); - editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); - editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); - editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); - editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); - }); - - cx.run_until_parked(); - - editor.update(&mut cx, |editor, cx| { - assert_eq!(editor.text(cx), "Lorem @file "); - assert!(editor.has_visible_completions_menu()); - }); - - cx.simulate_input("one"); - - editor.update(&mut cx, |editor, cx| { - assert_eq!(editor.text(cx), "Lorem @file one"); - assert!(editor.has_visible_completions_menu()); - assert_eq!(current_completion_labels(editor), vec!["one.txt dir/a/"]); - }); - - editor.update_in(&mut cx, |editor, window, cx| { - assert!(editor.has_visible_completions_menu()); - editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); - }); - - editor.update(&mut cx, |editor, cx| { - assert_eq!(editor.text(cx), "Lorem [@one.txt](file:///dir/a/one.txt) "); - assert!(!editor.has_visible_completions_menu()); - assert_eq!( - fold_ranges(editor, cx), - vec![Point::new(0, 6)..Point::new(0, 39)] - ); - }); - - let contents = cx - .update(|window, cx| { - mention_set.lock().contents( - project.clone(), - thread_store.clone(), - text_thread_store.clone(), - window, - cx, - ) - }) - .await - .unwrap() - .into_values() - .collect::>(); - - assert_eq!(contents.len(), 1); - assert_eq!(contents[0].content, "1"); assert_eq!( - contents[0].uri.to_uri().to_string(), - "file:///dir/a/one.txt" + MentionCompletion::try_parse(true, "Lorem @main ", 0), + Some(MentionCompletion { + source_range: 6..12, + mode: None, + argument: Some("main".to_string()), + }) ); - cx.simulate_input(" "); - - editor.update(&mut cx, |editor, cx| { - assert_eq!(editor.text(cx), "Lorem [@one.txt](file:///dir/a/one.txt) "); - assert!(!editor.has_visible_completions_menu()); - assert_eq!( - fold_ranges(editor, cx), - vec![Point::new(0, 6)..Point::new(0, 39)] - ); - }); - - cx.simulate_input("Ipsum "); - - editor.update(&mut cx, |editor, cx| { - assert_eq!( - editor.text(cx), - "Lorem [@one.txt](file:///dir/a/one.txt) Ipsum ", - ); - assert!(!editor.has_visible_completions_menu()); - assert_eq!( - fold_ranges(editor, cx), - vec![Point::new(0, 6)..Point::new(0, 39)] - ); - }); - - cx.simulate_input("@file "); + assert_eq!(MentionCompletion::try_parse(true, "Lorem @main m", 0), None); - editor.update(&mut cx, |editor, cx| { - assert_eq!( - editor.text(cx), - "Lorem [@one.txt](file:///dir/a/one.txt) Ipsum @file ", - ); - assert!(editor.has_visible_completions_menu()); - assert_eq!( - fold_ranges(editor, cx), - vec![Point::new(0, 6)..Point::new(0, 39)] - ); - }); + assert_eq!(MentionCompletion::try_parse(true, "test@", 0), None); - editor.update_in(&mut cx, |editor, window, cx| { - editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); - }); + // Allowed non-file mentions - cx.run_until_parked(); - - let contents = cx - .update(|window, cx| { - mention_set.lock().contents( - project.clone(), - thread_store.clone(), - text_thread_store.clone(), - window, - cx, - ) + assert_eq!( + MentionCompletion::try_parse(true, "Lorem @symbol main", 0), + Some(MentionCompletion { + source_range: 6..18, + mode: Some(ContextPickerMode::Symbol), + argument: Some("main".to_string()), }) - .await - .unwrap() - .into_values() - .collect::>(); + ); - assert_eq!(contents.len(), 2); - let new_mention = contents - .iter() - .find(|mention| mention.uri.to_uri().to_string() == "file:///dir/b/eight.txt") - .unwrap(); - assert_eq!(new_mention.content, "8"); - - editor.update(&mut cx, |editor, cx| { - assert_eq!( - editor.text(cx), - "Lorem [@one.txt](file:///dir/a/one.txt) Ipsum [@eight.txt](file:///dir/b/eight.txt) " - ); - assert!(!editor.has_visible_completions_menu()); - assert_eq!( - fold_ranges(editor, cx), - vec![ - Point::new(0, 6)..Point::new(0, 39), - Point::new(0, 47)..Point::new(0, 84) - ] - ); - }); + // Disallowed non-file mentions + assert_eq!( + MentionCompletion::try_parse(false, "Lorem @symbol main", 0), + None + ); - let plain_text_language = Arc::new(language::Language::new( - language::LanguageConfig { - name: "Plain Text".into(), - matcher: language::LanguageMatcher { - path_suffixes: vec!["txt".to_string()], - ..Default::default() - }, - ..Default::default() - }, + assert_eq!( + MentionCompletion::try_parse(true, "Lorem@symbol", 0), None, - )); - - // Register the language and fake LSP - let language_registry = project.read_with(&cx, |project, _| project.languages().clone()); - language_registry.add(plain_text_language); - - let mut fake_language_servers = language_registry.register_fake_lsp( - "Plain Text", - language::FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - workspace_symbol_provider: Some(lsp::OneOf::Left(true)), - ..Default::default() - }, - ..Default::default() - }, + "Should not parse mention inside word" ); - // Open the buffer to trigger LSP initialization - let buffer = project - .update(&mut cx, |project, cx| { - project.open_local_buffer(path!("/dir/a/one.txt"), cx) - }) - .await - .unwrap(); - - // Register the buffer with language servers - let _handle = project.update(&mut cx, |project, cx| { - project.register_buffer_with_language_servers(&buffer, cx) - }); - - cx.run_until_parked(); - - let fake_language_server = fake_language_servers.next().await.unwrap(); - fake_language_server.set_request_handler::( - |_, _| async move { - Ok(Some(lsp::WorkspaceSymbolResponse::Flat(vec![ - #[allow(deprecated)] - lsp::SymbolInformation { - name: "MySymbol".into(), - location: lsp::Location { - uri: lsp::Url::from_file_path(path!("/dir/a/one.txt")).unwrap(), - range: lsp::Range::new( - lsp::Position::new(0, 0), - lsp::Position::new(0, 1), - ), - }, - kind: lsp::SymbolKind::CONSTANT, - tags: None, - container_name: None, - deprecated: None, - }, - ]))) - }, + assert_eq!( + MentionCompletion::try_parse(true, "Lorem @ file", 0), + None, + "Should not parse with a space after @" ); - cx.simulate_input("@symbol "); - - editor.update(&mut cx, |editor, cx| { - assert_eq!( - editor.text(cx), - "Lorem [@one.txt](file:///dir/a/one.txt) Ipsum [@eight.txt](file:///dir/b/eight.txt) @symbol " - ); - assert!(editor.has_visible_completions_menu()); - assert_eq!( - current_completion_labels(editor), - &[ - "MySymbol", - ] - ); - }); - - editor.update_in(&mut cx, |editor, window, cx| { - editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); - }); - - let contents = cx - .update(|window, cx| { - mention_set.lock().contents( - project.clone(), - thread_store, - text_thread_store, - window, - cx, - ) - }) - .await - .unwrap() - .into_values() - .collect::>(); - - assert_eq!(contents.len(), 3); - let new_mention = contents - .iter() - .find(|mention| { - mention.uri.to_uri().to_string() == "file:///dir/a/one.txt?symbol=MySymbol#L1:1" - }) - .unwrap(); - assert_eq!(new_mention.content, "1"); - - cx.run_until_parked(); - - editor.read_with(&mut cx, |editor, cx| { - assert_eq!( - editor.text(cx), - "Lorem [@one.txt](file:///dir/a/one.txt) Ipsum [@eight.txt](file:///dir/b/eight.txt) [@MySymbol](file:///dir/a/one.txt?symbol=MySymbol#L1:1) " - ); - }); - } - - fn fold_ranges(editor: &Editor, cx: &mut App) -> Vec> { - let snapshot = editor.buffer().read(cx).snapshot(cx); - editor.display_map.update(cx, |display_map, cx| { - display_map - .snapshot(cx) - .folds_in_range(0..snapshot.len()) - .map(|fold| fold.range.to_point(&snapshot)) - .collect() - }) - } - - fn current_completion_labels(editor: &Editor) -> Vec { - let completions = editor.current_completions().expect("Missing completions"); - completions - .into_iter() - .map(|completion| completion.label.text.to_string()) - .collect::>() - } - - pub(crate) fn init_test(cx: &mut TestAppContext) { - cx.update(|cx| { - let store = SettingsStore::test(cx); - cx.set_global(store); - theme::init(theme::LoadThemes::JustBase, cx); - client::init_settings(cx); - language::init(cx); - Project::init_settings(cx); - workspace::init_settings(cx); - editor::init_settings(cx); - }); + assert_eq!( + MentionCompletion::try_parse(true, "@ file", 0), + None, + "Should not parse with a space after @ at the start of the line" + ); } } diff --git a/crates/agent_ui/src/acp/entry_view_state.rs b/crates/agent_ui/src/acp/entry_view_state.rs new file mode 100644 index 0000000000000000000000000000000000000000..ec57ea7e6df3244b6ea1bcb99212d845fa68c457 --- /dev/null +++ b/crates/agent_ui/src/acp/entry_view_state.rs @@ -0,0 +1,554 @@ +use std::{ + cell::{Cell, RefCell}, + ops::Range, + rc::Rc, +}; + +use acp_thread::{AcpThread, AgentThreadEntry}; +use agent_client_protocol::{self as acp, ToolCallId}; +use agent2::HistoryStore; +use collections::HashMap; +use editor::{Editor, EditorMode, MinimapVisibility}; +use gpui::{ + AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, FocusHandle, Focusable, + ScrollHandle, SharedString, TextStyleRefinement, WeakEntity, Window, +}; +use language::language_settings::SoftWrap; +use project::Project; +use prompt_store::PromptStore; +use settings::Settings as _; +use terminal_view::TerminalView; +use theme::ThemeSettings; +use ui::{Context, TextSize}; +use workspace::Workspace; + +use crate::acp::message_editor::{MessageEditor, MessageEditorEvent}; + +pub struct EntryViewState { + workspace: WeakEntity, + project: Entity, + history_store: Entity, + prompt_store: Option>, + entries: Vec, + prompt_capabilities: Rc>, + available_commands: Rc>>, + agent_name: SharedString, +} + +impl EntryViewState { + pub fn new( + workspace: WeakEntity, + project: Entity, + history_store: Entity, + prompt_store: Option>, + prompt_capabilities: Rc>, + available_commands: Rc>>, + agent_name: SharedString, + ) -> Self { + Self { + workspace, + project, + history_store, + prompt_store, + entries: Vec::new(), + prompt_capabilities, + available_commands, + agent_name, + } + } + + pub fn entry(&self, index: usize) -> Option<&Entry> { + self.entries.get(index) + } + + pub fn sync_entry( + &mut self, + index: usize, + thread: &Entity, + window: &mut Window, + cx: &mut Context, + ) { + let Some(thread_entry) = thread.read(cx).entries().get(index) else { + return; + }; + + match thread_entry { + AgentThreadEntry::UserMessage(message) => { + let has_id = message.id.is_some(); + let chunks = message.chunks.clone(); + if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) { + if !editor.focus_handle(cx).is_focused(window) { + // Only update if we are not editing. + // If we are, cancelling the edit will set the message to the newest content. + editor.update(cx, |editor, cx| { + editor.set_message(chunks, window, cx); + }); + } + } else { + let message_editor = cx.new(|cx| { + let mut editor = MessageEditor::new( + self.workspace.clone(), + self.project.clone(), + self.history_store.clone(), + self.prompt_store.clone(), + self.prompt_capabilities.clone(), + self.available_commands.clone(), + self.agent_name.clone(), + "Edit message - @ to include context", + editor::EditorMode::AutoHeight { + min_lines: 1, + max_lines: None, + }, + window, + cx, + ); + if !has_id { + editor.set_read_only(true, cx); + } + editor.set_message(chunks, window, cx); + editor + }); + cx.subscribe(&message_editor, move |_, editor, event, cx| { + cx.emit(EntryViewEvent { + entry_index: index, + view_event: ViewEvent::MessageEditorEvent(editor, *event), + }) + }) + .detach(); + self.set_entry(index, Entry::UserMessage(message_editor)); + } + } + AgentThreadEntry::ToolCall(tool_call) => { + let id = tool_call.id.clone(); + let terminals = tool_call.terminals().cloned().collect::>(); + let diffs = tool_call.diffs().cloned().collect::>(); + + let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) { + views + } else { + self.set_entry(index, Entry::empty()); + let Some(Entry::Content(views)) = self.entries.get_mut(index) else { + unreachable!() + }; + views + }; + + let is_tool_call_completed = + matches!(tool_call.status, acp_thread::ToolCallStatus::Completed); + + for terminal in terminals { + match views.entry(terminal.entity_id()) { + collections::hash_map::Entry::Vacant(entry) => { + let element = create_terminal( + self.workspace.clone(), + self.project.clone(), + terminal.clone(), + window, + cx, + ) + .into_any(); + cx.emit(EntryViewEvent { + entry_index: index, + view_event: ViewEvent::NewTerminal(id.clone()), + }); + entry.insert(element); + } + collections::hash_map::Entry::Occupied(_entry) => { + if is_tool_call_completed && terminal.read(cx).output().is_none() { + cx.emit(EntryViewEvent { + entry_index: index, + view_event: ViewEvent::TerminalMovedToBackground(id.clone()), + }); + } + } + } + } + + for diff in diffs { + views.entry(diff.entity_id()).or_insert_with(|| { + let element = create_editor_diff(diff.clone(), window, cx).into_any(); + cx.emit(EntryViewEvent { + entry_index: index, + view_event: ViewEvent::NewDiff(id.clone()), + }); + element + }); + } + } + AgentThreadEntry::AssistantMessage(message) => { + let entry = if let Some(Entry::AssistantMessage(entry)) = + self.entries.get_mut(index) + { + entry + } else { + self.set_entry( + index, + Entry::AssistantMessage(AssistantMessageEntry::default()), + ); + let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else { + unreachable!() + }; + entry + }; + entry.sync(message); + } + }; + } + + fn set_entry(&mut self, index: usize, entry: Entry) { + if index == self.entries.len() { + self.entries.push(entry); + } else { + self.entries[index] = entry; + } + } + + pub fn remove(&mut self, range: Range) { + self.entries.drain(range); + } + + pub fn agent_font_size_changed(&mut self, cx: &mut App) { + for entry in self.entries.iter() { + match entry { + Entry::UserMessage { .. } | Entry::AssistantMessage { .. } => {} + Entry::Content(response_views) => { + for view in response_views.values() { + if let Ok(diff_editor) = view.clone().downcast::() { + diff_editor.update(cx, |diff_editor, cx| { + diff_editor.set_text_style_refinement( + diff_editor_text_style_refinement(cx), + ); + cx.notify(); + }) + } + } + } + } + } + } +} + +impl EventEmitter for EntryViewState {} + +pub struct EntryViewEvent { + pub entry_index: usize, + pub view_event: ViewEvent, +} + +pub enum ViewEvent { + NewDiff(ToolCallId), + NewTerminal(ToolCallId), + TerminalMovedToBackground(ToolCallId), + MessageEditorEvent(Entity, MessageEditorEvent), +} + +#[derive(Default, Debug)] +pub struct AssistantMessageEntry { + scroll_handles_by_chunk_index: HashMap, +} + +impl AssistantMessageEntry { + pub fn scroll_handle_for_chunk(&self, ix: usize) -> Option { + self.scroll_handles_by_chunk_index.get(&ix).cloned() + } + + pub fn sync(&mut self, message: &acp_thread::AssistantMessage) { + if let Some(acp_thread::AssistantMessageChunk::Thought { .. }) = message.chunks.last() { + let ix = message.chunks.len() - 1; + let handle = self.scroll_handles_by_chunk_index.entry(ix).or_default(); + handle.scroll_to_bottom(); + } + } +} + +#[derive(Debug)] +pub enum Entry { + UserMessage(Entity), + AssistantMessage(AssistantMessageEntry), + Content(HashMap), +} + +impl Entry { + pub fn focus_handle(&self, cx: &App) -> Option { + match self { + Self::UserMessage(editor) => Some(editor.read(cx).focus_handle(cx)), + Self::AssistantMessage(_) | Self::Content(_) => None, + } + } + + pub fn message_editor(&self) -> Option<&Entity> { + match self { + Self::UserMessage(editor) => Some(editor), + Self::AssistantMessage(_) | Self::Content(_) => None, + } + } + + pub fn editor_for_diff(&self, diff: &Entity) -> Option> { + self.content_map()? + .get(&diff.entity_id()) + .cloned() + .map(|entity| entity.downcast::().unwrap()) + } + + pub fn terminal( + &self, + terminal: &Entity, + ) -> Option> { + self.content_map()? + .get(&terminal.entity_id()) + .cloned() + .map(|entity| entity.downcast::().unwrap()) + } + + pub fn scroll_handle_for_assistant_message_chunk( + &self, + chunk_ix: usize, + ) -> Option { + match self { + Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix), + Self::UserMessage(_) | Self::Content(_) => None, + } + } + + fn content_map(&self) -> Option<&HashMap> { + match self { + Self::Content(map) => Some(map), + _ => None, + } + } + + fn empty() -> Self { + Self::Content(HashMap::default()) + } + + #[cfg(test)] + pub fn has_content(&self) -> bool { + match self { + Self::Content(map) => !map.is_empty(), + Self::UserMessage(_) | Self::AssistantMessage(_) => false, + } + } +} + +fn create_terminal( + workspace: WeakEntity, + project: Entity, + terminal: Entity, + window: &mut Window, + cx: &mut App, +) -> Entity { + cx.new(|cx| { + let mut view = TerminalView::new( + terminal.read(cx).inner().clone(), + workspace.clone(), + None, + project.downgrade(), + window, + cx, + ); + view.set_embedded_mode(Some(1000), cx); + view + }) +} + +fn create_editor_diff( + diff: Entity, + window: &mut Window, + cx: &mut App, +) -> Entity { + cx.new(|cx| { + let mut editor = Editor::new( + EditorMode::Full { + scale_ui_elements_with_buffer_font_size: false, + show_active_line_background: false, + sized_by_content: true, + }, + diff.read(cx).multibuffer().clone(), + None, + window, + cx, + ); + editor.set_show_gutter(false, cx); + editor.disable_inline_diagnostics(); + editor.disable_expand_excerpt_buttons(cx); + editor.set_show_vertical_scrollbar(false, cx); + editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); + editor.set_soft_wrap_mode(SoftWrap::None, cx); + editor.scroll_manager.set_forbid_vertical_scroll(true); + editor.set_show_indent_guides(false, cx); + editor.set_read_only(true); + editor.set_show_breakpoints(false, cx); + editor.set_show_code_actions(false, cx); + editor.set_show_git_diff_gutter(false, cx); + editor.set_expand_all_diff_hunks(cx); + editor.set_text_style_refinement(diff_editor_text_style_refinement(cx)); + editor + }) +} + +fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement { + TextStyleRefinement { + font_size: Some( + TextSize::Small + .rems(cx) + .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) + .into(), + ), + ..Default::default() + } +} + +#[cfg(test)] +mod tests { + use std::{path::Path, rc::Rc}; + + use acp_thread::{AgentConnection, StubAgentConnection}; + use agent_client_protocol as acp; + use agent_settings::AgentSettings; + use agent2::HistoryStore; + use assistant_context::ContextStore; + use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind}; + use editor::{EditorSettings, RowInfo}; + use fs::FakeFs; + use gpui::{AppContext as _, SemanticVersion, TestAppContext}; + + use crate::acp::entry_view_state::EntryViewState; + use multi_buffer::MultiBufferRow; + use pretty_assertions::assert_matches; + use project::Project; + use serde_json::json; + use settings::{Settings as _, SettingsStore}; + use theme::ThemeSettings; + use util::path; + use workspace::Workspace; + + #[gpui::test] + async fn test_diff_sync(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + "hello.txt": "hi world" + }), + ) + .await; + let project = Project::test(fs, [Path::new(path!("/project"))], cx).await; + + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let tool_call = acp::ToolCall { + id: acp::ToolCallId("tool".into()), + title: "Tool call".into(), + kind: acp::ToolKind::Other, + status: acp::ToolCallStatus::InProgress, + content: vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: "/project/hello.txt".into(), + old_text: Some("hi world".into()), + new_text: "hello world".into(), + }, + }], + locations: vec![], + raw_input: None, + raw_output: None, + }; + let connection = Rc::new(StubAgentConnection::new()); + let thread = cx + .update(|_, cx| { + connection + .clone() + .new_thread(project.clone(), Path::new(path!("/project")), cx) + }) + .await + .unwrap(); + let session_id = thread.update(cx, |thread, _| thread.session_id().clone()); + + cx.update(|_, cx| { + connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx) + }); + + let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + + let view_state = cx.new(|_cx| { + EntryViewState::new( + workspace.downgrade(), + project.clone(), + history_store, + None, + Default::default(), + Default::default(), + "Test Agent".into(), + ) + }); + + view_state.update_in(cx, |view_state, window, cx| { + view_state.sync_entry(0, &thread, window, cx) + }); + + let diff = thread.read_with(cx, |thread, _cx| { + thread + .entries() + .get(0) + .unwrap() + .diffs() + .next() + .unwrap() + .clone() + }); + + cx.run_until_parked(); + + let diff_editor = view_state.read_with(cx, |view_state, _cx| { + view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap() + }); + assert_eq!( + diff_editor.read_with(cx, |editor, cx| editor.text(cx)), + "hi world\nhello world" + ); + let row_infos = diff_editor.read_with(cx, |editor, cx| { + let multibuffer = editor.buffer().read(cx); + multibuffer + .snapshot(cx) + .row_infos(MultiBufferRow(0)) + .collect::>() + }); + assert_matches!( + row_infos.as_slice(), + [ + RowInfo { + multibuffer_row: Some(MultiBufferRow(0)), + diff_status: Some(DiffHunkStatus { + kind: DiffHunkStatusKind::Deleted, + .. + }), + .. + }, + RowInfo { + multibuffer_row: Some(MultiBufferRow(1)), + diff_status: Some(DiffHunkStatus { + kind: DiffHunkStatusKind::Added, + .. + }), + .. + } + ] + ); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + AgentSettings::register(cx); + workspace::init_settings(cx); + ThemeSettings::register(cx); + release_channel::init(SemanticVersion::default(), cx); + EditorSettings::register(cx); + }); + } +} diff --git a/crates/agent_ui/src/acp/message_editor.rs b/crates/agent_ui/src/acp/message_editor.rs new file mode 100644 index 0000000000000000000000000000000000000000..02ee46e840299b9307253603f3c165bbd525d377 --- /dev/null +++ b/crates/agent_ui/src/acp/message_editor.rs @@ -0,0 +1,2587 @@ +use crate::{ + acp::completion_provider::{ContextPickerCompletionProvider, SlashCommandCompletion}, + context_picker::{ContextPickerAction, fetch_context_picker::fetch_url_content}, +}; +use acp_thread::{MentionUri, selection_name}; +use agent_client_protocol as acp; +use agent_servers::{AgentServer, AgentServerDelegate}; +use agent2::HistoryStore; +use anyhow::{Result, anyhow}; +use assistant_slash_commands::codeblock_fence_for_path; +use collections::{HashMap, HashSet}; +use editor::{ + Addon, Anchor, AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, + EditorEvent, EditorMode, EditorSnapshot, EditorStyle, ExcerptId, FoldPlaceholder, InlayId, + MultiBuffer, ToOffset, + actions::Paste, + display_map::{Crease, CreaseId, FoldId, Inlay}, +}; +use futures::{ + FutureExt as _, + future::{Shared, join_all}, +}; +use gpui::{ + Animation, AnimationExt as _, AppContext, ClipboardEntry, Context, Entity, EntityId, + EventEmitter, FocusHandle, Focusable, Image, ImageFormat, Img, KeyContext, SharedString, + Subscription, Task, TextStyle, WeakEntity, pulsating_between, +}; +use language::{Buffer, Language, language_settings::InlayHintKind}; +use language_model::LanguageModelImage; +use postage::stream::Stream as _; +use project::{ + CompletionIntent, InlayHint, InlayHintLabel, Project, ProjectItem, ProjectPath, Worktree, +}; +use prompt_store::{PromptId, PromptStore}; +use rope::Point; +use settings::Settings; +use std::{ + cell::{Cell, RefCell}, + ffi::OsStr, + fmt::Write, + ops::{Range, RangeInclusive}, + path::{Path, PathBuf}, + rc::Rc, + sync::Arc, + time::Duration, +}; +use text::OffsetRangeExt; +use theme::ThemeSettings; +use ui::{ + ActiveTheme, AnyElement, App, ButtonCommon, ButtonLike, ButtonStyle, Color, Element as _, + FluentBuilder as _, Icon, IconName, IconSize, InteractiveElement, IntoElement, Label, + LabelCommon, LabelSize, ParentElement, Render, SelectableButton, Styled, TextSize, TintColor, + Toggleable, Window, div, h_flex, +}; +use util::{ResultExt, debug_panic}; +use workspace::{Workspace, notifications::NotifyResultExt as _}; +use zed_actions::agent::Chat; + +pub struct MessageEditor { + mention_set: MentionSet, + editor: Entity, + project: Entity, + workspace: WeakEntity, + history_store: Entity, + prompt_store: Option>, + prompt_capabilities: Rc>, + available_commands: Rc>>, + agent_name: SharedString, + _subscriptions: Vec, + _parse_slash_command_task: Task<()>, +} + +#[derive(Clone, Copy, Debug)] +pub enum MessageEditorEvent { + Send, + Cancel, + Focus, + LostFocus, +} + +impl EventEmitter for MessageEditor {} + +const COMMAND_HINT_INLAY_ID: usize = 0; + +impl MessageEditor { + pub fn new( + workspace: WeakEntity, + project: Entity, + history_store: Entity, + prompt_store: Option>, + prompt_capabilities: Rc>, + available_commands: Rc>>, + agent_name: SharedString, + placeholder: &str, + mode: EditorMode, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let language = Language::new( + language::LanguageConfig { + completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']), + ..Default::default() + }, + None, + ); + let completion_provider = Rc::new(ContextPickerCompletionProvider::new( + cx.weak_entity(), + workspace.clone(), + history_store.clone(), + prompt_store.clone(), + prompt_capabilities.clone(), + available_commands.clone(), + )); + let mention_set = MentionSet::default(); + let editor = cx.new(|cx| { + let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx)); + let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); + + let mut editor = Editor::new(mode, buffer, None, window, cx); + editor.set_placeholder_text(placeholder, window, cx); + editor.set_show_indent_guides(false, cx); + editor.set_soft_wrap(); + editor.set_use_modal_editing(true); + editor.set_completion_provider(Some(completion_provider.clone())); + editor.set_context_menu_options(ContextMenuOptions { + min_entries_visible: 12, + max_entries_visible: 12, + placement: Some(ContextMenuPlacement::Above), + }); + editor.register_addon(MessageEditorAddon::new()); + editor + }); + + cx.on_focus_in(&editor.focus_handle(cx), window, |_, _, cx| { + cx.emit(MessageEditorEvent::Focus) + }) + .detach(); + cx.on_focus_out(&editor.focus_handle(cx), window, |_, _, _, cx| { + cx.emit(MessageEditorEvent::LostFocus) + }) + .detach(); + + let mut has_hint = false; + let mut subscriptions = Vec::new(); + + subscriptions.push(cx.subscribe_in(&editor, window, { + move |this, editor, event, window, cx| { + if let EditorEvent::Edited { .. } = event { + let snapshot = editor.update(cx, |editor, cx| { + let new_hints = this + .command_hint(editor.buffer(), cx) + .into_iter() + .collect::>(); + let has_new_hint = !new_hints.is_empty(); + editor.splice_inlays( + if has_hint { + &[InlayId::Hint(COMMAND_HINT_INLAY_ID)] + } else { + &[] + }, + new_hints, + cx, + ); + has_hint = has_new_hint; + + editor.snapshot(window, cx) + }); + this.mention_set.remove_invalid(snapshot); + + cx.notify(); + } + } + })); + + Self { + editor, + project, + mention_set, + workspace, + history_store, + prompt_store, + prompt_capabilities, + available_commands, + agent_name, + _subscriptions: subscriptions, + _parse_slash_command_task: Task::ready(()), + } + } + + fn command_hint(&self, buffer: &Entity, cx: &App) -> Option { + let available_commands = self.available_commands.borrow(); + if available_commands.is_empty() { + return None; + } + + let snapshot = buffer.read(cx).snapshot(cx); + let parsed_command = SlashCommandCompletion::try_parse(&snapshot.text(), 0)?; + if parsed_command.argument.is_some() { + return None; + } + + let command_name = parsed_command.command?; + let available_command = available_commands + .iter() + .find(|command| command.name == command_name)?; + + let acp::AvailableCommandInput::Unstructured { mut hint } = + available_command.input.clone()?; + + let mut hint_pos = parsed_command.source_range.end + 1; + if hint_pos > snapshot.len() { + hint_pos = snapshot.len(); + hint.insert(0, ' '); + } + + let hint_pos = snapshot.anchor_after(hint_pos); + + Some(Inlay::hint( + COMMAND_HINT_INLAY_ID, + hint_pos, + &InlayHint { + position: hint_pos.text_anchor, + label: InlayHintLabel::String(hint), + kind: Some(InlayHintKind::Parameter), + padding_left: false, + padding_right: false, + tooltip: None, + resolve_state: project::ResolveState::Resolved, + }, + )) + } + + pub fn insert_thread_summary( + &mut self, + thread: agent2::DbThreadMetadata, + window: &mut Window, + cx: &mut Context, + ) { + let start = self.editor.update(cx, |editor, cx| { + editor.set_text(format!("{}\n", thread.title), window, cx); + editor + .buffer() + .read(cx) + .snapshot(cx) + .anchor_before(Point::zero()) + .text_anchor + }); + + self.confirm_mention_completion( + thread.title.clone(), + start, + thread.title.len(), + MentionUri::Thread { + id: thread.id.clone(), + name: thread.title.to_string(), + }, + window, + cx, + ) + .detach(); + } + + #[cfg(test)] + pub(crate) fn editor(&self) -> &Entity { + &self.editor + } + + #[cfg(test)] + pub(crate) fn mention_set(&mut self) -> &mut MentionSet { + &mut self.mention_set + } + + pub fn is_empty(&self, cx: &App) -> bool { + self.editor.read(cx).is_empty(cx) + } + + pub fn mentions(&self) -> HashSet { + self.mention_set + .mentions + .values() + .map(|(uri, _)| uri.clone()) + .collect() + } + + pub fn confirm_mention_completion( + &mut self, + crease_text: SharedString, + start: text::Anchor, + content_len: usize, + mention_uri: MentionUri, + window: &mut Window, + cx: &mut Context, + ) -> Task<()> { + let snapshot = self + .editor + .update(cx, |editor, cx| editor.snapshot(window, cx)); + let Some((excerpt_id, _, _)) = snapshot.buffer_snapshot.as_singleton() else { + return Task::ready(()); + }; + let Some(start_anchor) = snapshot + .buffer_snapshot + .anchor_in_excerpt(*excerpt_id, start) + else { + return Task::ready(()); + }; + let end_anchor = snapshot + .buffer_snapshot + .anchor_before(start_anchor.to_offset(&snapshot.buffer_snapshot) + content_len + 1); + + let crease = if let MentionUri::File { abs_path } = &mention_uri + && let Some(extension) = abs_path.extension() + && let Some(extension) = extension.to_str() + && Img::extensions().contains(&extension) + && !extension.contains("svg") + { + let Some(project_path) = self + .project + .read(cx) + .project_path_for_absolute_path(&abs_path, cx) + else { + log::error!("project path not found"); + return Task::ready(()); + }; + let image = self + .project + .update(cx, |project, cx| project.open_image(project_path, cx)); + let image = cx + .spawn(async move |_, cx| { + let image = image.await.map_err(|e| e.to_string())?; + let image = image + .update(cx, |image, _| image.image.clone()) + .map_err(|e| e.to_string())?; + Ok(image) + }) + .shared(); + insert_crease_for_mention( + *excerpt_id, + start, + content_len, + mention_uri.name().into(), + IconName::Image.path().into(), + Some(image), + self.editor.clone(), + window, + cx, + ) + } else { + insert_crease_for_mention( + *excerpt_id, + start, + content_len, + crease_text, + mention_uri.icon_path(cx), + None, + self.editor.clone(), + window, + cx, + ) + }; + let Some((crease_id, tx)) = crease else { + return Task::ready(()); + }; + + let task = match mention_uri.clone() { + MentionUri::Fetch { url } => self.confirm_mention_for_fetch(url, cx), + MentionUri::Directory { abs_path } => self.confirm_mention_for_directory(abs_path, cx), + MentionUri::Thread { id, .. } => self.confirm_mention_for_thread(id, cx), + MentionUri::TextThread { path, .. } => self.confirm_mention_for_text_thread(path, cx), + MentionUri::File { abs_path } => self.confirm_mention_for_file(abs_path, cx), + MentionUri::Symbol { + abs_path, + line_range, + .. + } => self.confirm_mention_for_symbol(abs_path, line_range, cx), + MentionUri::Rule { id, .. } => self.confirm_mention_for_rule(id, cx), + MentionUri::PastedImage => { + debug_panic!("pasted image URI should not be included in completions"); + Task::ready(Err(anyhow!( + "pasted imaged URI should not be included in completions" + ))) + } + MentionUri::Selection { .. } => { + // Handled elsewhere + debug_panic!("unexpected selection URI"); + Task::ready(Err(anyhow!("unexpected selection URI"))) + } + }; + let task = cx + .spawn(async move |_, _| task.await.map_err(|e| e.to_string())) + .shared(); + self.mention_set + .mentions + .insert(crease_id, (mention_uri, task.clone())); + + // Notify the user if we failed to load the mentioned context + cx.spawn_in(window, async move |this, cx| { + let result = task.await.notify_async_err(cx); + drop(tx); + if result.is_none() { + this.update(cx, |this, cx| { + this.editor.update(cx, |editor, cx| { + // Remove mention + editor.edit([(start_anchor..end_anchor, "")], cx); + }); + this.mention_set.mentions.remove(&crease_id); + }) + .ok(); + } + }) + } + + fn confirm_mention_for_file( + &mut self, + abs_path: PathBuf, + cx: &mut Context, + ) -> Task> { + let Some(project_path) = self + .project + .read(cx) + .project_path_for_absolute_path(&abs_path, cx) + else { + return Task::ready(Err(anyhow!("project path not found"))); + }; + let extension = abs_path + .extension() + .and_then(OsStr::to_str) + .unwrap_or_default(); + + if Img::extensions().contains(&extension) && !extension.contains("svg") { + if !self.prompt_capabilities.get().image { + return Task::ready(Err(anyhow!("This model does not support images yet"))); + } + let task = self + .project + .update(cx, |project, cx| project.open_image(project_path, cx)); + return cx.spawn(async move |_, cx| { + let image = task.await?; + let image = image.update(cx, |image, _| image.image.clone())?; + let format = image.format; + let image = cx + .update(|cx| LanguageModelImage::from_image(image, cx))? + .await; + if let Some(image) = image { + Ok(Mention::Image(MentionImage { + data: image.source, + format, + })) + } else { + Err(anyhow!("Failed to convert image")) + } + }); + } + + let buffer = self + .project + .update(cx, |project, cx| project.open_buffer(project_path, cx)); + cx.spawn(async move |_, cx| { + let buffer = buffer.await?; + let mention = buffer.update(cx, |buffer, cx| Mention::Text { + content: buffer.text(), + tracked_buffers: vec![cx.entity()], + })?; + anyhow::Ok(mention) + }) + } + + fn confirm_mention_for_directory( + &mut self, + abs_path: PathBuf, + cx: &mut Context, + ) -> Task> { + fn collect_files_in_path(worktree: &Worktree, path: &Path) -> Vec<(Arc, PathBuf)> { + let mut files = Vec::new(); + + for entry in worktree.child_entries(path) { + if entry.is_dir() { + files.extend(collect_files_in_path(worktree, &entry.path)); + } else if entry.is_file() { + files.push((entry.path.clone(), worktree.full_path(&entry.path))); + } + } + + files + } + + let Some(project_path) = self + .project + .read(cx) + .project_path_for_absolute_path(&abs_path, cx) + else { + return Task::ready(Err(anyhow!("project path not found"))); + }; + let Some(entry) = self.project.read(cx).entry_for_path(&project_path, cx) else { + return Task::ready(Err(anyhow!("project entry not found"))); + }; + let directory_path = entry.path.clone(); + let worktree_id = project_path.worktree_id; + let Some(worktree) = self.project.read(cx).worktree_for_id(worktree_id, cx) else { + return Task::ready(Err(anyhow!("worktree not found"))); + }; + let project = self.project.clone(); + cx.spawn(async move |_, cx| { + let file_paths = worktree.read_with(cx, |worktree, _cx| { + collect_files_in_path(worktree, &directory_path) + })?; + let descendants_future = cx.update(|cx| { + join_all(file_paths.into_iter().map(|(worktree_path, full_path)| { + let rel_path = worktree_path + .strip_prefix(&directory_path) + .log_err() + .map_or_else(|| worktree_path.clone(), |rel_path| rel_path.into()); + + let open_task = project.update(cx, |project, cx| { + project.buffer_store().update(cx, |buffer_store, cx| { + let project_path = ProjectPath { + worktree_id, + path: worktree_path, + }; + buffer_store.open_buffer(project_path, cx) + }) + }); + + // TODO: report load errors instead of just logging + let rope_task = cx.spawn(async move |cx| { + let buffer = open_task.await.log_err()?; + let rope = buffer + .read_with(cx, |buffer, _cx| buffer.as_rope().clone()) + .log_err()?; + Some((rope, buffer)) + }); + + cx.background_spawn(async move { + let (rope, buffer) = rope_task.await?; + Some((rel_path, full_path, rope.to_string(), buffer)) + }) + })) + })?; + + let contents = cx + .background_spawn(async move { + let (contents, tracked_buffers) = descendants_future + .await + .into_iter() + .flatten() + .map(|(rel_path, full_path, rope, buffer)| { + ((rel_path, full_path, rope), buffer) + }) + .unzip(); + Mention::Text { + content: render_directory_contents(contents), + tracked_buffers, + } + }) + .await; + anyhow::Ok(contents) + }) + } + + fn confirm_mention_for_fetch( + &mut self, + url: url::Url, + cx: &mut Context, + ) -> Task> { + let http_client = match self + .workspace + .update(cx, |workspace, _| workspace.client().http_client()) + { + Ok(http_client) => http_client, + Err(e) => return Task::ready(Err(e)), + }; + cx.background_executor().spawn(async move { + let content = fetch_url_content(http_client, url.to_string()).await?; + Ok(Mention::Text { + content, + tracked_buffers: Vec::new(), + }) + }) + } + + fn confirm_mention_for_symbol( + &mut self, + abs_path: PathBuf, + line_range: RangeInclusive, + cx: &mut Context, + ) -> Task> { + let Some(project_path) = self + .project + .read(cx) + .project_path_for_absolute_path(&abs_path, cx) + else { + return Task::ready(Err(anyhow!("project path not found"))); + }; + let buffer = self + .project + .update(cx, |project, cx| project.open_buffer(project_path, cx)); + cx.spawn(async move |_, cx| { + let buffer = buffer.await?; + let mention = buffer.update(cx, |buffer, cx| { + let start = Point::new(*line_range.start(), 0).min(buffer.max_point()); + let end = Point::new(*line_range.end() + 1, 0).min(buffer.max_point()); + let content = buffer.text_for_range(start..end).collect(); + Mention::Text { + content, + tracked_buffers: vec![cx.entity()], + } + })?; + anyhow::Ok(mention) + }) + } + + fn confirm_mention_for_rule( + &mut self, + id: PromptId, + cx: &mut Context, + ) -> Task> { + let Some(prompt_store) = self.prompt_store.clone() else { + return Task::ready(Err(anyhow!("missing prompt store"))); + }; + let prompt = prompt_store.read(cx).load(id, cx); + cx.spawn(async move |_, _| { + let prompt = prompt.await?; + Ok(Mention::Text { + content: prompt, + tracked_buffers: Vec::new(), + }) + }) + } + + pub fn confirm_mention_for_selection( + &mut self, + source_range: Range, + selections: Vec<(Entity, Range, Range)>, + window: &mut Window, + cx: &mut Context, + ) { + let snapshot = self.editor.read(cx).buffer().read(cx).snapshot(cx); + let Some((&excerpt_id, _, _)) = snapshot.as_singleton() else { + return; + }; + let Some(start) = snapshot.anchor_in_excerpt(excerpt_id, source_range.start) else { + return; + }; + + let offset = start.to_offset(&snapshot); + + for (buffer, selection_range, range_to_fold) in selections { + let range = snapshot.anchor_after(offset + range_to_fold.start) + ..snapshot.anchor_after(offset + range_to_fold.end); + + let abs_path = buffer + .read(cx) + .project_path(cx) + .and_then(|project_path| self.project.read(cx).absolute_path(&project_path, cx)); + let snapshot = buffer.read(cx).snapshot(); + + let text = snapshot + .text_for_range(selection_range.clone()) + .collect::(); + let point_range = selection_range.to_point(&snapshot); + let line_range = point_range.start.row..=point_range.end.row; + + let uri = MentionUri::Selection { + abs_path: abs_path.clone(), + line_range: line_range.clone(), + }; + let crease = crate::context_picker::crease_for_mention( + selection_name(abs_path.as_deref(), &line_range).into(), + uri.icon_path(cx), + range, + self.editor.downgrade(), + ); + + let crease_id = self.editor.update(cx, |editor, cx| { + let crease_ids = editor.insert_creases(vec![crease.clone()], cx); + editor.fold_creases(vec![crease], false, window, cx); + crease_ids.first().copied().unwrap() + }); + + self.mention_set.mentions.insert( + crease_id, + ( + uri, + Task::ready(Ok(Mention::Text { + content: text, + tracked_buffers: vec![buffer], + })) + .shared(), + ), + ); + } + } + + fn confirm_mention_for_thread( + &mut self, + id: acp::SessionId, + cx: &mut Context, + ) -> Task> { + let server = Rc::new(agent2::NativeAgentServer::new( + self.project.read(cx).fs().clone(), + self.history_store.clone(), + )); + let delegate = AgentServerDelegate::new( + self.project.read(cx).agent_server_store().clone(), + self.project.clone(), + None, + None, + ); + let connection = server.connect(None, delegate, cx); + cx.spawn(async move |_, cx| { + let (agent, _) = connection.await?; + let agent = agent.downcast::().unwrap(); + let summary = agent + .0 + .update(cx, |agent, cx| agent.thread_summary(id, cx))? + .await?; + anyhow::Ok(Mention::Text { + content: summary.to_string(), + tracked_buffers: Vec::new(), + }) + }) + } + + fn confirm_mention_for_text_thread( + &mut self, + path: PathBuf, + cx: &mut Context, + ) -> Task> { + let context = self.history_store.update(cx, |text_thread_store, cx| { + text_thread_store.load_text_thread(path.as_path().into(), cx) + }); + cx.spawn(async move |_, cx| { + let context = context.await?; + let xml = context.update(cx, |context, cx| context.to_xml(cx))?; + Ok(Mention::Text { + content: xml, + tracked_buffers: Vec::new(), + }) + }) + } + + fn validate_slash_commands( + text: &str, + available_commands: &[acp::AvailableCommand], + agent_name: &str, + ) -> Result<()> { + if let Some(parsed_command) = SlashCommandCompletion::try_parse(text, 0) { + if let Some(command_name) = parsed_command.command { + // Check if this command is in the list of available commands from the server + let is_supported = available_commands + .iter() + .any(|cmd| cmd.name == command_name); + + if !is_supported { + return Err(anyhow!( + "The /{} command is not supported by {}.\n\nAvailable commands: {}", + command_name, + agent_name, + if available_commands.is_empty() { + "none".to_string() + } else { + available_commands + .iter() + .map(|cmd| format!("/{}", cmd.name)) + .collect::>() + .join(", ") + } + )); + } + } + } + Ok(()) + } + + pub fn contents( + &self, + cx: &mut Context, + ) -> Task, Vec>)>> { + // Check for unsupported slash commands before spawning async task + let text = self.editor.read(cx).text(cx); + let available_commands = self.available_commands.borrow().clone(); + if let Err(err) = + Self::validate_slash_commands(&text, &available_commands, &self.agent_name) + { + return Task::ready(Err(err)); + } + + let contents = self + .mention_set + .contents(&self.prompt_capabilities.get(), cx); + let editor = self.editor.clone(); + + cx.spawn(async move |_, cx| { + let contents = contents.await?; + let mut all_tracked_buffers = Vec::new(); + + let result = editor.update(cx, |editor, cx| { + let mut ix = 0; + let mut chunks: Vec = Vec::new(); + let text = editor.text(cx); + editor.display_map.update(cx, |map, cx| { + let snapshot = map.snapshot(cx); + for (crease_id, crease) in snapshot.crease_snapshot.creases() { + let Some((uri, mention)) = contents.get(&crease_id) else { + continue; + }; + + let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot); + if crease_range.start > ix { + //todo(): Custom slash command ContentBlock? + // let chunk = if prevent_slash_commands + // && ix == 0 + // && parse_slash_command(&text[ix..]).is_some() + // { + // format!(" {}", &text[ix..crease_range.start]).into() + // } else { + // text[ix..crease_range.start].into() + // }; + let chunk = text[ix..crease_range.start].into(); + chunks.push(chunk); + } + let chunk = match mention { + Mention::Text { + content, + tracked_buffers, + } => { + all_tracked_buffers.extend(tracked_buffers.iter().cloned()); + acp::ContentBlock::Resource(acp::EmbeddedResource { + annotations: None, + resource: acp::EmbeddedResourceResource::TextResourceContents( + acp::TextResourceContents { + mime_type: None, + text: content.clone(), + uri: uri.to_uri().to_string(), + }, + ), + }) + } + Mention::Image(mention_image) => { + let uri = match uri { + MentionUri::File { .. } => Some(uri.to_uri().to_string()), + MentionUri::PastedImage => None, + other => { + debug_panic!( + "unexpected mention uri for image: {:?}", + other + ); + None + } + }; + acp::ContentBlock::Image(acp::ImageContent { + annotations: None, + data: mention_image.data.to_string(), + mime_type: mention_image.format.mime_type().into(), + uri, + }) + } + Mention::UriOnly => { + acp::ContentBlock::ResourceLink(acp::ResourceLink { + name: uri.name(), + uri: uri.to_uri().to_string(), + annotations: None, + description: None, + mime_type: None, + size: None, + title: None, + }) + } + }; + chunks.push(chunk); + ix = crease_range.end; + } + + if ix < text.len() { + //todo(): Custom slash command ContentBlock? + // let last_chunk = if prevent_slash_commands + // && ix == 0 + // && parse_slash_command(&text[ix..]).is_some() + // { + // format!(" {}", text[ix..].trim_end()) + // } else { + // text[ix..].trim_end().to_owned() + // }; + let last_chunk = text[ix..].trim_end().to_owned(); + if !last_chunk.is_empty() { + chunks.push(last_chunk.into()); + } + } + }); + Ok((chunks, all_tracked_buffers)) + })?; + result + }) + } + + pub fn clear(&mut self, window: &mut Window, cx: &mut Context) { + self.editor.update(cx, |editor, cx| { + editor.clear(window, cx); + editor.remove_creases( + self.mention_set + .mentions + .drain() + .map(|(crease_id, _)| crease_id), + cx, + ) + }); + } + + fn send(&mut self, _: &Chat, _: &mut Window, cx: &mut Context) { + if self.is_empty(cx) { + return; + } + cx.emit(MessageEditorEvent::Send) + } + + fn cancel(&mut self, _: &editor::actions::Cancel, _: &mut Window, cx: &mut Context) { + cx.emit(MessageEditorEvent::Cancel) + } + + fn paste(&mut self, _: &Paste, window: &mut Window, cx: &mut Context) { + if !self.prompt_capabilities.get().image { + return; + } + + let images = cx + .read_from_clipboard() + .map(|item| { + item.into_entries() + .filter_map(|entry| { + if let ClipboardEntry::Image(image) = entry { + Some(image) + } else { + None + } + }) + .collect::>() + }) + .unwrap_or_default(); + + if images.is_empty() { + return; + } + cx.stop_propagation(); + + let replacement_text = MentionUri::PastedImage.as_link().to_string(); + for image in images { + let (excerpt_id, text_anchor, multibuffer_anchor) = + self.editor.update(cx, |message_editor, cx| { + let snapshot = message_editor.snapshot(window, cx); + let (excerpt_id, _, buffer_snapshot) = + snapshot.buffer_snapshot.as_singleton().unwrap(); + + let text_anchor = buffer_snapshot.anchor_before(buffer_snapshot.len()); + let multibuffer_anchor = snapshot + .buffer_snapshot + .anchor_in_excerpt(*excerpt_id, text_anchor); + message_editor.edit( + [( + multi_buffer::Anchor::max()..multi_buffer::Anchor::max(), + format!("{replacement_text} "), + )], + cx, + ); + (*excerpt_id, text_anchor, multibuffer_anchor) + }); + + let content_len = replacement_text.len(); + let Some(start_anchor) = multibuffer_anchor else { + continue; + }; + let end_anchor = self.editor.update(cx, |editor, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + snapshot.anchor_before(start_anchor.to_offset(&snapshot) + content_len) + }); + let image = Arc::new(image); + let Some((crease_id, tx)) = insert_crease_for_mention( + excerpt_id, + text_anchor, + content_len, + MentionUri::PastedImage.name().into(), + IconName::Image.path().into(), + Some(Task::ready(Ok(image.clone())).shared()), + self.editor.clone(), + window, + cx, + ) else { + continue; + }; + let task = cx + .spawn_in(window, { + async move |_, cx| { + let format = image.format; + let image = cx + .update(|_, cx| LanguageModelImage::from_image(image, cx)) + .map_err(|e| e.to_string())? + .await; + drop(tx); + if let Some(image) = image { + Ok(Mention::Image(MentionImage { + data: image.source, + format, + })) + } else { + Err("Failed to convert image".into()) + } + } + }) + .shared(); + + self.mention_set + .mentions + .insert(crease_id, (MentionUri::PastedImage, task.clone())); + + cx.spawn_in(window, async move |this, cx| { + if task.await.notify_async_err(cx).is_none() { + this.update(cx, |this, cx| { + this.editor.update(cx, |editor, cx| { + editor.edit([(start_anchor..end_anchor, "")], cx); + }); + this.mention_set.mentions.remove(&crease_id); + }) + .ok(); + } + }) + .detach(); + } + } + + pub fn insert_dragged_files( + &mut self, + paths: Vec, + added_worktrees: Vec>, + window: &mut Window, + cx: &mut Context, + ) { + let buffer = self.editor.read(cx).buffer().clone(); + let Some(buffer) = buffer.read(cx).as_singleton() else { + return; + }; + let mut tasks = Vec::new(); + for path in paths { + let Some(entry) = self.project.read(cx).entry_for_path(&path, cx) else { + continue; + }; + let Some(abs_path) = self.project.read(cx).absolute_path(&path, cx) else { + continue; + }; + let path_prefix = abs_path + .file_name() + .unwrap_or(path.path.as_os_str()) + .display() + .to_string(); + let (file_name, _) = + crate::context_picker::file_context_picker::extract_file_name_and_directory( + &path.path, + &path_prefix, + ); + + let uri = if entry.is_dir() { + MentionUri::Directory { abs_path } + } else { + MentionUri::File { abs_path } + }; + + let new_text = format!("{} ", uri.as_link()); + let content_len = new_text.len() - 1; + + let anchor = buffer.update(cx, |buffer, _cx| buffer.anchor_before(buffer.len())); + + self.editor.update(cx, |message_editor, cx| { + message_editor.edit( + [( + multi_buffer::Anchor::max()..multi_buffer::Anchor::max(), + new_text, + )], + cx, + ); + }); + tasks.push(self.confirm_mention_completion( + file_name, + anchor, + content_len, + uri, + window, + cx, + )); + } + cx.spawn(async move |_, _| { + join_all(tasks).await; + drop(added_worktrees); + }) + .detach(); + } + + pub fn insert_selections(&mut self, window: &mut Window, cx: &mut Context) { + let buffer = self.editor.read(cx).buffer().clone(); + let Some(buffer) = buffer.read(cx).as_singleton() else { + return; + }; + let anchor = buffer.update(cx, |buffer, _cx| buffer.anchor_before(buffer.len())); + let Some(workspace) = self.workspace.upgrade() else { + return; + }; + let Some(completion) = ContextPickerCompletionProvider::completion_for_action( + ContextPickerAction::AddSelections, + anchor..anchor, + cx.weak_entity(), + &workspace, + cx, + ) else { + return; + }; + self.editor.update(cx, |message_editor, cx| { + message_editor.edit( + [( + multi_buffer::Anchor::max()..multi_buffer::Anchor::max(), + completion.new_text, + )], + cx, + ); + }); + if let Some(confirm) = completion.confirm { + confirm(CompletionIntent::Complete, window, cx); + } + } + + pub fn set_read_only(&mut self, read_only: bool, cx: &mut Context) { + self.editor.update(cx, |message_editor, cx| { + message_editor.set_read_only(read_only); + cx.notify() + }) + } + + pub fn set_mode(&mut self, mode: EditorMode, cx: &mut Context) { + self.editor.update(cx, |editor, cx| { + editor.set_mode(mode); + cx.notify() + }); + } + + pub fn set_message( + &mut self, + message: Vec, + window: &mut Window, + cx: &mut Context, + ) { + self.clear(window, cx); + + let mut text = String::new(); + let mut mentions = Vec::new(); + + for chunk in message { + match chunk { + acp::ContentBlock::Text(text_content) => { + text.push_str(&text_content.text); + } + acp::ContentBlock::Resource(acp::EmbeddedResource { + resource: acp::EmbeddedResourceResource::TextResourceContents(resource), + .. + }) => { + let Some(mention_uri) = MentionUri::parse(&resource.uri).log_err() else { + continue; + }; + let start = text.len(); + write!(&mut text, "{}", mention_uri.as_link()).ok(); + let end = text.len(); + mentions.push(( + start..end, + mention_uri, + Mention::Text { + content: resource.text, + tracked_buffers: Vec::new(), + }, + )); + } + acp::ContentBlock::ResourceLink(resource) => { + if let Some(mention_uri) = MentionUri::parse(&resource.uri).log_err() { + let start = text.len(); + write!(&mut text, "{}", mention_uri.as_link()).ok(); + let end = text.len(); + mentions.push((start..end, mention_uri, Mention::UriOnly)); + } + } + acp::ContentBlock::Image(acp::ImageContent { + uri, + data, + mime_type, + annotations: _, + }) => { + let mention_uri = if let Some(uri) = uri { + MentionUri::parse(&uri) + } else { + Ok(MentionUri::PastedImage) + }; + let Some(mention_uri) = mention_uri.log_err() else { + continue; + }; + let Some(format) = ImageFormat::from_mime_type(&mime_type) else { + log::error!("failed to parse MIME type for image: {mime_type:?}"); + continue; + }; + let start = text.len(); + write!(&mut text, "{}", mention_uri.as_link()).ok(); + let end = text.len(); + mentions.push(( + start..end, + mention_uri, + Mention::Image(MentionImage { + data: data.into(), + format, + }), + )); + } + acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => {} + } + } + + let snapshot = self.editor.update(cx, |editor, cx| { + editor.set_text(text, window, cx); + editor.buffer().read(cx).snapshot(cx) + }); + + for (range, mention_uri, mention) in mentions { + let anchor = snapshot.anchor_before(range.start); + let Some((crease_id, tx)) = insert_crease_for_mention( + anchor.excerpt_id, + anchor.text_anchor, + range.end - range.start, + mention_uri.name().into(), + mention_uri.icon_path(cx), + None, + self.editor.clone(), + window, + cx, + ) else { + continue; + }; + drop(tx); + + self.mention_set.mentions.insert( + crease_id, + (mention_uri.clone(), Task::ready(Ok(mention)).shared()), + ); + } + cx.notify(); + } + + pub fn text(&self, cx: &App) -> String { + self.editor.read(cx).text(cx) + } + + #[cfg(test)] + pub fn set_text(&mut self, text: &str, window: &mut Window, cx: &mut Context) { + self.editor.update(cx, |editor, cx| { + editor.set_text(text, window, cx); + }); + } +} + +fn render_directory_contents(entries: Vec<(Arc, PathBuf, String)>) -> String { + let mut output = String::new(); + for (_relative_path, full_path, content) in entries { + let fence = codeblock_fence_for_path(Some(&full_path), None); + write!(output, "\n{fence}\n{content}\n```").unwrap(); + } + output +} + +impl Focusable for MessageEditor { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.editor.focus_handle(cx) + } +} + +impl Render for MessageEditor { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + div() + .key_context("MessageEditor") + .on_action(cx.listener(Self::send)) + .on_action(cx.listener(Self::cancel)) + .capture_action(cx.listener(Self::paste)) + .flex_1() + .child({ + let settings = ThemeSettings::get_global(cx); + let font_size = TextSize::Small + .rems(cx) + .to_pixels(settings.agent_font_size(cx)); + let line_height = settings.buffer_line_height.value() * font_size; + + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.buffer_font.family.clone(), + font_fallbacks: settings.buffer_font.fallbacks.clone(), + font_features: settings.buffer_font.features.clone(), + font_size: font_size.into(), + line_height: line_height.into(), + ..Default::default() + }; + + EditorElement::new( + &self.editor, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + syntax: cx.theme().syntax().clone(), + inlay_hints_style: editor::make_inlay_hints_style(cx), + ..Default::default() + }, + ) + }) + } +} + +pub(crate) fn insert_crease_for_mention( + excerpt_id: ExcerptId, + anchor: text::Anchor, + content_len: usize, + crease_label: SharedString, + crease_icon: SharedString, + // abs_path: Option>, + image: Option, String>>>>, + editor: Entity, + window: &mut Window, + cx: &mut App, +) -> Option<(CreaseId, postage::barrier::Sender)> { + let (tx, rx) = postage::barrier::channel(); + + let crease_id = editor.update(cx, |editor, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + + let start = snapshot.anchor_in_excerpt(excerpt_id, anchor)?; + + let start = start.bias_right(&snapshot); + let end = snapshot.anchor_before(start.to_offset(&snapshot) + content_len); + + let placeholder = FoldPlaceholder { + render: render_mention_fold_button( + crease_label, + crease_icon, + start..end, + rx, + image, + cx.weak_entity(), + cx, + ), + merge_adjacent: false, + ..Default::default() + }; + + let crease = Crease::Inline { + range: start..end, + placeholder, + render_toggle: None, + render_trailer: None, + metadata: None, + }; + + let ids = editor.insert_creases(vec![crease.clone()], cx); + editor.fold_creases(vec![crease], false, window, cx); + + Some(ids[0]) + })?; + + Some((crease_id, tx)) +} + +fn render_mention_fold_button( + label: SharedString, + icon: SharedString, + range: Range, + mut loading_finished: postage::barrier::Receiver, + image_task: Option, String>>>>, + editor: WeakEntity, + cx: &mut App, +) -> Arc, &mut App) -> AnyElement> { + let loading = cx.new(|cx| { + let loading = cx.spawn(async move |this, cx| { + loading_finished.recv().await; + this.update(cx, |this: &mut LoadingContext, cx| { + this.loading = None; + cx.notify(); + }) + .ok(); + }); + LoadingContext { + id: cx.entity_id(), + label, + icon, + range, + editor, + loading: Some(loading), + image: image_task.clone(), + } + }); + Arc::new(move |_fold_id, _fold_range, _cx| loading.clone().into_any_element()) +} + +struct LoadingContext { + id: EntityId, + label: SharedString, + icon: SharedString, + range: Range, + editor: WeakEntity, + loading: Option>, + image: Option, String>>>>, +} + +impl Render for LoadingContext { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let is_in_text_selection = self + .editor + .update(cx, |editor, cx| editor.is_range_selected(&self.range, cx)) + .unwrap_or_default(); + ButtonLike::new(("loading-context", self.id)) + .style(ButtonStyle::Filled) + .selected_style(ButtonStyle::Tinted(TintColor::Accent)) + .toggle_state(is_in_text_selection) + .when_some(self.image.clone(), |el, image_task| { + el.hoverable_tooltip(move |_, cx| { + let image = image_task.peek().cloned().transpose().ok().flatten(); + let image_task = image_task.clone(); + cx.new::(|cx| ImageHover { + image, + _task: cx.spawn(async move |this, cx| { + if let Ok(image) = image_task.clone().await { + this.update(cx, |this, cx| { + if this.image.replace(image).is_none() { + cx.notify(); + } + }) + .ok(); + } + }), + }) + .into() + }) + }) + .child( + h_flex() + .gap_1() + .child( + Icon::from_path(self.icon.clone()) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child( + Label::new(self.label.clone()) + .size(LabelSize::Small) + .buffer_font(cx) + .single_line(), + ) + .map(|el| { + if self.loading.is_some() { + el.with_animation( + "loading-context-crease", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.4, 0.8)), + |label, delta| label.opacity(delta), + ) + .into_any() + } else { + el.into_any() + } + }), + ) + } +} + +struct ImageHover { + image: Option>, + _task: Task<()>, +} + +impl Render for ImageHover { + fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { + if let Some(image) = self.image.clone() { + gpui::img(image).max_w_96().max_h_96().into_any_element() + } else { + gpui::Empty.into_any_element() + } + } +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum Mention { + Text { + content: String, + tracked_buffers: Vec>, + }, + Image(MentionImage), + UriOnly, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct MentionImage { + pub data: SharedString, + pub format: ImageFormat, +} + +#[derive(Default)] +pub struct MentionSet { + mentions: HashMap>>)>, +} + +impl MentionSet { + fn contents( + &self, + prompt_capabilities: &acp::PromptCapabilities, + cx: &mut App, + ) -> Task>> { + if !prompt_capabilities.embedded_context { + let mentions = self + .mentions + .iter() + .map(|(crease_id, (uri, _))| (*crease_id, (uri.clone(), Mention::UriOnly))) + .collect(); + + return Task::ready(Ok(mentions)); + } + + let mentions = self.mentions.clone(); + cx.spawn(async move |_cx| { + let mut contents = HashMap::default(); + for (crease_id, (mention_uri, task)) in mentions { + contents.insert( + crease_id, + (mention_uri, task.await.map_err(|e| anyhow!("{e}"))?), + ); + } + Ok(contents) + }) + } + + fn remove_invalid(&mut self, snapshot: EditorSnapshot) { + for (crease_id, crease) in snapshot.crease_snapshot.creases() { + if !crease.range().start.is_valid(&snapshot.buffer_snapshot) { + self.mentions.remove(&crease_id); + } + } + } +} + +pub struct MessageEditorAddon {} + +impl MessageEditorAddon { + pub fn new() -> Self { + Self {} + } +} + +impl Addon for MessageEditorAddon { + fn to_any(&self) -> &dyn std::any::Any { + self + } + + fn to_any_mut(&mut self) -> Option<&mut dyn std::any::Any> { + Some(self) + } + + fn extend_key_context(&self, key_context: &mut KeyContext, cx: &App) { + let settings = agent_settings::AgentSettings::get_global(cx); + if settings.use_modifier_to_send { + key_context.add("use_modifier_to_send"); + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + cell::{Cell, RefCell}, + ops::Range, + path::Path, + rc::Rc, + sync::Arc, + }; + + use acp_thread::MentionUri; + use agent_client_protocol as acp; + use agent2::HistoryStore; + use assistant_context::ContextStore; + use editor::{AnchorRangeExt as _, Editor, EditorMode}; + use fs::FakeFs; + use futures::StreamExt as _; + use gpui::{ + AppContext, Entity, EventEmitter, FocusHandle, Focusable, TestAppContext, VisualTestContext, + }; + use lsp::{CompletionContext, CompletionTriggerKind}; + use project::{CompletionIntent, Project, ProjectPath}; + use serde_json::json; + use text::Point; + use ui::{App, Context, IntoElement, Render, SharedString, Window}; + use util::{path, uri}; + use workspace::{AppState, Item, Workspace}; + + use crate::acp::{ + message_editor::{Mention, MessageEditor}, + thread_view::tests::init_test, + }; + + #[gpui::test] + async fn test_at_mention_removal(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/project", json!({"file": ""})).await; + let project = Project::test(fs, [Path::new(path!("/project"))], cx).await; + + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + + let message_editor = cx.update(|window, cx| { + cx.new(|cx| { + MessageEditor::new( + workspace.downgrade(), + project.clone(), + history_store.clone(), + None, + Default::default(), + Default::default(), + "Test Agent".into(), + "Test", + EditorMode::AutoHeight { + min_lines: 1, + max_lines: None, + }, + window, + cx, + ) + }) + }); + let editor = message_editor.update(cx, |message_editor, _| message_editor.editor.clone()); + + cx.run_until_parked(); + + let excerpt_id = editor.update(cx, |editor, cx| { + editor + .buffer() + .read(cx) + .excerpt_ids() + .into_iter() + .next() + .unwrap() + }); + let completions = editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello @file ", window, cx); + let buffer = editor.buffer().read(cx).as_singleton().unwrap(); + let completion_provider = editor.completion_provider().unwrap(); + completion_provider.completions( + excerpt_id, + &buffer, + text::Anchor::MAX, + CompletionContext { + trigger_kind: CompletionTriggerKind::TRIGGER_CHARACTER, + trigger_character: Some("@".into()), + }, + window, + cx, + ) + }); + let [_, completion]: [_; 2] = completions + .await + .unwrap() + .into_iter() + .flat_map(|response| response.completions) + .collect::>() + .try_into() + .unwrap(); + + editor.update_in(cx, |editor, window, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + let start = snapshot + .anchor_in_excerpt(excerpt_id, completion.replace_range.start) + .unwrap(); + let end = snapshot + .anchor_in_excerpt(excerpt_id, completion.replace_range.end) + .unwrap(); + editor.edit([(start..end, completion.new_text)], cx); + (completion.confirm.unwrap())(CompletionIntent::Complete, window, cx); + }); + + cx.run_until_parked(); + + // Backspace over the inserted crease (and the following space). + editor.update_in(cx, |editor, window, cx| { + editor.backspace(&Default::default(), window, cx); + editor.backspace(&Default::default(), window, cx); + }); + + let (content, _) = message_editor + .update(cx, |message_editor, cx| message_editor.contents(cx)) + .await + .unwrap(); + + // We don't send a resource link for the deleted crease. + pretty_assertions::assert_matches!(content.as_slice(), [acp::ContentBlock::Text { .. }]); + } + + #[gpui::test] + async fn test_slash_command_validation(cx: &mut gpui::TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/test", + json!({ + ".zed": { + "tasks.json": r#"[{"label": "test", "command": "echo"}]"# + }, + "src": { + "main.rs": "fn main() {}", + }, + }), + ) + .await; + + let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await; + let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let prompt_capabilities = Rc::new(Cell::new(acp::PromptCapabilities::default())); + // Start with no available commands - simulating Claude which doesn't support slash commands + let available_commands = Rc::new(RefCell::new(vec![])); + + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let workspace_handle = workspace.downgrade(); + let message_editor = workspace.update_in(cx, |_, window, cx| { + cx.new(|cx| { + MessageEditor::new( + workspace_handle.clone(), + project.clone(), + history_store.clone(), + None, + prompt_capabilities.clone(), + available_commands.clone(), + "Claude Code".into(), + "Test", + EditorMode::AutoHeight { + min_lines: 1, + max_lines: None, + }, + window, + cx, + ) + }) + }); + let editor = message_editor.update(cx, |message_editor, _| message_editor.editor.clone()); + + // Test that slash commands fail when no available_commands are set (empty list means no commands supported) + editor.update_in(cx, |editor, window, cx| { + editor.set_text("/file test.txt", window, cx); + }); + + let contents_result = message_editor + .update(cx, |message_editor, cx| message_editor.contents(cx)) + .await; + + // Should fail because available_commands is empty (no commands supported) + assert!(contents_result.is_err()); + let error_message = contents_result.unwrap_err().to_string(); + assert!(error_message.contains("not supported by Claude Code")); + assert!(error_message.contains("Available commands: none")); + + // Now simulate Claude providing its list of available commands (which doesn't include file) + available_commands.replace(vec![acp::AvailableCommand { + name: "help".to_string(), + description: "Get help".to_string(), + input: None, + }]); + + // Test that unsupported slash commands trigger an error when we have a list of available commands + editor.update_in(cx, |editor, window, cx| { + editor.set_text("/file test.txt", window, cx); + }); + + let contents_result = message_editor + .update(cx, |message_editor, cx| message_editor.contents(cx)) + .await; + + assert!(contents_result.is_err()); + let error_message = contents_result.unwrap_err().to_string(); + assert!(error_message.contains("not supported by Claude Code")); + assert!(error_message.contains("/file")); + assert!(error_message.contains("Available commands: /help")); + + // Test that supported commands work fine + editor.update_in(cx, |editor, window, cx| { + editor.set_text("/help", window, cx); + }); + + let contents_result = message_editor + .update(cx, |message_editor, cx| message_editor.contents(cx)) + .await; + + // Should succeed because /help is in available_commands + assert!(contents_result.is_ok()); + + // Test that regular text works fine + editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello Claude!", window, cx); + }); + + let (content, _) = message_editor + .update(cx, |message_editor, cx| message_editor.contents(cx)) + .await + .unwrap(); + + assert_eq!(content.len(), 1); + if let acp::ContentBlock::Text(text) = &content[0] { + assert_eq!(text.text, "Hello Claude!"); + } else { + panic!("Expected ContentBlock::Text"); + } + + // Test that @ mentions still work + editor.update_in(cx, |editor, window, cx| { + editor.set_text("Check this @", window, cx); + }); + + // The @ mention functionality should not be affected + let (content, _) = message_editor + .update(cx, |message_editor, cx| message_editor.contents(cx)) + .await + .unwrap(); + + assert_eq!(content.len(), 1); + if let acp::ContentBlock::Text(text) = &content[0] { + assert_eq!(text.text, "Check this @"); + } else { + panic!("Expected ContentBlock::Text"); + } + } + + struct MessageEditorItem(Entity); + + impl Item for MessageEditorItem { + type Event = (); + + fn include_in_nav_history() -> bool { + false + } + + fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { + "Test".into() + } + } + + impl EventEmitter<()> for MessageEditorItem {} + + impl Focusable for MessageEditorItem { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.0.read(cx).focus_handle(cx) + } + } + + impl Render for MessageEditorItem { + fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { + self.0.clone().into_any_element() + } + } + + #[gpui::test] + async fn test_completion_provider_commands(cx: &mut TestAppContext) { + init_test(cx); + + let app_state = cx.update(AppState::test); + + cx.update(|cx| { + language::init(cx); + editor::init(cx); + workspace::init(app_state.clone(), cx); + Project::init_settings(cx); + }); + + let project = Project::test(app_state.fs.clone(), [path!("/dir").as_ref()], cx).await; + let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let workspace = window.root(cx).unwrap(); + + let mut cx = VisualTestContext::from_window(*window, cx); + + let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let prompt_capabilities = Rc::new(Cell::new(acp::PromptCapabilities::default())); + let available_commands = Rc::new(RefCell::new(vec![ + acp::AvailableCommand { + name: "quick-math".to_string(), + description: "2 + 2 = 4 - 1 = 3".to_string(), + input: None, + }, + acp::AvailableCommand { + name: "say-hello".to_string(), + description: "Say hello to whoever you want".to_string(), + input: Some(acp::AvailableCommandInput::Unstructured { + hint: "".to_string(), + }), + }, + ])); + + let editor = workspace.update_in(&mut cx, |workspace, window, cx| { + let workspace_handle = cx.weak_entity(); + let message_editor = cx.new(|cx| { + MessageEditor::new( + workspace_handle, + project.clone(), + history_store.clone(), + None, + prompt_capabilities.clone(), + available_commands.clone(), + "Test Agent".into(), + "Test", + EditorMode::AutoHeight { + max_lines: None, + min_lines: 1, + }, + window, + cx, + ) + }); + workspace.active_pane().update(cx, |pane, cx| { + pane.add_item( + Box::new(cx.new(|_| MessageEditorItem(message_editor.clone()))), + true, + true, + None, + window, + cx, + ); + }); + message_editor.read(cx).focus_handle(cx).focus(window); + message_editor.read(cx).editor().clone() + }); + + cx.simulate_input("/"); + + editor.update_in(&mut cx, |editor, window, cx| { + assert_eq!(editor.text(cx), "/"); + assert!(editor.has_visible_completions_menu()); + + assert_eq!( + current_completion_labels_with_documentation(editor), + &[ + ("quick-math".into(), "2 + 2 = 4 - 1 = 3".into()), + ("say-hello".into(), "Say hello to whoever you want".into()) + ] + ); + editor.set_text("", window, cx); + }); + + cx.simulate_input("/qui"); + + editor.update_in(&mut cx, |editor, window, cx| { + assert_eq!(editor.text(cx), "/qui"); + assert!(editor.has_visible_completions_menu()); + + assert_eq!( + current_completion_labels_with_documentation(editor), + &[("quick-math".into(), "2 + 2 = 4 - 1 = 3".into())] + ); + editor.set_text("", window, cx); + }); + + editor.update_in(&mut cx, |editor, window, cx| { + assert!(editor.has_visible_completions_menu()); + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + cx.run_until_parked(); + + editor.update_in(&mut cx, |editor, window, cx| { + assert_eq!(editor.display_text(cx), "/quick-math "); + assert!(!editor.has_visible_completions_menu()); + editor.set_text("", window, cx); + }); + + cx.simulate_input("/say"); + + editor.update_in(&mut cx, |editor, _window, cx| { + assert_eq!(editor.display_text(cx), "/say"); + assert!(editor.has_visible_completions_menu()); + + assert_eq!( + current_completion_labels_with_documentation(editor), + &[("say-hello".into(), "Say hello to whoever you want".into())] + ); + }); + + editor.update_in(&mut cx, |editor, window, cx| { + assert!(editor.has_visible_completions_menu()); + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + cx.run_until_parked(); + + editor.update_in(&mut cx, |editor, _window, cx| { + assert_eq!(editor.text(cx), "/say-hello "); + assert_eq!(editor.display_text(cx), "/say-hello "); + assert!(editor.has_visible_completions_menu()); + + assert_eq!( + current_completion_labels_with_documentation(editor), + &[("say-hello".into(), "Say hello to whoever you want".into())] + ); + }); + + cx.simulate_input("GPT5"); + + editor.update_in(&mut cx, |editor, window, cx| { + assert!(editor.has_visible_completions_menu()); + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + cx.run_until_parked(); + + editor.update_in(&mut cx, |editor, window, cx| { + assert_eq!(editor.text(cx), "/say-hello GPT5"); + assert_eq!(editor.display_text(cx), "/say-hello GPT5"); + assert!(!editor.has_visible_completions_menu()); + + // Delete argument + for _ in 0..4 { + editor.backspace(&editor::actions::Backspace, window, cx); + } + }); + + cx.run_until_parked(); + + editor.update_in(&mut cx, |editor, window, cx| { + assert_eq!(editor.text(cx), "/say-hello "); + // Hint is visible because argument was deleted + assert_eq!(editor.display_text(cx), "/say-hello "); + + // Delete last command letter + editor.backspace(&editor::actions::Backspace, window, cx); + editor.backspace(&editor::actions::Backspace, window, cx); + }); + + cx.run_until_parked(); + + editor.update_in(&mut cx, |editor, _window, cx| { + // Hint goes away once command no longer matches an available one + assert_eq!(editor.text(cx), "/say-hell"); + assert_eq!(editor.display_text(cx), "/say-hell"); + assert!(!editor.has_visible_completions_menu()); + }); + } + + #[gpui::test] + async fn test_context_completion_provider_mentions(cx: &mut TestAppContext) { + init_test(cx); + + let app_state = cx.update(AppState::test); + + cx.update(|cx| { + language::init(cx); + editor::init(cx); + workspace::init(app_state.clone(), cx); + Project::init_settings(cx); + }); + + app_state + .fs + .as_fake() + .insert_tree( + path!("/dir"), + json!({ + "editor": "", + "a": { + "one.txt": "1", + "two.txt": "2", + "three.txt": "3", + "four.txt": "4" + }, + "b": { + "five.txt": "5", + "six.txt": "6", + "seven.txt": "7", + "eight.txt": "8", + }, + "x.png": "", + }), + ) + .await; + + let project = Project::test(app_state.fs.clone(), [path!("/dir").as_ref()], cx).await; + let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let workspace = window.root(cx).unwrap(); + + let worktree = project.update(cx, |project, cx| { + let mut worktrees = project.worktrees(cx).collect::>(); + assert_eq!(worktrees.len(), 1); + worktrees.pop().unwrap() + }); + let worktree_id = worktree.read_with(cx, |worktree, _| worktree.id()); + + let mut cx = VisualTestContext::from_window(*window, cx); + + let paths = vec![ + path!("a/one.txt"), + path!("a/two.txt"), + path!("a/three.txt"), + path!("a/four.txt"), + path!("b/five.txt"), + path!("b/six.txt"), + path!("b/seven.txt"), + path!("b/eight.txt"), + ]; + + let mut opened_editors = Vec::new(); + for path in paths { + let buffer = workspace + .update_in(&mut cx, |workspace, window, cx| { + workspace.open_path( + ProjectPath { + worktree_id, + path: Path::new(path).into(), + }, + None, + false, + window, + cx, + ) + }) + .await + .unwrap(); + opened_editors.push(buffer); + } + + let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let prompt_capabilities = Rc::new(Cell::new(acp::PromptCapabilities::default())); + + let (message_editor, editor) = workspace.update_in(&mut cx, |workspace, window, cx| { + let workspace_handle = cx.weak_entity(); + let message_editor = cx.new(|cx| { + MessageEditor::new( + workspace_handle, + project.clone(), + history_store.clone(), + None, + prompt_capabilities.clone(), + Default::default(), + "Test Agent".into(), + "Test", + EditorMode::AutoHeight { + max_lines: None, + min_lines: 1, + }, + window, + cx, + ) + }); + workspace.active_pane().update(cx, |pane, cx| { + pane.add_item( + Box::new(cx.new(|_| MessageEditorItem(message_editor.clone()))), + true, + true, + None, + window, + cx, + ); + }); + message_editor.read(cx).focus_handle(cx).focus(window); + let editor = message_editor.read(cx).editor().clone(); + (message_editor, editor) + }); + + cx.simulate_input("Lorem @"); + + editor.update_in(&mut cx, |editor, window, cx| { + assert_eq!(editor.text(cx), "Lorem @"); + assert!(editor.has_visible_completions_menu()); + + assert_eq!( + current_completion_labels(editor), + &[ + "eight.txt dir/b/", + "seven.txt dir/b/", + "six.txt dir/b/", + "five.txt dir/b/", + ] + ); + editor.set_text("", window, cx); + }); + + prompt_capabilities.set(acp::PromptCapabilities { + image: true, + audio: true, + embedded_context: true, + }); + + cx.simulate_input("Lorem "); + + editor.update(&mut cx, |editor, cx| { + assert_eq!(editor.text(cx), "Lorem "); + assert!(!editor.has_visible_completions_menu()); + }); + + cx.simulate_input("@"); + + editor.update(&mut cx, |editor, cx| { + assert_eq!(editor.text(cx), "Lorem @"); + assert!(editor.has_visible_completions_menu()); + assert_eq!( + current_completion_labels(editor), + &[ + "eight.txt dir/b/", + "seven.txt dir/b/", + "six.txt dir/b/", + "five.txt dir/b/", + "Files & Directories", + "Symbols", + "Threads", + "Fetch" + ] + ); + }); + + // Select and confirm "File" + editor.update_in(&mut cx, |editor, window, cx| { + assert!(editor.has_visible_completions_menu()); + editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); + editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); + editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); + editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + cx.run_until_parked(); + + editor.update(&mut cx, |editor, cx| { + assert_eq!(editor.text(cx), "Lorem @file "); + assert!(editor.has_visible_completions_menu()); + }); + + cx.simulate_input("one"); + + editor.update(&mut cx, |editor, cx| { + assert_eq!(editor.text(cx), "Lorem @file one"); + assert!(editor.has_visible_completions_menu()); + assert_eq!(current_completion_labels(editor), vec!["one.txt dir/a/"]); + }); + + editor.update_in(&mut cx, |editor, window, cx| { + assert!(editor.has_visible_completions_menu()); + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + let url_one = uri!("file:///dir/a/one.txt"); + editor.update(&mut cx, |editor, cx| { + let text = editor.text(cx); + assert_eq!(text, format!("Lorem [@one.txt]({url_one}) ")); + assert!(!editor.has_visible_completions_menu()); + assert_eq!(fold_ranges(editor, cx).len(), 1); + }); + + let all_prompt_capabilities = acp::PromptCapabilities { + image: true, + audio: true, + embedded_context: true, + }; + + let contents = message_editor + .update(&mut cx, |message_editor, cx| { + message_editor + .mention_set() + .contents(&all_prompt_capabilities, cx) + }) + .await + .unwrap() + .into_values() + .collect::>(); + + { + let [(uri, Mention::Text { content, .. })] = contents.as_slice() else { + panic!("Unexpected mentions"); + }; + pretty_assertions::assert_eq!(content, "1"); + pretty_assertions::assert_eq!(uri, &url_one.parse::().unwrap()); + } + + let contents = message_editor + .update(&mut cx, |message_editor, cx| { + message_editor + .mention_set() + .contents(&acp::PromptCapabilities::default(), cx) + }) + .await + .unwrap() + .into_values() + .collect::>(); + + { + let [(uri, Mention::UriOnly)] = contents.as_slice() else { + panic!("Unexpected mentions"); + }; + pretty_assertions::assert_eq!(uri, &url_one.parse::().unwrap()); + } + + cx.simulate_input(" "); + + editor.update(&mut cx, |editor, cx| { + let text = editor.text(cx); + assert_eq!(text, format!("Lorem [@one.txt]({url_one}) ")); + assert!(!editor.has_visible_completions_menu()); + assert_eq!(fold_ranges(editor, cx).len(), 1); + }); + + cx.simulate_input("Ipsum "); + + editor.update(&mut cx, |editor, cx| { + let text = editor.text(cx); + assert_eq!(text, format!("Lorem [@one.txt]({url_one}) Ipsum "),); + assert!(!editor.has_visible_completions_menu()); + assert_eq!(fold_ranges(editor, cx).len(), 1); + }); + + cx.simulate_input("@file "); + + editor.update(&mut cx, |editor, cx| { + let text = editor.text(cx); + assert_eq!(text, format!("Lorem [@one.txt]({url_one}) Ipsum @file "),); + assert!(editor.has_visible_completions_menu()); + assert_eq!(fold_ranges(editor, cx).len(), 1); + }); + + editor.update_in(&mut cx, |editor, window, cx| { + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + cx.run_until_parked(); + + let contents = message_editor + .update(&mut cx, |message_editor, cx| { + message_editor + .mention_set() + .contents(&all_prompt_capabilities, cx) + }) + .await + .unwrap() + .into_values() + .collect::>(); + + let url_eight = uri!("file:///dir/b/eight.txt"); + + { + let [_, (uri, Mention::Text { content, .. })] = contents.as_slice() else { + panic!("Unexpected mentions"); + }; + pretty_assertions::assert_eq!(content, "8"); + pretty_assertions::assert_eq!(uri, &url_eight.parse::().unwrap()); + } + + editor.update(&mut cx, |editor, cx| { + assert_eq!( + editor.text(cx), + format!("Lorem [@one.txt]({url_one}) Ipsum [@eight.txt]({url_eight}) ") + ); + assert!(!editor.has_visible_completions_menu()); + assert_eq!(fold_ranges(editor, cx).len(), 2); + }); + + let plain_text_language = Arc::new(language::Language::new( + language::LanguageConfig { + name: "Plain Text".into(), + matcher: language::LanguageMatcher { + path_suffixes: vec!["txt".to_string()], + ..Default::default() + }, + ..Default::default() + }, + None, + )); + + // Register the language and fake LSP + let language_registry = project.read_with(&cx, |project, _| project.languages().clone()); + language_registry.add(plain_text_language); + + let mut fake_language_servers = language_registry.register_fake_lsp( + "Plain Text", + language::FakeLspAdapter { + capabilities: lsp::ServerCapabilities { + workspace_symbol_provider: Some(lsp::OneOf::Left(true)), + ..Default::default() + }, + ..Default::default() + }, + ); + + // Open the buffer to trigger LSP initialization + let buffer = project + .update(&mut cx, |project, cx| { + project.open_local_buffer(path!("/dir/a/one.txt"), cx) + }) + .await + .unwrap(); + + // Register the buffer with language servers + let _handle = project.update(&mut cx, |project, cx| { + project.register_buffer_with_language_servers(&buffer, cx) + }); + + cx.run_until_parked(); + + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server.set_request_handler::( + move |_, _| async move { + Ok(Some(lsp::WorkspaceSymbolResponse::Flat(vec![ + #[allow(deprecated)] + lsp::SymbolInformation { + name: "MySymbol".into(), + location: lsp::Location { + uri: lsp::Uri::from_file_path(path!("/dir/a/one.txt")).unwrap(), + range: lsp::Range::new( + lsp::Position::new(0, 0), + lsp::Position::new(0, 1), + ), + }, + kind: lsp::SymbolKind::CONSTANT, + tags: None, + container_name: None, + deprecated: None, + }, + ]))) + }, + ); + + cx.simulate_input("@symbol "); + + editor.update(&mut cx, |editor, cx| { + assert_eq!( + editor.text(cx), + format!("Lorem [@one.txt]({url_one}) Ipsum [@eight.txt]({url_eight}) @symbol ") + ); + assert!(editor.has_visible_completions_menu()); + assert_eq!(current_completion_labels(editor), &["MySymbol"]); + }); + + editor.update_in(&mut cx, |editor, window, cx| { + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + let contents = message_editor + .update(&mut cx, |message_editor, cx| { + message_editor + .mention_set() + .contents(&all_prompt_capabilities, cx) + }) + .await + .unwrap() + .into_values() + .collect::>(); + + { + let [_, _, (uri, Mention::Text { content, .. })] = contents.as_slice() else { + panic!("Unexpected mentions"); + }; + pretty_assertions::assert_eq!(content, "1"); + pretty_assertions::assert_eq!( + uri, + &format!("{url_one}?symbol=MySymbol#L1:1") + .parse::() + .unwrap() + ); + } + + cx.run_until_parked(); + + editor.read_with(&cx, |editor, cx| { + assert_eq!( + editor.text(cx), + format!("Lorem [@one.txt]({url_one}) Ipsum [@eight.txt]({url_eight}) [@MySymbol]({url_one}?symbol=MySymbol#L1:1) ") + ); + }); + + // Try to mention an "image" file that will fail to load + cx.simulate_input("@file x.png"); + + editor.update(&mut cx, |editor, cx| { + assert_eq!( + editor.text(cx), + format!("Lorem [@one.txt]({url_one}) Ipsum [@eight.txt]({url_eight}) [@MySymbol]({url_one}?symbol=MySymbol#L1:1) @file x.png") + ); + assert!(editor.has_visible_completions_menu()); + assert_eq!(current_completion_labels(editor), &["x.png dir/"]); + }); + + editor.update_in(&mut cx, |editor, window, cx| { + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + // Getting the message contents fails + message_editor + .update(&mut cx, |message_editor, cx| { + message_editor + .mention_set() + .contents(&all_prompt_capabilities, cx) + }) + .await + .expect_err("Should fail to load x.png"); + + cx.run_until_parked(); + + // Mention was removed + editor.read_with(&cx, |editor, cx| { + assert_eq!( + editor.text(cx), + format!("Lorem [@one.txt]({url_one}) Ipsum [@eight.txt]({url_eight}) [@MySymbol]({url_one}?symbol=MySymbol#L1:1) ") + ); + }); + + // Once more + cx.simulate_input("@file x.png"); + + editor.update(&mut cx, |editor, cx| { + assert_eq!( + editor.text(cx), + format!("Lorem [@one.txt]({url_one}) Ipsum [@eight.txt]({url_eight}) [@MySymbol]({url_one}?symbol=MySymbol#L1:1) @file x.png") + ); + assert!(editor.has_visible_completions_menu()); + assert_eq!(current_completion_labels(editor), &["x.png dir/"]); + }); + + editor.update_in(&mut cx, |editor, window, cx| { + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + // This time don't immediately get the contents, just let the confirmed completion settle + cx.run_until_parked(); + + // Mention was removed + editor.read_with(&cx, |editor, cx| { + assert_eq!( + editor.text(cx), + format!("Lorem [@one.txt]({url_one}) Ipsum [@eight.txt]({url_eight}) [@MySymbol]({url_one}?symbol=MySymbol#L1:1) ") + ); + }); + + // Now getting the contents succeeds, because the invalid mention was removed + let contents = message_editor + .update(&mut cx, |message_editor, cx| { + message_editor + .mention_set() + .contents(&all_prompt_capabilities, cx) + }) + .await + .unwrap(); + assert_eq!(contents.len(), 3); + } + + fn fold_ranges(editor: &Editor, cx: &mut App) -> Vec> { + let snapshot = editor.buffer().read(cx).snapshot(cx); + editor.display_map.update(cx, |display_map, cx| { + display_map + .snapshot(cx) + .folds_in_range(0..snapshot.len()) + .map(|fold| fold.range.to_point(&snapshot)) + .collect() + }) + } + + fn current_completion_labels(editor: &Editor) -> Vec { + let completions = editor.current_completions().expect("Missing completions"); + completions + .into_iter() + .map(|completion| completion.label.text) + .collect::>() + } + + fn current_completion_labels_with_documentation(editor: &Editor) -> Vec<(String, String)> { + let completions = editor.current_completions().expect("Missing completions"); + completions + .into_iter() + .map(|completion| { + ( + completion.label.text, + completion + .documentation + .map(|d| d.text().to_string()) + .unwrap_or_default(), + ) + }) + .collect::>() + } +} diff --git a/crates/agent_ui/src/acp/message_history.rs b/crates/agent_ui/src/acp/message_history.rs deleted file mode 100644 index c8280573a0230ccd15890bba10745ab552b703e6..0000000000000000000000000000000000000000 --- a/crates/agent_ui/src/acp/message_history.rs +++ /dev/null @@ -1,88 +0,0 @@ -pub struct MessageHistory { - items: Vec, - current: Option, -} - -impl Default for MessageHistory { - fn default() -> Self { - MessageHistory { - items: Vec::new(), - current: None, - } - } -} - -impl MessageHistory { - pub fn push(&mut self, message: T) { - self.current.take(); - self.items.push(message); - } - - pub fn reset_position(&mut self) { - self.current.take(); - } - - pub fn prev(&mut self) -> Option<&T> { - if self.items.is_empty() { - return None; - } - - let new_ix = self - .current - .get_or_insert(self.items.len()) - .saturating_sub(1); - - self.current = Some(new_ix); - self.items.get(new_ix) - } - - pub fn next(&mut self) -> Option<&T> { - let current = self.current.as_mut()?; - *current += 1; - - self.items.get(*current).or_else(|| { - self.current.take(); - None - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_prev_next() { - let mut history = MessageHistory::default(); - - // Test empty history - assert_eq!(history.prev(), None); - assert_eq!(history.next(), None); - - // Add some messages - history.push("first"); - history.push("second"); - history.push("third"); - - // Test prev navigation - assert_eq!(history.prev(), Some(&"third")); - assert_eq!(history.prev(), Some(&"second")); - assert_eq!(history.prev(), Some(&"first")); - assert_eq!(history.prev(), Some(&"first")); - - assert_eq!(history.next(), Some(&"second")); - - // Test mixed navigation - history.push("fourth"); - assert_eq!(history.prev(), Some(&"fourth")); - assert_eq!(history.prev(), Some(&"third")); - assert_eq!(history.next(), Some(&"fourth")); - assert_eq!(history.next(), None); - - // Test that push resets navigation - history.prev(); - history.prev(); - history.push("fifth"); - assert_eq!(history.prev(), Some(&"fifth")); - } -} diff --git a/crates/agent_ui/src/acp/mode_selector.rs b/crates/agent_ui/src/acp/mode_selector.rs new file mode 100644 index 0000000000000000000000000000000000000000..b68643859efdcd7fcac5e2ca5f652372a58cc577 --- /dev/null +++ b/crates/agent_ui/src/acp/mode_selector.rs @@ -0,0 +1,230 @@ +use acp_thread::AgentSessionModes; +use agent_client_protocol as acp; +use agent_servers::AgentServer; +use fs::Fs; +use gpui::{Context, Entity, FocusHandle, WeakEntity, Window, prelude::*}; +use std::{rc::Rc, sync::Arc}; +use ui::{ + Button, ContextMenu, ContextMenuEntry, KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, + prelude::*, +}; + +use crate::{CycleModeSelector, ToggleProfileSelector}; + +pub struct ModeSelector { + connection: Rc, + agent_server: Rc, + menu_handle: PopoverMenuHandle, + focus_handle: FocusHandle, + fs: Arc, + setting_mode: bool, +} + +impl ModeSelector { + pub fn new( + session_modes: Rc, + agent_server: Rc, + fs: Arc, + focus_handle: FocusHandle, + ) -> Self { + Self { + connection: session_modes, + agent_server, + menu_handle: PopoverMenuHandle::default(), + fs, + setting_mode: false, + focus_handle, + } + } + + pub fn menu_handle(&self) -> PopoverMenuHandle { + self.menu_handle.clone() + } + + pub fn cycle_mode(&mut self, _window: &mut Window, cx: &mut Context) { + let all_modes = self.connection.all_modes(); + let current_mode = self.connection.current_mode(); + + let current_index = all_modes + .iter() + .position(|mode| mode.id.0 == current_mode.0) + .unwrap_or(0); + + let next_index = (current_index + 1) % all_modes.len(); + self.set_mode(all_modes[next_index].id.clone(), cx); + } + + pub fn set_mode(&mut self, mode: acp::SessionModeId, cx: &mut Context) { + let task = self.connection.set_mode(mode, cx); + self.setting_mode = true; + cx.notify(); + + cx.spawn(async move |this: WeakEntity, cx| { + if let Err(err) = task.await { + log::error!("Failed to set session mode: {:?}", err); + } + this.update(cx, |this, cx| { + this.setting_mode = false; + cx.notify(); + }) + .ok(); + }) + .detach(); + } + + fn build_context_menu( + &self, + window: &mut Window, + cx: &mut Context, + ) -> Entity { + let weak_self = cx.weak_entity(); + + ContextMenu::build(window, cx, move |mut menu, _window, cx| { + let all_modes = self.connection.all_modes(); + let current_mode = self.connection.current_mode(); + let default_mode = self.agent_server.default_mode(cx); + + for mode in all_modes { + let is_selected = &mode.id == ¤t_mode; + let is_default = Some(&mode.id) == default_mode.as_ref(); + let entry = ContextMenuEntry::new(mode.name.clone()) + .toggleable(IconPosition::End, is_selected); + + let entry = if let Some(description) = &mode.description { + entry.documentation_aside(ui::DocumentationSide::Left, { + let description = description.clone(); + + move |cx| { + v_flex() + .gap_1() + .child(Label::new(description.clone())) + .child( + h_flex() + .pt_1() + .border_t_1() + .border_color(cx.theme().colors().border_variant) + .gap_0p5() + .text_sm() + .text_color(Color::Muted.color(cx)) + .child("Hold") + .child(div().pt_0p5().children(ui::render_modifiers( + &gpui::Modifiers::secondary_key(), + PlatformStyle::platform(), + None, + Some(ui::TextSize::Default.rems(cx).into()), + true, + ))) + .child(div().map(|this| { + if is_default { + this.child("to also unset as default") + } else { + this.child("to also set as default") + } + })), + ) + .into_any_element() + } + }) + } else { + entry + }; + + menu.push_item(entry.handler({ + let mode_id = mode.id.clone(); + let weak_self = weak_self.clone(); + move |window, cx| { + weak_self + .update(cx, |this, cx| { + if window.modifiers().secondary() { + this.agent_server.set_default_mode( + if is_default { + None + } else { + Some(mode_id.clone()) + }, + this.fs.clone(), + cx, + ); + } + + this.set_mode(mode_id.clone(), cx); + }) + .ok(); + } + })); + } + + menu.key_context("ModeSelector") + }) + } +} + +impl Render for ModeSelector { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let current_mode_id = self.connection.current_mode(); + let current_mode_name = self + .connection + .all_modes() + .iter() + .find(|mode| mode.id == current_mode_id) + .map(|mode| mode.name.clone()) + .unwrap_or_else(|| "Unknown".into()); + + let this = cx.entity(); + + let trigger_button = Button::new("mode-selector-trigger", current_mode_name) + .label_size(LabelSize::Small) + .style(ButtonStyle::Subtle) + .color(Color::Muted) + .icon(IconName::ChevronDown) + .icon_size(IconSize::XSmall) + .icon_position(IconPosition::End) + .icon_color(Color::Muted) + .disabled(self.setting_mode); + + PopoverMenu::new("mode-selector") + .trigger_with_tooltip( + trigger_button, + Tooltip::element({ + let focus_handle = self.focus_handle.clone(); + move |window, cx| { + v_flex() + .gap_1() + .child( + h_flex() + .pb_1() + .gap_2() + .justify_between() + .border_b_1() + .border_color(cx.theme().colors().border_variant) + .child(Label::new("Cycle Through Modes")) + .children(KeyBinding::for_action_in( + &CycleModeSelector, + &focus_handle, + window, + cx, + )), + ) + .child( + h_flex() + .gap_2() + .justify_between() + .child(Label::new("Toggle Mode Menu")) + .children(KeyBinding::for_action_in( + &ToggleProfileSelector, + &focus_handle, + window, + cx, + )), + ) + .into_any() + } + }), + ) + .anchor(gpui::Corner::BottomRight) + .with_handle(self.menu_handle.clone()) + .menu(move |window, cx| { + Some(this.update(cx, |this, cx| this.build_context_menu(window, cx))) + }) + } +} diff --git a/crates/agent_ui/src/acp/model_selector.rs b/crates/agent_ui/src/acp/model_selector.rs index 563afee65f0168232c0461092272f3af4bbb77dd..95c0478aa3cf6b1ca78cf391a5bd734820c41454 100644 --- a/crates/agent_ui/src/acp/model_selector.rs +++ b/crates/agent_ui/src/acp/model_selector.rs @@ -73,11 +73,8 @@ impl AcpModelPickerDelegate { this.update_in(cx, |this, window, cx| { this.delegate.models = models.ok(); this.delegate.selected_model = selected_model.ok(); - this.delegate.update_matches(this.query(cx), window, cx) - })? - .await; - - Ok(()) + this.refresh(window, cx) + }) } refresh(&this, &session_id, cx).await.log_err(); @@ -195,8 +192,10 @@ impl PickerDelegate for AcpModelPickerDelegate { } } - fn dismissed(&mut self, _: &mut Window, cx: &mut Context>) { - cx.emit(DismissEvent); + fn dismissed(&mut self, window: &mut Window, cx: &mut Context>) { + cx.defer_in(window, |picker, window, cx| { + picker.set_query("", window, cx); + }); } fn render_match( @@ -330,7 +329,7 @@ async fn fuzzy_search( .collect::>(); let mut matches = match_strings( &candidates, - &query, + query, false, true, 100, diff --git a/crates/agent_ui/src/acp/model_selector_popover.rs b/crates/agent_ui/src/acp/model_selector_popover.rs index e52101113a61c7379be54e25f1784ac16b660200..fa771c695ecf8175859d145b8d08d2cf3447a77a 100644 --- a/crates/agent_ui/src/acp/model_selector_popover.rs +++ b/crates/agent_ui/src/acp/model_selector_popover.rs @@ -5,7 +5,8 @@ use agent_client_protocol as acp; use gpui::{Entity, FocusHandle}; use picker::popover_menu::PickerPopoverMenu; use ui::{ - ButtonLike, Context, IntoElement, PopoverMenuHandle, SharedString, Tooltip, Window, prelude::*, + ButtonLike, Context, IntoElement, PopoverMenuHandle, SharedString, TintColor, Tooltip, Window, + prelude::*, }; use zed_actions::agent::ToggleModelSelector; @@ -36,6 +37,14 @@ impl AcpModelSelectorPopover { pub fn toggle(&self, window: &mut Window, cx: &mut Context) { self.menu_handle.toggle(window, cx); } + + pub fn active_model_name(&self, cx: &App) -> Option { + self.selector + .read(cx) + .delegate + .active_model() + .map(|model| model.name.clone()) + } } impl Render for AcpModelSelectorPopover { @@ -50,15 +59,22 @@ impl Render for AcpModelSelectorPopover { let focus_handle = self.focus_handle.clone(); + let color = if self.menu_handle.is_deployed() { + Color::Accent + } else { + Color::Muted + }; + PickerPopoverMenu::new( self.selector.clone(), ButtonLike::new("active-model") .when_some(model_icon, |this, icon| { - this.child(Icon::new(icon).color(Color::Muted).size(IconSize::XSmall)) + this.child(Icon::new(icon).color(color).size(IconSize::XSmall)) }) + .selected_style(ButtonStyle::Tinted(TintColor::Accent)) .child( Label::new(model_name) - .color(Color::Muted) + .color(color) .size(LabelSize::Small) .ml_0p5(), ) diff --git a/crates/agent_ui/src/acp/thread_history.rs b/crates/agent_ui/src/acp/thread_history.rs new file mode 100644 index 0000000000000000000000000000000000000000..015a2548d54ac5545f06984ec31bce2d3d58a56e --- /dev/null +++ b/crates/agent_ui/src/acp/thread_history.rs @@ -0,0 +1,825 @@ +use crate::acp::AcpThreadView; +use crate::{AgentPanel, RemoveSelectedThread}; +use agent2::{HistoryEntry, HistoryStore}; +use chrono::{Datelike as _, Local, NaiveDate, TimeDelta}; +use editor::{Editor, EditorEvent}; +use fuzzy::StringMatchCandidate; +use gpui::{ + App, Entity, EventEmitter, FocusHandle, Focusable, ScrollStrategy, Stateful, Task, + UniformListScrollHandle, WeakEntity, Window, uniform_list, +}; +use std::{fmt::Display, ops::Range}; +use text::Bias; +use time::{OffsetDateTime, UtcOffset}; +use ui::{ + HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Scrollbar, ScrollbarState, + Tooltip, prelude::*, +}; + +pub struct AcpThreadHistory { + pub(crate) history_store: Entity, + scroll_handle: UniformListScrollHandle, + selected_index: usize, + hovered_index: Option, + search_editor: Entity, + search_query: SharedString, + + visible_items: Vec, + + scrollbar_visibility: bool, + scrollbar_state: ScrollbarState, + local_timezone: UtcOffset, + + _update_task: Task<()>, + _subscriptions: Vec, +} + +enum ListItemType { + BucketSeparator(TimeBucket), + Entry { + entry: HistoryEntry, + format: EntryTimeFormat, + }, + SearchResult { + entry: HistoryEntry, + positions: Vec, + }, +} + +impl ListItemType { + fn history_entry(&self) -> Option<&HistoryEntry> { + match self { + ListItemType::Entry { entry, .. } => Some(entry), + ListItemType::SearchResult { entry, .. } => Some(entry), + _ => None, + } + } +} + +pub enum ThreadHistoryEvent { + Open(HistoryEntry), +} + +impl EventEmitter for AcpThreadHistory {} + +impl AcpThreadHistory { + pub(crate) fn new( + history_store: Entity, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let search_editor = cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor.set_placeholder_text("Search threads...", window, cx); + editor + }); + + let search_editor_subscription = + cx.subscribe(&search_editor, |this, search_editor, event, cx| { + if let EditorEvent::BufferEdited = event { + let query = search_editor.read(cx).text(cx); + if this.search_query != query { + this.search_query = query.into(); + this.update_visible_items(false, cx); + } + } + }); + + let history_store_subscription = cx.observe(&history_store, |this, _, cx| { + this.update_visible_items(true, cx); + }); + + let scroll_handle = UniformListScrollHandle::default(); + let scrollbar_state = ScrollbarState::new(scroll_handle.clone()); + + let mut this = Self { + history_store, + scroll_handle, + selected_index: 0, + hovered_index: None, + visible_items: Default::default(), + search_editor, + scrollbar_visibility: true, + scrollbar_state, + local_timezone: UtcOffset::from_whole_seconds( + chrono::Local::now().offset().local_minus_utc(), + ) + .unwrap(), + search_query: SharedString::default(), + _subscriptions: vec![search_editor_subscription, history_store_subscription], + _update_task: Task::ready(()), + }; + this.update_visible_items(false, cx); + this + } + + fn update_visible_items(&mut self, preserve_selected_item: bool, cx: &mut Context) { + let entries = self + .history_store + .update(cx, |store, _| store.entries().collect()); + let new_list_items = if self.search_query.is_empty() { + self.add_list_separators(entries, cx) + } else { + self.filter_search_results(entries, cx) + }; + let selected_history_entry = if preserve_selected_item { + self.selected_history_entry().cloned() + } else { + None + }; + + self._update_task = cx.spawn(async move |this, cx| { + let new_visible_items = new_list_items.await; + this.update(cx, |this, cx| { + let new_selected_index = if let Some(history_entry) = selected_history_entry { + let history_entry_id = history_entry.id(); + new_visible_items + .iter() + .position(|visible_entry| { + visible_entry + .history_entry() + .is_some_and(|entry| entry.id() == history_entry_id) + }) + .unwrap_or(0) + } else { + 0 + }; + + this.visible_items = new_visible_items; + this.set_selected_index(new_selected_index, Bias::Right, cx); + cx.notify(); + }) + .ok(); + }); + } + + fn add_list_separators(&self, entries: Vec, cx: &App) -> Task> { + cx.background_spawn(async move { + let mut items = Vec::with_capacity(entries.len() + 1); + let mut bucket = None; + let today = Local::now().naive_local().date(); + + for entry in entries.into_iter() { + let entry_date = entry + .updated_at() + .with_timezone(&Local) + .naive_local() + .date(); + let entry_bucket = TimeBucket::from_dates(today, entry_date); + + if Some(entry_bucket) != bucket { + bucket = Some(entry_bucket); + items.push(ListItemType::BucketSeparator(entry_bucket)); + } + + items.push(ListItemType::Entry { + entry, + format: entry_bucket.into(), + }); + } + items + }) + } + + fn filter_search_results( + &self, + entries: Vec, + cx: &App, + ) -> Task> { + let query = self.search_query.clone(); + cx.background_spawn({ + let executor = cx.background_executor().clone(); + async move { + let mut candidates = Vec::with_capacity(entries.len()); + + for (idx, entry) in entries.iter().enumerate() { + candidates.push(StringMatchCandidate::new(idx, entry.title())); + } + + const MAX_MATCHES: usize = 100; + + let matches = fuzzy::match_strings( + &candidates, + &query, + false, + true, + MAX_MATCHES, + &Default::default(), + executor, + ) + .await; + + matches + .into_iter() + .map(|search_match| ListItemType::SearchResult { + entry: entries[search_match.candidate_id].clone(), + positions: search_match.positions, + }) + .collect() + } + }) + } + + fn search_produced_no_matches(&self) -> bool { + self.visible_items.is_empty() && !self.search_query.is_empty() + } + + fn selected_history_entry(&self) -> Option<&HistoryEntry> { + self.get_history_entry(self.selected_index) + } + + fn get_history_entry(&self, visible_items_ix: usize) -> Option<&HistoryEntry> { + self.visible_items.get(visible_items_ix)?.history_entry() + } + + fn set_selected_index(&mut self, mut index: usize, bias: Bias, cx: &mut Context) { + if self.visible_items.len() == 0 { + self.selected_index = 0; + return; + } + while matches!( + self.visible_items.get(index), + None | Some(ListItemType::BucketSeparator(..)) + ) { + index = match bias { + Bias::Left => { + if index == 0 { + self.visible_items.len() - 1 + } else { + index - 1 + } + } + Bias::Right => { + if index >= self.visible_items.len() - 1 { + 0 + } else { + index + 1 + } + } + }; + } + self.selected_index = index; + self.scroll_handle + .scroll_to_item(index, ScrollStrategy::Top); + cx.notify() + } + + pub fn select_previous( + &mut self, + _: &menu::SelectPrevious, + _window: &mut Window, + cx: &mut Context, + ) { + if self.selected_index == 0 { + self.set_selected_index(self.visible_items.len() - 1, Bias::Left, cx); + } else { + self.set_selected_index(self.selected_index - 1, Bias::Left, cx); + } + } + + pub fn select_next( + &mut self, + _: &menu::SelectNext, + _window: &mut Window, + cx: &mut Context, + ) { + if self.selected_index == self.visible_items.len() - 1 { + self.set_selected_index(0, Bias::Right, cx); + } else { + self.set_selected_index(self.selected_index + 1, Bias::Right, cx); + } + } + + fn select_first( + &mut self, + _: &menu::SelectFirst, + _window: &mut Window, + cx: &mut Context, + ) { + self.set_selected_index(0, Bias::Right, cx); + } + + fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context) { + self.set_selected_index(self.visible_items.len() - 1, Bias::Left, cx); + } + + fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context) { + self.confirm_entry(self.selected_index, cx); + } + + fn confirm_entry(&mut self, ix: usize, cx: &mut Context) { + let Some(entry) = self.get_history_entry(ix) else { + return; + }; + cx.emit(ThreadHistoryEvent::Open(entry.clone())); + } + + fn remove_selected_thread( + &mut self, + _: &RemoveSelectedThread, + _window: &mut Window, + cx: &mut Context, + ) { + self.remove_thread(self.selected_index, cx) + } + + fn remove_thread(&mut self, visible_item_ix: usize, cx: &mut Context) { + let Some(entry) = self.get_history_entry(visible_item_ix) else { + return; + }; + + let task = match entry { + HistoryEntry::AcpThread(thread) => self + .history_store + .update(cx, |this, cx| this.delete_thread(thread.id.clone(), cx)), + HistoryEntry::TextThread(context) => self.history_store.update(cx, |this, cx| { + this.delete_text_thread(context.path.clone(), cx) + }), + }; + task.detach_and_log_err(cx); + } + + fn render_scrollbar(&self, cx: &mut Context) -> Option> { + if !(self.scrollbar_visibility || self.scrollbar_state.is_dragging()) { + return None; + } + + Some( + div() + .occlude() + .id("thread-history-scroll") + .h_full() + .bg(cx.theme().colors().panel_background.opacity(0.8)) + .border_l_1() + .border_color(cx.theme().colors().border_variant) + .absolute() + .right_1() + .top_0() + .bottom_0() + .w_4() + .pl_1() + .cursor_default() + .on_mouse_move(cx.listener(|_, _, _window, cx| { + cx.notify(); + cx.stop_propagation() + })) + .on_hover(|_, _window, cx| { + cx.stop_propagation(); + }) + .on_any_mouse_down(|_, _window, cx| { + cx.stop_propagation(); + }) + .on_scroll_wheel(cx.listener(|_, _, _window, cx| { + cx.notify(); + })) + .children(Scrollbar::vertical(self.scrollbar_state.clone())), + ) + } + + fn render_list_items( + &mut self, + range: Range, + _window: &mut Window, + cx: &mut Context, + ) -> Vec { + self.visible_items + .get(range.clone()) + .into_iter() + .flatten() + .enumerate() + .map(|(ix, item)| self.render_list_item(item, range.start + ix, cx)) + .collect() + } + + fn render_list_item(&self, item: &ListItemType, ix: usize, cx: &Context) -> AnyElement { + match item { + ListItemType::Entry { entry, format } => self + .render_history_entry(entry, *format, ix, Vec::default(), cx) + .into_any(), + ListItemType::SearchResult { entry, positions } => self.render_history_entry( + entry, + EntryTimeFormat::DateAndTime, + ix, + positions.clone(), + cx, + ), + ListItemType::BucketSeparator(bucket) => div() + .px(DynamicSpacing::Base06.rems(cx)) + .pt_2() + .pb_1() + .child( + Label::new(bucket.to_string()) + .size(LabelSize::XSmall) + .color(Color::Muted), + ) + .into_any_element(), + } + } + + fn render_history_entry( + &self, + entry: &HistoryEntry, + format: EntryTimeFormat, + ix: usize, + highlight_positions: Vec, + cx: &Context, + ) -> AnyElement { + let selected = ix == self.selected_index; + let hovered = Some(ix) == self.hovered_index; + let timestamp = entry.updated_at().timestamp(); + let thread_timestamp = format.format_timestamp(timestamp, self.local_timezone); + + h_flex() + .w_full() + .pb_1() + .child( + ListItem::new(ix) + .rounded() + .toggle_state(selected) + .spacing(ListItemSpacing::Sparse) + .start_slot( + h_flex() + .w_full() + .gap_2() + .justify_between() + .child( + HighlightedLabel::new(entry.title(), highlight_positions) + .size(LabelSize::Small) + .truncate(), + ) + .child( + Label::new(thread_timestamp) + .color(Color::Muted) + .size(LabelSize::XSmall), + ), + ) + .on_hover(cx.listener(move |this, is_hovered, _window, cx| { + if *is_hovered { + this.hovered_index = Some(ix); + } else if this.hovered_index == Some(ix) { + this.hovered_index = None; + } + + cx.notify(); + })) + .end_slot::(if hovered { + Some( + IconButton::new("delete", IconName::Trash) + .shape(IconButtonShape::Square) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .tooltip(move |window, cx| { + Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx) + }) + .on_click( + cx.listener(move |this, _, _, cx| this.remove_thread(ix, cx)), + ), + ) + } else { + None + }) + .on_click(cx.listener(move |this, _, _, cx| this.confirm_entry(ix, cx))), + ) + .into_any_element() + } +} + +impl Focusable for AcpThreadHistory { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.search_editor.focus_handle(cx) + } +} + +impl Render for AcpThreadHistory { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + v_flex() + .key_context("ThreadHistory") + .size_full() + .on_action(cx.listener(Self::select_previous)) + .on_action(cx.listener(Self::select_next)) + .on_action(cx.listener(Self::select_first)) + .on_action(cx.listener(Self::select_last)) + .on_action(cx.listener(Self::confirm)) + .on_action(cx.listener(Self::remove_selected_thread)) + .when(!self.history_store.read(cx).is_empty(cx), |parent| { + parent.child( + h_flex() + .h(px(41.)) // Match the toolbar perfectly + .w_full() + .py_1() + .px_2() + .gap_2() + .justify_between() + .border_b_1() + .border_color(cx.theme().colors().border) + .child( + Icon::new(IconName::MagnifyingGlass) + .color(Color::Muted) + .size(IconSize::Small), + ) + .child(self.search_editor.clone()), + ) + }) + .child({ + let view = v_flex() + .id("list-container") + .relative() + .overflow_hidden() + .flex_grow(); + + if self.history_store.read(cx).is_empty(cx) { + view.justify_center() + .child( + h_flex().w_full().justify_center().child( + Label::new("You don't have any past threads yet.") + .size(LabelSize::Small), + ), + ) + } else if self.search_produced_no_matches() { + view.justify_center().child( + h_flex().w_full().justify_center().child( + Label::new("No threads match your search.").size(LabelSize::Small), + ), + ) + } else { + view.pr_5() + .child( + uniform_list( + "thread-history", + self.visible_items.len(), + cx.processor(|this, range: Range, window, cx| { + this.render_list_items(range, window, cx) + }), + ) + .p_1() + .track_scroll(self.scroll_handle.clone()) + .flex_grow(), + ) + .when_some(self.render_scrollbar(cx), |div, scrollbar| { + div.child(scrollbar) + }) + } + }) + } +} + +#[derive(IntoElement)] +pub struct AcpHistoryEntryElement { + entry: HistoryEntry, + thread_view: WeakEntity, + selected: bool, + hovered: bool, + on_hover: Box, +} + +impl AcpHistoryEntryElement { + pub fn new(entry: HistoryEntry, thread_view: WeakEntity) -> Self { + Self { + entry, + thread_view, + selected: false, + hovered: false, + on_hover: Box::new(|_, _, _| {}), + } + } + + pub fn hovered(mut self, hovered: bool) -> Self { + self.hovered = hovered; + self + } + + pub fn on_hover(mut self, on_hover: impl Fn(&bool, &mut Window, &mut App) + 'static) -> Self { + self.on_hover = Box::new(on_hover); + self + } +} + +impl RenderOnce for AcpHistoryEntryElement { + fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement { + let id = self.entry.id(); + let title = self.entry.title(); + let timestamp = self.entry.updated_at(); + + let formatted_time = { + let now = chrono::Utc::now(); + let duration = now.signed_duration_since(timestamp); + + if duration.num_days() > 0 { + format!("{}d", duration.num_days()) + } else if duration.num_hours() > 0 { + format!("{}h ago", duration.num_hours()) + } else if duration.num_minutes() > 0 { + format!("{}m ago", duration.num_minutes()) + } else { + "Just now".to_string() + } + }; + + ListItem::new(id) + .rounded() + .toggle_state(self.selected) + .spacing(ListItemSpacing::Sparse) + .start_slot( + h_flex() + .w_full() + .gap_2() + .justify_between() + .child(Label::new(title).size(LabelSize::Small).truncate()) + .child( + Label::new(formatted_time) + .color(Color::Muted) + .size(LabelSize::XSmall), + ), + ) + .on_hover(self.on_hover) + .end_slot::(if self.hovered || self.selected { + Some( + IconButton::new("delete", IconName::Trash) + .shape(IconButtonShape::Square) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .tooltip(move |window, cx| { + Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx) + }) + .on_click({ + let thread_view = self.thread_view.clone(); + let entry = self.entry.clone(); + + move |_event, _window, cx| { + if let Some(thread_view) = thread_view.upgrade() { + thread_view.update(cx, |thread_view, cx| { + thread_view.delete_history_entry(entry.clone(), cx); + }); + } + } + }), + ) + } else { + None + }) + .on_click({ + let thread_view = self.thread_view.clone(); + let entry = self.entry; + + move |_event, window, cx| { + if let Some(workspace) = thread_view + .upgrade() + .and_then(|view| view.read(cx).workspace().upgrade()) + { + match &entry { + HistoryEntry::AcpThread(thread_metadata) => { + if let Some(panel) = workspace.read(cx).panel::(cx) { + panel.update(cx, |panel, cx| { + panel.load_agent_thread( + thread_metadata.clone(), + window, + cx, + ); + }); + } + } + HistoryEntry::TextThread(context) => { + if let Some(panel) = workspace.read(cx).panel::(cx) { + panel.update(cx, |panel, cx| { + panel + .open_saved_prompt_editor( + context.path.clone(), + window, + cx, + ) + .detach_and_log_err(cx); + }); + } + } + } + } + } + }) + } +} + +#[derive(Clone, Copy)] +pub enum EntryTimeFormat { + DateAndTime, + TimeOnly, +} + +impl EntryTimeFormat { + fn format_timestamp(&self, timestamp: i64, timezone: UtcOffset) -> String { + let timestamp = OffsetDateTime::from_unix_timestamp(timestamp).unwrap(); + + match self { + EntryTimeFormat::DateAndTime => time_format::format_localized_timestamp( + timestamp, + OffsetDateTime::now_utc(), + timezone, + time_format::TimestampFormat::EnhancedAbsolute, + ), + EntryTimeFormat::TimeOnly => time_format::format_time(timestamp), + } + } +} + +impl From for EntryTimeFormat { + fn from(bucket: TimeBucket) -> Self { + match bucket { + TimeBucket::Today => EntryTimeFormat::TimeOnly, + TimeBucket::Yesterday => EntryTimeFormat::TimeOnly, + TimeBucket::ThisWeek => EntryTimeFormat::DateAndTime, + TimeBucket::PastWeek => EntryTimeFormat::DateAndTime, + TimeBucket::All => EntryTimeFormat::DateAndTime, + } + } +} + +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +enum TimeBucket { + Today, + Yesterday, + ThisWeek, + PastWeek, + All, +} + +impl TimeBucket { + fn from_dates(reference: NaiveDate, date: NaiveDate) -> Self { + if date == reference { + return TimeBucket::Today; + } + + if date == reference - TimeDelta::days(1) { + return TimeBucket::Yesterday; + } + + let week = date.iso_week(); + + if reference.iso_week() == week { + return TimeBucket::ThisWeek; + } + + let last_week = (reference - TimeDelta::days(7)).iso_week(); + + if week == last_week { + return TimeBucket::PastWeek; + } + + TimeBucket::All + } +} + +impl Display for TimeBucket { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TimeBucket::Today => write!(f, "Today"), + TimeBucket::Yesterday => write!(f, "Yesterday"), + TimeBucket::ThisWeek => write!(f, "This Week"), + TimeBucket::PastWeek => write!(f, "Past Week"), + TimeBucket::All => write!(f, "All"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::NaiveDate; + + #[test] + fn test_time_bucket_from_dates() { + let today = NaiveDate::from_ymd_opt(2023, 1, 15).unwrap(); + + let date = today; + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::Today); + + let date = NaiveDate::from_ymd_opt(2023, 1, 14).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::Yesterday); + + let date = NaiveDate::from_ymd_opt(2023, 1, 13).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::ThisWeek); + + let date = NaiveDate::from_ymd_opt(2023, 1, 11).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::ThisWeek); + + let date = NaiveDate::from_ymd_opt(2023, 1, 8).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::PastWeek); + + let date = NaiveDate::from_ymd_opt(2023, 1, 5).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::PastWeek); + + // All: not in this week or last week + let date = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::All); + + // Test year boundary cases + let new_year = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap(); + + let date = NaiveDate::from_ymd_opt(2022, 12, 31).unwrap(); + assert_eq!( + TimeBucket::from_dates(new_year, date), + TimeBucket::Yesterday + ); + + let date = NaiveDate::from_ymd_opt(2022, 12, 28).unwrap(); + assert_eq!(TimeBucket::from_dates(new_year, date), TimeBucket::ThisWeek); + } +} diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 5f67dc15b8b1e6154a1bdfd06d092755c462814c..f8bc6c353bad4cca5cf6c5dbaa14f0e2927a1800 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,84 +1,285 @@ use acp_thread::{ AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, - LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, UserMessageId, + AuthRequired, LoadError, MentionUri, RetryStatus, ThreadStatus, ToolCall, ToolCallContent, + ToolCallStatus, UserMessageId, }; use acp_thread::{AgentConnection, Plan}; use action_log::ActionLog; -use agent::{TextThreadStore, ThreadStore}; -use agent_client_protocol as acp; -use agent_servers::AgentServer; -use agent_settings::{AgentSettings, NotifyWhenAgentWaiting}; +use agent_client_protocol::{self as acp, PromptCapabilities}; +use agent_servers::{AgentServer, AgentServerDelegate}; +use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, NotifyWhenAgentWaiting}; +use agent2::{DbThreadMetadata, HistoryEntry, HistoryEntryId, HistoryStore, NativeAgentServer}; +use anyhow::{Context as _, Result, anyhow, bail}; +use arrayvec::ArrayVec; use audio::{Audio, Sound}; use buffer_diff::BufferDiff; +use client::zed_urls; +use cloud_llm_client::PlanV1; use collections::{HashMap, HashSet}; use editor::scroll::Autoscroll; -use editor::{ - AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode, - EditorStyle, MinimapVisibility, MultiBuffer, PathKey, SelectionEffects, -}; +use editor::{Editor, EditorEvent, EditorMode, MultiBuffer, PathKey, SelectionEffects}; use file_icons::FileIcons; +use fs::Fs; +use futures::FutureExt as _; use gpui::{ - Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId, - FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, PlatformDisplay, - SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, - Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, - linear_gradient, list, percentage, point, prelude::*, pulsating_between, + Action, Animation, AnimationExt, AnyView, App, BorderStyle, ClickEvent, ClipboardItem, + CursorStyle, EdgesRefinement, ElementId, Empty, Entity, FocusHandle, Focusable, Hsla, Length, + ListOffset, ListState, MouseButton, PlatformDisplay, SharedString, Stateful, StyleRefinement, + Subscription, Task, TextStyle, TextStyleRefinement, UnderlineStyle, WeakEntity, Window, + WindowHandle, div, ease_in_out, linear_color_stop, linear_gradient, list, point, prelude::*, + pulsating_between, }; -use language::language_settings::SoftWrap; -use language::{Buffer, Language}; +use language::Buffer; + +use language_model::LanguageModelRegistry; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; -use parking_lot::Mutex; -use project::{CompletionIntent, Project}; -use prompt_store::PromptId; +use project::{Project, ProjectEntryId}; +use prompt_store::{PromptId, PromptStore}; use rope::Point; use settings::{Settings as _, SettingsStore}; -use std::fmt::Write as _; -use std::path::PathBuf; -use std::{ - cell::RefCell, collections::BTreeMap, path::Path, process::ExitStatus, rc::Rc, sync::Arc, - time::Duration, -}; -use terminal_view::TerminalView; -use text::{Anchor, BufferSnapshot}; -use theme::ThemeSettings; +use std::cell::{Cell, RefCell}; +use std::path::Path; +use std::sync::Arc; +use std::time::Instant; +use std::{collections::BTreeMap, rc::Rc, time::Duration}; +use terminal_view::terminal_panel::TerminalPanel; +use text::Anchor; +use theme::{AgentFontSize, ThemeSettings}; use ui::{ - Disclosure, Divider, DividerColor, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState, - Tooltip, prelude::*, + Callout, CommonAnimationExt, Disclosure, Divider, DividerColor, ElevationIndex, KeyBinding, + PopoverMenuHandle, Scrollbar, ScrollbarState, SpinnerLabel, TintColor, Tooltip, prelude::*, }; use util::{ResultExt, size::format_file_size, time::duration_alt_display}; use workspace::{CollaboratorId, Workspace}; -use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage, ToggleModelSelector}; +use zed_actions::agent::{Chat, ToggleModelSelector}; use zed_actions::assistant::OpenRulesLibrary; +use super::entry_view_state::EntryViewState; use crate::acp::AcpModelSelectorPopover; -use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSet}; -use crate::acp::message_history::MessageHistory; +use crate::acp::ModeSelector; +use crate::acp::entry_view_state::{EntryViewEvent, ViewEvent}; +use crate::acp::message_editor::{MessageEditor, MessageEditorEvent}; use crate::agent_diff::AgentDiff; -use crate::message_editor::{MAX_EDITOR_LINES, MIN_EDITOR_LINES}; -use crate::ui::{AgentNotification, AgentNotificationEvent}; +use crate::profile_selector::{ProfileProvider, ProfileSelector}; + +use crate::ui::preview::UsageCallout; +use crate::ui::{ + AgentNotification, AgentNotificationEvent, BurnModeTooltip, UnavailableEditingTooltip, +}; use crate::{ - AgentDiffPane, AgentPanel, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll, + AgentDiffPane, AgentPanel, AllowAlways, AllowOnce, ContinueThread, ContinueWithBurnMode, + CycleModeSelector, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, OpenHistory, RejectAll, + RejectOnce, ToggleBurnMode, ToggleProfileSelector, }; -const RESPONSE_PADDING_X: Pixels = px(19.); +pub const MIN_EDITOR_LINES: usize = 4; +pub const MAX_EDITOR_LINES: usize = 8; + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum ThreadFeedback { + Positive, + Negative, +} + +#[derive(Debug)] +enum ThreadError { + PaymentRequired, + ModelRequestLimitReached(cloud_llm_client::Plan), + ToolUseLimitReached, + Refusal, + AuthenticationRequired(SharedString), + Other(SharedString), +} + +impl ThreadError { + fn from_err(error: anyhow::Error, agent: &Rc) -> Self { + if error.is::() { + Self::PaymentRequired + } else if error.is::() { + Self::ToolUseLimitReached + } else if let Some(error) = + error.downcast_ref::() + { + Self::ModelRequestLimitReached(error.plan) + } else if let Some(acp_error) = error.downcast_ref::() + && acp_error.code == acp::ErrorCode::AUTH_REQUIRED.code + { + Self::AuthenticationRequired(acp_error.message.clone().into()) + } else { + let string = error.to_string(); + // TODO: we should have Gemini return better errors here. + if agent.clone().downcast::().is_some() + && string.contains("Could not load the default credentials") + || string.contains("API key not valid") + || string.contains("Request had invalid authentication credentials") + { + Self::AuthenticationRequired(string.into()) + } else { + Self::Other(error.to_string().into()) + } + } + } +} + +impl ProfileProvider for Entity { + fn profile_id(&self, cx: &App) -> AgentProfileId { + self.read(cx).profile().clone() + } + + fn set_profile(&self, profile_id: AgentProfileId, cx: &mut App) { + self.update(cx, |thread, _cx| { + thread.set_profile(profile_id); + }); + } + + fn profiles_supported(&self, cx: &App) -> bool { + self.read(cx) + .model() + .is_some_and(|model| model.supports_tools()) + } +} + +#[derive(Default)] +struct ThreadFeedbackState { + feedback: Option, + comments_editor: Option>, +} + +impl ThreadFeedbackState { + pub fn submit( + &mut self, + thread: Entity, + feedback: ThreadFeedback, + window: &mut Window, + cx: &mut App, + ) { + let Some(telemetry) = thread.read(cx).connection().telemetry() else { + return; + }; + + if self.feedback == Some(feedback) { + return; + } + + self.feedback = Some(feedback); + match feedback { + ThreadFeedback::Positive => { + self.comments_editor = None; + } + ThreadFeedback::Negative => { + self.comments_editor = Some(Self::build_feedback_comments_editor(window, cx)); + } + } + let session_id = thread.read(cx).session_id().clone(); + let agent_name = telemetry.agent_name(); + let task = telemetry.thread_data(&session_id, cx); + let rating = match feedback { + ThreadFeedback::Positive => "positive", + ThreadFeedback::Negative => "negative", + }; + cx.background_spawn(async move { + let thread = task.await?; + telemetry::event!( + "Agent Thread Rated", + session_id = session_id, + rating = rating, + agent = agent_name, + thread = thread + ); + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + + pub fn submit_comments(&mut self, thread: Entity, cx: &mut App) { + let Some(telemetry) = thread.read(cx).connection().telemetry() else { + return; + }; + + let Some(comments) = self + .comments_editor + .as_ref() + .map(|editor| editor.read(cx).text(cx)) + .filter(|text| !text.trim().is_empty()) + else { + return; + }; + + self.comments_editor.take(); + + let session_id = thread.read(cx).session_id().clone(); + let agent_name = telemetry.agent_name(); + let task = telemetry.thread_data(&session_id, cx); + cx.background_spawn(async move { + let thread = task.await?; + telemetry::event!( + "Agent Thread Feedback Comments", + session_id = session_id, + comments = comments, + agent = agent_name, + thread = thread + ); + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + + pub fn clear(&mut self) { + *self = Self::default() + } + + pub fn dismiss_comments(&mut self) { + self.comments_editor.take(); + } + + fn build_feedback_comments_editor(window: &mut Window, cx: &mut App) -> Entity { + let buffer = cx.new(|cx| { + let empty_string = String::new(); + MultiBuffer::singleton(cx.new(|cx| Buffer::local(empty_string, cx)), cx) + }); + + let editor = cx.new(|cx| { + let mut editor = Editor::new( + editor::EditorMode::AutoHeight { + min_lines: 1, + max_lines: Some(4), + }, + buffer, + None, + window, + cx, + ); + editor.set_placeholder_text( + "What went wrong? Share your feedback so we can improve.", + window, + cx, + ); + editor + }); + + editor.read(cx).focus_handle(cx).focus(window); + editor + } +} pub struct AcpThreadView { agent: Rc, workspace: WeakEntity, project: Entity, - thread_store: Entity, - text_thread_store: Entity, thread_state: ThreadState, - diff_editors: HashMap>, - terminal_views: HashMap>, - message_editor: Entity, + login: Option, + history_store: Entity, + hovered_recent_history_item: Option, + entry_view_state: Entity, + message_editor: Entity, + focus_handle: FocusHandle, model_selector: Option>, - message_set_from_history: Option, - _message_editor_subscription: Subscription, - mention_set: Arc>, + profile_selector: Option>, notifications: Vec>, notification_subscriptions: HashMap, Vec>, - last_error: Option>, + thread_retry_status: Option, + thread_error: Option, + thread_feedback: ThreadFeedbackState, list_state: ListState, scrollbar_state: ScrollbarState, auth_task: Option>, @@ -87,167 +288,218 @@ pub struct AcpThreadView { edits_expanded: bool, plan_expanded: bool, editor_expanded: bool, - terminal_expanded: bool, - message_history: Rc>>>, + should_be_following: bool, + editing_message: Option, + prompt_capabilities: Rc>, + available_commands: Rc>>, + is_loading_contents: bool, + new_server_version_available: Option, _cancel_task: Option>, - _subscriptions: [Subscription; 1], + _subscriptions: [Subscription; 4], } enum ThreadState { - Loading { - _task: Task<()>, - }, + Loading(Entity), Ready { thread: Entity, - _subscription: [Subscription; 2], + title_editor: Option>, + mode_selector: Option>, + _subscriptions: Vec, }, LoadError(LoadError), Unauthenticated { connection: Rc, + description: Option>, + configuration_view: Option, + pending_auth_method: Option, + _subscription: Option, }, - ServerExited { - status: ExitStatus, - }, +} + +struct LoadingView { + title: SharedString, + _load_task: Task<()>, + _update_title_task: Task>, } impl AcpThreadView { pub fn new( agent: Rc, + resume_thread: Option, + summarize_thread: Option, workspace: WeakEntity, project: Entity, - thread_store: Entity, - text_thread_store: Entity, - message_history: Rc>>>, - min_lines: usize, - max_lines: Option, + history_store: Entity, + prompt_store: Option>, window: &mut Window, cx: &mut Context, ) -> Self { - let language = Language::new( - language::LanguageConfig { - completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']), - ..Default::default() - }, - None, - ); - - let mention_set = Arc::new(Mutex::new(MentionSet::default())); + let prompt_capabilities = Rc::new(Cell::new(acp::PromptCapabilities::default())); + let available_commands = Rc::new(RefCell::new(vec![])); + + let placeholder = if agent.name() == "Zed Agent" { + format!("Message the {} — @ to include context", agent.name()) + } else if agent.name() == "Claude Code" || !available_commands.borrow().is_empty() { + format!( + "Message {} — @ to include context, / for commands", + agent.name() + ) + } else { + format!("Message {} — @ to include context", agent.name()) + }; let message_editor = cx.new(|cx| { - let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx)); - let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); - - let mut editor = Editor::new( + let mut editor = MessageEditor::new( + workspace.clone(), + project.clone(), + history_store.clone(), + prompt_store.clone(), + prompt_capabilities.clone(), + available_commands.clone(), + agent.name(), + &placeholder, editor::EditorMode::AutoHeight { - min_lines, - max_lines: max_lines, + min_lines: MIN_EDITOR_LINES, + max_lines: Some(MAX_EDITOR_LINES), }, - buffer, - None, window, cx, ); - editor.set_placeholder_text("Message the agent - @ to include files", cx); - editor.set_show_indent_guides(false, cx); - editor.set_soft_wrap(); - editor.set_use_modal_editing(true); - editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new( - mention_set.clone(), - workspace.clone(), - thread_store.downgrade(), - text_thread_store.downgrade(), - cx.weak_entity(), - )))); - editor.set_context_menu_options(ContextMenuOptions { - min_entries_visible: 12, - max_entries_visible: 12, - placement: Some(ContextMenuPlacement::Above), - }); + if let Some(entry) = summarize_thread { + editor.insert_thread_summary(entry, window, cx); + } editor }); - let message_editor_subscription = - cx.subscribe(&message_editor, |this, editor, event, cx| { - if let editor::EditorEvent::BufferEdited = &event { - let buffer = editor - .read(cx) - .buffer() - .read(cx) - .as_singleton() - .unwrap() - .read(cx) - .snapshot(); - if let Some(message) = this.message_set_from_history.clone() - && message.version() != buffer.version() - { - this.message_set_from_history = None; - } - - if this.message_set_from_history.is_none() { - this.message_history.borrow_mut().reset_position(); - } - } - }); - - let mention_set = mention_set.clone(); - let list_state = ListState::new(0, gpui::ListAlignment::Bottom, px(2048.0)); - let subscription = cx.observe_global_in::(window, Self::settings_changed); + let entry_view_state = cx.new(|_| { + EntryViewState::new( + workspace.clone(), + project.clone(), + history_store.clone(), + prompt_store.clone(), + prompt_capabilities.clone(), + available_commands.clone(), + agent.name(), + ) + }); + + let subscriptions = [ + cx.observe_global_in::(window, Self::agent_font_size_changed), + cx.observe_global_in::(window, Self::agent_font_size_changed), + cx.subscribe_in(&message_editor, window, Self::handle_message_editor_event), + cx.subscribe_in(&entry_view_state, window, Self::handle_entry_view_event), + ]; Self { agent: agent.clone(), workspace: workspace.clone(), project: project.clone(), - thread_store, - text_thread_store, - thread_state: Self::initial_state(agent, workspace, project, window, cx), + entry_view_state, + thread_state: Self::initial_state(agent, resume_thread, workspace, project, window, cx), + login: None, message_editor, model_selector: None, - message_set_from_history: None, - _message_editor_subscription: message_editor_subscription, - mention_set, + profile_selector: None, + notifications: Vec::new(), notification_subscriptions: HashMap::default(), - diff_editors: Default::default(), - terminal_views: Default::default(), list_state: list_state.clone(), scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()), - last_error: None, + thread_retry_status: None, + thread_error: None, + thread_feedback: Default::default(), auth_task: None, expanded_tool_calls: HashSet::default(), expanded_thinking_blocks: HashSet::default(), + editing_message: None, edits_expanded: false, plan_expanded: false, + prompt_capabilities, + available_commands, editor_expanded: false, - terminal_expanded: true, - message_history, - _subscriptions: [subscription], + should_be_following: false, + history_store, + hovered_recent_history_item: None, + is_loading_contents: false, + _subscriptions: subscriptions, _cancel_task: None, + focus_handle: cx.focus_handle(), + new_server_version_available: None, } } + fn reset(&mut self, window: &mut Window, cx: &mut Context) { + self.thread_state = Self::initial_state( + self.agent.clone(), + None, + self.workspace.clone(), + self.project.clone(), + window, + cx, + ); + self.available_commands.replace(vec![]); + self.new_server_version_available.take(); + cx.notify(); + } + fn initial_state( agent: Rc, + resume_thread: Option, workspace: WeakEntity, project: Entity, window: &mut Window, cx: &mut Context, ) -> ThreadState { - let root_dir = project - .read(cx) - .visible_worktrees(cx) - .next() - .map(|worktree| worktree.read(cx).abs_path()) - .unwrap_or_else(|| paths::home_dir().as_path().into()); + if project.read(cx).is_via_collab() + && agent.clone().downcast::().is_none() + { + return ThreadState::LoadError(LoadError::Other( + "External agents are not yet supported in shared projects.".into(), + )); + } + let mut worktrees = project.read(cx).visible_worktrees(cx).collect::>(); + // Pick the first non-single-file worktree for the root directory if there are any, + // and otherwise the parent of a single-file worktree, falling back to $HOME if there are no visible worktrees. + worktrees.sort_by(|l, r| { + l.read(cx) + .is_single_file() + .cmp(&r.read(cx).is_single_file()) + }); + let root_dir = worktrees + .into_iter() + .filter_map(|worktree| { + if worktree.read(cx).is_single_file() { + Some(worktree.read(cx).abs_path().parent()?.into()) + } else { + Some(worktree.read(cx).abs_path()) + } + }) + .next(); + let (status_tx, mut status_rx) = watch::channel("Loading…".into()); + let (new_version_available_tx, mut new_version_available_rx) = watch::channel(None); + let delegate = AgentServerDelegate::new( + project.read(cx).agent_server_store().clone(), + project.clone(), + Some(status_tx), + Some(new_version_available_tx), + ); - let connect_task = agent.connect(&root_dir, &project, cx); + let connect_task = agent.connect(root_dir.as_deref(), delegate, cx); let load_task = cx.spawn_in(window, async move |this, cx| { let connection = match connect_task.await { - Ok(connection) => connection, + Ok((connection, login)) => { + this.update(cx, |this, _| this.login = login).ok(); + connection + } Err(err) => { - this.update(cx, |this, cx| { - this.handle_load_error(err, cx); + this.update_in(cx, |this, window, cx| { + if err.downcast_ref::().is_some() { + this.handle_load_error(err, window, cx); + } else { + this.handle_thread_error(err, cx); + } cx.notify(); }) .log_err(); @@ -255,53 +507,79 @@ impl AcpThreadView { } }; - // this.update_in(cx, |_this, _window, cx| { - // let status = connection.exit_status(cx); - // cx.spawn(async move |this, cx| { - // let status = status.await.ok(); - // this.update(cx, |this, cx| { - // this.thread_state = ThreadState::ServerExited { status }; - // cx.notify(); - // }) - // .ok(); - // }) - // .detach(); - // }) - // .ok(); - - let result = match connection + let result = if let Some(native_agent) = connection .clone() - .new_thread(project.clone(), &root_dir, cx) - .await + .downcast::() + && let Some(resume) = resume_thread.clone() { - Err(e) => { - let mut cx = cx.clone(); - if e.is::() { - this.update(&mut cx, |this, cx| { - this.thread_state = ThreadState::Unauthenticated { connection }; - cx.notify(); + cx.update(|_, cx| { + native_agent + .0 + .update(cx, |agent, cx| agent.open_thread(resume.id, cx)) + }) + .log_err() + } else { + let root_dir = if let Some(acp_agent) = connection + .clone() + .downcast::() + { + acp_agent.root_dir().into() + } else { + root_dir.unwrap_or(paths::home_dir().as_path().into()) + }; + cx.update(|_, cx| { + connection + .clone() + .new_thread(project.clone(), &root_dir, cx) + }) + .log_err() + }; + + let Some(result) = result else { + return; + }; + + let result = match result.await { + Err(e) => match e.downcast::() { + Ok(err) => { + cx.update(|window, cx| { + Self::handle_auth_required(this, err, agent, connection, window, cx) }) - .ok(); + .log_err(); return; - } else { - Err(e) } - } + Err(err) => Err(err), + }, Ok(thread) => Ok(thread), }; this.update_in(cx, |this, window, cx| { match result { Ok(thread) => { - let thread_subscription = - cx.subscribe_in(&thread, window, Self::handle_thread_event); - let action_log = thread.read(cx).action_log().clone(); - let action_log_subscription = - cx.observe(&action_log, |_, _, cx| cx.notify()); - this.list_state - .splice(0..0, thread.read(cx).entries().len()); + this.prompt_capabilities + .set(thread.read(cx).prompt_capabilities()); + + let count = thread.read(cx).entries().len(); + this.entry_view_state.update(cx, |view_state, cx| { + for ix in 0..count { + view_state.sync_entry(ix, &thread, window, cx); + } + this.list_state.splice_focusable( + 0..0, + (0..count).map(|ix| view_state.entry(ix)?.focus_handle(cx)), + ); + }); + + if let Some(resume) = resume_thread { + this.history_store.update(cx, |history, cx| { + history.push_recently_opened_entry( + HistoryEntryId::AcpThread(resume.id), + cx, + ); + }); + } AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx); @@ -323,55 +601,236 @@ impl AcpThreadView { }) }); + let mode_selector = thread + .read(cx) + .connection() + .session_modes(thread.read(cx).session_id(), cx) + .map(|session_modes| { + let fs = this.project.read(cx).fs().clone(); + let focus_handle = this.focus_handle(cx); + cx.new(|_cx| { + ModeSelector::new( + session_modes, + this.agent.clone(), + fs, + focus_handle, + ) + }) + }); + + let mut subscriptions = vec![ + cx.subscribe_in(&thread, window, Self::handle_thread_event), + cx.observe(&action_log, |_, _, cx| cx.notify()), + ]; + + let title_editor = + if thread.update(cx, |thread, cx| thread.can_set_title(cx)) { + let editor = cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor.set_text(thread.read(cx).title(), window, cx); + editor + }); + subscriptions.push(cx.subscribe_in( + &editor, + window, + Self::handle_title_editor_event, + )); + Some(editor) + } else { + None + }; + this.thread_state = ThreadState::Ready { thread, - _subscription: [thread_subscription, action_log_subscription], + title_editor, + mode_selector, + _subscriptions: subscriptions, }; + this.message_editor.focus_handle(cx).focus(window); + + this.profile_selector = this.as_native_thread(cx).map(|thread| { + cx.new(|cx| { + ProfileSelector::new( + ::global(cx), + Arc::new(thread.clone()), + this.focus_handle(cx), + cx, + ) + }) + }); cx.notify(); } Err(err) => { - this.handle_load_error(err, cx); + this.handle_load_error(err, window, cx); } }; }) .log_err(); }); - ThreadState::Loading { _task: load_task } + cx.spawn(async move |this, cx| { + while let Ok(new_version) = new_version_available_rx.recv().await { + if let Some(new_version) = new_version { + this.update(cx, |this, cx| { + this.new_server_version_available = Some(new_version.into()); + cx.notify(); + }) + .log_err(); + } + } + }) + .detach(); + + let loading_view = cx.new(|cx| { + let update_title_task = cx.spawn(async move |this, cx| { + loop { + let status = status_rx.recv().await?; + this.update(cx, |this: &mut LoadingView, cx| { + this.title = status; + cx.notify(); + })?; + } + }); + + LoadingView { + title: "Loading…".into(), + _load_task: load_task, + _update_title_task: update_title_task, + } + }); + + ThreadState::Loading(loading_view) + } + + fn handle_auth_required( + this: WeakEntity, + err: AuthRequired, + agent: Rc, + connection: Rc, + window: &mut Window, + cx: &mut App, + ) { + let agent_name = agent.name(); + let (configuration_view, subscription) = if let Some(provider_id) = err.provider_id { + let registry = LanguageModelRegistry::global(cx); + + let sub = window.subscribe(®istry, cx, { + let provider_id = provider_id.clone(); + let this = this.clone(); + move |_, ev, window, cx| { + if let language_model::Event::ProviderStateChanged(updated_provider_id) = &ev + && &provider_id == updated_provider_id + && LanguageModelRegistry::global(cx) + .read(cx) + .provider(&provider_id) + .map_or(false, |provider| provider.is_authenticated(cx)) + { + this.update(cx, |this, cx| { + this.reset(window, cx); + }) + .ok(); + } + } + }); + + let view = registry.read(cx).provider(&provider_id).map(|provider| { + provider.configuration_view( + language_model::ConfigurationViewTargetAgent::Other(agent_name.clone()), + window, + cx, + ) + }); + + (view, Some(sub)) + } else { + (None, None) + }; + + this.update(cx, |this, cx| { + this.thread_state = ThreadState::Unauthenticated { + pending_auth_method: None, + connection, + configuration_view, + description: err + .description + .clone() + .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx))), + _subscription: subscription, + }; + if this.message_editor.focus_handle(cx).is_focused(window) { + this.focus_handle.focus(window) + } + cx.notify(); + }) + .ok(); } - fn handle_load_error(&mut self, err: anyhow::Error, cx: &mut Context) { + fn handle_load_error( + &mut self, + err: anyhow::Error, + window: &mut Window, + cx: &mut Context, + ) { if let Some(load_err) = err.downcast_ref::() { self.thread_state = ThreadState::LoadError(load_err.clone()); } else { self.thread_state = ThreadState::LoadError(LoadError::Other(err.to_string().into())) } + if self.message_editor.focus_handle(cx).is_focused(window) { + self.focus_handle.focus(window) + } cx.notify(); } + pub fn workspace(&self) -> &WeakEntity { + &self.workspace + } + pub fn thread(&self) -> Option<&Entity> { match &self.thread_state { ThreadState::Ready { thread, .. } => Some(thread), ThreadState::Unauthenticated { .. } | ThreadState::Loading { .. } - | ThreadState::LoadError(..) - | ThreadState::ServerExited { .. } => None, + | ThreadState::LoadError { .. } => None, + } + } + + pub fn mode_selector(&self) -> Option<&Entity> { + match &self.thread_state { + ThreadState::Ready { mode_selector, .. } => mode_selector.as_ref(), + ThreadState::Unauthenticated { .. } + | ThreadState::Loading { .. } + | ThreadState::LoadError { .. } => None, } } pub fn title(&self, cx: &App) -> SharedString { match &self.thread_state { - ThreadState::Ready { thread, .. } => thread.read(cx).title(), - ThreadState::Loading { .. } => "Loading…".into(), - ThreadState::LoadError(_) => "Failed to load".into(), - ThreadState::Unauthenticated { .. } => "Not authenticated".into(), - ThreadState::ServerExited { .. } => "Server exited unexpectedly".into(), + ThreadState::Ready { .. } | ThreadState::Unauthenticated { .. } => "New Thread".into(), + ThreadState::Loading(loading_view) => loading_view.read(cx).title.clone(), + ThreadState::LoadError(error) => match error { + LoadError::Unsupported { .. } => format!("Upgrade {}", self.agent.name()).into(), + LoadError::FailedToInstall(_) => { + format!("Failed to Install {}", self.agent.name()).into() + } + LoadError::Exited { .. } => format!("{} Exited", self.agent.name()).into(), + LoadError::Other(_) => format!("Error Loading {}", self.agent.name()).into(), + }, + } + } + + pub fn title_editor(&self) -> Option> { + if let ThreadState::Ready { title_editor, .. } = &self.thread_state { + title_editor.clone() + } else { + None } } - pub fn cancel(&mut self, cx: &mut Context) { - self.last_error.take(); + pub fn cancel_generation(&mut self, cx: &mut Context) { + self.thread_error.take(); + self.thread_retry_status.take(); if let Some(thread) = self.thread() { self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx))); @@ -390,193 +849,372 @@ impl AcpThreadView { fn set_editor_is_expanded(&mut self, is_expanded: bool, cx: &mut Context) { self.editor_expanded = is_expanded; - self.message_editor.update(cx, |editor, _| { - if self.editor_expanded { - editor.set_mode(EditorMode::Full { - scale_ui_elements_with_buffer_font_size: false, - show_active_line_background: false, - sized_by_content: false, - }) + self.message_editor.update(cx, |editor, cx| { + if is_expanded { + editor.set_mode( + EditorMode::Full { + scale_ui_elements_with_buffer_font_size: false, + show_active_line_background: false, + sized_by_content: false, + }, + cx, + ) } else { - editor.set_mode(EditorMode::AutoHeight { - min_lines: MIN_EDITOR_LINES, - max_lines: Some(MAX_EDITOR_LINES), - }) + editor.set_mode( + EditorMode::AutoHeight { + min_lines: MIN_EDITOR_LINES, + max_lines: Some(MAX_EDITOR_LINES), + }, + cx, + ) } }); cx.notify(); } - fn chat(&mut self, _: &Chat, window: &mut Window, cx: &mut Context) { - self.last_error.take(); - - let mut ix = 0; - let mut chunks: Vec = Vec::new(); - let project = self.project.clone(); - - let thread_store = self.thread_store.clone(); - let text_thread_store = self.text_thread_store.clone(); - - let contents = - self.mention_set - .lock() - .contents(project, thread_store, text_thread_store, window, cx); + pub fn handle_title_editor_event( + &mut self, + title_editor: &Entity, + event: &EditorEvent, + window: &mut Window, + cx: &mut Context, + ) { + let Some(thread) = self.thread() else { return }; - cx.spawn_in(window, async move |this, cx| { - let contents = match contents.await { - Ok(contents) => contents, - Err(e) => { - this.update(cx, |this, cx| { - this.last_error = - Some(cx.new(|cx| Markdown::new(e.to_string().into(), None, None, cx))); - }) - .ok(); - return; + match event { + EditorEvent::BufferEdited => { + let new_title = title_editor.read(cx).text(cx); + thread.update(cx, |thread, cx| { + thread + .set_title(new_title.into(), cx) + .detach_and_log_err(cx); + }) + } + EditorEvent::Blurred => { + if title_editor.read(cx).text(cx).is_empty() { + title_editor.update(cx, |editor, cx| { + editor.set_text("New Thread", window, cx); + }); } - }; + } + _ => {} + } + } - this.update_in(cx, |this, window, cx| { - this.message_editor.update(cx, |editor, cx| { - let text = editor.text(cx); - editor.display_map.update(cx, |map, cx| { - let snapshot = map.snapshot(cx); - for (crease_id, crease) in snapshot.crease_snapshot.creases() { - // Skip creases that have been edited out of the message buffer. - if !crease.range().start.is_valid(&snapshot.buffer_snapshot) { - continue; - } + pub fn handle_message_editor_event( + &mut self, + _: &Entity, + event: &MessageEditorEvent, + window: &mut Window, + cx: &mut Context, + ) { + match event { + MessageEditorEvent::Send => self.send(window, cx), + MessageEditorEvent::Cancel => self.cancel_generation(cx), + MessageEditorEvent::Focus => { + self.cancel_editing(&Default::default(), window, cx); + } + MessageEditorEvent::LostFocus => {} + } + } - if let Some(mention) = contents.get(&crease_id) { - let crease_range = - crease.range().to_offset(&snapshot.buffer_snapshot); - if crease_range.start > ix { - chunks.push(text[ix..crease_range.start].into()); - } - chunks.push(acp::ContentBlock::Resource(acp::EmbeddedResource { - annotations: None, - resource: acp::EmbeddedResourceResource::TextResourceContents( - acp::TextResourceContents { - mime_type: None, - text: mention.content.clone(), - uri: mention.uri.to_uri().to_string(), - }, - ), - })); - ix = crease_range.end; - } - } + pub fn handle_entry_view_event( + &mut self, + _: &Entity, + event: &EntryViewEvent, + window: &mut Window, + cx: &mut Context, + ) { + match &event.view_event { + ViewEvent::NewDiff(tool_call_id) => { + if AgentSettings::get_global(cx).expand_edit_card { + self.expanded_tool_calls.insert(tool_call_id.clone()); + } + } + ViewEvent::NewTerminal(tool_call_id) => { + if AgentSettings::get_global(cx).expand_terminal_card { + self.expanded_tool_calls.insert(tool_call_id.clone()); + } + } + ViewEvent::TerminalMovedToBackground(tool_call_id) => { + self.expanded_tool_calls.remove(tool_call_id); + } + ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Focus) => { + if let Some(thread) = self.thread() + && let Some(AgentThreadEntry::UserMessage(user_message)) = + thread.read(cx).entries().get(event.entry_index) + && user_message.id.is_some() + { + self.editing_message = Some(event.entry_index); + cx.notify(); + } + } + ViewEvent::MessageEditorEvent(editor, MessageEditorEvent::LostFocus) => { + if let Some(thread) = self.thread() + && let Some(AgentThreadEntry::UserMessage(user_message)) = + thread.read(cx).entries().get(event.entry_index) + && user_message.id.is_some() + { + if editor.read(cx).text(cx).as_str() == user_message.content.to_markdown(cx) { + self.editing_message = None; + cx.notify(); + } + } + } + ViewEvent::MessageEditorEvent(editor, MessageEditorEvent::Send) => { + self.regenerate(event.entry_index, editor.clone(), window, cx); + } + ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Cancel) => { + self.cancel_editing(&Default::default(), window, cx); + } + } + } - if ix < text.len() { - let last_chunk = text[ix..].trim_end(); - if !last_chunk.is_empty() { - chunks.push(last_chunk.into()); - } - } - }) - }); + fn resume_chat(&mut self, cx: &mut Context) { + self.thread_error.take(); + let Some(thread) = self.thread() else { + return; + }; + if !thread.read(cx).can_resume(cx) { + return; + } - if chunks.is_empty() { - return; + let task = thread.update(cx, |thread, cx| thread.resume(cx)); + cx.spawn(async move |this, cx| { + let result = task.await; + + this.update(cx, |this, cx| { + if let Err(err) = result { + this.handle_thread_error(err, cx); } + }) + }) + .detach(); + } - let Some(thread) = this.thread() else { - return; - }; - let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx)); + fn send(&mut self, window: &mut Window, cx: &mut Context) { + let Some(thread) = self.thread() else { return }; - cx.spawn(async move |this, cx| { - let result = task.await; + if self.is_loading_contents { + return; + } - this.update(cx, |this, cx| { - if let Err(err) = result { - this.last_error = - Some(cx.new(|cx| { - Markdown::new(err.to_string().into(), None, None, cx) - })) - } - }) - }) - .detach(); + self.history_store.update(cx, |history, cx| { + history.push_recently_opened_entry( + HistoryEntryId::AcpThread(thread.read(cx).session_id().clone()), + cx, + ); + }); - let mention_set = this.mention_set.clone(); + if thread.read(cx).status() != ThreadStatus::Idle { + self.stop_current_and_send_new_message(window, cx); + return; + } - this.set_editor_is_expanded(false, cx); + let text = self.message_editor.read(cx).text(cx); + let text = text.trim(); + if text == "/login" || text == "/logout" { + let ThreadState::Ready { thread, .. } = &self.thread_state else { + return; + }; - this.message_editor.update(cx, |editor, cx| { - editor.clear(window, cx); - editor.remove_creases(mention_set.lock().drain(), cx) - }); + let connection = thread.read(cx).connection().clone(); + if !connection + .auth_methods() + .iter() + .any(|method| method.id.0.as_ref() == "claude-login") + { + return; + }; + let this = cx.weak_entity(); + let agent = self.agent.clone(); + window.defer(cx, |window, cx| { + Self::handle_auth_required( + this, + AuthRequired { + description: None, + provider_id: None, + }, + agent, + connection, + window, + cx, + ); + }); + cx.notify(); + return; + } - this.scroll_to_bottom(cx); + let contents = self + .message_editor + .update(cx, |message_editor, cx| message_editor.contents(cx)); + self.send_impl(contents, window, cx) + } + + fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context) { + let Some(thread) = self.thread().cloned() else { + return; + }; - this.message_history.borrow_mut().push(chunks); + let cancelled = thread.update(cx, |thread, cx| thread.cancel(cx)); + + let contents = self + .message_editor + .update(cx, |message_editor, cx| message_editor.contents(cx)); + + cx.spawn_in(window, async move |this, cx| { + cancelled.await; + + this.update_in(cx, |this, window, cx| { + this.send_impl(contents, window, cx); }) .ok(); }) .detach(); } - fn previous_history_message( + fn send_impl( &mut self, - _: &PreviousHistoryMessage, + contents: Task, Vec>)>>, window: &mut Window, cx: &mut Context, ) { - if self.message_set_from_history.is_none() && !self.message_editor.read(cx).is_empty(cx) { - self.message_editor.update(cx, |editor, cx| { - editor.move_up(&Default::default(), window, cx); - }); + let agent_telemetry_id = self.agent.telemetry_id(); + + self.thread_error.take(); + self.editing_message.take(); + self.thread_feedback.clear(); + + let Some(thread) = self.thread() else { return; + }; + let thread = thread.downgrade(); + if self.should_be_following { + self.workspace + .update(cx, |workspace, cx| { + workspace.follow(CollaboratorId::Agent, window, cx); + }) + .ok(); } - self.message_set_from_history = Self::set_draft_message( - self.message_editor.clone(), - self.mention_set.clone(), - self.project.clone(), - self.message_history - .borrow_mut() - .prev() - .map(|blocks| blocks.as_slice()), - window, - cx, - ); + self.is_loading_contents = true; + let guard = cx.new(|_| ()); + cx.observe_release(&guard, |this, _guard, cx| { + this.is_loading_contents = false; + cx.notify(); + }) + .detach(); + + let task = cx.spawn_in(window, async move |this, cx| { + let (contents, tracked_buffers) = contents.await?; + + if contents.is_empty() { + return Ok(()); + } + + this.update_in(cx, |this, window, cx| { + this.set_editor_is_expanded(false, cx); + this.scroll_to_bottom(cx); + this.message_editor.update(cx, |message_editor, cx| { + message_editor.clear(window, cx); + }); + })?; + let send = thread.update(cx, |thread, cx| { + thread.action_log().update(cx, |action_log, cx| { + for buffer in tracked_buffers { + action_log.buffer_read(buffer, cx) + } + }); + drop(guard); + + telemetry::event!("Agent Message Sent", agent = agent_telemetry_id); + + thread.send(contents, cx) + })?; + send.await + }); + + cx.spawn(async move |this, cx| { + if let Err(err) = task.await { + this.update(cx, |this, cx| { + this.handle_thread_error(err, cx); + }) + .ok(); + } else { + this.update(cx, |this, cx| { + this.should_be_following = this + .workspace + .update(cx, |workspace, _| { + workspace.is_being_followed(CollaboratorId::Agent) + }) + .unwrap_or_default(); + }) + .ok(); + } + }) + .detach(); + } + + fn cancel_editing(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context) { + let Some(thread) = self.thread().cloned() else { + return; + }; + + if let Some(index) = self.editing_message.take() + && let Some(editor) = self + .entry_view_state + .read(cx) + .entry(index) + .and_then(|e| e.message_editor()) + .cloned() + { + editor.update(cx, |editor, cx| { + if let Some(user_message) = thread + .read(cx) + .entries() + .get(index) + .and_then(|e| e.user_message()) + { + editor.set_message(user_message.chunks.clone(), window, cx); + } + }) + }; + self.focus_handle(cx).focus(window); + cx.notify(); } - fn next_history_message( + fn regenerate( &mut self, - _: &NextHistoryMessage, + entry_ix: usize, + message_editor: Entity, window: &mut Window, cx: &mut Context, ) { - if self.message_set_from_history.is_none() { - self.message_editor.update(cx, |editor, cx| { - editor.move_down(&Default::default(), window, cx); - }); + let Some(thread) = self.thread().cloned() else { + return; + }; + if self.is_loading_contents { return; } - let mut message_history = self.message_history.borrow_mut(); - let next_history = message_history.next(); - - let set_draft_message = Self::set_draft_message( - self.message_editor.clone(), - self.mention_set.clone(), - self.project.clone(), - Some( - next_history - .map(|blocks| blocks.as_slice()) - .unwrap_or_else(|| &[]), - ), - window, - cx, - ); - // If we reset the text to an empty string because we ran out of history, - // we don't want to mark it as coming from the history - self.message_set_from_history = if next_history.is_some() { - set_draft_message - } else { - None + let Some(user_message_id) = thread.update(cx, |thread, _| { + thread.entries().get(entry_ix)?.user_message()?.id.clone() + }) else { + return; }; + + cx.spawn_in(window, async move |this, cx| { + thread + .update(cx, |thread, cx| thread.rewind(user_message_id, cx))? + .await?; + let contents = + message_editor.update(cx, |message_editor, cx| message_editor.contents(cx))?; + this.update_in(cx, |this, window, cx| { + this.send_impl(contents, window, cx); + })?; + anyhow::Ok(()) + }) + .detach(); } fn open_agent_diff(&mut self, _: &OpenAgentDiff, window: &mut Window, cx: &mut Context) { @@ -602,94 +1240,50 @@ impl AcpThreadView { }; diff.update(cx, |diff, cx| { - diff.move_to_path(PathKey::for_buffer(&buffer, cx), window, cx) + diff.move_to_path(PathKey::for_buffer(buffer, cx), window, cx) }) } - fn set_draft_message( - message_editor: Entity, - mention_set: Arc>, - project: Entity, - message: Option<&[acp::ContentBlock]>, - window: &mut Window, - cx: &mut Context, - ) -> Option { - cx.notify(); - - let message = message?; - - let mut text = String::new(); - let mut mentions = Vec::new(); + fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context) { + let Some(thread) = self.as_native_thread(cx) else { + return; + }; + let project_context = thread.read(cx).project_context().read(cx); - for chunk in message { - match chunk { - acp::ContentBlock::Text(text_content) => { - text.push_str(&text_content.text); - } - acp::ContentBlock::Resource(acp::EmbeddedResource { - resource: acp::EmbeddedResourceResource::TextResourceContents(resource), - .. - }) => { - let path = PathBuf::from(&resource.uri); - let project_path = project.read(cx).project_path_for_absolute_path(&path, cx); - let start = text.len(); - let _ = write!(&mut text, "{}", MentionUri::File(path).to_uri()); - let end = text.len(); - if let Some(project_path) = project_path { - let filename: SharedString = project_path - .path - .file_name() - .unwrap_or_default() - .to_string_lossy() - .to_string() - .into(); - mentions.push((start..end, project_path, filename)); - } + let project_entry_ids = project_context + .worktrees + .iter() + .flat_map(|worktree| worktree.rules_file.as_ref()) + .map(|rules_file| ProjectEntryId::from_usize(rules_file.project_entry_id)) + .collect::>(); + + self.workspace + .update(cx, move |workspace, cx| { + // TODO: Open a multibuffer instead? In some cases this doesn't make the set of rules + // files clear. For example, if rules file 1 is already open but rules file 2 is not, + // this would open and focus rules file 2 in a tab that is not next to rules file 1. + let project = workspace.project().read(cx); + let project_paths = project_entry_ids + .into_iter() + .flat_map(|entry_id| project.path_for_entry(entry_id, cx)) + .collect::>(); + for project_path in project_paths { + workspace + .open_path(project_path, None, true, window, cx) + .detach_and_log_err(cx); } - acp::ContentBlock::Image(_) - | acp::ContentBlock::Audio(_) - | acp::ContentBlock::Resource(_) - | acp::ContentBlock::ResourceLink(_) => {} - } - } - - let snapshot = message_editor.update(cx, |editor, cx| { - editor.set_text(text, window, cx); - editor.buffer().read(cx).snapshot(cx) - }); - - for (range, project_path, filename) in mentions { - let crease_icon_path = if project_path.path.is_dir() { - FileIcons::get_folder_icon(false, cx) - .unwrap_or_else(|| IconName::Folder.path().into()) - } else { - FileIcons::get_icon(Path::new(project_path.path.as_ref()), cx) - .unwrap_or_else(|| IconName::File.path().into()) - }; - - let anchor = snapshot.anchor_before(range.start); - if let Some(project_path) = project.read(cx).absolute_path(&project_path, cx) { - let crease_id = crate::context_picker::insert_crease_for_mention( - anchor.excerpt_id, - anchor.text_anchor, - range.end - range.start, - filename, - crease_icon_path, - message_editor.clone(), - window, - cx, - ); + }) + .ok(); + } - if let Some(crease_id) = crease_id { - mention_set - .lock() - .insert(crease_id, MentionUri::File(project_path)); - } - } - } + fn handle_thread_error(&mut self, error: anyhow::Error, cx: &mut Context) { + self.thread_error = Some(ThreadError::from_err(error, &self.agent)); + cx.notify(); + } - let snapshot = snapshot.as_singleton().unwrap().2.clone(); - Some(snapshot.text) + fn clear_thread_error(&mut self, cx: &mut Context) { + self.thread_error = None; + cx.notify(); } fn handle_thread_event( @@ -701,22 +1295,36 @@ impl AcpThreadView { ) { match event { AcpThreadEvent::NewEntry => { - let index = thread.read(cx).entries().len() - 1; - self.sync_thread_entry_view(index, window, cx); - self.list_state.splice(index..index, 1); + let len = thread.read(cx).entries().len(); + let index = len - 1; + self.entry_view_state.update(cx, |view_state, cx| { + view_state.sync_entry(index, thread, window, cx); + self.list_state.splice_focusable( + index..index, + [view_state + .entry(index) + .and_then(|entry| entry.focus_handle(cx))], + ); + }); } AcpThreadEvent::EntryUpdated(index) => { - self.sync_thread_entry_view(*index, window, cx); - self.list_state.splice(*index..index + 1, 1); + self.entry_view_state.update(cx, |view_state, cx| { + view_state.sync_entry(*index, thread, window, cx) + }); } AcpThreadEvent::EntriesRemoved(range) => { - // TODO: Clean up unused diff editors and terminal views + self.entry_view_state + .update(cx, |view_state, _cx| view_state.remove(range.clone())); self.list_state.splice(range.clone(), 0); } AcpThreadEvent::ToolAuthorizationRequired => { self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx); } + AcpThreadEvent::Retry(retry) => { + self.thread_retry_status = Some(retry.clone()); + } AcpThreadEvent::Stopped => { + self.thread_retry_status.take(); let used_tools = thread.read(cx).used_tools_since_last_user_message(); self.notify_with_sound( if used_tools { @@ -729,7 +1337,16 @@ impl AcpThreadView { cx, ); } + AcpThreadEvent::Refusal => { + self.thread_retry_status.take(); + self.thread_error = Some(ThreadError::Refusal); + let model_or_agent_name = self.get_current_model_name(cx); + let notification_message = + format!("{} refused to respond to this request", model_or_agent_name); + self.notify_with_sound(¬ification_message, IconName::Warning, window, cx); + } AcpThreadEvent::Error => { + self.thread_retry_status.take(); self.notify_with_sound( "Agent stopped due to an error", IconName::Warning, @@ -737,172 +1354,302 @@ impl AcpThreadView { cx, ); } - AcpThreadEvent::ServerExited(status) => { - self.thread_state = ThreadState::ServerExited { status: *status }; + AcpThreadEvent::LoadError(error) => { + self.thread_retry_status.take(); + self.thread_state = ThreadState::LoadError(error.clone()); + if self.message_editor.focus_handle(cx).is_focused(window) { + self.focus_handle.focus(window) + } + } + AcpThreadEvent::TitleUpdated => { + let title = thread.read(cx).title(); + if let Some(title_editor) = self.title_editor() { + title_editor.update(cx, |editor, cx| { + if editor.text(cx) != title { + editor.set_text(title, window, cx); + } + }); + } + } + AcpThreadEvent::PromptCapabilitiesUpdated => { + self.prompt_capabilities + .set(thread.read(cx).prompt_capabilities()); + } + AcpThreadEvent::TokenUsageUpdated => {} + AcpThreadEvent::AvailableCommandsUpdated(available_commands) => { + let mut available_commands = available_commands.clone(); + + if thread + .read(cx) + .connection() + .auth_methods() + .iter() + .any(|method| method.id.0.as_ref() == "claude-login") + { + available_commands.push(acp::AvailableCommand { + name: "login".to_owned(), + description: "Authenticate".to_owned(), + input: None, + }); + available_commands.push(acp::AvailableCommand { + name: "logout".to_owned(), + description: "Authenticate".to_owned(), + input: None, + }); + } + + self.available_commands.replace(available_commands); + } + AcpThreadEvent::ModeUpdated(_mode) => { + // The connection keeps track of the mode + cx.notify(); } } cx.notify(); } - fn sync_thread_entry_view( - &mut self, - entry_ix: usize, - window: &mut Window, - cx: &mut Context, - ) { - self.sync_diff_multibuffers(entry_ix, window, cx); - self.sync_terminals(entry_ix, window, cx); - } - - fn sync_diff_multibuffers( + fn authenticate( &mut self, - entry_ix: usize, + method: acp::AuthMethodId, window: &mut Window, cx: &mut Context, ) { - let Some(multibuffers) = self.entry_diff_multibuffers(entry_ix, cx) else { + let ThreadState::Unauthenticated { + connection, + pending_auth_method, + configuration_view, + .. + } = &mut self.thread_state + else { return; }; - let multibuffers = multibuffers.collect::>(); - - for multibuffer in multibuffers { - if self.diff_editors.contains_key(&multibuffer.entity_id()) { + if method.0.as_ref() == "gemini-api-key" { + let registry = LanguageModelRegistry::global(cx); + let provider = registry + .read(cx) + .provider(&language_model::GOOGLE_PROVIDER_ID) + .unwrap(); + if !provider.is_authenticated(cx) { + let this = cx.weak_entity(); + let agent = self.agent.clone(); + let connection = connection.clone(); + window.defer(cx, |window, cx| { + Self::handle_auth_required( + this, + AuthRequired { + description: Some("GEMINI_API_KEY must be set".to_owned()), + provider_id: Some(language_model::GOOGLE_PROVIDER_ID), + }, + agent, + connection, + window, + cx, + ); + }); return; } - - let editor = cx.new(|cx| { - let mut editor = Editor::new( - EditorMode::Full { - scale_ui_elements_with_buffer_font_size: false, - show_active_line_background: false, - sized_by_content: true, - }, - multibuffer.clone(), - None, - window, - cx, - ); - editor.set_show_gutter(false, cx); - editor.disable_inline_diagnostics(); - editor.disable_expand_excerpt_buttons(cx); - editor.set_show_vertical_scrollbar(false, cx); - editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); - editor.set_soft_wrap_mode(SoftWrap::None, cx); - editor.scroll_manager.set_forbid_vertical_scroll(true); - editor.set_show_indent_guides(false, cx); - editor.set_read_only(true); - editor.set_show_breakpoints(false, cx); - editor.set_show_code_actions(false, cx); - editor.set_show_git_diff_gutter(false, cx); - editor.set_expand_all_diff_hunks(cx); - editor.set_text_style_refinement(diff_editor_text_style_refinement(cx)); - editor + } else if method.0.as_ref() == "anthropic-api-key" { + let registry = LanguageModelRegistry::global(cx); + let provider = registry + .read(cx) + .provider(&language_model::ANTHROPIC_PROVIDER_ID) + .unwrap(); + let this = cx.weak_entity(); + let agent = self.agent.clone(); + let connection = connection.clone(); + window.defer(cx, move |window, cx| { + if !provider.is_authenticated(cx) { + Self::handle_auth_required( + this, + AuthRequired { + description: Some("ANTHROPIC_API_KEY must be set".to_owned()), + provider_id: Some(language_model::ANTHROPIC_PROVIDER_ID), + }, + agent, + connection, + window, + cx, + ); + } else { + this.update(cx, |this, cx| { + this.thread_state = Self::initial_state( + agent, + None, + this.workspace.clone(), + this.project.clone(), + window, + cx, + ) + }) + .ok(); + } }); - let entity_id = multibuffer.entity_id(); - cx.observe_release(&multibuffer, move |this, _, _| { - this.diff_editors.remove(&entity_id); - }) - .detach(); - - self.diff_editors.insert(entity_id, editor); - } - } - - fn entry_diff_multibuffers( - &self, - entry_ix: usize, - cx: &App, - ) -> Option>> { - let entry = self.thread()?.read(cx).entries().get(entry_ix)?; - Some( - entry - .diffs() - .map(|diff| diff.read(cx).multibuffer().clone()), - ) - } - - fn sync_terminals(&mut self, entry_ix: usize, window: &mut Window, cx: &mut Context) { - let Some(terminals) = self.entry_terminals(entry_ix, cx) else { return; - }; - - let terminals = terminals.collect::>(); + } else if method.0.as_ref() == "vertex-ai" + && std::env::var("GOOGLE_API_KEY").is_err() + && (std::env::var("GOOGLE_CLOUD_PROJECT").is_err() + || (std::env::var("GOOGLE_CLOUD_PROJECT").is_err())) + { + let this = cx.weak_entity(); + let agent = self.agent.clone(); + let connection = connection.clone(); + + window.defer(cx, |window, cx| { + Self::handle_auth_required( + this, + AuthRequired { + description: Some( + "GOOGLE_API_KEY must be set in the environment to use Vertex AI authentication for Gemini CLI. Please export it and restart Zed." + .to_owned(), + ), + provider_id: None, + }, + agent, + connection, + window, + cx, + ) + }); + return; + } - for terminal in terminals { - if self.terminal_views.contains_key(&terminal.entity_id()) { - return; + self.thread_error.take(); + configuration_view.take(); + pending_auth_method.replace(method.clone()); + let authenticate = if (method.0.as_ref() == "claude-login" + || method.0.as_ref() == "spawn-gemini-cli") + && let Some(login) = self.login.clone() + { + if let Some(workspace) = self.workspace.upgrade() { + Self::spawn_external_agent_login(login, workspace, false, window, cx) + } else { + Task::ready(Ok(())) } + } else { + connection.authenticate(method, cx) + }; + cx.notify(); + self.auth_task = + Some(cx.spawn_in(window, { + let agent = self.agent.clone(); + async move |this, cx| { + let result = authenticate.await; + + match &result { + Ok(_) => telemetry::event!( + "Authenticate Agent Succeeded", + agent = agent.telemetry_id() + ), + Err(_) => { + telemetry::event!( + "Authenticate Agent Failed", + agent = agent.telemetry_id(), + ) + } + } - let terminal_view = cx.new(|cx| { - let mut view = TerminalView::new( - terminal.read(cx).inner().clone(), - self.workspace.clone(), - None, - self.project.downgrade(), - window, - cx, - ); - view.set_embedded_mode(Some(1000), cx); - view - }); - - let entity_id = terminal.entity_id(); - cx.observe_release(&terminal, move |this, _, _| { - this.terminal_views.remove(&entity_id); - }) - .detach(); - - self.terminal_views.insert(entity_id, terminal_view); - } - } - - fn entry_terminals( - &self, - entry_ix: usize, - cx: &App, - ) -> Option>> { - let entry = self.thread()?.read(cx).entries().get(entry_ix)?; - Some(entry.terminals().map(|terminal| terminal.clone())) + this.update_in(cx, |this, window, cx| { + if let Err(err) = result { + if let ThreadState::Unauthenticated { + pending_auth_method, + .. + } = &mut this.thread_state + { + pending_auth_method.take(); + } + this.handle_thread_error(err, cx); + } else { + this.reset(window, cx); + } + this.auth_task.take() + }) + .ok(); + } + })); } - fn authenticate( - &mut self, - method: acp::AuthMethodId, + fn spawn_external_agent_login( + login: task::SpawnInTerminal, + workspace: Entity, + previous_attempt: bool, window: &mut Window, - cx: &mut Context, - ) { - let ThreadState::Unauthenticated { ref connection } = self.thread_state else { - return; + cx: &mut App, + ) -> Task> { + let Some(terminal_panel) = workspace.read(cx).panel::(cx) else { + return Task::ready(Ok(())); }; + let project = workspace.read(cx).project().clone(); + let cwd = project.read(cx).first_project_directory(cx); + let shell = project.read(cx).terminal_settings(&cwd, cx).shell.clone(); - self.last_error.take(); - let authenticate = connection.authenticate(method, cx); - self.auth_task = Some(cx.spawn_in(window, { - let project = self.project.clone(); - let agent = self.agent.clone(); - async move |this, cx| { - let result = authenticate.await; - - this.update_in(cx, |this, window, cx| { - if let Err(err) = result { - this.last_error = Some(cx.new(|cx| { - Markdown::new(format!("Error: {err}").into(), None, None, cx) - })) - } else { - this.thread_state = Self::initial_state( - agent, - this.workspace.clone(), - project.clone(), - window, - cx, - ) + window.spawn(cx, async move |cx| { + let mut task = login.clone(); + task.command = task + .command + .map(|command| anyhow::Ok(shlex::try_quote(&command)?.to_string())) + .transpose()?; + task.args = task + .args + .iter() + .map(|arg| { + Ok(shlex::try_quote(arg) + .context("Failed to quote argument")? + .to_string()) + }) + .collect::>>()?; + task.full_label = task.label.clone(); + task.id = task::TaskId(format!("external-agent-{}-login", task.label)); + task.command_label = task.label.clone(); + task.use_new_terminal = true; + task.allow_concurrent_runs = true; + task.hide = task::HideStrategy::Always; + task.shell = shell; + + let terminal = terminal_panel.update_in(cx, |terminal_panel, window, cx| { + terminal_panel.spawn_task(&login, window, cx) + })?; + + let terminal = terminal.await?; + let mut exit_status = terminal + .read_with(cx, |terminal, cx| terminal.wait_for_completed_task(cx))? + .fuse(); + + let logged_in = cx + .spawn({ + let terminal = terminal.clone(); + async move |cx| { + loop { + cx.background_executor().timer(Duration::from_secs(1)).await; + let content = + terminal.update(cx, |terminal, _cx| terminal.get_content())?; + if content.contains("Login successful") + || content.contains("Type your message") + { + return anyhow::Ok(()); + } + } } - this.auth_task.take() }) - .ok(); + .fuse(); + futures::pin_mut!(logged_in); + futures::select_biased! { + result = logged_in => { + if let Err(e) = result { + log::error!("{e}"); + return Err(anyhow!("exited before logging in")); + } + } + _ = exit_status => { + if !previous_attempt && project.read_with(cx, |project, _| project.is_via_remote_server())? && login.label.contains("gemini") { + return cx.update(|window, cx| Self::spawn_external_agent_login(login, workspace, true, window, cx))?.await + } + return Err(anyhow!("exited before logging in")); + } } - })); + terminal.update(cx, |terminal, _| terminal.kill_active_task())?; + Ok(()) + }) } fn authorize_tool_call( @@ -910,6 +1657,7 @@ impl AcpThreadView { tool_call_id: acp::ToolCallId, option_id: acp::PermissionOptionId, option_kind: acp::PermissionOptionKind, + window: &mut Window, cx: &mut Context, ) { let Some(thread) = self.thread() else { @@ -918,69 +1666,202 @@ impl AcpThreadView { thread.update(cx, |thread, cx| { thread.authorize_tool_call(tool_call_id, option_id, option_kind, cx); }); + if self.should_be_following { + self.workspace + .update(cx, |workspace, cx| { + workspace.follow(CollaboratorId::Agent, window, cx); + }) + .ok(); + } cx.notify(); } - fn rewind(&mut self, message_id: &UserMessageId, cx: &mut Context) { + fn restore_checkpoint(&mut self, message_id: &UserMessageId, cx: &mut Context) { let Some(thread) = self.thread() else { return; }; + thread - .update(cx, |thread, cx| thread.rewind(message_id.clone(), cx)) + .update(cx, |thread, cx| { + thread.restore_checkpoint(message_id.clone(), cx) + }) .detach_and_log_err(cx); - cx.notify(); } fn render_entry( &self, - index: usize, + entry_ix: usize, total_entries: usize, entry: &AgentThreadEntry, window: &mut Window, cx: &Context, ) -> AnyElement { let primary = match &entry { - AgentThreadEntry::UserMessage(message) => div() - .id(("user_message", index)) - .py_4() - .px_2() - .children(message.id.clone().and_then(|message_id| { - message.checkpoint.as_ref()?; + AgentThreadEntry::UserMessage(message) => { + let Some(editor) = self + .entry_view_state + .read(cx) + .entry(entry_ix) + .and_then(|entry| entry.message_editor()) + .cloned() + else { + return Empty.into_any_element(); + }; - Some( - Button::new("restore-checkpoint", "Restore Checkpoint") - .icon(IconName::Undo) - .icon_size(IconSize::XSmall) - .icon_position(IconPosition::Start) - .label_size(LabelSize::XSmall) - .on_click(cx.listener(move |this, _, _window, cx| { - this.rewind(&message_id, cx); - })), - ) - })) - .child( - v_flex() - .p_3() - .gap_1p5() - .rounded_lg() - .shadow_md() - .bg(cx.theme().colors().editor_background) - .border_1() - .border_color(cx.theme().colors().border) - .text_xs() - .children(message.content.markdown().map(|md| { - self.render_markdown( - md.clone(), - user_message_markdown_style(window, cx), + let editing = self.editing_message == Some(entry_ix); + let editor_focus = editor.focus_handle(cx).is_focused(window); + let focus_border = cx.theme().colors().border_focused; + + let rules_item = if entry_ix == 0 { + self.render_rules_item(cx) + } else { + None + }; + + let has_checkpoint_button = message + .checkpoint + .as_ref() + .is_some_and(|checkpoint| checkpoint.show); + + let agent_name = self.agent.name(); + + v_flex() + .id(("user_message", entry_ix)) + .map(|this| { + if entry_ix == 0 && !has_checkpoint_button && rules_item.is_none() { + this.pt(rems_from_px(18.)) + } else if rules_item.is_some() { + this.pt_3() + } else { + this.pt_2() + } + }) + .pb_3() + .px_2() + .gap_1p5() + .w_full() + .children(rules_item) + .children(message.id.clone().and_then(|message_id| { + message.checkpoint.as_ref()?.show.then(|| { + h_flex() + .px_3() + .gap_2() + .child(Divider::horizontal()) + .child( + Button::new("restore-checkpoint", "Restore Checkpoint") + .icon(IconName::Undo) + .icon_size(IconSize::XSmall) + .icon_position(IconPosition::Start) + .label_size(LabelSize::XSmall) + .icon_color(Color::Muted) + .color(Color::Muted) + .tooltip(Tooltip::text("Restores all files in the project to the content they had at this point in the conversation.")) + .on_click(cx.listener(move |this, _, _window, cx| { + this.restore_checkpoint(&message_id, cx); + })) + ) + .child(Divider::horizontal()) + }) + })) + .child( + div() + .relative() + .child( + div() + .py_3() + .px_2() + .rounded_md() + .shadow_md() + .bg(cx.theme().colors().editor_background) + .border_1() + .when(editing && !editor_focus, |this| this.border_dashed()) + .border_color(cx.theme().colors().border) + .map(|this|{ + if editing && editor_focus { + this.border_color(focus_border) + } else if message.id.is_some() { + this.hover(|s| s.border_color(focus_border.opacity(0.8))) + } else { + this + } + }) + .text_xs() + .child(editor.clone().into_any_element()), ) - })), - ) - .into_any(), + .when(editor_focus, |this| { + let base_container = h_flex() + .absolute() + .top_neg_3p5() + .right_3() + .gap_1() + .rounded_sm() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().editor_background) + .overflow_hidden(); + + if message.id.is_some() { + this.child( + base_container + .child( + IconButton::new("cancel", IconName::Close) + .disabled(self.is_loading_contents) + .icon_color(Color::Error) + .icon_size(IconSize::XSmall) + .on_click(cx.listener(Self::cancel_editing)) + ) + .child( + if self.is_loading_contents { + div() + .id("loading-edited-message-content") + .tooltip(Tooltip::text("Loading Added Context…")) + .child(loading_contents_spinner(IconSize::XSmall)) + .into_any_element() + } else { + IconButton::new("regenerate", IconName::Return) + .icon_color(Color::Muted) + .icon_size(IconSize::XSmall) + .tooltip(Tooltip::text( + "Editing will restart the thread from this point." + )) + .on_click(cx.listener({ + let editor = editor.clone(); + move |this, _, window, cx| { + this.regenerate( + entry_ix, editor.clone(), window, cx, + ); + } + })).into_any_element() + } + ) + ) + } else { + this.child( + base_container + .border_dashed() + .child( + IconButton::new("editing_unavailable", IconName::PencilUnavailable) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .style(ButtonStyle::Transparent) + .tooltip(move |_window, cx| { + cx.new(|_| UnavailableEditingTooltip::new(agent_name.clone())) + .into() + }) + ) + ) + } + }), + ) + .into_any() + } AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) => { - let style = default_markdown_style(false, window, cx); + let is_last = entry_ix + 1 == total_entries; + + let style = default_markdown_style(false, false, window, cx); let message_body = v_flex() .w_full() - .gap_2p5() + .gap_3() .children(chunks.iter().enumerate().filter_map( |(chunk_ix, chunk)| match chunk { AssistantMessageChunk::Message { block } => { @@ -992,7 +1873,7 @@ impl AcpThreadView { AssistantMessageChunk::Thought { block } => { block.markdown().map(|md| { self.render_thinking_block( - index, + entry_ix, chunk_ix, md.clone(), window, @@ -1007,8 +1888,8 @@ impl AcpThreadView { v_flex() .px_5() - .py_1() - .when(index + 1 == total_entries, |this| this.pb_4()) + .py_1p5() + .when(is_last, |this| this.pb_4()) .w_full() .text_ui(cx) .child(message_body) @@ -1017,13 +1898,15 @@ impl AcpThreadView { AgentThreadEntry::ToolCall(tool_call) => { let has_terminals = tool_call.terminals().next().is_some(); - div().w_full().py_1p5().px_5().map(|this| { + div().w_full().map(|this| { if has_terminals { this.children(tool_call.terminals().map(|terminal| { - self.render_terminal_tool_call(terminal, tool_call, window, cx) + self.render_terminal_tool_call( + entry_ix, terminal, tool_call, window, cx, + ) })) } else { - this.child(self.render_tool_call(index, tool_call, window, cx)) + this.child(self.render_tool_call(entry_ix, tool_call, window, cx)) } }) } @@ -1034,12 +1917,37 @@ impl AcpThreadView { return primary; }; - let is_generating = matches!(thread.read(cx).status(), ThreadStatus::Generating); - if index == total_entries - 1 && !is_generating { + let primary = if entry_ix == total_entries - 1 { v_flex() .w_full() .child(primary) - .child(self.render_thread_controls(cx)) + .child(self.render_thread_controls(&thread, cx)) + .when_some( + self.thread_feedback.comments_editor.clone(), + |this, editor| this.child(Self::render_feedback_feedback_editor(editor, cx)), + ) + .into_any_element() + } else { + primary + }; + + if let Some(editing_index) = self.editing_message.as_ref() + && *editing_index < entry_ix + { + let backdrop = div() + .id(("backdrop", entry_ix)) + .size_full() + .absolute() + .inset_0() + .bg(cx.theme().colors().panel_background) + .opacity(0.8) + .block_mouse_except_scroll() + .on_click(cx.listener(Self::cancel_editing)); + + div() + .relative() + .child(primary) + .child(backdrop) .into_any_element() } else { primary @@ -1054,7 +1962,7 @@ impl AcpThreadView { } fn tool_card_border_color(&self, cx: &Context) -> Hsla { - cx.theme().colors().border.opacity(0.6) + cx.theme().colors().border.opacity(0.8) } fn tool_name_font_size(&self) -> Rems { @@ -1071,60 +1979,72 @@ impl AcpThreadView { ) -> AnyElement { let header_id = SharedString::from(format!("thinking-block-header-{}", entry_ix)); let card_header_id = SharedString::from("inner-card-header"); + let key = (entry_ix, chunk_ix); + let is_open = self.expanded_thinking_blocks.contains(&key); + let scroll_handle = self + .entry_view_state + .read(cx) + .entry(entry_ix) + .and_then(|entry| entry.scroll_handle_for_assistant_message_chunk(chunk_ix)); + + let thinking_content = { + div() + .id(("thinking-content", chunk_ix)) + .when_some(scroll_handle, |this, scroll_handle| { + this.track_scroll(&scroll_handle) + }) + .text_ui_sm(cx) + .overflow_hidden() + .child( + self.render_markdown(chunk, default_markdown_style(false, false, window, cx)), + ) + }; + v_flex() + .gap_1() .child( h_flex() .id(header_id) .group(&card_header_id) .relative() .w_full() - .gap_1p5() - .opacity(0.8) - .hover(|style| style.opacity(1.)) + .pr_1() + .justify_between() .child( h_flex() - .size_4() - .justify_center() + .h(window.line_height() - px(2.)) + .gap_1p5() + .overflow_hidden() .child( - div() - .group_hover(&card_header_id, |s| s.invisible().w_0()) - .child( - Icon::new(IconName::ToolThink) - .size(IconSize::Small) - .color(Color::Muted), - ), + Icon::new(IconName::ToolThink) + .size(IconSize::Small) + .color(Color::Muted), ) .child( - h_flex() - .absolute() - .inset_0() - .invisible() - .justify_center() - .group_hover(&card_header_id, |s| s.visible()) - .child( - Disclosure::new(("expand", entry_ix), is_open) - .opened_icon(IconName::ChevronUp) - .closed_icon(IconName::ChevronRight) - .on_click(cx.listener({ - move |this, _event, _window, cx| { - if is_open { - this.expanded_thinking_blocks.remove(&key); - } else { - this.expanded_thinking_blocks.insert(key); - } - cx.notify(); - } - })), - ), + div() + .text_size(self.tool_name_font_size()) + .text_color(cx.theme().colors().text_muted) + .child("Thinking"), ), ) .child( - div() - .text_size(self.tool_name_font_size()) - .child("Thinking"), + Disclosure::new(("expand", entry_ix), is_open) + .opened_icon(IconName::ChevronUp) + .closed_icon(IconName::ChevronDown) + .visible_on_hover(&card_header_id) + .on_click(cx.listener({ + move |this, _event, _window, cx| { + if is_open { + this.expanded_thinking_blocks.remove(&key); + } else { + this.expanded_thinking_blocks.insert(key); + } + cx.notify(); + } + })), ) .on_click(cx.listener({ move |this, _event, _window, cx| { @@ -1140,82 +2060,16 @@ impl AcpThreadView { .when(is_open, |this| { this.child( div() - .relative() - .mt_1p5() - .ml(px(7.)) - .pl_4() + .ml_1p5() + .pl_3p5() .border_l_1() .border_color(self.tool_card_border_color(cx)) - .text_ui_sm(cx) - .child( - self.render_markdown(chunk, default_markdown_style(false, window, cx)), - ), + .child(thinking_content), ) }) .into_any_element() } - fn render_tool_call_icon( - &self, - group_name: SharedString, - entry_ix: usize, - is_collapsible: bool, - is_open: bool, - tool_call: &ToolCall, - cx: &Context, - ) -> Div { - let tool_icon = Icon::new(match tool_call.kind { - acp::ToolKind::Read => IconName::ToolRead, - acp::ToolKind::Edit => IconName::ToolPencil, - acp::ToolKind::Delete => IconName::ToolDeleteFile, - acp::ToolKind::Move => IconName::ArrowRightLeft, - acp::ToolKind::Search => IconName::ToolSearch, - acp::ToolKind::Execute => IconName::ToolTerminal, - acp::ToolKind::Think => IconName::ToolThink, - acp::ToolKind::Fetch => IconName::ToolWeb, - acp::ToolKind::Other => IconName::ToolHammer, - }) - .size(IconSize::Small) - .color(Color::Muted); - - let base_container = h_flex().size_4().justify_center(); - - if is_collapsible { - base_container - .child( - div() - .group_hover(&group_name, |s| s.invisible().w_0()) - .child(tool_icon), - ) - .child( - h_flex() - .absolute() - .inset_0() - .invisible() - .justify_center() - .group_hover(&group_name, |s| s.visible()) - .child( - Disclosure::new(("expand", entry_ix), is_open) - .opened_icon(IconName::ChevronUp) - .closed_icon(IconName::ChevronRight) - .on_click(cx.listener({ - let id = tool_call.id.clone(); - move |this: &mut Self, _, _, cx: &mut Context| { - if is_open { - this.expanded_tool_calls.remove(&id); - } else { - this.expanded_tool_calls.insert(id.clone()); - } - cx.notify(); - } - })), - ), - ) - } else { - base_container.child(tool_icon) - } - } - fn render_tool_call( &self, entry_ix: usize, @@ -1223,174 +2077,203 @@ impl AcpThreadView { window: &Window, cx: &Context, ) -> Div { - let header_id = SharedString::from(format!("outer-tool-call-header-{}", entry_ix)); + let has_location = tool_call.locations.len() == 1; let card_header_id = SharedString::from("inner-tool-call-header"); - let status_icon = match &tool_call.status { - ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Pending, - } - | ToolCallStatus::WaitingForConfirmation { .. } => None, - ToolCallStatus::Allowed { - status: acp::ToolCallStatus::InProgress, - .. - } => Some( - Icon::new(IconName::ArrowCircle) - .color(Color::Accent) - .size(IconSize::Small) - .with_animation( - "running", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), - ) - .into_any(), - ), - ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Completed, - .. - } => None, - ToolCallStatus::Rejected - | ToolCallStatus::Canceled - | ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Failed, - .. - } => Some( - Icon::new(IconName::Close) - .color(Color::Error) - .size(IconSize::Small) - .into_any_element(), - ), + let tool_icon = if tool_call.kind == acp::ToolKind::Edit && has_location { + FileIcons::get_icon(&tool_call.locations[0].path, cx) + .map(Icon::from_path) + .unwrap_or(Icon::new(IconName::ToolPencil)) + } else { + Icon::new(match tool_call.kind { + acp::ToolKind::Read => IconName::ToolSearch, + acp::ToolKind::Edit => IconName::ToolPencil, + acp::ToolKind::Delete => IconName::ToolDeleteFile, + acp::ToolKind::Move => IconName::ArrowRightLeft, + acp::ToolKind::Search => IconName::ToolSearch, + acp::ToolKind::Execute => IconName::ToolTerminal, + acp::ToolKind::Think => IconName::ToolThink, + acp::ToolKind::Fetch => IconName::ToolWeb, + acp::ToolKind::SwitchMode => IconName::ArrowRightLeft, + acp::ToolKind::Other => IconName::ToolHammer, + }) + } + .size(IconSize::Small) + .color(Color::Muted); + + let failed_or_canceled = match &tool_call.status { + ToolCallStatus::Rejected | ToolCallStatus::Canceled | ToolCallStatus::Failed => true, + _ => false, }; let needs_confirmation = matches!( tool_call.status, ToolCallStatus::WaitingForConfirmation { .. } ); - let is_edit = matches!(tool_call.kind, acp::ToolKind::Edit); - let has_diff = tool_call - .content - .iter() - .any(|content| matches!(content, ToolCallContent::Diff { .. })); - let has_nonempty_diff = tool_call.content.iter().any(|content| match content { - ToolCallContent::Diff(diff) => diff.read(cx).has_revealed_range(cx), - _ => false, - }); - let use_card_layout = needs_confirmation || is_edit || has_diff; + let is_edit = + matches!(tool_call.kind, acp::ToolKind::Edit) || tool_call.diffs().next().is_some(); + let use_card_layout = needs_confirmation || is_edit; - let is_collapsible = !tool_call.content.is_empty() && !use_card_layout; + let is_collapsible = !tool_call.content.is_empty() && !needs_confirmation; - let is_open = tool_call.content.is_empty() - || needs_confirmation - || has_nonempty_diff - || self.expanded_tool_calls.contains(&tool_call.id); + let is_open = needs_confirmation || self.expanded_tool_calls.contains(&tool_call.id); - let gradient_overlay = |color: Hsla| { + let gradient_overlay = { div() .absolute() .top_0() .right_0() .w_12() .h_full() - .bg(linear_gradient( - 90., - linear_color_stop(color, 1.), - linear_color_stop(color.opacity(0.2), 0.), - )) - }; - let gradient_color = if use_card_layout { - self.tool_card_header_bg(cx) - } else { - cx.theme().colors().panel_background + .map(|this| { + if use_card_layout { + this.bg(linear_gradient( + 90., + linear_color_stop(self.tool_card_header_bg(cx), 1.), + linear_color_stop(self.tool_card_header_bg(cx).opacity(0.2), 0.), + )) + } else { + this.bg(linear_gradient( + 90., + linear_color_stop(cx.theme().colors().panel_background, 1.), + linear_color_stop( + cx.theme().colors().panel_background.opacity(0.2), + 0., + ), + )) + } + }) }; - let tool_output_display = match &tool_call.status { - ToolCallStatus::WaitingForConfirmation { options, .. } => v_flex() - .w_full() - .children(tool_call.content.iter().map(|content| { - div() - .child(self.render_tool_call_content(content, tool_call, window, cx)) - .into_any_element() - })) - .child(self.render_permission_buttons( - options, - entry_ix, - tool_call.id.clone(), - tool_call.content.is_empty(), - cx, - )), - ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => v_flex() - .w_full() - .children(tool_call.content.iter().map(|content| { - div() - .child(self.render_tool_call_content(content, tool_call, window, cx)) - .into_any_element() - })), - ToolCallStatus::Rejected => v_flex().size_0(), - }; + let tool_output_display = + if is_open { + match &tool_call.status { + ToolCallStatus::WaitingForConfirmation { options, .. } => v_flex() + .w_full() + .children(tool_call.content.iter().enumerate().map( + |(content_ix, content)| { + div() + .child(self.render_tool_call_content( + entry_ix, + content, + content_ix, + tool_call, + use_card_layout, + window, + cx, + )) + .into_any_element() + }, + )) + .child(self.render_permission_buttons( + tool_call.kind, + options, + entry_ix, + tool_call.id.clone(), + window, + cx, + )) + .into_any(), + ToolCallStatus::Pending | ToolCallStatus::InProgress + if is_edit + && tool_call.content.is_empty() + && self.as_native_connection(cx).is_some() => + { + self.render_diff_loading(cx).into_any() + } + ToolCallStatus::Pending + | ToolCallStatus::InProgress + | ToolCallStatus::Completed + | ToolCallStatus::Failed + | ToolCallStatus::Canceled => v_flex() + .w_full() + .children(tool_call.content.iter().enumerate().map( + |(content_ix, content)| { + div().child(self.render_tool_call_content( + entry_ix, + content, + content_ix, + tool_call, + use_card_layout, + window, + cx, + )) + }, + )) + .into_any(), + ToolCallStatus::Rejected => Empty.into_any(), + } + .into() + } else { + None + }; v_flex() - .when(use_card_layout, |this| { - this.rounded_lg() - .border_1() - .border_color(self.tool_card_border_color(cx)) - .bg(cx.theme().colors().editor_background) - .overflow_hidden() + .map(|this| { + if use_card_layout { + this.my_1p5() + .rounded_md() + .border_1() + .border_color(self.tool_card_border_color(cx)) + .bg(cx.theme().colors().editor_background) + .overflow_hidden() + } else { + this.my_1() + } }) + .map(|this| { + if has_location && !use_card_layout { + this.ml_4() + } else { + this.ml_5() + } + }) + .mr_5() .child( h_flex() - .id(header_id) + .group(&card_header_id) + .relative() .w_full() .gap_1() .justify_between() - .map(|this| { - if use_card_layout { - this.pl_2() - .pr_1() - .py_1() - .rounded_t_md() - .bg(self.tool_card_header_bg(cx)) - } else { - this.opacity(0.8).hover(|style| style.opacity(1.)) - } + .when(use_card_layout, |this| { + this.p_0p5() + .rounded_t(rems_from_px(5.)) + .bg(self.tool_card_header_bg(cx)) }) .child( h_flex() - .group(&card_header_id) .relative() .w_full() + .h(window.line_height() - px(2.)) .text_size(self.tool_name_font_size()) - .child(self.render_tool_call_icon( - card_header_id, - entry_ix, - is_collapsible, - is_open, - tool_call, - cx, - )) - .child(if tool_call.locations.len() == 1 { - let name = tool_call.locations[0] - .path - .file_name() - .unwrap_or_default() - .display() - .to_string(); - + .gap_1p5() + .when(has_location || use_card_layout, |this| this.px_1()) + .when(has_location, |this| { + this.cursor(CursorStyle::PointingHand) + .rounded(rems_from_px(3.)) // Concentric border radius + .hover(|s| s.bg(cx.theme().colors().element_hover.opacity(0.5))) + }) + .overflow_hidden() + .child(tool_icon) + .child(if has_location { h_flex() .id(("open-tool-call-location", entry_ix)) .w_full() - .max_w_full() - .px_1p5() - .rounded_sm() - .overflow_x_scroll() - .opacity(0.8) - .hover(|label| { - label.opacity(1.).bg(cx - .theme() - .colors() - .element_hover - .opacity(0.5)) + .map(|this| { + if use_card_layout { + this.text_color(cx.theme().colors().text) + } else { + this.text_color(cx.theme().colors().text_muted) + } }) - .child(name) + .child(self.render_markdown( + tool_call.label.clone(), + MarkdownStyle { + prevent_mouse_interaction: true, + ..default_markdown_style(false, true, window, cx) + }, + )) .tooltip(Tooltip::text("Jump to File")) .on_click(cx.listener(move |this, _, window, cx| { this.open_tool_call_location(entry_ix, 0, window, cx); @@ -1398,50 +2281,59 @@ impl AcpThreadView { .into_any_element() } else { h_flex() - .id("non-card-label-container") .w_full() - .relative() - .ml_1p5() - .overflow_hidden() - .child( - h_flex() - .id("non-card-label") - .pr_8() - .w_full() - .overflow_x_scroll() - .child(self.render_markdown( - tool_call.label.clone(), - default_markdown_style( - needs_confirmation || is_edit || has_diff, - window, - cx, - ), - )), - ) - .child(gradient_overlay(gradient_color)) - .on_click(cx.listener({ - let id = tool_call.id.clone(); - move |this: &mut Self, _, _, cx: &mut Context| { - if is_open { - this.expanded_tool_calls.remove(&id); - } else { - this.expanded_tool_calls.insert(id.clone()); - } - cx.notify(); - } - })) + .child(self.render_markdown( + tool_call.label.clone(), + default_markdown_style(false, true, window, cx), + )) .into_any() - }), + }) + .when(!has_location, |this| this.child(gradient_overlay)), ) - .children(status_icon), + .when(is_collapsible || failed_or_canceled, |this| { + this.child( + h_flex() + .px_1() + .gap_px() + .when(is_collapsible, |this| { + this.child( + Disclosure::new(("expand", entry_ix), is_open) + .opened_icon(IconName::ChevronUp) + .closed_icon(IconName::ChevronDown) + .visible_on_hover(&card_header_id) + .on_click(cx.listener({ + let id = tool_call.id.clone(); + move |this: &mut Self, _, _, cx: &mut Context| { + if is_open { + this.expanded_tool_calls.remove(&id); + } else { + this.expanded_tool_calls.insert(id.clone()); + } + cx.notify(); + } + })), + ) + }) + .when(failed_or_canceled, |this| { + this.child( + Icon::new(IconName::Close) + .color(Color::Error) + .size(IconSize::Small), + ) + }), + ) + }), ) - .when(is_open, |this| this.child(tool_output_display)) + .children(tool_output_display) } fn render_tool_call_content( &self, + entry_ix: usize, content: &ToolCallContent, + context_ix: usize, tool_call: &ToolCall, + card_layout: bool, window: &Window, cx: &Context, ) -> AnyElement { @@ -1450,16 +2342,21 @@ impl AcpThreadView { if let Some(resource_link) = content.resource_link() { self.render_resource_link(resource_link, cx) } else if let Some(markdown) = content.markdown() { - self.render_markdown_output(markdown.clone(), tool_call.id.clone(), window, cx) + self.render_markdown_output( + markdown.clone(), + tool_call.id.clone(), + context_ix, + card_layout, + window, + cx, + ) } else { Empty.into_any_element() } } - ToolCallContent::Diff(diff) => { - self.render_diff_editor(&diff.read(cx).multibuffer(), cx) - } + ToolCallContent::Diff(diff) => self.render_diff_editor(entry_ix, diff, tool_call, cx), ToolCallContent::Terminal(terminal) => { - self.render_terminal_tool_call(terminal, tool_call, window, cx) + self.render_terminal_tool_call(entry_ix, terminal, tool_call, window, cx) } } } @@ -1468,37 +2365,46 @@ impl AcpThreadView { &self, markdown: Entity, tool_call_id: acp::ToolCallId, + context_ix: usize, + card_layout: bool, window: &Window, cx: &Context, ) -> AnyElement { - let button_id = SharedString::from(format!("tool_output-{:?}", tool_call_id.clone())); + let button_id = SharedString::from(format!("tool_output-{:?}", tool_call_id)); v_flex() .mt_1p5() - .ml(px(7.)) - .px_3p5() .gap_2() - .border_l_1() - .border_color(self.tool_card_border_color(cx)) - .text_sm() + .when(!card_layout, |this| { + this.ml(rems(0.4)) + .px_3p5() + .border_l_1() + .border_color(self.tool_card_border_color(cx)) + }) + .when(card_layout, |this| { + this.px_2().pb_2().when(context_ix > 0, |this| { + this.border_t_1() + .pt_2() + .border_color(self.tool_card_border_color(cx)) + }) + }) + .text_xs() .text_color(cx.theme().colors().text_muted) - .child(self.render_markdown(markdown, default_markdown_style(false, window, cx))) - .child( - Button::new(button_id, "Collapse Output") - .full_width() - .style(ButtonStyle::Outlined) - .label_size(LabelSize::Small) - .icon(IconName::ChevronUp) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .on_click(cx.listener({ - let id = tool_call_id.clone(); - move |this: &mut Self, _, _, cx: &mut Context| { - this.expanded_tool_calls.remove(&id); - cx.notify(); - } - })), - ) + .child(self.render_markdown(markdown, default_markdown_style(false, false, window, cx))) + .when(!card_layout, |this| { + this.child( + IconButton::new(button_id, IconName::ChevronUp) + .full_width() + .style(ButtonStyle::Outlined) + .icon_color(Color::Muted) + .on_click(cx.listener({ + move |this: &mut Self, _, _, cx: &mut Context| { + this.expanded_tool_calls.remove(&tool_call_id); + cx.notify(); + } + })), + ) + }) .into_any_element() } @@ -1508,17 +2414,35 @@ impl AcpThreadView { cx: &Context, ) -> AnyElement { let uri: SharedString = resource_link.uri.clone().into(); + let is_file = resource_link.uri.strip_prefix("file://"); - let label: SharedString = if let Some(path) = resource_link.uri.strip_prefix("file://") { - path.to_string().into() + let label: SharedString = if let Some(abs_path) = is_file { + if let Some(project_path) = self + .project + .read(cx) + .project_path_for_absolute_path(&Path::new(abs_path), cx) + && let Some(worktree) = self + .project + .read(cx) + .worktree_for_id(project_path.worktree_id, cx) + { + worktree + .read(cx) + .full_path(&project_path.path) + .to_string_lossy() + .to_string() + .into() + } else { + abs_path.to_string().into() + } } else { uri.clone() }; - let button_id = SharedString::from(format!("item-{}", uri.clone())); + let button_id = SharedString::from(format!("item-{}", uri)); div() - .ml(px(7.)) + .ml(rems(0.4)) .pl_2p5() .border_l_1() .border_color(self.tool_card_border_color(cx)) @@ -1527,10 +2451,12 @@ impl AcpThreadView { Button::new(button_id, label) .label_size(LabelSize::Small) .color(Color::Muted) - .icon(IconName::ArrowUpRight) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) .truncate(true) + .when(is_file.is_none(), |this| { + this.icon(IconName::ArrowUpRight) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + }) .on_click(cx.listener({ let workspace = self.workspace.clone(); move |_, _, window, cx: &mut Context| { @@ -1543,44 +2469,70 @@ impl AcpThreadView { fn render_permission_buttons( &self, + kind: acp::ToolKind, options: &[acp::PermissionOption], entry_ix: usize, tool_call_id: acp::ToolCallId, - empty_content: bool, + window: &Window, cx: &Context, ) -> Div { - h_flex() - .py_1() - .pl_2() - .pr_1() - .gap_1() - .justify_between() - .flex_wrap() - .when(!empty_content, |this| { - this.border_t_1() - .border_color(self.tool_card_border_color(cx)) + let is_first = self.thread().is_some_and(|thread| { + thread + .read(cx) + .first_tool_awaiting_confirmation() + .is_some_and(|call| call.id == tool_call_id) + }); + let mut seen_kinds: ArrayVec = ArrayVec::new(); + + div() + .p_1() + .border_t_1() + .border_color(self.tool_card_border_color(cx)) + .w_full() + .map(|this| { + if kind == acp::ToolKind::SwitchMode { + this.v_flex() + } else { + this.h_flex().justify_end().flex_wrap() + } }) - .child( - div() - .min_w(rems_from_px(145.)) - .child(LoadingLabel::new("Waiting for Confirmation").size(LabelSize::Small)), - ) - .child(h_flex().gap_0p5().children(options.iter().map(|option| { + .gap_0p5() + .children(options.iter().map(move |option| { let option_id = SharedString::from(option.id.0.clone()); Button::new((option_id, entry_ix), option.name.clone()) - .map(|this| match option.kind { - acp::PermissionOptionKind::AllowOnce => { - this.icon(IconName::Check).icon_color(Color::Success) - } - acp::PermissionOptionKind::AllowAlways => { - this.icon(IconName::CheckDouble).icon_color(Color::Success) - } - acp::PermissionOptionKind::RejectOnce => { - this.icon(IconName::Close).icon_color(Color::Error) - } - acp::PermissionOptionKind::RejectAlways => { - this.icon(IconName::Close).icon_color(Color::Error) + .map(|this| { + let (this, action) = match option.kind { + acp::PermissionOptionKind::AllowOnce => ( + this.icon(IconName::Check).icon_color(Color::Success), + Some(&AllowOnce as &dyn Action), + ), + acp::PermissionOptionKind::AllowAlways => ( + this.icon(IconName::CheckDouble).icon_color(Color::Success), + Some(&AllowAlways as &dyn Action), + ), + acp::PermissionOptionKind::RejectOnce => ( + this.icon(IconName::Close).icon_color(Color::Error), + Some(&RejectOnce as &dyn Action), + ), + acp::PermissionOptionKind::RejectAlways => { + (this.icon(IconName::Close).icon_color(Color::Error), None) + } + }; + + let Some(action) = action else { + return this; + }; + + if !is_first || seen_kinds.contains(&option.kind) { + return this; } + + seen_kinds.push(option.kind); + + this.key_binding( + KeyBinding::for_action_in(action, &self.focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(10.))), + ) }) .icon_position(IconPosition::Start) .icon_size(IconSize::XSmall) @@ -1589,30 +2541,84 @@ impl AcpThreadView { let tool_call_id = tool_call_id.clone(); let option_id = option.id.clone(); let option_kind = option.kind; - move |this, _, _, cx| { + move |this, _, window, cx| { this.authorize_tool_call( tool_call_id.clone(), option_id.clone(), option_kind, + window, cx, ); } })) - }))) + })) + } + + fn render_diff_loading(&self, cx: &Context) -> AnyElement { + let bar = |n: u64, width_class: &str| { + let bg_color = cx.theme().colors().element_active; + let base = h_flex().h_1().rounded_full(); + + let modified = match width_class { + "w_4_5" => base.w_3_4(), + "w_1_4" => base.w_1_4(), + "w_2_4" => base.w_2_4(), + "w_3_5" => base.w_3_5(), + "w_2_5" => base.w_2_5(), + _ => base.w_1_2(), + }; + + modified.with_animation( + ElementId::Integer(n), + Animation::new(Duration::from_secs(2)).repeat(), + move |tab, delta| { + let delta = (delta - 0.15 * n as f32) / 0.7; + let delta = 1.0 - (0.5 - delta).abs() * 2.; + let delta = ease_in_out(delta.clamp(0., 1.)); + let delta = 0.1 + 0.9 * delta; + + tab.bg(bg_color.opacity(delta)) + }, + ) + }; + + v_flex() + .p_3() + .gap_1() + .rounded_b_md() + .bg(cx.theme().colors().editor_background) + .child(bar(0, "w_4_5")) + .child(bar(1, "w_1_4")) + .child(bar(2, "w_2_4")) + .child(bar(3, "w_3_5")) + .child(bar(4, "w_2_5")) + .into_any_element() } fn render_diff_editor( &self, - multibuffer: &Entity, + entry_ix: usize, + diff: &Entity, + tool_call: &ToolCall, cx: &Context, ) -> AnyElement { + let tool_progress = matches!( + &tool_call.status, + ToolCallStatus::InProgress | ToolCallStatus::Pending + ); + v_flex() .h_full() .border_t_1() .border_color(self.tool_card_border_color(cx)) .child( - if let Some(editor) = self.diff_editors.get(&multibuffer.entity_id()) { - editor.clone().into_any_element() + if let Some(entry) = self.entry_view_state.read(cx).entry(entry_ix) + && let Some(editor) = entry.editor_for_diff(diff) + && diff.read(cx).has_revealed_range(cx) + { + editor.into_any_element() + } else if tool_progress && self.as_native_connection(cx).is_some() { + self.render_diff_loading(cx) } else { Empty.into_any() }, @@ -1622,6 +2628,7 @@ impl AcpThreadView { fn render_terminal_tool_call( &self, + entry_ix: usize, terminal: &Entity, tool_call: &ToolCall, window: &Window, @@ -1634,17 +2641,13 @@ impl AcpThreadView { let tool_failed = matches!( &tool_call.status, - ToolCallStatus::Rejected - | ToolCallStatus::Canceled - | ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Failed, - .. - } + ToolCallStatus::Rejected | ToolCallStatus::Canceled | ToolCallStatus::Failed ); let output = terminal_data.output(); let command_finished = output.is_some(); - let truncated_output = output.is_some_and(|output| output.was_content_truncated); + let truncated_output = + output.is_some_and(|output| output.original_content_len > output.content.len()); let output_line_count = output.map(|output| output.content_line_count).unwrap_or(0); let command_failed = command_finished @@ -1656,6 +2659,12 @@ impl AcpThreadView { started_at.elapsed() }; + let header_id = + SharedString::from(format!("terminal-tool-header-{}", terminal.entity_id())); + let header_group = SharedString::from(format!( + "terminal-tool-header-group-{}", + terminal.entity_id() + )); let header_bg = cx .theme() .colors() @@ -1668,11 +2677,10 @@ impl AcpThreadView { .map(|path| format!("{}", path.display())) .unwrap_or_else(|| "current directory".to_string()); + let is_expanded = self.expanded_tool_calls.contains(&tool_call.id); + let header = h_flex() - .id(SharedString::from(format!( - "terminal-tool-header-{}", - terminal.entity_id() - ))) + .id(header_id) .flex_none() .gap_1() .justify_between() @@ -1727,43 +2735,20 @@ impl AcpThreadView { Icon::new(IconName::ArrowCircle) .size(IconSize::XSmall) .color(Color::Info) - .with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| { - icon.transform(Transformation::rotate(percentage(delta))) - }, - ), + .with_rotate_animation(2) ) }) - .when(tool_failed || command_failed, |header| { - header.child( - div() - .id(("terminal-tool-error-code-indicator", terminal.entity_id())) - .child( - Icon::new(IconName::Close) - .size(IconSize::Small) - .color(Color::Error), - ) - .when_some(output.and_then(|o| o.exit_status), |this, status| { - this.tooltip(Tooltip::text(format!( - "Exited with code {}", - status.code().unwrap_or(-1), - ))) - }), - ) - }) .when(truncated_output, |header| { let tooltip = if let Some(output) = output { if output_line_count + 10 > terminal::MAX_SCROLL_HISTORY_LINES { - "Output exceeded terminal max lines and was \ - truncated, the model received the first 16 KB." - .to_string() + format!("Output exceeded terminal max lines and was \ + truncated, the model received the first {}.", format_file_size(output.content.len() as u64, true)) } else { format!( - "Output is {} long—to avoid unexpected token usage, \ - only 16 KB was sent back to the model.", + "Output is {} long, and to avoid unexpected token usage, \ + only {} was sent back to the agent.", format_file_size(output.original_content_len as u64, true), + format_file_size(output.content.len() as u64, true) ) } } else { @@ -1795,36 +2780,67 @@ impl AcpThreadView { .size(LabelSize::XSmall), ) }) + .when(tool_failed || command_failed, |header| { + header.child( + div() + .id(("terminal-tool-error-code-indicator", terminal.entity_id())) + .child( + Icon::new(IconName::Close) + .size(IconSize::Small) + .color(Color::Error), + ) + .when_some(output.and_then(|o| o.exit_status), |this, status| { + this.tooltip(Tooltip::text(format!( + "Exited with code {}", + status.code().unwrap_or(-1), + ))) + }), + ) + }) .child( Disclosure::new( SharedString::from(format!( "terminal-tool-disclosure-{}", terminal.entity_id() )), - self.terminal_expanded, + is_expanded, ) .opened_icon(IconName::ChevronUp) .closed_icon(IconName::ChevronDown) - .on_click(cx.listener(move |this, _event, _window, _cx| { - this.terminal_expanded = !this.terminal_expanded; + .visible_on_hover(&header_group) + .on_click(cx.listener({ + let id = tool_call.id.clone(); + move |this, _event, _window, _cx| { + if is_expanded { + this.expanded_tool_calls.remove(&id); + } else { + this.expanded_tool_calls.insert(id.clone()); + } + } })), ); - let show_output = - self.terminal_expanded && self.terminal_views.contains_key(&terminal.entity_id()); + let terminal_view = self + .entry_view_state + .read(cx) + .entry(entry_ix) + .and_then(|entry| entry.terminal(terminal)); + let show_output = is_expanded && terminal_view.is_some(); v_flex() - .mb_2() + .my_1p5() + .mx_5() .border_1() .when(tool_failed || command_failed, |card| card.border_dashed()) .border_color(border_color) - .rounded_lg() + .rounded_md() .overflow_hidden() .child( v_flex() + .group(&header_group) .py_1p5() - .pl_2() .pr_1p5() + .pl_2() .gap_0p5() .bg(header_bg) .text_xs() @@ -1844,8 +2860,6 @@ impl AcpThreadView { ), ) .when(show_output, |this| { - let terminal_view = self.terminal_views.get(&terminal.entity_id()).unwrap(); - this.child( div() .pt_2() @@ -1855,188 +2869,443 @@ impl AcpThreadView { .bg(cx.theme().colors().editor_background) .rounded_b_md() .text_ui_sm(cx) - .child(terminal_view.clone()), + .h_full() + .children(terminal_view.map(|terminal_view| { + if terminal_view + .read(cx) + .content_mode(window, cx) + .is_scrollable() + { + div().h_72().child(terminal_view).into_any_element() + } else { + terminal_view.into_any_element() + } + })), ) }) .into_any() } - fn render_agent_logo(&self) -> AnyElement { - Icon::new(self.agent.logo()) - .color(Color::Muted) - .size(IconSize::XLarge) - .into_any_element() - } + fn render_rules_item(&self, cx: &Context) -> Option { + let project_context = self + .as_native_thread(cx)? + .read(cx) + .project_context() + .read(cx); - fn render_error_agent_logo(&self) -> AnyElement { - let logo = Icon::new(self.agent.logo()) - .color(Color::Muted) - .size(IconSize::XLarge) - .into_any_element(); + let user_rules_text = if project_context.user_rules.is_empty() { + None + } else if project_context.user_rules.len() == 1 { + let user_rules = &project_context.user_rules[0]; - h_flex() - .relative() - .justify_center() - .child(div().opacity(0.3).child(logo)) - .child( - h_flex().absolute().right_1().bottom_0().child( - Icon::new(IconName::XCircle) - .color(Color::Error) - .size(IconSize::Small), - ), - ) - .into_any_element() - } + match user_rules.title.as_ref() { + Some(title) => Some(format!("Using \"{title}\" user rule")), + None => Some("Using user rule".into()), + } + } else { + Some(format!( + "Using {} user rules", + project_context.user_rules.len() + )) + }; - fn render_empty_state(&self, cx: &App) -> AnyElement { - let loading = matches!(&self.thread_state, ThreadState::Loading { .. }); + let first_user_rules_id = project_context + .user_rules + .first() + .map(|user_rules| user_rules.uuid.0); - v_flex() - .size_full() - .items_center() - .justify_center() - .child(if loading { - h_flex() - .justify_center() - .child(self.render_agent_logo()) - .with_animation( - "pulsating_icon", - Animation::new(Duration::from_secs(2)) - .repeat() - .with_easing(pulsating_between(0.4, 1.0)), - |icon, delta| icon.opacity(delta), - ) - .into_any() - } else { - self.render_agent_logo().into_any_element() - }) - .child(h_flex().mt_4().mb_1().justify_center().child(if loading { - div() - .child(LoadingLabel::new("").size(LabelSize::Large)) - .into_any_element() - } else { - Headline::new(self.agent.empty_state_headline()) - .size(HeadlineSize::Medium) - .into_any_element() - })) - .child( - div() - .max_w_1_2() - .text_sm() - .text_center() - .map(|this| { - if loading { - this.invisible() - } else { - this.text_color(cx.theme().colors().text_muted) - } - }) - .child(self.agent.empty_state_message()), - ) - .into_any() - } + let rules_files = project_context + .worktrees + .iter() + .filter_map(|worktree| worktree.rules_file.as_ref()) + .collect::>(); + + let rules_file_text = match rules_files.as_slice() { + &[] => None, + &[rules_file] => Some(format!( + "Using project {:?} file", + rules_file.path_in_worktree + )), + rules_files => Some(format!("Using {} project rules files", rules_files.len())), + }; - fn render_pending_auth_state(&self) -> AnyElement { - v_flex() - .items_center() - .justify_center() - .child(self.render_error_agent_logo()) - .child( - h_flex() - .mt_4() - .mb_1() - .justify_center() - .child(Headline::new("Not Authenticated").size(HeadlineSize::Medium)), - ) - .into_any() - } + if user_rules_text.is_none() && rules_file_text.is_none() { + return None; + } + + let has_both = user_rules_text.is_some() && rules_file_text.is_some(); + + Some( + h_flex() + .px_2p5() + .child( + Icon::new(IconName::Attach) + .size(IconSize::XSmall) + .color(Color::Disabled), + ) + .when_some(user_rules_text, |parent, user_rules_text| { + parent.child( + h_flex() + .id("user-rules") + .ml_1() + .mr_1p5() + .child( + Label::new(user_rules_text) + .size(LabelSize::XSmall) + .color(Color::Muted) + .truncate(), + ) + .hover(|s| s.bg(cx.theme().colors().element_hover)) + .tooltip(Tooltip::text("View User Rules")) + .on_click(move |_event, window, cx| { + window.dispatch_action( + Box::new(OpenRulesLibrary { + prompt_to_select: first_user_rules_id, + }), + cx, + ) + }), + ) + }) + .when(has_both, |this| { + this.child( + Label::new("•") + .size(LabelSize::XSmall) + .color(Color::Disabled), + ) + }) + .when_some(rules_file_text, |parent, rules_file_text| { + parent.child( + h_flex() + .id("project-rules") + .ml_1p5() + .child( + Label::new(rules_file_text) + .size(LabelSize::XSmall) + .color(Color::Muted), + ) + .hover(|s| s.bg(cx.theme().colors().element_hover)) + .tooltip(Tooltip::text("View Project Rules")) + .on_click(cx.listener(Self::handle_open_rules)), + ) + }) + .into_any(), + ) + } + + fn render_empty_state_section_header( + &self, + label: impl Into, + action_slot: Option, + cx: &mut Context, + ) -> impl IntoElement { + div().pl_1().pr_1p5().child( + h_flex() + .mt_2() + .pl_1p5() + .pb_1() + .w_full() + .justify_between() + .border_b_1() + .border_color(cx.theme().colors().border_variant) + .child( + Label::new(label.into()) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .children(action_slot), + ) + } + + fn render_recent_history(&self, window: &mut Window, cx: &mut Context) -> AnyElement { + let render_history = self + .agent + .clone() + .downcast::() + .is_some() + && self + .history_store + .update(cx, |history_store, cx| !history_store.is_empty(cx)); - fn render_server_exited(&self, status: ExitStatus, _cx: &Context) -> AnyElement { v_flex() - .items_center() - .justify_center() - .child(self.render_error_agent_logo()) - .child( - v_flex() - .mt_4() - .mb_2() - .gap_0p5() - .text_center() - .items_center() - .child(Headline::new("Server exited unexpectedly").size(HeadlineSize::Medium)) - .child( - Label::new(format!("Exit status: {}", status.code().unwrap_or(-127))) - .size(LabelSize::Small) - .color(Color::Muted), - ), - ) - .into_any_element() + .size_full() + .when(render_history, |this| { + let recent_history: Vec<_> = self.history_store.update(cx, |history_store, _| { + history_store.entries().take(3).collect() + }); + this.justify_end().child( + v_flex() + .child( + self.render_empty_state_section_header( + "Recent", + Some( + Button::new("view-history", "View All") + .style(ButtonStyle::Subtle) + .label_size(LabelSize::Small) + .key_binding( + KeyBinding::for_action_in( + &OpenHistory, + &self.focus_handle(cx), + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(move |_event, window, cx| { + window.dispatch_action(OpenHistory.boxed_clone(), cx); + }) + .into_any_element(), + ), + cx, + ), + ) + .child( + v_flex().p_1().pr_1p5().gap_1().children( + recent_history + .into_iter() + .enumerate() + .map(|(index, entry)| { + // TODO: Add keyboard navigation. + let is_hovered = + self.hovered_recent_history_item == Some(index); + crate::acp::thread_history::AcpHistoryEntryElement::new( + entry, + cx.entity().downgrade(), + ) + .hovered(is_hovered) + .on_hover(cx.listener( + move |this, is_hovered, _window, cx| { + if *is_hovered { + this.hovered_recent_history_item = Some(index); + } else if this.hovered_recent_history_item + == Some(index) + { + this.hovered_recent_history_item = None; + } + cx.notify(); + }, + )) + .into_any_element() + }), + ), + ), + ) + }) + .into_any() } - fn render_load_error(&self, e: &LoadError, cx: &Context) -> AnyElement { - let mut container = v_flex() - .items_center() - .justify_center() - .child(self.render_error_agent_logo()) - .child( - v_flex() - .mt_4() - .mb_2() - .gap_0p5() - .text_center() - .items_center() - .child(Headline::new("Failed to launch").size(HeadlineSize::Medium)) - .child( - Label::new(e.to_string()) - .size(LabelSize::Small) - .color(Color::Muted), - ), - ); + fn render_auth_required_state( + &self, + connection: &Rc, + description: Option<&Entity>, + configuration_view: Option<&AnyView>, + pending_auth_method: Option<&acp::AuthMethodId>, + window: &mut Window, + cx: &Context, + ) -> Div { + let show_description = + configuration_view.is_none() && description.is_none() && pending_auth_method.is_none(); - if let LoadError::Unsupported { - upgrade_message, - upgrade_command, - .. - } = &e - { - let upgrade_message = upgrade_message.clone(); - let upgrade_command = upgrade_command.clone(); - container = container.child(Button::new("upgrade", upgrade_message).on_click( - cx.listener(move |this, _, window, cx| { - this.workspace - .update(cx, |workspace, cx| { - let project = workspace.project().read(cx); - let cwd = project.first_project_directory(cx); - let shell = project.terminal_settings(&cwd, cx).shell.clone(); - let spawn_in_terminal = task::SpawnInTerminal { - id: task::TaskId("install".to_string()), - full_label: upgrade_command.clone(), - label: upgrade_command.clone(), - command: Some(upgrade_command.clone()), - args: Vec::new(), - command_label: upgrade_command.clone(), - cwd, - env: Default::default(), - use_new_terminal: true, - allow_concurrent_runs: true, - reveal: Default::default(), - reveal_target: Default::default(), - hide: Default::default(), - shell, - show_summary: true, - show_command: true, - show_rerun: false, - }; - workspace - .spawn_in_terminal(spawn_in_terminal, window, cx) - .detach(); - }) - .ok(); + let auth_methods = connection.auth_methods(); + + v_flex().flex_1().size_full().justify_end().child( + v_flex() + .p_2() + .pr_3() + .w_full() + .gap_1() + .border_t_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().status().warning.opacity(0.04)) + .child( + h_flex() + .gap_1p5() + .child( + Icon::new(IconName::Warning) + .color(Color::Warning) + .size(IconSize::Small), + ) + .child(Label::new("Authentication Required").size(LabelSize::Small)), + ) + .children(description.map(|desc| { + div().text_ui(cx).child(self.render_markdown( + desc.clone(), + default_markdown_style(false, false, window, cx), + )) + })) + .children( + configuration_view + .cloned() + .map(|view| div().w_full().child(view)), + ) + .when(show_description, |el| { + el.child( + Label::new(format!( + "You are not currently authenticated with {}.{}", + self.agent.name(), + if auth_methods.len() > 1 { + " Please choose one of the following options:" + } else { + "" + } + )) + .size(LabelSize::Small) + .color(Color::Muted) + .mb_1() + .ml_5(), + ) + }) + .when_some(pending_auth_method, |el, _| { + el.child( + h_flex() + .py_4() + .w_full() + .justify_center() + .gap_1() + .child( + Icon::new(IconName::ArrowCircle) + .size(IconSize::Small) + .color(Color::Muted) + .with_rotate_animation(2), + ) + .child(Label::new("Authenticating…").size(LabelSize::Small)), + ) + }) + .when(!auth_methods.is_empty(), |this| { + this.child( + h_flex() + .justify_end() + .flex_wrap() + .gap_1() + .when(!show_description, |this| { + this.border_t_1() + .mt_1() + .pt_2() + .border_color(cx.theme().colors().border.opacity(0.8)) + }) + .children(connection.auth_methods().iter().enumerate().rev().map( + |(ix, method)| { + let (method_id, name) = if self + .project + .read(cx) + .is_via_remote_server() + && method.id.0.as_ref() == "oauth-personal" + && method.name == "Log in with Google" + { + ("spawn-gemini-cli".into(), "Log in with Gemini CLI".into()) + } else { + (method.id.0.clone(), method.name.clone()) + }; + + Button::new(SharedString::from(method_id.clone()), name) + .when(ix == 0, |el| { + el.style(ButtonStyle::Tinted(ui::TintColor::Warning)) + }) + .label_size(LabelSize::Small) + .on_click({ + cx.listener(move |this, _, window, cx| { + telemetry::event!( + "Authenticate Agent Started", + agent = this.agent.telemetry_id(), + method = method_id + ); + + this.authenticate( + acp::AuthMethodId(method_id.clone()), + window, + cx, + ) + }) + }) + }, + )), + ) }), - )); - } + ) + } + + fn render_load_error( + &self, + e: &LoadError, + window: &mut Window, + cx: &mut Context, + ) -> AnyElement { + let (title, message, action_slot): (_, SharedString, _) = match e { + LoadError::Unsupported { + command: path, + current_version, + minimum_version, + } => { + return self.render_unsupported(path, current_version, minimum_version, window, cx); + } + LoadError::FailedToInstall(msg) => ( + "Failed to Install", + msg.into(), + Some(self.create_copy_button(msg.to_string()).into_any_element()), + ), + LoadError::Exited { status } => ( + "Failed to Launch", + format!("Server exited with status {status}").into(), + None, + ), + LoadError::Other(msg) => ( + "Failed to Launch", + msg.into(), + Some(self.create_copy_button(msg.to_string()).into_any_element()), + ), + }; + + Callout::new() + .severity(Severity::Error) + .icon(IconName::XCircleFilled) + .title(title) + .description(message) + .actions_slot(div().children(action_slot)) + .into_any_element() + } + + fn render_unsupported( + &self, + path: &SharedString, + version: &SharedString, + minimum_version: &SharedString, + _window: &mut Window, + cx: &mut Context, + ) -> AnyElement { + let (heading_label, description_label) = ( + format!("Upgrade {} to work with Zed", self.agent.name()), + if version.is_empty() { + format!( + "Currently using {}, which does not report a valid --version", + path, + ) + } else { + format!( + "Currently using {}, which is only version {} (need at least {minimum_version})", + path, version + ) + }, + ); - container.into_any() + v_flex() + .w_full() + .p_3p5() + .gap_2p5() + .border_t_1() + .border_color(cx.theme().colors().border) + .bg(linear_gradient( + 180., + linear_color_stop(cx.theme().colors().editor_background.opacity(0.4), 4.), + linear_color_stop(cx.theme().status().info_background.opacity(0.), 0.), + )) + .child( + v_flex().gap_0p5().child(Label::new(heading_label)).child( + Label::new(description_label) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + .into_any_element() } fn render_activity_bar( @@ -2058,7 +3327,12 @@ impl AcpThreadView { let active_color = cx.theme().colors().element_selected; let bg_edit_files_disclosure = editor_bg_color.blend(active_color.opacity(0.3)); - let pending_edits = thread.has_pending_edit_tool_calls(); + // Temporarily always enable ACP edit controls. This is temporary, to lessen the + // impact of a nasty bug that causes them to sometimes be disabled when they shouldn't + // be, which blocks you from being able to accept or reject edits. This switches the + // bug to be that sometimes it's enabled when it shouldn't be, which at least doesn't + // block you from using the panel. + let pending_edits = false; v_flex() .mt_1() @@ -2085,7 +3359,6 @@ impl AcpThreadView { }) .when(!changed_buffers.is_empty(), |this| { this.child(self.render_edits_summary( - action_log, &changed_buffers, self.edits_expanded, pending_edits, @@ -2212,13 +3485,7 @@ impl AcpThreadView { acp::PlanEntryStatus::InProgress => Icon::new(IconName::TodoProgress) .size(IconSize::Small) .color(Color::Accent) - .with_animation( - "running", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| { - icon.transform(Transformation::rotate(percentage(delta))) - }, - ) + .with_rotate_animation(2) .into_any_element(), acp::PlanEntryStatus::Completed => Icon::new(IconName::TodoComplete) .size(IconSize::Small) @@ -2237,7 +3504,6 @@ impl AcpThreadView { fn render_edits_summary( &self, - action_log: &Entity, changed_buffers: &BTreeMap, Entity>, expanded: bool, pending_edits: bool, @@ -2251,13 +3517,13 @@ impl AcpThreadView { h_flex() .p_1() .justify_between() + .flex_wrap() .when(expanded, |this| { this.border_b_1().border_color(cx.theme().colors().border) }) .child( h_flex() .id("edits-container") - .w_full() .gap_1() .child(Disclosure::new("edits-disclosure", expanded)) .map(|this| { @@ -2348,14 +3614,9 @@ impl AcpThreadView { ) .map(|kb| kb.size(rems_from_px(10.))), ) - .on_click({ - let action_log = action_log.clone(); - cx.listener(move |_, _, _, cx| { - action_log.update(cx, |action_log, cx| { - action_log.reject_all_edits(cx).detach(); - }) - }) - }), + .on_click(cx.listener(move |this, _, window, cx| { + this.reject_all(&RejectAll, window, cx); + })), ) .child( Button::new("keep-all-changes", "Keep All") @@ -2368,14 +3629,9 @@ impl AcpThreadView { KeyBinding::for_action_in(&KeepAll, &focus_handle, window, cx) .map(|kb| kb.size(rems_from_px(10.))), ) - .on_click({ - let action_log = action_log.clone(); - cx.listener(move |_, _, _, cx| { - action_log.update(cx, |action_log, cx| { - action_log.keep_all_edits(cx); - }) - }) - }), + .on_click(cx.listener(move |this, _, window, cx| { + this.keep_all(&KeepAll, window, cx); + })), ), ) } @@ -2389,7 +3645,7 @@ impl AcpThreadView { ) -> Div { let editor_bg_color = cx.theme().colors().editor_background; - v_flex().children(changed_buffers.into_iter().enumerate().flat_map( + v_flex().children(changed_buffers.iter().enumerate().flat_map( |(index, (buffer, _diff))| { let file = buffer.read(cx).file()?; let path = file.path(); @@ -2415,7 +3671,7 @@ impl AcpThreadView { .buffer_font(cx) }); - let file_icon = FileIcons::get_icon(&path, cx) + let file_icon = FileIcons::get_icon(path, cx) .map(Icon::from_path) .map(|icon| icon.color(Color::Muted).size(IconSize::Small)) .unwrap_or_else(|| { @@ -2433,7 +3689,6 @@ impl AcpThreadView { let element = h_flex() .group("edited-code") .id(("file-container", index)) - .relative() .py_1() .pl_2() .pr_1() @@ -2445,6 +3700,7 @@ impl AcpThreadView { }) .child( h_flex() + .relative() .id(("file-name", index)) .pr_8() .gap_1p5() @@ -2452,6 +3708,16 @@ impl AcpThreadView { .overflow_x_scroll() .child(file_icon) .child(h_flex().gap_0p5().children(file_name).children(file_path)) + .child( + div() + .absolute() + .h_full() + .w_12() + .top_0() + .bottom_0() + .right_0() + .bg(overlay_gradient), + ) .on_click({ let buffer = buffer.clone(); cx.listener(move |this, _, window, cx| { @@ -2512,17 +3778,6 @@ impl AcpThreadView { } }), ), - ) - .child( - div() - .id("gradient-overlay") - .absolute() - .h_full() - .w_12() - .top_0() - .bottom_0() - .right(px(152.)) - .bg(overlay_gradient), ); Some(element) @@ -2539,8 +3794,35 @@ impl AcpThreadView { (IconName::Maximize, "Expand Message Editor") }; + let backdrop = div() + .size_full() + .absolute() + .inset_0() + .bg(cx.theme().colors().panel_background) + .opacity(0.8) + .block_mouse_except_scroll(); + + let enable_editor = match self.thread_state { + ThreadState::Loading { .. } | ThreadState::Ready { .. } => true, + ThreadState::Unauthenticated { .. } | ThreadState::LoadError(..) => false, + }; + v_flex() .on_action(cx.listener(Self::expand_message_editor)) + .on_action(cx.listener(|this, _: &ToggleProfileSelector, window, cx| { + if let Some(profile_selector) = this.profile_selector.as_ref() { + profile_selector.read(cx).menu_handle().toggle(window, cx); + } else if let Some(mode_selector) = this.mode_selector() { + mode_selector.read(cx).menu_handle().toggle(window, cx); + } + })) + .on_action(cx.listener(|this, _: &CycleModeSelector, window, cx| { + if let Some(mode_selector) = this.mode_selector() { + mode_selector.update(cx, |mode_selector, cx| { + mode_selector.cycle_mode(window, cx); + }); + } + })) .on_action(cx.listener(|this, _: &ToggleModelSelector, window, cx| { if let Some(model_selector) = this.model_selector.as_ref() { model_selector @@ -2561,34 +3843,7 @@ impl AcpThreadView { .size_full() .pt_1() .pr_2p5() - .child(div().flex_1().child({ - let settings = ThemeSettings::get_global(cx); - let font_size = TextSize::Small - .rems(cx) - .to_pixels(settings.agent_font_size(cx)); - let line_height = settings.buffer_line_height.value() * font_size; - - let text_style = TextStyle { - color: cx.theme().colors().text, - font_family: settings.buffer_font.family.clone(), - font_fallbacks: settings.buffer_font.fallbacks.clone(), - font_features: settings.buffer_font.features.clone(), - font_size: font_size.into(), - line_height: line_height.into(), - ..Default::default() - }; - - EditorElement::new( - &self.message_editor, - EditorStyle { - background: editor_bg_color, - local_player: cx.theme().players().local(), - text: text_style, - syntax: cx.theme().syntax().clone(), - ..Default::default() - }, - ) - })) + .child(self.message_editor.clone()) .child( h_flex() .absolute() @@ -2601,7 +3856,6 @@ impl AcpThreadView { .icon_size(IconSize::Small) .icon_color(Color::Muted) .tooltip({ - let focus_handle = focus_handle.clone(); move |window, cx| { Tooltip::for_action_in( expand_tooltip, @@ -2621,85 +3875,316 @@ impl AcpThreadView { .child( h_flex() .flex_none() + .flex_wrap() .justify_between() - .child(self.render_follow_toggle(cx)) + .child( + h_flex() + .child(self.render_follow_toggle(cx)) + .children(self.render_burn_mode_toggle(cx)), + ) .child( h_flex() .gap_1() + .children(self.render_token_usage(cx)) + .children(self.profile_selector.clone()) + .children(self.mode_selector().cloned()) .children(self.model_selector.clone()) .child(self.render_send_button(cx)), ), ) + .when(!enable_editor, |this| this.child(backdrop)) .into_any() } - fn render_send_button(&self, cx: &mut Context) -> AnyElement { - if self.thread().map_or(true, |thread| { - thread.read(cx).status() == ThreadStatus::Idle - }) { - let is_editor_empty = self.message_editor.read(cx).is_empty(cx); - IconButton::new("send-message", IconName::Send) - .icon_color(Color::Accent) - .style(ButtonStyle::Filled) - .disabled(self.thread().is_none() || is_editor_empty) - .when(!is_editor_empty, |button| { - button.tooltip(move |window, cx| Tooltip::for_action("Send", &Chat, window, cx)) - }) - .when(is_editor_empty, |button| { - button.tooltip(Tooltip::text("Type a message to submit")) - }) - .on_click(cx.listener(|this, _, window, cx| { - this.chat(&Chat, window, cx); - })) - .into_any_element() - } else { - IconButton::new("stop-generation", IconName::Stop) - .icon_color(Color::Error) - .style(ButtonStyle::Tinted(ui::TintColor::Error)) - .tooltip(move |window, cx| { - Tooltip::for_action("Stop Generation", &editor::actions::Cancel, window, cx) - }) - .on_click(cx.listener(|this, _event, _, cx| this.cancel(cx))) - .into_any_element() - } + pub(crate) fn as_native_connection( + &self, + cx: &App, + ) -> Option> { + let acp_thread = self.thread()?.read(cx); + acp_thread.connection().clone().downcast() } - fn render_follow_toggle(&self, cx: &mut Context) -> impl IntoElement { - let following = self - .workspace - .read_with(cx, |workspace, _| { - workspace.is_being_followed(CollaboratorId::Agent) - }) - .unwrap_or(false); + pub(crate) fn as_native_thread(&self, cx: &App) -> Option> { + let acp_thread = self.thread()?.read(cx); + self.as_native_connection(cx)? + .thread(acp_thread.session_id(), cx) + } - IconButton::new("follow-agent", IconName::Crosshair) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .toggle_state(following) - .selected_icon_color(Some(Color::Custom(cx.theme().players().agent().cursor))) - .tooltip(move |window, cx| { - if following { - Tooltip::for_action("Stop Following Agent", &Follow, window, cx) - } else { - Tooltip::with_meta( - "Follow Agent", - Some(&Follow), - "Track the agent's location as it reads and edits files.", - window, - cx, - ) - } + fn is_using_zed_ai_models(&self, cx: &App) -> bool { + self.as_native_thread(cx) + .and_then(|thread| thread.read(cx).model()) + .is_some_and(|model| model.provider_id() == language_model::ZED_CLOUD_PROVIDER_ID) + } + + fn render_token_usage(&self, cx: &mut Context) -> Option
{ + let thread = self.thread()?.read(cx); + let usage = thread.token_usage()?; + let is_generating = thread.status() != ThreadStatus::Idle; + + let used = crate::text_thread_editor::humanize_token_count(usage.used_tokens); + let max = crate::text_thread_editor::humanize_token_count(usage.max_tokens); + + Some( + h_flex() + .flex_shrink_0() + .gap_0p5() + .mr_1p5() + .child( + Label::new(used) + .size(LabelSize::Small) + .color(Color::Muted) + .map(|label| { + if is_generating { + label + .with_animation( + "used-tokens-label", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.3, 0.8)), + |label, delta| label.alpha(delta), + ) + .into_any() + } else { + label.into_any_element() + } + }), + ) + .child( + Label::new("/") + .size(LabelSize::Small) + .color(Color::Custom(cx.theme().colors().text_muted.opacity(0.5))), + ) + .child(Label::new(max).size(LabelSize::Small).color(Color::Muted)), + ) + } + + fn toggle_burn_mode( + &mut self, + _: &ToggleBurnMode, + _window: &mut Window, + cx: &mut Context, + ) { + let Some(thread) = self.as_native_thread(cx) else { + return; + }; + + thread.update(cx, |thread, cx| { + let current_mode = thread.completion_mode(); + thread.set_completion_mode( + match current_mode { + CompletionMode::Burn => CompletionMode::Normal, + CompletionMode::Normal => CompletionMode::Burn, + }, + cx, + ); + }); + } + + fn keep_all(&mut self, _: &KeepAll, _window: &mut Window, cx: &mut Context) { + let Some(thread) = self.thread() else { + return; + }; + let action_log = thread.read(cx).action_log().clone(); + action_log.update(cx, |action_log, cx| action_log.keep_all_edits(cx)); + } + + fn reject_all(&mut self, _: &RejectAll, _window: &mut Window, cx: &mut Context) { + let Some(thread) = self.thread() else { + return; + }; + let action_log = thread.read(cx).action_log().clone(); + action_log + .update(cx, |action_log, cx| action_log.reject_all_edits(cx)) + .detach(); + } + + fn allow_always(&mut self, _: &AllowAlways, window: &mut Window, cx: &mut Context) { + self.authorize_pending_tool_call(acp::PermissionOptionKind::AllowAlways, window, cx); + } + + fn allow_once(&mut self, _: &AllowOnce, window: &mut Window, cx: &mut Context) { + self.authorize_pending_tool_call(acp::PermissionOptionKind::AllowOnce, window, cx); + } + + fn reject_once(&mut self, _: &RejectOnce, window: &mut Window, cx: &mut Context) { + self.authorize_pending_tool_call(acp::PermissionOptionKind::RejectOnce, window, cx); + } + + fn authorize_pending_tool_call( + &mut self, + kind: acp::PermissionOptionKind, + window: &mut Window, + cx: &mut Context, + ) -> Option<()> { + let thread = self.thread()?.read(cx); + let tool_call = thread.first_tool_awaiting_confirmation()?; + let ToolCallStatus::WaitingForConfirmation { options, .. } = &tool_call.status else { + return None; + }; + let option = options.iter().find(|o| o.kind == kind)?; + + self.authorize_tool_call( + tool_call.id.clone(), + option.id.clone(), + option.kind, + window, + cx, + ); + + Some(()) + } + + fn render_burn_mode_toggle(&self, cx: &mut Context) -> Option { + let thread = self.as_native_thread(cx)?.read(cx); + + if thread + .model() + .is_none_or(|model| !model.supports_burn_mode()) + { + return None; + } + + let active_completion_mode = thread.completion_mode(); + let burn_mode_enabled = active_completion_mode == CompletionMode::Burn; + let icon = if burn_mode_enabled { + IconName::ZedBurnModeOn + } else { + IconName::ZedBurnMode + }; + + Some( + IconButton::new("burn-mode", icon) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .toggle_state(burn_mode_enabled) + .selected_icon_color(Color::Error) + .on_click(cx.listener(|this, _event, window, cx| { + this.toggle_burn_mode(&ToggleBurnMode, window, cx); + })) + .tooltip(move |_window, cx| { + cx.new(|_| BurnModeTooltip::new().selected(burn_mode_enabled)) + .into() + }) + .into_any_element(), + ) + } + + fn render_send_button(&self, cx: &mut Context) -> AnyElement { + let is_editor_empty = self.message_editor.read(cx).is_empty(cx); + let is_generating = self + .thread() + .is_some_and(|thread| thread.read(cx).status() != ThreadStatus::Idle); + + if self.is_loading_contents { + div() + .id("loading-message-content") + .px_1() + .tooltip(Tooltip::text("Loading Added Context…")) + .child(loading_contents_spinner(IconSize::default())) + .into_any_element() + } else if is_generating && is_editor_empty { + IconButton::new("stop-generation", IconName::Stop) + .icon_color(Color::Error) + .style(ButtonStyle::Tinted(ui::TintColor::Error)) + .tooltip(move |window, cx| { + Tooltip::for_action("Stop Generation", &editor::actions::Cancel, window, cx) + }) + .on_click(cx.listener(|this, _event, _, cx| this.cancel_generation(cx))) + .into_any_element() + } else { + let send_btn_tooltip = if is_editor_empty && !is_generating { + "Type to Send" + } else if is_generating { + "Stop and Send Message" + } else { + "Send" + }; + + IconButton::new("send-message", IconName::Send) + .style(ButtonStyle::Filled) + .map(|this| { + if is_editor_empty && !is_generating { + this.disabled(true).icon_color(Color::Muted) + } else { + this.icon_color(Color::Accent) + } + }) + .tooltip(move |window, cx| Tooltip::for_action(send_btn_tooltip, &Chat, window, cx)) + .on_click(cx.listener(|this, _, window, cx| { + this.send(window, cx); + })) + .into_any_element() + } + } + + fn is_following(&self, cx: &App) -> bool { + match self.thread().map(|thread| thread.read(cx).status()) { + Some(ThreadStatus::Generating) => self + .workspace + .read_with(cx, |workspace, _| { + workspace.is_being_followed(CollaboratorId::Agent) + }) + .unwrap_or(false), + _ => self.should_be_following, + } + } + + fn toggle_following(&mut self, window: &mut Window, cx: &mut Context) { + let following = self.is_following(cx); + + self.should_be_following = !following; + if self.thread().map(|thread| thread.read(cx).status()) == Some(ThreadStatus::Generating) { + self.workspace + .update(cx, |workspace, cx| { + if following { + workspace.unfollow(CollaboratorId::Agent, window, cx); + } else { + workspace.follow(CollaboratorId::Agent, window, cx); + } + }) + .ok(); + } + + telemetry::event!("Follow Agent Selected", following = !following); + } + + fn render_follow_toggle(&self, cx: &mut Context) -> impl IntoElement { + let following = self.is_following(cx); + + let tooltip_label = if following { + if self.agent.name() == "Zed Agent" { + format!("Stop Following the {}", self.agent.name()) + } else { + format!("Stop Following {}", self.agent.name()) + } + } else { + if self.agent.name() == "Zed Agent" { + format!("Follow the {}", self.agent.name()) + } else { + format!("Follow {}", self.agent.name()) + } + }; + + IconButton::new("follow-agent", IconName::Crosshair) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .toggle_state(following) + .selected_icon_color(Some(Color::Custom(cx.theme().players().agent().cursor))) + .tooltip(move |window, cx| { + if following { + Tooltip::for_action(tooltip_label.clone(), &Follow, window, cx) + } else { + Tooltip::with_meta( + tooltip_label.clone(), + Some(&Follow), + "Track the agent's location as it reads and edits files.", + window, + cx, + ) + } }) .on_click(cx.listener(move |this, _, window, cx| { - this.workspace - .update(cx, |workspace, cx| { - if following { - workspace.unfollow(CollaboratorId::Agent, window, cx); - } else { - workspace.follow(CollaboratorId::Agent, window, cx); - } - }) - .ok(); + this.toggle_following(window, cx); })) } @@ -2723,36 +4208,45 @@ impl AcpThreadView { if let Some(mention) = MentionUri::parse(&url).log_err() { workspace.update(cx, |workspace, cx| match mention { - MentionUri::File(path) => { + MentionUri::File { abs_path } => { + let project = workspace.project(); + let Some(path) = + project.update(cx, |project, cx| project.find_project_path(abs_path, cx)) + else { + return; + }; + + workspace + .open_path(path, None, true, window, cx) + .detach_and_log_err(cx); + } + MentionUri::PastedImage => {} + MentionUri::Directory { abs_path } => { let project = workspace.project(); - let Some((path, entry)) = project.update(cx, |project, cx| { - let path = project.find_project_path(path, cx)?; - let entry = project.entry_for_path(&path, cx)?; - Some((path, entry)) + let Some(entry_id) = project.update(cx, |project, cx| { + let path = project.find_project_path(abs_path, cx)?; + project.entry_for_path(&path, cx).map(|entry| entry.id) }) else { return; }; - if entry.is_dir() { - project.update(cx, |_, cx| { - cx.emit(project::Event::RevealInProjectPanel(entry.id)); - }); - } else { - workspace - .open_path(path, None, true, window, cx) - .detach_and_log_err(cx); - } + project.update(cx, |_, cx| { + cx.emit(project::Event::RevealInProjectPanel(entry_id)); + }); } MentionUri::Symbol { - path, line_range, .. + abs_path: path, + line_range, + .. } - | MentionUri::Selection { path, line_range } => { + | MentionUri::Selection { + abs_path: Some(path), + line_range, + } => { let project = workspace.project(); - let Some((path, _)) = project.update(cx, |project, cx| { - let path = project.find_project_path(path, cx)?; - let entry = project.entry_for_path(&path, cx)?; - Some((path, entry)) - }) else { + let Some(path) = + project.update(cx, |project, cx| project.find_project_path(path, cx)) + else { return; }; @@ -2762,8 +4256,8 @@ impl AcpThreadView { let Some(editor) = item.await?.downcast::() else { return Ok(()); }; - let range = - Point::new(line_range.start, 0)..Point::new(line_range.start, 0); + let range = Point::new(*line_range.start(), 0) + ..Point::new(*line_range.start(), 0); editor .update_in(cx, |editor, window, cx| { editor.change_selections( @@ -2778,12 +4272,19 @@ impl AcpThreadView { }) .detach_and_log_err(cx); } - MentionUri::Thread { id, .. } => { + MentionUri::Selection { abs_path: None, .. } => {} + MentionUri::Thread { id, name } => { if let Some(panel) = workspace.panel::(cx) { panel.update(cx, |panel, cx| { - panel - .open_thread_by_id(&id, window, cx) - .detach_and_log_err(cx) + panel.load_agent_thread( + DbThreadMetadata { + id, + title: name.into(), + updated_at: Default::default(), + }, + window, + cx, + ) }); } } @@ -2882,7 +4383,7 @@ impl AcpThreadView { workspace: Entity, window: &mut Window, cx: &mut App, - ) -> Task> { + ) -> Task> { let markdown_language_task = workspace .read(cx) .app_state() @@ -2903,11 +4404,11 @@ impl AcpThreadView { let project = workspace.project().clone(); if !project.read(cx).is_local() { - anyhow::bail!("failed to open active thread as markdown in remote project"); + bail!("failed to open active thread as markdown in remote project"); } let buffer = project.update(cx, |project, cx| { - project.create_local_buffer(&markdown, Some(markdown_language), cx) + project.create_local_buffer(&markdown, Some(markdown_language), true, cx) }); let buffer = cx.new(|cx| { MultiBuffer::singleton(buffer, cx).with_title(thread_summary.clone()) @@ -2974,7 +4475,8 @@ impl AcpThreadView { return; } - let title = self.title(cx); + // TODO: Change this once we have title summarization for external agents. + let title = self.agent.name(); match AgentSettings::get_global(cx).notify_when_agent_waiting { NotifyWhenAgentWaiting::PrimaryScreen => { @@ -3022,62 +4524,61 @@ impl AcpThreadView { }) }) .log_err() + && let Some(pop_up) = screen_window.entity(cx).log_err() { - if let Some(pop_up) = screen_window.entity(cx).log_err() { - self.notification_subscriptions - .entry(screen_window) - .or_insert_with(Vec::new) - .push(cx.subscribe_in(&pop_up, window, { - |this, _, event, window, cx| match event { - AgentNotificationEvent::Accepted => { - let handle = window.window_handle(); - cx.activate(true); - - let workspace_handle = this.workspace.clone(); - - // If there are multiple Zed windows, activate the correct one. - cx.defer(move |cx| { - handle - .update(cx, |_view, window, _cx| { - window.activate_window(); - - if let Some(workspace) = workspace_handle.upgrade() { - workspace.update(_cx, |workspace, cx| { - workspace.focus_panel::(window, cx); - }); - } - }) - .log_err(); - }); + self.notification_subscriptions + .entry(screen_window) + .or_insert_with(Vec::new) + .push(cx.subscribe_in(&pop_up, window, { + |this, _, event, window, cx| match event { + AgentNotificationEvent::Accepted => { + let handle = window.window_handle(); + cx.activate(true); + + let workspace_handle = this.workspace.clone(); + + // If there are multiple Zed windows, activate the correct one. + cx.defer(move |cx| { + handle + .update(cx, |_view, window, _cx| { + window.activate_window(); + + if let Some(workspace) = workspace_handle.upgrade() { + workspace.update(_cx, |workspace, cx| { + workspace.focus_panel::(window, cx); + }); + } + }) + .log_err(); + }); - this.dismiss_notifications(cx); - } - AgentNotificationEvent::Dismissed => { - this.dismiss_notifications(cx); - } + this.dismiss_notifications(cx); } - })); - - self.notifications.push(screen_window); - - // If the user manually refocuses the original window, dismiss the popup. - self.notification_subscriptions - .entry(screen_window) - .or_insert_with(Vec::new) - .push({ - let pop_up_weak = pop_up.downgrade(); - - cx.observe_window_activation(window, move |_, window, cx| { - if window.is_window_active() { - if let Some(pop_up) = pop_up_weak.upgrade() { - pop_up.update(cx, |_, cx| { - cx.emit(AgentNotificationEvent::Dismissed); - }); - } - } - }) - }); - } + AgentNotificationEvent::Dismissed => { + this.dismiss_notifications(cx); + } + } + })); + + self.notifications.push(screen_window); + + // If the user manually refocuses the original window, dismiss the popup. + self.notification_subscriptions + .entry(screen_window) + .or_insert_with(Vec::new) + .push({ + let pop_up_weak = pop_up.downgrade(); + + cx.observe_window_activation(window, move |_, window, cx| { + if window.is_window_active() + && let Some(pop_up) = pop_up_weak.upgrade() + { + pop_up.update(cx, |_, cx| { + cx.emit(AgentNotificationEvent::Dismissed); + }); + } + }) + }); } } @@ -3093,7 +4594,21 @@ impl AcpThreadView { } } - fn render_thread_controls(&self, cx: &Context) -> impl IntoElement { + fn render_thread_controls( + &self, + thread: &Entity, + cx: &Context, + ) -> impl IntoElement { + let is_generating = matches!(thread.read(cx).status(), ThreadStatus::Generating); + if is_generating { + return h_flex().id("thread-controls-container").child( + div() + .py_2() + .px(rems_from_px(22.)) + .child(SpinnerLabel::new().size(LabelSize::Small)), + ); + } + let open_as_markdown = IconButton::new("open-as-markdown", IconName::FileMarkdown) .shape(ui::IconButtonShape::Square) .icon_size(IconSize::Small) @@ -3115,17 +4630,136 @@ impl AcpThreadView { this.scroll_to_top(cx); })); - h_flex() + let mut container = h_flex() + .id("thread-controls-container") + .group("thread-controls-container") .w_full() - .mr_1() - .pb_2() - .px(RESPONSE_PADDING_X) - .opacity(0.4) + .py_2() + .px_5() + .gap_px() + .opacity(0.6) .hover(|style| style.opacity(1.)) .flex_wrap() - .justify_end() - .child(open_as_markdown) - .child(scroll_to_top) + .justify_end(); + + if AgentSettings::get_global(cx).enable_feedback + && self + .thread() + .is_some_and(|thread| thread.read(cx).connection().telemetry().is_some()) + { + let feedback = self.thread_feedback.feedback; + + container = container + .child( + div().visible_on_hover("thread-controls-container").child( + Label::new(match feedback { + Some(ThreadFeedback::Positive) => "Thanks for your feedback!", + Some(ThreadFeedback::Negative) => { + "We appreciate your feedback and will use it to improve." + } + None => { + "Rating the thread sends all of your current conversation to the Zed team." + } + }) + .color(Color::Muted) + .size(LabelSize::XSmall) + .truncate(), + ), + ) + .child( + IconButton::new("feedback-thumbs-up", IconName::ThumbsUp) + .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::Small) + .icon_color(match feedback { + Some(ThreadFeedback::Positive) => Color::Accent, + _ => Color::Ignored, + }) + .tooltip(Tooltip::text("Helpful Response")) + .on_click(cx.listener(move |this, _, window, cx| { + this.handle_feedback_click(ThreadFeedback::Positive, window, cx); + })), + ) + .child( + IconButton::new("feedback-thumbs-down", IconName::ThumbsDown) + .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::Small) + .icon_color(match feedback { + Some(ThreadFeedback::Negative) => Color::Accent, + _ => Color::Ignored, + }) + .tooltip(Tooltip::text("Not Helpful")) + .on_click(cx.listener(move |this, _, window, cx| { + this.handle_feedback_click(ThreadFeedback::Negative, window, cx); + })), + ); + } + + container.child(open_as_markdown).child(scroll_to_top) + } + + fn render_feedback_feedback_editor(editor: Entity, cx: &Context) -> Div { + h_flex() + .key_context("AgentFeedbackMessageEditor") + .on_action(cx.listener(move |this, _: &menu::Cancel, _, cx| { + this.thread_feedback.dismiss_comments(); + cx.notify(); + })) + .on_action(cx.listener(move |this, _: &menu::Confirm, _window, cx| { + this.submit_feedback_message(cx); + })) + .p_2() + .mb_2() + .mx_5() + .gap_1() + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().editor_background) + .child(div().w_full().child(editor)) + .child( + h_flex() + .child( + IconButton::new("dismiss-feedback-message", IconName::Close) + .icon_color(Color::Error) + .icon_size(IconSize::XSmall) + .shape(ui::IconButtonShape::Square) + .on_click(cx.listener(move |this, _, _window, cx| { + this.thread_feedback.dismiss_comments(); + cx.notify(); + })), + ) + .child( + IconButton::new("submit-feedback-message", IconName::Return) + .icon_size(IconSize::XSmall) + .shape(ui::IconButtonShape::Square) + .on_click(cx.listener(move |this, _, _window, cx| { + this.submit_feedback_message(cx); + })), + ), + ) + } + + fn handle_feedback_click( + &mut self, + feedback: ThreadFeedback, + window: &mut Window, + cx: &mut Context, + ) { + let Some(thread) = self.thread().cloned() else { + return; + }; + + self.thread_feedback.submit(thread, feedback, window, cx); + cx.notify(); + } + + fn submit_feedback_message(&mut self, cx: &mut Context) { + let Some(thread) = self.thread().cloned() else { + return; + }; + + self.thread_feedback.submit_comments(thread, cx); + cx.notify(); } fn render_vertical_scrollbar(&self, cx: &mut Context) -> Stateful
{ @@ -3161,165 +4795,620 @@ impl AcpThreadView { .children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx))) } - fn settings_changed(&mut self, _window: &mut Window, cx: &mut Context) { - for diff_editor in self.diff_editors.values() { - diff_editor.update(cx, |diff_editor, cx| { - diff_editor.set_text_style_refinement(diff_editor_text_style_refinement(cx)); - cx.notify(); - }) - } - } + fn render_token_limit_callout( + &self, + line_height: Pixels, + cx: &mut Context, + ) -> Option { + let token_usage = self.thread()?.read(cx).token_usage()?; + let ratio = token_usage.ratio(); + + let (severity, title) = match ratio { + acp_thread::TokenUsageRatio::Normal => return None, + acp_thread::TokenUsageRatio::Warning => { + (Severity::Warning, "Thread reaching the token limit soon") + } + acp_thread::TokenUsageRatio::Exceeded => { + (Severity::Error, "Thread reached the token limit") + } + }; + + let burn_mode_available = self.as_native_thread(cx).is_some_and(|thread| { + thread.read(cx).completion_mode() == CompletionMode::Normal + && thread + .read(cx) + .model() + .is_some_and(|model| model.supports_burn_mode()) + }); + + let description = if burn_mode_available { + "To continue, start a new thread from a summary or turn Burn Mode on." + } else { + "To continue, start a new thread from a summary." + }; + + Some( + Callout::new() + .severity(severity) + .line_height(line_height) + .title(title) + .description(description) + .actions_slot( + h_flex() + .gap_0p5() + .child( + Button::new("start-new-thread", "Start New Thread") + .label_size(LabelSize::Small) + .on_click(cx.listener(|this, _, window, cx| { + let Some(thread) = this.thread() else { + return; + }; + let session_id = thread.read(cx).session_id().clone(); + window.dispatch_action( + crate::NewNativeAgentThreadFromSummary { + from_session_id: session_id, + } + .boxed_clone(), + cx, + ); + })), + ) + .when(burn_mode_available, |this| { + this.child( + IconButton::new("burn-mode-callout", IconName::ZedBurnMode) + .icon_size(IconSize::XSmall) + .on_click(cx.listener(|this, _event, window, cx| { + this.toggle_burn_mode(&ToggleBurnMode, window, cx); + })), + ) + }), + ), + ) + } + + fn render_usage_callout(&self, line_height: Pixels, cx: &mut Context) -> Option
{ + if !self.is_using_zed_ai_models(cx) { + return None; + } + + let user_store = self.project.read(cx).user_store().read(cx); + if user_store.is_usage_based_billing_enabled() { + return None; + } + + let plan = user_store + .plan() + .unwrap_or(cloud_llm_client::Plan::V1(PlanV1::ZedFree)); + + let usage = user_store.model_request_usage()?; + + Some( + div() + .child(UsageCallout::new(plan, usage)) + .line_height(line_height), + ) + } + + fn agent_font_size_changed(&mut self, _window: &mut Window, cx: &mut Context) { + self.entry_view_state.update(cx, |entry_view_state, cx| { + entry_view_state.agent_font_size_changed(cx); + }); + } pub(crate) fn insert_dragged_files( &self, paths: Vec, - _added_worktrees: Vec>, + added_worktrees: Vec>, window: &mut Window, - cx: &mut Context<'_, Self>, + cx: &mut Context, ) { - let buffer = self.message_editor.read(cx).buffer().clone(); - let Some((&excerpt_id, _, _)) = buffer.read(cx).snapshot(cx).as_singleton() else { - return; - }; - let Some(buffer) = buffer.read(cx).as_singleton() else { - return; + self.message_editor.update(cx, |message_editor, cx| { + message_editor.insert_dragged_files(paths, added_worktrees, window, cx); + }) + } + + pub(crate) fn insert_selections(&self, window: &mut Window, cx: &mut Context) { + self.message_editor.update(cx, |message_editor, cx| { + message_editor.insert_selections(window, cx); + }) + } + + fn render_thread_retry_status_callout( + &self, + _window: &mut Window, + _cx: &mut Context, + ) -> Option { + let state = self.thread_retry_status.as_ref()?; + + let next_attempt_in = state + .duration + .saturating_sub(Instant::now().saturating_duration_since(state.started_at)); + if next_attempt_in.is_zero() { + return None; + } + + let next_attempt_in_secs = next_attempt_in.as_secs() + 1; + + let retry_message = if state.max_attempts == 1 { + if next_attempt_in_secs == 1 { + "Retrying. Next attempt in 1 second.".to_string() + } else { + format!("Retrying. Next attempt in {next_attempt_in_secs} seconds.") + } + } else if next_attempt_in_secs == 1 { + format!( + "Retrying. Next attempt in 1 second (Attempt {} of {}).", + state.attempt, state.max_attempts, + ) + } else { + format!( + "Retrying. Next attempt in {next_attempt_in_secs} seconds (Attempt {} of {}).", + state.attempt, state.max_attempts, + ) }; - for path in paths { - let Some(entry) = self.project.read(cx).entry_for_path(&path, cx) else { - continue; - }; - let Some(abs_path) = self.project.read(cx).absolute_path(&path, cx) else { - continue; - }; - let anchor = buffer.update(cx, |buffer, _cx| buffer.anchor_before(buffer.len())); - let path_prefix = abs_path - .file_name() - .unwrap_or(path.path.as_os_str()) - .display() - .to_string(); - let Some(completion) = ContextPickerCompletionProvider::completion_for_path( - path, - &path_prefix, - false, - entry.is_dir(), - excerpt_id, - anchor..anchor, - self.message_editor.clone(), - self.mention_set.clone(), - self.project.clone(), - cx, - ) else { - continue; - }; + Some( + Callout::new() + .severity(Severity::Warning) + .title(state.last_error.clone()) + .description(retry_message), + ) + } - self.message_editor.update(cx, |message_editor, cx| { - message_editor.edit( - [( - multi_buffer::Anchor::max()..multi_buffer::Anchor::max(), - completion.new_text, - )], - cx, - ); - }); - if let Some(confirm) = completion.confirm.clone() { - confirm(CompletionIntent::Complete, window, cx); + fn render_thread_error(&self, window: &mut Window, cx: &mut Context) -> Option
{ + let content = match self.thread_error.as_ref()? { + ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx), + ThreadError::Refusal => self.render_refusal_error(cx), + ThreadError::AuthenticationRequired(error) => { + self.render_authentication_required_error(error.clone(), cx) + } + ThreadError::PaymentRequired => self.render_payment_required_error(cx), + ThreadError::ModelRequestLimitReached(plan) => { + self.render_model_request_limit_reached_error(*plan, cx) } + ThreadError::ToolUseLimitReached => { + self.render_tool_use_limit_reached_error(window, cx)? + } + }; + + Some(div().child(content)) + } + + fn render_new_version_callout(&self, version: &SharedString, cx: &mut Context) -> Div { + v_flex().w_full().justify_end().child( + h_flex() + .p_2() + .pr_3() + .w_full() + .gap_1p5() + .border_t_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().element_background) + .child( + h_flex() + .flex_1() + .gap_1p5() + .child( + Icon::new(IconName::Download) + .color(Color::Accent) + .size(IconSize::Small), + ) + .child(Label::new("New version available").size(LabelSize::Small)), + ) + .child( + Button::new("update-button", format!("Update to v{}", version)) + .label_size(LabelSize::Small) + .style(ButtonStyle::Tinted(TintColor::Accent)) + .on_click(cx.listener(|this, _, window, cx| { + this.reset(window, cx); + })), + ), + ) + } + + fn get_current_model_name(&self, cx: &App) -> SharedString { + // For native agent (Zed Agent), use the specific model name (e.g., "Claude 3.5 Sonnet") + // For ACP agents, use the agent name (e.g., "Claude Code", "Gemini CLI") + // This provides better clarity about what refused the request + if self + .agent + .clone() + .downcast::() + .is_some() + { + // Native agent - use the model name + self.model_selector + .as_ref() + .and_then(|selector| selector.read(cx).active_model_name(cx)) + .unwrap_or_else(|| SharedString::from("The model")) + } else { + // ACP agent - use the agent name (e.g., "Claude Code", "Gemini CLI") + self.agent.name() } } + + fn render_refusal_error(&self, cx: &mut Context<'_, Self>) -> Callout { + let model_or_agent_name = self.get_current_model_name(cx); + let refusal_message = format!( + "{} refused to respond to this prompt. This can happen when a model believes the prompt violates its content policy or safety guidelines, so rephrasing it can sometimes address the issue.", + model_or_agent_name + ); + + Callout::new() + .severity(Severity::Error) + .title("Request Refused") + .icon(IconName::XCircle) + .description(refusal_message.clone()) + .actions_slot(self.create_copy_button(&refusal_message)) + .dismiss_action(self.dismiss_error_button(cx)) + } + + fn render_any_thread_error(&self, error: SharedString, cx: &mut Context<'_, Self>) -> Callout { + let can_resume = self + .thread() + .map_or(false, |thread| thread.read(cx).can_resume(cx)); + + let can_enable_burn_mode = self.as_native_thread(cx).map_or(false, |thread| { + let thread = thread.read(cx); + let supports_burn_mode = thread + .model() + .map_or(false, |model| model.supports_burn_mode()); + supports_burn_mode && thread.completion_mode() == CompletionMode::Normal + }); + + Callout::new() + .severity(Severity::Error) + .title("Error") + .icon(IconName::XCircle) + .description(error.clone()) + .actions_slot( + h_flex() + .gap_0p5() + .when(can_resume && can_enable_burn_mode, |this| { + this.child( + Button::new("enable-burn-mode-and-retry", "Enable Burn Mode and Retry") + .icon(IconName::ZedBurnMode) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .label_size(LabelSize::Small) + .on_click(cx.listener(|this, _, window, cx| { + this.toggle_burn_mode(&ToggleBurnMode, window, cx); + this.resume_chat(cx); + })), + ) + }) + .when(can_resume, |this| { + this.child( + Button::new("retry", "Retry") + .icon(IconName::RotateCw) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .label_size(LabelSize::Small) + .on_click(cx.listener(|this, _, _window, cx| { + this.resume_chat(cx); + })), + ) + }) + .child(self.create_copy_button(error.to_string())), + ) + .dismiss_action(self.dismiss_error_button(cx)) + } + + fn render_payment_required_error(&self, cx: &mut Context) -> Callout { + const ERROR_MESSAGE: &str = + "You reached your free usage limit. Upgrade to Zed Pro for more prompts."; + + Callout::new() + .severity(Severity::Error) + .icon(IconName::XCircle) + .title("Free Usage Exceeded") + .description(ERROR_MESSAGE) + .actions_slot( + h_flex() + .gap_0p5() + .child(self.upgrade_button(cx)) + .child(self.create_copy_button(ERROR_MESSAGE)), + ) + .dismiss_action(self.dismiss_error_button(cx)) + } + + fn render_authentication_required_error( + &self, + error: SharedString, + cx: &mut Context, + ) -> Callout { + Callout::new() + .severity(Severity::Error) + .title("Authentication Required") + .icon(IconName::XCircle) + .description(error.clone()) + .actions_slot( + h_flex() + .gap_0p5() + .child(self.authenticate_button(cx)) + .child(self.create_copy_button(error)), + ) + .dismiss_action(self.dismiss_error_button(cx)) + } + + fn render_model_request_limit_reached_error( + &self, + plan: cloud_llm_client::Plan, + cx: &mut Context, + ) -> Callout { + let error_message = match plan { + cloud_llm_client::Plan::V1(PlanV1::ZedPro) => { + "Upgrade to usage-based billing for more prompts." + } + cloud_llm_client::Plan::V1(PlanV1::ZedProTrial) + | cloud_llm_client::Plan::V1(PlanV1::ZedFree) => "Upgrade to Zed Pro for more prompts.", + cloud_llm_client::Plan::V2(_) => "", + }; + + Callout::new() + .severity(Severity::Error) + .title("Model Prompt Limit Reached") + .icon(IconName::XCircle) + .description(error_message) + .actions_slot( + h_flex() + .gap_0p5() + .child(self.upgrade_button(cx)) + .child(self.create_copy_button(error_message)), + ) + .dismiss_action(self.dismiss_error_button(cx)) + } + + fn render_tool_use_limit_reached_error( + &self, + window: &mut Window, + cx: &mut Context, + ) -> Option { + let thread = self.as_native_thread(cx)?; + let supports_burn_mode = thread + .read(cx) + .model() + .is_some_and(|model| model.supports_burn_mode()); + + let focus_handle = self.focus_handle(cx); + + Some( + Callout::new() + .icon(IconName::Info) + .title("Consecutive tool use limit reached.") + .actions_slot( + h_flex() + .gap_0p5() + .when(supports_burn_mode, |this| { + this.child( + Button::new("continue-burn-mode", "Continue with Burn Mode") + .style(ButtonStyle::Filled) + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .layer(ElevationIndex::ModalSurface) + .label_size(LabelSize::Small) + .key_binding( + KeyBinding::for_action_in( + &ContinueWithBurnMode, + &focus_handle, + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(10.))), + ) + .tooltip(Tooltip::text( + "Enable Burn Mode for unlimited tool use.", + )) + .on_click({ + cx.listener(move |this, _, _window, cx| { + thread.update(cx, |thread, cx| { + thread + .set_completion_mode(CompletionMode::Burn, cx); + }); + this.resume_chat(cx); + }) + }), + ) + }) + .child( + Button::new("continue-conversation", "Continue") + .layer(ElevationIndex::ModalSurface) + .label_size(LabelSize::Small) + .key_binding( + KeyBinding::for_action_in( + &ContinueThread, + &focus_handle, + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(10.))), + ) + .on_click(cx.listener(|this, _, _window, cx| { + this.resume_chat(cx); + })), + ), + ), + ) + } + + fn create_copy_button(&self, message: impl Into) -> impl IntoElement { + let message = message.into(); + + IconButton::new("copy", IconName::Copy) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .tooltip(Tooltip::text("Copy Error Message")) + .on_click(move |_, _, cx| { + cx.write_to_clipboard(ClipboardItem::new_string(message.clone())) + }) + } + + fn dismiss_error_button(&self, cx: &mut Context) -> impl IntoElement { + IconButton::new("dismiss", IconName::Close) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .tooltip(Tooltip::text("Dismiss Error")) + .on_click(cx.listener({ + move |this, _, _, cx| { + this.clear_thread_error(cx); + cx.notify(); + } + })) + } + + fn authenticate_button(&self, cx: &mut Context) -> impl IntoElement { + Button::new("authenticate", "Authenticate") + .label_size(LabelSize::Small) + .style(ButtonStyle::Filled) + .on_click(cx.listener({ + move |this, _, window, cx| { + let agent = this.agent.clone(); + let ThreadState::Ready { thread, .. } = &this.thread_state else { + return; + }; + + let connection = thread.read(cx).connection().clone(); + let err = AuthRequired { + description: None, + provider_id: None, + }; + this.clear_thread_error(cx); + let this = cx.weak_entity(); + window.defer(cx, |window, cx| { + Self::handle_auth_required(this, err, agent, connection, window, cx); + }) + } + })) + } + + pub(crate) fn reauthenticate(&mut self, window: &mut Window, cx: &mut Context) { + let agent = self.agent.clone(); + let ThreadState::Ready { thread, .. } = &self.thread_state else { + return; + }; + + let connection = thread.read(cx).connection().clone(); + let err = AuthRequired { + description: None, + provider_id: None, + }; + self.clear_thread_error(cx); + let this = cx.weak_entity(); + window.defer(cx, |window, cx| { + Self::handle_auth_required(this, err, agent, connection, window, cx); + }) + } + + fn upgrade_button(&self, cx: &mut Context) -> impl IntoElement { + Button::new("upgrade", "Upgrade") + .label_size(LabelSize::Small) + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click(cx.listener({ + move |this, _, _, cx| { + this.clear_thread_error(cx); + cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)); + } + })) + } + + pub fn delete_history_entry(&mut self, entry: HistoryEntry, cx: &mut Context) { + let task = match entry { + HistoryEntry::AcpThread(thread) => self.history_store.update(cx, |history, cx| { + history.delete_thread(thread.id.clone(), cx) + }), + HistoryEntry::TextThread(context) => self.history_store.update(cx, |history, cx| { + history.delete_text_thread(context.path.clone(), cx) + }), + }; + task.detach_and_log_err(cx); + } +} + +fn loading_contents_spinner(size: IconSize) -> AnyElement { + Icon::new(IconName::LoadCircle) + .size(size) + .color(Color::Accent) + .with_rotate_animation(3) + .into_any_element() } impl Focusable for AcpThreadView { fn focus_handle(&self, cx: &App) -> FocusHandle { - self.message_editor.focus_handle(cx) + match self.thread_state { + ThreadState::Loading { .. } | ThreadState::Ready { .. } => { + self.message_editor.focus_handle(cx) + } + ThreadState::LoadError(_) | ThreadState::Unauthenticated { .. } => { + self.focus_handle.clone() + } + } } } impl Render for AcpThreadView { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let has_messages = self.list_state.item_count() > 0; + let line_height = TextSize::Small.rems(cx).to_pixels(window.rem_size()) * 1.5; v_flex() .size_full() .key_context("AcpThread") - .on_action(cx.listener(Self::chat)) - .on_action(cx.listener(Self::previous_history_message)) - .on_action(cx.listener(Self::next_history_message)) .on_action(cx.listener(Self::open_agent_diff)) + .on_action(cx.listener(Self::toggle_burn_mode)) + .on_action(cx.listener(Self::keep_all)) + .on_action(cx.listener(Self::reject_all)) + .on_action(cx.listener(Self::allow_always)) + .on_action(cx.listener(Self::allow_once)) + .on_action(cx.listener(Self::reject_once)) + .track_focus(&self.focus_handle) .bg(cx.theme().colors().panel_background) .child(match &self.thread_state { - ThreadState::Unauthenticated { connection } => v_flex() - .p_2() + ThreadState::Unauthenticated { + connection, + description, + configuration_view, + pending_auth_method, + .. + } => self.render_auth_required_state( + connection, + description.as_ref(), + configuration_view.as_ref(), + pending_auth_method.as_ref(), + window, + cx, + ), + ThreadState::Loading { .. } => v_flex() .flex_1() - .items_center() - .justify_center() - .child(self.render_pending_auth_state()) - .child(h_flex().mt_1p5().justify_center().children( - connection.auth_methods().into_iter().map(|method| { - Button::new( - SharedString::from(method.id.0.clone()), - method.name.clone(), - ) - .on_click({ - let method_id = method.id.clone(); - cx.listener(move |this, _, window, cx| { - this.authenticate(method_id.clone(), window, cx) - }) - }) - }), - )), - ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), + .child(self.render_recent_history(window, cx)), ThreadState::LoadError(e) => v_flex() - .p_2() - .flex_1() - .items_center() - .justify_center() - .child(self.render_load_error(e, cx)), - ThreadState::ServerExited { status } => v_flex() - .p_2() .flex_1() + .size_full() .items_center() - .justify_center() - .child(self.render_server_exited(*status, cx)), - ThreadState::Ready { thread, .. } => { - let thread_clone = thread.clone(); - - v_flex().flex_1().map(|this| { - if has_messages { - this.child( - list( - self.list_state.clone(), - cx.processor(|this, index: usize, window, cx| { - let Some((entry, len)) = this.thread().and_then(|thread| { - let entries = &thread.read(cx).entries(); - Some((entries.get(index)?, entries.len())) - }) else { - return Empty.into_any(); - }; - this.render_entry(index, len, entry, window, cx) - }), - ) - .with_sizing_behavior(gpui::ListSizingBehavior::Auto) - .flex_grow() - .into_any(), - ) - .child(self.render_vertical_scrollbar(cx)) - .children( - match thread_clone.read(cx).status() { - ThreadStatus::Idle - | ThreadStatus::WaitingForToolConfirmation => None, - ThreadStatus::Generating => div() - .px_5() - .py_2() - .child(LoadingLabel::new("").size(LabelSize::Small)) - .into(), - }, + .justify_end() + .child(self.render_load_error(e, window, cx)), + ThreadState::Ready { .. } => v_flex().flex_1().map(|this| { + if has_messages { + this.child( + list( + self.list_state.clone(), + cx.processor(|this, index: usize, window, cx| { + let Some((entry, len)) = this.thread().and_then(|thread| { + let entries = &thread.read(cx).entries(); + Some((entries.get(index)?, entries.len())) + }) else { + return Empty.into_any(); + }; + this.render_entry(index, len, entry, window, cx) + }), ) - } else { - this.child(self.render_empty_state(cx)) - } - }) - } + .with_sizing_behavior(gpui::ListSizingBehavior::Auto) + .flex_grow() + .into_any(), + ) + .child(self.render_vertical_scrollbar(cx)) + } else { + this.child(self.render_recent_history(window, cx)) + } + }), }) // The activity bar is intentionally rendered outside of the ThreadState::Ready match // above so that the scrollbar doesn't render behind it. The current setup allows @@ -3330,53 +5419,32 @@ impl Render for AcpThreadView { } _ => this, }) - .when_some(self.last_error.clone(), |el, error| { - el.child( - div() - .p_2() - .text_xs() - .border_t_1() - .border_color(cx.theme().colors().border) - .bg(cx.theme().status().error_background) - .child( - self.render_markdown(error, default_markdown_style(false, window, cx)), - ), - ) - }) + .children(self.render_thread_retry_status_callout(window, cx)) + .children(self.render_thread_error(window, cx)) + .when_some( + self.new_server_version_available.as_ref().filter(|_| { + !has_messages || !matches!(self.thread_state, ThreadState::Ready { .. }) + }), + |this, version| this.child(self.render_new_version_callout(&version, cx)), + ) + .children( + if let Some(usage_callout) = self.render_usage_callout(line_height, cx) { + Some(usage_callout.into_any_element()) + } else { + self.render_token_limit_callout(line_height, cx) + .map(|token_limit_callout| token_limit_callout.into_any_element()) + }, + ) .child(self.render_message_editor(window, cx)) } } -fn user_message_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { - let mut style = default_markdown_style(false, window, cx); - let mut text_style = window.text_style(); - let theme_settings = ThemeSettings::get_global(cx); - - let buffer_font = theme_settings.buffer_font.family.clone(); - let buffer_font_size = TextSize::Small.rems(cx); - - text_style.refine(&TextStyleRefinement { - font_family: Some(buffer_font), - font_size: Some(buffer_font_size.into()), - ..Default::default() - }); - - style.base_text_style = text_style; - style.link_callback = Some(Rc::new(move |url, cx| { - if MentionUri::parse(url).is_ok() { - let colors = cx.theme().colors(); - Some(TextStyleRefinement { - background_color: Some(colors.element_background), - ..Default::default() - }) - } else { - None - } - })); - style -} - -fn default_markdown_style(buffer_font: bool, window: &Window, cx: &App) -> MarkdownStyle { +fn default_markdown_style( + buffer_font: bool, + muted_text: bool, + window: &Window, + cx: &App, +) -> MarkdownStyle { let theme_settings = ThemeSettings::get_global(cx); let colors = cx.theme().colors(); @@ -3397,20 +5465,26 @@ fn default_markdown_style(buffer_font: bool, window: &Window, cx: &App) -> Markd TextSize::Default.rems(cx) }; + let text_color = if muted_text { + colors.text_muted + } else { + colors.text + }; + text_style.refine(&TextStyleRefinement { font_family: Some(font_family), font_fallbacks: theme_settings.ui_font.fallbacks.clone(), font_features: Some(theme_settings.ui_font.features.clone()), font_size: Some(font_size.into()), line_height: Some(line_height.into()), - color: Some(cx.theme().colors().text), + color: Some(text_color), ..Default::default() }); MarkdownStyle { base_text_style: text_style.clone(), syntax: cx.theme().syntax().clone(), - selection_background_color: cx.theme().colors().element_selection_background, + selection_background_color: colors.element_selection_background, code_block_overflow_x_scroll: true, table_overflow_x_scroll: true, heading_level_styles: Some(HeadingLevelStyles { @@ -3496,7 +5570,7 @@ fn plan_label_markdown_style( window: &Window, cx: &App, ) -> MarkdownStyle { - let default_md_style = default_markdown_style(false, window, cx); + let default_md_style = default_markdown_style(false, false, window, cx); MarkdownStyle { base_text_style: TextStyle { @@ -3515,20 +5589,8 @@ fn plan_label_markdown_style( } } -fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement { - TextStyleRefinement { - font_size: Some( - TextSize::Small - .rems(cx) - .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) - .into(), - ), - ..Default::default() - } -} - fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { - let default_md_style = default_markdown_style(true, window, cx); + let default_md_style = default_markdown_style(true, false, window, cx); MarkdownStyle { base_text_style: TextStyle { @@ -3539,377 +5601,980 @@ fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { } } -#[cfg(test)] -mod tests { - use agent::{TextThreadStore, ThreadStore}; - use agent_client_protocol::SessionId; - use editor::EditorSettings; - use fs::FakeFs; - use futures::future::try_join_all; - use gpui::{SemanticVersion, TestAppContext, VisualTestContext}; - use rand::Rng; - use settings::SettingsStore; +#[cfg(test)] +pub(crate) mod tests { + use acp_thread::StubAgentConnection; + use agent_client_protocol::SessionId; + use assistant_context::ContextStore; + use editor::EditorSettings; + use fs::FakeFs; + use gpui::{EventEmitter, SemanticVersion, TestAppContext, VisualTestContext}; + use project::Project; + use serde_json::json; + use settings::SettingsStore; + use std::any::Any; + use std::path::Path; + use workspace::Item; + + use super::*; + + #[gpui::test] + async fn test_drop(cx: &mut TestAppContext) { + init_test(cx); + + let (thread_view, _cx) = setup_thread_view(StubAgentServer::default_response(), cx).await; + let weak_view = thread_view.downgrade(); + drop(thread_view); + assert!(!weak_view.is_upgradable()); + } + + #[gpui::test] + async fn test_notification_for_stop_event(cx: &mut TestAppContext) { + init_test(cx); + + let (thread_view, cx) = setup_thread_view(StubAgentServer::default_response(), cx).await; + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello", window, cx); + }); + + cx.deactivate_window(); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.send(window, cx); + }); + + cx.run_until_parked(); + + assert!( + cx.windows() + .iter() + .any(|window| window.downcast::().is_some()) + ); + } + + #[gpui::test] + async fn test_notification_for_error(cx: &mut TestAppContext) { + init_test(cx); + + let (thread_view, cx) = + setup_thread_view(StubAgentServer::new(SaboteurAgentConnection), cx).await; + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello", window, cx); + }); + + cx.deactivate_window(); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.send(window, cx); + }); + + cx.run_until_parked(); + + assert!( + cx.windows() + .iter() + .any(|window| window.downcast::().is_some()) + ); + } + + #[gpui::test] + async fn test_refusal_handling(cx: &mut TestAppContext) { + init_test(cx); + + let (thread_view, cx) = + setup_thread_view(StubAgentServer::new(RefusalAgentConnection), cx).await; + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Do something harmful", window, cx); + }); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.send(window, cx); + }); + + cx.run_until_parked(); + + // Check that the refusal error is set + thread_view.read_with(cx, |thread_view, _cx| { + assert!( + matches!(thread_view.thread_error, Some(ThreadError::Refusal)), + "Expected refusal error to be set" + ); + }); + } + + #[gpui::test] + async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) { + init_test(cx); + + let tool_call_id = acp::ToolCallId("1".into()); + let tool_call = acp::ToolCall { + id: tool_call_id.clone(), + title: "Label".into(), + kind: acp::ToolKind::Edit, + status: acp::ToolCallStatus::Pending, + content: vec!["hi".into()], + locations: vec![], + raw_input: None, + raw_output: None, + }; + let connection = + StubAgentConnection::new().with_permission_requests(HashMap::from_iter([( + tool_call_id, + vec![acp::PermissionOption { + id: acp::PermissionOptionId("1".into()), + name: "Allow".into(), + kind: acp::PermissionOptionKind::AllowOnce, + }], + )])); + + connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(tool_call)]); + + let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await; + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello", window, cx); + }); + + cx.deactivate_window(); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.send(window, cx); + }); + + cx.run_until_parked(); + + assert!( + cx.windows() + .iter() + .any(|window| window.downcast::().is_some()) + ); + } + + async fn setup_thread_view( + agent: impl AgentServer + 'static, + cx: &mut TestAppContext, + ) -> (Entity, &mut VisualTestContext) { + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let context_store = + cx.update(|_window, cx| cx.new(|cx| ContextStore::fake(project.clone(), cx))); + let history_store = + cx.update(|_window, cx| cx.new(|cx| HistoryStore::new(context_store, cx))); + + let thread_view = cx.update(|window, cx| { + cx.new(|cx| { + AcpThreadView::new( + Rc::new(agent), + None, + None, + workspace.downgrade(), + project, + history_store, + None, + window, + cx, + ) + }) + }); + cx.run_until_parked(); + (thread_view, cx) + } + + fn add_to_workspace(thread_view: Entity, cx: &mut VisualTestContext) { + let workspace = thread_view.read_with(cx, |thread_view, _cx| thread_view.workspace.clone()); + + workspace + .update_in(cx, |workspace, window, cx| { + workspace.add_item_to_active_pane( + Box::new(cx.new(|_| ThreadViewItem(thread_view.clone()))), + None, + true, + window, + cx, + ); + }) + .unwrap(); + } + + struct ThreadViewItem(Entity); + + impl Item for ThreadViewItem { + type Event = (); + + fn include_in_nav_history() -> bool { + false + } + + fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { + "Test".into() + } + } + + impl EventEmitter<()> for ThreadViewItem {} + + impl Focusable for ThreadViewItem { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.0.read(cx).focus_handle(cx) + } + } + + impl Render for ThreadViewItem { + fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { + self.0.clone().into_any_element() + } + } + + struct StubAgentServer { + connection: C, + } + + impl StubAgentServer { + fn new(connection: C) -> Self { + Self { connection } + } + } + + impl StubAgentServer { + fn default_response() -> Self { + let conn = StubAgentConnection::new(); + conn.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk { + content: "Default response".into(), + }]); + Self::new(conn) + } + } + + impl AgentServer for StubAgentServer + where + C: 'static + AgentConnection + Send + Clone, + { + fn telemetry_id(&self) -> &'static str { + "test" + } + + fn logo(&self) -> ui::IconName { + ui::IconName::Ai + } + + fn name(&self) -> SharedString { + "Test".into() + } + + fn connect( + &self, + _root_dir: Option<&Path>, + _delegate: AgentServerDelegate, + _cx: &mut App, + ) -> Task, Option)>> { + Task::ready(Ok((Rc::new(self.connection.clone()), None))) + } + + fn into_any(self: Rc) -> Rc { + self + } + } + + #[derive(Clone)] + struct SaboteurAgentConnection; + + impl AgentConnection for SaboteurAgentConnection { + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::App, + ) -> Task>> { + Task::ready(Ok(cx.new(|cx| { + let action_log = cx.new(|_| ActionLog::new(project.clone())); + AcpThread::new( + "SaboteurAgentConnection", + self, + project, + action_log, + SessionId("test".into()), + watch::Receiver::constant(acp::PromptCapabilities { + image: true, + audio: true, + embedded_context: true, + }), + cx, + ) + }))) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn authenticate( + &self, + _method_id: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { + unimplemented!() + } + + fn prompt( + &self, + _id: Option, + _params: acp::PromptRequest, + _cx: &mut App, + ) -> Task> { + Task::ready(Err(anyhow::anyhow!("Error prompting"))) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { + unimplemented!() + } + + fn into_any(self: Rc) -> Rc { + self + } + } + + /// Simulates a model which always returns a refusal response + #[derive(Clone)] + struct RefusalAgentConnection; + + impl AgentConnection for RefusalAgentConnection { + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::App, + ) -> Task>> { + Task::ready(Ok(cx.new(|cx| { + let action_log = cx.new(|_| ActionLog::new(project.clone())); + AcpThread::new( + "RefusalAgentConnection", + self, + project, + action_log, + SessionId("test".into()), + watch::Receiver::constant(acp::PromptCapabilities { + image: true, + audio: true, + embedded_context: true, + }), + cx, + ) + }))) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn authenticate( + &self, + _method_id: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { + unimplemented!() + } + + fn prompt( + &self, + _id: Option, + _params: acp::PromptRequest, + _cx: &mut App, + ) -> Task> { + Task::ready(Ok(acp::PromptResponse { + stop_reason: acp::StopReason::Refusal, + })) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { + unimplemented!() + } + + fn into_any(self: Rc) -> Rc { + self + } + } + + pub(crate) fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + AgentSettings::register(cx); + workspace::init_settings(cx); + ThemeSettings::register(cx); + release_channel::init(SemanticVersion::default(), cx); + EditorSettings::register(cx); + prompt_store::init(cx) + }); + } + + #[gpui::test] + async fn test_rewind_views(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + "test1.txt": "old content 1", + "test2.txt": "old content 2" + }), + ) + .await; + let project = Project::test(fs, [Path::new("/project")], cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let context_store = + cx.update(|_window, cx| cx.new(|cx| ContextStore::fake(project.clone(), cx))); + let history_store = + cx.update(|_window, cx| cx.new(|cx| HistoryStore::new(context_store, cx))); + + let connection = Rc::new(StubAgentConnection::new()); + let thread_view = cx.update(|window, cx| { + cx.new(|cx| { + AcpThreadView::new( + Rc::new(StubAgentServer::new(connection.as_ref().clone())), + None, + None, + workspace.downgrade(), + project.clone(), + history_store.clone(), + None, + window, + cx, + ) + }) + }); + + cx.run_until_parked(); + + let thread = thread_view + .read_with(cx, |view, _| view.thread().cloned()) + .unwrap(); + + // First user message + connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(acp::ToolCall { + id: acp::ToolCallId("tool1".into()), + title: "Edit file 1".into(), + kind: acp::ToolKind::Edit, + status: acp::ToolCallStatus::Completed, + content: vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: "/project/test1.txt".into(), + old_text: Some("old content 1".into()), + new_text: "new content 1".into(), + }, + }], + locations: vec![], + raw_input: None, + raw_output: None, + })]); + + thread + .update(cx, |thread, cx| thread.send_raw("Give me a diff", cx)) + .await + .unwrap(); + cx.run_until_parked(); + + thread.read_with(cx, |thread, _| { + assert_eq!(thread.entries().len(), 2); + }); + + thread_view.read_with(cx, |view, cx| { + view.entry_view_state.read_with(cx, |entry_view_state, _| { + assert!( + entry_view_state + .entry(0) + .unwrap() + .message_editor() + .is_some() + ); + assert!(entry_view_state.entry(1).unwrap().has_content()); + }); + }); + + // Second user message + connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(acp::ToolCall { + id: acp::ToolCallId("tool2".into()), + title: "Edit file 2".into(), + kind: acp::ToolKind::Edit, + status: acp::ToolCallStatus::Completed, + content: vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: "/project/test2.txt".into(), + old_text: Some("old content 2".into()), + new_text: "new content 2".into(), + }, + }], + locations: vec![], + raw_input: None, + raw_output: None, + })]); + + thread + .update(cx, |thread, cx| thread.send_raw("Another one", cx)) + .await + .unwrap(); + cx.run_until_parked(); + + let second_user_message_id = thread.read_with(cx, |thread, _| { + assert_eq!(thread.entries().len(), 4); + let AgentThreadEntry::UserMessage(user_message) = &thread.entries()[2] else { + panic!(); + }; + user_message.id.clone().unwrap() + }); + + thread_view.read_with(cx, |view, cx| { + view.entry_view_state.read_with(cx, |entry_view_state, _| { + assert!( + entry_view_state + .entry(0) + .unwrap() + .message_editor() + .is_some() + ); + assert!(entry_view_state.entry(1).unwrap().has_content()); + assert!( + entry_view_state + .entry(2) + .unwrap() + .message_editor() + .is_some() + ); + assert!(entry_view_state.entry(3).unwrap().has_content()); + }); + }); + + // Rewind to first message + thread + .update(cx, |thread, cx| thread.rewind(second_user_message_id, cx)) + .await + .unwrap(); + + cx.run_until_parked(); + + thread.read_with(cx, |thread, _| { + assert_eq!(thread.entries().len(), 2); + }); + + thread_view.read_with(cx, |view, cx| { + view.entry_view_state.read_with(cx, |entry_view_state, _| { + assert!( + entry_view_state + .entry(0) + .unwrap() + .message_editor() + .is_some() + ); + assert!(entry_view_state.entry(1).unwrap().has_content()); + + // Old views should be dropped + assert!(entry_view_state.entry(2).is_none()); + assert!(entry_view_state.entry(3).is_none()); + }); + }); + } + + #[gpui::test] + async fn test_message_editing_cancel(cx: &mut TestAppContext) { + init_test(cx); + + let connection = StubAgentConnection::new(); + + connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk { + content: acp::ContentBlock::Text(acp::TextContent { + text: "Response".into(), + annotations: None, + }), + }]); + + let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await; + add_to_workspace(thread_view.clone(), cx); + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Original message to edit", window, cx); + }); + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.send(window, cx); + }); + + cx.run_until_parked(); + + let user_message_editor = thread_view.read_with(cx, |view, cx| { + assert_eq!(view.editing_message, None); + + view.entry_view_state + .read(cx) + .entry(0) + .unwrap() + .message_editor() + .unwrap() + .clone() + }); + + // Focus + cx.focus(&user_message_editor); + thread_view.read_with(cx, |view, _cx| { + assert_eq!(view.editing_message, Some(0)); + }); + + // Edit + user_message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Edited message content", window, cx); + }); + + // Cancel + user_message_editor.update_in(cx, |_editor, window, cx| { + window.dispatch_action(Box::new(editor::actions::Cancel), cx); + }); + + thread_view.read_with(cx, |view, _cx| { + assert_eq!(view.editing_message, None); + }); + + user_message_editor.read_with(cx, |editor, cx| { + assert_eq!(editor.text(cx), "Original message to edit"); + }); + } + + #[gpui::test] + async fn test_message_doesnt_send_if_empty(cx: &mut TestAppContext) { + init_test(cx); + + let connection = StubAgentConnection::new(); + + let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await; + add_to_workspace(thread_view.clone(), cx); + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + let mut events = cx.events(&message_editor); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("", window, cx); + }); + + message_editor.update_in(cx, |_editor, window, cx| { + window.dispatch_action(Box::new(Chat), cx); + }); + cx.run_until_parked(); + // We shouldn't have received any messages + assert!(matches!( + events.try_next(), + Err(futures::channel::mpsc::TryRecvError { .. }) + )); + } + + #[gpui::test] + async fn test_message_editing_regenerate(cx: &mut TestAppContext) { + init_test(cx); + + let connection = StubAgentConnection::new(); + + connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk { + content: acp::ContentBlock::Text(acp::TextContent { + text: "Response".into(), + annotations: None, + }), + }]); + + let (thread_view, cx) = + setup_thread_view(StubAgentServer::new(connection.clone()), cx).await; + add_to_workspace(thread_view.clone(), cx); + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Original message to edit", window, cx); + }); + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.send(window, cx); + }); + + cx.run_until_parked(); + + let user_message_editor = thread_view.read_with(cx, |view, cx| { + assert_eq!(view.editing_message, None); + assert_eq!(view.thread().unwrap().read(cx).entries().len(), 2); + + view.entry_view_state + .read(cx) + .entry(0) + .unwrap() + .message_editor() + .unwrap() + .clone() + }); + + // Focus + cx.focus(&user_message_editor); + + // Edit + user_message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Edited message content", window, cx); + }); + + // Send + connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk { + content: acp::ContentBlock::Text(acp::TextContent { + text: "New Response".into(), + annotations: None, + }), + }]); + + user_message_editor.update_in(cx, |_editor, window, cx| { + window.dispatch_action(Box::new(Chat), cx); + }); + + cx.run_until_parked(); - use super::*; + thread_view.read_with(cx, |view, cx| { + assert_eq!(view.editing_message, None); - #[gpui::test] - async fn test_drop(cx: &mut TestAppContext) { - init_test(cx); + let entries = view.thread().unwrap().read(cx).entries(); + assert_eq!(entries.len(), 2); + assert_eq!( + entries[0].to_markdown(cx), + "## User\n\nEdited message content\n\n" + ); + assert_eq!( + entries[1].to_markdown(cx), + "## Assistant\n\nNew Response\n\n" + ); - let (thread_view, _cx) = setup_thread_view(StubAgentServer::default(), cx).await; - let weak_view = thread_view.downgrade(); - drop(thread_view); - assert!(!weak_view.is_upgradable()); + let new_editor = view.entry_view_state.read_with(cx, |state, _cx| { + assert!(!state.entry(1).unwrap().has_content()); + state.entry(0).unwrap().message_editor().unwrap().clone() + }); + + assert_eq!(new_editor.read(cx).text(cx), "Edited message content"); + }) } #[gpui::test] - async fn test_notification_for_stop_event(cx: &mut TestAppContext) { + async fn test_message_editing_while_generating(cx: &mut TestAppContext) { init_test(cx); - let (thread_view, cx) = setup_thread_view(StubAgentServer::default(), cx).await; + let connection = StubAgentConnection::new(); + + let (thread_view, cx) = + setup_thread_view(StubAgentServer::new(connection.clone()), cx).await; + add_to_workspace(thread_view.clone(), cx); let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); message_editor.update_in(cx, |editor, window, cx| { - editor.set_text("Hello", window, cx); + editor.set_text("Original message to edit", window, cx); }); - - cx.deactivate_window(); - thread_view.update_in(cx, |thread_view, window, cx| { - thread_view.chat(&Chat, window, cx); + thread_view.send(window, cx); }); cx.run_until_parked(); - assert!( - cx.windows() - .iter() - .any(|window| window.downcast::().is_some()) - ); - } + let (user_message_editor, session_id) = thread_view.read_with(cx, |view, cx| { + let thread = view.thread().unwrap().read(cx); + assert_eq!(thread.entries().len(), 1); - #[gpui::test] - async fn test_notification_for_error(cx: &mut TestAppContext) { - init_test(cx); + let editor = view + .entry_view_state + .read(cx) + .entry(0) + .unwrap() + .message_editor() + .unwrap() + .clone(); - let (thread_view, cx) = - setup_thread_view(StubAgentServer::new(SaboteurAgentConnection), cx).await; + (editor, thread.session_id().clone()) + }); - let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); - message_editor.update_in(cx, |editor, window, cx| { - editor.set_text("Hello", window, cx); + // Focus + cx.focus(&user_message_editor); + + thread_view.read_with(cx, |view, _cx| { + assert_eq!(view.editing_message, Some(0)); }); - cx.deactivate_window(); + // Edit + user_message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Edited message content", window, cx); + }); - thread_view.update_in(cx, |thread_view, window, cx| { - thread_view.chat(&Chat, window, cx); + thread_view.read_with(cx, |view, _cx| { + assert_eq!(view.editing_message, Some(0)); + }); + + // Finish streaming response + cx.update(|_, cx| { + connection.send_update( + session_id.clone(), + acp::SessionUpdate::AgentMessageChunk { + content: acp::ContentBlock::Text(acp::TextContent { + text: "Response".into(), + annotations: None, + }), + }, + cx, + ); + connection.end_turn(session_id, acp::StopReason::EndTurn); + }); + + thread_view.read_with(cx, |view, _cx| { + assert_eq!(view.editing_message, Some(0)); }); cx.run_until_parked(); - assert!( - cx.windows() - .iter() - .any(|window| window.downcast::().is_some()) - ); + // Should still be editing + cx.update(|window, cx| { + assert!(user_message_editor.focus_handle(cx).is_focused(window)); + assert_eq!(thread_view.read(cx).editing_message, Some(0)); + assert_eq!( + user_message_editor.read(cx).text(cx), + "Edited message content" + ); + }); } #[gpui::test] - async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) { + async fn test_interrupt(cx: &mut TestAppContext) { init_test(cx); - let tool_call_id = acp::ToolCallId("1".into()); - let tool_call = acp::ToolCall { - id: tool_call_id.clone(), - title: "Label".into(), - kind: acp::ToolKind::Edit, - status: acp::ToolCallStatus::Pending, - content: vec!["hi".into()], - locations: vec![], - raw_input: None, - raw_output: None, - }; - let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)]) - .with_permission_requests(HashMap::from_iter([( - tool_call_id, - vec![acp::PermissionOption { - id: acp::PermissionOptionId("1".into()), - name: "Allow".into(), - kind: acp::PermissionOptionKind::AllowOnce, - }], - )])); - let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await; + let connection = StubAgentConnection::new(); + + let (thread_view, cx) = + setup_thread_view(StubAgentServer::new(connection.clone()), cx).await; + add_to_workspace(thread_view.clone(), cx); let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); message_editor.update_in(cx, |editor, window, cx| { - editor.set_text("Hello", window, cx); + editor.set_text("Message 1", window, cx); + }); + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.send(window, cx); }); - cx.deactivate_window(); + let (thread, session_id) = thread_view.read_with(cx, |view, cx| { + let thread = view.thread().unwrap(); - thread_view.update_in(cx, |thread_view, window, cx| { - thread_view.chat(&Chat, window, cx); + (thread.clone(), thread.read(cx).session_id().clone()) }); cx.run_until_parked(); - assert!( - cx.windows() - .iter() - .any(|window| window.downcast::().is_some()) - ); - } + cx.update(|_, cx| { + connection.send_update( + session_id.clone(), + acp::SessionUpdate::AgentMessageChunk { + content: "Message 1 resp".into(), + }, + cx, + ); + }); - async fn setup_thread_view( - agent: impl AgentServer + 'static, - cx: &mut TestAppContext, - ) -> (Entity, &mut VisualTestContext) { - let fs = FakeFs::new(cx.executor()); - let project = Project::test(fs, [], cx).await; - let (workspace, cx) = - cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + cx.run_until_parked(); - let thread_store = - cx.update(|_window, cx| cx.new(|cx| ThreadStore::fake(project.clone(), cx))); - let text_thread_store = - cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx))); + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc::indoc! {" + ## User - let thread_view = cx.update(|window, cx| { - cx.new(|cx| { - AcpThreadView::new( - Rc::new(agent), - workspace.downgrade(), - project, - thread_store.clone(), - text_thread_store.clone(), - Rc::new(RefCell::new(MessageHistory::default())), - 1, - None, - window, - cx, - ) - }) - }); - cx.run_until_parked(); - (thread_view, cx) - } + Message 1 - struct StubAgentServer { - connection: C, - } + ## Assistant - impl StubAgentServer { - fn new(connection: C) -> Self { - Self { connection } - } - } + Message 1 resp - impl StubAgentServer { - fn default() -> Self { - Self::new(StubAgentConnection::default()) - } - } + "} + ) + }); - impl AgentServer for StubAgentServer - where - C: 'static + AgentConnection + Send + Clone, - { - fn logo(&self) -> ui::IconName { - unimplemented!() - } + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Message 2", window, cx); + }); + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.send(window, cx); + }); - fn name(&self) -> &'static str { - unimplemented!() - } + cx.update(|_, cx| { + // Simulate a response sent after beginning to cancel + connection.send_update( + session_id.clone(), + acp::SessionUpdate::AgentMessageChunk { + content: "onse".into(), + }, + cx, + ); + }); - fn empty_state_headline(&self) -> &'static str { - unimplemented!() - } + cx.run_until_parked(); - fn empty_state_message(&self) -> &'static str { - unimplemented!() - } + // Last Message 1 response should appear before Message 2 + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc::indoc! {" + ## User - fn connect( - &self, - _root_dir: &Path, - _project: &Entity, - _cx: &mut App, - ) -> Task>> { - Task::ready(Ok(Rc::new(self.connection.clone()))) - } - } + Message 1 - #[derive(Clone, Default)] - struct StubAgentConnection { - sessions: Arc>>>, - permission_requests: HashMap>, - updates: Vec, - } + ## Assistant - impl StubAgentConnection { - fn new(updates: Vec) -> Self { - Self { - updates, - permission_requests: HashMap::default(), - sessions: Arc::default(), - } - } + Message 1 response - fn with_permission_requests( - mut self, - permission_requests: HashMap>, - ) -> Self { - self.permission_requests = permission_requests; - self - } - } + ## User - impl AgentConnection for StubAgentConnection { - fn auth_methods(&self) -> &[acp::AuthMethod] { - &[] - } + Message 2 - fn new_thread( - self: Rc, - project: Entity, - _cwd: &Path, - cx: &mut gpui::AsyncApp, - ) -> Task>> { - let session_id = SessionId( - rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(7) - .map(char::from) - .collect::() - .into(), + "} + ) + }); + + cx.update(|_, cx| { + connection.send_update( + session_id.clone(), + acp::SessionUpdate::AgentMessageChunk { + content: "Message 2 response".into(), + }, + cx, ); - let thread = cx - .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)) - .unwrap(); - self.sessions.lock().insert(session_id, thread.downgrade()); - Task::ready(Ok(thread)) - } + connection.end_turn(session_id.clone(), acp::StopReason::EndTurn); + }); - fn authenticate( - &self, - _method_id: acp::AuthMethodId, - _cx: &mut App, - ) -> Task> { - unimplemented!() - } + cx.run_until_parked(); - fn prompt( - &self, - _id: Option, - params: acp::PromptRequest, - cx: &mut App, - ) -> Task> { - let sessions = self.sessions.lock(); - let thread = sessions.get(¶ms.session_id).unwrap(); - let mut tasks = vec![]; - for update in &self.updates { - let thread = thread.clone(); - let update = update.clone(); - let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update - && let Some(options) = self.permission_requests.get(&tool_call.id) - { - Some((tool_call.clone(), options.clone())) - } else { - None - }; - let task = cx.spawn(async move |cx| { - if let Some((tool_call, options)) = permission_request { - let permission = thread.update(cx, |thread, cx| { - thread.request_tool_call_authorization( - tool_call.clone(), - options.clone(), - cx, - ) - })?; - permission.await?; - } - thread.update(cx, |thread, cx| { - thread.handle_session_update(update.clone(), cx).unwrap(); - })?; - anyhow::Ok(()) - }); - tasks.push(task); - } - cx.spawn(async move |_| { - try_join_all(tasks).await?; - Ok(acp::PromptResponse { - stop_reason: acp::StopReason::EndTurn, - }) - }) - } + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc::indoc! {" + ## User - fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { - unimplemented!() - } - } + Message 1 - #[derive(Clone)] - struct SaboteurAgentConnection; + ## Assistant - impl AgentConnection for SaboteurAgentConnection { - fn new_thread( - self: Rc, - project: Entity, - _cwd: &Path, - cx: &mut gpui::AsyncApp, - ) -> Task>> { - Task::ready(Ok(cx - .new(|cx| { - AcpThread::new( - "SaboteurAgentConnection", - self, - project, - SessionId("test".into()), - cx, - ) - }) - .unwrap())) - } + Message 1 response - fn auth_methods(&self) -> &[acp::AuthMethod] { - &[] - } + ## User - fn authenticate( - &self, - _method_id: acp::AuthMethodId, - _cx: &mut App, - ) -> Task> { - unimplemented!() - } + Message 2 - fn prompt( - &self, - _id: Option, - _params: acp::PromptRequest, - _cx: &mut App, - ) -> Task> { - Task::ready(Err(anyhow::anyhow!("Error prompting"))) - } + ## Assistant - fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { - unimplemented!() - } - } + Message 2 response - fn init_test(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - Project::init_settings(cx); - AgentSettings::register(cx); - workspace::init_settings(cx); - ThemeSettings::register(cx); - release_channel::init(SemanticVersion::default(), cx); - EditorSettings::register(cx); + "} + ) }); } } diff --git a/crates/agent_ui/src/active_thread.rs b/crates/agent_ui/src/active_thread.rs index ffed62d41ff6f76c11ca6def63bff6e99df168c3..6dfadb691f3f9f7200991079da8e18bf11ce8853 100644 --- a/crates/agent_ui/src/active_thread.rs +++ b/crates/agent_ui/src/active_thread.rs @@ -23,9 +23,8 @@ use gpui::{ AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardEntry, ClipboardItem, DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla, ListAlignment, ListOffset, ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful, - StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation, - UnderlineStyle, WeakEntity, WindowHandle, linear_color_stop, linear_gradient, list, percentage, - pulsating_between, + StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, UnderlineStyle, + WeakEntity, WindowHandle, linear_color_stop, linear_gradient, list, pulsating_between, }; use language::{Buffer, Language, LanguageRegistry}; use language_model::{ @@ -46,8 +45,8 @@ use std::time::Duration; use text::ToPoint; use theme::ThemeSettings; use ui::{ - Banner, Disclosure, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState, TextSize, - Tooltip, prelude::*, + Banner, CommonAnimationExt, Disclosure, KeyBinding, PopoverMenuHandle, Scrollbar, + ScrollbarState, TextSize, Tooltip, prelude::*, }; use util::ResultExt as _; use util::markdown::MarkdownCodeBlock; @@ -491,7 +490,7 @@ fn render_markdown_code_block( .on_click({ let active_thread = active_thread.clone(); let parsed_markdown = parsed_markdown.clone(); - let code_block_range = metadata.content_range.clone(); + let code_block_range = metadata.content_range; move |_event, _window, cx| { active_thread.update(cx, |this, cx| { this.copied_code_block_ids.insert((message_id, ix)); @@ -532,7 +531,6 @@ fn render_markdown_code_block( "Expand Code" })) .on_click({ - let active_thread = active_thread.clone(); move |_event, _window, cx| { active_thread.update(cx, |this, cx| { this.toggle_codeblock_expanded(message_id, ix); @@ -780,13 +778,11 @@ impl ActiveThread { let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.)); - let workspace_subscription = if let Some(workspace) = workspace.upgrade() { - Some(cx.observe_release(&workspace, |this, _, cx| { + let workspace_subscription = workspace.upgrade().map(|workspace| { + cx.observe_release(&workspace, |this, _, cx| { this.dismiss_notifications(cx); - })) - } else { - None - }; + }) + }); let mut this = Self { language_registry, @@ -916,7 +912,7 @@ impl ActiveThread { ) { let rendered = self .rendered_tool_uses - .entry(tool_use_id.clone()) + .entry(tool_use_id) .or_insert_with(|| RenderedToolUse { label: cx.new(|cx| { Markdown::new("".into(), Some(self.language_registry.clone()), None, cx) @@ -1005,8 +1001,22 @@ impl ActiveThread { // Don't notify for intermediate tool use } Ok(StopReason::Refusal) => { + let model_name = self + .thread + .read(cx) + .configured_model() + .map(|configured| configured.model.name().0.to_string()) + .unwrap_or_else(|| "The model".to_string()); + let refusal_message = format!( + "{} refused to respond to this prompt. This can happen when a model believes the prompt violates its content policy or safety guidelines, so rephrasing it can sometimes address the issue.", + model_name + ); + self.last_error = Some(ThreadError::Message { + header: SharedString::from("Request Refused"), + message: SharedString::from(refusal_message), + }); self.notify_with_sound( - "Language model refused to respond", + format!("{} refused to respond", model_name), IconName::Warning, window, cx, @@ -1044,12 +1054,12 @@ impl ActiveThread { ); } ThreadEvent::StreamedAssistantText(message_id, text) => { - if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) { + if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(message_id) { rendered_message.append_text(text, cx); } } ThreadEvent::StreamedAssistantThinking(message_id, text) => { - if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) { + if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(message_id) { rendered_message.append_thinking(text, cx); } } @@ -1072,8 +1082,8 @@ impl ActiveThread { } ThreadEvent::MessageEdited(message_id) => { self.clear_last_error(); - if let Some(index) = self.messages.iter().position(|id| id == message_id) { - if let Some(rendered_message) = self.thread.update(cx, |thread, cx| { + if let Some(index) = self.messages.iter().position(|id| id == message_id) + && let Some(rendered_message) = self.thread.update(cx, |thread, cx| { thread.message(*message_id).map(|message| { let mut rendered_message = RenderedMessage { language_registry: self.language_registry.clone(), @@ -1084,14 +1094,14 @@ impl ActiveThread { } rendered_message }) - }) { - self.list_state.splice(index..index + 1, 1); - self.rendered_messages_by_id - .insert(*message_id, rendered_message); - self.scroll_to_bottom(cx); - self.save_thread(cx); - cx.notify(); - } + }) + { + self.list_state.splice(index..index + 1, 1); + self.rendered_messages_by_id + .insert(*message_id, rendered_message); + self.scroll_to_bottom(cx); + self.save_thread(cx); + cx.notify(); } } ThreadEvent::MessageDeleted(message_id) => { @@ -1218,7 +1228,7 @@ impl ActiveThread { match AgentSettings::get_global(cx).notify_when_agent_waiting { NotifyWhenAgentWaiting::PrimaryScreen => { if let Some(primary) = cx.primary_display() { - self.pop_up(icon, caption.into(), title.clone(), window, primary, cx); + self.pop_up(icon, caption.into(), title, window, primary, cx); } } NotifyWhenAgentWaiting::AllScreens => { @@ -1272,62 +1282,61 @@ impl ActiveThread { }) }) .log_err() + && let Some(pop_up) = screen_window.entity(cx).log_err() { - if let Some(pop_up) = screen_window.entity(cx).log_err() { - self.notification_subscriptions - .entry(screen_window) - .or_insert_with(Vec::new) - .push(cx.subscribe_in(&pop_up, window, { - |this, _, event, window, cx| match event { - AgentNotificationEvent::Accepted => { - let handle = window.window_handle(); - cx.activate(true); - - let workspace_handle = this.workspace.clone(); - - // If there are multiple Zed windows, activate the correct one. - cx.defer(move |cx| { - handle - .update(cx, |_view, window, _cx| { - window.activate_window(); - - if let Some(workspace) = workspace_handle.upgrade() { - workspace.update(_cx, |workspace, cx| { - workspace.focus_panel::(window, cx); - }); - } - }) - .log_err(); - }); + self.notification_subscriptions + .entry(screen_window) + .or_insert_with(Vec::new) + .push(cx.subscribe_in(&pop_up, window, { + |this, _, event, window, cx| match event { + AgentNotificationEvent::Accepted => { + let handle = window.window_handle(); + cx.activate(true); + + let workspace_handle = this.workspace.clone(); + + // If there are multiple Zed windows, activate the correct one. + cx.defer(move |cx| { + handle + .update(cx, |_view, window, _cx| { + window.activate_window(); + + if let Some(workspace) = workspace_handle.upgrade() { + workspace.update(_cx, |workspace, cx| { + workspace.focus_panel::(window, cx); + }); + } + }) + .log_err(); + }); - this.dismiss_notifications(cx); - } - AgentNotificationEvent::Dismissed => { - this.dismiss_notifications(cx); - } + this.dismiss_notifications(cx); } - })); - - self.notifications.push(screen_window); - - // If the user manually refocuses the original window, dismiss the popup. - self.notification_subscriptions - .entry(screen_window) - .or_insert_with(Vec::new) - .push({ - let pop_up_weak = pop_up.downgrade(); - - cx.observe_window_activation(window, move |_, window, cx| { - if window.is_window_active() { - if let Some(pop_up) = pop_up_weak.upgrade() { - pop_up.update(cx, |_, cx| { - cx.emit(AgentNotificationEvent::Dismissed); - }); - } - } - }) - }); - } + AgentNotificationEvent::Dismissed => { + this.dismiss_notifications(cx); + } + } + })); + + self.notifications.push(screen_window); + + // If the user manually refocuses the original window, dismiss the popup. + self.notification_subscriptions + .entry(screen_window) + .or_insert_with(Vec::new) + .push({ + let pop_up_weak = pop_up.downgrade(); + + cx.observe_window_activation(window, move |_, window, cx| { + if window.is_window_active() + && let Some(pop_up) = pop_up_weak.upgrade() + { + pop_up.update(cx, |_, cx| { + cx.emit(AgentNotificationEvent::Dismissed); + }); + } + }) + }); } } @@ -1374,12 +1383,12 @@ impl ActiveThread { editor.focus_handle(cx).focus(window); editor.move_to_end(&editor::actions::MoveToEnd, window, cx); }); - let buffer_edited_subscription = cx.subscribe(&editor, |this, _, event, cx| match event { - EditorEvent::BufferEdited => { - this.update_editing_message_token_count(true, cx); - } - _ => {} - }); + let buffer_edited_subscription = + cx.subscribe(&editor, |this, _, event: &EditorEvent, cx| { + if event == &EditorEvent::BufferEdited { + this.update_editing_message_token_count(true, cx); + } + }); let context_picker_menu_handle = PopoverMenuHandle::default(); let context_strip = cx.new(|cx| { @@ -1599,11 +1608,6 @@ impl ActiveThread { return; }; - if model.provider.must_accept_terms(cx) { - cx.notify(); - return; - } - let edited_text = state.editor.read(cx).text(cx); let creases = state.editor.update(cx, extract_message_creases); @@ -1738,6 +1742,7 @@ impl ActiveThread { ); editor.set_placeholder_text( "What went wrong? Share your feedback so we can improve.", + window, cx, ); editor @@ -1766,7 +1771,7 @@ impl ActiveThread { .thread .read(cx) .message(message_id) - .map(|msg| msg.to_string()) + .map(|msg| msg.to_message_content()) .unwrap_or_default(); telemetry::event!( @@ -2113,7 +2118,7 @@ impl ActiveThread { .gap_1() .children(message_content) .when_some(editing_message_state, |this, state| { - let focus_handle = state.editor.focus_handle(cx).clone(); + let focus_handle = state.editor.focus_handle(cx); this.child( h_flex() @@ -2174,7 +2179,6 @@ impl ActiveThread { .icon_color(Color::Muted) .icon_size(IconSize::Small) .tooltip({ - let focus_handle = focus_handle.clone(); move |window, cx| { Tooltip::for_action_in( "Regenerate", @@ -2247,9 +2251,7 @@ impl ActiveThread { let after_editing_message = self .editing_message .as_ref() - .map_or(false, |(editing_message_id, _)| { - message_id > *editing_message_id - }); + .is_some_and(|(editing_message_id, _)| message_id > *editing_message_id); let backdrop = div() .id(("backdrop", ix)) @@ -2269,13 +2271,12 @@ impl ActiveThread { let mut error = None; if let Some(last_restore_checkpoint) = self.thread.read(cx).last_restore_checkpoint() + && last_restore_checkpoint.message_id() == message_id { - if last_restore_checkpoint.message_id() == message_id { - match last_restore_checkpoint { - LastRestoreCheckpoint::Pending { .. } => is_pending = true, - LastRestoreCheckpoint::Error { error: err, .. } => { - error = Some(err.clone()); - } + match last_restore_checkpoint { + LastRestoreCheckpoint::Pending { .. } => is_pending = true, + LastRestoreCheckpoint::Error { error: err, .. } => { + error = Some(err.clone()); } } } @@ -2316,7 +2317,7 @@ impl ActiveThread { .into_any_element() } else if let Some(error) = error { restore_checkpoint_button - .tooltip(Tooltip::text(error.to_string())) + .tooltip(Tooltip::text(error)) .into_any_element() } else { restore_checkpoint_button.into_any_element() @@ -2357,7 +2358,6 @@ impl ActiveThread { this.submit_feedback_message(message_id, cx); cx.notify(); })) - .on_action(cx.listener(Self::confirm_editing_message)) .mb_2() .mx_4() .p_2() @@ -2473,7 +2473,7 @@ impl ActiveThread { message_id, index, content.clone(), - &scroll_handle, + scroll_handle, Some(index) == pending_thinking_segment_index, window, cx, @@ -2597,7 +2597,7 @@ impl ActiveThread { .id(("message-container", ix)) .py_1() .px_2p5() - .child(Banner::new().severity(ui::Severity::Warning).child(message)) + .child(Banner::new().severity(Severity::Warning).child(message)) } fn render_message_thinking_segment( @@ -2661,15 +2661,7 @@ impl ActiveThread { Icon::new(IconName::ArrowCircle) .color(Color::Accent) .size(IconSize::Small) - .with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| { - icon.transform(Transformation::rotate( - percentage(delta), - )) - }, - ) + .with_rotate_animation(2) }), ), ) @@ -2845,17 +2837,11 @@ impl ActiveThread { } ToolUseStatus::Pending | ToolUseStatus::InputStillStreaming - | ToolUseStatus::Running => { - let icon = Icon::new(IconName::ArrowCircle) - .color(Color::Accent) - .size(IconSize::Small); - icon.with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), - ) - .into_any_element() - } + | ToolUseStatus::Running => Icon::new(IconName::ArrowCircle) + .color(Color::Accent) + .size(IconSize::Small) + .with_rotate_animation(2) + .into_any_element(), ToolUseStatus::Finished(_) => div().w_0().into_any_element(), ToolUseStatus::Error(_) => { let icon = Icon::new(IconName::Close) @@ -2944,15 +2930,7 @@ impl ActiveThread { Icon::new(IconName::ArrowCircle) .size(IconSize::Small) .color(Color::Accent) - .with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| { - icon.transform(Transformation::rotate(percentage( - delta, - ))) - }, - ), + .with_rotate_animation(2), ) .child( Label::new("Running…") @@ -3608,7 +3586,7 @@ pub(crate) fn open_active_thread_as_markdown( } let buffer = project.update(cx, |project, cx| { - project.create_local_buffer(&markdown, Some(markdown_language), cx) + project.create_local_buffer(&markdown, Some(markdown_language), true, cx) }); let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx).with_title(thread_summary.clone())); @@ -4020,7 +3998,7 @@ mod tests { cx.run_until_parked(); - // Verify that the previous completion was cancelled + // Verify that the previous completion was canceled assert_eq!(cancellation_events.lock().unwrap().len(), 1); // Verify that a new request was started after cancellation diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index 5f72fa58c8489a21a7f386cea3b2678af37ba44f..4ad5b2d8e84779aeaa928133a11dc1a10d9a02bc 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -3,19 +3,21 @@ mod configure_context_server_modal; mod manage_profiles_modal; mod tool_picker; -use std::{sync::Arc, time::Duration}; +use std::{ops::Range, sync::Arc}; use agent_settings::AgentSettings; +use anyhow::Result; use assistant_tool::{ToolSource, ToolWorkingSet}; -use cloud_llm_client::Plan; +use cloud_llm_client::{Plan, PlanV1, PlanV2}; use collections::HashMap; use context_server::ContextServerId; +use editor::{Editor, SelectionEffects, scroll::Autoscroll}; use extension::ExtensionManifest; use extension_host::ExtensionStore; use fs::Fs; use gpui::{ - Action, Animation, AnimationExt as _, AnyView, App, Corner, Entity, EventEmitter, FocusHandle, - Focusable, ScrollHandle, Subscription, Task, Transformation, WeakEntity, percentage, + Action, AnyView, App, AsyncWindowContext, Corner, Entity, EventEmitter, FocusHandle, Focusable, + Hsla, ScrollHandle, Subscription, Task, WeakEntity, }; use language::LanguageRegistry; use language_model::{ @@ -23,29 +25,36 @@ use language_model::{ }; use notifications::status_toast::{StatusToast, ToastIcon}; use project::{ + agent_server_store::{ + AgentServerCommand, AgentServerStore, AllAgentServersSettings, CLAUDE_CODE_NAME, + CustomAgentServerSettings, GEMINI_NAME, + }, context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore}, project_settings::{ContextServerSettings, ProjectSettings}, }; -use settings::{Settings, update_settings_file}; +use settings::{Settings, SettingsStore, update_settings_file}; use ui::{ - Chip, ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu, - Scrollbar, ScrollbarState, Switch, SwitchColor, SwitchField, Tooltip, prelude::*, + Chip, CommonAnimationExt, ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, + Indicator, PopoverMenu, Scrollbar, ScrollbarState, Switch, SwitchColor, SwitchField, Tooltip, + prelude::*, }; use util::ResultExt as _; -use workspace::Workspace; +use workspace::{Workspace, create_and_open_local_file}; use zed_actions::ExtensionCategoryFilter; pub(crate) use configure_context_server_modal::ConfigureContextServerModal; pub(crate) use manage_profiles_modal::ManageProfilesModal; use crate::{ - AddContextServer, + AddContextServer, ExternalAgent, NewExternalAgentThread, agent_configuration::add_llm_provider_modal::{AddLlmProviderModal, LlmCompatibleProvider}, + placeholder_command, }; pub struct AgentConfiguration { fs: Arc, language_registry: Arc, + agent_server_store: Entity, workspace: WeakEntity, focus_handle: FocusHandle, configuration_views_by_provider: HashMap, @@ -56,11 +65,13 @@ pub struct AgentConfiguration { _registry_subscription: Subscription, scroll_handle: ScrollHandle, scrollbar_state: ScrollbarState, + _check_for_gemini: Task<()>, } impl AgentConfiguration { pub fn new( fs: Arc, + agent_server_store: Entity, context_server_store: Entity, tools: Entity, language_registry: Arc, @@ -93,27 +104,21 @@ impl AgentConfiguration { let scroll_handle = ScrollHandle::new(); let scrollbar_state = ScrollbarState::new(scroll_handle.clone()); - let mut expanded_provider_configurations = HashMap::default(); - if LanguageModelRegistry::read_global(cx) - .provider(&ZED_CLOUD_PROVIDER_ID) - .map_or(false, |cloud_provider| cloud_provider.must_accept_terms(cx)) - { - expanded_provider_configurations.insert(ZED_CLOUD_PROVIDER_ID, true); - } - let mut this = Self { fs, language_registry, workspace, focus_handle, configuration_views_by_provider: HashMap::default(), + agent_server_store, context_server_store, expanded_context_server_tools: HashMap::default(), - expanded_provider_configurations, + expanded_provider_configurations: HashMap::default(), tools, _registry_subscription: registry_subscription, scroll_handle, scrollbar_state, + _check_for_gemini: Task::ready(()), }; this.build_provider_configuration_views(window, cx); this @@ -137,7 +142,11 @@ impl AgentConfiguration { window: &mut Window, cx: &mut Context, ) { - let configuration_view = provider.configuration_view(window, cx); + let configuration_view = provider.configuration_view( + language_model::ConfigurationViewTargetAgent::ZedAgent, + window, + cx, + ); self.configuration_views_by_provider .insert(provider.id(), configuration_view); } @@ -161,8 +170,8 @@ impl AgentConfiguration { provider: &Arc, cx: &mut Context, ) -> impl IntoElement + use<> { - let provider_id = provider.id().0.clone(); - let provider_name = provider.name().0.clone(); + let provider_id = provider.id().0; + let provider_name = provider.name().0; let provider_id_string = SharedString::from(format!("provider-disclosure-{provider_id}")); let configuration_view = self @@ -188,7 +197,7 @@ impl AgentConfiguration { let is_signed_in = self .workspace .read_with(cx, |workspace, _| { - workspace.client().status().borrow().is_connected() + !workspace.client().status().borrow().is_signed_out() }) .unwrap_or(false); @@ -215,7 +224,6 @@ impl AgentConfiguration { .child( h_flex() .id(provider_id_string.clone()) - .cursor_pointer() .px_2() .py_0p5() .w_full() @@ -235,10 +243,7 @@ impl AgentConfiguration { h_flex() .w_full() .gap_1() - .child( - Label::new(provider_name.clone()) - .size(LabelSize::Large), - ) + .child(Label::new(provider_name.clone())) .map(|this| { if is_zed_provider && is_signed_in { this.child( @@ -265,7 +270,7 @@ impl AgentConfiguration { .closed_icon(IconName::ChevronDown), ) .on_click(cx.listener({ - let provider_id = provider.id().clone(); + let provider_id = provider.id(); move |this, _event, _window, _cx| { let is_expanded = this .expanded_provider_configurations @@ -283,7 +288,7 @@ impl AgentConfiguration { "Start New Thread", ) .icon_position(IconPosition::Start) - .icon(IconName::Plus) + .icon(IconName::Thread) .icon_size(IconSize::Small) .icon_color(Color::Muted) .label_size(LabelSize::Small) @@ -300,6 +305,7 @@ impl AgentConfiguration { ) .child( div() + .w_full() .px_2() .when(is_expanded, |parent| match configuration_view { Some(configuration_view) => parent.child(configuration_view), @@ -332,6 +338,7 @@ impl AgentConfiguration { .gap_0p5() .child( h_flex() + .pr_1() .w_full() .gap_2() .justify_between() @@ -381,7 +388,7 @@ impl AgentConfiguration { ), ) .child( - Label::new("Add at least one provider to use AI-powered features.") + Label::new("Add at least one provider to use AI-powered features with Zed's native agent.") .color(Color::Muted), ), ), @@ -465,7 +472,7 @@ impl AgentConfiguration { "modifier-send", "Use modifier to submit a message", Some( - "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.".into(), + "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux or Windows) required to send messages.".into(), ), use_modifier_to_send, move |state, _window, cx| { @@ -508,9 +515,15 @@ impl AgentConfiguration { .blend(cx.theme().colors().text_accent.opacity(0.2)); let (plan_name, label_color, bg_color) = match plan { - Plan::ZedFree => ("Free", Color::Default, free_chip_bg), - Plan::ZedProTrial => ("Pro Trial", Color::Accent, pro_chip_bg), - Plan::ZedPro => ("Pro", Color::Accent, pro_chip_bg), + Plan::V1(PlanV1::ZedFree) | Plan::V2(PlanV2::ZedFree) => { + ("Free", Color::Default, free_chip_bg) + } + Plan::V1(PlanV1::ZedProTrial) | Plan::V2(PlanV2::ZedProTrial) => { + ("Pro Trial", Color::Accent, pro_chip_bg) + } + Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro) => { + ("Pro", Color::Accent, pro_chip_bg) + } }; Chip::new(plan_name.to_string()) @@ -522,6 +535,14 @@ impl AgentConfiguration { } } + fn card_item_bg_color(&self, cx: &mut Context) -> Hsla { + cx.theme().colors().background.opacity(0.25) + } + + fn card_item_border_color(&self, cx: &mut Context) -> Hsla { + cx.theme().colors().border.opacity(0.6) + } + fn render_context_servers_section( &mut self, window: &mut Window, @@ -539,7 +560,12 @@ impl AgentConfiguration { v_flex() .gap_0p5() .child(Headline::new("Model Context Protocol (MCP) Servers")) - .child(Label::new("Connect to context servers through the Model Context Protocol, either using Zed extensions or directly.").color(Color::Muted)), + .child( + Label::new( + "All context servers connected through the Model Context Protocol.", + ) + .color(Color::Muted), + ), ) .children( context_server_ids.into_iter().map(|context_server_id| { @@ -549,7 +575,7 @@ impl AgentConfiguration { .child( h_flex() .justify_between() - .gap_2() + .gap_1p5() .child( h_flex().w_full().child( Button::new("add-context-server", "Add Custom Server") @@ -640,8 +666,6 @@ impl AgentConfiguration { .map_or([].as_slice(), |tools| tools.as_slice()); let tool_count = tools.len(); - let border_color = cx.theme().colors().border.opacity(0.6); - let (source_icon, source_tooltip) = if is_from_extension { ( IconName::ZedMcpExtension, @@ -659,10 +683,9 @@ impl AgentConfiguration { Icon::new(IconName::LoadCircle) .size(IconSize::XSmall) .color(Color::Accent) - .with_animation( - SharedString::from(format!("{}-starting", context_server_id.0.clone(),)), - Animation::new(Duration::from_secs(3)).repeat(), - |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), + .with_keyed_rotate_animation( + SharedString::from(format!("{}-starting", context_server_id.0)), + 3, ) .into_any_element(), "Server is starting.", @@ -784,8 +807,8 @@ impl AgentConfiguration { .id(item_id.clone()) .border_1() .rounded_md() - .border_color(border_color) - .bg(cx.theme().colors().background.opacity(0.2)) + .border_color(self.card_item_border_color(cx)) + .bg(self.card_item_bg_color(cx)) .overflow_hidden() .child( h_flex() @@ -793,7 +816,11 @@ impl AgentConfiguration { .justify_between() .when( error.is_some() || are_tools_expanded && tool_count >= 1, - |element| element.border_b_1().border_color(border_color), + |element| { + element + .border_b_1() + .border_color(self.card_item_border_color(cx)) + }, ) .child( h_flex() @@ -860,7 +887,6 @@ impl AgentConfiguration { .on_click({ let context_server_manager = self.context_server_store.clone(); - let context_server_id = context_server_id.clone(); let fs = self.fs.clone(); move |state, _window, cx| { @@ -953,7 +979,7 @@ impl AgentConfiguration { } parent.child(v_flex().py_1p5().px_1().gap_1().children( - tools.into_iter().enumerate().map(|(ix, tool)| { + tools.iter().enumerate().map(|(ix, tool)| { h_flex() .id(("tool-item", ix)) .px_1() @@ -976,6 +1002,149 @@ impl AgentConfiguration { )) }) } + + fn render_agent_servers_section(&mut self, cx: &mut Context) -> impl IntoElement { + let custom_settings = cx + .global::() + .get::(None) + .custom + .clone(); + let user_defined_agents = self + .agent_server_store + .read(cx) + .external_agents() + .filter(|name| name.0 != GEMINI_NAME && name.0 != CLAUDE_CODE_NAME) + .cloned() + .collect::>(); + let user_defined_agents = user_defined_agents + .into_iter() + .map(|name| { + self.render_agent_server( + IconName::Ai, + name.clone(), + ExternalAgent::Custom { + name: name.clone().into(), + command: custom_settings + .get(&name.0) + .map(|settings| settings.command.clone()) + .unwrap_or(placeholder_command()), + }, + cx, + ) + .into_any_element() + }) + .collect::>(); + + v_flex() + .border_b_1() + .border_color(cx.theme().colors().border) + .child( + v_flex() + .p(DynamicSpacing::Base16.rems(cx)) + .pr(DynamicSpacing::Base20.rems(cx)) + .gap_2() + .child( + v_flex() + .gap_0p5() + .child( + h_flex() + .pr_1() + .w_full() + .gap_2() + .justify_between() + .child(Headline::new("External Agents")) + .child( + Button::new("add-agent", "Add Agent") + .icon_position(IconPosition::Start) + .icon(IconName::Plus) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .label_size(LabelSize::Small) + .on_click( + move |_, window, cx| { + if let Some(workspace) = window.root().flatten() { + let workspace = workspace.downgrade(); + window + .spawn(cx, async |cx| { + open_new_agent_servers_entry_in_settings_editor( + workspace, + cx, + ).await + }) + .detach_and_log_err(cx); + } + } + ), + ) + ) + .child( + Label::new( + "All agents connected through the Agent Client Protocol.", + ) + .color(Color::Muted), + ), + ) + .child(self.render_agent_server( + IconName::AiGemini, + "Gemini CLI", + ExternalAgent::Gemini, + cx, + )) + .child(self.render_agent_server( + IconName::AiClaude, + "Claude Code", + ExternalAgent::ClaudeCode, + cx, + )) + .children(user_defined_agents), + ) + } + + fn render_agent_server( + &self, + icon: IconName, + name: impl Into, + agent: ExternalAgent, + cx: &mut Context, + ) -> impl IntoElement { + let name = name.into(); + h_flex() + .p_1() + .pl_2() + .gap_1p5() + .justify_between() + .border_1() + .rounded_md() + .border_color(self.card_item_border_color(cx)) + .bg(self.card_item_bg_color(cx)) + .overflow_hidden() + .child( + h_flex() + .gap_1p5() + .child(Icon::new(icon).size(IconSize::Small).color(Color::Muted)) + .child(Label::new(name.clone())), + ) + .child( + Button::new( + SharedString::from(format!("start_acp_thread-{name}")), + "Start New Thread", + ) + .label_size(LabelSize::Small) + .icon(IconName::Thread) + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(move |_, window, cx| { + window.dispatch_action( + NewExternalAgentThread { + agent: Some(agent.clone()), + } + .boxed_clone(), + cx, + ); + }), + ) + } } impl Render for AgentConfiguration { @@ -995,6 +1164,7 @@ impl Render for AgentConfiguration { .size_full() .overflow_y_scroll() .child(self.render_general_settings_section(cx)) + .child(self.render_agent_servers_section(cx)) .child(self.render_context_servers_section(window, cx)) .child(self.render_provider_configuration_section(cx)), ) @@ -1035,7 +1205,6 @@ fn extension_only_provides_context_server(manifest: &ExtensionManifest) -> bool && manifest.grammars.is_empty() && manifest.language_servers.is_empty() && manifest.slash_commands.is_empty() - && manifest.indexed_docs_providers.is_empty() && manifest.snippets.is_none() && manifest.debug_locators.is_empty() } @@ -1071,7 +1240,6 @@ fn show_unable_to_uninstall_extension_with_context_server( cx, move |this, _cx| { let workspace_handle = workspace_handle.clone(); - let context_server_id = context_server_id.clone(); this.icon(ToastIcon::new(IconName::Warning).color(Color::Warning)) .dismiss_button(true) @@ -1115,3 +1283,110 @@ fn show_unable_to_uninstall_extension_with_context_server( workspace.toggle_status_toast(status_toast, cx); } + +async fn open_new_agent_servers_entry_in_settings_editor( + workspace: WeakEntity, + cx: &mut AsyncWindowContext, +) -> Result<()> { + let settings_editor = workspace + .update_in(cx, |_, window, cx| { + create_and_open_local_file(paths::settings_file(), window, cx, || { + settings::initial_user_settings_content().as_ref().into() + }) + })? + .await? + .downcast::() + .unwrap(); + + settings_editor + .downgrade() + .update_in(cx, |item, window, cx| { + let text = item.buffer().read(cx).snapshot(cx).text(); + + let settings = cx.global::(); + + let mut unique_server_name = None; + let edits = settings.edits_for_update::(&text, |file| { + let server_name: Option = (0..u8::MAX) + .map(|i| { + if i == 0 { + "your_agent".into() + } else { + format!("your_agent_{}", i).into() + } + }) + .find(|name| !file.custom.contains_key(name)); + if let Some(server_name) = server_name { + unique_server_name = Some(server_name.clone()); + file.custom.insert( + server_name, + CustomAgentServerSettings { + command: AgentServerCommand { + path: "path_to_executable".into(), + args: vec![], + env: Some(HashMap::default()), + }, + default_mode: None, + }, + ); + } + }); + + if edits.is_empty() { + return; + } + + let ranges = edits + .iter() + .map(|(range, _)| range.clone()) + .collect::>(); + + item.edit(edits, cx); + if let Some((unique_server_name, buffer)) = + unique_server_name.zip(item.buffer().read(cx).as_singleton()) + { + let snapshot = buffer.read(cx).snapshot(); + if let Some(range) = + find_text_in_buffer(&unique_server_name, ranges[0].start, &snapshot) + { + item.change_selections( + SelectionEffects::scroll(Autoscroll::newest()), + window, + cx, + |selections| { + selections.select_ranges(vec![range]); + }, + ); + } + } + }) +} + +fn find_text_in_buffer( + text: &str, + start: usize, + snapshot: &language::BufferSnapshot, +) -> Option> { + let chars = text.chars().collect::>(); + + let mut offset = start; + let mut char_offset = 0; + for c in snapshot.chars_at(start) { + if char_offset >= chars.len() { + break; + } + offset += 1; + + if c == chars[char_offset] { + char_offset += 1; + } else { + char_offset = 0; + } + } + + if char_offset == chars.len() { + Some(offset.saturating_sub(chars.len())..offset) + } else { + None + } +} diff --git a/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs index 401a6334886e18ef2e53bbd5b68392597d0db1e9..182831f488870997d175cce0ad7e1c94e392f1ea 100644 --- a/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs +++ b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs @@ -7,10 +7,12 @@ use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, T use language_model::LanguageModelRegistry; use language_models::{ AllLanguageModelSettings, OpenAiCompatibleSettingsContent, - provider::open_ai_compatible::AvailableModel, + provider::open_ai_compatible::{AvailableModel, ModelCapabilities}, }; use settings::update_settings_file; -use ui::{Banner, KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*}; +use ui::{ + Banner, Checkbox, KeyBinding, Modal, ModalFooter, ModalHeader, Section, ToggleState, prelude::*, +}; use ui_input::SingleLineInput; use workspace::{ModalView, Workspace}; @@ -69,11 +71,19 @@ impl AddLlmProviderInput { } } +struct ModelCapabilityToggles { + pub supports_tools: ToggleState, + pub supports_images: ToggleState, + pub supports_parallel_tool_calls: ToggleState, + pub supports_prompt_cache_key: ToggleState, +} + struct ModelInput { name: Entity, max_completion_tokens: Entity, max_output_tokens: Entity, max_tokens: Entity, + capabilities: ModelCapabilityToggles, } impl ModelInput { @@ -100,11 +110,23 @@ impl ModelInput { cx, ); let max_tokens = single_line_input("Max Tokens", "Max Tokens", Some("200000"), window, cx); + let ModelCapabilities { + tools, + images, + parallel_tool_calls, + prompt_cache_key, + } = ModelCapabilities::default(); Self { name: model_name, max_completion_tokens, max_output_tokens, max_tokens, + capabilities: ModelCapabilityToggles { + supports_tools: tools.into(), + supports_images: images.into(), + supports_parallel_tool_calls: parallel_tool_calls.into(), + supports_prompt_cache_key: prompt_cache_key.into(), + }, } } @@ -136,6 +158,12 @@ impl ModelInput { .text(cx) .parse::() .map_err(|_| SharedString::from("Max Tokens must be a number"))?, + capabilities: ModelCapabilities { + tools: self.capabilities.supports_tools.selected(), + images: self.capabilities.supports_images.selected(), + parallel_tool_calls: self.capabilities.supports_parallel_tool_calls.selected(), + prompt_cache_key: self.capabilities.supports_prompt_cache_key.selected(), + }, }) } } @@ -322,6 +350,55 @@ impl AddLlmProviderModal { .child(model.max_output_tokens.clone()), ) .child(model.max_tokens.clone()) + .child( + v_flex() + .gap_1() + .child( + Checkbox::new(("supports-tools", ix), model.capabilities.supports_tools) + .label("Supports tools") + .on_click(cx.listener(move |this, checked, _window, cx| { + this.input.models[ix].capabilities.supports_tools = *checked; + cx.notify(); + })), + ) + .child( + Checkbox::new(("supports-images", ix), model.capabilities.supports_images) + .label("Supports images") + .on_click(cx.listener(move |this, checked, _window, cx| { + this.input.models[ix].capabilities.supports_images = *checked; + cx.notify(); + })), + ) + .child( + Checkbox::new( + ("supports-parallel-tool-calls", ix), + model.capabilities.supports_parallel_tool_calls, + ) + .label("Supports parallel_tool_calls") + .on_click(cx.listener( + move |this, checked, _window, cx| { + this.input.models[ix] + .capabilities + .supports_parallel_tool_calls = *checked; + cx.notify(); + }, + )), + ) + .child( + Checkbox::new( + ("supports-prompt-cache-key", ix), + model.capabilities.supports_prompt_cache_key, + ) + .label("Supports prompt_cache_key") + .on_click(cx.listener( + move |this, checked, _window, cx| { + this.input.models[ix].capabilities.supports_prompt_cache_key = + *checked; + cx.notify(); + }, + )), + ), + ) .when(has_more_than_one_model, |this| { this.child( Button::new(("remove-model", ix), "Remove Model") @@ -377,7 +454,7 @@ impl Render for AddLlmProviderModal { this.section( Section::new().child( Banner::new() - .severity(ui::Severity::Warning) + .severity(Severity::Warning) .child(div().text_xs().child(error)), ), ) @@ -562,6 +639,93 @@ mod tests { ); } + #[gpui::test] + async fn test_model_input_default_capabilities(cx: &mut TestAppContext) { + let cx = setup_test(cx).await; + + cx.update(|window, cx| { + let model_input = ModelInput::new(window, cx); + model_input.name.update(cx, |input, cx| { + input.editor().update(cx, |editor, cx| { + editor.set_text("somemodel", window, cx); + }); + }); + assert_eq!( + model_input.capabilities.supports_tools, + ToggleState::Selected + ); + assert_eq!( + model_input.capabilities.supports_images, + ToggleState::Unselected + ); + assert_eq!( + model_input.capabilities.supports_parallel_tool_calls, + ToggleState::Unselected + ); + assert_eq!( + model_input.capabilities.supports_prompt_cache_key, + ToggleState::Unselected + ); + + let parsed_model = model_input.parse(cx).unwrap(); + assert!(parsed_model.capabilities.tools); + assert!(!parsed_model.capabilities.images); + assert!(!parsed_model.capabilities.parallel_tool_calls); + assert!(!parsed_model.capabilities.prompt_cache_key); + }); + } + + #[gpui::test] + async fn test_model_input_deselected_capabilities(cx: &mut TestAppContext) { + let cx = setup_test(cx).await; + + cx.update(|window, cx| { + let mut model_input = ModelInput::new(window, cx); + model_input.name.update(cx, |input, cx| { + input.editor().update(cx, |editor, cx| { + editor.set_text("somemodel", window, cx); + }); + }); + + model_input.capabilities.supports_tools = ToggleState::Unselected; + model_input.capabilities.supports_images = ToggleState::Unselected; + model_input.capabilities.supports_parallel_tool_calls = ToggleState::Unselected; + model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected; + + let parsed_model = model_input.parse(cx).unwrap(); + assert!(!parsed_model.capabilities.tools); + assert!(!parsed_model.capabilities.images); + assert!(!parsed_model.capabilities.parallel_tool_calls); + assert!(!parsed_model.capabilities.prompt_cache_key); + }); + } + + #[gpui::test] + async fn test_model_input_with_name_and_capabilities(cx: &mut TestAppContext) { + let cx = setup_test(cx).await; + + cx.update(|window, cx| { + let mut model_input = ModelInput::new(window, cx); + model_input.name.update(cx, |input, cx| { + input.editor().update(cx, |editor, cx| { + editor.set_text("somemodel", window, cx); + }); + }); + + model_input.capabilities.supports_tools = ToggleState::Selected; + model_input.capabilities.supports_images = ToggleState::Unselected; + model_input.capabilities.supports_parallel_tool_calls = ToggleState::Selected; + model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected; + + let parsed_model = model_input.parse(cx).unwrap(); + assert_eq!(parsed_model.name, "somemodel"); + assert!(parsed_model.capabilities.tools); + assert!(!parsed_model.capabilities.images); + assert!(parsed_model.capabilities.parallel_tool_calls); + assert!(!parsed_model.capabilities.prompt_cache_key); + }); + } + async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext { cx.update(|cx| { let store = SettingsStore::test(cx); diff --git a/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs b/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs index 32360dd56ef925d56310ff7e2e5668de1973f472..4d338840143fbcf007f7d5c66e2406ef4bb9fc88 100644 --- a/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs +++ b/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs @@ -1,16 +1,14 @@ use std::{ path::PathBuf, sync::{Arc, Mutex}, - time::Duration, }; use anyhow::{Context as _, Result}; use context_server::{ContextServerCommand, ContextServerId}; use editor::{Editor, EditorElement, EditorStyle}; use gpui::{ - Animation, AnimationExt as _, AsyncWindowContext, DismissEvent, Entity, EventEmitter, - FocusHandle, Focusable, Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, - WeakEntity, percentage, prelude::*, + AsyncWindowContext, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Task, + TextStyle, TextStyleRefinement, UnderlineStyle, WeakEntity, prelude::*, }; use language::{Language, LanguageRegistry}; use markdown::{Markdown, MarkdownElement, MarkdownStyle}; @@ -24,7 +22,9 @@ use project::{ }; use settings::{Settings as _, update_settings_file}; use theme::ThemeSettings; -use ui::{KeyBinding, Modal, ModalFooter, ModalHeader, Section, Tooltip, prelude::*}; +use ui::{ + CommonAnimationExt, KeyBinding, Modal, ModalFooter, ModalHeader, Section, Tooltip, prelude::*, +}; use util::ResultExt as _; use workspace::{ModalView, Workspace}; @@ -163,10 +163,10 @@ impl ConfigurationSource { .read(cx) .text(cx); let settings = serde_json_lenient::from_str::(&text)?; - if let Some(settings_validator) = settings_validator { - if let Err(error) = settings_validator.validate(&settings) { - return Err(anyhow::anyhow!(error.to_string())); - } + if let Some(settings_validator) = settings_validator + && let Err(error) = settings_validator.validate(&settings) + { + return Err(anyhow::anyhow!(error.to_string())); } Ok(( id.clone(), @@ -251,6 +251,7 @@ pub struct ConfigureContextServerModal { workspace: WeakEntity, source: ConfigurationSource, state: State, + original_server_id: Option, } impl ConfigureContextServerModal { @@ -261,7 +262,6 @@ impl ConfigureContextServerModal { _cx: &mut Context, ) { workspace.register_action({ - let language_registry = language_registry.clone(); move |_workspace, _: &AddContextServer, window, cx| { let workspace_handle = cx.weak_entity(); let language_registry = language_registry.clone(); @@ -349,6 +349,11 @@ impl ConfigureContextServerModal { context_server_store, workspace: workspace_handle, state: State::Idle, + original_server_id: match &target { + ConfigurationTarget::Existing { id, .. } => Some(id.clone()), + ConfigurationTarget::Extension { id, .. } => Some(id.clone()), + ConfigurationTarget::New => None, + }, source: ConfigurationSource::from_target( target, language_registry, @@ -416,9 +421,19 @@ impl ConfigureContextServerModal { // When we write the settings to the file, the context server will be restarted. workspace.update(cx, |workspace, cx| { let fs = workspace.app_state().fs.clone(); - update_settings_file::(fs.clone(), cx, |project_settings, _| { - project_settings.context_servers.insert(id.0, settings); - }); + let original_server_id = self.original_server_id.clone(); + update_settings_file::( + fs.clone(), + cx, + move |project_settings, _| { + if let Some(original_id) = original_server_id { + if original_id != id { + project_settings.context_servers.remove(&original_id.0); + } + } + project_settings.context_servers.insert(id.0, settings); + }, + ); }); } else if let Some(existing_server) = existing_server { self.context_server_store @@ -487,7 +502,7 @@ impl ConfigureContextServerModal { } fn render_modal_description(&self, window: &mut Window, cx: &mut Context) -> AnyElement { - const MODAL_DESCRIPTION: &'static str = "Visit the MCP server configuration docs to find all necessary arguments and environment variables."; + const MODAL_DESCRIPTION: &str = "Visit the MCP server configuration docs to find all necessary arguments and environment variables."; if let ConfigurationSource::Extension { installation_instructions: Some(installation_instructions), @@ -639,11 +654,7 @@ impl ConfigureContextServerModal { Icon::new(IconName::ArrowCircle) .size(IconSize::XSmall) .color(Color::Info) - .with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), - ) + .with_rotate_animation(2) .into_any_element(), ) .child( @@ -716,24 +727,24 @@ fn wait_for_context_server( project::context_server_store::Event::ServerStatusChanged { server_id, status } => { match status { ContextServerStatus::Running => { - if server_id == &context_server_id { - if let Some(tx) = tx.lock().unwrap().take() { - let _ = tx.send(Ok(())); - } + if server_id == &context_server_id + && let Some(tx) = tx.lock().unwrap().take() + { + let _ = tx.send(Ok(())); } } ContextServerStatus::Stopped => { - if server_id == &context_server_id { - if let Some(tx) = tx.lock().unwrap().take() { - let _ = tx.send(Err("Context server stopped running".into())); - } + if server_id == &context_server_id + && let Some(tx) = tx.lock().unwrap().take() + { + let _ = tx.send(Err("Context server stopped running".into())); } } ContextServerStatus::Error(error) => { - if server_id == &context_server_id { - if let Some(tx) = tx.lock().unwrap().take() { - let _ = tx.send(Err(error.clone())); - } + if server_id == &context_server_id + && let Some(tx) = tx.lock().unwrap().take() + { + let _ = tx.send(Err(error.clone())); } } _ => {} diff --git a/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs b/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs index 09ad013d1ceb56d7c031cfc9eededb429aed2841..3bd5ed40d2f265287f6fe22dfbc2a19487149c37 100644 --- a/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs +++ b/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs @@ -156,7 +156,7 @@ impl ManageProfilesModal { ) { let name_editor = cx.new(|cx| Editor::single_line(window, cx)); name_editor.update(cx, |editor, cx| { - editor.set_placeholder_text("Profile name", cx); + editor.set_placeholder_text("Profile name", window, cx); }); self.mode = Mode::NewProfile(NewProfileMode { @@ -464,7 +464,7 @@ impl ManageProfilesModal { }, )) .child(ListSeparator) - .child(h_flex().p_2().child(mode.name_editor.clone())) + .child(h_flex().p_2().child(mode.name_editor)) } fn render_view_profile( diff --git a/crates/agent_ui/src/agent_configuration/tool_picker.rs b/crates/agent_ui/src/agent_configuration/tool_picker.rs index 8f1e0d71c0bd8ef56a71c1a88db1bf67929b060c..2ba92fa6b7993664d278cfd57d851dcfd9cb0922 100644 --- a/crates/agent_ui/src/agent_configuration/tool_picker.rs +++ b/crates/agent_ui/src/agent_configuration/tool_picker.rs @@ -191,10 +191,10 @@ impl PickerDelegate for ToolPickerDelegate { BTreeMap::default(); for item in all_items.iter() { - if let PickerItem::Tool { server_id, name } = item.clone() { - if name.contains(&query) { - tools_by_provider.entry(server_id).or_default().push(name); - } + if let PickerItem::Tool { server_id, name } = item.clone() + && name.contains(&query) + { + tools_by_provider.entry(server_id).or_default().push(name); } } @@ -318,7 +318,7 @@ impl PickerDelegate for ToolPickerDelegate { _window: &mut Window, cx: &mut Context>, ) -> Option { - let item = &self.filtered_items[ix]; + let item = &self.filtered_items.get(ix)?; match item { PickerItem::ContextServer { server_id, .. } => Some( div() diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index b9e1ea5d0a26262fc24dc58d05d54ed970371ccd..14a60b30148acea49ed81832287a1a8ef51f65a5 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -10,12 +10,12 @@ use editor::{ Direction, Editor, EditorEvent, EditorSettings, MultiBuffer, MultiBufferSnapshot, SelectionEffects, ToPoint, actions::{GoToHunk, GoToPreviousHunk}, + multibuffer_context_lines, scroll::Autoscroll, }; use gpui::{ - Action, Animation, AnimationExt, AnyElement, AnyView, App, AppContext, Empty, Entity, - EventEmitter, FocusHandle, Focusable, Global, SharedString, Subscription, Task, Transformation, - WeakEntity, Window, percentage, prelude::*, + Action, AnyElement, AnyView, App, AppContext, Empty, Entity, EventEmitter, FocusHandle, + Focusable, Global, SharedString, Subscription, Task, WeakEntity, Window, prelude::*, }; use language::{Buffer, Capability, DiskState, OffsetRangeExt, Point}; @@ -28,9 +28,8 @@ use std::{ collections::hash_map::Entry, ops::Range, sync::Arc, - time::Duration, }; -use ui::{IconButtonShape, KeyBinding, Tooltip, prelude::*, vertical_divider}; +use ui::{CommonAnimationExt, IconButtonShape, KeyBinding, Tooltip, prelude::*, vertical_divider}; use util::ResultExt; use workspace::{ Item, ItemHandle, ItemNavHistory, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, @@ -185,7 +184,7 @@ impl AgentDiffPane { let focus_handle = cx.focus_handle(); let multibuffer = cx.new(|_| MultiBuffer::new(Capability::ReadWrite)); - let project = thread.project(cx).clone(); + let project = thread.project(cx); let editor = cx.new(|cx| { let mut editor = Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx); @@ -196,27 +195,24 @@ impl AgentDiffPane { editor }); - let action_log = thread.action_log(cx).clone(); + let action_log = thread.action_log(cx); let mut this = Self { - _subscriptions: [ - Some( - cx.observe_in(&action_log, window, |this, _action_log, window, cx| { - this.update_excerpts(window, cx) - }), - ), + _subscriptions: vec![ + cx.observe_in(&action_log, window, |this, _action_log, window, cx| { + this.update_excerpts(window, cx) + }), match &thread { - AgentDiffThread::Native(thread) => { - Some(cx.subscribe(&thread, |this, _thread, event, cx| { - this.handle_thread_event(event, cx) - })) - } - AgentDiffThread::AcpThread(_) => None, + AgentDiffThread::Native(thread) => cx + .subscribe(thread, |this, _thread, event, cx| { + this.handle_native_thread_event(event, cx) + }), + AgentDiffThread::AcpThread(thread) => cx + .subscribe(thread, |this, _thread, event, cx| { + this.handle_acp_thread_event(event, cx) + }), }, - ] - .into_iter() - .flatten() - .collect(), + ], title: SharedString::default(), multibuffer, editor, @@ -260,7 +256,7 @@ impl AgentDiffPane { path_key.clone(), buffer.clone(), diff_hunk_ranges, - editor::DEFAULT_MULTIBUFFER_CONTEXT, + multibuffer_context_lines(cx), cx, ); multibuffer.add_diff(diff_handle, cx); @@ -288,7 +284,7 @@ impl AgentDiffPane { && buffer .read(cx) .file() - .map_or(false, |file| file.disk_state() == DiskState::Deleted) + .is_some_and(|file| file.disk_state() == DiskState::Deleted) { editor.fold_buffer(snapshot.text.remote_id(), cx) } @@ -324,10 +320,15 @@ impl AgentDiffPane { } } - fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context) { - match event { - ThreadEvent::SummaryGenerated => self.update_title(cx), - _ => {} + fn handle_native_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context) { + if let ThreadEvent::SummaryGenerated = event { + self.update_title(cx) + } + } + + fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context) { + if let AcpThreadEvent::TitleUpdated = event { + self.update_title(cx) } } @@ -398,7 +399,7 @@ fn keep_edits_in_selection( .disjoint_anchor_ranges() .collect::>(); - keep_edits_in_ranges(editor, buffer_snapshot, &thread, ranges, window, cx) + keep_edits_in_ranges(editor, buffer_snapshot, thread, ranges, window, cx) } fn reject_edits_in_selection( @@ -412,7 +413,7 @@ fn reject_edits_in_selection( .selections .disjoint_anchor_ranges() .collect::>(); - reject_edits_in_ranges(editor, buffer_snapshot, &thread, ranges, window, cx) + reject_edits_in_ranges(editor, buffer_snapshot, thread, ranges, window, cx) } fn keep_edits_in_ranges( @@ -503,8 +504,7 @@ fn update_editor_selection( &[last_kept_hunk_end..editor::Anchor::max()], buffer_snapshot, ) - .skip(1) - .next() + .nth(1) }) .or_else(|| { let first_kept_hunk = diff_hunks.first()?; @@ -1001,7 +1001,7 @@ impl AgentDiffToolbar { return; }; - *state = agent_diff.read(cx).editor_state(&editor); + *state = agent_diff.read(cx).editor_state(editor); self.update_location(cx); cx.notify(); } @@ -1044,23 +1044,23 @@ impl ToolbarItemView for AgentDiffToolbar { return self.location(cx); } - if let Some(editor) = item.act_as::(cx) { - if editor.read(cx).mode().is_full() { - let agent_diff = AgentDiff::global(cx); + if let Some(editor) = item.act_as::(cx) + && editor.read(cx).mode().is_full() + { + let agent_diff = AgentDiff::global(cx); - self.active_item = Some(AgentDiffToolbarItem::Editor { - editor: editor.downgrade(), - state: agent_diff.read(cx).editor_state(&editor.downgrade()), - _diff_subscription: cx.observe(&agent_diff, Self::handle_diff_notify), - }); + self.active_item = Some(AgentDiffToolbarItem::Editor { + editor: editor.downgrade(), + state: agent_diff.read(cx).editor_state(&editor.downgrade()), + _diff_subscription: cx.observe(&agent_diff, Self::handle_diff_notify), + }); - return self.location(cx); - } + return self.location(cx); } } self.active_item = None; - return self.location(cx); + self.location(cx) } fn pane_focus_update( @@ -1082,11 +1082,7 @@ impl Render for AgentDiffToolbar { Icon::new(IconName::LoadCircle) .size(IconSize::Small) .color(Color::Accent) - .with_animation( - "load_circle", - Animation::new(Duration::from_secs(3)).repeat(), - |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), - ), + .with_rotate_animation(3), ) .into_any(); @@ -1311,7 +1307,7 @@ impl AgentDiff { let entity = cx.new(|_cx| Self::default()); let global = AgentDiffGlobal(entity.clone()); cx.set_global(global); - entity.clone() + entity }) } @@ -1333,7 +1329,7 @@ impl AgentDiff { window: &mut Window, cx: &mut Context, ) { - let action_log = thread.action_log(cx).clone(); + let action_log = thread.action_log(cx); let action_log_subscription = cx.observe_in(&action_log, window, { let workspace = workspace.clone(); @@ -1343,13 +1339,13 @@ impl AgentDiff { }); let thread_subscription = match &thread { - AgentDiffThread::Native(thread) => cx.subscribe_in(&thread, window, { + AgentDiffThread::Native(thread) => cx.subscribe_in(thread, window, { let workspace = workspace.clone(); move |this, _thread, event, window, cx| { this.handle_native_thread_event(&workspace, event, window, cx) } }), - AgentDiffThread::AcpThread(thread) => cx.subscribe_in(&thread, window, { + AgentDiffThread::AcpThread(thread) => cx.subscribe_in(thread, window, { let workspace = workspace.clone(); move |this, thread, event, window, cx| { this.handle_acp_thread_event(&workspace, thread, event, window, cx) @@ -1357,11 +1353,11 @@ impl AgentDiff { }), }; - if let Some(workspace_thread) = self.workspace_threads.get_mut(&workspace) { + if let Some(workspace_thread) = self.workspace_threads.get_mut(workspace) { // replace thread and action log subscription, but keep editors workspace_thread.thread = thread.downgrade(); workspace_thread._thread_subscriptions = (action_log_subscription, thread_subscription); - self.update_reviewing_editors(&workspace, window, cx); + self.update_reviewing_editors(workspace, window, cx); return; } @@ -1506,7 +1502,7 @@ impl AgentDiff { .read(cx) .entries() .last() - .map_or(false, |entry| entry.diffs().next().is_some()) + .is_some_and(|entry| entry.diffs().next().is_some()) { self.update_reviewing_editors(workspace, window, cx); } @@ -1516,16 +1512,25 @@ impl AgentDiff { .read(cx) .entries() .get(*ix) - .map_or(false, |entry| entry.diffs().next().is_some()) + .is_some_and(|entry| entry.diffs().next().is_some()) { self.update_reviewing_editors(workspace, window, cx); } } - AcpThreadEvent::EntriesRemoved(_) - | AcpThreadEvent::Stopped - | AcpThreadEvent::ToolAuthorizationRequired + AcpThreadEvent::Stopped | AcpThreadEvent::Error - | AcpThreadEvent::ServerExited(_) => {} + | AcpThreadEvent::LoadError(_) + | AcpThreadEvent::Refusal => { + self.update_reviewing_editors(workspace, window, cx); + } + AcpThreadEvent::TitleUpdated + | AcpThreadEvent::TokenUsageUpdated + | AcpThreadEvent::EntriesRemoved(_) + | AcpThreadEvent::ToolAuthorizationRequired + | AcpThreadEvent::PromptCapabilitiesUpdated + | AcpThreadEvent::AvailableCommandsUpdated(_) + | AcpThreadEvent::Retry(_) + | AcpThreadEvent::ModeUpdated(_) => {} } } @@ -1536,21 +1541,11 @@ impl AgentDiff { window: &mut Window, cx: &mut Context, ) { - match event { - workspace::Event::ItemAdded { item } => { - if let Some(editor) = item.downcast::() { - if let Some(buffer) = Self::full_editor_buffer(editor.read(cx), cx) { - self.register_editor( - workspace.downgrade(), - buffer.clone(), - editor, - window, - cx, - ); - } - } - } - _ => {} + if let workspace::Event::ItemAdded { item } = event + && let Some(editor) = item.downcast::() + && let Some(buffer) = Self::full_editor_buffer(editor.read(cx), cx) + { + self.register_editor(workspace.downgrade(), buffer, editor, window, cx); } } @@ -1649,7 +1644,7 @@ impl AgentDiff { continue; }; - for (weak_editor, _) in buffer_editors { + for weak_editor in buffer_editors.keys() { let Some(editor) = weak_editor.upgrade() else { continue; }; @@ -1677,7 +1672,7 @@ impl AgentDiff { editor.register_addon(EditorAgentDiffAddon); }); } else { - unaffected.remove(&weak_editor); + unaffected.remove(weak_editor); } if new_state == EditorState::Reviewing && previous_state != Some(new_state) { @@ -1710,7 +1705,7 @@ impl AgentDiff { .read_with(cx, |editor, _cx| editor.workspace()) .ok() .flatten() - .map_or(false, |editor_workspace| { + .is_some_and(|editor_workspace| { editor_workspace.entity_id() == workspace.entity_id() }); @@ -1730,7 +1725,7 @@ impl AgentDiff { fn editor_state(&self, editor: &WeakEntity) -> EditorState { self.reviewing_editors - .get(&editor) + .get(editor) .cloned() .unwrap_or(EditorState::Idle) } @@ -1850,26 +1845,26 @@ impl AgentDiff { let thread = thread.upgrade()?; - if let PostReviewState::AllReviewed = review(&editor, &thread, window, cx) { - if let Some(curr_buffer) = editor.read(cx).buffer().read(cx).as_singleton() { - let changed_buffers = thread.action_log(cx).read(cx).changed_buffers(cx); - - let mut keys = changed_buffers.keys().cycle(); - keys.find(|k| *k == &curr_buffer); - let next_project_path = keys - .next() - .filter(|k| *k != &curr_buffer) - .and_then(|after| after.read(cx).project_path(cx)); - - if let Some(path) = next_project_path { - let task = workspace.open_path(path, None, true, window, cx); - let task = cx.spawn(async move |_, _cx| task.await.map(|_| ())); - return Some(task); - } + if let PostReviewState::AllReviewed = review(&editor, &thread, window, cx) + && let Some(curr_buffer) = editor.read(cx).buffer().read(cx).as_singleton() + { + let changed_buffers = thread.action_log(cx).read(cx).changed_buffers(cx); + + let mut keys = changed_buffers.keys().cycle(); + keys.find(|k| *k == &curr_buffer); + let next_project_path = keys + .next() + .filter(|k| *k != &curr_buffer) + .and_then(|after| after.read(cx).project_path(cx)); + + if let Some(path) = next_project_path { + let task = workspace.open_path(path, None, true, window, cx); + let task = cx.spawn(async move |_, _cx| task.await.map(|_| ())); + return Some(task); } } - return Some(Task::ready(Ok(()))); + Some(Task::ready(Ok(()))) } } diff --git a/crates/agent_ui/src/agent_model_selector.rs b/crates/agent_ui/src/agent_model_selector.rs index b989e7bf1e9147c7f6beb90b5054120cef7b818f..3de1027d91f6d613e9f3aa723b345e5d5f17ee6f 100644 --- a/crates/agent_ui/src/agent_model_selector.rs +++ b/crates/agent_ui/src/agent_model_selector.rs @@ -66,10 +66,8 @@ impl AgentModelSelector { fs.clone(), cx, move |settings, _cx| { - settings.set_inline_assistant_model( - provider.clone(), - model_id.clone(), - ); + settings + .set_inline_assistant_model(provider.clone(), model_id); }, ); } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index e47cbe3714939aa3b839d575112c3b9699a0eeba..2ace7096c2588699939bc02b92841ea5b5db2e06 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -1,17 +1,22 @@ -use std::cell::RefCell; use std::ops::{Not, Range}; use std::path::Path; use std::rc::Rc; use std::sync::Arc; use std::time::Duration; -use agent_servers::AgentServer; +use acp_thread::AcpThread; +use agent2::{DbThreadMetadata, HistoryEntry}; use db::kvp::{Dismissable, KEY_VALUE_STORE}; +use project::agent_server_store::{ + AgentServerCommand, AllAgentServersSettings, CLAUDE_CODE_NAME, GEMINI_NAME, +}; use serde::{Deserialize, Serialize}; +use zed_actions::OpenBrowser; +use zed_actions::agent::{OpenClaudeCodeOnboardingModal, ReauthenticateAgent}; -use crate::NewExternalAgentThread; +use crate::acp::{AcpThreadHistory, ThreadHistoryEvent}; use crate::agent_diff::AgentDiffThread; -use crate::message_editor::{MAX_EDITOR_LINES, MIN_EDITOR_LINES}; +use crate::ui::{AcpOnboardingModal, ClaudeCodeOnboardingModal}; use crate::{ AddContextServer, AgentDiffPane, ContinueThread, ContinueWithBurnMode, DeleteRecentlyOpenThread, ExpandMessageEditor, Follow, InlineAssistant, NewTextThread, @@ -26,11 +31,13 @@ use crate::{ slash_command::SlashCommandCompletionProvider, text_thread_editor::{ AgentPanelDelegate, TextThreadEditor, humanize_token_count, make_lsp_adapter_delegate, - render_remaining_tokens, }, thread_history::{HistoryEntryElement, ThreadHistory}, ui::{AgentOnboardingModal, EndTrialUpsell}, }; +use crate::{ + ExternalAgent, NewExternalAgentThread, NewNativeAgentThreadFromSummary, placeholder_command, +}; use agent::{ Thread, ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio, context_store::ContextStore, @@ -44,25 +51,22 @@ use assistant_context::{AssistantContext, ContextEvent, ContextSummary}; use assistant_slash_command::SlashCommandWorkingSet; use assistant_tool::ToolWorkingSet; use client::{UserStore, zed_urls}; -use cloud_llm_client::{CompletionIntent, Plan, UsageLimit}; +use cloud_llm_client::{CompletionIntent, Plan, PlanV1, PlanV2, UsageLimit}; use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer}; -use feature_flags::{self, FeatureFlagAppExt}; +use feature_flags::{self, ClaudeCodeFeatureFlag, FeatureFlagAppExt, GeminiAndNativeFeatureFlag}; use fs::Fs; use gpui::{ Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, ClipboardItem, - Corner, DismissEvent, Entity, EventEmitter, ExternalPaths, FocusHandle, Focusable, Hsla, - KeyContext, Pixels, Subscription, Task, UpdateGlobal, WeakEntity, prelude::*, - pulsating_between, + Corner, DismissEvent, Entity, EventEmitter, ExternalPaths, FocusHandle, Focusable, KeyContext, + Pixels, Subscription, Task, UpdateGlobal, WeakEntity, prelude::*, pulsating_between, }; use language::LanguageRegistry; -use language_model::{ - ConfigurationError, ConfiguredModel, LanguageModelProviderTosView, LanguageModelRegistry, -}; +use language_model::{ConfigurationError, ConfiguredModel, LanguageModelRegistry}; use project::{DisableAiSettings, Project, ProjectPath, Worktree}; use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; use rules_library::{RulesLibrary, open_rules_library}; use search::{BufferSearchBar, buffer_search}; -use settings::{Settings, update_settings_file}; +use settings::{Settings, SettingsStore, update_settings_file}; use theme::ThemeSettings; use time::UtcOffset; use ui::utils::WithRemSize; @@ -77,13 +81,16 @@ use workspace::{ }; use zed_actions::{ DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize, - agent::{OpenOnboardingModal, OpenSettings, ResetOnboarding, ToggleModelSelector}, + agent::{ + OpenAcpOnboardingModal, OpenOnboardingModal, OpenSettings, ResetOnboarding, + ToggleModelSelector, + }, assistant::{OpenRulesLibrary, ToggleFocus}, }; const AGENT_PANEL_KEY: &str = "agent_panel"; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] struct SerializedAgentPanel { width: Option, selected_agent: Option, @@ -99,6 +106,16 @@ pub fn init(cx: &mut App) { workspace.focus_panel::(window, cx); } }) + .register_action( + |workspace, action: &NewNativeAgentThreadFromSummary, window, cx| { + if let Some(panel) = workspace.panel::(cx) { + panel.update(cx, |panel, cx| { + panel.new_native_agent_thread_from_summary(action, window, cx) + }); + workspace.focus_panel::(window, cx); + } + }, + ) .register_action(|workspace, _: &OpenHistory, window, cx| { if let Some(panel) = workspace.panel::(cx) { workspace.focus_panel::(window, cx); @@ -121,7 +138,7 @@ pub fn init(cx: &mut App) { if let Some(panel) = workspace.panel::(cx) { workspace.focus_panel::(window, cx); panel.update(cx, |panel, cx| { - panel.new_external_thread(action.agent, window, cx) + panel.external_thread(action.agent.clone(), None, None, window, cx) }); } }) @@ -191,6 +208,12 @@ pub fn init(cx: &mut App) { .register_action(|workspace, _: &OpenOnboardingModal, window, cx| { AgentOnboardingModal::toggle(workspace, window, cx) }) + .register_action(|workspace, _: &OpenAcpOnboardingModal, window, cx| { + AcpOnboardingModal::toggle(workspace, window, cx) + }) + .register_action(|workspace, _: &OpenClaudeCodeOnboardingModal, window, cx| { + ClaudeCodeOnboardingModal::toggle(workspace, window, cx) + }) .register_action(|_workspace, _: &ResetOnboarding, window, cx| { window.dispatch_action(workspace::RestoreBanner.boxed_clone(), cx); window.refresh(); @@ -232,7 +255,8 @@ enum WhichFontSize { None, } -#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +// TODO unify this with ExternalAgent +#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] pub enum AgentType { #[default] Zed, @@ -240,24 +264,40 @@ pub enum AgentType { Gemini, ClaudeCode, NativeAgent, + Custom { + name: SharedString, + command: AgentServerCommand, + }, } impl AgentType { - fn label(self) -> impl Into { + fn label(&self) -> SharedString { match self { - Self::Zed | Self::TextThread => "Zed", - Self::NativeAgent => "Agent 2", - Self::Gemini => "Gemini", - Self::ClaudeCode => "Claude Code", + Self::Zed | Self::TextThread => "Zed Agent".into(), + Self::NativeAgent => "Agent 2".into(), + Self::Gemini => "Gemini CLI".into(), + Self::ClaudeCode => "Claude Code".into(), + Self::Custom { name, .. } => name.into(), } } - fn icon(self) -> IconName { + fn icon(&self) -> Option { match self { - Self::Zed | Self::TextThread => IconName::AiZed, - Self::NativeAgent => IconName::ZedAssistant, - Self::Gemini => IconName::AiGemini, - Self::ClaudeCode => IconName::AiClaude, + Self::Zed | Self::NativeAgent | Self::TextThread => None, + Self::Gemini => Some(IconName::AiGemini), + Self::ClaudeCode => Some(IconName::AiClaude), + Self::Custom { .. } => Some(IconName::Terminal), + } + } +} + +impl From for AgentType { + fn from(value: ExternalAgent) -> Self { + match value { + ExternalAgent::Gemini => Self::Gemini, + ExternalAgent::ClaudeCode => Self::ClaudeCode, + ExternalAgent::Custom { name, command } => Self::Custom { name, command }, + ExternalAgent::NativeAgent => Self::NativeAgent, } } } @@ -356,7 +396,7 @@ impl ActiveView { Self::Thread { change_title_editor: editor, thread: active_thread, - message_editor: message_editor, + message_editor, _subscriptions: subscriptions, } } @@ -364,6 +404,7 @@ impl ActiveView { pub fn prompt_editor( context_editor: Entity, history_store: Entity, + acp_history_store: Entity, language_registry: Arc, window: &mut Window, cx: &mut App, @@ -441,6 +482,18 @@ impl ActiveView { ); } }); + + acp_history_store.update(cx, |history_store, cx| { + if let Some(old_path) = old_path { + history_store + .replace_recently_opened_text_thread(old_path, new_path, cx); + } else { + history_store.push_recently_opened_entry( + agent2::HistoryEntryId::TextThread(new_path.clone()), + cx, + ); + } + }); } _ => {} } @@ -469,6 +522,8 @@ pub struct AgentPanel { fs: Arc, language_registry: Arc, thread_store: Entity, + acp_history: Entity, + acp_history_store: Entity, _default_model_subscription: Subscription, context_store: Entity, prompt_store: Option>, @@ -477,8 +532,6 @@ pub struct AgentPanel { configuration_subscription: Option, local_timezone: UtcOffset, active_view: ActiveView, - acp_message_history: - Rc>>>, previous_view: Option, history_store: Entity, history: Entity, @@ -498,7 +551,7 @@ pub struct AgentPanel { impl AgentPanel { fn serialize(&mut self, cx: &mut Context) { let width = self.width; - let selected_agent = self.selected_agent; + let selected_agent = self.selected_agent.clone(); self.pending_serialization = Some(cx.background_spawn(async move { KEY_VALUE_STORE .write_kvp( @@ -512,6 +565,7 @@ impl AgentPanel { anyhow::Ok(()) })); } + pub fn load( workspace: WeakEntity, prompt_builder: Arc, @@ -556,7 +610,7 @@ impl AgentPanel { .log_err() .flatten() { - Some(serde_json::from_str::(&panel)?) + serde_json::from_str::(&panel).log_err() } else { None }; @@ -576,10 +630,15 @@ impl AgentPanel { panel.update(cx, |panel, cx| { panel.width = serialized_panel.width.map(|w| w.round()); if let Some(selected_agent) = serialized_panel.selected_agent { - panel.selected_agent = selected_agent; + panel.selected_agent = selected_agent.clone(); + panel.new_agent_thread(selected_agent, window, cx); } cx.notify(); }); + } else { + panel.update(cx, |panel, cx| { + panel.new_agent_thread(AgentType::NativeAgent, window, cx); + }); } panel })?; @@ -636,6 +695,29 @@ impl AgentPanel { ) }); + let acp_history_store = cx.new(|cx| agent2::HistoryStore::new(context_store.clone(), cx)); + let acp_history = cx.new(|cx| AcpThreadHistory::new(acp_history_store.clone(), window, cx)); + cx.subscribe_in( + &acp_history, + window, + |this, _, event, window, cx| match event { + ThreadHistoryEvent::Open(HistoryEntry::AcpThread(thread)) => { + this.external_thread( + Some(crate::ExternalAgent::NativeAgent), + Some(thread.clone()), + None, + window, + cx, + ); + } + ThreadHistoryEvent::Open(HistoryEntry::TextThread(thread)) => { + this.open_saved_prompt_editor(thread.path.clone(), window, cx) + .detach_and_log_err(cx); + } + }, + ) + .detach(); + cx.observe(&history_store, |_, _, cx| cx.notify()).detach(); let active_thread = cx.new(|cx| { @@ -674,6 +756,7 @@ impl AgentPanel { ActiveView::prompt_editor( context_editor, history_store.clone(), + acp_history_store.clone(), language_registry.clone(), window, cx, @@ -690,7 +773,11 @@ impl AgentPanel { let assistant_navigation_menu = ContextMenu::build_persistent(window, cx, move |mut menu, _window, cx| { if let Some(panel) = panel.upgrade() { - menu = Self::populate_recently_opened_menu_section(menu, panel, cx); + if cx.has_flag::() { + menu = Self::populate_recently_opened_menu_section_new(menu, panel, cx); + } else { + menu = Self::populate_recently_opened_menu_section_old(menu, panel, cx); + } } menu.action("View All", Box::new(OpenHistory)) .end_slot_action(DeleteRecentlyOpenThread.boxed_clone()) @@ -716,25 +803,25 @@ impl AgentPanel { .ok(); }); - let _default_model_subscription = cx.subscribe( - &LanguageModelRegistry::global(cx), - |this, _, event: &language_model::Event, cx| match event { - language_model::Event::DefaultModelChanged => match &this.active_view { - ActiveView::Thread { thread, .. } => { - thread - .read(cx) - .thread() - .clone() - .update(cx, |thread, cx| thread.get_or_init_configured_model(cx)); + let _default_model_subscription = + cx.subscribe( + &LanguageModelRegistry::global(cx), + |this, _, event: &language_model::Event, cx| { + if let language_model::Event::DefaultModelChanged = event { + match &this.active_view { + ActiveView::Thread { thread, .. } => { + thread.read(cx).thread().clone().update(cx, |thread, cx| { + thread.get_or_init_configured_model(cx) + }); + } + ActiveView::ExternalAgentThread { .. } + | ActiveView::TextThread { .. } + | ActiveView::History + | ActiveView::Configuration => {} + } } - ActiveView::ExternalAgentThread { .. } - | ActiveView::TextThread { .. } - | ActiveView::History - | ActiveView::Configuration => {} }, - _ => {} - }, - ); + ); let onboarding = cx.new(|cx| { AgentPanelOnboarding::new( @@ -766,7 +853,6 @@ impl AgentPanel { .unwrap(), inline_assist_context_store, previous_view: None, - acp_message_history: Default::default(), history_store: history_store.clone(), history: cx.new(|cx| ThreadHistory::new(weak_self, history_store, window, cx)), hovered_recent_history_item: None, @@ -779,6 +865,8 @@ impl AgentPanel { zoomed: false, pending_serialization: None, onboarding, + acp_history, + acp_history_store, selected_agent: AgentType::default(), } } @@ -823,10 +911,10 @@ impl AgentPanel { ActiveView::Thread { thread, .. } => { thread.update(cx, |thread, cx| thread.cancel_last_completion(window, cx)); } - ActiveView::ExternalAgentThread { thread_view, .. } => { - thread_view.update(cx, |thread_element, cx| thread_element.cancel(cx)); - } - ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} + ActiveView::ExternalAgentThread { .. } + | ActiveView::TextThread { .. } + | ActiveView::History + | ActiveView::Configuration => {} } } @@ -840,7 +928,20 @@ impl AgentPanel { } } + fn active_thread_view(&self) -> Option<&Entity> { + match &self.active_view { + ActiveView::ExternalAgentThread { thread_view, .. } => Some(thread_view), + ActiveView::Thread { .. } + | ActiveView::TextThread { .. } + | ActiveView::History + | ActiveView::Configuration => None, + } + } + fn new_thread(&mut self, action: &NewThread, window: &mut Window, cx: &mut Context) { + if cx.has_flag::() { + return self.new_agent_thread(AgentType::NativeAgent, window, cx); + } // Preserve chat box text when using creating new thread let preserved_text = self .active_message_editor() @@ -913,13 +1014,38 @@ impl AgentPanel { message_editor.focus_handle(cx).focus(window); - let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx); + let thread_view = ActiveView::thread(active_thread, message_editor, window, cx); self.set_active_view(thread_view, window, cx); AgentDiff::set_active_thread(&self.workspace, thread.clone(), window, cx); } + fn new_native_agent_thread_from_summary( + &mut self, + action: &NewNativeAgentThreadFromSummary, + window: &mut Window, + cx: &mut Context, + ) { + let Some(thread) = self + .acp_history_store + .read(cx) + .thread_from_session_id(&action.from_session_id) + else { + return; + }; + + self.external_thread( + Some(ExternalAgent::NativeAgent), + None, + Some(thread.clone()), + window, + cx, + ); + } + fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context) { + telemetry::event!("Agent Thread Started", agent = "zed-text"); + let context = self .context_store .update(cx, |context_store, cx| context_store.create(cx)); @@ -941,10 +1067,16 @@ impl AgentPanel { editor }); + if self.selected_agent != AgentType::TextThread { + self.selected_agent = AgentType::TextThread; + self.serialize(cx); + } + self.set_active_view( ActiveView::prompt_editor( context_editor.clone(), self.history_store.clone(), + self.acp_history_store.clone(), self.language_registry.clone(), window, cx, @@ -955,16 +1087,18 @@ impl AgentPanel { context_editor.focus_handle(cx).focus(window); } - fn new_external_thread( + fn external_thread( &mut self, agent_choice: Option, + resume_thread: Option, + summarize_thread: Option, window: &mut Window, cx: &mut Context, ) { let workspace = self.workspace.clone(); let project = self.project.clone(); - let message_history = self.acp_message_history.clone(); let fs = self.fs.clone(); + let is_via_collab = self.project.read(cx).is_via_collab(); const LAST_USED_EXTERNAL_AGENT_KEY: &str = "agent_panel__last_used_external_agent"; @@ -973,52 +1107,82 @@ impl AgentPanel { agent: crate::ExternalAgent, } - let thread_store = self.thread_store.clone(); - let text_thread_store = self.context_store.clone(); + let history = self.acp_history_store.clone(); cx.spawn_in(window, async move |this, cx| { - let server: Rc = match agent_choice { + let ext_agent = match agent_choice { Some(agent) => { - cx.background_spawn(async move { - if let Some(serialized) = - serde_json::to_string(&LastUsedExternalAgent { agent }).log_err() - { - KEY_VALUE_STORE - .write_kvp(LAST_USED_EXTERNAL_AGENT_KEY.to_string(), serialized) - .await - .log_err(); + cx.background_spawn({ + let agent = agent.clone(); + async move { + if let Some(serialized) = + serde_json::to_string(&LastUsedExternalAgent { agent }).log_err() + { + KEY_VALUE_STORE + .write_kvp(LAST_USED_EXTERNAL_AGENT_KEY.to_string(), serialized) + .await + .log_err(); + } } }) .detach(); - agent.server(fs) + agent + } + None => { + if is_via_collab { + ExternalAgent::NativeAgent + } else { + cx.background_spawn(async move { + KEY_VALUE_STORE.read_kvp(LAST_USED_EXTERNAL_AGENT_KEY) + }) + .await + .log_err() + .flatten() + .and_then(|value| { + serde_json::from_str::(&value).log_err() + }) + .unwrap_or_default() + .agent + } } - None => cx - .background_spawn(async move { - KEY_VALUE_STORE.read_kvp(LAST_USED_EXTERNAL_AGENT_KEY) - }) - .await - .log_err() - .flatten() - .and_then(|value| { - serde_json::from_str::(&value).log_err() - }) - .unwrap_or_default() - .agent - .server(fs), }; + telemetry::event!("Agent Thread Started", agent = ext_agent.name()); + + let server = ext_agent.server(fs, history); + this.update_in(cx, |this, window, cx| { + match ext_agent { + crate::ExternalAgent::Gemini + | crate::ExternalAgent::NativeAgent + | crate::ExternalAgent::Custom { .. } => { + if !cx.has_flag::() { + return; + } + } + crate::ExternalAgent::ClaudeCode => { + if !cx.has_flag::() { + return; + } + } + } + + let selected_agent = ext_agent.into(); + if this.selected_agent != selected_agent { + this.selected_agent = selected_agent; + this.serialize(cx); + } + let thread_view = cx.new(|cx| { crate::acp::AcpThreadView::new( server, + resume_thread, + summarize_thread, workspace.clone(), project, - thread_store.clone(), - text_thread_store.clone(), - message_history, - MIN_EDITOR_LINES, - Some(MAX_EDITOR_LINES), + this.acp_history_store.clone(), + this.prompt_store.clone(), window, cx, ) @@ -1105,10 +1269,17 @@ impl AgentPanel { cx, ) }); + + if self.selected_agent != AgentType::TextThread { + self.selected_agent = AgentType::TextThread; + self.serialize(cx); + } + self.set_active_view( ActiveView::prompt_editor( - editor.clone(), + editor, self.history_store.clone(), + self.acp_history_store.clone(), self.language_registry.clone(), window, cx, @@ -1179,7 +1350,7 @@ impl AgentPanel { }); message_editor.focus_handle(cx).focus(window); - let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx); + let thread_view = ActiveView::thread(active_thread, message_editor, window, cx); self.set_active_view(thread_view, window, cx); AgentDiff::set_active_thread(&self.workspace, thread.clone(), window, cx); } @@ -1266,13 +1437,11 @@ impl AgentPanel { ThemeSettings::get_global(cx).agent_font_size(cx) + delta; let _ = settings .agent_font_size - .insert(theme::clamp_font_size(agent_font_size).0); + .insert(Some(theme::clamp_font_size(agent_font_size).into())); }, ); } else { - theme::adjust_agent_font_size(cx, |size| { - *size += delta; - }); + theme::adjust_agent_font_size(cx, |size| size + delta); } } WhichFontSize::BufferFont => { @@ -1338,6 +1507,7 @@ impl AgentPanel { } pub(crate) fn open_configuration(&mut self, window: &mut Window, cx: &mut Context) { + let agent_server_store = self.project.read(cx).agent_server_store().clone(); let context_server_store = self.project.read(cx).context_server_store(); let tools = self.thread_store.read(cx).tools(); let fs = self.fs.clone(); @@ -1346,6 +1516,7 @@ impl AgentPanel { self.configuration = Some(cx.new(|cx| { AgentConfiguration::new( fs, + agent_server_store, context_server_store, tools, self.language_registry.clone(), @@ -1408,15 +1579,14 @@ impl AgentPanel { AssistantConfigurationEvent::NewThread(provider) => { if LanguageModelRegistry::read_global(cx) .default_model() - .map_or(true, |model| model.provider.id() != provider.id()) + .is_none_or(|model| model.provider.id() != provider.id()) + && let Some(model) = provider.default_model(cx) { - if let Some(model) = provider.default_model(cx) { - update_settings_file::( - self.fs.clone(), - cx, - move |settings, _| settings.set_model(model), - ); - } + update_settings_file::( + self.fs.clone(), + cx, + move |settings, _| settings.set_model(model), + ); } self.new_thread(&NewThread::default(), window, cx); @@ -1443,6 +1613,14 @@ impl AgentPanel { _ => None, } } + pub(crate) fn active_agent_thread(&self, cx: &App) -> Option> { + match &self.active_view { + ActiveView::ExternalAgentThread { thread_view, .. } => { + thread_view.read(cx).thread().cloned() + } + _ => None, + } + } pub(crate) fn delete_thread( &mut self, @@ -1463,7 +1641,7 @@ impl AgentPanel { return; } - let model = thread_state.configured_model().map(|cm| cm.model.clone()); + let model = thread_state.configured_model().map(|cm| cm.model); if let Some(model) = model { thread.update(cx, |active_thread, cx| { active_thread.thread().update(cx, |thread, cx| { @@ -1535,17 +1713,14 @@ impl AgentPanel { let current_is_special = current_is_history || current_is_config; let new_is_special = new_is_history || new_is_config; - match &self.active_view { - ActiveView::Thread { thread, .. } => { - let thread = thread.read(cx); - if thread.is_empty() { - let id = thread.thread().read(cx).id().clone(); - self.history_store.update(cx, |store, cx| { - store.remove_recently_opened_thread(id, cx); - }); - } + if let ActiveView::Thread { thread, .. } = &self.active_view { + let thread = thread.read(cx); + if thread.is_empty() { + let id = thread.thread().read(cx).id().clone(); + self.history_store.update(cx, |store, cx| { + store.remove_recently_opened_thread(id, cx); + }); } - _ => {} } match &new_view { @@ -1558,6 +1733,14 @@ impl AgentPanel { if let Some(path) = context_editor.read(cx).context().read(cx).path() { store.push_recently_opened_entry(HistoryEntryId::Context(path.clone()), cx) } + }); + self.acp_history_store.update(cx, |store, cx| { + if let Some(path) = context_editor.read(cx).context().read(cx).path() { + store.push_recently_opened_entry( + agent2::HistoryEntryId::TextThread(path.clone()), + cx, + ) + } }) } ActiveView::ExternalAgentThread { .. } => {} @@ -1575,12 +1758,10 @@ impl AgentPanel { self.active_view = new_view; } - self.acp_message_history.borrow_mut().reset_position(); - self.focus_handle(cx).focus(window); } - fn populate_recently_opened_menu_section( + fn populate_recently_opened_menu_section_old( mut menu: ContextMenu, panel: Entity, cx: &mut Context, @@ -1615,7 +1796,7 @@ impl AgentPanel { .open_thread_by_id(&id, window, cx) .detach_and_log_err(cx), HistoryEntryId::Context(path) => this - .open_saved_prompt_editor(path.clone(), window, cx) + .open_saved_prompt_editor(path, window, cx) .detach_and_log_err(cx), }) .ok(); @@ -1644,15 +1825,140 @@ impl AgentPanel { menu } - pub fn set_selected_agent(&mut self, agent: AgentType, cx: &mut Context) { - if self.selected_agent != agent { - self.selected_agent = agent; - self.serialize(cx); + fn populate_recently_opened_menu_section_new( + mut menu: ContextMenu, + panel: Entity, + cx: &mut Context, + ) -> ContextMenu { + let entries = panel + .read(cx) + .acp_history_store + .read(cx) + .recently_opened_entries(cx); + + if entries.is_empty() { + return menu; + } + + menu = menu.header("Recently Opened"); + + for entry in entries { + let title = entry.title().clone(); + + menu = menu.entry_with_end_slot_on_hover( + title, + None, + { + let panel = panel.downgrade(); + let entry = entry.clone(); + move |window, cx| { + let entry = entry.clone(); + panel + .update(cx, move |this, cx| match &entry { + agent2::HistoryEntry::AcpThread(entry) => this.external_thread( + Some(ExternalAgent::NativeAgent), + Some(entry.clone()), + None, + window, + cx, + ), + agent2::HistoryEntry::TextThread(entry) => this + .open_saved_prompt_editor(entry.path.clone(), window, cx) + .detach_and_log_err(cx), + }) + .ok(); + } + }, + IconName::Close, + "Close Entry".into(), + { + let panel = panel.downgrade(); + let id = entry.id(); + move |_window, cx| { + panel + .update(cx, |this, cx| { + this.acp_history_store.update(cx, |history_store, cx| { + history_store.remove_recently_opened_entry(&id, cx); + }); + }) + .ok(); + } + }, + ); } + + menu = menu.separator(); + + menu } pub fn selected_agent(&self) -> AgentType { - self.selected_agent + self.selected_agent.clone() + } + + pub fn new_agent_thread( + &mut self, + agent: AgentType, + window: &mut Window, + cx: &mut Context, + ) { + match agent { + AgentType::Zed => { + window.dispatch_action( + NewThread { + from_thread_id: None, + } + .boxed_clone(), + cx, + ); + } + AgentType::TextThread => { + window.dispatch_action(NewTextThread.boxed_clone(), cx); + } + AgentType::NativeAgent => self.external_thread( + Some(crate::ExternalAgent::NativeAgent), + None, + None, + window, + cx, + ), + AgentType::Gemini => { + self.external_thread(Some(crate::ExternalAgent::Gemini), None, None, window, cx) + } + AgentType::ClaudeCode => { + self.selected_agent = AgentType::ClaudeCode; + self.serialize(cx); + self.external_thread( + Some(crate::ExternalAgent::ClaudeCode), + None, + None, + window, + cx, + ) + } + AgentType::Custom { name, command } => self.external_thread( + Some(crate::ExternalAgent::Custom { name, command }), + None, + None, + window, + cx, + ), + } + } + + pub fn load_agent_thread( + &mut self, + thread: DbThreadMetadata, + window: &mut Window, + cx: &mut Context, + ) { + self.external_thread( + Some(ExternalAgent::NativeAgent), + Some(thread), + None, + window, + cx, + ); } } @@ -1661,7 +1967,13 @@ impl Focusable for AgentPanel { match &self.active_view { ActiveView::Thread { message_editor, .. } => message_editor.focus_handle(cx), ActiveView::ExternalAgentThread { thread_view, .. } => thread_view.focus_handle(cx), - ActiveView::History => self.history.focus_handle(cx), + ActiveView::History => { + if cx.has_flag::() { + self.acp_history.focus_handle(cx) + } else { + self.history.focus_handle(cx) + } + } ActiveView::TextThread { context_editor, .. } => context_editor.focus_handle(cx), ActiveView::Configuration => { if let Some(configuration) = self.configuration.as_ref() { @@ -1783,11 +2095,13 @@ impl AgentPanel { }; match state { - ThreadSummary::Pending => Label::new(ThreadSummary::DEFAULT.clone()) + ThreadSummary::Pending => Label::new(ThreadSummary::DEFAULT) .truncate() + .color(Color::Muted) .into_any_element(), ThreadSummary::Generating => Label::new(LOADING_SUMMARY_PLACEHOLDER) .truncate() + .color(Color::Muted) .into_any_element(), ThreadSummary::Ready(_) => div() .w_full() @@ -1797,7 +2111,8 @@ impl AgentPanel { .w_full() .child(change_title_editor.clone()) .child( - ui::IconButton::new("retry-summary-generation", IconName::RotateCcw) + IconButton::new("retry-summary-generation", IconName::RotateCcw) + .icon_size(IconSize::Small) .on_click({ let active_thread = active_thread.clone(); move |_, _window, cx| { @@ -1818,9 +2133,33 @@ impl AgentPanel { } } ActiveView::ExternalAgentThread { thread_view } => { - Label::new(thread_view.read(cx).title(cx)) - .truncate() - .into_any_element() + if let Some(title_editor) = thread_view.read(cx).title_editor() { + div() + .w_full() + .on_action({ + let thread_view = thread_view.downgrade(); + move |_: &menu::Confirm, window, cx| { + if let Some(thread_view) = thread_view.upgrade() { + thread_view.focus_handle(cx).focus(window); + } + } + }) + .on_action({ + let thread_view = thread_view.downgrade(); + move |_: &editor::actions::Cancel, window, cx| { + if let Some(thread_view) = thread_view.upgrade() { + thread_view.focus_handle(cx).focus(window); + } + } + }) + .child(title_editor) + .into_any_element() + } else { + Label::new(thread_view.read(cx).title(cx)) + .color(Color::Muted) + .truncate() + .into_any_element() + } } ActiveView::TextThread { title_editor, @@ -1831,6 +2170,7 @@ impl AgentPanel { match summary { ContextSummary::Pending => Label::new(ContextSummary::DEFAULT) + .color(Color::Muted) .truncate() .into_any_element(), ContextSummary::Content(summary) => { @@ -1842,6 +2182,7 @@ impl AgentPanel { } else { Label::new(LOADING_SUMMARY_PLACEHOLDER) .truncate() + .color(Color::Muted) .into_any_element() } } @@ -1849,7 +2190,8 @@ impl AgentPanel { .w_full() .child(title_editor.clone()) .child( - ui::IconButton::new("retry-summary-generation", IconName::RotateCcw) + IconButton::new("retry-summary-generation", IconName::RotateCcw) + .icon_size(IconSize::Small) .on_click({ let context_editor = context_editor.clone(); move |_, _window, cx| { @@ -1901,6 +2243,8 @@ impl AgentPanel { "Enable Full Screen" }; + let selected_agent = self.selected_agent.clone(); + PopoverMenu::new("agent-options-menu") .trigger_with_tooltip( IconButton::new("agent-options-menu", IconName::Ellipsis) @@ -1921,7 +2265,6 @@ impl AgentPanel { .anchor(Corner::TopRight) .with_handle(self.agent_panel_menu_handle.clone()) .menu({ - let focus_handle = focus_handle.clone(); move |window, cx| { Some(ContextMenu::build(window, cx, |mut menu, _window, _| { menu = menu.context(focus_handle.clone()); @@ -1981,6 +2324,11 @@ impl AgentPanel { .action("Settings", Box::new(OpenSettings)) .separator() .action(full_screen_label, Box::new(ToggleZoom)); + + if selected_agent == AgentType::Gemini { + menu = menu.action("Reauthenticate", Box::new(ReauthenticateAgent)) + } + menu })) } @@ -1990,6 +2338,7 @@ impl AgentPanel { fn render_recent_entries_menu( &self, icon: IconName, + corner: Corner, cx: &mut Context, ) -> impl IntoElement { let focus_handle = self.focus_handle(cx); @@ -1998,10 +2347,9 @@ impl AgentPanel { .trigger_with_tooltip( IconButton::new("agent-nav-menu", icon).icon_size(IconSize::Small), { - let focus_handle = focus_handle.clone(); move |window, cx| { Tooltip::for_action_in( - "Toggle Panel Menu", + "Toggle Recent Threads", &ToggleNavigationMenu, &focus_handle, window, @@ -2010,11 +2358,13 @@ impl AgentPanel { } }, ) - .anchor(Corner::TopLeft) + .anchor(corner) .with_handle(self.assistant_navigation_menu_handle.clone()) .menu({ let menu = self.assistant_navigation_menu.clone(); move |window, cx| { + telemetry::event!("View Thread History Clicked"); + if let Some(menu) = menu.as_ref() { menu.update(cx, |_, cx| { cx.defer_in(window, |menu, window, cx| { @@ -2036,8 +2386,6 @@ impl AgentPanel { this.go_back(&workspace::GoBack, window, cx); })) .tooltip({ - let focus_handle = focus_handle.clone(); - move |window, cx| { Tooltip::for_action_in("Go Back", &workspace::GoBack, &focus_handle, window, cx) } @@ -2063,7 +2411,6 @@ impl AgentPanel { .anchor(Corner::TopRight) .with_handle(self.new_thread_menu_handle.clone()) .menu({ - let focus_handle = focus_handle.clone(); move |window, cx| { let active_thread = active_thread.clone(); Some(ContextMenu::build(window, cx, |mut menu, _window, cx| { @@ -2138,7 +2485,7 @@ impl AgentPanel { .child(self.render_toolbar_back_button(cx)) .into_any_element(), _ => self - .render_recent_entries_menu(IconName::MenuAlt, cx) + .render_recent_entries_menu(IconName::MenuAlt, Corner::TopLeft, cx) .into_any_element(), }) .child(self.render_title_view(window, cx)), @@ -2162,11 +2509,14 @@ impl AgentPanel { } fn render_toolbar_new(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + let agent_server_store = self.project.read(cx).agent_server_store().clone(); let focus_handle = self.focus_handle(cx); let active_thread = match &self.active_view { - ActiveView::Thread { thread, .. } => Some(thread.read(cx).thread().clone()), - ActiveView::ExternalAgentThread { .. } + ActiveView::ExternalAgentThread { thread_view } => { + thread_view.read(cx).as_native_thread(cx) + } + ActiveView::Thread { .. } | ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => None, @@ -2188,13 +2538,19 @@ impl AgentPanel { } }, ) - .anchor(Corner::TopLeft) + .anchor(Corner::TopRight) .with_handle(self.new_thread_menu_handle.clone()) .menu({ - let focus_handle = focus_handle.clone(); let workspace = self.workspace.clone(); + let is_via_collab = workspace + .update(cx, |workspace, cx| { + workspace.project().read(cx).is_via_collab() + }) + .unwrap_or_default(); move |window, cx| { + telemetry::event!("New Thread Clicked"); + let active_thread = active_thread.clone(); Some(ContextMenu::build(window, cx, |mut menu, _window, cx| { menu = menu @@ -2204,15 +2560,15 @@ impl AgentPanel { let thread = active_thread.read(cx); if !thread.is_empty() { - let thread_id = thread.id().clone(); + let session_id = thread.id().clone(); this.item( ContextMenuEntry::new("New From Summary") .icon(IconName::ThreadFromSummary) .icon_color(Color::Muted) .handler(move |window, cx| { window.dispatch_action( - Box::new(NewThread { - from_thread_id: Some(thread_id.clone()), + Box::new(NewNativeAgentThreadFromSummary { + from_session_id: session_id.clone(), }), cx, ); @@ -2224,9 +2580,9 @@ impl AgentPanel { }) .item( ContextMenuEntry::new("New Thread") + .action(NewThread::default().boxed_clone()) .icon(IconName::Thread) .icon_color(Color::Muted) - .action(NewThread::default().boxed_clone()) .handler({ let workspace = workspace.clone(); move |window, cx| { @@ -2236,18 +2592,15 @@ impl AgentPanel { workspace.panel::(cx) { panel.update(cx, |panel, cx| { - panel.set_selected_agent( - AgentType::Zed, + panel.new_agent_thread( + AgentType::NativeAgent, + window, cx, ); }); } }); } - window.dispatch_action( - NewThread::default().boxed_clone(), - cx, - ); } }), ) @@ -2265,118 +2618,159 @@ impl AgentPanel { workspace.panel::(cx) { panel.update(cx, |panel, cx| { - panel.set_selected_agent( + panel.new_agent_thread( AgentType::TextThread, + window, cx, ); }); } }); } - window.dispatch_action(NewTextThread.boxed_clone(), cx); - } - }), - ) - .item( - ContextMenuEntry::new("New Native Agent Thread") - .icon(IconName::ZedAssistant) - .icon_color(Color::Muted) - .handler({ - let workspace = workspace.clone(); - move |window, cx| { - if let Some(workspace) = workspace.upgrade() { - workspace.update(cx, |workspace, cx| { - if let Some(panel) = - workspace.panel::(cx) - { - panel.update(cx, |panel, cx| { - panel.set_selected_agent( - AgentType::NativeAgent, - cx, - ); - }); - } - }); - } - window.dispatch_action( - NewExternalAgentThread { - agent: Some(crate::ExternalAgent::NativeAgent), - } - .boxed_clone(), - cx, - ); } }), ) .separator() .header("External Agents") - .item( - ContextMenuEntry::new("New Gemini Thread") - .icon(IconName::AiGemini) - .icon_color(Color::Muted) - .handler({ - let workspace = workspace.clone(); - move |window, cx| { - if let Some(workspace) = workspace.upgrade() { - workspace.update(cx, |workspace, cx| { - if let Some(panel) = - workspace.panel::(cx) - { - panel.update(cx, |panel, cx| { - panel.set_selected_agent( - AgentType::Gemini, - cx, - ); - }); - } - }); + .when(cx.has_flag::(), |menu| { + menu.item( + ContextMenuEntry::new("New Gemini CLI Thread") + .icon(IconName::AiGemini) + .icon_color(Color::Muted) + .disabled(is_via_collab) + .handler({ + let workspace = workspace.clone(); + move |window, cx| { + if let Some(workspace) = workspace.upgrade() { + workspace.update(cx, |workspace, cx| { + if let Some(panel) = + workspace.panel::(cx) + { + panel.update(cx, |panel, cx| { + panel.new_agent_thread( + AgentType::Gemini, + window, + cx, + ); + }); + } + }); + } } - window.dispatch_action( - NewExternalAgentThread { - agent: Some(crate::ExternalAgent::Gemini), + }), + ) + }) + .when(cx.has_flag::(), |menu| { + menu.item( + ContextMenuEntry::new("New Claude Code Thread") + .icon(IconName::AiClaude) + .disabled(is_via_collab) + .icon_color(Color::Muted) + .handler({ + let workspace = workspace.clone(); + move |window, cx| { + if let Some(workspace) = workspace.upgrade() { + workspace.update(cx, |workspace, cx| { + if let Some(panel) = + workspace.panel::(cx) + { + panel.update(cx, |panel, cx| { + panel.new_agent_thread( + AgentType::ClaudeCode, + window, + cx, + ); + }); + } + }); } - .boxed_clone(), - cx, - ); - } - }), - ) - .item( - ContextMenuEntry::new("New Claude Code Thread") - .icon(IconName::AiClaude) - .icon_color(Color::Muted) - .handler({ - let workspace = workspace.clone(); - move |window, cx| { - if let Some(workspace) = workspace.upgrade() { - workspace.update(cx, |workspace, cx| { - if let Some(panel) = - workspace.panel::(cx) - { - panel.update(cx, |panel, cx| { - panel.set_selected_agent( - AgentType::ClaudeCode, - cx, - ); + } + }), + ) + }) + .when(cx.has_flag::(), |mut menu| { + let agent_names = agent_server_store + .read(cx) + .external_agents() + .filter(|name| { + name.0 != GEMINI_NAME && name.0 != CLAUDE_CODE_NAME + }) + .cloned() + .collect::>(); + let custom_settings = cx.global::().get::(None).custom.clone(); + for agent_name in agent_names { + menu = menu.item( + ContextMenuEntry::new(format!("New {} Thread", agent_name)) + .icon(IconName::Terminal) + .icon_color(Color::Muted) + .disabled(is_via_collab) + .handler({ + let workspace = workspace.clone(); + let agent_name = agent_name.clone(); + let custom_settings = custom_settings.clone(); + move |window, cx| { + if let Some(workspace) = workspace.upgrade() { + workspace.update(cx, |workspace, cx| { + if let Some(panel) = + workspace.panel::(cx) + { + panel.update(cx, |panel, cx| { + panel.new_agent_thread( + AgentType::Custom { + name: agent_name.clone().into(), + command: custom_settings + .get(&agent_name.0) + .map(|settings| { + settings.command.clone() + }) + .unwrap_or(placeholder_command()), + }, + window, + cx, + ); + }); + } }); } - }); - } - window.dispatch_action( - NewExternalAgentThread { - agent: Some(crate::ExternalAgent::ClaudeCode), } - .boxed_clone(), - cx, - ); - } - }), - ); + }), + ); + } + + menu + }) + .when(cx.has_flag::(), |menu| { + menu.separator().link( + "Add Other Agents", + OpenBrowser { + url: zed_urls::external_agents_docs(cx), + } + .boxed_clone(), + ) + }); menu })) } }); + let selected_agent_label = self.selected_agent.label(); + let selected_agent = div() + .id("selected_agent_icon") + .when_some(self.selected_agent.icon(), |this, icon| { + this.px(DynamicSpacing::Base02.rems(cx)) + .child(Icon::new(icon).color(Color::Muted)) + .tooltip(move |window, cx| { + Tooltip::with_meta( + selected_agent_label.clone(), + None, + "Selected Agent", + window, + cx, + ) + }) + }) + .into_any_element(); + h_flex() .id("agent-panel-toolbar") .h(Tab::container_height(cx)) @@ -2390,52 +2784,36 @@ impl AgentPanel { .child( h_flex() .size_full() - .gap(DynamicSpacing::Base08.rems(cx)) + .gap(DynamicSpacing::Base04.rems(cx)) + .pl(DynamicSpacing::Base04.rems(cx)) .child(match &self.active_view { - ActiveView::History | ActiveView::Configuration => div() - .pl(DynamicSpacing::Base04.rems(cx)) - .child(self.render_toolbar_back_button(cx)) - .into_any_element(), - _ => h_flex() - .h_full() - .px(DynamicSpacing::Base04.rems(cx)) - .border_r_1() - .border_color(cx.theme().colors().border) - .child( - h_flex() - .px_0p5() - .gap_1p5() - .child( - Icon::new(self.selected_agent.icon()).color(Color::Muted), - ) - .child(Label::new(self.selected_agent.label())), - ) - .into_any_element(), + ActiveView::History | ActiveView::Configuration => { + self.render_toolbar_back_button(cx).into_any_element() + } + _ => selected_agent.into_any_element(), }) .child(self.render_title_view(window, cx)), ) .child( h_flex() - .h_full() - .gap_2() - .children(self.render_token_count(cx)) - .child( - h_flex() - .h_full() - .gap(DynamicSpacing::Base02.rems(cx)) - .pl(DynamicSpacing::Base04.rems(cx)) - .pr(DynamicSpacing::Base06.rems(cx)) - .border_l_1() - .border_color(cx.theme().colors().border) - .child(new_thread_menu) - .child(self.render_recent_entries_menu(IconName::HistoryRerun, cx)) - .child(self.render_panel_options_menu(window, cx)), - ), + .flex_none() + .gap(DynamicSpacing::Base02.rems(cx)) + .pl(DynamicSpacing::Base04.rems(cx)) + .pr(DynamicSpacing::Base06.rems(cx)) + .child(new_thread_menu) + .child(self.render_recent_entries_menu( + IconName::MenuAltTemp, + Corner::TopRight, + cx, + )) + .child(self.render_panel_options_menu(window, cx)), ) } fn render_toolbar(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - if cx.has_flag::() { + if cx.has_flag::() + || cx.has_flag::() + { self.render_toolbar_new(window, cx).into_any_element() } else { self.render_toolbar_old(window, cx).into_any_element() @@ -2553,16 +2931,10 @@ impl AgentPanel { Some(token_count) } - ActiveView::TextThread { context_editor, .. } => { - let element = render_remaining_tokens(context_editor, cx)?; - - Some(element.into_any_element()) - } ActiveView::ExternalAgentThread { .. } + | ActiveView::TextThread { .. } | ActiveView::History - | ActiveView::Configuration => { - return None; - } + | ActiveView::Configuration => None, } } @@ -2578,7 +2950,7 @@ impl AgentPanel { .thread() .read(cx) .configured_model() - .map_or(false, |model| { + .is_some_and(|model| { model.provider.id() != language_model::ZED_CLOUD_PROVIDER_ID }) { @@ -2589,7 +2961,7 @@ impl AgentPanel { if LanguageModelRegistry::global(cx) .read(cx) .default_model() - .map_or(false, |model| { + .is_some_and(|model| { model.provider.id() != language_model::ZED_CLOUD_PROVIDER_ID }) { @@ -2604,7 +2976,10 @@ impl AgentPanel { let plan = self.user_store.read(cx).plan(); let has_previous_trial = self.user_store.read(cx).trial_started_at().is_some(); - matches!(plan, Some(Plan::ZedFree)) && has_previous_trial + matches!( + plan, + Some(Plan::V1(PlanV1::ZedFree) | Plan::V2(PlanV2::ZedFree)) + ) && has_previous_trial } fn should_render_onboarding(&self, cx: &mut Context) -> bool { @@ -2612,11 +2987,37 @@ impl AgentPanel { return false; } + let user_store = self.user_store.read(cx); + + if user_store + .plan() + .is_some_and(|plan| matches!(plan, Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro))) + && user_store + .subscription_period() + .and_then(|period| period.0.checked_add_days(chrono::Days::new(1))) + .is_some_and(|date| date < chrono::Utc::now()) + { + OnboardingUpsell::set_dismissed(true, cx); + return false; + } + match &self.active_view { - ActiveView::Thread { .. } | ActiveView::TextThread { .. } => { - let history_is_empty = self - .history_store - .update(cx, |store, cx| store.recent_entries(1, cx).is_empty()); + ActiveView::History | ActiveView::Configuration => false, + ActiveView::ExternalAgentThread { thread_view, .. } + if thread_view.read(cx).as_native_thread(cx).is_none() => + { + false + } + _ => { + let history_is_empty = if cx.has_flag::() { + self.acp_history_store.read(cx).is_empty(cx) + && self + .history_store + .update(cx, |store, cx| store.recent_entries(1, cx).is_empty()) + } else { + self.history_store + .update(cx, |store, cx| store.recent_entries(1, cx).is_empty()) + }; let has_configured_non_zed_providers = LanguageModelRegistry::read_global(cx) .providers() @@ -2628,9 +3029,6 @@ impl AgentPanel { history_is_empty || !has_configured_non_zed_providers } - ActiveView::ExternalAgentThread { .. } - | ActiveView::History - | ActiveView::Configuration => false, } } @@ -2677,6 +3075,8 @@ impl AgentPanel { return None; } + let plan = self.user_store.read(cx).plan()?; + Some( v_flex() .absolute() @@ -2685,15 +3085,18 @@ impl AgentPanel { .bg(cx.theme().colors().panel_background) .opacity(0.85) .block_mouse_except_scroll() - .child(EndTrialUpsell::new(Arc::new({ - let this = cx.entity(); - move |_, cx| { - this.update(cx, |_this, cx| { - TrialEndUpsell::set_dismissed(true, cx); - cx.notify(); - }); - } - }))), + .child(EndTrialUpsell::new( + plan, + Arc::new({ + let this = cx.entity(); + move |_, cx| { + this.update(cx, |_this, cx| { + TrialEndUpsell::set_dismissed(true, cx); + cx.notify(); + }); + } + }), + )), ) } @@ -2703,20 +3106,22 @@ impl AgentPanel { action_slot: Option, cx: &mut Context, ) -> impl IntoElement { - h_flex() - .mt_2() - .pl_1p5() - .pb_1() - .w_full() - .justify_between() - .border_b_1() - .border_color(cx.theme().colors().border_variant) - .child( - Label::new(label.into()) - .size(LabelSize::Small) - .color(Color::Muted), - ) - .children(action_slot) + div().pl_1().pr_1p5().child( + h_flex() + .mt_2() + .pl_1p5() + .pb_1() + .w_full() + .justify_between() + .border_b_1() + .border_color(cx.theme().colors().border_variant) + .child( + Label::new(label.into()) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .children(action_slot), + ) } fn render_thread_empty_state( @@ -2822,22 +3227,12 @@ impl AgentPanel { }), ), ) - }) - .when_some(configuration_error.as_ref(), |this, err| { - this.child(self.render_configuration_error( - err, - &focus_handle, - window, - cx, - )) }), ) }) .when(!recent_history.is_empty(), |parent| { - let focus_handle = focus_handle.clone(); parent .overflow_hidden() - .p_1p5() .justify_end() .gap_1() .child( @@ -2865,14 +3260,15 @@ impl AgentPanel { ), ) .child( - v_flex() - .gap_1() - .children(recent_history.into_iter().enumerate().map( - |(index, entry)| { + v_flex().p_1().pr_1p5().gap_1().children( + recent_history + .into_iter() + .enumerate() + .map(|(index, entry)| { // TODO: Add keyboard navigation. let is_hovered = self.hovered_recent_history_item == Some(index); - HistoryEntryElement::new(entry.clone(), cx.entity().downgrade()) + HistoryEntryElement::new(entry, cx.entity().downgrade()) .hovered(is_hovered) .on_hover(cx.listener( move |this, is_hovered, _window, cx| { @@ -2887,50 +3283,82 @@ impl AgentPanel { }, )) .into_any_element() - }, - )), + }), + ), ) - .when_some(configuration_error.as_ref(), |this, err| { - this.child(self.render_configuration_error(err, &focus_handle, window, cx)) - }) + }) + .when_some(configuration_error.as_ref(), |this, err| { + this.child(self.render_configuration_error(false, err, &focus_handle, window, cx)) }) } fn render_configuration_error( &self, + border_bottom: bool, configuration_error: &ConfigurationError, focus_handle: &FocusHandle, window: &mut Window, cx: &mut App, ) -> impl IntoElement { - match configuration_error { - ConfigurationError::ModelNotFound - | ConfigurationError::ProviderNotAuthenticated(_) - | ConfigurationError::NoProvider => Banner::new() - .severity(ui::Severity::Warning) - .child(Label::new(configuration_error.to_string())) - .action_slot( - Button::new("settings", "Configure Provider") + let zed_provider_configured = AgentSettings::get_global(cx) + .default_model + .as_ref() + .is_some_and(|selection| selection.provider.0.as_str() == "zed.dev"); + + let callout = if zed_provider_configured { + Callout::new() + .icon(IconName::Warning) + .severity(Severity::Warning) + .when(border_bottom, |this| { + this.border_position(ui::BorderPosition::Bottom) + }) + .title("Sign in to continue using Zed as your LLM provider.") + .actions_slot( + Button::new("sign_in", "Sign In") + .style(ButtonStyle::Tinted(ui::TintColor::Warning)) + .label_size(LabelSize::Small) + .on_click({ + let workspace = self.workspace.clone(); + move |_, _, cx| { + let Ok(client) = + workspace.update(cx, |workspace, _| workspace.client().clone()) + else { + return; + }; + + cx.spawn(async move |cx| { + client.sign_in_with_optional_connect(true, cx).await + }) + .detach_and_log_err(cx); + } + }), + ) + } else { + Callout::new() + .icon(IconName::Warning) + .severity(Severity::Warning) + .when(border_bottom, |this| { + this.border_position(ui::BorderPosition::Bottom) + }) + .title(configuration_error.to_string()) + .actions_slot( + Button::new("settings", "Configure") .style(ButtonStyle::Tinted(ui::TintColor::Warning)) .label_size(LabelSize::Small) .key_binding( - KeyBinding::for_action_in(&OpenSettings, &focus_handle, window, cx) + KeyBinding::for_action_in(&OpenSettings, focus_handle, window, cx) .map(|kb| kb.size(rems_from_px(12.))), ) .on_click(|_event, window, cx| { window.dispatch_action(OpenSettings.boxed_clone(), cx) }), - ), - ConfigurationError::ProviderPendingTermsAcceptance(provider) => { - Banner::new().severity(ui::Severity::Warning).child( - h_flex().w_full().children( - provider.render_accept_terms( - LanguageModelProviderTosView::ThreadEmptyState, - cx, - ), - ), ) - } + }; + + match configuration_error { + ConfigurationError::ModelNotFound + | ConfigurationError::ProviderNotAuthenticated(_) + | ConfigurationError::NoProvider => callout.into_any_element(), } } @@ -2961,7 +3389,7 @@ impl AgentPanel { let focus_handle = self.focus_handle(cx); let banner = Banner::new() - .severity(ui::Severity::Info) + .severity(Severity::Info) .child(Label::new("Consecutive tool use limit reached.").size(LabelSize::Small)) .action_slot( h_flex() @@ -3072,10 +3500,6 @@ impl AgentPanel { })) } - fn error_callout_bg(&self, cx: &Context) -> Hsla { - cx.theme().status().error.opacity(0.08) - } - fn render_payment_required_error( &self, thread: &Entity, @@ -3084,23 +3508,18 @@ impl AgentPanel { const ERROR_MESSAGE: &str = "You reached your free usage limit. Upgrade to Zed Pro for more prompts."; - let icon = Icon::new(IconName::XCircle) - .size(IconSize::Small) - .color(Color::Error); - - div() - .border_t_1() - .border_color(cx.theme().colors().border) - .child( - Callout::new() - .icon(icon) - .title("Free Usage Exceeded") - .description(ERROR_MESSAGE) - .tertiary_action(self.upgrade_button(thread, cx)) - .secondary_action(self.create_copy_button(ERROR_MESSAGE)) - .primary_action(self.dismiss_error_button(thread, cx)) - .bg_color(self.error_callout_bg(cx)), + Callout::new() + .severity(Severity::Error) + .icon(IconName::XCircle) + .title("Free Usage Exceeded") + .description(ERROR_MESSAGE) + .actions_slot( + h_flex() + .gap_0p5() + .child(self.upgrade_button(thread, cx)) + .child(self.create_copy_button(ERROR_MESSAGE)), ) + .dismiss_action(self.dismiss_error_button(thread, cx)) .into_any_element() } @@ -3111,44 +3530,29 @@ impl AgentPanel { cx: &mut Context, ) -> AnyElement { let error_message = match plan { - Plan::ZedPro => "Upgrade to usage-based billing for more prompts.", - Plan::ZedProTrial | Plan::ZedFree => "Upgrade to Zed Pro for more prompts.", + Plan::V1(PlanV1::ZedPro) => "Upgrade to usage-based billing for more prompts.", + Plan::V1(PlanV1::ZedProTrial) | Plan::V1(PlanV1::ZedFree) => { + "Upgrade to Zed Pro for more prompts." + } + Plan::V2(_) => "", }; - let icon = Icon::new(IconName::XCircle) - .size(IconSize::Small) - .color(Color::Error); - - div() - .border_t_1() - .border_color(cx.theme().colors().border) - .child( - Callout::new() - .icon(icon) - .title("Model Prompt Limit Reached") - .description(error_message) - .tertiary_action(self.upgrade_button(thread, cx)) - .secondary_action(self.create_copy_button(error_message)) - .primary_action(self.dismiss_error_button(thread, cx)) - .bg_color(self.error_callout_bg(cx)), + Callout::new() + .severity(Severity::Error) + .title("Model Prompt Limit Reached") + .description(error_message) + .actions_slot( + h_flex() + .gap_0p5() + .child(self.upgrade_button(thread, cx)) + .child(self.create_copy_button(error_message)), ) + .dismiss_action(self.dismiss_error_button(thread, cx)) .into_any_element() } - fn render_error_message( - &self, - header: SharedString, - message: SharedString, - thread: &Entity, - cx: &mut Context, - ) -> AnyElement { - let message_with_header = format!("{}\n{}", header, message); - - let icon = Icon::new(IconName::XCircle) - .size(IconSize::Small) - .color(Color::Error); - - let retry_button = Button::new("retry", "Retry") + fn render_retry_button(&self, thread: &Entity) -> AnyElement { + Button::new("retry", "Retry") .icon(IconName::RotateCw) .icon_position(IconPosition::Start) .icon_size(IconSize::Small) @@ -3163,21 +3567,36 @@ impl AgentPanel { }); }); } - }); + }) + .into_any_element() + } - div() - .border_t_1() - .border_color(cx.theme().colors().border) - .child( - Callout::new() - .icon(icon) - .title(header) - .description(message.clone()) - .primary_action(retry_button) - .secondary_action(self.dismiss_error_button(thread, cx)) - .tertiary_action(self.create_copy_button(message_with_header)) - .bg_color(self.error_callout_bg(cx)), + fn render_error_message( + &self, + header: SharedString, + message: SharedString, + thread: &Entity, + cx: &mut Context, + ) -> AnyElement { + let message_with_header = format!("{}\n{}", header, message); + + // Don't show Retry button for refusals + let is_refusal = header == "Request Refused"; + let retry_button = self.render_retry_button(thread); + let copy_button = self.create_copy_button(message_with_header); + + Callout::new() + .severity(Severity::Error) + .icon(IconName::XCircle) + .title(header) + .description(message) + .actions_slot( + h_flex() + .gap_0p5() + .when(!is_refusal, |this| this.child(retry_button)) + .child(copy_button), ) + .dismiss_action(self.dismiss_error_button(thread, cx)) .into_any_element() } @@ -3186,60 +3605,39 @@ impl AgentPanel { message: SharedString, can_enable_burn_mode: bool, thread: &Entity, - cx: &mut Context, ) -> AnyElement { - let icon = Icon::new(IconName::XCircle) - .size(IconSize::Small) - .color(Color::Error); - - let retry_button = Button::new("retry", "Retry") - .icon(IconName::RotateCw) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Small) - .label_size(LabelSize::Small) - .on_click({ - let thread = thread.clone(); - move |_, window, cx| { - thread.update(cx, |thread, cx| { - thread.clear_last_error(); - thread.thread().update(cx, |thread, cx| { - thread.retry_last_completion(Some(window.window_handle()), cx); - }); - }); - } - }); - - let mut callout = Callout::new() - .icon(icon) + Callout::new() + .severity(Severity::Error) .title("Error") - .description(message.clone()) - .bg_color(self.error_callout_bg(cx)) - .primary_action(retry_button); - - if can_enable_burn_mode { - let burn_mode_button = Button::new("enable_burn_retry", "Enable Burn Mode and Retry") - .icon(IconName::ZedBurnMode) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Small) - .label_size(LabelSize::Small) - .on_click({ - let thread = thread.clone(); - move |_, window, cx| { - thread.update(cx, |thread, cx| { - thread.clear_last_error(); - thread.thread().update(cx, |thread, cx| { - thread.enable_burn_mode_and_retry(Some(window.window_handle()), cx); - }); - }); - } - }); - callout = callout.secondary_action(burn_mode_button); - } - - div() - .border_t_1() - .border_color(cx.theme().colors().border) - .child(callout) + .description(message) + .actions_slot( + h_flex() + .gap_0p5() + .when(can_enable_burn_mode, |this| { + this.child( + Button::new("enable_burn_retry", "Enable Burn Mode and Retry") + .icon(IconName::ZedBurnMode) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .label_size(LabelSize::Small) + .on_click({ + let thread = thread.clone(); + move |_, window, cx| { + thread.update(cx, |thread, cx| { + thread.clear_last_error(); + thread.thread().update(cx, |thread, cx| { + thread.enable_burn_mode_and_retry( + Some(window.window_handle()), + cx, + ); + }); + }); + } + }), + ) + }) + .child(self.render_retry_button(thread)), + ) .into_any_element() } @@ -3318,9 +3716,9 @@ impl AgentPanel { .on_drop(cx.listener(move |this, paths: &ExternalPaths, window, cx| { let tasks = paths .paths() - .into_iter() + .iter() .map(|path| { - Workspace::project_path_for_path(this.project.clone(), &path, false, cx) + Workspace::project_path_for_path(this.project.clone(), path, false, cx) }) .collect::>(); cx.spawn_in(window, async move |this, cx| { @@ -3458,6 +3856,11 @@ impl Render for AgentPanel { } })) .on_action(cx.listener(Self::toggle_burn_mode)) + .on_action(cx.listener(|this, _: &ReauthenticateAgent, window, cx| { + if let Some(thread_view) = this.active_thread_view() { + thread_view.update(cx, |thread_view, cx| thread_view.reauthenticate(window, cx)) + } + })) .child(self.render_toolbar(window, cx)) .children(self.render_onboarding(window, cx)) .map(|parent| match &self.active_view { @@ -3494,7 +3897,6 @@ impl Render for AgentPanel { message, can_enable_burn_mode, thread, - cx, ), }) .into_any(), @@ -3508,7 +3910,13 @@ impl Render for AgentPanel { ActiveView::ExternalAgentThread { thread_view, .. } => parent .child(thread_view.clone()) .child(self.render_drag_target(cx)), - ActiveView::History => parent.child(self.history.clone()), + ActiveView::History => { + if cx.has_flag::() { + parent.child(self.acp_history.clone()) + } else { + parent.child(self.history.clone()) + } + } ActiveView::TextThread { context_editor, buffer_search_bar, @@ -3522,16 +3930,13 @@ impl Render for AgentPanel { if !self.should_render_onboarding(cx) && let Some(err) = configuration_error.as_ref() { - this.child( - div().bg(cx.theme().colors().editor_background).p_2().child( - self.render_configuration_error( - err, - &self.focus_handle(cx), - window, - cx, - ), - ), - ) + this.child(self.render_configuration_error( + true, + err, + &self.focus_handle(cx), + window, + cx, + )) } else { this } @@ -3590,7 +3995,7 @@ impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist { let text_thread_store = None; let context_store = cx.new(|_| ContextStore::new(project.clone(), None)); assistant.assist( - &prompt_editor, + prompt_editor, self.workspace.clone(), context_store, project, @@ -3673,7 +4078,11 @@ impl AgentPanelDelegate for ConcreteAssistantPanelDelegate { // Wait to create a new context until the workspace is no longer // being updated. cx.defer_in(window, move |panel, window, cx| { - if let Some(message_editor) = panel.active_message_editor() { + if let Some(thread_view) = panel.active_thread_view() { + thread_view.update(cx, |thread_view, cx| { + thread_view.insert_selections(window, cx); + }); + } else if let Some(message_editor) = panel.active_message_editor() { message_editor.update(cx, |message_editor, cx| { message_editor.context_store().update(cx, |store, cx| { let buffer = buffer.read(cx); diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 231b9cfb38a8c6e8ce4dc102fd14e06703b3e1c5..09d2179fc3a2ec4ff4288da4062365c51ad4444f 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -5,7 +5,6 @@ mod agent_diff; mod agent_model_selector; mod agent_panel; mod buffer_codegen; -mod burn_mode_tooltip; mod context_picker; mod context_server_configuration; mod context_strip; @@ -35,12 +34,13 @@ use client::Client; use command_palette_hooks::CommandPaletteFilter; use feature_flags::FeatureFlagAppExt as _; use fs::Fs; -use gpui::{Action, App, Entity, actions}; +use gpui::{Action, App, Entity, SharedString, actions}; use language::LanguageRegistry; use language_model::{ ConfiguredModel, LanguageModel, LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, }; use project::DisableAiSettings; +use project::agent_server_store::AgentServerCommand; use prompt_store::PromptBuilder; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -72,8 +72,10 @@ actions!( ToggleOptionsMenu, /// Deletes the recently opened thread from history. DeleteRecentlyOpenThread, - /// Toggles the profile selector for switching between agent profiles. + /// Toggles the profile or mode selector for switching between agent profiles. ToggleProfileSelector, + /// Cycles through available session modes. + CycleModeSelector, /// Removes all added context from the current conversation. RemoveAllContext, /// Expands the message editor to full size. @@ -114,6 +116,12 @@ actions!( RejectAll, /// Keeps all suggestions or changes. KeepAll, + /// Allow this operation only this time. + AllowOnce, + /// Allow this operation and remember the choice. + AllowAlways, + /// Reject this operation only this time. + RejectOnce, /// Follows the agent's suggestions. Follow, /// Resets the trial upsell notification. @@ -129,6 +137,12 @@ actions!( ] ); +#[derive(Clone, Copy, Debug, PartialEq, Eq, Action)] +#[action(namespace = agent)] +#[action(deprecated_aliases = ["assistant::QuoteSelection"])] +/// Quotes the current selection in the agent panel's message editor. +pub struct QuoteSelection; + /// Creates a new conversation thread, optionally based on an existing thread. #[derive(Default, Clone, PartialEq, Deserialize, JsonSchema, Action)] #[action(namespace = agent)] @@ -147,21 +161,57 @@ pub struct NewExternalAgentThread { agent: Option, } -#[derive(Default, Clone, Copy, PartialEq, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, PartialEq, Deserialize, JsonSchema, Action)] +#[action(namespace = agent)] +#[serde(deny_unknown_fields)] +pub struct NewNativeAgentThreadFromSummary { + from_session_id: agent_client_protocol::SessionId, +} + +// TODO unify this with AgentType +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] enum ExternalAgent { #[default] Gemini, ClaudeCode, NativeAgent, + Custom { + name: SharedString, + command: AgentServerCommand, + }, +} + +fn placeholder_command() -> AgentServerCommand { + AgentServerCommand { + path: "/placeholder".into(), + args: vec![], + env: None, + } } impl ExternalAgent { - pub fn server(&self, fs: Arc) -> Rc { + fn name(&self) -> &'static str { + match self { + Self::NativeAgent => "zed", + Self::Gemini => "gemini-cli", + Self::ClaudeCode => "claude-code", + Self::Custom { .. } => "custom", + } + } + + pub fn server( + &self, + fs: Arc, + history: Entity, + ) -> Rc { match self { - ExternalAgent::Gemini => Rc::new(agent_servers::Gemini), - ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode), - ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer::new(fs)), + Self::Gemini => Rc::new(agent_servers::Gemini), + Self::ClaudeCode => Rc::new(agent_servers::ClaudeCode), + Self::NativeAgent => Rc::new(agent2::NativeAgentServer::new(fs, history)), + Self::Custom { name, command: _ } => { + Rc::new(agent_servers::CustomAgentServer::new(name.clone())) + } } } } @@ -237,13 +287,7 @@ pub fn init( client.telemetry().clone(), cx, ); - terminal_inline_assistant::init( - fs.clone(), - prompt_builder.clone(), - client.telemetry().clone(), - cx, - ); - indexed_docs::init(cx); + terminal_inline_assistant::init(fs.clone(), prompt_builder, client.telemetry().clone(), cx); cx.observe_new(move |workspace, window, cx| { ConfigureContextServerModal::register(workspace, language_registry.clone(), window, cx) }) @@ -308,8 +352,7 @@ fn update_command_palette_filter(cx: &mut App) { ]; filter.show_action_types(edit_prediction_actions.iter()); - filter - .show_action_types([TypeId::of::()].iter()); + filter.show_action_types(&[TypeId::of::()]); } }); } @@ -322,7 +365,7 @@ fn init_language_model_settings(cx: &mut App) { cx.subscribe( &LanguageModelRegistry::global(cx), |_, event: &language_model::Event, cx| match event { - language_model::Event::ProviderStateChanged + language_model::Event::ProviderStateChanged(_) | language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { update_active_language_model_from_settings(cx); @@ -389,7 +432,6 @@ fn register_slash_commands(cx: &mut App) { slash_command_registry.register_command(assistant_slash_commands::FetchSlashCommand, true); cx.observe_flag::({ - let slash_command_registry = slash_command_registry.clone(); move |is_enabled, _cx| { if is_enabled { slash_command_registry.register_command( @@ -410,12 +452,6 @@ fn update_slash_commands_from_settings(cx: &mut App) { let slash_command_registry = SlashCommandRegistry::global(cx); let settings = SlashCommandSettings::get_global(cx); - if settings.docs.enabled { - slash_command_registry.register_command(assistant_slash_commands::DocsSlashCommand, true); - } else { - slash_command_registry.unregister_command(assistant_slash_commands::DocsSlashCommand); - } - if settings.cargo_workspace.enabled { slash_command_registry .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true); diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs index 615142b73dfd6eed59f635af780310290e3f6f25..2309aad754aee55af5ad040c39d22304486446a4 100644 --- a/crates/agent_ui/src/buffer_codegen.rs +++ b/crates/agent_ui/src/buffer_codegen.rs @@ -352,12 +352,12 @@ impl CodegenAlternative { event: &multi_buffer::Event, cx: &mut Context, ) { - if let multi_buffer::Event::TransactionUndone { transaction_id } = event { - if self.transformation_transaction_id == Some(*transaction_id) { - self.transformation_transaction_id = None; - self.generation = Task::ready(()); - cx.emit(CodegenEvent::Undone); - } + if let multi_buffer::Event::TransactionUndone { transaction_id } = event + && self.transformation_transaction_id == Some(*transaction_id) + { + self.transformation_transaction_id = None; + self.generation = Task::ready(()); + cx.emit(CodegenEvent::Undone); } } @@ -388,7 +388,7 @@ impl CodegenAlternative { } else { let request = self.build_request(&model, user_prompt, cx)?; cx.spawn(async move |_, cx| { - Ok(model.stream_completion_text(request.await, &cx).await?) + Ok(model.stream_completion_text(request.await, cx).await?) }) .boxed_local() }; @@ -447,7 +447,7 @@ impl CodegenAlternative { } }); - let temperature = AgentSettings::temperature_for_model(&model, cx); + let temperature = AgentSettings::temperature_for_model(model, cx); Ok(cx.spawn(async move |_cx| { let mut request_message = LanguageModelRequestMessage { @@ -576,38 +576,34 @@ impl CodegenAlternative { let mut lines = chunk.split('\n').peekable(); while let Some(line) = lines.next() { new_text.push_str(line); - if line_indent.is_none() { - if let Some(non_whitespace_ch_ix) = + if line_indent.is_none() + && let Some(non_whitespace_ch_ix) = new_text.find(|ch: char| !ch.is_whitespace()) - { - line_indent = Some(non_whitespace_ch_ix); - base_indent = base_indent.or(line_indent); - - let line_indent = line_indent.unwrap(); - let base_indent = base_indent.unwrap(); - let indent_delta = - line_indent as i32 - base_indent as i32; - let mut corrected_indent_len = cmp::max( - 0, - suggested_line_indent.len as i32 + indent_delta, - ) - as usize; - if first_line { - corrected_indent_len = corrected_indent_len - .saturating_sub( - selection_start.column as usize, - ); - } - - let indent_char = suggested_line_indent.char(); - let mut indent_buffer = [0; 4]; - let indent_str = - indent_char.encode_utf8(&mut indent_buffer); - new_text.replace_range( - ..line_indent, - &indent_str.repeat(corrected_indent_len), - ); + { + line_indent = Some(non_whitespace_ch_ix); + base_indent = base_indent.or(line_indent); + + let line_indent = line_indent.unwrap(); + let base_indent = base_indent.unwrap(); + let indent_delta = line_indent as i32 - base_indent as i32; + let mut corrected_indent_len = cmp::max( + 0, + suggested_line_indent.len as i32 + indent_delta, + ) + as usize; + if first_line { + corrected_indent_len = corrected_indent_len + .saturating_sub(selection_start.column as usize); } + + let indent_char = suggested_line_indent.char(); + let mut indent_buffer = [0; 4]; + let indent_str = + indent_char.encode_utf8(&mut indent_buffer); + new_text.replace_range( + ..line_indent, + &indent_str.repeat(corrected_indent_len), + ); } if line_indent.is_some() { @@ -1028,7 +1024,7 @@ where chunk.push('\n'); } - chunk.push_str(&line); + chunk.push_str(line); } consumed += line.len(); @@ -1133,7 +1129,7 @@ mod tests { ) }); - let chunks_tx = simulate_response_stream(codegen.clone(), cx); + let chunks_tx = simulate_response_stream(&codegen, cx); let mut new_text = concat!( " let mut x = 0;\n", @@ -1143,7 +1139,7 @@ mod tests { ); while !new_text.is_empty() { let max_len = cmp::min(new_text.len(), 10); - let len = rng.gen_range(1..=max_len); + let len = rng.random_range(1..=max_len); let (chunk, suffix) = new_text.split_at(len); chunks_tx.unbounded_send(chunk.to_string()).unwrap(); new_text = suffix; @@ -1200,7 +1196,7 @@ mod tests { ) }); - let chunks_tx = simulate_response_stream(codegen.clone(), cx); + let chunks_tx = simulate_response_stream(&codegen, cx); cx.background_executor.run_until_parked(); @@ -1212,7 +1208,7 @@ mod tests { ); while !new_text.is_empty() { let max_len = cmp::min(new_text.len(), 10); - let len = rng.gen_range(1..=max_len); + let len = rng.random_range(1..=max_len); let (chunk, suffix) = new_text.split_at(len); chunks_tx.unbounded_send(chunk.to_string()).unwrap(); new_text = suffix; @@ -1269,7 +1265,7 @@ mod tests { ) }); - let chunks_tx = simulate_response_stream(codegen.clone(), cx); + let chunks_tx = simulate_response_stream(&codegen, cx); cx.background_executor.run_until_parked(); @@ -1281,7 +1277,7 @@ mod tests { ); while !new_text.is_empty() { let max_len = cmp::min(new_text.len(), 10); - let len = rng.gen_range(1..=max_len); + let len = rng.random_range(1..=max_len); let (chunk, suffix) = new_text.split_at(len); chunks_tx.unbounded_send(chunk.to_string()).unwrap(); new_text = suffix; @@ -1338,7 +1334,7 @@ mod tests { ) }); - let chunks_tx = simulate_response_stream(codegen.clone(), cx); + let chunks_tx = simulate_response_stream(&codegen, cx); let new_text = concat!( "func main() {\n", "\tx := 0\n", @@ -1395,7 +1391,7 @@ mod tests { ) }); - let chunks_tx = simulate_response_stream(codegen.clone(), cx); + let chunks_tx = simulate_response_stream(&codegen, cx); chunks_tx .unbounded_send("let mut x = 0;\nx += 1;".to_string()) .unwrap(); @@ -1477,7 +1473,7 @@ mod tests { } fn simulate_response_stream( - codegen: Entity, + codegen: &Entity, cx: &mut TestAppContext, ) -> mpsc::UnboundedSender { let (chunks_tx, chunks_rx) = mpsc::unbounded(); diff --git a/crates/agent_ui/src/burn_mode_tooltip.rs b/crates/agent_ui/src/burn_mode_tooltip.rs deleted file mode 100644 index 6354c07760f5aa0261b69e8dd08ce1f1b1be6023..0000000000000000000000000000000000000000 --- a/crates/agent_ui/src/burn_mode_tooltip.rs +++ /dev/null @@ -1,61 +0,0 @@ -use gpui::{Context, FontWeight, IntoElement, Render, Window}; -use ui::{prelude::*, tooltip_container}; - -pub struct BurnModeTooltip { - selected: bool, -} - -impl BurnModeTooltip { - pub fn new() -> Self { - Self { selected: false } - } - - pub fn selected(mut self, selected: bool) -> Self { - self.selected = selected; - self - } -} - -impl Render for BurnModeTooltip { - fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let (icon, color) = if self.selected { - (IconName::ZedBurnModeOn, Color::Error) - } else { - (IconName::ZedBurnMode, Color::Default) - }; - - let turned_on = h_flex() - .h_4() - .px_1() - .border_1() - .border_color(cx.theme().colors().border) - .bg(cx.theme().colors().text_accent.opacity(0.1)) - .rounded_sm() - .child( - Label::new("ON") - .size(LabelSize::XSmall) - .weight(FontWeight::SEMIBOLD) - .color(Color::Accent), - ); - - let title = h_flex() - .gap_1p5() - .child(Icon::new(icon).size(IconSize::Small).color(color)) - .child(Label::new("Burn Mode")) - .when(self.selected, |title| title.child(turned_on)); - - tooltip_container(window, cx, |this, _, _| { - this - .child(title) - .child( - div() - .max_w_64() - .child( - Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning.") - .size(LabelSize::Small) - .color(Color::Muted) - ) - ) - }) - } -} diff --git a/crates/agent_ui/src/context_picker.rs b/crates/agent_ui/src/context_picker.rs index 7dc00bfae2ecd5404b5c3ae3617f6387791f857b..b225fbf34058604cfb3f306a9cee14f69bb5edaa 100644 --- a/crates/agent_ui/src/context_picker.rs +++ b/crates/agent_ui/src/context_picker.rs @@ -13,7 +13,7 @@ use anyhow::{Result, anyhow}; use collections::HashSet; pub use completion_provider::ContextPickerCompletionProvider; use editor::display_map::{Crease, CreaseId, CreaseMetadata, FoldId}; -use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset}; +use editor::{Anchor, Editor, ExcerptId, FoldPlaceholder, ToOffset}; use fetch_context_picker::FetchContextPicker; use file_context_picker::FileContextPicker; use file_context_picker::render_file_context_entry; @@ -228,7 +228,7 @@ impl ContextPicker { } fn build_menu(&mut self, window: &mut Window, cx: &mut Context) -> Entity { - let context_picker = cx.entity().clone(); + let context_picker = cx.entity(); let menu = ContextMenu::build(window, cx, move |menu, _window, cx| { let recent = self.recent_entries(cx); @@ -385,12 +385,11 @@ impl ContextPicker { } pub fn select_first(&mut self, window: &mut Window, cx: &mut Context) { - match &self.mode { - ContextPickerState::Default(entity) => entity.update(cx, |entity, cx| { + // Other variants already select their first entry on open automatically + if let ContextPickerState::Default(entity) = &self.mode { + entity.update(cx, |entity, cx| { entity.select_first(&Default::default(), window, cx) - }), - // Other variants already select their first entry on open automatically - _ => {} + }) } } @@ -610,9 +609,7 @@ pub(crate) fn available_context_picker_entries( .read(cx) .active_item(cx) .and_then(|item| item.downcast::()) - .map_or(false, |editor| { - editor.update(cx, |editor, cx| editor.has_non_empty_selection(cx)) - }); + .is_some_and(|editor| editor.update(cx, |editor, cx| editor.has_non_empty_selection(cx))); if has_selection { entries.push(ContextPickerEntry::Action( ContextPickerAction::AddSelections, @@ -680,7 +677,7 @@ pub(crate) fn recent_context_picker_entries( .filter(|(_, abs_path)| { abs_path .as_ref() - .map_or(true, |path| !exclude_paths.contains(path.as_path())) + .is_none_or(|path| !exclude_paths.contains(path.as_path())) }) .take(4) .filter_map(|(project_path, _)| { @@ -821,13 +818,8 @@ pub fn crease_for_mention( let render_trailer = move |_row, _unfold, _window: &mut Window, _cx: &mut App| Empty.into_any(); - Crease::inline( - range, - placeholder.clone(), - fold_toggle("mention"), - render_trailer, - ) - .with_metadata(CreaseMetadata { icon_path, label }) + Crease::inline(range, placeholder, fold_toggle("mention"), render_trailer) + .with_metadata(CreaseMetadata { icon_path, label }) } fn render_fold_icon_button( @@ -837,42 +829,9 @@ fn render_fold_icon_button( ) -> Arc, &mut App) -> AnyElement> { Arc::new({ move |fold_id, fold_range, cx| { - let is_in_text_selection = editor.upgrade().is_some_and(|editor| { - editor.update(cx, |editor, cx| { - let snapshot = editor - .buffer() - .update(cx, |multi_buffer, cx| multi_buffer.snapshot(cx)); - - let is_in_pending_selection = || { - editor - .selections - .pending - .as_ref() - .is_some_and(|pending_selection| { - pending_selection - .selection - .range() - .includes(&fold_range, &snapshot) - }) - }; - - let mut is_in_complete_selection = || { - editor - .selections - .disjoint_in_range::(fold_range.clone(), cx) - .into_iter() - .any(|selection| { - // This is needed to cover a corner case, if we just check for an existing - // selection in the fold range, having a cursor at the start of the fold - // marks it as selected. Non-empty selections don't cause this. - let length = selection.end - selection.start; - length > 0 - }) - }; - - is_in_pending_selection() || is_in_complete_selection() - }) - }); + let is_in_text_selection = editor + .update(cx, |editor, cx| editor.is_range_selected(&fold_range, cx)) + .unwrap_or_default(); ButtonLike::new(fold_id) .style(ButtonStyle::Filled) @@ -1028,7 +987,8 @@ impl MentionLink { .read(cx) .project() .read(cx) - .entry_for_path(&project_path, cx)?; + .entry_for_path(&project_path, cx)? + .clone(); Some(MentionLink::File(project_path, entry)) } Self::SYMBOL => { diff --git a/crates/agent_ui/src/context_picker/completion_provider.rs b/crates/agent_ui/src/context_picker/completion_provider.rs index 962c0df03db99ba8739df2c0eb8713d0e25f7f75..b67b463e3bfa654baefece2c97fc505460830f2d 100644 --- a/crates/agent_ui/src/context_picker/completion_provider.rs +++ b/crates/agent_ui/src/context_picker/completion_provider.rs @@ -13,7 +13,10 @@ use http_client::HttpClientWithUrl; use itertools::Itertools; use language::{Buffer, CodeLabel, HighlightId}; use lsp::CompletionContext; -use project::{Completion, CompletionIntent, CompletionResponse, ProjectPath, Symbol, WorktreeId}; +use project::{ + Completion, CompletionDisplayOptions, CompletionIntent, CompletionResponse, ProjectPath, + Symbol, WorktreeId, +}; use prompt_store::PromptStore; use rope::Point; use text::{Anchor, OffsetRangeExt, ToPoint}; @@ -79,8 +82,7 @@ fn search( ) -> Task> { match mode { Some(ContextPickerMode::File) => { - let search_files_task = - search_files(query.clone(), cancellation_flag.clone(), &workspace, cx); + let search_files_task = search_files(query, cancellation_flag, &workspace, cx); cx.background_spawn(async move { search_files_task .await @@ -91,8 +93,7 @@ fn search( } Some(ContextPickerMode::Symbol) => { - let search_symbols_task = - search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx); + let search_symbols_task = search_symbols(query, cancellation_flag, &workspace, cx); cx.background_spawn(async move { search_symbols_task .await @@ -108,13 +109,8 @@ fn search( .and_then(|t| t.upgrade()) .zip(text_thread_context_store.as_ref().and_then(|t| t.upgrade())) { - let search_threads_task = search_threads( - query.clone(), - cancellation_flag.clone(), - thread_store, - context_store, - cx, - ); + let search_threads_task = + search_threads(query, cancellation_flag, thread_store, context_store, cx); cx.background_spawn(async move { search_threads_task .await @@ -137,8 +133,7 @@ fn search( Some(ContextPickerMode::Rules) => { if let Some(prompt_store) = prompt_store.as_ref() { - let search_rules_task = - search_rules(query.clone(), cancellation_flag.clone(), prompt_store, cx); + let search_rules_task = search_rules(query, cancellation_flag, prompt_store, cx); cx.background_spawn(async move { search_rules_task .await @@ -196,7 +191,7 @@ fn search( let executor = cx.background_executor().clone(); let search_files_task = - search_files(query.clone(), cancellation_flag.clone(), &workspace, cx); + search_files(query.clone(), cancellation_flag, &workspace, cx); let entries = available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx); @@ -283,7 +278,7 @@ impl ContextPickerCompletionProvider { ) -> Option { match entry { ContextPickerEntry::Mode(mode) => Some(Completion { - replace_range: source_range.clone(), + replace_range: source_range, new_text: format!("@{} ", mode.keyword()), label: CodeLabel::plain(mode.label().to_string(), None), icon_path: Some(mode.icon().path().into()), @@ -330,9 +325,6 @@ impl ContextPickerCompletionProvider { ); let callback = Arc::new({ - let context_store = context_store.clone(); - let selections = selections.clone(); - let selection_infos = selection_infos.clone(); move |_, window: &mut Window, cx: &mut App| { context_store.update(cx, |context_store, cx| { for (buffer, range) in &selections { @@ -441,7 +433,7 @@ impl ContextPickerCompletionProvider { excerpt_id, source_range.start, new_text_len - 1, - editor.clone(), + editor, context_store.clone(), move |window, cx| match &thread_entry { ThreadContextEntry::Thread { id, .. } => { @@ -510,7 +502,7 @@ impl ContextPickerCompletionProvider { excerpt_id, source_range.start, new_text_len - 1, - editor.clone(), + editor, context_store.clone(), move |_, cx| { let user_prompt_id = rules.prompt_id; @@ -547,7 +539,7 @@ impl ContextPickerCompletionProvider { excerpt_id, source_range.start, new_text_len - 1, - editor.clone(), + editor, context_store.clone(), move |_, cx| { let context_store = context_store.clone(); @@ -704,16 +696,16 @@ impl ContextPickerCompletionProvider { excerpt_id, source_range.start, new_text_len - 1, - editor.clone(), + editor, context_store.clone(), move |_, cx| { let symbol = symbol.clone(); let context_store = context_store.clone(); let workspace = workspace.clone(); let result = super::symbol_context_picker::add_symbol( - symbol.clone(), + symbol, false, - workspace.clone(), + workspace, context_store.downgrade(), cx, ); @@ -728,11 +720,11 @@ fn build_code_label_for_full_path(file_name: &str, directory: Option<&str>, cx: let comment_id = cx.theme().syntax().highlight_id("comment").map(HighlightId); let mut label = CodeLabel::default(); - label.push_str(&file_name, None); + label.push_str(file_name, None); label.push_str(" ", None); if let Some(directory) = directory { - label.push_str(&directory, comment_id); + label.push_str(directory, comment_id); } label.filter_range = 0..label.text().len(); @@ -908,6 +900,7 @@ impl CompletionProvider for ContextPickerCompletionProvider { Ok(vec![CompletionResponse { completions, + display_options: CompletionDisplayOptions::default(), // Since this does its own filtering (see `filter_completions()` returns false), // there is no benefit to computing whether this set of completions is incomplete. is_incomplete: true, @@ -1020,7 +1013,7 @@ impl MentionCompletion { && line .chars() .nth(last_mention_start - 1) - .map_or(false, |c| !c.is_whitespace()) + .is_some_and(|c| !c.is_whitespace()) { return None; } @@ -1162,7 +1155,7 @@ mod tests { impl Focusable for AtMentionEditor { fn focus_handle(&self, cx: &App) -> FocusHandle { - self.0.read(cx).focus_handle(cx).clone() + self.0.read(cx).focus_handle(cx) } } @@ -1480,7 +1473,7 @@ mod tests { let completions = editor.current_completions().expect("Missing completions"); completions .into_iter() - .map(|completion| completion.label.text.to_string()) + .map(|completion| completion.label.text) .collect::>() } diff --git a/crates/agent_ui/src/context_picker/fetch_context_picker.rs b/crates/agent_ui/src/context_picker/fetch_context_picker.rs index 8ff68a8365ee01ac79d707abf00197bf5175e43a..dd558b2a1c88f60e68313b208b076a0974b30f85 100644 --- a/crates/agent_ui/src/context_picker/fetch_context_picker.rs +++ b/crates/agent_ui/src/context_picker/fetch_context_picker.rs @@ -226,9 +226,10 @@ impl PickerDelegate for FetchContextPickerDelegate { _window: &mut Window, cx: &mut Context>, ) -> Option { - let added = self.context_store.upgrade().map_or(false, |context_store| { - context_store.read(cx).includes_url(&self.url) - }); + let added = self + .context_store + .upgrade() + .is_some_and(|context_store| context_store.read(cx).includes_url(&self.url)); Some( ListItem::new(ix) diff --git a/crates/agent_ui/src/context_picker/file_context_picker.rs b/crates/agent_ui/src/context_picker/file_context_picker.rs index eaf9ed16d6fc7a09854d9f0160d87e23f3c5ffd8..43b1fa5e92fcd792ee1e8567ac558652e933bbfa 100644 --- a/crates/agent_ui/src/context_picker/file_context_picker.rs +++ b/crates/agent_ui/src/context_picker/file_context_picker.rs @@ -160,7 +160,7 @@ impl PickerDelegate for FileContextPickerDelegate { _window: &mut Window, cx: &mut Context>, ) -> Option { - let FileMatch { mat, .. } = &self.matches[ix]; + let FileMatch { mat, .. } = &self.matches.get(ix)?; Some( ListItem::new(ix) @@ -239,9 +239,7 @@ pub(crate) fn search_files( PathMatchCandidateSet { snapshot: worktree.snapshot(), - include_ignored: worktree - .root_entry() - .map_or(false, |entry| entry.is_ignored), + include_ignored: worktree.root_entry().is_some_and(|entry| entry.is_ignored), include_root_name: true, candidates: project::Candidates::Entries, } @@ -315,7 +313,7 @@ pub fn render_file_context_entry( context_store: WeakEntity, cx: &App, ) -> Stateful
{ - let (file_name, directory) = extract_file_name_and_directory(&path, path_prefix); + let (file_name, directory) = extract_file_name_and_directory(path, path_prefix); let added = context_store.upgrade().and_then(|context_store| { let project_path = ProjectPath { @@ -334,7 +332,7 @@ pub fn render_file_context_entry( let file_icon = if is_directory { FileIcons::get_folder_icon(false, cx) } else { - FileIcons::get_icon(&path, cx) + FileIcons::get_icon(path, cx) } .map(Icon::from_path) .unwrap_or_else(|| Icon::new(IconName::File)); diff --git a/crates/agent_ui/src/context_picker/rules_context_picker.rs b/crates/agent_ui/src/context_picker/rules_context_picker.rs index 8ce821cfaaab0a49f4af70fca13c1ed202de20a1..677011577aef23296a34203acdb10e5228ca7cd7 100644 --- a/crates/agent_ui/src/context_picker/rules_context_picker.rs +++ b/crates/agent_ui/src/context_picker/rules_context_picker.rs @@ -146,7 +146,7 @@ impl PickerDelegate for RulesContextPickerDelegate { _window: &mut Window, cx: &mut Context>, ) -> Option { - let thread = &self.matches[ix]; + let thread = &self.matches.get(ix)?; Some(ListItem::new(ix).inset(true).toggle_state(selected).child( render_thread_context_entry(thread, self.context_store.clone(), cx), @@ -159,7 +159,7 @@ pub fn render_thread_context_entry( context_store: WeakEntity, cx: &mut App, ) -> Div { - let added = context_store.upgrade().map_or(false, |context_store| { + let added = context_store.upgrade().is_some_and(|context_store| { context_store .read(cx) .includes_user_rules(user_rules.prompt_id) diff --git a/crates/agent_ui/src/context_picker/symbol_context_picker.rs b/crates/agent_ui/src/context_picker/symbol_context_picker.rs index 05e77deece6117d250d6efedbd9d24c6716b757e..993d65bd12ee4e01ca8d9767ccd46dd3fd645dd3 100644 --- a/crates/agent_ui/src/context_picker/symbol_context_picker.rs +++ b/crates/agent_ui/src/context_picker/symbol_context_picker.rs @@ -169,7 +169,7 @@ impl PickerDelegate for SymbolContextPickerDelegate { _window: &mut Window, _: &mut Context>, ) -> Option { - let mat = &self.matches[ix]; + let mat = &self.matches.get(ix)?; Some(ListItem::new(ix).inset(true).toggle_state(selected).child( render_symbol_context_entry(ElementId::named_usize("symbol-ctx-picker", ix), mat), @@ -289,12 +289,12 @@ pub(crate) fn search_symbols( .iter() .enumerate() .map(|(id, symbol)| { - StringMatchCandidate::new(id, &symbol.label.filter_text()) + StringMatchCandidate::new(id, symbol.label.filter_text()) }) .partition(|candidate| { project .entry_for_path(&symbols[candidate.id].path, cx) - .map_or(false, |e| !e.is_ignored) + .is_some_and(|e| !e.is_ignored) }) }) .log_err() diff --git a/crates/agent_ui/src/context_picker/thread_context_picker.rs b/crates/agent_ui/src/context_picker/thread_context_picker.rs index 15cc731f8f2b7c82885c566273bc1cda9f3c156a..9e843779c2216a89fe23dce514553e50043b8187 100644 --- a/crates/agent_ui/src/context_picker/thread_context_picker.rs +++ b/crates/agent_ui/src/context_picker/thread_context_picker.rs @@ -167,7 +167,7 @@ impl PickerDelegate for ThreadContextPickerDelegate { return; }; let open_thread_task = - thread_store.update(cx, |this, cx| this.open_thread(&id, window, cx)); + thread_store.update(cx, |this, cx| this.open_thread(id, window, cx)); cx.spawn(async move |this, cx| { let thread = open_thread_task.await?; @@ -220,7 +220,7 @@ impl PickerDelegate for ThreadContextPickerDelegate { _window: &mut Window, cx: &mut Context>, ) -> Option { - let thread = &self.matches[ix]; + let thread = &self.matches.get(ix)?; Some(ListItem::new(ix).inset(true).toggle_state(selected).child( render_thread_context_entry(thread, self.context_store.clone(), cx), @@ -236,12 +236,10 @@ pub fn render_thread_context_entry( let is_added = match entry { ThreadContextEntry::Thread { id, .. } => context_store .upgrade() - .map_or(false, |ctx_store| ctx_store.read(cx).includes_thread(&id)), - ThreadContextEntry::Context { path, .. } => { - context_store.upgrade().map_or(false, |ctx_store| { - ctx_store.read(cx).includes_text_thread(path) - }) - } + .is_some_and(|ctx_store| ctx_store.read(cx).includes_thread(id)), + ThreadContextEntry::Context { path, .. } => context_store + .upgrade() + .is_some_and(|ctx_store| ctx_store.read(cx).includes_text_thread(path)), }; h_flex() @@ -338,7 +336,7 @@ pub(crate) fn search_threads( let candidates = threads .iter() .enumerate() - .map(|(id, (_, thread))| StringMatchCandidate::new(id, &thread.title())) + .map(|(id, (_, thread))| StringMatchCandidate::new(id, thread.title())) .collect::>(); let matches = fuzzy::match_strings( &candidates, diff --git a/crates/agent_ui/src/context_strip.rs b/crates/agent_ui/src/context_strip.rs index 369964f165dc4d4460fd446c949538ec820fb82e..d25d7d35443e6ca7c28bb0894f72c0063f500721 100644 --- a/crates/agent_ui/src/context_strip.rs +++ b/crates/agent_ui/src/context_strip.rs @@ -145,7 +145,7 @@ impl ContextStrip { } let file_name = active_buffer.file()?.file_name(cx); - let icon_path = FileIcons::get_icon(&Path::new(&file_name), cx); + let icon_path = FileIcons::get_icon(Path::new(&file_name), cx); Some(SuggestedContext::File { name: file_name.to_string_lossy().into_owned().into(), buffer: active_buffer_entity.downgrade(), @@ -368,16 +368,16 @@ impl ContextStrip { _window: &mut Window, cx: &mut Context, ) { - if let Some(suggested) = self.suggested_context(cx) { - if self.is_suggested_focused(&self.added_contexts(cx)) { - self.add_suggested_context(&suggested, cx); - } + if let Some(suggested) = self.suggested_context(cx) + && self.is_suggested_focused(&self.added_contexts(cx)) + { + self.add_suggested_context(&suggested, cx); } } fn add_suggested_context(&mut self, suggested: &SuggestedContext, cx: &mut Context) { self.context_store.update(cx, |context_store, cx| { - context_store.add_suggested_context(&suggested, cx) + context_store.add_suggested_context(suggested, cx) }); cx.notify(); } diff --git a/crates/agent_ui/src/debug.rs b/crates/agent_ui/src/debug.rs index bd34659210e933ad99357e7e1ceeedb6b53c5ee0..227528e8ae13dc861fe55da7698c86b485f8ae0a 100644 --- a/crates/agent_ui/src/debug.rs +++ b/crates/agent_ui/src/debug.rs @@ -1,7 +1,7 @@ #![allow(unused, dead_code)] use client::{ModelRequestUsage, RequestUsage}; -use cloud_llm_client::{Plan, UsageLimit}; +use cloud_llm_client::{Plan, PlanV1, UsageLimit}; use gpui::Global; use std::ops::{Deref, DerefMut}; use ui::prelude::*; @@ -75,7 +75,7 @@ impl Default for DebugAccountState { Self { enabled: false, trial_expired: false, - plan: Plan::ZedFree, + plan: Plan::V1(PlanV1::ZedFree), custom_prompt_usage: ModelRequestUsage(RequestUsage { limit: UsageLimit::Unlimited, amount: 0, diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index 4a4a747899ecc310666685de336bacffc4b271e6..4ac88e6daa3d3623580e206c2759f27b218d1bac 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -72,7 +72,7 @@ pub fn init( let Some(window) = window else { return; }; - let workspace = cx.entity().clone(); + let workspace = cx.entity(); InlineAssistant::update_global(cx, |inline_assistant, cx| { inline_assistant.register_workspace(&workspace, window, cx) }); @@ -144,7 +144,8 @@ impl InlineAssistant { let Some(terminal_panel) = workspace.read(cx).panel::(cx) else { return; }; - let enabled = AgentSettings::get_global(cx).enabled; + let enabled = !DisableAiSettings::get_global(cx).disable_ai + && AgentSettings::get_global(cx).enabled; terminal_panel.update(cx, |terminal_panel, cx| { terminal_panel.set_assistant_enabled(enabled, cx) }); @@ -182,13 +183,13 @@ impl InlineAssistant { match event { workspace::Event::UserSavedItem { item, .. } => { // When the user manually saves an editor, automatically accepts all finished transformations. - if let Some(editor) = item.upgrade().and_then(|item| item.act_as::(cx)) { - if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) { - for assist_id in editor_assists.assist_ids.clone() { - let assist = &self.assists[&assist_id]; - if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) { - self.finish_assist(assist_id, false, window, cx) - } + if let Some(editor) = item.upgrade().and_then(|item| item.act_as::(cx)) + && let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) + { + for assist_id in editor_assists.assist_ids.clone() { + let assist = &self.assists[&assist_id]; + if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) { + self.finish_assist(assist_id, false, window, cx) } } } @@ -342,13 +343,11 @@ impl InlineAssistant { ) .await .ok(); - if let Some(answer) = answer { - if answer == 0 { - cx.update(|window, cx| { - window.dispatch_action(Box::new(OpenSettings), cx) - }) + if let Some(answer) = answer + && answer == 0 + { + cx.update(|window, cx| window.dispatch_action(Box::new(OpenSettings), cx)) .ok(); - } } anyhow::Ok(()) }) @@ -435,11 +434,11 @@ impl InlineAssistant { } } - if let Some(prev_selection) = selections.last_mut() { - if selection.start <= prev_selection.end { - prev_selection.end = selection.end; - continue; - } + if let Some(prev_selection) = selections.last_mut() + && selection.start <= prev_selection.end + { + prev_selection.end = selection.end; + continue; } let latest_selection = newest_selection.get_or_insert_with(|| selection.clone()); @@ -526,9 +525,9 @@ impl InlineAssistant { if assist_to_focus.is_none() { let focus_assist = if newest_selection.reversed { - range.start.to_point(&snapshot) == newest_selection.start + range.start.to_point(snapshot) == newest_selection.start } else { - range.end.to_point(&snapshot) == newest_selection.end + range.end.to_point(snapshot) == newest_selection.end }; if focus_assist { assist_to_focus = Some(assist_id); @@ -550,7 +549,7 @@ impl InlineAssistant { let editor_assists = self .assists_by_editor .entry(editor.downgrade()) - .or_insert_with(|| EditorInlineAssists::new(&editor, window, cx)); + .or_insert_with(|| EditorInlineAssists::new(editor, window, cx)); let mut assist_group = InlineAssistGroup::new(); for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists { let codegen = prompt_editor.read(cx).codegen().clone(); @@ -649,7 +648,7 @@ impl InlineAssistant { let editor_assists = self .assists_by_editor .entry(editor.downgrade()) - .or_insert_with(|| EditorInlineAssists::new(&editor, window, cx)); + .or_insert_with(|| EditorInlineAssists::new(editor, window, cx)); let mut assist_group = InlineAssistGroup::new(); self.assists.insert( @@ -985,14 +984,13 @@ impl InlineAssistant { EditorEvent::SelectionsChanged { .. } => { for assist_id in editor_assists.assist_ids.clone() { let assist = &self.assists[&assist_id]; - if let Some(decorations) = assist.decorations.as_ref() { - if decorations + if let Some(decorations) = assist.decorations.as_ref() + && decorations .prompt_editor .focus_handle(cx) .is_focused(window) - { - return; - } + { + return; } } @@ -1123,7 +1121,7 @@ impl InlineAssistant { if editor_assists .scroll_lock .as_ref() - .map_or(false, |lock| lock.assist_id == assist_id) + .is_some_and(|lock| lock.assist_id == assist_id) { editor_assists.scroll_lock = None; } @@ -1503,20 +1501,18 @@ impl InlineAssistant { window: &mut Window, cx: &mut App, ) -> Option { - if let Some(terminal_panel) = workspace.panel::(cx) { - if terminal_panel + if let Some(terminal_panel) = workspace.panel::(cx) + && terminal_panel .read(cx) .focus_handle(cx) .contains_focused(window, cx) - { - if let Some(terminal_view) = terminal_panel.read(cx).pane().and_then(|pane| { - pane.read(cx) - .active_item() - .and_then(|t| t.downcast::()) - }) { - return Some(InlineAssistTarget::Terminal(terminal_view)); - } - } + && let Some(terminal_view) = terminal_panel.read(cx).pane().and_then(|pane| { + pane.read(cx) + .active_item() + .and_then(|t| t.downcast::()) + }) + { + return Some(InlineAssistTarget::Terminal(terminal_view)); } let context_editor = agent_panel @@ -1537,13 +1533,11 @@ impl InlineAssistant { .and_then(|item| item.act_as::(cx)) { Some(InlineAssistTarget::Editor(workspace_editor)) - } else if let Some(terminal_view) = workspace - .active_item(cx) - .and_then(|item| item.act_as::(cx)) - { - Some(InlineAssistTarget::Terminal(terminal_view)) } else { - None + workspace + .active_item(cx) + .and_then(|item| item.act_as::(cx)) + .map(InlineAssistTarget::Terminal) } } } @@ -1698,7 +1692,7 @@ impl InlineAssist { }), range, codegen: codegen.clone(), - workspace: workspace.clone(), + workspace, _subscriptions: vec![ window.on_focus_in(&prompt_editor_focus_handle, cx, move |_, cx| { InlineAssistant::update_global(cx, |this, cx| { @@ -1741,22 +1735,20 @@ impl InlineAssist { return; }; - if let CodegenStatus::Error(error) = codegen.read(cx).status(cx) { - if assist.decorations.is_none() { - if let Some(workspace) = assist.workspace.upgrade() { - let error = format!("Inline assistant error: {}", error); - workspace.update(cx, |workspace, cx| { - struct InlineAssistantError; - - let id = - NotificationId::composite::( - assist_id.0, - ); - - workspace.show_toast(Toast::new(id, error), cx); - }) - } - } + if let CodegenStatus::Error(error) = codegen.read(cx).status(cx) + && assist.decorations.is_none() + && let Some(workspace) = assist.workspace.upgrade() + { + let error = format!("Inline assistant error: {}", error); + workspace.update(cx, |workspace, cx| { + struct InlineAssistantError; + + let id = NotificationId::composite::( + assist_id.0, + ); + + workspace.show_toast(Toast::new(id, error), cx); + }) } if assist.decorations.is_none() { @@ -1821,18 +1813,15 @@ impl CodeActionProvider for AssistantCodeActionProvider { has_diagnostics = true; } if has_diagnostics { - if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None) { - if let Some(symbol) = symbols_containing_start.last() { - range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot)); - range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot)); - } + let symbols_containing_start = snapshot.symbols_containing(range.start, None); + if let Some(symbol) = symbols_containing_start.last() { + range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot)); + range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot)); } - - if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None) { - if let Some(symbol) = symbols_containing_end.last() { - range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot)); - range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot)); - } + let symbols_containing_end = snapshot.symbols_containing(range.end, None); + if let Some(symbol) = symbols_containing_end.last() { + range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot)); + range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot)); } Task::ready(Ok(vec![CodeAction { diff --git a/crates/agent_ui/src/inline_prompt_editor.rs b/crates/agent_ui/src/inline_prompt_editor.rs index e6fca1698496b064917a6b1b8257388e81a00df7..0e817ca8073d71022f47b0aa08d34101f622f470 100644 --- a/crates/agent_ui/src/inline_prompt_editor.rs +++ b/crates/agent_ui/src/inline_prompt_editor.rs @@ -1,29 +1,18 @@ -use crate::agent_model_selector::AgentModelSelector; -use crate::buffer_codegen::BufferCodegen; -use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider}; -use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; -use crate::message_editor::{ContextCreasesAddon, extract_message_creases, insert_message_creases}; -use crate::terminal_codegen::TerminalCodegen; -use crate::{CycleNextInlineAssist, CyclePreviousInlineAssist, ModelUsageContext}; -use crate::{RemoveAllContext, ToggleContextPicker}; use agent::{ context_store::ContextStore, thread_store::{TextThreadStore, ThreadStore}, }; -use client::ErrorExt; use collections::VecDeque; -use db::kvp::Dismissable; use editor::actions::Paste; use editor::display_map::EditorMargins; use editor::{ ContextMenuOptions, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, MultiBuffer, actions::{MoveDown, MoveUp}, }; -use feature_flags::{FeatureFlagAppExt as _, ZedProFeatureFlag}; use fs::Fs; use gpui::{ - AnyElement, App, ClickEvent, Context, CursorStyle, Entity, EventEmitter, FocusHandle, - Focusable, FontWeight, Subscription, TextStyle, WeakEntity, Window, anchored, deferred, point, + AnyElement, App, Context, CursorStyle, Entity, EventEmitter, FocusHandle, Focusable, + Subscription, TextStyle, WeakEntity, Window, }; use language_model::{LanguageModel, LanguageModelRegistry}; use parking_lot::Mutex; @@ -33,12 +22,19 @@ use std::rc::Rc; use std::sync::Arc; use theme::ThemeSettings; use ui::utils::WithRemSize; -use ui::{ - CheckboxWithLabel, IconButtonShape, KeyBinding, Popover, PopoverMenuHandle, Tooltip, prelude::*, -}; +use ui::{IconButtonShape, KeyBinding, PopoverMenuHandle, Tooltip, prelude::*}; use workspace::Workspace; use zed_actions::agent::ToggleModelSelector; +use crate::agent_model_selector::AgentModelSelector; +use crate::buffer_codegen::BufferCodegen; +use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider}; +use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; +use crate::message_editor::{ContextCreasesAddon, extract_message_creases, insert_message_creases}; +use crate::terminal_codegen::TerminalCodegen; +use crate::{CycleNextInlineAssist, CyclePreviousInlineAssist, ModelUsageContext}; +use crate::{RemoveAllContext, ToggleContextPicker}; + pub struct PromptEditor { pub editor: Entity, mode: PromptEditorMode, @@ -75,7 +71,7 @@ impl Render for PromptEditor { let codegen = codegen.read(cx); if codegen.alternative_count(cx) > 1 { - buttons.push(self.render_cycle_controls(&codegen, cx)); + buttons.push(self.render_cycle_controls(codegen, cx)); } let editor_margins = editor_margins.lock(); @@ -93,8 +89,8 @@ impl Render for PromptEditor { }; let bottom_padding = match &self.mode { - PromptEditorMode::Buffer { .. } => Pixels::from(0.), - PromptEditorMode::Terminal { .. } => Pixels::from(8.0), + PromptEditorMode::Buffer { .. } => rems_from_px(2.0), + PromptEditorMode::Terminal { .. } => rems_from_px(8.0), }; buttons.extend(self.render_buttons(window, cx)); @@ -144,47 +140,16 @@ impl Render for PromptEditor { }; let error_message = SharedString::from(error.to_string()); - if error.error_code() == proto::ErrorCode::RateLimitExceeded - && cx.has_flag::() - { - el.child( - v_flex() - .child( - IconButton::new( - "rate-limit-error", - IconName::XCircle, - ) - .toggle_state(self.show_rate_limit_notice) - .shape(IconButtonShape::Square) - .icon_size(IconSize::Small) - .on_click( - cx.listener(Self::toggle_rate_limit_notice), - ), - ) - .children(self.show_rate_limit_notice.then(|| { - deferred( - anchored() - .position_mode( - gpui::AnchoredPositionMode::Local, - ) - .position(point(px(0.), px(24.))) - .anchor(gpui::Corner::TopLeft) - .child(self.render_rate_limit_notice(cx)), - ) - })), - ) - } else { - el.child( - div() - .id("error") - .tooltip(Tooltip::text(error_message)) - .child( - Icon::new(IconName::XCircle) - .size(IconSize::Small) - .color(Color::Error), - ), - ) - } + el.child( + div() + .id("error") + .tooltip(Tooltip::text(error_message)) + .child( + Icon::new(IconName::XCircle) + .size(IconSize::Small) + .color(Color::Error), + ), + ) }), ) .child( @@ -264,7 +229,7 @@ impl PromptEditor { self.editor = cx.new(|cx| { let mut editor = Editor::auto_height(1, Self::MAX_LINES as usize, window, cx); editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx); - editor.set_placeholder_text("Add a prompt…", cx); + editor.set_placeholder_text("Add a prompt…", window, cx); editor.set_text(prompt, window, cx); insert_message_creases( &mut editor, @@ -310,19 +275,6 @@ impl PromptEditor { crate::active_thread::attach_pasted_images_as_context(&self.context_store, cx); } - fn toggle_rate_limit_notice( - &mut self, - _: &ClickEvent, - window: &mut Window, - cx: &mut Context, - ) { - self.show_rate_limit_notice = !self.show_rate_limit_notice; - if self.show_rate_limit_notice { - window.focus(&self.editor.focus_handle(cx)); - } - cx.notify(); - } - fn handle_prompt_editor_events( &mut self, _: &Entity, @@ -334,7 +286,7 @@ impl PromptEditor { EditorEvent::Edited { .. } => { if let Some(workspace) = window.root::().flatten() { workspace.update(cx, |workspace, cx| { - let is_via_ssh = workspace.project().read(cx).is_via_ssh(); + let is_via_ssh = workspace.project().read(cx).is_via_remote_server(); workspace .client() @@ -345,7 +297,7 @@ impl PromptEditor { let prompt = self.editor.read(cx).text(cx); if self .prompt_history_ix - .map_or(true, |ix| self.prompt_history[ix] != prompt) + .is_none_or(|ix| self.prompt_history[ix] != prompt) { self.prompt_history_ix.take(); self.pending_prompt = prompt; @@ -707,75 +659,22 @@ impl PromptEditor { .into_any_element() } - fn render_rate_limit_notice(&self, cx: &mut Context) -> impl IntoElement { - Popover::new().child( - v_flex() - .occlude() - .p_2() - .child( - Label::new("Out of Tokens") - .size(LabelSize::Small) - .weight(FontWeight::BOLD), - ) - .child(Label::new( - "Try Zed Pro for higher limits, a wider range of models, and more.", - )) - .child( - h_flex() - .justify_between() - .child(CheckboxWithLabel::new( - "dont-show-again", - Label::new("Don't show again"), - if RateLimitNotice::dismissed() { - ui::ToggleState::Selected - } else { - ui::ToggleState::Unselected - }, - |selection, _, cx| { - let is_dismissed = match selection { - ui::ToggleState::Unselected => false, - ui::ToggleState::Indeterminate => return, - ui::ToggleState::Selected => true, - }; - - RateLimitNotice::set_dismissed(is_dismissed, cx); - }, - )) - .child( - h_flex() - .gap_2() - .child( - Button::new("dismiss", "Dismiss") - .style(ButtonStyle::Transparent) - .on_click(cx.listener(Self::toggle_rate_limit_notice)), - ) - .child(Button::new("more-info", "More Info").on_click( - |_event, window, cx| { - window.dispatch_action( - Box::new(zed_actions::OpenAccountSettings), - cx, - ) - }, - )), - ), - ), - ) - } - - fn render_editor(&mut self, window: &mut Window, cx: &mut Context) -> AnyElement { - let font_size = TextSize::Default.rems(cx); - let line_height = font_size.to_pixels(window.rem_size()) * 1.3; + fn render_editor(&mut self, _window: &mut Window, cx: &mut Context) -> AnyElement { + let colors = cx.theme().colors(); div() .key_context("InlineAssistEditor") .size_full() .p_2() .pl_1() - .bg(cx.theme().colors().editor_background) + .bg(colors.editor_background) .child({ let settings = ThemeSettings::get_global(cx); + let font_size = settings.buffer_font_size(cx); + let line_height = font_size * 1.2; + let text_style = TextStyle { - color: cx.theme().colors().editor_foreground, + color: colors.editor_foreground, font_family: settings.buffer_font.family.clone(), font_features: settings.buffer_font.features.clone(), font_size: font_size.into(), @@ -786,7 +685,7 @@ impl PromptEditor { EditorElement::new( &self.editor, EditorStyle { - background: cx.theme().colors().editor_background, + background: colors.editor_background, local_player: cx.theme().players().local(), text: text_style, ..Default::default() @@ -883,7 +782,7 @@ impl PromptEditor { // always show the cursor (even when it isn't focused) because // typing in one will make what you typed appear in all of them. editor.set_show_cursor_when_unfocused(true, cx); - editor.set_placeholder_text(Self::placeholder_text(&mode, window, cx), cx); + editor.set_placeholder_text(&Self::placeholder_text(&mode, window, cx), window, cx); editor.register_addon(ContextCreasesAddon::new()); editor.set_context_menu_options(ContextMenuOptions { min_entries_visible: 12, @@ -976,15 +875,7 @@ impl PromptEditor { self.editor .update(cx, |editor, _| editor.set_read_only(false)); } - CodegenStatus::Error(error) => { - if cx.has_flag::() - && error.error_code() == proto::ErrorCode::RateLimitExceeded - && !RateLimitNotice::dismissed() - { - self.show_rate_limit_notice = true; - cx.notify(); - } - + CodegenStatus::Error(_error) => { self.edited_since_done = false; self.editor .update(cx, |editor, _| editor.set_read_only(false)); @@ -1058,7 +949,7 @@ impl PromptEditor { cx, ); editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx); - editor.set_placeholder_text(Self::placeholder_text(&mode, window, cx), cx); + editor.set_placeholder_text(&Self::placeholder_text(&mode, window, cx), window, cx); editor.set_context_menu_options(ContextMenuOptions { min_entries_visible: 12, max_entries_visible: 12, @@ -1187,12 +1078,6 @@ impl PromptEditor { } } -struct RateLimitNotice; - -impl Dismissable for RateLimitNotice { - const KEY: &'static str = "dismissed-rate-limit-notice"; -} - pub enum CodegenStatus { Idle, Pending, @@ -1229,27 +1114,27 @@ pub enum GenerationMode { impl GenerationMode { fn start_label(self) -> &'static str { match self { - GenerationMode::Generate { .. } => "Generate", + GenerationMode::Generate => "Generate", GenerationMode::Transform => "Transform", } } fn tooltip_interrupt(self) -> &'static str { match self { - GenerationMode::Generate { .. } => "Interrupt Generation", + GenerationMode::Generate => "Interrupt Generation", GenerationMode::Transform => "Interrupt Transform", } } fn tooltip_restart(self) -> &'static str { match self { - GenerationMode::Generate { .. } => "Restart Generation", + GenerationMode::Generate => "Restart Generation", GenerationMode::Transform => "Restart Transform", } } fn tooltip_accept(self) -> &'static str { match self { - GenerationMode::Generate { .. } => "Accept Generation", + GenerationMode::Generate => "Accept Generation", GenerationMode::Transform => "Accept Transform", } } diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 7121624c87f6e44ba73f8380bfdf60227cba5b90..eb5a734b4ca57c2b79ac0dd004e42fc59c195fed 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -1,7 +1,6 @@ use std::{cmp::Reverse, sync::Arc}; use collections::{HashSet, IndexMap}; -use feature_flags::ZedProFeatureFlag; use fuzzy::{StringMatch, StringMatchCandidate, match_strings}; use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task}; use language_model::{ @@ -10,11 +9,8 @@ use language_model::{ }; use ordered_float::OrderedFloat; use picker::{Picker, PickerDelegate}; -use proto::Plan; use ui::{ListItem, ListItemSpacing, prelude::*}; -const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro"; - type OnModelChanged = Arc, &mut App) + 'static>; type GetActiveModel = Arc Option + 'static>; @@ -93,7 +89,7 @@ impl LanguageModelPickerDelegate { let entries = models.entries(); Self { - on_model_changed: on_model_changed.clone(), + on_model_changed, all_models: Arc::new(models), selected_index: Self::get_active_model_index(&entries, get_active_model(cx)), filtered_entries: entries, @@ -104,7 +100,7 @@ impl LanguageModelPickerDelegate { window, |picker, _, event, window, cx| { match event { - language_model::Event::ProviderStateChanged + language_model::Event::ProviderStateChanged(_) | language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { let query = picker.query(cx); @@ -296,7 +292,7 @@ impl ModelMatcher { pub fn fuzzy_search(&self, query: &str) -> Vec { let mut matches = self.bg_executor.block(match_strings( &self.candidates, - &query, + query, false, true, 100, @@ -514,7 +510,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { .pl_0p5() .gap_1p5() .w(px(240.)) - .child(Label::new(model_info.model.name().0.clone()).truncate()), + .child(Label::new(model_info.model.name().0).truncate()), ) .end_slot(div().pr_3().when(is_selected, |this| { this.child( @@ -531,13 +527,9 @@ impl PickerDelegate for LanguageModelPickerDelegate { fn render_footer( &self, - _: &mut Window, + _window: &mut Window, cx: &mut Context>, ) -> Option { - use feature_flags::FeatureFlagAppExt; - - let plan = proto::Plan::ZedPro; - Some( h_flex() .w_full() @@ -546,28 +538,6 @@ impl PickerDelegate for LanguageModelPickerDelegate { .p_1() .gap_4() .justify_between() - .when(cx.has_flag::(), |this| { - this.child(match plan { - Plan::ZedPro => Button::new("zed-pro", "Zed Pro") - .icon(IconName::ZedAssistant) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .on_click(|_, window, cx| { - window - .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx) - }), - Plan::Free | Plan::ZedProTrial => Button::new( - "try-pro", - if plan == Plan::ZedProTrial { - "Upgrade to Pro" - } else { - "Try Pro" - }, - ) - .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)), - }) - }) .child( Button::new("configure", "Configure") .icon(IconName::Settings) diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index 4b6d51c4c13d92b309d7bb10a6a753076344e4fa..e9a482c5f425f7559df2178f802b390cc53f2f61 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -6,7 +6,7 @@ use crate::agent_diff::AgentDiffThread; use crate::agent_model_selector::AgentModelSelector; use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip}; use crate::ui::{ - MaxModeTooltip, + BurnModeTooltip, preview::{AgentPreview, UsageCallout}, }; use agent::history_store::HistoryStore; @@ -14,10 +14,10 @@ use agent::{ context::{AgentContextKey, ContextLoadResult, load_context}, context_store::ContextStoreEvent, }; -use agent_settings::{AgentSettings, CompletionMode}; +use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use ai_onboarding::ApiKeysWithProviders; use buffer_diff::BufferDiff; -use cloud_llm_client::CompletionIntent; +use cloud_llm_client::{CompletionIntent, PlanV1}; use collections::{HashMap, HashSet}; use editor::actions::{MoveUp, Paste}; use editor::display_map::CreaseId; @@ -55,7 +55,7 @@ use zed_actions::agent::ToggleModelSelector; use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; -use crate::profile_selector::ProfileSelector; +use crate::profile_selector::{ProfileProvider, ProfileSelector}; use crate::{ ActiveThread, AgentDiffPane, ChatWithFollow, ExpandMessageEditor, Follow, KeepAll, ModelUsageContext, NewThread, OpenAgentDiff, RejectAll, RemoveAllContext, ToggleBurnMode, @@ -117,14 +117,15 @@ pub(crate) fn create_editor( let mut editor = Editor::new( editor::EditorMode::AutoHeight { min_lines, - max_lines: max_lines, + max_lines, }, buffer, None, window, cx, ); - editor.set_placeholder_text("Message the agent – @ to include context", cx); + editor.set_placeholder_text("Message the agent – @ to include context", window, cx); + editor.disable_word_completions(); editor.set_show_indent_guides(false, cx); editor.set_soft_wrap(); editor.set_use_modal_editing(true); @@ -152,6 +153,24 @@ pub(crate) fn create_editor( editor } +impl ProfileProvider for Entity { + fn profiles_supported(&self, cx: &App) -> bool { + self.read(cx) + .configured_model() + .is_some_and(|model| model.model.supports_tools()) + } + + fn profile_id(&self, cx: &App) -> AgentProfileId { + self.read(cx).profile().id().clone() + } + + fn set_profile(&self, profile_id: AgentProfileId, cx: &mut App) { + self.update(cx, |this, cx| { + this.set_profile(profile_id, cx); + }); + } +} + impl MessageEditor { pub fn new( fs: Arc, @@ -197,9 +216,10 @@ impl MessageEditor { let subscriptions = vec![ cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event), - cx.subscribe(&editor, |this, _, event, cx| match event { - EditorEvent::BufferEdited => this.handle_message_changed(cx), - _ => {} + cx.subscribe(&editor, |this, _, event: &EditorEvent, cx| { + if event == &EditorEvent::BufferEdited { + this.handle_message_changed(cx) + } }), cx.observe(&context_store, |this, _, cx| { // When context changes, reload it for token counting. @@ -221,14 +241,15 @@ impl MessageEditor { ) }); - let profile_selector = - cx.new(|cx| ProfileSelector::new(fs, thread.clone(), editor.focus_handle(cx), cx)); + let profile_selector = cx.new(|cx| { + ProfileSelector::new(fs, Arc::new(thread.clone()), editor.focus_handle(cx), cx) + }); Self { editor: editor.clone(), project: thread.read(cx).project().clone(), thread, - incompatible_tools_state: incompatible_tools.clone(), + incompatible_tools_state: incompatible_tools, workspace, context_store, prompt_store, @@ -358,18 +379,13 @@ impl MessageEditor { } fn send_to_model(&mut self, window: &mut Window, cx: &mut Context) { - let Some(ConfiguredModel { model, provider }) = self + let Some(ConfiguredModel { model, .. }) = self .thread .update(cx, |thread, cx| thread.get_or_init_configured_model(cx)) else { return; }; - if provider.must_accept_terms(cx) { - cx.notify(); - return; - } - let (user_message, user_message_creases) = self.editor.update(cx, |editor, cx| { let creases = extract_message_creases(editor, cx); let text = editor.text(cx); @@ -422,11 +438,11 @@ impl MessageEditor { thread.cancel_editing(cx); }); - let cancelled = self.thread.update(cx, |thread, cx| { + let canceled = self.thread.update(cx, |thread, cx| { thread.cancel_last_completion(Some(window.window_handle()), cx) }); - if cancelled { + if canceled { self.set_editor_is_expanded(false, cx); self.send_to_model(window, cx); } @@ -605,7 +621,7 @@ impl MessageEditor { this.toggle_burn_mode(&ToggleBurnMode, window, cx); })) .tooltip(move |_window, cx| { - cx.new(|_| MaxModeTooltip::new().selected(burn_mode_enabled)) + cx.new(|_| BurnModeTooltip::new().selected(burn_mode_enabled)) .into() }) .into_any_element(), @@ -671,11 +687,7 @@ impl MessageEditor { .as_ref() .map(|model| { self.incompatible_tools_state.update(cx, |state, cx| { - state - .incompatible_tools(&model.model, cx) - .iter() - .cloned() - .collect::>() + state.incompatible_tools(&model.model, cx).to_vec() }) }) .unwrap_or_default(); @@ -823,7 +835,6 @@ impl MessageEditor { .child(self.profile_selector.clone()) .child(self.model_selector.clone()) .map({ - let focus_handle = focus_handle.clone(); move |parent| { if is_generating { parent @@ -1117,7 +1128,7 @@ impl MessageEditor { ) .when(is_edit_changes_expanded, |parent| { parent.child( - v_flex().children(changed_buffers.into_iter().enumerate().flat_map( + v_flex().children(changed_buffers.iter().enumerate().flat_map( |(index, (buffer, _diff))| { let file = buffer.read(cx).file()?; let path = file.path(); @@ -1147,7 +1158,7 @@ impl MessageEditor { .buffer_font(cx) }); - let file_icon = FileIcons::get_icon(&path, cx) + let file_icon = FileIcons::get_icon(path, cx) .map(Icon::from_path) .map(|icon| icon.color(Color::Muted).size(IconSize::Small)) .unwrap_or_else(|| { @@ -1274,7 +1285,7 @@ impl MessageEditor { self.thread .read(cx) .configured_model() - .map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID) + .is_some_and(|model| model.provider.id() == ZED_CLOUD_PROVIDER_ID) } fn render_usage_callout(&self, line_height: Pixels, cx: &mut Context) -> Option
{ @@ -1287,7 +1298,9 @@ impl MessageEditor { return None; } - let plan = user_store.plan().unwrap_or(cloud_llm_client::Plan::ZedFree); + let plan = user_store + .plan() + .unwrap_or(cloud_llm_client::Plan::V1(PlanV1::ZedFree)); let usage = user_store.model_request_usage()?; @@ -1304,14 +1317,10 @@ impl MessageEditor { token_usage_ratio: TokenUsageRatio, cx: &mut Context, ) -> Option
{ - let icon = if token_usage_ratio == TokenUsageRatio::Exceeded { - Icon::new(IconName::Close) - .color(Color::Error) - .size(IconSize::XSmall) + let (icon, severity) = if token_usage_ratio == TokenUsageRatio::Exceeded { + (IconName::Close, Severity::Error) } else { - Icon::new(IconName::Warning) - .color(Color::Warning) - .size(IconSize::XSmall) + (IconName::Warning, Severity::Warning) }; let title = if token_usage_ratio == TokenUsageRatio::Exceeded { @@ -1326,29 +1335,33 @@ impl MessageEditor { "To continue, start a new thread from a summary." }; - let mut callout = Callout::new() + let callout = Callout::new() .line_height(line_height) + .severity(severity) .icon(icon) .title(title) .description(description) - .primary_action( - Button::new("start-new-thread", "Start New Thread") - .label_size(LabelSize::Small) - .on_click(cx.listener(|this, _, window, cx| { - let from_thread_id = Some(this.thread.read(cx).id().clone()); - window.dispatch_action(Box::new(NewThread { from_thread_id }), cx); - })), - ); - - if self.is_using_zed_provider(cx) { - callout = callout.secondary_action( - IconButton::new("burn-mode-callout", IconName::ZedBurnMode) - .icon_size(IconSize::XSmall) - .on_click(cx.listener(|this, _event, window, cx| { - this.toggle_burn_mode(&ToggleBurnMode, window, cx); - })), + .actions_slot( + h_flex() + .gap_0p5() + .when(self.is_using_zed_provider(cx), |this| { + this.child( + IconButton::new("burn-mode-callout", IconName::ZedBurnMode) + .icon_size(IconSize::XSmall) + .on_click(cx.listener(|this, _event, window, cx| { + this.toggle_burn_mode(&ToggleBurnMode, window, cx); + })), + ) + }) + .child( + Button::new("start-new-thread", "Start New Thread") + .label_size(LabelSize::Small) + .on_click(cx.listener(|this, _, window, cx| { + let from_thread_id = Some(this.thread.read(cx).id().clone()); + window.dispatch_action(Box::new(NewThread { from_thread_id }), cx); + })), + ), ); - } Some( div() @@ -1385,7 +1398,7 @@ impl MessageEditor { }) .ok(); }); - // Replace existing load task, if any, causing it to be cancelled. + // Replace existing load task, if any, causing it to be canceled. let load_task = load_task.shared(); self.load_context_task = Some(load_task.clone()); cx.spawn(async move |this, cx| { @@ -1427,7 +1440,7 @@ impl MessageEditor { let message_text = editor.read(cx).text(cx); if message_text.is_empty() - && loaded_context.map_or(true, |loaded_context| loaded_context.is_empty()) + && loaded_context.is_none_or(|loaded_context| loaded_context.is_empty()) { return None; } @@ -1540,9 +1553,8 @@ impl ContextCreasesAddon { cx: &mut Context, ) { self.creases.entry(key).or_default().extend(creases); - self._subscription = Some(cx.subscribe( - &context_store, - |editor, _, event, cx| match event { + self._subscription = Some( + cx.subscribe(context_store, |editor, _, event, cx| match event { ContextStoreEvent::ContextRemoved(key) => { let Some(this) = editor.addon_mut::() else { return; @@ -1562,8 +1574,8 @@ impl ContextCreasesAddon { editor.edit(ranges.into_iter().zip(replacement_texts), cx); cx.notify(); } - }, - )) + }), + ) } pub fn into_inner(self) -> HashMap> { @@ -1591,7 +1603,8 @@ pub fn extract_message_creases( .collect::>(); // Filter the addon's list of creases based on what the editor reports, // since the addon might have removed creases in it. - let creases = editor.display_map.update(cx, |display_map, cx| { + + editor.display_map.update(cx, |display_map, cx| { display_map .snapshot(cx) .crease_snapshot @@ -1615,8 +1628,7 @@ pub fn extract_message_creases( } }) .collect() - }); - creases + }) } impl EventEmitter for MessageEditor {} @@ -1668,7 +1680,7 @@ impl Render for MessageEditor { let has_history = self .history_store .as_ref() - .and_then(|hs| hs.update(cx, |hs, cx| hs.entries(cx).len() > 0).ok()) + .and_then(|hs| hs.update(cx, |hs, cx| !hs.entries(cx).is_empty()).ok()) .unwrap_or(false) || self .thread @@ -1681,7 +1693,7 @@ impl Render for MessageEditor { !has_history && is_signed_out && has_configured_providers, |this| this.child(cx.new(ApiKeysWithProviders::new)), ) - .when(changed_buffers.len() > 0, |parent| { + .when(!changed_buffers.is_empty(), |parent| { parent.child(self.render_edits_bar(&changed_buffers, window, cx)) }) .child(self.render_editor(window, cx)) @@ -1786,7 +1798,7 @@ impl AgentPreview for MessageEditor { .bg(cx.theme().colors().panel_background) .border_1() .border_color(cx.theme().colors().border) - .child(default_message_editor.clone()) + .child(default_message_editor) .into_any_element(), )]) .into_any_element(), diff --git a/crates/agent_ui/src/profile_selector.rs b/crates/agent_ui/src/profile_selector.rs index ddcb44d46b800f257314a8802ad01abc98560ce0..6ae4a73598a8e0e48509dda7a9bdd5e4fa2ea0ff 100644 --- a/crates/agent_ui/src/profile_selector.rs +++ b/crates/agent_ui/src/profile_selector.rs @@ -1,23 +1,31 @@ use crate::{ManageProfiles, ToggleProfileSelector}; -use agent::{ - Thread, - agent_profile::{AgentProfile, AvailableProfiles}, -}; +use agent::agent_profile::{AgentProfile, AvailableProfiles}; use agent_settings::{AgentDockPosition, AgentProfileId, AgentSettings, builtin_profiles}; use fs::Fs; -use gpui::{Action, Empty, Entity, FocusHandle, Subscription, prelude::*}; -use language_model::LanguageModelRegistry; +use gpui::{Action, Entity, FocusHandle, Subscription, prelude::*}; use settings::{Settings as _, SettingsStore, update_settings_file}; use std::sync::Arc; use ui::{ - ContextMenu, ContextMenuEntry, DocumentationSide, PopoverMenu, PopoverMenuHandle, Tooltip, - prelude::*, + ContextMenu, ContextMenuEntry, DocumentationSide, PopoverMenu, PopoverMenuHandle, TintColor, + Tooltip, prelude::*, }; +/// Trait for types that can provide and manage agent profiles +pub trait ProfileProvider { + /// Get the current profile ID + fn profile_id(&self, cx: &App) -> AgentProfileId; + + /// Set the profile ID + fn set_profile(&self, profile_id: AgentProfileId, cx: &mut App); + + /// Check if profiles are supported in the current context (e.g. if the model that is selected has tool support) + fn profiles_supported(&self, cx: &App) -> bool; +} + pub struct ProfileSelector { profiles: AvailableProfiles, fs: Arc, - thread: Entity, + provider: Arc, menu_handle: PopoverMenuHandle, focus_handle: FocusHandle, _subscriptions: Vec, @@ -26,7 +34,7 @@ pub struct ProfileSelector { impl ProfileSelector { pub fn new( fs: Arc, - thread: Entity, + provider: Arc, focus_handle: FocusHandle, cx: &mut Context, ) -> Self { @@ -37,7 +45,7 @@ impl ProfileSelector { Self { profiles: AgentProfile::available_profiles(cx), fs, - thread, + provider, menu_handle: PopoverMenuHandle::default(), focus_handle, _subscriptions: vec![settings_subscription], @@ -113,10 +121,10 @@ impl ProfileSelector { builtin_profiles::MINIMAL => Some("Chat about anything with no tools."), _ => None, }; - let thread_profile_id = self.thread.read(cx).profile().id(); + let thread_profile_id = self.provider.profile_id(cx); let entry = ContextMenuEntry::new(profile_name.clone()) - .toggleable(IconPosition::End, &profile_id == thread_profile_id); + .toggleable(IconPosition::End, profile_id == thread_profile_id); let entry = if let Some(doc_text) = documentation { entry.documentation_aside(documentation_side(settings.dock), move |_| { @@ -128,19 +136,16 @@ impl ProfileSelector { entry.handler({ let fs = self.fs.clone(); - let thread = self.thread.clone(); - let profile_id = profile_id.clone(); + let provider = self.provider.clone(); move |_window, cx| { update_settings_file::(fs.clone(), cx, { let profile_id = profile_id.clone(); move |settings, _cx| { - settings.set_profile(profile_id.clone()); + settings.set_profile(profile_id); } }); - thread.update(cx, |this, cx| { - this.set_profile(profile_id.clone(), cx); - }); + provider.set_profile(profile_id.clone(), cx); } }) } @@ -149,23 +154,15 @@ impl ProfileSelector { impl Render for ProfileSelector { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { let settings = AgentSettings::get_global(cx); - let profile_id = self.thread.read(cx).profile().id(); - let profile = settings.profiles.get(profile_id); + let profile_id = self.provider.profile_id(cx); + let profile = settings.profiles.get(&profile_id); let selected_profile = profile .map(|profile| profile.name.clone()) .unwrap_or_else(|| "Unknown".into()); - let configured_model = self.thread.read(cx).configured_model().or_else(|| { - let model_registry = LanguageModelRegistry::read_global(cx); - model_registry.default_model() - }); - let Some(configured_model) = configured_model else { - return Empty.into_any_element(); - }; - - if configured_model.model.supports_tools() { - let this = cx.entity().clone(); + if self.provider.profiles_supported(cx) { + let this = cx.entity(); let focus_handle = self.focus_handle.clone(); let trigger_button = Button::new("profile-selector-model", selected_profile) .label_size(LabelSize::Small) @@ -173,11 +170,11 @@ impl Render for ProfileSelector { .icon(IconName::ChevronDown) .icon_size(IconSize::XSmall) .icon_position(IconPosition::End) - .icon_color(Color::Muted); + .icon_color(Color::Muted) + .selected_style(ButtonStyle::Tinted(TintColor::Accent)); PopoverMenu::new("profile-selector") .trigger_with_tooltip(trigger_button, { - let focus_handle = focus_handle.clone(); move |window, cx| { Tooltip::for_action_in( "Toggle Profile Menu", @@ -199,6 +196,10 @@ impl Render for ProfileSelector { .menu(move |window, cx| { Some(this.update(cx, |this, cx| this.build_context_menu(window, cx))) }) + .offset(gpui::Point { + x: px(0.0), + y: px(-2.0), + }) .into_any_element() } else { Button::new("tools-not-supported-button", "Tools Unsupported") diff --git a/crates/agent_ui/src/slash_command.rs b/crates/agent_ui/src/slash_command.rs index 6b37c5a2d7d6aaf2c9878efb90a22d11ddac2419..c2f26c4f2ed33860196790746dd296e8c617b810 100644 --- a/crates/agent_ui/src/slash_command.rs +++ b/crates/agent_ui/src/slash_command.rs @@ -7,7 +7,10 @@ use fuzzy::{StringMatchCandidate, match_strings}; use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity, Window}; use language::{Anchor, Buffer, ToPoint}; use parking_lot::Mutex; -use project::{CompletionIntent, CompletionSource, lsp_store::CompletionDocumentation}; +use project::{ + CompletionDisplayOptions, CompletionIntent, CompletionSource, + lsp_store::CompletionDocumentation, +}; use rope::Point; use std::{ ops::Range, @@ -88,8 +91,6 @@ impl SlashCommandCompletionProvider { .map(|(editor, workspace)| { let command_name = mat.string.clone(); let command_range = command_range.clone(); - let editor = editor.clone(); - let workspace = workspace.clone(); Arc::new( move |intent: CompletionIntent, window: &mut Window, @@ -135,6 +136,7 @@ impl SlashCommandCompletionProvider { vec![project::CompletionResponse { completions, + display_options: CompletionDisplayOptions::default(), is_incomplete: false, }] }) @@ -158,7 +160,7 @@ impl SlashCommandCompletionProvider { if let Some(command) = self.slash_commands.command(command_name, cx) { let completions = command.complete_argument( arguments, - new_cancel_flag.clone(), + new_cancel_flag, self.workspace.clone(), window, cx, @@ -239,6 +241,7 @@ impl SlashCommandCompletionProvider { Ok(vec![project::CompletionResponse { completions, + display_options: CompletionDisplayOptions::default(), // TODO: Could have slash commands indicate whether their completions are incomplete. is_incomplete: true, }]) @@ -246,6 +249,7 @@ impl SlashCommandCompletionProvider { } else { Task::ready(Ok(vec![project::CompletionResponse { completions: Vec::new(), + display_options: CompletionDisplayOptions::default(), is_incomplete: true, }])) } @@ -307,6 +311,7 @@ impl CompletionProvider for SlashCommandCompletionProvider { else { return Task::ready(Ok(vec![project::CompletionResponse { completions: Vec::new(), + display_options: CompletionDisplayOptions::default(), is_incomplete: false, }])); }; diff --git a/crates/agent_ui/src/slash_command_picker.rs b/crates/agent_ui/src/slash_command_picker.rs index 678562e0594b69f43524155c70bb176727f57b46..a6bb61510cbeb557e22018c73082bba17d177d7e 100644 --- a/crates/agent_ui/src/slash_command_picker.rs +++ b/crates/agent_ui/src/slash_command_picker.rs @@ -140,12 +140,10 @@ impl PickerDelegate for SlashCommandDelegate { ); ret.push(index - 1); } - } else { - if let SlashCommandEntry::Advert { .. } = command { - previous_is_advert = true; - if index != 0 { - ret.push(index - 1); - } + } else if let SlashCommandEntry::Advert { .. } = command { + previous_is_advert = true; + if index != 0 { + ret.push(index - 1); } } } @@ -214,7 +212,7 @@ impl PickerDelegate for SlashCommandDelegate { let mut label = format!("{}", info.name); if let Some(args) = info.args.as_ref().filter(|_| selected) { - label.push_str(&args); + label.push_str(args); } Label::new(label) .single_line() @@ -329,9 +327,7 @@ where }; let picker_view = cx.new(|cx| { - let picker = - Picker::uniform_list(delegate, window, cx).max_height(Some(rems(20.).into())); - picker + Picker::uniform_list(delegate, window, cx).max_height(Some(rems(20.).into())) }); let handle = self diff --git a/crates/agent_ui/src/slash_command_settings.rs b/crates/agent_ui/src/slash_command_settings.rs index f254d00ec60b08197237d0f179bc755c16ed7d40..9580ffef0f317fbe726c57041fad4f0fa438e143 100644 --- a/crates/agent_ui/src/slash_command_settings.rs +++ b/crates/agent_ui/src/slash_command_settings.rs @@ -2,27 +2,17 @@ use anyhow::Result; use gpui::App; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsUi}; /// Settings for slash commands. -#[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)] +#[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema, SettingsUi, SettingsKey)] +#[settings_key(key = "slash_commands")] pub struct SlashCommandSettings { - /// Settings for the `/docs` slash command. - #[serde(default)] - pub docs: DocsCommandSettings, /// Settings for the `/cargo-workspace` slash command. #[serde(default)] pub cargo_workspace: CargoWorkspaceCommandSettings, } -/// Settings for the `/docs` slash command. -#[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)] -pub struct DocsCommandSettings { - /// Whether `/docs` is enabled. - #[serde(default)] - pub enabled: bool, -} - /// Settings for the `/cargo-workspace` slash command. #[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)] pub struct CargoWorkspaceCommandSettings { @@ -32,8 +22,6 @@ pub struct CargoWorkspaceCommandSettings { } impl Settings for SlashCommandSettings { - const KEY: Option<&'static str> = Some("slash_commands"); - type FileContent = Self; fn load(sources: SettingsSources, _cx: &mut App) -> Result { diff --git a/crates/agent_ui/src/terminal_codegen.rs b/crates/agent_ui/src/terminal_codegen.rs index 54f5b52f584cb87fd2953a148aa8ae48ea38b862..5a4a9d560a16e858dcaedf706f2067a24bc12c5f 100644 --- a/crates/agent_ui/src/terminal_codegen.rs +++ b/crates/agent_ui/src/terminal_codegen.rs @@ -48,7 +48,7 @@ impl TerminalCodegen { let prompt = prompt_task.await; let model_telemetry_id = model.telemetry_id(); let model_provider_id = model.provider_id(); - let response = model.stream_completion_text(prompt, &cx).await; + let response = model.stream_completion_text(prompt, cx).await; let generate = async { let message_id = response .as_ref() diff --git a/crates/agent_ui/src/terminal_inline_assistant.rs b/crates/agent_ui/src/terminal_inline_assistant.rs index bcbc308c99da7b80e716fce9e60461352dcb814c..e7070c0d7fc4878c1f73a6d5f874607422ae53d6 100644 --- a/crates/agent_ui/src/terminal_inline_assistant.rs +++ b/crates/agent_ui/src/terminal_inline_assistant.rs @@ -388,20 +388,20 @@ impl TerminalInlineAssistant { window: &mut Window, cx: &mut App, ) { - if let Some(assist) = self.assists.get_mut(&assist_id) { - if let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned() { - assist - .terminal - .update(cx, |terminal, cx| { - terminal.clear_block_below_cursor(cx); - let block = terminal_view::BlockProperties { - height, - render: Box::new(move |_| prompt_editor.clone().into_any_element()), - }; - terminal.set_block_below_cursor(block, window, cx); - }) - .log_err(); - } + if let Some(assist) = self.assists.get_mut(&assist_id) + && let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned() + { + assist + .terminal + .update(cx, |terminal, cx| { + terminal.clear_block_below_cursor(cx); + let block = terminal_view::BlockProperties { + height, + render: Box::new(move |_| prompt_editor.clone().into_any_element()), + }; + terminal.set_block_below_cursor(block, window, cx); + }) + .log_err(); } } } @@ -432,7 +432,7 @@ impl TerminalInlineAssist { terminal: terminal.downgrade(), prompt_editor: Some(prompt_editor.clone()), codegen: codegen.clone(), - workspace: workspace.clone(), + workspace, context_store, prompt_store, _subscriptions: vec![ @@ -450,23 +450,20 @@ impl TerminalInlineAssist { return; }; - if let CodegenStatus::Error(error) = &codegen.read(cx).status { - if assist.prompt_editor.is_none() { - if let Some(workspace) = assist.workspace.upgrade() { - let error = - format!("Terminal inline assistant error: {}", error); - workspace.update(cx, |workspace, cx| { - struct InlineAssistantError; - - let id = - NotificationId::composite::( - assist_id.0, - ); - - workspace.show_toast(Toast::new(id, error), cx); - }) - } - } + if let CodegenStatus::Error(error) = &codegen.read(cx).status + && assist.prompt_editor.is_none() + && let Some(workspace) = assist.workspace.upgrade() + { + let error = format!("Terminal inline assistant error: {}", error); + workspace.update(cx, |workspace, cx| { + struct InlineAssistantError; + + let id = NotificationId::composite::( + assist_id.0, + ); + + workspace.show_toast(Toast::new(id, error), cx); + }) } if assist.prompt_editor.is_none() { diff --git a/crates/agent_ui/src/text_thread_editor.rs b/crates/agent_ui/src/text_thread_editor.rs index 49a37002f76873aaebbd3bcf32a3e7f7608ffa35..d979db5e0468b696d32ed755aec1ef47e2fd3df3 100644 --- a/crates/agent_ui/src/text_thread_editor.rs +++ b/crates/agent_ui/src/text_thread_editor.rs @@ -1,14 +1,12 @@ use crate::{ - burn_mode_tooltip::BurnModeTooltip, + QuoteSelection, language_model_selector::{LanguageModelSelector, language_model_selector}, + ui::BurnModeTooltip, }; use agent_settings::{AgentSettings, CompletionMode}; use anyhow::Result; use assistant_slash_command::{SlashCommand, SlashCommandOutputSection, SlashCommandWorkingSet}; -use assistant_slash_commands::{ - DefaultSlashCommand, DocsSlashCommand, DocsSlashCommandArgs, FileSlashCommand, - selections_creases, -}; +use assistant_slash_commands::{DefaultSlashCommand, FileSlashCommand, selections_creases}; use client::{proto, zed_urls}; use collections::{BTreeSet, HashMap, HashSet, hash_map}; use editor::{ @@ -27,10 +25,9 @@ use gpui::{ Action, Animation, AnimationExt, AnyElement, AnyView, App, ClipboardEntry, ClipboardItem, Empty, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, Global, InteractiveElement, IntoElement, ParentElement, Pixels, Render, RenderImage, SharedString, Size, - StatefulInteractiveElement, Styled, Subscription, Task, Transformation, WeakEntity, actions, - div, img, percentage, point, prelude::*, pulsating_between, size, + StatefulInteractiveElement, Styled, Subscription, Task, WeakEntity, actions, div, img, point, + prelude::*, pulsating_between, size, }; -use indexed_docs::IndexedDocsStore; use language::{ BufferSnapshot, LspAdapterDelegate, ToOffset, language_settings::{SoftWrap, all_language_settings}, @@ -56,8 +53,8 @@ use std::{ }; use text::SelectionGoal; use ui::{ - ButtonLike, Disclosure, ElevationIndex, KeyBinding, PopoverMenuHandle, TintColor, Tooltip, - prelude::*, + ButtonLike, CommonAnimationExt, Disclosure, ElevationIndex, KeyBinding, PopoverMenuHandle, + TintColor, Tooltip, prelude::*, }; use util::{ResultExt, maybe}; use workspace::{ @@ -77,7 +74,7 @@ use crate::{slash_command::SlashCommandCompletionProvider, slash_command_picker} use assistant_context::{ AssistantContext, CacheStatus, Content, ContextEvent, ContextId, InvokedSlashCommandId, InvokedSlashCommandStatus, Message, MessageId, MessageMetadata, MessageStatus, - ParsedSlashCommand, PendingSlashCommandStatus, ThoughtProcessOutputSection, + PendingSlashCommandStatus, ThoughtProcessOutputSection, }; actions!( @@ -93,8 +90,6 @@ actions!( CycleMessageRole, /// Inserts the selected text into the active editor. InsertIntoEditor, - /// Quotes the current selection in the assistant conversation. - QuoteSelection, /// Splits the conversation at the current cursor position. Split, ] @@ -195,7 +190,6 @@ pub struct TextThreadEditor { invoked_slash_command_creases: HashMap, _subscriptions: Vec, last_error: Option, - show_accept_terms: bool, pub(crate) slash_menu_handle: PopoverMenuHandle>, // dragged_file_worktrees is used to keep references to worktrees that were added @@ -294,7 +288,6 @@ impl TextThreadEditor { invoked_slash_command_creases: HashMap::default(), _subscriptions, last_error: None, - show_accept_terms: false, slash_menu_handle: Default::default(), dragged_file_worktrees: Vec::new(), language_model_selector: cx.new(|cx| { @@ -368,24 +361,12 @@ impl TextThreadEditor { if self.sending_disabled(cx) { return; } + telemetry::event!("Agent Message Sent", agent = "zed-text"); self.send_to_model(window, cx); } fn send_to_model(&mut self, window: &mut Window, cx: &mut Context) { - let provider = LanguageModelRegistry::read_global(cx) - .default_model() - .map(|default| default.provider); - if provider - .as_ref() - .map_or(false, |provider| provider.must_accept_terms(cx)) - { - self.show_accept_terms = true; - cx.notify(); - return; - } - self.last_error = None; - if let Some(user_message) = self.context.update(cx, |context, cx| context.assist(cx)) { let new_selection = { let cursor = user_message @@ -461,7 +442,7 @@ impl TextThreadEditor { || snapshot .chars_at(newest_cursor) .next() - .map_or(false, |ch| ch != '\n') + .is_some_and(|ch| ch != '\n') { editor.move_to_end_of_line( &MoveToEndOfLine { @@ -544,7 +525,7 @@ impl TextThreadEditor { let context = self.context.read(cx); let sections = context .slash_command_output_sections() - .into_iter() + .iter() .filter(|section| section.is_valid(context.buffer().read(cx))) .cloned() .collect::>(); @@ -701,19 +682,7 @@ impl TextThreadEditor { } }; let render_trailer = { - let command = command.clone(); - move |row, _unfold, _window: &mut Window, cx: &mut App| { - // TODO: In the future we should investigate how we can expose - // this as a hook on the `SlashCommand` trait so that we don't - // need to special-case it here. - if command.name == DocsSlashCommand::NAME { - return render_docs_slash_command_trailer( - row, - command.clone(), - cx, - ); - } - + move |_row, _unfold, _window: &mut Window, _cx: &mut App| { Empty.into_any() } }; @@ -761,32 +730,27 @@ impl TextThreadEditor { ) { if let Some(invoked_slash_command) = self.context.read(cx).invoked_slash_command(&command_id) + && let InvokedSlashCommandStatus::Finished = invoked_slash_command.status { - if let InvokedSlashCommandStatus::Finished = invoked_slash_command.status { - let run_commands_in_ranges = invoked_slash_command - .run_commands_in_ranges - .iter() - .cloned() - .collect::>(); - for range in run_commands_in_ranges { - let commands = self.context.update(cx, |context, cx| { - context.reparse(cx); - context - .pending_commands_for_range(range.clone(), cx) - .to_vec() - }); + let run_commands_in_ranges = invoked_slash_command.run_commands_in_ranges.clone(); + for range in run_commands_in_ranges { + let commands = self.context.update(cx, |context, cx| { + context.reparse(cx); + context + .pending_commands_for_range(range.clone(), cx) + .to_vec() + }); - for command in commands { - self.run_command( - command.source_range, - &command.name, - &command.arguments, - false, - self.workspace.clone(), - window, - cx, - ); - } + for command in commands { + self.run_command( + command.source_range, + &command.name, + &command.arguments, + false, + self.workspace.clone(), + window, + cx, + ); } } } @@ -1097,15 +1061,7 @@ impl TextThreadEditor { Icon::new(IconName::ArrowCircle) .size(IconSize::XSmall) .color(Color::Info) - .with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| { - icon.transform(Transformation::rotate( - percentage(delta), - )) - }, - ) + .with_rotate_animation(2) .into_any_element(), ); note = Some(Self::esc_kbd(cx).into_any_element()); @@ -1258,7 +1214,7 @@ impl TextThreadEditor { let mut new_blocks = vec![]; let mut block_index_to_message = vec![]; for message in self.context.read(cx).messages(cx) { - if let Some(_) = blocks_to_remove.remove(&message.id) { + if blocks_to_remove.remove(&message.id).is_some() { // This is an old message that we might modify. let Some((meta, block_id)) = old_blocks.get_mut(&message.id) else { debug_assert!( @@ -1296,7 +1252,7 @@ impl TextThreadEditor { context_editor_view: &Entity, cx: &mut Context, ) -> Option<(String, bool)> { - const CODE_FENCE_DELIMITER: &'static str = "```"; + const CODE_FENCE_DELIMITER: &str = "```"; let context_editor = context_editor_view.read(cx).editor.clone(); context_editor.update(cx, |context_editor, cx| { @@ -1760,7 +1716,7 @@ impl TextThreadEditor { render_slash_command_output_toggle, |_, _, _, _| Empty.into_any(), ) - .with_metadata(metadata.crease.clone()) + .with_metadata(metadata.crease) }), cx, ); @@ -1831,7 +1787,7 @@ impl TextThreadEditor { .filter_map(|(anchor, render_image)| { const MAX_HEIGHT_IN_LINES: u32 = 8; let anchor = buffer.anchor_in_excerpt(excerpt_id, anchor).unwrap(); - let image = render_image.clone(); + let image = render_image; anchor.is_valid(&buffer).then(|| BlockProperties { placement: BlockPlacement::Above(anchor), height: Some(MAX_HEIGHT_IN_LINES), @@ -1893,8 +1849,55 @@ impl TextThreadEditor { .update(cx, |context, cx| context.summarize(true, cx)); } + fn render_remaining_tokens(&self, cx: &App) -> Option> { + let (token_count_color, token_count, max_token_count, tooltip) = + match token_state(&self.context, cx)? { + TokenState::NoTokensLeft { + max_token_count, + token_count, + } => ( + Color::Error, + token_count, + max_token_count, + Some("Token Limit Reached"), + ), + TokenState::HasMoreTokens { + max_token_count, + token_count, + over_warn_threshold, + } => { + let (color, tooltip) = if over_warn_threshold { + (Color::Warning, Some("Token Limit is Close to Exhaustion")) + } else { + (Color::Muted, None) + }; + (color, token_count, max_token_count, tooltip) + } + }; + + Some( + h_flex() + .id("token-count") + .gap_0p5() + .child( + Label::new(humanize_token_count(token_count)) + .size(LabelSize::Small) + .color(token_count_color), + ) + .child(Label::new("/").size(LabelSize::Small).color(Color::Muted)) + .child( + Label::new(humanize_token_count(max_token_count)) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .when_some(tooltip, |element, tooltip| { + element.tooltip(Tooltip::text(tooltip)) + }), + ) + } + fn render_send_button(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let focus_handle = self.focus_handle(cx).clone(); + let focus_handle = self.focus_handle(cx); let (style, tooltip) = match token_state(&self.context, cx) { Some(TokenState::NoTokensLeft { .. }) => ( @@ -1952,7 +1955,6 @@ impl TextThreadEditor { ConfigurationError::NoProvider | ConfigurationError::ModelNotFound | ConfigurationError::ProviderNotAuthenticated(_) => true, - ConfigurationError::ProviderPendingTermsAcceptance(_) => self.show_accept_terms, } } @@ -2036,7 +2038,7 @@ impl TextThreadEditor { None => IconName::Ai, }; - let focus_handle = self.editor().focus_handle(cx).clone(); + let focus_handle = self.editor().focus_handle(cx); PickerPopoverMenu::new( self.language_model_selector.clone(), @@ -2182,8 +2184,8 @@ impl TextThreadEditor { /// Returns the contents of the *outermost* fenced code block that contains the given offset. fn find_surrounding_code_block(snapshot: &BufferSnapshot, offset: usize) -> Option> { - const CODE_BLOCK_NODE: &'static str = "fenced_code_block"; - const CODE_BLOCK_CONTENT: &'static str = "code_fence_content"; + const CODE_BLOCK_NODE: &str = "fenced_code_block"; + const CODE_BLOCK_CONTENT: &str = "code_fence_content"; let layer = snapshot.syntax_layers().next()?; @@ -2398,70 +2400,6 @@ fn render_pending_slash_command_gutter_decoration( icon.into_any_element() } -fn render_docs_slash_command_trailer( - row: MultiBufferRow, - command: ParsedSlashCommand, - cx: &mut App, -) -> AnyElement { - if command.arguments.is_empty() { - return Empty.into_any(); - } - let args = DocsSlashCommandArgs::parse(&command.arguments); - - let Some(store) = args - .provider() - .and_then(|provider| IndexedDocsStore::try_global(provider, cx).ok()) - else { - return Empty.into_any(); - }; - - let Some(package) = args.package() else { - return Empty.into_any(); - }; - - let mut children = Vec::new(); - - if store.is_indexing(&package) { - children.push( - div() - .id(("crates-being-indexed", row.0)) - .child(Icon::new(IconName::ArrowCircle).with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(4)).repeat(), - |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), - )) - .tooltip({ - let package = package.clone(); - Tooltip::text(format!("Indexing {package}…")) - }) - .into_any_element(), - ); - } - - if let Some(latest_error) = store.latest_error_for_package(&package) { - children.push( - div() - .id(("latest-error", row.0)) - .child( - Icon::new(IconName::Warning) - .size(IconSize::Small) - .color(Color::Warning), - ) - .tooltip(Tooltip::text(format!("Failed to index: {latest_error}"))) - .into_any_element(), - ) - } - - let is_indexing = store.is_indexing(&package); - let latest_error = store.latest_error_for_package(&package); - - if !is_indexing && latest_error.is_none() { - return Empty.into_any(); - } - - h_flex().gap_2().children(children).into_any_element() -} - #[derive(Debug, Clone, Serialize, Deserialize)] struct CopyMetadata { creases: Vec, @@ -2521,9 +2459,14 @@ impl Render for TextThreadEditor { ) .child( h_flex() - .gap_1() - .child(self.render_language_model_selector(window, cx)) - .child(self.render_send_button(window, cx)), + .gap_2p5() + .children(self.render_remaining_tokens(cx)) + .child( + h_flex() + .gap_1() + .child(self.render_language_model_selector(window, cx)) + .child(self.render_send_button(window, cx)), + ), ), ) } @@ -2811,58 +2754,6 @@ impl FollowableItem for TextThreadEditor { } } -pub fn render_remaining_tokens( - context_editor: &Entity, - cx: &App, -) -> Option> { - let context = &context_editor.read(cx).context; - - let (token_count_color, token_count, max_token_count, tooltip) = match token_state(context, cx)? - { - TokenState::NoTokensLeft { - max_token_count, - token_count, - } => ( - Color::Error, - token_count, - max_token_count, - Some("Token Limit Reached"), - ), - TokenState::HasMoreTokens { - max_token_count, - token_count, - over_warn_threshold, - } => { - let (color, tooltip) = if over_warn_threshold { - (Color::Warning, Some("Token Limit is Close to Exhaustion")) - } else { - (Color::Muted, None) - }; - (color, token_count, max_token_count, tooltip) - } - }; - - Some( - h_flex() - .id("token-count") - .gap_0p5() - .child( - Label::new(humanize_token_count(token_count)) - .size(LabelSize::Small) - .color(token_count_color), - ) - .child(Label::new("/").size(LabelSize::Small).color(Color::Muted)) - .child( - Label::new(humanize_token_count(max_token_count)) - .size(LabelSize::Small) - .color(Color::Muted), - ) - .when_some(tooltip, |element, tooltip| { - element.tooltip(Tooltip::text(tooltip)) - }), - ) -} - enum PendingSlashCommand {} fn invoked_slash_command_fold_placeholder( @@ -2891,11 +2782,7 @@ fn invoked_slash_command_fold_placeholder( .child(Label::new(format!("/{}", command.name))) .map(|parent| match &command.status { InvokedSlashCommandStatus::Running(_) => { - parent.child(Icon::new(IconName::ArrowCircle).with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(4)).repeat(), - |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), - )) + parent.child(Icon::new(IconName::ArrowCircle).with_rotate_animation(4)) } InvokedSlashCommandStatus::Error(message) => parent.child( Label::new(format!("error: {message}")) @@ -3214,7 +3101,7 @@ mod tests { let context_editor = window .update(&mut cx, |_, window, cx| { cx.new(|cx| { - let editor = TextThreadEditor::for_context( + TextThreadEditor::for_context( context.clone(), fs, workspace.downgrade(), @@ -3222,8 +3109,7 @@ mod tests { None, window, cx, - ); - editor + ) }) }) .unwrap(); diff --git a/crates/agent_ui/src/thread_history.rs b/crates/agent_ui/src/thread_history.rs index b8d1db88d6e3164b32ade0f2137ad7ca37a0650a..73d3b705b74c18e367298cf5ed74852459e68b2e 100644 --- a/crates/agent_ui/src/thread_history.rs +++ b/crates/agent_ui/src/thread_history.rs @@ -73,7 +73,7 @@ impl ThreadHistory { ) -> Self { let search_editor = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); - editor.set_placeholder_text("Search threads...", cx); + editor.set_placeholder_text("Search threads...", window, cx); editor }); @@ -166,14 +166,13 @@ impl ThreadHistory { this.all_entries.len().saturating_sub(1), cx, ); - } else if let Some(prev_id) = previously_selected_entry { - if let Some(new_ix) = this + } else if let Some(prev_id) = previously_selected_entry + && let Some(new_ix) = this .all_entries .iter() .position(|probe| probe.id() == prev_id) - { - this.set_selected_entry_index(new_ix, cx); - } + { + this.set_selected_entry_index(new_ix, cx); } } SearchState::Searching { query, .. } | SearchState::Searched { query, .. } => { @@ -541,6 +540,7 @@ impl Render for ThreadHistory { v_flex() .key_context("ThreadHistory") .size_full() + .bg(cx.theme().colors().panel_background) .on_action(cx.listener(Self::select_previous)) .on_action(cx.listener(Self::select_next)) .on_action(cx.listener(Self::select_first)) diff --git a/crates/agent_ui/src/tool_compatibility.rs b/crates/agent_ui/src/tool_compatibility.rs index d4e1da5bb0a532c8307364582349378d98c51a26..046c0a4abc5e3ac0130b2af5406cfbb3f977b00c 100644 --- a/crates/agent_ui/src/tool_compatibility.rs +++ b/crates/agent_ui/src/tool_compatibility.rs @@ -14,13 +14,11 @@ pub struct IncompatibleToolsState { impl IncompatibleToolsState { pub fn new(thread: Entity, cx: &mut Context) -> Self { - let _tool_working_set_subscription = - cx.subscribe(&thread, |this, _, event, _| match event { - ThreadEvent::ProfileChanged => { - this.cache.clear(); - } - _ => {} - }); + let _tool_working_set_subscription = cx.subscribe(&thread, |this, _, event, _| { + if let ThreadEvent::ProfileChanged = event { + this.cache.clear(); + } + }); Self { cache: HashMap::default(), diff --git a/crates/agent_ui/src/ui.rs b/crates/agent_ui/src/ui.rs index beeaf0c43bbaa9384030879654bfaada1e4d9cd1..1a3264bd77ccda1a27ffd19f3c61c3635fe78dc9 100644 --- a/crates/agent_ui/src/ui.rs +++ b/crates/agent_ui/src/ui.rs @@ -1,14 +1,18 @@ +mod acp_onboarding_modal; mod agent_notification; mod burn_mode_tooltip; +mod claude_code_onboarding_modal; mod context_pill; mod end_trial_upsell; -// mod new_thread_button; mod onboarding_modal; pub mod preview; +mod unavailable_editing_tooltip; +pub use acp_onboarding_modal::*; pub use agent_notification::*; pub use burn_mode_tooltip::*; +pub use claude_code_onboarding_modal::*; pub use context_pill::*; pub use end_trial_upsell::*; -// pub use new_thread_button::*; pub use onboarding_modal::*; +pub use unavailable_editing_tooltip::*; diff --git a/crates/agent_ui/src/ui/acp_onboarding_modal.rs b/crates/agent_ui/src/ui/acp_onboarding_modal.rs new file mode 100644 index 0000000000000000000000000000000000000000..8433904fb3b540c2d78c8634b7a6755303d6e15c --- /dev/null +++ b/crates/agent_ui/src/ui/acp_onboarding_modal.rs @@ -0,0 +1,246 @@ +use client::zed_urls; +use gpui::{ + ClickEvent, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, MouseDownEvent, Render, + linear_color_stop, linear_gradient, +}; +use ui::{TintColor, Vector, VectorName, prelude::*}; +use workspace::{ModalView, Workspace}; + +use crate::agent_panel::{AgentPanel, AgentType}; + +macro_rules! acp_onboarding_event { + ($name:expr) => { + telemetry::event!($name, source = "ACP Onboarding"); + }; + ($name:expr, $($key:ident $(= $value:expr)?),+ $(,)?) => { + telemetry::event!($name, source = "ACP Onboarding", $($key $(= $value)?),+); + }; +} + +pub struct AcpOnboardingModal { + focus_handle: FocusHandle, + workspace: Entity, +} + +impl AcpOnboardingModal { + pub fn toggle(workspace: &mut Workspace, window: &mut Window, cx: &mut Context) { + let workspace_entity = cx.entity(); + workspace.toggle_modal(window, cx, |_window, cx| Self { + workspace: workspace_entity, + focus_handle: cx.focus_handle(), + }); + } + + fn open_panel(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context) { + self.workspace.update(cx, |workspace, cx| { + workspace.focus_panel::(window, cx); + + if let Some(panel) = workspace.panel::(cx) { + panel.update(cx, |panel, cx| { + panel.new_agent_thread(AgentType::Gemini, window, cx); + }); + } + }); + + cx.emit(DismissEvent); + + acp_onboarding_event!("Open Panel Clicked"); + } + + fn view_docs(&mut self, _: &ClickEvent, _: &mut Window, cx: &mut Context) { + cx.open_url(&zed_urls::external_agents_docs(cx)); + cx.notify(); + + acp_onboarding_event!("Documentation Link Clicked"); + } + + fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context) { + cx.emit(DismissEvent); + } +} + +impl EventEmitter for AcpOnboardingModal {} + +impl Focusable for AcpOnboardingModal { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl ModalView for AcpOnboardingModal {} + +impl Render for AcpOnboardingModal { + fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + let illustration_element = |label: bool, opacity: f32| { + h_flex() + .px_1() + .py_0p5() + .gap_1() + .rounded_sm() + .bg(cx.theme().colors().element_active.opacity(0.05)) + .border_1() + .border_color(cx.theme().colors().border) + .border_dashed() + .child( + Icon::new(IconName::Stop) + .size(IconSize::Small) + .color(Color::Custom(cx.theme().colors().text_muted.opacity(0.15))), + ) + .map(|this| { + if label { + this.child( + Label::new("Your Agent Here") + .size(LabelSize::Small) + .color(Color::Muted), + ) + } else { + this.child( + div().w_16().h_1().rounded_full().bg(cx + .theme() + .colors() + .element_active + .opacity(0.6)), + ) + } + }) + .opacity(opacity) + }; + + let illustration = h_flex() + .relative() + .h(rems_from_px(126.)) + .bg(cx.theme().colors().editor_background) + .border_b_1() + .border_color(cx.theme().colors().border_variant) + .justify_center() + .gap_8() + .rounded_t_md() + .overflow_hidden() + .child( + div().absolute().inset_0().w(px(515.)).h(px(126.)).child( + Vector::new(VectorName::AcpGrid, rems_from_px(515.), rems_from_px(126.)) + .color(ui::Color::Custom(cx.theme().colors().text.opacity(0.02))), + ), + ) + .child(div().absolute().inset_0().size_full().bg(linear_gradient( + 0., + linear_color_stop( + cx.theme().colors().elevated_surface_background.opacity(0.1), + 0.9, + ), + linear_color_stop( + cx.theme().colors().elevated_surface_background.opacity(0.), + 0., + ), + ))) + .child( + div() + .absolute() + .inset_0() + .size_full() + .bg(gpui::black().opacity(0.15)), + ) + .child( + Vector::new( + VectorName::AcpLogoSerif, + rems_from_px(257.), + rems_from_px(47.), + ) + .color(ui::Color::Custom(cx.theme().colors().text.opacity(0.8))), + ) + .child( + v_flex() + .gap_1p5() + .child(illustration_element(false, 0.15)) + .child(illustration_element(true, 0.3)) + .child( + h_flex() + .pl_1() + .pr_2() + .py_0p5() + .gap_1() + .rounded_sm() + .bg(cx.theme().colors().element_active.opacity(0.2)) + .border_1() + .border_color(cx.theme().colors().border) + .child( + Icon::new(IconName::AiGemini) + .size(IconSize::Small) + .color(Color::Muted), + ) + .child(Label::new("New Gemini CLI Thread").size(LabelSize::Small)), + ) + .child(illustration_element(true, 0.3)) + .child(illustration_element(false, 0.15)), + ); + + let heading = v_flex() + .w_full() + .gap_1() + .child( + Label::new("Now Available") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child(Headline::new("Bring Your Own Agent to Zed").size(HeadlineSize::Large)); + + let copy = "Bring the agent of your choice to Zed via our new Agent Client Protocol (ACP), starting with Google's Gemini CLI integration."; + + let open_panel_button = Button::new("open-panel", "Start with Gemini CLI") + .icon_size(IconSize::Indicator) + .style(ButtonStyle::Tinted(TintColor::Accent)) + .full_width() + .on_click(cx.listener(Self::open_panel)); + + let docs_button = Button::new("add-other-agents", "Add Other Agents") + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::Indicator) + .icon_color(Color::Muted) + .full_width() + .on_click(cx.listener(Self::view_docs)); + + let close_button = h_flex().absolute().top_2().right_2().child( + IconButton::new("cancel", IconName::Close).on_click(cx.listener( + |_, _: &ClickEvent, _window, cx| { + acp_onboarding_event!("Canceled", trigger = "X click"); + cx.emit(DismissEvent); + }, + )), + ); + + v_flex() + .id("acp-onboarding") + .key_context("AcpOnboardingModal") + .relative() + .w(rems(34.)) + .h_full() + .elevation_3(cx) + .track_focus(&self.focus_handle(cx)) + .overflow_hidden() + .on_action(cx.listener(Self::cancel)) + .on_action(cx.listener(|_, _: &menu::Cancel, _window, cx| { + acp_onboarding_event!("Canceled", trigger = "Action"); + cx.emit(DismissEvent); + })) + .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, _cx| { + this.focus_handle.focus(window); + })) + .child(illustration) + .child( + v_flex() + .p_4() + .gap_2() + .child(heading) + .child(Label::new(copy).color(Color::Muted)) + .child( + v_flex() + .w_full() + .mt_2() + .gap_1() + .child(open_panel_button) + .child(docs_button), + ), + ) + .child(close_button) + } +} diff --git a/crates/agent_ui/src/ui/agent_notification.rs b/crates/agent_ui/src/ui/agent_notification.rs index 68480c047f9cab4cd72f1998422bc727993e1f5e..af2a022f147b79a0a299c17dd26c7e9a8b62aeb9 100644 --- a/crates/agent_ui/src/ui/agent_notification.rs +++ b/crates/agent_ui/src/ui/agent_notification.rs @@ -62,6 +62,8 @@ impl AgentNotification { app_id: Some(app_id.to_owned()), window_min_size: None, window_decorations: Some(WindowDecorations::Client), + tabbing_identifier: None, + ..Default::default() } } } diff --git a/crates/agent_ui/src/ui/burn_mode_tooltip.rs b/crates/agent_ui/src/ui/burn_mode_tooltip.rs index 97f7853a61bc2bc2766492e077e34c3f1b534abe..72faaa614d0d531365fef9ba5ff0e62a6fbcf145 100644 --- a/crates/agent_ui/src/ui/burn_mode_tooltip.rs +++ b/crates/agent_ui/src/ui/burn_mode_tooltip.rs @@ -2,11 +2,11 @@ use crate::ToggleBurnMode; use gpui::{Context, FontWeight, IntoElement, Render, Window}; use ui::{KeyBinding, prelude::*, tooltip_container}; -pub struct MaxModeTooltip { +pub struct BurnModeTooltip { selected: bool, } -impl MaxModeTooltip { +impl BurnModeTooltip { pub fn new() -> Self { Self { selected: false } } @@ -17,7 +17,7 @@ impl MaxModeTooltip { } } -impl Render for MaxModeTooltip { +impl Render for BurnModeTooltip { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let (icon, color) = if self.selected { (IconName::ZedBurnModeOn, Color::Error) diff --git a/crates/agent_ui/src/ui/claude_code_onboarding_modal.rs b/crates/agent_ui/src/ui/claude_code_onboarding_modal.rs new file mode 100644 index 0000000000000000000000000000000000000000..06980f18977aefe228bb7f09962e69fe2b3a5068 --- /dev/null +++ b/crates/agent_ui/src/ui/claude_code_onboarding_modal.rs @@ -0,0 +1,254 @@ +use client::zed_urls; +use gpui::{ + ClickEvent, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, MouseDownEvent, Render, + linear_color_stop, linear_gradient, +}; +use ui::{TintColor, Vector, VectorName, prelude::*}; +use workspace::{ModalView, Workspace}; + +use crate::agent_panel::{AgentPanel, AgentType}; + +macro_rules! claude_code_onboarding_event { + ($name:expr) => { + telemetry::event!($name, source = "ACP Claude Code Onboarding"); + }; + ($name:expr, $($key:ident $(= $value:expr)?),+ $(,)?) => { + telemetry::event!($name, source = "ACP Claude Code Onboarding", $($key $(= $value)?),+); + }; +} + +pub struct ClaudeCodeOnboardingModal { + focus_handle: FocusHandle, + workspace: Entity, +} + +impl ClaudeCodeOnboardingModal { + pub fn toggle(workspace: &mut Workspace, window: &mut Window, cx: &mut Context) { + let workspace_entity = cx.entity(); + workspace.toggle_modal(window, cx, |_window, cx| Self { + workspace: workspace_entity, + focus_handle: cx.focus_handle(), + }); + } + + fn open_panel(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context) { + self.workspace.update(cx, |workspace, cx| { + workspace.focus_panel::(window, cx); + + if let Some(panel) = workspace.panel::(cx) { + panel.update(cx, |panel, cx| { + panel.new_agent_thread(AgentType::ClaudeCode, window, cx); + }); + } + }); + + cx.emit(DismissEvent); + + claude_code_onboarding_event!("Open Panel Clicked"); + } + + fn view_docs(&mut self, _: &ClickEvent, _: &mut Window, cx: &mut Context) { + cx.open_url(&zed_urls::external_agents_docs(cx)); + cx.notify(); + + claude_code_onboarding_event!("Documentation Link Clicked"); + } + + fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context) { + cx.emit(DismissEvent); + } +} + +impl EventEmitter for ClaudeCodeOnboardingModal {} + +impl Focusable for ClaudeCodeOnboardingModal { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl ModalView for ClaudeCodeOnboardingModal {} + +impl Render for ClaudeCodeOnboardingModal { + fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + let illustration_element = |icon: IconName, label: Option, opacity: f32| { + h_flex() + .px_1() + .py_0p5() + .gap_1() + .rounded_sm() + .bg(cx.theme().colors().element_active.opacity(0.05)) + .border_1() + .border_color(cx.theme().colors().border) + .border_dashed() + .child( + Icon::new(icon) + .size(IconSize::Small) + .color(Color::Custom(cx.theme().colors().text_muted.opacity(0.15))), + ) + .map(|this| { + if let Some(label_text) = label { + this.child( + Label::new(label_text) + .size(LabelSize::Small) + .color(Color::Muted), + ) + } else { + this.child( + div().w_16().h_1().rounded_full().bg(cx + .theme() + .colors() + .element_active + .opacity(0.6)), + ) + } + }) + .opacity(opacity) + }; + + let illustration = h_flex() + .relative() + .h(rems_from_px(126.)) + .bg(cx.theme().colors().editor_background) + .border_b_1() + .border_color(cx.theme().colors().border_variant) + .justify_center() + .gap_8() + .rounded_t_md() + .overflow_hidden() + .child( + div().absolute().inset_0().w(px(515.)).h(px(126.)).child( + Vector::new(VectorName::AcpGrid, rems_from_px(515.), rems_from_px(126.)) + .color(ui::Color::Custom(cx.theme().colors().text.opacity(0.02))), + ), + ) + .child(div().absolute().inset_0().size_full().bg(linear_gradient( + 0., + linear_color_stop( + cx.theme().colors().elevated_surface_background.opacity(0.1), + 0.9, + ), + linear_color_stop( + cx.theme().colors().elevated_surface_background.opacity(0.), + 0., + ), + ))) + .child( + div() + .absolute() + .inset_0() + .size_full() + .bg(gpui::black().opacity(0.15)), + ) + .child( + Vector::new( + VectorName::AcpLogoSerif, + rems_from_px(257.), + rems_from_px(47.), + ) + .color(ui::Color::Custom(cx.theme().colors().text.opacity(0.8))), + ) + .child( + v_flex() + .gap_1p5() + .child(illustration_element(IconName::Stop, None, 0.15)) + .child(illustration_element( + IconName::AiGemini, + Some("New Gemini CLI Thread".into()), + 0.3, + )) + .child( + h_flex() + .pl_1() + .pr_2() + .py_0p5() + .gap_1() + .rounded_sm() + .bg(cx.theme().colors().element_active.opacity(0.2)) + .border_1() + .border_color(cx.theme().colors().border) + .child( + Icon::new(IconName::AiClaude) + .size(IconSize::Small) + .color(Color::Muted), + ) + .child(Label::new("New Claude Code Thread").size(LabelSize::Small)), + ) + .child(illustration_element( + IconName::Stop, + Some("Your Agent Here".into()), + 0.3, + )) + .child(illustration_element(IconName::Stop, None, 0.15)), + ); + + let heading = v_flex() + .w_full() + .gap_1() + .child( + Label::new("Beta Release") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child(Headline::new("Claude Code: Natively in Zed").size(HeadlineSize::Large)); + + let copy = "Powered by the Agent Client Protocol, you can now run Claude Code as\na first-class citizen in Zed's agent panel."; + + let open_panel_button = Button::new("open-panel", "Start with Claude Code") + .icon_size(IconSize::Indicator) + .style(ButtonStyle::Tinted(TintColor::Accent)) + .full_width() + .on_click(cx.listener(Self::open_panel)); + + let docs_button = Button::new("add-other-agents", "Add Other Agents") + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::Indicator) + .icon_color(Color::Muted) + .full_width() + .on_click(cx.listener(Self::view_docs)); + + let close_button = h_flex().absolute().top_2().right_2().child( + IconButton::new("cancel", IconName::Close).on_click(cx.listener( + |_, _: &ClickEvent, _window, cx| { + claude_code_onboarding_event!("Canceled", trigger = "X click"); + cx.emit(DismissEvent); + }, + )), + ); + + v_flex() + .id("acp-onboarding") + .key_context("AcpOnboardingModal") + .relative() + .w(rems(34.)) + .h_full() + .elevation_3(cx) + .track_focus(&self.focus_handle(cx)) + .overflow_hidden() + .on_action(cx.listener(Self::cancel)) + .on_action(cx.listener(|_, _: &menu::Cancel, _window, cx| { + claude_code_onboarding_event!("Canceled", trigger = "Action"); + cx.emit(DismissEvent); + })) + .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, _cx| { + this.focus_handle.focus(window); + })) + .child(illustration) + .child( + v_flex() + .p_4() + .gap_2() + .child(heading) + .child(Label::new(copy).color(Color::Muted)) + .child( + v_flex() + .w_full() + .mt_2() + .gap_1() + .child(open_panel_button) + .child(docs_button), + ), + ) + .child(close_button) + } +} diff --git a/crates/agent_ui/src/ui/context_pill.rs b/crates/agent_ui/src/ui/context_pill.rs index 5dd57de24490df03ce0f2c41a844be33fb675793..7c7fbd27f0d4ebe3b5c42cc6c5a244ae6add5614 100644 --- a/crates/agent_ui/src/ui/context_pill.rs +++ b/crates/agent_ui/src/ui/context_pill.rs @@ -353,7 +353,7 @@ impl AddedContext { name, parent, tooltip: Some(full_path_string), - icon_path: FileIcons::get_icon(&full_path, cx), + icon_path: FileIcons::get_icon(full_path, cx), status: ContextStatus::Ready, render_hover: None, handle: AgentContextHandle::File(handle), @@ -499,7 +499,7 @@ impl AddedContext { let thread = handle.thread.clone(); Some(Rc::new(move |_, cx| { let text = thread.read(cx).latest_detailed_summary_or_text(); - ContextPillHover::new_text(text.clone(), cx).into() + ContextPillHover::new_text(text, cx).into() })) }, handle: AgentContextHandle::Thread(handle), @@ -574,7 +574,7 @@ impl AddedContext { .unwrap_or_else(|| "Unnamed Rule".into()); Some(AddedContext { kind: ContextKind::Rules, - name: title.clone(), + name: title, parent: None, tooltip: None, icon_path: None, @@ -615,7 +615,7 @@ impl AddedContext { let full_path_string: SharedString = full_path.to_string_lossy().into_owned().into(); let (name, parent) = extract_file_name_and_directory_from_full_path(full_path, &full_path_string); - let icon_path = FileIcons::get_icon(&full_path, cx); + let icon_path = FileIcons::get_icon(full_path, cx); (name, parent, icon_path) } else { ("Image".into(), None, None) @@ -706,7 +706,7 @@ impl ContextFileExcerpt { .and_then(|p| p.file_name()) .map(|n| n.to_string_lossy().into_owned().into()); - let icon_path = FileIcons::get_icon(&full_path, cx); + let icon_path = FileIcons::get_icon(full_path, cx); ContextFileExcerpt { file_name_and_range: file_name_and_range.into(), diff --git a/crates/agent_ui/src/ui/end_trial_upsell.rs b/crates/agent_ui/src/ui/end_trial_upsell.rs index 3a8a119800543ad033efd563d7896ccc80add373..4db9244469cf1ad7fab414874a64e45f3b97e377 100644 --- a/crates/agent_ui/src/ui/end_trial_upsell.rs +++ b/crates/agent_ui/src/ui/end_trial_upsell.rs @@ -2,24 +2,27 @@ use std::sync::Arc; use ai_onboarding::{AgentPanelOnboardingCard, PlanDefinitions}; use client::zed_urls; +use cloud_llm_client::{Plan, PlanV1}; use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; use ui::{Divider, Tooltip, prelude::*}; #[derive(IntoElement, RegisterComponent)] pub struct EndTrialUpsell { + plan: Plan, dismiss_upsell: Arc, } impl EndTrialUpsell { - pub fn new(dismiss_upsell: Arc) -> Self { - Self { dismiss_upsell } + pub fn new(plan: Plan, dismiss_upsell: Arc) -> Self { + Self { + plan, + dismiss_upsell, + } } } impl RenderOnce for EndTrialUpsell { fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { - let plan_definitions = PlanDefinitions; - let pro_section = v_flex() .gap_1() .child( @@ -33,7 +36,7 @@ impl RenderOnce for EndTrialUpsell { ) .child(Divider::horizontal()), ) - .child(plan_definitions.pro_plan(false)) + .child(PlanDefinitions.pro_plan(self.plan.is_v2(), false)) .child( Button::new("cta-button", "Upgrade to Zed Pro") .full_width() @@ -64,7 +67,7 @@ impl RenderOnce for EndTrialUpsell { ) .child(Divider::horizontal()), ) - .child(plan_definitions.free_plan()); + .child(PlanDefinitions.free_plan(self.plan.is_v2())); AgentPanelOnboardingCard::new() .child(Headline::new("Your Zed Pro Trial has expired")) @@ -109,6 +112,7 @@ impl Component for EndTrialUpsell { Some( v_flex() .child(EndTrialUpsell { + plan: Plan::V1(PlanV1::ZedFree), dismiss_upsell: Arc::new(|_, _| {}), }) .into_any_element(), diff --git a/crates/agent_ui/src/ui/new_thread_button.rs b/crates/agent_ui/src/ui/new_thread_button.rs deleted file mode 100644 index 347d6adcaf14221fef31f87303028e30091d2ec4..0000000000000000000000000000000000000000 --- a/crates/agent_ui/src/ui/new_thread_button.rs +++ /dev/null @@ -1,75 +0,0 @@ -use gpui::{ClickEvent, ElementId, IntoElement, ParentElement, Styled}; -use ui::prelude::*; - -#[derive(IntoElement)] -pub struct NewThreadButton { - id: ElementId, - label: SharedString, - icon: IconName, - keybinding: Option, - on_click: Option>, -} - -impl NewThreadButton { - fn new(id: impl Into, label: impl Into, icon: IconName) -> Self { - Self { - id: id.into(), - label: label.into(), - icon, - keybinding: None, - on_click: None, - } - } - - fn keybinding(mut self, keybinding: Option) -> Self { - self.keybinding = keybinding; - self - } - - fn on_click(mut self, handler: F) -> Self - where - F: Fn(&mut Window, &mut App) + 'static, - { - self.on_click = Some(Box::new( - move |_: &ClickEvent, window: &mut Window, cx: &mut App| handler(window, cx), - )); - self - } -} - -impl RenderOnce for NewThreadButton { - fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { - h_flex() - .id(self.id) - .w_full() - .py_1p5() - .px_2() - .gap_1() - .justify_between() - .rounded_md() - .border_1() - .border_color(cx.theme().colors().border.opacity(0.4)) - .bg(cx.theme().colors().element_active.opacity(0.2)) - .hover(|style| { - style - .bg(cx.theme().colors().element_hover) - .border_color(cx.theme().colors().border) - }) - .child( - h_flex() - .gap_1p5() - .child( - Icon::new(self.icon) - .size(IconSize::XSmall) - .color(Color::Muted), - ) - .child(Label::new(self.label).size(LabelSize::Small)), - ) - .when_some(self.keybinding, |this, keybinding| { - this.child(keybinding.size(rems_from_px(10.))) - }) - .when_some(self.on_click, |this, on_click| { - this.on_click(move |event, window, cx| on_click(event, window, cx)) - }) - } -} diff --git a/crates/agent_ui/src/ui/preview/usage_callouts.rs b/crates/agent_ui/src/ui/preview/usage_callouts.rs index eef878a9d1b9cac72cb13cde0c8fbd92c1519afc..e31af9f49aca781e735a96384dbaddbd3b446ef7 100644 --- a/crates/agent_ui/src/ui/preview/usage_callouts.rs +++ b/crates/agent_ui/src/ui/preview/usage_callouts.rs @@ -1,5 +1,5 @@ use client::{ModelRequestUsage, RequestUsage, zed_urls}; -use cloud_llm_client::{Plan, UsageLimit}; +use cloud_llm_client::{Plan, PlanV1, PlanV2, UsageLimit}; use component::{empty_example, example_group_with_title, single_example}; use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; use ui::{Callout, prelude::*}; @@ -38,20 +38,20 @@ impl RenderOnce for UsageCallout { let (title, message, button_text, url) = if is_limit_reached { match self.plan { - Plan::ZedFree => ( + Plan::V1(PlanV1::ZedFree) | Plan::V2(PlanV2::ZedFree) => ( "Out of free prompts", "Upgrade to continue, wait for the next reset, or switch to API key." .to_string(), "Upgrade", zed_urls::account_url(cx), ), - Plan::ZedProTrial => ( + Plan::V1(PlanV1::ZedProTrial) | Plan::V2(PlanV2::ZedProTrial) => ( "Out of trial prompts", "Upgrade to Zed Pro to continue, or switch to API key.".to_string(), "Upgrade", zed_urls::account_url(cx), ), - Plan::ZedPro => ( + Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro) => ( "Out of included prompts", "Enable usage-based billing to continue.".to_string(), "Manage", @@ -60,7 +60,7 @@ impl RenderOnce for UsageCallout { } } else { match self.plan { - Plan::ZedFree => ( + Plan::V1(PlanV1::ZedFree) => ( "Reaching free plan limit soon", format!( "{remaining} remaining - Upgrade to increase limit, or switch providers", @@ -68,7 +68,7 @@ impl RenderOnce for UsageCallout { "Upgrade", zed_urls::account_url(cx), ), - Plan::ZedProTrial => ( + Plan::V1(PlanV1::ZedProTrial) => ( "Reaching trial limit soon", format!( "{remaining} remaining - Upgrade to increase limit, or switch providers", @@ -76,35 +76,28 @@ impl RenderOnce for UsageCallout { "Upgrade", zed_urls::account_url(cx), ), - _ => return div().into_any_element(), + Plan::V1(PlanV1::ZedPro) | Plan::V2(_) => return div().into_any_element(), } }; - let icon = if is_limit_reached { - Icon::new(IconName::Close) - .color(Color::Error) - .size(IconSize::XSmall) + let (icon, severity) = if is_limit_reached { + (IconName::Close, Severity::Error) } else { - Icon::new(IconName::Warning) - .color(Color::Warning) - .size(IconSize::XSmall) + (IconName::Warning, Severity::Warning) }; - div() - .border_t_1() - .border_color(cx.theme().colors().border) - .child( - Callout::new() - .icon(icon) - .title(title) - .description(message) - .primary_action( - Button::new("upgrade", button_text) - .label_size(LabelSize::Small) - .on_click(move |_, _, cx| { - cx.open_url(&url); - }), - ), + Callout::new() + .icon(icon) + .severity(severity) + .icon(icon) + .title(title) + .description(message) + .actions_slot( + Button::new("upgrade", button_text) + .label_size(LabelSize::Small) + .on_click(move |_, _, cx| { + cx.open_url(&url); + }), ) .into_any_element() } @@ -126,7 +119,7 @@ impl Component for UsageCallout { single_example( "Approaching limit (90%)", UsageCallout::new( - Plan::ZedFree, + Plan::V1(PlanV1::ZedFree), ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(50), amount: 45, // 90% of limit @@ -137,7 +130,7 @@ impl Component for UsageCallout { single_example( "Limit reached (100%)", UsageCallout::new( - Plan::ZedFree, + Plan::V1(PlanV1::ZedFree), ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(50), amount: 50, // 100% of limit @@ -154,7 +147,7 @@ impl Component for UsageCallout { single_example( "Approaching limit (90%)", UsageCallout::new( - Plan::ZedProTrial, + Plan::V1(PlanV1::ZedProTrial), ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(150), amount: 135, // 90% of limit @@ -165,7 +158,7 @@ impl Component for UsageCallout { single_example( "Limit reached (100%)", UsageCallout::new( - Plan::ZedProTrial, + Plan::V1(PlanV1::ZedProTrial), ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(150), amount: 150, // 100% of limit @@ -182,7 +175,7 @@ impl Component for UsageCallout { single_example( "Limit reached (100%)", UsageCallout::new( - Plan::ZedPro, + Plan::V1(PlanV1::ZedPro), ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(500), amount: 500, // 100% of limit diff --git a/crates/agent_ui/src/ui/unavailable_editing_tooltip.rs b/crates/agent_ui/src/ui/unavailable_editing_tooltip.rs new file mode 100644 index 0000000000000000000000000000000000000000..78d4c64e0acc7bff86516657f76007e78a54d304 --- /dev/null +++ b/crates/agent_ui/src/ui/unavailable_editing_tooltip.rs @@ -0,0 +1,29 @@ +use gpui::{Context, IntoElement, Render, Window}; +use ui::{prelude::*, tooltip_container}; + +pub struct UnavailableEditingTooltip { + agent_name: SharedString, +} + +impl UnavailableEditingTooltip { + pub fn new(agent_name: SharedString) -> Self { + Self { agent_name } + } +} + +impl Render for UnavailableEditingTooltip { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + tooltip_container(window, cx, |this, _, _| { + this.child(Label::new("Unavailable Editing")).child( + div().max_w_64().child( + Label::new(format!( + "Editing previous messages is not available for {} yet.", + self.agent_name + )) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + }) + } +} diff --git a/crates/ai_onboarding/Cargo.toml b/crates/ai_onboarding/Cargo.toml index 95a45b1a6fbe103f02532d33c21af707f2f51d45..cf3e6e9cd66eff0ce412436d4dc1d2b4b01c0041 100644 --- a/crates/ai_onboarding/Cargo.toml +++ b/crates/ai_onboarding/Cargo.toml @@ -18,6 +18,7 @@ default = [] client.workspace = true cloud_llm_client.workspace = true component.workspace = true +feature_flags.workspace = true gpui.workspace = true language_model.workspace = true serde.workspace = true diff --git a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs index b55ad4c89549a8843fe2d8273da60236400cb565..fadc4222ae44f3dbad862fd9479b89321dbd3016 100644 --- a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs +++ b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs @@ -11,7 +11,7 @@ impl ApiKeysWithProviders { cx.subscribe( &LanguageModelRegistry::global(cx), |this: &mut Self, _registry, event: &language_model::Event, cx| match event { - language_model::Event::ProviderStateChanged + language_model::Event::ProviderStateChanged(_) | language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { this.configured_providers = Self::compute_configured_providers(cx) @@ -33,7 +33,7 @@ impl ApiKeysWithProviders { .filter(|provider| { provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID }) - .map(|provider| (provider.icon(), provider.name().0.clone())) + .map(|provider| (provider.icon(), provider.name().0)) .collect() } } diff --git a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs index f1629eeff81ef51bf2ff823eef0db64c1585a669..3c8ffc1663e0660829698b5449a006de5b3c6009 100644 --- a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs +++ b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use client::{Client, UserStore}; -use cloud_llm_client::Plan; +use cloud_llm_client::{Plan, PlanV1, PlanV2}; use gpui::{Entity, IntoElement, ParentElement}; use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; use ui::prelude::*; @@ -25,7 +25,7 @@ impl AgentPanelOnboarding { cx.subscribe( &LanguageModelRegistry::global(cx), |this: &mut Self, _registry, event: &language_model::Event, cx| match event { - language_model::Event::ProviderStateChanged + language_model::Event::ProviderStateChanged(_) | language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { this.configured_providers = Self::compute_available_providers(cx) @@ -50,15 +50,22 @@ impl AgentPanelOnboarding { .filter(|provider| { provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID }) - .map(|provider| (provider.icon(), provider.name().0.clone())) + .map(|provider| (provider.icon(), provider.name().0)) .collect() } } impl Render for AgentPanelOnboarding { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { - let enrolled_in_trial = self.user_store.read(cx).plan() == Some(Plan::ZedProTrial); - let is_pro_user = self.user_store.read(cx).plan() == Some(Plan::ZedPro); + let enrolled_in_trial = self.user_store.read(cx).plan().is_some_and(|plan| { + matches!( + plan, + Plan::V1(PlanV1::ZedProTrial) | Plan::V2(PlanV2::ZedProTrial) + ) + }); + let is_pro_user = self.user_store.read(cx).plan().is_some_and(|plan| { + matches!(plan, Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro)) + }); AgentPanelOnboardingCard::new() .child( @@ -74,7 +81,7 @@ impl Render for AgentPanelOnboarding { }), ) .map(|this| { - if enrolled_in_trial || is_pro_user || self.configured_providers.len() >= 1 { + if enrolled_in_trial || is_pro_user || !self.configured_providers.is_empty() { this } else { this.child(ApiKeysWithoutProviders::new()) diff --git a/crates/ai_onboarding/src/ai_onboarding.rs b/crates/ai_onboarding/src/ai_onboarding.rs index 75177d4bd2bf22b203cf9f50134bb821438a433f..e131a60b12c2b330ff3a6c099713fa4a4a083358 100644 --- a/crates/ai_onboarding/src/ai_onboarding.rs +++ b/crates/ai_onboarding/src/ai_onboarding.rs @@ -10,7 +10,7 @@ pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProvider pub use agent_panel_onboarding_card::AgentPanelOnboardingCard; pub use agent_panel_onboarding_content::AgentPanelOnboarding; pub use ai_upsell_card::AiUpsellCard; -use cloud_llm_client::Plan; +use cloud_llm_client::{Plan, PlanV1, PlanV2}; pub use edit_prediction_onboarding_content::EditPredictionOnboarding; pub use plan_definitions::PlanDefinitions; pub use young_account_banner::YoungAccountBanner; @@ -18,8 +18,9 @@ pub use young_account_banner::YoungAccountBanner; use std::sync::Arc; use client::{Client, UserStore, zed_urls}; +use feature_flags::{BillingV2FeatureFlag, FeatureFlagAppExt as _}; use gpui::{AnyElement, Entity, IntoElement, ParentElement}; -use ui::{Divider, RegisterComponent, TintColor, Tooltip, prelude::*}; +use ui::{Divider, RegisterComponent, Tooltip, prelude::*}; #[derive(PartialEq)] pub enum SignInStatus { @@ -43,12 +44,10 @@ impl From for SignInStatus { #[derive(RegisterComponent, IntoElement)] pub struct ZedAiOnboarding { pub sign_in_status: SignInStatus, - pub has_accepted_terms_of_service: bool, pub plan: Option, pub account_too_young: bool, pub continue_with_zed_ai: Arc, pub sign_in: Arc, - pub accept_terms_of_service: Arc, pub dismiss_onboarding: Option>, } @@ -64,17 +63,9 @@ impl ZedAiOnboarding { Self { sign_in_status: status.into(), - has_accepted_terms_of_service: store.has_accepted_terms_of_service(), plan: store.plan(), account_too_young: store.account_too_young(), continue_with_zed_ai, - accept_terms_of_service: Arc::new({ - let store = user_store.clone(); - move |_window, cx| { - let task = store.update(cx, |store, cx| store.accept_terms_of_service(cx)); - task.detach_and_log_err(cx); - } - }), sign_in: Arc::new(move |_window, cx| { cx.spawn({ let client = client.clone(); @@ -94,45 +85,8 @@ impl ZedAiOnboarding { self } - fn render_accept_terms_of_service(&self) -> AnyElement { - v_flex() - .gap_1() - .w_full() - .child(Headline::new("Accept Terms of Service")) - .child( - Label::new("We don’t sell your data, track you across the web, or compromise your privacy.") - .color(Color::Muted) - .mb_2(), - ) - .child( - Button::new("terms_of_service", "Review Terms of Service") - .full_width() - .style(ButtonStyle::Outlined) - .icon(IconName::ArrowUpRight) - .icon_color(Color::Muted) - .icon_size(IconSize::Small) - .on_click(move |_, _window, cx| { - telemetry::event!("Review Terms of Service Clicked"); - cx.open_url(&zed_urls::terms_of_service(cx)) - }), - ) - .child( - Button::new("accept_terms", "Accept") - .full_width() - .style(ButtonStyle::Tinted(TintColor::Accent)) - .on_click({ - let callback = self.accept_terms_of_service.clone(); - move |_, window, cx| { - telemetry::event!("Terms of Service Accepted"); - (callback)(window, cx)} - }), - ) - .into_any_element() - } - - fn render_sign_in_disclaimer(&self, _cx: &mut App) -> AnyElement { + fn render_sign_in_disclaimer(&self, cx: &mut App) -> AnyElement { let signing_in = matches!(self.sign_in_status, SignInStatus::SigningIn); - let plan_definitions = PlanDefinitions; v_flex() .gap_1() @@ -142,7 +96,7 @@ impl ZedAiOnboarding { .color(Color::Muted) .mb_2(), ) - .child(plan_definitions.pro_plan(false)) + .child(PlanDefinitions.pro_plan(cx.has_flag::(), false)) .child( Button::new("sign_in", "Try Zed Pro for Free") .disabled(signing_in) @@ -159,17 +113,14 @@ impl ZedAiOnboarding { .into_any_element() } - fn render_free_plan_state(&self, cx: &mut App) -> AnyElement { - let young_account_banner = YoungAccountBanner; - let plan_definitions = PlanDefinitions; - + fn render_free_plan_state(&self, is_v2: bool, cx: &mut App) -> AnyElement { if self.account_too_young { v_flex() .relative() .max_w_full() .gap_1() .child(Headline::new("Welcome to Zed AI")) - .child(young_account_banner) + .child(YoungAccountBanner) .child( v_flex() .mt_2() @@ -185,7 +136,7 @@ impl ZedAiOnboarding { ) .child(Divider::horizontal()), ) - .child(plan_definitions.pro_plan(true)) + .child(PlanDefinitions.pro_plan(is_v2, true)) .child( Button::new("pro", "Get Started") .full_width() @@ -228,7 +179,7 @@ impl ZedAiOnboarding { ) .child(Divider::horizontal()), ) - .child(plan_definitions.free_plan()), + .child(PlanDefinitions.free_plan(is_v2)), ) .when_some( self.dismiss_onboarding.as_ref(), @@ -266,7 +217,7 @@ impl ZedAiOnboarding { ) .child(Divider::horizontal()), ) - .child(plan_definitions.pro_trial(true)) + .child(PlanDefinitions.pro_trial(is_v2, true)) .child( Button::new("pro", "Start Free Trial") .full_width() @@ -284,9 +235,7 @@ impl ZedAiOnboarding { } } - fn render_trial_state(&self, _cx: &mut App) -> AnyElement { - let plan_definitions = PlanDefinitions; - + fn render_trial_state(&self, is_v2: bool, _cx: &mut App) -> AnyElement { v_flex() .relative() .gap_1() @@ -296,7 +245,7 @@ impl ZedAiOnboarding { .color(Color::Muted) .mb_2(), ) - .child(plan_definitions.pro_trial(false)) + .child(PlanDefinitions.pro_trial(is_v2, false)) .when_some( self.dismiss_onboarding.as_ref(), |this, dismiss_callback| { @@ -320,9 +269,7 @@ impl ZedAiOnboarding { .into_any_element() } - fn render_pro_plan_state(&self, _cx: &mut App) -> AnyElement { - let plan_definitions = PlanDefinitions; - + fn render_pro_plan_state(&self, is_v2: bool, _cx: &mut App) -> AnyElement { v_flex() .gap_1() .child(Headline::new("Welcome to Zed Pro")) @@ -331,18 +278,26 @@ impl ZedAiOnboarding { .color(Color::Muted) .mb_2(), ) - .child(plan_definitions.pro_plan(false)) - .child( - Button::new("pro", "Continue with Zed Pro") - .full_width() - .style(ButtonStyle::Outlined) - .on_click({ - let callback = self.continue_with_zed_ai.clone(); - move |_, window, cx| { - telemetry::event!("Banner Dismissed", source = "AI Onboarding"); - callback(window, cx) - } - }), + .child(PlanDefinitions.pro_plan(is_v2, false)) + .when_some( + self.dismiss_onboarding.as_ref(), + |this, dismiss_callback| { + let callback = dismiss_callback.clone(); + this.child( + h_flex().absolute().top_0().right_0().child( + IconButton::new("dismiss_onboarding", IconName::Close) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text("Dismiss")) + .on_click(move |_, window, cx| { + telemetry::event!( + "Banner Dismissed", + source = "AI Onboarding", + ); + callback(window, cx) + }), + ), + ) + }, ) .into_any_element() } @@ -351,14 +306,17 @@ impl ZedAiOnboarding { impl RenderOnce for ZedAiOnboarding { fn render(self, _window: &mut ui::Window, cx: &mut App) -> impl IntoElement { if matches!(self.sign_in_status, SignInStatus::SignedIn) { - if self.has_accepted_terms_of_service { - match self.plan { - None | Some(Plan::ZedFree) => self.render_free_plan_state(cx), - Some(Plan::ZedProTrial) => self.render_trial_state(cx), - Some(Plan::ZedPro) => self.render_pro_plan_state(cx), + match self.plan { + None => self.render_free_plan_state(cx.has_flag::(), cx), + Some(plan @ (Plan::V1(PlanV1::ZedFree) | Plan::V2(PlanV2::ZedFree))) => { + self.render_free_plan_state(plan.is_v2(), cx) + } + Some(plan @ (Plan::V1(PlanV1::ZedProTrial) | Plan::V2(PlanV2::ZedProTrial))) => { + self.render_trial_state(plan.is_v2(), cx) + } + Some(plan @ (Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro))) => { + self.render_pro_plan_state(plan.is_v2(), cx) } - } else { - self.render_accept_terms_of_service() } } else { self.render_sign_in_disclaimer(cx) @@ -382,18 +340,15 @@ impl Component for ZedAiOnboarding { fn preview(_window: &mut Window, _cx: &mut App) -> Option { fn onboarding( sign_in_status: SignInStatus, - has_accepted_terms_of_service: bool, plan: Option, account_too_young: bool, ) -> AnyElement { ZedAiOnboarding { sign_in_status, - has_accepted_terms_of_service, plan, account_too_young, continue_with_zed_ai: Arc::new(|_, _| {}), sign_in: Arc::new(|_, _| {}), - accept_terms_of_service: Arc::new(|_, _| {}), dismiss_onboarding: None, } .into_any_element() @@ -407,27 +362,35 @@ impl Component for ZedAiOnboarding { .children(vec![ single_example( "Not Signed-in", - onboarding(SignInStatus::SignedOut, false, None, false), - ), - single_example( - "Not Accepted ToS", - onboarding(SignInStatus::SignedIn, false, None, false), + onboarding(SignInStatus::SignedOut, None, false), ), single_example( "Young Account", - onboarding(SignInStatus::SignedIn, true, None, true), + onboarding(SignInStatus::SignedIn, None, true), ), single_example( "Free Plan", - onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedFree), false), + onboarding( + SignInStatus::SignedIn, + Some(Plan::V1(PlanV1::ZedFree)), + false, + ), ), single_example( "Pro Trial", - onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedProTrial), false), + onboarding( + SignInStatus::SignedIn, + Some(Plan::V1(PlanV1::ZedProTrial)), + false, + ), ), single_example( "Pro Plan", - onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedPro), false), + onboarding( + SignInStatus::SignedIn, + Some(Plan::V1(PlanV1::ZedPro)), + false, + ), ), ]) .into_any_element(), diff --git a/crates/ai_onboarding/src/ai_upsell_card.rs b/crates/ai_onboarding/src/ai_upsell_card.rs index e9639ca075d1190ef6ab13f1bb01dd7333010d86..51758dd9ac123b309450018bd254a2aae31d68af 100644 --- a/crates/ai_onboarding/src/ai_upsell_card.rs +++ b/crates/ai_onboarding/src/ai_upsell_card.rs @@ -1,22 +1,20 @@ -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; use client::{Client, UserStore, zed_urls}; -use cloud_llm_client::Plan; -use gpui::{ - Animation, AnimationExt, AnyElement, App, Entity, IntoElement, RenderOnce, Transformation, - Window, percentage, -}; -use ui::{Divider, Vector, VectorName, prelude::*}; +use cloud_llm_client::{Plan, PlanV1, PlanV2}; +use feature_flags::{BillingV2FeatureFlag, FeatureFlagAppExt}; +use gpui::{AnyElement, App, Entity, IntoElement, RenderOnce, Window}; +use ui::{CommonAnimationExt, Divider, Vector, VectorName, prelude::*}; use crate::{SignInStatus, YoungAccountBanner, plan_definitions::PlanDefinitions}; #[derive(IntoElement, RegisterComponent)] pub struct AiUpsellCard { - pub sign_in_status: SignInStatus, - pub sign_in: Arc, - pub account_too_young: bool, - pub user_plan: Option, - pub tab_index: Option, + sign_in_status: SignInStatus, + sign_in: Arc, + account_too_young: bool, + user_plan: Option, + tab_index: Option, } impl AiUpsellCard { @@ -43,12 +41,18 @@ impl AiUpsellCard { tab_index: None, } } + + pub fn tab_index(mut self, tab_index: Option) -> Self { + self.tab_index = tab_index; + self + } } impl RenderOnce for AiUpsellCard { fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { - let plan_definitions = PlanDefinitions; - let young_account_banner = YoungAccountBanner; + let is_v2_plan = self + .user_plan + .map_or(cx.has_flag::(), |plan| plan.is_v2()); let pro_section = v_flex() .flex_grow() @@ -65,7 +69,7 @@ impl RenderOnce for AiUpsellCard { ) .child(Divider::horizontal()), ) - .child(plan_definitions.pro_plan(false)); + .child(PlanDefinitions.pro_plan(is_v2_plan, false)); let free_section = v_flex() .flex_grow() @@ -82,12 +86,18 @@ impl RenderOnce for AiUpsellCard { ) .child(Divider::horizontal()), ) - .child(plan_definitions.free_plan()); + .child(PlanDefinitions.free_plan(is_v2_plan)); - let grid_bg = h_flex().absolute().inset_0().w_full().h(px(240.)).child( - Vector::new(VectorName::Grid, rems_from_px(500.), rems_from_px(240.)) - .color(Color::Custom(cx.theme().colors().border.opacity(0.05))), - ); + let grid_bg = h_flex() + .absolute() + .inset_0() + .w_full() + .h(px(240.)) + .bg(gpui::pattern_slash( + cx.theme().colors().border.opacity(0.1), + 2., + 25., + )); let gradient_bg = div() .absolute() @@ -142,11 +152,7 @@ impl RenderOnce for AiUpsellCard { rems_from_px(72.), ) .color(Color::Custom(cx.theme().colors().text_accent.alpha(0.3))) - .with_animation( - "loading_stamp", - Animation::new(Duration::from_secs(10)).repeat(), - |this, delta| this.transform(Transformation::rotate(percentage(delta))), - ), + .with_rotate_animation(10), ); let pro_trial_stamp = div() @@ -165,11 +171,11 @@ impl RenderOnce for AiUpsellCard { match self.sign_in_status { SignInStatus::SignedIn => match self.user_plan { - None | Some(Plan::ZedFree) => card + None | Some(Plan::V1(PlanV1::ZedFree) | Plan::V2(PlanV2::ZedFree)) => card .child(Label::new("Try Zed AI").size(LabelSize::Large)) .map(|this| { if self.account_too_young { - this.child(young_account_banner).child( + this.child(YoungAccountBanner).child( v_flex() .mt_2() .gap_1() @@ -184,7 +190,7 @@ impl RenderOnce for AiUpsellCard { ) .child(Divider::horizontal()), ) - .child(plan_definitions.pro_plan(true)) + .child(PlanDefinitions.pro_plan(is_v2_plan, true)) .child( Button::new("pro", "Get Started") .full_width() @@ -231,16 +237,17 @@ impl RenderOnce for AiUpsellCard { ) } }), - Some(Plan::ZedProTrial) => card - .child(pro_trial_stamp) - .child(Label::new("You're in the Zed Pro Trial").size(LabelSize::Large)) - .child( - Label::new("Here's what you get for the next 14 days:") - .color(Color::Muted) - .mb_2(), - ) - .child(plan_definitions.pro_trial(false)), - Some(Plan::ZedPro) => card + Some(plan @ (Plan::V1(PlanV1::ZedProTrial) | Plan::V2(PlanV2::ZedProTrial))) => { + card.child(pro_trial_stamp) + .child(Label::new("You're in the Zed Pro Trial").size(LabelSize::Large)) + .child( + Label::new("Here's what you get for the next 14 days:") + .color(Color::Muted) + .mb_2(), + ) + .child(PlanDefinitions.pro_trial(plan.is_v2(), false)) + } + Some(plan @ (Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro))) => card .child(certified_user_stamp) .child(Label::new("You're in the Zed Pro plan").size(LabelSize::Large)) .child( @@ -248,7 +255,7 @@ impl RenderOnce for AiUpsellCard { .color(Color::Muted) .mb_2(), ) - .child(plan_definitions.pro_plan(false)), + .child(PlanDefinitions.pro_plan(plan.is_v2(), false)), }, // Signed Out State _ => card @@ -320,7 +327,7 @@ impl Component for AiUpsellCard { sign_in_status: SignInStatus::SignedIn, sign_in: Arc::new(|_, _| {}), account_too_young: false, - user_plan: Some(Plan::ZedFree), + user_plan: Some(Plan::V1(PlanV1::ZedFree)), tab_index: Some(1), } .into_any_element(), @@ -331,7 +338,7 @@ impl Component for AiUpsellCard { sign_in_status: SignInStatus::SignedIn, sign_in: Arc::new(|_, _| {}), account_too_young: true, - user_plan: Some(Plan::ZedFree), + user_plan: Some(Plan::V1(PlanV1::ZedFree)), tab_index: Some(1), } .into_any_element(), @@ -342,7 +349,7 @@ impl Component for AiUpsellCard { sign_in_status: SignInStatus::SignedIn, sign_in: Arc::new(|_, _| {}), account_too_young: false, - user_plan: Some(Plan::ZedProTrial), + user_plan: Some(Plan::V1(PlanV1::ZedProTrial)), tab_index: Some(1), } .into_any_element(), @@ -353,7 +360,7 @@ impl Component for AiUpsellCard { sign_in_status: SignInStatus::SignedIn, sign_in: Arc::new(|_, _| {}), account_too_young: false, - user_plan: Some(Plan::ZedPro), + user_plan: Some(Plan::V1(PlanV1::ZedPro)), tab_index: Some(1), } .into_any_element(), diff --git a/crates/ai_onboarding/src/edit_prediction_onboarding_content.rs b/crates/ai_onboarding/src/edit_prediction_onboarding_content.rs index e883d8da8ce01bfea3f08676666c308a90f6d650..571f0f8e450ac2974cea2f4b2a7085069bc45c7c 100644 --- a/crates/ai_onboarding/src/edit_prediction_onboarding_content.rs +++ b/crates/ai_onboarding/src/edit_prediction_onboarding_content.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use client::{Client, UserStore}; +use cloud_llm_client::{Plan, PlanV1, PlanV2}; use gpui::{Entity, IntoElement, ParentElement}; use ui::prelude::*; @@ -35,6 +36,10 @@ impl EditPredictionOnboarding { impl Render for EditPredictionOnboarding { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let is_free_plan = self.user_store.read(cx).plan().is_some_and(|plan| { + matches!(plan, Plan::V1(PlanV1::ZedFree) | Plan::V2(PlanV2::ZedFree)) + }); + let github_copilot = v_flex() .gap_1() .child(Label::new(if self.copilot_is_configured { @@ -67,7 +72,8 @@ impl Render for EditPredictionOnboarding { self.continue_with_zed_ai.clone(), cx, )) - .child(ui::Divider::horizontal()) - .child(github_copilot) + .when(is_free_plan, |this| { + this.child(ui::Divider::horizontal()).child(github_copilot) + }) } } diff --git a/crates/ai_onboarding/src/plan_definitions.rs b/crates/ai_onboarding/src/plan_definitions.rs index 8d66f6c3563c482b2356e081b5786219f5bf1de3..dce67d421006ce918018923b86dbe22012efef01 100644 --- a/crates/ai_onboarding/src/plan_definitions.rs +++ b/crates/ai_onboarding/src/plan_definitions.rs @@ -7,13 +7,13 @@ pub struct PlanDefinitions; impl PlanDefinitions { pub const AI_DESCRIPTION: &'static str = "Zed offers a complete agentic experience, with robust editing and reviewing features to collaborate with AI."; - pub fn free_plan(&self) -> impl IntoElement { + pub fn free_plan(&self, _is_v2: bool) -> impl IntoElement { List::new() .child(ListBulletItem::new("50 prompts with Claude models")) .child(ListBulletItem::new("2,000 accepted edit predictions")) } - pub fn pro_trial(&self, period: bool) -> impl IntoElement { + pub fn pro_trial(&self, _is_v2: bool, period: bool) -> impl IntoElement { List::new() .child(ListBulletItem::new("150 prompts with Claude models")) .child(ListBulletItem::new( @@ -26,7 +26,7 @@ impl PlanDefinitions { }) } - pub fn pro_plan(&self, price: bool) -> impl IntoElement { + pub fn pro_plan(&self, _is_v2: bool, price: bool) -> impl IntoElement { List::new() .child(ListBulletItem::new("500 prompts with Claude models")) .child(ListBulletItem::new( diff --git a/crates/ai_onboarding/src/young_account_banner.rs b/crates/ai_onboarding/src/young_account_banner.rs index 54f563e4aac8ca71fff16199cd6c2e8f81ad5376..ae13b9556885c1552f7e90935f844347cd76a778 100644 --- a/crates/ai_onboarding/src/young_account_banner.rs +++ b/crates/ai_onboarding/src/young_account_banner.rs @@ -6,7 +6,7 @@ pub struct YoungAccountBanner; impl RenderOnce for YoungAccountBanner { fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { - const YOUNG_ACCOUNT_DISCLAIMER: &str = "To prevent abuse of our service, we cannot offer plans to GitHub accounts created fewer than 30 days ago. To request an exception, reach out to billing-support@zed.dev."; + const YOUNG_ACCOUNT_DISCLAIMER: &str = "To prevent abuse of our service, GitHub accounts created fewer than 30 days ago are not eligible for free plan usage or Pro plan free trial. To request an exception, reach out to billing-support@zed.dev."; let label = div() .w_full() @@ -17,6 +17,6 @@ impl RenderOnce for YoungAccountBanner { div() .max_w_full() .my_1() - .child(Banner::new().severity(ui::Severity::Warning).child(label)) + .child(Banner::new().severity(Severity::Warning).child(label)) } } diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 3ff1666755d439cf52a14ea635a06a7c3414d9f6..7fd0fb4bc5abd983c57507522c2a37dffcbfa258 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -363,17 +363,15 @@ pub async fn complete( api_url: &str, api_key: &str, request: Request, + beta_headers: String, ) -> Result { let uri = format!("{api_url}/v1/messages"); - let beta_headers = Model::from_id(&request.model) - .map(|model| model.beta_headers()) - .unwrap_or_else(|_| Model::DEFAULT_BETA_HEADERS.join(",")); let request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) .header("Anthropic-Version", "2023-06-01") .header("Anthropic-Beta", beta_headers) - .header("X-Api-Key", api_key) + .header("X-Api-Key", api_key.trim()) .header("Content-Type", "application/json"); let serialized_request = @@ -409,8 +407,9 @@ pub async fn stream_completion( api_url: &str, api_key: &str, request: Request, + beta_headers: String, ) -> Result>, AnthropicError> { - stream_completion_with_rate_limit_info(client, api_url, api_key, request) + stream_completion_with_rate_limit_info(client, api_url, api_key, request, beta_headers) .await .map(|output| output.0) } @@ -506,6 +505,7 @@ pub async fn stream_completion_with_rate_limit_info( api_url: &str, api_key: &str, request: Request, + beta_headers: String, ) -> Result< ( BoxStream<'static, Result>, @@ -518,15 +518,13 @@ pub async fn stream_completion_with_rate_limit_info( stream: true, }; let uri = format!("{api_url}/v1/messages"); - let beta_headers = Model::from_id(&request.base.model) - .map(|model| model.beta_headers()) - .unwrap_or_else(|_| Model::DEFAULT_BETA_HEADERS.join(",")); + let request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) .header("Anthropic-Version", "2023-06-01") .header("Anthropic-Beta", beta_headers) - .header("X-Api-Key", api_key) + .header("X-Api-Key", api_key.trim()) .header("Content-Type", "application/json"); let serialized_request = serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?; diff --git a/crates/askpass/src/askpass.rs b/crates/askpass/src/askpass.rs index f085a2be72d04d7c1d16f855230011639853ddf2..9e84a9fed03c8a620c7cb33cc76ef22c000c3fa6 100644 --- a/crates/askpass/src/askpass.rs +++ b/crates/askpass/src/askpass.rs @@ -177,11 +177,11 @@ impl AskPassSession { _ = askpass_opened_rx.fuse() => { // Note: this await can only resolve after we are dropped. askpass_kill_master_rx.await.ok(); - return AskPassResult::CancelledByUser + AskPassResult::CancelledByUser } _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => { - return AskPassResult::Timedout + AskPassResult::Timedout } } } @@ -215,7 +215,7 @@ pub fn main(socket: &str) { } #[cfg(target_os = "windows")] - while buffer.last().map_or(false, |&b| b == b'\n' || b == b'\r') { + while buffer.last().is_some_and(|&b| b == b'\n' || b == b'\r') { buffer.pop(); } if buffer.last() != Some(&b'\0') { diff --git a/crates/assistant_context/Cargo.toml b/crates/assistant_context/Cargo.toml index 45c0072418782909829ba3186138f0c6a9456654..3e2761a84674c6c4201165edf856b675843315d9 100644 --- a/crates/assistant_context/Cargo.toml +++ b/crates/assistant_context/Cargo.toml @@ -50,8 +50,9 @@ text.workspace = true ui.workspace = true util.workspace = true uuid.workspace = true -workspace-hack.workspace = true workspace.workspace = true +workspace-hack.workspace = true +zed_env_vars.workspace = true [dev-dependencies] indoc.workspace = true diff --git a/crates/assistant_context/src/assistant_context.rs b/crates/assistant_context/src/assistant_context.rs index 557f9592e4d12e86c4e73d1bc742dfa74535d66c..12eda0954a2e1cca9ddc7df9816b8f5a37d0ce10 100644 --- a/crates/assistant_context/src/assistant_context.rs +++ b/crates/assistant_context/src/assistant_context.rs @@ -590,17 +590,16 @@ impl From<&Message> for MessageMetadata { impl MessageMetadata { pub fn is_cache_valid(&self, buffer: &BufferSnapshot, range: &Range) -> bool { - let result = match &self.cache { + match &self.cache { Some(MessageCacheMetadata { cached_at, .. }) => !buffer.has_edits_since_in_range( - &cached_at, + cached_at, Range { start: buffer.anchor_at(range.start, Bias::Right), end: buffer.anchor_at(range.end, Bias::Left), }, ), _ => false, - }; - result + } } } @@ -1023,9 +1022,11 @@ impl AssistantContext { summary: new_summary, .. } => { - if self.summary.timestamp().map_or(true, |current_timestamp| { - new_summary.timestamp > current_timestamp - }) { + if self + .summary + .timestamp() + .is_none_or(|current_timestamp| new_summary.timestamp > current_timestamp) + { self.summary = ContextSummary::Content(new_summary); summary_generated = true; } @@ -1076,20 +1077,20 @@ impl AssistantContext { timestamp, .. } => { - if let Some(slash_command) = self.invoked_slash_commands.get_mut(&id) { - if timestamp > slash_command.timestamp { - slash_command.timestamp = timestamp; - match error_message { - Some(message) => { - slash_command.status = - InvokedSlashCommandStatus::Error(message.into()); - } - None => { - slash_command.status = InvokedSlashCommandStatus::Finished; - } + if let Some(slash_command) = self.invoked_slash_commands.get_mut(&id) + && timestamp > slash_command.timestamp + { + slash_command.timestamp = timestamp; + match error_message { + Some(message) => { + slash_command.status = + InvokedSlashCommandStatus::Error(message.into()); + } + None => { + slash_command.status = InvokedSlashCommandStatus::Finished; } - cx.emit(ContextEvent::InvokedSlashCommandChanged { command_id: id }); } + cx.emit(ContextEvent::InvokedSlashCommandChanged { command_id: id }); } } ContextOperation::BufferOperation(_) => unreachable!(), @@ -1339,7 +1340,7 @@ impl AssistantContext { let is_invalid = self .messages_metadata .get(&message_id) - .map_or(true, |metadata| { + .is_none_or(|metadata| { !metadata.is_cache_valid(&buffer, &message.offset_range) || *encountered_invalid }); @@ -1368,10 +1369,10 @@ impl AssistantContext { continue; } - if let Some(last_anchor) = last_anchor { - if message.id == last_anchor { - hit_last_anchor = true; - } + if let Some(last_anchor) = last_anchor + && message.id == last_anchor + { + hit_last_anchor = true; } new_anchor_needs_caching = new_anchor_needs_caching @@ -1406,14 +1407,14 @@ impl AssistantContext { if !self.pending_completions.is_empty() { return; } - if let Some(cache_configuration) = cache_configuration { - if !cache_configuration.should_speculate { - return; - } + if let Some(cache_configuration) = cache_configuration + && !cache_configuration.should_speculate + { + return; } let request = { - let mut req = self.to_completion_request(Some(&model), cx); + let mut req = self.to_completion_request(Some(model), cx); // Skip the last message because it's likely to change and // therefore would be a waste to cache. req.messages.pop(); @@ -1428,7 +1429,7 @@ impl AssistantContext { let model = Arc::clone(model); self.pending_cache_warming_task = cx.spawn(async move |this, cx| { async move { - match model.stream_completion(request, &cx).await { + match model.stream_completion(request, cx).await { Ok(mut stream) => { stream.next().await; log::info!("Cache warming completed successfully"); @@ -1552,25 +1553,24 @@ impl AssistantContext { }) .map(ToOwned::to_owned) .collect::>(); - if let Some(command) = self.slash_commands.command(name, cx) { - if !command.requires_argument() || !arguments.is_empty() { - let start_ix = offset + command_line.name.start - 1; - let end_ix = offset - + command_line - .arguments - .last() - .map_or(command_line.name.end, |argument| argument.end); - let source_range = - buffer.anchor_after(start_ix)..buffer.anchor_after(end_ix); - let pending_command = ParsedSlashCommand { - name: name.to_string(), - arguments, - source_range, - status: PendingSlashCommandStatus::Idle, - }; - updated.push(pending_command.clone()); - new_commands.push(pending_command); - } + if let Some(command) = self.slash_commands.command(name, cx) + && (!command.requires_argument() || !arguments.is_empty()) + { + let start_ix = offset + command_line.name.start - 1; + let end_ix = offset + + command_line + .arguments + .last() + .map_or(command_line.name.end, |argument| argument.end); + let source_range = buffer.anchor_after(start_ix)..buffer.anchor_after(end_ix); + let pending_command = ParsedSlashCommand { + name: name.to_string(), + arguments, + source_range, + status: PendingSlashCommandStatus::Idle, + }; + updated.push(pending_command.clone()); + new_commands.push(pending_command); } } @@ -1661,12 +1661,12 @@ impl AssistantContext { ) -> Range { let buffer = self.buffer.read(cx); let start_ix = match all_annotations - .binary_search_by(|probe| probe.range().end.cmp(&range.start, &buffer)) + .binary_search_by(|probe| probe.range().end.cmp(&range.start, buffer)) { Ok(ix) | Err(ix) => ix, }; let end_ix = match all_annotations - .binary_search_by(|probe| probe.range().start.cmp(&range.end, &buffer)) + .binary_search_by(|probe| probe.range().start.cmp(&range.end, buffer)) { Ok(ix) => ix + 1, Err(ix) => ix, @@ -1799,14 +1799,13 @@ impl AssistantContext { }); let end = this.buffer.read(cx).anchor_before(insert_position); - if run_commands_in_text { - if let Some(invoked_slash_command) = + if run_commands_in_text + && let Some(invoked_slash_command) = this.invoked_slash_commands.get_mut(&command_id) - { - invoked_slash_command - .run_commands_in_ranges - .push(start..end); - } + { + invoked_slash_command + .run_commands_in_ranges + .push(start..end); } } SlashCommandEvent::EndSection => { @@ -1862,7 +1861,7 @@ impl AssistantContext { { let newline_offset = insert_position.saturating_sub(1); if buffer.contains_str_at(newline_offset, "\n") - && last_section_range.map_or(true, |last_section_range| { + && last_section_range.is_none_or(|last_section_range| { !last_section_range .to_offset(buffer) .contains(&newline_offset) @@ -2045,7 +2044,7 @@ impl AssistantContext { let task = cx.spawn({ async move |this, cx| { - let stream = model.stream_completion(request, &cx); + let stream = model.stream_completion(request, cx); let assistant_message_id = assistant_message.id; let mut response_latency = None; let stream_completion = async { @@ -2081,15 +2080,12 @@ impl AssistantContext { match event { LanguageModelCompletionEvent::StatusUpdate(status_update) => { - match status_update { - CompletionRequestStatus::UsageUpdated { amount, limit } => { - this.update_model_request_usage( - amount as u32, - limit, - cx, - ); - } - _ => {} + if let CompletionRequestStatus::UsageUpdated { amount, limit } = status_update { + this.update_model_request_usage( + amount as u32, + limit, + cx, + ); } } LanguageModelCompletionEvent::StartMessage { .. } => {} @@ -2286,7 +2282,7 @@ impl AssistantContext { let mut contents = self.contents(cx).peekable(); fn collect_text_content(buffer: &Buffer, range: Range) -> Option { - let text: String = buffer.text_for_range(range.clone()).collect(); + let text: String = buffer.text_for_range(range).collect(); if text.trim().is_empty() { None } else { @@ -2315,10 +2311,7 @@ impl AssistantContext { let mut request_message = LanguageModelRequestMessage { role: message.role, content: Vec::new(), - cache: message - .cache - .as_ref() - .map_or(false, |cache| cache.is_anchor), + cache: message.cache.as_ref().is_some_and(|cache| cache.is_anchor), }; while let Some(content) = contents.peek() { @@ -2708,7 +2701,7 @@ impl AssistantContext { self.summary_task = cx.spawn(async move |this, cx| { let result = async { - let stream = model.model.stream_completion_text(request, &cx); + let stream = model.model.stream_completion_text(request, cx); let mut messages = stream.await?; let mut replaced = !replace_old; @@ -2741,10 +2734,10 @@ impl AssistantContext { } this.read_with(cx, |this, _cx| { - if let Some(summary) = this.summary.content() { - if summary.text.is_empty() { - bail!("Model generated an empty summary"); - } + if let Some(summary) = this.summary.content() + && summary.text.is_empty() + { + bail!("Model generated an empty summary"); } Ok(()) })??; @@ -2799,7 +2792,7 @@ impl AssistantContext { let mut current_message = messages.next(); while let Some(offset) = offsets.next() { // Locate the message that contains the offset. - while current_message.as_ref().map_or(false, |message| { + while current_message.as_ref().is_some_and(|message| { !message.offset_range.contains(&offset) && messages.peek().is_some() }) { current_message = messages.next(); @@ -2809,7 +2802,7 @@ impl AssistantContext { }; // Skip offsets that are in the same message. - while offsets.peek().map_or(false, |offset| { + while offsets.peek().is_some_and(|offset| { message.offset_range.contains(offset) || messages.peek().is_none() }) { offsets.next(); @@ -2924,18 +2917,18 @@ impl AssistantContext { fs.create_dir(contexts_dir().as_ref()).await?; // rename before write ensures that only one file exists - if let Some(old_path) = old_path.as_ref() { - if new_path.as_path() != old_path.as_ref() { - fs.rename( - &old_path, - &new_path, - RenameOptions { - overwrite: true, - ignore_if_exists: true, - }, - ) - .await?; - } + if let Some(old_path) = old_path.as_ref() + && new_path.as_path() != old_path.as_ref() + { + fs.rename( + old_path, + &new_path, + RenameOptions { + overwrite: true, + ignore_if_exists: true, + }, + ) + .await?; } // update path before write in case it fails diff --git a/crates/assistant_context/src/assistant_context_tests.rs b/crates/assistant_context/src/assistant_context_tests.rs index efcad8ed9654449c747ee4853c7e7aa689c0568b..8b182685cfeb4e3ae1b9df8c532b8f0c5ad91235 100644 --- a/crates/assistant_context/src/assistant_context_tests.rs +++ b/crates/assistant_context/src/assistant_context_tests.rs @@ -764,7 +764,7 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std let network = Arc::new(Mutex::new(Network::new(rng.clone()))); let mut contexts = Vec::new(); - let num_peers = rng.gen_range(min_peers..=max_peers); + let num_peers = rng.random_range(min_peers..=max_peers); let context_id = ContextId::new(); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); for i in 0..num_peers { @@ -806,10 +806,10 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std || !network.lock().is_idle() || network.lock().contains_disconnected_peers() { - let context_index = rng.gen_range(0..contexts.len()); + let context_index = rng.random_range(0..contexts.len()); let context = &contexts[context_index]; - match rng.gen_range(0..100) { + match rng.random_range(0..100) { 0..=29 if mutation_count > 0 => { log::info!("Context {}: edit buffer", context_index); context.update(cx, |context, cx| { @@ -874,10 +874,10 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std merge_same_roles: true, })]; - let num_sections = rng.gen_range(0..=3); + let num_sections = rng.random_range(0..=3); let mut section_start = 0; for _ in 0..num_sections { - let mut section_end = rng.gen_range(section_start..=output_text.len()); + let mut section_end = rng.random_range(section_start..=output_text.len()); while !output_text.is_char_boundary(section_end) { section_end += 1; } @@ -924,7 +924,7 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std 75..=84 if mutation_count > 0 => { context.update(cx, |context, cx| { if let Some(message) = context.messages(cx).choose(&mut rng) { - let new_status = match rng.gen_range(0..3) { + let new_status = match rng.random_range(0..3) { 0 => MessageStatus::Done, 1 => MessageStatus::Pending, _ => MessageStatus::Error(SharedString::from("Random error")), @@ -971,7 +971,7 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std network.lock().broadcast(replica_id, ops_to_send); context.update(cx, |context, cx| context.apply_ops(ops_to_receive, cx)); - } else if rng.gen_bool(0.1) && replica_id != 0 { + } else if rng.random_bool(0.1) && replica_id != 0 { log::info!("Context {}: disconnecting", context_index); network.lock().disconnect_peer(replica_id); } else if network.lock().has_unreceived(replica_id) { @@ -1055,7 +1055,7 @@ fn test_mark_cache_anchors(cx: &mut App) { assert_eq!( messages_cache(&context, cx) .iter() - .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor)) + .filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor)) .count(), 0, "Empty messages should not have any cache anchors." @@ -1083,7 +1083,7 @@ fn test_mark_cache_anchors(cx: &mut App) { assert_eq!( messages_cache(&context, cx) .iter() - .filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor)) + .filter(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor)) .count(), 0, "Messages should not be marked for cache before going over the token minimum." @@ -1098,7 +1098,7 @@ fn test_mark_cache_anchors(cx: &mut App) { assert_eq!( messages_cache(&context, cx) .iter() - .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor)) + .map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor)) .collect::>(), vec![true, true, false], "Last message should not be an anchor on speculative request." @@ -1116,7 +1116,7 @@ fn test_mark_cache_anchors(cx: &mut App) { assert_eq!( messages_cache(&context, cx) .iter() - .map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor)) + .map(|(_, cache)| cache.as_ref().is_some_and(|cache| cache.is_anchor)) .collect::>(), vec![false, true, true, false], "Most recent message should also be cached if not a speculative request." @@ -1300,7 +1300,7 @@ fn test_summarize_error( context.assist(cx); }); - simulate_successful_response(&model, cx); + simulate_successful_response(model, cx); context.read_with(cx, |context, _| { assert!(!context.summary().content().unwrap().done); @@ -1321,7 +1321,7 @@ fn test_summarize_error( fn setup_context_editor_with_fake_model( cx: &mut TestAppContext, ) -> (Entity, Arc) { - let registry = Arc::new(LanguageRegistry::test(cx.executor().clone())); + let registry = Arc::new(LanguageRegistry::test(cx.executor())); let fake_provider = Arc::new(FakeLanguageModelProvider::default()); let fake_model = Arc::new(fake_provider.test_model()); @@ -1376,7 +1376,7 @@ fn messages_cache( context .read(cx) .messages(cx) - .map(|message| (message.id, message.cache.clone())) + .map(|message| (message.id, message.cache)) .collect() } @@ -1436,6 +1436,6 @@ impl SlashCommand for FakeSlashCommand { sections: vec![], run_commands_in_text: false, } - .to_event_stream())) + .into_event_stream())) } } diff --git a/crates/assistant_context/src/context_store.rs b/crates/assistant_context/src/context_store.rs index 622d8867a7194924f0a7eacb520fe4e26f29539b..5fac44e31f4cc073af8fe6bbb57f75fc03b27f45 100644 --- a/crates/assistant_context/src/context_store.rs +++ b/crates/assistant_context/src/context_store.rs @@ -24,6 +24,7 @@ use rpc::AnyProtoClient; use std::sync::LazyLock; use std::{cmp::Reverse, ffi::OsStr, mem, path::Path, sync::Arc, time::Duration}; use util::{ResultExt, TryFutureExt}; +use zed_env_vars::ZED_STATELESS; pub(crate) fn init(client: &AnyProtoClient) { client.add_entity_message_handler(ContextStore::handle_advertise_contexts); @@ -320,7 +321,7 @@ impl ContextStore { .client .subscribe_to_entity(remote_id) .log_err() - .map(|subscription| subscription.set_entity(&cx.entity(), &mut cx.to_async())); + .map(|subscription| subscription.set_entity(&cx.entity(), &cx.to_async())); self.advertise_contexts(cx); } else { self.client_subscription = None; @@ -788,8 +789,6 @@ impl ContextStore { fn reload(&mut self, cx: &mut Context) -> Task> { let fs = self.fs.clone(); cx.spawn(async move |this, cx| { - pub static ZED_STATELESS: LazyLock = - LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty())); if *ZED_STATELESS { return Ok(()); } @@ -862,7 +861,7 @@ impl ContextStore { ContextServerStatus::Running => { self.load_context_server_slash_commands( server_id.clone(), - context_server_store.clone(), + context_server_store, cx, ); } @@ -894,34 +893,33 @@ impl ContextStore { return; }; - if protocol.capable(context_server::protocol::ServerCapability::Prompts) { - if let Some(response) = protocol + if protocol.capable(context_server::protocol::ServerCapability::Prompts) + && let Some(response) = protocol .request::(()) .await .log_err() - { - let slash_command_ids = response - .prompts - .into_iter() - .filter(assistant_slash_commands::acceptable_prompt) - .map(|prompt| { - log::info!("registering context server command: {:?}", prompt.name); - slash_command_working_set.insert(Arc::new( - assistant_slash_commands::ContextServerSlashCommand::new( - context_server_store.clone(), - server.id(), - prompt, - ), - )) - }) - .collect::>(); - - this.update(cx, |this, _cx| { - this.context_server_slash_command_ids - .insert(server_id.clone(), slash_command_ids); + { + let slash_command_ids = response + .prompts + .into_iter() + .filter(assistant_slash_commands::acceptable_prompt) + .map(|prompt| { + log::info!("registering context server command: {:?}", prompt.name); + slash_command_working_set.insert(Arc::new( + assistant_slash_commands::ContextServerSlashCommand::new( + context_server_store.clone(), + server.id(), + prompt, + ), + )) }) - .log_err(); - } + .collect::>(); + + this.update(cx, |this, _cx| { + this.context_server_slash_command_ids + .insert(server_id.clone(), slash_command_ids); + }) + .log_err(); } }) .detach(); diff --git a/crates/assistant_slash_command/src/assistant_slash_command.rs b/crates/assistant_slash_command/src/assistant_slash_command.rs index 828f115bf5ed8cfedf14c67243b4a8048d07ebd0..4b85fa2edf2afd6b3ea7df154b5e14ab492a8013 100644 --- a/crates/assistant_slash_command/src/assistant_slash_command.rs +++ b/crates/assistant_slash_command/src/assistant_slash_command.rs @@ -161,7 +161,7 @@ impl SlashCommandOutput { } /// Returns this [`SlashCommandOutput`] as a stream of [`SlashCommandEvent`]s. - pub fn to_event_stream(mut self) -> BoxStream<'static, Result> { + pub fn into_event_stream(mut self) -> BoxStream<'static, Result> { self.ensure_valid_section_ranges(); let mut events = Vec::new(); @@ -363,7 +363,7 @@ mod tests { run_commands_in_text: false, }; - let events = output.clone().to_event_stream().collect::>().await; + let events = output.clone().into_event_stream().collect::>().await; let events = events .into_iter() .filter_map(|event| event.ok()) @@ -386,7 +386,7 @@ mod tests { ); let new_output = - SlashCommandOutput::from_event_stream(output.clone().to_event_stream()) + SlashCommandOutput::from_event_stream(output.clone().into_event_stream()) .await .unwrap(); @@ -415,7 +415,7 @@ mod tests { run_commands_in_text: false, }; - let events = output.clone().to_event_stream().collect::>().await; + let events = output.clone().into_event_stream().collect::>().await; let events = events .into_iter() .filter_map(|event| event.ok()) @@ -452,7 +452,7 @@ mod tests { ); let new_output = - SlashCommandOutput::from_event_stream(output.clone().to_event_stream()) + SlashCommandOutput::from_event_stream(output.clone().into_event_stream()) .await .unwrap(); @@ -493,7 +493,7 @@ mod tests { run_commands_in_text: false, }; - let events = output.clone().to_event_stream().collect::>().await; + let events = output.clone().into_event_stream().collect::>().await; let events = events .into_iter() .filter_map(|event| event.ok()) @@ -562,7 +562,7 @@ mod tests { ); let new_output = - SlashCommandOutput::from_event_stream(output.clone().to_event_stream()) + SlashCommandOutput::from_event_stream(output.clone().into_event_stream()) .await .unwrap(); diff --git a/crates/assistant_slash_command/src/extension_slash_command.rs b/crates/assistant_slash_command/src/extension_slash_command.rs index 74c46ffb5ffefb2ccbefdba8edec4e9e778489b5..e47ae52c98740af17c90fe657386bb0120773d9b 100644 --- a/crates/assistant_slash_command/src/extension_slash_command.rs +++ b/crates/assistant_slash_command/src/extension_slash_command.rs @@ -166,7 +166,7 @@ impl SlashCommand for ExtensionSlashCommand { .collect(), run_commands_in_text: false, } - .to_event_stream()) + .into_event_stream()) }) } } diff --git a/crates/assistant_slash_commands/Cargo.toml b/crates/assistant_slash_commands/Cargo.toml index f703a753f5d261f4151d0d6a47eb3753fd18afb8..c054c3ced84825bcd131bdd76644c00595c4c4a9 100644 --- a/crates/assistant_slash_commands/Cargo.toml +++ b/crates/assistant_slash_commands/Cargo.toml @@ -27,7 +27,6 @@ globset.workspace = true gpui.workspace = true html_to_markdown.workspace = true http_client.workspace = true -indexed_docs.workspace = true language.workspace = true project.workspace = true prompt_store.workspace = true diff --git a/crates/assistant_slash_commands/src/assistant_slash_commands.rs b/crates/assistant_slash_commands/src/assistant_slash_commands.rs index fa5dd8b683d4404365db252e27f9e8e30db6ca30..fb00a912197e07942a67ad92418b85c4920ad66b 100644 --- a/crates/assistant_slash_commands/src/assistant_slash_commands.rs +++ b/crates/assistant_slash_commands/src/assistant_slash_commands.rs @@ -3,7 +3,6 @@ mod context_server_command; mod default_command; mod delta_command; mod diagnostics_command; -mod docs_command; mod fetch_command; mod file_command; mod now_command; @@ -18,7 +17,6 @@ pub use crate::context_server_command::*; pub use crate::default_command::*; pub use crate::delta_command::*; pub use crate::diagnostics_command::*; -pub use crate::docs_command::*; pub use crate::fetch_command::*; pub use crate::file_command::*; pub use crate::now_command::*; diff --git a/crates/assistant_slash_commands/src/cargo_workspace_command.rs b/crates/assistant_slash_commands/src/cargo_workspace_command.rs index 8b088ea012de5f1ef6f7c787924c3cb2c6ec44c8..d58b2edc4c3dffd799dd9eb1c104686dc6488687 100644 --- a/crates/assistant_slash_commands/src/cargo_workspace_command.rs +++ b/crates/assistant_slash_commands/src/cargo_workspace_command.rs @@ -150,7 +150,7 @@ impl SlashCommand for CargoWorkspaceSlashCommand { }], run_commands_in_text: false, } - .to_event_stream()) + .into_event_stream()) }) }); output.unwrap_or_else(|error| Task::ready(Err(error))) diff --git a/crates/assistant_slash_commands/src/context_server_command.rs b/crates/assistant_slash_commands/src/context_server_command.rs index f223d3b184ccf6d795b80caca9a6a616aafc7f33..ee0cbf54c23a595f6503162c91dd1df3be019dd5 100644 --- a/crates/assistant_slash_commands/src/context_server_command.rs +++ b/crates/assistant_slash_commands/src/context_server_command.rs @@ -39,12 +39,12 @@ impl SlashCommand for ContextServerSlashCommand { fn label(&self, cx: &App) -> language::CodeLabel { let mut parts = vec![self.prompt.name.as_str()]; - if let Some(args) = &self.prompt.arguments { - if let Some(arg) = args.first() { - parts.push(arg.name.as_str()); - } + if let Some(args) = &self.prompt.arguments + && let Some(arg) = args.first() + { + parts.push(arg.name.as_str()); } - create_label_for_command(&parts[0], &parts[1..], cx) + create_label_for_command(parts[0], &parts[1..], cx) } fn description(&self) -> String { @@ -62,9 +62,10 @@ impl SlashCommand for ContextServerSlashCommand { } fn requires_argument(&self) -> bool { - self.prompt.arguments.as_ref().map_or(false, |args| { - args.iter().any(|arg| arg.required == Some(true)) - }) + self.prompt + .arguments + .as_ref() + .is_some_and(|args| args.iter().any(|arg| arg.required == Some(true))) } fn complete_argument( @@ -190,7 +191,7 @@ impl SlashCommand for ContextServerSlashCommand { text: prompt, run_commands_in_text: false, } - .to_event_stream()) + .into_event_stream()) }) } else { Task::ready(Err(anyhow!("Context server not found"))) diff --git a/crates/assistant_slash_commands/src/default_command.rs b/crates/assistant_slash_commands/src/default_command.rs index 6fce7f07a46d3d248c1c1292a67f1ad577c43645..01eff881cff0f07db9bf34e25853432e413ed79f 100644 --- a/crates/assistant_slash_commands/src/default_command.rs +++ b/crates/assistant_slash_commands/src/default_command.rs @@ -85,7 +85,7 @@ impl SlashCommand for DefaultSlashCommand { text, run_commands_in_text: true, } - .to_event_stream()) + .into_event_stream()) }) } } diff --git a/crates/assistant_slash_commands/src/delta_command.rs b/crates/assistant_slash_commands/src/delta_command.rs index 8c840c17b2c7fe9d8c8995b21c35cb35980dd71b..ea05fca588d0a496eeb3a2d2128b3861ba8a1e30 100644 --- a/crates/assistant_slash_commands/src/delta_command.rs +++ b/crates/assistant_slash_commands/src/delta_command.rs @@ -66,23 +66,22 @@ impl SlashCommand for DeltaSlashCommand { .metadata .as_ref() .and_then(|value| serde_json::from_value::(value.clone()).ok()) + && paths.insert(metadata.path.clone()) { - if paths.insert(metadata.path.clone()) { - file_command_old_outputs.push( - context_buffer - .as_rope() - .slice(section.range.to_offset(&context_buffer)), - ); - file_command_new_outputs.push(Arc::new(FileSlashCommand).run( - std::slice::from_ref(&metadata.path), - context_slash_command_output_sections, - context_buffer.clone(), - workspace.clone(), - delegate.clone(), - window, - cx, - )); - } + file_command_old_outputs.push( + context_buffer + .as_rope() + .slice(section.range.to_offset(&context_buffer)), + ); + file_command_new_outputs.push(Arc::new(FileSlashCommand).run( + std::slice::from_ref(&metadata.path), + context_slash_command_output_sections, + context_buffer.clone(), + workspace.clone(), + delegate.clone(), + window, + cx, + )); } } @@ -95,31 +94,31 @@ impl SlashCommand for DeltaSlashCommand { .into_iter() .zip(file_command_new_outputs) { - if let Ok(new_output) = new_output { - if let Ok(new_output) = SlashCommandOutput::from_event_stream(new_output).await - { - if let Some(file_command_range) = new_output.sections.first() { - let new_text = &new_output.text[file_command_range.range.clone()]; - if old_text.chars().ne(new_text.chars()) { - changes_detected = true; - output.sections.extend(new_output.sections.into_iter().map( - |section| SlashCommandOutputSection { - range: output.text.len() + section.range.start - ..output.text.len() + section.range.end, - icon: section.icon, - label: section.label, - metadata: section.metadata, - }, - )); - output.text.push_str(&new_output.text); - } - } + if let Ok(new_output) = new_output + && let Ok(new_output) = SlashCommandOutput::from_event_stream(new_output).await + && let Some(file_command_range) = new_output.sections.first() + { + let new_text = &new_output.text[file_command_range.range.clone()]; + if old_text.chars().ne(new_text.chars()) { + changes_detected = true; + output + .sections + .extend(new_output.sections.into_iter().map(|section| { + SlashCommandOutputSection { + range: output.text.len() + section.range.start + ..output.text.len() + section.range.end, + icon: section.icon, + label: section.label, + metadata: section.metadata, + } + })); + output.text.push_str(&new_output.text); } } } anyhow::ensure!(changes_detected, "no new changes detected"); - Ok(output.to_event_stream()) + Ok(output.into_event_stream()) }) } } diff --git a/crates/assistant_slash_commands/src/diagnostics_command.rs b/crates/assistant_slash_commands/src/diagnostics_command.rs index 2feabd8b1e018cc6495a88fe5a89276e3e19dfb1..8b1dbd515cabeb498d2a639387b426527dcda651 100644 --- a/crates/assistant_slash_commands/src/diagnostics_command.rs +++ b/crates/assistant_slash_commands/src/diagnostics_command.rs @@ -44,7 +44,7 @@ impl DiagnosticsSlashCommand { score: 0., positions: Vec::new(), worktree_id: entry.worktree_id.to_usize(), - path: entry.path.clone(), + path: entry.path, path_prefix: path_prefix.clone(), is_dir: false, // Diagnostics can't be produced for directories distance_to_relative_ancestor: 0, @@ -61,7 +61,7 @@ impl DiagnosticsSlashCommand { snapshot: worktree.snapshot(), include_ignored: worktree .root_entry() - .map_or(false, |entry| entry.is_ignored), + .is_some_and(|entry| entry.is_ignored), include_root_name: true, candidates: project::Candidates::Entries, } @@ -189,7 +189,7 @@ impl SlashCommand for DiagnosticsSlashCommand { window.spawn(cx, async move |_| { task.await? - .map(|output| output.to_event_stream()) + .map(|output| output.into_event_stream()) .context("No diagnostics found") }) } @@ -249,7 +249,7 @@ fn collect_diagnostics( let worktree = worktree.read(cx); let worktree_root_path = Path::new(worktree.root_name()); let relative_path = path.strip_prefix(worktree_root_path).ok()?; - worktree.absolutize(&relative_path).ok() + worktree.absolutize(relative_path).ok() }) }) .is_some() @@ -280,10 +280,10 @@ fn collect_diagnostics( let mut project_summary = DiagnosticSummary::default(); for (project_path, path, summary) in diagnostic_summaries { - if let Some(path_matcher) = &options.path_matcher { - if !path_matcher.is_match(&path) { - continue; - } + if let Some(path_matcher) = &options.path_matcher + && !path_matcher.is_match(&path) + { + continue; } project_summary.error_count += summary.error_count; @@ -365,7 +365,7 @@ pub fn collect_buffer_diagnostics( ) { for (_, group) in snapshot.diagnostic_groups(None) { let entry = &group.entries[group.primary_ix]; - collect_diagnostic(output, entry, &snapshot, include_warnings) + collect_diagnostic(output, entry, snapshot, include_warnings) } } @@ -396,7 +396,7 @@ fn collect_diagnostic( let start_row = range.start.row.saturating_sub(EXCERPT_EXPANSION_SIZE); let end_row = (range.end.row + EXCERPT_EXPANSION_SIZE).min(snapshot.max_point().row) + 1; let excerpt_range = - Point::new(start_row, 0).to_offset(&snapshot)..Point::new(end_row, 0).to_offset(&snapshot); + Point::new(start_row, 0).to_offset(snapshot)..Point::new(end_row, 0).to_offset(snapshot); output.text.push_str("```"); if let Some(language_name) = snapshot.language().map(|l| l.code_fence_block_name()) { diff --git a/crates/assistant_slash_commands/src/docs_command.rs b/crates/assistant_slash_commands/src/docs_command.rs deleted file mode 100644 index bd87c72849e1eb54ca782d978f319676c1e8b3fe..0000000000000000000000000000000000000000 --- a/crates/assistant_slash_commands/src/docs_command.rs +++ /dev/null @@ -1,543 +0,0 @@ -use std::path::Path; -use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use std::time::Duration; - -use anyhow::{Context as _, Result, anyhow, bail}; -use assistant_slash_command::{ - ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection, - SlashCommandResult, -}; -use gpui::{App, BackgroundExecutor, Entity, Task, WeakEntity}; -use indexed_docs::{ - DocsDotRsProvider, IndexedDocsRegistry, IndexedDocsStore, LocalRustdocProvider, PackageName, - ProviderId, -}; -use language::{BufferSnapshot, LspAdapterDelegate}; -use project::{Project, ProjectPath}; -use ui::prelude::*; -use util::{ResultExt, maybe}; -use workspace::Workspace; - -pub struct DocsSlashCommand; - -impl DocsSlashCommand { - pub const NAME: &'static str = "docs"; - - fn path_to_cargo_toml(project: Entity, cx: &mut App) -> Option> { - let worktree = project.read(cx).worktrees(cx).next()?; - let worktree = worktree.read(cx); - let entry = worktree.entry_for_path("Cargo.toml")?; - let path = ProjectPath { - worktree_id: worktree.id(), - path: entry.path.clone(), - }; - Some(Arc::from( - project.read(cx).absolute_path(&path, cx)?.as_path(), - )) - } - - /// Ensures that the indexed doc providers for Rust are registered. - /// - /// Ideally we would do this sooner, but we need to wait until we're able to - /// access the workspace so we can read the project. - fn ensure_rust_doc_providers_are_registered( - &self, - workspace: Option>, - cx: &mut App, - ) { - let indexed_docs_registry = IndexedDocsRegistry::global(cx); - if indexed_docs_registry - .get_provider_store(LocalRustdocProvider::id()) - .is_none() - { - let index_provider_deps = maybe!({ - let workspace = workspace - .as_ref() - .context("no workspace")? - .upgrade() - .context("workspace dropped")?; - let project = workspace.read(cx).project().clone(); - let fs = project.read(cx).fs().clone(); - let cargo_workspace_root = Self::path_to_cargo_toml(project, cx) - .and_then(|path| path.parent().map(|path| path.to_path_buf())) - .context("no Cargo workspace root found")?; - - anyhow::Ok((fs, cargo_workspace_root)) - }); - - if let Some((fs, cargo_workspace_root)) = index_provider_deps.log_err() { - indexed_docs_registry.register_provider(Box::new(LocalRustdocProvider::new( - fs, - cargo_workspace_root, - ))); - } - } - - if indexed_docs_registry - .get_provider_store(DocsDotRsProvider::id()) - .is_none() - { - let http_client = maybe!({ - let workspace = workspace - .as_ref() - .context("no workspace")? - .upgrade() - .context("workspace was dropped")?; - let project = workspace.read(cx).project().clone(); - anyhow::Ok(project.read(cx).client().http_client()) - }); - - if let Some(http_client) = http_client.log_err() { - indexed_docs_registry - .register_provider(Box::new(DocsDotRsProvider::new(http_client))); - } - } - } - - /// Runs just-in-time indexing for a given package, in case the slash command - /// is run without any entries existing in the index. - fn run_just_in_time_indexing( - store: Arc, - key: String, - package: PackageName, - executor: BackgroundExecutor, - ) -> Task<()> { - executor.clone().spawn(async move { - let (prefix, needs_full_index) = if let Some((prefix, _)) = key.split_once('*') { - // If we have a wildcard in the search, we want to wait until - // we've completely finished indexing so we get a full set of - // results for the wildcard. - (prefix.to_string(), true) - } else { - (key, false) - }; - - // If we already have some entries, we assume that we've indexed the package before - // and don't need to do it again. - let has_any_entries = store - .any_with_prefix(prefix.clone()) - .await - .unwrap_or_default(); - if has_any_entries { - return (); - }; - - let index_task = store.clone().index(package.clone()); - - if needs_full_index { - _ = index_task.await; - } else { - loop { - executor.timer(Duration::from_millis(200)).await; - - if store - .any_with_prefix(prefix.clone()) - .await - .unwrap_or_default() - || !store.is_indexing(&package) - { - break; - } - } - } - }) - } -} - -impl SlashCommand for DocsSlashCommand { - fn name(&self) -> String { - Self::NAME.into() - } - - fn description(&self) -> String { - "insert docs".into() - } - - fn menu_text(&self) -> String { - "Insert Documentation".into() - } - - fn requires_argument(&self) -> bool { - true - } - - fn complete_argument( - self: Arc, - arguments: &[String], - _cancel: Arc, - workspace: Option>, - _: &mut Window, - cx: &mut App, - ) -> Task>> { - self.ensure_rust_doc_providers_are_registered(workspace, cx); - - let indexed_docs_registry = IndexedDocsRegistry::global(cx); - let args = DocsSlashCommandArgs::parse(arguments); - let store = args - .provider() - .context("no docs provider specified") - .and_then(|provider| IndexedDocsStore::try_global(provider, cx)); - cx.background_spawn(async move { - fn build_completions(items: Vec) -> Vec { - items - .into_iter() - .map(|item| ArgumentCompletion { - label: item.clone().into(), - new_text: item.to_string(), - after_completion: assistant_slash_command::AfterCompletion::Run, - replace_previous_arguments: false, - }) - .collect() - } - - match args { - DocsSlashCommandArgs::NoProvider => { - let providers = indexed_docs_registry.list_providers(); - if providers.is_empty() { - return Ok(vec![ArgumentCompletion { - label: "No available docs providers.".into(), - new_text: String::new(), - after_completion: false.into(), - replace_previous_arguments: false, - }]); - } - - Ok(providers - .into_iter() - .map(|provider| ArgumentCompletion { - label: provider.to_string().into(), - new_text: provider.to_string(), - after_completion: false.into(), - replace_previous_arguments: false, - }) - .collect()) - } - DocsSlashCommandArgs::SearchPackageDocs { - provider, - package, - index, - } => { - let store = store?; - - if index { - // We don't need to hold onto this task, as the `IndexedDocsStore` will hold it - // until it completes. - drop(store.clone().index(package.as_str().into())); - } - - let suggested_packages = store.clone().suggest_packages().await?; - let search_results = store.search(package).await; - - let mut items = build_completions(search_results); - let workspace_crate_completions = suggested_packages - .into_iter() - .filter(|package_name| { - !items - .iter() - .any(|item| item.label.text() == package_name.as_ref()) - }) - .map(|package_name| ArgumentCompletion { - label: format!("{package_name} (unindexed)").into(), - new_text: format!("{package_name}"), - after_completion: true.into(), - replace_previous_arguments: false, - }) - .collect::>(); - items.extend(workspace_crate_completions); - - if items.is_empty() { - return Ok(vec![ArgumentCompletion { - label: format!( - "Enter a {package_term} name.", - package_term = package_term(&provider) - ) - .into(), - new_text: provider.to_string(), - after_completion: false.into(), - replace_previous_arguments: false, - }]); - } - - Ok(items) - } - DocsSlashCommandArgs::SearchItemDocs { item_path, .. } => { - let store = store?; - let items = store.search(item_path).await; - Ok(build_completions(items)) - } - } - }) - } - - fn run( - self: Arc, - arguments: &[String], - _context_slash_command_output_sections: &[SlashCommandOutputSection], - _context_buffer: BufferSnapshot, - _workspace: WeakEntity, - _delegate: Option>, - _: &mut Window, - cx: &mut App, - ) -> Task { - if arguments.is_empty() { - return Task::ready(Err(anyhow!("missing an argument"))); - }; - - let args = DocsSlashCommandArgs::parse(arguments); - let executor = cx.background_executor().clone(); - let task = cx.background_spawn({ - let store = args - .provider() - .context("no docs provider specified") - .and_then(|provider| IndexedDocsStore::try_global(provider, cx)); - async move { - let (provider, key) = match args.clone() { - DocsSlashCommandArgs::NoProvider => bail!("no docs provider specified"), - DocsSlashCommandArgs::SearchPackageDocs { - provider, package, .. - } => (provider, package), - DocsSlashCommandArgs::SearchItemDocs { - provider, - item_path, - .. - } => (provider, item_path), - }; - - if key.trim().is_empty() { - bail!( - "no {package_term} name provided", - package_term = package_term(&provider) - ); - } - - let store = store?; - - if let Some(package) = args.package() { - Self::run_just_in_time_indexing(store.clone(), key.clone(), package, executor) - .await; - } - - let (text, ranges) = if let Some((prefix, _)) = key.split_once('*') { - let docs = store.load_many_by_prefix(prefix.to_string()).await?; - - let mut text = String::new(); - let mut ranges = Vec::new(); - - for (key, docs) in docs { - let prev_len = text.len(); - - text.push_str(&docs.0); - text.push_str("\n"); - ranges.push((key, prev_len..text.len())); - text.push_str("\n"); - } - - (text, ranges) - } else { - let item_docs = store.load(key.clone()).await?; - let text = item_docs.to_string(); - let range = 0..text.len(); - - (text, vec![(key, range)]) - }; - - anyhow::Ok((provider, text, ranges)) - } - }); - - cx.foreground_executor().spawn(async move { - let (provider, text, ranges) = task.await?; - Ok(SlashCommandOutput { - text, - sections: ranges - .into_iter() - .map(|(key, range)| SlashCommandOutputSection { - range, - icon: IconName::FileDoc, - label: format!("docs ({provider}): {key}",).into(), - metadata: None, - }) - .collect(), - run_commands_in_text: false, - } - .to_event_stream()) - }) - } -} - -fn is_item_path_delimiter(char: char) -> bool { - !char.is_alphanumeric() && char != '-' && char != '_' -} - -#[derive(Debug, PartialEq, Clone)] -pub enum DocsSlashCommandArgs { - NoProvider, - SearchPackageDocs { - provider: ProviderId, - package: String, - index: bool, - }, - SearchItemDocs { - provider: ProviderId, - package: String, - item_path: String, - }, -} - -impl DocsSlashCommandArgs { - pub fn parse(arguments: &[String]) -> Self { - let Some(provider) = arguments - .get(0) - .cloned() - .filter(|arg| !arg.trim().is_empty()) - else { - return Self::NoProvider; - }; - let provider = ProviderId(provider.into()); - let Some(argument) = arguments.get(1) else { - return Self::NoProvider; - }; - - if let Some((package, rest)) = argument.split_once(is_item_path_delimiter) { - if rest.trim().is_empty() { - Self::SearchPackageDocs { - provider, - package: package.to_owned(), - index: true, - } - } else { - Self::SearchItemDocs { - provider, - package: package.to_owned(), - item_path: argument.to_owned(), - } - } - } else { - Self::SearchPackageDocs { - provider, - package: argument.to_owned(), - index: false, - } - } - } - - pub fn provider(&self) -> Option { - match self { - Self::NoProvider => None, - Self::SearchPackageDocs { provider, .. } | Self::SearchItemDocs { provider, .. } => { - Some(provider.clone()) - } - } - } - - pub fn package(&self) -> Option { - match self { - Self::NoProvider => None, - Self::SearchPackageDocs { package, .. } | Self::SearchItemDocs { package, .. } => { - Some(package.as_str().into()) - } - } - } -} - -/// Returns the term used to refer to a package. -fn package_term(provider: &ProviderId) -> &'static str { - if provider == &DocsDotRsProvider::id() || provider == &LocalRustdocProvider::id() { - return "crate"; - } - - "package" -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_docs_slash_command_args() { - assert_eq!( - DocsSlashCommandArgs::parse(&["".to_string()]), - DocsSlashCommandArgs::NoProvider - ); - assert_eq!( - DocsSlashCommandArgs::parse(&["rustdoc".to_string()]), - DocsSlashCommandArgs::NoProvider - ); - - assert_eq!( - DocsSlashCommandArgs::parse(&["rustdoc".to_string(), "".to_string()]), - DocsSlashCommandArgs::SearchPackageDocs { - provider: ProviderId("rustdoc".into()), - package: "".into(), - index: false - } - ); - assert_eq!( - DocsSlashCommandArgs::parse(&["gleam".to_string(), "".to_string()]), - DocsSlashCommandArgs::SearchPackageDocs { - provider: ProviderId("gleam".into()), - package: "".into(), - index: false - } - ); - - assert_eq!( - DocsSlashCommandArgs::parse(&["rustdoc".to_string(), "gpui".to_string()]), - DocsSlashCommandArgs::SearchPackageDocs { - provider: ProviderId("rustdoc".into()), - package: "gpui".into(), - index: false, - } - ); - assert_eq!( - DocsSlashCommandArgs::parse(&["gleam".to_string(), "gleam_stdlib".to_string()]), - DocsSlashCommandArgs::SearchPackageDocs { - provider: ProviderId("gleam".into()), - package: "gleam_stdlib".into(), - index: false - } - ); - - // Adding an item path delimiter indicates we can start indexing. - assert_eq!( - DocsSlashCommandArgs::parse(&["rustdoc".to_string(), "gpui:".to_string()]), - DocsSlashCommandArgs::SearchPackageDocs { - provider: ProviderId("rustdoc".into()), - package: "gpui".into(), - index: true, - } - ); - assert_eq!( - DocsSlashCommandArgs::parse(&["gleam".to_string(), "gleam_stdlib/".to_string()]), - DocsSlashCommandArgs::SearchPackageDocs { - provider: ProviderId("gleam".into()), - package: "gleam_stdlib".into(), - index: true - } - ); - - assert_eq!( - DocsSlashCommandArgs::parse(&[ - "rustdoc".to_string(), - "gpui::foo::bar::Baz".to_string() - ]), - DocsSlashCommandArgs::SearchItemDocs { - provider: ProviderId("rustdoc".into()), - package: "gpui".into(), - item_path: "gpui::foo::bar::Baz".into() - } - ); - assert_eq!( - DocsSlashCommandArgs::parse(&[ - "gleam".to_string(), - "gleam_stdlib/gleam/int".to_string() - ]), - DocsSlashCommandArgs::SearchItemDocs { - provider: ProviderId("gleam".into()), - package: "gleam_stdlib".into(), - item_path: "gleam_stdlib/gleam/int".into() - } - ); - } -} diff --git a/crates/assistant_slash_commands/src/fetch_command.rs b/crates/assistant_slash_commands/src/fetch_command.rs index 4e0bb3d05a7f3c2828206a6c4deeaee8c505ed7e..6d3f66c9a23c896c765ba6c0a43b7a99dbc7ee73 100644 --- a/crates/assistant_slash_commands/src/fetch_command.rs +++ b/crates/assistant_slash_commands/src/fetch_command.rs @@ -177,7 +177,7 @@ impl SlashCommand for FetchSlashCommand { }], run_commands_in_text: false, } - .to_event_stream()) + .into_event_stream()) }) } } diff --git a/crates/assistant_slash_commands/src/file_command.rs b/crates/assistant_slash_commands/src/file_command.rs index c913ccc0f199cb5d03cf0a91d67459f3728b55a9..261e15bc0ae8b9e886d4d146696db78e5c0c831d 100644 --- a/crates/assistant_slash_commands/src/file_command.rs +++ b/crates/assistant_slash_commands/src/file_command.rs @@ -92,7 +92,7 @@ impl FileSlashCommand { snapshot: worktree.snapshot(), include_ignored: worktree .root_entry() - .map_or(false, |entry| entry.is_ignored), + .is_some_and(|entry| entry.is_ignored), include_root_name: true, candidates: project::Candidates::Entries, } @@ -223,7 +223,7 @@ fn collect_files( cx: &mut App, ) -> impl Stream> + use<> { let Ok(matchers) = glob_inputs - .into_iter() + .iter() .map(|glob_input| { custom_path_matcher::PathMatcher::new(&[glob_input.to_owned()]) .with_context(|| format!("invalid path {glob_input}")) @@ -371,7 +371,7 @@ fn collect_files( &mut output, ) .log_err(); - let mut buffer_events = output.to_event_stream(); + let mut buffer_events = output.into_event_stream(); while let Some(event) = buffer_events.next().await { events_tx.unbounded_send(event)?; } @@ -379,7 +379,7 @@ fn collect_files( } } - while let Some(_) = directory_stack.pop() { + while directory_stack.pop().is_some() { events_tx.unbounded_send(Ok(SlashCommandEvent::EndSection))?; } } @@ -491,8 +491,8 @@ mod custom_path_matcher { impl PathMatcher { pub fn new(globs: &[String]) -> Result { let globs = globs - .into_iter() - .map(|glob| Glob::new(&SanitizedPath::from(glob).to_glob_string())) + .iter() + .map(|glob| Glob::new(&SanitizedPath::new(glob).to_glob_string())) .collect::, _>>()?; let sources = globs.iter().map(|glob| glob.glob().to_owned()).collect(); let sources_with_trailing_slash = globs @@ -536,7 +536,7 @@ mod custom_path_matcher { let path_str = path.to_string_lossy(); let separator = std::path::MAIN_SEPARATOR_STR; if path_str.ends_with(separator) { - return false; + false } else { self.glob.is_match(path_str.to_string() + separator) } diff --git a/crates/assistant_slash_commands/src/now_command.rs b/crates/assistant_slash_commands/src/now_command.rs index e4abef2a7c80fbdc96df28cbd1072d180fd864f3..aec21e7173bafd4cb07e7c37135fa0ad6fa88812 100644 --- a/crates/assistant_slash_commands/src/now_command.rs +++ b/crates/assistant_slash_commands/src/now_command.rs @@ -66,6 +66,6 @@ impl SlashCommand for NowSlashCommand { }], run_commands_in_text: false, } - .to_event_stream())) + .into_event_stream())) } } diff --git a/crates/assistant_slash_commands/src/prompt_command.rs b/crates/assistant_slash_commands/src/prompt_command.rs index c177f9f3599525924aa18700ea09d5fe977a5698..bbd6d3e3ad201c06940d6dc986616f61c8e15547 100644 --- a/crates/assistant_slash_commands/src/prompt_command.rs +++ b/crates/assistant_slash_commands/src/prompt_command.rs @@ -80,7 +80,7 @@ impl SlashCommand for PromptSlashCommand { }; let store = PromptStore::global(cx); - let title = SharedString::from(title.clone()); + let title = SharedString::from(title); let prompt = cx.spawn({ let title = title.clone(); async move |cx| { @@ -117,7 +117,7 @@ impl SlashCommand for PromptSlashCommand { }], run_commands_in_text: true, } - .to_event_stream()) + .into_event_stream()) }) } } diff --git a/crates/assistant_slash_commands/src/symbols_command.rs b/crates/assistant_slash_commands/src/symbols_command.rs index ef9314643116689d36e99b2a9bcb7d69982a776f..c700319800769e3a6e45234355c850e746231200 100644 --- a/crates/assistant_slash_commands/src/symbols_command.rs +++ b/crates/assistant_slash_commands/src/symbols_command.rs @@ -1,4 +1,4 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use assistant_slash_command::{ ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection, SlashCommandResult, @@ -70,9 +70,7 @@ impl SlashCommand for OutlineSlashCommand { let path = snapshot.resolve_file_path(cx, true); cx.background_spawn(async move { - let outline = snapshot - .outline(None) - .context("no symbols for active tab")?; + let outline = snapshot.outline(None); let path = path.as_deref().unwrap_or(Path::new("untitled")); let mut outline_text = format!("Symbols for {}:\n", path.display()); @@ -92,7 +90,7 @@ impl SlashCommand for OutlineSlashCommand { text: outline_text, run_commands_in_text: false, } - .to_event_stream()) + .into_event_stream()) }) }); diff --git a/crates/assistant_slash_commands/src/tab_command.rs b/crates/assistant_slash_commands/src/tab_command.rs index ca7601bc4c3a48d9d9c352ad545d72c032e7c47e..a124beed6302d6c67085ccb70f4c3aa58834d3f2 100644 --- a/crates/assistant_slash_commands/src/tab_command.rs +++ b/crates/assistant_slash_commands/src/tab_command.rs @@ -157,7 +157,7 @@ impl SlashCommand for TabSlashCommand { for (full_path, buffer, _) in tab_items_search.await? { append_buffer_to_output(&buffer, full_path.as_deref(), &mut output).log_err(); } - Ok(output.to_event_stream()) + Ok(output.into_event_stream()) }) } } @@ -195,16 +195,14 @@ fn tab_items_for_queries( } for editor in workspace.items_of_type::(cx) { - if let Some(buffer) = editor.read(cx).buffer().read(cx).as_singleton() { - if let Some(timestamp) = + if let Some(buffer) = editor.read(cx).buffer().read(cx).as_singleton() + && let Some(timestamp) = timestamps_by_entity_id.get(&editor.entity_id()) - { - if visited_buffers.insert(buffer.read(cx).remote_id()) { - let snapshot = buffer.read(cx).snapshot(); - let full_path = snapshot.resolve_file_path(cx, true); - open_buffers.push((full_path, snapshot, *timestamp)); - } - } + && visited_buffers.insert(buffer.read(cx).remote_id()) + { + let snapshot = buffer.read(cx).snapshot(); + let full_path = snapshot.resolve_file_path(cx, true); + open_buffers.push((full_path, snapshot, *timestamp)); } } diff --git a/crates/assistant_tool/src/outline.rs b/crates/assistant_tool/src/outline.rs index 4f8bde5456073912185fe160d48363eac7601ef5..d9bf64cf059a33f1cc4e6d2833e77a0554b82d93 100644 --- a/crates/assistant_tool/src/outline.rs +++ b/crates/assistant_tool/src/outline.rs @@ -41,9 +41,7 @@ pub async fn file_outline( } let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?; - let outline = snapshot - .outline(None) - .context("No outline information available for this file at path {path}")?; + let outline = snapshot.outline(None); render_outline( outline diff --git a/crates/assistant_tool/src/tool_schema.rs b/crates/assistant_tool/src/tool_schema.rs index 7b48f93ba6d23bcc1a6e2cf051737efaf69fa595..192f7c8a2bb565ece01a3472a9e46dad316377f4 100644 --- a/crates/assistant_tool/src/tool_schema.rs +++ b/crates/assistant_tool/src/tool_schema.rs @@ -24,16 +24,16 @@ pub fn adapt_schema_to_format( fn preprocess_json_schema(json: &mut Value) -> Result<()> { // `additionalProperties` defaults to `false` unless explicitly specified. // This prevents models from hallucinating tool parameters. - if let Value::Object(obj) = json { - if matches!(obj.get("type"), Some(Value::String(s)) if s == "object") { - if !obj.contains_key("additionalProperties") { - obj.insert("additionalProperties".to_string(), Value::Bool(false)); - } + if let Value::Object(obj) = json + && matches!(obj.get("type"), Some(Value::String(s)) if s == "object") + { + if !obj.contains_key("additionalProperties") { + obj.insert("additionalProperties".to_string(), Value::Bool(false)); + } - // OpenAI API requires non-missing `properties` - if !obj.contains_key("properties") { - obj.insert("properties".to_string(), Value::Object(Default::default())); - } + // OpenAI API requires non-missing `properties` + if !obj.contains_key("properties") { + obj.insert("properties".to_string(), Value::Object(Default::default())); } } Ok(()) @@ -59,10 +59,10 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { ("optional", |value| value.is_boolean()), ]; for (key, predicate) in KEYS_TO_REMOVE { - if let Some(value) = obj.get(key) { - if predicate(value) { - obj.remove(key); - } + if let Some(value) = obj.get(key) + && predicate(value) + { + obj.remove(key); } } @@ -77,12 +77,12 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { } // Handle oneOf -> anyOf conversion - if let Some(subschemas) = obj.get_mut("oneOf") { - if subschemas.is_array() { - let subschemas_clone = subschemas.clone(); - obj.remove("oneOf"); - obj.insert("anyOf".to_string(), subschemas_clone); - } + if let Some(subschemas) = obj.get_mut("oneOf") + && subschemas.is_array() + { + let subschemas_clone = subschemas.clone(); + obj.remove("oneOf"); + obj.insert("anyOf".to_string(), subschemas_clone); } // Recursively process all nested objects and arrays diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs index c0a358917b499908d85fbc157212cf6db5b5e0eb..61f57affc76aad9e4d2185665b539f9092e3491c 100644 --- a/crates/assistant_tool/src/tool_working_set.rs +++ b/crates/assistant_tool/src/tool_working_set.rs @@ -156,13 +156,13 @@ fn resolve_context_server_tool_name_conflicts( if duplicated_tool_names.is_empty() { return context_server_tools - .into_iter() + .iter() .map(|tool| (resolve_tool_name(tool).into(), tool.clone())) .collect(); } context_server_tools - .into_iter() + .iter() .filter_map(|tool| { let mut tool_name = resolve_tool_name(tool); if !duplicated_tool_names.contains(&tool_name) { diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index bf668e691885d328ecd34b22d0a4e14633be565a..ce3b639cb2c46d3f736490c0b2153260f970963c 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -72,11 +72,10 @@ pub fn init(http_client: Arc, cx: &mut App) { register_web_search_tool(&LanguageModelRegistry::global(cx), cx); cx.subscribe( &LanguageModelRegistry::global(cx), - move |registry, event, cx| match event { - language_model::Event::DefaultModelChanged => { + move |registry, event, cx| { + if let language_model::Event::DefaultModelChanged = event { register_web_search_tool(®istry, cx); } - _ => {} }, ) .detach(); @@ -86,7 +85,7 @@ fn register_web_search_tool(registry: &Entity, cx: &mut A let using_zed_provider = registry .read(cx) .default_model() - .map_or(false, |default| default.is_provided_by_zed()); + .is_some_and(|default| default.is_provided_by_zed()); if using_zed_provider { ToolRegistry::global(cx).register_tool(WebSearchTool); } else { diff --git a/crates/assistant_tools/src/delete_path_tool.rs b/crates/assistant_tools/src/delete_path_tool.rs index b181eeff5ca0f1a45176921ed9e24973aae3839f..7c85f1ed7552931822500f76bb9f3b1b1f47fd0c 100644 --- a/crates/assistant_tools/src/delete_path_tool.rs +++ b/crates/assistant_tools/src/delete_path_tool.rs @@ -35,7 +35,7 @@ impl Tool for DeletePathTool { } fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - false + true } fn may_perform_edits(&self) -> bool { diff --git a/crates/assistant_tools/src/edit_agent.rs b/crates/assistant_tools/src/edit_agent.rs index aa321aa8f30117e21a04e4acb52b5c5cdbfedfaa..29ac53e2a606d63873f515aff25326debf0486f1 100644 --- a/crates/assistant_tools/src/edit_agent.rs +++ b/crates/assistant_tools/src/edit_agent.rs @@ -672,29 +672,30 @@ impl EditAgent { cx: &mut AsyncApp, ) -> Result>> { let mut messages_iter = conversation.messages.iter_mut(); - if let Some(last_message) = messages_iter.next_back() { - if last_message.role == Role::Assistant { - let old_content_len = last_message.content.len(); - last_message - .content - .retain(|content| !matches!(content, MessageContent::ToolUse(_))); - let new_content_len = last_message.content.len(); - - // We just removed pending tool uses from the content of the - // last message, so it doesn't make sense to cache it anymore - // (e.g., the message will look very different on the next - // request). Thus, we move the flag to the message prior to it, - // as it will still be a valid prefix of the conversation. - if old_content_len != new_content_len && last_message.cache { - if let Some(prev_message) = messages_iter.next_back() { - last_message.cache = false; - prev_message.cache = true; - } - } + if let Some(last_message) = messages_iter.next_back() + && last_message.role == Role::Assistant + { + let old_content_len = last_message.content.len(); + last_message + .content + .retain(|content| !matches!(content, MessageContent::ToolUse(_))); + let new_content_len = last_message.content.len(); + + // We just removed pending tool uses from the content of the + // last message, so it doesn't make sense to cache it anymore + // (e.g., the message will look very different on the next + // request). Thus, we move the flag to the message prior to it, + // as it will still be a valid prefix of the conversation. + if old_content_len != new_content_len + && last_message.cache + && let Some(prev_message) = messages_iter.next_back() + { + last_message.cache = false; + prev_message.cache = true; + } - if last_message.content.is_empty() { - conversation.messages.pop(); - } + if last_message.content.is_empty() { + conversation.messages.pop(); } } @@ -1314,17 +1315,17 @@ mod tests { #[gpui::test(iterations = 100)] async fn test_random_indents(mut rng: StdRng) { - let len = rng.gen_range(1..=100); + let len = rng.random_range(1..=100); let new_text = util::RandomCharIter::new(&mut rng) .with_simple_text() .take(len) .collect::(); let new_text = new_text .split('\n') - .map(|line| format!("{}{}", " ".repeat(rng.gen_range(0..=8)), line)) + .map(|line| format!("{}{}", " ".repeat(rng.random_range(0..=8)), line)) .collect::>() .join("\n"); - let delta = IndentDelta::Spaces(rng.gen_range(-4..=4)); + let delta = IndentDelta::Spaces(rng.random_range(-4i8..=4i8) as isize); let chunks = to_random_chunks(&mut rng, &new_text); let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| { @@ -1356,7 +1357,7 @@ mod tests { } fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec { - let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50)); + let chunk_count = rng.random_range(1..=cmp::min(input.len(), 50)); let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count); chunk_indices.sort(); chunk_indices.push(input.len()); diff --git a/crates/assistant_tools/src/edit_agent/create_file_parser.rs b/crates/assistant_tools/src/edit_agent/create_file_parser.rs index 0aad9ecb87c1426486b531ac4291913cd0d74092..5126f9c6b1fe4ee5cc600ae93b7300b7af09451f 100644 --- a/crates/assistant_tools/src/edit_agent/create_file_parser.rs +++ b/crates/assistant_tools/src/edit_agent/create_file_parser.rs @@ -204,7 +204,7 @@ mod tests { } fn parse_random_chunks(input: &str, parser: &mut CreateFileParser, rng: &mut StdRng) -> String { - let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50)); + let chunk_count = rng.random_range(1..=cmp::min(input.len(), 50)); let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count); chunk_indices.sort(); chunk_indices.push(input.len()); diff --git a/crates/assistant_tools/src/edit_agent/edit_parser.rs b/crates/assistant_tools/src/edit_agent/edit_parser.rs index db58c2bf3685030abfa6cfdd506c068c6643dce8..8411171ba4ea491d2603014a0715ce471b34e36f 100644 --- a/crates/assistant_tools/src/edit_agent/edit_parser.rs +++ b/crates/assistant_tools/src/edit_agent/edit_parser.rs @@ -996,7 +996,7 @@ mod tests { } fn parse_random_chunks(input: &str, parser: &mut EditParser, rng: &mut StdRng) -> Vec { - let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50)); + let chunk_count = rng.random_range(1..=cmp::min(input.len(), 50)); let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count); chunk_indices.sort(); chunk_indices.push(input.len()); diff --git a/crates/assistant_tools/src/edit_agent/evals.rs b/crates/assistant_tools/src/edit_agent/evals.rs index 9a8e7624559e9a1284ace7c932f428c7389b6254..515e22d5f8b184a875cd91038d7bfa0a7d8127a7 100644 --- a/crates/assistant_tools/src/edit_agent/evals.rs +++ b/crates/assistant_tools/src/edit_agent/evals.rs @@ -1153,8 +1153,7 @@ impl EvalInput { .expect("Conversation must end with an edit_file tool use") .clone(); - let edit_file_input: EditFileToolInput = - serde_json::from_value(tool_use.input.clone()).unwrap(); + let edit_file_input: EditFileToolInput = serde_json::from_value(tool_use.input).unwrap(); EvalInput { conversation, @@ -1283,14 +1282,14 @@ impl EvalAssertion { // Parse the score from the response let re = regex::Regex::new(r"(\d+)").unwrap(); - if let Some(captures) = re.captures(&output) { - if let Some(score_match) = captures.get(1) { - let score = score_match.as_str().parse().unwrap_or(0); - return Ok(EvalAssertionOutcome { - score, - message: Some(output), - }); - } + if let Some(captures) = re.captures(&output) + && let Some(score_match) = captures.get(1) + { + let score = score_match.as_str().parse().unwrap_or(0); + return Ok(EvalAssertionOutcome { + score, + message: Some(output), + }); } anyhow::bail!("No score found in response. Raw output: {output}"); @@ -1400,7 +1399,7 @@ fn eval( } fn run_eval(eval: EvalInput, tx: mpsc::Sender>) { - let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy()); + let dispatcher = gpui::TestDispatcher::new(StdRng::from_os_rng()); let mut cx = TestAppContext::build(dispatcher, None); let output = cx.executor().block_test(async { let test = EditAgentTest::new(&mut cx).await; @@ -1460,7 +1459,7 @@ impl EditAgentTest { async fn new(cx: &mut TestAppContext) -> Self { cx.executor().allow_parking(); - let fs = FakeFs::new(cx.executor().clone()); + let fs = FakeFs::new(cx.executor()); cx.update(|cx| { settings::init(cx); gpui_tokio::init(cx); @@ -1475,7 +1474,7 @@ impl EditAgentTest { Project::init_settings(cx); language::init(cx); language_model::init(client.clone(), cx); - language_models::init(user_store.clone(), client.clone(), cx); + language_models::init(user_store, client.clone(), cx); crate::init(client.http_client(), cx); }); @@ -1521,7 +1520,15 @@ impl EditAgentTest { selected_model: &SelectedModel, cx: &mut AsyncApp, ) -> Result> { - let (provider, model) = cx.update(|cx| { + cx.update(|cx| { + let registry = LanguageModelRegistry::read_global(cx); + let provider = registry + .provider(&selected_model.provider) + .expect("Provider not found"); + provider.authenticate(cx) + })? + .await?; + cx.update(|cx| { let models = LanguageModelRegistry::read_global(cx); let model = models .available_models(cx) @@ -1530,11 +1537,8 @@ impl EditAgentTest { && model.id() == selected_model.model }) .expect("Model not found"); - let provider = models.provider(&model.provider_id()).unwrap(); - (provider, model) - })?; - cx.update(|cx| provider.authenticate(cx))?.await?; - Ok(model) + model + }) } async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result { @@ -1586,7 +1590,7 @@ impl EditAgentTest { let has_system_prompt = eval .conversation .first() - .map_or(false, |msg| msg.role == Role::System); + .is_some_and(|msg| msg.role == Role::System); let messages = if has_system_prompt { eval.conversation } else { @@ -1708,7 +1712,7 @@ async fn retry_on_rate_limit(mut request: impl AsyncFnMut() -> Result) -> }; if let Some(retry_after) = retry_delay { - let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); + let jitter = retry_after.mul_f64(rand::rng().random_range(0.0..1.0)); eprintln!("Attempt #{attempt}: Retry after {retry_after:?} + jitter of {jitter:?}"); Timer::after(retry_after + jitter).await; } else { diff --git a/crates/assistant_tools/src/edit_agent/streaming_fuzzy_matcher.rs b/crates/assistant_tools/src/edit_agent/streaming_fuzzy_matcher.rs index 092bdce8b347ee5bcb5849703533710652b5b01c..386b8204400a157b37b2f356829fa27df3abca92 100644 --- a/crates/assistant_tools/src/edit_agent/streaming_fuzzy_matcher.rs +++ b/crates/assistant_tools/src/edit_agent/streaming_fuzzy_matcher.rs @@ -319,7 +319,7 @@ mod tests { ); let snapshot = buffer.snapshot(); - let mut finder = StreamingFuzzyMatcher::new(snapshot.clone()); + let mut finder = StreamingFuzzyMatcher::new(snapshot); assert_eq!(push(&mut finder, ""), None); assert_eq!(finish(finder), None); } @@ -333,7 +333,7 @@ mod tests { ); let snapshot = buffer.snapshot(); - let mut finder = StreamingFuzzyMatcher::new(snapshot.clone()); + let mut finder = StreamingFuzzyMatcher::new(snapshot); // Push partial query assert_eq!(push(&mut finder, "This"), None); @@ -365,7 +365,7 @@ mod tests { ); let snapshot = buffer.snapshot(); - let mut finder = StreamingFuzzyMatcher::new(snapshot.clone()); + let mut finder = StreamingFuzzyMatcher::new(snapshot); // Push a fuzzy query that should match the first function assert_eq!( @@ -391,7 +391,7 @@ mod tests { ); let snapshot = buffer.snapshot(); - let mut finder = StreamingFuzzyMatcher::new(snapshot.clone()); + let mut finder = StreamingFuzzyMatcher::new(snapshot); // No match initially assert_eq!(push(&mut finder, "Lin"), None); @@ -420,7 +420,7 @@ mod tests { ); let snapshot = buffer.snapshot(); - let mut finder = StreamingFuzzyMatcher::new(snapshot.clone()); + let mut finder = StreamingFuzzyMatcher::new(snapshot); // Push text in small chunks across line boundaries assert_eq!(push(&mut finder, "jumps "), None); // No newline yet @@ -458,7 +458,7 @@ mod tests { ); let snapshot = buffer.snapshot(); - let mut finder = StreamingFuzzyMatcher::new(snapshot.clone()); + let mut finder = StreamingFuzzyMatcher::new(snapshot); assert_eq!( push(&mut finder, "impl Debug for User {\n"), @@ -711,7 +711,7 @@ mod tests { "Expected to match `second_function` based on the line hint" ); - let mut matcher = StreamingFuzzyMatcher::new(snapshot.clone()); + let mut matcher = StreamingFuzzyMatcher::new(snapshot); matcher.push(query, None); matcher.finish(); let best_match = matcher.select_best_match(); @@ -727,7 +727,7 @@ mod tests { let buffer = TextBuffer::new(0, BufferId::new(1).unwrap(), text.clone()); let snapshot = buffer.snapshot(); - let mut matcher = StreamingFuzzyMatcher::new(snapshot.clone()); + let mut matcher = StreamingFuzzyMatcher::new(snapshot); // Split query into random chunks let chunks = to_random_chunks(rng, query); @@ -771,7 +771,7 @@ mod tests { } fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec { - let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50)); + let chunk_count = rng.random_range(1..=cmp::min(input.len(), 50)); let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count); chunk_indices.sort(); chunk_indices.push(input.len()); @@ -794,10 +794,8 @@ mod tests { fn finish(mut finder: StreamingFuzzyMatcher) -> Option { let snapshot = finder.snapshot.clone(); let matches = finder.finish(); - if let Some(range) = matches.first() { - Some(snapshot.text_for_range(range.clone()).collect::()) - } else { - None - } + matches + .first() + .map(|range| snapshot.text_for_range(range.clone()).collect::()) } } diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs index e819c51e1edb841954508dbfad0fd1d2e85b51c4..d13f9891c3af1933ee49428c223d3e6737871047 100644 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -11,11 +11,13 @@ use assistant_tool::{ AnyToolCard, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus, }; use buffer_diff::{BufferDiff, BufferDiffSnapshot}; -use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer, PathKey}; +use editor::{ + Editor, EditorMode, MinimapVisibility, MultiBuffer, PathKey, multibuffer_context_lines, +}; use futures::StreamExt; use gpui::{ Animation, AnimationExt, AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task, - TextStyleRefinement, Transformation, WeakEntity, percentage, pulsating_between, px, + TextStyleRefinement, WeakEntity, pulsating_between, px, }; use indoc::formatdoc; use language::{ @@ -42,7 +44,7 @@ use std::{ time::Duration, }; use theme::ThemeSettings; -use ui::{Disclosure, Tooltip, prelude::*}; +use ui::{CommonAnimationExt, Disclosure, Tooltip, prelude::*}; use util::ResultExt; use workspace::Workspace; @@ -155,10 +157,10 @@ impl Tool for EditFileTool { // It's also possible that the global config dir is configured to be inside the project, // so check for that edge case too. - if let Ok(canonical_path) = std::fs::canonicalize(&input.path) { - if canonical_path.starts_with(paths::config_dir()) { - return true; - } + if let Ok(canonical_path) = std::fs::canonicalize(&input.path) + && canonical_path.starts_with(paths::config_dir()) + { + return true; } // Check if path is inside the global config directory @@ -199,10 +201,10 @@ impl Tool for EditFileTool { .any(|c| c.as_os_str() == local_settings_folder.as_os_str()) { description.push_str(" (local settings)"); - } else if let Ok(canonical_path) = std::fs::canonicalize(&input.path) { - if canonical_path.starts_with(paths::config_dir()) { - description.push_str(" (global settings)"); - } + } else if let Ok(canonical_path) = std::fs::canonicalize(&input.path) + && canonical_path.starts_with(paths::config_dir()) + { + description.push_str(" (global settings)"); } description @@ -376,7 +378,7 @@ impl Tool for EditFileTool { let output = EditFileToolOutput { original_path: project_path.path.to_path_buf(), - new_text: new_text.clone(), + new_text, old_text, raw_output: Some(agent_output), }; @@ -474,7 +476,7 @@ impl Tool for EditFileTool { PathKey::for_buffer(&buffer, cx), buffer, diff_hunk_ranges, - editor::DEFAULT_MULTIBUFFER_CONTEXT, + multibuffer_context_lines(cx), cx, ); multibuffer.add_diff(buffer_diff, cx); @@ -536,7 +538,7 @@ fn resolve_path( let parent_entry = parent_project_path .as_ref() - .and_then(|path| project.entry_for_path(&path, cx)) + .and_then(|path| project.entry_for_path(path, cx)) .context("Can't create file: parent directory doesn't exist")?; anyhow::ensure!( @@ -643,7 +645,7 @@ impl EditFileToolCard { diff }); - self.buffer = Some(buffer.clone()); + self.buffer = Some(buffer); self.base_text = Some(base_text.into()); self.buffer_diff = Some(buffer_diff.clone()); @@ -703,7 +705,7 @@ impl EditFileToolCard { PathKey::for_buffer(buffer, cx), buffer.clone(), ranges, - editor::DEFAULT_MULTIBUFFER_CONTEXT, + multibuffer_context_lines(cx), cx, ); let end = multibuffer.len(cx); @@ -723,13 +725,13 @@ impl EditFileToolCard { let buffer = buffer.read(cx); let diff = diff.read(cx); let mut ranges = diff - .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx) - .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer)) + .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, buffer, cx) + .map(|diff_hunk| diff_hunk.buffer_range.to_point(buffer)) .collect::>(); ranges.extend( self.revealed_ranges .iter() - .map(|range| range.to_point(&buffer)), + .map(|range| range.to_point(buffer)), ); ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end))); @@ -776,7 +778,6 @@ impl EditFileToolCard { let buffer_diff = cx.spawn({ let buffer = buffer.clone(); - let language_registry = language_registry.clone(); async move |_this, cx| { build_buffer_diff(base_text, &buffer, &language_registry, cx).await } @@ -792,7 +793,7 @@ impl EditFileToolCard { path_key, buffer, ranges, - editor::DEFAULT_MULTIBUFFER_CONTEXT, + multibuffer_context_lines(cx), cx, ); multibuffer.add_diff(buffer_diff.clone(), cx); @@ -863,7 +864,6 @@ impl ToolCard for EditFileToolCard { ) .on_click({ let path = self.path.clone(); - let workspace = workspace.clone(); move |_, window, cx| { workspace .update(cx, { @@ -939,11 +939,7 @@ impl ToolCard for EditFileToolCard { Icon::new(IconName::ArrowCircle) .size(IconSize::XSmall) .color(Color::Info) - .with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), - ), + .with_rotate_animation(2), ) }) .when_some(error_message, |header, error_message| { @@ -1356,8 +1352,7 @@ mod tests { mode: mode.clone(), }; - let result = cx.update(|cx| resolve_path(&input, project, cx)); - result + cx.update(|cx| resolve_path(&input, project, cx)) } fn assert_resolved_path_eq(path: anyhow::Result, expected: &str) { diff --git a/crates/assistant_tools/src/fetch_tool.rs b/crates/assistant_tools/src/fetch_tool.rs index 79e205f205d02ba2a3f977163d2296423f71d9da..cc22c9fc09f73914720c4b639f8d273207d7ca53 100644 --- a/crates/assistant_tools/src/fetch_tool.rs +++ b/crates/assistant_tools/src/fetch_tool.rs @@ -118,7 +118,7 @@ impl Tool for FetchTool { } fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - false + true } fn may_perform_edits(&self) -> bool { diff --git a/crates/assistant_tools/src/find_path_tool.rs b/crates/assistant_tools/src/find_path_tool.rs index 6b62638a4c33a3a4d29f7af51d3688a06f9c1dee..d1451132aeb066a5d4ff9e05f81db3855c1d513a 100644 --- a/crates/assistant_tools/src/find_path_tool.rs +++ b/crates/assistant_tools/src/find_path_tool.rs @@ -234,7 +234,7 @@ impl ToolCard for FindPathToolCard { workspace: WeakEntity, cx: &mut Context, ) -> impl IntoElement { - let matches_label: SharedString = if self.paths.len() == 0 { + let matches_label: SharedString = if self.paths.is_empty() { "No matches".into() } else if self.paths.len() == 1 { "1 match".into() @@ -435,8 +435,8 @@ mod test { assert_eq!( matches, &[ - PathBuf::from("root/apple/banana/carrot"), - PathBuf::from("root/apple/bandana/carbonara") + PathBuf::from(path!("root/apple/banana/carrot")), + PathBuf::from(path!("root/apple/bandana/carbonara")) ] ); @@ -447,8 +447,8 @@ mod test { assert_eq!( matches, &[ - PathBuf::from("root/apple/banana/carrot"), - PathBuf::from("root/apple/bandana/carbonara") + PathBuf::from(path!("root/apple/banana/carrot")), + PathBuf::from(path!("root/apple/bandana/carbonara")) ] ); } diff --git a/crates/assistant_tools/src/grep_tool.rs b/crates/assistant_tools/src/grep_tool.rs index a5ce07823fd68ff9531c7d834973166384645601..e43a54661ca146902a49fa1d975e44d486e18587 100644 --- a/crates/assistant_tools/src/grep_tool.rs +++ b/crates/assistant_tools/src/grep_tool.rs @@ -188,15 +188,14 @@ impl Tool for GrepTool { // Check if this file should be excluded based on its worktree settings if let Ok(Some(project_path)) = project.read_with(cx, |project, cx| { project.find_project_path(&path, cx) - }) { - if cx.update(|cx| { + }) + && cx.update(|cx| { let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx); worktree_settings.is_path_excluded(&project_path.path) || worktree_settings.is_path_private(&project_path.path) }).unwrap_or(false) { continue; } - } while *parse_status.borrow() != ParseStatus::Idle { parse_status.changed().await?; @@ -268,10 +267,8 @@ impl Tool for GrepTool { let end_row = range.end.row; output.push_str("\n### "); - if let Some(parent_symbols) = &parent_symbols { - for symbol in parent_symbols { - write!(output, "{} › ", symbol.text)?; - } + for symbol in parent_symbols { + write!(output, "{} › ", symbol.text)?; } if range.start.row == end_row { @@ -284,12 +281,11 @@ impl Tool for GrepTool { output.extend(snapshot.text_for_range(range)); output.push_str("\n```\n"); - if let Some(ancestor_range) = ancestor_range { - if end_row < ancestor_range.end.row { + if let Some(ancestor_range) = ancestor_range + && end_row < ancestor_range.end.row { let remaining_lines = ancestor_range.end.row - end_row; writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?; } - } matches_found += 1; } @@ -329,7 +325,7 @@ mod tests { init_test(cx); cx.executor().allow_parking(); - let fs = FakeFs::new(cx.executor().clone()); + let fs = FakeFs::new(cx.executor()); fs.insert_tree( path!("/root"), serde_json::json!({ @@ -417,7 +413,7 @@ mod tests { init_test(cx); cx.executor().allow_parking(); - let fs = FakeFs::new(cx.executor().clone()); + let fs = FakeFs::new(cx.executor()); fs.insert_tree( path!("/root"), serde_json::json!({ @@ -496,7 +492,7 @@ mod tests { init_test(cx); cx.executor().allow_parking(); - let fs = FakeFs::new(cx.executor().clone()); + let fs = FakeFs::new(cx.executor()); // Create test file with syntax structures fs.insert_tree( @@ -894,7 +890,7 @@ mod tests { }) .await; let results = result.unwrap(); - let paths = extract_paths_from_results(&results.content.as_str().unwrap()); + let paths = extract_paths_from_results(results.content.as_str().unwrap()); assert!( paths.is_empty(), "grep_tool should not find files outside the project worktree" @@ -920,7 +916,7 @@ mod tests { }) .await; let results = result.unwrap(); - let paths = extract_paths_from_results(&results.content.as_str().unwrap()); + let paths = extract_paths_from_results(results.content.as_str().unwrap()); assert!( paths.iter().any(|p| p.contains("allowed_file.rs")), "grep_tool should be able to search files inside worktrees" @@ -946,7 +942,7 @@ mod tests { }) .await; let results = result.unwrap(); - let paths = extract_paths_from_results(&results.content.as_str().unwrap()); + let paths = extract_paths_from_results(results.content.as_str().unwrap()); assert!( paths.is_empty(), "grep_tool should not search files in .secretdir (file_scan_exclusions)" @@ -971,7 +967,7 @@ mod tests { }) .await; let results = result.unwrap(); - let paths = extract_paths_from_results(&results.content.as_str().unwrap()); + let paths = extract_paths_from_results(results.content.as_str().unwrap()); assert!( paths.is_empty(), "grep_tool should not search .mymetadata files (file_scan_exclusions)" @@ -997,7 +993,7 @@ mod tests { }) .await; let results = result.unwrap(); - let paths = extract_paths_from_results(&results.content.as_str().unwrap()); + let paths = extract_paths_from_results(results.content.as_str().unwrap()); assert!( paths.is_empty(), "grep_tool should not search .mysecrets (private_files)" @@ -1022,7 +1018,7 @@ mod tests { }) .await; let results = result.unwrap(); - let paths = extract_paths_from_results(&results.content.as_str().unwrap()); + let paths = extract_paths_from_results(results.content.as_str().unwrap()); assert!( paths.is_empty(), "grep_tool should not search .privatekey files (private_files)" @@ -1047,7 +1043,7 @@ mod tests { }) .await; let results = result.unwrap(); - let paths = extract_paths_from_results(&results.content.as_str().unwrap()); + let paths = extract_paths_from_results(results.content.as_str().unwrap()); assert!( paths.is_empty(), "grep_tool should not search .mysensitive files (private_files)" @@ -1073,7 +1069,7 @@ mod tests { }) .await; let results = result.unwrap(); - let paths = extract_paths_from_results(&results.content.as_str().unwrap()); + let paths = extract_paths_from_results(results.content.as_str().unwrap()); assert!( paths.iter().any(|p| p.contains("normal_file.rs")), "Should be able to search normal files" @@ -1100,7 +1096,7 @@ mod tests { }) .await; let results = result.unwrap(); - let paths = extract_paths_from_results(&results.content.as_str().unwrap()); + let paths = extract_paths_from_results(results.content.as_str().unwrap()); assert!( paths.is_empty(), "grep_tool should not allow escaping project boundaries with relative paths" @@ -1206,7 +1202,7 @@ mod tests { .unwrap(); let content = result.content.as_str().unwrap(); - let paths = extract_paths_from_results(&content); + let paths = extract_paths_from_results(content); // Should find matches in non-private files assert!( @@ -1271,7 +1267,7 @@ mod tests { .unwrap(); let content = result.content.as_str().unwrap(); - let paths = extract_paths_from_results(&content); + let paths = extract_paths_from_results(content); // Should only find matches in worktree1 *.rs files (excluding private ones) assert!( diff --git a/crates/assistant_tools/src/project_notifications_tool.rs b/crates/assistant_tools/src/project_notifications_tool.rs index c65cfd0ca76d91f454982ded5f2893159ab7a32a..e30d80207dae4de1e69efe99724a2a5343b57664 100644 --- a/crates/assistant_tools/src/project_notifications_tool.rs +++ b/crates/assistant_tools/src/project_notifications_tool.rs @@ -81,7 +81,7 @@ fn fit_patch_to_size(patch: &str, max_size: usize) -> String { // Compression level 1: remove context lines in diff bodies, but // leave the counts and positions of inserted/deleted lines let mut current_size = patch.len(); - let mut file_patches = split_patch(&patch); + let mut file_patches = split_patch(patch); file_patches.sort_by_key(|patch| patch.len()); let compressed_patches = file_patches .iter() diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index 68b870e40f4af23f4eb27f68c8d45d4789f6bc48..a6e984fca6f2704a6dbe4c16d5e659f0c8bfe141 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -68,7 +68,7 @@ impl Tool for ReadFileTool { } fn icon(&self) -> IconName { - IconName::ToolRead + IconName::ToolSearch } fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { @@ -201,7 +201,7 @@ impl Tool for ReadFileTool { buffer .file() .as_ref() - .map_or(true, |file| !file.disk_state().exists()) + .is_none_or(|file| !file.disk_state().exists()) })? { anyhow::bail!("{file_path} not found"); } diff --git a/crates/assistant_tools/src/schema.rs b/crates/assistant_tools/src/schema.rs index 10a8bf0acd99131d2c0a80411072f312c9a42f50..dab7384efd8ba23669db645c87dcf79e95538d3a 100644 --- a/crates/assistant_tools/src/schema.rs +++ b/crates/assistant_tools/src/schema.rs @@ -43,12 +43,11 @@ impl Transform for ToJsonSchemaSubsetTransform { fn transform(&mut self, schema: &mut Schema) { // Ensure that the type field is not an array, this happens when we use // Option, the type will be [T, "null"]. - if let Some(type_field) = schema.get_mut("type") { - if let Some(types) = type_field.as_array() { - if let Some(first_type) = types.first() { - *type_field = first_type.clone(); - } - } + if let Some(type_field) = schema.get_mut("type") + && let Some(types) = type_field.as_array() + && let Some(first_type) = types.first() + { + *type_field = first_type.clone(); } // oneOf is not supported, use anyOf instead diff --git a/crates/assistant_tools/src/terminal_tool.rs b/crates/assistant_tools/src/terminal_tool.rs index 46227f130d6c706c598466045057075786b21cd3..1605003671621b90e58a5f62e521c0aba2c990c6 100644 --- a/crates/assistant_tools/src/terminal_tool.rs +++ b/crates/assistant_tools/src/terminal_tool.rs @@ -8,14 +8,14 @@ use anyhow::{Context as _, Result, anyhow}; use assistant_tool::{Tool, ToolCard, ToolResult, ToolUseStatus}; use futures::{FutureExt as _, future::Shared}; use gpui::{ - Animation, AnimationExt, AnyWindowHandle, App, AppContext, Empty, Entity, EntityId, Task, - TextStyleRefinement, Transformation, WeakEntity, Window, percentage, + AnyWindowHandle, App, AppContext, Empty, Entity, EntityId, Task, TextStyleRefinement, + WeakEntity, Window, }; use language::LineEnding; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; use markdown::{Markdown, MarkdownElement, MarkdownStyle}; use portable_pty::{CommandBuilder, PtySize, native_pty_system}; -use project::{Project, terminals::TerminalKind}; +use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; @@ -28,7 +28,7 @@ use std::{ }; use terminal_view::TerminalView; use theme::ThemeSettings; -use ui::{Disclosure, Tooltip, prelude::*}; +use ui::{CommonAnimationExt, Disclosure, Tooltip, prelude::*}; use util::{ ResultExt, get_system_shell, markdown::MarkdownInlineCode, size::format_file_size, time::duration_alt_display, @@ -59,12 +59,9 @@ impl TerminalTool { } if which::which("bash").is_ok() { - log::info!("agent selected bash for terminal tool"); "bash".into() } else { - let shell = get_system_shell(); - log::info!("agent selected {shell} for terminal tool"); - shell + get_system_shell() } }); Self { @@ -105,7 +102,7 @@ impl Tool for TerminalTool { let first_line = lines.next().unwrap_or_default(); let remaining_line_count = lines.count(); match remaining_line_count { - 0 => MarkdownInlineCode(&first_line).to_string(), + 0 => MarkdownInlineCode(first_line).to_string(), 1 => MarkdownInlineCode(&format!( "{} - {} more line", first_line, remaining_line_count @@ -216,21 +213,20 @@ impl Tool for TerminalTool { async move |cx| { let program = program.await; let env = env.await; - let terminal = project + project .update(cx, |project, cx| { - project.create_terminal( - TerminalKind::Task(task::SpawnInTerminal { + project.create_terminal_task( + task::SpawnInTerminal { command: Some(program), args, cwd, env, ..Default::default() - }), + }, cx, ) })? - .await; - terminal + .await } }); @@ -353,7 +349,7 @@ fn process_content( if is_empty { "Command executed successfully.".to_string() } else { - content.to_string() + content } } Some(exit_status) => { @@ -387,7 +383,7 @@ fn working_dir( let project = project.read(cx); let cd = &input.cd; - if cd == "." || cd == "" { + if cd == "." || cd.is_empty() { // Accept "." or "" as meaning "the one worktree" if we only have one worktree. let mut worktrees = project.worktrees(cx); @@ -412,10 +408,8 @@ fn working_dir( { return Ok(Some(input_path.into())); } - } else { - if let Some(worktree) = project.worktree_for_root_name(cd, cx) { - return Ok(Some(worktree.read(cx).abs_path().to_path_buf())); - } + } else if let Some(worktree) = project.worktree_for_root_name(cd, cx) { + return Ok(Some(worktree.read(cx).abs_path().to_path_buf())); } anyhow::bail!("`cd` directory {cd:?} was not in any of the project's worktrees."); @@ -528,11 +522,7 @@ impl ToolCard for TerminalToolCard { Icon::new(IconName::ArrowCircle) .size(IconSize::XSmall) .color(Color::Info) - .with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), - ), + .with_rotate_animation(2), ) }) .when(tool_failed || command_failed, |header| { diff --git a/crates/assistant_tools/src/ui/tool_call_card_header.rs b/crates/assistant_tools/src/ui/tool_call_card_header.rs index b71453373feb84d91168576a5bc7c22f8d883aa9..b41f19432f99685cf745f684228169b53939fffb 100644 --- a/crates/assistant_tools/src/ui/tool_call_card_header.rs +++ b/crates/assistant_tools/src/ui/tool_call_card_header.rs @@ -101,14 +101,11 @@ impl RenderOnce for ToolCallCardHeader { }) .when_some(secondary_text, |this, secondary_text| { this.child(bullet_divider()) - .child(div().text_size(font_size).child(secondary_text.clone())) + .child(div().text_size(font_size).child(secondary_text)) }) .when_some(code_path, |this, code_path| { - this.child(bullet_divider()).child( - Label::new(code_path.clone()) - .size(LabelSize::Small) - .inline_code(cx), - ) + this.child(bullet_divider()) + .child(Label::new(code_path).size(LabelSize::Small).inline_code(cx)) }) .with_animation( "loading-label", diff --git a/crates/assistant_tools/src/web_search_tool.rs b/crates/assistant_tools/src/web_search_tool.rs index 47a6958b7ad278f01fb654d23b68360d562d73e9..dbcca0a1f6f2d5f679fd240a5bfe64c6c9705256 100644 --- a/crates/assistant_tools/src/web_search_tool.rs +++ b/crates/assistant_tools/src/web_search_tool.rs @@ -193,10 +193,7 @@ impl ToolCard for WebSearchToolCard { ) } }) - .on_click({ - let url = url.clone(); - move |_, _, cx| cx.open_url(&url) - }) + .on_click(move |_, _, cx| cx.open_url(&url)) })) .into_any(), ), diff --git a/crates/audio/Cargo.toml b/crates/audio/Cargo.toml index f1f40ad6540eea847313efdb4dceedbd4b27f6df..08e0df424dcdaa15cfd78fddaf5758fb9b8d7e0b 100644 --- a/crates/audio/Cargo.toml +++ b/crates/audio/Cargo.toml @@ -14,10 +14,19 @@ doctest = false [dependencies] anyhow.workspace = true +async-tar.workspace = true collections.workspace = true -derive_more.workspace = true +crossbeam.workspace = true gpui.workspace = true +log.workspace = true parking_lot.workspace = true -rodio = { workspace = true, features = ["wav", "playback", "tracing"] } +rodio = { workspace = true, features = [ "wav", "playback", "wav_output" ] } +schemars.workspace = true +serde.workspace = true +settings.workspace = true +smol.workspace = true util.workspace = true workspace-hack.workspace = true + +[target.'cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))'.dependencies] +libwebrtc = { rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d", git = "https://github.com/zed-industries/livekit-rust-sdks" } diff --git a/crates/audio/src/assets.rs b/crates/audio/src/assets.rs deleted file mode 100644 index fd5c935d875960f4fd9bf30494301f4811b22448..0000000000000000000000000000000000000000 --- a/crates/audio/src/assets.rs +++ /dev/null @@ -1,54 +0,0 @@ -use std::{io::Cursor, sync::Arc}; - -use anyhow::{Context as _, Result}; -use collections::HashMap; -use gpui::{App, AssetSource, Global}; -use rodio::{Decoder, Source, source::Buffered}; - -type Sound = Buffered>>>; - -pub struct SoundRegistry { - cache: Arc>>, - assets: Box, -} - -struct GlobalSoundRegistry(Arc); - -impl Global for GlobalSoundRegistry {} - -impl SoundRegistry { - pub fn new(source: impl AssetSource) -> Arc { - Arc::new(Self { - cache: Default::default(), - assets: Box::new(source), - }) - } - - pub fn global(cx: &App) -> Arc { - cx.global::().0.clone() - } - - pub(crate) fn set_global(source: impl AssetSource, cx: &mut App) { - cx.set_global(GlobalSoundRegistry(SoundRegistry::new(source))); - } - - pub fn get(&self, name: &str) -> Result + use<>> { - if let Some(wav) = self.cache.lock().get(name) { - return Ok(wav.clone()); - } - - let path = format!("sounds/{}.wav", name); - let bytes = self - .assets - .load(&path)? - .map(anyhow::Ok) - .with_context(|| format!("No asset available for path {path}"))?? - .into_owned(); - let cursor = Cursor::new(bytes); - let source = Decoder::new(cursor)?.buffered(); - - self.cache.lock().insert(name.to_string(), source.clone()); - - Ok(source) - } -} diff --git a/crates/audio/src/audio.rs b/crates/audio/src/audio.rs index 44baa16aa20a3e4b7651744974cfc085dcde7fb1..511d00671ae99789610bac1f7e30b63ca29ac480 100644 --- a/crates/audio/src/audio.rs +++ b/crates/audio/src/audio.rs @@ -1,16 +1,54 @@ -use assets::SoundRegistry; -use derive_more::{Deref, DerefMut}; -use gpui::{App, AssetSource, BorrowAppContext, Global}; -use rodio::{OutputStream, OutputStreamBuilder}; +use anyhow::{Context as _, Result}; +use collections::HashMap; +use gpui::{App, AsyncApp, BackgroundExecutor, BorrowAppContext, Global}; +use libwebrtc::native::apm; +use log::info; +use parking_lot::Mutex; +use rodio::{ + Decoder, OutputStream, OutputStreamBuilder, Source, + cpal::Sample, + mixer::Mixer, + nz, + source::{Buffered, LimitSettings, UniformSourceIterator}, +}; +use settings::Settings; +use std::{ + io::Cursor, + num::NonZero, + path::PathBuf, + sync::{Arc, atomic::Ordering}, + time::Duration, +}; use util::ResultExt; -mod assets; +mod audio_settings; +mod replays; +mod rodio_ext; +pub use audio_settings::AudioSettings; +pub use rodio_ext::RodioExt; -pub fn init(source: impl AssetSource, cx: &mut App) { - SoundRegistry::set_global(source, cx); - cx.set_global(GlobalAudio(Audio::new())); +use crate::audio_settings::LIVE_SETTINGS; + +// NOTE: We used to use WebRTC's mixer which only supported +// 16kHz, 32kHz and 48kHz. As 48 is the most common "next step up" +// for audio output devices like speakers/bluetooth, we just hard-code +// this; and downsample when we need to. +// +// Since most noise cancelling requires 16kHz we will move to +// that in the future. +pub const SAMPLE_RATE: NonZero = nz!(48000); +pub const CHANNEL_COUNT: NonZero = nz!(2); +pub const BUFFER_SIZE: usize = // echo canceller and livekit want 10ms of audio + (SAMPLE_RATE.get() as usize / 100) * CHANNEL_COUNT.get() as usize; + +pub const REPLAY_DURATION: Duration = Duration::from_secs(30); + +pub fn init(cx: &mut App) { + AudioSettings::register(cx); + LIVE_SETTINGS.initialize(cx); } +#[derive(Debug, Copy, Clone, Eq, Hash, PartialEq)] pub enum Sound { Joined, Leave, @@ -35,49 +73,196 @@ impl Sound { } } -#[derive(Default)] pub struct Audio { output_handle: Option, + output_mixer: Option, + pub echo_canceller: Arc>, + source_cache: HashMap>>>>, + replays: replays::Replays, } -#[derive(Deref, DerefMut)] -struct GlobalAudio(Audio); +impl Default for Audio { + fn default() -> Self { + Self { + output_handle: Default::default(), + output_mixer: Default::default(), + echo_canceller: Arc::new(Mutex::new(apm::AudioProcessingModule::new( + true, false, false, false, + ))), + source_cache: Default::default(), + replays: Default::default(), + } + } +} -impl Global for GlobalAudio {} +impl Global for Audio {} impl Audio { - pub fn new() -> Self { - Self::default() - } - - fn ensure_output_exists(&mut self) -> Option<&OutputStream> { + fn ensure_output_exists(&mut self) -> Result<&Mixer> { if self.output_handle.is_none() { - self.output_handle = OutputStreamBuilder::open_default_stream().log_err(); + self.output_handle = Some( + OutputStreamBuilder::open_default_stream() + .context("Could not open default output stream")?, + ); + if let Some(output_handle) = &self.output_handle { + let (mixer, source) = rodio::mixer::mixer(CHANNEL_COUNT, SAMPLE_RATE); + // or the mixer will end immediately as its empty. + mixer.add(rodio::source::Zero::new(CHANNEL_COUNT, SAMPLE_RATE)); + self.output_mixer = Some(mixer); + + let echo_canceller = Arc::clone(&self.echo_canceller); + let source = source.inspect_buffer::(move |buffer| { + let mut buf: [i16; _] = buffer.map(|s| s.to_sample()); + echo_canceller + .lock() + .process_reverse_stream( + &mut buf, + SAMPLE_RATE.get() as i32, + CHANNEL_COUNT.get().into(), + ) + .expect("Audio input and output threads should not panic"); + }); + output_handle.mixer().add(source); + } } - self.output_handle.as_ref() + Ok(self + .output_mixer + .as_ref() + .expect("we only get here if opening the outputstream succeeded")) + } + + pub fn save_replays( + &self, + executor: BackgroundExecutor, + ) -> gpui::Task> { + self.replays.replays_to_tar(executor) + } + + pub fn open_microphone(voip_parts: VoipParts) -> anyhow::Result { + let stream = rodio::microphone::MicrophoneBuilder::new() + .default_device()? + .default_config()? + .prefer_sample_rates([SAMPLE_RATE, SAMPLE_RATE.saturating_mul(nz!(2))]) + .prefer_channel_counts([nz!(1), nz!(2)]) + .prefer_buffer_sizes(512..) + .open_stream()?; + info!("Opened microphone: {:?}", stream.config()); + + let (replay, stream) = UniformSourceIterator::new(stream, CHANNEL_COUNT, SAMPLE_RATE) + .limit(LimitSettings::live_performance()) + .process_buffer::(move |buffer| { + let mut int_buffer: [i16; _] = buffer.map(|s| s.to_sample()); + if voip_parts + .echo_canceller + .lock() + .process_stream( + &mut int_buffer, + SAMPLE_RATE.get() as i32, + CHANNEL_COUNT.get() as i32, + ) + .context("livekit audio processor error") + .log_err() + .is_some() + { + for (sample, processed) in buffer.iter_mut().zip(&int_buffer) { + *sample = (*processed).to_sample(); + } + } + }) + .automatic_gain_control(1.0, 4.0, 0.0, 5.0) + .periodic_access(Duration::from_millis(100), move |agc_source| { + agc_source.set_enabled(LIVE_SETTINGS.control_input_volume.load(Ordering::Relaxed)); + }) + .replayable(REPLAY_DURATION) + .expect("REPLAY_DURATION is longer then 100ms"); + + voip_parts + .replays + .add_voip_stream("local microphone".to_string(), replay); + Ok(stream) + } + + pub fn play_voip_stream( + source: impl rodio::Source + Send + 'static, + speaker_name: String, + is_staff: bool, + cx: &mut App, + ) -> anyhow::Result<()> { + let (replay_source, source) = source + .automatic_gain_control(1.0, 4.0, 0.0, 5.0) + .periodic_access(Duration::from_millis(100), move |agc_source| { + agc_source.set_enabled(LIVE_SETTINGS.control_input_volume.load(Ordering::Relaxed)); + }) + .replayable(REPLAY_DURATION) + .expect("REPLAY_DURATION is longer then 100ms"); + + cx.update_default_global(|this: &mut Self, _cx| { + let output_mixer = this + .ensure_output_exists() + .context("Could not get output mixer")?; + output_mixer.add(source); + if is_staff { + this.replays.add_voip_stream(speaker_name, replay_source); + } + Ok(()) + }) } pub fn play_sound(sound: Sound, cx: &mut App) { - if !cx.has_global::() { - return; - } + cx.update_default_global(|this: &mut Self, cx| { + let source = this.sound_source(sound, cx).log_err()?; + let output_mixer = this + .ensure_output_exists() + .context("Could not get output mixer") + .log_err()?; - cx.update_global::(|this, cx| { - let output_handle = this.ensure_output_exists()?; - let source = SoundRegistry::global(cx).get(sound.file()).log_err()?; - output_handle.mixer().add(source); + output_mixer.add(source); Some(()) }); } pub fn end_call(cx: &mut App) { - if !cx.has_global::() { - return; - } - - cx.update_global::(|this, _| { + cx.update_default_global(|this: &mut Self, _cx| { this.output_handle.take(); }); } + + fn sound_source(&mut self, sound: Sound, cx: &App) -> Result> { + if let Some(wav) = self.source_cache.get(&sound) { + return Ok(wav.clone()); + } + + let path = format!("sounds/{}.wav", sound.file()); + let bytes = cx + .asset_source() + .load(&path)? + .map(anyhow::Ok) + .with_context(|| format!("No asset available for path {path}"))?? + .into_owned(); + let cursor = Cursor::new(bytes); + let source = Decoder::new(cursor)?.buffered(); + + self.source_cache.insert(sound, source.clone()); + + Ok(source) + } +} + +pub struct VoipParts { + echo_canceller: Arc>, + replays: replays::Replays, +} + +impl VoipParts { + pub fn new(cx: &AsyncApp) -> anyhow::Result { + let (apm, replays) = cx.try_read_default_global::(|audio, _| { + (Arc::clone(&audio.echo_canceller), audio.replays.clone()) + })?; + + Ok(Self { + echo_canceller: apm, + replays, + }) + } } diff --git a/crates/audio/src/audio_settings.rs b/crates/audio/src/audio_settings.rs new file mode 100644 index 0000000000000000000000000000000000000000..43edb8d60d96122d5515ec7274a6b5725b247ca0 --- /dev/null +++ b/crates/audio/src/audio_settings.rs @@ -0,0 +1,96 @@ +use std::sync::atomic::{AtomicBool, Ordering}; + +use anyhow::Result; +use gpui::App; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsStore, SettingsUi}; + +#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug, SettingsUi)] +pub struct AudioSettings { + /// Opt into the new audio system. + #[serde(rename = "experimental.rodio_audio", default)] + pub rodio_audio: bool, // default is false + /// Requires 'rodio_audio: true' + /// + /// Use the new audio systems automatic gain control for your microphone. + /// This affects how loud you sound to others. + #[serde(rename = "experimental.control_input_volume", default)] + pub control_input_volume: bool, + /// Requires 'rodio_audio: true' + /// + /// Use the new audio systems automatic gain control on everyone in the + /// call. This makes call members who are too quite louder and those who are + /// too loud quieter. This only affects how things sound for you. + #[serde(rename = "experimental.control_output_volume", default)] + pub control_output_volume: bool, +} + +/// Configuration of audio in Zed. +#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug, SettingsUi, SettingsKey)] +#[serde(default)] +#[settings_key(key = "audio")] +pub struct AudioSettingsContent { + /// Opt into the new audio system. + #[serde(rename = "experimental.rodio_audio", default)] + pub rodio_audio: bool, // default is false + /// Requires 'rodio_audio: true' + /// + /// Use the new audio systems automatic gain control for your microphone. + /// This affects how loud you sound to others. + #[serde(rename = "experimental.control_input_volume", default)] + pub control_input_volume: bool, + /// Requires 'rodio_audio: true' + /// + /// Use the new audio systems automatic gain control on everyone in the + /// call. This makes call members who are too quite louder and those who are + /// too loud quieter. This only affects how things sound for you. + #[serde(rename = "experimental.control_output_volume", default)] + pub control_output_volume: bool, +} + +impl Settings for AudioSettings { + type FileContent = AudioSettingsContent; + + fn load(sources: SettingsSources, _cx: &mut App) -> Result { + sources.json_merge() + } + + fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} +} + +pub(crate) struct LiveSettings { + pub(crate) control_input_volume: AtomicBool, + pub(crate) control_output_volume: AtomicBool, +} + +impl LiveSettings { + pub(crate) fn initialize(&self, cx: &mut App) { + cx.observe_global::(move |cx| { + LIVE_SETTINGS.control_input_volume.store( + AudioSettings::get_global(cx).control_input_volume, + Ordering::Relaxed, + ); + LIVE_SETTINGS.control_output_volume.store( + AudioSettings::get_global(cx).control_output_volume, + Ordering::Relaxed, + ); + }) + .detach(); + + let init_settings = AudioSettings::get_global(cx); + LIVE_SETTINGS + .control_input_volume + .store(init_settings.control_input_volume, Ordering::Relaxed); + LIVE_SETTINGS + .control_output_volume + .store(init_settings.control_output_volume, Ordering::Relaxed); + } +} + +/// Allows access to settings from the audio thread. Updated by +/// observer of SettingsStore. +pub(crate) static LIVE_SETTINGS: LiveSettings = LiveSettings { + control_input_volume: AtomicBool::new(true), + control_output_volume: AtomicBool::new(true), +}; diff --git a/crates/audio/src/replays.rs b/crates/audio/src/replays.rs new file mode 100644 index 0000000000000000000000000000000000000000..bb21df51e5642bf633d068d544690cb26a239151 --- /dev/null +++ b/crates/audio/src/replays.rs @@ -0,0 +1,77 @@ +use anyhow::{Context, anyhow}; +use async_tar::{Builder, Header}; +use gpui::{BackgroundExecutor, Task}; + +use collections::HashMap; +use parking_lot::Mutex; +use rodio::Source; +use smol::fs::File; +use std::{io, path::PathBuf, sync::Arc, time::Duration}; + +use crate::{REPLAY_DURATION, rodio_ext::Replay}; + +#[derive(Default, Clone)] +pub(crate) struct Replays(Arc>>); + +impl Replays { + pub(crate) fn add_voip_stream(&self, stream_name: String, source: Replay) { + let mut map = self.0.lock(); + map.retain(|_, replay| replay.source_is_active()); + map.insert(stream_name, source); + } + + pub(crate) fn replays_to_tar( + &self, + executor: BackgroundExecutor, + ) -> Task> { + let map = Arc::clone(&self.0); + executor.spawn(async move { + let recordings: Vec<_> = map + .lock() + .iter_mut() + .map(|(name, replay)| { + let queued = REPLAY_DURATION.min(replay.duration_ready()); + (name.clone(), replay.take_duration(queued).record()) + }) + .collect(); + let longest = recordings + .iter() + .map(|(_, r)| { + r.total_duration() + .expect("SamplesBuffer always returns a total duration") + }) + .max() + .ok_or(anyhow!("There is no audio to capture"))?; + + let path = std::env::current_dir() + .context("Could not get current dir")? + .join("replays.tar"); + let tar = File::create(&path) + .await + .context("Could not create file for tar")?; + + let mut tar = Builder::new(tar); + + for (name, recording) in recordings { + let mut writer = io::Cursor::new(Vec::new()); + rodio::wav_to_writer(recording, &mut writer).context("failed to encode wav")?; + let wav_data = writer.into_inner(); + let path = name.replace(' ', "_") + ".wav"; + let mut header = Header::new_gnu(); + // rw permissions for everyone + header.set_mode(0o666); + header.set_size(wav_data.len() as u64); + tar.append_data(&mut header, path, wav_data.as_slice()) + .await + .context("failed to apped wav to tar")?; + } + tar.into_inner() + .await + .context("Could not finish writing tar")? + .sync_all() + .await + .context("Could not flush tar file to disk")?; + Ok((path, longest)) + }) + } +} diff --git a/crates/audio/src/rodio_ext.rs b/crates/audio/src/rodio_ext.rs new file mode 100644 index 0000000000000000000000000000000000000000..4e9430a0b9462448b879f653f9ddcb06ef892cdb --- /dev/null +++ b/crates/audio/src/rodio_ext.rs @@ -0,0 +1,593 @@ +use std::{ + sync::{ + Arc, Mutex, + atomic::{AtomicBool, Ordering}, + }, + time::Duration, +}; + +use crossbeam::queue::ArrayQueue; +use rodio::{ChannelCount, Sample, SampleRate, Source}; + +#[derive(Debug)] +pub struct ReplayDurationTooShort; + +pub trait RodioExt: Source + Sized { + fn process_buffer(self, callback: F) -> ProcessBuffer + where + F: FnMut(&mut [Sample; N]); + fn inspect_buffer(self, callback: F) -> InspectBuffer + where + F: FnMut(&[Sample; N]); + fn replayable( + self, + duration: Duration, + ) -> Result<(Replay, Replayable), ReplayDurationTooShort>; + fn take_samples(self, n: usize) -> TakeSamples; +} + +impl RodioExt for S { + fn process_buffer(self, callback: F) -> ProcessBuffer + where + F: FnMut(&mut [Sample; N]), + { + ProcessBuffer { + inner: self, + callback, + buffer: [0.0; N], + next: N, + } + } + fn inspect_buffer(self, callback: F) -> InspectBuffer + where + F: FnMut(&[Sample; N]), + { + InspectBuffer { + inner: self, + callback, + buffer: [0.0; N], + free: 0, + } + } + /// Maintains a live replay with a history of at least `duration` seconds. + /// + /// Note: + /// History can be 100ms longer if the source drops before or while the + /// replay is being read + /// + /// # Errors + /// If duration is smaller then 100ms + fn replayable( + self, + duration: Duration, + ) -> Result<(Replay, Replayable), ReplayDurationTooShort> { + if duration < Duration::from_millis(100) { + return Err(ReplayDurationTooShort); + } + + let samples_per_second = self.sample_rate().get() as usize * self.channels().get() as usize; + let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64; + let samples_to_queue = + (samples_to_queue as usize).next_multiple_of(self.channels().get().into()); + + let chunk_size = + (samples_per_second.div_ceil(10)).next_multiple_of(self.channels().get() as usize); + let chunks_to_queue = samples_to_queue.div_ceil(chunk_size); + + let is_active = Arc::new(AtomicBool::new(true)); + let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size)); + Ok(( + Replay { + rx: Arc::clone(&queue), + buffer: Vec::new().into_iter(), + sleep_duration: duration / 2, + sample_rate: self.sample_rate(), + channel_count: self.channels(), + source_is_active: is_active.clone(), + }, + Replayable { + tx: queue, + inner: self, + buffer: Vec::with_capacity(chunk_size), + chunk_size, + is_active, + }, + )) + } + fn take_samples(self, n: usize) -> TakeSamples { + TakeSamples { + inner: self, + left_to_take: n, + } + } +} + +pub struct TakeSamples { + inner: S, + left_to_take: usize, +} + +impl Iterator for TakeSamples { + type Item = Sample; + + fn next(&mut self) -> Option { + if self.left_to_take == 0 { + None + } else { + self.left_to_take -= 1; + self.inner.next() + } + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.left_to_take)) + } +} + +impl Source for TakeSamples { + fn current_span_len(&self) -> Option { + None // does not support spans + } + + fn channels(&self) -> ChannelCount { + self.inner.channels() + } + + fn sample_rate(&self) -> SampleRate { + self.inner.sample_rate() + } + + fn total_duration(&self) -> Option { + Some(Duration::from_secs_f64( + self.left_to_take as f64 + / self.sample_rate().get() as f64 + / self.channels().get() as f64, + )) + } +} + +#[derive(Debug)] +struct ReplayQueue { + inner: ArrayQueue>, + normal_chunk_len: usize, + /// The last chunk in the queue may be smaller then + /// the normal chunk size. This is always equal to the + /// size of the last element in the queue. + /// (so normally chunk_size) + last_chunk: Mutex>, +} + +impl ReplayQueue { + fn new(queue_len: usize, chunk_size: usize) -> Self { + Self { + inner: ArrayQueue::new(queue_len), + normal_chunk_len: chunk_size, + last_chunk: Mutex::new(Vec::new()), + } + } + /// Returns the length in samples + fn len(&self) -> usize { + self.inner.len().saturating_sub(1) * self.normal_chunk_len + + self + .last_chunk + .lock() + .expect("Self::push_last can not poison this lock") + .len() + } + + fn pop(&self) -> Option> { + self.inner.pop() // removes element that was inserted first + } + + fn push_last(&self, mut samples: Vec) { + let mut last_chunk = self + .last_chunk + .lock() + .expect("Self::len can not poison this lock"); + std::mem::swap(&mut *last_chunk, &mut samples); + } + + fn push_normal(&self, samples: Vec) { + let _pushed_out_of_ringbuf = self.inner.force_push(samples); + } +} + +pub struct ProcessBuffer +where + S: Source + Sized, + F: FnMut(&mut [Sample; N]), +{ + inner: S, + callback: F, + /// Buffer used for both input and output. + buffer: [Sample; N], + /// Next already processed sample is at this index + /// in buffer. + /// + /// If this is equal to the length of the buffer we have no more samples and + /// we must get new ones and process them + next: usize, +} + +impl Iterator for ProcessBuffer +where + S: Source + Sized, + F: FnMut(&mut [Sample; N]), +{ + type Item = Sample; + + fn next(&mut self) -> Option { + self.next += 1; + if self.next < self.buffer.len() { + let sample = self.buffer[self.next]; + return Some(sample); + } + + for sample in &mut self.buffer { + *sample = self.inner.next()? + } + (self.callback)(&mut self.buffer); + + self.next = 0; + Some(self.buffer[0]) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +impl Source for ProcessBuffer +where + S: Source + Sized, + F: FnMut(&mut [Sample; N]), +{ + fn current_span_len(&self) -> Option { + None + } + + fn channels(&self) -> rodio::ChannelCount { + self.inner.channels() + } + + fn sample_rate(&self) -> rodio::SampleRate { + self.inner.sample_rate() + } + + fn total_duration(&self) -> Option { + self.inner.total_duration() + } +} + +pub struct InspectBuffer +where + S: Source + Sized, + F: FnMut(&[Sample; N]), +{ + inner: S, + callback: F, + /// Stores already emitted samples, once its full we call the callback. + buffer: [Sample; N], + /// Next free element in buffer. If this is equal to the buffer length + /// we have no more free lements. + free: usize, +} + +impl Iterator for InspectBuffer +where + S: Source + Sized, + F: FnMut(&[Sample; N]), +{ + type Item = Sample; + + fn next(&mut self) -> Option { + let Some(sample) = self.inner.next() else { + return None; + }; + + self.buffer[self.free] = sample; + self.free += 1; + + if self.free == self.buffer.len() { + (self.callback)(&self.buffer); + self.free = 0 + } + + Some(sample) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +impl Source for InspectBuffer +where + S: Source + Sized, + F: FnMut(&[Sample; N]), +{ + fn current_span_len(&self) -> Option { + None + } + + fn channels(&self) -> rodio::ChannelCount { + self.inner.channels() + } + + fn sample_rate(&self) -> rodio::SampleRate { + self.inner.sample_rate() + } + + fn total_duration(&self) -> Option { + self.inner.total_duration() + } +} + +#[derive(Debug)] +pub struct Replayable { + inner: S, + buffer: Vec, + chunk_size: usize, + tx: Arc, + is_active: Arc, +} + +impl Iterator for Replayable { + type Item = Sample; + + fn next(&mut self) -> Option { + if let Some(sample) = self.inner.next() { + self.buffer.push(sample); + if self.buffer.len() == self.chunk_size { + self.tx.push_normal(std::mem::take(&mut self.buffer)); + } + Some(sample) + } else { + let last_chunk = std::mem::take(&mut self.buffer); + self.tx.push_last(last_chunk); + self.is_active.store(false, Ordering::Relaxed); + None + } + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +impl Source for Replayable { + fn current_span_len(&self) -> Option { + self.inner.current_span_len() + } + + fn channels(&self) -> ChannelCount { + self.inner.channels() + } + + fn sample_rate(&self) -> SampleRate { + self.inner.sample_rate() + } + + fn total_duration(&self) -> Option { + self.inner.total_duration() + } +} + +#[derive(Debug)] +pub struct Replay { + rx: Arc, + buffer: std::vec::IntoIter, + sleep_duration: Duration, + sample_rate: SampleRate, + channel_count: ChannelCount, + source_is_active: Arc, +} + +impl Replay { + pub fn source_is_active(&self) -> bool { + // - source could return None and not drop + // - source could be dropped before returning None + self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2 + } + + /// Duration of what is in the buffer and can be returned without blocking. + pub fn duration_ready(&self) -> Duration { + let samples_per_second = self.channels().get() as u32 * self.sample_rate().get(); + + let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64; + Duration::from_secs_f64(seconds_queued) + } + + /// Number of samples in the buffer and can be returned without blocking. + pub fn samples_ready(&self) -> usize { + self.rx.len() + self.buffer.len() + } +} + +impl Iterator for Replay { + type Item = Sample; + + fn next(&mut self) -> Option { + if let Some(sample) = self.buffer.next() { + return Some(sample); + } + + loop { + if let Some(new_buffer) = self.rx.pop() { + self.buffer = new_buffer.into_iter(); + return self.buffer.next(); + } + + if !self.source_is_active() { + return None; + } + + std::thread::sleep(self.sleep_duration); + } + } + + fn size_hint(&self) -> (usize, Option) { + ((self.rx.len() + self.buffer.len()), None) + } +} + +impl Source for Replay { + fn current_span_len(&self) -> Option { + None // source is not compatible with spans + } + + fn channels(&self) -> ChannelCount { + self.channel_count + } + + fn sample_rate(&self) -> SampleRate { + self.sample_rate + } + + fn total_duration(&self) -> Option { + None + } +} + +#[cfg(test)] +mod tests { + use rodio::{nz, static_buffer::StaticSamplesBuffer}; + + use super::*; + + const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0]; + + fn test_source() -> StaticSamplesBuffer { + StaticSamplesBuffer::new(nz!(1), nz!(1), &SAMPLES) + } + + mod process_buffer { + use super::*; + + #[test] + fn callback_gets_all_samples() { + let input = test_source(); + + let _ = input + .process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES)) + .count(); + } + #[test] + fn callback_modifies_yielded() { + let input = test_source(); + + let yielded: Vec<_> = input + .process_buffer::<{ SAMPLES.len() }, _>(|buffer| { + for sample in buffer { + *sample += 1.0; + } + }) + .collect(); + assert_eq!( + yielded, + SAMPLES.into_iter().map(|s| s + 1.0).collect::>() + ) + } + #[test] + fn source_truncates_to_whole_buffers() { + let input = test_source(); + + let yielded = input + .process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3])) + .count(); + assert_eq!(yielded, 3) + } + } + + mod inspect_buffer { + use super::*; + + #[test] + fn callback_gets_all_samples() { + let input = test_source(); + + let _ = input + .inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES)) + .count(); + } + #[test] + fn source_does_not_truncate() { + let input = test_source(); + + let yielded = input + .inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3])) + .count(); + assert_eq!(yielded, SAMPLES.len()) + } + } + + mod instant_replay { + use super::*; + + #[test] + fn continues_after_history() { + let input = test_source(); + + let (mut replay, mut source) = input + .replayable(Duration::from_secs(3)) + .expect("longer then 100ms"); + + source.by_ref().take(3).count(); + let yielded: Vec = replay.by_ref().take(3).collect(); + assert_eq!(&yielded, &SAMPLES[0..3],); + + source.count(); + let yielded: Vec = replay.collect(); + assert_eq!(&yielded, &SAMPLES[3..5],); + } + + #[test] + fn keeps_only_latest() { + let input = test_source(); + + let (mut replay, mut source) = input + .replayable(Duration::from_secs(2)) + .expect("longer then 100ms"); + + source.by_ref().take(5).count(); // get all items but do not end the source + let yielded: Vec = replay.by_ref().take(2).collect(); + assert_eq!(&yielded, &SAMPLES[3..5]); + source.count(); // exhaust source + assert_eq!(replay.next(), None); + } + + #[test] + fn keeps_correct_amount_of_seconds() { + let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]); + + let (replay, mut source) = input + .replayable(Duration::from_secs(2)) + .expect("longer then 100ms"); + + // exhaust but do not yet end source + source.by_ref().take(40_000).count(); + + // take all samples we can without blocking + let ready = replay.samples_ready(); + let n_yielded = replay.take_samples(ready).count(); + + let max = source.sample_rate().get() * source.channels().get() as u32 * 2; + let margin = 16_000 / 10; // 100ms + assert!(n_yielded as u32 >= max - margin); + } + + #[test] + fn samples_ready() { + let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]); + let (mut replay, source) = input + .replayable(Duration::from_secs(2)) + .expect("longer then 100ms"); + assert_eq!(replay.by_ref().samples_ready(), 0); + + source.take(8000).count(); // half a second + let margin = 16_000 / 10; // 100ms + let ready = replay.samples_ready(); + assert!(ready >= 8000 - margin); + } + } +} diff --git a/crates/auto_update/src/auto_update.rs b/crates/auto_update/src/auto_update.rs index 4d0d2d59843d4cde885340319d261a4e7315765e..f5d4533a9ee042e62752f26b989bc75561c534ae 100644 --- a/crates/auto_update/src/auto_update.rs +++ b/crates/auto_update/src/auto_update.rs @@ -10,7 +10,7 @@ use paths::remote_servers_dir; use release_channel::{AppCommitSha, ReleaseChannel}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources, SettingsStore}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsStore, SettingsUi}; use smol::{fs, io::AsyncReadExt}; use smol::{fs::File, process::Command}; use std::{ @@ -118,14 +118,13 @@ struct AutoUpdateSetting(bool); /// Whether or not to automatically check for updates. /// /// Default: true -#[derive(Clone, Copy, Default, JsonSchema, Deserialize, Serialize)] +#[derive(Clone, Copy, Default, JsonSchema, Deserialize, Serialize, SettingsUi, SettingsKey)] #[serde(transparent)] +#[settings_key(key = "auto_update")] struct AutoUpdateSettingContent(bool); impl Settings for AutoUpdateSetting { - const KEY: Option<&'static str> = Some("auto_update"); - - type FileContent = Option; + type FileContent = AutoUpdateSettingContent; fn load(sources: SettingsSources, _: &mut App) -> Result { let auto_update = [ @@ -135,17 +134,19 @@ impl Settings for AutoUpdateSetting { sources.user, ] .into_iter() - .find_map(|value| value.copied().flatten()) - .unwrap_or(sources.default.ok_or_else(Self::missing_default)?); + .find_map(|value| value.copied()) + .unwrap_or(*sources.default); Ok(Self(auto_update.0)) } fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) { - vscode.enum_setting("update.mode", current, |s| match s { + let mut cur = &mut Some(*current); + vscode.enum_setting("update.mode", &mut cur, |s| match s { "none" | "manual" => Some(AutoUpdateSettingContent(false)), _ => Some(AutoUpdateSettingContent(true)), }); + *current = cur.unwrap(); } } @@ -543,7 +544,7 @@ impl AutoUpdater { async fn update(this: Entity, mut cx: AsyncApp) -> Result<()> { let (client, installed_version, previous_status, release_channel) = - this.read_with(&mut cx, |this, cx| { + this.read_with(&cx, |this, cx| { ( this.http_client.clone(), this.current_version, diff --git a/crates/auto_update_helper/src/auto_update_helper.rs b/crates/auto_update_helper/src/auto_update_helper.rs index 3aa57094d38f07400d6077e42203746d0cbb5bff..21ead701b2629960a9f2b5bc639f5d6dcdbc96c5 100644 --- a/crates/auto_update_helper/src/auto_update_helper.rs +++ b/crates/auto_update_helper/src/auto_update_helper.rs @@ -128,23 +128,20 @@ mod windows_impl { #[test] fn test_parse_args() { // launch can be specified via two separate arguments - assert_eq!(parse_args(["--launch".into(), "true".into()]).launch, true); - assert_eq!( - parse_args(["--launch".into(), "false".into()]).launch, - false - ); + assert!(parse_args(["--launch".into(), "true".into()]).launch); + assert!(!parse_args(["--launch".into(), "false".into()]).launch); // launch can be specified via one single argument - assert_eq!(parse_args(["--launch=true".into()]).launch, true); - assert_eq!(parse_args(["--launch=false".into()]).launch, false); + assert!(parse_args(["--launch=true".into()]).launch); + assert!(!parse_args(["--launch=false".into()]).launch); // launch defaults to true on no arguments - assert_eq!(parse_args([]).launch, true); + assert!(parse_args([]).launch); // launch defaults to true on invalid arguments - assert_eq!(parse_args(["--launch".into()]).launch, true); - assert_eq!(parse_args(["--launch=".into()]).launch, true); - assert_eq!(parse_args(["--launch=invalid".into()]).launch, true); + assert!(parse_args(["--launch".into()]).launch); + assert!(parse_args(["--launch=".into()]).launch); + assert!(parse_args(["--launch=invalid".into()]).launch); } } } diff --git a/crates/auto_update_helper/src/dialog.rs b/crates/auto_update_helper/src/dialog.rs index 757819df519a533fb79aa21bec5bed8c5a077590..903ac34da227b2929705ff2af72db3770cff6532 100644 --- a/crates/auto_update_helper/src/dialog.rs +++ b/crates/auto_update_helper/src/dialog.rs @@ -186,11 +186,11 @@ unsafe extern "system" fn wnd_proc( }), WM_TERMINATE => { with_dialog_data(hwnd, |data| { - if let Ok(result) = data.borrow_mut().rx.recv() { - if let Err(e) = result { - log::error!("Failed to update Zed: {:?}", e); - show_error(format!("Error: {:?}", e)); - } + if let Ok(result) = data.borrow_mut().rx.recv() + && let Err(e) = result + { + log::error!("Failed to update Zed: {:?}", e); + show_error(format!("Error: {:?}", e)); } }); unsafe { PostQuitMessage(0) }; diff --git a/crates/auto_update_helper/src/updater.rs b/crates/auto_update_helper/src/updater.rs index 920f8d5fcf3224f8842ff888249c2281412c478a..a48bbccec304a1b49bb0496c21b299f5dd176076 100644 --- a/crates/auto_update_helper/src/updater.rs +++ b/crates/auto_update_helper/src/updater.rs @@ -16,7 +16,7 @@ use crate::windows_impl::WM_JOB_UPDATED; type Job = fn(&Path) -> Result<()>; #[cfg(not(test))] -pub(crate) const JOBS: [Job; 6] = [ +pub(crate) const JOBS: &[Job] = &[ // Delete old files |app_dir| { let zed_executable = app_dir.join("Zed.exe"); @@ -32,6 +32,12 @@ pub(crate) const JOBS: [Job; 6] = [ std::fs::remove_file(&zed_cli) .context(format!("Failed to remove old file {}", zed_cli.display())) }, + |app_dir| { + let zed_wsl = app_dir.join("bin\\zed"); + log::info!("Removing old file: {}", zed_wsl.display()); + std::fs::remove_file(&zed_wsl) + .context(format!("Failed to remove old file {}", zed_wsl.display())) + }, // Copy new files |app_dir| { let zed_executable_source = app_dir.join("install\\Zed.exe"); @@ -65,6 +71,22 @@ pub(crate) const JOBS: [Job; 6] = [ zed_cli_dest.display() )) }, + |app_dir| { + let zed_wsl_source = app_dir.join("install\\bin\\zed"); + let zed_wsl_dest = app_dir.join("bin\\zed"); + log::info!( + "Copying new file {} to {}", + zed_wsl_source.display(), + zed_wsl_dest.display() + ); + std::fs::copy(&zed_wsl_source, &zed_wsl_dest) + .map(|_| ()) + .context(format!( + "Failed to copy new file {} to {}", + zed_wsl_source.display(), + zed_wsl_dest.display() + )) + }, // Clean up installer folder and updates folder |app_dir| { let updates_folder = app_dir.join("updates"); @@ -85,16 +107,12 @@ pub(crate) const JOBS: [Job; 6] = [ ]; #[cfg(test)] -pub(crate) const JOBS: [Job; 2] = [ +pub(crate) const JOBS: &[Job] = &[ |_| { std::thread::sleep(Duration::from_millis(1000)); if let Ok(config) = std::env::var("ZED_AUTO_UPDATE") { match config.as_str() { - "err" => Err(std::io::Error::new( - std::io::ErrorKind::Other, - "Simulated error", - )) - .context("Anyhow!"), + "err" => Err(std::io::Error::other("Simulated error")).context("Anyhow!"), _ => panic!("Unknown ZED_AUTO_UPDATE value: {}", config), } } else { @@ -105,11 +123,7 @@ pub(crate) const JOBS: [Job; 2] = [ std::thread::sleep(Duration::from_millis(1000)); if let Ok(config) = std::env::var("ZED_AUTO_UPDATE") { match config.as_str() { - "err" => Err(std::io::Error::new( - std::io::ErrorKind::Other, - "Simulated error", - )) - .context("Anyhow!"), + "err" => Err(std::io::Error::other("Simulated error")).context("Anyhow!"), _ => panic!("Unknown ZED_AUTO_UPDATE value: {}", config), } } else { diff --git a/crates/auto_update_ui/src/auto_update_ui.rs b/crates/auto_update_ui/src/auto_update_ui.rs index 63baef1f7d178045a2a2b5c976ede9ad75adb646..efac14968ea48d93ae35089d239916de1f0a5253 100644 --- a/crates/auto_update_ui/src/auto_update_ui.rs +++ b/crates/auto_update_ui/src/auto_update_ui.rs @@ -1,5 +1,4 @@ use auto_update::AutoUpdater; -use client::proto::UpdateNotification; use editor::{Editor, MultiBuffer}; use gpui::{App, Context, DismissEvent, Entity, Window, actions, prelude::*}; use http_client::HttpClient; @@ -88,10 +87,7 @@ fn view_release_notes_locally( .update_in(cx, |workspace, window, cx| { let project = workspace.project().clone(); let buffer = project.update(cx, |project, cx| { - let buffer = project.create_local_buffer("", markdown, cx); - project - .mark_buffer_as_non_searchable(buffer.read(cx).remote_id(), cx); - buffer + project.create_local_buffer("", markdown, false, cx) }); buffer.update(cx, |buffer, cx| { buffer.edit([(0..0, body.release_notes)], None, cx) @@ -114,7 +110,7 @@ fn view_release_notes_locally( cx, ); workspace.add_item_to_active_pane( - Box::new(markdown_preview.clone()), + Box::new(markdown_preview), None, true, window, @@ -141,6 +137,8 @@ pub fn notify_if_app_was_updated(cx: &mut App) { return; } + struct UpdateNotification; + let should_show_notification = updater.read(cx).should_show_update_notification(cx); cx.spawn(async move |cx| { let should_show_notification = should_show_notification.await?; diff --git a/crates/bedrock/src/bedrock.rs b/crates/bedrock/src/bedrock.rs index 1c6a9bd0a1e745da1dd4577741fc7cb4cab771ad..ec0b4070906fdfd31195668312b3e7b425cd28ee 100644 --- a/crates/bedrock/src/bedrock.rs +++ b/crates/bedrock/src/bedrock.rs @@ -3,6 +3,7 @@ mod models; use anyhow::{Context, Error, Result, anyhow}; use aws_sdk_bedrockruntime as bedrock; pub use aws_sdk_bedrockruntime as bedrock_client; +use aws_sdk_bedrockruntime::types::InferenceConfiguration; pub use aws_sdk_bedrockruntime::types::{ AnyToolChoice as BedrockAnyToolChoice, AutoToolChoice as BedrockAutoToolChoice, ContentBlock as BedrockInnerContent, Tool as BedrockTool, ToolChoice as BedrockToolChoice, @@ -17,7 +18,8 @@ pub use bedrock::types::{ ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse, ImageBlock as BedrockImageBlock, Message as BedrockMessage, ReasoningContentBlock as BedrockThinkingBlock, ReasoningTextBlock as BedrockThinkingTextBlock, - ResponseStream as BedrockResponseStream, ToolResultBlock as BedrockToolResultBlock, + ResponseStream as BedrockResponseStream, SystemContentBlock as BedrockSystemContentBlock, + ToolResultBlock as BedrockToolResultBlock, ToolResultContentBlock as BedrockToolResultContentBlock, ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock, }; @@ -54,14 +56,24 @@ pub async fn stream_completion( )]))); } - if request - .tools - .as_ref() - .map_or(false, |t| !t.tools.is_empty()) - { + if request.tools.as_ref().is_some_and(|t| !t.tools.is_empty()) { response = response.set_tool_config(request.tools); } + let inference_config = InferenceConfiguration::builder() + .max_tokens(request.max_tokens as i32) + .set_temperature(request.temperature) + .set_top_p(request.top_p) + .build(); + + response = response.inference_config(inference_config); + + if let Some(system) = request.system { + if !system.is_empty() { + response = response.system(BedrockSystemContentBlock::Text(system)); + } + } + let output = response .send() .await diff --git a/crates/bedrock/src/models.rs b/crates/bedrock/src/models.rs index 69d2ffb84569ef848f88de47f5394a6b25b18e02..c3a793d69d086a8a8c607d34debc5a7034f33f32 100644 --- a/crates/bedrock/src/models.rs +++ b/crates/bedrock/src/models.rs @@ -151,12 +151,12 @@ impl Model { pub fn id(&self) -> &str { match self { - Model::ClaudeSonnet4 => "claude-4-sonnet", - Model::ClaudeSonnet4Thinking => "claude-4-sonnet-thinking", - Model::ClaudeOpus4 => "claude-4-opus", - Model::ClaudeOpus4_1 => "claude-4-opus-1", - Model::ClaudeOpus4Thinking => "claude-4-opus-thinking", - Model::ClaudeOpus4_1Thinking => "claude-4-opus-1-thinking", + Model::ClaudeSonnet4 => "claude-sonnet-4", + Model::ClaudeSonnet4Thinking => "claude-sonnet-4-thinking", + Model::ClaudeOpus4 => "claude-opus-4", + Model::ClaudeOpus4_1 => "claude-opus-4-1", + Model::ClaudeOpus4Thinking => "claude-opus-4-thinking", + Model::ClaudeOpus4_1Thinking => "claude-opus-4-1-thinking", Model::Claude3_5SonnetV2 => "claude-3-5-sonnet-v2", Model::Claude3_5Sonnet => "claude-3-5-sonnet", Model::Claude3Opus => "claude-3-opus", @@ -359,14 +359,12 @@ impl Model { pub fn max_output_tokens(&self) -> u64 { match self { Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096, - Self::Claude3_7Sonnet - | Self::Claude3_7SonnetThinking - | Self::ClaudeSonnet4 - | Self::ClaudeSonnet4Thinking - | Self::ClaudeOpus4 - | Model::ClaudeOpus4Thinking + Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => 128_000, + Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking => 64_000, + Self::ClaudeOpus4 + | Self::ClaudeOpus4Thinking | Self::ClaudeOpus4_1 - | Model::ClaudeOpus4_1Thinking => 128_000, + | Self::ClaudeOpus4_1Thinking => 32_000, Self::Claude3_5SonnetV2 | Self::PalmyraWriterX4 | Self::PalmyraWriterX5 => 8_192, Self::Custom { max_output_tokens, .. @@ -784,10 +782,10 @@ mod tests { ); // Test thinking models have different friendly IDs but same request IDs - assert_eq!(Model::ClaudeSonnet4.id(), "claude-4-sonnet"); + assert_eq!(Model::ClaudeSonnet4.id(), "claude-sonnet-4"); assert_eq!( Model::ClaudeSonnet4Thinking.id(), - "claude-4-sonnet-thinking" + "claude-sonnet-4-thinking" ); assert_eq!( Model::ClaudeSonnet4.request_id(), diff --git a/crates/breadcrumbs/src/breadcrumbs.rs b/crates/breadcrumbs/src/breadcrumbs.rs index 8eed7497da0fea8cb0227b22885599b446e5aac0..a6b27476fe36b1143103e1acd035bda6cda15132 100644 --- a/crates/breadcrumbs/src/breadcrumbs.rs +++ b/crates/breadcrumbs/src/breadcrumbs.rs @@ -82,11 +82,12 @@ impl Render for Breadcrumbs { } text_style.color = Color::Muted.color(cx); - if index == 0 && !TabBarSettings::get_global(cx).show && active_item.is_dirty(cx) { - if let Some(styled_element) = apply_dirty_filename_style(&segment, &text_style, cx) - { - return styled_element; - } + if index == 0 + && !TabBarSettings::get_global(cx).show + && active_item.is_dirty(cx) + && let Some(styled_element) = apply_dirty_filename_style(&segment, &text_style, cx) + { + return styled_element; } StyledText::new(segment.text.replace('\n', "⏎")) @@ -231,7 +232,7 @@ fn apply_dirty_filename_style( let highlight = vec![(filename_position..text.len(), highlight_style)]; Some( StyledText::new(text) - .with_default_highlights(&text_style, highlight) + .with_default_highlights(text_style, highlight) .into_any(), ) } diff --git a/crates/buffer_diff/src/buffer_diff.rs b/crates/buffer_diff/src/buffer_diff.rs index 97f529fe377c0eaa3d74dca1600e1b3f0c3499db..22ee20e0db2810610dc2e7a4cae86dca90681337 100644 --- a/crates/buffer_diff/src/buffer_diff.rs +++ b/crates/buffer_diff/src/buffer_diff.rs @@ -162,6 +162,22 @@ impl BufferDiffSnapshot { } } + fn unchanged( + buffer: &text::BufferSnapshot, + base_text: language::BufferSnapshot, + ) -> BufferDiffSnapshot { + debug_assert_eq!(buffer.text(), base_text.text()); + BufferDiffSnapshot { + inner: BufferDiffInner { + base_text, + hunks: SumTree::new(buffer), + pending_hunks: SumTree::new(buffer), + base_text_exists: false, + }, + secondary_diff: None, + } + } + fn new_with_base_text( buffer: text::BufferSnapshot, base_text: Option>, @@ -175,12 +191,8 @@ impl BufferDiffSnapshot { if let Some(text) = &base_text { let base_text_rope = Rope::from(text.as_str()); base_text_pair = Some((text.clone(), base_text_rope.clone())); - let snapshot = language::Buffer::build_snapshot( - base_text_rope, - language.clone(), - language_registry.clone(), - cx, - ); + let snapshot = + language::Buffer::build_snapshot(base_text_rope, language, language_registry, cx); base_text_snapshot = cx.background_spawn(snapshot); base_text_exists = true; } else { @@ -217,7 +229,10 @@ impl BufferDiffSnapshot { cx: &App, ) -> impl Future + use<> { let base_text_exists = base_text.is_some(); - let base_text_pair = base_text.map(|text| (text, base_text_snapshot.as_rope().clone())); + let base_text_pair = base_text.map(|text| { + debug_assert_eq!(&*text, &base_text_snapshot.text()); + (text, base_text_snapshot.as_rope().clone()) + }); cx.background_executor() .spawn_labeled(*CALCULATE_DIFF_TASK, async move { Self { @@ -572,14 +587,14 @@ impl BufferDiffInner { pending_range.end.column = 0; } - if pending_range == (start_point..end_point) { - if !buffer.has_edits_since_in_range( + if pending_range == (start_point..end_point) + && !buffer.has_edits_since_in_range( &pending_hunk.buffer_version, start_anchor..end_anchor, - ) { - has_pending = true; - secondary_status = pending_hunk.new_status; - } + ) + { + has_pending = true; + secondary_status = pending_hunk.new_status; } } @@ -877,6 +892,18 @@ impl BufferDiff { } } + pub fn new_unchanged( + buffer: &text::BufferSnapshot, + base_text: language::BufferSnapshot, + ) -> Self { + debug_assert_eq!(buffer.text(), base_text.text()); + BufferDiff { + buffer_id: buffer.remote_id(), + inner: BufferDiffSnapshot::unchanged(buffer, base_text).inner, + secondary_diff: None, + } + } + #[cfg(any(test, feature = "test-support"))] pub fn new_with_base_text( base_text: &str, @@ -928,7 +955,7 @@ impl BufferDiff { let new_index_text = self.inner.stage_or_unstage_hunks_impl( &self.secondary_diff.as_ref()?.read(cx).inner, stage, - &hunks, + hunks, buffer, file_exists, ); @@ -952,12 +979,12 @@ impl BufferDiff { cx: &App, ) -> Option> { let start = self - .hunks_intersecting_range(range.clone(), &buffer, cx) + .hunks_intersecting_range(range.clone(), buffer, cx) .next()? .buffer_range .start; let end = self - .hunks_intersecting_range_rev(range.clone(), &buffer) + .hunks_intersecting_range_rev(range, buffer) .next()? .buffer_range .end; @@ -1031,21 +1058,20 @@ impl BufferDiff { && state.base_text.syntax_update_count() == new_state.base_text.syntax_update_count() => { - (false, new_state.compare(&state, buffer)) + (false, new_state.compare(state, buffer)) } _ => (true, Some(text::Anchor::MIN..text::Anchor::MAX)), }; - if let Some(secondary_changed_range) = secondary_diff_change { - if let Some(secondary_hunk_range) = - self.range_to_hunk_range(secondary_changed_range, &buffer, cx) - { - if let Some(range) = &mut changed_range { - range.start = secondary_hunk_range.start.min(&range.start, &buffer); - range.end = secondary_hunk_range.end.max(&range.end, &buffer); - } else { - changed_range = Some(secondary_hunk_range); - } + if let Some(secondary_changed_range) = secondary_diff_change + && let Some(secondary_hunk_range) = + self.range_to_hunk_range(secondary_changed_range, buffer, cx) + { + if let Some(range) = &mut changed_range { + range.start = secondary_hunk_range.start.min(&range.start, buffer); + range.end = secondary_hunk_range.end.max(&range.end, buffer); + } else { + changed_range = Some(secondary_hunk_range); } } @@ -1057,8 +1083,8 @@ impl BufferDiff { if let Some((first, last)) = state.pending_hunks.first().zip(state.pending_hunks.last()) { if let Some(range) = &mut changed_range { - range.start = range.start.min(&first.buffer_range.start, &buffer); - range.end = range.end.max(&last.buffer_range.end, &buffer); + range.start = range.start.min(&first.buffer_range.start, buffer); + range.end = range.end.max(&last.buffer_range.end, buffer); } else { changed_range = Some(first.buffer_range.start..last.buffer_range.end); } @@ -1442,7 +1468,7 @@ mod tests { .unindent(); let buffer = Buffer::new(0, BufferId::new(1).unwrap(), buffer_text); - let unstaged_diff = BufferDiffSnapshot::new_sync(buffer.clone(), index_text.clone(), cx); + let unstaged_diff = BufferDiffSnapshot::new_sync(buffer.clone(), index_text, cx); let mut uncommitted_diff = BufferDiffSnapshot::new_sync(buffer.clone(), head_text.clone(), cx); uncommitted_diff.secondary_diff = Some(Box::new(unstaged_diff)); @@ -1797,7 +1823,7 @@ mod tests { uncommitted_diff.update(cx, |diff, cx| { let hunks = diff - .hunks_intersecting_range(hunk_range.clone(), &buffer, &cx) + .hunks_intersecting_range(hunk_range.clone(), &buffer, cx) .collect::>(); for hunk in &hunks { assert_ne!( @@ -1812,7 +1838,7 @@ mod tests { .to_string(); let hunks = diff - .hunks_intersecting_range(hunk_range.clone(), &buffer, &cx) + .hunks_intersecting_range(hunk_range.clone(), &buffer, cx) .collect::>(); for hunk in &hunks { assert_eq!( @@ -1870,7 +1896,7 @@ mod tests { .to_string(); assert_eq!(new_index_text, buffer_text); - let hunk = diff.hunks(&buffer, &cx).next().unwrap(); + let hunk = diff.hunks(&buffer, cx).next().unwrap(); assert_eq!( hunk.secondary_status, DiffHunkSecondaryStatus::SecondaryHunkRemovalPending @@ -1882,7 +1908,7 @@ mod tests { .to_string(); assert_eq!(index_text, head_text); - let hunk = diff.hunks(&buffer, &cx).next().unwrap(); + let hunk = diff.hunks(&buffer, cx).next().unwrap(); // optimistically unstaged (fine, could also be HasSecondaryHunk) assert_eq!( hunk.secondary_status, @@ -2018,10 +2044,10 @@ mod tests { #[gpui::test(iterations = 100)] async fn test_staging_and_unstaging_hunks(cx: &mut TestAppContext, mut rng: StdRng) { fn gen_line(rng: &mut StdRng) -> String { - if rng.gen_bool(0.2) { + if rng.random_bool(0.2) { "\n".to_owned() } else { - let c = rng.gen_range('A'..='Z'); + let c = rng.random_range('A'..='Z'); format!("{c}{c}{c}\n") } } @@ -2029,8 +2055,8 @@ mod tests { fn gen_working_copy(rng: &mut StdRng, head: &str) -> String { let mut old_lines = { let mut old_lines = Vec::new(); - let mut old_lines_iter = head.lines(); - while let Some(line) = old_lines_iter.next() { + let old_lines_iter = head.lines(); + for line in old_lines_iter { assert!(!line.ends_with("\n")); old_lines.push(line.to_owned()); } @@ -2040,7 +2066,7 @@ mod tests { old_lines.into_iter() }; let mut result = String::new(); - let unchanged_count = rng.gen_range(0..=old_lines.len()); + let unchanged_count = rng.random_range(0..=old_lines.len()); result += &old_lines .by_ref() @@ -2050,14 +2076,14 @@ mod tests { s }); while old_lines.len() > 0 { - let deleted_count = rng.gen_range(0..=old_lines.len()); + let deleted_count = rng.random_range(0..=old_lines.len()); let _advance = old_lines .by_ref() .take(deleted_count) .map(|line| line.len() + 1) .sum::(); let minimum_added = if deleted_count == 0 { 1 } else { 0 }; - let added_count = rng.gen_range(minimum_added..=5); + let added_count = rng.random_range(minimum_added..=5); let addition = (0..added_count).map(|_| gen_line(rng)).collect::(); result += &addition; @@ -2066,7 +2092,8 @@ mod tests { if blank_lines == old_lines.len() { break; }; - let unchanged_count = rng.gen_range((blank_lines + 1).max(1)..=old_lines.len()); + let unchanged_count = + rng.random_range((blank_lines + 1).max(1)..=old_lines.len()); result += &old_lines.by_ref().take(unchanged_count).fold( String::new(), |mut s, line| { @@ -2123,7 +2150,7 @@ mod tests { ) }); let working_copy = working_copy.read_with(cx, |working_copy, _| working_copy.snapshot()); - let mut index_text = if rng.r#gen() { + let mut index_text = if rng.random() { Rope::from(head_text.as_str()) } else { working_copy.as_rope().clone() @@ -2134,12 +2161,12 @@ mod tests { diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &working_copy, cx) .collect::>() }); - if hunks.len() == 0 { + if hunks.is_empty() { return; } for _ in 0..operations { - let i = rng.gen_range(0..hunks.len()); + let i = rng.random_range(0..hunks.len()); let hunk = &mut hunks[i]; let hunk_to_change = hunk.clone(); let stage = match hunk.secondary_status { diff --git a/crates/call/Cargo.toml b/crates/call/Cargo.toml index 30e2943af3fcb9e8d5141568b2602a8db9a69a6c..ad3d569d61482ad71ee98e636db8c20274d56820 100644 --- a/crates/call/Cargo.toml +++ b/crates/call/Cargo.toml @@ -29,6 +29,7 @@ client.workspace = true collections.workspace = true fs.workspace = true futures.workspace = true +feature_flags.workspace = true gpui = { workspace = true, features = ["screen-capture"] } language.workspace = true log.workspace = true diff --git a/crates/call/src/call_impl/mod.rs b/crates/call/src/call_impl/mod.rs index 71c314932419e1228c74e2d3de547a4e21b152c6..156a80faba61d2a4946bafa5943c167284d14a97 100644 --- a/crates/call/src/call_impl/mod.rs +++ b/crates/call/src/call_impl/mod.rs @@ -116,7 +116,7 @@ impl ActiveCall { envelope: TypedEnvelope, mut cx: AsyncApp, ) -> Result { - let user_store = this.read_with(&mut cx, |this, _| this.user_store.clone())?; + let user_store = this.read_with(&cx, |this, _| this.user_store.clone())?; let call = IncomingCall { room_id: envelope.payload.room_id, participants: user_store @@ -147,7 +147,7 @@ impl ActiveCall { let mut incoming_call = this.incoming_call.0.borrow_mut(); if incoming_call .as_ref() - .map_or(false, |call| call.room_id == envelope.payload.room_id) + .is_some_and(|call| call.room_id == envelope.payload.room_id) { incoming_call.take(); } diff --git a/crates/call/src/call_impl/participant.rs b/crates/call/src/call_impl/participant.rs index 8e1e264a23d7c58c927d182bbac811a0beb4f02a..6fb6a2eb79b537aa9d7296a323f7d45221a4b05d 100644 --- a/crates/call/src/call_impl/participant.rs +++ b/crates/call/src/call_impl/participant.rs @@ -64,7 +64,7 @@ pub struct RemoteParticipant { impl RemoteParticipant { pub fn has_video_tracks(&self) -> bool { - return !self.video_tracks.is_empty(); + !self.video_tracks.is_empty() } pub fn can_write(&self) -> bool { diff --git a/crates/call/src/call_impl/room.rs b/crates/call/src/call_impl/room.rs index afeee4c924feb2990668f953d5b2f7dfcff26f34..930846ab8ff37272f9b0fc0652319318c676f3f7 100644 --- a/crates/call/src/call_impl/room.rs +++ b/crates/call/src/call_impl/room.rs @@ -9,11 +9,12 @@ use client::{ proto::{self, PeerId}, }; use collections::{BTreeMap, HashMap, HashSet}; +use feature_flags::FeatureFlagAppExt; use fs::Fs; -use futures::{FutureExt, StreamExt}; +use futures::StreamExt; use gpui::{ - App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, ScreenCaptureSource, - ScreenCaptureStream, Task, WeakEntity, + App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, FutureExt as _, + ScreenCaptureSource, ScreenCaptureStream, Task, Timeout, WeakEntity, }; use gpui_tokio::Tokio; use language::LanguageRegistry; @@ -370,57 +371,53 @@ impl Room { })?; // Wait for client to re-establish a connection to the server. - { - let mut reconnection_timeout = - cx.background_executor().timer(RECONNECT_TIMEOUT).fuse(); - let client_reconnection = async { - let mut remaining_attempts = 3; - while remaining_attempts > 0 { - if client_status.borrow().is_connected() { - log::info!("client reconnected, attempting to rejoin room"); - - let Some(this) = this.upgrade() else { break }; - match this.update(cx, |this, cx| this.rejoin(cx)) { - Ok(task) => { - if task.await.log_err().is_some() { - return true; - } else { - remaining_attempts -= 1; - } + let executor = cx.background_executor().clone(); + let client_reconnection = async { + let mut remaining_attempts = 3; + while remaining_attempts > 0 { + if client_status.borrow().is_connected() { + log::info!("client reconnected, attempting to rejoin room"); + + let Some(this) = this.upgrade() else { break }; + match this.update(cx, |this, cx| this.rejoin(cx)) { + Ok(task) => { + if task.await.log_err().is_some() { + return true; + } else { + remaining_attempts -= 1; } - Err(_app_dropped) => return false, } - } else if client_status.borrow().is_signed_out() { - return false; + Err(_app_dropped) => return false, } - - log::info!( - "waiting for client status change, remaining attempts {}", - remaining_attempts - ); - client_status.next().await; + } else if client_status.borrow().is_signed_out() { + return false; } - false + + log::info!( + "waiting for client status change, remaining attempts {}", + remaining_attempts + ); + client_status.next().await; } - .fuse(); - futures::pin_mut!(client_reconnection); - - futures::select_biased! { - reconnected = client_reconnection => { - if reconnected { - log::info!("successfully reconnected to room"); - // If we successfully joined the room, go back around the loop - // waiting for future connection status changes. - continue; - } - } - _ = reconnection_timeout => { - log::info!("room reconnection timeout expired"); - } + false + }; + + match client_reconnection + .with_timeout(RECONNECT_TIMEOUT, &executor) + .await + { + Ok(true) => { + log::info!("successfully reconnected to room"); + // If we successfully joined the room, go back around the loop + // waiting for future connection status changes. + continue; + } + Ok(false) => break, + Err(Timeout) => { + log::info!("room reconnection timeout expired"); + break; } } - - break; } } @@ -831,24 +828,23 @@ impl Room { ); Audio::play_sound(Sound::Joined, cx); - if let Some(livekit_participants) = &livekit_participants { - if let Some(livekit_participant) = livekit_participants + if let Some(livekit_participants) = &livekit_participants + && let Some(livekit_participant) = livekit_participants .get(&ParticipantIdentity(user.id.to_string())) + { + for publication in + livekit_participant.track_publications().into_values() { - for publication in - livekit_participant.track_publications().into_values() - { - if let Some(track) = publication.track() { - this.livekit_room_updated( - RoomEvent::TrackSubscribed { - track, - publication, - participant: livekit_participant.clone(), - }, - cx, - ) - .warn_on_err(); - } + if let Some(track) = publication.track() { + this.livekit_room_updated( + RoomEvent::TrackSubscribed { + track, + publication, + participant: livekit_participant.clone(), + }, + cx, + ) + .warn_on_err(); } } } @@ -944,10 +940,8 @@ impl Room { self.client.user_id() ) })?; - if self.live_kit.as_ref().map_or(true, |kit| kit.deafened) { - if publication.is_audio() { - publication.set_enabled(false, cx); - } + if self.live_kit.as_ref().is_none_or(|kit| kit.deafened) && publication.is_audio() { + publication.set_enabled(false, cx); } match track { livekit_client::RemoteTrack::Audio(track) => { @@ -1009,10 +1003,10 @@ impl Room { for (sid, participant) in &mut self.remote_participants { participant.speaking = speaker_ids.binary_search(sid).is_ok(); } - if let Some(id) = self.client.user_id() { - if let Some(room) = &mut self.live_kit { - room.speaking = speaker_ids.binary_search(&id).is_ok(); - } + if let Some(id) = self.client.user_id() + && let Some(room) = &mut self.live_kit + { + room.speaking = speaker_ids.binary_search(&id).is_ok(); } } @@ -1046,18 +1040,16 @@ impl Room { if let LocalTrack::Published { track_publication, .. } = &room.microphone_track + && track_publication.sid() == publication.sid() { - if track_publication.sid() == publication.sid() { - room.microphone_track = LocalTrack::None; - } + room.microphone_track = LocalTrack::None; } if let LocalTrack::Published { track_publication, .. } = &room.screen_track + && track_publication.sid() == publication.sid() { - if track_publication.sid() == publication.sid() { - room.screen_track = LocalTrack::None; - } + room.screen_track = LocalTrack::None; } } } @@ -1170,7 +1162,7 @@ impl Room { let request = self.client.request(proto::ShareProject { room_id: self.id(), worktrees: project.read(cx).worktree_metadata_protos(cx), - is_ssh_project: project.read(cx).is_via_ssh(), + is_ssh_project: project.read(cx).is_via_remote_server(), }); cx.spawn(async move |this, cx| { @@ -1182,7 +1174,7 @@ impl Room { this.update(cx, |this, cx| { this.shared_projects.insert(project.downgrade()); let active_project = this.local_participant.active_project.as_ref(); - if active_project.map_or(false, |location| *location == project) { + if active_project.is_some_and(|location| *location == project) { this.set_location(Some(&project), cx) } else { Task::ready(Ok(())) @@ -1255,9 +1247,9 @@ impl Room { } pub fn is_sharing_screen(&self) -> bool { - self.live_kit.as_ref().map_or(false, |live_kit| { - !matches!(live_kit.screen_track, LocalTrack::None) - }) + self.live_kit + .as_ref() + .is_some_and(|live_kit| !matches!(live_kit.screen_track, LocalTrack::None)) } pub fn shared_screen_id(&self) -> Option { @@ -1270,13 +1262,13 @@ impl Room { } pub fn is_sharing_mic(&self) -> bool { - self.live_kit.as_ref().map_or(false, |live_kit| { - !matches!(live_kit.microphone_track, LocalTrack::None) - }) + self.live_kit + .as_ref() + .is_some_and(|live_kit| !matches!(live_kit.microphone_track, LocalTrack::None)) } pub fn is_muted(&self) -> bool { - self.live_kit.as_ref().map_or(false, |live_kit| { + self.live_kit.as_ref().is_some_and(|live_kit| { matches!(live_kit.microphone_track, LocalTrack::None) || live_kit.muted_by_user || live_kit.deafened @@ -1286,13 +1278,13 @@ impl Room { pub fn muted_by_user(&self) -> bool { self.live_kit .as_ref() - .map_or(false, |live_kit| live_kit.muted_by_user) + .is_some_and(|live_kit| live_kit.muted_by_user) } pub fn is_speaking(&self) -> bool { self.live_kit .as_ref() - .map_or(false, |live_kit| live_kit.speaking) + .is_some_and(|live_kit| live_kit.speaking) } pub fn is_deafened(&self) -> Option { @@ -1331,8 +1323,18 @@ impl Room { return Task::ready(Err(anyhow!("live-kit was not initialized"))); }; + let is_staff = cx.is_staff(); + let user_name = self + .user_store + .read(cx) + .current_user() + .and_then(|user| user.name.clone()) + .unwrap_or_else(|| "unknown".to_string()); + cx.spawn(async move |this, cx| { - let publication = room.publish_local_microphone_track(cx).await; + let publication = room + .publish_local_microphone_track(user_name, is_staff, cx) + .await; this.update(cx, |this, cx| { let live_kit = this .live_kit @@ -1488,10 +1490,8 @@ impl Room { self.set_deafened(deafened, cx); - if should_change_mute { - if let Some(task) = self.set_mute(deafened, cx) { - task.detach_and_log_err(cx); - } + if should_change_mute && let Some(task) = self.set_mute(deafened, cx) { + task.detach_and_log_err(cx); } } } diff --git a/crates/call/src/call_settings.rs b/crates/call/src/call_settings.rs index c8f51e0c1a2019dd2c266210e469989946ed8a35..b0677e3c3bcb5112fdd9ad2abc4bf188b225aeac 100644 --- a/crates/call/src/call_settings.rs +++ b/crates/call/src/call_settings.rs @@ -2,7 +2,7 @@ use anyhow::Result; use gpui::App; use schemars::JsonSchema; use serde_derive::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsUi}; #[derive(Deserialize, Debug)] pub struct CallSettings { @@ -11,7 +11,8 @@ pub struct CallSettings { } /// Configuration of voice calls in Zed. -#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)] +#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug, SettingsUi, SettingsKey)] +#[settings_key(key = "calls")] pub struct CallSettingsContent { /// Whether the microphone should be muted when joining a channel or a call. /// @@ -25,8 +26,6 @@ pub struct CallSettingsContent { } impl Settings for CallSettings { - const KEY: Option<&'static str> = Some("calls"); - type FileContent = CallSettingsContent; fn load(sources: SettingsSources, _: &mut App) -> Result { diff --git a/crates/channel/Cargo.toml b/crates/channel/Cargo.toml index 962847f3f1cf21f361b6e2f1b9299c0c66992b3e..ab6e1dfc2b8dd0f89c4e6cd03e5ee66840003d6a 100644 --- a/crates/channel/Cargo.toml +++ b/crates/channel/Cargo.toml @@ -25,11 +25,9 @@ gpui.workspace = true language.workspace = true log.workspace = true postage.workspace = true -rand.workspace = true release_channel.workspace = true rpc.workspace = true settings.workspace = true -sum_tree.workspace = true text.workspace = true time.workspace = true util.workspace = true diff --git a/crates/channel/src/channel.rs b/crates/channel/src/channel.rs index 63865c574ecc36da27e18f02ccb8c44138cef3ba..6cc5a0e8815a4f24f41b3677622f8c200a4f59d9 100644 --- a/crates/channel/src/channel.rs +++ b/crates/channel/src/channel.rs @@ -1,5 +1,4 @@ mod channel_buffer; -mod channel_chat; mod channel_store; use client::{Client, UserStore}; @@ -7,10 +6,6 @@ use gpui::{App, Entity}; use std::sync::Arc; pub use channel_buffer::{ACKNOWLEDGE_DEBOUNCE_INTERVAL, ChannelBuffer, ChannelBufferEvent}; -pub use channel_chat::{ - ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId, MessageParams, - mentions_to_proto, -}; pub use channel_store::{Channel, ChannelEvent, ChannelMembership, ChannelStore}; #[cfg(test)] @@ -19,5 +14,4 @@ mod channel_store_tests; pub fn init(client: &Arc, user_store: Entity, cx: &mut App) { channel_store::init(client, user_store, cx); channel_buffer::init(&client.clone().into()); - channel_chat::init(&client.clone().into()); } diff --git a/crates/channel/src/channel_buffer.rs b/crates/channel/src/channel_buffer.rs index 183f7eb3c6a47dad4cb35b95dd2d3e096e0a612f..828248b330b6ef6cfe0e13eab426de2900d364b2 100644 --- a/crates/channel/src/channel_buffer.rs +++ b/crates/channel/src/channel_buffer.rs @@ -82,7 +82,7 @@ impl ChannelBuffer { collaborators: Default::default(), acknowledge_task: None, channel_id: channel.id, - subscription: Some(subscription.set_entity(&cx.entity(), &mut cx.to_async())), + subscription: Some(subscription.set_entity(&cx.entity(), &cx.to_async())), user_store, channel_store, }; @@ -110,7 +110,7 @@ impl ChannelBuffer { let Ok(subscription) = self.client.subscribe_to_entity(self.channel_id.0) else { return; }; - self.subscription = Some(subscription.set_entity(&cx.entity(), &mut cx.to_async())); + self.subscription = Some(subscription.set_entity(&cx.entity(), &cx.to_async())); cx.emit(ChannelBufferEvent::Connected); } } @@ -135,7 +135,7 @@ impl ChannelBuffer { } } - for (_, old_collaborator) in &self.collaborators { + for old_collaborator in self.collaborators.values() { if !new_collaborators.contains_key(&old_collaborator.peer_id) { self.buffer.update(cx, |buffer, cx| { buffer.remove_peer(old_collaborator.replica_id, cx) @@ -191,12 +191,11 @@ impl ChannelBuffer { operation, is_local: true, } => { - if *ZED_ALWAYS_ACTIVE { - if let language::Operation::UpdateSelections { selections, .. } = operation { - if selections.is_empty() { - return; - } - } + if *ZED_ALWAYS_ACTIVE + && let language::Operation::UpdateSelections { selections, .. } = operation + && selections.is_empty() + { + return; } let operation = language::proto::serialize_operation(operation); self.client diff --git a/crates/channel/src/channel_chat.rs b/crates/channel/src/channel_chat.rs deleted file mode 100644 index 4ac37ffd14ca2602756afecc788aae9f6065cad9..0000000000000000000000000000000000000000 --- a/crates/channel/src/channel_chat.rs +++ /dev/null @@ -1,862 +0,0 @@ -use crate::{Channel, ChannelStore}; -use anyhow::{Context as _, Result}; -use client::{ - ChannelId, Client, Subscription, TypedEnvelope, UserId, proto, - user::{User, UserStore}, -}; -use collections::HashSet; -use futures::lock::Mutex; -use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity}; -use rand::prelude::*; -use rpc::AnyProtoClient; -use std::{ - ops::{ControlFlow, Range}, - sync::Arc, -}; -use sum_tree::{Bias, Dimensions, SumTree}; -use time::OffsetDateTime; -use util::{ResultExt as _, TryFutureExt, post_inc}; - -pub struct ChannelChat { - pub channel_id: ChannelId, - messages: SumTree, - acknowledged_message_ids: HashSet, - channel_store: Entity, - loaded_all_messages: bool, - last_acknowledged_id: Option, - next_pending_message_id: usize, - first_loaded_message_id: Option, - user_store: Entity, - rpc: Arc, - outgoing_messages_lock: Arc>, - rng: StdRng, - _subscription: Subscription, -} - -#[derive(Debug, PartialEq, Eq)] -pub struct MessageParams { - pub text: String, - pub mentions: Vec<(Range, UserId)>, - pub reply_to_message_id: Option, -} - -#[derive(Clone, Debug)] -pub struct ChannelMessage { - pub id: ChannelMessageId, - pub body: String, - pub timestamp: OffsetDateTime, - pub sender: Arc, - pub nonce: u128, - pub mentions: Vec<(Range, UserId)>, - pub reply_to_message_id: Option, - pub edited_at: Option, -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum ChannelMessageId { - Saved(u64), - Pending(usize), -} - -impl From for Option { - fn from(val: ChannelMessageId) -> Self { - match val { - ChannelMessageId::Saved(id) => Some(id), - ChannelMessageId::Pending(_) => None, - } - } -} - -#[derive(Clone, Debug, Default)] -pub struct ChannelMessageSummary { - max_id: ChannelMessageId, - count: usize, -} - -#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] -struct Count(usize); - -#[derive(Clone, Debug, PartialEq)] -pub enum ChannelChatEvent { - MessagesUpdated { - old_range: Range, - new_count: usize, - }, - UpdateMessage { - message_id: ChannelMessageId, - message_ix: usize, - }, - NewMessage { - channel_id: ChannelId, - message_id: u64, - }, -} - -impl EventEmitter for ChannelChat {} -pub fn init(client: &AnyProtoClient) { - client.add_entity_message_handler(ChannelChat::handle_message_sent); - client.add_entity_message_handler(ChannelChat::handle_message_removed); - client.add_entity_message_handler(ChannelChat::handle_message_updated); -} - -impl ChannelChat { - pub async fn new( - channel: Arc, - channel_store: Entity, - user_store: Entity, - client: Arc, - cx: &mut AsyncApp, - ) -> Result> { - let channel_id = channel.id; - let subscription = client.subscribe_to_entity(channel_id.0).unwrap(); - - let response = client - .request(proto::JoinChannelChat { - channel_id: channel_id.0, - }) - .await?; - - let handle = cx.new(|cx| { - cx.on_release(Self::release).detach(); - Self { - channel_id: channel.id, - user_store: user_store.clone(), - channel_store, - rpc: client.clone(), - outgoing_messages_lock: Default::default(), - messages: Default::default(), - acknowledged_message_ids: Default::default(), - loaded_all_messages: false, - next_pending_message_id: 0, - last_acknowledged_id: None, - rng: StdRng::from_entropy(), - first_loaded_message_id: None, - _subscription: subscription.set_entity(&cx.entity(), &cx.to_async()), - } - })?; - Self::handle_loaded_messages( - handle.downgrade(), - user_store, - client, - response.messages, - response.done, - cx, - ) - .await?; - Ok(handle) - } - - fn release(&mut self, _: &mut App) { - self.rpc - .send(proto::LeaveChannelChat { - channel_id: self.channel_id.0, - }) - .log_err(); - } - - pub fn channel(&self, cx: &App) -> Option> { - self.channel_store - .read(cx) - .channel_for_id(self.channel_id) - .cloned() - } - - pub fn client(&self) -> &Arc { - &self.rpc - } - - pub fn send_message( - &mut self, - message: MessageParams, - cx: &mut Context, - ) -> Result>> { - anyhow::ensure!( - !message.text.trim().is_empty(), - "message body can't be empty" - ); - - let current_user = self - .user_store - .read(cx) - .current_user() - .context("current_user is not present")?; - - let channel_id = self.channel_id; - let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id)); - let nonce = self.rng.r#gen(); - self.insert_messages( - SumTree::from_item( - ChannelMessage { - id: pending_id, - body: message.text.clone(), - sender: current_user, - timestamp: OffsetDateTime::now_utc(), - mentions: message.mentions.clone(), - nonce, - reply_to_message_id: message.reply_to_message_id, - edited_at: None, - }, - &(), - ), - cx, - ); - let user_store = self.user_store.clone(); - let rpc = self.rpc.clone(); - let outgoing_messages_lock = self.outgoing_messages_lock.clone(); - - // todo - handle messages that fail to send (e.g. >1024 chars) - Ok(cx.spawn(async move |this, cx| { - let outgoing_message_guard = outgoing_messages_lock.lock().await; - let request = rpc.request(proto::SendChannelMessage { - channel_id: channel_id.0, - body: message.text, - nonce: Some(nonce.into()), - mentions: mentions_to_proto(&message.mentions), - reply_to_message_id: message.reply_to_message_id, - }); - let response = request.await?; - drop(outgoing_message_guard); - let response = response.message.context("invalid message")?; - let id = response.id; - let message = ChannelMessage::from_proto(response, &user_store, cx).await?; - this.update(cx, |this, cx| { - this.insert_messages(SumTree::from_item(message, &()), cx); - if this.first_loaded_message_id.is_none() { - this.first_loaded_message_id = Some(id); - } - })?; - Ok(id) - })) - } - - pub fn remove_message(&mut self, id: u64, cx: &mut Context) -> Task> { - let response = self.rpc.request(proto::RemoveChannelMessage { - channel_id: self.channel_id.0, - message_id: id, - }); - cx.spawn(async move |this, cx| { - response.await?; - this.update(cx, |this, cx| { - this.message_removed(id, cx); - })?; - Ok(()) - }) - } - - pub fn update_message( - &mut self, - id: u64, - message: MessageParams, - cx: &mut Context, - ) -> Result>> { - self.message_update( - ChannelMessageId::Saved(id), - message.text.clone(), - message.mentions.clone(), - Some(OffsetDateTime::now_utc()), - cx, - ); - - let nonce: u128 = self.rng.r#gen(); - - let request = self.rpc.request(proto::UpdateChannelMessage { - channel_id: self.channel_id.0, - message_id: id, - body: message.text, - nonce: Some(nonce.into()), - mentions: mentions_to_proto(&message.mentions), - }); - Ok(cx.spawn(async move |_, _| { - request.await?; - Ok(()) - })) - } - - pub fn load_more_messages(&mut self, cx: &mut Context) -> Option>> { - if self.loaded_all_messages { - return None; - } - - let rpc = self.rpc.clone(); - let user_store = self.user_store.clone(); - let channel_id = self.channel_id; - let before_message_id = self.first_loaded_message_id()?; - Some(cx.spawn(async move |this, cx| { - async move { - let response = rpc - .request(proto::GetChannelMessages { - channel_id: channel_id.0, - before_message_id, - }) - .await?; - Self::handle_loaded_messages( - this, - user_store, - rpc, - response.messages, - response.done, - cx, - ) - .await?; - - anyhow::Ok(()) - } - .log_err() - .await - })) - } - - pub fn first_loaded_message_id(&mut self) -> Option { - self.first_loaded_message_id - } - - /// Load a message by its id, if it's already stored locally. - pub fn find_loaded_message(&self, id: u64) -> Option<&ChannelMessage> { - self.messages.iter().find(|message| match message.id { - ChannelMessageId::Saved(message_id) => message_id == id, - ChannelMessageId::Pending(_) => false, - }) - } - - /// Load all of the chat messages since a certain message id. - /// - /// For now, we always maintain a suffix of the channel's messages. - pub async fn load_history_since_message( - chat: Entity, - message_id: u64, - mut cx: AsyncApp, - ) -> Option { - loop { - let step = chat - .update(&mut cx, |chat, cx| { - if let Some(first_id) = chat.first_loaded_message_id() { - if first_id <= message_id { - let mut cursor = chat - .messages - .cursor::>(&()); - let message_id = ChannelMessageId::Saved(message_id); - cursor.seek(&message_id, Bias::Left); - return ControlFlow::Break( - if cursor - .item() - .map_or(false, |message| message.id == message_id) - { - Some(cursor.start().1.0) - } else { - None - }, - ); - } - } - ControlFlow::Continue(chat.load_more_messages(cx)) - }) - .log_err()?; - match step { - ControlFlow::Break(ix) => return ix, - ControlFlow::Continue(task) => task?.await?, - } - } - } - - pub fn acknowledge_last_message(&mut self, cx: &mut Context) { - if let ChannelMessageId::Saved(latest_message_id) = self.messages.summary().max_id { - if self - .last_acknowledged_id - .map_or(true, |acknowledged_id| acknowledged_id < latest_message_id) - { - self.rpc - .send(proto::AckChannelMessage { - channel_id: self.channel_id.0, - message_id: latest_message_id, - }) - .ok(); - self.last_acknowledged_id = Some(latest_message_id); - self.channel_store.update(cx, |store, cx| { - store.acknowledge_message_id(self.channel_id, latest_message_id, cx); - }); - } - } - } - - async fn handle_loaded_messages( - this: WeakEntity, - user_store: Entity, - rpc: Arc, - proto_messages: Vec, - loaded_all_messages: bool, - cx: &mut AsyncApp, - ) -> Result<()> { - let loaded_messages = messages_from_proto(proto_messages, &user_store, cx).await?; - - let first_loaded_message_id = loaded_messages.first().map(|m| m.id); - let loaded_message_ids = this.read_with(cx, |this, _| { - let mut loaded_message_ids: HashSet = HashSet::default(); - for message in loaded_messages.iter() { - if let Some(saved_message_id) = message.id.into() { - loaded_message_ids.insert(saved_message_id); - } - } - for message in this.messages.iter() { - if let Some(saved_message_id) = message.id.into() { - loaded_message_ids.insert(saved_message_id); - } - } - loaded_message_ids - })?; - - let missing_ancestors = loaded_messages - .iter() - .filter_map(|message| { - if let Some(ancestor_id) = message.reply_to_message_id { - if !loaded_message_ids.contains(&ancestor_id) { - return Some(ancestor_id); - } - } - None - }) - .collect::>(); - - let loaded_ancestors = if missing_ancestors.is_empty() { - None - } else { - let response = rpc - .request(proto::GetChannelMessagesById { - message_ids: missing_ancestors, - }) - .await?; - Some(messages_from_proto(response.messages, &user_store, cx).await?) - }; - this.update(cx, |this, cx| { - this.first_loaded_message_id = first_loaded_message_id.and_then(|msg_id| msg_id.into()); - this.loaded_all_messages = loaded_all_messages; - this.insert_messages(loaded_messages, cx); - if let Some(loaded_ancestors) = loaded_ancestors { - this.insert_messages(loaded_ancestors, cx); - } - })?; - - Ok(()) - } - - pub fn rejoin(&mut self, cx: &mut Context) { - let user_store = self.user_store.clone(); - let rpc = self.rpc.clone(); - let channel_id = self.channel_id; - cx.spawn(async move |this, cx| { - async move { - let response = rpc - .request(proto::JoinChannelChat { - channel_id: channel_id.0, - }) - .await?; - Self::handle_loaded_messages( - this.clone(), - user_store.clone(), - rpc.clone(), - response.messages, - response.done, - cx, - ) - .await?; - - let pending_messages = this.read_with(cx, |this, _| { - this.pending_messages().cloned().collect::>() - })?; - - for pending_message in pending_messages { - let request = rpc.request(proto::SendChannelMessage { - channel_id: channel_id.0, - body: pending_message.body, - mentions: mentions_to_proto(&pending_message.mentions), - nonce: Some(pending_message.nonce.into()), - reply_to_message_id: pending_message.reply_to_message_id, - }); - let response = request.await?; - let message = ChannelMessage::from_proto( - response.message.context("invalid message")?, - &user_store, - cx, - ) - .await?; - this.update(cx, |this, cx| { - this.insert_messages(SumTree::from_item(message, &()), cx); - })?; - } - - anyhow::Ok(()) - } - .log_err() - .await - }) - .detach(); - } - - pub fn message_count(&self) -> usize { - self.messages.summary().count - } - - pub fn messages(&self) -> &SumTree { - &self.messages - } - - pub fn message(&self, ix: usize) -> &ChannelMessage { - let mut cursor = self.messages.cursor::(&()); - cursor.seek(&Count(ix), Bias::Right); - cursor.item().unwrap() - } - - pub fn acknowledge_message(&mut self, id: u64) { - if self.acknowledged_message_ids.insert(id) { - self.rpc - .send(proto::AckChannelMessage { - channel_id: self.channel_id.0, - message_id: id, - }) - .ok(); - } - } - - pub fn messages_in_range(&self, range: Range) -> impl Iterator { - let mut cursor = self.messages.cursor::(&()); - cursor.seek(&Count(range.start), Bias::Right); - cursor.take(range.len()) - } - - pub fn pending_messages(&self) -> impl Iterator { - let mut cursor = self.messages.cursor::(&()); - cursor.seek(&ChannelMessageId::Pending(0), Bias::Left); - cursor - } - - async fn handle_message_sent( - this: Entity, - message: TypedEnvelope, - mut cx: AsyncApp, - ) -> Result<()> { - let user_store = this.read_with(&mut cx, |this, _| this.user_store.clone())?; - let message = message.payload.message.context("empty message")?; - let message_id = message.id; - - let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?; - this.update(&mut cx, |this, cx| { - this.insert_messages(SumTree::from_item(message, &()), cx); - cx.emit(ChannelChatEvent::NewMessage { - channel_id: this.channel_id, - message_id, - }) - })?; - - Ok(()) - } - - async fn handle_message_removed( - this: Entity, - message: TypedEnvelope, - mut cx: AsyncApp, - ) -> Result<()> { - this.update(&mut cx, |this, cx| { - this.message_removed(message.payload.message_id, cx) - })?; - Ok(()) - } - - async fn handle_message_updated( - this: Entity, - message: TypedEnvelope, - mut cx: AsyncApp, - ) -> Result<()> { - let user_store = this.read_with(&mut cx, |this, _| this.user_store.clone())?; - let message = message.payload.message.context("empty message")?; - - let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?; - - this.update(&mut cx, |this, cx| { - this.message_update( - message.id, - message.body, - message.mentions, - message.edited_at, - cx, - ) - })?; - Ok(()) - } - - fn insert_messages(&mut self, messages: SumTree, cx: &mut Context) { - if let Some((first_message, last_message)) = messages.first().zip(messages.last()) { - let nonces = messages - .cursor::<()>(&()) - .map(|m| m.nonce) - .collect::>(); - - let mut old_cursor = self - .messages - .cursor::>(&()); - let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left); - let start_ix = old_cursor.start().1.0; - let removed_messages = old_cursor.slice(&last_message.id, Bias::Right); - let removed_count = removed_messages.summary().count; - let new_count = messages.summary().count; - let end_ix = start_ix + removed_count; - - new_messages.append(messages, &()); - - let mut ranges = Vec::>::new(); - if new_messages.last().unwrap().is_pending() { - new_messages.append(old_cursor.suffix(), &()); - } else { - new_messages.append( - old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left), - &(), - ); - - while let Some(message) = old_cursor.item() { - let message_ix = old_cursor.start().1.0; - if nonces.contains(&message.nonce) { - if ranges.last().map_or(false, |r| r.end == message_ix) { - ranges.last_mut().unwrap().end += 1; - } else { - ranges.push(message_ix..message_ix + 1); - } - } else { - new_messages.push(message.clone(), &()); - } - old_cursor.next(); - } - } - - drop(old_cursor); - self.messages = new_messages; - - for range in ranges.into_iter().rev() { - cx.emit(ChannelChatEvent::MessagesUpdated { - old_range: range, - new_count: 0, - }); - } - cx.emit(ChannelChatEvent::MessagesUpdated { - old_range: start_ix..end_ix, - new_count, - }); - - cx.notify(); - } - } - - fn message_removed(&mut self, id: u64, cx: &mut Context) { - let mut cursor = self.messages.cursor::(&()); - let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left); - if let Some(item) = cursor.item() { - if item.id == ChannelMessageId::Saved(id) { - let deleted_message_ix = messages.summary().count; - cursor.next(); - messages.append(cursor.suffix(), &()); - drop(cursor); - self.messages = messages; - - // If the message that was deleted was the last acknowledged message, - // replace the acknowledged message with an earlier one. - self.channel_store.update(cx, |store, _| { - let summary = self.messages.summary(); - if summary.count == 0 { - store.set_acknowledged_message_id(self.channel_id, None); - } else if deleted_message_ix == summary.count { - if let ChannelMessageId::Saved(id) = summary.max_id { - store.set_acknowledged_message_id(self.channel_id, Some(id)); - } - } - }); - - cx.emit(ChannelChatEvent::MessagesUpdated { - old_range: deleted_message_ix..deleted_message_ix + 1, - new_count: 0, - }); - } - } - } - - fn message_update( - &mut self, - id: ChannelMessageId, - body: String, - mentions: Vec<(Range, u64)>, - edited_at: Option, - cx: &mut Context, - ) { - let mut cursor = self.messages.cursor::(&()); - let mut messages = cursor.slice(&id, Bias::Left); - let ix = messages.summary().count; - - if let Some(mut message_to_update) = cursor.item().cloned() { - message_to_update.body = body; - message_to_update.mentions = mentions; - message_to_update.edited_at = edited_at; - messages.push(message_to_update, &()); - cursor.next(); - } - - messages.append(cursor.suffix(), &()); - drop(cursor); - self.messages = messages; - - cx.emit(ChannelChatEvent::UpdateMessage { - message_ix: ix, - message_id: id, - }); - - cx.notify(); - } -} - -async fn messages_from_proto( - proto_messages: Vec, - user_store: &Entity, - cx: &mut AsyncApp, -) -> Result> { - let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?; - let mut result = SumTree::default(); - result.extend(messages, &()); - Ok(result) -} - -impl ChannelMessage { - pub async fn from_proto( - message: proto::ChannelMessage, - user_store: &Entity, - cx: &mut AsyncApp, - ) -> Result { - let sender = user_store - .update(cx, |user_store, cx| { - user_store.get_user(message.sender_id, cx) - })? - .await?; - - let edited_at = message.edited_at.and_then(|t| -> Option { - if let Ok(a) = OffsetDateTime::from_unix_timestamp(t as i64) { - return Some(a); - } - - None - }); - - Ok(ChannelMessage { - id: ChannelMessageId::Saved(message.id), - body: message.body, - mentions: message - .mentions - .into_iter() - .filter_map(|mention| { - let range = mention.range?; - Some((range.start as usize..range.end as usize, mention.user_id)) - }) - .collect(), - timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?, - sender, - nonce: message.nonce.context("nonce is required")?.into(), - reply_to_message_id: message.reply_to_message_id, - edited_at, - }) - } - - pub fn is_pending(&self) -> bool { - matches!(self.id, ChannelMessageId::Pending(_)) - } - - pub async fn from_proto_vec( - proto_messages: Vec, - user_store: &Entity, - cx: &mut AsyncApp, - ) -> Result> { - let unique_user_ids = proto_messages - .iter() - .map(|m| m.sender_id) - .collect::>() - .into_iter() - .collect(); - user_store - .update(cx, |user_store, cx| { - user_store.get_users(unique_user_ids, cx) - })? - .await?; - - let mut messages = Vec::with_capacity(proto_messages.len()); - for message in proto_messages { - messages.push(ChannelMessage::from_proto(message, user_store, cx).await?); - } - Ok(messages) - } -} - -pub fn mentions_to_proto(mentions: &[(Range, UserId)]) -> Vec { - mentions - .iter() - .map(|(range, user_id)| proto::ChatMention { - range: Some(proto::Range { - start: range.start as u64, - end: range.end as u64, - }), - user_id: *user_id, - }) - .collect() -} - -impl sum_tree::Item for ChannelMessage { - type Summary = ChannelMessageSummary; - - fn summary(&self, _cx: &()) -> Self::Summary { - ChannelMessageSummary { - max_id: self.id, - count: 1, - } - } -} - -impl Default for ChannelMessageId { - fn default() -> Self { - Self::Saved(0) - } -} - -impl sum_tree::Summary for ChannelMessageSummary { - type Context = (); - - fn zero(_cx: &Self::Context) -> Self { - Default::default() - } - - fn add_summary(&mut self, summary: &Self, _: &()) { - self.max_id = summary.max_id; - self.count += summary.count; - } -} - -impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId { - fn zero(_cx: &()) -> Self { - Default::default() - } - - fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) { - debug_assert!(summary.max_id > *self); - *self = summary.max_id; - } -} - -impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count { - fn zero(_cx: &()) -> Self { - Default::default() - } - - fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) { - self.0 += summary.count; - } -} - -impl<'a> From<&'a str> for MessageParams { - fn from(value: &'a str) -> Self { - Self { - text: value.into(), - mentions: Vec::new(), - reply_to_message_id: None, - } - } -} diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index 4ad156b9fb08e8af95e5ea49132c4c4786e348a1..e983d03e0d6758f681de9e4a3e6fd13dc7075b01 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -1,6 +1,6 @@ mod channel_index; -use crate::{ChannelMessage, channel_buffer::ChannelBuffer, channel_chat::ChannelChat}; +use crate::channel_buffer::ChannelBuffer; use anyhow::{Context as _, Result, anyhow}; use channel_index::ChannelIndex; use client::{ChannelId, Client, ClientSettings, Subscription, User, UserId, UserStore}; @@ -41,7 +41,6 @@ pub struct ChannelStore { outgoing_invites: HashSet<(ChannelId, UserId)>, update_channels_tx: mpsc::UnboundedSender, opened_buffers: HashMap>, - opened_chats: HashMap>, client: Arc, did_subscribe: bool, channels_loaded: (watch::Sender, watch::Receiver), @@ -63,10 +62,8 @@ pub struct Channel { #[derive(Default, Debug)] pub struct ChannelState { - latest_chat_message: Option, latest_notes_version: NotesVersion, observed_notes_version: NotesVersion, - observed_chat_message: Option, role: Option, } @@ -196,7 +193,6 @@ impl ChannelStore { channel_participants: Default::default(), outgoing_invites: Default::default(), opened_buffers: Default::default(), - opened_chats: Default::default(), update_channels_tx, client, user_store, @@ -262,13 +258,12 @@ impl ChannelStore { } } status = status_receiver.next().fuse() => { - if let Some(status) = status { - if status.is_connected() { + if let Some(status) = status + && status.is_connected() { this.update(cx, |this, _cx| { this.initialize(); }).ok(); } - } continue; } _ = timer => { @@ -336,10 +331,10 @@ impl ChannelStore { } pub fn has_open_channel_buffer(&self, channel_id: ChannelId, _cx: &App) -> bool { - if let Some(buffer) = self.opened_buffers.get(&channel_id) { - if let OpenEntityHandle::Open(buffer) = buffer { - return buffer.upgrade().is_some(); - } + if let Some(buffer) = self.opened_buffers.get(&channel_id) + && let OpenEntityHandle::Open(buffer) = buffer + { + return buffer.upgrade().is_some(); } false } @@ -363,90 +358,12 @@ impl ChannelStore { ) } - pub fn fetch_channel_messages( - &self, - message_ids: Vec, - cx: &mut Context, - ) -> Task>> { - let request = if message_ids.is_empty() { - None - } else { - Some( - self.client - .request(proto::GetChannelMessagesById { message_ids }), - ) - }; - cx.spawn(async move |this, cx| { - if let Some(request) = request { - let response = request.await?; - let this = this.upgrade().context("channel store dropped")?; - let user_store = this.read_with(cx, |this, _| this.user_store.clone())?; - ChannelMessage::from_proto_vec(response.messages, &user_store, cx).await - } else { - Ok(Vec::new()) - } - }) - } - pub fn has_channel_buffer_changed(&self, channel_id: ChannelId) -> bool { self.channel_states .get(&channel_id) .is_some_and(|state| state.has_channel_buffer_changed()) } - pub fn has_new_messages(&self, channel_id: ChannelId) -> bool { - self.channel_states - .get(&channel_id) - .is_some_and(|state| state.has_new_messages()) - } - - pub fn set_acknowledged_message_id(&mut self, channel_id: ChannelId, message_id: Option) { - if let Some(state) = self.channel_states.get_mut(&channel_id) { - state.latest_chat_message = message_id; - } - } - - pub fn last_acknowledge_message_id(&self, channel_id: ChannelId) -> Option { - self.channel_states.get(&channel_id).and_then(|state| { - if let Some(last_message_id) = state.latest_chat_message { - if state - .last_acknowledged_message_id() - .is_some_and(|id| id < last_message_id) - { - return state.last_acknowledged_message_id(); - } - } - - None - }) - } - - pub fn acknowledge_message_id( - &mut self, - channel_id: ChannelId, - message_id: u64, - cx: &mut Context, - ) { - self.channel_states - .entry(channel_id) - .or_default() - .acknowledge_message_id(message_id); - cx.notify(); - } - - pub fn update_latest_message_id( - &mut self, - channel_id: ChannelId, - message_id: u64, - cx: &mut Context, - ) { - self.channel_states - .entry(channel_id) - .or_default() - .update_latest_message_id(message_id); - cx.notify(); - } - pub fn acknowledge_notes_version( &mut self, channel_id: ChannelId, @@ -475,23 +392,6 @@ impl ChannelStore { cx.notify() } - pub fn open_channel_chat( - &mut self, - channel_id: ChannelId, - cx: &mut Context, - ) -> Task>> { - let client = self.client.clone(); - let user_store = self.user_store.clone(); - let this = cx.entity(); - self.open_channel_resource( - channel_id, - "chat", - |this| &mut this.opened_chats, - async move |channel, cx| ChannelChat::new(channel, this, user_store, client, cx).await, - cx, - ) - } - /// Asynchronously open a given resource associated with a channel. /// /// Make sure that the resource is only opened once, even if this method @@ -570,16 +470,14 @@ impl ChannelStore { self.channel_index .by_id() .get(&channel_id) - .map_or(false, |channel| channel.is_root_channel()) + .is_some_and(|channel| channel.is_root_channel()) } pub fn is_public_channel(&self, channel_id: ChannelId) -> bool { self.channel_index .by_id() .get(&channel_id) - .map_or(false, |channel| { - channel.visibility == ChannelVisibility::Public - }) + .is_some_and(|channel| channel.visibility == ChannelVisibility::Public) } pub fn channel_capability(&self, channel_id: ChannelId) -> Capability { @@ -910,9 +808,9 @@ impl ChannelStore { async fn handle_update_channels( this: Entity, message: TypedEnvelope, - mut cx: AsyncApp, + cx: AsyncApp, ) -> Result<()> { - this.read_with(&mut cx, |this, _| { + this.read_with(&cx, |this, _| { this.update_channels_tx .unbounded_send(message.payload) .unwrap(); @@ -935,13 +833,6 @@ impl ChannelStore { cx, ); } - for message_id in message.payload.observed_channel_message_id { - this.acknowledge_message_id( - ChannelId(message_id.channel_id), - message_id.message_id, - cx, - ); - } for membership in message.payload.channel_memberships { if let Some(role) = ChannelRole::from_i32(membership.role) { this.channel_states @@ -961,28 +852,18 @@ impl ChannelStore { self.outgoing_invites.clear(); self.disconnect_channel_buffers_task.take(); - for chat in self.opened_chats.values() { - if let OpenEntityHandle::Open(chat) = chat { - if let Some(chat) = chat.upgrade() { - chat.update(cx, |chat, cx| { - chat.rejoin(cx); - }); - } - } - } - let mut buffer_versions = Vec::new(); for buffer in self.opened_buffers.values() { - if let OpenEntityHandle::Open(buffer) = buffer { - if let Some(buffer) = buffer.upgrade() { - let channel_buffer = buffer.read(cx); - let buffer = channel_buffer.buffer().read(cx); - buffer_versions.push(proto::ChannelBufferVersion { - channel_id: channel_buffer.channel_id.0, - epoch: channel_buffer.epoch(), - version: language::proto::serialize_version(&buffer.version()), - }); - } + if let OpenEntityHandle::Open(buffer) = buffer + && let Some(buffer) = buffer.upgrade() + { + let channel_buffer = buffer.read(cx); + let buffer = channel_buffer.buffer().read(cx); + buffer_versions.push(proto::ChannelBufferVersion { + channel_id: channel_buffer.channel_id.0, + epoch: channel_buffer.epoch(), + version: language::proto::serialize_version(&buffer.version()), + }); } } @@ -1077,11 +958,11 @@ impl ChannelStore { if let Some(this) = this.upgrade() { this.update(cx, |this, cx| { - for (_, buffer) in &this.opened_buffers { - if let OpenEntityHandle::Open(buffer) = &buffer { - if let Some(buffer) = buffer.upgrade() { - buffer.update(cx, |buffer, cx| buffer.disconnect(cx)); - } + for buffer in this.opened_buffers.values() { + if let OpenEntityHandle::Open(buffer) = &buffer + && let Some(buffer) = buffer.upgrade() + { + buffer.update(cx, |buffer, cx| buffer.disconnect(cx)); } } }) @@ -1098,7 +979,6 @@ impl ChannelStore { self.channel_participants.clear(); self.outgoing_invites.clear(); self.opened_buffers.clear(); - self.opened_chats.clear(); self.disconnect_channel_buffers_task = None; self.channel_states.clear(); } @@ -1135,7 +1015,6 @@ impl ChannelStore { let channels_changed = !payload.channels.is_empty() || !payload.delete_channels.is_empty() - || !payload.latest_channel_message_ids.is_empty() || !payload.latest_channel_buffer_versions.is_empty(); if channels_changed { @@ -1157,10 +1036,9 @@ impl ChannelStore { } if let Some(OpenEntityHandle::Open(buffer)) = self.opened_buffers.remove(&channel_id) + && let Some(buffer) = buffer.upgrade() { - if let Some(buffer) = buffer.upgrade() { - buffer.update(cx, ChannelBuffer::disconnect); - } + buffer.update(cx, ChannelBuffer::disconnect); } } } @@ -1170,12 +1048,11 @@ impl ChannelStore { let id = ChannelId(channel.id); let channel_changed = index.insert(channel); - if channel_changed { - if let Some(OpenEntityHandle::Open(buffer)) = self.opened_buffers.get(&id) { - if let Some(buffer) = buffer.upgrade() { - buffer.update(cx, ChannelBuffer::channel_changed); - } - } + if channel_changed + && let Some(OpenEntityHandle::Open(buffer)) = self.opened_buffers.get(&id) + && let Some(buffer) = buffer.upgrade() + { + buffer.update(cx, ChannelBuffer::channel_changed); } } @@ -1187,13 +1064,6 @@ impl ChannelStore { .update_latest_notes_version(latest_buffer_version.epoch, &version) } - for latest_channel_message in payload.latest_channel_message_ids { - self.channel_states - .entry(ChannelId(latest_channel_message.channel_id)) - .or_default() - .update_latest_message_id(latest_channel_message.message_id); - } - self.channels_loaded.0.try_send(true).log_err(); } @@ -1257,29 +1127,6 @@ impl ChannelState { .changed_since(&self.observed_notes_version.version)) } - fn has_new_messages(&self) -> bool { - let latest_message_id = self.latest_chat_message; - let observed_message_id = self.observed_chat_message; - - latest_message_id.is_some_and(|latest_message_id| { - latest_message_id > observed_message_id.unwrap_or_default() - }) - } - - fn last_acknowledged_message_id(&self) -> Option { - self.observed_chat_message - } - - fn acknowledge_message_id(&mut self, message_id: u64) { - let observed = self.observed_chat_message.get_or_insert(message_id); - *observed = (*observed).max(message_id); - } - - fn update_latest_message_id(&mut self, message_id: u64) { - self.latest_chat_message = - Some(message_id.max(self.latest_chat_message.unwrap_or_default())); - } - fn acknowledge_notes_version(&mut self, epoch: u64, version: &clock::Global) { if self.observed_notes_version.epoch == epoch { self.observed_notes_version.version.join(version); diff --git a/crates/channel/src/channel_store_tests.rs b/crates/channel/src/channel_store_tests.rs index c92226eeebd131170b0a5b04e4ed7f42c19a64fc..fbdfe9f8b59f2b5e47720bb497c56b47c8abb77e 100644 --- a/crates/channel/src/channel_store_tests.rs +++ b/crates/channel/src/channel_store_tests.rs @@ -1,9 +1,7 @@ -use crate::channel_chat::ChannelChatEvent; - use super::*; -use client::{Client, UserStore, test::FakeServer}; +use client::{Client, UserStore}; use clock::FakeSystemClock; -use gpui::{App, AppContext as _, Entity, SemanticVersion, TestAppContext}; +use gpui::{App, AppContext as _, Entity, SemanticVersion}; use http_client::FakeHttpClient; use rpc::proto::{self}; use settings::SettingsStore; @@ -235,201 +233,6 @@ fn test_dangling_channel_paths(cx: &mut App) { assert_channels(&channel_store, &[(0, "a".to_string())], cx); } -#[gpui::test] -async fn test_channel_messages(cx: &mut TestAppContext) { - let user_id = 5; - let channel_id = 5; - let channel_store = cx.update(init_test); - let client = channel_store.read_with(cx, |s, _| s.client()); - let server = FakeServer::for_client(user_id, &client, cx).await; - - // Get the available channels. - server.send(proto::UpdateChannels { - channels: vec![proto::Channel { - id: channel_id, - name: "the-channel".to_string(), - visibility: proto::ChannelVisibility::Members as i32, - parent_path: vec![], - channel_order: 1, - }], - ..Default::default() - }); - cx.executor().run_until_parked(); - cx.update(|cx| { - assert_channels(&channel_store, &[(0, "the-channel".to_string())], cx); - }); - - // Join a channel and populate its existing messages. - let channel = channel_store.update(cx, |store, cx| { - let channel_id = store.ordered_channels().next().unwrap().1.id; - store.open_channel_chat(channel_id, cx) - }); - let join_channel = server.receive::().await.unwrap(); - server.respond( - join_channel.receipt(), - proto::JoinChannelChatResponse { - messages: vec![ - proto::ChannelMessage { - id: 10, - body: "a".into(), - timestamp: 1000, - sender_id: 5, - mentions: vec![], - nonce: Some(1.into()), - reply_to_message_id: None, - edited_at: None, - }, - proto::ChannelMessage { - id: 11, - body: "b".into(), - timestamp: 1001, - sender_id: 6, - mentions: vec![], - nonce: Some(2.into()), - reply_to_message_id: None, - edited_at: None, - }, - ], - done: false, - }, - ); - - cx.executor().start_waiting(); - - // Client requests all users for the received messages - let mut get_users = server.receive::().await.unwrap(); - get_users.payload.user_ids.sort(); - assert_eq!(get_users.payload.user_ids, vec![6]); - server.respond( - get_users.receipt(), - proto::UsersResponse { - users: vec![proto::User { - id: 6, - github_login: "maxbrunsfeld".into(), - avatar_url: "http://avatar.com/maxbrunsfeld".into(), - name: None, - }], - }, - ); - - let channel = channel.await.unwrap(); - channel.update(cx, |channel, _| { - assert_eq!( - channel - .messages_in_range(0..2) - .map(|message| (message.sender.github_login.clone(), message.body.clone())) - .collect::>(), - &[ - ("user-5".into(), "a".into()), - ("maxbrunsfeld".into(), "b".into()) - ] - ); - }); - - // Receive a new message. - server.send(proto::ChannelMessageSent { - channel_id, - message: Some(proto::ChannelMessage { - id: 12, - body: "c".into(), - timestamp: 1002, - sender_id: 7, - mentions: vec![], - nonce: Some(3.into()), - reply_to_message_id: None, - edited_at: None, - }), - }); - - // Client requests user for message since they haven't seen them yet - let get_users = server.receive::().await.unwrap(); - assert_eq!(get_users.payload.user_ids, vec![7]); - server.respond( - get_users.receipt(), - proto::UsersResponse { - users: vec![proto::User { - id: 7, - github_login: "as-cii".into(), - avatar_url: "http://avatar.com/as-cii".into(), - name: None, - }], - }, - ); - - assert_eq!( - channel.next_event(cx).await, - ChannelChatEvent::MessagesUpdated { - old_range: 2..2, - new_count: 1, - } - ); - channel.update(cx, |channel, _| { - assert_eq!( - channel - .messages_in_range(2..3) - .map(|message| (message.sender.github_login.clone(), message.body.clone())) - .collect::>(), - &[("as-cii".into(), "c".into())] - ) - }); - - // Scroll up to view older messages. - channel.update(cx, |channel, cx| { - channel.load_more_messages(cx).unwrap().detach(); - }); - let get_messages = server.receive::().await.unwrap(); - assert_eq!(get_messages.payload.channel_id, 5); - assert_eq!(get_messages.payload.before_message_id, 10); - server.respond( - get_messages.receipt(), - proto::GetChannelMessagesResponse { - done: true, - messages: vec![ - proto::ChannelMessage { - id: 8, - body: "y".into(), - timestamp: 998, - sender_id: 5, - nonce: Some(4.into()), - mentions: vec![], - reply_to_message_id: None, - edited_at: None, - }, - proto::ChannelMessage { - id: 9, - body: "z".into(), - timestamp: 999, - sender_id: 6, - nonce: Some(5.into()), - mentions: vec![], - reply_to_message_id: None, - edited_at: None, - }, - ], - }, - ); - - assert_eq!( - channel.next_event(cx).await, - ChannelChatEvent::MessagesUpdated { - old_range: 0..0, - new_count: 2, - } - ); - channel.update(cx, |channel, _| { - assert_eq!( - channel - .messages_in_range(0..2) - .map(|message| (message.sender.github_login.clone(), message.body.clone())) - .collect::>(), - &[ - ("user-5".into(), "y".into()), - ("maxbrunsfeld".into(), "z".into()) - ] - ); - }); -} - fn init_test(cx: &mut App) -> Entity { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); @@ -438,7 +241,7 @@ fn init_test(cx: &mut App) -> Entity { let clock = Arc::new(FakeSystemClock::new()); let http = FakeHttpClient::with_404_response(); - let client = Client::new(clock, http.clone(), cx); + let client = Client::new(clock, http, cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); client::init(&client, cx); diff --git a/crates/cli/src/cli.rs b/crates/cli/src/cli.rs index 6274f69035a02bed20d1a85608371744395c951a..79a10fa2b0936b44d9500fd9990ffa4c6ac62e85 100644 --- a/crates/cli/src/cli.rs +++ b/crates/cli/src/cli.rs @@ -14,6 +14,7 @@ pub enum CliRequest { paths: Vec, urls: Vec, diff_paths: Vec<[String; 2]>, + wsl: Option, wait: bool, open_new_workspace: Option, env: Option>, diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index 67591167dfdbfd758f480b5538471aa65175e859..d4b4a350f61b5bd1249b33ff3925dd281e9d529c 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -6,7 +6,6 @@ use anyhow::{Context as _, Result}; use clap::Parser; use cli::{CliRequest, CliResponse, IpcHandshake, ipc::IpcOneShotServer}; -use collections::HashMap; use parking_lot::Mutex; use std::{ env, fs, io, @@ -85,6 +84,18 @@ struct Args { /// Run zed in dev-server mode #[arg(long)] dev_server_token: Option, + /// The username and WSL distribution to use when opening paths. If not specified, + /// Zed will attempt to open the paths directly. + /// + /// The username is optional, and if not specified, the default user for the distribution + /// will be used. + /// + /// Example: `me@Ubuntu` or `Ubuntu`. + /// + /// WARN: You should not fill in this field by hand. + #[cfg(target_os = "windows")] + #[arg(long, value_name = "USER@DISTRO")] + wsl: Option, /// Not supported in Zed CLI, only supported on Zed binary /// Will attempt to give the correct command to run #[arg(long)] @@ -129,14 +140,41 @@ fn parse_path_with_position(argument_str: &str) -> anyhow::Result { Ok(canonicalized.to_string(|path| path.to_string_lossy().to_string())) } -fn main() -> Result<()> { - #[cfg(all(not(debug_assertions), target_os = "windows"))] - unsafe { - use ::windows::Win32::System::Console::{ATTACH_PARENT_PROCESS, AttachConsole}; +fn parse_path_in_wsl(source: &str, wsl: &str) -> Result { + let mut command = util::command::new_std_command("wsl.exe"); - let _ = AttachConsole(ATTACH_PARENT_PROCESS); + let (user, distro_name) = if let Some((user, distro)) = wsl.split_once('@') { + if user.is_empty() { + anyhow::bail!("user is empty in wsl argument"); + } + (Some(user), distro) + } else { + (None, wsl) + }; + + if let Some(user) = user { + command.arg("--user").arg(user); } + let output = command + .arg("--distribution") + .arg(distro_name) + .arg("wslpath") + .arg("-m") + .arg(source) + .output()?; + + let result = String::from_utf8_lossy(&output.stdout); + let prefix = format!("//wsl.localhost/{}", distro_name); + + Ok(result + .trim() + .strip_prefix(&prefix) + .unwrap_or(&result) + .to_string()) +} + +fn main() -> Result<()> { #[cfg(unix)] util::prevent_root_execution(); @@ -223,6 +261,8 @@ fn main() -> Result<()> { let env = { #[cfg(any(target_os = "linux", target_os = "freebsd"))] { + use collections::HashMap; + // On Linux, the desktop entry uses `cli` to spawn `zed`. // We need to handle env vars correctly since std::env::vars() may not contain // project-specific vars (e.g. those set by direnv). @@ -235,8 +275,19 @@ fn main() -> Result<()> { } } - #[cfg(not(any(target_os = "linux", target_os = "freebsd")))] - Some(std::env::vars().collect::>()) + #[cfg(target_os = "windows")] + { + // On Windows, by default, a child process inherits a copy of the environment block of the parent process. + // So we don't need to pass env vars explicitly. + None + } + + #[cfg(not(any(target_os = "linux", target_os = "freebsd", target_os = "windows")))] + { + use collections::HashMap; + + Some(std::env::vars().collect::>()) + } }; let exit_status = Arc::new(Mutex::new(None)); @@ -253,6 +304,11 @@ fn main() -> Result<()> { ]); } + #[cfg(target_os = "windows")] + let wsl = args.wsl.as_ref(); + #[cfg(not(target_os = "windows"))] + let wsl = None; + for path in args.paths_with_position.iter() { if path.starts_with("zed://") || path.starts_with("http://") @@ -271,8 +327,10 @@ fn main() -> Result<()> { paths.push(tmp_file.path().to_string_lossy().to_string()); let (tmp_file, _) = tmp_file.keep()?; anonymous_fd_tmp_files.push((file, tmp_file)); + } else if let Some(wsl) = wsl { + urls.push(format!("file://{}", parse_path_in_wsl(path, wsl)?)); } else { - paths.push(parse_path_with_position(path)?) + paths.push(parse_path_with_position(path)?); } } @@ -288,10 +346,16 @@ fn main() -> Result<()> { let (_, handshake) = server.accept().context("Handshake after Zed spawn")?; let (tx, rx) = (handshake.requests, handshake.responses); + #[cfg(target_os = "windows")] + let wsl = args.wsl; + #[cfg(not(target_os = "windows"))] + let wsl = None; + tx.send(CliRequest::Open { paths, urls, diff_paths, + wsl, wait: args.wait, open_new_workspace, env, @@ -363,7 +427,7 @@ fn anonymous_fd(path: &str) -> Option { let fd: fd::RawFd = fd_str.parse().ok()?; let file = unsafe { fs::File::from_raw_fd(fd) }; - return Some(file); + Some(file) } #[cfg(any(target_os = "macos", target_os = "freebsd"))] { @@ -381,13 +445,13 @@ fn anonymous_fd(path: &str) -> Option { } let fd: fd::RawFd = fd_str.parse().ok()?; let file = unsafe { fs::File::from_raw_fd(fd) }; - return Some(file); + Some(file) } #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "freebsd")))] { _ = path; // not implemented for bsd, windows. Could be, but isn't yet - return None; + None } } @@ -494,11 +558,11 @@ mod linux { Ok(Fork::Parent(_)) => Ok(()), Ok(Fork::Child) => { unsafe { std::env::set_var(FORCE_CLI_MODE_ENV_VAR_NAME, "") }; - if let Err(_) = fork::setsid() { + if fork::setsid().is_err() { eprintln!("failed to setsid: {}", std::io::Error::last_os_error()); process::exit(1); } - if let Err(_) = fork::close_fd() { + if fork::close_fd().is_err() { eprintln!("failed to close_fd: {}", std::io::Error::last_os_error()); } let error = @@ -518,11 +582,11 @@ mod linux { ) -> Result<(), std::io::Error> { for _ in 0..100 { thread::sleep(Duration::from_millis(10)); - if sock.connect_addr(&sock_addr).is_ok() { + if sock.connect_addr(sock_addr).is_ok() { return Ok(()); } } - sock.connect_addr(&sock_addr) + sock.connect_addr(sock_addr) } } } @@ -534,8 +598,8 @@ mod flatpak { use std::process::Command; use std::{env, process}; - const EXTRA_LIB_ENV_NAME: &'static str = "ZED_FLATPAK_LIB_PATH"; - const NO_ESCAPE_ENV_NAME: &'static str = "ZED_FLATPAK_NO_ESCAPE"; + const EXTRA_LIB_ENV_NAME: &str = "ZED_FLATPAK_LIB_PATH"; + const NO_ESCAPE_ENV_NAME: &str = "ZED_FLATPAK_NO_ESCAPE"; /// Adds bundled libraries to LD_LIBRARY_PATH if running under flatpak pub fn ld_extra_libs() { @@ -586,14 +650,11 @@ mod flatpak { pub fn set_bin_if_no_escape(mut args: super::Args) -> super::Args { if env::var(NO_ESCAPE_ENV_NAME).is_ok() - && env::var("FLATPAK_ID").map_or(false, |id| id.starts_with("dev.zed.Zed")) + && env::var("FLATPAK_ID").is_ok_and(|id| id.starts_with("dev.zed.Zed")) + && args.zed.is_none() { - if args.zed.is_none() { - args.zed = Some("/app/libexec/zed-editor".into()); - unsafe { - env::set_var("ZED_UPDATE_EXPLANATION", "Please use flatpak to update zed") - }; - } + args.zed = Some("/app/libexec/zed-editor".into()); + unsafe { env::set_var("ZED_UPDATE_EXPLANATION", "Please use flatpak to update zed") }; } args } @@ -647,15 +708,15 @@ mod windows { Storage::FileSystem::{ CreateFileW, FILE_FLAGS_AND_ATTRIBUTES, FILE_SHARE_MODE, OPEN_EXISTING, WriteFile, }, - System::Threading::CreateMutexW, + System::Threading::{CREATE_NEW_PROCESS_GROUP, CreateMutexW}, }, core::HSTRING, }; use crate::{Detect, InstalledApp}; - use std::io; use std::path::{Path, PathBuf}; use std::process::ExitStatus; + use std::{io, os::windows::process::CommandExt}; fn check_single_instance() -> bool { let mutex = unsafe { @@ -694,6 +755,7 @@ mod windows { fn launch(&self, ipc_url: String) -> anyhow::Result<()> { if check_single_instance() { std::process::Command::new(self.0.clone()) + .creation_flags(CREATE_NEW_PROCESS_GROUP.0) .arg(ipc_url) .spawn()?; } else { @@ -929,7 +991,7 @@ mod mac_os { fn path(&self) -> PathBuf { match self { - Bundle::App { app_bundle, .. } => app_bundle.join("Contents/MacOS/zed").clone(), + Bundle::App { app_bundle, .. } => app_bundle.join("Contents/MacOS/zed"), Bundle::LocalPath { executable, .. } => executable.clone(), } } diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index 365625b44535e474baecf058c98f54aaf05b5e49..01007cdc6618996735c859284e3860b936f540e8 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -44,6 +44,7 @@ rpc = { workspace = true, features = ["gpui"] } schemars.workspace = true serde.workspace = true serde_json.workspace = true +serde_urlencoded.workspace = true settings.workspace = true sha2.workspace = true smol.workspace = true @@ -74,7 +75,7 @@ util = { workspace = true, features = ["test-support"] } windows.workspace = true [target.'cfg(target_os = "macos")'.dependencies] -cocoa.workspace = true +objc2-foundation.workspace = true [target.'cfg(any(target_os = "windows", target_os = "macos"))'.dependencies] tokio-native-tls = "0.3" diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index f09c012a858e3cf97166dae9dbdbeb3da51b96b6..cb8185c7ed326ed7d45726a99077c53903118316 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -31,7 +31,7 @@ use release_channel::{AppVersion, ReleaseChannel}; use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsUi}; use std::{ any::TypeId, convert::TryFrom, @@ -66,6 +66,8 @@ pub static IMPERSONATE_LOGIN: LazyLock> = LazyLock::new(|| { .and_then(|s| if s.is_empty() { None } else { Some(s) }) }); +pub static USE_WEB_LOGIN: LazyLock = LazyLock::new(|| std::env::var("ZED_WEB_LOGIN").is_ok()); + pub static ADMIN_API_TOKEN: LazyLock> = LazyLock::new(|| { std::env::var("ZED_ADMIN_API_TOKEN") .ok() @@ -76,7 +78,7 @@ pub static ZED_APP_PATH: LazyLock> = LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from)); pub static ZED_ALWAYS_ACTIVE: LazyLock = - LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty())); + LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").is_ok_and(|e| !e.is_empty())); pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500); pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30); @@ -94,7 +96,8 @@ actions!( ] ); -#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, SettingsUi, SettingsKey)] +#[settings_key(None)] pub struct ClientSettingsContent { server_url: Option, } @@ -105,8 +108,6 @@ pub struct ClientSettings { } impl Settings for ClientSettings { - const KEY: Option<&'static str> = None; - type FileContent = ClientSettingsContent; fn load(sources: SettingsSources, _: &mut App) -> Result { @@ -120,7 +121,8 @@ impl Settings for ClientSettings { fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} } -#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)] +#[derive(Default, Clone, Serialize, Deserialize, JsonSchema, SettingsUi, SettingsKey)] +#[settings_key(None)] pub struct ProxySettingsContent { proxy: Option, } @@ -131,8 +133,6 @@ pub struct ProxySettings { } impl Settings for ProxySettings { - const KEY: Option<&'static str> = None; - type FileContent = ProxySettingsContent; fn load(sources: SettingsSources, _: &mut App) -> Result { @@ -162,7 +162,7 @@ pub fn init(client: &Arc, cx: &mut App) { let client = client.clone(); move |_: &SignIn, cx| { if let Some(client) = client.upgrade() { - cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, &cx).await) + cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, cx).await) .detach_and_log_err(cx); } } @@ -173,7 +173,7 @@ pub fn init(client: &Arc, cx: &mut App) { move |_: &SignOut, cx| { if let Some(client) = client.upgrade() { cx.spawn(async move |cx| { - client.sign_out(&cx).await; + client.sign_out(cx).await; }) .detach(); } @@ -181,11 +181,11 @@ pub fn init(client: &Arc, cx: &mut App) { }); cx.on_action({ - let client = client.clone(); + let client = client; move |_: &Reconnect, cx| { if let Some(client) = client.upgrade() { cx.spawn(async move |cx| { - client.reconnect(&cx); + client.reconnect(cx); }) .detach(); } @@ -285,6 +285,7 @@ pub enum Status { }, ConnectionLost, Reauthenticating, + Reauthenticated, Reconnecting, ReconnectionError { next_reconnection: Instant, @@ -296,6 +297,21 @@ impl Status { matches!(self, Self::Connected { .. }) } + pub fn was_connected(&self) -> bool { + matches!( + self, + Self::ConnectionLost + | Self::Reauthenticating + | Self::Reauthenticated + | Self::Reconnecting + ) + } + + /// Returns whether the client is currently connected or was connected at some point. + pub fn is_or_was_connected(&self) -> bool { + self.is_connected() || self.was_connected() + } + pub fn is_signing_in(&self) -> bool { matches!( self, @@ -509,7 +525,8 @@ pub struct TelemetrySettings { } /// Control what info is collected by Zed. -#[derive(Default, Clone, Serialize, Deserialize, JsonSchema, Debug)] +#[derive(Default, Clone, Serialize, Deserialize, JsonSchema, Debug, SettingsUi, SettingsKey)] +#[settings_key(key = "telemetry")] pub struct TelemetrySettingsContent { /// Send debug info like crash reports. /// @@ -522,8 +539,6 @@ pub struct TelemetrySettingsContent { } impl settings::Settings for TelemetrySettings { - const KEY: Option<&'static str> = Some("telemetry"); - type FileContent = TelemetrySettingsContent; fn load(sources: SettingsSources, _: &mut App) -> Result { @@ -673,11 +688,11 @@ impl Client { #[cfg(any(test, feature = "test-support"))] let mut rng = StdRng::seed_from_u64(0); #[cfg(not(any(test, feature = "test-support")))] - let mut rng = StdRng::from_entropy(); + let mut rng = StdRng::from_os_rng(); let mut delay = INITIAL_RECONNECTION_DELAY; loop { - match client.connect(true, &cx).await { + match client.connect(true, cx).await { ConnectionResult::Timeout => { log::error!("client connect attempt timed out") } @@ -701,10 +716,11 @@ impl Client { Status::ReconnectionError { next_reconnection: Instant::now() + delay, }, - &cx, + cx, + ); + let jitter = Duration::from_millis( + rng.random_range(0..delay.as_millis() as u64), ); - let jitter = - Duration::from_millis(rng.gen_range(0..delay.as_millis() as u64)); cx.background_executor().timer(delay + jitter).await; delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY); } else { @@ -791,7 +807,7 @@ impl Client { Arc::new(move |subscriber, envelope, client, cx| { let subscriber = subscriber.downcast::().unwrap(); let envelope = envelope.into_any().downcast::>().unwrap(); - handler(subscriber, *envelope, client.clone(), cx).boxed_local() + handler(subscriber, *envelope, client, cx).boxed_local() }), ); if prev_handler.is_some() { @@ -855,31 +871,34 @@ impl Client { try_provider: bool, cx: &AsyncApp, ) -> Result { - if self.status().borrow().is_signed_out() { + let is_reauthenticating = if self.status().borrow().is_signed_out() { self.set_status(Status::Authenticating, cx); + false } else { self.set_status(Status::Reauthenticating, cx); - } + true + }; let mut credentials = None; let old_credentials = self.state.read().credentials.clone(); - if let Some(old_credentials) = old_credentials { - if self.validate_credentials(&old_credentials, cx).await? { - credentials = Some(old_credentials); - } + if let Some(old_credentials) = old_credentials + && self.validate_credentials(&old_credentials, cx).await? + { + credentials = Some(old_credentials); } - if credentials.is_none() && try_provider { - if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await { - if self.validate_credentials(&stored_credentials, cx).await? { - credentials = Some(stored_credentials); - } else { - self.credentials_provider - .delete_credentials(cx) - .await - .log_err(); - } + if credentials.is_none() + && try_provider + && let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await + { + if self.validate_credentials(&stored_credentials, cx).await? { + credentials = Some(stored_credentials); + } else { + self.credentials_provider + .delete_credentials(cx) + .await + .log_err(); } } @@ -916,7 +935,14 @@ impl Client { self.cloud_client .set_credentials(credentials.user_id as u32, credentials.access_token.clone()); self.state.write().credentials = Some(credentials.clone()); - self.set_status(Status::Authenticated, cx); + self.set_status( + if is_reauthenticating { + Status::Reauthenticated + } else { + Status::Authenticated + }, + cx, + ); Ok(credentials) } @@ -973,6 +999,11 @@ impl Client { try_provider: bool, cx: &AsyncApp, ) -> Result<()> { + // Don't try to sign in again if we're already connected to Collab, as it will temporarily disconnect us. + if self.status().borrow().is_connected() { + return Ok(()); + } + let (is_staff_tx, is_staff_rx) = oneshot::channel::(); let mut is_staff_tx = Some(is_staff_tx); cx.update(|cx| { @@ -1023,11 +1054,12 @@ impl Client { Status::SignedOut | Status::Authenticated => true, Status::ConnectionError | Status::ConnectionLost - | Status::Authenticating { .. } + | Status::Authenticating | Status::AuthenticationError - | Status::Reauthenticating { .. } + | Status::Reauthenticating + | Status::Reauthenticated | Status::ReconnectionError { .. } => false, - Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => { + Status::Connected { .. } | Status::Connecting | Status::Reconnecting => { return ConnectionResult::Result(Ok(())); } Status::UpgradeRequired => { @@ -1151,7 +1183,7 @@ impl Client { let this = self.clone(); async move |cx| { while let Some(message) = incoming.next().await { - this.handle_message(message, &cx); + this.handle_message(message, cx); // Don't starve the main thread when receiving lots of messages at once. smol::future::yield_now().await; } @@ -1169,12 +1201,12 @@ impl Client { peer_id, }) { - this.set_status(Status::SignedOut, &cx); + this.set_status(Status::SignedOut, cx); } } Err(err) => { log::error!("connection error: {:?}", err); - this.set_status(Status::ConnectionLost, &cx); + this.set_status(Status::ConnectionLost, cx); } } }) @@ -1284,19 +1316,21 @@ impl Client { "http" => Http, _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?, }; - let rpc_host = rpc_url - .host_str() - .zip(rpc_url.port_or_known_default()) - .context("missing host in rpc url")?; - - let stream = { - let handle = cx.update(|cx| gpui_tokio::Tokio::handle(cx)).ok().unwrap(); - let _guard = handle.enter(); - match proxy { - Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?, - None => Box::new(TcpStream::connect(rpc_host).await?), + + let stream = gpui_tokio::Tokio::spawn_result(cx, { + let rpc_url = rpc_url.clone(); + async move { + let rpc_host = rpc_url + .host_str() + .zip(rpc_url.port_or_known_default()) + .context("missing host in rpc url")?; + Ok(match proxy { + Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?, + None => Box::new(TcpStream::connect(rpc_host).await?), + }) } - }; + })? + .await?; log::info!("connected to rpc endpoint {}", rpc_url); @@ -1384,11 +1418,13 @@ impl Client { if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) { - eprintln!("authenticate as admin {login}, {token}"); + if !*USE_WEB_LOGIN { + eprintln!("authenticate as admin {login}, {token}"); - return this - .authenticate_as_admin(http, login.clone(), token.clone()) - .await; + return this + .authenticate_as_admin(http, login.clone(), token.clone()) + .await; + } } // Start an HTTP server to receive the redirect from Zed's sign-in page. @@ -1410,6 +1446,12 @@ impl Client { open_url_tx.send(url).log_err(); + #[derive(Deserialize)] + struct CallbackParams { + pub user_id: String, + pub access_token: String, + } + // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted // access token from the query params. // @@ -1420,17 +1462,13 @@ impl Client { for _ in 0..100 { if let Some(req) = server.recv_timeout(Duration::from_secs(1))? { let path = req.url(); - let mut user_id = None; - let mut access_token = None; let url = Url::parse(&format!("http://example.com{}", path)) .context("failed to parse login notification url")?; - for (key, value) in url.query_pairs() { - if key == "access_token" { - access_token = Some(value.to_string()); - } else if key == "user_id" { - user_id = Some(value.to_string()); - } - } + let callback_params: CallbackParams = + serde_urlencoded::from_str(url.query().unwrap_or_default()) + .context( + "failed to parse sign-in callback query parameters", + )?; let post_auth_url = http.build_url("/native_app_signin_succeeded"); @@ -1445,8 +1483,8 @@ impl Client { ) .context("failed to respond to login http request")?; return Ok(( - user_id.context("missing user_id parameter")?, - access_token.context("missing access_token parameter")?, + callback_params.user_id, + callback_params.access_token, )); } } @@ -1656,21 +1694,10 @@ impl Client { ); cx.spawn(async move |_| match future.await { Ok(()) => { - log::debug!( - "rpc message handled. client_id:{}, sender_id:{:?}, type:{}", - client_id, - original_sender_id, - type_name - ); + log::debug!("rpc message handled. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}"); } Err(error) => { - log::error!( - "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}", - client_id, - original_sender_id, - type_name, - error - ); + log::error!("error handling message. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}, error:{error:#}"); } }) .detach(); @@ -1894,10 +1921,7 @@ mod tests { assert!(matches!(status.next().await, Some(Status::Connecting))); executor.advance_clock(CONNECTION_TIMEOUT); - assert!(matches!( - status.next().await, - Some(Status::ConnectionError { .. }) - )); + assert!(matches!(status.next().await, Some(Status::ConnectionError))); auth_and_connect.await.into_response().unwrap_err(); // Allow the connection to be established. @@ -1921,10 +1945,7 @@ mod tests { }) }); executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY); - assert!(matches!( - status.next().await, - Some(Status::Reconnecting { .. }) - )); + assert!(matches!(status.next().await, Some(Status::Reconnecting))); executor.advance_clock(CONNECTION_TIMEOUT); assert!(matches!( @@ -2040,10 +2061,7 @@ mod tests { assert_eq!(*auth_count.lock(), 1); assert_eq!(*dropped_auth_count.lock(), 0); - let _authenticate = cx.spawn({ - let client = client.clone(); - |cx| async move { client.connect(false, &cx).await } - }); + let _authenticate = cx.spawn(|cx| async move { client.connect(false, &cx).await }); executor.run_until_parked(); assert_eq!(*auth_count.lock(), 2); assert_eq!(*dropped_auth_count.lock(), 1); @@ -2065,8 +2083,8 @@ mod tests { let (done_tx1, done_rx1) = smol::channel::unbounded(); let (done_tx2, done_rx2) = smol::channel::unbounded(); AnyProtoClient::from(client.clone()).add_entity_message_handler( - move |entity: Entity, _: TypedEnvelope, mut cx| { - match entity.read_with(&mut cx, |entity, _| entity.id).unwrap() { + move |entity: Entity, _: TypedEnvelope, cx| { + match entity.read_with(&cx, |entity, _| entity.id).unwrap() { 1 => done_tx1.try_send(()).unwrap(), 2 => done_tx2.try_send(()).unwrap(), _ => unreachable!(), @@ -2090,17 +2108,17 @@ mod tests { let _subscription1 = client .subscribe_to_entity(1) .unwrap() - .set_entity(&entity1, &mut cx.to_async()); + .set_entity(&entity1, &cx.to_async()); let _subscription2 = client .subscribe_to_entity(2) .unwrap() - .set_entity(&entity2, &mut cx.to_async()); + .set_entity(&entity2, &cx.to_async()); // Ensure dropping a subscription for the same entity type still allows receiving of // messages for other entity IDs of the same type. let subscription3 = client .subscribe_to_entity(3) .unwrap() - .set_entity(&entity3, &mut cx.to_async()); + .set_entity(&entity3, &cx.to_async()); drop(subscription3); server.send(proto::JoinProject { diff --git a/crates/client/src/telemetry.rs b/crates/client/src/telemetry.rs index 54b3d3f801ff45c7ef13ebadbd38b8a81f76d644..e3123400866516bda26b071e288bdad9dd5964e0 100644 --- a/crates/client/src/telemetry.rs +++ b/crates/client/src/telemetry.rs @@ -76,7 +76,7 @@ static ZED_CLIENT_CHECKSUM_SEED: LazyLock>> = LazyLock::new(|| { pub static MINIDUMP_ENDPOINT: LazyLock> = LazyLock::new(|| { option_env!("ZED_MINIDUMP_ENDPOINT") - .map(|s| s.to_owned()) + .map(str::to_string) .or_else(|| env::var("ZED_MINIDUMP_ENDPOINT").ok()) }); @@ -84,6 +84,10 @@ static DOTNET_PROJECT_FILES_REGEX: LazyLock = LazyLock::new(|| { Regex::new(r"^(global\.json|Directory\.Build\.props|.*\.(csproj|fsproj|vbproj|sln))$").unwrap() }); +#[cfg(target_os = "macos")] +static MACOS_VERSION_REGEX: LazyLock = + LazyLock::new(|| Regex::new(r"(\s*\(Build [^)]*[0-9]\))").unwrap()); + pub fn os_name() -> String { #[cfg(target_os = "macos")] { @@ -108,19 +112,16 @@ pub fn os_name() -> String { pub fn os_version() -> String { #[cfg(target_os = "macos")] { - use cocoa::base::nil; - use cocoa::foundation::NSProcessInfo; - - unsafe { - let process_info = cocoa::foundation::NSProcessInfo::processInfo(nil); - let version = process_info.operatingSystemVersion(); - gpui::SemanticVersion::new( - version.majorVersion as usize, - version.minorVersion as usize, - version.patchVersion as usize, - ) + use objc2_foundation::NSProcessInfo; + let process_info = NSProcessInfo::processInfo(); + let version_nsstring = unsafe { process_info.operatingSystemVersionString() }; + // "Version 15.6.1 (Build 24G90)" -> "15.6.1 (Build 24G90)" + let version_string = version_nsstring.to_string().replace("Version ", ""); + // "15.6.1 (Build 24G90)" -> "15.6.1" + // "26.0.0 (Build 25A5349a)" -> unchanged (Beta or Rapid Security Response; ends with letter) + MACOS_VERSION_REGEX + .replace_all(&version_string, "") .to_string() - } } #[cfg(any(target_os = "linux", target_os = "freebsd"))] { @@ -739,7 +740,7 @@ mod tests { ); // Third scan of worktree does not double report, as we already reported - test_project_discovery_helper(telemetry.clone(), vec!["package.json"], None, worktree_id); + test_project_discovery_helper(telemetry, vec!["package.json"], None, worktree_id); } #[gpui::test] @@ -751,7 +752,7 @@ mod tests { let telemetry = cx.update(|cx| Telemetry::new(clock.clone(), http, cx)); test_project_discovery_helper( - telemetry.clone(), + telemetry, vec!["package.json", "pnpm-lock.yaml"], Some(vec!["node", "pnpm"]), 1, @@ -767,7 +768,7 @@ mod tests { let telemetry = cx.update(|cx| Telemetry::new(clock.clone(), http, cx)); test_project_discovery_helper( - telemetry.clone(), + telemetry, vec!["package.json", "yarn.lock"], Some(vec!["node", "yarn"]), 1, @@ -786,7 +787,7 @@ mod tests { // project type for the same worktree multiple times test_project_discovery_helper( - telemetry.clone().clone(), + telemetry.clone(), vec!["global.json"], Some(vec!["dotnet"]), 1, diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index 439fb100d2244499fa59a81495e282673305e00b..da0e8a0f55a7256bd880b642a4d8cb2f744450d4 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -1,16 +1,12 @@ use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore}; use anyhow::{Context as _, Result, anyhow}; -use chrono::Duration; use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo}; -use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit}; +use cloud_llm_client::{CurrentUsage, PlanV1, UsageData, UsageLimit}; use futures::{StreamExt, stream::BoxStream}; use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext}; use http_client::{AsyncBody, Method, Request, http}; use parking_lot::Mutex; -use rpc::{ - ConnectionId, Peer, Receipt, TypedEnvelope, - proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse}, -}; +use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto}; use std::sync::Arc; pub struct FakeServer { @@ -187,50 +183,27 @@ impl FakeServer { pub async fn receive(&self) -> Result> { self.executor.start_waiting(); - loop { - let message = self - .state - .lock() - .incoming - .as_mut() - .expect("not connected") - .next() - .await - .context("other half hung up")?; - self.executor.finish_waiting(); - let type_name = message.payload_type_name(); - let message = message.into_any(); - - if message.is::>() { - return Ok(*message.downcast().unwrap()); - } - - let accepted_tos_at = chrono::Utc::now() - .checked_sub_signed(Duration::hours(5)) - .expect("failed to build accepted_tos_at") - .timestamp() as u64; - - if message.is::>() { - self.respond( - message - .downcast::>() - .unwrap() - .receipt(), - GetPrivateUserInfoResponse { - metrics_id: "the-metrics-id".into(), - staff: false, - flags: Default::default(), - accepted_tos_at: Some(accepted_tos_at), - }, - ); - continue; - } + let message = self + .state + .lock() + .incoming + .as_mut() + .expect("not connected") + .next() + .await + .context("other half hung up")?; + self.executor.finish_waiting(); + let type_name = message.payload_type_name(); + let message = message.into_any(); - panic!( - "fake server received unexpected message type: {:?}", - type_name - ); + if message.is::>() { + return Ok(*message.downcast().unwrap()); } + + panic!( + "fake server received unexpected message type: {:?}", + type_name + ); } pub fn respond(&self, receipt: Receipt, response: T::Response) { @@ -296,7 +269,8 @@ pub fn make_get_authenticated_user_response( }, feature_flags: vec![], plan: PlanInfo { - plan: Plan::ZedPro, + plan: PlanV1::ZedPro, + plan_v2: None, subscription_period: None, usage: CurrentUsage { model_requests: UsageData { diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index faf46945d888d3a3da18f69a16cdcc11009e1937..63626e8ce1f3b25c742f227a56556545762367c3 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -1,11 +1,12 @@ use super::{Client, Status, TypedEnvelope, proto}; -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Context as _, Result}; use chrono::{DateTime, Utc}; use cloud_api_client::websocket_protocol::MessageToClient; use cloud_api_client::{GetAuthenticatedUserResponse, PlanInfo}; use cloud_llm_client::{ EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, - MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, + MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, Plan, + UsageLimit, }; use collections::{HashMap, HashSet, hash_map::Entry}; use derive_more::Deref; @@ -41,16 +42,11 @@ impl std::fmt::Display for ChannelId { pub struct ProjectId(pub u64); impl ProjectId { - pub fn to_proto(&self) -> u64 { + pub fn to_proto(self) -> u64 { self.0 } } -#[derive( - Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize, -)] -pub struct DevServerProjectId(pub u64); - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ParticipantIndex(pub u32); @@ -116,7 +112,6 @@ pub struct UserStore { edit_prediction_usage: Option, plan_info: Option, current_user: watch::Receiver>>, - accepted_tos_at: Option>, contacts: Vec>, incoming_contact_requests: Vec>, outgoing_contact_requests: Vec>, @@ -177,7 +172,6 @@ impl UserStore { let (mut current_user_tx, current_user_rx) = watch::channel(); let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded(); let rpc_subscriptions = vec![ - client.add_message_handler(cx.weak_entity(), Self::handle_update_plan), client.add_message_handler(cx.weak_entity(), Self::handle_update_contacts), client.add_message_handler(cx.weak_entity(), Self::handle_update_invite_info), client.add_message_handler(cx.weak_entity(), Self::handle_show_contacts), @@ -195,7 +189,6 @@ impl UserStore { plan_info: None, model_request_usage: None, edit_prediction_usage: None, - accepted_tos_at: None, contacts: Default::default(), incoming_contact_requests: Default::default(), participant_indices: Default::default(), @@ -224,7 +217,9 @@ impl UserStore { return Ok(()); }; match status { - Status::Authenticated | Status::Connected { .. } => { + Status::Authenticated + | Status::Reauthenticated + | Status::Connected { .. } => { if let Some(user_id) = client.user_id() { let response = client .cloud_client() @@ -272,7 +267,6 @@ impl UserStore { Status::SignedOut => { current_user_tx.send(None).await.ok(); this.update(cx, |this, cx| { - this.accepted_tos_at = None; cx.emit(Event::PrivateUserInfoUpdated); cx.notify(); this.clear_contacts() @@ -333,9 +327,9 @@ impl UserStore { async fn handle_update_contacts( this: Entity, message: TypedEnvelope, - mut cx: AsyncApp, + cx: AsyncApp, ) -> Result<()> { - this.read_with(&mut cx, |this, _| { + this.read_with(&cx, |this, _| { this.update_contacts_tx .unbounded_send(UpdateContacts::Update(message.payload)) .unwrap(); @@ -343,26 +337,6 @@ impl UserStore { Ok(()) } - async fn handle_update_plan( - this: Entity, - _message: TypedEnvelope, - mut cx: AsyncApp, - ) -> Result<()> { - let client = this - .read_with(&cx, |this, _| this.client.upgrade())? - .context("client was dropped")?; - - let response = client - .cloud_client() - .get_authenticated_user() - .await - .context("failed to fetch authenticated user")?; - - this.update(&mut cx, |this, cx| { - this.update_authenticated_user(response, cx); - }) - } - fn update_contacts(&mut self, message: UpdateContacts, cx: &Context) -> Task> { match message { UpdateContacts::Wait(barrier) => { @@ -719,20 +693,22 @@ impl UserStore { self.current_user.borrow().clone() } - pub fn plan(&self) -> Option { + pub fn plan(&self) -> Option { #[cfg(debug_assertions)] if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() { + use cloud_llm_client::PlanV1; + return match plan.as_str() { - "free" => Some(cloud_llm_client::Plan::ZedFree), - "trial" => Some(cloud_llm_client::Plan::ZedProTrial), - "pro" => Some(cloud_llm_client::Plan::ZedPro), + "free" => Some(Plan::V1(PlanV1::ZedFree)), + "trial" => Some(Plan::V1(PlanV1::ZedProTrial)), + "pro" => Some(Plan::V1(PlanV1::ZedPro)), _ => { panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'"); } }; } - self.plan_info.as_ref().map(|info| info.plan) + self.plan_info.as_ref().map(|info| info.plan()) } pub fn subscription_period(&self) -> Option<(DateTime, DateTime)> { @@ -812,19 +788,6 @@ impl UserStore { .set_authenticated_user_info(Some(response.user.metrics_id.clone()), staff); } - let accepted_tos_at = { - #[cfg(debug_assertions)] - if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() { - None - } else { - response.user.accepted_tos_at - } - - #[cfg(not(debug_assertions))] - response.user.accepted_tos_at - }; - - self.accepted_tos_at = Some(accepted_tos_at); self.model_request_usage = Some(ModelRequestUsage(RequestUsage { limit: response.plan.usage.model_requests.limit, amount: response.plan.usage.model_requests.used as i32, @@ -867,32 +830,6 @@ impl UserStore { self.current_user.clone() } - pub fn has_accepted_terms_of_service(&self) -> bool { - self.accepted_tos_at - .map_or(false, |accepted_tos_at| accepted_tos_at.is_some()) - } - - pub fn accept_terms_of_service(&self, cx: &Context) -> Task> { - if self.current_user().is_none() { - return Task::ready(Err(anyhow!("no current user"))); - }; - - let client = self.client.clone(); - cx.spawn(async move |this, cx| -> anyhow::Result<()> { - let client = client.upgrade().context("client not found")?; - let response = client - .cloud_client() - .accept_terms_of_service() - .await - .context("error accepting tos")?; - this.update(cx, |this, cx| { - this.accepted_tos_at = Some(response.user.accepted_tos_at); - cx.emit(Event::PrivateUserInfoUpdated); - })?; - Ok(()) - }) - } - fn load_users( &self, request: impl RequestMessage, @@ -915,10 +852,10 @@ impl UserStore { let mut ret = Vec::with_capacity(users.len()); for user in users { let user = User::new(user); - if let Some(old) = self.users.insert(user.id, user.clone()) { - if old.github_login != user.github_login { - self.by_github_login.remove(&old.github_login); - } + if let Some(old) = self.users.insert(user.id, user.clone()) + && old.github_login != user.github_login + { + self.by_github_login.remove(&old.github_login); } self.by_github_login .insert(user.github_login.clone(), user.id); @@ -1019,19 +956,6 @@ impl RequestUsage { } } - pub fn from_proto(amount: u32, limit: proto::UsageLimit) -> Option { - let limit = match limit.variant? { - proto::usage_limit::Variant::Limited(limited) => { - UsageLimit::Limited(limited.limit as i32) - } - proto::usage_limit::Variant::Unlimited(_) => UsageLimit::Unlimited, - }; - Some(RequestUsage { - limit, - amount: amount as i32, - }) - } - fn from_headers( limit_name: &str, amount_name: &str, diff --git a/crates/client/src/zed_urls.rs b/crates/client/src/zed_urls.rs index 9df41906d79b4d43234a28dde19bd6862469de8c..7193c099473c95794796c2fc4d3eaaf2f06eb1ac 100644 --- a/crates/client/src/zed_urls.rs +++ b/crates/client/src/zed_urls.rs @@ -43,3 +43,11 @@ pub fn ai_privacy_and_security(cx: &App) -> String { server_url = server_url(cx) ) } + +/// Returns the URL to Zed AI's external agents documentation. +pub fn external_agents_docs(cx: &App) -> String { + format!( + "{server_url}/docs/ai/external-agents", + server_url = server_url(cx) + ) +} diff --git a/crates/cloud_api_client/src/cloud_api_client.rs b/crates/cloud_api_client/src/cloud_api_client.rs index ef9a1a9a553596baf737c4e1ee60d9b3344f4ecf..7fd96fcef0e8fd764bbcaa8ab59a9666095f9db9 100644 --- a/crates/cloud_api_client/src/cloud_api_client.rs +++ b/crates/cloud_api_client/src/cloud_api_client.rs @@ -102,13 +102,7 @@ impl CloudApiClient { let credentials = credentials.as_ref().context("no credentials provided")?; let authorization_header = format!("{} {}", credentials.user_id, credentials.access_token); - Ok(cx.spawn(async move |cx| { - let handle = cx - .update(|cx| Tokio::handle(cx)) - .ok() - .context("failed to get Tokio handle")?; - let _guard = handle.enter(); - + Ok(Tokio::spawn_result(cx, async move { let ws = WebSocket::connect(connect_url) .with_request( request::Builder::new() @@ -121,34 +115,6 @@ impl CloudApiClient { })) } - pub async fn accept_terms_of_service(&self) -> Result { - let request = self.build_request( - Request::builder().method(Method::POST).uri( - self.http_client - .build_zed_cloud_url("/client/terms_of_service/accept", &[])? - .as_ref(), - ), - AsyncBody::default(), - )?; - - let mut response = self.http_client.send(request).await?; - - if !response.status().is_success() { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - anyhow::bail!( - "Failed to accept terms of service.\nStatus: {:?}\nBody: {body}", - response.status() - ) - } - - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - Ok(serde_json::from_str(&body)?) - } - pub async fn create_llm_token( &self, system_id: Option, @@ -205,12 +171,12 @@ impl CloudApiClient { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; if response.status() == StatusCode::UNAUTHORIZED { - return Ok(false); + Ok(false) } else { - return Err(anyhow!( + Err(anyhow!( "Failed to get authenticated user.\nStatus: {:?}\nBody: {body}", response.status() - )); + )) } } } diff --git a/crates/cloud_api_types/src/cloud_api_types.rs b/crates/cloud_api_types/src/cloud_api_types.rs index fa189cd3b5ed7e87e2f3f2a6803d9095c6305105..fddf505f1075ae98f9b0260b2bc0391c7a008942 100644 --- a/crates/cloud_api_types/src/cloud_api_types.rs +++ b/crates/cloud_api_types/src/cloud_api_types.rs @@ -1,6 +1,7 @@ mod timestamp; pub mod websocket_protocol; +use cloud_llm_client::Plan; use serde::{Deserialize, Serialize}; pub use crate::timestamp::Timestamp; @@ -27,7 +28,9 @@ pub struct AuthenticatedUser { #[derive(Debug, PartialEq, Serialize, Deserialize)] pub struct PlanInfo { - pub plan: cloud_llm_client::Plan, + pub plan: cloud_llm_client::PlanV1, + #[serde(default)] + pub plan_v2: Option, pub subscription_period: Option, pub usage: cloud_llm_client::CurrentUsage, pub trial_started_at: Option, @@ -36,6 +39,12 @@ pub struct PlanInfo { pub has_overdue_invoices: bool, } +impl PlanInfo { + pub fn plan(&self) -> Plan { + self.plan_v2.map(Plan::V2).unwrap_or(Plan::V1(self.plan)) + } +} + #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] pub struct SubscriptionPeriod { pub started_at: Timestamp, diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index 741945af1087e7a4ff5edfc32cca4d080db3982f..16267d86d806387140016dc0a25021ad92607ff2 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -74,9 +74,21 @@ impl FromStr for UsageLimit { } } +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Plan { + V1(PlanV1), + V2(PlanV2), +} + +impl Plan { + pub fn is_v2(&self) -> bool { + matches!(self, Self::V2(_)) + } +} + #[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] -pub enum Plan { +pub enum PlanV1 { #[default] #[serde(alias = "Free")] ZedFree, @@ -86,40 +98,36 @@ pub enum Plan { ZedProTrial, } -impl Plan { - pub fn as_str(&self) -> &'static str { - match self { - Plan::ZedFree => "zed_free", - Plan::ZedPro => "zed_pro", - Plan::ZedProTrial => "zed_pro_trial", - } - } +impl FromStr for PlanV1 { + type Err = anyhow::Error; - pub fn model_requests_limit(&self) -> UsageLimit { - match self { - Plan::ZedPro => UsageLimit::Limited(500), - Plan::ZedProTrial => UsageLimit::Limited(150), - Plan::ZedFree => UsageLimit::Limited(50), + fn from_str(value: &str) -> Result { + match value { + "zed_free" => Ok(Self::ZedFree), + "zed_pro" => Ok(Self::ZedPro), + "zed_pro_trial" => Ok(Self::ZedProTrial), + plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")), } } +} - pub fn edit_predictions_limit(&self) -> UsageLimit { - match self { - Plan::ZedPro => UsageLimit::Unlimited, - Plan::ZedProTrial => UsageLimit::Unlimited, - Plan::ZedFree => UsageLimit::Limited(2_000), - } - } +#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum PlanV2 { + #[default] + ZedFree, + ZedPro, + ZedProTrial, } -impl FromStr for Plan { +impl FromStr for PlanV2 { type Err = anyhow::Error; fn from_str(value: &str) -> Result { match value { - "zed_free" => Ok(Plan::ZedFree), - "zed_pro" => Ok(Plan::ZedPro), - "zed_pro_trial" => Ok(Plan::ZedProTrial), + "zed_free" => Ok(Self::ZedFree), + "zed_pro" => Ok(Self::ZedPro), + "zed_pro_trial" => Ok(Self::ZedProTrial), plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")), } } @@ -320,7 +328,7 @@ pub struct ListModelsResponse { #[derive(Debug, Serialize, Deserialize)] pub struct GetSubscriptionResponse { - pub plan: Plan, + pub plan: PlanV1, pub usage: Option, } @@ -344,27 +352,39 @@ mod tests { use super::*; #[test] - fn test_plan_deserialize_snake_case() { - let plan = serde_json::from_value::(json!("zed_free")).unwrap(); - assert_eq!(plan, Plan::ZedFree); + fn test_plan_v1_deserialize_snake_case() { + let plan = serde_json::from_value::(json!("zed_free")).unwrap(); + assert_eq!(plan, PlanV1::ZedFree); + + let plan = serde_json::from_value::(json!("zed_pro")).unwrap(); + assert_eq!(plan, PlanV1::ZedPro); + + let plan = serde_json::from_value::(json!("zed_pro_trial")).unwrap(); + assert_eq!(plan, PlanV1::ZedProTrial); + } + + #[test] + fn test_plan_v1_deserialize_aliases() { + let plan = serde_json::from_value::(json!("Free")).unwrap(); + assert_eq!(plan, PlanV1::ZedFree); - let plan = serde_json::from_value::(json!("zed_pro")).unwrap(); - assert_eq!(plan, Plan::ZedPro); + let plan = serde_json::from_value::(json!("ZedPro")).unwrap(); + assert_eq!(plan, PlanV1::ZedPro); - let plan = serde_json::from_value::(json!("zed_pro_trial")).unwrap(); - assert_eq!(plan, Plan::ZedProTrial); + let plan = serde_json::from_value::(json!("ZedProTrial")).unwrap(); + assert_eq!(plan, PlanV1::ZedProTrial); } #[test] - fn test_plan_deserialize_aliases() { - let plan = serde_json::from_value::(json!("Free")).unwrap(); - assert_eq!(plan, Plan::ZedFree); + fn test_plan_v2_deserialize_snake_case() { + let plan = serde_json::from_value::(json!("zed_free")).unwrap(); + assert_eq!(plan, PlanV2::ZedFree); - let plan = serde_json::from_value::(json!("ZedPro")).unwrap(); - assert_eq!(plan, Plan::ZedPro); + let plan = serde_json::from_value::(json!("zed_pro")).unwrap(); + assert_eq!(plan, PlanV2::ZedPro); - let plan = serde_json::from_value::(json!("ZedProTrial")).unwrap(); - assert_eq!(plan, Plan::ZedProTrial); + let plan = serde_json::from_value::(json!("zed_pro_trial")).unwrap(); + assert_eq!(plan, PlanV2::ZedProTrial); } #[test] diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 9af95317e60db78fc93b9a1fa01eaee687fac4fc..4fccd3be7ff8b4d44daf5f761695bdba81bd2ad8 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -19,7 +19,6 @@ test-support = ["sqlite"] [dependencies] anyhow.workspace = true -async-stripe.workspace = true async-trait.workspace = true async-tungstenite.workspace = true aws-config = { version = "1.1.5" } @@ -30,16 +29,13 @@ axum-extra = { version = "0.4", features = ["erased-json"] } base64.workspace = true chrono.workspace = true clock.workspace = true -cloud_llm_client.workspace = true collections.workspace = true dashmap.workspace = true -derive_more.workspace = true envy = "0.4.2" futures.workspace = true gpui.workspace = true hex.workspace = true http_client.workspace = true -jsonwebtoken.workspace = true livekit_api.workspace = true log.workspace = true nanoid.workspace = true @@ -65,7 +61,6 @@ subtle.workspace = true supermaven_api.workspace = true telemetry_events.workspace = true text.workspace = true -thiserror.workspace = true time.workspace = true tokio = { workspace = true, features = ["full"] } toml.workspace = true @@ -136,6 +131,3 @@ util.workspace = true workspace = { workspace = true, features = ["test-support"] } worktree = { workspace = true, features = ["test-support"] } zlog.workspace = true - -[package.metadata.cargo-machete] -ignored = ["async-stripe"] diff --git a/crates/collab/k8s/collab.template.yml b/crates/collab/k8s/collab.template.yml index 45fc018a4afbd22180419742ab50197d13ac3a59..214b550ac20499b8b03cfafeefab9b45d51fcc24 100644 --- a/crates/collab/k8s/collab.template.yml +++ b/crates/collab/k8s/collab.template.yml @@ -219,12 +219,6 @@ spec: secretKeyRef: name: slack key: panics_webhook - - name: STRIPE_API_KEY - valueFrom: - secretKeyRef: - name: stripe - key: api_key - optional: true - name: COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR value: "1000" - name: SUPERMAVEN_ADMIN_API_KEY diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 73d473ab767e633ae2cefc309d87074523811851..b2e25458ef98b295b4d056a7f59521f4fa896f1a 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -116,6 +116,7 @@ CREATE TABLE "project_repositories" ( "scan_id" INTEGER NOT NULL, "is_deleted" BOOL NOT NULL, "current_merge_conflicts" VARCHAR, + "merge_message" VARCHAR, "branch_summary" VARCHAR, "head_commit_details" VARCHAR, PRIMARY KEY (project_id, id) @@ -174,6 +175,7 @@ CREATE TABLE "language_servers" ( "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "name" VARCHAR NOT NULL, "capabilities" TEXT NOT NULL, + "worktree_id" BIGINT, PRIMARY KEY (project_id, id) ); @@ -474,67 +476,6 @@ CREATE UNIQUE INDEX "index_extensions_external_id" ON "extensions" ("external_id CREATE INDEX "index_extensions_total_download_count" ON "extensions" ("total_download_count"); -CREATE TABLE rate_buckets ( - user_id INT NOT NULL, - rate_limit_name VARCHAR(255) NOT NULL, - token_count INT NOT NULL, - last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL, - PRIMARY KEY (user_id, rate_limit_name), - FOREIGN KEY (user_id) REFERENCES users (id) -); - -CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name); - -CREATE TABLE IF NOT EXISTS billing_preferences ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - user_id INTEGER NOT NULL REFERENCES users (id), - max_monthly_llm_usage_spending_in_cents INTEGER NOT NULL, - model_request_overages_enabled bool NOT NULL DEFAULT FALSE, - model_request_overages_spend_limit_in_cents integer NOT NULL DEFAULT 0 -); - -CREATE UNIQUE INDEX "uix_billing_preferences_on_user_id" ON billing_preferences (user_id); - -CREATE TABLE IF NOT EXISTS billing_customers ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - user_id INTEGER NOT NULL REFERENCES users (id), - has_overdue_invoices BOOLEAN NOT NULL DEFAULT FALSE, - stripe_customer_id TEXT NOT NULL, - trial_started_at TIMESTAMP -); - -CREATE UNIQUE INDEX "uix_billing_customers_on_user_id" ON billing_customers (user_id); - -CREATE UNIQUE INDEX "uix_billing_customers_on_stripe_customer_id" ON billing_customers (stripe_customer_id); - -CREATE TABLE IF NOT EXISTS billing_subscriptions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - billing_customer_id INTEGER NOT NULL REFERENCES billing_customers (id), - stripe_subscription_id TEXT NOT NULL, - stripe_subscription_status TEXT NOT NULL, - stripe_cancel_at TIMESTAMP, - stripe_cancellation_reason TEXT, - kind TEXT, - stripe_current_period_start BIGINT, - stripe_current_period_end BIGINT -); - -CREATE INDEX "ix_billing_subscriptions_on_billing_customer_id" ON billing_subscriptions (billing_customer_id); - -CREATE UNIQUE INDEX "uix_billing_subscriptions_on_stripe_subscription_id" ON billing_subscriptions (stripe_subscription_id); - -CREATE TABLE IF NOT EXISTS processed_stripe_events ( - stripe_event_id TEXT PRIMARY KEY, - stripe_event_type TEXT NOT NULL, - stripe_event_created_timestamp INTEGER NOT NULL, - processed_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP -); - -CREATE INDEX "ix_processed_stripe_events_on_stripe_event_created_timestamp" ON processed_stripe_events (stripe_event_created_timestamp); - CREATE TABLE IF NOT EXISTS "breakpoints" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, diff --git a/crates/collab/migrations/20250816124707_make_admin_required_on_users.sql b/crates/collab/migrations/20250816124707_make_admin_required_on_users.sql new file mode 100644 index 0000000000000000000000000000000000000000..e372723d6d5f5e822a2e437cfac4b95bc2023998 --- /dev/null +++ b/crates/collab/migrations/20250816124707_make_admin_required_on_users.sql @@ -0,0 +1,2 @@ +alter table users +alter column admin set not null; diff --git a/crates/collab/migrations/20250816133027_add_orb_customer_id_to_billing_customers.sql b/crates/collab/migrations/20250816133027_add_orb_customer_id_to_billing_customers.sql new file mode 100644 index 0000000000000000000000000000000000000000..ea5e4de52a829413030bb5e206f5c7401381adcf --- /dev/null +++ b/crates/collab/migrations/20250816133027_add_orb_customer_id_to_billing_customers.sql @@ -0,0 +1,2 @@ +alter table billing_customers + add column orb_customer_id text; diff --git a/crates/collab/migrations/20250816135346_drop_rate_buckets_table.sql b/crates/collab/migrations/20250816135346_drop_rate_buckets_table.sql new file mode 100644 index 0000000000000000000000000000000000000000..f51a33ed30d7fb88bc9dc6c82e7217c7e4634b28 --- /dev/null +++ b/crates/collab/migrations/20250816135346_drop_rate_buckets_table.sql @@ -0,0 +1 @@ +drop table rate_buckets; diff --git a/crates/collab/migrations/20250818192156_add_git_merge_message.sql b/crates/collab/migrations/20250818192156_add_git_merge_message.sql new file mode 100644 index 0000000000000000000000000000000000000000..335ea2f82493082e0e20d7762b5282696dc50224 --- /dev/null +++ b/crates/collab/migrations/20250818192156_add_git_merge_message.sql @@ -0,0 +1 @@ +ALTER TABLE "project_repositories" ADD COLUMN "merge_message" VARCHAR; diff --git a/crates/collab/migrations/20250819022421_add_orb_subscription_id_to_billing_subscriptions.sql b/crates/collab/migrations/20250819022421_add_orb_subscription_id_to_billing_subscriptions.sql new file mode 100644 index 0000000000000000000000000000000000000000..317f6a7653e3d1762f74e795a17d2f99b3831201 --- /dev/null +++ b/crates/collab/migrations/20250819022421_add_orb_subscription_id_to_billing_subscriptions.sql @@ -0,0 +1,2 @@ +alter table billing_subscriptions + add column orb_subscription_id text; diff --git a/crates/collab/migrations/20250819225916_make_stripe_fields_optional_on_billing_subscription.sql b/crates/collab/migrations/20250819225916_make_stripe_fields_optional_on_billing_subscription.sql new file mode 100644 index 0000000000000000000000000000000000000000..cf3b79da60be98da8dd78a2bcb01f7532be7fc59 --- /dev/null +++ b/crates/collab/migrations/20250819225916_make_stripe_fields_optional_on_billing_subscription.sql @@ -0,0 +1,3 @@ +alter table billing_subscriptions + alter column stripe_subscription_id drop not null, + alter column stripe_subscription_status drop not null; diff --git a/crates/collab/migrations/20250821133754_add_orb_subscription_status_and_period_to_billing_subscriptions.sql b/crates/collab/migrations/20250821133754_add_orb_subscription_status_and_period_to_billing_subscriptions.sql new file mode 100644 index 0000000000000000000000000000000000000000..89a42ab82bd97f487a426ef1fa0a08aa5b0c8396 --- /dev/null +++ b/crates/collab/migrations/20250821133754_add_orb_subscription_status_and_period_to_billing_subscriptions.sql @@ -0,0 +1,4 @@ +alter table billing_subscriptions + add column orb_subscription_status text, + add column orb_current_billing_period_start_date timestamp without time zone, + add column orb_current_billing_period_end_date timestamp without time zone; diff --git a/crates/collab/migrations/20250827084812_worktree_in_servers.sql b/crates/collab/migrations/20250827084812_worktree_in_servers.sql new file mode 100644 index 0000000000000000000000000000000000000000..d4c6ffbbcccb2d2f23654cfc287b45bb8ea20508 --- /dev/null +++ b/crates/collab/migrations/20250827084812_worktree_in_servers.sql @@ -0,0 +1,2 @@ +ALTER TABLE language_servers + ADD COLUMN worktree_id BIGINT; diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 6cf3f68f54eda75ac19950c53cf535ff30a107a9..0cc7e2b2e93969ba7b8942838e4afcee251a20d9 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -1,19 +1,11 @@ -pub mod billing; pub mod contributors; pub mod events; pub mod extensions; pub mod ips_file; pub mod slack; -use crate::db::Database; -use crate::{ - AppState, Error, Result, auth, - db::{User, UserId}, - rpc, -}; -use ::rpc::proto; +use crate::{AppState, Error, Result, auth, db::UserId, rpc}; use anyhow::Context as _; -use axum::extract; use axum::{ Extension, Json, Router, body::Body, @@ -25,7 +17,6 @@ use axum::{ routing::{get, post}, }; use axum_extra::response::ErasedJson; -use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::sync::{Arc, OnceLock}; use tower::ServiceBuilder; @@ -100,10 +91,7 @@ impl std::fmt::Display for SystemIdHeader { pub fn routes(rpc_server: Arc) -> Router<(), Body> { Router::new() - .route("/users/look_up", get(look_up_user)) .route("/users/:id/access_tokens", post(create_access_token)) - .route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens)) - .route("/users/:id/update_plan", post(update_plan)) .route("/rpc_server_snapshot", get(get_rpc_server_snapshot)) .merge(contributors::router()) .layer( @@ -144,99 +132,6 @@ pub async fn validate_api_token(req: Request, next: Next) -> impl IntoR Ok::<_, Error>(next.run(req).await) } -#[derive(Debug, Deserialize)] -struct LookUpUserParams { - identifier: String, -} - -#[derive(Debug, Serialize)] -struct LookUpUserResponse { - user: Option, -} - -async fn look_up_user( - Query(params): Query, - Extension(app): Extension>, -) -> Result> { - let user = resolve_identifier_to_user(&app.db, ¶ms.identifier).await?; - let user = if let Some(user) = user { - match user { - UserOrId::User(user) => Some(user), - UserOrId::Id(id) => app.db.get_user_by_id(id).await?, - } - } else { - None - }; - - Ok(Json(LookUpUserResponse { user })) -} - -enum UserOrId { - User(User), - Id(UserId), -} - -async fn resolve_identifier_to_user( - db: &Arc, - identifier: &str, -) -> Result> { - if let Some(identifier) = identifier.parse::().ok() { - let user = db.get_user_by_id(UserId(identifier)).await?; - - return Ok(user.map(UserOrId::User)); - } - - if identifier.starts_with("cus_") { - let billing_customer = db - .get_billing_customer_by_stripe_customer_id(&identifier) - .await?; - - return Ok(billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id))); - } - - if identifier.starts_with("sub_") { - let billing_subscription = db - .get_billing_subscription_by_stripe_subscription_id(&identifier) - .await?; - - if let Some(billing_subscription) = billing_subscription { - let billing_customer = db - .get_billing_customer_by_id(billing_subscription.billing_customer_id) - .await?; - - return Ok( - billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id)) - ); - } else { - return Ok(None); - } - } - - if identifier.contains('@') { - let user = db.get_user_by_email(identifier).await?; - - return Ok(user.map(UserOrId::User)); - } - - if let Some(user) = db.get_user_by_github_login(identifier).await? { - return Ok(Some(UserOrId::User(user))); - } - - Ok(None) -} - -#[derive(Deserialize, Debug)] -struct CreateUserParams { - github_user_id: i32, - github_login: String, - email_address: String, - email_confirmation_code: Option, - #[serde(default)] - admin: bool, - #[serde(default)] - invite_count: i32, -} - async fn get_rpc_server_snapshot( Extension(rpc_server): Extension>, ) -> Result { @@ -295,90 +190,3 @@ async fn create_access_token( encrypted_access_token, })) } - -#[derive(Serialize)] -struct RefreshLlmTokensResponse {} - -async fn refresh_llm_tokens( - Path(user_id): Path, - Extension(rpc_server): Extension>, -) -> Result> { - rpc_server.refresh_llm_tokens_for_user(user_id).await; - - Ok(Json(RefreshLlmTokensResponse {})) -} - -#[derive(Debug, Serialize, Deserialize)] -struct UpdatePlanBody { - pub plan: cloud_llm_client::Plan, - pub subscription_period: SubscriptionPeriod, - pub usage: cloud_llm_client::CurrentUsage, - pub trial_started_at: Option>, - pub is_usage_based_billing_enabled: bool, - pub is_account_too_young: bool, - pub has_overdue_invoices: bool, -} - -#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] -struct SubscriptionPeriod { - pub started_at: DateTime, - pub ended_at: DateTime, -} - -#[derive(Serialize)] -struct UpdatePlanResponse {} - -async fn update_plan( - Path(user_id): Path, - Extension(rpc_server): Extension>, - extract::Json(body): extract::Json, -) -> Result> { - let plan = match body.plan { - cloud_llm_client::Plan::ZedFree => proto::Plan::Free, - cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro, - cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial, - }; - - let update_user_plan = proto::UpdateUserPlan { - plan: plan.into(), - trial_started_at: body - .trial_started_at - .map(|trial_started_at| trial_started_at.timestamp() as u64), - is_usage_based_billing_enabled: Some(body.is_usage_based_billing_enabled), - usage: Some(proto::SubscriptionUsage { - model_requests_usage_amount: body.usage.model_requests.used, - model_requests_usage_limit: Some(usage_limit_to_proto(body.usage.model_requests.limit)), - edit_predictions_usage_amount: body.usage.edit_predictions.used, - edit_predictions_usage_limit: Some(usage_limit_to_proto( - body.usage.edit_predictions.limit, - )), - }), - subscription_period: Some(proto::SubscriptionPeriod { - started_at: body.subscription_period.started_at.timestamp() as u64, - ended_at: body.subscription_period.ended_at.timestamp() as u64, - }), - account_too_young: Some(body.is_account_too_young), - has_overdue_invoices: Some(body.has_overdue_invoices), - }; - - rpc_server - .update_plan_for_user(user_id, update_user_plan) - .await?; - - Ok(Json(UpdatePlanResponse {})) -} - -fn usage_limit_to_proto(limit: cloud_llm_client::UsageLimit) -> proto::UsageLimit { - proto::UsageLimit { - variant: Some(match limit { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - } -} diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs deleted file mode 100644 index a0325d14c4a1b9f4221b17b446983b17f767fcbe..0000000000000000000000000000000000000000 --- a/crates/collab/src/api/billing.rs +++ /dev/null @@ -1,59 +0,0 @@ -use std::sync::Arc; -use stripe::SubscriptionStatus; - -use crate::AppState; -use crate::db::billing_subscription::StripeSubscriptionStatus; -use crate::db::{CreateBillingCustomerParams, billing_customer}; -use crate::stripe_client::{StripeClient, StripeCustomerId}; - -impl From for StripeSubscriptionStatus { - fn from(value: SubscriptionStatus) -> Self { - match value { - SubscriptionStatus::Incomplete => Self::Incomplete, - SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired, - SubscriptionStatus::Trialing => Self::Trialing, - SubscriptionStatus::Active => Self::Active, - SubscriptionStatus::PastDue => Self::PastDue, - SubscriptionStatus::Canceled => Self::Canceled, - SubscriptionStatus::Unpaid => Self::Unpaid, - SubscriptionStatus::Paused => Self::Paused, - } - } -} - -/// Finds or creates a billing customer using the provided customer. -pub async fn find_or_create_billing_customer( - app: &Arc, - stripe_client: &dyn StripeClient, - customer_id: &StripeCustomerId, -) -> anyhow::Result> { - // If we already have a billing customer record associated with the Stripe customer, - // there's nothing more we need to do. - if let Some(billing_customer) = app - .db - .get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref()) - .await? - { - return Ok(Some(billing_customer)); - } - - let customer = stripe_client.get_customer(customer_id).await?; - - let Some(email) = customer.email else { - return Ok(None); - }; - - let Some(user) = app.db.get_user_by_email(&email).await? else { - return Ok(None); - }; - - let billing_customer = app - .db - .create_billing_customer(&CreateBillingCustomerParams { - user_id: user.id, - stripe_customer_id: customer.id.to_string(), - }) - .await?; - - Ok(Some(billing_customer)) -} diff --git a/crates/collab/src/api/events.rs b/crates/collab/src/api/events.rs index 2f34a843a860d9d2933a4819788d0f9285473edf..da78a980693bec2243d872092a4f373698958b7a 100644 --- a/crates/collab/src/api/events.rs +++ b/crates/collab/src/api/events.rs @@ -149,35 +149,35 @@ pub async fn post_crash( "crash report" ); - if let Some(kinesis_client) = app.kinesis_client.clone() { - if let Some(stream) = app.config.kinesis_stream.clone() { - let properties = json!({ - "app_version": report.header.app_version, - "os_version": report.header.os_version, - "os_name": "macOS", - "bundle_id": report.header.bundle_id, - "incident_id": report.header.incident_id, - "installation_id": installation_id, - "description": description, - "backtrace": summary, - }); - let row = SnowflakeRow::new( - "Crash Reported", - None, - false, - Some(installation_id), - properties, - ); - let data = serde_json::to_vec(&row)?; - kinesis_client - .put_record() - .stream_name(stream) - .partition_key(row.insert_id.unwrap_or_default()) - .data(data.into()) - .send() - .await - .log_err(); - } + if let Some(kinesis_client) = app.kinesis_client.clone() + && let Some(stream) = app.config.kinesis_stream.clone() + { + let properties = json!({ + "app_version": report.header.app_version, + "os_version": report.header.os_version, + "os_name": "macOS", + "bundle_id": report.header.bundle_id, + "incident_id": report.header.incident_id, + "installation_id": installation_id, + "description": description, + "backtrace": summary, + }); + let row = SnowflakeRow::new( + "Crash Reported", + None, + false, + Some(installation_id), + properties, + ); + let data = serde_json::to_vec(&row)?; + kinesis_client + .put_record() + .stream_name(stream) + .partition_key(row.insert_id.unwrap_or_default()) + .data(data.into()) + .send() + .await + .log_err(); } if let Some(slack_panics_webhook) = app.config.slack_panics_webhook.clone() { @@ -280,7 +280,7 @@ pub async fn post_hang( service = "client", version = %report.app_version.unwrap_or_default().to_string(), os_name = %report.os_name, - os_version = report.os_version.unwrap_or_default().to_string(), + os_version = report.os_version.unwrap_or_default(), incident_id = %incident_id, installation_id = %report.installation_id.unwrap_or_default(), backtrace = %backtrace, @@ -359,34 +359,34 @@ pub async fn post_panic( "panic report" ); - if let Some(kinesis_client) = app.kinesis_client.clone() { - if let Some(stream) = app.config.kinesis_stream.clone() { - let properties = json!({ - "app_version": panic.app_version, - "os_name": panic.os_name, - "os_version": panic.os_version, - "incident_id": incident_id, - "installation_id": panic.installation_id, - "description": panic.payload, - "backtrace": backtrace, - }); - let row = SnowflakeRow::new( - "Panic Reported", - None, - false, - panic.installation_id.clone(), - properties, - ); - let data = serde_json::to_vec(&row)?; - kinesis_client - .put_record() - .stream_name(stream) - .partition_key(row.insert_id.unwrap_or_default()) - .data(data.into()) - .send() - .await - .log_err(); - } + if let Some(kinesis_client) = app.kinesis_client.clone() + && let Some(stream) = app.config.kinesis_stream.clone() + { + let properties = json!({ + "app_version": panic.app_version, + "os_name": panic.os_name, + "os_version": panic.os_version, + "incident_id": incident_id, + "installation_id": panic.installation_id, + "description": panic.payload, + "backtrace": backtrace, + }); + let row = SnowflakeRow::new( + "Panic Reported", + None, + false, + panic.installation_id.clone(), + properties, + ); + let data = serde_json::to_vec(&row)?; + kinesis_client + .put_record() + .stream_name(stream) + .partition_key(row.insert_id.unwrap_or_default()) + .data(data.into()) + .send() + .await + .log_err(); } if !report_to_slack(&panic) { @@ -518,31 +518,31 @@ pub async fn post_events( let first_event_at = chrono::Utc::now() - chrono::Duration::milliseconds(last_event.milliseconds_since_first_event); - if let Some(kinesis_client) = app.kinesis_client.clone() { - if let Some(stream) = app.config.kinesis_stream.clone() { - let mut request = kinesis_client.put_records().stream_name(stream); - let mut has_records = false; - for row in for_snowflake( - request_body.clone(), - first_event_at, - country_code.clone(), - checksum_matched, - ) { - if let Some(data) = serde_json::to_vec(&row).log_err() { - request = request.records( - aws_sdk_kinesis::types::PutRecordsRequestEntry::builder() - .partition_key(request_body.system_id.clone().unwrap_or_default()) - .data(data.into()) - .build() - .unwrap(), - ); - has_records = true; - } - } - if has_records { - request.send().await.log_err(); + if let Some(kinesis_client) = app.kinesis_client.clone() + && let Some(stream) = app.config.kinesis_stream.clone() + { + let mut request = kinesis_client.put_records().stream_name(stream); + let mut has_records = false; + for row in for_snowflake( + request_body.clone(), + first_event_at, + country_code.clone(), + checksum_matched, + ) { + if let Some(data) = serde_json::to_vec(&row).log_err() { + request = request.records( + aws_sdk_kinesis::types::PutRecordsRequestEntry::builder() + .partition_key(request_body.system_id.clone().unwrap_or_default()) + .data(data.into()) + .build() + .unwrap(), + ); + has_records = true; } } + if has_records { + request.send().await.log_err(); + } }; Ok(()) @@ -564,170 +564,10 @@ fn for_snowflake( country_code: Option, checksum_matched: bool, ) -> impl Iterator { - body.events.into_iter().filter_map(move |event| { + body.events.into_iter().map(move |event| { let timestamp = first_event_at + Duration::milliseconds(event.milliseconds_since_first_event); - // We will need to double check, but I believe all of the events that - // are being transformed here are now migrated over to use the - // telemetry::event! macro, as of this commit so this code can go away - // when we feel enough users have upgraded past this point. let (event_type, mut event_properties) = match &event.event { - Event::Editor(e) => ( - match e.operation.as_str() { - "open" => "Editor Opened".to_string(), - "save" => "Editor Saved".to_string(), - _ => format!("Unknown Editor Event: {}", e.operation), - }, - serde_json::to_value(e).unwrap(), - ), - Event::EditPrediction(e) => ( - format!( - "Edit Prediction {}", - if e.suggestion_accepted { - "Accepted" - } else { - "Discarded" - } - ), - serde_json::to_value(e).unwrap(), - ), - Event::EditPredictionRating(e) => ( - "Edit Prediction Rated".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Call(e) => { - let event_type = match e.operation.trim() { - "unshare project" => "Project Unshared".to_string(), - "open channel notes" => "Channel Notes Opened".to_string(), - "share project" => "Project Shared".to_string(), - "join channel" => "Channel Joined".to_string(), - "hang up" => "Call Ended".to_string(), - "accept incoming" => "Incoming Call Accepted".to_string(), - "invite" => "Participant Invited".to_string(), - "disable microphone" => "Microphone Disabled".to_string(), - "enable microphone" => "Microphone Enabled".to_string(), - "enable screen share" => "Screen Share Enabled".to_string(), - "disable screen share" => "Screen Share Disabled".to_string(), - "decline incoming" => "Incoming Call Declined".to_string(), - _ => format!("Unknown Call Event: {}", e.operation), - }; - - (event_type, serde_json::to_value(e).unwrap()) - } - Event::Assistant(e) => ( - match e.phase { - telemetry_events::AssistantPhase::Response => "Assistant Responded".to_string(), - telemetry_events::AssistantPhase::Invoked => "Assistant Invoked".to_string(), - telemetry_events::AssistantPhase::Accepted => { - "Assistant Response Accepted".to_string() - } - telemetry_events::AssistantPhase::Rejected => { - "Assistant Response Rejected".to_string() - } - }, - serde_json::to_value(e).unwrap(), - ), - Event::Cpu(_) | Event::Memory(_) => return None, - Event::App(e) => { - let mut properties = json!({}); - let event_type = match e.operation.trim() { - // App - "open" => "App Opened".to_string(), - "first open" => "App First Opened".to_string(), - "first open for release channel" => { - "App First Opened For Release Channel".to_string() - } - "close" => "App Closed".to_string(), - - // Project - "open project" => "Project Opened".to_string(), - "open node project" => { - properties["project_type"] = json!("node"); - "Project Opened".to_string() - } - "open pnpm project" => { - properties["project_type"] = json!("pnpm"); - "Project Opened".to_string() - } - "open yarn project" => { - properties["project_type"] = json!("yarn"); - "Project Opened".to_string() - } - - // SSH - "create ssh server" => "SSH Server Created".to_string(), - "create ssh project" => "SSH Project Created".to_string(), - "open ssh project" => "SSH Project Opened".to_string(), - - // Welcome Page - "welcome page: change keymap" => "Welcome Keymap Changed".to_string(), - "welcome page: change theme" => "Welcome Theme Changed".to_string(), - "welcome page: close" => "Welcome Page Closed".to_string(), - "welcome page: edit settings" => "Welcome Settings Edited".to_string(), - "welcome page: install cli" => "Welcome CLI Installed".to_string(), - "welcome page: open" => "Welcome Page Opened".to_string(), - "welcome page: open extensions" => "Welcome Extensions Page Opened".to_string(), - "welcome page: sign in to copilot" => "Welcome Copilot Signed In".to_string(), - "welcome page: toggle diagnostic telemetry" => { - "Welcome Diagnostic Telemetry Toggled".to_string() - } - "welcome page: toggle metric telemetry" => { - "Welcome Metric Telemetry Toggled".to_string() - } - "welcome page: toggle vim" => "Welcome Vim Mode Toggled".to_string(), - "welcome page: view docs" => "Welcome Documentation Viewed".to_string(), - - // Extensions - "extensions page: open" => "Extensions Page Opened".to_string(), - "extensions: install extension" => "Extension Installed".to_string(), - "extensions: uninstall extension" => "Extension Uninstalled".to_string(), - - // Misc - "markdown preview: open" => "Markdown Preview Opened".to_string(), - "project diagnostics: open" => "Project Diagnostics Opened".to_string(), - "project search: open" => "Project Search Opened".to_string(), - "repl sessions: open" => "REPL Session Started".to_string(), - - // Feature Upsell - "feature upsell: toggle vim" => { - properties["source"] = json!("Feature Upsell"); - "Vim Mode Toggled".to_string() - } - _ => e - .operation - .strip_prefix("feature upsell: viewed docs (") - .and_then(|s| s.strip_suffix(')')) - .map_or_else( - || format!("Unknown App Event: {}", e.operation), - |docs_url| { - properties["url"] = json!(docs_url); - properties["source"] = json!("Feature Upsell"); - "Documentation Viewed".to_string() - }, - ), - }; - (event_type, properties) - } - Event::Setting(e) => ( - "Settings Changed".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Extension(e) => ( - "Extension Loaded".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Edit(e) => ( - "Editor Edited".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Action(e) => ( - "Action Invoked".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Repl(e) => ( - "Kernel Status Changed".to_string(), - serde_json::to_value(e).unwrap(), - ), Event::Flexible(e) => ( e.event_type.clone(), serde_json::to_value(&e.event_properties).unwrap(), @@ -759,7 +599,7 @@ fn for_snowflake( }) }); - Some(SnowflakeRow { + SnowflakeRow { time: timestamp, user_id: body.metrics_id.clone(), device_id: body.system_id.clone(), @@ -767,7 +607,7 @@ fn for_snowflake( event_properties, user_properties, insert_id: Some(Uuid::new_v4().to_string()), - }) + } }) } diff --git a/crates/collab/src/api/extensions.rs b/crates/collab/src/api/extensions.rs index 9170c39e472d33420fc889972e4c96e44914db15..1ace433db298be7ffd159128b54b194395ba4fe5 100644 --- a/crates/collab/src/api/extensions.rs +++ b/crates/collab/src/api/extensions.rs @@ -337,8 +337,7 @@ async fn fetch_extensions_from_blob_store( if known_versions .binary_search_by_key(&published_version, |known_version| known_version) .is_err() - { - if let Some(extension) = fetch_extension_manifest( + && let Some(extension) = fetch_extension_manifest( blob_store_client, blob_store_bucket, extension_id, @@ -346,12 +345,11 @@ async fn fetch_extensions_from_blob_store( ) .await .log_err() - { - new_versions - .entry(extension_id) - .or_default() - .push(extension); - } + { + new_versions + .entry(extension_id) + .or_default() + .push(extension); } } } diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 00f37c675874ce200cf79f2e8763450f4494fc79..13296b79ae8b3df97753e7adf4f2078990c187b0 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -79,27 +79,27 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into verify_access_token(access_token, user_id, &state.db).await }; - if let Ok(validate_result) = validate_result { - if validate_result.is_valid { - let user = state + if let Ok(validate_result) = validate_result + && validate_result.is_valid + { + let user = state + .db + .get_user_by_id(user_id) + .await? + .with_context(|| format!("user {user_id} not found"))?; + + if let Some(impersonator_id) = validate_result.impersonator_id { + let admin = state .db - .get_user_by_id(user_id) + .get_user_by_id(impersonator_id) .await? - .with_context(|| format!("user {user_id} not found"))?; - - if let Some(impersonator_id) = validate_result.impersonator_id { - let admin = state - .db - .get_user_by_id(impersonator_id) - .await? - .with_context(|| format!("user {impersonator_id} not found"))?; - req.extensions_mut() - .insert(Principal::Impersonated { user, admin }); - } else { - req.extensions_mut().insert(Principal::User(user)); - }; - return Ok::<_, Error>(next.run(req).await); - } + .with_context(|| format!("user {impersonator_id} not found"))?; + req.extensions_mut() + .insert(Principal::Impersonated { user, admin }); + } else { + req.extensions_mut().insert(Principal::User(user)); + }; + return Ok::<_, Error>(next.run(req).await); } Err(Error::http( @@ -227,7 +227,7 @@ pub async fn verify_access_token( #[cfg(test)] mod test { - use rand::thread_rng; + use rand::prelude::*; use scrypt::password_hash::{PasswordHasher, SaltString}; use sea_orm::EntityTrait; @@ -236,7 +236,7 @@ mod test { #[gpui::test] async fn test_verify_access_token(cx: &mut gpui::TestAppContext) { - let test_db = crate::db::TestDb::sqlite(cx.executor().clone()); + let test_db = crate::db::TestDb::sqlite(cx.executor()); let db = test_db.db(); let user = db @@ -358,9 +358,42 @@ mod test { None, None, params, - &SaltString::generate(thread_rng()), + &SaltString::generate(PasswordHashRngCompat::new()), ) .map_err(anyhow::Error::new)? .to_string()) } + + // TODO: remove once we password_hash v0.6 is released. + struct PasswordHashRngCompat(rand::rngs::ThreadRng); + + impl PasswordHashRngCompat { + fn new() -> Self { + Self(rand::rng()) + } + } + + impl scrypt::password_hash::rand_core::RngCore for PasswordHashRngCompat { + fn next_u32(&mut self) -> u32 { + self.0.next_u32() + } + + fn next_u64(&mut self) -> u64 { + self.0.next_u64() + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.0.fill_bytes(dest); + } + + fn try_fill_bytes( + &mut self, + dest: &mut [u8], + ) -> Result<(), scrypt::password_hash::rand_core::Error> { + self.fill_bytes(dest); + Ok(()) + } + } + + impl scrypt::password_hash::rand_core::CryptoRng for PasswordHashRngCompat {} } diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 2c22ca206945eb02752680b6149d7796643ee938..6ec57ce95e1863d973624f57947b28fffec042b1 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -26,7 +26,6 @@ use semantic_version::SemanticVersion; use serde::{Deserialize, Serialize}; use std::ops::RangeInclusive; use std::{ - fmt::Write as _, future::Future, marker::PhantomData, ops::{Deref, DerefMut}, @@ -41,12 +40,7 @@ use worktree_settings_file::LocalSettingsKind; pub use tests::TestDb; pub use ids::*; -pub use queries::billing_customers::{CreateBillingCustomerParams, UpdateBillingCustomerParams}; -pub use queries::billing_subscriptions::{ - CreateBillingSubscriptionParams, UpdateBillingSubscriptionParams, -}; pub use queries::contributors::ContributorSelector; -pub use queries::processed_stripe_events::CreateProcessedStripeEventParams; pub use sea_orm::ConnectOptions; pub use tables::user::Model as User; pub use tables::*; @@ -261,7 +255,7 @@ impl Database { let test_options = self.test_options.as_ref().unwrap(); test_options.executor.simulate_random_delay().await; let fail_probability = *test_options.query_failure_probability.lock(); - if test_options.executor.rng().gen_bool(fail_probability) { + if test_options.executor.rng().random_bool(fail_probability) { return Err(anyhow!("simulated query failure"))?; } @@ -491,9 +485,7 @@ pub struct ChannelsForUser { pub invited_channels: Vec, pub observed_buffer_versions: Vec, - pub observed_channel_messages: Vec, pub latest_buffer_versions: Vec, - pub latest_channel_messages: Vec, } #[derive(Debug)] @@ -690,7 +682,7 @@ impl LocalSettingsKind { } } - pub fn to_proto(&self) -> proto::LocalSettingsKind { + pub fn to_proto(self) -> proto::LocalSettingsKind { match self { Self::Settings => proto::LocalSettingsKind::Settings, Self::Tasks => proto::LocalSettingsKind::Tasks, diff --git a/crates/collab/src/db/ids.rs b/crates/collab/src/db/ids.rs index 2ba7ec10514d8b0bbf4b26eab0b9384b3911204e..8f116cfd633749b21ff197a723f9e779a750b561 100644 --- a/crates/collab/src/db/ids.rs +++ b/crates/collab/src/db/ids.rs @@ -70,9 +70,6 @@ macro_rules! id_type { } id_type!(AccessTokenId); -id_type!(BillingCustomerId); -id_type!(BillingSubscriptionId); -id_type!(BillingPreferencesId); id_type!(BufferId); id_type!(ChannelBufferCollaboratorId); id_type!(ChannelChatParticipantId); diff --git a/crates/collab/src/db/queries.rs b/crates/collab/src/db/queries.rs index 64b627e47518e0c18eedd8c625a0c98b678a96cc..7b457a5da438e0a9ab7c6cd79368b2845e962318 100644 --- a/crates/collab/src/db/queries.rs +++ b/crates/collab/src/db/queries.rs @@ -1,18 +1,13 @@ use super::*; pub mod access_tokens; -pub mod billing_customers; -pub mod billing_preferences; -pub mod billing_subscriptions; pub mod buffers; pub mod channels; pub mod contacts; pub mod contributors; pub mod embeddings; pub mod extensions; -pub mod messages; pub mod notifications; -pub mod processed_stripe_events; pub mod projects; pub mod rooms; pub mod servers; diff --git a/crates/collab/src/db/queries/billing_customers.rs b/crates/collab/src/db/queries/billing_customers.rs deleted file mode 100644 index ead9e6cd32dc4e52a5c0e2438e9e8ff97735a255..0000000000000000000000000000000000000000 --- a/crates/collab/src/db/queries/billing_customers.rs +++ /dev/null @@ -1,100 +0,0 @@ -use super::*; - -#[derive(Debug)] -pub struct CreateBillingCustomerParams { - pub user_id: UserId, - pub stripe_customer_id: String, -} - -#[derive(Debug, Default)] -pub struct UpdateBillingCustomerParams { - pub user_id: ActiveValue, - pub stripe_customer_id: ActiveValue, - pub has_overdue_invoices: ActiveValue, - pub trial_started_at: ActiveValue>, -} - -impl Database { - /// Creates a new billing customer. - pub async fn create_billing_customer( - &self, - params: &CreateBillingCustomerParams, - ) -> Result { - self.transaction(|tx| async move { - let customer = billing_customer::Entity::insert(billing_customer::ActiveModel { - user_id: ActiveValue::set(params.user_id), - stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()), - ..Default::default() - }) - .exec_with_returning(&*tx) - .await?; - - Ok(customer) - }) - .await - } - - /// Updates the specified billing customer. - pub async fn update_billing_customer( - &self, - id: BillingCustomerId, - params: &UpdateBillingCustomerParams, - ) -> Result<()> { - self.transaction(|tx| async move { - billing_customer::Entity::update(billing_customer::ActiveModel { - id: ActiveValue::set(id), - user_id: params.user_id.clone(), - stripe_customer_id: params.stripe_customer_id.clone(), - has_overdue_invoices: params.has_overdue_invoices.clone(), - trial_started_at: params.trial_started_at.clone(), - created_at: ActiveValue::not_set(), - }) - .exec(&*tx) - .await?; - - Ok(()) - }) - .await - } - - pub async fn get_billing_customer_by_id( - &self, - id: BillingCustomerId, - ) -> Result> { - self.transaction(|tx| async move { - Ok(billing_customer::Entity::find() - .filter(billing_customer::Column::Id.eq(id)) - .one(&*tx) - .await?) - }) - .await - } - - /// Returns the billing customer for the user with the specified ID. - pub async fn get_billing_customer_by_user_id( - &self, - user_id: UserId, - ) -> Result> { - self.transaction(|tx| async move { - Ok(billing_customer::Entity::find() - .filter(billing_customer::Column::UserId.eq(user_id)) - .one(&*tx) - .await?) - }) - .await - } - - /// Returns the billing customer for the user with the specified Stripe customer ID. - pub async fn get_billing_customer_by_stripe_customer_id( - &self, - stripe_customer_id: &str, - ) -> Result> { - self.transaction(|tx| async move { - Ok(billing_customer::Entity::find() - .filter(billing_customer::Column::StripeCustomerId.eq(stripe_customer_id)) - .one(&*tx) - .await?) - }) - .await - } -} diff --git a/crates/collab/src/db/queries/billing_preferences.rs b/crates/collab/src/db/queries/billing_preferences.rs deleted file mode 100644 index f370964ecd7d5c762c88e5fb572fde84ce81935d..0000000000000000000000000000000000000000 --- a/crates/collab/src/db/queries/billing_preferences.rs +++ /dev/null @@ -1,17 +0,0 @@ -use super::*; - -impl Database { - /// Returns the billing preferences for the given user, if they exist. - pub async fn get_billing_preferences( - &self, - user_id: UserId, - ) -> Result> { - self.transaction(|tx| async move { - Ok(billing_preference::Entity::find() - .filter(billing_preference::Column::UserId.eq(user_id)) - .one(&*tx) - .await?) - }) - .await - } -} diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs deleted file mode 100644 index 8361d6b4d07f8e6b59f9c7b39b18057e6f62b3c0..0000000000000000000000000000000000000000 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ /dev/null @@ -1,158 +0,0 @@ -use anyhow::Context as _; - -use crate::db::billing_subscription::{ - StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind, -}; - -use super::*; - -#[derive(Debug)] -pub struct CreateBillingSubscriptionParams { - pub billing_customer_id: BillingCustomerId, - pub kind: Option, - pub stripe_subscription_id: String, - pub stripe_subscription_status: StripeSubscriptionStatus, - pub stripe_cancellation_reason: Option, - pub stripe_current_period_start: Option, - pub stripe_current_period_end: Option, -} - -#[derive(Debug, Default)] -pub struct UpdateBillingSubscriptionParams { - pub billing_customer_id: ActiveValue, - pub kind: ActiveValue>, - pub stripe_subscription_id: ActiveValue, - pub stripe_subscription_status: ActiveValue, - pub stripe_cancel_at: ActiveValue>, - pub stripe_cancellation_reason: ActiveValue>, - pub stripe_current_period_start: ActiveValue>, - pub stripe_current_period_end: ActiveValue>, -} - -impl Database { - /// Creates a new billing subscription. - pub async fn create_billing_subscription( - &self, - params: &CreateBillingSubscriptionParams, - ) -> Result { - self.transaction(|tx| async move { - let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel { - billing_customer_id: ActiveValue::set(params.billing_customer_id), - kind: ActiveValue::set(params.kind), - stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()), - stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status), - stripe_cancellation_reason: ActiveValue::set(params.stripe_cancellation_reason), - stripe_current_period_start: ActiveValue::set(params.stripe_current_period_start), - stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end), - ..Default::default() - }) - .exec(&*tx) - .await? - .last_insert_id; - - Ok(billing_subscription::Entity::find_by_id(id) - .one(&*tx) - .await? - .context("failed to retrieve inserted billing subscription")?) - }) - .await - } - - /// Updates the specified billing subscription. - pub async fn update_billing_subscription( - &self, - id: BillingSubscriptionId, - params: &UpdateBillingSubscriptionParams, - ) -> Result<()> { - self.transaction(|tx| async move { - billing_subscription::Entity::update(billing_subscription::ActiveModel { - id: ActiveValue::set(id), - billing_customer_id: params.billing_customer_id.clone(), - kind: params.kind.clone(), - stripe_subscription_id: params.stripe_subscription_id.clone(), - stripe_subscription_status: params.stripe_subscription_status.clone(), - stripe_cancel_at: params.stripe_cancel_at.clone(), - stripe_cancellation_reason: params.stripe_cancellation_reason.clone(), - stripe_current_period_start: params.stripe_current_period_start.clone(), - stripe_current_period_end: params.stripe_current_period_end.clone(), - created_at: ActiveValue::not_set(), - }) - .exec(&*tx) - .await?; - - Ok(()) - }) - .await - } - - /// Returns the billing subscription with the specified Stripe subscription ID. - pub async fn get_billing_subscription_by_stripe_subscription_id( - &self, - stripe_subscription_id: &str, - ) -> Result> { - self.transaction(|tx| async move { - Ok(billing_subscription::Entity::find() - .filter( - billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id), - ) - .one(&*tx) - .await?) - }) - .await - } - - pub async fn get_active_billing_subscription( - &self, - user_id: UserId, - ) -> Result> { - self.transaction(|tx| async move { - Ok(billing_subscription::Entity::find() - .inner_join(billing_customer::Entity) - .filter(billing_customer::Column::UserId.eq(user_id)) - .filter( - Condition::all() - .add( - Condition::any() - .add( - billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::Active), - ) - .add( - billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::Trialing), - ), - ) - .add(billing_subscription::Column::Kind.is_not_null()), - ) - .one(&*tx) - .await?) - }) - .await - } - - /// Returns whether the user has an active billing subscription. - pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result { - Ok(self.count_active_billing_subscriptions(user_id).await? > 0) - } - - /// Returns the count of the active billing subscriptions for the user with the specified ID. - pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result { - self.transaction(|tx| async move { - let count = billing_subscription::Entity::find() - .inner_join(billing_customer::Entity) - .filter( - billing_customer::Column::UserId.eq(user_id).and( - billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::Active) - .or(billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::Trialing)), - ), - ) - .count(&*tx) - .await?; - - Ok(count as usize) - }) - .await - } -} diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index 5e296e0a3b8e3cb16bd0a1820688d808e10a8193..4bb82865e73968e2861777d5cd0f700675366e81 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -618,25 +618,17 @@ impl Database { } drop(rows); - let latest_channel_messages = self.latest_channel_messages(&channel_ids, tx).await?; - let observed_buffer_versions = self .observed_channel_buffer_changes(&channel_ids_by_buffer_id, user_id, tx) .await?; - let observed_channel_messages = self - .observed_channel_messages(&channel_ids, user_id, tx) - .await?; - Ok(ChannelsForUser { channel_memberships, channels, invited_channels, channel_participants, latest_buffer_versions, - latest_channel_messages, observed_buffer_versions, - observed_channel_messages, }) } diff --git a/crates/collab/src/db/queries/extensions.rs b/crates/collab/src/db/queries/extensions.rs index 7d8aad2be4bd3581cbdbe3dc3a1dfbc935f81966..f218ff28507cf51a72cd0aa00a044ad75f64f839 100644 --- a/crates/collab/src/db/queries/extensions.rs +++ b/crates/collab/src/db/queries/extensions.rs @@ -87,10 +87,10 @@ impl Database { continue; }; - if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id) { - if max_extension_version > &extension_version { - continue; - } + if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id) + && max_extension_version > &extension_version + { + continue; } if let Some(constraints) = constraints { @@ -331,10 +331,10 @@ impl Database { .exec_without_returning(&*tx) .await?; - if let Ok(db_version) = semver::Version::parse(&extension.latest_version) { - if db_version >= latest_version.version { - continue; - } + if let Ok(db_version) = semver::Version::parse(&extension.latest_version) + && db_version >= latest_version.version + { + continue; } let mut extension = extension.into_active_model(); diff --git a/crates/collab/src/db/queries/messages.rs b/crates/collab/src/db/queries/messages.rs deleted file mode 100644 index 38e100053c0e88311aacd69a14fd8cb98e43ee28..0000000000000000000000000000000000000000 --- a/crates/collab/src/db/queries/messages.rs +++ /dev/null @@ -1,725 +0,0 @@ -use super::*; -use anyhow::Context as _; -use rpc::Notification; -use sea_orm::{SelectColumns, TryInsertResult}; -use time::OffsetDateTime; -use util::ResultExt; - -impl Database { - /// Inserts a record representing a user joining the chat for a given channel. - pub async fn join_channel_chat( - &self, - channel_id: ChannelId, - connection_id: ConnectionId, - user_id: UserId, - ) -> Result<()> { - self.transaction(|tx| async move { - let channel = self.get_channel_internal(channel_id, &tx).await?; - self.check_user_is_channel_participant(&channel, user_id, &tx) - .await?; - channel_chat_participant::ActiveModel { - id: ActiveValue::NotSet, - channel_id: ActiveValue::Set(channel_id), - user_id: ActiveValue::Set(user_id), - connection_id: ActiveValue::Set(connection_id.id as i32), - connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)), - } - .insert(&*tx) - .await?; - Ok(()) - }) - .await - } - - /// Removes `channel_chat_participant` records associated with the given connection ID. - pub async fn channel_chat_connection_lost( - &self, - connection_id: ConnectionId, - tx: &DatabaseTransaction, - ) -> Result<()> { - channel_chat_participant::Entity::delete_many() - .filter( - Condition::all() - .add( - channel_chat_participant::Column::ConnectionServerId - .eq(connection_id.owner_id), - ) - .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)), - ) - .exec(tx) - .await?; - Ok(()) - } - - /// Removes `channel_chat_participant` records associated with the given user ID so they - /// will no longer get chat notifications. - pub async fn leave_channel_chat( - &self, - channel_id: ChannelId, - connection_id: ConnectionId, - _user_id: UserId, - ) -> Result<()> { - self.transaction(|tx| async move { - channel_chat_participant::Entity::delete_many() - .filter( - Condition::all() - .add( - channel_chat_participant::Column::ConnectionServerId - .eq(connection_id.owner_id), - ) - .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)) - .add(channel_chat_participant::Column::ChannelId.eq(channel_id)), - ) - .exec(&*tx) - .await?; - - Ok(()) - }) - .await - } - - /// Retrieves the messages in the specified channel. - /// - /// Use `before_message_id` to paginate through the channel's messages. - pub async fn get_channel_messages( - &self, - channel_id: ChannelId, - user_id: UserId, - count: usize, - before_message_id: Option, - ) -> Result> { - self.transaction(|tx| async move { - let channel = self.get_channel_internal(channel_id, &tx).await?; - self.check_user_is_channel_participant(&channel, user_id, &tx) - .await?; - - let mut condition = - Condition::all().add(channel_message::Column::ChannelId.eq(channel_id)); - - if let Some(before_message_id) = before_message_id { - condition = condition.add(channel_message::Column::Id.lt(before_message_id)); - } - - let rows = channel_message::Entity::find() - .filter(condition) - .order_by_desc(channel_message::Column::Id) - .limit(count as u64) - .all(&*tx) - .await?; - - self.load_channel_messages(rows, &tx).await - }) - .await - } - - /// Returns the channel messages with the given IDs. - pub async fn get_channel_messages_by_id( - &self, - user_id: UserId, - message_ids: &[MessageId], - ) -> Result> { - self.transaction(|tx| async move { - let rows = channel_message::Entity::find() - .filter(channel_message::Column::Id.is_in(message_ids.iter().copied())) - .order_by_desc(channel_message::Column::Id) - .all(&*tx) - .await?; - - let mut channels = HashMap::::default(); - for row in &rows { - channels.insert( - row.channel_id, - self.get_channel_internal(row.channel_id, &tx).await?, - ); - } - - for (_, channel) in channels { - self.check_user_is_channel_participant(&channel, user_id, &tx) - .await?; - } - - let messages = self.load_channel_messages(rows, &tx).await?; - Ok(messages) - }) - .await - } - - async fn load_channel_messages( - &self, - rows: Vec, - tx: &DatabaseTransaction, - ) -> Result> { - let mut messages = rows - .into_iter() - .map(|row| { - let nonce = row.nonce.as_u64_pair(); - proto::ChannelMessage { - id: row.id.to_proto(), - sender_id: row.sender_id.to_proto(), - body: row.body, - timestamp: row.sent_at.assume_utc().unix_timestamp() as u64, - mentions: vec![], - nonce: Some(proto::Nonce { - upper_half: nonce.0, - lower_half: nonce.1, - }), - reply_to_message_id: row.reply_to_message_id.map(|id| id.to_proto()), - edited_at: row - .edited_at - .map(|t| t.assume_utc().unix_timestamp() as u64), - } - }) - .collect::>(); - messages.reverse(); - - let mut mentions = channel_message_mention::Entity::find() - .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id))) - .order_by_asc(channel_message_mention::Column::MessageId) - .order_by_asc(channel_message_mention::Column::StartOffset) - .stream(tx) - .await?; - - let mut message_ix = 0; - while let Some(mention) = mentions.next().await { - let mention = mention?; - let message_id = mention.message_id.to_proto(); - while let Some(message) = messages.get_mut(message_ix) { - if message.id < message_id { - message_ix += 1; - } else { - if message.id == message_id { - message.mentions.push(proto::ChatMention { - range: Some(proto::Range { - start: mention.start_offset as u64, - end: mention.end_offset as u64, - }), - user_id: mention.user_id.to_proto(), - }); - } - break; - } - } - } - - Ok(messages) - } - - fn format_mentions_to_entities( - &self, - message_id: MessageId, - body: &str, - mentions: &[proto::ChatMention], - ) -> Result> { - Ok(mentions - .iter() - .filter_map(|mention| { - let range = mention.range.as_ref()?; - if !body.is_char_boundary(range.start as usize) - || !body.is_char_boundary(range.end as usize) - { - return None; - } - Some(channel_message_mention::ActiveModel { - message_id: ActiveValue::Set(message_id), - start_offset: ActiveValue::Set(range.start as i32), - end_offset: ActiveValue::Set(range.end as i32), - user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)), - }) - }) - .collect::>()) - } - - /// Creates a new channel message. - pub async fn create_channel_message( - &self, - channel_id: ChannelId, - user_id: UserId, - body: &str, - mentions: &[proto::ChatMention], - timestamp: OffsetDateTime, - nonce: u128, - reply_to_message_id: Option, - ) -> Result { - self.transaction(|tx| async move { - let channel = self.get_channel_internal(channel_id, &tx).await?; - self.check_user_is_channel_participant(&channel, user_id, &tx) - .await?; - - let mut rows = channel_chat_participant::Entity::find() - .filter(channel_chat_participant::Column::ChannelId.eq(channel_id)) - .stream(&*tx) - .await?; - - let mut is_participant = false; - let mut participant_connection_ids = HashSet::default(); - let mut participant_user_ids = Vec::new(); - while let Some(row) = rows.next().await { - let row = row?; - if row.user_id == user_id { - is_participant = true; - } - participant_user_ids.push(row.user_id); - participant_connection_ids.insert(row.connection()); - } - drop(rows); - - if !is_participant { - Err(anyhow!("not a chat participant"))?; - } - - let timestamp = timestamp.to_offset(time::UtcOffset::UTC); - let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time()); - - let result = channel_message::Entity::insert(channel_message::ActiveModel { - channel_id: ActiveValue::Set(channel_id), - sender_id: ActiveValue::Set(user_id), - body: ActiveValue::Set(body.to_string()), - sent_at: ActiveValue::Set(timestamp), - nonce: ActiveValue::Set(Uuid::from_u128(nonce)), - id: ActiveValue::NotSet, - reply_to_message_id: ActiveValue::Set(reply_to_message_id), - edited_at: ActiveValue::NotSet, - }) - .on_conflict( - OnConflict::columns([ - channel_message::Column::SenderId, - channel_message::Column::Nonce, - ]) - .do_nothing() - .to_owned(), - ) - .do_nothing() - .exec(&*tx) - .await?; - - let message_id; - let mut notifications = Vec::new(); - match result { - TryInsertResult::Inserted(result) => { - message_id = result.last_insert_id; - let mentioned_user_ids = - mentions.iter().map(|m| m.user_id).collect::>(); - - let mentions = self.format_mentions_to_entities(message_id, body, mentions)?; - if !mentions.is_empty() { - channel_message_mention::Entity::insert_many(mentions) - .exec(&*tx) - .await?; - } - - for mentioned_user in mentioned_user_ids { - notifications.extend( - self.create_notification( - UserId::from_proto(mentioned_user), - rpc::Notification::ChannelMessageMention { - message_id: message_id.to_proto(), - sender_id: user_id.to_proto(), - channel_id: channel_id.to_proto(), - }, - false, - &tx, - ) - .await?, - ); - } - - self.observe_channel_message_internal(channel_id, user_id, message_id, &tx) - .await?; - } - _ => { - message_id = channel_message::Entity::find() - .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce))) - .one(&*tx) - .await? - .context("failed to insert message")? - .id; - } - } - - Ok(CreatedChannelMessage { - message_id, - participant_connection_ids, - notifications, - }) - }) - .await - } - - pub async fn observe_channel_message( - &self, - channel_id: ChannelId, - user_id: UserId, - message_id: MessageId, - ) -> Result { - self.transaction(|tx| async move { - self.observe_channel_message_internal(channel_id, user_id, message_id, &tx) - .await?; - let mut batch = NotificationBatch::default(); - batch.extend( - self.mark_notification_as_read( - user_id, - &Notification::ChannelMessageMention { - message_id: message_id.to_proto(), - sender_id: Default::default(), - channel_id: Default::default(), - }, - &tx, - ) - .await?, - ); - Ok(batch) - }) - .await - } - - async fn observe_channel_message_internal( - &self, - channel_id: ChannelId, - user_id: UserId, - message_id: MessageId, - tx: &DatabaseTransaction, - ) -> Result<()> { - observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel { - user_id: ActiveValue::Set(user_id), - channel_id: ActiveValue::Set(channel_id), - channel_message_id: ActiveValue::Set(message_id), - }) - .on_conflict( - OnConflict::columns([ - observed_channel_messages::Column::ChannelId, - observed_channel_messages::Column::UserId, - ]) - .update_column(observed_channel_messages::Column::ChannelMessageId) - .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id)) - .to_owned(), - ) - // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug - .exec_without_returning(tx) - .await?; - Ok(()) - } - - pub async fn observed_channel_messages( - &self, - channel_ids: &[ChannelId], - user_id: UserId, - tx: &DatabaseTransaction, - ) -> Result> { - let rows = observed_channel_messages::Entity::find() - .filter(observed_channel_messages::Column::UserId.eq(user_id)) - .filter( - observed_channel_messages::Column::ChannelId - .is_in(channel_ids.iter().map(|id| id.0)), - ) - .all(tx) - .await?; - - Ok(rows - .into_iter() - .map(|message| proto::ChannelMessageId { - channel_id: message.channel_id.to_proto(), - message_id: message.channel_message_id.to_proto(), - }) - .collect()) - } - - pub async fn latest_channel_messages( - &self, - channel_ids: &[ChannelId], - tx: &DatabaseTransaction, - ) -> Result> { - let mut values = String::new(); - for id in channel_ids { - if !values.is_empty() { - values.push_str(", "); - } - write!(&mut values, "({})", id).unwrap(); - } - - if values.is_empty() { - return Ok(Vec::default()); - } - - let sql = format!( - r#" - SELECT - * - FROM ( - SELECT - *, - row_number() OVER ( - PARTITION BY channel_id - ORDER BY id DESC - ) as row_number - FROM channel_messages - WHERE - channel_id in ({values}) - ) AS messages - WHERE - row_number = 1 - "#, - ); - - let stmt = Statement::from_string(self.pool.get_database_backend(), sql); - let mut last_messages = channel_message::Model::find_by_statement(stmt) - .stream(tx) - .await?; - - let mut results = Vec::new(); - while let Some(result) = last_messages.next().await { - let message = result?; - results.push(proto::ChannelMessageId { - channel_id: message.channel_id.to_proto(), - message_id: message.id.to_proto(), - }); - } - - Ok(results) - } - - fn get_notification_kind_id_by_name(&self, notification_kind: &str) -> Option { - self.notification_kinds_by_id - .iter() - .find(|(_, kind)| **kind == notification_kind) - .map(|kind| kind.0.0) - } - - /// Removes the channel message with the given ID. - pub async fn remove_channel_message( - &self, - channel_id: ChannelId, - message_id: MessageId, - user_id: UserId, - ) -> Result<(Vec, Vec)> { - self.transaction(|tx| async move { - let mut rows = channel_chat_participant::Entity::find() - .filter(channel_chat_participant::Column::ChannelId.eq(channel_id)) - .stream(&*tx) - .await?; - - let mut is_participant = false; - let mut participant_connection_ids = Vec::new(); - while let Some(row) = rows.next().await { - let row = row?; - if row.user_id == user_id { - is_participant = true; - } - participant_connection_ids.push(row.connection()); - } - drop(rows); - - if !is_participant { - Err(anyhow!("not a chat participant"))?; - } - - let result = channel_message::Entity::delete_by_id(message_id) - .filter(channel_message::Column::SenderId.eq(user_id)) - .exec(&*tx) - .await?; - - if result.rows_affected == 0 { - let channel = self.get_channel_internal(channel_id, &tx).await?; - if self - .check_user_is_channel_admin(&channel, user_id, &tx) - .await - .is_ok() - { - let result = channel_message::Entity::delete_by_id(message_id) - .exec(&*tx) - .await?; - if result.rows_affected == 0 { - Err(anyhow!("no such message"))?; - } - } else { - Err(anyhow!("operation could not be completed"))?; - } - } - - let notification_kind_id = - self.get_notification_kind_id_by_name("ChannelMessageMention"); - - let existing_notifications = notification::Entity::find() - .filter(notification::Column::EntityId.eq(message_id)) - .filter(notification::Column::Kind.eq(notification_kind_id)) - .select_column(notification::Column::Id) - .all(&*tx) - .await?; - - let existing_notification_ids = existing_notifications - .into_iter() - .map(|notification| notification.id) - .collect(); - - // remove all the mention notifications for this message - notification::Entity::delete_many() - .filter(notification::Column::EntityId.eq(message_id)) - .filter(notification::Column::Kind.eq(notification_kind_id)) - .exec(&*tx) - .await?; - - Ok((participant_connection_ids, existing_notification_ids)) - }) - .await - } - - /// Updates the channel message with the given ID, body and timestamp(edited_at). - pub async fn update_channel_message( - &self, - channel_id: ChannelId, - message_id: MessageId, - user_id: UserId, - body: &str, - mentions: &[proto::ChatMention], - edited_at: OffsetDateTime, - ) -> Result { - self.transaction(|tx| async move { - let channel = self.get_channel_internal(channel_id, &tx).await?; - self.check_user_is_channel_participant(&channel, user_id, &tx) - .await?; - - let mut rows = channel_chat_participant::Entity::find() - .filter(channel_chat_participant::Column::ChannelId.eq(channel_id)) - .stream(&*tx) - .await?; - - let mut is_participant = false; - let mut participant_connection_ids = Vec::new(); - let mut participant_user_ids = Vec::new(); - while let Some(row) = rows.next().await { - let row = row?; - if row.user_id == user_id { - is_participant = true; - } - participant_user_ids.push(row.user_id); - participant_connection_ids.push(row.connection()); - } - drop(rows); - - if !is_participant { - Err(anyhow!("not a chat participant"))?; - } - - let channel_message = channel_message::Entity::find_by_id(message_id) - .filter(channel_message::Column::SenderId.eq(user_id)) - .one(&*tx) - .await?; - - let Some(channel_message) = channel_message else { - Err(anyhow!("Channel message not found"))? - }; - - let edited_at = edited_at.to_offset(time::UtcOffset::UTC); - let edited_at = time::PrimitiveDateTime::new(edited_at.date(), edited_at.time()); - - let updated_message = channel_message::ActiveModel { - body: ActiveValue::Set(body.to_string()), - edited_at: ActiveValue::Set(Some(edited_at)), - reply_to_message_id: ActiveValue::Unchanged(channel_message.reply_to_message_id), - id: ActiveValue::Unchanged(message_id), - channel_id: ActiveValue::Unchanged(channel_id), - sender_id: ActiveValue::Unchanged(user_id), - sent_at: ActiveValue::Unchanged(channel_message.sent_at), - nonce: ActiveValue::Unchanged(channel_message.nonce), - }; - - let result = channel_message::Entity::update_many() - .set(updated_message) - .filter(channel_message::Column::Id.eq(message_id)) - .filter(channel_message::Column::SenderId.eq(user_id)) - .exec(&*tx) - .await?; - if result.rows_affected == 0 { - return Err(anyhow!( - "Attempted to edit a message (id: {message_id}) which does not exist anymore." - ))?; - } - - // we have to fetch the old mentions, - // so we don't send a notification when the message has been edited that you are mentioned in - let old_mentions = channel_message_mention::Entity::find() - .filter(channel_message_mention::Column::MessageId.eq(message_id)) - .all(&*tx) - .await?; - - // remove all existing mentions - channel_message_mention::Entity::delete_many() - .filter(channel_message_mention::Column::MessageId.eq(message_id)) - .exec(&*tx) - .await?; - - let new_mentions = self.format_mentions_to_entities(message_id, body, mentions)?; - if !new_mentions.is_empty() { - // insert new mentions - channel_message_mention::Entity::insert_many(new_mentions) - .exec(&*tx) - .await?; - } - - let mut update_mention_user_ids = HashSet::default(); - let mut new_mention_user_ids = - mentions.iter().map(|m| m.user_id).collect::>(); - // Filter out users that were mentioned before - for mention in &old_mentions { - if new_mention_user_ids.contains(&mention.user_id.to_proto()) { - update_mention_user_ids.insert(mention.user_id.to_proto()); - } - - new_mention_user_ids.remove(&mention.user_id.to_proto()); - } - - let notification_kind_id = - self.get_notification_kind_id_by_name("ChannelMessageMention"); - - let existing_notifications = notification::Entity::find() - .filter(notification::Column::EntityId.eq(message_id)) - .filter(notification::Column::Kind.eq(notification_kind_id)) - .all(&*tx) - .await?; - - // determine which notifications should be updated or deleted - let mut deleted_notification_ids = HashSet::default(); - let mut updated_mention_notifications = Vec::new(); - for notification in existing_notifications { - if update_mention_user_ids.contains(¬ification.recipient_id.to_proto()) { - if let Some(notification) = - self::notifications::model_to_proto(self, notification).log_err() - { - updated_mention_notifications.push(notification); - } - } else { - deleted_notification_ids.insert(notification.id); - } - } - - let mut notifications = Vec::new(); - for mentioned_user in new_mention_user_ids { - notifications.extend( - self.create_notification( - UserId::from_proto(mentioned_user), - rpc::Notification::ChannelMessageMention { - message_id: message_id.to_proto(), - sender_id: user_id.to_proto(), - channel_id: channel_id.to_proto(), - }, - false, - &tx, - ) - .await?, - ); - } - - Ok(UpdatedChannelMessage { - message_id, - participant_connection_ids, - notifications, - reply_to_message_id: channel_message.reply_to_message_id, - timestamp: channel_message.sent_at, - deleted_mention_notification_ids: deleted_notification_ids - .into_iter() - .collect::>(), - updated_mention_notifications, - }) - }) - .await - } -} diff --git a/crates/collab/src/db/queries/processed_stripe_events.rs b/crates/collab/src/db/queries/processed_stripe_events.rs deleted file mode 100644 index f14ad480e09fb4c0d6d43569b03e7888e9929cf4..0000000000000000000000000000000000000000 --- a/crates/collab/src/db/queries/processed_stripe_events.rs +++ /dev/null @@ -1,69 +0,0 @@ -use super::*; - -#[derive(Debug)] -pub struct CreateProcessedStripeEventParams { - pub stripe_event_id: String, - pub stripe_event_type: String, - pub stripe_event_created_timestamp: i64, -} - -impl Database { - /// Creates a new processed Stripe event. - pub async fn create_processed_stripe_event( - &self, - params: &CreateProcessedStripeEventParams, - ) -> Result<()> { - self.transaction(|tx| async move { - processed_stripe_event::Entity::insert(processed_stripe_event::ActiveModel { - stripe_event_id: ActiveValue::set(params.stripe_event_id.clone()), - stripe_event_type: ActiveValue::set(params.stripe_event_type.clone()), - stripe_event_created_timestamp: ActiveValue::set( - params.stripe_event_created_timestamp, - ), - ..Default::default() - }) - .exec_without_returning(&*tx) - .await?; - - Ok(()) - }) - .await - } - - /// Returns the processed Stripe event with the specified event ID. - pub async fn get_processed_stripe_event_by_event_id( - &self, - event_id: &str, - ) -> Result> { - self.transaction(|tx| async move { - Ok(processed_stripe_event::Entity::find_by_id(event_id) - .one(&*tx) - .await?) - }) - .await - } - - /// Returns the processed Stripe events with the specified event IDs. - pub async fn get_processed_stripe_events_by_event_ids( - &self, - event_ids: &[&str], - ) -> Result> { - self.transaction(|tx| async move { - Ok(processed_stripe_event::Entity::find() - .filter( - processed_stripe_event::Column::StripeEventId.is_in(event_ids.iter().copied()), - ) - .all(&*tx) - .await?) - }) - .await - } - - /// Returns whether the Stripe event with the specified ID has already been processed. - pub async fn already_processed_stripe_event(&self, event_id: &str) -> Result { - Ok(self - .get_processed_stripe_event_by_event_id(event_id) - .await? - .is_some()) - } -} diff --git a/crates/collab/src/db/queries/projects.rs b/crates/collab/src/db/queries/projects.rs index 82f74d910ba0d12c1473719189e066eb9d0307eb..a3f0ea6cbc6e762e365f82e74b886234e62da109 100644 --- a/crates/collab/src/db/queries/projects.rs +++ b/crates/collab/src/db/queries/projects.rs @@ -349,11 +349,11 @@ impl Database { serde_json::to_string(&repository.current_merge_conflicts) .unwrap(), )), - - // Old clients do not use abs path, entry ids or head_commit_details. + // Old clients do not use abs path, entry ids, head_commit_details, or merge_message. abs_path: ActiveValue::set(String::new()), entry_ids: ActiveValue::set("[]".into()), head_commit_details: ActiveValue::set(None), + merge_message: ActiveValue::set(None), } }), ) @@ -502,6 +502,7 @@ impl Database { current_merge_conflicts: ActiveValue::Set(Some( serde_json::to_string(&update.current_merge_conflicts).unwrap(), )), + merge_message: ActiveValue::set(update.merge_message.clone()), }) .on_conflict( OnConflict::columns([ @@ -515,6 +516,7 @@ impl Database { project_repository::Column::AbsPath, project_repository::Column::CurrentMergeConflicts, project_repository::Column::HeadCommitDetails, + project_repository::Column::MergeMessage, ]) .to_owned(), ) @@ -692,6 +694,7 @@ impl Database { project_id: ActiveValue::set(project_id), id: ActiveValue::set(server.id as i64), name: ActiveValue::set(server.name.clone()), + worktree_id: ActiveValue::set(server.worktree_id.map(|id| id as i64)), capabilities: ActiveValue::set(update.capabilities.clone()), }) .on_conflict( @@ -702,6 +705,7 @@ impl Database { .update_columns([ language_server::Column::Name, language_server::Column::Capabilities, + language_server::Column::WorktreeId, ]) .to_owned(), ) @@ -943,21 +947,21 @@ impl Database { let current_merge_conflicts = db_repository_entry .current_merge_conflicts .as_ref() - .map(|conflicts| serde_json::from_str(&conflicts)) + .map(|conflicts| serde_json::from_str(conflicts)) .transpose()? .unwrap_or_default(); let branch_summary = db_repository_entry .branch_summary .as_ref() - .map(|branch_summary| serde_json::from_str(&branch_summary)) + .map(|branch_summary| serde_json::from_str(branch_summary)) .transpose()? .unwrap_or_default(); let head_commit_details = db_repository_entry .head_commit_details .as_ref() - .map(|head_commit_details| serde_json::from_str(&head_commit_details)) + .map(|head_commit_details| serde_json::from_str(head_commit_details)) .transpose()? .unwrap_or_default(); @@ -990,6 +994,7 @@ impl Database { head_commit_details, scan_id: db_repository_entry.scan_id as u64, is_last_update: true, + merge_message: db_repository_entry.merge_message, }); } } @@ -1062,7 +1067,7 @@ impl Database { server: proto::LanguageServer { id: language_server.id as u64, name: language_server.name, - worktree_id: None, + worktree_id: language_server.worktree_id.map(|id| id as u64), }, capabilities: language_server.capabilities, }) @@ -1318,10 +1323,10 @@ impl Database { .await?; let mut connection_ids = HashSet::default(); - if let Some(host_connection) = project.host_connection().log_err() { - if !exclude_dev_server { - connection_ids.insert(host_connection); - } + if let Some(host_connection) = project.host_connection().log_err() + && !exclude_dev_server + { + connection_ids.insert(host_connection); } while let Some(collaborator) = collaborators.next().await { diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index c63d7133be2ec616a95fa73359a5050c289501bf..b4cca2a2b15de0c10a641e847c32d2dfe300deb2 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -746,21 +746,21 @@ impl Database { let current_merge_conflicts = db_repository .current_merge_conflicts .as_ref() - .map(|conflicts| serde_json::from_str(&conflicts)) + .map(|conflicts| serde_json::from_str(conflicts)) .transpose()? .unwrap_or_default(); let branch_summary = db_repository .branch_summary .as_ref() - .map(|branch_summary| serde_json::from_str(&branch_summary)) + .map(|branch_summary| serde_json::from_str(branch_summary)) .transpose()? .unwrap_or_default(); let head_commit_details = db_repository .head_commit_details .as_ref() - .map(|head_commit_details| serde_json::from_str(&head_commit_details)) + .map(|head_commit_details| serde_json::from_str(head_commit_details)) .transpose()? .unwrap_or_default(); @@ -793,6 +793,7 @@ impl Database { abs_path: db_repository.abs_path, scan_id: db_repository.scan_id as u64, is_last_update: true, + merge_message: db_repository.merge_message, }); } } @@ -808,7 +809,7 @@ impl Database { server: proto::LanguageServer { id: language_server.id as u64, name: language_server.name, - worktree_id: None, + worktree_id: language_server.worktree_id.map(|id| id as u64), }, capabilities: language_server.capabilities, }) @@ -1192,7 +1193,6 @@ impl Database { self.transaction(|tx| async move { self.room_connection_lost(connection, &tx).await?; self.channel_buffer_connection_lost(connection, &tx).await?; - self.channel_chat_connection_lost(connection, &tx).await?; Ok(()) }) .await diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index d87ab174bd7a70be7ad57fd1871853018fc25763..0082a9fb030a27e4be13af725f08ea9c82217377 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -1,7 +1,4 @@ pub mod access_token; -pub mod billing_customer; -pub mod billing_preference; -pub mod billing_subscription; pub mod buffer; pub mod buffer_operation; pub mod buffer_snapshot; @@ -23,7 +20,6 @@ pub mod notification; pub mod notification_kind; pub mod observed_buffer_edits; pub mod observed_channel_messages; -pub mod processed_stripe_event; pub mod project; pub mod project_collaborator; pub mod project_repository; diff --git a/crates/collab/src/db/tables/billing_customer.rs b/crates/collab/src/db/tables/billing_customer.rs deleted file mode 100644 index e7d4a216e348a74b0cc79a308626fc1a80c508f6..0000000000000000000000000000000000000000 --- a/crates/collab/src/db/tables/billing_customer.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::db::{BillingCustomerId, UserId}; -use sea_orm::entity::prelude::*; - -/// A billing customer. -#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] -#[sea_orm(table_name = "billing_customers")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: BillingCustomerId, - pub user_id: UserId, - pub stripe_customer_id: String, - pub has_overdue_invoices: bool, - pub trial_started_at: Option, - pub created_at: DateTime, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm( - belongs_to = "super::user::Entity", - from = "Column::UserId", - to = "super::user::Column::Id" - )] - User, - #[sea_orm(has_many = "super::billing_subscription::Entity")] - BillingSubscription, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::User.def() - } -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::BillingSubscription.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/billing_preference.rs b/crates/collab/src/db/tables/billing_preference.rs deleted file mode 100644 index c1888d3b2f9c954f0b9bcd38376f191ed383b973..0000000000000000000000000000000000000000 --- a/crates/collab/src/db/tables/billing_preference.rs +++ /dev/null @@ -1,32 +0,0 @@ -use crate::db::{BillingPreferencesId, UserId}; -use sea_orm::entity::prelude::*; - -#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] -#[sea_orm(table_name = "billing_preferences")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: BillingPreferencesId, - pub created_at: DateTime, - pub user_id: UserId, - pub max_monthly_llm_usage_spending_in_cents: i32, - pub model_request_overages_enabled: bool, - pub model_request_overages_spend_limit_in_cents: i32, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm( - belongs_to = "super::user::Entity", - from = "Column::UserId", - to = "super::user::Column::Id" - )] - User, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::User.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/billing_subscription.rs b/crates/collab/src/db/tables/billing_subscription.rs deleted file mode 100644 index 522973dbc970b69947b8e790e370bfc9fa93aa99..0000000000000000000000000000000000000000 --- a/crates/collab/src/db/tables/billing_subscription.rs +++ /dev/null @@ -1,176 +0,0 @@ -use crate::db::{BillingCustomerId, BillingSubscriptionId}; -use crate::stripe_client; -use chrono::{Datelike as _, NaiveDate, Utc}; -use sea_orm::entity::prelude::*; -use serde::Serialize; - -/// A billing subscription. -#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] -#[sea_orm(table_name = "billing_subscriptions")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: BillingSubscriptionId, - pub billing_customer_id: BillingCustomerId, - pub kind: Option, - pub stripe_subscription_id: String, - pub stripe_subscription_status: StripeSubscriptionStatus, - pub stripe_cancel_at: Option, - pub stripe_cancellation_reason: Option, - pub stripe_current_period_start: Option, - pub stripe_current_period_end: Option, - pub created_at: DateTime, -} - -impl Model { - pub fn current_period_start_at(&self) -> Option { - let period_start = self.stripe_current_period_start?; - chrono::DateTime::from_timestamp(period_start, 0) - } - - pub fn current_period_end_at(&self) -> Option { - let period_end = self.stripe_current_period_end?; - chrono::DateTime::from_timestamp(period_end, 0) - } - - pub fn current_period( - subscription: Option, - is_staff: bool, - ) -> Option<(DateTimeUtc, DateTimeUtc)> { - if is_staff { - let now = Utc::now(); - let year = now.year(); - let month = now.month(); - - let first_day_of_this_month = - NaiveDate::from_ymd_opt(year, month, 1)?.and_hms_opt(0, 0, 0)?; - - let next_month = if month == 12 { 1 } else { month + 1 }; - let next_month_year = if month == 12 { year + 1 } else { year }; - let first_day_of_next_month = - NaiveDate::from_ymd_opt(next_month_year, next_month, 1)?.and_hms_opt(23, 59, 59)?; - - let last_day_of_this_month = first_day_of_next_month - chrono::Days::new(1); - - Some(( - first_day_of_this_month.and_utc(), - last_day_of_this_month.and_utc(), - )) - } else { - let subscription = subscription?; - let period_start_at = subscription.current_period_start_at()?; - let period_end_at = subscription.current_period_end_at()?; - - Some((period_start_at, period_end_at)) - } - } -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm( - belongs_to = "super::billing_customer::Entity", - from = "Column::BillingCustomerId", - to = "super::billing_customer::Column::Id" - )] - BillingCustomer, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::BillingCustomer.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} - -#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)] -#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")] -#[serde(rename_all = "snake_case")] -pub enum SubscriptionKind { - #[sea_orm(string_value = "zed_pro")] - ZedPro, - #[sea_orm(string_value = "zed_pro_trial")] - ZedProTrial, - #[sea_orm(string_value = "zed_free")] - ZedFree, -} - -impl From for cloud_llm_client::Plan { - fn from(value: SubscriptionKind) -> Self { - match value { - SubscriptionKind::ZedPro => Self::ZedPro, - SubscriptionKind::ZedProTrial => Self::ZedProTrial, - SubscriptionKind::ZedFree => Self::ZedFree, - } - } -} - -/// The status of a Stripe subscription. -/// -/// [Stripe docs](https://docs.stripe.com/api/subscriptions/object#subscription_object-status) -#[derive( - Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash, Serialize, -)] -#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")] -#[serde(rename_all = "snake_case")] -pub enum StripeSubscriptionStatus { - #[default] - #[sea_orm(string_value = "incomplete")] - Incomplete, - #[sea_orm(string_value = "incomplete_expired")] - IncompleteExpired, - #[sea_orm(string_value = "trialing")] - Trialing, - #[sea_orm(string_value = "active")] - Active, - #[sea_orm(string_value = "past_due")] - PastDue, - #[sea_orm(string_value = "canceled")] - Canceled, - #[sea_orm(string_value = "unpaid")] - Unpaid, - #[sea_orm(string_value = "paused")] - Paused, -} - -impl StripeSubscriptionStatus { - pub fn is_cancelable(&self) -> bool { - match self { - Self::Trialing | Self::Active | Self::PastDue => true, - Self::Incomplete - | Self::IncompleteExpired - | Self::Canceled - | Self::Unpaid - | Self::Paused => false, - } - } -} - -/// The cancellation reason for a Stripe subscription. -/// -/// [Stripe docs](https://docs.stripe.com/api/subscriptions/object#subscription_object-cancellation_details-reason) -#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)] -#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")] -#[serde(rename_all = "snake_case")] -pub enum StripeCancellationReason { - #[sea_orm(string_value = "cancellation_requested")] - CancellationRequested, - #[sea_orm(string_value = "payment_disputed")] - PaymentDisputed, - #[sea_orm(string_value = "payment_failed")] - PaymentFailed, -} - -impl From for StripeCancellationReason { - fn from(value: stripe_client::StripeCancellationDetailsReason) -> Self { - match value { - stripe_client::StripeCancellationDetailsReason::CancellationRequested => { - Self::CancellationRequested - } - stripe_client::StripeCancellationDetailsReason::PaymentDisputed => { - Self::PaymentDisputed - } - stripe_client::StripeCancellationDetailsReason::PaymentFailed => Self::PaymentFailed, - } - } -} diff --git a/crates/collab/src/db/tables/language_server.rs b/crates/collab/src/db/tables/language_server.rs index 34c7514d917b313990521acf8542c31394d009fc..705aae292ba456622e9808f033a348f60c3835a4 100644 --- a/crates/collab/src/db/tables/language_server.rs +++ b/crates/collab/src/db/tables/language_server.rs @@ -10,6 +10,7 @@ pub struct Model { pub id: i64, pub name: String, pub capabilities: String, + pub worktree_id: Option, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/db/tables/processed_stripe_event.rs b/crates/collab/src/db/tables/processed_stripe_event.rs deleted file mode 100644 index 7b6f0cdc31d951caee57dc45c357178d375af9c8..0000000000000000000000000000000000000000 --- a/crates/collab/src/db/tables/processed_stripe_event.rs +++ /dev/null @@ -1,16 +0,0 @@ -use sea_orm::entity::prelude::*; - -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "processed_stripe_events")] -pub struct Model { - #[sea_orm(primary_key)] - pub stripe_event_id: String, - pub stripe_event_type: String, - pub stripe_event_created_timestamp: i64, - pub processed_at: DateTime, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/project_repository.rs b/crates/collab/src/db/tables/project_repository.rs index 665e87cd1fe8c492170a8459dbf7ac6c086f9e00..eb653ecee37d48ce79e26450eb85d87dec411c1e 100644 --- a/crates/collab/src/db/tables/project_repository.rs +++ b/crates/collab/src/db/tables/project_repository.rs @@ -16,6 +16,8 @@ pub struct Model { pub is_deleted: bool, // JSON array typed string pub current_merge_conflicts: Option, + // The suggested merge commit message + pub merge_message: Option, // A JSON object representing the current Branch values pub branch_summary: Option, // A JSON object representing the current Head commit values diff --git a/crates/collab/src/db/tables/user.rs b/crates/collab/src/db/tables/user.rs index 49fe3eb58f3ee149d9cfee88fd9c4b175854373b..af43fe300a6cc1224487541ca72af9d887a6fae3 100644 --- a/crates/collab/src/db/tables/user.rs +++ b/crates/collab/src/db/tables/user.rs @@ -29,8 +29,6 @@ pub struct Model { pub enum Relation { #[sea_orm(has_many = "super::access_token::Entity")] AccessToken, - #[sea_orm(has_one = "super::billing_customer::Entity")] - BillingCustomer, #[sea_orm(has_one = "super::room_participant::Entity")] RoomParticipant, #[sea_orm(has_many = "super::project::Entity")] @@ -68,12 +66,6 @@ impl Related for Entity { } } -impl Related for Entity { - fn to() -> RelationDef { - Relation::BillingCustomer.def() - } -} - impl Related for Entity { fn to() -> RelationDef { Relation::RoomParticipant.def() diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 6c2f9dc82a88c159df1111d01a213259ab3a6c76..25e03f1320a25455ede347b43477761d591fbd57 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -7,8 +7,6 @@ mod db_tests; mod embedding_tests; mod extension_tests; mod feature_flag_tests; -mod message_tests; -mod processed_stripe_event_tests; mod user_tests; use crate::migrations::run_database_migrations; @@ -22,7 +20,7 @@ use sqlx::migrate::MigrateDatabase; use std::{ sync::{ Arc, - atomic::{AtomicI32, AtomicU32, Ordering::SeqCst}, + atomic::{AtomicI32, Ordering::SeqCst}, }, time::Duration, }; @@ -76,10 +74,10 @@ impl TestDb { static LOCK: Mutex<()> = Mutex::new(()); let _guard = LOCK.lock(); - let mut rng = StdRng::from_entropy(); + let mut rng = StdRng::from_os_rng(); let url = format!( "postgres://postgres@localhost/zed-test-{}", - rng.r#gen::() + rng.random::() ); let runtime = tokio::runtime::Builder::new_current_thread() .enable_io() @@ -225,11 +223,3 @@ async fn new_test_user(db: &Arc, email: &str) -> UserId { .unwrap() .user_id } - -static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1); -fn new_test_connection(server: ServerId) -> ConnectionId { - ConnectionId { - id: TEST_CONNECTION_ID.fetch_add(1, SeqCst), - owner_id: server.0 as u32, - } -} diff --git a/crates/collab/src/db/tests/channel_tests.rs b/crates/collab/src/db/tests/channel_tests.rs index 1dd16fb50a8d002d01e27cec0a959fd9ea9ecde7..705dbba5ead0170acd629149b8d77b847a5784b0 100644 --- a/crates/collab/src/db/tests/channel_tests.rs +++ b/crates/collab/src/db/tests/channel_tests.rs @@ -1,7 +1,7 @@ use crate::{ db::{ Channel, ChannelId, ChannelRole, Database, NewUserParams, RoomId, UserId, - tests::{assert_channel_tree_matches, channel_tree, new_test_connection, new_test_user}, + tests::{assert_channel_tree_matches, channel_tree, new_test_user}, }, test_both_dbs, }; @@ -949,41 +949,6 @@ async fn test_user_is_channel_participant(db: &Arc) { ) } -test_both_dbs!( - test_guest_access, - test_guest_access_postgres, - test_guest_access_sqlite -); - -async fn test_guest_access(db: &Arc) { - let server = db.create_server("test").await.unwrap(); - - let admin = new_test_user(db, "admin@example.com").await; - let guest = new_test_user(db, "guest@example.com").await; - let guest_connection = new_test_connection(server); - - let zed_channel = db.create_root_channel("zed", admin).await.unwrap(); - db.set_channel_visibility(zed_channel, crate::db::ChannelVisibility::Public, admin) - .await - .unwrap(); - - assert!( - db.join_channel_chat(zed_channel, guest_connection, guest) - .await - .is_err() - ); - - db.join_channel(zed_channel, guest, guest_connection) - .await - .unwrap(); - - assert!( - db.join_channel_chat(zed_channel, guest_connection, guest) - .await - .is_ok() - ) -} - #[track_caller] fn assert_channel_tree(actual: Vec, expected: &[(ChannelId, &[ChannelId])]) { let actual = actual diff --git a/crates/collab/src/db/tests/embedding_tests.rs b/crates/collab/src/db/tests/embedding_tests.rs index 367e89f87bff827fe321b0935d52647a9034794a..5d8d69c0304d3a16b55e9d7b1477fe62cc22024a 100644 --- a/crates/collab/src/db/tests/embedding_tests.rs +++ b/crates/collab/src/db/tests/embedding_tests.rs @@ -8,7 +8,7 @@ use time::{Duration, OffsetDateTime, PrimitiveDateTime}; // SQLite does not support array arguments, so we only test this against a real postgres instance #[gpui::test] async fn test_get_embeddings_postgres(cx: &mut gpui::TestAppContext) { - let test_db = TestDb::postgres(cx.executor().clone()); + let test_db = TestDb::postgres(cx.executor()); let db = test_db.db(); let provider = "test_model"; @@ -38,7 +38,7 @@ async fn test_get_embeddings_postgres(cx: &mut gpui::TestAppContext) { #[gpui::test] async fn test_purge_old_embeddings(cx: &mut gpui::TestAppContext) { - let test_db = TestDb::postgres(cx.executor().clone()); + let test_db = TestDb::postgres(cx.executor()); let db = test_db.db(); let model = "test_model"; diff --git a/crates/collab/src/db/tests/message_tests.rs b/crates/collab/src/db/tests/message_tests.rs deleted file mode 100644 index e20473d3bdd4179309c4d392f1df93f20f1e928c..0000000000000000000000000000000000000000 --- a/crates/collab/src/db/tests/message_tests.rs +++ /dev/null @@ -1,421 +0,0 @@ -use super::new_test_user; -use crate::{ - db::{ChannelRole, Database, MessageId}, - test_both_dbs, -}; -use channel::mentions_to_proto; -use std::sync::Arc; -use time::OffsetDateTime; - -test_both_dbs!( - test_channel_message_retrieval, - test_channel_message_retrieval_postgres, - test_channel_message_retrieval_sqlite -); - -async fn test_channel_message_retrieval(db: &Arc) { - let user = new_test_user(db, "user@example.com").await; - let channel = db.create_channel("channel", None, user).await.unwrap().0; - - let owner_id = db.create_server("test").await.unwrap().0 as u32; - db.join_channel_chat(channel.id, rpc::ConnectionId { owner_id, id: 0 }, user) - .await - .unwrap(); - - let mut all_messages = Vec::new(); - for i in 0..10 { - all_messages.push( - db.create_channel_message( - channel.id, - user, - &i.to_string(), - &[], - OffsetDateTime::now_utc(), - i, - None, - ) - .await - .unwrap() - .message_id - .to_proto(), - ); - } - - let messages = db - .get_channel_messages(channel.id, user, 3, None) - .await - .unwrap() - .into_iter() - .map(|message| message.id) - .collect::>(); - assert_eq!(messages, &all_messages[7..10]); - - let messages = db - .get_channel_messages( - channel.id, - user, - 4, - Some(MessageId::from_proto(all_messages[6])), - ) - .await - .unwrap() - .into_iter() - .map(|message| message.id) - .collect::>(); - assert_eq!(messages, &all_messages[2..6]); -} - -test_both_dbs!( - test_channel_message_nonces, - test_channel_message_nonces_postgres, - test_channel_message_nonces_sqlite -); - -async fn test_channel_message_nonces(db: &Arc) { - let user_a = new_test_user(db, "user_a@example.com").await; - let user_b = new_test_user(db, "user_b@example.com").await; - let user_c = new_test_user(db, "user_c@example.com").await; - let channel = db.create_root_channel("channel", user_a).await.unwrap(); - db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member) - .await - .unwrap(); - db.invite_channel_member(channel, user_c, user_a, ChannelRole::Member) - .await - .unwrap(); - db.respond_to_channel_invite(channel, user_b, true) - .await - .unwrap(); - db.respond_to_channel_invite(channel, user_c, true) - .await - .unwrap(); - - let owner_id = db.create_server("test").await.unwrap().0 as u32; - db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user_a) - .await - .unwrap(); - db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 1 }, user_b) - .await - .unwrap(); - - // As user A, create messages that reuse the same nonces. The requests - // succeed, but return the same ids. - let id1 = db - .create_channel_message( - channel, - user_a, - "hi @user_b", - &mentions_to_proto(&[(3..10, user_b.to_proto())]), - OffsetDateTime::now_utc(), - 100, - None, - ) - .await - .unwrap() - .message_id; - let id2 = db - .create_channel_message( - channel, - user_a, - "hello, fellow users", - &mentions_to_proto(&[]), - OffsetDateTime::now_utc(), - 200, - None, - ) - .await - .unwrap() - .message_id; - let id3 = db - .create_channel_message( - channel, - user_a, - "bye @user_c (same nonce as first message)", - &mentions_to_proto(&[(4..11, user_c.to_proto())]), - OffsetDateTime::now_utc(), - 100, - None, - ) - .await - .unwrap() - .message_id; - let id4 = db - .create_channel_message( - channel, - user_a, - "omg (same nonce as second message)", - &mentions_to_proto(&[]), - OffsetDateTime::now_utc(), - 200, - None, - ) - .await - .unwrap() - .message_id; - - // As a different user, reuse one of the same nonces. This request succeeds - // and returns a different id. - let id5 = db - .create_channel_message( - channel, - user_b, - "omg @user_a (same nonce as user_a's first message)", - &mentions_to_proto(&[(4..11, user_a.to_proto())]), - OffsetDateTime::now_utc(), - 100, - None, - ) - .await - .unwrap() - .message_id; - - assert_ne!(id1, id2); - assert_eq!(id1, id3); - assert_eq!(id2, id4); - assert_ne!(id5, id1); - - let messages = db - .get_channel_messages(channel, user_a, 5, None) - .await - .unwrap() - .into_iter() - .map(|m| (m.id, m.body, m.mentions)) - .collect::>(); - assert_eq!( - messages, - &[ - ( - id1.to_proto(), - "hi @user_b".into(), - mentions_to_proto(&[(3..10, user_b.to_proto())]), - ), - ( - id2.to_proto(), - "hello, fellow users".into(), - mentions_to_proto(&[]) - ), - ( - id5.to_proto(), - "omg @user_a (same nonce as user_a's first message)".into(), - mentions_to_proto(&[(4..11, user_a.to_proto())]), - ), - ] - ); -} - -test_both_dbs!( - test_unseen_channel_messages, - test_unseen_channel_messages_postgres, - test_unseen_channel_messages_sqlite -); - -async fn test_unseen_channel_messages(db: &Arc) { - let user = new_test_user(db, "user_a@example.com").await; - let observer = new_test_user(db, "user_b@example.com").await; - - let channel_1 = db.create_root_channel("channel", user).await.unwrap(); - let channel_2 = db.create_root_channel("channel-2", user).await.unwrap(); - - db.invite_channel_member(channel_1, observer, user, ChannelRole::Member) - .await - .unwrap(); - db.invite_channel_member(channel_2, observer, user, ChannelRole::Member) - .await - .unwrap(); - - db.respond_to_channel_invite(channel_1, observer, true) - .await - .unwrap(); - db.respond_to_channel_invite(channel_2, observer, true) - .await - .unwrap(); - - let owner_id = db.create_server("test").await.unwrap().0 as u32; - let user_connection_id = rpc::ConnectionId { owner_id, id: 0 }; - - db.join_channel_chat(channel_1, user_connection_id, user) - .await - .unwrap(); - - let _ = db - .create_channel_message( - channel_1, - user, - "1_1", - &[], - OffsetDateTime::now_utc(), - 1, - None, - ) - .await - .unwrap(); - - let _ = db - .create_channel_message( - channel_1, - user, - "1_2", - &[], - OffsetDateTime::now_utc(), - 2, - None, - ) - .await - .unwrap(); - - let third_message = db - .create_channel_message( - channel_1, - user, - "1_3", - &[], - OffsetDateTime::now_utc(), - 3, - None, - ) - .await - .unwrap() - .message_id; - - db.join_channel_chat(channel_2, user_connection_id, user) - .await - .unwrap(); - - let fourth_message = db - .create_channel_message( - channel_2, - user, - "2_1", - &[], - OffsetDateTime::now_utc(), - 4, - None, - ) - .await - .unwrap() - .message_id; - - // Check that observer has new messages - let latest_messages = db - .transaction(|tx| async move { - db.latest_channel_messages(&[channel_1, channel_2], &tx) - .await - }) - .await - .unwrap(); - - assert_eq!( - latest_messages, - [ - rpc::proto::ChannelMessageId { - channel_id: channel_1.to_proto(), - message_id: third_message.to_proto(), - }, - rpc::proto::ChannelMessageId { - channel_id: channel_2.to_proto(), - message_id: fourth_message.to_proto(), - }, - ] - ); -} - -test_both_dbs!( - test_channel_message_mentions, - test_channel_message_mentions_postgres, - test_channel_message_mentions_sqlite -); - -async fn test_channel_message_mentions(db: &Arc) { - let user_a = new_test_user(db, "user_a@example.com").await; - let user_b = new_test_user(db, "user_b@example.com").await; - let user_c = new_test_user(db, "user_c@example.com").await; - - let channel = db - .create_channel("channel", None, user_a) - .await - .unwrap() - .0 - .id; - db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member) - .await - .unwrap(); - db.respond_to_channel_invite(channel, user_b, true) - .await - .unwrap(); - - let owner_id = db.create_server("test").await.unwrap().0 as u32; - let connection_id = rpc::ConnectionId { owner_id, id: 0 }; - db.join_channel_chat(channel, connection_id, user_a) - .await - .unwrap(); - - db.create_channel_message( - channel, - user_a, - "hi @user_b and @user_c", - &mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]), - OffsetDateTime::now_utc(), - 1, - None, - ) - .await - .unwrap(); - db.create_channel_message( - channel, - user_a, - "bye @user_c", - &mentions_to_proto(&[(4..11, user_c.to_proto())]), - OffsetDateTime::now_utc(), - 2, - None, - ) - .await - .unwrap(); - db.create_channel_message( - channel, - user_a, - "umm", - &mentions_to_proto(&[]), - OffsetDateTime::now_utc(), - 3, - None, - ) - .await - .unwrap(); - db.create_channel_message( - channel, - user_a, - "@user_b, stop.", - &mentions_to_proto(&[(0..7, user_b.to_proto())]), - OffsetDateTime::now_utc(), - 4, - None, - ) - .await - .unwrap(); - - let messages = db - .get_channel_messages(channel, user_b, 5, None) - .await - .unwrap() - .into_iter() - .map(|m| (m.body, m.mentions)) - .collect::>(); - assert_eq!( - &messages, - &[ - ( - "hi @user_b and @user_c".into(), - mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]), - ), - ( - "bye @user_c".into(), - mentions_to_proto(&[(4..11, user_c.to_proto())]), - ), - ("umm".into(), mentions_to_proto(&[]),), - ( - "@user_b, stop.".into(), - mentions_to_proto(&[(0..7, user_b.to_proto())]), - ), - ] - ); -} diff --git a/crates/collab/src/db/tests/processed_stripe_event_tests.rs b/crates/collab/src/db/tests/processed_stripe_event_tests.rs deleted file mode 100644 index ad93b5a6589dd3a413bd5738dfc2e7debb9228d0..0000000000000000000000000000000000000000 --- a/crates/collab/src/db/tests/processed_stripe_event_tests.rs +++ /dev/null @@ -1,38 +0,0 @@ -use std::sync::Arc; - -use crate::test_both_dbs; - -use super::{CreateProcessedStripeEventParams, Database}; - -test_both_dbs!( - test_already_processed_stripe_event, - test_already_processed_stripe_event_postgres, - test_already_processed_stripe_event_sqlite -); - -async fn test_already_processed_stripe_event(db: &Arc) { - let unprocessed_event_id = "evt_1PiJOuRxOf7d5PNaw2zzWiyO".to_string(); - let processed_event_id = "evt_1PiIfMRxOf7d5PNakHrAUe8P".to_string(); - - db.create_processed_stripe_event(&CreateProcessedStripeEventParams { - stripe_event_id: processed_event_id.clone(), - stripe_event_type: "customer.created".into(), - stripe_event_created_timestamp: 1722355968, - }) - .await - .unwrap(); - - assert!( - db.already_processed_stripe_event(&processed_event_id) - .await - .unwrap(), - "Expected {processed_event_id} to already be processed" - ); - - assert!( - !db.already_processed_stripe_event(&unprocessed_event_id) - .await - .unwrap(), - "Expected {unprocessed_event_id} to be unprocessed" - ); -} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 905859ca6996c3593e1f13fbcb0e723531595ff6..191025df3770db78df3a12bc16d5c8f32d54571c 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -7,8 +7,6 @@ pub mod llm; pub mod migrations; pub mod rpc; pub mod seed; -pub mod stripe_billing; -pub mod stripe_client; pub mod user_backfiller; #[cfg(test)] @@ -22,21 +20,16 @@ use axum::{ }; use db::{ChannelId, Database}; use executor::Executor; -use llm::db::LlmDatabase; use serde::Deserialize; use std::{path::PathBuf, sync::Arc}; use util::ResultExt; -use crate::stripe_billing::StripeBilling; -use crate::stripe_client::{RealStripeClient, StripeClient}; - pub type Result = std::result::Result; pub enum Error { Http(StatusCode, String, HeaderMap), Database(sea_orm::error::DbErr), Internal(anyhow::Error), - Stripe(stripe::StripeError), } impl From for Error { @@ -51,12 +44,6 @@ impl From for Error { } } -impl From for Error { - fn from(error: stripe::StripeError) -> Self { - Self::Stripe(error) - } -} - impl From for Error { fn from(error: axum::Error) -> Self { Self::Internal(error.into()) @@ -104,14 +91,6 @@ impl IntoResponse for Error { ); (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() } - Error::Stripe(error) => { - log::error!( - "HTTP error {}: {:?}", - StatusCode::INTERNAL_SERVER_ERROR, - &error - ); - (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() - } } } } @@ -122,7 +101,6 @@ impl std::fmt::Debug for Error { Error::Http(code, message, _headers) => (code, message).fmt(f), Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), - Error::Stripe(error) => error.fmt(f), } } } @@ -133,7 +111,6 @@ impl std::fmt::Display for Error { Error::Http(code, message, _) => write!(f, "{code}: {message}"), Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), - Error::Stripe(error) => error.fmt(f), } } } @@ -179,7 +156,6 @@ pub struct Config { pub zed_client_checksum_seed: Option, pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, - pub stripe_api_key: Option, pub supermaven_admin_api_key: Option>, pub user_backfiller_github_access_token: Option>, } @@ -234,7 +210,6 @@ impl Config { auto_join_channel_id: None, migrations_path: None, seed_path: None, - stripe_api_key: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, kinesis_region: None, @@ -266,14 +241,8 @@ impl ServiceMode { pub struct AppState { pub db: Arc, - pub llm_db: Option>, pub livekit_client: Option>, pub blob_store_client: Option, - /// This is a real instance of the Stripe client; we're working to replace references to this with the - /// [`StripeClient`] trait. - pub real_stripe_client: Option>, - pub stripe_client: Option>, - pub stripe_billing: Option>, pub executor: Executor, pub kinesis_client: Option<::aws_sdk_kinesis::Client>, pub config: Config, @@ -286,20 +255,6 @@ impl AppState { let mut db = Database::new(db_options).await?; db.initialize_notification_kinds().await?; - let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config - .llm_database_url - .clone() - .zip(config.llm_database_max_connections) - { - let mut llm_db_options = db::ConnectOptions::new(llm_database_url); - llm_db_options.max_connections(llm_database_max_connections); - let mut llm_db = LlmDatabase::new(llm_db_options, executor.clone()).await?; - llm_db.initialize().await?; - Some(Arc::new(llm_db)) - } else { - None - }; - let livekit_client = if let Some(((server, key), secret)) = config .livekit_server .as_ref() @@ -316,18 +271,10 @@ impl AppState { }; let db = Arc::new(db); - let stripe_client = build_stripe_client(&config).map(Arc::new).log_err(); let this = Self { db: db.clone(), - llm_db, livekit_client, blob_store_client: build_blob_store_client(&config).await.log_err(), - stripe_billing: stripe_client - .clone() - .map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))), - real_stripe_client: stripe_client.clone(), - stripe_client: stripe_client - .map(|stripe_client| Arc::new(RealStripeClient::new(stripe_client)) as _), executor, kinesis_client: if config.kinesis_access_key.is_some() { build_kinesis_client(&config).await.log_err() @@ -340,14 +287,6 @@ impl AppState { } } -fn build_stripe_client(config: &Config) -> anyhow::Result { - let api_key = config - .stripe_api_key - .as_ref() - .context("missing stripe_api_key")?; - Ok(stripe::Client::new(api_key)) -} - async fn build_blob_store_client(config: &Config) -> anyhow::Result { let keys = aws_sdk_s3::config::Credentials::new( config diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index de74858168fd94ab677cee03f721a1e3fbbdfd46..dec10232bdb000acef9def25cad519ceb213956b 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -1,12 +1 @@ pub mod db; -mod token; - -pub use token::*; - -pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial"; - -/// The name of the feature flag that bypasses the account age check. -pub const BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG: &str = "bypass-account-age-check"; - -/// The minimum account age an account must have in order to use the LLM service. -pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30); diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs index 18ad624dab840c47df766a55c2f59cf9a17c55e6..b15d5a42b5f183831b34552beba3f616d3a7c3f0 100644 --- a/crates/collab/src/llm/db.rs +++ b/crates/collab/src/llm/db.rs @@ -1,30 +1,9 @@ -mod ids; -mod queries; -mod seed; -mod tables; - -#[cfg(test)] -mod tests; - -use cloud_llm_client::LanguageModelProvider; -use collections::HashMap; -pub use ids::*; -pub use seed::*; -pub use tables::*; - -#[cfg(test)] -pub use tests::TestLlmDb; -use usage_measure::UsageMeasure; - use std::future::Future; use std::sync::Arc; use anyhow::Context; pub use sea_orm::ConnectOptions; -use sea_orm::prelude::*; -use sea_orm::{ - ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait, -}; +use sea_orm::{DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait}; use crate::Result; use crate::db::TransactionHandle; @@ -36,9 +15,6 @@ pub struct LlmDatabase { pool: DatabaseConnection, #[allow(unused)] executor: Executor, - provider_ids: HashMap, - models: HashMap<(LanguageModelProvider, String), model::Model>, - usage_measure_ids: HashMap, #[cfg(test)] runtime: Option, } @@ -51,59 +27,11 @@ impl LlmDatabase { options: options.clone(), pool: sea_orm::Database::connect(options).await?, executor, - provider_ids: HashMap::default(), - models: HashMap::default(), - usage_measure_ids: HashMap::default(), #[cfg(test)] runtime: None, }) } - pub async fn initialize(&mut self) -> Result<()> { - self.initialize_providers().await?; - self.initialize_models().await?; - self.initialize_usage_measures().await?; - Ok(()) - } - - /// Returns the list of all known models, with their [`LanguageModelProvider`]. - pub fn all_models(&self) -> Vec<(LanguageModelProvider, model::Model)> { - self.models - .iter() - .map(|((model_provider, _model_name), model)| (*model_provider, model.clone())) - .collect::>() - } - - /// Returns the names of the known models for the given [`LanguageModelProvider`]. - pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec { - self.models - .keys() - .filter_map(|(model_provider, model_name)| { - if model_provider == &provider { - Some(model_name) - } else { - None - } - }) - .cloned() - .collect::>() - } - - pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> { - Ok(self - .models - .get(&(provider, name.to_string())) - .with_context(|| format!("unknown model {provider:?}:{name}"))?) - } - - pub fn model_by_id(&self, id: ModelId) -> Result<&model::Model> { - Ok(self - .models - .values() - .find(|model| model.id == id) - .with_context(|| format!("no model for ID {id:?}"))?) - } - pub fn options(&self) -> &ConnectOptions { &self.options } diff --git a/crates/collab/src/llm/db/ids.rs b/crates/collab/src/llm/db/ids.rs deleted file mode 100644 index 03cab6cee0b9e7a07f2d4d43aa7e556615e34494..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/db/ids.rs +++ /dev/null @@ -1,11 +0,0 @@ -use sea_orm::{DbErr, entity::prelude::*}; -use serde::{Deserialize, Serialize}; - -use crate::id_type; - -id_type!(BillingEventId); -id_type!(ModelId); -id_type!(ProviderId); -id_type!(RevokedAccessTokenId); -id_type!(UsageId); -id_type!(UsageMeasureId); diff --git a/crates/collab/src/llm/db/queries.rs b/crates/collab/src/llm/db/queries.rs deleted file mode 100644 index 0087218b3ff9fe81850870bc8022bd81fe0ee48d..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/db/queries.rs +++ /dev/null @@ -1,5 +0,0 @@ -use super::*; - -pub mod providers; -pub mod subscription_usages; -pub mod usages; diff --git a/crates/collab/src/llm/db/queries/providers.rs b/crates/collab/src/llm/db/queries/providers.rs deleted file mode 100644 index 9c7dbdd1847ea1d087582ffd959497bc41757b75..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/db/queries/providers.rs +++ /dev/null @@ -1,134 +0,0 @@ -use super::*; -use sea_orm::{QueryOrder, sea_query::OnConflict}; -use std::str::FromStr; -use strum::IntoEnumIterator as _; - -pub struct ModelParams { - pub provider: LanguageModelProvider, - pub name: String, - pub max_requests_per_minute: i64, - pub max_tokens_per_minute: i64, - pub max_tokens_per_day: i64, - pub price_per_million_input_tokens: i32, - pub price_per_million_output_tokens: i32, -} - -impl LlmDatabase { - pub async fn initialize_providers(&mut self) -> Result<()> { - self.provider_ids = self - .transaction(|tx| async move { - let existing_providers = provider::Entity::find().all(&*tx).await?; - - let mut new_providers = LanguageModelProvider::iter() - .filter(|provider| { - !existing_providers - .iter() - .any(|p| p.name == provider.to_string()) - }) - .map(|provider| provider::ActiveModel { - name: ActiveValue::set(provider.to_string()), - ..Default::default() - }) - .peekable(); - - if new_providers.peek().is_some() { - provider::Entity::insert_many(new_providers) - .exec(&*tx) - .await?; - } - - let all_providers: HashMap<_, _> = provider::Entity::find() - .all(&*tx) - .await? - .iter() - .filter_map(|provider| { - LanguageModelProvider::from_str(&provider.name) - .ok() - .map(|p| (p, provider.id)) - }) - .collect(); - - Ok(all_providers) - }) - .await?; - Ok(()) - } - - pub async fn initialize_models(&mut self) -> Result<()> { - let all_provider_ids = &self.provider_ids; - self.models = self - .transaction(|tx| async move { - let all_models: HashMap<_, _> = model::Entity::find() - .all(&*tx) - .await? - .into_iter() - .filter_map(|model| { - let provider = all_provider_ids.iter().find_map(|(provider, id)| { - if *id == model.provider_id { - Some(provider) - } else { - None - } - })?; - Some(((*provider, model.name.clone()), model)) - }) - .collect(); - Ok(all_models) - }) - .await?; - Ok(()) - } - - pub async fn insert_models(&mut self, models: &[ModelParams]) -> Result<()> { - let all_provider_ids = &self.provider_ids; - self.transaction(|tx| async move { - model::Entity::insert_many(models.iter().map(|model_params| { - let provider_id = all_provider_ids[&model_params.provider]; - model::ActiveModel { - provider_id: ActiveValue::set(provider_id), - name: ActiveValue::set(model_params.name.clone()), - max_requests_per_minute: ActiveValue::set(model_params.max_requests_per_minute), - max_tokens_per_minute: ActiveValue::set(model_params.max_tokens_per_minute), - max_tokens_per_day: ActiveValue::set(model_params.max_tokens_per_day), - price_per_million_input_tokens: ActiveValue::set( - model_params.price_per_million_input_tokens, - ), - price_per_million_output_tokens: ActiveValue::set( - model_params.price_per_million_output_tokens, - ), - ..Default::default() - } - })) - .on_conflict( - OnConflict::columns([model::Column::ProviderId, model::Column::Name]) - .update_columns([ - model::Column::MaxRequestsPerMinute, - model::Column::MaxTokensPerMinute, - model::Column::MaxTokensPerDay, - model::Column::PricePerMillionInputTokens, - model::Column::PricePerMillionOutputTokens, - ]) - .to_owned(), - ) - .exec_without_returning(&*tx) - .await?; - Ok(()) - }) - .await?; - self.initialize_models().await - } - - /// Returns the list of LLM providers. - pub async fn list_providers(&self) -> Result> { - self.transaction(|tx| async move { - Ok(provider::Entity::find() - .order_by_asc(provider::Column::Name) - .all(&*tx) - .await? - .into_iter() - .filter_map(|p| LanguageModelProvider::from_str(&p.name).ok()) - .collect()) - }) - .await - } -} diff --git a/crates/collab/src/llm/db/queries/subscription_usages.rs b/crates/collab/src/llm/db/queries/subscription_usages.rs deleted file mode 100644 index 8a519790753099be62868e94e8b068958095d320..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/db/queries/subscription_usages.rs +++ /dev/null @@ -1,38 +0,0 @@ -use crate::db::UserId; - -use super::*; - -impl LlmDatabase { - pub async fn get_subscription_usage_for_period( - &self, - user_id: UserId, - period_start_at: DateTimeUtc, - period_end_at: DateTimeUtc, - ) -> Result> { - self.transaction(|tx| async move { - self.get_subscription_usage_for_period_in_tx( - user_id, - period_start_at, - period_end_at, - &tx, - ) - .await - }) - .await - } - - async fn get_subscription_usage_for_period_in_tx( - &self, - user_id: UserId, - period_start_at: DateTimeUtc, - period_end_at: DateTimeUtc, - tx: &DatabaseTransaction, - ) -> Result> { - Ok(subscription_usage::Entity::find() - .filter(subscription_usage::Column::UserId.eq(user_id)) - .filter(subscription_usage::Column::PeriodStartAt.eq(period_start_at)) - .filter(subscription_usage::Column::PeriodEndAt.eq(period_end_at)) - .one(tx) - .await?) - } -} diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs deleted file mode 100644 index a917703f960e657f3ebe345a59558525c7aaa4bb..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/db/queries/usages.rs +++ /dev/null @@ -1,44 +0,0 @@ -use std::str::FromStr; -use strum::IntoEnumIterator as _; - -use super::*; - -impl LlmDatabase { - pub async fn initialize_usage_measures(&mut self) -> Result<()> { - let all_measures = self - .transaction(|tx| async move { - let existing_measures = usage_measure::Entity::find().all(&*tx).await?; - - let new_measures = UsageMeasure::iter() - .filter(|measure| { - !existing_measures - .iter() - .any(|m| m.name == measure.to_string()) - }) - .map(|measure| usage_measure::ActiveModel { - name: ActiveValue::set(measure.to_string()), - ..Default::default() - }) - .collect::>(); - - if !new_measures.is_empty() { - usage_measure::Entity::insert_many(new_measures) - .exec(&*tx) - .await?; - } - - Ok(usage_measure::Entity::find().all(&*tx).await?) - }) - .await?; - - self.usage_measure_ids = all_measures - .into_iter() - .filter_map(|measure| { - UsageMeasure::from_str(&measure.name) - .ok() - .map(|um| (um, measure.id)) - }) - .collect(); - Ok(()) - } -} diff --git a/crates/collab/src/llm/db/seed.rs b/crates/collab/src/llm/db/seed.rs deleted file mode 100644 index 55c6c30cd5d8bf3c6755c3f9b9faaa6fc689370e..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/db/seed.rs +++ /dev/null @@ -1,45 +0,0 @@ -use super::*; -use crate::{Config, Result}; -use queries::providers::ModelParams; - -pub async fn seed_database(_config: &Config, db: &mut LlmDatabase, _force: bool) -> Result<()> { - db.insert_models(&[ - ModelParams { - provider: LanguageModelProvider::Anthropic, - name: "claude-3-5-sonnet".into(), - max_requests_per_minute: 5, - max_tokens_per_minute: 20_000, - max_tokens_per_day: 300_000, - price_per_million_input_tokens: 300, // $3.00/MTok - price_per_million_output_tokens: 1500, // $15.00/MTok - }, - ModelParams { - provider: LanguageModelProvider::Anthropic, - name: "claude-3-opus".into(), - max_requests_per_minute: 5, - max_tokens_per_minute: 10_000, - max_tokens_per_day: 300_000, - price_per_million_input_tokens: 1500, // $15.00/MTok - price_per_million_output_tokens: 7500, // $75.00/MTok - }, - ModelParams { - provider: LanguageModelProvider::Anthropic, - name: "claude-3-sonnet".into(), - max_requests_per_minute: 5, - max_tokens_per_minute: 20_000, - max_tokens_per_day: 300_000, - price_per_million_input_tokens: 1500, // $15.00/MTok - price_per_million_output_tokens: 7500, // $75.00/MTok - }, - ModelParams { - provider: LanguageModelProvider::Anthropic, - name: "claude-3-haiku".into(), - max_requests_per_minute: 5, - max_tokens_per_minute: 25_000, - max_tokens_per_day: 300_000, - price_per_million_input_tokens: 25, // $0.25/MTok - price_per_million_output_tokens: 125, // $1.25/MTok - }, - ]) - .await -} diff --git a/crates/collab/src/llm/db/tables.rs b/crates/collab/src/llm/db/tables.rs deleted file mode 100644 index 75ea8f51409ec28ec546db5a360b935ef04fb7f9..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/db/tables.rs +++ /dev/null @@ -1,6 +0,0 @@ -pub mod model; -pub mod provider; -pub mod subscription_usage; -pub mod subscription_usage_meter; -pub mod usage; -pub mod usage_measure; diff --git a/crates/collab/src/llm/db/tables/model.rs b/crates/collab/src/llm/db/tables/model.rs deleted file mode 100644 index f0a858b4a681ce930f9e8d57f5289950a5476ef1..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/db/tables/model.rs +++ /dev/null @@ -1,48 +0,0 @@ -use sea_orm::entity::prelude::*; - -use crate::llm::db::{ModelId, ProviderId}; - -/// An LLM model. -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "models")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: ModelId, - pub provider_id: ProviderId, - pub name: String, - pub max_requests_per_minute: i64, - pub max_tokens_per_minute: i64, - pub max_input_tokens_per_minute: i64, - pub max_output_tokens_per_minute: i64, - pub max_tokens_per_day: i64, - pub price_per_million_input_tokens: i32, - pub price_per_million_cache_creation_input_tokens: i32, - pub price_per_million_cache_read_input_tokens: i32, - pub price_per_million_output_tokens: i32, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm( - belongs_to = "super::provider::Entity", - from = "Column::ProviderId", - to = "super::provider::Column::Id" - )] - Provider, - #[sea_orm(has_many = "super::usage::Entity")] - Usages, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::Provider.def() - } -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::Usages.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tables/provider.rs b/crates/collab/src/llm/db/tables/provider.rs deleted file mode 100644 index 90838f7c65511e83cd7192676e0dafefdd05896a..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/db/tables/provider.rs +++ /dev/null @@ -1,25 +0,0 @@ -use crate::llm::db::ProviderId; -use sea_orm::entity::prelude::*; - -/// An LLM provider. -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "providers")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: ProviderId, - pub name: String, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm(has_many = "super::model::Entity")] - Models, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::Models.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tables/subscription_usage.rs b/crates/collab/src/llm/db/tables/subscription_usage.rs deleted file mode 100644 index dd93b03d051ef9752b1c777d24205085fca4487e..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/db/tables/subscription_usage.rs +++ /dev/null @@ -1,22 +0,0 @@ -use crate::db::UserId; -use crate::db::billing_subscription::SubscriptionKind; -use sea_orm::entity::prelude::*; -use time::PrimitiveDateTime; - -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "subscription_usages_v2")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: Uuid, - pub user_id: UserId, - pub period_start_at: PrimitiveDateTime, - pub period_end_at: PrimitiveDateTime, - pub plan: SubscriptionKind, - pub model_requests: i32, - pub edit_predictions: i32, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tables/subscription_usage_meter.rs b/crates/collab/src/llm/db/tables/subscription_usage_meter.rs deleted file mode 100644 index c082cf3bc132aa4df2c3c7b5422a0c53ec235579..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/db/tables/subscription_usage_meter.rs +++ /dev/null @@ -1,55 +0,0 @@ -use sea_orm::entity::prelude::*; -use serde::Serialize; - -use crate::llm::db::ModelId; - -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "subscription_usage_meters_v2")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: Uuid, - pub subscription_usage_id: Uuid, - pub model_id: ModelId, - pub mode: CompletionMode, - pub requests: i32, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm( - belongs_to = "super::subscription_usage::Entity", - from = "Column::SubscriptionUsageId", - to = "super::subscription_usage::Column::Id" - )] - SubscriptionUsage, - #[sea_orm( - belongs_to = "super::model::Entity", - from = "Column::ModelId", - to = "super::model::Column::Id" - )] - Model, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::SubscriptionUsage.def() - } -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::Model.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} - -#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)] -#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")] -#[serde(rename_all = "snake_case")] -pub enum CompletionMode { - #[sea_orm(string_value = "normal")] - Normal, - #[sea_orm(string_value = "max")] - Max, -} diff --git a/crates/collab/src/llm/db/tables/usage.rs b/crates/collab/src/llm/db/tables/usage.rs deleted file mode 100644 index 331c94a8a90df2e38601603a746f97ebbf703461..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/db/tables/usage.rs +++ /dev/null @@ -1,52 +0,0 @@ -use crate::{ - db::UserId, - llm::db::{ModelId, UsageId, UsageMeasureId}, -}; -use sea_orm::entity::prelude::*; - -/// An LLM usage record. -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "usages")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: UsageId, - /// The ID of the Zed user. - /// - /// Corresponds to the `users` table in the primary collab database. - pub user_id: UserId, - pub model_id: ModelId, - pub measure_id: UsageMeasureId, - pub timestamp: DateTime, - pub buckets: Vec, - pub is_staff: bool, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm( - belongs_to = "super::model::Entity", - from = "Column::ModelId", - to = "super::model::Column::Id" - )] - Model, - #[sea_orm( - belongs_to = "super::usage_measure::Entity", - from = "Column::MeasureId", - to = "super::usage_measure::Column::Id" - )] - UsageMeasure, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::Model.def() - } -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::UsageMeasure.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tables/usage_measure.rs b/crates/collab/src/llm/db/tables/usage_measure.rs deleted file mode 100644 index 4f75577ed4684ff73b98389eaa08aefbabc08a16..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/db/tables/usage_measure.rs +++ /dev/null @@ -1,36 +0,0 @@ -use crate::llm::db::UsageMeasureId; -use sea_orm::entity::prelude::*; - -#[derive( - Copy, Clone, Debug, PartialEq, Eq, Hash, strum::EnumString, strum::Display, strum::EnumIter, -)] -#[strum(serialize_all = "snake_case")] -pub enum UsageMeasure { - RequestsPerMinute, - TokensPerMinute, - InputTokensPerMinute, - OutputTokensPerMinute, - TokensPerDay, -} - -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "usage_measures")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: UsageMeasureId, - pub name: String, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm(has_many = "super::usage::Entity")] - Usages, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::Usages.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tests.rs b/crates/collab/src/llm/db/tests.rs deleted file mode 100644 index 43a1b8b0d457817d1e94d72d0cad094011424c83..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/db/tests.rs +++ /dev/null @@ -1,107 +0,0 @@ -mod provider_tests; - -use gpui::BackgroundExecutor; -use parking_lot::Mutex; -use rand::prelude::*; -use sea_orm::ConnectionTrait; -use sqlx::migrate::MigrateDatabase; -use std::time::Duration; - -use crate::migrations::run_database_migrations; - -use super::*; - -pub struct TestLlmDb { - pub db: Option, - pub connection: Option, -} - -impl TestLlmDb { - pub fn postgres(background: BackgroundExecutor) -> Self { - static LOCK: Mutex<()> = Mutex::new(()); - - let _guard = LOCK.lock(); - let mut rng = StdRng::from_entropy(); - let url = format!( - "postgres://postgres@localhost/zed-llm-test-{}", - rng.r#gen::() - ); - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .unwrap(); - - let mut db = runtime.block_on(async { - sqlx::Postgres::create_database(&url) - .await - .expect("failed to create test db"); - let mut options = ConnectOptions::new(url); - options - .max_connections(5) - .idle_timeout(Duration::from_secs(0)); - let db = LlmDatabase::new(options, Executor::Deterministic(background)) - .await - .unwrap(); - let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm"); - run_database_migrations(db.options(), migrations_path) - .await - .unwrap(); - db - }); - - db.runtime = Some(runtime); - - Self { - db: Some(db), - connection: None, - } - } - - pub fn db(&mut self) -> &mut LlmDatabase { - self.db.as_mut().unwrap() - } -} - -#[macro_export] -macro_rules! test_llm_db { - ($test_name:ident, $postgres_test_name:ident) => { - #[gpui::test] - async fn $postgres_test_name(cx: &mut gpui::TestAppContext) { - if !cfg!(target_os = "macos") { - return; - } - - let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone()); - $test_name(test_db.db()).await; - } - }; -} - -impl Drop for TestLlmDb { - fn drop(&mut self) { - let db = self.db.take().unwrap(); - if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() { - db.runtime.as_ref().unwrap().block_on(async { - use util::ResultExt; - let query = " - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE - pg_stat_activity.datname = current_database() AND - pid <> pg_backend_pid(); - "; - db.pool - .execute(sea_orm::Statement::from_string( - db.pool.get_database_backend(), - query, - )) - .await - .log_err(); - sqlx::Postgres::drop_database(db.options.get_url()) - .await - .log_err(); - }) - } - } -} diff --git a/crates/collab/src/llm/db/tests/provider_tests.rs b/crates/collab/src/llm/db/tests/provider_tests.rs deleted file mode 100644 index f4e1de40ec10705ed9b740619754fcf9ec5f3e1e..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/db/tests/provider_tests.rs +++ /dev/null @@ -1,31 +0,0 @@ -use cloud_llm_client::LanguageModelProvider; -use pretty_assertions::assert_eq; - -use crate::llm::db::LlmDatabase; -use crate::test_llm_db; - -test_llm_db!( - test_initialize_providers, - test_initialize_providers_postgres -); - -async fn test_initialize_providers(db: &mut LlmDatabase) { - let initial_providers = db.list_providers().await.unwrap(); - assert_eq!(initial_providers, vec![]); - - db.initialize_providers().await.unwrap(); - - // Do it twice, to make sure the operation is idempotent. - db.initialize_providers().await.unwrap(); - - let providers = db.list_providers().await.unwrap(); - - assert_eq!( - providers, - &[ - LanguageModelProvider::Anthropic, - LanguageModelProvider::Google, - LanguageModelProvider::OpenAi, - ] - ) -} diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs deleted file mode 100644 index da01c7f3bed5cab1e7dbd6cfdef8cd4d7643044c..0000000000000000000000000000000000000000 --- a/crates/collab/src/llm/token.rs +++ /dev/null @@ -1,146 +0,0 @@ -use crate::db::billing_subscription::SubscriptionKind; -use crate::db::{billing_customer, billing_subscription, user}; -use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG}; -use crate::{Config, db::billing_preference}; -use anyhow::{Context as _, Result}; -use chrono::{NaiveDateTime, Utc}; -use cloud_llm_client::Plan; -use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation}; -use serde::{Deserialize, Serialize}; -use std::time::Duration; -use thiserror::Error; -use uuid::Uuid; - -#[derive(Clone, Debug, Default, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct LlmTokenClaims { - pub iat: u64, - pub exp: u64, - pub jti: String, - pub user_id: u64, - pub system_id: Option, - pub metrics_id: Uuid, - pub github_user_login: String, - pub account_created_at: NaiveDateTime, - pub is_staff: bool, - pub has_llm_closed_beta_feature_flag: bool, - pub bypass_account_age_check: bool, - pub use_llm_request_queue: bool, - pub plan: Plan, - pub has_extended_trial: bool, - pub subscription_period: (NaiveDateTime, NaiveDateTime), - pub enable_model_request_overages: bool, - pub model_request_overages_spend_limit_in_cents: u32, - pub can_use_web_search_tool: bool, - #[serde(default)] - pub has_overdue_invoices: bool, -} - -const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60); - -impl LlmTokenClaims { - pub fn create( - user: &user::Model, - is_staff: bool, - billing_customer: billing_customer::Model, - billing_preferences: Option, - feature_flags: &Vec, - subscription: billing_subscription::Model, - system_id: Option, - config: &Config, - ) -> Result { - let secret = config - .llm_api_secret - .as_ref() - .context("no LLM API secret")?; - - let plan = if is_staff { - Plan::ZedPro - } else { - subscription.kind.map_or(Plan::ZedFree, |kind| match kind { - SubscriptionKind::ZedFree => Plan::ZedFree, - SubscriptionKind::ZedPro => Plan::ZedPro, - SubscriptionKind::ZedProTrial => Plan::ZedProTrial, - }) - }; - let subscription_period = - billing_subscription::Model::current_period(Some(subscription), is_staff) - .map(|(start, end)| (start.naive_utc(), end.naive_utc())) - .context("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started.")?; - - let now = Utc::now(); - let claims = Self { - iat: now.timestamp() as u64, - exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64, - jti: uuid::Uuid::new_v4().to_string(), - user_id: user.id.to_proto(), - system_id, - metrics_id: user.metrics_id, - github_user_login: user.github_login.clone(), - account_created_at: user.account_created_at(), - is_staff, - has_llm_closed_beta_feature_flag: feature_flags - .iter() - .any(|flag| flag == "llm-closed-beta"), - bypass_account_age_check: feature_flags - .iter() - .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG), - can_use_web_search_tool: true, - use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"), - plan, - has_extended_trial: feature_flags - .iter() - .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG), - subscription_period, - enable_model_request_overages: billing_preferences - .as_ref() - .map_or(false, |preferences| { - preferences.model_request_overages_enabled - }), - model_request_overages_spend_limit_in_cents: billing_preferences - .as_ref() - .map_or(0, |preferences| { - preferences.model_request_overages_spend_limit_in_cents as u32 - }), - has_overdue_invoices: billing_customer.has_overdue_invoices, - }; - - Ok(jsonwebtoken::encode( - &Header::default(), - &claims, - &EncodingKey::from_secret(secret.as_ref()), - )?) - } - - pub fn validate(token: &str, config: &Config) -> Result { - let secret = config - .llm_api_secret - .as_ref() - .context("no LLM API secret")?; - - match jsonwebtoken::decode::( - token, - &DecodingKey::from_secret(secret.as_ref()), - &Validation::default(), - ) { - Ok(token) => Ok(token.claims), - Err(e) => { - if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature { - Err(ValidateLlmTokenError::Expired) - } else { - Err(ValidateLlmTokenError::JwtError(e)) - } - } - } - } -} - -#[derive(Error, Debug)] -pub enum ValidateLlmTokenError { - #[error("access token is expired")] - Expired, - #[error("access token validation error: {0}")] - JwtError(#[from] jsonwebtoken::errors::Error), - #[error("{0}")] - Other(#[from] anyhow::Error), -} diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 20641cb2322a6aa10372064ca208eef091b2ae5a..cb6f6cad1dd483c463bcda5d8a4ff914f4bf10aa 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -62,13 +62,6 @@ async fn main() -> Result<()> { db.initialize_notification_kinds().await?; collab::seed::seed(&config, &db, false).await?; - - if let Some(llm_database_url) = config.llm_database_url.clone() { - let db_options = db::ConnectOptions::new(llm_database_url); - let mut db = LlmDatabase::new(db_options.clone(), Executor::Production).await?; - db.initialize().await?; - collab::llm::db::seed_database(&config, &mut db, true).await?; - } } Some("serve") => { let mode = match args.next().as_deref() { @@ -102,13 +95,6 @@ async fn main() -> Result<()> { let state = AppState::new(config, Executor::Production).await?; - if let Some(stripe_billing) = state.stripe_billing.clone() { - let executor = state.executor.clone(); - executor.spawn_detached(async move { - stripe_billing.initialize().await.trace_err(); - }); - } - if mode.is_collab() { state.db.purge_old_embeddings().await.trace_err(); @@ -270,9 +256,6 @@ async fn setup_llm_database(config: &Config) -> Result<()> { .llm_database_migrations_path .as_deref() .unwrap_or_else(|| { - #[cfg(feature = "sqlite")] - let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm.sqlite"); - #[cfg(not(feature = "sqlite"))] let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm"); Path::new(default_migrations) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 18eb1457dc336c5e2bf32a3d8430514b29bb6966..e19c59f9974f243a585b02baac8d87dc82e0d405 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1,21 +1,12 @@ mod connection_pool; -use crate::api::billing::find_or_create_billing_customer; use crate::api::{CloudflareIpCountryHeader, SystemIdHeader}; -use crate::db::billing_subscription::SubscriptionKind; -use crate::llm::db::LlmDatabase; -use crate::llm::{ - AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims, - MIN_ACCOUNT_AGE_FOR_LLM_USE, -}; -use crate::stripe_client::StripeCustomerId; use crate::{ AppState, Error, Result, auth, db::{ - self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, - CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId, - NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult, - RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId, + self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database, + InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject, + RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId, }, executor::Executor, }; @@ -37,7 +28,6 @@ use axum::{ response::IntoResponse, routing::get, }; -use chrono::Utc; use collections::{HashMap, HashSet}; pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; @@ -75,7 +65,6 @@ use std::{ }, time::{Duration, Instant}, }; -use time::OffsetDateTime; use tokio::sync::{Semaphore, watch}; use tower::ServiceBuilder; use tracing::{ @@ -89,8 +78,6 @@ pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); // kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources. pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15); -const MESSAGE_COUNT_PER_PAGE: usize = 100; -const MAX_MESSAGE_LEN: usize = 1024; const NOTIFICATION_COUNT_PER_PAGE: usize = 50; const MAX_CONCURRENT_CONNECTIONS: usize = 512; @@ -148,13 +135,6 @@ pub enum Principal { } impl Principal { - fn user(&self) -> &User { - match self { - Principal::User(user) => user, - Principal::Impersonated { user, .. } => user, - } - } - fn update_span(&self, span: &tracing::Span) { match &self { Principal::User(user) => { @@ -218,6 +198,7 @@ struct Session { /// The GeoIP country code for the user. #[allow(unused)] geoip_country_code: Option, + #[allow(unused)] system_id: Option, _executor: Executor, } @@ -325,7 +306,7 @@ impl Server { let mut server = Self { id: parking_lot::Mutex::new(id), peer: Peer::new(id.0 as u32), - app_state: app_state.clone(), + app_state, connection_pool: Default::default(), handlers: Default::default(), teardown: watch::channel(false).0, @@ -415,6 +396,8 @@ impl Server { .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(multi_lsp_query) + .add_request_handler(lsp_query) + .add_message_handler(broadcast_project_message_from_host::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) @@ -463,9 +446,6 @@ impl Server { .add_request_handler(follow) .add_message_handler(unfollow) .add_message_handler(update_followers) - .add_request_handler(get_private_user_info) - .add_request_handler(get_llm_api_token) - .add_request_handler(accept_terms_of_service) .add_message_handler(acknowledge_channel_message) .add_message_handler(acknowledge_buffer_version) .add_request_handler(get_supermaven_api_key) @@ -492,7 +472,9 @@ impl Server { .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_message_handler(broadcast_project_message_from_host::) - .add_message_handler(update_context); + .add_message_handler(update_context) + .add_request_handler(forward_mutating_project_request::) + .add_message_handler(broadcast_project_message_from_host::); Arc::new(server) } @@ -634,10 +616,10 @@ impl Server { } } - if let Some(live_kit) = livekit_client.as_ref() { - if delete_livekit_room { - live_kit.delete_room(livekit_room).await.trace_err(); - } + if let Some(live_kit) = livekit_client.as_ref() + && delete_livekit_room + { + live_kit.delete_room(livekit_room).await.trace_err(); } } } @@ -928,7 +910,9 @@ impl Server { user_id=field::Empty, login=field::Empty, impersonator=field::Empty, + // todo(lsp) remove after Zed Stable hits v0.204.x multi_lsp_query_request=field::Empty, + lsp_query_request=field::Empty, release_channel=field::Empty, { TOTAL_DURATION_MS }=field::Empty, { PROCESSING_DURATION_MS }=field::Empty, @@ -1000,8 +984,6 @@ impl Server { .await?; } - update_user_plan(session).await?; - let contacts = self.app_state.db.get_contacts(user.id).await?; { @@ -1035,99 +1017,52 @@ impl Server { inviter_id: UserId, invitee_id: UserId, ) -> Result<()> { - if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? { - if let Some(code) = &user.invite_code { - let pool = self.connection_pool.lock(); - let invitee_contact = contact_for_user(invitee_id, false, &pool); - for connection_id in pool.user_connection_ids(inviter_id) { - self.peer.send( - connection_id, - proto::UpdateContacts { - contacts: vec![invitee_contact.clone()], - ..Default::default() - }, - )?; - self.peer.send( - connection_id, - proto::UpdateInviteInfo { - url: format!("{}{}", self.app_state.config.invite_link_prefix, &code), - count: user.invite_count as u32, - }, - )?; - } + if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? + && let Some(code) = &user.invite_code + { + let pool = self.connection_pool.lock(); + let invitee_contact = contact_for_user(invitee_id, false, &pool); + for connection_id in pool.user_connection_ids(inviter_id) { + self.peer.send( + connection_id, + proto::UpdateContacts { + contacts: vec![invitee_contact.clone()], + ..Default::default() + }, + )?; + self.peer.send( + connection_id, + proto::UpdateInviteInfo { + url: format!("{}{}", self.app_state.config.invite_link_prefix, &code), + count: user.invite_count as u32, + }, + )?; } } Ok(()) } pub async fn invite_count_updated(self: &Arc, user_id: UserId) -> Result<()> { - if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? { - if let Some(invite_code) = &user.invite_code { - let pool = self.connection_pool.lock(); - for connection_id in pool.user_connection_ids(user_id) { - self.peer.send( - connection_id, - proto::UpdateInviteInfo { - url: format!( - "{}{}", - self.app_state.config.invite_link_prefix, invite_code - ), - count: user.invite_count as u32, - }, - )?; - } + if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? + && let Some(invite_code) = &user.invite_code + { + let pool = self.connection_pool.lock(); + for connection_id in pool.user_connection_ids(user_id) { + self.peer.send( + connection_id, + proto::UpdateInviteInfo { + url: format!( + "{}{}", + self.app_state.config.invite_link_prefix, invite_code + ), + count: user.invite_count as u32, + }, + )?; } } Ok(()) } - pub async fn update_plan_for_user( - self: &Arc, - user_id: UserId, - update_user_plan: proto::UpdateUserPlan, - ) -> Result<()> { - let pool = self.connection_pool.lock(); - for connection_id in pool.user_connection_ids(user_id) { - self.peer - .send(connection_id, update_user_plan.clone()) - .trace_err(); - } - - Ok(()) - } - - /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan` - /// message on the Collab server. - /// - /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint. - pub async fn update_plan_for_user_legacy(self: &Arc, user_id: UserId) -> Result<()> { - let user = self - .app_state - .db - .get_user_by_id(user_id) - .await? - .context("user not found")?; - - let update_user_plan = make_update_user_plan_message( - &user, - user.admin, - &self.app_state.db, - self.app_state.llm_db.clone(), - ) - .await?; - - self.update_plan_for_user(user_id, update_user_plan).await - } - - pub async fn refresh_llm_tokens_for_user(self: &Arc, user_id: UserId) { - let pool = self.connection_pool.lock(); - for connection_id in pool.user_connection_ids(user_id) { - self.peer - .send(connection_id, proto::RefreshLlmToken {}) - .trace_err(); - } - } - pub async fn snapshot(self: &Arc) -> ServerSnapshot<'_> { ServerSnapshot { connection_pool: ConnectionPoolGuard { @@ -1168,10 +1103,10 @@ fn broadcast( F: FnMut(ConnectionId) -> anyhow::Result<()>, { for receiver_id in receiver_ids { - if Some(receiver_id) != sender_id { - if let Err(error) = f(receiver_id) { - tracing::error!("failed to send to {:?} {}", receiver_id, error); - } + if Some(receiver_id) != sender_id + && let Err(error) = f(receiver_id) + { + tracing::error!("failed to send to {:?} {}", receiver_id, error); } } } @@ -1453,9 +1388,7 @@ async fn create_room( let live_kit = live_kit?; let user_id = session.user_id().to_string(); - let token = live_kit - .room_token(&livekit_room, &user_id.to_string()) - .trace_err()?; + let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?; Some(proto::LiveKitConnectionInfo { server_url: live_kit.url().into(), @@ -2082,9 +2015,9 @@ async fn join_project( .unzip(); response.send(proto::JoinProjectResponse { project_id: project.id.0 as u64, - worktrees: worktrees.clone(), + worktrees, replica_id: replica_id.0 as u32, - collaborators: collaborators.clone(), + collaborators, language_servers, language_server_capabilities, role: project.role.into(), @@ -2361,11 +2294,10 @@ async fn update_language_server( let db = session.db().await; if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant + && let Some(capabilities) = update.capabilities.clone() { - if let Some(capabilities) = update.capabilities.clone() { - db.update_server_capabilities(project_id, request.language_server_id, capabilities) - .await?; - } + db.update_server_capabilities(project_id, request.language_server_id, capabilities) + .await?; } let project_connection_ids = db @@ -2426,6 +2358,7 @@ where Ok(()) } +// todo(lsp) remove after Zed Stable hits v0.204.x async fn multi_lsp_query( request: MultiLspQuery, response: Response, @@ -2436,6 +2369,21 @@ async fn multi_lsp_query( forward_mutating_project_request(request, response, session).await } +async fn lsp_query( + request: proto::LspQuery, + response: Response, + session: MessageContext, +) -> Result<()> { + let (name, should_write) = request.query_name_and_write_permissions(); + tracing::Span::current().record("lsp_query_request", name); + tracing::info!("lsp_query message received"); + if should_write { + forward_mutating_project_request(request, response, session).await + } else { + forward_read_only_project_request(request, response, session).await + } +} + /// Notify other participants that a new buffer has been created async fn create_buffer_for_peer( request: proto::CreateBufferForPeer, @@ -2882,214 +2830,6 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool { version.0.minor() < 139 } -async fn current_plan(db: &Arc, user_id: UserId, is_staff: bool) -> Result { - if is_staff { - return Ok(proto::Plan::ZedPro); - } - - let subscription = db.get_active_billing_subscription(user_id).await?; - let subscription_kind = subscription.and_then(|subscription| subscription.kind); - - let plan = if let Some(subscription_kind) = subscription_kind { - match subscription_kind { - SubscriptionKind::ZedPro => proto::Plan::ZedPro, - SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial, - SubscriptionKind::ZedFree => proto::Plan::Free, - } - } else { - proto::Plan::Free - }; - - Ok(plan) -} - -async fn make_update_user_plan_message( - user: &User, - is_staff: bool, - db: &Arc, - llm_db: Option>, -) -> Result { - let feature_flags = db.get_user_flags(user.id).await?; - let plan = current_plan(db, user.id, is_staff).await?; - let billing_customer = db.get_billing_customer_by_user_id(user.id).await?; - let billing_preferences = db.get_billing_preferences(user.id).await?; - - let (subscription_period, usage) = if let Some(llm_db) = llm_db { - let subscription = db.get_active_billing_subscription(user.id).await?; - - let subscription_period = - crate::db::billing_subscription::Model::current_period(subscription, is_staff); - - let usage = if let Some((period_start_at, period_end_at)) = subscription_period { - llm_db - .get_subscription_usage_for_period(user.id, period_start_at, period_end_at) - .await? - } else { - None - }; - - (subscription_period, usage) - } else { - (None, None) - }; - - let bypass_account_age_check = feature_flags - .iter() - .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG); - let account_too_young = !matches!(plan, proto::Plan::ZedPro) - && !bypass_account_age_check - && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE; - - Ok(proto::UpdateUserPlan { - plan: plan.into(), - trial_started_at: billing_customer - .as_ref() - .and_then(|billing_customer| billing_customer.trial_started_at) - .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64), - is_usage_based_billing_enabled: if is_staff { - Some(true) - } else { - billing_preferences.map(|preferences| preferences.model_request_overages_enabled) - }, - subscription_period: subscription_period.map(|(started_at, ended_at)| { - proto::SubscriptionPeriod { - started_at: started_at.timestamp() as u64, - ended_at: ended_at.timestamp() as u64, - } - }), - account_too_young: Some(account_too_young), - has_overdue_invoices: billing_customer - .map(|billing_customer| billing_customer.has_overdue_invoices), - usage: Some( - usage - .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags)) - .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)), - ), - }) -} - -fn model_requests_limit( - plan: cloud_llm_client::Plan, - feature_flags: &Vec, -) -> cloud_llm_client::UsageLimit { - match plan.model_requests_limit() { - cloud_llm_client::UsageLimit::Limited(limit) => { - let limit = if plan == cloud_llm_client::Plan::ZedProTrial - && feature_flags - .iter() - .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG) - { - 1_000 - } else { - limit - }; - - cloud_llm_client::UsageLimit::Limited(limit) - } - cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited, - } -} - -fn subscription_usage_to_proto( - plan: proto::Plan, - usage: crate::llm::db::subscription_usage::Model, - feature_flags: &Vec, -) -> proto::SubscriptionUsage { - let plan = match plan { - proto::Plan::Free => cloud_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, - }; - - proto::SubscriptionUsage { - model_requests_usage_amount: usage.model_requests as u32, - model_requests_usage_limit: Some(proto::UsageLimit { - variant: Some(match model_requests_limit(plan, feature_flags) { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - edit_predictions_usage_amount: usage.edit_predictions as u32, - edit_predictions_usage_limit: Some(proto::UsageLimit { - variant: Some(match plan.edit_predictions_limit() { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - } -} - -fn make_default_subscription_usage( - plan: proto::Plan, - feature_flags: &Vec, -) -> proto::SubscriptionUsage { - let plan = match plan { - proto::Plan::Free => cloud_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, - }; - - proto::SubscriptionUsage { - model_requests_usage_amount: 0, - model_requests_usage_limit: Some(proto::UsageLimit { - variant: Some(match model_requests_limit(plan, feature_flags) { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - edit_predictions_usage_amount: 0, - edit_predictions_usage_limit: Some(proto::UsageLimit { - variant: Some(match plan.edit_predictions_limit() { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - } -} - -async fn update_user_plan(session: &Session) -> Result<()> { - let db = session.db().await; - - let update_user_plan = make_update_user_plan_message( - session.principal.user(), - session.is_staff(), - &db.0, - session.app_state.llm_db.clone(), - ) - .await?; - - session - .peer - .send(session.connection_id, update_user_plan) - .trace_err(); - - Ok(()) -} - async fn subscribe_to_channels( _: proto::SubscribeToChannels, session: MessageContext, @@ -3853,235 +3593,36 @@ fn send_notifications( /// Send a message to the channel async fn send_channel_message( - request: proto::SendChannelMessage, - response: Response, - session: MessageContext, + _request: proto::SendChannelMessage, + _response: Response, + _session: MessageContext, ) -> Result<()> { - // Validate the message body. - let body = request.body.trim().to_string(); - if body.len() > MAX_MESSAGE_LEN { - return Err(anyhow!("message is too long"))?; - } - if body.is_empty() { - return Err(anyhow!("message can't be blank"))?; - } - - // TODO: adjust mentions if body is trimmed - - let timestamp = OffsetDateTime::now_utc(); - let nonce = request.nonce.context("nonce can't be blank")?; - - let channel_id = ChannelId::from_proto(request.channel_id); - let CreatedChannelMessage { - message_id, - participant_connection_ids, - notifications, - } = session - .db() - .await - .create_channel_message( - channel_id, - session.user_id(), - &body, - &request.mentions, - timestamp, - nonce.clone().into(), - request.reply_to_message_id.map(MessageId::from_proto), - ) - .await?; - - let message = proto::ChannelMessage { - sender_id: session.user_id().to_proto(), - id: message_id.to_proto(), - body, - mentions: request.mentions, - timestamp: timestamp.unix_timestamp() as u64, - nonce: Some(nonce), - reply_to_message_id: request.reply_to_message_id, - edited_at: None, - }; - broadcast( - Some(session.connection_id), - participant_connection_ids.clone(), - |connection| { - session.peer.send( - connection, - proto::ChannelMessageSent { - channel_id: channel_id.to_proto(), - message: Some(message.clone()), - }, - ) - }, - ); - response.send(proto::SendChannelMessageResponse { - message: Some(message), - })?; - - let pool = &*session.connection_pool().await; - let non_participants = - pool.channel_connection_ids(channel_id) - .filter_map(|(connection_id, _)| { - if participant_connection_ids.contains(&connection_id) { - None - } else { - Some(connection_id) - } - }); - broadcast(None, non_participants, |peer_id| { - session.peer.send( - peer_id, - proto::UpdateChannels { - latest_channel_message_ids: vec![proto::ChannelMessageId { - channel_id: channel_id.to_proto(), - message_id: message_id.to_proto(), - }], - ..Default::default() - }, - ) - }); - send_notifications(pool, &session.peer, notifications); - - Ok(()) + Err(anyhow!("chat has been removed in the latest version of Zed").into()) } /// Delete a channel message async fn remove_channel_message( - request: proto::RemoveChannelMessage, - response: Response, - session: MessageContext, + _request: proto::RemoveChannelMessage, + _response: Response, + _session: MessageContext, ) -> Result<()> { - let channel_id = ChannelId::from_proto(request.channel_id); - let message_id = MessageId::from_proto(request.message_id); - let (connection_ids, existing_notification_ids) = session - .db() - .await - .remove_channel_message(channel_id, message_id, session.user_id()) - .await?; - - broadcast( - Some(session.connection_id), - connection_ids, - move |connection| { - session.peer.send(connection, request.clone())?; - - for notification_id in &existing_notification_ids { - session.peer.send( - connection, - proto::DeleteNotification { - notification_id: (*notification_id).to_proto(), - }, - )?; - } - - Ok(()) - }, - ); - response.send(proto::Ack {})?; - Ok(()) + Err(anyhow!("chat has been removed in the latest version of Zed").into()) } async fn update_channel_message( - request: proto::UpdateChannelMessage, - response: Response, - session: MessageContext, + _request: proto::UpdateChannelMessage, + _response: Response, + _session: MessageContext, ) -> Result<()> { - let channel_id = ChannelId::from_proto(request.channel_id); - let message_id = MessageId::from_proto(request.message_id); - let updated_at = OffsetDateTime::now_utc(); - let UpdatedChannelMessage { - message_id, - participant_connection_ids, - notifications, - reply_to_message_id, - timestamp, - deleted_mention_notification_ids, - updated_mention_notifications, - } = session - .db() - .await - .update_channel_message( - channel_id, - message_id, - session.user_id(), - request.body.as_str(), - &request.mentions, - updated_at, - ) - .await?; - - let nonce = request.nonce.clone().context("nonce can't be blank")?; - - let message = proto::ChannelMessage { - sender_id: session.user_id().to_proto(), - id: message_id.to_proto(), - body: request.body.clone(), - mentions: request.mentions.clone(), - timestamp: timestamp.assume_utc().unix_timestamp() as u64, - nonce: Some(nonce), - reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()), - edited_at: Some(updated_at.unix_timestamp() as u64), - }; - - response.send(proto::Ack {})?; - - let pool = &*session.connection_pool().await; - broadcast( - Some(session.connection_id), - participant_connection_ids, - |connection| { - session.peer.send( - connection, - proto::ChannelMessageUpdate { - channel_id: channel_id.to_proto(), - message: Some(message.clone()), - }, - )?; - - for notification_id in &deleted_mention_notification_ids { - session.peer.send( - connection, - proto::DeleteNotification { - notification_id: (*notification_id).to_proto(), - }, - )?; - } - - for notification in &updated_mention_notifications { - session.peer.send( - connection, - proto::UpdateNotification { - notification: Some(notification.clone()), - }, - )?; - } - - Ok(()) - }, - ); - - send_notifications(pool, &session.peer, notifications); - - Ok(()) + Err(anyhow!("chat has been removed in the latest version of Zed").into()) } /// Mark a channel message as read async fn acknowledge_channel_message( - request: proto::AckChannelMessage, - session: MessageContext, + _request: proto::AckChannelMessage, + _session: MessageContext, ) -> Result<()> { - let channel_id = ChannelId::from_proto(request.channel_id); - let message_id = MessageId::from_proto(request.message_id); - let notifications = session - .db() - .await - .observe_channel_message(channel_id, session.user_id(), message_id) - .await?; - send_notifications( - &*session.connection_pool().await, - &session.peer, - notifications, - ); - Ok(()) + Err(anyhow!("chat has been removed in the latest version of Zed").into()) } /// Mark a buffer version as synced @@ -4134,84 +3675,37 @@ async fn get_supermaven_api_key( /// Start receiving chat updates for a channel async fn join_channel_chat( - request: proto::JoinChannelChat, - response: Response, - session: MessageContext, + _request: proto::JoinChannelChat, + _response: Response, + _session: MessageContext, ) -> Result<()> { - let channel_id = ChannelId::from_proto(request.channel_id); - - let db = session.db().await; - db.join_channel_chat(channel_id, session.connection_id, session.user_id()) - .await?; - let messages = db - .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None) - .await?; - response.send(proto::JoinChannelChatResponse { - done: messages.len() < MESSAGE_COUNT_PER_PAGE, - messages, - })?; - Ok(()) + Err(anyhow!("chat has been removed in the latest version of Zed").into()) } /// Stop receiving chat updates for a channel async fn leave_channel_chat( - request: proto::LeaveChannelChat, - session: MessageContext, + _request: proto::LeaveChannelChat, + _session: MessageContext, ) -> Result<()> { - let channel_id = ChannelId::from_proto(request.channel_id); - session - .db() - .await - .leave_channel_chat(channel_id, session.connection_id, session.user_id()) - .await?; - Ok(()) + Err(anyhow!("chat has been removed in the latest version of Zed").into()) } /// Retrieve the chat history for a channel async fn get_channel_messages( - request: proto::GetChannelMessages, - response: Response, - session: MessageContext, + _request: proto::GetChannelMessages, + _response: Response, + _session: MessageContext, ) -> Result<()> { - let channel_id = ChannelId::from_proto(request.channel_id); - let messages = session - .db() - .await - .get_channel_messages( - channel_id, - session.user_id(), - MESSAGE_COUNT_PER_PAGE, - Some(MessageId::from_proto(request.before_message_id)), - ) - .await?; - response.send(proto::GetChannelMessagesResponse { - done: messages.len() < MESSAGE_COUNT_PER_PAGE, - messages, - })?; - Ok(()) + Err(anyhow!("chat has been removed in the latest version of Zed").into()) } /// Retrieve specific chat messages async fn get_channel_messages_by_id( - request: proto::GetChannelMessagesById, - response: Response, - session: MessageContext, + _request: proto::GetChannelMessagesById, + _response: Response, + _session: MessageContext, ) -> Result<()> { - let message_ids = request - .message_ids - .iter() - .map(|id| MessageId::from_proto(*id)) - .collect::>(); - let messages = session - .db() - .await - .get_channel_messages_by_id(session.user_id(), &message_ids) - .await?; - response.send(proto::GetChannelMessagesResponse { - done: messages.len() < MESSAGE_COUNT_PER_PAGE, - messages, - })?; - Ok(()) + Err(anyhow!("chat has been removed in the latest version of Zed").into()) } /// Retrieve the current users notifications @@ -4258,139 +3752,6 @@ async fn mark_notification_as_read( Ok(()) } -/// Get the current users information -async fn get_private_user_info( - _request: proto::GetPrivateUserInfo, - response: Response, - session: MessageContext, -) -> Result<()> { - let db = session.db().await; - - let metrics_id = db.get_user_metrics_id(session.user_id()).await?; - let user = db - .get_user_by_id(session.user_id()) - .await? - .context("user not found")?; - let flags = db.get_user_flags(session.user_id()).await?; - - response.send(proto::GetPrivateUserInfoResponse { - metrics_id, - staff: user.admin, - flags, - accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64), - })?; - Ok(()) -} - -/// Accept the terms of service (tos) on behalf of the current user -async fn accept_terms_of_service( - _request: proto::AcceptTermsOfService, - response: Response, - session: MessageContext, -) -> Result<()> { - let db = session.db().await; - - let accepted_tos_at = Utc::now(); - db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc())) - .await?; - - response.send(proto::AcceptTermsOfServiceResponse { - accepted_tos_at: accepted_tos_at.timestamp() as u64, - })?; - - // When the user accepts the terms of service, we want to refresh their LLM - // token to grant access. - session - .peer - .send(session.connection_id, proto::RefreshLlmToken {})?; - - Ok(()) -} - -async fn get_llm_api_token( - _request: proto::GetLlmToken, - response: Response, - session: MessageContext, -) -> Result<()> { - let db = session.db().await; - - let flags = db.get_user_flags(session.user_id()).await?; - - let user_id = session.user_id(); - let user = db - .get_user_by_id(user_id) - .await? - .with_context(|| format!("user {user_id} not found"))?; - - if user.accepted_tos_at.is_none() { - Err(anyhow!("terms of service not accepted"))? - } - - let stripe_client = session - .app_state - .stripe_client - .as_ref() - .context("failed to retrieve Stripe client")?; - - let stripe_billing = session - .app_state - .stripe_billing - .as_ref() - .context("failed to retrieve Stripe billing object")?; - - let billing_customer = if let Some(billing_customer) = - db.get_billing_customer_by_user_id(user.id).await? - { - billing_customer - } else { - let customer_id = stripe_billing - .find_or_create_customer_by_email(user.email_address.as_deref()) - .await?; - - find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id) - .await? - .context("billing customer not found")? - }; - - let billing_subscription = - if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? { - billing_subscription - } else { - let stripe_customer_id = - StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); - - let stripe_subscription = stripe_billing - .subscribe_to_zed_free(stripe_customer_id) - .await?; - - db.create_billing_subscription(&db::CreateBillingSubscriptionParams { - billing_customer_id: billing_customer.id, - kind: Some(SubscriptionKind::ZedFree), - stripe_subscription_id: stripe_subscription.id.to_string(), - stripe_subscription_status: stripe_subscription.status.into(), - stripe_cancellation_reason: None, - stripe_current_period_start: Some(stripe_subscription.current_period_start), - stripe_current_period_end: Some(stripe_subscription.current_period_end), - }) - .await? - }; - - let billing_preferences = db.get_billing_preferences(user.id).await?; - - let token = LlmTokenClaims::create( - &user, - session.is_staff(), - billing_customer, - billing_preferences, - &flags, - billing_subscription, - session.system_id.clone(), - &session.app_state.config, - )?; - response.send(proto::GetLlmTokenResponse { token })?; - Ok(()) -} - fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result { let message = match message { TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()), @@ -4484,7 +3845,6 @@ fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserCh }) .collect(), observed_channel_buffer_version: channels.observed_buffer_versions.clone(), - observed_channel_message_id: channels.observed_channel_messages.clone(), } } @@ -4496,7 +3856,6 @@ fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels { } update.latest_channel_buffer_versions = channels.latest_buffer_versions; - update.latest_channel_message_ids = channels.latest_channel_messages; for (channel_id, participants) in channels.channel_participants { update diff --git a/crates/collab/src/rpc/connection_pool.rs b/crates/collab/src/rpc/connection_pool.rs index 35290fa697680140e52a147cc25cd87b6afee31e..729e7c8533460c0789d74040e883d48c8b94af92 100644 --- a/crates/collab/src/rpc/connection_pool.rs +++ b/crates/collab/src/rpc/connection_pool.rs @@ -30,7 +30,19 @@ impl fmt::Display for ZedVersion { impl ZedVersion { pub fn can_collaborate(&self) -> bool { - self.0 >= SemanticVersion::new(0, 157, 0) + // v0.198.4 is the first version where we no longer connect to Collab automatically. + // We reject any clients older than that to prevent them from connecting to Collab just for authentication. + if self.0 < SemanticVersion::new(0, 198, 4) { + return false; + } + + // Since we hotfixed the changes to no longer connect to Collab automatically to Preview, we also need to reject + // versions in the range [v0.199.0, v0.199.1]. + if self.0 >= SemanticVersion::new(0, 199, 0) && self.0 < SemanticVersion::new(0, 199, 2) { + return false; + } + + true } } diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs deleted file mode 100644 index ef5bef3e7e5d6c687e4b963f820d5d484e6c4537..0000000000000000000000000000000000000000 --- a/crates/collab/src/stripe_billing.rs +++ /dev/null @@ -1,156 +0,0 @@ -use std::sync::Arc; - -use anyhow::anyhow; -use collections::HashMap; -use stripe::SubscriptionStatus; -use tokio::sync::RwLock; - -use crate::Result; -use crate::stripe_client::{ - RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateSubscriptionItems, - StripeCreateSubscriptionParams, StripeCustomerId, StripePrice, StripePriceId, - StripeSubscription, -}; - -pub struct StripeBilling { - state: RwLock, - client: Arc, -} - -#[derive(Default)] -struct StripeBillingState { - prices_by_lookup_key: HashMap, -} - -impl StripeBilling { - pub fn new(client: Arc) -> Self { - Self { - client: Arc::new(RealStripeClient::new(client.clone())), - state: RwLock::default(), - } - } - - #[cfg(test)] - pub fn test(client: Arc) -> Self { - Self { - client, - state: RwLock::default(), - } - } - - pub fn client(&self) -> &Arc { - &self.client - } - - pub async fn initialize(&self) -> Result<()> { - log::info!("StripeBilling: initializing"); - - let mut state = self.state.write().await; - - let prices = self.client.list_prices().await?; - - for price in prices { - if let Some(lookup_key) = price.lookup_key.clone() { - state.prices_by_lookup_key.insert(lookup_key, price); - } - } - - log::info!("StripeBilling: initialized"); - - Ok(()) - } - - pub async fn zed_pro_price_id(&self) -> Result { - self.find_price_id_by_lookup_key("zed-pro").await - } - - pub async fn zed_free_price_id(&self) -> Result { - self.find_price_id_by_lookup_key("zed-free").await - } - - pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result { - self.state - .read() - .await - .prices_by_lookup_key - .get(lookup_key) - .map(|price| price.id.clone()) - .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}"))) - } - - pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result { - self.state - .read() - .await - .prices_by_lookup_key - .get(lookup_key) - .cloned() - .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}"))) - } - - /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does - /// not already exist. - /// - /// Always returns a new Stripe customer if the email address is `None`. - pub async fn find_or_create_customer_by_email( - &self, - email_address: Option<&str>, - ) -> Result { - let existing_customer = if let Some(email) = email_address { - let customers = self.client.list_customers_by_email(email).await?; - - customers.first().cloned() - } else { - None - }; - - let customer_id = if let Some(existing_customer) = existing_customer { - existing_customer.id - } else { - let customer = self - .client - .create_customer(crate::stripe_client::CreateCustomerParams { - email: email_address, - }) - .await?; - - customer.id - }; - - Ok(customer_id) - } - - pub async fn subscribe_to_zed_free( - &self, - customer_id: StripeCustomerId, - ) -> Result { - let zed_free_price_id = self.zed_free_price_id().await?; - - let existing_subscriptions = self - .client - .list_subscriptions_for_customer(&customer_id) - .await?; - - let existing_active_subscription = - existing_subscriptions.into_iter().find(|subscription| { - subscription.status == SubscriptionStatus::Active - || subscription.status == SubscriptionStatus::Trialing - }); - if let Some(subscription) = existing_active_subscription { - return Ok(subscription); - } - - let params = StripeCreateSubscriptionParams { - customer: customer_id, - items: vec![StripeCreateSubscriptionItems { - price: Some(zed_free_price_id), - quantity: Some(1), - }], - automatic_tax: Some(StripeAutomaticTax { enabled: true }), - }; - - let subscription = self.client.create_subscription(params).await?; - - Ok(subscription) - } -} diff --git a/crates/collab/src/stripe_client.rs b/crates/collab/src/stripe_client.rs deleted file mode 100644 index 6e75a4d874bf41e7cb4418d4b56cfeb6040e5ff8..0000000000000000000000000000000000000000 --- a/crates/collab/src/stripe_client.rs +++ /dev/null @@ -1,285 +0,0 @@ -#[cfg(test)] -mod fake_stripe_client; -mod real_stripe_client; - -use std::collections::HashMap; -use std::sync::Arc; - -use anyhow::Result; -use async_trait::async_trait; - -#[cfg(test)] -pub use fake_stripe_client::*; -pub use real_stripe_client::*; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Serialize)] -pub struct StripeCustomerId(pub Arc); - -#[derive(Debug, Clone)] -pub struct StripeCustomer { - pub id: StripeCustomerId, - pub email: Option, -} - -#[derive(Debug)] -pub struct CreateCustomerParams<'a> { - pub email: Option<&'a str>, -} - -#[derive(Debug)] -pub struct UpdateCustomerParams<'a> { - pub email: Option<&'a str>, -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] -pub struct StripeSubscriptionId(pub Arc); - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeSubscription { - pub id: StripeSubscriptionId, - pub customer: StripeCustomerId, - // TODO: Create our own version of this enum. - pub status: stripe::SubscriptionStatus, - pub current_period_end: i64, - pub current_period_start: i64, - pub items: Vec, - pub cancel_at: Option, - pub cancellation_details: Option, -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] -pub struct StripeSubscriptionItemId(pub Arc); - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeSubscriptionItem { - pub id: StripeSubscriptionItemId, - pub price: Option, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct StripeCancellationDetails { - pub reason: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCancellationDetailsReason { - CancellationRequested, - PaymentDisputed, - PaymentFailed, -} - -#[derive(Debug)] -pub struct StripeCreateSubscriptionParams { - pub customer: StripeCustomerId, - pub items: Vec, - pub automatic_tax: Option, -} - -#[derive(Debug)] -pub struct StripeCreateSubscriptionItems { - pub price: Option, - pub quantity: Option, -} - -#[derive(Debug, Clone)] -pub struct UpdateSubscriptionParams { - pub items: Option>, - pub trial_settings: Option, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct UpdateSubscriptionItems { - pub price: Option, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeSubscriptionTrialSettings { - pub end_behavior: StripeSubscriptionTrialSettingsEndBehavior, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeSubscriptionTrialSettingsEndBehavior { - pub missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod { - Cancel, - CreateInvoice, - Pause, -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] -pub struct StripePriceId(pub Arc); - -#[derive(Debug, PartialEq, Clone)] -pub struct StripePrice { - pub id: StripePriceId, - pub unit_amount: Option, - pub lookup_key: Option, - pub recurring: Option, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripePriceRecurring { - pub meter: Option, -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Deserialize)] -pub struct StripeMeterId(pub Arc); - -#[derive(Debug, Clone, Deserialize)] -pub struct StripeMeter { - pub id: StripeMeterId, - pub event_name: String, -} - -#[derive(Debug, Serialize)] -pub struct StripeCreateMeterEventParams<'a> { - pub identifier: &'a str, - pub event_name: &'a str, - pub payload: StripeCreateMeterEventPayload<'a>, - pub timestamp: Option, -} - -#[derive(Debug, Serialize)] -pub struct StripeCreateMeterEventPayload<'a> { - pub value: u64, - pub stripe_customer_id: &'a StripeCustomerId, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeBillingAddressCollection { - Auto, - Required, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeCustomerUpdate { - pub address: Option, - pub name: Option, - pub shipping: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCustomerUpdateAddress { - Auto, - Never, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCustomerUpdateName { - Auto, - Never, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCustomerUpdateShipping { - Auto, - Never, -} - -#[derive(Debug, Default)] -pub struct StripeCreateCheckoutSessionParams<'a> { - pub customer: Option<&'a StripeCustomerId>, - pub client_reference_id: Option<&'a str>, - pub mode: Option, - pub line_items: Option>, - pub payment_method_collection: Option, - pub subscription_data: Option, - pub success_url: Option<&'a str>, - pub billing_address_collection: Option, - pub customer_update: Option, - pub tax_id_collection: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCheckoutSessionMode { - Payment, - Setup, - Subscription, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeCreateCheckoutSessionLineItems { - pub price: Option, - pub quantity: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCheckoutSessionPaymentMethodCollection { - Always, - IfRequired, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeCreateCheckoutSessionSubscriptionData { - pub metadata: Option>, - pub trial_period_days: Option, - pub trial_settings: Option, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeTaxIdCollection { - pub enabled: bool, -} - -#[derive(Debug, Clone)] -pub struct StripeAutomaticTax { - pub enabled: bool, -} - -#[derive(Debug)] -pub struct StripeCheckoutSession { - pub url: Option, -} - -#[async_trait] -pub trait StripeClient: Send + Sync { - async fn list_customers_by_email(&self, email: &str) -> Result>; - - async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result; - - async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result; - - async fn update_customer( - &self, - customer_id: &StripeCustomerId, - params: UpdateCustomerParams<'_>, - ) -> Result; - - async fn list_subscriptions_for_customer( - &self, - customer_id: &StripeCustomerId, - ) -> Result>; - - async fn get_subscription( - &self, - subscription_id: &StripeSubscriptionId, - ) -> Result; - - async fn create_subscription( - &self, - params: StripeCreateSubscriptionParams, - ) -> Result; - - async fn update_subscription( - &self, - subscription_id: &StripeSubscriptionId, - params: UpdateSubscriptionParams, - ) -> Result<()>; - - async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()>; - - async fn list_prices(&self) -> Result>; - - async fn list_meters(&self) -> Result>; - - async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>; - - async fn create_checkout_session( - &self, - params: StripeCreateCheckoutSessionParams<'_>, - ) -> Result; -} diff --git a/crates/collab/src/stripe_client/fake_stripe_client.rs b/crates/collab/src/stripe_client/fake_stripe_client.rs deleted file mode 100644 index 9bb08443ec6a5fd04ad11a8e24b1a71b03e4867b..0000000000000000000000000000000000000000 --- a/crates/collab/src/stripe_client/fake_stripe_client.rs +++ /dev/null @@ -1,247 +0,0 @@ -use std::sync::Arc; - -use anyhow::{Result, anyhow}; -use async_trait::async_trait; -use chrono::{Duration, Utc}; -use collections::HashMap; -use parking_lot::Mutex; -use uuid::Uuid; - -use crate::stripe_client::{ - CreateCustomerParams, StripeBillingAddressCollection, StripeCheckoutSession, - StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient, - StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, - StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, - StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate, - StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription, - StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, StripeTaxIdCollection, - UpdateCustomerParams, UpdateSubscriptionParams, -}; - -#[derive(Debug, Clone)] -pub struct StripeCreateMeterEventCall { - pub identifier: Arc, - pub event_name: Arc, - pub value: u64, - pub stripe_customer_id: StripeCustomerId, - pub timestamp: Option, -} - -#[derive(Debug, Clone)] -pub struct StripeCreateCheckoutSessionCall { - pub customer: Option, - pub client_reference_id: Option, - pub mode: Option, - pub line_items: Option>, - pub payment_method_collection: Option, - pub subscription_data: Option, - pub success_url: Option, - pub billing_address_collection: Option, - pub customer_update: Option, - pub tax_id_collection: Option, -} - -pub struct FakeStripeClient { - pub customers: Arc>>, - pub subscriptions: Arc>>, - pub update_subscription_calls: - Arc>>, - pub prices: Arc>>, - pub meters: Arc>>, - pub create_meter_event_calls: Arc>>, - pub create_checkout_session_calls: Arc>>, -} - -impl FakeStripeClient { - pub fn new() -> Self { - Self { - customers: Arc::new(Mutex::new(HashMap::default())), - subscriptions: Arc::new(Mutex::new(HashMap::default())), - update_subscription_calls: Arc::new(Mutex::new(Vec::new())), - prices: Arc::new(Mutex::new(HashMap::default())), - meters: Arc::new(Mutex::new(HashMap::default())), - create_meter_event_calls: Arc::new(Mutex::new(Vec::new())), - create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())), - } - } -} - -#[async_trait] -impl StripeClient for FakeStripeClient { - async fn list_customers_by_email(&self, email: &str) -> Result> { - Ok(self - .customers - .lock() - .values() - .filter(|customer| customer.email.as_deref() == Some(email)) - .cloned() - .collect()) - } - - async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result { - self.customers - .lock() - .get(customer_id) - .cloned() - .ok_or_else(|| anyhow!("no customer found for {customer_id:?}")) - } - - async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result { - let customer = StripeCustomer { - id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()), - email: params.email.map(|email| email.to_string()), - }; - - self.customers - .lock() - .insert(customer.id.clone(), customer.clone()); - - Ok(customer) - } - - async fn update_customer( - &self, - customer_id: &StripeCustomerId, - params: UpdateCustomerParams<'_>, - ) -> Result { - let mut customers = self.customers.lock(); - if let Some(customer) = customers.get_mut(customer_id) { - if let Some(email) = params.email { - customer.email = Some(email.to_string()); - } - Ok(customer.clone()) - } else { - Err(anyhow!("no customer found for {customer_id:?}")) - } - } - - async fn list_subscriptions_for_customer( - &self, - customer_id: &StripeCustomerId, - ) -> Result> { - let subscriptions = self - .subscriptions - .lock() - .values() - .filter(|subscription| subscription.customer == *customer_id) - .cloned() - .collect(); - - Ok(subscriptions) - } - - async fn get_subscription( - &self, - subscription_id: &StripeSubscriptionId, - ) -> Result { - self.subscriptions - .lock() - .get(subscription_id) - .cloned() - .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}")) - } - - async fn create_subscription( - &self, - params: StripeCreateSubscriptionParams, - ) -> Result { - let now = Utc::now(); - - let subscription = StripeSubscription { - id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()), - customer: params.customer, - status: stripe::SubscriptionStatus::Active, - current_period_start: now.timestamp(), - current_period_end: (now + Duration::days(30)).timestamp(), - items: params - .items - .into_iter() - .map(|item| StripeSubscriptionItem { - id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()), - price: item - .price - .and_then(|price_id| self.prices.lock().get(&price_id).cloned()), - }) - .collect(), - cancel_at: None, - cancellation_details: None, - }; - - self.subscriptions - .lock() - .insert(subscription.id.clone(), subscription.clone()); - - Ok(subscription) - } - - async fn update_subscription( - &self, - subscription_id: &StripeSubscriptionId, - params: UpdateSubscriptionParams, - ) -> Result<()> { - let subscription = self.get_subscription(subscription_id).await?; - - self.update_subscription_calls - .lock() - .push((subscription.id, params)); - - Ok(()) - } - - async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> { - // TODO: Implement fake subscription cancellation. - let _ = subscription_id; - - Ok(()) - } - - async fn list_prices(&self) -> Result> { - let prices = self.prices.lock().values().cloned().collect(); - - Ok(prices) - } - - async fn list_meters(&self) -> Result> { - let meters = self.meters.lock().values().cloned().collect(); - - Ok(meters) - } - - async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> { - self.create_meter_event_calls - .lock() - .push(StripeCreateMeterEventCall { - identifier: params.identifier.into(), - event_name: params.event_name.into(), - value: params.payload.value, - stripe_customer_id: params.payload.stripe_customer_id.clone(), - timestamp: params.timestamp, - }); - - Ok(()) - } - - async fn create_checkout_session( - &self, - params: StripeCreateCheckoutSessionParams<'_>, - ) -> Result { - self.create_checkout_session_calls - .lock() - .push(StripeCreateCheckoutSessionCall { - customer: params.customer.cloned(), - client_reference_id: params.client_reference_id.map(|id| id.to_string()), - mode: params.mode, - line_items: params.line_items, - payment_method_collection: params.payment_method_collection, - subscription_data: params.subscription_data, - success_url: params.success_url.map(|url| url.to_string()), - billing_address_collection: params.billing_address_collection, - customer_update: params.customer_update, - tax_id_collection: params.tax_id_collection, - }); - - Ok(StripeCheckoutSession { - url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()), - }) - } -} diff --git a/crates/collab/src/stripe_client/real_stripe_client.rs b/crates/collab/src/stripe_client/real_stripe_client.rs deleted file mode 100644 index 07c191ff30400ccbf4b73c4c84f09aa47e0fd9aa..0000000000000000000000000000000000000000 --- a/crates/collab/src/stripe_client/real_stripe_client.rs +++ /dev/null @@ -1,612 +0,0 @@ -use std::str::FromStr as _; -use std::sync::Arc; - -use anyhow::{Context as _, Result, anyhow}; -use async_trait::async_trait; -use serde::{Deserialize, Serialize}; -use stripe::{ - CancellationDetails, CancellationDetailsReason, CheckoutSession, CheckoutSessionMode, - CheckoutSessionPaymentMethodCollection, CreateCheckoutSession, CreateCheckoutSessionLineItems, - CreateCheckoutSessionSubscriptionData, CreateCheckoutSessionSubscriptionDataTrialSettings, - CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior, - CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod, - CreateCustomer, CreateSubscriptionAutomaticTax, Customer, CustomerId, ListCustomers, Price, - PriceId, Recurring, Subscription, SubscriptionId, SubscriptionItem, SubscriptionItemId, - UpdateCustomer, UpdateSubscriptionItems, UpdateSubscriptionTrialSettings, - UpdateSubscriptionTrialSettingsEndBehavior, - UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, -}; - -use crate::stripe_client::{ - CreateCustomerParams, StripeAutomaticTax, StripeBillingAddressCollection, - StripeCancellationDetails, StripeCancellationDetailsReason, StripeCheckoutSession, - StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient, - StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, - StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, - StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate, - StripeCustomerUpdateAddress, StripeCustomerUpdateName, StripeCustomerUpdateShipping, - StripeMeter, StripePrice, StripePriceId, StripePriceRecurring, StripeSubscription, - StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, - StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, StripeTaxIdCollection, - UpdateCustomerParams, UpdateSubscriptionParams, -}; - -pub struct RealStripeClient { - client: Arc, -} - -impl RealStripeClient { - pub fn new(client: Arc) -> Self { - Self { client } - } -} - -#[async_trait] -impl StripeClient for RealStripeClient { - async fn list_customers_by_email(&self, email: &str) -> Result> { - let response = Customer::list( - &self.client, - &ListCustomers { - email: Some(email), - ..Default::default() - }, - ) - .await?; - - Ok(response - .data - .into_iter() - .map(StripeCustomer::from) - .collect()) - } - - async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result { - let customer_id = customer_id.try_into()?; - - let customer = Customer::retrieve(&self.client, &customer_id, &[]).await?; - - Ok(StripeCustomer::from(customer)) - } - - async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result { - let customer = Customer::create( - &self.client, - CreateCustomer { - email: params.email, - ..Default::default() - }, - ) - .await?; - - Ok(StripeCustomer::from(customer)) - } - - async fn update_customer( - &self, - customer_id: &StripeCustomerId, - params: UpdateCustomerParams<'_>, - ) -> Result { - let customer = Customer::update( - &self.client, - &customer_id.try_into()?, - UpdateCustomer { - email: params.email, - ..Default::default() - }, - ) - .await?; - - Ok(StripeCustomer::from(customer)) - } - - async fn list_subscriptions_for_customer( - &self, - customer_id: &StripeCustomerId, - ) -> Result> { - let customer_id = customer_id.try_into()?; - - let subscriptions = stripe::Subscription::list( - &self.client, - &stripe::ListSubscriptions { - customer: Some(customer_id), - status: None, - ..Default::default() - }, - ) - .await?; - - Ok(subscriptions - .data - .into_iter() - .map(StripeSubscription::from) - .collect()) - } - - async fn get_subscription( - &self, - subscription_id: &StripeSubscriptionId, - ) -> Result { - let subscription_id = subscription_id.try_into()?; - - let subscription = Subscription::retrieve(&self.client, &subscription_id, &[]).await?; - - Ok(StripeSubscription::from(subscription)) - } - - async fn create_subscription( - &self, - params: StripeCreateSubscriptionParams, - ) -> Result { - let customer_id = params.customer.try_into()?; - - let mut create_subscription = stripe::CreateSubscription::new(customer_id); - create_subscription.items = Some( - params - .items - .into_iter() - .map(|item| stripe::CreateSubscriptionItems { - price: item.price.map(|price| price.to_string()), - quantity: item.quantity, - ..Default::default() - }) - .collect(), - ); - create_subscription.automatic_tax = params.automatic_tax.map(Into::into); - - let subscription = Subscription::create(&self.client, create_subscription).await?; - - Ok(StripeSubscription::from(subscription)) - } - - async fn update_subscription( - &self, - subscription_id: &StripeSubscriptionId, - params: UpdateSubscriptionParams, - ) -> Result<()> { - let subscription_id = subscription_id.try_into()?; - - stripe::Subscription::update( - &self.client, - &subscription_id, - stripe::UpdateSubscription { - items: params.items.map(|items| { - items - .into_iter() - .map(|item| UpdateSubscriptionItems { - price: item.price.map(|price| price.to_string()), - ..Default::default() - }) - .collect() - }), - trial_settings: params.trial_settings.map(Into::into), - ..Default::default() - }, - ) - .await?; - - Ok(()) - } - - async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> { - let subscription_id = subscription_id.try_into()?; - - Subscription::cancel( - &self.client, - &subscription_id, - stripe::CancelSubscription { - invoice_now: None, - ..Default::default() - }, - ) - .await?; - - Ok(()) - } - - async fn list_prices(&self) -> Result> { - let response = stripe::Price::list( - &self.client, - &stripe::ListPrices { - limit: Some(100), - ..Default::default() - }, - ) - .await?; - - Ok(response.data.into_iter().map(StripePrice::from).collect()) - } - - async fn list_meters(&self) -> Result> { - #[derive(Serialize)] - struct Params { - #[serde(skip_serializing_if = "Option::is_none")] - limit: Option, - } - - let response = self - .client - .get_query::, _>( - "/billing/meters", - Params { limit: Some(100) }, - ) - .await?; - - Ok(response.data) - } - - async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> { - #[derive(Deserialize)] - struct StripeMeterEvent { - pub identifier: String, - } - - let identifier = params.identifier; - match self - .client - .post_form::("/billing/meter_events", params) - .await - { - Ok(_event) => Ok(()), - Err(stripe::StripeError::Stripe(error)) => { - if error.http_status == 400 - && error - .message - .as_ref() - .map_or(false, |message| message.contains(identifier)) - { - Ok(()) - } else { - Err(anyhow!(stripe::StripeError::Stripe(error))) - } - } - Err(error) => Err(anyhow!("failed to create meter event: {error:?}")), - } - } - - async fn create_checkout_session( - &self, - params: StripeCreateCheckoutSessionParams<'_>, - ) -> Result { - let params = params.try_into()?; - let session = CheckoutSession::create(&self.client, params).await?; - - Ok(session.into()) - } -} - -impl From for StripeCustomerId { - fn from(value: CustomerId) -> Self { - Self(value.as_str().into()) - } -} - -impl TryFrom for CustomerId { - type Error = anyhow::Error; - - fn try_from(value: StripeCustomerId) -> Result { - Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID") - } -} - -impl TryFrom<&StripeCustomerId> for CustomerId { - type Error = anyhow::Error; - - fn try_from(value: &StripeCustomerId) -> Result { - Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID") - } -} - -impl From for StripeCustomer { - fn from(value: Customer) -> Self { - StripeCustomer { - id: value.id.into(), - email: value.email, - } - } -} - -impl From for StripeSubscriptionId { - fn from(value: SubscriptionId) -> Self { - Self(value.as_str().into()) - } -} - -impl TryFrom<&StripeSubscriptionId> for SubscriptionId { - type Error = anyhow::Error; - - fn try_from(value: &StripeSubscriptionId) -> Result { - Self::from_str(value.0.as_ref()).context("failed to parse Stripe subscription ID") - } -} - -impl From for StripeSubscription { - fn from(value: Subscription) -> Self { - Self { - id: value.id.into(), - customer: value.customer.id().into(), - status: value.status, - current_period_start: value.current_period_start, - current_period_end: value.current_period_end, - items: value.items.data.into_iter().map(Into::into).collect(), - cancel_at: value.cancel_at, - cancellation_details: value.cancellation_details.map(Into::into), - } - } -} - -impl From for StripeCancellationDetails { - fn from(value: CancellationDetails) -> Self { - Self { - reason: value.reason.map(Into::into), - } - } -} - -impl From for StripeCancellationDetailsReason { - fn from(value: CancellationDetailsReason) -> Self { - match value { - CancellationDetailsReason::CancellationRequested => Self::CancellationRequested, - CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed, - CancellationDetailsReason::PaymentFailed => Self::PaymentFailed, - } - } -} - -impl From for StripeSubscriptionItemId { - fn from(value: SubscriptionItemId) -> Self { - Self(value.as_str().into()) - } -} - -impl From for StripeSubscriptionItem { - fn from(value: SubscriptionItem) -> Self { - Self { - id: value.id.into(), - price: value.price.map(Into::into), - } - } -} - -impl From for CreateSubscriptionAutomaticTax { - fn from(value: StripeAutomaticTax) -> Self { - Self { - enabled: value.enabled, - liability: None, - } - } -} - -impl From for UpdateSubscriptionTrialSettings { - fn from(value: StripeSubscriptionTrialSettings) -> Self { - Self { - end_behavior: value.end_behavior.into(), - } - } -} - -impl From - for UpdateSubscriptionTrialSettingsEndBehavior -{ - fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self { - Self { - missing_payment_method: value.missing_payment_method.into(), - } - } -} - -impl From - for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod -{ - fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self { - match value { - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => { - Self::CreateInvoice - } - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause, - } - } -} - -impl From for StripePriceId { - fn from(value: PriceId) -> Self { - Self(value.as_str().into()) - } -} - -impl TryFrom for PriceId { - type Error = anyhow::Error; - - fn try_from(value: StripePriceId) -> Result { - Self::from_str(value.0.as_ref()).context("failed to parse Stripe price ID") - } -} - -impl From for StripePrice { - fn from(value: Price) -> Self { - Self { - id: value.id.into(), - unit_amount: value.unit_amount, - lookup_key: value.lookup_key, - recurring: value.recurring.map(StripePriceRecurring::from), - } - } -} - -impl From for StripePriceRecurring { - fn from(value: Recurring) -> Self { - Self { meter: value.meter } - } -} - -impl<'a> TryFrom> for CreateCheckoutSession<'a> { - type Error = anyhow::Error; - - fn try_from(value: StripeCreateCheckoutSessionParams<'a>) -> Result { - Ok(Self { - customer: value - .customer - .map(|customer_id| customer_id.try_into()) - .transpose()?, - client_reference_id: value.client_reference_id, - mode: value.mode.map(Into::into), - line_items: value - .line_items - .map(|line_items| line_items.into_iter().map(Into::into).collect()), - payment_method_collection: value.payment_method_collection.map(Into::into), - subscription_data: value.subscription_data.map(Into::into), - success_url: value.success_url, - billing_address_collection: value.billing_address_collection.map(Into::into), - customer_update: value.customer_update.map(Into::into), - tax_id_collection: value.tax_id_collection.map(Into::into), - ..Default::default() - }) - } -} - -impl From for CheckoutSessionMode { - fn from(value: StripeCheckoutSessionMode) -> Self { - match value { - StripeCheckoutSessionMode::Payment => Self::Payment, - StripeCheckoutSessionMode::Setup => Self::Setup, - StripeCheckoutSessionMode::Subscription => Self::Subscription, - } - } -} - -impl From for CreateCheckoutSessionLineItems { - fn from(value: StripeCreateCheckoutSessionLineItems) -> Self { - Self { - price: value.price, - quantity: value.quantity, - ..Default::default() - } - } -} - -impl From for CheckoutSessionPaymentMethodCollection { - fn from(value: StripeCheckoutSessionPaymentMethodCollection) -> Self { - match value { - StripeCheckoutSessionPaymentMethodCollection::Always => Self::Always, - StripeCheckoutSessionPaymentMethodCollection::IfRequired => Self::IfRequired, - } - } -} - -impl From for CreateCheckoutSessionSubscriptionData { - fn from(value: StripeCreateCheckoutSessionSubscriptionData) -> Self { - Self { - trial_period_days: value.trial_period_days, - trial_settings: value.trial_settings.map(Into::into), - metadata: value.metadata, - ..Default::default() - } - } -} - -impl From for CreateCheckoutSessionSubscriptionDataTrialSettings { - fn from(value: StripeSubscriptionTrialSettings) -> Self { - Self { - end_behavior: value.end_behavior.into(), - } - } -} - -impl From - for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior -{ - fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self { - Self { - missing_payment_method: value.missing_payment_method.into(), - } - } -} - -impl From - for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod -{ - fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self { - match value { - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => { - Self::CreateInvoice - } - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause, - } - } -} - -impl From for StripeCheckoutSession { - fn from(value: CheckoutSession) -> Self { - Self { url: value.url } - } -} - -impl From for stripe::CheckoutSessionBillingAddressCollection { - fn from(value: StripeBillingAddressCollection) -> Self { - match value { - StripeBillingAddressCollection::Auto => { - stripe::CheckoutSessionBillingAddressCollection::Auto - } - StripeBillingAddressCollection::Required => { - stripe::CheckoutSessionBillingAddressCollection::Required - } - } - } -} - -impl From for stripe::CreateCheckoutSessionCustomerUpdateAddress { - fn from(value: StripeCustomerUpdateAddress) -> Self { - match value { - StripeCustomerUpdateAddress::Auto => { - stripe::CreateCheckoutSessionCustomerUpdateAddress::Auto - } - StripeCustomerUpdateAddress::Never => { - stripe::CreateCheckoutSessionCustomerUpdateAddress::Never - } - } - } -} - -impl From for stripe::CreateCheckoutSessionCustomerUpdateName { - fn from(value: StripeCustomerUpdateName) -> Self { - match value { - StripeCustomerUpdateName::Auto => stripe::CreateCheckoutSessionCustomerUpdateName::Auto, - StripeCustomerUpdateName::Never => { - stripe::CreateCheckoutSessionCustomerUpdateName::Never - } - } - } -} - -impl From for stripe::CreateCheckoutSessionCustomerUpdateShipping { - fn from(value: StripeCustomerUpdateShipping) -> Self { - match value { - StripeCustomerUpdateShipping::Auto => { - stripe::CreateCheckoutSessionCustomerUpdateShipping::Auto - } - StripeCustomerUpdateShipping::Never => { - stripe::CreateCheckoutSessionCustomerUpdateShipping::Never - } - } - } -} - -impl From for stripe::CreateCheckoutSessionCustomerUpdate { - fn from(value: StripeCustomerUpdate) -> Self { - stripe::CreateCheckoutSessionCustomerUpdate { - address: value.address.map(Into::into), - name: value.name.map(Into::into), - shipping: value.shipping.map(Into::into), - } - } -} - -impl From for stripe::CreateCheckoutSessionTaxIdCollection { - fn from(value: StripeTaxIdCollection) -> Self { - stripe::CreateCheckoutSessionTaxIdCollection { - enabled: value.enabled, - } - } -} diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index 8d5d076780733406904cd1c0431d56d6ebbc776f..7d07360b8042ed54a9f19a82a2876e448e8a14a4 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -6,9 +6,7 @@ use gpui::{Entity, TestAppContext}; mod channel_buffer_tests; mod channel_guest_tests; -mod channel_message_tests; mod channel_tests; -// mod debug_panel_tests; mod editor_tests; mod following_tests; mod git_tests; @@ -18,7 +16,6 @@ mod random_channel_buffer_tests; mod random_project_collaboration_tests; mod randomized_test_helpers; mod remote_editing_collaboration_tests; -mod stripe_billing_tests; mod test_server; use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; diff --git a/crates/collab/src/tests/channel_message_tests.rs b/crates/collab/src/tests/channel_message_tests.rs deleted file mode 100644 index dbc5cd86c2582719bbb0782e1b3630f08e4cacaf..0000000000000000000000000000000000000000 --- a/crates/collab/src/tests/channel_message_tests.rs +++ /dev/null @@ -1,725 +0,0 @@ -use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; -use channel::{ChannelChat, ChannelMessageId, MessageParams}; -use collab_ui::chat_panel::ChatPanel; -use gpui::{BackgroundExecutor, Entity, TestAppContext}; -use rpc::Notification; -use workspace::dock::Panel; - -#[gpui::test] -async fn test_basic_channel_messages( - executor: BackgroundExecutor, - mut cx_a: &mut TestAppContext, - mut cx_b: &mut TestAppContext, - mut cx_c: &mut TestAppContext, -) { - let mut server = TestServer::start(executor.clone()).await; - let client_a = server.create_client(cx_a, "user_a").await; - let client_b = server.create_client(cx_b, "user_b").await; - let client_c = server.create_client(cx_c, "user_c").await; - - let channel_id = server - .make_channel( - "the-channel", - None, - (&client_a, cx_a), - &mut [(&client_b, cx_b), (&client_c, cx_c)], - ) - .await; - - let channel_chat_a = client_a - .channel_store() - .update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx)) - .await - .unwrap(); - let channel_chat_b = client_b - .channel_store() - .update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx)) - .await - .unwrap(); - - let message_id = channel_chat_a - .update(cx_a, |c, cx| { - c.send_message( - MessageParams { - text: "hi @user_c!".into(), - mentions: vec![(3..10, client_c.id())], - reply_to_message_id: None, - }, - cx, - ) - .unwrap() - }) - .await - .unwrap(); - channel_chat_a - .update(cx_a, |c, cx| c.send_message("two".into(), cx).unwrap()) - .await - .unwrap(); - - executor.run_until_parked(); - channel_chat_b - .update(cx_b, |c, cx| c.send_message("three".into(), cx).unwrap()) - .await - .unwrap(); - - executor.run_until_parked(); - - let channel_chat_c = client_c - .channel_store() - .update(cx_c, |store, cx| store.open_channel_chat(channel_id, cx)) - .await - .unwrap(); - - for (chat, cx) in [ - (&channel_chat_a, &mut cx_a), - (&channel_chat_b, &mut cx_b), - (&channel_chat_c, &mut cx_c), - ] { - chat.update(*cx, |c, _| { - assert_eq!( - c.messages() - .iter() - .map(|m| (m.body.as_str(), m.mentions.as_slice())) - .collect::>(), - vec![ - ("hi @user_c!", [(3..10, client_c.id())].as_slice()), - ("two", &[]), - ("three", &[]) - ], - "results for user {}", - c.client().id(), - ); - }); - } - - client_c.notification_store().update(cx_c, |store, _| { - assert_eq!(store.notification_count(), 2); - assert_eq!(store.unread_notification_count(), 1); - assert_eq!( - store.notification_at(0).unwrap().notification, - Notification::ChannelMessageMention { - message_id, - sender_id: client_a.id(), - channel_id: channel_id.0, - } - ); - assert_eq!( - store.notification_at(1).unwrap().notification, - Notification::ChannelInvitation { - channel_id: channel_id.0, - channel_name: "the-channel".to_string(), - inviter_id: client_a.id() - } - ); - }); -} - -#[gpui::test] -async fn test_rejoin_channel_chat( - executor: BackgroundExecutor, - cx_a: &mut TestAppContext, - cx_b: &mut TestAppContext, -) { - let mut server = TestServer::start(executor.clone()).await; - let client_a = server.create_client(cx_a, "user_a").await; - let client_b = server.create_client(cx_b, "user_b").await; - - let channel_id = server - .make_channel( - "the-channel", - None, - (&client_a, cx_a), - &mut [(&client_b, cx_b)], - ) - .await; - - let channel_chat_a = client_a - .channel_store() - .update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx)) - .await - .unwrap(); - let channel_chat_b = client_b - .channel_store() - .update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx)) - .await - .unwrap(); - - channel_chat_a - .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap()) - .await - .unwrap(); - channel_chat_b - .update(cx_b, |c, cx| c.send_message("two".into(), cx).unwrap()) - .await - .unwrap(); - - server.forbid_connections(); - server.disconnect_client(client_a.peer_id().unwrap()); - - // While client A is disconnected, clients A and B both send new messages. - channel_chat_a - .update(cx_a, |c, cx| c.send_message("three".into(), cx).unwrap()) - .await - .unwrap_err(); - channel_chat_a - .update(cx_a, |c, cx| c.send_message("four".into(), cx).unwrap()) - .await - .unwrap_err(); - channel_chat_b - .update(cx_b, |c, cx| c.send_message("five".into(), cx).unwrap()) - .await - .unwrap(); - channel_chat_b - .update(cx_b, |c, cx| c.send_message("six".into(), cx).unwrap()) - .await - .unwrap(); - - // Client A reconnects. - server.allow_connections(); - executor.advance_clock(RECONNECT_TIMEOUT); - - // Client A fetches the messages that were sent while they were disconnected - // and resends their own messages which failed to send. - let expected_messages = &["one", "two", "five", "six", "three", "four"]; - assert_messages(&channel_chat_a, expected_messages, cx_a); - assert_messages(&channel_chat_b, expected_messages, cx_b); -} - -#[gpui::test] -async fn test_remove_channel_message( - executor: BackgroundExecutor, - cx_a: &mut TestAppContext, - cx_b: &mut TestAppContext, - cx_c: &mut TestAppContext, -) { - let mut server = TestServer::start(executor.clone()).await; - let client_a = server.create_client(cx_a, "user_a").await; - let client_b = server.create_client(cx_b, "user_b").await; - let client_c = server.create_client(cx_c, "user_c").await; - - let channel_id = server - .make_channel( - "the-channel", - None, - (&client_a, cx_a), - &mut [(&client_b, cx_b), (&client_c, cx_c)], - ) - .await; - - let channel_chat_a = client_a - .channel_store() - .update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx)) - .await - .unwrap(); - let channel_chat_b = client_b - .channel_store() - .update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx)) - .await - .unwrap(); - - // Client A sends some messages. - channel_chat_a - .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap()) - .await - .unwrap(); - let msg_id_2 = channel_chat_a - .update(cx_a, |c, cx| { - c.send_message( - MessageParams { - text: "two @user_b".to_string(), - mentions: vec![(4..12, client_b.id())], - reply_to_message_id: None, - }, - cx, - ) - .unwrap() - }) - .await - .unwrap(); - channel_chat_a - .update(cx_a, |c, cx| c.send_message("three".into(), cx).unwrap()) - .await - .unwrap(); - - // Clients A and B see all of the messages. - executor.run_until_parked(); - let expected_messages = &["one", "two @user_b", "three"]; - assert_messages(&channel_chat_a, expected_messages, cx_a); - assert_messages(&channel_chat_b, expected_messages, cx_b); - - // Ensure that client B received a notification for the mention. - client_b.notification_store().read_with(cx_b, |store, _| { - assert_eq!(store.notification_count(), 2); - let entry = store.notification_at(0).unwrap(); - assert_eq!( - entry.notification, - Notification::ChannelMessageMention { - message_id: msg_id_2, - sender_id: client_a.id(), - channel_id: channel_id.0, - } - ); - }); - - // Client A deletes one of their messages. - channel_chat_a - .update(cx_a, |c, cx| { - let ChannelMessageId::Saved(id) = c.message(1).id else { - panic!("message not saved") - }; - c.remove_message(id, cx) - }) - .await - .unwrap(); - - // Client B sees that the message is gone. - executor.run_until_parked(); - let expected_messages = &["one", "three"]; - assert_messages(&channel_chat_a, expected_messages, cx_a); - assert_messages(&channel_chat_b, expected_messages, cx_b); - - // Client C joins the channel chat, and does not see the deleted message. - let channel_chat_c = client_c - .channel_store() - .update(cx_c, |store, cx| store.open_channel_chat(channel_id, cx)) - .await - .unwrap(); - assert_messages(&channel_chat_c, expected_messages, cx_c); - - // Ensure we remove the notifications when the message is removed - client_b.notification_store().read_with(cx_b, |store, _| { - // First notification is the channel invitation, second would be the mention - // notification, which should now be removed. - assert_eq!(store.notification_count(), 1); - }); -} - -#[track_caller] -fn assert_messages(chat: &Entity, messages: &[&str], cx: &mut TestAppContext) { - assert_eq!( - chat.read_with(cx, |chat, _| { - chat.messages() - .iter() - .map(|m| m.body.clone()) - .collect::>() - }), - messages - ); -} - -#[gpui::test] -async fn test_channel_message_changes( - executor: BackgroundExecutor, - cx_a: &mut TestAppContext, - cx_b: &mut TestAppContext, -) { - let mut server = TestServer::start(executor.clone()).await; - let client_a = server.create_client(cx_a, "user_a").await; - let client_b = server.create_client(cx_b, "user_b").await; - - let channel_id = server - .make_channel( - "the-channel", - None, - (&client_a, cx_a), - &mut [(&client_b, cx_b)], - ) - .await; - - // Client A sends a message, client B should see that there is a new message. - let channel_chat_a = client_a - .channel_store() - .update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx)) - .await - .unwrap(); - - channel_chat_a - .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap()) - .await - .unwrap(); - - executor.run_until_parked(); - - let b_has_messages = cx_b.update(|cx| { - client_b - .channel_store() - .read(cx) - .has_new_messages(channel_id) - }); - - assert!(b_has_messages); - - // Opening the chat should clear the changed flag. - cx_b.update(|cx| { - collab_ui::init(&client_b.app_state, cx); - }); - let project_b = client_b.build_empty_local_project(cx_b); - let (workspace_b, cx_b) = client_b.build_workspace(&project_b, cx_b); - - let chat_panel_b = workspace_b.update_in(cx_b, ChatPanel::new); - chat_panel_b - .update_in(cx_b, |chat_panel, window, cx| { - chat_panel.set_active(true, window, cx); - chat_panel.select_channel(channel_id, None, cx) - }) - .await - .unwrap(); - - executor.run_until_parked(); - - let b_has_messages = cx_b.update(|_, cx| { - client_b - .channel_store() - .read(cx) - .has_new_messages(channel_id) - }); - - assert!(!b_has_messages); - - // Sending a message while the chat is open should not change the flag. - channel_chat_a - .update(cx_a, |c, cx| c.send_message("two".into(), cx).unwrap()) - .await - .unwrap(); - - executor.run_until_parked(); - - let b_has_messages = cx_b.update(|_, cx| { - client_b - .channel_store() - .read(cx) - .has_new_messages(channel_id) - }); - - assert!(!b_has_messages); - - // Sending a message while the chat is closed should change the flag. - chat_panel_b.update_in(cx_b, |chat_panel, window, cx| { - chat_panel.set_active(false, window, cx); - }); - - // Sending a message while the chat is open should not change the flag. - channel_chat_a - .update(cx_a, |c, cx| c.send_message("three".into(), cx).unwrap()) - .await - .unwrap(); - - executor.run_until_parked(); - - let b_has_messages = cx_b.update(|_, cx| { - client_b - .channel_store() - .read(cx) - .has_new_messages(channel_id) - }); - - assert!(b_has_messages); - - // Closing the chat should re-enable change tracking - cx_b.update(|_, _| drop(chat_panel_b)); - - channel_chat_a - .update(cx_a, |c, cx| c.send_message("four".into(), cx).unwrap()) - .await - .unwrap(); - - executor.run_until_parked(); - - let b_has_messages = cx_b.update(|_, cx| { - client_b - .channel_store() - .read(cx) - .has_new_messages(channel_id) - }); - - assert!(b_has_messages); -} - -#[gpui::test] -async fn test_chat_replies(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { - let mut server = TestServer::start(cx_a.executor()).await; - let client_a = server.create_client(cx_a, "user_a").await; - let client_b = server.create_client(cx_b, "user_b").await; - - let channel_id = server - .make_channel( - "the-channel", - None, - (&client_a, cx_a), - &mut [(&client_b, cx_b)], - ) - .await; - - // Client A sends a message, client B should see that there is a new message. - let channel_chat_a = client_a - .channel_store() - .update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx)) - .await - .unwrap(); - - let channel_chat_b = client_b - .channel_store() - .update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx)) - .await - .unwrap(); - - let msg_id = channel_chat_a - .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap()) - .await - .unwrap(); - - cx_a.run_until_parked(); - - let reply_id = channel_chat_b - .update(cx_b, |c, cx| { - c.send_message( - MessageParams { - text: "reply".into(), - reply_to_message_id: Some(msg_id), - mentions: Vec::new(), - }, - cx, - ) - .unwrap() - }) - .await - .unwrap(); - - cx_a.run_until_parked(); - - channel_chat_a.update(cx_a, |channel_chat, _| { - assert_eq!( - channel_chat - .find_loaded_message(reply_id) - .unwrap() - .reply_to_message_id, - Some(msg_id), - ) - }); -} - -#[gpui::test] -async fn test_chat_editing(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { - let mut server = TestServer::start(cx_a.executor()).await; - let client_a = server.create_client(cx_a, "user_a").await; - let client_b = server.create_client(cx_b, "user_b").await; - - let channel_id = server - .make_channel( - "the-channel", - None, - (&client_a, cx_a), - &mut [(&client_b, cx_b)], - ) - .await; - - // Client A sends a message, client B should see that there is a new message. - let channel_chat_a = client_a - .channel_store() - .update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx)) - .await - .unwrap(); - - let channel_chat_b = client_b - .channel_store() - .update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx)) - .await - .unwrap(); - - let msg_id = channel_chat_a - .update(cx_a, |c, cx| { - c.send_message( - MessageParams { - text: "Initial message".into(), - reply_to_message_id: None, - mentions: Vec::new(), - }, - cx, - ) - .unwrap() - }) - .await - .unwrap(); - - cx_a.run_until_parked(); - - channel_chat_a - .update(cx_a, |c, cx| { - c.update_message( - msg_id, - MessageParams { - text: "Updated body".into(), - reply_to_message_id: None, - mentions: Vec::new(), - }, - cx, - ) - .unwrap() - }) - .await - .unwrap(); - - cx_a.run_until_parked(); - cx_b.run_until_parked(); - - channel_chat_a.update(cx_a, |channel_chat, _| { - let update_message = channel_chat.find_loaded_message(msg_id).unwrap(); - - assert_eq!(update_message.body, "Updated body"); - assert_eq!(update_message.mentions, Vec::new()); - }); - channel_chat_b.update(cx_b, |channel_chat, _| { - let update_message = channel_chat.find_loaded_message(msg_id).unwrap(); - - assert_eq!(update_message.body, "Updated body"); - assert_eq!(update_message.mentions, Vec::new()); - }); - - // test mentions are updated correctly - - client_b.notification_store().read_with(cx_b, |store, _| { - assert_eq!(store.notification_count(), 1); - let entry = store.notification_at(0).unwrap(); - assert!(matches!( - entry.notification, - Notification::ChannelInvitation { .. } - ),); - }); - - channel_chat_a - .update(cx_a, |c, cx| { - c.update_message( - msg_id, - MessageParams { - text: "Updated body including a mention for @user_b".into(), - reply_to_message_id: None, - mentions: vec![(37..45, client_b.id())], - }, - cx, - ) - .unwrap() - }) - .await - .unwrap(); - - cx_a.run_until_parked(); - cx_b.run_until_parked(); - - channel_chat_a.update(cx_a, |channel_chat, _| { - assert_eq!( - channel_chat.find_loaded_message(msg_id).unwrap().body, - "Updated body including a mention for @user_b", - ) - }); - channel_chat_b.update(cx_b, |channel_chat, _| { - assert_eq!( - channel_chat.find_loaded_message(msg_id).unwrap().body, - "Updated body including a mention for @user_b", - ) - }); - client_b.notification_store().read_with(cx_b, |store, _| { - assert_eq!(store.notification_count(), 2); - let entry = store.notification_at(0).unwrap(); - assert_eq!( - entry.notification, - Notification::ChannelMessageMention { - message_id: msg_id, - sender_id: client_a.id(), - channel_id: channel_id.0, - } - ); - }); - - // Test update message and keep the mention and check that the body is updated correctly - - channel_chat_a - .update(cx_a, |c, cx| { - c.update_message( - msg_id, - MessageParams { - text: "Updated body v2 including a mention for @user_b".into(), - reply_to_message_id: None, - mentions: vec![(37..45, client_b.id())], - }, - cx, - ) - .unwrap() - }) - .await - .unwrap(); - - cx_a.run_until_parked(); - cx_b.run_until_parked(); - - channel_chat_a.update(cx_a, |channel_chat, _| { - assert_eq!( - channel_chat.find_loaded_message(msg_id).unwrap().body, - "Updated body v2 including a mention for @user_b", - ) - }); - channel_chat_b.update(cx_b, |channel_chat, _| { - assert_eq!( - channel_chat.find_loaded_message(msg_id).unwrap().body, - "Updated body v2 including a mention for @user_b", - ) - }); - - client_b.notification_store().read_with(cx_b, |store, _| { - let message = store.channel_message_for_id(msg_id); - assert!(message.is_some()); - assert_eq!( - message.unwrap().body, - "Updated body v2 including a mention for @user_b" - ); - assert_eq!(store.notification_count(), 2); - let entry = store.notification_at(0).unwrap(); - assert_eq!( - entry.notification, - Notification::ChannelMessageMention { - message_id: msg_id, - sender_id: client_a.id(), - channel_id: channel_id.0, - } - ); - }); - - // If we remove a mention from a message the corresponding mention notification - // should also be removed. - - channel_chat_a - .update(cx_a, |c, cx| { - c.update_message( - msg_id, - MessageParams { - text: "Updated body without a mention".into(), - reply_to_message_id: None, - mentions: vec![], - }, - cx, - ) - .unwrap() - }) - .await - .unwrap(); - - cx_a.run_until_parked(); - cx_b.run_until_parked(); - - channel_chat_a.update(cx_a, |channel_chat, _| { - assert_eq!( - channel_chat.find_loaded_message(msg_id).unwrap().body, - "Updated body without a mention", - ) - }); - channel_chat_b.update(cx_b, |channel_chat, _| { - assert_eq!( - channel_chat.find_loaded_message(msg_id).unwrap().body, - "Updated body without a mention", - ) - }); - client_b.notification_store().read_with(cx_b, |store, _| { - // First notification is the channel invitation, second would be the mention - // notification, which should now be removed. - assert_eq!(store.notification_count(), 1); - }); -} diff --git a/crates/collab/src/tests/editor_tests.rs b/crates/collab/src/tests/editor_tests.rs index 7b95fdd45803554492e2541a83e1ede526e11753..a3f63c527693a19bb7ac1cd87c104cee3d5cfa6e 100644 --- a/crates/collab/src/tests/editor_tests.rs +++ b/crates/collab/src/tests/editor_tests.rs @@ -15,13 +15,14 @@ use editor::{ }, }; use fs::Fs; -use futures::{StreamExt, lock::Mutex}; +use futures::{SinkExt, StreamExt, channel::mpsc, lock::Mutex}; use gpui::{App, Rgba, TestAppContext, UpdateGlobal, VisualContext, VisualTestContext}; use indoc::indoc; use language::{ FakeLspAdapter, language_settings::{AllLanguageSettings, InlayHintSettings}, }; +use lsp::LSP_REQUEST_TIMEOUT; use project::{ ProjectPath, SERVER_PROGRESS_THROTTLE_TIMEOUT, lsp_store::lsp_ext_command::{ExpandedMacro, LspExtExpandMacro}, @@ -368,7 +369,7 @@ async fn test_collaborating_with_completion(cx_a: &mut TestAppContext, cx_b: &mu .set_request_handler::(|params, _| async move { assert_eq!( params.text_document_position.text_document.uri, - lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), ); assert_eq!( params.text_document_position.position, @@ -487,7 +488,7 @@ async fn test_collaborating_with_completion(cx_a: &mut TestAppContext, cx_b: &mu .set_request_handler::(|params, _| async move { assert_eq!( params.text_document_position.text_document.uri, - lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), ); assert_eq!( params.text_document_position.position, @@ -614,7 +615,7 @@ async fn test_collaborating_with_code_actions( .set_request_handler::(|params, _| async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), ); assert_eq!(params.range.start, lsp::Position::new(0, 0)); assert_eq!(params.range.end, lsp::Position::new(0, 0)); @@ -636,7 +637,7 @@ async fn test_collaborating_with_code_actions( .set_request_handler::(|params, _| async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), ); assert_eq!(params.range.start, lsp::Position::new(1, 31)); assert_eq!(params.range.end, lsp::Position::new(1, 31)); @@ -648,7 +649,7 @@ async fn test_collaborating_with_code_actions( changes: Some( [ ( - lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), vec![lsp::TextEdit::new( lsp::Range::new( lsp::Position::new(1, 22), @@ -658,7 +659,7 @@ async fn test_collaborating_with_code_actions( )], ), ( - lsp::Url::from_file_path(path!("/a/other.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/other.rs")).unwrap(), vec![lsp::TextEdit::new( lsp::Range::new( lsp::Position::new(0, 0), @@ -720,7 +721,7 @@ async fn test_collaborating_with_code_actions( changes: Some( [ ( - lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), vec![lsp::TextEdit::new( lsp::Range::new( lsp::Position::new(1, 22), @@ -730,7 +731,7 @@ async fn test_collaborating_with_code_actions( )], ), ( - lsp::Url::from_file_path(path!("/a/other.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/other.rs")).unwrap(), vec![lsp::TextEdit::new( lsp::Range::new( lsp::Position::new(0, 0), @@ -948,14 +949,14 @@ async fn test_collaborating_with_renames(cx_a: &mut TestAppContext, cx_b: &mut T changes: Some( [ ( - lsp::Url::from_file_path(path!("/dir/one.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/dir/one.rs")).unwrap(), vec![lsp::TextEdit::new( lsp::Range::new(lsp::Position::new(0, 6), lsp::Position::new(0, 9)), "THREE".to_string(), )], ), ( - lsp::Url::from_file_path(path!("/dir/two.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/dir/two.rs")).unwrap(), vec![ lsp::TextEdit::new( lsp::Range::new( @@ -1017,6 +1018,211 @@ async fn test_collaborating_with_renames(cx_a: &mut TestAppContext, cx_b: &mut T }) } +#[gpui::test] +async fn test_slow_lsp_server(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { + let mut server = TestServer::start(cx_a.executor()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + cx_b.update(editor::init); + + let command_name = "test_command"; + let capabilities = lsp::ServerCapabilities { + code_lens_provider: Some(lsp::CodeLensOptions { + resolve_provider: None, + }), + execute_command_provider: Some(lsp::ExecuteCommandOptions { + commands: vec![command_name.to_string()], + ..lsp::ExecuteCommandOptions::default() + }), + ..lsp::ServerCapabilities::default() + }; + client_a.language_registry().add(rust_lang()); + let mut fake_language_servers = client_a.language_registry().register_fake_lsp( + "Rust", + FakeLspAdapter { + capabilities: capabilities.clone(), + ..FakeLspAdapter::default() + }, + ); + client_b.language_registry().add(rust_lang()); + client_b.language_registry().register_fake_lsp_adapter( + "Rust", + FakeLspAdapter { + capabilities, + ..FakeLspAdapter::default() + }, + ); + + client_a + .fs() + .insert_tree( + path!("/dir"), + json!({ + "one.rs": "const ONE: usize = 1;" + }), + ) + .await; + let (project_a, worktree_id) = client_a.build_local_project(path!("/dir"), cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.join_remote_project(project_id, cx_b).await; + + let (workspace_b, cx_b) = client_b.build_workspace(&project_b, cx_b); + let editor_b = workspace_b + .update_in(cx_b, |workspace, window, cx| { + workspace.open_path((worktree_id, "one.rs"), None, true, window, cx) + }) + .await + .unwrap() + .downcast::() + .unwrap(); + let (lsp_store_b, buffer_b) = editor_b.update(cx_b, |editor, cx| { + let lsp_store = editor.project().unwrap().read(cx).lsp_store(); + let buffer = editor.buffer().read(cx).as_singleton().unwrap(); + (lsp_store, buffer) + }); + let fake_language_server = fake_language_servers.next().await.unwrap(); + cx_a.run_until_parked(); + cx_b.run_until_parked(); + + let long_request_time = LSP_REQUEST_TIMEOUT / 2; + let (request_started_tx, mut request_started_rx) = mpsc::unbounded(); + let requests_started = Arc::new(AtomicUsize::new(0)); + let requests_completed = Arc::new(AtomicUsize::new(0)); + let _lens_requests = fake_language_server + .set_request_handler::({ + let request_started_tx = request_started_tx.clone(); + let requests_started = requests_started.clone(); + let requests_completed = requests_completed.clone(); + move |params, cx| { + let mut request_started_tx = request_started_tx.clone(); + let requests_started = requests_started.clone(); + let requests_completed = requests_completed.clone(); + async move { + assert_eq!( + params.text_document.uri.as_str(), + uri!("file:///dir/one.rs") + ); + requests_started.fetch_add(1, atomic::Ordering::Release); + request_started_tx.send(()).await.unwrap(); + cx.background_executor().timer(long_request_time).await; + let i = requests_completed.fetch_add(1, atomic::Ordering::Release) + 1; + Ok(Some(vec![lsp::CodeLens { + range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 9)), + command: Some(lsp::Command { + title: format!("LSP Command {i}"), + command: command_name.to_string(), + arguments: None, + }), + data: None, + }])) + } + } + }); + + // Move cursor to a location, this should trigger the code lens call. + editor_b.update_in(cx_b, |editor, window, cx| { + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_ranges([7..7]) + }); + }); + let () = request_started_rx.next().await.unwrap(); + assert_eq!( + requests_started.load(atomic::Ordering::Acquire), + 1, + "Selection change should have initiated the first request" + ); + assert_eq!( + requests_completed.load(atomic::Ordering::Acquire), + 0, + "Slow requests should be running still" + ); + let _first_task = lsp_store_b.update(cx_b, |lsp_store, cx| { + lsp_store + .forget_code_lens_task(buffer_b.read(cx).remote_id()) + .expect("Should have the fetch task started") + }); + + editor_b.update_in(cx_b, |editor, window, cx| { + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_ranges([1..1]) + }); + }); + let () = request_started_rx.next().await.unwrap(); + assert_eq!( + requests_started.load(atomic::Ordering::Acquire), + 2, + "Selection change should have initiated the second request" + ); + assert_eq!( + requests_completed.load(atomic::Ordering::Acquire), + 0, + "Slow requests should be running still" + ); + let _second_task = lsp_store_b.update(cx_b, |lsp_store, cx| { + lsp_store + .forget_code_lens_task(buffer_b.read(cx).remote_id()) + .expect("Should have the fetch task started for the 2nd time") + }); + + editor_b.update_in(cx_b, |editor, window, cx| { + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_ranges([2..2]) + }); + }); + let () = request_started_rx.next().await.unwrap(); + assert_eq!( + requests_started.load(atomic::Ordering::Acquire), + 3, + "Selection change should have initiated the third request" + ); + assert_eq!( + requests_completed.load(atomic::Ordering::Acquire), + 0, + "Slow requests should be running still" + ); + + _first_task.await.unwrap(); + _second_task.await.unwrap(); + cx_b.run_until_parked(); + assert_eq!( + requests_started.load(atomic::Ordering::Acquire), + 3, + "No selection changes should trigger no more code lens requests" + ); + assert_eq!( + requests_completed.load(atomic::Ordering::Acquire), + 3, + "After enough time, all 3 LSP requests should have been served by the language server" + ); + let resulting_lens_actions = editor_b + .update(cx_b, |editor, cx| { + let lsp_store = editor.project().unwrap().read(cx).lsp_store(); + lsp_store.update(cx, |lsp_store, cx| { + lsp_store.code_lens_actions(&buffer_b, cx) + }) + }) + .await + .unwrap() + .unwrap(); + assert_eq!( + resulting_lens_actions.len(), + 1, + "Should have fetched one code lens action, but got: {resulting_lens_actions:?}" + ); + assert_eq!( + resulting_lens_actions.first().unwrap().lsp_action.title(), + "LSP Command 3", + "Only the final code lens action should be in the data" + ) +} + #[gpui::test(iterations = 10)] async fn test_language_server_statuses(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { let mut server = TestServer::start(cx_a.executor()).await; @@ -1368,7 +1574,7 @@ async fn test_on_input_format_from_host_to_guest( |params, _| async move { assert_eq!( params.text_document_position.text_document.uri, - lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), ); assert_eq!( params.text_document_position.position, @@ -1511,7 +1717,7 @@ async fn test_on_input_format_from_guest_to_host( .set_request_handler::(|params, _| async move { assert_eq!( params.text_document_position.text_document.uri, - lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), ); assert_eq!( params.text_document_position.position, @@ -1695,7 +1901,7 @@ async fn test_mutual_editor_inlay_hint_cache_update( async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), ); let edits_made = task_edits_made.load(atomic::Ordering::Acquire); Ok(Some(vec![lsp::InlayHint { @@ -1945,7 +2151,7 @@ async fn test_inlay_hint_refresh_is_forwarded( async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), ); let other_hints = task_other_hints.load(atomic::Ordering::Acquire); let character = if other_hints { 0 } else { 2 }; @@ -2126,7 +2332,7 @@ async fn test_lsp_document_color(cx_a: &mut TestAppContext, cx_b: &mut TestAppCo async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), ); requests_made.fetch_add(1, atomic::Ordering::Release); Ok(vec![lsp::ColorInformation { @@ -2415,11 +2621,11 @@ async fn test_lsp_pull_diagnostics( let requests_made = closure_diagnostics_pulls_made.clone(); let diagnostics_pulls_result_ids = closure_diagnostics_pulls_result_ids.clone(); async move { - let message = if lsp::Url::from_file_path(path!("/a/main.rs")).unwrap() + let message = if lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap() == params.text_document.uri { expected_pull_diagnostic_main_message.to_string() - } else if lsp::Url::from_file_path(path!("/a/lib.rs")).unwrap() + } else if lsp::Uri::from_file_path(path!("/a/lib.rs")).unwrap() == params.text_document.uri { expected_pull_diagnostic_lib_message.to_string() @@ -2511,7 +2717,7 @@ async fn test_lsp_pull_diagnostics( items: vec![ lsp::WorkspaceDocumentDiagnosticReport::Full( lsp::WorkspaceFullDocumentDiagnosticReport { - uri: lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), version: None, full_document_diagnostic_report: lsp::FullDocumentDiagnosticReport { @@ -2540,7 +2746,7 @@ async fn test_lsp_pull_diagnostics( ), lsp::WorkspaceDocumentDiagnosticReport::Full( lsp::WorkspaceFullDocumentDiagnosticReport { - uri: lsp::Url::from_file_path(path!("/a/lib.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/a/lib.rs")).unwrap(), version: None, full_document_diagnostic_report: lsp::FullDocumentDiagnosticReport { @@ -2615,7 +2821,7 @@ async fn test_lsp_pull_diagnostics( fake_language_server.notify::( &lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), diagnostics: vec![lsp::Diagnostic { range: lsp::Range { start: lsp::Position { @@ -2636,7 +2842,7 @@ async fn test_lsp_pull_diagnostics( ); fake_language_server.notify::( &lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/a/lib.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/a/lib.rs")).unwrap(), diagnostics: vec![lsp::Diagnostic { range: lsp::Range { start: lsp::Position { @@ -2664,7 +2870,7 @@ async fn test_lsp_pull_diagnostics( items: vec![ lsp::WorkspaceDocumentDiagnosticReport::Full( lsp::WorkspaceFullDocumentDiagnosticReport { - uri: lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), version: None, full_document_diagnostic_report: lsp::FullDocumentDiagnosticReport { @@ -2696,7 +2902,7 @@ async fn test_lsp_pull_diagnostics( ), lsp::WorkspaceDocumentDiagnosticReport::Full( lsp::WorkspaceFullDocumentDiagnosticReport { - uri: lsp::Url::from_file_path(path!("/a/lib.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/a/lib.rs")).unwrap(), version: None, full_document_diagnostic_report: lsp::FullDocumentDiagnosticReport { @@ -2845,7 +3051,7 @@ async fn test_lsp_pull_diagnostics( lsp::WorkspaceDiagnosticReportResult::Report(lsp::WorkspaceDiagnosticReport { items: vec![lsp::WorkspaceDocumentDiagnosticReport::Full( lsp::WorkspaceFullDocumentDiagnosticReport { - uri: lsp::Url::from_file_path(path!("/a/lib.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/a/lib.rs")).unwrap(), version: None, full_document_diagnostic_report: lsp::FullDocumentDiagnosticReport { result_id: Some(format!( @@ -2908,7 +3114,7 @@ async fn test_lsp_pull_diagnostics( { assert!( - diagnostics_pulls_result_ids.lock().await.len() > 0, + !diagnostics_pulls_result_ids.lock().await.is_empty(), "Initial diagnostics pulls should report None at least" ); assert_eq!( @@ -3219,16 +3425,16 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA assert_eq!( entries, vec![ - Some(blame_entry("1b1b1b", 0..1)), - Some(blame_entry("0d0d0d", 1..2)), - Some(blame_entry("3a3a3a", 2..3)), - Some(blame_entry("4c4c4c", 3..4)), + Some((buffer_id_b, blame_entry("1b1b1b", 0..1))), + Some((buffer_id_b, blame_entry("0d0d0d", 1..2))), + Some((buffer_id_b, blame_entry("3a3a3a", 2..3))), + Some((buffer_id_b, blame_entry("4c4c4c", 3..4))), ] ); blame.update(cx, |blame, _| { - for (idx, entry) in entries.iter().flatten().enumerate() { - let details = blame.details_for_entry(entry).unwrap(); + for (idx, (buffer, entry)) in entries.iter().flatten().enumerate() { + let details = blame.details_for_entry(*buffer, entry).unwrap(); assert_eq!(details.message, format!("message for idx-{}", idx)); assert_eq!( details.permalink.unwrap().to_string(), @@ -3268,9 +3474,9 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA entries, vec![ None, - Some(blame_entry("0d0d0d", 1..2)), - Some(blame_entry("3a3a3a", 2..3)), - Some(blame_entry("4c4c4c", 3..4)), + Some((buffer_id_b, blame_entry("0d0d0d", 1..2))), + Some((buffer_id_b, blame_entry("3a3a3a", 2..3))), + Some((buffer_id_b, blame_entry("4c4c4c", 3..4))), ] ); }); @@ -3305,8 +3511,8 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA vec![ None, None, - Some(blame_entry("3a3a3a", 2..3)), - Some(blame_entry("4c4c4c", 3..4)), + Some((buffer_id_b, blame_entry("3a3a3a", 2..3))), + Some((buffer_id_b, blame_entry("4c4c4c", 3..4))), ] ); }); @@ -3593,7 +3799,7 @@ async fn test_add_breakpoints(cx_a: &mut TestAppContext, cx_b: &mut TestAppConte let abs_path = project_a.read_with(cx_a, |project, cx| { project .absolute_path(&project_path, cx) - .map(|path_buf| Arc::from(path_buf.to_owned())) + .map(Arc::from) .unwrap() }); @@ -3647,20 +3853,16 @@ async fn test_add_breakpoints(cx_a: &mut TestAppContext, cx_b: &mut TestAppConte let breakpoints_a = editor_a.update(cx_a, |editor, cx| { editor .breakpoint_store() - .clone() .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); let breakpoints_b = editor_b.update(cx_b, |editor, cx| { editor .breakpoint_store() - .clone() .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); assert_eq!(1, breakpoints_a.len()); @@ -3680,20 +3882,16 @@ async fn test_add_breakpoints(cx_a: &mut TestAppContext, cx_b: &mut TestAppConte let breakpoints_a = editor_a.update(cx_a, |editor, cx| { editor .breakpoint_store() - .clone() .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); let breakpoints_b = editor_b.update(cx_b, |editor, cx| { editor .breakpoint_store() - .clone() .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); assert_eq!(1, breakpoints_a.len()); @@ -3713,20 +3911,16 @@ async fn test_add_breakpoints(cx_a: &mut TestAppContext, cx_b: &mut TestAppConte let breakpoints_a = editor_a.update(cx_a, |editor, cx| { editor .breakpoint_store() - .clone() .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); let breakpoints_b = editor_b.update(cx_b, |editor, cx| { editor .breakpoint_store() - .clone() .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); assert_eq!(1, breakpoints_a.len()); @@ -3746,20 +3940,16 @@ async fn test_add_breakpoints(cx_a: &mut TestAppContext, cx_b: &mut TestAppConte let breakpoints_a = editor_a.update(cx_a, |editor, cx| { editor .breakpoint_store() - .clone() .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); let breakpoints_b = editor_b.update(cx_b, |editor, cx| { editor .breakpoint_store() - .clone() .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); assert_eq!(0, breakpoints_a.len()); @@ -3850,7 +4040,7 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes |params, _| async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), ); assert_eq!(params.position, lsp::Position::new(0, 0)); Ok(Some(ExpandedMacro { @@ -3885,7 +4075,7 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes |params, _| async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), ); assert_eq!( params.position, diff --git a/crates/collab/src/tests/following_tests.rs b/crates/collab/src/tests/following_tests.rs index d9fd8ffeb2a6c693c3570409070f7a0fbfe33ea2..0a9a69bfca9cdda3fc446ac48e9c63da5e75fe28 100644 --- a/crates/collab/src/tests/following_tests.rs +++ b/crates/collab/src/tests/following_tests.rs @@ -970,7 +970,7 @@ async fn test_peers_following_each_other(cx_a: &mut TestAppContext, cx_b: &mut T // the follow. workspace_b.update_in(cx_b, |workspace, window, cx| { workspace.active_pane().update(cx, |pane, cx| { - pane.activate_prev_item(true, window, cx); + pane.activate_previous_item(&Default::default(), window, cx); }); }); executor.run_until_parked(); @@ -1073,7 +1073,7 @@ async fn test_peers_following_each_other(cx_a: &mut TestAppContext, cx_b: &mut T // Client A cycles through some tabs. workspace_a.update_in(cx_a, |workspace, window, cx| { workspace.active_pane().update(cx, |pane, cx| { - pane.activate_prev_item(true, window, cx); + pane.activate_previous_item(&Default::default(), window, cx); }); }); executor.run_until_parked(); @@ -1117,7 +1117,7 @@ async fn test_peers_following_each_other(cx_a: &mut TestAppContext, cx_b: &mut T workspace_a.update_in(cx_a, |workspace, window, cx| { workspace.active_pane().update(cx, |pane, cx| { - pane.activate_prev_item(true, window, cx); + pane.activate_previous_item(&Default::default(), window, cx); }); }); executor.run_until_parked(); @@ -1164,7 +1164,7 @@ async fn test_peers_following_each_other(cx_a: &mut TestAppContext, cx_b: &mut T workspace_a.update_in(cx_a, |workspace, window, cx| { workspace.active_pane().update(cx, |pane, cx| { - pane.activate_prev_item(true, window, cx); + pane.activate_previous_item(&Default::default(), window, cx); }); }); executor.run_until_parked(); @@ -2098,7 +2098,7 @@ async fn test_following_after_replacement(cx_a: &mut TestAppContext, cx_b: &mut share_workspace(&workspace, cx_a).await.unwrap(); let buffer = workspace.update(cx_a, |workspace, cx| { workspace.project().update(cx, |project, cx| { - project.create_local_buffer(&sample_text(26, 5, 'a'), None, cx) + project.create_local_buffer(&sample_text(26, 5, 'a'), None, false, cx) }) }); let multibuffer = cx_a.new(|cx| { diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index 5a2c40b890cfe32510347c33a1257af4cbea0768..646dbfbd1575756e6955c0d60ae5af64a2760328 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -2506,7 +2506,7 @@ async fn test_propagate_saves_and_fs_changes( }); let new_buffer_a = project_a - .update(cx_a, |p, cx| p.create_buffer(cx)) + .update(cx_a, |p, cx| p.create_buffer(false, cx)) .await .unwrap(); @@ -3208,7 +3208,7 @@ async fn test_fs_operations( }) .await .unwrap() - .to_included() + .into_included() .unwrap(); worktree_a.read_with(cx_a, |worktree, _| { @@ -3237,7 +3237,7 @@ async fn test_fs_operations( }) .await .unwrap() - .to_included() + .into_included() .unwrap(); worktree_a.read_with(cx_a, |worktree, _| { @@ -3266,7 +3266,7 @@ async fn test_fs_operations( }) .await .unwrap() - .to_included() + .into_included() .unwrap(); worktree_a.read_with(cx_a, |worktree, _| { @@ -3295,7 +3295,7 @@ async fn test_fs_operations( }) .await .unwrap() - .to_included() + .into_included() .unwrap(); project_b @@ -3304,7 +3304,7 @@ async fn test_fs_operations( }) .await .unwrap() - .to_included() + .into_included() .unwrap(); project_b @@ -3313,7 +3313,7 @@ async fn test_fs_operations( }) .await .unwrap() - .to_included() + .into_included() .unwrap(); worktree_a.read_with(cx_a, |worktree, _| { @@ -4075,7 +4075,7 @@ async fn test_collaborating_with_diagnostics( .await; fake_language_server.notify::( &lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/a/a.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/a/a.rs")).unwrap(), version: None, diagnostics: vec![lsp::Diagnostic { severity: Some(lsp::DiagnosticSeverity::WARNING), @@ -4095,7 +4095,7 @@ async fn test_collaborating_with_diagnostics( .unwrap(); fake_language_server.notify::( &lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/a/a.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/a/a.rs")).unwrap(), version: None, diagnostics: vec![lsp::Diagnostic { severity: Some(lsp::DiagnosticSeverity::ERROR), @@ -4169,7 +4169,7 @@ async fn test_collaborating_with_diagnostics( // Simulate a language server reporting more errors for a file. fake_language_server.notify::( &lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/a/a.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/a/a.rs")).unwrap(), version: None, diagnostics: vec![ lsp::Diagnostic { @@ -4265,7 +4265,7 @@ async fn test_collaborating_with_diagnostics( // Simulate a language server reporting no errors for a file. fake_language_server.notify::( &lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/a/a.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/a/a.rs")).unwrap(), version: None, diagnostics: Vec::new(), }, @@ -4372,7 +4372,7 @@ async fn test_collaborating_with_lsp_progress_updates_and_diagnostics_ordering( for file_name in file_names { fake_language_server.notify::( &lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(Path::new(path!("/test")).join(file_name)).unwrap(), + uri: lsp::Uri::from_file_path(Path::new(path!("/test")).join(file_name)).unwrap(), version: None, diagnostics: vec![lsp::Diagnostic { severity: Some(lsp::DiagnosticSeverity::WARNING), @@ -4838,7 +4838,7 @@ async fn test_definition( |_, _| async move { Ok(Some(lsp::GotoDefinitionResponse::Scalar( lsp::Location::new( - lsp::Url::from_file_path(path!("/root/dir-2/b.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/root/dir-2/b.rs")).unwrap(), lsp::Range::new(lsp::Position::new(0, 6), lsp::Position::new(0, 9)), ), ))) @@ -4850,6 +4850,7 @@ async fn test_definition( let definitions_1 = project_b .update(cx_b, |p, cx| p.definitions(&buffer_b, 23, cx)) .await + .unwrap() .unwrap(); cx_b.read(|cx| { assert_eq!( @@ -4875,7 +4876,7 @@ async fn test_definition( |_, _| async move { Ok(Some(lsp::GotoDefinitionResponse::Scalar( lsp::Location::new( - lsp::Url::from_file_path(path!("/root/dir-2/b.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/root/dir-2/b.rs")).unwrap(), lsp::Range::new(lsp::Position::new(1, 6), lsp::Position::new(1, 11)), ), ))) @@ -4885,6 +4886,7 @@ async fn test_definition( let definitions_2 = project_b .update(cx_b, |p, cx| p.definitions(&buffer_b, 33, cx)) .await + .unwrap() .unwrap(); cx_b.read(|cx| { assert_eq!(definitions_2.len(), 1); @@ -4912,7 +4914,7 @@ async fn test_definition( ); Ok(Some(lsp::GotoDefinitionResponse::Scalar( lsp::Location::new( - lsp::Url::from_file_path(path!("/root/dir-2/c.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/root/dir-2/c.rs")).unwrap(), lsp::Range::new(lsp::Position::new(0, 5), lsp::Position::new(0, 7)), ), ))) @@ -4922,6 +4924,7 @@ async fn test_definition( let type_definitions = project_b .update(cx_b, |p, cx| p.type_definitions(&buffer_b, 7, cx)) .await + .unwrap() .unwrap(); cx_b.read(|cx| { assert_eq!( @@ -4970,7 +4973,7 @@ async fn test_references( "Rust", FakeLspAdapter { name: "my-fake-lsp-adapter", - capabilities: capabilities, + capabilities, ..FakeLspAdapter::default() }, ); @@ -5046,21 +5049,21 @@ async fn test_references( lsp_response_tx .unbounded_send(Ok(Some(vec![ lsp::Location { - uri: lsp::Url::from_file_path(path!("/root/dir-1/two.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/root/dir-1/two.rs")).unwrap(), range: lsp::Range::new(lsp::Position::new(0, 24), lsp::Position::new(0, 27)), }, lsp::Location { - uri: lsp::Url::from_file_path(path!("/root/dir-1/two.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/root/dir-1/two.rs")).unwrap(), range: lsp::Range::new(lsp::Position::new(0, 35), lsp::Position::new(0, 38)), }, lsp::Location { - uri: lsp::Url::from_file_path(path!("/root/dir-2/three.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/root/dir-2/three.rs")).unwrap(), range: lsp::Range::new(lsp::Position::new(0, 37), lsp::Position::new(0, 40)), }, ]))) .unwrap(); - let references = references.await.unwrap(); + let references = references.await.unwrap().unwrap(); executor.run_until_parked(); project_b.read_with(cx_b, |project, cx| { // User is informed that a request is no longer pending. @@ -5104,7 +5107,7 @@ async fn test_references( lsp_response_tx .unbounded_send(Err(anyhow!("can't find references"))) .unwrap(); - assert_eq!(references.await.unwrap(), []); + assert_eq!(references.await.unwrap().unwrap(), []); // User is informed that the request is no longer pending. executor.run_until_parked(); @@ -5505,7 +5508,8 @@ async fn test_lsp_hover( // Request hover information as the guest. let mut hovers = project_b .update(cx_b, |p, cx| p.hover(&buffer_b, 22, cx)) - .await; + .await + .unwrap(); assert_eq!( hovers.len(), 2, @@ -5621,7 +5625,7 @@ async fn test_project_symbols( lsp::SymbolInformation { name: "TWO".into(), location: lsp::Location { - uri: lsp::Url::from_file_path(path!("/code/crate-2/two.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/code/crate-2/two.rs")).unwrap(), range: lsp::Range::new(lsp::Position::new(0, 6), lsp::Position::new(0, 9)), }, kind: lsp::SymbolKind::CONSTANT, @@ -5733,7 +5737,7 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it( |_, _| async move { Ok(Some(lsp::GotoDefinitionResponse::Scalar( lsp::Location::new( - lsp::Url::from_file_path(path!("/root/b.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/root/b.rs")).unwrap(), lsp::Range::new(lsp::Position::new(0, 6), lsp::Position::new(0, 9)), ), ))) @@ -5742,7 +5746,7 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it( let definitions; let buffer_b2; - if rng.r#gen() { + if rng.random() { cx_a.run_until_parked(); cx_b.run_until_parked(); definitions = project_b.update(cx_b, |p, cx| p.definitions(&buffer_b1, 23, cx)); @@ -5764,7 +5768,7 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it( definitions = project_b.update(cx_b, |p, cx| p.definitions(&buffer_b1, 23, cx)); } - let definitions = definitions.await.unwrap(); + let definitions = definitions.await.unwrap().unwrap(); assert_eq!( definitions.len(), 1, diff --git a/crates/collab/src/tests/random_channel_buffer_tests.rs b/crates/collab/src/tests/random_channel_buffer_tests.rs index c283a9fcd1741ad62c15e4c514df0a81ffb42062..9451090af2198117ddb20241b99be5b208daa729 100644 --- a/crates/collab/src/tests/random_channel_buffer_tests.rs +++ b/crates/collab/src/tests/random_channel_buffer_tests.rs @@ -84,7 +84,7 @@ impl RandomizedTest for RandomChannelBufferTest { } loop { - match rng.gen_range(0..100_u32) { + match rng.random_range(0..100_u32) { 0..=29 => { let channel_name = client.channel_store().read_with(cx, |store, cx| { store.ordered_channels().find_map(|(_, channel)| { @@ -266,7 +266,7 @@ impl RandomizedTest for RandomChannelBufferTest { "client {user_id} has different text than client {prev_user_id} for channel {channel_name}", ); } else { - prev_text = Some((user_id, text.clone())); + prev_text = Some((user_id, text)); } // Assert that all clients and the server agree about who is present in the diff --git a/crates/collab/src/tests/random_project_collaboration_tests.rs b/crates/collab/src/tests/random_project_collaboration_tests.rs index 4d94d041b9b5ca5e6d0ed3bd1f54b5f86f224c56..326f64cb244b88a64728f4347e3cfc31a8c252bf 100644 --- a/crates/collab/src/tests/random_project_collaboration_tests.rs +++ b/crates/collab/src/tests/random_project_collaboration_tests.rs @@ -17,7 +17,7 @@ use project::{ DEFAULT_COMPLETION_CONTEXT, Project, ProjectPath, search::SearchQuery, search::SearchResult, }; use rand::{ - distributions::{Alphanumeric, DistString}, + distr::{self, SampleString}, prelude::*, }; use serde::{Deserialize, Serialize}; @@ -168,19 +168,19 @@ impl RandomizedTest for ProjectCollaborationTest { ) -> ClientOperation { let call = cx.read(ActiveCall::global); loop { - match rng.gen_range(0..100_u32) { + match rng.random_range(0..100_u32) { // Mutate the call 0..=29 => { // Respond to an incoming call if call.read_with(cx, |call, _| call.incoming().borrow().is_some()) { - break if rng.gen_bool(0.7) { + break if rng.random_bool(0.7) { ClientOperation::AcceptIncomingCall } else { ClientOperation::RejectIncomingCall }; } - match rng.gen_range(0..100_u32) { + match rng.random_range(0..100_u32) { // Invite a contact to the current call 0..=70 => { let available_contacts = @@ -212,7 +212,7 @@ impl RandomizedTest for ProjectCollaborationTest { } // Mutate projects - 30..=59 => match rng.gen_range(0..100_u32) { + 30..=59 => match rng.random_range(0..100_u32) { // Open a new project 0..=70 => { // Open a remote project @@ -270,7 +270,7 @@ impl RandomizedTest for ProjectCollaborationTest { } // Mutate project worktrees - 81.. => match rng.gen_range(0..100_u32) { + 81.. => match rng.random_range(0..100_u32) { // Add a worktree to a local project 0..=50 => { let Some(project) = client.local_projects().choose(rng).cloned() else { @@ -279,7 +279,7 @@ impl RandomizedTest for ProjectCollaborationTest { let project_root_name = root_name_for_project(&project, cx); let mut paths = client.fs().paths(false); paths.remove(0); - let new_root_path = if paths.is_empty() || rng.r#gen() { + let new_root_path = if paths.is_empty() || rng.random() { Path::new(path!("/")).join(plan.next_root_dir_name()) } else { paths.choose(rng).unwrap().clone() @@ -304,12 +304,12 @@ impl RandomizedTest for ProjectCollaborationTest { let worktree = worktree.read(cx); worktree.is_visible() && worktree.entries(false, 0).any(|e| e.is_file()) - && worktree.root_entry().map_or(false, |e| e.is_dir()) + && worktree.root_entry().is_some_and(|e| e.is_dir()) }) .choose(rng) }); let Some(worktree) = worktree else { continue }; - let is_dir = rng.r#gen::(); + let is_dir = rng.random::(); let mut full_path = worktree.read_with(cx, |w, _| PathBuf::from(w.root_name())); full_path.push(gen_file_name(rng)); @@ -334,7 +334,7 @@ impl RandomizedTest for ProjectCollaborationTest { let project_root_name = root_name_for_project(&project, cx); let is_local = project.read_with(cx, |project, _| project.is_local()); - match rng.gen_range(0..100_u32) { + match rng.random_range(0..100_u32) { // Manipulate an existing buffer 0..=70 => { let Some(buffer) = client @@ -349,7 +349,7 @@ impl RandomizedTest for ProjectCollaborationTest { let full_path = buffer .read_with(cx, |buffer, cx| buffer.file().unwrap().full_path(cx)); - match rng.gen_range(0..100_u32) { + match rng.random_range(0..100_u32) { // Close the buffer 0..=15 => { break ClientOperation::CloseBuffer { @@ -360,7 +360,7 @@ impl RandomizedTest for ProjectCollaborationTest { } // Save the buffer 16..=29 if buffer.read_with(cx, |b, _| b.is_dirty()) => { - let detach = rng.gen_bool(0.3); + let detach = rng.random_bool(0.3); break ClientOperation::SaveBuffer { project_root_name, is_local, @@ -383,17 +383,17 @@ impl RandomizedTest for ProjectCollaborationTest { _ => { let offset = buffer.read_with(cx, |buffer, _| { buffer.clip_offset( - rng.gen_range(0..=buffer.len()), + rng.random_range(0..=buffer.len()), language::Bias::Left, ) }); - let detach = rng.r#gen(); + let detach = rng.random(); break ClientOperation::RequestLspDataInBuffer { project_root_name, full_path, offset, is_local, - kind: match rng.gen_range(0..5_u32) { + kind: match rng.random_range(0..5_u32) { 0 => LspRequestKind::Rename, 1 => LspRequestKind::Highlights, 2 => LspRequestKind::Definition, @@ -407,8 +407,8 @@ impl RandomizedTest for ProjectCollaborationTest { } 71..=80 => { - let query = rng.gen_range('a'..='z').to_string(); - let detach = rng.gen_bool(0.3); + let query = rng.random_range('a'..='z').to_string(); + let detach = rng.random_bool(0.3); break ClientOperation::SearchProject { project_root_name, is_local, @@ -460,7 +460,7 @@ impl RandomizedTest for ProjectCollaborationTest { // Create or update a file or directory 96.. => { - let is_dir = rng.r#gen::(); + let is_dir = rng.random::(); let content; let mut path; let dir_paths = client.fs().directories(false); @@ -470,11 +470,11 @@ impl RandomizedTest for ProjectCollaborationTest { path = dir_paths.choose(rng).unwrap().clone(); path.push(gen_file_name(rng)); } else { - content = Alphanumeric.sample_string(rng, 16); + content = distr::Alphanumeric.sample_string(rng, 16); // Create a new file or overwrite an existing file let file_paths = client.fs().files(); - if file_paths.is_empty() || rng.gen_bool(0.5) { + if file_paths.is_empty() || rng.random_bool(0.5) { path = dir_paths.choose(rng).unwrap().clone(); path.push(gen_file_name(rng)); path.set_extension("rs"); @@ -643,7 +643,7 @@ impl RandomizedTest for ProjectCollaborationTest { ); let project = project.await?; - client.dev_server_projects_mut().push(project.clone()); + client.dev_server_projects_mut().push(project); } ClientOperation::CreateWorktreeEntry { @@ -1090,7 +1090,7 @@ impl RandomizedTest for ProjectCollaborationTest { move |_, cx| { let background = cx.background_executor(); let mut rng = background.rng(); - let count = rng.gen_range::(1..3); + let count = rng.random_range::(1..3); let files = fs.as_fake().files(); let files = (0..count) .map(|_| files.choose(&mut rng).unwrap().clone()) @@ -1101,7 +1101,7 @@ impl RandomizedTest for ProjectCollaborationTest { files .into_iter() .map(|file| lsp::Location { - uri: lsp::Url::from_file_path(file).unwrap(), + uri: lsp::Uri::from_file_path(file).unwrap(), range: Default::default(), }) .collect(), @@ -1117,12 +1117,12 @@ impl RandomizedTest for ProjectCollaborationTest { let background = cx.background_executor(); let mut rng = background.rng(); - let highlight_count = rng.gen_range(1..=5); + let highlight_count = rng.random_range(1..=5); for _ in 0..highlight_count { - let start_row = rng.gen_range(0..100); - let start_column = rng.gen_range(0..100); - let end_row = rng.gen_range(0..100); - let end_column = rng.gen_range(0..100); + let start_row = rng.random_range(0..100); + let start_column = rng.random_range(0..100); + let end_row = rng.random_range(0..100); + let end_column = rng.random_range(0..100); let start = PointUtf16::new(start_row, start_column); let end = PointUtf16::new(end_row, end_column); let range = @@ -1162,8 +1162,8 @@ impl RandomizedTest for ProjectCollaborationTest { Some((project, cx)) }); - if !guest_project.is_disconnected(cx) { - if let Some((host_project, host_cx)) = host_project { + if !guest_project.is_disconnected(cx) + && let Some((host_project, host_cx)) = host_project { let host_worktree_snapshots = host_project.read_with(host_cx, |host_project, cx| { host_project @@ -1219,8 +1219,8 @@ impl RandomizedTest for ProjectCollaborationTest { guest_project.remote_id(), ); assert_eq!( - guest_snapshot.entries(false, 0).collect::>(), - host_snapshot.entries(false, 0).collect::>(), + guest_snapshot.entries(false, 0).map(null_out_entry_size).collect::>(), + host_snapshot.entries(false, 0).map(null_out_entry_size).collect::>(), "{} has different snapshot than the host for worktree {:?} ({:?}) and project {:?}", client.username, host_snapshot.abs_path(), @@ -1235,7 +1235,6 @@ impl RandomizedTest for ProjectCollaborationTest { ); } } - } for buffer in guest_project.opened_buffers(cx) { let buffer = buffer.read(cx); @@ -1249,6 +1248,18 @@ impl RandomizedTest for ProjectCollaborationTest { ); } }); + + // A hack to work around a hack in + // https://github.com/zed-industries/zed/pull/16696 that wasn't + // detected until we upgraded the rng crate. This whole crate is + // going away with DeltaDB soon, so we hold our nose and + // continue. + fn null_out_entry_size(entry: &project::Entry) -> project::Entry { + project::Entry { + size: 0, + ..entry.clone() + } + } } let buffers = client.buffers().clone(); @@ -1423,7 +1434,7 @@ fn generate_git_operation(rng: &mut StdRng, client: &TestClient) -> GitOperation .filter(|path| path.starts_with(repo_path)) .collect::>(); - let count = rng.gen_range(0..=paths.len()); + let count = rng.random_range(0..=paths.len()); paths.shuffle(rng); paths.truncate(count); @@ -1435,13 +1446,13 @@ fn generate_git_operation(rng: &mut StdRng, client: &TestClient) -> GitOperation let repo_path = client.fs().directories(false).choose(rng).unwrap().clone(); - match rng.gen_range(0..100_u32) { + match rng.random_range(0..100_u32) { 0..=25 => { let file_paths = generate_file_paths(&repo_path, rng, client); let contents = file_paths .into_iter() - .map(|path| (path, Alphanumeric.sample_string(rng, 16))) + .map(|path| (path, distr::Alphanumeric.sample_string(rng, 16))) .collect(); GitOperation::WriteGitIndex { @@ -1450,7 +1461,8 @@ fn generate_git_operation(rng: &mut StdRng, client: &TestClient) -> GitOperation } } 26..=63 => { - let new_branch = (rng.gen_range(0..10) > 3).then(|| Alphanumeric.sample_string(rng, 8)); + let new_branch = + (rng.random_range(0..10) > 3).then(|| distr::Alphanumeric.sample_string(rng, 8)); GitOperation::WriteGitBranch { repo_path, @@ -1597,7 +1609,7 @@ fn choose_random_project(client: &TestClient, rng: &mut StdRng) -> Option String { let mut name = String::new(); for _ in 0..10 { - let letter = rng.gen_range('a'..='z'); + let letter = rng.random_range('a'..='z'); name.push(letter); } name @@ -1605,7 +1617,7 @@ fn gen_file_name(rng: &mut StdRng) -> String { fn gen_status(rng: &mut StdRng) -> FileStatus { fn gen_tracked_status(rng: &mut StdRng) -> TrackedStatus { - match rng.gen_range(0..3) { + match rng.random_range(0..3) { 0 => TrackedStatus { index_status: StatusCode::Unmodified, worktree_status: StatusCode::Unmodified, @@ -1627,7 +1639,7 @@ fn gen_status(rng: &mut StdRng) -> FileStatus { } fn gen_unmerged_status_code(rng: &mut StdRng) -> UnmergedStatusCode { - match rng.gen_range(0..3) { + match rng.random_range(0..3) { 0 => UnmergedStatusCode::Updated, 1 => UnmergedStatusCode::Added, 2 => UnmergedStatusCode::Deleted, @@ -1635,7 +1647,7 @@ fn gen_status(rng: &mut StdRng) -> FileStatus { } } - match rng.gen_range(0..2) { + match rng.random_range(0..2) { 0 => FileStatus::Unmerged(UnmergedStatus { first_head: gen_unmerged_status_code(rng), second_head: gen_unmerged_status_code(rng), diff --git a/crates/collab/src/tests/randomized_test_helpers.rs b/crates/collab/src/tests/randomized_test_helpers.rs index cabf10cfbcec2e13a322ed742745b410ba760fd9..9a372017e34f575f780d56f3936fefec832e160c 100644 --- a/crates/collab/src/tests/randomized_test_helpers.rs +++ b/crates/collab/src/tests/randomized_test_helpers.rs @@ -198,19 +198,19 @@ pub async fn run_randomized_test( } pub fn save_randomized_test_plan() { - if let Some(serialize_plan) = LAST_PLAN.lock().take() { - if let Some(path) = plan_save_path() { - eprintln!("saved test plan to path {:?}", path); - std::fs::write(path, serialize_plan()).unwrap(); - } + if let Some(serialize_plan) = LAST_PLAN.lock().take() + && let Some(path) = plan_save_path() + { + eprintln!("saved test plan to path {:?}", path); + std::fs::write(path, serialize_plan()).unwrap(); } } impl TestPlan { pub async fn new(server: &mut TestServer, mut rng: StdRng) -> Arc> { - let allow_server_restarts = rng.gen_bool(0.7); - let allow_client_reconnection = rng.gen_bool(0.7); - let allow_client_disconnection = rng.gen_bool(0.1); + let allow_server_restarts = rng.random_bool(0.7); + let allow_client_reconnection = rng.random_bool(0.7); + let allow_client_disconnection = rng.random_bool(0.1); let mut users = Vec::new(); for ix in 0..max_peers() { @@ -290,10 +290,9 @@ impl TestPlan { if let StoredOperation::Client { user_id, batch_id, .. } = operation + && batch_id == current_batch_id { - if batch_id == current_batch_id { - return Some(user_id); - } + return Some(user_id); } None })); @@ -366,10 +365,9 @@ impl TestPlan { }, applied, ) = stored_operation + && user_id == ¤t_user_id { - if user_id == ¤t_user_id { - return Some((operation.clone(), applied.clone())); - } + return Some((operation.clone(), applied.clone())); } } None @@ -409,7 +407,7 @@ impl TestPlan { } Some(loop { - break match self.rng.gen_range(0..100) { + break match self.rng.random_range(0..100) { 0..=29 if clients.len() < self.users.len() => { let user = self .users @@ -423,13 +421,13 @@ impl TestPlan { } } 30..=34 if clients.len() > 1 && self.allow_client_disconnection => { - let (client, cx) = &clients[self.rng.gen_range(0..clients.len())]; + let (client, cx) = &clients[self.rng.random_range(0..clients.len())]; let user_id = client.current_user_id(cx); self.operation_ix += 1; ServerOperation::RemoveConnection { user_id } } 35..=39 if clients.len() > 1 && self.allow_client_reconnection => { - let (client, cx) = &clients[self.rng.gen_range(0..clients.len())]; + let (client, cx) = &clients[self.rng.random_range(0..clients.len())]; let user_id = client.current_user_id(cx); self.operation_ix += 1; ServerOperation::BounceConnection { user_id } @@ -441,12 +439,12 @@ impl TestPlan { _ if !clients.is_empty() => { let count = self .rng - .gen_range(1..10) + .random_range(1..10) .min(self.max_operations - self.operation_ix); let batch_id = util::post_inc(&mut self.next_batch_id); let mut user_ids = (0..count) .map(|_| { - let ix = self.rng.gen_range(0..clients.len()); + let ix = self.rng.random_range(0..clients.len()); let (client, cx) = &clients[ix]; client.current_user_id(cx) }) @@ -455,7 +453,7 @@ impl TestPlan { ServerOperation::MutateClients { user_ids, batch_id, - quiesce: self.rng.gen_bool(0.7), + quiesce: self.rng.random_bool(0.7), } } _ => continue, @@ -550,11 +548,11 @@ impl TestPlan { .unwrap(); let pool = server.connection_pool.lock(); for contact in contacts { - if let db::Contact::Accepted { user_id, busy, .. } = contact { - if user_id == removed_user_id { - assert!(!pool.is_user_online(user_id)); - assert!(!busy); - } + if let db::Contact::Accepted { user_id, busy, .. } = contact + && user_id == removed_user_id + { + assert!(!pool.is_user_online(user_id)); + assert!(!busy); } } } diff --git a/crates/collab/src/tests/remote_editing_collaboration_tests.rs b/crates/collab/src/tests/remote_editing_collaboration_tests.rs index 8ab6e6910c88880bc8b6451d972e39b5c2315812..6b46459a59b16717d965b42c4e19820f6d1dc062 100644 --- a/crates/collab/src/tests/remote_editing_collaboration_tests.rs +++ b/crates/collab/src/tests/remote_editing_collaboration_tests.rs @@ -26,7 +26,7 @@ use project::{ debugger::session::ThreadId, lsp_store::{FormatTrigger, LspFormatTarget}, }; -use remote::SshRemoteClient; +use remote::RemoteClient; use remote_server::{HeadlessAppState, HeadlessProject}; use rpc::proto; use serde_json::json; @@ -59,7 +59,7 @@ async fn test_sharing_an_ssh_remote_project( .await; // Set up project on remote FS - let (opts, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx); + let (opts, server_ssh) = RemoteClient::fake_server(cx_a, server_cx); let remote_fs = FakeFs::new(server_cx.executor()); remote_fs .insert_tree( @@ -101,7 +101,7 @@ async fn test_sharing_an_ssh_remote_project( ) }); - let client_ssh = SshRemoteClient::fake_client(opts, cx_a).await; + let client_ssh = RemoteClient::fake_client(opts, cx_a).await; let (project_a, worktree_id) = client_a .build_ssh_project(path!("/code/project1"), client_ssh, cx_a) .await; @@ -235,7 +235,7 @@ async fn test_ssh_collaboration_git_branches( .await; // Set up project on remote FS - let (opts, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx); + let (opts, server_ssh) = RemoteClient::fake_server(cx_a, server_cx); let remote_fs = FakeFs::new(server_cx.executor()); remote_fs .insert_tree("/project", serde_json::json!({ ".git":{} })) @@ -268,7 +268,7 @@ async fn test_ssh_collaboration_git_branches( ) }); - let client_ssh = SshRemoteClient::fake_client(opts, cx_a).await; + let client_ssh = RemoteClient::fake_client(opts, cx_a).await; let (project_a, _) = client_a .build_ssh_project("/project", client_ssh, cx_a) .await; @@ -420,7 +420,7 @@ async fn test_ssh_collaboration_formatting_with_prettier( .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) .await; - let (opts, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx); + let (opts, server_ssh) = RemoteClient::fake_server(cx_a, server_cx); let remote_fs = FakeFs::new(server_cx.executor()); let buffer_text = "let one = \"two\""; let prettier_format_suffix = project::TEST_PRETTIER_FORMAT_SUFFIX; @@ -473,7 +473,7 @@ async fn test_ssh_collaboration_formatting_with_prettier( ) }); - let client_ssh = SshRemoteClient::fake_client(opts, cx_a).await; + let client_ssh = RemoteClient::fake_client(opts, cx_a).await; let (project_a, worktree_id) = client_a .build_ssh_project(path!("/project"), client_ssh, cx_a) .await; @@ -602,7 +602,7 @@ async fn test_remote_server_debugger( release_channel::init(SemanticVersion::default(), cx); dap_adapters::init(cx); }); - let (opts, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx); + let (opts, server_ssh) = RemoteClient::fake_server(cx_a, server_cx); let remote_fs = FakeFs::new(server_cx.executor()); remote_fs .insert_tree( @@ -633,7 +633,7 @@ async fn test_remote_server_debugger( ) }); - let client_ssh = SshRemoteClient::fake_client(opts, cx_a).await; + let client_ssh = RemoteClient::fake_client(opts, cx_a).await; let mut server = TestServer::start(server_cx.executor()).await; let client_a = server.create_client(cx_a, "user_a").await; cx_a.update(|cx| { @@ -711,7 +711,7 @@ async fn test_slow_adapter_startup_retries( release_channel::init(SemanticVersion::default(), cx); dap_adapters::init(cx); }); - let (opts, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx); + let (opts, server_ssh) = RemoteClient::fake_server(cx_a, server_cx); let remote_fs = FakeFs::new(server_cx.executor()); remote_fs .insert_tree( @@ -742,7 +742,7 @@ async fn test_slow_adapter_startup_retries( ) }); - let client_ssh = SshRemoteClient::fake_client(opts, cx_a).await; + let client_ssh = RemoteClient::fake_client(opts, cx_a).await; let mut server = TestServer::start(server_cx.executor()).await; let client_a = server.create_client(cx_a, "user_a").await; cx_a.update(|cx| { diff --git a/crates/collab/src/tests/stripe_billing_tests.rs b/crates/collab/src/tests/stripe_billing_tests.rs deleted file mode 100644 index bb84bedfcfc1fb4f95724f60bbd80707b12c215a..0000000000000000000000000000000000000000 --- a/crates/collab/src/tests/stripe_billing_tests.rs +++ /dev/null @@ -1,123 +0,0 @@ -use std::sync::Arc; - -use pretty_assertions::assert_eq; - -use crate::stripe_billing::StripeBilling; -use crate::stripe_client::{FakeStripeClient, StripePrice, StripePriceId, StripePriceRecurring}; - -fn make_stripe_billing() -> (StripeBilling, Arc) { - let stripe_client = Arc::new(FakeStripeClient::new()); - let stripe_billing = StripeBilling::test(stripe_client.clone()); - - (stripe_billing, stripe_client) -} - -#[gpui::test] -async fn test_initialize() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - // Add test prices - let price1 = StripePrice { - id: StripePriceId("price_1".into()), - unit_amount: Some(1_000), - lookup_key: Some("zed-pro".to_string()), - recurring: None, - }; - let price2 = StripePrice { - id: StripePriceId("price_2".into()), - unit_amount: Some(0), - lookup_key: Some("zed-free".to_string()), - recurring: None, - }; - let price3 = StripePrice { - id: StripePriceId("price_3".into()), - unit_amount: Some(500), - lookup_key: None, - recurring: Some(StripePriceRecurring { - meter: Some("meter_1".to_string()), - }), - }; - stripe_client - .prices - .lock() - .insert(price1.id.clone(), price1); - stripe_client - .prices - .lock() - .insert(price2.id.clone(), price2); - stripe_client - .prices - .lock() - .insert(price3.id.clone(), price3); - - // Initialize the billing system - stripe_billing.initialize().await.unwrap(); - - // Verify that prices can be found by lookup key - let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap(); - assert_eq!(zed_pro_price_id.to_string(), "price_1"); - - let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap(); - assert_eq!(zed_free_price_id.to_string(), "price_2"); - - // Verify that a price can be found by lookup key - let zed_pro_price = stripe_billing - .find_price_by_lookup_key("zed-pro") - .await - .unwrap(); - assert_eq!(zed_pro_price.id.to_string(), "price_1"); - assert_eq!(zed_pro_price.unit_amount, Some(1_000)); - - // Verify that finding a non-existent lookup key returns an error - let result = stripe_billing - .find_price_by_lookup_key("non-existent") - .await; - assert!(result.is_err()); -} - -#[gpui::test] -async fn test_find_or_create_customer_by_email() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - // Create a customer with an email that doesn't yet correspond to a customer. - { - let email = "user@example.com"; - - let customer_id = stripe_billing - .find_or_create_customer_by_email(Some(email)) - .await - .unwrap(); - - let customer = stripe_client - .customers - .lock() - .get(&customer_id) - .unwrap() - .clone(); - assert_eq!(customer.email.as_deref(), Some(email)); - } - - // Create a customer with an email that corresponds to an existing customer. - { - let email = "user2@example.com"; - - let existing_customer_id = stripe_billing - .find_or_create_customer_by_email(Some(email)) - .await - .unwrap(); - - let customer_id = stripe_billing - .find_or_create_customer_by_email(Some(email)) - .await - .unwrap(); - assert_eq!(customer_id, existing_customer_id); - - let customer = stripe_client - .customers - .lock() - .get(&customer_id) - .unwrap() - .clone(); - assert_eq!(customer.email.as_deref(), Some(email)); - } -} diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index f5a0e8ea81f0befbb3bae44ab516a7b8f4b04b52..eb7df28478158a10a0c2d52c3560cad391937383 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -1,4 +1,3 @@ -use crate::stripe_client::FakeStripeClient; use crate::{ AppState, Config, db::{NewUserParams, UserId, tests::TestDb}, @@ -27,7 +26,7 @@ use node_runtime::NodeRuntime; use notifications::NotificationStore; use parking_lot::Mutex; use project::{Project, WorktreeId}; -use remote::SshRemoteClient; +use remote::RemoteClient; use rpc::{ RECEIVE_TIMEOUT, proto::{self, ChannelRole}, @@ -371,8 +370,8 @@ impl TestServer { let client = TestClient { app_state, username: name.to_string(), - channel_store: cx.read(ChannelStore::global).clone(), - notification_store: cx.read(NotificationStore::global).clone(), + channel_store: cx.read(ChannelStore::global), + notification_store: cx.read(NotificationStore::global), state: Default::default(), }; client.wait_for_current_user(cx).await; @@ -566,12 +565,8 @@ impl TestServer { ) -> Arc { Arc::new(AppState { db: test_db.db().clone(), - llm_db: None, livekit_client: Some(Arc::new(livekit_test_server.create_api_client())), blob_store_client: None, - real_stripe_client: None, - stripe_client: Some(Arc::new(FakeStripeClient::new())), - stripe_billing: None, executor, kinesis_client: None, config: Config { @@ -608,7 +603,6 @@ impl TestServer { auto_join_channel_id: None, migrations_path: None, seed_path: None, - stripe_api_key: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, kinesis_region: None, @@ -771,11 +765,11 @@ impl TestClient { pub async fn build_ssh_project( &self, root_path: impl AsRef, - ssh: Entity, + ssh: Entity, cx: &mut TestAppContext, ) -> (Entity, WorktreeId) { let project = cx.update(|cx| { - Project::ssh( + Project::remote( ssh, self.client().clone(), self.app_state.node_runtime.clone(), @@ -903,7 +897,7 @@ impl TestClient { let window = cx.update(|cx| cx.active_window().unwrap().downcast::().unwrap()); let entity = window.root(cx).unwrap(); - let cx = VisualTestContext::from_window(*window.deref(), cx).as_mut(); + let cx = VisualTestContext::from_window(*window.deref(), cx).into_mut(); // it might be nice to try and cleanup these at the end of each test. (entity, cx) } diff --git a/crates/collab/src/user_backfiller.rs b/crates/collab/src/user_backfiller.rs index 71b99a3d4c62560597ce47de926816e741507e44..569a298c9cd5bca6c65e5b7a39b45a784635ad35 100644 --- a/crates/collab/src/user_backfiller.rs +++ b/crates/collab/src/user_backfiller.rs @@ -130,17 +130,17 @@ impl UserBackfiller { .and_then(|value| value.parse::().ok()) .and_then(|value| DateTime::from_timestamp(value, 0)); - if rate_limit_remaining == Some(0) { - if let Some(reset_at) = rate_limit_reset { - let now = Utc::now(); - if reset_at > now { - let sleep_duration = reset_at - now; - log::info!( - "rate limit reached. Sleeping for {} seconds", - sleep_duration.num_seconds() - ); - self.executor.sleep(sleep_duration.to_std().unwrap()).await; - } + if rate_limit_remaining == Some(0) + && let Some(reset_at) = rate_limit_reset + { + let now = Utc::now(); + if reset_at > now { + let sleep_duration = reset_at - now; + log::info!( + "rate limit reached. Sleeping for {} seconds", + sleep_duration.num_seconds() + ); + self.executor.sleep(sleep_duration.to_std().unwrap()).await; } } diff --git a/crates/collab_ui/Cargo.toml b/crates/collab_ui/Cargo.toml index 46ba3ae49639a77dc1e93d0422290fd333acb3ad..34e40d767ea5a9cab115b4186a642ee234337845 100644 --- a/crates/collab_ui/Cargo.toml +++ b/crates/collab_ui/Cargo.toml @@ -37,18 +37,15 @@ client.workspace = true collections.workspace = true db.workspace = true editor.workspace = true -emojis.workspace = true futures.workspace = true fuzzy.workspace = true gpui.workspace = true -language.workspace = true log.workspace = true menu.workspace = true notifications.workspace = true picker.workspace = true project.workspace = true release_channel.workspace = true -rich_text.workspace = true rpc.workspace = true schemars.workspace = true serde.workspace = true diff --git a/crates/collab_ui/src/channel_view.rs b/crates/collab_ui/src/channel_view.rs index b86d72d92faede8c52e40a8e209fde5bf1ea9f0b..61b3e05e48a9fe3da35957b05fcd7dbf7206f146 100644 --- a/crates/collab_ui/src/channel_view.rs +++ b/crates/collab_ui/src/channel_view.rs @@ -66,7 +66,7 @@ impl ChannelView { channel_id, link_position, pane.clone(), - workspace.clone(), + workspace, window, cx, ); @@ -107,43 +107,32 @@ impl ChannelView { .find(|view| view.read(cx).channel_buffer.read(cx).remote_id(cx) == buffer_id); // If this channel buffer is already open in this pane, just return it. - if let Some(existing_view) = existing_view.clone() { - if existing_view.read(cx).channel_buffer == channel_view.read(cx).channel_buffer - { - if let Some(link_position) = link_position { - existing_view.update(cx, |channel_view, cx| { - channel_view.focus_position_from_link( - link_position, - true, - window, - cx, - ) - }); - } - return existing_view; + if let Some(existing_view) = existing_view.clone() + && existing_view.read(cx).channel_buffer == channel_view.read(cx).channel_buffer + { + if let Some(link_position) = link_position { + existing_view.update(cx, |channel_view, cx| { + channel_view.focus_position_from_link(link_position, true, window, cx) + }); } + return existing_view; } // If the pane contained a disconnected view for this channel buffer, // replace that. - if let Some(existing_item) = existing_view { - if let Some(ix) = pane.index_for_item(&existing_item) { - pane.close_item_by_id( - existing_item.entity_id(), - SaveIntent::Skip, - window, - cx, - ) + if let Some(existing_item) = existing_view + && let Some(ix) = pane.index_for_item(&existing_item) + { + pane.close_item_by_id(existing_item.entity_id(), SaveIntent::Skip, window, cx) .detach(); - pane.add_item( - Box::new(channel_view.clone()), - true, - true, - Some(ix), - window, - cx, - ); - } + pane.add_item( + Box::new(channel_view.clone()), + true, + true, + Some(ix), + window, + cx, + ); } if let Some(link_position) = link_position { @@ -259,26 +248,21 @@ impl ChannelView { .editor .update(cx, |editor, cx| editor.snapshot(window, cx)); - if let Some(outline) = snapshot.buffer_snapshot.outline(None) { - if let Some(item) = outline + if let Some(outline) = snapshot.buffer_snapshot.outline(None) + && let Some(item) = outline .items .iter() .find(|item| &Channel::slug(&item.text).to_lowercase() == &position) - { - self.editor.update(cx, |editor, cx| { - editor.change_selections( - SelectionEffects::scroll(Autoscroll::focused()), - window, - cx, - |s| { - s.replace_cursors_with(|map| { - vec![item.range.start.to_display_point(map)] - }) - }, - ) - }); - return; - } + { + self.editor.update(cx, |editor, cx| { + editor.change_selections( + SelectionEffects::scroll(Autoscroll::focused()), + window, + cx, + |s| s.replace_cursors_with(|map| vec![item.range.start.to_display_point(map)]), + ) + }); + return; } if !first_attempt { diff --git a/crates/collab_ui/src/chat_panel.rs b/crates/collab_ui/src/chat_panel.rs deleted file mode 100644 index 51d9f003f813212d40ff8e0716c86b1439fd4de6..0000000000000000000000000000000000000000 --- a/crates/collab_ui/src/chat_panel.rs +++ /dev/null @@ -1,1381 +0,0 @@ -use crate::{ChatPanelButton, ChatPanelSettings, collab_panel}; -use anyhow::Result; -use call::{ActiveCall, room}; -use channel::{ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId, ChannelStore}; -use client::{ChannelId, Client}; -use collections::HashMap; -use db::kvp::KEY_VALUE_STORE; -use editor::{Editor, actions}; -use gpui::{ - Action, App, AsyncWindowContext, ClipboardItem, Context, CursorStyle, DismissEvent, ElementId, - Entity, EventEmitter, FocusHandle, Focusable, FontWeight, HighlightStyle, ListOffset, - ListScrollEvent, ListState, Render, Stateful, Subscription, Task, WeakEntity, Window, actions, - div, list, prelude::*, px, -}; -use language::LanguageRegistry; -use menu::Confirm; -use message_editor::MessageEditor; -use project::Fs; -use rich_text::{Highlight, RichText}; -use serde::{Deserialize, Serialize}; -use settings::Settings; -use std::{sync::Arc, time::Duration}; -use time::{OffsetDateTime, UtcOffset}; -use ui::{ - Avatar, Button, ContextMenu, IconButton, IconName, KeyBinding, Label, PopoverMenu, Tab, TabBar, - Tooltip, prelude::*, -}; -use util::{ResultExt, TryFutureExt}; -use workspace::{ - Workspace, - dock::{DockPosition, Panel, PanelEvent}, -}; - -mod message_editor; - -const MESSAGE_LOADING_THRESHOLD: usize = 50; -const CHAT_PANEL_KEY: &str = "ChatPanel"; - -pub fn init(cx: &mut App) { - cx.observe_new(|workspace: &mut Workspace, _, _| { - workspace.register_action(|workspace, _: &ToggleFocus, window, cx| { - workspace.toggle_panel_focus::(window, cx); - }); - }) - .detach(); -} - -pub struct ChatPanel { - client: Arc, - channel_store: Entity, - languages: Arc, - message_list: ListState, - active_chat: Option<(Entity, Subscription)>, - message_editor: Entity, - local_timezone: UtcOffset, - fs: Arc, - width: Option, - active: bool, - pending_serialization: Task>, - subscriptions: Vec, - is_scrolled_to_bottom: bool, - markdown_data: HashMap, - focus_handle: FocusHandle, - open_context_menu: Option<(u64, Subscription)>, - highlighted_message: Option<(u64, Task<()>)>, - last_acknowledged_message_id: Option, -} - -#[derive(Serialize, Deserialize)] -struct SerializedChatPanel { - width: Option, -} - -actions!( - chat_panel, - [ - /// Toggles focus on the chat panel. - ToggleFocus - ] -); - -impl ChatPanel { - pub fn new( - workspace: &mut Workspace, - window: &mut Window, - cx: &mut Context, - ) -> Entity { - let fs = workspace.app_state().fs.clone(); - let client = workspace.app_state().client.clone(); - let channel_store = ChannelStore::global(cx); - let user_store = workspace.app_state().user_store.clone(); - let languages = workspace.app_state().languages.clone(); - - let input_editor = cx.new(|cx| { - MessageEditor::new( - languages.clone(), - user_store.clone(), - None, - cx.new(|cx| Editor::auto_height(1, 4, window, cx)), - window, - cx, - ) - }); - - cx.new(|cx| { - let message_list = ListState::new(0, gpui::ListAlignment::Bottom, px(1000.)); - - message_list.set_scroll_handler(cx.listener( - |this: &mut Self, event: &ListScrollEvent, _, cx| { - if event.visible_range.start < MESSAGE_LOADING_THRESHOLD { - this.load_more_messages(cx); - } - this.is_scrolled_to_bottom = !event.is_scrolled; - }, - )); - - let local_offset = chrono::Local::now().offset().local_minus_utc(); - let mut this = Self { - fs, - client, - channel_store, - languages, - message_list, - active_chat: Default::default(), - pending_serialization: Task::ready(None), - message_editor: input_editor, - local_timezone: UtcOffset::from_whole_seconds(local_offset).unwrap(), - subscriptions: Vec::new(), - is_scrolled_to_bottom: true, - active: false, - width: None, - markdown_data: Default::default(), - focus_handle: cx.focus_handle(), - open_context_menu: None, - highlighted_message: None, - last_acknowledged_message_id: None, - }; - - if let Some(channel_id) = ActiveCall::global(cx) - .read(cx) - .room() - .and_then(|room| room.read(cx).channel_id()) - { - this.select_channel(channel_id, None, cx) - .detach_and_log_err(cx); - } - - this.subscriptions.push(cx.subscribe( - &ActiveCall::global(cx), - move |this: &mut Self, call, event: &room::Event, cx| match event { - room::Event::RoomJoined { channel_id } => { - if let Some(channel_id) = channel_id { - this.select_channel(*channel_id, None, cx) - .detach_and_log_err(cx); - - if call - .read(cx) - .room() - .is_some_and(|room| room.read(cx).contains_guests()) - { - cx.emit(PanelEvent::Activate) - } - } - } - room::Event::RoomLeft { channel_id } => { - if channel_id == &this.channel_id(cx) { - cx.emit(PanelEvent::Close) - } - } - _ => {} - }, - )); - - this - }) - } - - pub fn channel_id(&self, cx: &App) -> Option { - self.active_chat - .as_ref() - .map(|(chat, _)| chat.read(cx).channel_id) - } - - pub fn is_scrolled_to_bottom(&self) -> bool { - self.is_scrolled_to_bottom - } - - pub fn active_chat(&self) -> Option> { - self.active_chat.as_ref().map(|(chat, _)| chat.clone()) - } - - pub fn load( - workspace: WeakEntity, - cx: AsyncWindowContext, - ) -> Task>> { - cx.spawn(async move |cx| { - let serialized_panel = if let Some(panel) = cx - .background_spawn(async move { KEY_VALUE_STORE.read_kvp(CHAT_PANEL_KEY) }) - .await - .log_err() - .flatten() - { - Some(serde_json::from_str::(&panel)?) - } else { - None - }; - - workspace.update_in(cx, |workspace, window, cx| { - let panel = Self::new(workspace, window, cx); - if let Some(serialized_panel) = serialized_panel { - panel.update(cx, |panel, cx| { - panel.width = serialized_panel.width.map(|r| r.round()); - cx.notify(); - }); - } - panel - }) - }) - } - - fn serialize(&mut self, cx: &mut Context) { - let width = self.width; - self.pending_serialization = cx.background_spawn( - async move { - KEY_VALUE_STORE - .write_kvp( - CHAT_PANEL_KEY.into(), - serde_json::to_string(&SerializedChatPanel { width })?, - ) - .await?; - anyhow::Ok(()) - } - .log_err(), - ); - } - - fn set_active_chat(&mut self, chat: Entity, cx: &mut Context) { - if self.active_chat.as_ref().map(|e| &e.0) != Some(&chat) { - self.markdown_data.clear(); - self.message_list.reset(chat.read(cx).message_count()); - self.message_editor.update(cx, |editor, cx| { - editor.set_channel_chat(chat.clone(), cx); - editor.clear_reply_to_message_id(); - }); - let subscription = cx.subscribe(&chat, Self::channel_did_change); - self.active_chat = Some((chat, subscription)); - self.acknowledge_last_message(cx); - cx.notify(); - } - } - - fn channel_did_change( - &mut self, - _: Entity, - event: &ChannelChatEvent, - cx: &mut Context, - ) { - match event { - ChannelChatEvent::MessagesUpdated { - old_range, - new_count, - } => { - self.message_list.splice(old_range.clone(), *new_count); - if self.active { - self.acknowledge_last_message(cx); - } - } - ChannelChatEvent::UpdateMessage { - message_id, - message_ix, - } => { - self.message_list.splice(*message_ix..*message_ix + 1, 1); - self.markdown_data.remove(message_id); - } - ChannelChatEvent::NewMessage { - channel_id, - message_id, - } => { - if !self.active { - self.channel_store.update(cx, |store, cx| { - store.update_latest_message_id(*channel_id, *message_id, cx) - }) - } - } - } - cx.notify(); - } - - fn acknowledge_last_message(&mut self, cx: &mut Context) { - if self.active && self.is_scrolled_to_bottom { - if let Some((chat, _)) = &self.active_chat { - if let Some(channel_id) = self.channel_id(cx) { - self.last_acknowledged_message_id = self - .channel_store - .read(cx) - .last_acknowledge_message_id(channel_id); - } - - chat.update(cx, |chat, cx| { - chat.acknowledge_last_message(cx); - }); - } - } - } - - fn render_replied_to_message( - &mut self, - message_id: Option, - reply_to_message: &Option, - cx: &mut Context, - ) -> impl IntoElement { - let reply_to_message = match reply_to_message { - None => { - return div().child( - h_flex() - .text_ui_xs(cx) - .my_0p5() - .px_0p5() - .gap_x_1() - .rounded_sm() - .child(Icon::new(IconName::ReplyArrowRight).color(Color::Muted)) - .when(reply_to_message.is_none(), |el| { - el.child( - Label::new("Message has been deleted...") - .size(LabelSize::XSmall) - .color(Color::Muted), - ) - }), - ); - } - Some(val) => val, - }; - - let user_being_replied_to = reply_to_message.sender.clone(); - let message_being_replied_to = reply_to_message.clone(); - - let message_element_id: ElementId = match message_id { - Some(ChannelMessageId::Saved(id)) => ("reply-to-saved-message-container", id).into(), - Some(ChannelMessageId::Pending(id)) => { - ("reply-to-pending-message-container", id).into() - } // This should never happen - None => ("composing-reply-container").into(), - }; - - let current_channel_id = self.channel_id(cx); - let reply_to_message_id = reply_to_message.id; - - div().child( - h_flex() - .id(message_element_id) - .text_ui_xs(cx) - .my_0p5() - .px_0p5() - .gap_x_1() - .rounded_sm() - .overflow_hidden() - .hover(|style| style.bg(cx.theme().colors().element_background)) - .child(Icon::new(IconName::ReplyArrowRight).color(Color::Muted)) - .child(Avatar::new(user_being_replied_to.avatar_uri.clone()).size(rems(0.7))) - .child( - Label::new(format!("@{}", user_being_replied_to.github_login)) - .size(LabelSize::XSmall) - .weight(FontWeight::SEMIBOLD) - .color(Color::Muted), - ) - .child( - div().overflow_y_hidden().child( - Label::new(message_being_replied_to.body.replace('\n', " ")) - .size(LabelSize::XSmall) - .color(Color::Default), - ), - ) - .cursor(CursorStyle::PointingHand) - .tooltip(Tooltip::text("Go to message")) - .on_click(cx.listener(move |chat_panel, _, _, cx| { - if let Some(channel_id) = current_channel_id { - chat_panel - .select_channel(channel_id, reply_to_message_id.into(), cx) - .detach_and_log_err(cx) - } - })), - ) - } - - fn render_message( - &mut self, - ix: usize, - window: &mut Window, - cx: &mut Context, - ) -> AnyElement { - let active_chat = &self.active_chat.as_ref().unwrap().0; - let (message, is_continuation_from_previous, is_admin) = - active_chat.update(cx, |active_chat, cx| { - let is_admin = self - .channel_store - .read(cx) - .is_channel_admin(active_chat.channel_id); - - let last_message = active_chat.message(ix.saturating_sub(1)); - let this_message = active_chat.message(ix).clone(); - - let duration_since_last_message = this_message.timestamp - last_message.timestamp; - let is_continuation_from_previous = last_message.sender.id - == this_message.sender.id - && last_message.id != this_message.id - && duration_since_last_message < Duration::from_secs(5 * 60); - - if let ChannelMessageId::Saved(id) = this_message.id { - if this_message - .mentions - .iter() - .any(|(_, user_id)| Some(*user_id) == self.client.user_id()) - { - active_chat.acknowledge_message(id); - } - } - - (this_message, is_continuation_from_previous, is_admin) - }); - - let _is_pending = message.is_pending(); - - let belongs_to_user = Some(message.sender.id) == self.client.user_id(); - let can_delete_message = belongs_to_user || is_admin; - let can_edit_message = belongs_to_user; - - let element_id: ElementId = match message.id { - ChannelMessageId::Saved(id) => ("saved-message", id).into(), - ChannelMessageId::Pending(id) => ("pending-message", id).into(), - }; - - let mentioning_you = message - .mentions - .iter() - .any(|m| Some(m.1) == self.client.user_id()); - - let message_id = match message.id { - ChannelMessageId::Saved(id) => Some(id), - ChannelMessageId::Pending(_) => None, - }; - - let reply_to_message = message - .reply_to_message_id - .and_then(|id| active_chat.read(cx).find_loaded_message(id)) - .cloned(); - - let replied_to_you = - reply_to_message.as_ref().map(|m| m.sender.id) == self.client.user_id(); - - let is_highlighted_message = self - .highlighted_message - .as_ref() - .is_some_and(|(id, _)| Some(id) == message_id.as_ref()); - let background = if is_highlighted_message { - cx.theme().status().info_background - } else if mentioning_you || replied_to_you { - cx.theme().colors().background - } else { - cx.theme().colors().panel_background - }; - - let reply_to_message_id = self.message_editor.read(cx).reply_to_message_id(); - - v_flex() - .w_full() - .relative() - .group("") - .when(!is_continuation_from_previous, |this| this.pt_2()) - .child( - div() - .group("") - .bg(background) - .rounded_sm() - .overflow_hidden() - .px_1p5() - .py_0p5() - .when_some(reply_to_message_id, |el, reply_id| { - el.when_some(message_id, |el, message_id| { - el.when(reply_id == message_id, |el| { - el.bg(cx.theme().colors().element_selected) - }) - }) - }) - .when(!self.has_open_menu(message_id), |this| { - this.hover(|style| style.bg(cx.theme().colors().element_hover)) - }) - .when(message.reply_to_message_id.is_some(), |el| { - el.child(self.render_replied_to_message( - Some(message.id), - &reply_to_message, - cx, - )) - .when(is_continuation_from_previous, |this| this.mt_2()) - }) - .when( - !is_continuation_from_previous || message.reply_to_message_id.is_some(), - |this| { - this.child( - h_flex() - .gap_2() - .text_ui_sm(cx) - .child( - Avatar::new(message.sender.avatar_uri.clone()) - .size(rems(1.)), - ) - .child( - Label::new(message.sender.github_login.clone()) - .size(LabelSize::Small) - .weight(FontWeight::BOLD), - ) - .child( - Label::new(time_format::format_localized_timestamp( - message.timestamp, - OffsetDateTime::now_utc(), - self.local_timezone, - time_format::TimestampFormat::EnhancedAbsolute, - )) - .size(LabelSize::Small) - .color(Color::Muted), - ), - ) - }, - ) - .when(mentioning_you || replied_to_you, |this| this.my_0p5()) - .map(|el| { - let text = self.markdown_data.entry(message.id).or_insert_with(|| { - Self::render_markdown_with_mentions( - &self.languages, - self.client.id(), - &message, - self.local_timezone, - cx, - ) - }); - el.child( - v_flex() - .w_full() - .text_ui_sm(cx) - .id(element_id) - .child(text.element("body".into(), window, cx)), - ) - .when(self.has_open_menu(message_id), |el| { - el.bg(cx.theme().colors().element_selected) - }) - }), - ) - .when( - self.last_acknowledged_message_id - .is_some_and(|l| Some(l) == message_id), - |this| { - this.child( - h_flex() - .py_2() - .gap_1() - .items_center() - .child(div().w_full().h_0p5().bg(cx.theme().colors().border)) - .child( - div() - .px_1() - .rounded_sm() - .text_ui_xs(cx) - .bg(cx.theme().colors().background) - .child("New messages"), - ) - .child(div().w_full().h_0p5().bg(cx.theme().colors().border)), - ) - }, - ) - .child( - self.render_popover_buttons(message_id, can_delete_message, can_edit_message, cx) - .mt_neg_2p5(), - ) - .into_any_element() - } - - fn has_open_menu(&self, message_id: Option) -> bool { - match self.open_context_menu.as_ref() { - Some((id, _)) => Some(*id) == message_id, - None => false, - } - } - - fn render_popover_button(&self, cx: &mut Context, child: Stateful
) -> Div { - div() - .w_6() - .bg(cx.theme().colors().element_background) - .hover(|style| style.bg(cx.theme().colors().element_hover).rounded_sm()) - .child(child) - } - - fn render_popover_buttons( - &self, - message_id: Option, - can_delete_message: bool, - can_edit_message: bool, - cx: &mut Context, - ) -> Div { - h_flex() - .absolute() - .right_2() - .overflow_hidden() - .rounded_sm() - .border_color(cx.theme().colors().element_selected) - .border_1() - .when(!self.has_open_menu(message_id), |el| { - el.visible_on_hover("") - }) - .bg(cx.theme().colors().element_background) - .when_some(message_id, |el, message_id| { - el.child( - self.render_popover_button( - cx, - div() - .id("reply") - .child( - IconButton::new(("reply", message_id), IconName::ReplyArrowRight) - .on_click(cx.listener(move |this, _, window, cx| { - this.cancel_edit_message(cx); - - this.message_editor.update(cx, |editor, cx| { - editor.set_reply_to_message_id(message_id); - window.focus(&editor.focus_handle(cx)); - }) - })), - ) - .tooltip(Tooltip::text("Reply")), - ), - ) - }) - .when_some(message_id, |el, message_id| { - el.when(can_edit_message, |el| { - el.child( - self.render_popover_button( - cx, - div() - .id("edit") - .child( - IconButton::new(("edit", message_id), IconName::Pencil) - .on_click(cx.listener(move |this, _, window, cx| { - this.message_editor.update(cx, |editor, cx| { - editor.clear_reply_to_message_id(); - - let message = this - .active_chat() - .and_then(|active_chat| { - active_chat - .read(cx) - .find_loaded_message(message_id) - }) - .cloned(); - - if let Some(message) = message { - let buffer = editor - .editor - .read(cx) - .buffer() - .read(cx) - .as_singleton() - .expect("message editor must be singleton"); - - buffer.update(cx, |buffer, cx| { - buffer.set_text(message.body.clone(), cx) - }); - - editor.set_edit_message_id(message_id); - editor.focus_handle(cx).focus(window); - } - }) - })), - ) - .tooltip(Tooltip::text("Edit")), - ), - ) - }) - }) - .when_some(message_id, |el, message_id| { - let this = cx.entity().clone(); - - el.child( - self.render_popover_button( - cx, - div() - .child( - PopoverMenu::new(("menu", message_id)) - .trigger(IconButton::new( - ("trigger", message_id), - IconName::Ellipsis, - )) - .menu(move |window, cx| { - Some(Self::render_message_menu( - &this, - message_id, - can_delete_message, - window, - cx, - )) - }), - ) - .id("more") - .tooltip(Tooltip::text("More")), - ), - ) - }) - } - - fn render_message_menu( - this: &Entity, - message_id: u64, - can_delete_message: bool, - window: &mut Window, - cx: &mut App, - ) -> Entity { - let menu = { - ContextMenu::build(window, cx, move |menu, window, _| { - menu.entry( - "Copy message text", - None, - window.handler_for(this, move |this, _, cx| { - if let Some(message) = this.active_chat().and_then(|active_chat| { - active_chat.read(cx).find_loaded_message(message_id) - }) { - let text = message.body.clone(); - cx.write_to_clipboard(ClipboardItem::new_string(text)) - } - }), - ) - .when(can_delete_message, |menu| { - menu.entry( - "Delete message", - None, - window.handler_for(this, move |this, _, cx| { - this.remove_message(message_id, cx) - }), - ) - }) - }) - }; - this.update(cx, |this, cx| { - let subscription = cx.subscribe_in( - &menu, - window, - |this: &mut Self, _, _: &DismissEvent, _, _| { - this.open_context_menu = None; - }, - ); - this.open_context_menu = Some((message_id, subscription)); - }); - menu - } - - fn render_markdown_with_mentions( - language_registry: &Arc, - current_user_id: u64, - message: &channel::ChannelMessage, - local_timezone: UtcOffset, - cx: &App, - ) -> RichText { - let mentions = message - .mentions - .iter() - .map(|(range, user_id)| rich_text::Mention { - range: range.clone(), - is_self_mention: *user_id == current_user_id, - }) - .collect::>(); - - const MESSAGE_EDITED: &str = " (edited)"; - - let mut body = message.body.clone(); - - if message.edited_at.is_some() { - body.push_str(MESSAGE_EDITED); - } - - let mut rich_text = RichText::new(body, &mentions, language_registry); - - if message.edited_at.is_some() { - let range = (rich_text.text.len() - MESSAGE_EDITED.len())..rich_text.text.len(); - rich_text.highlights.push(( - range.clone(), - Highlight::Highlight(HighlightStyle { - color: Some(cx.theme().colors().text_muted), - ..Default::default() - }), - )); - - if let Some(edit_timestamp) = message.edited_at { - let edit_timestamp_text = time_format::format_localized_timestamp( - edit_timestamp, - OffsetDateTime::now_utc(), - local_timezone, - time_format::TimestampFormat::Absolute, - ); - - rich_text.custom_ranges.push(range); - rich_text.set_tooltip_builder_for_custom_ranges(move |_, _, _, cx| { - Some(Tooltip::simple(edit_timestamp_text.clone(), cx)) - }) - } - } - rich_text - } - - fn send(&mut self, _: &Confirm, window: &mut Window, cx: &mut Context) { - if let Some((chat, _)) = self.active_chat.as_ref() { - let message = self - .message_editor - .update(cx, |editor, cx| editor.take_message(window, cx)); - - if let Some(id) = self.message_editor.read(cx).edit_message_id() { - self.message_editor.update(cx, |editor, _| { - editor.clear_edit_message_id(); - }); - - if let Some(task) = chat - .update(cx, |chat, cx| chat.update_message(id, message, cx)) - .log_err() - { - task.detach(); - } - } else if let Some(task) = chat - .update(cx, |chat, cx| chat.send_message(message, cx)) - .log_err() - { - task.detach(); - } - } - } - - fn remove_message(&mut self, id: u64, cx: &mut Context) { - if let Some((chat, _)) = self.active_chat.as_ref() { - chat.update(cx, |chat, cx| chat.remove_message(id, cx).detach()) - } - } - - fn load_more_messages(&mut self, cx: &mut Context) { - if let Some((chat, _)) = self.active_chat.as_ref() { - chat.update(cx, |channel, cx| { - if let Some(task) = channel.load_more_messages(cx) { - task.detach(); - } - }) - } - } - - pub fn select_channel( - &mut self, - selected_channel_id: ChannelId, - scroll_to_message_id: Option, - cx: &mut Context, - ) -> Task> { - let open_chat = self - .active_chat - .as_ref() - .and_then(|(chat, _)| { - (chat.read(cx).channel_id == selected_channel_id) - .then(|| Task::ready(anyhow::Ok(chat.clone()))) - }) - .unwrap_or_else(|| { - self.channel_store.update(cx, |store, cx| { - store.open_channel_chat(selected_channel_id, cx) - }) - }); - - cx.spawn(async move |this, cx| { - let chat = open_chat.await?; - let highlight_message_id = scroll_to_message_id; - let scroll_to_message_id = this.update(cx, |this, cx| { - this.set_active_chat(chat.clone(), cx); - - scroll_to_message_id.or(this.last_acknowledged_message_id) - })?; - - if let Some(message_id) = scroll_to_message_id { - if let Some(item_ix) = - ChannelChat::load_history_since_message(chat.clone(), message_id, cx.clone()) - .await - { - this.update(cx, |this, cx| { - if let Some(highlight_message_id) = highlight_message_id { - let task = cx.spawn(async move |this, cx| { - cx.background_executor().timer(Duration::from_secs(2)).await; - this.update(cx, |this, cx| { - this.highlighted_message.take(); - cx.notify(); - }) - .ok(); - }); - - this.highlighted_message = Some((highlight_message_id, task)); - } - - if this.active_chat.as_ref().map_or(false, |(c, _)| *c == chat) { - this.message_list.scroll_to(ListOffset { - item_ix, - offset_in_item: px(0.0), - }); - cx.notify(); - } - })?; - } - } - - Ok(()) - }) - } - - fn close_reply_preview(&mut self, cx: &mut Context) { - self.message_editor - .update(cx, |editor, _| editor.clear_reply_to_message_id()); - } - - fn cancel_edit_message(&mut self, cx: &mut Context) { - self.message_editor.update(cx, |editor, cx| { - // only clear the editor input if we were editing a message - if editor.edit_message_id().is_none() { - return; - } - - editor.clear_edit_message_id(); - - let buffer = editor - .editor - .read(cx) - .buffer() - .read(cx) - .as_singleton() - .expect("message editor must be singleton"); - - buffer.update(cx, |buffer, cx| buffer.set_text("", cx)); - }); - } -} - -impl Render for ChatPanel { - fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let channel_id = self - .active_chat - .as_ref() - .map(|(c, _)| c.read(cx).channel_id); - let message_editor = self.message_editor.read(cx); - - let reply_to_message_id = message_editor.reply_to_message_id(); - let edit_message_id = message_editor.edit_message_id(); - - v_flex() - .key_context("ChatPanel") - .track_focus(&self.focus_handle) - .size_full() - .on_action(cx.listener(Self::send)) - .child( - h_flex().child( - TabBar::new("chat_header").child( - h_flex() - .w_full() - .h(Tab::container_height(cx)) - .px_2() - .child(Label::new( - self.active_chat - .as_ref() - .and_then(|c| { - Some(format!("#{}", c.0.read(cx).channel(cx)?.name)) - }) - .unwrap_or("Chat".to_string()), - )), - ), - ), - ) - .child(div().flex_grow().px_2().map(|this| { - if self.active_chat.is_some() { - this.child( - list( - self.message_list.clone(), - cx.processor(Self::render_message), - ) - .size_full(), - ) - } else { - this.child( - div() - .size_full() - .p_4() - .child( - Label::new("Select a channel to chat in.") - .size(LabelSize::Small) - .color(Color::Muted), - ) - .child( - div().pt_1().w_full().items_center().child( - Button::new("toggle-collab", "Open") - .full_width() - .key_binding(KeyBinding::for_action( - &collab_panel::ToggleFocus, - window, - cx, - )) - .on_click(|_, window, cx| { - window.dispatch_action( - collab_panel::ToggleFocus.boxed_clone(), - cx, - ) - }), - ), - ), - ) - } - })) - .when(!self.is_scrolled_to_bottom, |el| { - el.child(div().border_t_1().border_color(cx.theme().colors().border)) - }) - .when_some(edit_message_id, |el, _| { - el.child( - h_flex() - .px_2() - .text_ui_xs(cx) - .justify_between() - .border_t_1() - .border_color(cx.theme().colors().border) - .bg(cx.theme().colors().background) - .child("Editing message") - .child( - IconButton::new("cancel-edit-message", IconName::Close) - .shape(ui::IconButtonShape::Square) - .tooltip(Tooltip::text("Cancel edit message")) - .on_click(cx.listener(move |this, _, _, cx| { - this.cancel_edit_message(cx); - })), - ), - ) - }) - .when_some(reply_to_message_id, |el, reply_to_message_id| { - let reply_message = self - .active_chat() - .and_then(|active_chat| { - active_chat - .read(cx) - .find_loaded_message(reply_to_message_id) - }) - .cloned(); - - el.when_some(reply_message, |el, reply_message| { - let user_being_replied_to = reply_message.sender.clone(); - - el.child( - h_flex() - .when(!self.is_scrolled_to_bottom, |el| { - el.border_t_1().border_color(cx.theme().colors().border) - }) - .justify_between() - .overflow_hidden() - .items_start() - .py_1() - .px_2() - .bg(cx.theme().colors().background) - .child( - div().flex_shrink().overflow_hidden().child( - h_flex() - .id(("reply-preview", reply_to_message_id)) - .child(Label::new("Replying to ").size(LabelSize::Small)) - .child( - Label::new(format!( - "@{}", - user_being_replied_to.github_login - )) - .size(LabelSize::Small) - .weight(FontWeight::BOLD), - ) - .when_some(channel_id, |this, channel_id| { - this.cursor_pointer().on_click(cx.listener( - move |chat_panel, _, _, cx| { - chat_panel - .select_channel( - channel_id, - reply_to_message_id.into(), - cx, - ) - .detach_and_log_err(cx) - }, - )) - }), - ), - ) - .child( - IconButton::new("close-reply-preview", IconName::Close) - .shape(ui::IconButtonShape::Square) - .tooltip(Tooltip::text("Close reply")) - .on_click(cx.listener(move |this, _, _, cx| { - this.close_reply_preview(cx); - })), - ), - ) - }) - }) - .children( - Some( - h_flex() - .p_2() - .on_action(cx.listener(|this, _: &actions::Cancel, _, cx| { - this.cancel_edit_message(cx); - this.close_reply_preview(cx); - })) - .map(|el| el.child(self.message_editor.clone())), - ) - .filter(|_| self.active_chat.is_some()), - ) - .into_any() - } -} - -impl Focusable for ChatPanel { - fn focus_handle(&self, cx: &App) -> gpui::FocusHandle { - if self.active_chat.is_some() { - self.message_editor.read(cx).focus_handle(cx) - } else { - self.focus_handle.clone() - } - } -} - -impl Panel for ChatPanel { - fn position(&self, _: &Window, cx: &App) -> DockPosition { - ChatPanelSettings::get_global(cx).dock - } - - fn position_is_valid(&self, position: DockPosition) -> bool { - matches!(position, DockPosition::Left | DockPosition::Right) - } - - fn set_position(&mut self, position: DockPosition, _: &mut Window, cx: &mut Context) { - settings::update_settings_file::( - self.fs.clone(), - cx, - move |settings, _| settings.dock = Some(position), - ); - } - - fn size(&self, _: &Window, cx: &App) -> Pixels { - self.width - .unwrap_or_else(|| ChatPanelSettings::get_global(cx).default_width) - } - - fn set_size(&mut self, size: Option, _: &mut Window, cx: &mut Context) { - self.width = size; - self.serialize(cx); - cx.notify(); - } - - fn set_active(&mut self, active: bool, _: &mut Window, cx: &mut Context) { - self.active = active; - if active { - self.acknowledge_last_message(cx); - } - } - - fn persistent_name() -> &'static str { - "ChatPanel" - } - - fn icon(&self, _window: &Window, cx: &App) -> Option { - self.enabled(cx).then(|| ui::IconName::Chat) - } - - fn icon_tooltip(&self, _: &Window, _: &App) -> Option<&'static str> { - Some("Chat Panel") - } - - fn toggle_action(&self) -> Box { - Box::new(ToggleFocus) - } - - fn starts_open(&self, _: &Window, cx: &App) -> bool { - ActiveCall::global(cx) - .read(cx) - .room() - .is_some_and(|room| room.read(cx).contains_guests()) - } - - fn activation_priority(&self) -> u32 { - 7 - } - - fn enabled(&self, cx: &App) -> bool { - match ChatPanelSettings::get_global(cx).button { - ChatPanelButton::Never => false, - ChatPanelButton::Always => true, - ChatPanelButton::WhenInCall => { - let is_in_call = ActiveCall::global(cx) - .read(cx) - .room() - .map_or(false, |room| room.read(cx).contains_guests()); - - self.active || is_in_call - } - } - } -} - -impl EventEmitter for ChatPanel {} - -#[cfg(test)] -mod tests { - use super::*; - use gpui::HighlightStyle; - use pretty_assertions::assert_eq; - use rich_text::Highlight; - use time::OffsetDateTime; - use util::test::marked_text_ranges; - - #[gpui::test] - fn test_render_markdown_with_mentions(cx: &mut App) { - let language_registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let (body, ranges) = marked_text_ranges("*hi*, «@abc», let's **call** «@fgh»", false); - let message = channel::ChannelMessage { - id: ChannelMessageId::Saved(0), - body, - timestamp: OffsetDateTime::now_utc(), - sender: Arc::new(client::User { - github_login: "fgh".into(), - avatar_uri: "avatar_fgh".into(), - id: 103, - name: None, - }), - nonce: 5, - mentions: vec![(ranges[0].clone(), 101), (ranges[1].clone(), 102)], - reply_to_message_id: None, - edited_at: None, - }; - - let message = ChatPanel::render_markdown_with_mentions( - &language_registry, - 102, - &message, - UtcOffset::UTC, - cx, - ); - - // Note that the "'" was replaced with ’ due to smart punctuation. - let (body, ranges) = marked_text_ranges("«hi», «@abc», let’s «call» «@fgh»", false); - assert_eq!(message.text, body); - assert_eq!( - message.highlights, - vec![ - ( - ranges[0].clone(), - HighlightStyle { - font_style: Some(gpui::FontStyle::Italic), - ..Default::default() - } - .into() - ), - (ranges[1].clone(), Highlight::Mention), - ( - ranges[2].clone(), - HighlightStyle { - font_weight: Some(gpui::FontWeight::BOLD), - ..Default::default() - } - .into() - ), - (ranges[3].clone(), Highlight::SelfMention) - ] - ); - } - - #[gpui::test] - fn test_render_markdown_with_auto_detect_links(cx: &mut App) { - let language_registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let message = channel::ChannelMessage { - id: ChannelMessageId::Saved(0), - body: "Here is a link https://zed.dev to zeds website".to_string(), - timestamp: OffsetDateTime::now_utc(), - sender: Arc::new(client::User { - github_login: "fgh".into(), - avatar_uri: "avatar_fgh".into(), - id: 103, - name: None, - }), - nonce: 5, - mentions: Vec::new(), - reply_to_message_id: None, - edited_at: None, - }; - - let message = ChatPanel::render_markdown_with_mentions( - &language_registry, - 102, - &message, - UtcOffset::UTC, - cx, - ); - - // Note that the "'" was replaced with ’ due to smart punctuation. - let (body, ranges) = - marked_text_ranges("Here is a link «https://zed.dev» to zeds website", false); - assert_eq!(message.text, body); - assert_eq!(1, ranges.len()); - assert_eq!( - message.highlights, - vec![( - ranges[0].clone(), - HighlightStyle { - underline: Some(gpui::UnderlineStyle { - thickness: 1.0.into(), - ..Default::default() - }), - ..Default::default() - } - .into() - ),] - ); - } - - #[gpui::test] - fn test_render_markdown_with_auto_detect_links_and_additional_formatting(cx: &mut App) { - let language_registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let message = channel::ChannelMessage { - id: ChannelMessageId::Saved(0), - body: "**Here is a link https://zed.dev to zeds website**".to_string(), - timestamp: OffsetDateTime::now_utc(), - sender: Arc::new(client::User { - github_login: "fgh".into(), - avatar_uri: "avatar_fgh".into(), - id: 103, - name: None, - }), - nonce: 5, - mentions: Vec::new(), - reply_to_message_id: None, - edited_at: None, - }; - - let message = ChatPanel::render_markdown_with_mentions( - &language_registry, - 102, - &message, - UtcOffset::UTC, - cx, - ); - - // Note that the "'" was replaced with ’ due to smart punctuation. - let (body, ranges) = marked_text_ranges( - "«Here is a link »«https://zed.dev»« to zeds website»", - false, - ); - assert_eq!(message.text, body); - assert_eq!(3, ranges.len()); - assert_eq!( - message.highlights, - vec![ - ( - ranges[0].clone(), - HighlightStyle { - font_weight: Some(gpui::FontWeight::BOLD), - ..Default::default() - } - .into() - ), - ( - ranges[1].clone(), - HighlightStyle { - font_weight: Some(gpui::FontWeight::BOLD), - underline: Some(gpui::UnderlineStyle { - thickness: 1.0.into(), - ..Default::default() - }), - ..Default::default() - } - .into() - ), - ( - ranges[2].clone(), - HighlightStyle { - font_weight: Some(gpui::FontWeight::BOLD), - ..Default::default() - } - .into() - ), - ] - ); - } -} diff --git a/crates/collab_ui/src/chat_panel/message_editor.rs b/crates/collab_ui/src/chat_panel/message_editor.rs deleted file mode 100644 index 03d39cb8ced169f59167b1a1f6e91102a268a37d..0000000000000000000000000000000000000000 --- a/crates/collab_ui/src/chat_panel/message_editor.rs +++ /dev/null @@ -1,548 +0,0 @@ -use anyhow::{Context as _, Result}; -use channel::{ChannelChat, ChannelStore, MessageParams}; -use client::{UserId, UserStore}; -use collections::HashSet; -use editor::{AnchorRangeExt, CompletionProvider, Editor, EditorElement, EditorStyle, ExcerptId}; -use fuzzy::{StringMatch, StringMatchCandidate}; -use gpui::{ - AsyncApp, AsyncWindowContext, Context, Entity, Focusable, FontStyle, FontWeight, - HighlightStyle, IntoElement, Render, Task, TextStyle, WeakEntity, Window, -}; -use language::{ - Anchor, Buffer, BufferSnapshot, CodeLabel, LanguageRegistry, ToOffset, - language_settings::SoftWrap, -}; -use project::{Completion, CompletionResponse, CompletionSource, search::SearchQuery}; -use settings::Settings; -use std::{ - ops::Range, - rc::Rc, - sync::{Arc, LazyLock}, - time::Duration, -}; -use theme::ThemeSettings; -use ui::{TextSize, prelude::*}; - -use crate::panel_settings::MessageEditorSettings; - -const MENTIONS_DEBOUNCE_INTERVAL: Duration = Duration::from_millis(50); - -static MENTIONS_SEARCH: LazyLock = LazyLock::new(|| { - SearchQuery::regex( - "@[-_\\w]+", - false, - false, - false, - false, - Default::default(), - Default::default(), - false, - None, - ) - .unwrap() -}); - -pub struct MessageEditor { - pub editor: Entity, - user_store: Entity, - channel_chat: Option>, - mentions: Vec, - mentions_task: Option>, - reply_to_message_id: Option, - edit_message_id: Option, -} - -struct MessageEditorCompletionProvider(WeakEntity); - -impl CompletionProvider for MessageEditorCompletionProvider { - fn completions( - &self, - _excerpt_id: ExcerptId, - buffer: &Entity, - buffer_position: language::Anchor, - _: editor::CompletionContext, - _window: &mut Window, - cx: &mut Context, - ) -> Task>> { - let Some(handle) = self.0.upgrade() else { - return Task::ready(Ok(Vec::new())); - }; - handle.update(cx, |message_editor, cx| { - message_editor.completions(buffer, buffer_position, cx) - }) - } - - fn is_completion_trigger( - &self, - _buffer: &Entity, - _position: language::Anchor, - text: &str, - _trigger_in_words: bool, - _menu_is_open: bool, - _cx: &mut Context, - ) -> bool { - text == "@" - } -} - -impl MessageEditor { - pub fn new( - language_registry: Arc, - user_store: Entity, - channel_chat: Option>, - editor: Entity, - window: &mut Window, - cx: &mut Context, - ) -> Self { - let this = cx.entity().downgrade(); - editor.update(cx, |editor, cx| { - editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx); - editor.set_offset_content(false, cx); - editor.set_use_autoclose(false); - editor.set_show_gutter(false, cx); - editor.set_show_wrap_guides(false, cx); - editor.set_show_indent_guides(false, cx); - editor.set_completion_provider(Some(Rc::new(MessageEditorCompletionProvider(this)))); - editor.set_auto_replace_emoji_shortcode( - MessageEditorSettings::get_global(cx) - .auto_replace_emoji_shortcode - .unwrap_or_default(), - ); - }); - - let buffer = editor - .read(cx) - .buffer() - .read(cx) - .as_singleton() - .expect("message editor must be singleton"); - - cx.subscribe_in(&buffer, window, Self::on_buffer_event) - .detach(); - cx.observe_global::(|this, cx| { - this.editor.update(cx, |editor, cx| { - editor.set_auto_replace_emoji_shortcode( - MessageEditorSettings::get_global(cx) - .auto_replace_emoji_shortcode - .unwrap_or_default(), - ) - }) - }) - .detach(); - - let markdown = language_registry.language_for_name("Markdown"); - cx.spawn_in(window, async move |_, cx| { - let markdown = markdown.await.context("failed to load Markdown language")?; - buffer.update(cx, |buffer, cx| buffer.set_language(Some(markdown), cx)) - }) - .detach_and_log_err(cx); - - Self { - editor, - user_store, - channel_chat, - mentions: Vec::new(), - mentions_task: None, - reply_to_message_id: None, - edit_message_id: None, - } - } - - pub fn reply_to_message_id(&self) -> Option { - self.reply_to_message_id - } - - pub fn set_reply_to_message_id(&mut self, reply_to_message_id: u64) { - self.reply_to_message_id = Some(reply_to_message_id); - } - - pub fn clear_reply_to_message_id(&mut self) { - self.reply_to_message_id = None; - } - - pub fn edit_message_id(&self) -> Option { - self.edit_message_id - } - - pub fn set_edit_message_id(&mut self, edit_message_id: u64) { - self.edit_message_id = Some(edit_message_id); - } - - pub fn clear_edit_message_id(&mut self) { - self.edit_message_id = None; - } - - pub fn set_channel_chat(&mut self, chat: Entity, cx: &mut Context) { - let channel_id = chat.read(cx).channel_id; - self.channel_chat = Some(chat); - let channel_name = ChannelStore::global(cx) - .read(cx) - .channel_for_id(channel_id) - .map(|channel| channel.name.clone()); - self.editor.update(cx, |editor, cx| { - if let Some(channel_name) = channel_name { - editor.set_placeholder_text(format!("Message #{channel_name}"), cx); - } else { - editor.set_placeholder_text("Message Channel", cx); - } - }); - } - - pub fn take_message(&mut self, window: &mut Window, cx: &mut Context) -> MessageParams { - self.editor.update(cx, |editor, cx| { - let highlights = editor.text_highlights::(cx); - let text = editor.text(cx); - let snapshot = editor.buffer().read(cx).snapshot(cx); - let mentions = if let Some((_, ranges)) = highlights { - ranges - .iter() - .map(|range| range.to_offset(&snapshot)) - .zip(self.mentions.iter().copied()) - .collect() - } else { - Vec::new() - }; - - editor.clear(window, cx); - self.mentions.clear(); - let reply_to_message_id = std::mem::take(&mut self.reply_to_message_id); - - MessageParams { - text, - mentions, - reply_to_message_id, - } - }) - } - - fn on_buffer_event( - &mut self, - buffer: &Entity, - event: &language::BufferEvent, - window: &mut Window, - cx: &mut Context, - ) { - if let language::BufferEvent::Reparsed | language::BufferEvent::Edited = event { - let buffer = buffer.read(cx).snapshot(); - self.mentions_task = Some(cx.spawn_in(window, async move |this, cx| { - cx.background_executor() - .timer(MENTIONS_DEBOUNCE_INTERVAL) - .await; - Self::find_mentions(this, buffer, cx).await; - })); - } - } - - fn completions( - &mut self, - buffer: &Entity, - end_anchor: Anchor, - cx: &mut Context, - ) -> Task>> { - if let Some((start_anchor, query, candidates)) = - self.collect_mention_candidates(buffer, end_anchor, cx) - { - if !candidates.is_empty() { - return cx.spawn(async move |_, cx| { - let completion_response = Self::completions_for_candidates( - &cx, - query.as_str(), - &candidates, - start_anchor..end_anchor, - Self::completion_for_mention, - ) - .await; - Ok(vec![completion_response]) - }); - } - } - - if let Some((start_anchor, query, candidates)) = - self.collect_emoji_candidates(buffer, end_anchor, cx) - { - if !candidates.is_empty() { - return cx.spawn(async move |_, cx| { - let completion_response = Self::completions_for_candidates( - &cx, - query.as_str(), - candidates, - start_anchor..end_anchor, - Self::completion_for_emoji, - ) - .await; - Ok(vec![completion_response]) - }); - } - } - - Task::ready(Ok(vec![CompletionResponse { - completions: Vec::new(), - is_incomplete: false, - }])) - } - - async fn completions_for_candidates( - cx: &AsyncApp, - query: &str, - candidates: &[StringMatchCandidate], - range: Range, - completion_fn: impl Fn(&StringMatch) -> (String, CodeLabel), - ) -> CompletionResponse { - const LIMIT: usize = 10; - let matches = fuzzy::match_strings( - candidates, - query, - true, - true, - LIMIT, - &Default::default(), - cx.background_executor().clone(), - ) - .await; - - let completions = matches - .into_iter() - .map(|mat| { - let (new_text, label) = completion_fn(&mat); - Completion { - replace_range: range.clone(), - new_text, - label, - icon_path: None, - confirm: None, - documentation: None, - insert_text_mode: None, - source: CompletionSource::Custom, - } - }) - .collect::>(); - - CompletionResponse { - is_incomplete: completions.len() >= LIMIT, - completions, - } - } - - fn completion_for_mention(mat: &StringMatch) -> (String, CodeLabel) { - let label = CodeLabel { - filter_range: 1..mat.string.len() + 1, - text: format!("@{}", mat.string), - runs: Vec::new(), - }; - (mat.string.clone(), label) - } - - fn completion_for_emoji(mat: &StringMatch) -> (String, CodeLabel) { - let emoji = emojis::get_by_shortcode(&mat.string).unwrap(); - let label = CodeLabel { - filter_range: 1..mat.string.len() + 1, - text: format!(":{}: {}", mat.string, emoji), - runs: Vec::new(), - }; - (emoji.to_string(), label) - } - - fn collect_mention_candidates( - &mut self, - buffer: &Entity, - end_anchor: Anchor, - cx: &mut Context, - ) -> Option<(Anchor, String, Vec)> { - let end_offset = end_anchor.to_offset(buffer.read(cx)); - - let query = buffer.read_with(cx, |buffer, _| { - let mut query = String::new(); - for ch in buffer.reversed_chars_at(end_offset).take(100) { - if ch == '@' { - return Some(query.chars().rev().collect::()); - } - if ch.is_whitespace() || !ch.is_ascii() { - break; - } - query.push(ch); - } - None - })?; - - let start_offset = end_offset - query.len(); - let start_anchor = buffer.read(cx).anchor_before(start_offset); - - let mut names = HashSet::default(); - if let Some(chat) = self.channel_chat.as_ref() { - let chat = chat.read(cx); - for participant in ChannelStore::global(cx) - .read(cx) - .channel_participants(chat.channel_id) - { - names.insert(participant.github_login.clone()); - } - for message in chat - .messages_in_range(chat.message_count().saturating_sub(100)..chat.message_count()) - { - names.insert(message.sender.github_login.clone()); - } - } - - let candidates = names - .into_iter() - .map(|user| StringMatchCandidate::new(0, &user)) - .collect::>(); - - Some((start_anchor, query, candidates)) - } - - fn collect_emoji_candidates( - &mut self, - buffer: &Entity, - end_anchor: Anchor, - cx: &mut Context, - ) -> Option<(Anchor, String, &'static [StringMatchCandidate])> { - static EMOJI_FUZZY_MATCH_CANDIDATES: LazyLock> = - LazyLock::new(|| { - let emojis = emojis::iter() - .flat_map(|s| s.shortcodes()) - .map(|emoji| StringMatchCandidate::new(0, emoji)) - .collect::>(); - emojis - }); - - let end_offset = end_anchor.to_offset(buffer.read(cx)); - - let query = buffer.read_with(cx, |buffer, _| { - let mut query = String::new(); - for ch in buffer.reversed_chars_at(end_offset).take(100) { - if ch == ':' { - let next_char = buffer - .reversed_chars_at(end_offset - query.len() - 1) - .next(); - // Ensure we are at the start of the message or that the previous character is a whitespace - if next_char.is_none() || next_char.unwrap().is_whitespace() { - return Some(query.chars().rev().collect::()); - } - - // If the previous character is not a whitespace, we are in the middle of a word - // and we only want to complete the shortcode if the word is made up of other emojis - let mut containing_word = String::new(); - for ch in buffer - .reversed_chars_at(end_offset - query.len() - 1) - .take(100) - { - if ch.is_whitespace() { - break; - } - containing_word.push(ch); - } - let containing_word = containing_word.chars().rev().collect::(); - if util::word_consists_of_emojis(containing_word.as_str()) { - return Some(query.chars().rev().collect::()); - } - break; - } - if ch.is_whitespace() || !ch.is_ascii() { - break; - } - query.push(ch); - } - None - })?; - - let start_offset = end_offset - query.len() - 1; - let start_anchor = buffer.read(cx).anchor_before(start_offset); - - Some((start_anchor, query, &EMOJI_FUZZY_MATCH_CANDIDATES)) - } - - async fn find_mentions( - this: WeakEntity, - buffer: BufferSnapshot, - cx: &mut AsyncWindowContext, - ) { - let (buffer, ranges) = cx - .background_spawn(async move { - let ranges = MENTIONS_SEARCH.search(&buffer, None).await; - (buffer, ranges) - }) - .await; - - this.update(cx, |this, cx| { - let mut anchor_ranges = Vec::new(); - let mut mentioned_user_ids = Vec::new(); - let mut text = String::new(); - - this.editor.update(cx, |editor, cx| { - let multi_buffer = editor.buffer().read(cx).snapshot(cx); - for range in ranges { - text.clear(); - text.extend(buffer.text_for_range(range.clone())); - if let Some(username) = text.strip_prefix('@') { - if let Some(user) = this - .user_store - .read(cx) - .cached_user_by_github_login(username) - { - let start = multi_buffer.anchor_after(range.start); - let end = multi_buffer.anchor_after(range.end); - - mentioned_user_ids.push(user.id); - anchor_ranges.push(start..end); - } - } - } - - editor.clear_highlights::(cx); - editor.highlight_text::( - anchor_ranges, - HighlightStyle { - font_weight: Some(FontWeight::BOLD), - ..Default::default() - }, - cx, - ) - }); - - this.mentions = mentioned_user_ids; - this.mentions_task.take(); - }) - .ok(); - } - - pub(crate) fn focus_handle(&self, cx: &gpui::App) -> gpui::FocusHandle { - self.editor.read(cx).focus_handle(cx) - } -} - -impl Render for MessageEditor { - fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let settings = ThemeSettings::get_global(cx); - let text_style = TextStyle { - color: if self.editor.read(cx).read_only(cx) { - cx.theme().colors().text_disabled - } else { - cx.theme().colors().text - }, - font_family: settings.ui_font.family.clone(), - font_features: settings.ui_font.features.clone(), - font_fallbacks: settings.ui_font.fallbacks.clone(), - font_size: TextSize::Small.rems(cx).into(), - font_weight: settings.ui_font.weight, - font_style: FontStyle::Normal, - line_height: relative(1.3), - ..Default::default() - }; - - div() - .w_full() - .px_2() - .py_1() - .bg(cx.theme().colors().editor_background) - .rounded_sm() - .child(EditorElement::new( - &self.editor, - EditorStyle { - local_player: cx.theme().players().local(), - text: text_style, - ..Default::default() - }, - )) - } -} diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index 430b447580ed385f5b483f8d9fff8a6492c005d7..82e0f84105b57baa47999db9e086542a4f99adf7 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -2,7 +2,7 @@ mod channel_modal; mod contact_finder; use self::channel_modal::ChannelModal; -use crate::{CollaborationPanelSettings, channel_view::ChannelView, chat_panel::ChatPanel}; +use crate::{CollaborationPanelSettings, channel_view::ChannelView}; use anyhow::Context as _; use call::ActiveCall; use channel::{Channel, ChannelEvent, ChannelStore}; @@ -38,7 +38,7 @@ use util::{ResultExt, TryFutureExt, maybe}; use workspace::{ Deafen, LeaveCall, Mute, OpenChannelNotes, ScreenShare, ShareProject, Workspace, dock::{DockPosition, Panel, PanelEvent}, - notifications::{DetachAndPromptErr, NotifyResultExt, NotifyTaskExt}, + notifications::{DetachAndPromptErr, NotifyResultExt}, }; actions!( @@ -95,7 +95,7 @@ pub fn init(cx: &mut App) { .and_then(|room| room.read(cx).channel_id()); if let Some(channel_id) = channel_id { - let workspace = cx.entity().clone(); + let workspace = cx.entity(); window.defer(cx, move |window, cx| { ChannelView::open(channel_id, None, workspace, window, cx) .detach_and_log_err(cx) @@ -261,9 +261,6 @@ enum ListEntry { ChannelNotes { channel_id: ChannelId, }, - ChannelChat { - channel_id: ChannelId, - }, ChannelEditor { depth: usize, }, @@ -283,7 +280,7 @@ impl CollabPanel { cx.new(|cx| { let filter_editor = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); - editor.set_placeholder_text("Filter...", cx); + editor.set_placeholder_text("Filter...", window, cx); editor }); @@ -311,10 +308,10 @@ impl CollabPanel { window, |this: &mut Self, _, event, window, cx| { if let editor::EditorEvent::Blurred = event { - if let Some(state) = &this.channel_editing_state { - if state.pending_name().is_some() { - return; - } + if let Some(state) = &this.channel_editing_state + && state.pending_name().is_some() + { + return; } this.take_editing_state(window, cx); this.update_entries(false, cx); @@ -491,11 +488,10 @@ impl CollabPanel { if !self.collapsed_sections.contains(&Section::ActiveCall) { let room = room.read(cx); - if query.is_empty() { - if let Some(channel_id) = room.channel_id() { - self.entries.push(ListEntry::ChannelNotes { channel_id }); - self.entries.push(ListEntry::ChannelChat { channel_id }); - } + if query.is_empty() + && let Some(channel_id) = room.channel_id() + { + self.entries.push(ListEntry::ChannelNotes { channel_id }); } // Populate the active user. @@ -639,10 +635,10 @@ impl CollabPanel { &Default::default(), executor.clone(), )); - if let Some(state) = &self.channel_editing_state { - if matches!(state, ChannelEditingState::Create { location: None, .. }) { - self.entries.push(ListEntry::ChannelEditor { depth: 0 }); - } + if let Some(state) = &self.channel_editing_state + && matches!(state, ChannelEditingState::Create { location: None, .. }) + { + self.entries.push(ListEntry::ChannelEditor { depth: 0 }); } let mut collapse_depth = None; for mat in matches { @@ -664,9 +660,7 @@ impl CollabPanel { let has_children = channel_store .channel_at_index(mat.candidate_id + 1) - .map_or(false, |next_channel| { - next_channel.parent_path.ends_with(&[channel.id]) - }); + .is_some_and(|next_channel| next_channel.parent_path.ends_with(&[channel.id])); match &self.channel_editing_state { Some(ChannelEditingState::Create { @@ -1091,41 +1085,8 @@ impl CollabPanel { .tooltip(Tooltip::text("Open Channel Notes")) } - fn render_channel_chat( - &self, - channel_id: ChannelId, - is_selected: bool, - window: &mut Window, - cx: &mut Context, - ) -> impl IntoElement { - let channel_store = self.channel_store.read(cx); - let has_messages_notification = channel_store.has_new_messages(channel_id); - ListItem::new("channel-chat") - .toggle_state(is_selected) - .on_click(cx.listener(move |this, _, window, cx| { - this.join_channel_chat(channel_id, window, cx); - })) - .start_slot( - h_flex() - .relative() - .gap_1() - .child(render_tree_branch(false, false, window, cx)) - .child(IconButton::new(0, IconName::Chat)) - .children(has_messages_notification.then(|| { - div() - .w_1p5() - .absolute() - .right(px(2.)) - .top(px(4.)) - .child(Indicator::dot().color(Color::Info)) - })), - ) - .child(Label::new("chat")) - .tooltip(Tooltip::text("Open Chat")) - } - fn has_subchannels(&self, ix: usize) -> bool { - self.entries.get(ix).map_or(false, |entry| { + self.entries.get(ix).is_some_and(|entry| { if let ListEntry::Channel { has_children, .. } = entry { *has_children } else { @@ -1142,7 +1103,7 @@ impl CollabPanel { window: &mut Window, cx: &mut Context, ) { - let this = cx.entity().clone(); + let this = cx.entity(); if !(role == proto::ChannelRole::Guest || role == proto::ChannelRole::Talker || role == proto::ChannelRole::Member) @@ -1272,7 +1233,7 @@ impl CollabPanel { .channel_for_id(clipboard.channel_id) .map(|channel| channel.name.clone()) }); - let this = cx.entity().clone(); + let this = cx.entity(); let context_menu = ContextMenu::build(window, cx, |mut context_menu, window, cx| { if self.has_subchannels(ix) { @@ -1298,13 +1259,6 @@ impl CollabPanel { this.open_channel_notes(channel_id, window, cx) }), ) - .entry( - "Open Chat", - None, - window.handler_for(&this, move |this, window, cx| { - this.join_channel_chat(channel_id, window, cx) - }), - ) .entry( "Copy Channel Link", None, @@ -1439,7 +1393,7 @@ impl CollabPanel { window: &mut Window, cx: &mut Context, ) { - let this = cx.entity().clone(); + let this = cx.entity(); let in_room = ActiveCall::global(cx).read(cx).room().is_some(); let context_menu = ContextMenu::build(window, cx, |mut context_menu, _, _| { @@ -1552,98 +1506,90 @@ impl CollabPanel { return; } - if let Some(selection) = self.selection { - if let Some(entry) = self.entries.get(selection) { - match entry { - ListEntry::Header(section) => match section { - Section::ActiveCall => Self::leave_call(window, cx), - Section::Channels => self.new_root_channel(window, cx), - Section::Contacts => self.toggle_contact_finder(window, cx), - Section::ContactRequests - | Section::Online - | Section::Offline - | Section::ChannelInvites => { - self.toggle_section_expanded(*section, cx); - } - }, - ListEntry::Contact { contact, calling } => { - if contact.online && !contact.busy && !calling { - self.call(contact.user.id, window, cx); - } + if let Some(selection) = self.selection + && let Some(entry) = self.entries.get(selection) + { + match entry { + ListEntry::Header(section) => match section { + Section::ActiveCall => Self::leave_call(window, cx), + Section::Channels => self.new_root_channel(window, cx), + Section::Contacts => self.toggle_contact_finder(window, cx), + Section::ContactRequests + | Section::Online + | Section::Offline + | Section::ChannelInvites => { + self.toggle_section_expanded(*section, cx); } - ListEntry::ParticipantProject { - project_id, - host_user_id, - .. - } => { - if let Some(workspace) = self.workspace.upgrade() { - let app_state = workspace.read(cx).app_state().clone(); - workspace::join_in_room_project( - *project_id, - *host_user_id, - app_state, - cx, - ) + }, + ListEntry::Contact { contact, calling } => { + if contact.online && !contact.busy && !calling { + self.call(contact.user.id, window, cx); + } + } + ListEntry::ParticipantProject { + project_id, + host_user_id, + .. + } => { + if let Some(workspace) = self.workspace.upgrade() { + let app_state = workspace.read(cx).app_state().clone(); + workspace::join_in_room_project(*project_id, *host_user_id, app_state, cx) .detach_and_prompt_err( "Failed to join project", window, cx, |_, _, _| None, ); - } } - ListEntry::ParticipantScreen { peer_id, .. } => { - let Some(peer_id) = peer_id else { - return; - }; - if let Some(workspace) = self.workspace.upgrade() { - workspace.update(cx, |workspace, cx| { - workspace.open_shared_screen(*peer_id, window, cx) - }); - } - } - ListEntry::Channel { channel, .. } => { - let is_active = maybe!({ - let call_channel = ActiveCall::global(cx) - .read(cx) - .room()? - .read(cx) - .channel_id()?; - - Some(call_channel == channel.id) - }) - .unwrap_or(false); - if is_active { - self.open_channel_notes(channel.id, window, cx) - } else { - self.join_channel(channel.id, window, cx) - } - } - ListEntry::ContactPlaceholder => self.toggle_contact_finder(window, cx), - ListEntry::CallParticipant { user, peer_id, .. } => { - if Some(user) == self.user_store.read(cx).current_user().as_ref() { - Self::leave_call(window, cx); - } else if let Some(peer_id) = peer_id { - self.workspace - .update(cx, |workspace, cx| workspace.follow(*peer_id, window, cx)) - .ok(); - } - } - ListEntry::IncomingRequest(user) => { - self.respond_to_contact_request(user.id, true, window, cx) - } - ListEntry::ChannelInvite(channel) => { - self.respond_to_channel_invite(channel.id, true, cx) + } + ListEntry::ParticipantScreen { peer_id, .. } => { + let Some(peer_id) = peer_id else { + return; + }; + if let Some(workspace) = self.workspace.upgrade() { + workspace.update(cx, |workspace, cx| { + workspace.open_shared_screen(*peer_id, window, cx) + }); } - ListEntry::ChannelNotes { channel_id } => { - self.open_channel_notes(*channel_id, window, cx) + } + ListEntry::Channel { channel, .. } => { + let is_active = maybe!({ + let call_channel = ActiveCall::global(cx) + .read(cx) + .room()? + .read(cx) + .channel_id()?; + + Some(call_channel == channel.id) + }) + .unwrap_or(false); + if is_active { + self.open_channel_notes(channel.id, window, cx) + } else { + self.join_channel(channel.id, window, cx) } - ListEntry::ChannelChat { channel_id } => { - self.join_channel_chat(*channel_id, window, cx) + } + ListEntry::ContactPlaceholder => self.toggle_contact_finder(window, cx), + ListEntry::CallParticipant { user, peer_id, .. } => { + if Some(user) == self.user_store.read(cx).current_user().as_ref() { + Self::leave_call(window, cx); + } else if let Some(peer_id) = peer_id { + self.workspace + .update(cx, |workspace, cx| workspace.follow(*peer_id, window, cx)) + .ok(); } - ListEntry::OutgoingRequest(_) => {} - ListEntry::ChannelEditor { .. } => {} } + ListEntry::IncomingRequest(user) => { + self.respond_to_contact_request(user.id, true, window, cx) + } + ListEntry::ChannelInvite(channel) => { + self.respond_to_channel_invite(channel.id, true, cx) + } + ListEntry::ChannelNotes { channel_id } => { + self.open_channel_notes(*channel_id, window, cx) + } + ListEntry::OutgoingRequest(_) => {} + ListEntry::ChannelEditor { .. } => {} } } } @@ -1828,10 +1774,10 @@ impl CollabPanel { } fn select_channel_editor(&mut self) { - self.selection = self.entries.iter().position(|entry| match entry { - ListEntry::ChannelEditor { .. } => true, - _ => false, - }); + self.selection = self + .entries + .iter() + .position(|entry| matches!(entry, ListEntry::ChannelEditor { .. })); } fn new_subchannel( @@ -2265,28 +2211,6 @@ impl CollabPanel { .detach_and_prompt_err("Failed to join channel", window, cx, |_, _, _| None) } - fn join_channel_chat( - &mut self, - channel_id: ChannelId, - window: &mut Window, - cx: &mut Context, - ) { - let Some(workspace) = self.workspace.upgrade() else { - return; - }; - window.defer(cx, move |window, cx| { - workspace.update(cx, |workspace, cx| { - if let Some(panel) = workspace.focus_panel::(window, cx) { - panel.update(cx, |panel, cx| { - panel - .select_channel(channel_id, None, cx) - .detach_and_notify_err(window, cx); - }); - } - }); - }); - } - fn copy_channel_link(&mut self, channel_id: ChannelId, cx: &mut Context) { let channel_store = self.channel_store.read(cx); let Some(channel) = channel_store.channel_for_id(channel_id) else { @@ -2317,7 +2241,7 @@ impl CollabPanel { let client = this.client.clone(); cx.spawn_in(window, async move |_, cx| { client - .connect(true, &cx) + .connect(true, cx) .await .into_response() .notify_async_err(cx); @@ -2405,9 +2329,6 @@ impl CollabPanel { ListEntry::ChannelNotes { channel_id } => self .render_channel_notes(*channel_id, is_selected, window, cx) .into_any_element(), - ListEntry::ChannelChat { channel_id } => self - .render_channel_chat(*channel_id, is_selected, window, cx) - .into_any_element(), } } @@ -2514,7 +2435,7 @@ impl CollabPanel { let button = match section { Section::ActiveCall => channel_link.map(|channel_link| { - let channel_link_copy = channel_link.clone(); + let channel_link_copy = channel_link; IconButton::new("channel-link", IconName::Copy) .icon_size(IconSize::Small) .size(ButtonSize::None) @@ -2698,7 +2619,7 @@ impl CollabPanel { h_flex() .w_full() .justify_between() - .child(Label::new(github_login.clone())) + .child(Label::new(github_login)) .child(h_flex().children(controls)), ) .start_slot(Avatar::new(user.avatar_uri.clone())) @@ -2788,7 +2709,6 @@ impl CollabPanel { let disclosed = has_children.then(|| self.collapsed_channels.binary_search(&channel.id).is_err()); - let has_messages_notification = channel_store.has_new_messages(channel_id); let has_notes_notification = channel_store.has_channel_buffer_changed(channel_id); const FACEPILE_LIMIT: usize = 3; @@ -2912,24 +2832,10 @@ impl CollabPanel { h_flex().absolute().right(rems(0.)).h_full().child( h_flex() .h_full() + .bg(cx.theme().colors().background) + .rounded_l_sm() .gap_1() .px_1() - .child( - IconButton::new("channel_chat", IconName::Chat) - .style(ButtonStyle::Filled) - .shape(ui::IconButtonShape::Square) - .icon_size(IconSize::Small) - .icon_color(if has_messages_notification { - Color::Default - } else { - Color::Muted - }) - .on_click(cx.listener(move |this, _, window, cx| { - this.join_channel_chat(channel_id, window, cx) - })) - .tooltip(Tooltip::text("Open channel chat")) - .visible_on_hover(""), - ) .child( IconButton::new("channel_notes", IconName::Reader) .style(ButtonStyle::Filled) @@ -2943,9 +2849,9 @@ impl CollabPanel { .on_click(cx.listener(move |this, _, window, cx| { this.open_channel_notes(channel_id, window, cx) })) - .tooltip(Tooltip::text("Open channel notes")) - .visible_on_hover(""), - ), + .tooltip(Tooltip::text("Open channel notes")), + ) + .visible_on_hover(""), ), ) .tooltip({ @@ -3053,7 +2959,7 @@ impl Render for CollabPanel { .on_action(cx.listener(CollabPanel::move_channel_down)) .track_focus(&self.focus_handle) .size_full() - .child(if !self.client.status().borrow().is_connected() { + .child(if !self.client.status().borrow().is_or_was_connected() { self.render_signed_out(cx) } else { self.render_signed_in(window, cx) @@ -3132,7 +3038,7 @@ impl Panel for CollabPanel { impl Focusable for CollabPanel { fn focus_handle(&self, cx: &App) -> gpui::FocusHandle { - self.filter_editor.focus_handle(cx).clone() + self.filter_editor.focus_handle(cx) } } @@ -3189,14 +3095,6 @@ impl PartialEq for ListEntry { return channel_id == other_id; } } - ListEntry::ChannelChat { channel_id } => { - if let ListEntry::ChannelChat { - channel_id: other_id, - } = other - { - return channel_id == other_id; - } - } ListEntry::ChannelInvite(channel_1) => { if let ListEntry::ChannelInvite(channel_2) = other { return channel_1.id == channel_2.id; diff --git a/crates/collab_ui/src/collab_panel/channel_modal.rs b/crates/collab_ui/src/collab_panel/channel_modal.rs index c0d3130ee997e3fe2ffffc4b228de9e512f18340..e558835dbaf0e34e2efa1b4f64fd8f6cb96016c5 100644 --- a/crates/collab_ui/src/collab_panel/channel_modal.rs +++ b/crates/collab_ui/src/collab_panel/channel_modal.rs @@ -586,7 +586,7 @@ impl ChannelModalDelegate { return; }; let user_id = membership.user.id; - let picker = cx.entity().clone(); + let picker = cx.entity(); let context_menu = ContextMenu::build(window, cx, |mut menu, _window, _cx| { let role = membership.role; diff --git a/crates/collab_ui/src/collab_panel/contact_finder.rs b/crates/collab_ui/src/collab_panel/contact_finder.rs index 3c23ccc017838e8b97ec334dd432840e516ed413..e5823d0e78d9bf73ae3ded307116f608d7c06b22 100644 --- a/crates/collab_ui/src/collab_panel/contact_finder.rs +++ b/crates/collab_ui/src/collab_panel/contact_finder.rs @@ -148,7 +148,7 @@ impl PickerDelegate for ContactFinderDelegate { _: &mut Window, cx: &mut Context>, ) -> Option { - let user = &self.potential_contacts[ix]; + let user = &self.potential_contacts.get(ix)?; let request_status = self.user_store.read(cx).contact_request_status(user); let icon_path = match request_status { diff --git a/crates/collab_ui/src/collab_ui.rs b/crates/collab_ui/src/collab_ui.rs index f9a2fa492562a89f66459510b1c4aa99edf57080..f75dd663c838c84f167b3070b50a4e1f44e9aa2d 100644 --- a/crates/collab_ui/src/collab_ui.rs +++ b/crates/collab_ui/src/collab_ui.rs @@ -1,5 +1,4 @@ pub mod channel_view; -pub mod chat_panel; pub mod collab_panel; pub mod notification_panel; pub mod notifications; @@ -13,9 +12,7 @@ use gpui::{ WindowDecorations, WindowKind, WindowOptions, point, }; use panel_settings::MessageEditorSettings; -pub use panel_settings::{ - ChatPanelButton, ChatPanelSettings, CollaborationPanelSettings, NotificationPanelSettings, -}; +pub use panel_settings::{CollaborationPanelSettings, NotificationPanelSettings}; use release_channel::ReleaseChannel; use settings::Settings; use ui::px; @@ -23,12 +20,10 @@ use workspace::AppState; pub fn init(app_state: &Arc, cx: &mut App) { CollaborationPanelSettings::register(cx); - ChatPanelSettings::register(cx); NotificationPanelSettings::register(cx); MessageEditorSettings::register(cx); channel_view::init(cx); - chat_panel::init(cx); collab_panel::init(cx); notification_panel::init(cx); notifications::init(app_state, cx); @@ -66,5 +61,7 @@ fn notification_window_options( app_id: Some(app_id.to_owned()), window_min_size: None, window_decorations: Some(WindowDecorations::Client), + tabbing_identifier: None, + ..Default::default() } } diff --git a/crates/collab_ui/src/notification_panel.rs b/crates/collab_ui/src/notification_panel.rs index 3a280ff6677c9a5f9598d5ecaf473af232a8fed1..9731b89521e29ebda21ad5ce2cfca6e0531ae437 100644 --- a/crates/collab_ui/src/notification_panel.rs +++ b/crates/collab_ui/src/notification_panel.rs @@ -1,4 +1,4 @@ -use crate::{NotificationPanelSettings, chat_panel::ChatPanel}; +use crate::NotificationPanelSettings; use anyhow::Result; use channel::ChannelStore; use client::{ChannelId, Client, Notification, User, UserStore}; @@ -6,8 +6,8 @@ use collections::HashMap; use db::kvp::KEY_VALUE_STORE; use futures::StreamExt; use gpui::{ - AnyElement, App, AsyncWindowContext, ClickEvent, Context, CursorStyle, DismissEvent, Element, - Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, ListAlignment, + AnyElement, App, AsyncWindowContext, ClickEvent, Context, DismissEvent, Element, Entity, + EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, ListAlignment, ListScrollEvent, ListState, ParentElement, Render, StatefulInteractiveElement, Styled, Task, WeakEntity, Window, actions, div, img, list, px, }; @@ -71,7 +71,6 @@ pub struct NotificationPresenter { pub text: String, pub icon: &'static str, pub needs_response: bool, - pub can_navigate: bool, } actions!( @@ -121,13 +120,12 @@ impl NotificationPanel { let notification_list = ListState::new(0, ListAlignment::Top, px(1000.)); notification_list.set_scroll_handler(cx.listener( |this, event: &ListScrollEvent, _, cx| { - if event.count.saturating_sub(event.visible_range.end) < LOADING_THRESHOLD { - if let Some(task) = this + if event.count.saturating_sub(event.visible_range.end) < LOADING_THRESHOLD + && let Some(task) = this .notification_store .update(cx, |store, cx| store.load_more_notifications(false, cx)) - { - task.detach(); - } + { + task.detach(); } }, )); @@ -235,7 +233,6 @@ impl NotificationPanel { actor, text, needs_response, - can_navigate, .. } = self.present_notification(entry, cx)?; @@ -270,14 +267,6 @@ impl NotificationPanel { .py_1() .gap_2() .hover(|style| style.bg(cx.theme().colors().element_hover)) - .when(can_navigate, |el| { - el.cursor(CursorStyle::PointingHand).on_click({ - let notification = notification.clone(); - cx.listener(move |this, _, window, cx| { - this.did_click_notification(¬ification, window, cx) - }) - }) - }) .children(actor.map(|actor| { img(actor.avatar_uri.clone()) .flex_none() @@ -290,7 +279,7 @@ impl NotificationPanel { .gap_1() .size_full() .overflow_hidden() - .child(Label::new(text.clone())) + .child(Label::new(text)) .child( h_flex() .child( @@ -321,7 +310,7 @@ impl NotificationPanel { .justify_end() .child(Button::new("decline", "Decline").on_click({ let notification = notification.clone(); - let entity = cx.entity().clone(); + let entity = cx.entity(); move |_, _, cx| { entity.update(cx, |this, cx| { this.respond_to_notification( @@ -334,7 +323,7 @@ impl NotificationPanel { })) .child(Button::new("accept", "Accept").on_click({ let notification = notification.clone(); - let entity = cx.entity().clone(); + let entity = cx.entity(); move |_, _, cx| { entity.update(cx, |this, cx| { this.respond_to_notification( @@ -370,7 +359,6 @@ impl NotificationPanel { text: format!("{} wants to add you as a contact", requester.github_login), needs_response: user_store.has_incoming_contact_request(requester.id), actor: Some(requester), - can_navigate: false, }) } Notification::ContactRequestAccepted { responder_id } => { @@ -380,7 +368,6 @@ impl NotificationPanel { text: format!("{} accepted your contact invite", responder.github_login), needs_response: false, actor: Some(responder), - can_navigate: false, }) } Notification::ChannelInvitation { @@ -397,29 +384,6 @@ impl NotificationPanel { ), needs_response: channel_store.has_channel_invitation(ChannelId(channel_id)), actor: Some(inviter), - can_navigate: false, - }) - } - Notification::ChannelMessageMention { - sender_id, - channel_id, - message_id, - } => { - let sender = user_store.get_cached_user(sender_id)?; - let channel = channel_store.channel_for_id(ChannelId(channel_id))?; - let message = self - .notification_store - .read(cx) - .channel_message_for_id(message_id)?; - Some(NotificationPresenter { - icon: "icons/conversations.svg", - text: format!( - "{} mentioned you in #{}:\n{}", - sender.github_login, channel.name, message.body, - ), - needs_response: false, - actor: Some(sender), - can_navigate: true, }) } } @@ -434,9 +398,7 @@ impl NotificationPanel { ) { let should_mark_as_read = match notification { Notification::ContactRequestAccepted { .. } => true, - Notification::ContactRequest { .. } - | Notification::ChannelInvitation { .. } - | Notification::ChannelMessageMention { .. } => false, + Notification::ContactRequest { .. } | Notification::ChannelInvitation { .. } => false, }; if should_mark_as_read { @@ -458,56 +420,6 @@ impl NotificationPanel { } } - fn did_click_notification( - &mut self, - notification: &Notification, - window: &mut Window, - cx: &mut Context, - ) { - if let Notification::ChannelMessageMention { - message_id, - channel_id, - .. - } = notification.clone() - { - if let Some(workspace) = self.workspace.upgrade() { - window.defer(cx, move |window, cx| { - workspace.update(cx, |workspace, cx| { - if let Some(panel) = workspace.focus_panel::(window, cx) { - panel.update(cx, |panel, cx| { - panel - .select_channel(ChannelId(channel_id), Some(message_id), cx) - .detach_and_log_err(cx); - }); - } - }); - }); - } - } - } - - fn is_showing_notification(&self, notification: &Notification, cx: &mut Context) -> bool { - if !self.active { - return false; - } - - if let Notification::ChannelMessageMention { channel_id, .. } = ¬ification { - if let Some(workspace) = self.workspace.upgrade() { - return if let Some(panel) = workspace.read(cx).panel::(cx) { - let panel = panel.read(cx); - panel.is_scrolled_to_bottom() - && panel - .active_chat() - .map_or(false, |chat| chat.read(cx).channel_id.0 == *channel_id) - } else { - false - }; - } - } - - false - } - fn on_notification_event( &mut self, _: &Entity, @@ -517,9 +429,7 @@ impl NotificationPanel { ) { match event { NotificationEvent::NewNotification { entry } => { - if !self.is_showing_notification(&entry.notification, cx) { - self.unseen_notifications.push(entry.clone()); - } + self.unseen_notifications.push(entry.clone()); self.add_toast(entry, window, cx); } NotificationEvent::NotificationRemoved { entry } @@ -543,10 +453,6 @@ impl NotificationPanel { window: &mut Window, cx: &mut Context, ) { - if self.is_showing_notification(&entry.notification, cx) { - return; - } - let Some(NotificationPresenter { actor, text, .. }) = self.present_notification(entry, cx) else { return; @@ -570,7 +476,6 @@ impl NotificationPanel { workspace.show_notification(id, cx, |cx| { let workspace = cx.entity().downgrade(); cx.new(|cx| NotificationToast { - notification_id, actor, text, workspace, @@ -582,16 +487,16 @@ impl NotificationPanel { } fn remove_toast(&mut self, notification_id: u64, cx: &mut Context) { - if let Some((current_id, _)) = &self.current_notification_toast { - if *current_id == notification_id { - self.current_notification_toast.take(); - self.workspace - .update(cx, |workspace, cx| { - let id = NotificationId::unique::(); - workspace.dismiss_notification(&id, cx) - }) - .ok(); - } + if let Some((current_id, _)) = &self.current_notification_toast + && *current_id == notification_id + { + self.current_notification_toast.take(); + self.workspace + .update(cx, |workspace, cx| { + let id = NotificationId::unique::(); + workspace.dismiss_notification(&id, cx) + }) + .ok(); } } @@ -643,7 +548,7 @@ impl Render for NotificationPanel { let client = client.clone(); window .spawn(cx, async move |cx| { - match client.connect(true, &cx).await { + match client.connect(true, cx).await { util::ConnectionResult::Timeout => { log::error!("Connection timeout"); } @@ -783,7 +688,6 @@ impl Panel for NotificationPanel { } pub struct NotificationToast { - notification_id: u64, actor: Option>, text: String, workspace: WeakEntity, @@ -801,22 +705,10 @@ impl WorkspaceNotification for NotificationToast {} impl NotificationToast { fn focus_notification_panel(&self, window: &mut Window, cx: &mut Context) { let workspace = self.workspace.clone(); - let notification_id = self.notification_id; window.defer(cx, move |window, cx| { workspace .update(cx, |workspace, cx| { - if let Some(panel) = workspace.focus_panel::(window, cx) { - panel.update(cx, |panel, cx| { - let store = panel.notification_store.read(cx); - if let Some(entry) = store.notification_for_id(notification_id) { - panel.did_click_notification( - &entry.clone().notification, - window, - cx, - ); - } - }); - } + workspace.focus_panel::(window, cx) }) .ok(); }) diff --git a/crates/collab_ui/src/panel_settings.rs b/crates/collab_ui/src/panel_settings.rs index 652d9eb67f6ce1f0ab583e20e4feab05cfb743e3..98559ffd34006bf2f65427a899fd1fe5d41a4d11 100644 --- a/crates/collab_ui/src/panel_settings.rs +++ b/crates/collab_ui/src/panel_settings.rs @@ -1,7 +1,7 @@ use gpui::Pixels; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsUi}; use workspace::dock::DockPosition; #[derive(Deserialize, Debug)] @@ -11,31 +11,16 @@ pub struct CollaborationPanelSettings { pub default_width: Pixels, } -#[derive(Clone, Copy, Default, Serialize, Deserialize, JsonSchema, Debug)] -#[serde(rename_all = "snake_case")] -pub enum ChatPanelButton { - Never, - Always, - #[default] - WhenInCall, -} - -#[derive(Deserialize, Debug)] -pub struct ChatPanelSettings { - pub button: ChatPanelButton, - pub dock: DockPosition, - pub default_width: Pixels, -} - -#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)] -pub struct ChatPanelSettingsContent { - /// When to show the panel button in the status bar. +#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug, SettingsUi, SettingsKey)] +#[settings_key(key = "collaboration_panel")] +pub struct PanelSettingsContent { + /// Whether to show the panel button in the status bar. /// - /// Default: only when in a call - pub button: Option, + /// Default: true + pub button: Option, /// Where to dock the panel. /// - /// Default: right + /// Default: left pub dock: Option, /// Default width of the panel in pixels. /// @@ -50,23 +35,25 @@ pub struct NotificationPanelSettings { pub default_width: Pixels, } -#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)] -pub struct PanelSettingsContent { +#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug, SettingsUi, SettingsKey)] +#[settings_key(key = "notification_panel")] +pub struct NotificationPanelSettingsContent { /// Whether to show the panel button in the status bar. /// /// Default: true pub button: Option, /// Where to dock the panel. /// - /// Default: left + /// Default: right pub dock: Option, /// Default width of the panel in pixels. /// - /// Default: 240 + /// Default: 300 pub default_width: Option, } -#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)] +#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug, SettingsUi, SettingsKey)] +#[settings_key(key = "message_editor")] pub struct MessageEditorSettings { /// Whether to automatically replace emoji shortcodes with emoji characters. /// For example: typing `:wave:` gets replaced with `👋`. @@ -76,8 +63,6 @@ pub struct MessageEditorSettings { } impl Settings for CollaborationPanelSettings { - const KEY: Option<&'static str> = Some("collaboration_panel"); - type FileContent = PanelSettingsContent; fn load( @@ -90,25 +75,8 @@ impl Settings for CollaborationPanelSettings { fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} } -impl Settings for ChatPanelSettings { - const KEY: Option<&'static str> = Some("chat_panel"); - - type FileContent = ChatPanelSettingsContent; - - fn load( - sources: SettingsSources, - _: &mut gpui::App, - ) -> anyhow::Result { - sources.json_merge() - } - - fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} -} - impl Settings for NotificationPanelSettings { - const KEY: Option<&'static str> = Some("notification_panel"); - - type FileContent = PanelSettingsContent; + type FileContent = NotificationPanelSettingsContent; fn load( sources: SettingsSources, @@ -121,8 +89,6 @@ impl Settings for NotificationPanelSettings { } impl Settings for MessageEditorSettings { - const KEY: Option<&'static str> = Some("message_editor"); - type FileContent = MessageEditorSettings; fn load( diff --git a/crates/command_palette/src/command_palette.rs b/crates/command_palette/src/command_palette.rs index b8800ff91284e6f105c029f7fffe9b4b83b6bcd1..227d246f04cecf8a5c58f2361d0b543ff678eac6 100644 --- a/crates/command_palette/src/command_palette.rs +++ b/crates/command_palette/src/command_palette.rs @@ -206,7 +206,7 @@ impl CommandPaletteDelegate { if parse_zed_link(&query, cx).is_some() { intercept_results = vec![CommandInterceptResult { action: OpenZedUrl { url: query.clone() }.boxed_clone(), - string: query.clone(), + string: query, positions: vec![], }] } diff --git a/crates/command_palette/src/persistence.rs b/crates/command_palette/src/persistence.rs index 5be97c36bc57cea59b51272270fd39ae1a9ab70d..01cf403083b2de4ed7919801ab33e4aae947007e 100644 --- a/crates/command_palette/src/persistence.rs +++ b/crates/command_palette/src/persistence.rs @@ -1,7 +1,10 @@ use anyhow::Result; use db::{ - define_connection, query, - sqlez::{bindable::Column, statement::Statement}, + query, + sqlez::{ + bindable::Column, domain::Domain, statement::Statement, + thread_safe_connection::ThreadSafeConnection, + }, sqlez_macros::sql, }; use serde::{Deserialize, Serialize}; @@ -50,8 +53,11 @@ impl Column for SerializedCommandInvocation { } } -define_connection!(pub static ref COMMAND_PALETTE_HISTORY: CommandPaletteDB<()> = - &[sql!( +pub struct CommandPaletteDB(ThreadSafeConnection); + +impl Domain for CommandPaletteDB { + const NAME: &str = stringify!(CommandPaletteDB); + const MIGRATIONS: &[&str] = &[sql!( CREATE TABLE IF NOT EXISTS command_invocations( id INTEGER PRIMARY KEY AUTOINCREMENT, command_name TEXT NOT NULL, @@ -59,7 +65,9 @@ define_connection!(pub static ref COMMAND_PALETTE_HISTORY: CommandPaletteDB<()> last_invoked INTEGER DEFAULT (unixepoch()) NOT NULL ) STRICT; )]; -); +} + +db::static_connection!(COMMAND_PALETTE_HISTORY, CommandPaletteDB, []); impl CommandPaletteDB { pub async fn write_command_invocation( diff --git a/crates/command_palette_hooks/src/command_palette_hooks.rs b/crates/command_palette_hooks/src/command_palette_hooks.rs index df64d53874b4907b3bf586ee7935302c2e6979ae..f1344c5ba6d46fce966ace60d483e3c0fc717f80 100644 --- a/crates/command_palette_hooks/src/command_palette_hooks.rs +++ b/crates/command_palette_hooks/src/command_palette_hooks.rs @@ -76,7 +76,7 @@ impl CommandPaletteFilter { } /// Hides all actions with the given types. - pub fn hide_action_types(&mut self, action_types: &[TypeId]) { + pub fn hide_action_types<'a>(&mut self, action_types: impl IntoIterator) { for action_type in action_types { self.hidden_action_types.insert(*action_type); self.shown_action_types.remove(action_type); @@ -84,7 +84,7 @@ impl CommandPaletteFilter { } /// Shows all actions with the given types. - pub fn show_action_types<'a>(&mut self, action_types: impl Iterator) { + pub fn show_action_types<'a>(&mut self, action_types: impl IntoIterator) { for action_type in action_types { self.shown_action_types.insert(*action_type); self.hidden_action_types.remove(action_type); diff --git a/crates/component/Cargo.toml b/crates/component/Cargo.toml index 92249de454d7140343cc6f814f6ac1bd99685cda..74481834f1cab5047dec3cd32121eb002fabbbbd 100644 --- a/crates/component/Cargo.toml +++ b/crates/component/Cargo.toml @@ -20,5 +20,8 @@ strum.workspace = true theme.workspace = true workspace-hack.workspace = true +[dev-dependencies] +documented.workspace = true + [features] default = [] diff --git a/crates/component/src/component.rs b/crates/component/src/component.rs index 0c05ba4a97f4598e9f7982cbc294831a955f1fc6..8c7b7ea4d7347ff087c84880c31df5d355870f65 100644 --- a/crates/component/src/component.rs +++ b/crates/component/src/component.rs @@ -227,6 +227,8 @@ pub trait Component { /// Example: /// /// ``` + /// use documented::Documented; + /// /// /// This is a doc comment. /// #[derive(Documented)] /// struct MyComponent; diff --git a/crates/component/src/component_layout.rs b/crates/component/src/component_layout.rs index 58bf1d8f0c85533a4a06bd38c07f840c08cc6de3..a840d520a62b57516f20c190f2a5148505ccfed4 100644 --- a/crates/component/src/component_layout.rs +++ b/crates/component/src/component_layout.rs @@ -42,7 +42,7 @@ impl RenderOnce for ComponentExample { div() .text_size(rems(0.875)) .text_color(cx.theme().colors().text_muted) - .child(description.clone()), + .child(description), ) }), ) diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 65283afa87d94fae3ec51f8a89574713080bded2..b3b44dbde67d92ce620d85a39a0925f27a4e2086 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -25,7 +25,7 @@ use crate::{ }; const JSON_RPC_VERSION: &str = "2.0"; -const REQUEST_TIMEOUT: Duration = Duration::from_secs(60); +const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(60); // Standard JSON-RPC error codes pub const PARSE_ERROR: i32 = -32700; @@ -60,6 +60,7 @@ pub(crate) struct Client { executor: BackgroundExecutor, #[allow(dead_code)] transport: Arc, + request_timeout: Option, } #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -67,11 +68,7 @@ pub(crate) struct Client { pub(crate) struct ContextServerId(pub Arc); fn is_null_value(value: &T) -> bool { - if let Ok(Value::Null) = serde_json::to_value(value) { - true - } else { - false - } + matches!(serde_json::to_value(value), Ok(Value::Null)) } #[derive(Serialize, Deserialize)] @@ -147,6 +144,7 @@ pub struct ModelContextServerBinary { pub executable: PathBuf, pub args: Vec, pub env: Option>, + pub timeout: Option, } impl Client { @@ -161,7 +159,7 @@ impl Client { working_directory: &Option, cx: AsyncApp, ) -> Result { - log::info!( + log::debug!( "starting context server (executable={:?}, args={:?})", binary.executable, &binary.args @@ -173,8 +171,9 @@ impl Client { .map(|name| name.to_string_lossy().to_string()) .unwrap_or_else(String::new); + let timeout = binary.timeout.map(Duration::from_millis); let transport = Arc::new(StdioTransport::new(binary, working_directory, &cx)?); - Self::new(server_id, server_name.into(), transport, cx) + Self::new(server_id, server_name.into(), transport, timeout, cx) } /// Creates a new Client instance for a context server. @@ -182,6 +181,7 @@ impl Client { server_id: ContextServerId, server_name: Arc, transport: Arc, + request_timeout: Option, cx: AsyncApp, ) -> Result { let (outbound_tx, outbound_rx) = channel::unbounded::(); @@ -241,6 +241,7 @@ impl Client { io_tasks: Mutex::new(Some((input_task, output_task))), output_done_rx: Mutex::new(Some(output_done_rx)), transport, + request_timeout, }) } @@ -271,10 +272,10 @@ impl Client { ); } } else if let Ok(response) = serde_json::from_str::(&message) { - if let Some(handlers) = response_handlers.lock().as_mut() { - if let Some(handler) = handlers.remove(&response.id) { - handler(Ok(message.to_string())); - } + if let Some(handlers) = response_handlers.lock().as_mut() + && let Some(handler) = handlers.remove(&response.id) + { + handler(Ok(message.to_string())); } } else if let Ok(notification) = serde_json::from_str::(&message) { let mut notification_handlers = notification_handlers.lock(); @@ -295,7 +296,7 @@ impl Client { /// Continuously reads and logs any error messages from the server. async fn handle_err(transport: Arc) -> anyhow::Result<()> { while let Some(err) = transport.receive_err().next().await { - log::warn!("context server stderr: {}", err.trim()); + log::debug!("context server stderr: {}", err.trim()); } Ok(()) @@ -331,8 +332,13 @@ impl Client { method: &str, params: impl Serialize, ) -> Result { - self.request_with(method, params, None, Some(REQUEST_TIMEOUT)) - .await + self.request_with( + method, + params, + None, + self.request_timeout.or(Some(DEFAULT_REQUEST_TIMEOUT)), + ) + .await } pub async fn request_with( diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index 34fa29678d5d68f864de7d9df3bef82d4c667f05..b126bb393784664692b5de39fee5ed7f66e9948a 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -34,6 +34,8 @@ pub struct ContextServerCommand { pub path: PathBuf, pub args: Vec, pub env: Option>, + /// Timeout for tool calls in milliseconds. Defaults to 60000 (60 seconds) if not specified. + pub timeout: Option, } impl std::fmt::Debug for ContextServerCommand { @@ -123,6 +125,7 @@ impl ContextServer { executable: Path::new(&command.path).to_path_buf(), args: command.args.clone(), env: command.env.clone(), + timeout: command.timeout, }, working_directory, cx.clone(), @@ -131,13 +134,14 @@ impl ContextServer { client::ContextServerId(self.id.0.clone()), self.id().0, transport.clone(), + None, cx.clone(), )?, }) } async fn initialize(&self, client: Client) -> Result<()> { - log::info!("starting context server {}", self.id); + log::debug!("starting context server {}", self.id); let protocol = crate::protocol::ModelContextProtocol::new(client); let client_info = types::Implementation { name: "Zed".to_string(), diff --git a/crates/context_server/src/listener.rs b/crates/context_server/src/listener.rs index 0e85fb21292739ab0a92d0898fc449a31efe6f29..4e5da2566ee25ee70e1687cf5f0806e19789a824 100644 --- a/crates/context_server/src/listener.rs +++ b/crates/context_server/src/listener.rs @@ -14,6 +14,7 @@ use serde::de::DeserializeOwned; use serde_json::{json, value::RawValue}; use smol::stream::StreamExt; use std::{ + any::TypeId, cell::RefCell, path::{Path, PathBuf}, rc::Rc, @@ -77,7 +78,7 @@ impl McpServer { socket_path, _server_task: server_task, tools, - handlers: handlers, + handlers, }) }) } @@ -87,23 +88,30 @@ impl McpServer { settings.inline_subschemas = true; let mut generator = settings.into_generator(); - let output_schema = generator.root_schema_for::(); - let unit_schema = generator.root_schema_for::(); + let input_schema = generator.root_schema_for::(); + + let description = input_schema + .get("description") + .and_then(|desc| desc.as_str()) + .map(|desc| desc.to_string()); + debug_assert!( + description.is_some(), + "Input schema struct must include a doc comment for the tool description" + ); let registered_tool = RegisteredTool { tool: Tool { name: T::NAME.into(), - description: Some(tool.description().into()), - input_schema: generator.root_schema_for::().into(), - output_schema: if output_schema == unit_schema { + description, + input_schema: input_schema.into(), + output_schema: if TypeId::of::() == TypeId::of::<()>() { None } else { - Some(output_schema.into()) + Some(generator.root_schema_for::().into()) }, annotations: Some(tool.annotations()), }, handler: Box::new({ - let tool = tool.clone(); move |input_value, cx| { let input = match input_value { Some(input) => serde_json::from_value(input), @@ -315,12 +323,12 @@ impl McpServer { Self::send_err( request_id, format!("Tool not found: {}", params.name), - &outgoing_tx, + outgoing_tx, ); } } Err(err) => { - Self::send_err(request_id, err.to_string(), &outgoing_tx); + Self::send_err(request_id, err.to_string(), outgoing_tx); } } } @@ -399,8 +407,6 @@ pub trait McpServerTool { const NAME: &'static str; - fn description(&self) -> &'static str; - fn annotations(&self) -> ToolAnnotations { ToolAnnotations { title: None, @@ -418,6 +424,7 @@ pub trait McpServerTool { ) -> impl Future>>; } +#[derive(Debug)] pub struct ToolResponse { pub content: Vec, pub structured_content: T, diff --git a/crates/context_server/src/test.rs b/crates/context_server/src/test.rs index dedf589664215a733b7d6bd5c2273af246863f42..008542ab246bc2d68a62d779e985e5941ac16856 100644 --- a/crates/context_server/src/test.rs +++ b/crates/context_server/src/test.rs @@ -1,6 +1,6 @@ use anyhow::Context as _; use collections::HashMap; -use futures::{Stream, StreamExt as _, lock::Mutex}; +use futures::{FutureExt, Stream, StreamExt as _, future::BoxFuture, lock::Mutex}; use gpui::BackgroundExecutor; use std::{pin::Pin, sync::Arc}; @@ -14,9 +14,12 @@ pub fn create_fake_transport( executor: BackgroundExecutor, ) -> FakeTransport { let name = name.into(); - FakeTransport::new(executor).on_request::(move |_params| { - create_initialize_response(name.clone()) - }) + FakeTransport::new(executor).on_request::( + move |_params| { + let name = name.clone(); + async move { create_initialize_response(name.clone()) } + }, + ) } fn create_initialize_response(server_name: String) -> InitializeResponse { @@ -32,8 +35,10 @@ fn create_initialize_response(server_name: String) -> InitializeResponse { } pub struct FakeTransport { - request_handlers: - HashMap<&'static str, Arc serde_json::Value + Send + Sync>>, + request_handlers: HashMap< + &'static str, + Arc BoxFuture<'static, serde_json::Value>>, + >, tx: futures::channel::mpsc::UnboundedSender, rx: Arc>>, executor: BackgroundExecutor, @@ -50,18 +55,25 @@ impl FakeTransport { } } - pub fn on_request( + pub fn on_request( mut self, - handler: impl Fn(T::Params) -> T::Response + Send + Sync + 'static, - ) -> Self { + handler: impl 'static + Send + Sync + Fn(T::Params) -> Fut, + ) -> Self + where + T: crate::types::Request, + Fut: 'static + Send + Future, + { self.request_handlers.insert( T::METHOD, Arc::new(move |value| { - let params = value.get("params").expect("Missing parameters").clone(); + let params = value + .get("params") + .cloned() + .unwrap_or(serde_json::Value::Null); let params: T::Params = serde_json::from_value(params).expect("Invalid parameters received"); let response = handler(params); - serde_json::to_value(response).unwrap() + async move { serde_json::to_value(response.await).unwrap() }.boxed() }), ); self @@ -77,7 +89,7 @@ impl Transport for FakeTransport { if let Some(method) = msg.get("method") { let method = method.as_str().expect("Invalid method received"); if let Some(handler) = self.request_handlers.get(method) { - let payload = handler(msg); + let payload = handler(msg).await; let response = serde_json::json!({ "jsonrpc": "2.0", "id": id, diff --git a/crates/context_server/src/types.rs b/crates/context_server/src/types.rs index 5fa2420a3d40ce04ee97b4f88c1105711dea8793..03aca4f3caf7995091bbc8e049494b324674a9d3 100644 --- a/crates/context_server/src/types.rs +++ b/crates/context_server/src/types.rs @@ -691,7 +691,7 @@ impl CallToolResponse { let mut text = String::new(); for chunk in &self.content { if let ToolResponseContent::Text { text: chunk } = chunk { - text.push_str(&chunk) + text.push_str(chunk) }; } text @@ -711,6 +711,16 @@ pub enum ToolResponseContent { Resource { resource: ResourceContents }, } +impl ToolResponseContent { + pub fn text(&self) -> Option<&str> { + if let ToolResponseContent::Text { text } = self { + Some(text) + } else { + None + } + } +} + #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ListToolsResponse { diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 166a582c70aa54fb48e291133c65d651cf6fa66f..61b7a4e18e4e679c29e26185735352737983c4d1 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -21,7 +21,7 @@ use language::{ point_from_lsp, point_to_lsp, }; use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName}; -use node_runtime::{NodeRuntime, VersionCheck}; +use node_runtime::{NodeRuntime, VersionStrategy}; use parking_lot::Mutex; use project::DisableAiSettings; use request::StatusNotification; @@ -81,10 +81,7 @@ pub fn init( }; copilot_chat::init(fs.clone(), http.clone(), configuration, cx); - let copilot = cx.new({ - let node_runtime = node_runtime.clone(); - move |cx| Copilot::start(new_server_id, fs, node_runtime, cx) - }); + let copilot = cx.new(move |cx| Copilot::start(new_server_id, fs, node_runtime, cx)); Copilot::set_global(copilot.clone(), cx); cx.observe(&copilot, |copilot, cx| { copilot.update(cx, |copilot, cx| copilot.update_action_visibilities(cx)); @@ -129,7 +126,7 @@ impl CopilotServer { fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> { let server = self.as_running()?; anyhow::ensure!( - matches!(server.sign_in_status, SignInStatus::Authorized { .. }), + matches!(server.sign_in_status, SignInStatus::Authorized), "must sign in before using copilot" ); Ok(server) @@ -200,7 +197,7 @@ impl Status { } struct RegisteredBuffer { - uri: lsp::Url, + uri: lsp::Uri, language_id: String, snapshot: BufferSnapshot, snapshot_version: i32, @@ -349,7 +346,11 @@ impl Copilot { this.start_copilot(true, false, cx); cx.observe_global::(move |this, cx| { this.start_copilot(true, false, cx); - this.send_configuration_update(cx); + if let Ok(server) = this.server.as_running() { + notify_did_change_config_to_server(&server.lsp, cx) + .context("copilot setting change: did change configuration") + .log_err(); + } }) .detach(); this @@ -438,43 +439,6 @@ impl Copilot { if env.is_empty() { None } else { Some(env) } } - fn send_configuration_update(&mut self, cx: &mut Context) { - let copilot_settings = all_language_settings(None, cx) - .edit_predictions - .copilot - .clone(); - - let settings = json!({ - "http": { - "proxy": copilot_settings.proxy, - "proxyStrictSSL": !copilot_settings.proxy_no_verify.unwrap_or(false) - }, - "github-enterprise": { - "uri": copilot_settings.enterprise_uri - } - }); - - if let Some(copilot_chat) = copilot_chat::CopilotChat::global(cx) { - copilot_chat.update(cx, |chat, cx| { - chat.set_configuration( - copilot_chat::CopilotChatConfiguration { - enterprise_uri: copilot_settings.enterprise_uri.clone(), - }, - cx, - ); - }); - } - - if let Ok(server) = self.server.as_running() { - server - .lsp - .notify::( - &lsp::DidChangeConfigurationParams { settings }, - ) - .log_err(); - } - } - #[cfg(any(test, feature = "test-support"))] pub fn fake(cx: &mut gpui::TestAppContext) -> (Entity, lsp::FakeLanguageServer) { use fs::FakeFs; @@ -573,6 +537,9 @@ impl Copilot { })? .await?; + this.update(cx, |_, cx| notify_did_change_config_to_server(&server, cx))? + .context("copilot: did change configuration")?; + let status = server .request::(request::CheckStatusParams { local_checks_only: false, @@ -598,8 +565,6 @@ impl Copilot { }); cx.emit(Event::CopilotLanguageServerStarted); this.update_sign_in_status(status, cx); - // Send configuration now that the LSP is fully started - this.send_configuration_update(cx); } Err(error) => { this.server = CopilotServer::Error(error.to_string().into()); @@ -613,12 +578,12 @@ impl Copilot { pub(crate) fn sign_in(&mut self, cx: &mut Context) -> Task> { if let CopilotServer::Running(server) = &mut self.server { let task = match &server.sign_in_status { - SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(), + SignInStatus::Authorized => Task::ready(Ok(())).shared(), SignInStatus::SigningIn { task, .. } => { cx.notify(); task.clone() } - SignInStatus::SignedOut { .. } | SignInStatus::Unauthorized { .. } => { + SignInStatus::SignedOut { .. } | SignInStatus::Unauthorized => { let lsp = server.lsp.clone(); let task = cx .spawn(async move |this, cx| { @@ -640,15 +605,13 @@ impl Copilot { sign_in_status: status, .. }) = &mut this.server - { - if let SignInStatus::SigningIn { + && let SignInStatus::SigningIn { prompt: prompt_flow, .. } = status - { - *prompt_flow = Some(flow.clone()); - cx.notify(); - } + { + *prompt_flow = Some(flow.clone()); + cx.notify(); } })?; let response = lsp @@ -764,7 +727,7 @@ impl Copilot { .. }) = &mut self.server { - if !matches!(status, SignInStatus::Authorized { .. }) { + if !matches!(status, SignInStatus::Authorized) { return; } @@ -814,59 +777,58 @@ impl Copilot { event: &language::BufferEvent, cx: &mut Context, ) -> Result<()> { - if let Ok(server) = self.server.as_running() { - if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id()) - { - match event { - language::BufferEvent::Edited => { - drop(registered_buffer.report_changes(&buffer, cx)); - } - language::BufferEvent::Saved => { + if let Ok(server) = self.server.as_running() + && let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id()) + { + match event { + language::BufferEvent::Edited => { + drop(registered_buffer.report_changes(&buffer, cx)); + } + language::BufferEvent::Saved => { + server + .lsp + .notify::( + &lsp::DidSaveTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new( + registered_buffer.uri.clone(), + ), + text: None, + }, + )?; + } + language::BufferEvent::FileHandleChanged + | language::BufferEvent::LanguageChanged => { + let new_language_id = id_for_language(buffer.read(cx).language()); + let Ok(new_uri) = uri_for_buffer(&buffer, cx) else { + return Ok(()); + }; + if new_uri != registered_buffer.uri + || new_language_id != registered_buffer.language_id + { + let old_uri = mem::replace(&mut registered_buffer.uri, new_uri); + registered_buffer.language_id = new_language_id; server .lsp - .notify::( - &lsp::DidSaveTextDocumentParams { - text_document: lsp::TextDocumentIdentifier::new( + .notify::( + &lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new(old_uri), + }, + )?; + server + .lsp + .notify::( + &lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem::new( registered_buffer.uri.clone(), + registered_buffer.language_id.clone(), + registered_buffer.snapshot_version, + registered_buffer.snapshot.text(), ), - text: None, }, )?; } - language::BufferEvent::FileHandleChanged - | language::BufferEvent::LanguageChanged => { - let new_language_id = id_for_language(buffer.read(cx).language()); - let Ok(new_uri) = uri_for_buffer(&buffer, cx) else { - return Ok(()); - }; - if new_uri != registered_buffer.uri - || new_language_id != registered_buffer.language_id - { - let old_uri = mem::replace(&mut registered_buffer.uri, new_uri); - registered_buffer.language_id = new_language_id; - server - .lsp - .notify::( - &lsp::DidCloseTextDocumentParams { - text_document: lsp::TextDocumentIdentifier::new(old_uri), - }, - )?; - server - .lsp - .notify::( - &lsp::DidOpenTextDocumentParams { - text_document: lsp::TextDocumentItem::new( - registered_buffer.uri.clone(), - registered_buffer.language_id.clone(), - registered_buffer.snapshot_version, - registered_buffer.snapshot.text(), - ), - }, - )?; - } - } - _ => {} } + _ => {} } } @@ -874,17 +836,17 @@ impl Copilot { } fn unregister_buffer(&mut self, buffer: &WeakEntity) { - if let Ok(server) = self.server.as_running() { - if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) { - server - .lsp - .notify::( - &lsp::DidCloseTextDocumentParams { - text_document: lsp::TextDocumentIdentifier::new(buffer.uri), - }, - ) - .ok(); - } + if let Ok(server) = self.server.as_running() + && let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) + { + server + .lsp + .notify::( + &lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new(buffer.uri), + }, + ) + .ok(); } } @@ -1047,8 +1009,8 @@ impl Copilot { CopilotServer::Error(error) => Status::Error(error.clone()), CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => { match sign_in_status { - SignInStatus::Authorized { .. } => Status::Authorized, - SignInStatus::Unauthorized { .. } => Status::Unauthorized, + SignInStatus::Authorized => Status::Authorized, + SignInStatus::Unauthorized => Status::Unauthorized, SignInStatus::SigningIn { prompt, .. } => Status::SigningIn { prompt: prompt.clone(), }, @@ -1133,7 +1095,7 @@ impl Copilot { _ => { filter.hide_action_types(&signed_in_actions); filter.hide_action_types(&auth_actions); - filter.show_action_types(no_auth_actions.iter()); + filter.show_action_types(&no_auth_actions); } } } @@ -1146,9 +1108,9 @@ fn id_for_language(language: Option<&Arc>) -> String { .unwrap_or_else(|| "plaintext".to_string()) } -fn uri_for_buffer(buffer: &Entity, cx: &App) -> Result { +fn uri_for_buffer(buffer: &Entity, cx: &App) -> Result { if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) { - lsp::Url::from_file_path(file.abs_path(cx)) + lsp::Uri::from_file_path(file.abs_path(cx)) } else { format!("buffer://{}", buffer.entity_id()) .parse() @@ -1156,6 +1118,41 @@ fn uri_for_buffer(buffer: &Entity, cx: &App) -> Result { } } +fn notify_did_change_config_to_server( + server: &Arc, + cx: &mut Context, +) -> std::result::Result<(), anyhow::Error> { + let copilot_settings = all_language_settings(None, cx) + .edit_predictions + .copilot + .clone(); + + if let Some(copilot_chat) = copilot_chat::CopilotChat::global(cx) { + copilot_chat.update(cx, |chat, cx| { + chat.set_configuration( + copilot_chat::CopilotChatConfiguration { + enterprise_uri: copilot_settings.enterprise_uri.clone(), + }, + cx, + ); + }); + } + + let settings = json!({ + "http": { + "proxy": copilot_settings.proxy, + "proxyStrictSSL": !copilot_settings.proxy_no_verify.unwrap_or(false) + }, + "github-enterprise": { + "uri": copilot_settings.enterprise_uri + } + }); + + server.notify::(&lsp::DidChangeConfigurationParams { + settings, + }) +} + async fn clear_copilot_dir() { remove_matching(paths::copilot_dir(), |_| true).await } @@ -1169,8 +1166,9 @@ async fn get_copilot_lsp(fs: Arc, node_runtime: NodeRuntime) -> anyhow:: const SERVER_PATH: &str = "node_modules/@github/copilot-language-server/dist/language-server.js"; - // pinning it: https://github.com/zed-industries/zed/issues/36093 - const PINNED_VERSION: &str = "1.354"; + let latest_version = node_runtime + .npm_package_latest_version(PACKAGE_NAME) + .await?; let server_path = paths::copilot_dir().join(SERVER_PATH); fs.create_dir(paths::copilot_dir()).await?; @@ -1180,13 +1178,12 @@ async fn get_copilot_lsp(fs: Arc, node_runtime: NodeRuntime) -> anyhow:: PACKAGE_NAME, &server_path, paths::copilot_dir(), - &PINNED_VERSION, - VersionCheck::VersionMismatch, + VersionStrategy::Latest(&latest_version), ) .await; if should_install { node_runtime - .npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &PINNED_VERSION)]) + .npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &latest_version)]) .await?; } @@ -1204,7 +1201,7 @@ mod tests { let (copilot, mut lsp) = Copilot::fake(cx); let buffer_1 = cx.new(|cx| Buffer::local("Hello", cx)); - let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.entity_id().as_u64()) + let buffer_1_uri: lsp::Uri = format!("buffer://{}", buffer_1.entity_id().as_u64()) .parse() .unwrap(); copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx)); @@ -1222,7 +1219,7 @@ mod tests { ); let buffer_2 = cx.new(|cx| Buffer::local("Goodbye", cx)); - let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.entity_id().as_u64()) + let buffer_2_uri: lsp::Uri = format!("buffer://{}", buffer_2.entity_id().as_u64()) .parse() .unwrap(); copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx)); @@ -1273,7 +1270,7 @@ mod tests { text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri), } ); - let buffer_1_uri = lsp::Url::from_file_path(path!("/root/child/buffer-1")).unwrap(); + let buffer_1_uri = lsp::Uri::from_file_path(path!("/root/child/buffer-1")).unwrap(); assert_eq!( lsp.receive_notification::() .await, diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index 4c91b4fedb790ab3500273ff21aba767cacd28e0..ccd8f09613eec54f2d30b619f142d111bf2a3497 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -62,12 +62,6 @@ impl CopilotChatConfiguration { } } -// Copilot's base model; defined by Microsoft in premium requests table -// This will be moved to the front of the Copilot model list, and will be used for -// 'fast' requests (e.g. title generation) -// https://docs.github.com/en/copilot/managing-copilot/monitoring-usage-and-entitlements/about-premium-requests -const DEFAULT_MODEL_ID: &str = "gpt-4.1"; - #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(rename_all = "lowercase")] pub enum Role { @@ -101,22 +95,41 @@ where Ok(models) } -#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)] pub struct Model { + billing: ModelBilling, capabilities: ModelCapabilities, id: String, name: String, policy: Option, vendor: ModelVendor, + is_chat_default: bool, + // The model with this value true is selected by VSCode copilot if a premium request limit is + // reached. Zed does not currently implement this behaviour + is_chat_fallback: bool, model_picker_enabled: bool, } +#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)] +struct ModelBilling { + is_premium: bool, + multiplier: f64, + // List of plans a model is restricted to + // Field is not present if a model is available for all plans + #[serde(default)] + restricted_to: Option>, +} + #[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)] struct ModelCapabilities { family: String, #[serde(default)] limits: ModelLimits, supports: ModelSupportedFeatures, + #[serde(rename = "type")] + model_type: String, + #[serde(default)] + tokenizer: Option, } #[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)] @@ -153,6 +166,11 @@ pub enum ModelVendor { OpenAI, Google, Anthropic, + #[serde(rename = "xAI")] + XAI, + /// Unknown vendor that we don't explicitly support yet + #[serde(other)] + Unknown, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] @@ -201,6 +219,10 @@ impl Model { pub fn supports_parallel_tool_calls(&self) -> bool { self.capabilities.supports.parallel_tool_calls } + + pub fn tokenizer(&self) -> Option<&str> { + self.capabilities.tokenizer.as_deref() + } } #[derive(Serialize, Deserialize)] @@ -484,7 +506,7 @@ impl CopilotChat { }; if this.oauth_token.is_some() { - cx.spawn(async move |this, mut cx| Self::update_models(&this, &mut cx).await) + cx.spawn(async move |this, cx| Self::update_models(&this, cx).await) .detach_and_log_err(cx); } @@ -602,6 +624,7 @@ async fn get_models( .into_iter() .filter(|model| { model.model_picker_enabled + && model.capabilities.model_type.as_str() == "chat" && model .policy .as_ref() @@ -610,9 +633,7 @@ async fn get_models( .dedup_by(|a, b| a.capabilities.family == b.capabilities.family) .collect(); - if let Some(default_model_position) = - models.iter().position(|model| model.id == DEFAULT_MODEL_ID) - { + if let Some(default_model_position) = models.iter().position(|model| model.is_chat_default) { let default_model = models.remove(default_model_position); models.insert(0, default_model); } @@ -630,7 +651,9 @@ async fn request_models( .uri(models_url.as_ref()) .header("Authorization", format!("Bearer {}", api_token)) .header("Content-Type", "application/json") - .header("Copilot-Integration-Id", "vscode-chat"); + .header("Copilot-Integration-Id", "vscode-chat") + .header("Editor-Version", "vscode/1.103.2") + .header("x-github-api-version", "2025-05-01"); let request = request_builder.body(AsyncBody::empty())?; @@ -801,6 +824,10 @@ mod tests { let json = r#"{ "data": [ { + "billing": { + "is_premium": false, + "multiplier": 0 + }, "capabilities": { "family": "gpt-4", "limits": { @@ -814,6 +841,8 @@ mod tests { "type": "chat" }, "id": "gpt-4", + "is_chat_default": false, + "is_chat_fallback": false, "model_picker_enabled": false, "name": "GPT 4", "object": "model", @@ -825,6 +854,16 @@ mod tests { "some-unknown-field": 123 }, { + "billing": { + "is_premium": true, + "multiplier": 1, + "restricted_to": [ + "pro", + "pro_plus", + "business", + "enterprise" + ] + }, "capabilities": { "family": "claude-3.7-sonnet", "limits": { @@ -848,6 +887,8 @@ mod tests { "type": "chat" }, "id": "claude-3.7-sonnet", + "is_chat_default": false, + "is_chat_fallback": false, "model_picker_enabled": true, "name": "Claude 3.7 Sonnet", "object": "model", @@ -863,10 +904,51 @@ mod tests { "object": "list" }"#; - let schema: ModelSchema = serde_json::from_str(&json).unwrap(); + let schema: ModelSchema = serde_json::from_str(json).unwrap(); assert_eq!(schema.data.len(), 2); assert_eq!(schema.data[0].id, "gpt-4"); assert_eq!(schema.data[1].id, "claude-3.7-sonnet"); } + + #[test] + fn test_unknown_vendor_resilience() { + let json = r#"{ + "data": [ + { + "billing": { + "is_premium": false, + "multiplier": 1 + }, + "capabilities": { + "family": "future-model", + "limits": { + "max_context_window_tokens": 128000, + "max_output_tokens": 8192, + "max_prompt_tokens": 120000 + }, + "object": "model_capabilities", + "supports": { "streaming": true, "tool_calls": true }, + "type": "chat" + }, + "id": "future-model-v1", + "is_chat_default": false, + "is_chat_fallback": false, + "model_picker_enabled": true, + "name": "Future Model v1", + "object": "model", + "preview": false, + "vendor": "SomeNewVendor", + "version": "v1.0" + } + ], + "object": "list" + }"#; + + let schema: ModelSchema = serde_json::from_str(json).unwrap(); + + assert_eq!(schema.data.len(), 1); + assert_eq!(schema.data[0].id, "future-model-v1"); + assert_eq!(schema.data[0].vendor, ModelVendor::Unknown); + } } diff --git a/crates/copilot/src/copilot_completion_provider.rs b/crates/copilot/src/copilot_completion_provider.rs index 2fd6df27b9e15d4247d85edca4d8836c35b23df1..52d75175e5b5ba265bb32c6c15c713e1bd8faecd 100644 --- a/crates/copilot/src/copilot_completion_provider.rs +++ b/crates/copilot/src/copilot_completion_provider.rs @@ -301,6 +301,7 @@ mod tests { init_test(cx, |settings| { settings.defaults.completions = Some(CompletionSettings { words: WordsCompletionMode::Disabled, + words_min_length: 0, lsp: true, lsp_fetch_timeout_ms: 0, lsp_insert_mode: LspInsertMode::Insert, @@ -533,6 +534,7 @@ mod tests { init_test(cx, |settings| { settings.defaults.completions = Some(CompletionSettings { words: WordsCompletionMode::Disabled, + words_min_length: 0, lsp: true, lsp_fetch_timeout_ms: 0, lsp_insert_mode: LspInsertMode::Insert, @@ -1083,7 +1085,7 @@ mod tests { let replace_range_marker: TextRangeMarker = ('<', '>').into(); let (_, mut marked_ranges) = marked_text_ranges_by( marked_string, - vec![complete_from_marker.clone(), replace_range_marker.clone()], + vec![complete_from_marker, replace_range_marker.clone()], ); let replace_range = diff --git a/crates/copilot/src/request.rs b/crates/copilot/src/request.rs index 0deabe16d15c4a502b278c4a631720094ad18af7..85d6254dc060824a9b2686e8f53090fccb39980e 100644 --- a/crates/copilot/src/request.rs +++ b/crates/copilot/src/request.rs @@ -102,7 +102,7 @@ pub struct GetCompletionsDocument { pub tab_size: u32, pub indent_size: u32, pub insert_spaces: bool, - pub uri: lsp::Url, + pub uri: lsp::Uri, pub relative_path: String, pub position: lsp::Position, pub version: usize, diff --git a/crates/crashes/Cargo.toml b/crates/crashes/Cargo.toml index afb4936b6370791b133395b6205fb5cffaa17284..9af416cbb0801c68e1a9a85b37a1b80c52da476c 100644 --- a/crates/crashes/Cargo.toml +++ b/crates/crashes/Cargo.toml @@ -6,13 +6,21 @@ edition.workspace = true license = "GPL-3.0-or-later" [dependencies] +bincode.workspace = true crash-handler.workspace = true log.workspace = true minidumper.workspace = true paths.workspace = true release_channel.workspace = true smol.workspace = true +serde.workspace = true +serde_json.workspace = true +system_specs.workspace = true workspace-hack.workspace = true +zstd.workspace = true + +[target.'cfg(target_os = "macos")'.dependencies] +mach2.workspace = true [lints] workspace = true diff --git a/crates/crashes/src/crashes.rs b/crates/crashes/src/crashes.rs index 5b9ae0b54606c7ee9bf3034a097d230e7570f572..f867f6cbdd6d1e0aa1dab8d7e4cd188295ace480 100644 --- a/crates/crashes/src/crashes.rs +++ b/crates/crashes/src/crashes.rs @@ -2,15 +2,19 @@ use crash_handler::CrashHandler; use log::info; use minidumper::{Client, LoopAction, MinidumpBinary}; use release_channel::{RELEASE_CHANNEL, ReleaseChannel}; +use serde::{Deserialize, Serialize}; +#[cfg(target_os = "macos")] +use std::sync::atomic::AtomicU32; use std::{ env, - fs::File, + fs::{self, File}, io, + panic::Location, path::{Path, PathBuf}, process::{self, Command}, sync::{ - LazyLock, OnceLock, + Arc, OnceLock, atomic::{AtomicBool, Ordering}, }, thread, @@ -18,19 +22,20 @@ use std::{ }; // set once the crash handler has initialized and the client has connected to it -pub static CRASH_HANDLER: AtomicBool = AtomicBool::new(false); +pub static CRASH_HANDLER: OnceLock> = OnceLock::new(); // set when the first minidump request is made to avoid generating duplicate crash reports pub static REQUESTED_MINIDUMP: AtomicBool = AtomicBool::new(false); -const CRASH_HANDLER_TIMEOUT: Duration = Duration::from_secs(60); +const CRASH_HANDLER_PING_TIMEOUT: Duration = Duration::from_secs(60); +const CRASH_HANDLER_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); -pub static GENERATE_MINIDUMPS: LazyLock = LazyLock::new(|| { - *RELEASE_CHANNEL != ReleaseChannel::Dev || env::var("ZED_GENERATE_MINIDUMPS").is_ok() -}); +#[cfg(target_os = "macos")] +static PANIC_THREAD_ID: AtomicU32 = AtomicU32::new(0); -pub async fn init(id: String) { - if !*GENERATE_MINIDUMPS { +pub async fn init(crash_init: InitCrashHandler) { + if *RELEASE_CHANNEL == ReleaseChannel::Dev && env::var("ZED_GENERATE_MINIDUMPS").is_err() { return; } + let exe = env::current_exe().expect("unable to find ourselves"); let zed_pid = process::id(); // TODO: we should be able to get away with using 1 crash-handler process per machine, @@ -61,9 +66,11 @@ pub async fn init(id: String) { smol::Timer::after(retry_frequency).await; } let client = maybe_client.unwrap(); - client.send_message(1, id).unwrap(); // set session id on the server + client + .send_message(1, serde_json::to_vec(&crash_init).unwrap()) + .unwrap(); - let client = std::sync::Arc::new(client); + let client = Arc::new(client); let handler = crash_handler::CrashHandler::attach(unsafe { let client = client.clone(); crash_handler::make_crash_event(move |crash_context: &crash_handler::CrashContext| { @@ -72,7 +79,9 @@ pub async fn init(id: String) { .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) .is_ok() { - client.send_message(2, "mistakes were made").unwrap(); + #[cfg(target_os = "macos")] + suspend_all_other_threads(); + client.ping().unwrap(); client.request_dump(crash_context).is_ok() } else { @@ -87,7 +96,7 @@ pub async fn init(id: String) { { handler.set_ptracer(Some(server_pid)); } - CRASH_HANDLER.store(true, Ordering::Release); + CRASH_HANDLER.set(client.clone()).ok(); std::mem::forget(handler); info!("crash handler registered"); @@ -97,64 +106,181 @@ pub async fn init(id: String) { } } +#[cfg(target_os = "macos")] +unsafe fn suspend_all_other_threads() { + let task = unsafe { mach2::traps::current_task() }; + let mut threads: mach2::mach_types::thread_act_array_t = std::ptr::null_mut(); + let mut count = 0; + unsafe { + mach2::task::task_threads(task, &raw mut threads, &raw mut count); + } + let current = unsafe { mach2::mach_init::mach_thread_self() }; + let panic_thread = PANIC_THREAD_ID.load(Ordering::SeqCst); + for i in 0..count { + let t = unsafe { *threads.add(i as usize) }; + if t != current && t != panic_thread { + unsafe { mach2::thread_act::thread_suspend(t) }; + } + } +} + pub struct CrashServer { - session_id: OnceLock, + initialization_params: OnceLock, + panic_info: OnceLock, + active_gpu: OnceLock, + has_connection: Arc, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct CrashInfo { + pub init: InitCrashHandler, + pub panic: Option, + pub minidump_error: Option, + pub gpus: Vec, + pub active_gpu: Option, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct InitCrashHandler { + pub session_id: String, + pub zed_version: String, + pub release_channel: String, + pub commit_sha: String, +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct CrashPanic { + pub message: String, + pub span: String, } impl minidumper::ServerHandler for CrashServer { fn create_minidump_file(&self) -> Result<(File, PathBuf), io::Error> { - let err_message = "Need to send a message with the ID upon starting the crash handler"; + let err_message = "Missing initialization data"; let dump_path = paths::logs_dir() - .join(self.session_id.get().expect(err_message)) + .join( + &self + .initialization_params + .get() + .expect(err_message) + .session_id, + ) .with_extension("dmp"); let file = File::create(&dump_path)?; Ok((file, dump_path)) } fn on_minidump_created(&self, result: Result) -> LoopAction { - match result { - Ok(mut md_bin) => { + let minidump_error = match result { + Ok(MinidumpBinary { mut file, path, .. }) => { use io::Write; - let _ = md_bin.file.flush(); - info!("wrote minidump to disk {:?}", md_bin.path); + file.flush().ok(); + // TODO: clean this up once https://github.com/EmbarkStudios/crash-handling/issues/101 is addressed + drop(file); + let original_file = File::open(&path).unwrap(); + let compressed_path = path.with_extension("zstd"); + let compressed_file = File::create(&compressed_path).unwrap(); + zstd::stream::copy_encode(original_file, compressed_file, 0).ok(); + fs::rename(&compressed_path, path).unwrap(); + None } - Err(e) => { - info!("failed to write minidump: {:#}", e); + Err(e) => Some(format!("{e:?}")), + }; + + #[cfg(not(any(target_os = "linux", target_os = "freebsd")))] + let gpus = vec![]; + + #[cfg(any(target_os = "linux", target_os = "freebsd"))] + let gpus = match system_specs::read_gpu_info_from_sys_class_drm() { + Ok(gpus) => gpus, + Err(err) => { + log::warn!("Failed to collect GPU information for crash report: {err}"); + vec![] } - } + }; + + let crash_info = CrashInfo { + init: self + .initialization_params + .get() + .expect("not initialized") + .clone(), + panic: self.panic_info.get().cloned(), + minidump_error, + active_gpu: self.active_gpu.get().cloned(), + gpus, + }; + + let crash_data_path = paths::logs_dir() + .join(&crash_info.init.session_id) + .with_extension("json"); + + fs::write(crash_data_path, serde_json::to_vec(&crash_info).unwrap()).ok(); + LoopAction::Exit } fn on_message(&self, kind: u32, buffer: Vec) { - let message = String::from_utf8(buffer).expect("invalid utf-8"); - info!("kind: {kind}, message: {message}",); - if kind == 1 { - self.session_id - .set(message) - .expect("session id already initialized"); + match kind { + 1 => { + let init_data = + serde_json::from_slice::(&buffer).expect("invalid init data"); + self.initialization_params + .set(init_data) + .expect("already initialized"); + } + 2 => { + let panic_data = + serde_json::from_slice::(&buffer).expect("invalid panic data"); + self.panic_info.set(panic_data).expect("already panicked"); + } + 3 => { + let gpu_specs: system_specs::GpuSpecs = + bincode::deserialize(&buffer).expect("gpu specs"); + self.active_gpu + .set(gpu_specs) + .expect("already set active gpu"); + } + _ => { + panic!("invalid message kind"); + } } } - fn on_client_disconnected(&self, clients: usize) -> LoopAction { - info!("client disconnected, {clients} remaining"); - if clients == 0 { - LoopAction::Exit - } else { - LoopAction::Continue - } + fn on_client_disconnected(&self, _clients: usize) -> LoopAction { + LoopAction::Exit } -} -pub fn handle_panic() { - if !*GENERATE_MINIDUMPS { - return; + fn on_client_connected(&self, _clients: usize) -> LoopAction { + self.has_connection.store(true, Ordering::SeqCst); + LoopAction::Continue } +} + +pub fn handle_panic(message: String, span: Option<&Location>) { + let span = span + .map(|loc| format!("{}:{}", loc.file(), loc.line())) + .unwrap_or_default(); + // wait 500ms for the crash handler process to start up // if it's still not there just write panic info and no minidump let retry_frequency = Duration::from_millis(100); for _ in 0..5 { - if CRASH_HANDLER.load(Ordering::Acquire) { + if let Some(client) = CRASH_HANDLER.get() { + client + .send_message( + 2, + serde_json::to_vec(&CrashPanic { message, span }).unwrap(), + ) + .ok(); log::error!("triggering a crash to generate a minidump..."); + + #[cfg(target_os = "macos")] + PANIC_THREAD_ID.store( + unsafe { mach2::mach_init::mach_thread_self() }, + Ordering::SeqCst, + ); + #[cfg(target_os = "linux")] CrashHandler.simulate_signal(crash_handler::Signal::Trap as u32); #[cfg(not(target_os = "linux"))] @@ -170,14 +296,31 @@ pub fn crash_server(socket: &Path) { log::info!("Couldn't create socket, there may already be a running crash server"); return; }; - let ab = AtomicBool::new(false); + + let shutdown = Arc::new(AtomicBool::new(false)); + let has_connection = Arc::new(AtomicBool::new(false)); + + std::thread::spawn({ + let shutdown = shutdown.clone(); + let has_connection = has_connection.clone(); + move || { + std::thread::sleep(CRASH_HANDLER_CONNECT_TIMEOUT); + if !has_connection.load(Ordering::SeqCst) { + shutdown.store(true, Ordering::SeqCst); + } + } + }); + server .run( Box::new(CrashServer { - session_id: OnceLock::new(), + initialization_params: OnceLock::new(), + panic_info: OnceLock::new(), + has_connection, + active_gpu: OnceLock::new(), }), - &ab, - Some(CRASH_HANDLER_TIMEOUT), + &shutdown, + Some(CRASH_HANDLER_PING_TIMEOUT), ) .expect("failed to run server"); } diff --git a/crates/credentials_provider/src/credentials_provider.rs b/crates/credentials_provider/src/credentials_provider.rs index f72fd6c39b12d5d46cfa1d4f3f30900f01471e64..2c8dd6fc812aaeffd6c06c88ee2adceabdbb27a3 100644 --- a/crates/credentials_provider/src/credentials_provider.rs +++ b/crates/credentials_provider/src/credentials_provider.rs @@ -19,7 +19,7 @@ use release_channel::ReleaseChannel; /// Only works in development. Setting this environment variable in other /// release channels is a no-op. static ZED_DEVELOPMENT_USE_KEYCHAIN: LazyLock = LazyLock::new(|| { - std::env::var("ZED_DEVELOPMENT_USE_KEYCHAIN").map_or(false, |value| !value.is_empty()) + std::env::var("ZED_DEVELOPMENT_USE_KEYCHAIN").is_ok_and(|value| !value.is_empty()) }); /// A provider for credentials. diff --git a/crates/dap/src/adapters.rs b/crates/dap/src/adapters.rs index 687305ae94da3bc1ddd72e9e9f4594f4f4a19ee4..2cef266677c1314f6fd253b9caf77914050ceb96 100644 --- a/crates/dap/src/adapters.rs +++ b/crates/dap/src/adapters.rs @@ -285,7 +285,7 @@ pub async fn download_adapter_from_github( } if !adapter_path.exists() { - fs.create_dir(&adapter_path.as_path()) + fs.create_dir(adapter_path.as_path()) .await .context("Failed creating adapter path")?; } diff --git a/crates/dap/src/client.rs b/crates/dap/src/client.rs index 7b791450ecba3b09b6571ac84fbebdf92fff57b8..2590bf5c8b0db8e70a7897b8de4bc878187e4daa 100644 --- a/crates/dap/src/client.rs +++ b/crates/dap/src/client.rs @@ -23,7 +23,7 @@ impl SessionId { Self(client_id as u32) } - pub fn to_proto(&self) -> u64 { + pub fn to_proto(self) -> u64 { self.0 as u64 } } diff --git a/crates/dap/src/debugger_settings.rs b/crates/dap/src/debugger_settings.rs index e1176633e5403116c2789161d654912337150e9a..4b841450462f1f59787df584cc4ba48eddf792c1 100644 --- a/crates/dap/src/debugger_settings.rs +++ b/crates/dap/src/debugger_settings.rs @@ -2,9 +2,9 @@ use dap_types::SteppingGranularity; use gpui::{App, Global}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsUi}; -#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq, SettingsUi)] #[serde(rename_all = "snake_case")] pub enum DebugPanelDockPosition { Left, @@ -12,12 +12,17 @@ pub enum DebugPanelDockPosition { Right, } -#[derive(Serialize, Deserialize, JsonSchema, Clone, Copy)] +#[derive(Serialize, Deserialize, JsonSchema, Clone, SettingsUi, SettingsKey)] #[serde(default)] +// todo(settings_ui) @ben: I'm pretty sure not having the fields be optional here is a bug, +// it means the defaults will override previously set values if a single key is missing +#[settings_ui(group = "Debugger")] +#[settings_key(key = "debugger")] pub struct DebuggerSettings { /// Determines the stepping granularity. /// /// Default: line + #[settings_ui(skip)] pub stepping_granularity: SteppingGranularity, /// Whether the breakpoints should be reused across Zed sessions. /// @@ -60,8 +65,6 @@ impl Default for DebuggerSettings { } impl Settings for DebuggerSettings { - const KEY: Option<&'static str> = Some("debugger"); - type FileContent = Self; fn load(sources: SettingsSources, _: &mut App) -> anyhow::Result { diff --git a/crates/dap_adapters/src/codelldb.rs b/crates/dap_adapters/src/codelldb.rs index 842bb264a8469402fe73747356ab2e616ab08533..25dc875740e8aba87872e8dc93fa8d77062ed545 100644 --- a/crates/dap_adapters/src/codelldb.rs +++ b/crates/dap_adapters/src/codelldb.rs @@ -385,7 +385,7 @@ impl DebugAdapter for CodeLldbDebugAdapter { && let Some(source_languages) = config.get("sourceLanguages").filter(|value| { value .as_array() - .map_or(false, |array| array.iter().all(Value::is_string)) + .is_some_and(|array| array.iter().all(Value::is_string)) }) { let ret = vec![ diff --git a/crates/dap_adapters/src/go.rs b/crates/dap_adapters/src/go.rs index 22d8262b93e36b17e548ae4dcc9bb725da8ca7cb..db8a45ceb49963eab053f3e7728d4cf715e9a40b 100644 --- a/crates/dap_adapters/src/go.rs +++ b/crates/dap_adapters/src/go.rs @@ -36,7 +36,7 @@ impl GoDebugAdapter { delegate: &Arc, ) -> Result { let release = latest_github_release( - &"zed-industries/delve-shim-dap", + "zed-industries/delve-shim-dap", true, false, delegate.http_client(), diff --git a/crates/dap_adapters/src/javascript.rs b/crates/dap_adapters/src/javascript.rs index 2d19921a0f0c979fe53ede5860ac0c4d26b510c3..a8826d563b09925068dd6da1be865f1e17bce0ec 100644 --- a/crates/dap_adapters/src/javascript.rs +++ b/crates/dap_adapters/src/javascript.rs @@ -99,10 +99,10 @@ impl JsDebugAdapter { } } - if let Some(env) = configuration.get("env").cloned() { - if let Ok(env) = serde_json::from_value(env) { - envs = env; - } + if let Some(env) = configuration.get("env").cloned() + && let Ok(env) = serde_json::from_value(env) + { + envs = env; } configuration @@ -514,7 +514,7 @@ impl DebugAdapter for JsDebugAdapter { } } - self.get_installed_binary(delegate, &config, user_installed_path, user_args, cx) + self.get_installed_binary(delegate, config, user_installed_path, user_args, cx) .await } diff --git a/crates/dap_adapters/src/python.rs b/crates/dap_adapters/src/python.rs index a2bd934311ec21da13d08d23211e62718ec5bbc5..6781e5cbd62d1abc9abfa58223b0771f26cc0c88 100644 --- a/crates/dap_adapters/src/python.rs +++ b/crates/dap_adapters/src/python.rs @@ -24,6 +24,7 @@ use util::{ResultExt, maybe}; #[derive(Default)] pub(crate) struct PythonDebugAdapter { + base_venv_path: OnceCell, String>>, debugpy_whl_base_path: OnceCell, String>>, } @@ -91,14 +92,12 @@ impl PythonDebugAdapter { }) } - async fn fetch_wheel(delegate: &Arc) -> Result, String> { - let system_python = Self::system_python_name(delegate) - .await - .ok_or_else(|| String::from("Could not find a Python installation"))?; - let command: &OsStr = system_python.as_ref(); + async fn fetch_wheel(&self, delegate: &Arc) -> Result, String> { let download_dir = debug_adapters_dir().join(Self::ADAPTER_NAME).join("wheels"); std::fs::create_dir_all(&download_dir).map_err(|e| e.to_string())?; - let installation_succeeded = util::command::new_smol_command(command) + let system_python = self.base_venv_path(delegate).await?; + + let installation_succeeded = util::command::new_smol_command(system_python.as_ref()) .args([ "-m", "pip", @@ -114,7 +113,7 @@ impl PythonDebugAdapter { .status .success(); if !installation_succeeded { - return Err("debugpy installation failed".into()); + return Err("debugpy installation failed (could not fetch Debugpy's wheel)".into()); } let wheel_path = std::fs::read_dir(&download_dir) @@ -139,7 +138,7 @@ impl PythonDebugAdapter { Ok(Arc::from(wheel_path.path())) } - async fn maybe_fetch_new_wheel(delegate: &Arc) { + async fn maybe_fetch_new_wheel(&self, delegate: &Arc) { let latest_release = delegate .http_client() .get( @@ -191,7 +190,7 @@ impl PythonDebugAdapter { ) .await .ok()?; - Self::fetch_wheel(delegate).await.ok()?; + self.fetch_wheel(delegate).await.ok()?; } Some(()) }) @@ -204,7 +203,7 @@ impl PythonDebugAdapter { ) -> Result, String> { self.debugpy_whl_base_path .get_or_init(|| async move { - Self::maybe_fetch_new_wheel(delegate).await; + self.maybe_fetch_new_wheel(delegate).await; Ok(Arc::from( debug_adapters_dir() .join(Self::ADAPTER_NAME) @@ -217,6 +216,46 @@ impl PythonDebugAdapter { .clone() } + async fn base_venv_path(&self, delegate: &Arc) -> Result, String> { + self.base_venv_path + .get_or_init(|| async { + let base_python = Self::system_python_name(delegate) + .await + .ok_or_else(|| String::from("Could not find a Python installation"))?; + + let did_succeed = util::command::new_smol_command(base_python) + .args(["-m", "venv", "zed_base_venv"]) + .current_dir( + paths::debug_adapters_dir().join(Self::DEBUG_ADAPTER_NAME.as_ref()), + ) + .spawn() + .map_err(|e| format!("{e:#?}"))? + .status() + .await + .map_err(|e| format!("{e:#?}"))? + .success(); + + if !did_succeed { + return Err("Failed to create base virtual environment".into()); + } + + const DIR: &str = if cfg!(target_os = "windows") { + "Scripts" + } else { + "bin" + }; + Ok(Arc::from( + paths::debug_adapters_dir() + .join(Self::DEBUG_ADAPTER_NAME.as_ref()) + .join("zed_base_venv") + .join(DIR) + .join("python3") + .as_ref(), + )) + }) + .await + .clone() + } async fn system_python_name(delegate: &Arc) -> Option { const BINARY_NAMES: [&str; 3] = ["python3", "python", "py"]; let mut name = None; @@ -679,7 +718,7 @@ impl DebugAdapter for PythonDebugAdapter { local_path.display() ); return self - .get_installed_binary(delegate, &config, Some(local_path.clone()), user_args, None) + .get_installed_binary(delegate, config, Some(local_path.clone()), user_args, None) .await; } @@ -716,7 +755,7 @@ impl DebugAdapter for PythonDebugAdapter { return self .get_installed_binary( delegate, - &config, + config, None, user_args, Some(toolchain.path.to_string()), @@ -724,7 +763,7 @@ impl DebugAdapter for PythonDebugAdapter { .await; } - self.get_installed_binary(delegate, &config, None, user_args, None) + self.get_installed_binary(delegate, config, None, user_args, None) .await } diff --git a/crates/db/Cargo.toml b/crates/db/Cargo.toml index c53b2988b94dd5b355e132024c2677b61a83d071..de449cd38f77d062eda906cced3e3b697a370d15 100644 --- a/crates/db/Cargo.toml +++ b/crates/db/Cargo.toml @@ -27,6 +27,7 @@ sqlez.workspace = true sqlez_macros.workspace = true util.workspace = true workspace-hack.workspace = true +zed_env_vars.workspace = true [dev-dependencies] gpui = { workspace = true, features = ["test-support"] } diff --git a/crates/db/src/db.rs b/crates/db/src/db.rs index de55212cbadfdd3c66ede66b706a3d120dd765c5..eab2f115d8e5c3db51541544a8dbc95f34713741 100644 --- a/crates/db/src/db.rs +++ b/crates/db/src/db.rs @@ -17,9 +17,10 @@ use sqlez::thread_safe_connection::ThreadSafeConnection; use sqlez_macros::sql; use std::future::Future; use std::path::Path; +use std::sync::atomic::AtomicBool; use std::sync::{LazyLock, atomic::Ordering}; -use std::{env, sync::atomic::AtomicBool}; use util::{ResultExt, maybe}; +use zed_env_vars::ZED_STATELESS; const CONNECTION_INITIALIZE_QUERY: &str = sql!( PRAGMA foreign_keys=TRUE; @@ -36,9 +37,6 @@ const FALLBACK_DB_NAME: &str = "FALLBACK_MEMORY_DB"; const DB_FILE_NAME: &str = "db.sqlite"; -pub static ZED_STATELESS: LazyLock = - LazyLock::new(|| env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty())); - pub static ALL_FILE_DB_FAILED: LazyLock = LazyLock::new(|| AtomicBool::new(false)); /// Open or create a database at the given directory path. @@ -74,7 +72,7 @@ pub async fn open_db(db_dir: &Path, scope: &str) -> Threa } async fn open_main_db(db_path: &Path) -> Option { - log::info!("Opening database {}", db_path.display()); + log::trace!("Opening database {}", db_path.display()); ThreadSafeConnection::builder::(db_path.to_string_lossy().as_ref(), true) .with_db_initialization_query(DB_INITIALIZE_QUERY) .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY) @@ -110,11 +108,14 @@ pub async fn open_test_db(db_name: &str) -> ThreadSafeConnection { } /// Implements a basic DB wrapper for a given domain +/// +/// Arguments: +/// - static variable name for connection +/// - type of connection wrapper +/// - dependencies, whose migrations should be run prior to this domain's migrations #[macro_export] -macro_rules! define_connection { - (pub static ref $id:ident: $t:ident<()> = $migrations:expr; $($global:ident)?) => { - pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection); - +macro_rules! static_connection { + ($id:ident, $t:ident, [ $($d:ty),* ] $(, $global:ident)?) => { impl ::std::ops::Deref for $t { type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection; @@ -123,16 +124,6 @@ macro_rules! define_connection { } } - impl $crate::sqlez::domain::Domain for $t { - fn name() -> &'static str { - stringify!($t) - } - - fn migrations() -> &'static [&'static str] { - $migrations - } - } - impl $t { #[cfg(any(test, feature = "test-support"))] pub async fn open_test_db(name: &'static str) -> Self { @@ -142,7 +133,8 @@ macro_rules! define_connection { #[cfg(any(test, feature = "test-support"))] pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| { - $t($crate::smol::block_on($crate::open_test_db::<$t>(stringify!($id)))) + #[allow(unused_parens)] + $t($crate::smol::block_on($crate::open_test_db::<($($d,)* $t)>(stringify!($id)))) }); #[cfg(not(any(test, feature = "test-support")))] @@ -153,46 +145,10 @@ macro_rules! define_connection { } else { $crate::RELEASE_CHANNEL.dev_name() }; - $t($crate::smol::block_on($crate::open_db::<$t>(db_dir, scope))) + #[allow(unused_parens)] + $t($crate::smol::block_on($crate::open_db::<($($d,)* $t)>(db_dir, scope))) }); - }; - (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr; $($global:ident)?) => { - pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection); - - impl ::std::ops::Deref for $t { - type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection; - - fn deref(&self) -> &Self::Target { - &self.0 - } - } - - impl $crate::sqlez::domain::Domain for $t { - fn name() -> &'static str { - stringify!($t) - } - - fn migrations() -> &'static [&'static str] { - $migrations - } - } - - #[cfg(any(test, feature = "test-support"))] - pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| { - $t($crate::smol::block_on($crate::open_test_db::<($($d),+, $t)>(stringify!($id)))) - }); - - #[cfg(not(any(test, feature = "test-support")))] - pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| { - let db_dir = $crate::database_dir(); - let scope = if false $(|| stringify!($global) == "global")? { - "global" - } else { - $crate::RELEASE_CHANNEL.dev_name() - }; - $t($crate::smol::block_on($crate::open_db::<($($d),+, $t)>(db_dir, scope))) - }); - }; + } } pub fn write_and_log(cx: &App, db_write: impl FnOnce() -> F + Send + 'static) @@ -219,17 +175,12 @@ mod tests { enum BadDB {} impl Domain for BadDB { - fn name() -> &'static str { - "db_tests" - } - - fn migrations() -> &'static [&'static str] { - &[ - sql!(CREATE TABLE test(value);), - // failure because test already exists - sql!(CREATE TABLE test(value);), - ] - } + const NAME: &str = "db_tests"; + const MIGRATIONS: &[&str] = &[ + sql!(CREATE TABLE test(value);), + // failure because test already exists + sql!(CREATE TABLE test(value);), + ]; } let tempdir = tempfile::Builder::new() @@ -238,7 +189,7 @@ mod tests { .unwrap(); let _bad_db = open_db::( tempdir.path(), - &release_channel::ReleaseChannel::Dev.dev_name(), + release_channel::ReleaseChannel::Dev.dev_name(), ) .await; } @@ -251,25 +202,15 @@ mod tests { enum CorruptedDB {} impl Domain for CorruptedDB { - fn name() -> &'static str { - "db_tests" - } - - fn migrations() -> &'static [&'static str] { - &[sql!(CREATE TABLE test(value);)] - } + const NAME: &str = "db_tests"; + const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test(value);)]; } enum GoodDB {} impl Domain for GoodDB { - fn name() -> &'static str { - "db_tests" //Notice same name - } - - fn migrations() -> &'static [&'static str] { - &[sql!(CREATE TABLE test2(value);)] //But different migration - } + const NAME: &str = "db_tests"; //Notice same name + const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test2(value);)]; } let tempdir = tempfile::Builder::new() @@ -279,7 +220,7 @@ mod tests { { let corrupt_db = open_db::( tempdir.path(), - &release_channel::ReleaseChannel::Dev.dev_name(), + release_channel::ReleaseChannel::Dev.dev_name(), ) .await; assert!(corrupt_db.persistent()); @@ -287,7 +228,7 @@ mod tests { let good_db = open_db::( tempdir.path(), - &release_channel::ReleaseChannel::Dev.dev_name(), + release_channel::ReleaseChannel::Dev.dev_name(), ) .await; assert!( @@ -305,25 +246,16 @@ mod tests { enum CorruptedDB {} impl Domain for CorruptedDB { - fn name() -> &'static str { - "db_tests" - } + const NAME: &str = "db_tests"; - fn migrations() -> &'static [&'static str] { - &[sql!(CREATE TABLE test(value);)] - } + const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test(value);)]; } enum GoodDB {} impl Domain for GoodDB { - fn name() -> &'static str { - "db_tests" //Notice same name - } - - fn migrations() -> &'static [&'static str] { - &[sql!(CREATE TABLE test2(value);)] //But different migration - } + const NAME: &str = "db_tests"; //Notice same name + const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test2(value);)]; // But different migration } let tempdir = tempfile::Builder::new() @@ -334,7 +266,7 @@ mod tests { // Setup the bad database let corrupt_db = open_db::( tempdir.path(), - &release_channel::ReleaseChannel::Dev.dev_name(), + release_channel::ReleaseChannel::Dev.dev_name(), ) .await; assert!(corrupt_db.persistent()); @@ -347,7 +279,7 @@ mod tests { let guard = thread::spawn(move || { let good_db = smol::block_on(open_db::( tmp_path.as_path(), - &release_channel::ReleaseChannel::Dev.dev_name(), + release_channel::ReleaseChannel::Dev.dev_name(), )); assert!( good_db.select_row::("SELECT * FROM test2").unwrap()() diff --git a/crates/db/src/kvp.rs b/crates/db/src/kvp.rs index daf0b136fde5bd62411c70033e8bcfcb668a5e06..8ea877b35bfaf57bb258e7e179fa5b71f2b518ea 100644 --- a/crates/db/src/kvp.rs +++ b/crates/db/src/kvp.rs @@ -2,16 +2,26 @@ use gpui::App; use sqlez_macros::sql; use util::ResultExt as _; -use crate::{define_connection, query, write_and_log}; +use crate::{ + query, + sqlez::{domain::Domain, thread_safe_connection::ThreadSafeConnection}, + write_and_log, +}; -define_connection!(pub static ref KEY_VALUE_STORE: KeyValueStore<()> = - &[sql!( +pub struct KeyValueStore(crate::sqlez::thread_safe_connection::ThreadSafeConnection); + +impl Domain for KeyValueStore { + const NAME: &str = stringify!(KeyValueStore); + + const MIGRATIONS: &[&str] = &[sql!( CREATE TABLE IF NOT EXISTS kv_store( key TEXT PRIMARY KEY, value TEXT NOT NULL ) STRICT; )]; -); +} + +crate::static_connection!(KEY_VALUE_STORE, KeyValueStore, []); pub trait Dismissable { const KEY: &'static str; @@ -20,7 +30,7 @@ pub trait Dismissable { KEY_VALUE_STORE .read_kvp(Self::KEY) .log_err() - .map_or(false, |s| s.is_some()) + .is_some_and(|s| s.is_some()) } fn set_dismissed(is_dismissed: bool, cx: &mut App) { @@ -91,15 +101,19 @@ mod tests { } } -define_connection!(pub static ref GLOBAL_KEY_VALUE_STORE: GlobalKeyValueStore<()> = - &[sql!( +pub struct GlobalKeyValueStore(ThreadSafeConnection); + +impl Domain for GlobalKeyValueStore { + const NAME: &str = stringify!(GlobalKeyValueStore); + const MIGRATIONS: &[&str] = &[sql!( CREATE TABLE IF NOT EXISTS kv_store( key TEXT PRIMARY KEY, value TEXT NOT NULL ) STRICT; )]; - global -); +} + +crate::static_connection!(GLOBAL_KEY_VALUE_STORE, GlobalKeyValueStore, [], global); impl GlobalKeyValueStore { query! { diff --git a/crates/debugger_tools/src/dap_log.rs b/crates/debugger_tools/src/dap_log.rs index b806381d251c6595a5dd12022dc3d1df8b71739f..c4338c6d0017a215c721c772871647c89227775e 100644 --- a/crates/debugger_tools/src/dap_log.rs +++ b/crates/debugger_tools/src/dap_log.rs @@ -392,7 +392,7 @@ impl LogStore { session.label(), session .adapter_client() - .map_or(false, |client| client.has_adapter_logs()), + .is_some_and(|client| client.has_adapter_logs()), ) }); @@ -485,7 +485,7 @@ impl LogStore { &mut self, id: &LogStoreEntryIdentifier<'_>, ) -> Option<&Vec> { - self.get_debug_adapter_state(&id) + self.get_debug_adapter_state(id) .map(|state| &state.rpc_messages.initialization_sequence) } } @@ -536,11 +536,11 @@ impl Render for DapLogToolbarItemView { }) .unwrap_or_else(|| "No adapter selected".into()), )) - .menu(move |mut window, cx| { + .menu(move |window, cx| { let log_view = log_view.clone(); let menu_rows = menu_rows.clone(); let project = project.clone(); - ContextMenu::build(&mut window, cx, move |mut menu, window, _cx| { + ContextMenu::build(window, cx, move |mut menu, window, _cx| { for row in menu_rows.into_iter() { menu = menu.custom_row(move |_window, _cx| { div() @@ -661,11 +661,11 @@ impl ToolbarItemView for DapLogToolbarItemView { _window: &mut Window, cx: &mut Context, ) -> workspace::ToolbarItemLocation { - if let Some(item) = active_pane_item { - if let Some(log_view) = item.downcast::() { - self.log_view = Some(log_view.clone()); - return workspace::ToolbarItemLocation::PrimaryLeft; - } + if let Some(item) = active_pane_item + && let Some(log_view) = item.downcast::() + { + self.log_view = Some(log_view); + return workspace::ToolbarItemLocation::PrimaryLeft; } self.log_view = None; @@ -1131,7 +1131,7 @@ impl LogStore { project: &WeakEntity, session_id: SessionId, ) -> Vec { - self.projects.get(&project).map_or(vec![], |state| { + self.projects.get(project).map_or(vec![], |state| { state .debug_sessions .get(&session_id) diff --git a/crates/debugger_ui/src/attach_modal.rs b/crates/debugger_ui/src/attach_modal.rs index 662a98c82075cd6e936988959c855eadb5138092..8926b3bb6f61e6e612f75777584810fa24b616ba 100644 --- a/crates/debugger_ui/src/attach_modal.rs +++ b/crates/debugger_ui/src/attach_modal.rs @@ -1,8 +1,10 @@ use dap::{DapRegistry, DebugRequest}; use fuzzy::{StringMatch, StringMatchCandidate}; -use gpui::{AppContext, DismissEvent, Entity, EventEmitter, Focusable, Render}; +use gpui::{AppContext, DismissEvent, Entity, EventEmitter, Focusable, Render, Task}; use gpui::{Subscription, WeakEntity}; use picker::{Picker, PickerDelegate}; +use project::Project; +use rpc::proto; use task::ZedDebugConfig; use util::debug_panic; @@ -56,29 +58,28 @@ impl AttachModal { pub fn new( definition: ZedDebugConfig, workspace: WeakEntity, + project: Entity, modal: bool, window: &mut Window, cx: &mut Context, ) -> Self { - let mut processes: Box<[_]> = System::new_all() - .processes() - .values() - .map(|process| { - let name = process.name().to_string_lossy().into_owned(); - Candidate { - name: name.into(), - pid: process.pid().as_u32(), - command: process - .cmd() - .iter() - .map(|s| s.to_string_lossy().to_string()) - .collect::>(), - } - }) - .collect(); - processes.sort_by_key(|k| k.name.clone()); - let processes = processes.into_iter().collect(); - Self::with_processes(workspace, definition, processes, modal, window, cx) + let processes_task = get_processes_for_project(&project, cx); + + let modal = Self::with_processes(workspace, definition, Arc::new([]), modal, window, cx); + + cx.spawn_in(window, async move |this, cx| { + let processes = processes_task.await; + this.update_in(cx, |modal, window, cx| { + modal.picker.update(cx, |picker, cx| { + picker.delegate.candidates = processes; + picker.refresh(window, cx); + }); + })?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + + modal } pub(super) fn with_processes( @@ -288,7 +289,7 @@ impl PickerDelegate for AttachModalDelegate { _window: &mut Window, _: &mut Context>, ) -> Option { - let hit = &self.matches[ix]; + let hit = &self.matches.get(ix)?; let candidate = self.candidates.get(hit.candidate_id)?; Some( @@ -332,6 +333,57 @@ impl PickerDelegate for AttachModalDelegate { } } +fn get_processes_for_project(project: &Entity, cx: &mut App) -> Task> { + let project = project.read(cx); + + if let Some(remote_client) = project.remote_client() { + let proto_client = remote_client.read(cx).proto_client(); + cx.spawn(async move |_cx| { + let response = proto_client + .request(proto::GetProcesses { + project_id: proto::REMOTE_SERVER_PROJECT_ID, + }) + .await + .unwrap_or_else(|_| proto::GetProcessesResponse { + processes: Vec::new(), + }); + + let mut processes: Vec = response + .processes + .into_iter() + .map(|p| Candidate { + pid: p.pid, + name: p.name.into(), + command: p.command, + }) + .collect(); + + processes.sort_by_key(|k| k.name.clone()); + Arc::from(processes.into_boxed_slice()) + }) + } else { + let mut processes: Box<[_]> = System::new_all() + .processes() + .values() + .map(|process| { + let name = process.name().to_string_lossy().into_owned(); + Candidate { + name: name.into(), + pid: process.pid().as_u32(), + command: process + .cmd() + .iter() + .map(|s| s.to_string_lossy().to_string()) + .collect::>(), + } + }) + .collect(); + processes.sort_by_key(|k| k.name.clone()); + let processes = processes.into_iter().collect(); + Task::ready(processes) + } +} + #[cfg(any(test, feature = "test-support"))] pub(crate) fn _process_names(modal: &AttachModal, cx: &mut Context) -> Vec { modal.picker.read_with(cx, |picker, _| { diff --git a/crates/debugger_ui/src/debugger_panel.rs b/crates/debugger_ui/src/debugger_panel.rs index 1d44c5c2448afba50f682ea8ae96da8d3104945f..ef714a1f6710f54c5673eac097e7530b3c605b58 100644 --- a/crates/debugger_ui/src/debugger_panel.rs +++ b/crates/debugger_ui/src/debugger_panel.rs @@ -13,11 +13,8 @@ use anyhow::{Context as _, Result, anyhow}; use collections::IndexMap; use dap::adapters::DebugAdapterName; use dap::debugger_settings::DebugPanelDockPosition; -use dap::{ - ContinuedEvent, LoadedSourceEvent, ModuleEvent, OutputEvent, StoppedEvent, ThreadEvent, - client::SessionId, debugger_settings::DebuggerSettings, -}; use dap::{DapRegistry, StartDebuggingRequestArguments}; +use dap::{client::SessionId, debugger_settings::DebuggerSettings}; use editor::Editor; use gpui::{ Action, App, AsyncWindowContext, ClipboardItem, Context, DismissEvent, Entity, EntityId, @@ -46,23 +43,6 @@ use workspace::{ }; use zed_actions::ToggleFocus; -pub enum DebugPanelEvent { - Exited(SessionId), - Terminated(SessionId), - Stopped { - client_id: SessionId, - event: StoppedEvent, - go_to_stack_frame: bool, - }, - Thread((SessionId, ThreadEvent)), - Continued((SessionId, ContinuedEvent)), - Output((SessionId, OutputEvent)), - Module((SessionId, ModuleEvent)), - LoadedSource((SessionId, LoadedSourceEvent)), - ClientShutdown(SessionId), - CapabilitiesChanged(SessionId), -} - pub struct DebugPanel { size: Pixels, active_session: Option>, @@ -257,7 +237,7 @@ impl DebugPanel { .as_ref() .map(|entity| entity.downgrade()), task_context: task_context.clone(), - worktree_id: worktree_id, + worktree_id, }); }; running.resolve_scenario( @@ -386,10 +366,10 @@ impl DebugPanel { return; }; - let dap_store_handle = self.project.read(cx).dap_store().clone(); + let dap_store_handle = self.project.read(cx).dap_store(); let label = curr_session.read(cx).label(); let quirks = curr_session.read(cx).quirks(); - let adapter = curr_session.read(cx).adapter().clone(); + let adapter = curr_session.read(cx).adapter(); let binary = curr_session.read(cx).binary().cloned().unwrap(); let task_context = curr_session.read(cx).task_context().clone(); @@ -447,9 +427,9 @@ impl DebugPanel { return; }; - let dap_store_handle = self.project.read(cx).dap_store().clone(); + let dap_store_handle = self.project.read(cx).dap_store(); let label = self.label_for_child_session(&parent_session, request, cx); - let adapter = parent_session.read(cx).adapter().clone(); + let adapter = parent_session.read(cx).adapter(); let quirks = parent_session.read(cx).quirks(); let Some(mut binary) = parent_session.read(cx).binary().cloned() else { log::error!("Attempted to start a child-session without a binary"); @@ -530,10 +510,9 @@ impl DebugPanel { .active_session .as_ref() .map(|session| session.entity_id()) + && active_session_id == entity_id { - if active_session_id == entity_id { - this.active_session = this.sessions_with_children.keys().next().cloned(); - } + this.active_session = this.sessions_with_children.keys().next().cloned(); } cx.notify() }) @@ -693,7 +672,7 @@ impl DebugPanel { ) .icon_size(IconSize::Small) .on_click(window.listener_for( - &running_state, + running_state, |this, _, _window, cx| { this.pause_thread(cx); }, @@ -719,7 +698,7 @@ impl DebugPanel { ) .icon_size(IconSize::Small) .on_click(window.listener_for( - &running_state, + running_state, |this, _, _window, cx| this.continue_thread(cx), )) .disabled(thread_status != ThreadStatus::Stopped) @@ -742,7 +721,7 @@ impl DebugPanel { IconButton::new("debug-step-over", IconName::ArrowRight) .icon_size(IconSize::Small) .on_click(window.listener_for( - &running_state, + running_state, |this, _, _window, cx| { this.step_over(cx); }, @@ -768,7 +747,7 @@ impl DebugPanel { ) .icon_size(IconSize::Small) .on_click(window.listener_for( - &running_state, + running_state, |this, _, _window, cx| { this.step_in(cx); }, @@ -791,7 +770,7 @@ impl DebugPanel { IconButton::new("debug-step-out", IconName::ArrowUpRight) .icon_size(IconSize::Small) .on_click(window.listener_for( - &running_state, + running_state, |this, _, _window, cx| { this.step_out(cx); }, @@ -815,7 +794,7 @@ impl DebugPanel { IconButton::new("debug-restart", IconName::RotateCcw) .icon_size(IconSize::Small) .on_click(window.listener_for( - &running_state, + running_state, |this, _, window, cx| { this.rerun_session(window, cx); }, @@ -837,7 +816,7 @@ impl DebugPanel { IconButton::new("debug-stop", IconName::Power) .icon_size(IconSize::Small) .on_click(window.listener_for( - &running_state, + running_state, |this, _, _window, cx| { if this.session().read(cx).is_building() { this.session().update(cx, |session, cx| { @@ -892,7 +871,7 @@ impl DebugPanel { ) .icon_size(IconSize::Small) .on_click(window.listener_for( - &running_state, + running_state, |this, _, _, cx| { this.detach_client(cx); }, @@ -933,7 +912,6 @@ impl DebugPanel { .cloned(), |this, running_state| { this.children({ - let running_state = running_state.clone(); let threads = running_state.update(cx, |running_state, cx| { let session = running_state.session(); @@ -1160,7 +1138,7 @@ impl DebugPanel { workspace .project() .read(cx) - .project_path_for_absolute_path(&path, cx) + .project_path_for_absolute_path(path, cx) .context( "Couldn't get project path for .zed/debug.json in active worktree", ) @@ -1302,10 +1280,10 @@ impl DebugPanel { cx: &mut Context<'_, Self>, ) -> Option { let adapter = parent_session.read(cx).adapter(); - if let Some(adapter) = DapRegistry::global(cx).adapter(&adapter) { - if let Some(label) = adapter.label_for_child_session(request) { - return Some(label.into()); - } + if let Some(adapter) = DapRegistry::global(cx).adapter(&adapter) + && let Some(label) = adapter.label_for_child_session(request) + { + return Some(label.into()); } None } @@ -1409,7 +1387,6 @@ async fn register_session_inner( } impl EventEmitter for DebugPanel {} -impl EventEmitter for DebugPanel {} impl Focusable for DebugPanel { fn focus_handle(&self, _: &App) -> FocusHandle { @@ -1646,7 +1623,6 @@ impl Render for DebugPanel { } }) .on_action({ - let this = this.clone(); move |_: &ToggleSessionPicker, window, cx| { this.update(cx, |this, cx| { this.toggle_session_picker(window, cx); diff --git a/crates/debugger_ui/src/debugger_ui.rs b/crates/debugger_ui/src/debugger_ui.rs index 5f5dfd1a1e6a543cdb7a4d87e1b8e9984c4ecba9..689e3cd878b574d31963231df9bcff317ea6d64c 100644 --- a/crates/debugger_ui/src/debugger_ui.rs +++ b/crates/debugger_ui/src/debugger_ui.rs @@ -85,6 +85,10 @@ actions!( Rerun, /// Toggles expansion of the selected item in the debugger UI. ToggleExpandItem, + /// Toggle the user frame filter in the stack frame list + /// When toggled on, only frames from the user's code are shown + /// When toggled off, all frames are shown + ToggleUserFrames, ] ); @@ -279,6 +283,18 @@ pub fn init(cx: &mut App) { .ok(); } }) + .on_action(move |_: &ToggleUserFrames, _, cx| { + if let Some((thread_status, stack_frame_list)) = active_item + .read_with(cx, |item, cx| { + (item.thread_status(cx), item.stack_frame_list().clone()) + }) + .ok() + { + stack_frame_list.update(cx, |stack_frame_list, cx| { + stack_frame_list.toggle_frame_filter(thread_status, cx); + }) + } + }) }); }) .detach(); @@ -293,9 +309,8 @@ pub fn init(cx: &mut App) { let Some(debug_panel) = workspace.read(cx).panel::(cx) else { return; }; - let Some(active_session) = debug_panel - .clone() - .update(cx, |panel, _| panel.active_session()) + let Some(active_session) = + debug_panel.update(cx, |panel, _| panel.active_session()) else { return; }; diff --git a/crates/debugger_ui/src/dropdown_menus.rs b/crates/debugger_ui/src/dropdown_menus.rs index dca15eb0527cfc78bd137889a1910e6b32abf98c..376a4a41ce7b03cd07f578d85f641a6ddfc4ebe8 100644 --- a/crates/debugger_ui/src/dropdown_menus.rs +++ b/crates/debugger_ui/src/dropdown_menus.rs @@ -1,9 +1,9 @@ -use std::{rc::Rc, time::Duration}; +use std::rc::Rc; use collections::HashMap; -use gpui::{Animation, AnimationExt as _, Entity, Transformation, WeakEntity, percentage}; +use gpui::{Entity, WeakEntity}; use project::debugger::session::{ThreadId, ThreadStatus}; -use ui::{ContextMenu, DropdownMenu, DropdownStyle, Indicator, prelude::*}; +use ui::{CommonAnimationExt, ContextMenu, DropdownMenu, DropdownStyle, Indicator, prelude::*}; use util::{maybe, truncate_and_trailoff}; use crate::{ @@ -113,23 +113,6 @@ impl DebugPanel { } }; session_entries.push(root_entry); - - session_entries.extend( - sessions_with_children - .by_ref() - .take_while(|(session, _)| { - session - .read(cx) - .session(cx) - .read(cx) - .parent_id(cx) - .is_some() - }) - .map(|(session, _)| SessionListEntry { - leaf: session.clone(), - ancestors: vec![], - }), - ); } let weak = cx.weak_entity(); @@ -152,11 +135,7 @@ impl DebugPanel { Icon::new(IconName::ArrowCircle) .size(IconSize::Small) .color(Color::Muted) - .with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), - ) + .with_rotate_animation(2) .into_any_element() } else { match running_state.thread_status(cx).unwrap_or_default() { @@ -272,10 +251,9 @@ impl DebugPanel { .child(session_entry.label_element(self_depth, cx)) .child( IconButton::new("close-debug-session", IconName::Close) - .visible_on_hover(id.clone()) + .visible_on_hover(id) .icon_size(IconSize::Small) .on_click({ - let weak = weak.clone(); move |_, window, cx| { weak.update(cx, |panel, cx| { panel.close_session(session_entity_id, window, cx); diff --git a/crates/debugger_ui/src/new_process_modal.rs b/crates/debugger_ui/src/new_process_modal.rs index 4ac8e371a15052a00ed962480a9f694a8802007c..eeed36ac1df18d5a0d1b33d6c3567c7df6a4b9c0 100644 --- a/crates/debugger_ui/src/new_process_modal.rs +++ b/crates/debugger_ui/src/new_process_modal.rs @@ -20,7 +20,7 @@ use gpui::{ }; use itertools::Itertools as _; use picker::{Picker, PickerDelegate, highlighted_match_with_paths::HighlightedMatch}; -use project::{DebugScenarioContext, TaskContexts, TaskSourceKind, task_store::TaskStore}; +use project::{DebugScenarioContext, Project, TaskContexts, TaskSourceKind, task_store::TaskStore}; use settings::Settings; use task::{DebugScenario, RevealTarget, ZedDebugConfig}; use theme::ThemeSettings; @@ -88,8 +88,10 @@ impl NewProcessModal { })?; workspace.update_in(cx, |workspace, window, cx| { let workspace_handle = workspace.weak_handle(); + let project = workspace.project().clone(); workspace.toggle_modal(window, cx, |window, cx| { - let attach_mode = AttachMode::new(None, workspace_handle.clone(), window, cx); + let attach_mode = + AttachMode::new(None, workspace_handle.clone(), project, window, cx); let debug_picker = cx.new(|cx| { let delegate = @@ -343,10 +345,10 @@ impl NewProcessModal { return; } - if let NewProcessMode::Launch = &self.mode { - if self.configure_mode.read(cx).save_to_debug_json.selected() { - self.save_debug_scenario(window, cx); - } + if let NewProcessMode::Launch = &self.mode + && self.configure_mode.read(cx).save_to_debug_json.selected() + { + self.save_debug_scenario(window, cx); } let Some(debugger) = self.debugger.clone() else { @@ -413,7 +415,7 @@ impl NewProcessModal { let Some(adapter) = self.debugger.as_ref() else { return; }; - let scenario = self.debug_scenario(&adapter, cx); + let scenario = self.debug_scenario(adapter, cx); cx.spawn_in(window, async move |this, cx| { let scenario = scenario.await.context("no scenario to save")?; let worktree_id = task_contexts @@ -659,12 +661,7 @@ impl Render for NewProcessModal { this.mode = NewProcessMode::Attach; if let Some(debugger) = this.debugger.as_ref() { - Self::update_attach_picker( - &this.attach_mode, - &debugger, - window, - cx, - ); + Self::update_attach_picker(&this.attach_mode, debugger, window, cx); } this.mode_focus_handle(cx).focus(window); cx.notify(); @@ -790,7 +787,7 @@ impl RenderOnce for AttachMode { v_flex() .w_full() .track_focus(&self.attach_picker.focus_handle(cx)) - .child(self.attach_picker.clone()) + .child(self.attach_picker) } } @@ -806,12 +803,12 @@ impl ConfigureMode { pub(super) fn new(window: &mut Window, cx: &mut App) -> Entity { let program = cx.new(|cx| Editor::single_line(window, cx)); program.update(cx, |this, cx| { - this.set_placeholder_text("ENV=Zed ~/bin/program --option", cx); + this.set_placeholder_text("ENV=Zed ~/bin/program --option", window, cx); }); let cwd = cx.new(|cx| Editor::single_line(window, cx)); cwd.update(cx, |this, cx| { - this.set_placeholder_text("Ex: $ZED_WORKTREE_ROOT", cx); + this.set_placeholder_text("Ex: $ZED_WORKTREE_ROOT", window, cx); }); cx.new(|_| Self { @@ -945,6 +942,7 @@ impl AttachMode { pub(super) fn new( debugger: Option, workspace: WeakEntity, + project: Entity, window: &mut Window, cx: &mut Context, ) -> Entity { @@ -955,7 +953,7 @@ impl AttachMode { stop_on_entry: Some(false), }; let attach_picker = cx.new(|cx| { - let modal = AttachModal::new(definition.clone(), workspace, false, window, cx); + let modal = AttachModal::new(definition.clone(), workspace, project, false, window, cx); window.focus(&modal.focus_handle(cx)); modal @@ -1083,7 +1081,7 @@ impl DebugDelegate { .into_iter() .map(|(scenario, context)| { let (kind, scenario) = - Self::get_scenario_kind(&languages, &dap_registry, scenario); + Self::get_scenario_kind(&languages, dap_registry, scenario); (kind, scenario, Some(context)) }) .chain( @@ -1100,7 +1098,7 @@ impl DebugDelegate { .filter(|(_, scenario)| valid_adapters.contains(&scenario.adapter)) .map(|(kind, scenario)| { let (language, scenario) = - Self::get_scenario_kind(&languages, &dap_registry, scenario); + Self::get_scenario_kind(&languages, dap_registry, scenario); (language.or(Some(kind)), scenario, None) }), ) @@ -1388,14 +1386,28 @@ impl PickerDelegate for DebugDelegate { .border_color(cx.theme().colors().border_variant) .children({ let action = menu::SecondaryConfirm.boxed_clone(); - KeyBinding::for_action(&*action, window, cx).map(|keybind| { - Button::new("edit-debug-task", "Edit in debug.json") - .label_size(LabelSize::Small) - .key_binding(keybind) - .on_click(move |_, window, cx| { - window.dispatch_action(action.boxed_clone(), cx) - }) - }) + if self.matches.is_empty() { + Some( + Button::new("edit-debug-json", "Edit debug.json") + .label_size(LabelSize::Small) + .on_click(cx.listener(|_picker, _, window, cx| { + window.dispatch_action( + zed_actions::OpenProjectDebugTasks.boxed_clone(), + cx, + ); + cx.emit(DismissEvent); + })), + ) + } else { + KeyBinding::for_action(&*action, window, cx).map(|keybind| { + Button::new("edit-debug-task", "Edit in debug.json") + .label_size(LabelSize::Small) + .key_binding(keybind) + .on_click(move |_, window, cx| { + window.dispatch_action(action.boxed_clone(), cx) + }) + }) + } }) .map(|this| { if (current_modifiers.alt || self.matches.is_empty()) && !self.prompt.is_empty() { @@ -1434,7 +1446,7 @@ impl PickerDelegate for DebugDelegate { window: &mut Window, cx: &mut Context>, ) -> Option { - let hit = &self.matches[ix]; + let hit = &self.matches.get(ix)?; let highlighted_location = HighlightedMatch { text: hit.string.clone(), diff --git a/crates/debugger_ui/src/persistence.rs b/crates/debugger_ui/src/persistence.rs index 3a0ad7a40e60d4dc28f2086b94a0a43186978542..ab68fea1154182fe266bb150d762f8be0995d733 100644 --- a/crates/debugger_ui/src/persistence.rs +++ b/crates/debugger_ui/src/persistence.rs @@ -256,7 +256,7 @@ pub(crate) fn deserialize_pane_layout( Some(Member::Axis(PaneAxis::load( if should_invert { axis.invert() } else { axis }, members, - flexes.clone(), + flexes, ))) } SerializedPaneLayout::Pane(serialized_pane) => { @@ -270,12 +270,9 @@ pub(crate) fn deserialize_pane_layout( .children .iter() .map(|child| match child { - DebuggerPaneItem::Frames => Box::new(SubView::new( - stack_frame_list.focus_handle(cx), - stack_frame_list.clone().into(), - DebuggerPaneItem::Frames, - cx, - )), + DebuggerPaneItem::Frames => { + Box::new(SubView::stack_frame_list(stack_frame_list.clone(), cx)) + } DebuggerPaneItem::Variables => Box::new(SubView::new( variable_list.focus_handle(cx), variable_list.clone().into(), @@ -341,7 +338,7 @@ impl SerializedPaneLayout { pub(crate) fn in_order(&self) -> Vec { let mut panes = vec![]; - Self::inner_in_order(&self, &mut panes); + Self::inner_in_order(self, &mut panes); panes } diff --git a/crates/debugger_ui/src/session.rs b/crates/debugger_ui/src/session.rs index 73cfef78cc6410196441ff974f09b5abe3d86916..40c9bd810f9c5c9691f51f3d38957a98c9f037a2 100644 --- a/crates/debugger_ui/src/session.rs +++ b/crates/debugger_ui/src/session.rs @@ -2,9 +2,7 @@ pub mod running; use crate::{StackTraceView, persistence::SerializedLayout, session::running::DebugTerminal}; use dap::client::SessionId; -use gpui::{ - App, Axis, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, -}; +use gpui::{App, Axis, Entity, EventEmitter, FocusHandle, Focusable, Task, WeakEntity}; use project::debugger::session::Session; use project::worktree_store::WorktreeStore; use project::{Project, debugger::session::SessionQuirks}; @@ -24,13 +22,6 @@ pub struct DebugSession { stack_trace_view: OnceCell>, _worktree_store: WeakEntity, workspace: WeakEntity, - _subscriptions: [Subscription; 1], -} - -#[derive(Debug)] -pub enum DebugPanelItemEvent { - Close, - Stopped { go_to_stack_frame: bool }, } impl DebugSession { @@ -59,9 +50,6 @@ impl DebugSession { let quirks = session.read(cx).quirks(); cx.new(|cx| Self { - _subscriptions: [cx.subscribe(&running_state, |_, _, _, cx| { - cx.notify(); - })], remote_id: None, running_state, quirks, @@ -87,7 +75,7 @@ impl DebugSession { self.stack_trace_view.get_or_init(|| { let stackframe_list = running_state.read(cx).stack_frame_list().clone(); - let stack_frame_view = cx.new(|cx| { + cx.new(|cx| { StackTraceView::new( workspace.clone(), project.clone(), @@ -95,9 +83,7 @@ impl DebugSession { window, cx, ) - }); - - stack_frame_view + }) }) } @@ -135,7 +121,7 @@ impl DebugSession { } } -impl EventEmitter for DebugSession {} +impl EventEmitter<()> for DebugSession {} impl Focusable for DebugSession { fn focus_handle(&self, cx: &App) -> FocusHandle { @@ -144,7 +130,7 @@ impl Focusable for DebugSession { } impl Item for DebugSession { - type Event = DebugPanelItemEvent; + type Event = (); fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { "Debugger".into() } diff --git a/crates/debugger_ui/src/session/running.rs b/crates/debugger_ui/src/session/running.rs index c8bee42039c41a8fd2e393bccd25f082aba488e4..a18a186469a0aaaf5f3d061830446f5ba27dec72 100644 --- a/crates/debugger_ui/src/session/running.rs +++ b/crates/debugger_ui/src/session/running.rs @@ -14,7 +14,6 @@ use crate::{ session::running::memory_view::MemoryView, }; -use super::DebugPanelItemEvent; use anyhow::{Context as _, Result, anyhow}; use breakpoint_list::BreakpointList; use collections::{HashMap, IndexMap}; @@ -36,7 +35,6 @@ use module_list::ModuleList; use project::{ DebugScenarioContext, Project, WorktreeId, debugger::session::{self, Session, SessionEvent, SessionStateEvent, ThreadId, ThreadStatus}, - terminals::TerminalKind, }; use rpc::proto::ViewId; use serde_json::Value; @@ -102,7 +100,7 @@ impl Render for RunningState { .find(|pane| pane.read(cx).is_zoomed()); let active = self.panes.panes().into_iter().next(); - let pane = if let Some(ref zoomed_pane) = zoomed_pane { + let pane = if let Some(zoomed_pane) = zoomed_pane { zoomed_pane.update(cx, |pane, cx| pane.render(window, cx).into_any_element()) } else if let Some(active) = active { self.panes @@ -158,6 +156,29 @@ impl SubView { }) } + pub(crate) fn stack_frame_list( + stack_frame_list: Entity, + cx: &mut App, + ) -> Entity { + let weak_list = stack_frame_list.downgrade(); + let this = Self::new( + stack_frame_list.focus_handle(cx), + stack_frame_list.into(), + DebuggerPaneItem::Frames, + cx, + ); + + this.update(cx, |this, _| { + this.with_actions(Box::new(move |_, cx| { + weak_list + .update(cx, |this, _| this.render_control_strip()) + .unwrap_or_else(|_| div().into_any_element()) + })); + }); + + this + } + pub(crate) fn console(console: Entity, cx: &mut App) -> Entity { let weak_console = console.downgrade(); let this = Self::new( @@ -180,7 +201,7 @@ impl SubView { let weak_list = list.downgrade(); let focus_handle = list.focus_handle(cx); let this = Self::new( - focus_handle.clone(), + focus_handle, list.into(), DebuggerPaneItem::BreakpointList, cx, @@ -291,7 +312,7 @@ pub(crate) fn new_debugger_pane( let Some(project) = project.upgrade() else { return ControlFlow::Break(()); }; - let this_pane = cx.entity().clone(); + let this_pane = cx.entity(); let item = if tab.pane == this_pane { pane.item_for_index(tab.ix) } else { @@ -358,7 +379,7 @@ pub(crate) fn new_debugger_pane( } }; - let ret = cx.new(move |cx| { + cx.new(move |cx| { let mut pane = Pane::new( workspace.clone(), project.clone(), @@ -414,7 +435,7 @@ pub(crate) fn new_debugger_pane( .and_then(|item| item.downcast::()); let is_hovered = as_subview .as_ref() - .map_or(false, |item| item.read(cx).hovered); + .is_some_and(|item| item.read(cx).hovered); h_flex() .track_focus(&focus_handle) @@ -427,7 +448,6 @@ pub(crate) fn new_debugger_pane( .bg(cx.theme().colors().tab_bar_background) .on_action(|_: &menu::Cancel, window, cx| { if cx.stop_active_drag(window) { - return; } else { cx.propagate(); } @@ -449,7 +469,7 @@ pub(crate) fn new_debugger_pane( .children(pane.items().enumerate().map(|(ix, item)| { let selected = active_pane_item .as_ref() - .map_or(false, |active| active.item_id() == item.item_id()); + .is_some_and(|active| active.item_id() == item.item_id()); let deemphasized = !pane.has_focus(window, cx); let item_ = item.boxed_clone(); div() @@ -502,7 +522,7 @@ pub(crate) fn new_debugger_pane( .on_drag( DraggedTab { item: item.boxed_clone(), - pane: cx.entity().clone(), + pane: cx.entity(), detail: 0, is_active: selected, ix, @@ -563,9 +583,7 @@ pub(crate) fn new_debugger_pane( } }); pane - }); - - ret + }) } pub struct DebugTerminal { @@ -627,7 +645,7 @@ impl RunningState { if s.starts_with("\"$ZED_") && s.ends_with('"') { *s = s[1..s.len() - 1].to_string(); } - if let Some(substituted) = substitute_variables_in_str(&s, context) { + if let Some(substituted) = substitute_variables_in_str(s, context) { *s = substituted; } } @@ -657,7 +675,7 @@ impl RunningState { } resolve_path(s); - if let Some(substituted) = substitute_variables_in_str(&s, context) { + if let Some(substituted) = substitute_variables_in_str(s, context) { *s = substituted; } } @@ -919,7 +937,11 @@ impl RunningState { let task_store = project.read(cx).task_store().downgrade(); let weak_project = project.downgrade(); let weak_workspace = workspace.downgrade(); - let is_local = project.read(cx).is_local(); + let remote_shell = project + .read(cx) + .remote_client() + .as_ref() + .and_then(|remote| remote.read(cx).shell()); cx.spawn_in(window, async move |this, cx| { let DebugScenario { @@ -954,7 +976,7 @@ impl RunningState { inventory.read(cx).task_template_by_label( buffer, worktree_id, - &label, + label, cx, ) }) @@ -1003,7 +1025,7 @@ impl RunningState { None }; - let builder = ShellBuilder::new(is_local, &task.resolved.shell); + let builder = ShellBuilder::new(remote_shell.as_deref(), &task.resolved.shell); let command_label = builder.command_label(&task.resolved.command_label); let (command, args) = builder.build(task.resolved.command.clone(), &task.resolved.args); @@ -1016,12 +1038,11 @@ impl RunningState { }; let terminal = project .update(cx, |project, cx| { - project.create_terminal( - TerminalKind::Task(task_with_shell.clone()), + project.create_terminal_task( + task_with_shell.clone(), cx, ) - })? - .await?; + })?.await?; let terminal_view = cx.new_window_entity(|window, cx| { TerminalView::new( @@ -1116,9 +1137,8 @@ impl RunningState { }; let session = self.session.read(cx); - let cwd = Some(&request.cwd) - .filter(|cwd| cwd.len() > 0) - .map(PathBuf::from) + let cwd = (!request.cwd.is_empty()) + .then(|| PathBuf::from(&request.cwd)) .or_else(|| session.binary().unwrap().cwd.clone()); let mut envs: HashMap = @@ -1153,7 +1173,7 @@ impl RunningState { } else { None } - } else if args.len() > 0 { + } else if !args.is_empty() { Some(args.remove(0)) } else { None @@ -1166,13 +1186,13 @@ impl RunningState { .filter(|title| !title.is_empty()) .or_else(|| command.clone()) .unwrap_or_else(|| "Debug terminal".to_string()); - let kind = TerminalKind::Task(task::SpawnInTerminal { + let kind = task::SpawnInTerminal { id: task::TaskId("debug".to_string()), full_label: title.clone(), label: title.clone(), - command: command.clone(), + command, args, - command_label: title.clone(), + command_label: title, cwd, env: envs, use_new_terminal: true, @@ -1184,12 +1204,13 @@ impl RunningState { show_summary: false, show_command: false, show_rerun: false, - }); + }; let workspace = self.workspace.clone(); let weak_project = project.downgrade(); - let terminal_task = project.update(cx, |project, cx| project.create_terminal(kind, cx)); + let terminal_task = + project.update(cx, |project, cx| project.create_terminal_task(kind, cx)); let terminal_task = cx.spawn_in(window, async move |_, cx| { let terminal = terminal_task.await?; @@ -1310,7 +1331,7 @@ impl RunningState { let mut pane_item_status = IndexMap::from_iter( DebuggerPaneItem::all() .iter() - .filter(|kind| kind.is_supported(&caps)) + .filter(|kind| kind.is_supported(caps)) .map(|kind| (*kind, false)), ); self.panes.panes().iter().for_each(|pane| { @@ -1371,7 +1392,7 @@ impl RunningState { this.serialize_layout(window, cx); match event { Event::Remove { .. } => { - let _did_find_pane = this.panes.remove(&source_pane).is_ok(); + let _did_find_pane = this.panes.remove(source_pane).is_ok(); debug_assert!(_did_find_pane); cx.notify(); } @@ -1759,7 +1780,7 @@ impl RunningState { this.activate_item(0, false, false, window, cx); }); - let rightmost_pane = new_debugger_pane(workspace.clone(), project.clone(), window, cx); + let rightmost_pane = new_debugger_pane(workspace.clone(), project, window, cx); rightmost_pane.update(cx, |this, cx| { this.add_item( Box::new(SubView::new( @@ -1804,8 +1825,6 @@ impl RunningState { } } -impl EventEmitter for RunningState {} - impl Focusable for RunningState { fn focus_handle(&self, _: &App) -> FocusHandle { self.focus_handle.clone() diff --git a/crates/debugger_ui/src/session/running/breakpoint_list.rs b/crates/debugger_ui/src/session/running/breakpoint_list.rs index 38108dbfbcc62e777ea9ee9aa9f1ab1f7d2c2f3d..9fc952a2ea46ac5e5c58c9ddff1f4860447b77b3 100644 --- a/crates/debugger_ui/src/session/running/breakpoint_list.rs +++ b/crates/debugger_ui/src/session/running/breakpoint_list.rs @@ -219,7 +219,7 @@ impl BreakpointList { }); self.input.update(cx, |this, cx| { - this.set_placeholder_text(placeholder, cx); + this.set_placeholder_text(placeholder, window, cx); this.set_read_only(is_exception_breakpoint); this.set_text(active_value.as_deref().unwrap_or(""), window, cx); }); @@ -239,14 +239,12 @@ impl BreakpointList { } fn select_next(&mut self, _: &menu::SelectNext, window: &mut Window, cx: &mut Context) { - if self.strip_mode.is_some() { - if self.input.focus_handle(cx).contains_focused(window, cx) { - cx.propagate(); - return; - } + if self.strip_mode.is_some() && self.input.focus_handle(cx).contains_focused(window, cx) { + cx.propagate(); + return; } let ix = match self.selected_ix { - _ if self.breakpoints.len() == 0 => None, + _ if self.breakpoints.is_empty() => None, None => Some(0), Some(ix) => { if ix == self.breakpoints.len() - 1 { @@ -265,14 +263,12 @@ impl BreakpointList { window: &mut Window, cx: &mut Context, ) { - if self.strip_mode.is_some() { - if self.input.focus_handle(cx).contains_focused(window, cx) { - cx.propagate(); - return; - } + if self.strip_mode.is_some() && self.input.focus_handle(cx).contains_focused(window, cx) { + cx.propagate(); + return; } let ix = match self.selected_ix { - _ if self.breakpoints.len() == 0 => None, + _ if self.breakpoints.is_empty() => None, None => Some(self.breakpoints.len() - 1), Some(ix) => { if ix == 0 { @@ -286,13 +282,11 @@ impl BreakpointList { } fn select_first(&mut self, _: &menu::SelectFirst, window: &mut Window, cx: &mut Context) { - if self.strip_mode.is_some() { - if self.input.focus_handle(cx).contains_focused(window, cx) { - cx.propagate(); - return; - } + if self.strip_mode.is_some() && self.input.focus_handle(cx).contains_focused(window, cx) { + cx.propagate(); + return; } - let ix = if self.breakpoints.len() > 0 { + let ix = if !self.breakpoints.is_empty() { Some(0) } else { None @@ -301,13 +295,11 @@ impl BreakpointList { } fn select_last(&mut self, _: &menu::SelectLast, window: &mut Window, cx: &mut Context) { - if self.strip_mode.is_some() { - if self.input.focus_handle(cx).contains_focused(window, cx) { - cx.propagate(); - return; - } + if self.strip_mode.is_some() && self.input.focus_handle(cx).contains_focused(window, cx) { + cx.propagate(); + return; } - let ix = if self.breakpoints.len() > 0 { + let ix = if !self.breakpoints.is_empty() { Some(self.breakpoints.len() - 1) } else { None @@ -337,8 +329,8 @@ impl BreakpointList { let text = self.input.read(cx).text(cx); match mode { - ActiveBreakpointStripMode::Log => match &entry.kind { - BreakpointEntryKind::LineBreakpoint(line_breakpoint) => { + ActiveBreakpointStripMode::Log => { + if let BreakpointEntryKind::LineBreakpoint(line_breakpoint) = &entry.kind { Self::edit_line_breakpoint_inner( &self.breakpoint_store, line_breakpoint.breakpoint.path.clone(), @@ -347,10 +339,9 @@ impl BreakpointList { cx, ); } - _ => {} - }, - ActiveBreakpointStripMode::Condition => match &entry.kind { - BreakpointEntryKind::LineBreakpoint(line_breakpoint) => { + } + ActiveBreakpointStripMode::Condition => { + if let BreakpointEntryKind::LineBreakpoint(line_breakpoint) = &entry.kind { Self::edit_line_breakpoint_inner( &self.breakpoint_store, line_breakpoint.breakpoint.path.clone(), @@ -359,10 +350,9 @@ impl BreakpointList { cx, ); } - _ => {} - }, - ActiveBreakpointStripMode::HitCondition => match &entry.kind { - BreakpointEntryKind::LineBreakpoint(line_breakpoint) => { + } + ActiveBreakpointStripMode::HitCondition => { + if let BreakpointEntryKind::LineBreakpoint(line_breakpoint) = &entry.kind { Self::edit_line_breakpoint_inner( &self.breakpoint_store, line_breakpoint.breakpoint.path.clone(), @@ -371,8 +361,7 @@ impl BreakpointList { cx, ); } - _ => {} - }, + } } self.focus_handle.focus(window); } else { @@ -401,11 +390,9 @@ impl BreakpointList { let Some(entry) = self.selected_ix.and_then(|ix| self.breakpoints.get_mut(ix)) else { return; }; - if self.strip_mode.is_some() { - if self.input.focus_handle(cx).contains_focused(window, cx) { - cx.propagate(); - return; - } + if self.strip_mode.is_some() && self.input.focus_handle(cx).contains_focused(window, cx) { + cx.propagate(); + return; } match &mut entry.kind { @@ -436,13 +423,10 @@ impl BreakpointList { return; }; - match &mut entry.kind { - BreakpointEntryKind::LineBreakpoint(line_breakpoint) => { - let path = line_breakpoint.breakpoint.path.clone(); - let row = line_breakpoint.breakpoint.row; - self.edit_line_breakpoint(path, row, BreakpointEditAction::Toggle, cx); - } - _ => {} + if let BreakpointEntryKind::LineBreakpoint(line_breakpoint) = &mut entry.kind { + let path = line_breakpoint.breakpoint.path.clone(); + let row = line_breakpoint.breakpoint.row; + self.edit_line_breakpoint(path, row, BreakpointEditAction::Toggle, cx); } cx.notify(); } @@ -494,7 +478,7 @@ impl BreakpointList { fn toggle_data_breakpoint(&mut self, id: &str, cx: &mut Context) { if let Some(session) = &self.session { session.update(cx, |this, cx| { - this.toggle_data_breakpoint(&id, cx); + this.toggle_data_breakpoint(id, cx); }); } } @@ -502,7 +486,7 @@ impl BreakpointList { fn toggle_exception_breakpoint(&mut self, id: &str, cx: &mut Context) { if let Some(session) = &self.session { session.update(cx, |this, cx| { - this.toggle_exception_breakpoint(&id, cx); + this.toggle_exception_breakpoint(id, cx); }); cx.notify(); const EXCEPTION_SERIALIZATION_INTERVAL: Duration = Duration::from_secs(1); @@ -538,7 +522,7 @@ impl BreakpointList { cx.background_executor() .spawn(async move { KEY_VALUE_STORE.write_kvp(key, value?).await }) } else { - return Task::ready(Result::Ok(())); + Task::ready(Result::Ok(())) } } @@ -701,7 +685,6 @@ impl BreakpointList { selection_kind.map(|kind| kind.0) != Some(SelectedBreakpointKind::Source), ) .on_click({ - let focus_handle = focus_handle.clone(); move |_, window, cx| { focus_handle.focus(window); window.dispatch_action(UnsetBreakpoint.boxed_clone(), cx) @@ -977,7 +960,7 @@ impl LineBreakpoint { props, breakpoint: BreakpointEntry { kind: BreakpointEntryKind::LineBreakpoint(self.clone()), - weak: weak, + weak, }, is_selected, focus_handle, @@ -1155,7 +1138,6 @@ impl ExceptionBreakpoint { } }) .on_click({ - let list = list.clone(); move |_, _, cx| { list.update(cx, |this, cx| { this.toggle_exception_breakpoint(&id, cx); @@ -1189,7 +1171,7 @@ impl ExceptionBreakpoint { props, breakpoint: BreakpointEntry { kind: BreakpointEntryKind::ExceptionBreakpoint(self.clone()), - weak: weak, + weak, }, is_selected, focus_handle, diff --git a/crates/debugger_ui/src/session/running/console.rs b/crates/debugger_ui/src/session/running/console.rs index e6308518e4dea66e6ef155e3dbf6ccfa74c18f55..92c5ace8f0128e47db08c6b772376679213ffbe1 100644 --- a/crates/debugger_ui/src/session/running/console.rs +++ b/crates/debugger_ui/src/session/running/console.rs @@ -15,7 +15,7 @@ use gpui::{ use language::{Anchor, Buffer, CodeLabel, TextBufferSnapshot, ToOffset}; use menu::{Confirm, SelectNext, SelectPrevious}; use project::{ - Completion, CompletionResponse, + Completion, CompletionDisplayOptions, CompletionResponse, debugger::session::{CompletionsQuery, OutputToken, Session}, lsp_store::CompletionDocumentation, search_history::{SearchHistory, SearchHistoryCursor}, @@ -83,7 +83,7 @@ impl Console { let this = cx.weak_entity(); let query_bar = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); - editor.set_placeholder_text("Evaluate an expression", cx); + editor.set_placeholder_text("Evaluate an expression", window, cx); editor.set_use_autoclose(false); editor.set_show_gutter(false, cx); editor.set_show_wrap_guides(false, cx); @@ -365,7 +365,7 @@ impl Console { Some(ContextMenu::build(window, cx, |context_menu, _, _| { context_menu .when_some(keybinding_target.clone(), |el, keybinding_target| { - el.context(keybinding_target.clone()) + el.context(keybinding_target) }) .action("Watch Expression", WatchExpression.boxed_clone()) })) @@ -611,17 +611,16 @@ impl ConsoleQueryBarCompletionProvider { for variable in console.variable_list.update(cx, |variable_list, cx| { variable_list.completion_variables(cx) }) { - if let Some(evaluate_name) = &variable.evaluate_name { - if variables + if let Some(evaluate_name) = &variable.evaluate_name + && variables .insert(evaluate_name.clone(), variable.value.clone()) .is_none() - { - string_matches.push(StringMatchCandidate { - id: 0, - string: evaluate_name.clone(), - char_bag: evaluate_name.chars().collect(), - }); - } + { + string_matches.push(StringMatchCandidate { + id: 0, + string: evaluate_name.clone(), + char_bag: evaluate_name.chars().collect(), + }); } if variables @@ -686,6 +685,7 @@ impl ConsoleQueryBarCompletionProvider { Ok(vec![project::CompletionResponse { is_incomplete: completions.len() >= LIMIT, + display_options: CompletionDisplayOptions::default(), completions, }]) }) @@ -697,7 +697,7 @@ impl ConsoleQueryBarCompletionProvider { new_bytes: &[u8], snapshot: &TextBufferSnapshot, ) -> Range { - let buffer_offset = buffer_position.to_offset(&snapshot); + let buffer_offset = buffer_position.to_offset(snapshot); let buffer_bytes = &buffer_text.as_bytes()[0..buffer_offset]; let mut prefix_len = 0; @@ -798,6 +798,7 @@ impl ConsoleQueryBarCompletionProvider { Ok(vec![project::CompletionResponse { completions, + display_options: CompletionDisplayOptions::default(), is_incomplete: false, }]) }) @@ -977,7 +978,7 @@ mod tests { &cx.buffer_text(), snapshot.anchor_before(buffer_position), replacement.as_bytes(), - &snapshot, + snapshot, ); cx.update_editor(|editor, _, cx| { diff --git a/crates/debugger_ui/src/session/running/loaded_source_list.rs b/crates/debugger_ui/src/session/running/loaded_source_list.rs index 6b376bb892e1ea5aae64a1d5873b91487e65f3c2..921ebd8b5f5bdfe8a3c8a8f7bb1625bd1ffad7fb 100644 --- a/crates/debugger_ui/src/session/running/loaded_source_list.rs +++ b/crates/debugger_ui/src/session/running/loaded_source_list.rs @@ -57,7 +57,7 @@ impl LoadedSourceList { h_flex() .text_ui_xs(cx) .text_color(cx.theme().colors().text_muted) - .when_some(source.path.clone(), |this, path| this.child(path)), + .when_some(source.path, |this, path| this.child(path)), ) .into_any() } diff --git a/crates/debugger_ui/src/session/running/memory_view.rs b/crates/debugger_ui/src/session/running/memory_view.rs index f936d908b157ae2631a20b78bfe9fcea26b47b96..a134b916a2200013bcdc9e03e00028a09227e05a 100644 --- a/crates/debugger_ui/src/session/running/memory_view.rs +++ b/crates/debugger_ui/src/session/running/memory_view.rs @@ -262,7 +262,7 @@ impl MemoryView { cx: &mut Context, ) { use parse_int::parse; - let Ok(as_address) = parse::(&memory_reference) else { + let Ok(as_address) = parse::(memory_reference) else { return; }; let access_size = evaluate_name @@ -428,14 +428,14 @@ impl MemoryView { if !self.is_writing_memory { self.query_editor.update(cx, |this, cx| { this.clear(window, cx); - this.set_placeholder_text("Write to Selected Memory Range", cx); + this.set_placeholder_text("Write to Selected Memory Range", window, cx); }); self.is_writing_memory = true; self.query_editor.focus_handle(cx).focus(window); } else { self.query_editor.update(cx, |this, cx| { this.clear(window, cx); - this.set_placeholder_text("Go to Memory Address / Expression", cx); + this.set_placeholder_text("Go to Memory Address / Expression", window, cx); }); self.is_writing_memory = false; } @@ -461,7 +461,7 @@ impl MemoryView { let data_breakpoint_info = this.data_breakpoint_info(context.clone(), None, cx); cx.spawn(async move |this, cx| { if let Some(info) = data_breakpoint_info.await { - let Some(data_id) = info.data_id.clone() else { + let Some(data_id) = info.data_id else { return; }; _ = this.update(cx, |this, cx| { @@ -931,7 +931,7 @@ impl Render for MemoryView { v_flex() .size_full() .on_drag_move(cx.listener(|this, evt, _, _| { - this.handle_memory_drag(&evt); + this.handle_memory_drag(evt); })) .child(self.render_memory(cx).size_full()) .children(self.open_context_menu.as_ref().map(|(menu, position, _)| { diff --git a/crates/debugger_ui/src/session/running/module_list.rs b/crates/debugger_ui/src/session/running/module_list.rs index 74a9fb457a57cf2e70af694ed586af2227ee4a0a..7743cfbdee7bf200ab25aabad4cfc455dc8b3484 100644 --- a/crates/debugger_ui/src/session/running/module_list.rs +++ b/crates/debugger_ui/src/session/running/module_list.rs @@ -157,7 +157,7 @@ impl ModuleList { h_flex() .text_ui_xs(cx) .text_color(cx.theme().colors().text_muted) - .when_some(module.path.clone(), |this, path| this.child(path)), + .when_some(module.path, |this, path| this.child(path)), ) .into_any() } @@ -223,7 +223,7 @@ impl ModuleList { fn select_next(&mut self, _: &menu::SelectNext, _window: &mut Window, cx: &mut Context) { let ix = match self.selected_ix { - _ if self.entries.len() == 0 => None, + _ if self.entries.is_empty() => None, None => Some(0), Some(ix) => { if ix == self.entries.len() - 1 { @@ -243,7 +243,7 @@ impl ModuleList { cx: &mut Context, ) { let ix = match self.selected_ix { - _ if self.entries.len() == 0 => None, + _ if self.entries.is_empty() => None, None => Some(self.entries.len() - 1), Some(ix) => { if ix == 0 { @@ -262,7 +262,7 @@ impl ModuleList { _window: &mut Window, cx: &mut Context, ) { - let ix = if self.entries.len() > 0 { + let ix = if !self.entries.is_empty() { Some(0) } else { None @@ -271,7 +271,7 @@ impl ModuleList { } fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context) { - let ix = if self.entries.len() > 0 { + let ix = if !self.entries.is_empty() { Some(self.entries.len() - 1) } else { None diff --git a/crates/debugger_ui/src/session/running/stack_frame_list.rs b/crates/debugger_ui/src/session/running/stack_frame_list.rs index 8b44c231c37dfe9bc3deb9905699e2c80df8897f..e51b8da362a581c96d2872a213a8be32ff31b097 100644 --- a/crates/debugger_ui/src/session/running/stack_frame_list.rs +++ b/crates/debugger_ui/src/session/running/stack_frame_list.rs @@ -4,16 +4,17 @@ use std::time::Duration; use anyhow::{Context as _, Result, anyhow}; use dap::StackFrameId; +use db::kvp::KEY_VALUE_STORE; use gpui::{ - AnyElement, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, ListState, MouseButton, - Stateful, Subscription, Task, WeakEntity, list, + Action, AnyElement, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, ListState, + MouseButton, Stateful, Subscription, Task, WeakEntity, list, }; use util::debug_panic; -use crate::StackTraceView; +use crate::{StackTraceView, ToggleUserFrames}; use language::PointUtf16; use project::debugger::breakpoint_store::ActiveStackFrame; -use project::debugger::session::{Session, SessionEvent, StackFrame}; +use project::debugger::session::{Session, SessionEvent, StackFrame, ThreadStatus}; use project::{ProjectItem, ProjectPath}; use ui::{Scrollbar, ScrollbarState, Tooltip, prelude::*}; use workspace::{ItemHandle, Workspace}; @@ -26,6 +27,34 @@ pub enum StackFrameListEvent { BuiltEntries, } +/// Represents the filter applied to the stack frame list +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub(crate) enum StackFrameFilter { + /// Show all frames + All, + /// Show only frames from the user's code + OnlyUserFrames, +} + +impl StackFrameFilter { + fn from_str_or_default(s: impl AsRef) -> Self { + match s.as_ref() { + "user" => StackFrameFilter::OnlyUserFrames, + "all" => StackFrameFilter::All, + _ => StackFrameFilter::All, + } + } +} + +impl From for String { + fn from(filter: StackFrameFilter) -> Self { + match filter { + StackFrameFilter::All => "all".to_string(), + StackFrameFilter::OnlyUserFrames => "user".to_string(), + } + } +} + pub struct StackFrameList { focus_handle: FocusHandle, _subscription: Subscription, @@ -37,6 +66,8 @@ pub struct StackFrameList { opened_stack_frame_id: Option, scrollbar_state: ScrollbarState, list_state: ListState, + list_filter: StackFrameFilter, + filter_entries_indices: Vec, error: Option, _refresh_task: Task<()>, } @@ -73,6 +104,16 @@ impl StackFrameList { let list_state = ListState::new(0, gpui::ListAlignment::Top, px(1000.)); let scrollbar_state = ScrollbarState::new(list_state.clone()); + let list_filter = KEY_VALUE_STORE + .read_kvp(&format!( + "stack-frame-list-filter-{}", + session.read(cx).adapter().0 + )) + .ok() + .flatten() + .map(StackFrameFilter::from_str_or_default) + .unwrap_or(StackFrameFilter::All); + let mut this = Self { session, workspace, @@ -80,9 +121,11 @@ impl StackFrameList { state, _subscription, entries: Default::default(), + filter_entries_indices: Vec::default(), error: None, selected_ix: None, opened_stack_frame_id: None, + list_filter, list_state, scrollbar_state, _refresh_task: Task::ready(()), @@ -103,7 +146,15 @@ impl StackFrameList { ) -> Vec { self.entries .iter() - .flat_map(|frame| match frame { + .enumerate() + .filter(|(ix, _)| { + self.list_filter == StackFrameFilter::All + || self + .filter_entries_indices + .binary_search_by_key(&ix, |ix| ix) + .is_ok() + }) + .flat_map(|(_, frame)| match frame { StackFrameEntry::Normal(frame) => vec![frame.clone()], StackFrameEntry::Label(frame) if show_labels => vec![frame.clone()], StackFrameEntry::Collapsed(frames) if show_collapsed => frames.clone(), @@ -123,11 +174,29 @@ impl StackFrameList { #[cfg(test)] pub(crate) fn dap_stack_frames(&self, cx: &mut App) -> Vec { - self.stack_frames(cx) - .unwrap_or_default() - .into_iter() - .map(|stack_frame| stack_frame.dap.clone()) - .collect() + match self.list_filter { + StackFrameFilter::All => self + .stack_frames(cx) + .unwrap_or_default() + .into_iter() + .map(|stack_frame| stack_frame.dap) + .collect(), + StackFrameFilter::OnlyUserFrames => self + .filter_entries_indices + .iter() + .map(|ix| match &self.entries[*ix] { + StackFrameEntry::Label(label) => label, + StackFrameEntry::Collapsed(_) => panic!("Collapsed tabs should not be visible"), + StackFrameEntry::Normal(frame) => frame, + }) + .cloned() + .collect(), + } + } + + #[cfg(test)] + pub(crate) fn list_filter(&self) -> StackFrameFilter { + self.list_filter } pub fn opened_stack_frame_id(&self) -> Option { @@ -187,12 +256,34 @@ impl StackFrameList { self.entries.clear(); self.selected_ix = None; self.list_state.reset(0); + self.filter_entries_indices.clear(); cx.emit(StackFrameListEvent::BuiltEntries); cx.notify(); return; } }; - for stack_frame in &stack_frames { + + let worktree_prefixes: Vec<_> = self + .workspace + .read_with(cx, |workspace, cx| { + workspace + .visible_worktrees(cx) + .map(|tree| tree.read(cx).abs_path()) + .collect() + }) + .unwrap_or_default(); + + let mut filter_entries_indices = Vec::default(); + for stack_frame in stack_frames.iter() { + let frame_in_visible_worktree = stack_frame.dap.source.as_ref().is_some_and(|source| { + source.path.as_ref().is_some_and(|path| { + worktree_prefixes + .iter() + .filter_map(|tree| tree.to_str()) + .any(|tree| path.starts_with(tree)) + }) + }); + match stack_frame.dap.presentation_hint { Some(dap::StackFramePresentationHint::Deemphasize) | Some(dap::StackFramePresentationHint::Subtle) => { @@ -218,15 +309,19 @@ impl StackFrameList { first_stack_frame_with_path.get_or_insert(entries.len()); } entries.push(StackFrameEntry::Normal(stack_frame.dap.clone())); + if frame_in_visible_worktree { + filter_entries_indices.push(entries.len() - 1); + } } } } let collapsed_entries = std::mem::take(&mut collapsed_entries); if !collapsed_entries.is_empty() { - entries.push(StackFrameEntry::Collapsed(collapsed_entries.clone())); + entries.push(StackFrameEntry::Collapsed(collapsed_entries)); } self.entries = entries; + self.filter_entries_indices = filter_entries_indices; if let Some(ix) = first_stack_frame_with_path .or(first_stack_frame) @@ -242,7 +337,14 @@ impl StackFrameList { self.selected_ix = ix; } - self.list_state.reset(self.entries.len()); + match self.list_filter { + StackFrameFilter::All => { + self.list_state.reset(self.entries.len()); + } + StackFrameFilter::OnlyUserFrames => { + self.list_state.reset(self.filter_entries_indices.len()); + } + } cx.emit(StackFrameListEvent::BuiltEntries); cx.notify(); } @@ -418,7 +520,7 @@ impl StackFrameList { let source = stack_frame.source.clone(); let is_selected_frame = Some(ix) == self.selected_ix; - let path = source.clone().and_then(|s| s.path.or(s.name)); + let path = source.and_then(|s| s.path.or(s.name)); let formatted_path = path.map(|path| format!("{}:{}", path, stack_frame.line,)); let formatted_path = formatted_path.map(|path| { Label::new(path) @@ -519,7 +621,16 @@ impl StackFrameList { let entries = std::mem::take(stack_frames) .into_iter() .map(StackFrameEntry::Normal); + // HERE + let entries_len = entries.len(); self.entries.splice(ix..ix + 1, entries); + let (Ok(filtered_indices_start) | Err(filtered_indices_start)) = + self.filter_entries_indices.binary_search(&ix); + + for idx in &mut self.filter_entries_indices[filtered_indices_start..] { + *idx += entries_len - 1; + } + self.selected_ix = Some(ix); self.list_state.reset(self.entries.len()); cx.emit(StackFrameListEvent::BuiltEntries); @@ -572,6 +683,11 @@ impl StackFrameList { } fn render_entry(&self, ix: usize, cx: &mut Context) -> AnyElement { + let ix = match self.list_filter { + StackFrameFilter::All => ix, + StackFrameFilter::OnlyUserFrames => self.filter_entries_indices[ix], + }; + match &self.entries[ix] { StackFrameEntry::Label(stack_frame) => self.render_label_entry(stack_frame, cx), StackFrameEntry::Normal(stack_frame) => self.render_normal_entry(ix, stack_frame, cx), @@ -621,7 +737,7 @@ impl StackFrameList { fn select_next(&mut self, _: &menu::SelectNext, _window: &mut Window, cx: &mut Context) { let ix = match self.selected_ix { - _ if self.entries.len() == 0 => None, + _ if self.entries.is_empty() => None, None => Some(0), Some(ix) => { if ix == self.entries.len() - 1 { @@ -641,7 +757,7 @@ impl StackFrameList { cx: &mut Context, ) { let ix = match self.selected_ix { - _ if self.entries.len() == 0 => None, + _ if self.entries.is_empty() => None, None => Some(self.entries.len() - 1), Some(ix) => { if ix == 0 { @@ -660,7 +776,7 @@ impl StackFrameList { _window: &mut Window, cx: &mut Context, ) { - let ix = if self.entries.len() > 0 { + let ix = if !self.entries.is_empty() { Some(0) } else { None @@ -669,7 +785,7 @@ impl StackFrameList { } fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context) { - let ix = if self.entries.len() > 0 { + let ix = if !self.entries.is_empty() { Some(self.entries.len() - 1) } else { None @@ -702,6 +818,67 @@ impl StackFrameList { self.activate_selected_entry(window, cx); } + pub(crate) fn toggle_frame_filter( + &mut self, + thread_status: Option, + cx: &mut Context, + ) { + self.list_filter = match self.list_filter { + StackFrameFilter::All => StackFrameFilter::OnlyUserFrames, + StackFrameFilter::OnlyUserFrames => StackFrameFilter::All, + }; + + if let Some(database_id) = self + .workspace + .read_with(cx, |workspace, _| workspace.database_id()) + .ok() + .flatten() + { + let database_id: i64 = database_id.into(); + let save_task = KEY_VALUE_STORE.write_kvp( + format!( + "stack-frame-list-filter-{}-{}", + self.session.read(cx).adapter().0, + database_id, + ), + self.list_filter.into(), + ); + cx.background_spawn(save_task).detach(); + } + + if let Some(ThreadStatus::Stopped) = thread_status { + match self.list_filter { + StackFrameFilter::All => { + self.list_state.reset(self.entries.len()); + } + StackFrameFilter::OnlyUserFrames => { + self.list_state.reset(self.filter_entries_indices.len()); + if !self + .selected_ix + .map(|ix| self.filter_entries_indices.contains(&ix)) + .unwrap_or_default() + { + self.selected_ix = None; + } + } + } + + if let Some(ix) = self.selected_ix { + let scroll_to = match self.list_filter { + StackFrameFilter::All => ix, + StackFrameFilter::OnlyUserFrames => self + .filter_entries_indices + .binary_search_by_key(&ix, |ix| *ix) + .expect("This index will always exist"), + }; + self.list_state.scroll_to_reveal_item(scroll_to); + } + + cx.emit(StackFrameListEvent::BuiltEntries); + cx.notify(); + } + } + fn render_list(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { div().p_1().size_full().child( list( @@ -711,6 +888,30 @@ impl StackFrameList { .size_full(), ) } + + pub(crate) fn render_control_strip(&self) -> AnyElement { + let tooltip_title = match self.list_filter { + StackFrameFilter::All => "Show stack frames from your project", + StackFrameFilter::OnlyUserFrames => "Show all stack frames", + }; + + h_flex() + .child( + IconButton::new( + "filter-by-visible-worktree-stack-frame-list", + IconName::ListFilter, + ) + .tooltip(move |window, cx| { + Tooltip::for_action(tooltip_title, &ToggleUserFrames, window, cx) + }) + .toggle_state(self.list_filter == StackFrameFilter::OnlyUserFrames) + .icon_size(IconSize::Small) + .on_click(|_, window, cx| { + window.dispatch_action(ToggleUserFrames.boxed_clone(), cx) + }), + ) + .into_any_element() + } } impl Render for StackFrameList { diff --git a/crates/debugger_ui/src/session/running/variable_list.rs b/crates/debugger_ui/src/session/running/variable_list.rs index efbc72e8cfc9099a5d699493898440d17fbf615b..b396f0921e5fdf58959e82db54bb8d558249891c 100644 --- a/crates/debugger_ui/src/session/running/variable_list.rs +++ b/crates/debugger_ui/src/session/running/variable_list.rs @@ -272,7 +272,7 @@ impl VariableList { let mut entries = vec![]; let scopes: Vec<_> = self.session.update(cx, |session, cx| { - session.scopes(stack_frame_id, cx).iter().cloned().collect() + session.scopes(stack_frame_id, cx).to_vec() }); let mut contains_local_scope = false; @@ -291,7 +291,7 @@ impl VariableList { } self.session.update(cx, |session, cx| { - session.variables(scope.variables_reference, cx).len() > 0 + !session.variables(scope.variables_reference, cx).is_empty() }) }) .map(|scope| { @@ -313,7 +313,7 @@ impl VariableList { watcher.variables_reference, watcher.variables_reference, EntryPath::for_watcher(watcher.expression.clone()), - DapEntry::Watcher(watcher.clone()), + DapEntry::Watcher(watcher), ) }) .collect::>(), @@ -947,7 +947,7 @@ impl VariableList { #[track_caller] #[cfg(test)] pub(crate) fn assert_visual_entries(&self, expected: Vec<&str>) { - const INDENT: &'static str = " "; + const INDENT: &str = " "; let entries = &self.entries; let mut visual_entries = Vec::with_capacity(entries.len()); @@ -997,7 +997,7 @@ impl VariableList { DapEntry::Watcher { .. } => continue, DapEntry::Variable(dap) => scopes[idx].1.push(dap.clone()), DapEntry::Scope(scope) => { - if scopes.len() > 0 { + if !scopes.is_empty() { idx += 1; } @@ -1289,7 +1289,7 @@ impl VariableList { }), ) .child(self.render_variable_value( - &entry, + entry, &variable_color, watcher.value.to_string(), cx, @@ -1301,8 +1301,6 @@ impl VariableList { IconName::Close, ) .on_click({ - let weak = weak.clone(); - let path = path.clone(); move |_, window, cx| { weak.update(cx, |variable_list, cx| { variable_list.selection = Some(path.clone()); @@ -1470,7 +1468,6 @@ impl VariableList { })) }) .on_secondary_mouse_down(cx.listener({ - let path = path.clone(); let entry = variable.clone(); move |this, event: &MouseDownEvent, window, cx| { this.selection = Some(path.clone()); @@ -1494,7 +1491,7 @@ impl VariableList { }), ) .child(self.render_variable_value( - &variable, + variable, &variable_color, dap.value.clone(), cx, diff --git a/crates/debugger_ui/src/tests/attach_modal.rs b/crates/debugger_ui/src/tests/attach_modal.rs index 906a7a0d4bd76f0451d6b5d5cfa5beff0136c613..80e2b73d5a100bbd21462f0ad80def1997e184de 100644 --- a/crates/debugger_ui/src/tests/attach_modal.rs +++ b/crates/debugger_ui/src/tests/attach_modal.rs @@ -139,7 +139,7 @@ async fn test_show_attach_modal_and_select_process( workspace .update(cx, |_, window, cx| { let names = - attach_modal.update(cx, |modal, cx| attach_modal::_process_names(&modal, cx)); + attach_modal.update(cx, |modal, cx| attach_modal::_process_names(modal, cx)); // Initially all processes are visible. assert_eq!(3, names.len()); attach_modal.update(cx, |this, cx| { @@ -154,7 +154,7 @@ async fn test_show_attach_modal_and_select_process( workspace .update(cx, |_, _, cx| { let names = - attach_modal.update(cx, |modal, cx| attach_modal::_process_names(&modal, cx)); + attach_modal.update(cx, |modal, cx| attach_modal::_process_names(modal, cx)); // Initially all processes are visible. assert_eq!(2, names.len()); }) diff --git a/crates/debugger_ui/src/tests/debugger_panel.rs b/crates/debugger_ui/src/tests/debugger_panel.rs index 6180831ea9dccfb3c1ee861daac099e54b2242c3..ab6d5cb9605d5d774187f836130fdae66a8d3404 100644 --- a/crates/debugger_ui/src/tests/debugger_panel.rs +++ b/crates/debugger_ui/src/tests/debugger_panel.rs @@ -1330,7 +1330,6 @@ async fn test_unsetting_breakpoints_on_clear_breakpoint_action( let called_set_breakpoints = Arc::new(AtomicBool::new(false)); client.on_request::({ - let called_set_breakpoints = called_set_breakpoints.clone(); move |_, args| { assert!( args.breakpoints.is_none_or(|bps| bps.is_empty()), @@ -1445,7 +1444,6 @@ async fn test_we_send_arguments_from_user_config( let launch_handler_called = Arc::new(AtomicBool::new(false)); start_debug_session_with(&workspace, cx, debug_definition.clone(), { - let debug_definition = debug_definition.clone(); let launch_handler_called = launch_handler_called.clone(); move |client| { @@ -1783,9 +1781,8 @@ async fn test_debug_adapters_shutdown_on_app_quit( let disconnect_request_received = Arc::new(AtomicBool::new(false)); let disconnect_clone = disconnect_request_received.clone(); - let disconnect_clone_for_handler = disconnect_clone.clone(); client.on_request::(move |_, _| { - disconnect_clone_for_handler.store(true, Ordering::SeqCst); + disconnect_clone.store(true, Ordering::SeqCst); Ok(()) }); diff --git a/crates/debugger_ui/src/tests/new_process_modal.rs b/crates/debugger_ui/src/tests/new_process_modal.rs index d6b0dfa00429f9487eafbe38dca5f072ed547779..bfc445cf67329b7190f8e5b8d353415fb53fcd74 100644 --- a/crates/debugger_ui/src/tests/new_process_modal.rs +++ b/crates/debugger_ui/src/tests/new_process_modal.rs @@ -106,9 +106,7 @@ async fn test_debug_session_substitutes_variables_and_relativizes_paths( ); let expected_other_field = if input_path.contains("$ZED_WORKTREE_ROOT") { - input_path - .replace("$ZED_WORKTREE_ROOT", &path!("/test/worktree/path")) - .to_owned() + input_path.replace("$ZED_WORKTREE_ROOT", path!("/test/worktree/path")) } else { input_path.to_string() }; diff --git a/crates/debugger_ui/src/tests/stack_frame_list.rs b/crates/debugger_ui/src/tests/stack_frame_list.rs index 95a6903c14a1cbd5f750d6e11437cb0bf92887c7..a61a31d270c9d599f30185d7da3c825c51bb7898 100644 --- a/crates/debugger_ui/src/tests/stack_frame_list.rs +++ b/crates/debugger_ui/src/tests/stack_frame_list.rs @@ -1,6 +1,6 @@ use crate::{ debugger_panel::DebugPanel, - session::running::stack_frame_list::StackFrameEntry, + session::running::stack_frame_list::{StackFrameEntry, StackFrameFilter}, tests::{active_debug_session_panel, init_test, init_test_workspace, start_debug_session}, }; use dap::{ @@ -752,3 +752,346 @@ async fn test_collapsed_entries(executor: BackgroundExecutor, cx: &mut TestAppCo }); }); } + +#[gpui::test] +async fn test_stack_frame_filter(executor: BackgroundExecutor, cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(executor.clone()); + + let test_file_content = r#" + function main() { + doSomething(); + } + + function doSomething() { + console.log('doing something'); + } + "# + .unindent(); + + fs.insert_tree( + path!("/project"), + json!({ + "src": { + "test.js": test_file_content, + } + }), + ) + .await; + + let project = Project::test(fs, [path!("/project").as_ref()], cx).await; + let workspace = init_test_workspace(&project, cx).await; + let cx = &mut VisualTestContext::from_window(*workspace, cx); + + let session = start_debug_session(&workspace, cx, |_| {}).unwrap(); + let client = session.update(cx, |session, _| session.adapter_client().unwrap()); + + client.on_request::(move |_, _| { + Ok(dap::ThreadsResponse { + threads: vec![dap::Thread { + id: 1, + name: "Thread 1".into(), + }], + }) + }); + + client.on_request::(move |_, _| Ok(dap::ScopesResponse { scopes: vec![] })); + + let stack_frames = vec![ + StackFrame { + id: 1, + name: "main".into(), + source: Some(dap::Source { + name: Some("test.js".into()), + path: Some(path!("/project/src/test.js").into()), + source_reference: None, + presentation_hint: None, + origin: None, + sources: None, + adapter_data: None, + checksums: None, + }), + line: 2, + column: 1, + end_line: None, + end_column: None, + can_restart: None, + instruction_pointer_reference: None, + module_id: None, + presentation_hint: None, + }, + StackFrame { + id: 2, + name: "node:internal/modules/cjs/loader".into(), + source: Some(dap::Source { + name: Some("loader.js".into()), + path: Some(path!("/usr/lib/node/internal/modules/cjs/loader.js").into()), + source_reference: None, + presentation_hint: None, + origin: None, + sources: None, + adapter_data: None, + checksums: None, + }), + line: 100, + column: 1, + end_line: None, + end_column: None, + can_restart: None, + instruction_pointer_reference: None, + module_id: None, + presentation_hint: Some(dap::StackFramePresentationHint::Deemphasize), + }, + StackFrame { + id: 3, + name: "node:internal/modules/run_main".into(), + source: Some(dap::Source { + name: Some("run_main.js".into()), + path: Some(path!("/usr/lib/node/internal/modules/run_main.js").into()), + source_reference: None, + presentation_hint: None, + origin: None, + sources: None, + adapter_data: None, + checksums: None, + }), + line: 50, + column: 1, + end_line: None, + end_column: None, + can_restart: None, + instruction_pointer_reference: None, + module_id: None, + presentation_hint: Some(dap::StackFramePresentationHint::Deemphasize), + }, + StackFrame { + id: 4, + name: "node:internal/modules/run_main2".into(), + source: Some(dap::Source { + name: Some("run_main.js".into()), + path: Some(path!("/usr/lib/node/internal/modules/run_main2.js").into()), + source_reference: None, + presentation_hint: None, + origin: None, + sources: None, + adapter_data: None, + checksums: None, + }), + line: 50, + column: 1, + end_line: None, + end_column: None, + can_restart: None, + instruction_pointer_reference: None, + module_id: None, + presentation_hint: Some(dap::StackFramePresentationHint::Deemphasize), + }, + StackFrame { + id: 5, + name: "doSomething".into(), + source: Some(dap::Source { + name: Some("test.js".into()), + path: Some(path!("/project/src/test.js").into()), + source_reference: None, + presentation_hint: None, + origin: None, + sources: None, + adapter_data: None, + checksums: None, + }), + line: 3, + column: 1, + end_line: None, + end_column: None, + can_restart: None, + instruction_pointer_reference: None, + module_id: None, + presentation_hint: None, + }, + ]; + + // Store a copy for assertions + let stack_frames_for_assertions = stack_frames.clone(); + + client.on_request::({ + let stack_frames = Arc::new(stack_frames.clone()); + move |_, args| { + assert_eq!(1, args.thread_id); + + Ok(dap::StackTraceResponse { + stack_frames: (*stack_frames).clone(), + total_frames: None, + }) + } + }); + + client + .fake_event(dap::messages::Events::Stopped(dap::StoppedEvent { + reason: dap::StoppedEventReason::Pause, + description: None, + thread_id: Some(1), + preserve_focus_hint: None, + text: None, + all_threads_stopped: None, + hit_breakpoint_ids: None, + })) + .await; + + cx.run_until_parked(); + + // trigger threads to load + active_debug_session_panel(workspace, cx).update(cx, |session, cx| { + session.running_state().update(cx, |running_state, cx| { + running_state + .session() + .update(cx, |session, cx| session.threads(cx)); + }); + }); + + cx.run_until_parked(); + + // select first thread + active_debug_session_panel(workspace, cx).update_in(cx, |session, window, cx| { + session.running_state().update(cx, |running_state, cx| { + running_state.select_current_thread( + &running_state + .session() + .update(cx, |session, cx| session.threads(cx)), + window, + cx, + ); + }); + }); + + cx.run_until_parked(); + + // trigger stack frames to load + active_debug_session_panel(workspace, cx).update(cx, |debug_panel_item, cx| { + let stack_frame_list = debug_panel_item + .running_state() + .update(cx, |state, _| state.stack_frame_list().clone()); + + stack_frame_list.update(cx, |stack_frame_list, cx| { + stack_frame_list.dap_stack_frames(cx); + }); + }); + + cx.run_until_parked(); + + let stack_frame_list = + active_debug_session_panel(workspace, cx).update_in(cx, |debug_panel_item, window, cx| { + let stack_frame_list = debug_panel_item + .running_state() + .update(cx, |state, _| state.stack_frame_list().clone()); + + stack_frame_list.update(cx, |stack_frame_list, cx| { + stack_frame_list.build_entries(true, window, cx); + + // Verify we have the expected collapsed structure + assert_eq!( + stack_frame_list.entries(), + &vec![ + StackFrameEntry::Normal(stack_frames_for_assertions[0].clone()), + StackFrameEntry::Collapsed(vec![ + stack_frames_for_assertions[1].clone(), + stack_frames_for_assertions[2].clone(), + stack_frames_for_assertions[3].clone() + ]), + StackFrameEntry::Normal(stack_frames_for_assertions[4].clone()), + ] + ); + }); + + stack_frame_list + }); + + stack_frame_list.update(cx, |stack_frame_list, cx| { + let all_frames = stack_frame_list.flatten_entries(true, false); + assert_eq!(all_frames.len(), 5, "Should see all 5 frames initially"); + + stack_frame_list + .toggle_frame_filter(Some(project::debugger::session::ThreadStatus::Stopped), cx); + assert_eq!( + stack_frame_list.list_filter(), + StackFrameFilter::OnlyUserFrames + ); + }); + + stack_frame_list.update(cx, |stack_frame_list, cx| { + let user_frames = stack_frame_list.dap_stack_frames(cx); + assert_eq!(user_frames.len(), 2, "Should only see 2 user frames"); + assert_eq!(user_frames[0].name, "main"); + assert_eq!(user_frames[1].name, "doSomething"); + + // Toggle back to all frames + stack_frame_list + .toggle_frame_filter(Some(project::debugger::session::ThreadStatus::Stopped), cx); + assert_eq!(stack_frame_list.list_filter(), StackFrameFilter::All); + }); + + stack_frame_list.update(cx, |stack_frame_list, cx| { + let all_frames_again = stack_frame_list.flatten_entries(true, false); + assert_eq!( + all_frames_again.len(), + 5, + "Should see all 5 frames after toggling back" + ); + + // Test 3: Verify collapsed entries stay expanded + stack_frame_list.expand_collapsed_entry(1, cx); + assert_eq!( + stack_frame_list.entries(), + &vec![ + StackFrameEntry::Normal(stack_frames_for_assertions[0].clone()), + StackFrameEntry::Normal(stack_frames_for_assertions[1].clone()), + StackFrameEntry::Normal(stack_frames_for_assertions[2].clone()), + StackFrameEntry::Normal(stack_frames_for_assertions[3].clone()), + StackFrameEntry::Normal(stack_frames_for_assertions[4].clone()), + ] + ); + + stack_frame_list + .toggle_frame_filter(Some(project::debugger::session::ThreadStatus::Stopped), cx); + assert_eq!( + stack_frame_list.list_filter(), + StackFrameFilter::OnlyUserFrames + ); + }); + + stack_frame_list.update(cx, |stack_frame_list, cx| { + stack_frame_list + .toggle_frame_filter(Some(project::debugger::session::ThreadStatus::Stopped), cx); + assert_eq!(stack_frame_list.list_filter(), StackFrameFilter::All); + }); + + stack_frame_list.update(cx, |stack_frame_list, cx| { + stack_frame_list + .toggle_frame_filter(Some(project::debugger::session::ThreadStatus::Stopped), cx); + assert_eq!( + stack_frame_list.list_filter(), + StackFrameFilter::OnlyUserFrames + ); + + assert_eq!( + stack_frame_list.dap_stack_frames(cx).as_slice(), + &[ + stack_frames_for_assertions[0].clone(), + stack_frames_for_assertions[4].clone() + ] + ); + + // Verify entries remain expanded + assert_eq!( + stack_frame_list.entries(), + &vec![ + StackFrameEntry::Normal(stack_frames_for_assertions[0].clone()), + StackFrameEntry::Normal(stack_frames_for_assertions[1].clone()), + StackFrameEntry::Normal(stack_frames_for_assertions[2].clone()), + StackFrameEntry::Normal(stack_frames_for_assertions[3].clone()), + StackFrameEntry::Normal(stack_frames_for_assertions[4].clone()), + ], + "Expanded entries should remain expanded after toggling filter" + ); + }); +} diff --git a/crates/debugger_ui/src/tests/variable_list.rs b/crates/debugger_ui/src/tests/variable_list.rs index fbbd52964105659c2cae645cec494824069f5529..4cfdae093f6a1464b178c053e629a6ebe6d76d02 100644 --- a/crates/debugger_ui/src/tests/variable_list.rs +++ b/crates/debugger_ui/src/tests/variable_list.rs @@ -1445,11 +1445,8 @@ async fn test_variable_list_only_sends_requests_when_rendering( cx.run_until_parked(); - let running_state = active_debug_session_panel(workspace, cx).update_in(cx, |item, _, _| { - let state = item.running_state().clone(); - - state - }); + let running_state = active_debug_session_panel(workspace, cx) + .update_in(cx, |item, _, _| item.running_state().clone()); client .fake_event(dap::messages::Events::Stopped(dap::StoppedEvent { diff --git a/crates/deepseek/src/deepseek.rs b/crates/deepseek/src/deepseek.rs index c49270febe3b2b3702b808e2219f6e45d7252267..64a1cbe5d96354260c2bf84a43ed70be7336aa7a 100644 --- a/crates/deepseek/src/deepseek.rs +++ b/crates/deepseek/src/deepseek.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use std::convert::TryFrom; -pub const DEEPSEEK_API_URL: &str = "https://api.deepseek.com"; +pub const DEEPSEEK_API_URL: &str = "https://api.deepseek.com/v1"; #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(rename_all = "lowercase")] @@ -96,7 +96,7 @@ impl Model { pub fn max_token_count(&self) -> u64 { match self { - Self::Chat | Self::Reasoner => 64_000, + Self::Chat | Self::Reasoner => 128_000, Self::Custom { max_tokens, .. } => *max_tokens, } } @@ -104,7 +104,7 @@ impl Model { pub fn max_output_tokens(&self) -> Option { match self { Self::Chat => Some(8_192), - Self::Reasoner => Some(8_192), + Self::Reasoner => Some(64_000), Self::Custom { max_output_tokens, .. } => *max_output_tokens, @@ -263,12 +263,12 @@ pub async fn stream_completion( api_key: &str, request: Request, ) -> Result>> { - let uri = format!("{api_url}/v1/chat/completions"); + let uri = format!("{api_url}/chat/completions"); let request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)); + .header("Authorization", format!("Bearer {}", api_key.trim())); let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let mut response = client.send(request).await?; diff --git a/crates/diagnostics/Cargo.toml b/crates/diagnostics/Cargo.toml index 53b5792e10e73d1629a104e345965547b6f2b25e..fd678078e8668b8a569c2d0f1627c786987a3cb4 100644 --- a/crates/diagnostics/Cargo.toml +++ b/crates/diagnostics/Cargo.toml @@ -18,7 +18,6 @@ collections.workspace = true component.workspace = true ctor.workspace = true editor.workspace = true -futures.workspace = true gpui.workspace = true indoc.workspace = true language.workspace = true diff --git a/crates/diagnostics/src/buffer_diagnostics.rs b/crates/diagnostics/src/buffer_diagnostics.rs new file mode 100644 index 0000000000000000000000000000000000000000..3a245163822fb19c43d11a93bc48c3d276e4d502 --- /dev/null +++ b/crates/diagnostics/src/buffer_diagnostics.rs @@ -0,0 +1,982 @@ +use crate::{ + DIAGNOSTICS_UPDATE_DELAY, IncludeWarnings, ToggleWarnings, context_range_for_entry, + diagnostic_renderer::{DiagnosticBlock, DiagnosticRenderer}, + toolbar_controls::DiagnosticsToolbarEditor, +}; +use anyhow::Result; +use collections::HashMap; +use editor::{ + Editor, EditorEvent, ExcerptRange, MultiBuffer, PathKey, + display_map::{BlockPlacement, BlockProperties, BlockStyle, CustomBlockId}, + multibuffer_context_lines, +}; +use gpui::{ + AnyElement, App, AppContext, Context, Entity, EntityId, EventEmitter, FocusHandle, Focusable, + InteractiveElement, IntoElement, ParentElement, Render, SharedString, Styled, Subscription, + Task, WeakEntity, Window, actions, div, +}; +use language::{Buffer, DiagnosticEntry, Point}; +use project::{ + DiagnosticSummary, Event, Project, ProjectItem, ProjectPath, + project_settings::{DiagnosticSeverity, ProjectSettings}, +}; +use settings::Settings; +use std::{ + any::{Any, TypeId}, + cmp::Ordering, + sync::Arc, +}; +use text::{Anchor, BufferSnapshot, OffsetRangeExt}; +use ui::{Button, ButtonStyle, Icon, IconName, Label, Tooltip, h_flex, prelude::*}; +use util::paths::PathExt; +use workspace::{ + ItemHandle, ItemNavHistory, ToolbarItemLocation, Workspace, + item::{BreadcrumbText, Item, ItemEvent, TabContentParams}, +}; + +actions!( + diagnostics, + [ + /// Opens the project diagnostics view for the currently focused file. + DeployCurrentFile, + ] +); + +/// The `BufferDiagnosticsEditor` is meant to be used when dealing specifically +/// with diagnostics for a single buffer, as only the excerpts of the buffer +/// where diagnostics are available are displayed. +pub(crate) struct BufferDiagnosticsEditor { + pub project: Entity, + focus_handle: FocusHandle, + editor: Entity, + /// The current diagnostic entries in the `BufferDiagnosticsEditor`. Used to + /// allow quick comparison of updated diagnostics, to confirm if anything + /// has changed. + pub(crate) diagnostics: Vec>, + /// The blocks used to display the diagnostics' content in the editor, next + /// to the excerpts where the diagnostic originated. + blocks: Vec, + /// Multibuffer to contain all excerpts that contain diagnostics, which are + /// to be rendered in the editor. + multibuffer: Entity, + /// The buffer for which the editor is displaying diagnostics and excerpts + /// for. + buffer: Option>, + /// The path for which the editor is displaying diagnostics for. + project_path: ProjectPath, + /// Summary of the number of warnings and errors for the path. Used to + /// display the number of warnings and errors in the tab's content. + summary: DiagnosticSummary, + /// Whether to include warnings in the list of diagnostics shown in the + /// editor. + pub(crate) include_warnings: bool, + /// Keeps track of whether there's a background task already running to + /// update the excerpts, in order to avoid firing multiple tasks for this purpose. + pub(crate) update_excerpts_task: Option>>, + /// The project's subscription, responsible for processing events related to + /// diagnostics. + _subscription: Subscription, +} + +impl BufferDiagnosticsEditor { + /// Creates new instance of the `BufferDiagnosticsEditor` which can then be + /// displayed by adding it to a pane. + pub fn new( + project_path: ProjectPath, + project_handle: Entity, + buffer: Option>, + include_warnings: bool, + window: &mut Window, + cx: &mut Context, + ) -> Self { + // Subscribe to project events related to diagnostics so the + // `BufferDiagnosticsEditor` can update its state accordingly. + let project_event_subscription = cx.subscribe_in( + &project_handle, + window, + |buffer_diagnostics_editor, _project, event, window, cx| match event { + Event::DiskBasedDiagnosticsStarted { .. } => { + cx.notify(); + } + Event::DiskBasedDiagnosticsFinished { .. } => { + buffer_diagnostics_editor.update_all_excerpts(window, cx); + } + Event::DiagnosticsUpdated { + paths, + language_server_id, + } => { + // When diagnostics have been updated, the + // `BufferDiagnosticsEditor` should update its state only if + // one of the paths matches its `project_path`, otherwise + // the event should be ignored. + if paths.contains(&buffer_diagnostics_editor.project_path) { + buffer_diagnostics_editor.update_diagnostic_summary(cx); + + if buffer_diagnostics_editor.editor.focus_handle(cx).contains_focused(window, cx) || buffer_diagnostics_editor.focus_handle.contains_focused(window, cx) { + log::debug!("diagnostics updated for server {language_server_id}. recording change"); + } else { + log::debug!("diagnostics updated for server {language_server_id}. updating excerpts"); + buffer_diagnostics_editor.update_all_excerpts(window, cx); + } + } + } + _ => {} + }, + ); + + let focus_handle = cx.focus_handle(); + + cx.on_focus_in( + &focus_handle, + window, + |buffer_diagnostics_editor, window, cx| buffer_diagnostics_editor.focus_in(window, cx), + ) + .detach(); + + cx.on_focus_out( + &focus_handle, + window, + |buffer_diagnostics_editor, _event, window, cx| { + buffer_diagnostics_editor.focus_out(window, cx) + }, + ) + .detach(); + + let summary = project_handle + .read(cx) + .diagnostic_summary_for_path(&project_path, cx); + + let multibuffer = cx.new(|cx| MultiBuffer::new(project_handle.read(cx).capability())); + let max_severity = Self::max_diagnostics_severity(include_warnings); + let editor = cx.new(|cx| { + let mut editor = Editor::for_multibuffer( + multibuffer.clone(), + Some(project_handle.clone()), + window, + cx, + ); + editor.set_vertical_scroll_margin(5, cx); + editor.disable_inline_diagnostics(); + editor.set_max_diagnostics_severity(max_severity, cx); + editor.set_all_diagnostics_active(cx); + editor + }); + + // Subscribe to events triggered by the editor in order to correctly + // update the buffer's excerpts. + cx.subscribe_in( + &editor, + window, + |buffer_diagnostics_editor, _editor, event: &EditorEvent, window, cx| { + cx.emit(event.clone()); + + match event { + // If the user tries to focus on the editor but there's actually + // no excerpts for the buffer, focus back on the + // `BufferDiagnosticsEditor` instance. + EditorEvent::Focused => { + if buffer_diagnostics_editor.multibuffer.read(cx).is_empty() { + window.focus(&buffer_diagnostics_editor.focus_handle); + } + } + EditorEvent::Blurred => { + buffer_diagnostics_editor.update_all_excerpts(window, cx) + } + _ => {} + } + }, + ) + .detach(); + + let diagnostics = vec![]; + let update_excerpts_task = None; + let mut buffer_diagnostics_editor = Self { + project: project_handle, + focus_handle, + editor, + diagnostics, + blocks: Default::default(), + multibuffer, + buffer, + project_path, + summary, + include_warnings, + update_excerpts_task, + _subscription: project_event_subscription, + }; + + buffer_diagnostics_editor.update_all_diagnostics(window, cx); + buffer_diagnostics_editor + } + + fn deploy( + workspace: &mut Workspace, + _: &DeployCurrentFile, + window: &mut Window, + cx: &mut Context, + ) { + // Determine the currently opened path by finding the active editor and + // finding the project path for the buffer. + // If there's no active editor with a project path, avoiding deploying + // the buffer diagnostics view. + if let Some(editor) = workspace.active_item_as::(cx) + && let Some(project_path) = editor.project_path(cx) + { + // Check if there's already a `BufferDiagnosticsEditor` tab for this + // same path, and if so, focus on that one instead of creating a new + // one. + let existing_editor = workspace + .items_of_type::(cx) + .find(|editor| editor.read(cx).project_path == project_path); + + if let Some(editor) = existing_editor { + workspace.activate_item(&editor, true, true, window, cx); + } else { + let include_warnings = match cx.try_global::() { + Some(include_warnings) => include_warnings.0, + None => ProjectSettings::get_global(cx).diagnostics.include_warnings, + }; + + let item = cx.new(|cx| { + Self::new( + project_path, + workspace.project().clone(), + editor.read(cx).buffer().read(cx).as_singleton(), + include_warnings, + window, + cx, + ) + }); + + workspace.add_item_to_active_pane(Box::new(item), None, true, window, cx); + } + } + } + + pub fn register( + workspace: &mut Workspace, + _window: Option<&mut Window>, + _: &mut Context, + ) { + workspace.register_action(Self::deploy); + } + + fn update_all_diagnostics(&mut self, window: &mut Window, cx: &mut Context) { + self.update_all_excerpts(window, cx); + } + + fn update_diagnostic_summary(&mut self, cx: &mut Context) { + let project = self.project.read(cx); + + self.summary = project.diagnostic_summary_for_path(&self.project_path, cx); + } + + /// Enqueue an update to the excerpts and diagnostic blocks being shown in + /// the editor. + pub(crate) fn update_all_excerpts(&mut self, window: &mut Window, cx: &mut Context) { + // If there's already a task updating the excerpts, early return and let + // the other task finish. + if self.update_excerpts_task.is_some() { + return; + } + + let buffer = self.buffer.clone(); + + self.update_excerpts_task = Some(cx.spawn_in(window, async move |editor, cx| { + cx.background_executor() + .timer(DIAGNOSTICS_UPDATE_DELAY) + .await; + + if let Some(buffer) = buffer { + editor + .update_in(cx, |editor, window, cx| { + editor.update_excerpts(buffer, window, cx) + })? + .await?; + }; + + let _ = editor.update(cx, |editor, cx| { + editor.update_excerpts_task = None; + cx.notify(); + }); + + Ok(()) + })); + } + + /// Updates the excerpts in the `BufferDiagnosticsEditor` for a single + /// buffer. + fn update_excerpts( + &mut self, + buffer: Entity, + window: &mut Window, + cx: &mut Context, + ) -> Task> { + let was_empty = self.multibuffer.read(cx).is_empty(); + let multibuffer_context = multibuffer_context_lines(cx); + let buffer_snapshot = buffer.read(cx).snapshot(); + let buffer_snapshot_max = buffer_snapshot.max_point(); + let max_severity = Self::max_diagnostics_severity(self.include_warnings) + .into_lsp() + .unwrap_or(lsp::DiagnosticSeverity::WARNING); + + cx.spawn_in(window, async move |buffer_diagnostics_editor, mut cx| { + // Fetch the diagnostics for the whole of the buffer + // (`Point::zero()..buffer_snapshot.max_point()`) so we can confirm + // if the diagnostics changed, if it didn't, early return as there's + // nothing to update. + let diagnostics = buffer_snapshot + .diagnostics_in_range::<_, Anchor>(Point::zero()..buffer_snapshot_max, false) + .collect::>(); + + let unchanged = + buffer_diagnostics_editor.update(cx, |buffer_diagnostics_editor, _cx| { + if buffer_diagnostics_editor + .diagnostics_are_unchanged(&diagnostics, &buffer_snapshot) + { + return true; + } + + buffer_diagnostics_editor.set_diagnostics(&diagnostics); + return false; + })?; + + if unchanged { + return Ok(()); + } + + // Mapping between the Group ID and a vector of DiagnosticEntry. + let mut grouped: HashMap> = HashMap::default(); + for entry in diagnostics { + grouped + .entry(entry.diagnostic.group_id) + .or_default() + .push(DiagnosticEntry { + range: entry.range.to_point(&buffer_snapshot), + diagnostic: entry.diagnostic, + }) + } + + let mut blocks: Vec = Vec::new(); + for (_, group) in grouped { + // If the minimum severity of the group is higher than the + // maximum severity, or it doesn't even have severity, skip this + // group. + if group + .iter() + .map(|d| d.diagnostic.severity) + .min() + .is_none_or(|severity| severity > max_severity) + { + continue; + } + + let diagnostic_blocks = cx.update(|_window, cx| { + DiagnosticRenderer::diagnostic_blocks_for_group( + group, + buffer_snapshot.remote_id(), + Some(Arc::new(buffer_diagnostics_editor.clone())), + cx, + ) + })?; + + // For each of the diagnostic blocks to be displayed in the + // editor, figure out its index in the list of blocks. + // + // The following rules are used to determine the order: + // 1. Blocks with a lower start position should come first. + // 2. If two blocks have the same start position, the one with + // the higher end position should come first. + for diagnostic_block in diagnostic_blocks { + let index = blocks.partition_point(|probe| { + match probe + .initial_range + .start + .cmp(&diagnostic_block.initial_range.start) + { + Ordering::Less => true, + Ordering::Greater => false, + Ordering::Equal => { + probe.initial_range.end > diagnostic_block.initial_range.end + } + } + }); + + blocks.insert(index, diagnostic_block); + } + } + + // Build the excerpt ranges for this specific buffer's diagnostics, + // so those excerpts can later be used to update the excerpts shown + // in the editor. + // This is done by iterating over the list of diagnostic blocks and + // determine what range does the diagnostic block span. + let mut excerpt_ranges: Vec> = Vec::new(); + + for diagnostic_block in blocks.iter() { + let excerpt_range = context_range_for_entry( + diagnostic_block.initial_range.clone(), + multibuffer_context, + buffer_snapshot.clone(), + &mut cx, + ) + .await; + + let index = excerpt_ranges + .binary_search_by(|probe| { + probe + .context + .start + .cmp(&excerpt_range.start) + .then(probe.context.end.cmp(&excerpt_range.end)) + .then( + probe + .primary + .start + .cmp(&diagnostic_block.initial_range.start), + ) + .then(probe.primary.end.cmp(&diagnostic_block.initial_range.end)) + .then(Ordering::Greater) + }) + .unwrap_or_else(|index| index); + + excerpt_ranges.insert( + index, + ExcerptRange { + context: excerpt_range, + primary: diagnostic_block.initial_range.clone(), + }, + ) + } + + // Finally, update the editor's content with the new excerpt ranges + // for this editor, as well as the diagnostic blocks. + buffer_diagnostics_editor.update_in(cx, |buffer_diagnostics_editor, window, cx| { + // Remove the list of `CustomBlockId` from the editor's display + // map, ensuring that if any diagnostics have been solved, the + // associated block stops being shown. + let block_ids = buffer_diagnostics_editor.blocks.clone(); + + buffer_diagnostics_editor.editor.update(cx, |editor, cx| { + editor.display_map.update(cx, |display_map, cx| { + display_map.remove_blocks(block_ids.into_iter().collect(), cx); + }) + }); + + let (anchor_ranges, _) = + buffer_diagnostics_editor + .multibuffer + .update(cx, |multibuffer, cx| { + multibuffer.set_excerpt_ranges_for_path( + PathKey::for_buffer(&buffer, cx), + buffer.clone(), + &buffer_snapshot, + excerpt_ranges, + cx, + ) + }); + + if was_empty { + if let Some(anchor_range) = anchor_ranges.first() { + let range_to_select = anchor_range.start..anchor_range.start; + + buffer_diagnostics_editor.editor.update(cx, |editor, cx| { + editor.change_selections(Default::default(), window, cx, |selection| { + selection.select_anchor_ranges([range_to_select]) + }) + }); + + // If the `BufferDiagnosticsEditor` is currently + // focused, move focus to its editor. + if buffer_diagnostics_editor.focus_handle.is_focused(window) { + buffer_diagnostics_editor + .editor + .read(cx) + .focus_handle(cx) + .focus(window); + } + } + } + + // Cloning the blocks before moving ownership so these can later + // be used to set the block contents for testing purposes. + #[cfg(test)] + let cloned_blocks = blocks.clone(); + + // Build new diagnostic blocks to be added to the editor's + // display map for the new diagnostics. Update the `blocks` + // property before finishing, to ensure the blocks are removed + // on the next execution. + let editor_blocks = + anchor_ranges + .into_iter() + .zip(blocks.into_iter()) + .map(|(anchor, block)| { + let editor = buffer_diagnostics_editor.editor.downgrade(); + + BlockProperties { + placement: BlockPlacement::Near(anchor.start), + height: Some(1), + style: BlockStyle::Flex, + render: Arc::new(move |block_context| { + block.render_block(editor.clone(), block_context) + }), + priority: 1, + } + }); + + let block_ids = buffer_diagnostics_editor.editor.update(cx, |editor, cx| { + editor.display_map.update(cx, |display_map, cx| { + display_map.insert_blocks(editor_blocks, cx) + }) + }); + + // In order to be able to verify which diagnostic blocks are + // rendered in the editor, the `set_block_content_for_tests` + // function must be used, so that the + // `editor::test::editor_content_with_blocks` function can then + // be called to fetch these blocks. + #[cfg(test)] + { + for (block_id, block) in block_ids.iter().zip(cloned_blocks.iter()) { + let markdown = block.markdown.clone(); + editor::test::set_block_content_for_tests( + &buffer_diagnostics_editor.editor, + *block_id, + cx, + move |cx| { + markdown::MarkdownElement::rendered_text( + markdown.clone(), + cx, + editor::hover_popover::diagnostics_markdown_style, + ) + }, + ); + } + } + + buffer_diagnostics_editor.blocks = block_ids; + cx.notify() + }) + }) + } + + fn set_diagnostics(&mut self, diagnostics: &Vec>) { + self.diagnostics = diagnostics.clone(); + } + + fn diagnostics_are_unchanged( + &self, + diagnostics: &Vec>, + snapshot: &BufferSnapshot, + ) -> bool { + if self.diagnostics.len() != diagnostics.len() { + return false; + } + + self.diagnostics + .iter() + .zip(diagnostics.iter()) + .all(|(existing, new)| { + existing.diagnostic.message == new.diagnostic.message + && existing.diagnostic.severity == new.diagnostic.severity + && existing.diagnostic.is_primary == new.diagnostic.is_primary + && existing.range.to_offset(snapshot) == new.range.to_offset(snapshot) + }) + } + + fn focus_in(&mut self, window: &mut Window, cx: &mut Context) { + // If the `BufferDiagnosticsEditor` is focused and the multibuffer is + // not empty, focus on the editor instead, which will allow the user to + // start interacting and editing the buffer's contents. + if self.focus_handle.is_focused(window) && !self.multibuffer.read(cx).is_empty() { + self.editor.focus_handle(cx).focus(window) + } + } + + fn focus_out(&mut self, window: &mut Window, cx: &mut Context) { + if !self.focus_handle.is_focused(window) && !self.editor.focus_handle(cx).is_focused(window) + { + self.update_all_excerpts(window, cx); + } + } + + pub fn toggle_warnings( + &mut self, + _: &ToggleWarnings, + window: &mut Window, + cx: &mut Context, + ) { + let include_warnings = !self.include_warnings; + let max_severity = Self::max_diagnostics_severity(include_warnings); + + self.editor.update(cx, |editor, cx| { + editor.set_max_diagnostics_severity(max_severity, cx); + }); + + self.include_warnings = include_warnings; + self.diagnostics.clear(); + self.update_all_diagnostics(window, cx); + } + + fn max_diagnostics_severity(include_warnings: bool) -> DiagnosticSeverity { + match include_warnings { + true => DiagnosticSeverity::Warning, + false => DiagnosticSeverity::Error, + } + } + + #[cfg(test)] + pub fn editor(&self) -> &Entity { + &self.editor + } + + #[cfg(test)] + pub fn summary(&self) -> &DiagnosticSummary { + &self.summary + } +} + +impl Focusable for BufferDiagnosticsEditor { + fn focus_handle(&self, _: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl EventEmitter for BufferDiagnosticsEditor {} + +impl Item for BufferDiagnosticsEditor { + type Event = EditorEvent; + + fn act_as_type<'a>( + &'a self, + type_id: std::any::TypeId, + self_handle: &'a Entity, + _: &'a App, + ) -> Option { + if type_id == TypeId::of::() { + Some(self_handle.to_any()) + } else if type_id == TypeId::of::() { + Some(self.editor.to_any()) + } else { + None + } + } + + fn added_to_workspace( + &mut self, + workspace: &mut Workspace, + window: &mut Window, + cx: &mut Context, + ) { + self.editor.update(cx, |editor, cx| { + editor.added_to_workspace(workspace, window, cx) + }); + } + + fn breadcrumb_location(&self, _: &App) -> ToolbarItemLocation { + ToolbarItemLocation::PrimaryLeft + } + + fn breadcrumbs(&self, theme: &theme::Theme, cx: &App) -> Option> { + self.editor.breadcrumbs(theme, cx) + } + + fn can_save(&self, _cx: &App) -> bool { + true + } + + fn clone_on_split( + &self, + _workspace_id: Option, + window: &mut Window, + cx: &mut Context, + ) -> Option> + where + Self: Sized, + { + Some(cx.new(|cx| { + BufferDiagnosticsEditor::new( + self.project_path.clone(), + self.project.clone(), + self.buffer.clone(), + self.include_warnings, + window, + cx, + ) + })) + } + + fn deactivated(&mut self, window: &mut Window, cx: &mut Context) { + self.editor + .update(cx, |editor, cx| editor.deactivated(window, cx)); + } + + fn for_each_project_item(&self, cx: &App, f: &mut dyn FnMut(EntityId, &dyn ProjectItem)) { + self.editor.for_each_project_item(cx, f); + } + + fn has_conflict(&self, cx: &App) -> bool { + self.multibuffer.read(cx).has_conflict(cx) + } + + fn has_deleted_file(&self, cx: &App) -> bool { + self.multibuffer.read(cx).has_deleted_file(cx) + } + + fn is_dirty(&self, cx: &App) -> bool { + self.multibuffer.read(cx).is_dirty(cx) + } + + fn is_singleton(&self, _cx: &App) -> bool { + false + } + + fn navigate( + &mut self, + data: Box, + window: &mut Window, + cx: &mut Context, + ) -> bool { + self.editor + .update(cx, |editor, cx| editor.navigate(data, window, cx)) + } + + fn reload( + &mut self, + project: Entity, + window: &mut Window, + cx: &mut Context, + ) -> Task> { + self.editor.reload(project, window, cx) + } + + fn save( + &mut self, + options: workspace::item::SaveOptions, + project: Entity, + window: &mut Window, + cx: &mut Context, + ) -> Task> { + self.editor.save(options, project, window, cx) + } + + fn save_as( + &mut self, + _project: Entity, + _path: ProjectPath, + _window: &mut Window, + _cx: &mut Context, + ) -> Task> { + unreachable!() + } + + fn set_nav_history( + &mut self, + nav_history: ItemNavHistory, + _window: &mut Window, + cx: &mut Context, + ) { + self.editor.update(cx, |editor, _| { + editor.set_nav_history(Some(nav_history)); + }) + } + + // Builds the content to be displayed in the tab. + fn tab_content(&self, params: TabContentParams, _window: &Window, _cx: &App) -> AnyElement { + let error_count = self.summary.error_count; + let warning_count = self.summary.warning_count; + let label = Label::new( + self.project_path + .path + .file_name() + .map(|f| f.to_sanitized_string()) + .unwrap_or_else(|| self.project_path.path.to_sanitized_string()), + ); + + h_flex() + .gap_1() + .child(label) + .when(error_count == 0 && warning_count == 0, |parent| { + parent.child( + h_flex() + .gap_1() + .child(Icon::new(IconName::Check).color(Color::Success)), + ) + }) + .when(error_count > 0, |parent| { + parent.child( + h_flex() + .gap_1() + .child(Icon::new(IconName::XCircle).color(Color::Error)) + .child(Label::new(error_count.to_string()).color(params.text_color())), + ) + }) + .when(warning_count > 0, |parent| { + parent.child( + h_flex() + .gap_1() + .child(Icon::new(IconName::Warning).color(Color::Warning)) + .child(Label::new(warning_count.to_string()).color(params.text_color())), + ) + }) + .into_any_element() + } + + fn tab_content_text(&self, _detail: usize, _app: &App) -> SharedString { + "Buffer Diagnostics".into() + } + + fn tab_tooltip_text(&self, _: &App) -> Option { + Some( + format!( + "Buffer Diagnostics - {}", + self.project_path.path.to_sanitized_string() + ) + .into(), + ) + } + + fn telemetry_event_text(&self) -> Option<&'static str> { + Some("Buffer Diagnostics Opened") + } + + fn to_item_events(event: &EditorEvent, f: impl FnMut(ItemEvent)) { + Editor::to_item_events(event, f) + } +} + +impl Render for BufferDiagnosticsEditor { + fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + let filename = self.project_path.path.to_sanitized_string(); + let error_count = self.summary.error_count; + let warning_count = match self.include_warnings { + true => self.summary.warning_count, + false => 0, + }; + + let child = if error_count + warning_count == 0 { + let label = match warning_count { + 0 => "No problems in", + _ => "No errors in", + }; + + v_flex() + .key_context("EmptyPane") + .size_full() + .gap_1() + .justify_center() + .items_center() + .text_center() + .bg(cx.theme().colors().editor_background) + .child( + div() + .h_flex() + .child(Label::new(label).color(Color::Muted)) + .child( + Button::new("open-file", filename) + .style(ButtonStyle::Transparent) + .tooltip(Tooltip::text("Open File")) + .on_click(cx.listener(|buffer_diagnostics, _, window, cx| { + if let Some(workspace) = window.root::().flatten() { + workspace.update(cx, |workspace, cx| { + workspace + .open_path( + buffer_diagnostics.project_path.clone(), + None, + true, + window, + cx, + ) + .detach_and_log_err(cx); + }) + } + })), + ), + ) + .when(self.summary.warning_count > 0, |div| { + let label = match self.summary.warning_count { + 1 => "Show 1 warning".into(), + warning_count => format!("Show {} warnings", warning_count), + }; + + div.child( + Button::new("diagnostics-show-warning-label", label).on_click(cx.listener( + |buffer_diagnostics_editor, _, window, cx| { + buffer_diagnostics_editor.toggle_warnings( + &Default::default(), + window, + cx, + ); + cx.notify(); + }, + )), + ) + }) + } else { + div().size_full().child(self.editor.clone()) + }; + + div() + .key_context("Diagnostics") + .track_focus(&self.focus_handle(cx)) + .size_full() + .child(child) + } +} + +impl DiagnosticsToolbarEditor for WeakEntity { + fn include_warnings(&self, cx: &App) -> bool { + self.read_with(cx, |buffer_diagnostics_editor, _cx| { + buffer_diagnostics_editor.include_warnings + }) + .unwrap_or(false) + } + + fn has_stale_excerpts(&self, _cx: &App) -> bool { + false + } + + fn is_updating(&self, cx: &App) -> bool { + self.read_with(cx, |buffer_diagnostics_editor, cx| { + buffer_diagnostics_editor.update_excerpts_task.is_some() + || buffer_diagnostics_editor + .project + .read(cx) + .language_servers_running_disk_based_diagnostics(cx) + .next() + .is_some() + }) + .unwrap_or(false) + } + + fn stop_updating(&self, cx: &mut App) { + let _ = self.update(cx, |buffer_diagnostics_editor, cx| { + buffer_diagnostics_editor.update_excerpts_task = None; + cx.notify(); + }); + } + + fn refresh_diagnostics(&self, window: &mut Window, cx: &mut App) { + let _ = self.update(cx, |buffer_diagnostics_editor, cx| { + buffer_diagnostics_editor.update_all_excerpts(window, cx); + }); + } + + fn toggle_warnings(&self, window: &mut Window, cx: &mut App) { + let _ = self.update(cx, |buffer_diagnostics_editor, cx| { + buffer_diagnostics_editor.toggle_warnings(&Default::default(), window, cx); + }); + } + + fn get_diagnostics_for_buffer( + &self, + _buffer_id: text::BufferId, + cx: &App, + ) -> Vec> { + self.read_with(cx, |buffer_diagnostics_editor, _cx| { + buffer_diagnostics_editor.diagnostics.clone() + }) + .unwrap_or_default() + } +} diff --git a/crates/diagnostics/src/diagnostic_renderer.rs b/crates/diagnostics/src/diagnostic_renderer.rs index ce7b253702a01e24e7f4a457ac418572e0fa2729..e22065afa5587e25e35e5c65ffec2e18860b6788 100644 --- a/crates/diagnostics/src/diagnostic_renderer.rs +++ b/crates/diagnostics/src/diagnostic_renderer.rs @@ -18,7 +18,7 @@ use ui::{ }; use util::maybe; -use crate::ProjectDiagnosticsEditor; +use crate::toolbar_controls::DiagnosticsToolbarEditor; pub struct DiagnosticRenderer; @@ -26,7 +26,7 @@ impl DiagnosticRenderer { pub fn diagnostic_blocks_for_group( diagnostic_group: Vec>, buffer_id: BufferId, - diagnostics_editor: Option>, + diagnostics_editor: Option>, cx: &mut App, ) -> Vec { let Some(primary_ix) = diagnostic_group @@ -46,7 +46,7 @@ impl DiagnosticRenderer { markdown.push_str(" ("); } if let Some(source) = diagnostic.source.as_ref() { - markdown.push_str(&Markdown::escape(&source)); + markdown.push_str(&Markdown::escape(source)); } if diagnostic.source.is_some() && diagnostic.code.is_some() { markdown.push(' '); @@ -130,6 +130,7 @@ impl editor::DiagnosticRenderer for DiagnosticRenderer { cx: &mut App, ) -> Vec> { let blocks = Self::diagnostic_blocks_for_group(diagnostic_group, buffer_id, None, cx); + blocks .into_iter() .map(|block| { @@ -182,7 +183,7 @@ pub(crate) struct DiagnosticBlock { pub(crate) initial_range: Range, pub(crate) severity: DiagnosticSeverity, pub(crate) markdown: Entity, - pub(crate) diagnostics_editor: Option>, + pub(crate) diagnostics_editor: Option>, } impl DiagnosticBlock { @@ -233,7 +234,7 @@ impl DiagnosticBlock { pub fn open_link( editor: &mut Editor, - diagnostics_editor: &Option>, + diagnostics_editor: &Option>, link: SharedString, window: &mut Window, cx: &mut Context, @@ -254,18 +255,10 @@ impl DiagnosticBlock { if let Some(diagnostics_editor) = diagnostics_editor { if let Some(diagnostic) = diagnostics_editor - .read_with(cx, |diagnostics, _| { - diagnostics - .diagnostics - .get(&buffer_id) - .cloned() - .unwrap_or_default() - .into_iter() - .filter(|d| d.diagnostic.group_id == group_id) - .nth(ix) - }) - .ok() - .flatten() + .get_diagnostics_for_buffer(buffer_id, cx) + .into_iter() + .filter(|d| d.diagnostic.group_id == group_id) + .nth(ix) { let multibuffer = editor.buffer().read(cx); let Some(snapshot) = multibuffer @@ -287,26 +280,24 @@ impl DiagnosticBlock { } } } - } else { - if let Some(diagnostic) = editor - .snapshot(window, cx) - .buffer_snapshot - .diagnostic_group(buffer_id, group_id) - .nth(ix) - { - Self::jump_to(editor, diagnostic.range, window, cx) - } + } else if let Some(diagnostic) = editor + .snapshot(window, cx) + .buffer_snapshot + .diagnostic_group(buffer_id, group_id) + .nth(ix) + { + Self::jump_to(editor, diagnostic.range, window, cx) }; } - fn jump_to( + fn jump_to( editor: &mut Editor, - range: Range, + range: Range, window: &mut Window, cx: &mut Context, ) { let snapshot = &editor.buffer().read(cx).snapshot(cx); - let range = range.start.to_offset(&snapshot)..range.end.to_offset(&snapshot); + let range = range.start.to_offset(snapshot)..range.end.to_offset(snapshot); editor.unfold_ranges(&[range.start..range.end], true, false, cx); editor.change_selections(Default::default(), window, cx, |s| { diff --git a/crates/diagnostics/src/diagnostics.rs b/crates/diagnostics/src/diagnostics.rs index e7660920da30ddcc088c2bbee6bfb1cf05d51d58..ef4d6ec4395189971da710fd5378f65b19199a16 100644 --- a/crates/diagnostics/src/diagnostics.rs +++ b/crates/diagnostics/src/diagnostics.rs @@ -1,19 +1,21 @@ pub mod items; mod toolbar_controls; +mod buffer_diagnostics; mod diagnostic_renderer; #[cfg(test)] mod diagnostics_tests; use anyhow::Result; +use buffer_diagnostics::BufferDiagnosticsEditor; use collections::{BTreeSet, HashMap}; use diagnostic_renderer::DiagnosticBlock; use editor::{ - DEFAULT_MULTIBUFFER_CONTEXT, Editor, EditorEvent, ExcerptRange, MultiBuffer, PathKey, + Editor, EditorEvent, ExcerptRange, MultiBuffer, PathKey, display_map::{BlockPlacement, BlockProperties, BlockStyle, CustomBlockId}, + multibuffer_context_lines, }; -use futures::future::join_all; use gpui::{ AnyElement, AnyView, App, AsyncApp, Context, Entity, EventEmitter, FocusHandle, Focusable, Global, InteractiveElement, IntoElement, ParentElement, Render, SharedString, Styled, @@ -24,7 +26,6 @@ use language::{ }; use project::{ DiagnosticSummary, Project, ProjectPath, - lsp_store::rust_analyzer_ext::{cancel_flycheck, run_flycheck}, project_settings::{DiagnosticSeverity, ProjectSettings}, }; use settings::Settings; @@ -37,6 +38,7 @@ use std::{ }; use text::{BufferId, OffsetRangeExt}; use theme::ActiveTheme; +use toolbar_controls::DiagnosticsToolbarEditor; pub use toolbar_controls::ToolbarControls; use ui::{Icon, IconName, Label, h_flex, prelude::*}; use util::ResultExt; @@ -65,6 +67,7 @@ impl Global for IncludeWarnings {} pub fn init(cx: &mut App) { editor::set_diagnostic_renderer(diagnostic_renderer::DiagnosticRenderer {}, cx); cx.observe_new(ProjectDiagnosticsEditor::register).detach(); + cx.observe_new(BufferDiagnosticsEditor::register).detach(); } pub(crate) struct ProjectDiagnosticsEditor { @@ -79,20 +82,14 @@ pub(crate) struct ProjectDiagnosticsEditor { paths_to_update: BTreeSet, include_warnings: bool, update_excerpts_task: Option>>, - cargo_diagnostics_fetch: CargoDiagnosticsFetchState, diagnostic_summary_update: Task<()>, _subscription: Subscription, } -struct CargoDiagnosticsFetchState { - fetch_task: Option>, - cancel_task: Option>, - diagnostic_sources: Arc>, -} - impl EventEmitter for ProjectDiagnosticsEditor {} const DIAGNOSTICS_UPDATE_DELAY: Duration = Duration::from_millis(50); +const DIAGNOSTICS_SUMMARY_UPDATE_DELAY: Duration = Duration::from_millis(30); impl Render for ProjectDiagnosticsEditor { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { @@ -102,43 +99,44 @@ impl Render for ProjectDiagnosticsEditor { 0 }; - let child = if warning_count + self.summary.error_count == 0 { - let label = if self.summary.warning_count == 0 { - SharedString::new_static("No problems in workspace") + let child = + if warning_count + self.summary.error_count == 0 && self.editor.read(cx).is_empty(cx) { + let label = if self.summary.warning_count == 0 { + SharedString::new_static("No problems in workspace") + } else { + SharedString::new_static("No errors in workspace") + }; + v_flex() + .key_context("EmptyPane") + .size_full() + .gap_1() + .justify_center() + .items_center() + .text_center() + .bg(cx.theme().colors().editor_background) + .child(Label::new(label).color(Color::Muted)) + .when(self.summary.warning_count > 0, |this| { + let plural_suffix = if self.summary.warning_count > 1 { + "s" + } else { + "" + }; + let label = format!( + "Show {} warning{}", + self.summary.warning_count, plural_suffix + ); + this.child( + Button::new("diagnostics-show-warning-label", label).on_click( + cx.listener(|this, _, window, cx| { + this.toggle_warnings(&Default::default(), window, cx); + cx.notify(); + }), + ), + ) + }) } else { - SharedString::new_static("No errors in workspace") + div().size_full().child(self.editor.clone()) }; - v_flex() - .key_context("EmptyPane") - .size_full() - .gap_1() - .justify_center() - .items_center() - .text_center() - .bg(cx.theme().colors().editor_background) - .child(Label::new(label).color(Color::Muted)) - .when(self.summary.warning_count > 0, |this| { - let plural_suffix = if self.summary.warning_count > 1 { - "s" - } else { - "" - }; - let label = format!( - "Show {} warning{}", - self.summary.warning_count, plural_suffix - ); - this.child( - Button::new("diagnostics-show-warning-label", label).on_click(cx.listener( - |this, _, window, cx| { - this.toggle_warnings(&Default::default(), window, cx); - cx.notify(); - }, - )), - ) - }) - } else { - div().size_full().child(self.editor.clone()) - }; div() .key_context("Diagnostics") @@ -151,7 +149,7 @@ impl Render for ProjectDiagnosticsEditor { } impl ProjectDiagnosticsEditor { - fn register( + pub fn register( workspace: &mut Workspace, _window: Option<&mut Window>, _: &mut Context, @@ -167,7 +165,7 @@ impl ProjectDiagnosticsEditor { cx: &mut Context, ) -> Self { let project_event_subscription = - cx.subscribe_in(&project_handle, window, |this, project, event, window, cx| match event { + cx.subscribe_in(&project_handle, window, |this, _project, event, window, cx| match event { project::Event::DiskBasedDiagnosticsStarted { .. } => { cx.notify(); } @@ -180,13 +178,12 @@ impl ProjectDiagnosticsEditor { paths, } => { this.paths_to_update.extend(paths.clone()); - let project = project.clone(); this.diagnostic_summary_update = cx.spawn(async move |this, cx| { cx.background_executor() - .timer(Duration::from_millis(30)) + .timer(DIAGNOSTICS_SUMMARY_UPDATE_DELAY) .await; this.update(cx, |this, cx| { - this.summary = project.read(cx).diagnostic_summary(false, cx); + this.update_diagnostic_summary(cx); }) .log_err(); }); @@ -241,6 +238,7 @@ impl ProjectDiagnosticsEditor { } } EditorEvent::Blurred => this.update_stale_excerpts(window, cx), + EditorEvent::Saved => this.update_stale_excerpts(window, cx), _ => {} } }, @@ -260,11 +258,7 @@ impl ProjectDiagnosticsEditor { ) }); this.diagnostics.clear(); - this.update_all_diagnostics(false, window, cx); - }) - .detach(); - cx.observe_release(&cx.entity(), |editor, _, cx| { - editor.stop_cargo_diagnostics_fetch(cx); + this.update_all_excerpts(window, cx); }) .detach(); @@ -281,20 +275,15 @@ impl ProjectDiagnosticsEditor { editor, paths_to_update: Default::default(), update_excerpts_task: None, - cargo_diagnostics_fetch: CargoDiagnosticsFetchState { - fetch_task: None, - cancel_task: None, - diagnostic_sources: Arc::new(Vec::new()), - }, diagnostic_summary_update: Task::ready(()), _subscription: project_event_subscription, }; - this.update_all_diagnostics(true, window, cx); + this.update_all_excerpts(window, cx); this } fn update_stale_excerpts(&mut self, window: &mut Window, cx: &mut Context) { - if self.update_excerpts_task.is_some() { + if self.update_excerpts_task.is_some() || self.multibuffer.read(cx).is_dirty(cx) { return; } @@ -341,6 +330,7 @@ impl ProjectDiagnosticsEditor { let is_active = workspace .active_item(cx) .is_some_and(|item| item.item_id() == existing.item_id()); + workspace.activate_item(&existing, true, !is_active, window, cx); } else { let workspace_handle = cx.entity().downgrade(); @@ -373,22 +363,10 @@ impl ProjectDiagnosticsEditor { window: &mut Window, cx: &mut Context, ) { - let fetch_cargo_diagnostics = ProjectSettings::get_global(cx) - .diagnostics - .fetch_cargo_diagnostics(); - - if fetch_cargo_diagnostics { - if self.cargo_diagnostics_fetch.fetch_task.is_some() { - self.stop_cargo_diagnostics_fetch(cx); - } else { - self.update_all_diagnostics(false, window, cx); - } + if self.update_excerpts_task.is_some() { + self.update_excerpts_task = None; } else { - if self.update_excerpts_task.is_some() { - self.update_excerpts_task = None; - } else { - self.update_all_diagnostics(false, window, cx); - } + self.update_all_excerpts(window, cx); } cx.notify(); } @@ -406,93 +384,29 @@ impl ProjectDiagnosticsEditor { } } - fn update_all_diagnostics( - &mut self, - first_launch: bool, - window: &mut Window, - cx: &mut Context, - ) { - let cargo_diagnostics_sources = self.cargo_diagnostics_sources(cx); - if cargo_diagnostics_sources.is_empty() { - self.update_all_excerpts(window, cx); - } else if first_launch && !self.summary.is_empty() { - self.update_all_excerpts(window, cx); - } else { - self.fetch_cargo_diagnostics(Arc::new(cargo_diagnostics_sources), cx); - } - } - - fn fetch_cargo_diagnostics( - &mut self, - diagnostics_sources: Arc>, - cx: &mut Context, - ) { - let project = self.project.clone(); - self.cargo_diagnostics_fetch.cancel_task = None; - self.cargo_diagnostics_fetch.fetch_task = None; - self.cargo_diagnostics_fetch.diagnostic_sources = diagnostics_sources.clone(); - if self.cargo_diagnostics_fetch.diagnostic_sources.is_empty() { - return; - } - - self.cargo_diagnostics_fetch.fetch_task = Some(cx.spawn(async move |editor, cx| { - let mut fetch_tasks = Vec::new(); - for buffer_path in diagnostics_sources.iter().cloned() { - if cx - .update(|cx| { - fetch_tasks.push(run_flycheck(project.clone(), buffer_path, cx)); - }) - .is_err() - { - break; - } - } - - let _ = join_all(fetch_tasks).await; - editor - .update(cx, |editor, _| { - editor.cargo_diagnostics_fetch.fetch_task = None; - }) - .ok(); - })); - } - - fn stop_cargo_diagnostics_fetch(&mut self, cx: &mut App) { - self.cargo_diagnostics_fetch.fetch_task = None; - let mut cancel_gasks = Vec::new(); - for buffer_path in std::mem::take(&mut self.cargo_diagnostics_fetch.diagnostic_sources) - .iter() - .cloned() - { - cancel_gasks.push(cancel_flycheck(self.project.clone(), buffer_path, cx)); - } - - self.cargo_diagnostics_fetch.cancel_task = Some(cx.background_spawn(async move { - let _ = join_all(cancel_gasks).await; - log::info!("Finished fetching cargo diagnostics"); - })); - } - /// Enqueue an update of all excerpts. Updates all paths that either /// currently have diagnostics or are currently present in this view. fn update_all_excerpts(&mut self, window: &mut Window, cx: &mut Context) { self.project.update(cx, |project, cx| { - let mut paths = project + let mut project_paths = project .diagnostic_summaries(false, cx) - .map(|(path, _, _)| path) + .map(|(project_path, _, _)| project_path) .collect::>(); + self.multibuffer.update(cx, |multibuffer, cx| { for buffer in multibuffer.all_buffers() { if let Some(file) = buffer.read(cx).file() { - paths.insert(ProjectPath { + project_paths.insert(ProjectPath { path: file.path().clone(), worktree_id: file.worktree_id(cx), }); } } }); - self.paths_to_update = paths; + + self.paths_to_update = project_paths; }); + self.update_stale_excerpts(window, cx); } @@ -522,19 +436,21 @@ impl ProjectDiagnosticsEditor { let was_empty = self.multibuffer.read(cx).is_empty(); let buffer_snapshot = buffer.read(cx).snapshot(); let buffer_id = buffer_snapshot.remote_id(); + let max_severity = if self.include_warnings { lsp::DiagnosticSeverity::WARNING } else { lsp::DiagnosticSeverity::ERROR }; - cx.spawn_in(window, async move |this, mut cx| { + cx.spawn_in(window, async move |this, cx| { let diagnostics = buffer_snapshot .diagnostics_in_range::<_, text::Anchor>( Point::zero()..buffer_snapshot.max_point(), false, ) .collect::>(); + let unchanged = this.update(cx, |this, _| { if this.diagnostics.get(&buffer_id).is_some_and(|existing| { this.diagnostics_are_unchanged(existing, &diagnostics, &buffer_snapshot) @@ -542,7 +458,7 @@ impl ProjectDiagnosticsEditor { return true; } this.diagnostics.insert(buffer_id, diagnostics.clone()); - return false; + false })?; if unchanged { return Ok(()); @@ -569,7 +485,7 @@ impl ProjectDiagnosticsEditor { crate::diagnostic_renderer::DiagnosticRenderer::diagnostic_blocks_for_group( group, buffer_snapshot.remote_id(), - Some(this.clone()), + Some(Arc::new(this.clone())), cx, ) })?; @@ -590,14 +506,16 @@ impl ProjectDiagnosticsEditor { } let mut excerpt_ranges: Vec> = Vec::new(); + let context_lines = cx.update(|_, cx| multibuffer_context_lines(cx))?; for b in blocks.iter() { let excerpt_range = context_range_for_entry( b.initial_range.clone(), - DEFAULT_MULTIBUFFER_CONTEXT, + context_lines, buffer_snapshot.clone(), - &mut cx, + cx, ) .await; + let i = excerpt_ranges .binary_search_by(|probe| { probe @@ -639,17 +557,15 @@ impl ProjectDiagnosticsEditor { #[cfg(test)] let cloned_blocks = blocks.clone(); - if was_empty { - if let Some(anchor_range) = anchor_ranges.first() { - let range_to_select = anchor_range.start..anchor_range.start; - this.editor.update(cx, |editor, cx| { - editor.change_selections(Default::default(), window, cx, |s| { - s.select_anchor_ranges([range_to_select]); - }) - }); - if this.focus_handle.is_focused(window) { - this.editor.read(cx).focus_handle(cx).focus(window); - } + if was_empty && let Some(anchor_range) = anchor_ranges.first() { + let range_to_select = anchor_range.start..anchor_range.start; + this.editor.update(cx, |editor, cx| { + editor.change_selections(Default::default(), window, cx, |s| { + s.select_anchor_ranges([range_to_select]); + }) + }); + if this.focus_handle.is_focused(window) { + this.editor.read(cx).focus_handle(cx).focus(window); } } @@ -669,6 +585,7 @@ impl ProjectDiagnosticsEditor { priority: 1, } }); + let block_ids = this.editor.update(cx, |editor, cx| { editor.display_map.update(cx, |display_map, cx| { display_map.insert_blocks(editor_blocks, cx) @@ -700,28 +617,8 @@ impl ProjectDiagnosticsEditor { }) } - pub fn cargo_diagnostics_sources(&self, cx: &App) -> Vec { - let fetch_cargo_diagnostics = ProjectSettings::get_global(cx) - .diagnostics - .fetch_cargo_diagnostics(); - if !fetch_cargo_diagnostics { - return Vec::new(); - } - self.project - .read(cx) - .worktrees(cx) - .filter_map(|worktree| { - let _cargo_toml_entry = worktree.read(cx).entry_for_path("Cargo.toml")?; - let rust_file_entry = worktree.read(cx).entries(false, 0).find(|entry| { - entry - .path - .extension() - .and_then(|extension| extension.to_str()) - == Some("rs") - })?; - self.project.read(cx).path_for_entry(rust_file_entry.id, cx) - }) - .collect() + fn update_diagnostic_summary(&mut self, cx: &mut Context) { + self.summary = self.project.read(cx).diagnostic_summary(false, cx); } } @@ -931,6 +828,68 @@ impl Item for ProjectDiagnosticsEditor { } } +impl DiagnosticsToolbarEditor for WeakEntity { + fn include_warnings(&self, cx: &App) -> bool { + self.read_with(cx, |project_diagnostics_editor, _cx| { + project_diagnostics_editor.include_warnings + }) + .unwrap_or(false) + } + + fn has_stale_excerpts(&self, cx: &App) -> bool { + self.read_with(cx, |project_diagnostics_editor, _cx| { + !project_diagnostics_editor.paths_to_update.is_empty() + }) + .unwrap_or(false) + } + + fn is_updating(&self, cx: &App) -> bool { + self.read_with(cx, |project_diagnostics_editor, cx| { + project_diagnostics_editor.update_excerpts_task.is_some() + || project_diagnostics_editor + .project + .read(cx) + .language_servers_running_disk_based_diagnostics(cx) + .next() + .is_some() + }) + .unwrap_or(false) + } + + fn stop_updating(&self, cx: &mut App) { + let _ = self.update(cx, |project_diagnostics_editor, cx| { + project_diagnostics_editor.update_excerpts_task = None; + cx.notify(); + }); + } + + fn refresh_diagnostics(&self, window: &mut Window, cx: &mut App) { + let _ = self.update(cx, |project_diagnostics_editor, cx| { + project_diagnostics_editor.update_all_excerpts(window, cx); + }); + } + + fn toggle_warnings(&self, window: &mut Window, cx: &mut App) { + let _ = self.update(cx, |project_diagnostics_editor, cx| { + project_diagnostics_editor.toggle_warnings(&Default::default(), window, cx); + }); + } + + fn get_diagnostics_for_buffer( + &self, + buffer_id: text::BufferId, + cx: &App, + ) -> Vec> { + self.read_with(cx, |project_diagnostics_editor, _cx| { + project_diagnostics_editor + .diagnostics + .get(&buffer_id) + .cloned() + .unwrap_or_default() + }) + .unwrap_or_default() + } +} const DIAGNOSTIC_EXPANSION_ROW_LIMIT: u32 = 32; async fn context_range_for_entry( @@ -980,18 +939,16 @@ async fn heuristic_syntactic_expand( // Remove blank lines from start and end if let Some(start_row) = (outline_range.start.row..outline_range.end.row) .find(|row| !snapshot.line_indent_for_row(*row).is_line_blank()) - { - if let Some(end_row) = (outline_range.start.row..outline_range.end.row + 1) + && let Some(end_row) = (outline_range.start.row..outline_range.end.row + 1) .rev() .find(|row| !snapshot.line_indent_for_row(*row).is_line_blank()) - { - let row_count = end_row.saturating_sub(start_row); - if row_count <= max_row_count { - return Some(RangeInclusive::new( - outline_range.start.row, - outline_range.end.row, - )); - } + { + let row_count = end_row.saturating_sub(start_row); + if row_count <= max_row_count { + return Some(RangeInclusive::new( + outline_range.start.row, + outline_range.end.row, + )); } } } diff --git a/crates/diagnostics/src/diagnostics_tests.rs b/crates/diagnostics/src/diagnostics_tests.rs index 8fb223b2cbfcc7db817059dd92bf1ff869846645..a50e20f579e67010819de0fdb7273d4c9912b8b8 100644 --- a/crates/diagnostics/src/diagnostics_tests.rs +++ b/crates/diagnostics/src/diagnostics_tests.rs @@ -24,6 +24,7 @@ use settings::SettingsStore; use std::{ env, path::{Path, PathBuf}, + str::FromStr, }; use unindent::Unindent as _; use util::{RandomCharIter, path, post_inc}; @@ -70,7 +71,7 @@ async fn test_diagnostics(cx: &mut TestAppContext) { let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); let cx = &mut VisualTestContext::from_window(*window, cx); let workspace = window.root(cx).unwrap(); - let uri = lsp::Url::from_file_path(path!("/test/main.rs")).unwrap(); + let uri = lsp::Uri::from_file_path(path!("/test/main.rs")).unwrap(); // Create some diagnostics lsp_store.update(cx, |lsp_store, cx| { @@ -167,7 +168,7 @@ async fn test_diagnostics(cx: &mut TestAppContext) { .update_diagnostics( language_server_id, lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/test/consts.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/test/consts.rs")).unwrap(), diagnostics: vec![lsp::Diagnostic { range: lsp::Range::new( lsp::Position::new(0, 15), @@ -243,7 +244,7 @@ async fn test_diagnostics(cx: &mut TestAppContext) { .update_diagnostics( language_server_id, lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/test/consts.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/test/consts.rs")).unwrap(), diagnostics: vec![ lsp::Diagnostic { range: lsp::Range::new( @@ -356,14 +357,14 @@ async fn test_diagnostics_with_folds(cx: &mut TestAppContext) { .update_diagnostics( server_id_1, lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/test/main.js")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/test/main.js")).unwrap(), diagnostics: vec![lsp::Diagnostic { range: lsp::Range::new(lsp::Position::new(4, 0), lsp::Position::new(4, 4)), severity: Some(lsp::DiagnosticSeverity::WARNING), message: "no method `tset`".to_string(), related_information: Some(vec![lsp::DiagnosticRelatedInformation { location: lsp::Location::new( - lsp::Url::from_file_path(path!("/test/main.js")).unwrap(), + lsp::Uri::from_file_path(path!("/test/main.js")).unwrap(), lsp::Range::new( lsp::Position::new(0, 9), lsp::Position::new(0, 13), @@ -465,7 +466,7 @@ async fn test_diagnostics_multiple_servers(cx: &mut TestAppContext) { .update_diagnostics( server_id_1, lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/test/main.js")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/test/main.js")).unwrap(), diagnostics: vec![lsp::Diagnostic { range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 1)), severity: Some(lsp::DiagnosticSeverity::WARNING), @@ -509,7 +510,7 @@ async fn test_diagnostics_multiple_servers(cx: &mut TestAppContext) { .update_diagnostics( server_id_2, lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/test/main.js")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/test/main.js")).unwrap(), diagnostics: vec![lsp::Diagnostic { range: lsp::Range::new(lsp::Position::new(1, 0), lsp::Position::new(1, 1)), severity: Some(lsp::DiagnosticSeverity::ERROR), @@ -552,7 +553,7 @@ async fn test_diagnostics_multiple_servers(cx: &mut TestAppContext) { .update_diagnostics( server_id_1, lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/test/main.js")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/test/main.js")).unwrap(), diagnostics: vec![lsp::Diagnostic { range: lsp::Range::new(lsp::Position::new(2, 0), lsp::Position::new(2, 1)), severity: Some(lsp::DiagnosticSeverity::WARNING), @@ -571,7 +572,7 @@ async fn test_diagnostics_multiple_servers(cx: &mut TestAppContext) { .update_diagnostics( server_id_2, lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/test/main.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/test/main.rs")).unwrap(), diagnostics: vec![], version: None, }, @@ -608,7 +609,7 @@ async fn test_diagnostics_multiple_servers(cx: &mut TestAppContext) { .update_diagnostics( server_id_2, lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/test/main.js")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/test/main.js")).unwrap(), diagnostics: vec![lsp::Diagnostic { range: lsp::Range::new(lsp::Position::new(3, 0), lsp::Position::new(3, 1)), severity: Some(lsp::DiagnosticSeverity::WARNING), @@ -681,7 +682,7 @@ async fn test_random_diagnostics_blocks(cx: &mut TestAppContext, mut rng: StdRng Default::default(); for _ in 0..operations { - match rng.gen_range(0..100) { + match rng.random_range(0..100) { // language server completes its diagnostic check 0..=20 if !updated_language_servers.is_empty() => { let server_id = *updated_language_servers.iter().choose(&mut rng).unwrap(); @@ -690,7 +691,7 @@ async fn test_random_diagnostics_blocks(cx: &mut TestAppContext, mut rng: StdRng lsp_store.disk_based_diagnostics_finished(server_id, cx) }); - if rng.gen_bool(0.5) { + if rng.random_bool(0.5) { cx.run_until_parked(); } } @@ -700,7 +701,7 @@ async fn test_random_diagnostics_blocks(cx: &mut TestAppContext, mut rng: StdRng let (path, server_id, diagnostics) = match current_diagnostics.iter_mut().choose(&mut rng) { // update existing set of diagnostics - Some(((path, server_id), diagnostics)) if rng.gen_bool(0.5) => { + Some(((path, server_id), diagnostics)) if rng.random_bool(0.5) => { (path.clone(), *server_id, diagnostics) } @@ -708,13 +709,13 @@ async fn test_random_diagnostics_blocks(cx: &mut TestAppContext, mut rng: StdRng _ => { let path: PathBuf = format!(path!("/test/{}.rs"), post_inc(&mut next_filename)).into(); - let len = rng.gen_range(128..256); + let len = rng.random_range(128..256); let content = RandomCharIter::new(&mut rng).take(len).collect::(); fs.insert_file(&path, content.into_bytes()).await; let server_id = match language_server_ids.iter().choose(&mut rng) { - Some(server_id) if rng.gen_bool(0.5) => *server_id, + Some(server_id) if rng.random_bool(0.5) => *server_id, _ => { let id = LanguageServerId(language_server_ids.len()); language_server_ids.push(id); @@ -745,8 +746,8 @@ async fn test_random_diagnostics_blocks(cx: &mut TestAppContext, mut rng: StdRng .update_diagnostics( server_id, lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(&path).unwrap_or_else(|_| { - lsp::Url::parse("file:///test/fallback.rs").unwrap() + uri: lsp::Uri::from_file_path(&path).unwrap_or_else(|_| { + lsp::Uri::from_str("file:///test/fallback.rs").unwrap() }), diagnostics: diagnostics.clone(), version: None, @@ -845,7 +846,7 @@ async fn test_random_diagnostics_with_inlays(cx: &mut TestAppContext, mut rng: S let mut next_inlay_id = 0; for _ in 0..operations { - match rng.gen_range(0..100) { + match rng.random_range(0..100) { // language server completes its diagnostic check 0..=20 if !updated_language_servers.is_empty() => { let server_id = *updated_language_servers.iter().choose(&mut rng).unwrap(); @@ -854,7 +855,7 @@ async fn test_random_diagnostics_with_inlays(cx: &mut TestAppContext, mut rng: S lsp_store.disk_based_diagnostics_finished(server_id, cx) }); - if rng.gen_bool(0.5) { + if rng.random_bool(0.5) { cx.run_until_parked(); } } @@ -862,8 +863,8 @@ async fn test_random_diagnostics_with_inlays(cx: &mut TestAppContext, mut rng: S 21..=50 => mutated_diagnostics.update_in(cx, |diagnostics, window, cx| { diagnostics.editor.update(cx, |editor, cx| { let snapshot = editor.snapshot(window, cx); - if snapshot.buffer_snapshot.len() > 0 { - let position = rng.gen_range(0..snapshot.buffer_snapshot.len()); + if !snapshot.buffer_snapshot.is_empty() { + let position = rng.random_range(0..snapshot.buffer_snapshot.len()); let position = snapshot.buffer_snapshot.clip_offset(position, Bias::Left); log::info!( "adding inlay at {position}/{}: {:?}", @@ -889,7 +890,7 @@ async fn test_random_diagnostics_with_inlays(cx: &mut TestAppContext, mut rng: S let (path, server_id, diagnostics) = match current_diagnostics.iter_mut().choose(&mut rng) { // update existing set of diagnostics - Some(((path, server_id), diagnostics)) if rng.gen_bool(0.5) => { + Some(((path, server_id), diagnostics)) if rng.random_bool(0.5) => { (path.clone(), *server_id, diagnostics) } @@ -897,13 +898,13 @@ async fn test_random_diagnostics_with_inlays(cx: &mut TestAppContext, mut rng: S _ => { let path: PathBuf = format!(path!("/test/{}.rs"), post_inc(&mut next_filename)).into(); - let len = rng.gen_range(128..256); + let len = rng.random_range(128..256); let content = RandomCharIter::new(&mut rng).take(len).collect::(); fs.insert_file(&path, content.into_bytes()).await; let server_id = match language_server_ids.iter().choose(&mut rng) { - Some(server_id) if rng.gen_bool(0.5) => *server_id, + Some(server_id) if rng.random_bool(0.5) => *server_id, _ => { let id = LanguageServerId(language_server_ids.len()); language_server_ids.push(id); @@ -934,8 +935,8 @@ async fn test_random_diagnostics_with_inlays(cx: &mut TestAppContext, mut rng: S .update_diagnostics( server_id, lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(&path).unwrap_or_else(|_| { - lsp::Url::parse("file:///test/fallback.rs").unwrap() + uri: lsp::Uri::from_file_path(&path).unwrap_or_else(|_| { + lsp::Uri::from_str("file:///test/fallback.rs").unwrap() }), diagnostics: diagnostics.clone(), version: None, @@ -971,7 +972,7 @@ async fn active_diagnostics_dismiss_after_invalidation(cx: &mut TestAppContext) let mut cx = EditorTestContext::new(cx).await; let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.set_state(indoc! {" ˇfn func(abc def: i32) -> u32 { @@ -985,7 +986,7 @@ async fn active_diagnostics_dismiss_after_invalidation(cx: &mut TestAppContext) .update_diagnostics( LanguageServerId(0), lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/root/file")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/root/file")).unwrap(), version: None, diagnostics: vec![lsp::Diagnostic { range: lsp::Range::new( @@ -1028,7 +1029,7 @@ async fn active_diagnostics_dismiss_after_invalidation(cx: &mut TestAppContext) .update_diagnostics( LanguageServerId(0), lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/root/file")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/root/file")).unwrap(), version: None, diagnostics: Vec::new(), }, @@ -1065,7 +1066,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { let mut cx = EditorTestContext::new(cx).await; let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.set_state(indoc! {" ˇfn func(abc def: i32) -> u32 { @@ -1078,7 +1079,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { .update_diagnostics( LanguageServerId(0), lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/root/file")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/root/file")).unwrap(), version: None, diagnostics: vec![ lsp::Diagnostic { @@ -1239,14 +1240,14 @@ async fn test_diagnostics_with_links(cx: &mut TestAppContext) { } "}); let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.update(|_, cx| { lsp_store.update(cx, |lsp_store, cx| { lsp_store.update_diagnostics( LanguageServerId(0), lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/root/file")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/root/file")).unwrap(), version: None, diagnostics: vec![lsp::Diagnostic { range: lsp::Range::new(lsp::Position::new(0, 8), lsp::Position::new(0, 12)), @@ -1293,13 +1294,13 @@ async fn test_hover_diagnostic_and_info_popovers(cx: &mut gpui::TestAppContext) fn «test»() { println!(); } "}); let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.update(|_, cx| { lsp_store.update(cx, |lsp_store, cx| { lsp_store.update_diagnostics( LanguageServerId(0), lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/root/dir/file.rs")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/root/dir/file.rs")).unwrap(), version: None, diagnostics: vec![lsp::Diagnostic { range, @@ -1376,7 +1377,7 @@ async fn test_diagnostics_with_code(cx: &mut TestAppContext) { let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); let cx = &mut VisualTestContext::from_window(*window, cx); let workspace = window.root(cx).unwrap(); - let uri = lsp::Url::from_file_path(path!("/root/main.js")).unwrap(); + let uri = lsp::Uri::from_file_path(path!("/root/main.js")).unwrap(); // Create diagnostics with code fields lsp_store.update(cx, |lsp_store, cx| { @@ -1450,7 +1451,7 @@ async fn go_to_diagnostic_with_severity(cx: &mut TestAppContext) { let mut cx = EditorTestContext::new(cx).await; let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.set_state(indoc! {"error warning info hiˇnt"}); @@ -1460,7 +1461,7 @@ async fn go_to_diagnostic_with_severity(cx: &mut TestAppContext) { .update_diagnostics( LanguageServerId(0), lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/root/file")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/root/file")).unwrap(), version: None, diagnostics: vec![ lsp::Diagnostic { @@ -1566,6 +1567,440 @@ async fn go_to_diagnostic_with_severity(cx: &mut TestAppContext) { cx.assert_editor_state(indoc! {"error ˇwarning info hint"}); } +#[gpui::test] +async fn test_buffer_diagnostics(cx: &mut TestAppContext) { + init_test(cx); + + // We'll be creating two different files, both with diagnostics, so we can + // later verify that, since the `BufferDiagnosticsEditor` only shows + // diagnostics for the provided path, the diagnostics for the other file + // will not be shown, contrary to what happens with + // `ProjectDiagnosticsEditor`. + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/test"), + json!({ + "main.rs": " + fn main() { + let x = vec![]; + let y = vec![]; + a(x); + b(y); + c(y); + d(x); + } + " + .unindent(), + "other.rs": " + fn other() { + let unused = 42; + undefined_function(); + } + " + .unindent(), + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; + let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let cx = &mut VisualTestContext::from_window(*window, cx); + let project_path = project::ProjectPath { + worktree_id: project.read_with(cx, |project, cx| { + project.worktrees(cx).next().unwrap().read(cx).id() + }), + path: Arc::from(Path::new("main.rs")), + }; + let buffer = project + .update(cx, |project, cx| { + project.open_buffer(project_path.clone(), cx) + }) + .await + .ok(); + + // Create the diagnostics for `main.rs`. + let language_server_id = LanguageServerId(0); + let uri = lsp::Uri::from_file_path(path!("/test/main.rs")).unwrap(); + let lsp_store = project.read_with(cx, |project, _| project.lsp_store()); + + lsp_store.update(cx, |lsp_store, cx| { + lsp_store.update_diagnostics(language_server_id, lsp::PublishDiagnosticsParams { + uri: uri.clone(), + diagnostics: vec![ + lsp::Diagnostic{ + range: lsp::Range::new(lsp::Position::new(5, 6), lsp::Position::new(5, 7)), + severity: Some(lsp::DiagnosticSeverity::WARNING), + message: "use of moved value\nvalue used here after move".to_string(), + related_information: Some(vec![ + lsp::DiagnosticRelatedInformation { + location: lsp::Location::new(uri.clone(), lsp::Range::new(lsp::Position::new(2, 8), lsp::Position::new(2, 9))), + message: "move occurs because `y` has type `Vec`, which does not implement the `Copy` trait".to_string() + }, + lsp::DiagnosticRelatedInformation { + location: lsp::Location::new(uri.clone(), lsp::Range::new(lsp::Position::new(4, 6), lsp::Position::new(4, 7))), + message: "value moved here".to_string() + }, + ]), + ..Default::default() + }, + lsp::Diagnostic{ + range: lsp::Range::new(lsp::Position::new(6, 6), lsp::Position::new(6, 7)), + severity: Some(lsp::DiagnosticSeverity::ERROR), + message: "use of moved value\nvalue used here after move".to_string(), + related_information: Some(vec![ + lsp::DiagnosticRelatedInformation { + location: lsp::Location::new(uri.clone(), lsp::Range::new(lsp::Position::new(1, 8), lsp::Position::new(1, 9))), + message: "move occurs because `x` has type `Vec`, which does not implement the `Copy` trait".to_string() + }, + lsp::DiagnosticRelatedInformation { + location: lsp::Location::new(uri.clone(), lsp::Range::new(lsp::Position::new(3, 6), lsp::Position::new(3, 7))), + message: "value moved here".to_string() + }, + ]), + ..Default::default() + } + ], + version: None + }, None, DiagnosticSourceKind::Pushed, &[], cx).unwrap(); + + // Create diagnostics for other.rs to ensure that the file and + // diagnostics are not included in `BufferDiagnosticsEditor` when it is + // deployed for main.rs. + lsp_store.update_diagnostics(language_server_id, lsp::PublishDiagnosticsParams { + uri: lsp::Uri::from_file_path(path!("/test/other.rs")).unwrap(), + diagnostics: vec![ + lsp::Diagnostic{ + range: lsp::Range::new(lsp::Position::new(1, 8), lsp::Position::new(1, 14)), + severity: Some(lsp::DiagnosticSeverity::WARNING), + message: "unused variable: `unused`".to_string(), + ..Default::default() + }, + lsp::Diagnostic{ + range: lsp::Range::new(lsp::Position::new(2, 4), lsp::Position::new(2, 22)), + severity: Some(lsp::DiagnosticSeverity::ERROR), + message: "cannot find function `undefined_function` in this scope".to_string(), + ..Default::default() + } + ], + version: None + }, None, DiagnosticSourceKind::Pushed, &[], cx).unwrap(); + }); + + let buffer_diagnostics = window.build_entity(cx, |window, cx| { + BufferDiagnosticsEditor::new( + project_path.clone(), + project.clone(), + buffer, + true, + window, + cx, + ) + }); + let editor = buffer_diagnostics.update(cx, |buffer_diagnostics, _| { + buffer_diagnostics.editor().clone() + }); + + // Since the excerpt updates is handled by a background task, we need to + // wait a little bit to ensure that the buffer diagnostic's editor content + // is rendered. + cx.executor() + .advance_clock(DIAGNOSTICS_UPDATE_DELAY + Duration::from_millis(10)); + + pretty_assertions::assert_eq!( + editor_content_with_blocks(&editor, cx), + indoc::indoc! { + "§ main.rs + § ----- + fn main() { + let x = vec![]; + § move occurs because `x` has type `Vec`, which does not implement + § the `Copy` trait (back) + let y = vec![]; + § move occurs because `y` has type `Vec`, which does not implement + § the `Copy` trait + a(x); § value moved here + b(y); § value moved here + c(y); + § use of moved value + § value used here after move + d(x); + § use of moved value + § value used here after move + § hint: move occurs because `x` has type `Vec`, which does not + § implement the `Copy` trait + }" + } + ); +} + +#[gpui::test] +async fn test_buffer_diagnostics_without_warnings(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/test"), + json!({ + "main.rs": " + fn main() { + let x = vec![]; + let y = vec![]; + a(x); + b(y); + c(y); + d(x); + } + " + .unindent(), + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; + let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let cx = &mut VisualTestContext::from_window(*window, cx); + let project_path = project::ProjectPath { + worktree_id: project.read_with(cx, |project, cx| { + project.worktrees(cx).next().unwrap().read(cx).id() + }), + path: Arc::from(Path::new("main.rs")), + }; + let buffer = project + .update(cx, |project, cx| { + project.open_buffer(project_path.clone(), cx) + }) + .await + .ok(); + + let language_server_id = LanguageServerId(0); + let uri = lsp::Uri::from_file_path(path!("/test/main.rs")).unwrap(); + let lsp_store = project.read_with(cx, |project, _| project.lsp_store()); + + lsp_store.update(cx, |lsp_store, cx| { + lsp_store.update_diagnostics(language_server_id, lsp::PublishDiagnosticsParams { + uri: uri.clone(), + diagnostics: vec![ + lsp::Diagnostic{ + range: lsp::Range::new(lsp::Position::new(5, 6), lsp::Position::new(5, 7)), + severity: Some(lsp::DiagnosticSeverity::WARNING), + message: "use of moved value\nvalue used here after move".to_string(), + related_information: Some(vec![ + lsp::DiagnosticRelatedInformation { + location: lsp::Location::new(uri.clone(), lsp::Range::new(lsp::Position::new(2, 8), lsp::Position::new(2, 9))), + message: "move occurs because `y` has type `Vec`, which does not implement the `Copy` trait".to_string() + }, + lsp::DiagnosticRelatedInformation { + location: lsp::Location::new(uri.clone(), lsp::Range::new(lsp::Position::new(4, 6), lsp::Position::new(4, 7))), + message: "value moved here".to_string() + }, + ]), + ..Default::default() + }, + lsp::Diagnostic{ + range: lsp::Range::new(lsp::Position::new(6, 6), lsp::Position::new(6, 7)), + severity: Some(lsp::DiagnosticSeverity::ERROR), + message: "use of moved value\nvalue used here after move".to_string(), + related_information: Some(vec![ + lsp::DiagnosticRelatedInformation { + location: lsp::Location::new(uri.clone(), lsp::Range::new(lsp::Position::new(1, 8), lsp::Position::new(1, 9))), + message: "move occurs because `x` has type `Vec`, which does not implement the `Copy` trait".to_string() + }, + lsp::DiagnosticRelatedInformation { + location: lsp::Location::new(uri.clone(), lsp::Range::new(lsp::Position::new(3, 6), lsp::Position::new(3, 7))), + message: "value moved here".to_string() + }, + ]), + ..Default::default() + } + ], + version: None + }, None, DiagnosticSourceKind::Pushed, &[], cx).unwrap(); + }); + + let include_warnings = false; + let buffer_diagnostics = window.build_entity(cx, |window, cx| { + BufferDiagnosticsEditor::new( + project_path.clone(), + project.clone(), + buffer, + include_warnings, + window, + cx, + ) + }); + + let editor = buffer_diagnostics.update(cx, |buffer_diagnostics, _cx| { + buffer_diagnostics.editor().clone() + }); + + // Since the excerpt updates is handled by a background task, we need to + // wait a little bit to ensure that the buffer diagnostic's editor content + // is rendered. + cx.executor() + .advance_clock(DIAGNOSTICS_UPDATE_DELAY + Duration::from_millis(10)); + + pretty_assertions::assert_eq!( + editor_content_with_blocks(&editor, cx), + indoc::indoc! { + "§ main.rs + § ----- + fn main() { + let x = vec![]; + § move occurs because `x` has type `Vec`, which does not implement + § the `Copy` trait (back) + let y = vec![]; + a(x); § value moved here + b(y); + c(y); + d(x); + § use of moved value + § value used here after move + § hint: move occurs because `x` has type `Vec`, which does not + § implement the `Copy` trait + }" + } + ); +} + +#[gpui::test] +async fn test_buffer_diagnostics_multiple_servers(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/test"), + json!({ + "main.rs": " + fn main() { + let x = vec![]; + let y = vec![]; + a(x); + b(y); + c(y); + d(x); + } + " + .unindent(), + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; + let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let cx = &mut VisualTestContext::from_window(*window, cx); + let project_path = project::ProjectPath { + worktree_id: project.read_with(cx, |project, cx| { + project.worktrees(cx).next().unwrap().read(cx).id() + }), + path: Arc::from(Path::new("main.rs")), + }; + let buffer = project + .update(cx, |project, cx| { + project.open_buffer(project_path.clone(), cx) + }) + .await + .ok(); + + // Create the diagnostics for `main.rs`. + // Two warnings are being created, one for each language server, in order to + // assert that both warnings are rendered in the editor. + let language_server_id_a = LanguageServerId(0); + let language_server_id_b = LanguageServerId(1); + let uri = lsp::Uri::from_file_path(path!("/test/main.rs")).unwrap(); + let lsp_store = project.read_with(cx, |project, _| project.lsp_store()); + + lsp_store.update(cx, |lsp_store, cx| { + lsp_store + .update_diagnostics( + language_server_id_a, + lsp::PublishDiagnosticsParams { + uri: uri.clone(), + diagnostics: vec![lsp::Diagnostic { + range: lsp::Range::new(lsp::Position::new(5, 6), lsp::Position::new(5, 7)), + severity: Some(lsp::DiagnosticSeverity::WARNING), + message: "use of moved value\nvalue used here after move".to_string(), + related_information: None, + ..Default::default() + }], + version: None, + }, + None, + DiagnosticSourceKind::Pushed, + &[], + cx, + ) + .unwrap(); + + lsp_store + .update_diagnostics( + language_server_id_b, + lsp::PublishDiagnosticsParams { + uri: uri.clone(), + diagnostics: vec![lsp::Diagnostic { + range: lsp::Range::new(lsp::Position::new(6, 6), lsp::Position::new(6, 7)), + severity: Some(lsp::DiagnosticSeverity::WARNING), + message: "use of moved value\nvalue used here after move".to_string(), + related_information: None, + ..Default::default() + }], + version: None, + }, + None, + DiagnosticSourceKind::Pushed, + &[], + cx, + ) + .unwrap(); + }); + + let buffer_diagnostics = window.build_entity(cx, |window, cx| { + BufferDiagnosticsEditor::new( + project_path.clone(), + project.clone(), + buffer, + true, + window, + cx, + ) + }); + let editor = buffer_diagnostics.update(cx, |buffer_diagnostics, _| { + buffer_diagnostics.editor().clone() + }); + + // Since the excerpt updates is handled by a background task, we need to + // wait a little bit to ensure that the buffer diagnostic's editor content + // is rendered. + cx.executor() + .advance_clock(DIAGNOSTICS_UPDATE_DELAY + Duration::from_millis(10)); + + pretty_assertions::assert_eq!( + editor_content_with_blocks(&editor, cx), + indoc::indoc! { + "§ main.rs + § ----- + a(x); + b(y); + c(y); + § use of moved value + § value used here after move + d(x); + § use of moved value + § value used here after move + }" + } + ); + + buffer_diagnostics.update(cx, |buffer_diagnostics, _cx| { + assert_eq!( + *buffer_diagnostics.summary(), + DiagnosticSummary { + warning_count: 2, + error_count: 0 + } + ); + }) +} + fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { zlog::init_test(); @@ -1588,10 +2023,10 @@ fn randomly_update_diagnostics_for_path( next_id: &mut usize, rng: &mut impl Rng, ) { - let mutation_count = rng.gen_range(1..=3); + let mutation_count = rng.random_range(1..=3); for _ in 0..mutation_count { - if rng.gen_bool(0.3) && !diagnostics.is_empty() { - let idx = rng.gen_range(0..diagnostics.len()); + if rng.random_bool(0.3) && !diagnostics.is_empty() { + let idx = rng.random_range(0..diagnostics.len()); log::info!(" removing diagnostic at index {idx}"); diagnostics.remove(idx); } else { @@ -1600,7 +2035,7 @@ fn randomly_update_diagnostics_for_path( let new_diagnostic = random_lsp_diagnostic(rng, fs, path, unique_id); - let ix = rng.gen_range(0..=diagnostics.len()); + let ix = rng.random_range(0..=diagnostics.len()); log::info!( " inserting {} at index {ix}. {},{}..{},{}", new_diagnostic.message, @@ -1637,8 +2072,8 @@ fn random_lsp_diagnostic( let file_content = fs.read_file_sync(path).unwrap(); let file_text = Rope::from(String::from_utf8_lossy(&file_content).as_ref()); - let start = rng.gen_range(0..file_text.len().saturating_add(ERROR_MARGIN)); - let end = rng.gen_range(start..file_text.len().saturating_add(ERROR_MARGIN)); + let start = rng.random_range(0..file_text.len().saturating_add(ERROR_MARGIN)); + let end = rng.random_range(start..file_text.len().saturating_add(ERROR_MARGIN)); let start_point = file_text.offset_to_point_utf16(start); let end_point = file_text.offset_to_point_utf16(end); @@ -1648,7 +2083,7 @@ fn random_lsp_diagnostic( lsp::Position::new(end_point.row, end_point.column), ); - let severity = if rng.gen_bool(0.5) { + let severity = if rng.random_bool(0.5) { Some(lsp::DiagnosticSeverity::ERROR) } else { Some(lsp::DiagnosticSeverity::WARNING) @@ -1656,13 +2091,14 @@ fn random_lsp_diagnostic( let message = format!("diagnostic {unique_id}"); - let related_information = if rng.gen_bool(0.3) { - let info_count = rng.gen_range(1..=3); + let related_information = if rng.random_bool(0.3) { + let info_count = rng.random_range(1..=3); let mut related_info = Vec::with_capacity(info_count); for i in 0..info_count { - let info_start = rng.gen_range(0..file_text.len().saturating_add(ERROR_MARGIN)); - let info_end = rng.gen_range(info_start..file_text.len().saturating_add(ERROR_MARGIN)); + let info_start = rng.random_range(0..file_text.len().saturating_add(ERROR_MARGIN)); + let info_end = + rng.random_range(info_start..file_text.len().saturating_add(ERROR_MARGIN)); let info_start_point = file_text.offset_to_point_utf16(info_start); let info_end_point = file_text.offset_to_point_utf16(info_end); @@ -1673,7 +2109,7 @@ fn random_lsp_diagnostic( ); related_info.push(lsp::DiagnosticRelatedInformation { - location: lsp::Location::new(lsp::Url::from_file_path(path).unwrap(), info_range), + location: lsp::Location::new(lsp::Uri::from_file_path(path).unwrap(), info_range), message: format!("related info {i} for diagnostic {unique_id}"), }); } diff --git a/crates/diagnostics/src/items.rs b/crates/diagnostics/src/items.rs index 7ac6d101f315674cec4fd07f4ad2df0830284124..11ee4ece96d0c4646714d808037e7a2789bcdf85 100644 --- a/crates/diagnostics/src/items.rs +++ b/crates/diagnostics/src/items.rs @@ -32,49 +32,38 @@ impl Render for DiagnosticIndicator { } let diagnostic_indicator = match (self.summary.error_count, self.summary.warning_count) { - (0, 0) => h_flex().map(|this| { - this.child( - Icon::new(IconName::Check) - .size(IconSize::Small) - .color(Color::Default), - ) - }), - (0, warning_count) => h_flex() - .gap_1() - .child( - Icon::new(IconName::Warning) - .size(IconSize::Small) - .color(Color::Warning), - ) - .child(Label::new(warning_count.to_string()).size(LabelSize::Small)), - (error_count, 0) => h_flex() - .gap_1() - .child( - Icon::new(IconName::XCircle) - .size(IconSize::Small) - .color(Color::Error), - ) - .child(Label::new(error_count.to_string()).size(LabelSize::Small)), + (0, 0) => h_flex().child( + Icon::new(IconName::Check) + .size(IconSize::Small) + .color(Color::Default), + ), (error_count, warning_count) => h_flex() .gap_1() - .child( - Icon::new(IconName::XCircle) - .size(IconSize::Small) - .color(Color::Error), - ) - .child(Label::new(error_count.to_string()).size(LabelSize::Small)) - .child( - Icon::new(IconName::Warning) - .size(IconSize::Small) - .color(Color::Warning), - ) - .child(Label::new(warning_count.to_string()).size(LabelSize::Small)), + .when(error_count > 0, |this| { + this.child( + Icon::new(IconName::XCircle) + .size(IconSize::Small) + .color(Color::Error), + ) + .child(Label::new(error_count.to_string()).size(LabelSize::Small)) + }) + .when(warning_count > 0, |this| { + this.child( + Icon::new(IconName::Warning) + .size(IconSize::Small) + .color(Color::Warning), + ) + .child(Label::new(warning_count.to_string()).size(LabelSize::Small)) + }), }; let status = if let Some(diagnostic) = &self.current_diagnostic { - let message = diagnostic.message.split('\n').next().unwrap().to_string(); + let message = diagnostic + .message + .split_once('\n') + .map_or(&*diagnostic.message, |(first, _)| first); Some( - Button::new("diagnostic_message", message) + Button::new("diagnostic_message", SharedString::new(message)) .label_size(LabelSize::Small) .tooltip(|window, cx| { Tooltip::for_action( diff --git a/crates/diagnostics/src/toolbar_controls.rs b/crates/diagnostics/src/toolbar_controls.rs index e77b80115f2ffe6de512743d3eb00311052d7937..ac7bfb0f692d820f6671b0cbb03849bddab58903 100644 --- a/crates/diagnostics/src/toolbar_controls.rs +++ b/crates/diagnostics/src/toolbar_controls.rs @@ -1,43 +1,56 @@ -use std::sync::Arc; - -use crate::{ProjectDiagnosticsEditor, ToggleDiagnosticsRefresh}; -use gpui::{Context, Entity, EventEmitter, ParentElement, Render, WeakEntity, Window}; +use crate::{BufferDiagnosticsEditor, ProjectDiagnosticsEditor, ToggleDiagnosticsRefresh}; +use gpui::{Context, EventEmitter, ParentElement, Render, Window}; +use language::DiagnosticEntry; +use text::{Anchor, BufferId}; use ui::prelude::*; use ui::{IconButton, IconButtonShape, IconName, Tooltip}; use workspace::{ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, item::ItemHandle}; pub struct ToolbarControls { - editor: Option>, + editor: Option>, +} + +pub(crate) trait DiagnosticsToolbarEditor: Send + Sync { + /// Informs the toolbar whether warnings are included in the diagnostics. + fn include_warnings(&self, cx: &App) -> bool; + /// Toggles whether warning diagnostics should be displayed by the + /// diagnostics editor. + fn toggle_warnings(&self, window: &mut Window, cx: &mut App); + /// Indicates whether any of the excerpts displayed by the diagnostics + /// editor are stale. + fn has_stale_excerpts(&self, cx: &App) -> bool; + /// Indicates whether the diagnostics editor is currently updating the + /// diagnostics. + fn is_updating(&self, cx: &App) -> bool; + /// Requests that the diagnostics editor stop updating the diagnostics. + fn stop_updating(&self, cx: &mut App); + /// Requests that the diagnostics editor updates the displayed diagnostics + /// with the latest information. + fn refresh_diagnostics(&self, window: &mut Window, cx: &mut App); + /// Returns a list of diagnostics for the provided buffer id. + fn get_diagnostics_for_buffer( + &self, + buffer_id: BufferId, + cx: &App, + ) -> Vec>; } impl Render for ToolbarControls { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let mut include_warnings = false; let mut has_stale_excerpts = false; + let mut include_warnings = false; let mut is_updating = false; - let cargo_diagnostics_sources = Arc::new(self.diagnostics().map_or(Vec::new(), |editor| { - editor.read(cx).cargo_diagnostics_sources(cx) - })); - let fetch_cargo_diagnostics = !cargo_diagnostics_sources.is_empty(); - if let Some(editor) = self.diagnostics() { - let diagnostics = editor.read(cx); - include_warnings = diagnostics.include_warnings; - has_stale_excerpts = !diagnostics.paths_to_update.is_empty(); - is_updating = if fetch_cargo_diagnostics { - diagnostics.cargo_diagnostics_fetch.fetch_task.is_some() - } else { - diagnostics.update_excerpts_task.is_some() - || diagnostics - .project - .read(cx) - .language_servers_running_disk_based_diagnostics(cx) - .next() - .is_some() - }; + match &self.editor { + Some(editor) => { + include_warnings = editor.include_warnings(cx); + has_stale_excerpts = editor.has_stale_excerpts(cx); + is_updating = editor.is_updating(cx); + } + None => {} } - let tooltip = if include_warnings { + let warning_tooltip = if include_warnings { "Exclude Warnings" } else { "Include Warnings" @@ -62,12 +75,12 @@ impl Render for ToolbarControls { &ToggleDiagnosticsRefresh, )) .on_click(cx.listener(move |toolbar_controls, _, _, cx| { - if let Some(diagnostics) = toolbar_controls.diagnostics() { - diagnostics.update(cx, |diagnostics, cx| { - diagnostics.stop_cargo_diagnostics_fetch(cx); - diagnostics.update_excerpts_task = None; + match toolbar_controls.editor() { + Some(editor) => { + editor.stop_updating(cx); cx.notify(); - }); + } + None => {} } })), ) @@ -76,27 +89,17 @@ impl Render for ToolbarControls { IconButton::new("refresh-diagnostics", IconName::ArrowCircle) .icon_color(Color::Info) .shape(IconButtonShape::Square) - .disabled(!has_stale_excerpts && !fetch_cargo_diagnostics) + .disabled(!has_stale_excerpts) .tooltip(Tooltip::for_action_title( "Refresh diagnostics", &ToggleDiagnosticsRefresh, )) .on_click(cx.listener({ - move |toolbar_controls, _, window, cx| { - if let Some(diagnostics) = toolbar_controls.diagnostics() { - let cargo_diagnostics_sources = - Arc::clone(&cargo_diagnostics_sources); - diagnostics.update(cx, move |diagnostics, cx| { - if fetch_cargo_diagnostics { - diagnostics.fetch_cargo_diagnostics( - cargo_diagnostics_sources, - cx, - ); - } else { - diagnostics.update_all_excerpts(window, cx); - } - }); - } + move |toolbar_controls, _, window, cx| match toolbar_controls + .editor() + { + Some(editor) => editor.refresh_diagnostics(window, cx), + None => {} } })), ) @@ -106,13 +109,10 @@ impl Render for ToolbarControls { IconButton::new("toggle-warnings", IconName::Warning) .icon_color(warning_color) .shape(IconButtonShape::Square) - .tooltip(Tooltip::text(tooltip)) - .on_click(cx.listener(|this, _, window, cx| { - if let Some(editor) = this.diagnostics() { - editor.update(cx, |editor, cx| { - editor.toggle_warnings(&Default::default(), window, cx); - }); - } + .tooltip(Tooltip::text(warning_tooltip)) + .on_click(cx.listener(|this, _, window, cx| match &this.editor { + Some(editor) => editor.toggle_warnings(window, cx), + None => {} })), ) } @@ -129,7 +129,10 @@ impl ToolbarItemView for ToolbarControls { ) -> ToolbarItemLocation { if let Some(pane_item) = active_pane_item.as_ref() { if let Some(editor) = pane_item.downcast::() { - self.editor = Some(editor.downgrade()); + self.editor = Some(Box::new(editor.downgrade())); + ToolbarItemLocation::PrimaryRight + } else if let Some(editor) = pane_item.downcast::() { + self.editor = Some(Box::new(editor.downgrade())); ToolbarItemLocation::PrimaryRight } else { ToolbarItemLocation::Hidden @@ -151,7 +154,7 @@ impl ToolbarControls { ToolbarControls { editor: None } } - fn diagnostics(&self) -> Option> { - self.editor.as_ref()?.upgrade() + fn editor(&self) -> Option<&dyn DiagnosticsToolbarEditor> { + self.editor.as_deref() } } diff --git a/crates/docs_preprocessor/src/main.rs b/crates/docs_preprocessor/src/main.rs index 17804b428145ed49d6bb274ab5f13d5b46e5f7f4..c8c3dc54b76085707c0491eab683ff954a483bf9 100644 --- a/crates/docs_preprocessor/src/main.rs +++ b/crates/docs_preprocessor/src/main.rs @@ -19,9 +19,13 @@ static KEYMAP_LINUX: LazyLock = LazyLock::new(|| { load_keymap("keymaps/default-linux.json").expect("Failed to load Linux keymap") }); +static KEYMAP_WINDOWS: LazyLock = LazyLock::new(|| { + load_keymap("keymaps/default-windows.json").expect("Failed to load Windows keymap") +}); + static ALL_ACTIONS: LazyLock> = LazyLock::new(dump_all_gpui_actions); -const FRONT_MATTER_COMMENT: &'static str = ""; +const FRONT_MATTER_COMMENT: &str = ""; fn main() -> Result<()> { zlog::init(); @@ -61,15 +65,13 @@ impl PreprocessorError { for alias in action.deprecated_aliases { if alias == &action_name { return PreprocessorError::DeprecatedActionUsed { - used: action_name.clone(), + used: action_name, should_be: action.name.to_string(), }; } } } - PreprocessorError::ActionNotFound { - action_name: action_name.to_string(), - } + PreprocessorError::ActionNotFound { action_name } } } @@ -101,12 +103,13 @@ fn handle_preprocessing() -> Result<()> { let mut errors = HashSet::::new(); handle_frontmatter(&mut book, &mut errors); + template_big_table_of_actions(&mut book); template_and_validate_keybindings(&mut book, &mut errors); template_and_validate_actions(&mut book, &mut errors); if !errors.is_empty() { - const ANSI_RED: &'static str = "\x1b[31m"; - const ANSI_RESET: &'static str = "\x1b[0m"; + const ANSI_RED: &str = "\x1b[31m"; + const ANSI_RESET: &str = "\x1b[0m"; for error in &errors { eprintln!("{ANSI_RED}ERROR{ANSI_RESET}: {}", error); } @@ -129,7 +132,7 @@ fn handle_frontmatter(book: &mut Book, errors: &mut HashSet) let Some((name, value)) = line.split_once(':') else { errors.insert(PreprocessorError::InvalidFrontmatterLine(format!( "{}: {}", - chapter_breadcrumbs(&chapter), + chapter_breadcrumbs(chapter), line ))); continue; @@ -143,11 +146,20 @@ fn handle_frontmatter(book: &mut Book, errors: &mut HashSet) &serde_json::to_string(&metadata).expect("Failed to serialize metadata"), ) }); - match new_content { - Cow::Owned(content) => { - chapter.content = content; - } - Cow::Borrowed(_) => {} + if let Cow::Owned(content) = new_content { + chapter.content = content; + } + }); +} + +fn template_big_table_of_actions(book: &mut Book) { + for_each_chapter_mut(book, |chapter| { + let needle = "{#ACTIONS_TABLE#}"; + if let Some(start) = chapter.content.rfind(needle) { + chapter.content.replace_range( + start..start + needle.len(), + &generate_big_table_of_actions(), + ); } }); } @@ -208,6 +220,7 @@ fn find_binding(os: &str, action: &str) -> Option { let keymap = match os { "macos" => &KEYMAP_MACOS, "linux" | "freebsd" => &KEYMAP_LINUX, + "windows" => &KEYMAP_WINDOWS, _ => unreachable!("Not a valid OS: {}", os), }; @@ -282,6 +295,7 @@ struct ActionDef { name: &'static str, human_name: String, deprecated_aliases: &'static [&'static str], + docs: Option<&'static str>, } fn dump_all_gpui_actions() -> Vec { @@ -290,12 +304,13 @@ fn dump_all_gpui_actions() -> Vec { name: action.name, human_name: command_palette::humanize_action_name(action.name), deprecated_aliases: action.deprecated_aliases, + docs: action.documentation, }) .collect::>(); actions.sort_by_key(|a| a.name); - return actions; + actions } fn handle_postprocessing() -> Result<()> { @@ -402,20 +417,20 @@ fn handle_postprocessing() -> Result<()> { path: &'a std::path::PathBuf, root: &'a std::path::PathBuf, ) -> &'a std::path::Path { - &path.strip_prefix(&root).unwrap_or(&path) + path.strip_prefix(&root).unwrap_or(path) } fn extract_title_from_page(contents: &str, pretty_path: &std::path::Path) -> String { let title_tag_contents = &title_regex() - .captures(&contents) + .captures(contents) .with_context(|| format!("Failed to find title in {:?}", pretty_path)) .expect("Page has element")[1]; - let title = title_tag_contents + + title_tag_contents .trim() .strip_suffix("- Zed") .unwrap_or(title_tag_contents) .trim() - .to_string(); - title + .to_string() } } @@ -423,3 +438,54 @@ fn title_regex() -> &'static Regex { static TITLE_REGEX: OnceLock<Regex> = OnceLock::new(); TITLE_REGEX.get_or_init(|| Regex::new(r"<title>\s*(.*?)\s*").unwrap()) } + +fn generate_big_table_of_actions() -> String { + let actions = &*ALL_ACTIONS; + let mut output = String::new(); + + let mut actions_sorted = actions.iter().collect::>(); + actions_sorted.sort_by_key(|a| a.name); + + // Start the definition list with custom styling for better spacing + output.push_str("
\n"); + + for action in actions_sorted.into_iter() { + // Add the humanized action name as the term with margin + output.push_str( + "
", + ); + output.push_str(&action.human_name); + output.push_str("
\n"); + + // Add the definition with keymap name and description + output.push_str("
\n"); + + // Add the description, escaping HTML if needed + if let Some(description) = action.docs { + output.push_str( + &description + .replace("&", "&") + .replace("<", "<") + .replace(">", ">"), + ); + output.push_str("
\n"); + } + output.push_str("Keymap Name: "); + output.push_str(action.name); + output.push_str("
\n"); + if !action.deprecated_aliases.is_empty() { + output.push_str("Deprecated Aliases:"); + for alias in action.deprecated_aliases.iter() { + output.push_str(""); + output.push_str(alias); + output.push_str(", "); + } + } + output.push_str("\n
\n"); + } + + // Close the definition list + output.push_str("
\n"); + + output +} diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index c8502f75de5adac0a1bfdcb8cd8fe4444bb70f84..6b695af1ae0e4807c9aa93af34a5d07de0c15795 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -34,7 +34,7 @@ pub enum DataCollectionState { impl DataCollectionState { pub fn is_supported(&self) -> bool { - !matches!(self, DataCollectionState::Unsupported { .. }) + !matches!(self, DataCollectionState::Unsupported) } pub fn is_enabled(&self) -> bool { @@ -89,9 +89,6 @@ pub trait EditPredictionProvider: 'static + Sized { debounce: bool, cx: &mut Context, ); - fn needs_terms_acceptance(&self, _cx: &App) -> bool { - false - } fn cycle( &mut self, buffer: Entity, @@ -124,7 +121,6 @@ pub trait EditPredictionProviderHandle { fn data_collection_state(&self, cx: &App) -> DataCollectionState; fn usage(&self, cx: &App) -> Option; fn toggle_data_collection(&self, cx: &mut App); - fn needs_terms_acceptance(&self, cx: &App) -> bool; fn is_refreshing(&self, cx: &App) -> bool; fn refresh( &self, @@ -196,10 +192,6 @@ where self.read(cx).is_enabled(buffer, cursor_position, cx) } - fn needs_terms_acceptance(&self, cx: &App) -> bool { - self.read(cx).needs_terms_acceptance(cx) - } - fn is_refreshing(&self, cx: &App) -> bool { self.read(cx).is_refreshing() } diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs index 3d3b43d71bc4a0914ed97dac24a278049f4c52f1..0e3fe8cb1a449e494592d8f517feb26131d89f65 100644 --- a/crates/edit_prediction_button/src/edit_prediction_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -127,7 +127,7 @@ impl Render for EditPredictionButton { }), ); } - let this = cx.entity().clone(); + let this = cx.entity(); div().child( PopoverMenu::new("copilot") @@ -168,7 +168,7 @@ impl Render for EditPredictionButton { let account_status = agent.account_status.clone(); match account_status { AccountStatus::NeedsActivation { activate_url } => { - SupermavenButtonStatus::NeedsActivation(activate_url.clone()) + SupermavenButtonStatus::NeedsActivation(activate_url) } AccountStatus::Unknown => SupermavenButtonStatus::Initializing, AccountStatus::Ready => SupermavenButtonStatus::Ready, @@ -182,10 +182,10 @@ impl Render for EditPredictionButton { let icon = status.to_icon(); let tooltip_text = status.to_tooltip(); let has_menu = status.has_menu(); - let this = cx.entity().clone(); + let this = cx.entity(); let fs = self.fs.clone(); - return div().child( + div().child( PopoverMenu::new("supermaven") .menu(move |window, cx| match &status { SupermavenButtonStatus::NeedsActivation(activate_url) => { @@ -230,7 +230,7 @@ impl Render for EditPredictionButton { }, ) .with_handle(self.popover_menu_handle.clone()), - ); + ) } EditPredictionProvider::Zed => { @@ -242,13 +242,9 @@ impl Render for EditPredictionButton { IconName::ZedPredictDisabled }; - if zeta::should_show_upsell_modal(&self.user_store, cx) { + if zeta::should_show_upsell_modal() { let tooltip_meta = if self.user_store.read(cx).current_user().is_some() { - if self.user_store.read(cx).has_accepted_terms_of_service() { - "Choose a Plan" - } else { - "Accept the Terms of Service" - } + "Choose a Plan" } else { "Sign In" }; @@ -331,7 +327,7 @@ impl Render for EditPredictionButton { }) }); - let this = cx.entity().clone(); + let this = cx.entity(); let mut popover_menu = PopoverMenu::new("zeta") .menu(move |window, cx| { @@ -343,7 +339,7 @@ impl Render for EditPredictionButton { let is_refreshing = self .edit_prediction_provider .as_ref() - .map_or(false, |provider| provider.is_refreshing(cx)); + .is_some_and(|provider| provider.is_refreshing(cx)); if is_refreshing { popover_menu = popover_menu.trigger( diff --git a/crates/editor/Cargo.toml b/crates/editor/Cargo.toml index 339f98ae8bd88263f1fea12c535569864faae294..be06cc04dfc7ee3f080e8d995783abb819e95842 100644 --- a/crates/editor/Cargo.toml +++ b/crates/editor/Cargo.toml @@ -94,6 +94,7 @@ zed_actions.workspace = true workspace-hack.workspace = true [dev-dependencies] +criterion.workspace = true ctor.workspace = true gpui = { workspace = true, features = ["test-support"] } language = { workspace = true, features = ["test-support"] } @@ -119,3 +120,12 @@ util = { workspace = true, features = ["test-support"] } workspace = { workspace = true, features = ["test-support"] } http_client = { workspace = true, features = ["test-support"] } zlog.workspace = true + + +[[bench]] +name = "editor_render" +harness = false + +[[bench]] +name = "display_map" +harness = false diff --git a/crates/editor/benches/display_map.rs b/crates/editor/benches/display_map.rs new file mode 100644 index 0000000000000000000000000000000000000000..919249ad01b87fe5fbabe1b5fe6e563179b41d10 --- /dev/null +++ b/crates/editor/benches/display_map.rs @@ -0,0 +1,102 @@ +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use editor::MultiBuffer; +use gpui::TestDispatcher; +use itertools::Itertools; +use rand::{Rng, SeedableRng, rngs::StdRng}; +use std::num::NonZeroU32; +use text::Bias; +use util::RandomCharIter; + +fn to_tab_point_benchmark(c: &mut Criterion) { + let rng = StdRng::seed_from_u64(1); + let dispatcher = TestDispatcher::new(rng); + let cx = gpui::TestAppContext::build(dispatcher, None); + + let create_tab_map = |length: usize| { + let mut rng = StdRng::seed_from_u64(1); + let text = RandomCharIter::new(&mut rng) + .take(length) + .collect::(); + let buffer = cx.update(|cx| MultiBuffer::build_simple(&text, cx)); + + let buffer_snapshot = cx.read(|cx| buffer.read(cx).snapshot(cx)); + use editor::display_map::*; + let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot); + let (_, fold_snapshot) = FoldMap::new(inlay_snapshot.clone()); + let fold_point = fold_snapshot.to_fold_point( + inlay_snapshot.to_point(InlayOffset(rng.random_range(0..length))), + Bias::Left, + ); + let (_, snapshot) = TabMap::new(fold_snapshot, NonZeroU32::new(4).unwrap()); + + (length, snapshot, fold_point) + }; + + let inputs = [1024].into_iter().map(create_tab_map).collect_vec(); + + let mut group = c.benchmark_group("To tab point"); + + for (batch_size, snapshot, fold_point) in inputs { + group.bench_with_input( + BenchmarkId::new("to_tab_point", batch_size), + &snapshot, + |bench, snapshot| { + bench.iter(|| { + snapshot.to_tab_point(fold_point); + }); + }, + ); + } + + group.finish(); +} + +fn to_fold_point_benchmark(c: &mut Criterion) { + let rng = StdRng::seed_from_u64(1); + let dispatcher = TestDispatcher::new(rng); + let cx = gpui::TestAppContext::build(dispatcher, None); + + let create_tab_map = |length: usize| { + let mut rng = StdRng::seed_from_u64(1); + let text = RandomCharIter::new(&mut rng) + .take(length) + .collect::(); + let buffer = cx.update(|cx| MultiBuffer::build_simple(&text, cx)); + + let buffer_snapshot = cx.read(|cx| buffer.read(cx).snapshot(cx)); + use editor::display_map::*; + let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot); + let (_, fold_snapshot) = FoldMap::new(inlay_snapshot.clone()); + + let fold_point = fold_snapshot.to_fold_point( + inlay_snapshot.to_point(InlayOffset(rng.random_range(0..length))), + Bias::Left, + ); + + let (_, snapshot) = TabMap::new(fold_snapshot, NonZeroU32::new(4).unwrap()); + let tab_point = snapshot.to_tab_point(fold_point); + + (length, snapshot, tab_point) + }; + + let inputs = [1024].into_iter().map(create_tab_map).collect_vec(); + + let mut group = c.benchmark_group("To fold point"); + + for (batch_size, snapshot, tab_point) in inputs { + group.bench_with_input( + BenchmarkId::new("to_fold_point", batch_size), + &snapshot, + |bench, snapshot| { + bench.iter(|| { + snapshot.to_fold_point(tab_point, Bias::Left); + }); + }, + ); + } + + group.finish(); +} + +criterion_group!(benches, to_tab_point_benchmark, to_fold_point_benchmark); +criterion_main!(benches); diff --git a/crates/editor/benches/editor_render.rs b/crates/editor/benches/editor_render.rs new file mode 100644 index 0000000000000000000000000000000000000000..0ae1af5537fb62a7658ccd306545503b818c28ae --- /dev/null +++ b/crates/editor/benches/editor_render.rs @@ -0,0 +1,172 @@ +use criterion::{Bencher, BenchmarkId}; +use editor::{ + Editor, EditorMode, MultiBuffer, + actions::{DeleteToPreviousWordStart, SelectAll, SplitSelectionIntoLines}, +}; +use gpui::{AppContext, Focusable as _, TestAppContext, TestDispatcher}; +use project::Project; +use rand::{Rng as _, SeedableRng as _, rngs::StdRng}; +use settings::SettingsStore; +use ui::IntoElement; +use util::RandomCharIter; + +fn editor_input_with_1000_cursors(bencher: &mut Bencher<'_>, cx: &TestAppContext) { + let mut cx = cx.clone(); + let text = String::from_iter(["line:\n"; 1000]); + let buffer = cx.update(|cx| MultiBuffer::build_simple(&text, cx)); + + let cx = cx.add_empty_window(); + let editor = cx.update(|window, cx| { + let editor = cx.new(|cx| { + let mut editor = Editor::new(EditorMode::full(), buffer, None, window, cx); + editor.set_style(editor::EditorStyle::default(), window, cx); + editor.select_all(&SelectAll, window, cx); + editor.split_selection_into_lines( + &SplitSelectionIntoLines { + keep_selections: true, + }, + window, + cx, + ); + editor + }); + window.focus(&editor.focus_handle(cx)); + editor + }); + + bencher.iter(|| { + cx.update(|window, cx| { + editor.update(cx, |editor, cx| { + editor.handle_input("hello world", window, cx); + editor.delete_to_previous_word_start( + &DeleteToPreviousWordStart { + ignore_newlines: false, + ignore_brackets: false, + }, + window, + cx, + ); + editor.delete_to_previous_word_start( + &DeleteToPreviousWordStart { + ignore_newlines: false, + ignore_brackets: false, + }, + window, + cx, + ); + }); + }) + }); +} + +fn open_editor_with_one_long_line(bencher: &mut Bencher<'_>, args: &(String, TestAppContext)) { + let (text, cx) = args; + let mut cx = cx.clone(); + + bencher.iter(|| { + let buffer = cx.update(|cx| MultiBuffer::build_simple(&text, cx)); + + let cx = cx.add_empty_window(); + let _ = cx.update(|window, cx| { + let editor = cx.new(|cx| { + let mut editor = Editor::new(EditorMode::full(), buffer, None, window, cx); + editor.set_style(editor::EditorStyle::default(), window, cx); + editor + }); + window.focus(&editor.focus_handle(cx)); + editor + }); + }); +} + +fn editor_render(bencher: &mut Bencher<'_>, cx: &TestAppContext) { + let mut cx = cx.clone(); + let buffer = cx.update(|cx| { + let mut rng = StdRng::seed_from_u64(1); + let text_len = rng.random_range(10000..90000); + if rng.random() { + let text = RandomCharIter::new(&mut rng) + .take(text_len) + .collect::(); + MultiBuffer::build_simple(&text, cx) + } else { + MultiBuffer::build_random(&mut rng, cx) + } + }); + + let cx = cx.add_empty_window(); + let editor = cx.update(|window, cx| { + let editor = cx.new(|cx| { + let mut editor = Editor::new(EditorMode::full(), buffer, None, window, cx); + editor.set_style(editor::EditorStyle::default(), window, cx); + editor + }); + window.focus(&editor.focus_handle(cx)); + editor + }); + + bencher.iter(|| { + cx.update(|window, cx| { + // editor.update(cx, |editor, cx| editor.move_down(&MoveDown, window, cx)); + let mut view = editor.clone().into_any_element(); + let _ = view.request_layout(window, cx); + let _ = view.prepaint(window, cx); + view.paint(window, cx); + }); + }) +} + +pub fn benches() { + let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(1)); + let cx = gpui::TestAppContext::build(dispatcher, None); + cx.update(|cx| { + let store = SettingsStore::test(cx); + cx.set_global(store); + assets::Assets.load_test_fonts(cx); + theme::init(theme::LoadThemes::JustBase, cx); + // release_channel::init(SemanticVersion::default(), cx); + client::init_settings(cx); + language::init(cx); + workspace::init_settings(cx); + Project::init_settings(cx); + editor::init(cx); + }); + + let mut criterion: criterion::Criterion<_> = + (criterion::Criterion::default()).configure_from_args(); + + // setup app context + let mut group = criterion.benchmark_group("Time to render"); + group.bench_with_input( + BenchmarkId::new("editor_render", "TestAppContext"), + &cx, + editor_render, + ); + + group.finish(); + + let text = String::from_iter(["char"; 1000]); + let mut group = criterion.benchmark_group("Build buffer with one long line"); + group.bench_with_input( + BenchmarkId::new("editor_with_one_long_line", "(String, TestAppContext )"), + &(text, cx.clone()), + open_editor_with_one_long_line, + ); + + group.finish(); + + let mut group = criterion.benchmark_group("multi cursor edits"); + group.bench_with_input( + BenchmarkId::new("editor_input_with_1000_cursors", "TestAppContext"), + &cx, + editor_input_with_1000_cursors, + ); + group.finish(); +} + +fn main() { + benches(); + criterion::Criterion::default() + .configure_from_args() + .final_summary(); +} diff --git a/crates/editor/src/actions.rs b/crates/editor/src/actions.rs index 39433b3c279e101f47ad4b2eed4d180f82a38997..9dac77970a4560d4e9684c925b5ac4878157941f 100644 --- a/crates/editor/src/actions.rs +++ b/crates/editor/src/actions.rs @@ -228,21 +228,38 @@ pub struct ShowCompletions { pub struct HandleInput(pub String); /// Deletes from the cursor to the end of the next word. +/// Stops before the end of the next word, if whitespace sequences of length >= 2 are encountered. #[derive(PartialEq, Clone, Deserialize, Default, JsonSchema, Action)] #[action(namespace = editor)] #[serde(deny_unknown_fields)] pub struct DeleteToNextWordEnd { #[serde(default)] pub ignore_newlines: bool, + // Whether to stop before the end of the next word, if language-defined bracket is encountered. + #[serde(default)] + pub ignore_brackets: bool, } /// Deletes from the cursor to the start of the previous word. +/// Stops before the start of the previous word, if whitespace sequences of length >= 2 are encountered. #[derive(PartialEq, Clone, Deserialize, Default, JsonSchema, Action)] #[action(namespace = editor)] #[serde(deny_unknown_fields)] pub struct DeleteToPreviousWordStart { #[serde(default)] pub ignore_newlines: bool, + // Whether to stop before the start of the previous word, if language-defined bracket is encountered. + #[serde(default)] + pub ignore_brackets: bool, +} + +/// Cuts from cursor to end of line. +#[derive(PartialEq, Clone, Deserialize, Default, JsonSchema, Action)] +#[action(namespace = editor)] +#[serde(deny_unknown_fields)] +pub struct CutToEndOfLine { + #[serde(default)] + pub stop_at_newlines: bool, } /// Folds all code blocks at the specified indentation level. @@ -273,6 +290,16 @@ pub enum UuidVersion { V7, } +/// Splits selection into individual lines. +#[derive(PartialEq, Clone, Deserialize, Default, JsonSchema, Action)] +#[action(namespace = editor)] +#[serde(deny_unknown_fields)] +pub struct SplitSelectionIntoLines { + /// Keep the text selected after splitting instead of collapsing to cursors. + #[serde(default)] + pub keep_selections: bool, +} + /// Goes to the next diagnostic in the file. #[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)] #[action(namespace = editor)] @@ -394,8 +421,6 @@ actions!( CopyPermalinkToLine, /// Cuts selected text to the clipboard. Cut, - /// Cuts from cursor to end of line. - CutToEndOfLine, /// Deletes the character after the cursor. Delete, /// Deletes the current line. @@ -475,6 +500,10 @@ actions!( GoToTypeDefinition, /// Goes to type definition in a split pane. GoToTypeDefinitionSplit, + /// Goes to the next document highlight. + GoToNextDocumentHighlight, + /// Goes to the previous document highlight. + GoToPreviousDocumentHighlight, /// Scrolls down by half a page. HalfPageDown, /// Scrolls up by half a page. @@ -622,6 +651,10 @@ actions!( SelectEnclosingSymbol, /// Selects the next larger syntax node. SelectLargerSyntaxNode, + /// Selects the next syntax node sibling. + SelectNextSyntaxNode, + /// Selects the previous syntax node sibling. + SelectPreviousSyntaxNode, /// Extends selection left. SelectLeft, /// Selects the current line. @@ -672,8 +705,6 @@ actions!( SortLinesCaseInsensitive, /// Sorts selected lines case-sensitively. SortLinesCaseSensitive, - /// Splits selection into individual lines. - SplitSelectionIntoLines, /// Stops the language server for the current file. StopLanguageServer, /// Switches between source and header files. @@ -745,6 +776,8 @@ actions!( UniqueLinesCaseInsensitive, /// Removes duplicate lines (case-sensitive). UniqueLinesCaseSensitive, - UnwrapSyntaxNode + UnwrapSyntaxNode, + /// Wraps selections in tag specified by language. + WrapSelectionsInTag ] ); diff --git a/crates/editor/src/clangd_ext.rs b/crates/editor/src/clangd_ext.rs index 3239fdc653e0e2acdbdaa3396e30c0546ef259cf..c78d4c83c01c49e6b1ff947d3cd53bc887424a16 100644 --- a/crates/editor/src/clangd_ext.rs +++ b/crates/editor/src/clangd_ext.rs @@ -13,7 +13,7 @@ use crate::{Editor, SwitchSourceHeader, element::register_action}; use project::lsp_store::clangd_ext::CLANGD_SERVER_NAME; fn is_c_language(language: &Language) -> bool { - return language.name() == "C++".into() || language.name() == "C".into(); + language.name() == "C++".into() || language.name() == "C".into() } pub fn switch_source_header( @@ -104,6 +104,6 @@ pub fn apply_related_actions(editor: &Entity, window: &mut Window, cx: & .filter_map(|buffer| buffer.read(cx).language()) .any(|language| is_c_language(language)) { - register_action(&editor, window, switch_source_header); + register_action(editor, window, switch_source_header); } } diff --git a/crates/editor/src/code_completion_tests.rs b/crates/editor/src/code_completion_tests.rs index fd8db29584d8eb6944ff674dd8bf5d860ce32428..a1d9f04a9c590ef1f20779bf19c2fe0be8905709 100644 --- a/crates/editor/src/code_completion_tests.rs +++ b/crates/editor/src/code_completion_tests.rs @@ -317,7 +317,7 @@ async fn filter_and_sort_matches( let candidates: Arc<[StringMatchCandidate]> = completions .iter() .enumerate() - .map(|(id, completion)| StringMatchCandidate::new(id, &completion.label.filter_text())) + .map(|(id, completion)| StringMatchCandidate::new(id, completion.label.filter_text())) .collect(); let cancel_flag = Arc::new(AtomicBool::new(false)); let background_executor = cx.executor(); @@ -331,5 +331,5 @@ async fn filter_and_sort_matches( background_executor, ) .await; - CompletionsMenu::sort_string_matches(matches, Some(query), snippet_sort_order, &completions) + CompletionsMenu::sort_string_matches(matches, Some(query), snippet_sort_order, completions) } diff --git a/crates/editor/src/code_context_menus.rs b/crates/editor/src/code_context_menus.rs index 4ae2a14ca730dafa7cfecd9e9b3bacbe3f7bc47b..eef4e8643928631d4fb20f044d6f27bbded80a09 100644 --- a/crates/editor/src/code_context_menus.rs +++ b/crates/editor/src/code_context_menus.rs @@ -1,7 +1,9 @@ +use crate::scroll::ScrollAmount; use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{ - AnyElement, Entity, Focusable, FontWeight, ListSizingBehavior, ScrollStrategy, SharedString, - Size, StrikethroughStyle, StyledText, Task, UniformListScrollHandle, div, px, uniform_list, + AnyElement, Entity, Focusable, FontWeight, ListSizingBehavior, ScrollHandle, ScrollStrategy, + SharedString, Size, StrikethroughStyle, StyledText, Task, UniformListScrollHandle, div, px, + uniform_list, }; use itertools::Itertools; use language::CodeLabel; @@ -9,9 +11,9 @@ use language::{Buffer, LanguageName, LanguageRegistry}; use markdown::{Markdown, MarkdownElement}; use multi_buffer::{Anchor, ExcerptId}; use ordered_float::OrderedFloat; -use project::CompletionSource; use project::lsp_store::CompletionDocumentation; use project::{CodeAction, Completion, TaskSourceKind}; +use project::{CompletionDisplayOptions, CompletionSource}; use task::DebugScenario; use task::TaskContext; @@ -184,6 +186,20 @@ impl CodeContextMenu { CodeContextMenu::CodeActions(_) => false, } } + + pub fn scroll_aside( + &mut self, + scroll_amount: ScrollAmount, + window: &mut Window, + cx: &mut Context, + ) { + match self { + CodeContextMenu::Completions(completions_menu) => { + completions_menu.scroll_aside(scroll_amount, window, cx) + } + CodeContextMenu::CodeActions(_) => (), + } + } } pub enum ContextMenuOrigin { @@ -207,12 +223,16 @@ pub struct CompletionsMenu { filter_task: Task<()>, cancel_filter: Arc, scroll_handle: UniformListScrollHandle, + // The `ScrollHandle` used on the Markdown documentation rendered on the + // side of the completions menu. + pub scroll_handle_aside: ScrollHandle, resolve_completions: bool, show_completion_documentation: bool, last_rendered_range: Rc>>>, markdown_cache: Rc)>>>, language_registry: Option>, language: Option, + display_options: CompletionDisplayOptions, snippet_sort_order: SnippetSortOrder, } @@ -231,7 +251,7 @@ enum MarkdownCacheKey { pub enum CompletionsMenuSource { Normal, SnippetChoices, - Words, + Words { ignore_threshold: bool }, } // TODO: There should really be a wrapper around fuzzy match tasks that does this. @@ -252,6 +272,7 @@ impl CompletionsMenu { is_incomplete: bool, buffer: Entity, completions: Box<[Completion]>, + display_options: CompletionDisplayOptions, snippet_sort_order: SnippetSortOrder, language_registry: Option>, language: Option, @@ -279,11 +300,13 @@ impl CompletionsMenu { filter_task: Task::ready(()), cancel_filter: Arc::new(AtomicBool::new(false)), scroll_handle: UniformListScrollHandle::new(), + scroll_handle_aside: ScrollHandle::new(), resolve_completions: true, last_rendered_range: RefCell::new(None).into(), markdown_cache: RefCell::new(VecDeque::new()).into(), language_registry, language, + display_options, snippet_sort_order, }; @@ -321,7 +344,7 @@ impl CompletionsMenu { let match_candidates = choices .iter() .enumerate() - .map(|(id, completion)| StringMatchCandidate::new(id, &completion)) + .map(|(id, completion)| StringMatchCandidate::new(id, completion)) .collect(); let entries = choices .iter() @@ -348,12 +371,14 @@ impl CompletionsMenu { filter_task: Task::ready(()), cancel_filter: Arc::new(AtomicBool::new(false)), scroll_handle: UniformListScrollHandle::new(), + scroll_handle_aside: ScrollHandle::new(), resolve_completions: false, show_completion_documentation: false, last_rendered_range: RefCell::new(None).into(), markdown_cache: RefCell::new(VecDeque::new()).into(), language_registry: None, language: None, + display_options: CompletionDisplayOptions::default(), snippet_sort_order, } } @@ -514,7 +539,7 @@ impl CompletionsMenu { // Expand the range to resolve more completions than are predicted to be visible, to reduce // jank on navigation. let entry_indices = util::expanded_and_wrapped_usize_range( - entry_range.clone(), + entry_range, RESOLVE_BEFORE_ITEMS, RESOLVE_AFTER_ITEMS, entries.len(), @@ -716,6 +741,33 @@ impl CompletionsMenu { cx: &mut Context, ) -> AnyElement { let show_completion_documentation = self.show_completion_documentation; + let widest_completion_ix = if self.display_options.dynamic_width { + let completions = self.completions.borrow(); + let widest_completion_ix = self + .entries + .borrow() + .iter() + .enumerate() + .max_by_key(|(_, mat)| { + let completion = &completions[mat.candidate_id]; + let documentation = &completion.documentation; + + let mut len = completion.label.text.chars().count(); + if let Some(CompletionDocumentation::SingleLine(text)) = documentation { + if show_completion_documentation { + len += text.chars().count(); + } + } + + len + }) + .map(|(ix, _)| ix); + drop(completions); + widest_completion_ix + } else { + None + }; + let selected_item = self.selected_item; let completions = self.completions.clone(); let entries = self.entries.clone(); @@ -842,7 +894,13 @@ impl CompletionsMenu { .max_h(max_height_in_lines as f32 * window.line_height()) .track_scroll(self.scroll_handle.clone()) .with_sizing_behavior(ListSizingBehavior::Infer) - .w(rems(34.)); + .map(|this| { + if self.display_options.dynamic_width { + this.with_width_from_item(widest_completion_ix) + } else { + this.w(rems(34.)) + } + }); Popover::new().child(list).into_any_element() } @@ -911,6 +969,7 @@ impl CompletionsMenu { .max_w(max_size.width) .max_h(max_size.height) .overflow_y_scroll() + .track_scroll(&self.scroll_handle_aside) .occlude(), ) .into_any_element(), @@ -1111,10 +1170,8 @@ impl CompletionsMenu { let query_start_doesnt_match_split_words = query_start_lower .map(|query_char| { !split_words(&string_match.string).any(|word| { - word.chars() - .next() - .and_then(|c| c.to_lowercase().next()) - .map_or(false, |word_char| word_char == query_char) + word.chars().next().and_then(|c| c.to_lowercase().next()) + == Some(query_char) }) }) .unwrap_or(false); @@ -1177,6 +1234,23 @@ impl CompletionsMenu { } }); } + + pub fn scroll_aside( + &mut self, + amount: ScrollAmount, + window: &mut Window, + cx: &mut Context, + ) { + let mut offset = self.scroll_handle_aside.offset(); + + offset.y -= amount.pixels( + window.line_height(), + self.scroll_handle_aside.bounds().size.height - px(16.), + ) / 2.0; + + cx.notify(); + self.scroll_handle_aside.set_offset(offset); + } } #[derive(Clone)] @@ -1428,6 +1502,7 @@ impl CodeActionsMenu { this.child( h_flex() .overflow_hidden() + .text_sm() .child( // TASK: It would be good to make lsp_action.title a SharedString to avoid allocating here. action.lsp_action.title().replace("\n", ""), diff --git a/crates/editor/src/display_map.rs b/crates/editor/src/display_map.rs index a16e516a70c9638965585cc5d6a23d8a9f67b639..1acbdab7a6646fe46b9ad9d9cb09c1549d64bb1a 100644 --- a/crates/editor/src/display_map.rs +++ b/crates/editor/src/display_map.rs @@ -37,13 +37,13 @@ pub use block_map::{ use block_map::{BlockRow, BlockSnapshot}; use collections::{HashMap, HashSet}; pub use crease_map::*; +use fold_map::FoldSnapshot; pub use fold_map::{ ChunkRenderer, ChunkRendererContext, ChunkRendererId, Fold, FoldId, FoldPlaceholder, FoldPoint, }; -use fold_map::{FoldMap, FoldSnapshot}; use gpui::{App, Context, Entity, Font, HighlightStyle, LineLayout, Pixels, UnderlineStyle}; pub use inlay_map::Inlay; -use inlay_map::{InlayMap, InlaySnapshot}; +use inlay_map::InlaySnapshot; pub use inlay_map::{InlayOffset, InlayPoint}; pub use invisibles::{is_invisible, replacement}; use language::{ @@ -66,12 +66,14 @@ use std::{ sync::Arc, }; use sum_tree::{Bias, TreeMap}; -use tab_map::{TabMap, TabSnapshot}; +use tab_map::TabSnapshot; use text::{BufferId, LineIndent}; use ui::{SharedString, px}; use unicode_segmentation::UnicodeSegmentation; use wrap_map::{WrapMap, WrapSnapshot}; +pub use crate::display_map::{fold_map::FoldMap, inlay_map::InlayMap, tab_map::TabMap}; + #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum FoldStatus { Folded, @@ -703,9 +705,8 @@ impl<'a> HighlightedChunk<'a> { }), ..Default::default() }; - let invisible_style = if let Some(mut style) = style { - style.highlight(invisible_highlight); - style + let invisible_style = if let Some(style) = style { + style.highlight(invisible_highlight) } else { invisible_highlight }; @@ -726,9 +727,8 @@ impl<'a> HighlightedChunk<'a> { }), ..Default::default() }; - let invisible_style = if let Some(mut style) = style { - style.highlight(invisible_highlight); - style + let invisible_style = if let Some(style) = style { + style.highlight(invisible_highlight) } else { invisible_highlight }; @@ -962,62 +962,59 @@ impl DisplaySnapshot { }, ) .flat_map(|chunk| { - let mut highlight_style = chunk + let highlight_style = chunk .syntax_highlight_id .and_then(|id| id.style(&editor_style.syntax)); - if let Some(chunk_highlight) = chunk.highlight_style { - // For color inlays, blend the color with the editor background - let mut processed_highlight = chunk_highlight; - if chunk.is_inlay { - if let Some(inlay_color) = chunk_highlight.color { - // Only blend if the color has transparency (alpha < 1.0) - if inlay_color.a < 1.0 { - let blended_color = editor_style.background.blend(inlay_color); - processed_highlight.color = Some(blended_color); + let chunk_highlight = chunk.highlight_style.map(|chunk_highlight| { + HighlightStyle { + // For color inlays, blend the color with the editor background + // if the color has transparency (alpha < 1.0) + color: chunk_highlight.color.map(|color| { + if chunk.is_inlay && !color.is_opaque() { + editor_style.background.blend(color) + } else { + color } - } + }), + ..chunk_highlight } + }); - if let Some(highlight_style) = highlight_style.as_mut() { - highlight_style.highlight(processed_highlight); - } else { - highlight_style = Some(processed_highlight); - } - } - - let mut diagnostic_highlight = HighlightStyle::default(); - - if let Some(severity) = chunk.diagnostic_severity.filter(|severity| { - self.diagnostics_max_severity - .into_lsp() - .map_or(false, |max_severity| severity <= &max_severity) - }) { - if chunk.is_unnecessary { - diagnostic_highlight.fade_out = Some(editor_style.unnecessary_code_fade); - } - if chunk.underline - && editor_style.show_underlines - && !(chunk.is_unnecessary && severity > lsp::DiagnosticSeverity::WARNING) - { - let diagnostic_color = super::diagnostic_style(severity, &editor_style.status); - diagnostic_highlight.underline = Some(UnderlineStyle { - color: Some(diagnostic_color), - thickness: 1.0.into(), - wavy: true, - }); - } - } + let diagnostic_highlight = chunk + .diagnostic_severity + .filter(|severity| { + self.diagnostics_max_severity + .into_lsp() + .is_some_and(|max_severity| severity <= &max_severity) + }) + .map(|severity| HighlightStyle { + fade_out: chunk + .is_unnecessary + .then_some(editor_style.unnecessary_code_fade), + underline: (chunk.underline + && editor_style.show_underlines + && !(chunk.is_unnecessary && severity > lsp::DiagnosticSeverity::WARNING)) + .then(|| { + let diagnostic_color = + super::diagnostic_style(severity, &editor_style.status); + UnderlineStyle { + color: Some(diagnostic_color), + thickness: 1.0.into(), + wavy: true, + } + }), + ..Default::default() + }); - if let Some(highlight_style) = highlight_style.as_mut() { - highlight_style.highlight(diagnostic_highlight); - } else { - highlight_style = Some(diagnostic_highlight); - } + let style = [highlight_style, chunk_highlight, diagnostic_highlight] + .into_iter() + .flatten() + .reduce(|acc, highlight| acc.highlight(highlight)); HighlightedChunk { text: chunk.text, - style: highlight_style, + style, is_tab: chunk.is_tab, is_inlay: chunk.is_inlay, replacement: chunk.renderer.map(ChunkReplacement::Renderer), @@ -1552,15 +1549,15 @@ pub mod tests { .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) .unwrap_or(10); - let mut tab_size = rng.gen_range(1..=4); - let buffer_start_excerpt_header_height = rng.gen_range(1..=5); - let excerpt_header_height = rng.gen_range(1..=5); + let mut tab_size = rng.random_range(1..=4); + let buffer_start_excerpt_header_height = rng.random_range(1..=5); + let excerpt_header_height = rng.random_range(1..=5); let font_size = px(14.0); let max_wrap_width = 300.0; - let mut wrap_width = if rng.gen_bool(0.1) { + let mut wrap_width = if rng.random_bool(0.1) { None } else { - Some(px(rng.gen_range(0.0..=max_wrap_width))) + Some(px(rng.random_range(0.0..=max_wrap_width))) }; log::info!("tab size: {}", tab_size); @@ -1571,8 +1568,8 @@ pub mod tests { }); let buffer = cx.update(|cx| { - if rng.r#gen() { - let len = rng.gen_range(0..10); + if rng.random() { + let len = rng.random_range(0..10); let text = util::RandomCharIter::new(&mut rng) .take(len) .collect::(); @@ -1609,12 +1606,12 @@ pub mod tests { log::info!("display text: {:?}", snapshot.text()); for _i in 0..operations { - match rng.gen_range(0..100) { + match rng.random_range(0..100) { 0..=19 => { - wrap_width = if rng.gen_bool(0.2) { + wrap_width = if rng.random_bool(0.2) { None } else { - Some(px(rng.gen_range(0.0..=max_wrap_width))) + Some(px(rng.random_range(0.0..=max_wrap_width))) }; log::info!("setting wrap width to {:?}", wrap_width); map.update(cx, |map, cx| map.set_wrap_width(wrap_width, cx)); @@ -1634,28 +1631,27 @@ pub mod tests { } 30..=44 => { map.update(cx, |map, cx| { - if rng.r#gen() || blocks.is_empty() { + if rng.random() || blocks.is_empty() { let buffer = map.snapshot(cx).buffer_snapshot; - let block_properties = (0..rng.gen_range(1..=1)) + let block_properties = (0..rng.random_range(1..=1)) .map(|_| { - let position = - buffer.anchor_after(buffer.clip_offset( - rng.gen_range(0..=buffer.len()), - Bias::Left, - )); + let position = buffer.anchor_after(buffer.clip_offset( + rng.random_range(0..=buffer.len()), + Bias::Left, + )); - let placement = if rng.r#gen() { + let placement = if rng.random() { BlockPlacement::Above(position) } else { BlockPlacement::Below(position) }; - let height = rng.gen_range(1..5); + let height = rng.random_range(1..5); log::info!( "inserting block {:?} with height {}", placement.as_ref().map(|p| p.to_point(&buffer)), height ); - let priority = rng.gen_range(1..100); + let priority = rng.random_range(1..100); BlockProperties { placement, style: BlockStyle::Fixed, @@ -1668,9 +1664,9 @@ pub mod tests { blocks.extend(map.insert_blocks(block_properties, cx)); } else { blocks.shuffle(&mut rng); - let remove_count = rng.gen_range(1..=4.min(blocks.len())); + let remove_count = rng.random_range(1..=4.min(blocks.len())); let block_ids_to_remove = (0..remove_count) - .map(|_| blocks.remove(rng.gen_range(0..blocks.len()))) + .map(|_| blocks.remove(rng.random_range(0..blocks.len()))) .collect(); log::info!("removing block ids {:?}", block_ids_to_remove); map.remove_blocks(block_ids_to_remove, cx); @@ -1679,16 +1675,16 @@ pub mod tests { } 45..=79 => { let mut ranges = Vec::new(); - for _ in 0..rng.gen_range(1..=3) { + for _ in 0..rng.random_range(1..=3) { buffer.read_with(cx, |buffer, cx| { let buffer = buffer.read(cx); - let end = buffer.clip_offset(rng.gen_range(0..=buffer.len()), Right); - let start = buffer.clip_offset(rng.gen_range(0..=end), Left); + let end = buffer.clip_offset(rng.random_range(0..=buffer.len()), Right); + let start = buffer.clip_offset(rng.random_range(0..=end), Left); ranges.push(start..end); }); } - if rng.r#gen() && fold_count > 0 { + if rng.random() && fold_count > 0 { log::info!("unfolding ranges: {:?}", ranges); map.update(cx, |map, cx| { map.unfold_intersecting(ranges, true, cx); @@ -1727,8 +1723,8 @@ pub mod tests { // Line boundaries let buffer = &snapshot.buffer_snapshot; for _ in 0..5 { - let row = rng.gen_range(0..=buffer.max_point().row); - let column = rng.gen_range(0..=buffer.line_len(MultiBufferRow(row))); + let row = rng.random_range(0..=buffer.max_point().row); + let column = rng.random_range(0..=buffer.line_len(MultiBufferRow(row))); let point = buffer.clip_point(Point::new(row, column), Left); let (prev_buffer_bound, prev_display_bound) = snapshot.prev_line_boundary(point); @@ -1776,8 +1772,8 @@ pub mod tests { let min_point = snapshot.clip_point(DisplayPoint::new(DisplayRow(0), 0), Left); let max_point = snapshot.clip_point(snapshot.max_point(), Right); for _ in 0..5 { - let row = rng.gen_range(0..=snapshot.max_point().row().0); - let column = rng.gen_range(0..=snapshot.line_len(DisplayRow(row))); + let row = rng.random_range(0..=snapshot.max_point().row().0); + let column = rng.random_range(0..=snapshot.line_len(DisplayRow(row))); let point = snapshot.clip_point(DisplayPoint::new(DisplayRow(row), column), Left); log::info!("Moving from point {:?}", point); @@ -2351,11 +2347,12 @@ pub mod tests { .highlight_style .and_then(|style| style.color) .map_or(black, |color| color.to_rgb()); - if let Some((last_chunk, last_severity, last_color)) = chunks.last_mut() { - if *last_severity == chunk.diagnostic_severity && *last_color == color { - last_chunk.push_str(chunk.text); - continue; - } + if let Some((last_chunk, last_severity, last_color)) = chunks.last_mut() + && *last_severity == chunk.diagnostic_severity + && *last_color == color + { + last_chunk.push_str(chunk.text); + continue; } chunks.push((chunk.text.to_string(), chunk.diagnostic_severity, color)); @@ -2609,7 +2606,7 @@ pub mod tests { ); language.set_theme(&theme); - let (text, highlighted_ranges) = marked_text_ranges(r#"constˇ «a»: B = "c «d»""#, false); + let (text, highlighted_ranges) = marked_text_ranges(r#"constˇ «a»«:» B = "c «d»""#, false); let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language, cx)); cx.condition(&buffer, |buf, _| !buf.is_parsing()).await; @@ -2658,7 +2655,7 @@ pub mod tests { [ ("const ".to_string(), None, None), ("a".to_string(), None, Some(Hsla::blue())), - (":".to_string(), Some(Hsla::red()), None), + (":".to_string(), Some(Hsla::red()), Some(Hsla::blue())), (" B = ".to_string(), None, None), ("\"c ".to_string(), Some(Hsla::green()), None), ("d".to_string(), Some(Hsla::green()), Some(Hsla::blue())), @@ -2901,11 +2898,12 @@ pub mod tests { .syntax_highlight_id .and_then(|id| id.style(theme)?.color); let highlight_color = chunk.highlight_style.and_then(|style| style.color); - if let Some((last_chunk, last_syntax_color, last_highlight_color)) = chunks.last_mut() { - if syntax_color == *last_syntax_color && highlight_color == *last_highlight_color { - last_chunk.push_str(chunk.text); - continue; - } + if let Some((last_chunk, last_syntax_color, last_highlight_color)) = chunks.last_mut() + && syntax_color == *last_syntax_color + && highlight_color == *last_highlight_color + { + last_chunk.push_str(chunk.text); + continue; } chunks.push((chunk.text.to_string(), syntax_color, highlight_color)); } diff --git a/crates/editor/src/display_map/block_map.rs b/crates/editor/src/display_map/block_map.rs index c4c9f2004adedcda0a2215aa3f073ef10f5aa78e..03d04e7010248293604d10c2f3e553430e74c9c6 100644 --- a/crates/editor/src/display_map/block_map.rs +++ b/crates/editor/src/display_map/block_map.rs @@ -128,10 +128,10 @@ impl BlockPlacement { } } - fn sort_order(&self) -> u8 { + fn tie_break(&self) -> u8 { match self { - BlockPlacement::Above(_) => 0, - BlockPlacement::Replace(_) => 1, + BlockPlacement::Replace(_) => 0, + BlockPlacement::Above(_) => 1, BlockPlacement::Near(_) => 2, BlockPlacement::Below(_) => 3, } @@ -143,7 +143,7 @@ impl BlockPlacement { self.start() .cmp(other.start(), buffer) .then_with(|| other.end().cmp(self.end(), buffer)) - .then_with(|| self.sort_order().cmp(&other.sort_order())) + .then_with(|| self.tie_break().cmp(&other.tie_break())) } fn to_wrap_row(&self, wrap_snapshot: &WrapSnapshot) -> Option> { @@ -290,7 +290,10 @@ pub enum Block { ExcerptBoundary { excerpt: ExcerptInfo, height: u32, - starts_new_buffer: bool, + }, + BufferHeader { + excerpt: ExcerptInfo, + height: u32, }, } @@ -303,27 +306,37 @@ impl Block { .. } => BlockId::ExcerptBoundary(next_excerpt.id), Block::FoldedBuffer { first_excerpt, .. } => BlockId::FoldedBuffer(first_excerpt.id), + Block::BufferHeader { + excerpt: next_excerpt, + .. + } => BlockId::ExcerptBoundary(next_excerpt.id), } } pub fn has_height(&self) -> bool { match self { Block::Custom(block) => block.height.is_some(), - Block::ExcerptBoundary { .. } | Block::FoldedBuffer { .. } => true, + Block::ExcerptBoundary { .. } + | Block::FoldedBuffer { .. } + | Block::BufferHeader { .. } => true, } } pub fn height(&self) -> u32 { match self { Block::Custom(block) => block.height.unwrap_or(0), - Block::ExcerptBoundary { height, .. } | Block::FoldedBuffer { height, .. } => *height, + Block::ExcerptBoundary { height, .. } + | Block::FoldedBuffer { height, .. } + | Block::BufferHeader { height, .. } => *height, } } pub fn style(&self) -> BlockStyle { match self { Block::Custom(block) => block.style, - Block::ExcerptBoundary { .. } | Block::FoldedBuffer { .. } => BlockStyle::Sticky, + Block::ExcerptBoundary { .. } + | Block::FoldedBuffer { .. } + | Block::BufferHeader { .. } => BlockStyle::Sticky, } } @@ -332,6 +345,7 @@ impl Block { Block::Custom(block) => matches!(block.placement, BlockPlacement::Above(_)), Block::FoldedBuffer { .. } => false, Block::ExcerptBoundary { .. } => true, + Block::BufferHeader { .. } => true, } } @@ -340,6 +354,7 @@ impl Block { Block::Custom(block) => matches!(block.placement, BlockPlacement::Near(_)), Block::FoldedBuffer { .. } => false, Block::ExcerptBoundary { .. } => false, + Block::BufferHeader { .. } => false, } } @@ -351,6 +366,7 @@ impl Block { ), Block::FoldedBuffer { .. } => false, Block::ExcerptBoundary { .. } => false, + Block::BufferHeader { .. } => false, } } @@ -359,6 +375,7 @@ impl Block { Block::Custom(block) => matches!(block.placement, BlockPlacement::Replace(_)), Block::FoldedBuffer { .. } => true, Block::ExcerptBoundary { .. } => false, + Block::BufferHeader { .. } => false, } } @@ -367,6 +384,7 @@ impl Block { Block::Custom(_) => false, Block::FoldedBuffer { .. } => true, Block::ExcerptBoundary { .. } => true, + Block::BufferHeader { .. } => true, } } @@ -374,9 +392,8 @@ impl Block { match self { Block::Custom(_) => false, Block::FoldedBuffer { .. } => true, - Block::ExcerptBoundary { - starts_new_buffer, .. - } => *starts_new_buffer, + Block::ExcerptBoundary { .. } => false, + Block::BufferHeader { .. } => true, } } } @@ -393,14 +410,14 @@ impl Debug for Block { .field("first_excerpt", &first_excerpt) .field("height", height) .finish(), - Self::ExcerptBoundary { - starts_new_buffer, - excerpt, - height, - } => f + Self::ExcerptBoundary { excerpt, height } => f .debug_struct("ExcerptBoundary") .field("excerpt", excerpt) - .field("starts_new_buffer", starts_new_buffer) + .field("height", height) + .finish(), + Self::BufferHeader { excerpt, height } => f + .debug_struct("BufferHeader") + .field("excerpt", excerpt) .field("height", height) .finish(), } @@ -525,26 +542,22 @@ impl BlockMap { // * Below blocks that end at the start of the edit // However, if we hit a replace block that ends at the start of the edit we want to reconstruct it. new_transforms.append(cursor.slice(&old_start, Bias::Left), &()); - if let Some(transform) = cursor.item() { - if transform.summary.input_rows > 0 - && cursor.end() == old_start - && transform - .block - .as_ref() - .map_or(true, |b| !b.is_replacement()) - { - // Preserve the transform (push and next) - new_transforms.push(transform.clone(), &()); - cursor.next(); + if let Some(transform) = cursor.item() + && transform.summary.input_rows > 0 + && cursor.end() == old_start + && transform.block.as_ref().is_none_or(|b| !b.is_replacement()) + { + // Preserve the transform (push and next) + new_transforms.push(transform.clone(), &()); + cursor.next(); - // Preserve below blocks at end of edit - while let Some(transform) = cursor.item() { - if transform.block.as_ref().map_or(false, |b| b.place_below()) { - new_transforms.push(transform.clone(), &()); - cursor.next(); - } else { - break; - } + // Preserve below blocks at end of edit + while let Some(transform) = cursor.item() { + if transform.block.as_ref().is_some_and(|b| b.place_below()) { + new_transforms.push(transform.clone(), &()); + cursor.next(); + } else { + break; } } } @@ -607,7 +620,7 @@ impl BlockMap { // Discard below blocks at the end of the edit. They'll be reconstructed. while let Some(transform) = cursor.item() { - if transform.block.as_ref().map_or(false, |b| b.place_below()) { + if transform.block.as_ref().is_some_and(|b| b.place_below()) { cursor.next(); } else { break; @@ -657,22 +670,20 @@ impl BlockMap { .iter() .filter_map(|block| { let placement = block.placement.to_wrap_row(wrap_snapshot)?; - if let BlockPlacement::Above(row) = placement { - if row < new_start { - return None; - } + if let BlockPlacement::Above(row) = placement + && row < new_start + { + return None; } Some((placement, Block::Custom(block.clone()))) }), ); - if buffer.show_headers() { - blocks_in_edit.extend(self.header_and_footer_blocks( - buffer, - (start_bound, end_bound), - wrap_snapshot, - )); - } + blocks_in_edit.extend(self.header_and_footer_blocks( + buffer, + (start_bound, end_bound), + wrap_snapshot, + )); BlockMap::sort_blocks(&mut blocks_in_edit); @@ -775,7 +786,7 @@ impl BlockMap { if self.buffers_with_disabled_headers.contains(&new_buffer_id) { continue; } - if self.folded_buffers.contains(&new_buffer_id) { + if self.folded_buffers.contains(&new_buffer_id) && buffer.show_headers() { let mut last_excerpt_end_row = first_excerpt.end_row; while let Some(next_boundary) = boundaries.peek() { @@ -808,20 +819,24 @@ impl BlockMap { } } - if new_buffer_id.is_some() { + let starts_new_buffer = new_buffer_id.is_some(); + let block = if starts_new_buffer && buffer.show_headers() { height += self.buffer_header_height; - } else { + Block::BufferHeader { + excerpt: excerpt_boundary.next, + height, + } + } else if excerpt_boundary.prev.is_some() { height += self.excerpt_header_height; - } - - return Some(( - BlockPlacement::Above(WrapRow(wrap_row)), Block::ExcerptBoundary { excerpt: excerpt_boundary.next, height, - starts_new_buffer: new_buffer_id.is_some(), - }, - )); + } + } else { + continue; + }; + + return Some((BlockPlacement::Above(WrapRow(wrap_row)), block)); } }) } @@ -832,6 +847,7 @@ impl BlockMap { .start() .cmp(placement_b.start()) .then_with(|| placement_b.end().cmp(placement_a.end())) + .then_with(|| placement_a.tie_break().cmp(&placement_b.tie_break())) .then_with(|| { if block_a.is_header() { Ordering::Less @@ -841,18 +857,29 @@ impl BlockMap { Ordering::Equal } }) - .then_with(|| placement_a.sort_order().cmp(&placement_b.sort_order())) .then_with(|| match (block_a, block_b) { ( Block::ExcerptBoundary { excerpt: excerpt_a, .. + } + | Block::BufferHeader { + excerpt: excerpt_a, .. }, Block::ExcerptBoundary { excerpt: excerpt_b, .. + } + | Block::BufferHeader { + excerpt: excerpt_b, .. }, ) => Some(excerpt_a.id).cmp(&Some(excerpt_b.id)), - (Block::ExcerptBoundary { .. }, Block::Custom(_)) => Ordering::Less, - (Block::Custom(_), Block::ExcerptBoundary { .. }) => Ordering::Greater, + ( + Block::ExcerptBoundary { .. } | Block::BufferHeader { .. }, + Block::Custom(_), + ) => Ordering::Less, + ( + Block::Custom(_), + Block::ExcerptBoundary { .. } | Block::BufferHeader { .. }, + ) => Ordering::Greater, (Block::Custom(block_a), Block::Custom(block_b)) => block_a .priority .cmp(&block_b.priority) @@ -977,10 +1004,10 @@ impl BlockMapReader<'_> { break; } - if let Some(BlockId::Custom(id)) = transform.block.as_ref().map(|block| block.id()) { - if id == block_id { - return Some(cursor.start().1); - } + if let Some(BlockId::Custom(id)) = transform.block.as_ref().map(|block| block.id()) + && id == block_id + { + return Some(cursor.start().1); } cursor.next(); } @@ -1299,14 +1326,14 @@ impl BlockSnapshot { let mut input_start = transform_input_start; let mut input_end = transform_input_start; - if let Some(transform) = cursor.item() { - if transform.block.is_none() { - input_start += rows.start - transform_output_start; - input_end += cmp::min( - rows.end - transform_output_start, - transform.summary.input_rows, - ); - } + if let Some(transform) = cursor.item() + && transform.block.is_none() + { + input_start += rows.start - transform_output_start; + input_end += cmp::min( + rows.end - transform_output_start, + transform.summary.input_rows, + ); } BlockChunks { @@ -1329,7 +1356,7 @@ impl BlockSnapshot { let Dimensions(output_start, input_start, _) = cursor.start(); let overshoot = if cursor .item() - .map_or(false, |transform| transform.block.is_none()) + .is_some_and(|transform| transform.block.is_none()) { start_row.0 - output_start.0 } else { @@ -1359,7 +1386,7 @@ impl BlockSnapshot { && transform .block .as_ref() - .map_or(false, |block| block.height() > 0)) + .is_some_and(|block| block.height() > 0)) { break; } @@ -1381,7 +1408,9 @@ impl BlockSnapshot { while let Some(transform) = cursor.item() { match &transform.block { - Some(Block::ExcerptBoundary { excerpt, .. }) => { + Some( + Block::ExcerptBoundary { excerpt, .. } | Block::BufferHeader { excerpt, .. }, + ) => { return Some(StickyHeaderExcerpt { excerpt }); } Some(block) if block.is_buffer_header() => return None, @@ -1472,18 +1501,18 @@ impl BlockSnapshot { longest_row_chars = summary.longest_row_chars; } - if let Some(transform) = cursor.item() { - if transform.block.is_none() { - let Dimensions(output_start, input_start, _) = cursor.start(); - let overshoot = range.end.0 - output_start.0; - let wrap_start_row = input_start.0; - let wrap_end_row = input_start.0 + overshoot; - let summary = self - .wrap_snapshot - .text_summary_for_range(wrap_start_row..wrap_end_row); - if summary.longest_row_chars > longest_row_chars { - longest_row = BlockRow(output_start.0 + summary.longest_row); - } + if let Some(transform) = cursor.item() + && transform.block.is_none() + { + let Dimensions(output_start, input_start, _) = cursor.start(); + let overshoot = range.end.0 - output_start.0; + let wrap_start_row = input_start.0; + let wrap_end_row = input_start.0 + overshoot; + let summary = self + .wrap_snapshot + .text_summary_for_range(wrap_start_row..wrap_end_row); + if summary.longest_row_chars > longest_row_chars { + longest_row = BlockRow(output_start.0 + summary.longest_row); } } } @@ -1512,7 +1541,7 @@ impl BlockSnapshot { pub(super) fn is_block_line(&self, row: BlockRow) -> bool { let mut cursor = self.transforms.cursor::>(&()); cursor.seek(&row, Bias::Right); - cursor.item().map_or(false, |t| t.block.is_some()) + cursor.item().is_some_and(|t| t.block.is_some()) } pub(super) fn is_folded_buffer_header(&self, row: BlockRow) -> bool { @@ -1530,11 +1559,11 @@ impl BlockSnapshot { .make_wrap_point(Point::new(row.0, 0), Bias::Left); let mut cursor = self.transforms.cursor::>(&()); cursor.seek(&WrapRow(wrap_point.row()), Bias::Right); - cursor.item().map_or(false, |transform| { + cursor.item().is_some_and(|transform| { transform .block .as_ref() - .map_or(false, |block| block.is_replacement()) + .is_some_and(|block| block.is_replacement()) }) } @@ -1557,12 +1586,11 @@ impl BlockSnapshot { match transform.block.as_ref() { Some(block) => { - if block.is_replacement() { - if ((bias == Bias::Left || search_left) && output_start <= point.0) - || (!search_left && output_start >= point.0) - { - return BlockPoint(output_start); - } + if block.is_replacement() + && (((bias == Bias::Left || search_left) && output_start <= point.0) + || (!search_left && output_start >= point.0)) + { + return BlockPoint(output_start); } } None => { @@ -1655,7 +1683,7 @@ impl BlockChunks<'_> { if transform .block .as_ref() - .map_or(false, |block| block.height() == 0) + .is_some_and(|block| block.height() == 0) { self.transforms.next(); } else { @@ -1666,7 +1694,7 @@ impl BlockChunks<'_> { if self .transforms .item() - .map_or(false, |transform| transform.block.is_none()) + .is_some_and(|transform| transform.block.is_none()) { let start_input_row = self.transforms.start().1.0; let start_output_row = self.transforms.start().0.0; @@ -1709,6 +1737,7 @@ impl<'a> Iterator for BlockChunks<'a> { return Some(Chunk { text: unsafe { std::str::from_utf8_unchecked(&NEWLINES[..line_count as usize]) }, + chars: (1 << line_count) - 1, ..Default::default() }); } @@ -1738,17 +1767,26 @@ impl<'a> Iterator for BlockChunks<'a> { let (mut prefix, suffix) = self.input_chunk.text.split_at(prefix_bytes); self.input_chunk.text = suffix; + self.input_chunk.tabs >>= prefix_bytes.saturating_sub(1); + self.input_chunk.chars >>= prefix_bytes.saturating_sub(1); + + let mut tabs = self.input_chunk.tabs; + let mut chars = self.input_chunk.chars; if self.masked { // Not great for multibyte text because to keep cursor math correct we // need to have the same number of bytes in the input as output. - let chars = prefix.chars().count(); - let bullet_len = chars; + let chars_count = prefix.chars().count(); + let bullet_len = chars_count; prefix = &BULLETS[..bullet_len]; + chars = (1 << bullet_len) - 1; + tabs = 0; } let chunk = Chunk { text: prefix, + tabs, + chars, ..self.input_chunk.clone() }; @@ -1776,7 +1814,7 @@ impl Iterator for BlockRows<'_> { if transform .block .as_ref() - .map_or(false, |block| block.height() == 0) + .is_some_and(|block| block.height() == 0) { self.transforms.next(); } else { @@ -1788,7 +1826,7 @@ impl Iterator for BlockRows<'_> { if transform .block .as_ref() - .map_or(true, |block| block.is_replacement()) + .is_none_or(|block| block.is_replacement()) { self.input_rows.seek(self.transforms.start().1.0); } @@ -2161,7 +2199,7 @@ mod tests { } let multi_buffer_snapshot = multi_buffer.read(cx).snapshot(cx); - let (_, inlay_snapshot) = InlayMap::new(multi_buffer_snapshot.clone()); + let (_, inlay_snapshot) = InlayMap::new(multi_buffer_snapshot); let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); let (_, tab_snapshot) = TabMap::new(fold_snapshot, 4.try_into().unwrap()); let (_, wraps_snapshot) = WrapMap::new(tab_snapshot, font, font_size, Some(wrap_width), cx); @@ -2280,7 +2318,7 @@ mod tests { new_heights.insert(block_ids[0], 3); block_map_writer.resize(new_heights); - let snapshot = block_map.read(wraps_snapshot.clone(), Default::default()); + let snapshot = block_map.read(wraps_snapshot, Default::default()); // Same height as before, should remain the same assert_eq!(snapshot.text(), "aaa\n\n\n\n\n\nbbb\nccc\nddd\n\n\n"); } @@ -2365,16 +2403,14 @@ mod tests { buffer.edit([(Point::new(2, 0)..Point::new(3, 0), "")], None, cx); buffer.snapshot(cx) }); - let (inlay_snapshot, inlay_edits) = inlay_map.sync( - buffer_snapshot.clone(), - buffer_subscription.consume().into_inner(), - ); + let (inlay_snapshot, inlay_edits) = + inlay_map.sync(buffer_snapshot, buffer_subscription.consume().into_inner()); let (fold_snapshot, fold_edits) = fold_map.read(inlay_snapshot, inlay_edits); let (tab_snapshot, tab_edits) = tab_map.sync(fold_snapshot, fold_edits, tab_size); let (wraps_snapshot, wrap_edits) = wrap_map.update(cx, |wrap_map, cx| { wrap_map.sync(tab_snapshot, tab_edits, cx) }); - let blocks_snapshot = block_map.read(wraps_snapshot.clone(), wrap_edits); + let blocks_snapshot = block_map.read(wraps_snapshot, wrap_edits); assert_eq!(blocks_snapshot.text(), "line1\n\n\n\n\nline5"); let buffer_snapshot = buffer.update(cx, |buffer, cx| { @@ -2459,7 +2495,7 @@ mod tests { // Removing the replace block shows all the hidden blocks again. let mut writer = block_map.write(wraps_snapshot.clone(), Default::default()); writer.remove(HashSet::from_iter([replace_block_id])); - let blocks_snapshot = block_map.read(wraps_snapshot.clone(), Default::default()); + let blocks_snapshot = block_map.read(wraps_snapshot, Default::default()); assert_eq!( blocks_snapshot.text(), "\nline1\n\nline2\n\n\nline 2.1\nline2.2\nline 2.3\nline 2.4\n\nline4\n\nline5" @@ -2798,7 +2834,7 @@ mod tests { buffer.read_with(cx, |buffer, cx| { writer.fold_buffers([buffer_id_3], buffer, cx); }); - let blocks_snapshot = block_map.read(wrap_snapshot.clone(), Patch::default()); + let blocks_snapshot = block_map.read(wrap_snapshot, Patch::default()); let blocks = blocks_snapshot .blocks_in_range(0..u32::MAX) .collect::>(); @@ -2851,7 +2887,7 @@ mod tests { assert_eq!(buffer_ids.len(), 1); let buffer_id = buffer_ids[0]; - let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot.clone()); + let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot); let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); let (_, tab_snapshot) = TabMap::new(fold_snapshot, 4.try_into().unwrap()); let (_, wrap_snapshot) = @@ -2865,7 +2901,7 @@ mod tests { buffer.read_with(cx, |buffer, cx| { writer.fold_buffers([buffer_id], buffer, cx); }); - let blocks_snapshot = block_map.read(wrap_snapshot.clone(), Patch::default()); + let blocks_snapshot = block_map.read(wrap_snapshot, Patch::default()); let blocks = blocks_snapshot .blocks_in_range(0..u32::MAX) .collect::>(); @@ -2873,12 +2909,7 @@ mod tests { 1, blocks .iter() - .filter(|(_, block)| { - match block { - Block::FoldedBuffer { .. } => true, - _ => false, - } - }) + .filter(|(_, block)| { matches!(block, Block::FoldedBuffer { .. }) }) .count(), "Should have one folded block, producing a header of the second buffer" ); @@ -2901,21 +2932,21 @@ mod tests { .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) .unwrap_or(10); - let wrap_width = if rng.gen_bool(0.2) { + let wrap_width = if rng.random_bool(0.2) { None } else { - Some(px(rng.gen_range(0.0..=100.0))) + Some(px(rng.random_range(0.0..=100.0))) }; let tab_size = 1.try_into().unwrap(); let font_size = px(14.0); - let buffer_start_header_height = rng.gen_range(1..=5); - let excerpt_header_height = rng.gen_range(1..=5); + let buffer_start_header_height = rng.random_range(1..=5); + let excerpt_header_height = rng.random_range(1..=5); log::info!("Wrap width: {:?}", wrap_width); log::info!("Excerpt Header Height: {:?}", excerpt_header_height); - let is_singleton = rng.r#gen(); + let is_singleton = rng.random(); let buffer = if is_singleton { - let len = rng.gen_range(0..10); + let len = rng.random_range(0..10); let text = RandomCharIter::new(&mut rng).take(len).collect::(); log::info!("initial singleton buffer text: {:?}", text); cx.update(|cx| MultiBuffer::build_simple(&text, cx)) @@ -2945,30 +2976,30 @@ mod tests { for _ in 0..operations { let mut buffer_edits = Vec::new(); - match rng.gen_range(0..=100) { + match rng.random_range(0..=100) { 0..=19 => { - let wrap_width = if rng.gen_bool(0.2) { + let wrap_width = if rng.random_bool(0.2) { None } else { - Some(px(rng.gen_range(0.0..=100.0))) + Some(px(rng.random_range(0.0..=100.0))) }; log::info!("Setting wrap width to {:?}", wrap_width); wrap_map.update(cx, |map, cx| map.set_wrap_width(wrap_width, cx)); } 20..=39 => { - let block_count = rng.gen_range(1..=5); + let block_count = rng.random_range(1..=5); let block_properties = (0..block_count) .map(|_| { let buffer = cx.update(|cx| buffer.read(cx).read(cx).clone()); let offset = - buffer.clip_offset(rng.gen_range(0..=buffer.len()), Bias::Left); + buffer.clip_offset(rng.random_range(0..=buffer.len()), Bias::Left); let mut min_height = 0; - let placement = match rng.gen_range(0..3) { + let placement = match rng.random_range(0..3) { 0 => { min_height = 1; let start = buffer.anchor_after(offset); let end = buffer.anchor_after(buffer.clip_offset( - rng.gen_range(offset..=buffer.len()), + rng.random_range(offset..=buffer.len()), Bias::Left, )); BlockPlacement::Replace(start..=end) @@ -2977,7 +3008,7 @@ mod tests { _ => BlockPlacement::Below(buffer.anchor_after(offset)), }; - let height = rng.gen_range(min_height..5); + let height = rng.random_range(min_height..5); BlockProperties { style: BlockStyle::Fixed, placement, @@ -3019,7 +3050,7 @@ mod tests { } } 40..=59 if !block_map.custom_blocks.is_empty() => { - let block_count = rng.gen_range(1..=4.min(block_map.custom_blocks.len())); + let block_count = rng.random_range(1..=4.min(block_map.custom_blocks.len())); let block_ids_to_remove = block_map .custom_blocks .choose_multiple(&mut rng, block_count) @@ -3074,8 +3105,8 @@ mod tests { let mut folded_count = folded_buffers.len(); let mut unfolded_count = unfolded_buffers.len(); - let fold = !unfolded_buffers.is_empty() && rng.gen_bool(0.5); - let unfold = !folded_buffers.is_empty() && rng.gen_bool(0.5); + let fold = !unfolded_buffers.is_empty() && rng.random_bool(0.5); + let unfold = !folded_buffers.is_empty() && rng.random_bool(0.5); if !fold && !unfold { log::info!( "Noop fold/unfold operation. Unfolded buffers: {unfolded_count}, folded buffers: {folded_count}" @@ -3086,7 +3117,7 @@ mod tests { buffer.update(cx, |buffer, cx| { if fold { let buffer_to_fold = - unfolded_buffers[rng.gen_range(0..unfolded_buffers.len())]; + unfolded_buffers[rng.random_range(0..unfolded_buffers.len())]; log::info!("Folding {buffer_to_fold:?}"); let related_excerpts = buffer_snapshot .excerpts() @@ -3112,7 +3143,7 @@ mod tests { } if unfold { let buffer_to_unfold = - folded_buffers[rng.gen_range(0..folded_buffers.len())]; + folded_buffers[rng.random_range(0..folded_buffers.len())]; log::info!("Unfolding {buffer_to_unfold:?}"); unfolded_count += 1; folded_count -= 1; @@ -3125,7 +3156,7 @@ mod tests { } _ => { buffer.update(cx, |buffer, cx| { - let mutation_count = rng.gen_range(1..=5); + let mutation_count = rng.random_range(1..=5); let subscription = buffer.subscribe(); buffer.randomly_mutate(&mut rng, mutation_count, cx); buffer_snapshot = buffer.snapshot(cx); @@ -3195,9 +3226,9 @@ mod tests { // so we special case row 0 to assume a leading '\n'. // // Linehood is the birthright of strings. - let mut input_text_lines = input_text.split('\n').enumerate().peekable(); + let input_text_lines = input_text.split('\n').enumerate().peekable(); let mut block_row = 0; - while let Some((wrap_row, input_line)) = input_text_lines.next() { + for (wrap_row, input_line) in input_text_lines { let wrap_row = wrap_row as u32; let multibuffer_row = wraps_snapshot .to_point(WrapPoint::new(wrap_row, 0), Bias::Left) @@ -3228,34 +3259,32 @@ mod tests { let mut is_in_replace_block = false; if let Some((BlockPlacement::Replace(replace_range), block)) = sorted_blocks_iter.peek() + && wrap_row >= replace_range.start().0 { - if wrap_row >= replace_range.start().0 { - is_in_replace_block = true; + is_in_replace_block = true; - if wrap_row == replace_range.start().0 { - if matches!(block, Block::FoldedBuffer { .. }) { - expected_buffer_rows.push(None); - } else { - expected_buffer_rows - .push(input_buffer_rows[multibuffer_row as usize]); - } + if wrap_row == replace_range.start().0 { + if matches!(block, Block::FoldedBuffer { .. }) { + expected_buffer_rows.push(None); + } else { + expected_buffer_rows.push(input_buffer_rows[multibuffer_row as usize]); } + } - if wrap_row == replace_range.end().0 { - expected_block_positions.push((block_row, block.id())); - let text = "\n".repeat((block.height() - 1) as usize); - if block_row > 0 { - expected_text.push('\n'); - } - expected_text.push_str(&text); - - for _ in 1..block.height() { - expected_buffer_rows.push(None); - } - block_row += block.height(); + if wrap_row == replace_range.end().0 { + expected_block_positions.push((block_row, block.id())); + let text = "\n".repeat((block.height() - 1) as usize); + if block_row > 0 { + expected_text.push('\n'); + } + expected_text.push_str(&text); - sorted_blocks_iter.next(); + for _ in 1..block.height() { + expected_buffer_rows.push(None); } + block_row += block.height(); + + sorted_blocks_iter.next(); } } @@ -3312,7 +3341,7 @@ mod tests { ); for start_row in 0..expected_row_count { - let end_row = rng.gen_range(start_row + 1..=expected_row_count); + let end_row = rng.random_range(start_row + 1..=expected_row_count); let mut expected_text = expected_lines[start_row..end_row].join("\n"); if end_row < expected_row_count { expected_text.push('\n'); @@ -3407,8 +3436,8 @@ mod tests { ); for _ in 0..10 { - let end_row = rng.gen_range(1..=expected_lines.len()); - let start_row = rng.gen_range(0..end_row); + let end_row = rng.random_range(1..=expected_lines.len()); + let start_row = rng.random_range(0..end_row); let mut expected_longest_rows_in_range = vec![]; let mut longest_line_len_in_range = 0; @@ -3539,7 +3568,7 @@ mod tests { ..buffer_snapshot.anchor_after(Point::new(1, 0))], false, ); - let blocks_snapshot = block_map.read(wraps_snapshot.clone(), Default::default()); + let blocks_snapshot = block_map.read(wraps_snapshot, Default::default()); assert_eq!(blocks_snapshot.text(), "abc\n\ndef\nghi\njkl\nmno"); } diff --git a/crates/editor/src/display_map/custom_highlights.rs b/crates/editor/src/display_map/custom_highlights.rs index ae69e9cf8c710acecc840ef14082c8f9d91d7c03..b7518af59c28dbc95a36d24b36a7eae2862916b6 100644 --- a/crates/editor/src/display_map/custom_highlights.rs +++ b/crates/editor/src/display_map/custom_highlights.rs @@ -25,9 +25,8 @@ pub struct CustomHighlightsChunks<'a> { #[derive(Debug, Copy, Clone, Eq, PartialEq)] struct HighlightEndpoint { offset: usize, - is_start: bool, tag: HighlightKey, - style: HighlightStyle, + style: Option, } impl<'a> CustomHighlightsChunks<'a> { @@ -77,7 +76,7 @@ fn create_highlight_endpoints( let ranges = &text_highlights.1; let start_ix = match ranges.binary_search_by(|probe| { - let cmp = probe.end.cmp(&start, &buffer); + let cmp = probe.end.cmp(&start, buffer); if cmp.is_gt() { cmp::Ordering::Greater } else { @@ -88,21 +87,24 @@ fn create_highlight_endpoints( }; for range in &ranges[start_ix..] { - if range.start.cmp(&end, &buffer).is_ge() { + if range.start.cmp(&end, buffer).is_ge() { break; } + let start = range.start.to_offset(buffer); + let end = range.end.to_offset(buffer); + if start == end { + continue; + } highlight_endpoints.push(HighlightEndpoint { - offset: range.start.to_offset(&buffer), - is_start: true, + offset: start, tag, - style, + style: Some(style), }); highlight_endpoints.push(HighlightEndpoint { - offset: range.end.to_offset(&buffer), - is_start: false, + offset: end, tag, - style, + style: None, }); } } @@ -118,8 +120,8 @@ impl<'a> Iterator for CustomHighlightsChunks<'a> { let mut next_highlight_endpoint = usize::MAX; while let Some(endpoint) = self.highlight_endpoints.peek().copied() { if endpoint.offset <= self.offset { - if endpoint.is_start { - self.active_highlights.insert(endpoint.tag, endpoint.style); + if let Some(style) = endpoint.style { + self.active_highlights.insert(endpoint.tag, style); } else { self.active_highlights.remove(&endpoint.tag); } @@ -132,27 +134,41 @@ impl<'a> Iterator for CustomHighlightsChunks<'a> { let chunk = self .buffer_chunk - .get_or_insert_with(|| self.buffer_chunks.next().unwrap()); + .get_or_insert_with(|| self.buffer_chunks.next().unwrap_or_default()); if chunk.text.is_empty() { - *chunk = self.buffer_chunks.next().unwrap(); + *chunk = self.buffer_chunks.next()?; } - let (prefix, suffix) = chunk - .text - .split_at(chunk.text.len().min(next_highlight_endpoint - self.offset)); + let split_idx = chunk.text.len().min(next_highlight_endpoint - self.offset); + let (prefix, suffix) = chunk.text.split_at(split_idx); + + let (chars, tabs) = if split_idx == 128 { + let output = (chunk.chars, chunk.tabs); + chunk.chars = 0; + chunk.tabs = 0; + output + } else { + let mask = (1 << split_idx) - 1; + let output = (chunk.chars & mask, chunk.tabs & mask); + chunk.chars = chunk.chars >> split_idx; + chunk.tabs = chunk.tabs >> split_idx; + output + }; chunk.text = suffix; self.offset += prefix.len(); let mut prefix = Chunk { text: prefix, + chars, + tabs, ..chunk.clone() }; if !self.active_highlights.is_empty() { - let mut highlight_style = HighlightStyle::default(); - for active_highlight in self.active_highlights.values() { - highlight_style.highlight(*active_highlight); - } - prefix.highlight_style = Some(highlight_style); + prefix.highlight_style = self + .active_highlights + .values() + .copied() + .reduce(|acc, active_highlight| acc.highlight(active_highlight)); } Some(prefix) } @@ -168,6 +184,143 @@ impl Ord for HighlightEndpoint { fn cmp(&self, other: &Self) -> cmp::Ordering { self.offset .cmp(&other.offset) - .then_with(|| other.is_start.cmp(&self.is_start)) + .then_with(|| self.style.is_some().cmp(&other.style.is_some())) + } +} + +#[cfg(test)] +mod tests { + use std::{any::TypeId, sync::Arc}; + + use super::*; + use crate::MultiBuffer; + use gpui::App; + use rand::prelude::*; + use util::RandomCharIter; + + #[gpui::test(iterations = 100)] + fn test_random_chunk_bitmaps(cx: &mut App, mut rng: StdRng) { + // Generate random buffer using existing test infrastructure + let len = rng.random_range(10..10000); + let buffer = if rng.random() { + let text = RandomCharIter::new(&mut rng).take(len).collect::(); + MultiBuffer::build_simple(&text, cx) + } else { + MultiBuffer::build_random(&mut rng, cx) + }; + + let buffer_snapshot = buffer.read(cx).snapshot(cx); + + // Create random highlights + let mut highlights = sum_tree::TreeMap::default(); + let highlight_count = rng.random_range(1..10); + + for _i in 0..highlight_count { + let style = HighlightStyle { + color: Some(gpui::Hsla { + h: rng.random::(), + s: rng.random::(), + l: rng.random::(), + a: 1.0, + }), + ..Default::default() + }; + + let mut ranges = Vec::new(); + let range_count = rng.random_range(1..10); + let text = buffer_snapshot.text(); + for _ in 0..range_count { + if buffer_snapshot.len() == 0 { + continue; + } + + let mut start = rng.random_range(0..=buffer_snapshot.len().saturating_sub(10)); + + while !text.is_char_boundary(start) { + start = start.saturating_sub(1); + } + + let end_end = buffer_snapshot.len().min(start + 100); + let mut end = rng.random_range(start..=end_end); + while !text.is_char_boundary(end) { + end = end.saturating_sub(1); + } + + if start < end { + start = end; + } + let start_anchor = buffer_snapshot.anchor_before(start); + let end_anchor = buffer_snapshot.anchor_after(end); + ranges.push(start_anchor..end_anchor); + } + + let type_id = TypeId::of::<()>(); // Simple type ID for testing + highlights.insert(HighlightKey::Type(type_id), Arc::new((style, ranges))); + } + + // Get all chunks and verify their bitmaps + let chunks = + CustomHighlightsChunks::new(0..buffer_snapshot.len(), false, None, &buffer_snapshot); + + for chunk in chunks { + let chunk_text = chunk.text; + let chars_bitmap = chunk.chars; + let tabs_bitmap = chunk.tabs; + + // Check empty chunks have empty bitmaps + if chunk_text.is_empty() { + assert_eq!( + chars_bitmap, 0, + "Empty chunk should have empty chars bitmap" + ); + assert_eq!(tabs_bitmap, 0, "Empty chunk should have empty tabs bitmap"); + continue; + } + + // Verify that chunk text doesn't exceed 128 bytes + assert!( + chunk_text.len() <= 128, + "Chunk text length {} exceeds 128 bytes", + chunk_text.len() + ); + + // Verify chars bitmap + let char_indices = chunk_text + .char_indices() + .map(|(i, _)| i) + .collect::>(); + + for byte_idx in 0..chunk_text.len() { + let should_have_bit = char_indices.contains(&byte_idx); + let has_bit = chars_bitmap & (1 << byte_idx) != 0; + + if has_bit != should_have_bit { + eprintln!("Chunk text bytes: {:?}", chunk_text.as_bytes()); + eprintln!("Char indices: {:?}", char_indices); + eprintln!("Chars bitmap: {:#b}", chars_bitmap); + assert_eq!( + has_bit, should_have_bit, + "Chars bitmap mismatch at byte index {} in chunk {:?}. Expected bit: {}, Got bit: {}", + byte_idx, chunk_text, should_have_bit, has_bit + ); + } + } + + // Verify tabs bitmap + for (byte_idx, byte) in chunk_text.bytes().enumerate() { + let is_tab = byte == b'\t'; + let has_bit = tabs_bitmap & (1 << byte_idx) != 0; + + if has_bit != is_tab { + eprintln!("Chunk text bytes: {:?}", chunk_text.as_bytes()); + eprintln!("Tabs bitmap: {:#b}", tabs_bitmap); + assert_eq!( + has_bit, is_tab, + "Tabs bitmap mismatch at byte index {} in chunk {:?}. Byte: {:?}, Expected bit: {}, Got bit: {}", + byte_idx, chunk_text, byte as char, is_tab, has_bit + ); + } + } + } } } diff --git a/crates/editor/src/display_map/fold_map.rs b/crates/editor/src/display_map/fold_map.rs index c4e53a0f4361d83429158f106bd81326c8ddb573..405f25219fa6d7bcef03c745aa34fec351d7abd3 100644 --- a/crates/editor/src/display_map/fold_map.rs +++ b/crates/editor/src/display_map/fold_map.rs @@ -289,25 +289,25 @@ impl FoldMapWriter<'_> { let ChunkRendererId::Fold(id) = id else { continue; }; - if let Some(metadata) = self.0.snapshot.fold_metadata_by_id.get(&id).cloned() { - if Some(new_width) != metadata.width { - let buffer_start = metadata.range.start.to_offset(buffer); - let buffer_end = metadata.range.end.to_offset(buffer); - let inlay_range = inlay_snapshot.to_inlay_offset(buffer_start) - ..inlay_snapshot.to_inlay_offset(buffer_end); - edits.push(InlayEdit { - old: inlay_range.clone(), - new: inlay_range.clone(), - }); + if let Some(metadata) = self.0.snapshot.fold_metadata_by_id.get(&id).cloned() + && Some(new_width) != metadata.width + { + let buffer_start = metadata.range.start.to_offset(buffer); + let buffer_end = metadata.range.end.to_offset(buffer); + let inlay_range = inlay_snapshot.to_inlay_offset(buffer_start) + ..inlay_snapshot.to_inlay_offset(buffer_end); + edits.push(InlayEdit { + old: inlay_range.clone(), + new: inlay_range.clone(), + }); - self.0.snapshot.fold_metadata_by_id.insert( - id, - FoldMetadata { - range: metadata.range, - width: Some(new_width), - }, - ); - } + self.0.snapshot.fold_metadata_by_id.insert( + id, + FoldMetadata { + range: metadata.range, + width: Some(new_width), + }, + ); } } @@ -320,13 +320,13 @@ impl FoldMapWriter<'_> { /// Decides where the fold indicators should be; also tracks parts of a source file that are currently folded. /// /// See the [`display_map` module documentation](crate::display_map) for more information. -pub(crate) struct FoldMap { +pub struct FoldMap { snapshot: FoldSnapshot, next_fold_id: FoldId, } impl FoldMap { - pub(crate) fn new(inlay_snapshot: InlaySnapshot) -> (Self, FoldSnapshot) { + pub fn new(inlay_snapshot: InlaySnapshot) -> (Self, FoldSnapshot) { let this = Self { snapshot: FoldSnapshot { folds: SumTree::new(&inlay_snapshot.buffer), @@ -360,7 +360,7 @@ impl FoldMap { (self.snapshot.clone(), edits) } - pub fn write( + pub(crate) fn write( &mut self, inlay_snapshot: InlaySnapshot, edits: Vec, @@ -417,18 +417,18 @@ impl FoldMap { cursor.seek(&InlayOffset(0), Bias::Right); while let Some(mut edit) = inlay_edits_iter.next() { - if let Some(item) = cursor.item() { - if !item.is_fold() { - new_transforms.update_last( - |transform| { - if !transform.is_fold() { - transform.summary.add_summary(&item.summary, &()); - cursor.next(); - } - }, - &(), - ); - } + if let Some(item) = cursor.item() + && !item.is_fold() + { + new_transforms.update_last( + |transform| { + if !transform.is_fold() { + transform.summary.add_summary(&item.summary, &()); + cursor.next(); + } + }, + &(), + ); } new_transforms.append(cursor.slice(&edit.old.start, Bias::Left), &()); edit.new.start -= edit.old.start - *cursor.start(); @@ -491,14 +491,14 @@ impl FoldMap { while folds .peek() - .map_or(false, |(_, fold_range)| fold_range.start < edit.new.end) + .is_some_and(|(_, fold_range)| fold_range.start < edit.new.end) { let (fold, mut fold_range) = folds.next().unwrap(); let sum = new_transforms.summary(); assert!(fold_range.start.0 >= sum.input.len); - while folds.peek().map_or(false, |(next_fold, next_fold_range)| { + while folds.peek().is_some_and(|(next_fold, next_fold_range)| { next_fold_range.start < fold_range.end || (next_fold_range.start == fold_range.end && fold.placeholder.merge_adjacent @@ -529,6 +529,7 @@ impl FoldMap { }, placeholder: Some(TransformPlaceholder { text: ELLIPSIS, + chars: 1, renderer: ChunkRenderer { id: ChunkRendererId::Fold(fold.id), render: Arc::new(move |cx| { @@ -575,14 +576,14 @@ impl FoldMap { for mut edit in inlay_edits { old_transforms.seek(&edit.old.start, Bias::Left); - if old_transforms.item().map_or(false, |t| t.is_fold()) { + if old_transforms.item().is_some_and(|t| t.is_fold()) { edit.old.start = old_transforms.start().0; } let old_start = old_transforms.start().1.0 + (edit.old.start - old_transforms.start().0).0; old_transforms.seek_forward(&edit.old.end, Bias::Right); - if old_transforms.item().map_or(false, |t| t.is_fold()) { + if old_transforms.item().is_some_and(|t| t.is_fold()) { old_transforms.next(); edit.old.end = old_transforms.start().0; } @@ -590,14 +591,14 @@ impl FoldMap { old_transforms.start().1.0 + (edit.old.end - old_transforms.start().0).0; new_transforms.seek(&edit.new.start, Bias::Left); - if new_transforms.item().map_or(false, |t| t.is_fold()) { + if new_transforms.item().is_some_and(|t| t.is_fold()) { edit.new.start = new_transforms.start().0; } let new_start = new_transforms.start().1.0 + (edit.new.start - new_transforms.start().0).0; new_transforms.seek_forward(&edit.new.end, Bias::Right); - if new_transforms.item().map_or(false, |t| t.is_fold()) { + if new_transforms.item().is_some_and(|t| t.is_fold()) { new_transforms.next(); edit.new.end = new_transforms.start().0; } @@ -709,7 +710,7 @@ impl FoldSnapshot { .transforms .cursor::>(&()); cursor.seek(&point, Bias::Right); - if cursor.item().map_or(false, |t| t.is_fold()) { + if cursor.item().is_some_and(|t| t.is_fold()) { if bias == Bias::Left || point == cursor.start().0 { cursor.start().1 } else { @@ -788,7 +789,7 @@ impl FoldSnapshot { let inlay_offset = self.inlay_snapshot.to_inlay_offset(buffer_offset); let mut cursor = self.transforms.cursor::(&()); cursor.seek(&inlay_offset, Bias::Right); - cursor.item().map_or(false, |t| t.placeholder.is_some()) + cursor.item().is_some_and(|t| t.placeholder.is_some()) } pub fn is_line_folded(&self, buffer_row: MultiBufferRow) -> bool { @@ -839,7 +840,7 @@ impl FoldSnapshot { let inlay_end = if transform_cursor .item() - .map_or(true, |transform| transform.is_fold()) + .is_none_or(|transform| transform.is_fold()) { inlay_start } else if range.end < transform_end.0 { @@ -872,6 +873,14 @@ impl FoldSnapshot { .flat_map(|chunk| chunk.text.chars()) } + pub fn chunks_at(&self, start: FoldPoint) -> FoldChunks<'_> { + self.chunks( + start.to_offset(self)..self.len(), + false, + Highlights::default(), + ) + } + #[cfg(test)] pub fn clip_offset(&self, offset: FoldOffset, bias: Bias) -> FoldOffset { if offset > self.len() { @@ -1034,6 +1043,7 @@ struct Transform { #[derive(Clone, Debug)] struct TransformPlaceholder { text: &'static str, + chars: u128, renderer: ChunkRenderer, } @@ -1274,6 +1284,10 @@ pub struct Chunk<'a> { pub is_inlay: bool, /// An optional recipe for how the chunk should be presented. pub renderer: Option, + /// Bitmap of tab character locations in chunk + pub tabs: u128, + /// Bitmap of character locations in chunk + pub chars: u128, } #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] @@ -1348,7 +1362,7 @@ impl FoldChunks<'_> { let inlay_end = if self .transform_cursor .item() - .map_or(true, |transform| transform.is_fold()) + .is_none_or(|transform| transform.is_fold()) { inlay_start } else if range.end < transform_end.0 { @@ -1391,6 +1405,7 @@ impl<'a> Iterator for FoldChunks<'a> { self.output_offset.0 += placeholder.text.len(); return Some(Chunk { text: placeholder.text, + chars: placeholder.chars, renderer: Some(placeholder.renderer.clone()), ..Default::default() }); @@ -1429,6 +1444,16 @@ impl<'a> Iterator for FoldChunks<'a> { chunk.text = &chunk.text [(self.inlay_offset - buffer_chunk_start).0..(chunk_end - buffer_chunk_start).0]; + let bit_end = (chunk_end - buffer_chunk_start).0; + let mask = if bit_end >= 128 { + u128::MAX + } else { + (1u128 << bit_end) - 1 + }; + + chunk.tabs = (chunk.tabs >> (self.inlay_offset - buffer_chunk_start).0) & mask; + chunk.chars = (chunk.chars >> (self.inlay_offset - buffer_chunk_start).0) & mask; + if chunk_end == transform_end { self.transform_cursor.next(); } else if chunk_end == buffer_chunk_end { @@ -1439,6 +1464,8 @@ impl<'a> Iterator for FoldChunks<'a> { self.output_offset.0 += chunk.text.len(); return Some(Chunk { text: chunk.text, + tabs: chunk.tabs, + chars: chunk.chars, syntax_highlight_id: chunk.syntax_highlight_id, highlight_style: chunk.highlight_style, diagnostic_severity: chunk.diagnostic_severity, @@ -1463,7 +1490,7 @@ impl FoldOffset { .transforms .cursor::>(&()); cursor.seek(&self, Bias::Right); - let overshoot = if cursor.item().map_or(true, |t| t.is_fold()) { + let overshoot = if cursor.item().is_none_or(|t| t.is_fold()) { Point::new(0, (self.0 - cursor.start().0.0) as u32) } else { let inlay_offset = cursor.start().1.input.len + self.0 - cursor.start().0.0; @@ -1557,7 +1584,7 @@ mod tests { let buffer = MultiBuffer::build_simple(&sample_text(5, 6, 'a'), cx); let subscription = buffer.update(cx, |buffer, _| buffer.subscribe()); let buffer_snapshot = buffer.read(cx).snapshot(cx); - let (mut inlay_map, inlay_snapshot) = InlayMap::new(buffer_snapshot.clone()); + let (mut inlay_map, inlay_snapshot) = InlayMap::new(buffer_snapshot); let mut map = FoldMap::new(inlay_snapshot.clone()).0; let (mut writer, _, _) = map.write(inlay_snapshot, vec![]); @@ -1636,7 +1663,7 @@ mod tests { let buffer = MultiBuffer::build_simple("abcdefghijkl", cx); let subscription = buffer.update(cx, |buffer, _| buffer.subscribe()); let buffer_snapshot = buffer.read(cx).snapshot(cx); - let (mut inlay_map, inlay_snapshot) = InlayMap::new(buffer_snapshot.clone()); + let (mut inlay_map, inlay_snapshot) = InlayMap::new(buffer_snapshot); { let mut map = FoldMap::new(inlay_snapshot.clone()).0; @@ -1712,7 +1739,7 @@ mod tests { let buffer = MultiBuffer::build_simple(&sample_text(5, 6, 'a'), cx); let subscription = buffer.update(cx, |buffer, _| buffer.subscribe()); let buffer_snapshot = buffer.read(cx).snapshot(cx); - let (mut inlay_map, inlay_snapshot) = InlayMap::new(buffer_snapshot.clone()); + let (mut inlay_map, inlay_snapshot) = InlayMap::new(buffer_snapshot); let mut map = FoldMap::new(inlay_snapshot.clone()).0; let (mut writer, _, _) = map.write(inlay_snapshot.clone(), vec![]); @@ -1720,7 +1747,7 @@ mod tests { (Point::new(0, 2)..Point::new(2, 2), FoldPlaceholder::test()), (Point::new(3, 1)..Point::new(4, 1), FoldPlaceholder::test()), ]); - let (snapshot, _) = map.read(inlay_snapshot.clone(), vec![]); + let (snapshot, _) = map.read(inlay_snapshot, vec![]); assert_eq!(snapshot.text(), "aa⋯cccc\nd⋯eeeee"); let buffer_snapshot = buffer.update(cx, |buffer, cx| { @@ -1747,7 +1774,7 @@ mod tests { (Point::new(1, 2)..Point::new(3, 2), FoldPlaceholder::test()), (Point::new(3, 1)..Point::new(4, 1), FoldPlaceholder::test()), ]); - let (snapshot, _) = map.read(inlay_snapshot.clone(), vec![]); + let (snapshot, _) = map.read(inlay_snapshot, vec![]); let fold_ranges = snapshot .folds_in_range(Point::new(1, 0)..Point::new(1, 3)) .map(|fold| { @@ -1771,9 +1798,9 @@ mod tests { .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) .unwrap_or(10); - let len = rng.gen_range(0..10); + let len = rng.random_range(0..10); let text = RandomCharIter::new(&mut rng).take(len).collect::(); - let buffer = if rng.r#gen() { + let buffer = if rng.random() { MultiBuffer::build_simple(&text, cx) } else { MultiBuffer::build_random(&mut rng, cx) @@ -1782,7 +1809,7 @@ mod tests { let (mut inlay_map, inlay_snapshot) = InlayMap::new(buffer_snapshot.clone()); let mut map = FoldMap::new(inlay_snapshot.clone()).0; - let (mut initial_snapshot, _) = map.read(inlay_snapshot.clone(), vec![]); + let (mut initial_snapshot, _) = map.read(inlay_snapshot, vec![]); let mut snapshot_edits = Vec::new(); let mut next_inlay_id = 0; @@ -1790,7 +1817,7 @@ mod tests { log::info!("text: {:?}", buffer_snapshot.text()); let mut buffer_edits = Vec::new(); let mut inlay_edits = Vec::new(); - match rng.gen_range(0..=100) { + match rng.random_range(0..=100) { 0..=39 => { snapshot_edits.extend(map.randomly_mutate(&mut rng)); } @@ -1800,7 +1827,7 @@ mod tests { } _ => buffer.update(cx, |buffer, cx| { let subscription = buffer.subscribe(); - let edit_count = rng.gen_range(1..=5); + let edit_count = rng.random_range(1..=5); buffer.randomly_mutate(&mut rng, edit_count, cx); buffer_snapshot = buffer.snapshot(cx); let edits = subscription.consume().into_inner(); @@ -1917,10 +1944,14 @@ mod tests { } for _ in 0..5 { - let mut start = snapshot - .clip_offset(FoldOffset(rng.gen_range(0..=snapshot.len().0)), Bias::Left); - let mut end = snapshot - .clip_offset(FoldOffset(rng.gen_range(0..=snapshot.len().0)), Bias::Right); + let mut start = snapshot.clip_offset( + FoldOffset(rng.random_range(0..=snapshot.len().0)), + Bias::Left, + ); + let mut end = snapshot.clip_offset( + FoldOffset(rng.random_range(0..=snapshot.len().0)), + Bias::Right, + ); if start > end { mem::swap(&mut start, &mut end); } @@ -1975,8 +2006,8 @@ mod tests { for _ in 0..5 { let end = - buffer_snapshot.clip_offset(rng.gen_range(0..=buffer_snapshot.len()), Right); - let start = buffer_snapshot.clip_offset(rng.gen_range(0..=end), Left); + buffer_snapshot.clip_offset(rng.random_range(0..=buffer_snapshot.len()), Right); + let start = buffer_snapshot.clip_offset(rng.random_range(0..=end), Left); let expected_folds = map .snapshot .folds @@ -2001,10 +2032,10 @@ mod tests { let text = snapshot.text(); for _ in 0..5 { - let start_row = rng.gen_range(0..=snapshot.max_point().row()); - let start_column = rng.gen_range(0..=snapshot.line_len(start_row)); - let end_row = rng.gen_range(0..=snapshot.max_point().row()); - let end_column = rng.gen_range(0..=snapshot.line_len(end_row)); + let start_row = rng.random_range(0..=snapshot.max_point().row()); + let start_column = rng.random_range(0..=snapshot.line_len(start_row)); + let end_row = rng.random_range(0..=snapshot.max_point().row()); + let end_column = rng.random_range(0..=snapshot.line_len(end_row)); let mut start = snapshot.clip_point(FoldPoint::new(start_row, start_column), Bias::Left); let mut end = snapshot.clip_point(FoldPoint::new(end_row, end_column), Bias::Right); @@ -2068,6 +2099,97 @@ mod tests { ); } + #[gpui::test(iterations = 100)] + fn test_random_chunk_bitmaps(cx: &mut gpui::App, mut rng: StdRng) { + init_test(cx); + + // Generate random buffer using existing test infrastructure + let text_len = rng.random_range(0..10000); + let buffer = if rng.random() { + let text = RandomCharIter::new(&mut rng) + .take(text_len) + .collect::(); + MultiBuffer::build_simple(&text, cx) + } else { + MultiBuffer::build_random(&mut rng, cx) + }; + let buffer_snapshot = buffer.read(cx).snapshot(cx); + let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot); + let (mut fold_map, _) = FoldMap::new(inlay_snapshot.clone()); + + // Perform random mutations + let mutation_count = rng.random_range(1..10); + for _ in 0..mutation_count { + fold_map.randomly_mutate(&mut rng); + } + + let (snapshot, _) = fold_map.read(inlay_snapshot, vec![]); + + // Get all chunks and verify their bitmaps + let chunks = snapshot.chunks( + FoldOffset(0)..FoldOffset(snapshot.len().0), + false, + Highlights::default(), + ); + + for chunk in chunks { + let chunk_text = chunk.text; + let chars_bitmap = chunk.chars; + let tabs_bitmap = chunk.tabs; + + // Check empty chunks have empty bitmaps + if chunk_text.is_empty() { + assert_eq!( + chars_bitmap, 0, + "Empty chunk should have empty chars bitmap" + ); + assert_eq!(tabs_bitmap, 0, "Empty chunk should have empty tabs bitmap"); + continue; + } + + // Verify that chunk text doesn't exceed 128 bytes + assert!( + chunk_text.len() <= 128, + "Chunk text length {} exceeds 128 bytes", + chunk_text.len() + ); + + // Verify chars bitmap + let char_indices = chunk_text + .char_indices() + .map(|(i, _)| i) + .collect::>(); + + for byte_idx in 0..chunk_text.len() { + let should_have_bit = char_indices.contains(&byte_idx); + let has_bit = chars_bitmap & (1 << byte_idx) != 0; + + if has_bit != should_have_bit { + eprintln!("Chunk text bytes: {:?}", chunk_text.as_bytes()); + eprintln!("Char indices: {:?}", char_indices); + eprintln!("Chars bitmap: {:#b}", chars_bitmap); + assert_eq!( + has_bit, should_have_bit, + "Chars bitmap mismatch at byte index {} in chunk {:?}. Expected bit: {}, Got bit: {}", + byte_idx, chunk_text, should_have_bit, has_bit + ); + } + } + + // Verify tabs bitmap + for (byte_idx, byte) in chunk_text.bytes().enumerate() { + let is_tab = byte == b'\t'; + let has_bit = tabs_bitmap & (1 << byte_idx) != 0; + + assert_eq!( + has_bit, is_tab, + "Tabs bitmap mismatch at byte index {} in chunk {:?}. Byte: {:?}, Expected bit: {}, Got bit: {}", + byte_idx, chunk_text, byte as char, is_tab, has_bit + ); + } + } + } + fn init_test(cx: &mut gpui::App) { let store = SettingsStore::test(cx); cx.set_global(store); @@ -2109,17 +2231,17 @@ mod tests { rng: &mut impl Rng, ) -> Vec<(FoldSnapshot, Vec)> { let mut snapshot_edits = Vec::new(); - match rng.gen_range(0..=100) { + match rng.random_range(0..=100) { 0..=39 if !self.snapshot.folds.is_empty() => { let inlay_snapshot = self.snapshot.inlay_snapshot.clone(); let buffer = &inlay_snapshot.buffer; let mut to_unfold = Vec::new(); - for _ in 0..rng.gen_range(1..=3) { - let end = buffer.clip_offset(rng.gen_range(0..=buffer.len()), Right); - let start = buffer.clip_offset(rng.gen_range(0..=end), Left); + for _ in 0..rng.random_range(1..=3) { + let end = buffer.clip_offset(rng.random_range(0..=buffer.len()), Right); + let start = buffer.clip_offset(rng.random_range(0..=end), Left); to_unfold.push(start..end); } - let inclusive = rng.r#gen(); + let inclusive = rng.random(); log::info!("unfolding {:?} (inclusive: {})", to_unfold, inclusive); let (mut writer, snapshot, edits) = self.write(inlay_snapshot, vec![]); snapshot_edits.push((snapshot, edits)); @@ -2130,9 +2252,9 @@ mod tests { let inlay_snapshot = self.snapshot.inlay_snapshot.clone(); let buffer = &inlay_snapshot.buffer; let mut to_fold = Vec::new(); - for _ in 0..rng.gen_range(1..=2) { - let end = buffer.clip_offset(rng.gen_range(0..=buffer.len()), Right); - let start = buffer.clip_offset(rng.gen_range(0..=end), Left); + for _ in 0..rng.random_range(1..=2) { + let end = buffer.clip_offset(rng.random_range(0..=buffer.len()), Right); + let start = buffer.clip_offset(rng.random_range(0..=end), Left); to_fold.push((start..end, FoldPlaceholder::test())); } log::info!("folding {:?}", to_fold); diff --git a/crates/editor/src/display_map/inlay_map.rs b/crates/editor/src/display_map/inlay_map.rs index b296b3e62a39aa2ec8671676e051e94f5f9622cf..9ceb0897d242f710353c2f7a90992b2a39f40958 100644 --- a/crates/editor/src/display_map/inlay_map.rs +++ b/crates/editor/src/display_map/inlay_map.rs @@ -11,7 +11,7 @@ use std::{ sync::Arc, }; use sum_tree::{Bias, Cursor, Dimensions, SumTree}; -use text::{Patch, Rope}; +use text::{ChunkBitmaps, Patch, Rope}; use ui::{ActiveTheme, IntoElement as _, ParentElement as _, Styled as _, div}; use super::{Highlights, custom_highlights::CustomHighlightsChunks, fold_map::ChunkRendererId}; @@ -48,7 +48,7 @@ pub struct Inlay { impl Inlay { pub fn hint(id: usize, position: Anchor, hint: &project::InlayHint) -> Self { let mut text = hint.text(); - if hint.padding_right && text.chars_at(text.len().saturating_sub(1)).next() != Some(' ') { + if hint.padding_right && text.reversed_chars_at(text.len()).next() != Some(' ') { text.push(" "); } if hint.padding_left && text.chars_at(0).next() != Some(' ') { @@ -245,8 +245,9 @@ pub struct InlayChunks<'a> { transforms: Cursor<'a, Transform, Dimensions>, buffer_chunks: CustomHighlightsChunks<'a>, buffer_chunk: Option>, - inlay_chunks: Option>, - inlay_chunk: Option<&'a str>, + inlay_chunks: Option>, + /// text, char bitmap, tabs bitmap + inlay_chunk: Option>, output_offset: InlayOffset, max_output_offset: InlayOffset, highlight_styles: HighlightStyles, @@ -316,11 +317,25 @@ impl<'a> Iterator for InlayChunks<'a> { let (prefix, suffix) = chunk.text.split_at(split_index); + let (chars, tabs) = if split_index == 128 { + let output = (chunk.chars, chunk.tabs); + chunk.chars = 0; + chunk.tabs = 0; + output + } else { + let mask = (1 << split_index) - 1; + let output = (chunk.chars & mask, chunk.tabs & mask); + chunk.chars = chunk.chars >> split_index; + chunk.tabs = chunk.tabs >> split_index; + output + }; chunk.text = suffix; self.output_offset.0 += prefix.len(); InlayChunk { chunk: Chunk { text: prefix, + chars, + tabs, ..chunk.clone() }, renderer: None, @@ -385,9 +400,9 @@ impl<'a> Iterator for InlayChunks<'a> { next_inlay_highlight_endpoint = usize::MAX; } else { next_inlay_highlight_endpoint = range.end - offset_in_inlay.0; - highlight_style - .get_or_insert_with(Default::default) - .highlight(*style); + highlight_style = highlight_style + .map(|highlight| highlight.highlight(*style)) + .or_else(|| Some(*style)); } } else { next_inlay_highlight_endpoint = usize::MAX; @@ -397,9 +412,14 @@ impl<'a> Iterator for InlayChunks<'a> { let start = offset_in_inlay; let end = cmp::min(self.max_output_offset, self.transforms.end().0) - self.transforms.start().0; - inlay.text.chunks_in_range(start.0..end.0) + let chunks = inlay.text.chunks_in_range(start.0..end.0); + text::ChunkWithBitmaps(chunks) }); - let inlay_chunk = self + let ChunkBitmaps { + text: inlay_chunk, + chars, + tabs, + } = self .inlay_chunk .get_or_insert_with(|| inlay_chunks.next().unwrap()); @@ -421,6 +441,20 @@ impl<'a> Iterator for InlayChunks<'a> { let (chunk, remainder) = inlay_chunk.split_at(split_index); *inlay_chunk = remainder; + + let (chars, tabs) = if split_index == 128 { + let output = (*chars, *tabs); + *chars = 0; + *tabs = 0; + output + } else { + let mask = (1 << split_index as u32) - 1; + let output = (*chars & mask, *tabs & mask); + *chars = *chars >> split_index; + *tabs = *tabs >> split_index; + output + }; + if inlay_chunk.is_empty() { self.inlay_chunk = None; } @@ -430,6 +464,8 @@ impl<'a> Iterator for InlayChunks<'a> { InlayChunk { chunk: Chunk { text: chunk, + chars, + tabs, highlight_style, is_inlay: true, ..Chunk::default() @@ -557,11 +593,11 @@ impl InlayMap { let mut buffer_edits_iter = buffer_edits.iter().peekable(); while let Some(buffer_edit) = buffer_edits_iter.next() { new_transforms.append(cursor.slice(&buffer_edit.old.start, Bias::Left), &()); - if let Some(Transform::Isomorphic(transform)) = cursor.item() { - if cursor.end().0 == buffer_edit.old.start { - push_isomorphic(&mut new_transforms, *transform); - cursor.next(); - } + if let Some(Transform::Isomorphic(transform)) = cursor.item() + && cursor.end().0 == buffer_edit.old.start + { + push_isomorphic(&mut new_transforms, *transform); + cursor.next(); } // Remove all the inlays and transforms contained by the edit. @@ -625,7 +661,7 @@ impl InlayMap { // we can push its remainder. if buffer_edits_iter .peek() - .map_or(true, |edit| edit.old.start >= cursor.end().0) + .is_none_or(|edit| edit.old.start >= cursor.end().0) { let transform_start = new_transforms.summary().input.len; let transform_end = @@ -719,14 +755,18 @@ impl InlayMap { let mut to_remove = Vec::new(); let mut to_insert = Vec::new(); let snapshot = &mut self.snapshot; - for i in 0..rng.gen_range(1..=5) { - if self.inlays.is_empty() || rng.r#gen() { + for i in 0..rng.random_range(1..=5) { + if self.inlays.is_empty() || rng.random() { let position = snapshot.buffer.random_byte_range(0, rng).start; - let bias = if rng.r#gen() { Bias::Left } else { Bias::Right }; - let len = if rng.gen_bool(0.01) { + let bias = if rng.random() { + Bias::Left + } else { + Bias::Right + }; + let len = if rng.random_bool(0.01) { 0 } else { - rng.gen_range(1..=5) + rng.random_range(1..=5) }; let text = util::RandomCharIter::new(&mut *rng) .filter(|ch| *ch != '\r') @@ -1220,6 +1260,7 @@ mod tests { use std::{any::TypeId, cmp::Reverse, env, sync::Arc}; use sum_tree::TreeMap; use text::Patch; + use util::RandomCharIter; use util::post_inc; #[test] @@ -1305,6 +1346,29 @@ mod tests { ); } + #[gpui::test] + fn test_inlay_hint_padding_with_multibyte_chars() { + assert_eq!( + Inlay::hint( + 0, + Anchor::min(), + &InlayHint { + label: InlayHintLabel::String("🎨".to_string()), + position: text::Anchor::default(), + padding_left: true, + padding_right: true, + tooltip: None, + kind: None, + resolve_state: ResolveState::Resolved, + }, + ) + .text + .to_string(), + " 🎨 ", + "Should pad single emoji correctly" + ); + } + #[gpui::test] fn test_basic_inlays(cx: &mut App) { let buffer = MultiBuffer::build_simple("abcdefghi", cx); @@ -1642,8 +1706,8 @@ mod tests { .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) .unwrap_or(10); - let len = rng.gen_range(0..30); - let buffer = if rng.r#gen() { + let len = rng.random_range(0..30); + let buffer = if rng.random() { let text = util::RandomCharIter::new(&mut rng) .take(len) .collect::(); @@ -1660,7 +1724,7 @@ mod tests { let mut prev_inlay_text = inlay_snapshot.text(); let mut buffer_edits = Vec::new(); - match rng.gen_range(0..=100) { + match rng.random_range(0..=100) { 0..=50 => { let (snapshot, edits) = inlay_map.randomly_mutate(&mut next_inlay_id, &mut rng); log::info!("mutated text: {:?}", snapshot.text()); @@ -1668,7 +1732,7 @@ mod tests { } _ => buffer.update(cx, |buffer, cx| { let subscription = buffer.subscribe(); - let edit_count = rng.gen_range(1..=5); + let edit_count = rng.random_range(1..=5); buffer.randomly_mutate(&mut rng, edit_count, cx); buffer_snapshot = buffer.snapshot(cx); let edits = subscription.consume().into_inner(); @@ -1717,7 +1781,7 @@ mod tests { } let mut text_highlights = TextHighlights::default(); - let text_highlight_count = rng.gen_range(0_usize..10); + let text_highlight_count = rng.random_range(0_usize..10); let mut text_highlight_ranges = (0..text_highlight_count) .map(|_| buffer_snapshot.random_byte_range(0, &mut rng)) .collect::>(); @@ -1739,10 +1803,10 @@ mod tests { let mut inlay_highlights = InlayHighlights::default(); if !inlays.is_empty() { - let inlay_highlight_count = rng.gen_range(0..inlays.len()); + let inlay_highlight_count = rng.random_range(0..inlays.len()); let mut inlay_indices = BTreeSet::default(); while inlay_indices.len() < inlay_highlight_count { - inlay_indices.insert(rng.gen_range(0..inlays.len())); + inlay_indices.insert(rng.random_range(0..inlays.len())); } let new_highlights = TreeMap::from_ordered_entries( inlay_indices @@ -1759,8 +1823,8 @@ mod tests { }), n => { let inlay_text = inlay.text.to_string(); - let mut highlight_end = rng.gen_range(1..n); - let mut highlight_start = rng.gen_range(0..highlight_end); + let mut highlight_end = rng.random_range(1..n); + let mut highlight_start = rng.random_range(0..highlight_end); while !inlay_text.is_char_boundary(highlight_end) { highlight_end += 1; } @@ -1782,9 +1846,9 @@ mod tests { } for _ in 0..5 { - let mut end = rng.gen_range(0..=inlay_snapshot.len().0); + let mut end = rng.random_range(0..=inlay_snapshot.len().0); end = expected_text.clip_offset(end, Bias::Right); - let mut start = rng.gen_range(0..=end); + let mut start = rng.random_range(0..=end); start = expected_text.clip_offset(start, Bias::Right); let range = InlayOffset(start)..InlayOffset(end); @@ -1939,6 +2003,102 @@ mod tests { } } + #[gpui::test(iterations = 100)] + fn test_random_chunk_bitmaps(cx: &mut gpui::App, mut rng: StdRng) { + init_test(cx); + + // Generate random buffer using existing test infrastructure + let text_len = rng.random_range(0..10000); + let buffer = if rng.random() { + let text = RandomCharIter::new(&mut rng) + .take(text_len) + .collect::(); + MultiBuffer::build_simple(&text, cx) + } else { + MultiBuffer::build_random(&mut rng, cx) + }; + + let buffer_snapshot = buffer.read(cx).snapshot(cx); + let (mut inlay_map, _) = InlayMap::new(buffer_snapshot.clone()); + + // Perform random mutations to add inlays + let mut next_inlay_id = 0; + let mutation_count = rng.random_range(1..10); + for _ in 0..mutation_count { + inlay_map.randomly_mutate(&mut next_inlay_id, &mut rng); + } + + let (snapshot, _) = inlay_map.sync(buffer_snapshot, vec![]); + + // Get all chunks and verify their bitmaps + let chunks = snapshot.chunks( + InlayOffset(0)..InlayOffset(snapshot.len().0), + false, + Highlights::default(), + ); + + for chunk in chunks.into_iter().map(|inlay_chunk| inlay_chunk.chunk) { + let chunk_text = chunk.text; + let chars_bitmap = chunk.chars; + let tabs_bitmap = chunk.tabs; + + // Check empty chunks have empty bitmaps + if chunk_text.is_empty() { + assert_eq!( + chars_bitmap, 0, + "Empty chunk should have empty chars bitmap" + ); + assert_eq!(tabs_bitmap, 0, "Empty chunk should have empty tabs bitmap"); + continue; + } + + // Verify that chunk text doesn't exceed 128 bytes + assert!( + chunk_text.len() <= 128, + "Chunk text length {} exceeds 128 bytes", + chunk_text.len() + ); + + // Verify chars bitmap + let char_indices = chunk_text + .char_indices() + .map(|(i, _)| i) + .collect::>(); + + for byte_idx in 0..chunk_text.len() { + let should_have_bit = char_indices.contains(&byte_idx); + let has_bit = chars_bitmap & (1 << byte_idx) != 0; + + if has_bit != should_have_bit { + eprintln!("Chunk text bytes: {:?}", chunk_text.as_bytes()); + eprintln!("Char indices: {:?}", char_indices); + eprintln!("Chars bitmap: {:#b}", chars_bitmap); + assert_eq!( + has_bit, should_have_bit, + "Chars bitmap mismatch at byte index {} in chunk {:?}. Expected bit: {}, Got bit: {}", + byte_idx, chunk_text, should_have_bit, has_bit + ); + } + } + + // Verify tabs bitmap + for (byte_idx, byte) in chunk_text.bytes().enumerate() { + let is_tab = byte == b'\t'; + let has_bit = tabs_bitmap & (1 << byte_idx) != 0; + + if has_bit != is_tab { + eprintln!("Chunk text bytes: {:?}", chunk_text.as_bytes()); + eprintln!("Tabs bitmap: {:#b}", tabs_bitmap); + assert_eq!( + has_bit, is_tab, + "Tabs bitmap mismatch at byte index {} in chunk {:?}. Byte: {:?}, Expected bit: {}, Got bit: {}", + byte_idx, chunk_text, byte as char, is_tab, has_bit + ); + } + } + } + } + fn init_test(cx: &mut App) { let store = SettingsStore::test(cx); cx.set_global(store); diff --git a/crates/editor/src/display_map/invisibles.rs b/crates/editor/src/display_map/invisibles.rs index 199986f2a41c82894acc0259be578a28e785a235..0712ddf9e2e53c22081c6fa63ebb4baeced37f78 100644 --- a/crates/editor/src/display_map/invisibles.rs +++ b/crates/editor/src/display_map/invisibles.rs @@ -36,8 +36,8 @@ pub fn is_invisible(c: char) -> bool { } else if c >= '\u{7f}' { c <= '\u{9f}' || (c.is_whitespace() && c != IDEOGRAPHIC_SPACE) - || contains(c, &FORMAT) - || contains(c, &OTHER) + || contains(c, FORMAT) + || contains(c, OTHER) } else { false } @@ -50,7 +50,7 @@ pub fn replacement(c: char) -> Option<&'static str> { Some(C0_SYMBOLS[c as usize]) } else if c == '\x7f' { Some(DEL) - } else if contains(c, &PRESERVE) { + } else if contains(c, PRESERVE) { None } else { Some("\u{2007}") // fixed width space @@ -61,14 +61,14 @@ pub fn replacement(c: char) -> Option<&'static str> { // but could if we tracked state in the classifier. const IDEOGRAPHIC_SPACE: char = '\u{3000}'; -const C0_SYMBOLS: &'static [&'static str] = &[ +const C0_SYMBOLS: &[&str] = &[ "␀", "␁", "␂", "␃", "␄", "␅", "␆", "␇", "␈", "␉", "␊", "␋", "␌", "␍", "␎", "␏", "␐", "␑", "␒", "␓", "␔", "␕", "␖", "␗", "␘", "␙", "␚", "␛", "␜", "␝", "␞", "␟", ]; -const DEL: &'static str = "␡"; +const DEL: &str = "␡"; // generated using ucd-generate: ucd-generate general-category --include Format --chars ucd-16.0.0 -pub const FORMAT: &'static [(char, char)] = &[ +pub const FORMAT: &[(char, char)] = &[ ('\u{ad}', '\u{ad}'), ('\u{600}', '\u{605}'), ('\u{61c}', '\u{61c}'), @@ -93,7 +93,7 @@ pub const FORMAT: &'static [(char, char)] = &[ ]; // hand-made base on https://invisible-characters.com (Excluding Cf) -pub const OTHER: &'static [(char, char)] = &[ +pub const OTHER: &[(char, char)] = &[ ('\u{034f}', '\u{034f}'), ('\u{115F}', '\u{1160}'), ('\u{17b4}', '\u{17b5}'), @@ -107,7 +107,7 @@ pub const OTHER: &'static [(char, char)] = &[ ]; // a subset of FORMAT/OTHER that may appear within glyphs -const PRESERVE: &'static [(char, char)] = &[ +const PRESERVE: &[(char, char)] = &[ ('\u{034f}', '\u{034f}'), ('\u{200d}', '\u{200d}'), ('\u{17b4}', '\u{17b5}'), diff --git a/crates/editor/src/display_map/tab_map.rs b/crates/editor/src/display_map/tab_map.rs index eb5d57d48472bdd4b2d4f150c40da05c7e422e19..e42d17123dfce9d0ca8c4faa84eabbaabf5707f4 100644 --- a/crates/editor/src/display_map/tab_map.rs +++ b/crates/editor/src/display_map/tab_map.rs @@ -2,6 +2,7 @@ use super::{ Highlights, fold_map::{self, Chunk, FoldChunks, FoldEdit, FoldPoint, FoldSnapshot}, }; + use language::Point; use multi_buffer::MultiBufferSnapshot; use std::{cmp, mem, num::NonZeroU32, ops::Range}; @@ -72,6 +73,7 @@ impl TabMap { false, Highlights::default(), ) { + // todo(performance use tabs bitmask) for (ix, _) in chunk.text.match_indices('\t') { let offset_from_edit = offset_from_edit + (ix as u32); if first_tab_offset.is_none() { @@ -116,7 +118,7 @@ impl TabMap { state.new.end = edit.new.end; Some(None) // Skip this edit, it's merged } else { - let new_state = edit.clone(); + let new_state = edit; let result = Some(Some(state.clone())); // Yield the previous edit **state = new_state; result @@ -230,7 +232,7 @@ impl TabSnapshot { } } - pub fn chunks<'a>( + pub(crate) fn chunks<'a>( &'a self, range: Range, language_aware: bool, @@ -299,21 +301,29 @@ impl TabSnapshot { } pub fn to_tab_point(&self, input: FoldPoint) -> TabPoint { - let chars = self.fold_snapshot.chars_at(FoldPoint::new(input.row(), 0)); - let expanded = self.expand_tabs(chars, input.column()); + let chunks = self.fold_snapshot.chunks_at(FoldPoint::new(input.row(), 0)); + let tab_cursor = TabStopCursor::new(chunks); + let expanded = self.expand_tabs(tab_cursor, input.column()); TabPoint::new(input.row(), expanded) } pub fn to_fold_point(&self, output: TabPoint, bias: Bias) -> (FoldPoint, u32, u32) { - let chars = self.fold_snapshot.chars_at(FoldPoint::new(output.row(), 0)); + let chunks = self + .fold_snapshot + .chunks_at(FoldPoint::new(output.row(), 0)); + + let tab_cursor = TabStopCursor::new(chunks); let expanded = output.column(); let (collapsed, expanded_char_column, to_next_stop) = - self.collapse_tabs(chars, expanded, bias); - ( + self.collapse_tabs(tab_cursor, expanded, bias); + + let result = ( FoldPoint::new(output.row(), collapsed), expanded_char_column, to_next_stop, - ) + ); + + result } pub fn make_tab_point(&self, point: Point, bias: Bias) -> TabPoint { @@ -330,72 +340,90 @@ impl TabSnapshot { .to_buffer_point(inlay_point) } - fn expand_tabs(&self, chars: impl Iterator, column: u32) -> u32 { + fn expand_tabs<'a, I>(&self, mut cursor: TabStopCursor<'a, I>, column: u32) -> u32 + where + I: Iterator>, + { let tab_size = self.tab_size.get(); - let mut expanded_chars = 0; - let mut expanded_bytes = 0; - let mut collapsed_bytes = 0; let end_column = column.min(self.max_expansion_column); - for c in chars { - if collapsed_bytes >= end_column { - break; - } - if c == '\t' { - let tab_len = tab_size - expanded_chars % tab_size; - expanded_bytes += tab_len; - expanded_chars += tab_len; - } else { - expanded_bytes += c.len_utf8() as u32; - expanded_chars += 1; - } - collapsed_bytes += c.len_utf8() as u32; + let mut seek_target = end_column; + let mut tab_count = 0; + let mut expanded_tab_len = 0; + + while let Some(tab_stop) = cursor.seek(seek_target) { + let expanded_chars_old = tab_stop.char_offset + expanded_tab_len - tab_count; + let tab_len = tab_size - ((expanded_chars_old - 1) % tab_size); + tab_count += 1; + expanded_tab_len += tab_len; + + seek_target = end_column - cursor.byte_offset; } + + let left_over_char_bytes = if !cursor.is_char_boundary() { + cursor.bytes_until_next_char().unwrap_or(0) as u32 + } else { + 0 + }; + + let collapsed_bytes = cursor.byte_offset() + left_over_char_bytes; + let expanded_bytes = + cursor.byte_offset() + expanded_tab_len - tab_count + left_over_char_bytes; + expanded_bytes + column.saturating_sub(collapsed_bytes) } - fn collapse_tabs( + fn collapse_tabs<'a, I>( &self, - chars: impl Iterator, + mut cursor: TabStopCursor<'a, I>, column: u32, bias: Bias, - ) -> (u32, u32, u32) { + ) -> (u32, u32, u32) + where + I: Iterator>, + { let tab_size = self.tab_size.get(); - - let mut expanded_bytes = 0; - let mut expanded_chars = 0; - let mut collapsed_bytes = 0; - for c in chars { - if expanded_bytes >= column { - break; - } - if collapsed_bytes >= self.max_expansion_column { - break; - } - - if c == '\t' { - let tab_len = tab_size - (expanded_chars % tab_size); - expanded_chars += tab_len; - expanded_bytes += tab_len; - if expanded_bytes > column { - expanded_chars -= expanded_bytes - column; - return match bias { - Bias::Left => (collapsed_bytes, expanded_chars, expanded_bytes - column), - Bias::Right => (collapsed_bytes + 1, expanded_chars, 0), - }; - } + let mut collapsed_column = column; + let mut seek_target = column.min(self.max_expansion_column); + let mut tab_count = 0; + let mut expanded_tab_len = 0; + + while let Some(tab_stop) = cursor.seek(seek_target) { + // Calculate how much we want to expand this tab stop (into spaces) + let expanded_chars_old = tab_stop.char_offset + expanded_tab_len - tab_count; + let tab_len = tab_size - ((expanded_chars_old - 1) % tab_size); + // Increment tab count + tab_count += 1; + // The count of how many spaces we've added to this line in place of tab bytes + expanded_tab_len += tab_len; + + // The count of bytes at this point in the iteration while considering tab_count and previous expansions + let expanded_bytes = tab_stop.byte_offset + expanded_tab_len - tab_count; + + // Did we expand past the search target? + if expanded_bytes > column { + let mut expanded_chars = tab_stop.char_offset + expanded_tab_len - tab_count; + // We expanded past the search target, so need to account for the offshoot + expanded_chars -= expanded_bytes - column; + return match bias { + Bias::Left => ( + cursor.byte_offset() - 1, + expanded_chars, + expanded_bytes - column, + ), + Bias::Right => (cursor.byte_offset(), expanded_chars, 0), + }; } else { - expanded_chars += 1; - expanded_bytes += c.len_utf8() as u32; - } - - if expanded_bytes > column && matches!(bias, Bias::Left) { - expanded_chars -= 1; - break; + // otherwise we only want to move the cursor collapse column forward + collapsed_column = collapsed_column - tab_len + 1; + seek_target = (collapsed_column - cursor.byte_offset) + .min(self.max_expansion_column - cursor.byte_offset); } - - collapsed_bytes += c.len_utf8() as u32; } + + let collapsed_bytes = cursor.byte_offset(); + let expanded_bytes = cursor.byte_offset() + expanded_tab_len - tab_count; + let expanded_chars = cursor.char_offset() + expanded_tab_len - tab_count; ( collapsed_bytes + column.saturating_sub(expanded_bytes), expanded_chars, @@ -523,6 +551,7 @@ impl TabChunks<'_> { self.chunk = Chunk { text: &SPACES[0..(to_next_stop as usize)], is_tab: true, + chars: (1u128 << to_next_stop) - 1, ..Default::default() }; self.inside_leading_tab = to_next_stop > 0; @@ -546,18 +575,37 @@ impl<'a> Iterator for TabChunks<'a> { } } + //todo(improve performance by using tab cursor) for (ix, c) in self.chunk.text.char_indices() { match c { '\t' => { if ix > 0 { let (prefix, suffix) = self.chunk.text.split_at(ix); + + let (chars, tabs) = if ix == 128 { + let output = (self.chunk.chars, self.chunk.tabs); + self.chunk.chars = 0; + self.chunk.tabs = 0; + output + } else { + let mask = (1 << ix) - 1; + let output = (self.chunk.chars & mask, self.chunk.tabs & mask); + self.chunk.chars = self.chunk.chars >> ix; + self.chunk.tabs = self.chunk.tabs >> ix; + output + }; + self.chunk.text = suffix; return Some(Chunk { text: prefix, + chars, + tabs, ..self.chunk.clone() }); } else { self.chunk.text = &self.chunk.text[1..]; + self.chunk.tabs >>= 1; + self.chunk.chars >>= 1; let tab_size = if self.input_column < self.max_expansion_column { self.tab_size.get() } else { @@ -575,6 +623,8 @@ impl<'a> Iterator for TabChunks<'a> { return Some(Chunk { text: &SPACES[..len as usize], is_tab: true, + chars: (1 << len) - 1, + tabs: 0, ..self.chunk.clone() }); } @@ -603,21 +653,270 @@ mod tests { use super::*; use crate::{ MultiBuffer, - display_map::{fold_map::FoldMap, inlay_map::InlayMap}, + display_map::{ + fold_map::{FoldMap, FoldOffset}, + inlay_map::InlayMap, + }, }; use rand::{Rng, prelude::StdRng}; + use util; + + impl TabSnapshot { + fn expected_collapse_tabs( + &self, + chars: impl Iterator, + column: u32, + bias: Bias, + ) -> (u32, u32, u32) { + let tab_size = self.tab_size.get(); + + let mut expanded_bytes = 0; + let mut expanded_chars = 0; + let mut collapsed_bytes = 0; + for c in chars { + if expanded_bytes >= column { + break; + } + if collapsed_bytes >= self.max_expansion_column { + break; + } + + if c == '\t' { + let tab_len = tab_size - (expanded_chars % tab_size); + expanded_chars += tab_len; + expanded_bytes += tab_len; + if expanded_bytes > column { + expanded_chars -= expanded_bytes - column; + return match bias { + Bias::Left => { + (collapsed_bytes, expanded_chars, expanded_bytes - column) + } + Bias::Right => (collapsed_bytes + 1, expanded_chars, 0), + }; + } + } else { + expanded_chars += 1; + expanded_bytes += c.len_utf8() as u32; + } + + if expanded_bytes > column && matches!(bias, Bias::Left) { + expanded_chars -= 1; + break; + } + + collapsed_bytes += c.len_utf8() as u32; + } + + ( + collapsed_bytes + column.saturating_sub(expanded_bytes), + expanded_chars, + 0, + ) + } + + pub fn expected_to_tab_point(&self, input: FoldPoint) -> TabPoint { + let chars = self.fold_snapshot.chars_at(FoldPoint::new(input.row(), 0)); + let expanded = self.expected_expand_tabs(chars, input.column()); + TabPoint::new(input.row(), expanded) + } + + fn expected_expand_tabs(&self, chars: impl Iterator, column: u32) -> u32 { + let tab_size = self.tab_size.get(); + + let mut expanded_chars = 0; + let mut expanded_bytes = 0; + let mut collapsed_bytes = 0; + let end_column = column.min(self.max_expansion_column); + for c in chars { + if collapsed_bytes >= end_column { + break; + } + if c == '\t' { + let tab_len = tab_size - expanded_chars % tab_size; + expanded_bytes += tab_len; + expanded_chars += tab_len; + } else { + expanded_bytes += c.len_utf8() as u32; + expanded_chars += 1; + } + collapsed_bytes += c.len_utf8() as u32; + } + + expanded_bytes + column.saturating_sub(collapsed_bytes) + } + + fn expected_to_fold_point(&self, output: TabPoint, bias: Bias) -> (FoldPoint, u32, u32) { + let chars = self.fold_snapshot.chars_at(FoldPoint::new(output.row(), 0)); + let expanded = output.column(); + let (collapsed, expanded_char_column, to_next_stop) = + self.expected_collapse_tabs(chars, expanded, bias); + ( + FoldPoint::new(output.row(), collapsed), + expanded_char_column, + to_next_stop, + ) + } + } #[gpui::test] fn test_expand_tabs(cx: &mut gpui::App) { + let test_values = [ + ("κg🏀 f\nwo🏀❌by🍐❎β🍗c\tβ❎ \ncλ🎉", 17), + (" \twςe", 4), + ("fε", 1), + ("i❎\t", 3), + ]; let buffer = MultiBuffer::build_simple("", cx); let buffer_snapshot = buffer.read(cx).snapshot(cx); - let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot.clone()); + let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot); let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); let (_, tab_snapshot) = TabMap::new(fold_snapshot, 4.try_into().unwrap()); - assert_eq!(tab_snapshot.expand_tabs("\t".chars(), 0), 0); - assert_eq!(tab_snapshot.expand_tabs("\t".chars(), 1), 4); - assert_eq!(tab_snapshot.expand_tabs("\ta".chars(), 2), 5); + for (text, column) in test_values { + let mut tabs = 0u128; + let mut chars = 0u128; + for (idx, c) in text.char_indices() { + if c == '\t' { + tabs |= 1 << idx; + } + chars |= 1 << idx; + } + + let chunks = [Chunk { + text, + tabs, + chars, + ..Default::default() + }]; + + let cursor = TabStopCursor::new(chunks); + + assert_eq!( + tab_snapshot.expected_expand_tabs(text.chars(), column), + tab_snapshot.expand_tabs(cursor, column) + ); + } + } + + #[gpui::test] + fn test_collapse_tabs(cx: &mut gpui::App) { + let input = "A\tBC\tDEF\tG\tHI\tJ\tK\tL\tM"; + + let buffer = MultiBuffer::build_simple(input, cx); + let buffer_snapshot = buffer.read(cx).snapshot(cx); + let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot); + let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); + let (_, tab_snapshot) = TabMap::new(fold_snapshot, 4.try_into().unwrap()); + + for (ix, _) in input.char_indices() { + let range = TabPoint::new(0, ix as u32)..tab_snapshot.max_point(); + + assert_eq!( + tab_snapshot.expected_to_fold_point(range.start, Bias::Left), + tab_snapshot.to_fold_point(range.start, Bias::Left), + "Failed with tab_point at column {ix}" + ); + assert_eq!( + tab_snapshot.expected_to_fold_point(range.start, Bias::Right), + tab_snapshot.to_fold_point(range.start, Bias::Right), + "Failed with tab_point at column {ix}" + ); + + assert_eq!( + tab_snapshot.expected_to_fold_point(range.end, Bias::Left), + tab_snapshot.to_fold_point(range.end, Bias::Left), + "Failed with tab_point at column {ix}" + ); + assert_eq!( + tab_snapshot.expected_to_fold_point(range.end, Bias::Right), + tab_snapshot.to_fold_point(range.end, Bias::Right), + "Failed with tab_point at column {ix}" + ); + } + } + + #[gpui::test] + fn test_to_fold_point_panic_reproduction(cx: &mut gpui::App) { + // This test reproduces a specific panic where to_fold_point returns incorrect results + let _text = "use macro_rules_attribute::apply;\nuse serde_json::Value;\nuse smol::{\n io::AsyncReadExt,\n process::{Command, Stdio},\n};\nuse smol_macros::main;\nuse std::io;\n\nfn test_random() {\n // Generate a random value\n let random_value = std::time::SystemTime::now()\n .duration_since(std::time::UNIX_EPOCH)\n .unwrap()\n .as_secs()\n % 100;\n\n // Create some complex nested data structures\n let mut vector = Vec::new();\n for i in 0..random_value {\n vector.push(i);\n }\n "; + + let text = "γ\tw⭐\n🍐🍗 \t"; + let buffer = MultiBuffer::build_simple(text, cx); + let buffer_snapshot = buffer.read(cx).snapshot(cx); + let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot); + let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); + let (_, tab_snapshot) = TabMap::new(fold_snapshot, 4.try_into().unwrap()); + + // This should panic with the expected vs actual mismatch + let tab_point = TabPoint::new(0, 9); + let result = tab_snapshot.to_fold_point(tab_point, Bias::Left); + let expected = tab_snapshot.expected_to_fold_point(tab_point, Bias::Left); + + assert_eq!(result, expected); + } + + #[gpui::test(iterations = 100)] + fn test_collapse_tabs_random(cx: &mut gpui::App, mut rng: StdRng) { + // Generate random input string with up to 200 characters including tabs + // to stay within the MAX_EXPANSION_COLUMN limit of 256 + let len = rng.random_range(0..=2048); + let tab_size = NonZeroU32::new(rng.random_range(1..=4)).unwrap(); + let mut input = String::with_capacity(len); + + for _ in 0..len { + if rng.random_bool(0.1) { + // 10% chance of inserting a tab + input.push('\t'); + } else { + // 90% chance of inserting a random ASCII character (excluding tab, newline, carriage return) + let ch = loop { + let ascii_code = rng.random_range(32..=126); // printable ASCII range + let ch = ascii_code as u8 as char; + if ch != '\t' { + break ch; + } + }; + input.push(ch); + } + } + + let buffer = MultiBuffer::build_simple(&input, cx); + let buffer_snapshot = buffer.read(cx).snapshot(cx); + let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot); + let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); + let (_, mut tab_snapshot) = TabMap::new(fold_snapshot, 4.try_into().unwrap()); + tab_snapshot.max_expansion_column = rng.random_range(0..323); + tab_snapshot.tab_size = tab_size; + + for (ix, _) in input.char_indices() { + let range = TabPoint::new(0, ix as u32)..tab_snapshot.max_point(); + + assert_eq!( + tab_snapshot.expected_to_fold_point(range.start, Bias::Left), + tab_snapshot.to_fold_point(range.start, Bias::Left), + "Failed with input: {}, with idx: {ix}", + input + ); + assert_eq!( + tab_snapshot.expected_to_fold_point(range.start, Bias::Right), + tab_snapshot.to_fold_point(range.start, Bias::Right), + "Failed with input: {}, with idx: {ix}", + input + ); + + assert_eq!( + tab_snapshot.expected_to_fold_point(range.end, Bias::Left), + tab_snapshot.to_fold_point(range.end, Bias::Left), + "Failed with input: {}, with idx: {ix}", + input + ); + assert_eq!( + tab_snapshot.expected_to_fold_point(range.end, Bias::Right), + tab_snapshot.to_fold_point(range.end, Bias::Right), + "Failed with input: {}, with idx: {ix}", + input + ); + } } #[gpui::test] @@ -628,7 +927,7 @@ mod tests { let buffer = MultiBuffer::build_simple(input, cx); let buffer_snapshot = buffer.read(cx).snapshot(cx); - let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot.clone()); + let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot); let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); let (_, mut tab_snapshot) = TabMap::new(fold_snapshot, 4.try_into().unwrap()); @@ -675,7 +974,7 @@ mod tests { let buffer = MultiBuffer::build_simple(input, cx); let buffer_snapshot = buffer.read(cx).snapshot(cx); - let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot.clone()); + let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot); let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); let (_, mut tab_snapshot) = TabMap::new(fold_snapshot, 4.try_into().unwrap()); @@ -689,7 +988,7 @@ mod tests { let buffer = MultiBuffer::build_simple(input, cx); let buffer_snapshot = buffer.read(cx).snapshot(cx); - let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot.clone()); + let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot); let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); let (_, tab_snapshot) = TabMap::new(fold_snapshot, 4.try_into().unwrap()); @@ -736,9 +1035,9 @@ mod tests { #[gpui::test(iterations = 100)] fn test_random_tabs(cx: &mut gpui::App, mut rng: StdRng) { - let tab_size = NonZeroU32::new(rng.gen_range(1..=4)).unwrap(); - let len = rng.gen_range(0..30); - let buffer = if rng.r#gen() { + let tab_size = NonZeroU32::new(rng.random_range(1..=4)).unwrap(); + let len = rng.random_range(0..30); + let buffer = if rng.random() { let text = util::RandomCharIter::new(&mut rng) .take(len) .collect::(); @@ -749,7 +1048,7 @@ mod tests { let buffer_snapshot = buffer.read(cx).snapshot(cx); log::info!("Buffer text: {:?}", buffer_snapshot.text()); - let (mut inlay_map, inlay_snapshot) = InlayMap::new(buffer_snapshot.clone()); + let (mut inlay_map, inlay_snapshot) = InlayMap::new(buffer_snapshot); log::info!("InlayMap text: {:?}", inlay_snapshot.text()); let (mut fold_map, _) = FoldMap::new(inlay_snapshot.clone()); fold_map.randomly_mutate(&mut rng); @@ -758,7 +1057,7 @@ mod tests { let (inlay_snapshot, _) = inlay_map.randomly_mutate(&mut 0, &mut rng); log::info!("InlayMap text: {:?}", inlay_snapshot.text()); - let (mut tab_map, _) = TabMap::new(fold_snapshot.clone(), tab_size); + let (mut tab_map, _) = TabMap::new(fold_snapshot, tab_size); let tabs_snapshot = tab_map.set_max_expansion_column(32); let text = text::Rope::from(tabs_snapshot.text().as_str()); @@ -769,11 +1068,11 @@ mod tests { ); for _ in 0..5 { - let end_row = rng.gen_range(0..=text.max_point().row); - let end_column = rng.gen_range(0..=text.line_len(end_row)); + let end_row = rng.random_range(0..=text.max_point().row); + let end_column = rng.random_range(0..=text.line_len(end_row)); let mut end = TabPoint(text.clip_point(Point::new(end_row, end_column), Bias::Right)); - let start_row = rng.gen_range(0..=text.max_point().row); - let start_column = rng.gen_range(0..=text.line_len(start_row)); + let start_row = rng.random_range(0..=text.max_point().row); + let start_column = rng.random_range(0..=text.line_len(start_row)); let mut start = TabPoint(text.clip_point(Point::new(start_row, start_column), Bias::Left)); if start > end { @@ -811,4 +1110,479 @@ mod tests { ); } } + + #[gpui::test(iterations = 100)] + fn test_to_tab_point_random(cx: &mut gpui::App, mut rng: StdRng) { + let tab_size = NonZeroU32::new(rng.random_range(1..=16)).unwrap(); + let len = rng.random_range(0..=2000); + + // Generate random text using RandomCharIter + let text = util::RandomCharIter::new(&mut rng) + .take(len) + .collect::(); + + // Create buffer and tab map + let buffer = MultiBuffer::build_simple(&text, cx); + let buffer_snapshot = buffer.read(cx).snapshot(cx); + let (mut inlay_map, inlay_snapshot) = InlayMap::new(buffer_snapshot); + let (mut fold_map, fold_snapshot) = FoldMap::new(inlay_snapshot); + let (mut tab_map, _) = TabMap::new(fold_snapshot, tab_size); + + let mut next_inlay_id = 0; + let (inlay_snapshot, inlay_edits) = inlay_map.randomly_mutate(&mut next_inlay_id, &mut rng); + let (fold_snapshot, fold_edits) = fold_map.read(inlay_snapshot, inlay_edits); + let max_fold_point = fold_snapshot.max_point(); + let (mut tab_snapshot, _) = tab_map.sync(fold_snapshot.clone(), fold_edits, tab_size); + + // Test random fold points + for _ in 0..50 { + tab_snapshot.max_expansion_column = rng.random_range(0..=256); + // Generate random fold point + let row = rng.random_range(0..=max_fold_point.row()); + let max_column = if row < max_fold_point.row() { + fold_snapshot.line_len(row) + } else { + max_fold_point.column() + }; + let column = rng.random_range(0..=max_column + 10); + let fold_point = FoldPoint::new(row, column); + + let actual = tab_snapshot.to_tab_point(fold_point); + let expected = tab_snapshot.expected_to_tab_point(fold_point); + + assert_eq!( + actual, expected, + "to_tab_point mismatch for fold_point {:?} in text {:?}", + fold_point, text + ); + } + } + + #[gpui::test] + fn test_tab_stop_cursor_utf8(cx: &mut gpui::App) { + let text = "\tfoo\tbarbarbar\t\tbaz\n"; + let buffer = MultiBuffer::build_simple(text, cx); + let buffer_snapshot = buffer.read(cx).snapshot(cx); + let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot); + let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); + let chunks = fold_snapshot.chunks( + FoldOffset(0)..fold_snapshot.len(), + false, + Default::default(), + ); + let mut cursor = TabStopCursor::new(chunks); + assert!(cursor.seek(0).is_none()); + let mut tab_stops = Vec::new(); + + let mut all_tab_stops = Vec::new(); + let mut byte_offset = 0; + for (offset, ch) in buffer.read(cx).snapshot(cx).text().char_indices() { + byte_offset += ch.len_utf8() as u32; + + if ch == '\t' { + all_tab_stops.push(TabStop { + byte_offset, + char_offset: offset as u32 + 1, + }); + } + } + + while let Some(tab_stop) = cursor.seek(u32::MAX) { + tab_stops.push(tab_stop); + } + pretty_assertions::assert_eq!(tab_stops.as_slice(), all_tab_stops.as_slice(),); + + assert_eq!(cursor.byte_offset(), byte_offset); + } + + #[gpui::test] + fn test_tab_stop_with_end_range_utf8(cx: &mut gpui::App) { + let input = "A\tBC\t"; // DEF\tG\tHI\tJ\tK\tL\tM + + let buffer = MultiBuffer::build_simple(input, cx); + let buffer_snapshot = buffer.read(cx).snapshot(cx); + let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot); + let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); + + let chunks = fold_snapshot.chunks_at(FoldPoint::new(0, 0)); + let mut cursor = TabStopCursor::new(chunks); + + let mut actual_tab_stops = Vec::new(); + + let mut expected_tab_stops = Vec::new(); + let mut byte_offset = 0; + for (offset, ch) in buffer.read(cx).snapshot(cx).text().char_indices() { + byte_offset += ch.len_utf8() as u32; + + if ch == '\t' { + expected_tab_stops.push(TabStop { + byte_offset, + char_offset: offset as u32 + 1, + }); + } + } + + while let Some(tab_stop) = cursor.seek(u32::MAX) { + actual_tab_stops.push(tab_stop); + } + pretty_assertions::assert_eq!(actual_tab_stops.as_slice(), expected_tab_stops.as_slice(),); + + assert_eq!(cursor.byte_offset(), byte_offset); + } + + #[gpui::test(iterations = 100)] + fn test_tab_stop_cursor_random_utf8(cx: &mut gpui::App, mut rng: StdRng) { + // Generate random input string with up to 512 characters including tabs + let len = rng.random_range(0..=2048); + let mut input = String::with_capacity(len); + + let mut skip_tabs = rng.random_bool(0.10); + for idx in 0..len { + if idx % 128 == 0 { + skip_tabs = rng.random_bool(0.10); + } + + if rng.random_bool(0.15) && !skip_tabs { + input.push('\t'); + } else { + let ch = loop { + let ascii_code = rng.random_range(32..=126); // printable ASCII range + let ch = ascii_code as u8 as char; + if ch != '\t' { + break ch; + } + }; + input.push(ch); + } + } + + // Build the buffer and create cursor + let buffer = MultiBuffer::build_simple(&input, cx); + let buffer_snapshot = buffer.read(cx).snapshot(cx); + let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot.clone()); + let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); + + // First, collect all expected tab positions + let mut all_tab_stops = Vec::new(); + let mut byte_offset = 1; + let mut char_offset = 1; + for ch in buffer_snapshot.text().chars() { + if ch == '\t' { + all_tab_stops.push(TabStop { + byte_offset, + char_offset, + }); + } + byte_offset += ch.len_utf8() as u32; + char_offset += 1; + } + + // Test with various distances + let distances = vec![1, 5, 10, 50, 100, u32::MAX]; + // let distances = vec![150]; + + for distance in distances { + let chunks = fold_snapshot.chunks_at(FoldPoint::new(0, 0)); + let mut cursor = TabStopCursor::new(chunks); + + let mut found_tab_stops = Vec::new(); + let mut position = distance; + while let Some(tab_stop) = cursor.seek(position) { + found_tab_stops.push(tab_stop); + position = distance - tab_stop.byte_offset; + } + + let expected_found_tab_stops: Vec<_> = all_tab_stops + .iter() + .take_while(|tab_stop| tab_stop.byte_offset <= distance) + .cloned() + .collect(); + + pretty_assertions::assert_eq!( + found_tab_stops, + expected_found_tab_stops, + "TabStopCursor output mismatch for distance {}. Input: {:?}", + distance, + input + ); + + let final_position = cursor.byte_offset(); + if !found_tab_stops.is_empty() { + let last_tab_stop = found_tab_stops.last().unwrap(); + assert!( + final_position >= last_tab_stop.byte_offset, + "Cursor final position {} is before last tab stop {}. Input: {:?}", + final_position, + last_tab_stop.byte_offset, + input + ); + } + } + } + + #[gpui::test] + fn test_tab_stop_cursor_utf16(cx: &mut gpui::App) { + let text = "\r\t😁foo\tb😀arbar🤯bar\t\tbaz\n"; + let buffer = MultiBuffer::build_simple(text, cx); + let buffer_snapshot = buffer.read(cx).snapshot(cx); + let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot); + let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); + let chunks = fold_snapshot.chunks( + FoldOffset(0)..fold_snapshot.len(), + false, + Default::default(), + ); + let mut cursor = TabStopCursor::new(chunks); + assert!(cursor.seek(0).is_none()); + + let mut expected_tab_stops = Vec::new(); + let mut byte_offset = 0; + for (i, ch) in fold_snapshot.chars_at(FoldPoint::new(0, 0)).enumerate() { + byte_offset += ch.len_utf8() as u32; + + if ch == '\t' { + expected_tab_stops.push(TabStop { + byte_offset, + char_offset: i as u32 + 1, + }); + } + } + + let mut actual_tab_stops = Vec::new(); + while let Some(tab_stop) = cursor.seek(u32::MAX) { + actual_tab_stops.push(tab_stop); + } + + pretty_assertions::assert_eq!(actual_tab_stops.as_slice(), expected_tab_stops.as_slice(),); + + assert_eq!(cursor.byte_offset(), byte_offset); + } + + #[gpui::test(iterations = 100)] + fn test_tab_stop_cursor_random_utf16(cx: &mut gpui::App, mut rng: StdRng) { + // Generate random input string with up to 512 characters including tabs + let len = rng.random_range(0..=2048); + let input = util::RandomCharIter::new(&mut rng) + .take(len) + .collect::(); + + // Build the buffer and create cursor + let buffer = MultiBuffer::build_simple(&input, cx); + let buffer_snapshot = buffer.read(cx).snapshot(cx); + let (_, inlay_snapshot) = InlayMap::new(buffer_snapshot.clone()); + let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); + + // First, collect all expected tab positions + let mut all_tab_stops = Vec::new(); + let mut byte_offset = 0; + for (i, ch) in buffer_snapshot.text().chars().enumerate() { + byte_offset += ch.len_utf8() as u32; + if ch == '\t' { + all_tab_stops.push(TabStop { + byte_offset, + char_offset: i as u32 + 1, + }); + } + } + + // Test with various distances + // let distances = vec![1, 5, 10, 50, 100, u32::MAX]; + let distances = vec![150]; + + for distance in distances { + let chunks = fold_snapshot.chunks_at(FoldPoint::new(0, 0)); + let mut cursor = TabStopCursor::new(chunks); + + let mut found_tab_stops = Vec::new(); + let mut position = distance; + while let Some(tab_stop) = cursor.seek(position) { + found_tab_stops.push(tab_stop); + position = distance - tab_stop.byte_offset; + } + + let expected_found_tab_stops: Vec<_> = all_tab_stops + .iter() + .take_while(|tab_stop| tab_stop.byte_offset <= distance) + .cloned() + .collect(); + + pretty_assertions::assert_eq!( + found_tab_stops, + expected_found_tab_stops, + "TabStopCursor output mismatch for distance {}. Input: {:?}", + distance, + input + ); + + let final_position = cursor.byte_offset(); + if !found_tab_stops.is_empty() { + let last_tab_stop = found_tab_stops.last().unwrap(); + assert!( + final_position >= last_tab_stop.byte_offset, + "Cursor final position {} is before last tab stop {}. Input: {:?}", + final_position, + last_tab_stop.byte_offset, + input + ); + } + } + } +} + +struct TabStopCursor<'a, I> +where + I: Iterator>, +{ + chunks: I, + byte_offset: u32, + char_offset: u32, + /// Chunk + /// last tab position iterated through + current_chunk: Option<(Chunk<'a>, u32)>, +} + +impl<'a, I> TabStopCursor<'a, I> +where + I: Iterator>, +{ + fn new(chunks: impl IntoIterator, IntoIter = I>) -> Self { + Self { + chunks: chunks.into_iter(), + byte_offset: 0, + char_offset: 0, + current_chunk: None, + } + } + + fn bytes_until_next_char(&self) -> Option { + self.current_chunk.as_ref().and_then(|(chunk, idx)| { + let mut idx = *idx; + let mut diff = 0; + while idx > 0 && chunk.chars & (1 << idx) == 0 { + idx -= 1; + diff += 1; + } + + if chunk.chars & (1 << idx) != 0 { + Some( + (chunk.text[idx as usize..].chars().next()?) + .len_utf8() + .saturating_sub(diff), + ) + } else { + None + } + }) + } + + fn is_char_boundary(&self) -> bool { + self.current_chunk + .as_ref() + .is_some_and(|(chunk, idx)| (chunk.chars & (1 << *idx.min(&127))) != 0) + } + + /// distance: length to move forward while searching for the next tab stop + fn seek(&mut self, distance: u32) -> Option { + if distance == 0 { + return None; + } + + let mut distance_traversed = 0; + + while let Some((mut chunk, chunk_position)) = self + .current_chunk + .take() + .or_else(|| self.chunks.next().zip(Some(0))) + { + if chunk.tabs == 0 { + let chunk_distance = chunk.text.len() as u32 - chunk_position; + if chunk_distance + distance_traversed >= distance { + let overshoot = distance_traversed.abs_diff(distance); + + self.byte_offset += overshoot; + self.char_offset += get_char_offset( + chunk_position..(chunk_position + overshoot).saturating_sub(1).min(127), + chunk.chars, + ); + + self.current_chunk = Some((chunk, chunk_position + overshoot)); + + return None; + } + + self.byte_offset += chunk_distance; + self.char_offset += get_char_offset( + chunk_position..(chunk_position + chunk_distance).saturating_sub(1).min(127), + chunk.chars, + ); + distance_traversed += chunk_distance; + continue; + } + let tab_position = chunk.tabs.trailing_zeros() + 1; + + if distance_traversed + tab_position - chunk_position > distance { + let cursor_position = distance_traversed.abs_diff(distance); + + self.char_offset += get_char_offset( + chunk_position..(chunk_position + cursor_position - 1), + chunk.chars, + ); + self.current_chunk = Some((chunk, cursor_position + chunk_position)); + self.byte_offset += cursor_position; + + return None; + } + + self.byte_offset += tab_position - chunk_position; + self.char_offset += get_char_offset(chunk_position..(tab_position - 1), chunk.chars); + + let tabstop = TabStop { + char_offset: self.char_offset, + byte_offset: self.byte_offset, + }; + + chunk.tabs = (chunk.tabs - 1) & chunk.tabs; + + if tab_position as usize != chunk.text.len() { + self.current_chunk = Some((chunk, tab_position)); + } + + return Some(tabstop); + } + + None + } + + fn byte_offset(&self) -> u32 { + self.byte_offset + } + + fn char_offset(&self) -> u32 { + self.char_offset + } +} + +#[inline(always)] +fn get_char_offset(range: Range, bit_map: u128) -> u32 { + // This edge case can happen when we're at chunk position 128 + + if range.start == range.end { + return if (1u128 << range.start) & bit_map == 0 { + 0 + } else { + 1 + }; + } + let end_shift: u128 = 127u128 - range.end.min(127) as u128; + let mut bit_mask = (u128::MAX >> range.start) << range.start; + bit_mask = (bit_mask << end_shift) >> end_shift; + let bit_map = bit_map & bit_mask; + + bit_map.count_ones() +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +struct TabStop { + char_offset: u32, + byte_offset: u32, } diff --git a/crates/editor/src/display_map/wrap_map.rs b/crates/editor/src/display_map/wrap_map.rs index caa4882a6ebbb00aaa1e498e49dfb530153a0e8e..cd354d8229634956651ab74dd384332db0eb219e 100644 --- a/crates/editor/src/display_map/wrap_map.rs +++ b/crates/editor/src/display_map/wrap_map.rs @@ -74,10 +74,10 @@ impl WrapRows<'_> { self.transforms .seek(&WrapPoint::new(start_row, 0), Bias::Left); let mut input_row = self.transforms.start().1.row(); - if self.transforms.item().map_or(false, |t| t.is_isomorphic()) { + if self.transforms.item().is_some_and(|t| t.is_isomorphic()) { input_row += start_row - self.transforms.start().0.row(); } - self.soft_wrapped = self.transforms.item().map_or(false, |t| !t.is_isomorphic()); + self.soft_wrapped = self.transforms.item().is_some_and(|t| !t.is_isomorphic()); self.input_buffer_rows.seek(input_row); self.input_buffer_row = self.input_buffer_rows.next().unwrap(); self.output_row = start_row; @@ -249,48 +249,48 @@ impl WrapMap { return; } - if let Some(wrap_width) = self.wrap_width { - if self.background_task.is_none() { - let pending_edits = self.pending_edits.clone(); - let mut snapshot = self.snapshot.clone(); - let text_system = cx.text_system().clone(); - let (font, font_size) = self.font_with_size.clone(); - let update_task = cx.background_spawn(async move { - let mut edits = Patch::default(); - let mut line_wrapper = text_system.line_wrapper(font, font_size); - for (tab_snapshot, tab_edits) in pending_edits { - let wrap_edits = snapshot - .update(tab_snapshot, &tab_edits, wrap_width, &mut line_wrapper) - .await; - edits = edits.compose(&wrap_edits); - } - (snapshot, edits) - }); + if let Some(wrap_width) = self.wrap_width + && self.background_task.is_none() + { + let pending_edits = self.pending_edits.clone(); + let mut snapshot = self.snapshot.clone(); + let text_system = cx.text_system().clone(); + let (font, font_size) = self.font_with_size.clone(); + let update_task = cx.background_spawn(async move { + let mut edits = Patch::default(); + let mut line_wrapper = text_system.line_wrapper(font, font_size); + for (tab_snapshot, tab_edits) in pending_edits { + let wrap_edits = snapshot + .update(tab_snapshot, &tab_edits, wrap_width, &mut line_wrapper) + .await; + edits = edits.compose(&wrap_edits); + } + (snapshot, edits) + }); - match cx - .background_executor() - .block_with_timeout(Duration::from_millis(1), update_task) - { - Ok((snapshot, output_edits)) => { - self.snapshot = snapshot; - self.edits_since_sync = self.edits_since_sync.compose(&output_edits); - } - Err(update_task) => { - self.background_task = Some(cx.spawn(async move |this, cx| { - let (snapshot, edits) = update_task.await; - this.update(cx, |this, cx| { - this.snapshot = snapshot; - this.edits_since_sync = this - .edits_since_sync - .compose(mem::take(&mut this.interpolated_edits).invert()) - .compose(&edits); - this.background_task = None; - this.flush_edits(cx); - cx.notify(); - }) - .ok(); - })); - } + match cx + .background_executor() + .block_with_timeout(Duration::from_millis(1), update_task) + { + Ok((snapshot, output_edits)) => { + self.snapshot = snapshot; + self.edits_since_sync = self.edits_since_sync.compose(&output_edits); + } + Err(update_task) => { + self.background_task = Some(cx.spawn(async move |this, cx| { + let (snapshot, edits) = update_task.await; + this.update(cx, |this, cx| { + this.snapshot = snapshot; + this.edits_since_sync = this + .edits_since_sync + .compose(mem::take(&mut this.interpolated_edits).invert()) + .compose(&edits); + this.background_task = None; + this.flush_edits(cx); + cx.notify(); + }) + .ok(); + })); } } } @@ -603,7 +603,7 @@ impl WrapSnapshot { .cursor::>(&()); transforms.seek(&output_start, Bias::Right); let mut input_start = TabPoint(transforms.start().1.0); - if transforms.item().map_or(false, |t| t.is_isomorphic()) { + if transforms.item().is_some_and(|t| t.is_isomorphic()) { input_start.0 += output_start.0 - transforms.start().0.0; } let input_end = self @@ -634,7 +634,7 @@ impl WrapSnapshot { cursor.seek(&WrapPoint::new(row + 1, 0), Bias::Left); if cursor .item() - .map_or(false, |transform| transform.is_isomorphic()) + .is_some_and(|transform| transform.is_isomorphic()) { let overshoot = row - cursor.start().0.row(); let tab_row = cursor.start().1.row() + overshoot; @@ -732,10 +732,10 @@ impl WrapSnapshot { .cursor::>(&()); transforms.seek(&WrapPoint::new(start_row, 0), Bias::Left); let mut input_row = transforms.start().1.row(); - if transforms.item().map_or(false, |t| t.is_isomorphic()) { + if transforms.item().is_some_and(|t| t.is_isomorphic()) { input_row += start_row - transforms.start().0.row(); } - let soft_wrapped = transforms.item().map_or(false, |t| !t.is_isomorphic()); + let soft_wrapped = transforms.item().is_some_and(|t| !t.is_isomorphic()); let mut input_buffer_rows = self.tab_snapshot.rows(input_row); let input_buffer_row = input_buffer_rows.next().unwrap(); WrapRows { @@ -754,7 +754,7 @@ impl WrapSnapshot { .cursor::>(&()); cursor.seek(&point, Bias::Right); let mut tab_point = cursor.start().1.0; - if cursor.item().map_or(false, |t| t.is_isomorphic()) { + if cursor.item().is_some_and(|t| t.is_isomorphic()) { tab_point += point.0 - cursor.start().0.0; } TabPoint(tab_point) @@ -780,7 +780,7 @@ impl WrapSnapshot { if bias == Bias::Left { let mut cursor = self.transforms.cursor::(&()); cursor.seek(&point, Bias::Right); - if cursor.item().map_or(false, |t| !t.is_isomorphic()) { + if cursor.item().is_some_and(|t| !t.is_isomorphic()) { point = *cursor.start(); *point.column_mut() -= 1; } @@ -901,7 +901,7 @@ impl WrapChunks<'_> { let output_end = WrapPoint::new(rows.end, 0); self.transforms.seek(&output_start, Bias::Right); let mut input_start = TabPoint(self.transforms.start().1.0); - if self.transforms.item().map_or(false, |t| t.is_isomorphic()) { + if self.transforms.item().is_some_and(|t| t.is_isomorphic()) { input_start.0 += output_start.0 - self.transforms.start().0.0; } let input_end = self @@ -970,9 +970,25 @@ impl<'a> Iterator for WrapChunks<'a> { } let (prefix, suffix) = self.input_chunk.text.split_at(input_len); + + let (chars, tabs) = if input_len == 128 { + let output = (self.input_chunk.chars, self.input_chunk.tabs); + self.input_chunk.chars = 0; + self.input_chunk.tabs = 0; + output + } else { + let mask = (1 << input_len) - 1; + let output = (self.input_chunk.chars & mask, self.input_chunk.tabs & mask); + self.input_chunk.chars = self.input_chunk.chars >> input_len; + self.input_chunk.tabs = self.input_chunk.tabs >> input_len; + output + }; + self.input_chunk.text = suffix; Some(Chunk { text: prefix, + chars, + tabs, ..self.input_chunk.clone() }) } @@ -993,7 +1009,7 @@ impl Iterator for WrapRows<'_> { self.output_row += 1; self.transforms .seek_forward(&WrapPoint::new(self.output_row, 0), Bias::Left); - if self.transforms.item().map_or(false, |t| t.is_isomorphic()) { + if self.transforms.item().is_some_and(|t| t.is_isomorphic()) { self.input_buffer_row = self.input_buffer_rows.next().unwrap(); self.soft_wrapped = false; } else { @@ -1065,12 +1081,12 @@ impl sum_tree::Item for Transform { } fn push_isomorphic(transforms: &mut Vec, summary: TextSummary) { - if let Some(last_transform) = transforms.last_mut() { - if last_transform.is_isomorphic() { - last_transform.summary.input += &summary; - last_transform.summary.output += &summary; - return; - } + if let Some(last_transform) = transforms.last_mut() + && last_transform.is_isomorphic() + { + last_transform.summary.input += &summary; + last_transform.summary.output += &summary; + return; } transforms.push(Transform::isomorphic(summary)); } @@ -1215,12 +1231,12 @@ mod tests { .unwrap_or(10); let text_system = cx.read(|cx| cx.text_system().clone()); - let mut wrap_width = if rng.gen_bool(0.1) { + let mut wrap_width = if rng.random_bool(0.1) { None } else { - Some(px(rng.gen_range(0.0..=1000.0))) + Some(px(rng.random_range(0.0..=1000.0))) }; - let tab_size = NonZeroU32::new(rng.gen_range(1..=4)).unwrap(); + let tab_size = NonZeroU32::new(rng.random_range(1..=4)).unwrap(); let font = test_font(); let _font_id = text_system.resolve_font(&font); @@ -1230,10 +1246,10 @@ mod tests { log::info!("Wrap width: {:?}", wrap_width); let buffer = cx.update(|cx| { - if rng.r#gen() { + if rng.random() { MultiBuffer::build_random(&mut rng, cx) } else { - let len = rng.gen_range(0..10); + let len = rng.random_range(0..10); let text = util::RandomCharIter::new(&mut rng) .take(len) .collect::(); @@ -1281,12 +1297,12 @@ mod tests { log::info!("{} ==============================================", _i); let mut buffer_edits = Vec::new(); - match rng.gen_range(0..=100) { + match rng.random_range(0..=100) { 0..=19 => { - wrap_width = if rng.gen_bool(0.2) { + wrap_width = if rng.random_bool(0.2) { None } else { - Some(px(rng.gen_range(0.0..=1000.0))) + Some(px(rng.random_range(0.0..=1000.0))) }; log::info!("Setting wrap width to {:?}", wrap_width); wrap_map.update(cx, |map, cx| map.set_wrap_width(wrap_width, cx)); @@ -1317,7 +1333,7 @@ mod tests { _ => { buffer.update(cx, |buffer, cx| { let subscription = buffer.subscribe(); - let edit_count = rng.gen_range(1..=5); + let edit_count = rng.random_range(1..=5); buffer.randomly_mutate(&mut rng, edit_count, cx); buffer_snapshot = buffer.snapshot(cx); buffer_edits.extend(subscription.consume()); @@ -1341,7 +1357,7 @@ mod tests { snapshot.verify_chunks(&mut rng); edits.push((snapshot, wrap_edits)); - if wrap_map.read_with(cx, |map, _| map.is_rewrapping()) && rng.gen_bool(0.4) { + if wrap_map.read_with(cx, |map, _| map.is_rewrapping()) && rng.random_bool(0.4) { log::info!("Waiting for wrapping to finish"); while wrap_map.read_with(cx, |map, _| map.is_rewrapping()) { notifications.next().await.unwrap(); @@ -1461,7 +1477,7 @@ mod tests { } let mut prev_ix = 0; - for boundary in line_wrapper.wrap_line(&[LineFragment::text(&line)], wrap_width) { + for boundary in line_wrapper.wrap_line(&[LineFragment::text(line)], wrap_width) { wrapped_text.push_str(&line[prev_ix..boundary.ix]); wrapped_text.push('\n'); wrapped_text.push_str(&" ".repeat(boundary.next_indent as usize)); @@ -1479,8 +1495,8 @@ mod tests { impl WrapSnapshot { fn verify_chunks(&mut self, rng: &mut impl Rng) { for _ in 0..5 { - let mut end_row = rng.gen_range(0..=self.max_point().row()); - let start_row = rng.gen_range(0..=end_row); + let mut end_row = rng.random_range(0..=self.max_point().row()); + let start_row = rng.random_range(0..=end_row); end_row += 1; let mut expected_text = self.text_chunks(start_row).collect::(); diff --git a/crates/editor/src/edit_prediction_tests.rs b/crates/editor/src/edit_prediction_tests.rs index 7bf51e45d72f383b4af34cf6ad493792f8e9d351..bba632e81f77ba91927abd1c0e3448a732e1c6f5 100644 --- a/crates/editor/src/edit_prediction_tests.rs +++ b/crates/editor/src/edit_prediction_tests.rs @@ -7,7 +7,9 @@ use std::ops::Range; use text::{Point, ToOffset}; use crate::{ - EditPrediction, editor_tests::init_test, test::editor_test_context::EditorTestContext, + EditPrediction, + editor_tests::{init_test, update_test_language_settings}, + test::editor_test_context::EditorTestContext, }; #[gpui::test] @@ -271,6 +273,44 @@ async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui: }); } +#[gpui::test] +async fn test_edit_predictions_disabled_in_scope(cx: &mut gpui::TestAppContext) { + init_test(cx, |_| {}); + + update_test_language_settings(cx, |settings| { + settings.defaults.edit_predictions_disabled_in = Some(vec!["string".to_string()]); + }); + + let mut cx = EditorTestContext::new(cx).await; + let provider = cx.new(|_| FakeEditPredictionProvider::default()); + assign_editor_completion_provider(provider.clone(), &mut cx); + + let language = languages::language("javascript", tree_sitter_typescript::LANGUAGE_TSX.into()); + cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx)); + + // Test disabled inside of string + cx.set_state("const x = \"hello ˇworld\";"); + propose_edits(&provider, vec![(17..17, "beautiful ")], &mut cx); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); + cx.editor(|editor, _, _| { + assert!( + editor.active_edit_prediction.is_none(), + "Edit predictions should be disabled in string scopes when configured in edit_predictions_disabled_in" + ); + }); + + // Test enabled outside of string + cx.set_state("const x = \"hello world\"; ˇ"); + propose_edits(&provider, vec![(24..24, "// comment")], &mut cx); + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); + cx.editor(|editor, _, _| { + assert!( + editor.active_edit_prediction.is_some(), + "Edit predictions should work outside of disabled scopes" + ); + }); +} + fn assert_editor_active_edit_completion( cx: &mut EditorTestContext, assert: impl FnOnce(MultiBufferSnapshot, &Vec<(Range, String)>), diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index cbee9021ed6b22ce36ea9f0473eacab52329a971..b731006a62990b5b9de75223ca38fbebb684c91c 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -147,23 +147,24 @@ use multi_buffer::{ use parking_lot::Mutex; use persistence::DB; use project::{ - BreakpointWithPosition, CodeAction, Completion, CompletionIntent, CompletionResponse, - CompletionSource, DisableAiSettings, DocumentHighlight, InlayHint, Location, LocationLink, - PrepareRenameResponse, Project, ProjectItem, ProjectPath, ProjectTransaction, TaskSourceKind, - debugger::breakpoint_store::Breakpoint, + BreakpointWithPosition, CodeAction, Completion, CompletionDisplayOptions, CompletionIntent, + CompletionResponse, CompletionSource, DisableAiSettings, DocumentHighlight, InlayHint, + Location, LocationLink, PrepareRenameResponse, Project, ProjectItem, ProjectPath, + ProjectTransaction, TaskSourceKind, debugger::{ breakpoint_store::{ - BreakpointEditAction, BreakpointSessionState, BreakpointState, BreakpointStore, - BreakpointStoreEvent, + Breakpoint, BreakpointEditAction, BreakpointSessionState, BreakpointState, + BreakpointStore, BreakpointStoreEvent, }, session::{Session, SessionEvent}, }, git_store::{GitStoreEvent, RepositoryEvent}, lsp_store::{CompletionDocumentation, FormatTrigger, LspFormatTarget, OpenLspBufferHandle}, - project_settings::{DiagnosticSeverity, GoToDiagnosticSeverityFilter}, - project_settings::{GitGutterSetting, ProjectSettings}, + project_settings::{ + DiagnosticSeverity, GitGutterSetting, GoToDiagnosticSeverityFilter, ProjectSettings, + }, }; -use rand::{seq::SliceRandom, thread_rng}; +use rand::seq::SliceRandom; use rpc::{ErrorCode, ErrorExt, proto::PeerId}; use scroll::{Autoscroll, OngoingScroll, ScrollAnchor, ScrollManager, ScrollbarAutoHide}; use selections_collection::{ @@ -189,7 +190,6 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; -use sum_tree::TreeMap; use task::{ResolvedTask, RunnableTag, TaskTemplate, TaskVariables}; use text::{BufferId, FromAnchor, OffsetUtf16, Rope}; use theme::{ @@ -219,7 +219,6 @@ use crate::{ pub const FILE_HEADER_HEIGHT: u32 = 2; pub const MULTI_BUFFER_EXCERPT_HEADER_HEIGHT: u32 = 1; -pub const DEFAULT_MULTIBUFFER_CONTEXT: u32 = 2; const CURSOR_BLINK_INTERVAL: Duration = Duration::from_millis(500); const MAX_LINE_LEN: usize = 1024; const MIN_NAVIGATION_HISTORY_ROW_DELTA: i64 = 10; @@ -227,7 +226,7 @@ const MAX_SELECTION_HISTORY_LEN: usize = 1024; pub(crate) const CURSORS_VISIBLE_FOR: Duration = Duration::from_millis(2000); #[doc(hidden)] pub const CODE_ACTIONS_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(250); -const SELECTION_HIGHLIGHT_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(100); +pub const SELECTION_HIGHLIGHT_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(100); pub(crate) const CODE_ACTION_TIMEOUT: Duration = Duration::from_secs(5); pub(crate) const FORMAT_TIMEOUT: Duration = Duration::from_secs(5); @@ -253,7 +252,6 @@ pub type RenderDiffHunkControlsFn = Arc< enum ReportEditorEvent { Saved { auto_saved: bool }, EditorOpened, - ZetaTosClicked, Closed, } @@ -262,7 +260,6 @@ impl ReportEditorEvent { match self { Self::Saved { .. } => "Editor Saved", Self::EditorOpened => "Editor Opened", - Self::ZetaTosClicked => "Edit Prediction Provider ToS Clicked", Self::Closed => "Editor Closed", } } @@ -782,10 +779,7 @@ impl MinimapVisibility { } fn disabled(&self) -> bool { - match *self { - Self::Disabled => true, - _ => false, - } + matches!(*self, Self::Disabled) } fn settings_visibility(&self) -> bool { @@ -942,10 +936,10 @@ impl ChangeList { } pub fn invert_last_group(&mut self) { - if let Some(last) = self.changes.last_mut() { - if let Some(current) = last.current.as_mut() { - mem::swap(&mut last.original, current); - } + if let Some(last) = self.changes.last_mut() + && let Some(current) = last.current.as_mut() + { + mem::swap(&mut last.original, current); } } } @@ -1014,6 +1008,7 @@ pub struct Editor { /// Map of how text in the buffer should be displayed. /// Handles soft wraps, folds, fake inlay text insertions, etc. pub display_map: Entity, + placeholder_display_map: Option>, pub selections: SelectionsCollection, pub scroll_manager: ScrollManager, /// When inline assist editors are linked, they all render cursors because @@ -1036,12 +1031,11 @@ pub struct Editor { inline_diagnostics_update: Task<()>, inline_diagnostics_enabled: bool, diagnostics_enabled: bool, + word_completions_enabled: bool, inline_diagnostics: Vec<(Anchor, InlineDiagnostic)>, soft_wrap_mode_override: Option, hard_wrap: Option, - - // TODO: make this a access method - pub project: Option>, + project: Option>, semantics_provider: Option>, completion_provider: Option>, collaboration_hub: Option>, @@ -1064,11 +1058,10 @@ pub struct Editor { show_breakpoints: Option, show_wrap_guides: Option, show_indent_guides: Option, - placeholder_text: Option>, highlight_order: usize, highlighted_rows: HashMap>, - background_highlights: TreeMap, - gutter_highlights: TreeMap, + background_highlights: HashMap, + gutter_highlights: HashMap, scrollbar_marker_state: ScrollbarMarkerState, active_indent_guides_state: ActiveIndentGuidesState, nav_history: Option, @@ -1216,7 +1209,7 @@ pub struct EditorSnapshot { show_breakpoints: Option, git_blame_gutter_max_author_length: Option, pub display_snapshot: DisplaySnapshot, - pub placeholder_text: Option>, + pub placeholder_display_snapshot: Option, is_focused: bool, scroll_anchor: ScrollAnchor, ongoing_scroll: OngoingScroll, @@ -1431,7 +1424,7 @@ impl SelectionHistory { if self .undo_stack .back() - .map_or(true, |e| e.selections != entry.selections) + .is_none_or(|e| e.selections != entry.selections) { self.undo_stack.push_back(entry); if self.undo_stack.len() > MAX_SELECTION_HISTORY_LEN { @@ -1444,7 +1437,7 @@ impl SelectionHistory { if self .redo_stack .back() - .map_or(true, |e| e.selections != entry.selections) + .is_none_or(|e| e.selections != entry.selections) { self.redo_stack.push_back(entry); if self.redo_stack.len() > MAX_SELECTION_HISTORY_LEN { @@ -1802,7 +1795,7 @@ impl Editor { let font_size = style.font_size.to_pixels(window.rem_size()); let editor = cx.entity().downgrade(); let fold_placeholder = FoldPlaceholder { - constrain_width: true, + constrain_width: false, render: Arc::new(move |fold_id, fold_range, cx| { let editor = editor.clone(); div() @@ -1859,118 +1852,166 @@ impl Editor { blink_manager }); - let soft_wrap_mode_override = matches!(mode, EditorMode::SingleLine { .. }) - .then(|| language_settings::SoftWrap::None); + let soft_wrap_mode_override = + matches!(mode, EditorMode::SingleLine).then(|| language_settings::SoftWrap::None); let mut project_subscriptions = Vec::new(); - if full_mode { - if let Some(project) = project.as_ref() { - project_subscriptions.push(cx.subscribe_in( - project, - window, - |editor, _, event, window, cx| match event { - project::Event::RefreshCodeLens => { - // we always query lens with actions, without storing them, always refreshing them - } - project::Event::RefreshInlayHints => { - editor - .refresh_inlay_hints(InlayHintRefreshReason::RefreshRequested, cx); - } - project::Event::LanguageServerAdded(..) - | project::Event::LanguageServerRemoved(..) => { - if editor.tasks_update_task.is_none() { - editor.tasks_update_task = - Some(editor.refresh_runnables(window, cx)); - } + if full_mode && let Some(project) = project.as_ref() { + project_subscriptions.push(cx.subscribe_in( + project, + window, + |editor, _, event, window, cx| match event { + project::Event::RefreshCodeLens => { + // we always query lens with actions, without storing them, always refreshing them + } + project::Event::RefreshInlayHints => { + editor.refresh_inlay_hints(InlayHintRefreshReason::RefreshRequested, cx); + } + project::Event::LanguageServerAdded(..) + | project::Event::LanguageServerRemoved(..) => { + if editor.tasks_update_task.is_none() { + editor.tasks_update_task = Some(editor.refresh_runnables(window, cx)); } - project::Event::SnippetEdit(id, snippet_edits) => { - if let Some(buffer) = editor.buffer.read(cx).buffer(*id) { - let focus_handle = editor.focus_handle(cx); - if focus_handle.is_focused(window) { - let snapshot = buffer.read(cx).snapshot(); - for (range, snippet) in snippet_edits { - let editor_range = - language::range_from_lsp(*range).to_offset(&snapshot); - editor - .insert_snippet( - &[editor_range], - snippet.clone(), - window, - cx, - ) - .ok(); - } + } + project::Event::SnippetEdit(id, snippet_edits) => { + if let Some(buffer) = editor.buffer.read(cx).buffer(*id) { + let focus_handle = editor.focus_handle(cx); + if focus_handle.is_focused(window) { + let snapshot = buffer.read(cx).snapshot(); + for (range, snippet) in snippet_edits { + let editor_range = + language::range_from_lsp(*range).to_offset(&snapshot); + editor + .insert_snippet( + &[editor_range], + snippet.clone(), + window, + cx, + ) + .ok(); } } } - project::Event::LanguageServerBufferRegistered { buffer_id, .. } => { - if editor.buffer().read(cx).buffer(*buffer_id).is_some() { - editor.update_lsp_data(false, Some(*buffer_id), window, cx); - } + } + project::Event::LanguageServerBufferRegistered { buffer_id, .. } => { + if editor.buffer().read(cx).buffer(*buffer_id).is_some() { + editor.update_lsp_data(false, Some(*buffer_id), window, cx); } - _ => {} - }, - )); - if let Some(task_inventory) = project - .read(cx) - .task_store() - .read(cx) - .task_inventory() - .cloned() - { - project_subscriptions.push(cx.observe_in( - &task_inventory, - window, - |editor, _, window, cx| { - editor.tasks_update_task = Some(editor.refresh_runnables(window, cx)); - }, - )); - }; + } - project_subscriptions.push(cx.subscribe_in( - &project.read(cx).breakpoint_store(), - window, - |editor, _, event, window, cx| match event { - BreakpointStoreEvent::ClearDebugLines => { - editor.clear_row_highlights::(); - editor.refresh_inline_values(cx); - } - BreakpointStoreEvent::SetDebugLine => { - if editor.go_to_active_debug_line(window, cx) { - cx.stop_propagation(); - } + project::Event::EntryRenamed(transaction) => { + let Some(workspace) = editor.workspace() else { + return; + }; + let Some(active_editor) = workspace.read(cx).active_item_as::(cx) + else { + return; + }; + if active_editor.entity_id() == cx.entity_id() { + let edited_buffers_already_open = { + let other_editors: Vec> = workspace + .read(cx) + .panes() + .iter() + .flat_map(|pane| pane.read(cx).items_of_type::()) + .filter(|editor| editor.entity_id() != cx.entity_id()) + .collect(); + + transaction.0.keys().all(|buffer| { + other_editors.iter().any(|editor| { + let multi_buffer = editor.read(cx).buffer(); + multi_buffer.read(cx).is_singleton() + && multi_buffer.read(cx).as_singleton().map_or( + false, + |singleton| { + singleton.entity_id() == buffer.entity_id() + }, + ) + }) + }) + }; - editor.refresh_inline_values(cx); + if !edited_buffers_already_open { + let workspace = workspace.downgrade(); + let transaction = transaction.clone(); + cx.defer_in(window, move |_, window, cx| { + cx.spawn_in(window, async move |editor, cx| { + Self::open_project_transaction( + &editor, + workspace, + transaction, + "Rename".to_string(), + cx, + ) + .await + .ok() + }) + .detach(); + }); + } } - _ => {} + } + + _ => {} + }, + )); + if let Some(task_inventory) = project + .read(cx) + .task_store() + .read(cx) + .task_inventory() + .cloned() + { + project_subscriptions.push(cx.observe_in( + &task_inventory, + window, + |editor, _, window, cx| { + editor.tasks_update_task = Some(editor.refresh_runnables(window, cx)); }, )); - let git_store = project.read(cx).git_store().clone(); - let project = project.clone(); - project_subscriptions.push(cx.subscribe(&git_store, move |this, _, event, cx| { - match event { - GitStoreEvent::RepositoryUpdated( - _, - RepositoryEvent::Updated { - new_instance: true, .. - }, - _, - ) => { - this.load_diff_task = Some( - update_uncommitted_diff_for_buffer( - cx.entity(), - &project, - this.buffer.read(cx).all_buffers(), - this.buffer.clone(), - cx, - ) - .shared(), - ); + }; + + project_subscriptions.push(cx.subscribe_in( + &project.read(cx).breakpoint_store(), + window, + |editor, _, event, window, cx| match event { + BreakpointStoreEvent::ClearDebugLines => { + editor.clear_row_highlights::(); + editor.refresh_inline_values(cx); + } + BreakpointStoreEvent::SetDebugLine => { + if editor.go_to_active_debug_line(window, cx) { + cx.stop_propagation(); } - _ => {} + + editor.refresh_inline_values(cx); } - })); - } + _ => {} + }, + )); + let git_store = project.read(cx).git_store().clone(); + let project = project.clone(); + project_subscriptions.push(cx.subscribe(&git_store, move |this, _, event, cx| { + if let GitStoreEvent::RepositoryUpdated( + _, + RepositoryEvent::Updated { + new_instance: true, .. + }, + _, + ) = event + { + this.load_diff_task = Some( + update_uncommitted_diff_for_buffer( + cx.entity(), + &project, + this.buffer.read(cx).all_buffers(), + this.buffer.clone(), + cx, + ) + .shared(), + ); + } + })); } let buffer_snapshot = buffer.read(cx).snapshot(cx); @@ -1991,14 +2032,12 @@ impl Editor { .detach(); } - let show_indent_guides = if matches!( - mode, - EditorMode::SingleLine { .. } | EditorMode::Minimap { .. } - ) { - Some(false) - } else { - None - }; + let show_indent_guides = + if matches!(mode, EditorMode::SingleLine | EditorMode::Minimap { .. }) { + Some(false) + } else { + None + }; let breakpoint_store = match (&mode, project.as_ref()) { (EditorMode::Full { .. }, Some(project)) => Some(project.read(cx).breakpoint_store()), @@ -2027,6 +2066,7 @@ impl Editor { last_focused_descendant: None, buffer: buffer.clone(), display_map: display_map.clone(), + placeholder_display_map: None, selections, scroll_manager: ScrollManager::new(cx), columnar_selection_state: None, @@ -2058,7 +2098,7 @@ impl Editor { vertical: full_mode, }, minimap_visibility: MinimapVisibility::for_mode(&mode, cx), - offset_content: !matches!(mode, EditorMode::SingleLine { .. }), + offset_content: !matches!(mode, EditorMode::SingleLine), show_breadcrumbs: EditorSettings::get_global(cx).toolbar.breadcrumbs, show_gutter: full_mode, show_line_numbers: (!full_mode).then_some(false), @@ -2070,11 +2110,10 @@ impl Editor { show_breakpoints: None, show_wrap_guides: None, show_indent_guides, - placeholder_text: None, highlight_order: 0, highlighted_rows: HashMap::default(), - background_highlights: TreeMap::default(), - gutter_highlights: TreeMap::default(), + background_highlights: HashMap::default(), + gutter_highlights: HashMap::default(), scrollbar_marker_state: ScrollbarMarkerState::default(), active_indent_guides_state: ActiveIndentGuidesState::default(), nav_history: None, @@ -2125,6 +2164,7 @@ impl Editor { }, inline_diagnostics_enabled: full_mode, diagnostics_enabled: full_mode, + word_completions_enabled: full_mode, inline_value_cache: InlineValueCache::new(inlay_hint_settings.show_value_hints), inlay_hint_cache: InlayHintCache::new(inlay_hint_settings), gutter_hovered: false, @@ -2325,15 +2365,15 @@ impl Editor { editor.go_to_active_debug_line(window, cx); - if let Some(buffer) = buffer.read(cx).as_singleton() { - if let Some(project) = editor.project.as_ref() { - let handle = project.update(cx, |project, cx| { - project.register_buffer_with_language_servers(&buffer, cx) - }); - editor - .registered_buffers - .insert(buffer.read(cx).remote_id(), handle); - } + if let Some(buffer) = buffer.read(cx).as_singleton() + && let Some(project) = editor.project() + { + let handle = project.update(cx, |project, cx| { + project.register_buffer_with_language_servers(&buffer, cx) + }); + editor + .registered_buffers + .insert(buffer.read(cx).remote_id(), handle); } editor.minimap = @@ -2371,6 +2411,34 @@ impl Editor { .is_some_and(|menu| menu.context_menu.focus_handle(cx).is_focused(window)) } + pub fn is_range_selected(&mut self, range: &Range, cx: &mut Context) -> bool { + if self + .selections + .pending + .as_ref() + .is_some_and(|pending_selection| { + let snapshot = self.buffer().read(cx).snapshot(cx); + pending_selection + .selection + .range() + .includes(range, &snapshot) + }) + { + return true; + } + + self.selections + .disjoint_in_range::(range.clone(), cx) + .into_iter() + .any(|selection| { + // This is needed to cover a corner case, if we just check for an existing + // selection in the fold range, having a cursor at the start of the fold + // marks it as selected. Non-empty selections don't cause this. + let length = selection.end - selection.start; + length > 0 + }) + } + pub fn key_context(&self, window: &Window, cx: &App) -> KeyContext { self.key_context_internal(self.has_active_edit_prediction(), window, cx) } @@ -2384,7 +2452,7 @@ impl Editor { let mut key_context = KeyContext::new_with_defaults(); key_context.add("Editor"); let mode = match self.mode { - EditorMode::SingleLine { .. } => "single_line", + EditorMode::SingleLine => "single_line", EditorMode::AutoHeight { .. } => "auto_height", EditorMode::Minimap { .. } => "minimap", EditorMode::Full { .. } => "full", @@ -2490,9 +2558,7 @@ impl Editor { .context_menu .borrow() .as_ref() - .map_or(false, |context| { - matches!(context, CodeContextMenu::Completions(_)) - }); + .is_some_and(|context| matches!(context, CodeContextMenu::Completions(_))); showing_completions || self.edit_prediction_requires_modifier() @@ -2523,7 +2589,7 @@ impl Editor { || binding .keystrokes() .first() - .map_or(false, |keystroke| keystroke.modifiers.modified()) + .is_some_and(|keystroke| keystroke.modifiers().modified()) })) } @@ -2553,7 +2619,7 @@ impl Editor { cx: &mut Context, ) -> Task>> { let project = workspace.project().clone(); - let create = project.update(cx, |project, cx| project.create_buffer(cx)); + let create = project.update(cx, |project, cx| project.create_buffer(true, cx)); cx.spawn_in(window, async move |workspace, cx| { let buffer = create.await?; @@ -2591,7 +2657,7 @@ impl Editor { cx: &mut Context, ) { let project = workspace.project().clone(); - let create = project.update(cx, |project, cx| project.create_buffer(cx)); + let create = project.update(cx, |project, cx| project.create_buffer(true, cx)); cx.spawn_in(window, async move |workspace, cx| { let buffer = create.await?; @@ -2626,6 +2692,10 @@ impl Editor { &self.buffer } + pub fn project(&self) -> Option<&Entity> { + self.project.as_ref() + } + pub fn workspace(&self) -> Option> { self.workspace.as_ref()?.0.upgrade() } @@ -2658,9 +2728,12 @@ impl Editor { show_breakpoints: self.show_breakpoints, git_blame_gutter_max_author_length, display_snapshot: self.display_map.update(cx, |map, cx| map.snapshot(cx)), + placeholder_display_snapshot: self + .placeholder_display_map + .as_ref() + .map(|display_map| display_map.update(cx, |map, cx| map.snapshot(cx))), scroll_anchor: self.scroll_manager.anchor(), ongoing_scroll: self.scroll_manager.ongoing_scroll(), - placeholder_text: self.placeholder_text.clone(), is_focused: self.focus_handle.is_focused(window), current_line_highlight: self .current_line_highlight @@ -2756,20 +2829,37 @@ impl Editor { self.refresh_edit_prediction(false, false, window, cx); } - pub fn placeholder_text(&self) -> Option<&str> { - self.placeholder_text.as_deref() + pub fn placeholder_text(&self, cx: &mut App) -> Option { + self.placeholder_display_map + .as_ref() + .map(|display_map| display_map.update(cx, |map, cx| map.snapshot(cx)).text()) } pub fn set_placeholder_text( &mut self, - placeholder_text: impl Into>, + placeholder_text: &str, + window: &mut Window, cx: &mut Context, ) { - let placeholder_text = Some(placeholder_text.into()); - if self.placeholder_text != placeholder_text { - self.placeholder_text = placeholder_text; - cx.notify(); - } + let multibuffer = cx + .new(|cx| MultiBuffer::singleton(cx.new(|cx| Buffer::local(placeholder_text, cx)), cx)); + + let style = window.text_style(); + + self.placeholder_display_map = Some(cx.new(|cx| { + DisplayMap::new( + multibuffer, + style.font(), + style.font_size.to_pixels(window.rem_size()), + None, + FILE_HEADER_HEIGHT, + MULTI_BUFFER_EXCERPT_HEADER_HEIGHT, + Default::default(), + DiagnosticSeverity::Off, + cx, + ) + })); + cx.notify(); } pub fn set_cursor_shape(&mut self, cursor_shape: CursorShape, cx: &mut Context) { @@ -2915,7 +3005,7 @@ impl Editor { return false; }; - scope.override_name().map_or(false, |scope_name| { + scope.override_name().is_some_and(|scope_name| { settings .edit_predictions_disabled_in .iter() @@ -3005,20 +3095,19 @@ impl Editor { } if local { - if let Some(buffer_id) = new_cursor_position.buffer_id { - if !self.registered_buffers.contains_key(&buffer_id) { - if let Some(project) = self.project.as_ref() { - project.update(cx, |project, cx| { - let Some(buffer) = self.buffer.read(cx).buffer(buffer_id) else { - return; - }; - self.registered_buffers.insert( - buffer_id, - project.register_buffer_with_language_servers(&buffer, cx), - ); - }) - } - } + if let Some(buffer_id) = new_cursor_position.buffer_id + && !self.registered_buffers.contains_key(&buffer_id) + && let Some(project) = self.project.as_ref() + { + project.update(cx, |project, cx| { + let Some(buffer) = self.buffer.read(cx).buffer(buffer_id) else { + return; + }; + self.registered_buffers.insert( + buffer_id, + project.register_buffer_with_language_servers(&buffer, cx), + ); + }) } let mut context_menu = self.context_menu.borrow_mut(); @@ -3033,28 +3122,28 @@ impl Editor { let completion_position = completion_menu.map(|menu| menu.initial_position); drop(context_menu); - if effects.completions { - if let Some(completion_position) = completion_position { - let start_offset = selection_start.to_offset(buffer); - let position_matches = start_offset == completion_position.to_offset(buffer); - let continue_showing = if position_matches { - if self.snippet_stack.is_empty() { - buffer.char_kind_before(start_offset, true) == Some(CharKind::Word) - } else { - // Snippet choices can be shown even when the cursor is in whitespace. - // Dismissing the menu with actions like backspace is handled by - // invalidation regions. - true - } - } else { - false - }; - - if continue_showing { - self.show_completions(&ShowCompletions { trigger: None }, window, cx); + if effects.completions + && let Some(completion_position) = completion_position + { + let start_offset = selection_start.to_offset(buffer); + let position_matches = start_offset == completion_position.to_offset(buffer); + let continue_showing = if position_matches { + if self.snippet_stack.is_empty() { + buffer.char_kind_before(start_offset, true) == Some(CharKind::Word) } else { - self.hide_context_menu(window, cx); + // Snippet choices can be shown even when the cursor is in whitespace. + // Dismissing the menu with actions like backspace is handled by + // invalidation regions. + true } + } else { + false + }; + + if continue_showing { + self.show_completions(&ShowCompletions { trigger: None }, window, cx); + } else { + self.hide_context_menu(window, cx); } } @@ -3085,30 +3174,27 @@ impl Editor { if selections.len() == 1 { cx.emit(SearchEvent::ActiveMatchChanged) } - if local { - if let Some((_, _, buffer_snapshot)) = buffer.as_singleton() { - let inmemory_selections = selections - .iter() - .map(|s| { - text::ToPoint::to_point(&s.range().start.text_anchor, buffer_snapshot) - ..text::ToPoint::to_point(&s.range().end.text_anchor, buffer_snapshot) - }) - .collect(); - self.update_restoration_data(cx, |data| { - data.selections = inmemory_selections; - }); + if local && let Some((_, _, buffer_snapshot)) = buffer.as_singleton() { + let inmemory_selections = selections + .iter() + .map(|s| { + text::ToPoint::to_point(&s.range().start.text_anchor, buffer_snapshot) + ..text::ToPoint::to_point(&s.range().end.text_anchor, buffer_snapshot) + }) + .collect(); + self.update_restoration_data(cx, |data| { + data.selections = inmemory_selections; + }); - if WorkspaceSettings::get(None, cx).restore_on_startup - != RestoreOnStartupBehavior::None - { - if let Some(workspace_id) = - self.workspace.as_ref().and_then(|workspace| workspace.1) - { - let snapshot = self.buffer().read(cx).snapshot(cx); - let selections = selections.clone(); - let background_executor = cx.background_executor().clone(); - let editor_id = cx.entity().entity_id().as_u64() as ItemId; - self.serialize_selections = cx.background_spawn(async move { + if WorkspaceSettings::get(None, cx).restore_on_startup != RestoreOnStartupBehavior::None + && let Some(workspace_id) = + self.workspace.as_ref().and_then(|workspace| workspace.1) + { + let snapshot = self.buffer().read(cx).snapshot(cx); + let selections = selections.clone(); + let background_executor = cx.background_executor().clone(); + let editor_id = cx.entity().entity_id().as_u64() as ItemId; + self.serialize_selections = cx.background_spawn(async move { background_executor.timer(SERIALIZATION_THROTTLE_TIME).await; let db_selections = selections .iter() @@ -3125,8 +3211,6 @@ impl Editor { .with_context(|| format!("persisting editor selections for editor {editor_id}, workspace {workspace_id:?}")) .log_err(); }); - } - } } } @@ -3203,35 +3287,31 @@ impl Editor { selections.select_anchors(other_selections); }); - let other_subscription = - cx.subscribe(&other, |this, other, other_evt, cx| match other_evt { - EditorEvent::SelectionsChanged { local: true } => { - let other_selections = other.read(cx).selections.disjoint.to_vec(); - if other_selections.is_empty() { - return; - } - this.selections.change_with(cx, |selections| { - selections.select_anchors(other_selections); - }); + let other_subscription = cx.subscribe(&other, |this, other, other_evt, cx| { + if let EditorEvent::SelectionsChanged { local: true } = other_evt { + let other_selections = other.read(cx).selections.disjoint.to_vec(); + if other_selections.is_empty() { + return; } - _ => {} - }); + this.selections.change_with(cx, |selections| { + selections.select_anchors(other_selections); + }); + } + }); - let this_subscription = - cx.subscribe_self::(move |this, this_evt, cx| match this_evt { - EditorEvent::SelectionsChanged { local: true } => { - let these_selections = this.selections.disjoint.to_vec(); - if these_selections.is_empty() { - return; - } - other.update(cx, |other_editor, cx| { - other_editor.selections.change_with(cx, |selections| { - selections.select_anchors(these_selections); - }) - }); + let this_subscription = cx.subscribe_self::(move |this, this_evt, cx| { + if let EditorEvent::SelectionsChanged { local: true } = this_evt { + let these_selections = this.selections.disjoint.to_vec(); + if these_selections.is_empty() { + return; } - _ => {} - }); + other.update(cx, |other_editor, cx| { + other_editor.selections.change_with(cx, |selections| { + selections.select_anchors(these_selections); + }) + }); + } + }); Subscription::join(other_subscription, this_subscription) } @@ -3312,9 +3392,9 @@ impl Editor { let old_cursor_position = &state.old_cursor_position; - self.selections_did_change(true, &old_cursor_position, state.effects, window, cx); + self.selections_did_change(true, old_cursor_position, state.effects, window, cx); - if self.should_open_signature_help_automatically(&old_cursor_position, cx) { + if self.should_open_signature_help_automatically(old_cursor_position, cx) { self.show_signature_help(&ShowSignatureHelp, window, cx); } } @@ -3734,9 +3814,9 @@ impl Editor { ColumnarSelectionState::FromMouse { selection_tail, display_point, - } => display_point.unwrap_or_else(|| selection_tail.to_display_point(&display_map)), + } => display_point.unwrap_or_else(|| selection_tail.to_display_point(display_map)), ColumnarSelectionState::FromSelection { selection_tail } => { - selection_tail.to_display_point(&display_map) + selection_tail.to_display_point(display_map) } }; @@ -4013,18 +4093,18 @@ impl Editor { let following_text_allows_autoclose = snapshot .chars_at(selection.start) .next() - .map_or(true, |c| scope.should_autoclose_before(c)); + .is_none_or(|c| scope.should_autoclose_before(c)); let preceding_text_allows_autoclose = selection.start.column == 0 - || snapshot.reversed_chars_at(selection.start).next().map_or( - true, - |c| { + || snapshot + .reversed_chars_at(selection.start) + .next() + .is_none_or(|c| { bracket_pair.start != bracket_pair.end || !snapshot .char_classifier_at(selection.start) .is_word(c) - }, - ); + }); let is_closing_quote = if bracket_pair.end == bracket_pair.start && bracket_pair.start.len() == 1 @@ -4124,42 +4204,38 @@ impl Editor { if self.auto_replace_emoji_shortcode && selection.is_empty() && text.as_ref().ends_with(':') - { - if let Some(possible_emoji_short_code) = + && let Some(possible_emoji_short_code) = Self::find_possible_emoji_shortcode_at_position(&snapshot, selection.start) - { - if !possible_emoji_short_code.is_empty() { - if let Some(emoji) = emojis::get_by_shortcode(&possible_emoji_short_code) { - let emoji_shortcode_start = Point::new( - selection.start.row, - selection.start.column - possible_emoji_short_code.len() as u32 - 1, - ); + && !possible_emoji_short_code.is_empty() + && let Some(emoji) = emojis::get_by_shortcode(&possible_emoji_short_code) + { + let emoji_shortcode_start = Point::new( + selection.start.row, + selection.start.column - possible_emoji_short_code.len() as u32 - 1, + ); - // Remove shortcode from buffer - edits.push(( - emoji_shortcode_start..selection.start, - "".to_string().into(), - )); - new_selections.push(( - Selection { - id: selection.id, - start: snapshot.anchor_after(emoji_shortcode_start), - end: snapshot.anchor_before(selection.start), - reversed: selection.reversed, - goal: selection.goal, - }, - 0, - )); + // Remove shortcode from buffer + edits.push(( + emoji_shortcode_start..selection.start, + "".to_string().into(), + )); + new_selections.push(( + Selection { + id: selection.id, + start: snapshot.anchor_after(emoji_shortcode_start), + end: snapshot.anchor_before(selection.start), + reversed: selection.reversed, + goal: selection.goal, + }, + 0, + )); - // Insert emoji - let selection_start_anchor = snapshot.anchor_after(selection.start); - new_selections.push((selection.map(|_| selection_start_anchor), 0)); - edits.push((selection.start..selection.end, emoji.to_string().into())); + // Insert emoji + let selection_start_anchor = snapshot.anchor_after(selection.start); + new_selections.push((selection.map(|_| selection_start_anchor), 0)); + edits.push((selection.start..selection.end, emoji.to_string().into())); - continue; - } - } - } + continue; } // If not handling any auto-close operation, then just replace the selected @@ -4169,7 +4245,7 @@ impl Editor { if !self.linked_edit_ranges.is_empty() { let start_anchor = snapshot.anchor_before(selection.start); - let is_word_char = text.chars().next().map_or(true, |char| { + let is_word_char = text.chars().next().is_none_or(|char| { let classifier = snapshot .char_classifier_at(start_anchor.to_offset(&snapshot)) .ignore_punctuation(true); @@ -4273,12 +4349,11 @@ impl Editor { |s| s.select(new_selections), ); - if !bracket_inserted { - if let Some(on_type_format_task) = + if !bracket_inserted + && let Some(on_type_format_task) = this.trigger_on_type_formatting(text.to_string(), window, cx) - { - on_type_format_task.detach_and_log_err(cx); - } + { + on_type_format_task.detach_and_log_err(cx); } let editor_settings = EditorSettings::get_global(cx); @@ -4524,7 +4599,7 @@ impl Editor { let mut char_position = 0u32; let mut end_tag_offset = None; - 'outer: for chunk in snapshot.text_for_range(range.clone()) { + 'outer: for chunk in snapshot.text_for_range(range) { if let Some(byte_pos) = chunk.find(&**end_tag) { let chars_before_match = chunk[..byte_pos].chars().count() as u32; @@ -4839,8 +4914,15 @@ impl Editor { }); match completions_source { - Some(CompletionsMenuSource::Words) => { - self.show_word_completions(&ShowWordCompletions, window, cx) + Some(CompletionsMenuSource::Words { .. }) => { + self.open_or_update_completions_menu( + Some(CompletionsMenuSource::Words { + ignore_threshold: false, + }), + None, + window, + cx, + ); } Some(CompletionsMenuSource::Normal) | Some(CompletionsMenuSource::SnippetChoices) @@ -4874,11 +4956,7 @@ impl Editor { cx: &mut Context, ) -> bool { let position = self.selections.newest_anchor().head(); - let multibuffer = self.buffer.read(cx); - let Some(buffer) = position - .buffer_id - .and_then(|buffer_id| multibuffer.buffer(buffer_id).clone()) - else { + let Some(buffer) = self.buffer.read(cx).buffer_for_anchor(position, cx) else { return false; }; @@ -5212,7 +5290,7 @@ impl Editor { restrict_to_languages: Option<&HashSet>>, cx: &mut Context, ) -> HashMap, clock::Global, Range)> { - let Some(project) = self.project.as_ref() else { + let Some(project) = self.project() else { return HashMap::default(); }; let project = project.read(cx); @@ -5244,10 +5322,10 @@ impl Editor { } let language = buffer.language()?; - if let Some(restrict_to_languages) = restrict_to_languages { - if !restrict_to_languages.contains(language) { - return None; - } + if let Some(restrict_to_languages) = restrict_to_languages + && !restrict_to_languages.contains(language) + { + return None; } Some(( excerpt_id, @@ -5294,7 +5372,7 @@ impl Editor { return None; } - let project = self.project.as_ref()?; + let project = self.project()?; let position = self.selections.newest_anchor().head(); let (buffer, buffer_position) = self .buffer @@ -5352,7 +5430,14 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - self.open_or_update_completions_menu(Some(CompletionsMenuSource::Words), None, window, cx); + self.open_or_update_completions_menu( + Some(CompletionsMenuSource::Words { + ignore_threshold: true, + }), + None, + window, + cx, + ); } pub fn show_completions( @@ -5401,9 +5486,13 @@ impl Editor { drop(multibuffer_snapshot); + let mut ignore_word_threshold = false; let provider = match requested_source { Some(CompletionsMenuSource::Normal) | None => self.completion_provider.clone(), - Some(CompletionsMenuSource::Words) => None, + Some(CompletionsMenuSource::Words { ignore_threshold }) => { + ignore_word_threshold = ignore_threshold; + None + } Some(CompletionsMenuSource::SnippetChoices) => { log::error!("bug: SnippetChoices requested_source is not handled"); None @@ -5412,11 +5501,11 @@ impl Editor { let sort_completions = provider .as_ref() - .map_or(false, |provider| provider.sort_completions()); + .is_some_and(|provider| provider.sort_completions()); let filter_completions = provider .as_ref() - .map_or(true, |provider| provider.filter_completions()); + .is_none_or(|provider| provider.filter_completions()); let trigger_kind = match trigger { Some(trigger) if buffer.read(cx).completion_triggers().contains(trigger) => { @@ -5522,7 +5611,14 @@ impl Editor { let skip_digits = query .as_ref() - .map_or(true, |query| !query.chars().any(|c| c.is_digit(10))); + .is_none_or(|query| !query.chars().any(|c| c.is_digit(10))); + + let omit_word_completions = !self.word_completions_enabled + || (!ignore_word_threshold + && match &query { + Some(query) => query.chars().count() < completion_settings.words_min_length, + None => completion_settings.words_min_length != 0, + }); let (mut words, provider_responses) = match &provider { Some(provider) => { @@ -5535,9 +5631,11 @@ impl Editor { cx, ); - let words = match completion_settings.words { - WordsCompletionMode::Disabled => Task::ready(BTreeMap::default()), - WordsCompletionMode::Enabled | WordsCompletionMode::Fallback => cx + let words = match (omit_word_completions, completion_settings.words) { + (true, _) | (_, WordsCompletionMode::Disabled) => { + Task::ready(BTreeMap::default()) + } + (false, WordsCompletionMode::Enabled | WordsCompletionMode::Fallback) => cx .background_spawn(async move { buffer_snapshot.words_in_range(WordsQuery { fuzzy_contents: None, @@ -5549,16 +5647,20 @@ impl Editor { (words, provider_responses) } - None => ( - cx.background_spawn(async move { - buffer_snapshot.words_in_range(WordsQuery { - fuzzy_contents: None, - range: word_search_range, - skip_digits, + None => { + let words = if omit_word_completions { + Task::ready(BTreeMap::default()) + } else { + cx.background_spawn(async move { + buffer_snapshot.words_in_range(WordsQuery { + fuzzy_contents: None, + range: word_search_range, + skip_digits, + }) }) - }), - Task::ready(Ok(Vec::new())), - ), + }; + (words, Task::ready(Ok(Vec::new()))) + } }; let snippet_sort_order = EditorSettings::get_global(cx).snippet_sort_order; @@ -5575,17 +5677,25 @@ impl Editor { // that having one source with `is_incomplete: true` doesn't cause all to be re-queried. let mut completions = Vec::new(); let mut is_incomplete = false; - if let Some(provider_responses) = provider_responses.await.log_err() { - if !provider_responses.is_empty() { - for response in provider_responses { - completions.extend(response.completions); - is_incomplete = is_incomplete || response.is_incomplete; - } - if completion_settings.words == WordsCompletionMode::Fallback { - words = Task::ready(BTreeMap::default()); + let mut display_options: Option = None; + if let Some(provider_responses) = provider_responses.await.log_err() + && !provider_responses.is_empty() + { + for response in provider_responses { + completions.extend(response.completions); + is_incomplete = is_incomplete || response.is_incomplete; + match display_options.as_mut() { + None => { + display_options = Some(response.display_options); + } + Some(options) => options.merge(&response.display_options), } } + if completion_settings.words == WordsCompletionMode::Fallback { + words = Task::ready(BTreeMap::default()); + } } + let display_options = display_options.unwrap_or_default(); let mut words = words.await; if let Some(word_to_exclude) = &word_to_exclude { @@ -5627,6 +5737,7 @@ impl Editor { is_incomplete, buffer.clone(), completions.into(), + display_options, snippet_sort_order, languages, language, @@ -5648,34 +5759,31 @@ impl Editor { let Ok(()) = editor.update_in(cx, |editor, window, cx| { // Newer menu already set, so exit. - match editor.context_menu.borrow().as_ref() { - Some(CodeContextMenu::Completions(prev_menu)) => { - if prev_menu.id > id { - return; - } - } - _ => {} + if let Some(CodeContextMenu::Completions(prev_menu)) = + editor.context_menu.borrow().as_ref() + && prev_menu.id > id + { + return; }; // Only valid to take prev_menu because it the new menu is immediately set // below, or the menu is hidden. - match editor.context_menu.borrow_mut().take() { - Some(CodeContextMenu::Completions(prev_menu)) => { - let position_matches = - if prev_menu.initial_position == menu.initial_position { - true - } else { - let snapshot = editor.buffer.read(cx).read(cx); - prev_menu.initial_position.to_offset(&snapshot) - == menu.initial_position.to_offset(&snapshot) - }; - if position_matches { - // Preserve markdown cache before `set_filter_results` because it will - // try to populate the documentation cache. - menu.preserve_markdown_cache(prev_menu); - } + if let Some(CodeContextMenu::Completions(prev_menu)) = + editor.context_menu.borrow_mut().take() + { + let position_matches = + if prev_menu.initial_position == menu.initial_position { + true + } else { + let snapshot = editor.buffer.read(cx).read(cx); + prev_menu.initial_position.to_offset(&snapshot) + == menu.initial_position.to_offset(&snapshot) + }; + if position_matches { + // Preserve markdown cache before `set_filter_results` because it will + // try to populate the documentation cache. + menu.preserve_markdown_cache(prev_menu); } - _ => {} }; menu.set_filter_results(matches, provider, window, cx); @@ -5688,21 +5796,21 @@ impl Editor { editor .update_in(cx, |editor, window, cx| { - if editor.focus_handle.is_focused(window) { - if let Some(menu) = menu { - *editor.context_menu.borrow_mut() = - Some(CodeContextMenu::Completions(menu)); - - crate::hover_popover::hide_hover(editor, cx); - if editor.show_edit_predictions_in_menu() { - editor.update_visible_edit_prediction(window, cx); - } else { - editor.discard_edit_prediction(false, cx); - } + if editor.focus_handle.is_focused(window) + && let Some(menu) = menu + { + *editor.context_menu.borrow_mut() = + Some(CodeContextMenu::Completions(menu)); - cx.notify(); - return; + crate::hover_popover::hide_hover(editor, cx); + if editor.show_edit_predictions_in_menu() { + editor.update_visible_edit_prediction(window, cx); + } else { + editor.discard_edit_prediction(false, cx); } + + cx.notify(); + return; } if editor.completion_tasks.len() <= 1 { @@ -5845,7 +5953,7 @@ impl Editor { multibuffer_anchor.start.to_offset(&snapshot) ..multibuffer_anchor.end.to_offset(&snapshot) }; - if newest_anchor.head().buffer_id != Some(buffer.remote_id()) { + if snapshot.buffer_id_for_anchor(newest_anchor.head()) != Some(buffer.remote_id()) { return None; } @@ -5956,7 +6064,7 @@ impl Editor { let show_new_completions_on_confirm = completion .confirm .as_ref() - .map_or(false, |confirm| confirm(intent, window, cx)); + .is_some_and(|confirm| confirm(intent, window, cx)); if show_new_completions_on_confirm { self.show_completions(&ShowCompletions { trigger: None }, window, cx); } @@ -6049,11 +6157,11 @@ impl Editor { Some(CodeActionSource::Indicator(_)) => Task::ready(Ok(Default::default())), _ => { let mut task_context_task = Task::ready(None); - if let Some(tasks) = &tasks { - if let Some(project) = project { - task_context_task = - Self::build_tasks_context(&project, &buffer, buffer_row, &tasks, cx); - } + if let Some(tasks) = &tasks + && let Some(project) = project + { + task_context_task = + Self::build_tasks_context(&project, &buffer, buffer_row, tasks, cx); } cx.spawn_in(window, { @@ -6088,10 +6196,10 @@ impl Editor { let spawn_straight_away = quick_launch && resolved_tasks .as_ref() - .map_or(false, |tasks| tasks.templates.len() == 1) + .is_some_and(|tasks| tasks.templates.len() == 1) && code_actions .as_ref() - .map_or(true, |actions| actions.is_empty()) + .is_none_or(|actions| actions.is_empty()) && debug_scenarios.is_empty(); editor.update_in(cx, |editor, window, cx| { @@ -6118,14 +6226,14 @@ impl Editor { deployed_from, })); cx.notify(); - if spawn_straight_away { - if let Some(task) = editor.confirm_code_action( + if spawn_straight_away + && let Some(task) = editor.confirm_code_action( &ConfirmCodeAction { item_ix: Some(0) }, window, cx, - ) { - return task; - } + ) + { + return task; } Task::ready(Ok(())) @@ -6141,7 +6249,7 @@ impl Editor { cx: &mut App, ) -> Task> { maybe!({ - let project = self.project.as_ref()?; + let project = self.project()?; let dap_store = project.read(cx).dap_store(); let mut scenarios = vec![]; let resolved_tasks = resolved_tasks.as_ref()?; @@ -6166,12 +6274,11 @@ impl Editor { } }); Some(cx.background_spawn(async move { - let scenarios = futures::future::join_all(scenarios) + futures::future::join_all(scenarios) .await .into_iter() .flatten() - .collect::>(); - scenarios + .collect::>() })) }) .unwrap_or_else(|| Task::ready(vec![])) @@ -6269,7 +6376,7 @@ impl Editor { })) } CodeActionsItem::DebugScenario(scenario) => { - let context = actions_menu.actions.context.clone(); + let context = actions_menu.actions.context; workspace.update(cx, |workspace, cx| { dap::send_telemetry(&scenario, TelemetrySpawnLocation::Gutter, cx); @@ -6288,7 +6395,7 @@ impl Editor { } pub async fn open_project_transaction( - this: &WeakEntity, + editor: &WeakEntity, workspace: WeakEntity, transaction: ProjectTransaction, title: String, @@ -6306,27 +6413,26 @@ impl Editor { if let Some((buffer, transaction)) = entries.first() { if entries.len() == 1 { - let excerpt = this.update(cx, |editor, cx| { + let excerpt = editor.update(cx, |editor, cx| { editor .buffer() .read(cx) .excerpt_containing(editor.selections.newest_anchor().head(), cx) })?; - if let Some((_, excerpted_buffer, excerpt_range)) = excerpt { - if excerpted_buffer == *buffer { - let all_edits_within_excerpt = buffer.read_with(cx, |buffer, _| { - let excerpt_range = excerpt_range.to_offset(buffer); - buffer - .edited_ranges_for_transaction::(transaction) - .all(|range| { - excerpt_range.start <= range.start - && excerpt_range.end >= range.end - }) - })?; + if let Some((_, excerpted_buffer, excerpt_range)) = excerpt + && excerpted_buffer == *buffer + { + let all_edits_within_excerpt = buffer.read_with(cx, |buffer, _| { + let excerpt_range = excerpt_range.to_offset(buffer); + buffer + .edited_ranges_for_transaction::(transaction) + .all(|range| { + excerpt_range.start <= range.start && excerpt_range.end >= range.end + }) + })?; - if all_edits_within_excerpt { - return Ok(()); - } + if all_edits_within_excerpt { + return Ok(()); } } } @@ -6346,7 +6452,7 @@ impl Editor { PathKey::for_buffer(buffer_handle, cx), buffer_handle.clone(), edited_ranges, - DEFAULT_MULTIBUFFER_CONTEXT, + multibuffer_context_lines(cx), cx, ); @@ -6470,7 +6576,7 @@ impl Editor { fn refresh_code_actions(&mut self, window: &mut Window, cx: &mut Context) -> Option<()> { let newest_selection = self.selections.newest_anchor().clone(); - let newest_selection_adjusted = self.selections.newest_adjusted(cx).clone(); + let newest_selection_adjusted = self.selections.newest_adjusted(cx); let buffer = self.buffer.read(cx); if newest_selection.head().diff_base_anchor.is_some() { return None; @@ -6565,7 +6671,7 @@ impl Editor { buffer_row: Some(point.row), ..Default::default() }; - let Some(blame_entry) = blame + let Some((buffer, blame_entry)) = blame .update(cx, |blame, cx| blame.blame_for_rows(&[row_info], cx).next()) .flatten() else { @@ -6575,12 +6681,19 @@ impl Editor { let anchor = self.selections.newest_anchor().head(); let position = self.to_pixel_point(anchor, &snapshot, window); if let (Some(position), Some(last_bounds)) = (position, self.last_bounds) { - self.show_blame_popover(&blame_entry, position + last_bounds.origin, true, cx); + self.show_blame_popover( + buffer, + &blame_entry, + position + last_bounds.origin, + true, + cx, + ); }; } fn show_blame_popover( &mut self, + buffer: BufferId, blame_entry: &BlameEntry, position: gpui::Point, ignore_timeout: bool, @@ -6604,7 +6717,7 @@ impl Editor { return; }; let blame = blame.read(cx); - let details = blame.details_for_entry(&blame_entry); + let details = blame.details_for_entry(buffer, &blame_entry); let markdown = cx.new(|cx| { Markdown::new( details @@ -6704,11 +6817,10 @@ impl Editor { return; } - let buffer_id = cursor_position.buffer_id; let buffer = this.buffer.read(cx); - if !buffer + if buffer .text_anchor_for_position(cursor_position, cx) - .map_or(false, |(buffer, _)| buffer == cursor_buffer) + .is_none_or(|(buffer, _)| buffer != cursor_buffer) { return; } @@ -6717,8 +6829,8 @@ impl Editor { let mut write_ranges = Vec::new(); let mut read_ranges = Vec::new(); for highlight in highlights { - for (excerpt_id, excerpt_range) in - buffer.excerpts_for_buffer(cursor_buffer.read(cx).remote_id(), cx) + let buffer_id = cursor_buffer.read(cx).remote_id(); + for (excerpt_id, excerpt_range) in buffer.excerpts_for_buffer(buffer_id, cx) { let start = highlight .range @@ -6733,12 +6845,12 @@ impl Editor { } let range = Anchor { - buffer_id, + buffer_id: Some(buffer_id), excerpt_id, text_anchor: start, diff_base_anchor: None, }..Anchor { - buffer_id, + buffer_id: Some(buffer_id), excerpt_id, text_anchor: end, diff_base_anchor: None, @@ -6773,7 +6885,7 @@ impl Editor { &mut self, cx: &mut Context, ) -> Option<(String, Range)> { - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { return None; } if !EditorSettings::get_global(cx).selection_highlight { @@ -6834,7 +6946,7 @@ impl Editor { for (buffer_snapshot, search_range, excerpt_id) in buffer_ranges { match_ranges.extend( regex - .search(&buffer_snapshot, Some(search_range.clone())) + .search(buffer_snapshot, Some(search_range.clone())) .await .into_iter() .filter_map(|match_range| { @@ -6958,9 +7070,7 @@ impl Editor { || self .quick_selection_highlight_task .as_ref() - .map_or(true, |(prev_anchor_range, _)| { - prev_anchor_range != &query_range - }) + .is_none_or(|(prev_anchor_range, _)| prev_anchor_range != &query_range) { let multi_buffer_visible_start = self .scroll_manager @@ -6989,9 +7099,7 @@ impl Editor { || self .debounced_selection_highlight_task .as_ref() - .map_or(true, |(prev_anchor_range, _)| { - prev_anchor_range != &query_range - }) + .is_none_or(|(prev_anchor_range, _)| prev_anchor_range != &query_range) { let multi_buffer_start = multi_buffer_snapshot .anchor_before(0) @@ -7126,9 +7234,7 @@ impl Editor { && self .edit_prediction_provider .as_ref() - .map_or(false, |provider| { - provider.provider.show_completions_in_menu() - }); + .is_some_and(|provider| provider.provider.show_completions_in_menu()); let preview_requires_modifier = all_language_settings(file, cx).edit_predictions_mode() == EditPredictionsMode::Subtle; @@ -7176,7 +7282,7 @@ impl Editor { return Some(false); } let provider = self.edit_prediction_provider()?; - if !provider.is_enabled(&buffer, buffer_position, cx) { + if !provider.is_enabled(buffer, buffer_position, cx) { return Some(false); } let buffer = buffer.read(cx); @@ -7637,16 +7743,16 @@ impl Editor { .keystroke() { modifiers_held = modifiers_held - || (&accept_keystroke.modifiers == modifiers - && accept_keystroke.modifiers.modified()); + || (accept_keystroke.modifiers() == modifiers + && accept_keystroke.modifiers().modified()); }; if let Some(accept_partial_keystroke) = self .accept_edit_prediction_keybind(true, window, cx) .keystroke() { modifiers_held = modifiers_held - || (&accept_partial_keystroke.modifiers == modifiers - && accept_partial_keystroke.modifiers.modified()); + || (accept_partial_keystroke.modifiers() == modifiers + && accept_partial_keystroke.modifiers().modified()); } if modifiers_held { @@ -7696,6 +7802,11 @@ impl Editor { return None; } + if self.ime_transaction.is_some() { + self.discard_edit_prediction(false, cx); + return None; + } + let selection = self.selections.newest_anchor(); let cursor = selection.head(); let multibuffer = self.buffer.read(cx).snapshot(cx); @@ -7712,7 +7823,7 @@ impl Editor { || self .active_edit_prediction .as_ref() - .map_or(false, |completion| { + .is_some_and(|completion| { let invalidation_range = completion.invalidation_range.to_offset(&multibuffer); let invalidation_range = invalidation_range.start..=invalidation_range.end; !invalidation_range.contains(&offset_selection.head()) @@ -7734,6 +7845,11 @@ impl Editor { self.edit_prediction_settings = self.edit_prediction_settings_at_position(&buffer, cursor_buffer_position, cx); + if let EditPredictionSettings::Disabled = self.edit_prediction_settings { + self.discard_edit_prediction(false, cx); + return None; + }; + self.edit_prediction_indent_conflict = multibuffer.is_line_whitespace_upto(cursor); if self.edit_prediction_indent_conflict { @@ -7741,10 +7857,10 @@ impl Editor { let indents = multibuffer.suggested_indents(cursor_point.row..cursor_point.row + 1, cx); - if let Some((_, indent)) = indents.iter().next() { - if indent.len == cursor_point.column { - self.edit_prediction_indent_conflict = false; - } + if let Some((_, indent)) = indents.iter().next() + && indent.len == cursor_point.column + { + self.edit_prediction_indent_conflict = false; } } @@ -7907,7 +8023,7 @@ impl Editor { let snapshot = self.snapshot(window, cx); let multi_buffer_snapshot = &snapshot.display_snapshot.buffer_snapshot; - let Some(project) = self.project.as_ref() else { + let Some(project) = self.project() else { return breakpoint_display_points; }; @@ -7936,7 +8052,7 @@ impl Editor { let multi_buffer_anchor = Anchor::in_buffer(excerpt_id, buffer_snapshot.remote_id(), breakpoint.position); let position = multi_buffer_anchor - .to_point(&multi_buffer_snapshot) + .to_point(multi_buffer_snapshot) .to_display_point(&snapshot); breakpoint_display_points.insert( @@ -8190,8 +8306,6 @@ impl Editor { .icon_color(color) .style(ButtonStyle::Transparent) .on_click(cx.listener({ - let breakpoint = breakpoint.clone(); - move |editor, event: &ClickEvent, window, cx| { let edit_action = if event.modifiers().platform || breakpoint.is_disabled() { BreakpointEditAction::InvertState @@ -8405,7 +8519,7 @@ impl Editor { .context_menu .borrow() .as_ref() - .map_or(false, |menu| menu.visible()) + .is_some_and(|menu| menu.visible()) } pub fn context_menu_origin(&self) -> Option { @@ -8829,7 +8943,7 @@ impl Editor { } let highlighted_edits = if let Some(edit_preview) = edit_preview.as_ref() { - crate::edit_prediction_edit_text(&snapshot, edits, edit_preview, false, cx) + crate::edit_prediction_edit_text(snapshot, edits, edit_preview, false, cx) } else { // Fallback for providers without edit_preview crate::edit_prediction_fallback_text(edits, cx) @@ -8951,9 +9065,8 @@ impl Editor { let end_row = start_row + line_count as u32; visible_row_range.contains(&start_row) && visible_row_range.contains(&end_row) - && cursor_row.map_or(true, |cursor_row| { - !((start_row..end_row).contains(&cursor_row)) - }) + && cursor_row + .is_none_or(|cursor_row| !((start_row..end_row).contains(&cursor_row))) })?; content_origin @@ -8993,7 +9106,7 @@ impl Editor { let is_platform_style_mac = PlatformStyle::platform() == PlatformStyle::Mac; - let modifiers_color = if accept_keystroke.modifiers == window.modifiers() { + let modifiers_color = if *accept_keystroke.modifiers() == window.modifiers() { Color::Accent } else { Color::Muted @@ -9005,19 +9118,19 @@ impl Editor { .font(theme::ThemeSettings::get_global(cx).buffer_font.clone()) .text_size(TextSize::XSmall.rems(cx)) .child(h_flex().children(ui::render_modifiers( - &accept_keystroke.modifiers, + accept_keystroke.modifiers(), PlatformStyle::platform(), Some(modifiers_color), Some(IconSize::XSmall.rems().into()), true, ))) .when(is_platform_style_mac, |parent| { - parent.child(accept_keystroke.key.clone()) + parent.child(accept_keystroke.key().to_string()) }) .when(!is_platform_style_mac, |parent| { parent.child( Key::new( - util::capitalize(&accept_keystroke.key), + util::capitalize(accept_keystroke.key()), Some(Color::Default), ) .size(Some(IconSize::XSmall.rems().into())), @@ -9120,52 +9233,13 @@ impl Editor { max_width: Pixels, cursor_point: Point, style: &EditorStyle, - accept_keystroke: Option<&gpui::Keystroke>, + accept_keystroke: Option<&gpui::KeybindingKeystroke>, _window: &Window, cx: &mut Context, ) -> Option { let provider = self.edit_prediction_provider.as_ref()?; let provider_icon = Self::get_prediction_provider_icon_name(&self.edit_prediction_provider); - if provider.provider.needs_terms_acceptance(cx) { - return Some( - h_flex() - .min_w(min_width) - .flex_1() - .px_2() - .py_1() - .gap_3() - .elevation_2(cx) - .hover(|style| style.bg(cx.theme().colors().element_hover)) - .id("accept-terms") - .cursor_pointer() - .on_mouse_down(MouseButton::Left, |_, window, _| window.prevent_default()) - .on_click(cx.listener(|this, _event, window, cx| { - cx.stop_propagation(); - this.report_editor_event(ReportEditorEvent::ZetaTosClicked, None, cx); - window.dispatch_action( - zed_actions::OpenZedPredictOnboarding.boxed_clone(), - cx, - ); - })) - .child( - h_flex() - .flex_1() - .gap_2() - .child(Icon::new(provider_icon)) - .child(Label::new("Accept Terms of Service")) - .child(div().w_full()) - .child( - Icon::new(IconName::ArrowUpRight) - .color(Color::Muted) - .size(IconSize::Small), - ) - .into_any_element(), - ) - .into_any(), - ); - } - let is_refreshing = provider.provider.is_refreshing(cx); fn pending_completion_container(icon: IconName) -> Div { @@ -9192,7 +9266,7 @@ impl Editor { .child(div().px_1p5().child(match &prediction.completion { EditPrediction::Move { target, snapshot } => { use text::ToPoint as _; - if target.text_anchor.to_point(&snapshot).row > cursor_point.row + if target.text_anchor.to_point(snapshot).row > cursor_point.row { Icon::new(IconName::ZedPredictDown) } else { @@ -9237,7 +9311,7 @@ impl Editor { accept_keystroke.as_ref(), |el, accept_keystroke| { el.child(h_flex().children(ui::render_modifiers( - &accept_keystroke.modifiers, + accept_keystroke.modifiers(), PlatformStyle::platform(), Some(Color::Default), Some(IconSize::XSmall.rems().into()), @@ -9307,7 +9381,7 @@ impl Editor { .child(completion), ) .when_some(accept_keystroke, |el, accept_keystroke| { - if !accept_keystroke.modifiers.modified() { + if !accept_keystroke.modifiers().modified() { return el; } @@ -9326,7 +9400,7 @@ impl Editor { .font(theme::ThemeSettings::get_global(cx).buffer_font.clone()) .when(is_platform_style_mac, |parent| parent.gap_1()) .child(h_flex().children(ui::render_modifiers( - &accept_keystroke.modifiers, + accept_keystroke.modifiers(), PlatformStyle::platform(), Some(if !has_completion { Color::Muted @@ -9394,7 +9468,7 @@ impl Editor { .gap_2() .flex_1() .child( - if target.text_anchor.to_point(&snapshot).row > cursor_point.row { + if target.text_anchor.to_point(snapshot).row > cursor_point.row { Icon::new(IconName::ZedPredictDown) } else { Icon::new(IconName::ZedPredictUp) @@ -9410,14 +9484,14 @@ impl Editor { snapshot, display_mode: _, } => { - let first_edit_row = edits.first()?.0.start.text_anchor.to_point(&snapshot).row; + let first_edit_row = edits.first()?.0.start.text_anchor.to_point(snapshot).row; let (highlighted_edits, has_more_lines) = if let Some(edit_preview) = edit_preview.as_ref() { - crate::edit_prediction_edit_text(&snapshot, &edits, edit_preview, true, cx) + crate::edit_prediction_edit_text(snapshot, edits, edit_preview, true, cx) .first_line_preview() } else { - crate::edit_prediction_fallback_text(&edits, cx).first_line_preview() + crate::edit_prediction_fallback_text(edits, cx).first_line_preview() }; let styled_text = gpui::StyledText::new(highlighted_edits.text) @@ -9493,10 +9567,10 @@ impl Editor { let context_menu = self.context_menu.borrow_mut().take(); self.stale_edit_prediction_in_menu.take(); self.update_visible_edit_prediction(window, cx); - if let Some(CodeContextMenu::Completions(_)) = &context_menu { - if let Some(completion_provider) = &self.completion_provider { - completion_provider.selection_changed(None, window, cx); - } + if let Some(CodeContextMenu::Completions(_)) = &context_menu + && let Some(completion_provider) = &self.completion_provider + { + completion_provider.selection_changed(None, window, cx); } context_menu } @@ -9507,17 +9581,21 @@ impl Editor { selection: Range, cx: &mut Context, ) { - let buffer_id = match (&selection.start.buffer_id, &selection.end.buffer_id) { - (Some(a), Some(b)) if a == b => a, - _ => { - log::error!("expected anchor range to have matching buffer IDs"); - return; - } + let Some((_, buffer, _)) = self + .buffer() + .read(cx) + .excerpt_containing(selection.start, cx) + else { + return; }; - let multi_buffer = self.buffer().read(cx); - let Some(buffer) = multi_buffer.buffer(*buffer_id) else { + let Some((_, end_buffer, _)) = self.buffer().read(cx).excerpt_containing(selection.end, cx) + else { return; }; + if buffer != end_buffer { + log::error!("expected anchor range to have matching buffer IDs"); + return; + } let id = post_inc(&mut self.next_completion_id); let snippet_sort_order = EditorSettings::get_global(cx).snippet_sort_order; @@ -9563,7 +9641,7 @@ impl Editor { .tabstops .iter() .map(|tabstop| { - let is_end_tabstop = tabstop.ranges.first().map_or(false, |tabstop| { + let is_end_tabstop = tabstop.ranges.first().is_some_and(|tabstop| { tabstop.is_empty() && tabstop.start == snippet.text.len() as isize }); let mut tabstop_ranges = tabstop @@ -9601,10 +9679,10 @@ impl Editor { s.select_ranges(tabstop.ranges.iter().rev().cloned()); }); - if let Some(choices) = &tabstop.choices { - if let Some(selection) = tabstop.ranges.first() { - self.show_snippet_choices(choices, selection.clone(), cx) - } + if let Some(choices) = &tabstop.choices + && let Some(selection) = tabstop.ranges.first() + { + self.show_snippet_choices(choices, selection.clone(), cx) } // If we're already at the last tabstop and it's at the end of the snippet, @@ -9738,10 +9816,10 @@ impl Editor { s.select_ranges(current_ranges.iter().rev().cloned()) }); - if let Some(choices) = &snippet.choices[snippet.active_index] { - if let Some(selection) = current_ranges.first() { - self.show_snippet_choices(&choices, selection.clone(), cx); - } + if let Some(choices) = &snippet.choices[snippet.active_index] + && let Some(selection) = current_ranges.first() + { + self.show_snippet_choices(choices, selection.clone(), cx); } // If snippet state is not at the last tabstop, push it back on the stack @@ -9763,6 +9841,9 @@ impl Editor { } pub fn backspace(&mut self, _: &Backspace, window: &mut Window, cx: &mut Context) { + if self.read_only(cx) { + return; + } self.hide_mouse_cursor(HideMouseCursorOrigin::TypingAction, cx); self.transact(window, cx, |this, window, cx| { this.select_autoclose_pair(window, cx); @@ -9856,6 +9937,9 @@ impl Editor { } pub fn delete(&mut self, _: &Delete, window: &mut Window, cx: &mut Context) { + if self.read_only(cx) { + return; + } self.hide_mouse_cursor(HideMouseCursorOrigin::TypingAction, cx); self.transact(window, cx, |this, window, cx| { this.change_selections(Default::default(), window, cx, |s| { @@ -10138,10 +10222,10 @@ impl Editor { // Avoid re-outdenting a row that has already been outdented by a // previous selection. - if let Some(last_row) = last_outdent { - if last_row == rows.start { - rows.start = rows.start.next_row(); - } + if let Some(last_row) = last_outdent + && last_row == rows.start + { + rows.start = rows.start.next_row(); } let has_multiple_rows = rows.len() > 1; for row in rows.iter_rows() { @@ -10319,11 +10403,11 @@ impl Editor { MultiBufferRow(selection.end.row) }; - if let Some(last_row_range) = row_ranges.last_mut() { - if start <= last_row_range.end { - last_row_range.end = end; - continue; - } + if let Some(last_row_range) = row_ranges.last_mut() + && start <= last_row_range.end + { + last_row_range.end = end; + continue; } row_ranges.push(start..end); } @@ -10425,6 +10509,86 @@ impl Editor { }) } + fn enable_wrap_selections_in_tag(&self, cx: &App) -> bool { + let snapshot = self.buffer.read(cx).snapshot(cx); + for selection in self.selections.disjoint_anchors().iter() { + if snapshot + .language_at(selection.start) + .and_then(|lang| lang.config().wrap_characters.as_ref()) + .is_some() + { + return true; + } + } + false + } + + fn wrap_selections_in_tag( + &mut self, + _: &WrapSelectionsInTag, + window: &mut Window, + cx: &mut Context, + ) { + self.hide_mouse_cursor(HideMouseCursorOrigin::TypingAction, cx); + + let snapshot = self.buffer.read(cx).snapshot(cx); + + let mut edits = Vec::new(); + let mut boundaries = Vec::new(); + + for selection in self.selections.all::(cx).iter() { + let Some(wrap_config) = snapshot + .language_at(selection.start) + .and_then(|lang| lang.config().wrap_characters.clone()) + else { + continue; + }; + + let open_tag = format!("{}{}", wrap_config.start_prefix, wrap_config.start_suffix); + let close_tag = format!("{}{}", wrap_config.end_prefix, wrap_config.end_suffix); + + let start_before = snapshot.anchor_before(selection.start); + let end_after = snapshot.anchor_after(selection.end); + + edits.push((start_before..start_before, open_tag)); + edits.push((end_after..end_after, close_tag)); + + boundaries.push(( + start_before, + end_after, + wrap_config.start_prefix.len(), + wrap_config.end_suffix.len(), + )); + } + + if edits.is_empty() { + return; + } + + self.transact(window, cx, |this, window, cx| { + let buffer = this.buffer.update(cx, |buffer, cx| { + buffer.edit(edits, None, cx); + buffer.snapshot(cx) + }); + + let mut new_selections = Vec::with_capacity(boundaries.len() * 2); + for (start_before, end_after, start_prefix_len, end_suffix_len) in + boundaries.into_iter() + { + let open_offset = start_before.to_offset(&buffer) + start_prefix_len; + let close_offset = end_after.to_offset(&buffer).saturating_sub(end_suffix_len); + new_selections.push(open_offset..open_offset); + new_selections.push(close_offset..close_offset); + } + + this.change_selections(Default::default(), window, cx, |s| { + s.select_ranges(new_selections); + }); + + this.request_autoscroll(Autoscroll::fit(), cx); + }); + } + pub fn reload_file(&mut self, _: &ReloadFile, window: &mut Window, cx: &mut Context) { let Some(project) = self.project.clone() else { return; @@ -10501,7 +10665,7 @@ impl Editor { ) { if let Some(working_directory) = self.active_excerpt(cx).and_then(|(_, buffer, _)| { let project_path = buffer.read(cx).project_path(cx)?; - let project = self.project.as_ref()?.read(cx); + let project = self.project()?.read(cx); let entry = project.entry_for_path(&project_path, cx)?; let parent = match &entry.canonical_path { Some(canonical_path) => canonical_path.to_path_buf(), @@ -10604,16 +10768,12 @@ impl Editor { snapshot: &EditorSnapshot, cx: &mut Context, ) -> Option<(Anchor, Breakpoint)> { - let project = self.project.clone()?; - - let buffer_id = breakpoint_position.buffer_id.or_else(|| { - snapshot - .buffer_snapshot - .buffer_id_for_excerpt(breakpoint_position.excerpt_id) - })?; + let buffer = self + .buffer + .read(cx) + .buffer_for_anchor(breakpoint_position, cx)?; let enclosing_excerpt = breakpoint_position.excerpt_id; - let buffer = project.read(cx).buffer_for_id(buffer_id, cx)?; let buffer_snapshot = buffer.read(cx).snapshot(); let row = buffer_snapshot @@ -10625,8 +10785,7 @@ impl Editor { .buffer_snapshot .anchor_after(Point::new(row, line_len)); - let bp = self - .breakpoint_store + self.breakpoint_store .as_ref()? .read_with(cx, |breakpoint_store, cx| { breakpoint_store @@ -10651,8 +10810,7 @@ impl Editor { None } }) - }); - bp + }) } pub fn edit_log_breakpoint( @@ -10688,7 +10846,7 @@ impl Editor { let cursors = self .selections .disjoint_anchors() - .into_iter() + .iter() .map(|selection| { let cursor_position: Point = selection.head().to_point(&snapshot.buffer_snapshot); @@ -10788,21 +10946,11 @@ impl Editor { return; }; - let Some(buffer_id) = breakpoint_position.buffer_id.or_else(|| { - if breakpoint_position == Anchor::min() { - self.buffer() - .read(cx) - .excerpt_buffer_ids() - .into_iter() - .next() - } else { - None - } - }) else { - return; - }; - - let Some(buffer) = self.buffer().read(cx).buffer(buffer_id) else { + let Some(buffer) = self + .buffer + .read(cx) + .buffer_for_anchor(breakpoint_position, cx) + else { return; }; @@ -10865,7 +11013,7 @@ impl Editor { } pub fn shuffle_lines(&mut self, _: &ShuffleLines, window: &mut Window, cx: &mut Context) { - self.manipulate_immutable_lines(window, cx, |lines| lines.shuffle(&mut thread_rng())) + self.manipulate_immutable_lines(window, cx, |lines| lines.shuffle(&mut rand::rng())) } fn manipulate_lines( @@ -11030,7 +11178,7 @@ impl Editor { let mut col = 0; let mut changed = false; - while let Some(ch) = chars.next() { + for ch in chars.by_ref() { match ch { ' ' => { reindented_line.push(' '); @@ -11086,7 +11234,7 @@ impl Editor { let mut first_non_indent_char = None; let mut changed = false; - while let Some(ch) = chars.next() { + for ch in chars.by_ref() { match ch { ' ' => { // Keep track of spaces. Append \t when we reach tab_size @@ -11285,14 +11433,17 @@ impl Editor { let mut edits = Vec::new(); let mut selection_adjustment = 0i32; - for selection in self.selections.all::(cx) { + for selection in self.selections.all_adjusted(cx) { let selection_is_empty = selection.is_empty(); let (start, end) = if selection_is_empty { let (word_range, _) = buffer.surrounding_word(selection.start, false); (word_range.start, word_range.end) } else { - (selection.start, selection.end) + ( + buffer.point_to_offset(selection.start), + buffer.point_to_offset(selection.end), + ) }; let text = buffer.text_for_range(start..end).collect::(); @@ -11303,7 +11454,8 @@ impl Editor { start: (start as i32 - selection_adjustment) as usize, end: ((start + text.len()) as i32 - selection_adjustment) as usize, goal: SelectionGoal::None, - ..selection + id: selection.id, + reversed: selection.reversed, }); selection_adjustment += old_length - text.len() as i32; @@ -11694,7 +11846,7 @@ impl Editor { let transpose_start = display_map .buffer_snapshot .clip_offset(transpose_offset.saturating_sub(1), Bias::Left); - if edits.last().map_or(true, |e| e.0.end <= transpose_start) { + if edits.last().is_none_or(|e| e.0.end <= transpose_start) { let transpose_end = display_map .buffer_snapshot .clip_offset(transpose_offset + 1, Bias::Right); @@ -11731,6 +11883,18 @@ impl Editor { let buffer = self.buffer.read(cx).snapshot(cx); let selections = self.selections.all::(cx); + #[derive(Clone, Debug, PartialEq)] + enum CommentFormat { + /// single line comment, with prefix for line + Line(String), + /// single line within a block comment, with prefix for line + BlockLine(String), + /// a single line of a block comment that includes the initial delimiter + BlockCommentWithStart(BlockCommentConfig), + /// a single line of a block comment that includes the ending delimiter + BlockCommentWithEnd(BlockCommentConfig), + } + // Split selections to respect paragraph, indent, and comment prefix boundaries. let wrap_ranges = selections.into_iter().flat_map(|selection| { let mut non_blank_rows_iter = (selection.start.row..=selection.end.row) @@ -11747,37 +11911,75 @@ impl Editor { let language_scope = buffer.language_scope_at(selection.head()); let indent_and_prefix_for_row = - |row: u32| -> (IndentSize, Option, Option) { + |row: u32| -> (IndentSize, Option, Option) { let indent = buffer.indent_size_for_line(MultiBufferRow(row)); - let (comment_prefix, rewrap_prefix) = - if let Some(language_scope) = &language_scope { - let indent_end = Point::new(row, indent.len); - let comment_prefix = language_scope + let (comment_prefix, rewrap_prefix) = if let Some(language_scope) = + &language_scope + { + let indent_end = Point::new(row, indent.len); + let line_end = Point::new(row, buffer.line_len(MultiBufferRow(row))); + let line_text_after_indent = buffer + .text_for_range(indent_end..line_end) + .collect::(); + + let is_within_comment_override = buffer + .language_scope_at(indent_end) + .is_some_and(|scope| scope.override_name() == Some("comment")); + let comment_delimiters = if is_within_comment_override { + // we are within a comment syntax node, but we don't + // yet know what kind of comment: block, doc or line + match ( + language_scope.documentation_comment(), + language_scope.block_comment(), + ) { + (Some(config), _) | (_, Some(config)) + if buffer.contains_str_at(indent_end, &config.start) => + { + Some(CommentFormat::BlockCommentWithStart(config.clone())) + } + (Some(config), _) | (_, Some(config)) + if line_text_after_indent.ends_with(config.end.as_ref()) => + { + Some(CommentFormat::BlockCommentWithEnd(config.clone())) + } + (Some(config), _) | (_, Some(config)) + if buffer.contains_str_at(indent_end, &config.prefix) => + { + Some(CommentFormat::BlockLine(config.prefix.to_string())) + } + (_, _) => language_scope + .line_comment_prefixes() + .iter() + .find(|prefix| buffer.contains_str_at(indent_end, prefix)) + .map(|prefix| CommentFormat::Line(prefix.to_string())), + } + } else { + // we not in an overridden comment node, but we may + // be within a non-overridden line comment node + language_scope .line_comment_prefixes() .iter() .find(|prefix| buffer.contains_str_at(indent_end, prefix)) - .map(|prefix| prefix.to_string()); - let line_end = Point::new(row, buffer.line_len(MultiBufferRow(row))); - let line_text_after_indent = buffer - .text_for_range(indent_end..line_end) - .collect::(); - let rewrap_prefix = language_scope - .rewrap_prefixes() - .iter() - .find_map(|prefix_regex| { - prefix_regex.find(&line_text_after_indent).map(|mat| { - if mat.start() == 0 { - Some(mat.as_str().to_string()) - } else { - None - } - }) - }) - .flatten(); - (comment_prefix, rewrap_prefix) - } else { - (None, None) + .map(|prefix| CommentFormat::Line(prefix.to_string())) }; + + let rewrap_prefix = language_scope + .rewrap_prefixes() + .iter() + .find_map(|prefix_regex| { + prefix_regex.find(&line_text_after_indent).map(|mat| { + if mat.start() == 0 { + Some(mat.as_str().to_string()) + } else { + None + } + }) + }) + .flatten(); + (comment_delimiters, rewrap_prefix) + } else { + (None, None) + }; (indent, comment_prefix, rewrap_prefix) }; @@ -11788,22 +11990,22 @@ impl Editor { let mut prev_row = first_row; let ( mut current_range_indent, - mut current_range_comment_prefix, + mut current_range_comment_delimiters, mut current_range_rewrap_prefix, ) = indent_and_prefix_for_row(first_row); for row in non_blank_rows_iter.skip(1) { let has_paragraph_break = row > prev_row + 1; - let (row_indent, row_comment_prefix, row_rewrap_prefix) = + let (row_indent, row_comment_delimiters, row_rewrap_prefix) = indent_and_prefix_for_row(row); let has_indent_change = row_indent != current_range_indent; - let has_comment_change = row_comment_prefix != current_range_comment_prefix; + let has_comment_change = row_comment_delimiters != current_range_comment_delimiters; let has_boundary_change = has_comment_change || row_rewrap_prefix.is_some() - || (has_indent_change && current_range_comment_prefix.is_some()); + || (has_indent_change && current_range_comment_delimiters.is_some()); if has_paragraph_break || has_boundary_change { ranges.push(( @@ -11811,13 +12013,13 @@ impl Editor { Point::new(current_range_start, 0) ..Point::new(prev_row, buffer.line_len(MultiBufferRow(prev_row))), current_range_indent, - current_range_comment_prefix.clone(), + current_range_comment_delimiters.clone(), current_range_rewrap_prefix.clone(), from_empty_selection, )); current_range_start = row; current_range_indent = row_indent; - current_range_comment_prefix = row_comment_prefix; + current_range_comment_delimiters = row_comment_delimiters; current_range_rewrap_prefix = row_rewrap_prefix; } prev_row = row; @@ -11828,7 +12030,7 @@ impl Editor { Point::new(current_range_start, 0) ..Point::new(prev_row, buffer.line_len(MultiBufferRow(prev_row))), current_range_indent, - current_range_comment_prefix, + current_range_comment_delimiters, current_range_rewrap_prefix, from_empty_selection, )); @@ -11842,7 +12044,7 @@ impl Editor { for ( language_settings, wrap_range, - indent_size, + mut indent_size, comment_prefix, rewrap_prefix, from_empty_selection, @@ -11862,16 +12064,26 @@ impl Editor { let tab_size = language_settings.tab_size; + let (line_prefix, inside_comment) = match &comment_prefix { + Some(CommentFormat::Line(prefix) | CommentFormat::BlockLine(prefix)) => { + (Some(prefix.as_str()), true) + } + Some(CommentFormat::BlockCommentWithEnd(BlockCommentConfig { prefix, .. })) => { + (Some(prefix.as_ref()), true) + } + Some(CommentFormat::BlockCommentWithStart(BlockCommentConfig { + start: _, + end: _, + prefix, + tab_size, + })) => { + indent_size.len += tab_size; + (Some(prefix.as_ref()), true) + } + None => (None, false), + }; let indent_prefix = indent_size.chars().collect::(); - let mut line_prefix = indent_prefix.clone(); - let mut inside_comment = false; - if let Some(prefix) = &comment_prefix { - line_prefix.push_str(prefix); - inside_comment = true; - } - if let Some(prefix) = &rewrap_prefix { - line_prefix.push_str(prefix); - } + let line_prefix = format!("{indent_prefix}{}", line_prefix.unwrap_or("")); let allow_rewrap_based_on_language = match language_settings.allow_rewrap { RewrapBehavior::InComments => inside_comment, @@ -11916,6 +12128,8 @@ impl Editor { let start_offset = start.to_offset(&buffer); let end = Point::new(end_row, buffer.line_len(MultiBufferRow(end_row))); let selection_text = buffer.text_for_range(start..end).collect::(); + let mut first_line_delimiter = None; + let mut last_line_delimiter = None; let Some(lines_without_prefixes) = selection_text .lines() .enumerate() @@ -11923,6 +12137,46 @@ impl Editor { let line_trimmed = line.trim_start(); if rewrap_prefix.is_some() && ix > 0 { Ok(line_trimmed) + } else if let Some( + CommentFormat::BlockCommentWithStart(BlockCommentConfig { + start, + prefix, + end, + tab_size, + }) + | CommentFormat::BlockCommentWithEnd(BlockCommentConfig { + start, + prefix, + end, + tab_size, + }), + ) = &comment_prefix + { + let line_trimmed = line_trimmed + .strip_prefix(start.as_ref()) + .map(|s| { + let mut indent_size = indent_size; + indent_size.len -= tab_size; + let indent_prefix: String = indent_size.chars().collect(); + first_line_delimiter = Some((indent_prefix, start)); + s.trim_start() + }) + .unwrap_or(line_trimmed); + let line_trimmed = line_trimmed + .strip_suffix(end.as_ref()) + .map(|s| { + last_line_delimiter = Some(end); + s.trim_end() + }) + .unwrap_or(line_trimmed); + let line_trimmed = line_trimmed + .strip_prefix(prefix.as_ref()) + .unwrap_or(line_trimmed); + Ok(line_trimmed) + } else if let Some(CommentFormat::BlockLine(prefix)) = &comment_prefix { + line_trimmed.strip_prefix(prefix).with_context(|| { + format!("line did not start with prefix {prefix:?}: {line:?}") + }) } else { line_trimmed .strip_prefix(&line_prefix.trim_start()) @@ -11949,14 +12203,25 @@ impl Editor { line_prefix.clone() }; - let wrapped_text = wrap_with_prefix( - line_prefix, - subsequent_lines_prefix, - lines_without_prefixes.join("\n"), - wrap_column, - tab_size, - options.preserve_existing_whitespace, - ); + let wrapped_text = { + let mut wrapped_text = wrap_with_prefix( + line_prefix, + subsequent_lines_prefix, + lines_without_prefixes.join("\n"), + wrap_column, + tab_size, + options.preserve_existing_whitespace, + ); + + if let Some((indent, delimiter)) = first_line_delimiter { + wrapped_text = format!("{indent}{delimiter}\n{wrapped_text}"); + } + if let Some(last_line) = last_line_delimiter { + wrapped_text = format!("{wrapped_text}\n{indent_prefix}{last_line}"); + } + + wrapped_text + }; // TODO: should always use char-based diff while still supporting cursor behavior that // matches vim. @@ -11984,7 +12249,12 @@ impl Editor { .update(cx, |buffer, cx| buffer.edit(edits, None, cx)); } - pub fn cut_common(&mut self, window: &mut Window, cx: &mut Context) -> ClipboardItem { + pub fn cut_common( + &mut self, + cut_no_selection_line: bool, + window: &mut Window, + cx: &mut Context, + ) -> ClipboardItem { let mut text = String::new(); let buffer = self.buffer.read(cx).snapshot(cx); let mut selections = self.selections.all::(cx); @@ -11993,7 +12263,8 @@ impl Editor { let max_point = buffer.max_point(); let mut is_first = true; for selection in &mut selections { - let is_entire_line = selection.is_empty() || self.selections.line_mode; + let is_entire_line = + (selection.is_empty() && cut_no_selection_line) || self.selections.line_mode; if is_entire_line { selection.start = Point::new(selection.start.row, 0); if !selection.is_empty() && selection.end.column == 0 { @@ -12034,7 +12305,7 @@ impl Editor { pub fn cut(&mut self, _: &Cut, window: &mut Window, cx: &mut Context) { self.hide_mouse_cursor(HideMouseCursorOrigin::TypingAction, cx); - let item = self.cut_common(window, cx); + let item = self.cut_common(true, window, cx); cx.write_to_clipboard(item); } @@ -12043,11 +12314,14 @@ impl Editor { self.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { s.move_with(|snapshot, sel| { if sel.is_empty() { - sel.end = DisplayPoint::new(sel.end.row(), snapshot.line_len(sel.end.row())) + sel.end = DisplayPoint::new(sel.end.row(), snapshot.line_len(sel.end.row())); + } + if sel.is_empty() { + sel.end = DisplayPoint::new(sel.end.row() + 1_u32, 0); } }); }); - let item = self.cut_common(window, cx); + let item = self.cut_common(true, window, cx); cx.set_global(KillRing(item)) } @@ -12254,7 +12528,7 @@ impl Editor { let trigger_in_words = this.show_edit_predictions_in_menu() || !had_active_edit_prediction; - this.trigger_completion_on_input(&text, trigger_in_words, window, cx); + this.trigger_completion_on_input(text, trigger_in_words, window, cx); }); } @@ -12608,7 +12882,7 @@ impl Editor { return; } - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } @@ -12732,7 +13006,7 @@ impl Editor { return; } - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } @@ -12934,11 +13208,17 @@ impl Editor { this.change_selections(Default::default(), window, cx, |s| { s.move_with(|map, selection| { if selection.is_empty() { - let cursor = if action.ignore_newlines { + let mut cursor = if action.ignore_newlines { movement::previous_word_start(map, selection.head()) } else { movement::previous_word_start_or_newline(map, selection.head()) }; + cursor = movement::adjust_greedy_deletion( + map, + selection.head(), + cursor, + action.ignore_brackets, + ); selection.set_head(cursor, SelectionGoal::None); } }); @@ -12959,7 +13239,9 @@ impl Editor { this.change_selections(Default::default(), window, cx, |s| { s.move_with(|map, selection| { if selection.is_empty() { - let cursor = movement::previous_subword_start(map, selection.head()); + let mut cursor = movement::previous_subword_start(map, selection.head()); + cursor = + movement::adjust_greedy_deletion(map, selection.head(), cursor, false); selection.set_head(cursor, SelectionGoal::None); } }); @@ -13035,11 +13317,17 @@ impl Editor { this.change_selections(Default::default(), window, cx, |s| { s.move_with(|map, selection| { if selection.is_empty() { - let cursor = if action.ignore_newlines { + let mut cursor = if action.ignore_newlines { movement::next_word_end(map, selection.head()) } else { movement::next_word_end_or_newline(map, selection.head()) }; + cursor = movement::adjust_greedy_deletion( + map, + selection.head(), + cursor, + action.ignore_brackets, + ); selection.set_head(cursor, SelectionGoal::None); } }); @@ -13059,7 +13347,9 @@ impl Editor { this.change_selections(Default::default(), window, cx, |s| { s.move_with(|map, selection| { if selection.is_empty() { - let cursor = movement::next_subword_end(map, selection.head()); + let mut cursor = movement::next_subword_end(map, selection.head()); + cursor = + movement::adjust_greedy_deletion(map, selection.head(), cursor, false); selection.set_head(cursor, SelectionGoal::None); } }); @@ -13193,7 +13483,7 @@ impl Editor { pub fn cut_to_end_of_line( &mut self, - _: &CutToEndOfLine, + action: &CutToEndOfLine, window: &mut Window, cx: &mut Context, ) { @@ -13206,7 +13496,18 @@ impl Editor { window, cx, ); - this.cut(&Cut, window, cx); + if !action.stop_at_newlines { + this.change_selections(Default::default(), window, cx, |s| { + s.move_with(|_, sel| { + if sel.is_empty() { + sel.end = DisplayPoint::new(sel.end.row() + 1_u32, 0); + } + }); + }); + } + this.hide_mouse_cursor(HideMouseCursorOrigin::TypingAction, cx); + let item = this.cut_common(false, window, cx); + cx.write_to_clipboard(item); }); } @@ -13216,7 +13517,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } @@ -13237,7 +13538,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } @@ -13258,7 +13559,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } @@ -13279,7 +13580,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } @@ -13300,7 +13601,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } @@ -13325,7 +13626,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } @@ -13350,7 +13651,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } @@ -13375,7 +13676,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } @@ -13400,7 +13701,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } @@ -13421,7 +13722,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } @@ -13442,7 +13743,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } @@ -13463,7 +13764,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } @@ -13484,7 +13785,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } @@ -13509,7 +13810,7 @@ impl Editor { } pub fn move_to_end(&mut self, _: &MoveToEnd, window: &mut Window, cx: &mut Context) { - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } @@ -13612,7 +13913,7 @@ impl Editor { pub fn split_selection_into_lines( &mut self, - _: &SplitSelectionIntoLines, + action: &SplitSelectionIntoLines, window: &mut Window, cx: &mut Context, ) { @@ -13629,8 +13930,21 @@ impl Editor { let buffer = self.buffer.read(cx).read(cx); for selection in selections { for row in selection.start.row..selection.end.row { - let cursor = Point::new(row, buffer.line_len(MultiBufferRow(row))); - new_selection_ranges.push(cursor..cursor); + let line_start = Point::new(row, 0); + let line_end = Point::new(row, buffer.line_len(MultiBufferRow(row))); + + if action.keep_selections { + // Keep the selection range for each line + let selection_start = if row == selection.start.row { + selection.start + } else { + line_start + }; + new_selection_ranges.push(selection_start..line_end); + } else { + // Collapse to cursor at end of line + new_selection_ranges.push(line_end..line_end); + } } let is_multiline_selection = selection.start.row != selection.end.row; @@ -13638,7 +13952,16 @@ impl Editor { // so this action feels more ergonomic when paired with other selection operations let should_skip_last = is_multiline_selection && selection.end.column == 0; if !should_skip_last { - new_selection_ranges.push(selection.end..selection.end); + if action.keep_selections { + if is_multiline_selection { + let line_start = Point::new(selection.end.row, 0); + new_selection_ranges.push(line_start..selection.end); + } else { + new_selection_ranges.push(selection.start..selection.end); + } + } else { + new_selection_ranges.push(selection.end..selection.end); + } } } } @@ -14536,7 +14859,7 @@ impl Editor { let advance_downwards = action.advance_downwards && selections_on_single_row && !selections_selecting - && !matches!(this.mode, EditorMode::SingleLine { .. }); + && !matches!(this.mode, EditorMode::SingleLine); if advance_downwards { let snapshot = this.buffer.read(cx).snapshot(cx); @@ -14663,9 +14986,13 @@ impl Editor { } let mut new_range = old_range.clone(); - while let Some((_node, containing_range)) = - buffer.syntax_ancestor(new_range.clone()) + while let Some((node, containing_range)) = buffer.syntax_ancestor(new_range.clone()) { + if !node.is_named() { + new_range = node.start_byte()..node.end_byte(); + continue; + } + new_range = match containing_range { MultiOrSingleBufferOffsetRange::Single(_) => break, MultiOrSingleBufferOffsetRange::Multi(range) => range, @@ -14782,15 +15109,18 @@ impl Editor { self.hide_mouse_cursor(HideMouseCursorOrigin::MovementAction, cx); let buffer = self.buffer.read(cx).snapshot(cx); - let old_selections: Box<[_]> = self.selections.all::(cx).into(); + let selections = self + .selections + .all::(cx) + .into_iter() + // subtracting the offset requires sorting + .sorted_by_key(|i| i.start); - let edits = old_selections - .iter() - // only consider the first selection for now - .take(1) - .map(|selection| { + let full_edits = selections + .into_iter() + .filter_map(|selection| { // Only requires two branches once if-let-chains stabilize (#53667) - let selection_range = if !selection.is_empty() { + let child = if !selection.is_empty() { selection.range() } else if let Some((_, ancestor_range)) = buffer.syntax_ancestor(selection.start..selection.end) @@ -14803,49 +15133,151 @@ impl Editor { selection.range() }; - let mut new_range = selection_range.clone(); - while let Some((_, ancestor_range)) = buffer.syntax_ancestor(new_range.clone()) { - new_range = match ancestor_range { + let mut parent = child.clone(); + while let Some((_, ancestor_range)) = buffer.syntax_ancestor(parent.clone()) { + parent = match ancestor_range { MultiOrSingleBufferOffsetRange::Single(range) => range, MultiOrSingleBufferOffsetRange::Multi(range) => range, }; - if new_range.start < selection_range.start - || new_range.end > selection_range.end - { + if parent.start < child.start || parent.end > child.end { break; } } - (selection, selection_range, new_range) + if parent == child { + return None; + } + let text = buffer.text_for_range(child).collect::(); + Some((selection.id, parent, text)) }) .collect::>(); - self.transact(window, cx, |editor, window, cx| { - for (_, child, parent) in &edits { - let text = buffer.text_for_range(child.clone()).collect::(); - editor.replace_text_in_range(Some(parent.clone()), &text, window, cx); - } + self.transact(window, cx, |this, window, cx| { + this.buffer.update(cx, |buffer, cx| { + buffer.edit( + full_edits + .iter() + .map(|(_, p, t)| (p.clone(), t.clone())) + .collect::>(), + None, + cx, + ); + }); + this.change_selections(Default::default(), window, cx, |s| { + let mut offset = 0; + let mut selections = vec![]; + for (id, parent, text) in full_edits { + let start = parent.start - offset; + offset += parent.len() - text.len(); + selections.push(Selection { + id, + start, + end: start + text.len(), + reversed: false, + goal: Default::default(), + }); + } + s.select(selections); + }); + }); + } + + pub fn select_next_syntax_node( + &mut self, + _: &SelectNextSyntaxNode, + window: &mut Window, + cx: &mut Context, + ) { + let old_selections: Box<[_]> = self.selections.all::(cx).into(); + if old_selections.is_empty() { + return; + } + + self.hide_mouse_cursor(HideMouseCursorOrigin::MovementAction, cx); + + let buffer = self.buffer.read(cx).snapshot(cx); + let mut selected_sibling = false; + + let new_selections = old_selections + .iter() + .map(|selection| { + let old_range = selection.start..selection.end; + + if let Some(node) = buffer.syntax_next_sibling(old_range) { + let new_range = node.byte_range(); + selected_sibling = true; + Selection { + id: selection.id, + start: new_range.start, + end: new_range.end, + goal: SelectionGoal::None, + reversed: selection.reversed, + } + } else { + selection.clone() + } + }) + .collect::>(); - editor.change_selections( + if selected_sibling { + self.change_selections( SelectionEffects::scroll(Autoscroll::fit()), window, cx, |s| { - s.select( - edits - .iter() - .map(|(s, old, new)| Selection { - id: s.id, - start: new.start, - end: new.start + old.len(), - goal: SelectionGoal::None, - reversed: s.reversed, - }) - .collect(), - ); + s.select(new_selections); }, ); - }); + } + } + + pub fn select_prev_syntax_node( + &mut self, + _: &SelectPreviousSyntaxNode, + window: &mut Window, + cx: &mut Context, + ) { + let old_selections: Box<[_]> = self.selections.all::(cx).into(); + if old_selections.is_empty() { + return; + } + + self.hide_mouse_cursor(HideMouseCursorOrigin::MovementAction, cx); + + let buffer = self.buffer.read(cx).snapshot(cx); + let mut selected_sibling = false; + + let new_selections = old_selections + .iter() + .map(|selection| { + let old_range = selection.start..selection.end; + + if let Some(node) = buffer.syntax_prev_sibling(old_range) { + let new_range = node.byte_range(); + selected_sibling = true; + Selection { + id: selection.id, + start: new_range.start, + end: new_range.end, + goal: SelectionGoal::None, + reversed: selection.reversed, + } + } else { + selection.clone() + } + }) + .collect::>(); + + if selected_sibling { + self.change_selections( + SelectionEffects::scroll(Autoscroll::fit()), + window, + cx, + |s| { + s.select(new_selections); + }, + ); + } } fn refresh_runnables(&mut self, window: &mut Window, cx: &mut Context) -> Task<()> { @@ -14853,7 +15285,7 @@ impl Editor { self.clear_tasks(); return Task::ready(()); } - let project = self.project.as_ref().map(Entity::downgrade); + let project = self.project().map(Entity::downgrade); let task_sources = self.lsp_task_sources(cx); let multi_buffer = self.buffer.downgrade(); cx.spawn_in(window, async move |editor, cx| { @@ -14868,10 +15300,7 @@ impl Editor { }; let hide_runnables = project - .update(cx, |project, cx| { - // Do not display any test indicators in non-dev server remote projects. - project.is_via_collab() && project.ssh_connection_string(cx).is_none() - }) + .update(cx, |project, _| project.is_via_collab()) .unwrap_or(true); if hide_runnables { return; @@ -15264,17 +15693,15 @@ impl Editor { if direction == ExpandExcerptDirection::Down { let multi_buffer = self.buffer.read(cx); let snapshot = multi_buffer.snapshot(cx); - if let Some(buffer_id) = snapshot.buffer_id_for_excerpt(excerpt) { - if let Some(buffer) = multi_buffer.buffer(buffer_id) { - if let Some(excerpt_range) = snapshot.buffer_range_for_excerpt(excerpt) { - let buffer_snapshot = buffer.read(cx).snapshot(); - let excerpt_end_row = - Point::from_anchor(&excerpt_range.end, &buffer_snapshot).row; - let last_row = buffer_snapshot.max_point().row; - let lines_below = last_row.saturating_sub(excerpt_end_row); - should_scroll_up = lines_below >= lines_to_expand; - } - } + if let Some(buffer_id) = snapshot.buffer_id_for_excerpt(excerpt) + && let Some(buffer) = multi_buffer.buffer(buffer_id) + && let Some(excerpt_range) = snapshot.buffer_range_for_excerpt(excerpt) + { + let buffer_snapshot = buffer.read(cx).snapshot(); + let excerpt_end_row = Point::from_anchor(&excerpt_range.end, &buffer_snapshot).row; + let last_row = buffer_snapshot.max_point().row; + let lines_below = last_row.saturating_sub(excerpt_end_row); + should_scroll_up = lines_below >= lines_to_expand; } } @@ -15359,10 +15786,10 @@ impl Editor { let selection = self.selections.newest::(cx); let mut active_group_id = None; - if let ActiveDiagnostic::Group(active_group) = &self.active_diagnostics { - if active_group.active_range.start.to_offset(&buffer) == selection.start { - active_group_id = Some(active_group.group_id); - } + if let ActiveDiagnostic::Group(active_group) = &self.active_diagnostics + && active_group.active_range.start.to_offset(&buffer) == selection.start + { + active_group_id = Some(active_group.group_id); } fn filtered( @@ -15421,7 +15848,8 @@ impl Editor { return; }; - let Some(buffer_id) = buffer.anchor_after(next_diagnostic.range.start).buffer_id else { + let next_diagnostic_start = buffer.anchor_after(next_diagnostic.range.start); + let Some(buffer_id) = buffer.buffer_id_for_anchor(next_diagnostic_start) else { return; }; self.change_selections(Default::default(), window, cx, |s| { @@ -15560,6 +15988,87 @@ impl Editor { } } + pub fn go_to_next_document_highlight( + &mut self, + _: &GoToNextDocumentHighlight, + window: &mut Window, + cx: &mut Context, + ) { + self.go_to_document_highlight_before_or_after_position(Direction::Next, window, cx); + } + + pub fn go_to_prev_document_highlight( + &mut self, + _: &GoToPreviousDocumentHighlight, + window: &mut Window, + cx: &mut Context, + ) { + self.go_to_document_highlight_before_or_after_position(Direction::Prev, window, cx); + } + + pub fn go_to_document_highlight_before_or_after_position( + &mut self, + direction: Direction, + window: &mut Window, + cx: &mut Context, + ) { + self.hide_mouse_cursor(HideMouseCursorOrigin::MovementAction, cx); + let snapshot = self.snapshot(window, cx); + let buffer = &snapshot.buffer_snapshot; + let position = self.selections.newest::(cx).head(); + let anchor_position = buffer.anchor_after(position); + + // Get all document highlights (both read and write) + let mut all_highlights = Vec::new(); + + if let Some((_, read_highlights)) = self + .background_highlights + .get(&HighlightKey::Type(TypeId::of::())) + { + all_highlights.extend(read_highlights.iter()); + } + + if let Some((_, write_highlights)) = self + .background_highlights + .get(&HighlightKey::Type(TypeId::of::())) + { + all_highlights.extend(write_highlights.iter()); + } + + if all_highlights.is_empty() { + return; + } + + // Sort highlights by position + all_highlights.sort_by(|a, b| a.start.cmp(&b.start, buffer)); + + let target_highlight = match direction { + Direction::Next => { + // Find the first highlight after the current position + all_highlights + .iter() + .find(|highlight| highlight.start.cmp(&anchor_position, buffer).is_gt()) + } + Direction::Prev => { + // Find the last highlight before the current position + all_highlights + .iter() + .rev() + .find(|highlight| highlight.end.cmp(&anchor_position, buffer).is_lt()) + } + }; + + if let Some(highlight) = target_highlight { + let destination = highlight.start.to_point(buffer); + let autoscroll = Autoscroll::center(); + + self.unfold_ranges(&[destination..destination], false, false, cx); + self.change_selections(SelectionEffects::scroll(autoscroll), window, cx, |s| { + s.select_ranges([destination..destination]); + }); + } + } + fn go_to_line( &mut self, position: Anchor, @@ -15699,7 +16208,9 @@ impl Editor { }; cx.spawn_in(window, async move |editor, cx| { - let definitions = definitions.await?; + let Some(definitions) = definitions.await? else { + return Ok(Navigated::No); + }; let navigated = editor .update_in(cx, |editor, window, cx| { editor.navigate_to_hover_links( @@ -15857,8 +16368,15 @@ impl Editor { .text_for_range(location.range.clone()) .collect::() }) + .filter(|text| !text.contains('\n')) + .unique() + .take(3) .join(", "); - format!("{tab_kind} for {target}") + if target.is_empty() { + tab_kind.to_owned() + } else { + format!("{tab_kind} for {target}") + } }) .context("buffer title")?; @@ -15914,7 +16432,7 @@ impl Editor { if !split && Some(&target.buffer) == editor.buffer.read(cx).as_singleton().as_ref() { - editor.go_to_singleton_buffer_range(range.clone(), window, cx); + editor.go_to_singleton_buffer_range(range, window, cx); } else { window.defer(cx, move |window, cx| { let target_editor: Entity = @@ -15963,38 +16481,24 @@ impl Editor { cx.spawn_in(window, async move |editor, cx| { let location_task = editor.update(cx, |_, cx| { project.update(cx, |project, cx| { - let language_server_name = project - .language_server_statuses(cx) - .find(|(id, _)| server_id == *id) - .map(|(_, status)| status.name.clone()); - language_server_name.map(|language_server_name| { - project.open_local_buffer_via_lsp( - lsp_location.uri.clone(), - server_id, - language_server_name, - cx, - ) - }) + project.open_local_buffer_via_lsp(lsp_location.uri.clone(), server_id, cx) }) })?; - let location = match location_task { - Some(task) => Some({ - let target_buffer_handle = task.await.context("open local buffer")?; - let range = target_buffer_handle.read_with(cx, |target_buffer, _| { - let target_start = target_buffer - .clip_point_utf16(point_from_lsp(lsp_location.range.start), Bias::Left); - let target_end = target_buffer - .clip_point_utf16(point_from_lsp(lsp_location.range.end), Bias::Left); - target_buffer.anchor_after(target_start) - ..target_buffer.anchor_before(target_end) - })?; - Location { - buffer: target_buffer_handle, - range, - } - }), - None => None, - }; + let location = Some({ + let target_buffer_handle = location_task.await.context("open local buffer")?; + let range = target_buffer_handle.read_with(cx, |target_buffer, _| { + let target_start = target_buffer + .clip_point_utf16(point_from_lsp(lsp_location.range.start), Bias::Left); + let target_end = target_buffer + .clip_point_utf16(point_from_lsp(lsp_location.range.end), Bias::Left); + target_buffer.anchor_after(target_start) + ..target_buffer.anchor_before(target_end) + })?; + Location { + buffer: target_buffer_handle, + range, + } + }); Ok(location) }) } @@ -16048,7 +16552,9 @@ impl Editor { } }); - let locations = references.await?; + let Some(locations) = references.await? else { + return anyhow::Ok(Navigated::No); + }; if locations.is_empty() { return anyhow::Ok(Navigated::No); } @@ -16063,8 +16569,15 @@ impl Editor { .text_for_range(location.range.clone()) .collect::() }) + .filter(|text| !text.contains('\n')) + .unique() + .take(3) .join(", "); - let title = format!("References to {target}"); + let title = if target.is_empty() { + "References".to_owned() + } else { + format!("References to {target}") + }; Self::open_locations_in_multibuffer( workspace, locations, @@ -16122,7 +16635,7 @@ impl Editor { PathKey::for_buffer(&location.buffer, cx), location.buffer.clone(), ranges_for_buffer, - DEFAULT_MULTIBUFFER_CONTEXT, + multibuffer_context_lines(cx), cx, ); ranges.extend(new_ranges) @@ -16179,24 +16692,22 @@ impl Editor { let item_id = item.item_id(); if split { - workspace.split_item(SplitDirection::Right, item.clone(), window, cx); - } else { - if PreviewTabsSettings::get_global(cx).enable_preview_from_code_navigation { - let (preview_item_id, preview_item_idx) = - workspace.active_pane().read_with(cx, |pane, _| { - (pane.preview_item_id(), pane.preview_item_idx()) - }); + workspace.split_item(SplitDirection::Right, item, window, cx); + } else if PreviewTabsSettings::get_global(cx).enable_preview_from_code_navigation { + let (preview_item_id, preview_item_idx) = + workspace.active_pane().read_with(cx, |pane, _| { + (pane.preview_item_id(), pane.preview_item_idx()) + }); - workspace.add_item_to_active_pane(item.clone(), preview_item_idx, true, window, cx); + workspace.add_item_to_active_pane(item, preview_item_idx, true, window, cx); - if let Some(preview_item_id) = preview_item_id { - workspace.active_pane().update(cx, |pane, cx| { - pane.remove_item(preview_item_id, false, false, window, cx); - }); - } - } else { - workspace.add_item_to_active_pane(item.clone(), None, true, window, cx); + if let Some(preview_item_id) = preview_item_id { + workspace.active_pane().update(cx, |pane, cx| { + pane.remove_item(preview_item_id, false, false, window, cx); + }); } + } else { + workspace.add_item_to_active_pane(item, None, true, window, cx); } workspace.active_pane().update(cx, |pane, cx| { pane.set_preview_item_id(Some(item_id), cx); @@ -16586,10 +17097,7 @@ impl Editor { .transaction(transaction_id_prev) .map(|t| t.0.clone()) }) - .unwrap_or_else(|| { - log::info!("Failed to determine selections from before format. Falling back to selections when format was initiated"); - self.selections.disjoint_anchors() - }); + .unwrap_or_else(|| self.selections.disjoint_anchors()); let mut timeout = cx.background_executor().timer(FORMAT_TIMEOUT).fuse(); let format = project.update(cx, |project, cx| { @@ -16607,10 +17115,10 @@ impl Editor { buffer .update(cx, |buffer, cx| { - if let Some(transaction) = transaction { - if !buffer.is_singleton() { - buffer.push_transaction(&transaction.0, cx); - } + if let Some(transaction) = transaction + && !buffer.is_singleton() + { + buffer.push_transaction(&transaction.0, cx); } cx.notify(); }) @@ -16676,10 +17184,10 @@ impl Editor { buffer .update(cx, |buffer, cx| { // check if we need this - if let Some(transaction) = transaction { - if !buffer.is_singleton() { - buffer.push_transaction(&transaction.0, cx); - } + if let Some(transaction) = transaction + && !buffer.is_singleton() + { + buffer.push_transaction(&transaction.0, cx); } cx.notify(); }) @@ -16858,6 +17366,10 @@ impl Editor { self.inline_diagnostics.clear(); } + pub fn disable_word_completions(&mut self) { + self.word_completions_enabled = false; + } + pub fn diagnostics_enabled(&self) -> bool { self.diagnostics_enabled && self.mode.is_full() } @@ -17028,7 +17540,7 @@ impl Editor { if !pull_diagnostics_settings.enabled { return None; } - let project = self.project.as_ref()?.downgrade(); + let project = self.project()?.downgrade(); let debounce = Duration::from_millis(pull_diagnostics_settings.debounce_ms); let mut buffers = self.buffer.read(cx).all_buffers(); if let Some(buffer_id) = buffer_id { @@ -17311,12 +17823,12 @@ impl Editor { } for row in (0..=range.start.row).rev() { - if let Some(crease) = display_map.crease_for_buffer_row(MultiBufferRow(row)) { - if crease.range().end.row >= buffer_start_row { - to_fold.push(crease); - if row <= range.start.row { - break; - } + if let Some(crease) = display_map.crease_for_buffer_row(MultiBufferRow(row)) + && crease.range().end.row >= buffer_start_row + { + to_fold.push(crease); + if row <= range.start.row { + break; } } } @@ -17844,7 +18356,7 @@ impl Editor { ranges: &[Range], snapshot: &MultiBufferSnapshot, ) -> bool { - let mut hunks = self.diff_hunks_in_ranges(ranges, &snapshot); + let mut hunks = self.diff_hunks_in_ranges(ranges, snapshot); hunks.any(|hunk| hunk.status().has_secondary_hunk()) } @@ -17992,7 +18504,7 @@ impl Editor { hunks: impl Iterator, cx: &mut App, ) -> Option<()> { - let project = self.project.as_ref()?; + let project = self.project()?; let buffer = project.read(cx).buffer_for_id(buffer_id, cx)?; let diff = self.buffer.read(cx).diff_for(buffer_id)?; let buffer_snapshot = buffer.read(cx).snapshot(); @@ -18394,12 +18906,7 @@ impl Editor { } /// called by the Element so we know what style we were most recently rendered with. - pub(crate) fn set_style( - &mut self, - style: EditorStyle, - window: &mut Window, - cx: &mut Context, - ) { + pub fn set_style(&mut self, style: EditorStyle, window: &mut Window, cx: &mut Context) { // We intentionally do not inform the display map about the minimap style // so that wrapping is not recalculated and stays consistent for the editor // and its linked minimap. @@ -18423,8 +18930,16 @@ impl Editor { // Called by the element. This method is not designed to be called outside of the editor // element's layout code because it does not notify when rewrapping is computed synchronously. pub(crate) fn set_wrap_width(&self, width: Option, cx: &mut App) -> bool { - self.display_map - .update(cx, |map, cx| map.set_wrap_width(width, cx)) + if self.is_empty(cx) { + self.placeholder_display_map + .as_ref() + .map_or(false, |display_map| { + display_map.update(cx, |map, cx| map.set_wrap_width(width, cx)) + }) + } else { + self.display_map + .update(cx, |map, cx| map.set_wrap_width(width, cx)) + } } pub fn set_soft_wrap(&mut self) { @@ -18626,10 +19141,10 @@ impl Editor { pub fn working_directory(&self, cx: &App) -> Option { if let Some(buffer) = self.buffer().read(cx).as_singleton() { - if let Some(file) = buffer.read(cx).file().and_then(|f| f.as_local()) { - if let Some(dir) = file.abs_path(cx).parent() { - return Some(dir.to_owned()); - } + if let Some(file) = buffer.read(cx).file().and_then(|f| f.as_local()) + && let Some(dir) = file.abs_path(cx).parent() + { + return Some(dir.to_owned()); } if let Some(project_path) = buffer.read(cx).project_path(cx) { @@ -18652,7 +19167,7 @@ impl Editor { self.active_excerpt(cx).and_then(|(_, buffer, _)| { let buffer = buffer.read(cx); if let Some(project_path) = buffer.project_path(cx) { - let project = self.project.as_ref()?.read(cx); + let project = self.project()?.read(cx); project.absolute_path(&project_path, cx) } else { buffer @@ -18665,7 +19180,7 @@ impl Editor { fn target_file_path(&self, cx: &mut Context) -> Option { self.active_excerpt(cx).and_then(|(_, buffer, _)| { let project_path = buffer.read(cx).project_path(cx)?; - let project = self.project.as_ref()?.read(cx); + let project = self.project()?.read(cx); let entry = project.entry_for_path(&project_path, cx)?; let path = entry.path.to_path_buf(); Some(path) @@ -18689,10 +19204,10 @@ impl Editor { _window: &mut Window, cx: &mut Context, ) { - if let Some(path) = self.target_file_abs_path(cx) { - if let Some(path) = path.to_str() { - cx.write_to_clipboard(ClipboardItem::new_string(path.to_string())); - } + if let Some(path) = self.target_file_abs_path(cx) + && let Some(path) = path.to_str() + { + cx.write_to_clipboard(ClipboardItem::new_string(path.to_string())); } } @@ -18702,13 +19217,15 @@ impl Editor { _window: &mut Window, cx: &mut Context, ) { - if let Some(path) = self.target_file_path(cx) { - if let Some(path) = path.to_str() { - cx.write_to_clipboard(ClipboardItem::new_string(path.to_string())); - } + if let Some(path) = self.target_file_path(cx) + && let Some(path) = path.to_str() + { + cx.write_to_clipboard(ClipboardItem::new_string(path.to_string())); } } + /// Returns the project path for the editor's buffer, if any buffer is + /// opened in the editor. pub fn project_path(&self, cx: &App) -> Option { if let Some(buffer) = self.buffer.read(cx).as_singleton() { buffer.read(cx).project_path(cx) @@ -18774,22 +19291,20 @@ impl Editor { _: &mut Window, cx: &mut Context, ) { - if let Some(file) = self.target_file(cx) { - if let Some(file_stem) = file.path().file_stem() { - if let Some(name) = file_stem.to_str() { - cx.write_to_clipboard(ClipboardItem::new_string(name.to_string())); - } - } + if let Some(file) = self.target_file(cx) + && let Some(file_stem) = file.path().file_stem() + && let Some(name) = file_stem.to_str() + { + cx.write_to_clipboard(ClipboardItem::new_string(name.to_string())); } } pub fn copy_file_name(&mut self, _: &CopyFileName, _: &mut Window, cx: &mut Context) { - if let Some(file) = self.target_file(cx) { - if let Some(file_name) = file.path().file_name() { - if let Some(name) = file_name.to_str() { - cx.write_to_clipboard(ClipboardItem::new_string(name.to_string())); - } - } + if let Some(file) = self.target_file(cx) + && let Some(file_name) = file.path().file_name() + && let Some(name) = file_name.to_str() + { + cx.write_to_clipboard(ClipboardItem::new_string(name.to_string())); } } @@ -18836,7 +19351,7 @@ impl Editor { let snapshot = self.snapshot(window, cx); let cursor = self.selections.newest::(cx).head(); let (buffer, point, _) = snapshot.buffer_snapshot.point_to_buffer_point(cursor)?; - let blame_entry = blame + let (_, blame_entry) = blame .update(cx, |blame, cx| { blame .blame_for_rows( @@ -18851,7 +19366,7 @@ impl Editor { }) .flatten()?; let renderer = cx.global::().0.clone(); - let repo = blame.read(cx).repository(cx)?; + let repo = blame.read(cx).repository(cx, buffer.remote_id())?; let workspace = self.workspace()?.downgrade(); renderer.open_blame_commit(blame_entry, repo, workspace, window, cx); None @@ -18886,19 +19401,18 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if let Some(project) = self.project.as_ref() { - let Some(buffer) = self.buffer().read(cx).as_singleton() else { - return; - }; - - if buffer.read(cx).file().is_none() { + if let Some(project) = self.project() { + if let Some(buffer) = self.buffer().read(cx).as_singleton() + && buffer.read(cx).file().is_none() + { return; } let focused = self.focus_handle(cx).contains_focused(window, cx); let project = project.clone(); - let blame = cx.new(|cx| GitBlame::new(buffer, project, user_triggered, focused, cx)); + let blame = cx + .new(|cx| GitBlame::new(self.buffer.clone(), project, user_triggered, focused, cx)); self.blame_subscription = Some(cx.observe_in(&blame, window, |_, _, _, cx| cx.notify())); self.blame = Some(blame); @@ -18963,7 +19477,7 @@ impl Editor { fn has_blame_entries(&self, cx: &App) -> bool { self.blame() - .map_or(false, |blame| blame.read(cx).has_generated_entries()) + .is_some_and(|blame| blame.read(cx).has_generated_entries()) } fn newest_selection_head_on_empty_line(&self, cx: &App) -> bool { @@ -18990,19 +19504,16 @@ impl Editor { buffer_ranges.last() }?; - let selection = text::ToPoint::to_point(&range.start, &buffer).row - ..text::ToPoint::to_point(&range.end, &buffer).row; - Some(( - multi_buffer.buffer(buffer.remote_id()).unwrap().clone(), - selection, - )) + let selection = text::ToPoint::to_point(&range.start, buffer).row + ..text::ToPoint::to_point(&range.end, buffer).row; + Some((multi_buffer.buffer(buffer.remote_id()).unwrap(), selection)) }); let Some((buffer, selection)) = buffer_and_selection else { return Task::ready(Err(anyhow!("failed to determine buffer and selection"))); }; - let Some(project) = self.project.as_ref() else { + let Some(project) = self.project() else { return Task::ready(Err(anyhow!("editor does not have project"))); }; @@ -19059,10 +19570,10 @@ impl Editor { cx: &mut Context, ) { let selection = self.selections.newest::(cx).start.row + 1; - if let Some(file) = self.target_file(cx) { - if let Some(path) = file.path().to_str() { - cx.write_to_clipboard(ClipboardItem::new_string(format!("{path}:{selection}"))); - } + if let Some(file) = self.target_file(cx) + && let Some(path) = file.path().to_str() + { + cx.write_to_clipboard(ClipboardItem::new_string(format!("{path}:{selection}"))); } } @@ -19166,7 +19677,7 @@ impl Editor { let locations = self .selections .all_anchors(cx) - .into_iter() + .iter() .map(|selection| Location { buffer: buffer.clone(), range: selection.start.text_anchor..selection.end.text_anchor, @@ -19237,7 +19748,7 @@ impl Editor { row_highlights.insert( ix, RowHighlight { - range: range.clone(), + range, index, color, options, @@ -19551,7 +20062,24 @@ impl Editor { let buffer = &snapshot.buffer_snapshot; let start = buffer.anchor_before(0); let end = buffer.anchor_after(buffer.len()); - self.background_highlights_in_range(start..end, &snapshot, cx.theme()) + self.sorted_background_highlights_in_range(start..end, &snapshot, cx.theme()) + } + + #[cfg(any(test, feature = "test-support"))] + pub fn sorted_background_highlights_in_range( + &self, + search_range: Range, + display_snapshot: &DisplaySnapshot, + theme: &Theme, + ) -> Vec<(Range, Hsla)> { + let mut res = self.background_highlights_in_range(search_range, display_snapshot, theme); + res.sort_by(|a, b| { + a.0.start + .cmp(&b.0.start) + .then_with(|| a.0.end.cmp(&b.0.end)) + .then_with(|| a.1.cmp(&b.1)) + }); + res } #[cfg(feature = "test-support")] @@ -19613,9 +20141,12 @@ impl Editor { pub fn has_background_highlights(&self) -> bool { self.background_highlights .get(&HighlightKey::Type(TypeId::of::())) - .map_or(false, |(_, highlights)| !highlights.is_empty()) + .is_some_and(|(_, highlights)| !highlights.is_empty()) } + /// Returns all background highlights for a given range. + /// + /// The order of highlights is not deterministic, do sort the ranges if needed for the logic. pub fn background_highlights_in_range( &self, search_range: Range, @@ -19654,84 +20185,6 @@ impl Editor { results } - pub fn background_highlight_row_ranges( - &self, - search_range: Range, - display_snapshot: &DisplaySnapshot, - count: usize, - ) -> Vec> { - let mut results = Vec::new(); - let Some((_, ranges)) = self - .background_highlights - .get(&HighlightKey::Type(TypeId::of::())) - else { - return vec![]; - }; - - let start_ix = match ranges.binary_search_by(|probe| { - let cmp = probe - .end - .cmp(&search_range.start, &display_snapshot.buffer_snapshot); - if cmp.is_gt() { - Ordering::Greater - } else { - Ordering::Less - } - }) { - Ok(i) | Err(i) => i, - }; - let mut push_region = |start: Option, end: Option| { - if let (Some(start_display), Some(end_display)) = (start, end) { - results.push( - start_display.to_display_point(display_snapshot) - ..=end_display.to_display_point(display_snapshot), - ); - } - }; - let mut start_row: Option = None; - let mut end_row: Option = None; - if ranges.len() > count { - return Vec::new(); - } - for range in &ranges[start_ix..] { - if range - .start - .cmp(&search_range.end, &display_snapshot.buffer_snapshot) - .is_ge() - { - break; - } - let end = range.end.to_point(&display_snapshot.buffer_snapshot); - if let Some(current_row) = &end_row { - if end.row == current_row.row { - continue; - } - } - let start = range.start.to_point(&display_snapshot.buffer_snapshot); - if start_row.is_none() { - assert_eq!(end_row, None); - start_row = Some(start); - end_row = Some(end); - continue; - } - if let Some(current_end) = end_row.as_mut() { - if start.row > current_end.row + 1 { - push_region(start_row, end_row); - start_row = Some(start); - end_row = Some(end); - } else { - // Merge two hunks. - *current_end = end; - } - } else { - unreachable!(); - } - } - // We might still have a hunk that was not rendered (if there was a search hit on the last line) - push_region(start_row, end_row); - results - } - pub fn gutter_highlights_in_range( &self, search_range: Range, @@ -19878,11 +20331,8 @@ impl Editor { event: &SessionEvent, cx: &mut Context, ) { - match event { - SessionEvent::InvalidateInlineValue => { - self.refresh_inline_values(cx); - } - _ => {} + if let SessionEvent::InvalidateInlineValue = event { + self.refresh_inline_values(cx); } } @@ -19997,17 +20447,16 @@ impl Editor { if self.has_active_edit_prediction() { self.update_visible_edit_prediction(window, cx); } - if let Some(project) = self.project.as_ref() { - if let Some(edited_buffer) = edited_buffer { - project.update(cx, |project, cx| { - self.registered_buffers - .entry(edited_buffer.read(cx).remote_id()) - .or_insert_with(|| { - project - .register_buffer_with_language_servers(&edited_buffer, cx) - }); - }); - } + if let Some(project) = self.project.as_ref() + && let Some(edited_buffer) = edited_buffer + { + project.update(cx, |project, cx| { + self.registered_buffers + .entry(edited_buffer.read(cx).remote_id()) + .or_insert_with(|| { + project.register_buffer_with_language_servers(edited_buffer, cx) + }); + }); } cx.emit(EditorEvent::BufferEdited); cx.emit(SearchEvent::MatchesInvalidated); @@ -20017,10 +20466,10 @@ impl Editor { } if *singleton_buffer_edited { - if let Some(buffer) = edited_buffer { - if buffer.read(cx).file().is_none() { - cx.emit(EditorEvent::TitleChanged); - } + if let Some(buffer) = edited_buffer + && buffer.read(cx).file().is_none() + { + cx.emit(EditorEvent::TitleChanged); } if let Some(project) = &self.project { #[allow(clippy::mutable_key_type)] @@ -20053,7 +20502,7 @@ impl Editor { let (telemetry, is_via_ssh) = { let project = project.read(cx); let telemetry = project.client().telemetry().clone(); - let is_via_ssh = project.is_via_ssh(); + let is_via_ssh = project.is_via_remote_server(); (telemetry, is_via_ssh) }; refresh_linked_ranges(self, window, cx); @@ -20066,17 +20515,17 @@ impl Editor { } => { self.tasks_update_task = Some(self.refresh_runnables(window, cx)); let buffer_id = buffer.read(cx).remote_id(); - if self.buffer.read(cx).diff_for(buffer_id).is_none() { - if let Some(project) = &self.project { - update_uncommitted_diff_for_buffer( - cx.entity(), - project, - [buffer.clone()], - self.buffer.clone(), - cx, - ) - .detach(); - } + if self.buffer.read(cx).diff_for(buffer_id).is_none() + && let Some(project) = &self.project + { + update_uncommitted_diff_for_buffer( + cx.entity(), + project, + [buffer.clone()], + self.buffer.clone(), + cx, + ) + .detach(); } self.update_lsp_data(false, Some(buffer_id), window, cx); cx.emit(EditorEvent::ExcerptsAdded { @@ -20135,7 +20584,6 @@ impl Editor { multi_buffer::Event::FileHandleChanged | multi_buffer::Event::Reloaded | multi_buffer::Event::BufferDiffChanged => cx.emit(EditorEvent::TitleChanged), - multi_buffer::Event::Closed => cx.emit(EditorEvent::Closed), multi_buffer::Event::DiagnosticsUpdated => { self.update_diagnostics_state(window, cx); } @@ -20209,6 +20657,7 @@ impl Editor { ); let old_cursor_shape = self.cursor_shape; + let old_show_breadcrumbs = self.show_breadcrumbs; { let editor_settings = EditorSettings::get_global(cx); @@ -20222,6 +20671,10 @@ impl Editor { cx.emit(EditorEvent::CursorShapeChanged); } + if old_show_breadcrumbs != self.show_breadcrumbs { + cx.emit(EditorEvent::BreadcrumbsChanged); + } + let project_settings = ProjectSettings::get_global(cx); self.serialize_dirty_buffers = !self.mode.is_minimap() && project_settings.session.restore_unsaved_buffers; @@ -20419,11 +20872,8 @@ impl Editor { .range_to_buffer_ranges_with_deleted_hunks(selection.range()) { if let Some(anchor) = anchor { - // selection is in a deleted hunk - let Some(buffer_id) = anchor.buffer_id else { - continue; - }; - let Some(buffer_handle) = multi_buffer.buffer(buffer_id) else { + let Some(buffer_handle) = multi_buffer.buffer_for_anchor(anchor, cx) + else { continue; }; let offset = text::ToOffset::to_offset( @@ -20531,7 +20981,7 @@ impl Editor { // For now, don't allow opening excerpts in buffers that aren't backed by // regular project files. fn can_open_excerpts_in_file(file: Option<&Arc>) -> bool { - file.map_or(true, |file| project::File::from_dyn(Some(file)).is_some()) + file.is_none_or(|file| project::File::from_dyn(Some(file)).is_some()) } fn marked_text_ranges(&self, cx: &App) -> Option>> { @@ -20619,7 +21069,7 @@ impl Editor { copilot_enabled, copilot_enabled_for_language, edit_predictions_provider, - is_via_ssh = project.is_via_ssh(), + is_via_ssh = project.is_via_remote_server(), ); } else { telemetry::event!( @@ -20629,7 +21079,7 @@ impl Editor { copilot_enabled, copilot_enabled_for_language, edit_predictions_provider, - is_via_ssh = project.is_via_ssh(), + is_via_ssh = project.is_via_remote_server(), ); }; } @@ -20675,11 +21125,11 @@ impl Editor { let mut chunk_lines = chunk.text.split('\n').peekable(); while let Some(text) = chunk_lines.next() { let mut merged_with_last_token = false; - if let Some(last_token) = line.back_mut() { - if last_token.highlight == highlight { - last_token.text.push_str(text); - merged_with_last_token = true; - } + if let Some(last_token) = line.back_mut() + && last_token.highlight == highlight + { + last_token.text.push_str(text); + merged_with_last_token = true; } if !merged_with_last_token { @@ -20869,7 +21319,7 @@ impl Editor { let existing_pending = self .text_highlights::(cx) - .map(|(_, ranges)| ranges.iter().cloned().collect::>()); + .map(|(_, ranges)| ranges.to_vec()); if existing_pending.is_none() && pending.is_empty() { return; } @@ -20984,7 +21434,7 @@ impl Editor { cx: &mut Context, ) { let workspace = self.workspace(); - let project = self.project.as_ref(); + let project = self.project(); let save_tasks = self.buffer().update(cx, |multi_buffer, cx| { let mut tasks = Vec::new(); for (buffer_id, changes) in revert_changes { @@ -21022,7 +21472,7 @@ impl Editor { }; if let Some((workspace, path)) = workspace.as_ref().zip(path) { let Some(task) = cx - .update_window_entity(&workspace, |workspace, window, cx| { + .update_window_entity(workspace, |workspace, window, cx| { workspace .open_path_preview(path, None, false, false, false, window, cx) }) @@ -21074,7 +21524,7 @@ impl Editor { pub fn has_visible_completions_menu(&self) -> bool { !self.edit_prediction_preview_is_active() - && self.context_menu.borrow().as_ref().map_or(false, |menu| { + && self.context_menu.borrow().as_ref().is_some_and(|menu| { menu.visible() && matches!(menu, CodeContextMenu::Completions(_)) }) } @@ -21138,39 +21588,37 @@ impl Editor { { let buffer_snapshot = OnceCell::new(); - if let Some(folds) = DB.get_editor_folds(item_id, workspace_id).log_err() { - if !folds.is_empty() { - let snapshot = - buffer_snapshot.get_or_init(|| self.buffer.read(cx).snapshot(cx)); - self.fold_ranges( - folds - .into_iter() - .map(|(start, end)| { - snapshot.clip_offset(start, Bias::Left) - ..snapshot.clip_offset(end, Bias::Right) - }) - .collect(), - false, - window, - cx, - ); - } - } - - if let Some(selections) = DB.get_editor_selections(item_id, workspace_id).log_err() { - if !selections.is_empty() { - let snapshot = - buffer_snapshot.get_or_init(|| self.buffer.read(cx).snapshot(cx)); - // skip adding the initial selection to selection history - self.selection_history.mode = SelectionHistoryMode::Skipping; - self.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { - s.select_ranges(selections.into_iter().map(|(start, end)| { + if let Some(folds) = DB.get_editor_folds(item_id, workspace_id).log_err() + && !folds.is_empty() + { + let snapshot = buffer_snapshot.get_or_init(|| self.buffer.read(cx).snapshot(cx)); + self.fold_ranges( + folds + .into_iter() + .map(|(start, end)| { snapshot.clip_offset(start, Bias::Left) ..snapshot.clip_offset(end, Bias::Right) - })); - }); - self.selection_history.mode = SelectionHistoryMode::Normal; - } + }) + .collect(), + false, + window, + cx, + ); + } + + if let Some(selections) = DB.get_editor_selections(item_id, workspace_id).log_err() + && !selections.is_empty() + { + let snapshot = buffer_snapshot.get_or_init(|| self.buffer.read(cx).snapshot(cx)); + // skip adding the initial selection to selection history + self.selection_history.mode = SelectionHistoryMode::Skipping; + self.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_ranges(selections.into_iter().map(|(start, end)| { + snapshot.clip_offset(start, Bias::Left) + ..snapshot.clip_offset(end, Bias::Right) + })); + }); + self.selection_history.mode = SelectionHistoryMode::Normal; }; } @@ -21212,17 +21660,15 @@ fn process_completion_for_edit( let mut snippet_source = completion.new_text.clone(); let mut previous_point = text::ToPoint::to_point(cursor_position, buffer); previous_point.column = previous_point.column.saturating_sub(1); - if let Some(scope) = buffer_snapshot.language_scope_at(previous_point) { - if scope.prefers_label_for_snippet_in_completion() { - if let Some(label) = completion.label() { - if matches!( - completion.kind(), - Some(CompletionItemKind::FUNCTION) | Some(CompletionItemKind::METHOD) - ) { - snippet_source = label; - } - } - } + if let Some(scope) = buffer_snapshot.language_scope_at(previous_point) + && scope.prefers_label_for_snippet_in_completion() + && let Some(label) = completion.label() + && matches!( + completion.kind(), + Some(CompletionItemKind::FUNCTION) | Some(CompletionItemKind::METHOD) + ) + { + snippet_source = label; } match Snippet::parse(&snippet_source).log_err() { Some(parsed_snippet) => (Some(parsed_snippet.clone()), parsed_snippet.text), @@ -21246,14 +21692,14 @@ fn process_completion_for_edit( debug_assert!( insert_range .start - .cmp(&cursor_position, &buffer_snapshot) + .cmp(cursor_position, &buffer_snapshot) .is_le(), "insert_range should start before or at cursor position" ); debug_assert!( replace_range .start - .cmp(&cursor_position, &buffer_snapshot) + .cmp(cursor_position, &buffer_snapshot) .is_le(), "replace_range should start before or at cursor position" ); @@ -21276,10 +21722,10 @@ fn process_completion_for_edit( ); let mut current_needle = text_to_replace.next(); for haystack_ch in completion.label.text.chars() { - if let Some(needle_ch) = current_needle { - if haystack_ch.eq_ignore_ascii_case(&needle_ch) { - current_needle = text_to_replace.next(); - } + if let Some(needle_ch) = current_needle + && haystack_ch.eq_ignore_ascii_case(&needle_ch) + { + current_needle = text_to_replace.next(); } } current_needle.is_none() @@ -21287,7 +21733,7 @@ fn process_completion_for_edit( LspInsertMode::ReplaceSuffix => { if replace_range .end - .cmp(&cursor_position, &buffer_snapshot) + .cmp(cursor_position, &buffer_snapshot) .is_gt() { let range_after_cursor = *cursor_position..replace_range.end; @@ -21323,7 +21769,7 @@ fn process_completion_for_edit( if range_to_replace .end - .cmp(&cursor_position, &buffer_snapshot) + .cmp(cursor_position, &buffer_snapshot) .is_lt() { range_to_replace.end = *cursor_position; @@ -21331,7 +21777,7 @@ fn process_completion_for_edit( CompletionEdit { new_text, - replace_range: range_to_replace.to_offset(&buffer), + replace_range: range_to_replace.to_offset(buffer), snippet, } } @@ -21501,9 +21947,9 @@ fn is_grapheme_whitespace(text: &str) -> bool { } fn should_stay_with_preceding_ideograph(text: &str) -> bool { - text.chars().next().map_or(false, |ch| { - matches!(ch, '。' | '、' | ',' | '?' | '!' | ':' | ';' | '…') - }) + text.chars() + .next() + .is_some_and(|ch| matches!(ch, '。' | '、' | ',' | '?' | '!' | ':' | ';' | '…')) } #[derive(PartialEq, Eq, Debug, Clone, Copy)] @@ -21533,20 +21979,20 @@ impl<'a> Iterator for WordBreakingTokenizer<'a> { offset += first_grapheme.len(); grapheme_len += 1; if is_grapheme_ideographic(first_grapheme) && !is_whitespace { - if let Some(grapheme) = iter.peek().copied() { - if should_stay_with_preceding_ideograph(grapheme) { - offset += grapheme.len(); - grapheme_len += 1; - } + if let Some(grapheme) = iter.peek().copied() + && should_stay_with_preceding_ideograph(grapheme) + { + offset += grapheme.len(); + grapheme_len += 1; } } else { let mut words = self.input[offset..].split_word_bound_indices().peekable(); let mut next_word_bound = words.peek().copied(); - if next_word_bound.map_or(false, |(i, _)| i == 0) { + if next_word_bound.is_some_and(|(i, _)| i == 0) { next_word_bound = words.next(); } while let Some(grapheme) = iter.peek().copied() { - if next_word_bound.map_or(false, |(i, _)| i == offset) { + if next_word_bound.is_some_and(|(i, _)| i == offset) { break; }; if is_grapheme_whitespace(grapheme) != is_whitespace @@ -21667,7 +22113,7 @@ fn wrap_with_prefix( let subsequent_lines_prefix_len = char_len_with_expanded_tabs(0, &subsequent_lines_prefix, tab_size); let mut wrapped_text = String::new(); - let mut current_line = first_line_prefix.clone(); + let mut current_line = first_line_prefix; let mut is_first_line = true; let tokenizer = WordBreakingTokenizer::new(&unwrapped_text); @@ -21839,7 +22285,7 @@ pub trait SemanticsProvider { buffer: &Entity, position: text::Anchor, cx: &mut App, - ) -> Option>>; + ) -> Option>>>; fn inline_values( &self, @@ -21878,7 +22324,7 @@ pub trait SemanticsProvider { position: text::Anchor, kind: GotoDefinitionKind, cx: &mut App, - ) -> Option>>>; + ) -> Option>>>>; fn range_for_rename( &self, @@ -21991,7 +22437,13 @@ impl CodeActionProvider for Entity { Ok(code_lens_actions .context("code lens fetch")? .into_iter() - .chain(code_actions.context("code action fetch")?) + .flatten() + .chain( + code_actions + .context("code action fetch")? + .into_iter() + .flatten(), + ) .collect()) }) }) @@ -22038,6 +22490,7 @@ fn snippet_completions( if scopes.is_empty() { return Task::ready(Ok(CompletionResponse { completions: vec![], + display_options: CompletionDisplayOptions::default(), is_incomplete: false, })); } @@ -22062,6 +22515,7 @@ fn snippet_completions( if last_word.is_empty() { return Ok(CompletionResponse { completions: vec![], + display_options: CompletionDisplayOptions::default(), is_incomplete: true, }); } @@ -22080,7 +22534,7 @@ fn snippet_completions( snippet .prefix .iter() - .map(move |prefix| StringMatchCandidate::new(ix, &prefix)) + .map(move |prefix| StringMatchCandidate::new(ix, prefix)) }) .collect::>(); @@ -22183,6 +22637,7 @@ fn snippet_completions( Ok(CompletionResponse { completions, + display_options: CompletionDisplayOptions::default(), is_incomplete, }) }) @@ -22286,7 +22741,7 @@ impl SemanticsProvider for Entity { buffer: &Entity, position: text::Anchor, cx: &mut App, - ) -> Option>> { + ) -> Option>>> { Some(self.update(cx, |project, cx| project.hover(buffer, position, cx))) } @@ -22307,12 +22762,12 @@ impl SemanticsProvider for Entity { position: text::Anchor, kind: GotoDefinitionKind, cx: &mut App, - ) -> Option>>> { + ) -> Option>>>> { Some(self.update(cx, |project, cx| match kind { - GotoDefinitionKind::Symbol => project.definitions(&buffer, position, cx), - GotoDefinitionKind::Declaration => project.declarations(&buffer, position, cx), - GotoDefinitionKind::Type => project.type_definitions(&buffer, position, cx), - GotoDefinitionKind::Implementation => project.implementations(&buffer, position, cx), + GotoDefinitionKind::Symbol => project.definitions(buffer, position, cx), + GotoDefinitionKind::Declaration => project.declarations(buffer, position, cx), + GotoDefinitionKind::Type => project.type_definitions(buffer, position, cx), + GotoDefinitionKind::Implementation => project.implementations(buffer, position, cx), })) } @@ -22599,8 +23054,10 @@ impl EditorSnapshot { self.is_focused } - pub fn placeholder_text(&self) -> Option<&Arc> { - self.placeholder_text.as_ref() + pub fn placeholder_text(&self) -> Option { + self.placeholder_display_snapshot + .as_ref() + .map(|display_map| display_map.text()) } pub fn scroll_position(&self) -> gpui::Point { @@ -22826,7 +23283,6 @@ pub enum EditorEvent { DirtyChanged, Saved, TitleChanged, - DiffBaseChanged, SelectionsChanged { local: bool, }, @@ -22834,15 +23290,14 @@ pub enum EditorEvent { local: bool, autoscroll: bool, }, - Closed, TransactionUndone { transaction_id: clock::Lamport, }, TransactionBegun { transaction_id: clock::Lamport, }, - Reloaded, CursorShapeChanged, + BreadcrumbsChanged, PushedToNavHistory { anchor: Anchor, is_deactivate: bool, @@ -22862,7 +23317,7 @@ impl Render for Editor { let settings = ThemeSettings::get_global(cx); let mut text_style = match self.mode { - EditorMode::SingleLine { .. } | EditorMode::AutoHeight { .. } => TextStyle { + EditorMode::SingleLine | EditorMode::AutoHeight { .. } => TextStyle { color: cx.theme().colors().editor_foreground, font_family: settings.ui_font.family.clone(), font_features: settings.ui_font.features.clone(), @@ -22888,7 +23343,7 @@ impl Render for Editor { } let background = match self.mode { - EditorMode::SingleLine { .. } => cx.theme().system().transparent, + EditorMode::SingleLine => cx.theme().system().transparent, EditorMode::AutoHeight { .. } => cx.theme().system().transparent, EditorMode::Full { .. } => cx.theme().colors().editor_background, EditorMode::Minimap { .. } => cx.theme().colors().editor_background.opacity(0.7), @@ -23377,8 +23832,7 @@ pub fn styled_runs_for_code_label<'a>( } else { return Default::default(); }; - let mut muted_style = style; - muted_style.highlight(fade_out); + let muted_style = style.highlight(fade_out); let mut runs = SmallVec::<[(Range, HighlightStyle); 3]>::new(); if range.start >= label.filter_range.end { @@ -23593,6 +24047,7 @@ impl BreakpointPromptEditor { BreakpointPromptEditAction::Condition => "Condition when a breakpoint is hit. Expressions within {} are interpolated.", BreakpointPromptEditAction::HitCondition => "How many breakpoint hits to ignore", }, + window, cx, ); @@ -23720,7 +24175,7 @@ fn all_edits_insertions_or_deletions( let mut all_deletions = true; for (range, new_text) in edits.iter() { - let range_is_empty = range.to_offset(&snapshot).is_empty(); + let range_is_empty = range.to_offset(snapshot).is_empty(); let text_is_empty = new_text.is_empty(); if range_is_empty != text_is_empty { @@ -23972,3 +24427,10 @@ fn render_diff_hunk_controls( ) .into_any_element() } + +pub fn multibuffer_context_lines(cx: &App) -> u32 { + EditorSettings::try_get(cx) + .map(|settings| settings.excerpt_context_lines) + .unwrap_or(2) + .clamp(1, 32) +} diff --git a/crates/editor/src/editor_settings.rs b/crates/editor/src/editor_settings.rs index 3d132651b846c3654b022b4a3e0aa6f60fa25d04..7f4d024e57c4831aa4c512e6dcb3a9ab35d4f610 100644 --- a/crates/editor/src/editor_settings.rs +++ b/crates/editor/src/editor_settings.rs @@ -6,7 +6,7 @@ use language::CursorShape; use project::project_settings::DiagnosticSeverity; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources, VsCodeSettings}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsUi, VsCodeSettings}; use util::serde::default_true; /// Imports from the VSCode settings at @@ -17,6 +17,7 @@ pub struct EditorSettings { pub cursor_shape: Option, pub current_line_highlight: CurrentLineHighlight, pub selection_highlight: bool, + pub rounded_selection: bool, pub lsp_highlight_debounce: u64, pub hover_popover_enabled: bool, pub hover_popover_delay: u64, @@ -37,6 +38,7 @@ pub struct EditorSettings { pub multi_cursor_modifier: MultiCursorModifier, pub redact_private_values: bool, pub expand_excerpt_lines: u32, + pub excerpt_context_lines: u32, pub middle_click_paste: bool, #[serde(default)] pub double_click_in_multibuffer: DoubleClickInMultibuffer, @@ -55,10 +57,13 @@ pub struct EditorSettings { pub inline_code_actions: bool, pub drag_and_drop_selection: DragAndDropSelection, pub lsp_document_colors: DocumentColorsRenderMode, + pub minimum_contrast_for_highlights: f32, } /// How to render LSP `textDocument/documentColor` colors in the editor. -#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] +#[derive( + Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq, JsonSchema, SettingsUi, +)] #[serde(rename_all = "snake_case")] pub enum DocumentColorsRenderMode { /// Do not query and render document colors. @@ -72,7 +77,7 @@ pub enum DocumentColorsRenderMode { Background, } -#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema, SettingsUi)] #[serde(rename_all = "snake_case")] pub enum CurrentLineHighlight { // Don't highlight the current line. @@ -86,7 +91,7 @@ pub enum CurrentLineHighlight { } /// When to populate a new search's query based on the text under the cursor. -#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema, SettingsUi)] #[serde(rename_all = "snake_case")] pub enum SeedQuerySetting { /// Always populate the search query with the word under the cursor. @@ -98,7 +103,9 @@ pub enum SeedQuerySetting { } /// What to do when multibuffer is double clicked in some of its excerpts (parts of singleton buffers). -#[derive(Default, Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] +#[derive( + Default, Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema, SettingsUi, +)] #[serde(rename_all = "snake_case")] pub enum DoubleClickInMultibuffer { /// Behave as a regular buffer and select the whole word. @@ -117,7 +124,9 @@ pub struct Jupyter { pub enabled: bool, } -#[derive(Default, Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] +#[derive( + Default, Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema, SettingsUi, +)] #[serde(rename_all = "snake_case")] pub struct JupyterContent { /// Whether the Jupyter feature is enabled. @@ -132,6 +141,10 @@ pub struct StatusBar { /// /// Default: true pub active_language_button: bool, + /// Whether to show the cursor position button in the status bar. + /// + /// Default: true + pub cursor_position_button: bool, } #[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] @@ -285,7 +298,9 @@ pub struct ScrollbarAxes { } /// Whether to allow drag and drop text selection in buffer. -#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +#[derive( + Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq, SettingsUi, +)] pub struct DragAndDropSelection { /// When true, enables drag and drop text selection in buffer. /// @@ -325,7 +340,7 @@ pub enum ScrollbarDiagnostics { /// The key to use for adding multiple cursors /// /// Default: alt -#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq, SettingsUi)] #[serde(rename_all = "snake_case")] pub enum MultiCursorModifier { Alt, @@ -336,7 +351,7 @@ pub enum MultiCursorModifier { /// Whether the editor will scroll beyond the last line. /// /// Default: one_page -#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq, SettingsUi)] #[serde(rename_all = "snake_case")] pub enum ScrollBeyondLastLine { /// The editor will not scroll beyond the last line. @@ -350,7 +365,9 @@ pub enum ScrollBeyondLastLine { } /// Default options for buffer and project search items. -#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +#[derive( + Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq, SettingsUi, +)] pub struct SearchSettings { /// Whether to show the project search button in the status bar. #[serde(default = "default_true")] @@ -366,7 +383,9 @@ pub struct SearchSettings { } /// What to do when go to definition yields no results. -#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] +#[derive( + Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq, JsonSchema, SettingsUi, +)] #[serde(rename_all = "snake_case")] pub enum GoToDefinitionFallback { /// Disables the fallback. @@ -379,7 +398,9 @@ pub enum GoToDefinitionFallback { /// Determines when the mouse cursor should be hidden in an editor or input box. /// /// Default: on_typing_and_movement -#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] +#[derive( + Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq, JsonSchema, SettingsUi, +)] #[serde(rename_all = "snake_case")] pub enum HideMouseMode { /// Never hide the mouse cursor @@ -394,7 +415,9 @@ pub enum HideMouseMode { /// Determines how snippets are sorted relative to other completion items. /// /// Default: inline -#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] +#[derive( + Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq, JsonSchema, SettingsUi, +)] #[serde(rename_all = "snake_case")] pub enum SnippetSortOrder { /// Place snippets at the top of the completion list @@ -408,7 +431,9 @@ pub enum SnippetSortOrder { None, } -#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, SettingsUi, SettingsKey)] +#[settings_ui(group = "Editor")] +#[settings_key(None)] pub struct EditorSettingsContent { /// Whether the cursor blinks in the editor. /// @@ -417,7 +442,7 @@ pub struct EditorSettingsContent { /// Cursor shape for the default editor. /// Can be "bar", "block", "underline", or "hollow". /// - /// Default: None + /// Default: bar pub cursor_shape: Option, /// Determines when the mouse cursor should be hidden in an editor or input box. /// @@ -435,6 +460,10 @@ pub struct EditorSettingsContent { /// /// Default: true pub selection_highlight: Option, + /// Whether the text selection should have rounded corners. + /// + /// Default: true + pub rounded_selection: Option, /// The debounce delay before querying highlights from the language /// server based on the current cursor location. /// @@ -511,6 +540,11 @@ pub struct EditorSettingsContent { /// Default: 3 pub expand_excerpt_lines: Option, + /// How many lines of context to provide in multibuffer excerpts by default + /// + /// Default: 2 + pub excerpt_context_lines: Option, + /// Whether to enable middle-click paste on Linux /// /// Default: true @@ -540,6 +574,12 @@ pub struct EditorSettingsContent { /// /// Default: false pub show_signature_help_after_edits: Option, + /// The minimum APCA perceptual contrast to maintain when + /// rendering text over highlight backgrounds in the editor. + /// + /// Values range from 0 to 106. Set to 0 to disable adjustments. + /// Default: 45 + pub minimum_contrast_for_highlights: Option, /// Whether to follow-up empty go to definition responses from the language server or not. /// `FindAllReferences` allows to look up references of the same symbol instead. @@ -579,16 +619,20 @@ pub struct EditorSettingsContent { } // Status bar related settings -#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq, SettingsUi)] pub struct StatusBarContent { /// Whether to display the active language button in the status bar. /// /// Default: true pub active_language_button: Option, + /// Whether to show the cursor position button in the status bar. + /// + /// Default: true + pub cursor_position_button: Option, } // Toolbar related settings -#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq, SettingsUi)] pub struct ToolbarContent { /// Whether to display breadcrumbs in the editor toolbar. /// @@ -614,7 +658,9 @@ pub struct ToolbarContent { } /// Scrollbar related settings -#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Default)] +#[derive( + Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Default, SettingsUi, +)] pub struct ScrollbarContent { /// When to show the scrollbar in the editor. /// @@ -649,7 +695,9 @@ pub struct ScrollbarContent { } /// Minimap related settings -#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] +#[derive( + Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema, PartialEq, SettingsUi, +)] pub struct MinimapContent { /// When to show the minimap in the editor. /// @@ -697,7 +745,10 @@ pub struct ScrollbarAxesContent { } /// Gutter related settings -#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +#[derive( + Copy, Clone, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq, SettingsUi, +)] +#[settings_ui(group = "Gutter")] pub struct GutterContent { /// Whether to show line numbers in the gutter. /// @@ -728,8 +779,6 @@ impl EditorSettings { } impl Settings for EditorSettings { - const KEY: Option<&'static str> = None; - type FileContent = EditorSettingsContent; fn load(sources: SettingsSources, _: &mut App) -> anyhow::Result { @@ -773,6 +822,7 @@ impl Settings for EditorSettings { "editor.selectionHighlight", &mut current.selection_highlight, ); + vscode.bool_setting("editor.roundedSelection", &mut current.rounded_selection); vscode.bool_setting("editor.hover.enabled", &mut current.hover_popover_enabled); vscode.u64_setting("editor.hover.delay", &mut current.hover_popover_delay); @@ -802,10 +852,8 @@ impl Settings for EditorSettings { if gutter.line_numbers.is_some() { old_gutter.line_numbers = gutter.line_numbers } - } else { - if gutter != GutterContent::default() { - current.gutter = Some(gutter) - } + } else if gutter != GutterContent::default() { + current.gutter = Some(gutter) } if let Some(b) = vscode.read_bool("editor.scrollBeyondLastLine") { current.scroll_beyond_last_line = Some(if b { diff --git a/crates/editor/src/editor_settings_controls.rs b/crates/editor/src/editor_settings_controls.rs index dc5557b05277da972ea36ba43ffdf08a565edda9..91022d94a8843a2e9b7e9c77137d4d2ba57bfa7f 100644 --- a/crates/editor/src/editor_settings_controls.rs +++ b/crates/editor/src/editor_settings_controls.rs @@ -88,7 +88,7 @@ impl RenderOnce for BufferFontFamilyControl { .child(Icon::new(IconName::Font)) .child(DropdownMenu::new( "buffer-font-family", - value.clone(), + value, ContextMenu::build(window, cx, |mut menu, _, cx| { let font_family_cache = FontFamilyCache::global(cx); diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index 44218697032f2f03d6354092a31f3c0d921992e4..4e3c5012e14f7919a8994bb2951c71cd75d0e404 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -57,7 +57,9 @@ use util::{ use workspace::{ CloseActiveItem, CloseAllItems, CloseOtherItems, MoveItemToPaneInDirection, NavigationEntry, OpenOptions, ViewId, + invalid_buffer_view::InvalidBufferView, item::{FollowEvent, FollowableItem, Item, ItemHandle, SaveOptions}, + register_project_item, }; #[gpui::test] @@ -74,7 +76,7 @@ fn test_edit_events(cx: &mut TestAppContext) { let editor1 = cx.add_window({ let events = events.clone(); |window, cx| { - let entity = cx.entity().clone(); + let entity = cx.entity(); cx.subscribe_in( &entity, window, @@ -95,7 +97,7 @@ fn test_edit_events(cx: &mut TestAppContext) { let events = events.clone(); |window, cx| { cx.subscribe_in( - &cx.entity().clone(), + &cx.entity(), window, move |_, _, event: &EditorEvent, _, _| match event { EditorEvent::Edited { .. } => events.borrow_mut().push(("editor2", "edited")), @@ -708,7 +710,7 @@ async fn test_navigation_history(cx: &mut TestAppContext) { _ = workspace.update(cx, |_v, window, cx| { cx.new(|cx| { let buffer = MultiBuffer::build_simple(&sample_text(300, 5, 'a'), cx); - let mut editor = build_editor(buffer.clone(), window, cx); + let mut editor = build_editor(buffer, window, cx); let handle = cx.entity(); editor.set_nav_history(Some(pane.read(cx).nav_history_for_item(&handle))); @@ -898,7 +900,7 @@ fn test_fold_action(cx: &mut TestAppContext) { .unindent(), cx, ); - build_editor(buffer.clone(), window, cx) + build_editor(buffer, window, cx) }); _ = editor.update(cx, |editor, window, cx| { @@ -989,7 +991,7 @@ fn test_fold_action_whitespace_sensitive_language(cx: &mut TestAppContext) { .unindent(), cx, ); - build_editor(buffer.clone(), window, cx) + build_editor(buffer, window, cx) }); _ = editor.update(cx, |editor, window, cx| { @@ -1074,7 +1076,7 @@ fn test_fold_action_multiple_line_breaks(cx: &mut TestAppContext) { .unindent(), cx, ); - build_editor(buffer.clone(), window, cx) + build_editor(buffer, window, cx) }); _ = editor.update(cx, |editor, window, cx| { @@ -1173,7 +1175,7 @@ fn test_fold_at_level(cx: &mut TestAppContext) { .unindent(), cx, ); - build_editor(buffer.clone(), window, cx) + build_editor(buffer, window, cx) }); _ = editor.update(cx, |editor, window, cx| { @@ -1335,7 +1337,7 @@ fn test_move_cursor_multibyte(cx: &mut TestAppContext) { let editor = cx.add_window(|window, cx| { let buffer = MultiBuffer::build_simple("🟥🟧🟨🟩🟦🟪\nabcde\nαβγδε", cx); - build_editor(buffer.clone(), window, cx) + build_editor(buffer, window, cx) }); assert_eq!('🟥'.len_utf8(), 4); @@ -1452,7 +1454,7 @@ fn test_move_cursor_different_line_lengths(cx: &mut TestAppContext) { let editor = cx.add_window(|window, cx| { let buffer = MultiBuffer::build_simple("ⓐⓑⓒⓓⓔ\nabcd\nαβγ\nabcd\nⓐⓑⓒⓓⓔ\n", cx); - build_editor(buffer.clone(), window, cx) + build_editor(buffer, window, cx) }); _ = editor.update(cx, |editor, window, cx| { editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { @@ -2474,154 +2476,488 @@ async fn test_delete_to_beginning_of_line(cx: &mut TestAppContext) { } #[gpui::test] -fn test_delete_to_word_boundary(cx: &mut TestAppContext) { +async fn test_delete_to_word_boundary(cx: &mut TestAppContext) { init_test(cx, |_| {}); - let editor = cx.add_window(|window, cx| { - let buffer = MultiBuffer::build_simple("one two three four", cx); - build_editor(buffer.clone(), window, cx) - }); + let mut cx = EditorTestContext::new(cx).await; - _ = editor.update(cx, |editor, window, cx| { - editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { - s.select_display_ranges([ - // an empty selection - the preceding word fragment is deleted - DisplayPoint::new(DisplayRow(0), 2)..DisplayPoint::new(DisplayRow(0), 2), - // characters selected - they are deleted - DisplayPoint::new(DisplayRow(0), 9)..DisplayPoint::new(DisplayRow(0), 12), - ]) - }); + // For an empty selection, the preceding word fragment is deleted. + // For non-empty selections, only selected characters are deleted. + cx.set_state("onˇe two t«hreˇ»e four"); + cx.update_editor(|editor, window, cx| { editor.delete_to_previous_word_start( &DeleteToPreviousWordStart { ignore_newlines: false, + ignore_brackets: false, }, window, cx, ); - assert_eq!(editor.buffer.read(cx).read(cx).text(), "e two te four"); }); + cx.assert_editor_state("ˇe two tˇe four"); - _ = editor.update(cx, |editor, window, cx| { - editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { - s.select_display_ranges([ - // an empty selection - the following word fragment is deleted - DisplayPoint::new(DisplayRow(0), 3)..DisplayPoint::new(DisplayRow(0), 3), - // characters selected - they are deleted - DisplayPoint::new(DisplayRow(0), 9)..DisplayPoint::new(DisplayRow(0), 10), - ]) - }); + cx.set_state("e tˇwo te «fˇ»our"); + cx.update_editor(|editor, window, cx| { editor.delete_to_next_word_end( &DeleteToNextWordEnd { ignore_newlines: false, + ignore_brackets: false, }, window, cx, ); - assert_eq!(editor.buffer.read(cx).read(cx).text(), "e t te our"); }); + cx.assert_editor_state("e tˇ te ˇour"); } #[gpui::test] -fn test_delete_to_previous_word_start_or_newline(cx: &mut TestAppContext) { +async fn test_delete_whitespaces(cx: &mut TestAppContext) { init_test(cx, |_| {}); - let editor = cx.add_window(|window, cx| { - let buffer = MultiBuffer::build_simple("one\n2\nthree\n4", cx); - build_editor(buffer.clone(), window, cx) + let mut cx = EditorTestContext::new(cx).await; + + cx.set_state("here is some text ˇwith a space"); + cx.update_editor(|editor, window, cx| { + editor.delete_to_previous_word_start( + &DeleteToPreviousWordStart { + ignore_newlines: false, + ignore_brackets: true, + }, + window, + cx, + ); }); - let del_to_prev_word_start = DeleteToPreviousWordStart { - ignore_newlines: false, - }; - let del_to_prev_word_start_ignore_newlines = DeleteToPreviousWordStart { - ignore_newlines: true, - }; + // Continuous whitespace sequences are removed entirely, words behind them are not affected by the deletion action. + cx.assert_editor_state("here is some textˇwith a space"); - _ = editor.update(cx, |editor, window, cx| { - editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { - s.select_display_ranges([ - DisplayPoint::new(DisplayRow(3), 1)..DisplayPoint::new(DisplayRow(3), 1) - ]) - }); - editor.delete_to_previous_word_start(&del_to_prev_word_start, window, cx); - assert_eq!(editor.buffer.read(cx).read(cx).text(), "one\n2\nthree\n"); - editor.delete_to_previous_word_start(&del_to_prev_word_start, window, cx); - assert_eq!(editor.buffer.read(cx).read(cx).text(), "one\n2\nthree"); - editor.delete_to_previous_word_start(&del_to_prev_word_start, window, cx); - assert_eq!(editor.buffer.read(cx).read(cx).text(), "one\n2\n"); - editor.delete_to_previous_word_start(&del_to_prev_word_start, window, cx); - assert_eq!(editor.buffer.read(cx).read(cx).text(), "one\n2"); - editor.delete_to_previous_word_start(&del_to_prev_word_start_ignore_newlines, window, cx); - assert_eq!(editor.buffer.read(cx).read(cx).text(), "one\n"); - editor.delete_to_previous_word_start(&del_to_prev_word_start_ignore_newlines, window, cx); - assert_eq!(editor.buffer.read(cx).read(cx).text(), ""); + cx.set_state("here is some text ˇwith a space"); + cx.update_editor(|editor, window, cx| { + editor.delete_to_previous_word_start( + &DeleteToPreviousWordStart { + ignore_newlines: false, + ignore_brackets: false, + }, + window, + cx, + ); }); -} + cx.assert_editor_state("here is some textˇwith a space"); -#[gpui::test] -fn test_delete_to_next_word_end_or_newline(cx: &mut TestAppContext) { - init_test(cx, |_| {}); + cx.set_state("here is some textˇ with a space"); + cx.update_editor(|editor, window, cx| { + editor.delete_to_next_word_end( + &DeleteToNextWordEnd { + ignore_newlines: false, + ignore_brackets: true, + }, + window, + cx, + ); + }); + // Same happens in the other direction. + cx.assert_editor_state("here is some textˇwith a space"); - let editor = cx.add_window(|window, cx| { - let buffer = MultiBuffer::build_simple("\none\n two\nthree\n four", cx); - build_editor(buffer.clone(), window, cx) + cx.set_state("here is some textˇ with a space"); + cx.update_editor(|editor, window, cx| { + editor.delete_to_next_word_end( + &DeleteToNextWordEnd { + ignore_newlines: false, + ignore_brackets: false, + }, + window, + cx, + ); }); - let del_to_next_word_end = DeleteToNextWordEnd { - ignore_newlines: false, - }; - let del_to_next_word_end_ignore_newlines = DeleteToNextWordEnd { - ignore_newlines: true, - }; + cx.assert_editor_state("here is some textˇwith a space"); - _ = editor.update(cx, |editor, window, cx| { - editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { - s.select_display_ranges([ - DisplayPoint::new(DisplayRow(0), 0)..DisplayPoint::new(DisplayRow(0), 0) - ]) - }); - editor.delete_to_next_word_end(&del_to_next_word_end, window, cx); - assert_eq!( - editor.buffer.read(cx).read(cx).text(), - "one\n two\nthree\n four" + cx.set_state("here is some textˇ with a space"); + cx.update_editor(|editor, window, cx| { + editor.delete_to_next_word_end( + &DeleteToNextWordEnd { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, ); - editor.delete_to_next_word_end(&del_to_next_word_end, window, cx); - assert_eq!( - editor.buffer.read(cx).read(cx).text(), - "\n two\nthree\n four" + }); + cx.assert_editor_state("here is some textˇwith a space"); + cx.update_editor(|editor, window, cx| { + editor.delete_to_previous_word_start( + &DeleteToPreviousWordStart { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, ); - editor.delete_to_next_word_end(&del_to_next_word_end, window, cx); - assert_eq!( - editor.buffer.read(cx).read(cx).text(), - "two\nthree\n four" + }); + cx.assert_editor_state("here is some ˇwith a space"); + cx.update_editor(|editor, window, cx| { + editor.delete_to_previous_word_start( + &DeleteToPreviousWordStart { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, + ); + }); + // Single whitespaces are removed with the word behind them. + cx.assert_editor_state("here is ˇwith a space"); + cx.update_editor(|editor, window, cx| { + editor.delete_to_previous_word_start( + &DeleteToPreviousWordStart { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, + ); + }); + cx.assert_editor_state("here ˇwith a space"); + cx.update_editor(|editor, window, cx| { + editor.delete_to_previous_word_start( + &DeleteToPreviousWordStart { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, + ); + }); + cx.assert_editor_state("ˇwith a space"); + cx.update_editor(|editor, window, cx| { + editor.delete_to_previous_word_start( + &DeleteToPreviousWordStart { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, + ); + }); + cx.assert_editor_state("ˇwith a space"); + cx.update_editor(|editor, window, cx| { + editor.delete_to_next_word_end( + &DeleteToNextWordEnd { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, + ); + }); + // Same happens in the other direction. + cx.assert_editor_state("ˇ a space"); + cx.update_editor(|editor, window, cx| { + editor.delete_to_next_word_end( + &DeleteToNextWordEnd { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, + ); + }); + cx.assert_editor_state("ˇ space"); + cx.update_editor(|editor, window, cx| { + editor.delete_to_next_word_end( + &DeleteToNextWordEnd { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, + ); + }); + cx.assert_editor_state("ˇ"); + cx.update_editor(|editor, window, cx| { + editor.delete_to_next_word_end( + &DeleteToNextWordEnd { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, + ); + }); + cx.assert_editor_state("ˇ"); + cx.update_editor(|editor, window, cx| { + editor.delete_to_previous_word_start( + &DeleteToPreviousWordStart { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, ); - editor.delete_to_next_word_end(&del_to_next_word_end, window, cx); - assert_eq!(editor.buffer.read(cx).read(cx).text(), "\nthree\n four"); - editor.delete_to_next_word_end(&del_to_next_word_end_ignore_newlines, window, cx); - assert_eq!(editor.buffer.read(cx).read(cx).text(), "\n four"); - editor.delete_to_next_word_end(&del_to_next_word_end_ignore_newlines, window, cx); - assert_eq!(editor.buffer.read(cx).read(cx).text(), ""); }); + cx.assert_editor_state("ˇ"); } #[gpui::test] -fn test_newline(cx: &mut TestAppContext) { +async fn test_delete_to_bracket(cx: &mut TestAppContext) { init_test(cx, |_| {}); - let editor = cx.add_window(|window, cx| { - let buffer = MultiBuffer::build_simple("aaaa\n bbbb\n", cx); - build_editor(buffer.clone(), window, cx) - }); + let language = Arc::new( + Language::new( + LanguageConfig { + brackets: BracketPairConfig { + pairs: vec![ + BracketPair { + start: "\"".to_string(), + end: "\"".to_string(), + close: true, + surround: true, + newline: false, + }, + BracketPair { + start: "(".to_string(), + end: ")".to_string(), + close: true, + surround: true, + newline: true, + }, + ], + ..BracketPairConfig::default() + }, + ..LanguageConfig::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_brackets_query( + r#" + ("(" @open ")" @close) + ("\"" @open "\"" @close) + "#, + ) + .unwrap(), + ); - _ = editor.update(cx, |editor, window, cx| { - editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { - s.select_display_ranges([ - DisplayPoint::new(DisplayRow(0), 2)..DisplayPoint::new(DisplayRow(0), 2), - DisplayPoint::new(DisplayRow(1), 2)..DisplayPoint::new(DisplayRow(1), 2), - DisplayPoint::new(DisplayRow(1), 6)..DisplayPoint::new(DisplayRow(1), 6), - ]) - }); + let mut cx = EditorTestContext::new(cx).await; + cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx)); - editor.newline(&Newline, window, cx); - assert_eq!(editor.text(cx), "aa\naa\n \n bb\n bb\n"); + cx.set_state(r#"macro!("// ˇCOMMENT");"#); + cx.update_editor(|editor, window, cx| { + editor.delete_to_previous_word_start( + &DeleteToPreviousWordStart { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, + ); + }); + // Deletion stops before brackets if asked to not ignore them. + cx.assert_editor_state(r#"macro!("ˇCOMMENT");"#); + cx.update_editor(|editor, window, cx| { + editor.delete_to_previous_word_start( + &DeleteToPreviousWordStart { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, + ); + }); + // Deletion has to remove a single bracket and then stop again. + cx.assert_editor_state(r#"macro!(ˇCOMMENT");"#); + + cx.update_editor(|editor, window, cx| { + editor.delete_to_previous_word_start( + &DeleteToPreviousWordStart { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, + ); + }); + cx.assert_editor_state(r#"macro!ˇCOMMENT");"#); + + cx.update_editor(|editor, window, cx| { + editor.delete_to_previous_word_start( + &DeleteToPreviousWordStart { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, + ); + }); + cx.assert_editor_state(r#"ˇCOMMENT");"#); + + cx.update_editor(|editor, window, cx| { + editor.delete_to_previous_word_start( + &DeleteToPreviousWordStart { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, + ); + }); + cx.assert_editor_state(r#"ˇCOMMENT");"#); + + cx.update_editor(|editor, window, cx| { + editor.delete_to_next_word_end( + &DeleteToNextWordEnd { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, + ); + }); + // Brackets on the right are not paired anymore, hence deletion does not stop at them + cx.assert_editor_state(r#"ˇ");"#); + + cx.update_editor(|editor, window, cx| { + editor.delete_to_next_word_end( + &DeleteToNextWordEnd { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, + ); + }); + cx.assert_editor_state(r#"ˇ"#); + + cx.update_editor(|editor, window, cx| { + editor.delete_to_next_word_end( + &DeleteToNextWordEnd { + ignore_newlines: true, + ignore_brackets: false, + }, + window, + cx, + ); + }); + cx.assert_editor_state(r#"ˇ"#); + + cx.set_state(r#"macro!("// ˇCOMMENT");"#); + cx.update_editor(|editor, window, cx| { + editor.delete_to_previous_word_start( + &DeleteToPreviousWordStart { + ignore_newlines: true, + ignore_brackets: true, + }, + window, + cx, + ); + }); + cx.assert_editor_state(r#"macroˇCOMMENT");"#); +} + +#[gpui::test] +fn test_delete_to_previous_word_start_or_newline(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let editor = cx.add_window(|window, cx| { + let buffer = MultiBuffer::build_simple("one\n2\nthree\n4", cx); + build_editor(buffer, window, cx) + }); + let del_to_prev_word_start = DeleteToPreviousWordStart { + ignore_newlines: false, + ignore_brackets: false, + }; + let del_to_prev_word_start_ignore_newlines = DeleteToPreviousWordStart { + ignore_newlines: true, + ignore_brackets: false, + }; + + _ = editor.update(cx, |editor, window, cx| { + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_display_ranges([ + DisplayPoint::new(DisplayRow(3), 1)..DisplayPoint::new(DisplayRow(3), 1) + ]) + }); + editor.delete_to_previous_word_start(&del_to_prev_word_start, window, cx); + assert_eq!(editor.buffer.read(cx).read(cx).text(), "one\n2\nthree\n"); + editor.delete_to_previous_word_start(&del_to_prev_word_start, window, cx); + assert_eq!(editor.buffer.read(cx).read(cx).text(), "one\n2\nthree"); + editor.delete_to_previous_word_start(&del_to_prev_word_start, window, cx); + assert_eq!(editor.buffer.read(cx).read(cx).text(), "one\n2\n"); + editor.delete_to_previous_word_start(&del_to_prev_word_start, window, cx); + assert_eq!(editor.buffer.read(cx).read(cx).text(), "one\n2"); + editor.delete_to_previous_word_start(&del_to_prev_word_start_ignore_newlines, window, cx); + assert_eq!(editor.buffer.read(cx).read(cx).text(), "one\n"); + editor.delete_to_previous_word_start(&del_to_prev_word_start_ignore_newlines, window, cx); + assert_eq!(editor.buffer.read(cx).read(cx).text(), ""); + }); +} + +#[gpui::test] +fn test_delete_to_next_word_end_or_newline(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let editor = cx.add_window(|window, cx| { + let buffer = MultiBuffer::build_simple("\none\n two\nthree\n four", cx); + build_editor(buffer, window, cx) + }); + let del_to_next_word_end = DeleteToNextWordEnd { + ignore_newlines: false, + ignore_brackets: false, + }; + let del_to_next_word_end_ignore_newlines = DeleteToNextWordEnd { + ignore_newlines: true, + ignore_brackets: false, + }; + + _ = editor.update(cx, |editor, window, cx| { + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_display_ranges([ + DisplayPoint::new(DisplayRow(0), 0)..DisplayPoint::new(DisplayRow(0), 0) + ]) + }); + editor.delete_to_next_word_end(&del_to_next_word_end, window, cx); + assert_eq!( + editor.buffer.read(cx).read(cx).text(), + "one\n two\nthree\n four" + ); + editor.delete_to_next_word_end(&del_to_next_word_end, window, cx); + assert_eq!( + editor.buffer.read(cx).read(cx).text(), + "\n two\nthree\n four" + ); + editor.delete_to_next_word_end(&del_to_next_word_end, window, cx); + assert_eq!( + editor.buffer.read(cx).read(cx).text(), + "two\nthree\n four" + ); + editor.delete_to_next_word_end(&del_to_next_word_end, window, cx); + assert_eq!(editor.buffer.read(cx).read(cx).text(), "\nthree\n four"); + editor.delete_to_next_word_end(&del_to_next_word_end_ignore_newlines, window, cx); + assert_eq!(editor.buffer.read(cx).read(cx).text(), "\n four"); + editor.delete_to_next_word_end(&del_to_next_word_end_ignore_newlines, window, cx); + assert_eq!(editor.buffer.read(cx).read(cx).text(), "four"); + editor.delete_to_next_word_end(&del_to_next_word_end_ignore_newlines, window, cx); + assert_eq!(editor.buffer.read(cx).read(cx).text(), ""); + }); +} + +#[gpui::test] +fn test_newline(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let editor = cx.add_window(|window, cx| { + let buffer = MultiBuffer::build_simple("aaaa\n bbbb\n", cx); + build_editor(buffer, window, cx) + }); + + _ = editor.update(cx, |editor, window, cx| { + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_display_ranges([ + DisplayPoint::new(DisplayRow(0), 2)..DisplayPoint::new(DisplayRow(0), 2), + DisplayPoint::new(DisplayRow(1), 2)..DisplayPoint::new(DisplayRow(1), 2), + DisplayPoint::new(DisplayRow(1), 6)..DisplayPoint::new(DisplayRow(1), 6), + ]) + }); + + editor.newline(&Newline, window, cx); + assert_eq!(editor.text(cx), "aa\naa\n \n bb\n bb\n"); }); } @@ -2644,7 +2980,7 @@ fn test_newline_with_old_selections(cx: &mut TestAppContext) { .as_str(), cx, ); - let mut editor = build_editor(buffer.clone(), window, cx); + let mut editor = build_editor(buffer, window, cx); editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { s.select_ranges([ Point::new(2, 4)..Point::new(2, 5), @@ -3175,7 +3511,7 @@ fn test_insert_with_old_selections(cx: &mut TestAppContext) { let editor = cx.add_window(|window, cx| { let buffer = MultiBuffer::build_simple("a( X ), b( Y ), c( Z )", cx); - let mut editor = build_editor(buffer.clone(), window, cx); + let mut editor = build_editor(buffer, window, cx); editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { s.select_ranges([3..4, 11..12, 19..20]) }); @@ -4401,6 +4737,129 @@ async fn test_unique_lines_single_selection(cx: &mut TestAppContext) { "}); } +#[gpui::test] +async fn test_wrap_in_tag_single_selection(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + + let js_language = Arc::new(Language::new( + LanguageConfig { + name: "JavaScript".into(), + wrap_characters: Some(language::WrapCharactersConfig { + start_prefix: "<".into(), + start_suffix: ">".into(), + end_prefix: "".into(), + }), + ..LanguageConfig::default() + }, + None, + )); + + cx.update_buffer(|buffer, cx| buffer.set_language(Some(js_language), cx)); + + cx.set_state(indoc! {" + «testˇ» + "}); + cx.update_editor(|e, window, cx| e.wrap_selections_in_tag(&WrapSelectionsInTag, window, cx)); + cx.assert_editor_state(indoc! {" + <«ˇ»>test + "}); + + cx.set_state(indoc! {" + «test + testˇ» + "}); + cx.update_editor(|e, window, cx| e.wrap_selections_in_tag(&WrapSelectionsInTag, window, cx)); + cx.assert_editor_state(indoc! {" + <«ˇ»>test + test + "}); + + cx.set_state(indoc! {" + teˇst + "}); + cx.update_editor(|e, window, cx| e.wrap_selections_in_tag(&WrapSelectionsInTag, window, cx)); + cx.assert_editor_state(indoc! {" + te<«ˇ»>st + "}); +} + +#[gpui::test] +async fn test_wrap_in_tag_multi_selection(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + + let js_language = Arc::new(Language::new( + LanguageConfig { + name: "JavaScript".into(), + wrap_characters: Some(language::WrapCharactersConfig { + start_prefix: "<".into(), + start_suffix: ">".into(), + end_prefix: "".into(), + }), + ..LanguageConfig::default() + }, + None, + )); + + cx.update_buffer(|buffer, cx| buffer.set_language(Some(js_language), cx)); + + cx.set_state(indoc! {" + «testˇ» + «testˇ» «testˇ» + «testˇ» + "}); + cx.update_editor(|e, window, cx| e.wrap_selections_in_tag(&WrapSelectionsInTag, window, cx)); + cx.assert_editor_state(indoc! {" + <«ˇ»>test + <«ˇ»>test <«ˇ»>test + <«ˇ»>test + "}); + + cx.set_state(indoc! {" + «test + testˇ» + «test + testˇ» + "}); + cx.update_editor(|e, window, cx| e.wrap_selections_in_tag(&WrapSelectionsInTag, window, cx)); + cx.assert_editor_state(indoc! {" + <«ˇ»>test + test + <«ˇ»>test + test + "}); +} + +#[gpui::test] +async fn test_wrap_in_tag_does_nothing_in_unsupported_languages(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + + let plaintext_language = Arc::new(Language::new( + LanguageConfig { + name: "Plain Text".into(), + ..LanguageConfig::default() + }, + None, + )); + + cx.update_buffer(|buffer, cx| buffer.set_language(Some(plaintext_language), cx)); + + cx.set_state(indoc! {" + «testˇ» + "}); + cx.update_editor(|e, window, cx| e.wrap_selections_in_tag(&WrapSelectionsInTag, window, cx)); + cx.assert_editor_state(indoc! {" + «testˇ» + "}); +} + #[gpui::test] async fn test_manipulate_immutable_lines_with_multi_selection(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -4904,10 +5363,24 @@ async fn test_manipulate_text(cx: &mut TestAppContext) { cx.assert_editor_state(indoc! {" «HeLlO, wOrLD!ˇ» "}); -} -#[gpui::test] -fn test_duplicate_line(cx: &mut TestAppContext) { + // Test selections with `line_mode = true`. + cx.update_editor(|editor, _window, _cx| editor.selections.line_mode = true); + cx.set_state(indoc! {" + «The quick brown + fox jumps over + tˇ»he lazy dog + "}); + cx.update_editor(|e, window, cx| e.convert_to_upper_case(&ConvertToUpperCase, window, cx)); + cx.assert_editor_state(indoc! {" + «THE QUICK BROWN + FOX JUMPS OVER + THE LAZY DOGˇ» + "}); +} + +#[gpui::test] +fn test_duplicate_line(cx: &mut TestAppContext) { init_test(cx, |_| {}); let editor = cx.add_window(|window, cx| { @@ -5436,14 +5909,18 @@ async fn test_rewrap(cx: &mut TestAppContext) { }, None, )); - let rust_language = Arc::new(Language::new( - LanguageConfig { - name: "Rust".into(), - line_comments: vec!["// ".into(), "/// ".into()], - ..LanguageConfig::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - )); + let rust_language = Arc::new( + Language::new( + LanguageConfig { + name: "Rust".into(), + line_comments: vec!["// ".into(), "/// ".into()], + ..LanguageConfig::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_override_query("[(line_comment)(block_comment)] @comment.inclusive") + .unwrap(), + ); let plaintext_language = Arc::new(Language::new( LanguageConfig { @@ -5562,7 +6039,7 @@ async fn test_rewrap(cx: &mut TestAppContext) { # ˇThis is a long comment using a pound # sign. "}, - python_language.clone(), + python_language, &mut cx, ); @@ -5669,7 +6146,7 @@ async fn test_rewrap(cx: &mut TestAppContext) { also very long and should not merge with the numbered item.ˇ» "}, - markdown_language.clone(), + markdown_language, &mut cx, ); @@ -5700,7 +6177,7 @@ async fn test_rewrap(cx: &mut TestAppContext) { // This is the second long comment block // to be wrapped.ˇ» "}, - rust_language.clone(), + rust_language, &mut cx, ); @@ -5723,7 +6200,7 @@ async fn test_rewrap(cx: &mut TestAppContext) { «\tThis is a very long indented line \tthat will be wrapped.ˇ» "}, - plaintext_language.clone(), + plaintext_language, &mut cx, ); @@ -5759,6 +6236,411 @@ async fn test_rewrap(cx: &mut TestAppContext) { } } +#[gpui::test] +async fn test_rewrap_block_comments(cx: &mut TestAppContext) { + init_test(cx, |settings| { + settings.languages.0.extend([( + "Rust".into(), + LanguageSettingsContent { + allow_rewrap: Some(language_settings::RewrapBehavior::InComments), + preferred_line_length: Some(40), + ..Default::default() + }, + )]) + }); + + let mut cx = EditorTestContext::new(cx).await; + + let rust_lang = Arc::new( + Language::new( + LanguageConfig { + name: "Rust".into(), + line_comments: vec!["// ".into()], + block_comment: Some(BlockCommentConfig { + start: "/*".into(), + end: "*/".into(), + prefix: "* ".into(), + tab_size: 1, + }), + documentation_comment: Some(BlockCommentConfig { + start: "/**".into(), + end: "*/".into(), + prefix: "* ".into(), + tab_size: 1, + }), + + ..LanguageConfig::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_override_query("[(line_comment) (block_comment)] @comment.inclusive") + .unwrap(), + ); + + // regular block comment + assert_rewrap( + indoc! {" + /* + *ˇ Lorem ipsum dolor sit amet, consectetur adipiscing elit. + */ + /*ˇ Lorem ipsum dolor sit amet, consectetur adipiscing elit. */ + "}, + indoc! {" + /* + *ˇ Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. + */ + /* + *ˇ Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. + */ + "}, + rust_lang.clone(), + &mut cx, + ); + + // indent is respected + assert_rewrap( + indoc! {" + {} + /*ˇ Lorem ipsum dolor sit amet, consectetur adipiscing elit. */ + "}, + indoc! {" + {} + /* + *ˇ Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. + */ + "}, + rust_lang.clone(), + &mut cx, + ); + + // short block comments with inline delimiters + assert_rewrap( + indoc! {" + /*ˇ Lorem ipsum dolor sit amet, consectetur adipiscing elit. */ + /*ˇ Lorem ipsum dolor sit amet, consectetur adipiscing elit. + */ + /* + *ˇ Lorem ipsum dolor sit amet, consectetur adipiscing elit. */ + "}, + indoc! {" + /* + *ˇ Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. + */ + /* + *ˇ Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. + */ + /* + *ˇ Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. + */ + "}, + rust_lang.clone(), + &mut cx, + ); + + // multiline block comment with inline start/end delimiters + assert_rewrap( + indoc! {" + /*ˇ Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. */ + "}, + indoc! {" + /* + *ˇ Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. + */ + "}, + rust_lang.clone(), + &mut cx, + ); + + // block comment rewrap still respects paragraph bounds + assert_rewrap( + indoc! {" + /* + *ˇ Lorem ipsum dolor sit amet, consectetur adipiscing elit. + * + * Lorem ipsum dolor sit amet, consectetur adipiscing elit. + */ + "}, + indoc! {" + /* + *ˇ Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. + * + * Lorem ipsum dolor sit amet, consectetur adipiscing elit. + */ + "}, + rust_lang.clone(), + &mut cx, + ); + + // documentation comments + assert_rewrap( + indoc! {" + /**ˇ Lorem ipsum dolor sit amet, consectetur adipiscing elit. */ + /** + *ˇ Lorem ipsum dolor sit amet, consectetur adipiscing elit. + */ + "}, + indoc! {" + /** + *ˇ Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. + */ + /** + *ˇ Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. + */ + "}, + rust_lang.clone(), + &mut cx, + ); + + // different, adjacent comments + assert_rewrap( + indoc! {" + /** + *ˇ Lorem ipsum dolor sit amet, consectetur adipiscing elit. + */ + /*ˇ Lorem ipsum dolor sit amet, consectetur adipiscing elit. */ + //ˇ Lorem ipsum dolor sit amet, consectetur adipiscing elit. + "}, + indoc! {" + /** + *ˇ Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. + */ + /* + *ˇ Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. + */ + //ˇ Lorem ipsum dolor sit amet, + // consectetur adipiscing elit. + "}, + rust_lang.clone(), + &mut cx, + ); + + // selection w/ single short block comment + assert_rewrap( + indoc! {" + «/* Lorem ipsum dolor sit amet, consectetur adipiscing elit. */ˇ» + "}, + indoc! {" + «/* + * Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. + */ˇ» + "}, + rust_lang.clone(), + &mut cx, + ); + + // rewrapping a single comment w/ abutting comments + assert_rewrap( + indoc! {" + /* ˇLorem ipsum dolor sit amet, consectetur adipiscing elit. */ + /* Lorem ipsum dolor sit amet, consectetur adipiscing elit. */ + "}, + indoc! {" + /* + * ˇLorem ipsum dolor sit amet, + * consectetur adipiscing elit. + */ + /* Lorem ipsum dolor sit amet, consectetur adipiscing elit. */ + "}, + rust_lang.clone(), + &mut cx, + ); + + // selection w/ non-abutting short block comments + assert_rewrap( + indoc! {" + «/* Lorem ipsum dolor sit amet, consectetur adipiscing elit. */ + + /* Lorem ipsum dolor sit amet, consectetur adipiscing elit. */ˇ» + "}, + indoc! {" + «/* + * Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. + */ + + /* + * Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. + */ˇ» + "}, + rust_lang.clone(), + &mut cx, + ); + + // selection of multiline block comments + assert_rewrap( + indoc! {" + «/* Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. */ˇ» + "}, + indoc! {" + «/* + * Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. + */ˇ» + "}, + rust_lang.clone(), + &mut cx, + ); + + // partial selection of multiline block comments + assert_rewrap( + indoc! {" + «/* Lorem ipsum dolor sit amet,ˇ» + * consectetur adipiscing elit. */ + /* Lorem ipsum dolor sit amet, + «* consectetur adipiscing elit. */ˇ» + "}, + indoc! {" + «/* + * Lorem ipsum dolor sit amet,ˇ» + * consectetur adipiscing elit. */ + /* Lorem ipsum dolor sit amet, + «* consectetur adipiscing elit. + */ˇ» + "}, + rust_lang.clone(), + &mut cx, + ); + + // selection w/ abutting short block comments + // TODO: should not be combined; should rewrap as 2 comments + assert_rewrap( + indoc! {" + «/* Lorem ipsum dolor sit amet, consectetur adipiscing elit. */ + /* Lorem ipsum dolor sit amet, consectetur adipiscing elit. */ˇ» + "}, + // desired behavior: + // indoc! {" + // «/* + // * Lorem ipsum dolor sit amet, + // * consectetur adipiscing elit. + // */ + // /* + // * Lorem ipsum dolor sit amet, + // * consectetur adipiscing elit. + // */ˇ» + // "}, + // actual behaviour: + indoc! {" + «/* + * Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. Lorem + * ipsum dolor sit amet, consectetur + * adipiscing elit. + */ˇ» + "}, + rust_lang.clone(), + &mut cx, + ); + + // TODO: same as above, but with delimiters on separate line + // assert_rewrap( + // indoc! {" + // «/* Lorem ipsum dolor sit amet, consectetur adipiscing elit. + // */ + // /* + // * Lorem ipsum dolor sit amet, consectetur adipiscing elit. */ˇ» + // "}, + // // desired: + // // indoc! {" + // // «/* + // // * Lorem ipsum dolor sit amet, + // // * consectetur adipiscing elit. + // // */ + // // /* + // // * Lorem ipsum dolor sit amet, + // // * consectetur adipiscing elit. + // // */ˇ» + // // "}, + // // actual: (but with trailing w/s on the empty lines) + // indoc! {" + // «/* + // * Lorem ipsum dolor sit amet, + // * consectetur adipiscing elit. + // * + // */ + // /* + // * + // * Lorem ipsum dolor sit amet, + // * consectetur adipiscing elit. + // */ˇ» + // "}, + // rust_lang.clone(), + // &mut cx, + // ); + + // TODO these are unhandled edge cases; not correct, just documenting known issues + assert_rewrap( + indoc! {" + /* + //ˇ Lorem ipsum dolor sit amet, consectetur adipiscing elit. + */ + /* + //ˇ Lorem ipsum dolor sit amet, consectetur adipiscing elit. */ + /*ˇ Lorem ipsum dolor sit amet */ /* consectetur adipiscing elit. */ + "}, + // desired: + // indoc! {" + // /* + // *ˇ Lorem ipsum dolor sit amet, + // * consectetur adipiscing elit. + // */ + // /* + // *ˇ Lorem ipsum dolor sit amet, + // * consectetur adipiscing elit. + // */ + // /* + // *ˇ Lorem ipsum dolor sit amet + // */ /* consectetur adipiscing elit. */ + // "}, + // actual: + indoc! {" + /* + //ˇ Lorem ipsum dolor sit amet, + // consectetur adipiscing elit. + */ + /* + * //ˇ Lorem ipsum dolor sit amet, + * consectetur adipiscing elit. + */ + /* + *ˇ Lorem ipsum dolor sit amet */ /* + * consectetur adipiscing elit. + */ + "}, + rust_lang, + &mut cx, + ); + + #[track_caller] + fn assert_rewrap( + unwrapped_text: &str, + wrapped_text: &str, + language: Arc, + cx: &mut EditorTestContext, + ) { + cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx)); + cx.set_state(unwrapped_text); + cx.update_editor(|e, window, cx| e.rewrap(&Rewrap, window, cx)); + cx.assert_editor_state(wrapped_text); + } +} + #[gpui::test] async fn test_hard_wrap(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -5849,28 +6731,80 @@ async fn test_hard_wrap(cx: &mut TestAppContext) { } #[gpui::test] -async fn test_clipboard(cx: &mut TestAppContext) { +async fn test_cut_line_ends(cx: &mut TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - cx.set_state("«one✅ ˇ»two «three ˇ»four «five ˇ»six "); + cx.set_state(indoc! {" + The quick« brownˇ» + fox jumps overˇ + the lazy dog"}); cx.update_editor(|e, window, cx| e.cut(&Cut, window, cx)); - cx.assert_editor_state("ˇtwo ˇfour ˇsix "); + cx.assert_editor_state(indoc! {" + The quickˇ + ˇthe lazy dog"}); - // Paste with three cursors. Each cursor pastes one slice of the clipboard text. - cx.set_state("two ˇfour ˇsix ˇ"); - cx.update_editor(|e, window, cx| e.paste(&Paste, window, cx)); - cx.assert_editor_state("two one✅ ˇfour three ˇsix five ˇ"); + cx.set_state(indoc! {" + The quick« brownˇ» + fox jumps overˇ + the lazy dog"}); + cx.update_editor(|e, window, cx| e.cut_to_end_of_line(&CutToEndOfLine::default(), window, cx)); + cx.assert_editor_state(indoc! {" + The quickˇ + fox jumps overˇthe lazy dog"}); - // Paste again but with only two cursors. Since the number of cursors doesn't - // match the number of slices in the clipboard, the entire clipboard text - // is pasted at each cursor. - cx.set_state("ˇtwo one✅ four three six five ˇ"); + cx.set_state(indoc! {" + The quick« brownˇ» + fox jumps overˇ + the lazy dog"}); cx.update_editor(|e, window, cx| { - e.handle_input("( ", window, cx); - e.paste(&Paste, window, cx); - e.handle_input(") ", window, cx); + e.cut_to_end_of_line( + &CutToEndOfLine { + stop_at_newlines: true, + }, + window, + cx, + ) + }); + cx.assert_editor_state(indoc! {" + The quickˇ + fox jumps overˇ + the lazy dog"}); + + cx.set_state(indoc! {" + The quick« brownˇ» + fox jumps overˇ + the lazy dog"}); + cx.update_editor(|e, window, cx| e.kill_ring_cut(&KillRingCut, window, cx)); + cx.assert_editor_state(indoc! {" + The quickˇ + fox jumps overˇthe lazy dog"}); +} + +#[gpui::test] +async fn test_clipboard(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + + cx.set_state("«one✅ ˇ»two «three ˇ»four «five ˇ»six "); + cx.update_editor(|e, window, cx| e.cut(&Cut, window, cx)); + cx.assert_editor_state("ˇtwo ˇfour ˇsix "); + + // Paste with three cursors. Each cursor pastes one slice of the clipboard text. + cx.set_state("two ˇfour ˇsix ˇ"); + cx.update_editor(|e, window, cx| e.paste(&Paste, window, cx)); + cx.assert_editor_state("two one✅ ˇfour three ˇsix five ˇ"); + + // Paste again but with only two cursors. Since the number of cursors doesn't + // match the number of slices in the clipboard, the entire clipboard text + // is pasted at each cursor. + cx.set_state("ˇtwo one✅ four three six five ˇ"); + cx.update_editor(|e, window, cx| { + e.handle_input("( ", window, cx); + e.paste(&Paste, window, cx); + e.handle_input(") ", window, cx); }); cx.assert_editor_state( &([ @@ -6401,7 +7335,7 @@ async fn test_split_selection_into_lines(cx: &mut TestAppContext) { fn test(cx: &mut EditorTestContext, initial_state: &'static str, expected_state: &'static str) { cx.set_state(initial_state); cx.update_editor(|e, window, cx| { - e.split_selection_into_lines(&SplitSelectionIntoLines, window, cx) + e.split_selection_into_lines(&Default::default(), window, cx) }); cx.assert_editor_state(expected_state); } @@ -6489,7 +7423,7 @@ async fn test_split_selection_into_lines_interacting_with_creases(cx: &mut TestA DisplayPoint::new(DisplayRow(4), 4)..DisplayPoint::new(DisplayRow(4), 4), ]) }); - editor.split_selection_into_lines(&SplitSelectionIntoLines, window, cx); + editor.split_selection_into_lines(&Default::default(), window, cx); assert_eq!( editor.display_text(cx), "aaaaa\nbbbbb\nccc⋯eeee\nfffff\nggggg\n⋯i" @@ -6505,7 +7439,7 @@ async fn test_split_selection_into_lines_interacting_with_creases(cx: &mut TestA DisplayPoint::new(DisplayRow(5), 0)..DisplayPoint::new(DisplayRow(0), 1) ]) }); - editor.split_selection_into_lines(&SplitSelectionIntoLines, window, cx); + editor.split_selection_into_lines(&Default::default(), window, cx); assert_eq!( editor.display_text(cx), "aaaaa\nbbbbb\nccccc\nddddd\neeeee\nfffff\nggggg\nhhhhh\niiiii" @@ -7709,10 +8643,10 @@ async fn test_select_larger_smaller_syntax_node(cx: &mut TestAppContext) { assert_text_with_selections( editor, indoc! {r#" - use mod1::mod2::{mod3, mo«ˇ»d4}; + use mod1::mod2::{mod3, moˇd4}; fn fn_1(para«ˇm1: bool, pa»ram2: &str) { - let var1 = "te«ˇ»xt"; + let var1 = "teˇxt"; } "#}, cx, @@ -7727,10 +8661,10 @@ async fn test_select_larger_smaller_syntax_node(cx: &mut TestAppContext) { assert_text_with_selections( editor, indoc! {r#" - use mod1::mod2::{mod3, mo«ˇ»d4}; + use mod1::mod2::{mod3, moˇd4}; fn fn_1(para«ˇm1: bool, pa»ram2: &str) { - let var1 = "te«ˇ»xt"; + let var1 = "teˇxt"; } "#}, cx, @@ -7834,6 +8768,184 @@ async fn test_select_larger_syntax_node_for_cursor_at_end(cx: &mut TestAppContex }); } +#[gpui::test] +async fn test_select_larger_syntax_node_for_cursor_at_symbol(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let language = Arc::new(Language::new( + LanguageConfig { + name: "JavaScript".into(), + ..Default::default() + }, + Some(tree_sitter_typescript::LANGUAGE_TSX.into()), + )); + + let text = r#" + let a = { + key: "value", + }; + "# + .unindent(); + + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language, cx)); + let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); + let (editor, cx) = cx.add_window_view(|window, cx| build_editor(buffer, window, cx)); + + editor + .condition::(cx, |editor, cx| !editor.buffer.read(cx).is_parsing(cx)) + .await; + + // Test case 1: Cursor after '{' + editor.update_in(cx, |editor, window, cx| { + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_display_ranges([ + DisplayPoint::new(DisplayRow(0), 9)..DisplayPoint::new(DisplayRow(0), 9) + ]); + }); + }); + editor.update(cx, |editor, cx| { + assert_text_with_selections( + editor, + indoc! {r#" + let a = {ˇ + key: "value", + }; + "#}, + cx, + ); + }); + editor.update_in(cx, |editor, window, cx| { + editor.select_larger_syntax_node(&SelectLargerSyntaxNode, window, cx); + }); + editor.update(cx, |editor, cx| { + assert_text_with_selections( + editor, + indoc! {r#" + let a = «ˇ{ + key: "value", + }»; + "#}, + cx, + ); + }); + + // Test case 2: Cursor after ':' + editor.update_in(cx, |editor, window, cx| { + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_display_ranges([ + DisplayPoint::new(DisplayRow(1), 8)..DisplayPoint::new(DisplayRow(1), 8) + ]); + }); + }); + editor.update(cx, |editor, cx| { + assert_text_with_selections( + editor, + indoc! {r#" + let a = { + key:ˇ "value", + }; + "#}, + cx, + ); + }); + editor.update_in(cx, |editor, window, cx| { + editor.select_larger_syntax_node(&SelectLargerSyntaxNode, window, cx); + }); + editor.update(cx, |editor, cx| { + assert_text_with_selections( + editor, + indoc! {r#" + let a = { + «ˇkey: "value"», + }; + "#}, + cx, + ); + }); + editor.update_in(cx, |editor, window, cx| { + editor.select_larger_syntax_node(&SelectLargerSyntaxNode, window, cx); + }); + editor.update(cx, |editor, cx| { + assert_text_with_selections( + editor, + indoc! {r#" + let a = «ˇ{ + key: "value", + }»; + "#}, + cx, + ); + }); + + // Test case 3: Cursor after ',' + editor.update_in(cx, |editor, window, cx| { + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_display_ranges([ + DisplayPoint::new(DisplayRow(1), 17)..DisplayPoint::new(DisplayRow(1), 17) + ]); + }); + }); + editor.update(cx, |editor, cx| { + assert_text_with_selections( + editor, + indoc! {r#" + let a = { + key: "value",ˇ + }; + "#}, + cx, + ); + }); + editor.update_in(cx, |editor, window, cx| { + editor.select_larger_syntax_node(&SelectLargerSyntaxNode, window, cx); + }); + editor.update(cx, |editor, cx| { + assert_text_with_selections( + editor, + indoc! {r#" + let a = «ˇ{ + key: "value", + }»; + "#}, + cx, + ); + }); + + // Test case 4: Cursor after ';' + editor.update_in(cx, |editor, window, cx| { + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_display_ranges([ + DisplayPoint::new(DisplayRow(2), 2)..DisplayPoint::new(DisplayRow(2), 2) + ]); + }); + }); + editor.update(cx, |editor, cx| { + assert_text_with_selections( + editor, + indoc! {r#" + let a = { + key: "value", + };ˇ + "#}, + cx, + ); + }); + editor.update_in(cx, |editor, window, cx| { + editor.select_larger_syntax_node(&SelectLargerSyntaxNode, window, cx); + }); + editor.update(cx, |editor, cx| { + assert_text_with_selections( + editor, + indoc! {r#" + «ˇlet a = { + key: "value", + }; + »"#}, + cx, + ); + }); +} + #[gpui::test] async fn test_select_larger_smaller_syntax_node_for_string(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -8015,7 +9127,7 @@ async fn test_select_larger_smaller_syntax_node_for_string(cx: &mut TestAppConte } #[gpui::test] -async fn test_unwrap_syntax_node(cx: &mut gpui::TestAppContext) { +async fn test_unwrap_syntax_nodes(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; @@ -8029,21 +9141,12 @@ async fn test_unwrap_syntax_node(cx: &mut gpui::TestAppContext) { buffer.set_language(Some(language), cx); }); - cx.set_state( - &r#" - use mod1::mod2::{«mod3ˇ», mod4}; - "# - .unindent(), - ); + cx.set_state(indoc! { r#"use mod1::{mod2::{«mod3ˇ», mod4}, mod5::{mod6, «mod7ˇ»}};"# }); cx.update_editor(|editor, window, cx| { editor.unwrap_syntax_node(&UnwrapSyntaxNode, window, cx); }); - cx.assert_editor_state( - &r#" - use mod1::mod2::«mod3ˇ»; - "# - .unindent(), - ); + + cx.assert_editor_state(indoc! { r#"use mod1::{mod2::«mod3ˇ», mod5::«mod7ˇ»};"# }); } #[gpui::test] @@ -8214,6 +9317,216 @@ async fn test_autoindent(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_autoindent_disabled(cx: &mut TestAppContext) { + init_test(cx, |settings| settings.defaults.auto_indent = Some(false)); + + let language = Arc::new( + Language::new( + LanguageConfig { + brackets: BracketPairConfig { + pairs: vec![ + BracketPair { + start: "{".to_string(), + end: "}".to_string(), + close: false, + surround: false, + newline: true, + }, + BracketPair { + start: "(".to_string(), + end: ")".to_string(), + close: false, + surround: false, + newline: true, + }, + ], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_indents_query( + r#" + (_ "(" ")" @end) @indent + (_ "{" "}" @end) @indent + "#, + ) + .unwrap(), + ); + + let text = "fn a() {}"; + + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language, cx)); + let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); + let (editor, cx) = cx.add_window_view(|window, cx| build_editor(buffer, window, cx)); + editor + .condition::(cx, |editor, cx| !editor.buffer.read(cx).is_parsing(cx)) + .await; + + editor.update_in(cx, |editor, window, cx| { + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_ranges([5..5, 8..8, 9..9]) + }); + editor.newline(&Newline, window, cx); + assert_eq!( + editor.text(cx), + indoc!( + " + fn a( + + ) { + + } + " + ) + ); + assert_eq!( + editor.selections.ranges(cx), + &[ + Point::new(1, 0)..Point::new(1, 0), + Point::new(3, 0)..Point::new(3, 0), + Point::new(5, 0)..Point::new(5, 0) + ] + ); + }); +} + +#[gpui::test] +async fn test_autoindent_disabled_with_nested_language(cx: &mut TestAppContext) { + init_test(cx, |settings| { + settings.defaults.auto_indent = Some(true); + settings.languages.0.insert( + "python".into(), + LanguageSettingsContent { + auto_indent: Some(false), + ..Default::default() + }, + ); + }); + + let mut cx = EditorTestContext::new(cx).await; + + let injected_language = Arc::new( + Language::new( + LanguageConfig { + brackets: BracketPairConfig { + pairs: vec![ + BracketPair { + start: "{".to_string(), + end: "}".to_string(), + close: false, + surround: false, + newline: true, + }, + BracketPair { + start: "(".to_string(), + end: ")".to_string(), + close: true, + surround: false, + newline: true, + }, + ], + ..Default::default() + }, + name: "python".into(), + ..Default::default() + }, + Some(tree_sitter_python::LANGUAGE.into()), + ) + .with_indents_query( + r#" + (_ "(" ")" @end) @indent + (_ "{" "}" @end) @indent + "#, + ) + .unwrap(), + ); + + let language = Arc::new( + Language::new( + LanguageConfig { + brackets: BracketPairConfig { + pairs: vec![ + BracketPair { + start: "{".to_string(), + end: "}".to_string(), + close: false, + surround: false, + newline: true, + }, + BracketPair { + start: "(".to_string(), + end: ")".to_string(), + close: true, + surround: false, + newline: true, + }, + ], + ..Default::default() + }, + name: LanguageName::new("rust"), + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_indents_query( + r#" + (_ "(" ")" @end) @indent + (_ "{" "}" @end) @indent + "#, + ) + .unwrap() + .with_injection_query( + r#" + (macro_invocation + macro: (identifier) @_macro_name + (token_tree) @injection.content + (#set! injection.language "python")) + "#, + ) + .unwrap(), + ); + + cx.language_registry().add(injected_language); + cx.language_registry().add(language.clone()); + + cx.update_buffer(|buffer, cx| { + buffer.set_language(Some(language), cx); + }); + + cx.set_state(r#"struct A {ˇ}"#); + + cx.update_editor(|editor, window, cx| { + editor.newline(&Default::default(), window, cx); + }); + + cx.assert_editor_state(indoc!( + "struct A { + ˇ + }" + )); + + cx.set_state(r#"select_biased!(ˇ)"#); + + cx.update_editor(|editor, window, cx| { + editor.newline(&Default::default(), window, cx); + editor.handle_input("def ", window, cx); + editor.handle_input("(", window, cx); + editor.newline(&Default::default(), window, cx); + editor.handle_input("a", window, cx); + }); + + cx.assert_editor_state(indoc!( + "select_biased!( + def ( + aˇ + ) + )" + )); +} + #[gpui::test] async fn test_autoindent_selections(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -8688,7 +10001,7 @@ async fn test_autoclose_with_embedded_language(cx: &mut TestAppContext) { )); cx.language_registry().add(html_language.clone()); - cx.language_registry().add(javascript_language.clone()); + cx.language_registry().add(javascript_language); cx.executor().run_until_parked(); cx.update_buffer(|buffer, cx| { @@ -9432,7 +10745,7 @@ async fn test_snippets(cx: &mut TestAppContext) { .selections .all(cx) .iter() - .map(|s| s.range().clone()) + .map(|s| s.range()) .collect::>(); editor .insert_snippet(&insertion_ranges, snippet, window, cx) @@ -9512,7 +10825,7 @@ async fn test_snippet_indentation(cx: &mut TestAppContext) { .selections .all(cx) .iter() - .map(|s| s.range().clone()) + .map(|s| s.range()) .collect::>(); editor .insert_snippet(&insertion_ranges, snippet, window, cx) @@ -9583,7 +10896,7 @@ async fn test_document_format_during_save(cx: &mut TestAppContext) { move |params, _| async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/file.rs")).unwrap() + lsp::Uri::from_file_path(path!("/file.rs")).unwrap() ); assert_eq!(params.options.tab_size, 4); Ok(Some(vec![lsp::TextEdit::new( @@ -9626,7 +10939,7 @@ async fn test_document_format_during_save(cx: &mut TestAppContext) { move |params, _| async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/file.rs")).unwrap() + lsp::Uri::from_file_path(path!("/file.rs")).unwrap() ); futures::future::pending::<()>().await; unreachable!() @@ -9674,7 +10987,7 @@ async fn test_document_format_during_save(cx: &mut TestAppContext) { .set_request_handler::(move |params, _| async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/file.rs")).unwrap() + lsp::Uri::from_file_path(path!("/file.rs")).unwrap() ); assert_eq!(params.options.tab_size, 8); Ok(Some(vec![])) @@ -10222,7 +11535,7 @@ async fn test_range_format_on_save_success(cx: &mut TestAppContext) { .set_request_handler::(move |params, _| async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/file.rs")).unwrap() + lsp::Uri::from_file_path(path!("/file.rs")).unwrap() ); assert_eq!(params.options.tab_size, 4); Ok(Some(vec![lsp::TextEdit::new( @@ -10255,7 +11568,7 @@ async fn test_range_format_on_save_timeout(cx: &mut TestAppContext) { move |params, _| async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/file.rs")).unwrap() + lsp::Uri::from_file_path(path!("/file.rs")).unwrap() ); futures::future::pending::<()>().await; unreachable!() @@ -10348,7 +11661,7 @@ async fn test_range_format_respects_language_tab_size_override(cx: &mut TestAppC .set_request_handler::(move |params, _| async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/file.rs")).unwrap() + lsp::Uri::from_file_path(path!("/file.rs")).unwrap() ); assert_eq!(params.options.tab_size, 8); Ok(Some(Vec::new())) @@ -10435,7 +11748,7 @@ async fn test_document_format_manual_trigger(cx: &mut TestAppContext) { .set_request_handler::(move |params, _| async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/file.rs")).unwrap() + lsp::Uri::from_file_path(path!("/file.rs")).unwrap() ); assert_eq!(params.options.tab_size, 4); Ok(Some(vec![lsp::TextEdit::new( @@ -10460,7 +11773,7 @@ async fn test_document_format_manual_trigger(cx: &mut TestAppContext) { move |params, _| async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/file.rs")).unwrap() + lsp::Uri::from_file_path(path!("/file.rs")).unwrap() ); futures::future::pending::<()>().await; unreachable!() @@ -10556,7 +11869,7 @@ async fn test_multiple_formatters(cx: &mut TestAppContext) { params.context.only, Some(vec!["code-action-1".into(), "code-action-2".into()]) ); - let uri = lsp::Url::from_file_path(path!("/file.rs")).unwrap(); + let uri = lsp::Uri::from_file_path(path!("/file.rs")).unwrap(); Ok(Some(vec![ lsp::CodeActionOrCommand::CodeAction(lsp::CodeAction { kind: Some("code-action-1".into()), @@ -10581,7 +11894,7 @@ async fn test_multiple_formatters(cx: &mut TestAppContext) { kind: Some("code-action-2".into()), edit: Some(lsp::WorkspaceEdit::new( [( - uri.clone(), + uri, vec![lsp::TextEdit::new( lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 0)), "applied-code-action-2-edit\n".to_string(), @@ -10616,7 +11929,7 @@ async fn test_multiple_formatters(cx: &mut TestAppContext) { edit: lsp::WorkspaceEdit { changes: Some( [( - lsp::Url::from_file_path(path!("/file.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/file.rs")).unwrap(), vec![lsp::TextEdit { range: lsp::Range::new( lsp::Position::new(0, 0), @@ -10827,7 +12140,7 @@ async fn test_organize_imports_manual_trigger(cx: &mut TestAppContext) { .set_request_handler::(move |params, _| async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/file.ts")).unwrap() + lsp::Uri::from_file_path(path!("/file.ts")).unwrap() ); Ok(Some(vec![lsp::CodeActionOrCommand::CodeAction( lsp::CodeAction { @@ -10875,7 +12188,7 @@ async fn test_organize_imports_manual_trigger(cx: &mut TestAppContext) { move |params, _| async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/file.ts")).unwrap() + lsp::Uri::from_file_path(path!("/file.ts")).unwrap() ); futures::future::pending::<()>().await; unreachable!() @@ -12036,6 +13349,7 @@ async fn test_completion_mode(cx: &mut TestAppContext) { settings.defaults.completions = Some(CompletionSettings { lsp_insert_mode, words: WordsCompletionMode::Disabled, + words_min_length: 0, lsp: true, lsp_fetch_timeout_ms: 0, }); @@ -12094,6 +13408,7 @@ async fn test_completion_with_mode_specified_by_action(cx: &mut TestAppContext) update_test_language_settings(&mut cx, |settings| { settings.defaults.completions = Some(CompletionSettings { words: WordsCompletionMode::Disabled, + words_min_length: 0, // set the opposite here to ensure that the action is overriding the default behavior lsp_insert_mode: LspInsertMode::Insert, lsp: true, @@ -12109,7 +13424,7 @@ async fn test_completion_with_mode_specified_by_action(cx: &mut TestAppContext) let counter = Arc::new(AtomicUsize::new(0)); handle_completion_request_with_insert_and_replace( &mut cx, - &buffer_marked_text, + buffer_marked_text, vec![(completion_text, completion_text)], counter.clone(), ) @@ -12123,13 +13438,14 @@ async fn test_completion_with_mode_specified_by_action(cx: &mut TestAppContext) .confirm_completion_replace(&ConfirmCompletionReplace, window, cx) .unwrap() }); - cx.assert_editor_state(&expected_with_replace_mode); + cx.assert_editor_state(expected_with_replace_mode); handle_resolve_completion_request(&mut cx, None).await; apply_additional_edits.await.unwrap(); update_test_language_settings(&mut cx, |settings| { settings.defaults.completions = Some(CompletionSettings { words: WordsCompletionMode::Disabled, + words_min_length: 0, // set the opposite here to ensure that the action is overriding the default behavior lsp_insert_mode: LspInsertMode::Replace, lsp: true, @@ -12143,7 +13459,7 @@ async fn test_completion_with_mode_specified_by_action(cx: &mut TestAppContext) }); handle_completion_request_with_insert_and_replace( &mut cx, - &buffer_marked_text, + buffer_marked_text, vec![(completion_text, completion_text)], counter.clone(), ) @@ -12157,7 +13473,7 @@ async fn test_completion_with_mode_specified_by_action(cx: &mut TestAppContext) .confirm_completion_insert(&ConfirmCompletionInsert, window, cx) .unwrap() }); - cx.assert_editor_state(&expected_with_insert_mode); + cx.assert_editor_state(expected_with_insert_mode); handle_resolve_completion_request(&mut cx, None).await; apply_additional_edits.await.unwrap(); } @@ -12871,6 +14187,7 @@ async fn test_word_completion(cx: &mut TestAppContext) { init_test(cx, |language_settings| { language_settings.defaults.completions = Some(CompletionSettings { words: WordsCompletionMode::Fallback, + words_min_length: 0, lsp: true, lsp_fetch_timeout_ms: 10, lsp_insert_mode: LspInsertMode::Insert, @@ -12931,7 +14248,7 @@ async fn test_word_completion(cx: &mut TestAppContext) { if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref() { assert_eq!( - completion_menu_entries(&menu), + completion_menu_entries(menu), &["first", "last"], "When LSP server is fast to reply, no fallback word completions are used" ); @@ -12954,7 +14271,7 @@ async fn test_word_completion(cx: &mut TestAppContext) { cx.update_editor(|editor, _, _| { if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref() { - assert_eq!(completion_menu_entries(&menu), &["one", "three", "two"], + assert_eq!(completion_menu_entries(menu), &["one", "three", "two"], "When LSP server is slow, document words can be shown instead, if configured accordingly"); } else { panic!("expected completion menu to be open"); @@ -12967,6 +14284,7 @@ async fn test_word_completions_do_not_duplicate_lsp_ones(cx: &mut TestAppContext init_test(cx, |language_settings| { language_settings.defaults.completions = Some(CompletionSettings { words: WordsCompletionMode::Enabled, + words_min_length: 0, lsp: true, lsp_fetch_timeout_ms: 0, lsp_insert_mode: LspInsertMode::Insert, @@ -13015,7 +14333,7 @@ async fn test_word_completions_do_not_duplicate_lsp_ones(cx: &mut TestAppContext if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref() { assert_eq!( - completion_menu_entries(&menu), + completion_menu_entries(menu), &["first", "last", "second"], "Word completions that has the same edit as the any of the LSP ones, should not be proposed" ); @@ -13030,6 +14348,7 @@ async fn test_word_completions_continue_on_typing(cx: &mut TestAppContext) { init_test(cx, |language_settings| { language_settings.defaults.completions = Some(CompletionSettings { words: WordsCompletionMode::Disabled, + words_min_length: 0, lsp: true, lsp_fetch_timeout_ms: 0, lsp_insert_mode: LspInsertMode::Insert, @@ -13071,7 +14390,7 @@ async fn test_word_completions_continue_on_typing(cx: &mut TestAppContext) { if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref() { assert_eq!( - completion_menu_entries(&menu), + completion_menu_entries(menu), &["first", "last", "second"], "`ShowWordCompletions` action should show word completions" ); @@ -13088,7 +14407,7 @@ async fn test_word_completions_continue_on_typing(cx: &mut TestAppContext) { if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref() { assert_eq!( - completion_menu_entries(&menu), + completion_menu_entries(menu), &["last"], "After showing word completions, further editing should filter them and not query the LSP" ); @@ -13103,6 +14422,7 @@ async fn test_word_completions_usually_skip_digits(cx: &mut TestAppContext) { init_test(cx, |language_settings| { language_settings.defaults.completions = Some(CompletionSettings { words: WordsCompletionMode::Fallback, + words_min_length: 0, lsp: false, lsp_fetch_timeout_ms: 0, lsp_insert_mode: LspInsertMode::Insert, @@ -13127,7 +14447,7 @@ async fn test_word_completions_usually_skip_digits(cx: &mut TestAppContext) { if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref() { assert_eq!( - completion_menu_entries(&menu), + completion_menu_entries(menu), &["let"], "With no digits in the completion query, no digits should be in the word completions" ); @@ -13152,7 +14472,7 @@ async fn test_word_completions_usually_skip_digits(cx: &mut TestAppContext) { cx.update_editor(|editor, _, _| { if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref() { - assert_eq!(completion_menu_entries(&menu), &["33", "35f32"], "The digit is in the completion query, \ + assert_eq!(completion_menu_entries(menu), &["33", "35f32"], "The digit is in the completion query, \ return matching words with digits (`33`, `35f32`) but exclude query duplicates (`3`)"); } else { panic!("expected completion menu to be open"); @@ -13160,6 +14480,120 @@ async fn test_word_completions_usually_skip_digits(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_word_completions_do_not_show_before_threshold(cx: &mut TestAppContext) { + init_test(cx, |language_settings| { + language_settings.defaults.completions = Some(CompletionSettings { + words: WordsCompletionMode::Enabled, + words_min_length: 3, + lsp: true, + lsp_fetch_timeout_ms: 0, + lsp_insert_mode: LspInsertMode::Insert, + }); + }); + + let mut cx = EditorLspTestContext::new_rust(lsp::ServerCapabilities::default(), cx).await; + cx.set_state(indoc! {"ˇ + wow + wowen + wowser + "}); + cx.simulate_keystroke("w"); + cx.executor().run_until_parked(); + cx.update_editor(|editor, _, _| { + if editor.context_menu.borrow_mut().is_some() { + panic!( + "expected completion menu to be hidden, as words completion threshold is not met" + ); + } + }); + + cx.update_editor(|editor, window, cx| { + editor.show_word_completions(&ShowWordCompletions, window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref() + { + assert_eq!(completion_menu_entries(menu), &["wowser", "wowen", "wow"], "Even though the threshold is not met, invoking word completions with an action should provide the completions"); + } else { + panic!("expected completion menu to be open after the word completions are called with an action"); + } + + editor.cancel(&Cancel, window, cx); + }); + cx.update_editor(|editor, _, _| { + if editor.context_menu.borrow_mut().is_some() { + panic!("expected completion menu to be hidden after canceling"); + } + }); + + cx.simulate_keystroke("o"); + cx.executor().run_until_parked(); + cx.update_editor(|editor, _, _| { + if editor.context_menu.borrow_mut().is_some() { + panic!( + "expected completion menu to be hidden, as words completion threshold is not met still" + ); + } + }); + + cx.simulate_keystroke("w"); + cx.executor().run_until_parked(); + cx.update_editor(|editor, _, _| { + if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref() + { + assert_eq!(completion_menu_entries(menu), &["wowen", "wowser"], "After word completion threshold is met, matching words should be shown, excluding the already typed word"); + } else { + panic!("expected completion menu to be open after the word completions threshold is met"); + } + }); +} + +#[gpui::test] +async fn test_word_completions_disabled(cx: &mut TestAppContext) { + init_test(cx, |language_settings| { + language_settings.defaults.completions = Some(CompletionSettings { + words: WordsCompletionMode::Enabled, + words_min_length: 0, + lsp: true, + lsp_fetch_timeout_ms: 0, + lsp_insert_mode: LspInsertMode::Insert, + }); + }); + + let mut cx = EditorLspTestContext::new_rust(lsp::ServerCapabilities::default(), cx).await; + cx.update_editor(|editor, _, _| { + editor.disable_word_completions(); + }); + cx.set_state(indoc! {"ˇ + wow + wowen + wowser + "}); + cx.simulate_keystroke("w"); + cx.executor().run_until_parked(); + cx.update_editor(|editor, _, _| { + if editor.context_menu.borrow_mut().is_some() { + panic!( + "expected completion menu to be hidden, as words completion are disabled for this editor" + ); + } + }); + + cx.update_editor(|editor, window, cx| { + editor.show_word_completions(&ShowWordCompletions, window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, _, _| { + if editor.context_menu.borrow_mut().is_some() { + panic!( + "expected completion menu to be hidden even if called for explicitly, as words completion are disabled for this editor" + ); + } + }); +} + fn gen_text_edit(params: &CompletionParams, text: &str) -> Option { let position = || lsp::Position { line: params.text_document_position.position.line, @@ -13389,7 +14823,7 @@ async fn test_completion_page_up_down_keys(cx: &mut TestAppContext) { cx.update_editor(|editor, _, _| { if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref() { - assert_eq!(completion_menu_entries(&menu), &["first", "last"]); + assert_eq!(completion_menu_entries(menu), &["first", "last"]); } else { panic!("expected completion menu to be open"); } @@ -14165,7 +15599,7 @@ async fn test_toggle_block_comment(cx: &mut TestAppContext) { )); cx.language_registry().add(html_language.clone()); - cx.language_registry().add(javascript_language.clone()); + cx.language_registry().add(javascript_language); cx.update_buffer(|buffer, cx| { buffer.set_language(Some(html_language), cx); }); @@ -14342,7 +15776,7 @@ fn test_editing_overlapping_excerpts(cx: &mut TestAppContext) { ); let excerpt_ranges = markers.into_iter().map(|marker| { let context = excerpt_ranges.remove(&marker).unwrap()[0].clone(); - ExcerptRange::new(context.clone()) + ExcerptRange::new(context) }); let buffer = cx.new(|cx| Buffer::local(initial_text, cx)); let multibuffer = cx.new(|cx| { @@ -14627,7 +16061,7 @@ fn test_highlighted_ranges(cx: &mut TestAppContext) { let editor = cx.add_window(|window, cx| { let buffer = MultiBuffer::build_simple(&sample_text(16, 8, 'a'), cx); - build_editor(buffer.clone(), window, cx) + build_editor(buffer, window, cx) }); _ = editor.update(cx, |editor, window, cx| { @@ -14661,37 +16095,34 @@ fn test_highlighted_ranges(cx: &mut TestAppContext) { ); let snapshot = editor.snapshot(window, cx); - let mut highlighted_ranges = editor.background_highlights_in_range( + let highlighted_ranges = editor.sorted_background_highlights_in_range( anchor_range(Point::new(3, 4)..Point::new(7, 4)), &snapshot, cx.theme(), ); - // Enforce a consistent ordering based on color without relying on the ordering of the - // highlight's `TypeId` which is non-executor. - highlighted_ranges.sort_unstable_by_key(|(_, color)| *color); assert_eq!( highlighted_ranges, &[ ( - DisplayPoint::new(DisplayRow(4), 2)..DisplayPoint::new(DisplayRow(4), 4), - Hsla::red(), + DisplayPoint::new(DisplayRow(3), 2)..DisplayPoint::new(DisplayRow(3), 5), + Hsla::green(), ), ( - DisplayPoint::new(DisplayRow(6), 3)..DisplayPoint::new(DisplayRow(6), 5), + DisplayPoint::new(DisplayRow(4), 2)..DisplayPoint::new(DisplayRow(4), 4), Hsla::red(), ), ( - DisplayPoint::new(DisplayRow(3), 2)..DisplayPoint::new(DisplayRow(3), 5), + DisplayPoint::new(DisplayRow(5), 3)..DisplayPoint::new(DisplayRow(5), 6), Hsla::green(), ), ( - DisplayPoint::new(DisplayRow(5), 3)..DisplayPoint::new(DisplayRow(5), 6), - Hsla::green(), + DisplayPoint::new(DisplayRow(6), 3)..DisplayPoint::new(DisplayRow(6), 5), + Hsla::red(), ), ] ); assert_eq!( - editor.background_highlights_in_range( + editor.sorted_background_highlights_in_range( anchor_range(Point::new(5, 6)..Point::new(6, 4)), &snapshot, cx.theme(), @@ -14712,7 +16143,7 @@ async fn test_following(cx: &mut TestAppContext) { let project = Project::test(fs, ["/file.rs".as_ref()], cx).await; let buffer = project.update(cx, |project, cx| { - let buffer = project.create_local_buffer(&sample_text(16, 8, 'a'), None, cx); + let buffer = project.create_local_buffer(&sample_text(16, 8, 'a'), None, false, cx); cx.new(|cx| MultiBuffer::singleton(buffer, cx)) }); let leader = cx.add_window(|window, cx| build_editor(buffer.clone(), window, cx)); @@ -14964,8 +16395,8 @@ async fn test_following_with_multiple_excerpts(cx: &mut TestAppContext) { let (buffer_1, buffer_2) = project.update(cx, |project, cx| { ( - project.create_local_buffer("abc\ndef\nghi\njkl\n", None, cx), - project.create_local_buffer("mno\npqr\nstu\nvwx\n", None, cx), + project.create_local_buffer("abc\ndef\nghi\njkl\n", None, false, cx), + project.create_local_buffer("mno\npqr\nstu\nvwx\n", None, false, cx), ) }); @@ -15082,7 +16513,7 @@ async fn go_to_prev_overlapping_diagnostic(executor: BackgroundExecutor, cx: &mu let mut cx = EditorTestContext::new(cx).await; let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.set_state(indoc! {" ˇfn func(abc def: i32) -> u32 { @@ -15095,7 +16526,7 @@ async fn go_to_prev_overlapping_diagnostic(executor: BackgroundExecutor, cx: &mu .update_diagnostics( LanguageServerId(0), lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/root/file")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/root/file")).unwrap(), version: None, diagnostics: vec![ lsp::Diagnostic { @@ -15491,7 +16922,7 @@ async fn test_on_type_formatting_not_triggered(cx: &mut TestAppContext) { |params, _| async move { assert_eq!( params.text_document_position.text_document.uri, - lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), ); assert_eq!( params.text_document_position.position, @@ -15549,8 +16980,7 @@ async fn test_on_type_formatting_is_applied_after_autoindent(cx: &mut TestAppCon cx.simulate_keystroke("\n"); cx.run_until_parked(); - let buffer_cloned = - cx.multibuffer(|multi_buffer, _| multi_buffer.as_singleton().unwrap().clone()); + let buffer_cloned = cx.multibuffer(|multi_buffer, _| multi_buffer.as_singleton().unwrap()); let mut request = cx.set_request_handler::(move |_, _, mut cx| { let buffer_cloned = buffer_cloned.clone(); @@ -15661,6 +17091,7 @@ async fn test_language_server_restart_due_to_settings_change(cx: &mut TestAppCon "some other init value": false })), enable_lsp_tasks: false, + fetch: None, }, ); }); @@ -15681,6 +17112,7 @@ async fn test_language_server_restart_due_to_settings_change(cx: &mut TestAppCon "anotherInitValue": false })), enable_lsp_tasks: false, + fetch: None, }, ); }); @@ -15701,6 +17133,7 @@ async fn test_language_server_restart_due_to_settings_change(cx: &mut TestAppCon "anotherInitValue": false })), enable_lsp_tasks: false, + fetch: None, }, ); }); @@ -15719,6 +17152,7 @@ async fn test_language_server_restart_due_to_settings_change(cx: &mut TestAppCon settings: None, initialization_options: None, enable_lsp_tasks: false, + fetch: None, }, ); }); @@ -16017,7 +17451,7 @@ async fn test_context_menus_hide_hover_popover(cx: &mut gpui::TestAppContext) { edit: Some(lsp::WorkspaceEdit { changes: Some( [( - lsp::Url::from_file_path(path!("/file.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/file.rs")).unwrap(), vec![lsp::TextEdit { range: lsp::Range::new( lsp::Position::new(5, 4), @@ -16492,7 +17926,7 @@ async fn test_completions_in_languages_with_extra_word_characters(cx: &mut TestA if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref() { assert_eq!( - completion_menu_entries(&menu), + completion_menu_entries(menu), &["bg-blue", "bg-red", "bg-yellow"] ); } else { @@ -16505,7 +17939,7 @@ async fn test_completions_in_languages_with_extra_word_characters(cx: &mut TestA cx.update_editor(|editor, _, _| { if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref() { - assert_eq!(completion_menu_entries(&menu), &["bg-blue", "bg-yellow"]); + assert_eq!(completion_menu_entries(menu), &["bg-blue", "bg-yellow"]); } else { panic!("expected completion menu to be open"); } @@ -16519,7 +17953,7 @@ async fn test_completions_in_languages_with_extra_word_characters(cx: &mut TestA cx.update_editor(|editor, _, _| { if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref() { - assert_eq!(completion_menu_entries(&menu), &["bg-yellow"]); + assert_eq!(completion_menu_entries(menu), &["bg-yellow"]); } else { panic!("expected completion menu to be open"); } @@ -17088,7 +18522,7 @@ async fn test_multibuffer_reverts(cx: &mut TestAppContext) { (buffer_2.clone(), base_text_2), (buffer_3.clone(), base_text_3), ] { - let diff = cx.new(|cx| BufferDiff::new_with_base_text(&diff_base, &buffer, cx)); + let diff = cx.new(|cx| BufferDiff::new_with_base_text(diff_base, &buffer, cx)); editor .buffer .update(cx, |buffer, cx| buffer.add_diff(diff, cx)); @@ -17709,7 +19143,7 @@ async fn test_toggle_diff_expand_in_multi_buffer(cx: &mut TestAppContext) { (buffer_2.clone(), file_2_old), (buffer_3.clone(), file_3_old), ] { - let diff = cx.new(|cx| BufferDiff::new_with_base_text(&diff_base, &buffer, cx)); + let diff = cx.new(|cx| BufferDiff::new_with_base_text(diff_base, &buffer, cx)); editor .buffer .update(cx, |buffer, cx| buffer.add_diff(diff, cx)); @@ -19254,7 +20688,7 @@ async fn test_adjacent_diff_hunks(executor: BackgroundExecutor, cx: &mut TestApp let buffer_id = hunks[0].buffer_id; hunks .into_iter() - .map(|hunk| Anchor::range_in_buffer(excerpt_id, buffer_id, hunk.buffer_range.clone())) + .map(|hunk| Anchor::range_in_buffer(excerpt_id, buffer_id, hunk.buffer_range)) .collect::>() }); assert_eq!(hunk_ranges.len(), 2); @@ -19345,7 +20779,7 @@ async fn test_adjacent_diff_hunks(executor: BackgroundExecutor, cx: &mut TestApp let buffer_id = hunks[0].buffer_id; hunks .into_iter() - .map(|hunk| Anchor::range_in_buffer(excerpt_id, buffer_id, hunk.buffer_range.clone())) + .map(|hunk| Anchor::range_in_buffer(excerpt_id, buffer_id, hunk.buffer_range)) .collect::>() }); assert_eq!(hunk_ranges.len(), 2); @@ -19411,7 +20845,7 @@ async fn test_toggle_deletion_hunk_at_start_of_file( let buffer_id = hunks[0].buffer_id; hunks .into_iter() - .map(|hunk| Anchor::range_in_buffer(excerpt_id, buffer_id, hunk.buffer_range.clone())) + .map(|hunk| Anchor::range_in_buffer(excerpt_id, buffer_id, hunk.buffer_range)) .collect::>() }); assert_eq!(hunk_ranges.len(), 1); @@ -19434,7 +20868,7 @@ async fn test_toggle_deletion_hunk_at_start_of_file( }); executor.run_until_parked(); - cx.assert_state_with_diff(hunk_expanded.clone()); + cx.assert_state_with_diff(hunk_expanded); } #[gpui::test] @@ -19485,7 +20919,7 @@ async fn test_display_diff_hunks(cx: &mut TestAppContext) { PathKey::namespaced(0, buffer.read(cx).file().unwrap().path().clone()), buffer.clone(), vec![text::Anchor::MIN.to_point(&snapshot)..text::Anchor::MAX.to_point(&snapshot)], - DEFAULT_MULTIBUFFER_CONTEXT, + 2, cx, ); } @@ -19634,13 +21068,8 @@ fn test_crease_insertion_and_rendering(cx: &mut TestAppContext) { editor.insert_creases(Some(crease), cx); let snapshot = editor.snapshot(window, cx); - let _div = snapshot.render_crease_toggle( - MultiBufferRow(1), - false, - cx.entity().clone(), - window, - cx, - ); + let _div = + snapshot.render_crease_toggle(MultiBufferRow(1), false, cx.entity(), window, cx); snapshot }) .unwrap(); @@ -20819,7 +22248,7 @@ async fn assert_highlighted_edits( cx.update(|_window, cx| { let highlighted_edits = edit_prediction_edit_text( - &snapshot.as_singleton().unwrap().2, + snapshot.as_singleton().unwrap().2, &edits, &edit_preview, include_deletions, @@ -20835,13 +22264,13 @@ fn assert_breakpoint( path: &Arc, expected: Vec<(u32, Breakpoint)>, ) { - if expected.len() == 0usize { + if expected.is_empty() { assert!(!breakpoints.contains_key(path), "{}", path.display()); } else { let mut breakpoint = breakpoints .get(path) .unwrap() - .into_iter() + .iter() .map(|breakpoint| { ( breakpoint.row, @@ -20870,13 +22299,7 @@ fn add_log_breakpoint_at_cursor( let (anchor, bp) = editor .breakpoints_at_cursors(window, cx) .first() - .and_then(|(anchor, bp)| { - if let Some(bp) = bp { - Some((*anchor, bp.clone())) - } else { - None - } - }) + .and_then(|(anchor, bp)| bp.as_ref().map(|bp| (*anchor, bp.clone()))) .unwrap_or_else(|| { let cursor_position: Point = editor.selections.newest(cx).head(); @@ -20886,7 +22309,7 @@ fn add_log_breakpoint_at_cursor( .buffer_snapshot .anchor_before(Point::new(cursor_position.row, 0)); - (breakpoint_position, Breakpoint::new_log(&log_message)) + (breakpoint_position, Breakpoint::new_log(log_message)) }); editor.edit_breakpoint_at_anchor( @@ -20954,7 +22377,7 @@ async fn test_breakpoint_toggling(cx: &mut TestAppContext) { let abs_path = project.read_with(cx, |project, cx| { project .absolute_path(&project_path, cx) - .map(|path_buf| Arc::from(path_buf.to_owned())) + .map(Arc::from) .unwrap() }); @@ -20972,7 +22395,6 @@ async fn test_breakpoint_toggling(cx: &mut TestAppContext) { .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); assert_eq!(1, breakpoints.len()); @@ -20997,7 +22419,6 @@ async fn test_breakpoint_toggling(cx: &mut TestAppContext) { .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); assert_eq!(1, breakpoints.len()); @@ -21019,7 +22440,6 @@ async fn test_breakpoint_toggling(cx: &mut TestAppContext) { .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); assert_eq!(0, breakpoints.len()); @@ -21071,7 +22491,7 @@ async fn test_log_breakpoint_editing(cx: &mut TestAppContext) { let abs_path = project.read_with(cx, |project, cx| { project .absolute_path(&project_path, cx) - .map(|path_buf| Arc::from(path_buf.to_owned())) + .map(Arc::from) .unwrap() }); @@ -21086,7 +22506,6 @@ async fn test_log_breakpoint_editing(cx: &mut TestAppContext) { .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); assert_breakpoint( @@ -21107,7 +22526,6 @@ async fn test_log_breakpoint_editing(cx: &mut TestAppContext) { .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); assert_breakpoint(&breakpoints, &abs_path, vec![]); @@ -21127,7 +22545,6 @@ async fn test_log_breakpoint_editing(cx: &mut TestAppContext) { .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); assert_breakpoint( @@ -21150,7 +22567,6 @@ async fn test_log_breakpoint_editing(cx: &mut TestAppContext) { .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); assert_breakpoint( @@ -21173,7 +22589,6 @@ async fn test_log_breakpoint_editing(cx: &mut TestAppContext) { .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); assert_breakpoint( @@ -21246,7 +22661,7 @@ async fn test_breakpoint_enabling_and_disabling(cx: &mut TestAppContext) { let abs_path = project.read_with(cx, |project, cx| { project .absolute_path(&project_path, cx) - .map(|path_buf| Arc::from(path_buf.to_owned())) + .map(Arc::from) .unwrap() }); @@ -21266,7 +22681,6 @@ async fn test_breakpoint_enabling_and_disabling(cx: &mut TestAppContext) { .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); assert_eq!(1, breakpoints.len()); @@ -21298,7 +22712,6 @@ async fn test_breakpoint_enabling_and_disabling(cx: &mut TestAppContext) { .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); let disable_breakpoint = { @@ -21334,7 +22747,6 @@ async fn test_breakpoint_enabling_and_disabling(cx: &mut TestAppContext) { .unwrap() .read(cx) .all_source_breakpoints(cx) - .clone() }); assert_eq!(1, breakpoints.len()); @@ -21707,7 +23119,7 @@ async fn test_apply_code_lens_actions_with_commands(cx: &mut gpui::TestAppContex edit: lsp::WorkspaceEdit { changes: Some( [( - lsp::Url::from_file_path(path!("/dir/a.ts")).unwrap(), + lsp::Uri::from_file_path(path!("/dir/a.ts")).unwrap(), vec![lsp::TextEdit { range: lsp::Range::new( lsp::Position::new(0, 0), @@ -22313,10 +23725,7 @@ async fn test_html_linked_edits_on_completion(cx: &mut TestAppContext) { let closing_range = buffer.anchor_before(Point::new(0, 6))..buffer.anchor_after(Point::new(0, 8)); let mut linked_ranges = HashMap::default(); - linked_ranges.insert( - buffer_id, - vec![(opening_range.clone(), vec![closing_range.clone()])], - ); + linked_ranges.insert(buffer_id, vec![(opening_range, vec![closing_range])]); editor.linked_edit_ranges = LinkedEditingRanges(linked_ranges); }); let mut completion_handle = @@ -22481,7 +23890,7 @@ async fn test_invisible_worktree_servers(cx: &mut TestAppContext) { .await .unwrap(); pane.update_in(cx, |pane, window, cx| { - pane.navigate_backward(window, cx); + pane.navigate_backward(&Default::default(), window, cx); }); cx.run_until_parked(); pane.update(cx, |pane, cx| { @@ -23426,7 +24835,7 @@ pub fn handle_completion_request( complete_from_position ); Ok(Some(lsp::CompletionResponse::List(lsp::CompletionList { - is_incomplete: is_incomplete, + is_incomplete, item_defaults: None, items: completions .iter() @@ -23682,7 +25091,7 @@ async fn test_pulling_diagnostics(cx: &mut TestAppContext) { let result_id = Some(new_result_id.to_string()); assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/a/first.rs")).unwrap() + lsp::Uri::from_file_path(path!("/a/first.rs")).unwrap() ); async move { Ok(lsp::DocumentDiagnosticReportResult::Report( @@ -23897,7 +25306,7 @@ async fn test_document_colors(cx: &mut TestAppContext) { async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/a/first.rs")).unwrap() + lsp::Uri::from_file_path(path!("/a/first.rs")).unwrap() ); requests_made.fetch_add(1, atomic::Ordering::Release); Ok(vec![ @@ -24068,7 +25477,7 @@ async fn test_document_colors(cx: &mut TestAppContext) { workspace .update(cx, |workspace, window, cx| { workspace.active_pane().update(cx, |pane, cx| { - pane.navigate_backward(window, cx); + pane.navigate_backward(&Default::default(), window, cx); }) }) .unwrap(); @@ -24116,6 +25525,231 @@ async fn test_newline_replacement_in_single_line(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_non_utf_8_opens(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + cx.update(|cx| { + register_project_item::(cx); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/root1", json!({})).await; + fs.insert_file("/root1/one.pdf", vec![0xff, 0xfe, 0xfd]) + .await; + + let project = Project::test(fs, ["/root1".as_ref()], cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let worktree_id = project.update(cx, |project, cx| { + project.worktrees(cx).next().unwrap().read(cx).id() + }); + + let handle = workspace + .update_in(cx, |workspace, window, cx| { + let project_path = (worktree_id, "one.pdf"); + workspace.open_path(project_path, None, true, window, cx) + }) + .await + .unwrap(); + + assert_eq!( + handle.to_any().entity_type(), + TypeId::of::() + ); +} + +#[gpui::test] +async fn test_select_next_prev_syntax_node(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let language = Arc::new(Language::new( + LanguageConfig::default(), + Some(tree_sitter_rust::LANGUAGE.into()), + )); + + // Test hierarchical sibling navigation + let text = r#" + fn outer() { + if condition { + let a = 1; + } + let b = 2; + } + + fn another() { + let c = 3; + } + "#; + + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language, cx)); + let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); + let (editor, cx) = cx.add_window_view(|window, cx| build_editor(buffer, window, cx)); + + // Wait for parsing to complete + editor + .condition::(cx, |editor, cx| !editor.buffer.read(cx).is_parsing(cx)) + .await; + + editor.update_in(cx, |editor, window, cx| { + // Start by selecting "let a = 1;" inside the if block + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_display_ranges([ + DisplayPoint::new(DisplayRow(3), 16)..DisplayPoint::new(DisplayRow(3), 26) + ]); + }); + + let initial_selection = editor.selections.display_ranges(cx); + assert_eq!(initial_selection.len(), 1, "Should have one selection"); + + // Test select next sibling - should move up levels to find the next sibling + // Since "let a = 1;" has no siblings in the if block, it should move up + // to find "let b = 2;" which is a sibling of the if block + editor.select_next_syntax_node(&SelectNextSyntaxNode, window, cx); + let next_selection = editor.selections.display_ranges(cx); + + // Should have a selection and it should be different from the initial + assert_eq!( + next_selection.len(), + 1, + "Should have one selection after next" + ); + assert_ne!( + next_selection[0], initial_selection[0], + "Next sibling selection should be different" + ); + + // Test hierarchical navigation by going to the end of the current function + // and trying to navigate to the next function + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_display_ranges([ + DisplayPoint::new(DisplayRow(5), 12)..DisplayPoint::new(DisplayRow(5), 22) + ]); + }); + + editor.select_next_syntax_node(&SelectNextSyntaxNode, window, cx); + let function_next_selection = editor.selections.display_ranges(cx); + + // Should move to the next function + assert_eq!( + function_next_selection.len(), + 1, + "Should have one selection after function next" + ); + + // Test select previous sibling navigation + editor.select_prev_syntax_node(&SelectPreviousSyntaxNode, window, cx); + let prev_selection = editor.selections.display_ranges(cx); + + // Should have a selection and it should be different + assert_eq!( + prev_selection.len(), + 1, + "Should have one selection after prev" + ); + assert_ne!( + prev_selection[0], function_next_selection[0], + "Previous sibling selection should be different from next" + ); + }); +} + +#[gpui::test] +async fn test_next_prev_document_highlight(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + cx.set_state( + "let ˇvariable = 42; +let another = variable + 1; +let result = variable * 2;", + ); + + // Set up document highlights manually (simulating LSP response) + cx.update_editor(|editor, _window, cx| { + let buffer_snapshot = editor.buffer().read(cx).snapshot(cx); + + // Create highlights for "variable" occurrences + let highlight_ranges = [ + Point::new(0, 4)..Point::new(0, 12), // First "variable" + Point::new(1, 14)..Point::new(1, 22), // Second "variable" + Point::new(2, 13)..Point::new(2, 21), // Third "variable" + ]; + + let anchor_ranges: Vec<_> = highlight_ranges + .iter() + .map(|range| range.clone().to_anchors(&buffer_snapshot)) + .collect(); + + editor.highlight_background::( + &anchor_ranges, + |theme| theme.colors().editor_document_highlight_read_background, + cx, + ); + }); + + // Go to next highlight - should move to second "variable" + cx.update_editor(|editor, window, cx| { + editor.go_to_next_document_highlight(&GoToNextDocumentHighlight, window, cx); + }); + cx.assert_editor_state( + "let variable = 42; +let another = ˇvariable + 1; +let result = variable * 2;", + ); + + // Go to next highlight - should move to third "variable" + cx.update_editor(|editor, window, cx| { + editor.go_to_next_document_highlight(&GoToNextDocumentHighlight, window, cx); + }); + cx.assert_editor_state( + "let variable = 42; +let another = variable + 1; +let result = ˇvariable * 2;", + ); + + // Go to next highlight - should stay at third "variable" (no wrap-around) + cx.update_editor(|editor, window, cx| { + editor.go_to_next_document_highlight(&GoToNextDocumentHighlight, window, cx); + }); + cx.assert_editor_state( + "let variable = 42; +let another = variable + 1; +let result = ˇvariable * 2;", + ); + + // Now test going backwards from third position + cx.update_editor(|editor, window, cx| { + editor.go_to_prev_document_highlight(&GoToPreviousDocumentHighlight, window, cx); + }); + cx.assert_editor_state( + "let variable = 42; +let another = ˇvariable + 1; +let result = variable * 2;", + ); + + // Go to previous highlight - should move to first "variable" + cx.update_editor(|editor, window, cx| { + editor.go_to_prev_document_highlight(&GoToPreviousDocumentHighlight, window, cx); + }); + cx.assert_editor_state( + "let ˇvariable = 42; +let another = variable + 1; +let result = variable * 2;", + ); + + // Go to previous highlight - should stay on first "variable" + cx.update_editor(|editor, window, cx| { + editor.go_to_prev_document_highlight(&GoToPreviousDocumentHighlight, window, cx); + }); + cx.assert_editor_state( + "let ˇvariable = 42; +let another = variable + 1; +let result = variable * 2;", + ); +} + #[track_caller] fn extract_color_inlays(editor: &Editor, cx: &App) -> Vec { editor diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 8a5c65f994fd0c03a59b939a3362f41f0a1bd205..673557a2c653311cc5f1c36b21c9937f146d79c1 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -40,14 +40,15 @@ use git::{ }; use gpui::{ Action, Along, AnyElement, App, AppContext, AvailableSpace, Axis as ScrollbarAxis, BorderStyle, - Bounds, ClickEvent, ContentMask, Context, Corner, Corners, CursorStyle, DispatchPhase, Edges, - Element, ElementInputHandler, Entity, Focusable as _, FontId, GlobalElementId, Hitbox, - HitboxBehavior, Hsla, InteractiveElement, IntoElement, IsZero, Keystroke, Length, - ModifiersChangedEvent, MouseButton, MouseClickEvent, MouseDownEvent, MouseMoveEvent, - MouseUpEvent, PaintQuad, ParentElement, Pixels, ScrollDelta, ScrollHandle, ScrollWheelEvent, - ShapedLine, SharedString, Size, StatefulInteractiveElement, Style, Styled, TextRun, - TextStyleRefinement, WeakEntity, Window, anchored, deferred, div, fill, linear_color_stop, - linear_gradient, outline, point, px, quad, relative, size, solid_background, transparent_black, + Bounds, ClickEvent, ClipboardItem, ContentMask, Context, Corner, Corners, CursorStyle, + DispatchPhase, Edges, Element, ElementInputHandler, Entity, Focusable as _, FontId, + GlobalElementId, Hitbox, HitboxBehavior, Hsla, InteractiveElement, IntoElement, IsZero, + KeybindingKeystroke, Length, ModifiersChangedEvent, MouseButton, MouseClickEvent, + MouseDownEvent, MouseMoveEvent, MouseUpEvent, PaintQuad, ParentElement, Pixels, ScrollDelta, + ScrollHandle, ScrollWheelEvent, ShapedLine, SharedString, Size, StatefulInteractiveElement, + Style, Styled, TextRun, TextStyleRefinement, WeakEntity, Window, anchored, deferred, div, fill, + linear_color_stop, linear_gradient, outline, point, px, quad, relative, size, solid_background, + transparent_black, }; use itertools::Itertools; use language::language_settings::{ @@ -60,7 +61,7 @@ use multi_buffer::{ }; use project::{ - ProjectPath, + Entry, ProjectPath, debugger::breakpoint_store::{Breakpoint, BreakpointSessionState}, project_settings::{GitGutterSetting, GitHunkStyleSetting, ProjectSettings}, }; @@ -73,6 +74,7 @@ use std::{ fmt::{self, Write}, iter, mem, ops::{Deref, Range}, + path::{self, Path}, rc::Rc, sync::Arc, time::{Duration, Instant}, @@ -80,11 +82,18 @@ use std::{ use sum_tree::Bias; use text::{BufferId, SelectionGoal}; use theme::{ActiveTheme, Appearance, BufferLineHeight, PlayerColor}; -use ui::{ButtonLike, KeyBinding, POPOVER_Y_PADDING, Tooltip, h_flex, prelude::*}; +use ui::utils::ensure_minimum_contrast; +use ui::{ + ButtonLike, ContextMenu, Indicator, KeyBinding, POPOVER_Y_PADDING, Tooltip, h_flex, prelude::*, + right_click_menu, +}; use unicode_segmentation::UnicodeSegmentation; use util::post_inc; use util::{RangeExt, ResultExt, debug_panic}; -use workspace::{CollaboratorId, Workspace, item::Item, notifications::NotifyTaskExt}; +use workspace::{ + CollaboratorId, ItemSettings, OpenInTerminal, OpenTerminal, RevealInProjectPanel, Workspace, + item::Item, notifications::NotifyTaskExt, +}; /// Determines what kinds of highlights should be applied to a lines background. #[derive(Clone, Copy, Default)] @@ -108,6 +117,7 @@ struct SelectionLayout { struct InlineBlameLayout { element: AnyElement, bounds: Bounds, + buffer_id: BufferId, entry: BlameEntry, } @@ -355,6 +365,8 @@ impl EditorElement { register_action(editor, window, Editor::toggle_comments); register_action(editor, window, Editor::select_larger_syntax_node); register_action(editor, window, Editor::select_smaller_syntax_node); + register_action(editor, window, Editor::select_next_syntax_node); + register_action(editor, window, Editor::select_prev_syntax_node); register_action(editor, window, Editor::unwrap_syntax_node); register_action(editor, window, Editor::select_enclosing_symbol); register_action(editor, window, Editor::move_to_enclosing_bracket); @@ -369,6 +381,8 @@ impl EditorElement { register_action(editor, window, Editor::go_to_prev_diagnostic); register_action(editor, window, Editor::go_to_next_hunk); register_action(editor, window, Editor::go_to_prev_hunk); + register_action(editor, window, Editor::go_to_next_document_highlight); + register_action(editor, window, Editor::go_to_prev_document_highlight); register_action(editor, window, |editor, action, window, cx| { editor .go_to_definition(action, window, cx) @@ -577,6 +591,9 @@ impl EditorElement { register_action(editor, window, Editor::edit_log_breakpoint); register_action(editor, window, Editor::enable_breakpoint); register_action(editor, window, Editor::disable_breakpoint); + if editor.read(cx).enable_wrap_selections_in_tag(cx) { + register_action(editor, window, Editor::wrap_selections_in_tag); + } } fn register_key_listeners(&self, window: &mut Window, _: &mut App, layout: &EditorLayout) { @@ -717,7 +734,7 @@ impl EditorElement { ColumnarMode::FromMouse => true, ColumnarMode::FromSelection => false, }, - mode: mode, + mode, goal_column: point_for_position.exact_unclipped.column(), }, window, @@ -910,6 +927,11 @@ impl EditorElement { } else if cfg!(any(target_os = "linux", target_os = "freebsd")) && event.button == MouseButton::Middle { + #[allow( + clippy::collapsible_if, + clippy::needless_return, + reason = "The cfg-block below makes this a false positive" + )] if !text_hitbox.is_hovered(window) || editor.read_only(cx) { return; } @@ -1115,26 +1137,24 @@ impl EditorElement { let hovered_diff_hunk_row = if let Some(control_row) = hovered_diff_control { Some(control_row) - } else { - if text_hovered { - let current_row = valid_point.row(); - position_map.display_hunks.iter().find_map(|(hunk, _)| { - if let DisplayDiffHunk::Unfolded { - display_row_range, .. - } = hunk - { - if display_row_range.contains(¤t_row) { - Some(display_row_range.start) - } else { - None - } + } else if text_hovered { + let current_row = valid_point.row(); + position_map.display_hunks.iter().find_map(|(hunk, _)| { + if let DisplayDiffHunk::Unfolded { + display_row_range, .. + } = hunk + { + if display_row_range.contains(¤t_row) { + Some(display_row_range.start) } else { None } - }) - } else { - None - } + } else { + None + } + }) + } else { + None }; if hovered_diff_hunk_row != editor.hovered_diff_hunk_row { @@ -1142,20 +1162,20 @@ impl EditorElement { cx.notify(); } - if let Some((bounds, blame_entry)) = &position_map.inline_blame_bounds { + if let Some((bounds, buffer_id, blame_entry)) = &position_map.inline_blame_bounds { let mouse_over_inline_blame = bounds.contains(&event.position); let mouse_over_popover = editor .inline_blame_popover .as_ref() .and_then(|state| state.popover_bounds) - .map_or(false, |bounds| bounds.contains(&event.position)); + .is_some_and(|bounds| bounds.contains(&event.position)); let keyboard_grace = editor .inline_blame_popover .as_ref() - .map_or(false, |state| state.keyboard_grace); + .is_some_and(|state| state.keyboard_grace); if mouse_over_inline_blame || mouse_over_popover { - editor.show_blame_popover(&blame_entry, event.position, false, cx); + editor.show_blame_popover(*buffer_id, blame_entry, event.position, false, cx); } else if !keyboard_grace { editor.hide_blame_popover(cx); } @@ -1179,10 +1199,10 @@ impl EditorElement { let is_visible = editor .gutter_breakpoint_indicator .0 - .map_or(false, |indicator| indicator.is_active); + .is_some_and(|indicator| indicator.is_active); let has_existing_breakpoint = - editor.breakpoint_store.as_ref().map_or(false, |store| { + editor.breakpoint_store.as_ref().is_some_and(|store| { let Some(project) = &editor.project else { return false; }; @@ -1380,29 +1400,27 @@ impl EditorElement { ref drop_cursor, ref hide_drop_cursor, } = editor.selection_drag_state + && !hide_drop_cursor + && (drop_cursor + .start + .cmp(&selection.start, &snapshot.buffer_snapshot) + .eq(&Ordering::Less) + || drop_cursor + .end + .cmp(&selection.end, &snapshot.buffer_snapshot) + .eq(&Ordering::Greater)) { - if !hide_drop_cursor - && (drop_cursor - .start - .cmp(&selection.start, &snapshot.buffer_snapshot) - .eq(&Ordering::Less) - || drop_cursor - .end - .cmp(&selection.end, &snapshot.buffer_snapshot) - .eq(&Ordering::Greater)) - { - let drag_cursor_layout = SelectionLayout::new( - drop_cursor.clone(), - false, - CursorShape::Bar, - &snapshot.display_snapshot, - false, - false, - None, - ); - let absent_color = cx.theme().players().absent(); - selections.push((absent_color, vec![drag_cursor_layout])); - } + let drag_cursor_layout = SelectionLayout::new( + drop_cursor.clone(), + false, + CursorShape::Bar, + &snapshot.display_snapshot, + false, + false, + None, + ); + let absent_color = cx.theme().players().absent(); + selections.push((absent_color, vec![drag_cursor_layout])); } } @@ -1413,19 +1431,15 @@ impl EditorElement { CollaboratorId::PeerId(peer_id) => { if let Some(collaborator) = collaboration_hub.collaborators(cx).get(&peer_id) - { - if let Some(participant_index) = collaboration_hub + && let Some(participant_index) = collaboration_hub .user_participant_indices(cx) .get(&collaborator.user_id) - { - if let Some((local_selection_style, _)) = selections.first_mut() - { - *local_selection_style = cx - .theme() - .players() - .color_for_participant(participant_index.0); - } - } + && let Some((local_selection_style, _)) = selections.first_mut() + { + *local_selection_style = cx + .theme() + .players() + .color_for_participant(participant_index.0); } } CollaboratorId::Agent => { @@ -2168,11 +2182,13 @@ impl EditorElement { }; let padding = ProjectSettings::get_global(cx).diagnostics.inline.padding as f32 * em_width; - let min_x = ProjectSettings::get_global(cx) - .diagnostics - .inline - .min_column as f32 - * em_width; + let min_x = self.column_pixels( + ProjectSettings::get_global(cx) + .diagnostics + .inline + .min_column as usize, + window, + ); let mut elements = HashMap::default(); for (row, mut diagnostics) in diagnostics_by_rows { @@ -2213,12 +2229,11 @@ impl EditorElement { cmp::max(padded_line, min_start) }; - let behind_edit_prediction_popover = edit_prediction_popover_origin.as_ref().map_or( - false, - |edit_prediction_popover_origin| { + let behind_edit_prediction_popover = edit_prediction_popover_origin + .as_ref() + .is_some_and(|edit_prediction_popover_origin| { (pos_y..pos_y + line_height).contains(&edit_prediction_popover_origin.y) - }, - ); + }); let opacity = if behind_edit_prediction_popover { 0.5 } else { @@ -2284,9 +2299,7 @@ impl EditorElement { None } }) - .map_or(false, |source| { - matches!(source, CodeActionSource::Indicator(..)) - }); + .is_some_and(|source| matches!(source, CodeActionSource::Indicator(..))); Some(editor.render_inline_code_actions(icon_size, display_point.row(), active, cx)) })?; @@ -2434,20 +2447,19 @@ impl EditorElement { .unwrap_or_default() .padding as f32; - if let Some(edit_prediction) = editor.active_edit_prediction.as_ref() { - match &edit_prediction.completion { - EditPrediction::Edit { - display_mode: EditDisplayMode::TabAccept, - .. - } => padding += INLINE_ACCEPT_SUGGESTION_EM_WIDTHS, - _ => {} - } + if let Some(edit_prediction) = editor.active_edit_prediction.as_ref() + && let EditPrediction::Edit { + display_mode: EditDisplayMode::TabAccept, + .. + } = &edit_prediction.completion + { + padding += INLINE_ACCEPT_SUGGESTION_EM_WIDTHS } padding * em_width }; - let entry = blame + let (buffer_id, entry) = blame .update(cx, |blame, cx| { blame.blame_for_rows(&[*row_info], cx).next() }) @@ -2482,13 +2494,22 @@ impl EditorElement { let size = element.layout_as_root(AvailableSpace::min_size(), window, cx); let bounds = Bounds::new(absolute_offset, size); - self.layout_blame_entry_popover(entry.clone(), blame, line_height, text_hitbox, window, cx); + self.layout_blame_entry_popover( + entry.clone(), + blame, + line_height, + text_hitbox, + row_info.buffer_id?, + window, + cx, + ); element.prepaint_as_root(absolute_offset, AvailableSpace::min_size(), window, cx); Some(InlineBlameLayout { element, bounds, + buffer_id, entry, }) } @@ -2499,6 +2520,7 @@ impl EditorElement { blame: Entity, line_height: Pixels, text_hitbox: &Hitbox, + buffer: BufferId, window: &mut Window, cx: &mut App, ) { @@ -2523,6 +2545,7 @@ impl EditorElement { popover_state.markdown, workspace, &blame, + buffer, window, cx, ) @@ -2597,14 +2620,16 @@ impl EditorElement { .into_iter() .enumerate() .flat_map(|(ix, blame_entry)| { + let (buffer_id, blame_entry) = blame_entry?; let mut element = render_blame_entry( ix, &blame, - blame_entry?, + blame_entry, &self.style, &mut last_used_color, self.editor.clone(), workspace.clone(), + buffer_id, blame_renderer.clone(), cx, )?; @@ -2747,7 +2772,10 @@ impl EditorElement { let mut block_offset = 0; let mut found_excerpt_header = false; for (_, block) in snapshot.blocks_in_range(prev_line..row_range.start) { - if matches!(block, Block::ExcerptBoundary { .. }) { + if matches!( + block, + Block::ExcerptBoundary { .. } | Block::BufferHeader { .. } + ) { found_excerpt_header = true; break; } @@ -2764,7 +2792,10 @@ impl EditorElement { let mut block_height = 0; let mut found_excerpt_header = false; for (_, block) in snapshot.blocks_in_range(row_range.end..cons_line) { - if matches!(block, Block::ExcerptBoundary { .. }) { + if matches!( + block, + Block::ExcerptBoundary { .. } | Block::BufferHeader { .. } + ) { found_excerpt_header = true; } block_height += block.height(); @@ -2811,7 +2842,7 @@ impl EditorElement { } let row = - MultiBufferRow(DisplayPoint::new(display_row, 0).to_point(&snapshot).row); + MultiBufferRow(DisplayPoint::new(display_row, 0).to_point(snapshot).row); if snapshot.is_line_folded(row) { return None; } @@ -2902,7 +2933,7 @@ impl EditorElement { if multibuffer_row .0 .checked_sub(1) - .map_or(false, |previous_row| { + .is_some_and(|previous_row| { snapshot.is_line_folded(MultiBufferRow(previous_row)) }) { @@ -2975,8 +3006,8 @@ impl EditorElement { .ilog10() + 1; - let elements = buffer_rows - .into_iter() + buffer_rows + .iter() .enumerate() .map(|(ix, row_info)| { let ExpandInfo { @@ -3031,9 +3062,7 @@ impl EditorElement { Some((toggle, origin)) }) - .collect(); - - elements + .collect() } fn calculate_relative_line_numbers( @@ -3133,7 +3162,7 @@ impl EditorElement { let relative_rows = self.calculate_relative_line_numbers(snapshot, &rows, relative_to); let mut line_number = String::new(); let line_numbers = buffer_rows - .into_iter() + .iter() .enumerate() .flat_map(|(ix, row_info)| { let display_row = DisplayRow(rows.start.0 + ix as u32); @@ -3210,7 +3239,7 @@ impl EditorElement { && self.editor.read(cx).is_singleton(cx); if include_fold_statuses { row_infos - .into_iter() + .iter() .enumerate() .map(|(ix, info)| { if info.expand_info.is_some() { @@ -3250,12 +3279,165 @@ impl EditorElement { .collect() } + fn bg_segments_per_row( + rows: Range, + selections: &[(PlayerColor, Vec)], + highlight_ranges: &[(Range, Hsla)], + base_background: Hsla, + ) -> Vec, Hsla)>> { + if rows.start >= rows.end { + return Vec::new(); + } + if !base_background.is_opaque() { + // We don't actually know what color is behind this editor. + return Vec::new(); + } + let highlight_iter = highlight_ranges.iter().cloned(); + let selection_iter = selections.iter().flat_map(|(player_color, layouts)| { + let color = player_color.selection; + layouts.iter().filter_map(move |selection_layout| { + if selection_layout.range.start != selection_layout.range.end { + Some((selection_layout.range.clone(), color)) + } else { + None + } + }) + }); + let mut per_row_map = vec![Vec::new(); rows.len()]; + for (range, color) in highlight_iter.chain(selection_iter) { + let covered_rows = if range.end.column() == 0 { + cmp::max(range.start.row(), rows.start)..cmp::min(range.end.row(), rows.end) + } else { + cmp::max(range.start.row(), rows.start) + ..cmp::min(range.end.row().next_row(), rows.end) + }; + for row in covered_rows.iter_rows() { + let seg_start = if row == range.start.row() { + range.start + } else { + DisplayPoint::new(row, 0) + }; + let seg_end = if row == range.end.row() && range.end.column() != 0 { + range.end + } else { + DisplayPoint::new(row, u32::MAX) + }; + let ix = row.minus(rows.start) as usize; + debug_assert!(row >= rows.start && row < rows.end); + debug_assert!(ix < per_row_map.len()); + per_row_map[ix].push((seg_start..seg_end, color)); + } + } + for row_segments in per_row_map.iter_mut() { + if row_segments.is_empty() { + continue; + } + let segments = mem::take(row_segments); + let merged = Self::merge_overlapping_ranges(segments, base_background); + *row_segments = merged; + } + per_row_map + } + + /// Merge overlapping ranges by splitting at all range boundaries and blending colors where + /// multiple ranges overlap. The result contains non-overlapping ranges ordered from left to right. + /// + /// Expects `start.row() == end.row()` for each range. + fn merge_overlapping_ranges( + ranges: Vec<(Range, Hsla)>, + base_background: Hsla, + ) -> Vec<(Range, Hsla)> { + struct Boundary { + pos: DisplayPoint, + is_start: bool, + index: usize, + color: Hsla, + } + + let mut boundaries: SmallVec<[Boundary; 16]> = SmallVec::with_capacity(ranges.len() * 2); + for (index, (range, color)) in ranges.iter().enumerate() { + debug_assert!( + range.start.row() == range.end.row(), + "expects single-row ranges" + ); + if range.start < range.end { + boundaries.push(Boundary { + pos: range.start, + is_start: true, + index, + color: *color, + }); + boundaries.push(Boundary { + pos: range.end, + is_start: false, + index, + color: *color, + }); + } + } + + if boundaries.is_empty() { + return Vec::new(); + } + + boundaries + .sort_unstable_by(|a, b| a.pos.cmp(&b.pos).then_with(|| a.is_start.cmp(&b.is_start))); + + let mut processed_ranges: Vec<(Range, Hsla)> = Vec::new(); + let mut active_ranges: SmallVec<[(usize, Hsla); 8]> = SmallVec::new(); + + let mut i = 0; + let mut start_pos = boundaries[0].pos; + + let boundaries_len = boundaries.len(); + while i < boundaries_len { + let current_boundary_pos = boundaries[i].pos; + if start_pos < current_boundary_pos { + if !active_ranges.is_empty() { + let mut color = base_background; + for &(_, c) in &active_ranges { + color = Hsla::blend(color, c); + } + if let Some((last_range, last_color)) = processed_ranges.last_mut() { + if *last_color == color && last_range.end == start_pos { + last_range.end = current_boundary_pos; + } else { + processed_ranges.push((start_pos..current_boundary_pos, color)); + } + } else { + processed_ranges.push((start_pos..current_boundary_pos, color)); + } + } + } + while i < boundaries_len && boundaries[i].pos == current_boundary_pos { + let active_range = &boundaries[i]; + if active_range.is_start { + let idx = active_range.index; + let pos = active_ranges + .binary_search_by_key(&idx, |(i, _)| *i) + .unwrap_or_else(|p| p); + active_ranges.insert(pos, (idx, active_range.color)); + } else { + let idx = active_range.index; + if let Ok(pos) = active_ranges.binary_search_by_key(&idx, |(i, _)| *i) { + active_ranges.remove(pos); + } + } + i += 1; + } + start_pos = current_boundary_pos; + } + + processed_ranges + } + fn layout_lines( rows: Range, snapshot: &EditorSnapshot, style: &EditorStyle, editor_width: Pixels, is_row_soft_wrapped: impl Copy + Fn(usize) -> bool, + bg_segments_per_row: &[Vec<(Range, Hsla)>], window: &mut Window, cx: &mut App, ) -> Vec { @@ -3271,12 +3453,15 @@ impl EditorElement { let placeholder_lines = placeholder_text .as_ref() - .map_or("", AsRef::as_ref) - .split('\n') + .map_or(Vec::new(), |text| text.split('\n').collect::>()); + + let placeholder_line_count = placeholder_lines.len(); + + placeholder_lines + .into_iter() .skip(rows.start.0 as usize) .chain(iter::repeat("")) - .take(rows.len()); - placeholder_lines + .take(cmp::max(rows.len(), placeholder_line_count)) .map(move |line| { let run = TextRun { len: line.len(), @@ -3305,12 +3490,13 @@ impl EditorElement { let chunks = snapshot.highlighted_chunks(rows.clone(), true, style); LineWithInvisibles::from_chunks( chunks, - &style, + style, MAX_LINE_LEN, rows.len(), &snapshot.mode, editor_width, is_row_soft_wrapped, + bg_segments_per_row, window, cx, ) @@ -3386,7 +3572,7 @@ impl EditorElement { let line_ix = align_to.row().0.checked_sub(rows.start.0); x_position = if let Some(layout) = line_ix.and_then(|ix| line_layouts.get(ix as usize)) { - x_and_width(&layout) + x_and_width(layout) } else { x_and_width(&layout_line( align_to.row(), @@ -3452,42 +3638,41 @@ impl EditorElement { .into_any_element() } - Block::ExcerptBoundary { - excerpt, - height, - starts_new_buffer, - .. - } => { + Block::ExcerptBoundary { .. } => { let color = cx.theme().colors().clone(); let mut result = v_flex().id(block_id).w_full(); + result = result.child( + h_flex().relative().child( + div() + .top(line_height / 2.) + .absolute() + .w_full() + .h_px() + .bg(color.border_variant), + ), + ); + + result.into_any() + } + + Block::BufferHeader { excerpt, height } => { + let mut result = v_flex().id(block_id).w_full(); + let jump_data = header_jump_data(snapshot, block_row_start, *height, excerpt); - if *starts_new_buffer { - if sticky_header_excerpt_id != Some(excerpt.id) { - let selected = selected_buffer_ids.contains(&excerpt.buffer_id); + if sticky_header_excerpt_id != Some(excerpt.id) { + let selected = selected_buffer_ids.contains(&excerpt.buffer_id); - result = result.child(div().pr(editor_margins.right).child( - self.render_buffer_header( - excerpt, false, selected, false, jump_data, window, cx, - ), - )); - } else { - result = - result.child(div().h(FILE_HEADER_HEIGHT as f32 * window.line_height())); - } - } else { - result = result.child( - h_flex().relative().child( - div() - .top(line_height / 2.) - .absolute() - .w_full() - .h_px() - .bg(color.border_variant), + result = result.child(div().pr(editor_margins.right).child( + self.render_buffer_header( + excerpt, false, selected, false, jump_data, window, cx, ), - ); - }; + )); + } else { + result = + result.child(div().h(FILE_HEADER_HEIGHT as f32 * window.line_height())); + } result.into_any() } @@ -3511,33 +3696,33 @@ impl EditorElement { let mut x_offset = px(0.); let mut is_block = true; - if let BlockId::Custom(custom_block_id) = block_id { - if block.has_height() { - if block.place_near() { - if let Some((x_target, line_width)) = x_position { - let margin = em_width * 2; - if line_width + final_size.width + margin - < editor_width + editor_margins.gutter.full_width() - && !row_block_types.contains_key(&(row - 1)) - && element_height_in_lines == 1 - { - x_offset = line_width + margin; - row = row - 1; - is_block = false; - element_height_in_lines = 0; - row_block_types.insert(row, is_block); - } else { - let max_offset = editor_width + editor_margins.gutter.full_width() - - final_size.width; - let min_offset = (x_target + em_width - final_size.width) - .max(editor_margins.gutter.full_width()); - x_offset = x_target.min(max_offset).max(min_offset); - } - } - }; - if element_height_in_lines != block.height() { - resized_blocks.insert(custom_block_id, element_height_in_lines); + if let BlockId::Custom(custom_block_id) = block_id + && block.has_height() + { + if block.place_near() + && let Some((x_target, line_width)) = x_position + { + let margin = em_width * 2; + if line_width + final_size.width + margin + < editor_width + editor_margins.gutter.full_width() + && !row_block_types.contains_key(&(row - 1)) + && element_height_in_lines == 1 + { + x_offset = line_width + margin; + row = row - 1; + is_block = false; + element_height_in_lines = 0; + row_block_types.insert(row, is_block); + } else { + let max_offset = + editor_width + editor_margins.gutter.full_width() - final_size.width; + let min_offset = (x_target + em_width - final_size.width) + .max(editor_margins.gutter.full_width()); + x_offset = x_target.min(max_offset).max(min_offset); } + }; + if element_height_in_lines != block.height() { + resized_blocks.insert(custom_block_id, element_height_in_lines); } } for i in 0..element_height_in_lines { @@ -3556,11 +3741,10 @@ impl EditorElement { jump_data: JumpData, window: &mut Window, cx: &mut App, - ) -> Div { + ) -> impl IntoElement { let editor = self.editor.read(cx); - let file_status = editor - .buffer - .read(cx) + let multi_buffer = editor.buffer.read(cx); + let file_status = multi_buffer .all_diff_hunks_expanded() .then(|| { editor @@ -3570,6 +3754,17 @@ impl EditorElement { .status_for_buffer_id(for_excerpt.buffer_id, cx) }) .flatten(); + let indicator = multi_buffer + .buffer(for_excerpt.buffer_id) + .and_then(|buffer| { + let buffer = buffer.read(cx); + let indicator_color = match (buffer.has_conflict(), buffer.is_dirty()) { + (true, _) => Some(Color::Warning), + (_, true) => Some(Color::Accent), + (false, false) => None, + }; + indicator_color.map(|indicator_color| Indicator::dot().color(indicator_color)) + }); let include_root = editor .project @@ -3577,17 +3772,17 @@ impl EditorElement { .map(|project| project.read(cx).visible_worktrees(cx).count() > 1) .unwrap_or_default(); let can_open_excerpts = Editor::can_open_excerpts_in_file(for_excerpt.buffer.file()); - let path = for_excerpt.buffer.resolve_file_path(cx, include_root); - let filename = path + let relative_path = for_excerpt.buffer.resolve_file_path(cx, include_root); + let filename = relative_path .as_ref() .and_then(|path| Some(path.file_name()?.to_string_lossy().to_string())); - let parent_path = path.as_ref().and_then(|path| { + let parent_path = relative_path.as_ref().and_then(|path| { Some(path.parent()?.to_string_lossy().to_string() + std::path::MAIN_SEPARATOR_STR) }); let focus_handle = editor.focus_handle(cx); let colors = cx.theme().colors(); - div() + let header = div() .p_1() .w_full() .h(FILE_HEADER_HEIGHT as f32 * window.line_height()) @@ -3677,6 +3872,12 @@ impl EditorElement { }) .take(1), ) + .child( + h_flex() + .size(Pixels(12.0)) + .justify_center() + .children(indicator), + ) .child( h_flex() .cursor_pointer() @@ -3687,29 +3888,38 @@ impl EditorElement { .child( h_flex() .gap_2() - .child( - Label::new( - filename - .map(SharedString::from) - .unwrap_or_else(|| "untitled".into()), - ) - .single_line() - .when_some( - file_status, - |el, status| { - el.color(if status.is_conflicted() { - Color::Conflict - } else if status.is_modified() { - Color::Modified - } else if status.is_deleted() { - Color::Disabled - } else { - Color::Created - }) - .when(status.is_deleted(), |el| el.strikethrough()) - }, - ), - ) + .map(|path_header| { + let filename = filename + .map(SharedString::from) + .unwrap_or_else(|| "untitled".into()); + + path_header + .when(ItemSettings::get_global(cx).file_icons, |el| { + let path = path::Path::new(filename.as_str()); + let icon = FileIcons::get_icon(path, cx) + .unwrap_or_default(); + let icon = + Icon::from_path(icon).color(Color::Muted); + el.child(icon) + }) + .child(Label::new(filename).single_line().when_some( + file_status, + |el, status| { + el.color(if status.is_conflicted() { + Color::Conflict + } else if status.is_modified() { + Color::Modified + } else if status.is_deleted() { + Color::Disabled + } else { + Color::Created + }) + .when(status.is_deleted(), |el| { + el.strikethrough() + }) + }, + )) + }) .when_some(parent_path, |then, path| { then.child(div().child(path).text_color( if file_status.is_some_and(FileStatus::is_deleted) { @@ -3720,23 +3930,26 @@ impl EditorElement { )) }), ) - .when(can_open_excerpts && is_selected && path.is_some(), |el| { - el.child( - h_flex() - .id("jump-to-file-button") - .gap_2p5() - .child(Label::new("Jump To File")) - .children( - KeyBinding::for_action_in( - &OpenExcerpts, - &focus_handle, - window, - cx, - ) - .map(|binding| binding.into_any_element()), - ), - ) - }) + .when( + can_open_excerpts && is_selected && relative_path.is_some(), + |el| { + el.child( + h_flex() + .id("jump-to-file-button") + .gap_2p5() + .child(Label::new("Jump To File")) + .children( + KeyBinding::for_action_in( + &OpenExcerpts, + &focus_handle, + window, + cx, + ) + .map(|binding| binding.into_any_element()), + ), + ) + }, + ) .on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation()) .on_click(window.listener_for(&self.editor, { move |editor, e: &ClickEvent, window, cx| { @@ -3749,7 +3962,107 @@ impl EditorElement { } })), ), - ) + ); + + let file = for_excerpt.buffer.file().cloned(); + let editor = self.editor.clone(); + right_click_menu("buffer-header-context-menu") + .trigger(move |_, _, _| header) + .menu(move |window, cx| { + let menu_context = focus_handle.clone(); + let editor = editor.clone(); + let file = file.clone(); + ContextMenu::build(window, cx, move |mut menu, window, cx| { + if let Some(file) = file + && let Some(project) = editor.read(cx).project() + && let Some(worktree) = + project.read(cx).worktree_for_id(file.worktree_id(cx), cx) + { + let worktree = worktree.read(cx); + let relative_path = file.path(); + let entry_for_path = worktree.entry_for_path(relative_path); + let abs_path = entry_for_path.map(|e| { + e.canonical_path.as_deref().map_or_else( + || worktree.abs_path().join(relative_path), + Path::to_path_buf, + ) + }); + let has_relative_path = worktree.root_entry().is_some_and(Entry::is_dir); + + let parent_abs_path = abs_path + .as_ref() + .and_then(|abs_path| Some(abs_path.parent()?.to_path_buf())); + let relative_path = has_relative_path + .then_some(relative_path) + .map(ToOwned::to_owned); + + let visible_in_project_panel = + relative_path.is_some() && worktree.is_visible(); + let reveal_in_project_panel = entry_for_path + .filter(|_| visible_in_project_panel) + .map(|entry| entry.id); + menu = menu + .when_some(abs_path, |menu, abs_path| { + menu.entry( + "Copy Path", + Some(Box::new(zed_actions::workspace::CopyPath)), + window.handler_for(&editor, move |_, _, cx| { + cx.write_to_clipboard(ClipboardItem::new_string( + abs_path.to_string_lossy().to_string(), + )); + }), + ) + }) + .when_some(relative_path, |menu, relative_path| { + menu.entry( + "Copy Relative Path", + Some(Box::new(zed_actions::workspace::CopyRelativePath)), + window.handler_for(&editor, move |_, _, cx| { + cx.write_to_clipboard(ClipboardItem::new_string( + relative_path.to_string_lossy().to_string(), + )); + }), + ) + }) + .when( + reveal_in_project_panel.is_some() || parent_abs_path.is_some(), + |menu| menu.separator(), + ) + .when_some(reveal_in_project_panel, |menu, entry_id| { + menu.entry( + "Reveal In Project Panel", + Some(Box::new(RevealInProjectPanel::default())), + window.handler_for(&editor, move |editor, _, cx| { + if let Some(project) = &mut editor.project { + project.update(cx, |_, cx| { + cx.emit(project::Event::RevealInProjectPanel( + entry_id, + )) + }); + } + }), + ) + }) + .when_some(parent_abs_path, |menu, parent_abs_path| { + menu.entry( + "Open in Terminal", + Some(Box::new(OpenInTerminal)), + window.handler_for(&editor, move |_, window, cx| { + window.dispatch_action( + OpenTerminal { + working_directory: parent_abs_path.clone(), + } + .boxed_clone(), + cx, + ); + }), + ) + }); + } + + menu.context(menu_context) + }) + }) } fn render_blocks( @@ -3787,7 +4100,7 @@ impl EditorElement { for (row, block) in fixed_blocks { let block_id = block.id(); - if focused_block.as_ref().map_or(false, |b| b.id == block_id) { + if focused_block.as_ref().is_some_and(|b| b.id == block_id) { focused_block = None; } @@ -3844,7 +4157,7 @@ impl EditorElement { }; let block_id = block.id(); - if focused_block.as_ref().map_or(false, |b| b.id == block_id) { + if focused_block.as_ref().is_some_and(|b| b.id == block_id) { focused_block = None; } @@ -3885,60 +4198,58 @@ impl EditorElement { } } - if let Some(focused_block) = focused_block { - if let Some(focus_handle) = focused_block.focus_handle.upgrade() { - if focus_handle.is_focused(window) { - if let Some(block) = snapshot.block_for_id(focused_block.id) { - let style = block.style(); - let width = match style { - BlockStyle::Fixed => AvailableSpace::MinContent, - BlockStyle::Flex => AvailableSpace::Definite( - hitbox - .size - .width - .max(fixed_block_max_width) - .max(editor_margins.gutter.width + *scroll_width), - ), - BlockStyle::Sticky => AvailableSpace::Definite(hitbox.size.width), - }; + if let Some(focused_block) = focused_block + && let Some(focus_handle) = focused_block.focus_handle.upgrade() + && focus_handle.is_focused(window) + && let Some(block) = snapshot.block_for_id(focused_block.id) + { + let style = block.style(); + let width = match style { + BlockStyle::Fixed => AvailableSpace::MinContent, + BlockStyle::Flex => AvailableSpace::Definite( + hitbox + .size + .width + .max(fixed_block_max_width) + .max(editor_margins.gutter.width + *scroll_width), + ), + BlockStyle::Sticky => AvailableSpace::Definite(hitbox.size.width), + }; - if let Some((element, element_size, _, x_offset)) = self.render_block( - &block, - width, - focused_block.id, - rows.end, - snapshot, - text_x, - &rows, - line_layouts, - editor_margins, - line_height, - em_width, - text_hitbox, - editor_width, - scroll_width, - &mut resized_blocks, - &mut row_block_types, - selections, - selected_buffer_ids, - is_row_soft_wrapped, - sticky_header_excerpt_id, - window, - cx, - ) { - blocks.push(BlockLayout { - id: block.id(), - x_offset, - row: None, - element, - available_space: size(width, element_size.height.into()), - style, - overlaps_gutter: true, - is_buffer_header: block.is_buffer_header(), - }); - } - } - } + if let Some((element, element_size, _, x_offset)) = self.render_block( + &block, + width, + focused_block.id, + rows.end, + snapshot, + text_x, + &rows, + line_layouts, + editor_margins, + line_height, + em_width, + text_hitbox, + editor_width, + scroll_width, + &mut resized_blocks, + &mut row_block_types, + selections, + selected_buffer_ids, + is_row_soft_wrapped, + sticky_header_excerpt_id, + window, + cx, + ) { + blocks.push(BlockLayout { + id: block.id(), + x_offset, + row: None, + element, + available_space: size(width, element_size.height.into()), + style, + overlaps_gutter: true, + is_buffer_header: block.is_buffer_header(), + }); } } @@ -4101,19 +4412,19 @@ impl EditorElement { edit_prediction_popover_visible = true; } - if editor.context_menu_visible() { - if let Some(crate::ContextMenuOrigin::Cursor) = editor.context_menu_origin() { - let (min_height_in_lines, max_height_in_lines) = editor - .context_menu_options - .as_ref() - .map_or((3, 12), |options| { - (options.min_entries_visible, options.max_entries_visible) - }); + if editor.context_menu_visible() + && let Some(crate::ContextMenuOrigin::Cursor) = editor.context_menu_origin() + { + let (min_height_in_lines, max_height_in_lines) = editor + .context_menu_options + .as_ref() + .map_or((3, 12), |options| { + (options.min_entries_visible, options.max_entries_visible) + }); - min_menu_height += line_height * min_height_in_lines as f32 + POPOVER_Y_PADDING; - max_menu_height += line_height * max_height_in_lines as f32 + POPOVER_Y_PADDING; - context_menu_visible = true; - } + min_menu_height += line_height * min_height_in_lines as f32 + POPOVER_Y_PADDING; + max_menu_height += line_height * max_height_in_lines as f32 + POPOVER_Y_PADDING; + context_menu_visible = true; } context_menu_placement = editor .context_menu_options @@ -4625,7 +4936,7 @@ impl EditorElement { } }; - let source_included = source_display_point.map_or(true, |source_display_point| { + let source_included = source_display_point.is_none_or(|source_display_point| { visible_range .to_inclusive() .contains(&source_display_point.row()) @@ -4805,7 +5116,7 @@ impl EditorElement { let intersects_menu = |bounds: Bounds| -> bool { context_menu_layout .as_ref() - .map_or(false, |menu| bounds.intersects(&menu.bounds)) + .is_some_and(|menu| bounds.intersects(&menu.bounds)) }; let can_place_above = { @@ -4990,7 +5301,7 @@ impl EditorElement { if active_positions .iter() - .any(|p| p.map_or(false, |p| display_row_range.contains(&p.row()))) + .any(|p| p.is_some_and(|p| display_row_range.contains(&p.row()))) { let y = display_row_range.start.as_f32() * line_height + text_hitbox.bounds.top() @@ -5103,7 +5414,7 @@ impl EditorElement { let intersects_menu = |bounds: Bounds| -> bool { context_menu_layout .as_ref() - .map_or(false, |menu| bounds.intersects(&menu.bounds)) + .is_some_and(|menu| bounds.intersects(&menu.bounds)) }; let final_origin = if popover_bounds_above.is_contained_within(hitbox) @@ -5188,7 +5499,7 @@ impl EditorElement { let mut end_row = start_row.0; while active_rows .peek() - .map_or(false, |(active_row, has_selection)| { + .is_some_and(|(active_row, has_selection)| { active_row.0 == end_row + 1 && has_selection.selection == contains_non_empty_selection.selection }) @@ -5447,9 +5758,9 @@ impl EditorElement { // In singleton buffers, we select corresponding lines on the line number click, so use | -like cursor. // In multi buffers, we open file at the line number clicked, so use a pointing hand cursor. if is_singleton { - window.set_cursor_style(CursorStyle::IBeam, &hitbox); + window.set_cursor_style(CursorStyle::IBeam, hitbox); } else { - window.set_cursor_style(CursorStyle::PointingHand, &hitbox); + window.set_cursor_style(CursorStyle::PointingHand, hitbox); } } } @@ -5468,7 +5779,7 @@ impl EditorElement { &layout.position_map.snapshot, line_height, layout.gutter_hitbox.bounds, - &hunk, + hunk, ); Some(( hunk_bounds, @@ -5604,7 +5915,10 @@ impl EditorElement { let end_row_in_current_excerpt = snapshot .blocks_in_range(start_row..end_row) .find_map(|(start_row, block)| { - if matches!(block, Block::ExcerptBoundary { .. }) { + if matches!( + block, + Block::ExcerptBoundary { .. } | Block::BufferHeader { .. } + ) { Some(start_row) } else { None @@ -5659,16 +5973,15 @@ impl EditorElement { cx: &mut App, ) { for (_, hunk_hitbox) in &layout.display_hunks { - if let Some(hunk_hitbox) = hunk_hitbox { - if !self + if let Some(hunk_hitbox) = hunk_hitbox + && !self .editor .read(cx) .buffer() .read(cx) .all_diff_hunks_expanded() - { - window.set_cursor_style(CursorStyle::PointingHand, hunk_hitbox); - } + { + window.set_cursor_style(CursorStyle::PointingHand, hunk_hitbox); } } @@ -5775,7 +6088,7 @@ impl EditorElement { }; self.paint_lines_background(layout, window, cx); - let invisible_display_ranges = self.paint_highlights(layout, window); + let invisible_display_ranges = self.paint_highlights(layout, window, cx); self.paint_document_colors(layout, window); self.paint_lines(&invisible_display_ranges, layout, window, cx); self.paint_redactions(layout, window); @@ -5797,6 +6110,7 @@ impl EditorElement { &mut self, layout: &mut EditorLayout, window: &mut Window, + cx: &mut App, ) -> SmallVec<[Range; 32]> { window.paint_layer(layout.position_map.text_hitbox.bounds, |window| { let mut invisible_display_ranges = SmallVec::<[Range; 32]>::new(); @@ -5813,7 +6127,11 @@ impl EditorElement { ); } - let corner_radius = 0.15 * layout.position_map.line_height; + let corner_radius = if EditorSettings::get_global(cx).rounded_selection { + 0.15 * layout.position_map.line_height + } else { + Pixels::ZERO + }; for (player_color, selections) in &layout.selections { for selection in selections.iter() { @@ -5990,10 +6308,10 @@ impl EditorElement { if axis == ScrollbarAxis::Vertical { let fast_markers = - self.collect_fast_scrollbar_markers(layout, &scrollbar_layout, cx); + self.collect_fast_scrollbar_markers(layout, scrollbar_layout, cx); // Refresh slow scrollbar markers in the background. Below, we // paint whatever markers have already been computed. - self.refresh_slow_scrollbar_markers(layout, &scrollbar_layout, window, cx); + self.refresh_slow_scrollbar_markers(layout, scrollbar_layout, window, cx); let markers = self.editor.read(cx).scrollbar_marker_state.markers.clone(); for marker in markers.iter().chain(&fast_markers) { @@ -6027,7 +6345,7 @@ impl EditorElement { if any_scrollbar_dragged { window.set_window_cursor_style(CursorStyle::Arrow); } else { - window.set_cursor_style(CursorStyle::Arrow, &hitbox); + window.set_cursor_style(CursorStyle::Arrow, hitbox); } } }) @@ -6577,25 +6895,23 @@ impl EditorElement { editor.set_scroll_position(position, window, cx); } cx.stop_propagation(); - } else { - if minimap_hitbox.is_hovered(window) { - editor.scroll_manager.set_is_hovering_minimap_thumb( - !event.dragging() - && layout - .thumb_layout - .thumb_bounds - .is_some_and(|bounds| bounds.contains(&event.position)), - cx, - ); + } else if minimap_hitbox.is_hovered(window) { + editor.scroll_manager.set_is_hovering_minimap_thumb( + !event.dragging() + && layout + .thumb_layout + .thumb_bounds + .is_some_and(|bounds| bounds.contains(&event.position)), + cx, + ); - // Stop hover events from propagating to the - // underlying editor if the minimap hitbox is hovered - if !event.dragging() { - cx.stop_propagation(); - } - } else { - editor.scroll_manager.hide_minimap_thumb(cx); + // Stop hover events from propagating to the + // underlying editor if the minimap hitbox is hovered + if !event.dragging() { + cx.stop_propagation(); } + } else { + editor.scroll_manager.hide_minimap_thumb(cx); } mouse_position = event.position; }); @@ -6974,9 +7290,7 @@ impl EditorElement { let unstaged_hollow = ProjectSettings::get_global(cx) .git .hunk_style - .map_or(false, |style| { - matches!(style, GitHunkStyleSetting::UnstagedHollow) - }); + .is_some_and(|style| matches!(style, GitHunkStyleSetting::UnstagedHollow)); unstaged == unstaged_hollow } @@ -7020,7 +7334,7 @@ fn header_jump_data( pub struct AcceptEditPredictionBinding(pub(crate) Option); impl AcceptEditPredictionBinding { - pub fn keystroke(&self) -> Option<&Keystroke> { + pub fn keystroke(&self) -> Option<&KeybindingKeystroke> { if let Some(binding) = self.0.as_ref() { match &binding.keystrokes() { [keystroke, ..] => Some(keystroke), @@ -7105,12 +7419,13 @@ fn render_blame_entry_popover( markdown: Entity, workspace: WeakEntity, blame: &Entity, + buffer: BufferId, window: &mut Window, cx: &mut App, ) -> Option { let renderer = cx.global::().0.clone(); let blame = blame.read(cx); - let repository = blame.repository(cx)?.clone(); + let repository = blame.repository(cx, buffer)?; renderer.render_blame_entry_popover( blame_entry, scroll_handle, @@ -7131,6 +7446,7 @@ fn render_blame_entry( last_used_color: &mut Option<(PlayerColor, Oid)>, editor: Entity, workspace: Entity, + buffer: BufferId, renderer: Arc, cx: &mut App, ) -> Option { @@ -7151,8 +7467,8 @@ fn render_blame_entry( last_used_color.replace((sha_color, blame_entry.sha)); let blame = blame.read(cx); - let details = blame.details_for_entry(&blame_entry); - let repository = blame.repository(cx)?; + let details = blame.details_for_entry(buffer, &blame_entry); + let repository = blame.repository(cx, buffer)?; renderer.render_blame_entry( &style.text, blame_entry, @@ -7207,6 +7523,7 @@ impl LineWithInvisibles { editor_mode: &EditorMode, text_width: Pixels, is_row_soft_wrapped: impl Copy + Fn(usize) -> bool, + bg_segments_per_row: &[Vec<(Range, Hsla)>], window: &mut Window, cx: &mut App, ) -> Vec { @@ -7222,6 +7539,7 @@ impl LineWithInvisibles { let mut row = 0; let mut line_exceeded_max_len = false; let font_size = text_style.font_size.to_pixels(window.rem_size()); + let min_contrast = EditorSettings::get_global(cx).minimum_contrast_for_highlights; let ellipsis = SharedString::from("⋯"); @@ -7234,10 +7552,16 @@ impl LineWithInvisibles { }]) { if let Some(replacement) = highlighted_chunk.replacement { if !line.is_empty() { + let segments = bg_segments_per_row.get(row).map(|v| &v[..]).unwrap_or(&[]); + let text_runs: &[TextRun] = if segments.is_empty() { + &styles + } else { + &Self::split_runs_by_bg_segments(&styles, segments, min_contrast) + }; let shaped_line = window.text_system().shape_line( line.clone().into(), font_size, - &styles, + text_runs, None, ); width += shaped_line.width; @@ -7315,10 +7639,16 @@ impl LineWithInvisibles { } else { for (ix, mut line_chunk) in highlighted_chunk.text.split('\n').enumerate() { if ix > 0 { + let segments = bg_segments_per_row.get(row).map(|v| &v[..]).unwrap_or(&[]); + let text_runs = if segments.is_empty() { + &styles + } else { + &Self::split_runs_by_bg_segments(&styles, segments, min_contrast) + }; let shaped_line = window.text_system().shape_line( line.clone().into(), font_size, - &styles, + text_runs, None, ); width += shaped_line.width; @@ -7406,6 +7736,81 @@ impl LineWithInvisibles { layouts } + /// Takes text runs and non-overlapping left-to-right background ranges with color. + /// Returns new text runs with adjusted contrast as per background ranges. + fn split_runs_by_bg_segments( + text_runs: &[TextRun], + bg_segments: &[(Range, Hsla)], + min_contrast: f32, + ) -> Vec { + let mut output_runs: Vec = Vec::with_capacity(text_runs.len()); + let mut line_col = 0usize; + let mut segment_ix = 0usize; + + for text_run in text_runs.iter() { + let run_start_col = line_col; + let run_end_col = run_start_col + text_run.len; + while segment_ix < bg_segments.len() + && (bg_segments[segment_ix].0.end.column() as usize) <= run_start_col + { + segment_ix += 1; + } + let mut cursor_col = run_start_col; + let mut local_segment_ix = segment_ix; + while local_segment_ix < bg_segments.len() { + let (range, segment_color) = &bg_segments[local_segment_ix]; + let segment_start_col = range.start.column() as usize; + let segment_end_col = range.end.column() as usize; + if segment_start_col >= run_end_col { + break; + } + if segment_start_col > cursor_col { + let span_len = segment_start_col - cursor_col; + output_runs.push(TextRun { + len: span_len, + font: text_run.font.clone(), + color: text_run.color, + background_color: text_run.background_color, + underline: text_run.underline, + strikethrough: text_run.strikethrough, + }); + cursor_col = segment_start_col; + } + let segment_slice_end_col = segment_end_col.min(run_end_col); + if segment_slice_end_col > cursor_col { + let new_text_color = + ensure_minimum_contrast(text_run.color, *segment_color, min_contrast); + output_runs.push(TextRun { + len: segment_slice_end_col - cursor_col, + font: text_run.font.clone(), + color: new_text_color, + background_color: text_run.background_color, + underline: text_run.underline, + strikethrough: text_run.strikethrough, + }); + cursor_col = segment_slice_end_col; + } + if segment_end_col >= run_end_col { + break; + } + local_segment_ix += 1; + } + if cursor_col < run_end_col { + output_runs.push(TextRun { + len: run_end_col - cursor_col, + font: text_run.font.clone(), + color: text_run.color, + background_color: text_run.background_color, + underline: text_run.underline, + strikethrough: text_run.strikethrough, + }); + } + line_col = run_end_col; + segment_ix = local_segment_ix; + } + output_runs + } + fn prepaint( &mut self, line_height: Pixels, @@ -7815,7 +8220,7 @@ impl Element for EditorElement { min_lines, max_lines, } => { - let editor_handle = cx.entity().clone(); + let editor_handle = cx.entity(); let max_line_number_width = self.max_line_number_width(&editor.snapshot(window, cx), window); window.request_measured_layout( @@ -7898,7 +8303,7 @@ impl Element for EditorElement { let (mut snapshot, is_read_only) = self.editor.update(cx, |editor, cx| { (editor.snapshot(window, cx), editor.read_only(cx)) }); - let style = self.style.clone(); + let style = &self.style; let rem_size = window.rem_size(); let font_id = window.text_system().resolve_font(&style.text.font()); @@ -8006,7 +8411,7 @@ impl Element for EditorElement { // The max scroll position for the top of the window let max_scroll_top = if matches!( snapshot.mode, - EditorMode::SingleLine { .. } + EditorMode::SingleLine | EditorMode::AutoHeight { .. } | EditorMode::Full { sized_by_content: true, @@ -8073,7 +8478,7 @@ impl Element for EditorElement { let is_row_soft_wrapped = |row: usize| { row_infos .get(row) - .map_or(true, |info| info.buffer_row.is_none()) + .is_none_or(|info| info.buffer_row.is_none()) }; let start_anchor = if start_row == Default::default() { @@ -8319,12 +8724,20 @@ impl Element for EditorElement { cx, ); + let bg_segments_per_row = Self::bg_segments_per_row( + start_row..end_row, + &selections, + &highlighted_ranges, + self.style.background, + ); + let mut line_layouts = Self::layout_lines( start_row..end_row, &snapshot, &self.style, editor_width, is_row_soft_wrapped, + &bg_segments_per_row, window, cx, ); @@ -8358,14 +8771,14 @@ impl Element for EditorElement { return None; } let blame = editor.blame.as_ref()?; - let blame_entry = blame + let (_, blame_entry) = blame .update(cx, |blame, cx| { let row_infos = snapshot.row_infos(snapshot.longest_row()).next()?; blame.blame_for_rows(&[row_infos], cx).next() }) .flatten()?; - let mut element = render_inline_blame_entry(blame_entry, &style, cx)?; + let mut element = render_inline_blame_entry(blame_entry, style, cx)?; let inline_blame_padding = ProjectSettings::get_global(cx) .git .inline_blame @@ -8385,7 +8798,7 @@ impl Element for EditorElement { let longest_line_width = layout_line( snapshot.longest_row(), &snapshot, - &style, + style, editor_width, is_row_soft_wrapped, window, @@ -8543,7 +8956,7 @@ impl Element for EditorElement { scroll_pixel_position, newest_selection_head, editor_width, - &style, + style, window, cx, ) @@ -8561,7 +8974,7 @@ impl Element for EditorElement { end_row, line_height, em_width, - &style, + style, window, cx, ); @@ -8706,7 +9119,7 @@ impl Element for EditorElement { &line_layouts, newest_selection_head, newest_selection_point, - &style, + style, window, cx, ) @@ -8846,11 +9259,21 @@ impl Element for EditorElement { }); let invisible_symbol_font_size = font_size / 2.; + let whitespace_map = &self + .editor + .read(cx) + .buffer + .read(cx) + .language_settings(cx) + .whitespace_map; + + let tab_char = whitespace_map.tab(); + let tab_len = tab_char.len(); let tab_invisible = window.text_system().shape_line( - "→".into(), + tab_char, invisible_symbol_font_size, &[TextRun { - len: "→".len(), + len: tab_len, font: self.style.text.font(), color: cx.theme().colors().editor_invisible, background_color: None, @@ -8859,11 +9282,14 @@ impl Element for EditorElement { }], None, ); + + let space_char = whitespace_map.space(); + let space_len = space_char.len(); let space_invisible = window.text_system().shape_line( - "•".into(), + space_char, invisible_symbol_font_size, &[TextRun { - len: "•".len(), + len: space_len, font: self.style.text.font(), color: cx.theme().colors().editor_invisible, background_color: None, @@ -8908,9 +9334,9 @@ impl Element for EditorElement { text_hitbox: text_hitbox.clone(), inline_blame_bounds: inline_blame_layout .as_ref() - .map(|layout| (layout.bounds, layout.entry.clone())), + .map(|layout| (layout.bounds, layout.buffer_id, layout.entry.clone())), display_hunks: display_hunks.clone(), - diff_hunk_control_bounds: diff_hunk_control_bounds.clone(), + diff_hunk_control_bounds, }); self.editor.update(cx, |editor, _| { @@ -9568,7 +9994,7 @@ pub(crate) struct PositionMap { pub snapshot: EditorSnapshot, pub text_hitbox: Hitbox, pub gutter_hitbox: Hitbox, - pub inline_blame_bounds: Option<(Bounds, BlameEntry)>, + pub inline_blame_bounds: Option<(Bounds, BufferId, BlameEntry)>, pub display_hunks: Vec<(DisplayDiffHunk, Option)>, pub diff_hunk_control_bounds: Vec<(DisplayRow, Bounds)>, } @@ -9608,14 +10034,12 @@ impl PointForPosition { false } else if start_row == end_row { candidate_col >= start_col && candidate_col < end_col + } else if candidate_row == start_row { + candidate_col >= start_col + } else if candidate_row == end_row { + candidate_col < end_col } else { - if candidate_row == start_row { - candidate_col >= start_col - } else if candidate_row == end_row { - candidate_col < end_col - } else { - true - } + true } } } @@ -9680,12 +10104,13 @@ pub fn layout_line( let chunks = snapshot.highlighted_chunks(row..row + DisplayRow(1), true, style); LineWithInvisibles::from_chunks( chunks, - &style, + style, MAX_LINE_LEN, 1, &snapshot.mode, text_width, is_row_soft_wrapped, + &[], window, cx, ) @@ -9797,7 +10222,7 @@ impl CursorLayout { .px_0p5() .line_height(text_size + px(2.)) .text_color(cursor_name.color) - .child(cursor_name.string.clone()) + .child(cursor_name.string) .into_any_element(); name_element.prepaint_as_root(name_origin, AvailableSpace::min_size(), window, cx); @@ -10050,10 +10475,10 @@ fn compute_auto_height_layout( let overscroll = size(em_width, px(0.)); let editor_width = text_width - gutter_dimensions.margin - overscroll.width - em_width; - if !matches!(editor.soft_wrap_mode(cx), SoftWrap::None) { - if editor.set_wrap_width(Some(editor_width), cx) { - snapshot = editor.snapshot(window, cx); - } + if !matches!(editor.soft_wrap_mode(cx), SoftWrap::None) + && editor.set_wrap_width(Some(editor_width), cx) + { + snapshot = editor.snapshot(window, cx); } let scroll_height = (snapshot.max_point().row().next_row().0 as f32) * line_height; @@ -10085,6 +10510,71 @@ mod tests { use std::num::NonZeroU32; use util::test::sample_text; + #[gpui::test] + async fn test_soft_wrap_editor_width_auto_height_editor(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let window = cx.add_window(|window, cx| { + let buffer = MultiBuffer::build_simple(&"a ".to_string().repeat(100), cx); + let mut editor = Editor::new( + EditorMode::AutoHeight { + min_lines: 1, + max_lines: None, + }, + buffer, + None, + window, + cx, + ); + editor.set_soft_wrap_mode(language_settings::SoftWrap::EditorWidth, cx); + editor + }); + let cx = &mut VisualTestContext::from_window(*window, cx); + let editor = window.root(cx).unwrap(); + let style = cx.update(|_, cx| editor.read(cx).style().unwrap().clone()); + + for x in 1..=100 { + let (_, state) = cx.draw( + Default::default(), + size(px(200. + 0.13 * x as f32), px(500.)), + |_, _| EditorElement::new(&editor, style.clone()), + ); + + assert!( + state.position_map.scroll_max.x == 0., + "Soft wrapped editor should have no horizontal scrolling!" + ); + } + } + + #[gpui::test] + async fn test_soft_wrap_editor_width_full_editor(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let window = cx.add_window(|window, cx| { + let buffer = MultiBuffer::build_simple(&"a ".to_string().repeat(100), cx); + let mut editor = Editor::new(EditorMode::full(), buffer, None, window, cx); + editor.set_soft_wrap_mode(language_settings::SoftWrap::EditorWidth, cx); + editor + }); + let cx = &mut VisualTestContext::from_window(*window, cx); + let editor = window.root(cx).unwrap(); + let style = cx.update(|_, cx| editor.read(cx).style().unwrap().clone()); + + for x in 1..=100 { + let (_, state) = cx.draw( + Default::default(), + size(px(200. + 0.13 * x as f32), px(500.)), + |_, _| EditorElement::new(&editor, style.clone()), + ); + + assert!( + state.position_map.scroll_max.x == 0., + "Soft wrapped editor should have no horizontal scrolling!" + ); + } + } + #[gpui::test] fn test_shape_line_numbers(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -10269,7 +10759,7 @@ mod tests { let style = cx.update(|_, cx| editor.read(cx).style().unwrap().clone()); window .update(cx, |editor, window, cx| { - editor.set_placeholder_text("hello", cx); + editor.set_placeholder_text("hello", window, cx); editor.insert_blocks( [BlockProperties { style: BlockStyle::Fixed, @@ -10521,4 +11011,289 @@ mod tests { .cloned() .collect() } + + #[gpui::test] + fn test_merge_overlapping_ranges() { + let base_bg = Hsla::white(); + let color1 = Hsla { + h: 0.0, + s: 0.5, + l: 0.5, + a: 0.5, + }; + let color2 = Hsla { + h: 120.0, + s: 0.5, + l: 0.5, + a: 0.5, + }; + + let display_point = |col| DisplayPoint::new(DisplayRow(0), col); + let cols = |v: &Vec<(Range, Hsla)>| -> Vec<(u32, u32)> { + v.iter() + .map(|(r, _)| (r.start.column(), r.end.column())) + .collect() + }; + + // Test overlapping ranges blend colors + let overlapping = vec![ + (display_point(5)..display_point(15), color1), + (display_point(10)..display_point(20), color2), + ]; + let result = EditorElement::merge_overlapping_ranges(overlapping, base_bg); + assert_eq!(cols(&result), vec![(5, 10), (10, 15), (15, 20)]); + + // Test middle segment should have blended color + let blended = Hsla::blend(Hsla::blend(base_bg, color1), color2); + assert_eq!(result[1].1, blended); + + // Test adjacent same-color ranges merge + let adjacent_same = vec![ + (display_point(5)..display_point(10), color1), + (display_point(10)..display_point(15), color1), + ]; + let result = EditorElement::merge_overlapping_ranges(adjacent_same, base_bg); + assert_eq!(cols(&result), vec![(5, 15)]); + + // Test contained range splits + let contained = vec![ + (display_point(5)..display_point(20), color1), + (display_point(10)..display_point(15), color2), + ]; + let result = EditorElement::merge_overlapping_ranges(contained, base_bg); + assert_eq!(cols(&result), vec![(5, 10), (10, 15), (15, 20)]); + + // Test multiple overlaps split at every boundary + let color3 = Hsla { + h: 240.0, + s: 0.5, + l: 0.5, + a: 0.5, + }; + let complex = vec![ + (display_point(5)..display_point(12), color1), + (display_point(8)..display_point(16), color2), + (display_point(10)..display_point(14), color3), + ]; + let result = EditorElement::merge_overlapping_ranges(complex, base_bg); + assert_eq!( + cols(&result), + vec![(5, 8), (8, 10), (10, 12), (12, 14), (14, 16)] + ); + } + + #[gpui::test] + fn test_bg_segments_per_row() { + let base_bg = Hsla::white(); + + // Case A: selection spans three display rows: row 1 [5, end), full row 2, row 3 [0, 7) + { + let selection_color = Hsla { + h: 200.0, + s: 0.5, + l: 0.5, + a: 0.5, + }; + let player_color = PlayerColor { + cursor: selection_color, + background: selection_color, + selection: selection_color, + }; + + let spanning_selection = SelectionLayout { + head: DisplayPoint::new(DisplayRow(3), 7), + cursor_shape: CursorShape::Bar, + is_newest: true, + is_local: true, + range: DisplayPoint::new(DisplayRow(1), 5)..DisplayPoint::new(DisplayRow(3), 7), + active_rows: DisplayRow(1)..DisplayRow(4), + user_name: None, + }; + + let selections = vec![(player_color, vec![spanning_selection])]; + let result = EditorElement::bg_segments_per_row( + DisplayRow(0)..DisplayRow(5), + &selections, + &[], + base_bg, + ); + + assert_eq!(result.len(), 5); + assert!(result[0].is_empty()); + assert_eq!(result[1].len(), 1); + assert_eq!(result[2].len(), 1); + assert_eq!(result[3].len(), 1); + assert!(result[4].is_empty()); + + assert_eq!(result[1][0].0.start, DisplayPoint::new(DisplayRow(1), 5)); + assert_eq!(result[1][0].0.end.row(), DisplayRow(1)); + assert_eq!(result[1][0].0.end.column(), u32::MAX); + assert_eq!(result[2][0].0.start, DisplayPoint::new(DisplayRow(2), 0)); + assert_eq!(result[2][0].0.end.row(), DisplayRow(2)); + assert_eq!(result[2][0].0.end.column(), u32::MAX); + assert_eq!(result[3][0].0.start, DisplayPoint::new(DisplayRow(3), 0)); + assert_eq!(result[3][0].0.end, DisplayPoint::new(DisplayRow(3), 7)); + } + + // Case B: selection ends exactly at the start of row 3, excluding row 3 + { + let selection_color = Hsla { + h: 120.0, + s: 0.5, + l: 0.5, + a: 0.5, + }; + let player_color = PlayerColor { + cursor: selection_color, + background: selection_color, + selection: selection_color, + }; + + let selection = SelectionLayout { + head: DisplayPoint::new(DisplayRow(2), 0), + cursor_shape: CursorShape::Bar, + is_newest: true, + is_local: true, + range: DisplayPoint::new(DisplayRow(1), 5)..DisplayPoint::new(DisplayRow(3), 0), + active_rows: DisplayRow(1)..DisplayRow(3), + user_name: None, + }; + + let selections = vec![(player_color, vec![selection])]; + let result = EditorElement::bg_segments_per_row( + DisplayRow(0)..DisplayRow(4), + &selections, + &[], + base_bg, + ); + + assert_eq!(result.len(), 4); + assert!(result[0].is_empty()); + assert_eq!(result[1].len(), 1); + assert_eq!(result[2].len(), 1); + assert!(result[3].is_empty()); + + assert_eq!(result[1][0].0.start, DisplayPoint::new(DisplayRow(1), 5)); + assert_eq!(result[1][0].0.end.row(), DisplayRow(1)); + assert_eq!(result[1][0].0.end.column(), u32::MAX); + assert_eq!(result[2][0].0.start, DisplayPoint::new(DisplayRow(2), 0)); + assert_eq!(result[2][0].0.end.row(), DisplayRow(2)); + assert_eq!(result[2][0].0.end.column(), u32::MAX); + } + } + + #[cfg(test)] + fn generate_test_run(len: usize, color: Hsla) -> TextRun { + TextRun { + len, + font: gpui::font(".SystemUIFont"), + color, + background_color: None, + underline: None, + strikethrough: None, + } + } + + #[gpui::test] + fn test_split_runs_by_bg_segments(cx: &mut gpui::TestAppContext) { + init_test(cx, |_| {}); + + let text_color = Hsla { + h: 210.0, + s: 0.1, + l: 0.4, + a: 1.0, + }; + let bg1 = Hsla { + h: 30.0, + s: 0.6, + l: 0.8, + a: 1.0, + }; + let bg2 = Hsla { + h: 200.0, + s: 0.6, + l: 0.2, + a: 1.0, + }; + let min_contrast = 45.0; + + // Case A: single run; disjoint segments inside the run + let runs = vec![generate_test_run(20, text_color)]; + let segs = vec![ + ( + DisplayPoint::new(DisplayRow(0), 5)..DisplayPoint::new(DisplayRow(0), 10), + bg1, + ), + ( + DisplayPoint::new(DisplayRow(0), 12)..DisplayPoint::new(DisplayRow(0), 16), + bg2, + ), + ]; + let out = LineWithInvisibles::split_runs_by_bg_segments(&runs, &segs, min_contrast); + // Expected slices: [0,5) [5,10) [10,12) [12,16) [16,20) + assert_eq!( + out.iter().map(|r| r.len).collect::>(), + vec![5, 5, 2, 4, 4] + ); + assert_eq!(out[0].color, text_color); + assert_eq!( + out[1].color, + ensure_minimum_contrast(text_color, bg1, min_contrast) + ); + assert_eq!(out[2].color, text_color); + assert_eq!( + out[3].color, + ensure_minimum_contrast(text_color, bg2, min_contrast) + ); + assert_eq!(out[4].color, text_color); + + // Case B: multiple runs; segment extends to end of line (u32::MAX) + let runs = vec![ + generate_test_run(8, text_color), + generate_test_run(7, text_color), + ]; + let segs = vec![( + DisplayPoint::new(DisplayRow(0), 6)..DisplayPoint::new(DisplayRow(0), u32::MAX), + bg1, + )]; + let out = LineWithInvisibles::split_runs_by_bg_segments(&runs, &segs, min_contrast); + // Expected slices across runs: [0,6) [6,8) | [0,7) + assert_eq!(out.iter().map(|r| r.len).collect::>(), vec![6, 2, 7]); + let adjusted = ensure_minimum_contrast(text_color, bg1, min_contrast); + assert_eq!(out[0].color, text_color); + assert_eq!(out[1].color, adjusted); + assert_eq!(out[2].color, adjusted); + + // Case C: multi-byte characters + // for text: "Hello 🌍 世界!" + let runs = vec![ + generate_test_run(5, text_color), // "Hello" + generate_test_run(6, text_color), // " 🌍 " + generate_test_run(6, text_color), // "世界" + generate_test_run(1, text_color), // "!" + ]; + // selecting "🌍 世" + let segs = vec![( + DisplayPoint::new(DisplayRow(0), 6)..DisplayPoint::new(DisplayRow(0), 14), + bg1, + )]; + let out = LineWithInvisibles::split_runs_by_bg_segments(&runs, &segs, min_contrast); + // "Hello" | " " | "🌍 " | "世" | "界" | "!" + assert_eq!( + out.iter().map(|r| r.len).collect::>(), + vec![5, 1, 5, 3, 3, 1] + ); + assert_eq!(out[0].color, text_color); // "Hello" + assert_eq!( + out[2].color, + ensure_minimum_contrast(text_color, bg1, min_contrast) + ); // "🌍 " + assert_eq!( + out[3].color, + ensure_minimum_contrast(text_color, bg1, min_contrast) + ); // "世" + assert_eq!(out[4].color, text_color); // "界" + assert_eq!(out[5].color, text_color); // "!" + } } diff --git a/crates/editor/src/git/blame.rs b/crates/editor/src/git/blame.rs index fc350a5a15b4f7b105872e61e5a2401d183c1a6d..51719048ef81cf273bc58e7d810d66d454a04805 100644 --- a/crates/editor/src/git/blame.rs +++ b/crates/editor/src/git/blame.rs @@ -10,16 +10,18 @@ use gpui::{ AnyElement, App, AppContext as _, Context, Entity, Hsla, ScrollHandle, Subscription, Task, TextStyle, WeakEntity, Window, }; -use language::{Bias, Buffer, BufferSnapshot, Edit}; +use itertools::Itertools; +use language::{Bias, BufferSnapshot, Edit}; use markdown::Markdown; -use multi_buffer::RowInfo; +use multi_buffer::{MultiBuffer, RowInfo}; use project::{ - Project, ProjectItem, + Project, ProjectItem as _, git_store::{GitStoreEvent, Repository, RepositoryEvent}, }; use smallvec::SmallVec; use std::{sync::Arc, time::Duration}; use sum_tree::SumTree; +use text::BufferId; use workspace::Workspace; #[derive(Clone, Debug, Default)] @@ -63,16 +65,19 @@ impl<'a> sum_tree::Dimension<'a, GitBlameEntrySummary> for u32 { } } -pub struct GitBlame { - project: Entity, - buffer: Entity, +struct GitBlameBuffer { entries: SumTree, - commit_details: HashMap, buffer_snapshot: BufferSnapshot, buffer_edits: text::Subscription, + commit_details: HashMap, +} + +pub struct GitBlame { + project: Entity, + multi_buffer: WeakEntity, + buffers: HashMap, task: Task>, focused: bool, - generated: bool, changed_while_blurred: bool, user_triggered: bool, regenerate_on_edit_task: Task>, @@ -184,47 +189,46 @@ impl gpui::Global for GlobalBlameRenderer {} impl GitBlame { pub fn new( - buffer: Entity, + multi_buffer: Entity, project: Entity, user_triggered: bool, focused: bool, cx: &mut Context, ) -> Self { - let entries = SumTree::from_item( - GitBlameEntry { - rows: buffer.read(cx).max_point().row + 1, - blame: None, + let multi_buffer_subscription = cx.subscribe( + &multi_buffer, + |git_blame, multi_buffer, event, cx| match event { + multi_buffer::Event::DirtyChanged => { + if !multi_buffer.read(cx).is_dirty(cx) { + git_blame.generate(cx); + } + } + multi_buffer::Event::ExcerptsAdded { .. } + | multi_buffer::Event::ExcerptsEdited { .. } => git_blame.regenerate_on_edit(cx), + _ => {} }, - &(), ); - let buffer_subscriptions = cx.subscribe(&buffer, |this, buffer, event, cx| match event { - language::BufferEvent::DirtyChanged => { - if !buffer.read(cx).is_dirty() { - this.generate(cx); - } - } - language::BufferEvent::Edited => { - this.regenerate_on_edit(cx); - } - _ => {} - }); - let project_subscription = cx.subscribe(&project, { - let buffer = buffer.clone(); - - move |this, _, event, cx| match event { - project::Event::WorktreeUpdatedEntries(_, updated) => { - let project_entry_id = buffer.read(cx).entry_id(cx); + let multi_buffer = multi_buffer.downgrade(); + + move |git_blame, _, event, cx| { + if let project::Event::WorktreeUpdatedEntries(_, updated) = event { + let Some(multi_buffer) = multi_buffer.upgrade() else { + return; + }; + let project_entry_id = multi_buffer + .read(cx) + .as_singleton() + .and_then(|it| it.read(cx).entry_id(cx)); if updated .iter() .any(|(_, entry_id, _)| project_entry_id == Some(*entry_id)) { log::debug!("Updated buffers. Regenerating blame data...",); - this.generate(cx); + git_blame.generate(cx); } } - _ => {} } }); @@ -240,24 +244,17 @@ impl GitBlame { _ => {} }); - let buffer_snapshot = buffer.read(cx).snapshot(); - let buffer_edits = buffer.update(cx, |buffer, _| buffer.subscribe()); - let mut this = Self { project, - buffer, - buffer_snapshot, - entries, - buffer_edits, + multi_buffer: multi_buffer.downgrade(), + buffers: HashMap::default(), user_triggered, focused, changed_while_blurred: false, - commit_details: HashMap::default(), task: Task::ready(Ok(())), - generated: false, regenerate_on_edit_task: Task::ready(Ok(())), _regenerate_subscriptions: vec![ - buffer_subscriptions, + multi_buffer_subscription, project_subscription, git_store_subscription, ], @@ -266,54 +263,61 @@ impl GitBlame { this } - pub fn repository(&self, cx: &App) -> Option> { + pub fn repository(&self, cx: &App, id: BufferId) -> Option> { self.project .read(cx) .git_store() .read(cx) - .repository_and_path_for_buffer_id(self.buffer.read(cx).remote_id(), cx) + .repository_and_path_for_buffer_id(id, cx) .map(|(repo, _)| repo) } pub fn has_generated_entries(&self) -> bool { - self.generated + !self.buffers.is_empty() } - pub fn details_for_entry(&self, entry: &BlameEntry) -> Option { - self.commit_details.get(&entry.sha).cloned() + pub fn details_for_entry( + &self, + buffer: BufferId, + entry: &BlameEntry, + ) -> Option { + self.buffers + .get(&buffer)? + .commit_details + .get(&entry.sha) + .cloned() } pub fn blame_for_rows<'a>( &'a mut self, rows: &'a [RowInfo], - cx: &App, - ) -> impl 'a + Iterator> { - self.sync(cx); - - let buffer_id = self.buffer_snapshot.remote_id(); - let mut cursor = self.entries.cursor::(&()); - rows.into_iter().map(move |info| { - let row = info - .buffer_row - .filter(|_| info.buffer_id == Some(buffer_id))?; - cursor.seek_forward(&row, Bias::Right); - cursor.item()?.blame.clone() + cx: &'a mut App, + ) -> impl Iterator> + use<'a> { + rows.iter().map(move |info| { + let buffer_id = info.buffer_id?; + self.sync(cx, buffer_id); + + let buffer_row = info.buffer_row?; + let mut cursor = self.buffers.get(&buffer_id)?.entries.cursor::(&()); + cursor.seek_forward(&buffer_row, Bias::Right); + Some((buffer_id, cursor.item()?.blame.clone()?)) }) } - pub fn max_author_length(&mut self, cx: &App) -> usize { - self.sync(cx); - + pub fn max_author_length(&mut self, cx: &mut App) -> usize { let mut max_author_length = 0; - - for entry in self.entries.iter() { - let author_len = entry - .blame - .as_ref() - .and_then(|entry| entry.author.as_ref()) - .map(|author| author.len()); - if let Some(author_len) = author_len { - if author_len > max_author_length { + self.sync_all(cx); + + for buffer in self.buffers.values() { + for entry in buffer.entries.iter() { + let author_len = entry + .blame + .as_ref() + .and_then(|entry| entry.author.as_ref()) + .map(|author| author.len()); + if let Some(author_len) = author_len + && author_len > max_author_length + { max_author_length = author_len; } } @@ -337,22 +341,48 @@ impl GitBlame { } } - fn sync(&mut self, cx: &App) { - let edits = self.buffer_edits.consume(); - let new_snapshot = self.buffer.read(cx).snapshot(); + fn sync_all(&mut self, cx: &mut App) { + let Some(multi_buffer) = self.multi_buffer.upgrade() else { + return; + }; + multi_buffer + .read(cx) + .excerpt_buffer_ids() + .into_iter() + .for_each(|id| self.sync(cx, id)); + } + + fn sync(&mut self, cx: &mut App, buffer_id: BufferId) { + let Some(blame_buffer) = self.buffers.get_mut(&buffer_id) else { + return; + }; + let Some(buffer) = self + .multi_buffer + .upgrade() + .and_then(|multi_buffer| multi_buffer.read(cx).buffer(buffer_id)) + else { + return; + }; + let edits = blame_buffer.buffer_edits.consume(); + let new_snapshot = buffer.read(cx).snapshot(); let mut row_edits = edits .into_iter() .map(|edit| { - let old_point_range = self.buffer_snapshot.offset_to_point(edit.old.start) - ..self.buffer_snapshot.offset_to_point(edit.old.end); + let old_point_range = blame_buffer.buffer_snapshot.offset_to_point(edit.old.start) + ..blame_buffer.buffer_snapshot.offset_to_point(edit.old.end); let new_point_range = new_snapshot.offset_to_point(edit.new.start) ..new_snapshot.offset_to_point(edit.new.end); if old_point_range.start.column - == self.buffer_snapshot.line_len(old_point_range.start.row) + == blame_buffer + .buffer_snapshot + .line_len(old_point_range.start.row) && (new_snapshot.chars_at(edit.new.start).next() == Some('\n') - || self.buffer_snapshot.line_len(old_point_range.end.row) == 0) + || blame_buffer + .buffer_snapshot + .line_len(old_point_range.end.row) + == 0) { Edit { old: old_point_range.start.row + 1..old_point_range.end.row + 1, @@ -376,7 +406,7 @@ impl GitBlame { .peekable(); let mut new_entries = SumTree::default(); - let mut cursor = self.entries.cursor::(&()); + let mut cursor = blame_buffer.entries.cursor::(&()); while let Some(mut edit) = row_edits.next() { while let Some(next_edit) = row_edits.peek() { @@ -415,37 +445,47 @@ impl GitBlame { let old_end = cursor.end(); if row_edits .peek() - .map_or(true, |next_edit| next_edit.old.start >= old_end) + .is_none_or(|next_edit| next_edit.old.start >= old_end) + && let Some(entry) = cursor.item() { - if let Some(entry) = cursor.item() { - if old_end > edit.old.end { - new_entries.push( - GitBlameEntry { - rows: cursor.end() - edit.old.end, - blame: entry.blame.clone(), - }, - &(), - ); - } - - cursor.next(); + if old_end > edit.old.end { + new_entries.push( + GitBlameEntry { + rows: cursor.end() - edit.old.end, + blame: entry.blame.clone(), + }, + &(), + ); } + + cursor.next(); } } new_entries.append(cursor.suffix(), &()); drop(cursor); - self.buffer_snapshot = new_snapshot; - self.entries = new_entries; + blame_buffer.buffer_snapshot = new_snapshot; + blame_buffer.entries = new_entries; } #[cfg(test)] fn check_invariants(&mut self, cx: &mut Context) { - self.sync(cx); - assert_eq!( - self.entries.summary().rows, - self.buffer.read(cx).max_point().row + 1 - ); + self.sync_all(cx); + for (&id, buffer) in &self.buffers { + assert_eq!( + buffer.entries.summary().rows, + self.multi_buffer + .upgrade() + .unwrap() + .read(cx) + .buffer(id) + .unwrap() + .read(cx) + .max_point() + .row + + 1 + ); + } } fn generate(&mut self, cx: &mut Context) { @@ -453,62 +493,105 @@ impl GitBlame { self.changed_while_blurred = true; return; } - let buffer_edits = self.buffer.update(cx, |buffer, _| buffer.subscribe()); - let snapshot = self.buffer.read(cx).snapshot(); let blame = self.project.update(cx, |project, cx| { - project.blame_buffer(&self.buffer, None, cx) + let Some(multi_buffer) = self.multi_buffer.upgrade() else { + return Vec::new(); + }; + multi_buffer + .read(cx) + .all_buffer_ids() + .into_iter() + .filter_map(|id| { + let buffer = multi_buffer.read(cx).buffer(id)?; + let snapshot = buffer.read(cx).snapshot(); + let buffer_edits = buffer.update(cx, |buffer, _| buffer.subscribe()); + + let blame_buffer = project.blame_buffer(&buffer, None, cx); + Some((id, snapshot, buffer_edits, blame_buffer)) + }) + .collect::>() }); let provider_registry = GitHostingProviderRegistry::default_global(cx); self.task = cx.spawn(async move |this, cx| { - let result = cx + let (result, errors) = cx .background_spawn({ - let snapshot = snapshot.clone(); async move { - let Some(Blame { - entries, - messages, - remote_url, - }) = blame.await? - else { - return Ok(None); - }; - - let entries = build_blame_entry_sum_tree(entries, snapshot.max_point().row); - let commit_details = - parse_commit_messages(messages, remote_url, provider_registry).await; - - anyhow::Ok(Some((entries, commit_details))) + let mut res = vec![]; + let mut errors = vec![]; + for (id, snapshot, buffer_edits, blame) in blame { + match blame.await { + Ok(Some(Blame { + entries, + messages, + remote_url, + })) => { + let entries = build_blame_entry_sum_tree( + entries, + snapshot.max_point().row, + ); + let commit_details = parse_commit_messages( + messages, + remote_url, + provider_registry.clone(), + ) + .await; + + res.push(( + id, + snapshot, + buffer_edits, + Some(entries), + commit_details, + )); + } + Ok(None) => { + res.push((id, snapshot, buffer_edits, None, Default::default())) + } + Err(e) => errors.push(e), + } + } + (res, errors) } }) .await; - this.update(cx, |this, cx| match result { - Ok(None) => { - // Nothing to do, e.g. no repository found + this.update(cx, |this, cx| { + this.buffers.clear(); + for (id, snapshot, buffer_edits, entries, commit_details) in result { + let Some(entries) = entries else { + continue; + }; + this.buffers.insert( + id, + GitBlameBuffer { + buffer_edits, + buffer_snapshot: snapshot, + entries, + commit_details, + }, + ); } - Ok(Some((entries, commit_details))) => { - this.buffer_edits = buffer_edits; - this.buffer_snapshot = snapshot; - this.entries = entries; - this.commit_details = commit_details; - this.generated = true; - cx.notify(); + cx.notify(); + if !errors.is_empty() { + this.project.update(cx, |_, cx| { + if this.user_triggered { + log::error!("failed to get git blame data: {errors:?}"); + let notification = errors + .into_iter() + .format_with(",", |e, f| f(&format_args!("{:#}", e))) + .to_string(); + cx.emit(project::Event::Toast { + notification_id: "git-blame".into(), + message: notification, + }); + } else { + // If we weren't triggered by a user, we just log errors in the background, instead of sending + // notifications. + log::debug!("failed to get git blame data: {errors:?}"); + } + }) } - Err(error) => this.project.update(cx, |_, cx| { - if this.user_triggered { - log::error!("failed to get git blame data: {error:?}"); - let notification = format!("{:#}", error).trim().to_string(); - cx.emit(project::Event::Toast { - notification_id: "git-blame".into(), - message: notification, - }); - } else { - // If we weren't triggered by a user, we just log errors in the background, instead of sending - // notifications. - log::debug!("failed to get git blame data: {error:?}"); - } - }), }) }); } @@ -522,7 +605,7 @@ impl GitBlame { this.update(cx, |this, cx| { this.generate(cx); }) - }) + }); } } @@ -661,6 +744,9 @@ mod tests { ) .collect::>(), expected + .into_iter() + .map(|it| Some((buffer_id, it?))) + .collect::>() ); } @@ -707,6 +793,7 @@ mod tests { }) .await .unwrap(); + let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); let blame = cx.new(|cx| GitBlame::new(buffer.clone(), project.clone(), true, true, cx)); @@ -787,6 +874,7 @@ mod tests { .await .unwrap(); let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id()); + let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); let git_blame = cx.new(|cx| GitBlame::new(buffer.clone(), project, false, true, cx)); @@ -808,14 +896,14 @@ mod tests { ) .collect::>(), vec![ - Some(blame_entry("1b1b1b", 0..1)), - Some(blame_entry("0d0d0d", 1..2)), - Some(blame_entry("3a3a3a", 2..3)), + Some((buffer_id, blame_entry("1b1b1b", 0..1))), + Some((buffer_id, blame_entry("0d0d0d", 1..2))), + Some((buffer_id, blame_entry("3a3a3a", 2..3))), None, None, - Some(blame_entry("3a3a3a", 5..6)), - Some(blame_entry("0d0d0d", 6..7)), - Some(blame_entry("3a3a3a", 7..8)), + Some((buffer_id, blame_entry("3a3a3a", 5..6))), + Some((buffer_id, blame_entry("0d0d0d", 6..7))), + Some((buffer_id, blame_entry("3a3a3a", 7..8))), ] ); // Subset of lines @@ -833,8 +921,8 @@ mod tests { ) .collect::>(), vec![ - Some(blame_entry("0d0d0d", 1..2)), - Some(blame_entry("3a3a3a", 2..3)), + Some((buffer_id, blame_entry("0d0d0d", 1..2))), + Some((buffer_id, blame_entry("3a3a3a", 2..3))), None ] ); @@ -854,7 +942,7 @@ mod tests { cx ) .collect::>(), - vec![Some(blame_entry("0d0d0d", 1..2)), None, None] + vec![Some((buffer_id, blame_entry("0d0d0d", 1..2))), None, None] ); }); } @@ -897,6 +985,7 @@ mod tests { .await .unwrap(); let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id()); + let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); let git_blame = cx.new(|cx| GitBlame::new(buffer.clone(), project, false, true, cx)); @@ -1018,7 +1107,7 @@ mod tests { init_test(cx); let fs = FakeFs::new(cx.executor()); - let buffer_initial_text_len = rng.gen_range(5..15); + let buffer_initial_text_len = rng.random_range(5..15); let mut buffer_initial_text = Rope::from( RandomCharIter::new(&mut rng) .take(buffer_initial_text_len) @@ -1063,13 +1152,14 @@ mod tests { }) .await .unwrap(); + let mbuffer = cx.new(|cx| MultiBuffer::singleton(buffer.clone(), cx)); - let git_blame = cx.new(|cx| GitBlame::new(buffer.clone(), project, false, true, cx)); + let git_blame = cx.new(|cx| GitBlame::new(mbuffer.clone(), project, false, true, cx)); cx.executor().run_until_parked(); git_blame.update(cx, |blame, cx| blame.check_invariants(cx)); for _ in 0..operations { - match rng.gen_range(0..100) { + match rng.random_range(0..100) { 0..=19 => { log::info!("quiescing"); cx.executor().run_until_parked(); @@ -1112,8 +1202,8 @@ mod tests { let mut blame_entries = Vec::new(); for ix in 0..5 { if last_row < max_row { - let row_start = rng.gen_range(last_row..max_row); - let row_end = rng.gen_range(row_start + 1..cmp::min(row_start + 3, max_row) + 1); + let row_start = rng.random_range(last_row..max_row); + let row_end = rng.random_range(row_start + 1..cmp::min(row_start + 3, max_row) + 1); blame_entries.push(blame_entry(&ix.to_string(), row_start..row_end)); last_row = row_end; } else { diff --git a/crates/editor/src/highlight_matching_bracket.rs b/crates/editor/src/highlight_matching_bracket.rs index e38197283d4a4e2623ecadb30d90d0363053fdc5..aa4e616924ad6bd47627bfd95e9a5c58587afc25 100644 --- a/crates/editor/src/highlight_matching_bracket.rs +++ b/crates/editor/src/highlight_matching_bracket.rs @@ -1,6 +1,7 @@ use crate::{Editor, RangeToAnchorExt}; -use gpui::{Context, Window}; +use gpui::{Context, HighlightStyle, Window}; use language::CursorShape; +use theme::ActiveTheme; enum MatchingBracketHighlight {} @@ -9,7 +10,7 @@ pub fn refresh_matching_bracket_highlights( window: &mut Window, cx: &mut Context, ) { - editor.clear_background_highlights::(cx); + editor.clear_highlights::(cx); let newest_selection = editor.selections.newest::(cx); // Don't highlight brackets if the selection isn't empty @@ -35,12 +36,19 @@ pub fn refresh_matching_bracket_highlights( .buffer_snapshot .innermost_enclosing_bracket_ranges(head..tail, None) { - editor.highlight_background::( - &[ + editor.highlight_text::( + vec![ opening_range.to_anchors(&snapshot.buffer_snapshot), closing_range.to_anchors(&snapshot.buffer_snapshot), ], - |theme| theme.colors().editor_document_highlight_bracket_background, + HighlightStyle { + background_color: Some( + cx.theme() + .colors() + .editor_document_highlight_bracket_background, + ), + ..Default::default() + }, cx, ) } @@ -104,7 +112,7 @@ mod tests { another_test(1, 2, 3); } "#}); - cx.assert_editor_background_highlights::(indoc! {r#" + cx.assert_editor_text_highlights::(indoc! {r#" pub fn test«(»"Test argument"«)» { another_test(1, 2, 3); } @@ -115,7 +123,7 @@ mod tests { another_test(1, ˇ2, 3); } "#}); - cx.assert_editor_background_highlights::(indoc! {r#" + cx.assert_editor_text_highlights::(indoc! {r#" pub fn test("Test argument") { another_test«(»1, 2, 3«)»; } @@ -126,7 +134,7 @@ mod tests { anotherˇ_test(1, 2, 3); } "#}); - cx.assert_editor_background_highlights::(indoc! {r#" + cx.assert_editor_text_highlights::(indoc! {r#" pub fn test("Test argument") «{» another_test(1, 2, 3); «}» @@ -138,7 +146,7 @@ mod tests { another_test(1, 2, 3); } "#}); - cx.assert_editor_background_highlights::(indoc! {r#" + cx.assert_editor_text_highlights::(indoc! {r#" pub fn test("Test argument") { another_test(1, 2, 3); } @@ -150,8 +158,8 @@ mod tests { another_test(1, 2, 3); } "#}); - cx.assert_editor_background_highlights::(indoc! {r#" - pub fn test("Test argument") { + cx.assert_editor_text_highlights::(indoc! {r#" + pub fn test«("Test argument") { another_test(1, 2, 3); } "#}); diff --git a/crates/editor/src/hover_links.rs b/crates/editor/src/hover_links.rs index 02f93e6829a3f7ac08ec7dfa390cd846560bb7d5..ba0b6f88683969aca3818a2795aa6b8454de3bb8 100644 --- a/crates/editor/src/hover_links.rs +++ b/crates/editor/src/hover_links.rs @@ -188,22 +188,26 @@ impl Editor { pub fn scroll_hover( &mut self, - amount: &ScrollAmount, + amount: ScrollAmount, window: &mut Window, cx: &mut Context, ) -> bool { let selection = self.selections.newest_anchor().head(); let snapshot = self.snapshot(window, cx); - let Some(popover) = self.hover_state.info_popovers.iter().find(|popover| { + if let Some(popover) = self.hover_state.info_popovers.iter().find(|popover| { popover .symbol_range .point_within_range(&TriggerPoint::Text(selection), &snapshot) - }) else { - return false; - }; - popover.scroll(amount, window, cx); - true + }) { + popover.scroll(amount, window, cx); + true + } else if let Some(context_menu) = self.context_menu.borrow_mut().as_mut() { + context_menu.scroll_aside(amount, window, cx); + true + } else { + false + } } fn cmd_click_reveal_task( @@ -271,7 +275,7 @@ impl Editor { Task::ready(Ok(Navigated::No)) }; self.select(SelectPhase::End, window, cx); - return navigate_task; + navigate_task } } @@ -321,7 +325,10 @@ pub fn update_inlay_link_and_hover_points( if let Some(cached_hint) = inlay_hint_cache.hint_by_id(excerpt_id, hovered_hint.id) { match cached_hint.resolve_state { ResolveState::CanResolve(_, _) => { - if let Some(buffer_id) = previous_valid_anchor.buffer_id { + if let Some(buffer_id) = snapshot + .buffer_snapshot + .buffer_id_for_anchor(previous_valid_anchor) + { inlay_hint_cache.spawn_hint_resolve( buffer_id, excerpt_id, @@ -418,24 +425,22 @@ pub fn update_inlay_link_and_hover_points( } if let Some((language_server_id, location)) = hovered_hint_part.location + && secondary_held + && !editor.has_pending_nonempty_selection() { - if secondary_held - && !editor.has_pending_nonempty_selection() - { - go_to_definition_updated = true; - show_link_definition( - shift_held, - editor, - TriggerPoint::InlayHint( - highlight, - location, - language_server_id, - ), - snapshot, - window, - cx, - ); - } + go_to_definition_updated = true; + show_link_definition( + shift_held, + editor, + TriggerPoint::InlayHint( + highlight, + location, + language_server_id, + ), + snapshot, + window, + cx, + ); } } } @@ -561,7 +566,7 @@ pub fn show_link_definition( provider.definitions(&buffer, buffer_position, preferred_kind, cx) })?; if let Some(task) = task { - task.await.ok().map(|definition_result| { + task.await.ok().flatten().map(|definition_result| { ( definition_result.iter().find_map(|link| { link.origin.as_ref().and_then(|origin| { @@ -657,11 +662,11 @@ pub fn show_link_definition( pub(crate) fn find_url( buffer: &Entity, position: text::Anchor, - mut cx: AsyncWindowContext, + cx: AsyncWindowContext, ) -> Option<(Range, String)> { const LIMIT: usize = 2048; - let Ok(snapshot) = buffer.read_with(&mut cx, |buffer, _| buffer.snapshot()) else { + let Ok(snapshot) = buffer.read_with(&cx, |buffer, _| buffer.snapshot()) else { return None; }; @@ -719,11 +724,11 @@ pub(crate) fn find_url( pub(crate) fn find_url_from_range( buffer: &Entity, range: Range, - mut cx: AsyncWindowContext, + cx: AsyncWindowContext, ) -> Option { const LIMIT: usize = 2048; - let Ok(snapshot) = buffer.read_with(&mut cx, |buffer, _| buffer.snapshot()) else { + let Ok(snapshot) = buffer.read_with(&cx, |buffer, _| buffer.snapshot()) else { return None; }; @@ -766,10 +771,11 @@ pub(crate) fn find_url_from_range( let mut finder = LinkFinder::new(); finder.kinds(&[LinkKind::Url]); - if let Some(link) = finder.links(&text).next() { - if link.start() == 0 && link.end() == text.len() { - return Some(link.as_str().to_string()); - } + if let Some(link) = finder.links(&text).next() + && link.start() == 0 + && link.end() == text.len() + { + return Some(link.as_str().to_string()); } None @@ -794,7 +800,7 @@ pub(crate) async fn find_file( ) -> Option { project .update(cx, |project, cx| { - project.resolve_path_in_buffer(&candidate_file_path, buffer, cx) + project.resolve_path_in_buffer(candidate_file_path, buffer, cx) }) .ok()? .await @@ -872,7 +878,7 @@ fn surrounding_filename( .peekable(); while let Some(ch) = forwards.next() { // Skip escaped whitespace - if ch == '\\' && forwards.peek().map_or(false, |ch| ch.is_whitespace()) { + if ch == '\\' && forwards.peek().is_some_and(|ch| ch.is_whitespace()) { token_end += ch.len_utf8(); let whitespace = forwards.next().unwrap(); token_end += whitespace.len_utf8(); diff --git a/crates/editor/src/hover_popover.rs b/crates/editor/src/hover_popover.rs index bda229e34669482549182b2c7abbe2c3efb9a751..6541f76a56e671fb414e28d83adc6b0459e288a8 100644 --- a/crates/editor/src/hover_popover.rs +++ b/crates/editor/src/hover_popover.rs @@ -142,11 +142,11 @@ pub fn hover_at_inlay( .info_popovers .iter() .any(|InfoPopover { symbol_range, .. }| { - if let RangeInEditor::Inlay(range) = symbol_range { - if range == &inlay_hover.range { - // Hover triggered from same location as last time. Don't show again. - return true; - } + if let RangeInEditor::Inlay(range) = symbol_range + && range == &inlay_hover.range + { + // Hover triggered from same location as last time. Don't show again. + return true; } false }) @@ -167,17 +167,16 @@ pub fn hover_at_inlay( let language_registry = project.read_with(cx, |p, _| p.languages().clone())?; let blocks = vec![inlay_hover.tooltip]; - let parsed_content = parse_blocks(&blocks, &language_registry, None, cx).await; + let parsed_content = + parse_blocks(&blocks, Some(&language_registry), None, cx).await; let scroll_handle = ScrollHandle::new(); let subscription = this .update(cx, |_, cx| { - if let Some(parsed_content) = &parsed_content { - Some(cx.observe(parsed_content, |_, _, cx| cx.notify())) - } else { - None - } + parsed_content.as_ref().map(|parsed_content| { + cx.observe(parsed_content, |_, _, cx| cx.notify()) + }) }) .ok() .flatten(); @@ -251,7 +250,9 @@ fn show_hover( let (excerpt_id, _, _) = editor.buffer().read(cx).excerpt_containing(anchor, cx)?; - let language_registry = editor.project.as_ref()?.read(cx).languages().clone(); + let language_registry = editor + .project() + .map(|project| project.read(cx).languages().clone()); let provider = editor.semantics_provider.clone()?; if !ignore_timeout { @@ -267,13 +268,12 @@ fn show_hover( } // Don't request again if the location is the same as the previous request - if let Some(triggered_from) = &editor.hover_state.triggered_from { - if triggered_from + if let Some(triggered_from) = &editor.hover_state.triggered_from + && triggered_from .cmp(&anchor, &snapshot.buffer_snapshot) .is_eq() - { - return None; - } + { + return None; } let hover_popover_delay = EditorSettings::get_global(cx).hover_popover_delay; @@ -428,7 +428,7 @@ fn show_hover( }; let hovers_response = if let Some(hover_request) = hover_request { - hover_request.await + hover_request.await.unwrap_or_default() } else { Vec::new() }; @@ -443,15 +443,14 @@ fn show_hover( text: format!("Unicode character U+{:02X}", invisible as u32), kind: HoverBlockKind::PlainText, }]; - let parsed_content = parse_blocks(&blocks, &language_registry, None, cx).await; + let parsed_content = + parse_blocks(&blocks, language_registry.as_ref(), None, cx).await; let scroll_handle = ScrollHandle::new(); let subscription = this .update(cx, |_, cx| { - if let Some(parsed_content) = &parsed_content { - Some(cx.observe(parsed_content, |_, _, cx| cx.notify())) - } else { - None - } + parsed_content.as_ref().map(|parsed_content| { + cx.observe(parsed_content, |_, _, cx| cx.notify()) + }) }) .ok() .flatten(); @@ -493,16 +492,15 @@ fn show_hover( let blocks = hover_result.contents; let language = hover_result.language; - let parsed_content = parse_blocks(&blocks, &language_registry, language, cx).await; + let parsed_content = + parse_blocks(&blocks, language_registry.as_ref(), language, cx).await; let scroll_handle = ScrollHandle::new(); hover_highlights.push(range.clone()); let subscription = this .update(cx, |_, cx| { - if let Some(parsed_content) = &parsed_content { - Some(cx.observe(parsed_content, |_, _, cx| cx.notify())) - } else { - None - } + parsed_content.as_ref().map(|parsed_content| { + cx.observe(parsed_content, |_, _, cx| cx.notify()) + }) }) .ok() .flatten(); @@ -583,7 +581,7 @@ fn same_diagnostic_hover(editor: &Editor, snapshot: &EditorSnapshot, anchor: Anc async fn parse_blocks( blocks: &[HoverBlock], - language_registry: &Arc, + language_registry: Option<&Arc>, language: Option>, cx: &mut AsyncWindowContext, ) -> Option> { @@ -599,18 +597,15 @@ async fn parse_blocks( }) .join("\n\n"); - let rendered_block = cx - .new_window_entity(|_window, cx| { - Markdown::new( - combined_text.into(), - Some(language_registry.clone()), - language.map(|language| language.name()), - cx, - ) - }) - .ok(); - - rendered_block + cx.new_window_entity(|_window, cx| { + Markdown::new( + combined_text.into(), + language_registry.cloned(), + language.map(|language| language.name()), + cx, + ) + }) + .ok() } pub fn hover_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { @@ -622,7 +617,7 @@ pub fn hover_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { let mut base_text_style = window.text_style(); base_text_style.refine(&TextStyleRefinement { - font_family: Some(ui_font_family.clone()), + font_family: Some(ui_font_family), font_fallbacks: ui_font_fallbacks, color: Some(cx.theme().colors().editor_foreground), ..Default::default() @@ -671,7 +666,7 @@ pub fn diagnostics_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { let mut base_text_style = window.text_style(); base_text_style.refine(&TextStyleRefinement { - font_family: Some(ui_font_family.clone()), + font_family: Some(ui_font_family), font_fallbacks: ui_font_fallbacks, color: Some(cx.theme().colors().editor_foreground), ..Default::default() @@ -712,59 +707,54 @@ pub fn diagnostics_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { } pub fn open_markdown_url(link: SharedString, window: &mut Window, cx: &mut App) { - if let Ok(uri) = Url::parse(&link) { - if uri.scheme() == "file" { - if let Some(workspace) = window.root::().flatten() { - workspace.update(cx, |workspace, cx| { - let task = workspace.open_abs_path( - PathBuf::from(uri.path()), - OpenOptions { - visible: Some(OpenVisible::None), - ..Default::default() - }, - window, - cx, - ); + if let Ok(uri) = Url::parse(&link) + && uri.scheme() == "file" + && let Some(workspace) = window.root::().flatten() + { + workspace.update(cx, |workspace, cx| { + let task = workspace.open_abs_path( + PathBuf::from(uri.path()), + OpenOptions { + visible: Some(OpenVisible::None), + ..Default::default() + }, + window, + cx, + ); - cx.spawn_in(window, async move |_, cx| { - let item = task.await?; - // Ruby LSP uses URLs with #L1,1-4,4 - // we'll just take the first number and assume it's a line number - let Some(fragment) = uri.fragment() else { - return anyhow::Ok(()); - }; - let mut accum = 0u32; - for c in fragment.chars() { - if c >= '0' && c <= '9' && accum < u32::MAX / 2 { - accum *= 10; - accum += c as u32 - '0' as u32; - } else if accum > 0 { - break; - } - } - if accum == 0 { - return Ok(()); - } - let Some(editor) = cx.update(|_, cx| item.act_as::(cx))? else { - return Ok(()); - }; - editor.update_in(cx, |editor, window, cx| { - editor.change_selections( - Default::default(), - window, - cx, - |selections| { - selections.select_ranges([text::Point::new(accum - 1, 0) - ..text::Point::new(accum - 1, 0)]); - }, - ); - }) - }) - .detach_and_log_err(cx); - }); - return; - } - } + cx.spawn_in(window, async move |_, cx| { + let item = task.await?; + // Ruby LSP uses URLs with #L1,1-4,4 + // we'll just take the first number and assume it's a line number + let Some(fragment) = uri.fragment() else { + return anyhow::Ok(()); + }; + let mut accum = 0u32; + for c in fragment.chars() { + if c >= '0' && c <= '9' && accum < u32::MAX / 2 { + accum *= 10; + accum += c as u32 - '0' as u32; + } else if accum > 0 { + break; + } + } + if accum == 0 { + return Ok(()); + } + let Some(editor) = cx.update(|_, cx| item.act_as::(cx))? else { + return Ok(()); + }; + editor.update_in(cx, |editor, window, cx| { + editor.change_selections(Default::default(), window, cx, |selections| { + selections.select_ranges([ + text::Point::new(accum - 1, 0)..text::Point::new(accum - 1, 0) + ]); + }); + }) + }) + .detach_and_log_err(cx); + }); + return; } cx.open_url(&link); } @@ -834,20 +824,19 @@ impl HoverState { pub fn focused(&self, window: &mut Window, cx: &mut Context) -> bool { let mut hover_popover_is_focused = false; for info_popover in &self.info_popovers { - if let Some(markdown_view) = &info_popover.parsed_content { - if markdown_view.focus_handle(cx).is_focused(window) { - hover_popover_is_focused = true; - } + if let Some(markdown_view) = &info_popover.parsed_content + && markdown_view.focus_handle(cx).is_focused(window) + { + hover_popover_is_focused = true; } } - if let Some(diagnostic_popover) = &self.diagnostic_popover { - if diagnostic_popover + if let Some(diagnostic_popover) = &self.diagnostic_popover + && diagnostic_popover .markdown .focus_handle(cx) .is_focused(window) - { - hover_popover_is_focused = true; - } + { + hover_popover_is_focused = true; } hover_popover_is_focused } @@ -907,7 +896,7 @@ impl InfoPopover { .into_any_element() } - pub fn scroll(&self, amount: &ScrollAmount, window: &mut Window, cx: &mut Context) { + pub fn scroll(&self, amount: ScrollAmount, window: &mut Window, cx: &mut Context) { let mut current = self.scroll_handle.offset(); current.y -= amount.pixels( window.line_height(), diff --git a/crates/editor/src/indent_guides.rs b/crates/editor/src/indent_guides.rs index f6d51c929a95ac3d256095b627c303a6c49a64a5..23717eeb158cea0f01e6a4efca6d2ff14a8fa824 100644 --- a/crates/editor/src/indent_guides.rs +++ b/crates/editor/src/indent_guides.rs @@ -164,15 +164,15 @@ pub fn indent_guides_in_range( let end_anchor = snapshot.buffer_snapshot.anchor_after(end_offset); let mut fold_ranges = Vec::>::new(); - let mut folds = snapshot.folds_in_range(start_offset..end_offset).peekable(); - while let Some(fold) = folds.next() { + let folds = snapshot.folds_in_range(start_offset..end_offset).peekable(); + for fold in folds { let start = fold.range.start.to_point(&snapshot.buffer_snapshot); let end = fold.range.end.to_point(&snapshot.buffer_snapshot); - if let Some(last_range) = fold_ranges.last_mut() { - if last_range.end >= start { - last_range.end = last_range.end.max(end); - continue; - } + if let Some(last_range) = fold_ranges.last_mut() + && last_range.end >= start + { + last_range.end = last_range.end.max(end); + continue; } fold_ranges.push(start..end); } diff --git a/crates/editor/src/inlay_hint_cache.rs b/crates/editor/src/inlay_hint_cache.rs index 60ad0e5bf6c5672a3ce651793b8f76a82ab4c0ff..c1b0a7640c155fff02f0b778e8996a9b68ea452e 100644 --- a/crates/editor/src/inlay_hint_cache.rs +++ b/crates/editor/src/inlay_hint_cache.rs @@ -475,10 +475,7 @@ impl InlayHintCache { let excerpt_cached_hints = excerpt_cached_hints.read(); let mut excerpt_cache = excerpt_cached_hints.ordered_hints.iter().fuse().peekable(); shown_excerpt_hints_to_remove.retain(|(shown_anchor, shown_hint_id)| { - let Some(buffer) = shown_anchor - .buffer_id - .and_then(|buffer_id| multi_buffer.buffer(buffer_id)) - else { + let Some(buffer) = multi_buffer.buffer_for_anchor(*shown_anchor, cx) else { return false; }; let buffer_snapshot = buffer.read(cx).snapshot(); @@ -498,16 +495,14 @@ impl InlayHintCache { cmp::Ordering::Less | cmp::Ordering::Equal => { if !old_kinds.contains(&cached_hint.kind) && new_kinds.contains(&cached_hint.kind) - { - if let Some(anchor) = multi_buffer_snapshot + && let Some(anchor) = multi_buffer_snapshot .anchor_in_excerpt(*excerpt_id, cached_hint.position) - { - to_insert.push(Inlay::hint( - cached_hint_id.id(), - anchor, - cached_hint, - )); - } + { + to_insert.push(Inlay::hint( + cached_hint_id.id(), + anchor, + cached_hint, + )); } excerpt_cache.next(); } @@ -522,16 +517,16 @@ impl InlayHintCache { for cached_hint_id in excerpt_cache { let maybe_missed_cached_hint = &excerpt_cached_hints.hints_by_id[cached_hint_id]; let cached_hint_kind = maybe_missed_cached_hint.kind; - if !old_kinds.contains(&cached_hint_kind) && new_kinds.contains(&cached_hint_kind) { - if let Some(anchor) = multi_buffer_snapshot + if !old_kinds.contains(&cached_hint_kind) + && new_kinds.contains(&cached_hint_kind) + && let Some(anchor) = multi_buffer_snapshot .anchor_in_excerpt(*excerpt_id, maybe_missed_cached_hint.position) - { - to_insert.push(Inlay::hint( - cached_hint_id.id(), - anchor, - maybe_missed_cached_hint, - )); - } + { + to_insert.push(Inlay::hint( + cached_hint_id.id(), + anchor, + maybe_missed_cached_hint, + )); } } } @@ -620,44 +615,44 @@ impl InlayHintCache { ) { if let Some(excerpt_hints) = self.hints.get(&excerpt_id) { let mut guard = excerpt_hints.write(); - if let Some(cached_hint) = guard.hints_by_id.get_mut(&id) { - if let ResolveState::CanResolve(server_id, _) = &cached_hint.resolve_state { - let hint_to_resolve = cached_hint.clone(); - let server_id = *server_id; - cached_hint.resolve_state = ResolveState::Resolving; - drop(guard); - cx.spawn_in(window, async move |editor, cx| { - let resolved_hint_task = editor.update(cx, |editor, cx| { - let buffer = editor.buffer().read(cx).buffer(buffer_id)?; - editor.semantics_provider.as_ref()?.resolve_inlay_hint( - hint_to_resolve, - buffer, - server_id, - cx, - ) - })?; - if let Some(resolved_hint_task) = resolved_hint_task { - let mut resolved_hint = - resolved_hint_task.await.context("hint resolve task")?; - editor.read_with(cx, |editor, _| { - if let Some(excerpt_hints) = - editor.inlay_hint_cache.hints.get(&excerpt_id) + if let Some(cached_hint) = guard.hints_by_id.get_mut(&id) + && let ResolveState::CanResolve(server_id, _) = &cached_hint.resolve_state + { + let hint_to_resolve = cached_hint.clone(); + let server_id = *server_id; + cached_hint.resolve_state = ResolveState::Resolving; + drop(guard); + cx.spawn_in(window, async move |editor, cx| { + let resolved_hint_task = editor.update(cx, |editor, cx| { + let buffer = editor.buffer().read(cx).buffer(buffer_id)?; + editor.semantics_provider.as_ref()?.resolve_inlay_hint( + hint_to_resolve, + buffer, + server_id, + cx, + ) + })?; + if let Some(resolved_hint_task) = resolved_hint_task { + let mut resolved_hint = + resolved_hint_task.await.context("hint resolve task")?; + editor.read_with(cx, |editor, _| { + if let Some(excerpt_hints) = + editor.inlay_hint_cache.hints.get(&excerpt_id) + { + let mut guard = excerpt_hints.write(); + if let Some(cached_hint) = guard.hints_by_id.get_mut(&id) + && cached_hint.resolve_state == ResolveState::Resolving { - let mut guard = excerpt_hints.write(); - if let Some(cached_hint) = guard.hints_by_id.get_mut(&id) { - if cached_hint.resolve_state == ResolveState::Resolving { - resolved_hint.resolve_state = ResolveState::Resolved; - *cached_hint = resolved_hint; - } - } + resolved_hint.resolve_state = ResolveState::Resolved; + *cached_hint = resolved_hint; } - })?; - } + } + })?; + } - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - } + anyhow::Ok(()) + }) + .detach_and_log_err(cx); } } } @@ -990,8 +985,8 @@ fn fetch_and_update_hints( let buffer = editor.buffer().read(cx).buffer(query.buffer_id)?; - if !editor.registered_buffers.contains_key(&query.buffer_id) { - if let Some(project) = editor.project.as_ref() { + if !editor.registered_buffers.contains_key(&query.buffer_id) + && let Some(project) = editor.project.as_ref() { project.update(cx, |project, cx| { editor.registered_buffers.insert( query.buffer_id, @@ -999,7 +994,6 @@ fn fetch_and_update_hints( ); }) } - } editor .semantics_provider @@ -1240,14 +1234,12 @@ fn apply_hint_update( .inlay_hint_cache .allowed_hint_kinds .contains(&new_hint.kind) - { - if let Some(new_hint_position) = + && let Some(new_hint_position) = multi_buffer_snapshot.anchor_in_excerpt(query.excerpt_id, new_hint.position) - { - splice - .to_insert - .push(Inlay::hint(new_inlay_id, new_hint_position, &new_hint)); - } + { + splice + .to_insert + .push(Inlay::hint(new_inlay_id, new_hint_position, &new_hint)); } let new_id = InlayId::Hint(new_inlay_id); cached_excerpt_hints.hints_by_id.insert(new_id, new_hint); @@ -1347,7 +1339,7 @@ pub mod tests { let i = task_lsp_request_count.fetch_add(1, Ordering::Release) + 1; assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(file_with_hints).unwrap(), + lsp::Uri::from_file_path(file_with_hints).unwrap(), ); Ok(Some(vec![lsp::InlayHint { position: lsp::Position::new(0, i), @@ -1457,7 +1449,7 @@ pub mod tests { async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(file_with_hints).unwrap(), + lsp::Uri::from_file_path(file_with_hints).unwrap(), ); let current_call_id = Arc::clone(&task_lsp_request_count).fetch_add(1, Ordering::SeqCst); @@ -1602,7 +1594,7 @@ pub mod tests { "Rust" => { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/a/main.rs")) + lsp::Uri::from_file_path(path!("/a/main.rs")) .unwrap(), ); rs_lsp_request_count.fetch_add(1, Ordering::Release) @@ -1611,7 +1603,7 @@ pub mod tests { "Markdown" => { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/a/other.md")) + lsp::Uri::from_file_path(path!("/a/other.md")) .unwrap(), ); md_lsp_request_count.fetch_add(1, Ordering::Release) @@ -1797,7 +1789,7 @@ pub mod tests { async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(file_with_hints).unwrap(), + lsp::Uri::from_file_path(file_with_hints).unwrap(), ); Ok(Some(vec![ lsp::InlayHint { @@ -2135,7 +2127,7 @@ pub mod tests { let i = lsp_request_count.fetch_add(1, Ordering::SeqCst) + 1; assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(file_with_hints).unwrap(), + lsp::Uri::from_file_path(file_with_hints).unwrap(), ); Ok(Some(vec![lsp::InlayHint { position: lsp::Position::new(0, i), @@ -2298,7 +2290,7 @@ pub mod tests { async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), ); task_lsp_request_ranges.lock().push(params.range); @@ -2641,11 +2633,11 @@ pub mod tests { let task_editor_edited = Arc::clone(&closure_editor_edited); async move { let hint_text = if params.text_document.uri - == lsp::Url::from_file_path(path!("/a/main.rs")).unwrap() + == lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap() { "main hint" } else if params.text_document.uri - == lsp::Url::from_file_path(path!("/a/other.rs")).unwrap() + == lsp::Uri::from_file_path(path!("/a/other.rs")).unwrap() { "other hint" } else { @@ -2952,11 +2944,11 @@ pub mod tests { let task_editor_edited = Arc::clone(&closure_editor_edited); async move { let hint_text = if params.text_document.uri - == lsp::Url::from_file_path(path!("/a/main.rs")).unwrap() + == lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap() { "main hint" } else if params.text_document.uri - == lsp::Url::from_file_path(path!("/a/other.rs")).unwrap() + == lsp::Uri::from_file_path(path!("/a/other.rs")).unwrap() { "other hint" } else { @@ -3124,7 +3116,7 @@ pub mod tests { async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), ); let query_start = params.range.start; Ok(Some(vec![lsp::InlayHint { @@ -3196,7 +3188,7 @@ pub mod tests { async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(file_with_hints).unwrap(), + lsp::Uri::from_file_path(file_with_hints).unwrap(), ); let i = lsp_request_count.fetch_add(1, Ordering::SeqCst) + 1; @@ -3359,7 +3351,7 @@ pub mod tests { move |params, _| async move { assert_eq!( params.text_document.uri, - lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + lsp::Uri::from_file_path(path!("/a/main.rs")).unwrap(), ); Ok(Some( serde_json::from_value(json!([ diff --git a/crates/editor/src/items.rs b/crates/editor/src/items.rs index 1da82c605d5cbf0555f864efc50dac97f323f777..253d0c27518107dc1cad3733cefbfef5bc12b807 100644 --- a/crates/editor/src/items.rs +++ b/crates/editor/src/items.rs @@ -42,6 +42,7 @@ use ui::{IconDecorationKind, prelude::*}; use util::{ResultExt, TryFutureExt, paths::PathExt}; use workspace::{ CollaboratorId, ItemId, ItemNavHistory, ToolbarItemLocation, ViewId, Workspace, WorkspaceId, + invalid_buffer_view::InvalidBufferView, item::{FollowableItem, Item, ItemEvent, ProjectItem, SaveOptions}, searchable::{Direction, SearchEvent, SearchableItem, SearchableItemHandle}, }; @@ -103,9 +104,9 @@ impl FollowableItem for Editor { multibuffer = MultiBuffer::new(project.read(cx).capability()); let mut sorted_excerpts = state.excerpts.clone(); sorted_excerpts.sort_by_key(|e| e.id); - let mut sorted_excerpts = sorted_excerpts.into_iter().peekable(); + let sorted_excerpts = sorted_excerpts.into_iter().peekable(); - while let Some(excerpt) = sorted_excerpts.next() { + for excerpt in sorted_excerpts { let Ok(buffer_id) = BufferId::new(excerpt.buffer_id) else { continue; }; @@ -201,7 +202,7 @@ impl FollowableItem for Editor { if buffer .as_singleton() .and_then(|buffer| buffer.read(cx).file()) - .map_or(false, |file| file.is_private()) + .is_some_and(|file| file.is_private()) { return None; } @@ -293,7 +294,7 @@ impl FollowableItem for Editor { EditorEvent::ExcerptsRemoved { ids, .. } => { update .deleted_excerpts - .extend(ids.iter().map(ExcerptId::to_proto)); + .extend(ids.iter().copied().map(ExcerptId::to_proto)); true } EditorEvent::ScrollPositionChanged { autoscroll, .. } if !autoscroll => { @@ -524,8 +525,8 @@ fn serialize_selection( ) -> proto::Selection { proto::Selection { id: selection.id as u64, - start: Some(serialize_anchor(&selection.start, &buffer)), - end: Some(serialize_anchor(&selection.end, &buffer)), + start: Some(serialize_anchor(&selection.start, buffer)), + end: Some(serialize_anchor(&selection.end, buffer)), reversed: selection.reversed, } } @@ -650,10 +651,15 @@ impl Item for Editor { if let Some(path) = path_for_buffer(&self.buffer, detail, true, cx) { path.to_string_lossy().to_string().into() } else { - "untitled".into() + // Use the same logic as the displayed title for consistency + self.buffer.read(cx).title(cx).to_string().into() } } + fn suggested_filename(&self, cx: &App) -> SharedString { + self.buffer.read(cx).title(cx).to_string().into() + } + fn tab_icon(&self, _: &Window, cx: &App) -> Option { ItemSettings::get_global(cx) .file_icons @@ -674,7 +680,7 @@ impl Item for Editor { let buffer = buffer.read(cx); let path = buffer.project_path(cx)?; let buffer_id = buffer.remote_id(); - let project = self.project.as_ref()?.read(cx); + let project = self.project()?.read(cx); let entry = project.entry_for_path(&path, cx)?; let (repo, repo_path) = project .git_store() @@ -711,7 +717,7 @@ impl Item for Editor { .read(cx) .as_singleton() .and_then(|buffer| buffer.read(cx).file()) - .map_or(false, |file| file.disk_state() == DiskState::Deleted); + .is_some_and(|file| file.disk_state() == DiskState::Deleted); h_flex() .gap_2() @@ -770,12 +776,6 @@ impl Item for Editor { self.nav_history = Some(history); } - fn discarded(&self, _project: Entity, _: &mut Window, cx: &mut Context) { - for buffer in self.buffer().clone().read(cx).all_buffers() { - buffer.update(cx, |buffer, cx| buffer.discarded(cx)) - } - } - fn on_removed(&self, cx: &App) { self.report_editor_event(ReportEditorEvent::Closed, None, cx); } @@ -926,10 +926,10 @@ impl Item for Editor { })?; buffer .update(cx, |buffer, cx| { - if let Some(transaction) = transaction { - if !buffer.is_singleton() { - buffer.push_transaction(&transaction.0, cx); - } + if let Some(transaction) = transaction + && !buffer.is_singleton() + { + buffer.push_transaction(&transaction.0, cx); } }) .ok(); @@ -1005,24 +1005,18 @@ impl Item for Editor { ) { self.workspace = Some((workspace.weak_handle(), workspace.database_id())); if let Some(workspace) = &workspace.weak_handle().upgrade() { - cx.subscribe( - &workspace, - |editor, _, event: &workspace::Event, _cx| match event { - workspace::Event::ModalOpened => { - editor.mouse_context_menu.take(); - editor.inline_blame_popover.take(); - } - _ => {} - }, - ) + cx.subscribe(workspace, |editor, _, event: &workspace::Event, _cx| { + if let workspace::Event::ModalOpened = event { + editor.mouse_context_menu.take(); + editor.inline_blame_popover.take(); + } + }) .detach(); } } fn to_item_events(event: &EditorEvent, mut f: impl FnMut(ItemEvent)) { match event { - EditorEvent::Closed => f(ItemEvent::CloseItem), - EditorEvent::Saved | EditorEvent::TitleChanged => { f(ItemEvent::UpdateTab); f(ItemEvent::UpdateBreadcrumbs); @@ -1036,6 +1030,10 @@ impl Item for Editor { f(ItemEvent::UpdateBreadcrumbs); } + EditorEvent::BreadcrumbsChanged => { + f(ItemEvent::UpdateBreadcrumbs); + } + EditorEvent::DirtyChanged => { f(ItemEvent::UpdateTab); } @@ -1132,7 +1130,7 @@ impl SerializableItem for Editor { // First create the empty buffer let buffer = project - .update(cx, |project, cx| project.create_buffer(cx))? + .update(cx, |project, cx| project.create_buffer(true, cx))? .await?; // Then set the text so that the dirty bit is set correctly @@ -1240,7 +1238,7 @@ impl SerializableItem for Editor { .. } => window.spawn(cx, async move |cx| { let buffer = project - .update(cx, |project, cx| project.create_buffer(cx))? + .update(cx, |project, cx| project.create_buffer(true, cx))? .await?; cx.update(|window, cx| { @@ -1288,7 +1286,7 @@ impl SerializableItem for Editor { project .read(cx) .worktree_for_id(worktree_id, cx) - .and_then(|worktree| worktree.read(cx).absolutize(&file.path()).ok()) + .and_then(|worktree| worktree.read(cx).absolutize(file.path()).ok()) .or_else(|| { let full_path = file.full_path(cx); let project_path = project.read(cx).find_project_path(&full_path, cx)?; @@ -1366,40 +1364,47 @@ impl ProjectItem for Editor { let mut editor = Self::for_buffer(buffer.clone(), Some(project), window, cx); if let Some((excerpt_id, buffer_id, snapshot)) = editor.buffer().read(cx).snapshot(cx).as_singleton() + && WorkspaceSettings::get(None, cx).restore_on_file_reopen + && let Some(restoration_data) = Self::project_item_kind() + .and_then(|kind| pane.as_ref()?.project_item_restoration_data.get(&kind)) + .and_then(|data| data.downcast_ref::()) + .and_then(|data| { + let file = project::File::from_dyn(buffer.read(cx).file())?; + data.entries.get(&file.abs_path(cx)) + }) { - if WorkspaceSettings::get(None, cx).restore_on_file_reopen { - if let Some(restoration_data) = Self::project_item_kind() - .and_then(|kind| pane.as_ref()?.project_item_restoration_data.get(&kind)) - .and_then(|data| data.downcast_ref::()) - .and_then(|data| { - let file = project::File::from_dyn(buffer.read(cx).file())?; - data.entries.get(&file.abs_path(cx)) - }) - { - editor.fold_ranges( - clip_ranges(&restoration_data.folds, &snapshot), - false, - window, - cx, - ); - if !restoration_data.selections.is_empty() { - editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { - s.select_ranges(clip_ranges(&restoration_data.selections, &snapshot)); - }); - } - let (top_row, offset) = restoration_data.scroll_position; - let anchor = Anchor::in_buffer( - *excerpt_id, - buffer_id, - snapshot.anchor_before(Point::new(top_row, 0)), - ); - editor.set_scroll_anchor(ScrollAnchor { anchor, offset }, window, cx); - } + editor.fold_ranges( + clip_ranges(&restoration_data.folds, snapshot), + false, + window, + cx, + ); + if !restoration_data.selections.is_empty() { + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_ranges(clip_ranges(&restoration_data.selections, snapshot)); + }); } + let (top_row, offset) = restoration_data.scroll_position; + let anchor = Anchor::in_buffer( + *excerpt_id, + buffer_id, + snapshot.anchor_before(Point::new(top_row, 0)), + ); + editor.set_scroll_anchor(ScrollAnchor { anchor, offset }, window, cx); } editor } + + fn for_broken_project_item( + abs_path: &Path, + is_local: bool, + e: &anyhow::Error, + window: &mut Window, + cx: &mut App, + ) -> Option { + Some(InvalidBufferView::new(abs_path, is_local, e, window, cx)) + } } fn clip_ranges<'a>( diff --git a/crates/editor/src/jsx_tag_auto_close.rs b/crates/editor/src/jsx_tag_auto_close.rs index 95a792583953e02a77e592ea957b752f0f8042bb..e6c518beae3ecf3741b5f74be6087628f5231c8c 100644 --- a/crates/editor/src/jsx_tag_auto_close.rs +++ b/crates/editor/src/jsx_tag_auto_close.rs @@ -37,7 +37,7 @@ pub(crate) fn should_auto_close( let text = buffer .text_for_range(edited_range.clone()) .collect::(); - let edited_range = edited_range.to_offset(&buffer); + let edited_range = edited_range.to_offset(buffer); if !text.ends_with(">") { continue; } @@ -51,12 +51,11 @@ pub(crate) fn should_auto_close( continue; }; let mut jsx_open_tag_node = node; - if node.grammar_name() != config.open_tag_node_name { - if let Some(parent) = node.parent() { - if parent.grammar_name() == config.open_tag_node_name { - jsx_open_tag_node = parent; - } - } + if node.grammar_name() != config.open_tag_node_name + && let Some(parent) = node.parent() + && parent.grammar_name() == config.open_tag_node_name + { + jsx_open_tag_node = parent; } if jsx_open_tag_node.grammar_name() != config.open_tag_node_name { continue; @@ -87,9 +86,9 @@ pub(crate) fn should_auto_close( }); } if to_auto_edit.is_empty() { - return None; + None } else { - return Some(to_auto_edit); + Some(to_auto_edit) } } @@ -182,12 +181,12 @@ pub(crate) fn generate_auto_close_edits( */ { let tag_node_name_equals = |node: &Node, name: &str| { - let is_empty = name.len() == 0; + let is_empty = name.is_empty(); if let Some(node_name) = node.named_child(TS_NODE_TAG_NAME_CHILD_INDEX) { let range = node_name.byte_range(); return buffer.text_for_range(range).equals_str(name); } - return is_empty; + is_empty }; let tree_root_node = { @@ -208,7 +207,7 @@ pub(crate) fn generate_auto_close_edits( cur = descendant; } - assert!(ancestors.len() > 0); + assert!(!ancestors.is_empty()); let mut tree_root_node = open_tag; @@ -228,7 +227,7 @@ pub(crate) fn generate_auto_close_edits( let has_open_tag_with_same_tag_name = ancestor .named_child(0) .filter(|n| n.kind() == config.open_tag_node_name) - .map_or(false, |element_open_tag_node| { + .is_some_and(|element_open_tag_node| { tag_node_name_equals(&element_open_tag_node, &tag_name) }); if has_open_tag_with_same_tag_name { @@ -264,8 +263,7 @@ pub(crate) fn generate_auto_close_edits( } let is_after_open_tag = |node: &Node| { - return node.start_byte() < open_tag.start_byte() - && node.end_byte() < open_tag.start_byte(); + node.start_byte() < open_tag.start_byte() && node.end_byte() < open_tag.start_byte() }; // perf: use cursor for more efficient traversal @@ -284,10 +282,8 @@ pub(crate) fn generate_auto_close_edits( unclosed_open_tag_count -= 1; } } else if has_erroneous_close_tag && kind == erroneous_close_tag_node_name { - if tag_node_name_equals(&node, &tag_name) { - if !is_after_open_tag(&node) { - unclosed_open_tag_count -= 1; - } + if tag_node_name_equals(&node, &tag_name) && !is_after_open_tag(&node) { + unclosed_open_tag_count -= 1; } } else if kind == config.jsx_element_node_name { // perf: filter only open,close,element,erroneous nodes @@ -304,7 +300,7 @@ pub(crate) fn generate_auto_close_edits( let edit_range = edit_anchor..edit_anchor; edits.push((edit_range, format!("", tag_name))); } - return Ok(edits); + Ok(edits) } pub(crate) fn refresh_enabled_in_any_buffer( @@ -370,7 +366,7 @@ pub(crate) fn construct_initial_buffer_versions_map< initial_buffer_versions.insert(buffer_id, buffer_version); } } - return initial_buffer_versions; + initial_buffer_versions } pub(crate) fn handle_from( @@ -458,12 +454,9 @@ pub(crate) fn handle_from( let ensure_no_edits_since_start = || -> Option<()> { let has_edits_since_start = this .read_with(cx, |this, cx| { - this.buffer - .read(cx) - .buffer(buffer_id) - .map_or(true, |buffer| { - buffer.read(cx).has_edits_since(&buffer_version_initial) - }) + this.buffer.read(cx).buffer(buffer_id).is_none_or(|buffer| { + buffer.read(cx).has_edits_since(&buffer_version_initial) + }) }) .ok()?; @@ -514,7 +507,7 @@ pub(crate) fn handle_from( { let selections = this - .read_with(cx, |this, _| this.selections.disjoint_anchors().clone()) + .read_with(cx, |this, _| this.selections.disjoint_anchors()) .ok()?; for selection in selections.iter() { let Some(selection_buffer_offset_head) = @@ -815,10 +808,7 @@ mod jsx_tag_autoclose_tests { ); buf }); - let buffer_c = cx.new(|cx| { - let buf = language::Buffer::local("(cx) .into_iter() .any(|s| !s.is_empty()); - let has_git_repo = anchor.buffer_id.is_some_and(|buffer_id| { - project - .read(cx) - .git_store() - .read(cx) - .repository_and_path_for_buffer_id(buffer_id, cx) - .is_some() - }); + let has_git_repo = buffer + .buffer_id_for_anchor(anchor) + .is_some_and(|buffer_id| { + project + .read(cx) + .git_store() + .read(cx) + .repository_and_path_for_buffer_id(buffer_id, cx) + .is_some() + }); let evaluate_selection = window.is_action_available(&EvaluateSelectedText, cx); let run_to_cursor = window.is_action_available(&RunToCursor, cx); diff --git a/crates/editor/src/movement.rs b/crates/editor/src/movement.rs index fdda0e82bca6a85b25042ad7e8a662ff2fdae49d..4bd353a2873431d8102dfc15dea9a74ac2b2c241 100644 --- a/crates/editor/src/movement.rs +++ b/crates/editor/src/movement.rs @@ -4,7 +4,7 @@ use super::{Bias, DisplayPoint, DisplaySnapshot, SelectionGoal, ToDisplayPoint}; use crate::{DisplayRow, EditorStyle, ToOffset, ToPoint, scroll::ScrollAnchor}; use gpui::{Pixels, WindowTextSystem}; -use language::Point; +use language::{CharClassifier, Point}; use multi_buffer::{MultiBufferRow, MultiBufferSnapshot}; use serde::Deserialize; use workspace::searchable::Direction; @@ -289,12 +289,114 @@ pub fn previous_word_start_or_newline(map: &DisplaySnapshot, point: DisplayPoint let classifier = map.buffer_snapshot.char_classifier_at(raw_point); find_preceding_boundary_display_point(map, point, FindRange::MultiLine, |left, right| { - (classifier.kind(left) != classifier.kind(right) && !right.is_whitespace()) + (classifier.kind(left) != classifier.kind(right) && !classifier.is_whitespace(right)) || left == '\n' || right == '\n' }) } +/// Text movements are too greedy, making deletions too greedy too. +/// Makes deletions more ergonomic by potentially reducing the deletion range based on its text contents: +/// * whitespace sequences with length >= 2 stop the deletion after removal (despite movement jumping over the word behind the whitespaces) +/// * brackets stop the deletion after removal (despite movement currently not accounting for these and jumping over) +pub fn adjust_greedy_deletion( + map: &DisplaySnapshot, + delete_from: DisplayPoint, + delete_until: DisplayPoint, + ignore_brackets: bool, +) -> DisplayPoint { + if delete_from == delete_until { + return delete_until; + } + let is_backward = delete_from > delete_until; + let delete_range = if is_backward { + map.display_point_to_point(delete_until, Bias::Left) + .to_offset(&map.buffer_snapshot) + ..map + .display_point_to_point(delete_from, Bias::Right) + .to_offset(&map.buffer_snapshot) + } else { + map.display_point_to_point(delete_from, Bias::Left) + .to_offset(&map.buffer_snapshot) + ..map + .display_point_to_point(delete_until, Bias::Right) + .to_offset(&map.buffer_snapshot) + }; + + let trimmed_delete_range = if ignore_brackets { + delete_range + } else { + let brackets_in_delete_range = map + .buffer_snapshot + .bracket_ranges(delete_range.clone()) + .into_iter() + .flatten() + .flat_map(|(left_bracket, right_bracket)| { + [ + left_bracket.start, + left_bracket.end, + right_bracket.start, + right_bracket.end, + ] + }) + .filter(|&bracket| delete_range.start < bracket && bracket < delete_range.end); + let closest_bracket = if is_backward { + brackets_in_delete_range.max() + } else { + brackets_in_delete_range.min() + }; + + if is_backward { + closest_bracket.unwrap_or(delete_range.start)..delete_range.end + } else { + delete_range.start..closest_bracket.unwrap_or(delete_range.end) + } + }; + + let mut whitespace_sequences = Vec::new(); + let mut current_offset = trimmed_delete_range.start; + let mut whitespace_sequence_length = 0; + let mut whitespace_sequence_start = 0; + for ch in map + .buffer_snapshot + .text_for_range(trimmed_delete_range.clone()) + .flat_map(str::chars) + { + if ch.is_whitespace() { + if whitespace_sequence_length == 0 { + whitespace_sequence_start = current_offset; + } + whitespace_sequence_length += 1; + } else { + if whitespace_sequence_length >= 2 { + whitespace_sequences.push((whitespace_sequence_start, current_offset)); + } + whitespace_sequence_start = 0; + whitespace_sequence_length = 0; + } + current_offset += ch.len_utf8(); + } + if whitespace_sequence_length >= 2 { + whitespace_sequences.push((whitespace_sequence_start, current_offset)); + } + + let closest_whitespace_end = if is_backward { + whitespace_sequences.last().map(|&(start, _)| start) + } else { + whitespace_sequences.first().map(|&(_, end)| end) + }; + + closest_whitespace_end + .unwrap_or_else(|| { + if is_backward { + trimmed_delete_range.start + } else { + trimmed_delete_range.end + } + }) + .to_display_point(map) +} + /// Returns a position of the previous subword boundary, where a subword is defined as a run of /// word characters of the same "subkind" - where subcharacter kinds are '_' character, /// lowerspace characters and uppercase characters. @@ -303,15 +405,18 @@ pub fn previous_subword_start(map: &DisplaySnapshot, point: DisplayPoint) -> Dis let classifier = map.buffer_snapshot.char_classifier_at(raw_point); find_preceding_boundary_display_point(map, point, FindRange::MultiLine, |left, right| { - let is_word_start = - classifier.kind(left) != classifier.kind(right) && !right.is_whitespace(); - let is_subword_start = classifier.is_word('-') && left == '-' && right != '-' - || left == '_' && right != '_' - || left.is_lowercase() && right.is_uppercase(); - is_word_start || is_subword_start || left == '\n' + is_subword_start(left, right, &classifier) || left == '\n' }) } +pub fn is_subword_start(left: char, right: char, classifier: &CharClassifier) -> bool { + let is_word_start = classifier.kind(left) != classifier.kind(right) && !right.is_whitespace(); + let is_subword_start = classifier.is_word('-') && left == '-' && right != '-' + || left == '_' && right != '_' + || left.is_lowercase() && right.is_uppercase(); + is_word_start || is_subword_start +} + /// Returns a position of the next word boundary, where a word character is defined as either /// uppercase letter, lowercase letter, '_' character or language-specific word character (like '-' in CSS). pub fn next_word_end(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint { @@ -361,15 +466,19 @@ pub fn next_subword_end(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPo let classifier = map.buffer_snapshot.char_classifier_at(raw_point); find_boundary(map, point, FindRange::MultiLine, |left, right| { - let is_word_end = - (classifier.kind(left) != classifier.kind(right)) && !classifier.is_whitespace(left); - let is_subword_end = classifier.is_word('-') && left != '-' && right == '-' - || left != '_' && right == '_' - || left.is_lowercase() && right.is_uppercase(); - is_word_end || is_subword_end || right == '\n' + is_subword_end(left, right, &classifier) || right == '\n' }) } +pub fn is_subword_end(left: char, right: char, classifier: &CharClassifier) -> bool { + let is_word_end = + (classifier.kind(left) != classifier.kind(right)) && !classifier.is_whitespace(left); + let is_subword_end = classifier.is_word('-') && left != '-' && right == '-' + || left != '_' && right == '_' + || left.is_lowercase() && right.is_uppercase(); + is_word_end || is_subword_end +} + /// Returns a position of the start of the current paragraph, where a paragraph /// is defined as a run of non-blank lines. pub fn start_of_paragraph( @@ -439,17 +548,17 @@ pub fn start_of_excerpt( }; match direction { Direction::Prev => { - let mut start = excerpt.start_anchor().to_display_point(&map); + let mut start = excerpt.start_anchor().to_display_point(map); if start >= display_point && start.row() > DisplayRow(0) { let Some(excerpt) = map.buffer_snapshot.excerpt_before(excerpt.id()) else { return display_point; }; - start = excerpt.start_anchor().to_display_point(&map); + start = excerpt.start_anchor().to_display_point(map); } start } Direction::Next => { - let mut end = excerpt.end_anchor().to_display_point(&map); + let mut end = excerpt.end_anchor().to_display_point(map); *end.row_mut() += 1; map.clip_point(end, Bias::Right) } @@ -467,7 +576,7 @@ pub fn end_of_excerpt( }; match direction { Direction::Prev => { - let mut start = excerpt.start_anchor().to_display_point(&map); + let mut start = excerpt.start_anchor().to_display_point(map); if start.row() > DisplayRow(0) { *start.row_mut() -= 1; } @@ -476,7 +585,7 @@ pub fn end_of_excerpt( start } Direction::Next => { - let mut end = excerpt.end_anchor().to_display_point(&map); + let mut end = excerpt.end_anchor().to_display_point(map); *end.column_mut() = 0; if end <= display_point { *end.row_mut() += 1; @@ -485,7 +594,7 @@ pub fn end_of_excerpt( else { return display_point; }; - end = excerpt.end_anchor().to_display_point(&map); + end = excerpt.end_anchor().to_display_point(map); *end.column_mut() = 0; } end @@ -510,10 +619,10 @@ pub fn find_preceding_boundary_point( if find_range == FindRange::SingleLine && ch == '\n' { break; } - if let Some(prev_ch) = prev_ch { - if is_boundary(ch, prev_ch) { - break; - } + if let Some(prev_ch) = prev_ch + && is_boundary(ch, prev_ch) + { + break; } offset -= ch.len_utf8(); @@ -562,13 +671,13 @@ pub fn find_boundary_point( if find_range == FindRange::SingleLine && ch == '\n' { break; } - if let Some(prev_ch) = prev_ch { - if is_boundary(prev_ch, ch) { - if return_point_before_boundary { - return map.clip_point(prev_offset.to_display_point(map), Bias::Right); - } else { - break; - } + if let Some(prev_ch) = prev_ch + && is_boundary(prev_ch, ch) + { + if return_point_before_boundary { + return map.clip_point(prev_offset.to_display_point(map), Bias::Right); + } else { + break; } } prev_offset = offset; @@ -603,13 +712,13 @@ pub fn find_preceding_boundary_trail( // Find the boundary let start_offset = offset; for ch in forward { - if let Some(prev_ch) = prev_ch { - if is_boundary(prev_ch, ch) { - if start_offset == offset { - trail_offset = Some(offset); - } else { - break; - } + if let Some(prev_ch) = prev_ch + && is_boundary(prev_ch, ch) + { + if start_offset == offset { + trail_offset = Some(offset); + } else { + break; } } offset -= ch.len_utf8(); @@ -651,13 +760,13 @@ pub fn find_boundary_trail( // Find the boundary let start_offset = offset; for ch in forward { - if let Some(prev_ch) = prev_ch { - if is_boundary(prev_ch, ch) { - if start_offset == offset { - trail_offset = Some(offset); - } else { - break; - } + if let Some(prev_ch) = prev_ch + && is_boundary(prev_ch, ch) + { + if start_offset == offset { + trail_offset = Some(offset); + } else { + break; } } offset += ch.len_utf8(); diff --git a/crates/editor/src/persistence.rs b/crates/editor/src/persistence.rs index 88fde539479b3159a2fbcb7e3b0473d4b4b91e76..ec7c149b4e107600c35e70ef3dffcdb2e8f8bcb7 100644 --- a/crates/editor/src/persistence.rs +++ b/crates/editor/src/persistence.rs @@ -1,13 +1,17 @@ use anyhow::Result; -use db::sqlez::bindable::{Bind, Column, StaticColumnCount}; -use db::sqlez::statement::Statement; +use db::{ + query, + sqlez::{ + bindable::{Bind, Column, StaticColumnCount}, + domain::Domain, + statement::Statement, + }, + sqlez_macros::sql, +}; use fs::MTime; use itertools::Itertools as _; use std::path::PathBuf; -use db::sqlez_macros::sql; -use db::{define_connection, query}; - use workspace::{ItemId, WorkspaceDb, WorkspaceId}; #[derive(Clone, Debug, PartialEq, Default)] @@ -83,7 +87,11 @@ impl Column for SerializedEditor { } } -define_connection!( +pub struct EditorDb(db::sqlez::thread_safe_connection::ThreadSafeConnection); + +impl Domain for EditorDb { + const NAME: &str = stringify!(EditorDb); + // Current schema shape using pseudo-rust syntax: // editors( // item_id: usize, @@ -113,7 +121,8 @@ define_connection!( // start: usize, // end: usize, // ) - pub static ref DB: EditorDb = &[ + + const MIGRATIONS: &[&str] = &[ sql! ( CREATE TABLE editors( item_id INTEGER NOT NULL, @@ -189,7 +198,9 @@ define_connection!( ) STRICT; ), ]; -); +} + +db::static_connection!(DB, EditorDb, [WorkspaceDb]); // https://www.sqlite.org/limits.html // > <..> the maximum value of a host parameter number is SQLITE_MAX_VARIABLE_NUMBER, diff --git a/crates/editor/src/proposed_changes_editor.rs b/crates/editor/src/proposed_changes_editor.rs index 1ead45b3de89c0705510f8afc55ecf6176a4d7a2..2d4710a8d44a023f0c3206ad0c327a34c36fdac4 100644 --- a/crates/editor/src/proposed_changes_editor.rs +++ b/crates/editor/src/proposed_changes_editor.rs @@ -241,24 +241,13 @@ impl ProposedChangesEditor { event: &BufferEvent, _cx: &mut Context, ) { - match event { - BufferEvent::Operation { .. } => { - self.recalculate_diffs_tx - .unbounded_send(RecalculateDiff { - buffer, - debounce: true, - }) - .ok(); - } - // BufferEvent::DiffBaseChanged => { - // self.recalculate_diffs_tx - // .unbounded_send(RecalculateDiff { - // buffer, - // debounce: false, - // }) - // .ok(); - // } - _ => (), + if let BufferEvent::Operation { .. } = event { + self.recalculate_diffs_tx + .unbounded_send(RecalculateDiff { + buffer, + debounce: true, + }) + .ok(); } } } @@ -442,7 +431,7 @@ impl SemanticsProvider for BranchBufferSemanticsProvider { buffer: &Entity, position: text::Anchor, cx: &mut App, - ) -> Option>> { + ) -> Option>>> { let buffer = self.to_base(buffer, &[position], cx)?; self.0.hover(&buffer, position, cx) } @@ -478,7 +467,7 @@ impl SemanticsProvider for BranchBufferSemanticsProvider { } fn supports_inlay_hints(&self, buffer: &Entity, cx: &mut App) -> bool { - if let Some(buffer) = self.to_base(&buffer, &[], cx) { + if let Some(buffer) = self.to_base(buffer, &[], cx) { self.0.supports_inlay_hints(&buffer, cx) } else { false @@ -491,7 +480,7 @@ impl SemanticsProvider for BranchBufferSemanticsProvider { position: text::Anchor, cx: &mut App, ) -> Option>>> { - let buffer = self.to_base(&buffer, &[position], cx)?; + let buffer = self.to_base(buffer, &[position], cx)?; self.0.document_highlights(&buffer, position, cx) } @@ -501,8 +490,8 @@ impl SemanticsProvider for BranchBufferSemanticsProvider { position: text::Anchor, kind: crate::GotoDefinitionKind, cx: &mut App, - ) -> Option>>> { - let buffer = self.to_base(&buffer, &[position], cx)?; + ) -> Option>>>> { + let buffer = self.to_base(buffer, &[position], cx)?; self.0.definitions(&buffer, position, kind, cx) } diff --git a/crates/editor/src/rust_analyzer_ext.rs b/crates/editor/src/rust_analyzer_ext.rs index 2b8150de67050ccced22100bfedd02be44f63907..f4059ca03d2ad70106aa958b4fe0c545cb4988ea 100644 --- a/crates/editor/src/rust_analyzer_ext.rs +++ b/crates/editor/src/rust_analyzer_ext.rs @@ -26,6 +26,17 @@ fn is_rust_language(language: &Language) -> bool { } pub fn apply_related_actions(editor: &Entity, window: &mut Window, cx: &mut App) { + if editor.read(cx).project().is_some_and(|project| { + project + .read(cx) + .language_server_statuses(cx) + .any(|(_, status)| status.name == RUST_ANALYZER_NAME) + }) { + register_action(editor, window, cancel_flycheck_action); + register_action(editor, window, run_flycheck_action); + register_action(editor, window, clear_flycheck_action); + } + if editor .read(cx) .buffer() @@ -35,12 +46,9 @@ pub fn apply_related_actions(editor: &Entity, window: &mut Window, cx: & .filter_map(|buffer| buffer.read(cx).language()) .any(|language| is_rust_language(language)) { - register_action(&editor, window, go_to_parent_module); - register_action(&editor, window, expand_macro_recursively); - register_action(&editor, window, open_docs); - register_action(&editor, window, cancel_flycheck_action); - register_action(&editor, window, run_flycheck_action); - register_action(&editor, window, clear_flycheck_action); + register_action(editor, window, go_to_parent_module); + register_action(editor, window, expand_macro_recursively); + register_action(editor, window, open_docs); } } @@ -192,7 +200,7 @@ pub fn expand_macro_recursively( } let buffer = project - .update(cx, |project, cx| project.create_buffer(cx))? + .update(cx, |project, cx| project.create_buffer(false, cx))? .await?; workspace.update_in(cx, |workspace, window, cx| { buffer.update(cx, |buffer, cx| { @@ -285,11 +293,11 @@ pub fn open_docs(editor: &mut Editor, _: &OpenDocs, window: &mut Window, cx: &mu workspace.update(cx, |_workspace, cx| { // Check if the local document exists, otherwise fallback to the online document. // Open with the default browser. - if let Some(local_url) = docs_urls.local { - if fs::metadata(Path::new(&local_url[8..])).is_ok() { - cx.open_url(&local_url); - return; - } + if let Some(local_url) = docs_urls.local + && fs::metadata(Path::new(&local_url[8..])).is_ok() + { + cx.open_url(&local_url); + return; } if let Some(web_url) = docs_urls.web { @@ -309,7 +317,7 @@ fn cancel_flycheck_action( let Some(project) = &editor.project else { return; }; - let Some(buffer_id) = editor + let buffer_id = editor .selections .disjoint_anchors() .iter() @@ -321,10 +329,7 @@ fn cancel_flycheck_action( .read(cx) .entry_id(cx)?; project.path_for_entry(entry_id, cx) - }) - else { - return; - }; + }); cancel_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx); } @@ -337,7 +342,7 @@ fn run_flycheck_action( let Some(project) = &editor.project else { return; }; - let Some(buffer_id) = editor + let buffer_id = editor .selections .disjoint_anchors() .iter() @@ -349,10 +354,7 @@ fn run_flycheck_action( .read(cx) .entry_id(cx)?; project.path_for_entry(entry_id, cx) - }) - else { - return; - }; + }); run_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx); } @@ -365,7 +367,7 @@ fn clear_flycheck_action( let Some(project) = &editor.project else { return; }; - let Some(buffer_id) = editor + let buffer_id = editor .selections .disjoint_anchors() .iter() @@ -377,9 +379,6 @@ fn clear_flycheck_action( .read(cx) .entry_id(cx)?; project.path_for_entry(entry_id, cx) - }) - else { - return; - }; + }); clear_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx); } diff --git a/crates/editor/src/scroll.rs b/crates/editor/src/scroll.rs index 08ff23f8f70be4e512826c2793a2d95e2aee1690..82314486187db99c2ba5c104faa42828dad57cdb 100644 --- a/crates/editor/src/scroll.rs +++ b/crates/editor/src/scroll.rs @@ -675,7 +675,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } @@ -703,20 +703,20 @@ impl Editor { if matches!( settings.defaults.soft_wrap, SoftWrap::PreferredLineLength | SoftWrap::Bounded - ) { - if (settings.defaults.preferred_line_length as f32) < visible_column_count { - visible_column_count = settings.defaults.preferred_line_length as f32; - } + ) && (settings.defaults.preferred_line_length as f32) < visible_column_count + { + visible_column_count = settings.defaults.preferred_line_length as f32; } // If the scroll position is currently at the left edge of the document // (x == 0.0) and the intent is to scroll right, the gutter's margin // should first be added to the current position, otherwise the cursor // will end at the column position minus the margin, which looks off. - if current_position.x == 0.0 && amount.columns(visible_column_count) > 0. { - if let Some(last_position_map) = &self.last_position_map { - current_position.x += self.gutter_dimensions.margin / last_position_map.em_advance; - } + if current_position.x == 0.0 + && amount.columns(visible_column_count) > 0. + && let Some(last_position_map) = &self.last_position_map + { + current_position.x += self.gutter_dimensions.margin / last_position_map.em_advance; } let new_position = current_position + point( @@ -749,12 +749,10 @@ impl Editor { if let (Some(visible_lines), Some(visible_columns)) = (self.visible_line_count(), self.visible_column_count()) + && newest_head.row() <= DisplayRow(screen_top.row().0 + visible_lines as u32) + && newest_head.column() <= screen_top.column() + visible_columns as u32 { - if newest_head.row() <= DisplayRow(screen_top.row().0 + visible_lines as u32) - && newest_head.column() <= screen_top.column() + visible_columns as u32 - { - return Ordering::Equal; - } + return Ordering::Equal; } Ordering::Greater diff --git a/crates/editor/src/scroll/actions.rs b/crates/editor/src/scroll/actions.rs index 72827b2fee48c424a632018b5f66015cd058ed79..f8104665f904e08466c72f3c410e58cb941c6b6f 100644 --- a/crates/editor/src/scroll/actions.rs +++ b/crates/editor/src/scroll/actions.rs @@ -16,7 +16,7 @@ impl Editor { return; } - if matches!(self.mode, EditorMode::SingleLine { .. }) { + if matches!(self.mode, EditorMode::SingleLine) { cx.propagate(); return; } diff --git a/crates/editor/src/scroll/autoscroll.rs b/crates/editor/src/scroll/autoscroll.rs index 88d3b52d764d15280c8ed03dd87f42b8c32d0911..057d622903ed12b4d996759cd93dc76f2ba9ee8d 100644 --- a/crates/editor/src/scroll/autoscroll.rs +++ b/crates/editor/src/scroll/autoscroll.rs @@ -116,12 +116,12 @@ impl Editor { let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); let mut scroll_position = self.scroll_manager.scroll_position(&display_map); let original_y = scroll_position.y; - if let Some(last_bounds) = self.expect_bounds_change.take() { - if scroll_position.y != 0. { - scroll_position.y += (bounds.top() - last_bounds.top()) / line_height; - if scroll_position.y < 0. { - scroll_position.y = 0.; - } + if let Some(last_bounds) = self.expect_bounds_change.take() + && scroll_position.y != 0. + { + scroll_position.y += (bounds.top() - last_bounds.top()) / line_height; + if scroll_position.y < 0. { + scroll_position.y = 0.; } } if scroll_position.y > max_scroll_top { diff --git a/crates/editor/src/scroll/scroll_amount.rs b/crates/editor/src/scroll/scroll_amount.rs index b2af4f8e4fbce899c6aee317402ee1365cee8600..43f1aa128548597ee07cbb297ab5aaf0e8f79b6e 100644 --- a/crates/editor/src/scroll/scroll_amount.rs +++ b/crates/editor/src/scroll/scroll_amount.rs @@ -15,7 +15,7 @@ impl ScrollDirection { } } -#[derive(Debug, Clone, PartialEq, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Deserialize)] pub enum ScrollAmount { // Scroll N lines (positive is towards the end of the document) Line(f32), @@ -67,10 +67,7 @@ impl ScrollAmount { } pub fn is_full_page(&self) -> bool { - match self { - ScrollAmount::Page(count) if count.abs() == 1.0 => true, - _ => false, - } + matches!(self, ScrollAmount::Page(count) if count.abs() == 1.0) } pub fn direction(&self) -> ScrollDirection { diff --git a/crates/editor/src/selections_collection.rs b/crates/editor/src/selections_collection.rs index 73c5f1c076e510b2aeb7d648b7ce066b65f9094c..0a02390b641e1020aff8d9cf0167b44485baf489 100644 --- a/crates/editor/src/selections_collection.rs +++ b/crates/editor/src/selections_collection.rs @@ -119,8 +119,8 @@ impl SelectionsCollection { cx: &mut App, ) -> Option> { let map = self.display_map(cx); - let selection = resolve_selections(self.pending_anchor().as_ref(), &map).next(); - selection + + resolve_selections(self.pending_anchor().as_ref(), &map).next() } pub(crate) fn pending_mode(&self) -> Option { @@ -276,18 +276,18 @@ impl SelectionsCollection { cx: &mut App, ) -> Selection { let map = self.display_map(cx); - let selection = resolve_selections([self.newest_anchor()], &map) + + resolve_selections([self.newest_anchor()], &map) .next() - .unwrap(); - selection + .unwrap() } pub fn newest_display(&self, cx: &mut App) -> Selection { let map = self.display_map(cx); - let selection = resolve_selections_display([self.newest_anchor()], &map) + + resolve_selections_display([self.newest_anchor()], &map) .next() - .unwrap(); - selection + .unwrap() } pub fn oldest_anchor(&self) -> &Selection { @@ -303,10 +303,10 @@ impl SelectionsCollection { cx: &mut App, ) -> Selection { let map = self.display_map(cx); - let selection = resolve_selections([self.oldest_anchor()], &map) + + resolve_selections([self.oldest_anchor()], &map) .next() - .unwrap(); - selection + .unwrap() } pub fn first_anchor(&self) -> Selection { diff --git a/crates/editor/src/signature_help.rs b/crates/editor/src/signature_help.rs index e9f8d2dbd33f71e224ae1c868dab80a7c4bb467a..cb21f35d7ed7556cf09f9e566286a10f8317ca6c 100644 --- a/crates/editor/src/signature_help.rs +++ b/crates/editor/src/signature_help.rs @@ -169,7 +169,7 @@ impl Editor { else { return; }; - let Some(lsp_store) = self.project.as_ref().map(|p| p.read(cx).lsp_store()) else { + let Some(lsp_store) = self.project().map(|p| p.read(cx).lsp_store()) else { return; }; let task = lsp_store.update(cx, |lsp_store, cx| { @@ -182,7 +182,9 @@ impl Editor { let signature_help = task.await; editor .update(cx, |editor, cx| { - let Some(mut signature_help) = signature_help.into_iter().next() else { + let Some(mut signature_help) = + signature_help.unwrap_or_default().into_iter().next() + else { editor .signature_help_state .hide(SignatureHelpHiddenBy::AutoClose); @@ -196,7 +198,7 @@ impl Editor { .highlight_text(&text, 0..signature.label.len()) .into_iter() .flat_map(|(range, highlight_id)| { - Some((range, highlight_id.style(&cx.theme().syntax())?)) + Some((range, highlight_id.style(cx.theme().syntax())?)) }); signature.highlights = combine_highlights(signature.highlights.clone(), highlights) diff --git a/crates/editor/src/tasks.rs b/crates/editor/src/tasks.rs index 0d497e4cac779a65b7a6593d3b82f786d10321ce..8be2a3a2e14d7b815d2ca3496adc6f70ec16055e 100644 --- a/crates/editor/src/tasks.rs +++ b/crates/editor/src/tasks.rs @@ -89,7 +89,7 @@ impl Editor { .lsp_task_source()?; if lsp_settings .get(&lsp_tasks_source) - .map_or(true, |s| s.enable_lsp_tasks) + .is_none_or(|s| s.enable_lsp_tasks) { let buffer_id = buffer.read(cx).remote_id(); Some((lsp_tasks_source, buffer_id)) diff --git a/crates/editor/src/test.rs b/crates/editor/src/test.rs index f328945dbe6ae961d3fcb1ef5c80055b6adb0afb..03e99b9fff9a89fcac28605fe6bf7a08b23f8f02 100644 --- a/crates/editor/src/test.rs +++ b/crates/editor/src/test.rs @@ -20,7 +20,7 @@ use multi_buffer::ToPoint; use pretty_assertions::assert_eq; use project::{Project, project_settings::DiagnosticSeverity}; use ui::{App, BorrowAppContext, px}; -use util::test::{marked_text_offsets, marked_text_ranges}; +use util::test::{generate_marked_text, marked_text_offsets, marked_text_ranges}; #[cfg(test)] #[ctor::ctor] @@ -104,13 +104,14 @@ pub fn assert_text_with_selections( marked_text: &str, cx: &mut Context, ) { - let (unmarked_text, text_ranges) = marked_text_ranges(marked_text, true); + let (unmarked_text, _text_ranges) = marked_text_ranges(marked_text, true); assert_eq!(editor.text(cx), unmarked_text, "text doesn't match"); - assert_eq!( - editor.selections.ranges(cx), - text_ranges, - "selections don't match", + let actual = generate_marked_text( + &editor.text(cx), + &editor.selections.ranges(cx), + marked_text.contains("«"), ); + assert_eq!(actual, marked_text, "Selections don't match"); } // RA thinks this is dead code even though it is used in a whole lot of tests @@ -184,12 +185,12 @@ pub fn editor_content_with_blocks(editor: &Entity, cx: &mut VisualTestCo for (row, block) in blocks { match block { Block::Custom(custom_block) => { - if let BlockPlacement::Near(x) = &custom_block.placement { - if snapshot.intersects_fold(x.to_point(&snapshot.buffer_snapshot)) { - continue; - } + if let BlockPlacement::Near(x) = &custom_block.placement + && snapshot.intersects_fold(x.to_point(&snapshot.buffer_snapshot)) + { + continue; }; - let content = block_content_for_tests(&editor, custom_block.id, cx) + let content = block_content_for_tests(editor, custom_block.id, cx) .expect("block content not found"); // 2: "related info 1 for diagnostic 0" if let Some(height) = custom_block.height { @@ -230,26 +231,23 @@ pub fn editor_content_with_blocks(editor: &Entity, cx: &mut VisualTestCo lines[row as usize].push_str("§ -----"); } } - Block::ExcerptBoundary { - excerpt, - height, - starts_new_buffer, - } => { - if starts_new_buffer { - lines[row.0 as usize].push_str(&cx.update(|_, cx| { - format!( - "§ {}", - excerpt - .buffer - .file() - .unwrap() - .file_name(cx) - .to_string_lossy() - ) - })); - } else { - lines[row.0 as usize].push_str("§ -----") + Block::ExcerptBoundary { height, .. } => { + for row in row.0..row.0 + height { + lines[row as usize].push_str("§ -----"); } + } + Block::BufferHeader { excerpt, height } => { + lines[row.0 as usize].push_str(&cx.update(|_, cx| { + format!( + "§ {}", + excerpt + .buffer + .file() + .unwrap() + .file_name(cx) + .to_string_lossy() + ) + })); for row in row.0 + 1..row.0 + height { lines[row as usize].push_str("§ -----"); } diff --git a/crates/editor/src/test/editor_lsp_test_context.rs b/crates/editor/src/test/editor_lsp_test_context.rs index c59786b1eb387835a21e2c155efaf6acefd4ff4a..79935340358662350dbbc640d96f5d60ec8aaf6b 100644 --- a/crates/editor/src/test/editor_lsp_test_context.rs +++ b/crates/editor/src/test/editor_lsp_test_context.rs @@ -29,7 +29,7 @@ pub struct EditorLspTestContext { pub cx: EditorTestContext, pub lsp: lsp::FakeLanguageServer, pub workspace: Entity, - pub buffer_lsp_url: lsp::Url, + pub buffer_lsp_url: lsp::Uri, } pub(crate) fn rust_lang() -> Arc { @@ -189,7 +189,7 @@ impl EditorLspTestContext { }, lsp, workspace, - buffer_lsp_url: lsp::Url::from_file_path(root.join("dir").join(file_name)).unwrap(), + buffer_lsp_url: lsp::Uri::from_file_path(root.join("dir").join(file_name)).unwrap(), } } @@ -300,6 +300,7 @@ impl EditorLspTestContext { self.to_lsp_range(ranges[0].clone()) } + #[expect(clippy::wrong_self_convention, reason = "This is test code")] pub fn to_lsp_range(&mut self, range: Range) -> lsp::Range { let snapshot = self.update_editor(|editor, window, cx| editor.snapshot(window, cx)); let start_point = range.start.to_point(&snapshot.buffer_snapshot); @@ -326,6 +327,7 @@ impl EditorLspTestContext { }) } + #[expect(clippy::wrong_self_convention, reason = "This is test code")] pub fn to_lsp(&mut self, offset: usize) -> lsp::Position { let snapshot = self.update_editor(|editor, window, cx| editor.snapshot(window, cx)); let point = offset.to_point(&snapshot.buffer_snapshot); @@ -356,7 +358,7 @@ impl EditorLspTestContext { where T: 'static + request::Request, T::Params: 'static + Send, - F: 'static + Send + FnMut(lsp::Url, T::Params, gpui::AsyncApp) -> Fut, + F: 'static + Send + FnMut(lsp::Uri, T::Params, gpui::AsyncApp) -> Fut, Fut: 'static + Future>, { let url = self.buffer_lsp_url.clone(); diff --git a/crates/editor/src/test/editor_test_context.rs b/crates/editor/src/test/editor_test_context.rs index bdf73da5fbfd5d4c29826859790493fbb8494239..8c54c265edf7a19af9d17e982a5f4cb6a0079cc3 100644 --- a/crates/editor/src/test/editor_test_context.rs +++ b/crates/editor/src/test/editor_test_context.rs @@ -119,13 +119,7 @@ impl EditorTestContext { for excerpt in excerpts.into_iter() { let (text, ranges) = marked_text_ranges(excerpt, false); let buffer = cx.new(|cx| Buffer::local(text, cx)); - multibuffer.push_excerpts( - buffer, - ranges - .into_iter() - .map(|range| ExcerptRange::new(range.clone())), - cx, - ); + multibuffer.push_excerpts(buffer, ranges.into_iter().map(ExcerptRange::new), cx); } multibuffer }); @@ -297,9 +291,8 @@ impl EditorTestContext { pub fn set_head_text(&mut self, diff_base: &str) { self.cx.run_until_parked(); - let fs = self.update_editor(|editor, _, cx| { - editor.project.as_ref().unwrap().read(cx).fs().as_fake() - }); + let fs = + self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake()); let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone()); fs.set_head_for_repo( &Self::root_path().join(".git"), @@ -311,18 +304,16 @@ impl EditorTestContext { pub fn clear_index_text(&mut self) { self.cx.run_until_parked(); - let fs = self.update_editor(|editor, _, cx| { - editor.project.as_ref().unwrap().read(cx).fs().as_fake() - }); + let fs = + self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake()); fs.set_index_for_repo(&Self::root_path().join(".git"), &[]); self.cx.run_until_parked(); } pub fn set_index_text(&mut self, diff_base: &str) { self.cx.run_until_parked(); - let fs = self.update_editor(|editor, _, cx| { - editor.project.as_ref().unwrap().read(cx).fs().as_fake() - }); + let fs = + self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake()); let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone()); fs.set_index_for_repo( &Self::root_path().join(".git"), @@ -333,9 +324,8 @@ impl EditorTestContext { #[track_caller] pub fn assert_index_text(&mut self, expected: Option<&str>) { - let fs = self.update_editor(|editor, _, cx| { - editor.project.as_ref().unwrap().read(cx).fs().as_fake() - }); + let fs = + self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake()); let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone()); let mut found = None; fs.with_git_state(&Self::root_path().join(".git"), false, |git_state| { @@ -430,7 +420,7 @@ impl EditorTestContext { if expected_text == "[FOLDED]\n" { assert!(is_folded, "excerpt {} should be folded", ix); let is_selected = selections.iter().any(|s| s.head().excerpt_id == excerpt_id); - if expected_selections.len() > 0 { + if !expected_selections.is_empty() { assert!( is_selected, "excerpt {ix} should be selected. got {:?}", diff --git a/crates/eval/src/assertions.rs b/crates/eval/src/assertions.rs index 489e4aa22ecdc6633a0002238a2287ca0a5105f0..01fac186d33a8b5b156121acf924d37c90c64679 100644 --- a/crates/eval/src/assertions.rs +++ b/crates/eval/src/assertions.rs @@ -54,7 +54,7 @@ impl AssertionsReport { pub fn passed_count(&self) -> usize { self.ran .iter() - .filter(|a| a.result.as_ref().map_or(false, |result| result.passed)) + .filter(|a| a.result.as_ref().is_ok_and(|result| result.passed)) .count() } diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index 6558222d89769f329ce50c238ad145e5d6aebc0f..9e0504abca479483b4e5f49c41eec1f6ba3834f1 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -103,7 +103,7 @@ fn main() { let languages: HashSet = args.languages.into_iter().collect(); let http_client = Arc::new(ReqwestClient::new()); - let app = Application::headless().with_http_client(http_client.clone()); + let app = Application::headless().with_http_client(http_client); let all_threads = examples::all(&examples_dir); app.run(move |cx| { @@ -112,7 +112,7 @@ fn main() { let telemetry = app_state.client.telemetry(); telemetry.start(system_id, installation_id, session_id, cx); - let enable_telemetry = env::var("ZED_EVAL_TELEMETRY").map_or(false, |value| value == "1") + let enable_telemetry = env::var("ZED_EVAL_TELEMETRY").is_ok_and(|value| value == "1") && telemetry.has_checksum_seed(); if enable_telemetry { println!("Telemetry enabled"); @@ -167,15 +167,14 @@ fn main() { continue; } - if let Some(language) = meta.language_server { - if !languages.contains(&language.file_extension) { + if let Some(language) = meta.language_server + && !languages.contains(&language.file_extension) { panic!( "Eval for {:?} could not be run because no language server was found for extension {:?}", meta.name, language.file_extension ); } - } // TODO: This creates a worktree per repetition. Ideally these examples should // either be run sequentially on the same worktree, or reuse worktrees when there @@ -417,11 +416,7 @@ pub fn init(cx: &mut App) -> Arc { language::init(cx); debug_adapter_extension::init(extension_host_proxy.clone(), cx); - language_extension::init( - LspAccess::Noop, - extension_host_proxy.clone(), - languages.clone(), - ); + language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone()); language_model::init(client.clone(), cx); language_models::init(user_store.clone(), client.clone(), cx); languages::init(languages.clone(), node_runtime.clone(), cx); @@ -520,7 +515,7 @@ async fn judge_example( enable_telemetry: bool, cx: &AsyncApp, ) -> JudgeOutput { - let judge_output = example.judge(model.clone(), &run_output, cx).await; + let judge_output = example.judge(model.clone(), run_output, cx).await; if enable_telemetry { telemetry::event!( @@ -531,7 +526,7 @@ async fn judge_example( example_name = example.name.clone(), example_repetition = example.repetition, diff_evaluation = judge_output.diff.clone(), - thread_evaluation = judge_output.thread.clone(), + thread_evaluation = judge_output.thread, tool_metrics = run_output.tool_metrics, response_count = run_output.response_count, token_usage = run_output.token_usage, @@ -711,7 +706,7 @@ fn print_report( println!("Average thread score: {average_thread_score}%"); } - println!(""); + println!(); print_h2("CUMULATIVE TOOL METRICS"); println!("{}", cumulative_tool_metrics); diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 23c8814916da2df4016c4196d7767b748da54280..457b62e98ca4cabf83fb379cbaa70f07957ac6b7 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -64,7 +64,7 @@ impl ExampleMetadata { self.url .split('/') .next_back() - .unwrap_or(&"") + .unwrap_or("") .trim_end_matches(".git") .into() } @@ -255,7 +255,7 @@ impl ExampleContext { thread.update(cx, |thread, _cx| { if let Some(tool_use) = pending_tool_use { let mut tool_metrics = tool_metrics.lock().unwrap(); - if let Some(tool_result) = thread.tool_result(&tool_use_id) { + if let Some(tool_result) = thread.tool_result(tool_use_id) { let message = if tool_result.is_error { format!("✖︎ {}", tool_use.name) } else { @@ -335,7 +335,7 @@ impl ExampleContext { for message in thread.messages().skip(message_count_before) { messages.push(Message { _role: message.role, - text: message.to_string(), + text: message.to_message_content(), tool_use: thread .tool_uses_for_message(message.id, cx) .into_iter() diff --git a/crates/eval/src/examples/add_arg_to_trait_method.rs b/crates/eval/src/examples/add_arg_to_trait_method.rs index 9c538f926059eb3998eb725168905d148dccdc9d..084f12bc6263da030d313c362cc3d051dfdb8ea8 100644 --- a/crates/eval/src/examples/add_arg_to_trait_method.rs +++ b/crates/eval/src/examples/add_arg_to_trait_method.rs @@ -70,10 +70,10 @@ impl Example for AddArgToTraitMethod { let path_str = format!("crates/assistant_tools/src/{}.rs", tool_name); let edits = edits.get(Path::new(&path_str)); - let ignored = edits.map_or(false, |edits| { + let ignored = edits.is_some_and(|edits| { edits.has_added_line(" _window: Option,\n") }); - let uningored = edits.map_or(false, |edits| { + let uningored = edits.is_some_and(|edits| { edits.has_added_line(" window: Option,\n") }); @@ -89,7 +89,7 @@ impl Example for AddArgToTraitMethod { let batch_tool_edits = edits.get(Path::new("crates/assistant_tools/src/batch_tool.rs")); cx.assert( - batch_tool_edits.map_or(false, |edits| { + batch_tool_edits.is_some_and(|edits| { edits.has_added_line(" window: Option,\n") }), "Argument: batch_tool", diff --git a/crates/eval/src/explorer.rs b/crates/eval/src/explorer.rs index ee1dfa95c3840af42bdd134be1110bd2483c97aa..3326070cea4e860210f8ba7e0038fec2f3404c30 100644 --- a/crates/eval/src/explorer.rs +++ b/crates/eval/src/explorer.rs @@ -46,27 +46,25 @@ fn find_target_files_recursive( max_depth, found_files, )?; - } else if path.is_file() { - if let Some(filename_osstr) = path.file_name() { - if let Some(filename_str) = filename_osstr.to_str() { - if filename_str == target_filename { - found_files.push(path); - } - } - } + } else if path.is_file() + && let Some(filename_osstr) = path.file_name() + && let Some(filename_str) = filename_osstr.to_str() + && filename_str == target_filename + { + found_files.push(path); } } Ok(()) } pub fn generate_explorer_html(input_paths: &[PathBuf], output_path: &PathBuf) -> Result { - if let Some(parent) = output_path.parent() { - if !parent.exists() { - fs::create_dir_all(parent).context(format!( - "Failed to create output directory: {}", - parent.display() - ))?; - } + if let Some(parent) = output_path.parent() + && !parent.exists() + { + fs::create_dir_all(parent).context(format!( + "Failed to create output directory: {}", + parent.display() + ))?; } let template_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("src/explorer.html"); diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 0f2b4c18eade06060f9002615b6b995d9bfdde0d..c6e4e0b6ec683b63b90920861f3cd023069666e6 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -90,11 +90,8 @@ impl ExampleInstance { worktrees_dir: &Path, repetition: usize, ) -> Self { - let name = thread.meta().name.to_string(); - let run_directory = run_dir - .join(&name) - .join(repetition.to_string()) - .to_path_buf(); + let name = thread.meta().name; + let run_directory = run_dir.join(&name).join(repetition.to_string()); let repo_path = repo_path_for_url(repos_dir, &thread.meta().url); @@ -376,11 +373,10 @@ impl ExampleInstance { ); let result = this.thread.conversation(&mut example_cx).await; - if let Err(err) = result { - if !err.is::() { + if let Err(err) = result + && !err.is::() { return Err(err); } - } println!("{}Stopped", this.log_prefix); @@ -459,8 +455,8 @@ impl ExampleInstance { let mut output_file = File::create(self.run_directory.join("judge.md")).expect("failed to create judge.md"); - let diff_task = self.judge_diff(model.clone(), &run_output, cx); - let thread_task = self.judge_thread(model.clone(), &run_output, cx); + let diff_task = self.judge_diff(model.clone(), run_output, cx); + let thread_task = self.judge_thread(model.clone(), run_output, cx); let (diff_result, thread_result) = futures::join!(diff_task, thread_task); @@ -661,7 +657,7 @@ pub fn wait_for_lang_server( .update(cx, |buffer, cx| { lsp_store.update(cx, |lsp_store, cx| { lsp_store - .language_servers_for_local_buffer(&buffer, cx) + .language_servers_for_local_buffer(buffer, cx) .next() .is_some() }) @@ -679,8 +675,8 @@ pub fn wait_for_lang_server( [ cx.subscribe(&lsp_store, { let log_prefix = log_prefix.clone(); - move |_, event, _| match event { - project::LspStoreEvent::LanguageServerUpdate { + move |_, event, _| { + if let project::LspStoreEvent::LanguageServerUpdate { message: client::proto::update_language_server::Variant::WorkProgress( LspWorkProgress { @@ -689,11 +685,13 @@ pub fn wait_for_lang_server( }, ), .. - } => println!("{}⟲ {message}", log_prefix), - _ => {} + } = event + { + println!("{}⟲ {message}", log_prefix) + } } }), - cx.subscribe(&project, { + cx.subscribe(project, { let buffer = buffer.clone(); move |project, event, cx| match event { project::Event::LanguageServerAdded(_, _, _) => { @@ -771,7 +769,7 @@ pub async fn query_lsp_diagnostics( } fn parse_assertion_result(response: &str) -> Result { - let analysis = get_tag("analysis", response)?.to_string(); + let analysis = get_tag("analysis", response)?; let passed = match get_tag("passed", response)?.to_lowercase().as_str() { "true" => true, "false" => false, @@ -838,7 +836,7 @@ fn messages_to_markdown<'a>(message_iter: impl IntoIterator) for segment in &message.segments { match segment { MessageSegment::Text(text) => { - messages.push_str(&text); + messages.push_str(text); messages.push_str("\n\n"); } MessageSegment::Thinking { text, signature } => { @@ -846,7 +844,7 @@ fn messages_to_markdown<'a>(message_iter: impl IntoIterator) if let Some(sig) = signature { messages.push_str(&format!("Signature: {}\n\n", sig)); } - messages.push_str(&text); + messages.push_str(text); messages.push_str("\n"); } MessageSegment::RedactedThinking(items) => { @@ -878,7 +876,7 @@ pub async fn send_language_model_request( request: LanguageModelRequest, cx: &AsyncApp, ) -> anyhow::Result { - match model.stream_completion_text(request, &cx).await { + match model.stream_completion_text(request, cx).await { Ok(mut stream) => { let mut full_response = String::new(); while let Some(chunk_result) = stream.stream.next().await { @@ -915,9 +913,9 @@ impl RequestMarkdown { for tool in &request.tools { write!(&mut tools, "# {}\n\n", tool.name).unwrap(); write!(&mut tools, "{}\n\n", tool.description).unwrap(); - write!( + writeln!( &mut tools, - "{}\n", + "{}", MarkdownCodeBlock { tag: "json", text: &format!("{:#}", tool.input_schema) @@ -1191,7 +1189,7 @@ mod test { output.analysis, Some("The model did a good job but there were still compilations errors.".into()) ); - assert_eq!(output.passed, true); + assert!(output.passed); let response = r#" Text around ignored @@ -1211,6 +1209,6 @@ mod test { output.analysis, Some("Failed to compile:\n- Error 1\n- Error 2".into()) ); - assert_eq!(output.passed, false); + assert!(!output.passed); } } diff --git a/crates/extension/src/extension.rs b/crates/extension/src/extension.rs index 35f7f419383cb9f3c6cc518663ad818735eab80e..6af793253bce2d122a5361f6b83f33cb39d45253 100644 --- a/crates/extension/src/extension.rs +++ b/crates/extension/src/extension.rs @@ -178,16 +178,15 @@ pub fn parse_wasm_extension_version( for part in wasmparser::Parser::new(0).parse_all(wasm_bytes) { if let wasmparser::Payload::CustomSection(s) = part.context("error parsing wasm extension")? + && s.name() == "zed:api-version" { - if s.name() == "zed:api-version" { - version = parse_wasm_extension_version_custom_section(s.data()); - if version.is_none() { - bail!( - "extension {} has invalid zed:api-version section: {:?}", - extension_id, - s.data() - ); - } + version = parse_wasm_extension_version_custom_section(s.data()); + if version.is_none() { + bail!( + "extension {} has invalid zed:api-version section: {:?}", + extension_id, + s.data() + ); } } } diff --git a/crates/extension/src/extension_builder.rs b/crates/extension/src/extension_builder.rs index 621ba9250c12f8edd4ab49bbdef13bc976a239dd..3a3026f19c1961a6f4ac4c7fe5ac217ef6855cea 100644 --- a/crates/extension/src/extension_builder.rs +++ b/crates/extension/src/extension_builder.rs @@ -401,7 +401,7 @@ impl ExtensionBuilder { let mut clang_path = wasi_sdk_dir.clone(); clang_path.extend(["bin", &format!("clang{}", env::consts::EXE_SUFFIX)]); - if fs::metadata(&clang_path).map_or(false, |metadata| metadata.is_file()) { + if fs::metadata(&clang_path).is_ok_and(|metadata| metadata.is_file()) { return Ok(clang_path); } @@ -452,7 +452,7 @@ impl ExtensionBuilder { let mut output = Vec::new(); let mut stack = Vec::new(); - for payload in Parser::new(0).parse_all(&input) { + for payload in Parser::new(0).parse_all(input) { let payload = payload?; // Track nesting depth, so that we don't mess with inner producer sections: @@ -484,14 +484,10 @@ impl ExtensionBuilder { _ => {} } - match &payload { - CustomSection(c) => { - if strip_custom_section(c.name()) { - continue; - } - } - - _ => {} + if let CustomSection(c) = &payload + && strip_custom_section(c.name()) + { + continue; } if let Some((id, range)) = payload.as_section() { RawSection { diff --git a/crates/extension/src/extension_events.rs b/crates/extension/src/extension_events.rs index b151b3f412ea523a1c5b97dea210adf68e5bea89..94f3277b05b76aa93717458c57d0280a15b8435f 100644 --- a/crates/extension/src/extension_events.rs +++ b/crates/extension/src/extension_events.rs @@ -19,9 +19,8 @@ pub struct ExtensionEvents; impl ExtensionEvents { /// Returns the global [`ExtensionEvents`]. pub fn try_global(cx: &App) -> Option> { - return cx - .try_global::() - .map(|g| g.0.clone()); + cx.try_global::() + .map(|g| g.0.clone()) } fn new(_cx: &mut Context) -> Self { diff --git a/crates/extension/src/extension_host_proxy.rs b/crates/extension/src/extension_host_proxy.rs index 917739759f2ab0dcbfe012b1d774a8c9f11ca96b..6a24e3ba3f496bd0f0b89d61e9125b29ecae0204 100644 --- a/crates/extension/src/extension_host_proxy.rs +++ b/crates/extension/src/extension_host_proxy.rs @@ -28,7 +28,6 @@ pub struct ExtensionHostProxy { snippet_proxy: RwLock>>, slash_command_proxy: RwLock>>, context_server_proxy: RwLock>>, - indexed_docs_provider_proxy: RwLock>>, debug_adapter_provider_proxy: RwLock>>, } @@ -54,7 +53,6 @@ impl ExtensionHostProxy { snippet_proxy: RwLock::default(), slash_command_proxy: RwLock::default(), context_server_proxy: RwLock::default(), - indexed_docs_provider_proxy: RwLock::default(), debug_adapter_provider_proxy: RwLock::default(), } } @@ -87,14 +85,6 @@ impl ExtensionHostProxy { self.context_server_proxy.write().replace(Arc::new(proxy)); } - pub fn register_indexed_docs_provider_proxy( - &self, - proxy: impl ExtensionIndexedDocsProviderProxy, - ) { - self.indexed_docs_provider_proxy - .write() - .replace(Arc::new(proxy)); - } pub fn register_debug_adapter_proxy(&self, proxy: impl ExtensionDebugAdapterProviderProxy) { self.debug_adapter_provider_proxy .write() @@ -408,30 +398,6 @@ impl ExtensionContextServerProxy for ExtensionHostProxy { } } -pub trait ExtensionIndexedDocsProviderProxy: Send + Sync + 'static { - fn register_indexed_docs_provider(&self, extension: Arc, provider_id: Arc); - - fn unregister_indexed_docs_provider(&self, provider_id: Arc); -} - -impl ExtensionIndexedDocsProviderProxy for ExtensionHostProxy { - fn register_indexed_docs_provider(&self, extension: Arc, provider_id: Arc) { - let Some(proxy) = self.indexed_docs_provider_proxy.read().clone() else { - return; - }; - - proxy.register_indexed_docs_provider(extension, provider_id) - } - - fn unregister_indexed_docs_provider(&self, provider_id: Arc) { - let Some(proxy) = self.indexed_docs_provider_proxy.read().clone() else { - return; - }; - - proxy.unregister_indexed_docs_provider(provider_id) - } -} - pub trait ExtensionDebugAdapterProviderProxy: Send + Sync + 'static { fn register_debug_adapter( &self, diff --git a/crates/extension/src/extension_manifest.rs b/crates/extension/src/extension_manifest.rs index 5852b3e3fc32601e8d9527e02d593e02cd66f3c6..f5296198b06ffeeb83dd21be35d27be6b4387294 100644 --- a/crates/extension/src/extension_manifest.rs +++ b/crates/extension/src/extension_manifest.rs @@ -84,8 +84,6 @@ pub struct ExtensionManifest { #[serde(default)] pub slash_commands: BTreeMap, SlashCommandManifestEntry>, #[serde(default)] - pub indexed_docs_providers: BTreeMap, IndexedDocsProviderEntry>, - #[serde(default)] pub snippets: Option, #[serde(default)] pub capabilities: Vec, @@ -195,9 +193,6 @@ pub struct SlashCommandManifestEntry { pub requires_argument: bool, } -#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)] -pub struct IndexedDocsProviderEntry {} - #[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)] pub struct DebugAdapterManifestEntry { pub schema_path: Option, @@ -271,7 +266,6 @@ fn manifest_from_old_manifest( language_servers: Default::default(), context_servers: BTreeMap::default(), slash_commands: BTreeMap::default(), - indexed_docs_providers: BTreeMap::default(), snippets: None, capabilities: Vec::new(), debug_adapters: Default::default(), @@ -304,7 +298,6 @@ mod tests { language_servers: BTreeMap::default(), context_servers: BTreeMap::default(), slash_commands: BTreeMap::default(), - indexed_docs_providers: BTreeMap::default(), snippets: None, capabilities: vec![], debug_adapters: Default::default(), diff --git a/crates/extension_api/src/extension_api.rs b/crates/extension_api/src/extension_api.rs index aacc5d8795202e8d84c043a881933eabefae36bd..72327179ee08550994854d8b95a190ac94d84cea 100644 --- a/crates/extension_api/src/extension_api.rs +++ b/crates/extension_api/src/extension_api.rs @@ -232,10 +232,10 @@ pub trait Extension: Send + Sync { /// /// To work through a real-world example, take a `cargo run` task and a hypothetical `cargo` locator: /// 1. We may need to modify the task; in this case, it is problematic that `cargo run` spawns a binary. We should turn `cargo run` into a debug scenario with - /// `cargo build` task. This is the decision we make at `dap_locator_create_scenario` scope. + /// `cargo build` task. This is the decision we make at `dap_locator_create_scenario` scope. /// 2. Then, after the build task finishes, we will run `run_dap_locator` of the locator that produced the build task to find the program to be debugged. This function - /// should give us a debugger-agnostic configuration for launching a debug target (that we end up resolving with [`Extension::dap_config_to_scenario`]). It's almost as if the user - /// found the artifact path by themselves. + /// should give us a debugger-agnostic configuration for launching a debug target (that we end up resolving with [`Extension::dap_config_to_scenario`]). It's almost as if the user + /// found the artifact path by themselves. /// /// Note that you're not obliged to use build tasks with locators. Specifically, it is sufficient to provide a debug configuration directly in the return value of /// `dap_locator_create_scenario` if you're able to do that. Make sure to not fill out `build` field in that case, as that will prevent Zed from running second phase of resolution in such case. diff --git a/crates/extension_cli/src/main.rs b/crates/extension_cli/src/main.rs index ab4a9cddb0fa13421677772d1c07c1a8d9234d76..d6c0501efdacff2a9eaf542695ed44325908ea56 100644 --- a/crates/extension_cli/src/main.rs +++ b/crates/extension_cli/src/main.rs @@ -144,10 +144,6 @@ fn extension_provides(manifest: &ExtensionManifest) -> BTreeSet ExtensionManifest { .collect(), context_servers: BTreeMap::default(), slash_commands: BTreeMap::default(), - indexed_docs_providers: BTreeMap::default(), snippets: None, capabilities: vec![ExtensionCapability::ProcessExec( extension::ProcessExecCapability { diff --git a/crates/extension_host/src/capability_granter.rs b/crates/extension_host/src/capability_granter.rs index c77e5ecba15b5e10caa331d3b6ee3976b899ed21..5491967e080fc4d12a52f0360dab1896b77e19d3 100644 --- a/crates/extension_host/src/capability_granter.rs +++ b/crates/extension_host/src/capability_granter.rs @@ -108,7 +108,6 @@ mod tests { language_servers: BTreeMap::default(), context_servers: BTreeMap::default(), slash_commands: BTreeMap::default(), - indexed_docs_providers: BTreeMap::default(), snippets: None, capabilities: vec![], debug_adapters: Default::default(), @@ -146,7 +145,7 @@ mod tests { command: "*".to_string(), args: vec!["**".to_string()], })], - manifest.clone(), + manifest, ); assert!(granter.grant_exec("ls", &["-la"]).is_ok()); } diff --git a/crates/extension_host/src/extension_host.rs b/crates/extension_host/src/extension_host.rs index 67baf4e692a26294cf673d2769b4b647d73811b9..b114ad9f4c526f9c270681c55626455531becc2f 100644 --- a/crates/extension_host/src/extension_host.rs +++ b/crates/extension_host/src/extension_host.rs @@ -16,9 +16,9 @@ pub use extension::ExtensionManifest; use extension::extension_builder::{CompileExtensionOptions, ExtensionBuilder}; use extension::{ ExtensionContextServerProxy, ExtensionDebugAdapterProviderProxy, ExtensionEvents, - ExtensionGrammarProxy, ExtensionHostProxy, ExtensionIndexedDocsProviderProxy, - ExtensionLanguageProxy, ExtensionLanguageServerProxy, ExtensionSlashCommandProxy, - ExtensionSnippetProxy, ExtensionThemeProxy, + ExtensionGrammarProxy, ExtensionHostProxy, ExtensionLanguageProxy, + ExtensionLanguageServerProxy, ExtensionSlashCommandProxy, ExtensionSnippetProxy, + ExtensionThemeProxy, }; use fs::{Fs, RemoveOptions}; use futures::future::join_all; @@ -43,7 +43,7 @@ use language::{ use node_runtime::NodeRuntime; use project::ContextProviderWithTasks; use release_channel::ReleaseChannel; -use remote::SshRemoteClient; +use remote::{RemoteClient, RemoteConnectionOptions}; use semantic_version::SemanticVersion; use serde::{Deserialize, Serialize}; use settings::Settings; @@ -93,10 +93,9 @@ pub fn is_version_compatible( .wasm_api_version .as_ref() .and_then(|wasm_api_version| SemanticVersion::from_str(wasm_api_version).ok()) + && !is_supported_wasm_api_version(release_channel, wasm_api_version) { - if !is_supported_wasm_api_version(release_channel, wasm_api_version) { - return false; - } + return false; } true @@ -118,7 +117,7 @@ pub struct ExtensionStore { pub wasm_host: Arc, pub wasm_extensions: Vec<(Arc, WasmExtension)>, pub tasks: Vec>, - pub ssh_clients: HashMap>, + pub remote_clients: HashMap>, pub ssh_registered_tx: UnboundedSender<()>, } @@ -271,7 +270,7 @@ impl ExtensionStore { reload_tx, tasks: Vec::new(), - ssh_clients: HashMap::default(), + remote_clients: HashMap::default(), ssh_registered_tx: connection_registered_tx, }; @@ -292,19 +291,17 @@ impl ExtensionStore { // it must be asynchronously rebuilt. let mut extension_index = ExtensionIndex::default(); let mut extension_index_needs_rebuild = true; - if let Ok(index_content) = index_content { - if let Some(index) = serde_json::from_str(&index_content).log_err() { - extension_index = index; - if let (Ok(Some(index_metadata)), Ok(Some(extensions_metadata))) = - (index_metadata, extensions_metadata) - { - if index_metadata - .mtime - .bad_is_greater_than(extensions_metadata.mtime) - { - extension_index_needs_rebuild = false; - } - } + if let Ok(index_content) = index_content + && let Some(index) = serde_json::from_str(&index_content).log_err() + { + extension_index = index; + if let (Ok(Some(index_metadata)), Ok(Some(extensions_metadata))) = + (index_metadata, extensions_metadata) + && index_metadata + .mtime + .bad_is_greater_than(extensions_metadata.mtime) + { + extension_index_needs_rebuild = false; } } @@ -392,10 +389,9 @@ impl ExtensionStore { if let Some(path::Component::Normal(extension_dir_name)) = event_path.components().next() + && let Some(extension_id) = extension_dir_name.to_str() { - if let Some(extension_id) = extension_dir_name.to_str() { - reload_tx.unbounded_send(Some(extension_id.into())).ok(); - } + reload_tx.unbounded_send(Some(extension_id.into())).ok(); } } } @@ -566,12 +562,12 @@ impl ExtensionStore { extensions .into_iter() .filter(|extension| { - this.extension_index.extensions.get(&extension.id).map_or( - true, - |installed_extension| { + this.extension_index + .extensions + .get(&extension.id) + .is_none_or(|installed_extension| { installed_extension.manifest.version != extension.manifest.version - }, - ) + }) }) .collect() }) @@ -763,8 +759,8 @@ impl ExtensionStore { if let ExtensionOperation::Install = operation { this.update( cx, |this, cx| { cx.emit(Event::ExtensionInstalled(extension_id.clone())); - if let Some(events) = ExtensionEvents::try_global(cx) { - if let Some(manifest) = this.extension_manifest_for_id(&extension_id) { + if let Some(events) = ExtensionEvents::try_global(cx) + && let Some(manifest) = this.extension_manifest_for_id(&extension_id) { events.update(cx, |this, cx| { this.emit( extension::Event::ExtensionInstalled(manifest.clone()), @@ -772,7 +768,6 @@ impl ExtensionStore { ) }); } - } }) .ok(); } @@ -912,12 +907,12 @@ impl ExtensionStore { extension_store.update(cx, |_, cx| { cx.emit(Event::ExtensionUninstalled(extension_id.clone())); - if let Some(events) = ExtensionEvents::try_global(cx) { - if let Some(manifest) = extension_manifest { - events.update(cx, |this, cx| { - this.emit(extension::Event::ExtensionUninstalled(manifest.clone()), cx) - }); - } + if let Some(events) = ExtensionEvents::try_global(cx) + && let Some(manifest) = extension_manifest + { + events.update(cx, |this, cx| { + this.emit(extension::Event::ExtensionUninstalled(manifest.clone()), cx) + }); } })?; @@ -997,12 +992,12 @@ impl ExtensionStore { this.update(cx, |this, cx| this.reload(None, cx))?.await; this.update(cx, |this, cx| { cx.emit(Event::ExtensionInstalled(extension_id.clone())); - if let Some(events) = ExtensionEvents::try_global(cx) { - if let Some(manifest) = this.extension_manifest_for_id(&extension_id) { - events.update(cx, |this, cx| { - this.emit(extension::Event::ExtensionInstalled(manifest.clone()), cx) - }); - } + if let Some(events) = ExtensionEvents::try_global(cx) + && let Some(manifest) = this.extension_manifest_for_id(&extension_id) + { + events.update(cx, |this, cx| { + this.emit(extension::Event::ExtensionInstalled(manifest.clone()), cx) + }); } })?; @@ -1180,22 +1175,18 @@ impl ExtensionStore { } } - for (server_id, _) in &extension.manifest.context_servers { + for server_id in extension.manifest.context_servers.keys() { self.proxy.unregister_context_server(server_id.clone(), cx); } - for (adapter, _) in &extension.manifest.debug_adapters { + for adapter in extension.manifest.debug_adapters.keys() { self.proxy.unregister_debug_adapter(adapter.clone()); } - for (locator, _) in &extension.manifest.debug_locators { + for locator in extension.manifest.debug_locators.keys() { self.proxy.unregister_debug_locator(locator.clone()); } - for (command_name, _) in &extension.manifest.slash_commands { + for command_name in extension.manifest.slash_commands.keys() { self.proxy.unregister_slash_command(command_name.clone()); } - for (provider_id, _) in &extension.manifest.indexed_docs_providers { - self.proxy - .unregister_indexed_docs_provider(provider_id.clone()); - } } self.wasm_extensions @@ -1279,6 +1270,7 @@ impl ExtensionStore { queries, context_provider, toolchain_provider: None, + manifest_name: None, }) }), ); @@ -1344,7 +1336,7 @@ impl ExtensionStore { &extension_path, &extension.manifest, wasm_host.clone(), - &cx, + cx, ) .await .with_context(|| format!("Loading extension from {extension_path:?}")); @@ -1394,16 +1386,11 @@ impl ExtensionStore { ); } - for (id, _context_server_entry) in &manifest.context_servers { + for id in manifest.context_servers.keys() { this.proxy .register_context_server(extension.clone(), id.clone(), cx); } - for (provider_id, _provider) in &manifest.indexed_docs_providers { - this.proxy - .register_indexed_docs_provider(extension.clone(), provider_id.clone()); - } - for (debug_adapter, meta) in &manifest.debug_adapters { let mut path = root_dir.clone(); path.push(Path::new(manifest.id.as_ref())); @@ -1464,7 +1451,7 @@ impl ExtensionStore { if extension_dir .file_name() - .map_or(false, |file_name| file_name == ".DS_Store") + .is_some_and(|file_name| file_name == ".DS_Store") { continue; } @@ -1688,9 +1675,8 @@ impl ExtensionStore { let schema_path = &extension::build_debug_adapter_schema_path(adapter_name, meta); if fs.is_file(&src_dir.join(schema_path)).await { - match schema_path.parent() { - Some(parent) => fs.create_dir(&tmp_dir.join(parent)).await?, - None => {} + if let Some(parent) = schema_path.parent() { + fs.create_dir(&tmp_dir.join(parent)).await? } fs.copy_file( &src_dir.join(schema_path), @@ -1707,7 +1693,7 @@ impl ExtensionStore { async fn sync_extensions_over_ssh( this: &WeakEntity, - client: WeakEntity, + client: WeakEntity, cx: &mut AsyncApp, ) -> Result<()> { let extensions = this.update(cx, |this, _cx| { @@ -1779,12 +1765,12 @@ impl ExtensionStore { pub async fn update_ssh_clients(this: &WeakEntity, cx: &mut AsyncApp) -> Result<()> { let clients = this.update(cx, |this, _cx| { - this.ssh_clients.retain(|_k, v| v.upgrade().is_some()); - this.ssh_clients.values().cloned().collect::>() + this.remote_clients.retain(|_k, v| v.upgrade().is_some()); + this.remote_clients.values().cloned().collect::>() })?; for client in clients { - Self::sync_extensions_over_ssh(&this, client, cx) + Self::sync_extensions_over_ssh(this, client, cx) .await .log_err(); } @@ -1792,17 +1778,16 @@ impl ExtensionStore { anyhow::Ok(()) } - pub fn register_ssh_client(&mut self, client: Entity, cx: &mut Context) { - let connection_options = client.read(cx).connection_options(); - let ssh_url = connection_options.ssh_url(); + pub fn register_remote_client(&mut self, client: Entity, cx: &mut Context) { + let options = client.read(cx).connection_options(); - if let Some(existing_client) = self.ssh_clients.get(&ssh_url) { - if existing_client.upgrade().is_some() { - return; - } + if let Some(existing_client) = self.remote_clients.get(&options) + && existing_client.upgrade().is_some() + { + return; } - self.ssh_clients.insert(ssh_url, client.downgrade()); + self.remote_clients.insert(options, client.downgrade()); self.ssh_registered_tx.unbounded_send(()).ok(); } } diff --git a/crates/extension_host/src/extension_settings.rs b/crates/extension_host/src/extension_settings.rs index cfa67990b09de9fda5bf0e26229a9b1b1410de46..fa5a613c55a76a0b5660b114d49acc17fcf79120 100644 --- a/crates/extension_host/src/extension_settings.rs +++ b/crates/extension_host/src/extension_settings.rs @@ -3,10 +3,11 @@ use collections::HashMap; use gpui::App; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsUi}; use std::sync::Arc; -#[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)] +#[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema, SettingsUi, SettingsKey)] +#[settings_key(None)] pub struct ExtensionSettings { /// The extensions that should be automatically installed by Zed. /// @@ -38,8 +39,6 @@ impl ExtensionSettings { } impl Settings for ExtensionSettings { - const KEY: Option<&'static str> = None; - type FileContent = Self; fn load(sources: SettingsSources, _cx: &mut App) -> Result { diff --git a/crates/extension_host/src/extension_store_test.rs b/crates/extension_host/src/extension_store_test.rs index c31774c20d3e94f829e8de5d6ca822228735ca18..347a610439c98ae020a7ebf190dd9e1a603df5a1 100644 --- a/crates/extension_host/src/extension_store_test.rs +++ b/crates/extension_host/src/extension_store_test.rs @@ -160,7 +160,6 @@ async fn test_extension_store(cx: &mut TestAppContext) { language_servers: BTreeMap::default(), context_servers: BTreeMap::default(), slash_commands: BTreeMap::default(), - indexed_docs_providers: BTreeMap::default(), snippets: None, capabilities: Vec::new(), debug_adapters: Default::default(), @@ -191,7 +190,6 @@ async fn test_extension_store(cx: &mut TestAppContext) { language_servers: BTreeMap::default(), context_servers: BTreeMap::default(), slash_commands: BTreeMap::default(), - indexed_docs_providers: BTreeMap::default(), snippets: None, capabilities: Vec::new(), debug_adapters: Default::default(), @@ -371,7 +369,6 @@ async fn test_extension_store(cx: &mut TestAppContext) { language_servers: BTreeMap::default(), context_servers: BTreeMap::default(), slash_commands: BTreeMap::default(), - indexed_docs_providers: BTreeMap::default(), snippets: None, capabilities: Vec::new(), debug_adapters: Default::default(), diff --git a/crates/extension_host/src/headless_host.rs b/crates/extension_host/src/headless_host.rs index adc9638c2998eb1f122df5137577ca7e0cf4c975..a6305118cd3355f69a42914ec86bb5edcfc74810 100644 --- a/crates/extension_host/src/headless_host.rs +++ b/crates/extension_host/src/headless_host.rs @@ -163,6 +163,7 @@ impl HeadlessExtensionStore { queries: LanguageQueries::default(), context_provider: None, toolchain_provider: None, + manifest_name: None, }) }), ); @@ -174,7 +175,7 @@ impl HeadlessExtensionStore { } let wasm_extension: Arc = - Arc::new(WasmExtension::load(&extension_dir, &manifest, wasm_host.clone(), &cx).await?); + Arc::new(WasmExtension::load(&extension_dir, &manifest, wasm_host.clone(), cx).await?); for (language_server_id, language_server_config) in &manifest.language_servers { for language in language_server_config.languages() { diff --git a/crates/extension_host/src/wasm_host.rs b/crates/extension_host/src/wasm_host.rs index d990b670f49221aca2f0af901293c70d341cf029..c5bc21fc1c44659b845b7616aa1714a0872f90f3 100644 --- a/crates/extension_host/src/wasm_host.rs +++ b/crates/extension_host/src/wasm_host.rs @@ -532,7 +532,7 @@ fn wasm_engine(executor: &BackgroundExecutor) -> wasmtime::Engine { // `Future::poll`. const EPOCH_INTERVAL: Duration = Duration::from_millis(100); let mut timer = Timer::interval(EPOCH_INTERVAL); - while let Some(_) = timer.next().await { + while (timer.next().await).is_some() { // Exit the loop and thread once the engine is dropped. let Some(engine) = engine_ref.upgrade() else { break; @@ -701,16 +701,15 @@ pub fn parse_wasm_extension_version( for part in wasmparser::Parser::new(0).parse_all(wasm_bytes) { if let wasmparser::Payload::CustomSection(s) = part.context("error parsing wasm extension")? + && s.name() == "zed:api-version" { - if s.name() == "zed:api-version" { - version = parse_wasm_extension_version_custom_section(s.data()); - if version.is_none() { - bail!( - "extension {} has invalid zed:api-version section: {:?}", - extension_id, - s.data() - ); - } + version = parse_wasm_extension_version_custom_section(s.data()); + if version.is_none() { + bail!( + "extension {} has invalid zed:api-version section: {:?}", + extension_id, + s.data() + ); } } } diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs index 767b9033ade3c81c6ac149363676513c72996b7e..84794d5386eda1517808d181eb259a3264f7b82d 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs @@ -938,7 +938,7 @@ impl ExtensionImports for WasmState { binary: settings.binary.map(|binary| settings::CommandSettings { path: binary.path, arguments: binary.arguments, - env: binary.env, + env: binary.env.map(|env| env.into_iter().collect()), }), settings: settings.settings, initialization_options: settings.initialization_options, diff --git a/crates/extensions_ui/src/components/feature_upsell.rs b/crates/extensions_ui/src/components/feature_upsell.rs index 573b0b992d343e04b74531ffeb8579f28c92620c..0515dd46d30ce9f7e87331f99542940c3efa837a 100644 --- a/crates/extensions_ui/src/components/feature_upsell.rs +++ b/crates/extensions_ui/src/components/feature_upsell.rs @@ -61,7 +61,6 @@ impl RenderOnce for FeatureUpsell { .icon_size(IconSize::Small) .icon_position(IconPosition::End) .on_click({ - let docs_url = docs_url.clone(); move |_event, _window, cx| { telemetry::event!( "Documentation Viewed", diff --git a/crates/extensions_ui/src/extension_version_selector.rs b/crates/extensions_ui/src/extension_version_selector.rs index aaf5d5e8eb8308f3833e2638f1e0e72186f3d983..fe7a419fbe8001b99d5c3ebf16dfc38cda3fc713 100644 --- a/crates/extensions_ui/src/extension_version_selector.rs +++ b/crates/extensions_ui/src/extension_version_selector.rs @@ -207,8 +207,8 @@ impl PickerDelegate for ExtensionVersionSelectorDelegate { _: &mut Window, cx: &mut Context>, ) -> Option { - let version_match = &self.matches[ix]; - let extension_version = &self.extension_versions[version_match.candidate_id]; + let version_match = &self.matches.get(ix)?; + let extension_version = &self.extension_versions.get(version_match.candidate_id)?; let is_version_compatible = extension_host::is_version_compatible(ReleaseChannel::global(cx), extension_version); diff --git a/crates/extensions_ui/src/extensions_ui.rs b/crates/extensions_ui/src/extensions_ui.rs index fe3e94f5c20dc1a78ae01defc24e290c18a1a3e6..82ee54174567987d00478815f6a4eefd94333202 100644 --- a/crates/extensions_ui/src/extensions_ui.rs +++ b/crates/extensions_ui/src/extensions_ui.rs @@ -116,6 +116,7 @@ pub fn init(cx: &mut App) { files: false, directories: true, multiple: false, + prompt: None, }, DirectoryLister::Local( workspace.project().clone(), @@ -326,7 +327,7 @@ impl ExtensionsPage { let query_editor = cx.new(|cx| { let mut input = Editor::single_line(window, cx); - input.set_placeholder_text("Search extensions...", cx); + input.set_placeholder_text("Search extensions...", window, cx); if let Some(id) = focus_extension_id { input.set_text(format!("id:{id}"), window, cx); } @@ -693,7 +694,7 @@ impl ExtensionsPage { cx.open_url(&repository_url); } })) - .tooltip(Tooltip::text(repository_url.clone())) + .tooltip(Tooltip::text(repository_url)) })), ) } @@ -703,7 +704,7 @@ impl ExtensionsPage { extension: &ExtensionMetadata, cx: &mut Context, ) -> ExtensionCard { - let this = cx.entity().clone(); + let this = cx.entity(); let status = Self::extension_status(&extension.id, cx); let has_dev_extension = Self::dev_extension_exists(&extension.id, cx); @@ -826,7 +827,7 @@ impl ExtensionsPage { cx.open_url(&repository_url); } })) - .tooltip(Tooltip::text(repository_url.clone())), + .tooltip(Tooltip::text(repository_url)), ) .child( PopoverMenu::new(SharedString::from(format!( @@ -862,7 +863,7 @@ impl ExtensionsPage { window: &mut Window, cx: &mut App, ) -> Entity { - let context_menu = ContextMenu::build(window, cx, |context_menu, window, _| { + ContextMenu::build(window, cx, |context_menu, window, _| { context_menu .entry( "Install Another Version...", @@ -886,9 +887,7 @@ impl ExtensionsPage { cx.write_to_clipboard(ClipboardItem::new_string(authors.join(", "))); } }) - }); - - context_menu + }) } fn show_extension_version_list( @@ -1030,15 +1029,14 @@ impl ExtensionsPage { .read(cx) .extension_manifest_for_id(&extension_id) .cloned() + && let Some(events) = extension::ExtensionEvents::try_global(cx) { - if let Some(events) = extension::ExtensionEvents::try_global(cx) { - events.update(cx, |this, cx| { - this.emit( - extension::Event::ConfigureExtensionRequested(manifest), - cx, - ) - }); - } + events.update(cx, |this, cx| { + this.emit( + extension::Event::ConfigureExtensionRequested(manifest), + cx, + ) + }); } } }) @@ -1347,7 +1345,7 @@ impl ExtensionsPage { this.update_settings::( selection, cx, - |setting, value| *setting = Some(value), + |setting, value| setting.vim_mode = Some(value), ); }), )), diff --git a/crates/feature_flags/src/feature_flags.rs b/crates/feature_flags/src/feature_flags.rs index ef357adf35997bfb7560f1e1849ef69e780cd1f9..4afeb32235114ea6d2e29042e9d4e465043c19da 100644 --- a/crates/feature_flags/src/feature_flags.rs +++ b/crates/feature_flags/src/feature_flags.rs @@ -14,7 +14,7 @@ struct FeatureFlags { } pub static ZED_DISABLE_STAFF: LazyLock = LazyLock::new(|| { - std::env::var("ZED_DISABLE_STAFF").map_or(false, |value| !value.is_empty() && value != "0") + std::env::var("ZED_DISABLE_STAFF").is_ok_and(|value| !value.is_empty() && value != "0") }); impl FeatureFlags { @@ -23,7 +23,7 @@ impl FeatureFlags { return true; } - if self.staff && T::enabled_for_staff() { + if (cfg!(debug_assertions) || self.staff) && !*ZED_DISABLE_STAFF && T::enabled_for_staff() { return true; } @@ -66,9 +66,10 @@ impl FeatureFlag for LlmClosedBetaFeatureFlag { const NAME: &'static str = "llm-closed-beta"; } -pub struct ZedProFeatureFlag {} -impl FeatureFlag for ZedProFeatureFlag { - const NAME: &'static str = "zed-pro"; +pub struct BillingV2FeatureFlag {} + +impl FeatureFlag for BillingV2FeatureFlag { + const NAME: &'static str = "billing-v2"; } pub struct NotebookFeatureFlag; @@ -77,14 +78,6 @@ impl FeatureFlag for NotebookFeatureFlag { const NAME: &'static str = "notebooks"; } -pub struct ThreadAutoCaptureFeatureFlag {} -impl FeatureFlag for ThreadAutoCaptureFeatureFlag { - const NAME: &'static str = "thread-auto-capture"; - - fn enabled_for_staff() -> bool { - false - } -} pub struct PanicFeatureFlag; impl FeatureFlag for PanicFeatureFlag { @@ -97,10 +90,29 @@ impl FeatureFlag for JjUiFeatureFlag { const NAME: &'static str = "jj-ui"; } -pub struct AcpFeatureFlag; +pub struct GeminiAndNativeFeatureFlag; + +impl FeatureFlag for GeminiAndNativeFeatureFlag { + // This was previously called "acp". + // + // We renamed it because existing builds used it to enable the Claude Code + // integration too, and we'd like to turn Gemini/Native on in new builds + // without enabling Claude Code in old builds. + const NAME: &'static str = "gemini-and-native"; + + fn enabled_for_all() -> bool { + true + } +} + +pub struct ClaudeCodeFeatureFlag; + +impl FeatureFlag for ClaudeCodeFeatureFlag { + const NAME: &'static str = "claude-code"; -impl FeatureFlag for AcpFeatureFlag { - const NAME: &'static str = "acp"; + fn enabled_for_all() -> bool { + true + } } pub trait FeatureFlagViewExt { @@ -198,7 +210,10 @@ impl FeatureFlagAppExt for App { fn has_flag(&self) -> bool { self.try_global::() .map(|flags| flags.has_flag::()) - .unwrap_or(false) + .unwrap_or_else(|| { + (cfg!(debug_assertions) && T::enabled_for_staff() && !*ZED_DISABLE_STAFF) + || T::enabled_for_all() + }) } fn is_staff(&self) -> bool { diff --git a/crates/feedback/Cargo.toml b/crates/feedback/Cargo.toml index 3a2c1fd7131ef7b5d7b07b8ec036fff4f1bba621..db872f7a15035c5012d42680c2d812d3486c6a89 100644 --- a/crates/feedback/Cargo.toml +++ b/crates/feedback/Cargo.toml @@ -15,13 +15,9 @@ path = "src/feedback.rs" test-support = [] [dependencies] -client.workspace = true gpui.workspace = true -human_bytes = "0.4.1" menu.workspace = true -release_channel.workspace = true -serde.workspace = true -sysinfo.workspace = true +system_specs.workspace = true ui.workspace = true urlencoding.workspace = true util.workspace = true diff --git a/crates/feedback/src/feedback.rs b/crates/feedback/src/feedback.rs index 40c2707d34c9f5ab50bdb51c8b82183be2106285..3822dd7ba38ac8131df4f391b8b0a5c05978fe8d 100644 --- a/crates/feedback/src/feedback.rs +++ b/crates/feedback/src/feedback.rs @@ -1,18 +1,14 @@ use gpui::{App, ClipboardItem, PromptLevel, actions}; -use system_specs::SystemSpecs; +use system_specs::{CopySystemSpecsIntoClipboard, SystemSpecs}; use util::ResultExt; use workspace::Workspace; use zed_actions::feedback::FileBugReport; pub mod feedback_modal; -pub mod system_specs; - actions!( zed, [ - /// Copies system specifications to the clipboard for bug reports. - CopySystemSpecsIntoClipboard, /// Opens email client to send feedback to Zed support. EmailZed, /// Opens the Zed repository on GitHub. diff --git a/crates/file_finder/src/file_finder.rs b/crates/file_finder/src/file_finder.rs index c6997ccdc0c89be67442e9ac2b16f61512feb141..53cf0552f22a59c31ab2422d86eb8cbb76145908 100644 --- a/crates/file_finder/src/file_finder.rs +++ b/crates/file_finder/src/file_finder.rs @@ -209,11 +209,11 @@ impl FileFinder { let Some(init_modifiers) = self.init_modifiers.take() else { return; }; - if self.picker.read(cx).delegate.has_changed_selected_index { - if !event.modified() || !init_modifiers.is_subset_of(&event) { - self.init_modifiers = None; - window.dispatch_action(menu::Confirm.boxed_clone(), cx); - } + if self.picker.read(cx).delegate.has_changed_selected_index + && (!event.modified() || !init_modifiers.is_subset_of(event)) + { + self.init_modifiers = None; + window.dispatch_action(menu::Confirm.boxed_clone(), cx); } } @@ -267,10 +267,9 @@ impl FileFinder { ) { self.picker.update(cx, |picker, cx| { picker.delegate.include_ignored = match picker.delegate.include_ignored { - Some(true) => match FileFinderSettings::get_global(cx).include_ignored { - Some(_) => Some(false), - None => None, - }, + Some(true) => FileFinderSettings::get_global(cx) + .include_ignored + .map(|_| false), Some(false) => Some(true), None => Some(true), }; @@ -323,27 +322,27 @@ impl FileFinder { ) { self.picker.update(cx, |picker, cx| { let delegate = &mut picker.delegate; - if let Some(workspace) = delegate.workspace.upgrade() { - if let Some(m) = delegate.matches.get(delegate.selected_index()) { - let path = match &m { - Match::History { path, .. } => { - let worktree_id = path.project.worktree_id; - ProjectPath { - worktree_id, - path: Arc::clone(&path.project.path), - } + if let Some(workspace) = delegate.workspace.upgrade() + && let Some(m) = delegate.matches.get(delegate.selected_index()) + { + let path = match &m { + Match::History { path, .. } => { + let worktree_id = path.project.worktree_id; + ProjectPath { + worktree_id, + path: Arc::clone(&path.project.path), } - Match::Search(m) => ProjectPath { - worktree_id: WorktreeId::from_usize(m.0.worktree_id), - path: m.0.path.clone(), - }, - Match::CreateNew(p) => p.clone(), - }; - let open_task = workspace.update(cx, move |workspace, cx| { - workspace.split_path_preview(path, false, Some(split_direction), window, cx) - }); - open_task.detach_and_log_err(cx); - } + } + Match::Search(m) => ProjectPath { + worktree_id: WorktreeId::from_usize(m.0.worktree_id), + path: m.0.path.clone(), + }, + Match::CreateNew(p) => p.clone(), + }; + let open_task = workspace.update(cx, move |workspace, cx| { + workspace.split_path_preview(path, false, Some(split_direction), window, cx) + }); + open_task.detach_and_log_err(cx); } }) } @@ -497,7 +496,7 @@ impl Match { fn panel_match(&self) -> Option<&ProjectPanelOrdMatch> { match self { Match::History { panel_match, .. } => panel_match.as_ref(), - Match::Search(panel_match) => Some(&panel_match), + Match::Search(panel_match) => Some(panel_match), Match::CreateNew(_) => None, } } @@ -537,7 +536,7 @@ impl Matches { self.matches.binary_search_by(|m| { // `reverse()` since if cmp_matches(a, b) == Ordering::Greater, then a is better than b. // And we want the better entries go first. - Self::cmp_matches(self.separate_history, currently_opened, &m, &entry).reverse() + Self::cmp_matches(self.separate_history, currently_opened, m, entry).reverse() }) } } @@ -675,17 +674,17 @@ impl Matches { let path_str = panel_match.0.path.to_string_lossy(); let filename_str = filename.to_string_lossy(); - if let Some(filename_pos) = path_str.rfind(&*filename_str) { - if panel_match.0.positions[0] >= filename_pos { - let mut prev_position = panel_match.0.positions[0]; - for p in &panel_match.0.positions[1..] { - if *p != prev_position + 1 { - return false; - } - prev_position = *p; + if let Some(filename_pos) = path_str.rfind(&*filename_str) + && panel_match.0.positions[0] >= filename_pos + { + let mut prev_position = panel_match.0.positions[0]; + for p in &panel_match.0.positions[1..] { + if *p != prev_position + 1 { + return false; } - return true; + prev_position = *p; } + return true; } } @@ -878,9 +877,7 @@ impl FileFinderDelegate { PathMatchCandidateSet { snapshot: worktree.snapshot(), include_ignored: self.include_ignored.unwrap_or_else(|| { - worktree - .root_entry() - .map_or(false, |entry| entry.is_ignored) + worktree.root_entry().is_some_and(|entry| entry.is_ignored) }), include_root_name, candidates: project::Candidates::Files, @@ -1045,10 +1042,10 @@ impl FileFinderDelegate { ) } else { let mut path = Arc::clone(project_relative_path); - if project_relative_path.as_ref() == Path::new("") { - if let Some(absolute_path) = &entry_path.absolute { - path = Arc::from(absolute_path.as_path()); - } + if project_relative_path.as_ref() == Path::new("") + && let Some(absolute_path) = &entry_path.absolute + { + path = Arc::from(absolute_path.as_path()); } let mut path_match = PathMatch { @@ -1078,23 +1075,21 @@ impl FileFinderDelegate { ), }; - if file_name_positions.is_empty() { - if let Some(user_home_path) = std::env::var("HOME").ok() { - let user_home_path = user_home_path.trim(); - if !user_home_path.is_empty() { - if (&full_path).starts_with(user_home_path) { - full_path.replace_range(0..user_home_path.len(), "~"); - full_path_positions.retain_mut(|pos| { - if *pos >= user_home_path.len() { - *pos -= user_home_path.len(); - *pos += 1; - true - } else { - false - } - }) + if file_name_positions.is_empty() + && let Some(user_home_path) = std::env::var("HOME").ok() + { + let user_home_path = user_home_path.trim(); + if !user_home_path.is_empty() && full_path.starts_with(user_home_path) { + full_path.replace_range(0..user_home_path.len(), "~"); + full_path_positions.retain_mut(|pos| { + if *pos >= user_home_path.len() { + *pos -= user_home_path.len(); + *pos += 1; + true + } else { + false } - } + }) } } @@ -1242,14 +1237,13 @@ impl FileFinderDelegate { /// Skips first history match (that is displayed topmost) if it's currently opened. fn calculate_selected_index(&self, cx: &mut Context>) -> usize { - if FileFinderSettings::get_global(cx).skip_focus_for_active_in_search { - if let Some(Match::History { path, .. }) = self.matches.get(0) { - if Some(path) == self.currently_opened_path.as_ref() { - let elements_after_first = self.matches.len() - 1; - if elements_after_first > 0 { - return 1; - } - } + if FileFinderSettings::get_global(cx).skip_focus_for_active_in_search + && let Some(Match::History { path, .. }) = self.matches.get(0) + && Some(path) == self.currently_opened_path.as_ref() + { + let elements_after_first = self.matches.len() - 1; + if elements_after_first > 0 { + return 1; } } @@ -1310,10 +1304,10 @@ impl PickerDelegate for FileFinderDelegate { .enumerate() .find(|(_, m)| !matches!(m, Match::History { .. })) .map(|(i, _)| i); - if let Some(first_non_history_index) = first_non_history_index { - if first_non_history_index > 0 { - return vec![first_non_history_index - 1]; - } + if let Some(first_non_history_index) = first_non_history_index + && first_non_history_index > 0 + { + return vec![first_non_history_index - 1]; } } Vec::new() @@ -1387,7 +1381,7 @@ impl PickerDelegate for FileFinderDelegate { project .worktree_for_id(history_item.project.worktree_id, cx) .is_some() - || ((project.is_local() || project.is_via_ssh()) + || ((project.is_local() || project.is_via_remote_server()) && history_item.absolute.is_some()) }), self.currently_opened_path.as_ref(), @@ -1402,18 +1396,21 @@ impl PickerDelegate for FileFinderDelegate { cx.notify(); Task::ready(()) } else { - let path_position = PathWithPosition::parse_str(&raw_query); + let path_position = PathWithPosition::parse_str(raw_query); #[cfg(windows)] let raw_query = raw_query.trim().to_owned().replace("/", "\\"); #[cfg(not(windows))] - let raw_query = raw_query.trim().to_owned(); + let raw_query = raw_query.trim(); - let file_query_end = if path_position.path.to_str().unwrap_or(&raw_query) == raw_query { + let raw_query = raw_query.trim_end_matches(':').to_owned(); + let path = path_position.path.to_str(); + let path_trimmed = path.unwrap_or(&raw_query).trim_end_matches(':'); + let file_query_end = if path_trimmed == raw_query { None } else { // Safe to unwrap as we won't get here when the unwrap in if fails - Some(path_position.path.to_str().unwrap().len()) + Some(path.unwrap().len()) }; let query = FileSearchQuery { @@ -1436,69 +1433,101 @@ impl PickerDelegate for FileFinderDelegate { window: &mut Window, cx: &mut Context>, ) { - if let Some(m) = self.matches.get(self.selected_index()) { - if let Some(workspace) = self.workspace.upgrade() { - let open_task = workspace.update(cx, |workspace, cx| { - let split_or_open = - |workspace: &mut Workspace, - project_path, - window: &mut Window, - cx: &mut Context| { - let allow_preview = - PreviewTabsSettings::get_global(cx).enable_preview_from_file_finder; - if secondary { - workspace.split_path_preview( - project_path, - allow_preview, - None, - window, - cx, - ) - } else { - workspace.open_path_preview( - project_path, - None, - true, - allow_preview, - true, - window, - cx, - ) - } - }; - match &m { - Match::CreateNew(project_path) => { - // Create a new file with the given filename - if secondary { - workspace.split_path_preview( - project_path.clone(), - false, - None, - window, - cx, - ) - } else { - workspace.open_path_preview( - project_path.clone(), - None, - true, - false, - true, - window, - cx, - ) - } + if let Some(m) = self.matches.get(self.selected_index()) + && let Some(workspace) = self.workspace.upgrade() + { + let open_task = workspace.update(cx, |workspace, cx| { + let split_or_open = + |workspace: &mut Workspace, + project_path, + window: &mut Window, + cx: &mut Context| { + let allow_preview = + PreviewTabsSettings::get_global(cx).enable_preview_from_file_finder; + if secondary { + workspace.split_path_preview( + project_path, + allow_preview, + None, + window, + cx, + ) + } else { + workspace.open_path_preview( + project_path, + None, + true, + allow_preview, + true, + window, + cx, + ) } + }; + match &m { + Match::CreateNew(project_path) => { + // Create a new file with the given filename + if secondary { + workspace.split_path_preview( + project_path.clone(), + false, + None, + window, + cx, + ) + } else { + workspace.open_path_preview( + project_path.clone(), + None, + true, + false, + true, + window, + cx, + ) + } + } - Match::History { path, .. } => { - let worktree_id = path.project.worktree_id; - if workspace - .project() - .read(cx) - .worktree_for_id(worktree_id, cx) - .is_some() - { - split_or_open( + Match::History { path, .. } => { + let worktree_id = path.project.worktree_id; + if workspace + .project() + .read(cx) + .worktree_for_id(worktree_id, cx) + .is_some() + { + split_or_open( + workspace, + ProjectPath { + worktree_id, + path: Arc::clone(&path.project.path), + }, + window, + cx, + ) + } else { + match path.absolute.as_ref() { + Some(abs_path) => { + if secondary { + workspace.split_abs_path( + abs_path.to_path_buf(), + false, + window, + cx, + ) + } else { + workspace.open_abs_path( + abs_path.to_path_buf(), + OpenOptions { + visible: Some(OpenVisible::None), + ..Default::default() + }, + window, + cx, + ) + } + } + None => split_or_open( workspace, ProjectPath { worktree_id, @@ -1506,88 +1535,52 @@ impl PickerDelegate for FileFinderDelegate { }, window, cx, - ) - } else { - match path.absolute.as_ref() { - Some(abs_path) => { - if secondary { - workspace.split_abs_path( - abs_path.to_path_buf(), - false, - window, - cx, - ) - } else { - workspace.open_abs_path( - abs_path.to_path_buf(), - OpenOptions { - visible: Some(OpenVisible::None), - ..Default::default() - }, - window, - cx, - ) - } - } - None => split_or_open( - workspace, - ProjectPath { - worktree_id, - path: Arc::clone(&path.project.path), - }, - window, - cx, - ), - } + ), } } - Match::Search(m) => split_or_open( - workspace, - ProjectPath { - worktree_id: WorktreeId::from_usize(m.0.worktree_id), - path: m.0.path.clone(), - }, - window, - cx, - ), } - }); + Match::Search(m) => split_or_open( + workspace, + ProjectPath { + worktree_id: WorktreeId::from_usize(m.0.worktree_id), + path: m.0.path.clone(), + }, + window, + cx, + ), + } + }); - let row = self - .latest_search_query - .as_ref() - .and_then(|query| query.path_position.row) - .map(|row| row.saturating_sub(1)); - let col = self - .latest_search_query - .as_ref() - .and_then(|query| query.path_position.column) - .unwrap_or(0) - .saturating_sub(1); - let finder = self.file_finder.clone(); - - cx.spawn_in(window, async move |_, cx| { - let item = open_task.await.notify_async_err(cx)?; - if let Some(row) = row { - if let Some(active_editor) = item.downcast::() { - active_editor - .downgrade() - .update_in(cx, |editor, window, cx| { - editor.go_to_singleton_buffer_point( - Point::new(row, col), - window, - cx, - ); - }) - .log_err(); - } - } - finder.update(cx, |_, cx| cx.emit(DismissEvent)).ok()?; + let row = self + .latest_search_query + .as_ref() + .and_then(|query| query.path_position.row) + .map(|row| row.saturating_sub(1)); + let col = self + .latest_search_query + .as_ref() + .and_then(|query| query.path_position.column) + .unwrap_or(0) + .saturating_sub(1); + let finder = self.file_finder.clone(); + + cx.spawn_in(window, async move |_, cx| { + let item = open_task.await.notify_async_err(cx)?; + if let Some(row) = row + && let Some(active_editor) = item.downcast::() + { + active_editor + .downgrade() + .update_in(cx, |editor, window, cx| { + editor.go_to_singleton_buffer_point(Point::new(row, col), window, cx); + }) + .log_err(); + } + finder.update(cx, |_, cx| cx.emit(DismissEvent)).ok()?; - Some(()) - }) - .detach(); - } + Some(()) + }) + .detach(); } } @@ -1606,10 +1599,7 @@ impl PickerDelegate for FileFinderDelegate { ) -> Option { let settings = FileFinderSettings::get_global(cx); - let path_match = self - .matches - .get(ix) - .expect("Invalid matches state: no element for index {ix}"); + let path_match = self.matches.get(ix)?; let history_icon = match &path_match { Match::History { .. } => Icon::new(IconName::HistoryRerun) @@ -1759,7 +1749,7 @@ impl PickerDelegate for FileFinderDelegate { Some(ContextMenu::build(window, cx, { let focus_handle = focus_handle.clone(); move |menu, _, _| { - menu.context(focus_handle.clone()) + menu.context(focus_handle) .action( "Split Left", pane::SplitLeft.boxed_clone(), diff --git a/crates/file_finder/src/file_finder_settings.rs b/crates/file_finder/src/file_finder_settings.rs index 350e1de3b36c9073d137993ce4fbc50aa43bb36e..6a6b98b8ea3e1c7e7f0e3cc0385fdd7f413b659f 100644 --- a/crates/file_finder/src/file_finder_settings.rs +++ b/crates/file_finder/src/file_finder_settings.rs @@ -1,7 +1,7 @@ use anyhow::Result; use schemars::JsonSchema; use serde_derive::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsUi}; #[derive(Deserialize, Debug, Clone, Copy, PartialEq)] pub struct FileFinderSettings { @@ -11,7 +11,8 @@ pub struct FileFinderSettings { pub include_ignored: Option, } -#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)] +#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug, SettingsUi, SettingsKey)] +#[settings_key(key = "file_finder")] pub struct FileFinderSettingsContent { /// Whether to show file icons in the file finder. /// @@ -42,8 +43,6 @@ pub struct FileFinderSettingsContent { } impl Settings for FileFinderSettings { - const KEY: Option<&'static str> = Some("file_finder"); - type FileContent = FileFinderSettingsContent; fn load(sources: SettingsSources, _: &mut gpui::App) -> Result { diff --git a/crates/file_finder/src/file_finder_tests.rs b/crates/file_finder/src/file_finder_tests.rs index db259ccef854b1d3c5c4fae3bc9ebad08e398891..cd0f203d6a300b4039df74a646bf0a9d56818347 100644 --- a/crates/file_finder/src/file_finder_tests.rs +++ b/crates/file_finder/src/file_finder_tests.rs @@ -218,6 +218,7 @@ async fn test_matching_paths(cx: &mut TestAppContext) { " ndan ", " band ", "a bandana", + "bandana:", ] { picker .update_in(cx, |picker, window, cx| { @@ -252,6 +253,53 @@ async fn test_matching_paths(cx: &mut TestAppContext) { } } +#[gpui::test] +async fn test_matching_paths_with_colon(cx: &mut TestAppContext) { + let app_state = init_test(cx); + app_state + .fs + .as_fake() + .insert_tree( + path!("/root"), + json!({ + "a": { + "foo:bar.rs": "", + "foo.rs": "", + } + }), + ) + .await; + + let project = Project::test(app_state.fs.clone(), [path!("/root").as_ref()], cx).await; + + let (picker, _, cx) = build_find_picker(project, cx); + + // 'foo:' matches both files + cx.simulate_input("foo:"); + picker.update(cx, |picker, _| { + assert_eq!(picker.delegate.matches.len(), 3); + assert_match_at_position(picker, 0, "foo.rs"); + assert_match_at_position(picker, 1, "foo:bar.rs"); + }); + + // 'foo:b' matches one of the files + cx.simulate_input("b"); + picker.update(cx, |picker, _| { + assert_eq!(picker.delegate.matches.len(), 2); + assert_match_at_position(picker, 0, "foo:bar.rs"); + }); + + cx.dispatch_action(editor::actions::Backspace); + + // 'foo:1' matches both files, specifying which row to jump to + cx.simulate_input("1"); + picker.update(cx, |picker, _| { + assert_eq!(picker.delegate.matches.len(), 3); + assert_match_at_position(picker, 0, "foo.rs"); + assert_match_at_position(picker, 1, "foo:bar.rs"); + }); +} + #[gpui::test] async fn test_unicode_paths(cx: &mut TestAppContext) { let app_state = init_test(cx); @@ -1614,7 +1662,7 @@ async fn test_select_current_open_file_when_no_history(cx: &mut gpui::TestAppCon let picker = open_file_picker(&workspace, cx); picker.update(cx, |finder, _| { - assert_match_selection(&finder, 0, "1_qw"); + assert_match_selection(finder, 0, "1_qw"); }); } @@ -2623,7 +2671,7 @@ async fn open_queried_buffer( workspace: &Entity, cx: &mut gpui::VisualTestContext, ) -> Vec { - let picker = open_file_picker(&workspace, cx); + let picker = open_file_picker(workspace, cx); cx.simulate_input(input); let history_items = picker.update(cx, |finder, _| { diff --git a/crates/file_finder/src/open_path_prompt.rs b/crates/file_finder/src/open_path_prompt.rs index 68ba7a78b52fee42588b732d7a6a3c582a80061f..ab00f943b811a941f7339853a79f81dfea5275eb 100644 --- a/crates/file_finder/src/open_path_prompt.rs +++ b/crates/file_finder/src/open_path_prompt.rs @@ -1,7 +1,7 @@ use crate::file_finder_settings::FileFinderSettings; use file_icons::FileIcons; use futures::channel::oneshot; -use fuzzy::{StringMatch, StringMatchCandidate}; +use fuzzy::{CharBag, StringMatch, StringMatchCandidate}; use gpui::{HighlightStyle, StyledText, Task}; use picker::{Picker, PickerDelegate}; use project::{DirectoryItem, DirectoryLister}; @@ -23,7 +23,6 @@ use workspace::Workspace; pub(crate) struct OpenPathPrompt; -#[derive(Debug)] pub struct OpenPathDelegate { tx: Option>>>, lister: DirectoryLister, @@ -35,6 +34,9 @@ pub struct OpenPathDelegate { prompt_root: String, path_style: PathStyle, replace_prompt: Task<()>, + render_footer: + Arc>) -> Option + 'static>, + hidden_entries: bool, } impl OpenPathDelegate { @@ -60,9 +62,25 @@ impl OpenPathDelegate { }, path_style, replace_prompt: Task::ready(()), + render_footer: Arc::new(|_, _| None), + hidden_entries: false, } } + pub fn with_footer( + mut self, + footer: Arc< + dyn Fn(&mut Window, &mut Context>) -> Option + 'static, + >, + ) -> Self { + self.render_footer = footer; + self + } + + pub fn show_hidden(mut self) -> Self { + self.hidden_entries = true; + self + } fn get_entry(&self, selected_match_index: usize) -> Option { match &self.directory_state { DirectoryState::List { entries, .. } => { @@ -75,16 +93,16 @@ impl OpenPathDelegate { .. } => { let mut i = selected_match_index; - if let Some(user_input) = user_input { - if !user_input.exists || !user_input.is_dir { - if i == 0 { - return Some(CandidateInfo { - path: user_input.file.clone(), - is_dir: false, - }); - } else { - i -= 1; - } + if let Some(user_input) = user_input + && (!user_input.exists || !user_input.is_dir) + { + if i == 0 { + return Some(CandidateInfo { + path: user_input.file.clone(), + is_dir: false, + }); + } else { + i -= 1; } } let id = self.string_matches.get(i)?.candidate_id; @@ -112,7 +130,7 @@ impl OpenPathDelegate { entries, .. } => user_input - .into_iter() + .iter() .filter(|user_input| !user_input.exists || !user_input.is_dir) .map(|user_input| user_input.file.string.clone()) .chain(self.string_matches.iter().filter_map(|string_match| { @@ -125,6 +143,13 @@ impl OpenPathDelegate { DirectoryState::None { .. } => Vec::new(), } } + + fn current_dir(&self) -> &'static str { + match self.path_style { + PathStyle::Posix => "./", + PathStyle::Windows => ".\\", + } + } } #[derive(Debug)] @@ -233,6 +258,7 @@ impl PickerDelegate for OpenPathDelegate { cx: &mut Context>, ) -> Task<()> { let lister = &self.lister; + let input_is_empty = query.is_empty(); let (dir, suffix) = get_dir_and_suffix(query, self.path_style); let query = match &self.directory_state { @@ -261,8 +287,9 @@ impl PickerDelegate for OpenPathDelegate { self.cancel_flag.store(true, atomic::Ordering::Release); self.cancel_flag = Arc::new(AtomicBool::new(false)); let cancel_flag = self.cancel_flag.clone(); - + let hidden_entries = self.hidden_entries; let parent_path_is_root = self.prompt_root == dir; + let current_dir = self.current_dir(); cx.spawn_in(window, async move |this, cx| { if let Some(query) = query { let paths = query.await; @@ -353,10 +380,39 @@ impl PickerDelegate for OpenPathDelegate { return; }; - if !suffix.starts_with('.') { - new_entries.retain(|entry| !entry.path.string.starts_with('.')); + let mut max_id = 0; + if !suffix.starts_with('.') && !hidden_entries { + new_entries.retain(|entry| { + max_id = max_id.max(entry.path.id); + !entry.path.string.starts_with('.') + }); } + if suffix.is_empty() { + let should_prepend_with_current_dir = this + .read_with(cx, |picker, _| { + !input_is_empty + && match &picker.delegate.directory_state { + DirectoryState::List { error, .. } => error.is_none(), + DirectoryState::Create { .. } => false, + DirectoryState::None { .. } => false, + } + }) + .unwrap_or(false); + if should_prepend_with_current_dir { + new_entries.insert( + 0, + CandidateInfo { + path: StringMatchCandidate { + id: max_id + 1, + string: current_dir.to_string(), + char_bag: CharBag::from(current_dir), + }, + is_dir: true, + }, + ); + } + this.update(cx, |this, cx| { this.delegate.selected_index = 0; this.delegate.string_matches = new_entries @@ -485,6 +541,10 @@ impl PickerDelegate for OpenPathDelegate { _: &mut Context>, ) -> Option { let candidate = self.get_entry(self.selected_index)?; + if candidate.path.string.is_empty() || candidate.path.string == self.current_dir() { + return None; + } + let path_style = self.path_style; Some( maybe!({ @@ -629,15 +689,21 @@ impl PickerDelegate for OpenPathDelegate { DirectoryState::None { .. } => Vec::new(), }; + let is_current_dir_candidate = candidate.path.string == self.current_dir(); + let file_icon = maybe!({ if !settings.file_icons { return None; } let icon = if candidate.is_dir { - FileIcons::get_folder_icon(false, cx)? + if is_current_dir_candidate { + return Some(Icon::new(IconName::ReplyArrowRight).color(Color::Muted)); + } else { + FileIcons::get_folder_icon(false, cx)? + } } else { let path = path::Path::new(&candidate.path.string); - FileIcons::get_icon(&path, cx)? + FileIcons::get_icon(path, cx)? }; Some(Icon::from_path(icon).color(Color::Muted)) }); @@ -652,8 +718,10 @@ impl PickerDelegate for OpenPathDelegate { .child(HighlightedLabel::new( if parent_path == &self.prompt_root { format!("{}{}", self.prompt_root, candidate.path.string) + } else if is_current_dir_candidate { + "open this directory".to_string() } else { - candidate.path.string.clone() + candidate.path.string }, match_positions, )), @@ -684,7 +752,7 @@ impl PickerDelegate for OpenPathDelegate { }; StyledText::new(label) .with_default_highlights( - &window.text_style().clone(), + &window.text_style(), vec![( delta..delta + label_len, HighlightStyle::color(Color::Conflict.color(cx)), @@ -694,7 +762,7 @@ impl PickerDelegate for OpenPathDelegate { } else { StyledText::new(format!("{label} (create)")) .with_default_highlights( - &window.text_style().clone(), + &window.text_style(), vec![( delta..delta + label_len, HighlightStyle::color(Color::Created.color(cx)), @@ -728,10 +796,18 @@ impl PickerDelegate for OpenPathDelegate { .child(LabelLike::new().child(label_with_highlights)), ) } - DirectoryState::None { .. } => return None, + DirectoryState::None { .. } => None, } } + fn render_footer( + &self, + window: &mut Window, + cx: &mut Context>, + ) -> Option { + (self.render_footer)(window, cx) + } + fn no_matches_text(&self, _window: &mut Window, _cx: &mut App) -> Option { Some(match &self.directory_state { DirectoryState::Create { .. } => SharedString::from("Type a path…"), @@ -747,6 +823,17 @@ impl PickerDelegate for OpenPathDelegate { fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc { Arc::from(format!("[directory{MAIN_SEPARATOR_STR}]filename.ext")) } + + fn separators_after_indices(&self) -> Vec { + let Some(m) = self.string_matches.first() else { + return Vec::new(); + }; + if m.string == self.current_dir() { + vec![0] + } else { + Vec::new() + } + } } fn path_candidates( diff --git a/crates/file_finder/src/open_path_prompt_tests.rs b/crates/file_finder/src/open_path_prompt_tests.rs index a69ac6992dc280fd6537b16087302c2fbb9f8f4c..fd7cc1c6c612d28b1cc8f2352f6dbb0a254e7e98 100644 --- a/crates/file_finder/src/open_path_prompt_tests.rs +++ b/crates/file_finder/src/open_path_prompt_tests.rs @@ -39,16 +39,24 @@ async fn test_open_path_prompt(cx: &mut TestAppContext) { let (picker, cx) = build_open_path_prompt(project, false, PathStyle::current(), cx); + insert_query(path!("sadjaoislkdjasldj"), &picker, cx).await; + assert_eq!(collect_match_candidates(&picker, cx), Vec::::new()); + let query = path!("/root"); insert_query(query, &picker, cx).await; assert_eq!(collect_match_candidates(&picker, cx), vec!["root"]); + #[cfg(not(windows))] + let expected_separator = "./"; + #[cfg(windows)] + let expected_separator = ".\\"; + // If the query ends with a slash, the picker should show the contents of the directory. let query = path!("/root/"); insert_query(query, &picker, cx).await; assert_eq!( collect_match_candidates(&picker, cx), - vec!["a1", "a2", "a3", "dir1", "dir2"] + vec![expected_separator, "a1", "a2", "a3", "dir1", "dir2"] ); // Show candidates for the query "a". @@ -72,7 +80,7 @@ async fn test_open_path_prompt(cx: &mut TestAppContext) { insert_query(query, &picker, cx).await; assert_eq!( collect_match_candidates(&picker, cx), - vec!["c", "d1", "d2", "d3", "dir3", "dir4"] + vec![expected_separator, "c", "d1", "d2", "d3", "dir3", "dir4"] ); // Show candidates for the query "d". @@ -116,71 +124,86 @@ async fn test_open_path_prompt_completion(cx: &mut TestAppContext) { // Confirm completion for the query "/root", since it's a directory, it should add a trailing slash. let query = path!("/root"); insert_query(query, &picker, cx).await; - assert_eq!(confirm_completion(query, 0, &picker, cx), path!("/root/")); + assert_eq!( + confirm_completion(query, 0, &picker, cx).unwrap(), + path!("/root/") + ); // Confirm completion for the query "/root/", selecting the first candidate "a", since it's a file, it should not add a trailing slash. let query = path!("/root/"); insert_query(query, &picker, cx).await; - assert_eq!(confirm_completion(query, 0, &picker, cx), path!("/root/a")); + assert_eq!( + confirm_completion(query, 0, &picker, cx), + None, + "First entry is `./` and when we confirm completion, it is tabbed below" + ); + assert_eq!( + confirm_completion(query, 1, &picker, cx).unwrap(), + path!("/root/a"), + "Second entry is the first entry of a directory that we want to be completed" + ); // Confirm completion for the query "/root/", selecting the second candidate "dir1", since it's a directory, it should add a trailing slash. let query = path!("/root/"); insert_query(query, &picker, cx).await; assert_eq!( - confirm_completion(query, 1, &picker, cx), + confirm_completion(query, 2, &picker, cx).unwrap(), path!("/root/dir1/") ); let query = path!("/root/a"); insert_query(query, &picker, cx).await; - assert_eq!(confirm_completion(query, 0, &picker, cx), path!("/root/a")); + assert_eq!( + confirm_completion(query, 0, &picker, cx).unwrap(), + path!("/root/a") + ); let query = path!("/root/d"); insert_query(query, &picker, cx).await; assert_eq!( - confirm_completion(query, 1, &picker, cx), + confirm_completion(query, 1, &picker, cx).unwrap(), path!("/root/dir2/") ); let query = path!("/root/dir2"); insert_query(query, &picker, cx).await; assert_eq!( - confirm_completion(query, 0, &picker, cx), + confirm_completion(query, 0, &picker, cx).unwrap(), path!("/root/dir2/") ); let query = path!("/root/dir2/"); insert_query(query, &picker, cx).await; assert_eq!( - confirm_completion(query, 0, &picker, cx), + confirm_completion(query, 1, &picker, cx).unwrap(), path!("/root/dir2/c") ); let query = path!("/root/dir2/"); insert_query(query, &picker, cx).await; assert_eq!( - confirm_completion(query, 2, &picker, cx), + confirm_completion(query, 3, &picker, cx).unwrap(), path!("/root/dir2/dir3/") ); let query = path!("/root/dir2/d"); insert_query(query, &picker, cx).await; assert_eq!( - confirm_completion(query, 0, &picker, cx), + confirm_completion(query, 0, &picker, cx).unwrap(), path!("/root/dir2/d") ); let query = path!("/root/dir2/d"); insert_query(query, &picker, cx).await; assert_eq!( - confirm_completion(query, 1, &picker, cx), + confirm_completion(query, 1, &picker, cx).unwrap(), path!("/root/dir2/dir3/") ); let query = path!("/root/dir2/di"); insert_query(query, &picker, cx).await; assert_eq!( - confirm_completion(query, 1, &picker, cx), + confirm_completion(query, 1, &picker, cx).unwrap(), path!("/root/dir2/dir4/") ); } @@ -211,42 +234,63 @@ async fn test_open_path_prompt_on_windows(cx: &mut TestAppContext) { insert_query(query, &picker, cx).await; assert_eq!( collect_match_candidates(&picker, cx), - vec!["a", "dir1", "dir2"] + vec![".\\", "a", "dir1", "dir2"] + ); + assert_eq!( + confirm_completion(query, 0, &picker, cx), + None, + "First entry is `.\\` and when we confirm completion, it is tabbed below" + ); + assert_eq!( + confirm_completion(query, 1, &picker, cx).unwrap(), + "C:/root/a", + "Second entry is the first entry of a directory that we want to be completed" ); - assert_eq!(confirm_completion(query, 0, &picker, cx), "C:/root/a"); let query = "C:\\root/"; insert_query(query, &picker, cx).await; assert_eq!( collect_match_candidates(&picker, cx), - vec!["a", "dir1", "dir2"] + vec![".\\", "a", "dir1", "dir2"] + ); + assert_eq!( + confirm_completion(query, 1, &picker, cx).unwrap(), + "C:\\root/a" ); - assert_eq!(confirm_completion(query, 0, &picker, cx), "C:\\root/a"); let query = "C:\\root\\"; insert_query(query, &picker, cx).await; assert_eq!( collect_match_candidates(&picker, cx), - vec!["a", "dir1", "dir2"] + vec![".\\", "a", "dir1", "dir2"] + ); + assert_eq!( + confirm_completion(query, 1, &picker, cx).unwrap(), + "C:\\root\\a" ); - assert_eq!(confirm_completion(query, 0, &picker, cx), "C:\\root\\a"); // Confirm completion for the query "C:/root/d", selecting the second candidate "dir2", since it's a directory, it should add a trailing slash. let query = "C:/root/d"; insert_query(query, &picker, cx).await; assert_eq!(collect_match_candidates(&picker, cx), vec!["dir1", "dir2"]); - assert_eq!(confirm_completion(query, 1, &picker, cx), "C:/root/dir2\\"); + assert_eq!( + confirm_completion(query, 1, &picker, cx).unwrap(), + "C:/root/dir2\\" + ); let query = "C:\\root/d"; insert_query(query, &picker, cx).await; assert_eq!(collect_match_candidates(&picker, cx), vec!["dir1", "dir2"]); - assert_eq!(confirm_completion(query, 0, &picker, cx), "C:\\root/dir1\\"); + assert_eq!( + confirm_completion(query, 0, &picker, cx).unwrap(), + "C:\\root/dir1\\" + ); let query = "C:\\root\\d"; insert_query(query, &picker, cx).await; assert_eq!(collect_match_candidates(&picker, cx), vec!["dir1", "dir2"]); assert_eq!( - confirm_completion(query, 0, &picker, cx), + confirm_completion(query, 0, &picker, cx).unwrap(), "C:\\root\\dir1\\" ); } @@ -276,20 +320,29 @@ async fn test_open_path_prompt_on_windows_with_remote(cx: &mut TestAppContext) { insert_query(query, &picker, cx).await; assert_eq!( collect_match_candidates(&picker, cx), - vec!["a", "dir1", "dir2"] + vec!["./", "a", "dir1", "dir2"] + ); + assert_eq!( + confirm_completion(query, 1, &picker, cx).unwrap(), + "/root/a" ); - assert_eq!(confirm_completion(query, 0, &picker, cx), "/root/a"); // Confirm completion for the query "/root/d", selecting the second candidate "dir2", since it's a directory, it should add a trailing slash. let query = "/root/d"; insert_query(query, &picker, cx).await; assert_eq!(collect_match_candidates(&picker, cx), vec!["dir1", "dir2"]); - assert_eq!(confirm_completion(query, 1, &picker, cx), "/root/dir2/"); + assert_eq!( + confirm_completion(query, 1, &picker, cx).unwrap(), + "/root/dir2/" + ); let query = "/root/d"; insert_query(query, &picker, cx).await; assert_eq!(collect_match_candidates(&picker, cx), vec!["dir1", "dir2"]); - assert_eq!(confirm_completion(query, 0, &picker, cx), "/root/dir1/"); + assert_eq!( + confirm_completion(query, 0, &picker, cx).unwrap(), + "/root/dir1/" + ); } #[gpui::test] @@ -396,15 +449,13 @@ fn confirm_completion( select: usize, picker: &Entity>, cx: &mut VisualTestContext, -) -> String { - picker - .update_in(cx, |f, window, cx| { - if f.delegate.selected_index() != select { - f.delegate.set_selected_index(select, window, cx); - } - f.delegate.confirm_completion(query.to_string(), window, cx) - }) - .unwrap() +) -> Option { + picker.update_in(cx, |f, window, cx| { + if f.delegate.selected_index() != select { + f.delegate.set_selected_index(select, window, cx); + } + f.delegate.confirm_completion(query.to_string(), window, cx) + }) } fn collect_match_candidates( diff --git a/crates/file_icons/src/file_icons.rs b/crates/file_icons/src/file_icons.rs index 82a8e05d8571b04ec177c9944a765778684fe2a4..42c00fb12d5e9f0fbb1662eb0941ed70d94382b5 100644 --- a/crates/file_icons/src/file_icons.rs +++ b/crates/file_icons/src/file_icons.rs @@ -72,7 +72,7 @@ impl FileIcons { return maybe_path; } } - return this.get_icon_for_type("default", cx); + this.get_icon_for_type("default", cx) } fn default_icon_theme(cx: &App) -> Option> { diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index f7f6e160fd33e2ab05dc18082529822fd0b1b5cd..37a6ff109d8a3f16a0c2cf5a0bd924180ecbede8 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -345,7 +345,7 @@ impl GitRepository for FakeGitRepository { fn create_branch(&self, name: String) -> BoxFuture<'_, Result<()>> { self.with_state_async(true, move |state| { - state.branches.insert(name.to_owned()); + state.branches.insert(name); Ok(()) }) } @@ -603,9 +603,9 @@ mod tests { assert_eq!( fs.files_with_contents(Path::new("")), [ - (Path::new("/bar/baz").into(), b"qux".into()), - (Path::new("/foo/a").into(), b"lorem".into()), - (Path::new("/foo/b").into(), b"ipsum".into()) + (Path::new(path!("/bar/baz")).into(), b"qux".into()), + (Path::new(path!("/foo/a")).into(), b"lorem".into()), + (Path::new(path!("/foo/b")).into(), b"ipsum".into()) ] ); } diff --git a/crates/fs/src/fs.rs b/crates/fs/src/fs.rs index 22bfdbcd66ee0b3193ef51e3ec461dfe225fa8f0..98c8dc9054984c49732bec57a9604a14ceb5ee72 100644 --- a/crates/fs/src/fs.rs +++ b/crates/fs/src/fs.rs @@ -20,6 +20,9 @@ use std::os::fd::{AsFd, AsRawFd}; #[cfg(unix)] use std::os::unix::fs::{FileTypeExt, MetadataExt}; +#[cfg(any(target_os = "macos", target_os = "freebsd"))] +use std::mem::MaybeUninit; + use async_tar::Archive; use futures::{AsyncRead, Stream, StreamExt, future::BoxFuture}; use git::repository::{GitRepository, RealGitRepository}; @@ -261,14 +264,15 @@ impl FileHandle for std::fs::File { }; let fd = self.as_fd(); - let mut path_buf: [libc::c_char; libc::PATH_MAX as usize] = [0; libc::PATH_MAX as usize]; + let mut path_buf = MaybeUninit::<[u8; libc::PATH_MAX as usize]>::uninit(); let result = unsafe { libc::fcntl(fd.as_raw_fd(), libc::F_GETPATH, path_buf.as_mut_ptr()) }; if result == -1 { anyhow::bail!("fcntl returned -1".to_string()); } - let c_str = unsafe { CStr::from_ptr(path_buf.as_ptr()) }; + // SAFETY: `fcntl` will initialize the path buffer. + let c_str = unsafe { CStr::from_ptr(path_buf.as_ptr().cast()) }; let path = PathBuf::from(OsStr::from_bytes(c_str.to_bytes())); Ok(path) } @@ -296,15 +300,16 @@ impl FileHandle for std::fs::File { }; let fd = self.as_fd(); - let mut kif: libc::kinfo_file = unsafe { std::mem::zeroed() }; + let mut kif = MaybeUninit::::uninit(); kif.kf_structsize = libc::KINFO_FILE_SIZE; - let result = unsafe { libc::fcntl(fd.as_raw_fd(), libc::F_KINFO, &mut kif) }; + let result = unsafe { libc::fcntl(fd.as_raw_fd(), libc::F_KINFO, kif.as_mut_ptr()) }; if result == -1 { anyhow::bail!("fcntl returned -1".to_string()); } - let c_str = unsafe { CStr::from_ptr(kif.kf_path.as_ptr()) }; + // SAFETY: `fcntl` will initialize the kif. + let c_str = unsafe { CStr::from_ptr(kif.assume_init().kf_path.as_ptr()) }; let path = PathBuf::from(OsStr::from_bytes(c_str.to_bytes())); Ok(path) } @@ -420,18 +425,19 @@ impl Fs for RealFs { async fn remove_file(&self, path: &Path, options: RemoveOptions) -> Result<()> { #[cfg(windows)] - if let Ok(Some(metadata)) = self.metadata(path).await { - if metadata.is_symlink && metadata.is_dir { - self.remove_dir( - path, - RemoveOptions { - recursive: false, - ignore_if_not_exists: true, - }, - ) - .await?; - return Ok(()); - } + if let Ok(Some(metadata)) = self.metadata(path).await + && metadata.is_symlink + && metadata.is_dir + { + self.remove_dir( + path, + RemoveOptions { + recursive: false, + ignore_if_not_exists: true, + }, + ) + .await?; + return Ok(()); } match smol::fs::remove_file(path).await { @@ -467,11 +473,11 @@ impl Fs for RealFs { #[cfg(any(target_os = "linux", target_os = "freebsd"))] async fn trash_file(&self, path: &Path, _options: RemoveOptions) -> Result<()> { - if let Ok(Some(metadata)) = self.metadata(path).await { - if metadata.is_symlink { - // TODO: trash_file does not support trashing symlinks yet - https://github.com/bilelmoussaoui/ashpd/issues/255 - return self.remove_file(path, RemoveOptions::default()).await; - } + if let Ok(Some(metadata)) = self.metadata(path).await + && metadata.is_symlink + { + // TODO: trash_file does not support trashing symlinks yet - https://github.com/bilelmoussaoui/ashpd/issues/255 + return self.remove_file(path, RemoveOptions::default()).await; } let file = smol::fs::File::open(path).await?; match trash::trash_file(&file.as_fd()).await { @@ -494,7 +500,8 @@ impl Fs for RealFs { }; // todo(windows) // When new version of `windows-rs` release, make this operation `async` - let path = SanitizedPath::from(path.canonicalize()?); + let path = path.canonicalize()?; + let path = SanitizedPath::new(&path); let path_string = path.to_string(); let file = StorageFile::GetFileFromPathAsync(&HSTRING::from(path_string))?.get()?; file.DeleteAsync(StorageDeleteOption::Default)?.get()?; @@ -521,7 +528,8 @@ impl Fs for RealFs { // todo(windows) // When new version of `windows-rs` release, make this operation `async` - let path = SanitizedPath::from(path.canonicalize()?); + let path = path.canonicalize()?; + let path = SanitizedPath::new(&path); let path_string = path.to_string(); let folder = StorageFolder::GetFolderFromPathAsync(&HSTRING::from(path_string))?.get()?; folder.DeleteAsync(StorageDeleteOption::Default)?.get()?; @@ -624,13 +632,13 @@ impl Fs for RealFs { async fn is_file(&self, path: &Path) -> bool { smol::fs::metadata(path) .await - .map_or(false, |metadata| metadata.is_file()) + .is_ok_and(|metadata| metadata.is_file()) } async fn is_dir(&self, path: &Path) -> bool { smol::fs::metadata(path) .await - .map_or(false, |metadata| metadata.is_dir()) + .is_ok_and(|metadata| metadata.is_dir()) } async fn metadata(&self, path: &Path) -> Result> { @@ -766,24 +774,23 @@ impl Fs for RealFs { let pending_paths: Arc>> = Default::default(); let watcher = Arc::new(fs_watcher::FsWatcher::new(tx, pending_paths.clone())); - if watcher.add(path).is_err() { - // If the path doesn't exist yet (e.g. settings.json), watch the parent dir to learn when it's created. - if let Some(parent) = path.parent() { - if let Err(e) = watcher.add(parent) { - log::warn!("Failed to watch: {e}"); - } - } + // If the path doesn't exist yet (e.g. settings.json), watch the parent dir to learn when it's created. + if watcher.add(path).is_err() + && let Some(parent) = path.parent() + && let Err(e) = watcher.add(parent) + { + log::warn!("Failed to watch: {e}"); } // Check if path is a symlink and follow the target parent - if let Some(mut target) = self.read_link(&path).await.ok() { + if let Some(mut target) = self.read_link(path).await.ok() { // Check if symlink target is relative path, if so make it absolute - if target.is_relative() { - if let Some(parent) = path.parent() { - target = parent.join(target); - if let Ok(canonical) = self.canonicalize(&target).await { - target = SanitizedPath::from(canonical).as_path().to_path_buf(); - } + if target.is_relative() + && let Some(parent) = path.parent() + { + target = parent.join(target); + if let Ok(canonical) = self.canonicalize(&target).await { + target = SanitizedPath::new(&canonical).as_path().to_path_buf(); } } watcher.add(&target).ok(); @@ -1068,13 +1075,13 @@ impl FakeFsState { let current_entry = *entry_stack.last()?; if let FakeFsEntry::Dir { entries, .. } = current_entry { let entry = entries.get(name.to_str().unwrap())?; - if path_components.peek().is_some() || follow_symlink { - if let FakeFsEntry::Symlink { target, .. } = entry { - let mut target = target.clone(); - target.extend(path_components); - path = target; - continue 'outer; - } + if (path_components.peek().is_some() || follow_symlink) + && let FakeFsEntry::Symlink { target, .. } = entry + { + let mut target = target.clone(); + target.extend(path_components); + path = target; + continue 'outer; } entry_stack.push(entry); canonical_path = canonical_path.join(name); @@ -1101,7 +1108,9 @@ impl FakeFsState { ) -> Option<(&mut FakeFsEntry, PathBuf)> { let canonical_path = self.canonicalize(target, follow_symlink)?; - let mut components = canonical_path.components(); + let mut components = canonical_path + .components() + .skip_while(|component| matches!(component, Component::Prefix(_))); let Some(Component::RootDir) = components.next() else { panic!( "the path {:?} was not canonicalized properly {:?}", @@ -1566,10 +1575,10 @@ impl FakeFs { pub fn insert_branches(&self, dot_git: &Path, branches: &[&str]) { self.with_git_state(dot_git, true, |state| { - if let Some(first) = branches.first() { - if state.current_branch_name.is_none() { - state.current_branch_name = Some(first.to_string()) - } + if let Some(first) = branches.first() + && state.current_branch_name.is_none() + { + state.current_branch_name = Some(first.to_string()) } state .branches @@ -1677,7 +1686,7 @@ impl FakeFs { /// by mutating the head, index, and unmerged state. pub fn set_status_for_repo(&self, dot_git: &Path, statuses: &[(&Path, FileStatus)]) { let workdir_path = dot_git.parent().unwrap(); - let workdir_contents = self.files_with_contents(&workdir_path); + let workdir_contents = self.files_with_contents(workdir_path); self.with_git_state(dot_git, true, |state| { state.index_contents.clear(); state.head_contents.clear(); @@ -1958,7 +1967,7 @@ impl FileHandle for FakeHandle { }; if state.try_entry(&target, false).is_some() { - return Ok(target.clone()); + return Ok(target); } anyhow::bail!("fake fd target not found") } @@ -2244,7 +2253,7 @@ impl Fs for FakeFs { async fn open_handle(&self, path: &Path) -> Result> { self.simulate_random_delay().await; let mut state = self.state.lock(); - let inode = match state.entry(&path)? { + let inode = match state.entry(path)? { FakeFsEntry::File { inode, .. } => *inode, FakeFsEntry::Dir { inode, .. } => *inode, _ => unreachable!(), @@ -2254,7 +2263,7 @@ impl Fs for FakeFs { async fn load(&self, path: &Path) -> Result { let content = self.load_internal(path).await?; - Ok(String::from_utf8(content.clone())?) + Ok(String::from_utf8(content)?) } async fn load_bytes(&self, path: &Path) -> Result> { @@ -2410,19 +2419,18 @@ impl Fs for FakeFs { tx, original_path: path.to_owned(), fs_state: self.state.clone(), - prefixes: Mutex::new(vec![path.to_owned()]), + prefixes: Mutex::new(vec![path]), }); ( Box::pin(futures::StreamExt::filter(rx, { let watcher = watcher.clone(); move |events| { let result = events.iter().any(|evt_path| { - let result = watcher + watcher .prefixes .lock() .iter() - .any(|prefix| evt_path.path.starts_with(prefix)); - result + .any(|prefix| evt_path.path.starts_with(prefix)) }); let executor = executor.clone(); async move { diff --git a/crates/fs/src/fs_watcher.rs b/crates/fs/src/fs_watcher.rs index a5ce21294fc65e609428ad95fafb43fe578bc698..07374b7f40455f09cf52d31ddd1a1f64ab6abcd3 100644 --- a/crates/fs/src/fs_watcher.rs +++ b/crates/fs/src/fs_watcher.rs @@ -42,7 +42,7 @@ impl Drop for FsWatcher { impl Watcher for FsWatcher { fn add(&self, path: &std::path::Path) -> anyhow::Result<()> { - let root_path = SanitizedPath::from(path); + let root_path = SanitizedPath::new_arc(path); let tx = self.tx.clone(); let pending_paths = self.pending_path_events.clone(); @@ -70,7 +70,7 @@ impl Watcher for FsWatcher { .paths .iter() .filter_map(|event_path| { - let event_path = SanitizedPath::from(event_path); + let event_path = SanitizedPath::new(event_path); event_path.starts_with(&root_path).then(|| PathEvent { path: event_path.as_path().to_path_buf(), kind, @@ -159,7 +159,7 @@ impl GlobalWatcher { path: path.clone(), }; state.watchers.insert(id, registration_state); - *state.path_registrations.entry(path.clone()).or_insert(0) += 1; + *state.path_registrations.entry(path).or_insert(0) += 1; Ok(id) } diff --git a/crates/fs/src/mac_watcher.rs b/crates/fs/src/mac_watcher.rs index aa75ad31d9beadada32b62ed4d21a612631d31c3..7bd176639f1dccef2da4c4ae8dcb317d0be602cb 100644 --- a/crates/fs/src/mac_watcher.rs +++ b/crates/fs/src/mac_watcher.rs @@ -41,10 +41,9 @@ impl Watcher for MacWatcher { if let Some((watched_path, _)) = handles .range::((Bound::Unbounded, Bound::Included(path))) .next_back() + && path.starts_with(watched_path) { - if path.starts_with(watched_path) { - return Ok(()); - } + return Ok(()); } let (stream, handle) = EventStream::new(&[path], self.latency); diff --git a/crates/fsevent/src/fsevent.rs b/crates/fsevent/src/fsevent.rs index 81ca0a4114253fc38b5d120d1c37dfc9233f7fd1..c97ab5f35d1b1e8463e895da7a309dc7ef3be998 100644 --- a/crates/fsevent/src/fsevent.rs +++ b/crates/fsevent/src/fsevent.rs @@ -178,40 +178,39 @@ impl EventStream { flags.contains(StreamFlags::USER_DROPPED) || flags.contains(StreamFlags::KERNEL_DROPPED) }) + && let Some(last_valid_event_id) = state.last_valid_event_id.take() { - if let Some(last_valid_event_id) = state.last_valid_event_id.take() { - fs::FSEventStreamStop(state.stream); - fs::FSEventStreamInvalidate(state.stream); - fs::FSEventStreamRelease(state.stream); - - let stream_context = fs::FSEventStreamContext { - version: 0, - info, - retain: None, - release: None, - copy_description: None, - }; - let stream = fs::FSEventStreamCreate( - cf::kCFAllocatorDefault, - Self::trampoline, - &stream_context, - state.paths, - last_valid_event_id, - state.latency.as_secs_f64(), - fs::kFSEventStreamCreateFlagFileEvents - | fs::kFSEventStreamCreateFlagNoDefer - | fs::kFSEventStreamCreateFlagWatchRoot, - ); - - state.stream = stream; - fs::FSEventStreamScheduleWithRunLoop( - state.stream, - cf::CFRunLoopGetCurrent(), - cf::kCFRunLoopDefaultMode, - ); - fs::FSEventStreamStart(state.stream); - stream_restarted = true; - } + fs::FSEventStreamStop(state.stream); + fs::FSEventStreamInvalidate(state.stream); + fs::FSEventStreamRelease(state.stream); + + let stream_context = fs::FSEventStreamContext { + version: 0, + info, + retain: None, + release: None, + copy_description: None, + }; + let stream = fs::FSEventStreamCreate( + cf::kCFAllocatorDefault, + Self::trampoline, + &stream_context, + state.paths, + last_valid_event_id, + state.latency.as_secs_f64(), + fs::kFSEventStreamCreateFlagFileEvents + | fs::kFSEventStreamCreateFlagNoDefer + | fs::kFSEventStreamCreateFlagWatchRoot, + ); + + state.stream = stream; + fs::FSEventStreamScheduleWithRunLoop( + state.stream, + cf::CFRunLoopGetCurrent(), + cf::kCFRunLoopDefaultMode, + ); + fs::FSEventStreamStart(state.stream); + stream_restarted = true; } if !stream_restarted { diff --git a/crates/git/src/blame.rs b/crates/git/src/blame.rs index 6f12681ea08956b53d9ce298593ce08f0e2a74a9..24b2c44218120b1237fb42e04edc9b6784356c57 100644 --- a/crates/git/src/blame.rs +++ b/crates/git/src/blame.rs @@ -289,14 +289,12 @@ fn parse_git_blame(output: &str) -> Result> { } }; - if done { - if let Some(entry) = current_entry.take() { - index.insert(entry.sha, entries.len()); + if done && let Some(entry) = current_entry.take() { + index.insert(entry.sha, entries.len()); - // We only want annotations that have a commit. - if !entry.sha.is_zero() { - entries.push(entry); - } + // We only want annotations that have a commit. + if !entry.sha.is_zero() { + entries.push(entry); } } } diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index 49578f42d54827d91ebbfc1e2dfb3211d55655eb..be6d5b4f03e4685b94b6f637c04d432c4acc4332 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -176,6 +176,7 @@ pub struct CommitSummary { pub subject: SharedString, /// This is a unix timestamp pub commit_timestamp: i64, + pub author_name: SharedString, pub has_parent: bool, } @@ -295,10 +296,8 @@ impl GitExcludeOverride { pub async fn restore_original(&mut self) -> Result<()> { if let Some(ref original) = self.original_excludes { smol::fs::write(&self.git_exclude_path, original).await?; - } else { - if self.git_exclude_path.exists() { - smol::fs::remove_file(&self.git_exclude_path).await?; - } + } else if self.git_exclude_path.exists() { + smol::fs::remove_file(&self.git_exclude_path).await?; } self.added_excludes = None; @@ -885,7 +884,7 @@ impl GitRepository for RealGitRepository { let output = new_smol_command(&git_binary_path) .current_dir(&working_directory) .envs(env.iter()) - .args(["update-index", "--add", "--cacheinfo", "100644", &sha]) + .args(["update-index", "--add", "--cacheinfo", "100644", sha]) .arg(path.to_unix_style()) .output() .await?; @@ -945,7 +944,7 @@ impl GitRepository for RealGitRepository { .context("no stdin for git cat-file subprocess")?; let mut stdin = BufWriter::new(stdin); for rev in &revs { - write!(&mut stdin, "{rev}\n")?; + writeln!(&mut stdin, "{rev}")?; } stdin.flush()?; drop(stdin); @@ -986,7 +985,7 @@ impl GitRepository for RealGitRepository { Ok(working_directory) => working_directory, Err(e) => return Task::ready(Err(e)), }; - let args = git_status_args(&path_prefixes); + let args = git_status_args(path_prefixes); log::debug!("Checking for git status in {path_prefixes:?}"); self.executor.spawn(async move { let output = new_std_command(&git_binary_path) @@ -1016,6 +1015,7 @@ impl GitRepository for RealGitRepository { "%(upstream)", "%(upstream:track)", "%(committerdate:unix)", + "%(authorname)", "%(contents:subject)", ] .join("%00"); @@ -1507,12 +1507,11 @@ impl GitRepository for RealGitRepository { let mut remote_branches = vec![]; let mut add_if_matching = async |remote_head: &str| { - if let Ok(merge_base) = git_cmd(&["merge-base", &head, remote_head]).await { - if merge_base.trim() == head { - if let Some(s) = remote_head.strip_prefix("refs/remotes/") { - remote_branches.push(s.to_owned().into()); - } - } + if let Ok(merge_base) = git_cmd(&["merge-base", &head, remote_head]).await + && merge_base.trim() == head + && let Some(s) = remote_head.strip_prefix("refs/remotes/") + { + remote_branches.push(s.to_owned().into()); } }; @@ -1634,10 +1633,9 @@ impl GitRepository for RealGitRepository { Err(error) => { if let Some(GitBinaryCommandError { status, .. }) = error.downcast_ref::() + && status.code() == Some(1) { - if status.code() == Some(1) { - return Ok(false); - } + return Ok(false); } Err(error) @@ -1926,23 +1924,13 @@ impl GitBinary { } #[derive(Error, Debug)] -#[error("Git command failed: {}", .stderr.trim().if_empty(.stdout.trim()))] +#[error("Git command failed:\n{stdout}{stderr}\n")] struct GitBinaryCommandError { stdout: String, stderr: String, status: ExitStatus, } -trait StringExt { - fn if_empty<'a>(&'a self, fallback: &'a str) -> &'a str; -} - -impl StringExt for str { - fn if_empty<'a>(&'a self, fallback: &'a str) -> &'a str { - if self.is_empty() { fallback } else { self } - } -} - async fn run_git_command( env: Arc>, ask_pass: AskPassDelegate, @@ -2121,6 +2109,7 @@ fn parse_branch_input(input: &str) -> Result> { let upstream_name = fields.next().context("no upstream")?.to_string(); let upstream_tracking = parse_upstream_track(fields.next().context("no upstream:track")?)?; let commiterdate = fields.next().context("no committerdate")?.parse::()?; + let author_name = fields.next().context("no authorname")?.to_string().into(); let subject: SharedString = fields .next() .context("no contents:subject")? @@ -2129,11 +2118,12 @@ fn parse_branch_input(input: &str) -> Result> { branches.push(Branch { is_head: is_current_branch, - ref_name: ref_name, + ref_name, most_recent_commit: Some(CommitSummary { sha: head_sha, subject, commit_timestamp: commiterdate, + author_name: author_name, has_parent: !parent_sha.is_empty(), }), upstream: if upstream_name.is_empty() { @@ -2151,7 +2141,7 @@ fn parse_branch_input(input: &str) -> Result> { } fn parse_upstream_track(upstream_track: &str) -> Result { - if upstream_track == "" { + if upstream_track.is_empty() { return Ok(UpstreamTracking::Tracked(UpstreamTrackingStatus { ahead: 0, behind: 0, @@ -2444,9 +2434,9 @@ mod tests { fn test_branches_parsing() { // suppress "help: octal escapes are not supported, `\0` is always null" #[allow(clippy::octal_escapes)] - let input = "*\0060964da10574cd9bf06463a53bf6e0769c5c45e\0\0refs/heads/zed-patches\0refs/remotes/origin/zed-patches\0\01733187470\0generated protobuf\n"; + let input = "*\0060964da10574cd9bf06463a53bf6e0769c5c45e\0\0refs/heads/zed-patches\0refs/remotes/origin/zed-patches\0\01733187470\0John Doe\0generated protobuf\n"; assert_eq!( - parse_branch_input(&input).unwrap(), + parse_branch_input(input).unwrap(), vec![Branch { is_head: true, ref_name: "refs/heads/zed-patches".into(), @@ -2461,6 +2451,7 @@ mod tests { sha: "060964da10574cd9bf06463a53bf6e0769c5c45e".into(), subject: "generated protobuf".into(), commit_timestamp: 1733187470, + author_name: SharedString::new("John Doe"), has_parent: false, }) }] diff --git a/crates/git/src/status.rs b/crates/git/src/status.rs index 6158b5179838c2b3bd36fb91f2aa9e2286c52ca1..71ca14c5b2c4b82ae7dc21e832a2a07c55de8fc3 100644 --- a/crates/git/src/status.rs +++ b/crates/git/src/status.rs @@ -153,17 +153,11 @@ impl FileStatus { } pub fn is_conflicted(self) -> bool { - match self { - FileStatus::Unmerged { .. } => true, - _ => false, - } + matches!(self, FileStatus::Unmerged { .. }) } pub fn is_ignored(self) -> bool { - match self { - FileStatus::Ignored => true, - _ => false, - } + matches!(self, FileStatus::Ignored) } pub fn has_changes(&self) -> bool { @@ -176,40 +170,31 @@ impl FileStatus { pub fn is_modified(self) -> bool { match self { - FileStatus::Tracked(tracked) => match (tracked.index_status, tracked.worktree_status) { - (StatusCode::Modified, _) | (_, StatusCode::Modified) => true, - _ => false, - }, + FileStatus::Tracked(tracked) => matches!( + (tracked.index_status, tracked.worktree_status), + (StatusCode::Modified, _) | (_, StatusCode::Modified) + ), _ => false, } } pub fn is_created(self) -> bool { match self { - FileStatus::Tracked(tracked) => match (tracked.index_status, tracked.worktree_status) { - (StatusCode::Added, _) | (_, StatusCode::Added) => true, - _ => false, - }, + FileStatus::Tracked(tracked) => matches!( + (tracked.index_status, tracked.worktree_status), + (StatusCode::Added, _) | (_, StatusCode::Added) + ), FileStatus::Untracked => true, _ => false, } } pub fn is_deleted(self) -> bool { - match self { - FileStatus::Tracked(tracked) => match (tracked.index_status, tracked.worktree_status) { - (StatusCode::Deleted, _) | (_, StatusCode::Deleted) => true, - _ => false, - }, - _ => false, - } + matches!(self, FileStatus::Tracked(tracked) if matches!((tracked.index_status, tracked.worktree_status), (StatusCode::Deleted, _) | (_, StatusCode::Deleted))) } pub fn is_untracked(self) -> bool { - match self { - FileStatus::Untracked => true, - _ => false, - } + matches!(self, FileStatus::Untracked) } pub fn summary(self) -> GitSummary { @@ -468,7 +453,7 @@ impl FromStr for GitStatus { Some((path, status)) }) .collect::>(); - entries.sort_unstable_by(|(a, _), (b, _)| a.cmp(&b)); + entries.sort_unstable_by(|(a, _), (b, _)| a.cmp(b)); // When a file exists in HEAD, is deleted in the index, and exists again in the working copy, // git produces two lines for it, one reading `D ` (deleted in index, unmodified in working copy) // and the other reading `??` (untracked). Merge these two into the equivalent of `DA`. diff --git a/crates/git_hosting_providers/src/git_hosting_providers.rs b/crates/git_hosting_providers/src/git_hosting_providers.rs index b31412ed4a46b0dc2695ae0229638fad409de13c..1d88c47f2e26fc9ad4e27b1e36351198c4365caf 100644 --- a/crates/git_hosting_providers/src/git_hosting_providers.rs +++ b/crates/git_hosting_providers/src/git_hosting_providers.rs @@ -49,13 +49,13 @@ pub fn register_additional_providers( pub fn get_host_from_git_remote_url(remote_url: &str) -> Result { maybe!({ - if let Some(remote_url) = remote_url.strip_prefix("git@") { - if let Some((host, _)) = remote_url.trim_start_matches("git@").split_once(':') { - return Some(host.to_string()); - } + if let Some(remote_url) = remote_url.strip_prefix("git@") + && let Some((host, _)) = remote_url.trim_start_matches("git@").split_once(':') + { + return Some(host.to_string()); } - Url::parse(&remote_url) + Url::parse(remote_url) .ok() .and_then(|remote_url| remote_url.host_str().map(|host| host.to_string())) }) diff --git a/crates/git_hosting_providers/src/providers/chromium.rs b/crates/git_hosting_providers/src/providers/chromium.rs index b68c629ec7faaf9e37316cd0f7fb4f297b55f502..5d940fb496be6fde2778272abe640987e3b2a4af 100644 --- a/crates/git_hosting_providers/src/providers/chromium.rs +++ b/crates/git_hosting_providers/src/providers/chromium.rs @@ -292,7 +292,7 @@ mod tests { assert_eq!( Chromium - .extract_pull_request(&remote, &message) + .extract_pull_request(&remote, message) .unwrap() .url .as_str(), diff --git a/crates/git_hosting_providers/src/providers/github.rs b/crates/git_hosting_providers/src/providers/github.rs index 30f8d058a7c46798209685930518f4b040dbe714..4475afeb495f41e89273ce0336d830db4cc869cf 100644 --- a/crates/git_hosting_providers/src/providers/github.rs +++ b/crates/git_hosting_providers/src/providers/github.rs @@ -474,7 +474,7 @@ mod tests { assert_eq!( github - .extract_pull_request(&remote, &message) + .extract_pull_request(&remote, message) .unwrap() .url .as_str(), @@ -488,6 +488,6 @@ mod tests { See the original PR, this is a fix. "# }; - assert_eq!(github.extract_pull_request(&remote, &message), None); + assert_eq!(github.extract_pull_request(&remote, message), None); } } diff --git a/crates/git_hosting_providers/src/settings.rs b/crates/git_hosting_providers/src/settings.rs index 91179fea392bc38cfc2a513bfc391dd3eec6137d..3249981db91015479bab728484341519db357683 100644 --- a/crates/git_hosting_providers/src/settings.rs +++ b/crates/git_hosting_providers/src/settings.rs @@ -5,7 +5,7 @@ use git::GitHostingProviderRegistry; use gpui::App; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsStore}; +use settings::{Settings, SettingsKey, SettingsStore, SettingsUi}; use url::Url; use util::ResultExt as _; @@ -78,7 +78,8 @@ pub struct GitHostingProviderConfig { pub name: String, } -#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema, SettingsUi, SettingsKey)] +#[settings_key(None)] pub struct GitHostingProviderSettings { /// The list of custom Git hosting providers. #[serde(default)] @@ -86,8 +87,6 @@ pub struct GitHostingProviderSettings { } impl Settings for GitHostingProviderSettings { - const KEY: Option<&'static str> = None; - type FileContent = Self; fn load(sources: settings::SettingsSources, _: &mut App) -> Result { diff --git a/crates/git_ui/src/blame_ui.rs b/crates/git_ui/src/blame_ui.rs index f910de7bbe461ea8edf7addda4acad1b712a6f60..ad5823c1674353f2e0531e5f71e1420fe464bfe6 100644 --- a/crates/git_ui/src/blame_ui.rs +++ b/crates/git_ui/src/blame_ui.rs @@ -90,6 +90,11 @@ impl BlameRenderer for GitBlameRenderer { sha: blame_entry.sha.to_string().into(), subject: blame_entry.summary.clone().unwrap_or_default().into(), commit_timestamp: blame_entry.committer_time.unwrap_or_default(), + author_name: blame_entry + .committer_name + .clone() + .unwrap_or_default() + .into(), has_parent: true, }, repository.downgrade(), @@ -172,7 +177,7 @@ impl BlameRenderer for GitBlameRenderer { .clone() .unwrap_or("".to_string()) .into(), - author_email: blame.author_mail.clone().unwrap_or("".to_string()).into(), + author_email: blame.author_mail.unwrap_or("".to_string()).into(), message: details, }; @@ -186,7 +191,7 @@ impl BlameRenderer for GitBlameRenderer { .get(0..8) .map(|sha| sha.to_string().into()) .unwrap_or_else(|| commit_details.sha.clone()); - let full_sha = commit_details.sha.to_string().clone(); + let full_sha = commit_details.sha.to_string(); let absolute_timestamp = format_local_timestamp( commit_details.commit_time, OffsetDateTime::now_utc(), @@ -229,6 +234,7 @@ impl BlameRenderer for GitBlameRenderer { .into() }), commit_timestamp: commit_details.commit_time.unix_timestamp(), + author_name: commit_details.author_name.clone(), has_parent: false, }; @@ -374,10 +380,11 @@ impl BlameRenderer for GitBlameRenderer { sha: blame_entry.sha.to_string().into(), subject: blame_entry.summary.clone().unwrap_or_default().into(), commit_timestamp: blame_entry.committer_time.unwrap_or_default(), + author_name: blame_entry.committer_name.unwrap_or_default().into(), has_parent: true, }, repository.downgrade(), - workspace.clone(), + workspace, window, cx, ) diff --git a/crates/git_ui/src/branch_picker.rs b/crates/git_ui/src/branch_picker.rs index a49b5656d7015ad6325d3a0d1232bc9a01a888c8..02ba79ce08b200622690a649eadf530cfc0ce4b2 100644 --- a/crates/git_ui/src/branch_picker.rs +++ b/crates/git_ui/src/branch_picker.rs @@ -10,6 +10,8 @@ use gpui::{ }; use picker::{Picker, PickerDelegate, PickerEditorPosition}; use project::git_store::Repository; +use project::project_settings::ProjectSettings; +use settings::Settings; use std::sync::Arc; use time::OffsetDateTime; use time_format::format_local_timestamp; @@ -48,7 +50,7 @@ pub fn open( window: &mut Window, cx: &mut Context, ) { - let repository = workspace.project().read(cx).active_repository(cx).clone(); + let repository = workspace.project().read(cx).active_repository(cx); let style = BranchListStyle::Modal; workspace.toggle_modal(window, cx, |window, cx| { BranchList::new(repository, style, rems(34.), window, cx) @@ -122,10 +124,13 @@ impl BranchList { all_branches.retain(|branch| !remote_upstreams.contains(&branch.ref_name)); all_branches.sort_by_key(|branch| { - branch - .most_recent_commit - .as_ref() - .map(|commit| 0 - commit.commit_timestamp) + ( + !branch.is_head, // Current branch (is_head=true) comes first + branch + .most_recent_commit + .as_ref() + .map(|commit| 0 - commit.commit_timestamp), + ) }); all_branches @@ -144,7 +149,7 @@ impl BranchList { }) .detach_and_log_err(cx); - let delegate = BranchListDelegate::new(repository.clone(), style); + let delegate = BranchListDelegate::new(repository, style); let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx)); let _subscription = cx.subscribe(&picker, |_, _, _, cx| { @@ -353,7 +358,6 @@ impl PickerDelegate for BranchListDelegate { }; picker .update(cx, |picker, _| { - #[allow(clippy::nonminimal_bool)] if !query.is_empty() && !matches .first() @@ -472,9 +476,9 @@ impl PickerDelegate for BranchListDelegate { _window: &mut Window, cx: &mut Context>, ) -> Option { - let entry = &self.matches[ix]; + let entry = &self.matches.get(ix)?; - let (commit_time, subject) = entry + let (commit_time, author_name, subject) = entry .branch .most_recent_commit .as_ref() @@ -487,9 +491,10 @@ impl PickerDelegate for BranchListDelegate { OffsetDateTime::now_utc(), time_format::TimestampFormat::Relative, ); - (Some(formatted_time), Some(subject)) + let author = commit.author_name.clone(); + (Some(formatted_time), Some(author), Some(subject)) }) - .unwrap_or_else(|| (None, None)); + .unwrap_or_else(|| (None, None, None)); let icon = if let Some(default_branch) = self.default_branch.clone() && entry.is_new @@ -570,7 +575,19 @@ impl PickerDelegate for BranchListDelegate { "based off the current branch".to_string() } } else { - subject.unwrap_or("no commits found".into()).to_string() + let show_author_name = ProjectSettings::get_global(cx) + .git + .branch_picker + .unwrap_or_default() + .show_author_name; + + subject.map_or("no commits found".into(), |subject| { + if show_author_name && author_name.is_some() { + format!("{} • {}", author_name.unwrap(), subject) + } else { + subject.to_string() + } + }) }; Label::new(message) .size(LabelSize::Small) diff --git a/crates/git_ui/src/commit_modal.rs b/crates/git_ui/src/commit_modal.rs index 5e7430ebc693458e6df9a41513138ae993b9097c..a2f84726543af50312dc24d0fcd9e0486b51d9c5 100644 --- a/crates/git_ui/src/commit_modal.rs +++ b/crates/git_ui/src/commit_modal.rs @@ -35,7 +35,7 @@ impl ModalContainerProperties { // Calculate width based on character width let mut modal_width = 460.0; - let style = window.text_style().clone(); + let style = window.text_style(); let font_id = window.text_system().resolve_font(&style.font()); let font_size = style.font_size.to_pixels(window.rem_size()); @@ -135,11 +135,10 @@ impl CommitModal { .as_ref() .and_then(|repo| repo.read(cx).head_commit.as_ref()) .is_some() + && !git_panel.amend_pending() { - if !git_panel.amend_pending() { - git_panel.set_amend_pending(true, cx); - git_panel.load_last_commit_message_if_empty(cx); - } + git_panel.set_amend_pending(true, cx); + git_panel.load_last_commit_message_if_empty(cx); } } ForceMode::Commit => { @@ -180,7 +179,7 @@ impl CommitModal { let commit_editor = git_panel.update(cx, |git_panel, cx| { git_panel.set_modal_open(true, cx); - let buffer = git_panel.commit_message_buffer(cx).clone(); + let buffer = git_panel.commit_message_buffer(cx); let panel_editor = git_panel.commit_editor.clone(); let project = git_panel.project.clone(); @@ -195,12 +194,12 @@ impl CommitModal { let commit_message = commit_editor.read(cx).text(cx); - if let Some(suggested_commit_message) = suggested_commit_message { - if commit_message.is_empty() { - commit_editor.update(cx, |editor, cx| { - editor.set_placeholder_text(suggested_commit_message, cx); - }); - } + if let Some(suggested_commit_message) = suggested_commit_message + && commit_message.is_empty() + { + commit_editor.update(cx, |editor, cx| { + editor.set_placeholder_text(&suggested_commit_message, window, cx); + }); } let focus_handle = commit_editor.focus_handle(cx); @@ -286,7 +285,7 @@ impl CommitModal { Some(ContextMenu::build(window, cx, |context_menu, _, _| { context_menu .when_some(keybinding_target.clone(), |el, keybinding_target| { - el.context(keybinding_target.clone()) + el.context(keybinding_target) }) .when(has_previous_commit, |this| { this.toggleable_entry( @@ -392,15 +391,9 @@ impl CommitModal { }); let focus_handle = self.focus_handle(cx); - let close_kb_hint = - if let Some(close_kb) = ui::KeyBinding::for_action(&menu::Cancel, window, cx) { - Some( - KeybindingHint::new(close_kb, cx.theme().colors().editor_background) - .suffix("Cancel"), - ) - } else { - None - }; + let close_kb_hint = ui::KeyBinding::for_action(&menu::Cancel, window, cx).map(|close_kb| { + KeybindingHint::new(close_kb, cx.theme().colors().editor_background).suffix("Cancel") + }); h_flex() .group("commit_editor_footer") @@ -483,7 +476,7 @@ impl CommitModal { }), self.render_git_commit_menu( ElementId::Name(format!("split-button-right-{}", commit_label).into()), - Some(focus_handle.clone()), + Some(focus_handle), ) .into_any_element(), )), diff --git a/crates/git_ui/src/commit_tooltip.rs b/crates/git_ui/src/commit_tooltip.rs index 00ab911610c92dff4094452c55662447c4291259..809274beae369736ddeb4bb6e3ddc2cca2b15262 100644 --- a/crates/git_ui/src/commit_tooltip.rs +++ b/crates/git_ui/src/commit_tooltip.rs @@ -181,7 +181,7 @@ impl Render for CommitTooltip { .get(0..8) .map(|sha| sha.to_string().into()) .unwrap_or_else(|| self.commit.sha.clone()); - let full_sha = self.commit.sha.to_string().clone(); + let full_sha = self.commit.sha.to_string(); let absolute_timestamp = format_local_timestamp( self.commit.commit_time, OffsetDateTime::now_utc(), @@ -229,6 +229,7 @@ impl Render for CommitTooltip { .into() }), commit_timestamp: self.commit.commit_time.unix_timestamp(), + author_name: self.commit.author_name.clone(), has_parent: false, }; diff --git a/crates/git_ui/src/commit_view.rs b/crates/git_ui/src/commit_view.rs index c8c237fe90f12f2ac4ead04e0f2f0b4955f8bc1c..ac51cee8e42567a607891dd242c2bf103ae7fc0e 100644 --- a/crates/git_ui/src/commit_view.rs +++ b/crates/git_ui/src/commit_view.rs @@ -1,6 +1,6 @@ use anyhow::{Context as _, Result}; use buffer_diff::{BufferDiff, BufferDiffSnapshot}; -use editor::{Editor, EditorEvent, MultiBuffer, SelectionEffects}; +use editor::{Editor, EditorEvent, MultiBuffer, SelectionEffects, multibuffer_context_lines}; use git::repository::{CommitDetails, CommitDiff, CommitSummary, RepoPath}; use gpui::{ AnyElement, AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, @@ -88,11 +88,10 @@ impl CommitView { let ix = pane.items().position(|item| { let commit_view = item.downcast::(); commit_view - .map_or(false, |view| view.read(cx).commit.sha == commit.sha) + .is_some_and(|view| view.read(cx).commit.sha == commit.sha) }); if let Some(ix) = ix { pane.activate_item(ix, true, true, window, cx); - return; } else { pane.add_item(Box::new(commit_view), true, true, None, window, cx); } @@ -160,7 +159,7 @@ impl CommitView { }); } - cx.spawn(async move |this, mut cx| { + cx.spawn(async move |this, cx| { for file in commit_diff.files { let is_deleted = file.new_text.is_none(); let new_text = file.new_text.unwrap_or_default(); @@ -179,9 +178,9 @@ impl CommitView { worktree_id, }) as Arc; - let buffer = build_buffer(new_text, file, &language_registry, &mut cx).await?; + let buffer = build_buffer(new_text, file, &language_registry, cx).await?; let buffer_diff = - build_buffer_diff(old_text, &buffer, &language_registry, &mut cx).await?; + build_buffer_diff(old_text, &buffer, &language_registry, cx).await?; this.update(cx, |this, cx| { this.multibuffer.update(cx, |multibuffer, cx| { @@ -196,7 +195,7 @@ impl CommitView { PathKey::namespaced(FILE_NAMESPACE, path), buffer, diff_hunk_ranges, - editor::DEFAULT_MULTIBUFFER_CONTEXT, + multibuffer_context_lines(cx), cx, ); multibuffer.add_diff(buffer_diff, cx); diff --git a/crates/git_ui/src/conflict_view.rs b/crates/git_ui/src/conflict_view.rs index 0bbb9411be9ef8a8e6b73d11cc4d01126570741f..ee1b82920d7621f6e5b1d4ab9a9b44e151fbf82a 100644 --- a/crates/git_ui/src/conflict_view.rs +++ b/crates/git_ui/src/conflict_view.rs @@ -55,7 +55,7 @@ pub fn register_editor(editor: &mut Editor, buffer: Entity, cx: &mu buffers: Default::default(), }); - let buffers = buffer.read(cx).all_buffers().clone(); + let buffers = buffer.read(cx).all_buffers(); for buffer in buffers { buffer_added(editor, buffer, cx); } @@ -112,7 +112,7 @@ fn excerpt_for_buffer_updated( } fn buffer_added(editor: &mut Editor, buffer: Entity, cx: &mut Context) { - let Some(project) = &editor.project else { + let Some(project) = editor.project() else { return; }; let git_store = project.read(cx).git_store().clone(); @@ -129,7 +129,7 @@ fn buffer_added(editor: &mut Editor, buffer: Entity, cx: &mut Context( where T: IntoEnumIterator + VariantNames + 'static, { - let rx = window.prompt(PromptLevel::Info, msg, detail, &T::VARIANTS, cx); + let rx = window.prompt(PromptLevel::Info, msg, detail, T::VARIANTS, cx); cx.spawn(async move |_| Ok(T::iter().nth(rx.await?).unwrap())) } @@ -119,6 +120,7 @@ struct GitMenuState { has_staged_changes: bool, has_unstaged_changes: bool, has_new_changes: bool, + sort_by_path: bool, } fn git_panel_context_menu( @@ -160,6 +162,16 @@ fn git_panel_context_menu( "Trash Untracked Files", TrashUntrackedFiles.boxed_clone(), ) + .separator() + .entry( + if state.sort_by_path { + "Sort by Status" + } else { + "Sort by Path" + }, + Some(Box::new(ToggleSortByPath)), + move |window, cx| window.dispatch_action(Box::new(ToggleSortByPath), cx), + ) }) } @@ -351,6 +363,7 @@ pub struct GitPanel { pending: Vec, pending_commit: Option>, amend_pending: bool, + original_commit_message: Option, signoff_enabled: bool, pending_serialization: Task<()>, pub(crate) project: Entity, @@ -388,9 +401,6 @@ pub(crate) fn commit_message_editor( window: &mut Window, cx: &mut Context, ) -> Editor { - project.update(cx, |this, cx| { - this.mark_buffer_as_non_searchable(commit_message_buffer.read(cx).remote_id(), cx); - }); let buffer = cx.new(|cx| MultiBuffer::singleton(commit_message_buffer, cx)); let max_lines = if in_panel { MAX_PANEL_EDITOR_LINES } else { 18 }; let mut commit_editor = Editor::new( @@ -410,7 +420,7 @@ pub(crate) fn commit_message_editor( commit_editor.set_show_wrap_guides(false, cx); commit_editor.set_show_indent_guides(false, cx); let placeholder = placeholder.unwrap_or("Enter commit message".into()); - commit_editor.set_placeholder_text(placeholder, cx); + commit_editor.set_placeholder_text(&placeholder, window, cx); commit_editor } @@ -426,7 +436,7 @@ impl GitPanel { let git_store = project.read(cx).git_store().clone(); let active_repository = project.read(cx).active_repository(cx); - let git_panel = cx.new(|cx| { + cx.new(|cx| { let focus_handle = cx.focus_handle(); cx.on_focus(&focus_handle, window, Self::focus_in).detach(); cx.on_focus_out(&focus_handle, window, |this, _, window, cx| { @@ -435,10 +445,10 @@ impl GitPanel { .detach(); let mut was_sort_by_path = GitPanelSettings::get_global(cx).sort_by_path; - cx.observe_global::(move |this, cx| { + cx.observe_global_in::(window, move |this, window, cx| { let is_sort_by_path = GitPanelSettings::get_global(cx).sort_by_path; if is_sort_by_path != was_sort_by_path { - this.update_visible_entries(cx); + this.update_visible_entries(window, cx); } was_sort_by_path = is_sort_by_path }) @@ -535,6 +545,7 @@ impl GitPanel { pending: Vec::new(), pending_commit: None, amend_pending: false, + original_commit_message: None, signoff_enabled: false, pending_serialization: Task::ready(()), single_staged_entry: None, @@ -563,9 +574,7 @@ impl GitPanel { this.schedule_update(false, window, cx); this - }); - - git_panel + }) } fn hide_scrollbars(&mut self, window: &mut Window, cx: &mut Context) { @@ -652,14 +661,14 @@ impl GitPanel { if GitPanelSettings::get_global(cx).sort_by_path { return self .entries - .binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(&path)) + .binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(path)) .ok(); } if self.conflicted_count > 0 { let conflicted_start = 1; if let Ok(ix) = self.entries[conflicted_start..conflicted_start + self.conflicted_count] - .binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(&path)) + .binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(path)) { return Some(conflicted_start + ix); } @@ -671,7 +680,7 @@ impl GitPanel { 0 } + 1; if let Ok(ix) = self.entries[tracked_start..tracked_start + self.tracked_count] - .binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(&path)) + .binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(path)) { return Some(tracked_start + ix); } @@ -687,7 +696,7 @@ impl GitPanel { 0 } + 1; if let Ok(ix) = self.entries[untracked_start..untracked_start + self.new_count] - .binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(&path)) + .binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(path)) { return Some(untracked_start + ix); } @@ -775,7 +784,7 @@ impl GitPanel { if window .focused(cx) - .map_or(false, |focused| self.focus_handle == focused) + .is_some_and(|focused| self.focus_handle == focused) { dispatch_context.add("menu"); dispatch_context.add("ChangesList"); @@ -894,9 +903,7 @@ impl GitPanel { let have_entries = self .active_repository .as_ref() - .map_or(false, |active_repository| { - active_repository.read(cx).status_summary().count > 0 - }); + .is_some_and(|active_repository| active_repository.read(cx).status_summary().count > 0); if have_entries && self.selected_entry.is_none() { self.selected_entry = Some(1); self.scroll_to_selected_entry(cx); @@ -926,19 +933,17 @@ impl GitPanel { let workspace = self.workspace.upgrade()?; let git_repo = self.active_repository.as_ref()?; - if let Some(project_diff) = workspace.read(cx).active_item_as::(cx) { - if let Some(project_path) = project_diff.read(cx).active_path(cx) { - if Some(&entry.repo_path) - == git_repo - .read(cx) - .project_path_to_repo_path(&project_path, cx) - .as_ref() - { - project_diff.focus_handle(cx).focus(window); - project_diff.update(cx, |project_diff, cx| project_diff.autoscroll(cx)); - return None; - } - } + if let Some(project_diff) = workspace.read(cx).active_item_as::(cx) + && let Some(project_path) = project_diff.read(cx).active_path(cx) + && Some(&entry.repo_path) + == git_repo + .read(cx) + .project_path_to_repo_path(&project_path, cx) + .as_ref() + { + project_diff.focus_handle(cx).focus(window); + project_diff.update(cx, |project_diff, cx| project_diff.autoscroll(cx)); + return None; }; self.workspace @@ -1048,7 +1053,7 @@ impl GitPanel { let filename = path.path.file_name()?.to_string_lossy(); if !entry.status.is_created() { - self.perform_checkout(vec![entry.clone()], cx); + self.perform_checkout(vec![entry.clone()], window, cx); } else { let prompt = prompt(&format!("Trash {}?", filename), None, window, cx); cx.spawn_in(window, async move |_, cx| { @@ -1077,7 +1082,12 @@ impl GitPanel { }); } - fn perform_checkout(&mut self, entries: Vec, cx: &mut Context) { + fn perform_checkout( + &mut self, + entries: Vec, + window: &mut Window, + cx: &mut Context, + ) { let workspace = self.workspace.clone(); let Some(active_repository) = self.active_repository.clone() else { return; @@ -1090,7 +1100,7 @@ impl GitPanel { entries: entries.clone(), finished: false, }); - self.update_visible_entries(cx); + self.update_visible_entries(window, cx); let task = cx.spawn(async move |_, cx| { let tasks: Vec<_> = workspace.update(cx, |workspace, cx| { workspace.project().update(cx, |project, cx| { @@ -1137,16 +1147,16 @@ impl GitPanel { Ok(()) }); - cx.spawn(async move |this, cx| { + cx.spawn_in(window, async move |this, cx| { let result = task.await; - this.update(cx, |this, cx| { + this.update_in(cx, |this, window, cx| { for pending in this.pending.iter_mut() { if pending.op_id == op_id { pending.finished = true; if result.is_err() { pending.target_status = TargetStatus::Unchanged; - this.update_visible_entries(cx); + this.update_visible_entries(window, cx); } break; } @@ -1202,16 +1212,13 @@ impl GitPanel { window, cx, ); - cx.spawn(async move |this, cx| match prompt.await { - Ok(RestoreCancel::RestoreTrackedFiles) => { - this.update(cx, |this, cx| { - this.perform_checkout(entries, cx); + cx.spawn_in(window, async move |this, cx| { + if let Ok(RestoreCancel::RestoreTrackedFiles) = prompt.await { + this.update_in(cx, |this, window, cx| { + this.perform_checkout(entries, window, cx); }) .ok(); } - _ => { - return; - } }) .detach(); } @@ -1341,10 +1348,10 @@ impl GitPanel { .iter() .filter_map(|entry| entry.status_entry()) .filter(|status_entry| { - section.contains(&status_entry, repository) + section.contains(status_entry, repository) && status_entry.staging.as_bool() != Some(goal_staged_state) }) - .map(|status_entry| status_entry.clone()) + .cloned() .collect::>(); (goal_staged_state, entries) @@ -1476,7 +1483,6 @@ impl GitPanel { .read(cx) .as_singleton() .unwrap() - .clone() } fn toggle_staged_for_selected( @@ -1642,13 +1648,12 @@ impl GitPanel { fn has_commit_message(&self, cx: &mut Context) -> bool { let text = self.commit_editor.read(cx).text(cx); if !text.trim().is_empty() { - return true; + true } else if text.is_empty() { - return self - .suggest_commit_message(cx) - .is_some_and(|text| !text.trim().is_empty()); + self.suggest_commit_message(cx) + .is_some_and(|text| !text.trim().is_empty()) } else { - return false; + false } } @@ -1727,6 +1732,7 @@ impl GitPanel { Ok(()) => { this.commit_editor .update(cx, |editor, cx| editor.clear(window, cx)); + this.original_commit_message = None; } Err(e) => this.show_error_toast("commit", e, cx), } @@ -1737,7 +1743,7 @@ impl GitPanel { self.pending_commit = Some(task); } - fn uncommit(&mut self, window: &mut Window, cx: &mut Context) { + pub(crate) fn uncommit(&mut self, window: &mut Window, cx: &mut Context) { let Some(repo) = self.active_repository.clone() else { return; }; @@ -1833,7 +1839,9 @@ impl GitPanel { let git_status_entry = if let Some(staged_entry) = &self.single_staged_entry { Some(staged_entry) - } else if let Some(single_tracked_entry) = &self.single_tracked_entry { + } else if self.total_staged_count() == 0 + && let Some(single_tracked_entry) = &self.single_tracked_entry + { Some(single_tracked_entry) } else { None @@ -1869,13 +1877,17 @@ impl GitPanel { /// Generates a commit message using an LLM. pub fn generate_commit_message(&mut self, cx: &mut Context) { - if !self.can_commit() || DisableAiSettings::get_global(cx).disable_ai { + if !self.can_commit() + || DisableAiSettings::get_global(cx).disable_ai + || !agent_settings::AgentSettings::get_global(cx).enabled + { return; } - let model = match current_language_model(cx) { - Some(value) => value, - None => return, + let Some(ConfiguredModel { provider, model }) = + LanguageModelRegistry::read_global(cx).commit_message_model() + else { + return; }; let Some(repo) = self.active_repository.as_ref() else { @@ -1900,6 +1912,16 @@ impl GitPanel { this.generate_commit_message_task.take(); }); + if let Some(task) = cx.update(|cx| { + if !provider.is_authenticated(cx) { + Some(provider.authenticate(cx)) + } else { + None + } + })? { + task.await.log_err(); + }; + let mut diff_text = match diff.await { Ok(result) => match result { Ok(text) => text, @@ -1950,7 +1972,7 @@ impl GitPanel { thinking_allowed: false, }; - let stream = model.stream_completion_text(request, &cx); + let stream = model.stream_completion_text(request, cx); match stream.await { Ok(mut messages) => { if !text_empty { @@ -2086,6 +2108,7 @@ impl GitPanel { files: false, directories: true, multiple: false, + prompt: Some("Select as Repository Destination".into()), }); let workspace = self.workspace.clone(); @@ -2183,7 +2206,7 @@ impl GitPanel { let worktree = if worktrees.len() == 1 { Task::ready(Some(worktrees.first().unwrap().clone())) - } else if worktrees.len() == 0 { + } else if worktrees.is_empty() { let result = window.prompt( PromptLevel::Warning, "Unable to initialize a git repository", @@ -2511,10 +2534,11 @@ impl GitPanel { new_co_authors.push((name.clone(), email.clone())) } } - if !project.is_local() && !project.is_read_only(cx) { - if let Some(local_committer) = self.local_committer(room, cx) { - new_co_authors.push(local_committer); - } + if !project.is_local() + && !project.is_read_only(cx) + && let Some(local_committer) = self.local_committer(room, cx) + { + new_co_authors.push(local_committer); } new_co_authors } @@ -2541,6 +2565,24 @@ impl GitPanel { cx.notify(); } + fn toggle_sort_by_path( + &mut self, + _: &ToggleSortByPath, + _: &mut Window, + cx: &mut Context, + ) { + let current_setting = GitPanelSettings::get_global(cx).sort_by_path; + if let Some(workspace) = self.workspace.upgrade() { + let workspace = workspace.read(cx); + let fs = workspace.app_state().fs.clone(); + cx.update_global::(|store, _cx| { + store.update_settings_file::(fs, move |settings, _cx| { + settings.sort_by_path = Some(!current_setting); + }); + }); + } + } + fn fill_co_authors(&mut self, message: &mut String, cx: &mut Context) { const CO_AUTHOR_PREFIX: &str = "Co-authored-by: "; @@ -2605,7 +2647,7 @@ impl GitPanel { if clear_pending { git_panel.clear_pending(); } - git_panel.update_visible_entries(cx); + git_panel.update_visible_entries(window, cx); git_panel.update_scrollbar_properties(window, cx); }) .ok(); @@ -2658,7 +2700,7 @@ impl GitPanel { self.pending.retain(|v| !v.finished) } - fn update_visible_entries(&mut self, cx: &mut Context) { + fn update_visible_entries(&mut self, window: &mut Window, cx: &mut Context) { let bulk_staging = self.bulk_staging.take(); let last_staged_path_prev_index = bulk_staging .as_ref() @@ -2753,35 +2795,34 @@ impl GitPanel { for pending in self.pending.iter() { if pending.target_status == TargetStatus::Staged { pending_staged_count += pending.entries.len(); - last_pending_staged = pending.entries.iter().next().cloned(); + last_pending_staged = pending.entries.first().cloned(); } - if let Some(single_staged) = &single_staged_entry { - if pending + if let Some(single_staged) = &single_staged_entry + && pending .entries .iter() .any(|entry| entry.repo_path == single_staged.repo_path) - { - pending_status_for_single_staged = Some(pending.target_status); - } + { + pending_status_for_single_staged = Some(pending.target_status); } } - if conflict_entries.len() == 0 && staged_count == 1 && pending_staged_count == 0 { + if conflict_entries.is_empty() && staged_count == 1 && pending_staged_count == 0 { match pending_status_for_single_staged { Some(TargetStatus::Staged) | None => { self.single_staged_entry = single_staged_entry; } _ => {} } - } else if conflict_entries.len() == 0 && pending_staged_count == 1 { + } else if conflict_entries.is_empty() && pending_staged_count == 1 { self.single_staged_entry = last_pending_staged; } - if conflict_entries.len() == 0 && changed_entries.len() == 1 { + if conflict_entries.is_empty() && changed_entries.len() == 1 { self.single_tracked_entry = changed_entries.first().cloned(); } - if conflict_entries.len() > 0 { + if !conflict_entries.is_empty() { self.entries.push(GitListEntry::Header(GitHeaderEntry { header: Section::Conflict, })); @@ -2789,7 +2830,7 @@ impl GitPanel { .extend(conflict_entries.into_iter().map(GitListEntry::Status)); } - if changed_entries.len() > 0 { + if !changed_entries.is_empty() { if !sort_by_path { self.entries.push(GitListEntry::Header(GitHeaderEntry { header: Section::Tracked, @@ -2798,7 +2839,7 @@ impl GitPanel { self.entries .extend(changed_entries.into_iter().map(GitListEntry::Status)); } - if new_entries.len() > 0 { + if !new_entries.is_empty() { self.entries.push(GitListEntry::Header(GitHeaderEntry { header: Section::New, })); @@ -2834,7 +2875,7 @@ impl GitPanel { let placeholder_text = suggested_commit_message.unwrap_or("Enter commit message".into()); self.commit_editor.update(cx, |editor, cx| { - editor.set_placeholder_text(Arc::from(placeholder_text), cx) + editor.set_placeholder_text(&placeholder_text, window, cx) }); cx.notify(); @@ -2937,8 +2978,7 @@ impl GitPanel { .matches(git::repository::REMOTE_CANCELLED_BY_USER) .next() .is_some() - { - return; // Hide the cancelled by user message + { // Hide the cancelled by user message } else { workspace.update(cx, |workspace, cx| { let workspace_weak = cx.weak_entity(); @@ -2992,9 +3032,7 @@ impl GitPanel { let status_toast = StatusToast::new(message, cx, move |this, _cx| { use remote_output::SuccessStyle::*; match style { - Toast { .. } => { - this.icon(ToastIcon::new(IconName::GitBranchAlt).color(Color::Muted)) - } + Toast => this.icon(ToastIcon::new(IconName::GitBranchAlt).color(Color::Muted)), ToastWithLog { output } => this .icon(ToastIcon::new(IconName::GitBranchAlt).color(Color::Muted)) .action("View Log", move |window, cx| { @@ -3079,6 +3117,7 @@ impl GitPanel { has_staged_changes, has_unstaged_changes, has_new_changes, + sort_by_path: GitPanelSettings::get_global(cx).sort_by_path, }, window, cx, @@ -3091,32 +3130,37 @@ impl GitPanel { &self, cx: &Context, ) -> Option { - current_language_model(cx).is_some().then(|| { - if self.generate_commit_message_task.is_some() { - return h_flex() + if !agent_settings::AgentSettings::get_global(cx).enabled + || DisableAiSettings::get_global(cx).disable_ai + || LanguageModelRegistry::read_global(cx) + .commit_message_model() + .is_none() + { + return None; + } + + if self.generate_commit_message_task.is_some() { + return Some( + h_flex() .gap_1() .child( Icon::new(IconName::ArrowCircle) .size(IconSize::XSmall) .color(Color::Info) - .with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| { - icon.transform(Transformation::rotate(percentage(delta))) - }, - ), + .with_rotate_animation(2), ) .child( Label::new("Generating Commit...") .size(LabelSize::Small) .color(Color::Muted), ) - .into_any_element(); - } + .into_any_element(), + ); + } - let can_commit = self.can_commit(); - let editor_focus_handle = self.commit_editor.focus_handle(cx); + let can_commit = self.can_commit(); + let editor_focus_handle = self.commit_editor.focus_handle(cx); + Some( IconButton::new("generate-commit-message", IconName::AiEdit) .shape(ui::IconButtonShape::Square) .icon_color(Color::Muted) @@ -3137,8 +3181,8 @@ impl GitPanel { .on_click(cx.listener(move |this, _event, _window, cx| { this.generate_commit_message(cx); })) - .into_any_element() - }) + .into_any_element(), + ) } pub(crate) fn render_co_authors(&self, cx: &Context) -> Option { @@ -3215,7 +3259,7 @@ impl GitPanel { Some(ContextMenu::build(window, cx, |context_menu, _, _| { context_menu .when_some(keybinding_target.clone(), |el, keybinding_target| { - el.context(keybinding_target.clone()) + el.context(keybinding_target) }) .when(has_previous_commit, |this| { this.toggleable_entry( @@ -3271,12 +3315,10 @@ impl GitPanel { } else { "Amend Tracked" } + } else if self.has_staged_changes() { + "Commit" } else { - if self.has_staged_changes() { - "Commit" - } else { - "Commit Tracked" - } + "Commit Tracked" } } @@ -3397,7 +3439,7 @@ impl GitPanel { let enable_coauthors = self.render_co_authors(cx); let editor_focus_handle = self.commit_editor.focus_handle(cx); - let expand_tooltip_focus_handle = editor_focus_handle.clone(); + let expand_tooltip_focus_handle = editor_focus_handle; let branch = active_repository.read(cx).branch.clone(); let head_commit = active_repository.read(cx).head_commit.clone(); @@ -3410,7 +3452,7 @@ impl GitPanel { * MAX_PANEL_EDITOR_LINES + gap; - let git_panel = cx.entity().clone(); + let git_panel = cx.entity(); let display_name = SharedString::from(Arc::from( active_repository .read(cx) @@ -3426,7 +3468,7 @@ impl GitPanel { display_name, branch, head_commit, - Some(git_panel.clone()), + Some(git_panel), )) .child( panel_editor_container(window, cx) @@ -3577,7 +3619,7 @@ impl GitPanel { }), self.render_git_commit_menu( ElementId::Name(format!("split-button-right-{}", title).into()), - Some(commit_tooltip_focus_handle.clone()), + Some(commit_tooltip_focus_handle), cx, ) .into_any_element(), @@ -3643,7 +3685,7 @@ impl GitPanel { CommitView::open( commit.clone(), repo.clone(), - workspace.clone().clone(), + workspace.clone(), window, cx, ); @@ -4128,6 +4170,7 @@ impl GitPanel { has_staged_changes: self.has_staged_changes(), has_unstaged_changes: self.has_unstaged_changes(), has_new_changes: self.new_count > 0, + sort_by_path: GitPanelSettings::get_global(cx).sort_by_path, }, window, cx, @@ -4351,7 +4394,7 @@ impl GitPanel { } }) .child( - self.entry_label(display_name.clone(), label_color) + self.entry_label(display_name, label_color) .when(status.is_deleted(), |this| this.strikethrough()), ), ) @@ -4367,6 +4410,22 @@ impl GitPanel { } pub fn set_amend_pending(&mut self, value: bool, cx: &mut Context) { + if value && !self.amend_pending { + let current_message = self.commit_message_buffer(cx).read(cx).text(); + self.original_commit_message = if current_message.trim().is_empty() { + None + } else { + Some(current_message) + }; + } else if !value && self.amend_pending { + let message = self.original_commit_message.take().unwrap_or_default(); + self.commit_message_buffer(cx).update(cx, |buffer, cx| { + let start = buffer.anchor_before(0); + let end = buffer.anchor_after(buffer.len()); + buffer.edit([(start..end, message)], None, cx); + }); + } + self.amend_pending = value; self.serialize(cx); cx.notify(); @@ -4472,24 +4531,10 @@ impl GitPanel { } } -fn current_language_model(cx: &Context<'_, GitPanel>) -> Option> { - let is_enabled = agent_settings::AgentSettings::get_global(cx).enabled - && !DisableAiSettings::get_global(cx).disable_ai; - - is_enabled - .then(|| { - let ConfiguredModel { provider, model } = - LanguageModelRegistry::read_global(cx).commit_message_model()?; - - provider.is_authenticated(cx).then(|| model) - }) - .flatten() -} - impl Render for GitPanel { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let project = self.project.read(cx); - let has_entries = self.entries.len() > 0; + let has_entries = !self.entries.is_empty(); let room = self .workspace .upgrade() @@ -4497,7 +4542,7 @@ impl Render for GitPanel { let has_write_access = self.has_write_access(cx); - let has_co_authors = room.map_or(false, |room| { + let has_co_authors = room.is_some_and(|room| { self.load_local_committer(cx); let room = room.read(cx); room.remote_participants() @@ -4539,6 +4584,7 @@ impl Render for GitPanel { .when(has_write_access && has_co_authors, |git_panel| { git_panel.on_action(cx.listener(Self::toggle_fill_co_authors)) }) + .on_action(cx.listener(Self::toggle_sort_by_path)) .on_hover(cx.listener(move |this, hovered, window, cx| { if *hovered { this.horizontal_scrollbar.show(cx); @@ -4617,7 +4663,7 @@ impl editor::Addon for GitPanelAddon { git_panel .read(cx) - .render_buffer_header_controls(&git_panel, &file, window, cx) + .render_buffer_header_controls(&git_panel, file, window, cx) } } @@ -4700,7 +4746,7 @@ impl GitPanelMessageTooltip { author_email: details.author_email.clone(), commit_time: OffsetDateTime::from_unix_timestamp(details.commit_timestamp)?, message: Some(ParsedCommitMessage { - message: details.message.clone(), + message: details.message, ..Default::default() }), }; @@ -4813,12 +4859,10 @@ impl RenderOnce for PanelRepoFooter { // ideally, show the whole branch and repo names but // when we can't, use a budget to allocate space between the two - let (repo_display_len, branch_display_len) = if branch_actual_len + repo_actual_len - <= LABEL_CHARACTER_BUDGET - { - (repo_actual_len, branch_actual_len) - } else { - if branch_actual_len <= MAX_BRANCH_LEN { + let (repo_display_len, branch_display_len) = + if branch_actual_len + repo_actual_len <= LABEL_CHARACTER_BUDGET { + (repo_actual_len, branch_actual_len) + } else if branch_actual_len <= MAX_BRANCH_LEN { let repo_space = (LABEL_CHARACTER_BUDGET - branch_actual_len).min(MAX_REPO_LEN); (repo_space, branch_actual_len) } else if repo_actual_len <= MAX_REPO_LEN { @@ -4826,8 +4870,7 @@ impl RenderOnce for PanelRepoFooter { (repo_actual_len, branch_space) } else { (MAX_REPO_LEN, MAX_BRANCH_LEN) - } - }; + }; let truncated_repo_name = if repo_actual_len <= repo_display_len { active_repo_name.to_string() @@ -4836,7 +4879,7 @@ impl RenderOnce for PanelRepoFooter { }; let truncated_branch_name = if branch_actual_len <= branch_display_len { - branch_name.to_string() + branch_name } else { util::truncate_and_trailoff(branch_name.trim_ascii(), branch_display_len) }; @@ -4849,7 +4892,7 @@ impl RenderOnce for PanelRepoFooter { let repo_selector = PopoverMenu::new("repository-switcher") .menu({ - let project = project.clone(); + let project = project; move |window, cx| { let project = project.clone()?; Some(cx.new(|cx| RepositorySelector::new(project, rems(16.), window, cx))) @@ -4979,6 +5022,7 @@ impl Component for PanelRepoFooter { sha: "abc123".into(), subject: "Modify stuff".into(), commit_timestamp: 1710932954, + author_name: "John Doe".into(), has_parent: true, }), } @@ -4996,6 +5040,7 @@ impl Component for PanelRepoFooter { sha: "abc123".into(), subject: "Modify stuff".into(), commit_timestamp: 1710932954, + author_name: "John Doe".into(), has_parent: true, }), } @@ -5020,10 +5065,7 @@ impl Component for PanelRepoFooter { div() .w(example_width) .overflow_hidden() - .child(PanelRepoFooter::new_preview( - active_repository(1).clone(), - None, - )) + .child(PanelRepoFooter::new_preview(active_repository(1), None)) .into_any_element(), ), single_example( @@ -5032,7 +5074,7 @@ impl Component for PanelRepoFooter { .w(example_width) .overflow_hidden() .child(PanelRepoFooter::new_preview( - active_repository(2).clone(), + active_repository(2), Some(branch(unknown_upstream)), )) .into_any_element(), @@ -5043,7 +5085,7 @@ impl Component for PanelRepoFooter { .w(example_width) .overflow_hidden() .child(PanelRepoFooter::new_preview( - active_repository(3).clone(), + active_repository(3), Some(branch(no_remote_upstream)), )) .into_any_element(), @@ -5054,7 +5096,7 @@ impl Component for PanelRepoFooter { .w(example_width) .overflow_hidden() .child(PanelRepoFooter::new_preview( - active_repository(4).clone(), + active_repository(4), Some(branch(not_ahead_or_behind_upstream)), )) .into_any_element(), @@ -5065,7 +5107,7 @@ impl Component for PanelRepoFooter { .w(example_width) .overflow_hidden() .child(PanelRepoFooter::new_preview( - active_repository(5).clone(), + active_repository(5), Some(branch(behind_upstream)), )) .into_any_element(), @@ -5076,7 +5118,7 @@ impl Component for PanelRepoFooter { .w(example_width) .overflow_hidden() .child(PanelRepoFooter::new_preview( - active_repository(6).clone(), + active_repository(6), Some(branch(ahead_of_upstream)), )) .into_any_element(), @@ -5087,7 +5129,7 @@ impl Component for PanelRepoFooter { .w(example_width) .overflow_hidden() .child(PanelRepoFooter::new_preview( - active_repository(7).clone(), + active_repository(7), Some(branch(ahead_and_behind_upstream)), )) .into_any_element(), @@ -5258,7 +5300,7 @@ mod tests { project .read(cx) .worktrees(cx) - .nth(0) + .next() .unwrap() .read(cx) .as_local() @@ -5383,7 +5425,7 @@ mod tests { project .read(cx) .worktrees(cx) - .nth(0) + .next() .unwrap() .read(cx) .as_local() @@ -5434,7 +5476,7 @@ mod tests { project .read(cx) .worktrees(cx) - .nth(0) + .next() .unwrap() .read(cx) .as_local() @@ -5483,7 +5525,7 @@ mod tests { project .read(cx) .worktrees(cx) - .nth(0) + .next() .unwrap() .read(cx) .as_local() @@ -5518,4 +5560,73 @@ mod tests { ], ); } + + #[gpui::test] + async fn test_amend_commit_message_handling(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + "/root", + json!({ + "project": { + ".git": {}, + "src": { + "main.rs": "fn main() {}" + } + } + }), + ) + .await; + + fs.set_status_for_repo( + Path::new(path!("/root/project/.git")), + &[(Path::new("src/main.rs"), StatusCode::Modified.worktree())], + ); + + let project = Project::test(fs.clone(), [Path::new(path!("/root/project"))], cx).await; + let workspace = + cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let cx = &mut VisualTestContext::from_window(*workspace, cx); + + let panel = workspace.update(cx, GitPanel::new).unwrap(); + + // Test: User has commit message, enables amend (saves message), then disables (restores message) + panel.update(cx, |panel, cx| { + panel.commit_message_buffer(cx).update(cx, |buffer, cx| { + let start = buffer.anchor_before(0); + let end = buffer.anchor_after(buffer.len()); + buffer.edit([(start..end, "Initial commit message")], None, cx); + }); + + panel.set_amend_pending(true, cx); + assert!(panel.original_commit_message.is_some()); + + panel.set_amend_pending(false, cx); + let current_message = panel.commit_message_buffer(cx).read(cx).text(); + assert_eq!(current_message, "Initial commit message"); + assert!(panel.original_commit_message.is_none()); + }); + + // Test: User has empty commit message, enables amend, then disables (clears message) + panel.update(cx, |panel, cx| { + panel.commit_message_buffer(cx).update(cx, |buffer, cx| { + let start = buffer.anchor_before(0); + let end = buffer.anchor_after(buffer.len()); + buffer.edit([(start..end, "")], None, cx); + }); + + panel.set_amend_pending(true, cx); + assert!(panel.original_commit_message.is_none()); + + panel.commit_message_buffer(cx).update(cx, |buffer, cx| { + let start = buffer.anchor_before(0); + let end = buffer.anchor_after(buffer.len()); + buffer.edit([(start..end, "Previous commit message")], None, cx); + }); + + panel.set_amend_pending(false, cx); + let current_message = panel.commit_message_buffer(cx).read(cx).text(); + assert_eq!(current_message, ""); + }); + } } diff --git a/crates/git_ui/src/git_panel_settings.rs b/crates/git_ui/src/git_panel_settings.rs index b6891c7d256794b5b457669a20b17e6e41e4fd23..be207314acd82446566dffd2eb58339974f177ff 100644 --- a/crates/git_ui/src/git_panel_settings.rs +++ b/crates/git_ui/src/git_panel_settings.rs @@ -2,7 +2,7 @@ use editor::ShowScrollbar; use gpui::Pixels; use schemars::JsonSchema; use serde_derive::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsUi}; use workspace::dock::DockPosition; #[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] @@ -36,7 +36,8 @@ pub enum StatusStyle { LabelColor, } -#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)] +#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug, SettingsUi, SettingsKey)] +#[settings_key(key = "git_panel")] pub struct GitPanelSettingsContent { /// Whether to show the panel button in the status bar. /// @@ -90,8 +91,6 @@ pub struct GitPanelSettings { } impl Settings for GitPanelSettings { - const KEY: Option<&'static str> = Some("git_panel"); - type FileContent = GitPanelSettingsContent; fn load( diff --git a/crates/git_ui/src/git_ui.rs b/crates/git_ui/src/git_ui.rs index 7cb5c3818ce9118205febb92ead3fddcd367f3b1..a51382a744ad5a8a795ec9c6727d8bce3838a0ec 100644 --- a/crates/git_ui/src/git_ui.rs +++ b/crates/git_ui/src/git_ui.rs @@ -3,7 +3,7 @@ use std::any::Any; use ::settings::Settings; use command_palette_hooks::CommandPaletteFilter; use commit_modal::CommitModal; -use editor::{Editor, EditorElement, EditorStyle, actions::DiffClipboardWithSelectionData}; +use editor::{Editor, actions::DiffClipboardWithSelectionData}; use ui::{ Headline, HeadlineSize, Icon, IconName, IconSize, IntoElement, ParentElement, Render, Styled, StyledExt, div, h_flex, rems, v_flex, @@ -18,13 +18,12 @@ use git::{ use git_panel_settings::GitPanelSettings; use gpui::{ Action, App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, SharedString, - TextStyle, Window, actions, + Window, actions, }; use menu::{Cancel, Confirm}; use onboarding::GitOnboardingModal; use project::git_store::Repository; use project_diff::ProjectDiff; -use theme::ThemeSettings; use ui::prelude::*; use workspace::{ModalView, Workspace, notifications::DetachAndPromptErr}; use zed_actions; @@ -158,6 +157,14 @@ pub fn init(cx: &mut App) { panel.unstage_all(action, window, cx); }); }); + workspace.register_action(|workspace, _: &git::Uncommit, window, cx| { + let Some(panel) = workspace.panel::(cx) else { + return; + }; + panel.update(cx, |panel, cx| { + panel.uncommit(window, cx); + }) + }); CommandPaletteFilter::update_global(cx, |filter, _cx| { filter.hide_action_types(&[ zed_actions::OpenGitIntegrationOnboarding.type_id(), @@ -373,12 +380,12 @@ fn render_remote_button( } (0, 0) => None, (ahead, 0) => Some(remote_button::render_push_button( - keybinding_target.clone(), + keybinding_target, id, ahead, )), (ahead, behind) => Some(remote_button::render_pull_button( - keybinding_target.clone(), + keybinding_target, id, ahead, behind, @@ -553,16 +560,9 @@ mod remote_button { let command = command.into(); if let Some(handle) = focus_handle { - Tooltip::with_meta_in( - label.clone(), - Some(action), - command.clone(), - &handle, - window, - cx, - ) + Tooltip::with_meta_in(label, Some(action), command, &handle, window, cx) } else { - Tooltip::with_meta(label.clone(), Some(action), command.clone(), window, cx) + Tooltip::with_meta(label, Some(action), command, window, cx) } } @@ -585,7 +585,7 @@ mod remote_button { Some(ContextMenu::build(window, cx, |context_menu, _, _| { context_menu .when_some(keybinding_target.clone(), |el, keybinding_target| { - el.context(keybinding_target.clone()) + el.context(keybinding_target) }) .action("Fetch", git::Fetch.boxed_clone()) .action("Fetch From", git::FetchFrom.boxed_clone()) @@ -764,7 +764,7 @@ impl GitCloneModal { pub fn show(panel: Entity, window: &mut Window, cx: &mut Context) -> Self { let repo_input = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); - editor.set_placeholder_text("Enter repository", cx); + editor.set_placeholder_text("Enter repository URL…", window, cx); editor }); let focus_handle = repo_input.focus_handle(cx); @@ -777,46 +777,6 @@ impl GitCloneModal { focus_handle, } } - - fn render_editor(&self, window: &Window, cx: &App) -> impl IntoElement { - let settings = ThemeSettings::get_global(cx); - let theme = cx.theme(); - - let text_style = TextStyle { - color: cx.theme().colors().text, - font_family: settings.buffer_font.family.clone(), - font_features: settings.buffer_font.features.clone(), - font_size: settings.buffer_font_size(cx).into(), - font_weight: settings.buffer_font.weight, - line_height: relative(settings.buffer_line_height.value()), - background_color: Some(theme.colors().editor_background), - ..Default::default() - }; - - let element = EditorElement::new( - &self.repo_input, - EditorStyle { - background: theme.colors().editor_background, - local_player: theme.players().local(), - text: text_style, - ..Default::default() - }, - ); - - div() - .rounded_md() - .p_1() - .border_1() - .border_color(theme.colors().border_variant) - .when( - self.repo_input - .focus_handle(cx) - .contains_focused(window, cx), - |this| this.border_color(theme.colors().border_focused), - ) - .child(element) - .bg(theme.colors().editor_background) - } } impl Focusable for GitCloneModal { @@ -826,12 +786,42 @@ impl Focusable for GitCloneModal { } impl Render for GitCloneModal { - fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { div() - .size_full() - .w(rems(34.)) .elevation_3(cx) - .child(self.render_editor(window, cx)) + .w(rems(34.)) + .flex_1() + .overflow_hidden() + .child( + div() + .w_full() + .p_2() + .border_b_1() + .border_color(cx.theme().colors().border_variant) + .child(self.repo_input.clone()), + ) + .child( + h_flex() + .w_full() + .p_2() + .gap_0p5() + .rounded_b_sm() + .bg(cx.theme().colors().editor_background) + .child( + Label::new("Clone a repository from GitHub or other sources.") + .color(Color::Muted) + .size(LabelSize::Small), + ) + .child( + Button::new("learn-more", "Learn More") + .label_size(LabelSize::Small) + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::XSmall) + .on_click(|_, _, cx| { + cx.open_url("https://github.com/git-guides/git-clone"); + }), + ), + ) .on_action(cx.listener(|_, _: &menu::Cancel, _, cx| { cx.emit(DismissEvent); })) diff --git a/crates/git_ui/src/picker_prompt.rs b/crates/git_ui/src/picker_prompt.rs index 4077e0f3623e0925a87824e252a77755c78721ea..9997b0590cedfeab7cad6a7c52bce63f10657a80 100644 --- a/crates/git_ui/src/picker_prompt.rs +++ b/crates/git_ui/src/picker_prompt.rs @@ -152,7 +152,7 @@ impl PickerDelegate for PickerPromptDelegate { .all_options .iter() .enumerate() - .map(|(ix, option)| StringMatchCandidate::new(ix, &option)) + .map(|(ix, option)| StringMatchCandidate::new(ix, option)) .collect::>() }); let Some(candidates) = candidates.log_err() else { @@ -216,7 +216,7 @@ impl PickerDelegate for PickerPromptDelegate { _window: &mut Window, _cx: &mut Context>, ) -> Option { - let hit = &self.matches[ix]; + let hit = &self.matches.get(ix)?; let shortened_option = util::truncate_and_trailoff(&hit.string, self.max_match_length); Some( diff --git a/crates/git_ui/src/project_diff.rs b/crates/git_ui/src/project_diff.rs index d6a4e27286af1bb38dcd1acc488bce9da1813a42..9d0a575247427ec5fe674b342d0f2660e40e2299 100644 --- a/crates/git_ui/src/project_diff.rs +++ b/crates/git_ui/src/project_diff.rs @@ -10,6 +10,7 @@ use collections::HashSet; use editor::{ Editor, EditorEvent, SelectionEffects, actions::{GoToHunk, GoToPreviousHunk}, + multibuffer_context_lines, scroll::Autoscroll, }; use futures::StreamExt; @@ -242,7 +243,7 @@ impl ProjectDiff { TRACKED_NAMESPACE }; - let path_key = PathKey::namespaced(namespace, entry.repo_path.0.clone()); + let path_key = PathKey::namespaced(namespace, entry.repo_path.0); self.move_to_path(path_key, window, cx) } @@ -280,7 +281,7 @@ impl ProjectDiff { fn button_states(&self, cx: &App) -> ButtonStates { let editor = self.editor.read(cx); let snapshot = self.multibuffer.read(cx).snapshot(cx); - let prev_next = snapshot.diff_hunks().skip(1).next().is_some(); + let prev_next = snapshot.diff_hunks().nth(1).is_some(); let mut selection = true; let mut ranges = editor @@ -329,14 +330,14 @@ impl ProjectDiff { }) .ok(); - return ButtonStates { + ButtonStates { stage: has_unstaged_hunks, unstage: has_staged_hunks, prev_next, selection, stage_all, unstage_all, - }; + } } fn handle_editor_event( @@ -346,27 +347,24 @@ impl ProjectDiff { window: &mut Window, cx: &mut Context, ) { - match event { - EditorEvent::SelectionsChanged { local: true } => { - let Some(project_path) = self.active_path(cx) else { - return; - }; - self.workspace - .update(cx, |workspace, cx| { - if let Some(git_panel) = workspace.panel::(cx) { - git_panel.update(cx, |git_panel, cx| { - git_panel.select_entry_by_path(project_path, window, cx) - }) - } - }) - .ok(); - } - _ => {} + if let EditorEvent::SelectionsChanged { local: true } = event { + let Some(project_path) = self.active_path(cx) else { + return; + }; + self.workspace + .update(cx, |workspace, cx| { + if let Some(git_panel) = workspace.panel::(cx) { + git_panel.update(cx, |git_panel, cx| { + git_panel.select_entry_by_path(project_path, window, cx) + }) + } + }) + .ok(); } - if editor.focus_handle(cx).contains_focused(window, cx) { - if self.multibuffer.read(cx).is_empty() { - self.focus_handle.focus(window) - } + if editor.focus_handle(cx).contains_focused(window, cx) + && self.multibuffer.read(cx).is_empty() + { + self.focus_handle.focus(window) } } @@ -451,10 +449,10 @@ impl ProjectDiff { let diff = diff.read(cx); let diff_hunk_ranges = diff .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx) - .map(|diff_hunk| diff_hunk.buffer_range.clone()); + .map(|diff_hunk| diff_hunk.buffer_range); let conflicts = conflict_addon .conflict_set(snapshot.remote_id()) - .map(|conflict_set| conflict_set.read(cx).snapshot().conflicts.clone()) + .map(|conflict_set| conflict_set.read(cx).snapshot().conflicts) .unwrap_or_default(); let conflicts = conflicts.iter().map(|conflict| conflict.range.clone()); @@ -468,7 +466,7 @@ impl ProjectDiff { path_key.clone(), buffer, excerpt_ranges, - editor::DEFAULT_MULTIBUFFER_CONTEXT, + multibuffer_context_lines(cx), cx, ); (was_empty, is_newly_added) @@ -513,7 +511,7 @@ impl ProjectDiff { mut recv: postage::watch::Receiver<()>, cx: &mut AsyncWindowContext, ) -> Result<()> { - while let Some(_) = recv.next().await { + while (recv.next().await).is_some() { let buffers_to_load = this.update(cx, |this, cx| this.load_buffers(cx))?; for buffer_to_load in buffers_to_load { if let Some(buffer) = buffer_to_load.await.log_err() { @@ -740,7 +738,7 @@ impl Render for ProjectDiff { } else { None }; - let keybinding_focus_handle = self.focus_handle(cx).clone(); + let keybinding_focus_handle = self.focus_handle(cx); el.child( v_flex() .gap_1() @@ -1073,8 +1071,7 @@ pub struct ProjectDiffEmptyState { impl RenderOnce for ProjectDiffEmptyState { fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { let status_against_remote = |ahead_by: usize, behind_by: usize| -> bool { - match self.current_branch { - Some(Branch { + matches!(self.current_branch, Some(Branch { upstream: Some(Upstream { tracking: @@ -1084,9 +1081,7 @@ impl RenderOnce for ProjectDiffEmptyState { .. }), .. - }) if (ahead > 0) == (ahead_by > 0) && (behind > 0) == (behind_by > 0) => true, - _ => false, - } + }) if (ahead > 0) == (ahead_by > 0) && (behind > 0) == (behind_by > 0)) }; let change_count = |current_branch: &Branch| -> (usize, usize) { @@ -1173,7 +1168,7 @@ impl RenderOnce for ProjectDiffEmptyState { .child(Label::new("No Changes").color(Color::Muted)) } else { this.when_some(self.current_branch.as_ref(), |this, branch| { - this.child(has_branch_container(&branch)) + this.child(has_branch_container(branch)) }) } }), @@ -1225,6 +1220,7 @@ mod preview { sha: "abc123".into(), subject: "Modify stuff".into(), commit_timestamp: 1710932954, + author_name: "John Doe".into(), has_parent: true, }), } @@ -1332,14 +1328,14 @@ fn merge_anchor_ranges<'a>( loop { if let Some(left_range) = left .peek() - .filter(|range| range.start.cmp(&next_range.end, &snapshot).is_le()) + .filter(|range| range.start.cmp(&next_range.end, snapshot).is_le()) .cloned() { left.next(); next_range.end = left_range.end; } else if let Some(right_range) = right .peek() - .filter(|range| range.start.cmp(&next_range.end, &snapshot).is_le()) + .filter(|range| range.start.cmp(&next_range.end, snapshot).is_le()) .cloned() { right.next(); diff --git a/crates/git_ui/src/text_diff_view.rs b/crates/git_ui/src/text_diff_view.rs index 005c1e18b40727f42df81437c7038f4e5a7ef905..ebf32d1b994814fa277201176b555efed5e85e66 100644 --- a/crates/git_ui/src/text_diff_view.rs +++ b/crates/git_ui/src/text_diff_view.rs @@ -48,7 +48,7 @@ impl TextDiffView { let selection_data = source_editor.update(cx, |editor, cx| { let multibuffer = editor.buffer().read(cx); - let source_buffer = multibuffer.as_singleton()?.clone(); + let source_buffer = multibuffer.as_singleton()?; let selections = editor.selections.all::(cx); let buffer_snapshot = source_buffer.read(cx); let first_selection = selections.first()?; @@ -207,7 +207,7 @@ impl TextDiffView { path: Some(format!("Clipboard ↔ {selection_location_path}").into()), buffer_changes_tx, _recalculate_diff_task: cx.spawn(async move |_, cx| { - while let Ok(_) = buffer_changes_rx.recv().await { + while buffer_changes_rx.recv().await.is_ok() { loop { let mut timer = cx .background_executor() @@ -259,7 +259,7 @@ async fn update_diff_buffer( let source_buffer_snapshot = source_buffer.read_with(cx, |buffer, _| buffer.snapshot())?; let base_buffer_snapshot = clipboard_buffer.read_with(cx, |buffer, _| buffer.snapshot())?; - let base_text = base_buffer_snapshot.text().to_string(); + let base_text = base_buffer_snapshot.text(); let diff_snapshot = cx .update(|cx| { @@ -686,7 +686,7 @@ mod tests { let project = Project::test(fs, [project_root.as_ref()], cx).await; - let (workspace, mut cx) = + let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); let buffer = project @@ -725,7 +725,7 @@ mod tests { assert_state_with_diff( &diff_view.read_with(cx, |diff_view, _| diff_view.diff_editor.clone()), - &mut cx, + cx, expected_diff, ); diff --git a/crates/go_to_line/src/cursor_position.rs b/crates/go_to_line/src/cursor_position.rs index 29064eb29cb986187b9d86046fd3d78cd2f63451..6af8c79fe9cc4ed0be0d7cb466753fa939355eec 100644 --- a/crates/go_to_line/src/cursor_position.rs +++ b/crates/go_to_line/src/cursor_position.rs @@ -1,8 +1,8 @@ -use editor::{Editor, MultiBufferSnapshot}; +use editor::{Editor, EditorSettings, MultiBufferSnapshot}; use gpui::{App, Entity, FocusHandle, Focusable, Subscription, Task, WeakEntity}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsUi}; use std::{fmt::Write, num::NonZeroU32, time::Duration}; use text::{Point, Selection}; use ui::{ @@ -95,10 +95,8 @@ impl CursorPosition { .ok() .unwrap_or(true); - if !is_singleton { - if let Some(debounce) = debounce { - cx.background_executor().timer(debounce).await; - } + if !is_singleton && let Some(debounce) = debounce { + cx.background_executor().timer(debounce).await; } editor @@ -108,7 +106,7 @@ impl CursorPosition { cursor_position.selected_count.selections = editor.selections.count(); match editor.mode() { editor::EditorMode::AutoHeight { .. } - | editor::EditorMode::SingleLine { .. } + | editor::EditorMode::SingleLine | editor::EditorMode::Minimap { .. } => { cursor_position.position = None; cursor_position.context = None; @@ -131,7 +129,7 @@ impl CursorPosition { cursor_position.selected_count.lines += 1; } } - if last_selection.as_ref().map_or(true, |last_selection| { + if last_selection.as_ref().is_none_or(|last_selection| { selection.id > last_selection.id }) { last_selection = Some(selection); @@ -209,6 +207,13 @@ impl CursorPosition { impl Render for CursorPosition { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + if !EditorSettings::get_global(cx) + .status_bar + .cursor_position_button + { + return div(); + } + div().when_some(self.position, |el, position| { let mut text = format!( "{}{FILE_ROW_COLUMN_DELIMITER}{}", @@ -227,13 +232,11 @@ impl Render for CursorPosition { if let Some(editor) = workspace .active_item(cx) .and_then(|item| item.act_as::(cx)) + && let Some((_, buffer, _)) = editor.read(cx).active_excerpt(cx) { - if let Some((_, buffer, _)) = editor.read(cx).active_excerpt(cx) - { - workspace.toggle_modal(window, cx, |window, cx| { - crate::GoToLine::new(editor, buffer, window, cx) - }) - } + workspace.toggle_modal(window, cx, |window, cx| { + crate::GoToLine::new(editor, buffer, window, cx) + }) } }); } @@ -298,14 +301,13 @@ pub(crate) enum LineIndicatorFormat { Long, } -#[derive(Clone, Copy, Default, JsonSchema, Deserialize, Serialize)] +#[derive(Clone, Copy, Default, JsonSchema, Deserialize, Serialize, SettingsUi, SettingsKey)] #[serde(transparent)] +#[settings_key(key = "line_indicator_format")] pub(crate) struct LineIndicatorFormatContent(LineIndicatorFormat); impl Settings for LineIndicatorFormat { - const KEY: Option<&'static str> = Some("line_indicator_format"); - - type FileContent = Option; + type FileContent = LineIndicatorFormatContent; fn load(sources: SettingsSources, _: &mut App) -> anyhow::Result { let format = [ @@ -314,8 +316,8 @@ impl Settings for LineIndicatorFormat { sources.user, ] .into_iter() - .find_map(|value| value.copied().flatten()) - .unwrap_or(sources.default.ok_or_else(Self::missing_default)?); + .find_map(|value| value.copied()) + .unwrap_or(*sources.default); Ok(format.0) } diff --git a/crates/go_to_line/src/go_to_line.rs b/crates/go_to_line/src/go_to_line.rs index 1ac933e316bcde24384139c851a8bedb63388611..9b573d7071b64c7470e81079e7be5a5f048fc5eb 100644 --- a/crates/go_to_line/src/go_to_line.rs +++ b/crates/go_to_line/src/go_to_line.rs @@ -103,17 +103,20 @@ impl GoToLine { return; }; editor.update(cx, |editor, cx| { - if let Some(placeholder_text) = editor.placeholder_text() { - if editor.text(cx).is_empty() { - let placeholder_text = placeholder_text.to_string(); - editor.set_text(placeholder_text, window, cx); - } + if let Some(placeholder_text) = editor.placeholder_text(cx) + && editor.text(cx).is_empty() + { + editor.set_text(placeholder_text, window, cx); } }); } }) .detach(); - editor.set_placeholder_text(format!("{line}{FILE_ROW_COLUMN_DELIMITER}{column}"), cx); + editor.set_placeholder_text( + &format!("{line}{FILE_ROW_COLUMN_DELIMITER}{column}"), + window, + cx, + ); editor }); let line_editor_change = cx.subscribe_in(&line_editor, window, Self::on_line_editor_event); @@ -157,7 +160,7 @@ impl GoToLine { self.prev_scroll_position.take(); cx.emit(DismissEvent) } - editor::EditorEvent::BufferEdited { .. } => self.highlight_current_line(cx), + editor::EditorEvent::BufferEdited => self.highlight_current_line(cx), _ => {} } } @@ -691,11 +694,11 @@ mod tests { let go_to_line_view = open_go_to_line_view(workspace, cx); go_to_line_view.update(cx, |go_to_line_view, cx| { assert_eq!( - go_to_line_view - .line_editor - .read(cx) - .placeholder_text() - .expect("No placeholder text"), + go_to_line_view.line_editor.update(cx, |line_editor, cx| { + line_editor + .placeholder_text(cx) + .expect("No placeholder text") + }), format!( "{}:{}", expected_placeholder.line, expected_placeholder.character @@ -712,7 +715,7 @@ mod tests { ) -> Entity { cx.dispatch_action(editor::actions::ToggleGoToLine); workspace.update(cx, |workspace, cx| { - workspace.active_modal::(cx).unwrap().clone() + workspace.active_modal::(cx).unwrap() }) } diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index dfa51d024c46c2acc15744cd366c2fd723d59046..92fd53189327fabccdc1472ac0fa2a20dc665646 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -13,6 +13,7 @@ pub async fn stream_generate_content( api_key: &str, mut request: GenerateContentRequest, ) -> Result>> { + let api_key = api_key.trim(); validate_generate_content_request(&request)?; // The `model` field is emptied as it is provided as a path parameter. @@ -106,10 +107,9 @@ pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Re .contents .iter() .find(|content| content.role == Role::User) + && user_content.parts.is_empty() { - if user_content.parts.is_empty() { - bail!("User content must contain at least one part"); - } + bail!("User content must contain at least one part"); } Ok(()) @@ -267,7 +267,7 @@ pub struct CitationMetadata { pub struct PromptFeedback { #[serde(skip_serializing_if = "Option::is_none")] pub block_reason: Option, - pub safety_ratings: Vec, + pub safety_ratings: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub block_reason_message: Option, } @@ -478,10 +478,10 @@ impl<'de> Deserialize<'de> for ModelName { model_id: id.to_string(), }) } else { - return Err(serde::de::Error::custom(format!( + Err(serde::de::Error::custom(format!( "Expected model name to begin with {}, got: {}", MODEL_NAME_PREFIX, string - ))); + ))) } } } diff --git a/crates/gpui/Cargo.toml b/crates/gpui/Cargo.toml index d720dfb2a16ac7e8ac9ecae71d30a52d62a2c257..dd91eb4d4ee408b5381701f8ef5f4dae13344994 100644 --- a/crates/gpui/Cargo.toml +++ b/crates/gpui/Cargo.toml @@ -12,13 +12,13 @@ license = "Apache-2.0" workspace = true [features] -default = ["http_client", "font-kit", "wayland", "x11", "windows-manifest"] +default = ["font-kit", "wayland", "x11", "windows-manifest"] test-support = [ "leak-detection", "collections/test-support", "rand", "util/test-support", - "http_client?/test-support", + "http_client/test-support", "wayland", "x11", ] @@ -91,7 +91,7 @@ derive_more.workspace = true etagere = "0.2" futures.workspace = true gpui_macros.workspace = true -http_client = { optional = true, workspace = true } +http_client.workspace = true image.workspace = true inventory.workspace = true itertools.workspace = true @@ -119,6 +119,7 @@ serde_json.workspace = true slotmap = "1.0.6" smallvec.workspace = true smol.workspace = true +stacksafe.workspace = true strum.workspace = true sum_tree.workspace = true taffy = "=0.9.0" @@ -209,7 +210,7 @@ xkbcommon = { version = "0.8.0", features = [ "wayland", "x11", ], optional = true } -xim = { git = "https://github.com/XDeme1/xim-rs", rev = "d50d461764c2213655cd9cf65a0ea94c70d3c4fd", features = [ +xim = { git = "https://github.com/zed-industries/xim-rs", rev = "c0a70c1bd2ce197364216e5e818a2cb3adb99a8d" , features = [ "x11rb-xcb", "x11rb-client", ], optional = true } diff --git a/crates/gpui/README.md b/crates/gpui/README.md index 9faab7b6801873f531f8138e375cdad73fc23dc4..672d83e8ff0d72f641598d1a0b2c69a98650d45c 100644 --- a/crates/gpui/README.md +++ b/crates/gpui/README.md @@ -23,7 +23,7 @@ On macOS, GPUI uses Metal for rendering. In order to use Metal, you need to do t - Install [Xcode](https://apps.apple.com/us/app/xcode/id497799835?mt=12) from the macOS App Store, or from the [Apple Developer](https://developer.apple.com/download/all/) website. Note this requires a developer account. -> Ensure you launch XCode after installing, and install the macOS components, which is the default option. +> Ensure you launch Xcode after installing, and install the macOS components, which is the default option. If you are on macOS 26 (Tahoe) you will need to use `--features gpui/runtime_shaders` or add the feature in the root `Cargo.toml` - Install [Xcode command line tools](https://developer.apple.com/xcode/resources/) diff --git a/crates/gpui/build.rs b/crates/gpui/build.rs index 93a1c15c41dd173a35ffc0adf06af6c449809890..0040046f90554ecd3cb4c010faf18e7b98b62159 100644 --- a/crates/gpui/build.rs +++ b/crates/gpui/build.rs @@ -327,10 +327,10 @@ mod windows { /// You can set the `GPUI_FXC_PATH` environment variable to specify the path to the fxc.exe compiler. fn find_fxc_compiler() -> String { // Check environment variable - if let Ok(path) = std::env::var("GPUI_FXC_PATH") { - if Path::new(&path).exists() { - return path; - } + if let Ok(path) = std::env::var("GPUI_FXC_PATH") + && Path::new(&path).exists() + { + return path; } // Try to find in PATH @@ -338,11 +338,10 @@ mod windows { if let Ok(output) = std::process::Command::new("where.exe") .arg("fxc.exe") .output() + && output.status.success() { - if output.status.success() { - let path = String::from_utf8_lossy(&output.stdout); - return path.trim().to_string(); - } + let path = String::from_utf8_lossy(&output.stdout); + return path.trim().to_string(); } // Check the default path @@ -374,7 +373,7 @@ mod windows { shader_path, "vs_4_1", ); - generate_rust_binding(&const_name, &output_file, &rust_binding_path); + generate_rust_binding(&const_name, &output_file, rust_binding_path); // Compile fragment shader let output_file = format!("{}/{}_ps.h", out_dir, module); @@ -387,7 +386,7 @@ mod windows { shader_path, "ps_4_1", ); - generate_rust_binding(&const_name, &output_file, &rust_binding_path); + generate_rust_binding(&const_name, &output_file, rust_binding_path); } fn compile_shader_impl( diff --git a/crates/gpui/examples/data_table.rs b/crates/gpui/examples/data_table.rs index 5e82b08839de5f3b98ec3267b22a3bb8586fa02c..10e22828a8e8f5c8778cbcb06a087d4bdfac3adc 100644 --- a/crates/gpui/examples/data_table.rs +++ b/crates/gpui/examples/data_table.rs @@ -38,58 +38,58 @@ pub struct Quote { impl Quote { pub fn random() -> Self { use rand::Rng; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); // simulate a base price in a realistic range - let prev_close = rng.gen_range(100.0..200.0); - let change = rng.gen_range(-5.0..5.0); + let prev_close = rng.random_range(100.0..200.0); + let change = rng.random_range(-5.0..5.0); let last_done = prev_close + change; - let open = prev_close + rng.gen_range(-3.0..3.0); - let high = (prev_close + rng.gen_range::(0.0..10.0)).max(open); - let low = (prev_close - rng.gen_range::(0.0..10.0)).min(open); - let timestamp = Duration::from_secs(rng.gen_range(0..86400)); - let volume = rng.gen_range(1_000_000..100_000_000); + let open = prev_close + rng.random_range(-3.0..3.0); + let high = (prev_close + rng.random_range::(0.0..10.0)).max(open); + let low = (prev_close - rng.random_range::(0.0..10.0)).min(open); + let timestamp = Duration::from_secs(rng.random_range(0..86400)); + let volume = rng.random_range(1_000_000..100_000_000); let turnover = last_done * volume as f64; let symbol = { let mut ticker = String::new(); - if rng.gen_bool(0.5) { + if rng.random_bool(0.5) { ticker.push_str(&format!( "{:03}.{}", - rng.gen_range(100..1000), - rng.gen_range(0..10) + rng.random_range(100..1000), + rng.random_range(0..10) )); } else { ticker.push_str(&format!( "{}{}", - rng.gen_range('A'..='Z'), - rng.gen_range('A'..='Z') + rng.random_range('A'..='Z'), + rng.random_range('A'..='Z') )); } - ticker.push_str(&format!(".{}", rng.gen_range('A'..='Z'))); + ticker.push_str(&format!(".{}", rng.random_range('A'..='Z'))); ticker }; let name = format!( "{} {} - #{}", symbol, - rng.gen_range(1..100), - rng.gen_range(10000..100000) + rng.random_range(1..100), + rng.random_range(10000..100000) ); - let ttm = rng.gen_range(0.0..10.0); - let market_cap = rng.gen_range(1_000_000.0..10_000_000.0); - let float_cap = market_cap + rng.gen_range(1_000.0..10_000.0); - let shares = rng.gen_range(100.0..1000.0); + let ttm = rng.random_range(0.0..10.0); + let market_cap = rng.random_range(1_000_000.0..10_000_000.0); + let float_cap = market_cap + rng.random_range(1_000.0..10_000.0); + let shares = rng.random_range(100.0..1000.0); let pb = market_cap / shares; let pe = market_cap / shares; let eps = market_cap / shares; - let dividend = rng.gen_range(0.0..10.0); - let dividend_yield = rng.gen_range(0.0..10.0); - let dividend_per_share = rng.gen_range(0.0..10.0); + let dividend = rng.random_range(0.0..10.0); + let dividend_yield = rng.random_range(0.0..10.0); + let dividend_per_share = rng.random_range(0.0..10.0); let dividend_date = SharedString::new(format!( "{}-{}-{}", - rng.gen_range(2000..2023), - rng.gen_range(1..12), - rng.gen_range(1..28) + rng.random_range(2000..2023), + rng.random_range(1..12), + rng.random_range(1..28) )); - let dividend_payment = rng.gen_range(0.0..10.0); + let dividend_payment = rng.random_range(0.0..10.0); Self { name: name.into(), diff --git a/crates/gpui/examples/image/image.rs b/crates/gpui/examples/image/image.rs index bd1708e8c453656b2b7047b428f3dc63409eddec..34a510f76db396a91a225dffe21fcec986a62e20 100644 --- a/crates/gpui/examples/image/image.rs +++ b/crates/gpui/examples/image/image.rs @@ -75,65 +75,71 @@ impl Render for ImageShowcase { fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { div() .id("main") + .bg(gpui::white()) .overflow_y_scroll() .p_5() .size_full() - .flex() - .flex_col() - .justify_center() - .items_center() - .gap_8() - .bg(rgb(0xffffff)) .child( div() .flex() - .flex_row() + .flex_col() .justify_center() .items_center() .gap_8() - .child(ImageContainer::new( - "Image loaded from a local file", - self.local_resource.clone(), - )) - .child(ImageContainer::new( - "Image loaded from a remote resource", - self.remote_resource.clone(), + .child(img( + "https://github.com/zed-industries/zed/actions/workflows/ci.yml/badge.svg", )) - .child(ImageContainer::new( - "Image loaded from an asset", - self.asset_resource.clone(), - )), - ) - .child( - div() - .flex() - .flex_row() - .gap_8() .child( div() - .flex_col() - .child("Auto Width") - .child(img("https://picsum.photos/800/400").h(px(180.))), + .flex() + .flex_row() + .justify_center() + .items_center() + .gap_8() + .child(ImageContainer::new( + "Image loaded from a local file", + self.local_resource.clone(), + )) + .child(ImageContainer::new( + "Image loaded from a remote resource", + self.remote_resource.clone(), + )) + .child(ImageContainer::new( + "Image loaded from an asset", + self.asset_resource.clone(), + )), + ) + .child( + div() + .flex() + .flex_row() + .gap_8() + .child( + div() + .flex_col() + .child("Auto Width") + .child(img("https://picsum.photos/800/400").h(px(180.))), + ) + .child( + div() + .flex_col() + .child("Auto Height") + .child(img("https://picsum.photos/800/400").w(px(180.))), + ), ) .child( div() + .flex() .flex_col() - .child("Auto Height") - .child(img("https://picsum.photos/800/400").w(px(180.))), + .justify_center() + .items_center() + .w_full() + .border_1() + .border_color(rgb(0xC0C0C0)) + .child("image with max width 100%") + .child(img("https://picsum.photos/800/400").max_w_full()), ), ) - .child( - div() - .flex() - .flex_col() - .justify_center() - .items_center() - .w_full() - .border_1() - .border_color(rgb(0xC0C0C0)) - .child("image with max width 100%") - .child(img("https://picsum.photos/800/400").max_w_full()), - ) } } diff --git a/crates/gpui/examples/input.rs b/crates/gpui/examples/input.rs index 52a5b08b967927ef709dffc1e21c4075e4cdc5df..37115feaa551a787562e7299c9d44bcc97b5fca3 100644 --- a/crates/gpui/examples/input.rs +++ b/crates/gpui/examples/input.rs @@ -137,14 +137,14 @@ impl TextInput { fn copy(&mut self, _: &Copy, _: &mut Window, cx: &mut Context) { if !self.selected_range.is_empty() { cx.write_to_clipboard(ClipboardItem::new_string( - (&self.content[self.selected_range.clone()]).to_string(), + self.content[self.selected_range.clone()].to_string(), )); } } fn cut(&mut self, _: &Cut, window: &mut Window, cx: &mut Context) { if !self.selected_range.is_empty() { cx.write_to_clipboard(ClipboardItem::new_string( - (&self.content[self.selected_range.clone()]).to_string(), + self.content[self.selected_range.clone()].to_string(), )); self.replace_text_in_range(None, "", window, cx) } @@ -446,7 +446,7 @@ impl Element for TextElement { let (display_text, text_color) = if content.is_empty() { (input.placeholder.clone(), hsla(0., 0., 0., 0.2)) } else { - (content.clone(), style.color) + (content, style.color) }; let run = TextRun { @@ -474,7 +474,7 @@ impl Element for TextElement { }, TextRun { len: display_text.len() - marked_range.end, - ..run.clone() + ..run }, ] .into_iter() @@ -549,10 +549,10 @@ impl Element for TextElement { line.paint(bounds.origin, window.line_height(), window, cx) .unwrap(); - if focus_handle.is_focused(window) { - if let Some(cursor) = prepaint.cursor.take() { - window.paint_quad(cursor); - } + if focus_handle.is_focused(window) + && let Some(cursor) = prepaint.cursor.take() + { + window.paint_quad(cursor); } self.input.update(cx, |input, _cx| { @@ -595,9 +595,7 @@ impl Render for TextInput { .w_full() .p(px(4.)) .bg(white()) - .child(TextElement { - input: cx.entity().clone(), - }), + .child(TextElement { input: cx.entity() }), ) } } diff --git a/crates/gpui/examples/text.rs b/crates/gpui/examples/text.rs index 1166bb279541c80eb8686b59c85724b4068895ed..66e9cff0aa9773d99b412ea7249ca64ea103b138 100644 --- a/crates/gpui/examples/text.rs +++ b/crates/gpui/examples/text.rs @@ -155,7 +155,7 @@ impl RenderOnce for Specimen { .text_size(px(font_size * scale)) .line_height(relative(line_height)) .p(px(10.0)) - .child(self.string.clone()) + .child(self.string) } } diff --git a/crates/gpui/examples/window.rs b/crates/gpui/examples/window.rs index 30f3697b223d6d85a9db573eb3659e9689af60a5..4445f24e4ec0f2809109964fd34610cad1299e90 100644 --- a/crates/gpui/examples/window.rs +++ b/crates/gpui/examples/window.rs @@ -152,6 +152,36 @@ impl Render for WindowDemo { ) .unwrap(); })) + .child(button("Unresizable", move |_, cx| { + cx.open_window( + WindowOptions { + is_resizable: false, + window_bounds: Some(window_bounds), + ..Default::default() + }, + |_, cx| { + cx.new(|_| SubWindow { + custom_titlebar: false, + }) + }, + ) + .unwrap(); + })) + .child(button("Unminimizable", move |_, cx| { + cx.open_window( + WindowOptions { + is_minimizable: false, + window_bounds: Some(window_bounds), + ..Default::default() + }, + |_, cx| { + cx.new(|_| SubWindow { + custom_titlebar: false, + }) + }, + ) + .unwrap(); + })) .child(button("Hide Application", |window, cx| { cx.hide(); diff --git a/crates/gpui/examples/window_positioning.rs b/crates/gpui/examples/window_positioning.rs index 0f0bb8ac288d7117867df9b12a104e4272903378..ca6cd535d67aa8b2e700b2d0bc632056e928e0e7 100644 --- a/crates/gpui/examples/window_positioning.rs +++ b/crates/gpui/examples/window_positioning.rs @@ -62,6 +62,8 @@ fn build_window_options(display_id: DisplayId, bounds: Bounds) -> Window app_id: None, window_min_size: None, window_decorations: None, + tabbing_identifier: None, + ..Default::default() } } diff --git a/crates/gpui/src/action.rs b/crates/gpui/src/action.rs index b179076cd5f0da826ca0d5da5e2a5a41cbb5e806..0b824fec34aee7abcb2dbba285265c79b6851d16 100644 --- a/crates/gpui/src/action.rs +++ b/crates/gpui/src/action.rs @@ -73,18 +73,18 @@ macro_rules! actions { /// - `name = "ActionName"` overrides the action's name. This must not contain `::`. /// /// - `no_json` causes the `build` method to always error and `action_json_schema` to return `None`, -/// and allows actions not implement `serde::Serialize` and `schemars::JsonSchema`. +/// and allows actions not implement `serde::Serialize` and `schemars::JsonSchema`. /// /// - `no_register` skips registering the action. This is useful for implementing the `Action` trait -/// while not supporting invocation by name or JSON deserialization. +/// while not supporting invocation by name or JSON deserialization. /// /// - `deprecated_aliases = ["editor::SomeAction"]` specifies deprecated old names for the action. -/// These action names should *not* correspond to any actions that are registered. These old names -/// can then still be used to refer to invoke this action. In Zed, the keymap JSON schema will -/// accept these old names and provide warnings. +/// These action names should *not* correspond to any actions that are registered. These old names +/// can then still be used to refer to invoke this action. In Zed, the keymap JSON schema will +/// accept these old names and provide warnings. /// /// - `deprecated = "Message about why this action is deprecation"` specifies a deprecation message. -/// In Zed, the keymap JSON schema will cause this to be displayed as a warning. +/// In Zed, the keymap JSON schema will cause this to be displayed as a warning. /// /// # Manual Implementation /// diff --git a/crates/gpui/src/app.rs b/crates/gpui/src/app.rs index 5f6d25250375561ebb41adc9191ebb5c37480ba3..8b0b404d1dffbf8a27de1f29437ce9cc2ba63f0f 100644 --- a/crates/gpui/src/app.rs +++ b/crates/gpui/src/app.rs @@ -7,7 +7,7 @@ use std::{ path::{Path, PathBuf}, rc::{Rc, Weak}, sync::{Arc, atomic::Ordering::SeqCst}, - time::Duration, + time::{Duration, Instant}, }; use anyhow::{Context as _, Result, anyhow}; @@ -17,6 +17,7 @@ use futures::{ channel::oneshot, future::{LocalBoxFuture, Shared}, }; +use itertools::Itertools; use parking_lot::RwLock; use slotmap::SlotMap; @@ -37,10 +38,10 @@ use crate::{ AssetSource, BackgroundExecutor, Bounds, ClipboardItem, CursorStyle, DispatchPhase, DisplayId, EventEmitter, FocusHandle, FocusMap, ForegroundExecutor, Global, KeyBinding, KeyContext, Keymap, Keystroke, LayoutId, Menu, MenuItem, OwnedMenu, PathPromptOptions, Pixels, Platform, - PlatformDisplay, PlatformKeyboardLayout, Point, PromptBuilder, PromptButton, PromptHandle, - PromptLevel, Render, RenderImage, RenderablePromptHandle, Reservation, ScreenCaptureSource, - SubscriberSet, Subscription, SvgRenderer, Task, TextSystem, Window, WindowAppearance, - WindowHandle, WindowId, WindowInvalidator, + PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, Point, PromptBuilder, + PromptButton, PromptHandle, PromptLevel, Render, RenderImage, RenderablePromptHandle, + Reservation, ScreenCaptureSource, SharedString, SubscriberSet, Subscription, SvgRenderer, Task, + TextSystem, Window, WindowAppearance, WindowHandle, WindowId, WindowInvalidator, colors::{Colors, GlobalColors}, current_platform, hash, init_app_menus, }; @@ -237,6 +238,303 @@ type WindowClosedHandler = Box; type ReleaseListener = Box; type NewEntityListener = Box, &mut App) + 'static>; +#[doc(hidden)] +#[derive(Clone, PartialEq, Eq)] +pub struct SystemWindowTab { + pub id: WindowId, + pub title: SharedString, + pub handle: AnyWindowHandle, + pub last_active_at: Instant, +} + +impl SystemWindowTab { + /// Create a new instance of the window tab. + pub fn new(title: SharedString, handle: AnyWindowHandle) -> Self { + Self { + id: handle.id, + title, + handle, + last_active_at: Instant::now(), + } + } +} + +/// A controller for managing window tabs. +#[derive(Default)] +pub struct SystemWindowTabController { + visible: Option, + tab_groups: FxHashMap>, +} + +impl Global for SystemWindowTabController {} + +impl SystemWindowTabController { + /// Create a new instance of the window tab controller. + pub fn new() -> Self { + Self { + visible: None, + tab_groups: FxHashMap::default(), + } + } + + /// Initialize the global window tab controller. + pub fn init(cx: &mut App) { + cx.set_global(SystemWindowTabController::new()); + } + + /// Get all tab groups. + pub fn tab_groups(&self) -> &FxHashMap> { + &self.tab_groups + } + + /// Get the next tab group window handle. + pub fn get_next_tab_group_window(cx: &mut App, id: WindowId) -> Option<&AnyWindowHandle> { + let controller = cx.global::(); + let current_group = controller + .tab_groups + .iter() + .find_map(|(group, tabs)| tabs.iter().find(|tab| tab.id == id).map(|_| group)); + + let current_group = current_group?; + let mut group_ids: Vec<_> = controller.tab_groups.keys().collect(); + let idx = group_ids.iter().position(|g| *g == current_group)?; + let next_idx = (idx + 1) % group_ids.len(); + + controller + .tab_groups + .get(group_ids[next_idx]) + .and_then(|tabs| { + tabs.iter() + .max_by_key(|tab| tab.last_active_at) + .or_else(|| tabs.first()) + .map(|tab| &tab.handle) + }) + } + + /// Get the previous tab group window handle. + pub fn get_prev_tab_group_window(cx: &mut App, id: WindowId) -> Option<&AnyWindowHandle> { + let controller = cx.global::(); + let current_group = controller + .tab_groups + .iter() + .find_map(|(group, tabs)| tabs.iter().find(|tab| tab.id == id).map(|_| group)); + + let current_group = current_group?; + let mut group_ids: Vec<_> = controller.tab_groups.keys().collect(); + let idx = group_ids.iter().position(|g| *g == current_group)?; + let prev_idx = if idx == 0 { + group_ids.len() - 1 + } else { + idx - 1 + }; + + controller + .tab_groups + .get(group_ids[prev_idx]) + .and_then(|tabs| { + tabs.iter() + .max_by_key(|tab| tab.last_active_at) + .or_else(|| tabs.first()) + .map(|tab| &tab.handle) + }) + } + + /// Get all tabs in the same window. + pub fn tabs(&self, id: WindowId) -> Option<&Vec> { + let tab_group = self + .tab_groups + .iter() + .find_map(|(group, tabs)| tabs.iter().find(|tab| tab.id == id).map(|_| *group)); + + if let Some(tab_group) = tab_group { + self.tab_groups.get(&tab_group) + } else { + None + } + } + + /// Initialize the visibility of the system window tab controller. + pub fn init_visible(cx: &mut App, visible: bool) { + let mut controller = cx.global_mut::(); + if controller.visible.is_none() { + controller.visible = Some(visible); + } + } + + /// Get the visibility of the system window tab controller. + pub fn is_visible(&self) -> bool { + self.visible.unwrap_or(false) + } + + /// Set the visibility of the system window tab controller. + pub fn set_visible(cx: &mut App, visible: bool) { + let mut controller = cx.global_mut::(); + controller.visible = Some(visible); + } + + /// Update the last active of a window. + pub fn update_last_active(cx: &mut App, id: WindowId) { + let mut controller = cx.global_mut::(); + for windows in controller.tab_groups.values_mut() { + for tab in windows.iter_mut() { + if tab.id == id { + tab.last_active_at = Instant::now(); + } + } + } + } + + /// Update the position of a tab within its group. + pub fn update_tab_position(cx: &mut App, id: WindowId, ix: usize) { + let mut controller = cx.global_mut::(); + for (_, windows) in controller.tab_groups.iter_mut() { + if let Some(current_pos) = windows.iter().position(|tab| tab.id == id) { + if ix < windows.len() && current_pos != ix { + let window_tab = windows.remove(current_pos); + windows.insert(ix, window_tab); + } + break; + } + } + } + + /// Update the title of a tab. + pub fn update_tab_title(cx: &mut App, id: WindowId, title: SharedString) { + let controller = cx.global::(); + let tab = controller + .tab_groups + .values() + .flat_map(|windows| windows.iter()) + .find(|tab| tab.id == id); + + if tab.map_or(true, |t| t.title == title) { + return; + } + + let mut controller = cx.global_mut::(); + for windows in controller.tab_groups.values_mut() { + for tab in windows.iter_mut() { + if tab.id == id { + tab.title = title.clone(); + } + } + } + } + + /// Insert a tab into a tab group. + pub fn add_tab(cx: &mut App, id: WindowId, tabs: Vec) { + let mut controller = cx.global_mut::(); + let Some(tab) = tabs.clone().into_iter().find(|tab| tab.id == id) else { + return; + }; + + let mut expected_tab_ids: Vec<_> = tabs + .iter() + .filter(|tab| tab.id != id) + .map(|tab| tab.id) + .sorted() + .collect(); + + let mut tab_group_id = None; + for (group_id, group_tabs) in &controller.tab_groups { + let tab_ids: Vec<_> = group_tabs.iter().map(|tab| tab.id).sorted().collect(); + if tab_ids == expected_tab_ids { + tab_group_id = Some(*group_id); + break; + } + } + + if let Some(tab_group_id) = tab_group_id { + if let Some(tabs) = controller.tab_groups.get_mut(&tab_group_id) { + tabs.push(tab); + } + } else { + let new_group_id = controller.tab_groups.len(); + controller.tab_groups.insert(new_group_id, tabs); + } + } + + /// Remove a tab from a tab group. + pub fn remove_tab(cx: &mut App, id: WindowId) -> Option { + let mut controller = cx.global_mut::(); + let mut removed_tab = None; + + controller.tab_groups.retain(|_, tabs| { + if let Some(pos) = tabs.iter().position(|tab| tab.id == id) { + removed_tab = Some(tabs.remove(pos)); + } + !tabs.is_empty() + }); + + removed_tab + } + + /// Move a tab to a new tab group. + pub fn move_tab_to_new_window(cx: &mut App, id: WindowId) { + let mut removed_tab = Self::remove_tab(cx, id); + let mut controller = cx.global_mut::(); + + if let Some(tab) = removed_tab { + let new_group_id = controller.tab_groups.keys().max().map_or(0, |k| k + 1); + controller.tab_groups.insert(new_group_id, vec![tab]); + } + } + + /// Merge all tab groups into a single group. + pub fn merge_all_windows(cx: &mut App, id: WindowId) { + let mut controller = cx.global_mut::(); + let Some(initial_tabs) = controller.tabs(id) else { + return; + }; + + let mut all_tabs = initial_tabs.clone(); + for tabs in controller.tab_groups.values() { + all_tabs.extend( + tabs.iter() + .filter(|tab| !initial_tabs.contains(tab)) + .cloned(), + ); + } + + controller.tab_groups.clear(); + controller.tab_groups.insert(0, all_tabs); + } + + /// Selects the next tab in the tab group in the trailing direction. + pub fn select_next_tab(cx: &mut App, id: WindowId) { + let mut controller = cx.global_mut::(); + let Some(tabs) = controller.tabs(id) else { + return; + }; + + let current_index = tabs.iter().position(|tab| tab.id == id).unwrap(); + let next_index = (current_index + 1) % tabs.len(); + + let _ = &tabs[next_index].handle.update(cx, |_, window, _| { + window.activate_window(); + }); + } + + /// Selects the previous tab in the tab group in the leading direction. + pub fn select_previous_tab(cx: &mut App, id: WindowId) { + let mut controller = cx.global_mut::(); + let Some(tabs) = controller.tabs(id) else { + return; + }; + + let current_index = tabs.iter().position(|tab| tab.id == id).unwrap(); + let previous_index = if current_index == 0 { + tabs.len() - 1 + } else { + current_index - 1 + }; + + let _ = &tabs[previous_index].handle.update(cx, |_, window, _| { + window.activate_window(); + }); + } +} + /// Contains the state of the full application, and passed as a reference to a variety of callbacks. /// Other [Context] derefs to this type. /// You need a reference to an `App` to access the state of a [Entity]. @@ -263,6 +561,7 @@ pub struct App { pub(crate) focus_handles: Arc, pub(crate) keymap: Rc>, pub(crate) keyboard_layout: Box, + pub(crate) keyboard_mapper: Rc, pub(crate) global_action_listeners: FxHashMap>>, pending_effects: VecDeque, @@ -312,6 +611,7 @@ impl App { let text_system = Arc::new(TextSystem::new(platform.text_system())); let entities = EntityMap::new(); let keyboard_layout = platform.keyboard_layout(); + let keyboard_mapper = platform.keyboard_mapper(); let app = Rc::new_cyclic(|this| AppCell { app: RefCell::new(App { @@ -337,6 +637,7 @@ impl App { focus_handles: Arc::new(RwLock::new(SlotMap::with_key())), keymap: Rc::new(RefCell::new(Keymap::default())), keyboard_layout, + keyboard_mapper, global_action_listeners: FxHashMap::default(), pending_effects: VecDeque::new(), pending_notifications: FxHashSet::default(), @@ -368,7 +669,8 @@ impl App { }), }); - init_app_menus(platform.as_ref(), &mut app.borrow_mut()); + init_app_menus(platform.as_ref(), &app.borrow()); + SystemWindowTabController::init(&mut app.borrow_mut()); platform.on_keyboard_layout_change(Box::new({ let app = Rc::downgrade(&app); @@ -376,6 +678,7 @@ impl App { if let Some(app) = app.upgrade() { let cx = &mut app.borrow_mut(); cx.keyboard_layout = cx.platform.keyboard_layout(); + cx.keyboard_mapper = cx.platform.keyboard_mapper(); cx.keyboard_layout_observers .clone() .retain(&(), move |callback| (callback)(cx)); @@ -424,6 +727,11 @@ impl App { self.keyboard_layout.as_ref() } + /// Get the current keyboard mapper. + pub fn keyboard_mapper(&self) -> &Rc { + &self.keyboard_mapper + } + /// Invokes a handler when the current keyboard layout changes pub fn on_keyboard_layout_change(&self, mut callback: F) -> Subscription where @@ -816,8 +1124,9 @@ impl App { pub fn prompt_for_new_path( &self, directory: &Path, + suggested_name: Option<&str>, ) -> oneshot::Receiver>> { - self.platform.prompt_for_new_path(directory) + self.platform.prompt_for_new_path(directory, suggested_name) } /// Reveals the specified path at the platform level, such as in Finder on macOS. @@ -1049,12 +1358,7 @@ impl App { F: FnOnce(AnyView, &mut Window, &mut App) -> T, { self.update(|cx| { - let mut window = cx - .windows - .get_mut(id) - .context("window not found")? - .take() - .context("window not found")?; + let mut window = cx.windows.get_mut(id)?.take()?; let root_view = window.root.clone().unwrap(); @@ -1071,15 +1375,14 @@ impl App { true }); } else { - cx.windows - .get_mut(id) - .context("window not found")? - .replace(window); + cx.windows.get_mut(id)?.replace(window); } - Ok(result) + Some(result) }) + .context("window not found") } + /// Creates an `AsyncApp`, which can be cloned and has a static lifetime /// so it can be held across `await` points. pub fn to_async(&self) -> AsyncApp { @@ -1309,7 +1612,7 @@ impl App { T: 'static, { let window_handle = window.handle; - self.observe_release(&handle, move |entity, cx| { + self.observe_release(handle, move |entity, cx| { let _ = window_handle.update(cx, |_, window, cx| on_release(entity, window, cx)); }) } @@ -1331,7 +1634,7 @@ impl App { } inner( - &mut self.keystroke_observers, + &self.keystroke_observers, Box::new(move |event, window, cx| { f(event, window, cx); true @@ -1357,7 +1660,7 @@ impl App { } inner( - &mut self.keystroke_interceptors, + &self.keystroke_interceptors, Box::new(move |event, window, cx| { f(event, window, cx); true @@ -1515,12 +1818,11 @@ impl App { /// the bindings in the element tree, and any global action listeners. pub fn is_action_available(&mut self, action: &dyn Action) -> bool { let mut action_available = false; - if let Some(window) = self.active_window() { - if let Ok(window_action_available) = + if let Some(window) = self.active_window() + && let Ok(window_action_available) = window.update(self, |_, window, cx| window.is_action_available(action, cx)) - { - action_available = window_action_available; - } + { + action_available = window_action_available; } action_available @@ -1605,27 +1907,26 @@ impl App { .insert(action.as_any().type_id(), global_listeners); } - if self.propagate_event { - if let Some(mut global_listeners) = self + if self.propagate_event + && let Some(mut global_listeners) = self .global_action_listeners .remove(&action.as_any().type_id()) - { - for listener in global_listeners.iter().rev() { - listener(action.as_any(), DispatchPhase::Bubble, self); - if !self.propagate_event { - break; - } + { + for listener in global_listeners.iter().rev() { + listener(action.as_any(), DispatchPhase::Bubble, self); + if !self.propagate_event { + break; } + } - global_listeners.extend( - self.global_action_listeners - .remove(&action.as_any().type_id()) - .unwrap_or_default(), - ); - + global_listeners.extend( self.global_action_listeners - .insert(action.as_any().type_id(), global_listeners); - } + .remove(&action.as_any().type_id()) + .unwrap_or_default(), + ); + + self.global_action_listeners + .insert(action.as_any().type_id(), global_listeners); } } @@ -1708,8 +2009,8 @@ impl App { .unwrap_or_else(|| { is_first = true; let future = A::load(source.clone(), self); - let task = self.background_executor().spawn(future).shared(); - task + + self.background_executor().spawn(future).shared() }); self.loading_assets.insert(asset_id, Box::new(task.clone())); @@ -1916,7 +2217,7 @@ impl AppContext for App { G: Global, { let mut g = self.global::(); - callback(&g, self) + callback(g, self) } } @@ -2006,7 +2307,7 @@ pub struct AnyDrag { } /// Contains state associated with a tooltip. You'll only need this struct if you're implementing -/// tooltip behavior on a custom element. Otherwise, use [Div::tooltip]. +/// tooltip behavior on a custom element. Otherwise, use [Div::tooltip](crate::Interactivity::tooltip). #[derive(Clone)] pub struct AnyTooltip { /// The view used to display the tooltip diff --git a/crates/gpui/src/app/async_context.rs b/crates/gpui/src/app/async_context.rs index d9d21c024461cab68d62d685a40b61c9c74d46dd..cfe7a5a75c258d09194c7d77a117208161713c6f 100644 --- a/crates/gpui/src/app/async_context.rs +++ b/crates/gpui/src/app/async_context.rs @@ -218,7 +218,24 @@ impl AsyncApp { Some(read(app.try_global()?, &app)) } - /// A convenience method for [App::update_global] + /// Reads the global state of the specified type, passing it to the given callback. + /// A default value is assigned if a global of this type has not yet been assigned. + /// + /// # Errors + /// If the app has ben dropped this returns an error. + pub fn try_read_default_global( + &self, + read: impl FnOnce(&G, &App) -> R, + ) -> Result { + let app = self.app.upgrade().context("app was released")?; + let mut app = app.borrow_mut(); + app.update(|cx| { + cx.default_global::(); + }); + Ok(read(app.try_global().context("app was released")?, &app)) + } + + /// A convenience method for [`App::update_global`](BorrowAppContext::update_global) /// for updating the global state of the specified type. pub fn update_global( &self, @@ -293,7 +310,7 @@ impl AsyncWindowContext { .update(self, |_, window, cx| read(cx.global(), window, cx)) } - /// A convenience method for [`App::update_global`]. + /// A convenience method for [`App::update_global`](BorrowAppContext::update_global). /// for updating the global state of the specified type. pub fn update_global( &mut self, @@ -465,7 +482,7 @@ impl VisualContext for AsyncWindowContext { V: Focusable, { self.window.update(self, |_, window, cx| { - view.read(cx).focus_handle(cx).clone().focus(window); + view.read(cx).focus_handle(cx).focus(window); }) } } diff --git a/crates/gpui/src/app/context.rs b/crates/gpui/src/app/context.rs index 68c41592b3872addef6381d9f5e8a3f611611bd0..1112878a66b07c133031086c3b14aa8427617bea 100644 --- a/crates/gpui/src/app/context.rs +++ b/crates/gpui/src/app/context.rs @@ -472,7 +472,7 @@ impl<'a, T: 'static> Context<'a, T> { let view = self.weak_entity(); inner( - &mut self.keystroke_observers, + &self.keystroke_observers, Box::new(move |event, window, cx| { if let Some(view) = view.upgrade() { view.update(cx, |view, cx| f(view, event, window, cx)); @@ -610,16 +610,16 @@ impl<'a, T: 'static> Context<'a, T> { let (subscription, activate) = window.new_focus_listener(Box::new(move |event, window, cx| { view.update(cx, |view, cx| { - if let Some(blurred_id) = event.previous_focus_path.last().copied() { - if event.is_focus_out(focus_id) { - let event = FocusOutEvent { - blurred: WeakFocusHandle { - id: blurred_id, - handles: Arc::downgrade(&cx.focus_handles), - }, - }; - listener(view, event, window, cx) - } + if let Some(blurred_id) = event.previous_focus_path.last().copied() + && event.is_focus_out(focus_id) + { + let event = FocusOutEvent { + blurred: WeakFocusHandle { + id: blurred_id, + handles: Arc::downgrade(&cx.focus_handles), + }, + }; + listener(view, event, window, cx) } }) .is_ok() diff --git a/crates/gpui/src/app/entity_map.rs b/crates/gpui/src/app/entity_map.rs index fccb417caa70c7526a0f15a307d74caeabcdab77..ea52b46d9fce958f8cb6e878581fb988c146c43b 100644 --- a/crates/gpui/src/app/entity_map.rs +++ b/crates/gpui/src/app/entity_map.rs @@ -231,14 +231,15 @@ impl AnyEntity { Self { entity_id: id, entity_type, - entity_map: entity_map.clone(), #[cfg(any(test, feature = "leak-detection"))] handle_id: entity_map + .clone() .upgrade() .unwrap() .write() .leak_detector .handle_created(id), + entity_map, } } @@ -661,7 +662,7 @@ pub struct WeakEntity { impl std::fmt::Debug for WeakEntity { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct(&type_name::()) + f.debug_struct(type_name::()) .field("entity_id", &self.any_entity.entity_id) .field("entity_type", &type_name::()) .finish() @@ -786,7 +787,7 @@ impl PartialOrd for WeakEntity { #[cfg(any(test, feature = "leak-detection"))] static LEAK_BACKTRACE: std::sync::LazyLock = - std::sync::LazyLock::new(|| std::env::var("LEAK_BACKTRACE").map_or(false, |b| !b.is_empty())); + std::sync::LazyLock::new(|| std::env::var("LEAK_BACKTRACE").is_ok_and(|b| !b.is_empty())); #[cfg(any(test, feature = "leak-detection"))] #[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq)] diff --git a/crates/gpui/src/app/test_context.rs b/crates/gpui/src/app/test_context.rs index 35e60326714f049faeaac54e8d979a91f9d97bbc..b3d342b09bf1dceb27413d3ec24fbcc0d2f541e9 100644 --- a/crates/gpui/src/app/test_context.rs +++ b/crates/gpui/src/app/test_context.rs @@ -134,7 +134,7 @@ impl TestAppContext { app: App::new_app(platform.clone(), asset_source, http_client), background_executor, foreground_executor, - dispatcher: dispatcher.clone(), + dispatcher, test_platform: platform, text_system, fn_name, @@ -144,7 +144,7 @@ impl TestAppContext { /// Create a single TestAppContext, for non-multi-client tests pub fn single() -> Self { - let dispatcher = TestDispatcher::new(StdRng::from_entropy()); + let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0)); Self::build(dispatcher, None) } @@ -192,6 +192,7 @@ impl TestAppContext { &self.foreground_executor } + #[expect(clippy::wrong_self_convention)] fn new(&mut self, build_entity: impl FnOnce(&mut Context) -> T) -> Entity { let mut cx = self.app.borrow_mut(); cx.new(build_entity) @@ -219,7 +220,7 @@ impl TestAppContext { let mut cx = self.app.borrow_mut(); // Some tests rely on the window size matching the bounds of the test display - let bounds = Bounds::maximized(None, &mut cx); + let bounds = Bounds::maximized(None, &cx); cx.open_window( WindowOptions { window_bounds: Some(WindowBounds::Windowed(bounds)), @@ -233,7 +234,7 @@ impl TestAppContext { /// Adds a new window with no content. pub fn add_empty_window(&mut self) -> &mut VisualTestContext { let mut cx = self.app.borrow_mut(); - let bounds = Bounds::maximized(None, &mut cx); + let bounds = Bounds::maximized(None, &cx); let window = cx .open_window( WindowOptions { @@ -244,7 +245,7 @@ impl TestAppContext { ) .unwrap(); drop(cx); - let cx = VisualTestContext::from_window(*window.deref(), self).as_mut(); + let cx = VisualTestContext::from_window(*window.deref(), self).into_mut(); cx.run_until_parked(); cx } @@ -261,7 +262,7 @@ impl TestAppContext { V: 'static + Render, { let mut cx = self.app.borrow_mut(); - let bounds = Bounds::maximized(None, &mut cx); + let bounds = Bounds::maximized(None, &cx); let window = cx .open_window( WindowOptions { @@ -273,7 +274,7 @@ impl TestAppContext { .unwrap(); drop(cx); let view = window.root(self).unwrap(); - let cx = VisualTestContext::from_window(*window.deref(), self).as_mut(); + let cx = VisualTestContext::from_window(*window.deref(), self).into_mut(); cx.run_until_parked(); // it might be nice to try and cleanup these at the end of each test. @@ -338,7 +339,7 @@ impl TestAppContext { /// Returns all windows open in the test. pub fn windows(&self) -> Vec { - self.app.borrow().windows().clone() + self.app.borrow().windows() } /// Run the given task on the main thread. @@ -585,7 +586,7 @@ impl Entity { cx.executor().advance_clock(advance_clock_by); async move { - let notification = crate::util::timeout(duration, rx.recv()) + let notification = crate::util::smol_timeout(duration, rx.recv()) .await .expect("next notification timed out"); drop(subscription); @@ -618,7 +619,7 @@ impl Entity { } }), cx.subscribe(self, { - let mut tx = tx.clone(); + let mut tx = tx; move |_, _: &Evt, _| { tx.blocking_send(()).ok(); } @@ -629,7 +630,7 @@ impl Entity { let handle = self.downgrade(); async move { - crate::util::timeout(Duration::from_secs(1), async move { + crate::util::smol_timeout(Duration::from_secs(1), async move { loop { { let cx = cx.borrow(); @@ -882,7 +883,7 @@ impl VisualTestContext { /// Get an &mut VisualTestContext (which is mostly what you need to pass to other methods). /// This method internally retains the VisualTestContext until the end of the test. - pub fn as_mut(self) -> &'static mut Self { + pub fn into_mut(self) -> &'static mut Self { let ptr = Box::into_raw(Box::new(self)); // safety: on_quit will be called after the test has finished. // the executor will ensure that all tasks related to the test have stopped. @@ -1025,7 +1026,7 @@ impl VisualContext for VisualTestContext { fn focus(&mut self, view: &Entity) -> Self::Result<()> { self.window .update(&mut self.cx, |_, window, cx| { - view.read(cx).focus_handle(cx).clone().focus(window) + view.read(cx).focus_handle(cx).focus(window) }) .unwrap() } diff --git a/crates/gpui/src/arena.rs b/crates/gpui/src/arena.rs index ee72d0e96425816220094f4cbff86315153afb74..a0d0c23987472de46d5b23129adb5a4ec8ee00cb 100644 --- a/crates/gpui/src/arena.rs +++ b/crates/gpui/src/arena.rs @@ -1,8 +1,9 @@ use std::{ alloc::{self, handle_alloc_error}, cell::Cell, + num::NonZeroUsize, ops::{Deref, DerefMut}, - ptr, + ptr::{self, NonNull}, rc::Rc, }; @@ -30,23 +31,23 @@ impl Drop for Chunk { fn drop(&mut self) { unsafe { let chunk_size = self.end.offset_from_unsigned(self.start); - // this never fails as it succeeded during allocation - let layout = alloc::Layout::from_size_align(chunk_size, 1).unwrap(); + // SAFETY: This succeeded during allocation. + let layout = alloc::Layout::from_size_align_unchecked(chunk_size, 1); alloc::dealloc(self.start, layout); } } } impl Chunk { - fn new(chunk_size: usize) -> Self { + fn new(chunk_size: NonZeroUsize) -> Self { unsafe { // this only fails if chunk_size is unreasonably huge - let layout = alloc::Layout::from_size_align(chunk_size, 1).unwrap(); + let layout = alloc::Layout::from_size_align(chunk_size.get(), 1).unwrap(); let start = alloc::alloc(layout); if start.is_null() { handle_alloc_error(layout); } - let end = start.add(chunk_size); + let end = start.add(chunk_size.get()); Self { start, end, @@ -55,14 +56,14 @@ impl Chunk { } } - fn allocate(&mut self, layout: alloc::Layout) -> Option<*mut u8> { + fn allocate(&mut self, layout: alloc::Layout) -> Option> { unsafe { let aligned = self.offset.add(self.offset.align_offset(layout.align())); let next = aligned.add(layout.size()); if next <= self.end { self.offset = next; - Some(aligned) + NonNull::new(aligned) } else { None } @@ -79,7 +80,7 @@ pub struct Arena { elements: Vec, valid: Rc>, current_chunk_index: usize, - chunk_size: usize, + chunk_size: NonZeroUsize, } impl Drop for Arena { @@ -90,7 +91,7 @@ impl Drop for Arena { impl Arena { pub fn new(chunk_size: usize) -> Self { - assert!(chunk_size > 0); + let chunk_size = NonZeroUsize::try_from(chunk_size).unwrap(); Self { chunks: vec![Chunk::new(chunk_size)], elements: Vec::new(), @@ -101,7 +102,7 @@ impl Arena { } pub fn capacity(&self) -> usize { - self.chunks.len() * self.chunk_size + self.chunks.len() * self.chunk_size.get() } pub fn clear(&mut self) { @@ -136,20 +137,20 @@ impl Arena { let layout = alloc::Layout::new::(); let mut current_chunk = &mut self.chunks[self.current_chunk_index]; let ptr = if let Some(ptr) = current_chunk.allocate(layout) { - ptr + ptr.as_ptr() } else { self.current_chunk_index += 1; if self.current_chunk_index >= self.chunks.len() { self.chunks.push(Chunk::new(self.chunk_size)); assert_eq!(self.current_chunk_index, self.chunks.len() - 1); - log::info!( + log::trace!( "increased element arena capacity to {}kb", self.capacity() / 1024, ); } current_chunk = &mut self.chunks[self.current_chunk_index]; if let Some(ptr) = current_chunk.allocate(layout) { - ptr + ptr.as_ptr() } else { panic!( "Arena chunk_size of {} is too small to allocate {} bytes", diff --git a/crates/gpui/src/assets.rs b/crates/gpui/src/assets.rs index 70a07c11e9239c048f9eaede8cae31a79acf779c..8930b58f8d4fc0423b7d6f41755189a03d8b8b84 100644 --- a/crates/gpui/src/assets.rs +++ b/crates/gpui/src/assets.rs @@ -1,4 +1,4 @@ -use crate::{DevicePixels, Result, SharedString, Size, size}; +use crate::{DevicePixels, Pixels, Result, SharedString, Size, size}; use smallvec::SmallVec; use image::{Delay, Frame}; @@ -42,6 +42,8 @@ pub(crate) struct RenderImageParams { pub struct RenderImage { /// The ID associated with this image pub id: ImageId, + /// The scale factor of this image on render. + pub(crate) scale_factor: f32, data: SmallVec<[Frame; 1]>, } @@ -60,6 +62,7 @@ impl RenderImage { Self { id: ImageId(NEXT_ID.fetch_add(1, SeqCst)), + scale_factor: 1.0, data: data.into(), } } @@ -77,6 +80,12 @@ impl RenderImage { size(width.into(), height.into()) } + /// Get the size of this image, in pixels for display, adjusted for the scale factor. + pub(crate) fn render_size(&self, frame_index: usize) -> Size { + self.size(frame_index) + .map(|v| (v.0 as f32 / self.scale_factor).into()) + } + /// Get the delay of this frame from the previous pub fn delay(&self, frame_index: usize) -> Delay { self.data[frame_index].delay() diff --git a/crates/gpui/src/bounds_tree.rs b/crates/gpui/src/bounds_tree.rs index 03f83b95035489bd86201c4d64c15f5a12ed50ea..a96bfe55b9ff431a96da7bf42692288264eb184c 100644 --- a/crates/gpui/src/bounds_tree.rs +++ b/crates/gpui/src/bounds_tree.rs @@ -309,12 +309,12 @@ mod tests { let mut expected_quads: Vec<(Bounds, u32)> = Vec::new(); // Insert a random number of random AABBs into the tree. - let num_bounds = rng.gen_range(1..=max_bounds); + let num_bounds = rng.random_range(1..=max_bounds); for _ in 0..num_bounds { - let min_x: f32 = rng.gen_range(-100.0..100.0); - let min_y: f32 = rng.gen_range(-100.0..100.0); - let width: f32 = rng.gen_range(0.0..50.0); - let height: f32 = rng.gen_range(0.0..50.0); + let min_x: f32 = rng.random_range(-100.0..100.0); + let min_y: f32 = rng.random_range(-100.0..100.0); + let width: f32 = rng.random_range(0.0..50.0); + let height: f32 = rng.random_range(0.0..50.0); let bounds = Bounds { origin: Point { x: min_x, y: min_y }, size: Size { width, height }, diff --git a/crates/gpui/src/color.rs b/crates/gpui/src/color.rs index 639c84c10144310b14a94c2a22b84957b8b09524..93c69744a6f9f3cc74e7696e9edf49001587376b 100644 --- a/crates/gpui/src/color.rs +++ b/crates/gpui/src/color.rs @@ -473,6 +473,11 @@ impl Hsla { self.a == 0.0 } + /// Returns true if the HSLA color is fully opaque, false otherwise. + pub fn is_opaque(&self) -> bool { + self.a == 1.0 + } + /// Blends `other` on top of `self` based on `other`'s alpha value. The resulting color is a combination of `self`'s and `other`'s colors. /// /// If `other`'s alpha value is 1.0 or greater, `other` color is fully opaque, thus `other` is returned as the output color. @@ -905,9 +910,9 @@ mod tests { assert_eq!(background.solid, color); assert_eq!(background.opacity(0.5).solid, color.opacity(0.5)); - assert_eq!(background.is_transparent(), false); + assert!(!background.is_transparent()); background.solid = hsla(0.0, 0.0, 0.0, 0.0); - assert_eq!(background.is_transparent(), true); + assert!(background.is_transparent()); } #[test] @@ -921,7 +926,7 @@ mod tests { assert_eq!(background.opacity(0.5).colors[0], from.opacity(0.5)); assert_eq!(background.opacity(0.5).colors[1], to.opacity(0.5)); - assert_eq!(background.is_transparent(), false); - assert_eq!(background.opacity(0.0).is_transparent(), true); + assert!(!background.is_transparent()); + assert!(background.opacity(0.0).is_transparent()); } } diff --git a/crates/gpui/src/colors.rs b/crates/gpui/src/colors.rs index 5e14c1238addbb02b0c6a02942aae05b703583ea..ef11ef57fdb363dae3f910db2e540e3de02fe453 100644 --- a/crates/gpui/src/colors.rs +++ b/crates/gpui/src/colors.rs @@ -88,9 +88,9 @@ impl Deref for GlobalColors { impl Global for GlobalColors {} -/// Implement this trait to allow global [Color] access via `cx.default_colors()`. +/// Implement this trait to allow global [Colors] access via `cx.default_colors()`. pub trait DefaultColors { - /// Returns the default [`gpui::Colors`] + /// Returns the default [`Colors`] fn default_colors(&self) -> &Arc; } diff --git a/crates/gpui/src/element.rs b/crates/gpui/src/element.rs index e5f49c7be141a3620e52599bcc2b151acc1f7319..a3fc6269f33d8726b55f8e8be4aadb52109a7606 100644 --- a/crates/gpui/src/element.rs +++ b/crates/gpui/src/element.rs @@ -14,13 +14,13 @@ //! tree and any callbacks they have registered with GPUI are dropped and the process repeats. //! //! But some state is too simple and voluminous to store in every view that needs it, e.g. -//! whether a hover has been started or not. For this, GPUI provides the [`Element::State`], associated type. +//! whether a hover has been started or not. For this, GPUI provides the [`Element::PrepaintState`], associated type. //! //! # Implementing your own elements //! //! Elements are intended to be the low level, imperative API to GPUI. They are responsible for upholding, //! or breaking, GPUI's features as they deem necessary. As an example, most GPUI elements are expected -//! to stay in the bounds that their parent element gives them. But with [`WindowContext::break_content_mask`], +//! to stay in the bounds that their parent element gives them. But with [`Window::with_content_mask`], //! you can ignore this restriction and paint anywhere inside of the window's bounds. This is useful for overlays //! and popups and anything else that shows up 'on top' of other elements. //! With great power, comes great responsibility. @@ -603,10 +603,8 @@ impl AnyElement { self.0.prepaint(window, cx); - if !focus_assigned { - if let Some(focus_id) = window.next_frame.focus { - return FocusHandle::for_id(focus_id, &cx.focus_handles); - } + if !focus_assigned && let Some(focus_id) = window.next_frame.focus { + return FocusHandle::for_id(focus_id, &cx.focus_handles); } None diff --git a/crates/gpui/src/elements/animation.rs b/crates/gpui/src/elements/animation.rs index 11dd19e260c20e49b87e05137771be73a3f816ea..e72fb00456d14dec74ffc56e040511c189af1d18 100644 --- a/crates/gpui/src/elements/animation.rs +++ b/crates/gpui/src/elements/animation.rs @@ -87,7 +87,7 @@ pub trait AnimationExt { } } -impl AnimationExt for E {} +impl AnimationExt for E {} /// A GPUI element that applies an animation to another element pub struct AnimationElement { diff --git a/crates/gpui/src/elements/div.rs b/crates/gpui/src/elements/div.rs index 09afbff929b99bb927d365621ea0550c28dcedf8..443bcb14bbec7c5fac39fdd0f5e5d621d84df610 100644 --- a/crates/gpui/src/elements/div.rs +++ b/crates/gpui/src/elements/div.rs @@ -27,6 +27,7 @@ use crate::{ use collections::HashMap; use refineable::Refineable; use smallvec::SmallVec; +use stacksafe::{StackSafe, stacksafe}; use std::{ any::{Any, TypeId}, cell::RefCell, @@ -285,21 +286,20 @@ impl Interactivity { { self.mouse_move_listeners .push(Box::new(move |event, phase, hitbox, window, cx| { - if phase == DispatchPhase::Capture { - if let Some(drag) = &cx.active_drag { - if drag.value.as_ref().type_id() == TypeId::of::() { - (listener)( - &DragMoveEvent { - event: event.clone(), - bounds: hitbox.bounds, - drag: PhantomData, - dragged_item: Arc::clone(&drag.value), - }, - window, - cx, - ); - } - } + if phase == DispatchPhase::Capture + && let Some(drag) = &cx.active_drag + && drag.value.as_ref().type_id() == TypeId::of::() + { + (listener)( + &DragMoveEvent { + event: event.clone(), + bounds: hitbox.bounds, + drag: PhantomData, + dragged_item: Arc::clone(&drag.value), + }, + window, + cx, + ); } })); } @@ -533,7 +533,7 @@ impl Interactivity { } /// Use the given callback to construct a new tooltip view when the mouse hovers over this element. - /// The imperative API equivalent to [`InteractiveElement::tooltip`] + /// The imperative API equivalent to [`StatefulInteractiveElement::tooltip`] pub fn tooltip(&mut self, build_tooltip: impl Fn(&mut Window, &mut App) -> AnyView + 'static) where Self: Sized, @@ -550,7 +550,7 @@ impl Interactivity { /// Use the given callback to construct a new tooltip view when the mouse hovers over this element. /// The tooltip itself is also hoverable and won't disappear when the user moves the mouse into - /// the tooltip. The imperative API equivalent to [`InteractiveElement::hoverable_tooltip`] + /// the tooltip. The imperative API equivalent to [`StatefulInteractiveElement::hoverable_tooltip`] pub fn hoverable_tooltip( &mut self, build_tooltip: impl Fn(&mut Window, &mut App) -> AnyView + 'static, @@ -676,7 +676,7 @@ pub trait InteractiveElement: Sized { #[cfg(any(test, feature = "test-support"))] /// Set a key that can be used to look up this element's bounds - /// in the [`VisualTestContext::debug_bounds`] map + /// in the [`crate::VisualTestContext::debug_bounds`] map /// This is a noop in release builds fn debug_selector(mut self, f: impl FnOnce() -> String) -> Self { self.interactivity().debug_selector = Some(f()); @@ -685,7 +685,7 @@ pub trait InteractiveElement: Sized { #[cfg(not(any(test, feature = "test-support")))] /// Set a key that can be used to look up this element's bounds - /// in the [`VisualTestContext::debug_bounds`] map + /// in the [`crate::VisualTestContext::debug_bounds`] map /// This is a noop in release builds #[inline] fn debug_selector(self, _: impl FnOnce() -> String) -> Self { @@ -1087,7 +1087,7 @@ pub trait StatefulInteractiveElement: InteractiveElement { /// On drag initiation, this callback will be used to create a new view to render the dragged value for a /// drag and drop operation. This API should also be used as the equivalent of 'on drag start' with - /// the [`Self::on_drag_move`] API. + /// the [`InteractiveElement::on_drag_move`] API. /// The callback also has access to the offset of triggering click from the origin of parent element. /// The fluent API equivalent to [`Interactivity::on_drag`] /// @@ -1195,7 +1195,7 @@ pub fn div() -> Div { /// A [`Div`] element, the all-in-one element for building complex UIs in GPUI pub struct Div { interactivity: Interactivity, - children: SmallVec<[AnyElement; 2]>, + children: SmallVec<[StackSafe; 2]>, prepaint_listener: Option>, &mut Window, &mut App) + 'static>>, image_cache: Option>, } @@ -1256,7 +1256,8 @@ impl InteractiveElement for Div { impl ParentElement for Div { fn extend(&mut self, elements: impl IntoIterator) { - self.children.extend(elements) + self.children + .extend(elements.into_iter().map(StackSafe::new)) } } @@ -1272,6 +1273,7 @@ impl Element for Div { self.interactivity.source_location() } + #[stacksafe] fn request_layout( &mut self, global_id: Option<&GlobalElementId>, @@ -1307,6 +1309,7 @@ impl Element for Div { (layout_id, DivFrameState { child_layout_ids }) } + #[stacksafe] fn prepaint( &mut self, global_id: Option<&GlobalElementId>, @@ -1376,6 +1379,7 @@ impl Element for Div { ) } + #[stacksafe] fn paint( &mut self, global_id: Option<&GlobalElementId>, @@ -1509,15 +1513,14 @@ impl Interactivity { let mut element_state = element_state.map(|element_state| element_state.unwrap_or_default()); - if let Some(element_state) = element_state.as_ref() { - if cx.has_active_drag() { - if let Some(pending_mouse_down) = element_state.pending_mouse_down.as_ref() - { - *pending_mouse_down.borrow_mut() = None; - } - if let Some(clicked_state) = element_state.clicked_state.as_ref() { - *clicked_state.borrow_mut() = ElementClickedState::default(); - } + if let Some(element_state) = element_state.as_ref() + && cx.has_active_drag() + { + if let Some(pending_mouse_down) = element_state.pending_mouse_down.as_ref() { + *pending_mouse_down.borrow_mut() = None; + } + if let Some(clicked_state) = element_state.clicked_state.as_ref() { + *clicked_state.borrow_mut() = ElementClickedState::default(); } } @@ -1525,35 +1528,35 @@ impl Interactivity { // If there's an explicit focus handle we're tracking, use that. Otherwise // create a new handle and store it in the element state, which lives for as // as frames contain an element with this id. - if self.focusable && self.tracked_focus_handle.is_none() { - if let Some(element_state) = element_state.as_mut() { - let mut handle = element_state - .focus_handle - .get_or_insert_with(|| cx.focus_handle()) - .clone() - .tab_stop(false); - - if let Some(index) = self.tab_index { - handle = handle.tab_index(index).tab_stop(true); - } - - self.tracked_focus_handle = Some(handle); + if self.focusable + && self.tracked_focus_handle.is_none() + && let Some(element_state) = element_state.as_mut() + { + let mut handle = element_state + .focus_handle + .get_or_insert_with(|| cx.focus_handle()) + .clone() + .tab_stop(false); + + if let Some(index) = self.tab_index { + handle = handle.tab_index(index).tab_stop(true); } + + self.tracked_focus_handle = Some(handle); } if let Some(scroll_handle) = self.tracked_scroll_handle.as_ref() { self.scroll_offset = Some(scroll_handle.0.borrow().offset.clone()); - } else if self.base_style.overflow.x == Some(Overflow::Scroll) - || self.base_style.overflow.y == Some(Overflow::Scroll) + } else if (self.base_style.overflow.x == Some(Overflow::Scroll) + || self.base_style.overflow.y == Some(Overflow::Scroll)) + && let Some(element_state) = element_state.as_mut() { - if let Some(element_state) = element_state.as_mut() { - self.scroll_offset = Some( - element_state - .scroll_offset - .get_or_insert_with(Rc::default) - .clone(), - ); - } + self.scroll_offset = Some( + element_state + .scroll_offset + .get_or_insert_with(Rc::default) + .clone(), + ); } let style = self.compute_style_internal(None, element_state.as_mut(), window, cx); @@ -2026,26 +2029,27 @@ impl Interactivity { let hitbox = hitbox.clone(); window.on_mouse_event({ move |_: &MouseUpEvent, phase, window, cx| { - if let Some(drag) = &cx.active_drag { - if phase == DispatchPhase::Bubble && hitbox.is_hovered(window) { - let drag_state_type = drag.value.as_ref().type_id(); - for (drop_state_type, listener) in &drop_listeners { - if *drop_state_type == drag_state_type { - let drag = cx - .active_drag - .take() - .expect("checked for type drag state type above"); - - let mut can_drop = true; - if let Some(predicate) = &can_drop_predicate { - can_drop = predicate(drag.value.as_ref(), window, cx); - } + if let Some(drag) = &cx.active_drag + && phase == DispatchPhase::Bubble + && hitbox.is_hovered(window) + { + let drag_state_type = drag.value.as_ref().type_id(); + for (drop_state_type, listener) in &drop_listeners { + if *drop_state_type == drag_state_type { + let drag = cx + .active_drag + .take() + .expect("checked for type drag state type above"); + + let mut can_drop = true; + if let Some(predicate) = &can_drop_predicate { + can_drop = predicate(drag.value.as_ref(), window, cx); + } - if can_drop { - listener(drag.value.as_ref(), window, cx); - window.refresh(); - cx.stop_propagation(); - } + if can_drop { + listener(drag.value.as_ref(), window, cx); + window.refresh(); + cx.stop_propagation(); } } } @@ -2089,31 +2093,24 @@ impl Interactivity { } let mut pending_mouse_down = pending_mouse_down.borrow_mut(); - if let Some(mouse_down) = pending_mouse_down.clone() { - if !cx.has_active_drag() - && (event.position - mouse_down.position).magnitude() - > DRAG_THRESHOLD - { - if let Some((drag_value, drag_listener)) = drag_listener.take() { - *clicked_state.borrow_mut() = ElementClickedState::default(); - let cursor_offset = event.position - hitbox.origin; - let drag = (drag_listener)( - drag_value.as_ref(), - cursor_offset, - window, - cx, - ); - cx.active_drag = Some(AnyDrag { - view: drag, - value: drag_value, - cursor_offset, - cursor_style: drag_cursor_style, - }); - pending_mouse_down.take(); - window.refresh(); - cx.stop_propagation(); - } - } + if let Some(mouse_down) = pending_mouse_down.clone() + && !cx.has_active_drag() + && (event.position - mouse_down.position).magnitude() > DRAG_THRESHOLD + && let Some((drag_value, drag_listener)) = drag_listener.take() + { + *clicked_state.borrow_mut() = ElementClickedState::default(); + let cursor_offset = event.position - hitbox.origin; + let drag = + (drag_listener)(drag_value.as_ref(), cursor_offset, window, cx); + cx.active_drag = Some(AnyDrag { + view: drag, + value: drag_value, + cursor_offset, + cursor_style: drag_cursor_style, + }); + pending_mouse_down.take(); + window.refresh(); + cx.stop_propagation(); } } }); @@ -2277,7 +2274,7 @@ impl Interactivity { window.on_mouse_event(move |_: &MouseDownEvent, phase, window, _cx| { if phase == DispatchPhase::Bubble && !window.default_prevented() { let group_hovered = active_group_hitbox - .map_or(false, |group_hitbox_id| group_hitbox_id.is_hovered(window)); + .is_some_and(|group_hitbox_id| group_hitbox_id.is_hovered(window)); let element_hovered = hitbox.is_hovered(window); if group_hovered || element_hovered { *active_state.borrow_mut() = ElementClickedState { @@ -2423,33 +2420,32 @@ impl Interactivity { style.refine(&self.base_style); if let Some(focus_handle) = self.tracked_focus_handle.as_ref() { - if let Some(in_focus_style) = self.in_focus_style.as_ref() { - if focus_handle.within_focused(window, cx) { - style.refine(in_focus_style); - } + if let Some(in_focus_style) = self.in_focus_style.as_ref() + && focus_handle.within_focused(window, cx) + { + style.refine(in_focus_style); } - if let Some(focus_style) = self.focus_style.as_ref() { - if focus_handle.is_focused(window) { - style.refine(focus_style); - } + if let Some(focus_style) = self.focus_style.as_ref() + && focus_handle.is_focused(window) + { + style.refine(focus_style); } } if let Some(hitbox) = hitbox { if !cx.has_active_drag() { - if let Some(group_hover) = self.group_hover_style.as_ref() { - if let Some(group_hitbox_id) = GroupHitboxes::get(&group_hover.group, cx) { - if group_hitbox_id.is_hovered(window) { - style.refine(&group_hover.style); - } - } + if let Some(group_hover) = self.group_hover_style.as_ref() + && let Some(group_hitbox_id) = GroupHitboxes::get(&group_hover.group, cx) + && group_hitbox_id.is_hovered(window) + { + style.refine(&group_hover.style); } - if let Some(hover_style) = self.hover_style.as_ref() { - if hitbox.is_hovered(window) { - style.refine(hover_style); - } + if let Some(hover_style) = self.hover_style.as_ref() + && hitbox.is_hovered(window) + { + style.refine(hover_style); } } @@ -2463,12 +2459,10 @@ impl Interactivity { for (state_type, group_drag_style) in &self.group_drag_over_styles { if let Some(group_hitbox_id) = GroupHitboxes::get(&group_drag_style.group, cx) + && *state_type == drag.value.as_ref().type_id() + && group_hitbox_id.is_hovered(window) { - if *state_type == drag.value.as_ref().type_id() - && group_hitbox_id.is_hovered(window) - { - style.refine(&group_drag_style.style); - } + style.refine(&group_drag_style.style); } } @@ -2490,16 +2484,16 @@ impl Interactivity { .clicked_state .get_or_insert_with(Default::default) .borrow(); - if clicked_state.group { - if let Some(group) = self.group_active_style.as_ref() { - style.refine(&group.style) - } + if clicked_state.group + && let Some(group) = self.group_active_style.as_ref() + { + style.refine(&group.style) } - if let Some(active_style) = self.active_style.as_ref() { - if clicked_state.element { - style.refine(active_style) - } + if let Some(active_style) = self.active_style.as_ref() + && clicked_state.element + { + style.refine(active_style) } } @@ -2620,7 +2614,7 @@ pub(crate) fn register_tooltip_mouse_handlers( window.on_mouse_event({ let active_tooltip = active_tooltip.clone(); move |_: &MouseDownEvent, _phase, window: &mut Window, _cx| { - if !tooltip_id.map_or(false, |tooltip_id| tooltip_id.is_hovered(window)) { + if !tooltip_id.is_some_and(|tooltip_id| tooltip_id.is_hovered(window)) { clear_active_tooltip_if_not_hoverable(&active_tooltip, window); } } @@ -2629,7 +2623,7 @@ pub(crate) fn register_tooltip_mouse_handlers( window.on_mouse_event({ let active_tooltip = active_tooltip.clone(); move |_: &ScrollWheelEvent, _phase, window: &mut Window, _cx| { - if !tooltip_id.map_or(false, |tooltip_id| tooltip_id.is_hovered(window)) { + if !tooltip_id.is_some_and(|tooltip_id| tooltip_id.is_hovered(window)) { clear_active_tooltip_if_not_hoverable(&active_tooltip, window); } } @@ -2785,7 +2779,7 @@ fn handle_tooltip_check_visible_and_update( match action { Action::None => {} - Action::Hide => clear_active_tooltip(&active_tooltip, window), + Action::Hide => clear_active_tooltip(active_tooltip, window), Action::ScheduleHide(tooltip) => { let delayed_hide_task = window.spawn(cx, { let active_tooltip = active_tooltip.clone(); diff --git a/crates/gpui/src/elements/image_cache.rs b/crates/gpui/src/elements/image_cache.rs index e7bdeaf9eb4d26913718a9b235cee4fcb0ca85ff..ee1436134a30f70e7015ab1c86f60733e60e9164 100644 --- a/crates/gpui/src/elements/image_cache.rs +++ b/crates/gpui/src/elements/image_cache.rs @@ -64,7 +64,7 @@ mod any_image_cache { cx: &mut App, ) -> Option, ImageCacheError>> { let image_cache = image_cache.clone().downcast::().unwrap(); - return image_cache.update(cx, |image_cache, cx| image_cache.load(resource, window, cx)); + image_cache.update(cx, |image_cache, cx| image_cache.load(resource, window, cx)) } } @@ -297,10 +297,10 @@ impl RetainAllImageCache { /// Remove the image from the cache by the given source. pub fn remove(&mut self, source: &Resource, window: &mut Window, cx: &mut App) { let hash = hash(source); - if let Some(mut item) = self.0.remove(&hash) { - if let Some(Ok(image)) = item.get() { - cx.drop_image(image, Some(window)); - } + if let Some(mut item) = self.0.remove(&hash) + && let Some(Ok(image)) = item.get() + { + cx.drop_image(image, Some(window)); } } diff --git a/crates/gpui/src/elements/img.rs b/crates/gpui/src/elements/img.rs index 993b319b697ece386ad8af6d6164c1b85bf3a1c7..40d1b5e44981b7cfd0de92ddbb10f2f715008c70 100644 --- a/crates/gpui/src/elements/img.rs +++ b/crates/gpui/src/elements/img.rs @@ -332,20 +332,18 @@ impl Element for Img { state.started_loading = None; } - let image_size = data.size(frame_index); - style.aspect_ratio = - Some(image_size.width.0 as f32 / image_size.height.0 as f32); + let image_size = data.render_size(frame_index); + style.aspect_ratio = Some(image_size.width / image_size.height); if let Length::Auto = style.size.width { style.size.width = match style.size.height { Length::Definite(DefiniteLength::Absolute( AbsoluteLength::Pixels(height), )) => Length::Definite( - px(image_size.width.0 as f32 * height.0 - / image_size.height.0 as f32) - .into(), + px(image_size.width.0 * height.0 / image_size.height.0) + .into(), ), - _ => Length::Definite(px(image_size.width.0 as f32).into()), + _ => Length::Definite(image_size.width.into()), }; } @@ -354,11 +352,10 @@ impl Element for Img { Length::Definite(DefiniteLength::Absolute( AbsoluteLength::Pixels(width), )) => Length::Definite( - px(image_size.height.0 as f32 * width.0 - / image_size.width.0 as f32) - .into(), + px(image_size.height.0 * width.0 / image_size.width.0) + .into(), ), - _ => Length::Definite(px(image_size.height.0 as f32).into()), + _ => Length::Definite(image_size.height.into()), }; } @@ -379,13 +376,12 @@ impl Element for Img { None => { if let Some(state) = &mut state { if let Some((started_loading, _)) = state.started_loading { - if started_loading.elapsed() > LOADING_DELAY { - if let Some(loading) = self.style.loading.as_ref() { - let mut element = loading(); - replacement_id = - Some(element.request_layout(window, cx)); - layout_state.replacement = Some(element); - } + if started_loading.elapsed() > LOADING_DELAY + && let Some(loading) = self.style.loading.as_ref() + { + let mut element = loading(); + replacement_id = Some(element.request_layout(window, cx)); + layout_state.replacement = Some(element); } } else { let current_view = window.current_view(); @@ -476,7 +472,7 @@ impl Element for Img { .paint_image( new_bounds, corner_radii, - data.clone(), + data, layout_state.frame_index, self.style.grayscale, ) @@ -702,7 +698,9 @@ impl Asset for ImageAssetLoader { swap_rgba_pa_to_bgra(pixel); } - RenderImage::new(SmallVec::from_elem(Frame::new(buffer), 1)) + let mut image = RenderImage::new(SmallVec::from_elem(Frame::new(buffer), 1)); + image.scale_factor = SMOOTH_SVG_SCALE_FACTOR; + image }; Ok(Arc::new(data)) diff --git a/crates/gpui/src/elements/list.rs b/crates/gpui/src/elements/list.rs index 39f38bdc69d6a5d4c9ce8c7c349707e906124cca..ed4ca64e83513531b9176f05c4c00b0af71aea74 100644 --- a/crates/gpui/src/elements/list.rs +++ b/crates/gpui/src/elements/list.rs @@ -5,7 +5,7 @@ //! In order to minimize re-renders, this element's state is stored intrusively //! on your own views, so that your code can coordinate directly with the list element's cached state. //! -//! If all of your elements are the same height, see [`UniformList`] for a simpler API +//! If all of your elements are the same height, see [`crate::UniformList`] for a simpler API use crate::{ AnyElement, App, AvailableSpace, Bounds, ContentMask, DispatchPhase, Edges, Element, EntityId, @@ -235,7 +235,7 @@ impl ListState { } /// Register with the list state that the items in `old_range` have been replaced - /// by new items. As opposed to [`splice`], this method allows an iterator of optional focus handles + /// by new items. As opposed to [`Self::splice`], this method allows an iterator of optional focus handles /// to be supplied to properly integrate with items in the list that can be focused. If a focused item /// is scrolled out of view, the list will continue to render it to allow keyboard interaction. pub fn splice_focusable( @@ -732,46 +732,44 @@ impl StateInner { item.element.prepaint_at(item_origin, window, cx); }); - if let Some(autoscroll_bounds) = window.take_autoscroll() { - if autoscroll { - if autoscroll_bounds.top() < bounds.top() { - return Err(ListOffset { - item_ix: item.index, - offset_in_item: autoscroll_bounds.top() - item_origin.y, - }); - } else if autoscroll_bounds.bottom() > bounds.bottom() { - let mut cursor = self.items.cursor::(&()); - cursor.seek(&Count(item.index), Bias::Right); - let mut height = bounds.size.height - padding.top - padding.bottom; - - // Account for the height of the element down until the autoscroll bottom. - height -= autoscroll_bounds.bottom() - item_origin.y; - - // Keep decreasing the scroll top until we fill all the available space. - while height > Pixels::ZERO { - cursor.prev(); - let Some(item) = cursor.item() else { break }; - - let size = item.size().unwrap_or_else(|| { - let mut item = render_item(cursor.start().0, window, cx); - let item_available_size = size( - bounds.size.width.into(), - AvailableSpace::MinContent, - ); - item.layout_as_root(item_available_size, window, cx) - }); - height -= size.height; - } - - return Err(ListOffset { - item_ix: cursor.start().0, - offset_in_item: if height < Pixels::ZERO { - -height - } else { - Pixels::ZERO - }, + if let Some(autoscroll_bounds) = window.take_autoscroll() + && autoscroll + { + if autoscroll_bounds.top() < bounds.top() { + return Err(ListOffset { + item_ix: item.index, + offset_in_item: autoscroll_bounds.top() - item_origin.y, + }); + } else if autoscroll_bounds.bottom() > bounds.bottom() { + let mut cursor = self.items.cursor::(&()); + cursor.seek(&Count(item.index), Bias::Right); + let mut height = bounds.size.height - padding.top - padding.bottom; + + // Account for the height of the element down until the autoscroll bottom. + height -= autoscroll_bounds.bottom() - item_origin.y; + + // Keep decreasing the scroll top until we fill all the available space. + while height > Pixels::ZERO { + cursor.prev(); + let Some(item) = cursor.item() else { break }; + + let size = item.size().unwrap_or_else(|| { + let mut item = render_item(cursor.start().0, window, cx); + let item_available_size = + size(bounds.size.width.into(), AvailableSpace::MinContent); + item.layout_as_root(item_available_size, window, cx) }); + height -= size.height; } + + return Err(ListOffset { + item_ix: cursor.start().0, + offset_in_item: if height < Pixels::ZERO { + -height + } else { + Pixels::ZERO + }, + }); } } @@ -940,9 +938,10 @@ impl Element for List { let hitbox = window.insert_hitbox(bounds, HitboxBehavior::Normal); // If the width of the list has changed, invalidate all cached item heights - if state.last_layout_bounds.map_or(true, |last_bounds| { - last_bounds.size.width != bounds.size.width - }) { + if state + .last_layout_bounds + .is_none_or(|last_bounds| last_bounds.size.width != bounds.size.width) + { let new_items = SumTree::from_iter( state.items.iter().map(|item| ListItem::Unmeasured { focus_handle: item.focus_handle(), diff --git a/crates/gpui/src/elements/text.rs b/crates/gpui/src/elements/text.rs index 014f617e2cfc74755908368f57060aeaeb38aa74..b5e071279623611685ea744e38b072284e764e2a 100644 --- a/crates/gpui/src/elements/text.rs +++ b/crates/gpui/src/elements/text.rs @@ -326,7 +326,7 @@ impl TextLayout { vec![text_style.to_run(text.len())] }; - let layout_id = window.request_measured_layout(Default::default(), { + window.request_measured_layout(Default::default(), { let element_state = self.clone(); move |known_dimensions, available_space, window, cx| { @@ -356,12 +356,11 @@ impl TextLayout { (None, "".into()) }; - if let Some(text_layout) = element_state.0.borrow().as_ref() { - if text_layout.size.is_some() - && (wrap_width.is_none() || wrap_width == text_layout.wrap_width) - { - return text_layout.size.unwrap(); - } + if let Some(text_layout) = element_state.0.borrow().as_ref() + && text_layout.size.is_some() + && (wrap_width.is_none() || wrap_width == text_layout.wrap_width) + { + return text_layout.size.unwrap(); } let mut line_wrapper = cx.text_system().line_wrapper(text_style.font(), font_size); @@ -417,9 +416,7 @@ impl TextLayout { size } - }); - - layout_id + }) } fn prepaint(&self, bounds: Bounds, text: &str) { @@ -763,14 +760,13 @@ impl Element for InteractiveText { let mut interactive_state = interactive_state.unwrap_or_default(); if let Some(click_listener) = self.click_listener.take() { let mouse_position = window.mouse_position(); - if let Ok(ix) = text_layout.index_for_position(mouse_position) { - if self + if let Ok(ix) = text_layout.index_for_position(mouse_position) + && self .clickable_ranges .iter() .any(|range| range.contains(&ix)) - { - window.set_cursor_style(crate::CursorStyle::PointingHand, hitbox) - } + { + window.set_cursor_style(crate::CursorStyle::PointingHand, hitbox) } let text_layout = text_layout.clone(); @@ -803,13 +799,13 @@ impl Element for InteractiveText { } else { let hitbox = hitbox.clone(); window.on_mouse_event(move |event: &MouseDownEvent, phase, window, _| { - if phase == DispatchPhase::Bubble && hitbox.is_hovered(window) { - if let Ok(mouse_down_index) = + if phase == DispatchPhase::Bubble + && hitbox.is_hovered(window) + && let Ok(mouse_down_index) = text_layout.index_for_position(event.position) - { - mouse_down.set(Some(mouse_down_index)); - window.refresh(); - } + { + mouse_down.set(Some(mouse_down_index)); + window.refresh(); } }); } diff --git a/crates/gpui/src/executor.rs b/crates/gpui/src/executor.rs index 273a3ea503bad26de075a8eb1c6cec01d23f453b..0b28dd030baff6bc95ede07e50e358660a9c1353 100644 --- a/crates/gpui/src/executor.rs +++ b/crates/gpui/src/executor.rs @@ -391,7 +391,7 @@ impl BackgroundExecutor { } /// in tests, run all tasks that are ready to run. If after doing so - /// the test still has outstanding tasks, this will panic. (See also `allow_parking`) + /// the test still has outstanding tasks, this will panic. (See also [`Self::allow_parking`]) #[cfg(any(test, feature = "test-support"))] pub fn run_until_parked(&self) { self.dispatcher.as_test().unwrap().run_until_parked() @@ -405,7 +405,7 @@ impl BackgroundExecutor { self.dispatcher.as_test().unwrap().allow_parking(); } - /// undoes the effect of [`allow_parking`]. + /// undoes the effect of [`Self::allow_parking`]. #[cfg(any(test, feature = "test-support"))] pub fn forbid_parking(&self) { self.dispatcher.as_test().unwrap().forbid_parking(); @@ -480,7 +480,7 @@ impl ForegroundExecutor { /// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics. /// /// Copy-modified from: -/// https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405 +/// #[track_caller] fn spawn_local_with_source_location( future: Fut, diff --git a/crates/gpui/src/geometry.rs b/crates/gpui/src/geometry.rs index 2de3e23ff716d179bb4e2b55c80650d2b010c38e..87cabc8cd9f446fddcb5c98dafbdf956eb1efea2 100644 --- a/crates/gpui/src/geometry.rs +++ b/crates/gpui/src/geometry.rs @@ -1046,7 +1046,7 @@ where size: self.size.clone() + size( amount.left.clone() + amount.right.clone(), - amount.top.clone() + amount.bottom.clone(), + amount.top.clone() + amount.bottom, ), } } @@ -1159,10 +1159,10 @@ where /// Computes the space available within outer bounds. pub fn space_within(&self, outer: &Self) -> Edges { Edges { - top: self.top().clone() - outer.top().clone(), - right: outer.right().clone() - self.right().clone(), - bottom: outer.bottom().clone() - self.bottom().clone(), - left: self.left().clone() - outer.left().clone(), + top: self.top() - outer.top(), + right: outer.right() - self.right(), + bottom: outer.bottom() - self.bottom(), + left: self.left() - outer.left(), } } } @@ -1641,7 +1641,7 @@ impl Bounds { } /// Convert the bounds from logical pixels to physical pixels - pub fn to_device_pixels(&self, factor: f32) -> Bounds { + pub fn to_device_pixels(self, factor: f32) -> Bounds { Bounds { origin: point( DevicePixels((self.origin.x.0 * factor).round() as i32), @@ -1712,7 +1712,7 @@ where top: self.top.clone() * rhs.top, right: self.right.clone() * rhs.right, bottom: self.bottom.clone() * rhs.bottom, - left: self.left.clone() * rhs.left, + left: self.left * rhs.left, } } } @@ -1957,7 +1957,7 @@ impl Edges { /// assert_eq!(edges_in_pixels.bottom, px(32.0)); // 2 rems /// assert_eq!(edges_in_pixels.left, px(50.0)); // 25% of parent width /// ``` - pub fn to_pixels(&self, parent_size: Size, rem_size: Pixels) -> Edges { + pub fn to_pixels(self, parent_size: Size, rem_size: Pixels) -> Edges { Edges { top: self.top.to_pixels(parent_size.height, rem_size), right: self.right.to_pixels(parent_size.width, rem_size), @@ -2027,7 +2027,7 @@ impl Edges { /// assert_eq!(edges_in_pixels.bottom, px(20.0)); // Already in pixels /// assert_eq!(edges_in_pixels.left, px(32.0)); // 2 rems converted to pixels /// ``` - pub fn to_pixels(&self, rem_size: Pixels) -> Edges { + pub fn to_pixels(self, rem_size: Pixels) -> Edges { Edges { top: self.top.to_pixels(rem_size), right: self.right.to_pixels(rem_size), @@ -2272,7 +2272,7 @@ impl Corners { /// assert_eq!(corners_in_pixels.bottom_right, Pixels(30.0)); /// assert_eq!(corners_in_pixels.bottom_left, Pixels(32.0)); // 2 rems converted to pixels /// ``` - pub fn to_pixels(&self, rem_size: Pixels) -> Corners { + pub fn to_pixels(self, rem_size: Pixels) -> Corners { Corners { top_left: self.top_left.to_pixels(rem_size), top_right: self.top_right.to_pixels(rem_size), @@ -2411,7 +2411,7 @@ where top_left: self.top_left.clone() * rhs.top_left, top_right: self.top_right.clone() * rhs.top_right, bottom_right: self.bottom_right.clone() * rhs.bottom_right, - bottom_left: self.bottom_left.clone() * rhs.bottom_left, + bottom_left: self.bottom_left * rhs.bottom_left, } } } @@ -2858,7 +2858,7 @@ impl DevicePixels { /// let total_bytes = pixels.to_bytes(bytes_per_pixel); /// assert_eq!(total_bytes, 40); // 10 pixels * 4 bytes/pixel = 40 bytes /// ``` - pub fn to_bytes(&self, bytes_per_pixel: u8) -> u32 { + pub fn to_bytes(self, bytes_per_pixel: u8) -> u32 { self.0 as u32 * bytes_per_pixel as u32 } } @@ -3073,8 +3073,8 @@ pub struct Rems(pub f32); impl Rems { /// Convert this Rem value to pixels. - pub fn to_pixels(&self, rem_size: Pixels) -> Pixels { - *self * rem_size + pub fn to_pixels(self, rem_size: Pixels) -> Pixels { + self * rem_size } } @@ -3168,9 +3168,9 @@ impl AbsoluteLength { /// assert_eq!(length_in_pixels.to_pixels(rem_size), Pixels(42.0)); /// assert_eq!(length_in_rems.to_pixels(rem_size), Pixels(32.0)); /// ``` - pub fn to_pixels(&self, rem_size: Pixels) -> Pixels { + pub fn to_pixels(self, rem_size: Pixels) -> Pixels { match self { - AbsoluteLength::Pixels(pixels) => *pixels, + AbsoluteLength::Pixels(pixels) => pixels, AbsoluteLength::Rems(rems) => rems.to_pixels(rem_size), } } @@ -3184,10 +3184,10 @@ impl AbsoluteLength { /// # Returns /// /// Returns the `AbsoluteLength` as `Pixels`. - pub fn to_rems(&self, rem_size: Pixels) -> Rems { + pub fn to_rems(self, rem_size: Pixels) -> Rems { match self { AbsoluteLength::Pixels(pixels) => Rems(pixels.0 / rem_size.0), - AbsoluteLength::Rems(rems) => *rems, + AbsoluteLength::Rems(rems) => rems, } } } @@ -3315,12 +3315,12 @@ impl DefiniteLength { /// assert_eq!(length_in_rems.to_pixels(base_size, rem_size), Pixels(32.0)); /// assert_eq!(length_as_fraction.to_pixels(base_size, rem_size), Pixels(50.0)); /// ``` - pub fn to_pixels(&self, base_size: AbsoluteLength, rem_size: Pixels) -> Pixels { + pub fn to_pixels(self, base_size: AbsoluteLength, rem_size: Pixels) -> Pixels { match self { DefiniteLength::Absolute(size) => size.to_pixels(rem_size), DefiniteLength::Fraction(fraction) => match base_size { - AbsoluteLength::Pixels(px) => px * *fraction, - AbsoluteLength::Rems(rems) => rems * rem_size * *fraction, + AbsoluteLength::Pixels(px) => px * fraction, + AbsoluteLength::Rems(rems) => rems * rem_size * fraction, }, } } diff --git a/crates/gpui/src/gpui.rs b/crates/gpui/src/gpui.rs index 09799eb910f0eeece17fd9975c3c13f6accd2df6..3c4ee41c16ab7cfc5e42007291e330282b330ecb 100644 --- a/crates/gpui/src/gpui.rs +++ b/crates/gpui/src/gpui.rs @@ -24,7 +24,7 @@ //! - State management and communication with [`Entity`]'s. Whenever you need to store application state //! that communicates between different parts of your application, you'll want to use GPUI's //! entities. Entities are owned by GPUI and are only accessible through an owned smart pointer -//! similar to an [`std::rc::Rc`]. See the [`app::context`] module for more information. +//! similar to an [`std::rc::Rc`]. See [`app::Context`] for more information. //! //! - High level, declarative UI with views. All UI in GPUI starts with a view. A view is simply //! a [`Entity`] that can be rendered, by implementing the [`Render`] trait. At the start of each frame, GPUI @@ -37,7 +37,7 @@ //! provide a nice wrapper around an imperative API that provides as much flexibility and control as //! you need. Elements have total control over how they and their child elements are rendered and //! can be used for making efficient views into large lists, implement custom layouting for a code editor, -//! and anything else you can think of. See the [`element`] module for more information. +//! and anything else you can think of. See the [`elements`] module for more information. //! //! Each of these registers has one or more corresponding contexts that can be accessed from all GPUI services. //! This context is your main interface to GPUI, and is used extensively throughout the framework. @@ -51,9 +51,9 @@ //! Use this for implementing keyboard shortcuts, such as cmd-q (See `action` module for more information). //! - Platform services, such as `quit the app` or `open a URL` are available as methods on the [`app::App`]. //! - An async executor that is integrated with the platform's event loop. See the [`executor`] module for more information., -//! - The [`gpui::test`](test) macro provides a convenient way to write tests for your GPUI applications. Tests also have their -//! own kind of context, a [`TestAppContext`] which provides ways of simulating common platform input. See [`app::test_context`] -//! and [`test`] modules for more details. +//! - The [`gpui::test`](macro@test) macro provides a convenient way to write tests for your GPUI applications. Tests also have their +//! own kind of context, a [`TestAppContext`] which provides ways of simulating common platform input. See [`TestAppContext`] +//! and [`mod@test`] modules for more details. //! //! Currently, the best way to learn about these APIs is to read the Zed source code, ask us about it at a fireside hack, or drop //! a question in the [Zed Discord](https://zed.dev/community-links). We're working on improving the documentation, creating more examples, @@ -117,7 +117,7 @@ pub mod private { mod seal { /// A mechanism for restricting implementations of a trait to only those in GPUI. - /// See: https://predr.ag/blog/definitive-guide-to-sealed-traits-in-rust/ + /// See: pub trait Sealed {} } @@ -157,7 +157,7 @@ pub use taffy::{AvailableSpace, LayoutId}; #[cfg(any(test, feature = "test-support"))] pub use test::*; pub use text_system::*; -pub use util::arc_cow::ArcCow; +pub use util::{FutureExt, Timeout, arc_cow::ArcCow}; pub use view::*; pub use window::*; @@ -172,6 +172,10 @@ pub trait AppContext { type Result; /// Create a new entity in the app context. + #[expect( + clippy::wrong_self_convention, + reason = "`App::new` is an ubiquitous function for creating entities" + )] fn new( &mut self, build_entity: impl FnOnce(&mut Context) -> T, @@ -348,7 +352,7 @@ impl Flatten for Result { } /// Information about the GPU GPUI is running on. -#[derive(Default, Debug)] +#[derive(Default, Debug, serde::Serialize, serde::Deserialize, Clone)] pub struct GpuSpecs { /// Whether the GPU is really a fake (like `llvmpipe`) running on the CPU. pub is_software_emulated: bool, diff --git a/crates/gpui/src/input.rs b/crates/gpui/src/input.rs index 4acd7f90c1273a1eb51b1be2ccc672a79e6f7710..dc36ef9e16feedf31c01cd38327fd12729f894b3 100644 --- a/crates/gpui/src/input.rs +++ b/crates/gpui/src/input.rs @@ -72,7 +72,7 @@ pub trait EntityInputHandler: 'static + Sized { ) -> Option; } -/// The canonical implementation of [`PlatformInputHandler`]. Call [`Window::handle_input`] +/// The canonical implementation of [`crate::PlatformInputHandler`]. Call [`Window::handle_input`] /// with an instance during your element's paint. pub struct ElementInputHandler { view: Entity, diff --git a/crates/gpui/src/inspector.rs b/crates/gpui/src/inspector.rs index 23c46edcc11ed36cfbe3ad110dc296af3e129784..9f86576a599845bb9e09760e8001333b9dea745d 100644 --- a/crates/gpui/src/inspector.rs +++ b/crates/gpui/src/inspector.rs @@ -164,7 +164,7 @@ mod conditional { if let Some(render_inspector) = cx .inspector_element_registry .renderers_by_type_id - .remove(&type_id) + .remove(type_id) { let mut element = (render_inspector)( active_element.id.clone(), diff --git a/crates/gpui/src/key_dispatch.rs b/crates/gpui/src/key_dispatch.rs index c3f5d186030bf2a38c2345724666ca38003cd484..63cfa680c0d158811a92cddfda390ff82fb7db5c 100644 --- a/crates/gpui/src/key_dispatch.rs +++ b/crates/gpui/src/key_dispatch.rs @@ -408,7 +408,7 @@ impl DispatchTree { keymap .bindings_for_action(action) .filter(|binding| { - Self::binding_matches_predicate_and_not_shadowed(&keymap, &binding, context_stack) + Self::binding_matches_predicate_and_not_shadowed(&keymap, binding, context_stack) }) .cloned() .collect() @@ -426,7 +426,7 @@ impl DispatchTree { .bindings_for_action(action) .rev() .find(|binding| { - Self::binding_matches_predicate_and_not_shadowed(&keymap, &binding, context_stack) + Self::binding_matches_predicate_and_not_shadowed(&keymap, binding, context_stack) }) .cloned() } @@ -458,7 +458,7 @@ impl DispatchTree { .keymap .borrow() .bindings_for_input(input, &context_stack); - return (bindings, partial, context_stack); + (bindings, partial, context_stack) } /// dispatch_key processes the keystroke @@ -552,7 +552,7 @@ impl DispatchTree { let mut current_node_id = Some(target); while let Some(node_id) = current_node_id { dispatch_path.push(node_id); - current_node_id = self.nodes[node_id.0].parent; + current_node_id = self.nodes.get(node_id.0).and_then(|node| node.parent); } dispatch_path.reverse(); // Reverse the path so it goes from the root to the focused node. dispatch_path @@ -639,10 +639,7 @@ mod tests { } fn partial_eq(&self, action: &dyn Action) -> bool { - action - .as_any() - .downcast_ref::() - .map_or(false, |a| self == a) + action.as_any().downcast_ref::() == Some(self) } fn boxed_clone(&self) -> std::boxed::Box { diff --git a/crates/gpui/src/keymap.rs b/crates/gpui/src/keymap.rs index 83d7479a04423d249a2be69c69756211eb9d485d..12f082eb60799bdf9a0cdfaf7d546fa2bdf13e04 100644 --- a/crates/gpui/src/keymap.rs +++ b/crates/gpui/src/keymap.rs @@ -4,7 +4,7 @@ mod context; pub use binding::*; pub use context::*; -use crate::{Action, Keystroke, is_no_action}; +use crate::{Action, AsKeystroke, Keystroke, is_no_action}; use collections::{HashMap, HashSet}; use smallvec::SmallVec; use std::any::TypeId; @@ -141,14 +141,14 @@ impl Keymap { /// only. pub fn bindings_for_input( &self, - input: &[Keystroke], + input: &[impl AsKeystroke], context_stack: &[KeyContext], ) -> (SmallVec<[KeyBinding; 1]>, bool) { let mut matched_bindings = SmallVec::<[(usize, BindingIndex, &KeyBinding); 1]>::new(); let mut pending_bindings = SmallVec::<[(BindingIndex, &KeyBinding); 1]>::new(); for (ix, binding) in self.bindings().enumerate().rev() { - let Some(depth) = self.binding_enabled(binding, &context_stack) else { + let Some(depth) = self.binding_enabled(binding, context_stack) else { continue; }; let Some(pending) = binding.match_keystrokes(input) else { @@ -192,7 +192,6 @@ impl Keymap { (bindings, !pending.is_empty()) } - /// Check if the given binding is enabled, given a certain key context. /// Returns the deepest depth at which the binding matches, or None if it doesn't match. fn binding_enabled(&self, binding: &KeyBinding, contexts: &[KeyContext]) -> Option { @@ -264,7 +263,7 @@ mod tests { ]; let mut keymap = Keymap::default(); - keymap.add_bindings(bindings.clone()); + keymap.add_bindings(bindings); let (result, pending) = keymap.bindings_for_input( &[Keystroke::parse("ctrl-a").unwrap()], @@ -290,7 +289,7 @@ mod tests { ]; let mut keymap = Keymap::default(); - keymap.add_bindings(bindings.clone()); + keymap.add_bindings(bindings); // binding is only enabled in a specific context assert!( @@ -344,7 +343,7 @@ mod tests { ]; let mut keymap = Keymap::default(); - keymap.add_bindings(bindings.clone()); + keymap.add_bindings(bindings); let space = || Keystroke::parse("space").unwrap(); let w = || Keystroke::parse("w").unwrap(); @@ -364,29 +363,29 @@ mod tests { // Ensure `space` results in pending input on the workspace, but not editor let space_workspace = keymap.bindings_for_input(&[space()], &workspace_context()); assert!(space_workspace.0.is_empty()); - assert_eq!(space_workspace.1, true); + assert!(space_workspace.1); let space_editor = keymap.bindings_for_input(&[space()], &editor_workspace_context()); assert!(space_editor.0.is_empty()); - assert_eq!(space_editor.1, false); + assert!(!space_editor.1); // Ensure `space w` results in pending input on the workspace, but not editor let space_w_workspace = keymap.bindings_for_input(&space_w, &workspace_context()); assert!(space_w_workspace.0.is_empty()); - assert_eq!(space_w_workspace.1, true); + assert!(space_w_workspace.1); let space_w_editor = keymap.bindings_for_input(&space_w, &editor_workspace_context()); assert!(space_w_editor.0.is_empty()); - assert_eq!(space_w_editor.1, false); + assert!(!space_w_editor.1); // Ensure `space w w` results in the binding in the workspace, but not in the editor let space_w_w_workspace = keymap.bindings_for_input(&space_w_w, &workspace_context()); assert!(!space_w_w_workspace.0.is_empty()); - assert_eq!(space_w_w_workspace.1, false); + assert!(!space_w_w_workspace.1); let space_w_w_editor = keymap.bindings_for_input(&space_w_w, &editor_workspace_context()); assert!(space_w_w_editor.0.is_empty()); - assert_eq!(space_w_w_editor.1, false); + assert!(!space_w_w_editor.1); // Now test what happens if we have another binding defined AFTER the NoAction // that should result in pending @@ -396,11 +395,11 @@ mod tests { KeyBinding::new("space w x", ActionAlpha {}, Some("editor")), ]; let mut keymap = Keymap::default(); - keymap.add_bindings(bindings.clone()); + keymap.add_bindings(bindings); let space_editor = keymap.bindings_for_input(&[space()], &editor_workspace_context()); assert!(space_editor.0.is_empty()); - assert_eq!(space_editor.1, true); + assert!(space_editor.1); // Now test what happens if we have another binding defined BEFORE the NoAction // that should result in pending @@ -410,11 +409,11 @@ mod tests { KeyBinding::new("space w w", NoAction {}, Some("editor")), ]; let mut keymap = Keymap::default(); - keymap.add_bindings(bindings.clone()); + keymap.add_bindings(bindings); let space_editor = keymap.bindings_for_input(&[space()], &editor_workspace_context()); assert!(space_editor.0.is_empty()); - assert_eq!(space_editor.1, true); + assert!(space_editor.1); // Now test what happens if we have another binding defined at a higher context // that should result in pending @@ -424,11 +423,11 @@ mod tests { KeyBinding::new("space w w", NoAction {}, Some("editor")), ]; let mut keymap = Keymap::default(); - keymap.add_bindings(bindings.clone()); + keymap.add_bindings(bindings); let space_editor = keymap.bindings_for_input(&[space()], &editor_workspace_context()); assert!(space_editor.0.is_empty()); - assert_eq!(space_editor.1, true); + assert!(space_editor.1); } #[test] @@ -439,7 +438,7 @@ mod tests { ]; let mut keymap = Keymap::default(); - keymap.add_bindings(bindings.clone()); + keymap.add_bindings(bindings); // Ensure `space` results in pending input on the workspace, but not editor let (result, pending) = keymap.bindings_for_input( @@ -447,7 +446,7 @@ mod tests { &[KeyContext::parse("editor").unwrap()], ); assert!(result.is_empty()); - assert_eq!(pending, true); + assert!(pending); let bindings = [ KeyBinding::new("ctrl-w left", ActionAlpha {}, Some("editor")), @@ -455,7 +454,7 @@ mod tests { ]; let mut keymap = Keymap::default(); - keymap.add_bindings(bindings.clone()); + keymap.add_bindings(bindings); // Ensure `space` results in pending input on the workspace, but not editor let (result, pending) = keymap.bindings_for_input( @@ -463,7 +462,7 @@ mod tests { &[KeyContext::parse("editor").unwrap()], ); assert_eq!(result.len(), 1); - assert_eq!(pending, false); + assert!(!pending); } #[test] @@ -474,7 +473,7 @@ mod tests { ]; let mut keymap = Keymap::default(); - keymap.add_bindings(bindings.clone()); + keymap.add_bindings(bindings); // Ensure `space` results in pending input on the workspace, but not editor let (result, pending) = keymap.bindings_for_input( @@ -482,7 +481,7 @@ mod tests { &[KeyContext::parse("editor").unwrap()], ); assert!(result.is_empty()); - assert_eq!(pending, false); + assert!(!pending); } #[test] @@ -494,7 +493,7 @@ mod tests { ]; let mut keymap = Keymap::default(); - keymap.add_bindings(bindings.clone()); + keymap.add_bindings(bindings); // Ensure `space` results in pending input on the workspace, but not editor let (result, pending) = keymap.bindings_for_input( @@ -505,7 +504,7 @@ mod tests { ], ); assert_eq!(result.len(), 1); - assert_eq!(pending, false); + assert!(!pending); } #[test] @@ -516,7 +515,7 @@ mod tests { ]; let mut keymap = Keymap::default(); - keymap.add_bindings(bindings.clone()); + keymap.add_bindings(bindings); // Ensure `space` results in pending input on the workspace, but not editor let (result, pending) = keymap.bindings_for_input( @@ -527,7 +526,7 @@ mod tests { ], ); assert_eq!(result.len(), 0); - assert_eq!(pending, false); + assert!(!pending); } #[test] @@ -537,7 +536,7 @@ mod tests { KeyBinding::new("ctrl-x 0", ActionAlpha, Some("Workspace")), ]; let mut keymap = Keymap::default(); - keymap.add_bindings(bindings.clone()); + keymap.add_bindings(bindings); let matched = keymap.bindings_for_input( &[Keystroke::parse("ctrl-x")].map(Result::unwrap), @@ -560,7 +559,7 @@ mod tests { KeyBinding::new("ctrl-x 0", NoAction, Some("Workspace")), ]; let mut keymap = Keymap::default(); - keymap.add_bindings(bindings.clone()); + keymap.add_bindings(bindings); let matched = keymap.bindings_for_input( &[Keystroke::parse("ctrl-x")].map(Result::unwrap), @@ -579,7 +578,7 @@ mod tests { KeyBinding::new("ctrl-x 0", NoAction, Some("vim_mode == normal")), ]; let mut keymap = Keymap::default(); - keymap.add_bindings(bindings.clone()); + keymap.add_bindings(bindings); let matched = keymap.bindings_for_input( &[Keystroke::parse("ctrl-x")].map(Result::unwrap), @@ -602,7 +601,7 @@ mod tests { KeyBinding::new("ctrl-x", ActionBeta, Some("vim_mode == normal")), ]; let mut keymap = Keymap::default(); - keymap.add_bindings(bindings.clone()); + keymap.add_bindings(bindings); let matched = keymap.bindings_for_input( &[Keystroke::parse("ctrl-x")].map(Result::unwrap), @@ -629,7 +628,7 @@ mod tests { ]; let mut keymap = Keymap::default(); - keymap.add_bindings(bindings.clone()); + keymap.add_bindings(bindings); assert_bindings(&keymap, &ActionAlpha {}, &["ctrl-a"]); assert_bindings(&keymap, &ActionBeta {}, &[]); @@ -639,7 +638,7 @@ mod tests { fn assert_bindings(keymap: &Keymap, action: &dyn Action, expected: &[&str]) { let actual = keymap .bindings_for_action(action) - .map(|binding| binding.keystrokes[0].unparse()) + .map(|binding| binding.keystrokes[0].inner().unparse()) .collect::>(); assert_eq!(actual, expected, "{:?}", action); } diff --git a/crates/gpui/src/keymap/binding.rs b/crates/gpui/src/keymap/binding.rs index 1d3f612c5bef76d75cb1bd8ee9d9c686190c3fd7..fc4b32941b85f4cdea31aaba7198d3e7043ee481 100644 --- a/crates/gpui/src/keymap/binding.rs +++ b/crates/gpui/src/keymap/binding.rs @@ -1,14 +1,15 @@ use std::rc::Rc; -use collections::HashMap; - -use crate::{Action, InvalidKeystrokeError, KeyBindingContextPredicate, Keystroke, SharedString}; +use crate::{ + Action, AsKeystroke, DummyKeyboardMapper, InvalidKeystrokeError, KeyBindingContextPredicate, + KeybindingKeystroke, Keystroke, PlatformKeyboardMapper, SharedString, +}; use smallvec::SmallVec; /// A keybinding and its associated metadata, from the keymap. pub struct KeyBinding { pub(crate) action: Box, - pub(crate) keystrokes: SmallVec<[Keystroke; 2]>, + pub(crate) keystrokes: SmallVec<[KeybindingKeystroke; 2]>, pub(crate) context_predicate: Option>, pub(crate) meta: Option, /// The json input string used when building the keybinding, if any @@ -30,12 +31,17 @@ impl Clone for KeyBinding { impl KeyBinding { /// Construct a new keybinding from the given data. Panics on parse error. pub fn new(keystrokes: &str, action: A, context: Option<&str>) -> Self { - let context_predicate = if let Some(context) = context { - Some(KeyBindingContextPredicate::parse(context).unwrap().into()) - } else { - None - }; - Self::load(keystrokes, Box::new(action), context_predicate, None, None).unwrap() + let context_predicate = + context.map(|context| KeyBindingContextPredicate::parse(context).unwrap().into()); + Self::load( + keystrokes, + Box::new(action), + context_predicate, + false, + None, + &DummyKeyboardMapper, + ) + .unwrap() } /// Load a keybinding from the given raw data. @@ -43,24 +49,22 @@ impl KeyBinding { keystrokes: &str, action: Box, context_predicate: Option>, - key_equivalents: Option<&HashMap>, + use_key_equivalents: bool, action_input: Option, + keyboard_mapper: &dyn PlatformKeyboardMapper, ) -> std::result::Result { - let mut keystrokes: SmallVec<[Keystroke; 2]> = keystrokes + let keystrokes: SmallVec<[KeybindingKeystroke; 2]> = keystrokes .split_whitespace() - .map(Keystroke::parse) + .map(|source| { + let keystroke = Keystroke::parse(source)?; + Ok(KeybindingKeystroke::new_with_mapper( + keystroke, + use_key_equivalents, + keyboard_mapper, + )) + }) .collect::>()?; - if let Some(equivalents) = key_equivalents { - for keystroke in keystrokes.iter_mut() { - if keystroke.key.chars().count() == 1 { - if let Some(key) = equivalents.get(&keystroke.key.chars().next().unwrap()) { - keystroke.key = key.to_string(); - } - } - } - } - Ok(Self { keystrokes, action, @@ -82,13 +86,13 @@ impl KeyBinding { } /// Check if the given keystrokes match this binding. - pub fn match_keystrokes(&self, typed: &[Keystroke]) -> Option { + pub fn match_keystrokes(&self, typed: &[impl AsKeystroke]) -> Option { if self.keystrokes.len() < typed.len() { return None; } for (target, typed) in self.keystrokes.iter().zip(typed.iter()) { - if !typed.should_match(target) { + if !typed.as_keystroke().should_match(target) { return None; } } @@ -97,7 +101,7 @@ impl KeyBinding { } /// Get the keystrokes associated with this binding - pub fn keystrokes(&self) -> &[Keystroke] { + pub fn keystrokes(&self) -> &[KeybindingKeystroke] { self.keystrokes.as_slice() } diff --git a/crates/gpui/src/keymap/context.rs b/crates/gpui/src/keymap/context.rs index 281035fe97614dd810f1057c8094b2c698984166..960bd1752fe8c1527b9c593658e429af4cd61029 100644 --- a/crates/gpui/src/keymap/context.rs +++ b/crates/gpui/src/keymap/context.rs @@ -287,7 +287,7 @@ impl KeyBindingContextPredicate { return false; } } - return true; + true } // Workspace > Pane > Editor // @@ -305,7 +305,7 @@ impl KeyBindingContextPredicate { return true; } } - return false; + false } Self::And(left, right) => { left.eval_inner(contexts, all_contexts) && right.eval_inner(contexts, all_contexts) @@ -668,11 +668,7 @@ mod tests { let contexts = vec![other_context.clone(), child_context.clone()]; assert!(!predicate.eval(&contexts)); - let contexts = vec![ - parent_context.clone(), - other_context.clone(), - child_context.clone(), - ]; + let contexts = vec![parent_context.clone(), other_context, child_context.clone()]; assert!(predicate.eval(&contexts)); assert!(!predicate.eval(&[])); @@ -681,7 +677,7 @@ mod tests { let zany_predicate = KeyBindingContextPredicate::parse("child > child").unwrap(); assert!(!zany_predicate.eval(slice::from_ref(&child_context))); - assert!(zany_predicate.eval(&[child_context.clone(), child_context.clone()])); + assert!(zany_predicate.eval(&[child_context.clone(), child_context])); } #[test] @@ -718,7 +714,7 @@ mod tests { let not_descendant = KeyBindingContextPredicate::parse("parent > !child").unwrap(); assert!(!not_descendant.eval(slice::from_ref(&parent_context))); assert!(!not_descendant.eval(slice::from_ref(&child_context))); - assert!(!not_descendant.eval(&[parent_context.clone(), child_context.clone()])); + assert!(!not_descendant.eval(&[parent_context, child_context])); let double_not = KeyBindingContextPredicate::parse("!!editor").unwrap(); assert!(double_not.eval(slice::from_ref(&editor_context))); diff --git a/crates/gpui/src/path_builder.rs b/crates/gpui/src/path_builder.rs index 6c8cfddd523c4d56c81ebcbbf1437b5cc418d73c..40a6e71e0a1738adf1ed261183d2340682826992 100644 --- a/crates/gpui/src/path_builder.rs +++ b/crates/gpui/src/path_builder.rs @@ -278,7 +278,7 @@ impl PathBuilder { options: &StrokeOptions, ) -> Result, Error> { let path = if let Some(dash_array) = dash_array { - let measurements = lyon::algorithms::measure::PathMeasurements::from_path(&path, 0.01); + let measurements = lyon::algorithms::measure::PathMeasurements::from_path(path, 0.01); let mut sampler = measurements .create_sampler(path, lyon::algorithms::measure::SampleType::Normalized); let mut builder = lyon::path::Path::builder(); @@ -318,7 +318,7 @@ impl PathBuilder { Ok(Self::build_path(buf)) } - /// Builds a [`Path`] from a [`lyon::VertexBuffers`]. + /// Builds a [`Path`] from a [`lyon::tessellation::VertexBuffers`]. pub fn build_path(buf: VertexBuffers) -> Path { if buf.vertices.is_empty() { return Path::new(Point::default()); diff --git a/crates/gpui/src/platform.rs b/crates/gpui/src/platform.rs index b495d70dfdd3594a27ed3c1793e7e0ac4e7e0b4a..444b60ac154424c423c3cd6a827b22cd7024694f 100644 --- a/crates/gpui/src/platform.rs +++ b/crates/gpui/src/platform.rs @@ -39,8 +39,8 @@ use crate::{ Action, AnyWindowHandle, App, AsyncWindowContext, BackgroundExecutor, Bounds, DEFAULT_WINDOW_SIZE, DevicePixels, DispatchEventResult, Font, FontId, FontMetrics, FontRun, ForegroundExecutor, GlyphId, GpuSpecs, ImageSource, Keymap, LineLayout, Pixels, PlatformInput, - Point, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams, ScaledPixels, Scene, - ShapedGlyph, ShapedRun, SharedString, Size, SvgRenderer, SvgSize, Task, TaskLabel, Window, + Point, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams, Scene, ShapedGlyph, + ShapedRun, SharedString, Size, SvgRenderer, SvgSize, SystemWindowTab, Task, TaskLabel, Window, WindowControlArea, hash, point, px, size, }; use anyhow::Result; @@ -220,14 +220,17 @@ pub(crate) trait Platform: 'static { &self, options: PathPromptOptions, ) -> oneshot::Receiver>>>; - fn prompt_for_new_path(&self, directory: &Path) -> oneshot::Receiver>>; + fn prompt_for_new_path( + &self, + directory: &Path, + suggested_name: Option<&str>, + ) -> oneshot::Receiver>>; fn can_select_mixed_files_and_dirs(&self) -> bool; fn reveal_path(&self, path: &Path); fn open_with_system(&self, path: &Path); fn on_quit(&self, callback: Box); fn on_reopen(&self, callback: Box); - fn on_keyboard_layout_change(&self, callback: Box); fn set_menus(&self, menus: Vec, keymap: &Keymap); fn get_menus(&self) -> Option> { @@ -247,7 +250,6 @@ pub(crate) trait Platform: 'static { fn on_app_menu_action(&self, callback: Box); fn on_will_open_app_menu(&self, callback: Box); fn on_validate_app_menu_command(&self, callback: Box bool>); - fn keyboard_layout(&self) -> Box; fn compositor_name(&self) -> &'static str { "" @@ -268,6 +270,10 @@ pub(crate) trait Platform: 'static { fn write_credentials(&self, url: &str, username: &str, password: &[u8]) -> Task>; fn read_credentials(&self, url: &str) -> Task)>>>; fn delete_credentials(&self, url: &str) -> Task>; + + fn keyboard_layout(&self) -> Box; + fn keyboard_mapper(&self) -> Rc; + fn on_keyboard_layout_change(&self, callback: Box); } /// A handle to a platform's display, e.g. a monitor or laptop screen. @@ -496,9 +502,27 @@ pub(crate) trait PlatformWindow: HasWindowHandle + HasDisplayHandle { fn sprite_atlas(&self) -> Arc; // macOS specific methods + fn get_title(&self) -> String { + String::new() + } + fn tabbed_windows(&self) -> Option> { + None + } + fn tab_bar_visible(&self) -> bool { + false + } fn set_edited(&mut self, _edited: bool) {} fn show_character_palette(&self) {} fn titlebar_double_click(&self) {} + fn on_move_tab_to_new_window(&self, _callback: Box) {} + fn on_merge_all_windows(&self, _callback: Box) {} + fn on_select_previous_tab(&self, _callback: Box) {} + fn on_select_next_tab(&self, _callback: Box) {} + fn on_toggle_tab_bar(&self, _callback: Box) {} + fn merge_all_windows(&self) {} + fn move_tab_to_new_window(&self) {} + fn toggle_window_tab_overview(&self) {} + fn set_tabbing_identifier(&self, _identifier: Option) {} #[cfg(target_os = "windows")] fn get_raw_handle(&self) -> windows::HWND; @@ -524,7 +548,7 @@ pub(crate) trait PlatformWindow: HasWindowHandle + HasDisplayHandle { fn set_client_inset(&self, _inset: Pixels) {} fn gpu_specs(&self) -> Option; - fn update_ime_position(&self, _bounds: Bounds); + fn update_ime_position(&self, _bounds: Bounds); #[cfg(any(test, feature = "test-support"))] fn as_test(&mut self) -> Option<&mut TestWindow> { @@ -588,7 +612,7 @@ impl PlatformTextSystem for NoopTextSystem { } fn font_id(&self, _descriptor: &Font) -> Result { - return Ok(FontId(1)); + Ok(FontId(1)) } fn font_metrics(&self, _font_id: FontId) -> FontMetrics { @@ -669,7 +693,7 @@ impl PlatformTextSystem for NoopTextSystem { } } let mut runs = Vec::default(); - if glyphs.len() > 0 { + if !glyphs.is_empty() { runs.push(ShapedRun { font_id: FontId(0), glyphs, @@ -1085,6 +1109,12 @@ pub struct WindowOptions { /// Whether the window should be movable by the user pub is_movable: bool, + /// Whether the window should be resizable by the user + pub is_resizable: bool, + + /// Whether the window should be minimized by the user + pub is_minimizable: bool, + /// The display to create the window on, if this is None, /// the window will be created on the main display pub display_id: Option, @@ -1101,6 +1131,9 @@ pub struct WindowOptions { /// Whether to use client or server side decorations. Wayland only /// Note that this may be ignored. pub window_decorations: Option, + + /// Tab group name, allows opening the window as a native tab on macOS 10.12+. Windows with the same tabbing identifier will be grouped together. + pub tabbing_identifier: Option, } /// The variables that can be configured when creating a new window @@ -1127,6 +1160,14 @@ pub(crate) struct WindowParams { #[cfg_attr(any(target_os = "linux", target_os = "freebsd"), allow(dead_code))] pub is_movable: bool, + /// Whether the window should be resizable by the user + #[cfg_attr(any(target_os = "linux", target_os = "freebsd"), allow(dead_code))] + pub is_resizable: bool, + + /// Whether the window should be minimized by the user + #[cfg_attr(any(target_os = "linux", target_os = "freebsd"), allow(dead_code))] + pub is_minimizable: bool, + #[cfg_attr( any(target_os = "linux", target_os = "freebsd", target_os = "windows"), allow(dead_code) @@ -1140,6 +1181,8 @@ pub(crate) struct WindowParams { pub display_id: Option, pub window_min_size: Option>, + #[cfg(target_os = "macos")] + pub tabbing_identifier: Option, } /// Represents the status of how a window should be opened. @@ -1185,11 +1228,14 @@ impl Default for WindowOptions { show: true, kind: WindowKind::Normal, is_movable: true, + is_resizable: true, + is_minimizable: true, display_id: None, window_background: WindowBackgroundAppearance::default(), app_id: None, window_min_size: None, window_decorations: None, + tabbing_identifier: None, } } } @@ -1274,7 +1320,7 @@ pub enum WindowBackgroundAppearance { } /// The options that can be configured for a file dialog prompt -#[derive(Copy, Clone, Debug)] +#[derive(Clone, Debug)] pub struct PathPromptOptions { /// Should the prompt allow files to be selected? pub files: bool, @@ -1282,6 +1328,8 @@ pub struct PathPromptOptions { pub directories: bool, /// Should the prompt allow multiple files to be selected? pub multiple: bool, + /// The prompt to show to a user when selecting a path + pub prompt: Option, } /// What kind of prompt styling to show @@ -1502,7 +1550,7 @@ impl ClipboardItem { for entry in self.entries.iter() { if let ClipboardEntry::String(ClipboardString { text, metadata: _ }) = entry { - answer.push_str(&text); + answer.push_str(text); any_entries = true; } } diff --git a/crates/gpui/src/platform/blade/blade_context.rs b/crates/gpui/src/platform/blade/blade_context.rs index 48872f16198a4ed2d1fc8c2a0b1cbce3eb0de477..12c68a1e70188d3ed2ab425b5abc1bac0dfe3a19 100644 --- a/crates/gpui/src/platform/blade/blade_context.rs +++ b/crates/gpui/src/platform/blade/blade_context.rs @@ -49,7 +49,7 @@ fn parse_pci_id(id: &str) -> anyhow::Result { "Expected a 4 digit PCI ID in hexadecimal format" ); - return u32::from_str_radix(id, 16).context("parsing PCI ID as hex"); + u32::from_str_radix(id, 16).context("parsing PCI ID as hex") } #[cfg(test)] diff --git a/crates/gpui/src/platform/blade/blade_renderer.rs b/crates/gpui/src/platform/blade/blade_renderer.rs index 46d3c16c72a9c10c0e686aff425fcc236c253ce7..1f60920bcc928c97c1f2b2c06e22ed235217c87e 100644 --- a/crates/gpui/src/platform/blade/blade_renderer.rs +++ b/crates/gpui/src/platform/blade/blade_renderer.rs @@ -371,7 +371,7 @@ impl BladeRenderer { .or_else(|| { [4, 2, 1] .into_iter() - .find(|count| context.gpu.supports_texture_sample_count(*count)) + .find(|&n| (context.gpu.capabilities().sample_count_mask & n) != 0) }) .unwrap_or(1); let pipelines = BladePipelines::new(&context.gpu, surface.info(), path_sample_count); @@ -434,24 +434,24 @@ impl BladeRenderer { } fn wait_for_gpu(&mut self) { - if let Some(last_sp) = self.last_sync_point.take() { - if !self.gpu.wait_for(&last_sp, MAX_FRAME_TIME_MS) { - log::error!("GPU hung"); - #[cfg(target_os = "linux")] - if self.gpu.device_information().driver_name == "radv" { - log::error!( - "there's a known bug with amdgpu/radv, try setting ZED_PATH_SAMPLE_COUNT=0 as a workaround" - ); - log::error!( - "if that helps you're running into https://github.com/zed-industries/zed/issues/26143" - ); - } + if let Some(last_sp) = self.last_sync_point.take() + && !self.gpu.wait_for(&last_sp, MAX_FRAME_TIME_MS) + { + log::error!("GPU hung"); + #[cfg(target_os = "linux")] + if self.gpu.device_information().driver_name == "radv" { log::error!( - "your device information is: {:?}", - self.gpu.device_information() + "there's a known bug with amdgpu/radv, try setting ZED_PATH_SAMPLE_COUNT=0 as a workaround" + ); + log::error!( + "if that helps you're running into https://github.com/zed-industries/zed/issues/26143" ); - while !self.gpu.wait_for(&last_sp, MAX_FRAME_TIME_MS) {} } + log::error!( + "your device information is: {:?}", + self.gpu.device_information() + ); + while !self.gpu.wait_for(&last_sp, MAX_FRAME_TIME_MS) {} } } diff --git a/crates/gpui/src/platform/keyboard.rs b/crates/gpui/src/platform/keyboard.rs index e28d7815200800b7e3950c6819e6ef3fc42f0306..10b8620258ecffd41e8018fc539c47812df0fe05 100644 --- a/crates/gpui/src/platform/keyboard.rs +++ b/crates/gpui/src/platform/keyboard.rs @@ -1,3 +1,7 @@ +use collections::HashMap; + +use crate::{KeybindingKeystroke, Keystroke}; + /// A trait for platform-specific keyboard layouts pub trait PlatformKeyboardLayout { /// Get the keyboard layout ID, which should be unique to the layout @@ -5,3 +9,33 @@ pub trait PlatformKeyboardLayout { /// Get the keyboard layout display name fn name(&self) -> &str; } + +/// A trait for platform-specific keyboard mappings +pub trait PlatformKeyboardMapper { + /// Map a key equivalent to its platform-specific representation + fn map_key_equivalent( + &self, + keystroke: Keystroke, + use_key_equivalents: bool, + ) -> KeybindingKeystroke; + /// Get the key equivalents for the current keyboard layout, + /// only used on macOS + fn get_key_equivalents(&self) -> Option<&HashMap>; +} + +/// A dummy implementation of the platform keyboard mapper +pub struct DummyKeyboardMapper; + +impl PlatformKeyboardMapper for DummyKeyboardMapper { + fn map_key_equivalent( + &self, + keystroke: Keystroke, + _use_key_equivalents: bool, + ) -> KeybindingKeystroke { + KeybindingKeystroke::from_keystroke(keystroke) + } + + fn get_key_equivalents(&self) -> Option<&HashMap> { + None + } +} diff --git a/crates/gpui/src/platform/keystroke.rs b/crates/gpui/src/platform/keystroke.rs index 24601eefd6de450622247caaca5ff680c60a3257..4a2bfc785e3eb7e13a845bb67b4524255affb3bb 100644 --- a/crates/gpui/src/platform/keystroke.rs +++ b/crates/gpui/src/platform/keystroke.rs @@ -5,6 +5,14 @@ use std::{ fmt::{Display, Write}, }; +use crate::PlatformKeyboardMapper; + +/// This is a helper trait so that we can simplify the implementation of some functions +pub trait AsKeystroke { + /// Returns the GPUI representation of the keystroke. + fn as_keystroke(&self) -> &Keystroke; +} + /// A keystroke and associated metadata generated by the platform #[derive(Clone, Debug, Eq, PartialEq, Default, Deserialize, Hash)] pub struct Keystroke { @@ -24,6 +32,19 @@ pub struct Keystroke { pub key_char: Option, } +/// Represents a keystroke that can be used in keybindings and displayed to the user. +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct KeybindingKeystroke { + /// The GPUI representation of the keystroke. + inner: Keystroke, + /// The modifiers to display. + #[cfg(target_os = "windows")] + display_modifiers: Modifiers, + /// The key to display. + #[cfg(target_os = "windows")] + display_key: String, +} + /// Error type for `Keystroke::parse`. This is used instead of `anyhow::Error` so that Zed can use /// markdown to display it. #[derive(Debug)] @@ -58,7 +79,7 @@ impl Keystroke { /// /// This method assumes that `self` was typed and `target' is in the keymap, and checks /// both possibilities for self against the target. - pub fn should_match(&self, target: &Keystroke) -> bool { + pub fn should_match(&self, target: &KeybindingKeystroke) -> bool { #[cfg(not(target_os = "windows"))] if let Some(key_char) = self .key_char @@ -71,7 +92,7 @@ impl Keystroke { ..Default::default() }; - if &target.key == key_char && target.modifiers == ime_modifiers { + if &target.inner.key == key_char && target.inner.modifiers == ime_modifiers { return true; } } @@ -83,12 +104,12 @@ impl Keystroke { .filter(|key_char| key_char != &&self.key) { // On Windows, if key_char is set, then the typed keystroke produced the key_char - if &target.key == key_char && target.modifiers == Modifiers::none() { + if &target.inner.key == key_char && target.inner.modifiers == Modifiers::none() { return true; } } - target.modifiers == self.modifiers && target.key == self.key + target.inner.modifiers == self.modifiers && target.inner.key == self.key } /// key syntax is: @@ -200,31 +221,7 @@ impl Keystroke { /// Produces a representation of this key that Parse can understand. pub fn unparse(&self) -> String { - let mut str = String::new(); - if self.modifiers.function { - str.push_str("fn-"); - } - if self.modifiers.control { - str.push_str("ctrl-"); - } - if self.modifiers.alt { - str.push_str("alt-"); - } - if self.modifiers.platform { - #[cfg(target_os = "macos")] - str.push_str("cmd-"); - - #[cfg(any(target_os = "linux", target_os = "freebsd"))] - str.push_str("super-"); - - #[cfg(target_os = "windows")] - str.push_str("win-"); - } - if self.modifiers.shift { - str.push_str("shift-"); - } - str.push_str(&self.key); - str + unparse(&self.modifiers, &self.key) } /// Returns true if this keystroke left @@ -266,6 +263,117 @@ impl Keystroke { } } +impl KeybindingKeystroke { + #[cfg(target_os = "windows")] + pub(crate) fn new(inner: Keystroke, display_modifiers: Modifiers, display_key: String) -> Self { + KeybindingKeystroke { + inner, + display_modifiers, + display_key, + } + } + + /// Create a new keybinding keystroke from the given keystroke using the given keyboard mapper. + pub fn new_with_mapper( + inner: Keystroke, + use_key_equivalents: bool, + keyboard_mapper: &dyn PlatformKeyboardMapper, + ) -> Self { + keyboard_mapper.map_key_equivalent(inner, use_key_equivalents) + } + + /// Create a new keybinding keystroke from the given keystroke, without any platform-specific mapping. + pub fn from_keystroke(keystroke: Keystroke) -> Self { + #[cfg(target_os = "windows")] + { + let key = keystroke.key.clone(); + let modifiers = keystroke.modifiers; + KeybindingKeystroke { + inner: keystroke, + display_modifiers: modifiers, + display_key: key, + } + } + #[cfg(not(target_os = "windows"))] + { + KeybindingKeystroke { inner: keystroke } + } + } + + /// Returns the GPUI representation of the keystroke. + pub fn inner(&self) -> &Keystroke { + &self.inner + } + + /// Returns the modifiers. + /// + /// Platform-specific behavior: + /// - On macOS and Linux, this modifiers is the same as `inner.modifiers`, which is the GPUI representation of the keystroke. + /// - On Windows, this modifiers is the display modifiers, for example, a `ctrl-@` keystroke will have `inner.modifiers` as + /// `Modifiers::control()` and `display_modifiers` as `Modifiers::control_shift()`. + pub fn modifiers(&self) -> &Modifiers { + #[cfg(target_os = "windows")] + { + &self.display_modifiers + } + #[cfg(not(target_os = "windows"))] + { + &self.inner.modifiers + } + } + + /// Returns the key. + /// + /// Platform-specific behavior: + /// - On macOS and Linux, this key is the same as `inner.key`, which is the GPUI representation of the keystroke. + /// - On Windows, this key is the display key, for example, a `ctrl-@` keystroke will have `inner.key` as `@` and `display_key` as `2`. + pub fn key(&self) -> &str { + #[cfg(target_os = "windows")] + { + &self.display_key + } + #[cfg(not(target_os = "windows"))] + { + &self.inner.key + } + } + + /// Sets the modifiers. On Windows this modifies both `inner.modifiers` and `display_modifiers`. + pub fn set_modifiers(&mut self, modifiers: Modifiers) { + self.inner.modifiers = modifiers; + #[cfg(target_os = "windows")] + { + self.display_modifiers = modifiers; + } + } + + /// Sets the key. On Windows this modifies both `inner.key` and `display_key`. + pub fn set_key(&mut self, key: String) { + #[cfg(target_os = "windows")] + { + self.display_key = key.clone(); + } + self.inner.key = key; + } + + /// Produces a representation of this key that Parse can understand. + pub fn unparse(&self) -> String { + #[cfg(target_os = "windows")] + { + unparse(&self.display_modifiers, &self.display_key) + } + #[cfg(not(target_os = "windows"))] + { + unparse(&self.inner.modifiers, &self.inner.key) + } + } + + /// Removes the key_char + pub fn remove_key_char(&mut self) { + self.inner.key_char = None; + } +} + fn is_printable_key(key: &str) -> bool { !matches!( key, @@ -322,65 +430,15 @@ fn is_printable_key(key: &str) -> bool { impl std::fmt::Display for Keystroke { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.modifiers.control { - #[cfg(target_os = "macos")] - f.write_char('^')?; - - #[cfg(not(target_os = "macos"))] - write!(f, "ctrl-")?; - } - if self.modifiers.alt { - #[cfg(target_os = "macos")] - f.write_char('⌥')?; - - #[cfg(not(target_os = "macos"))] - write!(f, "alt-")?; - } - if self.modifiers.platform { - #[cfg(target_os = "macos")] - f.write_char('⌘')?; - - #[cfg(any(target_os = "linux", target_os = "freebsd"))] - f.write_char('❖')?; - - #[cfg(target_os = "windows")] - f.write_char('⊞')?; - } - if self.modifiers.shift { - #[cfg(target_os = "macos")] - f.write_char('⇧')?; + display_modifiers(&self.modifiers, f)?; + display_key(&self.key, f) + } +} - #[cfg(not(target_os = "macos"))] - write!(f, "shift-")?; - } - let key = match self.key.as_str() { - #[cfg(target_os = "macos")] - "backspace" => '⌫', - #[cfg(target_os = "macos")] - "up" => '↑', - #[cfg(target_os = "macos")] - "down" => '↓', - #[cfg(target_os = "macos")] - "left" => '←', - #[cfg(target_os = "macos")] - "right" => '→', - #[cfg(target_os = "macos")] - "tab" => '⇥', - #[cfg(target_os = "macos")] - "escape" => '⎋', - #[cfg(target_os = "macos")] - "shift" => '⇧', - #[cfg(target_os = "macos")] - "control" => '⌃', - #[cfg(target_os = "macos")] - "alt" => '⌥', - #[cfg(target_os = "macos")] - "platform" => '⌘', - - key if key.len() == 1 => key.chars().next().unwrap().to_ascii_uppercase(), - key => return f.write_str(key), - }; - f.write_char(key) +impl std::fmt::Display for KeybindingKeystroke { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + display_modifiers(self.modifiers(), f)?; + display_key(self.key(), f) } } @@ -600,3 +658,110 @@ pub struct Capslock { #[serde(default)] pub on: bool, } + +impl AsKeystroke for Keystroke { + fn as_keystroke(&self) -> &Keystroke { + self + } +} + +impl AsKeystroke for KeybindingKeystroke { + fn as_keystroke(&self) -> &Keystroke { + &self.inner + } +} + +fn display_modifiers(modifiers: &Modifiers, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if modifiers.control { + #[cfg(target_os = "macos")] + f.write_char('^')?; + + #[cfg(not(target_os = "macos"))] + write!(f, "ctrl-")?; + } + if modifiers.alt { + #[cfg(target_os = "macos")] + f.write_char('⌥')?; + + #[cfg(not(target_os = "macos"))] + write!(f, "alt-")?; + } + if modifiers.platform { + #[cfg(target_os = "macos")] + f.write_char('⌘')?; + + #[cfg(any(target_os = "linux", target_os = "freebsd"))] + f.write_char('❖')?; + + #[cfg(target_os = "windows")] + f.write_char('⊞')?; + } + if modifiers.shift { + #[cfg(target_os = "macos")] + f.write_char('⇧')?; + + #[cfg(not(target_os = "macos"))] + write!(f, "shift-")?; + } + Ok(()) +} + +fn display_key(key: &str, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let key = match key { + #[cfg(target_os = "macos")] + "backspace" => '⌫', + #[cfg(target_os = "macos")] + "up" => '↑', + #[cfg(target_os = "macos")] + "down" => '↓', + #[cfg(target_os = "macos")] + "left" => '←', + #[cfg(target_os = "macos")] + "right" => '→', + #[cfg(target_os = "macos")] + "tab" => '⇥', + #[cfg(target_os = "macos")] + "escape" => '⎋', + #[cfg(target_os = "macos")] + "shift" => '⇧', + #[cfg(target_os = "macos")] + "control" => '⌃', + #[cfg(target_os = "macos")] + "alt" => '⌥', + #[cfg(target_os = "macos")] + "platform" => '⌘', + + key if key.len() == 1 => key.chars().next().unwrap().to_ascii_uppercase(), + key => return f.write_str(key), + }; + f.write_char(key) +} + +#[inline] +fn unparse(modifiers: &Modifiers, key: &str) -> String { + let mut result = String::new(); + if modifiers.function { + result.push_str("fn-"); + } + if modifiers.control { + result.push_str("ctrl-"); + } + if modifiers.alt { + result.push_str("alt-"); + } + if modifiers.platform { + #[cfg(target_os = "macos")] + result.push_str("cmd-"); + + #[cfg(any(target_os = "linux", target_os = "freebsd"))] + result.push_str("super-"); + + #[cfg(target_os = "windows")] + result.push_str("win-"); + } + if modifiers.shift { + result.push_str("shift-"); + } + result.push_str(&key); + result +} diff --git a/crates/gpui/src/platform/linux/platform.rs b/crates/gpui/src/platform/linux/platform.rs index fe6a36baa854856eb961a020ab35a7bd0195d465..196e5b65d04125ca90c588212c140d3a63345c2e 100644 --- a/crates/gpui/src/platform/linux/platform.rs +++ b/crates/gpui/src/platform/linux/platform.rs @@ -25,8 +25,8 @@ use xkbcommon::xkb::{self, Keycode, Keysym, State}; use crate::{ Action, AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DisplayId, ForegroundExecutor, Keymap, LinuxDispatcher, Menu, MenuItem, OwnedMenu, PathPromptOptions, - Pixels, Platform, PlatformDisplay, PlatformKeyboardLayout, PlatformTextSystem, PlatformWindow, - Point, Result, Task, WindowAppearance, WindowParams, px, + Pixels, Platform, PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, + PlatformTextSystem, PlatformWindow, Point, Result, Task, WindowAppearance, WindowParams, px, }; #[cfg(any(feature = "wayland", feature = "x11"))] @@ -108,13 +108,13 @@ impl LinuxCommon { let callbacks = PlatformHandlers::default(); - let dispatcher = Arc::new(LinuxDispatcher::new(main_sender.clone())); + let dispatcher = Arc::new(LinuxDispatcher::new(main_sender)); let background_executor = BackgroundExecutor::new(dispatcher.clone()); let common = LinuxCommon { background_executor, - foreground_executor: ForegroundExecutor::new(dispatcher.clone()), + foreground_executor: ForegroundExecutor::new(dispatcher), text_system, appearance: WindowAppearance::Light, auto_hide_scrollbars: false, @@ -144,6 +144,10 @@ impl Platform for P { self.keyboard_layout() } + fn keyboard_mapper(&self) -> Rc { + Rc::new(crate::DummyKeyboardMapper) + } + fn on_keyboard_layout_change(&self, callback: Box) { self.with_common(|common| common.callbacks.keyboard_layout_change = Some(callback)); } @@ -294,6 +298,7 @@ impl Platform for P { let request = match ashpd::desktop::file_chooser::OpenFileRequest::default() .modal(true) .title(title) + .accept_label(options.prompt.as_ref().map(crate::SharedString::as_str)) .multiple(options.multiple) .directory(options.directories) .send() @@ -327,26 +332,35 @@ impl Platform for P { done_rx } - fn prompt_for_new_path(&self, directory: &Path) -> oneshot::Receiver>> { + fn prompt_for_new_path( + &self, + directory: &Path, + suggested_name: Option<&str>, + ) -> oneshot::Receiver>> { let (done_tx, done_rx) = oneshot::channel(); #[cfg(not(any(feature = "wayland", feature = "x11")))] - let _ = (done_tx.send(Ok(None)), directory); + let _ = (done_tx.send(Ok(None)), directory, suggested_name); #[cfg(any(feature = "wayland", feature = "x11"))] self.foreground_executor() .spawn({ let directory = directory.to_owned(); + let suggested_name = suggested_name.map(|s| s.to_owned()); async move { - let request = match ashpd::desktop::file_chooser::SaveFileRequest::default() - .modal(true) - .title("Save File") - .current_folder(directory) - .expect("pathbuf should not be nul terminated") - .send() - .await - { + let mut request_builder = + ashpd::desktop::file_chooser::SaveFileRequest::default() + .modal(true) + .title("Save File") + .current_folder(directory) + .expect("pathbuf should not be nul terminated"); + + if let Some(suggested_name) = suggested_name { + request_builder = request_builder.current_name(suggested_name.as_str()); + } + + let request = match request_builder.send().await { Ok(request) => request, Err(err) => { let result = match err { @@ -431,7 +445,7 @@ impl Platform for P { fn app_path(&self) -> Result { // get the path of the executable of the current process let app_path = env::current_exe()?; - return Ok(app_path); + Ok(app_path) } fn set_menus(&self, menus: Vec, _keymap: &Keymap) { @@ -632,7 +646,7 @@ pub(super) fn get_xkb_compose_state(cx: &xkb::Context) -> Option = None; for locale in locales { if let Ok(table) = - xkb::compose::Table::new_from_locale(&cx, &locale, xkb::compose::COMPILE_NO_FLAGS) + xkb::compose::Table::new_from_locale(cx, &locale, xkb::compose::COMPILE_NO_FLAGS) { state = Some(xkb::compose::State::new( &table, @@ -657,7 +671,7 @@ pub(super) const DEFAULT_CURSOR_ICON_NAME: &str = "left_ptr"; impl CursorStyle { #[cfg(any(feature = "wayland", feature = "x11"))] - pub(super) fn to_icon_names(&self) -> &'static [&'static str] { + pub(super) fn to_icon_names(self) -> &'static [&'static str] { // Based on cursor names from chromium: // https://github.com/chromium/chromium/blob/d3069cf9c973dc3627fa75f64085c6a86c8f41bf/ui/base/cursor/cursor_factory.cc#L113 match self { @@ -834,6 +848,7 @@ impl crate::Keystroke { Keysym::Down => "down".to_owned(), Keysym::Home => "home".to_owned(), Keysym::End => "end".to_owned(), + Keysym::Insert => "insert".to_owned(), _ => { let name = xkb::keysym_get_name(key_sym).to_lowercase(); @@ -980,21 +995,18 @@ mod tests { #[test] fn test_is_within_click_distance() { let zero = Point::new(px(0.0), px(0.0)); - assert_eq!( - is_within_click_distance(zero, Point::new(px(5.0), px(5.0))), - true - ); - assert_eq!( - is_within_click_distance(zero, Point::new(px(-4.9), px(5.0))), - true - ); - assert_eq!( - is_within_click_distance(Point::new(px(3.0), px(2.0)), Point::new(px(-2.0), px(-2.0))), - true - ); - assert_eq!( - is_within_click_distance(zero, Point::new(px(5.0), px(5.1))), - false - ); + assert!(is_within_click_distance(zero, Point::new(px(5.0), px(5.0)))); + assert!(is_within_click_distance( + zero, + Point::new(px(-4.9), px(5.0)) + )); + assert!(is_within_click_distance( + Point::new(px(3.0), px(2.0)), + Point::new(px(-2.0), px(-2.0)) + )); + assert!(!is_within_click_distance( + zero, + Point::new(px(5.0), px(5.1)) + ),); } } diff --git a/crates/gpui/src/platform/linux/wayland.rs b/crates/gpui/src/platform/linux/wayland.rs index cf73832b11fb1baad08bf5ee3142e461876fbe92..487bc9f38c927609100a238ac4726c2aab3b87b0 100644 --- a/crates/gpui/src/platform/linux/wayland.rs +++ b/crates/gpui/src/platform/linux/wayland.rs @@ -12,7 +12,7 @@ use wayland_protocols::wp::cursor_shape::v1::client::wp_cursor_shape_device_v1:: use crate::CursorStyle; impl CursorStyle { - pub(super) fn to_shape(&self) -> Shape { + pub(super) fn to_shape(self) -> Shape { match self { CursorStyle::Arrow => Shape::Default, CursorStyle::IBeam => Shape::Text, diff --git a/crates/gpui/src/platform/linux/wayland/client.rs b/crates/gpui/src/platform/linux/wayland/client.rs index 72e4477ecf697a9f6443dffb80e0637202d3b848..8596bddc8dd821426982d618f661d6da621096bb 100644 --- a/crates/gpui/src/platform/linux/wayland/client.rs +++ b/crates/gpui/src/platform/linux/wayland/client.rs @@ -75,8 +75,8 @@ use crate::{ FileDropEvent, ForegroundExecutor, KeyDownEvent, KeyUpEvent, Keystroke, LinuxCommon, LinuxKeyboardLayout, Modifiers, ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseExitEvent, MouseMoveEvent, MouseUpEvent, NavigationDirection, Pixels, PlatformDisplay, - PlatformInput, PlatformKeyboardLayout, Point, SCROLL_LINES, ScaledPixels, ScrollDelta, - ScrollWheelEvent, Size, TouchPhase, WindowParams, point, px, size, + PlatformInput, PlatformKeyboardLayout, Point, SCROLL_LINES, ScrollDelta, ScrollWheelEvent, + Size, TouchPhase, WindowParams, point, px, size, }; use crate::{ SharedString, @@ -323,7 +323,7 @@ impl WaylandClientStatePtr { } } - pub fn update_ime_position(&self, bounds: Bounds) { + pub fn update_ime_position(&self, bounds: Bounds) { let client = self.get_client(); let mut state = client.borrow_mut(); if state.composing || state.text_input.is_none() || state.pre_edit_text.is_some() { @@ -359,13 +359,13 @@ impl WaylandClientStatePtr { } changed }; - if changed { - if let Some(mut callback) = state.common.callbacks.keyboard_layout_change.take() { - drop(state); - callback(); - state = client.borrow_mut(); - state.common.callbacks.keyboard_layout_change = Some(callback); - } + + if changed && let Some(mut callback) = state.common.callbacks.keyboard_layout_change.take() + { + drop(state); + callback(); + state = client.borrow_mut(); + state.common.callbacks.keyboard_layout_change = Some(callback); } } @@ -373,15 +373,15 @@ impl WaylandClientStatePtr { let mut client = self.get_client(); let mut state = client.borrow_mut(); let closed_window = state.windows.remove(surface_id).unwrap(); - if let Some(window) = state.mouse_focused_window.take() { - if !window.ptr_eq(&closed_window) { - state.mouse_focused_window = Some(window); - } + if let Some(window) = state.mouse_focused_window.take() + && !window.ptr_eq(&closed_window) + { + state.mouse_focused_window = Some(window); } - if let Some(window) = state.keyboard_focused_window.take() { - if !window.ptr_eq(&closed_window) { - state.keyboard_focused_window = Some(window); - } + if let Some(window) = state.keyboard_focused_window.take() + && !window.ptr_eq(&closed_window) + { + state.keyboard_focused_window = Some(window); } if state.windows.is_empty() { state.common.signal.stop(); @@ -528,7 +528,7 @@ impl WaylandClient { client.common.appearance = appearance; - for (_, window) in &mut client.windows { + for window in client.windows.values_mut() { window.set_appearance(appearance); } } @@ -710,9 +710,7 @@ impl LinuxClient for WaylandClient { fn set_cursor_style(&self, style: CursorStyle) { let mut state = self.0.borrow_mut(); - let need_update = state - .cursor_style - .map_or(true, |current_style| current_style != style); + let need_update = state.cursor_style != Some(style); if need_update { let serial = state.serial_tracker.get(SerialKind::MouseEnter); @@ -951,11 +949,8 @@ impl Dispatch for WaylandClientStatePtr { }; drop(state); - match event { - wl_callback::Event::Done { .. } => { - window.frame(); - } - _ => {} + if let wl_callback::Event::Done { .. } = event { + window.frame(); } } } @@ -1145,7 +1140,7 @@ impl Dispatch for WaylandClientStatePtr { .globals .text_input_manager .as_ref() - .map(|text_input_manager| text_input_manager.get_text_input(&seat, qh, ())); + .map(|text_input_manager| text_input_manager.get_text_input(seat, qh, ())); if let Some(wl_keyboard) = &state.wl_keyboard { wl_keyboard.release(); @@ -1285,7 +1280,6 @@ impl Dispatch for WaylandClientStatePtr { let Some(focused_window) = focused_window else { return; }; - let focused_window = focused_window.clone(); let keymap_state = state.keymap_state.as_ref().unwrap(); let keycode = Keycode::from(key + MIN_KEYCODE); @@ -1294,7 +1288,7 @@ impl Dispatch for WaylandClientStatePtr { match key_state { wl_keyboard::KeyState::Pressed if !keysym.is_modifier_key() => { let mut keystroke = - Keystroke::from_xkb(&keymap_state, state.modifiers, keycode); + Keystroke::from_xkb(keymap_state, state.modifiers, keycode); if let Some(mut compose) = state.compose_state.take() { compose.feed(keysym); match compose.status() { @@ -1538,12 +1532,9 @@ impl Dispatch for WaylandClientStatePtr { cursor_shape_device.set_shape(serial, style.to_shape()); } else { let scale = window.primary_output_scale(); - state.cursor.set_icon( - &wl_pointer, - serial, - style.to_icon_names(), - scale, - ); + state + .cursor + .set_icon(wl_pointer, serial, style.to_icon_names(), scale); } } drop(state); @@ -1580,7 +1571,7 @@ impl Dispatch for WaylandClientStatePtr { if state .keyboard_focused_window .as_ref() - .map_or(false, |keyboard_window| window.ptr_eq(&keyboard_window)) + .is_some_and(|keyboard_window| window.ptr_eq(keyboard_window)) { state.enter_token = None; } @@ -1787,17 +1778,17 @@ impl Dispatch for WaylandClientStatePtr { drop(state); window.handle_input(input); } - } else if let Some(discrete) = discrete { - if let Some(window) = state.mouse_focused_window.clone() { - let input = PlatformInput::ScrollWheel(ScrollWheelEvent { - position: state.mouse_location.unwrap(), - delta: ScrollDelta::Lines(discrete), - modifiers: state.modifiers, - touch_phase: TouchPhase::Moved, - }); - drop(state); - window.handle_input(input); - } + } else if let Some(discrete) = discrete + && let Some(window) = state.mouse_focused_window.clone() + { + let input = PlatformInput::ScrollWheel(ScrollWheelEvent { + position: state.mouse_location.unwrap(), + delta: ScrollDelta::Lines(discrete), + modifiers: state.modifiers, + touch_phase: TouchPhase::Moved, + }); + drop(state); + window.handle_input(input); } } } @@ -2019,25 +2010,22 @@ impl Dispatch for WaylandClientStatePtr { let client = this.get_client(); let mut state = client.borrow_mut(); - match event { - wl_data_offer::Event::Offer { mime_type } => { - // Drag and drop - if mime_type == FILE_LIST_MIME_TYPE { - let serial = state.serial_tracker.get(SerialKind::DataDevice); - let mime_type = mime_type.clone(); - data_offer.accept(serial, Some(mime_type)); - } + if let wl_data_offer::Event::Offer { mime_type } = event { + // Drag and drop + if mime_type == FILE_LIST_MIME_TYPE { + let serial = state.serial_tracker.get(SerialKind::DataDevice); + let mime_type = mime_type.clone(); + data_offer.accept(serial, Some(mime_type)); + } - // Clipboard - if let Some(offer) = state - .data_offers - .iter_mut() - .find(|wrapper| wrapper.inner.id() == data_offer.id()) - { - offer.add_mime_type(mime_type); - } + // Clipboard + if let Some(offer) = state + .data_offers + .iter_mut() + .find(|wrapper| wrapper.inner.id() == data_offer.id()) + { + offer.add_mime_type(mime_type); } - _ => {} } } } @@ -2118,13 +2106,10 @@ impl Dispatch let client = this.get_client(); let mut state = client.borrow_mut(); - match event { - zwp_primary_selection_offer_v1::Event::Offer { mime_type } => { - if let Some(offer) = state.primary_data_offer.as_mut() { - offer.add_mime_type(mime_type); - } - } - _ => {} + if let zwp_primary_selection_offer_v1::Event::Offer { mime_type } = event + && let Some(offer) = state.primary_data_offer.as_mut() + { + offer.add_mime_type(mime_type); } } } diff --git a/crates/gpui/src/platform/linux/wayland/cursor.rs b/crates/gpui/src/platform/linux/wayland/cursor.rs index 2a24d0e1ba347fb718da126120bc809c65d93b33..c7c9139dea795701e459387a309b1817e2f60971 100644 --- a/crates/gpui/src/platform/linux/wayland/cursor.rs +++ b/crates/gpui/src/platform/linux/wayland/cursor.rs @@ -45,10 +45,11 @@ impl Cursor { } fn set_theme_internal(&mut self, theme_name: Option) { - if let Some(loaded_theme) = self.loaded_theme.as_ref() { - if loaded_theme.name == theme_name && loaded_theme.scaled_size == self.scaled_size { - return; - } + if let Some(loaded_theme) = self.loaded_theme.as_ref() + && loaded_theme.name == theme_name + && loaded_theme.scaled_size == self.scaled_size + { + return; } let result = if let Some(theme_name) = theme_name.as_ref() { CursorTheme::load_from_name( @@ -66,7 +67,7 @@ impl Cursor { { self.loaded_theme = Some(LoadedTheme { theme, - name: theme_name.map(|name| name.to_string()), + name: theme_name, scaled_size: self.scaled_size, }); } @@ -144,7 +145,7 @@ impl Cursor { hot_y as i32 / scale, ); - self.surface.attach(Some(&buffer), 0, 0); + self.surface.attach(Some(buffer), 0, 0); self.surface.damage(0, 0, width as i32, height as i32); self.surface.commit(); } diff --git a/crates/gpui/src/platform/linux/wayland/window.rs b/crates/gpui/src/platform/linux/wayland/window.rs index 2b2207e22c86fc25e6387581bb92b9c304f4bc9d..76dd89c940c615d726af1cf5922be226d91dfd41 100644 --- a/crates/gpui/src/platform/linux/wayland/window.rs +++ b/crates/gpui/src/platform/linux/wayland/window.rs @@ -25,9 +25,8 @@ use crate::scene::Scene; use crate::{ AnyWindowHandle, Bounds, Decorations, Globals, GpuSpecs, Modifiers, Output, Pixels, PlatformDisplay, PlatformInput, Point, PromptButton, PromptLevel, RequestFrameOptions, - ResizeEdge, ScaledPixels, Size, Tiling, WaylandClientStatePtr, WindowAppearance, - WindowBackgroundAppearance, WindowBounds, WindowControlArea, WindowControls, WindowDecorations, - WindowParams, px, size, + ResizeEdge, Size, Tiling, WaylandClientStatePtr, WindowAppearance, WindowBackgroundAppearance, + WindowBounds, WindowControlArea, WindowControls, WindowDecorations, WindowParams, px, size, }; use crate::{ Capslock, @@ -355,85 +354,82 @@ impl WaylandWindowStatePtr { } pub fn handle_xdg_surface_event(&self, event: xdg_surface::Event) { - match event { - xdg_surface::Event::Configure { serial } => { - { - let mut state = self.state.borrow_mut(); - if let Some(window_controls) = state.in_progress_window_controls.take() { - state.window_controls = window_controls; - - drop(state); - let mut callbacks = self.callbacks.borrow_mut(); - if let Some(appearance_changed) = callbacks.appearance_changed.as_mut() { - appearance_changed(); - } + if let xdg_surface::Event::Configure { serial } = event { + { + let mut state = self.state.borrow_mut(); + if let Some(window_controls) = state.in_progress_window_controls.take() { + state.window_controls = window_controls; + + drop(state); + let mut callbacks = self.callbacks.borrow_mut(); + if let Some(appearance_changed) = callbacks.appearance_changed.as_mut() { + appearance_changed(); } } - { - let mut state = self.state.borrow_mut(); - - if let Some(mut configure) = state.in_progress_configure.take() { - let got_unmaximized = state.maximized && !configure.maximized; - state.fullscreen = configure.fullscreen; - state.maximized = configure.maximized; - state.tiling = configure.tiling; - // Limit interactive resizes to once per vblank - if configure.resizing && state.resize_throttle { - return; - } else if configure.resizing { - state.resize_throttle = true; - } - if !configure.fullscreen && !configure.maximized { - configure.size = if got_unmaximized { - Some(state.window_bounds.size) - } else { - compute_outer_size(state.inset(), configure.size, state.tiling) - }; - if let Some(size) = configure.size { - state.window_bounds = Bounds { - origin: Point::default(), - size, - }; - } - } - drop(state); + } + { + let mut state = self.state.borrow_mut(); + + if let Some(mut configure) = state.in_progress_configure.take() { + let got_unmaximized = state.maximized && !configure.maximized; + state.fullscreen = configure.fullscreen; + state.maximized = configure.maximized; + state.tiling = configure.tiling; + // Limit interactive resizes to once per vblank + if configure.resizing && state.resize_throttle { + return; + } else if configure.resizing { + state.resize_throttle = true; + } + if !configure.fullscreen && !configure.maximized { + configure.size = if got_unmaximized { + Some(state.window_bounds.size) + } else { + compute_outer_size(state.inset(), configure.size, state.tiling) + }; if let Some(size) = configure.size { - self.resize(size); + state.window_bounds = Bounds { + origin: Point::default(), + size, + }; } } - } - let mut state = self.state.borrow_mut(); - state.xdg_surface.ack_configure(serial); - - let window_geometry = inset_by_tiling( - state.bounds.map_origin(|_| px(0.0)), - state.inset(), - state.tiling, - ) - .map(|v| v.0 as i32) - .map_size(|v| if v <= 0 { 1 } else { v }); - - state.xdg_surface.set_window_geometry( - window_geometry.origin.x, - window_geometry.origin.y, - window_geometry.size.width, - window_geometry.size.height, - ); - - let request_frame_callback = !state.acknowledged_first_configure; - if request_frame_callback { - state.acknowledged_first_configure = true; drop(state); - self.frame(); + if let Some(size) = configure.size { + self.resize(size); + } } } - _ => {} + let mut state = self.state.borrow_mut(); + state.xdg_surface.ack_configure(serial); + + let window_geometry = inset_by_tiling( + state.bounds.map_origin(|_| px(0.0)), + state.inset(), + state.tiling, + ) + .map(|v| v.0 as i32) + .map_size(|v| if v <= 0 { 1 } else { v }); + + state.xdg_surface.set_window_geometry( + window_geometry.origin.x, + window_geometry.origin.y, + window_geometry.size.width, + window_geometry.size.height, + ); + + let request_frame_callback = !state.acknowledged_first_configure; + if request_frame_callback { + state.acknowledged_first_configure = true; + drop(state); + self.frame(); + } } } pub fn handle_toplevel_decoration_event(&self, event: zxdg_toplevel_decoration_v1::Event) { - match event { - zxdg_toplevel_decoration_v1::Event::Configure { mode } => match mode { + if let zxdg_toplevel_decoration_v1::Event::Configure { mode } = event { + match mode { WEnum::Value(zxdg_toplevel_decoration_v1::Mode::ServerSide) => { self.state.borrow_mut().decorations = WindowDecorations::Server; if let Some(mut appearance_changed) = @@ -457,17 +453,13 @@ impl WaylandWindowStatePtr { WEnum::Unknown(v) => { log::warn!("Unknown decoration mode: {}", v); } - }, - _ => {} + } } } pub fn handle_fractional_scale_event(&self, event: wp_fractional_scale_v1::Event) { - match event { - wp_fractional_scale_v1::Event::PreferredScale { scale } => { - self.rescale(scale as f32 / 120.0); - } - _ => {} + if let wp_fractional_scale_v1::Event::PreferredScale { scale } = event { + self.rescale(scale as f32 / 120.0); } } @@ -669,8 +661,8 @@ impl WaylandWindowStatePtr { pub fn set_size_and_scale(&self, size: Option>, scale: Option) { let (size, scale) = { let mut state = self.state.borrow_mut(); - if size.map_or(true, |size| size == state.bounds.size) - && scale.map_or(true, |scale| scale == state.scale) + if size.is_none_or(|size| size == state.bounds.size) + && scale.is_none_or(|scale| scale == state.scale) { return; } @@ -713,21 +705,20 @@ impl WaylandWindowStatePtr { } pub fn handle_input(&self, input: PlatformInput) { - if let Some(ref mut fun) = self.callbacks.borrow_mut().input { - if !fun(input.clone()).propagate { - return; - } + if let Some(ref mut fun) = self.callbacks.borrow_mut().input + && !fun(input.clone()).propagate + { + return; } - if let PlatformInput::KeyDown(event) = input { - if event.keystroke.modifiers.is_subset_of(&Modifiers::shift()) { - if let Some(key_char) = &event.keystroke.key_char { - let mut state = self.state.borrow_mut(); - if let Some(mut input_handler) = state.input_handler.take() { - drop(state); - input_handler.replace_text_in_range(None, key_char); - self.state.borrow_mut().input_handler = Some(input_handler); - } - } + if let PlatformInput::KeyDown(event) = input + && event.keystroke.modifiers.is_subset_of(&Modifiers::shift()) + && let Some(key_char) = &event.keystroke.key_char + { + let mut state = self.state.borrow_mut(); + if let Some(mut input_handler) = state.input_handler.take() { + drop(state); + input_handler.replace_text_in_range(None, key_char); + self.state.borrow_mut().input_handler = Some(input_handler); } } } @@ -1086,7 +1077,7 @@ impl PlatformWindow for WaylandWindow { } } - fn update_ime_position(&self, bounds: Bounds) { + fn update_ime_position(&self, bounds: Bounds) { let state = self.borrow(); state.client.update_ime_position(bounds); } @@ -1147,7 +1138,7 @@ fn update_window(mut state: RefMut) { } impl WindowDecorations { - fn to_xdg(&self) -> zxdg_toplevel_decoration_v1::Mode { + fn to_xdg(self) -> zxdg_toplevel_decoration_v1::Mode { match self { WindowDecorations::Client => zxdg_toplevel_decoration_v1::Mode::ClientSide, WindowDecorations::Server => zxdg_toplevel_decoration_v1::Mode::ServerSide, @@ -1156,7 +1147,7 @@ impl WindowDecorations { } impl ResizeEdge { - fn to_xdg(&self) -> xdg_toplevel::ResizeEdge { + fn to_xdg(self) -> xdg_toplevel::ResizeEdge { match self { ResizeEdge::Top => xdg_toplevel::ResizeEdge::Top, ResizeEdge::TopRight => xdg_toplevel::ResizeEdge::TopRight, diff --git a/crates/gpui/src/platform/linux/x11/client.rs b/crates/gpui/src/platform/linux/x11/client.rs index 573e4addf75b90d50e7f453555507462280fb3d4..42c59701d3ee644b99bc8bb58002b429265c1a45 100644 --- a/crates/gpui/src/platform/linux/x11/client.rs +++ b/crates/gpui/src/platform/linux/x11/client.rs @@ -62,8 +62,7 @@ use crate::{ AnyWindowHandle, Bounds, ClipboardItem, CursorStyle, DisplayId, FileDropEvent, Keystroke, LinuxKeyboardLayout, Modifiers, ModifiersChangedEvent, MouseButton, Pixels, Platform, PlatformDisplay, PlatformInput, PlatformKeyboardLayout, Point, RequestFrameOptions, - ScaledPixels, ScrollDelta, Size, TouchPhase, WindowParams, X11Window, - modifiers_from_xinput_info, point, px, + ScrollDelta, Size, TouchPhase, WindowParams, X11Window, modifiers_from_xinput_info, point, px, }; /// Value for DeviceId parameters which selects all devices. @@ -232,15 +231,12 @@ impl X11ClientStatePtr { }; let mut state = client.0.borrow_mut(); - if let Some(window_ref) = state.windows.remove(&x_window) { - match window_ref.refresh_state { - Some(RefreshState::PeriodicRefresh { - event_loop_token, .. - }) => { - state.loop_handle.remove(event_loop_token); - } - _ => {} - } + if let Some(window_ref) = state.windows.remove(&x_window) + && let Some(RefreshState::PeriodicRefresh { + event_loop_token, .. + }) = window_ref.refresh_state + { + state.loop_handle.remove(event_loop_token); } if state.mouse_focused_window == Some(x_window) { state.mouse_focused_window = None; @@ -255,7 +251,7 @@ impl X11ClientStatePtr { } } - pub fn update_ime_position(&self, bounds: Bounds) { + pub fn update_ime_position(&self, bounds: Bounds) { let Some(client) = self.get_client() else { return; }; @@ -273,6 +269,7 @@ impl X11ClientStatePtr { state.ximc = Some(ximc); return; }; + let scaled_bounds = bounds.scale(state.scale_factor); let ic_attributes = ximc .build_ic_attributes() .push( @@ -285,8 +282,8 @@ impl X11ClientStatePtr { b.push( xim::AttributeName::SpotLocation, xim::Point { - x: u32::from(bounds.origin.x + bounds.size.width) as i16, - y: u32::from(bounds.origin.y + bounds.size.height) as i16, + x: u32::from(scaled_bounds.origin.x + scaled_bounds.size.width) as i16, + y: u32::from(scaled_bounds.origin.y + scaled_bounds.size.height) as i16, }, ); }) @@ -459,7 +456,7 @@ impl X11Client { move |event, _, client| match event { XDPEvent::WindowAppearance(appearance) => { client.with_common(|common| common.appearance = appearance); - for (_, window) in &mut client.0.borrow_mut().windows { + for window in client.0.borrow_mut().windows.values_mut() { window.window.set_appearance(appearance); } } @@ -565,10 +562,10 @@ impl X11Client { events.push(last_keymap_change_event); } - if let Some(last_press) = last_key_press.as_ref() { - if last_press.detail == key_press.detail { - continue; - } + if let Some(last_press) = last_key_press.as_ref() + && last_press.detail == key_press.detail + { + continue; } if let Some(Event::KeyRelease(key_release)) = @@ -642,13 +639,7 @@ impl X11Client { let xim_connected = xim_handler.connected; drop(state); - let xim_filtered = match ximc.filter_event(&event, &mut xim_handler) { - Ok(handled) => handled, - Err(err) => { - log::error!("XIMClientError: {}", err); - false - } - }; + let xim_filtered = ximc.filter_event(&event, &mut xim_handler); let xim_callback_event = xim_handler.last_callback_event.take(); let mut state = self.0.borrow_mut(); @@ -659,14 +650,28 @@ impl X11Client { self.handle_xim_callback_event(event); } - if xim_filtered { - continue; - } - - if xim_connected { - self.xim_handle_event(event); - } else { - self.handle_event(event); + match xim_filtered { + Ok(handled) => { + if handled { + continue; + } + if xim_connected { + self.xim_handle_event(event); + } else { + self.handle_event(event); + } + } + Err(err) => { + // this might happen when xim server crashes on one of the events + // we do lose 1-2 keys when crash happens since there is no reliable way to get that info + // luckily, x11 sends us window not found error when xim server crashes upon further key press + // hence we fall back to handle_event + log::error!("XIMClientError: {}", err); + let mut state = self.0.borrow_mut(); + state.take_xim(); + drop(state); + self.handle_event(event); + } } } } @@ -698,14 +703,14 @@ impl X11Client { state.xim_handler = Some(xim_handler); return; }; - if let Some(area) = window.get_ime_area() { + if let Some(scaled_area) = window.get_ime_area() { ic_attributes = ic_attributes.nested_list(xim::AttributeName::PreeditAttributes, |b| { b.push( xim::AttributeName::SpotLocation, xim::Point { - x: u32::from(area.origin.x + area.size.width) as i16, - y: u32::from(area.origin.y + area.size.height) as i16, + x: u32::from(scaled_area.origin.x + scaled_area.size.width) as i16, + y: u32::from(scaled_area.origin.y + scaled_area.size.height) as i16, }, ); }); @@ -868,22 +873,19 @@ impl X11Client { let Some(reply) = reply else { return Some(()); }; - match str::from_utf8(&reply.value) { - Ok(file_list) => { - let paths: SmallVec<[_; 2]> = file_list - .lines() - .filter_map(|path| Url::parse(path).log_err()) - .filter_map(|url| url.to_file_path().log_err()) - .collect(); - let input = PlatformInput::FileDrop(FileDropEvent::Entered { - position: state.xdnd_state.position, - paths: crate::ExternalPaths(paths), - }); - drop(state); - window.handle_input(input); - self.0.borrow_mut().xdnd_state.retrieved = true; - } - Err(_) => {} + if let Ok(file_list) = str::from_utf8(&reply.value) { + let paths: SmallVec<[_; 2]> = file_list + .lines() + .filter_map(|path| Url::parse(path).log_err()) + .filter_map(|url| url.to_file_path().log_err()) + .collect(); + let input = PlatformInput::FileDrop(FileDropEvent::Entered { + position: state.xdnd_state.position, + paths: crate::ExternalPaths(paths), + }); + drop(state); + window.handle_input(input); + self.0.borrow_mut().xdnd_state.retrieved = true; } } Event::ConfigureNotify(event) => { @@ -1204,7 +1206,7 @@ impl X11Client { state = self.0.borrow_mut(); if let Some(mut pointer) = state.pointer_device_states.get_mut(&event.sourceid) { - let scroll_delta = get_scroll_delta_and_update_state(&mut pointer, &event); + let scroll_delta = get_scroll_delta_and_update_state(pointer, &event); drop(state); if let Some(scroll_delta) = scroll_delta { window.handle_input(PlatformInput::ScrollWheel(make_scroll_wheel_event( @@ -1263,7 +1265,7 @@ impl X11Client { Event::XinputDeviceChanged(event) => { let mut state = self.0.borrow_mut(); if let Some(mut pointer) = state.pointer_device_states.get_mut(&event.sourceid) { - reset_pointer_device_scroll_positions(&mut pointer); + reset_pointer_device_scroll_positions(pointer); } } _ => {} @@ -1327,7 +1329,7 @@ impl X11Client { state.composing = false; drop(state); if let Some(mut keystroke) = keystroke { - keystroke.key_char = Some(text.clone()); + keystroke.key_char = Some(text); window.handle_input(PlatformInput::KeyDown(crate::KeyDownEvent { keystroke, is_held: false, @@ -1349,7 +1351,7 @@ impl X11Client { drop(state); window.handle_ime_preedit(text); - if let Some(area) = window.get_ime_area() { + if let Some(scaled_area) = window.get_ime_area() { let ic_attributes = ximc .build_ic_attributes() .push( @@ -1362,8 +1364,8 @@ impl X11Client { b.push( xim::AttributeName::SpotLocation, xim::Point { - x: u32::from(area.origin.x + area.size.width) as i16, - y: u32::from(area.origin.y + area.size.height) as i16, + x: u32::from(scaled_area.origin.x + scaled_area.size.width) as i16, + y: u32::from(scaled_area.origin.y + scaled_area.size.height) as i16, }, ); }) @@ -1578,11 +1580,11 @@ impl LinuxClient for X11Client { fn read_from_primary(&self) -> Option { let state = self.0.borrow_mut(); - return state + state .clipboard .get_any(clipboard::ClipboardKind::Primary) .context("X11: Failed to read from clipboard (primary)") - .log_with_level(log::Level::Debug); + .log_with_level(log::Level::Debug) } fn read_from_clipboard(&self) -> Option { @@ -1595,11 +1597,11 @@ impl LinuxClient for X11Client { { return state.clipboard_item.clone(); } - return state + state .clipboard .get_any(clipboard::ClipboardKind::Clipboard) .context("X11: Failed to read from clipboard (clipboard)") - .log_with_level(log::Level::Debug); + .log_with_level(log::Level::Debug) } fn run(&self) { @@ -2002,12 +2004,12 @@ fn check_gtk_frame_extents_supported( } fn xdnd_is_atom_supported(atom: u32, atoms: &XcbAtoms) -> bool { - return atom == atoms.TEXT + atom == atoms.TEXT || atom == atoms.STRING || atom == atoms.UTF8_STRING || atom == atoms.TEXT_PLAIN || atom == atoms.TEXT_PLAIN_UTF8 - || atom == atoms.TextUriList; + || atom == atoms.TextUriList } fn xdnd_get_supported_atom( @@ -2027,16 +2029,15 @@ fn xdnd_get_supported_atom( ), ) .log_with_level(Level::Warn) + && let Some(atoms) = reply.value32() { - if let Some(atoms) = reply.value32() { - for atom in atoms { - if xdnd_is_atom_supported(atom, &supported_atoms) { - return atom; - } + for atom in atoms { + if xdnd_is_atom_supported(atom, supported_atoms) { + return atom; } } } - return 0; + 0 } fn xdnd_send_finished( @@ -2107,7 +2108,7 @@ fn current_pointer_device_states( .classes .iter() .filter_map(|class| class.data.as_scroll()) - .map(|class| *class) + .copied() .rev() .collect::>(); let old_state = scroll_values_to_preserve.get(&info.deviceid); @@ -2137,7 +2138,7 @@ fn current_pointer_device_states( if pointer_device_states.is_empty() { log::error!("Found no xinput mouse pointers."); } - return Some(pointer_device_states); + Some(pointer_device_states) } /// Returns true if the device is a pointer device. Does not include pointer device groups. @@ -2403,11 +2404,13 @@ fn legacy_get_randr_scale_factor(connection: &XCBConnection, root: u32) -> Optio let mut crtc_infos: HashMap = HashMap::default(); let mut valid_outputs: HashSet = HashSet::new(); for (crtc, cookie) in crtc_cookies { - if let Ok(reply) = cookie.reply() { - if reply.width > 0 && reply.height > 0 && !reply.outputs.is_empty() { - crtc_infos.insert(crtc, reply.clone()); - valid_outputs.extend(&reply.outputs); - } + if let Ok(reply) = cookie.reply() + && reply.width > 0 + && reply.height > 0 + && !reply.outputs.is_empty() + { + crtc_infos.insert(crtc, reply.clone()); + valid_outputs.extend(&reply.outputs); } } diff --git a/crates/gpui/src/platform/linux/x11/clipboard.rs b/crates/gpui/src/platform/linux/x11/clipboard.rs index 5d42eadaaf04e0ad7811b980e6d31b4bca935139..a6f96d38c4254da5a2f92261700126962c16e91c 100644 --- a/crates/gpui/src/platform/linux/x11/clipboard.rs +++ b/crates/gpui/src/platform/linux/x11/clipboard.rs @@ -1078,11 +1078,11 @@ impl Clipboard { } else { String::from_utf8(result.bytes).map_err(|_| Error::ConversionFailure)? }; - return Ok(ClipboardItem::new_string(text)); + Ok(ClipboardItem::new_string(text)) } pub fn is_owner(&self, selection: ClipboardKind) -> bool { - return self.inner.is_owner(selection).unwrap_or(false); + self.inner.is_owner(selection).unwrap_or(false) } } @@ -1120,25 +1120,25 @@ impl Drop for Clipboard { log::error!("Failed to flush the clipboard window. Error: {}", e); return; } - if let Some(global_cb) = global_cb { - if let Err(e) = global_cb.server_handle.join() { - // Let's try extracting the error message - let message; - if let Some(msg) = e.downcast_ref::<&'static str>() { - message = Some((*msg).to_string()); - } else if let Some(msg) = e.downcast_ref::() { - message = Some(msg.clone()); - } else { - message = None; - } - if let Some(message) = message { - log::error!( - "The clipboard server thread panicked. Panic message: '{}'", - message, - ); - } else { - log::error!("The clipboard server thread panicked."); - } + if let Some(global_cb) = global_cb + && let Err(e) = global_cb.server_handle.join() + { + // Let's try extracting the error message + let message; + if let Some(msg) = e.downcast_ref::<&'static str>() { + message = Some((*msg).to_string()); + } else if let Some(msg) = e.downcast_ref::() { + message = Some(msg.clone()); + } else { + message = None; + } + if let Some(message) = message { + log::error!( + "The clipboard server thread panicked. Panic message: '{}'", + message, + ); + } else { + log::error!("The clipboard server thread panicked."); } } } diff --git a/crates/gpui/src/platform/linux/x11/event.rs b/crates/gpui/src/platform/linux/x11/event.rs index cd4cef24a33f33aaa2f2e685089eb1a2368719e2..17bcc908d3a6bdd48f16a8f5db69f08290b9444f 100644 --- a/crates/gpui/src/platform/linux/x11/event.rs +++ b/crates/gpui/src/platform/linux/x11/event.rs @@ -73,8 +73,8 @@ pub(crate) fn get_valuator_axis_index( // valuator present in this event's axisvalues. Axisvalues is ordered from // lowest valuator number to highest, so counting bits before the 1 bit for // this valuator yields the index in axisvalues. - if bit_is_set_in_vec(&valuator_mask, valuator_number) { - Some(popcount_upto_bit_index(&valuator_mask, valuator_number) as usize) + if bit_is_set_in_vec(valuator_mask, valuator_number) { + Some(popcount_upto_bit_index(valuator_mask, valuator_number) as usize) } else { None } @@ -104,7 +104,7 @@ fn bit_is_set_in_vec(bit_vec: &Vec, bit_index: u16) -> bool { let array_index = bit_index as usize / 32; bit_vec .get(array_index) - .map_or(false, |bits| bit_is_set(*bits, bit_index % 32)) + .is_some_and(|bits| bit_is_set(*bits, bit_index % 32)) } fn bit_is_set(bits: u32, bit_index: u16) -> bool { diff --git a/crates/gpui/src/platform/linux/x11/window.rs b/crates/gpui/src/platform/linux/x11/window.rs index 1a3c323c35129b9ea56595b7f81775de4b036454..79a43837252f7dc702b43176d2f06172a3acec18 100644 --- a/crates/gpui/src/platform/linux/x11/window.rs +++ b/crates/gpui/src/platform/linux/x11/window.rs @@ -95,7 +95,7 @@ fn query_render_extent( } impl ResizeEdge { - fn to_moveresize(&self) -> u32 { + fn to_moveresize(self) -> u32 { match self { ResizeEdge::TopLeft => 0, ResizeEdge::Top => 1, @@ -397,7 +397,7 @@ impl X11WindowState { .display_id .map_or(x_main_screen_index, |did| did.0 as usize); - let visual_set = find_visuals(&xcb, x_screen_index); + let visual_set = find_visuals(xcb, x_screen_index); let visual = match visual_set.transparent { Some(visual) => visual, @@ -515,19 +515,19 @@ impl X11WindowState { xcb.configure_window(x_window, &xproto::ConfigureWindowAux::new().x(x).y(y)), )?; } - if let Some(titlebar) = params.titlebar { - if let Some(title) = titlebar.title { - check_reply( - || "X11 ChangeProperty8 on window title failed.", - xcb.change_property8( - xproto::PropMode::REPLACE, - x_window, - xproto::AtomEnum::WM_NAME, - xproto::AtomEnum::STRING, - title.as_bytes(), - ), - )?; - } + if let Some(titlebar) = params.titlebar + && let Some(title) = titlebar.title + { + check_reply( + || "X11 ChangeProperty8 on window title failed.", + xcb.change_property8( + xproto::PropMode::REPLACE, + x_window, + xproto::AtomEnum::WM_NAME, + xproto::AtomEnum::STRING, + title.as_bytes(), + ), + )?; } if params.kind == WindowKind::PopUp { check_reply( @@ -604,7 +604,7 @@ impl X11WindowState { ), )?; - xcb_flush(&xcb); + xcb_flush(xcb); let renderer = { let raw_window = RawWindow { @@ -664,7 +664,7 @@ impl X11WindowState { || "X11 DestroyWindow failed while cleaning it up after setup failure.", xcb.destroy_window(x_window), )?; - xcb_flush(&xcb); + xcb_flush(xcb); } setup_result @@ -956,10 +956,10 @@ impl X11WindowStatePtr { } pub fn handle_input(&self, input: PlatformInput) { - if let Some(ref mut fun) = self.callbacks.borrow_mut().input { - if !fun(input.clone()).propagate { - return; - } + if let Some(ref mut fun) = self.callbacks.borrow_mut().input + && !fun(input.clone()).propagate + { + return; } if let PlatformInput::KeyDown(event) = input { // only allow shift modifier when inserting text @@ -1019,8 +1019,9 @@ impl X11WindowStatePtr { } } - pub fn get_ime_area(&self) -> Option> { + pub fn get_ime_area(&self) -> Option> { let mut state = self.state.borrow_mut(); + let scale_factor = state.scale_factor; let mut bounds: Option> = None; if let Some(mut input_handler) = state.input_handler.take() { drop(state); @@ -1030,7 +1031,7 @@ impl X11WindowStatePtr { let mut state = self.state.borrow_mut(); state.input_handler = Some(input_handler); }; - bounds + bounds.map(|b| b.scale(scale_factor)) } pub fn set_bounds(&self, bounds: Bounds) -> anyhow::Result<()> { @@ -1068,15 +1069,14 @@ impl X11WindowStatePtr { } let mut callbacks = self.callbacks.borrow_mut(); - if let Some((content_size, scale_factor)) = resize_args { - if let Some(ref mut fun) = callbacks.resize { - fun(content_size, scale_factor) - } + if let Some((content_size, scale_factor)) = resize_args + && let Some(ref mut fun) = callbacks.resize + { + fun(content_size, scale_factor) } - if !is_resize { - if let Some(ref mut fun) = callbacks.moved { - fun(); - } + + if !is_resize && let Some(ref mut fun) = callbacks.moved { + fun(); } Ok(()) @@ -1619,7 +1619,7 @@ impl PlatformWindow for X11Window { } } - fn update_ime_position(&self, bounds: Bounds) { + fn update_ime_position(&self, bounds: Bounds) { let mut state = self.0.state.borrow_mut(); let client = state.client.clone(); drop(state); diff --git a/crates/gpui/src/platform/mac/events.rs b/crates/gpui/src/platform/mac/events.rs index 0dc361b9dcfdb0980561037484cf51b84dc251e8..938db4b76205ee43eb979995c240b8d96e2aa57a 100644 --- a/crates/gpui/src/platform/mac/events.rs +++ b/crates/gpui/src/platform/mac/events.rs @@ -311,9 +311,8 @@ unsafe fn parse_keystroke(native_event: id) -> Keystroke { let mut shift = modifiers.contains(NSEventModifierFlags::NSShiftKeyMask); let command = modifiers.contains(NSEventModifierFlags::NSCommandKeyMask); let function = modifiers.contains(NSEventModifierFlags::NSFunctionKeyMask) - && first_char.map_or(true, |ch| { - !(NSUpArrowFunctionKey..=NSModeSwitchFunctionKey).contains(&ch) - }); + && first_char + .is_none_or(|ch| !(NSUpArrowFunctionKey..=NSModeSwitchFunctionKey).contains(&ch)); #[allow(non_upper_case_globals)] let key = match first_char { @@ -427,7 +426,7 @@ unsafe fn parse_keystroke(native_event: id) -> Keystroke { key_char = Some(chars_for_modified_key(native_event.keyCode(), mods)); } - let mut key = if shift + if shift && chars_ignoring_modifiers .chars() .all(|c| c.is_ascii_lowercase()) @@ -438,9 +437,7 @@ unsafe fn parse_keystroke(native_event: id) -> Keystroke { chars_with_shift } else { chars_ignoring_modifiers - }; - - key + } } }; diff --git a/crates/gpui/src/platform/mac/keyboard.rs b/crates/gpui/src/platform/mac/keyboard.rs index a9f6af3edb584157b72b0df25f6389472410883b..14097312468cbb732b46f004dbb0970c26f6e821 100644 --- a/crates/gpui/src/platform/mac/keyboard.rs +++ b/crates/gpui/src/platform/mac/keyboard.rs @@ -1,8 +1,9 @@ +use collections::HashMap; use std::ffi::{CStr, c_void}; use objc::{msg_send, runtime::Object, sel, sel_impl}; -use crate::PlatformKeyboardLayout; +use crate::{KeybindingKeystroke, Keystroke, PlatformKeyboardLayout, PlatformKeyboardMapper}; use super::{ TISCopyCurrentKeyboardLayoutInputSource, TISGetInputSourceProperty, kTISPropertyInputSourceID, @@ -14,6 +15,10 @@ pub(crate) struct MacKeyboardLayout { name: String, } +pub(crate) struct MacKeyboardMapper { + key_equivalents: Option>, +} + impl PlatformKeyboardLayout for MacKeyboardLayout { fn id(&self) -> &str { &self.id @@ -24,6 +29,27 @@ impl PlatformKeyboardLayout for MacKeyboardLayout { } } +impl PlatformKeyboardMapper for MacKeyboardMapper { + fn map_key_equivalent( + &self, + mut keystroke: Keystroke, + use_key_equivalents: bool, + ) -> KeybindingKeystroke { + if use_key_equivalents && let Some(key_equivalents) = &self.key_equivalents { + if keystroke.key.chars().count() == 1 + && let Some(key) = key_equivalents.get(&keystroke.key.chars().next().unwrap()) + { + keystroke.key = key.to_string(); + } + } + KeybindingKeystroke::from_keystroke(keystroke) + } + + fn get_key_equivalents(&self) -> Option<&HashMap> { + self.key_equivalents.as_ref() + } +} + impl MacKeyboardLayout { pub(crate) fn new() -> Self { unsafe { @@ -47,3 +73,1428 @@ impl MacKeyboardLayout { } } } + +impl MacKeyboardMapper { + pub(crate) fn new(layout_id: &str) -> Self { + let key_equivalents = get_key_equivalents(layout_id); + + Self { key_equivalents } + } +} + +// On some keyboards (e.g. German QWERTZ) it is not possible to type the full ASCII range +// without using option. This means that some of our built in keyboard shortcuts do not work +// for those users. +// +// The way macOS solves this problem is to move shortcuts around so that they are all reachable, +// even if the mnemonic changes. https://developer.apple.com/documentation/swiftui/keyboardshortcut/localization-swift.struct +// +// For example, cmd-> is the "switch window" shortcut because the > key is right above tab. +// To ensure this doesn't cause problems for shortcuts defined for a QWERTY layout, apple moves +// any shortcuts defined as cmd-> to cmd-:. Coincidentally this s also the same keyboard position +// as cmd-> on a QWERTY layout. +// +// Another example is cmd-[ and cmd-], as they cannot be typed without option, those keys are remapped to cmd-ö +// and cmd-ä. These shortcuts are not in the same position as a QWERTY keyboard, because on a QWERTZ keyboard +// the + key is in the way; and shortcuts bound to cmd-+ are still typed as cmd-+ on either keyboard (though the +// specific key moves) +// +// As far as I can tell, there's no way to query the mappings Apple uses except by rendering a menu with every +// possible key combination, and inspecting the UI to see what it rendered. So that's what we did... +// +// These mappings were generated by running https://github.com/ConradIrwin/keyboard-inspector, tidying up the +// output to remove languages with no mappings and other oddities, and converting it to a less verbose representation with: +// jq -s 'map(to_entries | map({key: .key, value: [(.value | to_entries | map(.key) | join("")), (.value | to_entries | map(.value) | join(""))]}) | from_entries) | add' +// From there I used multi-cursor to produce this match statement. +fn get_key_equivalents(layout_id: &str) -> Option> { + let mappings: &[(char, char)] = match layout_id { + "com.apple.keylayout.ABC-AZERTY" => &[ + ('!', '1'), + ('"', '%'), + ('#', '3'), + ('$', '4'), + ('%', '5'), + ('&', '7'), + ('(', '9'), + (')', '0'), + ('*', '8'), + ('.', ';'), + ('/', ':'), + ('0', 'à'), + ('1', '&'), + ('2', 'é'), + ('3', '"'), + ('4', '\''), + ('5', '('), + ('6', '§'), + ('7', 'è'), + ('8', '!'), + ('9', 'ç'), + (':', '°'), + (';', ')'), + ('<', '.'), + ('>', '/'), + ('@', '2'), + ('[', '^'), + ('\'', 'ù'), + ('\\', '`'), + (']', '$'), + ('^', '6'), + ('`', '<'), + ('{', '¨'), + ('|', '£'), + ('}', '*'), + ('~', '>'), + ], + "com.apple.keylayout.ABC-QWERTZ" => &[ + ('"', '`'), + ('#', '§'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', 'ß'), + (':', 'Ü'), + (';', 'ü'), + ('<', ';'), + ('=', '*'), + ('>', ':'), + ('@', '"'), + ('[', 'ö'), + ('\'', '´'), + ('\\', '#'), + (']', 'ä'), + ('^', '&'), + ('`', '<'), + ('{', 'Ö'), + ('|', '\''), + ('}', 'Ä'), + ('~', '>'), + ], + "com.apple.keylayout.Albanian" => &[ + ('"', '\''), + (':', 'Ç'), + (';', 'ç'), + ('<', ';'), + ('>', ':'), + ('@', '"'), + ('\'', '@'), + ('\\', 'ë'), + ('`', '<'), + ('|', 'Ë'), + ('~', '>'), + ], + "com.apple.keylayout.Austrian" => &[ + ('"', '`'), + ('#', '§'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', 'ß'), + (':', 'Ü'), + (';', 'ü'), + ('<', ';'), + ('=', '*'), + ('>', ':'), + ('@', '"'), + ('[', 'ö'), + ('\'', '´'), + ('\\', '#'), + (']', 'ä'), + ('^', '&'), + ('`', '<'), + ('{', 'Ö'), + ('|', '\''), + ('}', 'Ä'), + ('~', '>'), + ], + "com.apple.keylayout.Azeri" => &[ + ('"', 'Ə'), + (',', 'ç'), + ('.', 'ş'), + ('/', '.'), + (':', 'I'), + (';', 'ı'), + ('<', 'Ç'), + ('>', 'Ş'), + ('?', ','), + ('W', 'Ü'), + ('[', 'ö'), + ('\'', 'ə'), + (']', 'ğ'), + ('w', 'ü'), + ('{', 'Ö'), + ('|', '/'), + ('}', 'Ğ'), + ], + "com.apple.keylayout.Belgian" => &[ + ('!', '1'), + ('"', '%'), + ('#', '3'), + ('$', '4'), + ('%', '5'), + ('&', '7'), + ('(', '9'), + (')', '0'), + ('*', '8'), + ('.', ';'), + ('/', ':'), + ('0', 'à'), + ('1', '&'), + ('2', 'é'), + ('3', '"'), + ('4', '\''), + ('5', '('), + ('6', '§'), + ('7', 'è'), + ('8', '!'), + ('9', 'ç'), + (':', '°'), + (';', ')'), + ('<', '.'), + ('>', '/'), + ('@', '2'), + ('[', '^'), + ('\'', 'ù'), + ('\\', '`'), + (']', '$'), + ('^', '6'), + ('`', '<'), + ('{', '¨'), + ('|', '£'), + ('}', '*'), + ('~', '>'), + ], + "com.apple.keylayout.Brazilian-ABNT2" => &[ + ('"', '`'), + ('/', 'ç'), + ('?', 'Ç'), + ('\'', '´'), + ('\\', '~'), + ('^', '¨'), + ('`', '\''), + ('|', '^'), + ('~', '"'), + ], + "com.apple.keylayout.Brazilian-Pro" => &[('^', 'ˆ'), ('~', '˜')], + "com.apple.keylayout.British" => &[('#', '£')], + "com.apple.keylayout.Canadian-CSA" => &[ + ('"', 'È'), + ('/', 'é'), + ('<', '\''), + ('>', '"'), + ('?', 'É'), + ('[', '^'), + ('\'', 'è'), + ('\\', 'à'), + (']', 'ç'), + ('`', 'ù'), + ('{', '¨'), + ('|', 'À'), + ('}', 'Ç'), + ('~', 'Ù'), + ], + "com.apple.keylayout.Croatian" => &[ + ('"', 'Ć'), + ('&', '\''), + ('(', ')'), + (')', '='), + ('*', '('), + (':', 'Č'), + (';', 'č'), + ('<', ';'), + ('=', '*'), + ('>', ':'), + ('@', '"'), + ('[', 'š'), + ('\'', 'ć'), + ('\\', 'ž'), + (']', 'đ'), + ('^', '&'), + ('`', '<'), + ('{', 'Š'), + ('|', 'Ž'), + ('}', 'Đ'), + ('~', '>'), + ], + "com.apple.keylayout.Croatian-PC" => &[ + ('"', 'Ć'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '\''), + (':', 'Č'), + (';', 'č'), + ('<', ';'), + ('=', '*'), + ('>', ':'), + ('@', '"'), + ('[', 'š'), + ('\'', 'ć'), + ('\\', 'ž'), + (']', 'đ'), + ('^', '&'), + ('`', '<'), + ('{', 'Š'), + ('|', 'Ž'), + ('}', 'Đ'), + ('~', '>'), + ], + "com.apple.keylayout.Czech" => &[ + ('!', '1'), + ('"', '!'), + ('#', '3'), + ('$', '4'), + ('%', '5'), + ('&', '7'), + ('(', '9'), + (')', '0'), + ('*', '8'), + ('+', '%'), + ('/', '\''), + ('0', 'é'), + ('1', '+'), + ('2', 'ě'), + ('3', 'š'), + ('4', 'č'), + ('5', 'ř'), + ('6', 'ž'), + ('7', 'ý'), + ('8', 'á'), + ('9', 'í'), + (':', '"'), + (';', 'ů'), + ('<', '?'), + ('>', ':'), + ('?', 'ˇ'), + ('@', '2'), + ('[', 'ú'), + ('\'', '§'), + (']', ')'), + ('^', '6'), + ('`', '¨'), + ('{', 'Ú'), + ('}', '('), + ('~', '`'), + ], + "com.apple.keylayout.Czech-QWERTY" => &[ + ('!', '1'), + ('"', '!'), + ('#', '3'), + ('$', '4'), + ('%', '5'), + ('&', '7'), + ('(', '9'), + (')', '0'), + ('*', '8'), + ('+', '%'), + ('/', '\''), + ('0', 'é'), + ('1', '+'), + ('2', 'ě'), + ('3', 'š'), + ('4', 'č'), + ('5', 'ř'), + ('6', 'ž'), + ('7', 'ý'), + ('8', 'á'), + ('9', 'í'), + (':', '"'), + (';', 'ů'), + ('<', '?'), + ('>', ':'), + ('?', 'ˇ'), + ('@', '2'), + ('[', 'ú'), + ('\'', '§'), + (']', ')'), + ('^', '6'), + ('`', '¨'), + ('{', 'Ú'), + ('}', '('), + ('~', '`'), + ], + "com.apple.keylayout.Danish" => &[ + ('"', '^'), + ('$', '€'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '´'), + (':', 'Å'), + (';', 'å'), + ('<', ';'), + ('=', '`'), + ('>', ':'), + ('@', '"'), + ('[', 'æ'), + ('\'', '¨'), + ('\\', '\''), + (']', 'ø'), + ('^', '&'), + ('`', '<'), + ('{', 'Æ'), + ('|', '*'), + ('}', 'Ø'), + ('~', '>'), + ], + "com.apple.keylayout.Faroese" => &[ + ('"', 'Ø'), + ('$', '€'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '´'), + (':', 'Æ'), + (';', 'æ'), + ('<', ';'), + ('=', '`'), + ('>', ':'), + ('@', '"'), + ('[', 'å'), + ('\'', 'ø'), + ('\\', '\''), + (']', 'ð'), + ('^', '&'), + ('`', '<'), + ('{', 'Å'), + ('|', '*'), + ('}', 'Ð'), + ('~', '>'), + ], + "com.apple.keylayout.Finnish" => &[ + ('"', '^'), + ('$', '€'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '´'), + (':', 'Å'), + (';', 'å'), + ('<', ';'), + ('=', '`'), + ('>', ':'), + ('@', '"'), + ('[', 'ö'), + ('\'', '¨'), + ('\\', '\''), + (']', 'ä'), + ('^', '&'), + ('`', '<'), + ('{', 'Ö'), + ('|', '*'), + ('}', 'Ä'), + ('~', '>'), + ], + "com.apple.keylayout.FinnishExtended" => &[ + ('"', 'ˆ'), + ('$', '€'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '´'), + (':', 'Å'), + (';', 'å'), + ('<', ';'), + ('=', '`'), + ('>', ':'), + ('@', '"'), + ('[', 'ö'), + ('\'', '¨'), + ('\\', '\''), + (']', 'ä'), + ('^', '&'), + ('`', '<'), + ('{', 'Ö'), + ('|', '*'), + ('}', 'Ä'), + ('~', '>'), + ], + "com.apple.keylayout.FinnishSami-PC" => &[ + ('"', 'ˆ'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '´'), + (':', 'Å'), + (';', 'å'), + ('<', ';'), + ('=', '`'), + ('>', ':'), + ('@', '"'), + ('[', 'ö'), + ('\'', '¨'), + ('\\', '@'), + (']', 'ä'), + ('^', '&'), + ('`', '<'), + ('{', 'Ö'), + ('|', '*'), + ('}', 'Ä'), + ('~', '>'), + ], + "com.apple.keylayout.French" => &[ + ('!', '1'), + ('"', '%'), + ('#', '3'), + ('$', '4'), + ('%', '5'), + ('&', '7'), + ('(', '9'), + (')', '0'), + ('*', '8'), + ('.', ';'), + ('/', ':'), + ('0', 'à'), + ('1', '&'), + ('2', 'é'), + ('3', '"'), + ('4', '\''), + ('5', '('), + ('6', '§'), + ('7', 'è'), + ('8', '!'), + ('9', 'ç'), + (':', '°'), + (';', ')'), + ('<', '.'), + ('>', '/'), + ('@', '2'), + ('[', '^'), + ('\'', 'ù'), + ('\\', '`'), + (']', '$'), + ('^', '6'), + ('`', '<'), + ('{', '¨'), + ('|', '£'), + ('}', '*'), + ('~', '>'), + ], + "com.apple.keylayout.French-PC" => &[ + ('!', '1'), + ('"', '%'), + ('#', '3'), + ('$', '4'), + ('%', '5'), + ('&', '7'), + ('(', '9'), + (')', '0'), + ('*', '8'), + ('-', ')'), + ('.', ';'), + ('/', ':'), + ('0', 'à'), + ('1', '&'), + ('2', 'é'), + ('3', '"'), + ('4', '\''), + ('5', '('), + ('6', '-'), + ('7', 'è'), + ('8', '_'), + ('9', 'ç'), + (':', '§'), + (';', '!'), + ('<', '.'), + ('>', '/'), + ('@', '2'), + ('[', '^'), + ('\'', 'ù'), + ('\\', '*'), + (']', '$'), + ('^', '6'), + ('_', '°'), + ('`', '<'), + ('{', '¨'), + ('|', 'μ'), + ('}', '£'), + ('~', '>'), + ], + "com.apple.keylayout.French-numerical" => &[ + ('!', '1'), + ('"', '%'), + ('#', '3'), + ('$', '4'), + ('%', '5'), + ('&', '7'), + ('(', '9'), + (')', '0'), + ('*', '8'), + ('.', ';'), + ('/', ':'), + ('0', 'à'), + ('1', '&'), + ('2', 'é'), + ('3', '"'), + ('4', '\''), + ('5', '('), + ('6', '§'), + ('7', 'è'), + ('8', '!'), + ('9', 'ç'), + (':', '°'), + (';', ')'), + ('<', '.'), + ('>', '/'), + ('@', '2'), + ('[', '^'), + ('\'', 'ù'), + ('\\', '`'), + (']', '$'), + ('^', '6'), + ('`', '<'), + ('{', '¨'), + ('|', '£'), + ('}', '*'), + ('~', '>'), + ], + "com.apple.keylayout.German" => &[ + ('"', '`'), + ('#', '§'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', 'ß'), + (':', 'Ü'), + (';', 'ü'), + ('<', ';'), + ('=', '*'), + ('>', ':'), + ('@', '"'), + ('[', 'ö'), + ('\'', '´'), + ('\\', '#'), + (']', 'ä'), + ('^', '&'), + ('`', '<'), + ('{', 'Ö'), + ('|', '\''), + ('}', 'Ä'), + ('~', '>'), + ], + "com.apple.keylayout.German-DIN-2137" => &[ + ('"', '`'), + ('#', '§'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', 'ß'), + (':', 'Ü'), + (';', 'ü'), + ('<', ';'), + ('=', '*'), + ('>', ':'), + ('@', '"'), + ('[', 'ö'), + ('\'', '´'), + ('\\', '#'), + (']', 'ä'), + ('^', '&'), + ('`', '<'), + ('{', 'Ö'), + ('|', '\''), + ('}', 'Ä'), + ('~', '>'), + ], + "com.apple.keylayout.Hawaiian" => &[('\'', 'ʻ')], + "com.apple.keylayout.Hungarian" => &[ + ('!', '\''), + ('"', 'Á'), + ('#', '+'), + ('$', '!'), + ('&', '='), + ('(', ')'), + (')', 'Ö'), + ('*', '('), + ('+', 'Ó'), + ('/', 'ü'), + ('0', 'ö'), + (':', 'É'), + (';', 'é'), + ('<', 'Ü'), + ('=', 'ó'), + ('>', ':'), + ('@', '"'), + ('[', 'ő'), + ('\'', 'á'), + ('\\', 'ű'), + (']', 'ú'), + ('^', '/'), + ('`', 'í'), + ('{', 'Ő'), + ('|', 'Ű'), + ('}', 'Ú'), + ('~', 'Í'), + ], + "com.apple.keylayout.Hungarian-QWERTY" => &[ + ('!', '\''), + ('"', 'Á'), + ('#', '+'), + ('$', '!'), + ('&', '='), + ('(', ')'), + (')', 'Ö'), + ('*', '('), + ('+', 'Ó'), + ('/', 'ü'), + ('0', 'ö'), + (':', 'É'), + (';', 'é'), + ('<', 'Ü'), + ('=', 'ó'), + ('>', ':'), + ('@', '"'), + ('[', 'ő'), + ('\'', 'á'), + ('\\', 'ű'), + (']', 'ú'), + ('^', '/'), + ('`', 'í'), + ('{', 'Ő'), + ('|', 'Ű'), + ('}', 'Ú'), + ('~', 'Í'), + ], + "com.apple.keylayout.Icelandic" => &[ + ('"', 'Ö'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '\''), + (':', 'Ð'), + (';', 'ð'), + ('<', ';'), + ('=', '*'), + ('>', ':'), + ('@', '"'), + ('[', 'æ'), + ('\'', 'ö'), + ('\\', 'þ'), + (']', '´'), + ('^', '&'), + ('`', '<'), + ('{', 'Æ'), + ('|', 'Þ'), + ('}', '´'), + ('~', '>'), + ], + "com.apple.keylayout.Irish" => &[('#', '£')], + "com.apple.keylayout.IrishExtended" => &[('#', '£')], + "com.apple.keylayout.Italian" => &[ + ('!', '1'), + ('"', '%'), + ('#', '3'), + ('$', '4'), + ('%', '5'), + ('&', '7'), + ('(', '9'), + (')', '0'), + ('*', '8'), + (',', ';'), + ('.', ':'), + ('/', ','), + ('0', 'é'), + ('1', '&'), + ('2', '"'), + ('3', '\''), + ('4', '('), + ('5', 'ç'), + ('6', 'è'), + ('7', ')'), + ('8', '£'), + ('9', 'à'), + (':', '!'), + (';', 'ò'), + ('<', '.'), + ('>', '/'), + ('@', '2'), + ('[', 'ì'), + ('\'', 'ù'), + ('\\', '§'), + (']', '$'), + ('^', '6'), + ('`', '<'), + ('{', '^'), + ('|', '°'), + ('}', '*'), + ('~', '>'), + ], + "com.apple.keylayout.Italian-Pro" => &[ + ('"', '^'), + ('#', '£'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '\''), + (':', 'é'), + (';', 'è'), + ('<', ';'), + ('=', '*'), + ('>', ':'), + ('@', '"'), + ('[', 'ò'), + ('\'', 'ì'), + ('\\', 'ù'), + (']', 'à'), + ('^', '&'), + ('`', '<'), + ('{', 'ç'), + ('|', '§'), + ('}', '°'), + ('~', '>'), + ], + "com.apple.keylayout.LatinAmerican" => &[ + ('"', '¨'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '\''), + (':', 'Ñ'), + (';', 'ñ'), + ('<', ';'), + ('=', '*'), + ('>', ':'), + ('@', '"'), + ('[', '{'), + ('\'', '´'), + ('\\', '¿'), + (']', '}'), + ('^', '&'), + ('`', '<'), + ('{', '['), + ('|', '¡'), + ('}', ']'), + ('~', '>'), + ], + "com.apple.keylayout.Lithuanian" => &[ + ('!', 'Ą'), + ('#', 'Ę'), + ('$', 'Ė'), + ('%', 'Į'), + ('&', 'Ų'), + ('*', 'Ū'), + ('+', 'Ž'), + ('1', 'ą'), + ('2', 'č'), + ('3', 'ę'), + ('4', 'ė'), + ('5', 'į'), + ('6', 'š'), + ('7', 'ų'), + ('8', 'ū'), + ('=', 'ž'), + ('@', 'Č'), + ('^', 'Š'), + ], + "com.apple.keylayout.Maltese" => &[ + ('#', '£'), + ('[', 'ġ'), + (']', 'ħ'), + ('`', 'ż'), + ('{', 'Ġ'), + ('}', 'Ħ'), + ('~', 'Ż'), + ], + "com.apple.keylayout.NorthernSami" => &[ + ('"', 'Ŋ'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '´'), + (':', 'Å'), + (';', 'å'), + ('<', ';'), + ('=', '`'), + ('>', ':'), + ('@', '"'), + ('Q', 'Á'), + ('W', 'Š'), + ('X', 'Č'), + ('[', 'ø'), + ('\'', 'ŋ'), + ('\\', 'đ'), + (']', 'æ'), + ('^', '&'), + ('`', 'ž'), + ('q', 'á'), + ('w', 'š'), + ('x', 'č'), + ('{', 'Ø'), + ('|', 'Đ'), + ('}', 'Æ'), + ('~', 'Ž'), + ], + "com.apple.keylayout.Norwegian" => &[ + ('"', '^'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '´'), + (':', 'Å'), + (';', 'å'), + ('<', ';'), + ('=', '`'), + ('>', ':'), + ('@', '"'), + ('[', 'ø'), + ('\'', '¨'), + ('\\', '@'), + (']', 'æ'), + ('^', '&'), + ('`', '<'), + ('{', 'Ø'), + ('|', '*'), + ('}', 'Æ'), + ('~', '>'), + ], + "com.apple.keylayout.NorwegianExtended" => &[ + ('"', 'ˆ'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '´'), + (':', 'Å'), + (';', 'å'), + ('<', ';'), + ('=', '`'), + ('>', ':'), + ('@', '"'), + ('[', 'ø'), + ('\\', '@'), + (']', 'æ'), + ('`', '<'), + ('}', 'Æ'), + ('~', '>'), + ], + "com.apple.keylayout.NorwegianSami-PC" => &[ + ('"', 'ˆ'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '´'), + (':', 'Å'), + (';', 'å'), + ('<', ';'), + ('=', '`'), + ('>', ':'), + ('@', '"'), + ('[', 'ø'), + ('\'', '¨'), + ('\\', '@'), + (']', 'æ'), + ('^', '&'), + ('`', '<'), + ('{', 'Ø'), + ('|', '*'), + ('}', 'Æ'), + ('~', '>'), + ], + "com.apple.keylayout.Polish" => &[ + ('!', '§'), + ('"', 'ę'), + ('#', '!'), + ('$', '?'), + ('%', '+'), + ('&', ':'), + ('(', '/'), + (')', '"'), + ('*', '_'), + ('+', ']'), + (',', '.'), + ('.', ','), + ('/', 'ż'), + (':', 'Ł'), + (';', 'ł'), + ('<', 'ś'), + ('=', '['), + ('>', 'ń'), + ('?', 'Ż'), + ('@', '%'), + ('[', 'ó'), + ('\'', 'ą'), + ('\\', ';'), + (']', '('), + ('^', '='), + ('_', 'ć'), + ('`', '<'), + ('{', 'ź'), + ('|', '$'), + ('}', ')'), + ('~', '>'), + ], + "com.apple.keylayout.Portuguese" => &[ + ('"', '`'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '\''), + (':', 'ª'), + (';', 'º'), + ('<', ';'), + ('=', '*'), + ('>', ':'), + ('@', '"'), + ('[', 'ç'), + ('\'', '´'), + (']', '~'), + ('^', '&'), + ('`', '<'), + ('{', 'Ç'), + ('}', '^'), + ('~', '>'), + ], + "com.apple.keylayout.Sami-PC" => &[ + ('"', 'Ŋ'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '´'), + (':', 'Å'), + (';', 'å'), + ('<', ';'), + ('=', '`'), + ('>', ':'), + ('@', '"'), + ('Q', 'Á'), + ('W', 'Š'), + ('X', 'Č'), + ('[', 'ø'), + ('\'', 'ŋ'), + ('\\', 'đ'), + (']', 'æ'), + ('^', '&'), + ('`', 'ž'), + ('q', 'á'), + ('w', 'š'), + ('x', 'č'), + ('{', 'Ø'), + ('|', 'Đ'), + ('}', 'Æ'), + ('~', 'Ž'), + ], + "com.apple.keylayout.Serbian-Latin" => &[ + ('"', 'Ć'), + ('&', '\''), + ('(', ')'), + (')', '='), + ('*', '('), + (':', 'Č'), + (';', 'č'), + ('<', ';'), + ('=', '*'), + ('>', ':'), + ('@', '"'), + ('[', 'š'), + ('\'', 'ć'), + ('\\', 'ž'), + (']', 'đ'), + ('^', '&'), + ('`', '<'), + ('{', 'Š'), + ('|', 'Ž'), + ('}', 'Đ'), + ('~', '>'), + ], + "com.apple.keylayout.Slovak" => &[ + ('!', '1'), + ('"', '!'), + ('#', '3'), + ('$', '4'), + ('%', '5'), + ('&', '7'), + ('(', '9'), + (')', '0'), + ('*', '8'), + ('+', '%'), + ('/', '\''), + ('0', 'é'), + ('1', '+'), + ('2', 'ľ'), + ('3', 'š'), + ('4', 'č'), + ('5', 'ť'), + ('6', 'ž'), + ('7', 'ý'), + ('8', 'á'), + ('9', 'í'), + (':', '"'), + (';', 'ô'), + ('<', '?'), + ('>', ':'), + ('?', 'ˇ'), + ('@', '2'), + ('[', 'ú'), + ('\'', '§'), + (']', 'ä'), + ('^', '6'), + ('`', 'ň'), + ('{', 'Ú'), + ('}', 'Ä'), + ('~', 'Ň'), + ], + "com.apple.keylayout.Slovak-QWERTY" => &[ + ('!', '1'), + ('"', '!'), + ('#', '3'), + ('$', '4'), + ('%', '5'), + ('&', '7'), + ('(', '9'), + (')', '0'), + ('*', '8'), + ('+', '%'), + ('/', '\''), + ('0', 'é'), + ('1', '+'), + ('2', 'ľ'), + ('3', 'š'), + ('4', 'č'), + ('5', 'ť'), + ('6', 'ž'), + ('7', 'ý'), + ('8', 'á'), + ('9', 'í'), + (':', '"'), + (';', 'ô'), + ('<', '?'), + ('>', ':'), + ('?', 'ˇ'), + ('@', '2'), + ('[', 'ú'), + ('\'', '§'), + (']', 'ä'), + ('^', '6'), + ('`', 'ň'), + ('{', 'Ú'), + ('}', 'Ä'), + ('~', 'Ň'), + ], + "com.apple.keylayout.Slovenian" => &[ + ('"', 'Ć'), + ('&', '\''), + ('(', ')'), + (')', '='), + ('*', '('), + (':', 'Č'), + (';', 'č'), + ('<', ';'), + ('=', '*'), + ('>', ':'), + ('@', '"'), + ('[', 'š'), + ('\'', 'ć'), + ('\\', 'ž'), + (']', 'đ'), + ('^', '&'), + ('`', '<'), + ('{', 'Š'), + ('|', 'Ž'), + ('}', 'Đ'), + ('~', '>'), + ], + "com.apple.keylayout.Spanish" => &[ + ('!', '¡'), + ('"', '¨'), + ('.', 'ç'), + ('/', '.'), + (':', 'º'), + (';', '´'), + ('<', '¿'), + ('>', 'Ç'), + ('@', '!'), + ('[', 'ñ'), + ('\'', '`'), + ('\\', '\''), + (']', ';'), + ('^', '/'), + ('`', '<'), + ('{', 'Ñ'), + ('|', '"'), + ('}', ':'), + ('~', '>'), + ], + "com.apple.keylayout.Spanish-ISO" => &[ + ('"', '¨'), + ('#', '·'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('.', 'ç'), + ('/', '.'), + (':', 'º'), + (';', '´'), + ('<', '¿'), + ('>', 'Ç'), + ('@', '"'), + ('[', 'ñ'), + ('\'', '`'), + ('\\', '\''), + (']', ';'), + ('^', '&'), + ('`', '<'), + ('{', 'Ñ'), + ('|', '"'), + ('}', '`'), + ('~', '>'), + ], + "com.apple.keylayout.Swedish" => &[ + ('"', '^'), + ('$', '€'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '´'), + (':', 'Å'), + (';', 'å'), + ('<', ';'), + ('=', '`'), + ('>', ':'), + ('@', '"'), + ('[', 'ö'), + ('\'', '¨'), + ('\\', '\''), + (']', 'ä'), + ('^', '&'), + ('`', '<'), + ('{', 'Ö'), + ('|', '*'), + ('}', 'Ä'), + ('~', '>'), + ], + "com.apple.keylayout.Swedish-Pro" => &[ + ('"', '^'), + ('$', '€'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '´'), + (':', 'Å'), + (';', 'å'), + ('<', ';'), + ('=', '`'), + ('>', ':'), + ('@', '"'), + ('[', 'ö'), + ('\'', '¨'), + ('\\', '\''), + (']', 'ä'), + ('^', '&'), + ('`', '<'), + ('{', 'Ö'), + ('|', '*'), + ('}', 'Ä'), + ('~', '>'), + ], + "com.apple.keylayout.SwedishSami-PC" => &[ + ('"', 'ˆ'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('/', '´'), + (':', 'Å'), + (';', 'å'), + ('<', ';'), + ('=', '`'), + ('>', ':'), + ('@', '"'), + ('[', 'ö'), + ('\'', '¨'), + ('\\', '@'), + (']', 'ä'), + ('^', '&'), + ('`', '<'), + ('{', 'Ö'), + ('|', '*'), + ('}', 'Ä'), + ('~', '>'), + ], + "com.apple.keylayout.SwissFrench" => &[ + ('!', '+'), + ('"', '`'), + ('#', '*'), + ('$', 'ç'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('+', '!'), + ('/', '\''), + (':', 'ü'), + (';', 'è'), + ('<', ';'), + ('=', '¨'), + ('>', ':'), + ('@', '"'), + ('[', 'é'), + ('\'', '^'), + ('\\', '$'), + (']', 'à'), + ('^', '&'), + ('`', '<'), + ('{', 'ö'), + ('|', '£'), + ('}', 'ä'), + ('~', '>'), + ], + "com.apple.keylayout.SwissGerman" => &[ + ('!', '+'), + ('"', '`'), + ('#', '*'), + ('$', 'ç'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('+', '!'), + ('/', '\''), + (':', 'è'), + (';', 'ü'), + ('<', ';'), + ('=', '¨'), + ('>', ':'), + ('@', '"'), + ('[', 'ö'), + ('\'', '^'), + ('\\', '$'), + (']', 'ä'), + ('^', '&'), + ('`', '<'), + ('{', 'é'), + ('|', '£'), + ('}', 'à'), + ('~', '>'), + ], + "com.apple.keylayout.Turkish" => &[ + ('"', '-'), + ('#', '"'), + ('$', '\''), + ('%', '('), + ('&', ')'), + ('(', '%'), + (')', ':'), + ('*', '_'), + (',', 'ö'), + ('-', 'ş'), + ('.', 'ç'), + ('/', '.'), + (':', '$'), + ('<', 'Ö'), + ('>', 'Ç'), + ('@', '*'), + ('[', 'ğ'), + ('\'', ','), + ('\\', 'ü'), + (']', 'ı'), + ('^', '/'), + ('_', 'Ş'), + ('`', '<'), + ('{', 'Ğ'), + ('|', 'Ü'), + ('}', 'I'), + ('~', '>'), + ], + "com.apple.keylayout.Turkish-QWERTY-PC" => &[ + ('"', 'I'), + ('#', '^'), + ('$', '+'), + ('&', '/'), + ('(', ')'), + (')', '='), + ('*', '('), + ('+', ':'), + (',', 'ö'), + ('.', 'ç'), + ('/', '*'), + (':', 'Ş'), + (';', 'ş'), + ('<', 'Ö'), + ('=', '.'), + ('>', 'Ç'), + ('@', '\''), + ('[', 'ğ'), + ('\'', 'ı'), + ('\\', ','), + (']', 'ü'), + ('^', '&'), + ('`', '<'), + ('{', 'Ğ'), + ('|', ';'), + ('}', 'Ü'), + ('~', '>'), + ], + "com.apple.keylayout.Turkish-Standard" => &[ + ('"', 'Ş'), + ('#', '^'), + ('&', '\''), + ('(', ')'), + (')', '='), + ('*', '('), + (',', '.'), + ('.', ','), + (':', 'Ç'), + (';', 'ç'), + ('<', ':'), + ('=', '*'), + ('>', ';'), + ('@', '"'), + ('[', 'ğ'), + ('\'', 'ş'), + ('\\', 'ü'), + (']', 'ı'), + ('^', '&'), + ('`', 'ö'), + ('{', 'Ğ'), + ('|', 'Ü'), + ('}', 'I'), + ('~', 'Ö'), + ], + "com.apple.keylayout.Turkmen" => &[ + ('C', 'Ç'), + ('Q', 'Ä'), + ('V', 'Ý'), + ('X', 'Ü'), + ('[', 'ň'), + ('\\', 'ş'), + (']', 'ö'), + ('^', '№'), + ('`', 'ž'), + ('c', 'ç'), + ('q', 'ä'), + ('v', 'ý'), + ('x', 'ü'), + ('{', 'Ň'), + ('|', 'Ş'), + ('}', 'Ö'), + ('~', 'Ž'), + ], + "com.apple.keylayout.USInternational-PC" => &[('^', 'ˆ'), ('~', '˜')], + "com.apple.keylayout.Welsh" => &[('#', '£')], + + _ => return None, + }; + + Some(HashMap::from_iter(mappings.iter().cloned())) +} diff --git a/crates/gpui/src/platform/mac/metal_renderer.rs b/crates/gpui/src/platform/mac/metal_renderer.rs index 629654014d5a15632c5992d9347cab3ee1fd28d9..9e5d6ec5ff02c74b4f0acfada8eee3d002bfd06b 100644 --- a/crates/gpui/src/platform/mac/metal_renderer.rs +++ b/crates/gpui/src/platform/mac/metal_renderer.rs @@ -314,6 +314,15 @@ impl MetalRenderer { } fn update_path_intermediate_textures(&mut self, size: Size) { + // We are uncertain when this happens, but sometimes size can be 0 here. Most likely before + // the layout pass on window creation. Zero-sized texture creation causes SIGABRT. + // https://github.com/zed-industries/zed/issues/36229 + if size.width.0 <= 0 || size.height.0 <= 0 { + self.path_intermediate_texture = None; + self.path_intermediate_msaa_texture = None; + return; + } + let texture_descriptor = metal::TextureDescriptor::new(); texture_descriptor.set_width(size.width.0 as u64); texture_descriptor.set_height(size.height.0 as u64); @@ -323,7 +332,7 @@ impl MetalRenderer { self.path_intermediate_texture = Some(self.device.new_texture(&texture_descriptor)); if self.path_sample_count > 1 { - let mut msaa_descriptor = texture_descriptor.clone(); + let mut msaa_descriptor = texture_descriptor; msaa_descriptor.set_texture_type(metal::MTLTextureType::D2Multisample); msaa_descriptor.set_storage_mode(metal::MTLStorageMode::Private); msaa_descriptor.set_sample_count(self.path_sample_count as _); @@ -436,14 +445,14 @@ impl MetalRenderer { instance_buffer, &mut instance_offset, viewport_size, - &command_encoder, + command_encoder, ), PrimitiveBatch::Quads(quads) => self.draw_quads( quads, instance_buffer, &mut instance_offset, viewport_size, - &command_encoder, + command_encoder, ), PrimitiveBatch::Paths(paths) => { command_encoder.end_encoding(); @@ -471,7 +480,7 @@ impl MetalRenderer { instance_buffer, &mut instance_offset, viewport_size, - &command_encoder, + command_encoder, ) } else { false @@ -482,7 +491,7 @@ impl MetalRenderer { instance_buffer, &mut instance_offset, viewport_size, - &command_encoder, + command_encoder, ), PrimitiveBatch::MonochromeSprites { texture_id, @@ -493,7 +502,7 @@ impl MetalRenderer { instance_buffer, &mut instance_offset, viewport_size, - &command_encoder, + command_encoder, ), PrimitiveBatch::PolychromeSprites { texture_id, @@ -504,14 +513,14 @@ impl MetalRenderer { instance_buffer, &mut instance_offset, viewport_size, - &command_encoder, + command_encoder, ), PrimitiveBatch::Surfaces(surfaces) => self.draw_surfaces( surfaces, instance_buffer, &mut instance_offset, viewport_size, - &command_encoder, + command_encoder, ), }; if !ok { @@ -754,7 +763,7 @@ impl MetalRenderer { viewport_size: Size, command_encoder: &metal::RenderCommandEncoderRef, ) -> bool { - let Some(ref first_path) = paths.first() else { + let Some(first_path) = paths.first() else { return true; }; diff --git a/crates/gpui/src/platform/mac/open_type.rs b/crates/gpui/src/platform/mac/open_type.rs index 2ae5e8f87ab78a70e423a4645c96e69f098828a6..37a29559fdfbc284ffd1021cc6c2c6ed717ca228 100644 --- a/crates/gpui/src/platform/mac/open_type.rs +++ b/crates/gpui/src/platform/mac/open_type.rs @@ -35,14 +35,14 @@ pub fn apply_features_and_fallbacks( unsafe { let mut keys = vec![kCTFontFeatureSettingsAttribute]; let mut values = vec![generate_feature_array(features)]; - if let Some(fallbacks) = fallbacks { - if !fallbacks.fallback_list().is_empty() { - keys.push(kCTFontCascadeListAttribute); - values.push(generate_fallback_array( - fallbacks, - font.native_font().as_concrete_TypeRef(), - )); - } + if let Some(fallbacks) = fallbacks + && !fallbacks.fallback_list().is_empty() + { + keys.push(kCTFontCascadeListAttribute); + values.push(generate_fallback_array( + fallbacks, + font.native_font().as_concrete_TypeRef(), + )); } let attrs = CFDictionaryCreate( kCFAllocatorDefault, diff --git a/crates/gpui/src/platform/mac/platform.rs b/crates/gpui/src/platform/mac/platform.rs index c5731317994c2c81e456556815a0e47842a6c642..dea04d89a06acac526a8b033681829fdc1e148fd 100644 --- a/crates/gpui/src/platform/mac/platform.rs +++ b/crates/gpui/src/platform/mac/platform.rs @@ -1,5 +1,5 @@ use super::{ - BoolExt, MacKeyboardLayout, + BoolExt, MacKeyboardLayout, MacKeyboardMapper, attributed_string::{NSAttributedString, NSMutableAttributedString}, events::key_to_native, renderer, @@ -8,8 +8,9 @@ use crate::{ Action, AnyWindowHandle, BackgroundExecutor, ClipboardEntry, ClipboardItem, ClipboardString, CursorStyle, ForegroundExecutor, Image, ImageFormat, KeyContext, Keymap, MacDispatcher, MacDisplay, MacWindow, Menu, MenuItem, OsMenu, OwnedMenu, PathPromptOptions, Platform, - PlatformDisplay, PlatformKeyboardLayout, PlatformTextSystem, PlatformWindow, Result, - SemanticVersion, SystemMenuType, Task, WindowAppearance, WindowParams, hash, + PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, PlatformTextSystem, + PlatformWindow, Result, SemanticVersion, SystemMenuType, Task, WindowAppearance, WindowParams, + hash, }; use anyhow::{Context as _, anyhow}; use block::ConcreteBlock; @@ -171,6 +172,7 @@ pub(crate) struct MacPlatformState { finish_launching: Option>, dock_menu: Option, menus: Option>, + keyboard_mapper: Rc, } impl Default for MacPlatform { @@ -189,6 +191,9 @@ impl MacPlatform { #[cfg(not(feature = "font-kit"))] let text_system = Arc::new(crate::NoopTextSystem::new()); + let keyboard_layout = MacKeyboardLayout::new(); + let keyboard_mapper = Rc::new(MacKeyboardMapper::new(keyboard_layout.id())); + Self(Mutex::new(MacPlatformState { headless, text_system, @@ -209,6 +214,7 @@ impl MacPlatform { dock_menu: None, on_keyboard_layout_change: None, menus: None, + keyboard_mapper, })) } @@ -348,19 +354,19 @@ impl MacPlatform { let mut mask = NSEventModifierFlags::empty(); for (modifier, flag) in &[ ( - keystroke.modifiers.platform, + keystroke.modifiers().platform, NSEventModifierFlags::NSCommandKeyMask, ), ( - keystroke.modifiers.control, + keystroke.modifiers().control, NSEventModifierFlags::NSControlKeyMask, ), ( - keystroke.modifiers.alt, + keystroke.modifiers().alt, NSEventModifierFlags::NSAlternateKeyMask, ), ( - keystroke.modifiers.shift, + keystroke.modifiers().shift, NSEventModifierFlags::NSShiftKeyMask, ), ] { @@ -371,9 +377,9 @@ impl MacPlatform { item = NSMenuItem::alloc(nil) .initWithTitle_action_keyEquivalent_( - ns_string(&name), + ns_string(name), selector, - ns_string(key_to_native(&keystroke.key).as_ref()), + ns_string(key_to_native(keystroke.key()).as_ref()), ) .autorelease(); if Self::os_version() >= SemanticVersion::new(12, 0, 0) { @@ -383,7 +389,7 @@ impl MacPlatform { } else { item = NSMenuItem::alloc(nil) .initWithTitle_action_keyEquivalent_( - ns_string(&name), + ns_string(name), selector, ns_string(""), ) @@ -392,7 +398,7 @@ impl MacPlatform { } else { item = NSMenuItem::alloc(nil) .initWithTitle_action_keyEquivalent_( - ns_string(&name), + ns_string(name), selector, ns_string(""), ) @@ -412,7 +418,7 @@ impl MacPlatform { submenu.addItem_(Self::create_menu_item(item, delegate, actions, keymap)); } item.setSubmenu_(submenu); - item.setTitle_(ns_string(&name)); + item.setTitle_(ns_string(name)); item } MenuItem::SystemMenu(OsMenu { name, menu_type }) => { @@ -420,7 +426,7 @@ impl MacPlatform { let submenu = NSMenu::new(nil).autorelease(); submenu.setDelegate_(delegate); item.setSubmenu_(submenu); - item.setTitle_(ns_string(&name)); + item.setTitle_(ns_string(name)); match menu_type { SystemMenuType::Services => { @@ -705,6 +711,7 @@ impl Platform for MacPlatform { panel.setCanChooseDirectories_(options.directories.to_objc()); panel.setCanChooseFiles_(options.files.to_objc()); panel.setAllowsMultipleSelection_(options.multiple.to_objc()); + panel.setCanCreateDirectories(true.to_objc()); panel.setResolvesAliases_(false.to_objc()); let done_tx = Cell::new(Some(done_tx)); @@ -714,10 +721,10 @@ impl Platform for MacPlatform { let urls = panel.URLs(); for i in 0..urls.count() { let url = urls.objectAtIndex(i); - if url.isFileURL() == YES { - if let Ok(path) = ns_url_to_path(url) { - result.push(path) - } + if url.isFileURL() == YES + && let Ok(path) = ns_url_to_path(url) + { + result.push(path) } } Some(result) @@ -730,6 +737,11 @@ impl Platform for MacPlatform { } }); let block = block.copy(); + + if let Some(prompt) = options.prompt { + let _: () = msg_send![panel, setPrompt: ns_string(&prompt)]; + } + let _: () = msg_send![panel, beginWithCompletionHandler: block]; } }) @@ -737,8 +749,13 @@ impl Platform for MacPlatform { done_rx } - fn prompt_for_new_path(&self, directory: &Path) -> oneshot::Receiver>> { + fn prompt_for_new_path( + &self, + directory: &Path, + suggested_name: Option<&str>, + ) -> oneshot::Receiver>> { let directory = directory.to_owned(); + let suggested_name = suggested_name.map(|s| s.to_owned()); let (done_tx, done_rx) = oneshot::channel(); self.foreground_executor() .spawn(async move { @@ -748,6 +765,11 @@ impl Platform for MacPlatform { let url = NSURL::fileURLWithPath_isDirectory_(nil, path, true.to_objc()); panel.setDirectoryURL(url); + if let Some(suggested_name) = suggested_name { + let name_string = ns_string(&suggested_name); + let _: () = msg_send![panel, setNameFieldStringValue: name_string]; + } + let done_tx = Cell::new(Some(done_tx)); let block = ConcreteBlock::new(move |response: NSModalResponse| { let mut result = None; @@ -770,17 +792,18 @@ impl Platform for MacPlatform { // This is conditional on OS version because I'd like to get rid of it, so that // you can manually create a file called `a.sql.s`. That said it seems better // to break that use-case than breaking `a.sql`. - if chunks.len() == 3 && chunks[1].starts_with(chunks[2]) { - if Self::os_version() >= SemanticVersion::new(15, 0, 0) { - let new_filename = OsStr::from_bytes( - &filename.as_bytes() - [..chunks[0].len() + 1 + chunks[1].len()], - ) - .to_owned(); - result.set_file_name(&new_filename); - } + if chunks.len() == 3 + && chunks[1].starts_with(chunks[2]) + && Self::os_version() >= SemanticVersion::new(15, 0, 0) + { + let new_filename = OsStr::from_bytes( + &filename.as_bytes() + [..chunks[0].len() + 1 + chunks[1].len()], + ) + .to_owned(); + result.set_file_name(&new_filename); } - return result; + result }) } } @@ -865,6 +888,10 @@ impl Platform for MacPlatform { Box::new(MacKeyboardLayout::new()) } + fn keyboard_mapper(&self) -> Rc { + self.0.lock().keyboard_mapper.clone() + } + fn app_path(&self) -> Result { unsafe { let bundle: id = NSBundle::mainBundle(); @@ -1376,6 +1403,8 @@ extern "C" fn will_terminate(this: &mut Object, _: Sel, _: id) { extern "C" fn on_keyboard_layout_change(this: &mut Object, _: Sel, _: id) { let platform = unsafe { get_mac_platform(this) }; let mut lock = platform.0.lock(); + let keyboard_layout = MacKeyboardLayout::new(); + lock.keyboard_mapper = Rc::new(MacKeyboardMapper::new(keyboard_layout.id())); if let Some(mut callback) = lock.on_keyboard_layout_change.take() { drop(lock); callback(); diff --git a/crates/gpui/src/platform/mac/text_system.rs b/crates/gpui/src/platform/mac/text_system.rs index 849925c72772b70162f09fc680c0be2d6510878a..9144b2a23a40bd527e1441cf71adcc2562c33f3c 100644 --- a/crates/gpui/src/platform/mac/text_system.rs +++ b/crates/gpui/src/platform/mac/text_system.rs @@ -16,7 +16,7 @@ use core_foundation::{ use core_graphics::{ base::{CGGlyph, kCGImageAlphaPremultipliedLast}, color_space::CGColorSpace, - context::CGContext, + context::{CGContext, CGTextDrawingMode}, display::CGPoint, }; use core_text::{ @@ -319,7 +319,7 @@ impl MacTextSystemState { fn is_emoji(&self, font_id: FontId) -> bool { self.postscript_names_by_font_id .get(&font_id) - .map_or(false, |postscript_name| { + .is_some_and(|postscript_name| { postscript_name == "AppleColorEmoji" || postscript_name == ".AppleColorEmojiUI" }) } @@ -396,6 +396,12 @@ impl MacTextSystemState { let subpixel_shift = params .subpixel_variant .map(|v| v as f32 / SUBPIXEL_VARIANTS as f32); + cx.set_allows_font_smoothing(true); + cx.set_should_smooth_fonts(true); + cx.set_text_drawing_mode(CGTextDrawingMode::CGTextFill); + cx.set_gray_fill_color(0.0, 1.0); + cx.set_allows_antialiasing(true); + cx.set_should_antialias(true); cx.set_allows_font_subpixel_positioning(true); cx.set_should_subpixel_position_fonts(true); cx.set_allows_font_subpixel_quantization(false); diff --git a/crates/gpui/src/platform/mac/window.rs b/crates/gpui/src/platform/mac/window.rs index aedf131909a6956e9a4501b107c81ce242b80a49..1230a704062ba835bceb5db5d2ecf05b688e34df 100644 --- a/crates/gpui/src/platform/mac/window.rs +++ b/crates/gpui/src/platform/mac/window.rs @@ -4,8 +4,9 @@ use crate::{ ForegroundExecutor, KeyDownEvent, Keystroke, Modifiers, ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Pixels, PlatformAtlas, PlatformDisplay, PlatformInput, PlatformWindow, Point, PromptButton, PromptLevel, RequestFrameOptions, - ScaledPixels, Size, Timer, WindowAppearance, WindowBackgroundAppearance, WindowBounds, - WindowControlArea, WindowKind, WindowParams, platform::PlatformInputHandler, point, px, size, + SharedString, Size, SystemWindowTab, Timer, WindowAppearance, WindowBackgroundAppearance, + WindowBounds, WindowControlArea, WindowKind, WindowParams, dispatch_get_main_queue, + dispatch_sys::dispatch_async_f, platform::PlatformInputHandler, point, px, size, }; use block::ConcreteBlock; use cocoa::{ @@ -24,6 +25,7 @@ use cocoa::{ NSUserDefaults, }, }; + use core_graphics::display::{CGDirectDisplayID, CGPoint, CGRect}; use ctor::ctor; use futures::channel::oneshot; @@ -82,6 +84,12 @@ type NSDragOperation = NSUInteger; const NSDragOperationNone: NSDragOperation = 0; #[allow(non_upper_case_globals)] const NSDragOperationCopy: NSDragOperation = 1; +#[derive(PartialEq)] +pub enum UserTabbingPreference { + Never, + Always, + InFullScreen, +} #[link(name = "CoreGraphics", kind = "framework")] unsafe extern "C" { @@ -343,6 +351,36 @@ unsafe fn build_window_class(name: &'static str, superclass: &Class) -> *const C conclude_drag_operation as extern "C" fn(&Object, Sel, id), ); + decl.add_method( + sel!(addTitlebarAccessoryViewController:), + add_titlebar_accessory_view_controller as extern "C" fn(&Object, Sel, id), + ); + + decl.add_method( + sel!(moveTabToNewWindow:), + move_tab_to_new_window as extern "C" fn(&Object, Sel, id), + ); + + decl.add_method( + sel!(mergeAllWindows:), + merge_all_windows as extern "C" fn(&Object, Sel, id), + ); + + decl.add_method( + sel!(selectNextTab:), + select_next_tab as extern "C" fn(&Object, Sel, id), + ); + + decl.add_method( + sel!(selectPreviousTab:), + select_previous_tab as extern "C" fn(&Object, Sel, id), + ); + + decl.add_method( + sel!(toggleTabBar:), + toggle_tab_bar as extern "C" fn(&Object, Sel, id), + ); + decl.register() } } @@ -375,6 +413,11 @@ struct MacWindowState { // Whether the next left-mouse click is also the focusing click. first_mouse: bool, fullscreen_restore_bounds: Bounds, + move_tab_to_new_window_callback: Option>, + merge_all_windows_callback: Option>, + select_next_tab_callback: Option>, + select_previous_tab_callback: Option>, + toggle_tab_bar_callback: Option>, } impl MacWindowState { @@ -530,10 +573,13 @@ impl MacWindow { titlebar, kind, is_movable, + is_resizable, + is_minimizable, focus, show, display_id, window_min_size, + tabbing_identifier, }: WindowParams, executor: ForegroundExecutor, renderer_context: renderer::Context, @@ -541,14 +587,25 @@ impl MacWindow { unsafe { let pool = NSAutoreleasePool::new(nil); - let () = msg_send![class!(NSWindow), setAllowsAutomaticWindowTabbing: NO]; + let allows_automatic_window_tabbing = tabbing_identifier.is_some(); + if allows_automatic_window_tabbing { + let () = msg_send![class!(NSWindow), setAllowsAutomaticWindowTabbing: YES]; + } else { + let () = msg_send![class!(NSWindow), setAllowsAutomaticWindowTabbing: NO]; + } let mut style_mask; if let Some(titlebar) = titlebar.as_ref() { - style_mask = NSWindowStyleMask::NSClosableWindowMask - | NSWindowStyleMask::NSMiniaturizableWindowMask - | NSWindowStyleMask::NSResizableWindowMask - | NSWindowStyleMask::NSTitledWindowMask; + style_mask = + NSWindowStyleMask::NSClosableWindowMask | NSWindowStyleMask::NSTitledWindowMask; + + if is_resizable { + style_mask |= NSWindowStyleMask::NSResizableWindowMask; + } + + if is_minimizable { + style_mask |= NSWindowStyleMask::NSMiniaturizableWindowMask; + } if titlebar.appears_transparent { style_mask |= NSWindowStyleMask::NSFullSizeContentViewWindowMask; @@ -653,13 +710,18 @@ impl MacWindow { .and_then(|titlebar| titlebar.traffic_light_position), transparent_titlebar: titlebar .as_ref() - .map_or(true, |titlebar| titlebar.appears_transparent), + .is_none_or(|titlebar| titlebar.appears_transparent), previous_modifiers_changed_event: None, keystroke_for_do_command: None, do_command_handled: None, external_files_dragged: false, first_mouse: false, fullscreen_restore_bounds: Bounds::default(), + move_tab_to_new_window_callback: None, + merge_all_windows_callback: None, + select_next_tab_callback: None, + select_previous_tab_callback: None, + toggle_tab_bar_callback: None, }))); (*native_window).set_ivar( @@ -688,7 +750,7 @@ impl MacWindow { }); } - if titlebar.map_or(true, |titlebar| titlebar.appears_transparent) { + if titlebar.is_none_or(|titlebar| titlebar.appears_transparent) { native_window.setTitlebarAppearsTransparent_(YES); native_window.setTitleVisibility_(NSWindowTitleVisibility::NSWindowTitleHidden); } @@ -714,6 +776,13 @@ impl MacWindow { WindowKind::Normal => { native_window.setLevel_(NSNormalWindowLevel); native_window.setAcceptsMouseMovedEvents_(YES); + + if let Some(tabbing_identifier) = tabbing_identifier { + let tabbing_id = NSString::alloc(nil).init_str(tabbing_identifier.as_str()); + let _: () = msg_send![native_window, setTabbingIdentifier: tabbing_id]; + } else { + let _: () = msg_send![native_window, setTabbingIdentifier:nil]; + } } WindowKind::PopUp => { // Use a tracking area to allow receiving MouseMoved events even when @@ -742,6 +811,38 @@ impl MacWindow { } } + let app = NSApplication::sharedApplication(nil); + let main_window: id = msg_send![app, mainWindow]; + if allows_automatic_window_tabbing + && !main_window.is_null() + && main_window != native_window + { + let main_window_is_fullscreen = main_window + .styleMask() + .contains(NSWindowStyleMask::NSFullScreenWindowMask); + let user_tabbing_preference = Self::get_user_tabbing_preference() + .unwrap_or(UserTabbingPreference::InFullScreen); + let should_add_as_tab = user_tabbing_preference == UserTabbingPreference::Always + || user_tabbing_preference == UserTabbingPreference::InFullScreen + && main_window_is_fullscreen; + + if should_add_as_tab { + let main_window_can_tab: BOOL = + msg_send![main_window, respondsToSelector: sel!(addTabbedWindow:ordered:)]; + let main_window_visible: BOOL = msg_send![main_window, isVisible]; + + if main_window_can_tab == YES && main_window_visible == YES { + let _: () = msg_send![main_window, addTabbedWindow: native_window ordered: NSWindowOrderingMode::NSWindowAbove]; + + // Ensure the window is visible immediately after adding the tab, since the tab bar is updated with a new entry at this point. + // Note: Calling orderFront here can break fullscreen mode (makes fullscreen windows exit fullscreen), so only do this if the main window is not fullscreen. + if !main_window_is_fullscreen { + let _: () = msg_send![native_window, orderFront: nil]; + } + } + } + } + if focus && show { native_window.makeKeyAndOrderFront_(nil); } else if show { @@ -796,6 +897,33 @@ impl MacWindow { window_handles } } + + pub fn get_user_tabbing_preference() -> Option { + unsafe { + let defaults: id = NSUserDefaults::standardUserDefaults(); + let domain = NSString::alloc(nil).init_str("NSGlobalDomain"); + let key = NSString::alloc(nil).init_str("AppleWindowTabbingMode"); + + let dict: id = msg_send![defaults, persistentDomainForName: domain]; + let value: id = if !dict.is_null() { + msg_send![dict, objectForKey: key] + } else { + nil + }; + + let value_str = if !value.is_null() { + CStr::from_ptr(NSString::UTF8String(value)).to_string_lossy() + } else { + "".into() + }; + + match value_str.as_ref() { + "manual" => Some(UserTabbingPreference::Never), + "always" => Some(UserTabbingPreference::Always), + _ => Some(UserTabbingPreference::InFullScreen), + } + } + } } impl Drop for MacWindow { @@ -851,6 +979,65 @@ impl PlatformWindow for MacWindow { .detach(); } + fn merge_all_windows(&self) { + let native_window = self.0.lock().native_window; + unsafe extern "C" fn merge_windows_async(context: *mut std::ffi::c_void) { + let native_window = context as id; + let _: () = msg_send![native_window, mergeAllWindows:nil]; + } + + unsafe { + dispatch_async_f( + dispatch_get_main_queue(), + native_window as *mut std::ffi::c_void, + Some(merge_windows_async), + ); + } + } + + fn move_tab_to_new_window(&self) { + let native_window = self.0.lock().native_window; + unsafe extern "C" fn move_tab_async(context: *mut std::ffi::c_void) { + let native_window = context as id; + let _: () = msg_send![native_window, moveTabToNewWindow:nil]; + let _: () = msg_send![native_window, makeKeyAndOrderFront: nil]; + } + + unsafe { + dispatch_async_f( + dispatch_get_main_queue(), + native_window as *mut std::ffi::c_void, + Some(move_tab_async), + ); + } + } + + fn toggle_window_tab_overview(&self) { + let native_window = self.0.lock().native_window; + unsafe { + let _: () = msg_send![native_window, toggleTabOverview:nil]; + } + } + + fn set_tabbing_identifier(&self, tabbing_identifier: Option) { + let native_window = self.0.lock().native_window; + unsafe { + let allows_automatic_window_tabbing = tabbing_identifier.is_some(); + if allows_automatic_window_tabbing { + let () = msg_send![class!(NSWindow), setAllowsAutomaticWindowTabbing: YES]; + } else { + let () = msg_send![class!(NSWindow), setAllowsAutomaticWindowTabbing: NO]; + } + + if let Some(tabbing_identifier) = tabbing_identifier { + let tabbing_id = NSString::alloc(nil).init_str(tabbing_identifier.as_str()); + let _: () = msg_send![native_window, setTabbingIdentifier: tabbing_id]; + } else { + let _: () = msg_send![native_window, setTabbingIdentifier:nil]; + } + } + } + fn scale_factor(&self) -> f32 { self.0.as_ref().lock().scale_factor() } @@ -1051,6 +1238,17 @@ impl PlatformWindow for MacWindow { } } + fn get_title(&self) -> String { + unsafe { + let title: id = msg_send![self.0.lock().native_window, title]; + if title.is_null() { + "".to_string() + } else { + title.to_str().to_string() + } + } + } + fn set_app_id(&mut self, _app_id: &str) {} fn set_background_appearance(&self, background_appearance: WindowBackgroundAppearance) { @@ -1090,7 +1288,7 @@ impl PlatformWindow for MacWindow { NSView::removeFromSuperview(blur_view); this.blurred_view = None; } - } else if this.blurred_view == None { + } else if this.blurred_view.is_none() { let content_view = this.native_window.contentView(); let frame = NSView::bounds(content_view); let mut blur_view: id = msg_send![BLURRED_VIEW_CLASS, alloc]; @@ -1212,6 +1410,62 @@ impl PlatformWindow for MacWindow { self.0.lock().appearance_changed_callback = Some(callback); } + fn tabbed_windows(&self) -> Option> { + unsafe { + let windows: id = msg_send![self.0.lock().native_window, tabbedWindows]; + if windows.is_null() { + return None; + } + + let count: NSUInteger = msg_send![windows, count]; + let mut result = Vec::new(); + for i in 0..count { + let window: id = msg_send![windows, objectAtIndex:i]; + if msg_send![window, isKindOfClass: WINDOW_CLASS] { + let handle = get_window_state(&*window).lock().handle; + let title: id = msg_send![window, title]; + let title = SharedString::from(title.to_str().to_string()); + + result.push(SystemWindowTab::new(title, handle)); + } + } + + Some(result) + } + } + + fn tab_bar_visible(&self) -> bool { + unsafe { + let tab_group: id = msg_send![self.0.lock().native_window, tabGroup]; + if tab_group.is_null() { + false + } else { + let tab_bar_visible: BOOL = msg_send![tab_group, isTabBarVisible]; + tab_bar_visible == YES + } + } + } + + fn on_move_tab_to_new_window(&self, callback: Box) { + self.0.as_ref().lock().move_tab_to_new_window_callback = Some(callback); + } + + fn on_merge_all_windows(&self, callback: Box) { + self.0.as_ref().lock().merge_all_windows_callback = Some(callback); + } + + fn on_select_next_tab(&self, callback: Box) { + self.0.as_ref().lock().select_next_tab_callback = Some(callback); + } + + fn on_select_previous_tab(&self, callback: Box) { + self.0.as_ref().lock().select_previous_tab_callback = Some(callback); + } + + fn on_toggle_tab_bar(&self, callback: Box) { + self.0.as_ref().lock().toggle_tab_bar_callback = Some(callback); + } + fn draw(&self, scene: &crate::Scene) { let mut this = self.0.lock(); this.renderer.draw(scene); @@ -1225,7 +1479,7 @@ impl PlatformWindow for MacWindow { None } - fn update_ime_position(&self, _bounds: Bounds) { + fn update_ime_position(&self, _bounds: Bounds) { let executor = self.0.lock().executor.clone(); executor .spawn(async move { @@ -1478,18 +1732,18 @@ extern "C" fn handle_key_event(this: &Object, native_event: id, key_equivalent: return YES; } - if key_down_event.is_held { - if let Some(key_char) = key_down_event.keystroke.key_char.as_ref() { - let handled = with_input_handler(&this, |input_handler| { - if !input_handler.apple_press_and_hold_enabled() { - input_handler.replace_text_in_range(None, &key_char); - return YES; - } - NO - }); - if handled == Some(YES) { + if key_down_event.is_held + && let Some(key_char) = key_down_event.keystroke.key_char.as_ref() + { + let handled = with_input_handler(this, |input_handler| { + if !input_handler.apple_press_and_hold_enabled() { + input_handler.replace_text_in_range(None, key_char); return YES; } + NO + }); + if handled == Some(YES) { + return YES; } } @@ -1624,10 +1878,10 @@ extern "C" fn handle_view_event(this: &Object, _: Sel, native_event: id) { modifiers: prev_modifiers, capslock: prev_capslock, })) = &lock.previous_modifiers_changed_event + && prev_modifiers == modifiers + && prev_capslock == capslock { - if prev_modifiers == modifiers && prev_capslock == capslock { - return; - } + return; } lock.previous_modifiers_changed_event = Some(event.clone()); @@ -1653,6 +1907,7 @@ extern "C" fn window_did_change_occlusion_state(this: &Object, _: Sel, _: id) { .occlusionState() .contains(NSWindowOcclusionState::NSWindowOcclusionStateVisible) { + lock.move_traffic_light(); lock.start_display_link(); } else { lock.stop_display_link(); @@ -1714,7 +1969,7 @@ extern "C" fn window_did_change_screen(this: &Object, _: Sel, _: id) { extern "C" fn window_did_change_key_status(this: &Object, selector: Sel, _: id) { let window_state = unsafe { get_window_state(this) }; - let lock = window_state.lock(); + let mut lock = window_state.lock(); let is_active = unsafe { lock.native_window.isKeyWindow() == YES }; // When opening a pop-up while the application isn't active, Cocoa sends a spurious @@ -1735,9 +1990,34 @@ extern "C" fn window_did_change_key_status(this: &Object, selector: Sel, _: id) let executor = lock.executor.clone(); drop(lock); + + // If window is becoming active, trigger immediate synchronous frame request. + if selector == sel!(windowDidBecomeKey:) && is_active { + let window_state = unsafe { get_window_state(this) }; + let mut lock = window_state.lock(); + + if let Some(mut callback) = lock.request_frame_callback.take() { + #[cfg(not(feature = "macos-blade"))] + lock.renderer.set_presents_with_transaction(true); + lock.stop_display_link(); + drop(lock); + callback(Default::default()); + + let mut lock = window_state.lock(); + lock.request_frame_callback = Some(callback); + #[cfg(not(feature = "macos-blade"))] + lock.renderer.set_presents_with_transaction(false); + lock.start_display_link(); + } + } + executor .spawn(async move { let mut lock = window_state.as_ref().lock(); + if is_active { + lock.move_traffic_light(); + } + if let Some(mut callback) = lock.activate_callback.take() { drop(lock); callback(is_active); @@ -1949,7 +2229,7 @@ extern "C" fn insert_text(this: &Object, _: Sel, text: id, replacement_range: NS let text = text.to_str(); let replacement_range = replacement_range.to_range(); with_input_handler(this, |input_handler| { - input_handler.replace_text_in_range(replacement_range, &text) + input_handler.replace_text_in_range(replacement_range, text) }); } } @@ -1973,7 +2253,7 @@ extern "C" fn set_marked_text( let replacement_range = replacement_range.to_range(); let text = text.to_str(); with_input_handler(this, |input_handler| { - input_handler.replace_and_mark_text_in_range(replacement_range, &text, selected_range) + input_handler.replace_and_mark_text_in_range(replacement_range, text, selected_range) }); } } @@ -1995,10 +2275,10 @@ extern "C" fn attributed_substring_for_proposed_range( let mut adjusted: Option> = None; let selected_text = input_handler.text_for_range(range.clone(), &mut adjusted)?; - if let Some(adjusted) = adjusted { - if adjusted != range { - unsafe { (actual_range as *mut NSRange).write(NSRange::from(adjusted)) }; - } + if let Some(adjusted) = adjusted + && adjusted != range + { + unsafe { (actual_range as *mut NSRange).write(NSRange::from(adjusted)) }; } unsafe { let string: id = msg_send![class!(NSAttributedString), alloc]; @@ -2063,8 +2343,8 @@ fn screen_point_to_gpui_point(this: &Object, position: NSPoint) -> Point let frame = get_frame(this); let window_x = position.x - frame.origin.x; let window_y = frame.size.height - (position.y - frame.origin.y); - let position = point(px(window_x as f32), px(window_y as f32)); - position + + point(px(window_x as f32), px(window_y as f32)) } extern "C" fn dragging_entered(this: &Object, _: Sel, dragging_info: id) -> NSDragOperation { @@ -2073,11 +2353,10 @@ extern "C" fn dragging_entered(this: &Object, _: Sel, dragging_info: id) -> NSDr let paths = external_paths_from_event(dragging_info); if let Some(event) = paths.map(|paths| PlatformInput::FileDrop(FileDropEvent::Entered { position, paths })) + && send_new_event(&window_state, event) { - if send_new_event(&window_state, event) { - window_state.lock().external_files_dragged = true; - return NSDragOperationCopy; - } + window_state.lock().external_files_dragged = true; + return NSDragOperationCopy; } NSDragOperationNone } @@ -2274,3 +2553,80 @@ unsafe fn remove_layer_background(layer: id) { } } } + +extern "C" fn add_titlebar_accessory_view_controller(this: &Object, _: Sel, view_controller: id) { + unsafe { + let _: () = msg_send![super(this, class!(NSWindow)), addTitlebarAccessoryViewController: view_controller]; + + // Hide the native tab bar and set its height to 0, since we render our own. + let accessory_view: id = msg_send![view_controller, view]; + let _: () = msg_send![accessory_view, setHidden: YES]; + let mut frame: NSRect = msg_send![accessory_view, frame]; + frame.size.height = 0.0; + let _: () = msg_send![accessory_view, setFrame: frame]; + } +} + +extern "C" fn move_tab_to_new_window(this: &Object, _: Sel, _: id) { + unsafe { + let _: () = msg_send![super(this, class!(NSWindow)), moveTabToNewWindow:nil]; + + let window_state = get_window_state(this); + let mut lock = window_state.as_ref().lock(); + if let Some(mut callback) = lock.move_tab_to_new_window_callback.take() { + drop(lock); + callback(); + window_state.lock().move_tab_to_new_window_callback = Some(callback); + } + } +} + +extern "C" fn merge_all_windows(this: &Object, _: Sel, _: id) { + unsafe { + let _: () = msg_send![super(this, class!(NSWindow)), mergeAllWindows:nil]; + + let window_state = get_window_state(this); + let mut lock = window_state.as_ref().lock(); + if let Some(mut callback) = lock.merge_all_windows_callback.take() { + drop(lock); + callback(); + window_state.lock().merge_all_windows_callback = Some(callback); + } + } +} + +extern "C" fn select_next_tab(this: &Object, _sel: Sel, _id: id) { + let window_state = unsafe { get_window_state(this) }; + let mut lock = window_state.as_ref().lock(); + if let Some(mut callback) = lock.select_next_tab_callback.take() { + drop(lock); + callback(); + window_state.lock().select_next_tab_callback = Some(callback); + } +} + +extern "C" fn select_previous_tab(this: &Object, _sel: Sel, _id: id) { + let window_state = unsafe { get_window_state(this) }; + let mut lock = window_state.as_ref().lock(); + if let Some(mut callback) = lock.select_previous_tab_callback.take() { + drop(lock); + callback(); + window_state.lock().select_previous_tab_callback = Some(callback); + } +} + +extern "C" fn toggle_tab_bar(this: &Object, _sel: Sel, _id: id) { + unsafe { + let _: () = msg_send![super(this, class!(NSWindow)), toggleTabBar:nil]; + + let window_state = get_window_state(this); + let mut lock = window_state.as_ref().lock(); + lock.move_traffic_light(); + + if let Some(mut callback) = lock.toggle_tab_bar_callback.take() { + drop(lock); + callback(); + window_state.lock().toggle_tab_bar_callback = Some(callback); + } + } +} diff --git a/crates/gpui/src/platform/scap_screen_capture.rs b/crates/gpui/src/platform/scap_screen_capture.rs index 32041b655fdc20b046717291c623dcb5c4d5146c..d6d19cd8102d58ceaa9bc87bff348eaeda9adfef 100644 --- a/crates/gpui/src/platform/scap_screen_capture.rs +++ b/crates/gpui/src/platform/scap_screen_capture.rs @@ -228,7 +228,7 @@ fn run_capture( display, size, })); - if let Err(_) = stream_send_result { + if stream_send_result.is_err() { return; } while !cancel_stream.load(std::sync::atomic::Ordering::SeqCst) { diff --git a/crates/gpui/src/platform/test/dispatcher.rs b/crates/gpui/src/platform/test/dispatcher.rs index 16edabfa4bfee9c66dcf6ed8abc5eeb7957a7fa0..e19710effda9299c6eb72e8c4acc2f615ac077ee 100644 --- a/crates/gpui/src/platform/test/dispatcher.rs +++ b/crates/gpui/src/platform/test/dispatcher.rs @@ -78,11 +78,11 @@ impl TestDispatcher { let state = self.state.lock(); let next_due_time = state.delayed.first().map(|(time, _)| *time); drop(state); - if let Some(due_time) = next_due_time { - if due_time <= new_now { - self.state.lock().time = due_time; - continue; - } + if let Some(due_time) = next_due_time + && due_time <= new_now + { + self.state.lock().time = due_time; + continue; } break; } @@ -118,7 +118,7 @@ impl TestDispatcher { } YieldNow { - count: self.state.lock().random.gen_range(0..10), + count: self.state.lock().random.random_range(0..10), } } @@ -151,11 +151,11 @@ impl TestDispatcher { if deprioritized_background_len == 0 { return false; } - let ix = state.random.gen_range(0..deprioritized_background_len); + let ix = state.random.random_range(0..deprioritized_background_len); main_thread = false; runnable = state.deprioritized_background.swap_remove(ix); } else { - main_thread = state.random.gen_ratio( + main_thread = state.random.random_ratio( foreground_len as u32, (foreground_len + background_len) as u32, ); @@ -170,7 +170,7 @@ impl TestDispatcher { .pop_front() .unwrap(); } else { - let ix = state.random.gen_range(0..background_len); + let ix = state.random.random_range(0..background_len); runnable = state.background.swap_remove(ix); }; }; @@ -241,7 +241,7 @@ impl TestDispatcher { pub fn gen_block_on_ticks(&self) -> usize { let mut lock = self.state.lock(); let block_on_ticks = lock.block_on_ticks.clone(); - lock.random.gen_range(block_on_ticks) + lock.random.random_range(block_on_ticks) } } @@ -270,9 +270,7 @@ impl PlatformDispatcher for TestDispatcher { fn dispatch(&self, runnable: Runnable, label: Option) { { let mut state = self.state.lock(); - if label.map_or(false, |label| { - state.deprioritized_task_labels.contains(&label) - }) { + if label.is_some_and(|label| state.deprioritized_task_labels.contains(&label)) { state.deprioritized_background.push(runnable); } else { state.background.push(runnable); diff --git a/crates/gpui/src/platform/test/platform.rs b/crates/gpui/src/platform/test/platform.rs index a26b65576cc49e290494762eed597d5bd8d0af26..15b909199fbd53b974e6a140f3223641dc0ac6ae 100644 --- a/crates/gpui/src/platform/test/platform.rs +++ b/crates/gpui/src/platform/test/platform.rs @@ -1,8 +1,9 @@ use crate::{ AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DevicePixels, - ForegroundExecutor, Keymap, NoopTextSystem, Platform, PlatformDisplay, PlatformKeyboardLayout, - PlatformTextSystem, PromptButton, ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream, - SourceMetadata, Task, TestDisplay, TestWindow, WindowAppearance, WindowParams, size, + DummyKeyboardMapper, ForegroundExecutor, Keymap, NoopTextSystem, Platform, PlatformDisplay, + PlatformKeyboardLayout, PlatformKeyboardMapper, PlatformTextSystem, PromptButton, + ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream, SourceMetadata, Task, + TestDisplay, TestWindow, WindowAppearance, WindowParams, size, }; use anyhow::Result; use collections::VecDeque; @@ -187,24 +188,24 @@ impl TestPlatform { .push_back(TestPrompt { msg: msg.to_string(), detail: detail.map(|s| s.to_string()), - answers: answers.clone(), + answers, tx, }); rx } pub(crate) fn set_active_window(&self, window: Option) { - let executor = self.foreground_executor().clone(); + let executor = self.foreground_executor(); let previous_window = self.active_window.borrow_mut().take(); self.active_window.borrow_mut().clone_from(&window); executor .spawn(async move { if let Some(previous_window) = previous_window { - if let Some(window) = window.as_ref() { - if Rc::ptr_eq(&previous_window.0, &window.0) { - return; - } + if let Some(window) = window.as_ref() + && Rc::ptr_eq(&previous_window.0, &window.0) + { + return; } previous_window.simulate_active_status_change(false); } @@ -237,6 +238,10 @@ impl Platform for TestPlatform { Box::new(TestKeyboardLayout) } + fn keyboard_mapper(&self) -> Rc { + Rc::new(DummyKeyboardMapper) + } + fn on_keyboard_layout_change(&self, _: Box) {} fn run(&self, _on_finish_launching: Box) { @@ -336,6 +341,7 @@ impl Platform for TestPlatform { fn prompt_for_new_path( &self, directory: &std::path::Path, + _suggested_name: Option<&str>, ) -> oneshot::Receiver>> { let (tx, rx) = oneshot::channel(); self.background_executor() diff --git a/crates/gpui/src/platform/test/window.rs b/crates/gpui/src/platform/test/window.rs index e15bd7aeecec5932eb6386bd47d168eda906dd63..9e87f4504ddd61e34b645ea69ea394c4940f9d55 100644 --- a/crates/gpui/src/platform/test/window.rs +++ b/crates/gpui/src/platform/test/window.rs @@ -1,8 +1,8 @@ use crate::{ AnyWindowHandle, AtlasKey, AtlasTextureId, AtlasTile, Bounds, DispatchEventResult, GpuSpecs, Pixels, PlatformAtlas, PlatformDisplay, PlatformInput, PlatformInputHandler, PlatformWindow, - Point, PromptButton, RequestFrameOptions, ScaledPixels, Size, TestPlatform, TileId, - WindowAppearance, WindowBackgroundAppearance, WindowBounds, WindowControlArea, WindowParams, + Point, PromptButton, RequestFrameOptions, Size, TestPlatform, TileId, WindowAppearance, + WindowBackgroundAppearance, WindowBounds, WindowControlArea, WindowParams, }; use collections::HashMap; use parking_lot::Mutex; @@ -289,7 +289,7 @@ impl PlatformWindow for TestWindow { unimplemented!() } - fn update_ime_position(&self, _bounds: Bounds) {} + fn update_ime_position(&self, _bounds: Bounds) {} fn gpu_specs(&self) -> Option { None diff --git a/crates/gpui/src/platform/windows.rs b/crates/gpui/src/platform/windows.rs index 77e0ca41bf8b394dc8bdd75e521aab3ba63dce2c..9cd1a7d05f4bcc6aa097db5dad64bdbc502575fc 100644 --- a/crates/gpui/src/platform/windows.rs +++ b/crates/gpui/src/platform/windows.rs @@ -2,6 +2,7 @@ mod clipboard; mod destination_list; mod direct_write; mod directx_atlas; +mod directx_devices; mod directx_renderer; mod dispatcher; mod display; @@ -18,6 +19,7 @@ pub(crate) use clipboard::*; pub(crate) use destination_list::*; pub(crate) use direct_write::*; pub(crate) use directx_atlas::*; +pub(crate) use directx_devices::*; pub(crate) use directx_renderer::*; pub(crate) use dispatcher::*; pub(crate) use display::*; diff --git a/crates/gpui/src/platform/windows/alpha_correction.hlsl b/crates/gpui/src/platform/windows/alpha_correction.hlsl new file mode 100644 index 0000000000000000000000000000000000000000..dc8d0b5dc52e9ef1484bfdf776161b5d5d8ce1b9 --- /dev/null +++ b/crates/gpui/src/platform/windows/alpha_correction.hlsl @@ -0,0 +1,28 @@ +float color_brightness(float3 color) { + // REC. 601 luminance coefficients for perceived brightness + return dot(color, float3(0.30f, 0.59f, 0.11f)); +} + +float light_on_dark_contrast(float enhancedContrast, float3 color) { + float brightness = color_brightness(color); + float multiplier = saturate(4.0f * (0.75f - brightness)); + return enhancedContrast * multiplier; +} + +float enhance_contrast(float alpha, float k) { + return alpha * (k + 1.0f) / (alpha * k + 1.0f); +} + +float apply_alpha_correction(float a, float b, float4 g) { + float brightness_adjustment = g.x * b + g.y; + float correction = brightness_adjustment * a + (g.z * b + g.w); + return a + a * (1.0f - a) * correction; +} + +float apply_contrast_and_gamma_correction(float sample, float3 color, float enhanced_contrast_factor, float4 gamma_ratios) { + float enhanced_contrast = light_on_dark_contrast(enhanced_contrast_factor, color); + float brightness = color_brightness(color); + + float contrasted = enhance_contrast(sample, enhanced_contrast); + return apply_alpha_correction(contrasted, brightness, gamma_ratios); +} diff --git a/crates/gpui/src/platform/windows/color_text_raster.hlsl b/crates/gpui/src/platform/windows/color_text_raster.hlsl index ccc5fa26f00d57f2b69e85965a66b6ecea98a833..2fbc156ba5ea9e443366558d10d0b8791c2eb488 100644 --- a/crates/gpui/src/platform/windows/color_text_raster.hlsl +++ b/crates/gpui/src/platform/windows/color_text_raster.hlsl @@ -1,3 +1,5 @@ +#include "alpha_correction.hlsl" + struct RasterVertexOutput { float4 position : SV_Position; float2 texcoord : TEXCOORD0; @@ -23,17 +25,20 @@ struct Bounds { int2 size; }; -Texture2D t_layer : register(t0); +Texture2D t_layer : register(t0); SamplerState s_layer : register(s0); cbuffer GlyphLayerTextureParams : register(b0) { Bounds bounds; float4 run_color; + float4 gamma_ratios; + float grayscale_enhanced_contrast; + float3 _pad; }; float4 emoji_rasterization_fragment(PixelInput input): SV_Target { - float3 sampled = t_layer.Sample(s_layer, input.texcoord.xy).rgb; - float alpha = (sampled.r + sampled.g + sampled.b) / 3; - - return float4(run_color.rgb, alpha); + float sample = t_layer.Sample(s_layer, input.texcoord.xy).r; + float alpha_corrected = apply_contrast_and_gamma_correction(sample, run_color.rgb, grayscale_enhanced_contrast, gamma_ratios); + float alpha = alpha_corrected * run_color.a; + return float4(run_color.rgb * alpha, alpha); } diff --git a/crates/gpui/src/platform/windows/direct_write.rs b/crates/gpui/src/platform/windows/direct_write.rs index 75cb50243b9c8ec845e256f4095cdedc40d2eea2..df3161bf079a8eb0cb04908e586f5d344519821e 100644 --- a/crates/gpui/src/platform/windows/direct_write.rs +++ b/crates/gpui/src/platform/windows/direct_write.rs @@ -1,7 +1,7 @@ use std::{borrow::Cow, sync::Arc}; use ::util::ResultExt; -use anyhow::Result; +use anyhow::{Context, Result}; use collections::HashMap; use itertools::Itertools; use parking_lot::{RwLock, RwLockUpgradableReadGuard}; @@ -10,12 +10,8 @@ use windows::{ Foundation::*, Globalization::GetUserDefaultLocaleName, Graphics::{ - Direct3D::D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, - Direct3D11::*, - DirectWrite::*, - Dxgi::Common::*, - Gdi::{IsRectEmpty, LOGFONTW}, - Imaging::*, + Direct3D::D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, Direct3D11::*, DirectWrite::*, + Dxgi::Common::*, Gdi::LOGFONTW, }, System::SystemServices::LOCALE_NAME_MAX_LENGTH, UI::WindowsAndMessaging::*, @@ -40,12 +36,10 @@ pub(crate) struct DirectWriteTextSystem(RwLock); struct DirectWriteComponent { locale: String, factory: IDWriteFactory5, - bitmap_factory: AgileReference, in_memory_loader: IDWriteInMemoryFontFileLoader, builder: IDWriteFontSetBuilder1, text_renderer: Arc, - render_params: IDWriteRenderingParams3, gpu_state: GPUState, } @@ -76,11 +70,10 @@ struct FontIdentifier { } impl DirectWriteComponent { - pub fn new(bitmap_factory: &IWICImagingFactory, gpu_context: &DirectXDevices) -> Result { + pub fn new(directx_devices: &DirectXDevices) -> Result { // todo: ideally this would not be a large unsafe block but smaller isolated ones for easier auditing unsafe { let factory: IDWriteFactory5 = DWriteCreateFactory(DWRITE_FACTORY_TYPE_SHARED)?; - let bitmap_factory = AgileReference::new(bitmap_factory)?; // The `IDWriteInMemoryFontFileLoader` here is supported starting from // Windows 10 Creators Update, which consequently requires the entire // `DirectWriteTextSystem` to run on `win10 1703`+. @@ -92,36 +85,14 @@ impl DirectWriteComponent { let locale = String::from_utf16_lossy(&locale_vec); let text_renderer = Arc::new(TextRendererWrapper::new(&locale)); - let render_params = { - let default_params: IDWriteRenderingParams3 = - factory.CreateRenderingParams()?.cast()?; - let gamma = default_params.GetGamma(); - let enhanced_contrast = default_params.GetEnhancedContrast(); - let gray_contrast = default_params.GetGrayscaleEnhancedContrast(); - let cleartype_level = default_params.GetClearTypeLevel(); - let grid_fit_mode = default_params.GetGridFitMode(); - - factory.CreateCustomRenderingParams( - gamma, - enhanced_contrast, - gray_contrast, - cleartype_level, - DWRITE_PIXEL_GEOMETRY_RGB, - DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, - grid_fit_mode, - )? - }; - - let gpu_state = GPUState::new(gpu_context)?; + let gpu_state = GPUState::new(directx_devices)?; Ok(DirectWriteComponent { locale, factory, - bitmap_factory, in_memory_loader, builder, text_renderer, - render_params, gpu_state, }) } @@ -129,9 +100,9 @@ impl DirectWriteComponent { } impl GPUState { - fn new(gpu_context: &DirectXDevices) -> Result { - let device = gpu_context.device.clone(); - let device_context = gpu_context.device_context.clone(); + fn new(directx_devices: &DirectXDevices) -> Result { + let device = directx_devices.device.clone(); + let device_context = directx_devices.device_context.clone(); let blend_state = { let mut blend_state = None; @@ -141,10 +112,10 @@ impl GPUState { RenderTarget: [ D3D11_RENDER_TARGET_BLEND_DESC { BlendEnable: true.into(), - SrcBlend: D3D11_BLEND_SRC_ALPHA, + SrcBlend: D3D11_BLEND_ONE, DestBlend: D3D11_BLEND_INV_SRC_ALPHA, BlendOp: D3D11_BLEND_OP_ADD, - SrcBlendAlpha: D3D11_BLEND_SRC_ALPHA, + SrcBlendAlpha: D3D11_BLEND_ONE, DestBlendAlpha: D3D11_BLEND_INV_SRC_ALPHA, BlendOpAlpha: D3D11_BLEND_OP_ADD, RenderTargetWriteMask: D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8, @@ -212,11 +183,8 @@ impl GPUState { } impl DirectWriteTextSystem { - pub(crate) fn new( - gpu_context: &DirectXDevices, - bitmap_factory: &IWICImagingFactory, - ) -> Result { - let components = DirectWriteComponent::new(bitmap_factory, gpu_context)?; + pub(crate) fn new(directx_devices: &DirectXDevices) -> Result { + let components = DirectWriteComponent::new(directx_devices)?; let system_font_collection = unsafe { let mut result = std::mem::zeroed(); components @@ -242,6 +210,10 @@ impl DirectWriteTextSystem { font_id_by_identifier: HashMap::default(), }))) } + + pub(crate) fn handle_gpu_lost(&self, directx_devices: &DirectXDevices) { + self.0.write().handle_gpu_lost(directx_devices); + } } impl PlatformTextSystem for DirectWriteTextSystem { @@ -762,18 +734,22 @@ impl DirectWriteState { unsafe { font.font_face.GetRecommendedRenderingMode( params.font_size.0, - // The dpi here seems that it has the same effect with `Some(&transform)` - 1.0, - 1.0, + // Using 96 as scale is applied by the transform + 96.0, + 96.0, Some(&transform), false, DWRITE_OUTLINE_THRESHOLD_ANTIALIASED, DWRITE_MEASURING_MODE_NATURAL, - &self.components.render_params, + None, &mut rendering_mode, &mut grid_fit_mode, )?; } + let rendering_mode = match rendering_mode { + DWRITE_RENDERING_MODE1_OUTLINE => DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, + m => m, + }; let glyph_analysis = unsafe { self.components.factory.CreateGlyphRunAnalysis( @@ -782,8 +758,7 @@ impl DirectWriteState { rendering_mode, DWRITE_MEASURING_MODE_NATURAL, grid_fit_mode, - // We're using cleartype not grayscale for monochrome is because it provides better quality - DWRITE_TEXT_ANTIALIAS_MODE_CLEARTYPE, + DWRITE_TEXT_ANTIALIAS_MODE_GRAYSCALE, baseline_origin_x, baseline_origin_y, ) @@ -794,10 +769,14 @@ impl DirectWriteState { fn raster_bounds(&self, params: &RenderGlyphParams) -> Result> { let glyph_analysis = self.create_glyph_run_analysis(params)?; - let bounds = unsafe { glyph_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_CLEARTYPE_3x1)? }; - // Some glyphs cannot be drawn with ClearType, such as bitmap fonts. In that case - // GetAlphaTextureBounds() supposedly returns an empty RECT, but I haven't tested that yet. - if !unsafe { IsRectEmpty(&bounds) }.as_bool() { + let bounds = unsafe { glyph_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_ALIASED_1x1)? }; + + if bounds.right < bounds.left { + Ok(Bounds { + origin: point(0.into(), 0.into()), + size: size(0.into(), 0.into()), + }) + } else { Ok(Bounds { origin: point(bounds.left.into(), bounds.top.into()), size: size( @@ -805,25 +784,6 @@ impl DirectWriteState { (bounds.bottom - bounds.top).into(), ), }) - } else { - // If it's empty, retry with grayscale AA. - let bounds = - unsafe { glyph_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_ALIASED_1x1)? }; - - if bounds.right < bounds.left { - Ok(Bounds { - origin: point(0.into(), 0.into()), - size: size(0.into(), 0.into()), - }) - } else { - Ok(Bounds { - origin: point(bounds.left.into(), bounds.top.into()), - size: size( - (bounds.right - bounds.left).into(), - (bounds.bottom - bounds.top).into(), - ), - }) - } } } @@ -850,7 +810,7 @@ impl DirectWriteState { } let bitmap_data = if params.is_emoji { - if let Ok(color) = self.rasterize_color(¶ms, glyph_bounds) { + if let Ok(color) = self.rasterize_color(params, glyph_bounds) { color } else { let monochrome = self.rasterize_monochrome(params, glyph_bounds)?; @@ -872,13 +832,12 @@ impl DirectWriteState { glyph_bounds: Bounds, ) -> Result> { let mut bitmap_data = - vec![0u8; glyph_bounds.size.width.0 as usize * glyph_bounds.size.height.0 as usize * 3]; + vec![0u8; glyph_bounds.size.width.0 as usize * glyph_bounds.size.height.0 as usize]; let glyph_analysis = self.create_glyph_run_analysis(params)?; unsafe { glyph_analysis.CreateAlphaTexture( - // We're using cleartype not grayscale for monochrome is because it provides better quality - DWRITE_TEXTURE_CLEARTYPE_3x1, + DWRITE_TEXTURE_ALIASED_1x1, &RECT { left: glyph_bounds.origin.x.0, top: glyph_bounds.origin.y.0, @@ -889,30 +848,6 @@ impl DirectWriteState { )?; } - let bitmap_factory = self.components.bitmap_factory.resolve()?; - let bitmap = unsafe { - bitmap_factory.CreateBitmapFromMemory( - glyph_bounds.size.width.0 as u32, - glyph_bounds.size.height.0 as u32, - &GUID_WICPixelFormat24bppRGB, - glyph_bounds.size.width.0 as u32 * 3, - &bitmap_data, - ) - }?; - - let grayscale_bitmap = - unsafe { WICConvertBitmapSource(&GUID_WICPixelFormat8bppGray, &bitmap) }?; - - let mut bitmap_data = - vec![0u8; glyph_bounds.size.width.0 as usize * glyph_bounds.size.height.0 as usize]; - unsafe { - grayscale_bitmap.CopyPixels( - std::ptr::null() as _, - glyph_bounds.size.width.0 as u32, - &mut bitmap_data, - ) - }?; - Ok(bitmap_data) } @@ -981,25 +916,24 @@ impl DirectWriteState { DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, DWRITE_MEASURING_MODE_NATURAL, DWRITE_GRID_FIT_MODE_DEFAULT, - DWRITE_TEXT_ANTIALIAS_MODE_CLEARTYPE, + DWRITE_TEXT_ANTIALIAS_MODE_GRAYSCALE, baseline_origin_x, baseline_origin_y, ) }?; let color_bounds = - unsafe { color_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_CLEARTYPE_3x1) }?; + unsafe { color_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_ALIASED_1x1) }?; let color_size = size( color_bounds.right - color_bounds.left, color_bounds.bottom - color_bounds.top, ); if color_size.width > 0 && color_size.height > 0 { - let mut alpha_data = - vec![0u8; (color_size.width * color_size.height * 3) as usize]; + let mut alpha_data = vec![0u8; (color_size.width * color_size.height) as usize]; unsafe { color_analysis.CreateAlphaTexture( - DWRITE_TEXTURE_CLEARTYPE_3x1, + DWRITE_TEXTURE_ALIASED_1x1, &color_bounds, &mut alpha_data, ) @@ -1015,10 +949,6 @@ impl DirectWriteState { } }; let bounds = bounds(point(color_bounds.left, color_bounds.top), color_size); - let alpha_data = alpha_data - .chunks_exact(3) - .flat_map(|chunk| [chunk[0], chunk[1], chunk[2], 255]) - .collect::>(); glyph_layers.push(GlyphLayerTexture::new( &self.components.gpu_state, run_color, @@ -1135,10 +1065,18 @@ impl DirectWriteState { unsafe { device_context.PSSetSamplers(0, Some(&gpu_state.sampler)) }; unsafe { device_context.OMSetBlendState(&gpu_state.blend_state, None, 0xffffffff) }; + let crate::FontInfo { + gamma_ratios, + grayscale_enhanced_contrast, + } = DirectXRenderer::get_font_info(); + for layer in glyph_layers { let params = GlyphLayerTextureParams { run_color: layer.run_color, bounds: layer.bounds, + gamma_ratios: *gamma_ratios, + grayscale_enhanced_contrast: *grayscale_enhanced_contrast, + _pad: [0f32; 3], }; unsafe { let mut dest = std::mem::zeroed(); @@ -1202,6 +1140,20 @@ impl DirectWriteState { }; } + // Convert from premultiplied to straight alpha + for chunk in rasterized.chunks_exact_mut(4) { + let b = chunk[0] as f32; + let g = chunk[1] as f32; + let r = chunk[2] as f32; + let a = chunk[3] as f32; + if a > 0.0 { + let inv_a = 255.0 / a; + chunk[0] = (b * inv_a).clamp(0.0, 255.0) as u8; + chunk[1] = (g * inv_a).clamp(0.0, 255.0) as u8; + chunk[2] = (r * inv_a).clamp(0.0, 255.0) as u8; + } + } + Ok(rasterized) } @@ -1263,6 +1215,20 @@ impl DirectWriteState { )); result } + + fn handle_gpu_lost(&mut self, directx_devices: &DirectXDevices) { + try_to_recover_from_device_lost( + || GPUState::new(directx_devices).context("Recreating GPU state for DirectWrite"), + |gpu_state| self.components.gpu_state = gpu_state, + || { + log::error!( + "Failed to recreate GPU state for DirectWrite after multiple attempts." + ); + // Do something here? + // At this point, the device loss is considered unrecoverable. + }, + ); + } } impl Drop for DirectWriteState { @@ -1298,14 +1264,14 @@ impl GlyphLayerTexture { Height: texture_size.height as u32, MipLevels: 1, ArraySize: 1, - Format: DXGI_FORMAT_R8G8B8A8_UNORM, + Format: DXGI_FORMAT_R8_UNORM, SampleDesc: DXGI_SAMPLE_DESC { Count: 1, Quality: 0, }, Usage: D3D11_USAGE_DEFAULT, BindFlags: D3D11_BIND_SHADER_RESOURCE.0 as u32, - CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + CPUAccessFlags: 0, MiscFlags: 0, }; @@ -1334,7 +1300,7 @@ impl GlyphLayerTexture { 0, None, alpha_data.as_ptr() as _, - (texture_size.width * 4) as u32, + texture_size.width as u32, 0, ) }; @@ -1352,6 +1318,9 @@ impl GlyphLayerTexture { struct GlyphLayerTextureParams { bounds: Bounds, run_color: Rgba, + gamma_ratios: [f32; 4], + grayscale_enhanced_contrast: f32, + _pad: [f32; 3], } struct TextRendererWrapper(pub IDWriteTextRenderer); @@ -1784,7 +1753,7 @@ fn apply_font_features( } unsafe { - direct_write_features.AddFontFeature(make_direct_write_feature(&tag, *value))?; + direct_write_features.AddFontFeature(make_direct_write_feature(tag, *value))?; } } unsafe { diff --git a/crates/gpui/src/platform/windows/directx_atlas.rs b/crates/gpui/src/platform/windows/directx_atlas.rs index 6bced4c11d922ed2c514b9a70fe7e582d7b15a6b..38c22a41bf9d32cf43f585050390b75602a6bf42 100644 --- a/crates/gpui/src/platform/windows/directx_atlas.rs +++ b/crates/gpui/src/platform/windows/directx_atlas.rs @@ -3,9 +3,8 @@ use etagere::BucketedAtlasAllocator; use parking_lot::Mutex; use windows::Win32::Graphics::{ Direct3D11::{ - D3D11_BIND_SHADER_RESOURCE, D3D11_BOX, D3D11_CPU_ACCESS_WRITE, D3D11_TEXTURE2D_DESC, - D3D11_USAGE_DEFAULT, ID3D11Device, ID3D11DeviceContext, ID3D11ShaderResourceView, - ID3D11Texture2D, + D3D11_BIND_SHADER_RESOURCE, D3D11_BOX, D3D11_TEXTURE2D_DESC, D3D11_USAGE_DEFAULT, + ID3D11Device, ID3D11DeviceContext, ID3D11ShaderResourceView, ID3D11Texture2D, }, Dxgi::Common::*, }; @@ -189,7 +188,7 @@ impl DirectXAtlasState { }, Usage: D3D11_USAGE_DEFAULT, BindFlags: bind_flag.0 as u32, - CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + CPUAccessFlags: 0, MiscFlags: 0, }; let mut texture: Option = None; diff --git a/crates/gpui/src/platform/windows/directx_devices.rs b/crates/gpui/src/platform/windows/directx_devices.rs new file mode 100644 index 0000000000000000000000000000000000000000..4fa4db827492faffaa0d8912b1f37b52f8cfc88f --- /dev/null +++ b/crates/gpui/src/platform/windows/directx_devices.rs @@ -0,0 +1,199 @@ +use anyhow::{Context, Result}; +use util::ResultExt; +use windows::Win32::{ + Foundation::HMODULE, + Graphics::{ + Direct3D::{ + D3D_DRIVER_TYPE_UNKNOWN, D3D_FEATURE_LEVEL, D3D_FEATURE_LEVEL_10_1, + D3D_FEATURE_LEVEL_11_0, D3D_FEATURE_LEVEL_11_1, + }, + Direct3D11::{ + D3D11_CREATE_DEVICE_BGRA_SUPPORT, D3D11_CREATE_DEVICE_DEBUG, + D3D11_FEATURE_D3D10_X_HARDWARE_OPTIONS, D3D11_FEATURE_DATA_D3D10_X_HARDWARE_OPTIONS, + D3D11_SDK_VERSION, D3D11CreateDevice, ID3D11Device, ID3D11DeviceContext, + }, + Dxgi::{ + CreateDXGIFactory2, DXGI_CREATE_FACTORY_DEBUG, DXGI_CREATE_FACTORY_FLAGS, + DXGI_GPU_PREFERENCE_MINIMUM_POWER, IDXGIAdapter1, IDXGIFactory6, + }, + }, +}; + +pub(crate) fn try_to_recover_from_device_lost( + mut f: impl FnMut() -> Result, + on_success: impl FnOnce(T), + on_error: impl FnOnce(), +) { + let result = (0..5).find_map(|i| { + if i > 0 { + // Add a small delay before retrying + std::thread::sleep(std::time::Duration::from_millis(100)); + } + f().log_err() + }); + + if let Some(result) = result { + on_success(result); + } else { + on_error(); + } +} + +#[derive(Clone)] +pub(crate) struct DirectXDevices { + pub(crate) adapter: IDXGIAdapter1, + pub(crate) dxgi_factory: IDXGIFactory6, + pub(crate) device: ID3D11Device, + pub(crate) device_context: ID3D11DeviceContext, +} + +impl DirectXDevices { + pub(crate) fn new() -> Result { + let debug_layer_available = check_debug_layer_available(); + let dxgi_factory = + get_dxgi_factory(debug_layer_available).context("Creating DXGI factory")?; + let adapter = + get_adapter(&dxgi_factory, debug_layer_available).context("Getting DXGI adapter")?; + let (device, device_context) = { + let mut context: Option = None; + let mut feature_level = D3D_FEATURE_LEVEL::default(); + let device = get_device( + &adapter, + Some(&mut context), + Some(&mut feature_level), + debug_layer_available, + ) + .context("Creating Direct3D device")?; + match feature_level { + D3D_FEATURE_LEVEL_11_1 => { + log::info!("Created device with Direct3D 11.1 feature level.") + } + D3D_FEATURE_LEVEL_11_0 => { + log::info!("Created device with Direct3D 11.0 feature level.") + } + D3D_FEATURE_LEVEL_10_1 => { + log::info!("Created device with Direct3D 10.1 feature level.") + } + _ => unreachable!(), + } + (device, context.unwrap()) + }; + + Ok(Self { + adapter, + dxgi_factory, + device, + device_context, + }) + } +} + +#[inline] +fn check_debug_layer_available() -> bool { + #[cfg(debug_assertions)] + { + use windows::Win32::Graphics::Dxgi::{DXGIGetDebugInterface1, IDXGIInfoQueue}; + + unsafe { DXGIGetDebugInterface1::(0) } + .log_err() + .is_some() + } + #[cfg(not(debug_assertions))] + { + false + } +} + +#[inline] +fn get_dxgi_factory(debug_layer_available: bool) -> Result { + let factory_flag = if debug_layer_available { + DXGI_CREATE_FACTORY_DEBUG + } else { + #[cfg(debug_assertions)] + log::warn!( + "Failed to get DXGI debug interface. DirectX debugging features will be disabled." + ); + DXGI_CREATE_FACTORY_FLAGS::default() + }; + unsafe { Ok(CreateDXGIFactory2(factory_flag)?) } +} + +#[inline] +fn get_adapter(dxgi_factory: &IDXGIFactory6, debug_layer_available: bool) -> Result { + for adapter_index in 0.. { + let adapter: IDXGIAdapter1 = unsafe { + dxgi_factory + .EnumAdapterByGpuPreference(adapter_index, DXGI_GPU_PREFERENCE_MINIMUM_POWER) + }?; + if let Ok(desc) = unsafe { adapter.GetDesc1() } { + let gpu_name = String::from_utf16_lossy(&desc.Description) + .trim_matches(char::from(0)) + .to_string(); + log::info!("Using GPU: {}", gpu_name); + } + // Check to see whether the adapter supports Direct3D 11, but don't + // create the actual device yet. + if get_device(&adapter, None, None, debug_layer_available) + .log_err() + .is_some() + { + return Ok(adapter); + } + } + + unreachable!() +} + +#[inline] +fn get_device( + adapter: &IDXGIAdapter1, + context: Option<*mut Option>, + feature_level: Option<*mut D3D_FEATURE_LEVEL>, + debug_layer_available: bool, +) -> Result { + let mut device: Option = None; + let device_flags = if debug_layer_available { + D3D11_CREATE_DEVICE_BGRA_SUPPORT | D3D11_CREATE_DEVICE_DEBUG + } else { + D3D11_CREATE_DEVICE_BGRA_SUPPORT + }; + unsafe { + D3D11CreateDevice( + adapter, + D3D_DRIVER_TYPE_UNKNOWN, + HMODULE::default(), + device_flags, + // 4x MSAA is required for Direct3D Feature Level 10.1 or better + Some(&[ + D3D_FEATURE_LEVEL_11_1, + D3D_FEATURE_LEVEL_11_0, + D3D_FEATURE_LEVEL_10_1, + ]), + D3D11_SDK_VERSION, + Some(&mut device), + feature_level, + context, + )?; + } + let device = device.unwrap(); + let mut data = D3D11_FEATURE_DATA_D3D10_X_HARDWARE_OPTIONS::default(); + unsafe { + device + .CheckFeatureSupport( + D3D11_FEATURE_D3D10_X_HARDWARE_OPTIONS, + &mut data as *mut _ as _, + std::mem::size_of::() as u32, + ) + .context("Checking GPU device feature support")?; + } + if data + .ComputeShaders_Plus_RawAndStructuredBuffers_Via_Shader_4_x + .as_bool() + { + Ok(device) + } else { + Err(anyhow::anyhow!( + "Required feature StructuredBuffer is not supported by GPU/driver" + )) + } +} diff --git a/crates/gpui/src/platform/windows/directx_renderer.rs b/crates/gpui/src/platform/windows/directx_renderer.rs index 4e72ded5341479c2d861c441fc3c43d5fee7056c..2baa237cdaa196da225070c241232fc6af0f0ff4 100644 --- a/crates/gpui/src/platform/windows/directx_renderer.rs +++ b/crates/gpui/src/platform/windows/directx_renderer.rs @@ -1,14 +1,18 @@ -use std::{mem::ManuallyDrop, sync::Arc}; +use std::{ + mem::ManuallyDrop, + sync::{Arc, OnceLock}, +}; use ::util::ResultExt; use anyhow::{Context, Result}; use windows::{ Win32::{ - Foundation::{HMODULE, HWND}, + Foundation::HWND, Graphics::{ Direct3D::*, Direct3D11::*, DirectComposition::*, + DirectWrite::*, Dxgi::{Common::*, *}, }, }, @@ -27,21 +31,27 @@ const RENDER_TARGET_FORMAT: DXGI_FORMAT = DXGI_FORMAT_B8G8R8A8_UNORM; // This configuration is used for MSAA rendering on paths only, and it's guaranteed to be supported by DirectX 11. const PATH_MULTISAMPLE_COUNT: u32 = 4; +pub(crate) struct FontInfo { + pub gamma_ratios: [f32; 4], + pub grayscale_enhanced_contrast: f32, +} + pub(crate) struct DirectXRenderer { hwnd: HWND, atlas: Arc, - devices: ManuallyDrop, + devices: ManuallyDrop, resources: ManuallyDrop, globals: DirectXGlobalElements, pipelines: DirectXRenderPipelines, direct_composition: Option, + font_info: &'static FontInfo, } /// Direct3D objects #[derive(Clone)] -pub(crate) struct DirectXDevices { - adapter: IDXGIAdapter1, - dxgi_factory: IDXGIFactory6, +pub(crate) struct DirectXRendererDevices { + pub(crate) adapter: IDXGIAdapter1, + pub(crate) dxgi_factory: IDXGIFactory6, pub(crate) device: ID3D11Device, pub(crate) device_context: ID3D11DeviceContext, dxgi_device: Option, @@ -86,39 +96,17 @@ struct DirectComposition { comp_visual: IDCompositionVisual, } -impl DirectXDevices { - pub(crate) fn new(disable_direct_composition: bool) -> Result> { - let debug_layer_available = check_debug_layer_available(); - let dxgi_factory = - get_dxgi_factory(debug_layer_available).context("Creating DXGI factory")?; - let adapter = - get_adapter(&dxgi_factory, debug_layer_available).context("Getting DXGI adapter")?; - let (device, device_context) = { - let mut device: Option = None; - let mut context: Option = None; - let mut feature_level = D3D_FEATURE_LEVEL::default(); - get_device( - &adapter, - Some(&mut device), - Some(&mut context), - Some(&mut feature_level), - debug_layer_available, - ) - .context("Creating Direct3D device")?; - match feature_level { - D3D_FEATURE_LEVEL_11_1 => { - log::info!("Created device with Direct3D 11.1 feature level.") - } - D3D_FEATURE_LEVEL_11_0 => { - log::info!("Created device with Direct3D 11.0 feature level.") - } - D3D_FEATURE_LEVEL_10_1 => { - log::info!("Created device with Direct3D 10.1 feature level.") - } - _ => unreachable!(), - } - (device.unwrap(), context.unwrap()) - }; +impl DirectXRendererDevices { + pub(crate) fn new( + directx_devices: &DirectXDevices, + disable_direct_composition: bool, + ) -> Result> { + let DirectXDevices { + adapter, + dxgi_factory, + device, + device_context, + } = directx_devices; let dxgi_device = if disable_direct_composition { None } else { @@ -126,23 +114,27 @@ impl DirectXDevices { }; Ok(ManuallyDrop::new(Self { - adapter, - dxgi_factory, + adapter: adapter.clone(), + dxgi_factory: dxgi_factory.clone(), + device: device.clone(), + device_context: device_context.clone(), dxgi_device, - device, - device_context, })) } } impl DirectXRenderer { - pub(crate) fn new(hwnd: HWND, disable_direct_composition: bool) -> Result { + pub(crate) fn new( + hwnd: HWND, + directx_devices: &DirectXDevices, + disable_direct_composition: bool, + ) -> Result { if disable_direct_composition { log::info!("Direct Composition is disabled."); } - let devices = - DirectXDevices::new(disable_direct_composition).context("Creating DirectX devices")?; + let devices = DirectXRendererDevices::new(directx_devices, disable_direct_composition) + .context("Creating DirectX devices")?; let atlas = Arc::new(DirectXAtlas::new(&devices.device, &devices.device_context)); let resources = DirectXResources::new(&devices, 1, 1, hwnd, disable_direct_composition) @@ -171,6 +163,7 @@ impl DirectXRenderer { globals, pipelines, direct_composition, + font_info: Self::get_font_info(), }) } @@ -183,10 +176,12 @@ impl DirectXRenderer { &self.devices.device_context, self.globals.global_params_buffer[0].as_ref().unwrap(), &[GlobalParams { + gamma_ratios: self.font_info.gamma_ratios, viewport_size: [ self.resources.viewport[0].Width, self.resources.viewport[0].Height, ], + grayscale_enhanced_contrast: self.font_info.grayscale_enhanced_contrast, _pad: 0, }], )?; @@ -205,28 +200,30 @@ impl DirectXRenderer { Ok(()) } + #[inline] fn present(&mut self) -> Result<()> { - unsafe { - let result = self.resources.swap_chain.Present(0, DXGI_PRESENT(0)); - // Presenting the swap chain can fail if the DirectX device was removed or reset. - if result == DXGI_ERROR_DEVICE_REMOVED || result == DXGI_ERROR_DEVICE_RESET { - let reason = self.devices.device.GetDeviceRemovedReason(); + let result = unsafe { self.resources.swap_chain.Present(0, DXGI_PRESENT(0)) }; + result.ok().context("Presenting swap chain failed") + } + + pub(crate) fn handle_device_lost(&mut self, directx_devices: &DirectXDevices) { + try_to_recover_from_device_lost( + || { + self.handle_device_lost_impl(directx_devices) + .context("DirectXRenderer handling device lost") + }, + |_| {}, + || { log::error!( - "DirectX device removed or reset when drawing. Reason: {:?}", - reason + "DirectXRenderer failed to recover from device lost after multiple attempts" ); - self.handle_device_lost()?; - } else { - result.ok()?; - } - } - Ok(()) + // Do something here? + // At this point, the device loss is considered unrecoverable. + }, + ); } - fn handle_device_lost(&mut self) -> Result<()> { - // Here we wait a bit to ensure the the system has time to recover from the device lost state. - // If we don't wait, the final drawing result will be blank. - std::thread::sleep(std::time::Duration::from_millis(300)); + fn handle_device_lost_impl(&mut self, directx_devices: &DirectXDevices) -> Result<()> { let disable_direct_composition = self.direct_composition.is_none(); unsafe { @@ -249,7 +246,7 @@ impl DirectXRenderer { ManuallyDrop::drop(&mut self.devices); } - let devices = DirectXDevices::new(disable_direct_composition) + let devices = DirectXRendererDevices::new(directx_devices, disable_direct_composition) .context("Recreating DirectX devices")?; let resources = DirectXResources::new( &devices, @@ -324,49 +321,39 @@ impl DirectXRenderer { if self.resources.width == width && self.resources.height == height { return Ok(()); } + self.resources.width = width; + self.resources.height = height; + + // Clear the render target before resizing + unsafe { self.devices.device_context.OMSetRenderTargets(None, None) }; + unsafe { ManuallyDrop::drop(&mut self.resources.render_target) }; + drop(self.resources.render_target_view[0].take().unwrap()); + + // Resizing the swap chain requires a call to the underlying DXGI adapter, which can return the device removed error. + // The app might have moved to a monitor that's attached to a different graphics device. + // When a graphics device is removed or reset, the desktop resolution often changes, resulting in a window size change. + // But here we just return the error, because we are handling device lost scenarios elsewhere. unsafe { - // Clear the render target before resizing - self.devices.device_context.OMSetRenderTargets(None, None); - ManuallyDrop::drop(&mut self.resources.render_target); - drop(self.resources.render_target_view[0].take().unwrap()); - - let result = self.resources.swap_chain.ResizeBuffers( - BUFFER_COUNT as u32, - width, - height, - RENDER_TARGET_FORMAT, - DXGI_SWAP_CHAIN_FLAG(0), - ); - // Resizing the swap chain requires a call to the underlying DXGI adapter, which can return the device removed error. - // The app might have moved to a monitor that's attached to a different graphics device. - // When a graphics device is removed or reset, the desktop resolution often changes, resulting in a window size change. - match result { - Ok(_) => {} - Err(e) => { - if e.code() == DXGI_ERROR_DEVICE_REMOVED || e.code() == DXGI_ERROR_DEVICE_RESET - { - let reason = self.devices.device.GetDeviceRemovedReason(); - log::error!( - "DirectX device removed or reset when resizing. Reason: {:?}", - reason - ); - self.resources.width = width; - self.resources.height = height; - self.handle_device_lost()?; - return Ok(()); - } else { - log::error!("Failed to resize swap chain: {:?}", e); - return Err(e.into()); - } - } - } - self.resources - .recreate_resources(&self.devices, width, height)?; + .swap_chain + .ResizeBuffers( + BUFFER_COUNT as u32, + width, + height, + RENDER_TARGET_FORMAT, + DXGI_SWAP_CHAIN_FLAG(0), + ) + .context("Failed to resize swap chain")?; + } + + self.resources + .recreate_resources(&self.devices, width, height)?; + unsafe { self.devices .device_context .OMSetRenderTargets(Some(&self.resources.render_target_view), None); } + Ok(()) } @@ -617,11 +604,57 @@ impl DirectXRenderer { driver_info: driver_version, }) } + + pub(crate) fn get_font_info() -> &'static FontInfo { + static CACHED_FONT_INFO: OnceLock = OnceLock::new(); + CACHED_FONT_INFO.get_or_init(|| unsafe { + let factory: IDWriteFactory5 = DWriteCreateFactory(DWRITE_FACTORY_TYPE_SHARED).unwrap(); + let render_params: IDWriteRenderingParams1 = + factory.CreateRenderingParams().unwrap().cast().unwrap(); + FontInfo { + gamma_ratios: Self::get_gamma_ratios(render_params.GetGamma()), + grayscale_enhanced_contrast: render_params.GetGrayscaleEnhancedContrast(), + } + }) + } + + // Gamma ratios for brightening/darkening edges for better contrast + // https://github.com/microsoft/terminal/blob/1283c0f5b99a2961673249fa77c6b986efb5086c/src/renderer/atlas/dwrite.cpp#L50 + fn get_gamma_ratios(gamma: f32) -> [f32; 4] { + const GAMMA_INCORRECT_TARGET_RATIOS: [[f32; 4]; 13] = [ + [0.0000 / 4.0, 0.0000 / 4.0, 0.0000 / 4.0, 0.0000 / 4.0], // gamma = 1.0 + [0.0166 / 4.0, -0.0807 / 4.0, 0.2227 / 4.0, -0.0751 / 4.0], // gamma = 1.1 + [0.0350 / 4.0, -0.1760 / 4.0, 0.4325 / 4.0, -0.1370 / 4.0], // gamma = 1.2 + [0.0543 / 4.0, -0.2821 / 4.0, 0.6302 / 4.0, -0.1876 / 4.0], // gamma = 1.3 + [0.0739 / 4.0, -0.3963 / 4.0, 0.8167 / 4.0, -0.2287 / 4.0], // gamma = 1.4 + [0.0933 / 4.0, -0.5161 / 4.0, 0.9926 / 4.0, -0.2616 / 4.0], // gamma = 1.5 + [0.1121 / 4.0, -0.6395 / 4.0, 1.1588 / 4.0, -0.2877 / 4.0], // gamma = 1.6 + [0.1300 / 4.0, -0.7649 / 4.0, 1.3159 / 4.0, -0.3080 / 4.0], // gamma = 1.7 + [0.1469 / 4.0, -0.8911 / 4.0, 1.4644 / 4.0, -0.3234 / 4.0], // gamma = 1.8 + [0.1627 / 4.0, -1.0170 / 4.0, 1.6051 / 4.0, -0.3347 / 4.0], // gamma = 1.9 + [0.1773 / 4.0, -1.1420 / 4.0, 1.7385 / 4.0, -0.3426 / 4.0], // gamma = 2.0 + [0.1908 / 4.0, -1.2652 / 4.0, 1.8650 / 4.0, -0.3476 / 4.0], // gamma = 2.1 + [0.2031 / 4.0, -1.3864 / 4.0, 1.9851 / 4.0, -0.3501 / 4.0], // gamma = 2.2 + ]; + + const NORM13: f32 = ((0x10000 as f64) / (255.0 * 255.0) * 4.0) as f32; + const NORM24: f32 = ((0x100 as f64) / (255.0) * 4.0) as f32; + + let index = ((gamma * 10.0).round() as usize).clamp(10, 22) - 10; + let ratios = GAMMA_INCORRECT_TARGET_RATIOS[index]; + + [ + ratios[0] * NORM13, + ratios[1] * NORM24, + ratios[2] * NORM13, + ratios[3] * NORM24, + ] + } } impl DirectXResources { pub fn new( - devices: &DirectXDevices, + devices: &DirectXRendererDevices, width: u32, height: u32, hwnd: HWND, @@ -666,7 +699,7 @@ impl DirectXResources { #[inline] fn recreate_resources( &mut self, - devices: &DirectXDevices, + devices: &DirectXRendererDevices, width: u32, height: u32, ) -> Result<()> { @@ -686,8 +719,6 @@ impl DirectXResources { self.path_intermediate_msaa_view = path_intermediate_msaa_view; self.path_intermediate_srv = path_intermediate_srv; self.viewport = viewport; - self.width = width; - self.height = height; Ok(()) } } @@ -758,7 +789,7 @@ impl DirectXRenderPipelines { impl DirectComposition { pub fn new(dxgi_device: &IDXGIDevice, hwnd: HWND) -> Result { - let comp_device = get_comp_device(&dxgi_device)?; + let comp_device = get_comp_device(dxgi_device)?; let comp_target = unsafe { comp_device.CreateTargetForHwnd(hwnd, true) }?; let comp_visual = unsafe { comp_device.CreateVisual() }?; @@ -822,8 +853,10 @@ impl DirectXGlobalElements { #[derive(Debug, Default)] #[repr(C)] struct GlobalParams { + gamma_ratios: [f32; 4], viewport_size: [f32; 2], - _pad: u64, + grayscale_enhanced_contrast: f32, + _pad: u32, } struct PipelineState { @@ -980,92 +1013,6 @@ impl Drop for DirectXResources { } } -#[inline] -fn check_debug_layer_available() -> bool { - #[cfg(debug_assertions)] - { - unsafe { DXGIGetDebugInterface1::(0) } - .log_err() - .is_some() - } - #[cfg(not(debug_assertions))] - { - false - } -} - -#[inline] -fn get_dxgi_factory(debug_layer_available: bool) -> Result { - let factory_flag = if debug_layer_available { - DXGI_CREATE_FACTORY_DEBUG - } else { - #[cfg(debug_assertions)] - log::warn!( - "Failed to get DXGI debug interface. DirectX debugging features will be disabled." - ); - DXGI_CREATE_FACTORY_FLAGS::default() - }; - unsafe { Ok(CreateDXGIFactory2(factory_flag)?) } -} - -fn get_adapter(dxgi_factory: &IDXGIFactory6, debug_layer_available: bool) -> Result { - for adapter_index in 0.. { - let adapter: IDXGIAdapter1 = unsafe { - dxgi_factory - .EnumAdapterByGpuPreference(adapter_index, DXGI_GPU_PREFERENCE_MINIMUM_POWER) - }?; - if let Ok(desc) = unsafe { adapter.GetDesc1() } { - let gpu_name = String::from_utf16_lossy(&desc.Description) - .trim_matches(char::from(0)) - .to_string(); - log::info!("Using GPU: {}", gpu_name); - } - // Check to see whether the adapter supports Direct3D 11, but don't - // create the actual device yet. - if get_device(&adapter, None, None, None, debug_layer_available) - .log_err() - .is_some() - { - return Ok(adapter); - } - } - - unreachable!() -} - -fn get_device( - adapter: &IDXGIAdapter1, - device: Option<*mut Option>, - context: Option<*mut Option>, - feature_level: Option<*mut D3D_FEATURE_LEVEL>, - debug_layer_available: bool, -) -> Result<()> { - let device_flags = if debug_layer_available { - D3D11_CREATE_DEVICE_BGRA_SUPPORT | D3D11_CREATE_DEVICE_DEBUG - } else { - D3D11_CREATE_DEVICE_BGRA_SUPPORT - }; - unsafe { - D3D11CreateDevice( - adapter, - D3D_DRIVER_TYPE_UNKNOWN, - HMODULE::default(), - device_flags, - // 4x MSAA is required for Direct3D Feature Level 10.1 or better - Some(&[ - D3D_FEATURE_LEVEL_11_1, - D3D_FEATURE_LEVEL_11_0, - D3D_FEATURE_LEVEL_10_1, - ]), - D3D11_SDK_VERSION, - device, - feature_level, - context, - )?; - } - Ok(()) -} - #[inline] fn get_comp_device(dxgi_device: &IDXGIDevice) -> Result { Ok(unsafe { DCompositionCreateDevice(dxgi_device)? }) @@ -1130,7 +1077,7 @@ fn create_swap_chain( #[inline] fn create_resources( - devices: &DirectXDevices, + devices: &DirectXRendererDevices, swap_chain: &IDXGISwapChain1, width: u32, height: u32, @@ -1144,7 +1091,7 @@ fn create_resources( [D3D11_VIEWPORT; 1], )> { let (render_target, render_target_view) = - create_render_target_and_its_view(&swap_chain, &devices.device)?; + create_render_target_and_its_view(swap_chain, &devices.device)?; let (path_intermediate_texture, path_intermediate_srv) = create_path_intermediate_texture(&devices.device, width, height)?; let (path_intermediate_msaa_texture, path_intermediate_msaa_view) = @@ -1544,6 +1491,10 @@ pub(crate) mod shader_resources { #[cfg(debug_assertions)] pub(super) fn build_shader_blob(entry: ShaderModule, target: ShaderTarget) -> Result { unsafe { + use windows::Win32::Graphics::{ + Direct3D::ID3DInclude, Hlsl::D3D_COMPILE_STANDARD_FILE_INCLUDE, + }; + let shader_name = if matches!(entry, ShaderModule::EmojiRasterization) { "color_text_raster.hlsl" } else { @@ -1572,10 +1523,15 @@ pub(crate) mod shader_resources { let entry_point = PCSTR::from_raw(entry.as_ptr()); let target_cstr = PCSTR::from_raw(target.as_ptr()); + // really dirty trick because winapi bindings are unhappy otherwise + let include_handler = &std::mem::transmute::( + D3D_COMPILE_STANDARD_FILE_INCLUDE as usize, + ); + let ret = D3DCompileFromFile( &HSTRING::from(shader_path.to_str().unwrap()), None, - None, + include_handler, entry_point, target_cstr, D3DCOMPILE_DEBUG | D3DCOMPILE_SKIP_OPTIMIZATION, @@ -1760,7 +1716,7 @@ mod amd { anyhow::bail!("Failed to initialize AMD AGS, error code: {}", result); } - // Vulkan acctually returns this as the driver version + // Vulkan actually returns this as the driver version let software_version = if !gpu_info.radeon_software_version.is_null() { std::ffi::CStr::from_ptr(gpu_info.radeon_software_version) .to_string_lossy() diff --git a/crates/gpui/src/platform/windows/dispatcher.rs b/crates/gpui/src/platform/windows/dispatcher.rs index e5b9c020d511b478779dc1affb3927018f8f7b3f..3707a69047cf53cf68a40b3711e135f77dff8be3 100644 --- a/crates/gpui/src/platform/windows/dispatcher.rs +++ b/crates/gpui/src/platform/windows/dispatcher.rs @@ -9,41 +9,42 @@ use parking::Parker; use parking_lot::Mutex; use util::ResultExt; use windows::{ - Foundation::TimeSpan, System::Threading::{ - ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemOptions, - WorkItemPriority, + ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemPriority, }, Win32::{ Foundation::{LPARAM, WPARAM}, - UI::WindowsAndMessaging::PostThreadMessageW, + UI::WindowsAndMessaging::PostMessageW, }, }; -use crate::{PlatformDispatcher, TaskLabel, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD}; +use crate::{ + HWND, PlatformDispatcher, SafeHwnd, TaskLabel, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD, +}; pub(crate) struct WindowsDispatcher { main_sender: Sender, parker: Mutex, main_thread_id: ThreadId, - main_thread_id_win32: u32, + platform_window_handle: SafeHwnd, validation_number: usize, } impl WindowsDispatcher { pub(crate) fn new( main_sender: Sender, - main_thread_id_win32: u32, + platform_window_handle: HWND, validation_number: usize, ) -> Self { let parker = Mutex::new(Parker::new()); let main_thread_id = current().id(); + let platform_window_handle = platform_window_handle.into(); WindowsDispatcher { main_sender, parker, main_thread_id, - main_thread_id_win32, + platform_window_handle, validation_number, } } @@ -56,12 +57,7 @@ impl WindowsDispatcher { Ok(()) }) }; - ThreadPool::RunWithPriorityAndOptionsAsync( - &handler, - WorkItemPriority::High, - WorkItemOptions::TimeSliced, - ) - .log_err(); + ThreadPool::RunWithPriorityAsync(&handler, WorkItemPriority::High).log_err(); } fn dispatch_on_threadpool_after(&self, runnable: Runnable, duration: Duration) { @@ -72,12 +68,7 @@ impl WindowsDispatcher { Ok(()) }) }; - let delay = TimeSpan { - // A time period expressed in 100-nanosecond units. - // 10,000,000 ticks per second - Duration: (duration.as_nanos() / 100) as i64, - }; - ThreadPoolTimer::CreateTimer(&handler, delay).log_err(); + ThreadPoolTimer::CreateTimer(&handler, duration.into()).log_err(); } } @@ -96,8 +87,8 @@ impl PlatformDispatcher for WindowsDispatcher { fn dispatch_on_main_thread(&self, runnable: Runnable) { match self.main_sender.send(runnable) { Ok(_) => unsafe { - PostThreadMessageW( - self.main_thread_id_win32, + PostMessageW( + Some(self.platform_window_handle.as_raw()), WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD, WPARAM(self.validation_number), LPARAM(0), diff --git a/crates/gpui/src/platform/windows/events.rs b/crates/gpui/src/platform/windows/events.rs index 4ab257d27a69fc5fed458655150e1c09c3ebbba8..c1e2040d377da814b261682eb93321fda4ebdb2d 100644 --- a/crates/gpui/src/platform/windows/events.rs +++ b/crates/gpui/src/platform/windows/events.rs @@ -24,6 +24,8 @@ pub(crate) const WM_GPUI_CLOSE_ONE_WINDOW: u32 = WM_USER + 2; pub(crate) const WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD: u32 = WM_USER + 3; pub(crate) const WM_GPUI_DOCK_MENU_ACTION: u32 = WM_USER + 4; pub(crate) const WM_GPUI_FORCE_UPDATE_WINDOW: u32 = WM_USER + 5; +pub(crate) const WM_GPUI_KEYBOARD_LAYOUT_CHANGED: u32 = WM_USER + 6; +pub(crate) const WM_GPUI_GPU_DEVICE_LOST: u32 = WM_USER + 7; const SIZE_MOVE_LOOP_TIMER_ID: usize = 1; const AUTO_HIDE_TASKBAR_THICKNESS_PX: i32 = 1; @@ -39,7 +41,6 @@ impl WindowsWindowInner { let handled = match msg { WM_ACTIVATE => self.handle_activate_msg(wparam), WM_CREATE => self.handle_create_msg(handle), - WM_DEVICECHANGE => self.handle_device_change_msg(handle, wparam), WM_MOVE => self.handle_move_msg(handle, lparam), WM_SIZE => self.handle_size_msg(wparam, lparam), WM_GETMINMAXINFO => self.handle_get_min_max_info_msg(lparam), @@ -99,9 +100,11 @@ impl WindowsWindowInner { WM_IME_COMPOSITION => self.handle_ime_composition(handle, lparam), WM_SETCURSOR => self.handle_set_cursor(handle, lparam), WM_SETTINGCHANGE => self.handle_system_settings_changed(handle, wparam, lparam), - WM_INPUTLANGCHANGE => self.handle_input_language_changed(lparam), + WM_INPUTLANGCHANGE => self.handle_input_language_changed(), + WM_SHOWWINDOW => self.handle_window_visibility_changed(handle, wparam), WM_GPUI_CURSOR_STYLE_CHANGED => self.handle_cursor_changed(lparam), WM_GPUI_FORCE_UPDATE_WINDOW => self.draw_window(handle, true), + WM_GPUI_GPU_DEVICE_LOST => self.handle_device_lost(lparam), _ => None, }; if let Some(n) = handled { @@ -263,8 +266,8 @@ impl WindowsWindowInner { callback(); } unsafe { - PostThreadMessageW( - self.main_thread_id_win32, + PostMessageW( + Some(self.platform_window_handle), WM_GPUI_CLOSE_ONE_WINDOW, WPARAM(self.validation_number), LPARAM(handle.0 as isize), @@ -700,29 +703,28 @@ impl WindowsWindowInner { // Fix auto hide taskbar not showing. This solution is based on the approach // used by Chrome. However, it may result in one row of pixels being obscured // in our client area. But as Chrome says, "there seems to be no better solution." - if is_maximized { - if let Some(ref taskbar_position) = self + if is_maximized + && let Some(ref taskbar_position) = self .state .borrow() .system_settings .auto_hide_taskbar_position - { - // Fot the auto-hide taskbar, adjust in by 1 pixel on taskbar edge, - // so the window isn't treated as a "fullscreen app", which would cause - // the taskbar to disappear. - match taskbar_position { - AutoHideTaskbarPosition::Left => { - requested_client_rect[0].left += AUTO_HIDE_TASKBAR_THICKNESS_PX - } - AutoHideTaskbarPosition::Top => { - requested_client_rect[0].top += AUTO_HIDE_TASKBAR_THICKNESS_PX - } - AutoHideTaskbarPosition::Right => { - requested_client_rect[0].right -= AUTO_HIDE_TASKBAR_THICKNESS_PX - } - AutoHideTaskbarPosition::Bottom => { - requested_client_rect[0].bottom -= AUTO_HIDE_TASKBAR_THICKNESS_PX - } + { + // For the auto-hide taskbar, adjust in by 1 pixel on taskbar edge, + // so the window isn't treated as a "fullscreen app", which would cause + // the taskbar to disappear. + match taskbar_position { + AutoHideTaskbarPosition::Left => { + requested_client_rect[0].left += AUTO_HIDE_TASKBAR_THICKNESS_PX + } + AutoHideTaskbarPosition::Top => { + requested_client_rect[0].top += AUTO_HIDE_TASKBAR_THICKNESS_PX + } + AutoHideTaskbarPosition::Right => { + requested_client_rect[0].right -= AUTO_HIDE_TASKBAR_THICKNESS_PX + } + AutoHideTaskbarPosition::Bottom => { + requested_client_rect[0].bottom -= AUTO_HIDE_TASKBAR_THICKNESS_PX } } } @@ -956,7 +958,7 @@ impl WindowsWindowInner { click_count, first_mouse: false, }); - let result = func(input.clone()); + let result = func(input); let handled = !result.propagate || result.default_prevented; self.state.borrow_mut().callbacks.input = Some(func); @@ -1124,62 +1126,54 @@ impl WindowsWindowInner { // lParam is a pointer to a string that indicates the area containing the system parameter // that was changed. let parameter = PCWSTR::from_raw(lparam.0 as _); - if unsafe { !parameter.is_null() && !parameter.is_empty() } { - if let Some(parameter_string) = unsafe { parameter.to_string() }.log_err() { - log::info!("System settings changed: {}", parameter_string); - match parameter_string.as_str() { - "ImmersiveColorSet" => { - let new_appearance = system_appearance() - .context( - "unable to get system appearance when handling ImmersiveColorSet", - ) - .log_err()?; - let mut lock = self.state.borrow_mut(); - if new_appearance != lock.appearance { - lock.appearance = new_appearance; - let mut callback = lock.callbacks.appearance_changed.take()?; - drop(lock); - callback(); - self.state.borrow_mut().callbacks.appearance_changed = Some(callback); - configure_dwm_dark_mode(handle, new_appearance); - } - } - _ => {} + if unsafe { !parameter.is_null() && !parameter.is_empty() } + && let Some(parameter_string) = unsafe { parameter.to_string() }.log_err() + { + log::info!("System settings changed: {}", parameter_string); + if parameter_string.as_str() == "ImmersiveColorSet" { + let new_appearance = system_appearance() + .context("unable to get system appearance when handling ImmersiveColorSet") + .log_err()?; + let mut lock = self.state.borrow_mut(); + if new_appearance != lock.appearance { + lock.appearance = new_appearance; + let mut callback = lock.callbacks.appearance_changed.take()?; + drop(lock); + callback(); + self.state.borrow_mut().callbacks.appearance_changed = Some(callback); + configure_dwm_dark_mode(handle, new_appearance); } } } Some(0) } - fn handle_input_language_changed(&self, lparam: LPARAM) -> Option { - let thread = self.main_thread_id_win32; - let validation = self.validation_number; + fn handle_input_language_changed(&self) -> Option { unsafe { - PostThreadMessageW(thread, WM_INPUTLANGCHANGE, WPARAM(validation), lparam).log_err(); + PostMessageW( + Some(self.platform_window_handle), + WM_GPUI_KEYBOARD_LAYOUT_CHANGED, + WPARAM(self.validation_number), + LPARAM(0), + ) + .log_err(); } Some(0) } - fn handle_device_change_msg(&self, handle: HWND, wparam: WPARAM) -> Option { - if wparam.0 == DBT_DEVNODES_CHANGED as usize { - // The reason for sending this message is to actually trigger a redraw of the window. - unsafe { - PostMessageW( - Some(handle), - WM_GPUI_FORCE_UPDATE_WINDOW, - WPARAM(0), - LPARAM(0), - ) - .log_err(); - } - // If the GPU device is lost, this redraw will take care of recreating the device context. - // The WM_GPUI_FORCE_UPDATE_WINDOW message will take care of redrawing the window, after - // the device context has been recreated. - self.draw_window(handle, true) - } else { - // Other device change messages are not handled. - None + fn handle_window_visibility_changed(&self, handle: HWND, wparam: WPARAM) -> Option { + if wparam.0 == 1 { + self.draw_window(handle, false); } + None + } + + fn handle_device_lost(&self, lparam: LPARAM) -> Option { + let mut lock = self.state.borrow_mut(); + let devices = lparam.0 as *const DirectXDevices; + let devices = unsafe { &*devices }; + lock.renderer.handle_device_lost(&devices); + Some(0) } #[inline] @@ -1464,7 +1458,7 @@ pub(crate) fn current_modifiers() -> Modifiers { #[inline] pub(crate) fn current_capslock() -> Capslock { let on = unsafe { GetKeyState(VK_CAPITAL.0 as i32) & 1 } > 0; - Capslock { on: on } + Capslock { on } } fn get_client_area_insets( diff --git a/crates/gpui/src/platform/windows/keyboard.rs b/crates/gpui/src/platform/windows/keyboard.rs index 371feb70c25ab593ce612c7a90381a4cffdeff7d..259ebaebff794d4ed7203420c8c66188998c5fa4 100644 --- a/crates/gpui/src/platform/windows/keyboard.rs +++ b/crates/gpui/src/platform/windows/keyboard.rs @@ -1,22 +1,31 @@ use anyhow::Result; +use collections::HashMap; use windows::Win32::UI::{ Input::KeyboardAndMouse::{ - GetKeyboardLayoutNameW, MAPVK_VK_TO_CHAR, MapVirtualKeyW, ToUnicode, VIRTUAL_KEY, VK_0, - VK_1, VK_2, VK_3, VK_4, VK_5, VK_6, VK_7, VK_8, VK_9, VK_ABNT_C1, VK_CONTROL, VK_MENU, - VK_OEM_1, VK_OEM_2, VK_OEM_3, VK_OEM_4, VK_OEM_5, VK_OEM_6, VK_OEM_7, VK_OEM_8, VK_OEM_102, - VK_OEM_COMMA, VK_OEM_MINUS, VK_OEM_PERIOD, VK_OEM_PLUS, VK_SHIFT, + GetKeyboardLayoutNameW, MAPVK_VK_TO_CHAR, MAPVK_VK_TO_VSC, MapVirtualKeyW, ToUnicode, + VIRTUAL_KEY, VK_0, VK_1, VK_2, VK_3, VK_4, VK_5, VK_6, VK_7, VK_8, VK_9, VK_ABNT_C1, + VK_CONTROL, VK_MENU, VK_OEM_1, VK_OEM_2, VK_OEM_3, VK_OEM_4, VK_OEM_5, VK_OEM_6, VK_OEM_7, + VK_OEM_8, VK_OEM_102, VK_OEM_COMMA, VK_OEM_MINUS, VK_OEM_PERIOD, VK_OEM_PLUS, VK_SHIFT, }, WindowsAndMessaging::KL_NAMELENGTH, }; use windows_core::HSTRING; -use crate::{Modifiers, PlatformKeyboardLayout}; +use crate::{ + KeybindingKeystroke, Keystroke, Modifiers, PlatformKeyboardLayout, PlatformKeyboardMapper, +}; pub(crate) struct WindowsKeyboardLayout { id: String, name: String, } +pub(crate) struct WindowsKeyboardMapper { + key_to_vkey: HashMap, + vkey_to_key: HashMap, + vkey_to_shifted: HashMap, +} + impl PlatformKeyboardLayout for WindowsKeyboardLayout { fn id(&self) -> &str { &self.id @@ -27,6 +36,61 @@ impl PlatformKeyboardLayout for WindowsKeyboardLayout { } } +impl PlatformKeyboardMapper for WindowsKeyboardMapper { + fn map_key_equivalent( + &self, + mut keystroke: Keystroke, + use_key_equivalents: bool, + ) -> KeybindingKeystroke { + let Some((vkey, shifted_key)) = self.get_vkey_from_key(&keystroke.key, use_key_equivalents) + else { + return KeybindingKeystroke::from_keystroke(keystroke); + }; + if shifted_key && keystroke.modifiers.shift { + log::warn!( + "Keystroke '{}' has both shift and a shifted key, this is likely a bug", + keystroke.key + ); + } + + let shift = shifted_key || keystroke.modifiers.shift; + keystroke.modifiers.shift = false; + + let Some(key) = self.vkey_to_key.get(&vkey).cloned() else { + log::error!( + "Failed to map key equivalent '{:?}' to a valid key", + keystroke + ); + return KeybindingKeystroke::from_keystroke(keystroke); + }; + + keystroke.key = if shift { + let Some(shifted_key) = self.vkey_to_shifted.get(&vkey).cloned() else { + log::error!( + "Failed to map keystroke {:?} with virtual key '{:?}' to a shifted key", + keystroke, + vkey + ); + return KeybindingKeystroke::from_keystroke(keystroke); + }; + shifted_key + } else { + key.clone() + }; + + let modifiers = Modifiers { + shift, + ..keystroke.modifiers + }; + + KeybindingKeystroke::new(keystroke, modifiers, key) + } + + fn get_key_equivalents(&self) -> Option<&HashMap> { + None + } +} + impl WindowsKeyboardLayout { pub(crate) fn new() -> Result { let mut buffer = [0u16; KL_NAMELENGTH as usize]; @@ -48,6 +112,41 @@ impl WindowsKeyboardLayout { } } +impl WindowsKeyboardMapper { + pub(crate) fn new() -> Self { + let mut key_to_vkey = HashMap::default(); + let mut vkey_to_key = HashMap::default(); + let mut vkey_to_shifted = HashMap::default(); + for vkey in CANDIDATE_VKEYS { + if let Some(key) = get_key_from_vkey(*vkey) { + key_to_vkey.insert(key.clone(), (vkey.0, false)); + vkey_to_key.insert(vkey.0, key); + } + let scan_code = unsafe { MapVirtualKeyW(vkey.0 as u32, MAPVK_VK_TO_VSC) }; + if scan_code == 0 { + continue; + } + if let Some(shifted_key) = get_shifted_key(*vkey, scan_code) { + key_to_vkey.insert(shifted_key.clone(), (vkey.0, true)); + vkey_to_shifted.insert(vkey.0, shifted_key); + } + } + Self { + key_to_vkey, + vkey_to_key, + vkey_to_shifted, + } + } + + fn get_vkey_from_key(&self, key: &str, use_key_equivalents: bool) -> Option<(u16, bool)> { + if use_key_equivalents { + get_vkey_from_key_with_us_layout(key) + } else { + self.key_to_vkey.get(key).cloned() + } + } +} + pub(crate) fn get_keystroke_key( vkey: VIRTUAL_KEY, scan_code: u32, @@ -140,3 +239,134 @@ pub(crate) fn generate_key_char( _ => None, } } + +fn get_vkey_from_key_with_us_layout(key: &str) -> Option<(u16, bool)> { + match key { + // ` => VK_OEM_3 + "`" => Some((VK_OEM_3.0, false)), + "~" => Some((VK_OEM_3.0, true)), + "1" => Some((VK_1.0, false)), + "!" => Some((VK_1.0, true)), + "2" => Some((VK_2.0, false)), + "@" => Some((VK_2.0, true)), + "3" => Some((VK_3.0, false)), + "#" => Some((VK_3.0, true)), + "4" => Some((VK_4.0, false)), + "$" => Some((VK_4.0, true)), + "5" => Some((VK_5.0, false)), + "%" => Some((VK_5.0, true)), + "6" => Some((VK_6.0, false)), + "^" => Some((VK_6.0, true)), + "7" => Some((VK_7.0, false)), + "&" => Some((VK_7.0, true)), + "8" => Some((VK_8.0, false)), + "*" => Some((VK_8.0, true)), + "9" => Some((VK_9.0, false)), + "(" => Some((VK_9.0, true)), + "0" => Some((VK_0.0, false)), + ")" => Some((VK_0.0, true)), + "-" => Some((VK_OEM_MINUS.0, false)), + "_" => Some((VK_OEM_MINUS.0, true)), + "=" => Some((VK_OEM_PLUS.0, false)), + "+" => Some((VK_OEM_PLUS.0, true)), + "[" => Some((VK_OEM_4.0, false)), + "{" => Some((VK_OEM_4.0, true)), + "]" => Some((VK_OEM_6.0, false)), + "}" => Some((VK_OEM_6.0, true)), + "\\" => Some((VK_OEM_5.0, false)), + "|" => Some((VK_OEM_5.0, true)), + ";" => Some((VK_OEM_1.0, false)), + ":" => Some((VK_OEM_1.0, true)), + "'" => Some((VK_OEM_7.0, false)), + "\"" => Some((VK_OEM_7.0, true)), + "," => Some((VK_OEM_COMMA.0, false)), + "<" => Some((VK_OEM_COMMA.0, true)), + "." => Some((VK_OEM_PERIOD.0, false)), + ">" => Some((VK_OEM_PERIOD.0, true)), + "/" => Some((VK_OEM_2.0, false)), + "?" => Some((VK_OEM_2.0, true)), + _ => None, + } +} + +const CANDIDATE_VKEYS: &[VIRTUAL_KEY] = &[ + VK_OEM_3, + VK_OEM_MINUS, + VK_OEM_PLUS, + VK_OEM_4, + VK_OEM_5, + VK_OEM_6, + VK_OEM_1, + VK_OEM_7, + VK_OEM_COMMA, + VK_OEM_PERIOD, + VK_OEM_2, + VK_OEM_102, + VK_OEM_8, + VK_ABNT_C1, + VK_0, + VK_1, + VK_2, + VK_3, + VK_4, + VK_5, + VK_6, + VK_7, + VK_8, + VK_9, +]; + +#[cfg(test)] +mod tests { + use crate::{Keystroke, Modifiers, PlatformKeyboardMapper, WindowsKeyboardMapper}; + + #[test] + fn test_keyboard_mapper() { + let mapper = WindowsKeyboardMapper::new(); + + // Normal case + let keystroke = Keystroke { + modifiers: Modifiers::control(), + key: "a".to_string(), + key_char: None, + }; + let mapped = mapper.map_key_equivalent(keystroke.clone(), true); + assert_eq!(*mapped.inner(), keystroke); + assert_eq!(mapped.key(), "a"); + assert_eq!(*mapped.modifiers(), Modifiers::control()); + + // Shifted case, ctrl-$ + let keystroke = Keystroke { + modifiers: Modifiers::control(), + key: "$".to_string(), + key_char: None, + }; + let mapped = mapper.map_key_equivalent(keystroke.clone(), true); + assert_eq!(*mapped.inner(), keystroke); + assert_eq!(mapped.key(), "4"); + assert_eq!(*mapped.modifiers(), Modifiers::control_shift()); + + // Shifted case, but shift is true + let keystroke = Keystroke { + modifiers: Modifiers::control_shift(), + key: "$".to_string(), + key_char: None, + }; + let mapped = mapper.map_key_equivalent(keystroke, true); + assert_eq!(mapped.inner().modifiers, Modifiers::control()); + assert_eq!(mapped.key(), "4"); + assert_eq!(*mapped.modifiers(), Modifiers::control_shift()); + + // Windows style + let keystroke = Keystroke { + modifiers: Modifiers::control_shift(), + key: "4".to_string(), + key_char: None, + }; + let mapped = mapper.map_key_equivalent(keystroke, true); + assert_eq!(mapped.inner().modifiers, Modifiers::control()); + assert_eq!(mapped.inner().key, "$"); + assert_eq!(mapped.key(), "4"); + assert_eq!(*mapped.modifiers(), Modifiers::control_shift()); + } +} diff --git a/crates/gpui/src/platform/windows/platform.rs b/crates/gpui/src/platform/windows/platform.rs index bbde655b80517198e4f604edde2e560e19b57ff2..4d0e6ea56f7d90f303f6634de1239a6a4542429a 100644 --- a/crates/gpui/src/platform/windows/platform.rs +++ b/crates/gpui/src/platform/windows/platform.rs @@ -1,8 +1,9 @@ use std::{ cell::RefCell, + ffi::OsStr, mem::ManuallyDrop, path::{Path, PathBuf}, - rc::Rc, + rc::{Rc, Weak}, sync::Arc, }; @@ -17,12 +18,9 @@ use windows::{ UI::ViewManagement::UISettings, Win32::{ Foundation::*, - Graphics::{ - Gdi::*, - Imaging::{CLSID_WICImagingFactory, IWICImagingFactory}, - }, + Graphics::{Direct3D11::ID3D11Device, Gdi::*}, Security::Credentials::*, - System::{Com::*, LibraryLoader::*, Ole::*, SystemInformation::*, Threading::*}, + System::{Com::*, LibraryLoader::*, Ole::*, SystemInformation::*}, UI::{Input::KeyboardAndMouse::*, Shell::*, WindowsAndMessaging::*}, }, core::*, @@ -31,28 +29,34 @@ use windows::{ use crate::*; pub(crate) struct WindowsPlatform { - state: RefCell, + inner: Rc, raw_window_handles: Arc>>, // The below members will never change throughout the entire lifecycle of the app. icon: HICON, - main_receiver: flume::Receiver, background_executor: BackgroundExecutor, foreground_executor: ForegroundExecutor, text_system: Arc, windows_version: WindowsVersion, - bitmap_factory: ManuallyDrop, drop_target_helper: IDropTargetHelper, - validation_number: usize, - main_thread_id_win32: u32, + handle: HWND, disable_direct_composition: bool, } +struct WindowsPlatformInner { + state: RefCell, + raw_window_handles: std::sync::Weak>>, + // The below members will never change throughout the entire lifecycle of the app. + validation_number: usize, + main_receiver: flume::Receiver, +} + pub(crate) struct WindowsPlatformState { callbacks: PlatformCallbacks, menus: Vec, jump_list: JumpList, // NOTE: standard cursor handles don't need to close. pub(crate) current_cursor: Option, + directx_devices: ManuallyDrop, } #[derive(Default)] @@ -67,15 +71,17 @@ struct PlatformCallbacks { } impl WindowsPlatformState { - fn new() -> Self { + fn new(directx_devices: DirectXDevices) -> Self { let callbacks = PlatformCallbacks::default(); let jump_list = JumpList::new(); let current_cursor = load_cursor(CursorStyle::Arrow); + let directx_devices = ManuallyDrop::new(directx_devices); Self { callbacks, jump_list, current_cursor, + directx_devices, menus: Vec::new(), } } @@ -86,51 +92,72 @@ impl WindowsPlatform { unsafe { OleInitialize(None).context("unable to initialize Windows OLE")?; } + let directx_devices = DirectXDevices::new().context("Creating DirectX devices")?; let (main_sender, main_receiver) = flume::unbounded::(); - let main_thread_id_win32 = unsafe { GetCurrentThreadId() }; - let validation_number = rand::random::(); + let validation_number = if usize::BITS == 64 { + rand::random::() as usize + } else { + rand::random::() as usize + }; + let raw_window_handles = Arc::new(RwLock::new(SmallVec::new())); + let text_system = Arc::new( + DirectWriteTextSystem::new(&directx_devices) + .context("Error creating DirectWriteTextSystem")?, + ); + register_platform_window_class(); + let mut context = PlatformWindowCreateContext { + inner: None, + raw_window_handles: Arc::downgrade(&raw_window_handles), + validation_number, + main_receiver: Some(main_receiver), + directx_devices: Some(directx_devices), + }; + let result = unsafe { + CreateWindowExW( + WINDOW_EX_STYLE(0), + PLATFORM_WINDOW_CLASS_NAME, + None, + WINDOW_STYLE(0), + 0, + 0, + 0, + 0, + Some(HWND_MESSAGE), + None, + None, + Some(&context as *const _ as *const _), + ) + }; + let inner = context.inner.take().unwrap()?; + let handle = result?; let dispatcher = Arc::new(WindowsDispatcher::new( main_sender, - main_thread_id_win32, + handle, validation_number, )); let disable_direct_composition = std::env::var(DISABLE_DIRECT_COMPOSITION) .is_ok_and(|value| value == "true" || value == "1"); let background_executor = BackgroundExecutor::new(dispatcher.clone()); let foreground_executor = ForegroundExecutor::new(dispatcher); - let directx_devices = DirectXDevices::new(disable_direct_composition) - .context("Unable to init directx devices.")?; - let bitmap_factory = ManuallyDrop::new(unsafe { - CoCreateInstance(&CLSID_WICImagingFactory, None, CLSCTX_INPROC_SERVER) - .context("Error creating bitmap factory.")? - }); - let text_system = Arc::new( - DirectWriteTextSystem::new(&directx_devices, &bitmap_factory) - .context("Error creating DirectWriteTextSystem")?, - ); + let drop_target_helper: IDropTargetHelper = unsafe { CoCreateInstance(&CLSID_DragDropHelper, None, CLSCTX_INPROC_SERVER) .context("Error creating drop target helper.")? }; let icon = load_icon().unwrap_or_default(); - let state = RefCell::new(WindowsPlatformState::new()); - let raw_window_handles = Arc::new(RwLock::new(SmallVec::new())); let windows_version = WindowsVersion::new().context("Error retrieve windows version")?; Ok(Self { - state, + inner, + handle, raw_window_handles, icon, - main_receiver, background_executor, foreground_executor, text_system, disable_direct_composition, windows_version, - bitmap_factory, drop_target_helper, - validation_number, - main_thread_id_win32, }) } @@ -152,119 +179,21 @@ impl WindowsPlatform { }); } - fn close_one_window(&self, target_window: HWND) -> bool { - let mut lock = self.raw_window_handles.write(); - let index = lock - .iter() - .position(|handle| handle.as_raw() == target_window) - .unwrap(); - lock.remove(index); - - lock.is_empty() - } - - #[inline] - fn run_foreground_task(&self) { - for runnable in self.main_receiver.drain() { - runnable.run(); - } - } - fn generate_creation_info(&self) -> WindowCreationInfo { WindowCreationInfo { icon: self.icon, executor: self.foreground_executor.clone(), - current_cursor: self.state.borrow().current_cursor, + current_cursor: self.inner.state.borrow().current_cursor, windows_version: self.windows_version, drop_target_helper: self.drop_target_helper.clone(), - validation_number: self.validation_number, - main_receiver: self.main_receiver.clone(), - main_thread_id_win32: self.main_thread_id_win32, + validation_number: self.inner.validation_number, + main_receiver: self.inner.main_receiver.clone(), + platform_window_handle: self.handle, disable_direct_composition: self.disable_direct_composition, + directx_devices: (*self.inner.state.borrow().directx_devices).clone(), } } - fn handle_dock_action_event(&self, action_idx: usize) { - let mut lock = self.state.borrow_mut(); - if let Some(mut callback) = lock.callbacks.app_menu_action.take() { - let Some(action) = lock - .jump_list - .dock_menus - .get(action_idx) - .map(|dock_menu| dock_menu.action.boxed_clone()) - else { - lock.callbacks.app_menu_action = Some(callback); - log::error!("Dock menu for index {action_idx} not found"); - return; - }; - drop(lock); - callback(&*action); - self.state.borrow_mut().callbacks.app_menu_action = Some(callback); - } - } - - fn handle_input_lang_change(&self) { - let mut lock = self.state.borrow_mut(); - if let Some(mut callback) = lock.callbacks.keyboard_layout_change.take() { - drop(lock); - callback(); - self.state - .borrow_mut() - .callbacks - .keyboard_layout_change - .get_or_insert(callback); - } - } - - // Returns if the app should quit. - fn handle_events(&self) { - let mut msg = MSG::default(); - unsafe { - while GetMessageW(&mut msg, None, 0, 0).as_bool() { - match msg.message { - WM_QUIT => return, - WM_INPUTLANGCHANGE - | WM_GPUI_CLOSE_ONE_WINDOW - | WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD - | WM_GPUI_DOCK_MENU_ACTION => { - if self.handle_gpui_evnets(msg.message, msg.wParam, msg.lParam, &msg) { - return; - } - } - _ => { - DispatchMessageW(&msg); - } - } - } - } - } - - // Returns true if the app should quit. - fn handle_gpui_evnets( - &self, - message: u32, - wparam: WPARAM, - lparam: LPARAM, - msg: *const MSG, - ) -> bool { - if wparam.0 != self.validation_number { - unsafe { DispatchMessageW(msg) }; - return false; - } - match message { - WM_GPUI_CLOSE_ONE_WINDOW => { - if self.close_one_window(HWND(lparam.0 as _)) { - return true; - } - } - WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD => self.run_foreground_task(), - WM_GPUI_DOCK_MENU_ACTION => self.handle_dock_action_event(lparam.0 as _), - WM_INPUTLANGCHANGE => self.handle_input_lang_change(), - _ => unreachable!(), - } - false - } - fn set_dock_menus(&self, menus: Vec) { let mut actions = Vec::new(); menus.into_iter().for_each(|menu| { @@ -272,7 +201,7 @@ impl WindowsPlatform { actions.push(dock_menu); } }); - let mut lock = self.state.borrow_mut(); + let mut lock = self.inner.state.borrow_mut(); lock.jump_list.dock_menus = actions; update_jump_list(&lock.jump_list).log_err(); } @@ -288,7 +217,7 @@ impl WindowsPlatform { actions.push(dock_menu); } }); - let mut lock = self.state.borrow_mut(); + let mut lock = self.inner.state.borrow_mut(); lock.jump_list.dock_menus = actions; lock.jump_list.recent_workspaces = entries; update_jump_list(&lock.jump_list) @@ -309,19 +238,30 @@ impl WindowsPlatform { } fn begin_vsync_thread(&self) { + let mut directx_device = (*self.inner.state.borrow().directx_devices).clone(); + let platform_window: SafeHwnd = self.handle.into(); + let validation_number = self.inner.validation_number; let all_windows = Arc::downgrade(&self.raw_window_handles); + let text_system = Arc::downgrade(&self.text_system); std::thread::spawn(move || { let vsync_provider = VSyncProvider::new(); loop { vsync_provider.wait_for_vsync(); + if check_device_lost(&directx_device.device) { + handle_gpu_device_lost( + &mut directx_device, + platform_window.as_raw(), + validation_number, + &all_windows, + &text_system, + ); + } let Some(all_windows) = all_windows.upgrade() else { break; }; for hwnd in all_windows.read().iter() { unsafe { - RedrawWindow(Some(hwnd.as_raw()), None, None, RDW_INVALIDATE) - .ok() - .log_err(); + let _ = RedrawWindow(Some(hwnd.as_raw()), None, None, RDW_INVALIDATE); } } } @@ -350,16 +290,30 @@ impl Platform for WindowsPlatform { ) } + fn keyboard_mapper(&self) -> Rc { + Rc::new(WindowsKeyboardMapper::new()) + } + fn on_keyboard_layout_change(&self, callback: Box) { - self.state.borrow_mut().callbacks.keyboard_layout_change = Some(callback); + self.inner + .state + .borrow_mut() + .callbacks + .keyboard_layout_change = Some(callback); } fn run(&self, on_finish_launching: Box) { on_finish_launching(); self.begin_vsync_thread(); - self.handle_events(); - if let Some(ref mut callback) = self.state.borrow_mut().callbacks.quit { + let mut msg = MSG::default(); + unsafe { + while GetMessageW(&mut msg, None, 0, 0).as_bool() { + DispatchMessageW(&msg); + } + } + + if let Some(ref mut callback) = self.inner.state.borrow_mut().callbacks.quit { callback(); } } @@ -460,19 +414,21 @@ impl Platform for WindowsPlatform { } fn open_url(&self, url: &str) { + if url.is_empty() { + return; + } let url_string = url.to_string(); self.background_executor() .spawn(async move { - if url_string.is_empty() { - return; - } - open_target(url_string.as_str()); + open_target(&url_string) + .with_context(|| format!("Opening url: {}", url_string)) + .log_err(); }) .detach(); } fn on_open_urls(&self, callback: Box)>) { - self.state.borrow_mut().callbacks.open_urls = Some(callback); + self.inner.state.borrow_mut().callbacks.open_urls = Some(callback); } fn prompt_for_paths( @@ -490,13 +446,18 @@ impl Platform for WindowsPlatform { rx } - fn prompt_for_new_path(&self, directory: &Path) -> Receiver>> { + fn prompt_for_new_path( + &self, + directory: &Path, + suggested_name: Option<&str>, + ) -> Receiver>> { let directory = directory.to_owned(); + let suggested_name = suggested_name.map(|s| s.to_owned()); let (tx, rx) = oneshot::channel(); let window = self.find_current_active_window(); self.foreground_executor() .spawn(async move { - let _ = tx.send(file_save_dialog(directory, window)); + let _ = tx.send(file_save_dialog(directory, suggested_name, window)); }) .detach(); @@ -509,55 +470,47 @@ impl Platform for WindowsPlatform { } fn reveal_path(&self, path: &Path) { - let Ok(file_full_path) = path.canonicalize() else { - log::error!("unable to parse file path"); + if path.as_os_str().is_empty() { return; - }; + } + let path = path.to_path_buf(); self.background_executor() .spawn(async move { - let Some(path) = file_full_path.to_str() else { - return; - }; - if path.is_empty() { - return; - } - open_target_in_explorer(path); + open_target_in_explorer(&path) + .with_context(|| format!("Revealing path {} in explorer", path.display())) + .log_err(); }) .detach(); } fn open_with_system(&self, path: &Path) { - let Ok(full_path) = path.canonicalize() else { - log::error!("unable to parse file full path: {}", path.display()); + if path.as_os_str().is_empty() { return; - }; + } + let path = path.to_path_buf(); self.background_executor() .spawn(async move { - let Some(full_path_str) = full_path.to_str() else { - return; - }; - if full_path_str.is_empty() { - return; - }; - open_target(full_path_str); + open_target(&path) + .with_context(|| format!("Opening {} with system", path.display())) + .log_err(); }) .detach(); } fn on_quit(&self, callback: Box) { - self.state.borrow_mut().callbacks.quit = Some(callback); + self.inner.state.borrow_mut().callbacks.quit = Some(callback); } fn on_reopen(&self, callback: Box) { - self.state.borrow_mut().callbacks.reopen = Some(callback); + self.inner.state.borrow_mut().callbacks.reopen = Some(callback); } fn set_menus(&self, menus: Vec, _keymap: &Keymap) { - self.state.borrow_mut().menus = menus.into_iter().map(|menu| menu.owned()).collect(); + self.inner.state.borrow_mut().menus = menus.into_iter().map(|menu| menu.owned()).collect(); } fn get_menus(&self) -> Option> { - Some(self.state.borrow().menus.clone()) + Some(self.inner.state.borrow().menus.clone()) } fn set_dock_menu(&self, menus: Vec, _keymap: &Keymap) { @@ -565,15 +518,19 @@ impl Platform for WindowsPlatform { } fn on_app_menu_action(&self, callback: Box) { - self.state.borrow_mut().callbacks.app_menu_action = Some(callback); + self.inner.state.borrow_mut().callbacks.app_menu_action = Some(callback); } fn on_will_open_app_menu(&self, callback: Box) { - self.state.borrow_mut().callbacks.will_open_app_menu = Some(callback); + self.inner.state.borrow_mut().callbacks.will_open_app_menu = Some(callback); } fn on_validate_app_menu_command(&self, callback: Box bool>) { - self.state.borrow_mut().callbacks.validate_app_menu_command = Some(callback); + self.inner + .state + .borrow_mut() + .callbacks + .validate_app_menu_command = Some(callback); } fn app_path(&self) -> Result { @@ -587,7 +544,7 @@ impl Platform for WindowsPlatform { fn set_cursor_style(&self, style: CursorStyle) { let hcursor = load_cursor(style); - let mut lock = self.state.borrow_mut(); + let mut lock = self.inner.state.borrow_mut(); if lock.current_cursor.map(|c| c.0) != hcursor.map(|c| c.0) { self.post_message( WM_GPUI_CURSOR_STYLE_CHANGED, @@ -690,10 +647,10 @@ impl Platform for WindowsPlatform { fn perform_dock_menu_action(&self, action: usize) { unsafe { - PostThreadMessageW( - self.main_thread_id_win32, + PostMessageW( + Some(self.handle), WM_GPUI_DOCK_MENU_ACTION, - WPARAM(self.validation_number), + WPARAM(self.inner.validation_number), LPARAM(action as isize), ) .log_err(); @@ -709,15 +666,147 @@ impl Platform for WindowsPlatform { } } +impl WindowsPlatformInner { + fn new(context: &mut PlatformWindowCreateContext) -> Result> { + let state = RefCell::new(WindowsPlatformState::new( + context.directx_devices.take().unwrap(), + )); + Ok(Rc::new(Self { + state, + raw_window_handles: context.raw_window_handles.clone(), + validation_number: context.validation_number, + main_receiver: context.main_receiver.take().unwrap(), + })) + } + + fn handle_msg( + self: &Rc, + handle: HWND, + msg: u32, + wparam: WPARAM, + lparam: LPARAM, + ) -> LRESULT { + let handled = match msg { + WM_GPUI_CLOSE_ONE_WINDOW + | WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD + | WM_GPUI_DOCK_MENU_ACTION + | WM_GPUI_KEYBOARD_LAYOUT_CHANGED + | WM_GPUI_GPU_DEVICE_LOST => self.handle_gpui_events(msg, wparam, lparam), + _ => None, + }; + if let Some(result) = handled { + LRESULT(result) + } else { + unsafe { DefWindowProcW(handle, msg, wparam, lparam) } + } + } + + fn handle_gpui_events(&self, message: u32, wparam: WPARAM, lparam: LPARAM) -> Option { + if wparam.0 != self.validation_number { + log::error!("Wrong validation number while processing message: {message}"); + return None; + } + match message { + WM_GPUI_CLOSE_ONE_WINDOW => { + if self.close_one_window(HWND(lparam.0 as _)) { + unsafe { PostQuitMessage(0) }; + } + Some(0) + } + WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD => self.run_foreground_task(), + WM_GPUI_DOCK_MENU_ACTION => self.handle_dock_action_event(lparam.0 as _), + WM_GPUI_KEYBOARD_LAYOUT_CHANGED => self.handle_keyboard_layout_change(), + WM_GPUI_GPU_DEVICE_LOST => self.handle_device_lost(lparam), + _ => unreachable!(), + } + } + + fn close_one_window(&self, target_window: HWND) -> bool { + let Some(all_windows) = self.raw_window_handles.upgrade() else { + log::error!("Failed to upgrade raw window handles"); + return false; + }; + let mut lock = all_windows.write(); + let index = lock + .iter() + .position(|handle| handle.as_raw() == target_window) + .unwrap(); + lock.remove(index); + + lock.is_empty() + } + + #[inline] + fn run_foreground_task(&self) -> Option { + for runnable in self.main_receiver.drain() { + runnable.run(); + } + Some(0) + } + + fn handle_dock_action_event(&self, action_idx: usize) -> Option { + let mut lock = self.state.borrow_mut(); + let mut callback = lock.callbacks.app_menu_action.take()?; + let Some(action) = lock + .jump_list + .dock_menus + .get(action_idx) + .map(|dock_menu| dock_menu.action.boxed_clone()) + else { + lock.callbacks.app_menu_action = Some(callback); + log::error!("Dock menu for index {action_idx} not found"); + return Some(1); + }; + drop(lock); + callback(&*action); + self.state.borrow_mut().callbacks.app_menu_action = Some(callback); + Some(0) + } + + fn handle_keyboard_layout_change(&self) -> Option { + let mut callback = self + .state + .borrow_mut() + .callbacks + .keyboard_layout_change + .take()?; + callback(); + self.state.borrow_mut().callbacks.keyboard_layout_change = Some(callback); + Some(0) + } + + fn handle_device_lost(&self, lparam: LPARAM) -> Option { + let mut lock = self.state.borrow_mut(); + let directx_devices = lparam.0 as *const DirectXDevices; + let directx_devices = unsafe { &*directx_devices }; + unsafe { + ManuallyDrop::drop(&mut lock.directx_devices); + } + lock.directx_devices = ManuallyDrop::new(directx_devices.clone()); + + Some(0) + } +} + impl Drop for WindowsPlatform { fn drop(&mut self) { unsafe { - ManuallyDrop::drop(&mut self.bitmap_factory); + DestroyWindow(self.handle) + .context("Destroying platform window") + .log_err(); OleUninitialize(); } } } +impl Drop for WindowsPlatformState { + fn drop(&mut self) { + unsafe { + ManuallyDrop::drop(&mut self.directx_devices); + } + } +} + pub(crate) struct WindowCreationInfo { pub(crate) icon: HICON, pub(crate) executor: ForegroundExecutor, @@ -726,43 +815,80 @@ pub(crate) struct WindowCreationInfo { pub(crate) drop_target_helper: IDropTargetHelper, pub(crate) validation_number: usize, pub(crate) main_receiver: flume::Receiver, - pub(crate) main_thread_id_win32: u32, + pub(crate) platform_window_handle: HWND, pub(crate) disable_direct_composition: bool, + pub(crate) directx_devices: DirectXDevices, } -fn open_target(target: &str) { - unsafe { - let ret = ShellExecuteW( +struct PlatformWindowCreateContext { + inner: Option>>, + raw_window_handles: std::sync::Weak>>, + validation_number: usize, + main_receiver: Option>, + directx_devices: Option, +} + +fn open_target(target: impl AsRef) -> Result<()> { + let target = target.as_ref(); + let ret = unsafe { + ShellExecuteW( None, windows::core::w!("open"), &HSTRING::from(target), None, None, SW_SHOWDEFAULT, - ); - if ret.0 as isize <= 32 { - log::error!("Unable to open target: {}", std::io::Error::last_os_error()); - } + ) + }; + if ret.0 as isize <= 32 { + Err(anyhow::anyhow!( + "Unable to open target: {}", + std::io::Error::last_os_error() + )) + } else { + Ok(()) } } -fn open_target_in_explorer(target: &str) { +fn open_target_in_explorer(target: &Path) -> Result<()> { + let dir = target.parent().context("No parent folder found")?; + let desktop = unsafe { SHGetDesktopFolder()? }; + + let mut dir_item = std::ptr::null_mut(); unsafe { - let ret = ShellExecuteW( + desktop.ParseDisplayName( + HWND::default(), None, - windows::core::w!("open"), - windows::core::w!("explorer.exe"), - &HSTRING::from(format!("/select,{}", target).as_str()), + &HSTRING::from(dir), None, - SW_SHOWDEFAULT, - ); - if ret.0 as isize <= 32 { - log::error!( - "Unable to open target in explorer: {}", - std::io::Error::last_os_error() - ); - } + &mut dir_item, + std::ptr::null_mut(), + )?; } + + let mut file_item = std::ptr::null_mut(); + unsafe { + desktop.ParseDisplayName( + HWND::default(), + None, + &HSTRING::from(target), + None, + &mut file_item, + std::ptr::null_mut(), + )?; + } + + let highlight = [file_item as *const _]; + unsafe { SHOpenFolderAndSelectItems(dir_item as _, Some(&highlight), 0) }.or_else(|err| { + if err.code().0 == ERROR_FILE_NOT_FOUND.0 as i32 { + // On some systems, the above call mysteriously fails with "file not + // found" even though the file is there. In these cases, ShellExecute() + // seems to work as a fallback (although it won't select the file). + open_target(dir).context("Opening target parent folder") + } else { + Err(anyhow::anyhow!("Can not open target path: {}", err)) + } + }) } fn file_open_dialog( @@ -782,6 +908,12 @@ fn file_open_dialog( unsafe { folder_dialog.SetOptions(dialog_options)?; + + if let Some(prompt) = options.prompt { + let prompt: &str = &prompt; + folder_dialog.SetOkButtonLabel(&HSTRING::from(prompt))?; + } + if folder_dialog.Show(window).is_err() { // User cancelled return Ok(None); @@ -804,17 +936,26 @@ fn file_open_dialog( Ok(Some(paths)) } -fn file_save_dialog(directory: PathBuf, window: Option) -> Result> { +fn file_save_dialog( + directory: PathBuf, + suggested_name: Option, + window: Option, +) -> Result> { let dialog: IFileSaveDialog = unsafe { CoCreateInstance(&FileSaveDialog, None, CLSCTX_ALL)? }; - if !directory.to_string_lossy().is_empty() { - if let Some(full_path) = directory.canonicalize().log_err() { - let full_path = SanitizedPath::from(full_path); - let full_path_string = full_path.to_string(); - let path_item: IShellItem = - unsafe { SHCreateItemFromParsingName(&HSTRING::from(full_path_string), None)? }; - unsafe { dialog.SetFolder(&path_item).log_err() }; - } + if !directory.to_string_lossy().is_empty() + && let Some(full_path) = directory.canonicalize().log_err() + { + let full_path = SanitizedPath::new(&full_path); + let full_path_string = full_path.to_string(); + let path_item: IShellItem = + unsafe { SHCreateItemFromParsingName(&HSTRING::from(full_path_string), None)? }; + unsafe { dialog.SetFolder(&path_item).log_err() }; + } + + if let Some(suggested_name) = suggested_name { + unsafe { dialog.SetFileName(&HSTRING::from(suggested_name)).log_err() }; } + unsafe { dialog.SetFileTypes(&[Common::COMDLG_FILTERSPEC { pszName: windows::core::w!("All files"), @@ -857,6 +998,135 @@ fn should_auto_hide_scrollbars() -> Result { Ok(ui_settings.AutoHideScrollBars()?) } +fn check_device_lost(device: &ID3D11Device) -> bool { + let device_state = unsafe { device.GetDeviceRemovedReason() }; + match device_state { + Ok(_) => false, + Err(err) => { + log::error!("DirectX device lost detected: {:?}", err); + true + } + } +} + +fn handle_gpu_device_lost( + directx_devices: &mut DirectXDevices, + platform_window: HWND, + validation_number: usize, + all_windows: &std::sync::Weak>>, + text_system: &std::sync::Weak, +) { + // Here we wait a bit to ensure the the system has time to recover from the device lost state. + // If we don't wait, the final drawing result will be blank. + std::thread::sleep(std::time::Duration::from_millis(350)); + + try_to_recover_from_device_lost( + || { + DirectXDevices::new() + .context("Failed to recreate new DirectX devices after device lost") + }, + |new_devices| *directx_devices = new_devices, + || { + log::error!("Failed to recover DirectX devices after multiple attempts."); + // Do something here? + // At this point, the device loss is considered unrecoverable. + // std::process::exit(1); + }, + ); + log::info!("DirectX devices successfully recreated."); + + unsafe { + SendMessageW( + platform_window, + WM_GPUI_GPU_DEVICE_LOST, + Some(WPARAM(validation_number)), + Some(LPARAM(directx_devices as *const _ as _)), + ); + } + + if let Some(text_system) = text_system.upgrade() { + text_system.handle_gpu_lost(&directx_devices); + } + if let Some(all_windows) = all_windows.upgrade() { + for window in all_windows.read().iter() { + unsafe { + SendMessageW( + window.as_raw(), + WM_GPUI_GPU_DEVICE_LOST, + Some(WPARAM(validation_number)), + Some(LPARAM(directx_devices as *const _ as _)), + ); + } + } + std::thread::sleep(std::time::Duration::from_millis(200)); + for window in all_windows.read().iter() { + unsafe { + SendMessageW( + window.as_raw(), + WM_GPUI_FORCE_UPDATE_WINDOW, + Some(WPARAM(validation_number)), + None, + ); + } + } + } +} + +const PLATFORM_WINDOW_CLASS_NAME: PCWSTR = w!("Zed::PlatformWindow"); + +fn register_platform_window_class() { + let wc = WNDCLASSW { + lpfnWndProc: Some(window_procedure), + lpszClassName: PCWSTR(PLATFORM_WINDOW_CLASS_NAME.as_ptr()), + ..Default::default() + }; + unsafe { RegisterClassW(&wc) }; +} + +unsafe extern "system" fn window_procedure( + hwnd: HWND, + msg: u32, + wparam: WPARAM, + lparam: LPARAM, +) -> LRESULT { + if msg == WM_NCCREATE { + let params = lparam.0 as *const CREATESTRUCTW; + let params = unsafe { &*params }; + let creation_context = params.lpCreateParams as *mut PlatformWindowCreateContext; + let creation_context = unsafe { &mut *creation_context }; + return match WindowsPlatformInner::new(creation_context) { + Ok(inner) => { + let weak = Box::new(Rc::downgrade(&inner)); + unsafe { set_window_long(hwnd, GWLP_USERDATA, Box::into_raw(weak) as isize) }; + creation_context.inner = Some(Ok(inner)); + unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) } + } + Err(error) => { + creation_context.inner = Some(Err(error)); + LRESULT(0) + } + }; + } + + let ptr = unsafe { get_window_long(hwnd, GWLP_USERDATA) } as *mut Weak; + if ptr.is_null() { + return unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) }; + } + let inner = unsafe { &*ptr }; + let result = if let Some(inner) = inner.upgrade() { + inner.handle_msg(hwnd, msg, wparam, lparam) + } else { + unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) } + }; + + if msg == WM_NCDESTROY { + unsafe { set_window_long(hwnd, GWLP_USERDATA, 0) }; + unsafe { drop(Box::from_raw(ptr)) }; + } + + result +} + #[cfg(test)] mod tests { use crate::{ClipboardItem, read_from_clipboard, write_to_clipboard}; diff --git a/crates/gpui/src/platform/windows/shaders.hlsl b/crates/gpui/src/platform/windows/shaders.hlsl index 6fabe859e3fe6de58c438642455964e135258860..2cef54ae6166e313795eb42210b5f07c1bc378fc 100644 --- a/crates/gpui/src/platform/windows/shaders.hlsl +++ b/crates/gpui/src/platform/windows/shaders.hlsl @@ -1,6 +1,10 @@ +#include "alpha_correction.hlsl" + cbuffer GlobalParams: register(b0) { + float4 gamma_ratios; float2 global_viewport_size; - uint2 _pad; + float grayscale_enhanced_contrast; + uint _pad; }; Texture2D t_sprite: register(t0); @@ -1098,7 +1102,8 @@ MonochromeSpriteVertexOutput monochrome_sprite_vertex(uint vertex_id: SV_VertexI float4 monochrome_sprite_fragment(MonochromeSpriteFragmentInput input): SV_Target { float sample = t_sprite.Sample(s_sprite, input.tile_position).r; - return float4(input.color.rgb, input.color.a * sample); + float alpha_corrected = apply_contrast_and_gamma_correction(sample, input.color.rgb, grayscale_enhanced_contrast, gamma_ratios); + return float4(input.color.rgb, input.color.a * alpha_corrected); } /* diff --git a/crates/gpui/src/platform/windows/vsync.rs b/crates/gpui/src/platform/windows/vsync.rs index 6d09b0960f11cefce007413066620b3b332e1ae9..5cbcb8e99e2741c4b37cad4d550e290c4cab869f 100644 --- a/crates/gpui/src/platform/windows/vsync.rs +++ b/crates/gpui/src/platform/windows/vsync.rs @@ -94,7 +94,7 @@ impl VSyncProvider { // DwmFlush and DCompositionWaitForCompositorClock returns very early // instead of waiting until vblank when the monitor goes to sleep or is // unplugged (nothing to present due to desktop occlusion). We use 1ms as - // a threshhold for the duration of the wait functions and fallback to + // a threshold for the duration of the wait functions and fallback to // Sleep() if it returns before that. This could happen during normal // operation for the first call after the vsync thread becomes non-idle, // but it shouldn't happen often. diff --git a/crates/gpui/src/platform/windows/window.rs b/crates/gpui/src/platform/windows/window.rs index 32a6da23915d1e2bdf61c662e364119e9c6a8c64..aa907c8d734973fc4fc795b6d8ebf7654d1b40de 100644 --- a/crates/gpui/src/platform/windows/window.rs +++ b/crates/gpui/src/platform/windows/window.rs @@ -73,12 +73,13 @@ pub(crate) struct WindowsWindowInner { pub(crate) windows_version: WindowsVersion, pub(crate) validation_number: usize, pub(crate) main_receiver: flume::Receiver, - pub(crate) main_thread_id_win32: u32, + pub(crate) platform_window_handle: HWND, } impl WindowsWindowState { fn new( hwnd: HWND, + directx_devices: &DirectXDevices, window_params: &CREATESTRUCTW, current_cursor: Option, display: WindowsDisplay, @@ -104,7 +105,7 @@ impl WindowsWindowState { }; let border_offset = WindowBorderOffset::default(); let restore_from_minimized = None; - let renderer = DirectXRenderer::new(hwnd, disable_direct_composition) + let renderer = DirectXRenderer::new(hwnd, directx_devices, disable_direct_composition) .context("Creating DirectX renderer")?; let callbacks = Callbacks::default(); let input_handler = None; @@ -205,9 +206,10 @@ impl WindowsWindowState { } impl WindowsWindowInner { - fn new(context: &WindowCreateContext, hwnd: HWND, cs: &CREATESTRUCTW) -> Result> { + fn new(context: &mut WindowCreateContext, hwnd: HWND, cs: &CREATESTRUCTW) -> Result> { let state = RefCell::new(WindowsWindowState::new( hwnd, + &context.directx_devices, cs, context.current_cursor, context.display, @@ -228,7 +230,7 @@ impl WindowsWindowInner { windows_version: context.windows_version, validation_number: context.validation_number, main_receiver: context.main_receiver.clone(), - main_thread_id_win32: context.main_thread_id_win32, + platform_window_handle: context.platform_window_handle, })) } @@ -342,9 +344,10 @@ struct WindowCreateContext { drop_target_helper: IDropTargetHelper, validation_number: usize, main_receiver: flume::Receiver, - main_thread_id_win32: u32, + platform_window_handle: HWND, appearance: WindowAppearance, disable_direct_composition: bool, + directx_devices: DirectXDevices, } impl WindowsWindow { @@ -361,8 +364,9 @@ impl WindowsWindow { drop_target_helper, validation_number, main_receiver, - main_thread_id_win32, + platform_window_handle, disable_direct_composition, + directx_devices, } = creation_info; register_window_class(icon); let hide_title_bar = params @@ -382,10 +386,17 @@ impl WindowsWindow { let (mut dwexstyle, dwstyle) = if params.kind == WindowKind::PopUp { (WS_EX_TOOLWINDOW, WINDOW_STYLE(0x0)) } else { - ( - WS_EX_APPWINDOW, - WS_THICKFRAME | WS_SYSMENU | WS_MAXIMIZEBOX | WS_MINIMIZEBOX, - ) + let mut dwstyle = WS_SYSMENU; + + if params.is_resizable { + dwstyle |= WS_THICKFRAME | WS_MAXIMIZEBOX; + } + + if params.is_minimizable { + dwstyle |= WS_MINIMIZEBOX; + } + + (WS_EX_APPWINDOW, dwstyle) }; if !disable_direct_composition { dwexstyle |= WS_EX_NOREDIRECTIONBITMAP; @@ -412,9 +423,10 @@ impl WindowsWindow { drop_target_helper, validation_number, main_receiver, - main_thread_id_win32, + platform_window_handle, appearance, disable_direct_composition, + directx_devices, }; let creation_result = unsafe { CreateWindowExW( @@ -592,10 +604,7 @@ impl PlatformWindow for WindowsWindow { ) -> Option> { let (done_tx, done_rx) = oneshot::channel(); let msg = msg.to_string(); - let detail_string = match detail { - Some(info) => Some(info.to_string()), - None => None, - }; + let detail_string = detail.map(|detail| detail.to_string()); let handle = self.0.hwnd; let answers = answers.to_vec(); self.0 @@ -830,7 +839,7 @@ impl PlatformWindow for WindowsWindow { self.0.state.borrow().renderer.gpu_specs().log_err() } - fn update_ime_position(&self, _bounds: Bounds) { + fn update_ime_position(&self, _bounds: Bounds) { // There is no such thing on Windows. } } diff --git a/crates/gpui/src/shared_string.rs b/crates/gpui/src/shared_string.rs index c325f98cd243121264875d7a9452308772d49e86..350184d350aec8c5995fe7d2f0856f1fe1cfea0f 100644 --- a/crates/gpui/src/shared_string.rs +++ b/crates/gpui/src/shared_string.rs @@ -23,6 +23,11 @@ impl SharedString { pub fn new(str: impl Into>) -> Self { SharedString(ArcCow::Owned(str.into())) } + + /// Get a &str from the underlying string. + pub fn as_str(&self) -> &str { + &self.0 + } } impl JsonSchema for SharedString { @@ -103,7 +108,7 @@ impl From for Arc { fn from(val: SharedString) -> Self { match val.0 { ArcCow::Borrowed(borrowed) => Arc::from(borrowed), - ArcCow::Owned(owned) => owned.clone(), + ArcCow::Owned(owned) => owned, } } } diff --git a/crates/gpui/src/style.rs b/crates/gpui/src/style.rs index 09985722efa1b23b3b42f4d168549c651dd6bd26..53ca6508bc94e96033d0929a674ce59e8206ba04 100644 --- a/crates/gpui/src/style.rs +++ b/crates/gpui/src/style.rs @@ -573,7 +573,7 @@ impl Style { if self .border_color - .map_or(false, |color| !color.is_transparent()) + .is_some_and(|color| !color.is_transparent()) { min.x += self.border_widths.left.to_pixels(rem_size); max.x -= self.border_widths.right.to_pixels(rem_size); @@ -633,7 +633,7 @@ impl Style { window.paint_shadows(bounds, corner_radii, &self.box_shadow); let background_color = self.background.as_ref().and_then(Fill::color); - if background_color.map_or(false, |color| !color.is_transparent()) { + if background_color.is_some_and(|color| !color.is_transparent()) { let mut border_color = match background_color { Some(color) => match color.tag { BackgroundTag::Solid => color.solid, @@ -729,7 +729,7 @@ impl Style { fn is_border_visible(&self) -> bool { self.border_color - .map_or(false, |color| !color.is_transparent()) + .is_some_and(|color| !color.is_transparent()) && self.border_widths.any(|length| !length.is_zero()) } } @@ -886,43 +886,32 @@ impl HighlightStyle { } /// Blend this highlight style with another. /// Non-continuous properties, like font_weight and font_style, are overwritten. - pub fn highlight(&mut self, other: HighlightStyle) { - match (self.color, other.color) { - (Some(self_color), Some(other_color)) => { - self.color = Some(Hsla::blend(other_color, self_color)); - } - (None, Some(other_color)) => { - self.color = Some(other_color); - } - _ => {} - } - - if other.font_weight.is_some() { - self.font_weight = other.font_weight; - } - - if other.font_style.is_some() { - self.font_style = other.font_style; - } - - if other.background_color.is_some() { - self.background_color = other.background_color; - } - - if other.underline.is_some() { - self.underline = other.underline; - } - - if other.strikethrough.is_some() { - self.strikethrough = other.strikethrough; - } - - match (other.fade_out, self.fade_out) { - (Some(source_fade), None) => self.fade_out = Some(source_fade), - (Some(source_fade), Some(dest_fade)) => { - self.fade_out = Some((dest_fade * (1. + source_fade)).clamp(0., 1.)); - } - _ => {} + #[must_use] + pub fn highlight(self, other: HighlightStyle) -> Self { + Self { + color: other + .color + .map(|other_color| { + if let Some(color) = self.color { + color.blend(other_color) + } else { + other_color + } + }) + .or(self.color), + font_weight: other.font_weight.or(self.font_weight), + font_style: other.font_style.or(self.font_style), + background_color: other.background_color.or(self.background_color), + underline: other.underline.or(self.underline), + strikethrough: other.strikethrough.or(self.strikethrough), + fade_out: other + .fade_out + .map(|source_fade| { + self.fade_out + .map(|dest_fade| (dest_fade * (1. + source_fade)).clamp(0., 1.)) + .unwrap_or(source_fade) + }) + .or(self.fade_out), } } } @@ -987,10 +976,11 @@ pub fn combine_highlights( while let Some((endpoint_ix, highlight_id, is_start)) = endpoints.peek() { let prev_index = mem::replace(&mut ix, *endpoint_ix); if ix > prev_index && !active_styles.is_empty() { - let mut current_style = HighlightStyle::default(); - for highlight_id in &active_styles { - current_style.highlight(highlights[*highlight_id]); - } + let current_style = active_styles + .iter() + .fold(HighlightStyle::default(), |acc, highlight_id| { + acc.highlight(highlights[*highlight_id]) + }); return Some((prev_index..ix, current_style)); } @@ -1306,10 +1296,95 @@ impl From for taffy::style::Position { #[cfg(test)] mod tests { - use crate::{blue, green, red, yellow}; + use crate::{blue, green, px, red, yellow}; use super::*; + #[test] + fn test_basic_highlight_style_combination() { + let style_a = HighlightStyle::default(); + let style_b = HighlightStyle::default(); + let style_a = style_a.highlight(style_b); + assert_eq!( + style_a, + HighlightStyle::default(), + "Combining empty styles should not produce a non-empty style." + ); + + let mut style_b = HighlightStyle { + color: Some(red()), + strikethrough: Some(StrikethroughStyle { + thickness: px(2.), + color: Some(blue()), + }), + fade_out: Some(0.), + font_style: Some(FontStyle::Italic), + font_weight: Some(FontWeight(300.)), + background_color: Some(yellow()), + underline: Some(UnderlineStyle { + thickness: px(2.), + color: Some(red()), + wavy: true, + }), + }; + let expected_style = style_b; + + let style_a = style_a.highlight(style_b); + assert_eq!( + style_a, expected_style, + "Blending an empty style with another style should return the other style" + ); + + let style_b = style_b.highlight(Default::default()); + assert_eq!( + style_b, expected_style, + "Blending a style with an empty style should not change the style." + ); + + let mut style_c = expected_style; + + let style_d = HighlightStyle { + color: Some(blue().alpha(0.7)), + strikethrough: Some(StrikethroughStyle { + thickness: px(4.), + color: Some(crate::red()), + }), + fade_out: Some(0.), + font_style: Some(FontStyle::Oblique), + font_weight: Some(FontWeight(800.)), + background_color: Some(green()), + underline: Some(UnderlineStyle { + thickness: px(4.), + color: None, + wavy: false, + }), + }; + + let expected_style = HighlightStyle { + color: Some(red().blend(blue().alpha(0.7))), + strikethrough: Some(StrikethroughStyle { + thickness: px(4.), + color: Some(red()), + }), + // TODO this does not seem right + fade_out: Some(0.), + font_style: Some(FontStyle::Oblique), + font_weight: Some(FontWeight(800.)), + background_color: Some(green()), + underline: Some(UnderlineStyle { + thickness: px(4.), + color: None, + wavy: false, + }), + }; + + let style_c = style_c.highlight(style_d); + assert_eq!( + style_c, expected_style, + "Blending styles should blend properties where possible and override all others" + ); + } + #[test] fn test_combine_highlights() { assert_eq!( @@ -1337,14 +1412,14 @@ mod tests { ( 1..2, HighlightStyle { - color: Some(green()), + color: Some(blue()), ..Default::default() } ), ( 2..3, HighlightStyle { - color: Some(green()), + color: Some(blue()), font_style: Some(FontStyle::Italic), ..Default::default() } diff --git a/crates/gpui/src/subscription.rs b/crates/gpui/src/subscription.rs index a584f1a45f82094ce9b867bc5f43805c48f93ebe..bd869f8d32cdfc81917fc2287b7dc62fac7d727d 100644 --- a/crates/gpui/src/subscription.rs +++ b/crates/gpui/src/subscription.rs @@ -201,3 +201,9 @@ impl Drop for Subscription { } } } + +impl std::fmt::Debug for Subscription { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Subscription").finish() + } +} diff --git a/crates/gpui/src/tab_stop.rs b/crates/gpui/src/tab_stop.rs index 7dde42efed8a138de3a29657683d95c60e27dda0..c4d2fda6e9a9c3e0adfb2d02cf5c372869d42751 100644 --- a/crates/gpui/src/tab_stop.rs +++ b/crates/gpui/src/tab_stop.rs @@ -45,27 +45,18 @@ impl TabHandles { }) .unwrap_or_default(); - if let Some(next_handle) = self.handles.get(next_ix) { - Some(next_handle.clone()) - } else { - None - } + self.handles.get(next_ix).cloned() } pub(crate) fn prev(&self, focused_id: Option<&FocusId>) -> Option { let ix = self.current_index(focused_id).unwrap_or_default(); - let prev_ix; - if ix == 0 { - prev_ix = self.handles.len().saturating_sub(1); + let prev_ix = if ix == 0 { + self.handles.len().saturating_sub(1) } else { - prev_ix = ix.saturating_sub(1); - } + ix.saturating_sub(1) + }; - if let Some(prev_handle) = self.handles.get(prev_ix) { - Some(prev_handle.clone()) - } else { - None - } + self.handles.get(prev_ix).cloned() } } @@ -90,7 +81,7 @@ mod tests { ]; for handle in focus_handles.iter() { - tab.insert(&handle); + tab.insert(handle); } assert_eq!( tab.handles diff --git a/crates/gpui/src/taffy.rs b/crates/gpui/src/taffy.rs index ee21ecd8c4a4b5b4c4a3853b56af9fe210bc5481..58386ad1f5031e1427baad05c4db075df1b2d761 100644 --- a/crates/gpui/src/taffy.rs +++ b/crates/gpui/src/taffy.rs @@ -3,6 +3,7 @@ use crate::{ }; use collections::{FxHashMap, FxHashSet}; use smallvec::SmallVec; +use stacksafe::{StackSafe, stacksafe}; use std::{fmt::Debug, ops::Range}; use taffy::{ TaffyTree, TraversePartialTree as _, @@ -11,8 +12,15 @@ use taffy::{ tree::NodeId, }; -type NodeMeasureFn = Box< - dyn FnMut(Size>, Size, &mut Window, &mut App) -> Size, +type NodeMeasureFn = StackSafe< + Box< + dyn FnMut( + Size>, + Size, + &mut Window, + &mut App, + ) -> Size, + >, >; struct NodeContext { @@ -50,23 +58,21 @@ impl TaffyLayoutEngine { children: &[LayoutId], ) -> LayoutId { let taffy_style = style.to_taffy(rem_size); - let layout_id = if children.is_empty() { + + if children.is_empty() { self.taffy .new_leaf(taffy_style) .expect(EXPECT_MESSAGE) .into() } else { - let parent_id = self - .taffy + self.taffy // This is safe because LayoutId is repr(transparent) to taffy::tree::NodeId. .new_with_children(taffy_style, unsafe { std::mem::transmute::<&[LayoutId], &[taffy::NodeId]>(children) }) .expect(EXPECT_MESSAGE) - .into(); - parent_id - }; - layout_id + .into() + } } pub fn request_measured_layout( @@ -83,17 +89,15 @@ impl TaffyLayoutEngine { ) -> LayoutId { let taffy_style = style.to_taffy(rem_size); - let layout_id = self - .taffy + self.taffy .new_leaf_with_context( taffy_style, NodeContext { - measure: Box::new(measure), + measure: StackSafe::new(Box::new(measure)), }, ) .expect(EXPECT_MESSAGE) - .into(); - layout_id + .into() } // Used to understand performance @@ -143,6 +147,7 @@ impl TaffyLayoutEngine { Ok(edges) } + #[stacksafe] pub fn compute_layout( &mut self, id: LayoutId, @@ -159,7 +164,6 @@ impl TaffyLayoutEngine { // for (a, b) in self.get_edges(id)? { // println!("N{} --> N{}", u64::from(a), u64::from(b)); // } - // println!(""); // if !self.computed_layouts.insert(id) { diff --git a/crates/gpui/src/test.rs b/crates/gpui/src/test.rs index 4794fd002e28595a5d165ff3ac5876ea31c8ce20..5ae72d2be1688893374e16a55445558b5bc33040 100644 --- a/crates/gpui/src/test.rs +++ b/crates/gpui/src/test.rs @@ -64,6 +64,9 @@ pub fn run_test( if attempt < max_retries { println!("attempt {} failed, retrying", attempt); attempt += 1; + // The panic payload might itself trigger an unwind on drop: + // https://doc.rust-lang.org/std/panic/fn.catch_unwind.html#notes + std::mem::forget(error); } else { if is_multiple_runs { eprintln!("failing seed: {}", seed); diff --git a/crates/gpui/src/text_system.rs b/crates/gpui/src/text_system.rs index b48c3a29350ab3c770c5eca765f7019ae1afd8f3..4d4087f45d4093c239218f96f015d153fa77dc10 100644 --- a/crates/gpui/src/text_system.rs +++ b/crates/gpui/src/text_system.rs @@ -351,7 +351,7 @@ impl WindowTextSystem { /// /// Note that this method can only shape a single line of text. It will panic /// if the text contains newlines. If you need to shape multiple lines of text, - /// use `TextLayout::shape_text` instead. + /// use [`Self::shape_text`] instead. pub fn shape_line( &self, text: SharedString, @@ -366,15 +366,14 @@ impl WindowTextSystem { let mut decoration_runs = SmallVec::<[DecorationRun; 32]>::new(); for run in runs { - if let Some(last_run) = decoration_runs.last_mut() { - if last_run.color == run.color - && last_run.underline == run.underline - && last_run.strikethrough == run.strikethrough - && last_run.background_color == run.background_color - { - last_run.len += run.len as u32; - continue; - } + if let Some(last_run) = decoration_runs.last_mut() + && last_run.color == run.color + && last_run.underline == run.underline + && last_run.strikethrough == run.strikethrough + && last_run.background_color == run.background_color + { + last_run.len += run.len as u32; + continue; } decoration_runs.push(DecorationRun { len: run.len as u32, @@ -436,7 +435,7 @@ impl WindowTextSystem { }); } - if decoration_runs.last().map_or(false, |last_run| { + if decoration_runs.last().is_some_and(|last_run| { last_run.color == run.color && last_run.underline == run.underline && last_run.strikethrough == run.strikethrough @@ -492,14 +491,14 @@ impl WindowTextSystem { let mut split_lines = text.split('\n'); let mut processed = false; - if let Some(first_line) = split_lines.next() { - if let Some(second_line) = split_lines.next() { - processed = true; - process_line(first_line.to_string().into()); - process_line(second_line.to_string().into()); - for line_text in split_lines { - process_line(line_text.to_string().into()); - } + if let Some(first_line) = split_lines.next() + && let Some(second_line) = split_lines.next() + { + processed = true; + process_line(first_line.to_string().into()); + process_line(second_line.to_string().into()); + for line_text in split_lines { + process_line(line_text.to_string().into()); } } @@ -518,7 +517,7 @@ impl WindowTextSystem { /// Layout the given line of text, at the given font_size. /// Subsets of the line can be styled independently with the `runs` parameter. - /// Generally, you should prefer to use `TextLayout::shape_line` instead, which + /// Generally, you should prefer to use [`Self::shape_line`] instead, which /// can be painted directly. pub fn layout_line( &self, @@ -534,11 +533,11 @@ impl WindowTextSystem { let mut font_runs = self.font_runs_pool.lock().pop().unwrap_or_default(); for run in runs.iter() { let font_id = self.resolve_font(&run.font); - if let Some(last_run) = font_runs.last_mut() { - if last_run.font_id == font_id { - last_run.len += run.len; - continue; - } + if let Some(last_run) = font_runs.last_mut() + && last_run.font_id == font_id + { + last_run.len += run.len; + continue; } font_runs.push(FontRun { len: run.len, @@ -669,7 +668,7 @@ impl Display for FontStyle { } } -/// A styled run of text, for use in [`TextLayout`]. +/// A styled run of text, for use in [`crate::TextLayout`]. #[derive(Clone, Debug, PartialEq, Eq)] pub struct TextRun { /// A number of utf8 bytes @@ -695,7 +694,7 @@ impl TextRun { } } -/// An identifier for a specific glyph, as returned by [`TextSystem::layout_line`]. +/// An identifier for a specific glyph, as returned by [`WindowTextSystem::layout_line`]. #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] #[repr(C)] pub struct GlyphId(pub(crate) u32); diff --git a/crates/gpui/src/text_system/line.rs b/crates/gpui/src/text_system/line.rs index 3813393d81deaff4ed9adb1d96e204b75953233f..8d559f981581858990fa545b8e0ba65bccdf80a8 100644 --- a/crates/gpui/src/text_system/line.rs +++ b/crates/gpui/src/text_system/line.rs @@ -292,10 +292,10 @@ fn paint_line( } if let Some(style_run) = style_run { - if let Some((_, underline_style)) = &mut current_underline { - if style_run.underline.as_ref() != Some(underline_style) { - finished_underline = current_underline.take(); - } + if let Some((_, underline_style)) = &mut current_underline + && style_run.underline.as_ref() != Some(underline_style) + { + finished_underline = current_underline.take(); } if let Some(run_underline) = style_run.underline.as_ref() { current_underline.get_or_insert(( @@ -310,10 +310,10 @@ fn paint_line( }, )); } - if let Some((_, strikethrough_style)) = &mut current_strikethrough { - if style_run.strikethrough.as_ref() != Some(strikethrough_style) { - finished_strikethrough = current_strikethrough.take(); - } + if let Some((_, strikethrough_style)) = &mut current_strikethrough + && style_run.strikethrough.as_ref() != Some(strikethrough_style) + { + finished_strikethrough = current_strikethrough.take(); } if let Some(run_strikethrough) = style_run.strikethrough.as_ref() { current_strikethrough.get_or_insert(( @@ -509,10 +509,10 @@ fn paint_line_background( } if let Some(style_run) = style_run { - if let Some((_, background_color)) = &mut current_background { - if style_run.background_color.as_ref() != Some(background_color) { - finished_background = current_background.take(); - } + if let Some((_, background_color)) = &mut current_background + && style_run.background_color.as_ref() != Some(background_color) + { + finished_background = current_background.take(); } if let Some(run_background) = style_run.background_color { current_background.get_or_insert(( diff --git a/crates/gpui/src/text_system/line_layout.rs b/crates/gpui/src/text_system/line_layout.rs index 9c2dd7f0871e5b67bd15d3a419c1c03496e2afaa..43694702a82566b8f84199dcfc4ff996da93588e 100644 --- a/crates/gpui/src/text_system/line_layout.rs +++ b/crates/gpui/src/text_system/line_layout.rs @@ -185,10 +185,10 @@ impl LineLayout { if width > wrap_width && boundary > last_boundary { // When used line_clamp, we should limit the number of lines. - if let Some(max_lines) = max_lines { - if boundaries.len() >= max_lines - 1 { - break; - } + if let Some(max_lines) = max_lines + && boundaries.len() >= max_lines - 1 + { + break; } if let Some(last_candidate_ix) = last_candidate_ix.take() { diff --git a/crates/gpui/src/text_system/line_wrapper.rs b/crates/gpui/src/text_system/line_wrapper.rs index 648d714c89765d09623f154ef55ddd44d9716028..d499d78551a5e0e268b575496bbdac5ddf59369c 100644 --- a/crates/gpui/src/text_system/line_wrapper.rs +++ b/crates/gpui/src/text_system/line_wrapper.rs @@ -44,7 +44,7 @@ impl LineWrapper { let mut prev_c = '\0'; let mut index = 0; let mut candidates = fragments - .into_iter() + .iter() .flat_map(move |fragment| fragment.wrap_boundary_candidates()) .peekable(); iter::from_fn(move || { @@ -181,7 +181,7 @@ impl LineWrapper { matches!(c, '\u{0400}'..='\u{04FF}') || // Some other known special characters that should be treated as word characters, // e.g. `a-b`, `var_name`, `I'm`, '@mention`, `#hashtag`, `100%`, `3.1415`, `2^3`, `a~b`, etc. - matches!(c, '-' | '_' | '.' | '\'' | '$' | '%' | '@' | '#' | '^' | '~' | ',') || + matches!(c, '-' | '_' | '.' | '\'' | '$' | '%' | '@' | '#' | '^' | '~' | ',' | '!' | ';' | '*') || // Characters that used in URL, e.g. `https://github.com/zed-industries/zed?a=1&b=2` for better wrapping a long URL. matches!(c, '/' | ':' | '?' | '&' | '=') || // `⋯` character is special used in Zed, to keep this at the end of the line. diff --git a/crates/gpui/src/util.rs b/crates/gpui/src/util.rs index 5e92335fdc86e331d3a469c4384043fd9799b00a..3704784a954f14b8317202e227ffb1b17092d70d 100644 --- a/crates/gpui/src/util.rs +++ b/crates/gpui/src/util.rs @@ -1,13 +1,11 @@ -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering::SeqCst; -#[cfg(any(test, feature = "test-support"))] -use std::time::Duration; - -#[cfg(any(test, feature = "test-support"))] -use futures::Future; - -#[cfg(any(test, feature = "test-support"))] -use smol::future::FutureExt; +use crate::{BackgroundExecutor, Task}; +use std::{ + future::Future, + pin::Pin, + sync::atomic::{AtomicUsize, Ordering::SeqCst}, + task, + time::Duration, +}; pub use util::*; @@ -60,18 +58,63 @@ pub trait FluentBuilder { where Self: Sized, { - self.map(|this| { - if let Some(_) = option { - this - } else { - then(this) - } - }) + self.map(|this| if option.is_some() { this } else { then(this) }) + } +} + +/// Extensions for Future types that provide additional combinators and utilities. +pub trait FutureExt { + /// Requires a Future to complete before the specified duration has elapsed. + /// Similar to tokio::timeout. + fn with_timeout(self, timeout: Duration, executor: &BackgroundExecutor) -> WithTimeout + where + Self: Sized; +} + +impl FutureExt for T { + fn with_timeout(self, timeout: Duration, executor: &BackgroundExecutor) -> WithTimeout + where + Self: Sized, + { + WithTimeout { + future: self, + timer: executor.timer(timeout), + } + } +} + +pub struct WithTimeout { + future: T, + timer: Task<()>, +} + +#[derive(Debug, thiserror::Error)] +#[error("Timed out before future resolved")] +/// Error returned by with_timeout when the timeout duration elapsed before the future resolved +pub struct Timeout; + +impl Future for WithTimeout { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll { + // SAFETY: the fields of Timeout are private and we never move the future ourselves + // And its already pinned since we are being polled (all futures need to be pinned to be polled) + let this = unsafe { &raw mut *self.get_unchecked_mut() }; + let future = unsafe { Pin::new_unchecked(&mut (*this).future) }; + let timer = unsafe { Pin::new_unchecked(&mut (*this).timer) }; + + if let task::Poll::Ready(output) = future.poll(cx) { + task::Poll::Ready(Ok(output)) + } else if timer.poll(cx).is_ready() { + task::Poll::Ready(Err(Timeout)) + } else { + task::Poll::Pending + } } } #[cfg(any(test, feature = "test-support"))] -pub async fn timeout(timeout: Duration, f: F) -> Result +pub async fn smol_timeout(timeout: Duration, f: F) -> Result where F: Future, { @@ -80,7 +123,7 @@ where Err(()) }; let future = async move { Ok(f.await) }; - timer.race(future).await + smol::future::FutureExt::race(timer, future).await } /// Increment the given atomic counter if it is not zero. diff --git a/crates/gpui/src/view.rs b/crates/gpui/src/view.rs index f461e2f7d01a1dc2cdc93cda4f5854c8e958feaf..217971792ee978307a19f7e40374cb337e38a625 100644 --- a/crates/gpui/src/view.rs +++ b/crates/gpui/src/view.rs @@ -205,22 +205,21 @@ impl Element for AnyView { let content_mask = window.content_mask(); let text_style = window.text_style(); - if let Some(mut element_state) = element_state { - if element_state.cache_key.bounds == bounds - && element_state.cache_key.content_mask == content_mask - && element_state.cache_key.text_style == text_style - && !window.dirty_views.contains(&self.entity_id()) - && !window.refreshing - { - let prepaint_start = window.prepaint_index(); - window.reuse_prepaint(element_state.prepaint_range.clone()); - cx.entities - .extend_accessed(&element_state.accessed_entities); - let prepaint_end = window.prepaint_index(); - element_state.prepaint_range = prepaint_start..prepaint_end; - - return (None, element_state); - } + if let Some(mut element_state) = element_state + && element_state.cache_key.bounds == bounds + && element_state.cache_key.content_mask == content_mask + && element_state.cache_key.text_style == text_style + && !window.dirty_views.contains(&self.entity_id()) + && !window.refreshing + { + let prepaint_start = window.prepaint_index(); + window.reuse_prepaint(element_state.prepaint_range.clone()); + cx.entities + .extend_accessed(&element_state.accessed_entities); + let prepaint_end = window.prepaint_index(); + element_state.prepaint_range = prepaint_start..prepaint_end; + + return (None, element_state); } let refreshing = mem::replace(&mut window.refreshing, true); diff --git a/crates/gpui/src/window.rs b/crates/gpui/src/window.rs index c0ffd34a0d99f78f9388927ba7857ebb9661baa1..c230a4c7aa58af3d89c71b4093cd7c6b8816252a 100644 --- a/crates/gpui/src/window.rs +++ b/crates/gpui/src/window.rs @@ -12,11 +12,11 @@ use crate::{ PlatformInputHandler, PlatformWindow, Point, PolychromeSprite, PromptButton, PromptLevel, Quad, Render, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams, Replay, ResizeEdge, SMOOTH_SVG_SCALE_FACTOR, SUBPIXEL_VARIANTS, ScaledPixels, Scene, Shadow, SharedString, Size, - StrikethroughStyle, Style, SubscriberSet, Subscription, TabHandles, TaffyLayoutEngine, Task, - TextStyle, TextStyleRefinement, TransformationMatrix, Underline, UnderlineStyle, - WindowAppearance, WindowBackgroundAppearance, WindowBounds, WindowControls, WindowDecorations, - WindowOptions, WindowParams, WindowTextSystem, point, prelude::*, px, rems, size, - transparent_black, + StrikethroughStyle, Style, SubscriberSet, Subscription, SystemWindowTab, + SystemWindowTabController, TabHandles, TaffyLayoutEngine, Task, TextStyle, TextStyleRefinement, + TransformationMatrix, Underline, UnderlineStyle, WindowAppearance, WindowBackgroundAppearance, + WindowBounds, WindowControls, WindowDecorations, WindowOptions, WindowParams, WindowTextSystem, + point, prelude::*, px, rems, size, transparent_black, }; use anyhow::{Context as _, Result, anyhow}; use collections::{FxHashMap, FxHashSet}; @@ -243,14 +243,14 @@ impl FocusId { pub fn contains_focused(&self, window: &Window, cx: &App) -> bool { window .focused(cx) - .map_or(false, |focused| self.contains(focused.id, window)) + .is_some_and(|focused| self.contains(focused.id, window)) } /// Obtains whether the element associated with this handle is contained within the /// focused element or is itself focused. pub fn within_focused(&self, window: &Window, cx: &App) -> bool { let focused = window.focused(cx); - focused.map_or(false, |focused| focused.id.contains(*self, window)) + focused.is_some_and(|focused| focused.id.contains(*self, window)) } /// Obtains whether this handle contains the given handle in the most recently rendered frame. @@ -504,7 +504,7 @@ impl HitboxId { return true; } } - return false; + false } /// Checks if the hitbox with this ID contains the mouse and should handle scroll events. @@ -585,7 +585,7 @@ pub enum HitboxBehavior { /// if phase == DispatchPhase::Capture && hitbox.is_hovered(window) { /// cx.stop_propagation(); /// } - /// } + /// }) /// ``` /// /// This has effects beyond event handling - any use of hitbox checking, such as hover @@ -605,11 +605,11 @@ pub enum HitboxBehavior { /// bubble-phase handler for every mouse event type **except** `ScrollWheelEvent`: /// /// ``` - /// window.on_mouse_event(move |_: &EveryMouseEventTypeExceptScroll, phase, window, _cx| { + /// window.on_mouse_event(move |_: &EveryMouseEventTypeExceptScroll, phase, window, cx| { /// if phase == DispatchPhase::Bubble && hitbox.should_handle_scroll(window) { /// cx.stop_propagation(); /// } - /// } + /// }) /// ``` /// /// See the documentation of [`Hitbox::is_hovered`] for details of why `ScrollWheelEvent` is @@ -634,7 +634,7 @@ impl TooltipId { window .tooltip_bounds .as_ref() - .map_or(false, |tooltip_bounds| { + .is_some_and(|tooltip_bounds| { tooltip_bounds.id == *self && tooltip_bounds.bounds.contains(&window.mouse_position()) }) @@ -939,11 +939,15 @@ impl Window { show, kind, is_movable, + is_resizable, + is_minimizable, display_id, window_background, app_id, window_min_size, window_decorations, + #[cfg_attr(not(target_os = "macos"), allow(unused_variables))] + tabbing_identifier, } = options; let bounds = window_bounds @@ -956,12 +960,23 @@ impl Window { titlebar, kind, is_movable, + is_resizable, + is_minimizable, focus, show, display_id, window_min_size, + #[cfg(target_os = "macos")] + tabbing_identifier, }, )?; + + let tab_bar_visible = platform_window.tab_bar_visible(); + SystemWindowTabController::init_visible(cx, tab_bar_visible); + if let Some(tabs) = platform_window.tabbed_windows() { + SystemWindowTabController::add_tab(cx, handle.window_id(), tabs); + } + let display_id = platform_window.display().map(|display| display.id()); let sprite_atlas = platform_window.sprite_atlas(); let mouse_position = platform_window.mouse_position(); @@ -991,9 +1006,13 @@ impl Window { } platform_window.on_close(Box::new({ + let window_id = handle.window_id(); let mut cx = cx.to_async(); move || { let _ = handle.update(&mut cx, |_, window, _| window.remove_window()); + let _ = cx.update(|cx| { + SystemWindowTabController::remove_tab(cx, window_id); + }); } })); platform_window.on_request_frame(Box::new({ @@ -1082,7 +1101,11 @@ impl Window { .activation_observers .clone() .retain(&(), |callback| callback(window, cx)); + + window.bounds_changed(cx); window.refresh(); + + SystemWindowTabController::update_last_active(cx, window.handle.id); }) .log_err(); } @@ -1123,6 +1146,57 @@ impl Window { .unwrap_or(None) }) }); + platform_window.on_move_tab_to_new_window({ + let mut cx = cx.to_async(); + Box::new(move || { + handle + .update(&mut cx, |_, _window, cx| { + SystemWindowTabController::move_tab_to_new_window(cx, handle.window_id()); + }) + .log_err(); + }) + }); + platform_window.on_merge_all_windows({ + let mut cx = cx.to_async(); + Box::new(move || { + handle + .update(&mut cx, |_, _window, cx| { + SystemWindowTabController::merge_all_windows(cx, handle.window_id()); + }) + .log_err(); + }) + }); + platform_window.on_select_next_tab({ + let mut cx = cx.to_async(); + Box::new(move || { + handle + .update(&mut cx, |_, _window, cx| { + SystemWindowTabController::select_next_tab(cx, handle.window_id()); + }) + .log_err(); + }) + }); + platform_window.on_select_previous_tab({ + let mut cx = cx.to_async(); + Box::new(move || { + handle + .update(&mut cx, |_, _window, cx| { + SystemWindowTabController::select_previous_tab(cx, handle.window_id()) + }) + .log_err(); + }) + }); + platform_window.on_toggle_tab_bar({ + let mut cx = cx.to_async(); + Box::new(move || { + handle + .update(&mut cx, |_, window, cx| { + let tab_bar_visible = window.platform_window.tab_bar_visible(); + SystemWindowTabController::set_visible(cx, tab_bar_visible); + }) + .log_err(); + }) + }); if let Some(app_id) = app_id { platform_window.set_app_id(&app_id); @@ -1835,7 +1909,7 @@ impl Window { } /// Produces a new frame and assigns it to `rendered_frame`. To actually show - /// the contents of the new [Scene], use [present]. + /// the contents of the new [`Scene`], use [`Self::present`]. #[profiling::function] pub fn draw(&mut self, cx: &mut App) -> ArenaClearNeeded { self.invalidate_entities(); @@ -2377,7 +2451,7 @@ impl Window { /// Perform prepaint on child elements in a "retryable" manner, so that any side effects /// of prepaints can be discarded before prepainting again. This is used to support autoscroll /// where we need to prepaint children to detect the autoscroll bounds, then adjust the - /// element offset and prepaint again. See [`List`] for an example. This method should only be + /// element offset and prepaint again. See [`crate::List`] for an example. This method should only be /// called during the prepaint phase of element drawing. pub fn transact(&mut self, f: impl FnOnce(&mut Self) -> Result) -> Result { self.invalidator.debug_assert_prepaint(); @@ -2402,9 +2476,9 @@ impl Window { result } - /// When you call this method during [`prepaint`], containing elements will attempt to + /// When you call this method during [`Element::prepaint`], containing elements will attempt to /// scroll to cause the specified bounds to become visible. When they decide to autoscroll, they will call - /// [`prepaint`] again with a new set of bounds. See [`List`] for an example of an element + /// [`Element::prepaint`] again with a new set of bounds. See [`crate::List`] for an example of an element /// that supports this method being called on the elements it contains. This method should only be /// called during the prepaint phase of element drawing. pub fn request_autoscroll(&mut self, bounds: Bounds) { @@ -2412,8 +2486,8 @@ impl Window { self.requested_autoscroll = Some(bounds); } - /// This method can be called from a containing element such as [`List`] to support the autoscroll behavior - /// described in [`request_autoscroll`]. + /// This method can be called from a containing element such as [`crate::List`] to support the autoscroll behavior + /// described in [`Self::request_autoscroll`]. pub fn take_autoscroll(&mut self) -> Option> { self.invalidator.debug_assert_prepaint(); self.requested_autoscroll.take() @@ -2453,7 +2527,7 @@ impl Window { /// time. pub fn get_asset(&mut self, source: &A::Source, cx: &mut App) -> Option { let (task, _) = cx.fetch_asset::(source); - task.clone().now_or_never() + task.now_or_never() } /// Obtain the current element offset. This method should only be called during the /// prepaint phase of element drawing. @@ -2504,7 +2578,7 @@ impl Window { &mut self, key: impl Into, cx: &mut App, - init: impl FnOnce(&mut Self, &mut App) -> S, + init: impl FnOnce(&mut Self, &mut Context) -> S, ) -> Entity { let current_view = self.current_view(); self.with_global_id(key.into(), |global_id, window| { @@ -2537,7 +2611,7 @@ impl Window { pub fn use_state( &mut self, cx: &mut App, - init: impl FnOnce(&mut Self, &mut App) -> S, + init: impl FnOnce(&mut Self, &mut Context) -> S, ) -> Entity { self.use_keyed_state( ElementId::CodeLocation(*core::panic::Location::caller()), @@ -2741,7 +2815,7 @@ impl Window { /// Paint one or more quads into the scene for the next frame at the current stacking context. /// Quads are colored rectangular regions with an optional background, border, and corner radius. - /// see [`fill`](crate::fill), [`outline`](crate::outline), and [`quad`](crate::quad) to construct this type. + /// see [`fill`], [`outline`], and [`quad`] to construct this type. /// /// This method should only be called as part of the paint phase of element drawing. /// @@ -3044,7 +3118,7 @@ impl Window { let tile = self .sprite_atlas - .get_or_insert_with(¶ms.clone().into(), &mut || { + .get_or_insert_with(¶ms.into(), &mut || { Ok(Some(( data.size(frame_index), Cow::Borrowed( @@ -3401,16 +3475,16 @@ impl Window { let focus_id = handle.id; let (subscription, activate) = self.new_focus_listener(Box::new(move |event, window, cx| { - if let Some(blurred_id) = event.previous_focus_path.last().copied() { - if event.is_focus_out(focus_id) { - let event = FocusOutEvent { - blurred: WeakFocusHandle { - id: blurred_id, - handles: Arc::downgrade(&cx.focus_handles), - }, - }; - listener(event, window, cx) - } + if let Some(blurred_id) = event.previous_focus_path.last().copied() + && event.is_focus_out(focus_id) + { + let event = FocusOutEvent { + blurred: WeakFocusHandle { + id: blurred_id, + handles: Arc::downgrade(&cx.focus_handles), + }, + }; + listener(event, window, cx) } true })); @@ -3444,12 +3518,12 @@ impl Window { return true; } - if let Some(input) = keystroke.key_char { - if let Some(mut input_handler) = self.platform_window.take_input_handler() { - input_handler.dispatch_input(&input, self, cx); - self.platform_window.set_input_handler(input_handler); - return true; - } + if let Some(input) = keystroke.key_char + && let Some(mut input_handler) = self.platform_window.take_input_handler() + { + input_handler.dispatch_input(&input, self, cx); + self.platform_window.set_input_handler(input_handler); + return true; } false @@ -3731,7 +3805,7 @@ impl Window { self.dispatch_keystroke_observers( event, Some(binding.action), - match_result.context_stack.clone(), + match_result.context_stack, cx, ); self.pending_input_changed(cx); @@ -3864,11 +3938,11 @@ impl Window { if !cx.propagate_event { continue 'replay; } - if let Some(input) = replay.keystroke.key_char.as_ref().cloned() { - if let Some(mut input_handler) = self.platform_window.take_input_handler() { - input_handler.dispatch_input(&input, self, cx); - self.platform_window.set_input_handler(input_handler) - } + if let Some(input) = replay.keystroke.key_char.as_ref().cloned() + && let Some(mut input_handler) = self.platform_window.take_input_handler() + { + input_handler.dispatch_input(&input, self, cx); + self.platform_window.set_input_handler(input_handler) } } } @@ -4022,9 +4096,7 @@ impl Window { self.on_next_frame(|window, cx| { if let Some(mut input_handler) = window.platform_window.take_input_handler() { if let Some(bounds) = input_handler.selected_bounds(window, cx) { - window - .platform_window - .update_ime_position(bounds.scale(window.scale_factor())); + window.platform_window.update_ime_position(bounds); } window.platform_window.set_input_handler(input_handler); } @@ -4275,11 +4347,54 @@ impl Window { } /// Perform titlebar double-click action. - /// This is MacOS specific. + /// This is macOS specific. pub fn titlebar_double_click(&self) { self.platform_window.titlebar_double_click(); } + /// Gets the window's title at the platform level. + /// This is macOS specific. + pub fn window_title(&self) -> String { + self.platform_window.get_title() + } + + /// Returns a list of all tabbed windows and their titles. + /// This is macOS specific. + pub fn tabbed_windows(&self) -> Option> { + self.platform_window.tabbed_windows() + } + + /// Returns the tab bar visibility. + /// This is macOS specific. + pub fn tab_bar_visible(&self) -> bool { + self.platform_window.tab_bar_visible() + } + + /// Merges all open windows into a single tabbed window. + /// This is macOS specific. + pub fn merge_all_windows(&self) { + self.platform_window.merge_all_windows() + } + + /// Moves the tab to a new containing window. + /// This is macOS specific. + pub fn move_tab_to_new_window(&self) { + self.platform_window.move_tab_to_new_window() + } + + /// Shows or hides the window tab overview. + /// This is macOS specific. + pub fn toggle_window_tab_overview(&self) { + self.platform_window.toggle_window_tab_overview() + } + + /// Sets the tabbing identifier for the window. + /// This is macOS specific. + pub fn set_tabbing_identifier(&self, tabbing_identifier: Option) { + self.platform_window + .set_tabbing_identifier(tabbing_identifier) + } + /// Toggles the inspector mode on this window. #[cfg(any(feature = "inspector", debug_assertions))] pub fn toggle_inspector(&mut self, cx: &mut App) { @@ -4309,15 +4424,15 @@ impl Window { cx: &mut App, f: impl FnOnce(&mut Option, &mut Self) -> R, ) -> R { - if let Some(inspector_id) = _inspector_id { - if let Some(inspector) = &self.inspector { - let inspector = inspector.clone(); - let active_element_id = inspector.read(cx).active_element_id(); - if Some(inspector_id) == active_element_id { - return inspector.update(cx, |inspector, _cx| { - inspector.with_active_element_state(self, f) - }); - } + if let Some(inspector_id) = _inspector_id + && let Some(inspector) = &self.inspector + { + let inspector = inspector.clone(); + let active_element_id = inspector.read(cx).active_element_id(); + if Some(inspector_id) == active_element_id { + return inspector.update(cx, |inspector, _cx| { + inspector.with_active_element_state(self, f) + }); } } f(&mut None, self) @@ -4389,15 +4504,13 @@ impl Window { if let Some(inspector) = self.inspector.as_ref() { let inspector = inspector.read(cx); if let Some((hitbox_id, _)) = self.hovered_inspector_hitbox(inspector, &self.next_frame) - { - if let Some(hitbox) = self + && let Some(hitbox) = self .next_frame .hitboxes .iter() .find(|hitbox| hitbox.id == hitbox_id) - { - self.paint_quad(crate::fill(hitbox.bounds, crate::rgba(0x61afef4d))); - } + { + self.paint_quad(crate::fill(hitbox.bounds, crate::rgba(0x61afef4d))); } } } @@ -4444,7 +4557,7 @@ impl Window { if let Some((_, inspector_id)) = self.hovered_inspector_hitbox(inspector, &self.rendered_frame) { - inspector.set_active_element_id(inspector_id.clone(), self); + inspector.set_active_element_id(inspector_id, self); } } }); @@ -4468,7 +4581,14 @@ impl Window { } } } - return None; + None + } + + /// For testing: set the current modifier keys state. + /// This does not generate any events. + #[cfg(any(test, feature = "test-support"))] + pub fn set_modifiers(&mut self, modifiers: Modifiers) { + self.modifiers = modifiers; } } @@ -4585,7 +4705,7 @@ impl WindowHandle { where C: AppContext, { - cx.read_window(self, |root_view, _cx| root_view.clone()) + cx.read_window(self, |root_view, _cx| root_view) } /// Check if this window is 'active'. @@ -4699,7 +4819,7 @@ impl HasDisplayHandle for Window { } } -/// An identifier for an [`Element`](crate::Element). +/// An identifier for an [`Element`]. /// /// Can be constructed with a string, a number, or both, as well /// as other internal representations. diff --git a/crates/gpui_macros/src/derive_action.rs b/crates/gpui_macros/src/derive_action.rs index 9c7f97371d86eecc29dc16902ba9e392d53b8660..4e6c6277e452189657b4725b4027780a54cfed1d 100644 --- a/crates/gpui_macros/src/derive_action.rs +++ b/crates/gpui_macros/src/derive_action.rs @@ -16,6 +16,13 @@ pub(crate) fn derive_action(input: TokenStream) -> TokenStream { let mut deprecated = None; let mut doc_str: Option = None; + /* + * + * #[action()] + * Struct Foo { + * bar: bool // is bar considered an attribute + } + */ for attr in &input.attrs { if attr.path().is_ident("action") { attr.parse_nested_meta(|meta| { diff --git a/crates/gpui_macros/src/derive_inspector_reflection.rs b/crates/gpui_macros/src/derive_inspector_reflection.rs index fa22f95f9a1c274d193a6985a84bf3cdecfcc17f..9c1cb503a87e5f726ba27d1868a6c053b36c6731 100644 --- a/crates/gpui_macros/src/derive_inspector_reflection.rs +++ b/crates/gpui_macros/src/derive_inspector_reflection.rs @@ -160,16 +160,14 @@ fn extract_doc_comment(attrs: &[Attribute]) -> Option { let mut doc_lines = Vec::new(); for attr in attrs { - if attr.path().is_ident("doc") { - if let Meta::NameValue(meta) = &attr.meta { - if let Expr::Lit(expr_lit) = &meta.value { - if let Lit::Str(lit_str) = &expr_lit.lit { - let line = lit_str.value(); - let line = line.strip_prefix(' ').unwrap_or(&line); - doc_lines.push(line.to_string()); - } - } - } + if attr.path().is_ident("doc") + && let Meta::NameValue(meta) = &attr.meta + && let Expr::Lit(expr_lit) = &meta.value + && let Lit::Str(lit_str) = &expr_lit.lit + { + let line = lit_str.value(); + let line = line.strip_prefix(' ').unwrap_or(&line); + doc_lines.push(line.to_string()); } } @@ -191,7 +189,7 @@ fn extract_cfg_attributes(attrs: &[Attribute]) -> Vec { fn is_called_from_gpui_crate(_span: Span) -> bool { // Check if we're being called from within the gpui crate by examining the call site // This is a heuristic approach - we check if the current crate name is "gpui" - std::env::var("CARGO_PKG_NAME").map_or(false, |name| name == "gpui") + std::env::var("CARGO_PKG_NAME").is_ok_and(|name| name == "gpui") } struct MacroExpander; diff --git a/crates/gpui_macros/src/gpui_macros.rs b/crates/gpui_macros/src/gpui_macros.rs index 3a58af67052d06f108b4b9c87d52fc358405466e..0f1365be77ec221d9061f588f84ff6acab3c32ab 100644 --- a/crates/gpui_macros/src/gpui_macros.rs +++ b/crates/gpui_macros/src/gpui_macros.rs @@ -172,7 +172,7 @@ pub fn box_shadow_style_methods(input: TokenStream) -> TokenStream { /// - `#[gpui::test(iterations = 5)]` runs five times, providing as seed the values in the range `0..5`. /// - `#[gpui::test(retries = 3)]` runs up to four times if it fails to try and make it pass. /// - `#[gpui::test(on_failure = "crate::test::report_failure")]` will call the specified function after the -/// tests fail so that you can write out more detail about the failure. +/// tests fail so that you can write out more detail about the failure. /// /// You can combine `iterations = ...` with `seeds(...)`: /// - `#[gpui::test(iterations = 5, seed = 10)]` is equivalent to `#[gpui::test(seeds(0, 1, 2, 3, 4, 10))]`. diff --git a/crates/gpui_macros/src/test.rs b/crates/gpui_macros/src/test.rs index adb27f42ea2c7689d19290d17b78299ff149fdd2..648d3499edb0a8f13031092e37d761368363af08 100644 --- a/crates/gpui_macros/src/test.rs +++ b/crates/gpui_macros/src/test.rs @@ -73,7 +73,7 @@ impl Parse for Args { (Meta::NameValue(meta), "seed") => { seeds = vec![parse_usize_from_expr(&meta.value)? as u64] } - (Meta::List(list), "seeds") => seeds = parse_u64_array(&list)?, + (Meta::List(list), "seeds") => seeds = parse_u64_array(list)?, (Meta::Path(_), _) => { return Err(syn::Error::new(meta.span(), "invalid path argument")); } @@ -86,7 +86,7 @@ impl Parse for Args { Ok(Args { seeds, max_retries, - max_iterations: max_iterations, + max_iterations, on_failure_fn_name, }) } @@ -152,28 +152,28 @@ fn generate_test_function( } _ => {} } - } else if let Type::Reference(ty) = &*arg.ty { - if let Type::Path(ty) = &*ty.elem { - let last_segment = ty.path.segments.last(); - if let Some("TestAppContext") = - last_segment.map(|s| s.ident.to_string()).as_deref() - { - let cx_varname = format_ident!("cx_{}", ix); - cx_vars.extend(quote!( - let mut #cx_varname = gpui::TestAppContext::build( - dispatcher.clone(), - Some(stringify!(#outer_fn_name)), - ); - )); - cx_teardowns.extend(quote!( - dispatcher.run_until_parked(); - #cx_varname.executor().forbid_parking(); - #cx_varname.quit(); - dispatcher.run_until_parked(); - )); - inner_fn_args.extend(quote!(&mut #cx_varname,)); - continue; - } + } else if let Type::Reference(ty) = &*arg.ty + && let Type::Path(ty) = &*ty.elem + { + let last_segment = ty.path.segments.last(); + if let Some("TestAppContext") = + last_segment.map(|s| s.ident.to_string()).as_deref() + { + let cx_varname = format_ident!("cx_{}", ix); + cx_vars.extend(quote!( + let mut #cx_varname = gpui::TestAppContext::build( + dispatcher.clone(), + Some(stringify!(#outer_fn_name)), + ); + )); + cx_teardowns.extend(quote!( + dispatcher.run_until_parked(); + #cx_varname.executor().forbid_parking(); + #cx_varname.quit(); + dispatcher.run_until_parked(); + )); + inner_fn_args.extend(quote!(&mut #cx_varname,)); + continue; } } } @@ -215,48 +215,48 @@ fn generate_test_function( inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),)); continue; } - } else if let Type::Reference(ty) = &*arg.ty { - if let Type::Path(ty) = &*ty.elem { - let last_segment = ty.path.segments.last(); - match last_segment.map(|s| s.ident.to_string()).as_deref() { - Some("App") => { - let cx_varname = format_ident!("cx_{}", ix); - let cx_varname_lock = format_ident!("cx_{}_lock", ix); - cx_vars.extend(quote!( - let mut #cx_varname = gpui::TestAppContext::build( - dispatcher.clone(), - Some(stringify!(#outer_fn_name)) - ); - let mut #cx_varname_lock = #cx_varname.app.borrow_mut(); - )); - inner_fn_args.extend(quote!(&mut #cx_varname_lock,)); - cx_teardowns.extend(quote!( + } else if let Type::Reference(ty) = &*arg.ty + && let Type::Path(ty) = &*ty.elem + { + let last_segment = ty.path.segments.last(); + match last_segment.map(|s| s.ident.to_string()).as_deref() { + Some("App") => { + let cx_varname = format_ident!("cx_{}", ix); + let cx_varname_lock = format_ident!("cx_{}_lock", ix); + cx_vars.extend(quote!( + let mut #cx_varname = gpui::TestAppContext::build( + dispatcher.clone(), + Some(stringify!(#outer_fn_name)) + ); + let mut #cx_varname_lock = #cx_varname.app.borrow_mut(); + )); + inner_fn_args.extend(quote!(&mut #cx_varname_lock,)); + cx_teardowns.extend(quote!( drop(#cx_varname_lock); dispatcher.run_until_parked(); #cx_varname.update(|cx| { cx.background_executor().forbid_parking(); cx.quit(); }); dispatcher.run_until_parked(); )); - continue; - } - Some("TestAppContext") => { - let cx_varname = format_ident!("cx_{}", ix); - cx_vars.extend(quote!( - let mut #cx_varname = gpui::TestAppContext::build( - dispatcher.clone(), - Some(stringify!(#outer_fn_name)) - ); - )); - cx_teardowns.extend(quote!( - dispatcher.run_until_parked(); - #cx_varname.executor().forbid_parking(); - #cx_varname.quit(); - dispatcher.run_until_parked(); - )); - inner_fn_args.extend(quote!(&mut #cx_varname,)); - continue; - } - _ => {} + continue; + } + Some("TestAppContext") => { + let cx_varname = format_ident!("cx_{}", ix); + cx_vars.extend(quote!( + let mut #cx_varname = gpui::TestAppContext::build( + dispatcher.clone(), + Some(stringify!(#outer_fn_name)) + ); + )); + cx_teardowns.extend(quote!( + dispatcher.run_until_parked(); + #cx_varname.executor().forbid_parking(); + #cx_varname.quit(); + dispatcher.run_until_parked(); + )); + inner_fn_args.extend(quote!(&mut #cx_varname,)); + continue; } + _ => {} } } } diff --git a/crates/gpui_macros/tests/derive_inspector_reflection.rs b/crates/gpui_macros/tests/derive_inspector_reflection.rs index 522c0a62c469cd181c44c465547a8c19c4d04f69..a0adcb7801e55d7272191a1e4e831b2c9c6b115c 100644 --- a/crates/gpui_macros/tests/derive_inspector_reflection.rs +++ b/crates/gpui_macros/tests/derive_inspector_reflection.rs @@ -34,13 +34,6 @@ trait Transform: Clone { /// Adds one to the value fn add_one(self) -> Self; - - /// cfg attributes are respected - #[cfg(all())] - fn cfg_included(self) -> Self; - - #[cfg(any())] - fn cfg_omitted(self) -> Self; } #[derive(Debug, Clone, PartialEq)] @@ -70,10 +63,6 @@ impl Transform for Number { fn add_one(self) -> Self { Number(self.0 + 1) } - - fn cfg_included(self) -> Self { - Number(self.0) - } } #[test] @@ -83,14 +72,13 @@ fn test_derive_inspector_reflection() { // Get all methods that match the pattern fn(self) -> Self or fn(mut self) -> Self let methods = methods::(); - assert_eq!(methods.len(), 6); + assert_eq!(methods.len(), 5); let method_names: Vec<_> = methods.iter().map(|m| m.name).collect(); assert!(method_names.contains(&"double")); assert!(method_names.contains(&"triple")); assert!(method_names.contains(&"increment")); assert!(method_names.contains(&"quadruple")); assert!(method_names.contains(&"add_one")); - assert!(method_names.contains(&"cfg_included")); // Invoke methods by name let num = Number(5); @@ -106,9 +94,7 @@ fn test_derive_inspector_reflection() { .invoke(num.clone()); assert_eq!(incremented, Number(6)); - let quadrupled = find_method::("quadruple") - .unwrap() - .invoke(num.clone()); + let quadrupled = find_method::("quadruple").unwrap().invoke(num); assert_eq!(quadrupled, Number(20)); // Try to invoke a non-existent method diff --git a/crates/gpui_tokio/Cargo.toml b/crates/gpui_tokio/Cargo.toml index 46d5eafd5adceadadf5fbd942d104ee4249aa941..2d4abf40631a2f011306d7216a5f96864ccdb0da 100644 --- a/crates/gpui_tokio/Cargo.toml +++ b/crates/gpui_tokio/Cargo.toml @@ -13,6 +13,7 @@ path = "src/gpui_tokio.rs" doctest = false [dependencies] +anyhow.workspace = true util.workspace = true gpui.workspace = true tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } diff --git a/crates/gpui_tokio/src/gpui_tokio.rs b/crates/gpui_tokio/src/gpui_tokio.rs index fffe18a616d9b597c9f5ed25b68df2911c8f3886..8384f2a88ec82b96c0490913019b701cdf01239c 100644 --- a/crates/gpui_tokio/src/gpui_tokio.rs +++ b/crates/gpui_tokio/src/gpui_tokio.rs @@ -52,6 +52,28 @@ impl Tokio { }) } + /// Spawns the given future on Tokio's thread pool, and returns it via a GPUI task + /// Note that the Tokio task will be cancelled if the GPUI task is dropped + pub fn spawn_result(cx: &C, f: Fut) -> C::Result>> + where + C: AppContext, + Fut: Future> + Send + 'static, + R: Send + 'static, + { + cx.read_global(|tokio: &GlobalTokio, cx| { + let join_handle = tokio.runtime.spawn(f); + let abort_handle = join_handle.abort_handle(); + let cancel = defer(move || { + abort_handle.abort(); + }); + cx.background_spawn(async move { + let result = join_handle.await?; + drop(cancel); + result + }) + }) + } + pub fn handle(cx: &App) -> tokio::runtime::Handle { GlobalTokio::global(cx).runtime.handle().clone() } diff --git a/crates/html_to_markdown/src/markdown.rs b/crates/html_to_markdown/src/markdown.rs index b9ffbac79c6b6af64222e6447392aa3a75440dda..bb3b3563bcdff8692c80b1b79e7c94d4184bf1cb 100644 --- a/crates/html_to_markdown/src/markdown.rs +++ b/crates/html_to_markdown/src/markdown.rs @@ -34,15 +34,14 @@ impl HandleTag for ParagraphHandler { tag: &HtmlElement, writer: &mut MarkdownWriter, ) -> StartTagOutcome { - if tag.is_inline() && writer.is_inside("p") { - if let Some(parent) = writer.current_element_stack().iter().last() { - if !(parent.is_inline() - || writer.markdown.ends_with(' ') - || writer.markdown.ends_with('\n')) - { - writer.push_str(" "); - } - } + if tag.is_inline() + && writer.is_inside("p") + && let Some(parent) = writer.current_element_stack().iter().last() + && !(parent.is_inline() + || writer.markdown.ends_with(' ') + || writer.markdown.ends_with('\n')) + { + writer.push_str(" "); } if tag.tag() == "p" { diff --git a/crates/http_client/src/async_body.rs b/crates/http_client/src/async_body.rs index 473849f3cdca785a802590a60cce922c9ee0b5f9..6b99a54a7d941c290f2680bc2a599bc63251e24b 100644 --- a/crates/http_client/src/async_body.rs +++ b/crates/http_client/src/async_body.rs @@ -40,7 +40,7 @@ impl AsyncBody { } pub fn from_bytes(bytes: Bytes) -> Self { - Self(Inner::Bytes(Cursor::new(bytes.clone()))) + Self(Inner::Bytes(Cursor::new(bytes))) } } diff --git a/crates/http_client/src/github.rs b/crates/http_client/src/github.rs index 89309ff344c2a64127ee8b2603d10d029a82f6bf..32efed8e727330d3ac1c2fb6d8ea5d57fdd66dd4 100644 --- a/crates/http_client/src/github.rs +++ b/crates/http_client/src/github.rs @@ -77,10 +77,10 @@ pub async fn latest_github_release( .find(|release| release.pre_release == pre_release) .context("finding a prerelease")?; release.assets.iter_mut().for_each(|asset| { - if let Some(digest) = &mut asset.digest { - if let Some(stripped) = digest.strip_prefix("sha256:") { - *digest = stripped.to_owned(); - } + if let Some(digest) = &mut asset.digest + && let Some(stripped) = digest.strip_prefix("sha256:") + { + *digest = stripped.to_owned(); } }); Ok(release) diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index a7f75b0962561ac713e57f9ad26cb64ed82f8003..62468573ed29687c0436e98a0174baa515b0ee3d 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -435,8 +435,7 @@ impl HttpClient for FakeHttpClient { &self, req: Request, ) -> BoxFuture<'static, anyhow::Result>> { - let future = (self.handler.lock().as_ref().unwrap())(req); - future + ((self.handler.lock().as_ref().unwrap())(req)) as _ } fn user_agent(&self) -> Option<&HeaderValue> { diff --git a/crates/icons/README.md b/crates/icons/README.md index 71bc5c85459604243207c68686da5662ceabeddc..e340a00277db558b4bb13b212d53188c0c8fbe5a 100644 --- a/crates/icons/README.md +++ b/crates/icons/README.md @@ -6,7 +6,7 @@ Icons are a big part of Zed, and they're how we convey hundreds of actions witho When introducing a new icon, it's important to ensure consistency with the existing set, which follows these guidelines: 1. The SVG view box should be 16x16. -2. For outlined icons, use a 1.5px stroke width. +2. For outlined icons, use a 1.2px stroke width. 3. Not all icons are mathematically aligned; there's quite a bit of optical adjustment. However, try to keep the icon within an internal 12x12 bounding box as much as possible while ensuring proper visibility. 4. Use the `filled` and `outlined` terminology when introducing icons that will have these two variants. 5. Icons that are deeply contextual may have the feature context as their name prefix. For example, `ToolWeb`, `ReplPlay`, `DebugStepInto`, etc. diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index f5c2a83fec36fcf67647011bb1b123d8df3f8d02..f3609f7ea8706f33eb07eaaf456731e14c85555a 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -34,6 +34,7 @@ pub enum IconName { ArrowRightLeft, ArrowUp, ArrowUpRight, + Attach, AudioOff, AudioOn, Backspace, @@ -140,10 +141,12 @@ pub enum IconName { Image, Indicator, Info, + Json, Keyboard, Library, LineHeight, ListCollapse, + ListFilter, ListTodo, ListTree, ListX, @@ -154,6 +157,7 @@ pub enum IconName { Maximize, Menu, MenuAlt, + MenuAltTemp, Mic, MicMute, Minimize, @@ -162,6 +166,7 @@ pub enum IconName { PageDown, PageUp, Pencil, + PencilUnavailable, Person, Pin, PlayOutlined, @@ -211,6 +216,7 @@ pub enum IconName { Tab, Terminal, TerminalAlt, + TerminalGhost, TextSnippet, TextThread, Thread, @@ -244,6 +250,8 @@ pub enum IconName { Warning, WholeWord, XCircle, + XCircleFilled, + ZedAgent, ZedAssistant, ZedBurnMode, ZedBurnModeOn, diff --git a/crates/image_viewer/src/image_viewer.rs b/crates/image_viewer/src/image_viewer.rs index b96557b391f5941283b67b7b798ee177ab383cb2..2dca57424b86e2221acc271efac19cdf39a3f79f 100644 --- a/crates/image_viewer/src/image_viewer.rs +++ b/crates/image_viewer/src/image_viewer.rs @@ -401,12 +401,19 @@ pub fn init(cx: &mut App) { mod persistence { use std::path::PathBuf; - use db::{define_connection, query, sqlez_macros::sql}; + use db::{ + query, + sqlez::{domain::Domain, thread_safe_connection::ThreadSafeConnection}, + sqlez_macros::sql, + }; use workspace::{ItemId, WorkspaceDb, WorkspaceId}; - define_connection! { - pub static ref IMAGE_VIEWER: ImageViewerDb = - &[sql!( + pub struct ImageViewerDb(ThreadSafeConnection); + + impl Domain for ImageViewerDb { + const NAME: &str = stringify!(ImageViewerDb); + + const MIGRATIONS: &[&str] = &[sql!( CREATE TABLE image_viewers ( workspace_id INTEGER, item_id INTEGER UNIQUE, @@ -417,9 +424,11 @@ mod persistence { FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) ON DELETE CASCADE ) STRICT; - )]; + )]; } + db::static_connection!(IMAGE_VIEWER, ImageViewerDb, [WorkspaceDb]); + impl ImageViewerDb { query! { pub async fn save_image_path( diff --git a/crates/image_viewer/src/image_viewer_settings.rs b/crates/image_viewer/src/image_viewer_settings.rs index 1dcf99c0afcb3f69f48e2e1a82351852a4bf1c22..510de69b522fbb07cb8eedba43edfe3a95e4a591 100644 --- a/crates/image_viewer/src/image_viewer_settings.rs +++ b/crates/image_viewer/src/image_viewer_settings.rs @@ -1,10 +1,11 @@ use gpui::App; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsUi}; /// The settings for the image viewer. -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, Default)] +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, Default, SettingsUi, SettingsKey)] +#[settings_key(key = "image_viewer")] pub struct ImageViewerSettings { /// The unit to use for displaying image file sizes. /// @@ -24,8 +25,6 @@ pub enum ImageFileSizeUnit { } impl Settings for ImageViewerSettings { - const KEY: Option<&'static str> = Some("image_viewer"); - type FileContent = Self; fn load(sources: SettingsSources, _: &mut App) -> anyhow::Result { diff --git a/crates/indexed_docs/Cargo.toml b/crates/indexed_docs/Cargo.toml deleted file mode 100644 index eb269ad939b59394f12ccceba941585f6dec3ca7..0000000000000000000000000000000000000000 --- a/crates/indexed_docs/Cargo.toml +++ /dev/null @@ -1,38 +0,0 @@ -[package] -name = "indexed_docs" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/indexed_docs.rs" - -[dependencies] -anyhow.workspace = true -async-trait.workspace = true -cargo_metadata.workspace = true -collections.workspace = true -derive_more.workspace = true -extension.workspace = true -fs.workspace = true -futures.workspace = true -fuzzy.workspace = true -gpui.workspace = true -heed.workspace = true -html_to_markdown.workspace = true -http_client.workspace = true -indexmap.workspace = true -parking_lot.workspace = true -paths.workspace = true -serde.workspace = true -strum.workspace = true -util.workspace = true -workspace-hack.workspace = true - -[dev-dependencies] -indoc.workspace = true -pretty_assertions.workspace = true diff --git a/crates/indexed_docs/src/extension_indexed_docs_provider.rs b/crates/indexed_docs/src/extension_indexed_docs_provider.rs deleted file mode 100644 index c77ea4066d2b46f0377d7f64c7d514fd3a95872d..0000000000000000000000000000000000000000 --- a/crates/indexed_docs/src/extension_indexed_docs_provider.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::path::PathBuf; -use std::sync::Arc; - -use anyhow::Result; -use async_trait::async_trait; -use extension::{Extension, ExtensionHostProxy, ExtensionIndexedDocsProviderProxy}; -use gpui::App; - -use crate::{ - IndexedDocsDatabase, IndexedDocsProvider, IndexedDocsRegistry, PackageName, ProviderId, -}; - -pub fn init(cx: &mut App) { - let proxy = ExtensionHostProxy::default_global(cx); - proxy.register_indexed_docs_provider_proxy(IndexedDocsRegistryProxy { - indexed_docs_registry: IndexedDocsRegistry::global(cx), - }); -} - -struct IndexedDocsRegistryProxy { - indexed_docs_registry: Arc, -} - -impl ExtensionIndexedDocsProviderProxy for IndexedDocsRegistryProxy { - fn register_indexed_docs_provider(&self, extension: Arc, provider_id: Arc) { - self.indexed_docs_registry - .register_provider(Box::new(ExtensionIndexedDocsProvider::new( - extension, - ProviderId(provider_id), - ))); - } - - fn unregister_indexed_docs_provider(&self, provider_id: Arc) { - self.indexed_docs_registry - .unregister_provider(&ProviderId(provider_id)); - } -} - -pub struct ExtensionIndexedDocsProvider { - extension: Arc, - id: ProviderId, -} - -impl ExtensionIndexedDocsProvider { - pub fn new(extension: Arc, id: ProviderId) -> Self { - Self { extension, id } - } -} - -#[async_trait] -impl IndexedDocsProvider for ExtensionIndexedDocsProvider { - fn id(&self) -> ProviderId { - self.id.clone() - } - - fn database_path(&self) -> PathBuf { - let mut database_path = PathBuf::from(self.extension.work_dir().as_ref()); - database_path.push("docs"); - database_path.push(format!("{}.0.mdb", self.id)); - - database_path - } - - async fn suggest_packages(&self) -> Result> { - let packages = self - .extension - .suggest_docs_packages(self.id.0.clone()) - .await?; - - Ok(packages - .into_iter() - .map(|package| PackageName::from(package.as_str())) - .collect()) - } - - async fn index(&self, package: PackageName, database: Arc) -> Result<()> { - self.extension - .index_docs(self.id.0.clone(), package.as_ref().into(), database) - .await - } -} diff --git a/crates/indexed_docs/src/indexed_docs.rs b/crates/indexed_docs/src/indexed_docs.rs deleted file mode 100644 index 97538329d4d6265c7587209c679f4e3fa041ce35..0000000000000000000000000000000000000000 --- a/crates/indexed_docs/src/indexed_docs.rs +++ /dev/null @@ -1,16 +0,0 @@ -mod extension_indexed_docs_provider; -mod providers; -mod registry; -mod store; - -use gpui::App; - -pub use crate::extension_indexed_docs_provider::ExtensionIndexedDocsProvider; -pub use crate::providers::rustdoc::*; -pub use crate::registry::*; -pub use crate::store::*; - -pub fn init(cx: &mut App) { - IndexedDocsRegistry::init_global(cx); - extension_indexed_docs_provider::init(cx); -} diff --git a/crates/indexed_docs/src/providers.rs b/crates/indexed_docs/src/providers.rs deleted file mode 100644 index c6505a2ab667724c83d3f7bed3c1ca16a1423bc5..0000000000000000000000000000000000000000 --- a/crates/indexed_docs/src/providers.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod rustdoc; diff --git a/crates/indexed_docs/src/providers/rustdoc.rs b/crates/indexed_docs/src/providers/rustdoc.rs deleted file mode 100644 index ac6dc3a10bb3f70f7329b399287124e0417dc0f4..0000000000000000000000000000000000000000 --- a/crates/indexed_docs/src/providers/rustdoc.rs +++ /dev/null @@ -1,291 +0,0 @@ -mod item; -mod to_markdown; - -use cargo_metadata::MetadataCommand; -use futures::future::BoxFuture; -pub use item::*; -use parking_lot::RwLock; -pub use to_markdown::convert_rustdoc_to_markdown; - -use std::collections::BTreeSet; -use std::path::PathBuf; -use std::sync::{Arc, LazyLock}; -use std::time::{Duration, Instant}; - -use anyhow::{Context as _, Result, bail}; -use async_trait::async_trait; -use collections::{HashSet, VecDeque}; -use fs::Fs; -use futures::{AsyncReadExt, FutureExt}; -use http_client::{AsyncBody, HttpClient, HttpClientWithUrl}; - -use crate::{IndexedDocsDatabase, IndexedDocsProvider, PackageName, ProviderId}; - -#[derive(Debug)] -struct RustdocItemWithHistory { - pub item: RustdocItem, - #[cfg(debug_assertions)] - pub history: Vec, -} - -pub struct LocalRustdocProvider { - fs: Arc, - cargo_workspace_root: PathBuf, -} - -impl LocalRustdocProvider { - pub fn id() -> ProviderId { - ProviderId("rustdoc".into()) - } - - pub fn new(fs: Arc, cargo_workspace_root: PathBuf) -> Self { - Self { - fs, - cargo_workspace_root, - } - } -} - -#[async_trait] -impl IndexedDocsProvider for LocalRustdocProvider { - fn id(&self) -> ProviderId { - Self::id() - } - - fn database_path(&self) -> PathBuf { - paths::data_dir().join("docs/rust/rustdoc-db.1.mdb") - } - - async fn suggest_packages(&self) -> Result> { - static WORKSPACE_CRATES: LazyLock, Instant)>>> = - LazyLock::new(|| RwLock::new(None)); - - if let Some((crates, fetched_at)) = &*WORKSPACE_CRATES.read() { - if fetched_at.elapsed() < Duration::from_secs(300) { - return Ok(crates.iter().cloned().collect()); - } - } - - let workspace = MetadataCommand::new() - .manifest_path(self.cargo_workspace_root.join("Cargo.toml")) - .exec() - .context("failed to load cargo metadata")?; - - let workspace_crates = workspace - .packages - .into_iter() - .map(|package| PackageName::from(package.name.as_str())) - .collect::>(); - - *WORKSPACE_CRATES.write() = Some((workspace_crates.clone(), Instant::now())); - - Ok(workspace_crates.into_iter().collect()) - } - - async fn index(&self, package: PackageName, database: Arc) -> Result<()> { - index_rustdoc(package, database, { - move |crate_name, item| { - let fs = self.fs.clone(); - let cargo_workspace_root = self.cargo_workspace_root.clone(); - let crate_name = crate_name.clone(); - let item = item.cloned(); - async move { - let target_doc_path = cargo_workspace_root.join("target/doc"); - let mut local_cargo_doc_path = target_doc_path.join(crate_name.as_ref().replace('-', "_")); - - if !fs.is_dir(&local_cargo_doc_path).await { - let cargo_doc_exists_at_all = fs.is_dir(&target_doc_path).await; - if cargo_doc_exists_at_all { - bail!( - "no docs directory for '{crate_name}'. if this is a valid crate name, try running `cargo doc`" - ); - } else { - bail!("no cargo doc directory. run `cargo doc`"); - } - } - - if let Some(item) = item { - local_cargo_doc_path.push(item.url_path()); - } else { - local_cargo_doc_path.push("index.html"); - } - - let Ok(contents) = fs.load(&local_cargo_doc_path).await else { - return Ok(None); - }; - - Ok(Some(contents)) - } - .boxed() - } - }) - .await - } -} - -pub struct DocsDotRsProvider { - http_client: Arc, -} - -impl DocsDotRsProvider { - pub fn id() -> ProviderId { - ProviderId("docs-rs".into()) - } - - pub fn new(http_client: Arc) -> Self { - Self { http_client } - } -} - -#[async_trait] -impl IndexedDocsProvider for DocsDotRsProvider { - fn id(&self) -> ProviderId { - Self::id() - } - - fn database_path(&self) -> PathBuf { - paths::data_dir().join("docs/rust/docs-rs-db.1.mdb") - } - - async fn suggest_packages(&self) -> Result> { - static POPULAR_CRATES: LazyLock> = LazyLock::new(|| { - include_str!("./rustdoc/popular_crates.txt") - .lines() - .filter(|line| !line.starts_with('#')) - .map(|line| PackageName::from(line.trim())) - .collect() - }); - - Ok(POPULAR_CRATES.clone()) - } - - async fn index(&self, package: PackageName, database: Arc) -> Result<()> { - index_rustdoc(package, database, { - move |crate_name, item| { - let http_client = self.http_client.clone(); - let crate_name = crate_name.clone(); - let item = item.cloned(); - async move { - let version = "latest"; - let path = format!( - "{crate_name}/{version}/{crate_name}{item_path}", - item_path = item - .map(|item| format!("/{}", item.url_path())) - .unwrap_or_default() - ); - - let mut response = http_client - .get( - &format!("https://docs.rs/{path}"), - AsyncBody::default(), - true, - ) - .await?; - - let mut body = Vec::new(); - response - .body_mut() - .read_to_end(&mut body) - .await - .context("error reading docs.rs response body")?; - - if response.status().is_client_error() { - let text = String::from_utf8_lossy(body.as_slice()); - bail!( - "status error {}, response: {text:?}", - response.status().as_u16() - ); - } - - Ok(Some(String::from_utf8(body)?)) - } - .boxed() - } - }) - .await - } -} - -async fn index_rustdoc( - package: PackageName, - database: Arc, - fetch_page: impl Fn( - &PackageName, - Option<&RustdocItem>, - ) -> BoxFuture<'static, Result>> - + Send - + Sync, -) -> Result<()> { - let Some(package_root_content) = fetch_page(&package, None).await? else { - return Ok(()); - }; - - let (crate_root_markdown, items) = - convert_rustdoc_to_markdown(package_root_content.as_bytes())?; - - database - .insert(package.to_string(), crate_root_markdown) - .await?; - - let mut seen_items = HashSet::from_iter(items.clone()); - let mut items_to_visit: VecDeque = - VecDeque::from_iter(items.into_iter().map(|item| RustdocItemWithHistory { - item, - #[cfg(debug_assertions)] - history: Vec::new(), - })); - - while let Some(item_with_history) = items_to_visit.pop_front() { - let item = &item_with_history.item; - - let Some(result) = fetch_page(&package, Some(item)).await.with_context(|| { - #[cfg(debug_assertions)] - { - format!( - "failed to fetch {item:?}: {history:?}", - history = item_with_history.history - ) - } - - #[cfg(not(debug_assertions))] - { - format!("failed to fetch {item:?}") - } - })? - else { - continue; - }; - - let (markdown, referenced_items) = convert_rustdoc_to_markdown(result.as_bytes())?; - - database - .insert(format!("{package}::{}", item.display()), markdown) - .await?; - - let parent_item = item; - for mut item in referenced_items { - if seen_items.contains(&item) { - continue; - } - - seen_items.insert(item.clone()); - - item.path.extend(parent_item.path.clone()); - if parent_item.kind == RustdocItemKind::Mod { - item.path.push(parent_item.name.clone()); - } - - items_to_visit.push_back(RustdocItemWithHistory { - #[cfg(debug_assertions)] - history: { - let mut history = item_with_history.history.clone(); - history.push(item.url_path()); - history - }, - item, - }); - } - } - - Ok(()) -} diff --git a/crates/indexed_docs/src/providers/rustdoc/item.rs b/crates/indexed_docs/src/providers/rustdoc/item.rs deleted file mode 100644 index 7d9023ef3e1bcda298e7c1aaabcf9205f333ee23..0000000000000000000000000000000000000000 --- a/crates/indexed_docs/src/providers/rustdoc/item.rs +++ /dev/null @@ -1,82 +0,0 @@ -use std::sync::Arc; - -use serde::{Deserialize, Serialize}; -use strum::EnumIter; - -#[derive( - Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize, EnumIter, -)] -#[serde(rename_all = "snake_case")] -pub enum RustdocItemKind { - Mod, - Macro, - Struct, - Enum, - Constant, - Trait, - Function, - TypeAlias, - AttributeMacro, - DeriveMacro, -} - -impl RustdocItemKind { - pub(crate) const fn class(&self) -> &'static str { - match self { - Self::Mod => "mod", - Self::Macro => "macro", - Self::Struct => "struct", - Self::Enum => "enum", - Self::Constant => "constant", - Self::Trait => "trait", - Self::Function => "fn", - Self::TypeAlias => "type", - Self::AttributeMacro => "attr", - Self::DeriveMacro => "derive", - } - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)] -pub struct RustdocItem { - pub kind: RustdocItemKind, - /// The item path, up until the name of the item. - pub path: Vec>, - /// The name of the item. - pub name: Arc, -} - -impl RustdocItem { - pub fn display(&self) -> String { - let mut path_segments = self.path.clone(); - path_segments.push(self.name.clone()); - - path_segments.join("::") - } - - pub fn url_path(&self) -> String { - let name = &self.name; - let mut path_components = self.path.clone(); - - match self.kind { - RustdocItemKind::Mod => { - path_components.push(name.clone()); - path_components.push("index.html".into()); - } - RustdocItemKind::Macro - | RustdocItemKind::Struct - | RustdocItemKind::Enum - | RustdocItemKind::Constant - | RustdocItemKind::Trait - | RustdocItemKind::Function - | RustdocItemKind::TypeAlias - | RustdocItemKind::AttributeMacro - | RustdocItemKind::DeriveMacro => { - path_components - .push(format!("{kind}.{name}.html", kind = self.kind.class()).into()); - } - } - - path_components.join("/") - } -} diff --git a/crates/indexed_docs/src/providers/rustdoc/popular_crates.txt b/crates/indexed_docs/src/providers/rustdoc/popular_crates.txt deleted file mode 100644 index ce2c3d51d834ecca05ebdde23c697b19e356a478..0000000000000000000000000000000000000000 --- a/crates/indexed_docs/src/providers/rustdoc/popular_crates.txt +++ /dev/null @@ -1,252 +0,0 @@ -# A list of the most popular Rust crates. -# Sourced from https://lib.rs/std. -serde -serde_json -syn -clap -thiserror -rand -log -tokio -anyhow -regex -quote -proc-macro2 -base64 -itertools -chrono -lazy_static -once_cell -libc -reqwest -futures -bitflags -tracing -url -bytes -toml -tempfile -uuid -indexmap -env_logger -num-traits -async-trait -sha2 -hex -tracing-subscriber -http -parking_lot -cfg-if -futures-util -cc -hashbrown -rayon -hyper -getrandom -semver -strum -flate2 -tokio-util -smallvec -criterion -paste -heck -rand_core -nom -rustls -nix -glob -time -byteorder -strum_macros -serde_yaml -wasm-bindgen -ahash -either -num_cpus -rand_chacha -prost -percent-encoding -pin-project-lite -tokio-stream -bincode -walkdir -bindgen -axum -windows-sys -futures-core -ring -digest -num-bigint -rustls-pemfile -serde_with -crossbeam-channel -tokio-rustls -hmac -fastrand -dirs -zeroize -socket2 -pin-project -tower -derive_more -memchr -toml_edit -static_assertions -pretty_assertions -js-sys -convert_case -unicode-width -pkg-config -itoa -colored -rustc-hash -darling -mime -web-sys -image -bytemuck -which -sha1 -dashmap -arrayvec -fnv -tonic -humantime -libloading -winapi -rustc_version -http-body -indoc -num -home -serde_urlencoded -http-body-util -unicode-segmentation -num-integer -webpki-roots -phf -futures-channel -indicatif -petgraph -ordered-float -strsim -zstd -console -encoding_rs -wasm-bindgen-futures -urlencoding -subtle -crc32fast -slab -rustix -predicates -spin -hyper-rustls -backtrace -rustversion -mio -scopeguard -proc-macro-error -hyper-util -ryu -prost-types -textwrap -memmap2 -zip -zerocopy -generic-array -tar -pyo3 -async-stream -quick-xml -memoffset -csv -crossterm -windows -num_enum -tokio-tungstenite -crossbeam-utils -async-channel -lru -aes -futures-lite -tracing-core -prettyplease -httparse -serde_bytes -tracing-log -tower-service -cargo_metadata -pest -mime_guess -tower-http -data-encoding -native-tls -prost-build -proptest -derivative -serial_test -libm -half -futures-io -bitvec -rustls-native-certs -ureq -object -anstyle -tonic-build -form_urlencoded -num-derive -pest_derive -schemars -proc-macro-crate -rstest -futures-executor -assert_cmd -termcolor -serde_repr -ctrlc -sha3 -clap_complete -flume -mockall -ipnet -aho-corasick -atty -signal-hook -async-std -filetime -num-complex -opentelemetry -cmake -arc-swap -derive_builder -async-recursion -dyn-clone -bumpalo -fs_extra -git2 -sysinfo -shlex -instant -approx -rmp-serde -rand_distr -rustls-pki-types -maplit -sqlx -blake3 -hyper-tls -dotenvy -jsonwebtoken -openssl-sys -crossbeam -camino -winreg -config -rsa -bit-vec -chrono-tz -async-lock -bstr diff --git a/crates/indexed_docs/src/providers/rustdoc/to_markdown.rs b/crates/indexed_docs/src/providers/rustdoc/to_markdown.rs deleted file mode 100644 index 87e3863728c5822c19edf4b1dc79283d95b40481..0000000000000000000000000000000000000000 --- a/crates/indexed_docs/src/providers/rustdoc/to_markdown.rs +++ /dev/null @@ -1,618 +0,0 @@ -use std::cell::RefCell; -use std::io::Read; -use std::rc::Rc; - -use anyhow::Result; -use html_to_markdown::markdown::{ - HeadingHandler, ListHandler, ParagraphHandler, StyledTextHandler, TableHandler, -}; -use html_to_markdown::{ - HandleTag, HandlerOutcome, HtmlElement, MarkdownWriter, StartTagOutcome, TagHandler, - convert_html_to_markdown, -}; -use indexmap::IndexSet; -use strum::IntoEnumIterator; - -use crate::{RustdocItem, RustdocItemKind}; - -/// Converts the provided rustdoc HTML to Markdown. -pub fn convert_rustdoc_to_markdown(html: impl Read) -> Result<(String, Vec)> { - let item_collector = Rc::new(RefCell::new(RustdocItemCollector::new())); - - let mut handlers: Vec = vec![ - Rc::new(RefCell::new(ParagraphHandler)), - Rc::new(RefCell::new(HeadingHandler)), - Rc::new(RefCell::new(ListHandler)), - Rc::new(RefCell::new(TableHandler::new())), - Rc::new(RefCell::new(StyledTextHandler)), - Rc::new(RefCell::new(RustdocChromeRemover)), - Rc::new(RefCell::new(RustdocHeadingHandler)), - Rc::new(RefCell::new(RustdocCodeHandler)), - Rc::new(RefCell::new(RustdocItemHandler)), - item_collector.clone(), - ]; - - let markdown = convert_html_to_markdown(html, &mut handlers)?; - - let items = item_collector - .borrow() - .items - .iter() - .cloned() - .collect::>(); - - Ok((markdown, items)) -} - -pub struct RustdocHeadingHandler; - -impl HandleTag for RustdocHeadingHandler { - fn should_handle(&self, _tag: &str) -> bool { - // We're only handling text, so we don't need to visit any tags. - false - } - - fn handle_text(&mut self, text: &str, writer: &mut MarkdownWriter) -> HandlerOutcome { - if writer.is_inside("h1") - || writer.is_inside("h2") - || writer.is_inside("h3") - || writer.is_inside("h4") - || writer.is_inside("h5") - || writer.is_inside("h6") - { - let text = text - .trim_matches(|char| char == '\n' || char == '\r') - .replace('\n', " "); - writer.push_str(&text); - - return HandlerOutcome::Handled; - } - - HandlerOutcome::NoOp - } -} - -pub struct RustdocCodeHandler; - -impl HandleTag for RustdocCodeHandler { - fn should_handle(&self, tag: &str) -> bool { - matches!(tag, "pre" | "code") - } - - fn handle_tag_start( - &mut self, - tag: &HtmlElement, - writer: &mut MarkdownWriter, - ) -> StartTagOutcome { - match tag.tag() { - "code" => { - if !writer.is_inside("pre") { - writer.push_str("`"); - } - } - "pre" => { - let classes = tag.classes(); - let is_rust = classes.iter().any(|class| class == "rust"); - let language = is_rust - .then_some("rs") - .or_else(|| { - classes.iter().find_map(|class| { - if let Some((_, language)) = class.split_once("language-") { - Some(language.trim()) - } else { - None - } - }) - }) - .unwrap_or(""); - - writer.push_str(&format!("\n\n```{language}\n")); - } - _ => {} - } - - StartTagOutcome::Continue - } - - fn handle_tag_end(&mut self, tag: &HtmlElement, writer: &mut MarkdownWriter) { - match tag.tag() { - "code" => { - if !writer.is_inside("pre") { - writer.push_str("`"); - } - } - "pre" => writer.push_str("\n```\n"), - _ => {} - } - } - - fn handle_text(&mut self, text: &str, writer: &mut MarkdownWriter) -> HandlerOutcome { - if writer.is_inside("pre") { - writer.push_str(text); - return HandlerOutcome::Handled; - } - - HandlerOutcome::NoOp - } -} - -const RUSTDOC_ITEM_NAME_CLASS: &str = "item-name"; - -pub struct RustdocItemHandler; - -impl RustdocItemHandler { - /// Returns whether we're currently inside of an `.item-name` element, which - /// rustdoc uses to display Rust items in a list. - fn is_inside_item_name(writer: &MarkdownWriter) -> bool { - writer - .current_element_stack() - .iter() - .any(|element| element.has_class(RUSTDOC_ITEM_NAME_CLASS)) - } -} - -impl HandleTag for RustdocItemHandler { - fn should_handle(&self, tag: &str) -> bool { - matches!(tag, "div" | "span") - } - - fn handle_tag_start( - &mut self, - tag: &HtmlElement, - writer: &mut MarkdownWriter, - ) -> StartTagOutcome { - match tag.tag() { - "div" | "span" => { - if Self::is_inside_item_name(writer) && tag.has_class("stab") { - writer.push_str(" ["); - } - } - _ => {} - } - - StartTagOutcome::Continue - } - - fn handle_tag_end(&mut self, tag: &HtmlElement, writer: &mut MarkdownWriter) { - match tag.tag() { - "div" | "span" => { - if tag.has_class(RUSTDOC_ITEM_NAME_CLASS) { - writer.push_str(": "); - } - - if Self::is_inside_item_name(writer) && tag.has_class("stab") { - writer.push_str("]"); - } - } - _ => {} - } - } - - fn handle_text(&mut self, text: &str, writer: &mut MarkdownWriter) -> HandlerOutcome { - if Self::is_inside_item_name(writer) - && !writer.is_inside("span") - && !writer.is_inside("code") - { - writer.push_str(&format!("`{text}`")); - return HandlerOutcome::Handled; - } - - HandlerOutcome::NoOp - } -} - -pub struct RustdocChromeRemover; - -impl HandleTag for RustdocChromeRemover { - fn should_handle(&self, tag: &str) -> bool { - matches!( - tag, - "head" | "script" | "nav" | "summary" | "button" | "a" | "div" | "span" - ) - } - - fn handle_tag_start( - &mut self, - tag: &HtmlElement, - _writer: &mut MarkdownWriter, - ) -> StartTagOutcome { - match tag.tag() { - "head" | "script" | "nav" => return StartTagOutcome::Skip, - "summary" => { - if tag.has_class("hideme") { - return StartTagOutcome::Skip; - } - } - "button" => { - if tag.attr("id").as_deref() == Some("copy-path") { - return StartTagOutcome::Skip; - } - } - "a" => { - if tag.has_any_classes(&["anchor", "doc-anchor", "src"]) { - return StartTagOutcome::Skip; - } - } - "div" | "span" => { - if tag.has_any_classes(&["nav-container", "sidebar-elems", "out-of-band"]) { - return StartTagOutcome::Skip; - } - } - - _ => {} - } - - StartTagOutcome::Continue - } -} - -pub struct RustdocItemCollector { - pub items: IndexSet, -} - -impl RustdocItemCollector { - pub fn new() -> Self { - Self { - items: IndexSet::new(), - } - } - - fn parse_item(tag: &HtmlElement) -> Option { - if tag.tag() != "a" { - return None; - } - - let href = tag.attr("href")?; - if href.starts_with('#') || href.starts_with("https://") || href.starts_with("../") { - return None; - } - - for kind in RustdocItemKind::iter() { - if tag.has_class(kind.class()) { - let mut parts = href.trim_end_matches("/index.html").split('/'); - - if let Some(last_component) = parts.next_back() { - let last_component = match last_component.split_once('#') { - Some((component, _fragment)) => component, - None => last_component, - }; - - let name = last_component - .trim_start_matches(&format!("{}.", kind.class())) - .trim_end_matches(".html"); - - return Some(RustdocItem { - kind, - name: name.into(), - path: parts.map(Into::into).collect(), - }); - } - } - } - - None - } -} - -impl HandleTag for RustdocItemCollector { - fn should_handle(&self, tag: &str) -> bool { - tag == "a" - } - - fn handle_tag_start( - &mut self, - tag: &HtmlElement, - writer: &mut MarkdownWriter, - ) -> StartTagOutcome { - if tag.tag() == "a" { - let is_reexport = writer.current_element_stack().iter().any(|element| { - if let Some(id) = element.attr("id") { - id.starts_with("reexport.") || id.starts_with("method.") - } else { - false - } - }); - - if !is_reexport { - if let Some(item) = Self::parse_item(tag) { - self.items.insert(item); - } - } - } - - StartTagOutcome::Continue - } -} - -#[cfg(test)] -mod tests { - use html_to_markdown::{TagHandler, convert_html_to_markdown}; - use indoc::indoc; - use pretty_assertions::assert_eq; - - use super::*; - - fn rustdoc_handlers() -> Vec { - vec![ - Rc::new(RefCell::new(ParagraphHandler)), - Rc::new(RefCell::new(HeadingHandler)), - Rc::new(RefCell::new(ListHandler)), - Rc::new(RefCell::new(TableHandler::new())), - Rc::new(RefCell::new(StyledTextHandler)), - Rc::new(RefCell::new(RustdocChromeRemover)), - Rc::new(RefCell::new(RustdocHeadingHandler)), - Rc::new(RefCell::new(RustdocCodeHandler)), - Rc::new(RefCell::new(RustdocItemHandler)), - ] - } - - #[test] - fn test_main_heading_buttons_get_removed() { - let html = indoc! {r##" - - "##}; - let expected = indoc! {" - # Crate serde - "} - .trim(); - - assert_eq!( - convert_html_to_markdown(html.as_bytes(), &mut rustdoc_handlers()).unwrap(), - expected - ) - } - - #[test] - fn test_single_paragraph() { - let html = indoc! {r#" -

In particular, the last point is what sets axum apart from other frameworks. - axum doesn’t have its own middleware system but instead uses - tower::Service. This means axum gets timeouts, tracing, compression, - authorization, and more, for free. It also enables you to share middleware with - applications written using hyper or tonic.

- "#}; - let expected = indoc! {" - In particular, the last point is what sets `axum` apart from other frameworks. `axum` doesn’t have its own middleware system but instead uses `tower::Service`. This means `axum` gets timeouts, tracing, compression, authorization, and more, for free. It also enables you to share middleware with applications written using `hyper` or `tonic`. - "} - .trim(); - - assert_eq!( - convert_html_to_markdown(html.as_bytes(), &mut rustdoc_handlers()).unwrap(), - expected - ) - } - - #[test] - fn test_multiple_paragraphs() { - let html = indoc! {r##" -

§Serde

-

Serde is a framework for serializing and deserializing Rust data - structures efficiently and generically.

-

The Serde ecosystem consists of data structures that know how to serialize - and deserialize themselves along with data formats that know how to - serialize and deserialize other things. Serde provides the layer by which - these two groups interact with each other, allowing any supported data - structure to be serialized and deserialized using any supported data format.

-

See the Serde website https://serde.rs/ for additional documentation and - usage examples.

-

§Design

-

Where many other languages rely on runtime reflection for serializing data, - Serde is instead built on Rust’s powerful trait system. A data structure - that knows how to serialize and deserialize itself is one that implements - Serde’s Serialize and Deserialize traits (or uses Serde’s derive - attribute to automatically generate implementations at compile time). This - avoids any overhead of reflection or runtime type information. In fact in - many situations the interaction between data structure and data format can - be completely optimized away by the Rust compiler, leaving Serde - serialization to perform the same speed as a handwritten serializer for the - specific selection of data structure and data format.

- "##}; - let expected = indoc! {" - ## Serde - - Serde is a framework for _**ser**_ializing and _**de**_serializing Rust data structures efficiently and generically. - - The Serde ecosystem consists of data structures that know how to serialize and deserialize themselves along with data formats that know how to serialize and deserialize other things. Serde provides the layer by which these two groups interact with each other, allowing any supported data structure to be serialized and deserialized using any supported data format. - - See the Serde website https://serde.rs/ for additional documentation and usage examples. - - ### Design - - Where many other languages rely on runtime reflection for serializing data, Serde is instead built on Rust’s powerful trait system. A data structure that knows how to serialize and deserialize itself is one that implements Serde’s `Serialize` and `Deserialize` traits (or uses Serde’s derive attribute to automatically generate implementations at compile time). This avoids any overhead of reflection or runtime type information. In fact in many situations the interaction between data structure and data format can be completely optimized away by the Rust compiler, leaving Serde serialization to perform the same speed as a handwritten serializer for the specific selection of data structure and data format. - "} - .trim(); - - assert_eq!( - convert_html_to_markdown(html.as_bytes(), &mut rustdoc_handlers()).unwrap(), - expected - ) - } - - #[test] - fn test_styled_text() { - let html = indoc! {r#" -

This text is bolded.

-

This text is italicized.

- "#}; - let expected = indoc! {" - This text is **bolded**. - - This text is _italicized_. - "} - .trim(); - - assert_eq!( - convert_html_to_markdown(html.as_bytes(), &mut rustdoc_handlers()).unwrap(), - expected - ) - } - - #[test] - fn test_rust_code_block() { - let html = indoc! {r#" -
use axum::extract::{Path, Query, Json};
-            use std::collections::HashMap;
-
-            // `Path` gives you the path parameters and deserializes them.
-            async fn path(Path(user_id): Path<u32>) {}
-
-            // `Query` gives you the query parameters and deserializes them.
-            async fn query(Query(params): Query<HashMap<String, String>>) {}
-
-            // Buffer the request body and deserialize it as JSON into a
-            // `serde_json::Value`. `Json` supports any type that implements
-            // `serde::Deserialize`.
-            async fn json(Json(payload): Json<serde_json::Value>) {}
- "#}; - let expected = indoc! {" - ```rs - use axum::extract::{Path, Query, Json}; - use std::collections::HashMap; - - // `Path` gives you the path parameters and deserializes them. - async fn path(Path(user_id): Path) {} - - // `Query` gives you the query parameters and deserializes them. - async fn query(Query(params): Query>) {} - - // Buffer the request body and deserialize it as JSON into a - // `serde_json::Value`. `Json` supports any type that implements - // `serde::Deserialize`. - async fn json(Json(payload): Json) {} - ``` - "} - .trim(); - - assert_eq!( - convert_html_to_markdown(html.as_bytes(), &mut rustdoc_handlers()).unwrap(), - expected - ) - } - - #[test] - fn test_toml_code_block() { - let html = indoc! {r##" -

§Required dependencies

-

To use axum there are a few dependencies you have to pull in as well:

-
[dependencies]
-            axum = "<latest-version>"
-            tokio = { version = "<latest-version>", features = ["full"] }
-            tower = "<latest-version>"
-            
- "##}; - let expected = indoc! {r#" - ## Required dependencies - - To use axum there are a few dependencies you have to pull in as well: - - ```toml - [dependencies] - axum = "" - tokio = { version = "", features = ["full"] } - tower = "" - - ``` - "#} - .trim(); - - assert_eq!( - convert_html_to_markdown(html.as_bytes(), &mut rustdoc_handlers()).unwrap(), - expected - ) - } - - #[test] - fn test_item_table() { - let html = indoc! {r##" -

Structs§

-
    -
  • Errors that can happen when using axum.
  • -
  • Extractor and response for extensions.
  • -
  • Formform
    URL encoded extractor and response.
  • -
  • Jsonjson
    JSON Extractor / Response.
  • -
  • The router type for composing handlers and services.
-

Functions§

-
    -
  • servetokio and (http1 or http2)
    Serve the service with the supplied listener.
  • -
- "##}; - let expected = indoc! {r#" - ## Structs - - - `Error`: Errors that can happen when using axum. - - `Extension`: Extractor and response for extensions. - - `Form` [`form`]: URL encoded extractor and response. - - `Json` [`json`]: JSON Extractor / Response. - - `Router`: The router type for composing handlers and services. - - ## Functions - - - `serve` [`tokio` and (`http1` or `http2`)]: Serve the service with the supplied listener. - "#} - .trim(); - - assert_eq!( - convert_html_to_markdown(html.as_bytes(), &mut rustdoc_handlers()).unwrap(), - expected - ) - } - - #[test] - fn test_table() { - let html = indoc! {r##" -

§Feature flags

-

axum uses a set of feature flags to reduce the amount of compiled and - optional dependencies.

-

The following optional features are available:

-
- - - - - - - - - - - - - -
NameDescriptionDefault?
http1Enables hyper’s http1 featureYes
http2Enables hyper’s http2 featureNo
jsonEnables the Json type and some similar convenience functionalityYes
macrosEnables optional utility macrosNo
matched-pathEnables capturing of every request’s router path and the MatchedPath extractorYes
multipartEnables parsing multipart/form-data requests with MultipartNo
original-uriEnables capturing of every request’s original URI and the OriginalUri extractorYes
tokioEnables tokio as a dependency and axum::serve, SSE and extract::connect_info types.Yes
tower-logEnables tower’s log featureYes
tracingLog rejections from built-in extractorsYes
wsEnables WebSockets support via extract::wsNo
formEnables the Form extractorYes
queryEnables the Query extractorYes
- "##}; - let expected = indoc! {r#" - ## Feature flags - - axum uses a set of feature flags to reduce the amount of compiled and optional dependencies. - - The following optional features are available: - - | Name | Description | Default? | - | --- | --- | --- | - | `http1` | Enables hyper’s `http1` feature | Yes | - | `http2` | Enables hyper’s `http2` feature | No | - | `json` | Enables the `Json` type and some similar convenience functionality | Yes | - | `macros` | Enables optional utility macros | No | - | `matched-path` | Enables capturing of every request’s router path and the `MatchedPath` extractor | Yes | - | `multipart` | Enables parsing `multipart/form-data` requests with `Multipart` | No | - | `original-uri` | Enables capturing of every request’s original URI and the `OriginalUri` extractor | Yes | - | `tokio` | Enables `tokio` as a dependency and `axum::serve`, `SSE` and `extract::connect_info` types. | Yes | - | `tower-log` | Enables `tower`’s `log` feature | Yes | - | `tracing` | Log rejections from built-in extractors | Yes | - | `ws` | Enables WebSockets support via `extract::ws` | No | - | `form` | Enables the `Form` extractor | Yes | - | `query` | Enables the `Query` extractor | Yes | - "#} - .trim(); - - assert_eq!( - convert_html_to_markdown(html.as_bytes(), &mut rustdoc_handlers()).unwrap(), - expected - ) - } -} diff --git a/crates/indexed_docs/src/registry.rs b/crates/indexed_docs/src/registry.rs deleted file mode 100644 index 6757cd9c1a1a324e43765df618e01397a8b85708..0000000000000000000000000000000000000000 --- a/crates/indexed_docs/src/registry.rs +++ /dev/null @@ -1,62 +0,0 @@ -use std::sync::Arc; - -use collections::HashMap; -use gpui::{App, BackgroundExecutor, Global, ReadGlobal, UpdateGlobal}; -use parking_lot::RwLock; - -use crate::{IndexedDocsProvider, IndexedDocsStore, ProviderId}; - -struct GlobalIndexedDocsRegistry(Arc); - -impl Global for GlobalIndexedDocsRegistry {} - -pub struct IndexedDocsRegistry { - executor: BackgroundExecutor, - stores_by_provider: RwLock>>, -} - -impl IndexedDocsRegistry { - pub fn global(cx: &App) -> Arc { - GlobalIndexedDocsRegistry::global(cx).0.clone() - } - - pub(crate) fn init_global(cx: &mut App) { - GlobalIndexedDocsRegistry::set_global( - cx, - GlobalIndexedDocsRegistry(Arc::new(Self::new(cx.background_executor().clone()))), - ); - } - - pub fn new(executor: BackgroundExecutor) -> Self { - Self { - executor, - stores_by_provider: RwLock::new(HashMap::default()), - } - } - - pub fn list_providers(&self) -> Vec { - self.stores_by_provider - .read() - .keys() - .cloned() - .collect::>() - } - - pub fn register_provider( - &self, - provider: Box, - ) { - self.stores_by_provider.write().insert( - provider.id(), - Arc::new(IndexedDocsStore::new(provider, self.executor.clone())), - ); - } - - pub fn unregister_provider(&self, provider_id: &ProviderId) { - self.stores_by_provider.write().remove(provider_id); - } - - pub fn get_provider_store(&self, provider_id: ProviderId) -> Option> { - self.stores_by_provider.read().get(&provider_id).cloned() - } -} diff --git a/crates/indexed_docs/src/store.rs b/crates/indexed_docs/src/store.rs deleted file mode 100644 index 1407078efaf122acfc07e56894d3fd803605a42a..0000000000000000000000000000000000000000 --- a/crates/indexed_docs/src/store.rs +++ /dev/null @@ -1,346 +0,0 @@ -use std::path::PathBuf; -use std::sync::Arc; -use std::sync::atomic::AtomicBool; - -use anyhow::{Context as _, Result, anyhow}; -use async_trait::async_trait; -use collections::HashMap; -use derive_more::{Deref, Display}; -use futures::FutureExt; -use futures::future::{self, BoxFuture, Shared}; -use fuzzy::StringMatchCandidate; -use gpui::{App, BackgroundExecutor, Task}; -use heed::Database; -use heed::types::SerdeBincode; -use parking_lot::RwLock; -use serde::{Deserialize, Serialize}; -use util::ResultExt; - -use crate::IndexedDocsRegistry; - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Deref, Display)] -pub struct ProviderId(pub Arc); - -/// The name of a package. -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Deref, Display)] -pub struct PackageName(Arc); - -impl From<&str> for PackageName { - fn from(value: &str) -> Self { - Self(value.into()) - } -} - -#[async_trait] -pub trait IndexedDocsProvider { - /// Returns the ID of this provider. - fn id(&self) -> ProviderId; - - /// Returns the path to the database for this provider. - fn database_path(&self) -> PathBuf; - - /// Returns a list of packages as suggestions to be included in the search - /// results. - /// - /// This can be used to provide completions for known packages (e.g., from the - /// local project or a registry) before a package has been indexed. - async fn suggest_packages(&self) -> Result>; - - /// Indexes the package with the given name. - async fn index(&self, package: PackageName, database: Arc) -> Result<()>; -} - -/// A store for indexed docs. -pub struct IndexedDocsStore { - executor: BackgroundExecutor, - database_future: - Shared, Arc>>>, - provider: Box, - indexing_tasks_by_package: - RwLock>>>>>, - latest_errors_by_package: RwLock>>, -} - -impl IndexedDocsStore { - pub fn try_global(provider: ProviderId, cx: &App) -> Result> { - let registry = IndexedDocsRegistry::global(cx); - registry - .get_provider_store(provider.clone()) - .with_context(|| format!("no indexed docs store found for {provider}")) - } - - pub fn new( - provider: Box, - executor: BackgroundExecutor, - ) -> Self { - let database_future = executor - .spawn({ - let executor = executor.clone(); - let database_path = provider.database_path(); - async move { IndexedDocsDatabase::new(database_path, executor) } - }) - .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new))) - .boxed() - .shared(); - - Self { - executor, - database_future, - provider, - indexing_tasks_by_package: RwLock::new(HashMap::default()), - latest_errors_by_package: RwLock::new(HashMap::default()), - } - } - - pub fn latest_error_for_package(&self, package: &PackageName) -> Option> { - self.latest_errors_by_package.read().get(package).cloned() - } - - /// Returns whether the package with the given name is currently being indexed. - pub fn is_indexing(&self, package: &PackageName) -> bool { - self.indexing_tasks_by_package.read().contains_key(package) - } - - pub async fn load(&self, key: String) -> Result { - self.database_future - .clone() - .await - .map_err(|err| anyhow!(err))? - .load(key) - .await - } - - pub async fn load_many_by_prefix(&self, prefix: String) -> Result> { - self.database_future - .clone() - .await - .map_err(|err| anyhow!(err))? - .load_many_by_prefix(prefix) - .await - } - - /// Returns whether any entries exist with the given prefix. - pub async fn any_with_prefix(&self, prefix: String) -> Result { - self.database_future - .clone() - .await - .map_err(|err| anyhow!(err))? - .any_with_prefix(prefix) - .await - } - - pub fn suggest_packages(self: Arc) -> Task>> { - let this = self.clone(); - self.executor - .spawn(async move { this.provider.suggest_packages().await }) - } - - pub fn index( - self: Arc, - package: PackageName, - ) -> Shared>>> { - if let Some(existing_task) = self.indexing_tasks_by_package.read().get(&package) { - return existing_task.clone(); - } - - let indexing_task = self - .executor - .spawn({ - let this = self.clone(); - let package = package.clone(); - async move { - let _finally = util::defer({ - let this = this.clone(); - let package = package.clone(); - move || { - this.indexing_tasks_by_package.write().remove(&package); - } - }); - - let index_task = { - let package = package.clone(); - async { - let database = this - .database_future - .clone() - .await - .map_err(|err| anyhow!(err))?; - this.provider.index(package, database).await - } - }; - - let result = index_task.await.map_err(Arc::new); - match &result { - Ok(_) => { - this.latest_errors_by_package.write().remove(&package); - } - Err(err) => { - this.latest_errors_by_package - .write() - .insert(package, err.to_string().into()); - } - } - - result - } - }) - .shared(); - - self.indexing_tasks_by_package - .write() - .insert(package, indexing_task.clone()); - - indexing_task - } - - pub fn search(&self, query: String) -> Task> { - let executor = self.executor.clone(); - let database_future = self.database_future.clone(); - self.executor.spawn(async move { - let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else { - return Vec::new(); - }; - - let Some(items) = database.keys().await.log_err() else { - return Vec::new(); - }; - - let candidates = items - .iter() - .enumerate() - .map(|(ix, item_path)| StringMatchCandidate::new(ix, &item_path)) - .collect::>(); - - let matches = fuzzy::match_strings( - &candidates, - &query, - false, - true, - 100, - &AtomicBool::default(), - executor, - ) - .await; - - matches - .into_iter() - .map(|mat| items[mat.candidate_id].clone()) - .collect() - }) - } -} - -#[derive(Debug, PartialEq, Eq, Clone, Display, Serialize, Deserialize)] -pub struct MarkdownDocs(pub String); - -pub struct IndexedDocsDatabase { - executor: BackgroundExecutor, - env: heed::Env, - entries: Database, SerdeBincode>, -} - -impl IndexedDocsDatabase { - pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result { - std::fs::create_dir_all(&path)?; - - const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024; - let env = unsafe { - heed::EnvOpenOptions::new() - .map_size(ONE_GB_IN_BYTES) - .max_dbs(1) - .open(path)? - }; - - let mut txn = env.write_txn()?; - let entries = env.create_database(&mut txn, Some("rustdoc_entries"))?; - txn.commit()?; - - Ok(Self { - executor, - env, - entries, - }) - } - - pub fn keys(&self) -> Task>> { - let env = self.env.clone(); - let entries = self.entries; - - self.executor.spawn(async move { - let txn = env.read_txn()?; - let mut iter = entries.iter(&txn)?; - let mut keys = Vec::new(); - while let Some((key, _value)) = iter.next().transpose()? { - keys.push(key); - } - - Ok(keys) - }) - } - - pub fn load(&self, key: String) -> Task> { - let env = self.env.clone(); - let entries = self.entries; - - self.executor.spawn(async move { - let txn = env.read_txn()?; - entries - .get(&txn, &key)? - .with_context(|| format!("no docs found for {key}")) - }) - } - - pub fn load_many_by_prefix(&self, prefix: String) -> Task>> { - let env = self.env.clone(); - let entries = self.entries; - - self.executor.spawn(async move { - let txn = env.read_txn()?; - let results = entries - .iter(&txn)? - .filter_map(|entry| { - let (key, value) = entry.ok()?; - if key.starts_with(&prefix) { - Some((key, value)) - } else { - None - } - }) - .collect::>(); - - Ok(results) - }) - } - - /// Returns whether any entries exist with the given prefix. - pub fn any_with_prefix(&self, prefix: String) -> Task> { - let env = self.env.clone(); - let entries = self.entries; - - self.executor.spawn(async move { - let txn = env.read_txn()?; - let any = entries - .iter(&txn)? - .any(|entry| entry.map_or(false, |(key, _value)| key.starts_with(&prefix))); - Ok(any) - }) - } - - pub fn insert(&self, key: String, docs: String) -> Task> { - let env = self.env.clone(); - let entries = self.entries; - - self.executor.spawn(async move { - let mut txn = env.write_txn()?; - entries.put(&mut txn, &key, &MarkdownDocs(docs))?; - txn.commit()?; - Ok(()) - }) - } -} - -impl extension::KeyValueStoreDelegate for IndexedDocsDatabase { - fn insert(&self, key: String, docs: String) -> Task> { - IndexedDocsDatabase::insert(&self, key, docs) - } -} diff --git a/crates/inspector_ui/Cargo.toml b/crates/inspector_ui/Cargo.toml index 8e55a8a477e5346bd12ec594b36ac04e197dfc8e..cefe888974da2c9d164ad97079441ddec2d7fdff 100644 --- a/crates/inspector_ui/Cargo.toml +++ b/crates/inspector_ui/Cargo.toml @@ -24,6 +24,7 @@ serde_json_lenient.workspace = true theme.workspace = true ui.workspace = true util.workspace = true +util_macros.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true diff --git a/crates/inspector_ui/src/div_inspector.rs b/crates/inspector_ui/src/div_inspector.rs index bd395aa01bca42ce923073ee6f80472abc7820eb..fa8b76517f0125e7319f035b41996e445451510a 100644 --- a/crates/inspector_ui/src/div_inspector.rs +++ b/crates/inspector_ui/src/div_inspector.rs @@ -14,7 +14,10 @@ use language::{ DiagnosticSeverity, LanguageServerId, Point, ToOffset as _, ToPoint as _, }; use project::lsp_store::CompletionDocumentation; -use project::{Completion, CompletionResponse, CompletionSource, Project, ProjectPath}; +use project::{ + Completion, CompletionDisplayOptions, CompletionResponse, CompletionSource, Project, + ProjectPath, +}; use std::fmt::Write as _; use std::ops::Range; use std::path::Path; @@ -25,7 +28,7 @@ use util::split_str_with_ranges; /// Path used for unsaved buffer that contains style json. To support the json language server, this /// matches the name used in the generated schemas. -const ZED_INSPECTOR_STYLE_JSON: &str = "/zed-inspector-style.json"; +const ZED_INSPECTOR_STYLE_JSON: &str = util_macros::path!("/zed-inspector-style.json"); pub(crate) struct DivInspector { state: State, @@ -93,8 +96,8 @@ impl DivInspector { Ok((json_style_buffer, rust_style_buffer)) => { this.update_in(cx, |this, window, cx| { this.state = State::BuffersLoaded { - json_style_buffer: json_style_buffer, - rust_style_buffer: rust_style_buffer, + json_style_buffer, + rust_style_buffer, }; // Initialize editors immediately instead of waiting for @@ -200,8 +203,8 @@ impl DivInspector { cx.subscribe_in(&json_style_editor, window, { let id = id.clone(); let rust_style_buffer = rust_style_buffer.clone(); - move |this, editor, event: &EditorEvent, window, cx| match event { - EditorEvent::BufferEdited => { + move |this, editor, event: &EditorEvent, window, cx| { + if event == &EditorEvent::BufferEdited { let style_json = editor.read(cx).text(cx); match serde_json_lenient::from_str_lenient::(&style_json) { Ok(new_style) => { @@ -243,7 +246,6 @@ impl DivInspector { Err(err) => this.json_style_error = Some(err.to_string().into()), } } - _ => {} } }) .detach(); @@ -251,11 +253,10 @@ impl DivInspector { cx.subscribe(&rust_style_editor, { let json_style_buffer = json_style_buffer.clone(); let rust_style_buffer = rust_style_buffer.clone(); - move |this, _editor, event: &EditorEvent, cx| match event { - EditorEvent::BufferEdited => { + move |this, _editor, event: &EditorEvent, cx| { + if let EditorEvent::BufferEdited = event { this.update_json_style_from_rust(&json_style_buffer, &rust_style_buffer, cx); } - _ => {} } }) .detach(); @@ -271,23 +272,19 @@ impl DivInspector { } fn reset_style(&mut self, cx: &mut App) { - match &self.state { - State::Ready { - rust_style_buffer, - json_style_buffer, - .. - } => { - if let Err(err) = self.reset_style_editors( - &rust_style_buffer.clone(), - &json_style_buffer.clone(), - cx, - ) { - self.json_style_error = Some(format!("{err}").into()); - } else { - self.json_style_error = None; - } + if let State::Ready { + rust_style_buffer, + json_style_buffer, + .. + } = &self.state + { + if let Err(err) = + self.reset_style_editors(&rust_style_buffer.clone(), &json_style_buffer.clone(), cx) + { + self.json_style_error = Some(format!("{err}").into()); + } else { + self.json_style_error = None; } - _ => {} } } @@ -395,11 +392,11 @@ impl DivInspector { .zip(self.rust_completion_replace_range.as_ref()) { let before_text = snapshot - .text_for_range(0..completion_range.start.to_offset(&snapshot)) + .text_for_range(0..completion_range.start.to_offset(snapshot)) .collect::(); let after_text = snapshot .text_for_range( - completion_range.end.to_offset(&snapshot) + completion_range.end.to_offset(snapshot) ..snapshot.clip_offset(usize::MAX, Bias::Left), ) .collect::(); @@ -670,6 +667,7 @@ impl CompletionProvider for RustStyleCompletionProvider { confirm: None, }) .collect(), + display_options: CompletionDisplayOptions::default(), is_incomplete: false, }])) } @@ -702,10 +700,10 @@ impl CompletionProvider for RustStyleCompletionProvider { } fn completion_replace_range(snapshot: &BufferSnapshot, anchor: &Anchor) -> Option> { - let point = anchor.to_point(&snapshot); - let offset = point.to_offset(&snapshot); - let line_start = Point::new(point.row, 0).to_offset(&snapshot); - let line_end = Point::new(point.row, snapshot.line_len(point.row)).to_offset(&snapshot); + let point = anchor.to_point(snapshot); + let offset = point.to_offset(snapshot); + let line_start = Point::new(point.row, 0).to_offset(snapshot); + let line_end = Point::new(point.row, snapshot.line_len(point.row)).to_offset(snapshot); let mut lines = snapshot.text_for_range(line_start..line_end).lines(); let line = lines.next()?; diff --git a/crates/install_cli/src/install_cli.rs b/crates/install_cli/src/install_cli.rs index 12c094448b8362c8d638ac62da5838544b4fcc6d..281069020af37c3de6bf0df4465c495353ad82e9 100644 --- a/crates/install_cli/src/install_cli.rs +++ b/crates/install_cli/src/install_cli.rs @@ -1,112 +1,7 @@ -use anyhow::{Context as _, Result}; -use client::ZED_URL_SCHEME; -use gpui::{AppContext as _, AsyncApp, Context, PromptLevel, Window, actions}; -use release_channel::ReleaseChannel; -use std::ops::Deref; -use std::path::{Path, PathBuf}; -use util::ResultExt; -use workspace::notifications::{DetachAndPromptErr, NotificationId}; -use workspace::{Toast, Workspace}; +#[cfg(not(target_os = "windows"))] +mod install_cli_binary; +mod register_zed_scheme; -actions!( - cli, - [ - /// Installs the Zed CLI tool to the system PATH. - Install, - /// Registers the zed:// URL scheme handler. - RegisterZedScheme - ] -); - -async fn install_script(cx: &AsyncApp) -> Result { - let cli_path = cx.update(|cx| cx.path_for_auxiliary_executable("cli"))??; - let link_path = Path::new("/usr/local/bin/zed"); - let bin_dir_path = link_path.parent().unwrap(); - - // Don't re-create symlink if it points to the same CLI binary. - if smol::fs::read_link(link_path).await.ok().as_ref() == Some(&cli_path) { - return Ok(link_path.into()); - } - - // If the symlink is not there or is outdated, first try replacing it - // without escalating. - smol::fs::remove_file(link_path).await.log_err(); - // todo("windows") - #[cfg(not(windows))] - { - if smol::fs::unix::symlink(&cli_path, link_path) - .await - .log_err() - .is_some() - { - return Ok(link_path.into()); - } - } - - // The symlink could not be created, so use osascript with admin privileges - // to create it. - let status = smol::process::Command::new("/usr/bin/osascript") - .args([ - "-e", - &format!( - "do shell script \" \ - mkdir -p \'{}\' && \ - ln -sf \'{}\' \'{}\' \ - \" with administrator privileges", - bin_dir_path.to_string_lossy(), - cli_path.to_string_lossy(), - link_path.to_string_lossy(), - ), - ]) - .stdout(smol::process::Stdio::inherit()) - .stderr(smol::process::Stdio::inherit()) - .output() - .await? - .status; - anyhow::ensure!(status.success(), "error running osascript"); - Ok(link_path.into()) -} - -pub async fn register_zed_scheme(cx: &AsyncApp) -> anyhow::Result<()> { - cx.update(|cx| cx.register_url_scheme(ZED_URL_SCHEME))? - .await -} - -pub fn install_cli(window: &mut Window, cx: &mut Context) { - const LINUX_PROMPT_DETAIL: &str = "If you installed Zed from our official release add ~/.local/bin to your PATH.\n\nIf you installed Zed from a different source like your package manager, then you may need to create an alias/symlink manually.\n\nDepending on your package manager, the CLI might be named zeditor, zedit, zed-editor or something else."; - - cx.spawn_in(window, async move |workspace, cx| { - if cfg!(any(target_os = "linux", target_os = "freebsd")) { - let prompt = cx.prompt( - PromptLevel::Warning, - "CLI should already be installed", - Some(LINUX_PROMPT_DETAIL), - &["Ok"], - ); - cx.background_spawn(prompt).detach(); - return Ok(()); - } - let path = install_script(cx.deref()) - .await - .context("error creating CLI symlink")?; - - workspace.update_in(cx, |workspace, _, cx| { - struct InstalledZedCli; - - workspace.show_toast( - Toast::new( - NotificationId::unique::(), - format!( - "Installed `zed` to {}. You can launch {} from your terminal.", - path.to_string_lossy(), - ReleaseChannel::global(cx).display_name() - ), - ), - cx, - ) - })?; - register_zed_scheme(&cx).await.log_err(); - Ok(()) - }) - .detach_and_prompt_err("Error installing zed cli", window, cx, |_, _, _| None); -} +#[cfg(not(target_os = "windows"))] +pub use install_cli_binary::{InstallCliBinary, install_cli_binary}; +pub use register_zed_scheme::{RegisterZedScheme, register_zed_scheme}; diff --git a/crates/install_cli/src/install_cli_binary.rs b/crates/install_cli/src/install_cli_binary.rs new file mode 100644 index 0000000000000000000000000000000000000000..414bdabc7090be4372ff984949809839bbd3ee05 --- /dev/null +++ b/crates/install_cli/src/install_cli_binary.rs @@ -0,0 +1,101 @@ +use super::register_zed_scheme; +use anyhow::{Context as _, Result}; +use gpui::{AppContext as _, AsyncApp, Context, PromptLevel, Window, actions}; +use release_channel::ReleaseChannel; +use std::ops::Deref; +use std::path::{Path, PathBuf}; +use util::ResultExt; +use workspace::notifications::{DetachAndPromptErr, NotificationId}; +use workspace::{Toast, Workspace}; + +actions!( + cli, + [ + /// Installs the Zed CLI tool to the system PATH. + InstallCliBinary, + ] +); + +async fn install_script(cx: &AsyncApp) -> Result { + let cli_path = cx.update(|cx| cx.path_for_auxiliary_executable("cli"))??; + let link_path = Path::new("/usr/local/bin/zed"); + let bin_dir_path = link_path.parent().unwrap(); + + // Don't re-create symlink if it points to the same CLI binary. + if smol::fs::read_link(link_path).await.ok().as_ref() == Some(&cli_path) { + return Ok(link_path.into()); + } + + // If the symlink is not there or is outdated, first try replacing it + // without escalating. + smol::fs::remove_file(link_path).await.log_err(); + if smol::fs::unix::symlink(&cli_path, link_path) + .await + .log_err() + .is_some() + { + return Ok(link_path.into()); + } + + // The symlink could not be created, so use osascript with admin privileges + // to create it. + let status = smol::process::Command::new("/usr/bin/osascript") + .args([ + "-e", + &format!( + "do shell script \" \ + mkdir -p \'{}\' && \ + ln -sf \'{}\' \'{}\' \ + \" with administrator privileges", + bin_dir_path.to_string_lossy(), + cli_path.to_string_lossy(), + link_path.to_string_lossy(), + ), + ]) + .stdout(smol::process::Stdio::inherit()) + .stderr(smol::process::Stdio::inherit()) + .output() + .await? + .status; + anyhow::ensure!(status.success(), "error running osascript"); + Ok(link_path.into()) +} + +pub fn install_cli_binary(window: &mut Window, cx: &mut Context) { + const LINUX_PROMPT_DETAIL: &str = "If you installed Zed from our official release add ~/.local/bin to your PATH.\n\nIf you installed Zed from a different source like your package manager, then you may need to create an alias/symlink manually.\n\nDepending on your package manager, the CLI might be named zeditor, zedit, zed-editor or something else."; + + cx.spawn_in(window, async move |workspace, cx| { + if cfg!(any(target_os = "linux", target_os = "freebsd")) { + let prompt = cx.prompt( + PromptLevel::Warning, + "CLI should already be installed", + Some(LINUX_PROMPT_DETAIL), + &["Ok"], + ); + cx.background_spawn(prompt).detach(); + return Ok(()); + } + let path = install_script(cx.deref()) + .await + .context("error creating CLI symlink")?; + + workspace.update_in(cx, |workspace, _, cx| { + struct InstalledZedCli; + + workspace.show_toast( + Toast::new( + NotificationId::unique::(), + format!( + "Installed `zed` to {}. You can launch {} from your terminal.", + path.to_string_lossy(), + ReleaseChannel::global(cx).display_name() + ), + ), + cx, + ) + })?; + register_zed_scheme(cx).await.log_err(); + Ok(()) + }) + .detach_and_prompt_err("Error installing zed cli", window, cx, |_, _, _| None); +} diff --git a/crates/install_cli/src/register_zed_scheme.rs b/crates/install_cli/src/register_zed_scheme.rs new file mode 100644 index 0000000000000000000000000000000000000000..819287c5d0bcd15e531e21b417c7e5d4a4b4ece5 --- /dev/null +++ b/crates/install_cli/src/register_zed_scheme.rs @@ -0,0 +1,15 @@ +use client::ZED_URL_SCHEME; +use gpui::{AsyncApp, actions}; + +actions!( + cli, + [ + /// Registers the zed:// URL scheme handler. + RegisterZedScheme + ] +); + +pub async fn register_zed_scheme(cx: &AsyncApp) -> anyhow::Result<()> { + cx.update(|cx| cx.register_url_scheme(ZED_URL_SCHEME))? + .await +} diff --git a/crates/jj/src/jj_repository.rs b/crates/jj/src/jj_repository.rs index 93ae79eb90992a8fc71804788325683eae800cb4..afbe54c99dcb40a039e8f7cc87c14dc393ebac3a 100644 --- a/crates/jj/src/jj_repository.rs +++ b/crates/jj/src/jj_repository.rs @@ -50,16 +50,13 @@ impl RealJujutsuRepository { impl JujutsuRepository for RealJujutsuRepository { fn list_bookmarks(&self) -> Vec { - let bookmarks = self - .repository + self.repository .view() .bookmarks() .map(|(ref_name, _target)| Bookmark { ref_name: ref_name.as_str().to_string().into(), }) - .collect(); - - bookmarks + .collect() } } diff --git a/crates/jj/src/jj_store.rs b/crates/jj/src/jj_store.rs index a10f06fad48a3867ce6e19ffb5fc721c931ae6e4..2d2d958d7f964cdfc7723827fb2241e50d172697 100644 --- a/crates/jj/src/jj_store.rs +++ b/crates/jj/src/jj_store.rs @@ -16,7 +16,7 @@ pub struct JujutsuStore { impl JujutsuStore { pub fn init_global(cx: &mut App) { - let Some(repository) = RealJujutsuRepository::new(&Path::new(".")).ok() else { + let Some(repository) = RealJujutsuRepository::new(Path::new(".")).ok() else { return; }; diff --git a/crates/jj_ui/src/bookmark_picker.rs b/crates/jj_ui/src/bookmark_picker.rs index f6121fb9fc4cf40eaee2fa0f759e34e67d60429d..95c23e73f56059c92397222a132700351c710147 100644 --- a/crates/jj_ui/src/bookmark_picker.rs +++ b/crates/jj_ui/src/bookmark_picker.rs @@ -182,7 +182,7 @@ impl PickerDelegate for BookmarkPickerDelegate { _window: &mut Window, _cx: &mut Context>, ) -> Option { - let entry = &self.matches[ix]; + let entry = &self.matches.get(ix)?; Some( ListItem::new(ix) diff --git a/crates/journal/src/journal.rs b/crates/journal/src/journal.rs index 0335a746cd23eb2654dac7f8960a649aa3c269ff..5cdfa6c1df034deaf06e1c99ea99415757b84c29 100644 --- a/crates/journal/src/journal.rs +++ b/crates/journal/src/journal.rs @@ -5,7 +5,7 @@ use editor::{Editor, SelectionEffects}; use gpui::{App, AppContext as _, Context, Window, actions}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsUi}; use std::{ fs::OpenOptions, path::{Path, PathBuf}, @@ -22,7 +22,8 @@ actions!( ); /// Settings specific to journaling -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, SettingsUi, SettingsKey)] +#[settings_key(key = "journal")] pub struct JournalSettings { /// The path of the directory where journal entries are stored. /// @@ -52,8 +53,6 @@ pub enum HourFormat { } impl settings::Settings for JournalSettings { - const KEY: Option<&'static str> = Some("journal"); - type FileContent = Self; fn load(sources: SettingsSources, _: &mut App) -> Result { @@ -123,7 +122,7 @@ pub fn new_journal_entry(workspace: &Workspace, window: &mut Window, cx: &mut Ap } let app_state = workspace.app_state().clone(); - let view_snapshot = workspace.weak_handle().clone(); + let view_snapshot = workspace.weak_handle(); window .spawn(cx, async move |cx| { @@ -170,23 +169,23 @@ pub fn new_journal_entry(workspace: &Workspace, window: &mut Window, cx: &mut Ap .await }; - if let Some(Some(Ok(item))) = opened.first() { - if let Some(editor) = item.downcast::().map(|editor| editor.downgrade()) { - editor.update_in(cx, |editor, window, cx| { - let len = editor.buffer().read(cx).len(cx); - editor.change_selections( - SelectionEffects::scroll(Autoscroll::center()), - window, - cx, - |s| s.select_ranges([len..len]), - ); - if len > 0 { - editor.insert("\n\n", window, cx); - } - editor.insert(&entry_heading, window, cx); + if let Some(Some(Ok(item))) = opened.first() + && let Some(editor) = item.downcast::().map(|editor| editor.downgrade()) + { + editor.update_in(cx, |editor, window, cx| { + let len = editor.buffer().read(cx).len(cx); + editor.change_selections( + SelectionEffects::scroll(Autoscroll::center()), + window, + cx, + |s| s.select_ranges([len..len]), + ); + if len > 0 { editor.insert("\n\n", window, cx); - })?; - } + } + editor.insert(&entry_heading, window, cx); + editor.insert("\n\n", window, cx); + })?; } anyhow::Ok(()) @@ -195,11 +194,9 @@ pub fn new_journal_entry(workspace: &Workspace, window: &mut Window, cx: &mut Ap } fn journal_dir(path: &str) -> Option { - let expanded_journal_dir = shellexpand::full(path) //TODO handle this better + shellexpand::full(path) //TODO handle this better .ok() - .map(|dir| Path::new(&dir.to_string()).to_path_buf().join("journal")); - - expanded_journal_dir + .map(|dir| Path::new(&dir.to_string()).to_path_buf().join("journal")) } fn heading_entry(now: NaiveTime, hour_format: &Option) -> String { diff --git a/crates/keymap_editor/Cargo.toml b/crates/keymap_editor/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..ae3af21239f22a8d01ec9e792a3ab0daed6080bb --- /dev/null +++ b/crates/keymap_editor/Cargo.toml @@ -0,0 +1,53 @@ +[package] +name = "keymap_editor" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/keymap_editor.rs" + +[dependencies] +anyhow.workspace = true +collections.workspace = true +command_palette.workspace = true +component.workspace = true +db.workspace = true +editor.workspace = true +fs.workspace = true +fuzzy.workspace = true +gpui.workspace = true +itertools.workspace = true +language.workspace = true +log.workspace = true +menu.workspace = true +notifications.workspace = true +paths.workspace = true +project.workspace = true +search.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +telemetry.workspace = true +tempfile.workspace = true +theme.workspace = true +tree-sitter-json.workspace = true +tree-sitter-rust.workspace = true +ui.workspace = true +ui_input.workspace = true +util.workspace = true +vim.workspace = true +workspace-hack.workspace = true +workspace.workspace = true +zed_actions.workspace = true + +[dev-dependencies] +db = {"workspace"= true, "features" = ["test-support"]} +fs = { workspace = true, features = ["test-support"] } +gpui = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } +workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/semantic_index/LICENSE-GPL b/crates/keymap_editor/LICENSE-GPL similarity index 100% rename from crates/semantic_index/LICENSE-GPL rename to crates/keymap_editor/LICENSE-GPL diff --git a/crates/settings_ui/src/keybindings.rs b/crates/keymap_editor/src/keymap_editor.rs similarity index 77% rename from crates/settings_ui/src/keybindings.rs rename to crates/keymap_editor/src/keymap_editor.rs index a62c669488415daa689b755d3d970d6364da6dc0..7aa8a0c284576c131e472ea2fb5daba4d6ed9c23 100644 --- a/crates/settings_ui/src/keybindings.rs +++ b/crates/keymap_editor/src/keymap_editor.rs @@ -5,6 +5,8 @@ use std::{ time::Duration, }; +mod ui_components; + use anyhow::{Context as _, anyhow}; use collections::{HashMap, HashSet}; use editor::{CompletionProvider, Editor, EditorEvent}; @@ -12,18 +14,20 @@ use fs::Fs; use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{ Action, AppContext as _, AsyncApp, Axis, ClickEvent, Context, DismissEvent, Entity, - EventEmitter, FocusHandle, Focusable, Global, IsZero, KeyContext, Keystroke, MouseButton, - Point, ScrollStrategy, ScrollWheelEvent, Stateful, StyledText, Subscription, Task, - TextStyleRefinement, WeakEntity, actions, anchored, deferred, div, + EventEmitter, FocusHandle, Focusable, Global, IsZero, + KeyBindingContextPredicate::{And, Descendant, Equal, Identifier, Not, NotEqual, Or}, + KeyContext, KeybindingKeystroke, MouseButton, PlatformKeyboardMapper, Point, ScrollStrategy, + ScrollWheelEvent, Stateful, StyledText, Subscription, Task, TextStyleRefinement, WeakEntity, + actions, anchored, deferred, div, }; use language::{Language, LanguageConfig, ToOffset as _}; use notifications::status_toast::{StatusToast, ToastIcon}; -use project::Project; +use project::{CompletionDisplayOptions, Project}; use settings::{BaseKeymap, KeybindSource, KeymapFile, Settings as _, SettingsAssets}; use ui::{ ActiveTheme as _, App, Banner, BorrowAppContext, ContextMenu, IconButtonShape, Indicator, Modal, ModalFooter, ModalHeader, ParentElement as _, Render, Section, SharedString, - Styled as _, Tooltip, Window, prelude::*, + Styled as _, Tooltip, Window, prelude::*, right_click_menu, }; use ui_input::SingleLineInput; use util::ResultExt; @@ -32,8 +36,11 @@ use workspace::{ register_serializable_item, }; +pub use ui_components::*; +use zed_actions::OpenKeymapEditor; + use crate::{ - keybindings::persistence::KEYBINDING_EDITORS, + persistence::KEYBINDING_EDITORS, ui_components::{ keystroke_input::{ClearKeystrokes, KeystrokeInput, StartRecording, StopRecording}, table::{ColumnWidths, ResizeBehavior, Table, TableInteractionState}, @@ -42,14 +49,6 @@ use crate::{ const NO_ACTION_ARGUMENTS_TEXT: SharedString = SharedString::new_static(""); -actions!( - zed, - [ - /// Opens the keymap editor. - OpenKeymapEditor - ] -); - actions!( keymap_editor, [ @@ -172,7 +171,7 @@ impl FilterState { #[derive(Debug, Default, PartialEq, Eq, Clone, Hash)] struct ActionMapping { - keystrokes: Vec, + keystrokes: Vec, context: Option, } @@ -182,15 +181,6 @@ struct KeybindConflict { remaining_conflict_amount: usize, } -impl KeybindConflict { - fn from_iter<'a>(mut indices: impl Iterator) -> Option { - indices.next().map(|origin| Self { - first_conflict_index: origin.index, - remaining_conflict_amount: indices.count(), - }) - } -} - #[derive(Clone, Copy, PartialEq)] struct ConflictOrigin { override_source: KeybindSource, @@ -238,13 +228,21 @@ impl ConflictOrigin { #[derive(Default)] struct ConflictState { conflicts: Vec>, - keybind_mapping: HashMap>, + keybind_mapping: ConflictKeybindMapping, has_user_conflicts: bool, } +type ConflictKeybindMapping = HashMap< + Vec, + Vec<( + Option, + Vec, + )>, +>; + impl ConflictState { fn new(key_bindings: &[ProcessedBinding]) -> Self { - let mut action_keybind_mapping: HashMap<_, Vec> = HashMap::default(); + let mut action_keybind_mapping = ConflictKeybindMapping::default(); let mut largest_index = 0; for (index, binding) in key_bindings @@ -252,29 +250,48 @@ impl ConflictState { .enumerate() .flat_map(|(index, binding)| Some(index).zip(binding.keybind_information())) { - action_keybind_mapping - .entry(binding.get_action_mapping()) - .or_default() - .push(ConflictOrigin::new(binding.source, index)); + let mapping = binding.get_action_mapping(); + let predicate = mapping + .context + .and_then(|ctx| gpui::KeyBindingContextPredicate::parse(&ctx).ok()); + let entry = action_keybind_mapping + .entry(mapping.keystrokes) + .or_default(); + let origin = ConflictOrigin::new(binding.source, index); + if let Some((_, origins)) = + entry + .iter_mut() + .find(|(other_predicate, _)| match (&predicate, other_predicate) { + (None, None) => true, + (Some(a), Some(b)) => normalized_ctx_eq(a, b), + _ => false, + }) + { + origins.push(origin); + } else { + entry.push((predicate, vec![origin])); + } largest_index = index; } let mut conflicts = vec![None; largest_index + 1]; let mut has_user_conflicts = false; - for indices in action_keybind_mapping.values_mut() { - indices.sort_unstable_by_key(|origin| origin.override_source); - let Some((fst, snd)) = indices.get(0).zip(indices.get(1)) else { - continue; - }; + for entries in action_keybind_mapping.values_mut() { + for (_, indices) in entries.iter_mut() { + indices.sort_unstable_by_key(|origin| origin.override_source); + let Some((fst, snd)) = indices.get(0).zip(indices.get(1)) else { + continue; + }; - for origin in indices.iter() { - conflicts[origin.index] = - origin.get_conflict_with(if origin == fst { &snd } else { &fst }) - } + for origin in indices.iter() { + conflicts[origin.index] = + origin.get_conflict_with(if origin == fst { snd } else { fst }) + } - has_user_conflicts |= fst.override_source == KeybindSource::User - && snd.override_source == KeybindSource::User; + has_user_conflicts |= fst.override_source == KeybindSource::User + && snd.override_source == KeybindSource::User; + } } Self { @@ -289,15 +306,34 @@ impl ConflictState { action_mapping: &ActionMapping, keybind_idx: Option, ) -> Option { - self.keybind_mapping - .get(action_mapping) - .and_then(|indices| { - KeybindConflict::from_iter( - indices + let ActionMapping { + keystrokes, + context, + } = action_mapping; + let predicate = context + .as_deref() + .and_then(|ctx| gpui::KeyBindingContextPredicate::parse(&ctx).ok()); + self.keybind_mapping.get(keystrokes).and_then(|entries| { + entries + .iter() + .find_map(|(other_predicate, indices)| { + match (&predicate, other_predicate) { + (None, None) => true, + (Some(pred), Some(other)) => normalized_ctx_eq(pred, other), + _ => false, + } + .then_some(indices) + }) + .and_then(|indices| { + let mut indices = indices .iter() - .filter(|&conflict| Some(conflict.index) != keybind_idx), - ) - }) + .filter(|&conflict| Some(conflict.index) != keybind_idx); + indices.next().map(|origin| KeybindConflict { + first_conflict_index: origin.index, + remaining_conflict_amount: indices.count(), + }) + }) + }) } fn conflict_for_idx(&self, idx: usize) -> Option { @@ -375,12 +411,14 @@ impl Focusable for KeymapEditor { } } /// Helper function to check if two keystroke sequences match exactly -fn keystrokes_match_exactly(keystrokes1: &[Keystroke], keystrokes2: &[Keystroke]) -> bool { +fn keystrokes_match_exactly( + keystrokes1: &[KeybindingKeystroke], + keystrokes2: &[KeybindingKeystroke], +) -> bool { keystrokes1.len() == keystrokes2.len() - && keystrokes1 - .iter() - .zip(keystrokes2) - .all(|(k1, k2)| k1.key == k2.key && k1.modifiers == k2.modifiers) + && keystrokes1.iter().zip(keystrokes2).all(|(k1, k2)| { + k1.inner().key == k2.inner().key && k1.inner().modifiers == k2.inner().modifiers + }) } impl KeymapEditor { @@ -397,7 +435,7 @@ impl KeymapEditor { let filter_editor = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); - editor.set_placeholder_text("Filter action names…", cx); + editor.set_placeholder_text("Filter action names…", window, cx); editor }); @@ -470,15 +508,9 @@ impl KeymapEditor { self.filter_editor.read(cx).text(cx) } - fn current_keystroke_query(&self, cx: &App) -> Vec { + fn current_keystroke_query(&self, cx: &App) -> Vec { match self.search_mode { - SearchMode::KeyStroke { .. } => self - .keystroke_editor - .read(cx) - .keystrokes() - .iter() - .cloned() - .collect(), + SearchMode::KeyStroke { .. } => self.keystroke_editor.read(cx).keystrokes().to_vec(), SearchMode::Normal => Default::default(), } } @@ -497,7 +529,7 @@ impl KeymapEditor { let keystroke_query = keystroke_query .into_iter() - .map(|keystroke| keystroke.unparse()) + .map(|keystroke| keystroke.inner().unparse()) .collect::>() .join(" "); @@ -521,7 +553,7 @@ impl KeymapEditor { async fn update_matches( this: WeakEntity, action_query: String, - keystroke_query: Vec, + keystroke_query: Vec, cx: &mut AsyncApp, ) -> anyhow::Result<()> { let action_query = command_palette::normalize_action_query(&action_query); @@ -559,7 +591,7 @@ impl KeymapEditor { if exact_match { keystrokes_match_exactly(&keystroke_query, keystrokes) } else if keystroke_query.len() > keystrokes.len() { - return false; + false } else { for keystroke_offset in 0..keystrokes.len() { let mut found_count = 0; @@ -570,16 +602,15 @@ impl KeymapEditor { { let query = &keystroke_query[query_cursor]; let keystroke = &keystrokes[keystroke_cursor]; - let matches = - query.modifiers.is_subset_of(&keystroke.modifiers) - && ((query.key.is_empty() - || query.key == keystroke.key) - && query - .key_char - .as_ref() - .map_or(true, |q_kc| { - q_kc == &keystroke.key - })); + let matches = query + .inner() + .modifiers + .is_subset_of(&keystroke.inner().modifiers) + && ((query.inner().key.is_empty() + || query.inner().key == keystroke.inner().key) + && query.inner().key_char.as_ref().is_none_or( + |q_kc| q_kc == &keystroke.inner().key, + )); if matches { found_count += 1; query_cursor += 1; @@ -591,7 +622,7 @@ impl KeymapEditor { return true; } } - return false; + false } }) }); @@ -630,8 +661,7 @@ impl KeymapEditor { let key_bindings_ptr = cx.key_bindings(); let lock = key_bindings_ptr.borrow(); let key_bindings = lock.bindings(); - let mut unmapped_action_names = - HashSet::from_iter(cx.all_action_names().into_iter().copied()); + let mut unmapped_action_names = HashSet::from_iter(cx.all_action_names().iter().copied()); let action_documentation = cx.action_documentation(); let mut generator = KeymapFile::action_schema_generator(); let actions_with_schemas = HashSet::from_iter( @@ -649,7 +679,7 @@ impl KeymapEditor { .map(KeybindSource::from_meta) .unwrap_or(KeybindSource::Unknown); - let keystroke_text = ui::text_for_keystrokes(key_binding.keystrokes(), cx); + let keystroke_text = ui::text_for_keybinding_keystrokes(key_binding.keystrokes(), cx); let ui_key_binding = ui::KeyBinding::new_from_gpui(key_binding.clone(), cx) .vim_mode(source == KeybindSource::Vim); @@ -673,8 +703,8 @@ impl KeymapEditor { action_name, action_arguments, &actions_with_schemas, - &action_documentation, - &humanized_action_names, + action_documentation, + humanized_action_names, ); let index = processed_bindings.len(); @@ -696,8 +726,8 @@ impl KeymapEditor { action_name, None, &actions_with_schemas, - &action_documentation, - &humanized_action_names, + action_documentation, + humanized_action_names, ); let string_match_candidate = StringMatchCandidate::new(index, &action_information.humanized_name); @@ -1173,8 +1203,11 @@ impl KeymapEditor { .read(cx) .get_scrollbar_offset(Axis::Vertical), )); - cx.spawn(async move |_, _| remove_keybinding(to_remove, &fs, tab_size).await) - .detach_and_notify_err(window, cx); + let keyboard_mapper = cx.keyboard_mapper().clone(); + cx.spawn(async move |_, _| { + remove_keybinding(to_remove, &fs, tab_size, keyboard_mapper.as_ref()).await + }) + .detach_and_notify_err(window, cx); } fn copy_context_to_clipboard( @@ -1192,8 +1225,8 @@ impl KeymapEditor { return; }; - telemetry::event!("Keybinding Context Copied", context = context.clone()); - cx.write_to_clipboard(gpui::ClipboardItem::new_string(context.clone())); + telemetry::event!("Keybinding Context Copied", context = context); + cx.write_to_clipboard(gpui::ClipboardItem::new_string(context)); } fn copy_action_to_clipboard( @@ -1209,8 +1242,8 @@ impl KeymapEditor { return; }; - telemetry::event!("Keybinding Action Copied", action = action.clone()); - cx.write_to_clipboard(gpui::ClipboardItem::new_string(action.clone())); + telemetry::event!("Keybinding Action Copied", action = action); + cx.write_to_clipboard(gpui::ClipboardItem::new_string(action)); } fn toggle_conflict_filter( @@ -1298,7 +1331,7 @@ struct HumanizedActionNameCache { impl HumanizedActionNameCache { fn new(cx: &App) -> Self { - let cache = HashMap::from_iter(cx.all_action_names().into_iter().map(|&action_name| { + let cache = HashMap::from_iter(cx.all_action_names().iter().map(|&action_name| { ( action_name, command_palette::humanize_action_name(action_name).into(), @@ -1393,7 +1426,7 @@ impl ProcessedBinding { .map(|keybind| keybind.get_action_mapping()) } - fn keystrokes(&self) -> Option<&[Keystroke]> { + fn keystrokes(&self) -> Option<&[KeybindingKeystroke]> { self.ui_key_binding() .map(|binding| binding.keystrokes.as_slice()) } @@ -1474,7 +1507,7 @@ impl RenderOnce for KeybindContextString { fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { match self { KeybindContextString::Global => { - muted_styled_text(KeybindContextString::GLOBAL.clone(), cx).into_any_element() + muted_styled_text(KeybindContextString::GLOBAL, cx).into_any_element() } KeybindContextString::Local(name, language) => { SyntaxHighlightedText::new(name, language).into_any_element() @@ -1550,73 +1583,139 @@ impl Render for KeymapEditor { .py_1() .border_1() .border_color(theme.colors().border) - .rounded_lg() + .rounded_md() .child(self.filter_editor.clone()), ) .child( - IconButton::new( - "KeymapEditorToggleFiltersIcon", - IconName::Keyboard, - ) - .shape(ui::IconButtonShape::Square) - .tooltip({ - let focus_handle = focus_handle.clone(); - - move |window, cx| { - Tooltip::for_action_in( - "Search by Keystroke", - &ToggleKeystrokeSearch, - &focus_handle.clone(), - window, - cx, + h_flex() + .gap_1() + .min_w_64() + .child( + IconButton::new( + "KeymapEditorToggleFiltersIcon", + IconName::Keyboard, ) - } - }) - .toggle_state(matches!( - self.search_mode, - SearchMode::KeyStroke { .. } - )) - .on_click(|_, window, cx| { - window.dispatch_action(ToggleKeystrokeSearch.boxed_clone(), cx); - }), - ) - .child( - IconButton::new("KeymapEditorConflictIcon", IconName::Warning) - .shape(ui::IconButtonShape::Square) - .when( - self.keybinding_conflict_state.any_user_binding_conflicts(), - |this| { - this.indicator(Indicator::dot().color(Color::Warning)) - }, + .icon_size(IconSize::Small) + .tooltip({ + let focus_handle = focus_handle.clone(); + + move |window, cx| { + Tooltip::for_action_in( + "Search by Keystroke", + &ToggleKeystrokeSearch, + &focus_handle.clone(), + window, + cx, + ) + } + }) + .toggle_state(matches!( + self.search_mode, + SearchMode::KeyStroke { .. } + )) + .on_click(|_, window, cx| { + window.dispatch_action( + ToggleKeystrokeSearch.boxed_clone(), + cx, + ); + }), ) - .tooltip({ - let filter_state = self.filter_state; - let focus_handle = focus_handle.clone(); - - move |window, cx| { - Tooltip::for_action_in( - match filter_state { - FilterState::All => "Show Conflicts", - FilterState::Conflicts => "Hide Conflicts", + .child( + IconButton::new("KeymapEditorConflictIcon", IconName::Warning) + .icon_size(IconSize::Small) + .when( + self.keybinding_conflict_state + .any_user_binding_conflicts(), + |this| { + this.indicator( + Indicator::dot().color(Color::Warning), + ) }, - &ToggleConflictFilter, - &focus_handle.clone(), - window, - cx, ) - } - }) - .selected_icon_color(Color::Warning) - .toggle_state(matches!( - self.filter_state, - FilterState::Conflicts - )) - .on_click(|_, window, cx| { - window.dispatch_action( - ToggleConflictFilter.boxed_clone(), - cx, - ); - }), + .tooltip({ + let filter_state = self.filter_state; + let focus_handle = focus_handle.clone(); + + move |window, cx| { + Tooltip::for_action_in( + match filter_state { + FilterState::All => "Show Conflicts", + FilterState::Conflicts => { + "Hide Conflicts" + } + }, + &ToggleConflictFilter, + &focus_handle.clone(), + window, + cx, + ) + } + }) + .selected_icon_color(Color::Warning) + .toggle_state(matches!( + self.filter_state, + FilterState::Conflicts + )) + .on_click(|_, window, cx| { + window.dispatch_action( + ToggleConflictFilter.boxed_clone(), + cx, + ); + }), + ) + .child( + div() + .ml_1() + .pl_2() + .border_l_1() + .border_color(cx.theme().colors().border_variant) + .child( + right_click_menu("open-keymap-menu") + .menu(|window, cx| { + ContextMenu::build(window, cx, |menu, _, _| { + menu.header("Open Keymap JSON") + .action( + "User", + zed_actions::OpenKeymap.boxed_clone(), + ) + .action( + "Zed Default", + zed_actions::OpenDefaultKeymap + .boxed_clone(), + ) + .action( + "Vim Default", + vim::OpenDefaultKeymap.boxed_clone(), + ) + }) + }) + .anchor(gpui::Corner::TopLeft) + .trigger(|open, _, _| { + IconButton::new( + "OpenKeymapJsonButton", + IconName::Json, + ) + .icon_size(IconSize::Small) + .when(!open, |this| { + this.tooltip(move |window, cx| { + Tooltip::with_meta( + "Open keymap.json", + Some(&zed_actions::OpenKeymap), + "Right click to view more options", + window, + cx, + ) + }) + }) + .on_click(|_, window, cx| { + window.dispatch_action( + zed_actions::OpenKeymap.boxed_clone(), + cx, + ); + }) + }), + ), + ) ), ) .when_some( @@ -1627,48 +1726,42 @@ impl Render for KeymapEditor { |this, exact_match| { this.child( h_flex() - .map(|this| { - if self - .keybinding_conflict_state - .any_user_binding_conflicts() - { - this.pr(rems_from_px(54.)) - } else { - this.pr_7() - } - }) .gap_2() .child(self.keystroke_editor.clone()) .child( - IconButton::new( - "keystrokes-exact-match", - IconName::CaseSensitive, - ) - .tooltip({ - let keystroke_focus_handle = - self.keystroke_editor.read(cx).focus_handle(cx); - - move |window, cx| { - Tooltip::for_action_in( - "Toggle Exact Match Mode", - &ToggleExactKeystrokeMatching, - &keystroke_focus_handle, - window, - cx, + h_flex() + .min_w_64() + .child( + IconButton::new( + "keystrokes-exact-match", + IconName::CaseSensitive, ) - } - }) - .shape(IconButtonShape::Square) - .toggle_state(exact_match) - .on_click( - cx.listener(|_, _, window, cx| { - window.dispatch_action( - ToggleExactKeystrokeMatching.boxed_clone(), - cx, - ); - }), - ), - ), + .tooltip({ + let keystroke_focus_handle = + self.keystroke_editor.read(cx).focus_handle(cx); + + move |window, cx| { + Tooltip::for_action_in( + "Toggle Exact Match Mode", + &ToggleExactKeystrokeMatching, + &keystroke_focus_handle, + window, + cx, + ) + } + }) + .shape(IconButtonShape::Square) + .toggle_state(exact_match) + .on_click( + cx.listener(|_, _, window, cx| { + window.dispatch_action( + ToggleExactKeystrokeMatching.boxed_clone(), + cx, + ); + }), + ), + ), + ) ) }, ), @@ -1731,7 +1824,7 @@ impl Render for KeymapEditor { } else { const NULL: SharedString = SharedString::new_static(""); - muted_styled_text(NULL.clone(), cx) + muted_styled_text(NULL, cx) .into_any_element() } }) @@ -1839,18 +1932,15 @@ impl Render for KeymapEditor { mouse_down_event: &gpui::MouseDownEvent, window, cx| { - match mouse_down_event.button { - MouseButton::Right => { - this.select_index( - row_index, None, window, cx, - ); - this.create_context_menu( - mouse_down_event.position, - window, - cx, - ); - } - _ => {} + if mouse_down_event.button == MouseButton::Right { + this.select_index( + row_index, None, window, cx, + ); + this.create_context_menu( + mouse_down_event.position, + window, + cx, + ); } }, )) @@ -1994,21 +2084,21 @@ impl RenderOnce for SyntaxHighlightedText { #[derive(PartialEq)] struct InputError { - severity: ui::Severity, + severity: Severity, content: SharedString, } impl InputError { fn warning(message: impl Into) -> Self { Self { - severity: ui::Severity::Warning, + severity: Severity::Warning, content: message.into(), } } fn error(message: anyhow::Error) -> Self { Self { - severity: ui::Severity::Error, + severity: Severity::Error, content: message.to_string().into(), } } @@ -2135,9 +2225,11 @@ impl KeybindingEditorModal { } fn set_error(&mut self, error: InputError, cx: &mut Context) -> bool { - if self.error.as_ref().is_some_and(|old_error| { - old_error.severity == ui::Severity::Warning && *old_error == error - }) { + if self + .error + .as_ref() + .is_some_and(|old_error| old_error.severity == Severity::Warning && *old_error == error) + { false } else { self.error = Some(error); @@ -2150,7 +2242,8 @@ impl KeybindingEditorModal { let action_arguments = self .action_arguments_editor .as_ref() - .map(|editor| editor.read(cx).editor.read(cx).text(cx)); + .map(|arguments_editor| arguments_editor.read(cx).editor.read(cx).text(cx)) + .filter(|args| !args.is_empty()); let value = action_arguments .as_ref() @@ -2159,12 +2252,12 @@ impl KeybindingEditorModal { }) .transpose()?; - cx.build_action(&self.editing_keybind.action().name, value) + cx.build_action(self.editing_keybind.action().name, value) .context("Failed to validate action arguments")?; Ok(action_arguments) } - fn validate_keystrokes(&self, cx: &App) -> anyhow::Result> { + fn validate_keystrokes(&self, cx: &App) -> anyhow::Result> { let new_keystrokes = self .keybind_editor .read_with(cx, |editor, _| editor.keystrokes().to_vec()); @@ -2193,12 +2286,10 @@ impl KeybindingEditorModal { let fs = self.fs.clone(); let tab_size = cx.global::().json_tab_size(); - let new_keystrokes = self - .validate_keystrokes(cx) - .map_err(InputError::error)? - .into_iter() - .map(remove_key_char) - .collect::>(); + let mut new_keystrokes = self.validate_keystrokes(cx).map_err(InputError::error)?; + new_keystrokes + .iter_mut() + .for_each(|ks| ks.remove_key_char()); let new_context = self.validate_context(cx).map_err(InputError::error)?; let new_action_args = self @@ -2260,58 +2351,60 @@ impl KeybindingEditorModal { }).unwrap_or(Ok(()))?; let create = self.creating; - - let status_toast = StatusToast::new( - format!( - "Saved edits to the {} action.", - &self.editing_keybind.action().humanized_name - ), - cx, - move |this, _cx| { - this.icon(ToastIcon::new(IconName::Check).color(Color::Success)) - .dismiss_button(true) - // .action("Undo", f) todo: wire the undo functionality - }, - ); - - self.workspace - .update(cx, |workspace, cx| { - workspace.toggle_status_toast(status_toast, cx); - }) - .log_err(); + let keyboard_mapper = cx.keyboard_mapper().clone(); cx.spawn(async move |this, cx| { let action_name = existing_keybind.action().name; + let humanized_action_name = existing_keybind.action().humanized_name.clone(); - if let Err(err) = save_keybinding_update( + match save_keybinding_update( create, existing_keybind, &action_mapping, new_action_args.as_deref(), &fs, tab_size, + keyboard_mapper.as_ref(), ) .await { - this.update(cx, |this, cx| { - this.set_error(InputError::error(err), cx); - }) - .log_err(); - } else { - this.update(cx, |this, cx| { - this.keymap_editor.update(cx, |keymap, cx| { - keymap.previous_edit = Some(PreviousEdit::Keybinding { - action_mapping, - action_name, - fallback: keymap - .table_interaction_state - .read(cx) - .get_scrollbar_offset(Axis::Vertical), - }) - }); - cx.emit(DismissEvent); - }) - .ok(); + Ok(_) => { + this.update(cx, |this, cx| { + this.keymap_editor.update(cx, |keymap, cx| { + keymap.previous_edit = Some(PreviousEdit::Keybinding { + action_mapping, + action_name, + fallback: keymap + .table_interaction_state + .read(cx) + .get_scrollbar_offset(Axis::Vertical), + }); + let status_toast = StatusToast::new( + format!("Saved edits to the {} action.", humanized_action_name), + cx, + move |this, _cx| { + this.icon(ToastIcon::new(IconName::Check).color(Color::Success)) + .dismiss_button(true) + // .action("Undo", f) todo: wire the undo functionality + }, + ); + + this.workspace + .update(cx, |workspace, cx| { + workspace.toggle_status_toast(status_toast, cx); + }) + .log_err(); + }); + cx.emit(DismissEvent); + }) + .ok(); + } + Err(err) => { + this.update(cx, |this, cx| { + this.set_error(InputError::error(err), cx); + }) + .log_err(); + } } }) .detach(); @@ -2389,14 +2482,6 @@ impl KeybindingEditorModal { } } -fn remove_key_char(Keystroke { modifiers, key, .. }: Keystroke) -> Keystroke { - Keystroke { - modifiers, - key, - ..Default::default() - } -} - impl Render for KeybindingEditorModal { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { let theme = cx.theme().colors(); @@ -2691,7 +2776,7 @@ impl ActionArgumentsEditor { }) .ok(); } - return result; + result }) .detach_and_log_err(cx); Self { @@ -2712,7 +2797,7 @@ impl ActionArgumentsEditor { editor.set_text(arguments, window, cx); } else { // TODO: default value from schema? - editor.set_placeholder_text("Action Arguments", cx); + editor.set_placeholder_text("Action Arguments", window, cx); } } @@ -2794,7 +2879,7 @@ impl Render for ActionArgumentsEditor { self.editor .update(cx, |editor, _| editor.set_text_style_refinement(text_style)); - return v_flex().w_full().child( + v_flex().w_full().child( h_flex() .min_h_8() .min_w_48() @@ -2807,7 +2892,7 @@ impl Render for ActionArgumentsEditor { .border_color(border_color) .track_focus(&self.focus_handle) .child(self.editor.clone()), - ); + ) } } @@ -2834,11 +2919,8 @@ impl CompletionProvider for KeyContextCompletionProvider { break; } } - let start_anchor = buffer.anchor_before( - buffer_position - .to_offset(&buffer) - .saturating_sub(count_back), - ); + let start_anchor = + buffer.anchor_before(buffer_position.to_offset(buffer).saturating_sub(count_back)); let replace_range = start_anchor..buffer_position; gpui::Task::ready(Ok(vec![project::CompletionResponse { completions: self @@ -2855,6 +2937,7 @@ impl CompletionProvider for KeyContextCompletionProvider { confirm: None, }) .collect(), + display_options: CompletionDisplayOptions::default(), is_incomplete: false, }])) } @@ -2868,9 +2951,9 @@ impl CompletionProvider for KeyContextCompletionProvider { _menu_is_open: bool, _cx: &mut Context, ) -> bool { - text.chars().last().map_or(false, |last_char| { - last_char.is_ascii_alphanumeric() || last_char == '_' - }) + text.chars() + .last() + .is_some_and(|last_char| last_char.is_ascii_alphanumeric() || last_char == '_') } } @@ -2889,7 +2972,7 @@ async fn load_json_language(workspace: WeakEntity, cx: &mut AsyncApp) Some(task) => task.await.context("Failed to load JSON language").log_err(), None => None, }; - return json_language.unwrap_or_else(|| { + json_language.unwrap_or_else(|| { Arc::new(Language::new( LanguageConfig { name: "JSON".into(), @@ -2897,7 +2980,7 @@ async fn load_json_language(workspace: WeakEntity, cx: &mut AsyncApp) }, Some(tree_sitter_json::LANGUAGE.into()), )) - }); + }) } async fn load_keybind_context_language( @@ -2921,7 +3004,7 @@ async fn load_keybind_context_language( .log_err(), None => None, }; - return language.unwrap_or_else(|| { + language.unwrap_or_else(|| { Arc::new(Language::new( LanguageConfig { name: "Zed Keybind Context".into(), @@ -2929,7 +3012,7 @@ async fn load_keybind_context_language( }, Some(tree_sitter_rust::LANGUAGE.into()), )) - }); + }) } async fn save_keybinding_update( @@ -2939,6 +3022,7 @@ async fn save_keybinding_update( new_args: Option<&str>, fs: &Arc, tab_size: usize, + keyboard_mapper: &dyn PlatformKeyboardMapper, ) -> anyhow::Result<()> { let keymap_contents = settings::KeymapFile::load_keymap_file(fs) .await @@ -2955,14 +3039,14 @@ async fn save_keybinding_update( let target = settings::KeybindUpdateTarget { context: existing_context, keystrokes: existing_keystrokes, - action_name: &existing.action().name, + action_name: existing.action().name, action_arguments: existing_args, }; let source = settings::KeybindUpdateTarget { context: action_mapping.context.as_ref().map(|a| &***a), keystrokes: &action_mapping.keystrokes, - action_name: &existing.action().name, + action_name: existing.action().name, action_arguments: new_args, }; @@ -2981,9 +3065,13 @@ async fn save_keybinding_update( let (new_keybinding, removed_keybinding, source) = operation.generate_telemetry(); - let updated_keymap_contents = - settings::KeymapFile::update_keybinding(operation, keymap_contents, tab_size) - .context("Failed to update keybinding")?; + let updated_keymap_contents = settings::KeymapFile::update_keybinding( + operation, + keymap_contents, + tab_size, + keyboard_mapper, + ) + .map_err(|err| anyhow::anyhow!("Could not save updated keybinding: {}", err))?; fs.write( paths::keymap_file().as_path(), updated_keymap_contents.as_bytes(), @@ -3004,6 +3092,7 @@ async fn remove_keybinding( existing: ProcessedBinding, fs: &Arc, tab_size: usize, + keyboard_mapper: &dyn PlatformKeyboardMapper, ) -> anyhow::Result<()> { let Some(keystrokes) = existing.keystrokes() else { anyhow::bail!("Cannot remove a keybinding that does not exist"); @@ -3016,7 +3105,7 @@ async fn remove_keybinding( target: settings::KeybindUpdateTarget { context: existing.context().and_then(KeybindContextString::local_str), keystrokes, - action_name: &existing.action().name, + action_name: existing.action().name, action_arguments: existing .action() .arguments @@ -3027,9 +3116,13 @@ async fn remove_keybinding( }; let (new_keybinding, removed_keybinding, source) = operation.generate_telemetry(); - let updated_keymap_contents = - settings::KeymapFile::update_keybinding(operation, keymap_contents, tab_size) - .context("Failed to update keybinding")?; + let updated_keymap_contents = settings::KeymapFile::update_keybinding( + operation, + keymap_contents, + tab_size, + keyboard_mapper, + ) + .context("Failed to update keybinding")?; fs.write( paths::keymap_file().as_path(), updated_keymap_contents.as_bytes(), @@ -3075,29 +3168,29 @@ fn collect_contexts_from_assets() -> Vec { queue.push(root_context); while let Some(context) = queue.pop() { match context { - gpui::KeyBindingContextPredicate::Identifier(ident) => { + Identifier(ident) => { contexts.insert(ident); } - gpui::KeyBindingContextPredicate::Equal(ident_a, ident_b) => { + Equal(ident_a, ident_b) => { contexts.insert(ident_a); contexts.insert(ident_b); } - gpui::KeyBindingContextPredicate::NotEqual(ident_a, ident_b) => { + NotEqual(ident_a, ident_b) => { contexts.insert(ident_a); contexts.insert(ident_b); } - gpui::KeyBindingContextPredicate::Descendant(ctx_a, ctx_b) => { + Descendant(ctx_a, ctx_b) => { queue.push(*ctx_a); queue.push(*ctx_b); } - gpui::KeyBindingContextPredicate::Not(ctx) => { + Not(ctx) => { queue.push(*ctx); } - gpui::KeyBindingContextPredicate::And(ctx_a, ctx_b) => { + And(ctx_a, ctx_b) => { queue.push(*ctx_a); queue.push(*ctx_b); } - gpui::KeyBindingContextPredicate::Or(ctx_a, ctx_b) => { + Or(ctx_a, ctx_b) => { queue.push(*ctx_a); queue.push(*ctx_b); } @@ -3109,7 +3202,128 @@ fn collect_contexts_from_assets() -> Vec { let mut contexts = contexts.into_iter().collect::>(); contexts.sort(); - return contexts; + contexts +} + +fn normalized_ctx_eq( + a: &gpui::KeyBindingContextPredicate, + b: &gpui::KeyBindingContextPredicate, +) -> bool { + use gpui::KeyBindingContextPredicate::*; + return match (a, b) { + (Identifier(_), Identifier(_)) => a == b, + (Equal(a_left, a_right), Equal(b_left, b_right)) => { + (a_left == b_left && a_right == b_right) || (a_left == b_right && a_right == b_left) + } + (NotEqual(a_left, a_right), NotEqual(b_left, b_right)) => { + (a_left == b_left && a_right == b_right) || (a_left == b_right && a_right == b_left) + } + (Descendant(a_parent, a_child), Descendant(b_parent, b_child)) => { + normalized_ctx_eq(a_parent, b_parent) && normalized_ctx_eq(a_child, b_child) + } + (Not(a_expr), Not(b_expr)) => normalized_ctx_eq(a_expr, b_expr), + // Handle double negation: !(!a) == a + (Not(a_expr), b) if matches!(a_expr.as_ref(), Not(_)) => { + let Not(a_inner) = a_expr.as_ref() else { + unreachable!(); + }; + normalized_ctx_eq(b, a_inner) + } + (a, Not(b_expr)) if matches!(b_expr.as_ref(), Not(_)) => { + let Not(b_inner) = b_expr.as_ref() else { + unreachable!(); + }; + normalized_ctx_eq(a, b_inner) + } + (And(a_left, a_right), And(b_left, b_right)) + if matches!(a_left.as_ref(), And(_, _)) + || matches!(a_right.as_ref(), And(_, _)) + || matches!(b_left.as_ref(), And(_, _)) + || matches!(b_right.as_ref(), And(_, _)) => + { + let mut a_operands = Vec::new(); + flatten_and(a, &mut a_operands); + let mut b_operands = Vec::new(); + flatten_and(b, &mut b_operands); + compare_operand_sets(&a_operands, &b_operands) + } + (And(a_left, a_right), And(b_left, b_right)) => { + (normalized_ctx_eq(a_left, b_left) && normalized_ctx_eq(a_right, b_right)) + || (normalized_ctx_eq(a_left, b_right) && normalized_ctx_eq(a_right, b_left)) + } + (Or(a_left, a_right), Or(b_left, b_right)) + if matches!(a_left.as_ref(), Or(_, _)) + || matches!(a_right.as_ref(), Or(_, _)) + || matches!(b_left.as_ref(), Or(_, _)) + || matches!(b_right.as_ref(), Or(_, _)) => + { + let mut a_operands = Vec::new(); + flatten_or(a, &mut a_operands); + let mut b_operands = Vec::new(); + flatten_or(b, &mut b_operands); + compare_operand_sets(&a_operands, &b_operands) + } + (Or(a_left, a_right), Or(b_left, b_right)) => { + (normalized_ctx_eq(a_left, b_left) && normalized_ctx_eq(a_right, b_right)) + || (normalized_ctx_eq(a_left, b_right) && normalized_ctx_eq(a_right, b_left)) + } + _ => false, + }; + + fn flatten_and<'a>( + pred: &'a gpui::KeyBindingContextPredicate, + operands: &mut Vec<&'a gpui::KeyBindingContextPredicate>, + ) { + use gpui::KeyBindingContextPredicate::*; + match pred { + And(left, right) => { + flatten_and(left, operands); + flatten_and(right, operands); + } + _ => operands.push(pred), + } + } + + fn flatten_or<'a>( + pred: &'a gpui::KeyBindingContextPredicate, + operands: &mut Vec<&'a gpui::KeyBindingContextPredicate>, + ) { + use gpui::KeyBindingContextPredicate::*; + match pred { + Or(left, right) => { + flatten_or(left, operands); + flatten_or(right, operands); + } + _ => operands.push(pred), + } + } + + fn compare_operand_sets( + a: &[&gpui::KeyBindingContextPredicate], + b: &[&gpui::KeyBindingContextPredicate], + ) -> bool { + if a.len() != b.len() { + return false; + } + + // For each operand in a, find a matching operand in b + let mut b_matched = vec![false; b.len()]; + for a_operand in a { + let mut found = false; + for (b_idx, b_operand) in b.iter().enumerate() { + if !b_matched[b_idx] && normalized_ctx_eq(a_operand, b_operand) { + b_matched[b_idx] = true; + found = true; + break; + } + } + if !found { + return false; + } + } + + true + } } impl SerializableItem for KeymapEditor { @@ -3174,12 +3388,15 @@ impl SerializableItem for KeymapEditor { } mod persistence { - use db::{define_connection, query, sqlez_macros::sql}; + use db::{query, sqlez::domain::Domain, sqlez_macros::sql}; use workspace::WorkspaceDb; - define_connection! { - pub static ref KEYBINDING_EDITORS: KeybindingEditorDb = - &[sql!( + pub struct KeybindingEditorDb(db::sqlez::thread_safe_connection::ThreadSafeConnection); + + impl Domain for KeybindingEditorDb { + const NAME: &str = stringify!(KeybindingEditorDb); + + const MIGRATIONS: &[&str] = &[sql!( CREATE TABLE keybinding_editors ( workspace_id INTEGER, item_id INTEGER UNIQUE, @@ -3188,9 +3405,11 @@ mod persistence { FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) ON DELETE CASCADE ) STRICT; - )]; + )]; } + db::static_connection!(KEYBINDING_EDITORS, KeybindingEditorDb, [WorkspaceDb]); + impl KeybindingEditorDb { query! { pub async fn save_keybinding_editor( @@ -3214,3 +3433,152 @@ mod persistence { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn normalized_ctx_cmp() { + #[track_caller] + fn cmp(a: &str, b: &str) -> bool { + let a = gpui::KeyBindingContextPredicate::parse(a) + .expect("Failed to parse keybinding context a"); + let b = gpui::KeyBindingContextPredicate::parse(b) + .expect("Failed to parse keybinding context b"); + normalized_ctx_eq(&a, &b) + } + + // Basic equality - identical expressions + assert!(cmp("a && b", "a && b")); + assert!(cmp("a || b", "a || b")); + assert!(cmp("a == b", "a == b")); + assert!(cmp("a != b", "a != b")); + assert!(cmp("a > b", "a > b")); + assert!(cmp("!a", "!a")); + + // AND operator - associative/commutative + assert!(cmp("a && b", "b && a")); + assert!(cmp("a && b && c", "c && b && a")); + assert!(cmp("a && b && c", "b && a && c")); + assert!(cmp("a && b && c && d", "d && c && b && a")); + + // OR operator - associative/commutative + assert!(cmp("a || b", "b || a")); + assert!(cmp("a || b || c", "c || b || a")); + assert!(cmp("a || b || c", "b || a || c")); + assert!(cmp("a || b || c || d", "d || c || b || a")); + + // Equality operator - associative/commutative + assert!(cmp("a == b", "b == a")); + assert!(cmp("x == y", "y == x")); + + // Inequality operator - associative/commutative + assert!(cmp("a != b", "b != a")); + assert!(cmp("x != y", "y != x")); + + // Complex nested expressions with associative operators + assert!(cmp("(a && b) || c", "c || (a && b)")); + assert!(cmp("(a && b) || c", "c || (b && a)")); + assert!(cmp("(a || b) && c", "c && (a || b)")); + assert!(cmp("(a || b) && c", "c && (b || a)")); + assert!(cmp("(a && b) || (c && d)", "(c && d) || (a && b)")); + assert!(cmp("(a && b) || (c && d)", "(d && c) || (b && a)")); + + // Multiple levels of nesting + assert!(cmp("((a && b) || c) && d", "d && ((a && b) || c)")); + assert!(cmp("((a && b) || c) && d", "d && (c || (b && a))")); + assert!(cmp("a && (b || (c && d))", "(b || (c && d)) && a")); + assert!(cmp("a && (b || (c && d))", "(b || (d && c)) && a")); + + // Negation with associative operators + assert!(cmp("!a && b", "b && !a")); + assert!(cmp("!a || b", "b || !a")); + assert!(cmp("!(a && b) || c", "c || !(a && b)")); + assert!(cmp("!(a && b) || c", "c || !(b && a)")); + + // Descendant operator (>) - NOT associative/commutative + assert!(cmp("a > b", "a > b")); + assert!(!cmp("a > b", "b > a")); + assert!(!cmp("a > b > c", "c > b > a")); + assert!(!cmp("a > b > c", "a > c > b")); + + // Mixed operators with descendant + assert!(cmp("(a > b) && c", "c && (a > b)")); + assert!(!cmp("(a > b) && c", "c && (b > a)")); + assert!(cmp("(a > b) || (c > d)", "(c > d) || (a > b)")); + assert!(!cmp("(a > b) || (c > d)", "(b > a) || (d > c)")); + + // Negative cases - different operators + assert!(!cmp("a && b", "a || b")); + assert!(!cmp("a == b", "a != b")); + assert!(!cmp("a && b", "a > b")); + assert!(!cmp("a || b", "a > b")); + assert!(!cmp("a == b", "a && b")); + assert!(!cmp("a != b", "a || b")); + + // Negative cases - different operands + assert!(!cmp("a && b", "a && c")); + assert!(!cmp("a && b", "c && d")); + assert!(!cmp("a || b", "a || c")); + assert!(!cmp("a || b", "c || d")); + assert!(!cmp("a == b", "a == c")); + assert!(!cmp("a != b", "a != c")); + assert!(!cmp("a > b", "a > c")); + assert!(!cmp("a > b", "c > b")); + + // Negative cases - with negation + assert!(!cmp("!a", "a")); + assert!(!cmp("!a && b", "a && b")); + assert!(!cmp("!(a && b)", "a && b")); + assert!(!cmp("!a || b", "a || b")); + assert!(!cmp("!(a || b)", "a || b")); + + // Negative cases - complex expressions + assert!(!cmp("(a && b) || c", "(a || b) && c")); + assert!(!cmp("a && (b || c)", "a || (b && c)")); + assert!(!cmp("(a && b) || (c && d)", "(a || b) && (c || d)")); + assert!(!cmp("a > b && c", "a && b > c")); + + // Edge cases - multiple same operands + assert!(cmp("a && a", "a && a")); + assert!(cmp("a || a", "a || a")); + assert!(cmp("a && a && b", "b && a && a")); + assert!(cmp("a || a || b", "b || a || a")); + + // Edge cases - deeply nested + assert!(cmp( + "((a && b) || (c && d)) && ((e || f) && g)", + "((e || f) && g) && ((c && d) || (a && b))" + )); + assert!(cmp( + "((a && b) || (c && d)) && ((e || f) && g)", + "(g && (f || e)) && ((d && c) || (b && a))" + )); + + // Edge cases - repeated patterns + assert!(cmp("(a && b) || (a && b)", "(b && a) || (b && a)")); + assert!(cmp("(a || b) && (a || b)", "(b || a) && (b || a)")); + + // Negative cases - subtle differences + assert!(!cmp("a && b && c", "a && b")); + assert!(!cmp("a || b || c", "a || b")); + assert!(!cmp("(a && b) || c", "a && (b || c)")); + + // a > b > c is not the same as a > c, should not be equal + assert!(!cmp("a > b > c", "a > c")); + + // Double negation with complex expressions + assert!(cmp("!(!(a && b))", "a && b")); + assert!(cmp("!(!(a || b))", "a || b")); + assert!(cmp("!(!(a > b))", "a > b")); + assert!(cmp("!(!a) && b", "a && b")); + assert!(cmp("!(!a) || b", "a || b")); + assert!(cmp("!(!(a && b)) || c", "(a && b) || c")); + assert!(cmp("!(!(a && b)) || c", "(b && a) || c")); + assert!(cmp("!(!a)", "a")); + assert!(cmp("a", "!(!a)")); + assert!(cmp("!(!(!a))", "!a")); + assert!(cmp("!(!(!(!a)))", "a")); + } +} diff --git a/crates/settings_ui/src/ui_components/keystroke_input.rs b/crates/keymap_editor/src/ui_components/keystroke_input.rs similarity index 85% rename from crates/settings_ui/src/ui_components/keystroke_input.rs rename to crates/keymap_editor/src/ui_components/keystroke_input.rs index f23d80931c4e3e509731836bd1230df1f64e4423..e264df3b62bc3c5c78acc38ed906e81837dfbf94 100644 --- a/crates/settings_ui/src/ui_components/keystroke_input.rs +++ b/crates/keymap_editor/src/ui_components/keystroke_input.rs @@ -1,6 +1,6 @@ use gpui::{ Animation, AnimationExt, Context, EventEmitter, FocusHandle, Focusable, FontWeight, KeyContext, - Keystroke, Modifiers, ModifiersChangedEvent, Subscription, Task, actions, + KeybindingKeystroke, Keystroke, Modifiers, ModifiersChangedEvent, Subscription, Task, actions, }; use ui::{ ActiveTheme as _, Color, IconButton, IconButtonShape, IconName, IconSize, Label, LabelSize, @@ -19,7 +19,7 @@ actions!( ] ); -const KEY_CONTEXT_VALUE: &'static str = "KeystrokeInput"; +const KEY_CONTEXT_VALUE: &str = "KeystrokeInput"; const CLOSE_KEYSTROKE_CAPTURE_END_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(300); @@ -42,8 +42,8 @@ impl PartialEq for CloseKeystrokeResult { } pub struct KeystrokeInput { - keystrokes: Vec, - placeholder_keystrokes: Option>, + keystrokes: Vec, + placeholder_keystrokes: Option>, outer_focus_handle: FocusHandle, inner_focus_handle: FocusHandle, intercept_subscription: Option, @@ -70,7 +70,7 @@ impl KeystrokeInput { const KEYSTROKE_COUNT_MAX: usize = 3; pub fn new( - placeholder_keystrokes: Option>, + placeholder_keystrokes: Option>, window: &mut Window, cx: &mut Context, ) -> Self { @@ -97,7 +97,7 @@ impl KeystrokeInput { } } - pub fn set_keystrokes(&mut self, keystrokes: Vec, cx: &mut Context) { + pub fn set_keystrokes(&mut self, keystrokes: Vec, cx: &mut Context) { self.keystrokes = keystrokes; self.keystrokes_changed(cx); } @@ -106,7 +106,7 @@ impl KeystrokeInput { self.search = search; } - pub fn keystrokes(&self) -> &[Keystroke] { + pub fn keystrokes(&self) -> &[KeybindingKeystroke] { if let Some(placeholders) = self.placeholder_keystrokes.as_ref() && self.keystrokes.is_empty() { @@ -116,19 +116,19 @@ impl KeystrokeInput { && self .keystrokes .last() - .map_or(false, |last| last.key.is_empty()) + .is_some_and(|last| last.key().is_empty()) { return &self.keystrokes[..self.keystrokes.len() - 1]; } - return &self.keystrokes; + &self.keystrokes } - fn dummy(modifiers: Modifiers) -> Keystroke { - return Keystroke { + fn dummy(modifiers: Modifiers) -> KeybindingKeystroke { + KeybindingKeystroke::from_keystroke(Keystroke { modifiers, key: "".to_string(), key_char: None, - }; + }) } fn keystrokes_changed(&self, cx: &mut Context) { @@ -182,7 +182,7 @@ impl KeystrokeInput { fn end_close_keystrokes_capture(&mut self) -> Option { self.close_keystrokes.take(); self.clear_close_keystrokes_timer.take(); - return self.close_keystrokes_start.take(); + self.close_keystrokes_start.take() } fn handle_possible_close_keystroke( @@ -233,7 +233,7 @@ impl KeystrokeInput { return CloseKeystrokeResult::Partial; } self.end_close_keystrokes_capture(); - return CloseKeystrokeResult::None; + CloseKeystrokeResult::None } fn on_modifiers_changed( @@ -254,7 +254,7 @@ impl KeystrokeInput { self.keystrokes_changed(cx); if let Some(last) = self.keystrokes.last_mut() - && last.key.is_empty() + && last.key().is_empty() && keystrokes_len <= Self::KEYSTROKE_COUNT_MAX { if !self.search && !event.modifiers.modified() { @@ -263,13 +263,14 @@ impl KeystrokeInput { } if self.search { if self.previous_modifiers.modified() { - last.modifiers |= event.modifiers; + let modifiers = *last.modifiers() | event.modifiers; + last.set_modifiers(modifiers); } else { self.keystrokes.push(Self::dummy(event.modifiers)); } self.previous_modifiers |= event.modifiers; } else { - last.modifiers = event.modifiers; + last.set_modifiers(event.modifiers); return; } } else if keystrokes_len < Self::KEYSTROKE_COUNT_MAX { @@ -297,14 +298,15 @@ impl KeystrokeInput { return; } - let mut keystroke = keystroke.clone(); + let keystroke = KeybindingKeystroke::new_with_mapper( + keystroke.clone(), + false, + cx.keyboard_mapper().as_ref(), + ); if let Some(last) = self.keystrokes.last() - && last.key.is_empty() + && last.key().is_empty() && (!self.search || self.previous_modifiers.modified()) { - let key = keystroke.key.clone(); - keystroke = last.clone(); - keystroke.key = key; self.keystrokes.pop(); } @@ -320,15 +322,19 @@ impl KeystrokeInput { return; } - self.keystrokes.push(keystroke.clone()); + self.keystrokes.push(keystroke); self.keystrokes_changed(cx); + // The reason we use the real modifiers from the window instead of the keystroke's modifiers + // is that for keystrokes like `ctrl-$` the modifiers reported by keystroke is `ctrl` which + // is wrong, it should be `ctrl-shift`. The window's modifiers are always correct. + let real_modifiers = window.modifiers(); if self.search { - self.previous_modifiers = keystroke.modifiers; + self.previous_modifiers = real_modifiers; return; } - if self.keystrokes.len() < Self::KEYSTROKE_COUNT_MAX && keystroke.modifiers.modified() { - self.keystrokes.push(Self::dummy(keystroke.modifiers)); + if self.keystrokes.len() < Self::KEYSTROKE_COUNT_MAX && real_modifiers.modified() { + self.keystrokes.push(Self::dummy(real_modifiers)); } } @@ -364,7 +370,7 @@ impl KeystrokeInput { &self.keystrokes }; keystrokes.iter().map(move |keystroke| { - h_flex().children(ui::render_keystroke( + h_flex().children(ui::render_keybinding_keystroke( keystroke, Some(Color::Default), Some(rems(0.875).into()), @@ -437,7 +443,7 @@ impl KeystrokeInput { // is a much more reliable check, as the intercept keystroke handlers are installed // on focus of the inner focus handle, thereby ensuring our recording state does // not get de-synced - return self.inner_focus_handle.is_focused(window); + self.inner_focus_handle.is_focused(window) } } @@ -455,7 +461,7 @@ impl Render for KeystrokeInput { let is_focused = self.outer_focus_handle.contains_focused(window, cx); let is_recording = self.is_recording(window); - let horizontal_padding = rems_from_px(64.); + let width = rems_from_px(64.); let recording_bg_color = colors .editor_background @@ -522,6 +528,9 @@ impl Render for KeystrokeInput { h_flex() .id("keystroke-input") .track_focus(&self.outer_focus_handle) + .key_context(Self::key_context()) + .on_action(cx.listener(Self::start_recording)) + .on_action(cx.listener(Self::clear_keystrokes)) .py_2() .px_3() .gap_2() @@ -529,7 +538,7 @@ impl Render for KeystrokeInput { .w_full() .flex_1() .justify_between() - .rounded_sm() + .rounded_md() .overflow_hidden() .map(|this| { if is_recording { @@ -539,16 +548,16 @@ impl Render for KeystrokeInput { } }) .border_1() - .border_color(colors.border_variant) - .when(is_focused, |parent| { - parent.border_color(colors.border_focused) + .map(|this| { + if is_focused { + this.border_color(colors.border_focused) + } else { + this.border_color(colors.border_variant) + } }) - .key_context(Self::key_context()) - .on_action(cx.listener(Self::start_recording)) - .on_action(cx.listener(Self::clear_keystrokes)) .child( h_flex() - .w(horizontal_padding) + .w(width) .gap_0p5() .justify_start() .flex_none() @@ -567,14 +576,13 @@ impl Render for KeystrokeInput { .id("keystroke-input-inner") .track_focus(&self.inner_focus_handle) .on_modifiers_changed(cx.listener(Self::on_modifiers_changed)) - .size_full() .when(!self.search, |this| { this.focus(|mut style| { style.border_color = Some(colors.border_focused); style }) }) - .w_full() + .size_full() .min_w_0() .justify_center() .flex_wrap() @@ -583,7 +591,7 @@ impl Render for KeystrokeInput { ) .child( h_flex() - .w(horizontal_padding) + .w(width) .gap_0p5() .justify_end() .flex_none() @@ -635,9 +643,7 @@ impl Render for KeystrokeInput { "Clear Keystrokes", &ClearKeystrokes, )) - .when(!is_recording || !is_focused, |this| { - this.icon_color(Color::Muted) - }) + .when(!is_focused, |this| this.icon_color(Color::Muted)) .on_click(cx.listener(|this, _event, window, cx| { this.clear_keystrokes(&ClearKeystrokes, window, cx); })), @@ -706,8 +712,11 @@ mod tests { // Combine current modifiers with keystroke modifiers keystroke.modifiers |= self.current_modifiers; + let real_modifiers = keystroke.modifiers; + keystroke = to_gpui_keystroke(keystroke); self.update_input(|input, window, cx| { + window.set_modifiers(real_modifiers); input.handle_keystroke(&keystroke, window, cx); }); @@ -735,6 +744,7 @@ mod tests { }; self.update_input(|input, window, cx| { + window.set_modifiers(new_modifiers); input.on_modifiers_changed(&event, window, cx); }); @@ -809,9 +819,13 @@ mod tests { /// Verifies that the keystrokes match the expected strings #[track_caller] pub fn expect_keystrokes(&mut self, expected: &[&str]) -> &mut Self { - let actual = self - .input - .read_with(&mut self.cx, |input, _| input.keystrokes.clone()); + let actual: Vec = self.input.read_with(&self.cx, |input, _| { + input + .keystrokes + .iter() + .map(|keystroke| keystroke.inner().clone()) + .collect() + }); Self::expect_keystrokes_equal(&actual, expected); self } @@ -820,7 +834,7 @@ mod tests { pub fn expect_close_keystrokes(&mut self, expected: &[&str]) -> &mut Self { let actual = self .input - .read_with(&mut self.cx, |input, _| input.close_keystrokes.clone()) + .read_with(&self.cx, |input, _| input.close_keystrokes.clone()) .unwrap_or_default(); Self::expect_keystrokes_equal(&actual, expected); self @@ -934,12 +948,106 @@ mod tests { let change_tracker = KeystrokeUpdateTracker::new(self.input.clone(), &mut self.cx); let result = self.input.update_in(&mut self.cx, cb); KeystrokeUpdateTracker::finish(change_tracker, &self.cx); - return result; + result } } + /// For GPUI, when you press `ctrl-shift-2`, it produces `ctrl-@` without the shift modifier. + fn to_gpui_keystroke(mut keystroke: Keystroke) -> Keystroke { + if keystroke.modifiers.shift { + match keystroke.key.as_str() { + "`" => { + keystroke.key = "~".into(); + keystroke.modifiers.shift = false; + } + "1" => { + keystroke.key = "!".into(); + keystroke.modifiers.shift = false; + } + "2" => { + keystroke.key = "@".into(); + keystroke.modifiers.shift = false; + } + "3" => { + keystroke.key = "#".into(); + keystroke.modifiers.shift = false; + } + "4" => { + keystroke.key = "$".into(); + keystroke.modifiers.shift = false; + } + "5" => { + keystroke.key = "%".into(); + keystroke.modifiers.shift = false; + } + "6" => { + keystroke.key = "^".into(); + keystroke.modifiers.shift = false; + } + "7" => { + keystroke.key = "&".into(); + keystroke.modifiers.shift = false; + } + "8" => { + keystroke.key = "*".into(); + keystroke.modifiers.shift = false; + } + "9" => { + keystroke.key = "(".into(); + keystroke.modifiers.shift = false; + } + "0" => { + keystroke.key = ")".into(); + keystroke.modifiers.shift = false; + } + "-" => { + keystroke.key = "_".into(); + keystroke.modifiers.shift = false; + } + "=" => { + keystroke.key = "+".into(); + keystroke.modifiers.shift = false; + } + "[" => { + keystroke.key = "{".into(); + keystroke.modifiers.shift = false; + } + "]" => { + keystroke.key = "}".into(); + keystroke.modifiers.shift = false; + } + "\\" => { + keystroke.key = "|".into(); + keystroke.modifiers.shift = false; + } + ";" => { + keystroke.key = ":".into(); + keystroke.modifiers.shift = false; + } + "'" => { + keystroke.key = "\"".into(); + keystroke.modifiers.shift = false; + } + "," => { + keystroke.key = "<".into(); + keystroke.modifiers.shift = false; + } + "." => { + keystroke.key = ">".into(); + keystroke.modifiers.shift = false; + } + "/" => { + keystroke.key = "?".into(); + keystroke.modifiers.shift = false; + } + _ => {} + } + } + keystroke + } + struct KeystrokeUpdateTracker { - initial_keystrokes: Vec, + initial_keystrokes: Vec, _subscription: Subscription, input: Entity, received_keystrokes_updated: bool, @@ -983,8 +1091,8 @@ mod tests { ); } - fn keystrokes_str(ks: &[Keystroke]) -> String { - ks.iter().map(|ks| ks.unparse()).join(" ") + fn keystrokes_str(ks: &[KeybindingKeystroke]) -> String { + ks.iter().map(|ks| ks.inner().unparse()).join(" ") } } } @@ -1041,7 +1149,15 @@ mod tests { .send_events(&["+cmd", "shift-f", "-cmd"]) // In search mode, when completing a modifier-only keystroke with a key, // only the original modifiers are preserved, not the keystroke's modifiers - .expect_keystrokes(&["cmd-f"]); + // + // Update: + // This behavior was changed to preserve all modifiers in search mode, this is now reflected in the expected keystrokes. + // Specifically, considering the sequence: `+cmd +shift -shift 2`, we expect it to produce the same result as `+cmd +shift 2` + // which is `cmd-@`. But in the case of `+cmd +shift -shift 2`, the keystroke we receive is `cmd-2`, which means that + // we need to dynamically map the key from `2` to `@` when the shift modifier is not present, which is not possible. + // Therefore, we now preserve all modifiers in search mode to ensure consistent behavior. + // And also, VSCode seems to preserve all modifiers in search mode as well. + .expect_keystrokes(&["cmd-shift-f"]); } #[gpui::test] @@ -1218,7 +1334,7 @@ mod tests { .await .with_search_mode(true) .send_events(&["+ctrl", "+shift", "-shift", "a", "-ctrl"]) - .expect_keystrokes(&["ctrl-shift-a"]); + .expect_keystrokes(&["ctrl-a"]); } #[gpui::test] @@ -1326,7 +1442,7 @@ mod tests { .await .with_search_mode(true) .send_events(&["+ctrl+alt", "-ctrl", "j"]) - .expect_keystrokes(&["ctrl-alt-j"]); + .expect_keystrokes(&["alt-j"]); } #[gpui::test] @@ -1348,11 +1464,11 @@ mod tests { .send_events(&["+ctrl+alt", "-ctrl", "+shift"]) .expect_keystrokes(&["ctrl-shift-alt-"]) .send_keystroke("j") - .expect_keystrokes(&["ctrl-shift-alt-j"]) + .expect_keystrokes(&["shift-alt-j"]) .send_keystroke("i") - .expect_keystrokes(&["ctrl-shift-alt-j", "shift-alt-i"]) + .expect_keystrokes(&["shift-alt-j", "shift-alt-i"]) .send_events(&["-shift-alt", "+cmd"]) - .expect_keystrokes(&["ctrl-shift-alt-j", "shift-alt-i", "cmd-"]); + .expect_keystrokes(&["shift-alt-j", "shift-alt-i", "cmd-"]); } #[gpui::test] @@ -1385,4 +1501,13 @@ mod tests { .send_events(&["+ctrl", "-ctrl", "+alt", "-alt", "+shift", "-shift"]) .expect_empty(); } + + #[gpui::test] + async fn test_not_search_shifted_keys(cx: &mut TestAppContext) { + init_test(cx) + .await + .with_search_mode(false) + .send_events(&["+ctrl", "+shift", "4", "-all"]) + .expect_keystrokes(&["ctrl-$"]); + } } diff --git a/crates/settings_ui/src/ui_components/mod.rs b/crates/keymap_editor/src/ui_components/mod.rs similarity index 100% rename from crates/settings_ui/src/ui_components/mod.rs rename to crates/keymap_editor/src/ui_components/mod.rs diff --git a/crates/settings_ui/src/ui_components/table.rs b/crates/keymap_editor/src/ui_components/table.rs similarity index 99% rename from crates/settings_ui/src/ui_components/table.rs rename to crates/keymap_editor/src/ui_components/table.rs index 2b3e815f369a96235c8628935df433737d58b0ce..9d7bb0736061181eda93d072640f87b5946a2675 100644 --- a/crates/settings_ui/src/ui_components/table.rs +++ b/crates/keymap_editor/src/ui_components/table.rs @@ -213,7 +213,7 @@ impl TableInteractionState { let mut column_ix = 0; let resizable_columns_slice = *resizable_columns; - let mut resizable_columns = resizable_columns.into_iter(); + let mut resizable_columns = resizable_columns.iter(); let dividers = intersperse_with(spacers, || { window.with_id(column_ix, |window| { @@ -343,7 +343,7 @@ impl TableInteractionState { .on_any_mouse_down(|_, _, cx| { cx.stop_propagation(); }) - .on_scroll_wheel(Self::listener(&this, |_, _, _, cx| { + .on_scroll_wheel(Self::listener(this, |_, _, _, cx| { cx.notify(); })) .children(Scrollbar::vertical( @@ -731,7 +731,7 @@ impl ColumnWidths { } widths[col_idx] = widths[col_idx] + (diff - diff_remaining); - return diff_remaining; + diff_remaining } } @@ -801,7 +801,7 @@ impl Table { ) -> Self { self.rows = TableContents::UniformList(UniformListData { element_id: id.into(), - row_count: row_count, + row_count, render_item_fn: Box::new(render_item_fn), }); self diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index 83517accc239ecf9d2196f124fc5695a8545ef17..51e6b6d1e032aa5e786e9117b96aff8adaba638f 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -27,12 +27,13 @@ use gpui::{ App, AppContext as _, Context, Entity, EventEmitter, HighlightStyle, SharedString, StyledText, Task, TaskLabel, TextStyle, }; + use lsp::{LanguageServerId, NumberOrString}; use parking_lot::Mutex; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::Value; -use settings::WorktreeId; +use settings::{SettingsUi, WorktreeId}; use smallvec::SmallVec; use smol::future::yield_now; use std::{ @@ -173,7 +174,9 @@ pub enum IndentKind { } /// The shape of a selection cursor. -#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] +#[derive( + Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq, JsonSchema, SettingsUi, +)] #[serde(rename_all = "snake_case")] pub enum CursorShape { /// A vertical bar @@ -202,7 +205,7 @@ pub struct Diagnostic { pub source: Option, /// A machine-readable code that identifies this diagnostic. pub code: Option, - pub code_description: Option, + pub code_description: Option, /// Whether this diagnostic is a hint, warning, or error. pub severity: DiagnosticSeverity, /// The human-readable message associated with this diagnostic. @@ -282,6 +285,14 @@ pub enum Operation { /// The language server ID. server_id: LanguageServerId, }, + + /// An update to the line ending type of this buffer. + UpdateLineEnding { + /// The line ending type. + line_ending: LineEnding, + /// The buffer's lamport timestamp. + lamport_timestamp: clock::Lamport, + }, } /// An event that occurs in a buffer. @@ -313,10 +324,6 @@ pub enum BufferEvent { DiagnosticsUpdated, /// The buffer gained or lost editing capabilities. CapabilityChanged, - /// The buffer was explicitly requested to close. - Closed, - /// The buffer was discarded when closing. - Discarded, } /// The file associated with a buffer. @@ -494,6 +501,10 @@ pub struct Chunk<'a> { pub is_unnecessary: bool, /// Whether this chunk of text was originally a tab character. pub is_tab: bool, + /// A bitset of which characters are tabs in this string. + pub tabs: u128, + /// Bitmap of character indices in this chunk + pub chars: u128, /// Whether this chunk of text was originally a tab character. pub is_inlay: bool, /// Whether to underline the corresponding text range in the editor. @@ -629,13 +640,13 @@ impl HighlightedTextBuilder { self.text.push_str(chunk.text); let end = self.text.len(); - if let Some(mut highlight_style) = chunk + if let Some(highlight_style) = chunk .syntax_highlight_id .and_then(|id| id.style(syntax_theme)) { - if let Some(override_style) = override_style { - highlight_style.highlight(override_style); - } + let highlight_style = override_style.map_or(highlight_style, |override_style| { + highlight_style.highlight(override_style) + }); self.highlights.push((start..end, highlight_style)); } else if let Some(override_style) = override_style { self.highlights.push((start..end, override_style)); @@ -716,7 +727,7 @@ impl EditPreview { &self.applied_edits_snapshot, &self.syntax_snapshot, None, - &syntax_theme, + syntax_theme, ); } @@ -727,7 +738,7 @@ impl EditPreview { ¤t_snapshot.text, ¤t_snapshot.syntax, Some(deletion_highlight_style), - &syntax_theme, + syntax_theme, ); } @@ -737,7 +748,7 @@ impl EditPreview { &self.applied_edits_snapshot, &self.syntax_snapshot, Some(insertion_highlight_style), - &syntax_theme, + syntax_theme, ); } @@ -749,7 +760,7 @@ impl EditPreview { &self.applied_edits_snapshot, &self.syntax_snapshot, None, - &syntax_theme, + syntax_theme, ); highlighted_text.build() @@ -974,8 +985,6 @@ impl Buffer { TextBuffer::new_normalized(0, buffer_id, Default::default(), text).snapshot(); let mut syntax = SyntaxMap::new(&text).snapshot(); if let Some(language) = language.clone() { - let text = text.clone(); - let language = language.clone(); let language_registry = language_registry.clone(); syntax.reparse(&text, language_registry, language); } @@ -1020,9 +1029,6 @@ impl Buffer { let text = TextBuffer::new_normalized(0, buffer_id, Default::default(), text).snapshot(); let mut syntax = SyntaxMap::new(&text).snapshot(); if let Some(language) = language.clone() { - let text = text.clone(); - let language = language.clone(); - let language_registry = language_registry.clone(); syntax.reparse(&text, language_registry, language); } BufferSnapshot { @@ -1128,7 +1134,7 @@ impl Buffer { } else { ranges.as_slice() } - .into_iter() + .iter() .peekable(); let mut edits = Vec::new(); @@ -1158,13 +1164,12 @@ impl Buffer { base_buffer.edit(edits, None, cx) }); - if let Some(operation) = operation { - if let Some(BufferBranchState { + if let Some(operation) = operation + && let Some(BufferBranchState { merged_operations, .. }) = &mut self.branch_state - { - merged_operations.push(operation); - } + { + merged_operations.push(operation); } } @@ -1185,11 +1190,11 @@ impl Buffer { }; let mut operation_to_undo = None; - if let Operation::Buffer(text::Operation::Edit(operation)) = &operation { - if let Ok(ix) = merged_operations.binary_search(&operation.timestamp) { - merged_operations.remove(ix); - operation_to_undo = Some(operation.timestamp); - } + if let Operation::Buffer(text::Operation::Edit(operation)) = &operation + && let Ok(ix) = merged_operations.binary_search(&operation.timestamp) + { + merged_operations.remove(ix); + operation_to_undo = Some(operation.timestamp); } self.apply_ops([operation.clone()], cx); @@ -1248,10 +1253,27 @@ impl Buffer { self.syntax_map.lock().language_registry() } + /// Assign the line ending type to the buffer. + pub fn set_line_ending(&mut self, line_ending: LineEnding, cx: &mut Context) { + self.text.set_line_ending(line_ending); + + let lamport_timestamp = self.text.lamport_clock.tick(); + self.send_operation( + Operation::UpdateLineEnding { + line_ending, + lamport_timestamp, + }, + true, + cx, + ); + } + /// Assign the buffer a new [`Capability`]. pub fn set_capability(&mut self, capability: Capability, cx: &mut Context) { - self.capability = capability; - cx.emit(BufferEvent::CapabilityChanged) + if self.capability != capability { + self.capability = capability; + cx.emit(BufferEvent::CapabilityChanged) + } } /// This method is called to signal that the buffer has been saved. @@ -1271,12 +1293,6 @@ impl Buffer { cx.notify(); } - /// This method is called to signal that the buffer has been discarded. - pub fn discarded(&self, cx: &mut Context) { - cx.emit(BufferEvent::Discarded); - cx.notify(); - } - /// Reloads the contents of the buffer from disk. pub fn reload(&mut self, cx: &Context) -> oneshot::Receiver> { let (tx, rx) = futures::channel::oneshot::channel(); @@ -1396,7 +1412,8 @@ impl Buffer { is_first = false; return true; } - let any_sub_ranges_contain_range = layer + + layer .included_sub_ranges .map(|sub_ranges| { sub_ranges.iter().any(|sub_range| { @@ -1405,9 +1422,7 @@ impl Buffer { !is_before_start && !is_after_end }) }) - .unwrap_or(true); - let result = any_sub_ranges_contain_range; - return result; + .unwrap_or(true) }) .last() .map(|info| info.language.clone()) @@ -1424,10 +1439,10 @@ impl Buffer { .map(|info| info.language.clone()) .collect(); - if languages.is_empty() { - if let Some(buffer_language) = self.language() { - languages.push(buffer_language.clone()); - } + if languages.is_empty() + && let Some(buffer_language) = self.language() + { + languages.push(buffer_language.clone()); } languages @@ -1521,12 +1536,12 @@ impl Buffer { let new_syntax_map = parse_task.await; this.update(cx, move |this, cx| { let grammar_changed = - this.language.as_ref().map_or(true, |current_language| { + this.language.as_ref().is_none_or(|current_language| { !Arc::ptr_eq(&language, current_language) }); let language_registry_changed = new_syntax_map .contains_unknown_injections() - && language_registry.map_or(false, |registry| { + && language_registry.is_some_and(|registry| { registry.version() != new_syntax_map.language_registry_version() }); let parse_again = language_registry_changed @@ -1571,15 +1586,26 @@ impl Buffer { diagnostics: diagnostics.iter().cloned().collect(), lamport_timestamp, }; + self.apply_diagnostic_update(server_id, diagnostics, lamport_timestamp, cx); self.send_operation(op, true, cx); } - pub fn get_diagnostics(&self, server_id: LanguageServerId) -> Option<&DiagnosticSet> { - let Ok(idx) = self.diagnostics.binary_search_by_key(&server_id, |v| v.0) else { - return None; - }; - Some(&self.diagnostics[idx].1) + pub fn buffer_diagnostics( + &self, + for_server: Option, + ) -> Vec<&DiagnosticEntry> { + match for_server { + Some(server_id) => match self.diagnostics.binary_search_by_key(&server_id, |v| v.0) { + Ok(idx) => self.diagnostics[idx].1.iter().collect(), + Err(_) => Vec::new(), + }, + None => self + .diagnostics + .iter() + .flat_map(|(_, diagnostic_set)| diagnostic_set.iter()) + .collect(), + } } fn request_autoindent(&mut self, cx: &mut Context) { @@ -1719,8 +1745,7 @@ impl Buffer { }) .with_delta(suggestion.delta, language_indent_size); - if old_suggestions.get(&new_row).map_or( - true, + if old_suggestions.get(&new_row).is_none_or( |(old_indentation, was_within_error)| { suggested_indent != *old_indentation && (!suggestion.within_error || *was_within_error) @@ -2014,7 +2039,7 @@ impl Buffer { fn was_changed(&mut self) { self.change_bits.retain(|change_bit| { - change_bit.upgrade().map_or(false, |bit| { + change_bit.upgrade().is_some_and(|bit| { bit.replace(true); true }) @@ -2191,7 +2216,7 @@ impl Buffer { if self .remote_selections .get(&self.text.replica_id()) - .map_or(true, |set| !set.selections.is_empty()) + .is_none_or(|set| !set.selections.is_empty()) { self.set_active_selections(Arc::default(), false, Default::default(), cx); } @@ -2208,7 +2233,7 @@ impl Buffer { self.remote_selections.insert( AGENT_REPLICA_ID, SelectionSet { - selections: selections.clone(), + selections, lamport_timestamp, line_mode, cursor_shape, @@ -2270,13 +2295,11 @@ impl Buffer { } let new_text = new_text.into(); if !new_text.is_empty() || !range.is_empty() { - if let Some((prev_range, prev_text)) = edits.last_mut() { - if prev_range.end >= range.start { - prev_range.end = cmp::max(prev_range.end, range.end); - *prev_text = format!("{prev_text}{new_text}").into(); - } else { - edits.push((range, new_text)); - } + if let Some((prev_range, prev_text)) = edits.last_mut() + && prev_range.end >= range.start + { + prev_range.end = cmp::max(prev_range.end, range.end); + *prev_text = format!("{prev_text}{new_text}").into(); } else { edits.push((range, new_text)); } @@ -2296,10 +2319,27 @@ impl Buffer { if let Some((before_edit, mode)) = autoindent_request { let mut delta = 0isize; - let entries = edits + let mut previous_setting = None; + let entries: Vec<_> = edits .into_iter() .enumerate() .zip(&edit_operation.as_edit().unwrap().new_text) + .filter(|((_, (range, _)), _)| { + let language = before_edit.language_at(range.start); + let language_id = language.map(|l| l.id()); + if let Some((cached_language_id, auto_indent)) = previous_setting + && cached_language_id == language_id + { + auto_indent + } else { + // The auto-indent setting is not present in editorconfigs, hence + // we can avoid passing the file here. + let auto_indent = + language_settings(language.map(|l| l.name()), None, cx).auto_indent; + previous_setting = Some((language_id, auto_indent)); + auto_indent + } + }) .map(|((ix, (range, _)), new_text)| { let new_text_length = new_text.len(); let old_start = range.start.to_point(&before_edit); @@ -2373,12 +2413,14 @@ impl Buffer { }) .collect(); - self.autoindent_requests.push(Arc::new(AutoindentRequest { - before_edit, - entries, - is_block_mode: matches!(mode, AutoindentMode::Block { .. }), - ignore_empty_lines: false, - })); + if !entries.is_empty() { + self.autoindent_requests.push(Arc::new(AutoindentRequest { + before_edit, + entries, + is_block_mode: matches!(mode, AutoindentMode::Block { .. }), + ignore_empty_lines: false, + })); + } } self.end_transaction(cx); @@ -2543,7 +2585,7 @@ impl Buffer { Operation::UpdateSelections { selections, .. } => selections .iter() .all(|s| self.can_resolve(&s.start) && self.can_resolve(&s.end)), - Operation::UpdateCompletionTriggers { .. } => true, + Operation::UpdateCompletionTriggers { .. } | Operation::UpdateLineEnding { .. } => true, } } @@ -2571,10 +2613,10 @@ impl Buffer { line_mode, cursor_shape, } => { - if let Some(set) = self.remote_selections.get(&lamport_timestamp.replica_id) { - if set.lamport_timestamp > lamport_timestamp { - return; - } + if let Some(set) = self.remote_selections.get(&lamport_timestamp.replica_id) + && set.lamport_timestamp > lamport_timestamp + { + return; } self.remote_selections.insert( @@ -2600,7 +2642,7 @@ impl Buffer { self.completion_triggers = self .completion_triggers_per_language_server .values() - .flat_map(|triggers| triggers.into_iter().cloned()) + .flat_map(|triggers| triggers.iter().cloned()) .collect(); } else { self.completion_triggers_per_language_server @@ -2609,6 +2651,13 @@ impl Buffer { } self.text.lamport_clock.observe(lamport_timestamp); } + Operation::UpdateLineEnding { + line_ending, + lamport_timestamp, + } => { + self.text.set_line_ending(line_ending); + self.text.lamport_clock.observe(lamport_timestamp); + } } } @@ -2760,7 +2809,7 @@ impl Buffer { self.completion_triggers = self .completion_triggers_per_language_server .values() - .flat_map(|triggers| triggers.into_iter().cloned()) + .flat_map(|triggers| triggers.iter().cloned()) .collect(); } else { self.completion_triggers_per_language_server @@ -2822,18 +2871,18 @@ impl Buffer { let mut edits: Vec<(Range, String)> = Vec::new(); let mut last_end = None; for _ in 0..old_range_count { - if last_end.map_or(false, |last_end| last_end >= self.len()) { + if last_end.is_some_and(|last_end| last_end >= self.len()) { break; } let new_start = last_end.map_or(0, |last_end| last_end + 1); let mut range = self.random_byte_range(new_start, rng); - if rng.gen_bool(0.2) { + if rng.random_bool(0.2) { mem::swap(&mut range.start, &mut range.end); } last_end = Some(range.end); - let new_text_len = rng.gen_range(0..10); + let new_text_len = rng.random_range(0..10); let mut new_text: String = RandomCharIter::new(&mut *rng).take(new_text_len).collect(); new_text = new_text.to_uppercase(); @@ -2991,9 +3040,9 @@ impl BufferSnapshot { } let mut error_ranges = Vec::>::new(); - let mut matches = self.syntax.matches(range.clone(), &self.text, |grammar| { - grammar.error_query.as_ref() - }); + let mut matches = self + .syntax + .matches(range, &self.text, |grammar| grammar.error_query.as_ref()); while let Some(mat) = matches.peek() { let node = mat.captures[0].node; let start = Point::from_ts_point(node.start_position()); @@ -3042,14 +3091,14 @@ impl BufferSnapshot { if config .decrease_indent_pattern .as_ref() - .map_or(false, |regex| regex.is_match(line)) + .is_some_and(|regex| regex.is_match(line)) { indent_change_rows.push((row, Ordering::Less)); } if config .increase_indent_pattern .as_ref() - .map_or(false, |regex| regex.is_match(line)) + .is_some_and(|regex| regex.is_match(line)) { indent_change_rows.push((row + 1, Ordering::Greater)); } @@ -3065,7 +3114,7 @@ impl BufferSnapshot { } } for rule in &config.decrease_indent_patterns { - if rule.pattern.as_ref().map_or(false, |r| r.is_match(line)) { + if rule.pattern.as_ref().is_some_and(|r| r.is_match(line)) { let row_start_column = self.indent_size_for_line(row).len; let basis_row = rule .valid_after @@ -3278,8 +3327,7 @@ impl BufferSnapshot { range: Range, ) -> Option> { let range = range.to_offset(self); - return self - .syntax + self.syntax .layers_for_range(range, &self.text, false) .max_by(|a, b| { if a.depth != b.depth { @@ -3289,7 +3337,7 @@ impl BufferSnapshot { } else { a.node().end_byte().cmp(&b.node().end_byte()).reverse() } - }); + }) } /// Returns the main [`Language`]. @@ -3347,9 +3395,8 @@ impl BufferSnapshot { } } - if let Some(range) = range { - if smallest_range_and_depth.as_ref().map_or( - true, + if let Some(range) = range + && smallest_range_and_depth.as_ref().is_none_or( |(smallest_range, smallest_range_depth)| { if layer.depth > *smallest_range_depth { true @@ -3359,13 +3406,13 @@ impl BufferSnapshot { false } }, - ) { - smallest_range_and_depth = Some((range, layer.depth)); - scope = Some(LanguageScope { - language: layer.language.clone(), - override_id: layer.override_id(offset, &self.text), - }); - } + ) + { + smallest_range_and_depth = Some((range, layer.depth)); + scope = Some(LanguageScope { + language: layer.language.clone(), + override_id: layer.override_id(offset, &self.text), + }); } } @@ -3417,46 +3464,66 @@ impl BufferSnapshot { } /// Returns the closest syntax node enclosing the given range. + /// Positions a tree cursor at the leaf node that contains or touches the given range. + /// This is shared logic used by syntax navigation methods. + fn position_cursor_at_range(cursor: &mut tree_sitter::TreeCursor, range: &Range) { + // Descend to the first leaf that touches the start of the range. + // + // If the range is non-empty and the current node ends exactly at the start, + // move to the next sibling to find a node that extends beyond the start. + // + // If the range is empty and the current node starts after the range position, + // move to the previous sibling to find the node that contains the position. + while cursor.goto_first_child_for_byte(range.start).is_some() { + if !range.is_empty() && cursor.node().end_byte() == range.start { + cursor.goto_next_sibling(); + } + if range.is_empty() && cursor.node().start_byte() > range.start { + cursor.goto_previous_sibling(); + } + } + } + + /// Moves the cursor to find a node that contains the given range. + /// Returns true if such a node is found, false otherwise. + /// This is shared logic used by syntax navigation methods. + fn find_containing_node( + cursor: &mut tree_sitter::TreeCursor, + range: &Range, + strict: bool, + ) -> bool { + loop { + let node_range = cursor.node().byte_range(); + + if node_range.start <= range.start + && node_range.end >= range.end + && (!strict || node_range.len() > range.len()) + { + return true; + } + if !cursor.goto_parent() { + return false; + } + } + } + pub fn syntax_ancestor<'a, T: ToOffset>( &'a self, range: Range, ) -> Option> { let range = range.start.to_offset(self)..range.end.to_offset(self); let mut result: Option> = None; - 'outer: for layer in self + for layer in self .syntax .layers_for_range(range.clone(), &self.text, true) { let mut cursor = layer.node().walk(); - // Descend to the first leaf that touches the start of the range. - // - // If the range is non-empty and the current node ends exactly at the start, - // move to the next sibling to find a node that extends beyond the start. - // - // If the range is empty and the current node starts after the range position, - // move to the previous sibling to find the node that contains the position. - while cursor.goto_first_child_for_byte(range.start).is_some() { - if !range.is_empty() && cursor.node().end_byte() == range.start { - cursor.goto_next_sibling(); - } - if range.is_empty() && cursor.node().start_byte() > range.start { - cursor.goto_previous_sibling(); - } - } + Self::position_cursor_at_range(&mut cursor, &range); // Ascend to the smallest ancestor that strictly contains the range. - loop { - let node_range = cursor.node().byte_range(); - if node_range.start <= range.start - && node_range.end >= range.end - && node_range.len() > range.len() - { - break; - } - if !cursor.goto_parent() { - continue 'outer; - } + if !Self::find_containing_node(&mut cursor, &range, true) { + continue; } let left_node = cursor.node(); @@ -3481,19 +3548,125 @@ impl BufferSnapshot { // If there is a candidate node on both sides of the (empty) range, then // decide between the two by favoring a named node over an anonymous token. // If both nodes are the same in that regard, favor the right one. - if let Some(right_node) = right_node { - if right_node.is_named() || !left_node.is_named() { - layer_result = right_node; + if let Some(right_node) = right_node + && (right_node.is_named() || !left_node.is_named()) + { + layer_result = right_node; + } + } + + if let Some(previous_result) = &result + && previous_result.byte_range().len() < layer_result.byte_range().len() + { + continue; + } + result = Some(layer_result); + } + + result + } + + /// Find the previous sibling syntax node at the given range. + /// + /// This function locates the syntax node that precedes the node containing + /// the given range. It searches hierarchically by: + /// 1. Finding the node that contains the given range + /// 2. Looking for the previous sibling at the same tree level + /// 3. If no sibling is found, moving up to parent levels and searching for siblings + /// + /// Returns `None` if there is no previous sibling at any ancestor level. + pub fn syntax_prev_sibling<'a, T: ToOffset>( + &'a self, + range: Range, + ) -> Option> { + let range = range.start.to_offset(self)..range.end.to_offset(self); + let mut result: Option> = None; + + for layer in self + .syntax + .layers_for_range(range.clone(), &self.text, true) + { + let mut cursor = layer.node().walk(); + + Self::position_cursor_at_range(&mut cursor, &range); + + // Find the node that contains the range + if !Self::find_containing_node(&mut cursor, &range, false) { + continue; + } + + // Look for the previous sibling, moving up ancestor levels if needed + loop { + if cursor.goto_previous_sibling() { + let layer_result = cursor.node(); + + if let Some(previous_result) = &result { + if previous_result.byte_range().end < layer_result.byte_range().end { + continue; + } } + result = Some(layer_result); + break; + } + + // No sibling found at this level, try moving up to parent + if !cursor.goto_parent() { + break; } } + } - if let Some(previous_result) = &result { - if previous_result.byte_range().len() < layer_result.byte_range().len() { - continue; + result + } + + /// Find the next sibling syntax node at the given range. + /// + /// This function locates the syntax node that follows the node containing + /// the given range. It searches hierarchically by: + /// 1. Finding the node that contains the given range + /// 2. Looking for the next sibling at the same tree level + /// 3. If no sibling is found, moving up to parent levels and searching for siblings + /// + /// Returns `None` if there is no next sibling at any ancestor level. + pub fn syntax_next_sibling<'a, T: ToOffset>( + &'a self, + range: Range, + ) -> Option> { + let range = range.start.to_offset(self)..range.end.to_offset(self); + let mut result: Option> = None; + + for layer in self + .syntax + .layers_for_range(range.clone(), &self.text, true) + { + let mut cursor = layer.node().walk(); + + Self::position_cursor_at_range(&mut cursor, &range); + + // Find the node that contains the range + if !Self::find_containing_node(&mut cursor, &range, false) { + continue; + } + + // Look for the next sibling, moving up ancestor levels if needed + loop { + if cursor.goto_next_sibling() { + let layer_result = cursor.node(); + + if let Some(previous_result) = &result { + if previous_result.byte_range().start > layer_result.byte_range().start { + continue; + } + } + result = Some(layer_result); + break; + } + + // No sibling found at this level, try moving up to parent + if !cursor.goto_parent() { + break; } } - result = Some(layer_result); } result @@ -3526,16 +3699,15 @@ impl BufferSnapshot { } } - return Some(cursor.node()); + Some(cursor.node()) } /// Returns the outline for the buffer. /// /// This method allows passing an optional [`SyntaxTheme`] to /// syntax-highlight the returned symbols. - pub fn outline(&self, theme: Option<&SyntaxTheme>) -> Option> { - self.outline_items_containing(0..self.len(), true, theme) - .map(Outline::new) + pub fn outline(&self, theme: Option<&SyntaxTheme>) -> Outline { + Outline::new(self.outline_items_containing(0..self.len(), true, theme)) } /// Returns all the symbols that contain the given position. @@ -3546,20 +3718,20 @@ impl BufferSnapshot { &self, position: T, theme: Option<&SyntaxTheme>, - ) -> Option>> { + ) -> Vec> { let position = position.to_offset(self); let mut items = self.outline_items_containing( position.saturating_sub(1)..self.len().min(position + 1), false, theme, - )?; + ); let mut prev_depth = None; items.retain(|item| { - let result = prev_depth.map_or(true, |prev_depth| item.depth > prev_depth); + let result = prev_depth.is_none_or(|prev_depth| item.depth > prev_depth); prev_depth = Some(item.depth); result }); - Some(items) + items } pub fn outline_range_containing(&self, range: Range) -> Option> { @@ -3609,21 +3781,19 @@ impl BufferSnapshot { range: Range, include_extra_context: bool, theme: Option<&SyntaxTheme>, - ) -> Option>> { + ) -> Vec> { let range = range.to_offset(self); let mut matches = self.syntax.matches(range.clone(), &self.text, |grammar| { grammar.outline_config.as_ref().map(|c| &c.query) }); - let configs = matches - .grammars() - .iter() - .map(|g| g.outline_config.as_ref().unwrap()) - .collect::>(); let mut items = Vec::new(); let mut annotation_row_ranges: Vec> = Vec::new(); while let Some(mat) = matches.peek() { - let config = &configs[mat.grammar_index]; + let config = matches.grammars()[mat.grammar_index] + .outline_config + .as_ref() + .unwrap(); if let Some(item) = self.next_outline_item(config, &mat, &range, include_extra_context, theme) { @@ -3702,7 +3872,7 @@ impl BufferSnapshot { item_ends_stack.push(item.range.end); } - Some(anchor_items) + anchor_items } fn next_outline_item( @@ -4062,11 +4232,11 @@ impl BufferSnapshot { // Get the ranges of the innermost pair of brackets. let mut result: Option<(Range, Range)> = None; - for pair in self.enclosing_bracket_ranges(range.clone()) { - if let Some(range_filter) = range_filter { - if !range_filter(pair.open_range.clone(), pair.close_range.clone()) { - continue; - } + for pair in self.enclosing_bracket_ranges(range) { + if let Some(range_filter) = range_filter + && !range_filter(pair.open_range.clone(), pair.close_range.clone()) + { + continue; } let len = pair.close_range.end - pair.open_range.start; @@ -4235,7 +4405,7 @@ impl BufferSnapshot { .map(|(range, name)| { ( name.to_string(), - self.text_for_range(range.clone()).collect::(), + self.text_for_range(range).collect::(), ) }) .collect(); @@ -4432,7 +4602,7 @@ impl BufferSnapshot { pub fn words_in_range(&self, query: WordsQuery) -> BTreeMap> { let query_str = query.fuzzy_contents; - if query_str.map_or(false, |query| query.is_empty()) { + if query_str.is_some_and(|query| query.is_empty()) { return BTreeMap::default(); } @@ -4456,27 +4626,26 @@ impl BufferSnapshot { current_word_start_ix = Some(ix); } - if let Some(query_chars) = &query_chars { - if query_ix < query_len { - if c.to_lowercase().eq(query_chars[query_ix].to_lowercase()) { - query_ix += 1; - } - } + if let Some(query_chars) = &query_chars + && query_ix < query_len + && c.to_lowercase().eq(query_chars[query_ix].to_lowercase()) + { + query_ix += 1; } continue; - } else if let Some(word_start) = current_word_start_ix.take() { - if query_ix == query_len { - let word_range = self.anchor_before(word_start)..self.anchor_after(ix); - let mut word_text = self.text_for_range(word_start..ix).peekable(); - let first_char = word_text - .peek() - .and_then(|first_chunk| first_chunk.chars().next()); - // Skip empty and "words" starting with digits as a heuristic to reduce useless completions - if !query.skip_digits - || first_char.map_or(true, |first_char| !first_char.is_digit(10)) - { - words.insert(word_text.collect(), word_range); - } + } else if let Some(word_start) = current_word_start_ix.take() + && query_ix == query_len + { + let word_range = self.anchor_before(word_start)..self.anchor_after(ix); + let mut word_text = self.text_for_range(word_start..ix).peekable(); + let first_char = word_text + .peek() + .and_then(|first_chunk| first_chunk.chars().next()); + // Skip empty and "words" starting with digits as a heuristic to reduce useless completions + if !query.skip_digits + || first_char.is_none_or(|first_char| !first_char.is_digit(10)) + { + words.insert(word_text.collect(), word_range); } } query_ix = 0; @@ -4589,17 +4758,17 @@ impl<'a> BufferChunks<'a> { highlights .stack .retain(|(end_offset, _)| *end_offset > range.start); - if let Some(capture) = &highlights.next_capture { - if range.start >= capture.node.start_byte() { - let next_capture_end = capture.node.end_byte(); - if range.start < next_capture_end { - highlights.stack.push(( - next_capture_end, - highlights.highlight_maps[capture.grammar_index].get(capture.index), - )); - } - highlights.next_capture.take(); + if let Some(capture) = &highlights.next_capture + && range.start >= capture.node.start_byte() + { + let next_capture_end = capture.node.end_byte(); + if range.start < next_capture_end { + highlights.stack.push(( + next_capture_end, + highlights.highlight_maps[capture.grammar_index].get(capture.index), + )); } + highlights.next_capture.take(); } } else if let Some(snapshot) = self.buffer_snapshot { let (captures, highlight_maps) = snapshot.get_highlights(self.range.clone()); @@ -4624,33 +4793,33 @@ impl<'a> BufferChunks<'a> { } fn initialize_diagnostic_endpoints(&mut self) { - if let Some(diagnostics) = self.diagnostic_endpoints.as_mut() { - if let Some(buffer) = self.buffer_snapshot { - let mut diagnostic_endpoints = Vec::new(); - for entry in buffer.diagnostics_in_range::<_, usize>(self.range.clone(), false) { - diagnostic_endpoints.push(DiagnosticEndpoint { - offset: entry.range.start, - is_start: true, - severity: entry.diagnostic.severity, - is_unnecessary: entry.diagnostic.is_unnecessary, - underline: entry.diagnostic.underline, - }); - diagnostic_endpoints.push(DiagnosticEndpoint { - offset: entry.range.end, - is_start: false, - severity: entry.diagnostic.severity, - is_unnecessary: entry.diagnostic.is_unnecessary, - underline: entry.diagnostic.underline, - }); - } - diagnostic_endpoints - .sort_unstable_by_key(|endpoint| (endpoint.offset, !endpoint.is_start)); - *diagnostics = diagnostic_endpoints.into_iter().peekable(); - self.hint_depth = 0; - self.error_depth = 0; - self.warning_depth = 0; - self.information_depth = 0; + if let Some(diagnostics) = self.diagnostic_endpoints.as_mut() + && let Some(buffer) = self.buffer_snapshot + { + let mut diagnostic_endpoints = Vec::new(); + for entry in buffer.diagnostics_in_range::<_, usize>(self.range.clone(), false) { + diagnostic_endpoints.push(DiagnosticEndpoint { + offset: entry.range.start, + is_start: true, + severity: entry.diagnostic.severity, + is_unnecessary: entry.diagnostic.is_unnecessary, + underline: entry.diagnostic.underline, + }); + diagnostic_endpoints.push(DiagnosticEndpoint { + offset: entry.range.end, + is_start: false, + severity: entry.diagnostic.severity, + is_unnecessary: entry.diagnostic.is_unnecessary, + underline: entry.diagnostic.underline, + }); } + diagnostic_endpoints + .sort_unstable_by_key(|endpoint| (endpoint.offset, !endpoint.is_start)); + *diagnostics = diagnostic_endpoints.into_iter().peekable(); + self.hint_depth = 0; + self.error_depth = 0; + self.warning_depth = 0; + self.information_depth = 0; } } @@ -4755,21 +4924,36 @@ impl<'a> Iterator for BufferChunks<'a> { } self.diagnostic_endpoints = diagnostic_endpoints; - if let Some(chunk) = self.chunks.peek() { + if let Some(ChunkBitmaps { + text: chunk, + chars: chars_map, + tabs, + }) = self.chunks.peek_tabs() + { let chunk_start = self.range.start; let mut chunk_end = (self.chunks.offset() + chunk.len()) .min(next_capture_start) .min(next_diagnostic_endpoint); let mut highlight_id = None; - if let Some(highlights) = self.highlights.as_ref() { - if let Some((parent_capture_end, parent_highlight_id)) = highlights.stack.last() { - chunk_end = chunk_end.min(*parent_capture_end); - highlight_id = Some(*parent_highlight_id); - } + if let Some(highlights) = self.highlights.as_ref() + && let Some((parent_capture_end, parent_highlight_id)) = highlights.stack.last() + { + chunk_end = chunk_end.min(*parent_capture_end); + highlight_id = Some(*parent_highlight_id); } let slice = &chunk[chunk_start - self.chunks.offset()..chunk_end - self.chunks.offset()]; + let bit_end = chunk_end - self.chunks.offset(); + + let mask = if bit_end >= 128 { + u128::MAX + } else { + (1u128 << bit_end) - 1 + }; + let tabs = (tabs >> (chunk_start - self.chunks.offset())) & mask; + let chars_map = (chars_map >> (chunk_start - self.chunks.offset())) & mask; + self.range.start = chunk_end; if self.range.start == self.chunks.offset() + chunk.len() { self.chunks.next().unwrap(); @@ -4781,6 +4965,8 @@ impl<'a> Iterator for BufferChunks<'a> { underline: self.underline, diagnostic_severity: self.current_diagnostic_severity(), is_unnecessary: self.current_code_is_unnecessary(), + tabs, + chars: chars_map, ..Chunk::default() }) } else { @@ -4803,6 +4989,9 @@ impl operation_queue::Operation for Operation { } | Operation::UpdateCompletionTriggers { lamport_timestamp, .. + } + | Operation::UpdateLineEnding { + lamport_timestamp, .. } => *lamport_timestamp, } } @@ -4959,11 +5148,12 @@ pub(crate) fn contiguous_ranges( std::iter::from_fn(move || { loop { if let Some(value) = values.next() { - if let Some(range) = &mut current_range { - if value == range.end && range.len() < max_len { - range.end += 1; - continue; - } + if let Some(range) = &mut current_range + && value == range.end + && range.len() < max_len + { + range.end += 1; + continue; } let prev_range = current_range.clone(); @@ -5031,10 +5221,10 @@ impl CharClassifier { } else { scope.word_characters() }; - if let Some(characters) = characters { - if characters.contains(&c) { - return CharKind::Word; - } + if let Some(characters) = characters + && characters.contains(&c) + { + return CharKind::Word; } } diff --git a/crates/language/src/buffer_tests.rs b/crates/language/src/buffer_tests.rs index 2e2df7e658596daaca3b338ef830794fd0d3bef8..fcd93390c891f1d65b2f424a5bc70cd7f23c7912 100644 --- a/crates/language/src/buffer_tests.rs +++ b/crates/language/src/buffer_tests.rs @@ -67,6 +67,78 @@ fn test_line_endings(cx: &mut gpui::App) { }); } +#[gpui::test] +fn test_set_line_ending(cx: &mut TestAppContext) { + let base = cx.new(|cx| Buffer::local("one\ntwo\nthree\n", cx)); + let base_replica = cx.new(|cx| { + Buffer::from_proto(1, Capability::ReadWrite, base.read(cx).to_proto(cx), None).unwrap() + }); + base.update(cx, |_buffer, cx| { + cx.subscribe(&base_replica, |this, _, event, cx| { + if let BufferEvent::Operation { + operation, + is_local: true, + } = event + { + this.apply_ops([operation.clone()], cx); + } + }) + .detach(); + }); + base_replica.update(cx, |_buffer, cx| { + cx.subscribe(&base, |this, _, event, cx| { + if let BufferEvent::Operation { + operation, + is_local: true, + } = event + { + this.apply_ops([operation.clone()], cx); + } + }) + .detach(); + }); + + // Base + base_replica.read_with(cx, |buffer, _| { + assert_eq!(buffer.line_ending(), LineEnding::Unix); + }); + base.update(cx, |buffer, cx| { + assert_eq!(buffer.line_ending(), LineEnding::Unix); + buffer.set_line_ending(LineEnding::Windows, cx); + assert_eq!(buffer.line_ending(), LineEnding::Windows); + }); + base_replica.read_with(cx, |buffer, _| { + assert_eq!(buffer.line_ending(), LineEnding::Windows); + }); + base.update(cx, |buffer, cx| { + buffer.set_line_ending(LineEnding::Unix, cx); + assert_eq!(buffer.line_ending(), LineEnding::Unix); + }); + base_replica.read_with(cx, |buffer, _| { + assert_eq!(buffer.line_ending(), LineEnding::Unix); + }); + + // Replica + base.read_with(cx, |buffer, _| { + assert_eq!(buffer.line_ending(), LineEnding::Unix); + }); + base_replica.update(cx, |buffer, cx| { + assert_eq!(buffer.line_ending(), LineEnding::Unix); + buffer.set_line_ending(LineEnding::Windows, cx); + assert_eq!(buffer.line_ending(), LineEnding::Windows); + }); + base.read_with(cx, |buffer, _| { + assert_eq!(buffer.line_ending(), LineEnding::Windows); + }); + base_replica.update(cx, |buffer, cx| { + buffer.set_line_ending(LineEnding::Unix, cx); + assert_eq!(buffer.line_ending(), LineEnding::Unix); + }); + base.read_with(cx, |buffer, _| { + assert_eq!(buffer.line_ending(), LineEnding::Unix); + }); +} + #[gpui::test] fn test_select_language(cx: &mut App) { init_settings(cx, |_| {}); @@ -707,9 +779,7 @@ async fn test_outline(cx: &mut gpui::TestAppContext) { .unindent(); let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let outline = buffer - .update(cx, |buffer, _| buffer.snapshot().outline(None)) - .unwrap(); + let outline = buffer.update(cx, |buffer, _| buffer.snapshot().outline(None)); assert_eq!( outline @@ -791,9 +861,7 @@ async fn test_outline_nodes_with_newlines(cx: &mut gpui::TestAppContext) { .unindent(); let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let outline = buffer - .update(cx, |buffer, _| buffer.snapshot().outline(None)) - .unwrap(); + let outline = buffer.update(cx, |buffer, _| buffer.snapshot().outline(None)); assert_eq!( outline @@ -830,7 +898,7 @@ async fn test_outline_with_extra_context(cx: &mut gpui::TestAppContext) { let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()); // extra context nodes are included in the outline. - let outline = snapshot.outline(None).unwrap(); + let outline = snapshot.outline(None); assert_eq!( outline .items @@ -841,7 +909,7 @@ async fn test_outline_with_extra_context(cx: &mut gpui::TestAppContext) { ); // extra context nodes do not appear in breadcrumbs. - let symbols = snapshot.symbols_containing(3, None).unwrap(); + let symbols = snapshot.symbols_containing(3, None); assert_eq!( symbols .iter() @@ -873,9 +941,7 @@ fn test_outline_annotations(cx: &mut App) { .unindent(); let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let outline = buffer - .update(cx, |buffer, _| buffer.snapshot().outline(None)) - .unwrap(); + let outline = buffer.update(cx, |buffer, _| buffer.snapshot().outline(None)); assert_eq!( outline @@ -979,7 +1045,6 @@ async fn test_symbols_containing(cx: &mut gpui::TestAppContext) { ) -> Vec<(String, Range)> { snapshot .symbols_containing(position, None) - .unwrap() .into_iter() .map(|item| { ( @@ -1744,7 +1809,7 @@ fn test_autoindent_block_mode(cx: &mut App) { buffer.edit( [(Point::new(2, 8)..Point::new(2, 8), inserted_text)], Some(AutoindentMode::Block { - original_indent_columns: original_indent_columns.clone(), + original_indent_columns, }), cx, ); @@ -1790,9 +1855,9 @@ fn test_autoindent_block_mode_with_newline(cx: &mut App) { "# .unindent(); buffer.edit( - [(Point::new(2, 0)..Point::new(2, 0), inserted_text.clone())], + [(Point::new(2, 0)..Point::new(2, 0), inserted_text)], Some(AutoindentMode::Block { - original_indent_columns: original_indent_columns.clone(), + original_indent_columns, }), cx, ); @@ -1843,7 +1908,7 @@ fn test_autoindent_block_mode_without_original_indent_columns(cx: &mut App) { buffer.edit( [(Point::new(2, 0)..Point::new(2, 0), inserted_text)], Some(AutoindentMode::Block { - original_indent_columns: original_indent_columns.clone(), + original_indent_columns, }), cx, ); @@ -2030,7 +2095,7 @@ fn test_autoindent_with_injected_languages(cx: &mut App) { let language_registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); language_registry.add(html_language.clone()); - language_registry.add(javascript_language.clone()); + language_registry.add(javascript_language); cx.new(|cx| { let (text, ranges) = marked_text_ranges( @@ -3013,7 +3078,7 @@ fn test_random_collaboration(cx: &mut App, mut rng: StdRng) { .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) .unwrap_or(10); - let base_text_len = rng.gen_range(0..10); + let base_text_len = rng.random_range(0..10); let base_text = RandomCharIter::new(&mut rng) .take(base_text_len) .collect::(); @@ -3022,7 +3087,7 @@ fn test_random_collaboration(cx: &mut App, mut rng: StdRng) { let network = Arc::new(Mutex::new(Network::new(rng.clone()))); let base_buffer = cx.new(|cx| Buffer::local(base_text.as_str(), cx)); - for i in 0..rng.gen_range(min_peers..=max_peers) { + for i in 0..rng.random_range(min_peers..=max_peers) { let buffer = cx.new(|cx| { let state = base_buffer.read(cx).to_proto(cx); let ops = cx @@ -3035,7 +3100,7 @@ fn test_random_collaboration(cx: &mut App, mut rng: StdRng) { .map(|op| proto::deserialize_operation(op).unwrap()), cx, ); - buffer.set_group_interval(Duration::from_millis(rng.gen_range(0..=200))); + buffer.set_group_interval(Duration::from_millis(rng.random_range(0..=200))); let network = network.clone(); cx.subscribe(&cx.entity(), move |buffer, _, event, _| { if let BufferEvent::Operation { @@ -3066,11 +3131,11 @@ fn test_random_collaboration(cx: &mut App, mut rng: StdRng) { let mut next_diagnostic_id = 0; let mut active_selections = BTreeMap::default(); loop { - let replica_index = rng.gen_range(0..replica_ids.len()); + let replica_index = rng.random_range(0..replica_ids.len()); let replica_id = replica_ids[replica_index]; let buffer = &mut buffers[replica_index]; let mut new_buffer = None; - match rng.gen_range(0..100) { + match rng.random_range(0..100) { 0..=29 if mutation_count != 0 => { buffer.update(cx, |buffer, cx| { buffer.start_transaction_at(now); @@ -3082,13 +3147,13 @@ fn test_random_collaboration(cx: &mut App, mut rng: StdRng) { } 30..=39 if mutation_count != 0 => { buffer.update(cx, |buffer, cx| { - if rng.gen_bool(0.2) { + if rng.random_bool(0.2) { log::info!("peer {} clearing active selections", replica_id); active_selections.remove(&replica_id); buffer.remove_active_selections(cx); } else { let mut selections = Vec::new(); - for id in 0..rng.gen_range(1..=5) { + for id in 0..rng.random_range(1..=5) { let range = buffer.random_byte_range(0, &mut rng); selections.push(Selection { id, @@ -3111,7 +3176,7 @@ fn test_random_collaboration(cx: &mut App, mut rng: StdRng) { mutation_count -= 1; } 40..=49 if mutation_count != 0 && replica_id == 0 => { - let entry_count = rng.gen_range(1..=5); + let entry_count = rng.random_range(1..=5); buffer.update(cx, |buffer, cx| { let diagnostics = DiagnosticSet::new( (0..entry_count).map(|_| { @@ -3166,7 +3231,7 @@ fn test_random_collaboration(cx: &mut App, mut rng: StdRng) { new_buffer.replica_id(), new_buffer.text() ); - new_buffer.set_group_interval(Duration::from_millis(rng.gen_range(0..=200))); + new_buffer.set_group_interval(Duration::from_millis(rng.random_range(0..=200))); let network = network.clone(); cx.subscribe(&cx.entity(), move |buffer, _, event, _| { if let BufferEvent::Operation { @@ -3238,7 +3303,7 @@ fn test_random_collaboration(cx: &mut App, mut rng: StdRng) { _ => {} } - now += Duration::from_millis(rng.gen_range(0..=200)); + now += Duration::from_millis(rng.random_range(0..=200)); buffers.extend(new_buffer); for buffer in &buffers { @@ -3320,23 +3385,23 @@ fn test_trailing_whitespace_ranges(mut rng: StdRng) { // Generate a random multi-line string containing // some lines with trailing whitespace. let mut text = String::new(); - for _ in 0..rng.gen_range(0..16) { - for _ in 0..rng.gen_range(0..36) { - text.push(match rng.gen_range(0..10) { + for _ in 0..rng.random_range(0..16) { + for _ in 0..rng.random_range(0..36) { + text.push(match rng.random_range(0..10) { 0..=1 => ' ', 3 => '\t', - _ => rng.gen_range('a'..='z'), + _ => rng.random_range('a'..='z'), }); } text.push('\n'); } - match rng.gen_range(0..10) { + match rng.random_range(0..10) { // sometimes remove the last newline 0..=1 => drop(text.pop()), // // sometimes add extra newlines - 2..=3 => text.push_str(&"\n".repeat(rng.gen_range(1..5))), + 2..=3 => text.push_str(&"\n".repeat(rng.random_range(1..5))), _ => {} } @@ -3787,3 +3852,80 @@ fn init_settings(cx: &mut App, f: fn(&mut AllLanguageSettingsContent)) { settings.update_user_settings::(cx, f); }); } + +#[gpui::test(iterations = 100)] +fn test_random_chunk_bitmaps(cx: &mut App, mut rng: StdRng) { + use util::RandomCharIter; + + // Generate random text + let len = rng.random_range(0..10000); + let text = RandomCharIter::new(&mut rng).take(len).collect::(); + + let buffer = cx.new(|cx| Buffer::local(text, cx)); + let snapshot = buffer.read(cx).snapshot(); + + // Get all chunks and verify their bitmaps + let chunks = snapshot.chunks(0..snapshot.len(), false); + + for chunk in chunks { + let chunk_text = chunk.text; + let chars_bitmap = chunk.chars; + let tabs_bitmap = chunk.tabs; + + // Check empty chunks have empty bitmaps + if chunk_text.is_empty() { + assert_eq!( + chars_bitmap, 0, + "Empty chunk should have empty chars bitmap" + ); + assert_eq!(tabs_bitmap, 0, "Empty chunk should have empty tabs bitmap"); + continue; + } + + // Verify that chunk text doesn't exceed 128 bytes + assert!( + chunk_text.len() <= 128, + "Chunk text length {} exceeds 128 bytes", + chunk_text.len() + ); + + // Verify chars bitmap + let char_indices = chunk_text + .char_indices() + .map(|(i, _)| i) + .collect::>(); + + for byte_idx in 0..chunk_text.len() { + let should_have_bit = char_indices.contains(&byte_idx); + let has_bit = chars_bitmap & (1 << byte_idx) != 0; + + if has_bit != should_have_bit { + eprintln!("Chunk text bytes: {:?}", chunk_text.as_bytes()); + eprintln!("Char indices: {:?}", char_indices); + eprintln!("Chars bitmap: {:#b}", chars_bitmap); + } + + assert_eq!( + has_bit, should_have_bit, + "Chars bitmap mismatch at byte index {} in chunk {:?}. Expected bit: {}, Got bit: {}", + byte_idx, chunk_text, should_have_bit, has_bit + ); + } + + // Verify tabs bitmap + for (byte_idx, byte) in chunk_text.bytes().enumerate() { + let is_tab = byte == b'\t'; + let has_bit = tabs_bitmap & (1 << byte_idx) != 0; + + if has_bit != is_tab { + eprintln!("Chunk text bytes: {:?}", chunk_text.as_bytes()); + eprintln!("Tabs bitmap: {:#b}", tabs_bitmap); + assert_eq!( + has_bit, is_tab, + "Tabs bitmap mismatch at byte index {} in chunk {:?}. Byte: {:?}, Expected bit: {}, Got bit: {}", + byte_idx, chunk_text, byte as char, is_tab, has_bit + ); + } + } + } +} diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index b9933dfcec36f1e8c5cb31271668a25b60020c8a..256f6d45734ec068f1e038fe0d07049bb732e34b 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -44,6 +44,7 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer, de}; use serde_json::Value; use settings::WorktreeId; use smol::future::FutureExt as _; +use std::num::NonZeroU32; use std::{ any::Any, ffi::OsStr, @@ -59,7 +60,6 @@ use std::{ atomic::{AtomicU64, AtomicUsize, Ordering::SeqCst}, }, }; -use std::{num::NonZeroU32, sync::OnceLock}; use syntax_map::{QueryCursorHandle, SyntaxSnapshot}; use task::RunnableTag; pub use task_context::{ContextLocation, ContextProvider, RunnableRange}; @@ -67,7 +67,10 @@ pub use text_diff::{ DiffOptions, apply_diff_patch, line_diff, text_diff, text_diff_with_options, unified_diff, }; use theme::SyntaxTheme; -pub use toolchain::{LanguageToolchainStore, Toolchain, ToolchainList, ToolchainLister}; +pub use toolchain::{ + LanguageToolchainStore, LocalLanguageToolchainStore, Toolchain, ToolchainList, ToolchainLister, + ToolchainMetadata, ToolchainScope, +}; use tree_sitter::{self, Query, QueryCursor, WasmStore, wasmtime}; use util::serde::default_true; @@ -119,8 +122,8 @@ where func(cursor.deref_mut()) } -static NEXT_LANGUAGE_ID: LazyLock = LazyLock::new(Default::default); -static NEXT_GRAMMAR_ID: LazyLock = LazyLock::new(Default::default); +static NEXT_LANGUAGE_ID: AtomicUsize = AtomicUsize::new(0); +static NEXT_GRAMMAR_ID: AtomicUsize = AtomicUsize::new(0); static WASM_ENGINE: LazyLock = LazyLock::new(|| { wasmtime::Engine::new(&wasmtime::Config::new()).expect("Failed to create Wasmtime engine") }); @@ -165,7 +168,6 @@ pub struct CachedLspAdapter { pub adapter: Arc, pub reinstall_attempt_count: AtomicU64, cached_binary: futures::lock::Mutex>, - manifest_name: OnceLock>, } impl Debug for CachedLspAdapter { @@ -201,18 +203,17 @@ impl CachedLspAdapter { adapter, cached_binary: Default::default(), reinstall_attempt_count: AtomicU64::new(0), - manifest_name: Default::default(), }) } pub fn name(&self) -> LanguageServerName { - self.adapter.name().clone() + self.adapter.name() } pub async fn get_language_server_command( self: Arc, delegate: Arc, - toolchains: Arc, + toolchains: Option, binary_options: LanguageServerBinaryOptions, cx: &mut AsyncApp, ) -> Result { @@ -281,21 +282,6 @@ impl CachedLspAdapter { .cloned() .unwrap_or_else(|| language_name.lsp_id()) } - - pub fn manifest_name(&self) -> Option { - self.manifest_name - .get_or_init(|| self.adapter.manifest_name()) - .clone() - } -} - -/// Determines what gets sent out as a workspace folders content -#[derive(Clone, Copy, Debug, PartialEq)] -pub enum WorkspaceFoldersContent { - /// Send out a single entry with the root of the workspace. - WorktreeRoot, - /// Send out a list of subproject roots. - SubprojectRoots, } /// [`LspAdapterDelegate`] allows [`LspAdapter]` implementations to interface with the application @@ -327,7 +313,7 @@ pub trait LspAdapter: 'static + Send + Sync { fn get_language_server_command<'a>( self: Arc, delegate: Arc, - toolchains: Arc, + toolchains: Option, binary_options: LanguageServerBinaryOptions, mut cached_binary: futures::lock::MutexGuard<'a, Option>, cx: &'a mut AsyncApp, @@ -344,9 +330,9 @@ pub trait LspAdapter: 'static + Send + Sync { // We only want to cache when we fall back to the global one, // because we don't want to download and overwrite our global one // for each worktree we might have open. - if binary_options.allow_path_lookup { - if let Some(binary) = self.check_if_user_installed(delegate.as_ref(), toolchains, cx).await { - log::info!( + if binary_options.allow_path_lookup + && let Some(binary) = self.check_if_user_installed(delegate.as_ref(), toolchains, cx).await { + log::debug!( "found user-installed language server for {}. path: {:?}, arguments: {:?}", self.name().0, binary.path, @@ -354,7 +340,6 @@ pub trait LspAdapter: 'static + Send + Sync { ); return Ok(binary); } - } anyhow::ensure!(binary_options.allow_binary_download, "downloading language servers disabled"); @@ -402,7 +387,7 @@ pub trait LspAdapter: 'static + Send + Sync { async fn check_if_user_installed( &self, _: &dyn LspAdapterDelegate, - _: Arc, + _: Option, _: &AsyncApp, ) -> Option { None @@ -411,6 +396,7 @@ pub trait LspAdapter: 'static + Send + Sync { async fn fetch_latest_server_version( &self, delegate: &dyn LspAdapterDelegate, + cx: &AsyncApp, ) -> Result>; fn will_fetch_server( @@ -535,7 +521,7 @@ pub trait LspAdapter: 'static + Send + Sync { self: Arc, _: &dyn Fs, _: &Arc, - _: Arc, + _: Option, _cx: &mut AsyncApp, ) -> Result { Ok(serde_json::json!({})) @@ -555,7 +541,6 @@ pub trait LspAdapter: 'static + Send + Sync { _target_language_server_id: LanguageServerName, _: &dyn Fs, _: &Arc, - _: Arc, _cx: &mut AsyncApp, ) -> Result> { Ok(None) @@ -587,17 +572,6 @@ pub trait LspAdapter: 'static + Send + Sync { Ok(original) } - /// Determines whether a language server supports workspace folders. - /// - /// And does not trip over itself in the process. - fn workspace_folders_content(&self) -> WorkspaceFoldersContent { - WorkspaceFoldersContent::SubprojectRoots - } - - fn manifest_name(&self) -> Option { - None - } - /// Method only implemented by the default JSON language server adapter. /// Used to provide dynamic reloading of the JSON schemas used to /// provide autocompletion and diagnostics in Zed setting and keybind @@ -616,6 +590,11 @@ pub trait LspAdapter: 'static + Send + Sync { "Not implemented for this adapter. This method should only be called on the default JSON language server adapter" ); } + + /// True for the extension adapter and false otherwise. + fn is_extension(&self) -> bool { + false + } } async fn try_fetch_server_binary( @@ -629,18 +608,18 @@ async fn try_fetch_server_binary } let name = adapter.name(); - log::info!("fetching latest version of language server {:?}", name.0); + log::debug!("fetching latest version of language server {:?}", name.0); delegate.update_status(name.clone(), BinaryStatus::CheckingForUpdate); let latest_version = adapter - .fetch_latest_server_version(delegate.as_ref()) + .fetch_latest_server_version(delegate.as_ref(), cx) .await?; if let Some(binary) = adapter .check_if_version_installed(latest_version.as_ref(), &container_dir, delegate.as_ref()) .await { - log::info!("language server {:?} is already installed", name.0); + log::debug!("language server {:?} is already installed", name.0); delegate.update_status(name.clone(), BinaryStatus::None); Ok(binary) } else { @@ -748,6 +727,9 @@ pub struct LanguageConfig { /// How to soft-wrap long lines of text. #[serde(default)] pub soft_wrap: Option, + /// When set, selections can be wrapped using prefix/suffix pairs on both sides. + #[serde(default)] + pub wrap_characters: Option, /// The name of a Prettier parser that will be used for this language when no file path is available. /// If there's a parser name in the language settings, that will be used instead. #[serde(default)] @@ -951,6 +933,7 @@ impl Default for LanguageConfig { hard_tabs: None, tab_size: None, soft_wrap: None, + wrap_characters: None, prettier_parser_name: None, hidden: false, jsx_tag_auto_close: None, @@ -960,6 +943,18 @@ impl Default for LanguageConfig { } } +#[derive(Clone, Debug, Deserialize, JsonSchema)] +pub struct WrapCharactersConfig { + /// Opening token split into a prefix and suffix. The first caret goes + /// after the prefix (i.e., between prefix and suffix). + pub start_prefix: String, + pub start_suffix: String, + /// Closing token split into a prefix and suffix. The second caret goes + /// after the prefix (i.e., between prefix and suffix). + pub end_prefix: String, + pub end_suffix: String, +} + fn auto_indent_using_last_non_empty_line_default() -> bool { true } @@ -991,11 +986,11 @@ where fn deserialize_regex_vec<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { let sources = Vec::::deserialize(d)?; - let mut regexes = Vec::new(); - for source in sources { - regexes.push(regex::Regex::new(&source).map_err(de::Error::custom)?); - } - Ok(regexes) + sources + .into_iter() + .map(|source| regex::Regex::new(&source)) + .collect::>() + .map_err(de::Error::custom) } fn regex_vec_json_schema(_: &mut SchemaGenerator) -> schemars::Schema { @@ -1061,12 +1056,10 @@ impl<'de> Deserialize<'de> for BracketPairConfig { D: Deserializer<'de>, { let result = Vec::::deserialize(deserializer)?; - let mut brackets = Vec::with_capacity(result.len()); - let mut disabled_scopes_by_bracket_ix = Vec::with_capacity(result.len()); - for entry in result { - brackets.push(entry.bracket_pair); - disabled_scopes_by_bracket_ix.push(entry.not_in); - } + let (brackets, disabled_scopes_by_bracket_ix) = result + .into_iter() + .map(|entry| (entry.bracket_pair, entry.not_in)) + .unzip(); Ok(BracketPairConfig { pairs: brackets, @@ -1108,6 +1101,7 @@ pub struct Language { pub(crate) grammar: Option>, pub(crate) context_provider: Option>, pub(crate) toolchain: Option>, + pub(crate) manifest_name: Option, } #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] @@ -1263,6 +1257,7 @@ struct InjectionPatternConfig { combined: bool, } +#[derive(Debug)] struct BracketsConfig { query: Query, open_capture_ix: u32, @@ -1318,6 +1313,7 @@ impl Language { }), context_provider: None, toolchain: None, + manifest_name: None, } } @@ -1331,6 +1327,10 @@ impl Language { self } + pub fn with_manifest(mut self, name: Option) -> Self { + self.manifest_name = name; + self + } pub fn with_queries(mut self, queries: LanguageQueries) -> Result { if let Some(query) = queries.highlights { self = self @@ -1400,16 +1400,14 @@ impl Language { let grammar = self.grammar_mut().context("cannot mutate grammar")?; let query = Query::new(&grammar.ts_language, source)?; - let mut extra_captures = Vec::with_capacity(query.capture_names().len()); - - for name in query.capture_names().iter() { - let kind = if *name == "run" { - RunnableCapture::Run - } else { - RunnableCapture::Named(name.to_string().into()) - }; - extra_captures.push(kind); - } + let extra_captures: Vec<_> = query + .capture_names() + .iter() + .map(|&name| match name { + "run" => RunnableCapture::Run, + name => RunnableCapture::Named(name.to_string().into()), + }) + .collect(); grammar.runnable_config = Some(RunnableConfig { extra_captures, @@ -1539,9 +1537,8 @@ impl Language { .map(|ix| { let mut config = BracketsPatternConfig::default(); for setting in query.property_settings(ix) { - match setting.key.as_ref() { - "newline.only" => config.newline_only = true, - _ => {} + if setting.key.as_ref() == "newline.only" { + config.newline_only = true } } config @@ -1764,6 +1761,9 @@ impl Language { pub fn name(&self) -> LanguageName { self.config.name.clone() } + pub fn manifest(&self) -> Option<&ManifestName> { + self.manifest_name.as_ref() + } pub fn code_fence_block_name(&self) -> Arc { self.config @@ -1798,10 +1798,10 @@ impl Language { BufferChunks::new(text, range, Some((captures, highlight_maps)), false, None) { let end_offset = offset + chunk.text.len(); - if let Some(highlight_id) = chunk.syntax_highlight_id { - if !highlight_id.is_default() { - result.push((offset..end_offset, highlight_id)); - } + if let Some(highlight_id) = chunk.syntax_highlight_id + && !highlight_id.is_default() + { + result.push((offset..end_offset, highlight_id)); } offset = end_offset; } @@ -1818,11 +1818,11 @@ impl Language { } pub fn set_theme(&self, theme: &SyntaxTheme) { - if let Some(grammar) = self.grammar.as_ref() { - if let Some(highlights_query) = &grammar.highlights_query { - *grammar.highlight_map.lock() = - HighlightMap::new(highlights_query.capture_names(), theme); - } + if let Some(grammar) = self.grammar.as_ref() + && let Some(highlights_query) = &grammar.highlights_query + { + *grammar.highlight_map.lock() = + HighlightMap::new(highlights_query.capture_names(), theme); } } @@ -1852,7 +1852,7 @@ impl Language { impl LanguageScope { pub fn path_suffixes(&self) -> &[String] { - &self.language.path_suffixes() + self.language.path_suffixes() } pub fn language_name(&self) -> LanguageName { @@ -1942,11 +1942,11 @@ impl LanguageScope { .enumerate() .map(move |(ix, bracket)| { let mut is_enabled = true; - if let Some(next_disabled_ix) = disabled_ids.first() { - if ix == *next_disabled_ix as usize { - disabled_ids = &disabled_ids[1..]; - is_enabled = false; - } + if let Some(next_disabled_ix) = disabled_ids.first() + && ix == *next_disabled_ix as usize + { + disabled_ids = &disabled_ids[1..]; + is_enabled = false; } (bracket, is_enabled) }) @@ -2209,7 +2209,7 @@ impl LspAdapter for FakeLspAdapter { async fn check_if_user_installed( &self, _: &dyn LspAdapterDelegate, - _: Arc, + _: Option, _: &AsyncApp, ) -> Option { Some(self.language_server_binary.clone()) @@ -2218,7 +2218,7 @@ impl LspAdapter for FakeLspAdapter { fn get_language_server_command<'a>( self: Arc, _: Arc, - _: Arc, + _: Option, _: LanguageServerBinaryOptions, _: futures::lock::MutexGuard<'a, Option>, _: &'a mut AsyncApp, @@ -2229,6 +2229,7 @@ impl LspAdapter for FakeLspAdapter { async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, + _: &AsyncApp, ) -> Result> { unreachable!(); } @@ -2274,6 +2275,10 @@ impl LspAdapter for FakeLspAdapter { let label_for_completion = self.label_for_completion.as_ref()?; label_for_completion(item, language) } + + fn is_extension(&self) -> bool { + false + } } fn get_capture_indices(query: &Query, captures: &mut [(&str, &mut Option)]) { diff --git a/crates/language/src/language_registry.rs b/crates/language/src/language_registry.rs index ea988e8098ec2a795e8c0a386b4e162ecd5c89ca..5d9d5529c145a8769d142a1f943b6ae00aeaaeb8 100644 --- a/crates/language/src/language_registry.rs +++ b/crates/language/src/language_registry.rs @@ -1,6 +1,6 @@ use crate::{ CachedLspAdapter, File, Language, LanguageConfig, LanguageId, LanguageMatcher, - LanguageServerName, LspAdapter, PLAIN_TEXT, ToolchainLister, + LanguageServerName, LspAdapter, ManifestName, PLAIN_TEXT, ToolchainLister, language_settings::{ AllLanguageSettingsContent, LanguageSettingsContent, all_language_settings, }, @@ -49,7 +49,7 @@ impl LanguageName { pub fn from_proto(s: String) -> Self { Self(SharedString::from(s)) } - pub fn to_proto(self) -> String { + pub fn to_proto(&self) -> String { self.0.to_string() } pub fn lsp_id(&self) -> String { @@ -172,6 +172,7 @@ pub struct AvailableLanguage { hidden: bool, load: Arc Result + 'static + Send + Sync>, loaded: bool, + manifest_name: Option, } impl AvailableLanguage { @@ -259,6 +260,7 @@ pub struct LoadedLanguage { pub queries: LanguageQueries, pub context_provider: Option>, pub toolchain_provider: Option>, + pub manifest_name: Option, } impl LanguageRegistry { @@ -349,12 +351,14 @@ impl LanguageRegistry { config.grammar.clone(), config.matcher.clone(), config.hidden, + None, Arc::new(move || { Ok(LoadedLanguage { config: config.clone(), queries: Default::default(), toolchain_provider: None, context_provider: None, + manifest_name: None, }) }), ) @@ -370,14 +374,23 @@ impl LanguageRegistry { pub fn register_available_lsp_adapter( &self, name: LanguageServerName, - load: impl Fn() -> Arc + 'static + Send + Sync, + adapter: Arc, ) { - self.state.write().available_lsp_adapters.insert( + let mut state = self.state.write(); + + if adapter.is_extension() + && let Some(existing_adapter) = state.all_lsp_adapters.get(&name) + && !existing_adapter.adapter.is_extension() + { + log::warn!( + "not registering extension-provided language server {name:?}, since a builtin language server exists with that name", + ); + return; + } + + state.available_lsp_adapters.insert( name, - Arc::new(move || { - let lsp_adapter = load(); - CachedLspAdapter::new(lsp_adapter) - }), + Arc::new(move || CachedLspAdapter::new(adapter.clone())), ); } @@ -392,13 +405,21 @@ impl LanguageRegistry { Some(load_lsp_adapter()) } - pub fn register_lsp_adapter( - &self, - language_name: LanguageName, - adapter: Arc, - ) -> Arc { - let cached = CachedLspAdapter::new(adapter); + pub fn register_lsp_adapter(&self, language_name: LanguageName, adapter: Arc) { let mut state = self.state.write(); + + if adapter.is_extension() + && let Some(existing_adapter) = state.all_lsp_adapters.get(&adapter.name()) + && !existing_adapter.adapter.is_extension() + { + log::warn!( + "not registering extension-provided language server {:?} for language {language_name:?}, since a builtin language server exists with that name", + adapter.name(), + ); + return; + } + + let cached = CachedLspAdapter::new(adapter); state .lsp_adapters .entry(language_name) @@ -407,8 +428,6 @@ impl LanguageRegistry { state .all_lsp_adapters .insert(cached.name.clone(), cached.clone()); - - cached } /// Register a fake language server and adapter @@ -428,7 +447,7 @@ impl LanguageRegistry { let mut state = self.state.write(); state .lsp_adapters - .entry(language_name.clone()) + .entry(language_name) .or_default() .push(adapter.clone()); state.all_lsp_adapters.insert(adapter.name(), adapter); @@ -450,7 +469,7 @@ impl LanguageRegistry { let cached_adapter = CachedLspAdapter::new(Arc::new(adapter)); state .lsp_adapters - .entry(language_name.clone()) + .entry(language_name) .or_default() .push(cached_adapter.clone()); state @@ -487,6 +506,7 @@ impl LanguageRegistry { grammar_name: Option>, matcher: LanguageMatcher, hidden: bool, + manifest_name: Option, load: Arc Result + 'static + Send + Sync>, ) { let state = &mut *self.state.write(); @@ -496,6 +516,7 @@ impl LanguageRegistry { existing_language.grammar = grammar_name; existing_language.matcher = matcher; existing_language.load = load; + existing_language.manifest_name = manifest_name; return; } } @@ -508,6 +529,7 @@ impl LanguageRegistry { load, hidden, loaded: false, + manifest_name, }); state.version += 1; state.reload_count += 1; @@ -575,6 +597,7 @@ impl LanguageRegistry { grammar: language.config.grammar.clone(), matcher: language.config.matcher.clone(), hidden: language.config.hidden, + manifest_name: None, load: Arc::new(|| Err(anyhow!("already loaded"))), loaded: true, }); @@ -765,7 +788,7 @@ impl LanguageRegistry { }; let content_matches = || { - config.first_line_pattern.as_ref().map_or(false, |pattern| { + config.first_line_pattern.as_ref().is_some_and(|pattern| { content .as_ref() .is_some_and(|content| pattern.is_match(content)) @@ -914,10 +937,12 @@ impl LanguageRegistry { Language::new_with_id(id, loaded_language.config, grammar) .with_context_provider(loaded_language.context_provider) .with_toolchain_lister(loaded_language.toolchain_provider) + .with_manifest(loaded_language.manifest_name) .with_queries(loaded_language.queries) } else { Ok(Language::new_with_id(id, loaded_language.config, None) .with_context_provider(loaded_language.context_provider) + .with_manifest(loaded_language.manifest_name) .with_toolchain_lister(loaded_language.toolchain_provider)) } } @@ -1092,7 +1117,7 @@ impl LanguageRegistry { use gpui::AppContext as _; let mut state = self.state.write(); - let fake_entry = state.fake_server_entries.get_mut(&name)?; + let fake_entry = state.fake_server_entries.get_mut(name)?; let (server, mut fake_server) = lsp::FakeLanguageServer::new( server_id, binary, @@ -1157,8 +1182,7 @@ impl LanguageRegistryState { soft_wrap: language.config.soft_wrap, auto_indent_on_paste: language.config.auto_indent_on_paste, ..Default::default() - } - .clone(), + }, ); self.languages.push(language); self.version += 1; diff --git a/crates/language/src/language_settings.rs b/crates/language/src/language_settings.rs index 1aae0b2f7e23cc87cdd2f13e55805b566a20b5bb..af9e6edbfa4ed2ef44d7a5789069a83b7db829c7 100644 --- a/crates/language/src/language_settings.rs +++ b/crates/language/src/language_settings.rs @@ -5,10 +5,10 @@ use anyhow::Result; use collections::{FxHashMap, HashMap, HashSet}; use ec4rs::{ Properties as EditorconfigProperties, - property::{FinalNewline, IndentSize, IndentStyle, TabWidth, TrimTrailingWs}, + property::{FinalNewline, IndentSize, IndentStyle, MaxLineLen, TabWidth, TrimTrailingWs}, }; use globset::{Glob, GlobMatcher, GlobSet, GlobSetBuilder}; -use gpui::{App, Modifiers}; +use gpui::{App, Modifiers, SharedString}; use itertools::{Either, Itertools}; use schemars::{JsonSchema, json_schema}; use serde::{ @@ -17,7 +17,8 @@ use serde::{ }; use settings::{ - ParameterizedJsonSchema, Settings, SettingsLocation, SettingsSources, SettingsStore, + ParameterizedJsonSchema, Settings, SettingsKey, SettingsLocation, SettingsSources, + SettingsStore, SettingsUi, }; use shellexpand; use std::{borrow::Cow, num::NonZeroU32, path::Path, slice, sync::Arc}; @@ -122,6 +123,8 @@ pub struct LanguageSettings { pub edit_predictions_disabled_in: Vec, /// Whether to show tabs and spaces in the editor. pub show_whitespaces: ShowWhitespaceSetting, + /// Visible characters used to render whitespace when show_whitespaces is enabled. + pub whitespace_map: WhitespaceMap, /// Whether to start a new line with a comment when a previous line is a comment as well. pub extend_comment_on_newline: bool, /// Inlay hint related settings. @@ -133,6 +136,8 @@ pub struct LanguageSettings { /// Whether to use additional LSP queries to format (and amend) the code after /// every "trigger" symbol input, defined by LSP server capabilities. pub use_on_type_format: bool, + /// Whether indentation should be adjusted based on the context whilst typing. + pub auto_indent: bool, /// Whether indentation of pasted content should be adjusted based on the context. pub auto_indent_on_paste: bool, /// Controls how the editor handles the autoclosed characters. @@ -185,8 +190,8 @@ impl LanguageSettings { let rest = available_language_servers .iter() .filter(|&available_language_server| { - !disabled_language_servers.contains(&available_language_server) - && !enabled_language_servers.contains(&available_language_server) + !disabled_language_servers.contains(available_language_server) + && !enabled_language_servers.contains(available_language_server) }) .cloned() .collect::>(); @@ -197,7 +202,7 @@ impl LanguageSettings { if language_server.0.as_ref() == Self::REST_OF_LANGUAGE_SERVERS { rest.clone() } else { - vec![language_server.clone()] + vec![language_server] } }) .collect::>() @@ -205,7 +210,9 @@ impl LanguageSettings { } /// The provider that supplies edit predictions. -#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize, JsonSchema)] +#[derive( + Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize, JsonSchema, SettingsUi, +)] #[serde(rename_all = "snake_case")] pub enum EditPredictionProvider { None, @@ -228,13 +235,14 @@ impl EditPredictionProvider { /// The settings for edit predictions, such as [GitHub Copilot](https://github.com/features/copilot) /// or [Supermaven](https://supermaven.com). -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, SettingsUi)] pub struct EditPredictionSettings { /// The provider that supplies edit predictions. pub provider: EditPredictionProvider, /// A list of globs representing files that edit predictions should be disabled for. /// This list adds to a pre-existing, sensible default set of globs. /// Any additional ones you add are combined with them. + #[settings_ui(skip)] pub disabled_globs: Vec, /// Configures how edit predictions are displayed in the buffer. pub mode: EditPredictionsMode, @@ -251,7 +259,7 @@ impl EditPredictionSettings { !self.disabled_globs.iter().any(|glob| { if glob.is_absolute { file.as_local() - .map_or(false, |local| glob.matcher.is_match(local.abs_path(cx))) + .is_some_and(|local| glob.matcher.is_match(local.abs_path(cx))) } else { glob.matcher.is_match(file.path()) } @@ -266,7 +274,9 @@ pub struct DisabledGlob { } /// The mode in which edit predictions should be displayed. -#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize, JsonSchema)] +#[derive( + Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize, JsonSchema, SettingsUi, +)] #[serde(rename_all = "snake_case")] pub enum EditPredictionsMode { /// If provider supports it, display inline when holding modifier key (e.g., alt). @@ -279,18 +289,24 @@ pub enum EditPredictionsMode { Eager, } -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, SettingsUi)] pub struct CopilotSettings { /// HTTP/HTTPS proxy to use for Copilot. + #[settings_ui(skip)] pub proxy: Option, /// Disable certificate verification for proxy (not recommended). pub proxy_no_verify: Option, /// Enterprise URI for Copilot. + #[settings_ui(skip)] pub enterprise_uri: Option, } /// The settings for all languages. -#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, Clone, Default, PartialEq, Serialize, Deserialize, JsonSchema, SettingsUi, SettingsKey, +)] +#[settings_key(None)] +#[settings_ui(group = "Default Language Settings")] pub struct AllLanguageSettingsContent { /// The settings for enabling/disabling features. #[serde(default)] @@ -307,6 +323,7 @@ pub struct AllLanguageSettingsContent { /// Settings for associating file extensions and filenames /// with languages. #[serde(default)] + #[settings_ui(skip)] pub file_types: HashMap, Vec>, } @@ -315,6 +332,37 @@ pub struct AllLanguageSettingsContent { #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize, JsonSchema)] pub struct LanguageToSettingsMap(pub HashMap); +impl SettingsUi for LanguageToSettingsMap { + fn settings_ui_item() -> settings::SettingsUiItem { + settings::SettingsUiItem::DynamicMap(settings::SettingsUiItemDynamicMap { + item: LanguageSettingsContent::settings_ui_item, + defaults_path: &[], + determine_items: |settings_value, cx| { + use settings::SettingsUiEntryMetaData; + + // todo(settings_ui): We should be using a global LanguageRegistry, but it's not implemented yet + _ = cx; + + let Some(settings_language_map) = settings_value.as_object() else { + return Vec::new(); + }; + let mut languages = Vec::with_capacity(settings_language_map.len()); + + for language_name in settings_language_map.keys().map(gpui::SharedString::from) { + languages.push(SettingsUiEntryMetaData { + title: language_name.clone(), + path: language_name, + // todo(settings_ui): Implement documentation for each language + // ideally based on the language's official docs from extension or builtin info + documentation: None, + }); + } + return languages; + }, + }) + } +} + inventory::submit! { ParameterizedJsonSchema { add_and_get_ref: |generator, params, _cx| { @@ -339,7 +387,7 @@ inventory::submit! { } /// Controls how completions are processed for this language. -#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema, SettingsUi)] #[serde(rename_all = "snake_case")] pub struct CompletionSettings { /// Controls how words are completed. @@ -348,6 +396,12 @@ pub struct CompletionSettings { /// Default: `fallback` #[serde(default = "default_words_completion_mode")] pub words: WordsCompletionMode, + /// How many characters has to be in the completions query to automatically show the words-based completions. + /// Before that value, it's still possible to trigger the words-based completion manually with the corresponding editor command. + /// + /// Default: 3 + #[serde(default = "default_3")] + pub words_min_length: usize, /// Whether to fetch LSP completions or not. /// /// Default: true @@ -357,7 +411,7 @@ pub struct CompletionSettings { /// When set to 0, waits indefinitely. /// /// Default: 0 - #[serde(default = "default_lsp_fetch_timeout_ms")] + #[serde(default)] pub lsp_fetch_timeout_ms: u64, /// Controls how LSP completions are inserted. /// @@ -403,17 +457,19 @@ fn default_lsp_insert_mode() -> LspInsertMode { LspInsertMode::ReplaceSuffix } -fn default_lsp_fetch_timeout_ms() -> u64 { - 0 +fn default_3() -> usize { + 3 } /// The settings for a particular language. -#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize, JsonSchema)] +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize, JsonSchema, SettingsUi)] +#[settings_ui(group = "Default")] pub struct LanguageSettingsContent { /// How many columns a tab should occupy. /// /// Default: 4 #[serde(default)] + #[settings_ui(skip)] pub tab_size: Option, /// Whether to indent lines using tab characters, as opposed to multiple /// spaces. @@ -444,6 +500,7 @@ pub struct LanguageSettingsContent { /// /// Default: [] #[serde(default)] + #[settings_ui(skip)] pub wrap_guides: Option>, /// Indent guide related settings. #[serde(default)] @@ -469,6 +526,7 @@ pub struct LanguageSettingsContent { /// /// Default: auto #[serde(default)] + #[settings_ui(skip)] pub formatter: Option, /// Zed's Prettier integration settings. /// Allows to enable/disable formatting with Prettier @@ -494,6 +552,7 @@ pub struct LanguageSettingsContent { /// /// Default: ["..."] #[serde(default)] + #[settings_ui(skip)] pub language_servers: Option>, /// Controls where the `editor::Rewrap` action is allowed for this language. /// @@ -516,10 +575,16 @@ pub struct LanguageSettingsContent { /// /// Default: [] #[serde(default)] + #[settings_ui(skip)] pub edit_predictions_disabled_in: Option>, /// Whether to show tabs and spaces in the editor. #[serde(default)] pub show_whitespaces: Option, + /// Visible characters used to render whitespace when show_whitespaces is enabled. + /// + /// Default: "•" for spaces, "→" for tabs. + #[serde(default)] + pub whitespace_map: Option, /// Whether to start a new line with a comment when a previous line is a comment as well. /// /// Default: true @@ -555,12 +620,17 @@ pub struct LanguageSettingsContent { /// These are not run if formatting is off. /// /// Default: {} (or {"source.organizeImports": true} for Go). + #[settings_ui(skip)] pub code_actions_on_format: Option>, /// Whether to perform linked edits of associated ranges, if the language server supports it. /// For example, when editing opening tag, the contents of the closing tag will be edited as well. /// /// Default: true pub linked_edits: Option, + /// Whether indentation should be adjusted based on the context whilst typing. + /// + /// Default: true + pub auto_indent: Option, /// Whether indentation of pasted content should be adjusted based on the context. /// /// Default: true @@ -584,11 +654,14 @@ pub struct LanguageSettingsContent { /// Preferred debuggers for this language. /// /// Default: [] + #[settings_ui(skip)] pub debuggers: Option>, } /// The behavior of `editor::Rewrap`. -#[derive(Debug, PartialEq, Clone, Copy, Default, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Clone, Copy, Default, Serialize, Deserialize, JsonSchema, SettingsUi, +)] #[serde(rename_all = "snake_case")] pub enum RewrapBehavior { /// Only rewrap within comments. @@ -601,12 +674,13 @@ pub enum RewrapBehavior { } /// The contents of the edit prediction settings. -#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq)] +#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq, SettingsUi)] pub struct EditPredictionSettingsContent { /// A list of globs representing files that edit predictions should be disabled for. /// This list adds to a pre-existing, sensible default set of globs. /// Any additional ones you add are combined with them. #[serde(default)] + #[settings_ui(skip)] pub disabled_globs: Option>, /// The mode used to display edit predictions in the buffer. /// Provider support required. @@ -621,12 +695,13 @@ pub struct EditPredictionSettingsContent { pub enabled_in_text_threads: bool, } -#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq)] +#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq, SettingsUi)] pub struct CopilotSettingsContent { /// HTTP/HTTPS proxy to use for Copilot. /// /// Default: none #[serde(default)] + #[settings_ui(skip)] pub proxy: Option, /// Disable certificate verification for the proxy (not recommended). /// @@ -637,19 +712,21 @@ pub struct CopilotSettingsContent { /// /// Default: none #[serde(default)] + #[settings_ui(skip)] pub enterprise_uri: Option, } /// The settings for enabling/disabling features. -#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize, JsonSchema)] +#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize, JsonSchema, SettingsUi)] #[serde(rename_all = "snake_case")] +#[settings_ui(group = "Features")] pub struct FeaturesContent { /// Determines which edit prediction provider to use. pub edit_prediction_provider: Option, } /// Controls the soft-wrapping behavior in the editor. -#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema, SettingsUi)] #[serde(rename_all = "snake_case")] pub enum SoftWrap { /// Prefer a single line generally, unless an overly long line is encountered. @@ -666,7 +743,7 @@ pub enum SoftWrap { } /// Controls the behavior of formatting files when they are saved. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, SettingsUi)] pub enum FormatOnSave { /// Files should be formatted on save. On, @@ -765,7 +842,7 @@ impl<'de> Deserialize<'de> for FormatOnSave { } /// Controls how whitespace should be displayedin the editor. -#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema, SettingsUi)] #[serde(rename_all = "snake_case")] pub enum ShowWhitespaceSetting { /// Draw whitespace only for the selected text. @@ -785,8 +862,30 @@ pub enum ShowWhitespaceSetting { Trailing, } +#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq, SettingsUi)] +pub struct WhitespaceMap { + #[serde(default)] + pub space: Option, + #[serde(default)] + pub tab: Option, +} + +impl WhitespaceMap { + pub fn space(&self) -> SharedString { + self.space + .as_ref() + .map_or_else(|| SharedString::from("•"), |s| SharedString::from(s)) + } + + pub fn tab(&self) -> SharedString { + self.tab + .as_ref() + .map_or_else(|| SharedString::from("→"), |s| SharedString::from(s)) + } +} + /// Controls which formatter should be used when formatting code. -#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[derive(Clone, Debug, Default, PartialEq, Eq, SettingsUi)] pub enum SelectedFormatter { /// Format files using Zed's Prettier integration (if applicable), /// or falling back to formatting via language server. @@ -882,11 +981,17 @@ impl<'de> Deserialize<'de> for SelectedFormatter { } /// Controls which formatters should be used when formatting code. -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema, SettingsUi)] #[serde(untagged)] pub enum FormatterList { Single(Formatter), - Vec(Vec), + Vec(#[settings_ui(skip)] Vec), +} + +impl Default for FormatterList { + fn default() -> Self { + Self::Single(Formatter::default()) + } } impl AsRef<[Formatter]> for FormatterList { @@ -899,26 +1004,34 @@ impl AsRef<[Formatter]> for FormatterList { } /// Controls which formatter should be used when formatting code. If there are multiple formatters, they are executed in the order of declaration. -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] +#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema, SettingsUi)] #[serde(rename_all = "snake_case")] pub enum Formatter { /// Format code using the current language server. - LanguageServer { name: Option }, + LanguageServer { + #[settings_ui(skip)] + name: Option, + }, /// Format code using Zed's Prettier integration. + #[default] Prettier, /// Format code using an external command. External { /// The external program to run. + #[settings_ui(skip)] command: Arc, /// The arguments to pass to the program. + #[settings_ui(skip)] arguments: Option>, }, /// Files should be formatted using code actions executed by language servers. - CodeActions(HashMap), + CodeActions(#[settings_ui(skip)] HashMap), } /// The settings for indent guides. -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[derive( + Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, SettingsUi, +)] pub struct IndentGuideSettings { /// Whether to display indent guides in the editor. /// @@ -980,7 +1093,7 @@ pub enum IndentGuideBackgroundColoring { } /// The settings for inlay hints. -#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq, SettingsUi)] pub struct InlayHintSettings { /// Global switch to toggle hints on and off. /// @@ -1047,7 +1160,7 @@ fn scroll_debounce_ms() -> u64 { } /// The task settings for a particular language. -#[derive(Debug, Clone, Deserialize, PartialEq, Serialize, JsonSchema)] +#[derive(Debug, Clone, Deserialize, PartialEq, Serialize, JsonSchema, SettingsUi)] pub struct LanguageTaskConfig { /// Extra task variables to set for a particular language. #[serde(default)] @@ -1125,6 +1238,10 @@ impl AllLanguageSettings { } fn merge_with_editorconfig(settings: &mut LanguageSettings, cfg: &EditorconfigProperties) { + let preferred_line_length = cfg.get::().ok().and_then(|v| match v { + MaxLineLen::Value(u) => Some(u as u32), + MaxLineLen::Off => None, + }); let tab_size = cfg.get::().ok().and_then(|v| match v { IndentSize::Value(u) => NonZeroU32::new(u as u32), IndentSize::UseTabWidth => cfg.get::().ok().and_then(|w| match w { @@ -1152,6 +1269,7 @@ fn merge_with_editorconfig(settings: &mut LanguageSettings, cfg: &EditorconfigPr *target = value; } } + merge(&mut settings.preferred_line_length, preferred_line_length); merge(&mut settings.tab_size, tab_size); merge(&mut settings.hard_tabs, hard_tabs); merge( @@ -1196,8 +1314,6 @@ impl InlayHintKind { } impl settings::Settings for AllLanguageSettings { - const KEY: Option<&'static str> = None; - type FileContent = AllLanguageSettingsContent; fn load(sources: SettingsSources, _: &mut App) -> Result { @@ -1457,6 +1573,7 @@ impl settings::Settings for AllLanguageSettings { } else { d.completions = Some(CompletionSettings { words: mode, + words_min_length: 3, lsp: true, lsp_fetch_timeout_ms: 0, lsp_insert_mode: LspInsertMode::ReplaceSuffix, @@ -1517,6 +1634,7 @@ fn merge_settings(settings: &mut LanguageSettings, src: &LanguageSettingsContent merge(&mut settings.use_autoclose, src.use_autoclose); merge(&mut settings.use_auto_surround, src.use_auto_surround); merge(&mut settings.use_on_type_format, src.use_on_type_format); + merge(&mut settings.auto_indent, src.auto_indent); merge(&mut settings.auto_indent_on_paste, src.auto_indent_on_paste); merge( &mut settings.always_treat_brackets_as_autoclosed, @@ -1566,6 +1684,7 @@ fn merge_settings(settings: &mut LanguageSettings, src: &LanguageSettingsContent src.edit_predictions_disabled_in.clone(), ); merge(&mut settings.show_whitespaces, src.show_whitespaces); + merge(&mut settings.whitespace_map, src.whitespace_map.clone()); merge( &mut settings.extend_comment_on_newline, src.extend_comment_on_newline, @@ -1585,7 +1704,7 @@ fn merge_settings(settings: &mut LanguageSettings, src: &LanguageSettingsContent /// Allows to enable/disable formatting with Prettier /// and configure default Prettier, used when no project-level Prettier installation is found. /// Prettier formatting is disabled by default. -#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, SettingsUi)] pub struct PrettierSettings { /// Enables or disables formatting with Prettier for a given language. #[serde(default)] @@ -1598,15 +1717,17 @@ pub struct PrettierSettings { /// Forces Prettier integration to use specific plugins when formatting files with the language. /// The default Prettier will be installed with these plugins. #[serde(default)] + #[settings_ui(skip)] pub plugins: HashSet, /// Default Prettier options, in the format as in package.json section for Prettier. /// If project installs Prettier via its package.json, these options will be ignored. #[serde(flatten)] + #[settings_ui(skip)] pub options: HashMap, } -#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, SettingsUi)] pub struct JsxTagAutoCloseSettings { /// Enables or disables auto-closing of JSX tags. #[serde(default)] @@ -1786,7 +1907,7 @@ mod tests { assert!(!settings.enabled_for_file(&dot_env_file, &cx)); // Test tilde expansion - let home = shellexpand::tilde("~").into_owned().to_string(); + let home = shellexpand::tilde("~").into_owned(); let home_file = make_test_file(&[&home, "test.rs"]); let settings = build_settings(&["~/test.rs"]); assert!(!settings.enabled_for_file(&home_file, &cx)); diff --git a/crates/language/src/manifest.rs b/crates/language/src/manifest.rs index 37505fec3b233c2ecd7e2ac7807a7ade6a9b3d4a..3ca0ddf71da20f69d5d6440189d4a656bfbe7c9d 100644 --- a/crates/language/src/manifest.rs +++ b/crates/language/src/manifest.rs @@ -12,6 +12,12 @@ impl Borrow for ManifestName { } } +impl Borrow for ManifestName { + fn borrow(&self) -> &str { + &self.0 + } +} + impl From for ManifestName { fn from(value: SharedString) -> Self { Self(value) diff --git a/crates/language/src/proto.rs b/crates/language/src/proto.rs index 18f6bb8709c707af9dd19223cac30d6728eda160..bc85b10859632fc3e2cf61c663b7159a023f4f3a 100644 --- a/crates/language/src/proto.rs +++ b/crates/language/src/proto.rs @@ -86,10 +86,19 @@ pub fn serialize_operation(operation: &crate::Operation) -> proto::Operation { proto::operation::UpdateCompletionTriggers { replica_id: lamport_timestamp.replica_id as u32, lamport_timestamp: lamport_timestamp.value, - triggers: triggers.iter().cloned().collect(), + triggers: triggers.clone(), language_server_id: server_id.to_proto(), }, ), + + crate::Operation::UpdateLineEnding { + line_ending, + lamport_timestamp, + } => proto::operation::Variant::UpdateLineEnding(proto::operation::UpdateLineEnding { + replica_id: lamport_timestamp.replica_id as u32, + lamport_timestamp: lamport_timestamp.value, + line_ending: serialize_line_ending(*line_ending) as i32, + }), }), } } @@ -341,6 +350,18 @@ pub fn deserialize_operation(message: proto::Operation) -> Result { + crate::Operation::UpdateLineEnding { + lamport_timestamp: clock::Lamport { + replica_id: message.replica_id as ReplicaId, + value: message.lamport_timestamp, + }, + line_ending: deserialize_line_ending( + proto::LineEnding::from_i32(message.line_ending) + .context("missing line_ending")?, + ), + } + } }, ) } @@ -385,12 +406,10 @@ pub fn deserialize_undo_map_entry( /// Deserializes selections from the RPC representation. pub fn deserialize_selections(selections: Vec) -> Arc<[Selection]> { - Arc::from( - selections - .into_iter() - .filter_map(deserialize_selection) - .collect::>(), - ) + selections + .into_iter() + .filter_map(deserialize_selection) + .collect() } /// Deserializes a [`Selection`] from the RPC representation. @@ -433,7 +452,7 @@ pub fn deserialize_diagnostics( code: diagnostic.code.map(lsp::NumberOrString::from_string), code_description: diagnostic .code_description - .and_then(|s| lsp::Url::parse(&s).ok()), + .and_then(|s| lsp::Uri::from_str(&s).ok()), is_primary: diagnostic.is_primary, is_disk_based: diagnostic.is_disk_based, is_unnecessary: diagnostic.is_unnecessary, @@ -498,6 +517,10 @@ pub fn lamport_timestamp_for_operation(operation: &proto::Operation) -> Option { + replica_id = op.replica_id; + value = op.lamport_timestamp; + } } Some(clock::Lamport { diff --git a/crates/language/src/syntax_map.rs b/crates/language/src/syntax_map.rs index c56ffed0663a9419419201f902f3db8311acb9bd..38aad007fe16c655a3802bd70c9b709cbe83ea68 100644 --- a/crates/language/src/syntax_map.rs +++ b/crates/language/src/syntax_map.rs @@ -414,42 +414,42 @@ impl SyntaxSnapshot { .collect::>(); self.reparse_with_ranges(text, root_language.clone(), edit_ranges, registry.as_ref()); - if let Some(registry) = registry { - if registry.version() != self.language_registry_version { - let mut resolved_injection_ranges = Vec::new(); - let mut cursor = self - .layers - .filter::<_, ()>(text, |summary| summary.contains_unknown_injections); - cursor.next(); - while let Some(layer) = cursor.item() { - let SyntaxLayerContent::Pending { language_name } = &layer.content else { - unreachable!() - }; - if registry - .language_for_name_or_extension(language_name) - .now_or_never() - .and_then(|language| language.ok()) - .is_some() - { - let range = layer.range.to_offset(text); - log::trace!("reparse range {range:?} for language {language_name:?}"); - resolved_injection_ranges.push(range); - } - - cursor.next(); - } - drop(cursor); - - if !resolved_injection_ranges.is_empty() { - self.reparse_with_ranges( - text, - root_language, - resolved_injection_ranges, - Some(®istry), - ); + if let Some(registry) = registry + && registry.version() != self.language_registry_version + { + let mut resolved_injection_ranges = Vec::new(); + let mut cursor = self + .layers + .filter::<_, ()>(text, |summary| summary.contains_unknown_injections); + cursor.next(); + while let Some(layer) = cursor.item() { + let SyntaxLayerContent::Pending { language_name } = &layer.content else { + unreachable!() + }; + if registry + .language_for_name_or_extension(language_name) + .now_or_never() + .and_then(|language| language.ok()) + .is_some() + { + let range = layer.range.to_offset(text); + log::trace!("reparse range {range:?} for language {language_name:?}"); + resolved_injection_ranges.push(range); } - self.language_registry_version = registry.version(); + + cursor.next(); } + drop(cursor); + + if !resolved_injection_ranges.is_empty() { + self.reparse_with_ranges( + text, + root_language, + resolved_injection_ranges, + Some(®istry), + ); + } + self.language_registry_version = registry.version(); } self.update_count += 1; @@ -832,7 +832,7 @@ impl SyntaxSnapshot { query: fn(&Grammar) -> Option<&Query>, ) -> SyntaxMapCaptures<'a> { SyntaxMapCaptures::new( - range.clone(), + range, text, [SyntaxLayer { language, @@ -1065,10 +1065,10 @@ impl<'a> SyntaxMapCaptures<'a> { pub fn set_byte_range(&mut self, range: Range) { for layer in &mut self.layers { layer.captures.set_byte_range(range.clone()); - if let Some(capture) = &layer.next_capture { - if capture.node.end_byte() > range.start { - continue; - } + if let Some(capture) = &layer.next_capture + && capture.node.end_byte() > range.start + { + continue; } layer.advance(); } @@ -1277,11 +1277,11 @@ fn join_ranges( (None, None) => break, }; - if let Some(last) = result.last_mut() { - if range.start <= last.end { - last.end = last.end.max(range.end); - continue; - } + if let Some(last) = result.last_mut() + && range.start <= last.end + { + last.end = last.end.max(range.end); + continue; } result.push(range); } @@ -1297,7 +1297,7 @@ fn parse_text( ) -> anyhow::Result { with_parser(|parser| { let mut chunks = text.chunks_in_range(start_byte..text.len()); - parser.set_included_ranges(&ranges)?; + parser.set_included_ranges(ranges)?; parser.set_language(&grammar.ts_language)?; parser .parse_with_options( @@ -1330,14 +1330,13 @@ fn get_injections( // if there currently no matches for that injection. combined_injection_ranges.clear(); for pattern in &config.patterns { - if let (Some(language_name), true) = (pattern.language.as_ref(), pattern.combined) { - if let Some(language) = language_registry + if let (Some(language_name), true) = (pattern.language.as_ref(), pattern.combined) + && let Some(language) = language_registry .language_for_name_or_extension(language_name) .now_or_never() .and_then(|language| language.ok()) - { - combined_injection_ranges.insert(language.id, (language, Vec::new())); - } + { + combined_injection_ranges.insert(language.id, (language, Vec::new())); } } @@ -1357,10 +1356,11 @@ fn get_injections( content_ranges.first().unwrap().start_byte..content_ranges.last().unwrap().end_byte; // Avoid duplicate matches if two changed ranges intersect the same injection. - if let Some((prev_pattern_ix, prev_range)) = &prev_match { - if mat.pattern_index == *prev_pattern_ix && content_range == *prev_range { - continue; - } + if let Some((prev_pattern_ix, prev_range)) = &prev_match + && mat.pattern_index == *prev_pattern_ix + && content_range == *prev_range + { + continue; } prev_match = Some((mat.pattern_index, content_range.clone())); @@ -1630,10 +1630,8 @@ impl<'a> SyntaxLayer<'a> { if offset < range.start || offset > range.end { continue; } - } else { - if offset <= range.start || offset >= range.end { - continue; - } + } else if offset <= range.start || offset >= range.end { + continue; } if let Some((_, smallest_range)) = &smallest_match { diff --git a/crates/language/src/syntax_map/syntax_map_tests.rs b/crates/language/src/syntax_map/syntax_map_tests.rs index d576c95cd58eb823a7f8bdfdc42be9ba6a743410..622731b7814ce16bfcc026b6723e80d5ba4dda7a 100644 --- a/crates/language/src/syntax_map/syntax_map_tests.rs +++ b/crates/language/src/syntax_map/syntax_map_tests.rs @@ -58,8 +58,7 @@ fn test_splice_included_ranges() { assert_eq!(change, 0..1); // does not create overlapping ranges - let (new_ranges, change) = - splice_included_ranges(ranges.clone(), &[0..18], &[ts_range(20..32)]); + let (new_ranges, change) = splice_included_ranges(ranges, &[0..18], &[ts_range(20..32)]); assert_eq!( new_ranges, &[ts_range(20..32), ts_range(50..60), ts_range(80..90)] @@ -104,7 +103,7 @@ fn test_syntax_map_layers_for_range(cx: &mut App) { ); let mut syntax_map = SyntaxMap::new(&buffer); - syntax_map.set_language_registry(registry.clone()); + syntax_map.set_language_registry(registry); syntax_map.reparse(language.clone(), &buffer); assert_layers_for_range( @@ -165,7 +164,7 @@ fn test_syntax_map_layers_for_range(cx: &mut App) { // Put the vec! macro back, adding back the syntactic layer. buffer.undo(); syntax_map.interpolate(&buffer); - syntax_map.reparse(language.clone(), &buffer); + syntax_map.reparse(language, &buffer); assert_layers_for_range( &syntax_map, @@ -252,8 +251,8 @@ fn test_dynamic_language_injection(cx: &mut App) { assert!(syntax_map.contains_unknown_injections()); registry.add(Arc::new(html_lang())); - syntax_map.reparse(markdown.clone(), &buffer); - syntax_map.reparse(markdown_inline.clone(), &buffer); + syntax_map.reparse(markdown, &buffer); + syntax_map.reparse(markdown_inline, &buffer); assert_layers_for_range( &syntax_map, &buffer, @@ -862,7 +861,7 @@ fn test_syntax_map_languages_loading_with_erb(cx: &mut App) { log::info!("editing"); buffer.edit_via_marked_text(&text); syntax_map.interpolate(&buffer); - syntax_map.reparse(language.clone(), &buffer); + syntax_map.reparse(language, &buffer); assert_capture_ranges( &syntax_map, @@ -986,7 +985,7 @@ fn test_random_edits( syntax_map.reparse(language.clone(), &buffer); let mut reference_syntax_map = SyntaxMap::new(&buffer); - reference_syntax_map.set_language_registry(registry.clone()); + reference_syntax_map.set_language_registry(registry); log::info!("initial text:\n{}", buffer.text()); diff --git a/crates/language/src/text_diff.rs b/crates/language/src/text_diff.rs index f9221f571afb1baa0ba0b824922e799fcec01c88..11d8a070d213852f0a98078f2ed8c76c9cced47b 100644 --- a/crates/language/src/text_diff.rs +++ b/crates/language/src/text_diff.rs @@ -88,11 +88,11 @@ pub fn text_diff_with_options( let new_offset = new_byte_range.start; hunk_input.clear(); hunk_input.update_before(tokenize( - &old_text[old_byte_range.clone()], + &old_text[old_byte_range], options.language_scope.clone(), )); hunk_input.update_after(tokenize( - &new_text[new_byte_range.clone()], + &new_text[new_byte_range], options.language_scope.clone(), )); diff_internal(&hunk_input, |old_byte_range, new_byte_range, _, _| { @@ -103,7 +103,7 @@ pub fn text_diff_with_options( let replacement_text = if new_byte_range.is_empty() { empty.clone() } else { - new_text[new_byte_range.clone()].into() + new_text[new_byte_range].into() }; edits.push((old_byte_range, replacement_text)); }); @@ -111,9 +111,9 @@ pub fn text_diff_with_options( let replacement_text = if new_byte_range.is_empty() { empty.clone() } else { - new_text[new_byte_range.clone()].into() + new_text[new_byte_range].into() }; - edits.push((old_byte_range.clone(), replacement_text)); + edits.push((old_byte_range, replacement_text)); } }, ); @@ -154,19 +154,19 @@ fn diff_internal( input, |old_tokens: Range, new_tokens: Range| { old_offset += token_len( - &input, + input, &input.before[old_token_ix as usize..old_tokens.start as usize], ); new_offset += token_len( - &input, + input, &input.after[new_token_ix as usize..new_tokens.start as usize], ); let old_len = token_len( - &input, + input, &input.before[old_tokens.start as usize..old_tokens.end as usize], ); let new_len = token_len( - &input, + input, &input.after[new_tokens.start as usize..new_tokens.end as usize], ); let old_byte_range = old_offset..old_offset + old_len; @@ -186,14 +186,14 @@ fn tokenize(text: &str, language_scope: Option) -> impl Iterator< let mut prev = None; let mut start_ix = 0; iter::from_fn(move || { - while let Some((ix, c)) = chars.next() { + for (ix, c) in chars.by_ref() { let mut token = None; let kind = classifier.kind(c); - if let Some((prev_char, prev_kind)) = prev { - if kind != prev_kind || (kind == CharKind::Punctuation && c != prev_char) { - token = Some(&text[start_ix..ix]); - start_ix = ix; - } + if let Some((prev_char, prev_kind)) = prev + && (kind != prev_kind || (kind == CharKind::Punctuation && c != prev_char)) + { + token = Some(&text[start_ix..ix]); + start_ix = ix; } prev = Some((c, kind)); if token.is_some() { diff --git a/crates/language/src/toolchain.rs b/crates/language/src/toolchain.rs index 1f4b038f68e5fcf1ed5c499d543fa92ba3c2de94..2cc86881fbd515317d4d6f5949e82eb3da63a1bb 100644 --- a/crates/language/src/toolchain.rs +++ b/crates/language/src/toolchain.rs @@ -11,13 +11,15 @@ use std::{ use async_trait::async_trait; use collections::HashMap; +use fs::Fs; use gpui::{AsyncApp, SharedString}; use settings::WorktreeId; +use task::ShellKind; use crate::{LanguageName, ManifestName}; /// Represents a single toolchain. -#[derive(Clone, Debug)] +#[derive(Clone, Eq, Debug)] pub struct Toolchain { /// User-facing label pub name: SharedString, @@ -27,30 +29,104 @@ pub struct Toolchain { pub as_json: serde_json::Value, } +/// Declares a scope of a toolchain added by user. +/// +/// When the user adds a toolchain, we give them an option to see that toolchain in: +/// - All of their projects +/// - A project they're currently in. +/// - Only in the subproject they're currently in. +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +pub enum ToolchainScope { + Subproject(WorktreeId, Arc), + Project, + /// Available in all projects on this box. It wouldn't make sense to show suggestions across machines. + Global, +} + +impl ToolchainScope { + pub fn label(&self) -> &'static str { + match self { + ToolchainScope::Subproject(_, _) => "Subproject", + ToolchainScope::Project => "Project", + ToolchainScope::Global => "Global", + } + } + + pub fn description(&self) -> &'static str { + match self { + ToolchainScope::Subproject(_, _) => { + "Available only in the subproject you're currently in." + } + ToolchainScope::Project => "Available in all locations in your current project.", + ToolchainScope::Global => "Available in all of your projects on this machine.", + } + } +} + +impl std::hash::Hash for Toolchain { + fn hash(&self, state: &mut H) { + let Self { + name, + path, + language_name, + as_json: _, + } = self; + name.hash(state); + path.hash(state); + language_name.hash(state); + } +} + impl PartialEq for Toolchain { fn eq(&self, other: &Self) -> bool { + let Self { + name, + path, + language_name, + as_json: _, + } = self; // Do not use as_json for comparisons; it shouldn't impact equality, as it's not user-surfaced. // Thus, there could be multiple entries that look the same in the UI. - (&self.name, &self.path, &self.language_name).eq(&( - &other.name, - &other.path, - &other.language_name, - )) + (name, path, language_name).eq(&(&other.name, &other.path, &other.language_name)) } } #[async_trait] -pub trait ToolchainLister: Send + Sync { +pub trait ToolchainLister: Send + Sync + 'static { + /// List all available toolchains for a given path. async fn list( &self, worktree_root: PathBuf, - subroot_relative_path: Option>, + subroot_relative_path: Arc, project_env: Option>, ) -> ToolchainList; - // Returns a term which we should use in UI to refer to a toolchain. - fn term(&self) -> SharedString; - /// Returns the name of the manifest file for this toolchain. - fn manifest_name(&self) -> ManifestName; + + /// Given a user-created toolchain, resolve lister-specific details. + /// Put another way: fill in the details of the toolchain so the user does not have to. + async fn resolve( + &self, + path: PathBuf, + project_env: Option>, + ) -> anyhow::Result; + + async fn activation_script( + &self, + toolchain: &Toolchain, + shell: ShellKind, + fs: &dyn Fs, + ) -> Vec; + /// Returns various "static" bits of information about this toolchain lister. This function should be pure. + fn meta(&self) -> ToolchainMetadata; +} + +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct ToolchainMetadata { + /// Returns a term which we should use in UI to refer to toolchains produced by a given `[ToolchainLister]`. + pub term: SharedString, + /// A user-facing placeholder describing the semantic meaning of a path to a new toolchain. + pub new_toolchain_placeholder: SharedString, + /// The name of the manifest file for this toolchain. + pub manifest_name: ManifestName, } #[async_trait(?Send)] @@ -64,8 +140,31 @@ pub trait LanguageToolchainStore: Send + Sync + 'static { ) -> Option; } +pub trait LocalLanguageToolchainStore: Send + Sync + 'static { + fn active_toolchain( + self: Arc, + worktree_id: WorktreeId, + relative_path: &Arc, + language_name: LanguageName, + cx: &mut AsyncApp, + ) -> Option; +} + +#[async_trait(?Send)] +impl LanguageToolchainStore for T { + async fn active_toolchain( + self: Arc, + worktree_id: WorktreeId, + relative_path: Arc, + language_name: LanguageName, + cx: &mut AsyncApp, + ) -> Option { + self.active_toolchain(worktree_id, &relative_path, language_name, cx) + } +} + type DefaultIndex = usize; -#[derive(Default, Clone)] +#[derive(Default, Clone, Debug)] pub struct ToolchainList { pub toolchains: Vec, pub default: Option, diff --git a/crates/language_extension/src/extension_lsp_adapter.rs b/crates/language_extension/src/extension_lsp_adapter.rs index 98b6fd4b5a2ef6e7f1b5adbc54dcecd0707b60ff..c1bc058a344e02fd4830c9db89684579a9e7e045 100644 --- a/crates/language_extension/src/extension_lsp_adapter.rs +++ b/crates/language_extension/src/extension_lsp_adapter.rs @@ -12,8 +12,8 @@ use fs::Fs; use futures::{Future, FutureExt, future::join_all}; use gpui::{App, AppContext, AsyncApp, Task}; use language::{ - BinaryStatus, CodeLabel, HighlightId, Language, LanguageName, LanguageToolchainStore, - LspAdapter, LspAdapterDelegate, + BinaryStatus, CodeLabel, HighlightId, Language, LanguageName, LspAdapter, LspAdapterDelegate, + Toolchain, }; use lsp::{ CodeActionKind, LanguageServerBinary, LanguageServerBinaryOptions, LanguageServerName, @@ -159,7 +159,7 @@ impl LspAdapter for ExtensionLspAdapter { fn get_language_server_command<'a>( self: Arc, delegate: Arc, - _: Arc, + _: Option, _: LanguageServerBinaryOptions, _: futures::lock::MutexGuard<'a, Option>, _: &'a mut AsyncApp, @@ -204,6 +204,7 @@ impl LspAdapter for ExtensionLspAdapter { async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, + _: &AsyncApp, ) -> Result> { unreachable!("get_language_server_command is overridden") } @@ -288,7 +289,7 @@ impl LspAdapter for ExtensionLspAdapter { self: Arc, _: &dyn Fs, delegate: &Arc, - _: Arc, + _: Option, _cx: &mut AsyncApp, ) -> Result { let delegate = Arc::new(WorktreeDelegateAdapter(delegate.clone())) as _; @@ -336,7 +337,7 @@ impl LspAdapter for ExtensionLspAdapter { target_language_server_id: LanguageServerName, _: &dyn Fs, delegate: &Arc, - _: Arc, + _cx: &mut AsyncApp, ) -> Result> { let delegate = Arc::new(WorktreeDelegateAdapter(delegate.clone())) as _; @@ -397,6 +398,10 @@ impl LspAdapter for ExtensionLspAdapter { Ok(labels_from_extension(labels, language)) } + + fn is_extension(&self) -> bool { + true + } } fn labels_from_extension( diff --git a/crates/language_extension/src/language_extension.rs b/crates/language_extension/src/language_extension.rs index 1915eae2d18fe5fb96dbb0dcca614f8a4f41bb81..510f870ce8afbda090817e0ce515d4c5c2e3c63b 100644 --- a/crates/language_extension/src/language_extension.rs +++ b/crates/language_extension/src/language_extension.rs @@ -52,7 +52,7 @@ impl ExtensionLanguageProxy for LanguageServerRegistryProxy { load: Arc Result + Send + Sync + 'static>, ) { self.language_registry - .register_language(language, grammar, matcher, hidden, load); + .register_language(language, grammar, matcher, hidden, None, load); } fn remove_languages( @@ -61,6 +61,6 @@ impl ExtensionLanguageProxy for LanguageServerRegistryProxy { grammars_to_remove: &[Arc], ) { self.language_registry - .remove_languages(&languages_to_remove, &grammars_to_remove); + .remove_languages(languages_to_remove, grammars_to_remove); } } diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index f9920623b5ea3bff79535f92753fae0b723f850f..d4513f617b0d9f79e960c6cec6ca1a5dd806cea6 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -17,6 +17,7 @@ test-support = [] [dependencies] anthropic = { workspace = true, features = ["schemars"] } +open_router.workspace = true anyhow.workspace = true base64.workspace = true client.workspace = true diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index a9c7d5c0343295ff02d9d693f2cdbe3d92f1e07d..b06a475f9385012e5b88466c80fbb14e0ed744ac 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -1,14 +1,19 @@ use crate::{ - AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, - LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, - LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, + AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, LanguageModelCompletionError, + LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, + LanguageModelRequest, LanguageModelToolChoice, }; -use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream}; +use anyhow::anyhow; +use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, App, AsyncApp, Entity, Task, Window}; use http_client::Result; use parking_lot::Mutex; -use std::sync::Arc; +use smol::stream::StreamExt; +use std::sync::{ + Arc, + atomic::{AtomicBool, Ordering::SeqCst}, +}; #[derive(Clone)] pub struct FakeLanguageModelProvider { @@ -62,7 +67,12 @@ impl LanguageModelProvider for FakeLanguageModelProvider { Task::ready(Ok(())) } - fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: ConfigurationViewTargetAgent, + _window: &mut Window, + _: &mut App, + ) -> AnyView { unimplemented!() } @@ -95,9 +105,12 @@ pub struct FakeLanguageModel { current_completion_txs: Mutex< Vec<( LanguageModelRequest, - mpsc::UnboundedSender, + mpsc::UnboundedSender< + Result, + >, )>, >, + forbid_requests: AtomicBool, } impl Default for FakeLanguageModel { @@ -106,11 +119,20 @@ impl Default for FakeLanguageModel { provider_id: LanguageModelProviderId::from("fake".to_string()), provider_name: LanguageModelProviderName::from("Fake".to_string()), current_completion_txs: Mutex::new(Vec::new()), + forbid_requests: AtomicBool::new(false), } } } impl FakeLanguageModel { + pub fn allow_requests(&self) { + self.forbid_requests.store(false, SeqCst); + } + + pub fn forbid_requests(&self) { + self.forbid_requests.store(true, SeqCst); + } + pub fn pending_completions(&self) -> Vec { self.current_completion_txs .lock() @@ -145,7 +167,21 @@ impl FakeLanguageModel { .find(|(req, _)| req == request) .map(|(_, tx)| tx) .unwrap(); - tx.unbounded_send(event.into()).unwrap(); + tx.unbounded_send(Ok(event.into())).unwrap(); + } + + pub fn send_completion_stream_error( + &self, + request: &LanguageModelRequest, + error: impl Into, + ) { + let current_completion_txs = self.current_completion_txs.lock(); + let tx = current_completion_txs + .iter() + .find(|(req, _)| req == request) + .map(|(_, tx)| tx) + .unwrap(); + tx.unbounded_send(Err(error.into())).unwrap(); } pub fn end_completion_stream(&self, request: &LanguageModelRequest) { @@ -165,6 +201,13 @@ impl FakeLanguageModel { self.send_completion_stream_event(self.pending_completions().last().unwrap(), event); } + pub fn send_last_completion_stream_error( + &self, + error: impl Into, + ) { + self.send_completion_stream_error(self.pending_completions().last().unwrap(), error); + } + pub fn end_last_completion_stream(&self) { self.end_completion_stream(self.pending_completions().last().unwrap()); } @@ -222,9 +265,18 @@ impl LanguageModel for FakeLanguageModel { LanguageModelCompletionError, >, > { - let (tx, rx) = mpsc::unbounded(); - self.current_completion_txs.lock().push((request, tx)); - async move { Ok(rx.map(Ok).boxed()) }.boxed() + if self.forbid_requests.load(SeqCst) { + async move { + Err(LanguageModelCompletionError::Other(anyhow!( + "requests are forbidden" + ))) + } + .boxed() + } else { + let (tx, rx) = mpsc::unbounded(); + self.current_completion_txs.lock().push((request, tx)); + async move { Ok(rx.boxed()) }.boxed() + } } fn as_fake(&self) -> &Self { diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 1637d2de8a3c14b910ea345c03a4eb5db13df28d..fac302104fd9a4da82f5a383d5cd86b64fde4731 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -14,9 +14,10 @@ use client::Client; use cloud_llm_client::{CompletionMode, CompletionRequestStatus}; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; -use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; +use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window}; use http_client::{StatusCode, http}; use icons::IconName; +use open_router::OpenRouterError; use parking_lot::Mutex; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; @@ -54,7 +55,7 @@ pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName = pub fn init(client: Arc, cx: &mut App) { init_settings(cx); - RefreshLlmTokenListener::register(client.clone(), cx); + RefreshLlmTokenListener::register(client, cx); } pub fn init_settings(cx: &mut App) { @@ -300,7 +301,7 @@ impl From for LanguageModelCompletionError { }, AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded { provider, - retry_after: retry_after, + retry_after, }, AnthropicError::ApiError(api_error) => api_error.into(), } @@ -347,6 +348,72 @@ impl From for LanguageModelCompletionError { } } +impl From for LanguageModelCompletionError { + fn from(error: OpenRouterError) -> Self { + let provider = LanguageModelProviderName::new("OpenRouter"); + match error { + OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, + OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, + OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error }, + OpenRouterError::DeserializeResponse(error) => { + Self::DeserializeResponse { provider, error } + } + OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, + OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded { + provider, + retry_after: Some(retry_after), + }, + OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded { + provider, + retry_after, + }, + OpenRouterError::ApiError(api_error) => api_error.into(), + } + } +} + +impl From for LanguageModelCompletionError { + fn from(error: open_router::ApiError) -> Self { + use open_router::ApiErrorCode::*; + let provider = LanguageModelProviderName::new("OpenRouter"); + match error.code { + InvalidRequestError => Self::BadRequestFormat { + provider, + message: error.message, + }, + AuthenticationError => Self::AuthenticationError { + provider, + message: error.message, + }, + PaymentRequiredError => Self::AuthenticationError { + provider, + message: format!("Payment required: {}", error.message), + }, + PermissionError => Self::PermissionError { + provider, + message: error.message, + }, + RequestTimedOut => Self::HttpResponseError { + provider, + status_code: StatusCode::REQUEST_TIMEOUT, + message: error.message, + }, + RateLimitError => Self::RateLimitExceeded { + provider, + retry_after: None, + }, + ApiError => Self::ApiInternalServerError { + provider, + message: error.message, + }, + OverloadedError => Self::ServerOverloaded { + provider, + retry_after: None, + }, + } + } +} + /// Indicates the format used to define the input schema for a language model tool. #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] pub enum LanguageModelToolSchemaFormat { @@ -538,7 +605,7 @@ pub trait LanguageModel: Send + Sync { if let Some(first_event) = events.next().await { match first_event { Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => { - message_id = Some(id.clone()); + message_id = Some(id); } Ok(LanguageModelCompletionEvent::Text(text)) => { first_item_text = Some(text); @@ -634,20 +701,22 @@ pub trait LanguageModelProvider: 'static { } fn is_authenticated(&self, cx: &App) -> bool; fn authenticate(&self, cx: &mut App) -> Task>; - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView; - fn must_accept_terms(&self, _cx: &App) -> bool { - false - } - fn render_accept_terms( + fn configuration_view( &self, - _view: LanguageModelProviderTosView, - _cx: &mut App, - ) -> Option { - None - } + target_agent: ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView; fn reset_credentials(&self, cx: &mut App) -> Task>; } +#[derive(Default, Clone)] +pub enum ConfigurationViewTargetAgent { + #[default] + ZedAgent, + Other(SharedString), +} + #[derive(PartialEq, Eq)] pub enum LanguageModelProviderTosView { /// When there are some past interactions in the Agent Panel. diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 3b4c1fa269020d1bf17d98cbb67251902536dafc..e25ed0de50c4ddf03ff539dbce728dbc20def9b5 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use anyhow::Result; use client::Client; use cloud_api_types::websocket_protocol::MessageToClient; -use cloud_llm_client::Plan; +use cloud_llm_client::{Plan, PlanV1}; use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _}; use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use thiserror::Error; @@ -29,19 +29,34 @@ pub struct ModelRequestLimitReachedError { impl fmt::Display for ModelRequestLimitReachedError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let message = match self.plan { - Plan::ZedFree => "Model request limit reached. Upgrade to Zed Pro for more requests.", - Plan::ZedPro => { + Plan::V1(PlanV1::ZedFree) => { + "Model request limit reached. Upgrade to Zed Pro for more requests." + } + Plan::V1(PlanV1::ZedPro) => { "Model request limit reached. Upgrade to usage-based billing for more requests." } - Plan::ZedProTrial => { + Plan::V1(PlanV1::ZedProTrial) => { "Model request limit reached. Upgrade to Zed Pro for more requests." } + Plan::V2(_) => "Model request limit reached.", }; write!(f, "{message}") } } +#[derive(Error, Debug)] +pub struct ToolUseLimitReachedError; + +impl fmt::Display for ToolUseLimitReachedError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Consecutive tool use limit reached. Enable Burn Mode for unlimited tool use." + ) + } +} + #[derive(Clone, Default)] pub struct LlmApiToken(Arc>>); @@ -70,7 +85,7 @@ impl LlmApiToken { let response = client.cloud_client().create_llm_token(system_id).await?; *lock = Some(response.token.0.clone()); - Ok(response.token.0.clone()) + Ok(response.token.0) } } diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 7cf071808a2c0d95bf9aa5a41eaa260cff533d57..bab258bca1728ac45f5ef5c0397149b93f0d6031 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -21,13 +21,10 @@ impl Global for GlobalLanguageModelRegistry {} pub enum ConfigurationError { #[error("Configure at least one LLM provider to start using the panel.")] NoProvider, - #[error("LLM Provider is not configured or does not support the configured model.")] + #[error("LLM provider is not configured or does not support the configured model.")] ModelNotFound, #[error("{} LLM provider is not configured.", .0.name().0)] ProviderNotAuthenticated(Arc), - #[error("Using the {} LLM provider requires accepting the Terms of Service.", - .0.name().0)] - ProviderPendingTermsAcceptance(Arc), } impl std::fmt::Debug for ConfigurationError { @@ -38,9 +35,6 @@ impl std::fmt::Debug for ConfigurationError { Self::ProviderNotAuthenticated(provider) => { write!(f, "ProviderNotAuthenticated({})", provider.id()) } - Self::ProviderPendingTermsAcceptance(provider) => { - write!(f, "ProviderPendingTermsAcceptance({})", provider.id()) - } } } } @@ -107,7 +101,7 @@ pub enum Event { InlineAssistantModelChanged, CommitMessageModelChanged, ThreadSummaryModelChanged, - ProviderStateChanged, + ProviderStateChanged(LanguageModelProviderId), AddedProvider(LanguageModelProviderId), RemovedProvider(LanguageModelProviderId), } @@ -148,8 +142,11 @@ impl LanguageModelRegistry { ) { let id = provider.id(); - let subscription = provider.subscribe(cx, |_, cx| { - cx.emit(Event::ProviderStateChanged); + let subscription = provider.subscribe(cx, { + let id = id.clone(); + move |_, cx| { + cx.emit(Event::ProviderStateChanged(id.clone())); + } }); if let Some(subscription) = subscription { subscription.detach(); @@ -197,12 +194,6 @@ impl LanguageModelRegistry { return Some(ConfigurationError::ProviderNotAuthenticated(model.provider)); } - if model.provider.must_accept_terms(cx) { - return Some(ConfigurationError::ProviderPendingTermsAcceptance( - model.provider, - )); - } - None } @@ -217,6 +208,7 @@ impl LanguageModelRegistry { ) -> impl Iterator> + 'a { self.providers .values() + .filter(|provider| provider.is_authenticated(cx)) .flat_map(|provider| provider.provided_models(cx)) } diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index edce3d03b7063b383e51d88d4de7dc52ace0d04c..1182e0f7a8f1952a62832970ca63f3684eea5b17 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -220,42 +220,39 @@ impl<'de> Deserialize<'de> for LanguageModelToolResultContent { // Accept wrapped text format: { "type": "text", "text": "..." } if let (Some(type_value), Some(text_value)) = - (get_field(&obj, "type"), get_field(&obj, "text")) + (get_field(obj, "type"), get_field(obj, "text")) + && let Some(type_str) = type_value.as_str() + && type_str.to_lowercase() == "text" + && let Some(text) = text_value.as_str() { - if let Some(type_str) = type_value.as_str() { - if type_str.to_lowercase() == "text" { - if let Some(text) = text_value.as_str() { - return Ok(Self::Text(Arc::from(text))); - } - } - } + return Ok(Self::Text(Arc::from(text))); } // Check for wrapped Text variant: { "text": "..." } - if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") { - if obj.len() == 1 { - // Only one field, and it's "text" (case-insensitive) - if let Some(text) = value.as_str() { - return Ok(Self::Text(Arc::from(text))); - } + if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") + && obj.len() == 1 + { + // Only one field, and it's "text" (case-insensitive) + if let Some(text) = value.as_str() { + return Ok(Self::Text(Arc::from(text))); } } // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } } - if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") { - if obj.len() == 1 { - // Only one field, and it's "image" (case-insensitive) - // Try to parse the nested image object - if let Some(image_obj) = value.as_object() { - if let Some(image) = LanguageModelImage::from_json(image_obj) { - return Ok(Self::Image(image)); - } - } + if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") + && obj.len() == 1 + { + // Only one field, and it's "image" (case-insensitive) + // Try to parse the nested image object + if let Some(image_obj) = value.as_object() + && let Some(image) = LanguageModelImage::from_json(image_obj) + { + return Ok(Self::Image(image)); } } // Try as direct Image (object with "source" and "size" fields) - if let Some(image) = LanguageModelImage::from_json(&obj) { + if let Some(image) = LanguageModelImage::from_json(obj) { return Ok(Self::Image(image)); } } @@ -272,7 +269,7 @@ impl<'de> Deserialize<'de> for LanguageModelToolResultContent { impl LanguageModelToolResultContent { pub fn to_str(&self) -> Option<&str> { match self { - Self::Text(text) => Some(&text), + Self::Text(text) => Some(text), Self::Image(_) => None, } } diff --git a/crates/language_model/src/role.rs b/crates/language_model/src/role.rs index 953dfa6fdff91c61a3a444076fd768f260b882c5..4b47ef36dd564e5950ce7d42a7e4f9263f3998b7 100644 --- a/crates/language_model/src/role.rs +++ b/crates/language_model/src/role.rs @@ -19,7 +19,7 @@ impl Role { } } - pub fn to_proto(&self) -> proto::LanguageModelRole { + pub fn to_proto(self) -> proto::LanguageModelRole { match self { Role::User => proto::LanguageModelRole::LanguageModelUser, Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant, diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 18e6f47ed0591256591df578f98dcaf988ed6444..738b72b0c9a6dbb7c9606cc72707b27e66abf09c 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -104,7 +104,7 @@ fn register_language_model_providers( cx: &mut Context, ) { registry.register_provider( - CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx), + CloudLanguageModelProvider::new(user_store, client.clone(), cx), cx, ); diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index ef21e85f711e41722d4ac421ba1d0a89b422b6a6..ca7763e2c5cda3c07c5cb51389cb3173a55865e2 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -15,11 +15,11 @@ use gpui::{ }; use http_client::HttpClient; use language_model::{ - AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, - LanguageModelCompletionError, LanguageModelId, LanguageModelName, LanguageModelProvider, - LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, MessageContent, - RateLimiter, Role, + AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, + LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId, + LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, + LanguageModelToolResultContent, MessageContent, RateLimiter, Role, }; use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use schemars::JsonSchema; @@ -114,7 +114,7 @@ impl State { .clone(); cx.spawn(async move |this, cx| { credentials_provider - .delete_credentials(&api_url, &cx) + .delete_credentials(&api_url, cx) .await .ok(); this.update(cx, |this, cx| { @@ -133,7 +133,7 @@ impl State { .clone(); cx.spawn(async move |this, cx| { credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) .await .ok(); @@ -153,29 +153,14 @@ impl State { return Task::ready(Ok(())); } - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .anthropic - .api_url - .clone(); + let key = AnthropicLanguageModelProvider::api_key(cx); cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(ANTHROPIC_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, &cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; + let key = key.await?; this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; + this.api_key = Some(key.key); + this.api_key_from_env = key.from_env; cx.notify(); })?; @@ -184,6 +169,11 @@ impl State { } } +pub struct ApiKey { + pub key: String, + pub from_env: bool, +} + impl AnthropicLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { let state = cx.new(|cx| State { @@ -206,6 +196,33 @@ impl AnthropicLanguageModelProvider { request_limiter: RateLimiter::new(4), }) } + + pub fn api_key(cx: &mut App) -> Task> { + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .anthropic + .api_url + .clone(); + + if let Ok(key) = std::env::var(ANTHROPIC_API_KEY_VAR) { + Task::ready(Ok(ApiKey { + key, + from_env: true, + })) + } else { + cx.spawn(async move |cx| { + let (_, api_key) = credentials_provider + .read_credentials(&api_url, cx) + .await? + .ok_or(AuthenticateError::CredentialsNotFound)?; + + Ok(ApiKey { + key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, + from_env: false, + }) + }) + } + } } impl LanguageModelProviderState for AnthropicLanguageModelProvider { @@ -299,8 +316,13 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { - cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) + fn configuration_view( + &self, + target_agent: ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { + cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx)) .into() } @@ -402,14 +424,21 @@ impl AnthropicModel { return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed(); }; + let beta_headers = self.model.beta_headers(); + async move { let Some(api_key) = api_key else { return Err(LanguageModelCompletionError::NoApiKey { provider: PROVIDER_NAME, }); }; - let request = - anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let request = anthropic::stream_completion( + http_client.as_ref(), + &api_url, + &api_key, + request, + beta_headers, + ); request.await.map_err(Into::into) } .boxed() @@ -532,7 +561,7 @@ pub fn into_anthropic( .into_iter() .filter_map(|content| match content { MessageContent::Text(text) => { - let text = if text.chars().last().map_or(false, |c| c.is_whitespace()) { + let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) { text.trim_end().to_string() } else { text @@ -611,11 +640,11 @@ pub fn into_anthropic( Role::Assistant => anthropic::Role::Assistant, Role::System => unreachable!("System role should never occur here"), }; - if let Some(last_message) = new_messages.last_mut() { - if last_message.role == anthropic_role { - last_message.content.extend(anthropic_message_content); - continue; - } + if let Some(last_message) = new_messages.last_mut() + && last_message.role == anthropic_role + { + last_message.content.extend(anthropic_message_content); + continue; } // Mark the last segment of the message as cached @@ -791,7 +820,7 @@ impl AnthropicEventMapper { ))]; } } - return vec![]; + vec![] } }, Event::ContentBlockStop { index } => { @@ -902,12 +931,18 @@ struct ConfigurationView { api_key_editor: Entity, state: gpui::Entity, load_credentials_task: Option>, + target_agent: ConfigurationViewTargetAgent, } impl ConfigurationView { const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; - fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { + fn new( + state: gpui::Entity, + target_agent: ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut Context, + ) -> Self { cx.observe(&state, |_, _, cx| { cx.notify(); }) @@ -934,11 +969,12 @@ impl ConfigurationView { Self { api_key_editor: cx.new(|cx| { let mut editor = Editor::single_line(window, cx); - editor.set_placeholder_text(Self::PLACEHOLDER_TEXT, cx); + editor.set_placeholder_text(Self::PLACEHOLDER_TEXT, window, cx); editor }), state, load_credentials_task, + target_agent, } } @@ -1012,7 +1048,10 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's agent with Anthropic, you need to add an API key. Follow these steps:")) + .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent { + ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Anthropic".into(), + ConfigurationViewTargetAgent::Other(agent) => agent.clone(), + }))) .child( List::new() .child( @@ -1023,7 +1062,7 @@ impl Render for ConfigurationView { ) ) .child( - InstructionListItem::text_only("Paste your API key below and hit enter to start using the assistant") + InstructionListItem::text_only("Paste your API key below and hit enter to start using the agent") ) ) .child( diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 6df96c5c566aac6f23af837491292cc89a56c74a..49a976d5b18d2c7a2ca3162c632f53706b385cb0 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -150,7 +150,7 @@ impl State { let credentials_provider = ::global(cx); cx.spawn(async move |this, cx| { credentials_provider - .delete_credentials(AMAZON_AWS_URL, &cx) + .delete_credentials(AMAZON_AWS_URL, cx) .await .log_err(); this.update(cx, |this, cx| { @@ -174,7 +174,7 @@ impl State { AMAZON_AWS_URL, "Bearer", &serde_json::to_vec(&credentials)?, - &cx, + cx, ) .await?; this.update(cx, |this, cx| { @@ -206,7 +206,7 @@ impl State { (credentials, true) } else { let (_, credentials) = credentials_provider - .read_credentials(AMAZON_AWS_URL, &cx) + .read_credentials(AMAZON_AWS_URL, cx) .await? .ok_or_else(|| AuthenticateError::CredentialsNotFound)?; ( @@ -348,7 +348,12 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } @@ -407,10 +412,10 @@ impl BedrockModel { .region(Region::new(region)) .timeout_config(TimeoutConfig::disabled()); - if let Some(endpoint_url) = endpoint { - if !endpoint_url.is_empty() { - config_builder = config_builder.endpoint_url(endpoint_url); - } + if let Some(endpoint_url) = endpoint + && !endpoint_url.is_empty() + { + config_builder = config_builder.endpoint_url(endpoint_url); } match auth_method { @@ -460,7 +465,7 @@ impl BedrockModel { Result>>, > { let Ok(runtime_client) = self - .get_or_init_client(&cx) + .get_or_init_client(cx) .cloned() .context("Bedrock client not initialized") else { @@ -723,11 +728,11 @@ pub fn into_bedrock( Role::Assistant => bedrock::BedrockRole::Assistant, Role::System => unreachable!("System role should never occur here"), }; - if let Some(last_message) = new_messages.last_mut() { - if last_message.role == bedrock_role { - last_message.content.extend(bedrock_message_content); - continue; - } + if let Some(last_message) = new_messages.last_mut() + && last_message.role == bedrock_role + { + last_message.content.extend(bedrock_message_content); + continue; } new_messages.push( BedrockMessage::builder() @@ -912,7 +917,7 @@ pub fn map_to_language_model_completion_events( Some(ContentBlockDelta::ReasoningContent(thinking)) => match thinking { ReasoningContentBlockDelta::Text(thoughts) => { Some(Ok(LanguageModelCompletionEvent::Thinking { - text: thoughts.clone(), + text: thoughts, signature: None, })) } @@ -963,7 +968,7 @@ pub fn map_to_language_model_completion_events( id: tool_use.id.into(), name: tool_use.name.into(), is_input_complete: true, - raw_input: tool_use.input_json.clone(), + raw_input: tool_use.input_json, input, }, )) @@ -1048,22 +1053,22 @@ impl ConfigurationView { Self { access_key_id_editor: cx.new(|cx| { let mut editor = Editor::single_line(window, cx); - editor.set_placeholder_text(Self::PLACEHOLDER_ACCESS_KEY_ID_TEXT, cx); + editor.set_placeholder_text(Self::PLACEHOLDER_ACCESS_KEY_ID_TEXT, window, cx); editor }), secret_access_key_editor: cx.new(|cx| { let mut editor = Editor::single_line(window, cx); - editor.set_placeholder_text(Self::PLACEHOLDER_SECRET_ACCESS_KEY_TEXT, cx); + editor.set_placeholder_text(Self::PLACEHOLDER_SECRET_ACCESS_KEY_TEXT, window, cx); editor }), session_token_editor: cx.new(|cx| { let mut editor = Editor::single_line(window, cx); - editor.set_placeholder_text(Self::PLACEHOLDER_SESSION_TOKEN_TEXT, cx); + editor.set_placeholder_text(Self::PLACEHOLDER_SESSION_TOKEN_TEXT, window, cx); editor }), region_editor: cx.new(|cx| { let mut editor = Editor::single_line(window, cx); - editor.set_placeholder_text(Self::PLACEHOLDER_REGION, cx); + editor.set_placeholder_text(Self::PLACEHOLDER_REGION, window, cx); editor }), state, @@ -1081,21 +1086,18 @@ impl ConfigurationView { .access_key_id_editor .read(cx) .text(cx) - .to_string() .trim() .to_string(); let secret_access_key = self .secret_access_key_editor .read(cx) .text(cx) - .to_string() .trim() .to_string(); let session_token = self .session_token_editor .read(cx) .text(cx) - .to_string() .trim() .to_string(); let session_token = if session_token.is_empty() { @@ -1103,13 +1105,7 @@ impl ConfigurationView { } else { Some(session_token) }; - let region = self - .region_editor - .read(cx) - .text(cx) - .to_string() - .trim() - .to_string(); + let region = self.region_editor.read(cx).text(cx).trim().to_string(); let region = if region.is_empty() { "us-east-1".to_string() } else { diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index ff8048040e6c7967d52df79b5837a502844998a8..421e34e658fb604e21fdfc4af52b3c4c7874fd70 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -7,8 +7,9 @@ use cloud_llm_client::{ CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan, - SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, - TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME, + PlanV1, PlanV2, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, + SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, TOOL_USE_LIMIT_REACHED_HEADER_NAME, + ZED_VERSION_HEADER_NAME, }; use futures::{ AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream, @@ -23,9 +24,9 @@ use language_model::{ AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, - LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, - ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener, + LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, + LanguageModelToolSchemaFormat, LlmApiToken, ModelRequestLimitReachedError, + PaymentRequiredError, RateLimiter, RefreshLlmTokenListener, }; use release_channel::AppVersion; use schemars::JsonSchema; @@ -118,7 +119,6 @@ pub struct State { llm_api_token: LlmApiToken, user_store: Entity, status: client::Status, - accept_terms_of_service_task: Option>>, models: Vec>, default_model: Option>, default_fast_model: Option>, @@ -140,9 +140,8 @@ impl State { Self { client: client.clone(), llm_api_token: LlmApiToken::default(), - user_store: user_store.clone(), + user_store, status, - accept_terms_of_service_task: None, models: Vec::new(), default_model: None, default_fast_model: None, @@ -193,28 +192,10 @@ impl State { fn authenticate(&self, cx: &mut Context) -> Task> { let client = self.client.clone(); cx.spawn(async move |state, cx| { - client.sign_in_with_optional_connect(true, &cx).await?; + client.sign_in_with_optional_connect(true, cx).await?; state.update(cx, |_, cx| cx.notify()) }) } - - fn has_accepted_terms_of_service(&self, cx: &App) -> bool { - self.user_store.read(cx).has_accepted_terms_of_service() - } - - fn accept_terms_of_service(&mut self, cx: &mut Context) { - let user_store = self.user_store.clone(); - self.accept_terms_of_service_task = Some(cx.spawn(async move |this, cx| { - let _ = user_store - .update(cx, |store, cx| store.accept_terms_of_service(cx))? - .await; - this.update(cx, |this, cx| { - this.accept_terms_of_service_task = None; - cx.notify() - }) - })); - } - fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context) { let mut models = Vec::new(); @@ -270,7 +251,7 @@ impl State { if response.status().is_success() { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; - return Ok(serde_json::from_str(&body)?); + Ok(serde_json::from_str(&body)?) } else { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; @@ -307,7 +288,7 @@ impl CloudLanguageModelProvider { Self { client, - state: state.clone(), + state, _maintain_client_status: maintain_client_status, } } @@ -320,7 +301,7 @@ impl CloudLanguageModelProvider { Arc::new(CloudLanguageModel { id: LanguageModelId(SharedString::from(model.id.0.clone())), model, - llm_api_token: llm_api_token.clone(), + llm_api_token, client: self.client.clone(), request_limiter: RateLimiter::new(4), }) @@ -384,40 +365,21 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn is_authenticated(&self, cx: &App) -> bool { let state = self.state.read(cx); - !state.is_signed_out(cx) && state.has_accepted_terms_of_service(cx) + !state.is_signed_out(cx) } fn authenticate(&self, _cx: &mut App) -> Task> { Task::ready(Ok(())) } - fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView { - cx.new(|_| ConfigurationView::new(self.state.clone())) - .into() - } - - fn must_accept_terms(&self, cx: &App) -> bool { - !self.state.read(cx).has_accepted_terms_of_service(cx) - } - - fn render_accept_terms( + fn configuration_view( &self, - view: LanguageModelProviderTosView, + _target_agent: language_model::ConfigurationViewTargetAgent, + _: &mut Window, cx: &mut App, - ) -> Option { - let state = self.state.read(cx); - if state.has_accepted_terms_of_service(cx) { - return None; - } - Some( - render_accept_terms(view, state.accept_terms_of_service_task.is_some(), { - let state = self.state.clone(); - move |_window, cx| { - state.update(cx, |state, cx| state.accept_terms_of_service(cx)); - } - }) - .into_any_element(), - ) + ) -> AnyView { + cx.new(|_| ConfigurationView::new(self.state.clone())) + .into() } fn reset_credentials(&self, _cx: &mut App) -> Task> { @@ -425,83 +387,6 @@ impl LanguageModelProvider for CloudLanguageModelProvider { } } -fn render_accept_terms( - view_kind: LanguageModelProviderTosView, - accept_terms_of_service_in_progress: bool, - accept_terms_callback: impl Fn(&mut Window, &mut App) + 'static, -) -> impl IntoElement { - let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart); - let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadEmptyState); - - let terms_button = Button::new("terms_of_service", "Terms of Service") - .style(ButtonStyle::Subtle) - .icon(IconName::ArrowUpRight) - .icon_color(Color::Muted) - .icon_size(IconSize::Small) - .when(thread_empty_state, |this| this.label_size(LabelSize::Small)) - .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service")); - - let button_container = h_flex().child( - Button::new("accept_terms", "I accept the Terms of Service") - .when(!thread_empty_state, |this| { - this.full_width() - .style(ButtonStyle::Tinted(TintColor::Accent)) - .icon(IconName::Check) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Small) - }) - .when(thread_empty_state, |this| { - this.style(ButtonStyle::Tinted(TintColor::Warning)) - .label_size(LabelSize::Small) - }) - .disabled(accept_terms_of_service_in_progress) - .on_click(move |_, window, cx| (accept_terms_callback)(window, cx)), - ); - - if thread_empty_state { - h_flex() - .w_full() - .flex_wrap() - .justify_between() - .child( - h_flex() - .child( - Label::new("To start using Zed AI, please read and accept the") - .size(LabelSize::Small), - ) - .child(terms_button), - ) - .child(button_container) - } else { - v_flex() - .w_full() - .gap_2() - .child( - h_flex() - .flex_wrap() - .when(thread_fresh_start, |this| this.justify_center()) - .child(Label::new( - "To start using Zed AI, please read and accept the", - )) - .child(terms_button), - ) - .child({ - match view_kind { - LanguageModelProviderTosView::TextThreadPopup => { - button_container.w_full().justify_end() - } - LanguageModelProviderTosView::Configuration => { - button_container.w_full().justify_start() - } - LanguageModelProviderTosView::ThreadFreshStart => { - button_container.w_full().justify_center() - } - LanguageModelProviderTosView::ThreadEmptyState => div().w_0(), - } - }) - } -} - pub struct CloudLanguageModel { id: LanguageModelId, model: Arc, @@ -592,15 +477,14 @@ impl CloudLanguageModel { .headers() .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME) .and_then(|resource| resource.to_str().ok()) - { - if let Some(plan) = response + && let Some(plan) = response .headers() .get(CURRENT_PLAN_HEADER_NAME) .and_then(|plan| plan.to_str().ok()) - .and_then(|plan| cloud_llm_client::Plan::from_str(plan).ok()) - { - return Err(anyhow!(ModelRequestLimitReachedError { plan })); - } + .and_then(|plan| cloud_llm_client::PlanV1::from_str(plan).ok()) + .map(Plan::V1) + { + return Err(anyhow!(ModelRequestLimitReachedError { plan })); } } else if status == StatusCode::PAYMENT_REQUIRED { return Err(anyhow!(PaymentRequiredError)); @@ -657,29 +541,29 @@ where impl From for LanguageModelCompletionError { fn from(error: ApiError) -> Self { - if let Ok(cloud_error) = serde_json::from_str::(&error.body) { - if cloud_error.code.starts_with("upstream_http_") { - let status = if let Some(status) = cloud_error.upstream_status { - status - } else if cloud_error.code.ends_with("_error") { - error.status - } else { - // If there's a status code in the code string (e.g. "upstream_http_429") - // then use that; otherwise, see if the JSON contains a status code. - cloud_error - .code - .strip_prefix("upstream_http_") - .and_then(|code_str| code_str.parse::().ok()) - .and_then(|code| StatusCode::from_u16(code).ok()) - .unwrap_or(error.status) - }; + if let Ok(cloud_error) = serde_json::from_str::(&error.body) + && cloud_error.code.starts_with("upstream_http_") + { + let status = if let Some(status) = cloud_error.upstream_status { + status + } else if cloud_error.code.ends_with("_error") { + error.status + } else { + // If there's a status code in the code string (e.g. "upstream_http_429") + // then use that; otherwise, see if the JSON contains a status code. + cloud_error + .code + .strip_prefix("upstream_http_") + .and_then(|code_str| code_str.parse::().ok()) + .and_then(|code| StatusCode::from_u16(code).ok()) + .unwrap_or(error.status) + }; - return LanguageModelCompletionError::UpstreamProviderError { - message: cloud_error.message, - status, - retry_after: cloud_error.retry_after.map(Duration::from_secs_f64), - }; - } + return LanguageModelCompletionError::UpstreamProviderError { + message: cloud_error.message, + status, + retry_after: cloud_error.retry_after.map(Duration::from_secs_f64), + }; } let retry_after = None; @@ -941,6 +825,7 @@ impl LanguageModel for CloudLanguageModel { request, model.id(), model.supports_parallel_tool_calls(), + model.supports_prompt_cache_key(), None, None, ); @@ -1103,10 +988,7 @@ struct ZedAiConfiguration { plan: Option, subscription_period: Option<(DateTime, DateTime)>, eligible_for_trial: bool, - has_accepted_terms_of_service: bool, account_too_young: bool, - accept_terms_of_service_in_progress: bool, - accept_terms_of_service_callback: Arc, sign_in_callback: Arc, } @@ -1114,15 +996,17 @@ impl RenderOnce for ZedAiConfiguration { fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement { let young_account_banner = YoungAccountBanner; - let is_pro = self.plan == Some(Plan::ZedPro); + let is_pro = self.plan.is_some_and(|plan| { + matches!(plan, Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro)) + }); let subscription_text = match (self.plan, self.subscription_period) { - (Some(Plan::ZedPro), Some(_)) => { + (Some(Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro)), Some(_)) => { "You have access to Zed's hosted models through your Pro subscription." } - (Some(Plan::ZedProTrial), Some(_)) => { + (Some(Plan::V1(PlanV1::ZedProTrial) | Plan::V2(PlanV2::ZedProTrial)), Some(_)) => { "You have access to Zed's hosted models through your Pro trial." } - (Some(Plan::ZedFree), Some(_)) => { + (Some(Plan::V1(PlanV1::ZedFree) | Plan::V2(PlanV2::ZedFree)), Some(_)) => { "You have basic access to Zed's hosted models through the Free plan." } _ => { @@ -1172,58 +1056,30 @@ impl RenderOnce for ZedAiConfiguration { ); } - v_flex() - .gap_2() - .w_full() - .when(!self.has_accepted_terms_of_service, |this| { - this.child(render_accept_terms( - LanguageModelProviderTosView::Configuration, - self.accept_terms_of_service_in_progress, - { - let callback = self.accept_terms_of_service_callback.clone(); - move |window, cx| (callback)(window, cx) - }, - )) - }) - .map(|this| { - if self.has_accepted_terms_of_service && self.account_too_young { - this.child(young_account_banner).child( - Button::new("upgrade", "Upgrade to Pro") - .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent)) - .full_width() - .on_click(|_, _, cx| { - cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)) - }), - ) - } else if self.has_accepted_terms_of_service { - this.text_sm() - .child(subscription_text) - .child(manage_subscription_buttons) - } else { - this - } - }) - .when(self.has_accepted_terms_of_service, |this| this) + v_flex().gap_2().w_full().map(|this| { + if self.account_too_young { + this.child(young_account_banner).child( + Button::new("upgrade", "Upgrade to Pro") + .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent)) + .full_width() + .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))), + ) + } else { + this.text_sm() + .child(subscription_text) + .child(manage_subscription_buttons) + } + }) } } struct ConfigurationView { state: Entity, - accept_terms_of_service_callback: Arc, sign_in_callback: Arc, } impl ConfigurationView { fn new(state: Entity) -> Self { - let accept_terms_of_service_callback = Arc::new({ - let state = state.clone(); - move |_window: &mut Window, cx: &mut App| { - state.update(cx, |state, cx| { - state.accept_terms_of_service(cx); - }); - } - }); - let sign_in_callback = Arc::new({ let state = state.clone(); move |_window: &mut Window, cx: &mut App| { @@ -1235,7 +1091,6 @@ impl ConfigurationView { Self { state, - accept_terms_of_service_callback, sign_in_callback, } } @@ -1251,10 +1106,7 @@ impl Render for ConfigurationView { plan: user_store.plan(), subscription_period: user_store.subscription_period(), eligible_for_trial: user_store.trial_started_at().is_none(), - has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx), account_too_young: user_store.account_too_young(), - accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(), - accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(), sign_in_callback: self.sign_in_callback.clone(), } } @@ -1279,7 +1131,6 @@ impl Component for ZedAiConfiguration { plan: Option, eligible_for_trial: bool, account_too_young: bool, - has_accepted_terms_of_service: bool, ) -> AnyElement { ZedAiConfiguration { is_connected, @@ -1288,10 +1139,7 @@ impl Component for ZedAiConfiguration { .is_some() .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))), eligible_for_trial, - has_accepted_terms_of_service, account_too_young, - accept_terms_of_service_in_progress: false, - accept_terms_of_service_callback: Arc::new(|_, _| {}), sign_in_callback: Arc::new(|_, _| {}), } .into_any_element() @@ -1302,33 +1150,30 @@ impl Component for ZedAiConfiguration { .p_4() .gap_4() .children(vec![ - single_example( - "Not connected", - configuration(false, None, false, false, true), - ), + single_example("Not connected", configuration(false, None, false, false)), single_example( "Accept Terms of Service", - configuration(true, None, true, false, false), + configuration(true, None, true, false), ), single_example( "No Plan - Not eligible for trial", - configuration(true, None, false, false, true), + configuration(true, None, false, false), ), single_example( "No Plan - Eligible for trial", - configuration(true, None, true, false, true), + configuration(true, None, true, false), ), single_example( "Free Plan", - configuration(true, Some(Plan::ZedFree), true, false, true), + configuration(true, Some(Plan::V1(PlanV1::ZedFree)), true, false), ), single_example( "Zed Pro Trial Plan", - configuration(true, Some(Plan::ZedProTrial), true, false, true), + configuration(true, Some(Plan::V1(PlanV1::ZedProTrial)), true, false), ), single_example( "Zed Pro Plan", - configuration(true, Some(Plan::ZedPro), true, false, true), + configuration(true, Some(Plan::V1(PlanV1::ZedPro)), true, false), ), ]) .into_any_element(), diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 73f73a9a313c764d45adfd14910efd801a472f1c..b7ece55fed70beae543b9bd55e7635fa6a3fc04d 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -14,10 +14,7 @@ use copilot::{Copilot, Status}; use futures::future::BoxFuture; use futures::stream::BoxStream; use futures::{FutureExt, Stream, StreamExt}; -use gpui::{ - Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task, - Transformation, percentage, svg, -}; +use gpui::{Action, AnyView, App, AsyncApp, Entity, Render, Subscription, Task, svg}; use language::language_settings::all_language_settings; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -28,14 +25,9 @@ use language_model::{ StopReason, TokenUsage, }; use settings::SettingsStore; -use std::time::Duration; -use ui::prelude::*; +use ui::{CommonAnimationExt, prelude::*}; use util::debug_panic; -use super::anthropic::count_anthropic_tokens; -use super::google::count_google_tokens; -use super::open_ai::count_open_ai_tokens; - const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("GitHub Copilot Chat"); @@ -176,7 +168,12 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { Task::ready(Err(err.into())) } - fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + _: &mut Window, + cx: &mut App, + ) -> AnyView { let state = self.state.clone(); cx.new(|cx| ConfigurationView::new(state, cx)).into() } @@ -188,6 +185,25 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { } } +fn collect_tiktoken_messages( + request: LanguageModelRequest, +) -> Vec { + request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>() +} + pub struct CopilotChatLanguageModel { model: CopilotChatModel, request_limiter: RateLimiter, @@ -223,7 +239,9 @@ impl LanguageModel for CopilotChatLanguageModel { ModelVendor::OpenAI | ModelVendor::Anthropic => { LanguageModelToolSchemaFormat::JsonSchema } - ModelVendor::Google => LanguageModelToolSchemaFormat::JsonSchemaSubset, + ModelVendor::Google | ModelVendor::XAI | ModelVendor::Unknown => { + LanguageModelToolSchemaFormat::JsonSchemaSubset + } } } @@ -248,14 +266,20 @@ impl LanguageModel for CopilotChatLanguageModel { request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result> { - match self.model.vendor() { - ModelVendor::Anthropic => count_anthropic_tokens(request, cx), - ModelVendor::Google => count_google_tokens(request, cx), - ModelVendor::OpenAI => { - let model = open_ai::Model::from_id(self.model.id()).unwrap_or_default(); - count_open_ai_tokens(request, model, cx) - } - } + let model = self.model.clone(); + cx.background_spawn(async move { + let messages = collect_tiktoken_messages(request); + // Copilot uses OpenAI tiktoken tokenizer for all it's model irrespective of the underlying provider(vendor). + let tokenizer_model = match model.tokenizer() { + Some("o200k_base") => "gpt-4o", + Some("cl100k_base") => "gpt-4", + _ => "gpt-4o", + }; + + tiktoken_rs::num_tokens_from_messages(tokenizer_model, &messages) + .map(|tokens| tokens as u64) + }) + .boxed() } fn stream_completion( @@ -470,7 +494,6 @@ fn into_copilot_chat( } } - let mut tool_called = false; let mut messages: Vec = Vec::new(); for message in request_messages { match message.role { @@ -540,7 +563,6 @@ fn into_copilot_chat( let mut tool_calls = Vec::new(); for content in &message.content { if let MessageContent::ToolUse(tool_use) = content { - tool_called = true; tool_calls.push(ToolCall { id: tool_use.id.to_string(), content: copilot::copilot_chat::ToolCallContent::Function { @@ -585,7 +607,7 @@ fn into_copilot_chat( } } - let mut tools = request + let tools = request .tools .iter() .map(|tool| Tool::Function { @@ -597,22 +619,6 @@ fn into_copilot_chat( }) .collect::>(); - // The API will return a Bad Request (with no error message) when tools - // were used previously in the conversation but no tools are provided as - // part of this request. Inserting a dummy tool seems to circumvent this - // error. - if tool_called && tools.is_empty() { - tools.push(Tool::Function { - function: copilot::copilot_chat::Function { - name: "noop".to_string(), - description: "No operation".to_string(), - parameters: serde_json::json!({ - "type": "object" - }), - }, - }); - } - Ok(CopilotChatRequest { intent: true, n: 1, @@ -677,11 +683,7 @@ impl Render for ConfigurationView { }), ) } else { - let loading_icon = Icon::new(IconName::ArrowCircle).with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(4)).repeat(), - |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), - ); + let loading_icon = Icon::new(IconName::ArrowCircle).with_rotate_animation(4); const ERROR_LABEL: &str = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different Assistant provider."; diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index a568ef4034193b5b1078d2ec4907d18fb0762efa..82bf067cd475fe031630767da9e4302afa4d78ec 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -77,7 +77,7 @@ impl State { .clone(); cx.spawn(async move |this, cx| { credentials_provider - .delete_credentials(&api_url, &cx) + .delete_credentials(&api_url, cx) .await .log_err(); this.update(cx, |this, cx| { @@ -96,7 +96,7 @@ impl State { .clone(); cx.spawn(async move |this, cx| { credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) .await?; this.update(cx, |this, cx| { this.api_key = Some(api_key); @@ -120,7 +120,7 @@ impl State { (api_key, true) } else { let (_, api_key) = credentials_provider - .read_credentials(&api_url, &cx) + .read_credentials(&api_url, cx) .await? .ok_or(AuthenticateError::CredentialsNotFound)?; ( @@ -229,7 +229,12 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } @@ -570,7 +575,7 @@ impl ConfigurationView { fn new(state: Entity, window: &mut Window, cx: &mut Context) -> Self { let api_key_editor = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); - editor.set_placeholder_text("sk-00000000000000000000000000000000", cx); + editor.set_placeholder_text("sk-00000000000000000000000000000000", window, cx); editor }); diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index b287e8181a2ac5d04650d799a0cd9b23d51749c2..939cf0ca60d92d713b90a5d62e8ec7f6dac7ec46 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -12,9 +12,9 @@ use gpui::{ }; use http_client::HttpClient; use language_model::{ - AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent, - LanguageModelToolChoice, LanguageModelToolSchemaFormat, LanguageModelToolUse, - LanguageModelToolUseId, MessageContent, StopReason, + AuthenticateError, ConfigurationViewTargetAgent, LanguageModelCompletionError, + LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat, + LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason, }; use language_model::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, @@ -37,6 +37,8 @@ use util::ResultExt; use crate::AllLanguageModelSettings; use crate::ui::InstructionListItem; +use super::anthropic::ApiKey; + const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME; @@ -110,7 +112,7 @@ impl State { .clone(); cx.spawn(async move |this, cx| { credentials_provider - .delete_credentials(&api_url, &cx) + .delete_credentials(&api_url, cx) .await .log_err(); this.update(cx, |this, cx| { @@ -129,7 +131,7 @@ impl State { .clone(); cx.spawn(async move |this, cx| { credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) .await?; this.update(cx, |this, cx| { this.api_key = Some(api_key); @@ -156,7 +158,7 @@ impl State { (api_key, true) } else { let (_, api_key) = credentials_provider - .read_credentials(&api_url, &cx) + .read_credentials(&api_url, cx) .await? .ok_or(AuthenticateError::CredentialsNotFound)?; ( @@ -198,6 +200,33 @@ impl GoogleLanguageModelProvider { request_limiter: RateLimiter::new(4), }) } + + pub fn api_key(cx: &mut App) -> Task> { + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .google + .api_url + .clone(); + + if let Ok(key) = std::env::var(GEMINI_API_KEY_VAR) { + Task::ready(Ok(ApiKey { + key, + from_env: true, + })) + } else { + cx.spawn(async move |cx| { + let (_, api_key) = credentials_provider + .read_credentials(&api_url, cx) + .await? + .ok_or(AuthenticateError::CredentialsNotFound)?; + + Ok(ApiKey { + key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, + from_env: false, + }) + }) + } + } } impl LanguageModelProviderState for GoogleLanguageModelProvider { @@ -277,8 +306,13 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { - cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) + fn configuration_view( + &self, + target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { + cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx)) .into() } @@ -382,7 +416,7 @@ impl LanguageModel for GoogleLanguageModel { cx: &App, ) -> BoxFuture<'static, Result> { let model_id = self.model.request_id().to_string(); - let request = into_google(request, model_id.clone(), self.model.mode()); + let request = into_google(request, model_id, self.model.mode()); let http_client = self.http_client.clone(); let api_key = self.state.read(cx).api_key.clone(); @@ -525,7 +559,7 @@ pub fn into_google( let system_instructions = if request .messages .first() - .map_or(false, |msg| matches!(msg.role, Role::System)) + .is_some_and(|msg| matches!(msg.role, Role::System)) { let message = request.messages.remove(0); Some(SystemInstruction { @@ -572,7 +606,7 @@ pub fn into_google( top_k: None, }), safety_settings: None, - tools: (request.tools.len() > 0).then(|| { + tools: (!request.tools.is_empty()).then(|| { vec![google_ai::Tool { function_declarations: request .tools @@ -771,11 +805,17 @@ fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage { struct ConfigurationView { api_key_editor: Entity, state: gpui::Entity, + target_agent: language_model::ConfigurationViewTargetAgent, load_credentials_task: Option>, } impl ConfigurationView { - fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { + fn new( + state: gpui::Entity, + target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut Context, + ) -> Self { cx.observe(&state, |_, _, cx| { cx.notify(); }) @@ -802,9 +842,10 @@ impl ConfigurationView { Self { api_key_editor: cx.new(|cx| { let mut editor = Editor::single_line(window, cx); - editor.set_placeholder_text("AIzaSy...", cx); + editor.set_placeholder_text("AIzaSy...", window, cx); editor }), + target_agent, state, load_credentials_task, } @@ -880,7 +921,10 @@ impl Render for ConfigurationView { v_flex() .size_full() .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's agent with Google AI, you need to add an API key. Follow these steps:")) + .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent { + ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(), + ConfigurationViewTargetAgent::Other(agent) => agent.clone(), + }))) .child( List::new() .child(InstructionListItem::new( diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 36a32ab941ec65eb790a59ba3a7ed4fe3e6eb575..80b28a396b958ab20de3faa0a0f6919c57011e5c 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -210,7 +210,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider { .map(|model| { Arc::new(LmStudioLanguageModel { id: LanguageModelId::from(model.name.clone()), - model: model.clone(), + model, http_client: self.http_client.clone(), request_limiter: RateLimiter::new(4), }) as Arc @@ -226,7 +226,12 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + _window: &mut Window, + cx: &mut App, + ) -> AnyView { let state = self.state.clone(); cx.new(|cx| ConfigurationView::new(state, cx)).into() } diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 4a0d740334e38b9f5ea512344161fe5ca3f8db71..c9824bf89ea7a919f4517f492a5091a2cda7b43b 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -76,7 +76,7 @@ impl State { .clone(); cx.spawn(async move |this, cx| { credentials_provider - .delete_credentials(&api_url, &cx) + .delete_credentials(&api_url, cx) .await .log_err(); this.update(cx, |this, cx| { @@ -95,7 +95,7 @@ impl State { .clone(); cx.spawn(async move |this, cx| { credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) .await?; this.update(cx, |this, cx| { this.api_key = Some(api_key); @@ -119,7 +119,7 @@ impl State { (api_key, true) } else { let (_, api_key) = credentials_provider - .read_credentials(&api_url, &cx) + .read_credentials(&api_url, cx) .await? .ok_or(AuthenticateError::CredentialsNotFound)?; ( @@ -243,7 +243,12 @@ impl LanguageModelProvider for MistralLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } @@ -739,7 +744,7 @@ impl ConfigurationView { fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { let api_key_editor = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); - editor.set_placeholder_text("0aBCDEFGhIjKLmNOpqrSTUVwxyzabCDE1f2", cx); + editor.set_placeholder_text("0aBCDEFGhIjKLmNOpqrSTUVwxyzabCDE1f2", window, cx); editor }); diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 0c2b1107b18cf72f70e46c195e7c61bfae607285..a80cacfc4a02521af74b32c34cc3360e9665a7d9 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -11,8 +11,8 @@ use language_model::{ LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage, }; use ollama::{ - ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool, - OllamaToolCall, get_models, show_model, stream_chat_completion, + ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionCall, + OllamaFunctionTool, OllamaToolCall, get_models, show_model, stream_chat_completion, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -237,7 +237,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { .map(|model| { Arc::new(OllamaLanguageModel { id: LanguageModelId::from(model.name.clone()), - model: model.clone(), + model, http_client: self.http_client.clone(), request_limiter: RateLimiter::new(4), }) as Arc @@ -255,7 +255,12 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { let state = self.state.clone(); cx.new(|cx| ConfigurationView::new(state, window, cx)) .into() @@ -277,59 +282,85 @@ impl OllamaLanguageModel { fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest { let supports_vision = self.model.supports_vision.unwrap_or(false); - ChatRequest { - model: self.model.name.clone(), - messages: request - .messages - .into_iter() - .map(|msg| { - let images = if supports_vision { - msg.content - .iter() - .filter_map(|content| match content { - MessageContent::Image(image) => Some(image.source.to_string()), - _ => None, - }) - .collect::>() - } else { - vec![] - }; - - match msg.role { - Role::User => ChatMessage::User { + let mut messages = Vec::with_capacity(request.messages.len()); + + for mut msg in request.messages.into_iter() { + let images = if supports_vision { + msg.content + .iter() + .filter_map(|content| match content { + MessageContent::Image(image) => Some(image.source.to_string()), + _ => None, + }) + .collect::>() + } else { + vec![] + }; + + match msg.role { + Role::User => { + for tool_result in msg + .content + .extract_if(.., |x| matches!(x, MessageContent::ToolResult(..))) + { + match tool_result { + MessageContent::ToolResult(tool_result) => { + messages.push(ChatMessage::Tool { + tool_name: tool_result.tool_name.to_string(), + content: tool_result.content.to_str().unwrap_or("").to_string(), + }) + } + _ => unreachable!("Only tool result should be extracted"), + } + } + if !msg.content.is_empty() { + messages.push(ChatMessage::User { content: msg.string_contents(), images: if images.is_empty() { None } else { Some(images) }, - }, - Role::Assistant => { - let content = msg.string_contents(); - let thinking = - msg.content.into_iter().find_map(|content| match content { - MessageContent::Thinking { text, .. } if !text.is_empty() => { - Some(text) - } - _ => None, - }); - ChatMessage::Assistant { - content, - tool_calls: None, - images: if images.is_empty() { - None - } else { - Some(images) - }, - thinking, + }) + } + } + Role::Assistant => { + let content = msg.string_contents(); + let mut thinking = None; + let mut tool_calls = Vec::new(); + for content in msg.content.into_iter() { + match content { + MessageContent::Thinking { text, .. } if !text.is_empty() => { + thinking = Some(text) + } + MessageContent::ToolUse(tool_use) => { + tool_calls.push(OllamaToolCall::Function(OllamaFunctionCall { + name: tool_use.name.to_string(), + arguments: tool_use.input, + })); } + _ => (), } - Role::System => ChatMessage::System { - content: msg.string_contents(), - }, } - }) - .collect(), + messages.push(ChatMessage::Assistant { + content, + tool_calls: Some(tool_calls), + images: if images.is_empty() { + None + } else { + Some(images) + }, + thinking, + }) + } + Role::System => messages.push(ChatMessage::System { + content: msg.string_contents(), + }), + } + } + ChatRequest { + model: self.model.name.clone(), + messages, keep_alive: self.model.keep_alive.clone().unwrap_or_default(), stream: true, options: Some(ChatOptions { @@ -342,7 +373,11 @@ impl OllamaLanguageModel { .model .supports_thinking .map(|supports_thinking| supports_thinking && request.thinking_allowed), - tools: request.tools.into_iter().map(tool_into_ollama).collect(), + tools: if self.model.supports_tools.unwrap_or(false) { + request.tools.into_iter().map(tool_into_ollama).collect() + } else { + vec![] + }, } } } @@ -474,6 +509,9 @@ fn map_to_language_model_completion_events( ChatMessage::System { content } => { events.push(Ok(LanguageModelCompletionEvent::Text(content))); } + ChatMessage::Tool { content, .. } => { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } ChatMessage::Assistant { content, tool_calls, diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 725027b2a73d4b54303b17f795d0e93526169575..fca1cf977cb5e3b32dc6f2335fb0d9188979bc9f 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -56,13 +56,13 @@ pub struct OpenAiLanguageModelProvider { pub struct State { api_key: Option, api_key_from_env: bool, + last_api_url: String, _subscription: Subscription, } const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY"; impl State { - // fn is_authenticated(&self) -> bool { self.api_key.is_some() } @@ -75,7 +75,7 @@ impl State { .clone(); cx.spawn(async move |this, cx| { credentials_provider - .delete_credentials(&api_url, &cx) + .delete_credentials(&api_url, cx) .await .log_err(); this.update(cx, |this, cx| { @@ -94,7 +94,7 @@ impl State { .clone(); cx.spawn(async move |this, cx| { credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) .await .log_err(); this.update(cx, |this, cx| { @@ -104,11 +104,7 @@ impl State { }) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - + fn get_api_key(&self, cx: &mut Context) -> Task> { let credentials_provider = ::global(cx); let api_url = AllLanguageModelSettings::get_global(cx) .openai @@ -119,7 +115,7 @@ impl State { (api_key, true) } else { let (_, api_key) = credentials_provider - .read_credentials(&api_url, &cx) + .read_credentials(&api_url, cx) .await? .ok_or(AuthenticateError::CredentialsNotFound)?; ( @@ -136,14 +132,52 @@ impl State { Ok(()) }) } + + fn authenticate(&self, cx: &mut Context) -> Task> { + if self.is_authenticated() { + return Task::ready(Ok(())); + } + + self.get_api_key(cx) + } } impl OpenAiLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { + let initial_api_url = AllLanguageModelSettings::get_global(cx) + .openai + .api_url + .clone(); + let state = cx.new(|cx| State { api_key: None, api_key_from_env: false, - _subscription: cx.observe_global::(|_this: &mut State, cx| { + last_api_url: initial_api_url.clone(), + _subscription: cx.observe_global::(|this: &mut State, cx| { + let current_api_url = AllLanguageModelSettings::get_global(cx) + .openai + .api_url + .clone(); + + if this.last_api_url != current_api_url { + this.last_api_url = current_api_url; + if !this.api_key_from_env { + this.api_key = None; + let spawn_task = cx.spawn(async move |handle, cx| { + if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) { + if let Err(_) = task.await { + handle + .update(cx, |this, _| { + this.api_key = None; + this.api_key_from_env = false; + }) + .ok(); + } + } + }); + spawn_task.detach(); + } + } cx.notify(); }), }); @@ -233,7 +267,12 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } @@ -370,6 +409,7 @@ impl LanguageModel for OpenAiLanguageModel { request, self.model.id(), self.model.supports_parallel_tool_calls(), + self.model.supports_prompt_cache_key(), self.max_output_tokens(), self.model.reasoning_effort(), ); @@ -386,6 +426,7 @@ pub fn into_open_ai( request: LanguageModelRequest, model_id: &str, supports_parallel_tool_calls: bool, + supports_prompt_cache_key: bool, max_output_tokens: Option, reasoning_effort: Option, ) -> open_ai::Request { @@ -397,7 +438,7 @@ pub fn into_open_ai( match content { MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { add_message_content_part( - open_ai::MessagePart::Text { text: text }, + open_ai::MessagePart::Text { text }, message.role, &mut messages, ) @@ -477,7 +518,11 @@ pub fn into_open_ai( } else { None }, - prompt_cache_key: request.thread_id, + prompt_cache_key: if supports_prompt_cache_key { + request.thread_id + } else { + None + }, tools: request .tools .into_iter() @@ -575,7 +620,9 @@ impl OpenAiEventMapper { }; if let Some(content) = choice.delta.content.clone() { - events.push(Ok(LanguageModelCompletionEvent::Text(content))); + if !content.is_empty() { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } } if let Some(tool_calls) = choice.delta.tool_calls.as_ref() { diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index 6e912765cdabeadbaab743904c723b556502703f..4ebb11a07b66ec7054ca65437ec887a415fa3f5c 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -9,7 +9,7 @@ use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, RateLimiter, + LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, }; use menu; use open_ai::{ResponseStreamEvent, stream_completion}; @@ -38,6 +38,27 @@ pub struct AvailableModel { pub max_tokens: u64, pub max_output_tokens: Option, pub max_completion_tokens: Option, + #[serde(default)] + pub capabilities: ModelCapabilities, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct ModelCapabilities { + pub tools: bool, + pub images: bool, + pub parallel_tool_calls: bool, + pub prompt_cache_key: bool, +} + +impl Default for ModelCapabilities { + fn default() -> Self { + Self { + tools: true, + images: false, + parallel_tool_calls: false, + prompt_cache_key: false, + } + } } pub struct OpenAiCompatibleLanguageModelProvider { @@ -66,7 +87,7 @@ impl State { let api_url = self.settings.api_url.clone(); cx.spawn(async move |this, cx| { credentials_provider - .delete_credentials(&api_url, &cx) + .delete_credentials(&api_url, cx) .await .log_err(); this.update(cx, |this, cx| { @@ -82,7 +103,7 @@ impl State { let api_url = self.settings.api_url.clone(); cx.spawn(async move |this, cx| { credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) .await .log_err(); this.update(cx, |this, cx| { @@ -92,11 +113,7 @@ impl State { }) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - + fn get_api_key(&self, cx: &mut Context) -> Task> { let credentials_provider = ::global(cx); let env_var_name = self.env_var_name.clone(); let api_url = self.settings.api_url.clone(); @@ -105,7 +122,7 @@ impl State { (api_key, true) } else { let (_, api_key) = credentials_provider - .read_credentials(&api_url, &cx) + .read_credentials(&api_url, cx) .await? .ok_or(AuthenticateError::CredentialsNotFound)?; ( @@ -122,6 +139,14 @@ impl State { Ok(()) }) } + + fn authenticate(&self, cx: &mut Context) -> Task> { + if self.is_authenticated() { + return Task::ready(Ok(())); + } + + self.get_api_key(cx) + } } impl OpenAiCompatibleLanguageModelProvider { @@ -139,11 +164,27 @@ impl OpenAiCompatibleLanguageModelProvider { api_key: None, api_key_from_env: false, _subscription: cx.observe_global::(|this: &mut State, cx| { - let Some(settings) = resolve_settings(&this.id, cx) else { + let Some(settings) = resolve_settings(&this.id, cx).cloned() else { return; }; - if &this.settings != settings { - this.settings = settings.clone(); + if &this.settings != &settings { + if settings.api_url != this.settings.api_url && !this.api_key_from_env { + let spawn_task = cx.spawn(async move |handle, cx| { + if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) { + if let Err(_) = task.await { + handle + .update(cx, |this, _| { + this.api_key = None; + this.api_key_from_env = false; + }) + .ok(); + } + } + }); + spawn_task.detach(); + } + + this.settings = settings; cx.notify(); } }), @@ -222,7 +263,12 @@ impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } @@ -293,17 +339,21 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { } fn supports_tools(&self) -> bool { - true + self.model.capabilities.tools + } + + fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { + LanguageModelToolSchemaFormat::JsonSchemaSubset } fn supports_images(&self) -> bool { - false + self.model.capabilities.images } fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { match choice { - LanguageModelToolChoice::Auto => true, - LanguageModelToolChoice::Any => true, + LanguageModelToolChoice::Auto => self.model.capabilities.tools, + LanguageModelToolChoice::Any => self.model.capabilities.tools, LanguageModelToolChoice::None => true, } } @@ -358,7 +408,8 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { let request = into_open_ai( request, &self.model.name, - true, + self.model.capabilities.parallel_tool_calls, + self.model.capabilities.prompt_cache_key, self.max_output_tokens(), None, ); diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 3a492086f16e1f9b53a196b7bb2e9817a3cac0e7..698e9d23cc74c56b00daa48359b663ff034c1abe 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -92,7 +92,7 @@ pub struct State { api_key_from_env: bool, http_client: Arc, available_models: Vec, - fetch_models_task: Option>>, + fetch_models_task: Option>>, settings: OpenRouterSettings, _subscription: Subscription, } @@ -112,7 +112,7 @@ impl State { .clone(); cx.spawn(async move |this, cx| { credentials_provider - .delete_credentials(&api_url, &cx) + .delete_credentials(&api_url, cx) .await .log_err(); this.update(cx, |this, cx| { @@ -131,7 +131,7 @@ impl State { .clone(); cx.spawn(async move |this, cx| { credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) .await .log_err(); this.update(cx, |this, cx| { @@ -152,20 +152,21 @@ impl State { .open_router .api_url .clone(); + cx.spawn(async move |this, cx| { let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) { (api_key, true) } else { let (_, api_key) = credentials_provider - .read_credentials(&api_url, &cx) + .read_credentials(&api_url, cx) .await? .ok_or(AuthenticateError::CredentialsNotFound)?; ( - String::from_utf8(api_key) - .context(format!("invalid {} API key", PROVIDER_NAME))?, + String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, false, ) }; + this.update(cx, |this, cx| { this.api_key = Some(api_key); this.api_key_from_env = from_env; @@ -177,18 +178,35 @@ impl State { }) } - fn fetch_models(&mut self, cx: &mut Context) -> Task> { + fn fetch_models( + &mut self, + cx: &mut Context, + ) -> Task> { let settings = &AllLanguageModelSettings::get_global(cx).open_router; let http_client = self.http_client.clone(); let api_url = settings.api_url.clone(); - + let Some(api_key) = self.api_key.clone() else { + return Task::ready(Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + })); + }; cx.spawn(async move |this, cx| { - let models = list_models(http_client.as_ref(), &api_url).await?; + let models = list_models(http_client.as_ref(), &api_url, &api_key) + .await + .map_err(|e| { + LanguageModelCompletionError::Other(anyhow::anyhow!( + "OpenRouter error: {:?}", + e + )) + })?; this.update(cx, |this, cx| { this.available_models = models; cx.notify(); }) + .map_err(|e| LanguageModelCompletionError::Other(e))?; + + Ok(()) }) } @@ -306,7 +324,12 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } @@ -329,27 +352,37 @@ impl OpenRouterLanguageModel { &self, request: open_router::Request, cx: &AsyncApp, - ) -> BoxFuture<'static, Result>>> - { + ) -> BoxFuture< + 'static, + Result< + futures::stream::BoxStream< + 'static, + Result, + >, + LanguageModelCompletionError, + >, + > { let http_client = self.http_client.clone(); let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { let settings = &AllLanguageModelSettings::get_global(cx).open_router; (state.api_key.clone(), settings.api_url.clone()) }) else { - return futures::future::ready(Err(anyhow!( - "App state dropped: Unable to read API key or API URL from the application state" - ))) + return futures::future::ready(Err(LanguageModelCompletionError::Other(anyhow!( + "App state dropped" + )))) .boxed(); }; - let future = self.request_limiter.stream(async move { - let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenRouter API Key"))?; + async move { + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); - let response = request.await?; - Ok(response) - }); - - async move { Ok(future.await?.boxed()) }.boxed() + request.await.map_err(Into::into) + } + .boxed() } } @@ -376,7 +409,7 @@ impl LanguageModel for OpenRouterLanguageModel { fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { let model_id = self.model.id().trim().to_lowercase(); - if model_id.contains("gemini") || model_id.contains("grok-4") { + if model_id.contains("gemini") || model_id.contains("grok") { LanguageModelToolSchemaFormat::JsonSchemaSubset } else { LanguageModelToolSchemaFormat::JsonSchema @@ -430,12 +463,12 @@ impl LanguageModel for OpenRouterLanguageModel { >, > { let request = into_open_router(request, &self.model, self.max_output_tokens()); - let completions = self.stream_completion(request, cx); - async move { - let mapper = OpenRouterEventMapper::new(); - Ok(mapper.map_stream(completions.await?).boxed()) - } - .boxed() + let request = self.stream_completion(request, cx); + let future = self.request_limiter.stream(async move { + let response = request.await?; + Ok(OpenRouterEventMapper::new().map_stream(response)) + }); + async move { Ok(future.await?.boxed()) }.boxed() } } @@ -603,13 +636,17 @@ impl OpenRouterEventMapper { pub fn map_stream( mut self, - events: Pin>>>, + events: Pin< + Box< + dyn Send + Stream>, + >, + >, ) -> impl Stream> { events.flat_map(move |event| { futures::stream::iter(match event { Ok(event) => self.map_event(event), - Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], + Err(error) => vec![Err(error.into())], }) }) } @@ -750,8 +787,11 @@ impl ConfigurationView { fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { let api_key_editor = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); - editor - .set_placeholder_text("sk_or_000000000000000000000000000000000000000000000000", cx); + editor.set_placeholder_text( + "sk_or_000000000000000000000000000000000000000000000000", + window, + cx, + ); editor }); diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index 57a89ba4aabc63e981949ceb22e3f91de9ec3957..84f3175d1e5493fd55cafd2ea9c4a0604d2a97b4 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -71,7 +71,7 @@ impl State { }; cx.spawn(async move |this, cx| { credentials_provider - .delete_credentials(&api_url, &cx) + .delete_credentials(&api_url, cx) .await .log_err(); this.update(cx, |this, cx| { @@ -92,7 +92,7 @@ impl State { }; cx.spawn(async move |this, cx| { credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) .await .log_err(); this.update(cx, |this, cx| { @@ -119,7 +119,7 @@ impl State { (api_key, true) } else { let (_, api_key) = credentials_provider - .read_credentials(&api_url, &cx) + .read_credentials(&api_url, cx) .await? .ok_or(AuthenticateError::CredentialsNotFound)?; ( @@ -230,7 +230,12 @@ impl LanguageModelProvider for VercelLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } @@ -355,6 +360,7 @@ impl LanguageModel for VercelLanguageModel { request, self.model.id(), self.model.supports_parallel_tool_calls(), + self.model.supports_prompt_cache_key(), self.max_output_tokens(), None, ); diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index 5e7190ea961d1ec4781e31d9d1a19d9673afc9c0..bb17f22c7f3fdbb0296b1e0bb290fbce9a979ddf 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -71,7 +71,7 @@ impl State { }; cx.spawn(async move |this, cx| { credentials_provider - .delete_credentials(&api_url, &cx) + .delete_credentials(&api_url, cx) .await .log_err(); this.update(cx, |this, cx| { @@ -92,7 +92,7 @@ impl State { }; cx.spawn(async move |this, cx| { credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) .await .log_err(); this.update(cx, |this, cx| { @@ -119,7 +119,7 @@ impl State { (api_key, true) } else { let (_, api_key) = credentials_provider - .read_credentials(&api_url, &cx) + .read_credentials(&api_url, cx) .await? .ok_or(AuthenticateError::CredentialsNotFound)?; ( @@ -230,7 +230,12 @@ impl LanguageModelProvider for XAiLanguageModelProvider { self.state.update(cx, |state, cx| state.authenticate(cx)) } - fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } @@ -314,7 +319,7 @@ impl LanguageModel for XAiLanguageModel { } fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { let model_id = self.model.id().trim().to_lowercase(); - if model_id.eq(x_ai::Model::Grok4.id()) { + if model_id.eq(x_ai::Model::Grok4.id()) || model_id.eq(x_ai::Model::GrokCodeFast1.id()) { LanguageModelToolSchemaFormat::JsonSchemaSubset } else { LanguageModelToolSchemaFormat::JsonSchema @@ -359,6 +364,7 @@ impl LanguageModel for XAiLanguageModel { request, self.model.id(), self.model.supports_parallel_tool_calls(), + self.model.supports_prompt_cache_key(), self.max_output_tokens(), None, ); diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index b163585aa7b745447381aa62f710e8c5dbdf469c..cfe66c91a36d4da562cba84363f79bd1d5b4e1ce 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -5,7 +5,7 @@ use collections::HashMap; use gpui::App; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsUi}; use crate::provider::{ self, @@ -46,7 +46,10 @@ pub struct AllLanguageModelSettings { pub zed_dot_dev: ZedDotDevSettings, } -#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +#[derive( + Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema, SettingsUi, SettingsKey, +)] +#[settings_key(key = "language_models")] pub struct AllLanguageModelSettingsContent { pub anthropic: Option, pub bedrock: Option, @@ -145,8 +148,6 @@ pub struct OpenRouterSettingsContent { } impl settings::Settings for AllLanguageModelSettings { - const KEY: Option<&'static str> = Some("language_models"); - const PRESERVED_KEYS: Option<&'static [&'static str]> = Some(&["version"]); type FileContent = AllLanguageModelSettingsContent; diff --git a/crates/language_models/src/ui/instruction_list_item.rs b/crates/language_models/src/ui/instruction_list_item.rs index 3dee97aff6ca78f97f0e4386e9518f5a5d1f29e0..bdb5fbe242ee902dc98a37addfaa0f103ef9ad20 100644 --- a/crates/language_models/src/ui/instruction_list_item.rs +++ b/crates/language_models/src/ui/instruction_list_item.rs @@ -37,7 +37,7 @@ impl IntoElement for InstructionListItem { let item_content = if let (Some(button_label), Some(button_link)) = (self.button_label, self.button_link) { - let link = button_link.clone(); + let link = button_link; let unique_id = SharedString::from(format!("{}-button", self.label)); h_flex() diff --git a/crates/language_onboarding/Cargo.toml b/crates/language_onboarding/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..a437adf1191a3b76fbd828dacaa60b75b1f7df28 --- /dev/null +++ b/crates/language_onboarding/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "language_onboarding" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/python.rs" + +[features] +default = [] + +[dependencies] +db.workspace = true +editor.workspace = true +gpui.workspace = true +project.workspace = true +ui.workspace = true +workspace.workspace = true +workspace-hack.workspace = true + +# Uncomment other workspace dependencies as needed +# assistant.workspace = true +# client.workspace = true +# project.workspace = true +# settings.workspace = true diff --git a/crates/language_onboarding/LICENSE-GPL b/crates/language_onboarding/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/language_onboarding/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/language_onboarding/src/python.rs b/crates/language_onboarding/src/python.rs new file mode 100644 index 0000000000000000000000000000000000000000..6b83b841e0488d67014cc090b6c741035e544e04 --- /dev/null +++ b/crates/language_onboarding/src/python.rs @@ -0,0 +1,95 @@ +use db::kvp::Dismissable; +use editor::Editor; +use gpui::{Context, EventEmitter, Subscription}; +use ui::{Banner, FluentBuilder as _, prelude::*}; +use workspace::{ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace}; + +pub struct BasedPyrightBanner { + dismissed: bool, + have_basedpyright: bool, + _subscriptions: [Subscription; 1], +} + +impl Dismissable for BasedPyrightBanner { + const KEY: &str = "basedpyright-banner"; +} + +impl BasedPyrightBanner { + pub fn new(workspace: &Workspace, cx: &mut Context) -> Self { + let subscription = cx.subscribe(workspace.project(), |this, _, event, _| { + if let project::Event::LanguageServerAdded(_, name, _) = event + && name == "basedpyright" + { + this.have_basedpyright = true; + } + }); + let dismissed = Self::dismissed(); + Self { + dismissed, + have_basedpyright: false, + _subscriptions: [subscription], + } + } +} + +impl EventEmitter for BasedPyrightBanner {} + +impl Render for BasedPyrightBanner { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + div() + .id("basedpyright-banner") + .when(!self.dismissed && self.have_basedpyright, |el| { + el.child( + Banner::new() + .child( + v_flex() + .gap_0p5() + .child(Label::new("Basedpyright is now the only default language server for Python").mt_0p5()) + .child(Label::new("We have disabled PyRight and pylsp by default. They can be re-enabled in your settings.").size(LabelSize::Small).color(Color::Muted)) + ) + .action_slot( + h_flex() + .gap_0p5() + .child( + Button::new("learn-more", "Learn More") + .icon(IconName::ArrowUpRight) + .label_size(LabelSize::Small) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(|_, _, cx| { + cx.open_url("https://zed.dev/docs/languages/python") + }), + ) + .child(IconButton::new("dismiss", IconName::Close).icon_size(IconSize::Small).on_click( + cx.listener(|this, _, _, cx| { + this.dismissed = true; + Self::set_dismissed(true, cx); + cx.notify(); + }), + )) + ) + .into_any_element(), + ) + }) + } +} + +impl ToolbarItemView for BasedPyrightBanner { + fn set_active_pane_item( + &mut self, + active_pane_item: Option<&dyn workspace::ItemHandle>, + _window: &mut ui::Window, + cx: &mut Context, + ) -> ToolbarItemLocation { + if let Some(item) = active_pane_item + && let Some(editor) = item.act_as::(cx) + && let Some(path) = editor.update(cx, |editor, cx| editor.target_file_abs_path(cx)) + && let Some(file_name) = path.file_name() + && file_name.as_encoded_bytes().ends_with(".py".as_bytes()) + { + return ToolbarItemLocation::Secondary; + } + + ToolbarItemLocation::Hidden + } +} diff --git a/crates/language_selector/src/active_buffer_language.rs b/crates/language_selector/src/active_buffer_language.rs index c5c5eceab54f2c34f4b1e2aae1b04f85fc5d9ab6..56924c4cd2d54c64436a5ccaa7dabfe4c53ff0ec 100644 --- a/crates/language_selector/src/active_buffer_language.rs +++ b/crates/language_selector/src/active_buffer_language.rs @@ -28,10 +28,10 @@ impl ActiveBufferLanguage { self.active_language = Some(None); let editor = editor.read(cx); - if let Some((_, buffer, _)) = editor.active_excerpt(cx) { - if let Some(language) = buffer.read(cx).language() { - self.active_language = Some(Some(language.name())); - } + if let Some((_, buffer, _)) = editor.active_excerpt(cx) + && let Some(language) = buffer.read(cx).language() + { + self.active_language = Some(Some(language.name())); } cx.notify(); diff --git a/crates/language_selector/src/language_selector.rs b/crates/language_selector/src/language_selector.rs index f6e2d75015560582b30453767b1a3b30f7cce82e..991cce50baf82b2604e510a0eeb2eac4af1578dd 100644 --- a/crates/language_selector/src/language_selector.rs +++ b/crates/language_selector/src/language_selector.rs @@ -283,7 +283,7 @@ impl PickerDelegate for LanguageSelectorDelegate { _: &mut Window, cx: &mut Context>, ) -> Option { - let mat = &self.matches[ix]; + let mat = &self.matches.get(ix)?; let (label, language_icon) = self.language_data_for_match(mat, cx); Some( ListItem::new(ix) diff --git a/crates/language_tools/Cargo.toml b/crates/language_tools/Cargo.toml index 5aa914311a6eccc1cb68efa37e878ad12249d6fd..bbac900cded75e9ca680a1813734f57423ce0ee9 100644 --- a/crates/language_tools/Cargo.toml +++ b/crates/language_tools/Cargo.toml @@ -16,6 +16,7 @@ doctest = false anyhow.workspace = true client.workspace = true collections.workspace = true +command_palette_hooks.workspace = true copilot.workspace = true editor.workspace = true futures.workspace = true @@ -24,6 +25,7 @@ itertools.workspace = true language.workspace = true lsp.workspace = true project.workspace = true +proto.workspace = true serde_json.workspace = true settings.workspace = true theme.workspace = true diff --git a/crates/language_tools/src/key_context_view.rs b/crates/language_tools/src/key_context_view.rs index 88131781ec3af336d3ae793cf1820e5bcf731605..4140713544ed2b22413f909ac45989de8df4e706 100644 --- a/crates/language_tools/src/key_context_view.rs +++ b/crates/language_tools/src/key_context_view.rs @@ -4,7 +4,6 @@ use gpui::{ }; use itertools::Itertools; use serde_json::json; -use settings::get_key_equivalents; use ui::{Button, ButtonStyle}; use ui::{ ButtonCommon, Clickable, Context, FluentBuilder, InteractiveElement, Label, LabelCommon, @@ -71,12 +70,10 @@ impl KeyContextView { } else { None } + } else if this.action_matches(&e.action, binding.action()) { + Some(true) } else { - if this.action_matches(&e.action, binding.action()) { - Some(true) - } else { - Some(false) - } + Some(false) }; let predicate = if let Some(predicate) = binding.predicate() { format!("{}", predicate) @@ -98,9 +95,7 @@ impl KeyContextView { cx.notify(); }); let sub2 = cx.observe_pending_input(window, |this, window, cx| { - this.pending_keystrokes = window - .pending_input_keystrokes() - .map(|k| k.iter().cloned().collect()); + this.pending_keystrokes = window.pending_input_keystrokes().map(|k| k.to_vec()); if this.pending_keystrokes.is_some() { this.last_keystrokes.take(); } @@ -173,7 +168,8 @@ impl Item for KeyContextView { impl Render for KeyContextView { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl ui::IntoElement { use itertools::Itertools; - let key_equivalents = get_key_equivalents(cx.keyboard_layout().id()); + + let key_equivalents = cx.keyboard_mapper().get_key_equivalents(); v_flex() .id("key-context-view") .overflow_scroll() diff --git a/crates/language_tools/src/language_tools.rs b/crates/language_tools/src/language_tools.rs index cbf5756875f723b52fabbfe877c32265dd6f0aef..aa1672806417493c0c5a877a28fc7906f3da6ff8 100644 --- a/crates/language_tools/src/language_tools.rs +++ b/crates/language_tools/src/language_tools.rs @@ -1,20 +1,20 @@ mod key_context_view; -mod lsp_log; -pub mod lsp_tool; +pub mod lsp_button; +pub mod lsp_log_view; mod syntax_tree_view; #[cfg(test)] -mod lsp_log_tests; +mod lsp_log_view_tests; use gpui::{App, AppContext, Entity}; -pub use lsp_log::{LogStore, LspLogToolbarItemView, LspLogView}; +pub use lsp_log_view::LspLogView; pub use syntax_tree_view::{SyntaxTreeToolbarItemView, SyntaxTreeView}; use ui::{Context, Window}; use workspace::{Item, ItemHandle, SplitDirection, Workspace}; pub fn init(cx: &mut App) { - lsp_log::init(cx); + lsp_log_view::init(false, cx); syntax_tree_view::init(cx); key_context_view::init(cx); } diff --git a/crates/language_tools/src/lsp_tool.rs b/crates/language_tools/src/lsp_button.rs similarity index 90% rename from crates/language_tools/src/lsp_tool.rs rename to crates/language_tools/src/lsp_button.rs index 50547253a92b8c23d0530326faf916e56363dcd9..59beceff98ff2544aa22accc470b4e497b88c6ca 100644 --- a/crates/language_tools/src/lsp_tool.rs +++ b/crates/language_tools/src/lsp_button.rs @@ -11,7 +11,10 @@ use editor::{Editor, EditorEvent}; use gpui::{Corner, Entity, Subscription, Task, WeakEntity, actions}; use language::{BinaryStatus, BufferId, ServerHealth}; use lsp::{LanguageServerId, LanguageServerName, LanguageServerSelector}; -use project::{LspStore, LspStoreEvent, Worktree, project_settings::ProjectSettings}; +use project::{ + LspStore, LspStoreEvent, Worktree, lsp_store::log_store::GlobalLogStore, + project_settings::ProjectSettings, +}; use settings::{Settings as _, SettingsStore}; use ui::{ Context, ContextMenu, ContextMenuEntry, ContextMenuItem, DocumentationAside, DocumentationSide, @@ -20,7 +23,7 @@ use ui::{ use workspace::{StatusItemView, Workspace}; -use crate::lsp_log::GlobalLogStore; +use crate::lsp_log_view; actions!( lsp_tool, @@ -30,7 +33,7 @@ actions!( ] ); -pub struct LspTool { +pub struct LspButton { server_state: Entity, popover_menu_handle: PopoverMenuHandle, lsp_menu: Option>, @@ -121,9 +124,8 @@ impl LanguageServerState { menu = menu.align_popover_bottom(); let lsp_logs = cx .try_global::() - .and_then(|lsp_logs| lsp_logs.0.upgrade()); - let lsp_store = self.lsp_store.upgrade(); - let Some((lsp_logs, lsp_store)) = lsp_logs.zip(lsp_store) else { + .map(|lsp_logs| lsp_logs.0.clone()); + let Some(lsp_logs) = lsp_logs else { return menu; }; @@ -210,10 +212,11 @@ impl LanguageServerState { }; let server_selector = server_info.server_selector(); - // TODO currently, Zed remote does not work well with the LSP logs - // https://github.com/zed-industries/zed/issues/28557 - let has_logs = lsp_store.read(cx).as_local().is_some() - && lsp_logs.read(cx).has_server_logs(&server_selector); + let is_remote = self + .lsp_store + .update(cx, |lsp_store, _| lsp_store.as_remote().is_some()) + .unwrap_or(false); + let has_logs = is_remote || lsp_logs.read(cx).has_server_logs(&server_selector); let status_color = server_info .binary_status @@ -241,10 +244,10 @@ impl LanguageServerState { .as_ref() .or_else(|| server_info.binary_status.as_ref()?.message.as_ref()) .cloned(); - let hover_label = if has_logs { - Some("View Logs") - } else if message.is_some() { + let hover_label = if message.is_some() { Some("View Message") + } else if has_logs { + Some("View Logs") } else { None }; @@ -288,21 +291,12 @@ impl LanguageServerState { let server_name = server_info.name.clone(); let workspace = self.workspace.clone(); move |window, cx| { - if has_logs { - lsp_logs.update(cx, |lsp_logs, cx| { - lsp_logs.open_server_trace( - workspace.clone(), - server_selector.clone(), - window, - cx, - ); - }); - } else if let Some(message) = &message { + if let Some(message) = &message { let Some(create_buffer) = workspace .update(cx, |workspace, cx| { workspace .project() - .update(cx, |project, cx| project.create_buffer(cx)) + .update(cx, |project, cx| project.create_buffer(false, cx)) }) .ok() else { @@ -347,9 +341,16 @@ impl LanguageServerState { anyhow::Ok(()) }) .detach(); + } else if has_logs { + lsp_log_view::open_server_trace( + &lsp_logs, + workspace.clone(), + server_selector.clone(), + window, + cx, + ); } else { cx.propagate(); - return; } } }, @@ -511,7 +512,7 @@ impl ServerData<'_> { } } -impl LspTool { +impl LspButton { pub fn new( workspace: &Workspace, popover_menu_handle: PopoverMenuHandle, @@ -519,38 +520,59 @@ impl LspTool { cx: &mut Context, ) -> Self { let settings_subscription = - cx.observe_global_in::(window, move |lsp_tool, window, cx| { + cx.observe_global_in::(window, move |lsp_button, window, cx| { if ProjectSettings::get_global(cx).global_lsp_settings.button { - if lsp_tool.lsp_menu.is_none() { - lsp_tool.refresh_lsp_menu(true, window, cx); - return; + if lsp_button.lsp_menu.is_none() { + lsp_button.refresh_lsp_menu(true, window, cx); } - } else if lsp_tool.lsp_menu.take().is_some() { + } else if lsp_button.lsp_menu.take().is_some() { cx.notify(); } }); let lsp_store = workspace.project().read(cx).lsp_store(); + let mut language_servers = LanguageServers::default(); + for (_, status) in lsp_store.read(cx).language_server_statuses() { + language_servers.binary_statuses.insert( + status.name.clone(), + LanguageServerBinaryStatus { + status: BinaryStatus::None, + message: None, + }, + ); + } + let lsp_store_subscription = - cx.subscribe_in(&lsp_store, window, |lsp_tool, _, e, window, cx| { - lsp_tool.on_lsp_store_event(e, window, cx) + cx.subscribe_in(&lsp_store, window, |lsp_button, _, e, window, cx| { + lsp_button.on_lsp_store_event(e, window, cx) }); - let state = cx.new(|_| LanguageServerState { + let server_state = cx.new(|_| LanguageServerState { workspace: workspace.weak_handle(), items: Vec::new(), lsp_store: lsp_store.downgrade(), active_editor: None, - language_servers: LanguageServers::default(), + language_servers, }); - Self { - server_state: state, + let mut lsp_button = Self { + server_state, popover_menu_handle, lsp_menu: None, lsp_menu_refresh: Task::ready(()), _subscriptions: vec![settings_subscription, lsp_store_subscription], + }; + if !lsp_button + .server_state + .read(cx) + .language_servers + .binary_statuses + .is_empty() + { + lsp_button.refresh_lsp_menu(true, window, cx); } + + lsp_button } fn on_lsp_store_event( @@ -710,6 +732,25 @@ impl LspTool { } } } + state + .lsp_store + .update(cx, |lsp_store, cx| { + for (server_id, status) in lsp_store.language_server_statuses() { + if let Some(worktree) = status.worktree.and_then(|worktree_id| { + lsp_store + .worktree_store() + .read(cx) + .worktree_for_id(worktree_id, cx) + }) { + server_ids_to_worktrees.insert(server_id, worktree.clone()); + server_names_to_worktrees + .entry(status.name.clone()) + .or_default() + .insert((worktree, server_id)); + } + } + }) + .ok(); let mut servers_per_worktree = BTreeMap::>::new(); let mut servers_without_worktree = Vec::::new(); @@ -854,18 +895,18 @@ impl LspTool { ) { if create_if_empty || self.lsp_menu.is_some() { let state = self.server_state.clone(); - self.lsp_menu_refresh = cx.spawn_in(window, async move |lsp_tool, cx| { + self.lsp_menu_refresh = cx.spawn_in(window, async move |lsp_button, cx| { cx.background_executor() .timer(Duration::from_millis(30)) .await; - lsp_tool - .update_in(cx, |lsp_tool, window, cx| { - lsp_tool.regenerate_items(cx); + lsp_button + .update_in(cx, |lsp_button, window, cx| { + lsp_button.regenerate_items(cx); let menu = ContextMenu::build(window, cx, |menu, _, cx| { state.update(cx, |state, cx| state.fill_menu(menu, cx)) }); - lsp_tool.lsp_menu = Some(menu.clone()); - lsp_tool.popover_menu_handle.refresh_menu( + lsp_button.lsp_menu = Some(menu.clone()); + lsp_button.popover_menu_handle.refresh_menu( window, cx, Rc::new(move |_, _| Some(menu.clone())), @@ -878,7 +919,7 @@ impl LspTool { } } -impl StatusItemView for LspTool { +impl StatusItemView for LspButton { fn set_active_pane_item( &mut self, active_pane_item: Option<&dyn workspace::ItemHandle>, @@ -901,9 +942,9 @@ impl StatusItemView for LspTool { let _editor_subscription = cx.subscribe_in( &editor, window, - |lsp_tool, _, e: &EditorEvent, window, cx| match e { + |lsp_button, _, e: &EditorEvent, window, cx| match e { EditorEvent::ExcerptsAdded { buffer, .. } => { - let updated = lsp_tool.server_state.update(cx, |state, cx| { + let updated = lsp_button.server_state.update(cx, |state, cx| { if let Some(active_editor) = state.active_editor.as_mut() { let buffer_id = buffer.read(cx).remote_id(); active_editor.editor_buffers.insert(buffer_id) @@ -912,13 +953,13 @@ impl StatusItemView for LspTool { } }); if updated { - lsp_tool.refresh_lsp_menu(false, window, cx); + lsp_button.refresh_lsp_menu(false, window, cx); } } EditorEvent::ExcerptsRemoved { removed_buffer_ids, .. } => { - let removed = lsp_tool.server_state.update(cx, |state, _| { + let removed = lsp_button.server_state.update(cx, |state, _| { let mut removed = false; if let Some(active_editor) = state.active_editor.as_mut() { for id in removed_buffer_ids { @@ -932,7 +973,7 @@ impl StatusItemView for LspTool { removed }); if removed { - lsp_tool.refresh_lsp_menu(false, window, cx); + lsp_button.refresh_lsp_menu(false, window, cx); } } _ => {} @@ -962,7 +1003,7 @@ impl StatusItemView for LspTool { } } -impl Render for LspTool { +impl Render for LspButton { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl ui::IntoElement { if self.server_state.read(cx).language_servers.is_empty() || self.lsp_menu.is_none() { return div(); @@ -1007,11 +1048,11 @@ impl Render for LspTool { (None, "All Servers Operational") }; - let lsp_tool = cx.entity().clone(); + let lsp_button = cx.entity(); div().child( PopoverMenu::new("lsp-tool") - .menu(move |_, cx| lsp_tool.read(cx).lsp_menu.clone()) + .menu(move |_, cx| lsp_button.read(cx).lsp_menu.clone()) .anchor(Corner::BottomLeft) .with_handle(self.popover_menu_handle.clone()) .trigger_with_tooltip( diff --git a/crates/language_tools/src/lsp_log.rs b/crates/language_tools/src/lsp_log_view.rs similarity index 61% rename from crates/language_tools/src/lsp_log.rs rename to crates/language_tools/src/lsp_log_view.rs index 606f3a3f0e5f91b5fb8856cabce240d094f3cf49..fb63ab9a99147328c4987bd80b698ef4a477f013 100644 --- a/crates/language_tools/src/lsp_log.rs +++ b/crates/language_tools/src/lsp_log_view.rs @@ -1,20 +1,25 @@ -use collections::{HashMap, VecDeque}; +use collections::VecDeque; use copilot::Copilot; use editor::{Editor, EditorEvent, actions::MoveToEnd, scroll::Autoscroll}; -use futures::{StreamExt, channel::mpsc}; use gpui::{ - AnyView, App, Context, Corner, Entity, EventEmitter, FocusHandle, Focusable, Global, - IntoElement, ParentElement, Render, Styled, Subscription, WeakEntity, Window, actions, div, + AnyView, App, Context, Corner, Entity, EventEmitter, FocusHandle, Focusable, IntoElement, + ParentElement, Render, Styled, Subscription, WeakEntity, Window, actions, div, }; use itertools::Itertools; use language::{LanguageServerId, language_settings::SoftWrap}; use lsp::{ - IoKind, LanguageServer, LanguageServerName, LanguageServerSelector, MessageType, + LanguageServer, LanguageServerBinary, LanguageServerName, LanguageServerSelector, MessageType, SetTraceParams, TraceValue, notification::SetTrace, }; -use project::{Project, WorktreeId, search::SearchQuery}; +use project::{ + Project, + lsp_store::log_store::{self, Event, LanguageServerKind, LogKind, LogStore, Message}, + search::SearchQuery, +}; +use proto::toggle_lsp_logs::LogType; use std::{any::TypeId, borrow::Cow, sync::Arc}; use ui::{Button, Checkbox, ContextMenu, Label, PopoverMenu, ToggleState, prelude::*}; +use util::ResultExt as _; use workspace::{ SplitDirection, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace, WorkspaceId, item::{Item, ItemHandle}, @@ -23,132 +28,53 @@ use workspace::{ use crate::get_or_create_tool; -const SEND_LINE: &str = "\n// Send:"; -const RECEIVE_LINE: &str = "\n// Receive:"; -const MAX_STORED_LOG_ENTRIES: usize = 2000; - -pub struct LogStore { - projects: HashMap, ProjectState>, - language_servers: HashMap, - copilot_log_subscription: Option, - _copilot_subscription: Option, - io_tx: mpsc::UnboundedSender<(LanguageServerId, IoKind, String)>, -} - -struct ProjectState { - _subscriptions: [gpui::Subscription; 2], -} - -trait Message: AsRef { - type Level: Copy + std::fmt::Debug; - fn should_include(&self, _: Self::Level) -> bool { - true - } -} - -pub(super) struct LogMessage { - message: String, - typ: MessageType, -} - -impl AsRef for LogMessage { - fn as_ref(&self) -> &str { - &self.message - } -} - -impl Message for LogMessage { - type Level = MessageType; - - fn should_include(&self, level: Self::Level) -> bool { - match (self.typ, level) { - (MessageType::ERROR, _) => true, - (_, MessageType::ERROR) => false, - (MessageType::WARNING, _) => true, - (_, MessageType::WARNING) => false, - (MessageType::INFO, _) => true, - (_, MessageType::INFO) => false, - _ => true, - } - } -} - -pub(super) struct TraceMessage { - message: String, -} - -impl AsRef for TraceMessage { - fn as_ref(&self) -> &str { - &self.message - } -} - -impl Message for TraceMessage { - type Level = (); -} - -struct RpcMessage { - message: String, -} - -impl AsRef for RpcMessage { - fn as_ref(&self) -> &str { - &self.message - } -} - -impl Message for RpcMessage { - type Level = (); -} - -pub(super) struct LanguageServerState { - name: Option, - worktree_id: Option, - kind: LanguageServerKind, - log_messages: VecDeque, - trace_messages: VecDeque, - rpc_state: Option, - trace_level: TraceValue, - log_level: MessageType, - io_logs_subscription: Option, -} - -#[derive(PartialEq, Clone)] -pub enum LanguageServerKind { - Local { project: WeakEntity }, - Remote { project: WeakEntity }, - Global, -} - -impl LanguageServerKind { - fn is_remote(&self) -> bool { - matches!(self, LanguageServerKind::Remote { .. }) - } -} - -impl std::fmt::Debug for LanguageServerKind { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - LanguageServerKind::Local { .. } => write!(f, "LanguageServerKind::Local"), - LanguageServerKind::Remote { .. } => write!(f, "LanguageServerKind::Remote"), - LanguageServerKind::Global => write!(f, "LanguageServerKind::Global"), - } - } -} - -impl LanguageServerKind { - fn project(&self) -> Option<&WeakEntity> { - match self { - Self::Local { project } => Some(project), - Self::Remote { project } => Some(project), - Self::Global { .. } => None, - } - } -} - -struct LanguageServerRpcState { - rpc_messages: VecDeque, - last_message_kind: Option, +pub fn open_server_trace( + log_store: &Entity, + workspace: WeakEntity, + server: LanguageServerSelector, + window: &mut Window, + cx: &mut App, +) { + log_store.update(cx, |_, cx| { + cx.spawn_in(window, async move |log_store, cx| { + let Some(log_store) = log_store.upgrade() else { + return; + }; + workspace + .update_in(cx, |workspace, window, cx| { + let project = workspace.project().clone(); + let tool_log_store = log_store.clone(); + let log_view = get_or_create_tool( + workspace, + SplitDirection::Right, + window, + cx, + move |window, cx| LspLogView::new(project, tool_log_store, window, cx), + ); + log_view.update(cx, |log_view, cx| { + let server_id = match server { + LanguageServerSelector::Id(id) => Some(id), + LanguageServerSelector::Name(name) => { + log_store.read(cx).language_servers.iter().find_map( + |(id, state)| { + if state.name.as_ref() == Some(&name) { + Some(*id) + } else { + None + } + }, + ) + } + }; + if let Some(server_id) = server_id { + log_view.show_rpc_trace_for_server(server_id, window, cx); + } + }); + }) + .ok(); + }) + .detach(); + }) } pub struct LspLogView { @@ -167,32 +93,6 @@ pub struct LspLogToolbarItemView { _log_view_subscription: Option, } -#[derive(Copy, Clone, PartialEq, Eq)] -enum MessageKind { - Send, - Receive, -} - -#[derive(Clone, Copy, Debug, Default, PartialEq)] -pub enum LogKind { - Rpc, - Trace, - #[default] - Logs, - ServerInfo, -} - -impl LogKind { - fn label(&self) -> &'static str { - match self { - LogKind::Rpc => RPC_MESSAGES, - LogKind::Trace => SERVER_TRACE, - LogKind::Logs => SERVER_LOGS, - LogKind::ServerInfo => SERVER_INFO, - } - } -} - #[derive(Clone, Debug, PartialEq)] pub(crate) struct LogMenuItem { pub server_id: LanguageServerId, @@ -212,505 +112,68 @@ actions!( ] ); -pub(super) struct GlobalLogStore(pub WeakEntity); - -impl Global for GlobalLogStore {} +pub fn init(on_headless_host: bool, cx: &mut App) { + let log_store = log_store::init(on_headless_host, cx); -pub fn init(cx: &mut App) { - let log_store = cx.new(LogStore::new); - cx.set_global(GlobalLogStore(log_store.downgrade())); - - cx.observe_new(move |workspace: &mut Workspace, _, cx| { - let project = workspace.project(); - if project.read(cx).is_local() || project.read(cx).is_via_ssh() { - log_store.update(cx, |store, cx| { - store.add_project(project, cx); - }); - } - - let log_store = log_store.clone(); - workspace.register_action(move |workspace, _: &OpenLanguageServerLogs, window, cx| { - let project = workspace.project().read(cx); - if project.is_local() || project.is_via_ssh() { - let project = workspace.project().clone(); - let log_store = log_store.clone(); - get_or_create_tool( - workspace, - SplitDirection::Right, - window, - cx, - move |window, cx| LspLogView::new(project, log_store, window, cx), - ); - } - }); - }) - .detach(); -} - -impl LogStore { - pub fn new(cx: &mut Context) -> Self { - let (io_tx, mut io_rx) = mpsc::unbounded(); - - let copilot_subscription = Copilot::global(cx).map(|copilot| { + log_store.update(cx, |_, cx| { + Copilot::global(cx).map(|copilot| { let copilot = &copilot; - cx.subscribe(copilot, |this, copilot, edit_prediction_event, cx| { - if let copilot::Event::CopilotLanguageServerStarted = edit_prediction_event { - if let Some(server) = copilot.read(cx).language_server() { - let server_id = server.server_id(); - let weak_this = cx.weak_entity(); - this.copilot_log_subscription = - Some(server.on_notification::( - move |params, cx| { - weak_this - .update(cx, |this, cx| { - this.add_language_server_log( - server_id, - MessageType::LOG, - ¶ms.message, - cx, - ); - }) - .ok(); - }, - )); - let name = LanguageServerName::new_static("copilot"); - this.add_language_server( - LanguageServerKind::Global, - server.server_id(), - Some(name), - None, - Some(server.clone()), - cx, - ); - } + cx.subscribe(copilot, |log_store, copilot, edit_prediction_event, cx| { + if let copilot::Event::CopilotLanguageServerStarted = edit_prediction_event + && let Some(server) = copilot.read(cx).language_server() + { + let server_id = server.server_id(); + let weak_lsp_store = cx.weak_entity(); + log_store.copilot_log_subscription = + Some(server.on_notification::( + move |params, cx| { + weak_lsp_store + .update(cx, |lsp_store, cx| { + lsp_store.add_language_server_log( + server_id, + MessageType::LOG, + ¶ms.message, + cx, + ); + }) + .ok(); + }, + )); + + let name = LanguageServerName::new_static("copilot"); + log_store.add_language_server( + LanguageServerKind::Global, + server.server_id(), + Some(name), + None, + Some(server.clone()), + cx, + ); } }) - }); - - let this = Self { - copilot_log_subscription: None, - _copilot_subscription: copilot_subscription, - projects: HashMap::default(), - language_servers: HashMap::default(), - io_tx, - }; - - cx.spawn(async move |this, cx| { - while let Some((server_id, io_kind, message)) = io_rx.next().await { - if let Some(this) = this.upgrade() { - this.update(cx, |this, cx| { - this.on_io(server_id, io_kind, &message, cx); - })?; - } - } - anyhow::Ok(()) + .detach(); }) - .detach_and_log_err(cx); - this - } + }); - pub fn add_project(&mut self, project: &Entity, cx: &mut Context) { - let weak_project = project.downgrade(); - self.projects.insert( - project.downgrade(), - ProjectState { - _subscriptions: [ - cx.observe_release(project, move |this, _, _| { - this.projects.remove(&weak_project); - this.language_servers - .retain(|_, state| state.kind.project() != Some(&weak_project)); - }), - cx.subscribe(project, |this, project, event, cx| { - let server_kind = if project.read(cx).is_via_ssh() { - LanguageServerKind::Remote { - project: project.downgrade(), - } - } else { - LanguageServerKind::Local { - project: project.downgrade(), - } - }; - - match event { - project::Event::LanguageServerAdded(id, name, worktree_id) => { - this.add_language_server( - server_kind, - *id, - Some(name.clone()), - *worktree_id, - project - .read(cx) - .lsp_store() - .read(cx) - .language_server_for_id(*id), - cx, - ); - } - project::Event::LanguageServerRemoved(id) => { - this.remove_language_server(*id, cx); - } - project::Event::LanguageServerLog(id, typ, message) => { - this.add_language_server(server_kind, *id, None, None, None, cx); - match typ { - project::LanguageServerLogType::Log(typ) => { - this.add_language_server_log(*id, *typ, message, cx); - } - project::LanguageServerLogType::Trace(_) => { - this.add_language_server_trace(*id, message, cx); - } - } - } - _ => {} - } - }), - ], - }, - ); - } - - pub(super) fn get_language_server_state( - &mut self, - id: LanguageServerId, - ) -> Option<&mut LanguageServerState> { - self.language_servers.get_mut(&id) - } - - fn add_language_server( - &mut self, - kind: LanguageServerKind, - server_id: LanguageServerId, - name: Option, - worktree_id: Option, - server: Option>, - cx: &mut Context, - ) -> Option<&mut LanguageServerState> { - let server_state = self.language_servers.entry(server_id).or_insert_with(|| { - cx.notify(); - LanguageServerState { - name: None, - worktree_id: None, - kind, - rpc_state: None, - log_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES), - trace_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES), - trace_level: TraceValue::Off, - log_level: MessageType::LOG, - io_logs_subscription: None, - } + cx.observe_new(move |workspace: &mut Workspace, _, cx| { + log_store.update(cx, |store, cx| { + store.add_project(workspace.project(), cx); }); - if let Some(name) = name { - server_state.name = Some(name); - } - if let Some(worktree_id) = worktree_id { - server_state.worktree_id = Some(worktree_id); - } - - if let Some(server) = server - .clone() - .filter(|_| server_state.io_logs_subscription.is_none()) - { - let io_tx = self.io_tx.clone(); - let server_id = server.server_id(); - server_state.io_logs_subscription = Some(server.on_io(move |io_kind, message| { - io_tx - .unbounded_send((server_id, io_kind, message.to_string())) - .ok(); - })); - } - - Some(server_state) - } - - fn add_language_server_log( - &mut self, - id: LanguageServerId, - typ: MessageType, - message: &str, - cx: &mut Context, - ) -> Option<()> { - let language_server_state = self.get_language_server_state(id)?; - - let log_lines = &mut language_server_state.log_messages; - Self::add_language_server_message( - log_lines, - id, - LogMessage { - message: message.trim_end().to_string(), - typ, - }, - language_server_state.log_level, - LogKind::Logs, - cx, - ); - Some(()) - } - - fn add_language_server_trace( - &mut self, - id: LanguageServerId, - message: &str, - cx: &mut Context, - ) -> Option<()> { - let language_server_state = self.get_language_server_state(id)?; - - let log_lines = &mut language_server_state.trace_messages; - Self::add_language_server_message( - log_lines, - id, - TraceMessage { - message: message.trim().to_string(), - }, - (), - LogKind::Trace, - cx, - ); - Some(()) - } - - fn add_language_server_message( - log_lines: &mut VecDeque, - id: LanguageServerId, - message: T, - current_severity: ::Level, - kind: LogKind, - cx: &mut Context, - ) { - while log_lines.len() + 1 >= MAX_STORED_LOG_ENTRIES { - log_lines.pop_front(); - } - let text = message.as_ref().to_string(); - let visible = message.should_include(current_severity); - log_lines.push_back(message); - - if visible { - cx.emit(Event::NewServerLogEntry { id, kind, text }); - cx.notify(); - } - } - - fn remove_language_server(&mut self, id: LanguageServerId, cx: &mut Context) { - self.language_servers.remove(&id); - cx.notify(); - } - - pub(super) fn server_logs(&self, server_id: LanguageServerId) -> Option<&VecDeque> { - Some(&self.language_servers.get(&server_id)?.log_messages) - } - - pub(super) fn server_trace( - &self, - server_id: LanguageServerId, - ) -> Option<&VecDeque> { - Some(&self.language_servers.get(&server_id)?.trace_messages) - } - - fn server_ids_for_project<'a>( - &'a self, - lookup_project: &'a WeakEntity, - ) -> impl Iterator + 'a { - self.language_servers - .iter() - .filter_map(move |(id, state)| match &state.kind { - LanguageServerKind::Local { project } | LanguageServerKind::Remote { project } => { - if project == lookup_project { - Some(*id) - } else { - None - } - } - LanguageServerKind::Global => Some(*id), - }) - } - - fn enable_rpc_trace_for_language_server( - &mut self, - server_id: LanguageServerId, - ) -> Option<&mut LanguageServerRpcState> { - let rpc_state = self - .language_servers - .get_mut(&server_id)? - .rpc_state - .get_or_insert_with(|| LanguageServerRpcState { - rpc_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES), - last_message_kind: None, - }); - Some(rpc_state) - } - - pub fn disable_rpc_trace_for_language_server( - &mut self, - server_id: LanguageServerId, - ) -> Option<()> { - self.language_servers.get_mut(&server_id)?.rpc_state.take(); - Some(()) - } - - pub fn has_server_logs(&self, server: &LanguageServerSelector) -> bool { - match server { - LanguageServerSelector::Id(id) => self.language_servers.contains_key(id), - LanguageServerSelector::Name(name) => self - .language_servers - .iter() - .any(|(_, state)| state.name.as_ref() == Some(name)), - } - } - - pub fn open_server_log( - &mut self, - workspace: WeakEntity, - server: LanguageServerSelector, - window: &mut Window, - cx: &mut Context, - ) { - cx.spawn_in(window, async move |log_store, cx| { - let Some(log_store) = log_store.upgrade() else { - return; - }; - workspace - .update_in(cx, |workspace, window, cx| { - let project = workspace.project().clone(); - let tool_log_store = log_store.clone(); - let log_view = get_or_create_tool( - workspace, - SplitDirection::Right, - window, - cx, - move |window, cx| LspLogView::new(project, tool_log_store, window, cx), - ); - log_view.update(cx, |log_view, cx| { - let server_id = match server { - LanguageServerSelector::Id(id) => Some(id), - LanguageServerSelector::Name(name) => { - log_store.read(cx).language_servers.iter().find_map( - |(id, state)| { - if state.name.as_ref() == Some(&name) { - Some(*id) - } else { - None - } - }, - ) - } - }; - if let Some(server_id) = server_id { - log_view.show_logs_for_server(server_id, window, cx); - } - }); - }) - .ok(); - }) - .detach(); - } - - pub fn open_server_trace( - &mut self, - workspace: WeakEntity, - server: LanguageServerSelector, - window: &mut Window, - cx: &mut Context, - ) { - cx.spawn_in(window, async move |log_store, cx| { - let Some(log_store) = log_store.upgrade() else { - return; - }; - workspace - .update_in(cx, |workspace, window, cx| { - let project = workspace.project().clone(); - let tool_log_store = log_store.clone(); - let log_view = get_or_create_tool( - workspace, - SplitDirection::Right, - window, - cx, - move |window, cx| LspLogView::new(project, tool_log_store, window, cx), - ); - log_view.update(cx, |log_view, cx| { - let server_id = match server { - LanguageServerSelector::Id(id) => Some(id), - LanguageServerSelector::Name(name) => { - log_store.read(cx).language_servers.iter().find_map( - |(id, state)| { - if state.name.as_ref() == Some(&name) { - Some(*id) - } else { - None - } - }, - ) - } - }; - if let Some(server_id) = server_id { - log_view.show_rpc_trace_for_server(server_id, window, cx); - } - }); - }) - .ok(); - }) - .detach(); - } - - fn on_io( - &mut self, - language_server_id: LanguageServerId, - io_kind: IoKind, - message: &str, - cx: &mut Context, - ) -> Option<()> { - let is_received = match io_kind { - IoKind::StdOut => true, - IoKind::StdIn => false, - IoKind::StdErr => { - self.add_language_server_log(language_server_id, MessageType::LOG, &message, cx); - return Some(()); - } - }; - - let state = self - .get_language_server_state(language_server_id)? - .rpc_state - .as_mut()?; - let kind = if is_received { - MessageKind::Receive - } else { - MessageKind::Send - }; - - let rpc_log_lines = &mut state.rpc_messages; - if state.last_message_kind != Some(kind) { - while rpc_log_lines.len() + 1 >= MAX_STORED_LOG_ENTRIES { - rpc_log_lines.pop_front(); - } - let line_before_message = match kind { - MessageKind::Send => SEND_LINE, - MessageKind::Receive => RECEIVE_LINE, - }; - rpc_log_lines.push_back(RpcMessage { - message: line_before_message.to_string(), - }); - cx.emit(Event::NewServerLogEntry { - id: language_server_id, - kind: LogKind::Rpc, - text: line_before_message.to_string(), - }); - } - - while rpc_log_lines.len() + 1 >= MAX_STORED_LOG_ENTRIES { - rpc_log_lines.pop_front(); - } - - let message = message.trim(); - rpc_log_lines.push_back(RpcMessage { - message: message.to_string(), - }); - cx.emit(Event::NewServerLogEntry { - id: language_server_id, - kind: LogKind::Rpc, - text: message.to_string(), + let log_store = log_store.clone(); + workspace.register_action(move |workspace, _: &OpenLanguageServerLogs, window, cx| { + let log_store = log_store.clone(); + let project = workspace.project().clone(); + get_or_create_tool( + workspace, + SplitDirection::Right, + window, + cx, + move |window, cx| LspLogView::new(project, log_store, window, cx), + ); }); - cx.notify(); - Some(()) - } + }) + .detach(); } impl LspLogView { @@ -733,16 +196,14 @@ impl LspLogView { let first_server_id_for_project = store.read(cx).server_ids_for_project(&weak_project).next(); if let Some(current_lsp) = this.current_server_id { - if !store.read(cx).language_servers.contains_key(¤t_lsp) { - if let Some(server_id) = first_server_id_for_project { - match this.active_entry_kind { - LogKind::Rpc => { - this.show_rpc_trace_for_server(server_id, window, cx) - } - LogKind::Trace => this.show_trace_for_server(server_id, window, cx), - LogKind::Logs => this.show_logs_for_server(server_id, window, cx), - LogKind::ServerInfo => this.show_server_info(server_id, window, cx), - } + if !store.read(cx).language_servers.contains_key(¤t_lsp) + && let Some(server_id) = first_server_id_for_project + { + match this.active_entry_kind { + LogKind::Rpc => this.show_rpc_trace_for_server(server_id, window, cx), + LogKind::Trace => this.show_trace_for_server(server_id, window, cx), + LogKind::Logs => this.show_logs_for_server(server_id, window, cx), + LogKind::ServerInfo => this.show_server_info(server_id, window, cx), } } } else if let Some(server_id) = first_server_id_for_project { @@ -756,13 +217,14 @@ impl LspLogView { cx.notify(); }); + let events_subscriptions = cx.subscribe_in( &log_store, window, move |log_view, _, e, window, cx| match e { Event::NewServerLogEntry { id, kind, text } => { if log_view.current_server_id == Some(*id) - && *kind == log_view.active_entry_kind + && LogKind::from_server_log_type(kind) == log_view.active_entry_kind { log_view.editor.update(cx, |editor, cx| { editor.set_read_only(false); @@ -776,21 +238,17 @@ impl LspLogView { ], cx, ); - if text.len() > 1024 { - if let Some((fold_offset, _)) = + if text.len() > 1024 + && let Some((fold_offset, _)) = text.char_indices().dropping(1024).next() - { - if fold_offset < text.len() { - editor.fold_ranges( - vec![ - last_offset + fold_offset..last_offset + text.len(), - ], - false, - window, - cx, - ); - } - } + && fold_offset < text.len() + { + editor.fold_ranges( + vec![last_offset + fold_offset..last_offset + text.len()], + false, + window, + cx, + ); } if newest_cursor_is_at_end { @@ -809,7 +267,20 @@ impl LspLogView { window.focus(&log_view.editor.focus_handle(cx)); }); - let mut this = Self { + cx.on_release(|log_view, cx| { + log_view.log_store.update(cx, |log_store, cx| { + for (server_id, state) in &log_store.language_servers { + if let Some(log_kind) = state.toggled_log_kind { + if let Some(log_type) = log_type(log_kind) { + send_toggle_log_message(state, *server_id, false, log_type, cx); + } + } + } + }); + }) + .detach(); + + let mut lsp_log_view = Self { focus_handle, editor, editor_subscriptions, @@ -824,9 +295,9 @@ impl LspLogView { ], }; if let Some(server_id) = server_id { - this.show_logs_for_server(server_id, window, cx); + lsp_log_view.show_logs_for_server(server_id, window, cx); } - this + lsp_log_view } fn editor_for_logs( @@ -847,14 +318,14 @@ impl LspLogView { } fn editor_for_server_info( - server: &LanguageServer, + info: ServerInfo, window: &mut Window, cx: &mut Context, ) -> (Entity, Vec) { let server_info = format!( "* Server: {NAME} (id {ID}) -* Binary: {BINARY:#?} +* Binary: {BINARY} * Registered workspace folders: {WORKSPACE_FOLDERS} @@ -862,22 +333,21 @@ impl LspLogView { * Capabilities: {CAPABILITIES} * Configuration: {CONFIGURATION}", - NAME = server.name(), - ID = server.server_id(), - BINARY = server.binary(), - WORKSPACE_FOLDERS = server - .workspace_folders() - .into_iter() - .filter_map(|path| path - .to_file_path() - .ok() - .map(|path| path.to_string_lossy().into_owned())) - .collect::>() - .join(", "), - CAPABILITIES = serde_json::to_string_pretty(&server.capabilities()) + NAME = info.name, + ID = info.id, + BINARY = info + .binary + .as_ref() + .map_or_else(|| "Unknown".to_string(), |binary| format!("{binary:#?}")), + WORKSPACE_FOLDERS = info.workspace_folders.join(", "), + CAPABILITIES = serde_json::to_string_pretty(&info.capabilities) .unwrap_or_else(|e| format!("Failed to serialize capabilities: {e}")), - CONFIGURATION = serde_json::to_string_pretty(server.configuration()) - .unwrap_or_else(|e| format!("Failed to serialize configuration: {e}")), + CONFIGURATION = info + .configuration + .map(|configuration| serde_json::to_string_pretty(&configuration)) + .transpose() + .unwrap_or_else(|e| Some(format!("Failed to serialize configuration: {e}"))) + .unwrap_or_else(|| "Unknown".to_string()), ); let editor = initialize_new_editor(server_info, false, window, cx); let editor_subscription = cx.subscribe( @@ -900,7 +370,9 @@ impl LspLogView { .language_servers .iter() .map(|(server_id, state)| match &state.kind { - LanguageServerKind::Local { .. } | LanguageServerKind::Remote { .. } => { + LanguageServerKind::Local { .. } + | LanguageServerKind::Remote { .. } + | LanguageServerKind::LocalSsh { .. } => { let worktree_root_name = state .worktree_id .and_then(|id| self.project.read(cx).worktree_for_id(id, cx)) @@ -936,7 +408,7 @@ impl LspLogView { let state = log_store.language_servers.get(&server_id)?; Some(LogMenuItem { server_id, - server_name: name.clone(), + server_name: name, server_kind: state.kind.clone(), worktree_root_name: "supplementary".to_string(), rpc_trace_enabled: state.rpc_state.is_some(), @@ -978,6 +450,12 @@ impl LspLogView { cx.notify(); } self.editor.read(cx).focus_handle(cx).focus(window); + self.log_store.update(cx, |log_store, cx| { + let state = log_store.get_language_server_state(server_id)?; + state.toggled_log_kind = Some(LogKind::Logs); + send_toggle_log_message(state, server_id, true, LogType::Log, cx); + Some(()) + }); } fn update_log_level( @@ -1012,17 +490,29 @@ impl LspLogView { window: &mut Window, cx: &mut Context, ) { + let trace_level = self + .log_store + .update(cx, |log_store, _| { + Some(log_store.get_language_server_state(server_id)?.trace_level) + }) + .unwrap_or(TraceValue::Messages); let log_contents = self .log_store .read(cx) .server_trace(server_id) - .map(|v| log_contents(v, ())); + .map(|v| log_contents(v, trace_level)); if let Some(log_contents) = log_contents { self.current_server_id = Some(server_id); self.active_entry_kind = LogKind::Trace; let (editor, editor_subscriptions) = Self::editor_for_logs(log_contents, window, cx); self.editor = editor; self.editor_subscriptions = editor_subscriptions; + self.log_store.update(cx, |log_store, cx| { + let state = log_store.get_language_server_state(server_id)?; + state.toggled_log_kind = Some(LogKind::Trace); + send_toggle_log_message(state, server_id, true, LogType::Trace, cx); + Some(()) + }); cx.notify(); } self.editor.read(cx).focus_handle(cx).focus(window); @@ -1034,6 +524,7 @@ impl LspLogView { window: &mut Window, cx: &mut Context, ) { + self.toggle_rpc_trace_for_server(server_id, true, window, cx); let rpc_log = self.log_store.update(cx, |log_store, _| { log_store .enable_rpc_trace_for_language_server(server_id) @@ -1078,12 +569,16 @@ impl LspLogView { window: &mut Window, cx: &mut Context, ) { - self.log_store.update(cx, |log_store, _| { + self.log_store.update(cx, |log_store, cx| { if enabled { log_store.enable_rpc_trace_for_language_server(server_id); } else { log_store.disable_rpc_trace_for_language_server(server_id); } + + if let Some(server_state) = log_store.language_servers.get(&server_id) { + send_toggle_log_message(server_state, server_id, enabled, LogType::Rpc, cx); + }; }); if !enabled && Some(server_id) == self.current_server_id { self.show_logs_for_server(server_id, window, cx); @@ -1122,17 +617,85 @@ impl LspLogView { window: &mut Window, cx: &mut Context, ) { - let lsp_store = self.project.read(cx).lsp_store(); - let Some(server) = lsp_store.read(cx).language_server_for_id(server_id) else { + let Some(server_info) = self + .project + .read(cx) + .lsp_store() + .update(cx, |lsp_store, _| { + lsp_store + .language_server_for_id(server_id) + .as_ref() + .map(|language_server| ServerInfo::new(language_server)) + .or_else(move || { + let capabilities = + lsp_store.lsp_server_capabilities.get(&server_id)?.clone(); + let name = lsp_store + .language_server_statuses + .get(&server_id) + .map(|status| status.name.clone())?; + Some(ServerInfo { + id: server_id, + capabilities, + binary: None, + name, + workspace_folders: Vec::new(), + configuration: None, + }) + }) + }) + else { return; }; self.current_server_id = Some(server_id); self.active_entry_kind = LogKind::ServerInfo; - let (editor, editor_subscriptions) = Self::editor_for_server_info(&server, window, cx); + let (editor, editor_subscriptions) = Self::editor_for_server_info(server_info, window, cx); self.editor = editor; self.editor_subscriptions = editor_subscriptions; cx.notify(); self.editor.read(cx).focus_handle(cx).focus(window); + self.log_store.update(cx, |log_store, cx| { + let state = log_store.get_language_server_state(server_id)?; + if let Some(log_kind) = state.toggled_log_kind.take() { + if let Some(log_type) = log_type(log_kind) { + send_toggle_log_message(state, server_id, false, log_type, cx); + } + }; + Some(()) + }); + } +} + +fn log_type(log_kind: LogKind) -> Option { + match log_kind { + LogKind::Rpc => Some(LogType::Rpc), + LogKind::Trace => Some(LogType::Trace), + LogKind::Logs => Some(LogType::Log), + LogKind::ServerInfo => None, + } +} + +fn send_toggle_log_message( + server_state: &log_store::LanguageServerState, + server_id: LanguageServerId, + enabled: bool, + log_type: LogType, + cx: &mut App, +) { + if let LanguageServerKind::Remote { project } = &server_state.kind { + project + .update(cx, |project, cx| { + if let Some((client, project_id)) = project.lsp_store().read(cx).upstream_client() { + client + .send(proto::ToggleLspLogs { + project_id, + log_type: log_type as i32, + server_id: server_id.to_proto(), + enabled, + }) + .log_err(); + } + }) + .ok(); } } @@ -1311,14 +874,14 @@ impl ToolbarItemView for LspLogToolbarItemView { _: &mut Window, cx: &mut Context, ) -> workspace::ToolbarItemLocation { - if let Some(item) = active_pane_item { - if let Some(log_view) = item.downcast::() { - self.log_view = Some(log_view.clone()); - self._log_view_subscription = Some(cx.observe(&log_view, |_, _, cx| { - cx.notify(); - })); - return ToolbarItemLocation::PrimaryLeft; - } + if let Some(item) = active_pane_item + && let Some(log_view) = item.downcast::() + { + self.log_view = Some(log_view.clone()); + self._log_view_subscription = Some(cx.observe(&log_view, |_, _, cx| { + cx.notify(); + })); + return ToolbarItemLocation::PrimaryLeft; } self.log_view = None; self._log_view_subscription = None; @@ -1358,7 +921,7 @@ impl Render for LspLogToolbarItemView { }) .collect(); - let log_toolbar_view = cx.entity().clone(); + let log_toolbar_view = cx.entity(); let lsp_menu = PopoverMenu::new("LspLogView") .anchor(Corner::TopLeft) @@ -1425,13 +988,18 @@ impl Render for LspLogToolbarItemView { let view_selector = current_server.map(|server| { let server_id = server.server_id; - let is_remote = server.server_kind.is_remote(); let rpc_trace_enabled = server.rpc_trace_enabled; let log_view = log_view.clone(); + let label = match server.selected_entry { + LogKind::Rpc => RPC_MESSAGES, + LogKind::Trace => SERVER_TRACE, + LogKind::Logs => SERVER_LOGS, + LogKind::ServerInfo => SERVER_INFO, + }; PopoverMenu::new("LspViewSelector") .anchor(Corner::TopLeft) .trigger( - Button::new("language_server_menu_header", server.selected_entry.label()) + Button::new("language_server_menu_header", label) .icon(IconName::ChevronDown) .icon_size(IconSize::Small) .icon_color(Color::Muted), @@ -1447,55 +1015,53 @@ impl Render for LspLogToolbarItemView { view.show_logs_for_server(server_id, window, cx); }), ) - .when(!is_remote, |this| { - this.entry( - SERVER_TRACE, - None, - window.handler_for(&log_view, move |view, window, cx| { - view.show_trace_for_server(server_id, window, cx); - }), - ) - .custom_entry( - { - let log_toolbar_view = log_toolbar_view.clone(); - move |window, _| { - h_flex() - .w_full() - .justify_between() - .child(Label::new(RPC_MESSAGES)) - .child( - div().child( - Checkbox::new( - "LspLogEnableRpcTrace", - if rpc_trace_enabled { + .entry( + SERVER_TRACE, + None, + window.handler_for(&log_view, move |view, window, cx| { + view.show_trace_for_server(server_id, window, cx); + }), + ) + .custom_entry( + { + let log_toolbar_view = log_toolbar_view.clone(); + move |window, _| { + h_flex() + .w_full() + .justify_between() + .child(Label::new(RPC_MESSAGES)) + .child( + div().child( + Checkbox::new( + "LspLogEnableRpcTrace", + if rpc_trace_enabled { + ToggleState::Selected + } else { + ToggleState::Unselected + }, + ) + .on_click(window.listener_for( + &log_toolbar_view, + move |view, selection, window, cx| { + let enabled = matches!( + selection, ToggleState::Selected - } else { - ToggleState::Unselected - }, - ) - .on_click(window.listener_for( - &log_toolbar_view, - move |view, selection, window, cx| { - let enabled = matches!( - selection, - ToggleState::Selected - ); - view.toggle_rpc_logging_for_server( - server_id, enabled, window, cx, - ); - cx.stop_propagation(); - }, - )), - ), - ) - .into_any_element() - } - }, - window.handler_for(&log_view, move |view, window, cx| { - view.show_rpc_trace_for_server(server_id, window, cx); - }), - ) - }) + ); + view.toggle_rpc_logging_for_server( + server_id, enabled, window, cx, + ); + cx.stop_propagation(); + }, + )), + ), + ) + .into_any_element() + } + }, + window.handler_for(&log_view, move |view, window, cx| { + view.show_rpc_trace_for_server(server_id, window, cx); + }), + ) .entry( SERVER_INFO, None, @@ -1533,7 +1099,7 @@ impl Render for LspLogToolbarItemView { .icon_color(Color::Muted), ) .menu({ - let log_view = log_view.clone(); + let log_view = log_view; move |window, cx| { let id = log_view.read(cx).current_server_id?; @@ -1601,7 +1167,7 @@ impl Render for LspLogToolbarItemView { .icon_color(Color::Muted), ) .menu({ - let log_view = log_view.clone(); + let log_view = log_view; move |window, cx| { let id = log_view.read(cx).current_server_id?; @@ -1705,12 +1271,6 @@ const SERVER_LOGS: &str = "Server Logs"; const SERVER_TRACE: &str = "Server Trace"; const SERVER_INFO: &str = "Server Info"; -impl Default for LspLogToolbarItemView { - fn default() -> Self { - Self::new() - } -} - impl LspLogToolbarItemView { pub fn new() -> Self { Self { @@ -1743,15 +1303,35 @@ impl LspLogToolbarItemView { } } -pub enum Event { - NewServerLogEntry { - id: LanguageServerId, - kind: LogKind, - text: String, - }, +struct ServerInfo { + id: LanguageServerId, + capabilities: lsp::ServerCapabilities, + binary: Option, + name: LanguageServerName, + workspace_folders: Vec, + configuration: Option, +} + +impl ServerInfo { + fn new(server: &LanguageServer) -> Self { + Self { + id: server.server_id(), + capabilities: server.capabilities(), + binary: Some(server.binary().clone()), + name: server.name(), + workspace_folders: server + .workspace_folders() + .into_iter() + .filter_map(|path| { + path.to_file_path() + .ok() + .map(|path| path.to_string_lossy().into_owned()) + }) + .collect::>(), + configuration: Some(server.configuration().clone()), + } + } } -impl EventEmitter for LogStore {} -impl EventEmitter for LspLogView {} impl EventEmitter for LspLogView {} impl EventEmitter for LspLogView {} diff --git a/crates/language_tools/src/lsp_log_tests.rs b/crates/language_tools/src/lsp_log_view_tests.rs similarity index 91% rename from crates/language_tools/src/lsp_log_tests.rs rename to crates/language_tools/src/lsp_log_view_tests.rs index ad2b653fdcfd4dc228cac58da7ed15f844b4bb26..d572c4375ed09997dc57d6c58e6c90f3e55775b6 100644 --- a/crates/language_tools/src/lsp_log_tests.rs +++ b/crates/language_tools/src/lsp_log_view_tests.rs @@ -1,20 +1,22 @@ use std::sync::Arc; -use crate::lsp_log::LogMenuItem; +use crate::lsp_log_view::LogMenuItem; use super::*; use futures::StreamExt; use gpui::{AppContext as _, SemanticVersion, TestAppContext, VisualTestContext}; use language::{FakeLspAdapter, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; use lsp::LanguageServerName; -use lsp_log::LogKind; -use project::{FakeFs, Project}; +use project::{ + FakeFs, Project, + lsp_store::log_store::{LanguageServerKind, LogKind, LogStore}, +}; use serde_json::json; use settings::SettingsStore; use util::path; #[gpui::test] -async fn test_lsp_logs(cx: &mut TestAppContext) { +async fn test_lsp_log_view(cx: &mut TestAppContext) { zlog::init_test(); init_test(cx); @@ -51,7 +53,7 @@ async fn test_lsp_logs(cx: &mut TestAppContext) { }, ); - let log_store = cx.new(LogStore::new); + let log_store = cx.new(|cx| LogStore::new(false, cx)); log_store.update(cx, |store, cx| store.add_project(&project, cx)); let _rust_buffer = project @@ -94,7 +96,7 @@ async fn test_lsp_logs(cx: &mut TestAppContext) { rpc_trace_enabled: false, selected_entry: LogKind::Logs, trace_level: lsp::TraceValue::Off, - server_kind: lsp_log::LanguageServerKind::Local { + server_kind: LanguageServerKind::Local { project: project.downgrade() } }] diff --git a/crates/language_tools/src/syntax_tree_view.rs b/crates/language_tools/src/syntax_tree_view.rs index eadba2c1d2f4c96c4f0ad2646c2e9957bbae3bdc..5700d8d487e990937597295fb5bab761a46f2ba3 100644 --- a/crates/language_tools/src/syntax_tree_view.rs +++ b/crates/language_tools/src/syntax_tree_view.rs @@ -1,17 +1,22 @@ +use command_palette_hooks::CommandPaletteFilter; use editor::{Anchor, Editor, ExcerptId, SelectionEffects, scroll::Autoscroll}; use gpui::{ - App, AppContext as _, Context, Div, Entity, EventEmitter, FocusHandle, Focusable, Hsla, - InteractiveElement, IntoElement, MouseButton, MouseDownEvent, MouseMoveEvent, ParentElement, - Render, ScrollStrategy, SharedString, Styled, UniformListScrollHandle, WeakEntity, Window, - actions, div, rems, uniform_list, + App, AppContext as _, Context, Div, Entity, EntityId, EventEmitter, FocusHandle, Focusable, + Hsla, InteractiveElement, IntoElement, MouseButton, MouseDownEvent, MouseMoveEvent, + ParentElement, Render, ScrollStrategy, SharedString, Styled, UniformListScrollHandle, + WeakEntity, Window, actions, div, rems, uniform_list, }; use language::{Buffer, OwnedSyntaxLayer}; -use std::{mem, ops::Range}; +use std::{any::TypeId, mem, ops::Range}; use theme::ActiveTheme; use tree_sitter::{Node, TreeCursor}; -use ui::{ButtonLike, Color, ContextMenu, Label, LabelCommon, PopoverMenu, h_flex}; +use ui::{ + ButtonCommon, ButtonLike, Clickable, Color, ContextMenu, FluentBuilder as _, IconButton, + IconName, Label, LabelCommon, LabelSize, PopoverMenu, StyledExt, Tooltip, h_flex, v_flex, +}; use workspace::{ - SplitDirection, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace, + Event as WorkspaceEvent, SplitDirection, ToolbarItemEvent, ToolbarItemLocation, + ToolbarItemView, Workspace, item::{Item, ItemHandle}, }; @@ -19,17 +24,51 @@ actions!( dev, [ /// Opens the syntax tree view for the current file. - OpenSyntaxTreeView + OpenSyntaxTreeView, + ] +); + +actions!( + syntax_tree_view, + [ + /// Update the syntax tree view to show the last focused file. + UseActiveEditor ] ); pub fn init(cx: &mut App) { - cx.observe_new(|workspace: &mut Workspace, _, _| { - workspace.register_action(|workspace, _: &OpenSyntaxTreeView, window, cx| { + let syntax_tree_actions = [TypeId::of::()]; + + CommandPaletteFilter::update_global(cx, |this, _| { + this.hide_action_types(&syntax_tree_actions); + }); + + cx.observe_new(move |workspace: &mut Workspace, _, _| { + workspace.register_action(move |workspace, _: &OpenSyntaxTreeView, window, cx| { + CommandPaletteFilter::update_global(cx, |this, _| { + this.show_action_types(&syntax_tree_actions); + }); + let active_item = workspace.active_item(cx); let workspace_handle = workspace.weak_handle(); - let syntax_tree_view = - cx.new(|cx| SyntaxTreeView::new(workspace_handle, active_item, window, cx)); + let syntax_tree_view = cx.new(|cx| { + cx.on_release(move |view: &mut SyntaxTreeView, cx| { + if view + .workspace_handle + .read_with(cx, |workspace, cx| { + workspace.item_of_type::(cx).is_none() + }) + .unwrap_or_default() + { + CommandPaletteFilter::update_global(cx, |this, _| { + this.hide_action_types(&syntax_tree_actions); + }); + } + }) + .detach(); + + SyntaxTreeView::new(workspace_handle, active_item, window, cx) + }); workspace.split_item( SplitDirection::Right, Box::new(syntax_tree_view), @@ -37,6 +76,13 @@ pub fn init(cx: &mut App) { cx, ) }); + workspace.register_action(|workspace, _: &UseActiveEditor, window, cx| { + if let Some(tree_view) = workspace.item_of_type::(cx) { + tree_view.update(cx, |view, cx| { + view.update_active_editor(&Default::default(), window, cx) + }) + } + }); }) .detach(); } @@ -45,6 +91,9 @@ pub struct SyntaxTreeView { workspace_handle: WeakEntity, editor: Option, list_scroll_handle: UniformListScrollHandle, + /// The last active editor in the workspace. Note that this is specifically not the + /// currently shown editor. + last_active_editor: Option>, selected_descendant_ix: Option, hovered_descendant_ix: Option, focus_handle: FocusHandle, @@ -61,6 +110,14 @@ struct EditorState { _subscription: gpui::Subscription, } +impl EditorState { + fn has_language(&self) -> bool { + self.active_buffer + .as_ref() + .is_some_and(|buffer| buffer.active_layer.is_some()) + } +} + #[derive(Clone)] struct BufferState { buffer: Entity, @@ -79,17 +136,25 @@ impl SyntaxTreeView { workspace_handle: workspace_handle.clone(), list_scroll_handle: UniformListScrollHandle::new(), editor: None, + last_active_editor: None, hovered_descendant_ix: None, selected_descendant_ix: None, focus_handle: cx.focus_handle(), }; - this.workspace_updated(active_item, window, cx); - cx.observe_in( + this.handle_item_updated(active_item, window, cx); + + cx.subscribe_in( &workspace_handle.upgrade().unwrap(), window, - |this, workspace, window, cx| { - this.workspace_updated(workspace.read(cx).active_item(cx), window, cx); + move |this, workspace, event, window, cx| match event { + WorkspaceEvent::ItemAdded { .. } | WorkspaceEvent::ActiveItemChanged => { + this.handle_item_updated(workspace.read(cx).active_item(cx), window, cx) + } + WorkspaceEvent::ItemRemoved { item_id } => { + this.handle_item_removed(item_id, window, cx); + } + _ => {} }, ) .detach(); @@ -97,21 +162,56 @@ impl SyntaxTreeView { this } - fn workspace_updated( + fn handle_item_updated( &mut self, active_item: Option>, window: &mut Window, cx: &mut Context, ) { - if let Some(item) = active_item { - if item.item_id() != cx.entity_id() { - if let Some(editor) = item.act_as::(cx) { - self.set_editor(editor, window, cx); - } - } + let Some(editor) = active_item + .filter(|item| item.item_id() != cx.entity_id()) + .and_then(|item| item.act_as::(cx)) + else { + return; + }; + + if let Some(editor_state) = self.editor.as_ref().filter(|state| state.has_language()) { + self.last_active_editor = (editor_state.editor != editor).then_some(editor); + } else { + self.set_editor(editor, window, cx); } } + fn handle_item_removed( + &mut self, + item_id: &EntityId, + window: &mut Window, + cx: &mut Context, + ) { + if self + .editor + .as_ref() + .is_some_and(|state| state.editor.entity_id() == *item_id) + { + self.editor = None; + // Try activating the last active editor if there is one + self.update_active_editor(&Default::default(), window, cx); + cx.notify(); + } + } + + fn update_active_editor( + &mut self, + _: &UseActiveEditor, + window: &mut Window, + cx: &mut Context, + ) { + let Some(editor) = self.last_active_editor.take() else { + return; + }; + self.set_editor(editor, window, cx); + } + fn set_editor(&mut self, editor: Entity, window: &mut Window, cx: &mut Context) { if let Some(state) = &self.editor { if state.editor == editor { @@ -157,7 +257,7 @@ impl SyntaxTreeView { .buffer_snapshot .range_to_buffer_ranges(selection_range) .pop()?; - let buffer = multi_buffer.buffer(buffer.remote_id()).unwrap().clone(); + let buffer = multi_buffer.buffer(buffer.remote_id()).unwrap(); Some((buffer, range, excerpt_id)) })?; @@ -295,101 +395,153 @@ impl SyntaxTreeView { .pl(rems(depth as f32)) .hover(|style| style.bg(colors.element_hover)) } -} - -impl Render for SyntaxTreeView { - fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let mut rendered = div().flex_1().bg(cx.theme().colors().editor_background); - if let Some(layer) = self - .editor - .as_ref() - .and_then(|editor| editor.active_buffer.as_ref()) - .and_then(|buffer| buffer.active_layer.as_ref()) - { - let layer = layer.clone(); - rendered = rendered.child(uniform_list( - "SyntaxTreeView", - layer.node().descendant_count(), - cx.processor(move |this, range: Range, _, cx| { - let mut items = Vec::new(); - let mut cursor = layer.node().walk(); - let mut descendant_ix = range.start; - cursor.goto_descendant(descendant_ix); - let mut depth = cursor.depth(); - let mut visited_children = false; - while descendant_ix < range.end { - if visited_children { - if cursor.goto_next_sibling() { - visited_children = false; - } else if cursor.goto_parent() { - depth -= 1; - } else { - break; - } - } else { - items.push( - Self::render_node( - &cursor, - depth, - Some(descendant_ix) == this.selected_descendant_ix, + fn compute_items( + &mut self, + layer: &OwnedSyntaxLayer, + range: Range, + cx: &Context, + ) -> Vec
{ + let mut items = Vec::new(); + let mut cursor = layer.node().walk(); + let mut descendant_ix = range.start; + cursor.goto_descendant(descendant_ix); + let mut depth = cursor.depth(); + let mut visited_children = false; + while descendant_ix < range.end { + if visited_children { + if cursor.goto_next_sibling() { + visited_children = false; + } else if cursor.goto_parent() { + depth -= 1; + } else { + break; + } + } else { + items.push( + Self::render_node( + &cursor, + depth, + Some(descendant_ix) == self.selected_descendant_ix, + cx, + ) + .on_mouse_down( + MouseButton::Left, + cx.listener(move |tree_view, _: &MouseDownEvent, window, cx| { + tree_view.update_editor_with_range_for_descendant_ix( + descendant_ix, + window, + cx, + |editor, mut range, window, cx| { + // Put the cursor at the beginning of the node. + mem::swap(&mut range.start, &mut range.end); + + editor.change_selections( + SelectionEffects::scroll(Autoscroll::newest()), + window, + cx, + |selections| { + selections.select_ranges(vec![range]); + }, + ); + }, + ); + }), + ) + .on_mouse_move(cx.listener( + move |tree_view, _: &MouseMoveEvent, window, cx| { + if tree_view.hovered_descendant_ix != Some(descendant_ix) { + tree_view.hovered_descendant_ix = Some(descendant_ix); + tree_view.update_editor_with_range_for_descendant_ix( + descendant_ix, + window, cx, - ) - .on_mouse_down( - MouseButton::Left, - cx.listener(move |tree_view, _: &MouseDownEvent, window, cx| { - tree_view.update_editor_with_range_for_descendant_ix( - descendant_ix, - window, cx, - |editor, mut range, window, cx| { - // Put the cursor at the beginning of the node. - mem::swap(&mut range.start, &mut range.end); - - editor.change_selections( - SelectionEffects::scroll(Autoscroll::newest()), - window, cx, - |selections| { - selections.select_ranges(vec![range]); - }, - ); + |editor, range, _, cx| { + editor.clear_background_highlights::(cx); + editor.highlight_background::( + &[range], + |theme| { + theme + .colors() + .editor_document_highlight_write_background }, + cx, ); - }), - ) - .on_mouse_move(cx.listener( - move |tree_view, _: &MouseMoveEvent, window, cx| { - if tree_view.hovered_descendant_ix != Some(descendant_ix) { - tree_view.hovered_descendant_ix = Some(descendant_ix); - tree_view.update_editor_with_range_for_descendant_ix(descendant_ix, window, cx, |editor, range, _, cx| { - editor.clear_background_highlights::( cx); - editor.highlight_background::( - &[range], - |theme| theme.colors().editor_document_highlight_write_background, - cx, - ); - }); - cx.notify(); - } }, - )), - ); - descendant_ix += 1; - if cursor.goto_first_child() { - depth += 1; - } else { - visited_children = true; + ); + cx.notify(); } - } - } - items - }), - ) - .size_full() - .track_scroll(self.list_scroll_handle.clone()) - .text_bg(cx.theme().colors().background).into_any_element()); + }, + )), + ); + descendant_ix += 1; + if cursor.goto_first_child() { + depth += 1; + } else { + visited_children = true; + } + } } + items + } +} - rendered +impl Render for SyntaxTreeView { + fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + div() + .flex_1() + .bg(cx.theme().colors().editor_background) + .map(|this| { + let editor_state = self.editor.as_ref(); + + if let Some(layer) = editor_state + .and_then(|editor| editor.active_buffer.as_ref()) + .and_then(|buffer| buffer.active_layer.as_ref()) + { + let layer = layer.clone(); + this.child( + uniform_list( + "SyntaxTreeView", + layer.node().descendant_count(), + cx.processor(move |this, range: Range, _, cx| { + this.compute_items(&layer, range, cx) + }), + ) + .size_full() + .track_scroll(self.list_scroll_handle.clone()) + .text_bg(cx.theme().colors().background) + .into_any_element(), + ) + } else { + let inner_content = v_flex() + .items_center() + .text_center() + .gap_2() + .max_w_3_5() + .map(|this| { + if editor_state.is_some_and(|state| !state.has_language()) { + this.child(Label::new("Current editor has no associated language")) + .child( + Label::new(concat!( + "Try assigning a language or", + "switching to a different buffer" + )) + .size(LabelSize::Small), + ) + } else { + this.child(Label::new("Not attached to an editor")).child( + Label::new("Focus an editor to show a new tree view") + .size(LabelSize::Small), + ) + } + }); + + this.h_flex() + .size_full() + .justify_center() + .child(inner_content) + } + }) } } @@ -456,7 +608,7 @@ impl SyntaxTreeToolbarItemView { let active_layer = buffer_state.active_layer.clone()?; let active_buffer = buffer_state.buffer.read(cx).snapshot(); - let view = cx.entity().clone(); + let view = cx.entity(); Some( PopoverMenu::new("Syntax Tree") .trigger(Self::render_header(&active_layer)) @@ -507,6 +659,26 @@ impl SyntaxTreeToolbarItemView { .child(Label::new(active_layer.language.name())) .child(Label::new(format_node_range(active_layer.node()))) } + + fn render_update_button(&mut self, cx: &mut Context) -> Option { + self.tree_view.as_ref().and_then(|view| { + view.update(cx, |view, cx| { + view.last_active_editor.as_ref().map(|editor| { + IconButton::new("syntax-view-update", IconName::RotateCw) + .tooltip({ + let active_tab_name = editor.read_with(cx, |editor, cx| { + editor.tab_content_text(Default::default(), cx) + }); + + Tooltip::text(format!("Update view to '{active_tab_name}'")) + }) + .on_click(cx.listener(|this, _, window, cx| { + this.update_active_editor(&Default::default(), window, cx); + })) + }) + }) + }) + } } fn format_node_range(node: Node) -> String { @@ -523,8 +695,10 @@ fn format_node_range(node: Node) -> String { impl Render for SyntaxTreeToolbarItemView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - self.render_menu(cx) - .unwrap_or_else(|| PopoverMenu::new("Empty Syntax Tree")) + h_flex() + .gap_1() + .children(self.render_menu(cx)) + .children(self.render_update_button(cx)) } } @@ -537,12 +711,12 @@ impl ToolbarItemView for SyntaxTreeToolbarItemView { window: &mut Window, cx: &mut Context, ) -> ToolbarItemLocation { - if let Some(item) = active_pane_item { - if let Some(view) = item.downcast::() { - self.tree_view = Some(view.clone()); - self.subscription = Some(cx.observe_in(&view, window, |_, _, _, cx| cx.notify())); - return ToolbarItemLocation::PrimaryLeft; - } + if let Some(item) = active_pane_item + && let Some(view) = item.downcast::() + { + self.tree_view = Some(view.clone()); + self.subscription = Some(cx.observe_in(&view, window, |_, _, _, cx| cx.notify())); + return ToolbarItemLocation::PrimaryLeft; } self.tree_view = None; self.subscription = None; diff --git a/crates/languages/Cargo.toml b/crates/languages/Cargo.toml index 8e258180702626bb3dd32b28bfb0e82722a1f12f..e09b27d4742d4660cffb3d86905fb67268e617fa 100644 --- a/crates/languages/Cargo.toml +++ b/crates/languages/Cargo.toml @@ -42,7 +42,6 @@ async-trait.workspace = true chrono.workspace = true collections.workspace = true dap.workspace = true -feature_flags.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true diff --git a/extensions/toml/languages/toml/overrides.scm b/crates/languages/src/bash/overrides.scm similarity index 100% rename from extensions/toml/languages/toml/overrides.scm rename to crates/languages/src/bash/overrides.scm diff --git a/crates/languages/src/c.rs b/crates/languages/src/c.rs index aee1abee95fa2ea21931084ebe442c2ecd41da3c..afdf49e66e59b78c82f234160a9c4bc1efa83574 100644 --- a/crates/languages/src/c.rs +++ b/crates/languages/src/c.rs @@ -5,8 +5,9 @@ use gpui::{App, AsyncApp}; use http_client::github::{AssetKind, GitHubLspBinaryVersion, latest_github_release}; pub use language::*; use lsp::{InitializeParams, LanguageServerBinary, LanguageServerName}; -use project::lsp_store::clangd_ext; +use project::{lsp_store::clangd_ext, project_settings::ProjectSettings}; use serde_json::json; +use settings::Settings as _; use smol::fs; use std::{any::Any, env::consts, path::PathBuf, sync::Arc}; use util::{ResultExt, fs::remove_matching, maybe, merge_json_value_into}; @@ -22,13 +23,13 @@ impl CLspAdapter { #[async_trait(?Send)] impl super::LspAdapter for CLspAdapter { fn name(&self) -> LanguageServerName { - Self::SERVER_NAME.clone() + Self::SERVER_NAME } async fn check_if_user_installed( &self, delegate: &dyn LspAdapterDelegate, - _: Arc, + _: Option, _: &AsyncApp, ) -> Option { let path = delegate.which(Self::SERVER_NAME.as_ref()).await?; @@ -42,9 +43,19 @@ impl super::LspAdapter for CLspAdapter { async fn fetch_latest_server_version( &self, delegate: &dyn LspAdapterDelegate, + cx: &AsyncApp, ) -> Result> { - let release = - latest_github_release("clangd/clangd", true, false, delegate.http_client()).await?; + let release = latest_github_release( + "clangd/clangd", + true, + ProjectSettings::try_read_global(cx, |s| { + s.lsp.get(&Self::SERVER_NAME)?.fetch.as_ref()?.pre_release + }) + .flatten() + .unwrap_or(false), + delegate.http_client(), + ) + .await?; let os_suffix = match consts::OS { "macos" => "mac", "linux" => "linux", @@ -253,8 +264,7 @@ impl super::LspAdapter for CLspAdapter { .grammar() .and_then(|g| g.highlight_id_for_name(highlight_name?)) { - let mut label = - CodeLabel::plain(label.to_string(), completion.filter_text.as_deref()); + let mut label = CodeLabel::plain(label, completion.filter_text.as_deref()); label.runs.push(( 0..label.text.rfind('(').unwrap_or(label.text.len()), highlight_id, @@ -264,10 +274,7 @@ impl super::LspAdapter for CLspAdapter { } _ => {} } - Some(CodeLabel::plain( - label.to_string(), - completion.filter_text.as_deref(), - )) + Some(CodeLabel::plain(label, completion.filter_text.as_deref())) } async fn label_for_symbol( diff --git a/crates/languages/src/cpp/highlights.scm b/crates/languages/src/cpp/highlights.scm index 6fa8bd7b0858d3a1844ce2d322564ce9c39babea..bd988445bb155e8851ffa8bc3771bdd235fc7dff 100644 --- a/crates/languages/src/cpp/highlights.scm +++ b/crates/languages/src/cpp/highlights.scm @@ -3,8 +3,27 @@ (namespace_identifier) @namespace (concept_definition - (identifier) @concept) + name: (identifier) @concept) +(requires_clause + constraint: (template_type + name: (type_identifier) @concept)) + +(module_name + (identifier) @module) + +(module_declaration + name: (module_name + (identifier) @module)) + +(import_declaration + name: (module_name + (identifier) @module)) + +(import_declaration + partition: (module_partition + (module_name + (identifier) @module))) (call_expression function: (qualified_identifier @@ -61,6 +80,9 @@ (operator_name (identifier)? @operator) @function +(operator_name + "<=>" @operator.spaceship) + (destructor_name (identifier) @function) ((namespace_identifier) @type @@ -68,21 +90,17 @@ (auto) @type (type_identifier) @type -type :(primitive_type) @type.primitive -(sized_type_specifier) @type.primitive - -(requires_clause - constraint: (template_type - name: (type_identifier) @concept)) +type: (primitive_type) @type.builtin +(sized_type_specifier) @type.builtin (attribute - name: (identifier) @keyword) + name: (identifier) @attribute) -((identifier) @constant - (#match? @constant "^_*[A-Z][A-Z\\d_]*$")) +((identifier) @constant.builtin + (#match? @constant.builtin "^_*[A-Z][A-Z\\d_]*$")) (statement_identifier) @label -(this) @variable.special +(this) @variable.builtin ("static_assert") @function.builtin [ @@ -96,7 +114,9 @@ type :(primitive_type) @type.primitive "co_return" "co_yield" "concept" + "consteval" "constexpr" + "constinit" "continue" "decltype" "default" @@ -105,15 +125,20 @@ type :(primitive_type) @type.primitive "else" "enum" "explicit" + "export" "extern" "final" "for" "friend" + "goto" "if" + "import" "inline" + "module" "namespace" "new" "noexcept" + "operator" "override" "private" "protected" @@ -124,6 +149,7 @@ type :(primitive_type) @type.primitive "struct" "switch" "template" + "thread_local" "throw" "try" "typedef" @@ -146,7 +172,7 @@ type :(primitive_type) @type.primitive "#ifndef" "#include" (preproc_directive) -] @keyword +] @keyword.directive (comment) @comment @@ -224,10 +250,24 @@ type :(primitive_type) @type.primitive ">" "<=" ">=" - "<=>" - "||" "?" + "and" + "and_eq" + "bitand" + "bitor" + "compl" + "not" + "not_eq" + "or" + "or_eq" + "xor" + "xor_eq" ] @operator +"<=>" @operator.spaceship + +(binary_expression + operator: "<=>" @operator.spaceship) + (conditional_expression ":" @operator) (user_defined_literal (literal_suffix) @operator) diff --git a/crates/languages/src/css.rs b/crates/languages/src/css.rs index 19329fcc6edeea8bce1a6e09ec774793f1098811..79e3a62a342092a6203361e5611b9ff81481e984 100644 --- a/crates/languages/src/css.rs +++ b/crates/languages/src/css.rs @@ -2,9 +2,9 @@ use anyhow::{Context as _, Result}; use async_trait::async_trait; use futures::StreamExt; use gpui::AsyncApp; -use language::{LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; +use language::{LspAdapter, LspAdapterDelegate, Toolchain}; use lsp::{LanguageServerBinary, LanguageServerName}; -use node_runtime::NodeRuntime; +use node_runtime::{NodeRuntime, VersionStrategy}; use project::{Fs, lsp_store::language_server_settings}; use serde_json::json; use smol::fs; @@ -43,7 +43,7 @@ impl LspAdapter for CssLspAdapter { async fn check_if_user_installed( &self, delegate: &dyn LspAdapterDelegate, - _: Arc, + _: Option, _: &AsyncApp, ) -> Option { let path = delegate @@ -61,6 +61,7 @@ impl LspAdapter for CssLspAdapter { async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, + _: &AsyncApp, ) -> Result> { Ok(Box::new( self.node @@ -106,9 +107,8 @@ impl LspAdapter for CssLspAdapter { .should_install_npm_package( Self::PACKAGE_NAME, &server_path, - &container_dir, - &version, - Default::default(), + container_dir, + VersionStrategy::Latest(version), ) .await; @@ -145,7 +145,7 @@ impl LspAdapter for CssLspAdapter { self: Arc, _: &dyn Fs, delegate: &Arc, - _: Arc, + _: Option, cx: &mut AsyncApp, ) -> Result { let mut default_config = json!({ @@ -237,7 +237,7 @@ mod tests { .unindent(); let buffer = cx.new(|cx| language::Buffer::local(text, cx).with_language(language, cx)); - let outline = buffer.read_with(cx, |buffer, _| buffer.snapshot().outline(None).unwrap()); + let outline = buffer.read_with(cx, |buffer, _| buffer.snapshot().outline(None)); assert_eq!( outline .items diff --git a/crates/languages/src/github_download.rs b/crates/languages/src/github_download.rs index 5b0f1d0729c6ca620c6983ce3c3d64c5d7274314..766c894fbb2b660778f09933b4facd2114ebb5bf 100644 --- a/crates/languages/src/github_download.rs +++ b/crates/languages/src/github_download.rs @@ -96,7 +96,7 @@ async fn stream_response_archive( AssetKind::TarGz => extract_tar_gz(destination_path, url, response).await?, AssetKind::Gz => extract_gz(destination_path, url, response).await?, AssetKind::Zip => { - util::archive::extract_zip(&destination_path, response).await?; + util::archive::extract_zip(destination_path, response).await?; } }; Ok(()) @@ -113,11 +113,11 @@ async fn stream_file_archive( AssetKind::Gz => extract_gz(destination_path, url, file_archive).await?, #[cfg(not(windows))] AssetKind::Zip => { - util::archive::extract_seekable_zip(&destination_path, file_archive).await?; + util::archive::extract_seekable_zip(destination_path, file_archive).await?; } #[cfg(windows)] AssetKind::Zip => { - util::archive::extract_zip(&destination_path, file_archive).await?; + util::archive::extract_zip(destination_path, file_archive).await?; } }; Ok(()) diff --git a/crates/languages/src/go.rs b/crates/languages/src/go.rs index 14f646133bf22ba7977cb23dca38a4700b527e1b..55441c33b8fa076654a2231b0a69a0579306ce18 100644 --- a/crates/languages/src/go.rs +++ b/crates/languages/src/go.rs @@ -53,12 +53,13 @@ const BINARY: &str = if cfg!(target_os = "windows") { #[async_trait(?Send)] impl super::LspAdapter for GoLspAdapter { fn name(&self) -> LanguageServerName { - Self::SERVER_NAME.clone() + Self::SERVER_NAME } async fn fetch_latest_server_version( &self, delegate: &dyn LspAdapterDelegate, + _: &AsyncApp, ) -> Result> { let release = latest_github_release("golang/tools", false, false, delegate.http_client()).await?; @@ -75,7 +76,7 @@ impl super::LspAdapter for GoLspAdapter { async fn check_if_user_installed( &self, delegate: &dyn LspAdapterDelegate, - _: Arc, + _: Option, _: &AsyncApp, ) -> Option { let path = delegate.which(Self::SERVER_NAME.as_ref()).await?; @@ -131,19 +132,19 @@ impl super::LspAdapter for GoLspAdapter { if let Some(version) = *version { let binary_path = container_dir.join(format!("gopls_{version}_go_{go_version}")); - if let Ok(metadata) = fs::metadata(&binary_path).await { - if metadata.is_file() { - remove_matching(&container_dir, |entry| { - entry != binary_path && entry.file_name() != Some(OsStr::new("gobin")) - }) - .await; + if let Ok(metadata) = fs::metadata(&binary_path).await + && metadata.is_file() + { + remove_matching(&container_dir, |entry| { + entry != binary_path && entry.file_name() != Some(OsStr::new("gobin")) + }) + .await; - return Ok(LanguageServerBinary { - path: binary_path.to_path_buf(), - arguments: server_binary_arguments(), - env: None, - }); - } + return Ok(LanguageServerBinary { + path: binary_path.to_path_buf(), + arguments: server_binary_arguments(), + env: None, + }); } } else if let Some(path) = this .cached_server_binary(container_dir.clone(), delegate) @@ -203,7 +204,7 @@ impl super::LspAdapter for GoLspAdapter { _: &Arc, ) -> Result> { Ok(Some(json!({ - "usePlaceholders": true, + "usePlaceholders": false, "hints": { "assignVariableTypes": true, "compositeLiteralFields": true, @@ -452,7 +453,7 @@ async fn get_cached_server_binary(container_dir: PathBuf) -> Option"]) @punctuation.bracket.jsx) (jsx_self_closing_element (["<" "/>"]) @punctuation.bracket.jsx) (jsx_attribute "=" @punctuation.delimiter.jsx) -(jsx_text) @text.jsx \ No newline at end of file +(jsx_text) @text.jsx diff --git a/crates/languages/src/javascript/injections.scm b/crates/languages/src/javascript/injections.scm index 7baba5f227eb0df31cd753029296e165dfff0180..987be660d3c5ebd706284990d7d21a481b24a2af 100644 --- a/crates/languages/src/javascript/injections.scm +++ b/crates/languages/src/javascript/injections.scm @@ -11,6 +11,21 @@ (#set! injection.language "css")) ) +(call_expression + function: (member_expression + object: (identifier) @_obj (#eq? @_obj "styled") + property: (property_identifier)) + arguments: (template_string (string_fragment) @injection.content + (#set! injection.language "css")) +) + +(call_expression + function: (call_expression + function: (identifier) @_name (#eq? @_name "styled")) + arguments: (template_string (string_fragment) @injection.content + (#set! injection.language "css")) +) + (call_expression function: (identifier) @_name (#eq? @_name "html") arguments: (template_string) @injection.content @@ -58,3 +73,9 @@ arguments: (arguments (template_string (string_fragment) @injection.content (#set! injection.language "graphql"))) ) + +(call_expression + function: (identifier) @_name(#match? @_name "^iso$") + arguments: (arguments (template_string (string_fragment) @injection.content + (#set! injection.language "isograph"))) +) diff --git a/crates/languages/src/json.rs b/crates/languages/src/json.rs index 019b45d396891b434c6b5e8457353ee0ee3e0d69..a33f5c9836f621a59b45aba7a249f7b1c2d1489d 100644 --- a/crates/languages/src/json.rs +++ b/crates/languages/src/json.rs @@ -8,11 +8,11 @@ use futures::StreamExt; use gpui::{App, AsyncApp, Task}; use http_client::github::{GitHubLspBinaryVersion, latest_github_release}; use language::{ - ContextProvider, LanguageName, LanguageRegistry, LanguageToolchainStore, LocalFile as _, - LspAdapter, LspAdapterDelegate, + ContextProvider, LanguageName, LanguageRegistry, LocalFile as _, LspAdapter, + LspAdapterDelegate, Toolchain, }; use lsp::{LanguageServerBinary, LanguageServerName}; -use node_runtime::NodeRuntime; +use node_runtime::{NodeRuntime, VersionStrategy}; use project::{Fs, lsp_store::language_server_settings}; use serde_json::{Value, json}; use settings::{KeymapFile, SettingsJsonSchemaParams, SettingsStore}; @@ -234,7 +234,7 @@ impl JsonLspAdapter { schemas .as_array_mut() .unwrap() - .extend(cx.all_action_names().into_iter().map(|&name| { + .extend(cx.all_action_names().iter().map(|&name| { project::lsp_store::json_language_server_ext::url_schema_for_action(name) })); @@ -280,7 +280,7 @@ impl JsonLspAdapter { ) })?; writer.replace(config.clone()); - return Ok(config); + Ok(config) } } @@ -303,7 +303,7 @@ impl LspAdapter for JsonLspAdapter { async fn check_if_user_installed( &self, delegate: &dyn LspAdapterDelegate, - _: Arc, + _: Option, _: &AsyncApp, ) -> Option { let path = delegate @@ -321,6 +321,7 @@ impl LspAdapter for JsonLspAdapter { async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, + _: &AsyncApp, ) -> Result> { Ok(Box::new( self.node @@ -343,9 +344,8 @@ impl LspAdapter for JsonLspAdapter { .should_install_npm_package( Self::PACKAGE_NAME, &server_path, - &container_dir, - &version, - Default::default(), + container_dir, + VersionStrategy::Latest(version), ) .await; @@ -405,7 +405,7 @@ impl LspAdapter for JsonLspAdapter { self: Arc, _: &dyn Fs, delegate: &Arc, - _: Arc, + _: Option, cx: &mut AsyncApp, ) -> Result { let mut config = self.get_or_init_workspace_config(cx).await?; @@ -489,12 +489,13 @@ impl NodeVersionAdapter { #[async_trait(?Send)] impl LspAdapter for NodeVersionAdapter { fn name(&self) -> LanguageServerName { - Self::SERVER_NAME.clone() + Self::SERVER_NAME } async fn fetch_latest_server_version( &self, delegate: &dyn LspAdapterDelegate, + _: &AsyncApp, ) -> Result> { let release = latest_github_release( "zed-industries/package-version-server", @@ -530,7 +531,7 @@ impl LspAdapter for NodeVersionAdapter { async fn check_if_user_installed( &self, delegate: &dyn LspAdapterDelegate, - _: Arc, + _: Option, _: &AsyncApp, ) -> Option { let path = delegate.which(Self::SERVER_NAME.as_ref()).await?; diff --git a/crates/languages/src/jsonc/overrides.scm b/crates/languages/src/jsonc/overrides.scm index cc966ad4c13e0cc7f7fc27a1152b461f24e3c6b0..81fec9a5f57b28fc67b4781ec37df43559e21dc9 100644 --- a/crates/languages/src/jsonc/overrides.scm +++ b/crates/languages/src/jsonc/overrides.scm @@ -1 +1,2 @@ +(comment) @comment.inclusive (string) @string diff --git a/crates/languages/src/lib.rs b/crates/languages/src/lib.rs index 195ba79e1d0e96acea7ac1a53590c1a947334069..95fe1312183a3412509375050b1e1ff67642ef3e 100644 --- a/crates/languages/src/lib.rs +++ b/crates/languages/src/lib.rs @@ -1,6 +1,5 @@ use anyhow::Context as _; -use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; -use gpui::{App, UpdateGlobal}; +use gpui::{App, SharedString, UpdateGlobal}; use node_runtime::NodeRuntime; use python::PyprojectTomlManifestProvider; use rust::CargoManifestProvider; @@ -54,12 +53,6 @@ pub static LANGUAGE_GIT_COMMIT: std::sync::LazyLock> = )) }); -struct BasedPyrightFeatureFlag; - -impl FeatureFlag for BasedPyrightFeatureFlag { - const NAME: &'static str = "basedpyright"; -} - pub fn init(languages: Arc, node: NodeRuntime, cx: &mut App) { #[cfg(feature = "load-grammars")] languages.register_native_grammars([ @@ -97,14 +90,14 @@ pub fn init(languages: Arc, node: NodeRuntime, cx: &mut App) { let python_context_provider = Arc::new(python::PythonContextProvider); let python_lsp_adapter = Arc::new(python::PythonLspAdapter::new(node.clone())); let basedpyright_lsp_adapter = Arc::new(BasedPyrightLspAdapter::new()); - let python_toolchain_provider = Arc::new(python::PythonToolchainProvider::default()); + let python_toolchain_provider = Arc::new(python::PythonToolchainProvider); let rust_context_provider = Arc::new(rust::RustContextProvider); let rust_lsp_adapter = Arc::new(rust::RustLspAdapter); let tailwind_adapter = Arc::new(tailwind::TailwindLspAdapter::new(node.clone())); let typescript_context = Arc::new(typescript::TypeScriptContextProvider::new()); let typescript_lsp_adapter = Arc::new(typescript::TypeScriptLspAdapter::new(node.clone())); let vtsls_adapter = Arc::new(vtsls::VtslsLspAdapter::new(node.clone())); - let yaml_lsp_adapter = Arc::new(yaml::YamlLspAdapter::new(node.clone())); + let yaml_lsp_adapter = Arc::new(yaml::YamlLspAdapter::new(node)); let built_in_languages = [ LanguageInfo { @@ -119,12 +112,12 @@ pub fn init(languages: Arc, node: NodeRuntime, cx: &mut App) { }, LanguageInfo { name: "cpp", - adapters: vec![c_lsp_adapter.clone()], + adapters: vec![c_lsp_adapter], ..Default::default() }, LanguageInfo { name: "css", - adapters: vec![css_lsp_adapter.clone()], + adapters: vec![css_lsp_adapter], ..Default::default() }, LanguageInfo { @@ -146,20 +139,20 @@ pub fn init(languages: Arc, node: NodeRuntime, cx: &mut App) { }, LanguageInfo { name: "gowork", - adapters: vec![go_lsp_adapter.clone()], - context: Some(go_context_provider.clone()), + adapters: vec![go_lsp_adapter], + context: Some(go_context_provider), ..Default::default() }, LanguageInfo { name: "json", - adapters: vec![json_lsp_adapter.clone(), node_version_lsp_adapter.clone()], + adapters: vec![json_lsp_adapter.clone(), node_version_lsp_adapter], context: Some(json_context_provider.clone()), ..Default::default() }, LanguageInfo { name: "jsonc", - adapters: vec![json_lsp_adapter.clone()], - context: Some(json_context_provider.clone()), + adapters: vec![json_lsp_adapter], + context: Some(json_context_provider), ..Default::default() }, LanguageInfo { @@ -174,14 +167,16 @@ pub fn init(languages: Arc, node: NodeRuntime, cx: &mut App) { }, LanguageInfo { name: "python", - adapters: vec![python_lsp_adapter.clone(), py_lsp_adapter.clone()], + adapters: vec![basedpyright_lsp_adapter], context: Some(python_context_provider), toolchain: Some(python_toolchain_provider), + manifest_name: Some(SharedString::new_static("pyproject.toml").into()), }, LanguageInfo { name: "rust", adapters: vec![rust_lsp_adapter], context: Some(rust_context_provider), + manifest_name: Some(SharedString::new_static("Cargo.toml").into()), ..Default::default() }, LanguageInfo { @@ -199,7 +194,7 @@ pub fn init(languages: Arc, node: NodeRuntime, cx: &mut App) { LanguageInfo { name: "javascript", adapters: vec![typescript_lsp_adapter.clone(), vtsls_adapter.clone()], - context: Some(typescript_context.clone()), + context: Some(typescript_context), ..Default::default() }, LanguageInfo { @@ -234,23 +229,10 @@ pub fn init(languages: Arc, node: NodeRuntime, cx: &mut App) { registration.adapters, registration.context, registration.toolchain, + registration.manifest_name, ); } - let mut basedpyright_lsp_adapter = Some(basedpyright_lsp_adapter); - cx.observe_flag::({ - let languages = languages.clone(); - move |enabled, _| { - if enabled { - if let Some(adapter) = basedpyright_lsp_adapter.take() { - languages - .register_available_lsp_adapter(adapter.name(), move || adapter.clone()); - } - } - } - }) - .detach(); - // Register globally available language servers. // // This will allow users to add support for a built-in language server (e.g., Tailwind) @@ -267,27 +249,19 @@ pub fn init(languages: Arc, node: NodeRuntime, cx: &mut App) { // ``` languages.register_available_lsp_adapter( LanguageServerName("tailwindcss-language-server".into()), - { - let adapter = tailwind_adapter.clone(); - move || adapter.clone() - }, + tailwind_adapter.clone(), ); - languages.register_available_lsp_adapter(LanguageServerName("eslint".into()), { - let adapter = eslint_adapter.clone(); - move || adapter.clone() - }); - languages.register_available_lsp_adapter(LanguageServerName("vtsls".into()), { - let adapter = vtsls_adapter.clone(); - move || adapter.clone() - }); + languages.register_available_lsp_adapter( + LanguageServerName("eslint".into()), + eslint_adapter.clone(), + ); + languages.register_available_lsp_adapter(LanguageServerName("vtsls".into()), vtsls_adapter); languages.register_available_lsp_adapter( LanguageServerName("typescript-language-server".into()), - { - let adapter = typescript_lsp_adapter.clone(); - move || adapter.clone() - }, + typescript_lsp_adapter, ); - + languages.register_available_lsp_adapter(python_lsp_adapter.name(), python_lsp_adapter); + languages.register_available_lsp_adapter(py_lsp_adapter.name(), py_lsp_adapter); // Register Tailwind for the existing languages that should have it by default. // // This can be driven by the `language_servers` setting once we have a way for @@ -296,6 +270,7 @@ pub fn init(languages: Arc, node: NodeRuntime, cx: &mut App) { "Astro", "CSS", "ERB", + "HTML+ERB", "HTML/ERB", "HEEX", "HTML", @@ -340,7 +315,7 @@ pub fn init(languages: Arc, node: NodeRuntime, cx: &mut App) { Arc::from(PyprojectTomlManifestProvider), ]; for provider in manifest_providers { - project::ManifestProviders::global(cx).register(provider); + project::ManifestProvidersStore::global(cx).register(provider); } } @@ -350,6 +325,7 @@ struct LanguageInfo { adapters: Vec>, context: Option>, toolchain: Option>, + manifest_name: Option, } fn register_language( @@ -358,6 +334,7 @@ fn register_language( adapters: Vec>, context: Option>, toolchain: Option>, + manifest_name: Option, ) { let config = load_config(name); for adapter in adapters { @@ -368,12 +345,14 @@ fn register_language( config.grammar.clone(), config.matcher.clone(), config.hidden, + manifest_name.clone(), Arc::new(move || { Ok(LoadedLanguage { config: config.clone(), queries: load_queries(name), context_provider: context.clone(), toolchain_provider: toolchain.clone(), + manifest_name: manifest_name.clone(), }) }), ); diff --git a/crates/languages/src/markdown-inline/highlights.scm b/crates/languages/src/markdown-inline/highlights.scm index 61c3e34c62973c822a07415aaf56fadffbabc2e2..3c9f6fbcc340bd085466055c7b35551dd71b8c53 100644 --- a/crates/languages/src/markdown-inline/highlights.scm +++ b/crates/languages/src/markdown-inline/highlights.scm @@ -1,6 +1,22 @@ -(emphasis) @emphasis -(strong_emphasis) @emphasis.strong -(code_span) @text.literal -(link_text) @link_text -(link_label) @link_text -(link_destination) @link_uri +(emphasis) @emphasis.markup +(strong_emphasis) @emphasis.strong.markup +(code_span) @text.literal.markup +(strikethrough) @strikethrough.markup + +[ + (inline_link) + (shortcut_link) + (collapsed_reference_link) + (full_reference_link) + (image) + (link_text) + (link_label) +] @link_text.markup + +(inline_link ["(" ")"] @link_uri.markup) +(image ["(" ")"] @link_uri.markup) +[ + (link_destination) + (uri_autolink) + (email_autolink) +] @link_uri.markup diff --git a/crates/languages/src/markdown/config.toml b/crates/languages/src/markdown/config.toml index 926dcd70d9f9207c03154690e7d4e9866f9aacea..36071cb5392462a51c10e0513b39979580ec67f5 100644 --- a/crates/languages/src/markdown/config.toml +++ b/crates/languages/src/markdown/config.toml @@ -12,6 +12,7 @@ brackets = [ { start = "\"", end = "\"", close = false, newline = false }, { start = "'", end = "'", close = false, newline = false }, { start = "`", end = "`", close = false, newline = false }, + { start = "*", end = "*", close = false, newline = false, surround = true }, ] rewrap_prefixes = [ "[-*+]\\s+", diff --git a/crates/languages/src/markdown/highlights.scm b/crates/languages/src/markdown/highlights.scm index 6b9fa3482298c93207b4ab480116751259911d85..707bcc0816366f5cc875c9f1197b42a2363cab99 100644 --- a/crates/languages/src/markdown/highlights.scm +++ b/crates/languages/src/markdown/highlights.scm @@ -1,7 +1,15 @@ +[ + (paragraph) + (indented_code_block) + (pipe_table) +] @text + [ (atx_heading) (setext_heading) -] @title + (thematic_break) +] @title.markup +(setext_heading (paragraph) @title.markup) [ (list_marker_plus) @@ -9,8 +17,18 @@ (list_marker_star) (list_marker_dot) (list_marker_parenthesis) -] @punctuation.list_marker +] @punctuation.list_marker.markup + +(block_quote_marker) @punctuation.markup +(pipe_table_header "|" @punctuation.markup) +(pipe_table_row "|" @punctuation.markup) +(pipe_table_delimiter_row "|" @punctuation.markup) +(pipe_table_delimiter_cell "-" @punctuation.markup) + +[ + (fenced_code_block_delimiter) + (info_string) +] @punctuation.embedded.markup -(fenced_code_block - (info_string - (language) @text.literal)) +(link_reference_definition) @link_text.markup +(link_destination) @link_uri.markup diff --git a/crates/languages/src/python.rs b/crates/languages/src/python.rs index 551332448770ffabc4b834662980bd5bb00248c5..ede46b38e5ada332e10db4feb72ae11767e839bd 100644 --- a/crates/languages/src/python.rs +++ b/crates/languages/src/python.rs @@ -2,6 +2,7 @@ use anyhow::{Context as _, ensure}; use anyhow::{Result, anyhow}; use async_trait::async_trait; use collections::HashMap; +use futures::AsyncBufReadExt; use gpui::{App, Task}; use gpui::{AsyncApp, SharedString}; use language::ToolchainList; @@ -10,13 +11,13 @@ use language::language_settings::language_settings; use language::{ContextLocation, LanguageToolchainStore}; use language::{ContextProvider, LspAdapter, LspAdapterDelegate}; use language::{LanguageName, ManifestName, ManifestProvider, ManifestQuery}; -use language::{Toolchain, WorkspaceFoldersContent}; +use language::{Toolchain, ToolchainMetadata}; use lsp::LanguageServerBinary; use lsp::LanguageServerName; -use node_runtime::NodeRuntime; +use node_runtime::{NodeRuntime, VersionStrategy}; use pet_core::Configuration; use pet_core::os_environment::Environment; -use pet_core::python_environment::PythonEnvironmentKind; +use pet_core::python_environment::{PythonEnvironment, PythonEnvironmentKind}; use project::Fs; use project::lsp_store::language_server_settings; use serde_json::{Value, json}; @@ -30,13 +31,11 @@ use std::{ borrow::Cow, ffi::OsString, fmt::Write, - fs, - io::{self, BufRead}, path::{Path, PathBuf}, sync::Arc, }; -use task::{TaskTemplate, TaskTemplates, VariableName}; -use util::ResultExt; +use task::{ShellKind, TaskTemplate, TaskTemplates, VariableName}; +use util::{ResultExt, maybe}; pub(crate) struct PyprojectTomlManifestProvider; @@ -88,6 +87,18 @@ fn server_binary_arguments(server_path: &Path) -> Vec { vec![server_path.into(), "--stdio".into()] } +/// Pyright assigns each completion item a `sortText` of the form `XX.YYYY.name`. +/// Where `XX` is the sorting category, `YYYY` is based on most recent usage, +/// and `name` is the symbol name itself. +/// +/// The problem with it is that Pyright adjusts the sort text based on previous resolutions (items for which we've issued `completion/resolve` call have their sortText adjusted), +/// which - long story short - makes completion items list non-stable. Pyright probably relies on VSCode's implementation detail. +/// see https://github.com/microsoft/pyright/blob/95ef4e103b9b2f129c9320427e51b73ea7cf78bd/packages/pyright-internal/src/languageService/completionProvider.ts#LL2873 +fn process_pyright_completions(items: &mut [lsp::CompletionItem]) { + for item in items { + item.sort_text.take(); + } +} pub struct PythonLspAdapter { node: NodeRuntime, } @@ -103,7 +114,7 @@ impl PythonLspAdapter { #[async_trait(?Send)] impl LspAdapter for PythonLspAdapter { fn name(&self) -> LanguageServerName { - Self::SERVER_NAME.clone() + Self::SERVER_NAME } async fn initialization_options( @@ -127,7 +138,7 @@ impl LspAdapter for PythonLspAdapter { async fn check_if_user_installed( &self, delegate: &dyn LspAdapterDelegate, - _: Arc, + _: Option, _: &AsyncApp, ) -> Option { if let Some(pyright_bin) = delegate.which("pyright-langserver".as_ref()).await { @@ -158,6 +169,7 @@ impl LspAdapter for PythonLspAdapter { async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, + _: &AsyncApp, ) -> Result> { Ok(Box::new( self.node @@ -204,9 +216,8 @@ impl LspAdapter for PythonLspAdapter { .should_install_npm_package( Self::SERVER_NAME.as_ref(), &server_path, - &container_dir, - &version, - Default::default(), + container_dir, + VersionStrategy::Latest(version), ) .await; @@ -233,26 +244,7 @@ impl LspAdapter for PythonLspAdapter { } async fn process_completions(&self, items: &mut [lsp::CompletionItem]) { - // Pyright assigns each completion item a `sortText` of the form `XX.YYYY.name`. - // Where `XX` is the sorting category, `YYYY` is based on most recent usage, - // and `name` is the symbol name itself. - // - // Because the symbol name is included, there generally are not ties when - // sorting by the `sortText`, so the symbol's fuzzy match score is not taken - // into account. Here, we remove the symbol name from the sortText in order - // to allow our own fuzzy score to be used to break ties. - // - // see https://github.com/microsoft/pyright/blob/95ef4e103b9b2f129c9320427e51b73ea7cf78bd/packages/pyright-internal/src/languageService/completionProvider.ts#LL2873 - for item in items { - let Some(sort_text) = &mut item.sort_text else { - continue; - }; - let mut parts = sort_text.split('.'); - let Some(first) = parts.next() else { continue }; - let Some(second) = parts.next() else { continue }; - let Some(_) = parts.next() else { continue }; - sort_text.replace_range(first.len() + second.len() + 1.., ""); - } + process_pyright_completions(items); } async fn label_for_completion( @@ -263,20 +255,34 @@ impl LspAdapter for PythonLspAdapter { let label = &item.label; let grammar = language.grammar()?; let highlight_id = match item.kind? { - lsp::CompletionItemKind::METHOD => grammar.highlight_id_for_name("function.method")?, - lsp::CompletionItemKind::FUNCTION => grammar.highlight_id_for_name("function")?, - lsp::CompletionItemKind::CLASS => grammar.highlight_id_for_name("type")?, - lsp::CompletionItemKind::CONSTANT => grammar.highlight_id_for_name("constant")?, - _ => return None, + lsp::CompletionItemKind::METHOD => grammar.highlight_id_for_name("function.method"), + lsp::CompletionItemKind::FUNCTION => grammar.highlight_id_for_name("function"), + lsp::CompletionItemKind::CLASS => grammar.highlight_id_for_name("type"), + lsp::CompletionItemKind::CONSTANT => grammar.highlight_id_for_name("constant"), + lsp::CompletionItemKind::VARIABLE => grammar.highlight_id_for_name("variable"), + _ => { + return None; + } }; let filter_range = item .filter_text .as_deref() .and_then(|filter| label.find(filter).map(|ix| ix..ix + filter.len())) .unwrap_or(0..label.len()); + let mut text = label.clone(); + if let Some(completion_details) = item + .label_details + .as_ref() + .and_then(|details| details.description.as_ref()) + { + write!(&mut text, " {}", completion_details).ok(); + } Some(language::CodeLabel { - text: label.clone(), - runs: vec![(0..label.len(), highlight_id)], + runs: highlight_id + .map(|id| (0..label.len(), id)) + .into_iter() + .collect(), + text, filter_range, }) } @@ -320,17 +326,9 @@ impl LspAdapter for PythonLspAdapter { self: Arc, _: &dyn Fs, adapter: &Arc, - toolchains: Arc, + toolchain: Option, cx: &mut AsyncApp, ) -> Result { - let toolchain = toolchains - .active_toolchain( - adapter.worktree_id(), - Arc::from("".as_ref()), - LanguageName::new("Python"), - cx, - ) - .await; cx.update(move |cx| { let mut user_settings = language_server_settings(adapter.as_ref(), &Self::SERVER_NAME, cx) @@ -338,41 +336,35 @@ impl LspAdapter for PythonLspAdapter { .unwrap_or_default(); // If we have a detected toolchain, configure Pyright to use it - if let Some(toolchain) = toolchain { + if let Some(toolchain) = toolchain + && let Ok(env) = serde_json::from_value::< + pet_core::python_environment::PythonEnvironment, + >(toolchain.as_json.clone()) + { if user_settings.is_null() { user_settings = Value::Object(serde_json::Map::default()); } let object = user_settings.as_object_mut().unwrap(); let interpreter_path = toolchain.path.to_string(); + if let Some(venv_dir) = env.prefix { + // Set venvPath and venv at the root level + // This matches the format of a pyrightconfig.json file + if let Some(parent) = venv_dir.parent() { + // Use relative path if the venv is inside the workspace + let venv_path = if parent == adapter.worktree_root_path() { + ".".to_string() + } else { + parent.to_string_lossy().into_owned() + }; + object.insert("venvPath".to_string(), Value::String(venv_path)); + } - // Detect if this is a virtual environment - if let Some(interpreter_dir) = Path::new(&interpreter_path).parent() { - if let Some(venv_dir) = interpreter_dir.parent() { - // Check if this looks like a virtual environment - if venv_dir.join("pyvenv.cfg").exists() - || venv_dir.join("bin/activate").exists() - || venv_dir.join("Scripts/activate.bat").exists() - { - // Set venvPath and venv at the root level - // This matches the format of a pyrightconfig.json file - if let Some(parent) = venv_dir.parent() { - // Use relative path if the venv is inside the workspace - let venv_path = if parent == adapter.worktree_root_path() { - ".".to_string() - } else { - parent.to_string_lossy().into_owned() - }; - object.insert("venvPath".to_string(), Value::String(venv_path)); - } - - if let Some(venv_name) = venv_dir.file_name() { - object.insert( - "venv".to_owned(), - Value::String(venv_name.to_string_lossy().into_owned()), - ); - } - } + if let Some(venv_name) = venv_dir.file_name() { + object.insert( + "venv".to_owned(), + Value::String(venv_name.to_string_lossy().into_owned()), + ); } } @@ -398,12 +390,6 @@ impl LspAdapter for PythonLspAdapter { user_settings }) } - fn manifest_name(&self) -> Option { - Some(SharedString::new_static("pyproject.toml").into()) - } - fn workspace_folders_content(&self) -> WorkspaceFoldersContent { - WorkspaceFoldersContent::WorktreeRoot - } } async fn get_cached_server_binary( @@ -431,9 +417,6 @@ const PYTHON_TEST_TARGET_TASK_VARIABLE: VariableName = const PYTHON_ACTIVE_TOOLCHAIN_PATH: VariableName = VariableName::Custom(Cow::Borrowed("PYTHON_ACTIVE_ZED_TOOLCHAIN")); -const PYTHON_ACTIVE_TOOLCHAIN_PATH_RAW: VariableName = - VariableName::Custom(Cow::Borrowed("PYTHON_ACTIVE_ZED_TOOLCHAIN_RAW")); - const PYTHON_MODULE_NAME_TASK_VARIABLE: VariableName = VariableName::Custom(Cow::Borrowed("PYTHON_MODULE_NAME")); @@ -457,7 +440,7 @@ impl ContextProvider for PythonContextProvider { let worktree_id = location_file.as_ref().map(|f| f.worktree_id(cx)); cx.spawn(async move |cx| { - let raw_toolchain = if let Some(worktree_id) = worktree_id { + let active_toolchain = if let Some(worktree_id) = worktree_id { let file_path = location_file .as_ref() .and_then(|f| f.path().parent()) @@ -475,15 +458,13 @@ impl ContextProvider for PythonContextProvider { String::from("python3") }; - let active_toolchain = format!("\"{raw_toolchain}\""); let toolchain = (PYTHON_ACTIVE_TOOLCHAIN_PATH, active_toolchain); - let raw_toolchain_var = (PYTHON_ACTIVE_TOOLCHAIN_PATH_RAW, raw_toolchain); Ok(task::TaskVariables::from_iter( test_target .into_iter() .chain(module_target.into_iter()) - .chain([toolchain, raw_toolchain_var]), + .chain([toolchain]), )) }) } @@ -500,31 +481,31 @@ impl ContextProvider for PythonContextProvider { // Execute a selection TaskTemplate { label: "execute selection".to_owned(), - command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value(), + command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value_with_whitespace(), args: vec![ "-c".to_owned(), VariableName::SelectedText.template_value_with_whitespace(), ], - cwd: Some("$ZED_WORKTREE_ROOT".into()), + cwd: Some(VariableName::WorktreeRoot.template_value()), ..TaskTemplate::default() }, // Execute an entire file TaskTemplate { label: format!("run '{}'", VariableName::File.template_value()), - command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value(), + command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value_with_whitespace(), args: vec![VariableName::File.template_value_with_whitespace()], - cwd: Some("$ZED_WORKTREE_ROOT".into()), + cwd: Some(VariableName::WorktreeRoot.template_value()), ..TaskTemplate::default() }, // Execute a file as module TaskTemplate { label: format!("run module '{}'", VariableName::File.template_value()), - command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value(), + command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value_with_whitespace(), args: vec![ "-m".to_owned(), - PYTHON_MODULE_NAME_TASK_VARIABLE.template_value(), + PYTHON_MODULE_NAME_TASK_VARIABLE.template_value_with_whitespace(), ], - cwd: Some("$ZED_WORKTREE_ROOT".into()), + cwd: Some(VariableName::WorktreeRoot.template_value()), tags: vec!["python-module-main-method".to_owned()], ..TaskTemplate::default() }, @@ -536,19 +517,19 @@ impl ContextProvider for PythonContextProvider { // Run tests for an entire file TaskTemplate { label: format!("unittest '{}'", VariableName::File.template_value()), - command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value(), + command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value_with_whitespace(), args: vec![ "-m".to_owned(), "unittest".to_owned(), VariableName::File.template_value_with_whitespace(), ], - cwd: Some("$ZED_WORKTREE_ROOT".into()), + cwd: Some(VariableName::WorktreeRoot.template_value()), ..TaskTemplate::default() }, // Run test(s) for a specific target within a file TaskTemplate { label: "unittest $ZED_CUSTOM_PYTHON_TEST_TARGET".to_owned(), - command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value(), + command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value_with_whitespace(), args: vec![ "-m".to_owned(), "unittest".to_owned(), @@ -558,7 +539,7 @@ impl ContextProvider for PythonContextProvider { "python-unittest-class".to_owned(), "python-unittest-method".to_owned(), ], - cwd: Some("$ZED_WORKTREE_ROOT".into()), + cwd: Some(VariableName::WorktreeRoot.template_value()), ..TaskTemplate::default() }, ] @@ -568,25 +549,25 @@ impl ContextProvider for PythonContextProvider { // Run tests for an entire file TaskTemplate { label: format!("pytest '{}'", VariableName::File.template_value()), - command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value(), + command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value_with_whitespace(), args: vec![ "-m".to_owned(), "pytest".to_owned(), VariableName::File.template_value_with_whitespace(), ], - cwd: Some("$ZED_WORKTREE_ROOT".into()), + cwd: Some(VariableName::WorktreeRoot.template_value()), ..TaskTemplate::default() }, // Run test(s) for a specific target within a file TaskTemplate { label: "pytest $ZED_CUSTOM_PYTHON_TEST_TARGET".to_owned(), - command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value(), + command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value_with_whitespace(), args: vec![ "-m".to_owned(), "pytest".to_owned(), PYTHON_TEST_TARGET_TASK_VARIABLE.template_value_with_whitespace(), ], - cwd: Some("$ZED_WORKTREE_ROOT".into()), + cwd: Some(VariableName::WorktreeRoot.template_value()), tags: vec![ "python-pytest-class".to_owned(), "python-pytest-method".to_owned(), @@ -714,19 +695,9 @@ fn python_env_kind_display(k: &PythonEnvironmentKind) -> &'static str { } } -pub(crate) struct PythonToolchainProvider { - term: SharedString, -} - -impl Default for PythonToolchainProvider { - fn default() -> Self { - Self { - term: SharedString::new_static("Virtual Environment"), - } - } -} +pub(crate) struct PythonToolchainProvider; -static ENV_PRIORITY_LIST: &'static [PythonEnvironmentKind] = &[ +static ENV_PRIORITY_LIST: &[PythonEnvironmentKind] = &[ // Prioritize non-Conda environments. PythonEnvironmentKind::Poetry, PythonEnvironmentKind::Pipenv, @@ -756,25 +727,24 @@ fn env_priority(kind: Option) -> usize { /// Return the name of environment declared in Option { - fs::File::open(worktree_root.join(".venv")) - .and_then(|file| { - let mut venv_name = String::new(); - io::BufReader::new(file).read_line(&mut venv_name)?; - Ok(venv_name.trim().to_string()) - }) - .ok() +async fn get_worktree_venv_declaration(worktree_root: &Path) -> Option { + let file = async_fs::File::open(worktree_root.join(".venv")) + .await + .ok()?; + let mut venv_name = String::new(); + smol::io::BufReader::new(file) + .read_line(&mut venv_name) + .await + .ok()?; + Some(venv_name.trim().to_string()) } #[async_trait] impl ToolchainLister for PythonToolchainProvider { - fn manifest_name(&self) -> language::ManifestName { - ManifestName::from(SharedString::new_static("pyproject.toml")) - } async fn list( &self, worktree_root: PathBuf, - subroot_relative_path: Option>, + subroot_relative_path: Arc, project_env: Option>, ) -> ToolchainList { let env = project_env.unwrap_or_default(); @@ -786,13 +756,15 @@ impl ToolchainLister for PythonToolchainProvider { ); let mut config = Configuration::default(); - let mut directories = vec![worktree_root.clone()]; - if let Some(subroot_relative_path) = subroot_relative_path { - debug_assert!(subroot_relative_path.is_relative()); - directories.push(worktree_root.join(subroot_relative_path)); - } - - config.workspace_directories = Some(directories); + debug_assert!(subroot_relative_path.is_relative()); + // `.ancestors()` will yield at least one path, so in case of empty `subroot_relative_path`, we'll just use + // worktree root as the workspace directory. + config.workspace_directories = Some( + subroot_relative_path + .ancestors() + .map(|ancestor| worktree_root.join(ancestor)) + .collect(), + ); for locator in locators.iter() { locator.configure(&config); } @@ -806,7 +778,7 @@ impl ToolchainLister for PythonToolchainProvider { .map_or(Vec::new(), |mut guard| std::mem::take(&mut guard)); let wr = worktree_root; - let wr_venv = get_worktree_venv_declaration(&wr); + let wr_venv = get_worktree_venv_declaration(&wr).await; // Sort detected environments by: // environment name matching activation file (/.venv) // environment project dir matching worktree_root @@ -843,7 +815,7 @@ impl ToolchainLister for PythonToolchainProvider { .get_env_var("CONDA_PREFIX".to_string()) .map(|conda_prefix| { let is_match = |exe: &Option| { - exe.as_ref().map_or(false, |e| e.starts_with(&conda_prefix)) + exe.as_ref().is_some_and(|e| e.starts_with(&conda_prefix)) }; match (is_match(&lhs.executable), is_match(&rhs.executable)) { (true, false) => Ordering::Less, @@ -869,32 +841,7 @@ impl ToolchainLister for PythonToolchainProvider { let mut toolchains: Vec<_> = toolchains .into_iter() - .filter_map(|toolchain| { - let mut name = String::from("Python"); - if let Some(ref version) = toolchain.version { - _ = write!(name, " {version}"); - } - - let name_and_kind = match (&toolchain.name, &toolchain.kind) { - (Some(name), Some(kind)) => { - Some(format!("({name}; {})", python_env_kind_display(kind))) - } - (Some(name), None) => Some(format!("({name})")), - (None, Some(kind)) => Some(format!("({})", python_env_kind_display(kind))), - (None, None) => None, - }; - - if let Some(nk) = name_and_kind { - _ = write!(name, " {nk}"); - } - - Some(Toolchain { - name: name.into(), - path: toolchain.executable.as_ref()?.to_str()?.to_owned().into(), - language_name: LanguageName::new("Python"), - as_json: serde_json::to_value(toolchain).ok()?, - }) - }) + .filter_map(venv_to_toolchain) .collect(); toolchains.dedup(); ToolchainList { @@ -903,9 +850,125 @@ impl ToolchainLister for PythonToolchainProvider { groups: Default::default(), } } - fn term(&self) -> SharedString { - self.term.clone() + fn meta(&self) -> ToolchainMetadata { + ToolchainMetadata { + term: SharedString::new_static("Virtual Environment"), + new_toolchain_placeholder: SharedString::new_static( + "A path to the python3 executable within a virtual environment, or path to virtual environment itself", + ), + manifest_name: ManifestName::from(SharedString::new_static("pyproject.toml")), + } + } + + async fn resolve( + &self, + path: PathBuf, + env: Option>, + ) -> anyhow::Result { + let env = env.unwrap_or_default(); + let environment = EnvironmentApi::from_env(&env); + let locators = pet::locators::create_locators( + Arc::new(pet_conda::Conda::from(&environment)), + Arc::new(pet_poetry::Poetry::from(&environment)), + &environment, + ); + let toolchain = pet::resolve::resolve_environment(&path, &locators, &environment) + .context("Could not find a virtual environment in provided path")?; + let venv = toolchain.resolved.unwrap_or(toolchain.discovered); + venv_to_toolchain(venv).context("Could not convert a venv into a toolchain") } + + async fn activation_script( + &self, + toolchain: &Toolchain, + shell: ShellKind, + fs: &dyn Fs, + ) -> Vec { + let Ok(toolchain) = serde_json::from_value::( + toolchain.as_json.clone(), + ) else { + return vec![]; + }; + let mut activation_script = vec![]; + + match toolchain.kind { + Some(PythonEnvironmentKind::Conda) => { + if let Some(name) = &toolchain.name { + activation_script.push(format!("conda activate {name}")); + } else { + activation_script.push("conda activate".to_string()); + } + } + Some(PythonEnvironmentKind::Venv | PythonEnvironmentKind::VirtualEnv) => { + if let Some(prefix) = &toolchain.prefix { + let activate_keyword = match shell { + ShellKind::Cmd => ".", + ShellKind::Nushell => "overlay use", + ShellKind::Powershell => ".", + ShellKind::Fish => "source", + ShellKind::Csh => "source", + ShellKind::Posix => "source", + }; + let activate_script_name = match shell { + ShellKind::Posix => "activate", + ShellKind::Csh => "activate.csh", + ShellKind::Fish => "activate.fish", + ShellKind::Nushell => "activate.nu", + ShellKind::Powershell => "activate.ps1", + ShellKind::Cmd => "activate.bat", + }; + let path = prefix.join(BINARY_DIR).join(activate_script_name); + if fs.is_file(&path).await { + activation_script + .push(format!("{activate_keyword} \"{}\"", path.display())); + } + } + } + Some(PythonEnvironmentKind::Pyenv) => { + let Some(manager) = toolchain.manager else { + return vec![]; + }; + let version = toolchain.version.as_deref().unwrap_or("system"); + let pyenv = manager.executable; + let pyenv = pyenv.display(); + activation_script.extend(match shell { + ShellKind::Fish => Some(format!("\"{pyenv}\" shell - fish {version}")), + ShellKind::Posix => Some(format!("\"{pyenv}\" shell - sh {version}")), + ShellKind::Nushell => Some(format!("\"{pyenv}\" shell - nu {version}")), + ShellKind::Powershell => None, + ShellKind::Csh => None, + ShellKind::Cmd => None, + }) + } + _ => {} + } + activation_script + } +} + +fn venv_to_toolchain(venv: PythonEnvironment) -> Option { + let mut name = String::from("Python"); + if let Some(ref version) = venv.version { + _ = write!(name, " {version}"); + } + + let name_and_kind = match (&venv.name, &venv.kind) { + (Some(name), Some(kind)) => Some(format!("({name}; {})", python_env_kind_display(kind))), + (Some(name), None) => Some(format!("({name})")), + (None, Some(kind)) => Some(format!("({})", python_env_kind_display(kind))), + (None, None) => None, + }; + + if let Some(nk) = name_and_kind { + _ = write!(name, " {nk}"); + } + + Some(Toolchain { + name: name.into(), + path: venv.executable.as_ref()?.to_str()?.to_owned().into(), + language_name: LanguageName::new("Python"), + as_json: serde_json::to_value(venv).ok()?, + }) } pub struct EnvironmentApi<'a> { @@ -1041,14 +1104,14 @@ const BINARY_DIR: &str = if cfg!(target_os = "windows") { #[async_trait(?Send)] impl LspAdapter for PyLspAdapter { fn name(&self) -> LanguageServerName { - Self::SERVER_NAME.clone() + Self::SERVER_NAME } async fn check_if_user_installed( &self, delegate: &dyn LspAdapterDelegate, - toolchains: Arc, - cx: &AsyncApp, + toolchain: Option, + _: &AsyncApp, ) -> Option { if let Some(pylsp_bin) = delegate.which(Self::SERVER_NAME.as_ref()).await { let env = delegate.shell_env().await; @@ -1058,17 +1121,10 @@ impl LspAdapter for PyLspAdapter { arguments: vec![], }) } else { - let venv = toolchains - .active_toolchain( - delegate.worktree_id(), - Arc::from("".as_ref()), - LanguageName::new("Python"), - &mut cx.clone(), - ) - .await?; - let pylsp_path = Path::new(venv.path.as_ref()).parent()?.join("pylsp"); + let toolchain = toolchain?; + let pylsp_path = Path::new(toolchain.path.as_ref()).parent()?.join("pylsp"); pylsp_path.exists().then(|| LanguageServerBinary { - path: venv.path.to_string().into(), + path: toolchain.path.to_string().into(), arguments: vec![pylsp_path.into()], env: None, }) @@ -1078,6 +1134,7 @@ impl LspAdapter for PyLspAdapter { async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, + _: &AsyncApp, ) -> Result> { Ok(Box::new(()) as Box<_>) } @@ -1212,17 +1269,9 @@ impl LspAdapter for PyLspAdapter { self: Arc, _: &dyn Fs, adapter: &Arc, - toolchains: Arc, + toolchain: Option, cx: &mut AsyncApp, ) -> Result { - let toolchain = toolchains - .active_toolchain( - adapter.worktree_id(), - Arc::from("".as_ref()), - LanguageName::new("Python"), - cx, - ) - .await; cx.update(move |cx| { let mut user_settings = language_server_settings(adapter.as_ref(), &Self::SERVER_NAME, cx) @@ -1283,12 +1332,6 @@ impl LspAdapter for PyLspAdapter { user_settings }) } - fn manifest_name(&self) -> Option { - Some(SharedString::new_static("pyproject.toml").into()) - } - fn workspace_folders_content(&self) -> WorkspaceFoldersContent { - WorkspaceFoldersContent::WorktreeRoot - } } pub(crate) struct BasedPyrightLspAdapter { @@ -1354,7 +1397,7 @@ impl BasedPyrightLspAdapter { #[async_trait(?Send)] impl LspAdapter for BasedPyrightLspAdapter { fn name(&self) -> LanguageServerName { - Self::SERVER_NAME.clone() + Self::SERVER_NAME } async fn initialization_options( @@ -1378,8 +1421,8 @@ impl LspAdapter for BasedPyrightLspAdapter { async fn check_if_user_installed( &self, delegate: &dyn LspAdapterDelegate, - toolchains: Arc, - cx: &AsyncApp, + toolchain: Option, + _: &AsyncApp, ) -> Option { if let Some(bin) = delegate.which(Self::BINARY_NAME.as_ref()).await { let env = delegate.shell_env().await; @@ -1389,15 +1432,7 @@ impl LspAdapter for BasedPyrightLspAdapter { arguments: vec!["--stdio".into()], }) } else { - let venv = toolchains - .active_toolchain( - delegate.worktree_id(), - Arc::from("".as_ref()), - LanguageName::new("Python"), - &mut cx.clone(), - ) - .await?; - let path = Path::new(venv.path.as_ref()) + let path = Path::new(toolchain?.path.as_ref()) .parent()? .join(Self::BINARY_NAME); path.exists().then(|| LanguageServerBinary { @@ -1411,6 +1446,7 @@ impl LspAdapter for BasedPyrightLspAdapter { async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, + _: &AsyncApp, ) -> Result> { Ok(Box::new(()) as Box<_>) } @@ -1457,26 +1493,7 @@ impl LspAdapter for BasedPyrightLspAdapter { } async fn process_completions(&self, items: &mut [lsp::CompletionItem]) { - // Pyright assigns each completion item a `sortText` of the form `XX.YYYY.name`. - // Where `XX` is the sorting category, `YYYY` is based on most recent usage, - // and `name` is the symbol name itself. - // - // Because the symbol name is included, there generally are not ties when - // sorting by the `sortText`, so the symbol's fuzzy match score is not taken - // into account. Here, we remove the symbol name from the sortText in order - // to allow our own fuzzy score to be used to break ties. - // - // see https://github.com/microsoft/pyright/blob/95ef4e103b9b2f129c9320427e51b73ea7cf78bd/packages/pyright-internal/src/languageService/completionProvider.ts#LL2873 - for item in items { - let Some(sort_text) = &mut item.sort_text else { - continue; - }; - let mut parts = sort_text.split('.'); - let Some(first) = parts.next() else { continue }; - let Some(second) = parts.next() else { continue }; - let Some(_) = parts.next() else { continue }; - sort_text.replace_range(first.len() + second.len() + 1.., ""); - } + process_pyright_completions(items); } async fn label_for_completion( @@ -1487,20 +1504,34 @@ impl LspAdapter for BasedPyrightLspAdapter { let label = &item.label; let grammar = language.grammar()?; let highlight_id = match item.kind? { - lsp::CompletionItemKind::METHOD => grammar.highlight_id_for_name("function.method")?, - lsp::CompletionItemKind::FUNCTION => grammar.highlight_id_for_name("function")?, - lsp::CompletionItemKind::CLASS => grammar.highlight_id_for_name("type")?, - lsp::CompletionItemKind::CONSTANT => grammar.highlight_id_for_name("constant")?, - _ => return None, + lsp::CompletionItemKind::METHOD => grammar.highlight_id_for_name("function.method"), + lsp::CompletionItemKind::FUNCTION => grammar.highlight_id_for_name("function"), + lsp::CompletionItemKind::CLASS => grammar.highlight_id_for_name("type"), + lsp::CompletionItemKind::CONSTANT => grammar.highlight_id_for_name("constant"), + lsp::CompletionItemKind::VARIABLE => grammar.highlight_id_for_name("variable"), + _ => { + return None; + } }; let filter_range = item .filter_text .as_deref() .and_then(|filter| label.find(filter).map(|ix| ix..ix + filter.len())) .unwrap_or(0..label.len()); + let mut text = label.clone(); + if let Some(completion_details) = item + .label_details + .as_ref() + .and_then(|details| details.description.as_ref()) + { + write!(&mut text, " {}", completion_details).ok(); + } Some(language::CodeLabel { - text: label.clone(), - runs: vec![(0..label.len(), highlight_id)], + runs: highlight_id + .map(|id| (0..label.len(), id)) + .into_iter() + .collect(), + text, filter_range, }) } @@ -1544,17 +1575,9 @@ impl LspAdapter for BasedPyrightLspAdapter { self: Arc, _: &dyn Fs, adapter: &Arc, - toolchains: Arc, + toolchain: Option, cx: &mut AsyncApp, ) -> Result { - let toolchain = toolchains - .active_toolchain( - adapter.worktree_id(), - Arc::from("".as_ref()), - LanguageName::new("Python"), - cx, - ) - .await; cx.update(move |cx| { let mut user_settings = language_server_settings(adapter.as_ref(), &Self::SERVER_NAME, cx) @@ -1562,74 +1585,74 @@ impl LspAdapter for BasedPyrightLspAdapter { .unwrap_or_default(); // If we have a detected toolchain, configure Pyright to use it - if let Some(toolchain) = toolchain { + if let Some(toolchain) = toolchain + && let Ok(env) = serde_json::from_value::< + pet_core::python_environment::PythonEnvironment, + >(toolchain.as_json.clone()) + { if user_settings.is_null() { user_settings = Value::Object(serde_json::Map::default()); } let object = user_settings.as_object_mut().unwrap(); let interpreter_path = toolchain.path.to_string(); + if let Some(venv_dir) = env.prefix { + // Set venvPath and venv at the root level + // This matches the format of a pyrightconfig.json file + if let Some(parent) = venv_dir.parent() { + // Use relative path if the venv is inside the workspace + let venv_path = if parent == adapter.worktree_root_path() { + ".".to_string() + } else { + parent.to_string_lossy().into_owned() + }; + object.insert("venvPath".to_string(), Value::String(venv_path)); + } - // Detect if this is a virtual environment - if let Some(interpreter_dir) = Path::new(&interpreter_path).parent() { - if let Some(venv_dir) = interpreter_dir.parent() { - // Check if this looks like a virtual environment - if venv_dir.join("pyvenv.cfg").exists() - || venv_dir.join("bin/activate").exists() - || venv_dir.join("Scripts/activate.bat").exists() - { - // Set venvPath and venv at the root level - // This matches the format of a pyrightconfig.json file - if let Some(parent) = venv_dir.parent() { - // Use relative path if the venv is inside the workspace - let venv_path = if parent == adapter.worktree_root_path() { - ".".to_string() - } else { - parent.to_string_lossy().into_owned() - }; - object.insert("venvPath".to_string(), Value::String(venv_path)); - } - - if let Some(venv_name) = venv_dir.file_name() { - object.insert( - "venv".to_owned(), - Value::String(venv_name.to_string_lossy().into_owned()), - ); - } - } + if let Some(venv_name) = venv_dir.file_name() { + object.insert( + "venv".to_owned(), + Value::String(venv_name.to_string_lossy().into_owned()), + ); } } - // Always set the python interpreter path - // Get or create the python section - let python = object + // Set both pythonPath and defaultInterpreterPath for compatibility + if let Some(python) = object .entry("python") .or_insert(Value::Object(serde_json::Map::default())) .as_object_mut() - .unwrap(); - - // Set both pythonPath and defaultInterpreterPath for compatibility - python.insert( - "pythonPath".to_owned(), - Value::String(interpreter_path.clone()), - ); - python.insert( - "defaultInterpreterPath".to_owned(), - Value::String(interpreter_path), - ); + { + python.insert( + "pythonPath".to_owned(), + Value::String(interpreter_path.clone()), + ); + python.insert( + "defaultInterpreterPath".to_owned(), + Value::String(interpreter_path), + ); + } + // Basedpyright by default uses `strict` type checking, we tone it down as to not surpris users + maybe!({ + let basedpyright = object + .entry("basedpyright") + .or_insert(Value::Object(serde_json::Map::default())); + let analysis = basedpyright + .as_object_mut()? + .entry("analysis") + .or_insert(Value::Object(serde_json::Map::default())); + if let serde_json::map::Entry::Vacant(v) = + analysis.as_object_mut()?.entry("typeCheckingMode") + { + v.insert(Value::String("standard".to_owned())); + } + Some(()) + }); } user_settings }) } - - fn manifest_name(&self) -> Option { - Some(SharedString::new_static("pyproject.toml").into()) - } - - fn workspace_folders_content(&self) -> WorkspaceFoldersContent { - WorkspaceFoldersContent::WorktreeRoot - } } #[cfg(test)] diff --git a/crates/languages/src/rust.rs b/crates/languages/src/rust.rs index 3baaec18421f10cfd83aff44c348e7635e295acf..3d5ff1cd06149594e54de03d454bef64483b7e47 100644 --- a/crates/languages/src/rust.rs +++ b/crates/languages/src/rust.rs @@ -106,17 +106,13 @@ impl ManifestProvider for CargoManifestProvider { #[async_trait(?Send)] impl LspAdapter for RustLspAdapter { fn name(&self) -> LanguageServerName { - SERVER_NAME.clone() - } - - fn manifest_name(&self) -> Option { - Some(SharedString::new_static("Cargo.toml").into()) + SERVER_NAME } async fn check_if_user_installed( &self, delegate: &dyn LspAdapterDelegate, - _: Arc, + _: Option, _: &AsyncApp, ) -> Option { let path = delegate.which("rust-analyzer".as_ref()).await?; @@ -151,11 +147,16 @@ impl LspAdapter for RustLspAdapter { async fn fetch_latest_server_version( &self, delegate: &dyn LspAdapterDelegate, + cx: &AsyncApp, ) -> Result> { let release = latest_github_release( "rust-lang/rust-analyzer", true, - false, + ProjectSettings::try_read_global(cx, |s| { + s.lsp.get(&SERVER_NAME)?.fetch.as_ref()?.pre_release + }) + .flatten() + .unwrap_or(false), delegate.http_client(), ) .await?; @@ -407,7 +408,7 @@ impl LspAdapter for RustLspAdapter { } else if completion .detail .as_ref() - .map_or(false, |detail| detail.starts_with("macro_rules! ")) + .is_some_and(|detail| detail.starts_with("macro_rules! ")) { let text = completion.label.clone(); let len = text.len(); @@ -500,7 +501,7 @@ impl LspAdapter for RustLspAdapter { let enable_lsp_tasks = ProjectSettings::get_global(cx) .lsp .get(&SERVER_NAME) - .map_or(false, |s| s.enable_lsp_tasks); + .is_some_and(|s| s.enable_lsp_tasks); if enable_lsp_tasks { let experimental = json!({ "runnables": { @@ -514,20 +515,6 @@ impl LspAdapter for RustLspAdapter { } } - let cargo_diagnostics_fetched_separately = ProjectSettings::get_global(cx) - .diagnostics - .fetch_cargo_diagnostics(); - if cargo_diagnostics_fetched_separately { - let disable_check_on_save = json!({ - "checkOnSave": false, - }); - if let Some(initialization_options) = &mut original.initialization_options { - merge_json_value_into(disable_check_on_save, initialization_options); - } else { - original.initialization_options = Some(disable_check_on_save); - } - } - Ok(original) } } @@ -585,7 +572,7 @@ impl ContextProvider for RustContextProvider { if let (Some(path), Some(stem)) = (&local_abs_path, task_variables.get(&VariableName::Stem)) { - let fragment = test_fragment(&variables, &path, stem); + let fragment = test_fragment(&variables, path, stem); variables.insert(RUST_TEST_FRAGMENT_TASK_VARIABLE, fragment); }; if let Some(test_name) = @@ -602,16 +589,14 @@ impl ContextProvider for RustContextProvider { if let Some(path) = local_abs_path .as_deref() .and_then(|local_abs_path| local_abs_path.parent()) - { - if let Some(package_name) = + && let Some(package_name) = human_readable_package_name(path, project_env.as_ref()).await - { - variables.insert(RUST_PACKAGE_TASK_VARIABLE.clone(), package_name); - } + { + variables.insert(RUST_PACKAGE_TASK_VARIABLE.clone(), package_name); } if let Some(path) = local_abs_path.as_ref() && let Some((target, manifest_path)) = - target_info_from_abs_path(&path, project_env.as_ref()).await + target_info_from_abs_path(path, project_env.as_ref()).await { if let Some(target) = target { variables.extend(TaskVariables::from_iter([ @@ -665,7 +650,7 @@ impl ContextProvider for RustContextProvider { .variables .get(CUSTOM_TARGET_DIR) .cloned(); - let run_task_args = if let Some(package_to_run) = package_to_run.clone() { + let run_task_args = if let Some(package_to_run) = package_to_run { vec!["run".into(), "-p".into(), package_to_run] } else { vec!["run".into()] @@ -1025,8 +1010,8 @@ async fn get_cached_server_binary(container_dir: PathBuf) -> Option path.clone(), // Tar and gzip extract in place. - AssetKind::Zip => path.clone().join("rust-analyzer.exe"), // zip contains a .exe + AssetKind::TarGz | AssetKind::Gz => path, // Tar and gzip extract in place. + AssetKind::Zip => path.join("rust-analyzer.exe"), // zip contains a .exe }; anyhow::Ok(LanguageServerBinary { @@ -1078,7 +1063,7 @@ mod tests { #[gpui::test] async fn test_process_rust_diagnostics() { let mut params = lsp::PublishDiagnosticsParams { - uri: lsp::Url::from_file_path(path!("/a")).unwrap(), + uri: lsp::Uri::from_file_path(path!("/a")).unwrap(), version: None, diagnostics: vec![ // no newlines @@ -1574,7 +1559,7 @@ mod tests { let found = test_fragment( &TaskVariables::from_iter(variables.into_iter().map(|(k, v)| (k, v.to_owned()))), path, - &path.file_stem().unwrap().to_str().unwrap(), + path.file_stem().unwrap().to_str().unwrap(), ); assert_eq!(expected, found); } diff --git a/crates/languages/src/rust/highlights.scm b/crates/languages/src/rust/highlights.scm index 1c46061827cd504df669aadacd0a489172d1ce5a..b0daac71a097b922aa810aadef64a18e95b5b649 100644 --- a/crates/languages/src/rust/highlights.scm +++ b/crates/languages/src/rust/highlights.scm @@ -5,6 +5,7 @@ (primitive_type) @type.builtin (self) @variable.special (field_identifier) @property +(shorthand_field_identifier) @property (trait_item name: (type_identifier) @type.interface) (impl_item trait: (type_identifier) @type.interface) @@ -195,12 +196,13 @@ operator: "/" @operator (attribute_item (attribute [ (identifier) @attribute (scoped_identifier name: (identifier) @attribute) + (token_tree (identifier) @attribute (#match? @attribute "^[a-z\\d_]*$")) + (token_tree (identifier) @none "::" (#match? @none "^[a-z\\d_]*$")) ])) + (inner_attribute_item (attribute [ (identifier) @attribute (scoped_identifier name: (identifier) @attribute) + (token_tree (identifier) @attribute (#match? @attribute "^[a-z\\d_]*$")) + (token_tree (identifier) @none "::" (#match? @none "^[a-z\\d_]*$")) ])) -; Match nested snake case identifiers in attribute items. -(token_tree (identifier) @attribute (#match? @attribute "^[a-z\\d_]*$")) -; Override the attribute match for paths in scoped type/enum identifiers. -(token_tree (identifier) @variable "::" (identifier) @type (#match? @type "^[A-Z]")) diff --git a/crates/languages/src/tailwind.rs b/crates/languages/src/tailwind.rs index 6f03eeda8d2414578dcda939eb89fb1b5f812768..af7653ea9e3059a40e0023e5d5d2ccd3c8b02556 100644 --- a/crates/languages/src/tailwind.rs +++ b/crates/languages/src/tailwind.rs @@ -3,9 +3,9 @@ use async_trait::async_trait; use collections::HashMap; use futures::StreamExt; use gpui::AsyncApp; -use language::{LanguageName, LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; +use language::{LanguageName, LspAdapter, LspAdapterDelegate, Toolchain}; use lsp::{LanguageServerBinary, LanguageServerName}; -use node_runtime::NodeRuntime; +use node_runtime::{NodeRuntime, VersionStrategy}; use project::{Fs, lsp_store::language_server_settings}; use serde_json::{Value, json}; use smol::fs; @@ -44,13 +44,13 @@ impl TailwindLspAdapter { #[async_trait(?Send)] impl LspAdapter for TailwindLspAdapter { fn name(&self) -> LanguageServerName { - Self::SERVER_NAME.clone() + Self::SERVER_NAME } async fn check_if_user_installed( &self, delegate: &dyn LspAdapterDelegate, - _: Arc, + _: Option, _: &AsyncApp, ) -> Option { let path = delegate.which(Self::SERVER_NAME.as_ref()).await?; @@ -66,6 +66,7 @@ impl LspAdapter for TailwindLspAdapter { async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, + _: &AsyncApp, ) -> Result> { Ok(Box::new( self.node @@ -111,9 +112,8 @@ impl LspAdapter for TailwindLspAdapter { .should_install_npm_package( Self::PACKAGE_NAME, &server_path, - &container_dir, - &version, - Default::default(), + container_dir, + VersionStrategy::Latest(version), ) .await; @@ -156,7 +156,7 @@ impl LspAdapter for TailwindLspAdapter { self: Arc, _: &dyn Fs, delegate: &Arc, - _: Arc, + _: Option, cx: &mut AsyncApp, ) -> Result { let mut tailwind_user_settings = cx.update(|cx| { @@ -185,6 +185,7 @@ impl LspAdapter for TailwindLspAdapter { (LanguageName::new("Elixir"), "phoenix-heex".to_string()), (LanguageName::new("HEEX"), "phoenix-heex".to_string()), (LanguageName::new("ERB"), "erb".to_string()), + (LanguageName::new("HTML+ERB"), "erb".to_string()), (LanguageName::new("HTML/ERB"), "erb".to_string()), (LanguageName::new("PHP"), "php".to_string()), (LanguageName::new("Vue.js"), "vue".to_string()), diff --git a/crates/languages/src/tsx/config.toml b/crates/languages/src/tsx/config.toml index 5849b9842fd7f3483f89bbedbdb7b74b3fc1572d..b5ef5bd56df2097bc920f02b87d07e4118d7b0d1 100644 --- a/crates/languages/src/tsx/config.toml +++ b/crates/languages/src/tsx/config.toml @@ -4,6 +4,7 @@ path_suffixes = ["tsx"] line_comments = ["// "] block_comment = { start = "/*", prefix = "* ", end = "*/", tab_size = 1 } documentation_comment = { start = "/**", prefix = "* ", end = "*/", tab_size = 1 } +wrap_characters = { start_prefix = "<", start_suffix = ">", end_prefix = "" } autoclose_before = ";:.,=}])>" brackets = [ { start = "{", end = "}", close = true, newline = true }, diff --git a/crates/languages/src/tsx/highlights.scm b/crates/languages/src/tsx/highlights.scm index 5e2fbbf63ac9bce667599955c90bb5416dc29ec5..f7cb987831578f1d3e78decbf89f71c91d3a3b7e 100644 --- a/crates/languages/src/tsx/highlights.scm +++ b/crates/languages/src/tsx/highlights.scm @@ -237,6 +237,7 @@ "implements" "interface" "keyof" + "module" "namespace" "private" "protected" @@ -256,4 +257,4 @@ (jsx_closing_element ([""]) @punctuation.bracket.jsx) (jsx_self_closing_element (["<" "/>"]) @punctuation.bracket.jsx) (jsx_attribute "=" @punctuation.delimiter.jsx) -(jsx_text) @text.jsx \ No newline at end of file +(jsx_text) @text.jsx diff --git a/crates/languages/src/tsx/injections.scm b/crates/languages/src/tsx/injections.scm index 48da80995bba86765e3dc78748eea6b4d5811bed..f749aac43a713dadc6abe81a0523f241610b2675 100644 --- a/crates/languages/src/tsx/injections.scm +++ b/crates/languages/src/tsx/injections.scm @@ -11,6 +11,21 @@ (#set! injection.language "css")) ) +(call_expression + function: (member_expression + object: (identifier) @_obj (#eq? @_obj "styled") + property: (property_identifier)) + arguments: (template_string (string_fragment) @injection.content + (#set! injection.language "css")) +) + +(call_expression + function: (call_expression + function: (identifier) @_name (#eq? @_name "styled")) + arguments: (template_string (string_fragment) @injection.content + (#set! injection.language "css")) +) + (call_expression function: (identifier) @_name (#eq? @_name "html") arguments: (template_string (string_fragment) @injection.content @@ -58,3 +73,9 @@ arguments: (arguments (template_string (string_fragment) @injection.content (#set! injection.language "graphql"))) ) + +(call_expression + function: (identifier) @_name(#match? @_name "^iso$") + arguments: (arguments (template_string (string_fragment) @injection.content + (#set! injection.language "isograph"))) +) diff --git a/crates/languages/src/typescript.rs b/crates/languages/src/typescript.rs index a8ba880889b36071bdb0c474c4790f4c37ad165c..d8f18f5f045f67e658a5458ee5db443f5c85d92e 100644 --- a/crates/languages/src/typescript.rs +++ b/crates/languages/src/typescript.rs @@ -7,10 +7,10 @@ use gpui::{App, AppContext, AsyncApp, Task}; use http_client::github::{AssetKind, GitHubLspBinaryVersion, build_asset_url}; use language::{ ContextLocation, ContextProvider, File, LanguageName, LanguageToolchainStore, LspAdapter, - LspAdapterDelegate, + LspAdapterDelegate, Toolchain, }; use lsp::{CodeActionKind, LanguageServerBinary, LanguageServerName}; -use node_runtime::NodeRuntime; +use node_runtime::{NodeRuntime, VersionStrategy}; use project::{Fs, lsp_store::language_server_settings}; use serde_json::{Value, json}; use smol::{fs, lock::RwLock, stream::StreamExt}; @@ -341,10 +341,10 @@ async fn detect_package_manager( fs: Arc, package_json_data: Option, ) -> &'static str { - if let Some(package_json_data) = package_json_data { - if let Some(package_manager) = package_json_data.package_manager { - return package_manager; - } + if let Some(package_json_data) = package_json_data + && let Some(package_manager) = package_json_data.package_manager + { + return package_manager; } if fs.is_file(&worktree_root.join("pnpm-lock.yaml")).await { return "pnpm"; @@ -557,12 +557,13 @@ struct TypeScriptVersions { #[async_trait(?Send)] impl LspAdapter for TypeScriptLspAdapter { fn name(&self) -> LanguageServerName { - Self::SERVER_NAME.clone() + Self::SERVER_NAME } async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, + _: &AsyncApp, ) -> Result> { Ok(Box::new(TypeScriptVersions { typescript_version: self.node.npm_package_latest_version("typescript").await?, @@ -587,9 +588,8 @@ impl LspAdapter for TypeScriptLspAdapter { .should_install_npm_package( Self::PACKAGE_NAME, &server_path, - &container_dir, - version.typescript_version.as_str(), - Default::default(), + container_dir, + VersionStrategy::Latest(version.typescript_version.as_str()), ) .await; @@ -723,7 +723,7 @@ impl LspAdapter for TypeScriptLspAdapter { self: Arc, _: &dyn Fs, delegate: &Arc, - _: Arc, + _: Option, cx: &mut AsyncApp, ) -> Result { let override_options = cx.update(|cx| { @@ -823,7 +823,7 @@ impl LspAdapter for EsLintLspAdapter { self: Arc, _: &dyn Fs, delegate: &Arc, - _: Arc, + _: Option, cx: &mut AsyncApp, ) -> Result { let workspace_root = delegate.worktree_root_path(); @@ -880,12 +880,13 @@ impl LspAdapter for EsLintLspAdapter { } fn name(&self) -> LanguageServerName { - Self::SERVER_NAME.clone() + Self::SERVER_NAME } async fn fetch_latest_server_version( &self, _delegate: &dyn LspAdapterDelegate, + _: &AsyncApp, ) -> Result> { let url = build_asset_url( "zed-industries/vscode-eslint", @@ -911,7 +912,7 @@ impl LspAdapter for EsLintLspAdapter { let server_path = destination_path.join(Self::SERVER_PATH); if fs::metadata(&server_path).await.is_err() { - remove_matching(&container_dir, |entry| entry != destination_path).await; + remove_matching(&container_dir, |_| true).await; download_server_binary( delegate, @@ -1029,7 +1030,7 @@ mod tests { .unindent(); let buffer = cx.new(|cx| language::Buffer::local(text, cx).with_language(language, cx)); - let outline = buffer.read_with(cx, |buffer, _| buffer.snapshot().outline(None).unwrap()); + let outline = buffer.read_with(cx, |buffer, _| buffer.snapshot().outline(None)); assert_eq!( outline .items @@ -1083,7 +1084,7 @@ mod tests { .unindent(); let buffer = cx.new(|cx| language::Buffer::local(text, cx).with_language(language, cx)); - let outline = buffer.read_with(cx, |buffer, _| buffer.snapshot().outline(None).unwrap()); + let outline = buffer.read_with(cx, |buffer, _| buffer.snapshot().outline(None)); assert_eq!( outline .items diff --git a/crates/languages/src/typescript/config.toml b/crates/languages/src/typescript/config.toml index d7e3e4bd3d1569f96636b7f7572deea306b46df7..2344f6209da7756049438669ee55d5376fdb47f8 100644 --- a/crates/languages/src/typescript/config.toml +++ b/crates/languages/src/typescript/config.toml @@ -5,6 +5,7 @@ first_line_pattern = '^#!.*\b(?:deno run|ts-node|bun|tsx|[/ ]node)\b' line_comments = ["// "] block_comment = { start = "/*", prefix = "* ", end = "*/", tab_size = 1 } documentation_comment = { start = "/**", prefix = "* ", end = "*/", tab_size = 1 } +wrap_characters = { start_prefix = "<", start_suffix = ">", end_prefix = "" } autoclose_before = ";:.,=}])>" brackets = [ { start = "{", end = "}", close = true, newline = true }, diff --git a/crates/languages/src/typescript/highlights.scm b/crates/languages/src/typescript/highlights.scm index af37ef6415ba501c1623977c04b7a7b7d110eeb5..84cbbae77d43c96e62578c444ee913055604e11a 100644 --- a/crates/languages/src/typescript/highlights.scm +++ b/crates/languages/src/typescript/highlights.scm @@ -248,6 +248,7 @@ "is" "keyof" "let" + "module" "namespace" "new" "of" @@ -272,4 +273,4 @@ "while" "with" "yield" -] @keyword \ No newline at end of file +] @keyword diff --git a/crates/languages/src/typescript/injections.scm b/crates/languages/src/typescript/injections.scm index 7affdc5b758deb5ff717476f0de934a1786469aa..331f42fa913ff8ce79bde5c50599e679ef780962 100644 --- a/crates/languages/src/typescript/injections.scm +++ b/crates/languages/src/typescript/injections.scm @@ -15,6 +15,21 @@ (#set! injection.language "css")) ) +(call_expression + function: (member_expression + object: (identifier) @_obj (#eq? @_obj "styled") + property: (property_identifier)) + arguments: (template_string (string_fragment) @injection.content + (#set! injection.language "css")) +) + +(call_expression + function: (call_expression + function: (identifier) @_name (#eq? @_name "styled")) + arguments: (template_string (string_fragment) @injection.content + (#set! injection.language "css")) +) + (call_expression function: (identifier) @_name (#eq? @_name "html") arguments: (template_string) @injection.content @@ -63,6 +78,12 @@ (#set! injection.language "graphql"))) ) +(call_expression + function: (identifier) @_name(#match? @_name "^iso$") + arguments: (arguments (template_string (string_fragment) @injection.content + (#set! injection.language "isograph"))) +) + ;; Angular Component template injection (call_expression function: [ diff --git a/crates/languages/src/vtsls.rs b/crates/languages/src/vtsls.rs index 73498fc5795b936bd45b973a8dbc87d3a3f1f5fb..1cf3c8aa52db7bbfc1e784afecc62972884a3d47 100644 --- a/crates/languages/src/vtsls.rs +++ b/crates/languages/src/vtsls.rs @@ -2,9 +2,9 @@ use anyhow::Result; use async_trait::async_trait; use collections::HashMap; use gpui::AsyncApp; -use language::{LanguageName, LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; +use language::{LanguageName, LspAdapter, LspAdapterDelegate, Toolchain}; use lsp::{CodeActionKind, LanguageServerBinary, LanguageServerName}; -use node_runtime::NodeRuntime; +use node_runtime::{NodeRuntime, VersionStrategy}; use project::{Fs, lsp_store::language_server_settings}; use serde_json::Value; use std::{ @@ -67,12 +67,13 @@ const SERVER_NAME: LanguageServerName = LanguageServerName::new_static("vtsls"); #[async_trait(?Send)] impl LspAdapter for VtslsLspAdapter { fn name(&self) -> LanguageServerName { - SERVER_NAME.clone() + SERVER_NAME } async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, + _: &AsyncApp, ) -> Result> { Ok(Box::new(TypeScriptVersions { typescript_version: self.node.npm_package_latest_version("typescript").await?, @@ -86,7 +87,7 @@ impl LspAdapter for VtslsLspAdapter { async fn check_if_user_installed( &self, delegate: &dyn LspAdapterDelegate, - _: Arc, + _: Option, _: &AsyncApp, ) -> Option { let env = delegate.shell_env().await; @@ -115,8 +116,7 @@ impl LspAdapter for VtslsLspAdapter { Self::PACKAGE_NAME, &server_path, &container_dir, - &latest_version.server_version, - Default::default(), + VersionStrategy::Latest(&latest_version.server_version), ) .await { @@ -129,8 +129,7 @@ impl LspAdapter for VtslsLspAdapter { Self::TYPESCRIPT_PACKAGE_NAME, &container_dir.join(Self::TYPESCRIPT_TSDK_PATH), &container_dir, - &latest_version.typescript_version, - Default::default(), + VersionStrategy::Latest(&latest_version.typescript_version), ) .await { @@ -213,7 +212,7 @@ impl LspAdapter for VtslsLspAdapter { self: Arc, fs: &dyn Fs, delegate: &Arc, - _: Arc, + _: Option, cx: &mut AsyncApp, ) -> Result { let tsdk_path = Self::tsdk_path(fs, delegate).await; diff --git a/crates/languages/src/yaml.rs b/crates/languages/src/yaml.rs index 28be2cc1a45130a723084529e0c6164ab2a042c2..bf634aafbab9fb3312f63fa818a55ddae90a05f3 100644 --- a/crates/languages/src/yaml.rs +++ b/crates/languages/src/yaml.rs @@ -2,11 +2,9 @@ use anyhow::{Context as _, Result}; use async_trait::async_trait; use futures::StreamExt; use gpui::AsyncApp; -use language::{ - LanguageToolchainStore, LspAdapter, LspAdapterDelegate, language_settings::AllLanguageSettings, -}; +use language::{LspAdapter, LspAdapterDelegate, Toolchain, language_settings::AllLanguageSettings}; use lsp::{LanguageServerBinary, LanguageServerName}; -use node_runtime::NodeRuntime; +use node_runtime::{NodeRuntime, VersionStrategy}; use project::{Fs, lsp_store::language_server_settings}; use serde_json::Value; use settings::{Settings, SettingsLocation}; @@ -40,12 +38,13 @@ impl YamlLspAdapter { #[async_trait(?Send)] impl LspAdapter for YamlLspAdapter { fn name(&self) -> LanguageServerName { - Self::SERVER_NAME.clone() + Self::SERVER_NAME } async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, + _: &AsyncApp, ) -> Result> { Ok(Box::new( self.node @@ -57,7 +56,7 @@ impl LspAdapter for YamlLspAdapter { async fn check_if_user_installed( &self, delegate: &dyn LspAdapterDelegate, - _: Arc, + _: Option, _: &AsyncApp, ) -> Option { let path = delegate.which(Self::SERVER_NAME.as_ref()).await?; @@ -107,9 +106,8 @@ impl LspAdapter for YamlLspAdapter { .should_install_npm_package( Self::PACKAGE_NAME, &server_path, - &container_dir, - &version, - Default::default(), + container_dir, + VersionStrategy::Latest(version), ) .await; @@ -136,7 +134,7 @@ impl LspAdapter for YamlLspAdapter { self: Arc, _: &dyn Fs, delegate: &Arc, - _: Arc, + _: Option, cx: &mut AsyncApp, ) -> Result { let location = SettingsLocation { diff --git a/crates/languages/src/yaml/overrides.scm b/crates/languages/src/yaml/overrides.scm new file mode 100644 index 0000000000000000000000000000000000000000..9503051a62080eb2fdfca3416ef9e5286464dd17 --- /dev/null +++ b/crates/languages/src/yaml/overrides.scm @@ -0,0 +1,5 @@ +(comment) @comment.inclusive +[ + (single_quote_scalar) + (double_quote_scalar) +] @string diff --git a/crates/line_ending_selector/Cargo.toml b/crates/line_ending_selector/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..7c5c8f6d8f3996771f832c28d5d71b857bb0b3b6 --- /dev/null +++ b/crates/line_ending_selector/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "line_ending_selector" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/line_ending_selector.rs" +doctest = false + +[dependencies] +editor.workspace = true +gpui.workspace = true +language.workspace = true +picker.workspace = true +project.workspace = true +ui.workspace = true +util.workspace = true +workspace.workspace = true +workspace-hack.workspace = true diff --git a/crates/line_ending_selector/LICENSE-GPL b/crates/line_ending_selector/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/line_ending_selector/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/line_ending_selector/src/line_ending_selector.rs b/crates/line_ending_selector/src/line_ending_selector.rs new file mode 100644 index 0000000000000000000000000000000000000000..7f75a1ebe3550595c8fa78643ef5446ab2fa3a44 --- /dev/null +++ b/crates/line_ending_selector/src/line_ending_selector.rs @@ -0,0 +1,192 @@ +use editor::Editor; +use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Task, WeakEntity, actions}; +use language::{Buffer, LineEnding}; +use picker::{Picker, PickerDelegate}; +use project::Project; +use std::sync::Arc; +use ui::{ListItem, ListItemSpacing, prelude::*}; +use util::ResultExt; +use workspace::ModalView; + +actions!( + line_ending, + [ + /// Toggles the line ending selector modal. + Toggle + ] +); + +pub fn init(cx: &mut App) { + cx.observe_new(LineEndingSelector::register).detach(); +} + +pub struct LineEndingSelector { + picker: Entity>, +} + +impl LineEndingSelector { + fn register(editor: &mut Editor, _window: Option<&mut Window>, cx: &mut Context) { + let editor_handle = cx.weak_entity(); + editor + .register_action(move |_: &Toggle, window, cx| { + Self::toggle(&editor_handle, window, cx); + }) + .detach(); + } + + fn toggle(editor: &WeakEntity, window: &mut Window, cx: &mut App) { + let Some((workspace, buffer)) = editor + .update(cx, |editor, cx| { + Some((editor.workspace()?, editor.active_excerpt(cx)?.1)) + }) + .ok() + .flatten() + else { + return; + }; + + workspace.update(cx, |workspace, cx| { + let project = workspace.project().clone(); + workspace.toggle_modal(window, cx, move |window, cx| { + LineEndingSelector::new(buffer, project, window, cx) + }); + }) + } + + fn new( + buffer: Entity, + project: Entity, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let line_ending = buffer.read(cx).line_ending(); + let delegate = + LineEndingSelectorDelegate::new(cx.entity().downgrade(), buffer, project, line_ending); + let picker = cx.new(|cx| Picker::nonsearchable_uniform_list(delegate, window, cx)); + Self { picker } + } +} + +impl Render for LineEndingSelector { + fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { + v_flex().w(rems(34.)).child(self.picker.clone()) + } +} + +impl Focusable for LineEndingSelector { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.picker.focus_handle(cx) + } +} + +impl EventEmitter for LineEndingSelector {} +impl ModalView for LineEndingSelector {} + +struct LineEndingSelectorDelegate { + line_ending_selector: WeakEntity, + buffer: Entity, + project: Entity, + line_ending: LineEnding, + matches: Vec, + selected_index: usize, +} + +impl LineEndingSelectorDelegate { + fn new( + line_ending_selector: WeakEntity, + buffer: Entity, + project: Entity, + line_ending: LineEnding, + ) -> Self { + Self { + line_ending_selector, + buffer, + project, + line_ending, + matches: vec![LineEnding::Unix, LineEnding::Windows], + selected_index: 0, + } + } +} + +impl PickerDelegate for LineEndingSelectorDelegate { + type ListItem = ListItem; + + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc { + "Select a line ending…".into() + } + + fn match_count(&self) -> usize { + self.matches.len() + } + + fn confirm(&mut self, _: bool, window: &mut Window, cx: &mut Context>) { + if let Some(line_ending) = self.matches.get(self.selected_index) { + self.buffer.update(cx, |this, cx| { + this.set_line_ending(*line_ending, cx); + }); + let buffer = self.buffer.clone(); + let project = self.project.clone(); + cx.defer(move |cx| { + project.update(cx, |this, cx| { + this.save_buffer(buffer, cx).detach(); + }); + }); + } + self.dismissed(window, cx); + } + + fn dismissed(&mut self, _: &mut Window, cx: &mut Context>) { + self.line_ending_selector + .update(cx, |_, cx| cx.emit(DismissEvent)) + .log_err(); + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index( + &mut self, + ix: usize, + _window: &mut Window, + _: &mut Context>, + ) { + self.selected_index = ix; + } + + fn update_matches( + &mut self, + _query: String, + _window: &mut Window, + _cx: &mut Context>, + ) -> gpui::Task<()> { + return Task::ready(()); + } + + fn render_match( + &self, + ix: usize, + selected: bool, + _: &mut Window, + _: &mut Context>, + ) -> Option { + let line_ending = self.matches.get(ix)?; + let label = match line_ending { + LineEnding::Unix => "LF", + LineEnding::Windows => "CRLF", + }; + + let mut list_item = ListItem::new(ix) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .child(Label::new(label)); + + if &self.line_ending == line_ending { + list_item = list_item.end_slot(Icon::new(IconName::Check).color(Color::Muted)); + } + + Some(list_item) + } +} diff --git a/crates/livekit_client/Cargo.toml b/crates/livekit_client/Cargo.toml index 821fd5d39006b517d264687d7fb9a25fb570d0c2..80e4960c0df31f6a3d8115bd4bd66c0de09b76f0 100644 --- a/crates/livekit_client/Cargo.toml +++ b/crates/livekit_client/Cargo.toml @@ -22,6 +22,7 @@ test-support = ["collections/test-support", "gpui/test-support"] [dependencies] anyhow.workspace = true async-trait.workspace = true +audio.workspace = true collections.workspace = true cpal.workspace = true futures.workspace = true @@ -34,6 +35,10 @@ log.workspace = true nanoid.workspace = true parking_lot.workspace = true postage.workspace = true +rodio = { workspace = true, features = ["wav_output", "recording"] } +serde.workspace = true +serde_urlencoded.workspace = true +settings.workspace = true smallvec.workspace = true tokio-tungstenite.workspace = true util.workspace = true diff --git a/crates/livekit_client/examples/test_app.rs b/crates/livekit_client/examples/test_app.rs index e1d01df534e142502abb5f17392e19299f8ae158..c99abb292ef6d99e8adc3ab9007f4c49eeb05be2 100644 --- a/crates/livekit_client/examples/test_app.rs +++ b/crates/livekit_client/examples/test_app.rs @@ -159,14 +159,14 @@ impl LivekitWindow { if output .audio_output_stream .as_ref() - .map_or(false, |(track, _)| track.sid() == unpublish_sid) + .is_some_and(|(track, _)| track.sid() == unpublish_sid) { output.audio_output_stream.take(); } if output .screen_share_output_view .as_ref() - .map_or(false, |(track, _)| track.sid() == unpublish_sid) + .is_some_and(|(track, _)| track.sid() == unpublish_sid) { output.screen_share_output_view.take(); } @@ -183,7 +183,7 @@ impl LivekitWindow { match track { livekit_client::RemoteTrack::Audio(track) => { output.audio_output_stream = Some(( - publication.clone(), + publication, room.play_remote_audio_track(&track, cx).unwrap(), )); } @@ -255,7 +255,10 @@ impl LivekitWindow { } else { let room = self.room.clone(); cx.spawn_in(window, async move |this, cx| { - let (publication, stream) = room.publish_local_microphone_track(cx).await.unwrap(); + let (publication, stream) = room + .publish_local_microphone_track("test_user".to_string(), false, cx) + .await + .unwrap(); this.update(cx, |this, cx| { this.microphone_track = Some(publication); this.microphone_stream = Some(stream); diff --git a/crates/livekit_client/src/lib.rs b/crates/livekit_client/src/lib.rs index 149859fdc8ecd8533332c9462a090adb5496f100..055aa3704e06f25a21c69294343539289d8acb49 100644 --- a/crates/livekit_client/src/lib.rs +++ b/crates/livekit_client/src/lib.rs @@ -1,7 +1,13 @@ +use anyhow::Context as _; use collections::HashMap; mod remote_video_track_view; +use cpal::traits::HostTrait as _; pub use remote_video_track_view::{RemoteVideoTrackView, RemoteVideoTrackViewEvent}; +use rodio::DeviceTrait as _; + +mod record; +pub use record::CaptureInput; #[cfg(not(any( test, @@ -18,6 +24,11 @@ mod livekit_client; )))] pub use livekit_client::*; +// If you need proper LSP in livekit_client you've got to comment +// - the cfg blocks above +// - the mods: mock_client & test and their conditional blocks +// - the pub use mock_client::* and their conditional blocks + #[cfg(any( test, feature = "test-support", @@ -168,3 +179,59 @@ pub enum RoomEvent { Reconnecting, Reconnected, } + +pub(crate) fn default_device( + input: bool, +) -> anyhow::Result<(cpal::Device, cpal::SupportedStreamConfig)> { + let device; + let config; + if input { + device = cpal::default_host() + .default_input_device() + .context("no audio input device available")?; + config = device + .default_input_config() + .context("failed to get default input config")?; + } else { + device = cpal::default_host() + .default_output_device() + .context("no audio output device available")?; + config = device + .default_output_config() + .context("failed to get default output config")?; + } + Ok((device, config)) +} + +pub(crate) fn get_sample_data( + sample_format: cpal::SampleFormat, + data: &cpal::Data, +) -> anyhow::Result> { + match sample_format { + cpal::SampleFormat::I8 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::I16 => Ok(data.as_slice::().unwrap().to_vec()), + cpal::SampleFormat::I24 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::I32 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::I64 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::U8 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::U16 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::U32 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::U64 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::F32 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::F64 => Ok(convert_sample_data::(data)), + _ => anyhow::bail!("Unsupported sample format"), + } +} + +pub(crate) fn convert_sample_data< + TSource: cpal::SizedSample, + TDest: cpal::SizedSample + cpal::FromSample, +>( + data: &cpal::Data, +) -> Vec { + data.as_slice::() + .unwrap() + .iter() + .map(|e| e.to_sample::()) + .collect() +} diff --git a/crates/livekit_client/src/livekit_client.rs b/crates/livekit_client/src/livekit_client.rs index 8f0ac1a456aea1ac32879e121961e87930035dba..45e929cb2ec0bebf054497632d614af1975f6397 100644 --- a/crates/livekit_client/src/livekit_client.rs +++ b/crates/livekit_client/src/livekit_client.rs @@ -1,11 +1,14 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; +use audio::AudioSettings; use collections::HashMap; use futures::{SinkExt, channel::mpsc}; use gpui::{App, AsyncApp, ScreenCaptureSource, ScreenCaptureStream, Task}; use gpui_tokio::Tokio; +use log::info; use playback::capture_local_video_track; +use settings::Settings; mod playback; @@ -94,9 +97,13 @@ impl Room { pub async fn publish_local_microphone_track( &self, + user_name: String, + is_staff: bool, cx: &mut AsyncApp, ) -> Result<(LocalTrackPublication, playback::AudioStream)> { - let (track, stream) = self.playback.capture_local_microphone_track()?; + let (track, stream) = self + .playback + .capture_local_microphone_track(user_name, is_staff, &cx)?; let publication = self .local_participant() .publish_track( @@ -123,9 +130,14 @@ impl Room { pub fn play_remote_audio_track( &self, track: &RemoteAudioTrack, - _cx: &App, + cx: &mut App, ) -> Result { - Ok(self.playback.play_remote_audio_track(&track.0)) + if AudioSettings::get_global(cx).rodio_audio { + info!("Using experimental.rodio_audio audio pipeline for output"); + playback::play_remote_audio_track(&track.0, cx) + } else { + Ok(self.playback.play_remote_audio_track(&track.0)) + } } } diff --git a/crates/livekit_client/src/livekit_client/playback.rs b/crates/livekit_client/src/livekit_client/playback.rs index f14e156125f6da815fe24aabd798e53c6c3e82b8..d1b2cee4aa1750ba4b8af3033e44b1fe9fbe78de 100644 --- a/crates/livekit_client/src/livekit_client/playback.rs +++ b/crates/livekit_client/src/livekit_client/playback.rs @@ -1,11 +1,12 @@ use anyhow::{Context as _, Result}; -use cpal::traits::{DeviceTrait, HostTrait, StreamTrait as _}; -use cpal::{Data, FromSample, I24, SampleFormat, SizedSample}; +use audio::{AudioSettings, CHANNEL_COUNT, SAMPLE_RATE}; +use cpal::traits::{DeviceTrait, StreamTrait as _}; use futures::channel::mpsc::UnboundedSender; use futures::{Stream, StreamExt as _}; use gpui::{ - BackgroundExecutor, ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream, Task, + AsyncApp, BackgroundExecutor, ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream, + Task, }; use libwebrtc::native::{apm, audio_mixer, audio_resampler}; use livekit::track; @@ -18,14 +19,20 @@ use livekit::webrtc::{ video_source::{RtcVideoSource, VideoResolution, native::NativeVideoSource}, video_stream::native::NativeVideoStream, }; +use log::info; use parking_lot::Mutex; +use rodio::Source; +use serde::{Deserialize, Serialize}; +use settings::Settings; use std::cell::RefCell; use std::sync::Weak; -use std::sync::atomic::{self, AtomicI32}; +use std::sync::atomic::{AtomicBool, AtomicI32, Ordering}; use std::time::Duration; use std::{borrow::Cow, collections::VecDeque, sync::Arc, thread}; use util::{ResultExt as _, maybe}; +mod source; + pub(crate) struct AudioStack { executor: BackgroundExecutor, apm: Arc>, @@ -34,12 +41,36 @@ pub(crate) struct AudioStack { next_ssrc: AtomicI32, } -// NOTE: We use WebRTC's mixer which only supports -// 16kHz, 32kHz and 48kHz. As 48 is the most common "next step up" -// for audio output devices like speakers/bluetooth, we just hard-code -// this; and downsample when we need to. -const SAMPLE_RATE: u32 = 48000; -const NUM_CHANNELS: u32 = 2; +pub(crate) fn play_remote_audio_track( + track: &livekit::track::RemoteAudioTrack, + cx: &mut gpui::App, +) -> Result { + let stop_handle = Arc::new(AtomicBool::new(false)); + let stop_handle_clone = stop_handle.clone(); + let stream = source::LiveKitStream::new(cx.background_executor(), track); + + let stream = stream + .stoppable() + .periodic_access(Duration::from_millis(50), move |s| { + if stop_handle.load(Ordering::Relaxed) { + s.stop(); + } + }); + + let speaker: Speaker = serde_urlencoded::from_str(&track.name()).unwrap_or_else(|_| Speaker { + name: track.name(), + is_staff: false, + }); + audio::Audio::play_voip_stream(stream, speaker.name, speaker.is_staff, cx) + .context("Could not play audio")?; + + let on_drop = util::defer(move || { + stop_handle_clone.store(true, Ordering::Relaxed); + }); + Ok(AudioStream::Output { + _drop: Box::new(on_drop), + }) +} impl AudioStack { pub(crate) fn new(executor: BackgroundExecutor) -> Self { @@ -62,11 +93,11 @@ impl AudioStack { ) -> AudioStream { let output_task = self.start_output(); - let next_ssrc = self.next_ssrc.fetch_add(1, atomic::Ordering::Relaxed); + let next_ssrc = self.next_ssrc.fetch_add(1, Ordering::Relaxed); let source = AudioMixerSource { ssrc: next_ssrc, - sample_rate: SAMPLE_RATE, - num_channels: NUM_CHANNELS, + sample_rate: SAMPLE_RATE.get(), + num_channels: CHANNEL_COUNT.get() as u32, buffer: Arc::default(), }; self.mixer.lock().add_source(source.clone()); @@ -98,19 +129,45 @@ impl AudioStack { } } + fn start_output(&self) -> Arc> { + if let Some(task) = self._output_task.borrow().upgrade() { + return task; + } + let task = Arc::new(self.executor.spawn({ + let apm = self.apm.clone(); + let mixer = self.mixer.clone(); + async move { + Self::play_output(apm, mixer, SAMPLE_RATE.get(), CHANNEL_COUNT.get().into()) + .await + .log_err(); + } + })); + *self._output_task.borrow_mut() = Arc::downgrade(&task); + task + } + pub(crate) fn capture_local_microphone_track( &self, + user_name: String, + is_staff: bool, + cx: &AsyncApp, ) -> Result<(crate::LocalAudioTrack, AudioStream)> { let source = NativeAudioSource::new( // n.b. this struct's options are always ignored, noise cancellation is provided by apm. AudioSourceOptions::default(), - SAMPLE_RATE, - NUM_CHANNELS, + SAMPLE_RATE.get(), + CHANNEL_COUNT.get().into(), 10, ); + let track_name = serde_urlencoded::to_string(Speaker { + name: user_name, + is_staff, + }) + .context("Could not encode user information in track name")?; + let track = track::LocalAudioTrack::create_audio_track( - "microphone", + &track_name, RtcAudioSource::Native(source.clone()), ); @@ -118,44 +175,41 @@ impl AudioStack { let (frame_tx, mut frame_rx) = futures::channel::mpsc::unbounded(); let transmit_task = self.executor.spawn({ - let source = source.clone(); async move { while let Some(frame) = frame_rx.next().await { source.capture_frame(&frame).await.log_err(); } } }); - let capture_task = self.executor.spawn(async move { - Self::capture_input(apm, frame_tx, SAMPLE_RATE, NUM_CHANNELS).await - }); + let rodio_pipeline = + AudioSettings::try_read_global(cx, |setting| setting.rodio_audio).unwrap_or_default(); + let capture_task = if rodio_pipeline { + info!("Using experimental.rodio_audio audio pipeline"); + let voip_parts = audio::VoipParts::new(cx)?; + thread::spawn(move || { + // microphone is non send on mac + let microphone = audio::Audio::open_microphone(voip_parts)?; + send_to_livekit(frame_tx, microphone); + Ok::<(), anyhow::Error>(()) + }); + Task::ready(Ok(())) + } else { + self.executor.spawn(async move { + Self::capture_input(apm, frame_tx, SAMPLE_RATE.get(), CHANNEL_COUNT.get().into()) + .await + }) + }; let on_drop = util::defer(|| { drop(transmit_task); drop(capture_task); }); - return Ok(( + Ok(( super::LocalAudioTrack(track), AudioStream::Output { _drop: Box::new(on_drop), }, - )); - } - - fn start_output(&self) -> Arc> { - if let Some(task) = self._output_task.borrow().upgrade() { - return task; - } - let task = Arc::new(self.executor.spawn({ - let apm = self.apm.clone(); - let mixer = self.mixer.clone(); - async move { - Self::play_output(apm, mixer, SAMPLE_RATE, NUM_CHANNELS) - .await - .log_err(); - } - })); - *self._output_task.borrow_mut() = Arc::downgrade(&task); - task + )) } async fn play_output( @@ -166,7 +220,7 @@ impl AudioStack { ) -> Result<()> { loop { let mut device_change_listener = DeviceChangeListener::new(false)?; - let (output_device, output_config) = default_device(false)?; + let (output_device, output_config) = crate::default_device(false)?; let (end_on_drop_tx, end_on_drop_rx) = std::sync::mpsc::channel::<()>(); let mixer = mixer.clone(); let apm = apm.clone(); @@ -238,7 +292,7 @@ impl AudioStack { ) -> Result<()> { loop { let mut device_change_listener = DeviceChangeListener::new(true)?; - let (device, config) = default_device(true)?; + let (device, config) = crate::default_device(true)?; let (end_on_drop_tx, end_on_drop_rx) = std::sync::mpsc::channel::<()>(); let apm = apm.clone(); let frame_tx = frame_tx.clone(); @@ -262,7 +316,7 @@ impl AudioStack { config.sample_format(), move |data, _: &_| { let data = - Self::get_sample_data(config.sample_format(), data).log_err(); + crate::get_sample_data(config.sample_format(), data).log_err(); let Some(data) = data else { return; }; @@ -320,32 +374,35 @@ impl AudioStack { drop(end_on_drop_tx) } } +} - fn get_sample_data(sample_format: SampleFormat, data: &Data) -> Result> { - match sample_format { - SampleFormat::I8 => Ok(Self::convert_sample_data::(data)), - SampleFormat::I16 => Ok(data.as_slice::().unwrap().to_vec()), - SampleFormat::I24 => Ok(Self::convert_sample_data::(data)), - SampleFormat::I32 => Ok(Self::convert_sample_data::(data)), - SampleFormat::I64 => Ok(Self::convert_sample_data::(data)), - SampleFormat::U8 => Ok(Self::convert_sample_data::(data)), - SampleFormat::U16 => Ok(Self::convert_sample_data::(data)), - SampleFormat::U32 => Ok(Self::convert_sample_data::(data)), - SampleFormat::U64 => Ok(Self::convert_sample_data::(data)), - SampleFormat::F32 => Ok(Self::convert_sample_data::(data)), - SampleFormat::F64 => Ok(Self::convert_sample_data::(data)), - _ => anyhow::bail!("Unsupported sample format"), - } - } +#[derive(Serialize, Deserialize)] +struct Speaker { + name: String, + is_staff: bool, +} - fn convert_sample_data>( - data: &Data, - ) -> Vec { - data.as_slice::() - .unwrap() - .iter() - .map(|e| e.to_sample::()) - .collect() +fn send_to_livekit(frame_tx: UnboundedSender>, mut microphone: impl Source) { + use cpal::Sample; + loop { + let sampled: Vec<_> = microphone + .by_ref() + .take(audio::BUFFER_SIZE) + .map(|s| s.to_sample()) + .collect(); + + if frame_tx + .unbounded_send(AudioFrame { + sample_rate: SAMPLE_RATE.get(), + num_channels: CHANNEL_COUNT.get() as u32, + samples_per_channel: sampled.len() as u32 / CHANNEL_COUNT.get() as u32, + data: Cow::Owned(sampled), + }) + .is_err() + { + // must rx has dropped or is not consuming + break; + } } } @@ -393,27 +450,6 @@ pub(crate) async fn capture_local_video_track( )) } -fn default_device(input: bool) -> Result<(cpal::Device, cpal::SupportedStreamConfig)> { - let device; - let config; - if input { - device = cpal::default_host() - .default_input_device() - .context("no audio input device available")?; - config = device - .default_input_config() - .context("failed to get default input config")?; - } else { - device = cpal::default_host() - .default_output_device() - .context("no audio output device available")?; - config = device - .default_output_config() - .context("failed to get default output config")?; - } - Ok((device, config)) -} - #[derive(Clone)] struct AudioMixerSource { ssrc: i32, diff --git a/crates/livekit_client/src/livekit_client/playback/source.rs b/crates/livekit_client/src/livekit_client/playback/source.rs new file mode 100644 index 0000000000000000000000000000000000000000..67bfe793902da94a114ca617ce5bfa33c68d02e7 --- /dev/null +++ b/crates/livekit_client/src/livekit_client/playback/source.rs @@ -0,0 +1,84 @@ +use std::num::NonZero; + +use futures::StreamExt; +use libwebrtc::{audio_stream::native::NativeAudioStream, prelude::AudioFrame}; +use livekit::track::RemoteAudioTrack; +use rodio::{Source, buffer::SamplesBuffer, conversions::SampleTypeConverter, nz}; + +use audio::{CHANNEL_COUNT, SAMPLE_RATE}; + +fn frame_to_samplesbuffer(frame: AudioFrame) -> SamplesBuffer { + let samples = frame.data.iter().copied(); + let samples = SampleTypeConverter::<_, _>::new(samples); + let samples: Vec = samples.collect(); + SamplesBuffer::new( + // here be dragons + // NonZero::new(frame.num_channels as u16).expect("audio frame channels is nonzero"), + nz!(2), + NonZero::new(frame.sample_rate).expect("audio frame sample rate is nonzero"), + samples, + ) +} + +pub struct LiveKitStream { + // shared_buffer: SharedBuffer, + inner: rodio::queue::SourcesQueueOutput, + _receiver_task: gpui::Task<()>, +} + +impl LiveKitStream { + pub fn new(executor: &gpui::BackgroundExecutor, track: &RemoteAudioTrack) -> Self { + let mut stream = NativeAudioStream::new( + track.rtc_track(), + SAMPLE_RATE.get() as i32, + CHANNEL_COUNT.get().into(), + ); + let (queue_input, queue_output) = rodio::queue::queue(true); + // spawn rtc stream + let receiver_task = executor.spawn({ + async move { + while let Some(frame) = stream.next().await { + let samples = frame_to_samplesbuffer(frame); + queue_input.append(samples); + } + } + }); + + LiveKitStream { + _receiver_task: receiver_task, + inner: queue_output, + } + } +} + +impl Iterator for LiveKitStream { + type Item = rodio::Sample; + + fn next(&mut self) -> Option { + self.inner.next() + } +} + +impl Source for LiveKitStream { + fn current_span_len(&self) -> Option { + self.inner.current_span_len() + } + + fn channels(&self) -> rodio::ChannelCount { + // This must be hardcoded because the playback source assumes constant + // sample rate and channel count. The queue upon which this is build + // will however report different counts and rates. Even though we put in + // only items with our (constant) CHANNEL_COUNT & SAMPLE_RATE this will + // play silence on one channel and at 44100 which is not what our + // constants are. + CHANNEL_COUNT + } + + fn sample_rate(&self) -> rodio::SampleRate { + SAMPLE_RATE // see comment on channels + } + + fn total_duration(&self) -> Option { + self.inner.total_duration() + } +} diff --git a/crates/livekit_client/src/record.rs b/crates/livekit_client/src/record.rs new file mode 100644 index 0000000000000000000000000000000000000000..24e260e71665704c1010d07e082a03fbe6306a30 --- /dev/null +++ b/crates/livekit_client/src/record.rs @@ -0,0 +1,96 @@ +use std::{ + env, + num::NonZero, + path::{Path, PathBuf}, + sync::{Arc, Mutex}, + time::Duration, +}; + +use anyhow::{Context, Result}; +use cpal::traits::{DeviceTrait, StreamTrait}; +use rodio::{buffer::SamplesBuffer, conversions::SampleTypeConverter}; +use util::ResultExt; + +pub struct CaptureInput { + pub name: String, + config: cpal::SupportedStreamConfig, + samples: Arc>>, + _stream: cpal::Stream, +} + +impl CaptureInput { + pub fn start() -> anyhow::Result { + let (device, config) = crate::default_device(true)?; + let name = device.name().unwrap_or("".to_string()); + log::info!("Using microphone: {}", name); + + let samples = Arc::new(Mutex::new(Vec::new())); + let stream = start_capture(device, config.clone(), samples.clone())?; + + Ok(Self { + name, + _stream: stream, + config, + samples, + }) + } + + pub fn finish(self) -> Result { + let name = self.name; + let mut path = env::current_dir().context("Could not get current dir")?; + path.push(&format!("test_recording_{name}.wav")); + log::info!("Test recording written to: {}", path.display()); + write_out(self.samples, self.config, &path)?; + Ok(path) + } +} + +fn start_capture( + device: cpal::Device, + config: cpal::SupportedStreamConfig, + samples: Arc>>, +) -> Result { + let stream = device + .build_input_stream_raw( + &config.config(), + config.sample_format(), + move |data, _: &_| { + let data = crate::get_sample_data(config.sample_format(), data).log_err(); + let Some(data) = data else { + return; + }; + samples + .try_lock() + .expect("Only locked after stream ends") + .extend_from_slice(&data); + }, + |err| log::error!("error capturing audio track: {:?}", err), + Some(Duration::from_millis(100)), + ) + .context("failed to build input stream")?; + + stream.play()?; + Ok(stream) +} + +fn write_out( + samples: Arc>>, + config: cpal::SupportedStreamConfig, + path: &Path, +) -> Result<()> { + let samples = std::mem::take( + &mut *samples + .try_lock() + .expect("Stream has ended, callback cant hold the lock"), + ); + let samples: Vec = SampleTypeConverter::<_, f32>::new(samples.into_iter()).collect(); + let mut samples = SamplesBuffer::new( + NonZero::new(config.channels()).expect("config channel is never zero"), + NonZero::new(config.sample_rate().0).expect("config sample_rate is never zero"), + samples, + ); + match rodio::wav_to_file(&mut samples, path) { + Ok(_) => Ok(()), + Err(e) => Err(anyhow::anyhow!("Failed to write wav file: {}", e)), + } +} diff --git a/crates/livekit_client/src/test.rs b/crates/livekit_client/src/test.rs index e02c4d876fbe3411cf1730f3d97aaf8db3e208b6..fd3163598203ac26443cae1b733372b6c3bdf1d1 100644 --- a/crates/livekit_client/src/test.rs +++ b/crates/livekit_client/src/test.rs @@ -421,7 +421,7 @@ impl TestServer { track_sid: &TrackSid, muted: bool, ) -> Result<()> { - let claims = livekit_api::token::validate(&token, &self.secret_key)?; + let claims = livekit_api::token::validate(token, &self.secret_key)?; let room_name = claims.video.room.unwrap(); let identity = ParticipantIdentity(claims.sub.unwrap().to_string()); let mut server_rooms = self.rooms.lock(); @@ -475,7 +475,7 @@ impl TestServer { } pub(crate) fn is_track_muted(&self, token: &str, track_sid: &TrackSid) -> Option { - let claims = livekit_api::token::validate(&token, &self.secret_key).ok()?; + let claims = livekit_api::token::validate(token, &self.secret_key).ok()?; let room_name = claims.video.room.unwrap(); let mut server_rooms = self.rooms.lock(); @@ -728,6 +728,8 @@ impl Room { pub async fn publish_local_microphone_track( &self, + _track_name: String, + _is_staff: bool, cx: &mut AsyncApp, ) -> Result<(LocalTrackPublication, AudioStream)> { self.local_participant().publish_microphone_track(cx).await @@ -736,14 +738,14 @@ impl Room { impl Drop for RoomState { fn drop(&mut self) { - if self.connection_state == ConnectionState::Connected { - if let Ok(server) = TestServer::get(&self.url) { - let executor = server.executor.clone(); - let token = self.token.clone(); - executor - .spawn(async move { server.leave_room(token).await.ok() }) - .detach(); - } + if self.connection_state == ConnectionState::Connected + && let Ok(server) = TestServer::get(&self.url) + { + let executor = server.executor.clone(); + let token = self.token.clone(); + executor + .spawn(async move { server.leave_room(token).await.ok() }) + .detach(); } } } diff --git a/crates/lmstudio/src/lmstudio.rs b/crates/lmstudio/src/lmstudio.rs index 43c78115cdd4f517a51052991121620a0a93c363..ef2f7b6208f62e079609049b8eff83a80034741e 100644 --- a/crates/lmstudio/src/lmstudio.rs +++ b/crates/lmstudio/src/lmstudio.rs @@ -86,11 +86,12 @@ impl Model { } #[derive(Debug, Serialize, Deserialize)] -#[serde(untagged)] +#[serde(rename_all = "lowercase")] pub enum ToolChoice { Auto, Required, None, + #[serde(untagged)] Other(ToolDefinition), } diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index ce9e2fe229c0aded6fac31c260e334445f987f03..7af51ef6fff8bddefac993fb5eb40e10d054977c 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -45,7 +45,7 @@ use util::{ConnectionResult, ResultExt, TryFutureExt, redact}; const JSON_RPC_VERSION: &str = "2.0"; const CONTENT_LEN_HEADER: &str = "Content-Length: "; -const LSP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 2); +pub const LSP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 2); const SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); type NotificationHandler = Box, Value, &mut AsyncApp)>; @@ -100,8 +100,8 @@ pub struct LanguageServer { io_tasks: Mutex>, Task>)>>, output_done_rx: Mutex>, server: Arc>>, - workspace_folders: Option>>>, - root_uri: Url, + workspace_folders: Option>>>, + root_uri: Uri, } #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -166,6 +166,12 @@ impl<'a> From<&'a str> for LanguageServerName { } } +impl PartialEq for LanguageServerName { + fn eq(&self, other: &str) -> bool { + self.0 == other + } +} + /// Handle to a language server RPC activity subscription. pub enum Subscription { Notification { @@ -310,7 +316,7 @@ impl LanguageServer { binary: LanguageServerBinary, root_path: &Path, code_action_kinds: Option>, - workspace_folders: Option>>>, + workspace_folders: Option>>>, cx: &mut AsyncApp, ) -> Result { let working_dir = if root_path.is_dir() { @@ -318,7 +324,7 @@ impl LanguageServer { } else { root_path.parent().unwrap_or_else(|| Path::new("/")) }; - let root_uri = Url::from_file_path(&working_dir) + let root_uri = Uri::from_file_path(&working_dir) .map_err(|()| anyhow!("{working_dir:?} is not a valid URI"))?; log::info!( @@ -384,8 +390,8 @@ impl LanguageServer { server: Option, code_action_kinds: Option>, binary: LanguageServerBinary, - root_uri: Url, - workspace_folders: Option>>>, + root_uri: Uri, + workspace_folders: Option>>>, cx: &mut AsyncApp, on_unhandled_notification: F, ) -> Self @@ -1350,7 +1356,7 @@ impl LanguageServer { } /// Add new workspace folder to the list. - pub fn add_workspace_folder(&self, uri: Url) { + pub fn add_workspace_folder(&self, uri: Uri) { if self .capabilities() .workspace @@ -1383,8 +1389,9 @@ impl LanguageServer { self.notify::(¶ms).ok(); } } - /// Add new workspace folder to the list. - pub fn remove_workspace_folder(&self, uri: Url) { + + /// Remove existing workspace folder from the list. + pub fn remove_workspace_folder(&self, uri: Uri) { if self .capabilities() .workspace @@ -1416,7 +1423,7 @@ impl LanguageServer { self.notify::(¶ms).ok(); } } - pub fn set_workspace_folders(&self, folders: BTreeSet) { + pub fn set_workspace_folders(&self, folders: BTreeSet) { let Some(workspace_folders) = self.workspace_folders.as_ref() else { return; }; @@ -1449,7 +1456,7 @@ impl LanguageServer { } } - pub fn workspace_folders(&self) -> BTreeSet { + pub fn workspace_folders(&self) -> BTreeSet { self.workspace_folders.as_ref().map_or_else( || BTreeSet::from_iter([self.root_uri.clone()]), |folders| folders.lock().clone(), @@ -1458,7 +1465,7 @@ impl LanguageServer { pub fn register_buffer( &self, - uri: Url, + uri: Uri, language_id: String, version: i32, initial_text: String, @@ -1469,7 +1476,7 @@ impl LanguageServer { .ok(); } - pub fn unregister_buffer(&self, uri: Url) { + pub fn unregister_buffer(&self, uri: Uri) { self.notify::(&DidCloseTextDocumentParams { text_document: TextDocumentIdentifier::new(uri), }) @@ -1586,7 +1593,7 @@ impl FakeLanguageServer { let server_name = LanguageServerName(name.clone().into()); let process_name = Arc::from(name.as_str()); let root = Self::root_path(); - let workspace_folders: Arc>> = Default::default(); + let workspace_folders: Arc>> = Default::default(); let mut server = LanguageServer::new_internal( server_id, server_name.clone(), @@ -1656,13 +1663,13 @@ impl FakeLanguageServer { (server, fake) } #[cfg(target_os = "windows")] - fn root_path() -> Url { - Url::from_file_path("C:/").unwrap() + fn root_path() -> Uri { + Uri::from_file_path("C:/").unwrap() } #[cfg(not(target_os = "windows"))] - fn root_path() -> Url { - Url::from_file_path("/").unwrap() + fn root_path() -> Uri { + Uri::from_file_path("/").unwrap() } } @@ -1864,7 +1871,7 @@ mod tests { server .notify::(&DidOpenTextDocumentParams { text_document: TextDocumentItem::new( - Url::from_str("file://a/b").unwrap(), + Uri::from_str("file://a/b").unwrap(), "rust".to_string(), 0, "".to_string(), @@ -1885,7 +1892,7 @@ mod tests { message: "ok".to_string(), }); fake.notify::(&PublishDiagnosticsParams { - uri: Url::from_str("file://b/c").unwrap(), + uri: Uri::from_str("file://b/c").unwrap(), version: Some(5), diagnostics: vec![], }); diff --git a/crates/markdown/examples/markdown_as_child.rs b/crates/markdown/examples/markdown_as_child.rs index 862b657c8c50c7adc88642f1af21a4c075ff77f2..16c198601a31707602aea3dd250e3958c4c8f0fb 100644 --- a/crates/markdown/examples/markdown_as_child.rs +++ b/crates/markdown/examples/markdown_as_child.rs @@ -30,7 +30,7 @@ pub fn main() { let node_runtime = NodeRuntime::unavailable(); let language_registry = Arc::new(LanguageRegistry::new(cx.background_executor().clone())); - languages::init(language_registry.clone(), node_runtime, cx); + languages::init(language_registry, node_runtime, cx); theme::init(LoadThemes::JustBase, cx); Assets.load_fonts(cx).unwrap(); diff --git a/crates/markdown/src/markdown.rs b/crates/markdown/src/markdown.rs index a3235a977359270a9c1db0850ad7bb096a90d02d..4e1d3ac51e148439e57a4a1c305dabc31cbc2046 100644 --- a/crates/markdown/src/markdown.rs +++ b/crates/markdown/src/markdown.rs @@ -69,6 +69,7 @@ pub struct MarkdownStyle { pub heading_level_styles: Option, pub table_overflow_x_scroll: bool, pub height_is_multiple_of_line_height: bool, + pub prevent_mouse_interaction: bool, } impl Default for MarkdownStyle { @@ -89,6 +90,7 @@ impl Default for MarkdownStyle { heading_level_styles: None, table_overflow_x_scroll: false, height_is_multiple_of_line_height: false, + prevent_mouse_interaction: false, } } } @@ -340,27 +342,26 @@ impl Markdown { } for (range, event) in &events { - if let MarkdownEvent::Start(MarkdownTag::Image { dest_url, .. }) = event { - if let Some(data_url) = dest_url.strip_prefix("data:") { - let Some((mime_info, data)) = data_url.split_once(',') else { - continue; - }; - let Some((mime_type, encoding)) = mime_info.split_once(';') else { - continue; - }; - let Some(format) = ImageFormat::from_mime_type(mime_type) else { - continue; - }; - let is_base64 = encoding == "base64"; - if is_base64 { - if let Some(bytes) = base64::prelude::BASE64_STANDARD - .decode(data) - .log_with_level(Level::Debug) - { - let image = Arc::new(Image::from_bytes(format, bytes)); - images_by_source_offset.insert(range.start, image); - } - } + if let MarkdownEvent::Start(MarkdownTag::Image { dest_url, .. }) = event + && let Some(data_url) = dest_url.strip_prefix("data:") + { + let Some((mime_info, data)) = data_url.split_once(',') else { + continue; + }; + let Some((mime_type, encoding)) = mime_info.split_once(';') else { + continue; + }; + let Some(format) = ImageFormat::from_mime_type(mime_type) else { + continue; + }; + let is_base64 = encoding == "base64"; + if is_base64 + && let Some(bytes) = base64::prelude::BASE64_STANDARD + .decode(data) + .log_with_level(Level::Debug) + { + let image = Arc::new(Image::from_bytes(format, bytes)); + images_by_source_offset.insert(range.start, image); } } } @@ -576,16 +577,22 @@ impl MarkdownElement { window: &mut Window, cx: &mut App, ) { + if self.style.prevent_mouse_interaction { + return; + } + let is_hovering_link = hitbox.is_hovered(window) && !self.markdown.read(cx).selection.pending && rendered_text .link_for_position(window.mouse_position()) .is_some(); - if is_hovering_link { - window.set_cursor_style(CursorStyle::PointingHand, hitbox); - } else { - window.set_cursor_style(CursorStyle::IBeam, hitbox); + if !self.style.prevent_mouse_interaction { + if is_hovering_link { + window.set_cursor_style(CursorStyle::PointingHand, hitbox); + } else { + window.set_cursor_style(CursorStyle::IBeam, hitbox); + } } let on_open_url = self.on_url_click.take(); @@ -659,13 +666,13 @@ impl MarkdownElement { let rendered_text = rendered_text.clone(); move |markdown, event: &MouseUpEvent, phase, window, cx| { if phase.bubble() { - if let Some(pressed_link) = markdown.pressed_link.take() { - if Some(&pressed_link) == rendered_text.link_for_position(event.position) { - if let Some(open_url) = on_open_url.as_ref() { - open_url(pressed_link.destination_url, window, cx); - } else { - cx.open_url(&pressed_link.destination_url); - } + if let Some(pressed_link) = markdown.pressed_link.take() + && Some(&pressed_link) == rendered_text.link_for_position(event.position) + { + if let Some(open_url) = on_open_url.as_ref() { + open_url(pressed_link.destination_url, window, cx); + } else { + cx.open_url(&pressed_link.destination_url); } } } else if markdown.selection.pending { @@ -758,10 +765,10 @@ impl Element for MarkdownElement { let mut current_img_block_range: Option> = None; for (range, event) in parsed_markdown.events.iter() { // Skip alt text for images that rendered - if let Some(current_img_block_range) = ¤t_img_block_range { - if current_img_block_range.end > range.end { - continue; - } + if let Some(current_img_block_range) = ¤t_img_block_range + && current_img_block_range.end > range.end + { + continue; } match event { @@ -875,7 +882,7 @@ impl Element for MarkdownElement { (CodeBlockRenderer::Custom { render, .. }, _) => { let parent_container = render( kind, - &parsed_markdown, + parsed_markdown, range.clone(), metadata.clone(), window, @@ -1085,7 +1092,13 @@ impl Element for MarkdownElement { cx, ); el.child( - div().absolute().top_1().right_0p5().w_5().child(codeblock), + h_flex() + .w_4() + .absolute() + .top_1p5() + .right_1p5() + .justify_end() + .child(codeblock), ) }); } @@ -1110,11 +1123,12 @@ impl Element for MarkdownElement { cx, ); el.child( - div() + h_flex() + .w_4() .absolute() .top_0() .right_0() - .w_5() + .justify_end() .visible_on_hover("code_block") .child(codeblock), ) @@ -1315,11 +1329,11 @@ fn render_copy_code_block_button( ) .icon_color(Color::Muted) .icon_size(IconSize::Small) + .style(ButtonStyle::Filled) .shape(ui::IconButtonShape::Square) - .tooltip(Tooltip::text("Copy Code")) + .tooltip(Tooltip::text("Copy")) .on_click({ - let id = id.clone(); - let markdown = markdown.clone(); + let markdown = markdown; move |_event, _window, cx| { let id = id.clone(); markdown.update(cx, |this, cx| { @@ -1696,10 +1710,10 @@ impl RenderedText { while let Some(line) = lines.next() { let line_bounds = line.layout.bounds(); if position.y > line_bounds.bottom() { - if let Some(next_line) = lines.peek() { - if position.y < next_line.layout.bounds().top() { - return Err(line.source_end); - } + if let Some(next_line) = lines.peek() + && position.y < next_line.layout.bounds().top() + { + return Err(line.source_end); } continue; diff --git a/crates/markdown/src/parser.rs b/crates/markdown/src/parser.rs index 1035335ccb40f63133c727b5a5be8930d42b818f..3720e5b1ef5f61f0a209ac5617119de61ed05517 100644 --- a/crates/markdown/src/parser.rs +++ b/crates/markdown/src/parser.rs @@ -247,7 +247,7 @@ pub fn parse_markdown( events.push(event_for( text, range.source_range.start..range.source_range.start + prefix_len, - &head, + head, )); range.parsed = CowStr::Boxed(tail.into()); range.merged_range.start += prefix_len; diff --git a/crates/markdown_preview/Cargo.toml b/crates/markdown_preview/Cargo.toml index ebdd8a9eb6c0ffbe99f7c14d1e97b13b3a95d8a3..55646cdcf43617223665e9dc48f13c55f966d99d 100644 --- a/crates/markdown_preview/Cargo.toml +++ b/crates/markdown_preview/Cargo.toml @@ -19,19 +19,21 @@ anyhow.workspace = true async-recursion.workspace = true collections.workspace = true editor.workspace = true +fs.workspace = true gpui.workspace = true +html5ever.workspace = true language.workspace = true linkify.workspace = true log.workspace = true +markup5ever_rcdom.workspace = true pretty_assertions.workspace = true pulldown-cmark.workspace = true settings.workspace = true theme.workspace = true ui.workspace = true util.workspace = true -workspace.workspace = true workspace-hack.workspace = true -fs.workspace = true +workspace.workspace = true [dev-dependencies] editor = { workspace = true, features = ["test-support"] } diff --git a/crates/markdown_preview/src/markdown_elements.rs b/crates/markdown_preview/src/markdown_elements.rs index a570e79f5344d0f35693072f82f947004e24ac65..560e468439efce22aa72d91054d68d491e125b23 100644 --- a/crates/markdown_preview/src/markdown_elements.rs +++ b/crates/markdown_preview/src/markdown_elements.rs @@ -1,5 +1,6 @@ use gpui::{ - FontStyle, FontWeight, HighlightStyle, SharedString, StrikethroughStyle, UnderlineStyle, px, + DefiniteLength, FontStyle, FontWeight, HighlightStyle, SharedString, StrikethroughStyle, + UnderlineStyle, px, }; use language::HighlightId; use std::{fmt::Display, ops::Range, path::PathBuf}; @@ -15,6 +16,7 @@ pub enum ParsedMarkdownElement { /// A paragraph of text and other inline elements. Paragraph(MarkdownParagraph), HorizontalRule(Range), + Image(Image), } impl ParsedMarkdownElement { @@ -30,6 +32,7 @@ impl ParsedMarkdownElement { MarkdownParagraphChunk::Image(image) => image.source_range.clone(), }, Self::HorizontalRule(range) => range.clone(), + Self::Image(image) => image.source_range.clone(), }) } @@ -290,6 +293,8 @@ pub struct Image { pub link: Link, pub source_range: Range, pub alt_text: Option, + pub width: Option, + pub height: Option, } impl Image { @@ -303,10 +308,20 @@ impl Image { source_range, link, alt_text: None, + width: None, + height: None, }) } pub fn set_alt_text(&mut self, alt_text: SharedString) { self.alt_text = Some(alt_text); } + + pub fn set_width(&mut self, width: DefiniteLength) { + self.width = Some(width); + } + + pub fn set_height(&mut self, height: DefiniteLength) { + self.height = Some(height); + } } diff --git a/crates/markdown_preview/src/markdown_parser.rs b/crates/markdown_preview/src/markdown_parser.rs index 27691f2ecffadb7a7df1e9647e7d1d6487135974..1b116c50d9820dc4fea9d6b2e5816543d75e7d52 100644 --- a/crates/markdown_preview/src/markdown_parser.rs +++ b/crates/markdown_preview/src/markdown_parser.rs @@ -1,10 +1,12 @@ use crate::markdown_elements::*; use async_recursion::async_recursion; use collections::FxHashMap; -use gpui::FontWeight; +use gpui::{DefiniteLength, FontWeight, px, relative}; +use html5ever::{ParseOpts, local_name, parse_document, tendril::TendrilSink}; use language::LanguageRegistry; +use markup5ever_rcdom::RcDom; use pulldown_cmark::{Alignment, Event, Options, Parser, Tag, TagEnd}; -use std::{ops::Range, path::PathBuf, sync::Arc, vec}; +use std::{cell::RefCell, collections::HashMap, ops::Range, path::PathBuf, rc::Rc, sync::Arc, vec}; pub async fn parse_markdown( markdown_input: &str, @@ -76,22 +78,22 @@ impl<'a> MarkdownParser<'a> { if self.eof() || (steps + self.cursor) >= self.tokens.len() { return self.tokens.last(); } - return self.tokens.get(self.cursor + steps); + self.tokens.get(self.cursor + steps) } fn previous(&self) -> Option<&(Event<'_>, Range)> { if self.cursor == 0 || self.cursor > self.tokens.len() { return None; } - return self.tokens.get(self.cursor - 1); + self.tokens.get(self.cursor - 1) } fn current(&self) -> Option<&(Event<'_>, Range)> { - return self.peek(0); + self.peek(0) } fn current_event(&self) -> Option<&Event<'_>> { - return self.current().map(|(event, _)| event); + self.current().map(|(event, _)| event) } fn is_text_like(event: &Event) -> bool { @@ -172,13 +174,17 @@ impl<'a> MarkdownParser<'a> { self.cursor += 1; - let code_block = self.parse_code_block(language).await; + let code_block = self.parse_code_block(language).await?; Some(vec![ParsedMarkdownElement::CodeBlock(code_block)]) } + Tag::HtmlBlock => { + self.cursor += 1; + + Some(self.parse_html_block().await) + } _ => None, }, Event::Rule => { - let source_range = source_range.clone(); self.cursor += 1; Some(vec![ParsedMarkdownElement::HorizontalRule(source_range)]) } @@ -300,13 +306,12 @@ impl<'a> MarkdownParser<'a> { if style != MarkdownHighlightStyle::default() && last_run_len < text.len() { let mut new_highlight = true; - if let Some((last_range, last_style)) = highlights.last_mut() { - if last_range.end == last_run_len - && last_style == &MarkdownHighlight::Style(style.clone()) - { - last_range.end = text.len(); - new_highlight = false; - } + if let Some((last_range, last_style)) = highlights.last_mut() + && last_range.end == last_run_len + && last_style == &MarkdownHighlight::Style(style.clone()) + { + last_range.end = text.len(); + new_highlight = false; } if new_highlight { highlights.push(( @@ -380,7 +385,7 @@ impl<'a> MarkdownParser<'a> { TagEnd::Image => { if let Some(mut image) = image.take() { if !text.is_empty() { - image.alt_text = Some(std::mem::take(&mut text).into()); + image.set_alt_text(std::mem::take(&mut text).into()); } markdown_text_like.push(MarkdownParagraphChunk::Image(image)); } @@ -402,7 +407,7 @@ impl<'a> MarkdownParser<'a> { } if !text.is_empty() { markdown_text_like.push(MarkdownParagraphChunk::Text(ParsedMarkdownText { - source_range: source_range.clone(), + source_range, contents: text, highlights, regions, @@ -421,7 +426,7 @@ impl<'a> MarkdownParser<'a> { self.cursor += 1; ParsedMarkdownHeading { - source_range: source_range.clone(), + source_range, level: match level { pulldown_cmark::HeadingLevel::H1 => HeadingLevel::H1, pulldown_cmark::HeadingLevel::H2 => HeadingLevel::H2, @@ -579,10 +584,10 @@ impl<'a> MarkdownParser<'a> { } } else { let block = self.parse_block().await; - if let Some(block) = block { - if let Some(list_item) = items_stack.last_mut() { - list_item.content.extend(block); - } + if let Some(block) = block + && let Some(list_item) = items_stack.last_mut() + { + list_item.content.extend(block); } } } @@ -697,13 +702,22 @@ impl<'a> MarkdownParser<'a> { } } - async fn parse_code_block(&mut self, language: Option) -> ParsedMarkdownCodeBlock { - let (_event, source_range) = self.previous().unwrap(); + async fn parse_code_block( + &mut self, + language: Option, + ) -> Option { + let Some((_event, source_range)) = self.previous() else { + return None; + }; + let source_range = source_range.clone(); let mut code = String::new(); while !self.eof() { - let (current, _source_range) = self.current().unwrap(); + let Some((current, _source_range)) = self.current() else { + break; + }; + match current { Event::Text(text) => { code.push_str(text); @@ -736,23 +750,190 @@ impl<'a> MarkdownParser<'a> { None }; - ParsedMarkdownCodeBlock { + Some(ParsedMarkdownCodeBlock { source_range, contents: code.into(), language, highlights, + }) + } + + async fn parse_html_block(&mut self) -> Vec { + let mut elements = Vec::new(); + let Some((_event, _source_range)) = self.previous() else { + return elements; + }; + + while !self.eof() { + let Some((current, source_range)) = self.current() else { + break; + }; + let source_range = source_range.clone(); + match current { + Event::Html(html) => { + let mut cursor = std::io::Cursor::new(html.as_bytes()); + let Some(dom) = parse_document(RcDom::default(), ParseOpts::default()) + .from_utf8() + .read_from(&mut cursor) + .ok() + else { + self.cursor += 1; + continue; + }; + + self.cursor += 1; + + self.parse_html_node(source_range, &dom.document, &mut elements); + } + Event::End(TagEnd::CodeBlock) => { + self.cursor += 1; + break; + } + _ => { + break; + } + } + } + + elements + } + + fn parse_html_node( + &self, + source_range: Range, + node: &Rc, + elements: &mut Vec, + ) { + match &node.data { + markup5ever_rcdom::NodeData::Document => { + self.consume_children(source_range, node, elements); + } + markup5ever_rcdom::NodeData::Doctype { .. } => {} + markup5ever_rcdom::NodeData::Text { contents } => { + elements.push(ParsedMarkdownElement::Paragraph(vec![ + MarkdownParagraphChunk::Text(ParsedMarkdownText { + source_range, + contents: contents.borrow().to_string(), + highlights: Vec::default(), + region_ranges: Vec::default(), + regions: Vec::default(), + }), + ])); + } + markup5ever_rcdom::NodeData::Comment { .. } => {} + markup5ever_rcdom::NodeData::Element { name, attrs, .. } => { + if local_name!("img") == name.local { + if let Some(image) = self.extract_image(source_range, attrs) { + elements.push(ParsedMarkdownElement::Image(image)); + } + } else { + self.consume_children(source_range, node, elements); + } + } + markup5ever_rcdom::NodeData::ProcessingInstruction { .. } => {} + } + } + + fn consume_children( + &self, + source_range: Range, + node: &Rc, + elements: &mut Vec, + ) { + for node in node.children.borrow().iter() { + self.parse_html_node(source_range.clone(), node, elements); + } + } + + fn attr_value( + attrs: &RefCell>, + name: html5ever::LocalName, + ) -> Option { + attrs.borrow().iter().find_map(|attr| { + if attr.name.local == name { + Some(attr.value.to_string()) + } else { + None + } + }) + } + + fn extract_styles_from_attributes( + attrs: &RefCell>, + ) -> HashMap { + let mut styles = HashMap::new(); + + if let Some(style) = Self::attr_value(attrs, local_name!("style")) { + for decl in style.split(';') { + let mut parts = decl.splitn(2, ':'); + if let Some((key, value)) = parts.next().zip(parts.next()) { + styles.insert( + key.trim().to_lowercase().to_string(), + value.trim().to_string(), + ); + } + } + } + + styles + } + + fn extract_image( + &self, + source_range: Range, + attrs: &RefCell>, + ) -> Option { + let src = Self::attr_value(attrs, local_name!("src"))?; + + let mut image = Image::identify(src, source_range, self.file_location_directory.clone())?; + + if let Some(alt) = Self::attr_value(attrs, local_name!("alt")) { + image.set_alt_text(alt.into()); + } + + let styles = Self::extract_styles_from_attributes(attrs); + + if let Some(width) = Self::attr_value(attrs, local_name!("width")) + .or_else(|| styles.get("width").cloned()) + .and_then(|width| Self::parse_length(&width)) + { + image.set_width(width); + } + + if let Some(height) = Self::attr_value(attrs, local_name!("height")) + .or_else(|| styles.get("height").cloned()) + .and_then(|height| Self::parse_length(&height)) + { + image.set_height(height); + } + + Some(image) + } + + /// Parses the width/height attribute value of an html element (e.g. img element) + fn parse_length(value: &str) -> Option { + if value.ends_with("%") { + value + .trim_end_matches("%") + .parse::() + .ok() + .map(|value| relative(value / 100.)) + } else { + value + .trim_end_matches("px") + .parse() + .ok() + .map(|value| px(value).into()) } } } #[cfg(test)] mod tests { - use core::panic; - use super::*; - use ParsedMarkdownListItemType::*; - use gpui::BackgroundExecutor; + use core::panic; + use gpui::{AbsoluteLength, BackgroundExecutor, DefiniteLength}; use language::{ HighlightId, Language, LanguageConfig, LanguageMatcher, LanguageRegistry, tree_sitter_rust, }; @@ -927,6 +1108,8 @@ mod tests { url: "https://blog.logrocket.com/wp-content/uploads/2024/04/exploring-zed-open-source-code-editor-rust-2.png".to_string(), }, alt_text: Some("test".into()), + height: None, + width: None, },) ); } @@ -948,6 +1131,8 @@ mod tests { url: "http://example.com/foo.png".to_string(), }, alt_text: None, + height: None, + width: None, },) ); } @@ -967,6 +1152,8 @@ mod tests { url: "http://example.com/foo.png".to_string(), }, alt_text: Some("foo bar baz".into()), + height: None, + width: None, }),], ); } @@ -992,6 +1179,8 @@ mod tests { url: "http://example.com/foo.png".to_string(), }, alt_text: Some("foo".into()), + height: None, + width: None, }), MarkdownParagraphChunk::Text(ParsedMarkdownText { source_range: 0..81, @@ -1006,11 +1195,168 @@ mod tests { url: "http://example.com/bar.png".to_string(), }, alt_text: Some("bar".into()), + height: None, + width: None, }) ] ); } + #[test] + fn test_parse_length() { + // Test percentage values + assert_eq!( + MarkdownParser::parse_length("50%"), + Some(DefiniteLength::Fraction(0.5)) + ); + assert_eq!( + MarkdownParser::parse_length("100%"), + Some(DefiniteLength::Fraction(1.0)) + ); + assert_eq!( + MarkdownParser::parse_length("25%"), + Some(DefiniteLength::Fraction(0.25)) + ); + assert_eq!( + MarkdownParser::parse_length("0%"), + Some(DefiniteLength::Fraction(0.0)) + ); + + // Test pixel values + assert_eq!( + MarkdownParser::parse_length("100px"), + Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(px(100.0)))) + ); + assert_eq!( + MarkdownParser::parse_length("50px"), + Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(px(50.0)))) + ); + assert_eq!( + MarkdownParser::parse_length("0px"), + Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(px(0.0)))) + ); + + // Test values without units (should be treated as pixels) + assert_eq!( + MarkdownParser::parse_length("100"), + Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(px(100.0)))) + ); + assert_eq!( + MarkdownParser::parse_length("42"), + Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(px(42.0)))) + ); + + // Test invalid values + assert_eq!(MarkdownParser::parse_length("invalid"), None); + assert_eq!(MarkdownParser::parse_length("px"), None); + assert_eq!(MarkdownParser::parse_length("%"), None); + assert_eq!(MarkdownParser::parse_length(""), None); + assert_eq!(MarkdownParser::parse_length("abc%"), None); + assert_eq!(MarkdownParser::parse_length("abcpx"), None); + + // Test decimal values + assert_eq!( + MarkdownParser::parse_length("50.5%"), + Some(DefiniteLength::Fraction(0.505)) + ); + assert_eq!( + MarkdownParser::parse_length("100.25px"), + Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(px(100.25)))) + ); + assert_eq!( + MarkdownParser::parse_length("42.0"), + Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(px(42.0)))) + ); + } + + #[gpui::test] + async fn test_html_image_tag() { + let parsed = parse("").await; + + let ParsedMarkdownElement::Image(image) = &parsed.children[0] else { + panic!("Expected a image element"); + }; + assert_eq!( + image.clone(), + Image { + source_range: 0..40, + link: Link::Web { + url: "http://example.com/foo.png".to_string(), + }, + alt_text: None, + height: None, + width: None, + }, + ); + } + + #[gpui::test] + async fn test_html_image_tag_with_alt_text() { + let parsed = parse("\"Foo\"").await; + + let ParsedMarkdownElement::Image(image) = &parsed.children[0] else { + panic!("Expected a image element"); + }; + assert_eq!( + image.clone(), + Image { + source_range: 0..50, + link: Link::Web { + url: "http://example.com/foo.png".to_string(), + }, + alt_text: Some("Foo".into()), + height: None, + width: None, + }, + ); + } + + #[gpui::test] + async fn test_html_image_tag_with_height_and_width() { + let parsed = + parse("").await; + + let ParsedMarkdownElement::Image(image) = &parsed.children[0] else { + panic!("Expected a image element"); + }; + assert_eq!( + image.clone(), + Image { + source_range: 0..65, + link: Link::Web { + url: "http://example.com/foo.png".to_string(), + }, + alt_text: None, + height: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(px(100.)))), + width: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(px(200.)))), + }, + ); + } + + #[gpui::test] + async fn test_html_image_style_tag_with_height_and_width() { + let parsed = parse( + "", + ) + .await; + + let ParsedMarkdownElement::Image(image) = &parsed.children[0] else { + panic!("Expected a image element"); + }; + assert_eq!( + image.clone(), + Image { + source_range: 0..75, + link: Link::Web { + url: "http://example.com/foo.png".to_string(), + }, + alt_text: None, + height: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(px(100.)))), + width: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(px(200.)))), + }, + ); + } + #[gpui::test] async fn test_header_only_table() { let markdown = "\ diff --git a/crates/markdown_preview/src/markdown_preview_view.rs b/crates/markdown_preview/src/markdown_preview_view.rs index a0c8819991d68336a306af85a4dd709353222fa1..1121d64655f6c7e02f0b0d621605c9ba1aae7cde 100644 --- a/crates/markdown_preview/src/markdown_preview_view.rs +++ b/crates/markdown_preview/src/markdown_preview_view.rs @@ -115,8 +115,7 @@ impl MarkdownPreviewView { pane.activate_item(existing_follow_view_idx, true, true, window, cx); }); } else { - let view = - Self::create_following_markdown_view(workspace, editor.clone(), window, cx); + let view = Self::create_following_markdown_view(workspace, editor, window, cx); workspace.active_pane().update(cx, |pane, cx| { pane.add_item(Box::new(view.clone()), true, true, None, window, cx) }); @@ -151,10 +150,9 @@ impl MarkdownPreviewView { if let Some(editor) = workspace .active_item(cx) .and_then(|item| item.act_as::(cx)) + && Self::is_markdown_file(&editor, cx) { - if Self::is_markdown_file(&editor, cx) { - return Some(editor); - } + return Some(editor); } None } @@ -243,32 +241,30 @@ impl MarkdownPreviewView { window: &mut Window, cx: &mut Context, ) { - if let Some(item) = active_item { - if item.item_id() != cx.entity_id() { - if let Some(editor) = item.act_as::(cx) { - if Self::is_markdown_file(&editor, cx) { - self.set_editor(editor, window, cx); - } - } - } + if let Some(item) = active_item + && item.item_id() != cx.entity_id() + && let Some(editor) = item.act_as::(cx) + && Self::is_markdown_file(&editor, cx) + { + self.set_editor(editor, window, cx); } } pub fn is_markdown_file(editor: &Entity, cx: &mut Context) -> bool { let buffer = editor.read(cx).buffer().read(cx); - if let Some(buffer) = buffer.as_singleton() { - if let Some(language) = buffer.read(cx).language() { - return language.name() == "Markdown".into(); - } + if let Some(buffer) = buffer.as_singleton() + && let Some(language) = buffer.read(cx).language() + { + return language.name() == "Markdown".into(); } false } fn set_editor(&mut self, editor: Entity, window: &mut Window, cx: &mut Context) { - if let Some(active) = &self.active_editor { - if active.editor == editor { - return; - } + if let Some(active) = &self.active_editor + && active.editor == editor + { + return; } let subscription = cx.subscribe_in( @@ -552,21 +548,20 @@ impl Render for MarkdownPreviewView { .group("markdown-block") .on_click(cx.listener( move |this, event: &ClickEvent, window, cx| { - if event.click_count() == 2 { - if let Some(source_range) = this + if event.click_count() == 2 + && let Some(source_range) = this .contents .as_ref() .and_then(|c| c.children.get(ix)) .and_then(|block: &ParsedMarkdownElement| { block.source_range() }) - { - this.move_cursor_to_block( - window, - cx, - source_range.start..source_range.start, - ); - } + { + this.move_cursor_to_block( + window, + cx, + source_range.start..source_range.start, + ); } }, )) diff --git a/crates/markdown_preview/src/markdown_renderer.rs b/crates/markdown_preview/src/markdown_renderer.rs index 37d2ca21105566f1e2e3271f49c75a3ce1d7846b..5835797d9ecd1fa8cf97b18a9e518ab93ab4599c 100644 --- a/crates/markdown_preview/src/markdown_renderer.rs +++ b/crates/markdown_preview/src/markdown_renderer.rs @@ -1,5 +1,5 @@ use crate::markdown_elements::{ - HeadingLevel, Link, MarkdownParagraph, MarkdownParagraphChunk, ParsedMarkdown, + HeadingLevel, Image, Link, MarkdownParagraph, MarkdownParagraphChunk, ParsedMarkdown, ParsedMarkdownBlockQuote, ParsedMarkdownCodeBlock, ParsedMarkdownElement, ParsedMarkdownHeading, ParsedMarkdownListItem, ParsedMarkdownListItemType, ParsedMarkdownTable, ParsedMarkdownTableAlignment, ParsedMarkdownTableRow, @@ -111,11 +111,10 @@ impl RenderContext { /// buffer font size changes. The callees of this function should be reimplemented to use real /// relative sizing once that is implemented in GPUI pub fn scaled_rems(&self, rems: f32) -> Rems { - return self - .buffer_text_style + self.buffer_text_style .font_size .to_rems(self.window_rem_size) - .mul(rems); + .mul(rems) } /// This ensures that children inside of block quotes @@ -165,6 +164,7 @@ pub fn render_markdown_block(block: &ParsedMarkdownElement, cx: &mut RenderConte BlockQuote(block_quote) => render_markdown_block_quote(block_quote, cx), CodeBlock(code_block) => render_markdown_code_block(code_block, cx), HorizontalRule(_) => render_markdown_rule(cx), + Image(image) => render_markdown_image(image, cx), } } @@ -277,7 +277,11 @@ fn render_markdown_list_item( .items_start() .children(vec![ bullet, - div().children(contents).pr(cx.scaled_rems(1.0)).w_full(), + v_flex() + .children(contents) + .gap(cx.scaled_rems(1.0)) + .pr(cx.scaled_rems(1.0)) + .w_full(), ]); cx.with_common_p(item).into_any() @@ -459,13 +463,13 @@ fn render_markdown_table(parsed: &ParsedMarkdownTable, cx: &mut RenderContext) - let mut max_lengths: Vec = vec![0; parsed.header.children.len()]; for (index, cell) in parsed.header.children.iter().enumerate() { - let length = paragraph_len(&cell); + let length = paragraph_len(cell); max_lengths[index] = length; } for row in &parsed.body { for (index, cell) in row.children.iter().enumerate() { - let length = paragraph_len(&cell); + let length = paragraph_len(cell); if length > max_lengths[index] { max_lengths[index] = length; @@ -723,65 +727,7 @@ fn render_markdown_text(parsed_new: &MarkdownParagraph, cx: &mut RenderContext) } MarkdownParagraphChunk::Image(image) => { - let image_resource = match image.link.clone() { - Link::Web { url } => Resource::Uri(url.into()), - Link::Path { path, .. } => Resource::Path(Arc::from(path)), - }; - - let element_id = cx.next_id(&image.source_range); - - let image_element = div() - .id(element_id) - .cursor_pointer() - .child( - img(ImageSource::Resource(image_resource)) - .max_w_full() - .with_fallback({ - let alt_text = image.alt_text.clone(); - move || div().children(alt_text.clone()).into_any_element() - }), - ) - .tooltip({ - let link = image.link.clone(); - move |_, cx| { - InteractiveMarkdownElementTooltip::new( - Some(link.to_string()), - "open image", - cx, - ) - .into() - } - }) - .on_click({ - let workspace = workspace_clone.clone(); - let link = image.link.clone(); - move |_, window, cx| { - if window.modifiers().secondary() { - match &link { - Link::Web { url } => cx.open_url(url), - Link::Path { path, .. } => { - if let Some(workspace) = &workspace { - _ = workspace.update(cx, |workspace, cx| { - workspace - .open_abs_path( - path.clone(), - OpenOptions { - visible: Some(OpenVisible::None), - ..Default::default() - }, - window, - cx, - ) - .detach(); - }); - } - } - } - } - } - }) - .into_any(); - any_element.push(image_element); + any_element.push(render_markdown_image(image, cx)); } } } @@ -794,18 +740,86 @@ fn render_markdown_rule(cx: &mut RenderContext) -> AnyElement { div().py(cx.scaled_rems(0.5)).child(rule).into_any() } +fn render_markdown_image(image: &Image, cx: &mut RenderContext) -> AnyElement { + let image_resource = match image.link.clone() { + Link::Web { url } => Resource::Uri(url.into()), + Link::Path { path, .. } => Resource::Path(Arc::from(path)), + }; + + let element_id = cx.next_id(&image.source_range); + let workspace = cx.workspace.clone(); + + div() + .id(element_id) + .cursor_pointer() + .child( + img(ImageSource::Resource(image_resource)) + .max_w_full() + .with_fallback({ + let alt_text = image.alt_text.clone(); + move || div().children(alt_text.clone()).into_any_element() + }) + .when_some(image.height, |this, height| this.h(height)) + .when_some(image.width, |this, width| this.w(width)), + ) + .tooltip({ + let link = image.link.clone(); + let alt_text = image.alt_text.clone(); + move |_, cx| { + InteractiveMarkdownElementTooltip::new( + Some(alt_text.clone().unwrap_or(link.to_string().into())), + "open image", + cx, + ) + .into() + } + }) + .on_click({ + let link = image.link.clone(); + move |_, window, cx| { + if window.modifiers().secondary() { + match &link { + Link::Web { url } => cx.open_url(url), + Link::Path { path, .. } => { + if let Some(workspace) = &workspace { + _ = workspace.update(cx, |workspace, cx| { + workspace + .open_abs_path( + path.clone(), + OpenOptions { + visible: Some(OpenVisible::None), + ..Default::default() + }, + window, + cx, + ) + .detach(); + }); + } + } + } + } + } + }) + .into_any() +} + struct InteractiveMarkdownElementTooltip { tooltip_text: Option, - action_text: String, + action_text: SharedString, } impl InteractiveMarkdownElementTooltip { - pub fn new(tooltip_text: Option, action_text: &str, cx: &mut App) -> Entity { + pub fn new( + tooltip_text: Option, + action_text: impl Into, + cx: &mut App, + ) -> Entity { let tooltip_text = tooltip_text.map(|t| util::truncate_and_trailoff(&t, 50).into()); cx.new(|_cx| Self { tooltip_text, - action_text: action_text.to_string(), + action_text: action_text.into(), }) } } diff --git a/crates/migrator/src/migrations/m_2025_01_02/settings.rs b/crates/migrator/src/migrations/m_2025_01_02/settings.rs index 3ce85e6b2611b69dfaac5479ee3404eeda9c0ebc..a35b1ebd2e9d8e2c658de0623b7c2e8377662b18 100644 --- a/crates/migrator/src/migrations/m_2025_01_02/settings.rs +++ b/crates/migrator/src/migrations/m_2025_01_02/settings.rs @@ -20,14 +20,14 @@ fn replace_deprecated_settings_values( .nodes_for_capture_index(parent_object_capture_ix) .next()? .byte_range(); - let parent_object_name = contents.get(parent_object_range.clone())?; + let parent_object_name = contents.get(parent_object_range)?; let setting_name_ix = query.capture_index_for_name("setting_name")?; let setting_name_range = mat .nodes_for_capture_index(setting_name_ix) .next()? .byte_range(); - let setting_name = contents.get(setting_name_range.clone())?; + let setting_name = contents.get(setting_name_range)?; let setting_value_ix = query.capture_index_for_name("setting_value")?; let setting_value_range = mat diff --git a/crates/migrator/src/migrations/m_2025_01_29/keymap.rs b/crates/migrator/src/migrations/m_2025_01_29/keymap.rs index 646af8f63dc90b6ebe3faef9432eecc54140b438..eed2c46e0816452af6813ae699eab6cec1d65eec 100644 --- a/crates/migrator/src/migrations/m_2025_01_29/keymap.rs +++ b/crates/migrator/src/migrations/m_2025_01_29/keymap.rs @@ -242,22 +242,22 @@ static STRING_REPLACE: LazyLock> = LazyLock::new(|| { "inline_completion::ToggleMenu", "edit_prediction::ToggleMenu", ), - ("editor::NextEditPrediction", "editor::NextEditPrediction"), + ("editor::NextInlineCompletion", "editor::NextEditPrediction"), ( - "editor::PreviousEditPrediction", + "editor::PreviousInlineCompletion", "editor::PreviousEditPrediction", ), ( - "editor::AcceptPartialEditPrediction", + "editor::AcceptPartialInlineCompletion", "editor::AcceptPartialEditPrediction", ), - ("editor::ShowEditPrediction", "editor::ShowEditPrediction"), + ("editor::ShowInlineCompletion", "editor::ShowEditPrediction"), ( - "editor::AcceptEditPrediction", + "editor::AcceptInlineCompletion", "editor::AcceptEditPrediction", ), ( - "editor::ToggleEditPredictions", + "editor::ToggleInlineCompletions", "editor::ToggleEditPrediction", ), ]) @@ -279,7 +279,7 @@ fn rename_context_key( new_predicate = new_predicate.replace(old_key, new_key); } if new_predicate != old_predicate { - Some((context_predicate_range, new_predicate.to_string())) + Some((context_predicate_range, new_predicate)) } else { None } diff --git a/crates/migrator/src/migrations/m_2025_01_29/settings.rs b/crates/migrator/src/migrations/m_2025_01_29/settings.rs index 8d3261676b731d00e3dd85f3f5d94737931d74fe..46cfe2f178f1e4416cb404f26b5b77b55663aa29 100644 --- a/crates/migrator/src/migrations/m_2025_01_29/settings.rs +++ b/crates/migrator/src/migrations/m_2025_01_29/settings.rs @@ -57,7 +57,7 @@ pub fn replace_edit_prediction_provider_setting( .nodes_for_capture_index(parent_object_capture_ix) .next()? .byte_range(); - let parent_object_name = contents.get(parent_object_range.clone())?; + let parent_object_name = contents.get(parent_object_range)?; let setting_name_ix = query.capture_index_for_name("setting_name")?; let setting_range = mat diff --git a/crates/migrator/src/migrations/m_2025_01_30/settings.rs b/crates/migrator/src/migrations/m_2025_01_30/settings.rs index 23a3243b827b7d44e673208e56858b6cd2e8f2b7..2d763e4722cb2119f0b2f982b5841aab37e55c12 100644 --- a/crates/migrator/src/migrations/m_2025_01_30/settings.rs +++ b/crates/migrator/src/migrations/m_2025_01_30/settings.rs @@ -25,7 +25,7 @@ fn replace_tab_close_button_setting_key( .nodes_for_capture_index(parent_object_capture_ix) .next()? .byte_range(); - let parent_object_name = contents.get(parent_object_range.clone())?; + let parent_object_name = contents.get(parent_object_range)?; let setting_name_ix = query.capture_index_for_name("setting_name")?; let setting_range = mat @@ -51,14 +51,14 @@ fn replace_tab_close_button_setting_value( .nodes_for_capture_index(parent_object_capture_ix) .next()? .byte_range(); - let parent_object_name = contents.get(parent_object_range.clone())?; + let parent_object_name = contents.get(parent_object_range)?; let setting_name_ix = query.capture_index_for_name("setting_name")?; let setting_name_range = mat .nodes_for_capture_index(setting_name_ix) .next()? .byte_range(); - let setting_name = contents.get(setting_name_range.clone())?; + let setting_name = contents.get(setting_name_range)?; let setting_value_ix = query.capture_index_for_name("setting_value")?; let setting_value_range = mat diff --git a/crates/migrator/src/migrations/m_2025_03_29/settings.rs b/crates/migrator/src/migrations/m_2025_03_29/settings.rs index 47f65b407da2b7079fb68a4877275339d6309433..8f83d8e39ea050de0ec9291199804f0e62dab392 100644 --- a/crates/migrator/src/migrations/m_2025_03_29/settings.rs +++ b/crates/migrator/src/migrations/m_2025_03_29/settings.rs @@ -19,7 +19,7 @@ fn replace_setting_value( .nodes_for_capture_index(setting_capture_ix) .next()? .byte_range(); - let setting_name = contents.get(setting_name_range.clone())?; + let setting_name = contents.get(setting_name_range)?; if setting_name != "hide_mouse_while_typing" { return None; diff --git a/crates/migrator/src/migrations/m_2025_05_05/settings.rs b/crates/migrator/src/migrations/m_2025_05_05/settings.rs index 88c6c338d18bc9c648a6c09e8fe1755bc3f77cd9..77da1b9a077b4acc2e6df6d47713f8e15f0fd090 100644 --- a/crates/migrator/src/migrations/m_2025_05_05/settings.rs +++ b/crates/migrator/src/migrations/m_2025_05_05/settings.rs @@ -24,7 +24,7 @@ fn rename_assistant( .nodes_for_capture_index(key_capture_ix) .next()? .byte_range(); - return Some((key_range, "agent".to_string())); + Some((key_range, "agent".to_string())) } fn rename_edit_prediction_assistant( @@ -37,5 +37,5 @@ fn rename_edit_prediction_assistant( .nodes_for_capture_index(key_capture_ix) .next()? .byte_range(); - return Some((key_range, "enabled_in_text_threads".to_string())); + Some((key_range, "enabled_in_text_threads".to_string())) } diff --git a/crates/migrator/src/migrations/m_2025_05_29/settings.rs b/crates/migrator/src/migrations/m_2025_05_29/settings.rs index 56d72836fa396810db2a220f57b8144c939a872a..37ef0e45cc0730c9861ca4362a4b93f025002c6d 100644 --- a/crates/migrator/src/migrations/m_2025_05_29/settings.rs +++ b/crates/migrator/src/migrations/m_2025_05_29/settings.rs @@ -19,7 +19,7 @@ fn replace_preferred_completion_mode_value( .nodes_for_capture_index(parent_object_capture_ix) .next()? .byte_range(); - let parent_object_name = contents.get(parent_object_range.clone())?; + let parent_object_name = contents.get(parent_object_range)?; if parent_object_name != "agent" { return None; @@ -30,7 +30,7 @@ fn replace_preferred_completion_mode_value( .nodes_for_capture_index(setting_name_capture_ix) .next()? .byte_range(); - let setting_name = contents.get(setting_name_range.clone())?; + let setting_name = contents.get(setting_name_range)?; if setting_name != "preferred_completion_mode" { return None; diff --git a/crates/migrator/src/migrations/m_2025_06_16/settings.rs b/crates/migrator/src/migrations/m_2025_06_16/settings.rs index cce407e21b81bf9064c1261c142b216b622712a8..cd79eae2048ca9809b720b7913eba12b3e6cb1ce 100644 --- a/crates/migrator/src/migrations/m_2025_06_16/settings.rs +++ b/crates/migrator/src/migrations/m_2025_06_16/settings.rs @@ -40,20 +40,20 @@ fn migrate_context_server_settings( // Parse the server settings to check what keys it contains let mut cursor = server_settings.walk(); for child in server_settings.children(&mut cursor) { - if child.kind() == "pair" { - if let Some(key_node) = child.child_by_field_name("key") { - if let (None, Some(quote_content)) = (column, key_node.child(0)) { - column = Some(quote_content.start_position().column); - } - if let Some(string_content) = key_node.child(1) { - let key = &contents[string_content.byte_range()]; - match key { - // If it already has a source key, don't modify it - "source" => return None, - "command" => has_command = true, - "settings" => has_settings = true, - _ => other_keys += 1, - } + if child.kind() == "pair" + && let Some(key_node) = child.child_by_field_name("key") + { + if let (None, Some(quote_content)) = (column, key_node.child(0)) { + column = Some(quote_content.start_position().column); + } + if let Some(string_content) = key_node.child(1) { + let key = &contents[string_content.byte_range()]; + match key { + // If it already has a source key, don't modify it + "source" => return None, + "command" => has_command = true, + "settings" => has_settings = true, + _ => other_keys += 1, } } } diff --git a/crates/migrator/src/migrations/m_2025_06_25/settings.rs b/crates/migrator/src/migrations/m_2025_06_25/settings.rs index 5dd6c3093a43b00acff3db6c1e316a3fc6664175..2bf7658eeb9036c0b1d08d2af446c0aba788d402 100644 --- a/crates/migrator/src/migrations/m_2025_06_25/settings.rs +++ b/crates/migrator/src/migrations/m_2025_06_25/settings.rs @@ -84,10 +84,10 @@ fn remove_pair_with_whitespace( } } else { // If no next sibling, check if there's a comma before - if let Some(prev_sibling) = pair_node.prev_sibling() { - if prev_sibling.kind() == "," { - range_to_remove.start = prev_sibling.start_byte(); - } + if let Some(prev_sibling) = pair_node.prev_sibling() + && prev_sibling.kind() == "," + { + range_to_remove.start = prev_sibling.start_byte(); } } @@ -123,10 +123,10 @@ fn remove_pair_with_whitespace( // Also check if we need to include trailing whitespace up to the next line let text_after = &contents[range_to_remove.end..]; - if let Some(newline_pos) = text_after.find('\n') { - if text_after[..newline_pos].chars().all(|c| c.is_whitespace()) { - range_to_remove.end += newline_pos + 1; - } + if let Some(newline_pos) = text_after.find('\n') + && text_after[..newline_pos].chars().all(|c| c.is_whitespace()) + { + range_to_remove.end += newline_pos + 1; } Some((range_to_remove, String::new())) diff --git a/crates/migrator/src/migrations/m_2025_06_27/settings.rs b/crates/migrator/src/migrations/m_2025_06_27/settings.rs index 6156308fcec05dfb10b5b258d31077e5d4b09adc..e3e951b1a69e39d19e93a152a264750caf51a81e 100644 --- a/crates/migrator/src/migrations/m_2025_06_27/settings.rs +++ b/crates/migrator/src/migrations/m_2025_06_27/settings.rs @@ -56,19 +56,18 @@ fn flatten_context_server_command( let mut cursor = command_object.walk(); for child in command_object.children(&mut cursor) { - if child.kind() == "pair" { - if let Some(key_node) = child.child_by_field_name("key") { - if let Some(string_content) = key_node.child(1) { - let key = &contents[string_content.byte_range()]; - if let Some(value_node) = child.child_by_field_name("value") { - let value_range = value_node.byte_range(); - match key { - "path" => path_value = Some(&contents[value_range]), - "args" => args_value = Some(&contents[value_range]), - "env" => env_value = Some(&contents[value_range]), - _ => {} - } - } + if child.kind() == "pair" + && let Some(key_node) = child.child_by_field_name("key") + && let Some(string_content) = key_node.child(1) + { + let key = &contents[string_content.byte_range()]; + if let Some(value_node) = child.child_by_field_name("value") { + let value_range = value_node.byte_range(); + match key { + "path" => path_value = Some(&contents[value_range]), + "args" => args_value = Some(&contents[value_range]), + "env" => env_value = Some(&contents[value_range]), + _ => {} } } } diff --git a/crates/migrator/src/migrator.rs b/crates/migrator/src/migrator.rs index b425f7f1d5dc691ed1501d712ab72556412f7eb6..2180a049d03daf5fcd2a60e1f1f7ddd0013c7d1f 100644 --- a/crates/migrator/src/migrator.rs +++ b/crates/migrator/src/migrator.rs @@ -28,7 +28,7 @@ fn migrate(text: &str, patterns: MigrationPatterns, query: &Query) -> Result Result Result> { pub fn migrate_edit_prediction_provider_settings(text: &str) -> Result> { migrate( - &text, + text, &[( SETTINGS_NESTED_KEY_VALUE_PATTERN, migrations::m_2025_01_29::replace_edit_prediction_provider_setting, @@ -293,12 +293,12 @@ mod tests { use super::*; fn assert_migrate_keymap(input: &str, output: Option<&str>) { - let migrated = migrate_keymap(&input).unwrap(); + let migrated = migrate_keymap(input).unwrap(); pretty_assertions::assert_eq!(migrated.as_deref(), output); } fn assert_migrate_settings(input: &str, output: Option<&str>) { - let migrated = migrate_settings(&input).unwrap(); + let migrated = migrate_settings(input).unwrap(); pretty_assertions::assert_eq!(migrated.as_deref(), output); } diff --git a/crates/mistral/src/mistral.rs b/crates/mistral/src/mistral.rs index 5b4d05377c7132f47828aa6afafbb5c850e940a8..d6f62cfaa07bc211881817e6178a8673a9a670a6 100644 --- a/crates/mistral/src/mistral.rs +++ b/crates/mistral/src/mistral.rs @@ -286,12 +286,13 @@ pub enum Prediction { } #[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] +#[serde(rename_all = "lowercase")] pub enum ToolChoice { Auto, Required, None, Any, + #[serde(untagged)] Function(ToolDefinition), } @@ -482,7 +483,7 @@ pub async fn stream_completion( .method(Method::POST) .uri(uri) .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)); + .header("Authorization", format!("Bearer {}", api_key.trim())); let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let mut response = client.send(request).await?; diff --git a/crates/multi_buffer/src/anchor.rs b/crates/multi_buffer/src/anchor.rs index 1305328d384023517dbb80d25e210b44e632eed8..6bed0a4028c5c4b0816355a397046560fb6b8618 100644 --- a/crates/multi_buffer/src/anchor.rs +++ b/crates/multi_buffer/src/anchor.rs @@ -76,27 +76,26 @@ impl Anchor { if text_cmp.is_ne() { return text_cmp; } - if self.diff_base_anchor.is_some() || other.diff_base_anchor.is_some() { - if let Some(base_text) = snapshot + if (self.diff_base_anchor.is_some() || other.diff_base_anchor.is_some()) + && let Some(base_text) = snapshot .diffs .get(&excerpt.buffer_id) .map(|diff| diff.base_text()) - { - let self_anchor = self.diff_base_anchor.filter(|a| base_text.can_resolve(a)); - let other_anchor = other.diff_base_anchor.filter(|a| base_text.can_resolve(a)); - return match (self_anchor, other_anchor) { - (Some(a), Some(b)) => a.cmp(&b, base_text), - (Some(_), None) => match other.text_anchor.bias { - Bias::Left => Ordering::Greater, - Bias::Right => Ordering::Less, - }, - (None, Some(_)) => match self.text_anchor.bias { - Bias::Left => Ordering::Less, - Bias::Right => Ordering::Greater, - }, - (None, None) => Ordering::Equal, - }; - } + { + let self_anchor = self.diff_base_anchor.filter(|a| base_text.can_resolve(a)); + let other_anchor = other.diff_base_anchor.filter(|a| base_text.can_resolve(a)); + return match (self_anchor, other_anchor) { + (Some(a), Some(b)) => a.cmp(&b, base_text), + (Some(_), None) => match other.text_anchor.bias { + Bias::Left => Ordering::Greater, + Bias::Right => Ordering::Less, + }, + (None, Some(_)) => match self.text_anchor.bias { + Bias::Left => Ordering::Less, + Bias::Right => Ordering::Greater, + }, + (None, None) => Ordering::Equal, + }; } } Ordering::Equal @@ -107,51 +106,49 @@ impl Anchor { } pub fn bias_left(&self, snapshot: &MultiBufferSnapshot) -> Anchor { - if self.text_anchor.bias != Bias::Left { - if let Some(excerpt) = snapshot.excerpt(self.excerpt_id) { - return Self { - buffer_id: self.buffer_id, - excerpt_id: self.excerpt_id, - text_anchor: self.text_anchor.bias_left(&excerpt.buffer), - diff_base_anchor: self.diff_base_anchor.map(|a| { - if let Some(base_text) = snapshot - .diffs - .get(&excerpt.buffer_id) - .map(|diff| diff.base_text()) - { - if a.buffer_id == Some(base_text.remote_id()) { - return a.bias_left(base_text); - } - } - a - }), - }; - } + if self.text_anchor.bias != Bias::Left + && let Some(excerpt) = snapshot.excerpt(self.excerpt_id) + { + return Self { + buffer_id: self.buffer_id, + excerpt_id: self.excerpt_id, + text_anchor: self.text_anchor.bias_left(&excerpt.buffer), + diff_base_anchor: self.diff_base_anchor.map(|a| { + if let Some(base_text) = snapshot + .diffs + .get(&excerpt.buffer_id) + .map(|diff| diff.base_text()) + && a.buffer_id == Some(base_text.remote_id()) + { + return a.bias_left(base_text); + } + a + }), + }; } *self } pub fn bias_right(&self, snapshot: &MultiBufferSnapshot) -> Anchor { - if self.text_anchor.bias != Bias::Right { - if let Some(excerpt) = snapshot.excerpt(self.excerpt_id) { - return Self { - buffer_id: self.buffer_id, - excerpt_id: self.excerpt_id, - text_anchor: self.text_anchor.bias_right(&excerpt.buffer), - diff_base_anchor: self.diff_base_anchor.map(|a| { - if let Some(base_text) = snapshot - .diffs - .get(&excerpt.buffer_id) - .map(|diff| diff.base_text()) - { - if a.buffer_id == Some(base_text.remote_id()) { - return a.bias_right(&base_text); - } - } - a - }), - }; - } + if self.text_anchor.bias != Bias::Right + && let Some(excerpt) = snapshot.excerpt(self.excerpt_id) + { + return Self { + buffer_id: self.buffer_id, + excerpt_id: self.excerpt_id, + text_anchor: self.text_anchor.bias_right(&excerpt.buffer), + diff_base_anchor: self.diff_base_anchor.map(|a| { + if let Some(base_text) = snapshot + .diffs + .get(&excerpt.buffer_id) + .map(|diff| diff.base_text()) + && a.buffer_id == Some(base_text.remote_id()) + { + return a.bias_right(base_text); + } + a + }), + }; } *self } @@ -212,7 +209,7 @@ impl AnchorRangeExt for Range { } fn includes(&self, other: &Range, buffer: &MultiBufferSnapshot) -> bool { - self.start.cmp(&other.start, &buffer).is_le() && other.end.cmp(&self.end, &buffer).is_le() + self.start.cmp(&other.start, buffer).is_le() && other.end.cmp(&self.end, buffer).is_le() } fn overlaps(&self, other: &Range, buffer: &MultiBufferSnapshot) -> bool { diff --git a/crates/multi_buffer/src/multi_buffer.rs b/crates/multi_buffer/src/multi_buffer.rs index eb12e6929cbc4bf74f44a2cb6eb9970c825d0fe3..8fa8c2c08c25aa6f003365556594ed0719f9861e 100644 --- a/crates/multi_buffer/src/multi_buffer.rs +++ b/crates/multi_buffer/src/multi_buffer.rs @@ -113,15 +113,10 @@ pub enum Event { transaction_id: TransactionId, }, Reloaded, - ReloadNeeded, - LanguageChanged(BufferId), - CapabilityChanged, Reparsed(BufferId), Saved, FileHandleChanged, - Closed, - Discarded, DirtyChanged, DiagnosticsUpdated, BufferDiffChanged, @@ -735,7 +730,7 @@ impl MultiBuffer { pub fn as_singleton(&self) -> Option> { if self.singleton { - return Some( + Some( self.buffers .borrow() .values() @@ -743,7 +738,7 @@ impl MultiBuffer { .unwrap() .buffer .clone(), - ); + ) } else { None } @@ -835,7 +830,7 @@ impl MultiBuffer { this.convert_edits_to_buffer_edits(edits, &snapshot, &original_indent_columns); drop(snapshot); - let mut buffer_ids = Vec::new(); + let mut buffer_ids = Vec::with_capacity(buffer_edits.len()); for (buffer_id, mut edits) in buffer_edits { buffer_ids.push(buffer_id); edits.sort_by_key(|edit| edit.range.start); @@ -1082,11 +1077,11 @@ impl MultiBuffer { let mut ranges: Vec> = Vec::new(); for edit in edits { - if let Some(last_range) = ranges.last_mut() { - if edit.range.start <= last_range.end { - last_range.end = last_range.end.max(edit.range.end); - continue; - } + if let Some(last_range) = ranges.last_mut() + && edit.range.start <= last_range.end + { + last_range.end = last_range.end.max(edit.range.end); + continue; } ranges.push(edit.range); } @@ -1146,13 +1141,13 @@ impl MultiBuffer { pub fn last_transaction_id(&self, cx: &App) -> Option { if let Some(buffer) = self.as_singleton() { - return buffer + buffer .read(cx) .peek_undo_stack() - .map(|history_entry| history_entry.transaction_id()); + .map(|history_entry| history_entry.transaction_id()) } else { let last_transaction = self.history.undo_stack.last()?; - return Some(last_transaction.id); + Some(last_transaction.id) } } @@ -1212,25 +1207,24 @@ impl MultiBuffer { for range in buffer.edited_ranges_for_transaction_id::(*buffer_transaction) { for excerpt_id in &buffer_state.excerpts { cursor.seek(excerpt_id, Bias::Left); - if let Some(excerpt) = cursor.item() { - if excerpt.locator == *excerpt_id { - let excerpt_buffer_start = - excerpt.range.context.start.summary::(buffer); - let excerpt_buffer_end = excerpt.range.context.end.summary::(buffer); - let excerpt_range = excerpt_buffer_start..excerpt_buffer_end; - if excerpt_range.contains(&range.start) - && excerpt_range.contains(&range.end) - { - let excerpt_start = D::from_text_summary(&cursor.start().text); + if let Some(excerpt) = cursor.item() + && excerpt.locator == *excerpt_id + { + let excerpt_buffer_start = excerpt.range.context.start.summary::(buffer); + let excerpt_buffer_end = excerpt.range.context.end.summary::(buffer); + let excerpt_range = excerpt_buffer_start..excerpt_buffer_end; + if excerpt_range.contains(&range.start) + && excerpt_range.contains(&range.end) + { + let excerpt_start = D::from_text_summary(&cursor.start().text); - let mut start = excerpt_start; - start.add_assign(&(range.start - excerpt_buffer_start)); - let mut end = excerpt_start; - end.add_assign(&(range.end - excerpt_buffer_start)); + let mut start = excerpt_start; + start.add_assign(&(range.start - excerpt_buffer_start)); + let mut end = excerpt_start; + end.add_assign(&(range.end - excerpt_buffer_start)); - ranges.push(start..end); - break; - } + ranges.push(start..end); + break; } } } @@ -1251,25 +1245,25 @@ impl MultiBuffer { buffer.update(cx, |buffer, _| { buffer.merge_transactions(transaction, destination) }); - } else if let Some(transaction) = self.history.forget(transaction) { - if let Some(destination) = self.history.transaction_mut(destination) { - for (buffer_id, buffer_transaction_id) in transaction.buffer_transactions { - if let Some(destination_buffer_transaction_id) = - destination.buffer_transactions.get(&buffer_id) - { - if let Some(state) = self.buffers.borrow().get(&buffer_id) { - state.buffer.update(cx, |buffer, _| { - buffer.merge_transactions( - buffer_transaction_id, - *destination_buffer_transaction_id, - ) - }); - } - } else { - destination - .buffer_transactions - .insert(buffer_id, buffer_transaction_id); + } else if let Some(transaction) = self.history.forget(transaction) + && let Some(destination) = self.history.transaction_mut(destination) + { + for (buffer_id, buffer_transaction_id) in transaction.buffer_transactions { + if let Some(destination_buffer_transaction_id) = + destination.buffer_transactions.get(&buffer_id) + { + if let Some(state) = self.buffers.borrow().get(&buffer_id) { + state.buffer.update(cx, |buffer, _| { + buffer.merge_transactions( + buffer_transaction_id, + *destination_buffer_transaction_id, + ) + }); } + } else { + destination + .buffer_transactions + .insert(buffer_id, buffer_transaction_id); } } } @@ -1562,11 +1556,11 @@ impl MultiBuffer { }); let mut merged_ranges: Vec> = Vec::new(); for range in expanded_ranges { - if let Some(last_range) = merged_ranges.last_mut() { - if last_range.context.end >= range.context.start { - last_range.context.end = range.context.end; - continue; - } + if let Some(last_range) = merged_ranges.last_mut() + && last_range.context.end >= range.context.start + { + last_range.context.end = range.context.end; + continue; } merged_ranges.push(range) } @@ -1686,7 +1680,7 @@ impl MultiBuffer { cx: &mut Context, ) -> (Vec>, bool) { let (excerpt_ids, added_a_new_excerpt) = - self.update_path_excerpts(path, buffer, &buffer_snapshot, new, cx); + self.update_path_excerpts(path, buffer, buffer_snapshot, new, cx); let mut result = Vec::new(); let mut ranges = ranges.into_iter(); @@ -1726,7 +1720,7 @@ impl MultiBuffer { merged_ranges.push(range.clone()); counts.push(1); } - return (merged_ranges, counts); + (merged_ranges, counts) } fn update_path_excerpts( @@ -1784,7 +1778,7 @@ impl MultiBuffer { } Some(( *existing_id, - excerpt.range.context.to_point(&buffer_snapshot), + excerpt.range.context.to_point(buffer_snapshot), )) } else { None @@ -1794,25 +1788,25 @@ impl MultiBuffer { }; if let Some((last_id, last)) = to_insert.last_mut() { - if let Some(new) = new { - if last.context.end >= new.context.start { - last.context.end = last.context.end.max(new.context.end); - excerpt_ids.push(*last_id); - new_iter.next(); - continue; - } + if let Some(new) = new + && last.context.end >= new.context.start + { + last.context.end = last.context.end.max(new.context.end); + excerpt_ids.push(*last_id); + new_iter.next(); + continue; } - if let Some((existing_id, existing_range)) = &existing { - if last.context.end >= existing_range.start { - last.context.end = last.context.end.max(existing_range.end); - to_remove.push(*existing_id); - self.snapshot - .borrow_mut() - .replaced_excerpts - .insert(*existing_id, *last_id); - existing_iter.next(); - continue; - } + if let Some((existing_id, existing_range)) = &existing + && last.context.end >= existing_range.start + { + last.context.end = last.context.end.max(existing_range.end); + to_remove.push(*existing_id); + self.snapshot + .borrow_mut() + .replaced_excerpts + .insert(*existing_id, *last_id); + existing_iter.next(); + continue; } } @@ -2105,10 +2099,10 @@ impl MultiBuffer { .flatten() { cursor.seek_forward(&Some(locator), Bias::Left); - if let Some(excerpt) = cursor.item() { - if excerpt.locator == *locator { - excerpts.push((excerpt.id, excerpt.range.clone())); - } + if let Some(excerpt) = cursor.item() + && excerpt.locator == *locator + { + excerpts.push((excerpt.id, excerpt.range.clone())); } } @@ -2132,22 +2126,21 @@ impl MultiBuffer { let mut result = Vec::new(); for locator in locators { excerpts.seek_forward(&Some(locator), Bias::Left); - if let Some(excerpt) = excerpts.item() { - if excerpt.locator == *locator { - let excerpt_start = excerpts.start().1.clone(); - let excerpt_end = - ExcerptDimension(excerpt_start.0 + excerpt.text_summary.lines); + if let Some(excerpt) = excerpts.item() + && excerpt.locator == *locator + { + let excerpt_start = excerpts.start().1.clone(); + let excerpt_end = ExcerptDimension(excerpt_start.0 + excerpt.text_summary.lines); - diff_transforms.seek_forward(&excerpt_start, Bias::Left); - let overshoot = excerpt_start.0 - diff_transforms.start().0.0; - let start = diff_transforms.start().1.0 + overshoot; + diff_transforms.seek_forward(&excerpt_start, Bias::Left); + let overshoot = excerpt_start.0 - diff_transforms.start().0.0; + let start = diff_transforms.start().1.0 + overshoot; - diff_transforms.seek_forward(&excerpt_end, Bias::Right); - let overshoot = excerpt_end.0 - diff_transforms.start().0.0; - let end = diff_transforms.start().1.0 + overshoot; + diff_transforms.seek_forward(&excerpt_end, Bias::Right); + let overshoot = excerpt_end.0 - diff_transforms.start().0.0; + let end = diff_transforms.start().1.0 + overshoot; - result.push(start..end) - } + result.push(start..end) } } result @@ -2198,6 +2191,15 @@ impl MultiBuffer { }) } + pub fn buffer_for_anchor(&self, anchor: Anchor, cx: &App) -> Option> { + if let Some(buffer_id) = anchor.buffer_id { + self.buffer(buffer_id) + } else { + let (_, buffer, _) = self.excerpt_containing(anchor, cx)?; + Some(buffer) + } + } + // If point is at the end of the buffer, the last excerpt is returned pub fn point_to_buffer_offset( &self, @@ -2316,12 +2318,12 @@ impl MultiBuffer { // Skip over any subsequent excerpts that are also removed. if let Some(&next_excerpt_id) = excerpt_ids.peek() { let next_locator = snapshot.excerpt_locator_for_id(next_excerpt_id); - if let Some(next_excerpt) = cursor.item() { - if next_excerpt.locator == *next_locator { - excerpt_ids.next(); - excerpt = next_excerpt; - continue 'remove_excerpts; - } + if let Some(next_excerpt) = cursor.item() + && next_excerpt.locator == *next_locator + { + excerpt_ids.next(); + excerpt = next_excerpt; + continue 'remove_excerpts; } } @@ -2426,28 +2428,24 @@ impl MultiBuffer { event: &language::BufferEvent, cx: &mut Context, ) { + use language::BufferEvent; cx.emit(match event { - language::BufferEvent::Edited => Event::Edited { + BufferEvent::Edited => Event::Edited { singleton_buffer_edited: true, - edited_buffer: Some(buffer.clone()), + edited_buffer: Some(buffer), }, - language::BufferEvent::DirtyChanged => Event::DirtyChanged, - language::BufferEvent::Saved => Event::Saved, - language::BufferEvent::FileHandleChanged => Event::FileHandleChanged, - language::BufferEvent::Reloaded => Event::Reloaded, - language::BufferEvent::ReloadNeeded => Event::ReloadNeeded, - language::BufferEvent::LanguageChanged => { - Event::LanguageChanged(buffer.read(cx).remote_id()) - } - language::BufferEvent::Reparsed => Event::Reparsed(buffer.read(cx).remote_id()), - language::BufferEvent::DiagnosticsUpdated => Event::DiagnosticsUpdated, - language::BufferEvent::Closed => Event::Closed, - language::BufferEvent::Discarded => Event::Discarded, - language::BufferEvent::CapabilityChanged => { + BufferEvent::DirtyChanged => Event::DirtyChanged, + BufferEvent::Saved => Event::Saved, + BufferEvent::FileHandleChanged => Event::FileHandleChanged, + BufferEvent::Reloaded => Event::Reloaded, + BufferEvent::LanguageChanged => Event::LanguageChanged(buffer.read(cx).remote_id()), + BufferEvent::Reparsed => Event::Reparsed(buffer.read(cx).remote_id()), + BufferEvent::DiagnosticsUpdated => Event::DiagnosticsUpdated, + BufferEvent::CapabilityChanged => { self.capability = buffer.read(cx).capability(); - Event::CapabilityChanged + return; } - language::BufferEvent::Operation { .. } => return, + BufferEvent::Operation { .. } | BufferEvent::ReloadNeeded => return, }); } @@ -2484,7 +2482,7 @@ impl MultiBuffer { let base_text_changed = snapshot .diffs .get(&buffer_id) - .map_or(true, |old_diff| !new_diff.base_texts_eq(old_diff)); + .is_none_or(|old_diff| !new_diff.base_texts_eq(old_diff)); snapshot.diffs.insert(buffer_id, new_diff); @@ -2494,33 +2492,33 @@ impl MultiBuffer { .excerpts .cursor::, ExcerptOffset>>(&()); cursor.seek_forward(&Some(locator), Bias::Left); - if let Some(excerpt) = cursor.item() { - if excerpt.locator == *locator { - let excerpt_buffer_range = excerpt.range.context.to_offset(&excerpt.buffer); - if diff_change_range.end < excerpt_buffer_range.start - || diff_change_range.start > excerpt_buffer_range.end - { - continue; - } - let excerpt_start = cursor.start().1; - let excerpt_len = ExcerptOffset::new(excerpt.text_summary.len); - let diff_change_start_in_excerpt = ExcerptOffset::new( - diff_change_range - .start - .saturating_sub(excerpt_buffer_range.start), - ); - let diff_change_end_in_excerpt = ExcerptOffset::new( - diff_change_range - .end - .saturating_sub(excerpt_buffer_range.start), - ); - let edit_start = excerpt_start + diff_change_start_in_excerpt.min(excerpt_len); - let edit_end = excerpt_start + diff_change_end_in_excerpt.min(excerpt_len); - excerpt_edits.push(Edit { - old: edit_start..edit_end, - new: edit_start..edit_end, - }); + if let Some(excerpt) = cursor.item() + && excerpt.locator == *locator + { + let excerpt_buffer_range = excerpt.range.context.to_offset(&excerpt.buffer); + if diff_change_range.end < excerpt_buffer_range.start + || diff_change_range.start > excerpt_buffer_range.end + { + continue; } + let excerpt_start = cursor.start().1; + let excerpt_len = ExcerptOffset::new(excerpt.text_summary.len); + let diff_change_start_in_excerpt = ExcerptOffset::new( + diff_change_range + .start + .saturating_sub(excerpt_buffer_range.start), + ); + let diff_change_end_in_excerpt = ExcerptOffset::new( + diff_change_range + .end + .saturating_sub(excerpt_buffer_range.start), + ); + let edit_start = excerpt_start + diff_change_start_in_excerpt.min(excerpt_len); + let edit_end = excerpt_start + diff_change_end_in_excerpt.min(excerpt_len); + excerpt_edits.push(Edit { + old: edit_start..edit_end, + new: edit_start..edit_end, + }); } } @@ -2545,6 +2543,10 @@ impl MultiBuffer { .collect() } + pub fn all_buffer_ids(&self) -> Vec { + self.buffers.borrow().keys().copied().collect() + } + pub fn buffer(&self, buffer_id: BufferId) -> Option> { self.buffers .borrow() @@ -2778,7 +2780,7 @@ impl MultiBuffer { if diff_hunk.excerpt_id.cmp(&end_excerpt_id, &snapshot).is_gt() { continue; } - if last_hunk_row.map_or(false, |row| row >= diff_hunk.row_range.start) { + if last_hunk_row.is_some_and(|row| row >= diff_hunk.row_range.start) { continue; } let start = Anchor::in_buffer( @@ -3042,7 +3044,7 @@ impl MultiBuffer { is_dirty |= buffer.is_dirty(); has_deleted_file |= buffer .file() - .map_or(false, |file| file.disk_state() == DiskState::Deleted); + .is_some_and(|file| file.disk_state() == DiskState::Deleted); has_conflict |= buffer.has_conflict(); } if edited { @@ -3056,7 +3058,7 @@ impl MultiBuffer { snapshot.has_conflict = has_conflict; for (id, diff) in self.diffs.iter() { - if snapshot.diffs.get(&id).is_none() { + if snapshot.diffs.get(id).is_none() { snapshot.diffs.insert(*id, diff.diff.read(cx).snapshot(cx)); } } @@ -3155,13 +3157,12 @@ impl MultiBuffer { at_transform_boundary = false; let transforms_before_edit = old_diff_transforms.slice(&edit.old.start, Bias::Left); self.append_diff_transforms(&mut new_diff_transforms, transforms_before_edit); - if let Some(transform) = old_diff_transforms.item() { - if old_diff_transforms.end().0 == edit.old.start - && old_diff_transforms.start().0 < edit.old.start - { - self.push_diff_transform(&mut new_diff_transforms, transform.clone()); - old_diff_transforms.next(); - } + if let Some(transform) = old_diff_transforms.item() + && old_diff_transforms.end().0 == edit.old.start + && old_diff_transforms.start().0 < edit.old.start + { + self.push_diff_transform(&mut new_diff_transforms, transform.clone()); + old_diff_transforms.next(); } } @@ -3177,7 +3178,7 @@ impl MultiBuffer { &mut new_diff_transforms, &mut end_of_current_insert, &mut old_expanded_hunks, - &snapshot, + snapshot, change_kind, ); @@ -3201,9 +3202,10 @@ impl MultiBuffer { // If this is the last edit that intersects the current diff transform, // then recreate the content up to the end of this transform, to prepare // for reusing additional slices of the old transforms. - if excerpt_edits.peek().map_or(true, |next_edit| { - next_edit.old.start >= old_diff_transforms.end().0 - }) { + if excerpt_edits + .peek() + .is_none_or(|next_edit| next_edit.old.start >= old_diff_transforms.end().0) + { let keep_next_old_transform = (old_diff_transforms.start().0 >= edit.old.end) && match old_diff_transforms.item() { Some(DiffTransform::BufferContent { @@ -3223,7 +3225,7 @@ impl MultiBuffer { old_expanded_hunks.clear(); self.push_buffer_content_transform( - &snapshot, + snapshot, &mut new_diff_transforms, excerpt_offset, end_of_current_insert, @@ -3431,18 +3433,17 @@ impl MultiBuffer { inserted_hunk_info, summary, }) = subtree.first() - { - if self.extend_last_buffer_content_transform( + && self.extend_last_buffer_content_transform( new_transforms, *inserted_hunk_info, *summary, - ) { - let mut cursor = subtree.cursor::<()>(&()); - cursor.next(); - cursor.next(); - new_transforms.append(cursor.suffix(), &()); - return; - } + ) + { + let mut cursor = subtree.cursor::<()>(&()); + cursor.next(); + cursor.next(); + new_transforms.append(cursor.suffix(), &()); + return; } new_transforms.append(subtree, &()); } @@ -3456,14 +3457,13 @@ impl MultiBuffer { inserted_hunk_info: inserted_hunk_anchor, summary, } = transform - { - if self.extend_last_buffer_content_transform( + && self.extend_last_buffer_content_transform( new_transforms, inserted_hunk_anchor, summary, - ) { - return; - } + ) + { + return; } new_transforms.push(transform, &()); } @@ -3518,11 +3518,10 @@ impl MultiBuffer { summary, inserted_hunk_info: inserted_hunk_anchor, } = last_transform + && *inserted_hunk_anchor == new_inserted_hunk_info { - if *inserted_hunk_anchor == new_inserted_hunk_info { - *summary += summary_to_add; - did_extend = true; - } + *summary += summary_to_add; + did_extend = true; } }, &(), @@ -3565,9 +3564,7 @@ impl MultiBuffer { let multi = cx.new(|_| Self::new(Capability::ReadWrite)); for (text, ranges) in excerpts { let buffer = cx.new(|cx| Buffer::local(text, cx)); - let excerpt_ranges = ranges - .into_iter() - .map(|range| ExcerptRange::new(range.clone())); + let excerpt_ranges = ranges.into_iter().map(ExcerptRange::new); multi.update(cx, |multi, cx| { multi.push_excerpts(buffer, excerpt_ranges, cx) }); @@ -3583,7 +3580,7 @@ impl MultiBuffer { pub fn build_random(rng: &mut impl rand::Rng, cx: &mut gpui::App) -> Entity { cx.new(|cx| { let mut multibuffer = MultiBuffer::new(Capability::ReadWrite); - let mutation_count = rng.gen_range(1..=5); + let mutation_count = rng.random_range(1..=5); multibuffer.randomly_edit_excerpts(rng, mutation_count, cx); multibuffer }) @@ -3601,21 +3598,22 @@ impl MultiBuffer { let mut edits: Vec<(Range, Arc)> = Vec::new(); let mut last_end = None; for _ in 0..edit_count { - if last_end.map_or(false, |last_end| last_end >= snapshot.len()) { + if last_end.is_some_and(|last_end| last_end >= snapshot.len()) { break; } let new_start = last_end.map_or(0, |last_end| last_end + 1); - let end = snapshot.clip_offset(rng.gen_range(new_start..=snapshot.len()), Bias::Right); - let start = snapshot.clip_offset(rng.gen_range(new_start..=end), Bias::Right); + let end = + snapshot.clip_offset(rng.random_range(new_start..=snapshot.len()), Bias::Right); + let start = snapshot.clip_offset(rng.random_range(new_start..=end), Bias::Right); last_end = Some(end); let mut range = start..end; - if rng.gen_bool(0.2) { + if rng.random_bool(0.2) { mem::swap(&mut range.start, &mut range.end); } - let new_text_len = rng.gen_range(0..10); + let new_text_len = rng.random_range(0..10); let new_text: String = RandomCharIter::new(&mut *rng).take(new_text_len).collect(); edits.push((range, new_text.into())); @@ -3642,18 +3640,18 @@ impl MultiBuffer { let mut buffers = Vec::new(); for _ in 0..mutation_count { - if rng.gen_bool(0.05) { + if rng.random_bool(0.05) { log::info!("Clearing multi-buffer"); self.clear(cx); continue; - } else if rng.gen_bool(0.1) && !self.excerpt_ids().is_empty() { + } else if rng.random_bool(0.1) && !self.excerpt_ids().is_empty() { let ids = self.excerpt_ids(); let mut excerpts = HashSet::default(); - for _ in 0..rng.gen_range(0..ids.len()) { + for _ in 0..rng.random_range(0..ids.len()) { excerpts.extend(ids.choose(rng).copied()); } - let line_count = rng.gen_range(0..5); + let line_count = rng.random_range(0..5); log::info!("Expanding excerpts {excerpts:?} by {line_count} lines"); @@ -3667,8 +3665,8 @@ impl MultiBuffer { } let excerpt_ids = self.excerpt_ids(); - if excerpt_ids.is_empty() || (rng.r#gen() && excerpt_ids.len() < max_excerpts) { - let buffer_handle = if rng.r#gen() || self.buffers.borrow().is_empty() { + if excerpt_ids.is_empty() || (rng.random() && excerpt_ids.len() < max_excerpts) { + let buffer_handle = if rng.random() || self.buffers.borrow().is_empty() { let text = RandomCharIter::new(&mut *rng).take(10).collect::(); buffers.push(cx.new(|cx| Buffer::local(text, cx))); let buffer = buffers.last().unwrap().read(cx); @@ -3690,11 +3688,11 @@ impl MultiBuffer { let buffer = buffer_handle.read(cx); let buffer_text = buffer.text(); - let ranges = (0..rng.gen_range(0..5)) + let ranges = (0..rng.random_range(0..5)) .map(|_| { let end_ix = - buffer.clip_offset(rng.gen_range(0..=buffer.len()), Bias::Right); - let start_ix = buffer.clip_offset(rng.gen_range(0..=end_ix), Bias::Left); + buffer.clip_offset(rng.random_range(0..=buffer.len()), Bias::Right); + let start_ix = buffer.clip_offset(rng.random_range(0..=end_ix), Bias::Left); ExcerptRange::new(start_ix..end_ix) }) .collect::>(); @@ -3711,7 +3709,7 @@ impl MultiBuffer { let excerpt_id = self.push_excerpts(buffer_handle.clone(), ranges, cx); log::info!("Inserted with ids: {:?}", excerpt_id); } else { - let remove_count = rng.gen_range(1..=excerpt_ids.len()); + let remove_count = rng.random_range(1..=excerpt_ids.len()); let mut excerpts_to_remove = excerpt_ids .choose_multiple(rng, remove_count) .cloned() @@ -3733,7 +3731,7 @@ impl MultiBuffer { ) { use rand::prelude::*; - if rng.gen_bool(0.7) || self.singleton { + if rng.random_bool(0.7) || self.singleton { let buffer = self .buffers .borrow() @@ -3743,7 +3741,7 @@ impl MultiBuffer { if let Some(buffer) = buffer { buffer.update(cx, |buffer, cx| { - if rng.r#gen() { + if rng.random() { buffer.randomly_edit(rng, mutation_count, cx); } else { buffer.randomly_undo_redo(rng, cx); @@ -3916,8 +3914,8 @@ impl MultiBufferSnapshot { &self, range: Range, ) -> Vec<(&BufferSnapshot, Range, ExcerptId)> { - let start = range.start.to_offset(&self); - let end = range.end.to_offset(&self); + let start = range.start.to_offset(self); + let end = range.end.to_offset(self); let mut cursor = self.cursor::(); cursor.seek(&start); @@ -3955,8 +3953,8 @@ impl MultiBufferSnapshot { &self, range: Range, ) -> impl Iterator, ExcerptId, Option)> + '_ { - let start = range.start.to_offset(&self); - let end = range.end.to_offset(&self); + let start = range.start.to_offset(self); + let end = range.end.to_offset(self); let mut cursor = self.cursor::(); cursor.seek(&start); @@ -4037,10 +4035,10 @@ impl MultiBufferSnapshot { cursor.seek(&query_range.start); - if let Some(region) = cursor.region().filter(|region| !region.is_main_buffer) { - if region.range.start > D::zero(&()) { - cursor.prev() - } + if let Some(region) = cursor.region().filter(|region| !region.is_main_buffer) + && region.range.start > D::zero(&()) + { + cursor.prev() } iter::from_fn(move || { @@ -4070,19 +4068,15 @@ impl MultiBufferSnapshot { buffer_start = cursor.main_buffer_position()?; }; let mut buffer_end = excerpt.range.context.end.summary::(&excerpt.buffer); - if let Some((end_excerpt_id, end_buffer_offset)) = range_end { - if excerpt.id == end_excerpt_id { - buffer_end = buffer_end.min(end_buffer_offset); - } - } - - if let Some(iterator) = - get_buffer_metadata(&excerpt.buffer, buffer_start..buffer_end) + if let Some((end_excerpt_id, end_buffer_offset)) = range_end + && excerpt.id == end_excerpt_id { - Some(&mut current_excerpt_metadata.insert((excerpt.id, iterator)).1) - } else { - None + buffer_end = buffer_end.min(end_buffer_offset); } + + get_buffer_metadata(&excerpt.buffer, buffer_start..buffer_end).map(|iterator| { + &mut current_excerpt_metadata.insert((excerpt.id, iterator)).1 + }) }; // Visit each metadata item. @@ -4144,10 +4138,10 @@ impl MultiBufferSnapshot { // When there are no more metadata items for this excerpt, move to the next excerpt. else { current_excerpt_metadata.take(); - if let Some((end_excerpt_id, _)) = range_end { - if excerpt.id == end_excerpt_id { - return None; - } + if let Some((end_excerpt_id, _)) = range_end + && excerpt.id == end_excerpt_id + { + return None; } cursor.next_excerpt(); } @@ -4186,7 +4180,7 @@ impl MultiBufferSnapshot { } let start = Anchor::in_buffer(excerpt.id, excerpt.buffer_id, hunk.buffer_range.start) - .to_point(&self); + .to_point(self); return Some(MultiBufferRow(start.row)); } } @@ -4204,7 +4198,7 @@ impl MultiBufferSnapshot { continue; }; let start = Anchor::in_buffer(excerpt.id, excerpt.buffer_id, hunk.buffer_range.start) - .to_point(&self); + .to_point(self); return Some(MultiBufferRow(start.row)); } } @@ -4455,7 +4449,7 @@ impl MultiBufferSnapshot { let mut buffer_position = region.buffer_range.start; buffer_position.add_assign(&overshoot); let clipped_buffer_position = - clip_buffer_position(®ion.buffer, buffer_position, bias); + clip_buffer_position(region.buffer, buffer_position, bias); let mut position = region.range.start; position.add_assign(&(clipped_buffer_position - region.buffer_range.start)); position @@ -4485,7 +4479,7 @@ impl MultiBufferSnapshot { let buffer_start_value = region.buffer_range.start.value.unwrap(); let mut buffer_key = buffer_start_key; buffer_key.add_assign(&(key - start_key)); - let buffer_value = convert_buffer_dimension(®ion.buffer, buffer_key); + let buffer_value = convert_buffer_dimension(region.buffer, buffer_key); let mut result = start_value; result.add_assign(&(buffer_value - buffer_start_value)); result @@ -4622,20 +4616,20 @@ impl MultiBufferSnapshot { pub fn indent_and_comment_for_line(&self, row: MultiBufferRow, cx: &App) -> String { let mut indent = self.indent_size_for_line(row).chars().collect::(); - if self.language_settings(cx).extend_comment_on_newline { - if let Some(language_scope) = self.language_scope_at(Point::new(row.0, 0)) { - let delimiters = language_scope.line_comment_prefixes(); - for delimiter in delimiters { - if *self - .chars_at(Point::new(row.0, indent.len() as u32)) - .take(delimiter.chars().count()) - .collect::() - .as_str() - == **delimiter - { - indent.push_str(&delimiter); - break; - } + if self.language_settings(cx).extend_comment_on_newline + && let Some(language_scope) = self.language_scope_at(Point::new(row.0, 0)) + { + let delimiters = language_scope.line_comment_prefixes(); + for delimiter in delimiters { + if *self + .chars_at(Point::new(row.0, indent.len() as u32)) + .take(delimiter.chars().count()) + .collect::() + .as_str() + == **delimiter + { + indent.push_str(delimiter); + break; } } } @@ -4655,7 +4649,7 @@ impl MultiBufferSnapshot { return true; } } - return true; + true } pub fn prev_non_blank_row(&self, mut row: MultiBufferRow) -> Option { @@ -4893,25 +4887,22 @@ impl MultiBufferSnapshot { base_text_byte_range, .. }) => { - if let Some(diff_base_anchor) = &anchor.diff_base_anchor { - if let Some(base_text) = + if let Some(diff_base_anchor) = &anchor.diff_base_anchor + && let Some(base_text) = self.diffs.get(buffer_id).map(|diff| diff.base_text()) + && base_text.can_resolve(diff_base_anchor) + { + let base_text_offset = diff_base_anchor.to_offset(base_text); + if base_text_offset >= base_text_byte_range.start + && base_text_offset <= base_text_byte_range.end { - if base_text.can_resolve(&diff_base_anchor) { - let base_text_offset = diff_base_anchor.to_offset(&base_text); - if base_text_offset >= base_text_byte_range.start - && base_text_offset <= base_text_byte_range.end - { - let position_in_hunk = base_text - .text_summary_for_range::( - base_text_byte_range.start..base_text_offset, - ); - position.add_assign(&position_in_hunk); - } else if at_transform_end { - diff_transforms.next(); - continue; - } - } + let position_in_hunk = base_text.text_summary_for_range::( + base_text_byte_range.start..base_text_offset, + ); + position.add_assign(&position_in_hunk); + } else if at_transform_end { + diff_transforms.next(); + continue; } } } @@ -4941,20 +4932,19 @@ impl MultiBufferSnapshot { } let mut position = cursor.start().1; - if let Some(excerpt) = cursor.item() { - if excerpt.id == anchor.excerpt_id { - let excerpt_buffer_start = excerpt - .buffer - .offset_for_anchor(&excerpt.range.context.start); - let excerpt_buffer_end = - excerpt.buffer.offset_for_anchor(&excerpt.range.context.end); - let buffer_position = cmp::min( - excerpt_buffer_end, - excerpt.buffer.offset_for_anchor(&anchor.text_anchor), - ); - if buffer_position > excerpt_buffer_start { - position.value += buffer_position - excerpt_buffer_start; - } + if let Some(excerpt) = cursor.item() + && excerpt.id == anchor.excerpt_id + { + let excerpt_buffer_start = excerpt + .buffer + .offset_for_anchor(&excerpt.range.context.start); + let excerpt_buffer_end = excerpt.buffer.offset_for_anchor(&excerpt.range.context.end); + let buffer_position = cmp::min( + excerpt_buffer_end, + excerpt.buffer.offset_for_anchor(&anchor.text_anchor), + ); + if buffer_position > excerpt_buffer_start { + position.value += buffer_position - excerpt_buffer_start; } } position @@ -4964,7 +4954,7 @@ impl MultiBufferSnapshot { while let Some(replacement) = self.replaced_excerpts.get(&excerpt_id) { excerpt_id = *replacement; } - return excerpt_id; + excerpt_id } pub fn summaries_for_anchors<'a, D, I>(&'a self, anchors: I) -> Vec @@ -5082,9 +5072,9 @@ impl MultiBufferSnapshot { if point == region.range.end.key && region.has_trailing_newline { position.add_assign(&D::from_text_summary(&TextSummary::newline())); } - return Some(position); + Some(position) } else { - return Some(D::from_text_summary(&self.text_summary())); + Some(D::from_text_summary(&self.text_summary())) } }) } @@ -5124,7 +5114,7 @@ impl MultiBufferSnapshot { // Leave min and max anchors unchanged if invalid or // if the old excerpt still exists at this location let mut kept_position = next_excerpt - .map_or(false, |e| e.id == old_excerpt_id && e.contains(&anchor)) + .is_some_and(|e| e.id == old_excerpt_id && e.contains(&anchor)) || old_excerpt_id == ExcerptId::max() || old_excerpt_id == ExcerptId::min(); @@ -5211,15 +5201,12 @@ impl MultiBufferSnapshot { .cursor::>(&()); diff_transforms.seek(&offset, Bias::Right); - if offset == diff_transforms.start().0 && bias == Bias::Left { - if let Some(prev_item) = diff_transforms.prev_item() { - match prev_item { - DiffTransform::DeletedHunk { .. } => { - diff_transforms.prev(); - } - _ => {} - } - } + if offset == diff_transforms.start().0 + && bias == Bias::Left + && let Some(prev_item) = diff_transforms.prev_item() + && let DiffTransform::DeletedHunk { .. } = prev_item + { + diff_transforms.prev(); } let offset_in_transform = offset - diff_transforms.start().0; let mut excerpt_offset = diff_transforms.start().1; @@ -5246,15 +5233,6 @@ impl MultiBufferSnapshot { excerpt_offset += ExcerptOffset::new(offset_in_transform); }; - if let Some((excerpt_id, buffer_id, buffer)) = self.as_singleton() { - return Anchor { - buffer_id: Some(buffer_id), - excerpt_id: *excerpt_id, - text_anchor: buffer.anchor_at(excerpt_offset.value, bias), - diff_base_anchor, - }; - } - let mut excerpts = self .excerpts .cursor::>>(&()); @@ -5278,10 +5256,17 @@ impl MultiBufferSnapshot { text_anchor, diff_base_anchor, } - } else if excerpt_offset.is_zero() && bias == Bias::Left { - Anchor::min() } else { - Anchor::max() + let mut anchor = if excerpt_offset.is_zero() && bias == Bias::Left { + Anchor::min() + } else { + Anchor::max() + }; + // TODO this is a hack, remove it + if let Some((excerpt_id, _, _)) = self.as_singleton() { + anchor.excerpt_id = *excerpt_id; + } + anchor } } @@ -5296,17 +5281,17 @@ impl MultiBufferSnapshot { let locator = self.excerpt_locator_for_id(excerpt_id); let mut cursor = self.excerpts.cursor::>(&()); cursor.seek(locator, Bias::Left); - if let Some(excerpt) = cursor.item() { - if excerpt.id == excerpt_id { - let text_anchor = excerpt.clip_anchor(text_anchor); - drop(cursor); - return Some(Anchor { - buffer_id: Some(excerpt.buffer_id), - excerpt_id, - text_anchor, - diff_base_anchor: None, - }); - } + if let Some(excerpt) = cursor.item() + && excerpt.id == excerpt_id + { + let text_anchor = excerpt.clip_anchor(text_anchor); + drop(cursor); + return Some(Anchor { + buffer_id: Some(excerpt.buffer_id), + excerpt_id, + text_anchor, + diff_base_anchor: None, + }); } None } @@ -5491,7 +5476,7 @@ impl MultiBufferSnapshot { let range_filter = |open: Range, close: Range| -> bool { excerpt_buffer_range.contains(&open.start) && excerpt_buffer_range.contains(&close.end) - && range_filter.map_or(true, |filter| filter(buffer, open, close)) + && range_filter.is_none_or(|filter| filter(buffer, open, close)) }; let (open, close) = excerpt.buffer().innermost_enclosing_bracket_ranges( @@ -5651,10 +5636,10 @@ impl MultiBufferSnapshot { .buffer .line_indents_in_row_range(buffer_start_row..buffer_end_row); cursor.next(); - return Some(line_indents.map(move |(buffer_row, indent)| { + Some(line_indents.map(move |(buffer_row, indent)| { let row = region.range.start.row + (buffer_row - region.buffer_range.start.row); (MultiBufferRow(row), indent, ®ion.excerpt.buffer) - })); + })) }) .flatten() } @@ -5691,10 +5676,10 @@ impl MultiBufferSnapshot { .buffer .reversed_line_indents_in_row_range(buffer_start_row..buffer_end_row); cursor.prev(); - return Some(line_indents.map(move |(buffer_row, indent)| { + Some(line_indents.map(move |(buffer_row, indent)| { let row = region.range.start.row + (buffer_row - region.buffer_range.start.row); (MultiBufferRow(row), indent, ®ion.excerpt.buffer) - })); + })) }) .flatten() } @@ -5860,10 +5845,10 @@ impl MultiBufferSnapshot { let current_depth = indent_stack.len() as u32; // Avoid retrieving the language settings repeatedly for every buffer row. - if let Some((prev_buffer_id, _)) = &prev_settings { - if prev_buffer_id != &buffer.remote_id() { - prev_settings.take(); - } + if let Some((prev_buffer_id, _)) = &prev_settings + && prev_buffer_id != &buffer.remote_id() + { + prev_settings.take(); } let settings = &prev_settings .get_or_insert_with(|| { @@ -6110,9 +6095,31 @@ impl MultiBufferSnapshot { Some((node, range)) } + pub fn syntax_next_sibling( + &self, + range: Range, + ) -> Option> { + let range = range.start.to_offset(self)..range.end.to_offset(self); + let mut excerpt = self.excerpt_containing(range.clone())?; + excerpt + .buffer() + .syntax_next_sibling(excerpt.map_range_to_buffer(range)) + } + + pub fn syntax_prev_sibling( + &self, + range: Range, + ) -> Option> { + let range = range.start.to_offset(self)..range.end.to_offset(self); + let mut excerpt = self.excerpt_containing(range.clone())?; + excerpt + .buffer() + .syntax_prev_sibling(excerpt.map_range_to_buffer(range)) + } + pub fn outline(&self, theme: Option<&SyntaxTheme>) -> Option> { let (excerpt_id, _, buffer) = self.as_singleton()?; - let outline = buffer.outline(theme)?; + let outline = buffer.outline(theme); Some(Outline::new( outline .items @@ -6157,7 +6164,6 @@ impl MultiBufferSnapshot { .buffer .symbols_containing(anchor.text_anchor, theme) .into_iter() - .flatten() .flat_map(|item| { Some(OutlineItem { depth: item.depth, @@ -6192,10 +6198,10 @@ impl MultiBufferSnapshot { } else { let mut cursor = self.excerpt_ids.cursor::(&()); cursor.seek(&id, Bias::Left); - if let Some(entry) = cursor.item() { - if entry.id == id { - return &entry.locator; - } + if let Some(entry) = cursor.item() + && entry.id == id + { + return &entry.locator; } panic!("invalid excerpt id {id:?}") } @@ -6272,10 +6278,10 @@ impl MultiBufferSnapshot { pub fn buffer_range_for_excerpt(&self, excerpt_id: ExcerptId) -> Option> { let mut cursor = self.excerpts.cursor::>(&()); let locator = self.excerpt_locator_for_id(excerpt_id); - if cursor.seek(&Some(locator), Bias::Left) { - if let Some(excerpt) = cursor.item() { - return Some(excerpt.range.context.clone()); - } + if cursor.seek(&Some(locator), Bias::Left) + && let Some(excerpt) = cursor.item() + { + return Some(excerpt.range.context.clone()); } None } @@ -6284,10 +6290,10 @@ impl MultiBufferSnapshot { let mut cursor = self.excerpts.cursor::>(&()); let locator = self.excerpt_locator_for_id(excerpt_id); cursor.seek(&Some(locator), Bias::Left); - if let Some(excerpt) = cursor.item() { - if excerpt.id == excerpt_id { - return Some(excerpt); - } + if let Some(excerpt) = cursor.item() + && excerpt.id == excerpt_id + { + return Some(excerpt); } None } @@ -6323,6 +6329,14 @@ impl MultiBufferSnapshot { }) } + pub fn buffer_id_for_anchor(&self, anchor: Anchor) -> Option { + if let Some(id) = anchor.buffer_id { + return Some(id); + } + let excerpt = self.excerpt_containing(anchor..anchor)?; + Some(excerpt.buffer_id()) + } + pub fn selections_in_range<'a>( &'a self, range: &'a Range, @@ -6396,8 +6410,8 @@ impl MultiBufferSnapshot { #[cfg(any(test, feature = "test-support"))] impl MultiBufferSnapshot { pub fn random_byte_range(&self, start_offset: usize, rng: &mut impl rand::Rng) -> Range { - let end = self.clip_offset(rng.gen_range(start_offset..=self.len()), Bias::Right); - let start = self.clip_offset(rng.gen_range(start_offset..=end), Bias::Right); + let end = self.clip_offset(rng.random_range(start_offset..=self.len()), Bias::Right); + let start = self.clip_offset(rng.random_range(start_offset..=end), Bias::Right); start..end } @@ -6418,7 +6432,7 @@ impl MultiBufferSnapshot { for (ix, entry) in excerpt_ids.iter().enumerate() { if ix == 0 { - if entry.id.cmp(&ExcerptId::min(), &self).is_le() { + if entry.id.cmp(&ExcerptId::min(), self).is_le() { panic!("invalid first excerpt id {:?}", entry.id); } } else if entry.id <= excerpt_ids[ix - 1].id { @@ -6446,13 +6460,12 @@ impl MultiBufferSnapshot { inserted_hunk_info: prev_inserted_hunk_info, .. }) = prev_transform + && *inserted_hunk_info == *prev_inserted_hunk_info { - if *inserted_hunk_info == *prev_inserted_hunk_info { - panic!( - "multiple adjacent buffer content transforms with is_inserted_hunk = {inserted_hunk_info:?}. transforms: {:+?}", - self.diff_transforms.items(&()) - ); - } + panic!( + "multiple adjacent buffer content transforms with is_inserted_hunk = {inserted_hunk_info:?}. transforms: {:+?}", + self.diff_transforms.items(&()) + ); } if summary.len == 0 && !self.is_empty() { panic!("empty buffer content transform"); @@ -6552,14 +6565,12 @@ where self.excerpts.next(); } else if let Some(DiffTransform::DeletedHunk { hunk_info, .. }) = self.diff_transforms.item() - { - if self + && self .excerpts .item() - .map_or(false, |excerpt| excerpt.id != hunk_info.excerpt_id) - { - self.excerpts.next(); - } + .is_some_and(|excerpt| excerpt.id != hunk_info.excerpt_id) + { + self.excerpts.next(); } } } @@ -6604,7 +6615,7 @@ where let prev_transform = self.diff_transforms.item(); self.diff_transforms.next(); - prev_transform.map_or(true, |next_transform| { + prev_transform.is_none_or(|next_transform| { matches!(next_transform, DiffTransform::BufferContent { .. }) }) } @@ -6619,12 +6630,12 @@ where } let next_transform = self.diff_transforms.next_item(); - next_transform.map_or(true, |next_transform| match next_transform { + next_transform.is_none_or(|next_transform| match next_transform { DiffTransform::BufferContent { .. } => true, DiffTransform::DeletedHunk { hunk_info, .. } => self .excerpts .item() - .map_or(false, |excerpt| excerpt.id != hunk_info.excerpt_id), + .is_some_and(|excerpt| excerpt.id != hunk_info.excerpt_id), }) } @@ -6648,7 +6659,7 @@ where hunk_info, .. } => { - let diff = self.diffs.get(&buffer_id)?; + let diff = self.diffs.get(buffer_id)?; let buffer = diff.base_text(); let mut rope_cursor = buffer.as_rope().cursor(0); let buffer_start = rope_cursor.summary::(base_text_byte_range.start); @@ -6657,7 +6668,7 @@ where buffer_end.add_assign(&buffer_range_len); let start = self.diff_transforms.start().output_dimension.0; let end = self.diff_transforms.end().output_dimension.0; - return Some(MultiBufferRegion { + Some(MultiBufferRegion { buffer, excerpt, has_trailing_newline: *has_trailing_newline, @@ -6667,7 +6678,7 @@ where )), buffer_range: buffer_start..buffer_end, range: start..end, - }); + }) } DiffTransform::BufferContent { inserted_hunk_info, .. @@ -7004,7 +7015,7 @@ impl Excerpt { } fn contains(&self, anchor: &Anchor) -> bool { - Some(self.buffer_id) == anchor.buffer_id + (anchor.buffer_id == None || anchor.buffer_id == Some(self.buffer_id)) && self .range .context @@ -7164,7 +7175,7 @@ impl ExcerptId { Self(usize::MAX) } - pub fn to_proto(&self) -> u64 { + pub fn to_proto(self) -> u64 { self.0 as _ } @@ -7505,61 +7516,59 @@ impl Iterator for MultiBufferRows<'_> { self.cursor.next(); if let Some(next_region) = self.cursor.region() { region = next_region; - } else { - if self.point == self.cursor.diff_transforms.end().output_dimension.0 { - let multibuffer_row = MultiBufferRow(self.point.row); - let last_excerpt = self - .cursor - .excerpts - .item() - .or(self.cursor.excerpts.prev_item())?; - let last_row = last_excerpt - .range - .context - .end - .to_point(&last_excerpt.buffer) - .row; + } else if self.point == self.cursor.diff_transforms.end().output_dimension.0 { + let multibuffer_row = MultiBufferRow(self.point.row); + let last_excerpt = self + .cursor + .excerpts + .item() + .or(self.cursor.excerpts.prev_item())?; + let last_row = last_excerpt + .range + .context + .end + .to_point(&last_excerpt.buffer) + .row; - let first_row = last_excerpt - .range - .context - .start - .to_point(&last_excerpt.buffer) - .row; + let first_row = last_excerpt + .range + .context + .start + .to_point(&last_excerpt.buffer) + .row; - let expand_info = if self.is_singleton { - None - } else { - let needs_expand_up = first_row == last_row - && last_row > 0 - && !region.diff_hunk_status.is_some_and(|d| d.is_deleted()); - let needs_expand_down = last_row < last_excerpt.buffer.max_point().row; - - if needs_expand_up && needs_expand_down { - Some(ExpandExcerptDirection::UpAndDown) - } else if needs_expand_up { - Some(ExpandExcerptDirection::Up) - } else if needs_expand_down { - Some(ExpandExcerptDirection::Down) - } else { - None - } - .map(|direction| ExpandInfo { - direction, - excerpt_id: last_excerpt.id, - }) - }; - self.point += Point::new(1, 0); - return Some(RowInfo { - buffer_id: Some(last_excerpt.buffer_id), - buffer_row: Some(last_row), - multibuffer_row: Some(multibuffer_row), - diff_status: None, - expand_info, - }); + let expand_info = if self.is_singleton { + None } else { - return None; - } + let needs_expand_up = first_row == last_row + && last_row > 0 + && !region.diff_hunk_status.is_some_and(|d| d.is_deleted()); + let needs_expand_down = last_row < last_excerpt.buffer.max_point().row; + + if needs_expand_up && needs_expand_down { + Some(ExpandExcerptDirection::UpAndDown) + } else if needs_expand_up { + Some(ExpandExcerptDirection::Up) + } else if needs_expand_down { + Some(ExpandExcerptDirection::Down) + } else { + None + } + .map(|direction| ExpandInfo { + direction, + excerpt_id: last_excerpt.id, + }) + }; + self.point += Point::new(1, 0); + return Some(RowInfo { + buffer_id: Some(last_excerpt.buffer_id), + buffer_row: Some(last_row), + multibuffer_row: Some(multibuffer_row), + diff_status: None, + expand_info, + }); + } else { + return None; }; } @@ -7731,12 +7740,21 @@ impl<'a> Iterator for MultiBufferChunks<'a> { let diff_transform_end = diff_transform_end.min(self.range.end); if diff_transform_end < chunk_end { - let (before, after) = - chunk.text.split_at(diff_transform_end - self.range.start); + let split_idx = diff_transform_end - self.range.start; + let (before, after) = chunk.text.split_at(split_idx); self.range.start = diff_transform_end; + let mask = (1 << split_idx) - 1; + let chars = chunk.chars & mask; + let tabs = chunk.tabs & mask; + chunk.text = after; + chunk.chars = chunk.chars >> split_idx; + chunk.tabs = chunk.tabs >> split_idx; + Some(Chunk { text: before, + chars, + tabs, ..chunk.clone() }) } else { @@ -7767,7 +7785,7 @@ impl<'a> Iterator for MultiBufferChunks<'a> { } chunks } else { - let base_buffer = &self.diffs.get(&buffer_id)?.base_text(); + let base_buffer = &self.diffs.get(buffer_id)?.base_text(); base_buffer.chunks(base_text_start..base_text_end, self.language_aware) }; @@ -7780,6 +7798,7 @@ impl<'a> Iterator for MultiBufferChunks<'a> { self.range.start += "\n".len(); Chunk { text: "\n", + chars: 1u128, ..Default::default() } }; @@ -7855,10 +7874,11 @@ impl io::Read for ReversedMultiBufferBytes<'_> { if len > 0 { self.range.end -= len; self.chunk = &self.chunk[..self.chunk.len() - len]; - if !self.range.is_empty() && self.chunk.is_empty() { - if let Some(chunk) = self.chunks.next() { - self.chunk = chunk.as_bytes(); - } + if !self.range.is_empty() + && self.chunk.is_empty() + && let Some(chunk) = self.chunks.next() + { + self.chunk = chunk.as_bytes(); } } Ok(len) @@ -7875,9 +7895,11 @@ impl<'a> Iterator for ExcerptChunks<'a> { if self.footer_height > 0 { let text = unsafe { str::from_utf8_unchecked(&NEWLINES[..self.footer_height]) }; + let chars = (1 << self.footer_height) - 1; self.footer_height = 0; return Some(Chunk { text, + chars, ..Default::default() }); } diff --git a/crates/multi_buffer/src/multi_buffer_tests.rs b/crates/multi_buffer/src/multi_buffer_tests.rs index 824efa559f6d52bf654d8f6c6ff9655eaf4a0e52..1be82500786b36fc014c2acf4fb49d4e8abc4d6b 100644 --- a/crates/multi_buffer/src/multi_buffer_tests.rs +++ b/crates/multi_buffer/src/multi_buffer_tests.rs @@ -7,6 +7,7 @@ use parking_lot::RwLock; use rand::prelude::*; use settings::SettingsStore; use std::env; +use util::RandomCharIter; use util::test::sample_text; #[ctor::ctor] @@ -473,7 +474,7 @@ fn test_editing_text_in_diff_hunks(cx: &mut TestAppContext) { let base_text = "one\ntwo\nfour\nfive\nsix\nseven\n"; let text = "one\ntwo\nTHREE\nfour\nfive\nseven\n"; let buffer = cx.new(|cx| Buffer::local(text, cx)); - let diff = cx.new(|cx| BufferDiff::new_with_base_text(&base_text, &buffer, cx)); + let diff = cx.new(|cx| BufferDiff::new_with_base_text(base_text, &buffer, cx)); let multibuffer = cx.new(|cx| MultiBuffer::singleton(buffer.clone(), cx)); let (mut snapshot, mut subscription) = multibuffer.update(cx, |multibuffer, cx| { @@ -2250,11 +2251,11 @@ impl ReferenceMultibuffer { let base_buffer = diff.base_text(); let mut offset = buffer_range.start; - let mut hunks = diff + let hunks = diff .hunks_intersecting_range(excerpt.range.clone(), buffer, cx) .peekable(); - while let Some(hunk) = hunks.next() { + for hunk in hunks { // Ignore hunks that are outside the excerpt range. let mut hunk_range = hunk.buffer_range.to_offset(buffer); @@ -2265,14 +2266,14 @@ impl ReferenceMultibuffer { } if !excerpt.expanded_diff_hunks.iter().any(|expanded_anchor| { - expanded_anchor.to_offset(&buffer).max(buffer_range.start) + expanded_anchor.to_offset(buffer).max(buffer_range.start) == hunk_range.start.max(buffer_range.start) }) { log::trace!("skipping a hunk that's not marked as expanded"); continue; } - if !hunk.buffer_range.start.is_valid(&buffer) { + if !hunk.buffer_range.start.is_valid(buffer) { log::trace!("skipping hunk with deleted start: {:?}", hunk.range); continue; } @@ -2449,7 +2450,7 @@ impl ReferenceMultibuffer { return false; } while let Some(hunk) = hunks.peek() { - match hunk.buffer_range.start.cmp(&hunk_anchor, &buffer) { + match hunk.buffer_range.start.cmp(hunk_anchor, &buffer) { cmp::Ordering::Less => { hunks.next(); } @@ -2491,12 +2492,12 @@ async fn test_random_set_ranges(cx: &mut TestAppContext, mut rng: StdRng) { for _ in 0..operations { let snapshot = buf.update(cx, |buf, _| buf.snapshot()); - let num_ranges = rng.gen_range(0..=10); + let num_ranges = rng.random_range(0..=10); let max_row = snapshot.max_point().row; let mut ranges = (0..num_ranges) .map(|_| { - let start = rng.gen_range(0..max_row); - let end = rng.gen_range(start + 1..max_row + 1); + let start = rng.random_range(0..max_row); + let end = rng.random_range(start + 1..max_row + 1); Point::row_range(start..end) }) .collect::>(); @@ -2519,8 +2520,8 @@ async fn test_random_set_ranges(cx: &mut TestAppContext, mut rng: StdRng) { let mut seen_ranges = Vec::default(); for (_, buf, range) in snapshot.excerpts() { - let start = range.context.start.to_point(&buf); - let end = range.context.end.to_point(&buf); + let start = range.context.start.to_point(buf); + let end = range.context.end.to_point(buf); seen_ranges.push(start..end); if let Some(last_end) = last_end.take() { @@ -2562,11 +2563,11 @@ async fn test_random_multibuffer(cx: &mut TestAppContext, mut rng: StdRng) { let mut needs_diff_calculation = false; for _ in 0..operations { - match rng.gen_range(0..100) { + match rng.random_range(0..100) { 0..=14 if !buffers.is_empty() => { let buffer = buffers.choose(&mut rng).unwrap(); buffer.update(cx, |buf, cx| { - let edit_count = rng.gen_range(1..5); + let edit_count = rng.random_range(1..5); buf.randomly_edit(&mut rng, edit_count, cx); log::info!("buffer text:\n{}", buf.text()); needs_diff_calculation = true; @@ -2577,11 +2578,11 @@ async fn test_random_multibuffer(cx: &mut TestAppContext, mut rng: StdRng) { multibuffer.update(cx, |multibuffer, cx| { let ids = multibuffer.excerpt_ids(); let mut excerpts = HashSet::default(); - for _ in 0..rng.gen_range(0..ids.len()) { + for _ in 0..rng.random_range(0..ids.len()) { excerpts.extend(ids.choose(&mut rng).copied()); } - let line_count = rng.gen_range(0..5); + let line_count = rng.random_range(0..5); let excerpt_ixs = excerpts .iter() @@ -2600,7 +2601,7 @@ async fn test_random_multibuffer(cx: &mut TestAppContext, mut rng: StdRng) { } 20..=29 if !reference.excerpts.is_empty() => { let mut ids_to_remove = vec![]; - for _ in 0..rng.gen_range(1..=3) { + for _ in 0..rng.random_range(1..=3) { let Some(excerpt) = reference.excerpts.choose(&mut rng) else { break; }; @@ -2620,8 +2621,12 @@ async fn test_random_multibuffer(cx: &mut TestAppContext, mut rng: StdRng) { let multibuffer = multibuffer.read_with(cx, |multibuffer, cx| multibuffer.snapshot(cx)); let offset = - multibuffer.clip_offset(rng.gen_range(0..=multibuffer.len()), Bias::Left); - let bias = if rng.r#gen() { Bias::Left } else { Bias::Right }; + multibuffer.clip_offset(rng.random_range(0..=multibuffer.len()), Bias::Left); + let bias = if rng.random() { + Bias::Left + } else { + Bias::Right + }; log::info!("Creating anchor at {} with bias {:?}", offset, bias); anchors.push(multibuffer.anchor_at(offset, bias)); anchors.sort_by(|a, b| a.cmp(b, &multibuffer)); @@ -2654,7 +2659,7 @@ async fn test_random_multibuffer(cx: &mut TestAppContext, mut rng: StdRng) { 45..=55 if !reference.excerpts.is_empty() => { multibuffer.update(cx, |multibuffer, cx| { let snapshot = multibuffer.snapshot(cx); - let excerpt_ix = rng.gen_range(0..reference.excerpts.len()); + let excerpt_ix = rng.random_range(0..reference.excerpts.len()); let excerpt = &reference.excerpts[excerpt_ix]; let start = excerpt.range.start; let end = excerpt.range.end; @@ -2691,7 +2696,7 @@ async fn test_random_multibuffer(cx: &mut TestAppContext, mut rng: StdRng) { }); } _ => { - let buffer_handle = if buffers.is_empty() || rng.gen_bool(0.4) { + let buffer_handle = if buffers.is_empty() || rng.random_bool(0.4) { let mut base_text = util::RandomCharIter::new(&mut rng) .take(256) .collect::(); @@ -2708,7 +2713,7 @@ async fn test_random_multibuffer(cx: &mut TestAppContext, mut rng: StdRng) { buffers.choose(&mut rng).unwrap() }; - let prev_excerpt_ix = rng.gen_range(0..=reference.excerpts.len()); + let prev_excerpt_ix = rng.random_range(0..=reference.excerpts.len()); let prev_excerpt_id = reference .excerpts .get(prev_excerpt_ix) @@ -2716,8 +2721,8 @@ async fn test_random_multibuffer(cx: &mut TestAppContext, mut rng: StdRng) { let excerpt_ix = (prev_excerpt_ix + 1).min(reference.excerpts.len()); let (range, anchor_range) = buffer_handle.read_with(cx, |buffer, _| { - let end_row = rng.gen_range(0..=buffer.max_point().row); - let start_row = rng.gen_range(0..=end_row); + let end_row = rng.random_range(0..=buffer.max_point().row); + let start_row = rng.random_range(0..=end_row); let end_ix = buffer.point_to_offset(Point::new(end_row, 0)); let start_ix = buffer.point_to_offset(Point::new(start_row, 0)); let anchor_range = buffer.anchor_before(start_ix)..buffer.anchor_after(end_ix); @@ -2739,9 +2744,8 @@ async fn test_random_multibuffer(cx: &mut TestAppContext, mut rng: StdRng) { let id = buffer_handle.read(cx).remote_id(); if multibuffer.diff_for(id).is_none() { let base_text = base_texts.get(&id).unwrap(); - let diff = cx.new(|cx| { - BufferDiff::new_with_base_text(base_text, &buffer_handle, cx) - }); + let diff = cx + .new(|cx| BufferDiff::new_with_base_text(base_text, buffer_handle, cx)); reference.add_diff(diff.clone(), cx); multibuffer.add_diff(diff, cx) } @@ -2767,7 +2771,7 @@ async fn test_random_multibuffer(cx: &mut TestAppContext, mut rng: StdRng) { } } - if rng.gen_bool(0.3) { + if rng.random_bool(0.3) { multibuffer.update(cx, |multibuffer, cx| { old_versions.push((multibuffer.snapshot(cx), multibuffer.subscribe())); }) @@ -2816,7 +2820,7 @@ async fn test_random_multibuffer(cx: &mut TestAppContext, mut rng: StdRng) { pretty_assertions::assert_eq!(actual_row_infos, expected_row_infos); for _ in 0..5 { - let start_row = rng.gen_range(0..=expected_row_infos.len()); + let start_row = rng.random_range(0..=expected_row_infos.len()); assert_eq!( snapshot .row_infos(MultiBufferRow(start_row as u32)) @@ -2873,8 +2877,8 @@ async fn test_random_multibuffer(cx: &mut TestAppContext, mut rng: StdRng) { let text_rope = Rope::from(expected_text.as_str()); for _ in 0..10 { - let end_ix = text_rope.clip_offset(rng.gen_range(0..=text_rope.len()), Bias::Right); - let start_ix = text_rope.clip_offset(rng.gen_range(0..=end_ix), Bias::Left); + let end_ix = text_rope.clip_offset(rng.random_range(0..=text_rope.len()), Bias::Right); + let start_ix = text_rope.clip_offset(rng.random_range(0..=end_ix), Bias::Left); let text_for_range = snapshot .text_for_range(start_ix..end_ix) @@ -2909,7 +2913,7 @@ async fn test_random_multibuffer(cx: &mut TestAppContext, mut rng: StdRng) { } for _ in 0..10 { - let end_ix = text_rope.clip_offset(rng.gen_range(0..=text_rope.len()), Bias::Right); + let end_ix = text_rope.clip_offset(rng.random_range(0..=text_rope.len()), Bias::Right); assert_eq!( snapshot.reversed_chars_at(end_ix).collect::(), expected_text[..end_ix].chars().rev().collect::(), @@ -2917,8 +2921,8 @@ async fn test_random_multibuffer(cx: &mut TestAppContext, mut rng: StdRng) { } for _ in 0..10 { - let end_ix = rng.gen_range(0..=text_rope.len()); - let start_ix = rng.gen_range(0..=end_ix); + let end_ix = rng.random_range(0..=text_rope.len()); + let start_ix = rng.random_range(0..=end_ix); assert_eq!( snapshot .bytes_in_range(start_ix..end_ix) @@ -3593,24 +3597,20 @@ fn assert_position_translation(snapshot: &MultiBufferSnapshot) { for (anchors, bias) in [(&left_anchors, Bias::Left), (&right_anchors, Bias::Right)] { for (ix, (offset, anchor)) in offsets.iter().zip(anchors).enumerate() { - if ix > 0 { - if *offset == 252 { - if offset > &offsets[ix - 1] { - let prev_anchor = left_anchors[ix - 1]; - assert!( - anchor.cmp(&prev_anchor, snapshot).is_gt(), - "anchor({}, {bias:?}).cmp(&anchor({}, {bias:?}).is_gt()", - offsets[ix], - offsets[ix - 1], - ); - assert!( - prev_anchor.cmp(&anchor, snapshot).is_lt(), - "anchor({}, {bias:?}).cmp(&anchor({}, {bias:?}).is_lt()", - offsets[ix - 1], - offsets[ix], - ); - } - } + if ix > 0 && *offset == 252 && offset > &offsets[ix - 1] { + let prev_anchor = left_anchors[ix - 1]; + assert!( + anchor.cmp(&prev_anchor, snapshot).is_gt(), + "anchor({}, {bias:?}).cmp(&anchor({}, {bias:?}).is_gt()", + offsets[ix], + offsets[ix - 1], + ); + assert!( + prev_anchor.cmp(anchor, snapshot).is_lt(), + "anchor({}, {bias:?}).cmp(&anchor({}, {bias:?}).is_lt()", + offsets[ix - 1], + offsets[ix], + ); } } } @@ -3717,3 +3717,235 @@ fn test_new_empty_buffers_title_can_be_set(cx: &mut App) { }); assert_eq!(multibuffer.read(cx).title(cx), "Hey"); } + +#[gpui::test(iterations = 100)] +fn test_random_chunk_bitmaps(cx: &mut App, mut rng: StdRng) { + let multibuffer = if rng.random() { + let len = rng.random_range(0..10000); + let text = RandomCharIter::new(&mut rng).take(len).collect::(); + let buffer = cx.new(|cx| Buffer::local(text, cx)); + cx.new(|cx| MultiBuffer::singleton(buffer, cx)) + } else { + MultiBuffer::build_random(&mut rng, cx) + }; + + let snapshot = multibuffer.read(cx).snapshot(cx); + + let chunks = snapshot.chunks(0..snapshot.len(), false); + + for chunk in chunks { + let chunk_text = chunk.text; + let chars_bitmap = chunk.chars; + let tabs_bitmap = chunk.tabs; + + if chunk_text.is_empty() { + assert_eq!( + chars_bitmap, 0, + "Empty chunk should have empty chars bitmap" + ); + assert_eq!(tabs_bitmap, 0, "Empty chunk should have empty tabs bitmap"); + continue; + } + + assert!( + chunk_text.len() <= 128, + "Chunk text length {} exceeds 128 bytes", + chunk_text.len() + ); + + // Verify chars bitmap + let char_indices = chunk_text + .char_indices() + .map(|(i, _)| i) + .collect::>(); + + for byte_idx in 0..chunk_text.len() { + let should_have_bit = char_indices.contains(&byte_idx); + let has_bit = chars_bitmap & (1 << byte_idx) != 0; + + if has_bit != should_have_bit { + eprintln!("Chunk text bytes: {:?}", chunk_text.as_bytes()); + eprintln!("Char indices: {:?}", char_indices); + eprintln!("Chars bitmap: {:#b}", chars_bitmap); + } + + assert_eq!( + has_bit, should_have_bit, + "Chars bitmap mismatch at byte index {} in chunk {:?}. Expected bit: {}, Got bit: {}", + byte_idx, chunk_text, should_have_bit, has_bit + ); + } + + for (byte_idx, byte) in chunk_text.bytes().enumerate() { + let is_tab = byte == b'\t'; + let has_bit = tabs_bitmap & (1 << byte_idx) != 0; + + if has_bit != is_tab { + eprintln!("Chunk text bytes: {:?}", chunk_text.as_bytes()); + eprintln!("Tabs bitmap: {:#b}", tabs_bitmap); + assert_eq!( + has_bit, is_tab, + "Tabs bitmap mismatch at byte index {} in chunk {:?}. Byte: {:?}, Expected bit: {}, Got bit: {}", + byte_idx, chunk_text, byte as char, is_tab, has_bit + ); + } + } + } +} + +#[gpui::test(iterations = 100)] +fn test_random_chunk_bitmaps_with_diffs(cx: &mut App, mut rng: StdRng) { + use buffer_diff::BufferDiff; + use util::RandomCharIter; + + let multibuffer = if rng.random() { + let len = rng.random_range(100..10000); + let text = RandomCharIter::new(&mut rng).take(len).collect::(); + let buffer = cx.new(|cx| Buffer::local(text, cx)); + cx.new(|cx| MultiBuffer::singleton(buffer, cx)) + } else { + MultiBuffer::build_random(&mut rng, cx) + }; + + let _diff_count = rng.random_range(1..5); + let mut diffs = Vec::new(); + + multibuffer.update(cx, |multibuffer, cx| { + for buffer_id in multibuffer.excerpt_buffer_ids() { + if rng.random_bool(0.7) { + if let Some(buffer_handle) = multibuffer.buffer(buffer_id) { + let buffer_text = buffer_handle.read(cx).text(); + let mut base_text = String::new(); + + for line in buffer_text.lines() { + if rng.random_bool(0.3) { + continue; + } else if rng.random_bool(0.3) { + let line_len = rng.random_range(0..50); + let modified_line = RandomCharIter::new(&mut rng) + .take(line_len) + .collect::(); + base_text.push_str(&modified_line); + base_text.push('\n'); + } else { + base_text.push_str(line); + base_text.push('\n'); + } + } + + if rng.random_bool(0.5) { + let extra_lines = rng.random_range(1..5); + for _ in 0..extra_lines { + let line_len = rng.random_range(0..50); + let extra_line = RandomCharIter::new(&mut rng) + .take(line_len) + .collect::(); + base_text.push_str(&extra_line); + base_text.push('\n'); + } + } + + let diff = + cx.new(|cx| BufferDiff::new_with_base_text(&base_text, &buffer_handle, cx)); + diffs.push(diff.clone()); + multibuffer.add_diff(diff, cx); + } + } + } + }); + + multibuffer.update(cx, |multibuffer, cx| { + if rng.random_bool(0.5) { + multibuffer.set_all_diff_hunks_expanded(cx); + } else { + let snapshot = multibuffer.snapshot(cx); + let text = snapshot.text(); + + let mut ranges = Vec::new(); + for _ in 0..rng.random_range(1..5) { + if snapshot.len() == 0 { + break; + } + + let diff_size = rng.random_range(5..1000); + let mut start = rng.random_range(0..snapshot.len()); + + while !text.is_char_boundary(start) { + start = start.saturating_sub(1); + } + + let mut end = rng.random_range(start..snapshot.len().min(start + diff_size)); + + while !text.is_char_boundary(end) { + end = end.saturating_add(1); + } + let start_anchor = snapshot.anchor_after(start); + let end_anchor = snapshot.anchor_before(end); + ranges.push(start_anchor..end_anchor); + } + multibuffer.expand_diff_hunks(ranges, cx); + } + }); + + let snapshot = multibuffer.read(cx).snapshot(cx); + + let chunks = snapshot.chunks(0..snapshot.len(), false); + + for chunk in chunks { + let chunk_text = chunk.text; + let chars_bitmap = chunk.chars; + let tabs_bitmap = chunk.tabs; + + if chunk_text.is_empty() { + assert_eq!( + chars_bitmap, 0, + "Empty chunk should have empty chars bitmap" + ); + assert_eq!(tabs_bitmap, 0, "Empty chunk should have empty tabs bitmap"); + continue; + } + + assert!( + chunk_text.len() <= 128, + "Chunk text length {} exceeds 128 bytes", + chunk_text.len() + ); + + let char_indices = chunk_text + .char_indices() + .map(|(i, _)| i) + .collect::>(); + + for byte_idx in 0..chunk_text.len() { + let should_have_bit = char_indices.contains(&byte_idx); + let has_bit = chars_bitmap & (1 << byte_idx) != 0; + + if has_bit != should_have_bit { + eprintln!("Chunk text bytes: {:?}", chunk_text.as_bytes()); + eprintln!("Char indices: {:?}", char_indices); + eprintln!("Chars bitmap: {:#b}", chars_bitmap); + } + + assert_eq!( + has_bit, should_have_bit, + "Chars bitmap mismatch at byte index {} in chunk {:?}. Expected bit: {}, Got bit: {}", + byte_idx, chunk_text, should_have_bit, has_bit + ); + } + + for (byte_idx, byte) in chunk_text.bytes().enumerate() { + let is_tab = byte == b'\t'; + let has_bit = tabs_bitmap & (1 << byte_idx) != 0; + + if has_bit != is_tab { + eprintln!("Chunk text bytes: {:?}", chunk_text.as_bytes()); + eprintln!("Tabs bitmap: {:#b}", tabs_bitmap); + assert_eq!( + has_bit, is_tab, + "Tabs bitmap mismatch at byte index {} in chunk {:?}. Byte: {:?}, Expected bit: {}, Got bit: {}", + byte_idx, chunk_text, byte as char, is_tab, has_bit + ); + } + } + } +} diff --git a/crates/multi_buffer/src/position.rs b/crates/multi_buffer/src/position.rs index 06508750597b97d7275b964114bcdad0d0e34c79..8a3ce78d0d9f7a6880dbc3202c002507c800b7b0 100644 --- a/crates/multi_buffer/src/position.rs +++ b/crates/multi_buffer/src/position.rs @@ -126,17 +126,17 @@ impl Default for TypedRow { impl PartialOrd for TypedOffset { fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(&other)) + Some(self.cmp(other)) } } impl PartialOrd for TypedPoint { fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(&other)) + Some(self.cmp(other)) } } impl PartialOrd for TypedRow { fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(&other)) + Some(self.cmp(other)) } } diff --git a/crates/node_runtime/src/node_runtime.rs b/crates/node_runtime/src/node_runtime.rs index 6fcc3a728af903f581046c9a3e069f9fbcfc9ecf..9d41eb1562943683aae3e785b1daac8bc3bfeb1a 100644 --- a/crates/node_runtime/src/node_runtime.rs +++ b/crates/node_runtime/src/node_runtime.rs @@ -29,13 +29,11 @@ pub struct NodeBinaryOptions { pub use_paths: Option<(PathBuf, PathBuf)>, } -#[derive(Default)] -pub enum VersionCheck { - /// Check whether the installed and requested version have a mismatch - VersionMismatch, - /// Only check whether the currently installed version is older than the newest one - #[default] - OlderVersion, +pub enum VersionStrategy<'a> { + /// Install if current version doesn't match pinned version + Pin(&'a str), + /// Install if current version is older than latest version + Latest(&'a str), } #[derive(Clone)] @@ -78,9 +76,8 @@ impl NodeRuntime { let mut state = self.0.lock().await; let options = loop { - match state.options.borrow().as_ref() { - Some(options) => break options.clone(), - None => {} + if let Some(options) = state.options.borrow().as_ref() { + break options.clone(); } match state.options.changed().await { Ok(()) => {} @@ -199,7 +196,7 @@ impl NodeRuntime { state.instance = Some(instance.boxed_clone()); state.last_options = Some(options); - return instance; + instance } pub async fn binary_path(&self) -> Result { @@ -295,8 +292,7 @@ impl NodeRuntime { package_name: &str, local_executable_path: &Path, local_package_directory: &Path, - latest_version: &str, - version_check: VersionCheck, + version_strategy: VersionStrategy<'_>, ) -> bool { // In the case of the local system not having the package installed, // or in the instances where we fail to parse package.json data, @@ -317,13 +313,20 @@ impl NodeRuntime { let Some(installed_version) = Version::parse(&installed_version).log_err() else { return true; }; - let Some(latest_version) = Version::parse(latest_version).log_err() else { - return true; - }; - match version_check { - VersionCheck::VersionMismatch => installed_version != latest_version, - VersionCheck::OlderVersion => installed_version < latest_version, + match version_strategy { + VersionStrategy::Pin(pinned_version) => { + let Some(pinned_version) = Version::parse(pinned_version).log_err() else { + return true; + }; + installed_version != pinned_version + } + VersionStrategy::Latest(latest_version) => { + let Some(latest_version) = Version::parse(latest_version).log_err() else { + return true; + }; + installed_version < latest_version + } } } } diff --git a/crates/notifications/Cargo.toml b/crates/notifications/Cargo.toml index baf5444ef4903dd1d0efc64e7553abe3ed414720..39acfe2b384c8a2264c5c2dac91024edad89d33a 100644 --- a/crates/notifications/Cargo.toml +++ b/crates/notifications/Cargo.toml @@ -24,7 +24,6 @@ test-support = [ anyhow.workspace = true channel.workspace = true client.workspace = true -collections.workspace = true component.workspace = true db.workspace = true gpui.workspace = true diff --git a/crates/notifications/src/notification_store.rs b/crates/notifications/src/notification_store.rs index 29653748e4873a271f58f932ee71c820aa755b9a..7db17da9ff92bce492cc8414be8db28c219d61e7 100644 --- a/crates/notifications/src/notification_store.rs +++ b/crates/notifications/src/notification_store.rs @@ -1,7 +1,6 @@ use anyhow::{Context as _, Result}; -use channel::{ChannelMessage, ChannelMessageId, ChannelStore}; +use channel::ChannelStore; use client::{ChannelId, Client, UserStore}; -use collections::HashMap; use db::smol::stream::StreamExt; use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, Task}; use rpc::{Notification, TypedEnvelope, proto}; @@ -22,7 +21,6 @@ impl Global for GlobalNotificationStore {} pub struct NotificationStore { client: Arc, user_store: Entity, - channel_messages: HashMap, channel_store: Entity, notifications: SumTree, loaded_all_notifications: bool, @@ -100,12 +98,10 @@ impl NotificationStore { channel_store: ChannelStore::global(cx), notifications: Default::default(), loaded_all_notifications: false, - channel_messages: Default::default(), _watch_connection_status: watch_connection_status, _subscriptions: vec![ client.add_message_handler(cx.weak_entity(), Self::handle_new_notification), client.add_message_handler(cx.weak_entity(), Self::handle_delete_notification), - client.add_message_handler(cx.weak_entity(), Self::handle_update_notification), ], user_store, client, @@ -120,10 +116,6 @@ impl NotificationStore { self.notifications.summary().unread_count } - pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> { - self.channel_messages.get(&id) - } - // Get the nth newest notification. pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> { let count = self.notifications.summary().count; @@ -138,10 +130,10 @@ impl NotificationStore { pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> { let mut cursor = self.notifications.cursor::(&()); cursor.seek(&NotificationId(id), Bias::Left); - if let Some(item) = cursor.item() { - if item.id == id { - return Some(item); - } + if let Some(item) = cursor.item() + && item.id == id + { + return Some(item); } None } @@ -185,7 +177,6 @@ impl NotificationStore { fn handle_connect(&mut self, cx: &mut Context) -> Option>> { self.notifications = Default::default(); - self.channel_messages = Default::default(); cx.notify(); self.load_more_notifications(true, cx) } @@ -223,36 +214,6 @@ impl NotificationStore { })? } - async fn handle_update_notification( - this: Entity, - envelope: TypedEnvelope, - mut cx: AsyncApp, - ) -> Result<()> { - this.update(&mut cx, |this, cx| { - if let Some(notification) = envelope.payload.notification { - if let Some(rpc::Notification::ChannelMessageMention { message_id, .. }) = - Notification::from_proto(¬ification) - { - let fetch_message_task = this.channel_store.update(cx, |this, cx| { - this.fetch_channel_messages(vec![message_id], cx) - }); - - cx.spawn(async move |this, cx| { - let messages = fetch_message_task.await?; - this.update(cx, move |this, cx| { - for message in messages { - this.channel_messages.insert(message_id, message); - } - cx.notify(); - }) - }) - .detach_and_log_err(cx) - } - } - Ok(()) - })? - } - async fn add_notifications( this: Entity, notifications: Vec, @@ -260,7 +221,6 @@ impl NotificationStore { cx: &mut AsyncApp, ) -> Result<()> { let mut user_ids = Vec::new(); - let mut message_ids = Vec::new(); let notifications = notifications .into_iter() @@ -294,29 +254,14 @@ impl NotificationStore { } => { user_ids.push(contact_id); } - Notification::ChannelMessageMention { - sender_id, - message_id, - .. - } => { - user_ids.push(sender_id); - message_ids.push(message_id); - } } } - let (user_store, channel_store) = this.read_with(cx, |this, _| { - (this.user_store.clone(), this.channel_store.clone()) - })?; + let user_store = this.read_with(cx, |this, _| this.user_store.clone())?; user_store .update(cx, |store, cx| store.get_users(user_ids, cx))? .await?; - let messages = channel_store - .update(cx, |store, cx| { - store.fetch_channel_messages(message_ids, cx) - })? - .await?; this.update(cx, |this, cx| { if options.clear_old { cx.emit(NotificationEvent::NotificationsUpdated { @@ -324,7 +269,6 @@ impl NotificationStore { new_count: 0, }); this.notifications = SumTree::default(); - this.channel_messages.clear(); this.loaded_all_notifications = false; } @@ -332,15 +276,6 @@ impl NotificationStore { this.loaded_all_notifications = true; } - this.channel_messages - .extend(messages.into_iter().filter_map(|message| { - if let ChannelMessageId::Saved(id) = message.id { - Some((id, message)) - } else { - None - } - })); - this.splice_notifications( notifications .into_iter() @@ -390,12 +325,12 @@ impl NotificationStore { }); } } - } else if let Some(new_notification) = &new_notification { - if is_new { - cx.emit(NotificationEvent::NewNotification { - entry: new_notification.clone(), - }); - } + } else if let Some(new_notification) = &new_notification + && is_new + { + cx.emit(NotificationEvent::NewNotification { + entry: new_notification.clone(), + }); } if let Some(notification) = new_notification { diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index 64cd1cc0cbc06607ee9b3b72ee81cbeb9489c344..c61108d8bd59375256b7eb8b511527e8a0a119c2 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -46,19 +46,19 @@ fn get_max_tokens(name: &str) -> u64 { /// Default context length for unknown models. const DEFAULT_TOKENS: u64 = 4096; /// Magic number. Lets many Ollama models work with ~16GB of ram. + /// Models that support context beyond 16k such as codestral (32k) or devstral (128k) will be clamped down to 16k const MAXIMUM_TOKENS: u64 = 16384; match name.split(':').next().unwrap() { - "phi" | "tinyllama" | "granite-code" => 2048, - "llama2" | "yi" | "vicuna" | "stablelm2" => 4096, - "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192, + "granite-code" | "phi" | "tinyllama" => 2048, + "llama2" | "stablelm2" | "vicuna" | "yi" => 4096, + "aya" | "codegemma" | "gemma" | "gemma2" | "llama3" | "starcoder" => 8192, "codellama" | "starcoder2" => 16384, - "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "qwen2.5-coder" - | "dolphin-mixtral" => 32768, - "magistral" => 40000, - "llama3.1" | "llama3.2" | "llama3.3" | "phi3" | "phi3.5" | "phi4" | "command-r" - | "qwen3" | "gemma3" | "deepseek-coder-v2" | "deepseek-v3" | "deepseek-r1" | "yi-coder" - | "devstral" | "gpt-oss" => 128000, + "codestral" | "dolphin-mixtral" | "llava" | "magistral" | "mistral" | "mixstral" + | "qwen2" | "qwen2.5-coder" => 32768, + "cogito" | "command-r" | "deepseek-coder-v2" | "deepseek-r1" | "deepseek-v3" + | "devstral" | "gemma3" | "gpt-oss" | "granite3.3" | "llama3.1" | "llama3.2" + | "llama3.3" | "mistral-nemo" | "phi3" | "phi3.5" | "phi4" | "qwen3" | "yi-coder" => 128000, _ => DEFAULT_TOKENS, } .clamp(1, MAXIMUM_TOKENS) @@ -117,6 +117,10 @@ pub enum ChatMessage { System { content: String, }, + Tool { + tool_name: String, + content: String, + }, } #[derive(Serialize, Deserialize, Debug)] diff --git a/crates/onboarding/src/ai_setup_page.rs b/crates/onboarding/src/ai_setup_page.rs index bb1932bdf21ee9c927f085c8d5ad0a7cbd4c7fbd..3631ad00dfb8662d5d4142a4cbd11186c1b1b137 100644 --- a/crates/onboarding/src/ai_setup_page.rs +++ b/crates/onboarding/src/ai_setup_page.rs @@ -19,7 +19,7 @@ use util::ResultExt; use workspace::{ModalView, Workspace}; use zed_actions::agent::OpenSettings; -const FEATURED_PROVIDERS: [&'static str; 4] = ["anthropic", "google", "openai", "ollama"]; +const FEATURED_PROVIDERS: [&str; 4] = ["anthropic", "google", "openai", "ollama"]; fn render_llm_provider_section( tab_index: &mut isize, @@ -264,13 +264,9 @@ pub(crate) fn render_ai_setup_page( ); let fs = ::global(cx); - update_settings_file::( - fs, - cx, - move |ai_settings: &mut Option, _| { - *ai_settings = Some(enabled); - }, - ); + update_settings_file::(fs, cx, move |ai_settings, _| { + ai_settings.disable_ai = Some(enabled); + }); }, ) .tab_index({ @@ -283,17 +279,13 @@ pub(crate) fn render_ai_setup_page( v_flex() .mt_2() .gap_6() - .child({ - let mut ai_upsell_card = - AiUpsellCard::new(client, &user_store, user_store.read(cx).plan(), cx); - - ai_upsell_card.tab_index = Some({ - tab_index += 1; - tab_index - 1 - }); - - ai_upsell_card - }) + .child( + AiUpsellCard::new(client, &user_store, user_store.read(cx).plan(), cx) + .tab_index(Some({ + tab_index += 1; + tab_index - 1 + })), + ) .child(render_llm_provider_section( &mut tab_index, workspace, @@ -329,7 +321,11 @@ impl AiConfigurationModal { cx: &mut Context, ) -> Self { let focus_handle = cx.focus_handle(); - let configuration_view = selected_provider.configuration_view(window, cx); + let configuration_view = selected_provider.configuration_view( + language_model::ConfigurationViewTargetAgent::ZedAgent, + window, + cx, + ); Self { focus_handle, @@ -406,7 +402,7 @@ impl AiPrivacyTooltip { impl Render for AiPrivacyTooltip { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - const DESCRIPTION: &'static str = "We believe in opt-in data sharing as the default for building AI products, rather than opt-out. We'll only use or store your data if you affirmatively send it to us. "; + const DESCRIPTION: &str = "We believe in opt-in data sharing as the default for building AI products, rather than opt-out. We'll only use or store your data if you affirmatively send it to us. "; tooltip_container(window, cx, move |this, _, _| { this.child( diff --git a/crates/onboarding/src/base_keymap_picker.rs b/crates/onboarding/src/base_keymap_picker.rs index 0ac07d9a9d3b17921112d6accf6f4c9c9dd65ef6..950ed1d3133549a566808bd7cee002437708a2ad 100644 --- a/crates/onboarding/src/base_keymap_picker.rs +++ b/crates/onboarding/src/base_keymap_picker.rs @@ -187,7 +187,7 @@ impl PickerDelegate for BaseKeymapSelectorDelegate { ); update_settings_file::(self.fs.clone(), cx, move |setting, _| { - *setting = Some(base_keymap) + setting.base_keymap = Some(base_keymap) }); } @@ -213,7 +213,7 @@ impl PickerDelegate for BaseKeymapSelectorDelegate { _window: &mut Window, _cx: &mut Context>, ) -> Option { - let keymap_match = &self.matches[ix]; + let keymap_match = &self.matches.get(ix)?; Some( ListItem::new(ix) diff --git a/crates/onboarding/src/basics_page.rs b/crates/onboarding/src/basics_page.rs index 86ddc22a8687b5f591afb810ead541a0294dc7d9..aef9dcca86ce49a70f1a508c0a43614737a653c7 100644 --- a/crates/onboarding/src/basics_page.rs +++ b/crates/onboarding/src/basics_page.rs @@ -16,6 +16,23 @@ use vim_mode_setting::VimModeSetting; use crate::theme_preview::{ThemePreviewStyle, ThemePreviewTile}; +const LIGHT_THEMES: [&str; 3] = ["One Light", "Ayu Light", "Gruvbox Light"]; +const DARK_THEMES: [&str; 3] = ["One Dark", "Ayu Dark", "Gruvbox Dark"]; +const FAMILY_NAMES: [SharedString; 3] = [ + SharedString::new_static("One"), + SharedString::new_static("Ayu"), + SharedString::new_static("Gruvbox"), +]; + +fn get_theme_family_themes(theme_name: &str) -> Option<(&'static str, &'static str)> { + for i in 0..LIGHT_THEMES.len() { + if LIGHT_THEMES[i] == theme_name || DARK_THEMES[i] == theme_name { + return Some((LIGHT_THEMES[i], DARK_THEMES[i])); + } + } + None +} + fn render_theme_section(tab_index: &mut isize, cx: &mut App) -> impl IntoElement { let theme_selection = ThemeSettings::get_global(cx).theme_selection.clone(); let system_appearance = theme::SystemAppearance::global(cx); @@ -51,6 +68,12 @@ fn render_theme_section(tab_index: &mut isize, cx: &mut App) -> impl IntoElement MODE_NAMES[mode as usize].clone(), move |_, _, cx| { write_mode_change(mode, cx); + + telemetry::event!( + "Welcome Theme mode Changed", + from = theme_mode, + to = mode + ); }, ) }), @@ -88,15 +111,7 @@ fn render_theme_section(tab_index: &mut isize, cx: &mut App) -> impl IntoElement ThemeMode::Dark => Appearance::Dark, ThemeMode::System => *system_appearance, }; - let current_theme_name = theme_selection.theme(appearance); - - const LIGHT_THEMES: [&'static str; 3] = ["One Light", "Ayu Light", "Gruvbox Light"]; - const DARK_THEMES: [&'static str; 3] = ["One Dark", "Ayu Dark", "Gruvbox Dark"]; - const FAMILY_NAMES: [SharedString; 3] = [ - SharedString::new_static("One"), - SharedString::new_static("Ayu"), - SharedString::new_static("Gruvbox"), - ]; + let current_theme_name = SharedString::new(theme_selection.theme(appearance)); let theme_names = match appearance { Appearance::Light => LIGHT_THEMES, @@ -105,7 +120,7 @@ fn render_theme_section(tab_index: &mut isize, cx: &mut App) -> impl IntoElement let themes = theme_names.map(|theme| theme_registry.get(theme).unwrap()); - let theme_previews = [0, 1, 2].map(|index| { + [0, 1, 2].map(|index| { let theme = &themes[index]; let is_selected = theme.name == current_theme_name; let name = theme.name.clone(); @@ -117,7 +132,7 @@ fn render_theme_section(tab_index: &mut isize, cx: &mut App) -> impl IntoElement .gap_1() .child( h_flex() - .id(name.clone()) + .id(name) .relative() .w_full() .border_2() @@ -140,8 +155,15 @@ fn render_theme_section(tab_index: &mut isize, cx: &mut App) -> impl IntoElement }) .on_click({ let theme_name = theme.name.clone(); + let current_theme_name = current_theme_name.clone(); + move |_, _, cx| { write_theme_change(theme_name.clone(), theme_mode, cx); + telemetry::event!( + "Welcome Theme Changed", + from = current_theme_name, + to = theme_name + ); } }) .map(|this| { @@ -167,9 +189,7 @@ fn render_theme_section(tab_index: &mut isize, cx: &mut App) -> impl IntoElement .color(Color::Muted) .size(LabelSize::Small), ) - }); - - theme_previews + }) } fn write_mode_change(mode: ThemeMode, cx: &mut App) { @@ -184,14 +204,17 @@ fn render_theme_section(tab_index: &mut isize, cx: &mut App) -> impl IntoElement let theme = theme.into(); update_settings_file::(fs, cx, move |settings, cx| { if theme_mode == ThemeMode::System { + let (light_theme, dark_theme) = + get_theme_family_themes(&theme).unwrap_or((theme.as_ref(), theme.as_ref())); + settings.theme = Some(ThemeSelection::Dynamic { mode: ThemeMode::System, - light: ThemeName(theme.clone()), - dark: ThemeName(theme.clone()), + light: ThemeName(light_theme.into()), + dark: ThemeName(dark_theme.into()), }); } else { let appearance = *SystemAppearance::global(cx); - settings.set_theme(theme.clone(), appearance); + settings.set_theme(theme, appearance); } }); } @@ -229,6 +252,17 @@ fn render_telemetry_section(tab_index: &mut isize, cx: &App) -> impl IntoElement cx, move |setting, _| setting.metrics = Some(enabled), ); + + // This telemetry event shouldn't fire when it's off. If it does we'll be alerted + // and can fix it in a timely manner to respect a user's choice. + telemetry::event!("Welcome Page Telemetry Metrics Toggled", + options = if enabled { + "on" + } else { + "off" + } + ); + }}, ).tab_index({ *tab_index += 1; @@ -257,6 +291,16 @@ fn render_telemetry_section(tab_index: &mut isize, cx: &App) -> impl IntoElement cx, move |setting, _| setting.diagnostics = Some(enabled), ); + + // This telemetry event shouldn't fire when it's off. If it does we'll be alerted + // and can fix it in a timely manner to respect a user's choice. + telemetry::event!("Welcome Page Telemetry Diagnostics Toggled", + options = if enabled { + "on" + } else { + "off" + } + ); } } ).tab_index({ @@ -315,8 +359,10 @@ fn render_base_keymap_section(tab_index: &mut isize, cx: &mut App) -> impl IntoE let fs = ::global(cx); update_settings_file::(fs, cx, move |setting, _| { - *setting = Some(keymap_base); + setting.base_keymap = Some(keymap_base); }); + + telemetry::event!("Welcome Keymap Changed", keymap = keymap_base); } } @@ -334,13 +380,21 @@ fn render_vim_mode_switch(tab_index: &mut isize, cx: &mut App) -> impl IntoEleme { let fs = ::global(cx); move |&selection, _, cx| { - update_settings_file::(fs.clone(), cx, move |setting, _| { - *setting = match selection { - ToggleState::Selected => Some(true), - ToggleState::Unselected => Some(false), - ToggleState::Indeterminate => None, + let vim_mode = match selection { + ToggleState::Selected => true, + ToggleState::Unselected => false, + ToggleState::Indeterminate => { + return; } + }; + update_settings_file::(fs.clone(), cx, move |setting, _| { + setting.vim_mode = Some(vim_mode); }); + + telemetry::event!( + "Welcome Vim Mode Toggled", + options = if vim_mode { "on" } else { "off" }, + ); } }, ) diff --git a/crates/onboarding/src/editing_page.rs b/crates/onboarding/src/editing_page.rs index d941a0315afd726d493ecb121d2862d0067e6dd3..297016abd4a1499feb6f637d028056ca0b412d31 100644 --- a/crates/onboarding/src/editing_page.rs +++ b/crates/onboarding/src/editing_page.rs @@ -104,7 +104,7 @@ fn write_ui_font_family(font: SharedString, cx: &mut App) { "Welcome Font Changed", type = "ui font", old = theme_settings.ui_font_family, - new = font.clone() + new = font ); theme_settings.ui_font_family = Some(FontFamilyName(font.into())); }); @@ -134,7 +134,7 @@ fn write_buffer_font_family(font_family: SharedString, cx: &mut App) { "Welcome Font Changed", type = "editor font", old = theme_settings.buffer_font_family, - new = font_family.clone() + new = font_family ); theme_settings.buffer_font_family = Some(FontFamilyName(font_family.into())); @@ -314,7 +314,7 @@ fn render_font_customization_section( .child( PopoverMenu::new("ui-font-picker") .menu({ - let ui_font_picker = ui_font_picker.clone(); + let ui_font_picker = ui_font_picker; move |_window, _cx| Some(ui_font_picker.clone()) }) .trigger( @@ -378,7 +378,7 @@ fn render_font_customization_section( .child( PopoverMenu::new("buffer-font-picker") .menu({ - let buffer_font_picker = buffer_font_picker.clone(); + let buffer_font_picker = buffer_font_picker; move |_window, _cx| Some(buffer_font_picker.clone()) }) .trigger( @@ -449,28 +449,28 @@ impl FontPickerDelegate { ) -> Self { let font_family_cache = FontFamilyCache::global(cx); - let fonts: Vec = font_family_cache - .list_font_families(cx) - .into_iter() - .collect(); - + let fonts = font_family_cache + .try_list_font_families() + .unwrap_or_else(|| vec![current_font.clone()]); let selected_index = fonts .iter() .position(|font| *font == current_font) .unwrap_or(0); + let filtered_fonts = fonts + .iter() + .enumerate() + .map(|(index, font)| StringMatch { + candidate_id: index, + string: font.to_string(), + positions: Vec::new(), + score: 0.0, + }) + .collect(); + Self { - fonts: fonts.clone(), - filtered_fonts: fonts - .iter() - .enumerate() - .map(|(index, font)| StringMatch { - candidate_id: index, - string: font.to_string(), - positions: Vec::new(), - score: 0.0, - }) - .collect(), + fonts, + filtered_fonts, selected_index, current_font, on_font_changed: Arc::new(on_font_changed), @@ -605,8 +605,8 @@ fn render_popular_settings_section( window: &mut Window, cx: &mut App, ) -> impl IntoElement { - const LIGATURE_TOOLTIP: &'static str = - "Font ligatures combine two characters into one. For example, turning =/= into ≠."; + const LIGATURE_TOOLTIP: &str = + "Font ligatures combine two characters into one. For example, turning != into ≠."; v_flex() .pt_6() diff --git a/crates/onboarding/src/onboarding.rs b/crates/onboarding/src/onboarding.rs index e07a8dc9fb6c6c20b311863da1414dfd6e83eecd..9dcf27c7cbebf6621bbeb558619944c768e63fb6 100644 --- a/crates/onboarding/src/onboarding.rs +++ b/crates/onboarding/src/onboarding.rs @@ -242,12 +242,25 @@ struct Onboarding { impl Onboarding { fn new(workspace: &Workspace, cx: &mut App) -> Entity { - cx.new(|cx| Self { - workspace: workspace.weak_handle(), - focus_handle: cx.focus_handle(), - selected_page: SelectedPage::Basics, - user_store: workspace.user_store().clone(), - _settings_subscription: cx.observe_global::(move |_, cx| cx.notify()), + let font_family_cache = theme::FontFamilyCache::global(cx); + + cx.new(|cx| { + cx.spawn(async move |this, cx| { + font_family_cache.prefetch(cx).await; + this.update(cx, |_, cx| { + cx.notify(); + }) + }) + .detach(); + + Self { + workspace: workspace.weak_handle(), + focus_handle: cx.focus_handle(), + selected_page: SelectedPage::Basics, + user_store: workspace.user_store().clone(), + _settings_subscription: cx + .observe_global::(move |_, cx| cx.notify()), + } }) } @@ -476,6 +489,7 @@ impl Onboarding { .map(|kb| kb.size(rems_from_px(12.))), ) .on_click(|_, window, cx| { + telemetry::event!("Welcome Sign In Clicked"); window.dispatch_action(SignIn.boxed_clone(), cx); }) .into_any_element() @@ -494,7 +508,7 @@ impl Onboarding { window .spawn(cx, async move |cx| { client - .sign_in_with_optional_connect(true, &cx) + .sign_in_with_optional_connect(true, cx) .await .notify_async_err(cx); }) @@ -850,13 +864,19 @@ impl workspace::SerializableItem for Onboarding { } mod persistence { - use db::{define_connection, query, sqlez_macros::sql}; + use db::{ + query, + sqlez::{domain::Domain, thread_safe_connection::ThreadSafeConnection}, + sqlez_macros::sql, + }; use workspace::WorkspaceDb; - define_connection! { - pub static ref ONBOARDING_PAGES: OnboardingPagesDb = - &[ - sql!( + pub struct OnboardingPagesDb(ThreadSafeConnection); + + impl Domain for OnboardingPagesDb { + const NAME: &str = stringify!(OnboardingPagesDb); + + const MIGRATIONS: &[&str] = &[sql!( CREATE TABLE onboarding_pages ( workspace_id INTEGER, item_id INTEGER UNIQUE, @@ -866,10 +886,11 @@ mod persistence { FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) ON DELETE CASCADE ) STRICT; - ), - ]; + )]; } + db::static_connection!(ONBOARDING_PAGES, OnboardingPagesDb, [WorkspaceDb]); + impl OnboardingPagesDb { query! { pub async fn save_onboarding_page( diff --git a/crates/onboarding/src/theme_preview.rs b/crates/onboarding/src/theme_preview.rs index 9f299eb6ea0a994097bac282b60f08decb7ed838..8bd65d8a2707acdc53333071486f41741398a82a 100644 --- a/crates/onboarding/src/theme_preview.rs +++ b/crates/onboarding/src/theme_preview.rs @@ -206,16 +206,16 @@ impl ThemePreviewTile { sidebar_width, skeleton_height.clone(), )) - .child(Self::render_pane(seed, theme, skeleton_height.clone())) + .child(Self::render_pane(seed, theme, skeleton_height)) } fn render_borderless(seed: f32, theme: Arc) -> impl IntoElement { - return Self::render_editor( + Self::render_editor( seed, theme, Self::SIDEBAR_WIDTH_DEFAULT, Self::SKELETON_HEIGHT_DEFAULT, - ); + ) } fn render_border(seed: f32, theme: Arc) -> impl IntoElement { @@ -246,7 +246,7 @@ impl ThemePreviewTile { ) -> impl IntoElement { let sidebar_width = relative(0.20); - return div() + div() .size_full() .p(Self::ROOT_PADDING) .rounded(Self::ROOT_RADIUS) @@ -260,7 +260,7 @@ impl ThemePreviewTile { .overflow_hidden() .child(div().size_full().child(Self::render_editor( seed, - theme.clone(), + theme, sidebar_width, Self::SKELETON_HEIGHT_DEFAULT, ))) @@ -278,7 +278,7 @@ impl ThemePreviewTile { )), ), ) - .into_any_element(); + .into_any_element() } } @@ -329,9 +329,9 @@ impl Component for ThemePreviewTile { let themes_to_preview = vec![ one_dark.clone().ok(), - one_light.clone().ok(), - gruvbox_dark.clone().ok(), - gruvbox_light.clone().ok(), + one_light.ok(), + gruvbox_dark.ok(), + gruvbox_light.ok(), ] .into_iter() .flatten() @@ -348,7 +348,7 @@ impl Component for ThemePreviewTile { div() .w(px(240.)) .h(px(180.)) - .child(ThemePreviewTile::new(one_dark.clone(), 0.42)) + .child(ThemePreviewTile::new(one_dark, 0.42)) .into_any_element(), )])] } else { @@ -362,13 +362,12 @@ impl Component for ThemePreviewTile { .gap_4() .children( themes_to_preview - .iter() - .enumerate() - .map(|(_, theme)| { + .into_iter() + .map(|theme| { div() .w(px(200.)) .h(px(140.)) - .child(ThemePreviewTile::new(theme.clone(), 0.42)) + .child(ThemePreviewTile::new(theme, 0.42)) }) .collect::>(), ) diff --git a/crates/onboarding/src/welcome.rs b/crates/onboarding/src/welcome.rs index ba0053a3b68cf880918dcc60618dc4440e168968..8ff55d812b007d1b210781ec747b30cd1f505f35 100644 --- a/crates/onboarding/src/welcome.rs +++ b/crates/onboarding/src/welcome.rs @@ -37,7 +37,7 @@ const CONTENT: (Section<4>, Section<3>) = ( }, SectionEntry { icon: IconName::CloudDownload, - title: "Clone a Repo", + title: "Clone Repository", action: &git::Clone, }, SectionEntry { @@ -104,7 +104,7 @@ impl Section { self.entries .iter() .enumerate() - .map(|(index, entry)| entry.render(index_offset + index, &focus, window, cx)), + .map(|(index, entry)| entry.render(index_offset + index, focus, window, cx)), ) } } @@ -414,13 +414,19 @@ impl workspace::SerializableItem for WelcomePage { } mod persistence { - use db::{define_connection, query, sqlez_macros::sql}; + use db::{ + query, + sqlez::{domain::Domain, thread_safe_connection::ThreadSafeConnection}, + sqlez_macros::sql, + }; use workspace::WorkspaceDb; - define_connection! { - pub static ref WELCOME_PAGES: WelcomePagesDb = - &[ - sql!( + pub struct WelcomePagesDb(ThreadSafeConnection); + + impl Domain for WelcomePagesDb { + const NAME: &str = stringify!(WelcomePagesDb); + + const MIGRATIONS: &[&str] = (&[sql!( CREATE TABLE welcome_pages ( workspace_id INTEGER, item_id INTEGER UNIQUE, @@ -430,10 +436,11 @@ mod persistence { FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) ON DELETE CASCADE ) STRICT; - ), - ]; + )]); } + db::static_connection!(WELCOME_PAGES, WelcomePagesDb, [WorkspaceDb]); + impl WelcomePagesDb { query! { pub async fn save_welcome_page( diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 5801f29623ce8d9c515ef3d1756a2375bfdfcf4b..fda0544be1748f3bf958cd159bc55edccdbb5c14 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -9,7 +9,7 @@ use strum::EnumIter; pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1"; fn is_none_or_empty, U>(opt: &Option) -> bool { - opt.as_ref().map_or(true, |v| v.as_ref().is_empty()) + opt.as_ref().is_none_or(|v| v.as_ref().is_empty()) } #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] @@ -236,6 +236,13 @@ impl Model { Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false, } } + + /// Returns whether the given model supports the `prompt_cache_key` parameter. + /// + /// If the model does not support the parameter, do not pass it up. + pub fn supports_prompt_cache_key(&self) -> bool { + true + } } #[derive(Debug, Serialize, Deserialize)] @@ -257,15 +264,17 @@ pub struct Request { pub tools: Vec, #[serde(default, skip_serializing_if = "Option::is_none")] pub prompt_cache_key: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] pub reasoning_effort: Option, } #[derive(Debug, Serialize, Deserialize)] -#[serde(untagged)] +#[serde(rename_all = "lowercase")] pub enum ToolChoice { Auto, Required, None, + #[serde(untagged)] Other(ToolDefinition), } @@ -424,16 +433,20 @@ pub struct ChoiceDelta { pub finish_reason: Option, } +#[derive(Serialize, Deserialize, Debug)] +pub struct OpenAiError { + message: String, +} + #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] pub enum ResponseStreamResult { Ok(ResponseStreamEvent), - Err { error: String }, + Err { error: OpenAiError }, } #[derive(Serialize, Deserialize, Debug)] pub struct ResponseStreamEvent { - pub model: String, pub choices: Vec, pub usage: Option, } @@ -449,7 +462,7 @@ pub async fn stream_completion( .method(Method::POST) .uri(uri) .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)); + .header("Authorization", format!("Bearer {}", api_key.trim())); let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let mut response = client.send(request).await?; @@ -460,14 +473,14 @@ pub async fn stream_completion( .filter_map(|line| async move { match line { Ok(line) => { - let line = line.strip_prefix("data: ")?; + let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?; if line == "[DONE]" { None } else { match serde_json::from_str(line) { Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)), Ok(ResponseStreamResult::Err { error }) => { - Some(Err(anyhow!(error))) + Some(Err(anyhow!(error.message))) } Err(error) => { log::error!( @@ -494,11 +507,6 @@ pub async fn stream_completion( error: OpenAiError, } - #[derive(Deserialize)] - struct OpenAiError { - message: String, - } - match serde_json::from_str::(&body) { Ok(response) if !response.error.message.is_empty() => Err(anyhow!( "API request to {} failed: {}", @@ -558,7 +566,7 @@ pub fn embed<'a>( .method(Method::POST) .uri(uri) .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) + .header("Authorization", format!("Bearer {}", api_key.trim())) .body(body) .map(|request| client.send(request)); diff --git a/crates/open_router/Cargo.toml b/crates/open_router/Cargo.toml index bbc4fe190fa3985ef82505078d76dd06adf2abd9..8920c157dc3d6ea0974bd978816eb58cde19919d 100644 --- a/crates/open_router/Cargo.toml +++ b/crates/open_router/Cargo.toml @@ -22,4 +22,6 @@ http_client.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true +thiserror.workspace = true +strum.workspace = true workspace-hack.workspace = true diff --git a/crates/open_router/src/open_router.rs b/crates/open_router/src/open_router.rs index 3e6e406d9842d5996f2e866d534094ded23fd61c..cbc6c243d87c8f9ea3d0186dbecb8f0ac2e10a90 100644 --- a/crates/open_router/src/open_router.rs +++ b/crates/open_router/src/open_router.rs @@ -1,14 +1,33 @@ -use anyhow::{Context, Result, anyhow}; +use anyhow::{Result, anyhow}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; -use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http}; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::convert::TryFrom; +use std::{convert::TryFrom, io, time::Duration}; +use strum::EnumString; +use thiserror::Error; pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1"; +fn extract_retry_after(headers: &http::HeaderMap) -> Option { + if let Some(reset) = headers.get("X-RateLimit-Reset") { + if let Ok(s) = reset.to_str() { + if let Ok(epoch_ms) = s.parse::() { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as u64; + if epoch_ms > now { + return Some(std::time::Duration::from_millis(epoch_ms - now)); + } + } + } + } + None +} + fn is_none_or_empty, U>(opt: &Option) -> bool { - opt.as_ref().map_or(true, |v| v.as_ref().is_empty()) + opt.as_ref().is_none_or(|v| v.as_ref().is_empty()) } #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] @@ -240,10 +259,10 @@ impl MessageContent { impl From> for MessageContent { fn from(parts: Vec) -> Self { - if parts.len() == 1 { - if let MessagePart::Text { text } = &parts[0] { - return Self::Plain(text.clone()); - } + if parts.len() == 1 + && let MessagePart::Text { text } = &parts[0] + { + return Self::Plain(text.clone()); } Self::Multipart(parts) } @@ -413,76 +432,12 @@ pub struct ModelArchitecture { pub input_modalities: Vec, } -pub async fn complete( - client: &dyn HttpClient, - api_url: &str, - api_key: &str, - request: Request, -) -> Result { - let uri = format!("{api_url}/chat/completions"); - let request_builder = HttpRequest::builder() - .method(Method::POST) - .uri(uri) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) - .header("HTTP-Referer", "https://zed.dev") - .header("X-Title", "Zed Editor"); - - let mut request_body = request; - request_body.stream = false; - - let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?; - let mut response = client.send(request).await?; - - if response.status().is_success() { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - let response: Response = serde_json::from_str(&body)?; - Ok(response) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - #[derive(Deserialize)] - struct OpenRouterResponse { - error: OpenRouterError, - } - - #[derive(Deserialize)] - struct OpenRouterError { - message: String, - #[serde(default)] - code: String, - } - - match serde_json::from_str::(&body) { - Ok(response) if !response.error.message.is_empty() => { - let error_message = if !response.error.code.is_empty() { - format!("{}: {}", response.error.code, response.error.message) - } else { - response.error.message - }; - - Err(anyhow!( - "Failed to connect to OpenRouter API: {}", - error_message - )) - } - _ => Err(anyhow!( - "Failed to connect to OpenRouter API: {} {}", - response.status(), - body, - )), - } - } -} - pub async fn stream_completion( client: &dyn HttpClient, api_url: &str, api_key: &str, request: Request, -) -> Result>> { +) -> Result>, OpenRouterError> { let uri = format!("{api_url}/chat/completions"); let request_builder = HttpRequest::builder() .method(Method::POST) @@ -492,8 +447,15 @@ pub async fn stream_completion( .header("HTTP-Referer", "https://zed.dev") .header("X-Title", "Zed Editor"); - let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; - let mut response = client.send(request).await?; + let request = request_builder + .body(AsyncBody::from( + serde_json::to_string(&request).map_err(OpenRouterError::SerializeRequest)?, + )) + .map_err(OpenRouterError::BuildRequestBody)?; + let mut response = client + .send(request) + .await + .map_err(OpenRouterError::HttpSend)?; if response.status().is_success() { let reader = BufReader::new(response.into_body()); @@ -513,86 +475,89 @@ pub async fn stream_completion( match serde_json::from_str::(line) { Ok(response) => Some(Ok(response)), Err(error) => { - #[derive(Deserialize)] - struct ErrorResponse { - error: String, - } - - match serde_json::from_str::(line) { - Ok(err_response) => Some(Err(anyhow!(err_response.error))), - Err(_) => { - if line.trim().is_empty() { - None - } else { - Some(Err(anyhow!( - "Failed to parse response: {}. Original content: '{}'", - error, line - ))) - } - } + if line.trim().is_empty() { + None + } else { + Some(Err(OpenRouterError::DeserializeResponse(error))) } } } } } - Err(error) => Some(Err(anyhow!(error))), + Err(error) => Some(Err(OpenRouterError::ReadResponse(error))), } }) .boxed()) } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; + let code = ApiErrorCode::from_status(response.status().as_u16()); - #[derive(Deserialize)] - struct OpenRouterResponse { - error: OpenRouterError, - } - - #[derive(Deserialize)] - struct OpenRouterError { - message: String, - #[serde(default)] - code: String, - } - - match serde_json::from_str::(&body) { - Ok(response) if !response.error.message.is_empty() => { - let error_message = if !response.error.code.is_empty() { - format!("{}: {}", response.error.code, response.error.message) - } else { - response.error.message - }; - - Err(anyhow!( - "Failed to connect to OpenRouter API: {}", - error_message - )) + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(OpenRouterError::ReadResponse)?; + + let error_response = match serde_json::from_str::(&body) { + Ok(OpenRouterErrorResponse { error }) => error, + Err(_) => OpenRouterErrorBody { + code: response.status().as_u16(), + message: body, + metadata: None, + }, + }; + + match code { + ApiErrorCode::RateLimitError => { + let retry_after = extract_retry_after(response.headers()); + Err(OpenRouterError::RateLimit { + retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)), + }) + } + ApiErrorCode::OverloadedError => { + let retry_after = extract_retry_after(response.headers()); + Err(OpenRouterError::ServerOverloaded { retry_after }) } - _ => Err(anyhow!( - "Failed to connect to OpenRouter API: {} {}", - response.status(), - body, - )), + _ => Err(OpenRouterError::ApiError(ApiError { + code: code, + message: error_response.message, + })), } } } -pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result> { - let uri = format!("{api_url}/models"); +pub async fn list_models( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, +) -> Result, OpenRouterError> { + let uri = format!("{api_url}/models/user"); let request_builder = HttpRequest::builder() .method(Method::GET) .uri(uri) - .header("Accept", "application/json"); + .header("Accept", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .header("HTTP-Referer", "https://zed.dev") + .header("X-Title", "Zed Editor"); - let request = request_builder.body(AsyncBody::default())?; - let mut response = client.send(request).await?; + let request = request_builder + .body(AsyncBody::default()) + .map_err(OpenRouterError::BuildRequestBody)?; + let mut response = client + .send(request) + .await + .map_err(OpenRouterError::HttpSend)?; let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(OpenRouterError::ReadResponse)?; if response.status().is_success() { let response: ListModelsResponse = - serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?; + serde_json::from_str(&body).map_err(OpenRouterError::DeserializeResponse)?; let models = response .data @@ -637,10 +602,141 @@ pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result(&body) { + Ok(OpenRouterErrorResponse { error }) => error, + Err(_) => OpenRouterErrorBody { + code: response.status().as_u16(), + message: body, + metadata: None, + }, + }; + + match code { + ApiErrorCode::RateLimitError => { + let retry_after = extract_retry_after(response.headers()); + Err(OpenRouterError::RateLimit { + retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)), + }) + } + ApiErrorCode::OverloadedError => { + let retry_after = extract_retry_after(response.headers()); + Err(OpenRouterError::ServerOverloaded { retry_after }) + } + _ => Err(OpenRouterError::ApiError(ApiError { + code: code, + message: error_response.message, + })), + } + } +} + +#[derive(Debug)] +pub enum OpenRouterError { + /// Failed to serialize the HTTP request body to JSON + SerializeRequest(serde_json::Error), + + /// Failed to construct the HTTP request body + BuildRequestBody(http::Error), + + /// Failed to send the HTTP request + HttpSend(anyhow::Error), + + /// Failed to deserialize the response from JSON + DeserializeResponse(serde_json::Error), + + /// Failed to read from response stream + ReadResponse(io::Error), + + /// Rate limit exceeded + RateLimit { retry_after: Duration }, + + /// Server overloaded + ServerOverloaded { retry_after: Option }, + + /// API returned an error response + ApiError(ApiError), +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct OpenRouterErrorBody { + pub code: u16, + pub message: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub metadata: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct OpenRouterErrorResponse { + pub error: OpenRouterErrorBody, +} + +#[derive(Debug, Serialize, Deserialize, Error)] +#[error("OpenRouter API Error: {code}: {message}")] +pub struct ApiError { + pub code: ApiErrorCode, + pub message: String, +} + +/// An OpenROuter API error code. +/// +#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, Serialize, Deserialize)] +#[strum(serialize_all = "snake_case")] +pub enum ApiErrorCode { + /// 400: Bad Request (invalid or missing params, CORS) + InvalidRequestError, + /// 401: Invalid credentials (OAuth session expired, disabled/invalid API key) + AuthenticationError, + /// 402: Your account or API key has insufficient credits. Add more credits and retry the request. + PaymentRequiredError, + /// 403: Your chosen model requires moderation and your input was flagged + PermissionError, + /// 408: Your request timed out + RequestTimedOut, + /// 429: You are being rate limited + RateLimitError, + /// 502: Your chosen model is down or we received an invalid response from it + ApiError, + /// 503: There is no available model provider that meets your routing requirements + OverloadedError, +} + +impl std::fmt::Display for ApiErrorCode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + ApiErrorCode::InvalidRequestError => "invalid_request_error", + ApiErrorCode::AuthenticationError => "authentication_error", + ApiErrorCode::PaymentRequiredError => "payment_required_error", + ApiErrorCode::PermissionError => "permission_error", + ApiErrorCode::RequestTimedOut => "request_timed_out", + ApiErrorCode::RateLimitError => "rate_limit_error", + ApiErrorCode::ApiError => "api_error", + ApiErrorCode::OverloadedError => "overloaded_error", + }; + write!(f, "{s}") + } +} + +impl ApiErrorCode { + pub fn from_status(status: u16) -> Self { + match status { + 400 => ApiErrorCode::InvalidRequestError, + 401 => ApiErrorCode::AuthenticationError, + 402 => ApiErrorCode::PaymentRequiredError, + 403 => ApiErrorCode::PermissionError, + 408 => ApiErrorCode::RequestTimedOut, + 429 => ApiErrorCode::RateLimitError, + 502 => ApiErrorCode::ApiError, + 503 => ApiErrorCode::OverloadedError, + _ => ApiErrorCode::ApiError, + } } } diff --git a/crates/outline/src/outline.rs b/crates/outline/src/outline.rs index 8c5e78d77bce76e62ef94d2501dbef588cd76f00..1f85d08cee850f18a30cca9b56061854f7cbc7b1 100644 --- a/crates/outline/src/outline.rs +++ b/crates/outline/src/outline.rs @@ -81,8 +81,19 @@ impl ModalView for OutlineView { } impl Render for OutlineView { - fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { - v_flex().w(rems(34.)).child(self.picker.clone()) + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + v_flex() + .w(rems(34.)) + .on_action(cx.listener( + |_this: &mut OutlineView, + _: &zed_actions::outline::ToggleOutline, + _window: &mut Window, + cx: &mut Context| { + // When outline::Toggle is triggered while the outline is open, dismiss it + cx.emit(DismissEvent); + }, + )) + .child(self.picker.clone()) } } diff --git a/crates/outline_panel/src/outline_panel.rs b/crates/outline_panel/src/outline_panel.rs index 1cda3897ec356c76b8abf4751bad6c35873c1300..a8485248dbe2e544c80d59b3ad549c54c49e6e51 100644 --- a/crates/outline_panel/src/outline_panel.rs +++ b/crates/outline_panel/src/outline_panel.rs @@ -503,16 +503,16 @@ impl SearchData { && multi_buffer_snapshot .chars_at(extended_context_left_border) .last() - .map_or(false, |c| !c.is_whitespace()); + .is_some_and(|c| !c.is_whitespace()); let truncated_right = entire_context_text .chars() .last() - .map_or(true, |c| !c.is_whitespace()) + .is_none_or(|c| !c.is_whitespace()) && extended_context_right_border > context_right_border && multi_buffer_snapshot .chars_at(extended_context_right_border) .next() - .map_or(false, |c| !c.is_whitespace()); + .is_some_and(|c| !c.is_whitespace()); search_match_indices.iter_mut().for_each(|range| { range.start = multi_buffer_snapshot.clip_offset( range.start.saturating_sub(left_whitespaces_offset), @@ -733,10 +733,11 @@ impl OutlinePanel { ) -> Entity { let project = workspace.project().clone(); let workspace_handle = cx.entity().downgrade(); - let outline_panel = cx.new(|cx| { + + cx.new(|cx| { let filter_editor = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); - editor.set_placeholder_text("Filter...", cx); + editor.set_placeholder_text("Filter...", window, cx); editor }); let filter_update_subscription = cx.subscribe_in( @@ -912,9 +913,7 @@ impl OutlinePanel { outline_panel.replace_active_editor(item, editor, window, cx); } outline_panel - }); - - outline_panel + }) } fn serialization_key(workspace: &Workspace) -> Option { @@ -1170,12 +1169,11 @@ impl OutlinePanel { }); } else { let mut offset = Point::default(); - if let Some(buffer_id) = scroll_to_buffer { - if multi_buffer_snapshot.as_singleton().is_none() - && !active_editor.read(cx).is_buffer_folded(buffer_id, cx) - { - offset.y = -(active_editor.read(cx).file_header_size() as f32); - } + if let Some(buffer_id) = scroll_to_buffer + && multi_buffer_snapshot.as_singleton().is_none() + && !active_editor.read(cx).is_buffer_folded(buffer_id, cx) + { + offset.y = -(active_editor.read(cx).file_header_size() as f32); } active_editor.update(cx, |editor, cx| { @@ -1260,7 +1258,7 @@ impl OutlinePanel { dirs_worktree_id == worktree_id && dirs .last() - .map_or(false, |dir| dir.path.as_ref() == parent_path) + .is_some_and(|dir| dir.path.as_ref() == parent_path) } _ => false, }) @@ -1454,9 +1452,7 @@ impl OutlinePanel { if self .unfolded_dirs .get(&directory_worktree) - .map_or(true, |unfolded_dirs| { - !unfolded_dirs.contains(&directory_entry.id) - }) + .is_none_or(|unfolded_dirs| !unfolded_dirs.contains(&directory_entry.id)) { return false; } @@ -1606,16 +1602,14 @@ impl OutlinePanel { } PanelEntry::FoldedDirs(folded_dirs) => { let mut folded = false; - if let Some(dir_entry) = folded_dirs.entries.last() { - if self + if let Some(dir_entry) = folded_dirs.entries.last() + && self .collapsed_entries .insert(CollapsedEntry::Dir(folded_dirs.worktree_id, dir_entry.id)) - { - folded = true; - buffers_to_fold.extend( - self.buffers_inside_directory(folded_dirs.worktree_id, dir_entry), - ); - } + { + folded = true; + buffers_to_fold + .extend(self.buffers_inside_directory(folded_dirs.worktree_id, dir_entry)); } folded } @@ -2108,11 +2102,11 @@ impl OutlinePanel { dirs_to_expand.push(current_entry.id); } - if traversal.back_to_parent() { - if let Some(parent_entry) = traversal.entry() { - current_entry = parent_entry.clone(); - continue; - } + if traversal.back_to_parent() + && let Some(parent_entry) = traversal.entry() + { + current_entry = parent_entry.clone(); + continue; } break; } @@ -2159,7 +2153,7 @@ impl OutlinePanel { ExcerptOutlines::Invalidated(outlines) => Some(outlines), ExcerptOutlines::NotFetched => None, }) - .map_or(false, |outlines| !outlines.is_empty()); + .is_some_and(|outlines| !outlines.is_empty()); let is_expanded = !self .collapsed_entries .contains(&CollapsedEntry::Excerpt(excerpt.buffer_id, excerpt.id)); @@ -2475,17 +2469,17 @@ impl OutlinePanel { let search_data = match render_data.get() { Some(search_data) => search_data, None => { - if let ItemsDisplayMode::Search(search_state) = &mut self.mode { - if let Some(multi_buffer_snapshot) = multi_buffer_snapshot { - search_state - .highlight_search_match_tx - .try_send(HighlightArguments { - multi_buffer_snapshot: multi_buffer_snapshot.clone(), - match_range: match_range.clone(), - search_data: Arc::clone(render_data), - }) - .ok(); - } + if let ItemsDisplayMode::Search(search_state) = &mut self.mode + && let Some(multi_buffer_snapshot) = multi_buffer_snapshot + { + search_state + .highlight_search_match_tx + .try_send(HighlightArguments { + multi_buffer_snapshot: multi_buffer_snapshot.clone(), + match_range: match_range.clone(), + search_data: Arc::clone(render_data), + }) + .ok(); } return None; } @@ -2629,7 +2623,7 @@ impl OutlinePanel { } fn entry_name(&self, worktree_id: &WorktreeId, entry: &Entry, cx: &App) -> String { - let name = match self.project.read(cx).worktree_for_id(*worktree_id, cx) { + match self.project.read(cx).worktree_for_id(*worktree_id, cx) { Some(worktree) => { let worktree = worktree.read(cx); match worktree.snapshot().root_entry() { @@ -2650,8 +2644,7 @@ impl OutlinePanel { } } None => file_name(entry.path.as_ref()), - }; - name + } } fn update_fs_entries( @@ -2686,7 +2679,8 @@ impl OutlinePanel { new_collapsed_entries = outline_panel.collapsed_entries.clone(); new_unfolded_dirs = outline_panel.unfolded_dirs.clone(); let multi_buffer_snapshot = active_multi_buffer.read(cx).snapshot(cx); - let buffer_excerpts = multi_buffer_snapshot.excerpts().fold( + + multi_buffer_snapshot.excerpts().fold( HashMap::default(), |mut buffer_excerpts, (excerpt_id, buffer_snapshot, excerpt_range)| { let buffer_id = buffer_snapshot.remote_id(); @@ -2733,8 +2727,7 @@ impl OutlinePanel { ); buffer_excerpts }, - ); - buffer_excerpts + ) }) else { return; }; @@ -2833,11 +2826,12 @@ impl OutlinePanel { let new_entry_added = entries_to_add .insert(current_entry.id, current_entry) .is_none(); - if new_entry_added && traversal.back_to_parent() { - if let Some(parent_entry) = traversal.entry() { - current_entry = parent_entry.to_owned(); - continue; - } + if new_entry_added + && traversal.back_to_parent() + && let Some(parent_entry) = traversal.entry() + { + current_entry = parent_entry.to_owned(); + continue; } break; } @@ -2878,18 +2872,17 @@ impl OutlinePanel { entries .into_iter() .filter_map(|entry| { - if auto_fold_dirs { - if let Some(parent) = entry.path.parent() { - let children = new_children_count - .entry(worktree_id) - .or_default() - .entry(Arc::from(parent)) - .or_default(); - if entry.is_dir() { - children.dirs += 1; - } else { - children.files += 1; - } + if auto_fold_dirs && let Some(parent) = entry.path.parent() + { + let children = new_children_count + .entry(worktree_id) + .or_default() + .entry(Arc::from(parent)) + .or_default(); + if entry.is_dir() { + children.dirs += 1; + } else { + children.files += 1; } } @@ -2956,7 +2949,7 @@ impl OutlinePanel { .map(|(parent_dir_id, _)| { new_unfolded_dirs .get(&directory.worktree_id) - .map_or(true, |unfolded_dirs| { + .is_none_or(|unfolded_dirs| { unfolded_dirs .contains(parent_dir_id) }) @@ -3357,13 +3350,11 @@ impl OutlinePanel { let buffer_language = buffer_snapshot.language().cloned(); let fetched_outlines = cx .background_spawn(async move { - let mut outlines = buffer_snapshot - .outline_items_containing( - excerpt_range.context, - false, - Some(&syntax_theme), - ) - .unwrap_or_default(); + let mut outlines = buffer_snapshot.outline_items_containing( + excerpt_range.context, + false, + Some(&syntax_theme), + ); outlines.retain(|outline| { buffer_language.is_none() || buffer_language.as_ref() @@ -3409,30 +3400,29 @@ impl OutlinePanel { { excerpt.outlines = ExcerptOutlines::Outlines(fetched_outlines); - if let Some(default_depth) = pending_default_depth { - if let ExcerptOutlines::Outlines(outlines) = + if let Some(default_depth) = pending_default_depth + && let ExcerptOutlines::Outlines(outlines) = &excerpt.outlines - { - outlines - .iter() - .filter(|outline| { - (default_depth == 0 - || outline.depth >= default_depth) - && outlines_with_children.contains(&( - outline.range.clone(), - outline.depth, - )) - }) - .for_each(|outline| { - outline_panel.collapsed_entries.insert( - CollapsedEntry::Outline( - buffer_id, - excerpt_id, - outline.range.clone(), - ), - ); - }); - } + { + outlines + .iter() + .filter(|outline| { + (default_depth == 0 + || outline.depth >= default_depth) + && outlines_with_children.contains(&( + outline.range.clone(), + outline.depth, + )) + }) + .for_each(|outline| { + outline_panel.collapsed_entries.insert( + CollapsedEntry::Outline( + buffer_id, + excerpt_id, + outline.range.clone(), + ), + ); + }); } // Even if no outlines to check, we still need to update cached entries @@ -3448,9 +3438,8 @@ impl OutlinePanel { } fn is_singleton_active(&self, cx: &App) -> bool { - self.active_editor().map_or(false, |active_editor| { - active_editor.read(cx).buffer().read(cx).is_singleton() - }) + self.active_editor() + .is_some_and(|active_editor| active_editor.read(cx).buffer().read(cx).is_singleton()) } fn invalidate_outlines(&mut self, ids: &[ExcerptId]) { @@ -3611,10 +3600,9 @@ impl OutlinePanel { .update_in(cx, |outline_panel, window, cx| { outline_panel.cached_entries = new_cached_entries; outline_panel.max_width_item_index = max_width_item_index; - if outline_panel.selected_entry.is_invalidated() - || matches!(outline_panel.selected_entry, SelectedEntry::None) - { - if let Some(new_selected_entry) = + if (outline_panel.selected_entry.is_invalidated() + || matches!(outline_panel.selected_entry, SelectedEntry::None)) + && let Some(new_selected_entry) = outline_panel.active_editor().and_then(|active_editor| { outline_panel.location_for_editor_selection( &active_editor, @@ -3622,9 +3610,8 @@ impl OutlinePanel { cx, ) }) - { - outline_panel.select_entry(new_selected_entry, false, window, cx); - } + { + outline_panel.select_entry(new_selected_entry, false, window, cx); } outline_panel.autoscroll(cx); @@ -3670,7 +3657,7 @@ impl OutlinePanel { let is_root = project .read(cx) .worktree_for_id(directory_entry.worktree_id, cx) - .map_or(false, |worktree| { + .is_some_and(|worktree| { worktree.read(cx).root_entry() == Some(&directory_entry.entry) }); let folded = auto_fold_dirs @@ -3678,7 +3665,7 @@ impl OutlinePanel { && outline_panel .unfolded_dirs .get(&directory_entry.worktree_id) - .map_or(true, |unfolded_dirs| { + .is_none_or(|unfolded_dirs| { !unfolded_dirs.contains(&directory_entry.entry.id) }); let fs_depth = outline_panel @@ -3758,7 +3745,7 @@ impl OutlinePanel { .iter() .rev() .nth(folded_dirs.entries.len() + 1) - .map_or(true, |parent| parent.expanded); + .is_none_or(|parent| parent.expanded); if start_of_collapsed_dir_sequence || parent_expanded || query.is_some() @@ -3818,7 +3805,7 @@ impl OutlinePanel { .iter() .all(|entry| entry.path != parent.path) }) - .map_or(true, |parent| parent.expanded); + .is_none_or(|parent| parent.expanded); if !is_singleton && (parent_expanded || query.is_some()) { outline_panel.push_entry( &mut generation_state, @@ -3843,7 +3830,7 @@ impl OutlinePanel { .iter() .all(|entry| entry.path != parent.path) }) - .map_or(true, |parent| parent.expanded); + .is_none_or(|parent| parent.expanded); if !is_singleton && (parent_expanded || query.is_some()) { outline_panel.push_entry( &mut generation_state, @@ -3921,19 +3908,19 @@ impl OutlinePanel { } else { None }; - if let Some((buffer_id, entry_excerpts)) = excerpts_to_consider { - if !active_editor.read(cx).is_buffer_folded(buffer_id, cx) { - outline_panel.add_excerpt_entries( - &mut generation_state, - buffer_id, - entry_excerpts, - depth, - track_matches, - is_singleton, - query.as_deref(), - cx, - ); - } + if let Some((buffer_id, entry_excerpts)) = excerpts_to_consider + && !active_editor.read(cx).is_buffer_folded(buffer_id, cx) + { + outline_panel.add_excerpt_entries( + &mut generation_state, + buffer_id, + entry_excerpts, + depth, + track_matches, + is_singleton, + query.as_deref(), + cx, + ); } } } @@ -3964,7 +3951,7 @@ impl OutlinePanel { .iter() .all(|entry| entry.path != parent.path) }) - .map_or(true, |parent| parent.expanded); + .is_none_or(|parent| parent.expanded); if parent_expanded || query.is_some() { outline_panel.push_entry( &mut generation_state, @@ -4404,15 +4391,16 @@ impl OutlinePanel { }) .filter(|(match_range, _)| { let editor = active_editor.read(cx); - if let Some(buffer_id) = match_range.start.buffer_id { - if editor.is_buffer_folded(buffer_id, cx) { - return false; - } + let snapshot = editor.buffer().read(cx).snapshot(cx); + if let Some(buffer_id) = snapshot.buffer_id_for_anchor(match_range.start) + && editor.is_buffer_folded(buffer_id, cx) + { + return false; } - if let Some(buffer_id) = match_range.start.buffer_id { - if editor.is_buffer_folded(buffer_id, cx) { - return false; - } + if let Some(buffer_id) = snapshot.buffer_id_for_anchor(match_range.end) + && editor.is_buffer_folded(buffer_id, cx) + { + return false; } true }); @@ -4444,7 +4432,7 @@ impl OutlinePanel { } fn should_replace_active_item(&self, new_active_item: &dyn ItemHandle) -> bool { - self.active_item().map_or(true, |active_item| { + self.active_item().is_none_or(|active_item| { !self.pinned && active_item.item_id() != new_active_item.item_id() }) } @@ -4456,16 +4444,14 @@ impl OutlinePanel { cx: &mut Context, ) { self.pinned = !self.pinned; - if !self.pinned { - if let Some((active_item, active_editor)) = self + if !self.pinned + && let Some((active_item, active_editor)) = self .workspace .upgrade() .and_then(|workspace| workspace_active_editor(workspace.read(cx), cx)) - { - if self.should_replace_active_item(active_item.as_ref()) { - self.replace_active_editor(active_item, active_editor, window, cx); - } - } + && self.should_replace_active_item(active_item.as_ref()) + { + self.replace_active_editor(active_item, active_editor, window, cx); } cx.notify(); @@ -4815,51 +4801,45 @@ impl OutlinePanel { .when(show_indent_guides, |list| { list.with_decoration( ui::indent_guides(px(indent_size), IndentGuideColors::panel(cx)) - .with_compute_indents_fn( - cx.entity().clone(), - |outline_panel, range, _, _| { - let entries = outline_panel.cached_entries.get(range); - if let Some(entries) = entries { - entries.into_iter().map(|item| item.depth).collect() - } else { - smallvec::SmallVec::new() - } - }, - ) - .with_render_fn( - cx.entity().clone(), - move |outline_panel, params, _, _| { - const LEFT_OFFSET: Pixels = px(14.); - - let indent_size = params.indent_size; - let item_height = params.item_height; - let active_indent_guide_ix = find_active_indent_guide_ix( - outline_panel, - ¶ms.indent_guides, - ); + .with_compute_indents_fn(cx.entity(), |outline_panel, range, _, _| { + let entries = outline_panel.cached_entries.get(range); + if let Some(entries) = entries { + entries.iter().map(|item| item.depth).collect() + } else { + smallvec::SmallVec::new() + } + }) + .with_render_fn(cx.entity(), move |outline_panel, params, _, _| { + const LEFT_OFFSET: Pixels = px(14.); + + let indent_size = params.indent_size; + let item_height = params.item_height; + let active_indent_guide_ix = find_active_indent_guide_ix( + outline_panel, + ¶ms.indent_guides, + ); - params - .indent_guides - .into_iter() - .enumerate() - .map(|(ix, layout)| { - let bounds = Bounds::new( - point( - layout.offset.x * indent_size + LEFT_OFFSET, - layout.offset.y * item_height, - ), - size(px(1.), layout.length * item_height), - ); - ui::RenderedIndentGuide { - bounds, - layout, - is_active: active_indent_guide_ix == Some(ix), - hitbox: None, - } - }) - .collect() - }, - ), + params + .indent_guides + .into_iter() + .enumerate() + .map(|(ix, layout)| { + let bounds = Bounds::new( + point( + layout.offset.x * indent_size + LEFT_OFFSET, + layout.offset.y * item_height, + ), + size(px(1.), layout.length * item_height), + ); + ui::RenderedIndentGuide { + bounds, + layout, + is_active: active_indent_guide_ix == Some(ix), + hitbox: None, + } + }) + .collect() + }), ) }) }; @@ -5073,24 +5053,23 @@ impl Panel for OutlinePanel { let old_active = outline_panel.active; outline_panel.active = active; if old_active != active { - if active { - if let Some((active_item, active_editor)) = + if active + && let Some((active_item, active_editor)) = outline_panel.workspace.upgrade().and_then(|workspace| { workspace_active_editor(workspace.read(cx), cx) }) - { - if outline_panel.should_replace_active_item(active_item.as_ref()) { - outline_panel.replace_active_editor( - active_item, - active_editor, - window, - cx, - ); - } else { - outline_panel.update_fs_entries(active_editor, None, window, cx) - } - return; + { + if outline_panel.should_replace_active_item(active_item.as_ref()) { + outline_panel.replace_active_editor( + active_item, + active_editor, + window, + cx, + ); + } else { + outline_panel.update_fs_entries(active_editor, None, window, cx) } + return; } if !outline_panel.pinned { @@ -5111,7 +5090,7 @@ impl Panel for OutlinePanel { impl Focusable for OutlinePanel { fn focus_handle(&self, cx: &App) -> FocusHandle { - self.filter_editor.focus_handle(cx).clone() + self.filter_editor.focus_handle(cx) } } @@ -5121,9 +5100,9 @@ impl EventEmitter for OutlinePanel {} impl Render for OutlinePanel { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let (is_local, is_via_ssh) = self - .project - .read_with(cx, |project, _| (project.is_local(), project.is_via_ssh())); + let (is_local, is_via_ssh) = self.project.read_with(cx, |project, _| { + (project.is_local(), project.is_via_remote_server()) + }); let query = self.query(cx); let pinned = self.pinned; let settings = OutlinePanelSettings::get_global(cx); @@ -5325,8 +5304,8 @@ fn subscribe_for_editor_events( }) .copied(), ); - if !ignore_selections_change { - if let Some(entry_to_select) = latest_unfolded_buffer_id + if !ignore_selections_change + && let Some(entry_to_select) = latest_unfolded_buffer_id .or(latest_folded_buffer_id) .and_then(|toggled_buffer_id| { outline_panel.fs_entries.iter().find_map( @@ -5350,16 +5329,15 @@ fn subscribe_for_editor_events( ) }) .map(PanelEntry::Fs) - { - outline_panel.select_entry(entry_to_select, true, window, cx); - } + { + outline_panel.select_entry(entry_to_select, true, window, cx); } outline_panel.update_fs_entries(editor.clone(), debounce, window, cx); } EditorEvent::Reparsed(buffer_id) => { if let Some(excerpts) = outline_panel.excerpts.get_mut(buffer_id) { - for (_, excerpt) in excerpts { + for excerpt in excerpts.values_mut() { excerpt.invalidate_outlines(); } } @@ -5422,8 +5400,9 @@ mod tests { init_test(cx); let fs = FakeFs::new(cx.background_executor.clone()); - populate_with_test_ra_project(&fs, "/rust-analyzer").await; - let project = Project::test(fs.clone(), ["/rust-analyzer".as_ref()], cx).await; + let root = path!("/rust-analyzer"); + populate_with_test_ra_project(&fs, root).await; + let project = Project::test(fs.clone(), [Path::new(root)], cx).await; project.read_with(cx, |project, _| { project.languages().add(Arc::new(rust_lang())) }); @@ -5468,15 +5447,16 @@ mod tests { }); }); - let all_matches = r#"/rust-analyzer/ + let all_matches = format!( + r#"{root}/ crates/ ide/src/ inlay_hints/ fn_lifetime_fn.rs - search: match config.param_names_for_lifetime_elision_hints { - search: allocated_lifetimes.push(if config.param_names_for_lifetime_elision_hints { - search: Some(it) if config.param_names_for_lifetime_elision_hints => { - search: InlayHintsConfig { param_names_for_lifetime_elision_hints: true, ..TEST_CONFIG }, + search: match config.param_names_for_lifetime_elision_hints {{ + search: allocated_lifetimes.push(if config.param_names_for_lifetime_elision_hints {{ + search: Some(it) if config.param_names_for_lifetime_elision_hints => {{ + search: InlayHintsConfig {{ param_names_for_lifetime_elision_hints: true, ..TEST_CONFIG }}, inlay_hints.rs search: pub param_names_for_lifetime_elision_hints: bool, search: param_names_for_lifetime_elision_hints: self @@ -5487,7 +5467,9 @@ mod tests { analysis_stats.rs search: param_names_for_lifetime_elision_hints: true, config.rs - search: param_names_for_lifetime_elision_hints: self"#; + search: param_names_for_lifetime_elision_hints: self"# + ); + let select_first_in_all_matches = |line_to_select: &str| { assert!(all_matches.contains(line_to_select)); all_matches.replacen( @@ -5504,7 +5486,7 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -5520,7 +5502,7 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -5538,13 +5520,13 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, ), format!( - r#"/rust-analyzer/ + r#"{root}/ crates/ ide/src/ inlay_hints/ @@ -5575,7 +5557,7 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -5589,7 +5571,7 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -5608,13 +5590,13 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, ), format!( - r#"/rust-analyzer/ + r#"{root}/ crates/ ide/src/{SELECTED_MARKER} rust-analyzer/src/ @@ -5636,7 +5618,7 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -5651,8 +5633,9 @@ mod tests { init_test(cx); let fs = FakeFs::new(cx.background_executor.clone()); - populate_with_test_ra_project(&fs, "/rust-analyzer").await; - let project = Project::test(fs.clone(), ["/rust-analyzer".as_ref()], cx).await; + let root = path!("/rust-analyzer"); + populate_with_test_ra_project(&fs, root).await; + let project = Project::test(fs.clone(), [Path::new(root)], cx).await; project.read_with(cx, |project, _| { project.languages().add(Arc::new(rust_lang())) }); @@ -5696,15 +5679,16 @@ mod tests { ); }); }); - let all_matches = r#"/rust-analyzer/ + let all_matches = format!( + r#"{root}/ crates/ ide/src/ inlay_hints/ fn_lifetime_fn.rs - search: match config.param_names_for_lifetime_elision_hints { - search: allocated_lifetimes.push(if config.param_names_for_lifetime_elision_hints { - search: Some(it) if config.param_names_for_lifetime_elision_hints => { - search: InlayHintsConfig { param_names_for_lifetime_elision_hints: true, ..TEST_CONFIG }, + search: match config.param_names_for_lifetime_elision_hints {{ + search: allocated_lifetimes.push(if config.param_names_for_lifetime_elision_hints {{ + search: Some(it) if config.param_names_for_lifetime_elision_hints => {{ + search: InlayHintsConfig {{ param_names_for_lifetime_elision_hints: true, ..TEST_CONFIG }}, inlay_hints.rs search: pub param_names_for_lifetime_elision_hints: bool, search: param_names_for_lifetime_elision_hints: self @@ -5715,7 +5699,8 @@ mod tests { analysis_stats.rs search: param_names_for_lifetime_elision_hints: true, config.rs - search: param_names_for_lifetime_elision_hints: self"#; + search: param_names_for_lifetime_elision_hints: self"# + ); cx.executor() .advance_clock(UPDATE_DEBOUNCE + Duration::from_millis(100)); @@ -5724,7 +5709,7 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, None, cx, @@ -5747,7 +5732,7 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, None, cx, @@ -5773,7 +5758,7 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, None, cx, @@ -5788,8 +5773,9 @@ mod tests { init_test(cx); let fs = FakeFs::new(cx.background_executor.clone()); - populate_with_test_ra_project(&fs, path!("/rust-analyzer")).await; - let project = Project::test(fs.clone(), [path!("/rust-analyzer").as_ref()], cx).await; + let root = path!("/rust-analyzer"); + populate_with_test_ra_project(&fs, root).await; + let project = Project::test(fs.clone(), [Path::new(root)], cx).await; project.read_with(cx, |project, _| { project.languages().add(Arc::new(rust_lang())) }); @@ -5833,9 +5819,8 @@ mod tests { ); }); }); - let root_path = format!("{}/", path!("/rust-analyzer")); let all_matches = format!( - r#"{root_path} + r#"{root}/ crates/ ide/src/ inlay_hints/ @@ -5879,7 +5864,7 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -5902,7 +5887,7 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -5939,7 +5924,7 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -5976,7 +5961,7 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -5997,7 +5982,7 @@ mod tests { let fs = FakeFs::new(cx.background_executor.clone()); fs.insert_tree( - "/root", + path!("/root"), json!({ "one": { "a.txt": "aaa aaa" @@ -6009,7 +5994,7 @@ mod tests { }), ) .await; - let project = Project::test(fs.clone(), [Path::new("/root/one")], cx).await; + let project = Project::test(fs.clone(), [Path::new(path!("/root/one"))], cx).await; let workspace = add_outline_panel(&project, cx).await; let cx = &mut VisualTestContext::from_window(*workspace, cx); let outline_panel = outline_panel(&workspace, cx); @@ -6020,7 +6005,7 @@ mod tests { let items = workspace .update(cx, |workspace, window, cx| { workspace.open_paths( - vec![PathBuf::from("/root/two")], + vec![PathBuf::from(path!("/root/two"))], OpenOptions { visible: Some(OpenVisible::OnlyDirectories), ..Default::default() @@ -6079,18 +6064,22 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, ), - r#"/root/one/ + format!( + r#"{}/ a.txt search: aaa aaa <==== selected search: aaa aaa -/root/two/ +{}/ b.txt - search: a aaa"# + search: a aaa"#, + path!("/root/one"), + path!("/root/two"), + ), ); }); @@ -6105,16 +6094,20 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, ), - r#"/root/one/ + format!( + r#"{}/ a.txt <==== selected -/root/two/ +{}/ b.txt - search: a aaa"# + search: a aaa"#, + path!("/root/one"), + path!("/root/two"), + ), ); }); @@ -6129,14 +6122,18 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, ), - r#"/root/one/ + format!( + r#"{}/ a.txt -/root/two/ <==== selected"# +{}/ <==== selected"#, + path!("/root/one"), + path!("/root/two"), + ), ); }); @@ -6150,16 +6147,20 @@ mod tests { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, ), - r#"/root/one/ + format!( + r#"{}/ a.txt -/root/two/ <==== selected +{}/ <==== selected b.txt - search: a aaa"# + search: a aaa"#, + path!("/root/one"), + path!("/root/two"), + ) ); }); } @@ -6185,7 +6186,7 @@ struct OutlineEntryExcerpt { }), ) .await; - let project = Project::test(fs.clone(), [root.as_ref()], cx).await; + let project = Project::test(fs.clone(), [Path::new(root)], cx).await; project.read_with(cx, |project, _| { project.languages().add(Arc::new( rust_lang() @@ -6238,7 +6239,7 @@ struct OutlineEntryExcerpt { assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -6265,7 +6266,7 @@ outline: struct OutlineEntryExcerpt assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -6292,7 +6293,7 @@ outline: struct OutlineEntryExcerpt <==== selected assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -6319,7 +6320,7 @@ outline: struct OutlineEntryExcerpt assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -6346,7 +6347,7 @@ outline: struct OutlineEntryExcerpt assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -6373,7 +6374,7 @@ outline: struct OutlineEntryExcerpt assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -6400,7 +6401,7 @@ outline: struct OutlineEntryExcerpt <==== selected assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -6427,7 +6428,7 @@ outline: struct OutlineEntryExcerpt assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -6454,7 +6455,7 @@ outline: struct OutlineEntryExcerpt assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -6481,7 +6482,7 @@ outline: struct OutlineEntryExcerpt assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -6508,7 +6509,7 @@ outline: struct OutlineEntryExcerpt <==== selected assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -6528,7 +6529,7 @@ outline: struct OutlineEntryExcerpt async fn test_frontend_repo_structure(cx: &mut TestAppContext) { init_test(cx); - let root = "/frontend-project"; + let root = path!("/frontend-project"); let fs = FakeFs::new(cx.background_executor.clone()); fs.insert_tree( root, @@ -6565,7 +6566,7 @@ outline: struct OutlineEntryExcerpt }), ) .await; - let project = Project::test(fs.clone(), [root.as_ref()], cx).await; + let project = Project::test(fs.clone(), [Path::new(root)], cx).await; let workspace = add_outline_panel(&project, cx).await; let cx = &mut VisualTestContext::from_window(*workspace, cx); let outline_panel = outline_panel(&workspace, cx); @@ -6614,15 +6615,16 @@ outline: struct OutlineEntryExcerpt assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, ), - r#"/frontend-project/ + format!( + r#"{root}/ public/lottie/ syntax-tree.json - search: { "something": "static" } <==== selected + search: {{ "something": "static" }} <==== selected src/ app/(site)/ (about)/jobs/[slug]/ @@ -6634,6 +6636,7 @@ outline: struct OutlineEntryExcerpt components/ ErrorBoundary.tsx search: static"# + ) ); }); @@ -6651,20 +6654,22 @@ outline: struct OutlineEntryExcerpt assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, ), - r#"/frontend-project/ + format!( + r#"{root}/ public/lottie/ syntax-tree.json - search: { "something": "static" } + search: {{ "something": "static" }} src/ app/(site)/ <==== selected components/ ErrorBoundary.tsx search: static"# + ) ); }); @@ -6679,20 +6684,22 @@ outline: struct OutlineEntryExcerpt assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, ), - r#"/frontend-project/ + format!( + r#"{root}/ public/lottie/ syntax-tree.json - search: { "something": "static" } + search: {{ "something": "static" }} src/ app/(site)/ components/ ErrorBoundary.tsx search: static <==== selected"# + ) ); }); @@ -6711,19 +6718,21 @@ outline: struct OutlineEntryExcerpt assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, ), - r#"/frontend-project/ + format!( + r#"{root}/ public/lottie/ syntax-tree.json - search: { "something": "static" } + search: {{ "something": "static" }} src/ app/(site)/ components/ ErrorBoundary.tsx <==== selected"# + ) ); }); @@ -6742,20 +6751,22 @@ outline: struct OutlineEntryExcerpt assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, ), - r#"/frontend-project/ + format!( + r#"{root}/ public/lottie/ syntax-tree.json - search: { "something": "static" } + search: {{ "something": "static" }} src/ app/(site)/ components/ ErrorBoundary.tsx <==== selected search: static"# + ) ); }); } @@ -6870,7 +6881,7 @@ outline: struct OutlineEntryExcerpt .render_data .get_or_init(|| SearchData::new( &search_entry.match_range, - &multi_buffer_snapshot + multi_buffer_snapshot )) .context_text ) @@ -7261,7 +7272,7 @@ outline: struct OutlineEntryExcerpt assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -7320,7 +7331,7 @@ outline: fn main()" assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -7344,7 +7355,7 @@ outline: fn main()" assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -7409,7 +7420,7 @@ outline: fn main()" assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -7550,7 +7561,7 @@ outline: fn main()" assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -7588,7 +7599,7 @@ outline: fn main()" assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -7622,7 +7633,7 @@ outline: fn main()" assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, @@ -7654,7 +7665,7 @@ outline: fn main()" assert_eq!( display_entries( &project, - &snapshot(&outline_panel, cx), + &snapshot(outline_panel, cx), &outline_panel.cached_entries, outline_panel.selected_entry(), cx, diff --git a/crates/outline_panel/src/outline_panel_settings.rs b/crates/outline_panel/src/outline_panel_settings.rs index 133d28b748d2978e07a540b3c8c7517b03dc4767..dc123f2ba5fb38dd80b72aee8fc6ad6a000be23d 100644 --- a/crates/outline_panel/src/outline_panel_settings.rs +++ b/crates/outline_panel/src/outline_panel_settings.rs @@ -2,7 +2,7 @@ use editor::ShowScrollbar; use gpui::Pixels; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{Settings, SettingsKey, SettingsSources, SettingsUi}; #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, Copy, PartialEq)] #[serde(rename_all = "snake_case")] @@ -61,7 +61,8 @@ pub struct IndentGuidesSettingsContent { pub show: Option, } -#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)] +#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug, SettingsUi, SettingsKey)] +#[settings_key(key = "outline_panel")] pub struct OutlinePanelSettingsContent { /// Whether to show the outline panel button in the status bar. /// @@ -116,8 +117,6 @@ pub struct OutlinePanelSettingsContent { } impl Settings for OutlinePanelSettings { - const KEY: Option<&'static str> = Some("outline_panel"); - type FileContent = OutlinePanelSettingsContent; fn load( diff --git a/crates/panel/src/panel.rs b/crates/panel/src/panel.rs index 658a51167ba7da3f02c49ab77b50e72dabbbae57..1930f654e9b632e52719103e5b0a399cfe94f70a 100644 --- a/crates/panel/src/panel.rs +++ b/crates/panel/src/panel.rs @@ -52,7 +52,7 @@ impl RenderOnce for PanelTab { pub fn panel_button(label: impl Into) -> ui::Button { let label = label.into(); - let id = ElementId::Name(label.clone().to_lowercase().replace(' ', "_").into()); + let id = ElementId::Name(label.to_lowercase().replace(' ', "_").into()); ui::Button::new(id, label) .label_size(ui::LabelSize::Small) .icon_size(ui::IconSize::Small) diff --git a/crates/paths/src/paths.rs b/crates/paths/src/paths.rs index 47a0f12c0634dbde48d015e4f577519babc67b34..ede42af0272902892afd2e9dfdafb5c5eae2f8f5 100644 --- a/crates/paths/src/paths.rs +++ b/crates/paths/src/paths.rs @@ -33,6 +33,11 @@ pub fn remote_server_dir_relative() -> &'static Path { Path::new(".zed_server") } +/// Returns the relative path to the zed_wsl_server directory on the wsl host. +pub fn remote_wsl_server_dir_relative() -> &'static Path { + Path::new(".zed_wsl_server") +} + /// Sets a custom directory for all user data, overriding the default data directory. /// This function must be called before any other path operations that depend on the data directory. /// The directory's path will be canonicalized to an absolute path by a blocking FS operation. @@ -41,7 +46,7 @@ pub fn remote_server_dir_relative() -> &'static Path { /// # Arguments /// /// * `dir` - The path to use as the custom data directory. This will be used as the base -/// directory for all user data, including databases, extensions, and logs. +/// directory for all user data, including databases, extensions, and logs. /// /// # Returns /// @@ -63,7 +68,7 @@ pub fn set_custom_data_dir(dir: &str) -> &'static PathBuf { let abs_path = path .canonicalize() .expect("failed to canonicalize custom data directory's path to an absolute path"); - path = PathBuf::from(util::paths::SanitizedPath::from(abs_path)) + path = util::paths::SanitizedPath::new(&abs_path).into() } std::fs::create_dir_all(&path).expect("failed to create custom data directory"); path diff --git a/crates/picker/src/head.rs b/crates/picker/src/head.rs index aba7b8a1d05afdc1f485574178914f50f55bc12c..700896e3412bf96ceff25891c106d5a4dbc51460 100644 --- a/crates/picker/src/head.rs +++ b/crates/picker/src/head.rs @@ -23,7 +23,7 @@ impl Head { ) -> Self { let editor = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); - editor.set_placeholder_text(placeholder_text, cx); + editor.set_placeholder_text(placeholder_text.as_ref(), window, cx); editor }); cx.subscribe_in(&editor, window, edit_handler).detach(); diff --git a/crates/picker/src/picker.rs b/crates/picker/src/picker.rs index 34af5fed02e66fe242c398ebcf910bc89d81a256..4bd8ac99cbd9b5fe793e8b0cfe1926732920d0a1 100644 --- a/crates/picker/src/picker.rs +++ b/crates/picker/src/picker.rs @@ -615,7 +615,7 @@ impl Picker { Head::Editor(editor) => { let placeholder = self.delegate.placeholder_text(window, cx); editor.update(cx, |editor, cx| { - editor.set_placeholder_text(placeholder, cx); + editor.set_placeholder_text(placeholder.as_ref(), window, cx); cx.notify(); }); } diff --git a/crates/picker/src/popover_menu.rs b/crates/picker/src/popover_menu.rs index d05308ee71e87a472ffcb33e9727ef74fae70602..baf0918fd6c8e20211d04a150af9220cb2d66839 100644 --- a/crates/picker/src/popover_menu.rs +++ b/crates/picker/src/popover_menu.rs @@ -85,7 +85,7 @@ where .menu(move |_window, _cx| Some(picker.clone())) .trigger_with_tooltip(self.trigger, self.tooltip) .anchor(self.anchor) - .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle)) + .when_some(self.handle, |menu, handle| menu.with_handle(handle)) .offset(gpui::Point { x: px(0.0), y: px(-2.0), diff --git a/crates/prettier/src/prettier.rs b/crates/prettier/src/prettier.rs index 33320e6845964932aa7bfe051f3ffe4fba1a6168..32e39d466f1a236da72b746fb4bf2a24b7300385 100644 --- a/crates/prettier/src/prettier.rs +++ b/crates/prettier/src/prettier.rs @@ -119,7 +119,7 @@ impl Prettier { None } }).any(|workspace_definition| { - workspace_definition == subproject_path.to_string_lossy() || PathMatcher::new(&[workspace_definition]).ok().map_or(false, |path_matcher| path_matcher.is_match(subproject_path)) + workspace_definition == subproject_path.to_string_lossy() || PathMatcher::new(&[workspace_definition]).ok().is_some_and(|path_matcher| path_matcher.is_match(subproject_path)) }) { anyhow::ensure!(has_prettier_in_node_modules(fs, &path_to_check).await?, "Path {path_to_check:?} is the workspace root for project in {closest_package_json_path:?}, but it has no prettier installed"); log::info!("Found prettier path {path_to_check:?} in the workspace root for project in {closest_package_json_path:?}"); @@ -185,11 +185,11 @@ impl Prettier { .metadata(&ignore_path) .await .with_context(|| format!("fetching metadata for {ignore_path:?}"))? + && !metadata.is_dir + && !metadata.is_symlink { - if !metadata.is_dir && !metadata.is_symlink { - log::info!("Found prettier ignore at {ignore_path:?}"); - return Ok(ControlFlow::Continue(Some(path_to_check))); - } + log::info!("Found prettier ignore at {ignore_path:?}"); + return Ok(ControlFlow::Continue(Some(path_to_check))); } match &closest_package_json_path { None => closest_package_json_path = Some(path_to_check.clone()), @@ -217,19 +217,19 @@ impl Prettier { workspace_definition == subproject_path.to_string_lossy() || PathMatcher::new(&[workspace_definition]) .ok() - .map_or(false, |path_matcher| { + .is_some_and(|path_matcher| { path_matcher.is_match(subproject_path) }) }) { let workspace_ignore = path_to_check.join(".prettierignore"); - if let Some(metadata) = fs.metadata(&workspace_ignore).await? { - if !metadata.is_dir { - log::info!( - "Found prettier ignore at workspace root {workspace_ignore:?}" - ); - return Ok(ControlFlow::Continue(Some(path_to_check))); - } + if let Some(metadata) = fs.metadata(&workspace_ignore).await? + && !metadata.is_dir + { + log::info!( + "Found prettier ignore at workspace root {workspace_ignore:?}" + ); + return Ok(ControlFlow::Continue(Some(path_to_check))); } } } @@ -549,18 +549,16 @@ async fn read_package_json( .metadata(&possible_package_json) .await .with_context(|| format!("fetching metadata for package json {possible_package_json:?}"))? + && !package_json_metadata.is_dir + && !package_json_metadata.is_symlink { - if !package_json_metadata.is_dir && !package_json_metadata.is_symlink { - let package_json_contents = fs - .load(&possible_package_json) - .await - .with_context(|| format!("reading {possible_package_json:?} file contents"))?; - return serde_json::from_str::>( - &package_json_contents, - ) + let package_json_contents = fs + .load(&possible_package_json) + .await + .with_context(|| format!("reading {possible_package_json:?} file contents"))?; + return serde_json::from_str::>(&package_json_contents) .map(Some) .with_context(|| format!("parsing {possible_package_json:?} file contents")); - } } Ok(None) } diff --git a/crates/project/Cargo.toml b/crates/project/Cargo.toml index 57d6d6ca283af0fd51ed10622f55edc9fb086e7e..3d46a44770ec2504991899e98c1504116611c20b 100644 --- a/crates/project/Cargo.toml +++ b/crates/project/Cargo.toml @@ -67,6 +67,7 @@ regex.workspace = true remote.workspace = true rpc.workspace = true schemars.workspace = true +semver.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true @@ -85,6 +86,7 @@ text.workspace = true toml.workspace = true url.workspace = true util.workspace = true +watch.workspace = true which.workspace = true worktree.workspace = true zlog.workspace = true diff --git a/crates/project/src/agent_server_store.rs b/crates/project/src/agent_server_store.rs new file mode 100644 index 0000000000000000000000000000000000000000..bdb2297624e4a404cb3c918f07eab15004944f97 --- /dev/null +++ b/crates/project/src/agent_server_store.rs @@ -0,0 +1,1106 @@ +use std::{ + any::Any, + borrow::Borrow, + path::{Path, PathBuf}, + str::FromStr as _, + sync::Arc, + time::Duration, +}; + +use anyhow::{Context as _, Result, bail}; +use collections::HashMap; +use fs::{Fs, RemoveOptions, RenameOptions}; +use futures::StreamExt as _; +use gpui::{ + App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, SharedString, Subscription, Task, +}; +use node_runtime::NodeRuntime; +use remote::RemoteClient; +use rpc::{ + AnyProtoClient, TypedEnvelope, + proto::{self, ToProto}, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{SettingsKey, SettingsSources, SettingsStore, SettingsUi}; +use util::{ResultExt as _, debug_panic}; + +use crate::ProjectEnvironment; + +#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)] +pub struct AgentServerCommand { + #[serde(rename = "command")] + pub path: PathBuf, + #[serde(default)] + pub args: Vec, + pub env: Option>, +} + +impl std::fmt::Debug for AgentServerCommand { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let filtered_env = self.env.as_ref().map(|env| { + env.iter() + .map(|(k, v)| { + ( + k, + if util::redact::should_redact(k) { + "[REDACTED]" + } else { + v + }, + ) + }) + .collect::>() + }); + + f.debug_struct("AgentServerCommand") + .field("path", &self.path) + .field("args", &self.args) + .field("env", &filtered_env) + .finish() + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct ExternalAgentServerName(pub SharedString); + +impl std::fmt::Display for ExternalAgentServerName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From<&'static str> for ExternalAgentServerName { + fn from(value: &'static str) -> Self { + ExternalAgentServerName(value.into()) + } +} + +impl From for SharedString { + fn from(value: ExternalAgentServerName) -> Self { + value.0 + } +} + +impl Borrow for ExternalAgentServerName { + fn borrow(&self) -> &str { + &self.0 + } +} + +pub trait ExternalAgentServer { + fn get_command( + &mut self, + root_dir: Option<&str>, + extra_env: HashMap, + status_tx: Option>, + new_version_available_tx: Option>>, + cx: &mut AsyncApp, + ) -> Task)>>; + + fn as_any_mut(&mut self) -> &mut dyn Any; +} + +impl dyn ExternalAgentServer { + fn downcast_mut(&mut self) -> Option<&mut T> { + self.as_any_mut().downcast_mut() + } +} + +enum AgentServerStoreState { + Local { + node_runtime: NodeRuntime, + fs: Arc, + project_environment: Entity, + downstream_client: Option<(u64, AnyProtoClient)>, + settings: Option, + _subscriptions: [Subscription; 1], + }, + Remote { + project_id: u64, + upstream_client: Entity, + }, + Collab, +} + +pub struct AgentServerStore { + state: AgentServerStoreState, + external_agents: HashMap>, +} + +pub struct AgentServersUpdated; + +impl EventEmitter for AgentServerStore {} + +impl AgentServerStore { + pub fn init_remote(session: &AnyProtoClient) { + session.add_entity_message_handler(Self::handle_external_agents_updated); + session.add_entity_message_handler(Self::handle_loading_status_updated); + session.add_entity_message_handler(Self::handle_new_version_available); + } + + pub fn init_headless(session: &AnyProtoClient) { + session.add_entity_request_handler(Self::handle_get_agent_server_command); + } + + fn agent_servers_settings_changed(&mut self, cx: &mut Context) { + let AgentServerStoreState::Local { + node_runtime, + fs, + project_environment, + downstream_client, + settings: old_settings, + .. + } = &mut self.state + else { + debug_panic!( + "should not be subscribed to agent server settings changes in non-local project" + ); + return; + }; + + let new_settings = cx + .global::() + .get::(None) + .clone(); + if Some(&new_settings) == old_settings.as_ref() { + return; + } + + self.external_agents.clear(); + self.external_agents.insert( + GEMINI_NAME.into(), + Box::new(LocalGemini { + fs: fs.clone(), + node_runtime: node_runtime.clone(), + project_environment: project_environment.clone(), + custom_command: new_settings + .gemini + .clone() + .and_then(|settings| settings.custom_command()), + ignore_system_version: new_settings + .gemini + .as_ref() + .and_then(|settings| settings.ignore_system_version) + .unwrap_or(true), + }), + ); + self.external_agents.insert( + CLAUDE_CODE_NAME.into(), + Box::new(LocalClaudeCode { + fs: fs.clone(), + node_runtime: node_runtime.clone(), + project_environment: project_environment.clone(), + custom_command: new_settings + .claude + .clone() + .and_then(|settings| settings.custom_command()), + }), + ); + self.external_agents + .extend(new_settings.custom.iter().map(|(name, settings)| { + ( + ExternalAgentServerName(name.clone()), + Box::new(LocalCustomAgent { + command: settings.command.clone(), + project_environment: project_environment.clone(), + }) as Box, + ) + })); + + *old_settings = Some(new_settings.clone()); + + if let Some((project_id, downstream_client)) = downstream_client { + downstream_client + .send(proto::ExternalAgentsUpdated { + project_id: *project_id, + names: self + .external_agents + .keys() + .map(|name| name.to_string()) + .collect(), + }) + .log_err(); + } + cx.emit(AgentServersUpdated); + } + + pub fn local( + node_runtime: NodeRuntime, + fs: Arc, + project_environment: Entity, + cx: &mut Context, + ) -> Self { + let subscription = cx.observe_global::(|this, cx| { + this.agent_servers_settings_changed(cx); + }); + let this = Self { + state: AgentServerStoreState::Local { + node_runtime, + fs, + project_environment, + downstream_client: None, + settings: None, + _subscriptions: [subscription], + }, + external_agents: Default::default(), + }; + cx.spawn(async move |this, cx| { + cx.background_executor().timer(Duration::from_secs(1)).await; + this.update(cx, |this, cx| { + this.agent_servers_settings_changed(cx); + }) + .ok(); + }) + .detach(); + this + } + + pub(crate) fn remote( + project_id: u64, + upstream_client: Entity, + _cx: &mut Context, + ) -> Self { + // Set up the builtin agents here so they're immediately available in + // remote projects--we know that the HeadlessProject on the other end + // will have them. + let external_agents = [ + ( + GEMINI_NAME.into(), + Box::new(RemoteExternalAgentServer { + project_id, + upstream_client: upstream_client.clone(), + name: GEMINI_NAME.into(), + status_tx: None, + new_version_available_tx: None, + }) as Box, + ), + ( + CLAUDE_CODE_NAME.into(), + Box::new(RemoteExternalAgentServer { + project_id, + upstream_client: upstream_client.clone(), + name: CLAUDE_CODE_NAME.into(), + status_tx: None, + new_version_available_tx: None, + }) as Box, + ), + ] + .into_iter() + .collect(); + + Self { + state: AgentServerStoreState::Remote { + project_id, + upstream_client, + }, + external_agents, + } + } + + pub(crate) fn collab(_cx: &mut Context) -> Self { + Self { + state: AgentServerStoreState::Collab, + external_agents: Default::default(), + } + } + + pub fn shared(&mut self, project_id: u64, client: AnyProtoClient) { + match &mut self.state { + AgentServerStoreState::Local { + downstream_client, .. + } => { + client + .send(proto::ExternalAgentsUpdated { + project_id, + names: self + .external_agents + .keys() + .map(|name| name.to_string()) + .collect(), + }) + .log_err(); + *downstream_client = Some((project_id, client)); + } + AgentServerStoreState::Remote { .. } => { + debug_panic!( + "external agents over collab not implemented, remote project should not be shared" + ); + } + AgentServerStoreState::Collab => { + debug_panic!("external agents over collab not implemented, should not be shared"); + } + } + } + + pub fn get_external_agent( + &mut self, + name: &ExternalAgentServerName, + ) -> Option<&mut (dyn ExternalAgentServer + 'static)> { + self.external_agents + .get_mut(name) + .map(|agent| agent.as_mut()) + } + + pub fn external_agents(&self) -> impl Iterator { + self.external_agents.keys() + } + + async fn handle_get_agent_server_command( + this: Entity, + envelope: TypedEnvelope, + mut cx: AsyncApp, + ) -> Result { + let (command, root_dir, login) = this + .update(&mut cx, |this, cx| { + let AgentServerStoreState::Local { + downstream_client, .. + } = &this.state + else { + debug_panic!("should not receive GetAgentServerCommand in a non-local project"); + bail!("unexpected GetAgentServerCommand request in a non-local project"); + }; + let agent = this + .external_agents + .get_mut(&*envelope.payload.name) + .with_context(|| format!("agent `{}` not found", envelope.payload.name))?; + let (status_tx, new_version_available_tx) = downstream_client + .clone() + .map(|(project_id, downstream_client)| { + let (status_tx, mut status_rx) = watch::channel(SharedString::from("")); + let (new_version_available_tx, mut new_version_available_rx) = + watch::channel(None); + cx.spawn({ + let downstream_client = downstream_client.clone(); + let name = envelope.payload.name.clone(); + async move |_, _| { + while let Some(status) = status_rx.recv().await.ok() { + downstream_client.send( + proto::ExternalAgentLoadingStatusUpdated { + project_id, + name: name.clone(), + status: status.to_string(), + }, + )?; + } + anyhow::Ok(()) + } + }) + .detach_and_log_err(cx); + cx.spawn({ + let name = envelope.payload.name.clone(); + async move |_, _| { + if let Some(version) = + new_version_available_rx.recv().await.ok().flatten() + { + downstream_client.send( + proto::NewExternalAgentVersionAvailable { + project_id, + name: name.clone(), + version, + }, + )?; + } + anyhow::Ok(()) + } + }) + .detach_and_log_err(cx); + (status_tx, new_version_available_tx) + }) + .unzip(); + anyhow::Ok(agent.get_command( + envelope.payload.root_dir.as_deref(), + HashMap::default(), + status_tx, + new_version_available_tx, + &mut cx.to_async(), + )) + })?? + .await?; + Ok(proto::AgentServerCommand { + path: command.path.to_string_lossy().to_string(), + args: command.args, + env: command + .env + .map(|env| env.into_iter().collect()) + .unwrap_or_default(), + root_dir: root_dir, + login: login.map(|login| login.to_proto()), + }) + } + + async fn handle_external_agents_updated( + this: Entity, + envelope: TypedEnvelope, + mut cx: AsyncApp, + ) -> Result<()> { + this.update(&mut cx, |this, cx| { + let AgentServerStoreState::Remote { + project_id, + upstream_client, + } = &this.state + else { + debug_panic!( + "handle_external_agents_updated should not be called for a non-remote project" + ); + bail!("unexpected ExternalAgentsUpdated message") + }; + + let mut status_txs = this + .external_agents + .iter_mut() + .filter_map(|(name, agent)| { + Some(( + name.clone(), + agent + .downcast_mut::()? + .status_tx + .take(), + )) + }) + .collect::>(); + let mut new_version_available_txs = this + .external_agents + .iter_mut() + .filter_map(|(name, agent)| { + Some(( + name.clone(), + agent + .downcast_mut::()? + .new_version_available_tx + .take(), + )) + }) + .collect::>(); + + this.external_agents = envelope + .payload + .names + .into_iter() + .map(|name| { + let agent = RemoteExternalAgentServer { + project_id: *project_id, + upstream_client: upstream_client.clone(), + name: ExternalAgentServerName(name.clone().into()), + status_tx: status_txs.remove(&*name).flatten(), + new_version_available_tx: new_version_available_txs + .remove(&*name) + .flatten(), + }; + ( + ExternalAgentServerName(name.into()), + Box::new(agent) as Box, + ) + }) + .collect(); + cx.emit(AgentServersUpdated); + Ok(()) + })? + } + + async fn handle_loading_status_updated( + this: Entity, + envelope: TypedEnvelope, + mut cx: AsyncApp, + ) -> Result<()> { + this.update(&mut cx, |this, _| { + if let Some(agent) = this.external_agents.get_mut(&*envelope.payload.name) + && let Some(agent) = agent.downcast_mut::() + && let Some(status_tx) = &mut agent.status_tx + { + status_tx.send(envelope.payload.status.into()).ok(); + } + }) + } + + async fn handle_new_version_available( + this: Entity, + envelope: TypedEnvelope, + mut cx: AsyncApp, + ) -> Result<()> { + this.update(&mut cx, |this, _| { + if let Some(agent) = this.external_agents.get_mut(&*envelope.payload.name) + && let Some(agent) = agent.downcast_mut::() + && let Some(new_version_available_tx) = &mut agent.new_version_available_tx + { + new_version_available_tx + .send(Some(envelope.payload.version)) + .ok(); + } + }) + } +} + +fn get_or_npm_install_builtin_agent( + binary_name: SharedString, + package_name: SharedString, + entrypoint_path: PathBuf, + minimum_version: Option, + status_tx: Option>, + new_version_available: Option>>, + fs: Arc, + node_runtime: NodeRuntime, + cx: &mut AsyncApp, +) -> Task> { + cx.spawn(async move |cx| { + let node_path = node_runtime.binary_path().await?; + let dir = paths::data_dir() + .join("external_agents") + .join(binary_name.as_str()); + fs.create_dir(&dir).await?; + + let mut stream = fs.read_dir(&dir).await?; + let mut versions = Vec::new(); + let mut to_delete = Vec::new(); + while let Some(entry) = stream.next().await { + let Ok(entry) = entry else { continue }; + let Some(file_name) = entry.file_name() else { + continue; + }; + + if let Some(name) = file_name.to_str() + && let Some(version) = semver::Version::from_str(name).ok() + && fs + .is_file(&dir.join(file_name).join(&entrypoint_path)) + .await + { + versions.push((version, file_name.to_owned())); + } else { + to_delete.push(file_name.to_owned()) + } + } + + versions.sort(); + let newest_version = if let Some((version, file_name)) = versions.last().cloned() + && minimum_version.is_none_or(|minimum_version| version >= minimum_version) + { + versions.pop(); + Some(file_name) + } else { + None + }; + log::debug!("existing version of {package_name}: {newest_version:?}"); + to_delete.extend(versions.into_iter().map(|(_, file_name)| file_name)); + + cx.background_spawn({ + let fs = fs.clone(); + let dir = dir.clone(); + async move { + for file_name in to_delete { + fs.remove_dir( + &dir.join(file_name), + RemoveOptions { + recursive: true, + ignore_if_not_exists: false, + }, + ) + .await + .ok(); + } + } + }) + .detach(); + + let version = if let Some(file_name) = newest_version { + cx.background_spawn({ + let file_name = file_name.clone(); + let dir = dir.clone(); + let fs = fs.clone(); + async move { + let latest_version = + node_runtime.npm_package_latest_version(&package_name).await; + if let Ok(latest_version) = latest_version + && &latest_version != &file_name.to_string_lossy() + { + download_latest_version( + fs, + dir.clone(), + node_runtime, + package_name.clone(), + ) + .await + .log_err(); + if let Some(mut new_version_available) = new_version_available { + new_version_available.send(Some(latest_version)).ok(); + } + } + } + }) + .detach(); + file_name + } else { + if let Some(mut status_tx) = status_tx { + status_tx.send("Installing…".into()).ok(); + } + let dir = dir.clone(); + cx.background_spawn(download_latest_version( + fs.clone(), + dir.clone(), + node_runtime, + package_name.clone(), + )) + .await? + .into() + }; + + let agent_server_path = dir.join(version).join(entrypoint_path); + let agent_server_path_exists = fs.is_file(&agent_server_path).await; + anyhow::ensure!( + agent_server_path_exists, + "Missing entrypoint path {} after installation", + agent_server_path.to_string_lossy() + ); + + anyhow::Ok(AgentServerCommand { + path: node_path, + args: vec![agent_server_path.to_string_lossy().to_string()], + env: None, + }) + }) +} + +fn find_bin_in_path( + bin_name: SharedString, + root_dir: PathBuf, + env: HashMap, + cx: &mut AsyncApp, +) -> Task> { + cx.background_executor().spawn(async move { + let which_result = if cfg!(windows) { + which::which(bin_name.as_str()) + } else { + let shell_path = env.get("PATH").cloned(); + which::which_in(bin_name.as_str(), shell_path.as_ref(), &root_dir) + }; + + if let Err(which::Error::CannotFindBinaryPath) = which_result { + return None; + } + + which_result.log_err() + }) +} + +async fn download_latest_version( + fs: Arc, + dir: PathBuf, + node_runtime: NodeRuntime, + package_name: SharedString, +) -> Result { + log::debug!("downloading latest version of {package_name}"); + + let tmp_dir = tempfile::tempdir_in(&dir)?; + + node_runtime + .npm_install_packages(tmp_dir.path(), &[(&package_name, "latest")]) + .await?; + + let version = node_runtime + .npm_package_installed_version(tmp_dir.path(), &package_name) + .await? + .context("expected package to be installed")?; + + fs.rename( + &tmp_dir.keep(), + &dir.join(&version), + RenameOptions { + ignore_if_exists: true, + overwrite: false, + }, + ) + .await?; + + anyhow::Ok(version) +} + +struct RemoteExternalAgentServer { + project_id: u64, + upstream_client: Entity, + name: ExternalAgentServerName, + status_tx: Option>, + new_version_available_tx: Option>>, +} + +// new method: status_updated +// does nothing in the all-local case +// for RemoteExternalAgentServer, sends on the stored tx +// etc. + +impl ExternalAgentServer for RemoteExternalAgentServer { + fn get_command( + &mut self, + root_dir: Option<&str>, + extra_env: HashMap, + status_tx: Option>, + new_version_available_tx: Option>>, + cx: &mut AsyncApp, + ) -> Task)>> { + let project_id = self.project_id; + let name = self.name.to_string(); + let upstream_client = self.upstream_client.downgrade(); + let root_dir = root_dir.map(|root_dir| root_dir.to_owned()); + self.status_tx = status_tx; + self.new_version_available_tx = new_version_available_tx; + cx.spawn(async move |cx| { + let mut response = upstream_client + .update(cx, |upstream_client, _| { + upstream_client + .proto_client() + .request(proto::GetAgentServerCommand { + project_id, + name, + root_dir: root_dir.clone(), + }) + })? + .await?; + let root_dir = response.root_dir; + response.env.extend(extra_env); + let command = upstream_client.update(cx, |client, _| { + client.build_command( + Some(response.path), + &response.args, + &response.env.into_iter().collect(), + Some(root_dir.clone()), + None, + ) + })??; + Ok(( + AgentServerCommand { + path: command.program.into(), + args: command.args, + env: Some(command.env), + }, + root_dir, + response + .login + .map(|login| task::SpawnInTerminal::from_proto(login)), + )) + }) + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } +} + +struct LocalGemini { + fs: Arc, + node_runtime: NodeRuntime, + project_environment: Entity, + custom_command: Option, + ignore_system_version: bool, +} + +impl ExternalAgentServer for LocalGemini { + fn get_command( + &mut self, + root_dir: Option<&str>, + extra_env: HashMap, + status_tx: Option>, + new_version_available_tx: Option>>, + cx: &mut AsyncApp, + ) -> Task)>> { + let fs = self.fs.clone(); + let node_runtime = self.node_runtime.clone(); + let project_environment = self.project_environment.downgrade(); + let custom_command = self.custom_command.clone(); + let ignore_system_version = self.ignore_system_version; + let root_dir: Arc = root_dir + .map(|root_dir| Path::new(root_dir)) + .unwrap_or(paths::home_dir()) + .into(); + + cx.spawn(async move |cx| { + let mut env = project_environment + .update(cx, |project_environment, cx| { + project_environment.get_directory_environment(root_dir.clone(), cx) + })? + .await + .unwrap_or_default(); + + let mut command = if let Some(mut custom_command) = custom_command { + env.extend(custom_command.env.unwrap_or_default()); + custom_command.env = Some(env); + custom_command + } else if !ignore_system_version + && let Some(bin) = + find_bin_in_path("gemini".into(), root_dir.to_path_buf(), env.clone(), cx).await + { + AgentServerCommand { + path: bin, + args: Vec::new(), + env: Some(env), + } + } else { + let mut command = get_or_npm_install_builtin_agent( + GEMINI_NAME.into(), + "@google/gemini-cli".into(), + "node_modules/@google/gemini-cli/dist/index.js".into(), + Some("0.2.1".parse().unwrap()), + status_tx, + new_version_available_tx, + fs, + node_runtime, + cx, + ) + .await?; + command.env = Some(env); + command + }; + + // Gemini CLI doesn't seem to have a dedicated invocation for logging in--we just run it normally without any arguments. + let login = task::SpawnInTerminal { + command: Some(command.path.clone().to_proto()), + args: command.args.clone(), + env: command.env.clone().unwrap_or_default(), + label: "gemini /auth".into(), + ..Default::default() + }; + + command.env.get_or_insert_default().extend(extra_env); + command.args.push("--experimental-acp".into()); + Ok((command, root_dir.to_proto(), Some(login))) + }) + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } +} + +struct LocalClaudeCode { + fs: Arc, + node_runtime: NodeRuntime, + project_environment: Entity, + custom_command: Option, +} + +impl ExternalAgentServer for LocalClaudeCode { + fn get_command( + &mut self, + root_dir: Option<&str>, + extra_env: HashMap, + status_tx: Option>, + new_version_available_tx: Option>>, + cx: &mut AsyncApp, + ) -> Task)>> { + let fs = self.fs.clone(); + let node_runtime = self.node_runtime.clone(); + let project_environment = self.project_environment.downgrade(); + let custom_command = self.custom_command.clone(); + let root_dir: Arc = root_dir + .map(|root_dir| Path::new(root_dir)) + .unwrap_or(paths::home_dir()) + .into(); + + cx.spawn(async move |cx| { + let mut env = project_environment + .update(cx, |project_environment, cx| { + project_environment.get_directory_environment(root_dir.clone(), cx) + })? + .await + .unwrap_or_default(); + env.insert("ANTHROPIC_API_KEY".into(), "".into()); + + let (mut command, login) = if let Some(mut custom_command) = custom_command { + env.extend(custom_command.env.unwrap_or_default()); + custom_command.env = Some(env); + (custom_command, None) + } else { + let mut command = get_or_npm_install_builtin_agent( + "claude-code-acp".into(), + "@zed-industries/claude-code-acp".into(), + "node_modules/@zed-industries/claude-code-acp/dist/index.js".into(), + Some("0.2.5".parse().unwrap()), + status_tx, + new_version_available_tx, + fs, + node_runtime, + cx, + ) + .await?; + command.env = Some(env); + let login = command + .args + .first() + .and_then(|path| { + path.strip_suffix("/@zed-industries/claude-code-acp/dist/index.js") + }) + .map(|path_prefix| task::SpawnInTerminal { + command: Some(command.path.clone().to_proto()), + args: vec![ + Path::new(path_prefix) + .join("@anthropic-ai/claude-code/cli.js") + .to_string_lossy() + .to_string(), + "/login".into(), + ], + env: command.env.clone().unwrap_or_default(), + label: "claude /login".into(), + ..Default::default() + }); + (command, login) + }; + + command.env.get_or_insert_default().extend(extra_env); + Ok((command, root_dir.to_proto(), login)) + }) + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } +} + +struct LocalCustomAgent { + project_environment: Entity, + command: AgentServerCommand, +} + +impl ExternalAgentServer for LocalCustomAgent { + fn get_command( + &mut self, + root_dir: Option<&str>, + extra_env: HashMap, + _status_tx: Option>, + _new_version_available_tx: Option>>, + cx: &mut AsyncApp, + ) -> Task)>> { + let mut command = self.command.clone(); + let root_dir: Arc = root_dir + .map(|root_dir| Path::new(root_dir)) + .unwrap_or(paths::home_dir()) + .into(); + let project_environment = self.project_environment.downgrade(); + cx.spawn(async move |cx| { + let mut env = project_environment + .update(cx, |project_environment, cx| { + project_environment.get_directory_environment(root_dir.clone(), cx) + })? + .await + .unwrap_or_default(); + env.extend(command.env.unwrap_or_default()); + env.extend(extra_env); + command.env = Some(env); + Ok((command, root_dir.to_proto(), None)) + }) + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } +} + +pub const GEMINI_NAME: &'static str = "gemini"; +pub const CLAUDE_CODE_NAME: &'static str = "claude"; + +#[derive( + Default, Deserialize, Serialize, Clone, JsonSchema, Debug, SettingsUi, SettingsKey, PartialEq, +)] +#[settings_key(key = "agent_servers")] +pub struct AllAgentServersSettings { + pub gemini: Option, + pub claude: Option, + + /// Custom agent servers configured by the user + #[serde(flatten)] + pub custom: HashMap, +} + +#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug, PartialEq)] +pub struct BuiltinAgentServerSettings { + /// Absolute path to a binary to be used when launching this agent. + /// + /// This can be used to run a specific binary without automatic downloads or searching `$PATH`. + #[serde(rename = "command")] + pub path: Option, + /// If a binary is specified in `command`, it will be passed these arguments. + pub args: Option>, + /// If a binary is specified in `command`, it will be passed these environment variables. + pub env: Option>, + /// Whether to skip searching `$PATH` for an agent server binary when + /// launching this agent. + /// + /// This has no effect if a `command` is specified. Otherwise, when this is + /// `false`, Zed will search `$PATH` for an agent server binary and, if one + /// is found, use it for threads with this agent. If no agent binary is + /// found on `$PATH`, Zed will automatically install and use its own binary. + /// When this is `true`, Zed will not search `$PATH`, and will always use + /// its own binary. + /// + /// Default: true + pub ignore_system_version: Option, + /// The default mode to use for this agent. + /// + /// Note: Not only all agents support modes. + /// + /// Default: None + pub default_mode: Option, +} + +impl BuiltinAgentServerSettings { + pub(crate) fn custom_command(self) -> Option { + self.path.map(|path| AgentServerCommand { + path, + args: self.args.unwrap_or_default(), + env: self.env, + }) + } +} + +impl From for BuiltinAgentServerSettings { + fn from(value: AgentServerCommand) -> Self { + BuiltinAgentServerSettings { + path: Some(value.path), + args: Some(value.args), + env: value.env, + ..Default::default() + } + } +} + +#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug, PartialEq)] +pub struct CustomAgentServerSettings { + #[serde(flatten)] + pub command: AgentServerCommand, + /// The default mode to use for this agent. + /// + /// Note: Not only all agents support modes. + /// + /// Default: None + pub default_mode: Option, +} + +impl settings::Settings for AllAgentServersSettings { + type FileContent = Self; + + fn load(sources: SettingsSources, _: &mut App) -> Result { + let mut settings = AllAgentServersSettings::default(); + + for AllAgentServersSettings { + gemini, + claude, + custom, + } in sources.defaults_and_customizations() + { + if gemini.is_some() { + settings.gemini = gemini.clone(); + } + if claude.is_some() { + settings.claude = claude.clone(); + } + + // Merge custom agents + for (name, config) in custom { + // Skip built-in agent names to avoid conflicts + if name != GEMINI_NAME && name != CLAUDE_CODE_NAME { + settings.custom.insert(name.clone(), config.clone()); + } + } + } + + Ok(settings) + } + + fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} +} diff --git a/crates/project/src/buffer_store.rs b/crates/project/src/buffer_store.rs index b8101e14f39b4faf54b76eaab955864e4ef82ae5..07f8e0c95cf8551803d5f5828703dbec090fcedb 100644 --- a/crates/project/src/buffer_store.rs +++ b/crates/project/src/buffer_store.rs @@ -20,7 +20,7 @@ use language::{ }, }; use rpc::{ - AnyProtoClient, ErrorExt as _, TypedEnvelope, + AnyProtoClient, ErrorCode, ErrorExt as _, TypedEnvelope, proto::{self, ToProto}, }; use smol::channel::Receiver; @@ -88,9 +88,18 @@ pub enum BufferStoreEvent { }, } -#[derive(Default, Debug)] +#[derive(Default, Debug, Clone)] pub struct ProjectTransaction(pub HashMap, language::Transaction>); +impl PartialEq for ProjectTransaction { + fn eq(&self, other: &Self) -> bool { + self.0.len() == other.0.len() + && self.0.iter().all(|(buffer, transaction)| { + other.0.get(buffer).is_some_and(|t| t.id == transaction.id) + }) + } +} + impl EventEmitter for BufferStore {} impl RemoteBufferStore { @@ -168,7 +177,7 @@ impl RemoteBufferStore { .with_context(|| { format!("no worktree found for id {}", file.worktree_id) })?; - buffer_file = Some(Arc::new(File::from_proto(file, worktree.clone(), cx)?) + buffer_file = Some(Arc::new(File::from_proto(file, worktree, cx)?) as Arc); } Buffer::from_proto(replica_id, capability, state, buffer_file) @@ -234,7 +243,7 @@ impl RemoteBufferStore { } } } - return Ok(None); + Ok(None) } pub fn incomplete_buffer_ids(&self) -> Vec { @@ -310,7 +319,11 @@ impl RemoteBufferStore { }) } - fn create_buffer(&self, cx: &mut Context) -> Task>> { + fn create_buffer( + &self, + project_searchable: bool, + cx: &mut Context, + ) -> Task>> { let create = self.upstream_client.request(proto::OpenNewBuffer { project_id: self.project_id, }); @@ -318,8 +331,13 @@ impl RemoteBufferStore { let response = create.await?; let buffer_id = BufferId::new(response.buffer_id)?; - this.update(cx, |this, cx| this.wait_for_remote_buffer(buffer_id, cx))? - .await + this.update(cx, |this, cx| { + if !project_searchable { + this.non_searchable_buffers.insert(buffer_id); + } + this.wait_for_remote_buffer(buffer_id, cx) + })? + .await }) } @@ -413,13 +431,10 @@ impl LocalBufferStore { cx: &mut Context, ) { cx.subscribe(worktree, |this, worktree, event, cx| { - if worktree.read(cx).is_local() { - match event { - worktree::Event::UpdatedEntries(changes) => { - Self::local_worktree_entries_changed(this, &worktree, changes, cx); - } - _ => {} - } + if worktree.read(cx).is_local() + && let worktree::Event::UpdatedEntries(changes) = event + { + Self::local_worktree_entries_changed(this, &worktree, changes, cx); } }) .detach(); @@ -467,6 +482,7 @@ impl LocalBufferStore { Some(buffer) } else { this.opened_buffers.remove(&buffer_id); + this.non_searchable_buffers.remove(&buffer_id); None }; @@ -594,7 +610,7 @@ impl LocalBufferStore { else { return Task::ready(Err(anyhow!("no such worktree"))); }; - self.save_local_buffer(buffer, worktree, path.path.clone(), true, cx) + self.save_local_buffer(buffer, worktree, path.path, true, cx) } fn open_buffer( @@ -664,12 +680,21 @@ impl LocalBufferStore { }) } - fn create_buffer(&self, cx: &mut Context) -> Task>> { + fn create_buffer( + &self, + project_searchable: bool, + cx: &mut Context, + ) -> Task>> { cx.spawn(async move |buffer_store, cx| { let buffer = cx.new(|cx| Buffer::local("", cx).with_language(language::PLAIN_TEXT.clone(), cx))?; buffer_store.update(cx, |buffer_store, cx| { buffer_store.add_buffer(buffer.clone(), cx).log_err(); + if !project_searchable { + buffer_store + .non_searchable_buffers + .insert(buffer.read(cx).remote_id()); + } })?; Ok(buffer) }) @@ -831,13 +856,25 @@ impl BufferStore { } }; - cx.background_spawn(async move { task.await.map_err(|e| anyhow!("{e}")) }) + cx.background_spawn(async move { + task.await.map_err(|e| { + if e.error_code() != ErrorCode::Internal { + anyhow!(e.error_code()) + } else { + anyhow!("{e}") + } + }) + }) } - pub fn create_buffer(&mut self, cx: &mut Context) -> Task>> { + pub fn create_buffer( + &mut self, + project_searchable: bool, + cx: &mut Context, + ) -> Task>> { match &self.state { - BufferStoreState::Local(this) => this.create_buffer(cx), - BufferStoreState::Remote(this) => this.create_buffer(cx), + BufferStoreState::Local(this) => this.create_buffer(project_searchable, cx), + BufferStoreState::Remote(this) => this.create_buffer(project_searchable, cx), } } @@ -848,7 +885,7 @@ impl BufferStore { ) -> Task> { match &mut self.state { BufferStoreState::Local(this) => this.save_buffer(buffer, cx), - BufferStoreState::Remote(this) => this.save_remote_buffer(buffer.clone(), None, cx), + BufferStoreState::Remote(this) => this.save_remote_buffer(buffer, None, cx), } } @@ -938,7 +975,15 @@ impl BufferStore { ) -> impl Iterator>>)> { self.loading_buffers.iter().map(|(path, task)| { let task = task.clone(); - (path, async move { task.await.map_err(|e| anyhow!("{e}")) }) + (path, async move { + task.await.map_err(|e| { + if e.error_code() != ErrorCode::Internal { + anyhow!(e.error_code()) + } else { + anyhow!("{e}") + } + }) + }) }) } @@ -947,10 +992,9 @@ impl BufferStore { } pub fn get_by_path(&self, path: &ProjectPath) -> Option> { - self.path_to_buffer_id.get(path).and_then(|buffer_id| { - let buffer = self.get(*buffer_id); - buffer - }) + self.path_to_buffer_id + .get(path) + .and_then(|buffer_id| self.get(*buffer_id)) } pub fn get(&self, buffer_id: BufferId) -> Option> { @@ -1094,10 +1138,10 @@ impl BufferStore { .collect::>() })?; for buffer_task in buffers { - if let Some(buffer) = buffer_task.await.log_err() { - if tx.send(buffer).await.is_err() { - return anyhow::Ok(()); - } + if let Some(buffer) = buffer_task.await.log_err() + && tx.send(buffer).await.is_err() + { + return anyhow::Ok(()); } } } @@ -1142,7 +1186,7 @@ impl BufferStore { envelope: TypedEnvelope, mut cx: AsyncApp, ) -> Result { - let payload = envelope.payload.clone(); + let payload = envelope.payload; let buffer_id = BufferId::new(payload.buffer_id)?; let ops = payload .operations @@ -1173,11 +1217,11 @@ impl BufferStore { buffer_id: BufferId, handle: OpenLspBufferHandle, ) { - if let Some(shared_buffers) = self.shared_buffers.get_mut(&peer_id) { - if let Some(buffer) = shared_buffers.get_mut(&buffer_id) { - buffer.lsp_handle = Some(handle); - return; - } + if let Some(shared_buffers) = self.shared_buffers.get_mut(&peer_id) + && let Some(buffer) = shared_buffers.get_mut(&buffer_id) + { + buffer.lsp_handle = Some(handle); + return; } debug_panic!("tried to register shared lsp handle, but buffer was not shared") } @@ -1313,10 +1357,7 @@ impl BufferStore { let new_path = file.path.clone(); buffer.file_updated(Arc::new(file), cx); - if old_file - .as_ref() - .map_or(true, |old| *old.path() != new_path) - { + if old_file.as_ref().is_none_or(|old| *old.path() != new_path) { Some(old_file) } else { None @@ -1345,7 +1386,7 @@ impl BufferStore { mut cx: AsyncApp, ) -> Result { let buffer_id = BufferId::new(envelope.payload.buffer_id)?; - let (buffer, project_id) = this.read_with(&mut cx, |this, _| { + let (buffer, project_id) = this.read_with(&cx, |this, _| { anyhow::Ok(( this.get_existing(buffer_id)?, this.downstream_client @@ -1359,7 +1400,7 @@ impl BufferStore { buffer.wait_for_version(deserialize_version(&envelope.payload.version)) })? .await?; - let buffer_id = buffer.read_with(&mut cx, |buffer, _| buffer.remote_id())?; + let buffer_id = buffer.read_with(&cx, |buffer, _| buffer.remote_id())?; if let Some(new_path) = envelope.payload.new_path { let new_path = ProjectPath::from_proto(new_path); @@ -1372,7 +1413,7 @@ impl BufferStore { .await?; } - buffer.read_with(&mut cx, |buffer, _| proto::BufferSaved { + buffer.read_with(&cx, |buffer, _| proto::BufferSaved { project_id, buffer_id: buffer_id.into(), version: serialize_version(buffer.saved_version()), @@ -1388,14 +1429,14 @@ impl BufferStore { let peer_id = envelope.sender_id; let buffer_id = BufferId::new(envelope.payload.buffer_id)?; this.update(&mut cx, |this, cx| { - if let Some(shared) = this.shared_buffers.get_mut(&peer_id) { - if shared.remove(&buffer_id).is_some() { - cx.emit(BufferStoreEvent::SharedBufferClosed(peer_id, buffer_id)); - if shared.is_empty() { - this.shared_buffers.remove(&peer_id); - } - return; + if let Some(shared) = this.shared_buffers.get_mut(&peer_id) + && shared.remove(&buffer_id).is_some() + { + cx.emit(BufferStoreEvent::SharedBufferClosed(peer_id, buffer_id)); + if shared.is_empty() { + this.shared_buffers.remove(&peer_id); } + return; } debug_panic!( "peer_id {} closed buffer_id {} which was either not open or already closed", @@ -1592,6 +1633,7 @@ impl BufferStore { &mut self, text: &str, language: Option>, + project_searchable: bool, cx: &mut Context, ) -> Entity { let buffer = cx.new(|cx| { @@ -1601,6 +1643,9 @@ impl BufferStore { self.add_buffer(buffer.clone(), cx).log_err(); let buffer_id = buffer.read(cx).remote_id(); + if !project_searchable { + self.non_searchable_buffers.insert(buffer_id); + } if let Some(file) = File::from_dyn(buffer.read(cx).file()) { self.path_to_buffer_id.insert( @@ -1670,10 +1715,6 @@ impl BufferStore { } serialized_transaction } - - pub(crate) fn mark_buffer_as_non_searchable(&mut self, buffer_id: BufferId) { - self.non_searchable_buffers.insert(buffer_id); - } } impl OpenBuffer { diff --git a/crates/project/src/color_extractor.rs b/crates/project/src/color_extractor.rs index 5473da88af5bee6e66b005956366a289478f7ee4..6e9907e30b7393a3074f4af579536d74140418f9 100644 --- a/crates/project/src/color_extractor.rs +++ b/crates/project/src/color_extractor.rs @@ -4,8 +4,8 @@ use gpui::{Hsla, Rgba}; use lsp::{CompletionItem, Documentation}; use regex::{Regex, RegexBuilder}; -const HEX: &'static str = r#"(#(?:[\da-fA-F]{3}){1,2})"#; -const RGB_OR_HSL: &'static str = r#"(rgba?|hsla?)\(\s*(\d{1,3}%?)\s*,\s*(\d{1,3}%?)\s*,\s*(\d{1,3}%?)\s*(?:,\s*(1|0?\.\d+))?\s*\)"#; +const HEX: &str = r#"(#(?:[\da-fA-F]{3}){1,2})"#; +const RGB_OR_HSL: &str = r#"(rgba?|hsla?)\(\s*(\d{1,3}%?)\s*,\s*(\d{1,3}%?)\s*,\s*(\d{1,3}%?)\s*(?:,\s*(1|0?\.\d+))?\s*\)"#; static RELAXED_HEX_REGEX: LazyLock = LazyLock::new(|| { RegexBuilder::new(HEX) @@ -102,7 +102,7 @@ fn parse(str: &str, mode: ParseMode) -> Option { }; } - return None; + None } fn parse_component(value: &str, max: f32) -> Option { @@ -141,7 +141,7 @@ mod tests { use gpui::rgba; use lsp::{CompletionItem, CompletionItemKind}; - pub const COLOR_TABLE: &[(&'static str, Option)] = &[ + pub const COLOR_TABLE: &[(&str, Option)] = &[ // -- Invalid -- // Invalid hex ("f0f", None), diff --git a/crates/project/src/context_server_store.rs b/crates/project/src/context_server_store.rs index c96ab4e8f3ba87133d9b64e9701130f5d32adfb9..20188df5c4ae38b2ae305daee5b3eecc25319951 100644 --- a/crates/project/src/context_server_store.rs +++ b/crates/project/src/context_server_store.rs @@ -368,7 +368,7 @@ impl ContextServerStore { } pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context) -> Result<()> { - if let Some(state) = self.servers.get(&id) { + if let Some(state) = self.servers.get(id) { let configuration = state.configuration(); self.stop_server(&state.server().id(), cx)?; @@ -397,9 +397,8 @@ impl ContextServerStore { let server = server.clone(); let configuration = configuration.clone(); async move |this, cx| { - match server.clone().start(&cx).await { + match server.clone().start(cx).await { Ok(_) => { - log::info!("Started {} context server", id); debug_assert!(server.client().is_some()); this.update(cx, |this, cx| { @@ -588,7 +587,7 @@ impl ContextServerStore { for server_id in this.servers.keys() { // All servers that are not in desired_servers should be removed from the store. // This can happen if the user removed a server from the context server settings. - if !configured_servers.contains_key(&server_id) { + if !configured_servers.contains_key(server_id) { if disabled_servers.contains_key(&server_id.0) { servers_to_stop.insert(server_id.clone()); } else { @@ -642,8 +641,8 @@ mod tests { #[gpui::test] async fn test_context_server_status(cx: &mut TestAppContext) { - const SERVER_1_ID: &'static str = "mcp-1"; - const SERVER_2_ID: &'static str = "mcp-2"; + const SERVER_1_ID: &str = "mcp-1"; + const SERVER_2_ID: &str = "mcp-2"; let (_fs, project) = setup_context_server_test( cx, @@ -722,8 +721,8 @@ mod tests { #[gpui::test] async fn test_context_server_status_events(cx: &mut TestAppContext) { - const SERVER_1_ID: &'static str = "mcp-1"; - const SERVER_2_ID: &'static str = "mcp-2"; + const SERVER_1_ID: &str = "mcp-1"; + const SERVER_2_ID: &str = "mcp-2"; let (_fs, project) = setup_context_server_test( cx, @@ -761,7 +760,7 @@ mod tests { &store, vec![ (server_1_id.clone(), ContextServerStatus::Starting), - (server_1_id.clone(), ContextServerStatus::Running), + (server_1_id, ContextServerStatus::Running), (server_2_id.clone(), ContextServerStatus::Starting), (server_2_id.clone(), ContextServerStatus::Running), (server_2_id.clone(), ContextServerStatus::Stopped), @@ -784,7 +783,7 @@ mod tests { #[gpui::test(iterations = 25)] async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) { - const SERVER_1_ID: &'static str = "mcp-1"; + const SERVER_1_ID: &str = "mcp-1"; let (_fs, project) = setup_context_server_test( cx, @@ -845,8 +844,8 @@ mod tests { #[gpui::test] async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) { - const SERVER_1_ID: &'static str = "mcp-1"; - const SERVER_2_ID: &'static str = "mcp-2"; + const SERVER_1_ID: &str = "mcp-1"; + const SERVER_2_ID: &str = "mcp-2"; let server_1_id = ContextServerId(SERVER_1_ID.into()); let server_2_id = ContextServerId(SERVER_2_ID.into()); @@ -977,6 +976,7 @@ mod tests { path: "somebinary".into(), args: vec!["arg".to_string()], env: None, + timeout: None, }, }, ), @@ -1017,6 +1017,7 @@ mod tests { path: "somebinary".into(), args: vec!["anotherArg".to_string()], env: None, + timeout: None, }, }, ), @@ -1084,7 +1085,7 @@ mod tests { #[gpui::test] async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) { - const SERVER_1_ID: &'static str = "mcp-1"; + const SERVER_1_ID: &str = "mcp-1"; let server_1_id = ContextServerId(SERVER_1_ID.into()); @@ -1099,6 +1100,7 @@ mod tests { path: "somebinary".into(), args: vec!["arg".to_string()], env: None, + timeout: None, }, }, )], @@ -1151,6 +1153,7 @@ mod tests { path: "somebinary".into(), args: vec!["arg".to_string()], env: None, + timeout: None, }, }, )], @@ -1178,6 +1181,7 @@ mod tests { command: ContextServerCommand { path: "somebinary".into(), args: vec!["arg".to_string()], + timeout: None, env: None, }, }, @@ -1231,6 +1235,7 @@ mod tests { path: "somebinary".into(), args: vec!["arg".to_string()], env: None, + timeout: None, }, } } @@ -1319,6 +1324,7 @@ mod tests { path: self.path.clone(), args: vec!["arg1".to_string(), "arg2".to_string()], env: None, + timeout: None, })) } diff --git a/crates/project/src/context_server_store/extension.rs b/crates/project/src/context_server_store/extension.rs index 1eb0fe7da129ba9dbd3ee640cb6e02474a3990b6..ca5cacf3b549523dee8b85242bea86653eecbf7a 100644 --- a/crates/project/src/context_server_store/extension.rs +++ b/crates/project/src/context_server_store/extension.rs @@ -63,12 +63,13 @@ impl registry::ContextServerDescriptor for ContextServerDescriptor { .await?; command.command = extension.path_from_extension(&command.command); - log::info!("loaded command for context server {id}: {command:?}"); + log::debug!("loaded command for context server {id}: {command:?}"); Ok(ContextServerCommand { path: command.command, args: command.args, env: Some(command.env.into_iter().collect()), + timeout: None, }) }) } diff --git a/crates/project/src/debugger.rs b/crates/project/src/debugger.rs index 6c22468040097768688d93cde0720320a9e45be9..0bf6a0d61b792bd747992a821adc82150d93c8bf 100644 --- a/crates/project/src/debugger.rs +++ b/crates/project/src/debugger.rs @@ -6,9 +6,9 @@ //! //! There are few reasons for this divide: //! - Breakpoints persist across debug sessions and they're not really specific to any particular session. Sure, we have to send protocol messages for them -//! (so they're a "thing" in the protocol), but we also want to set them before any session starts up. +//! (so they're a "thing" in the protocol), but we also want to set them before any session starts up. //! - Debug clients are doing the heavy lifting, and this is where UI grabs all of it's data from. They also rely on breakpoint store during initialization to obtain -//! current set of breakpoints. +//! current set of breakpoints. //! - Since DAP store knows about all of the available debug sessions, it is responsible for routing RPC requests to sessions. It also knows how to find adapters for particular kind of session. pub mod breakpoint_store; diff --git a/crates/project/src/debugger/breakpoint_store.rs b/crates/project/src/debugger/breakpoint_store.rs index 025dca410069db0350d8d32509244a4889c62415..c47e5d35d5948eb0c176bbc6d14281faa3f60451 100644 --- a/crates/project/src/debugger/breakpoint_store.rs +++ b/crates/project/src/debugger/breakpoint_store.rs @@ -192,7 +192,7 @@ impl BreakpointStore { } pub(crate) fn shared(&mut self, project_id: u64, downstream_client: AnyProtoClient) { - self.downstream_client = Some((downstream_client.clone(), project_id)); + self.downstream_client = Some((downstream_client, project_id)); } pub(crate) fn unshared(&mut self, cx: &mut Context) { @@ -267,7 +267,7 @@ impl BreakpointStore { message: TypedEnvelope, mut cx: AsyncApp, ) -> Result { - let breakpoints = this.read_with(&mut cx, |this, _| this.breakpoint_store())?; + let breakpoints = this.read_with(&cx, |this, _| this.breakpoint_store())?; let path = this .update(&mut cx, |this, cx| { this.project_path_for_absolute_path(message.payload.path.as_ref(), cx) @@ -317,8 +317,8 @@ impl BreakpointStore { .iter() .filter_map(|breakpoint| { breakpoint.bp.bp.to_proto( - &path, - &breakpoint.position(), + path, + breakpoint.position(), &breakpoint.session_state, ) }) @@ -450,9 +450,9 @@ impl BreakpointStore { }); if let Some(found_bp) = found_bp { - found_bp.message = Some(log_message.clone()); + found_bp.message = Some(log_message); } else { - breakpoint.bp.message = Some(log_message.clone()); + breakpoint.bp.message = Some(log_message); // We did not remove any breakpoint, hence let's toggle one. breakpoint_set .breakpoints @@ -482,9 +482,9 @@ impl BreakpointStore { }); if let Some(found_bp) = found_bp { - found_bp.hit_condition = Some(hit_condition.clone()); + found_bp.hit_condition = Some(hit_condition); } else { - breakpoint.bp.hit_condition = Some(hit_condition.clone()); + breakpoint.bp.hit_condition = Some(hit_condition); // We did not remove any breakpoint, hence let's toggle one. breakpoint_set .breakpoints @@ -514,9 +514,9 @@ impl BreakpointStore { }); if let Some(found_bp) = found_bp { - found_bp.condition = Some(condition.clone()); + found_bp.condition = Some(condition); } else { - breakpoint.bp.condition = Some(condition.clone()); + breakpoint.bp.condition = Some(condition); // We did not remove any breakpoint, hence let's toggle one. breakpoint_set .breakpoints @@ -591,7 +591,7 @@ impl BreakpointStore { cx: &mut Context, ) { if let Some(breakpoints) = self.breakpoints.remove(&old_path) { - self.breakpoints.insert(new_path.clone(), breakpoints); + self.breakpoints.insert(new_path, breakpoints); cx.notify(); } @@ -623,12 +623,11 @@ impl BreakpointStore { file_breakpoints.breakpoints.iter().filter_map({ let range = range.clone(); move |bp| { - if let Some(range) = &range { - if bp.position().cmp(&range.start, buffer_snapshot).is_lt() - || bp.position().cmp(&range.end, buffer_snapshot).is_gt() - { - return None; - } + if let Some(range) = &range + && (bp.position().cmp(&range.start, buffer_snapshot).is_lt() + || bp.position().cmp(&range.end, buffer_snapshot).is_gt()) + { + return None; } let session_state = active_session_id .and_then(|id| bp.session_state.get(&id)) @@ -753,7 +752,7 @@ impl BreakpointStore { .iter() .map(|breakpoint| { let position = snapshot - .summary_for_anchor::(&breakpoint.position()) + .summary_for_anchor::(breakpoint.position()) .row; let breakpoint = &breakpoint.bp; SourceBreakpoint { @@ -832,7 +831,6 @@ impl BreakpointStore { new_breakpoints.insert(path, breakpoints_for_file); } this.update(cx, |this, cx| { - log::info!("Finish deserializing breakpoints & initializing breakpoint store"); for (path, count) in new_breakpoints.iter().map(|(path, bp_in_file)| { (path.to_string_lossy(), bp_in_file.breakpoints.len()) }) { @@ -906,7 +904,7 @@ impl BreakpointState { } #[inline] - pub fn to_int(&self) -> i32 { + pub fn to_int(self) -> i32 { match self { BreakpointState::Enabled => 0, BreakpointState::Disabled => 1, diff --git a/crates/project/src/debugger/dap_command.rs b/crates/project/src/debugger/dap_command.rs index 3be3192369452b58fd2382471ca2f41f4aeac75f..772ff2dcfeb98fcda794092f8071fad5c6fcdcd4 100644 --- a/crates/project/src/debugger/dap_command.rs +++ b/crates/project/src/debugger/dap_command.rs @@ -1454,7 +1454,7 @@ impl DapCommand for EvaluateCommand { variables_reference: message.variable_reference, named_variables: message.named_variables, indexed_variables: message.indexed_variables, - memory_reference: message.memory_reference.clone(), + memory_reference: message.memory_reference, value_location_reference: None, //TODO }) } diff --git a/crates/project/src/debugger/dap_store.rs b/crates/project/src/debugger/dap_store.rs index 6f834b5dc0cfd3fc6357d92403bdb7cbfefdd4b0..6c1449b728d3ee5b8c8b019d5e527e9adfb3bf25 100644 --- a/crates/project/src/debugger/dap_store.rs +++ b/crates/project/src/debugger/dap_store.rs @@ -5,11 +5,8 @@ use super::{ session::{self, Session, SessionStateEvent}, }; use crate::{ - InlayHint, InlayHintLabel, ProjectEnvironment, ResolveState, - debugger::session::SessionQuirks, - project_settings::ProjectSettings, - terminals::{SshCommand, wrap_for_ssh}, - worktree_store::WorktreeStore, + InlayHint, InlayHintLabel, ProjectEnvironment, ResolveState, debugger::session::SessionQuirks, + project_settings::ProjectSettings, worktree_store::WorktreeStore, }; use anyhow::{Context as _, Result, anyhow}; use async_trait::async_trait; @@ -34,7 +31,7 @@ use http_client::HttpClient; use language::{Buffer, LanguageToolchainStore, language_settings::InlayHintKind}; use node_runtime::NodeRuntime; -use remote::{SshRemoteClient, ssh_session::SshArgs}; +use remote::RemoteClient; use rpc::{ AnyProtoClient, TypedEnvelope, proto::{self}, @@ -68,7 +65,7 @@ pub enum DapStoreEvent { enum DapStoreMode { Local(LocalDapStore), - Ssh(SshDapStore), + Remote(RemoteDapStore), Collab, } @@ -80,8 +77,8 @@ pub struct LocalDapStore { toolchain_store: Arc, } -pub struct SshDapStore { - ssh_client: Entity, +pub struct RemoteDapStore { + remote_client: Entity, upstream_client: AnyProtoClient, upstream_project_id: u64, } @@ -147,16 +144,16 @@ impl DapStore { Self::new(mode, breakpoint_store, worktree_store, cx) } - pub fn new_ssh( + pub fn new_remote( project_id: u64, - ssh_client: Entity, + remote_client: Entity, breakpoint_store: Entity, worktree_store: Entity, cx: &mut Context, ) -> Self { - let mode = DapStoreMode::Ssh(SshDapStore { - upstream_client: ssh_client.read(cx).proto_client(), - ssh_client, + let mode = DapStoreMode::Remote(RemoteDapStore { + upstream_client: remote_client.read(cx).proto_client(), + remote_client, upstream_project_id: project_id, }); @@ -215,7 +212,7 @@ impl DapStore { dap_settings.and_then(|s| s.binary.as_ref().map(PathBuf::from)); let user_args = dap_settings.map(|s| s.args.clone()); - let delegate = self.delegate(&worktree, console, cx); + let delegate = self.delegate(worktree, console, cx); let cwd: Arc = worktree.read(cx).abs_path().as_ref().into(); cx.spawn(async move |this, cx| { @@ -242,59 +239,57 @@ impl DapStore { Ok(binary) }) } - DapStoreMode::Ssh(ssh) => { - let request = ssh.upstream_client.request(proto::GetDebugAdapterBinary { - session_id: session_id.to_proto(), - project_id: ssh.upstream_project_id, - worktree_id: worktree.read(cx).id().to_proto(), - definition: Some(definition.to_proto()), - }); - let ssh_client = ssh.ssh_client.clone(); + DapStoreMode::Remote(remote) => { + let request = remote + .upstream_client + .request(proto::GetDebugAdapterBinary { + session_id: session_id.to_proto(), + project_id: remote.upstream_project_id, + worktree_id: worktree.read(cx).id().to_proto(), + definition: Some(definition.to_proto()), + }); + let remote = remote.remote_client.clone(); cx.spawn(async move |_, cx| { let response = request.await?; let binary = DebugAdapterBinary::from_proto(response)?; - let (mut ssh_command, envs, path_style) = - ssh_client.read_with(cx, |ssh, _| { - let (SshArgs { arguments, envs }, path_style) = - ssh.ssh_info().context("SSH arguments not found")?; - anyhow::Ok(( - SshCommand { arguments }, - envs.unwrap_or_default(), - path_style, - )) - })??; - - let mut connection = None; - if let Some(c) = binary.connection { - let local_bind_addr = Ipv4Addr::LOCALHOST; - let port = - dap::transport::TcpTransport::unused_port(local_bind_addr).await?; - ssh_command.add_port_forwarding(port, c.host.to_string(), c.port); + let port_forwarding; + let connection; + if let Some(c) = binary.connection { + let host = Ipv4Addr::LOCALHOST; + let port; + if remote.read_with(cx, |remote, _cx| remote.shares_network_interface())? { + port = c.port; + port_forwarding = None; + } else { + port = dap::transport::TcpTransport::unused_port(host).await?; + port_forwarding = Some((port, c.host.to_string(), c.port)); + } connection = Some(TcpArguments { port, - host: local_bind_addr, + host, timeout: c.timeout, }) + } else { + port_forwarding = None; + connection = None; } - let (program, args) = wrap_for_ssh( - &ssh_command, - binary - .command - .as_ref() - .map(|command| (command, &binary.arguments)), - binary.cwd.as_deref(), - binary.envs, - None, - path_style, - ); + let command = remote.read_with(cx, |remote, _cx| { + remote.build_command( + binary.command, + &binary.arguments, + &binary.envs, + binary.cwd.map(|path| path.display().to_string()), + port_forwarding, + ) + })??; Ok(DebugAdapterBinary { - command: Some(program), - arguments: args, - envs, + command: Some(command.program), + arguments: command.args, + envs: command.env, cwd: None, connection, request_args: binary.request_args, @@ -360,9 +355,9 @@ impl DapStore { ))) } } - DapStoreMode::Ssh(ssh) => { - let request = ssh.upstream_client.request(proto::RunDebugLocators { - project_id: ssh.upstream_project_id, + DapStoreMode::Remote(remote) => { + let request = remote.upstream_client.request(proto::RunDebugLocators { + project_id: remote.upstream_project_id, build_command: Some(build_command.to_proto()), locator: locator_name.to_owned(), }); @@ -470,9 +465,8 @@ impl DapStore { session_id: impl Borrow, ) -> Option> { let session_id = session_id.borrow(); - let client = self.sessions.get(session_id).cloned(); - client + self.sessions.get(session_id).cloned() } pub fn sessions(&self) -> impl Iterator> { self.sessions.values() @@ -685,7 +679,7 @@ impl DapStore { let shutdown_id = parent_session.update(cx, |parent_session, _| { parent_session.remove_child_session_id(session_id); - if parent_session.child_session_ids().len() == 0 { + if parent_session.child_session_ids().is_empty() { Some(parent_session.session_id()) } else { None @@ -702,7 +696,7 @@ impl DapStore { cx.emit(DapStoreEvent::DebugClientShutdown(session_id)); cx.background_spawn(async move { - if shutdown_children.len() > 0 { + if !shutdown_children.is_empty() { let _ = join_all(shutdown_children).await; } @@ -722,7 +716,7 @@ impl DapStore { downstream_client: AnyProtoClient, _: &mut Context, ) { - self.downstream_client = Some((downstream_client.clone(), project_id)); + self.downstream_client = Some((downstream_client, project_id)); } pub fn unshared(&mut self, cx: &mut Context) { @@ -902,7 +896,7 @@ impl dap::adapters::DapDelegate for DapAdapterDelegate { } fn worktree_root_path(&self) -> &Path { - &self.worktree.abs_path() + self.worktree.abs_path() } fn http_client(&self) -> Arc { self.http_client.clone() diff --git a/crates/project/src/debugger/locators/cargo.rs b/crates/project/src/debugger/locators/cargo.rs index fa265dae586148f9c8efe14187ee26c805c65e42..b2f9580f9ced893448f86bfb2f7aab4a0de8a52e 100644 --- a/crates/project/src/debugger/locators/cargo.rs +++ b/crates/project/src/debugger/locators/cargo.rs @@ -117,7 +117,7 @@ impl DapLocator for CargoLocator { .cwd .clone() .context("Couldn't get cwd from debug config which is needed for locators")?; - let builder = ShellBuilder::new(true, &build_config.shell).non_interactive(); + let builder = ShellBuilder::new(None, &build_config.shell).non_interactive(); let (program, args) = builder.build( Some("cargo".into()), &build_config @@ -126,7 +126,7 @@ impl DapLocator for CargoLocator { .cloned() .take_while(|arg| arg != "--") .chain(Some("--message-format=json".to_owned())) - .collect(), + .collect::>(), ); let mut child = util::command::new_smol_command(program) .args(args) @@ -146,7 +146,7 @@ impl DapLocator for CargoLocator { let is_test = build_config .args .first() - .map_or(false, |arg| arg == "test" || arg == "t"); + .is_some_and(|arg| arg == "test" || arg == "t"); let executables = output .lines() @@ -187,12 +187,12 @@ impl DapLocator for CargoLocator { .cloned(); } let executable = { - if let Some(ref name) = test_name.as_ref().and_then(|name| { + if let Some(name) = test_name.as_ref().and_then(|name| { name.strip_prefix('$') .map(|name| build_config.env.get(name)) .unwrap_or(Some(name)) }) { - find_best_executable(&executables, &name).await + find_best_executable(&executables, name).await } else { None } diff --git a/crates/project/src/debugger/locators/go.rs b/crates/project/src/debugger/locators/go.rs index 61436fce8f3659d4b12c3010b82e0d845654c4e9..eec06084ec78548e1a627080663d2afccc8a0aca 100644 --- a/crates/project/src/debugger/locators/go.rs +++ b/crates/project/src/debugger/locators/go.rs @@ -174,7 +174,7 @@ impl DapLocator for GoLocator { request: "launch".to_string(), mode: "test".to_string(), program, - args: args, + args, build_flags, cwd: build_config.cwd.clone(), env: build_config.env.clone(), @@ -185,7 +185,7 @@ impl DapLocator for GoLocator { label: resolved_label.to_string().into(), adapter: adapter.0.clone(), build: None, - config: config, + config, tcp_connection: None, }) } @@ -220,7 +220,7 @@ impl DapLocator for GoLocator { request: "launch".to_string(), mode: "debug".to_string(), program, - args: args, + args, build_flags, }) .unwrap(); diff --git a/crates/project/src/debugger/locators/python.rs b/crates/project/src/debugger/locators/python.rs index 3de1281aed36c6a96970d08e0e4f5cb0ef3bd67f..06f7ab2e796c8139f2f8723b95f7f4503250a0c3 100644 --- a/crates/project/src/debugger/locators/python.rs +++ b/crates/project/src/debugger/locators/python.rs @@ -25,23 +25,15 @@ impl DapLocator for PythonLocator { if adapter.0.as_ref() != "Debugpy" { return None; } - let valid_program = build_config.command.starts_with("$ZED_") + let valid_program = build_config.command.starts_with("\"$ZED_") || Path::new(&build_config.command) .file_name() - .map_or(false, |name| { - name.to_str().is_some_and(|path| path.starts_with("python")) - }); + .is_some_and(|name| name.to_str().is_some_and(|path| path.starts_with("python"))); if !valid_program || build_config.args.iter().any(|arg| arg == "-c") { // We cannot debug selections. return None; } - let command = if build_config.command - == VariableName::Custom("PYTHON_ACTIVE_ZED_TOOLCHAIN".into()).template_value() - { - VariableName::Custom("PYTHON_ACTIVE_ZED_TOOLCHAIN_RAW".into()).template_value() - } else { - build_config.command.clone() - }; + let command = build_config.command.clone(); let module_specifier_position = build_config .args .iter() @@ -59,10 +51,8 @@ impl DapLocator for PythonLocator { let program_position = mod_name .is_none() .then(|| { - build_config - .args - .iter() - .position(|arg| *arg == "\"$ZED_FILE\"") + let zed_file = VariableName::File.template_value_with_whitespace(); + build_config.args.iter().position(|arg| *arg == zed_file) }) .flatten(); let args = if let Some(position) = program_position { diff --git a/crates/project/src/debugger/memory.rs b/crates/project/src/debugger/memory.rs index fec3c344c5a433eebb3a1f314a8fd911bd603022..42ad64e6880ba653c6c95cb13f0e6bcc23c9bdae 100644 --- a/crates/project/src/debugger/memory.rs +++ b/crates/project/src/debugger/memory.rs @@ -3,6 +3,7 @@ //! Each byte in memory can either be mapped or unmapped. We try to mimic that twofold: //! - We assume that the memory is divided into pages of a fixed size. //! - We assume that each page can be either mapped or unmapped. +//! //! These two assumptions drive the shape of the memory representation. //! In particular, we want the unmapped pages to be represented without allocating any memory, as *most* //! of the memory in a program space is usually unmapped. @@ -165,8 +166,8 @@ impl Memory { /// - If it succeeds/fails wholesale, cool; we have no unknown memory regions in this page. /// - If it succeeds partially, we know # of mapped bytes. /// We might also know the # of unmapped bytes. -/// However, we're still unsure about what's *after* the unreadable region. /// +/// However, we're still unsure about what's *after* the unreadable region. /// This is where this builder comes in. It lets us track the state of figuring out contents of a single page. pub(super) struct MemoryPageBuilder { chunks: MappedPageContents, @@ -318,19 +319,18 @@ impl Iterator for MemoryIterator { return None; } if let Some((current_page_address, current_memory_chunk)) = self.current_known_page.as_mut() + && current_page_address.0 <= self.start { - if current_page_address.0 <= self.start { - if let Some(next_cell) = current_memory_chunk.next() { - self.start += 1; - return Some(next_cell); - } else { - self.current_known_page.take(); - } + if let Some(next_cell) = current_memory_chunk.next() { + self.start += 1; + return Some(next_cell); + } else { + self.current_known_page.take(); } } if !self.fetch_next_page() { self.start += 1; - return Some(MemoryCell(None)); + Some(MemoryCell(None)) } else { self.next() } diff --git a/crates/project/src/debugger/session.rs b/crates/project/src/debugger/session.rs index d9c28df497b3baa4543e6271106ddb1cd11b4419..81cb3ade2e18b6430c4b644495529c4567344da5 100644 --- a/crates/project/src/debugger/session.rs +++ b/crates/project/src/debugger/session.rs @@ -226,7 +226,7 @@ impl RunningMode { fn unset_breakpoints_from_paths(&self, paths: &Vec>, cx: &mut App) -> Task<()> { let tasks: Vec<_> = paths - .into_iter() + .iter() .map(|path| { self.request(dap_command::SetBreakpoints { source: client_source(path), @@ -431,7 +431,7 @@ impl RunningMode { let should_send_exception_breakpoints = capabilities .exception_breakpoint_filters .as_ref() - .map_or(false, |filters| !filters.is_empty()) + .is_some_and(|filters| !filters.is_empty()) || !configuration_done_supported; let supports_exception_filters = capabilities .supports_exception_filter_options @@ -508,13 +508,12 @@ impl RunningMode { .ok(); } - let ret = if configuration_done_supported { + if configuration_done_supported { this.request(ConfigurationDone {}) } else { Task::ready(Ok(())) } - .await; - ret + .await } }); @@ -710,9 +709,7 @@ where T: LocalDapCommand + PartialEq + Eq + Hash, { fn dyn_eq(&self, rhs: &dyn CacheableCommand) -> bool { - (rhs as &dyn Any) - .downcast_ref::() - .map_or(false, |rhs| self == rhs) + (rhs as &dyn Any).downcast_ref::() == Some(self) } fn dyn_hash(&self, mut hasher: &mut dyn Hasher) { @@ -841,7 +838,7 @@ impl Session { }) .detach(); - let this = Self { + Self { mode: SessionState::Booting(None), id: session_id, child_session_ids: HashSet::default(), @@ -870,9 +867,7 @@ impl Session { task_context, memory: memory::Memory::new(), quirks, - }; - - this + } }) } @@ -1085,7 +1080,7 @@ impl Session { }) .detach(); - return tx; + tx } pub fn is_started(&self) -> bool { @@ -1399,7 +1394,7 @@ impl Session { let breakpoint_store = self.breakpoint_store.clone(); if let Some((local, path)) = self.as_running_mut().and_then(|local| { let breakpoint = local.tmp_breakpoint.take()?; - let path = breakpoint.path.clone(); + let path = breakpoint.path; Some((local, path)) }) { local @@ -1630,7 +1625,7 @@ impl Session { + 'static, cx: &mut Context, ) -> Task> { - if !T::is_supported(&capabilities) { + if !T::is_supported(capabilities) { log::warn!( "Attempted to send a DAP request that isn't supported: {:?}", request @@ -1688,7 +1683,7 @@ impl Session { self.requests .entry((&*key.0 as &dyn Any).type_id()) .and_modify(|request_map| { - request_map.remove(&key); + request_map.remove(key); }); } @@ -1715,7 +1710,7 @@ impl Session { this.threads = result .into_iter() - .map(|thread| (ThreadId(thread.id), Thread::from(thread.clone()))) + .map(|thread| (ThreadId(thread.id), Thread::from(thread))) .collect(); this.invalidate_command_type::(); @@ -2558,10 +2553,7 @@ impl Session { mode: Option, cx: &mut Context, ) -> Task> { - let command = DataBreakpointInfoCommand { - context: context.clone(), - mode, - }; + let command = DataBreakpointInfoCommand { context, mode }; self.request(command, |_, response, _| response.ok(), cx) } diff --git a/crates/project/src/environment.rs b/crates/project/src/environment.rs index 7379a7ef726c6004fc2b29a5b61a47cb9603fbb3..d109e307a89181f0a416d6d01a3fa74684a138a7 100644 --- a/crates/project/src/environment.rs +++ b/crates/project/src/environment.rs @@ -198,7 +198,7 @@ async fn load_directory_shell_environment( ); }; - load_shell_environment(&dir, load_direnv).await + load_shell_environment(dir, load_direnv).await } Err(err) => ( None, diff --git a/crates/project/src/git_store.rs b/crates/project/src/git_store.rs index 891d74fea63f9000440290209a40cd0a8b1051fe..678ab73023f071024c7e6f902d1a085a0e95f264 100644 --- a/crates/project/src/git_store.rs +++ b/crates/project/src/git_store.rs @@ -44,7 +44,7 @@ use parking_lot::Mutex; use postage::stream::Stream as _; use rpc::{ AnyProtoClient, TypedEnvelope, - proto::{self, FromProto, SSH_PROJECT_ID, ToProto, git_reset, split_repository_update}, + proto::{self, FromProto, ToProto, git_reset, split_repository_update}, }; use serde::Deserialize; use std::{ @@ -62,7 +62,7 @@ use std::{ }; use sum_tree::{Edit, SumTree, TreeSet}; use text::{Bias, BufferId}; -use util::{ResultExt, debug_panic, post_inc}; +use util::{ResultExt, debug_panic, paths::SanitizedPath, post_inc}; use worktree::{ File, PathChange, PathKey, PathProgress, PathSummary, PathTarget, ProjectEntryId, UpdatedGitRepositoriesSet, UpdatedGitRepository, Worktree, @@ -141,14 +141,10 @@ enum GitStoreState { project_environment: Entity, fs: Arc, }, - Ssh { - upstream_client: AnyProtoClient, - upstream_project_id: ProjectId, - downstream: Option<(AnyProtoClient, ProjectId)>, - }, Remote { upstream_client: AnyProtoClient, - upstream_project_id: ProjectId, + upstream_project_id: u64, + downstream: Option<(AnyProtoClient, ProjectId)>, }, } @@ -355,7 +351,7 @@ impl GitStore { worktree_store: &Entity, buffer_store: Entity, upstream_client: AnyProtoClient, - project_id: ProjectId, + project_id: u64, cx: &mut Context, ) -> Self { Self::new( @@ -364,23 +360,6 @@ impl GitStore { GitStoreState::Remote { upstream_client, upstream_project_id: project_id, - }, - cx, - ) - } - - pub fn ssh( - worktree_store: &Entity, - buffer_store: Entity, - upstream_client: AnyProtoClient, - cx: &mut Context, - ) -> Self { - Self::new( - worktree_store.clone(), - buffer_store, - GitStoreState::Ssh { - upstream_client, - upstream_project_id: ProjectId(SSH_PROJECT_ID), downstream: None, }, cx, @@ -452,7 +431,7 @@ impl GitStore { pub fn shared(&mut self, project_id: u64, client: AnyProtoClient, cx: &mut Context) { match &mut self.state { - GitStoreState::Ssh { + GitStoreState::Remote { downstream: downstream_client, .. } => { @@ -528,9 +507,6 @@ impl GitStore { }), }); } - GitStoreState::Remote { .. } => { - debug_panic!("shared called on remote store"); - } } } @@ -542,15 +518,12 @@ impl GitStore { } => { downstream_client.take(); } - GitStoreState::Ssh { + GitStoreState::Remote { downstream: downstream_client, .. } => { downstream_client.take(); } - GitStoreState::Remote { .. } => { - debug_panic!("unshared called on remote store"); - } } self.shared_diffs.clear(); } @@ -562,7 +535,7 @@ impl GitStore { pub fn active_repository(&self) -> Option> { self.active_repo_id .as_ref() - .map(|id| self.repositories[&id].clone()) + .map(|id| self.repositories[id].clone()) } pub fn open_unstaged_diff( @@ -571,23 +544,22 @@ impl GitStore { cx: &mut Context, ) -> Task>> { let buffer_id = buffer.read(cx).remote_id(); - if let Some(diff_state) = self.diffs.get(&buffer_id) { - if let Some(unstaged_diff) = diff_state + if let Some(diff_state) = self.diffs.get(&buffer_id) + && let Some(unstaged_diff) = diff_state .read(cx) .unstaged_diff .as_ref() .and_then(|weak| weak.upgrade()) + { + if let Some(task) = + diff_state.update(cx, |diff_state, _| diff_state.wait_for_recalculation()) { - if let Some(task) = - diff_state.update(cx, |diff_state, _| diff_state.wait_for_recalculation()) - { - return cx.background_executor().spawn(async move { - task.await; - Ok(unstaged_diff) - }); - } - return Task::ready(Ok(unstaged_diff)); + return cx.background_executor().spawn(async move { + task.await; + Ok(unstaged_diff) + }); } + return Task::ready(Ok(unstaged_diff)); } let Some((repo, repo_path)) = @@ -628,23 +600,22 @@ impl GitStore { ) -> Task>> { let buffer_id = buffer.read(cx).remote_id(); - if let Some(diff_state) = self.diffs.get(&buffer_id) { - if let Some(uncommitted_diff) = diff_state + if let Some(diff_state) = self.diffs.get(&buffer_id) + && let Some(uncommitted_diff) = diff_state .read(cx) .uncommitted_diff .as_ref() .and_then(|weak| weak.upgrade()) + { + if let Some(task) = + diff_state.update(cx, |diff_state, _| diff_state.wait_for_recalculation()) { - if let Some(task) = - diff_state.update(cx, |diff_state, _| diff_state.wait_for_recalculation()) - { - return cx.background_executor().spawn(async move { - task.await; - Ok(uncommitted_diff) - }); - } - return Task::ready(Ok(uncommitted_diff)); + return cx.background_executor().spawn(async move { + task.await; + Ok(uncommitted_diff) + }); } + return Task::ready(Ok(uncommitted_diff)); } let Some((repo, repo_path)) = @@ -765,29 +736,26 @@ impl GitStore { log::debug!("open conflict set"); let buffer_id = buffer.read(cx).remote_id(); - if let Some(git_state) = self.diffs.get(&buffer_id) { - if let Some(conflict_set) = git_state + if let Some(git_state) = self.diffs.get(&buffer_id) + && let Some(conflict_set) = git_state .read(cx) .conflict_set .as_ref() .and_then(|weak| weak.upgrade()) - { - let conflict_set = conflict_set.clone(); - let buffer_snapshot = buffer.read(cx).text_snapshot(); + { + let conflict_set = conflict_set; + let buffer_snapshot = buffer.read(cx).text_snapshot(); - git_state.update(cx, |state, cx| { - let _ = state.reparse_conflict_markers(buffer_snapshot, cx); - }); + git_state.update(cx, |state, cx| { + let _ = state.reparse_conflict_markers(buffer_snapshot, cx); + }); - return conflict_set; - } + return conflict_set; } let is_unmerged = self .repository_and_path_for_buffer_id(buffer_id, cx) - .map_or(false, |(repo, path)| { - repo.read(cx).snapshot.has_conflict(&path) - }); + .is_some_and(|(repo, path)| repo.read(cx).snapshot.has_conflict(&path)); let git_store = cx.weak_entity(); let buffer_git_state = self .diffs @@ -918,7 +886,7 @@ impl GitStore { return Task::ready(Err(anyhow!("failed to find a git repository for buffer"))); }; let content = match &version { - Some(version) => buffer.rope_for_version(version).clone(), + Some(version) => buffer.rope_for_version(version), None => buffer.as_rope().clone(), }; let version = version.unwrap_or(buffer.version()); @@ -1053,21 +1021,17 @@ impl GitStore { } => downstream_client .as_ref() .map(|state| (state.client.clone(), state.project_id)), - GitStoreState::Ssh { + GitStoreState::Remote { downstream: downstream_client, .. } => downstream_client.clone(), - GitStoreState::Remote { .. } => None, } } fn upstream_client(&self) -> Option { match &self.state { GitStoreState::Local { .. } => None, - GitStoreState::Ssh { - upstream_client, .. - } - | GitStoreState::Remote { + GitStoreState::Remote { upstream_client, .. } => Some(upstream_client.clone()), } @@ -1152,29 +1116,26 @@ impl GitStore { for (buffer_id, diff) in self.diffs.iter() { if let Some((buffer_repo, repo_path)) = self.repository_and_path_for_buffer_id(*buffer_id, cx) + && buffer_repo == repo { - if buffer_repo == repo { - diff.update(cx, |diff, cx| { - if let Some(conflict_set) = &diff.conflict_set { - let conflict_status_changed = - conflict_set.update(cx, |conflict_set, cx| { - let has_conflict = repo_snapshot.has_conflict(&repo_path); - conflict_set.set_has_conflict(has_conflict, cx) - })?; - if conflict_status_changed { - let buffer_store = self.buffer_store.read(cx); - if let Some(buffer) = buffer_store.get(*buffer_id) { - let _ = diff.reparse_conflict_markers( - buffer.read(cx).text_snapshot(), - cx, - ); - } + diff.update(cx, |diff, cx| { + if let Some(conflict_set) = &diff.conflict_set { + let conflict_status_changed = + conflict_set.update(cx, |conflict_set, cx| { + let has_conflict = repo_snapshot.has_conflict(&repo_path); + conflict_set.set_has_conflict(has_conflict, cx) + })?; + if conflict_status_changed { + let buffer_store = self.buffer_store.read(cx); + if let Some(buffer) = buffer_store.get(*buffer_id) { + let _ = diff + .reparse_conflict_markers(buffer.read(cx).text_snapshot(), cx); } } - anyhow::Ok(()) - }) - .ok(); - } + } + anyhow::Ok(()) + }) + .ok(); } } cx.emit(GitStoreEvent::RepositoryUpdated( @@ -1278,7 +1239,7 @@ impl GitStore { ) { match event { BufferStoreEvent::BufferAdded(buffer) => { - cx.subscribe(&buffer, |this, buffer, event, cx| { + cx.subscribe(buffer, |this, buffer, event, cx| { if let BufferEvent::LanguageChanged = event { let buffer_id = buffer.read(cx).remote_id(); if let Some(diff_state) = this.diffs.get(&buffer_id) { @@ -1296,7 +1257,7 @@ impl GitStore { } } BufferStoreEvent::BufferDropped(buffer_id) => { - self.diffs.remove(&buffer_id); + self.diffs.remove(buffer_id); for diffs in self.shared_diffs.values_mut() { diffs.remove(buffer_id); } @@ -1385,8 +1346,8 @@ impl GitStore { repository.update(cx, |repository, cx| { let repo_abs_path = &repository.work_directory_abs_path; if changed_repos.iter().any(|update| { - update.old_work_directory_abs_path.as_ref() == Some(&repo_abs_path) - || update.new_work_directory_abs_path.as_ref() == Some(&repo_abs_path) + update.old_work_directory_abs_path.as_ref() == Some(repo_abs_path) + || update.new_work_directory_abs_path.as_ref() == Some(repo_abs_path) }) { repository.reload_buffer_diff_bases(cx); } @@ -1441,12 +1402,7 @@ impl GitStore { cx.background_executor() .spawn(async move { fs.git_init(&path, fallback_branch_name) }) } - GitStoreState::Ssh { - upstream_client, - upstream_project_id: project_id, - .. - } - | GitStoreState::Remote { + GitStoreState::Remote { upstream_client, upstream_project_id: project_id, .. @@ -1456,7 +1412,7 @@ impl GitStore { cx.background_executor().spawn(async move { client .request(proto::GitInit { - project_id: project_id.0, + project_id: project_id, abs_path: path.to_string_lossy().to_string(), fallback_branch_name, }) @@ -1480,13 +1436,18 @@ impl GitStore { cx.background_executor() .spawn(async move { fs.git_clone(&repo, &path).await }) } - GitStoreState::Ssh { + GitStoreState::Remote { upstream_client, upstream_project_id, .. } => { + if upstream_client.is_via_collab() { + return Task::ready(Err(anyhow!( + "Git Clone isn't supported for project guests" + ))); + } let request = upstream_client.request(proto::GitClone { - project_id: upstream_project_id.0, + project_id: *upstream_project_id, abs_path: path.to_string_lossy().to_string(), remote_repo: repo, }); @@ -1500,9 +1461,6 @@ impl GitStore { } }) } - GitStoreState::Remote { .. } => { - Task::ready(Err(anyhow!("Git Clone isn't supported for remote users"))) - } } } @@ -1515,10 +1473,7 @@ impl GitStore { let mut update = envelope.payload; let id = RepositoryId::from_proto(update.id); - let client = this - .upstream_client() - .context("no upstream client")? - .clone(); + let client = this.upstream_client().context("no upstream client")?; let mut is_new = false; let repo = this.repositories.entry(id).or_insert_with(|| { @@ -1537,7 +1492,7 @@ impl GitStore { }); if is_new { this._subscriptions - .push(cx.subscribe(&repo, Self::on_repository_event)) + .push(cx.subscribe(repo, Self::on_repository_event)) } repo.update(cx, { @@ -2251,13 +2206,13 @@ impl GitStore { ) -> Result<()> { let buffer_id = BufferId::new(request.payload.buffer_id)?; this.update(&mut cx, |this, cx| { - if let Some(diff_state) = this.diffs.get_mut(&buffer_id) { - if let Some(buffer) = this.buffer_store.read(cx).get(buffer_id) { - let buffer = buffer.read(cx).text_snapshot(); - diff_state.update(cx, |diff_state, cx| { - diff_state.handle_base_texts_updated(buffer, request.payload, cx); - }) - } + if let Some(diff_state) = this.diffs.get_mut(&buffer_id) + && let Some(buffer) = this.buffer_store.read(cx).get(buffer_id) + { + let buffer = buffer.read(cx).text_snapshot(); + diff_state.update(cx, |diff_state, cx| { + diff_state.handle_base_texts_updated(buffer, request.payload, cx); + }) } }) } @@ -2369,16 +2324,20 @@ impl GitStore { return None; }; - let mut paths = vec![]; + let mut paths = Vec::new(); // All paths prefixed by a given repo will constitute a continuous range. while let Some(path) = entries.get(ix) && let Some(repo_path) = - RepositorySnapshot::abs_path_to_repo_path_inner(&repo_path, &path) + RepositorySnapshot::abs_path_to_repo_path_inner(&repo_path, path) { paths.push((repo_path, ix)); ix += 1; } - Some((repo, paths)) + if paths.is_empty() { + None + } else { + Some((repo, paths)) + } }); tasks.push_back(task); } @@ -2523,14 +2482,14 @@ impl BufferGitState { pub fn wait_for_recalculation(&mut self) -> Option + use<>> { if *self.recalculating_tx.borrow() { let mut rx = self.recalculating_tx.subscribe(); - return Some(async move { + Some(async move { loop { let is_recalculating = rx.recv().await; if is_recalculating != Some(true) { break; } } - }); + }) } else { None } @@ -2790,6 +2749,7 @@ impl RepositorySnapshot { .iter() .map(|repo_path| repo_path.to_proto()) .collect(), + merge_message: self.merge.message.as_ref().map(|msg| msg.to_string()), project_id, id: self.id.to_proto(), abs_path: self.work_directory_abs_path.to_proto(), @@ -2852,6 +2812,7 @@ impl RepositorySnapshot { .iter() .map(|path| path.as_ref().to_proto()) .collect(), + merge_message: self.merge.message.as_ref().map(|msg| msg.to_string()), project_id, id: self.id.to_proto(), abs_path: self.work_directory_abs_path.to_proto(), @@ -2891,15 +2852,15 @@ impl RepositorySnapshot { } pub fn had_conflict_on_last_merge_head_change(&self, repo_path: &RepoPath) -> bool { - self.merge.conflicted_paths.contains(&repo_path) + self.merge.conflicted_paths.contains(repo_path) } pub fn has_conflict(&self, repo_path: &RepoPath) -> bool { let had_conflict_on_last_merge_head_change = - self.merge.conflicted_paths.contains(&repo_path); + self.merge.conflicted_paths.contains(repo_path); let has_conflict_currently = self - .status_for_path(&repo_path) - .map_or(false, |entry| entry.status.is_conflicted()); + .status_for_path(repo_path) + .is_some_and(|entry| entry.status.is_conflicted()); had_conflict_on_last_merge_head_change || has_conflict_currently } @@ -3293,6 +3254,7 @@ impl Repository { let git_store = self.git_store.upgrade()?; let worktree_store = git_store.read(cx).worktree_store.read(cx); let abs_path = self.snapshot.work_directory_abs_path.join(&path.0); + let abs_path = SanitizedPath::new(&abs_path); let (worktree, relative_path) = worktree_store.find_worktree(abs_path, cx)?; Some(ProjectPath { worktree_id: worktree.read(cx).id(), @@ -3375,7 +3337,7 @@ impl Repository { ) -> Task>> { cx.spawn(async move |repository, cx| { let buffer = buffer_store - .update(cx, |buffer_store, cx| buffer_store.create_buffer(cx))? + .update(cx, |buffer_store, cx| buffer_store.create_buffer(false, cx))? .await?; if let Some(language_registry) = language_registry { @@ -3440,7 +3402,6 @@ impl Repository { reset_mode: ResetMode, _cx: &mut App, ) -> oneshot::Receiver> { - let commit = commit.to_string(); let id = self.id; self.send_job(None, move |git_repo, _| async move { @@ -3547,14 +3508,13 @@ impl Repository { let Some(project_path) = self.repo_path_to_project_path(path, cx) else { continue; }; - if let Some(buffer) = buffer_store.get_by_path(&project_path) { - if buffer + if let Some(buffer) = buffer_store.get_by_path(&project_path) + && buffer .read(cx) .file() - .map_or(false, |file| file.disk_state().exists()) - { - save_futures.push(buffer_store.save_buffer(buffer, cx)); - } + .is_some_and(|file| file.disk_state().exists()) + { + save_futures.push(buffer_store.save_buffer(buffer, cx)); } } }) @@ -3614,14 +3574,13 @@ impl Repository { let Some(project_path) = self.repo_path_to_project_path(path, cx) else { continue; }; - if let Some(buffer) = buffer_store.get_by_path(&project_path) { - if buffer + if let Some(buffer) = buffer_store.get_by_path(&project_path) + && buffer .read(cx) .file() - .map_or(false, |file| file.disk_state().exists()) - { - save_futures.push(buffer_store.save_buffer(buffer, cx)); - } + .is_some_and(|file| file.disk_state().exists()) + { + save_futures.push(buffer_store.save_buffer(buffer, cx)); } } }) @@ -3668,7 +3627,7 @@ impl Repository { let to_stage = self .cached_status() .filter(|entry| !entry.status.staging().is_fully_staged()) - .map(|entry| entry.repo_path.clone()) + .map(|entry| entry.repo_path) .collect(); self.stage_entries(to_stage, cx) } @@ -3677,16 +3636,13 @@ impl Repository { let to_unstage = self .cached_status() .filter(|entry| entry.status.staging().has_staged()) - .map(|entry| entry.repo_path.clone()) + .map(|entry| entry.repo_path) .collect(); self.unstage_entries(to_unstage, cx) } pub fn stash_all(&mut self, cx: &mut Context) -> Task> { - let to_stash = self - .cached_status() - .map(|entry| entry.repo_path.clone()) - .collect(); + let to_stash = self.cached_status().map(|entry| entry.repo_path).collect(); self.stash_entries(to_stash, cx) } @@ -4312,6 +4268,7 @@ impl Repository { .map(proto_to_commit_details); self.snapshot.merge.conflicted_paths = conflicted_paths; + self.snapshot.merge.message = update.merge_message.map(SharedString::from); let edits = update .removed_statuses @@ -4388,7 +4345,8 @@ impl Repository { bail!("not a local repository") }; let (snapshot, events) = this - .read_with(&mut cx, |this, _| { + .update(&mut cx, |this, _| { + this.paths_needing_status_update.clear(); compute_snapshot( this.id, this.work_directory_abs_path.clone(), @@ -4463,14 +4421,13 @@ impl Repository { } if let Some(job) = jobs.pop_front() { - if let Some(current_key) = &job.key { - if jobs + if let Some(current_key) = &job.key + && jobs .iter() .any(|other_job| other_job.key.as_ref() == Some(current_key)) { continue; } - } (job.job)(state.clone(), cx).await; } else if let Some(job) = job_rx.next().await { jobs.push_back(job); @@ -4501,13 +4458,12 @@ impl Repository { } if let Some(job) = jobs.pop_front() { - if let Some(current_key) = &job.key { - if jobs + if let Some(current_key) = &job.key + && jobs .iter() .any(|other_job| other_job.key.as_ref() == Some(current_key)) - { - continue; - } + { + continue; } (job.job)(state.clone(), cx).await; } else if let Some(job) = job_rx.next().await { @@ -4618,6 +4574,9 @@ impl Repository { }; let paths = changed_paths.iter().cloned().collect::>(); + if paths.is_empty() { + return Ok(()); + } let statuses = backend.status(&paths).await?; let changed_path_statuses = cx @@ -4628,10 +4587,10 @@ impl Repository { for (repo_path, status) in &*statuses.entries { changed_paths.remove(repo_path); - if cursor.seek_forward(&PathTarget::Path(repo_path), Bias::Left) { - if cursor.item().is_some_and(|entry| entry.status == *status) { - continue; - } + if cursor.seek_forward(&PathTarget::Path(repo_path), Bias::Left) + && cursor.item().is_some_and(|entry| entry.status == *status) + { + continue; } changed_path_statuses.push(Edit::Insert(StatusEntry { @@ -4843,6 +4802,7 @@ fn branch_to_proto(branch: &git::repository::Branch) -> proto::Branch { sha: commit.sha.to_string(), subject: commit.subject.to_string(), commit_timestamp: commit.commit_timestamp, + author_name: commit.author_name.to_string(), }), } } @@ -4872,6 +4832,7 @@ fn proto_to_branch(proto: &proto::Branch) -> git::repository::Branch { sha: commit.sha.to_string().into(), subject: commit.subject.to_string().into(), commit_timestamp: commit.commit_timestamp, + author_name: commit.author_name.to_string().into(), has_parent: true, } }), diff --git a/crates/project/src/git_store/conflict_set.rs b/crates/project/src/git_store/conflict_set.rs index 27b191f65f896e6488a4d9c52f37e9426cac1c46..313a1e90adc2fde8a62dbe6aa60b4d3a366af22c 100644 --- a/crates/project/src/git_store/conflict_set.rs +++ b/crates/project/src/git_store/conflict_set.rs @@ -369,7 +369,7 @@ mod tests { .unindent(); let buffer_id = BufferId::new(1).unwrap(); - let buffer = Buffer::new(0, buffer_id, test_content.to_string()); + let buffer = Buffer::new(0, buffer_id, test_content); let snapshot = buffer.snapshot(); let conflict_snapshot = ConflictSet::parse(&snapshot); @@ -400,7 +400,7 @@ mod tests { >>>>>>> "# .unindent(); let buffer_id = BufferId::new(1).unwrap(); - let buffer = Buffer::new(0, buffer_id, test_content.to_string()); + let buffer = Buffer::new(0, buffer_id, test_content); let snapshot = buffer.snapshot(); let conflict_snapshot = ConflictSet::parse(&snapshot); @@ -653,7 +653,7 @@ mod tests { cx.run_until_parked(); conflict_set.update(cx, |conflict_set, _| { - assert_eq!(conflict_set.has_conflict, false); + assert!(!conflict_set.has_conflict); assert_eq!(conflict_set.snapshot.conflicts.len(), 0); }); diff --git a/crates/project/src/git_store/git_traversal.rs b/crates/project/src/git_store/git_traversal.rs index bbcffe046debd8ab4529cf2b661abbebefd13f47..eee492e482daf746c60836cab172f84b2834b468 100644 --- a/crates/project/src/git_store/git_traversal.rs +++ b/crates/project/src/git_store/git_traversal.rs @@ -42,8 +42,8 @@ impl<'a> GitTraversal<'a> { // other_repo/ // .git/ // our_query.txt - let mut query = path.ancestors(); - while let Some(query) = query.next() { + let query = path.ancestors(); + for query in query { let (_, snapshot) = self .repo_root_to_snapshot .range(Path::new("")..=query) @@ -182,11 +182,11 @@ impl<'a> Iterator for ChildEntriesGitIter<'a> { type Item = GitEntryRef<'a>; fn next(&mut self) -> Option { - if let Some(item) = self.traversal.entry() { - if item.path.starts_with(self.parent_path) { - self.traversal.advance_to_sibling(); - return Some(item); - } + if let Some(item) = self.traversal.entry() + && item.path.starts_with(self.parent_path) + { + self.traversal.advance_to_sibling(); + return Some(item); } None } @@ -199,7 +199,7 @@ pub struct GitEntryRef<'a> { } impl GitEntryRef<'_> { - pub fn to_owned(&self) -> GitEntry { + pub fn to_owned(self) -> GitEntry { GitEntry { entry: self.entry.clone(), git_summary: self.git_summary, @@ -211,7 +211,7 @@ impl Deref for GitEntryRef<'_> { type Target = Entry; fn deref(&self) -> &Self::Target { - &self.entry + self.entry } } diff --git a/crates/project/src/image_store.rs b/crates/project/src/image_store.rs index 79f134b91a36a2f7d1f3f256506931b47ae8cf9c..e499d4e026f724f12e023738f12afb2735f9ce2d 100644 --- a/crates/project/src/image_store.rs +++ b/crates/project/src/image_store.rs @@ -224,7 +224,7 @@ impl ProjectItem for ImageItem { path: &ProjectPath, cx: &mut App, ) -> Option>>> { - if is_image_file(&project, &path, cx) { + if is_image_file(project, path, cx) { Some(cx.spawn({ let path = path.clone(); let project = project.clone(); @@ -244,7 +244,7 @@ impl ProjectItem for ImageItem { } fn project_path(&self, cx: &App) -> Option { - Some(self.project_path(cx).clone()) + Some(self.project_path(cx)) } fn is_dirty(&self) -> bool { @@ -375,7 +375,6 @@ impl ImageStore { let (mut tx, rx) = postage::watch::channel(); entry.insert(rx.clone()); - let project_path = project_path.clone(); let load_image = self .state .open_image(project_path.path.clone(), worktree, cx); @@ -446,15 +445,12 @@ impl ImageStore { event: &ImageItemEvent, cx: &mut Context, ) { - match event { - ImageItemEvent::FileHandleChanged => { - if let Some(local) = self.state.as_local() { - local.update(cx, |local, cx| { - local.image_changed_file(image, cx); - }) - } - } - _ => {} + if let ImageItemEvent::FileHandleChanged = event + && let Some(local) = self.state.as_local() + { + local.update(cx, |local, cx| { + local.image_changed_file(image, cx); + }) } } } @@ -531,13 +527,10 @@ impl ImageStoreImpl for Entity { impl LocalImageStore { fn subscribe_to_worktree(&mut self, worktree: &Entity, cx: &mut Context) { cx.subscribe(worktree, |this, worktree, event, cx| { - if worktree.read(cx).is_local() { - match event { - worktree::Event::UpdatedEntries(changes) => { - this.local_worktree_entries_changed(&worktree, changes, cx); - } - _ => {} - } + if worktree.read(cx).is_local() + && let worktree::Event::UpdatedEntries(changes) = event + { + this.local_worktree_entries_changed(&worktree, changes, cx); } }) .detach(); diff --git a/crates/project/src/lsp_command.rs b/crates/project/src/lsp_command.rs index c458b6b300c34ec03d144cf297277faf4a94f5db..a960e1183dd46537ef3aee829cd9753b28001480 100644 --- a/crates/project/src/lsp_command.rs +++ b/crates/project/src/lsp_command.rs @@ -50,8 +50,8 @@ pub fn lsp_formatting_options(settings: &LanguageSettings) -> lsp::FormattingOpt } } -pub fn file_path_to_lsp_url(path: &Path) -> Result { - match lsp::Url::from_file_path(path) { +pub fn file_path_to_lsp_url(path: &Path) -> Result { + match lsp::Uri::from_file_path(path) { Ok(url) => Ok(url), Err(()) => anyhow::bail!("Invalid file path provided to LSP request: {path:?}"), } @@ -332,9 +332,9 @@ impl LspCommand for PrepareRename { _: Entity, buffer: Entity, _: LanguageServerId, - mut cx: AsyncApp, + cx: AsyncApp, ) -> Result { - buffer.read_with(&mut cx, |buffer, _| match message { + buffer.read_with(&cx, |buffer, _| match message { Some(lsp::PrepareRenameResponse::Range(range)) | Some(lsp::PrepareRenameResponse::RangeWithPlaceholder { range, .. }) => { let Range { start, end } = range_from_lsp(range); @@ -386,7 +386,7 @@ impl LspCommand for PrepareRename { .await?; Ok(Self { - position: buffer.read_with(&mut cx, |buffer, _| position.to_point_utf16(buffer))?, + position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer))?, }) } @@ -500,13 +500,12 @@ impl LspCommand for PerformRename { mut cx: AsyncApp, ) -> Result { if let Some(edit) = message { - let (lsp_adapter, lsp_server) = + let (_, lsp_server) = language_server_for_buffer(&lsp_store, &buffer, server_id, &mut cx)?; LocalLspStore::deserialize_workspace_edit( lsp_store, edit, self.push_to_history, - lsp_adapter, lsp_server, &mut cx, ) @@ -544,7 +543,7 @@ impl LspCommand for PerformRename { })? .await?; Ok(Self { - position: buffer.read_with(&mut cx, |buffer, _| position.to_point_utf16(buffer))?, + position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer))?, new_name: message.new_name, push_to_history: false, }) @@ -659,7 +658,7 @@ impl LspCommand for GetDefinitions { })? .await?; Ok(Self { - position: buffer.read_with(&mut cx, |buffer, _| position.to_point_utf16(buffer))?, + position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer))?, }) } @@ -762,7 +761,7 @@ impl LspCommand for GetDeclarations { })? .await?; Ok(Self { - position: buffer.read_with(&mut cx, |buffer, _| position.to_point_utf16(buffer))?, + position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer))?, }) } @@ -864,7 +863,7 @@ impl LspCommand for GetImplementations { })? .await?; Ok(Self { - position: buffer.read_with(&mut cx, |buffer, _| position.to_point_utf16(buffer))?, + position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer))?, }) } @@ -963,7 +962,7 @@ impl LspCommand for GetTypeDefinitions { })? .await?; Ok(Self { - position: buffer.read_with(&mut cx, |buffer, _| position.to_point_utf16(buffer))?, + position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer))?, }) } @@ -1116,18 +1115,12 @@ pub async fn location_links_from_lsp( } } - let (lsp_adapter, language_server) = - language_server_for_buffer(&lsp_store, &buffer, server_id, &mut cx)?; + let (_, language_server) = language_server_for_buffer(&lsp_store, &buffer, server_id, &mut cx)?; let mut definitions = Vec::new(); for (origin_range, target_uri, target_range) in unresolved_links { let target_buffer_handle = lsp_store .update(&mut cx, |this, cx| { - this.open_local_buffer_via_lsp( - target_uri, - language_server.server_id(), - lsp_adapter.name.clone(), - cx, - ) + this.open_local_buffer_via_lsp(target_uri, language_server.server_id(), cx) })? .await?; @@ -1172,8 +1165,7 @@ pub async fn location_link_from_lsp( server_id: LanguageServerId, cx: &mut AsyncApp, ) -> Result { - let (lsp_adapter, language_server) = - language_server_for_buffer(&lsp_store, &buffer, server_id, cx)?; + let (_, language_server) = language_server_for_buffer(lsp_store, buffer, server_id, cx)?; let (origin_range, target_uri, target_range) = ( link.origin_selection_range, @@ -1183,12 +1175,7 @@ pub async fn location_link_from_lsp( let target_buffer_handle = lsp_store .update(cx, |lsp_store, cx| { - lsp_store.open_local_buffer_via_lsp( - target_uri, - language_server.server_id(), - lsp_adapter.name.clone(), - cx, - ) + lsp_store.open_local_buffer_via_lsp(target_uri, language_server.server_id(), cx) })? .await?; @@ -1326,7 +1313,7 @@ impl LspCommand for GetReferences { mut cx: AsyncApp, ) -> Result> { let mut references = Vec::new(); - let (lsp_adapter, language_server) = + let (_, language_server) = language_server_for_buffer(&lsp_store, &buffer, server_id, &mut cx)?; if let Some(locations) = locations { @@ -1336,7 +1323,6 @@ impl LspCommand for GetReferences { lsp_store.open_local_buffer_via_lsp( lsp_location.uri, language_server.server_id(), - lsp_adapter.name.clone(), cx, ) })? @@ -1344,7 +1330,7 @@ impl LspCommand for GetReferences { target_buffer_handle .clone() - .read_with(&mut cx, |target_buffer, _| { + .read_with(&cx, |target_buffer, _| { let target_start = target_buffer .clip_point_utf16(point_from_lsp(lsp_location.range.start), Bias::Left); let target_end = target_buffer @@ -1388,7 +1374,7 @@ impl LspCommand for GetReferences { })? .await?; Ok(Self { - position: buffer.read_with(&mut cx, |buffer, _| position.to_point_utf16(buffer))?, + position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer))?, }) } @@ -1498,9 +1484,9 @@ impl LspCommand for GetDocumentHighlights { _: Entity, buffer: Entity, _: LanguageServerId, - mut cx: AsyncApp, + cx: AsyncApp, ) -> Result> { - buffer.read_with(&mut cx, |buffer, _| { + buffer.read_with(&cx, |buffer, _| { let mut lsp_highlights = lsp_highlights.unwrap_or_default(); lsp_highlights.sort_unstable_by_key(|h| (h.range.start, Reverse(h.range.end))); lsp_highlights @@ -1548,7 +1534,7 @@ impl LspCommand for GetDocumentHighlights { })? .await?; Ok(Self { - position: buffer.read_with(&mut cx, |buffer, _| position.to_point_utf16(buffer))?, + position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer))?, }) } @@ -1879,7 +1865,7 @@ impl LspCommand for GetSignatureHelp { })? .await .with_context(|| format!("waiting for version for buffer {}", buffer.entity_id()))?; - let buffer_snapshot = buffer.read_with(&mut cx, |buffer, _| buffer.snapshot())?; + let buffer_snapshot = buffer.read_with(&cx, |buffer, _| buffer.snapshot())?; Ok(Self { position: payload .position @@ -1961,13 +1947,13 @@ impl LspCommand for GetHover { _: Entity, buffer: Entity, _: LanguageServerId, - mut cx: AsyncApp, + cx: AsyncApp, ) -> Result { let Some(hover) = message else { return Ok(None); }; - let (language, range) = buffer.read_with(&mut cx, |buffer, _| { + let (language, range) = buffer.read_with(&cx, |buffer, _| { ( buffer.language().cloned(), hover.range.map(|range| { @@ -2053,7 +2039,7 @@ impl LspCommand for GetHover { })? .await?; Ok(Self { - position: buffer.read_with(&mut cx, |buffer, _| position.to_point_utf16(buffer))?, + position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer))?, }) } @@ -2127,7 +2113,7 @@ impl LspCommand for GetHover { return Ok(None); } - let language = buffer.read_with(&mut cx, |buffer, _| buffer.language().cloned())?; + let language = buffer.read_with(&cx, |buffer, _| buffer.language().cloned())?; let range = if let (Some(start), Some(end)) = (message.start, message.end) { language::proto::deserialize_anchor(start) .and_then(|start| language::proto::deserialize_anchor(end).map(|end| start..end)) @@ -2222,7 +2208,7 @@ impl LspCommand for GetCompletions { let unfiltered_completions_count = completions.len(); let language_server_adapter = lsp_store - .read_with(&mut cx, |lsp_store, _| { + .read_with(&cx, |lsp_store, _| { lsp_store.language_server_adapter_for_id(server_id) })? .with_context(|| format!("no language server with id {server_id}"))?; @@ -2355,15 +2341,14 @@ impl LspCommand for GetCompletions { .zip(completion_edits) .map(|(mut lsp_completion, mut edit)| { LineEnding::normalize(&mut edit.new_text); - if lsp_completion.data.is_none() { - if let Some(default_data) = lsp_defaults + if lsp_completion.data.is_none() + && let Some(default_data) = lsp_defaults .as_ref() .and_then(|item_defaults| item_defaults.data.clone()) - { - // Servers (e.g. JDTLS) prefer unchanged completions, when resolving the items later, - // so we do not insert the defaults here, but `data` is needed for resolving, so this is an exception. - lsp_completion.data = Some(default_data); - } + { + // Servers (e.g. JDTLS) prefer unchanged completions, when resolving the items later, + // so we do not insert the defaults here, but `data` is needed for resolving, so this is an exception. + lsp_completion.data = Some(default_data); } CoreCompletion { replace_range: edit.replace_range, @@ -2409,7 +2394,7 @@ impl LspCommand for GetCompletions { .position .and_then(language::proto::deserialize_anchor) .map(|p| { - buffer.read_with(&mut cx, |buffer, _| { + buffer.read_with(&cx, |buffer, _| { buffer.clip_point_utf16(Unclipped(p.to_point_utf16(buffer)), Bias::Left) }) }) @@ -2516,8 +2501,8 @@ pub(crate) fn parse_completion_text_edit( }; Some(ParsedCompletionEdit { - insert_range: insert_range, - replace_range: replace_range, + insert_range, + replace_range, new_text: new_text.clone(), }) } @@ -2610,11 +2595,9 @@ impl LspCommand for GetCodeActions { server_id: LanguageServerId, cx: AsyncApp, ) -> Result> { - let requested_kinds_set = if let Some(kinds) = self.kinds { - Some(kinds.into_iter().collect::>()) - } else { - None - }; + let requested_kinds_set = self + .kinds + .map(|kinds| kinds.into_iter().collect::>()); let language_server = cx.update(|cx| { lsp_store @@ -2637,10 +2620,10 @@ impl LspCommand for GetCodeActions { .filter_map(|entry| { let (lsp_action, resolved) = match entry { lsp::CodeActionOrCommand::CodeAction(lsp_action) => { - if let Some(command) = lsp_action.command.as_ref() { - if !available_commands.contains(&command.command) { - return None; - } + if let Some(command) = lsp_action.command.as_ref() + && !available_commands.contains(&command.command) + { + return None; } (LspAction::Action(Box::new(lsp_action)), false) } @@ -2655,10 +2638,9 @@ impl LspCommand for GetCodeActions { if let Some((requested_kinds, kind)) = requested_kinds_set.as_ref().zip(lsp_action.action_kind()) + && !requested_kinds.contains(&kind) { - if !requested_kinds.contains(&kind) { - return None; - } + return None; } Some(CodeAction { @@ -2755,7 +2737,7 @@ impl GetCodeActions { Some(lsp::CodeActionProviderCapability::Options(CodeActionOptions { code_action_kinds: Some(supported_action_kinds), .. - })) => Some(supported_action_kinds.clone()), + })) => Some(supported_action_kinds), _ => capabilities.code_action_kinds, } } @@ -2878,7 +2860,7 @@ impl LspCommand for OnTypeFormatting { })?; Ok(Self { - position: buffer.read_with(&mut cx, |buffer, _| position.to_point_utf16(buffer))?, + position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer))?, trigger: message.trigger.clone(), options, push_to_history: false, @@ -3153,7 +3135,7 @@ impl InlayHints { Some(((uri, range), server_id)) => Some(( LanguageServerId(server_id as usize), lsp::Location { - uri: lsp::Url::parse(&uri) + uri: lsp::Uri::from_str(&uri) .context("invalid uri in hint part {part:?}")?, range: lsp::Range::new( point_to_lsp(PointUtf16::new( @@ -3462,10 +3444,7 @@ impl LspCommand for GetCodeLens { capabilities .server_capabilities .code_lens_provider - .as_ref() - .map_or(false, |code_lens_options| { - code_lens_options.resolve_provider.unwrap_or(false) - }) + .is_some() } fn to_lsp( @@ -3490,9 +3469,9 @@ impl LspCommand for GetCodeLens { lsp_store: Entity, buffer: Entity, server_id: LanguageServerId, - mut cx: AsyncApp, + cx: AsyncApp, ) -> anyhow::Result> { - let snapshot = buffer.read_with(&mut cx, |buffer, _| buffer.snapshot())?; + let snapshot = buffer.read_with(&cx, |buffer, _| buffer.snapshot())?; let language_server = cx.update(|cx| { lsp_store .read(cx) @@ -3754,7 +3733,7 @@ impl GetDocumentDiagnostics { .filter_map(|diagnostics| { Some(LspPullDiagnostics::Response { server_id: LanguageServerId::from_proto(diagnostics.server_id), - uri: lsp::Url::from_str(diagnostics.uri.as_str()).log_err()?, + uri: lsp::Uri::from_str(diagnostics.uri.as_str()).log_err()?, diagnostics: if diagnostics.changed { PulledDiagnostics::Unchanged { result_id: diagnostics.result_id?, @@ -3809,9 +3788,9 @@ impl GetDocumentDiagnostics { start: point_to_lsp(PointUtf16::new(start.row, start.column)), end: point_to_lsp(PointUtf16::new(end.row, end.column)), }, - uri: lsp::Url::parse(&info.location_url.unwrap()).unwrap(), + uri: lsp::Uri::from_str(&info.location_url.unwrap()).unwrap(), }, - message: info.message.clone(), + message: info.message, } }) .collect::>(); @@ -3839,12 +3818,11 @@ impl GetDocumentDiagnostics { _ => None, }, code, - code_description: match diagnostic.code_description { - Some(code_description) => Some(CodeDescription { - href: Some(lsp::Url::parse(&code_description).unwrap()), + code_description: diagnostic + .code_description + .map(|code_description| CodeDescription { + href: Some(lsp::Uri::from_str(&code_description).unwrap()), }), - None => None, - }, related_information: Some(related_information), tags: Some(tags), source: diagnostic.source.clone(), @@ -3983,7 +3961,7 @@ pub struct WorkspaceLspPullDiagnostics { } fn process_full_workspace_diagnostics_report( - diagnostics: &mut HashMap, + diagnostics: &mut HashMap, server_id: LanguageServerId, report: lsp::WorkspaceFullDocumentDiagnosticReport, ) { @@ -4006,7 +3984,7 @@ fn process_full_workspace_diagnostics_report( } fn process_unchanged_workspace_diagnostics_report( - diagnostics: &mut HashMap, + diagnostics: &mut HashMap, server_id: LanguageServerId, report: lsp::WorkspaceUnchangedDocumentDiagnosticReport, ) { @@ -4365,9 +4343,9 @@ impl LspCommand for GetDocumentColor { } fn process_related_documents( - diagnostics: &mut HashMap, + diagnostics: &mut HashMap, server_id: LanguageServerId, - documents: impl IntoIterator, + documents: impl IntoIterator, ) { for (url, report_kind) in documents { match report_kind { @@ -4382,9 +4360,9 @@ fn process_related_documents( } fn process_unchanged_diagnostics_report( - diagnostics: &mut HashMap, + diagnostics: &mut HashMap, server_id: LanguageServerId, - uri: lsp::Url, + uri: lsp::Uri, report: lsp::UnchangedDocumentDiagnosticReport, ) { let result_id = report.result_id; @@ -4426,9 +4404,9 @@ fn process_unchanged_diagnostics_report( } fn process_full_diagnostics_report( - diagnostics: &mut HashMap, + diagnostics: &mut HashMap, server_id: LanguageServerId, - uri: lsp::Url, + uri: lsp::Uri, report: lsp::FullDocumentDiagnosticReport, ) { let result_id = report.result_id; @@ -4509,9 +4487,8 @@ mod tests { data: Some(json!({"detail": "test detail"})), }; - let proto_diagnostic = - GetDocumentDiagnostics::serialize_lsp_diagnostic(lsp_diagnostic.clone()) - .expect("Failed to serialize diagnostic"); + let proto_diagnostic = GetDocumentDiagnostics::serialize_lsp_diagnostic(lsp_diagnostic) + .expect("Failed to serialize diagnostic"); let start = proto_diagnostic.start.unwrap(); let end = proto_diagnostic.end.unwrap(); @@ -4563,7 +4540,7 @@ mod tests { fn test_related_information() { let related_info = lsp::DiagnosticRelatedInformation { location: lsp::Location { - uri: lsp::Url::parse("file:///test.rs").unwrap(), + uri: lsp::Uri::from_str("file:///test.rs").unwrap(), range: lsp::Range { start: lsp::Position::new(1, 1), end: lsp::Position::new(1, 5), diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index 60d847023fbcf1d0839dfd53f688a7a11eb156bb..7430aa016f029ef61186a1a02b4fb91fa67a1b1a 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -1,25 +1,40 @@ +//! LSP store provides unified access to the language server protocol. +//! The consumers of LSP store can interact with language servers without knowing exactly which language server they're interacting with. +//! +//! # Local/Remote LSP Stores +//! This module is split up into three distinct parts: +//! - [`LocalLspStore`], which is ran on the host machine (either project host or SSH host), that manages the lifecycle of language servers. +//! - [`RemoteLspStore`], which is ran on the remote machine (project guests) which is mostly about passing through the requests via RPC. +//! The remote stores don't really care about which language server they're running against - they don't usually get to decide which language server is going to responsible for handling their request. +//! - [`LspStore`], which unifies the two under one consistent interface for interacting with language servers. +//! +//! Most of the interesting work happens at the local layer, as bulk of the complexity is with managing the lifecycle of language servers. The actual implementation of the LSP protocol is handled by [`lsp`] crate. pub mod clangd_ext; pub mod json_language_server_ext; +pub mod log_store; pub mod lsp_ext_command; pub mod rust_analyzer_ext; use crate::{ - CodeAction, ColorPresentation, Completion, CompletionResponse, CompletionSource, - CoreCompletion, DocumentColor, Hover, InlayHint, LocationLink, LspAction, LspPullDiagnostics, - ProjectItem, ProjectPath, ProjectTransaction, PulledDiagnostics, ResolveState, Symbol, - ToolchainStore, + CodeAction, ColorPresentation, Completion, CompletionDisplayOptions, CompletionResponse, + CompletionSource, CoreCompletion, DocumentColor, Hover, InlayHint, LocationLink, LspAction, + LspPullDiagnostics, ManifestProvidersStore, Project, ProjectItem, ProjectPath, + ProjectTransaction, PulledDiagnostics, ResolveState, Symbol, buffer_store::{BufferStore, BufferStoreEvent}, environment::ProjectEnvironment, lsp_command::{self, *}, - lsp_store, + lsp_store::{ + self, + log_store::{GlobalLogStore, LanguageServerKind}, + }, manifest_tree::{ - AdapterQuery, LanguageServerTree, LanguageServerTreeNode, LaunchDisposition, - ManifestQueryDelegate, ManifestTree, + LanguageServerTree, LanguageServerTreeNode, LaunchDisposition, ManifestQueryDelegate, + ManifestTree, }, prettier_store::{self, PrettierStore, PrettierStoreEvent}, project_settings::{LspSettings, ProjectSettings}, relativize_path, resolve_path, - toolchain_store::{EmptyToolchainStore, ToolchainStoreEvent}, + toolchain_store::{LocalToolchainStore, ToolchainStoreEvent}, worktree_store::{WorktreeStore, WorktreeStoreEvent}, yarn::YarnPathStore, }; @@ -44,9 +59,9 @@ use itertools::Itertools as _; use language::{ Bias, BinaryStatus, Buffer, BufferSnapshot, CachedLspAdapter, CodeLabel, Diagnostic, DiagnosticEntry, DiagnosticSet, DiagnosticSourceKind, Diff, File as _, Language, LanguageName, - LanguageRegistry, LanguageToolchainStore, LocalFile, LspAdapter, LspAdapterDelegate, Patch, - PointUtf16, TextBufferSnapshot, ToOffset, ToPointUtf16, Transaction, Unclipped, - WorkspaceFoldersContent, + LanguageRegistry, LocalFile, LspAdapter, LspAdapterDelegate, ManifestDelegate, ManifestName, + Patch, PointUtf16, TextBufferSnapshot, ToOffset, ToPoint, ToPointUtf16, Toolchain, Transaction, + Unclipped, language_settings::{ FormatOnSave, Formatter, LanguageSettings, SelectedFormatter, language_settings, }, @@ -61,19 +76,19 @@ use lsp::{ AdapterServerCapabilities, CodeActionKind, CompletionContext, DiagnosticSeverity, DiagnosticTag, DidChangeWatchedFilesRegistrationOptions, Edit, FileOperationFilter, FileOperationPatternKind, FileOperationRegistrationOptions, FileRename, FileSystemWatcher, - LanguageServer, LanguageServerBinary, LanguageServerBinaryOptions, LanguageServerId, - LanguageServerName, LanguageServerSelector, LspRequestFuture, MessageActionItem, MessageType, - OneOf, RenameFilesParams, SymbolKind, TextEdit, WillRenameFiles, WorkDoneProgressCancelParams, + LSP_REQUEST_TIMEOUT, LanguageServer, LanguageServerBinary, LanguageServerBinaryOptions, + LanguageServerId, LanguageServerName, LanguageServerSelector, LspRequestFuture, + MessageActionItem, MessageType, OneOf, RenameFilesParams, SymbolKind, + TextDocumentSyncSaveOptions, TextEdit, Uri, WillRenameFiles, WorkDoneProgressCancelParams, WorkspaceFolder, notification::DidRenameFiles, }; use node_runtime::read_package_installed_version; use parking_lot::Mutex; use postage::{mpsc, sink::Sink, stream::Stream, watch}; use rand::prelude::*; - use rpc::{ AnyProtoClient, - proto::{FromProto, ToProto}, + proto::{FromProto, LspRequestId, LspRequestMessage as _, ToProto}, }; use serde::Serialize; use settings::{Settings, SettingsLocation, SettingsStore}; @@ -81,7 +96,7 @@ use sha2::{Digest, Sha256}; use smol::channel::Sender; use snippet::Snippet; use std::{ - any::Any, + any::{Any, TypeId}, borrow::Cow, cell::RefCell, cmp::{Ordering, Reverse}, @@ -98,7 +113,7 @@ use std::{ }; use sum_tree::Dimensions; use text::{Anchor, BufferId, LineEnding, OffsetRangeExt}; -use url::Url; + use util::{ ConnectionResult, ResultExt as _, debug_panic, defer, maybe, merge_json_value_into, paths::{PathExt, SanitizedPath}, @@ -140,6 +155,20 @@ impl FormatTrigger { } } +#[derive(Clone)] +struct UnifiedLanguageServer { + id: LanguageServerId, + project_roots: HashSet>, +} + +#[derive(Clone, Hash, PartialEq, Eq)] +struct LanguageServerSeed { + worktree_id: WorktreeId, + name: LanguageServerName, + toolchain: Option, + settings: Arc, +} + #[derive(Debug)] pub struct DocumentDiagnosticsUpdate<'a, D> { pub diagnostics: D, @@ -157,17 +186,18 @@ pub struct DocumentDiagnostics { pub struct LocalLspStore { weak: WeakEntity, worktree_store: Entity, - toolchain_store: Entity, + toolchain_store: Entity, http_client: Arc, environment: Entity, fs: Arc, languages: Arc, - language_server_ids: HashMap<(WorktreeId, LanguageServerName), BTreeSet>, + language_server_ids: HashMap, yarn: Entity, pub language_servers: HashMap, buffers_being_formatted: HashSet, last_workspace_edits_by_language_server: HashMap, language_server_watched_paths: HashMap, + watched_manifest_filenames: HashSet, language_server_paths_watched_for_rename: HashMap, language_server_watcher_registrations: @@ -188,7 +218,7 @@ pub struct LocalLspStore { >, buffer_snapshots: HashMap>>, // buffer_id -> server_id -> vec of snapshots _subscription: gpui::Subscription, - lsp_tree: Entity, + lsp_tree: LanguageServerTree, registered_buffers: HashMap, buffers_opened_in_servers: HashMap>, buffer_pull_diagnostics_result_ids: HashMap>>, @@ -208,31 +238,82 @@ impl LocalLspStore { } } + fn get_or_insert_language_server( + &mut self, + worktree_handle: &Entity, + delegate: Arc, + disposition: &Arc, + language_name: &LanguageName, + cx: &mut App, + ) -> LanguageServerId { + let key = LanguageServerSeed { + worktree_id: worktree_handle.read(cx).id(), + name: disposition.server_name.clone(), + settings: disposition.settings.clone(), + toolchain: disposition.toolchain.clone(), + }; + if let Some(state) = self.language_server_ids.get_mut(&key) { + state.project_roots.insert(disposition.path.path.clone()); + state.id + } else { + let adapter = self + .languages + .lsp_adapters(language_name) + .into_iter() + .find(|adapter| adapter.name() == disposition.server_name) + .expect("To find LSP adapter"); + let new_language_server_id = self.start_language_server( + worktree_handle, + delegate, + adapter, + disposition.settings.clone(), + key.clone(), + cx, + ); + if let Some(state) = self.language_server_ids.get_mut(&key) { + state.project_roots.insert(disposition.path.path.clone()); + } else { + debug_assert!( + false, + "Expected `start_language_server` to ensure that `key` exists in a map" + ); + } + new_language_server_id + } + } + fn start_language_server( &mut self, worktree_handle: &Entity, delegate: Arc, adapter: Arc, settings: Arc, + key: LanguageServerSeed, cx: &mut App, ) -> LanguageServerId { let worktree = worktree_handle.read(cx); - let worktree_id = worktree.id(); - let root_path = worktree.abs_path(); - let key = (worktree_id, adapter.name.clone()); + let root_path = worktree.abs_path(); + let toolchain = key.toolchain.clone(); let override_options = settings.initialization_options.clone(); let stderr_capture = Arc::new(Mutex::new(Some(String::new()))); let server_id = self.languages.next_language_server_id(); - log::info!( + log::trace!( "attempting to start language server {:?}, path: {root_path:?}, id: {server_id}", adapter.name.0 ); - let binary = self.get_language_server_binary(adapter.clone(), delegate.clone(), true, cx); - let pending_workspace_folders: Arc>> = Default::default(); + let binary = self.get_language_server_binary( + adapter.clone(), + settings, + toolchain.clone(), + delegate.clone(), + true, + cx, + ); + let pending_workspace_folders: Arc>> = Default::default(); let pending_server = cx.spawn({ let adapter = adapter.clone(); @@ -267,10 +348,7 @@ impl LocalLspStore { binary, &root_path, code_action_kinds, - Some(pending_workspace_folders).filter(|_| { - adapter.adapter.workspace_folders_content() - == WorkspaceFoldersContent::SubprojectRoots - }), + Some(pending_workspace_folders), cx, ) } @@ -290,15 +368,13 @@ impl LocalLspStore { .enabled; cx.spawn(async move |cx| { let result = async { - let toolchains = - lsp_store.update(cx, |lsp_store, cx| lsp_store.toolchain_store(cx))?; let language_server = pending_server.await?; let workspace_config = Self::workspace_configuration_for_adapter( adapter.adapter.clone(), fs.as_ref(), &delegate, - toolchains.clone(), + toolchain, cx, ) .await?; @@ -370,14 +446,14 @@ impl LocalLspStore { match result { Ok(server) => { lsp_store - .update(cx, |lsp_store, mut cx| { + .update(cx, |lsp_store, cx| { lsp_store.insert_newly_running_language_server( adapter, server.clone(), server_id, key, pending_workspace_folders, - &mut cx, + cx, ); }) .ok(); @@ -417,31 +493,26 @@ impl LocalLspStore { self.language_servers.insert(server_id, state); self.language_server_ids .entry(key) - .or_default() - .insert(server_id); + .or_insert(UnifiedLanguageServer { + id: server_id, + project_roots: Default::default(), + }); server_id } fn get_language_server_binary( &self, adapter: Arc, + settings: Arc, + toolchain: Option, delegate: Arc, allow_binary_download: bool, cx: &mut App, ) -> Task> { - let settings = ProjectSettings::get( - Some(SettingsLocation { - worktree_id: delegate.worktree_id(), - path: Path::new(""), - }), - cx, - ) - .lsp - .get(&adapter.name) - .and_then(|s| s.binary.clone()); - - if settings.as_ref().is_some_and(|b| b.path.is_some()) { - let settings = settings.unwrap(); + if let Some(settings) = settings.binary.as_ref() + && settings.path.is_some() + { + let settings = settings.clone(); return cx.background_spawn(async move { let mut env = delegate.shell_env().await; @@ -461,16 +532,17 @@ impl LocalLspStore { } let lsp_binary_options = LanguageServerBinaryOptions { allow_path_lookup: !settings + .binary .as_ref() .and_then(|b| b.ignore_system_version) .unwrap_or_default(), allow_binary_download, }; - let toolchains = self.toolchain_store.read(cx).as_language_toolchain_store(); + cx.spawn(async move |cx| { let binary_result = adapter .clone() - .get_language_server_command(delegate.clone(), toolchains, lsp_binary_options, cx) + .get_language_server_command(delegate.clone(), toolchain, lsp_binary_options, cx) .await; delegate.update_status(adapter.name.clone(), BinaryStatus::None); @@ -480,12 +552,12 @@ impl LocalLspStore { shell_env.extend(binary.env.unwrap_or_default()); - if let Some(settings) = settings { - if let Some(arguments) = settings.arguments { - binary.arguments = arguments.into_iter().map(Into::into).collect(); + if let Some(settings) = settings.binary.as_ref() { + if let Some(arguments) = &settings.arguments { + binary.arguments = arguments.iter().map(Into::into).collect(); } - if let Some(env) = settings.env { - shell_env.extend(env); + if let Some(env) = &settings.env { + shell_env.extend(env.iter().map(|(k, v)| (k.clone(), v.clone()))); } } @@ -559,14 +631,20 @@ impl LocalLspStore { let fs = fs.clone(); let mut cx = cx.clone(); async move { - let toolchains = - this.update(&mut cx, |this, cx| this.toolchain_store(cx))?; - + let toolchain_for_id = this + .update(&mut cx, |this, _| { + this.as_local()?.language_server_ids.iter().find_map( + |(seed, value)| { + (value.id == server_id).then(|| seed.toolchain.clone()) + }, + ) + })? + .context("Expected the LSP store to be in a local mode")?; let workspace_config = Self::workspace_configuration_for_adapter( adapter.clone(), fs.as_ref(), &delegate, - toolchains.clone(), + toolchain_for_id, &mut cx, ) .await?; @@ -595,10 +673,10 @@ impl LocalLspStore { let this = this.clone(); move |_, cx| { let this = this.clone(); - let mut cx = cx.clone(); + let cx = cx.clone(); async move { - let Some(server) = this - .read_with(&mut cx, |this, _| this.language_server_for_id(server_id))? + let Some(server) = + this.read_with(&cx, |this, _| this.language_server_for_id(server_id))? else { return Ok(None); }; @@ -627,10 +705,9 @@ impl LocalLspStore { async move { this.update(&mut cx, |this, _| { if let Some(status) = this.language_server_statuses.get_mut(&server_id) + && let lsp::NumberOrString::String(token) = params.token { - if let lsp::NumberOrString::String(token) = params.token { - status.progress_tokens.insert(token); - } + status.progress_tokens.insert(token); } })?; @@ -700,18 +777,15 @@ impl LocalLspStore { language_server .on_request::({ - let adapter = adapter.clone(); let this = this.clone(); move |params, cx| { let mut cx = cx.clone(); let this = this.clone(); - let adapter = adapter.clone(); async move { LocalLspStore::on_lsp_workspace_edit( this.clone(), params, server_id, - adapter.clone(), &mut cx, ) .await @@ -847,7 +921,7 @@ impl LocalLspStore { message: params.message, actions: vec![], response_channel: tx, - lsp_name: name.clone(), + lsp_name: name, }; let _ = this.update(&mut cx, |_, cx| { @@ -906,7 +980,9 @@ impl LocalLspStore { this.update(&mut cx, |_, cx| { cx.emit(LspStoreEvent::LanguageServerLog( server_id, - LanguageServerLogType::Trace(params.verbose), + LanguageServerLogType::Trace { + verbose_info: params.verbose, + }, params.message, )); }) @@ -944,10 +1020,10 @@ impl LocalLspStore { } } LanguageServerState::Starting { startup, .. } => { - if let Some(server) = startup.await { - if let Some(shutdown) = server.shutdown() { - shutdown.await; - } + if let Some(server) = startup.await + && let Some(shutdown) = server.shutdown() + { + shutdown.await; } } } @@ -960,19 +1036,18 @@ impl LocalLspStore { ) -> impl Iterator> { self.language_server_ids .iter() - .flat_map(move |((language_server_path, _), ids)| { - ids.iter().filter_map(move |id| { - if *language_server_path != worktree_id { - return None; - } - if let Some(LanguageServerState::Running { server, .. }) = - self.language_servers.get(id) - { - return Some(server); - } else { - None - } - }) + .filter_map(move |(seed, state)| { + if seed.worktree_id != worktree_id { + return None; + } + + if let Some(LanguageServerState::Running { server, .. }) = + self.language_servers.get(&state.id) + { + Some(server) + } else { + None + } }) } @@ -989,19 +1064,18 @@ impl LocalLspStore { else { return Vec::new(); }; - let delegate = Arc::new(ManifestQueryDelegate::new(worktree.read(cx).snapshot())); - let root = self.lsp_tree.update(cx, |this, cx| { - this.get( + let delegate: Arc = + Arc::new(ManifestQueryDelegate::new(worktree.read(cx).snapshot())); + + self.lsp_tree + .get( project_path, - AdapterQuery::Language(&language.name()), - delegate, + language.name(), + language.manifest(), + &delegate, cx, ) - .filter_map(|node| node.server_id()) .collect::>() - }); - - root } fn language_server_ids_for_buffer( @@ -1083,7 +1157,7 @@ impl LocalLspStore { .collect::>() }) })?; - for (lsp_adapter, language_server) in adapters_and_servers.iter() { + for (_, language_server) in adapters_and_servers.iter() { let actions = Self::get_server_code_actions_from_action_kinds( &lsp_store, language_server.server_id(), @@ -1095,7 +1169,6 @@ impl LocalLspStore { Self::execute_code_actions_on_server( &lsp_store, language_server, - lsp_adapter, actions, push_to_history, &mut project_transaction, @@ -1810,7 +1883,7 @@ impl LocalLspStore { ) -> Result, Arc)>> { let capabilities = &language_server.capabilities(); let range_formatting_provider = capabilities.document_range_formatting_provider.as_ref(); - if range_formatting_provider.map_or(false, |provider| provider == &OneOf::Left(false)) { + if range_formatting_provider == Some(&OneOf::Left(false)) { anyhow::bail!( "{} language server does not support range formatting", language_server.name() @@ -1857,7 +1930,7 @@ impl LocalLspStore { if let Some(lsp_edits) = lsp_edits { this.update(cx, |this, cx| { this.as_local_mut().unwrap().edits_from_lsp( - &buffer_handle, + buffer_handle, lsp_edits, language_server.server_id(), None, @@ -2038,13 +2111,14 @@ impl LocalLspStore { let buffer = buffer_handle.read(cx); let file = buffer.file().cloned(); + let Some(file) = File::from_dyn(file.as_ref()) else { return; }; if !file.is_local() { return; } - + let path = ProjectPath::from_file(file, cx); let worktree_id = file.worktree_id(cx); let language = buffer.language().cloned(); @@ -2067,46 +2141,52 @@ impl LocalLspStore { let Some(language) = language else { return; }; - for adapter in self.languages.lsp_adapters(&language.name()) { - let servers = self - .language_server_ids - .get(&(worktree_id, adapter.name.clone())); - if let Some(server_ids) = servers { - for server_id in server_ids { - let server = self - .language_servers - .get(server_id) - .and_then(|server_state| { - if let LanguageServerState::Running { server, .. } = server_state { - Some(server.clone()) - } else { - None - } - }); - let server = match server { - Some(server) => server, - None => continue, - }; + let Some(snapshot) = self + .worktree_store + .read(cx) + .worktree_for_id(worktree_id, cx) + .map(|worktree| worktree.read(cx).snapshot()) + else { + return; + }; + let delegate: Arc = Arc::new(ManifestQueryDelegate::new(snapshot)); - buffer_handle.update(cx, |buffer, cx| { - buffer.set_completion_triggers( - server.server_id(), - server - .capabilities() - .completion_provider + for server_id in + self.lsp_tree + .get(path, language.name(), language.manifest(), &delegate, cx) + { + let server = self + .language_servers + .get(&server_id) + .and_then(|server_state| { + if let LanguageServerState::Running { server, .. } = server_state { + Some(server.clone()) + } else { + None + } + }); + let server = match server { + Some(server) => server, + None => continue, + }; + + buffer_handle.update(cx, |buffer, cx| { + buffer.set_completion_triggers( + server.server_id(), + server + .capabilities() + .completion_provider + .as_ref() + .and_then(|provider| { + provider + .trigger_characters .as_ref() - .and_then(|provider| { - provider - .trigger_characters - .as_ref() - .map(|characters| characters.iter().cloned().collect()) - }) - .unwrap_or_default(), - cx, - ); - }); - } - } + .map(|characters| characters.iter().cloned().collect()) + }) + .unwrap_or_default(), + cx, + ); + }); } } @@ -2216,6 +2296,31 @@ impl LocalLspStore { Ok(()) } + fn register_language_server_for_invisible_worktree( + &mut self, + worktree: &Entity, + language_server_id: LanguageServerId, + cx: &mut App, + ) { + let worktree = worktree.read(cx); + let worktree_id = worktree.id(); + debug_assert!(!worktree.is_visible()); + let Some(mut origin_seed) = self + .language_server_ids + .iter() + .find_map(|(seed, state)| (state.id == language_server_id).then(|| seed.clone())) + else { + return; + }; + origin_seed.worktree_id = worktree_id; + self.language_server_ids + .entry(origin_seed) + .or_insert_with(|| UnifiedLanguageServer { + id: language_server_id, + project_roots: Default::default(), + }); + } + fn register_buffer_with_language_servers( &mut self, buffer_handle: &Entity, @@ -2256,27 +2361,23 @@ impl LocalLspStore { }; let language_name = language.name(); let (reused, delegate, servers) = self - .lsp_tree - .update(cx, |lsp_tree, cx| { - self.reuse_existing_language_server(lsp_tree, &worktree, &language_name, cx) - }) - .map(|(delegate, servers)| (true, delegate, servers)) + .reuse_existing_language_server(&self.lsp_tree, &worktree, &language_name, cx) + .map(|(delegate, apply)| (true, delegate, apply(&mut self.lsp_tree))) .unwrap_or_else(|| { let lsp_delegate = LocalLspAdapterDelegate::from_local_lsp(self, &worktree, cx); - let delegate = Arc::new(ManifestQueryDelegate::new(worktree.read(cx).snapshot())); + let delegate: Arc = + Arc::new(ManifestQueryDelegate::new(worktree.read(cx).snapshot())); + let servers = self .lsp_tree - .clone() - .update(cx, |language_server_tree, cx| { - language_server_tree - .get( - ProjectPath { worktree_id, path }, - AdapterQuery::Language(&language.name()), - delegate.clone(), - cx, - ) - .collect::>() - }); + .walk( + ProjectPath { worktree_id, path }, + language.name(), + language.manifest(), + &delegate, + cx, + ) + .collect::>(); (false, lsp_delegate, servers) }); let servers_and_adapters = servers @@ -2286,67 +2387,46 @@ impl LocalLspStore { return None; } if !only_register_servers.is_empty() { - if let Some(server_id) = server_node.server_id() { - if !only_register_servers.contains(&LanguageServerSelector::Id(server_id)) { - return None; - } + if let Some(server_id) = server_node.server_id() + && !only_register_servers.contains(&LanguageServerSelector::Id(server_id)) + { + return None; } - if let Some(name) = server_node.name() { - if !only_register_servers.contains(&LanguageServerSelector::Name(name)) { - return None; - } + if let Some(name) = server_node.name() + && !only_register_servers.contains(&LanguageServerSelector::Name(name)) + { + return None; } } - let server_id = server_node.server_id_or_init( - |LaunchDisposition { - server_name, - path, - settings, - }| { - let server_id = - { - let uri = Url::from_file_path( - worktree.read(cx).abs_path().join(&path.path), - ); - let key = (worktree_id, server_name.clone()); - if !self.language_server_ids.contains_key(&key) { - let language_name = language.name(); - let adapter = self.languages - .lsp_adapters(&language_name) - .into_iter() - .find(|adapter| &adapter.name() == server_name) - .expect("To find LSP adapter"); - self.start_language_server( - &worktree, - delegate.clone(), - adapter, - settings, - cx, - ); - } - if let Some(server_ids) = self - .language_server_ids - .get(&key) - { - debug_assert_eq!(server_ids.len(), 1); - let server_id = server_ids.iter().cloned().next().unwrap(); - if let Some(state) = self.language_servers.get(&server_id) { - if let Ok(uri) = uri { - state.add_workspace_folder(uri); - }; - } - server_id - } else { - unreachable!("Language server ID should be available, as it's registered on demand") - } + let server_id = server_node.server_id_or_init(|disposition| { + let path = &disposition.path; + + { + let uri = + Uri::from_file_path(worktree.read(cx).abs_path().join(&path.path)); + + let server_id = self.get_or_insert_language_server( + &worktree, + delegate.clone(), + disposition, + &language_name, + cx, + ); + if let Some(state) = self.language_servers.get(&server_id) + && let Ok(uri) = uri + { + state.add_workspace_folder(uri); }; server_id - }, - )?; + } + })?; let server_state = self.language_servers.get(&server_id)?; - if let LanguageServerState::Running { server, adapter, .. } = server_state { + if let LanguageServerState::Running { + server, adapter, .. + } = server_state + { Some((server.clone(), adapter.clone())) } else { None @@ -2413,13 +2493,16 @@ impl LocalLspStore { } } - fn reuse_existing_language_server( + fn reuse_existing_language_server<'lang_name>( &self, - server_tree: &mut LanguageServerTree, + server_tree: &LanguageServerTree, worktree: &Entity, - language_name: &LanguageName, + language_name: &'lang_name LanguageName, cx: &mut App, - ) -> Option<(Arc, Vec)> { + ) -> Option<( + Arc, + impl FnOnce(&mut LanguageServerTree) -> Vec + use<'lang_name>, + )> { if worktree.read(cx).is_visible() { return None; } @@ -2458,16 +2541,16 @@ impl LocalLspStore { .into_values() .max_by_key(|servers| servers.len())?; - for server_node in &servers { - server_tree.register_reused( - worktree.read(cx).id(), - language_name.clone(), - server_node.clone(), - ); - } + let worktree_id = worktree.read(cx).id(); + let apply = move |tree: &mut LanguageServerTree| { + for server_node in &servers { + tree.register_reused(worktree_id, language_name.clone(), server_node.clone()); + } + servers + }; let delegate = LocalLspAdapterDelegate::from_local_lsp(self, worktree, cx); - Some((delegate, servers)) + Some((delegate, apply)) } pub(crate) fn unregister_old_buffer_from_language_servers( @@ -2481,11 +2564,8 @@ impl LocalLspStore { None => return, }; - let Ok(file_url) = lsp::Url::from_file_path(old_path.as_path()) else { - debug_panic!( - "`{}` is not parseable as an URI", - old_path.to_string_lossy() - ); + let Ok(file_url) = lsp::Uri::from_file_path(old_path.as_path()) else { + debug_panic!("{old_path:?} is not parseable as an URI"); return; }; self.unregister_buffer_from_language_servers(buffer, &file_url, cx); @@ -2494,7 +2574,7 @@ impl LocalLspStore { pub(crate) fn unregister_buffer_from_language_servers( &mut self, buffer: &Entity, - file_url: &lsp::Url, + file_url: &lsp::Uri, cx: &mut App, ) { buffer.update(cx, |buffer, cx| { @@ -2562,13 +2642,13 @@ impl LocalLspStore { this.request_lsp(buffer.clone(), server, request, cx) })? .await?; - return Ok(actions); + Ok(actions) } pub async fn execute_code_actions_on_server( lsp_store: &WeakEntity, language_server: &Arc, - lsp_adapter: &Arc, + actions: Vec, push_to_history: bool, project_transaction: &mut ProjectTransaction, @@ -2588,7 +2668,6 @@ impl LocalLspStore { lsp_store.upgrade().context("project dropped")?, edit.clone(), push_to_history, - lsp_adapter.clone(), language_server.clone(), cx, ) @@ -2639,7 +2718,7 @@ impl LocalLspStore { } } } - return Ok(()); + Ok(()) } pub async fn deserialize_text_edits( @@ -2769,7 +2848,6 @@ impl LocalLspStore { this: Entity, edit: lsp::WorkspaceEdit, push_to_history: bool, - lsp_adapter: Arc, language_server: Arc, cx: &mut AsyncApp, ) -> Result { @@ -2870,7 +2948,6 @@ impl LocalLspStore { this.open_local_buffer_via_lsp( op.text_document.uri.clone(), language_server.server_id(), - lsp_adapter.name.clone(), cx, ) })? @@ -2880,11 +2957,11 @@ impl LocalLspStore { .update(cx, |this, cx| { let path = buffer_to_edit.read(cx).project_path(cx); let active_entry = this.active_entry; - let is_active_entry = path.clone().map_or(false, |project_path| { + let is_active_entry = path.is_some_and(|project_path| { this.worktree_store .read(cx) .entry_for_path(&project_path, cx) - .map_or(false, |entry| Some(entry.id) == active_entry) + .is_some_and(|entry| Some(entry.id) == active_entry) }); let local = this.as_local_mut().unwrap(); @@ -2970,16 +3047,14 @@ impl LocalLspStore { buffer.edit([(range, text)], None, cx); } - let transaction = buffer.end_transaction(cx).and_then(|transaction_id| { + buffer.end_transaction(cx).and_then(|transaction_id| { if push_to_history { buffer.finalize_last_transaction(); buffer.get_transaction(transaction_id).cloned() } else { buffer.forget_transaction(transaction_id) } - }); - - transaction + }) })?; if let Some(transaction) = transaction { project_transaction.0.insert(buffer_to_edit, transaction); @@ -2995,7 +3070,6 @@ impl LocalLspStore { this: WeakEntity, params: lsp::ApplyWorkspaceEditParams, server_id: LanguageServerId, - adapter: Arc, cx: &mut AsyncApp, ) -> Result { let this = this.upgrade().context("project project closed")?; @@ -3006,7 +3080,6 @@ impl LocalLspStore { this.clone(), params.edit, true, - adapter.clone(), language_server.clone(), cx, ) @@ -3037,23 +3110,19 @@ impl LocalLspStore { prettier_store.remove_worktree(id_to_remove, cx); }); - let mut servers_to_remove = BTreeMap::default(); + let mut servers_to_remove = BTreeSet::default(); let mut servers_to_preserve = HashSet::default(); - for ((path, server_name), ref server_ids) in &self.language_server_ids { - if *path == id_to_remove { - servers_to_remove.extend(server_ids.iter().map(|id| (*id, server_name.clone()))); + for (seed, state) in &self.language_server_ids { + if seed.worktree_id == id_to_remove { + servers_to_remove.insert(state.id); } else { - servers_to_preserve.extend(server_ids.iter().cloned()); + servers_to_preserve.insert(state.id); } } - servers_to_remove.retain(|server_id, _| !servers_to_preserve.contains(server_id)); - - for (server_id_to_remove, _) in &servers_to_remove { - self.language_server_ids - .values_mut() - .for_each(|server_ids| { - server_ids.remove(server_id_to_remove); - }); + servers_to_remove.retain(|server_id| !servers_to_preserve.contains(server_id)); + self.language_server_ids + .retain(|_, state| !servers_to_remove.contains(&state.id)); + for server_id_to_remove in &servers_to_remove { self.language_server_watched_paths .remove(server_id_to_remove); self.language_server_paths_watched_for_rename @@ -3068,7 +3137,7 @@ impl LocalLspStore { } cx.emit(LspStoreEvent::LanguageServerRemoved(*server_id_to_remove)); } - servers_to_remove.into_keys().collect() + servers_to_remove.into_iter().collect() } fn rebuild_watched_paths_inner<'a>( @@ -3097,7 +3166,7 @@ impl LocalLspStore { for watcher in watchers { if let Some((worktree, literal_prefix, pattern)) = - self.worktree_and_path_for_file_watcher(&worktrees, &watcher, cx) + self.worktree_and_path_for_file_watcher(&worktrees, watcher, cx) { worktree.update(cx, |worktree, _| { if let Some((tree, glob)) = @@ -3113,7 +3182,7 @@ impl LocalLspStore { } else { let (path, pattern) = match &watcher.glob_pattern { lsp::GlobPattern::String(s) => { - let watcher_path = SanitizedPath::from(s); + let watcher_path = SanitizedPath::new(s); let path = glob_literal_prefix(watcher_path.as_path()); let pattern = watcher_path .as_path() @@ -3205,7 +3274,7 @@ impl LocalLspStore { let worktree_root_path = tree.abs_path(); match &watcher.glob_pattern { lsp::GlobPattern::String(s) => { - let watcher_path = SanitizedPath::from(s); + let watcher_path = SanitizedPath::new(s); let relative = watcher_path .as_path() .strip_prefix(&worktree_root_path) @@ -3326,16 +3395,20 @@ impl LocalLspStore { Ok(Some(initialization_config)) } + fn toolchain_store(&self) -> &Entity { + &self.toolchain_store + } + async fn workspace_configuration_for_adapter( adapter: Arc, fs: &dyn Fs, delegate: &Arc, - toolchains: Arc, + toolchain: Option, cx: &mut AsyncApp, ) -> Result { let mut workspace_config = adapter .clone() - .workspace_configuration(fs, delegate, toolchains.clone(), cx) + .workspace_configuration(fs, delegate, toolchain, cx) .await?; for other_adapter in delegate.registered_lsp_adapters() { @@ -3344,13 +3417,7 @@ impl LocalLspStore { } if let Ok(Some(target_config)) = other_adapter .clone() - .additional_workspace_configuration( - adapter.name(), - fs, - delegate, - toolchains.clone(), - cx, - ) + .additional_workspace_configuration(adapter.name(), fs, delegate, cx) .await { merge_json_value_into(target_config.clone(), &mut workspace_config); @@ -3416,17 +3483,17 @@ pub struct LspStore { nonce: u128, buffer_store: Entity, worktree_store: Entity, - toolchain_store: Option>, pub languages: Arc, - language_server_statuses: BTreeMap, + pub language_server_statuses: BTreeMap, active_entry: Option, _maintain_workspace_config: (Task>, watch::Sender<()>), _maintain_buffer_languages: Task<()>, diagnostic_summaries: HashMap, HashMap>>, - pub(super) lsp_server_capabilities: HashMap, + pub lsp_server_capabilities: HashMap, lsp_document_colors: HashMap, lsp_code_lens: HashMap, + running_lsp_requests: HashMap>)>, } #[derive(Debug, Default, Clone)] @@ -3436,7 +3503,7 @@ pub struct DocumentColors { } type DocumentColorTask = Shared>>>; -type CodeLensTask = Shared, Arc>>>; +type CodeLensTask = Shared>, Arc>>>; #[derive(Debug, Default)] struct DocumentColorData { @@ -3500,6 +3567,7 @@ pub struct LanguageServerStatus { pub pending_work: BTreeMap, pub has_pending_diagnostic_updates: bool, progress_tokens: HashSet, + pub worktree: Option, } #[derive(Clone, Debug)] @@ -3516,6 +3584,8 @@ struct CoreSymbol { impl LspStore { pub fn init(client: &AnyProtoClient) { + client.add_entity_request_handler(Self::handle_lsp_query); + client.add_entity_message_handler(Self::handle_lsp_query_response); client.add_entity_request_handler(Self::handle_multi_lsp_query); client.add_entity_request_handler(Self::handle_restart_language_servers); client.add_entity_request_handler(Self::handle_stop_language_servers); @@ -3607,7 +3677,7 @@ impl LspStore { buffer_store: Entity, worktree_store: Entity, prettier_store: Entity, - toolchain_store: Entity, + toolchain_store: Entity, environment: Entity, manifest_tree: Entity, languages: Arc, @@ -3649,7 +3719,7 @@ impl LspStore { mode: LspStoreMode::Local(LocalLspStore { weak: cx.weak_entity(), worktree_store: worktree_store.clone(), - toolchain_store: toolchain_store.clone(), + supplementary_language_servers: Default::default(), languages: languages.clone(), language_server_ids: Default::default(), @@ -3672,23 +3742,30 @@ impl LspStore { .unwrap() .shutdown_language_servers_on_quit(cx) }), - lsp_tree: LanguageServerTree::new(manifest_tree, languages.clone(), cx), + lsp_tree: LanguageServerTree::new( + manifest_tree, + languages.clone(), + toolchain_store.clone(), + ), + toolchain_store, registered_buffers: HashMap::default(), buffers_opened_in_servers: HashMap::default(), buffer_pull_diagnostics_result_ids: HashMap::default(), + watched_manifest_filenames: ManifestProvidersStore::global(cx) + .manifest_file_names(), }), last_formatting_failure: None, downstream_client: None, buffer_store, worktree_store, - toolchain_store: Some(toolchain_store), languages: languages.clone(), language_server_statuses: Default::default(), - nonce: StdRng::from_entropy().r#gen(), + nonce: StdRng::from_os_rng().random(), diagnostic_summaries: HashMap::default(), lsp_server_capabilities: HashMap::default(), lsp_document_colors: HashMap::default(), lsp_code_lens: HashMap::default(), + running_lsp_requests: HashMap::default(), active_entry: None, _maintain_workspace_config, _maintain_buffer_languages: Self::maintain_buffer_languages(languages, cx), @@ -3719,7 +3796,6 @@ impl LspStore { pub(super) fn new_remote( buffer_store: Entity, worktree_store: Entity, - toolchain_store: Option>, languages: Arc, upstream_client: AnyProtoClient, project_id: u64, @@ -3746,13 +3822,14 @@ impl LspStore { worktree_store, languages: languages.clone(), language_server_statuses: Default::default(), - nonce: StdRng::from_entropy().r#gen(), + nonce: StdRng::from_os_rng().random(), diagnostic_summaries: HashMap::default(), lsp_server_capabilities: HashMap::default(), lsp_document_colors: HashMap::default(), lsp_code_lens: HashMap::default(), + running_lsp_requests: HashMap::default(), active_entry: None, - toolchain_store, + _maintain_workspace_config, _maintain_buffer_languages: Self::maintain_buffer_languages(languages.clone(), cx), } @@ -3770,13 +3847,13 @@ impl LspStore { } BufferStoreEvent::BufferChangedFilePath { buffer, old_file } => { let buffer_id = buffer.read(cx).remote_id(); - if let Some(local) = self.as_local_mut() { - if let Some(old_file) = File::from_dyn(old_file.as_ref()) { - local.reset_buffer(buffer, old_file, cx); + if let Some(local) = self.as_local_mut() + && let Some(old_file) = File::from_dyn(old_file.as_ref()) + { + local.reset_buffer(buffer, old_file, cx); - if local.registered_buffers.contains_key(&buffer_id) { - local.unregister_old_buffer_from_language_servers(buffer, old_file, cx); - } + if local.registered_buffers.contains_key(&buffer_id) { + local.unregister_old_buffer_from_language_servers(buffer, old_file, cx); } } @@ -3851,14 +3928,12 @@ impl LspStore { fn on_toolchain_store_event( &mut self, - _: Entity, + _: Entity, event: &ToolchainStoreEvent, _: &mut Context, ) { - match event { - ToolchainStoreEvent::ToolchainActivated { .. } => { - self.request_workspace_config_refresh() - } + if let ToolchainStoreEvent::ToolchainActivated = event { + self.request_workspace_config_refresh() } } @@ -3930,9 +4005,9 @@ impl LspStore { let local = this.as_local()?; let mut servers = Vec::new(); - for ((worktree_id, _), server_ids) in &local.language_server_ids { - for server_id in server_ids { - let Some(states) = local.language_servers.get(server_id) else { + for (seed, state) in &local.language_server_ids { + + let Some(states) = local.language_servers.get(&state.id) else { continue; }; let (json_adapter, json_server) = match states { @@ -3947,7 +4022,7 @@ impl LspStore { let Some(worktree) = this .worktree_store .read(cx) - .worktree_for_id(*worktree_id, cx) + .worktree_for_id(seed.worktree_id, cx) else { continue; }; @@ -3963,9 +4038,9 @@ impl LspStore { ); servers.push((json_adapter, json_server, json_delegate)); - } + } - return Some(servers); + Some(servers) }) .ok() .flatten(); @@ -3974,10 +4049,10 @@ impl LspStore { return; }; - let Ok(Some((fs, toolchain_store))) = this.read_with(cx, |this, cx| { + let Ok(Some((fs, _))) = this.read_with(cx, |this, _| { let local = this.as_local()?; - let toolchain_store = this.toolchain_store(cx); - return Some((local.fs.clone(), toolchain_store)); + let toolchain_store = local.toolchain_store().clone(); + Some((local.fs.clone(), toolchain_store)) }) else { return; }; @@ -3988,7 +4063,7 @@ impl LspStore { adapter, fs.as_ref(), &delegate, - toolchain_store.clone(), + None, cx, ) .await @@ -4057,7 +4132,7 @@ impl LspStore { local.registered_buffers.remove(&buffer_id); local.buffers_opened_in_servers.remove(&buffer_id); if let Some(file) = File::from_dyn(buffer.read(cx).file()).cloned() { - local.unregister_old_buffer_from_language_servers(&buffer, &file, cx); + local.unregister_old_buffer_from_language_servers(buffer, &file, cx); } } }) @@ -4127,14 +4202,12 @@ impl LspStore { if local .registered_buffers .contains_key(&buffer.read(cx).remote_id()) - { - if let Some(file_url) = + && let Some(file_url) = file_path_to_lsp_url(&f.abs_path(cx)).log_err() - { - local.unregister_buffer_from_language_servers( - &buffer, &file_url, cx, - ); - } + { + local.unregister_buffer_from_language_servers( + &buffer, &file_url, cx, + ); } } } @@ -4232,25 +4305,19 @@ impl LspStore { let buffer = buffer_entity.read(cx); let buffer_file = buffer.file().cloned(); let buffer_id = buffer.remote_id(); - if let Some(local_store) = self.as_local_mut() { - if local_store.registered_buffers.contains_key(&buffer_id) { - if let Some(abs_path) = - File::from_dyn(buffer_file.as_ref()).map(|file| file.abs_path(cx)) - { - if let Some(file_url) = file_path_to_lsp_url(&abs_path).log_err() { - local_store.unregister_buffer_from_language_servers( - buffer_entity, - &file_url, - cx, - ); - } - } - } + if let Some(local_store) = self.as_local_mut() + && local_store.registered_buffers.contains_key(&buffer_id) + && let Some(abs_path) = + File::from_dyn(buffer_file.as_ref()).map(|file| file.abs_path(cx)) + && let Some(file_url) = file_path_to_lsp_url(&abs_path).log_err() + { + local_store.unregister_buffer_from_language_servers(buffer_entity, &file_url, cx); } buffer_entity.update(cx, |buffer, cx| { - if buffer.language().map_or(true, |old_language| { - !Arc::ptr_eq(old_language, &new_language) - }) { + if buffer + .language() + .is_none_or(|old_language| !Arc::ptr_eq(old_language, &new_language)) + { buffer.set_language(Some(new_language.clone()), cx); } }); @@ -4262,33 +4329,28 @@ impl LspStore { let worktree_id = if let Some(file) = buffer_file { let worktree = file.worktree.clone(); - if let Some(local) = self.as_local_mut() { - if local.registered_buffers.contains_key(&buffer_id) { - local.register_buffer_with_language_servers( - buffer_entity, - HashSet::default(), - cx, - ); - } + if let Some(local) = self.as_local_mut() + && local.registered_buffers.contains_key(&buffer_id) + { + local.register_buffer_with_language_servers(buffer_entity, HashSet::default(), cx); } Some(worktree.read(cx).id()) } else { None }; - if settings.prettier.allowed { - if let Some(prettier_plugins) = prettier_store::prettier_plugins_for_language(&settings) - { - let prettier_store = self.as_local().map(|s| s.prettier_store.clone()); - if let Some(prettier_store) = prettier_store { - prettier_store.update(cx, |prettier_store, cx| { - prettier_store.install_default_prettier( - worktree_id, - prettier_plugins.iter().map(|s| Arc::from(s.as_str())), - cx, - ) - }) - } + if settings.prettier.allowed + && let Some(prettier_plugins) = prettier_store::prettier_plugins_for_language(&settings) + { + let prettier_store = self.as_local().map(|s| s.prettier_store.clone()); + if let Some(prettier_store) = prettier_store { + prettier_store.update(cx, |prettier_store, cx| { + prettier_store.install_default_prettier( + worktree_id, + prettier_plugins.iter().map(|s| Arc::from(s.as_str())), + cx, + ) + }) } } @@ -4307,32 +4369,27 @@ impl LspStore { } pub(crate) fn send_diagnostic_summaries(&self, worktree: &mut Worktree) { - if let Some((client, downstream_project_id)) = self.downstream_client.clone() { - if let Some(diangostic_summaries) = self.diagnostic_summaries.get(&worktree.id()) { - let mut summaries = - diangostic_summaries - .into_iter() - .flat_map(|(path, summaries)| { - summaries - .into_iter() - .map(|(server_id, summary)| summary.to_proto(*server_id, path)) - }); - if let Some(summary) = summaries.next() { - client - .send(proto::UpdateDiagnosticSummary { - project_id: downstream_project_id, - worktree_id: worktree.id().to_proto(), - summary: Some(summary), - more_summaries: summaries.collect(), - }) - .log_err(); - } + if let Some((client, downstream_project_id)) = self.downstream_client.clone() + && let Some(diangostic_summaries) = self.diagnostic_summaries.get(&worktree.id()) + { + let mut summaries = diangostic_summaries.iter().flat_map(|(path, summaries)| { + summaries + .iter() + .map(|(server_id, summary)| summary.to_proto(*server_id, path)) + }); + if let Some(summary) = summaries.next() { + client + .send(proto::UpdateDiagnosticSummary { + project_id: downstream_project_id, + worktree_id: worktree.id().to_proto(), + summary: Some(summary), + more_summaries: summaries.collect(), + }) + .log_err(); } } } - // TODO: remove MultiLspQuery: instead, the proto handler should pick appropriate server(s) - // Then, use `send_lsp_proto_request` or analogue for most of the LSP proto requests and inline this check inside fn is_capable_for_proto_request( &self, buffer: &Entity, @@ -4379,7 +4436,7 @@ impl LspStore { .contains(&server_status.name) .then_some(server_id) }) - .filter_map(|server_id| self.lsp_server_capabilities.get(&server_id)) + .filter_map(|server_id| self.lsp_server_capabilities.get(server_id)) .any(check) } @@ -4456,7 +4513,7 @@ impl LspStore { if !request.check_capabilities(language_server.adapter_server_capabilities()) { return Task::ready(Ok(Default::default())); } - return cx.spawn(async move |this, cx| { + cx.spawn(async move |this, cx| { let lsp_request = language_server.request::(lsp_params); let id = lsp_request.id(); @@ -4505,7 +4562,7 @@ impl LspStore { anyhow::anyhow!(message) })?; - let response = request + request .response_from_lsp( response, this.upgrade().context("no app context")?, @@ -4513,9 +4570,8 @@ impl LspStore { language_server.server_id(), cx.clone(), ) - .await; - response - }); + .await + }) } fn on_settings_changed(&mut self, cx: &mut Context) { @@ -4533,7 +4589,7 @@ impl LspStore { } } - self.refresh_server_tree(cx); + self.request_workspace_config_refresh(); if let Some(prettier_store) = self.as_local().map(|s| s.prettier_store.clone()) { prettier_store.update(cx, |prettier_store, cx| { @@ -4546,158 +4602,148 @@ impl LspStore { fn refresh_server_tree(&mut self, cx: &mut Context) { let buffer_store = self.buffer_store.clone(); - if let Some(local) = self.as_local_mut() { - let mut adapters = BTreeMap::default(); - let get_adapter = { - let languages = local.languages.clone(); - let environment = local.environment.clone(); - let weak = local.weak.clone(); - let worktree_store = local.worktree_store.clone(); - let http_client = local.http_client.clone(); - let fs = local.fs.clone(); - move |worktree_id, cx: &mut App| { - let worktree = worktree_store.read(cx).worktree_for_id(worktree_id, cx)?; - Some(LocalLspAdapterDelegate::new( - languages.clone(), - &environment, - weak.clone(), - &worktree, - http_client.clone(), - fs.clone(), - cx, - )) - } - }; - - let mut messages_to_report = Vec::new(); - let to_stop = local.lsp_tree.clone().update(cx, |lsp_tree, cx| { - let mut rebase = lsp_tree.rebase(); - for buffer_handle in buffer_store.read(cx).buffers().sorted_by_key(|buffer| { - Reverse( - File::from_dyn(buffer.read(cx).file()) - .map(|file| file.worktree.read(cx).is_visible()), - ) - }) { - let buffer = buffer_handle.read(cx); - let buffer_id = buffer.remote_id(); - if !local.registered_buffers.contains_key(&buffer_id) { - continue; - } - if let Some((file, language)) = File::from_dyn(buffer.file()) - .cloned() - .zip(buffer.language().map(|l| l.name())) + let Some(local) = self.as_local_mut() else { + return; + }; + let mut adapters = BTreeMap::default(); + let get_adapter = { + let languages = local.languages.clone(); + let environment = local.environment.clone(); + let weak = local.weak.clone(); + let worktree_store = local.worktree_store.clone(); + let http_client = local.http_client.clone(); + let fs = local.fs.clone(); + move |worktree_id, cx: &mut App| { + let worktree = worktree_store.read(cx).worktree_for_id(worktree_id, cx)?; + Some(LocalLspAdapterDelegate::new( + languages.clone(), + &environment, + weak.clone(), + &worktree, + http_client.clone(), + fs.clone(), + cx, + )) + } + }; + + let mut messages_to_report = Vec::new(); + let (new_tree, to_stop) = { + let mut rebase = local.lsp_tree.rebase(); + let buffers = buffer_store + .read(cx) + .buffers() + .filter_map(|buffer| { + let raw_buffer = buffer.read(cx); + if !local + .registered_buffers + .contains_key(&raw_buffer.remote_id()) { - let worktree_id = file.worktree_id(cx); - let Some(worktree) = local - .worktree_store - .read(cx) - .worktree_for_id(worktree_id, cx) - else { - continue; - }; + return None; + } + let file = File::from_dyn(raw_buffer.file()).cloned()?; + let language = raw_buffer.language().cloned()?; + Some((file, language, raw_buffer.remote_id())) + }) + .sorted_by_key(|(file, _, _)| Reverse(file.worktree.read(cx).is_visible())); + for (file, language, buffer_id) in buffers { + let worktree_id = file.worktree_id(cx); + let Some(worktree) = local + .worktree_store + .read(cx) + .worktree_for_id(worktree_id, cx) + else { + continue; + }; + + if let Some((_, apply)) = local.reuse_existing_language_server( + rebase.server_tree(), + &worktree, + &language.name(), + cx, + ) { + (apply)(rebase.server_tree()); + } else if let Some(lsp_delegate) = adapters + .entry(worktree_id) + .or_insert_with(|| get_adapter(worktree_id, cx)) + .clone() + { + let delegate = + Arc::new(ManifestQueryDelegate::new(worktree.read(cx).snapshot())); + let path = file + .path() + .parent() + .map(Arc::from) + .unwrap_or_else(|| file.path().clone()); + let worktree_path = ProjectPath { worktree_id, path }; + let abs_path = file.abs_path(cx); + let worktree_root = worktree.read(cx).abs_path(); + let nodes = rebase + .walk( + worktree_path, + language.name(), + language.manifest(), + delegate.clone(), + cx, + ) + .collect::>(); + for node in nodes { + let server_id = node.server_id_or_init(|disposition| { + let path = &disposition.path; + let uri = Uri::from_file_path(worktree_root.join(&path.path)); + let key = LanguageServerSeed { + worktree_id, + name: disposition.server_name.clone(), + settings: disposition.settings.clone(), + toolchain: local.toolchain_store.read(cx).active_toolchain( + path.worktree_id, + &path.path, + language.name(), + ), + }; + local.language_server_ids.remove(&key); - let Some((reused, delegate, nodes)) = local - .reuse_existing_language_server( - rebase.server_tree(), + let server_id = local.get_or_insert_language_server( &worktree, - &language, + lsp_delegate.clone(), + disposition, + &language.name(), cx, - ) - .map(|(delegate, servers)| (true, delegate, servers)) - .or_else(|| { - let lsp_delegate = adapters - .entry(worktree_id) - .or_insert_with(|| get_adapter(worktree_id, cx)) - .clone()?; - let delegate = Arc::new(ManifestQueryDelegate::new( - worktree.read(cx).snapshot(), - )); - let path = file - .path() - .parent() - .map(Arc::from) - .unwrap_or_else(|| file.path().clone()); - let worktree_path = ProjectPath { worktree_id, path }; - - let nodes = rebase.get( - worktree_path, - AdapterQuery::Language(&language), - delegate.clone(), - cx, - ); - - Some((false, lsp_delegate, nodes.collect())) - }) - else { - continue; - }; - - let abs_path = file.abs_path(cx); - for node in nodes { - if !reused { - let server_id = node.server_id_or_init( - |LaunchDisposition { - server_name, - - path, - settings, - }| - { - let uri = Url::from_file_path( - worktree.read(cx).abs_path().join(&path.path), - ); - let key = (worktree_id, server_name.clone()); - local.language_server_ids.remove(&key); - - let adapter = local - .languages - .lsp_adapters(&language) - .into_iter() - .find(|adapter| &adapter.name() == server_name) - .expect("To find LSP adapter"); - let server_id = local.start_language_server( - &worktree, - delegate.clone(), - adapter, - settings, - cx, - ); - if let Some(state) = - local.language_servers.get(&server_id) - { - if let Ok(uri) = uri { - state.add_workspace_folder(uri); - }; - } - server_id - } - ); + ); + if let Some(state) = local.language_servers.get(&server_id) + && let Ok(uri) = uri + { + state.add_workspace_folder(uri); + }; + server_id + }); - if let Some(language_server_id) = server_id { - messages_to_report.push(LspStoreEvent::LanguageServerUpdate { - language_server_id, - name: node.name(), - message: - proto::update_language_server::Variant::RegisteredForBuffer( - proto::RegisteredForBuffer { - buffer_abs_path: abs_path.to_string_lossy().to_string(), - buffer_id: buffer_id.to_proto(), - }, - ), - }); - } - } + if let Some(language_server_id) = server_id { + messages_to_report.push(LspStoreEvent::LanguageServerUpdate { + language_server_id, + name: node.name(), + message: + proto::update_language_server::Variant::RegisteredForBuffer( + proto::RegisteredForBuffer { + buffer_abs_path: abs_path.to_string_lossy().to_string(), + buffer_id: buffer_id.to_proto(), + }, + ), + }); } } + } else { + continue; } - rebase.finish() - }); - for message in messages_to_report { - cx.emit(message); - } - for (id, _) in to_stop { - self.stop_local_language_server(id, cx).detach(); } + rebase.finish() + }; + for message in messages_to_report { + cx.emit(message); + } + local.lsp_tree = new_tree; + for (id, _) in to_stop { + self.stop_local_language_server(id, cx).detach(); } } @@ -4729,7 +4775,7 @@ impl LspStore { .await }) } else if self.mode.is_local() { - let Some((lsp_adapter, lang_server)) = buffer_handle.update(cx, |buffer, cx| { + let Some((_, lang_server)) = buffer_handle.update(cx, |buffer, cx| { self.language_server_for_local_buffer(buffer, action.server_id, cx) .map(|(adapter, server)| (adapter.clone(), server.clone())) }) else { @@ -4739,19 +4785,18 @@ impl LspStore { LocalLspStore::try_resolve_code_action(&lang_server, &mut action) .await .context("resolving a code action")?; - if let Some(edit) = action.lsp_action.edit() { - if edit.changes.is_some() || edit.document_changes.is_some() { + if let Some(edit) = action.lsp_action.edit() + && (edit.changes.is_some() || edit.document_changes.is_some()) { return LocalLspStore::deserialize_workspace_edit( this.upgrade().context("no app present")?, edit.clone(), push_to_history, - lsp_adapter.clone(), + lang_server.clone(), cx, ) .await; } - } if let Some(command) = action.lsp_action.command() { let server_capabilities = lang_server.capabilities(); @@ -4803,7 +4848,7 @@ impl LspStore { push_to_history: bool, cx: &mut Context, ) -> Task> { - if let Some(_) = self.as_local() { + if self.as_local().is_some() { cx.spawn(async move |lsp_store, cx| { let buffers = buffers.into_iter().collect::>(); let result = LocalLspStore::execute_code_action_kind_locally( @@ -5193,154 +5238,130 @@ impl LspStore { pub fn definitions( &mut self, - buffer_handle: &Entity, + buffer: &Entity, position: PointUtf16, cx: &mut Context, - ) -> Task>> { + ) -> Task>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { let request = GetDefinitions { position }; - if !self.is_capable_for_proto_request(buffer_handle, &request, cx) { - return Task::ready(Ok(Vec::new())); + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(None)); } - let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer_handle.read(cx).remote_id().into(), - version: serialize_version(&buffer_handle.read(cx).version()), + let request_task = upstream_client.request_lsp( project_id, - strategy: Some(proto::multi_lsp_query::Strategy::All( - proto::AllLanguageServers {}, - )), - request: Some(proto::multi_lsp_query::Request::GetDefinition( - request.to_proto(project_id, buffer_handle.read(cx)), - )), - }); - let buffer = buffer_handle.clone(); + LSP_REQUEST_TIMEOUT, + cx.background_executor().clone(), + request.to_proto(project_id, buffer.read(cx)), + ); + let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { - return Ok(Vec::new()); + return Ok(None); }; - let responses = request_task.await?.responses; - let actions = join_all( - responses - .into_iter() - .filter_map(|lsp_response| match lsp_response.response? { - proto::lsp_response::Response::GetDefinitionResponse(response) => { - Some(response) - } - unexpected => { - debug_panic!("Unexpected response: {unexpected:?}"); - None - } - }) - .map(|definitions_response| { - GetDefinitions { position }.response_from_proto( - definitions_response, - project.clone(), - buffer.clone(), - cx.clone(), - ) - }), - ) + let Some(responses) = request_task.await? else { + return Ok(None); + }; + let actions = join_all(responses.payload.into_iter().map(|response| { + GetDefinitions { position }.response_from_proto( + response.response, + project.clone(), + buffer.clone(), + cx.clone(), + ) + })) .await; - Ok(actions - .into_iter() - .collect::>>>()? - .into_iter() - .flatten() - .dedup() - .collect()) + Ok(Some( + actions + .into_iter() + .collect::>>>()? + .into_iter() + .flatten() + .dedup() + .collect(), + )) }) } else { let definitions_task = self.request_multiple_lsp_locally( - buffer_handle, + buffer, Some(position), GetDefinitions { position }, cx, ); cx.background_spawn(async move { - Ok(definitions_task - .await - .into_iter() - .flat_map(|(_, definitions)| definitions) - .dedup() - .collect()) + Ok(Some( + definitions_task + .await + .into_iter() + .flat_map(|(_, definitions)| definitions) + .dedup() + .collect(), + )) }) } } pub fn declarations( &mut self, - buffer_handle: &Entity, + buffer: &Entity, position: PointUtf16, cx: &mut Context, - ) -> Task>> { + ) -> Task>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { let request = GetDeclarations { position }; - if !self.is_capable_for_proto_request(buffer_handle, &request, cx) { - return Task::ready(Ok(Vec::new())); + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(None)); } - let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer_handle.read(cx).remote_id().into(), - version: serialize_version(&buffer_handle.read(cx).version()), + let request_task = upstream_client.request_lsp( project_id, - strategy: Some(proto::multi_lsp_query::Strategy::All( - proto::AllLanguageServers {}, - )), - request: Some(proto::multi_lsp_query::Request::GetDeclaration( - request.to_proto(project_id, buffer_handle.read(cx)), - )), - }); - let buffer = buffer_handle.clone(); + LSP_REQUEST_TIMEOUT, + cx.background_executor().clone(), + request.to_proto(project_id, buffer.read(cx)), + ); + let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { - return Ok(Vec::new()); + return Ok(None); }; - let responses = request_task.await?.responses; - let actions = join_all( - responses - .into_iter() - .filter_map(|lsp_response| match lsp_response.response? { - proto::lsp_response::Response::GetDeclarationResponse(response) => { - Some(response) - } - unexpected => { - debug_panic!("Unexpected response: {unexpected:?}"); - None - } - }) - .map(|declarations_response| { - GetDeclarations { position }.response_from_proto( - declarations_response, - project.clone(), - buffer.clone(), - cx.clone(), - ) - }), - ) + let Some(responses) = request_task.await? else { + return Ok(None); + }; + let actions = join_all(responses.payload.into_iter().map(|response| { + GetDeclarations { position }.response_from_proto( + response.response, + project.clone(), + buffer.clone(), + cx.clone(), + ) + })) .await; - Ok(actions - .into_iter() - .collect::>>>()? - .into_iter() - .flatten() - .dedup() - .collect()) + Ok(Some( + actions + .into_iter() + .collect::>>>()? + .into_iter() + .flatten() + .dedup() + .collect(), + )) }) } else { let declarations_task = self.request_multiple_lsp_locally( - buffer_handle, + buffer, Some(position), GetDeclarations { position }, cx, ); cx.background_spawn(async move { - Ok(declarations_task - .await - .into_iter() - .flat_map(|(_, declarations)| declarations) - .dedup() - .collect()) + Ok(Some( + declarations_task + .await + .into_iter() + .flat_map(|(_, declarations)| declarations) + .dedup() + .collect(), + )) }) } } @@ -5350,59 +5371,45 @@ impl LspStore { buffer: &Entity, position: PointUtf16, cx: &mut Context, - ) -> Task>> { + ) -> Task>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { let request = GetTypeDefinitions { position }; - if !self.is_capable_for_proto_request(&buffer, &request, cx) { - return Task::ready(Ok(Vec::new())); + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(None)); } - let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer.read(cx).remote_id().into(), - version: serialize_version(&buffer.read(cx).version()), + let request_task = upstream_client.request_lsp( project_id, - strategy: Some(proto::multi_lsp_query::Strategy::All( - proto::AllLanguageServers {}, - )), - request: Some(proto::multi_lsp_query::Request::GetTypeDefinition( - request.to_proto(project_id, buffer.read(cx)), - )), - }); + LSP_REQUEST_TIMEOUT, + cx.background_executor().clone(), + request.to_proto(project_id, buffer.read(cx)), + ); let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { - return Ok(Vec::new()); + return Ok(None); }; - let responses = request_task.await?.responses; - let actions = join_all( - responses - .into_iter() - .filter_map(|lsp_response| match lsp_response.response? { - proto::lsp_response::Response::GetTypeDefinitionResponse(response) => { - Some(response) - } - unexpected => { - debug_panic!("Unexpected response: {unexpected:?}"); - None - } - }) - .map(|type_definitions_response| { - GetTypeDefinitions { position }.response_from_proto( - type_definitions_response, - project.clone(), - buffer.clone(), - cx.clone(), - ) - }), - ) + let Some(responses) = request_task.await? else { + return Ok(None); + }; + let actions = join_all(responses.payload.into_iter().map(|response| { + GetTypeDefinitions { position }.response_from_proto( + response.response, + project.clone(), + buffer.clone(), + cx.clone(), + ) + })) .await; - Ok(actions - .into_iter() - .collect::>>>()? - .into_iter() - .flatten() - .dedup() - .collect()) + Ok(Some( + actions + .into_iter() + .collect::>>>()? + .into_iter() + .flatten() + .dedup() + .collect(), + )) }) } else { let type_definitions_task = self.request_multiple_lsp_locally( @@ -5412,12 +5419,14 @@ impl LspStore { cx, ); cx.background_spawn(async move { - Ok(type_definitions_task - .await - .into_iter() - .flat_map(|(_, type_definitions)| type_definitions) - .dedup() - .collect()) + Ok(Some( + type_definitions_task + .await + .into_iter() + .flat_map(|(_, type_definitions)| type_definitions) + .dedup() + .collect(), + )) }) } } @@ -5427,59 +5436,45 @@ impl LspStore { buffer: &Entity, position: PointUtf16, cx: &mut Context, - ) -> Task>> { + ) -> Task>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { let request = GetImplementations { position }; if !self.is_capable_for_proto_request(buffer, &request, cx) { - return Task::ready(Ok(Vec::new())); + return Task::ready(Ok(None)); } - let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer.read(cx).remote_id().into(), - version: serialize_version(&buffer.read(cx).version()), + let request_task = upstream_client.request_lsp( project_id, - strategy: Some(proto::multi_lsp_query::Strategy::All( - proto::AllLanguageServers {}, - )), - request: Some(proto::multi_lsp_query::Request::GetImplementation( - request.to_proto(project_id, buffer.read(cx)), - )), - }); + LSP_REQUEST_TIMEOUT, + cx.background_executor().clone(), + request.to_proto(project_id, buffer.read(cx)), + ); let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { - return Ok(Vec::new()); + return Ok(None); }; - let responses = request_task.await?.responses; - let actions = join_all( - responses - .into_iter() - .filter_map(|lsp_response| match lsp_response.response? { - proto::lsp_response::Response::GetImplementationResponse(response) => { - Some(response) - } - unexpected => { - debug_panic!("Unexpected response: {unexpected:?}"); - None - } - }) - .map(|implementations_response| { - GetImplementations { position }.response_from_proto( - implementations_response, - project.clone(), - buffer.clone(), - cx.clone(), - ) - }), - ) + let Some(responses) = request_task.await? else { + return Ok(None); + }; + let actions = join_all(responses.payload.into_iter().map(|response| { + GetImplementations { position }.response_from_proto( + response.response, + project.clone(), + buffer.clone(), + cx.clone(), + ) + })) .await; - Ok(actions - .into_iter() - .collect::>>>()? - .into_iter() - .flatten() - .dedup() - .collect()) + Ok(Some( + actions + .into_iter() + .collect::>>>()? + .into_iter() + .flatten() + .dedup() + .collect(), + )) }) } else { let implementations_task = self.request_multiple_lsp_locally( @@ -5489,12 +5484,14 @@ impl LspStore { cx, ); cx.background_spawn(async move { - Ok(implementations_task - .await - .into_iter() - .flat_map(|(_, implementations)| implementations) - .dedup() - .collect()) + Ok(Some( + implementations_task + .await + .into_iter() + .flat_map(|(_, implementations)| implementations) + .dedup() + .collect(), + )) }) } } @@ -5504,59 +5501,44 @@ impl LspStore { buffer: &Entity, position: PointUtf16, cx: &mut Context, - ) -> Task>> { + ) -> Task>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { let request = GetReferences { position }; - if !self.is_capable_for_proto_request(&buffer, &request, cx) { - return Task::ready(Ok(Vec::new())); + if !self.is_capable_for_proto_request(buffer, &request, cx) { + return Task::ready(Ok(None)); } - let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer.read(cx).remote_id().into(), - version: serialize_version(&buffer.read(cx).version()), + + let request_task = upstream_client.request_lsp( project_id, - strategy: Some(proto::multi_lsp_query::Strategy::All( - proto::AllLanguageServers {}, - )), - request: Some(proto::multi_lsp_query::Request::GetReferences( - request.to_proto(project_id, buffer.read(cx)), - )), - }); + LSP_REQUEST_TIMEOUT, + cx.background_executor().clone(), + request.to_proto(project_id, buffer.read(cx)), + ); let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { - return Ok(Vec::new()); + return Ok(None); + }; + let Some(responses) = request_task.await? else { + return Ok(None); }; - let responses = request_task.await?.responses; - let actions = join_all( - responses - .into_iter() - .filter_map(|lsp_response| match lsp_response.response? { - proto::lsp_response::Response::GetReferencesResponse(response) => { - Some(response) - } - unexpected => { - debug_panic!("Unexpected response: {unexpected:?}"); - None - } - }) - .map(|references_response| { - GetReferences { position }.response_from_proto( - references_response, - project.clone(), - buffer.clone(), - cx.clone(), - ) - }), - ) - .await; - Ok(actions - .into_iter() - .collect::>>>()? - .into_iter() - .flatten() - .dedup() - .collect()) + let locations = join_all(responses.payload.into_iter().map(|lsp_response| { + GetReferences { position }.response_from_proto( + lsp_response.response, + project.clone(), + buffer.clone(), + cx.clone(), + ) + })) + .await + .into_iter() + .collect::>>>()? + .into_iter() + .flatten() + .dedup() + .collect(); + Ok(Some(locations)) }) } else { let references_task = self.request_multiple_lsp_locally( @@ -5566,12 +5548,14 @@ impl LspStore { cx, ); cx.background_spawn(async move { - Ok(references_task - .await - .into_iter() - .flat_map(|(_, references)| references) - .dedup() - .collect()) + Ok(Some( + references_task + .await + .into_iter() + .flat_map(|(_, references)| references) + .dedup() + .collect(), + )) }) } } @@ -5582,82 +5566,67 @@ impl LspStore { range: Range, kinds: Option>, cx: &mut Context, - ) -> Task>> { + ) -> Task>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { let request = GetCodeActions { range: range.clone(), kinds: kinds.clone(), }; if !self.is_capable_for_proto_request(buffer, &request, cx) { - return Task::ready(Ok(Vec::new())); + return Task::ready(Ok(None)); } - let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer.read(cx).remote_id().into(), - version: serialize_version(&buffer.read(cx).version()), + let request_task = upstream_client.request_lsp( project_id, - strategy: Some(proto::multi_lsp_query::Strategy::All( - proto::AllLanguageServers {}, - )), - request: Some(proto::multi_lsp_query::Request::GetCodeActions( - request.to_proto(project_id, buffer.read(cx)), - )), - }); + LSP_REQUEST_TIMEOUT, + cx.background_executor().clone(), + request.to_proto(project_id, buffer.read(cx)), + ); let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { let Some(project) = weak_project.upgrade() else { - return Ok(Vec::new()); + return Ok(None); }; - let responses = request_task.await?.responses; - let actions = join_all( - responses - .into_iter() - .filter_map(|lsp_response| match lsp_response.response? { - proto::lsp_response::Response::GetCodeActionsResponse(response) => { - Some(response) - } - unexpected => { - debug_panic!("Unexpected response: {unexpected:?}"); - None - } - }) - .map(|code_actions_response| { - GetCodeActions { - range: range.clone(), - kinds: kinds.clone(), - } - .response_from_proto( - code_actions_response, - project.clone(), - buffer.clone(), - cx.clone(), - ) - }), - ) + let Some(responses) = request_task.await? else { + return Ok(None); + }; + let actions = join_all(responses.payload.into_iter().map(|response| { + GetCodeActions { + range: range.clone(), + kinds: kinds.clone(), + } + .response_from_proto( + response.response, + project.clone(), + buffer.clone(), + cx.clone(), + ) + })) .await; - Ok(actions - .into_iter() - .collect::>>>()? - .into_iter() - .flatten() - .collect()) + Ok(Some( + actions + .into_iter() + .collect::>>>()? + .into_iter() + .flatten() + .collect(), + )) }) } else { let all_actions_task = self.request_multiple_lsp_locally( buffer, Some(range.start), - GetCodeActions { - range: range.clone(), - kinds: kinds.clone(), - }, + GetCodeActions { range, kinds }, cx, ); cx.background_spawn(async move { - Ok(all_actions_task - .await - .into_iter() - .flat_map(|(_, actions)| actions) - .collect()) + Ok(Some( + all_actions_task + .await + .into_iter() + .flat_map(|(_, actions)| actions) + .collect(), + )) }) } } @@ -5670,28 +5639,30 @@ impl LspStore { let version_queried_for = buffer.read(cx).version(); let buffer_id = buffer.read(cx).remote_id(); - if let Some(cached_data) = self.lsp_code_lens.get(&buffer_id) { - if !version_queried_for.changed_since(&cached_data.lens_for_version) { - let has_different_servers = self.as_local().is_some_and(|local| { - local - .buffers_opened_in_servers - .get(&buffer_id) - .cloned() - .unwrap_or_default() - != cached_data.lens.keys().copied().collect() - }); - if !has_different_servers { - return Task::ready(Ok(cached_data.lens.values().flatten().cloned().collect())) - .shared(); - } + if let Some(cached_data) = self.lsp_code_lens.get(&buffer_id) + && !version_queried_for.changed_since(&cached_data.lens_for_version) + { + let has_different_servers = self.as_local().is_some_and(|local| { + local + .buffers_opened_in_servers + .get(&buffer_id) + .cloned() + .unwrap_or_default() + != cached_data.lens.keys().copied().collect() + }); + if !has_different_servers { + return Task::ready(Ok(Some( + cached_data.lens.values().flatten().cloned().collect(), + ))) + .shared(); } } let lsp_data = self.lsp_code_lens.entry(buffer_id).or_default(); - if let Some((updating_for, running_update)) = &lsp_data.update { - if !version_queried_for.changed_since(&updating_for) { - return running_update.clone(); - } + if let Some((updating_for, running_update)) = &lsp_data.update + && !version_queried_for.changed_since(updating_for) + { + return running_update.clone(); } let buffer = buffer.clone(); let query_version_queried_for = version_queried_for.clone(); @@ -5721,17 +5692,19 @@ impl LspStore { lsp_store .update(cx, |lsp_store, _| { let lsp_data = lsp_store.lsp_code_lens.entry(buffer_id).or_default(); - if lsp_data.lens_for_version == query_version_queried_for { - lsp_data.lens.extend(fetched_lens.clone()); - } else if !lsp_data - .lens_for_version - .changed_since(&query_version_queried_for) - { - lsp_data.lens_for_version = query_version_queried_for; - lsp_data.lens = fetched_lens.clone(); + if let Some(fetched_lens) = fetched_lens { + if lsp_data.lens_for_version == query_version_queried_for { + lsp_data.lens.extend(fetched_lens); + } else if !lsp_data + .lens_for_version + .changed_since(&query_version_queried_for) + { + lsp_data.lens_for_version = query_version_queried_for; + lsp_data.lens = fetched_lens; + } } lsp_data.update = None; - lsp_data.lens.values().flatten().cloned().collect() + Some(lsp_data.lens.values().flatten().cloned().collect()) }) .map_err(Arc::new) }) @@ -5744,64 +5717,40 @@ impl LspStore { &mut self, buffer: &Entity, cx: &mut Context, - ) -> Task>>> { + ) -> Task>>>> { if let Some((upstream_client, project_id)) = self.upstream_client() { let request = GetCodeLens; if !self.is_capable_for_proto_request(buffer, &request, cx) { - return Task::ready(Ok(HashMap::default())); + return Task::ready(Ok(None)); } - let request_task = upstream_client.request(proto::MultiLspQuery { - buffer_id: buffer.read(cx).remote_id().into(), - version: serialize_version(&buffer.read(cx).version()), + let request_task = upstream_client.request_lsp( project_id, - strategy: Some(proto::multi_lsp_query::Strategy::All( - proto::AllLanguageServers {}, - )), - request: Some(proto::multi_lsp_query::Request::GetCodeLens( - request.to_proto(project_id, buffer.read(cx)), - )), - }); + LSP_REQUEST_TIMEOUT, + cx.background_executor().clone(), + request.to_proto(project_id, buffer.read(cx)), + ); let buffer = buffer.clone(); cx.spawn(async move |weak_lsp_store, cx| { let Some(lsp_store) = weak_lsp_store.upgrade() else { - return Ok(HashMap::default()); + return Ok(None); }; - let responses = request_task.await?.responses; - let code_lens_actions = join_all( - responses - .into_iter() - .filter_map(|lsp_response| { - let response = match lsp_response.response? { - proto::lsp_response::Response::GetCodeLensResponse(response) => { - Some(response) - } - unexpected => { - debug_panic!("Unexpected response: {unexpected:?}"); - None - } - }?; - let server_id = LanguageServerId::from_proto(lsp_response.server_id); - Some((server_id, response)) - }) - .map(|(server_id, code_lens_response)| { - let lsp_store = lsp_store.clone(); - let buffer = buffer.clone(); - let cx = cx.clone(); - async move { - ( - server_id, - GetCodeLens - .response_from_proto( - code_lens_response, - lsp_store, - buffer, - cx, - ) - .await, - ) - } - }), - ) + let Some(responses) = request_task.await? else { + return Ok(None); + }; + + let code_lens_actions = join_all(responses.payload.into_iter().map(|response| { + let lsp_store = lsp_store.clone(); + let buffer = buffer.clone(); + let cx = cx.clone(); + async move { + ( + LanguageServerId::from_proto(response.server_id), + GetCodeLens + .response_from_proto(response.response, lsp_store, buffer, cx) + .await, + ) + } + })) .await; let mut has_errors = false; @@ -5820,14 +5769,14 @@ impl LspStore { !has_errors || !code_lens_actions.is_empty(), "Failed to fetch code lens" ); - Ok(code_lens_actions) + Ok(Some(code_lens_actions)) }) } else { let code_lens_actions_task = self.request_multiple_lsp_locally(buffer, None::, GetCodeLens, cx); - cx.background_spawn( - async move { Ok(code_lens_actions_task.await.into_iter().collect()) }, - ) + cx.background_spawn(async move { + Ok(Some(code_lens_actions_task.await.into_iter().collect())) + }) } } @@ -5875,6 +5824,7 @@ impl LspStore { .await; Ok(vec![CompletionResponse { completions, + display_options: CompletionDisplayOptions::default(), is_incomplete: completion_response.is_incomplete, }]) }) @@ -5967,6 +5917,7 @@ impl LspStore { .await; Some(CompletionResponse { completions, + display_options: CompletionDisplayOptions::default(), is_incomplete: completion_response.is_incomplete, }) }); @@ -6306,11 +6257,11 @@ impl LspStore { .old_replace_start .and_then(deserialize_anchor) .zip(response.old_replace_end.and_then(deserialize_anchor)); - if let Some((old_replace_start, old_replace_end)) = replace_range { - if !response.new_text.is_empty() { - completion.new_text = response.new_text; - completion.replace_range = old_replace_start..old_replace_end; - } + if let Some((old_replace_start, old_replace_end)) = replace_range + && !response.new_text.is_empty() + { + completion.new_text = response.new_text; + completion.replace_range = old_replace_start..old_replace_end; } Ok(()) @@ -6405,14 +6356,38 @@ impl LspStore { for (range, text) in edits { let primary = &completion.replace_range; - let start_within = primary.start.cmp(&range.start, buffer).is_le() - && primary.end.cmp(&range.start, buffer).is_ge(); - let end_within = range.start.cmp(&primary.end, buffer).is_le() - && range.end.cmp(&primary.end, buffer).is_ge(); + + // Special case: if both ranges start at the very beginning of the file (line 0, column 0), + // and the primary completion is just an insertion (empty range), then this is likely + // an auto-import scenario and should not be considered overlapping + // https://github.com/zed-industries/zed/issues/26136 + let is_file_start_auto_import = { + let snapshot = buffer.snapshot(); + let primary_start_point = primary.start.to_point(&snapshot); + let range_start_point = range.start.to_point(&snapshot); + + let result = primary_start_point.row == 0 + && primary_start_point.column == 0 + && range_start_point.row == 0 + && range_start_point.column == 0; + + result + }; + + let has_overlap = if is_file_start_auto_import { + false + } else { + let start_within = primary.start.cmp(&range.start, buffer).is_le() + && primary.end.cmp(&range.start, buffer).is_ge(); + let end_within = range.start.cmp(&primary.end, buffer).is_le() + && range.end.cmp(&primary.end, buffer).is_ge(); + let result = start_within || end_within; + result + }; //Skip additional edits which overlap with the primary completion edit //https://github.com/zed-industries/zed/pull/1871 - if !start_within && !end_within { + if !has_overlap { buffer.edit([(range, text)], None, cx); } } @@ -6443,48 +6418,23 @@ impl LspStore { let buffer_id = buffer.read(cx).remote_id(); if let Some((client, upstream_project_id)) = self.upstream_client() { - if !self.is_capable_for_proto_request( - &buffer, - &GetDocumentDiagnostics { - previous_result_id: None, - }, - cx, - ) { + let request = GetDocumentDiagnostics { + previous_result_id: None, + }; + if !self.is_capable_for_proto_request(&buffer, &request, cx) { return Task::ready(Ok(None)); } - let request_task = client.request(proto::MultiLspQuery { - buffer_id: buffer_id.to_proto(), - version: serialize_version(&buffer.read(cx).version()), - project_id: upstream_project_id, - strategy: Some(proto::multi_lsp_query::Strategy::All( - proto::AllLanguageServers {}, - )), - request: Some(proto::multi_lsp_query::Request::GetDocumentDiagnostics( - proto::GetDocumentDiagnostics { - project_id: upstream_project_id, - buffer_id: buffer_id.to_proto(), - version: serialize_version(&buffer.read(cx).version()), - }, - )), - }); + let request_task = client.request_lsp( + upstream_project_id, + LSP_REQUEST_TIMEOUT, + cx.background_executor().clone(), + request.to_proto(upstream_project_id, buffer.read(cx)), + ); cx.background_spawn(async move { - let _proto_responses = request_task - .await? - .responses - .into_iter() - .filter_map(|lsp_response| match lsp_response.response? { - proto::lsp_response::Response::GetDocumentDiagnosticsResponse(response) => { - Some(response) - } - unexpected => { - debug_panic!("Unexpected response: {unexpected:?}"); - None - } - }) - .collect::>(); // Proto requests cause the diagnostics to be pulled from language server(s) on the local side // and then, buffer state updated with the diagnostics received, which will be later propagated to the client. // Do not attempt to further process the dummy responses here. + let _response = request_task.await?; Ok(None) }) } else { @@ -6650,7 +6600,7 @@ impl LspStore { File::from_dyn(buffer.file()) .and_then(|file| { let abs_path = file.as_local()?.abs_path(cx); - lsp::Url::from_file_path(abs_path).ok() + lsp::Uri::from_file_path(abs_path).ok() }) .is_none_or(|buffer_uri| { unchanged_buffers.contains(&buffer_uri) @@ -6685,33 +6635,33 @@ impl LspStore { LspFetchStrategy::UseCache { known_cache_version, } => { - if let Some(cached_data) = self.lsp_document_colors.get(&buffer_id) { - if !version_queried_for.changed_since(&cached_data.colors_for_version) { - let has_different_servers = self.as_local().is_some_and(|local| { - local - .buffers_opened_in_servers - .get(&buffer_id) - .cloned() - .unwrap_or_default() - != cached_data.colors.keys().copied().collect() - }); - if !has_different_servers { - if Some(cached_data.cache_version) == known_cache_version { - return None; - } else { - return Some( - Task::ready(Ok(DocumentColors { - colors: cached_data - .colors - .values() - .flatten() - .cloned() - .collect(), - cache_version: Some(cached_data.cache_version), - })) - .shared(), - ); - } + if let Some(cached_data) = self.lsp_document_colors.get(&buffer_id) + && !version_queried_for.changed_since(&cached_data.colors_for_version) + { + let has_different_servers = self.as_local().is_some_and(|local| { + local + .buffers_opened_in_servers + .get(&buffer_id) + .cloned() + .unwrap_or_default() + != cached_data.colors.keys().copied().collect() + }); + if !has_different_servers { + if Some(cached_data.cache_version) == known_cache_version { + return None; + } else { + return Some( + Task::ready(Ok(DocumentColors { + colors: cached_data + .colors + .values() + .flatten() + .cloned() + .collect(), + cache_version: Some(cached_data.cache_version), + })) + .shared(), + ); } } } @@ -6719,10 +6669,10 @@ impl LspStore { } let lsp_data = self.lsp_document_colors.entry(buffer_id).or_default(); - if let Some((updating_for, running_update)) = &lsp_data.colors_update { - if !version_queried_for.changed_since(&updating_for) { - return Some(running_update.clone()); - } + if let Some((updating_for, running_update)) = &lsp_data.colors_update + && !version_queried_for.changed_since(updating_for) + { + return Some(running_update.clone()); } let query_version_queried_for = version_queried_for.clone(); let new_task = cx @@ -6769,16 +6719,18 @@ impl LspStore { .update(cx, |lsp_store, _| { let lsp_data = lsp_store.lsp_document_colors.entry(buffer_id).or_default(); - if lsp_data.colors_for_version == query_version_queried_for { - lsp_data.colors.extend(fetched_colors.clone()); - lsp_data.cache_version += 1; - } else if !lsp_data - .colors_for_version - .changed_since(&query_version_queried_for) - { - lsp_data.colors_for_version = query_version_queried_for; - lsp_data.colors = fetched_colors.clone(); - lsp_data.cache_version += 1; + if let Some(fetched_colors) = fetched_colors { + if lsp_data.colors_for_version == query_version_queried_for { + lsp_data.colors.extend(fetched_colors); + lsp_data.cache_version += 1; + } else if !lsp_data + .colors_for_version + .changed_since(&query_version_queried_for) + { + lsp_data.colors_for_version = query_version_queried_for; + lsp_data.colors = fetched_colors; + lsp_data.cache_version += 1; + } } lsp_data.colors_update = None; let colors = lsp_data @@ -6803,56 +6755,45 @@ impl LspStore { &mut self, buffer: &Entity, cx: &mut Context, - ) -> Task>>> { + ) -> Task>>>> { if let Some((client, project_id)) = self.upstream_client() { let request = GetDocumentColor {}; if !self.is_capable_for_proto_request(buffer, &request, cx) { - return Task::ready(Ok(HashMap::default())); + return Task::ready(Ok(None)); } - let request_task = client.request(proto::MultiLspQuery { + let request_task = client.request_lsp( project_id, - buffer_id: buffer.read(cx).remote_id().to_proto(), - version: serialize_version(&buffer.read(cx).version()), - strategy: Some(proto::multi_lsp_query::Strategy::All( - proto::AllLanguageServers {}, - )), - request: Some(proto::multi_lsp_query::Request::GetDocumentColor( - request.to_proto(project_id, buffer.read(cx)), - )), - }); + LSP_REQUEST_TIMEOUT, + cx.background_executor().clone(), + request.to_proto(project_id, buffer.read(cx)), + ); let buffer = buffer.clone(); - cx.spawn(async move |project, cx| { - let Some(project) = project.upgrade() else { - return Ok(HashMap::default()); + cx.spawn(async move |lsp_store, cx| { + let Some(project) = lsp_store.upgrade() else { + return Ok(None); }; let colors = join_all( request_task .await .log_err() - .map(|response| response.responses) + .flatten() + .map(|response| response.payload) .unwrap_or_default() .into_iter() - .filter_map(|lsp_response| match lsp_response.response? { - proto::lsp_response::Response::GetDocumentColorResponse(response) => { - Some(( - LanguageServerId::from_proto(lsp_response.server_id), - response, - )) - } - unexpected => { - debug_panic!("Unexpected response: {unexpected:?}"); - None - } - }) - .map(|(server_id, color_response)| { + .map(|color_response| { let response = request.response_from_proto( - color_response, + color_response.response, project.clone(), buffer.clone(), cx.clone(), ); - async move { (server_id, response.await.log_err().unwrap_or_default()) } + async move { + ( + LanguageServerId::from_proto(color_response.server_id), + response.await.log_err().unwrap_or_default(), + ) + } }), ) .await @@ -6863,23 +6804,25 @@ impl LspStore { .extend(colors); acc }); - Ok(colors) + Ok(Some(colors)) }) } else { let document_colors_task = self.request_multiple_lsp_locally(buffer, None::, GetDocumentColor, cx); cx.background_spawn(async move { - Ok(document_colors_task - .await - .into_iter() - .fold(HashMap::default(), |mut acc, (server_id, colors)| { - acc.entry(server_id) - .or_insert_with(HashSet::default) - .extend(colors); - acc - }) - .into_iter() - .collect()) + Ok(Some( + document_colors_task + .await + .into_iter() + .fold(HashMap::default(), |mut acc, (server_id, colors)| { + acc.entry(server_id) + .or_insert_with(HashSet::default) + .extend(colors); + acc + }) + .into_iter() + .collect(), + )) }) } } @@ -6889,49 +6832,34 @@ impl LspStore { buffer: &Entity, position: T, cx: &mut Context, - ) -> Task> { + ) -> Task>> { let position = position.to_point_utf16(buffer.read(cx)); if let Some((client, upstream_project_id)) = self.upstream_client() { let request = GetSignatureHelp { position }; if !self.is_capable_for_proto_request(buffer, &request, cx) { - return Task::ready(Vec::new()); + return Task::ready(None); } - let request_task = client.request(proto::MultiLspQuery { - buffer_id: buffer.read(cx).remote_id().into(), - version: serialize_version(&buffer.read(cx).version()), - project_id: upstream_project_id, - strategy: Some(proto::multi_lsp_query::Strategy::All( - proto::AllLanguageServers {}, - )), - request: Some(proto::multi_lsp_query::Request::GetSignatureHelp( - request.to_proto(upstream_project_id, buffer.read(cx)), - )), - }); + let request_task = client.request_lsp( + upstream_project_id, + LSP_REQUEST_TIMEOUT, + cx.background_executor().clone(), + request.to_proto(upstream_project_id, buffer.read(cx)), + ); let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { - let Some(project) = weak_project.upgrade() else { - return Vec::new(); - }; - join_all( + let project = weak_project.upgrade()?; + let signatures = join_all( request_task .await - .log_err() - .map(|response| response.responses) - .unwrap_or_default() - .into_iter() - .filter_map(|lsp_response| match lsp_response.response? { - proto::lsp_response::Response::GetSignatureHelpResponse(response) => { - Some(response) - } - unexpected => { - debug_panic!("Unexpected response: {unexpected:?}"); - None - } - }) - .map(|signature_response| { + .log_err() + .flatten() + .map(|response| response.payload) + .unwrap_or_default() + .into_iter() + .map(|response| { let response = GetSignatureHelp { position }.response_from_proto( - signature_response, + response.response, project.clone(), buffer.clone(), cx.clone(), @@ -6942,7 +6870,8 @@ impl LspStore { .await .into_iter() .flatten() - .collect() + .collect(); + Some(signatures) }) } else { let all_actions_task = self.request_multiple_lsp_locally( @@ -6952,11 +6881,13 @@ impl LspStore { cx, ); cx.background_spawn(async move { - all_actions_task - .await - .into_iter() - .flat_map(|(_, actions)| actions) - .collect::>() + Some( + all_actions_task + .await + .into_iter() + .flat_map(|(_, actions)| actions) + .collect::>(), + ) }) } } @@ -6966,47 +6897,32 @@ impl LspStore { buffer: &Entity, position: PointUtf16, cx: &mut Context, - ) -> Task> { + ) -> Task>> { if let Some((client, upstream_project_id)) = self.upstream_client() { let request = GetHover { position }; if !self.is_capable_for_proto_request(buffer, &request, cx) { - return Task::ready(Vec::new()); + return Task::ready(None); } - let request_task = client.request(proto::MultiLspQuery { - buffer_id: buffer.read(cx).remote_id().into(), - version: serialize_version(&buffer.read(cx).version()), - project_id: upstream_project_id, - strategy: Some(proto::multi_lsp_query::Strategy::All( - proto::AllLanguageServers {}, - )), - request: Some(proto::multi_lsp_query::Request::GetHover( - request.to_proto(upstream_project_id, buffer.read(cx)), - )), - }); + let request_task = client.request_lsp( + upstream_project_id, + LSP_REQUEST_TIMEOUT, + cx.background_executor().clone(), + request.to_proto(upstream_project_id, buffer.read(cx)), + ); let buffer = buffer.clone(); cx.spawn(async move |weak_project, cx| { - let Some(project) = weak_project.upgrade() else { - return Vec::new(); - }; - join_all( + let project = weak_project.upgrade()?; + let hovers = join_all( request_task .await .log_err() - .map(|response| response.responses) + .flatten() + .map(|response| response.payload) .unwrap_or_default() .into_iter() - .filter_map(|lsp_response| match lsp_response.response? { - proto::lsp_response::Response::GetHoverResponse(response) => { - Some(response) - } - unexpected => { - debug_panic!("Unexpected response: {unexpected:?}"); - None - } - }) - .map(|hover_response| { + .map(|response| { let response = GetHover { position }.response_from_proto( - hover_response, + response.response, project.clone(), buffer.clone(), cx.clone(), @@ -7023,7 +6939,8 @@ impl LspStore { .await .into_iter() .flatten() - .collect() + .collect(); + Some(hovers) }) } else { let all_actions_task = self.request_multiple_lsp_locally( @@ -7033,11 +6950,13 @@ impl LspStore { cx, ); cx.background_spawn(async move { - all_actions_task - .await - .into_iter() - .filter_map(|(_, hover)| remove_empty_hover_blocks(hover?)) - .collect::>() + Some( + all_actions_task + .await + .into_iter() + .filter_map(|(_, hover)| remove_empty_hover_blocks(hover?)) + .collect::>(), + ) }) } } @@ -7073,11 +6992,11 @@ impl LspStore { let mut requests = Vec::new(); let mut requested_servers = BTreeSet::new(); - 'next_server: for ((worktree_id, _), server_ids) in local.language_server_ids.iter() { + for (seed, state) in local.language_server_ids.iter() { let Some(worktree_handle) = self .worktree_store .read(cx) - .worktree_for_id(*worktree_id, cx) + .worktree_for_id(seed.worktree_id, cx) else { continue; }; @@ -7086,31 +7005,30 @@ impl LspStore { continue; } - let mut servers_to_query = server_ids - .difference(&requested_servers) - .cloned() - .collect::>(); - for server_id in &servers_to_query { - let (lsp_adapter, server) = match local.language_servers.get(server_id) { - Some(LanguageServerState::Running { - adapter, server, .. - }) => (adapter.clone(), server), - - _ => continue 'next_server, + if !requested_servers.insert(state.id) { + continue; + } + + let (lsp_adapter, server) = match local.language_servers.get(&state.id) { + Some(LanguageServerState::Running { + adapter, server, .. + }) => (adapter.clone(), server), + + _ => continue, + }; + let supports_workspace_symbol_request = + match server.capabilities().workspace_symbol_provider { + Some(OneOf::Left(supported)) => supported, + Some(OneOf::Right(_)) => true, + None => false, }; - let supports_workspace_symbol_request = - match server.capabilities().workspace_symbol_provider { - Some(OneOf::Left(supported)) => supported, - Some(OneOf::Right(_)) => true, - None => false, - }; - if !supports_workspace_symbol_request { - continue 'next_server; - } - let worktree_abs_path = worktree.abs_path().clone(); - let worktree_handle = worktree_handle.clone(); - let server_id = server.server_id(); - requests.push( + if !supports_workspace_symbol_request { + continue; + } + let worktree_abs_path = worktree.abs_path().clone(); + let worktree_handle = worktree_handle.clone(); + let server_id = server.server_id(); + requests.push( server .request::( lsp::WorkspaceSymbolParams { @@ -7152,8 +7070,6 @@ impl LspStore { } }), ); - } - requested_servers.append(&mut servers_to_query); } cx.spawn(async move |this, cx| { @@ -7182,7 +7098,7 @@ impl LspStore { worktree = tree; path = rel_path; } else { - worktree = source_worktree.clone(); + worktree = source_worktree; path = relativize_path(&result.worktree_abs_path, &abs_path); } @@ -7231,6 +7147,36 @@ impl LspStore { summary } + /// Returns the diagnostic summary for a specific project path. + pub fn diagnostic_summary_for_path( + &self, + project_path: &ProjectPath, + _: &App, + ) -> DiagnosticSummary { + if let Some(summaries) = self + .diagnostic_summaries + .get(&project_path.worktree_id) + .and_then(|map| map.get(&project_path.path)) + { + let (error_count, warning_count) = summaries.iter().fold( + (0, 0), + |(error_count, warning_count), (_language_server_id, summary)| { + ( + error_count + summary.error_count, + warning_count + summary.warning_count, + ) + }, + ); + + DiagnosticSummary { + error_count, + warning_count, + } + } else { + DiagnosticSummary::default() + } + } + pub fn diagnostic_summaries<'a>( &'a self, include_ignored: bool, @@ -7251,7 +7197,7 @@ impl LspStore { include_ignored || worktree .entry_for_path(path.as_ref()) - .map_or(false, |entry| !entry.is_ignored) + .is_some_and(|entry| !entry.is_ignored) }) .flat_map(move |(path, summaries)| { summaries.iter().map(move |(server_id, summary)| { @@ -7285,7 +7231,7 @@ impl LspStore { let buffer = buffer.read(cx); let file = File::from_dyn(buffer.file())?; let abs_path = file.as_local()?.abs_path(cx); - let uri = lsp::Url::from_file_path(abs_path).unwrap(); + let uri = lsp::Uri::from_file_path(abs_path).unwrap(); let next_snapshot = buffer.text_snapshot(); for language_server in language_servers { let language_server = language_server.clone(); @@ -7416,7 +7362,7 @@ impl LspStore { None } - pub(crate) async fn refresh_workspace_configurations( + async fn refresh_workspace_configurations( lsp_store: &WeakEntity, fs: Arc, cx: &mut AsyncApp, @@ -7425,90 +7371,83 @@ impl LspStore { let mut refreshed_servers = HashSet::default(); let servers = lsp_store .update(cx, |lsp_store, cx| { - let toolchain_store = lsp_store.toolchain_store(cx); - let Some(local) = lsp_store.as_local() else { - return Vec::default(); - }; - local + let local = lsp_store.as_local()?; + + let servers = local .language_server_ids .iter() - .flat_map(|((worktree_id, _), server_ids)| { + .filter_map(|(seed, state)| { let worktree = lsp_store .worktree_store .read(cx) - .worktree_for_id(*worktree_id, cx); - let delegate = worktree.map(|worktree| { - LocalLspAdapterDelegate::new( - local.languages.clone(), - &local.environment, - cx.weak_entity(), - &worktree, - local.http_client.clone(), - local.fs.clone(), - cx, - ) - }); + .worktree_for_id(seed.worktree_id, cx); + let delegate: Arc = + worktree.map(|worktree| { + LocalLspAdapterDelegate::new( + local.languages.clone(), + &local.environment, + cx.weak_entity(), + &worktree, + local.http_client.clone(), + local.fs.clone(), + cx, + ) + })?; + let server_id = state.id; - let fs = fs.clone(); - let toolchain_store = toolchain_store.clone(); - server_ids.iter().filter_map(|server_id| { - let delegate = delegate.clone()? as Arc; - let states = local.language_servers.get(server_id)?; - - match states { - LanguageServerState::Starting { .. } => None, - LanguageServerState::Running { - adapter, server, .. - } => { - let fs = fs.clone(); - let toolchain_store = toolchain_store.clone(); - let adapter = adapter.clone(); - let server = server.clone(); - refreshed_servers.insert(server.name()); - Some(cx.spawn(async move |_, cx| { - let settings = - LocalLspStore::workspace_configuration_for_adapter( - adapter.adapter.clone(), - fs.as_ref(), - &delegate, - toolchain_store, - cx, - ) - .await - .ok()?; - server - .notify::( - &lsp::DidChangeConfigurationParams { settings }, - ) - .ok()?; - Some(()) - })) - } + let states = local.language_servers.get(&server_id)?; + + match states { + LanguageServerState::Starting { .. } => None, + LanguageServerState::Running { + adapter, server, .. + } => { + let fs = fs.clone(); + + let adapter = adapter.clone(); + let server = server.clone(); + refreshed_servers.insert(server.name()); + let toolchain = seed.toolchain.clone(); + Some(cx.spawn(async move |_, cx| { + let settings = + LocalLspStore::workspace_configuration_for_adapter( + adapter.adapter.clone(), + fs.as_ref(), + &delegate, + toolchain, + cx, + ) + .await + .ok()?; + server + .notify::( + &lsp::DidChangeConfigurationParams { settings }, + ) + .ok()?; + Some(()) + })) } - }).collect::>() + } }) - .collect::>() + .collect::>(); + + Some(servers) }) - .ok()?; + .ok() + .flatten()?; - log::info!("Refreshing workspace configurations for servers {refreshed_servers:?}"); + log::debug!("Refreshing workspace configurations for servers {refreshed_servers:?}"); // TODO this asynchronous job runs concurrently with extension (de)registration and may take enough time for a certain extension // to stop and unregister its language server wrapper. // This is racy : an extension might have already removed all `local.language_servers` state, but here we `.clone()` and hold onto it anyway. // This now causes errors in the logs, we should find a way to remove such servers from the processing everywhere. let _: Vec> = join_all(servers).await; + Some(()) }) .await; } - fn toolchain_store(&self, cx: &App) -> Arc { - if let Some(toolchain_store) = self.toolchain_store.as_ref() { - toolchain_store.read(cx).as_language_toolchain_store() - } else { - Arc::new(EmptyToolchainStore) - } - } fn maintain_workspace_config( fs: Arc, external_refresh_requests: watch::Receiver<()>, @@ -7523,8 +7462,19 @@ impl LspStore { let mut joint_future = futures::stream::select(settings_changed_rx, external_refresh_requests); + // Multiple things can happen when a workspace environment (selected toolchain + settings) change: + // - We might shut down a language server if it's no longer enabled for a given language (and there are no buffers using it otherwise). + // - We might also shut it down when the workspace configuration of all of the users of a given language server converges onto that of the other. + // - In the same vein, we might also decide to start a new language server if the workspace configuration *diverges* from the other. + // - In the easiest case (where we're not wrangling the lifetime of a language server anyhow), if none of the roots of a single language server diverge in their configuration, + // but it is still different to what we had before, we're gonna send out a workspace configuration update. cx.spawn(async move |this, cx| { while let Some(()) = joint_future.next().await { + this.update(cx, |this, cx| { + this.refresh_server_tree(cx); + }) + .ok(); + Self::refresh_workspace_configurations(&this, fs.clone(), cx).await; } @@ -7592,7 +7542,7 @@ impl LspStore { server: Some(proto::LanguageServer { id: server_id.to_proto(), name: status.name.to_string(), - worktree_id: None, + worktree_id: status.worktree.map(|id| id.to_proto()), }), capabilities: serde_json::to_string(&server.capabilities()) .expect("serializing server LSP capabilities"), @@ -7617,9 +7567,15 @@ impl LspStore { pub(crate) fn set_language_server_statuses_from_proto( &mut self, + project: WeakEntity, language_servers: Vec, server_capabilities: Vec, + cx: &mut Context, ) { + let lsp_logs = cx + .try_global::() + .map(|lsp_store| lsp_store.0.clone()); + self.language_server_statuses = language_servers .into_iter() .zip(server_capabilities) @@ -7629,60 +7585,40 @@ impl LspStore { self.lsp_server_capabilities .insert(server_id, server_capabilities); } + + let name = LanguageServerName::from_proto(server.name); + let worktree = server.worktree_id.map(WorktreeId::from_proto); + + if let Some(lsp_logs) = &lsp_logs { + lsp_logs.update(cx, |lsp_logs, cx| { + lsp_logs.add_language_server( + // Only remote clients get their language servers set from proto + LanguageServerKind::Remote { + project: project.clone(), + }, + server_id, + Some(name.clone()), + worktree, + None, + cx, + ); + }); + } + ( server_id, LanguageServerStatus { - name: LanguageServerName::from_proto(server.name), + name, pending_work: Default::default(), has_pending_diagnostic_updates: false, progress_tokens: Default::default(), + worktree, }, ) }) .collect(); } - fn register_local_language_server( - &mut self, - worktree: Entity, - language_server_name: LanguageServerName, - language_server_id: LanguageServerId, - cx: &mut App, - ) { - let Some(local) = self.as_local_mut() else { - return; - }; - - let worktree_id = worktree.read(cx).id(); - if worktree.read(cx).is_visible() { - let path = ProjectPath { - worktree_id, - path: Arc::from("".as_ref()), - }; - let delegate = Arc::new(ManifestQueryDelegate::new(worktree.read(cx).snapshot())); - local.lsp_tree.update(cx, |language_server_tree, cx| { - for node in language_server_tree.get( - path, - AdapterQuery::Adapter(&language_server_name), - delegate, - cx, - ) { - node.server_id_or_init(|disposition| { - assert_eq!(disposition.server_name, &language_server_name); - - language_server_id - }); - } - }); - } - - local - .language_server_ids - .entry((worktree_id, language_server_name)) - .or_default() - .insert(language_server_id); - } - #[cfg(test)] pub fn update_diagnostic_entries( &mut self, @@ -7738,19 +7674,16 @@ impl LspStore { let snapshot = buffer_handle.read(cx).snapshot(); let buffer = buffer_handle.read(cx); let reused_diagnostics = buffer - .get_diagnostics(server_id) - .into_iter() - .flat_map(|diag| { - diag.iter() - .filter(|v| merge(buffer, &v.diagnostic, cx)) - .map(|v| { - let start = Unclipped(v.range.start.to_point_utf16(&snapshot)); - let end = Unclipped(v.range.end.to_point_utf16(&snapshot)); - DiagnosticEntry { - range: start..end, - diagnostic: v.diagnostic.clone(), - } - }) + .buffer_diagnostics(Some(server_id)) + .iter() + .filter(|v| merge(buffer, &v.diagnostic, cx)) + .map(|v| { + let start = Unclipped(v.range.start.to_point_utf16(&snapshot)); + let end = Unclipped(v.range.end.to_point_utf16(&snapshot)); + DiagnosticEntry { + range: start..end, + diagnostic: v.diagnostic.clone(), + } }) .collect::>(); @@ -7794,7 +7727,7 @@ impl LspStore { } None => { diagnostics_summary = Some(proto::UpdateDiagnosticSummary { - project_id: project_id, + project_id, worktree_id: worktree_id.to_proto(), summary: Some(proto::DiagnosticSummary { path: project_path.path.as_ref().to_proto(), @@ -7912,17 +7845,12 @@ impl LspStore { .await }) } else if let Some(local) = self.as_local() { - let Some(language_server_id) = local - .language_server_ids - .get(&( - symbol.source_worktree_id, - symbol.language_server_name.clone(), - )) - .and_then(|ids| { - ids.contains(&symbol.source_language_server_id) - .then_some(symbol.source_language_server_id) - }) - else { + let is_valid = local.language_server_ids.iter().any(|(seed, state)| { + seed.worktree_id == symbol.source_worktree_id + && state.id == symbol.source_language_server_id + && symbol.language_server_name == seed.name + }); + if !is_valid { return Task::ready(Err(anyhow!( "language server for worktree and language not found" ))); @@ -7940,34 +7868,28 @@ impl LspStore { }; let symbol_abs_path = resolve_path(&worktree_abs_path, &symbol.path.path); - let symbol_uri = if let Ok(uri) = lsp::Url::from_file_path(symbol_abs_path) { + let symbol_uri = if let Ok(uri) = lsp::Uri::from_file_path(symbol_abs_path) { uri } else { return Task::ready(Err(anyhow!("invalid symbol path"))); }; - self.open_local_buffer_via_lsp( - symbol_uri, - language_server_id, - symbol.language_server_name.clone(), - cx, - ) + self.open_local_buffer_via_lsp(symbol_uri, symbol.source_language_server_id, cx) } else { Task::ready(Err(anyhow!("no upstream client or local store"))) } } - pub fn open_local_buffer_via_lsp( + pub(crate) fn open_local_buffer_via_lsp( &mut self, - mut abs_path: lsp::Url, + abs_path: lsp::Uri, language_server_id: LanguageServerId, - language_server_name: LanguageServerName, cx: &mut Context, ) -> Task>> { cx.spawn(async move |lsp_store, cx| { // Escape percent-encoded string. let current_scheme = abs_path.scheme().to_owned(); - let _ = abs_path.set_scheme("file"); + // Uri is immutable, so we can't modify the scheme let abs_path = abs_path .to_file_path() @@ -8012,12 +7934,13 @@ impl LspStore { if worktree.read_with(cx, |worktree, _| worktree.is_local())? { lsp_store .update(cx, |lsp_store, cx| { - lsp_store.register_local_language_server( - worktree.clone(), - language_server_name, - language_server_id, - cx, - ) + if let Some(local) = lsp_store.as_local_mut() { + local.register_language_server_for_invisible_worktree( + &worktree, + language_server_id, + cx, + ) + } }) .ok(); } @@ -8150,12 +8073,209 @@ impl LspStore { })? } + async fn handle_lsp_query( + lsp_store: Entity, + envelope: TypedEnvelope, + mut cx: AsyncApp, + ) -> Result { + use proto::lsp_query::Request; + let sender_id = envelope.original_sender_id().unwrap_or_default(); + let lsp_query = envelope.payload; + let lsp_request_id = LspRequestId(lsp_query.lsp_request_id); + match lsp_query.request.context("invalid LSP query request")? { + Request::GetReferences(get_references) => { + let position = get_references.position.clone().and_then(deserialize_anchor); + Self::query_lsp_locally::( + lsp_store, + sender_id, + lsp_request_id, + get_references, + position, + cx.clone(), + ) + .await?; + } + Request::GetDocumentColor(get_document_color) => { + Self::query_lsp_locally::( + lsp_store, + sender_id, + lsp_request_id, + get_document_color, + None, + cx.clone(), + ) + .await?; + } + Request::GetHover(get_hover) => { + let position = get_hover.position.clone().and_then(deserialize_anchor); + Self::query_lsp_locally::( + lsp_store, + sender_id, + lsp_request_id, + get_hover, + position, + cx.clone(), + ) + .await?; + } + Request::GetCodeActions(get_code_actions) => { + Self::query_lsp_locally::( + lsp_store, + sender_id, + lsp_request_id, + get_code_actions, + None, + cx.clone(), + ) + .await?; + } + Request::GetSignatureHelp(get_signature_help) => { + let position = get_signature_help + .position + .clone() + .and_then(deserialize_anchor); + Self::query_lsp_locally::( + lsp_store, + sender_id, + lsp_request_id, + get_signature_help, + position, + cx.clone(), + ) + .await?; + } + Request::GetCodeLens(get_code_lens) => { + Self::query_lsp_locally::( + lsp_store, + sender_id, + lsp_request_id, + get_code_lens, + None, + cx.clone(), + ) + .await?; + } + Request::GetDefinition(get_definition) => { + let position = get_definition.position.clone().and_then(deserialize_anchor); + Self::query_lsp_locally::( + lsp_store, + sender_id, + lsp_request_id, + get_definition, + position, + cx.clone(), + ) + .await?; + } + Request::GetDeclaration(get_declaration) => { + let position = get_declaration + .position + .clone() + .and_then(deserialize_anchor); + Self::query_lsp_locally::( + lsp_store, + sender_id, + lsp_request_id, + get_declaration, + position, + cx.clone(), + ) + .await?; + } + Request::GetTypeDefinition(get_type_definition) => { + let position = get_type_definition + .position + .clone() + .and_then(deserialize_anchor); + Self::query_lsp_locally::( + lsp_store, + sender_id, + lsp_request_id, + get_type_definition, + position, + cx.clone(), + ) + .await?; + } + Request::GetImplementation(get_implementation) => { + let position = get_implementation + .position + .clone() + .and_then(deserialize_anchor); + Self::query_lsp_locally::( + lsp_store, + sender_id, + lsp_request_id, + get_implementation, + position, + cx.clone(), + ) + .await?; + } + // Diagnostics pull synchronizes internally via the buffer state, and cannot be handled generically as the other requests. + Request::GetDocumentDiagnostics(get_document_diagnostics) => { + let buffer_id = BufferId::new(get_document_diagnostics.buffer_id())?; + let version = deserialize_version(get_document_diagnostics.buffer_version()); + let buffer = lsp_store.update(&mut cx, |this, cx| { + this.buffer_store.read(cx).get_existing(buffer_id) + })??; + buffer + .update(&mut cx, |buffer, _| { + buffer.wait_for_version(version.clone()) + })? + .await?; + lsp_store.update(&mut cx, |lsp_store, cx| { + let existing_queries = lsp_store + .running_lsp_requests + .entry(TypeId::of::()) + .or_default(); + if ::ProtoRequest::stop_previous_requests( + ) || buffer.read(cx).version.changed_since(&existing_queries.0) + { + existing_queries.1.clear(); + } + existing_queries.1.insert( + lsp_request_id, + cx.spawn(async move |lsp_store, cx| { + let diagnostics_pull = lsp_store + .update(cx, |lsp_store, cx| { + lsp_store.pull_diagnostics_for_buffer(buffer, cx) + }) + .ok(); + if let Some(diagnostics_pull) = diagnostics_pull { + match diagnostics_pull.await { + Ok(()) => {} + Err(e) => log::error!("Failed to pull diagnostics: {e:#}"), + }; + } + }), + ); + })?; + } + } + Ok(proto::Ack {}) + } + + async fn handle_lsp_query_response( + lsp_store: Entity, + envelope: TypedEnvelope, + cx: AsyncApp, + ) -> Result<()> { + lsp_store.read_with(&cx, |lsp_store, _| { + if let Some((upstream_client, _)) = lsp_store.upstream_client() { + upstream_client.handle_lsp_response(envelope.clone()); + } + })?; + Ok(()) + } + + // todo(lsp) remove after Zed Stable hits v0.204.x async fn handle_multi_lsp_query( lsp_store: Entity, envelope: TypedEnvelope, mut cx: AsyncApp, ) -> Result { - let response_from_ssh = lsp_store.read_with(&mut cx, |this, _| { + let response_from_ssh = lsp_store.read_with(&cx, |this, _| { let (upstream_client, project_id) = this.upstream_client()?; let mut payload = envelope.payload.clone(); payload.project_id = project_id; @@ -8177,7 +8297,7 @@ impl LspStore { buffer.wait_for_version(version.clone()) })? .await?; - let buffer_version = buffer.read_with(&mut cx, |buffer, _| buffer.version())?; + let buffer_version = buffer.read_with(&cx, |buffer, _| buffer.version())?; match envelope .payload .strategy @@ -8718,12 +8838,12 @@ impl LspStore { })? .context("worktree not found")?; let (old_abs_path, new_abs_path) = { - let root_path = worktree.read_with(&mut cx, |this, _| this.abs_path())?; + let root_path = worktree.read_with(&cx, |this, _| this.abs_path())?; let new_path = PathBuf::from_proto(envelope.payload.new_path.clone()); (root_path.join(&old_path), root_path.join(&new_path)) }; - Self::will_rename_entry( + let _transaction = Self::will_rename_entry( this.downgrade(), worktree_id, &old_abs_path, @@ -8733,7 +8853,7 @@ impl LspStore { ) .await; let response = Worktree::handle_rename_entry(worktree, envelope.payload, cx.clone()).await; - this.read_with(&mut cx, |this, _| { + this.read_with(&cx, |this, _| { this.did_rename_entry(worktree_id, &old_abs_path, &new_abs_path, is_dir); }) .ok(); @@ -8769,12 +8889,11 @@ impl LspStore { if summary.is_empty() { if let Some(worktree_summaries) = lsp_store.diagnostic_summaries.get_mut(&worktree_id) + && let Some(summaries) = worktree_summaries.get_mut(&path) { - if let Some(summaries) = worktree_summaries.get_mut(&path) { - summaries.remove(&server_id); - if summaries.is_empty() { - worktree_summaries.remove(&path); - } + summaries.remove(&server_id); + if summaries.is_empty() { + worktree_summaries.remove(&path); } } } else { @@ -8859,6 +8978,7 @@ impl LspStore { pending_work: Default::default(), has_pending_diagnostic_updates: false, progress_tokens: Default::default(), + worktree: server.worktree_id.map(WorktreeId::from_proto), }, ); cx.emit(LspStoreEvent::LanguageServerAdded( @@ -8968,10 +9088,10 @@ impl LspStore { async fn handle_lsp_ext_cancel_flycheck( lsp_store: Entity, envelope: TypedEnvelope, - mut cx: AsyncApp, + cx: AsyncApp, ) -> Result { let server_id = LanguageServerId(envelope.payload.language_server_id as usize); - lsp_store.read_with(&mut cx, |lsp_store, _| { + lsp_store.read_with(&cx, |lsp_store, _| { if let Some(server) = lsp_store.language_server_for_id(server_id) { server .notify::(&()) @@ -8993,13 +9113,22 @@ impl LspStore { lsp_store.update(&mut cx, |lsp_store, cx| { if let Some(server) = lsp_store.language_server_for_id(server_id) { let text_document = if envelope.payload.current_file_only { - let buffer_id = BufferId::new(envelope.payload.buffer_id)?; - lsp_store - .buffer_store() - .read(cx) - .get(buffer_id) - .and_then(|buffer| Some(buffer.read(cx).file()?.as_local()?.abs_path(cx))) - .map(|path| make_text_document_identifier(&path)) + let buffer_id = envelope + .payload + .buffer_id + .map(|id| BufferId::new(id)) + .transpose()?; + buffer_id + .and_then(|buffer_id| { + lsp_store + .buffer_store() + .read(cx) + .get(buffer_id) + .and_then(|buffer| { + Some(buffer.read(cx).file()?.as_local()?.abs_path(cx)) + }) + .map(|path| make_text_document_identifier(&path)) + }) .transpose()? } else { None @@ -9020,10 +9149,10 @@ impl LspStore { async fn handle_lsp_ext_clear_flycheck( lsp_store: Entity, envelope: TypedEnvelope, - mut cx: AsyncApp, + cx: AsyncApp, ) -> Result { let server_id = LanguageServerId(envelope.payload.language_server_id as usize); - lsp_store.read_with(&mut cx, |lsp_store, _| { + lsp_store.read_with(&cx, |lsp_store, _| { if let Some(server) = lsp_store.language_server_for_id(server_id) { server .notify::(&()) @@ -9153,8 +9282,12 @@ impl LspStore { maybe!({ let local_store = self.as_local()?; - let old_uri = lsp::Url::from_file_path(old_path).ok().map(String::from)?; - let new_uri = lsp::Url::from_file_path(new_path).ok().map(String::from)?; + let old_uri = lsp::Uri::from_file_path(old_path) + .ok() + .map(|uri| uri.to_string())?; + let new_uri = lsp::Uri::from_file_path(new_path) + .ok() + .map(|uri| uri.to_string())?; for language_server in local_store.language_servers_for_worktree(worktree_id) { let Some(filter) = local_store @@ -9186,9 +9319,13 @@ impl LspStore { new_path: &Path, is_dir: bool, cx: AsyncApp, - ) -> Task<()> { - let old_uri = lsp::Url::from_file_path(old_path).ok().map(String::from); - let new_uri = lsp::Url::from_file_path(new_path).ok().map(String::from); + ) -> Task { + let old_uri = lsp::Uri::from_file_path(old_path) + .ok() + .map(|uri| uri.to_string()); + let new_uri = lsp::Uri::from_file_path(new_path) + .ok() + .map(|uri| uri.to_string()); cx.spawn(async move |cx| { let mut tasks = vec![]; this.update(cx, |this, cx| { @@ -9202,11 +9339,7 @@ impl LspStore { else { continue; }; - let Some(adapter) = - this.language_server_adapter_for_id(language_server.server_id()) - else { - continue; - }; + if filter.should_send_will_rename(&old_uri, is_dir) { let apply_edit = cx.spawn({ let old_uri = old_uri.clone(); @@ -9223,17 +9356,16 @@ impl LspStore { .log_err() .flatten()?; - LocalLspStore::deserialize_workspace_edit( + let transaction = LocalLspStore::deserialize_workspace_edit( this.upgrade()?, edit, false, - adapter.clone(), language_server.clone(), cx, ) .await - .ok(); - Some(()) + .ok()?; + Some(transaction) } }); tasks.push(apply_edit); @@ -9243,11 +9375,17 @@ impl LspStore { }) .ok() .flatten(); + let mut merged_transaction = ProjectTransaction::default(); for task in tasks { // Await on tasks sequentially so that the order of application of edits is deterministic // (at least with regards to the order of registration of language servers) - task.await; + if let Some(transaction) = task.await { + for (buffer, buffer_transaction) in transaction.0 { + merged_transaction.0.insert(buffer, buffer_transaction); + } + } } + merged_transaction }) } @@ -9348,9 +9486,7 @@ impl LspStore { let is_disk_based_diagnostics_progress = disk_based_diagnostics_progress_token .as_ref() - .map_or(false, |disk_based_token| { - token.starts_with(disk_based_token) - }); + .is_some_and(|disk_based_token| token.starts_with(disk_based_token)); match progress { lsp::WorkDoneProgress::Begin(report) => { @@ -9480,10 +9616,10 @@ impl LspStore { cx: &mut Context, ) { if let Some(status) = self.language_server_statuses.get_mut(&language_server_id) { - if let Some(work) = status.pending_work.remove(&token) { - if !work.is_disk_based_diagnostics_progress { - cx.emit(LspStoreEvent::RefreshInlayHints); - } + if let Some(work) = status.pending_work.remove(&token) + && !work.is_disk_based_diagnostics_progress + { + cx.emit(LspStoreEvent::RefreshInlayHints); } cx.notify(); } @@ -9796,7 +9932,7 @@ impl LspStore { let peer_id = envelope.original_sender_id().unwrap_or_default(); let symbol = envelope.payload.symbol.context("invalid symbol")?; let symbol = Self::deserialize_symbol(symbol)?; - let symbol = this.read_with(&mut cx, |this, _| { + let symbol = this.read_with(&cx, |this, _| { let signature = this.symbol_signature(&symbol.path); anyhow::ensure!(signature == symbol.signature, "invalid symbol signature"); Ok(symbol) @@ -10046,7 +10182,7 @@ impl LspStore { ) -> Shared>>> { if let Some(environment) = &self.as_local().map(|local| local.environment.clone()) { environment.update(cx, |env, cx| { - env.get_buffer_environment(&buffer, &self.worktree_store, cx) + env.get_buffer_environment(buffer, &self.worktree_store, cx) }) } else { Task::ready(None).shared() @@ -10062,7 +10198,7 @@ impl LspStore { cx: &mut Context, ) -> Task> { let logger = zlog::scoped!("format"); - if let Some(_) = self.as_local() { + if self.as_local().is_some() { zlog::trace!(logger => "Formatting locally"); let logger = zlog::scoped!(logger => "local"); let buffers = buffers @@ -10277,10 +10413,10 @@ impl LspStore { None => None, }; - if let Some(server) = server { - if let Some(shutdown) = server.shutdown() { - shutdown.await; - } + if let Some(server) = server + && let Some(shutdown) = server.shutdown() + { + shutdown.await; } } @@ -10290,28 +10426,18 @@ impl LspStore { &mut self, server_id: LanguageServerId, cx: &mut Context, - ) -> Task> { + ) -> Task<()> { let local = match &mut self.mode { LspStoreMode::Local(local) => local, _ => { - return Task::ready(Vec::new()); + return Task::ready(()); } }; - let mut orphaned_worktrees = Vec::new(); // Remove this server ID from all entries in the given worktree. - local.language_server_ids.retain(|(worktree, _), ids| { - if !ids.remove(&server_id) { - return true; - } - - if ids.is_empty() { - orphaned_worktrees.push(*worktree); - false - } else { - true - } - }); + local + .language_server_ids + .retain(|_, state| state.id != server_id); self.buffer_store.update(cx, |buffer_store, cx| { for buffer in buffer_store.buffers() { buffer.update(cx, |buffer, cx| { @@ -10364,7 +10490,7 @@ impl LspStore { let name = self .language_server_statuses .remove(&server_id) - .map(|status| status.name.clone()) + .map(|status| status.name) .or_else(|| { if let Some(LanguageServerState::Running { adapter, .. }) = server_state.as_ref() { Some(adapter.name()) @@ -10390,14 +10516,13 @@ impl LspStore { cx.notify(); }) .ok(); - orphaned_worktrees }); } if server_state.is_some() { cx.emit(LspStoreEvent::LanguageServerRemoved(server_id)); } - Task::ready(orphaned_worktrees) + Task::ready(()) } pub fn stop_all_language_servers(&mut self, cx: &mut Context) { @@ -10416,12 +10541,9 @@ impl LspStore { let language_servers_to_stop = local .language_server_ids .values() - .flatten() - .copied() + .map(|state| state.id) .collect(); - local.lsp_tree.update(cx, |this, _| { - this.remove_nodes(&language_servers_to_stop); - }); + local.lsp_tree.remove_nodes(&language_servers_to_stop); let tasks = language_servers_to_stop .into_iter() .map(|server| self.stop_local_language_server(server, cx)) @@ -10568,37 +10690,31 @@ impl LspStore { for buffer in buffers { buffer.update(cx, |buffer, cx| { language_servers_to_stop.extend(local.language_server_ids_for_buffer(buffer, cx)); - if let Some(worktree_id) = buffer.file().map(|f| f.worktree_id(cx)) { - if covered_worktrees.insert(worktree_id) { - language_server_names_to_stop.retain(|name| { - match local.language_server_ids.get(&(worktree_id, name.clone())) { - Some(server_ids) => { - language_servers_to_stop - .extend(server_ids.into_iter().copied()); - false - } - None => true, - } - }); - } + if let Some(worktree_id) = buffer.file().map(|f| f.worktree_id(cx)) + && covered_worktrees.insert(worktree_id) + { + language_server_names_to_stop.retain(|name| { + let old_ids_count = language_servers_to_stop.len(); + let all_language_servers_with_this_name = local + .language_server_ids + .iter() + .filter_map(|(seed, state)| seed.name.eq(name).then(|| state.id)); + language_servers_to_stop.extend(all_language_servers_with_this_name); + old_ids_count == language_servers_to_stop.len() + }); } }); } for name in language_server_names_to_stop { - if let Some(server_ids) = local - .language_server_ids - .iter() - .filter(|((_, server_name), _)| server_name == &name) - .map(|((_, _), server_ids)| server_ids) - .max_by_key(|server_ids| server_ids.len()) - { - language_servers_to_stop.extend(server_ids.into_iter().copied()); - } + language_servers_to_stop.extend( + local + .language_server_ids + .iter() + .filter_map(|(seed, v)| seed.name.eq(&name).then(|| v.id)), + ); } - local.lsp_tree.update(cx, |this, _| { - this.remove_nodes(&language_servers_to_stop); - }); + local.lsp_tree.remove_nodes(&language_servers_to_stop); let tasks = language_servers_to_stop .into_iter() .map(|server| self.stop_local_language_server(server, cx)) @@ -10703,7 +10819,7 @@ impl LspStore { let is_supporting = diagnostic .related_information .as_ref() - .map_or(false, |infos| { + .is_some_and(|infos| { infos.iter().any(|info| { primary_diagnostic_group_ids.contains_key(&( source, @@ -10716,11 +10832,11 @@ impl LspStore { let is_unnecessary = diagnostic .tags .as_ref() - .map_or(false, |tags| tags.contains(&DiagnosticTag::UNNECESSARY)); + .is_some_and(|tags| tags.contains(&DiagnosticTag::UNNECESSARY)); let underline = self .language_server_adapter_for_id(server_id) - .map_or(true, |adapter| adapter.underline_diagnostic(diagnostic)); + .is_none_or(|adapter| adapter.underline_diagnostic(diagnostic)); if is_supporting { supporting_diagnostics.insert( @@ -10730,7 +10846,7 @@ impl LspStore { } else { let group_id = post_inc(&mut self.as_local_mut().unwrap().next_diagnostic_group_id); let is_disk_based = - source.map_or(false, |source| disk_based_sources.contains(source)); + source.is_some_and(|source| disk_based_sources.contains(source)); sources_by_group_id.insert(group_id, source); primary_diagnostic_group_ids @@ -10821,8 +10937,8 @@ impl LspStore { adapter: Arc, language_server: Arc, server_id: LanguageServerId, - key: (WorktreeId, LanguageServerName), - workspace_folders: Arc>>, + key: LanguageServerSeed, + workspace_folders: Arc>>, cx: &mut Context, ) { let Some(local) = self.as_local_mut() else { @@ -10833,7 +10949,7 @@ impl LspStore { if local .language_server_ids .get(&key) - .map(|ids| !ids.contains(&server_id)) + .map(|state| state.id != server_id) .unwrap_or(false) { return; @@ -10884,13 +11000,14 @@ impl LspStore { pending_work: Default::default(), has_pending_diagnostic_updates: false, progress_tokens: Default::default(), + worktree: Some(key.worktree_id), }, ); cx.emit(LspStoreEvent::LanguageServerAdded( server_id, language_server.name(), - Some(key.0), + Some(key.worktree_id), )); cx.emit(LspStoreEvent::RefreshInlayHints); @@ -10902,7 +11019,7 @@ impl LspStore { server: Some(proto::LanguageServer { id: server_id.to_proto(), name: language_server.name().to_string(), - worktree_id: Some(key.0.to_proto()), + worktree_id: Some(key.worktree_id.to_proto()), }), capabilities: serde_json::to_string(&server_capabilities) .expect("serializing server LSP capabilities"), @@ -10914,13 +11031,13 @@ impl LspStore { // Tell the language server about every open buffer in the worktree that matches the language. // Also check for buffers in worktrees that reused this server - let mut worktrees_using_server = vec![key.0]; + let mut worktrees_using_server = vec![key.worktree_id]; if let Some(local) = self.as_local() { // Find all worktrees that have this server in their language server tree - for (worktree_id, servers) in &local.lsp_tree.read(cx).instances { - if *worktree_id != key.0 { - for (_, server_map) in &servers.roots { - if server_map.contains_key(&key.1) { + for (worktree_id, servers) in &local.lsp_tree.instances { + if *worktree_id != key.worktree_id { + for server_map in servers.roots.values() { + if server_map.contains_key(&key.name) { worktrees_using_server.push(*worktree_id); } } @@ -10946,7 +11063,7 @@ impl LspStore { .languages .lsp_adapters(&language.name()) .iter() - .any(|a| a.name == key.1) + .any(|a| a.name == key.name) { continue; } @@ -10981,7 +11098,7 @@ impl LspStore { let snapshot = versions.last().unwrap(); let version = snapshot.version; let initial_snapshot = &snapshot.snapshot; - let uri = lsp::Url::from_file_path(file.abs_path(cx)).unwrap(); + let uri = lsp::Uri::from_file_path(file.abs_path(cx)).unwrap(); language_server.register_buffer( uri, adapter.language_id(&language.name()), @@ -11090,10 +11207,10 @@ impl LspStore { if let Some((LanguageServerState::Running { server, .. }, status)) = server.zip(status) { for (token, progress) in &status.pending_work { - if let Some(token_to_cancel) = token_to_cancel.as_ref() { - if token != token_to_cancel { - continue; - } + if let Some(token_to_cancel) = token_to_cancel.as_ref() + && token != token_to_cancel + { + continue; } if progress.is_cancellable { server @@ -11184,18 +11301,14 @@ impl LspStore { let Some(local) = self.as_local() else { return }; local.prettier_store.update(cx, |prettier_store, cx| { - prettier_store.update_prettier_settings(&worktree_handle, changes, cx) + prettier_store.update_prettier_settings(worktree_handle, changes, cx) }); let worktree_id = worktree_handle.read(cx).id(); let mut language_server_ids = local .language_server_ids .iter() - .flat_map(|((server_worktree, _), server_ids)| { - server_ids - .iter() - .filter_map(|server_id| server_worktree.eq(&worktree_id).then(|| *server_id)) - }) + .filter_map(|(seed, v)| seed.worktree_id.eq(&worktree_id).then(|| v.id)) .collect::>(); language_server_ids.sort(); language_server_ids.dedup(); @@ -11204,41 +11317,47 @@ impl LspStore { for server_id in &language_server_ids { if let Some(LanguageServerState::Running { server, .. }) = local.language_servers.get(server_id) - { - if let Some(watched_paths) = local + && let Some(watched_paths) = local .language_server_watched_paths .get(server_id) .and_then(|paths| paths.worktree_paths.get(&worktree_id)) - { - let params = lsp::DidChangeWatchedFilesParams { - changes: changes - .iter() - .filter_map(|(path, _, change)| { - if !watched_paths.is_match(path) { - return None; - } - let typ = match change { - PathChange::Loaded => return None, - PathChange::Added => lsp::FileChangeType::CREATED, - PathChange::Removed => lsp::FileChangeType::DELETED, - PathChange::Updated => lsp::FileChangeType::CHANGED, - PathChange::AddedOrUpdated => lsp::FileChangeType::CHANGED, - }; - Some(lsp::FileEvent { - uri: lsp::Url::from_file_path(abs_path.join(path)).unwrap(), - typ, - }) + { + let params = lsp::DidChangeWatchedFilesParams { + changes: changes + .iter() + .filter_map(|(path, _, change)| { + if !watched_paths.is_match(path) { + return None; + } + let typ = match change { + PathChange::Loaded => return None, + PathChange::Added => lsp::FileChangeType::CREATED, + PathChange::Removed => lsp::FileChangeType::DELETED, + PathChange::Updated => lsp::FileChangeType::CHANGED, + PathChange::AddedOrUpdated => lsp::FileChangeType::CHANGED, + }; + Some(lsp::FileEvent { + uri: lsp::Uri::from_file_path(abs_path.join(path)).unwrap(), + typ, }) - .collect(), - }; - if !params.changes.is_empty() { - server - .notify::(¶ms) - .ok(); - } + }) + .collect(), + }; + if !params.changes.is_empty() { + server + .notify::(¶ms) + .ok(); } } } + for (path, _, _) in changes { + if let Some(file_name) = path.file_name().and_then(|file_name| file_name.to_str()) + && local.watched_manifest_filenames.contains(file_name) + { + self.request_workspace_config_refresh(); + break; + } + } } pub fn wait_for_remote_buffer( @@ -11630,7 +11749,7 @@ impl LspStore { File::from_dyn(buffer.file()) .and_then(|file| { let abs_path = file.as_local()?.abs_path(cx); - lsp::Url::from_file_path(abs_path).ok() + lsp::Uri::from_file_path(abs_path).ok() }) .is_none_or(|buffer_uri| { unchanged_buffers.contains(&buffer_uri) @@ -11679,13 +11798,26 @@ impl LspStore { "workspace/didChangeConfiguration" => { // Ignore payload since we notify clients of setting changes unconditionally, relying on them pulling the latest settings. } + "workspace/didChangeWorkspaceFolders" => { + // In this case register options is an empty object, we can ignore it + let caps = lsp::WorkspaceFoldersServerCapabilities { + supported: Some(true), + change_notifications: Some(OneOf::Right(reg.id)), + }; + server.update_capabilities(|capabilities| { + capabilities + .workspace + .get_or_insert_default() + .workspace_folders = Some(caps); + }); + notify_server_capabilities_updated(&server, cx); + } "workspace/symbol" => { - if let Some(options) = parse_register_capabilities(reg)? { - server.update_capabilities(|capabilities| { - capabilities.workspace_symbol_provider = Some(options); - }); - notify_server_capabilities_updated(&server, cx); - } + let options = parse_register_capabilities(reg)?; + server.update_capabilities(|capabilities| { + capabilities.workspace_symbol_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); } "workspace/fileOperations" => { if let Some(options) = reg.register_options { @@ -11709,12 +11841,11 @@ impl LspStore { } } "textDocument/rangeFormatting" => { - if let Some(options) = parse_register_capabilities(reg)? { - server.update_capabilities(|capabilities| { - capabilities.document_range_formatting_provider = Some(options); - }); - notify_server_capabilities_updated(&server, cx); - } + let options = parse_register_capabilities(reg)?; + server.update_capabilities(|capabilities| { + capabilities.document_range_formatting_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); } "textDocument/onTypeFormatting" => { if let Some(options) = reg @@ -11729,57 +11860,50 @@ impl LspStore { } } "textDocument/formatting" => { - if let Some(options) = parse_register_capabilities(reg)? { - server.update_capabilities(|capabilities| { - capabilities.document_formatting_provider = Some(options); - }); - notify_server_capabilities_updated(&server, cx); - } + let options = parse_register_capabilities(reg)?; + server.update_capabilities(|capabilities| { + capabilities.document_formatting_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); } "textDocument/rename" => { - if let Some(options) = parse_register_capabilities(reg)? { - server.update_capabilities(|capabilities| { - capabilities.rename_provider = Some(options); - }); - notify_server_capabilities_updated(&server, cx); - } + let options = parse_register_capabilities(reg)?; + server.update_capabilities(|capabilities| { + capabilities.rename_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); } "textDocument/inlayHint" => { - if let Some(options) = parse_register_capabilities(reg)? { - server.update_capabilities(|capabilities| { - capabilities.inlay_hint_provider = Some(options); - }); - notify_server_capabilities_updated(&server, cx); - } + let options = parse_register_capabilities(reg)?; + server.update_capabilities(|capabilities| { + capabilities.inlay_hint_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); } "textDocument/documentSymbol" => { - if let Some(options) = parse_register_capabilities(reg)? { - server.update_capabilities(|capabilities| { - capabilities.document_symbol_provider = Some(options); - }); - notify_server_capabilities_updated(&server, cx); - } + let options = parse_register_capabilities(reg)?; + server.update_capabilities(|capabilities| { + capabilities.document_symbol_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); } "textDocument/codeAction" => { - if let Some(options) = reg - .register_options - .map(serde_json::from_value) - .transpose()? - { - server.update_capabilities(|capabilities| { - capabilities.code_action_provider = - Some(lsp::CodeActionProviderCapability::Options(options)); - }); - notify_server_capabilities_updated(&server, cx); - } + let options = parse_register_capabilities(reg)?; + let provider = match options { + OneOf::Left(value) => lsp::CodeActionProviderCapability::Simple(value), + OneOf::Right(caps) => caps, + }; + server.update_capabilities(|capabilities| { + capabilities.code_action_provider = Some(provider); + }); + notify_server_capabilities_updated(&server, cx); } "textDocument/definition" => { - if let Some(options) = parse_register_capabilities(reg)? { - server.update_capabilities(|capabilities| { - capabilities.definition_provider = Some(options); - }); - notify_server_capabilities_updated(&server, cx); - } + let options = parse_register_capabilities(reg)?; + server.update_capabilities(|capabilities| { + capabilities.definition_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); } "textDocument/completion" => { if let Some(caps) = reg @@ -11794,37 +11918,70 @@ impl LspStore { } } "textDocument/hover" => { + let options = parse_register_capabilities(reg)?; + let provider = match options { + OneOf::Left(value) => lsp::HoverProviderCapability::Simple(value), + OneOf::Right(caps) => caps, + }; + server.update_capabilities(|capabilities| { + capabilities.hover_provider = Some(provider); + }); + notify_server_capabilities_updated(&server, cx); + } + "textDocument/signatureHelp" => { if let Some(caps) = reg .register_options .map(serde_json::from_value) .transpose()? { server.update_capabilities(|capabilities| { - capabilities.hover_provider = Some(caps); + capabilities.signature_help_provider = Some(caps); }); notify_server_capabilities_updated(&server, cx); } } - "textDocument/signatureHelp" => { - if let Some(caps) = reg + "textDocument/didChange" => { + if let Some(sync_kind) = reg .register_options - .map(serde_json::from_value) + .and_then(|opts| opts.get("syncKind").cloned()) + .map(serde_json::from_value::) .transpose()? { server.update_capabilities(|capabilities| { - capabilities.signature_help_provider = Some(caps); + let mut sync_options = + Self::take_text_document_sync_options(capabilities); + sync_options.change = Some(sync_kind); + capabilities.text_document_sync = + Some(lsp::TextDocumentSyncCapability::Options(sync_options)); }); notify_server_capabilities_updated(&server, cx); } } - "textDocument/synchronization" => { - if let Some(caps) = reg + "textDocument/didSave" => { + if let Some(include_text) = reg .register_options - .map(serde_json::from_value) + .map(|opts| { + let transpose = opts + .get("includeText") + .cloned() + .map(serde_json::from_value::>) + .transpose(); + match transpose { + Ok(value) => Ok(value.flatten()), + Err(e) => Err(e), + } + }) .transpose()? { server.update_capabilities(|capabilities| { - capabilities.text_document_sync = Some(caps); + let mut sync_options = + Self::take_text_document_sync_options(capabilities); + sync_options.save = + Some(TextDocumentSyncSaveOptions::SaveOptions(lsp::SaveOptions { + include_text, + })); + capabilities.text_document_sync = + Some(lsp::TextDocumentSyncCapability::Options(sync_options)); }); notify_server_capabilities_updated(&server, cx); } @@ -11853,17 +12010,16 @@ impl LspStore { notify_server_capabilities_updated(&server, cx); } } - "textDocument/colorProvider" => { - if let Some(caps) = reg - .register_options - .map(serde_json::from_value) - .transpose()? - { - server.update_capabilities(|capabilities| { - capabilities.color_provider = Some(caps); - }); - notify_server_capabilities_updated(&server, cx); - } + "textDocument/documentColor" => { + let options = parse_register_capabilities(reg)?; + let provider = match options { + OneOf::Left(value) => lsp::ColorProviderCapability::Simple(value), + OneOf::Right(caps) => caps, + }; + server.update_capabilities(|capabilities| { + capabilities.color_provider = Some(provider); + }); + notify_server_capabilities_updated(&server, cx); } _ => log::warn!("unhandled capability registration: {reg:?}"), } @@ -11898,6 +12054,18 @@ impl LspStore { "workspace/didChangeConfiguration" => { // Ignore payload since we notify clients of setting changes unconditionally, relying on them pulling the latest settings. } + "workspace/didChangeWorkspaceFolders" => { + server.update_capabilities(|capabilities| { + capabilities + .workspace + .get_or_insert_with(|| lsp::WorkspaceServerCapabilities { + workspace_folders: None, + file_operations: None, + }) + .workspace_folders = None; + }); + notify_server_capabilities_updated(&server, cx); + } "workspace/symbol" => { server.update_capabilities(|capabilities| { capabilities.workspace_symbol_provider = None @@ -11974,9 +12142,21 @@ impl LspStore { }); notify_server_capabilities_updated(&server, cx); } - "textDocument/synchronization" => { + "textDocument/didChange" => { + server.update_capabilities(|capabilities| { + let mut sync_options = Self::take_text_document_sync_options(capabilities); + sync_options.change = None; + capabilities.text_document_sync = + Some(lsp::TextDocumentSyncCapability::Options(sync_options)); + }); + notify_server_capabilities_updated(&server, cx); + } + "textDocument/didSave" => { server.update_capabilities(|capabilities| { - capabilities.text_document_sync = None; + let mut sync_options = Self::take_text_document_sync_options(capabilities); + sync_options.save = None; + capabilities.text_document_sync = + Some(lsp::TextDocumentSyncCapability::Options(sync_options)); }); notify_server_capabilities_updated(&server, cx); } @@ -11992,7 +12172,7 @@ impl LspStore { }); notify_server_capabilities_updated(&server, cx); } - "textDocument/colorProvider" => { + "textDocument/documentColor" => { server.update_capabilities(|capabilities| { capabilities.color_provider = None; }); @@ -12004,18 +12184,127 @@ impl LspStore { Ok(()) } + + async fn query_lsp_locally( + lsp_store: Entity, + sender_id: proto::PeerId, + lsp_request_id: LspRequestId, + proto_request: T::ProtoRequest, + position: Option, + mut cx: AsyncApp, + ) -> Result<()> + where + T: LspCommand + Clone, + T::ProtoRequest: proto::LspRequestMessage, + ::Response: + Into<::Response>, + { + let buffer_id = BufferId::new(proto_request.buffer_id())?; + let version = deserialize_version(proto_request.buffer_version()); + let buffer = lsp_store.update(&mut cx, |this, cx| { + this.buffer_store.read(cx).get_existing(buffer_id) + })??; + buffer + .update(&mut cx, |buffer, _| { + buffer.wait_for_version(version.clone()) + })? + .await?; + let buffer_version = buffer.read_with(&cx, |buffer, _| buffer.version())?; + let request = + T::from_proto(proto_request, lsp_store.clone(), buffer.clone(), cx.clone()).await?; + lsp_store.update(&mut cx, |lsp_store, cx| { + let request_task = + lsp_store.request_multiple_lsp_locally(&buffer, position, request, cx); + let existing_queries = lsp_store + .running_lsp_requests + .entry(TypeId::of::()) + .or_default(); + if T::ProtoRequest::stop_previous_requests() + || buffer_version.changed_since(&existing_queries.0) + { + existing_queries.1.clear(); + } + existing_queries.1.insert( + lsp_request_id, + cx.spawn(async move |lsp_store, cx| { + let response = request_task.await; + lsp_store + .update(cx, |lsp_store, cx| { + if let Some((client, project_id)) = lsp_store.downstream_client.clone() + { + let response = response + .into_iter() + .map(|(server_id, response)| { + ( + server_id.to_proto(), + T::response_to_proto( + response, + lsp_store, + sender_id, + &buffer_version, + cx, + ) + .into(), + ) + }) + .collect::>(); + match client.send_lsp_response::( + project_id, + lsp_request_id, + response, + ) { + Ok(()) => {} + Err(e) => { + log::error!("Failed to send LSP response: {e:#}",) + } + } + } + }) + .ok(); + }), + ); + })?; + Ok(()) + } + + fn take_text_document_sync_options( + capabilities: &mut lsp::ServerCapabilities, + ) -> lsp::TextDocumentSyncOptions { + match capabilities.text_document_sync.take() { + Some(lsp::TextDocumentSyncCapability::Options(sync_options)) => sync_options, + Some(lsp::TextDocumentSyncCapability::Kind(sync_kind)) => { + let mut sync_options = lsp::TextDocumentSyncOptions::default(); + sync_options.change = Some(sync_kind); + sync_options + } + None => lsp::TextDocumentSyncOptions::default(), + } + } + + #[cfg(any(test, feature = "test-support"))] + pub fn forget_code_lens_task(&mut self, buffer_id: BufferId) -> Option { + let data = self.lsp_code_lens.get_mut(&buffer_id)?; + Some(data.update.take()?.1) + } + + pub fn downstream_client(&self) -> Option<(AnyProtoClient, u64)> { + self.downstream_client.clone() + } + + pub fn worktree_store(&self) -> Entity { + self.worktree_store.clone() + } } -// Registration with empty capabilities should be ignored. -// https://github.com/microsoft/vscode-languageserver-node/blob/d90a87f9557a0df9142cfb33e251cfa6fe27d970/client/src/common/formatting.ts#L67-L70 +// Registration with registerOptions as null, should fallback to true. +// https://github.com/microsoft/vscode-languageserver-node/blob/d90a87f9557a0df9142cfb33e251cfa6fe27d970/client/src/common/client.ts#L2133 fn parse_register_capabilities( reg: lsp::Registration, -) -> anyhow::Result>> { - Ok(reg - .register_options - .map(|options| serde_json::from_value::(options)) - .transpose()? - .map(OneOf::Right)) +) -> Result> { + Ok(match reg.register_options { + Some(options) => OneOf::Right(serde_json::from_value::(options)?), + None => OneOf::Left(true), + }) } fn subscribe_to_binary_statuses( @@ -12250,11 +12539,10 @@ async fn populate_labels_for_completions( let lsp_completions = new_completions .iter() .filter_map(|new_completion| { - if let Some(lsp_completion) = new_completion.source.lsp_completion(true) { - Some(lsp_completion.into_owned()) - } else { - None - } + new_completion + .source + .lsp_completion(true) + .map(|lsp_completion| lsp_completion.into_owned()) }) .collect::>(); @@ -12274,11 +12562,7 @@ async fn populate_labels_for_completions( for completion in new_completions { match completion.source.lsp_completion(true) { Some(lsp_completion) => { - let documentation = if let Some(docs) = lsp_completion.documentation.clone() { - Some(docs.into()) - } else { - None - }; + let documentation = lsp_completion.documentation.clone().map(|docs| docs.into()); let mut label = labels.next().flatten().unwrap_or_else(|| { CodeLabel::fallback_for_completion(&lsp_completion, language.as_deref()) @@ -12374,7 +12658,7 @@ impl TryFrom<&FileOperationFilter> for RenameActionPredicate { ops.pattern .options .as_ref() - .map_or(false, |ops| ops.ignore_case.unwrap_or(false)), + .is_some_and(|ops| ops.ignore_case.unwrap_or(false)), ) .build()? .compile_matcher(), @@ -12389,7 +12673,7 @@ struct RenameActionPredicate { impl RenameActionPredicate { // Returns true if language server should be notified fn eval(&self, path: &str, is_dir: bool) -> bool { - self.kind.as_ref().map_or(true, |kind| { + self.kind.as_ref().is_none_or(|kind| { let expected_kind = if is_dir { FileOperationPatternKind::Folder } else { @@ -12519,45 +12803,69 @@ impl PartialEq for LanguageServerPromptRequest { #[derive(Clone, Debug, PartialEq)] pub enum LanguageServerLogType { Log(MessageType), - Trace(Option), + Trace { verbose_info: Option }, + Rpc { received: bool }, } impl LanguageServerLogType { pub fn to_proto(&self) -> proto::language_server_log::LogType { match self { Self::Log(log_type) => { - let message_type = match *log_type { - MessageType::ERROR => 1, - MessageType::WARNING => 2, - MessageType::INFO => 3, - MessageType::LOG => 4, + use proto::log_message::LogLevel; + let level = match *log_type { + MessageType::ERROR => LogLevel::Error, + MessageType::WARNING => LogLevel::Warning, + MessageType::INFO => LogLevel::Info, + MessageType::LOG => LogLevel::Log, other => { - log::warn!("Unknown lsp log message type: {:?}", other); - 4 + log::warn!("Unknown lsp log message type: {other:?}"); + LogLevel::Log } }; - proto::language_server_log::LogType::LogMessageType(message_type) + proto::language_server_log::LogType::Log(proto::LogMessage { + level: level as i32, + }) } - Self::Trace(message) => { - proto::language_server_log::LogType::LogTrace(proto::LspLogTrace { - message: message.clone(), + Self::Trace { verbose_info } => { + proto::language_server_log::LogType::Trace(proto::TraceMessage { + verbose_info: verbose_info.to_owned(), }) } + Self::Rpc { received } => { + let kind = if *received { + proto::rpc_message::Kind::Received + } else { + proto::rpc_message::Kind::Sent + }; + let kind = kind as i32; + proto::language_server_log::LogType::Rpc(proto::RpcMessage { kind }) + } } } pub fn from_proto(log_type: proto::language_server_log::LogType) -> Self { + use proto::log_message::LogLevel; + use proto::rpc_message; match log_type { - proto::language_server_log::LogType::LogMessageType(message_type) => { - Self::Log(match message_type { - 1 => MessageType::ERROR, - 2 => MessageType::WARNING, - 3 => MessageType::INFO, - 4 => MessageType::LOG, - _ => MessageType::LOG, - }) - } - proto::language_server_log::LogType::LogTrace(trace) => Self::Trace(trace.message), + proto::language_server_log::LogType::Log(message_type) => Self::Log( + match LogLevel::from_i32(message_type.level).unwrap_or(LogLevel::Log) { + LogLevel::Error => MessageType::ERROR, + LogLevel::Warning => MessageType::WARNING, + LogLevel::Info => MessageType::INFO, + LogLevel::Log => MessageType::LOG, + }, + ), + proto::language_server_log::LogType::Trace(trace_message) => Self::Trace { + verbose_info: trace_message.verbose_info, + }, + proto::language_server_log::LogType::Rpc(message) => Self::Rpc { + received: match rpc_message::Kind::from_i32(message.kind) + .unwrap_or(rpc_message::Kind::Received) + { + rpc_message::Kind::Received => true, + rpc_message::Kind::Sent => false, + }, + }, } } } @@ -12573,7 +12881,7 @@ pub enum LanguageServerState { Starting { startup: Task>>, /// List of language servers that will be added to the workspace once it's initialization completes. - pending_workspace_folders: Arc>>, + pending_workspace_folders: Arc>>, }, Running { @@ -12585,7 +12893,7 @@ pub enum LanguageServerState { } impl LanguageServerState { - fn add_workspace_folder(&self, uri: Url) { + fn add_workspace_folder(&self, uri: Uri) { match self { LanguageServerState::Starting { pending_workspace_folders, @@ -12598,7 +12906,7 @@ impl LanguageServerState { } } } - fn _remove_workspace_folder(&self, uri: Url) { + fn _remove_workspace_folder(&self, uri: Uri) { match self { LanguageServerState::Starting { pending_workspace_folders, @@ -12666,7 +12974,7 @@ impl DiagnosticSummary { } pub fn to_proto( - &self, + self, language_server_id: LanguageServerId, path: &Path, ) -> proto::DiagnosticSummary { @@ -12696,6 +13004,21 @@ pub enum CompletionDocumentation { }, } +impl CompletionDocumentation { + #[cfg(any(test, feature = "test-support"))] + pub fn text(&self) -> SharedString { + match self { + CompletionDocumentation::Undocumented => "".into(), + CompletionDocumentation::SingleLine(s) => s.clone(), + CompletionDocumentation::MultiLinePlainText(s) => s.clone(), + CompletionDocumentation::MultiLineMarkdown(s) => s.clone(), + CompletionDocumentation::SingleLineAndMultiLinePlainText { single_line, .. } => { + single_line.clone() + } + } + } +} + impl From for CompletionDocumentation { fn from(docs: lsp::Documentation) -> Self { match docs { @@ -12783,7 +13106,7 @@ impl LspAdapter for SshLspAdapter { async fn check_if_user_installed( &self, _: &dyn LspAdapterDelegate, - _: Arc, + _: Option, _: &AsyncApp, ) -> Option { Some(self.binary.clone()) @@ -12800,6 +13123,7 @@ impl LspAdapter for SshLspAdapter { async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, + _: &AsyncApp, ) -> Result> { anyhow::bail!("SshLspAdapter does not support fetch_latest_server_version") } @@ -13106,24 +13430,18 @@ async fn populate_labels_for_symbols( fn include_text(server: &lsp::LanguageServer) -> Option { match server.capabilities().text_document_sync.as_ref()? { - lsp::TextDocumentSyncCapability::Kind(kind) => match *kind { - lsp::TextDocumentSyncKind::NONE => None, - lsp::TextDocumentSyncKind::FULL => Some(true), - lsp::TextDocumentSyncKind::INCREMENTAL => Some(false), - _ => None, - }, - lsp::TextDocumentSyncCapability::Options(options) => match options.save.as_ref()? { - lsp::TextDocumentSyncSaveOptions::Supported(supported) => { - if *supported { - Some(true) - } else { - None - } - } + lsp::TextDocumentSyncCapability::Options(opts) => match opts.save.as_ref()? { + // Server wants didSave but didn't specify includeText. + lsp::TextDocumentSyncSaveOptions::Supported(true) => Some(false), + // Server doesn't want didSave at all. + lsp::TextDocumentSyncSaveOptions::Supported(false) => None, + // Server provided SaveOptions. lsp::TextDocumentSyncSaveOptions::SaveOptions(save_options) => { Some(save_options.include_text.unwrap_or(false)) } }, + // We do not have any save info. Kind affects didChange only. + lsp::TextDocumentSyncCapability::Kind(_) => None, } } @@ -13140,10 +13458,10 @@ fn ensure_uniform_list_compatible_label(label: &mut CodeLabel) { let mut offset_map = vec![0; label.text.len() + 1]; let mut last_char_was_space = false; let mut new_idx = 0; - let mut chars = label.text.char_indices().fuse(); + let chars = label.text.char_indices().fuse(); let mut newlines_removed = false; - while let Some((idx, c)) = chars.next() { + for (idx, c) in chars { offset_map[idx] = new_idx; match c { diff --git a/crates/project/src/lsp_store/clangd_ext.rs b/crates/project/src/lsp_store/clangd_ext.rs index 274b1b898086eeddf72710052397dd9963833663..b02f68dd4d1271ca9a8fa97e9ef41e03fdfe9763 100644 --- a/crates/project/src/lsp_store/clangd_ext.rs +++ b/crates/project/src/lsp_store/clangd_ext.rs @@ -58,7 +58,7 @@ pub fn register_notifications( language_server .on_notification::({ - let adapter = adapter.clone(); + let adapter = adapter; let this = lsp_store; move |params: InactiveRegionsParams, cx| { diff --git a/crates/project/src/lsp_store/log_store.rs b/crates/project/src/lsp_store/log_store.rs new file mode 100644 index 0000000000000000000000000000000000000000..00098712bf0092a6795de2ed48c7ccf15925c555 --- /dev/null +++ b/crates/project/src/lsp_store/log_store.rs @@ -0,0 +1,712 @@ +use std::{collections::VecDeque, sync::Arc}; + +use collections::HashMap; +use futures::{StreamExt, channel::mpsc}; +use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, Subscription, WeakEntity}; +use lsp::{ + IoKind, LanguageServer, LanguageServerId, LanguageServerName, LanguageServerSelector, + MessageType, TraceValue, +}; +use rpc::proto; +use settings::WorktreeId; + +use crate::{LanguageServerLogType, LspStore, Project, ProjectItem as _}; + +const SEND_LINE: &str = "\n// Send:"; +const RECEIVE_LINE: &str = "\n// Receive:"; +const MAX_STORED_LOG_ENTRIES: usize = 2000; + +pub fn init(on_headless_host: bool, cx: &mut App) -> Entity { + let log_store = cx.new(|cx| LogStore::new(on_headless_host, cx)); + cx.set_global(GlobalLogStore(log_store.clone())); + log_store +} + +pub struct GlobalLogStore(pub Entity); + +impl Global for GlobalLogStore {} + +#[derive(Debug)] +pub enum Event { + NewServerLogEntry { + id: LanguageServerId, + kind: LanguageServerLogType, + text: String, + }, +} + +impl EventEmitter for LogStore {} + +pub struct LogStore { + on_headless_host: bool, + projects: HashMap, ProjectState>, + pub copilot_log_subscription: Option, + pub language_servers: HashMap, + io_tx: mpsc::UnboundedSender<(LanguageServerId, IoKind, String)>, +} + +struct ProjectState { + _subscriptions: [Subscription; 2], +} + +pub trait Message: AsRef { + type Level: Copy + std::fmt::Debug; + fn should_include(&self, _: Self::Level) -> bool { + true + } +} + +#[derive(Debug)] +pub struct LogMessage { + message: String, + typ: MessageType, +} + +impl AsRef for LogMessage { + fn as_ref(&self) -> &str { + &self.message + } +} + +impl Message for LogMessage { + type Level = MessageType; + + fn should_include(&self, level: Self::Level) -> bool { + match (self.typ, level) { + (MessageType::ERROR, _) => true, + (_, MessageType::ERROR) => false, + (MessageType::WARNING, _) => true, + (_, MessageType::WARNING) => false, + (MessageType::INFO, _) => true, + (_, MessageType::INFO) => false, + _ => true, + } + } +} + +#[derive(Debug)] +pub struct TraceMessage { + message: String, + is_verbose: bool, +} + +impl AsRef for TraceMessage { + fn as_ref(&self) -> &str { + &self.message + } +} + +impl Message for TraceMessage { + type Level = TraceValue; + + fn should_include(&self, level: Self::Level) -> bool { + match level { + TraceValue::Off => false, + TraceValue::Messages => !self.is_verbose, + TraceValue::Verbose => true, + } + } +} + +#[derive(Debug)] +pub struct RpcMessage { + message: String, +} + +impl AsRef for RpcMessage { + fn as_ref(&self) -> &str { + &self.message + } +} + +impl Message for RpcMessage { + type Level = (); +} + +pub struct LanguageServerState { + pub name: Option, + pub worktree_id: Option, + pub kind: LanguageServerKind, + log_messages: VecDeque, + trace_messages: VecDeque, + pub rpc_state: Option, + pub trace_level: TraceValue, + pub log_level: MessageType, + io_logs_subscription: Option, + pub toggled_log_kind: Option, +} + +impl std::fmt::Debug for LanguageServerState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LanguageServerState") + .field("name", &self.name) + .field("worktree_id", &self.worktree_id) + .field("kind", &self.kind) + .field("log_messages", &self.log_messages) + .field("trace_messages", &self.trace_messages) + .field("rpc_state", &self.rpc_state) + .field("trace_level", &self.trace_level) + .field("log_level", &self.log_level) + .field("toggled_log_kind", &self.toggled_log_kind) + .finish_non_exhaustive() + } +} + +#[derive(PartialEq, Clone)] +pub enum LanguageServerKind { + Local { project: WeakEntity }, + Remote { project: WeakEntity }, + LocalSsh { lsp_store: WeakEntity }, + Global, +} + +impl std::fmt::Debug for LanguageServerKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LanguageServerKind::Local { .. } => write!(f, "LanguageServerKind::Local"), + LanguageServerKind::Remote { .. } => write!(f, "LanguageServerKind::Remote"), + LanguageServerKind::LocalSsh { .. } => write!(f, "LanguageServerKind::LocalSsh"), + LanguageServerKind::Global => write!(f, "LanguageServerKind::Global"), + } + } +} + +impl LanguageServerKind { + pub fn project(&self) -> Option<&WeakEntity> { + match self { + Self::Local { project } => Some(project), + Self::Remote { project } => Some(project), + Self::LocalSsh { .. } => None, + Self::Global { .. } => None, + } + } +} + +#[derive(Debug)] +pub struct LanguageServerRpcState { + pub rpc_messages: VecDeque, + last_message_kind: Option, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum MessageKind { + Send, + Receive, +} + +#[derive(Clone, Copy, Debug, Default, PartialEq)] +pub enum LogKind { + Rpc, + Trace, + #[default] + Logs, + ServerInfo, +} + +impl LogKind { + pub fn from_server_log_type(log_type: &LanguageServerLogType) -> Self { + match log_type { + LanguageServerLogType::Log(_) => Self::Logs, + LanguageServerLogType::Trace { .. } => Self::Trace, + LanguageServerLogType::Rpc { .. } => Self::Rpc, + } + } +} + +impl LogStore { + pub fn new(on_headless_host: bool, cx: &mut Context) -> Self { + let (io_tx, mut io_rx) = mpsc::unbounded(); + + let log_store = Self { + projects: HashMap::default(), + language_servers: HashMap::default(), + copilot_log_subscription: None, + on_headless_host, + io_tx, + }; + cx.spawn(async move |log_store, cx| { + while let Some((server_id, io_kind, message)) = io_rx.next().await { + if let Some(log_store) = log_store.upgrade() { + log_store.update(cx, |log_store, cx| { + log_store.on_io(server_id, io_kind, &message, cx); + })?; + } + } + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + + log_store + } + + pub fn add_project(&mut self, project: &Entity, cx: &mut Context) { + let weak_project = project.downgrade(); + self.projects.insert( + project.downgrade(), + ProjectState { + _subscriptions: [ + cx.observe_release(project, move |this, _, _| { + this.projects.remove(&weak_project); + this.language_servers + .retain(|_, state| state.kind.project() != Some(&weak_project)); + }), + cx.subscribe(project, move |log_store, project, event, cx| { + let server_kind = if project.read(cx).is_local() { + LanguageServerKind::Local { + project: project.downgrade(), + } + } else { + LanguageServerKind::Remote { + project: project.downgrade(), + } + }; + match event { + crate::Event::LanguageServerAdded(id, name, worktree_id) => { + log_store.add_language_server( + server_kind, + *id, + Some(name.clone()), + *worktree_id, + project + .read(cx) + .lsp_store() + .read(cx) + .language_server_for_id(*id), + cx, + ); + } + crate::Event::LanguageServerBufferRegistered { + server_id, + buffer_id, + name, + .. + } => { + let worktree_id = project + .read(cx) + .buffer_for_id(*buffer_id, cx) + .and_then(|buffer| { + Some(buffer.read(cx).project_path(cx)?.worktree_id) + }); + let name = name.clone().or_else(|| { + project + .read(cx) + .lsp_store() + .read(cx) + .language_server_statuses + .get(server_id) + .map(|status| status.name.clone()) + }); + log_store.add_language_server( + server_kind, + *server_id, + name, + worktree_id, + None, + cx, + ); + } + crate::Event::LanguageServerRemoved(id) => { + log_store.remove_language_server(*id, cx); + } + crate::Event::LanguageServerLog(id, typ, message) => { + log_store.add_language_server( + server_kind, + *id, + None, + None, + None, + cx, + ); + match typ { + crate::LanguageServerLogType::Log(typ) => { + log_store.add_language_server_log(*id, *typ, message, cx); + } + crate::LanguageServerLogType::Trace { verbose_info } => { + log_store.add_language_server_trace( + *id, + message, + verbose_info.clone(), + cx, + ); + } + crate::LanguageServerLogType::Rpc { received } => { + let kind = if *received { + MessageKind::Receive + } else { + MessageKind::Send + }; + log_store.add_language_server_rpc(*id, kind, message, cx); + } + } + } + crate::Event::ToggleLspLogs { + server_id, + enabled, + toggled_log_kind, + } => { + if let Some(server_state) = + log_store.get_language_server_state(*server_id) + { + if *enabled { + server_state.toggled_log_kind = Some(*toggled_log_kind); + } else { + server_state.toggled_log_kind = None; + } + } + if LogKind::Rpc == *toggled_log_kind { + if *enabled { + log_store.enable_rpc_trace_for_language_server(*server_id); + } else { + log_store.disable_rpc_trace_for_language_server(*server_id); + } + } + } + _ => {} + } + }), + ], + }, + ); + } + + pub fn get_language_server_state( + &mut self, + id: LanguageServerId, + ) -> Option<&mut LanguageServerState> { + self.language_servers.get_mut(&id) + } + + pub fn add_language_server( + &mut self, + kind: LanguageServerKind, + server_id: LanguageServerId, + name: Option, + worktree_id: Option, + server: Option>, + cx: &mut Context, + ) -> Option<&mut LanguageServerState> { + let server_state = self.language_servers.entry(server_id).or_insert_with(|| { + cx.notify(); + LanguageServerState { + name: None, + worktree_id: None, + kind, + rpc_state: None, + log_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES), + trace_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES), + trace_level: TraceValue::Off, + log_level: MessageType::LOG, + io_logs_subscription: None, + toggled_log_kind: None, + } + }); + + if let Some(name) = name { + server_state.name = Some(name); + } + if let Some(worktree_id) = worktree_id { + server_state.worktree_id = Some(worktree_id); + } + + if let Some(server) = server.filter(|_| server_state.io_logs_subscription.is_none()) { + let io_tx = self.io_tx.clone(); + let server_id = server.server_id(); + server_state.io_logs_subscription = Some(server.on_io(move |io_kind, message| { + io_tx + .unbounded_send((server_id, io_kind, message.to_string())) + .ok(); + })); + } + + Some(server_state) + } + + pub fn add_language_server_log( + &mut self, + id: LanguageServerId, + typ: MessageType, + message: &str, + cx: &mut Context, + ) -> Option<()> { + let store_logs = !self.on_headless_host; + let language_server_state = self.get_language_server_state(id)?; + + let log_lines = &mut language_server_state.log_messages; + let message = message.trim_end().to_string(); + if !store_logs { + // Send all messages regardless of the visibility in case of not storing, to notify the receiver anyway + self.emit_event( + Event::NewServerLogEntry { + id, + kind: LanguageServerLogType::Log(typ), + text: message, + }, + cx, + ); + } else if let Some(new_message) = Self::push_new_message( + log_lines, + LogMessage { message, typ }, + language_server_state.log_level, + ) { + self.emit_event( + Event::NewServerLogEntry { + id, + kind: LanguageServerLogType::Log(typ), + text: new_message, + }, + cx, + ); + } + Some(()) + } + + fn add_language_server_trace( + &mut self, + id: LanguageServerId, + message: &str, + verbose_info: Option, + cx: &mut Context, + ) -> Option<()> { + let store_logs = !self.on_headless_host; + let language_server_state = self.get_language_server_state(id)?; + + let log_lines = &mut language_server_state.trace_messages; + if !store_logs { + // Send all messages regardless of the visibility in case of not storing, to notify the receiver anyway + self.emit_event( + Event::NewServerLogEntry { + id, + kind: LanguageServerLogType::Trace { verbose_info }, + text: message.trim().to_string(), + }, + cx, + ); + } else if let Some(new_message) = Self::push_new_message( + log_lines, + TraceMessage { + message: message.trim().to_string(), + is_verbose: false, + }, + TraceValue::Messages, + ) { + if let Some(verbose_message) = verbose_info.as_ref() { + Self::push_new_message( + log_lines, + TraceMessage { + message: verbose_message.clone(), + is_verbose: true, + }, + TraceValue::Verbose, + ); + } + self.emit_event( + Event::NewServerLogEntry { + id, + kind: LanguageServerLogType::Trace { verbose_info }, + text: new_message, + }, + cx, + ); + } + Some(()) + } + + fn push_new_message( + log_lines: &mut VecDeque, + message: T, + current_severity: ::Level, + ) -> Option { + while log_lines.len() + 1 >= MAX_STORED_LOG_ENTRIES { + log_lines.pop_front(); + } + let visible = message.should_include(current_severity); + + let visible_message = visible.then(|| message.as_ref().to_string()); + log_lines.push_back(message); + visible_message + } + + fn add_language_server_rpc( + &mut self, + language_server_id: LanguageServerId, + kind: MessageKind, + message: &str, + cx: &mut Context<'_, Self>, + ) { + let store_logs = !self.on_headless_host; + let Some(state) = self + .get_language_server_state(language_server_id) + .and_then(|state| state.rpc_state.as_mut()) + else { + return; + }; + + let received = kind == MessageKind::Receive; + let rpc_log_lines = &mut state.rpc_messages; + if state.last_message_kind != Some(kind) { + while rpc_log_lines.len() + 1 >= MAX_STORED_LOG_ENTRIES { + rpc_log_lines.pop_front(); + } + let line_before_message = match kind { + MessageKind::Send => SEND_LINE, + MessageKind::Receive => RECEIVE_LINE, + }; + if store_logs { + rpc_log_lines.push_back(RpcMessage { + message: line_before_message.to_string(), + }); + } + // Do not send a synthetic message over the wire, it will be derived from the actual RPC message + cx.emit(Event::NewServerLogEntry { + id: language_server_id, + kind: LanguageServerLogType::Rpc { received }, + text: line_before_message.to_string(), + }); + } + + while rpc_log_lines.len() + 1 >= MAX_STORED_LOG_ENTRIES { + rpc_log_lines.pop_front(); + } + + if store_logs { + rpc_log_lines.push_back(RpcMessage { + message: message.trim().to_owned(), + }); + } + + self.emit_event( + Event::NewServerLogEntry { + id: language_server_id, + kind: LanguageServerLogType::Rpc { received }, + text: message.to_owned(), + }, + cx, + ); + } + + pub fn remove_language_server(&mut self, id: LanguageServerId, cx: &mut Context) { + self.language_servers.remove(&id); + cx.notify(); + } + + pub fn server_logs(&self, server_id: LanguageServerId) -> Option<&VecDeque> { + Some(&self.language_servers.get(&server_id)?.log_messages) + } + + pub fn server_trace(&self, server_id: LanguageServerId) -> Option<&VecDeque> { + Some(&self.language_servers.get(&server_id)?.trace_messages) + } + + pub fn server_ids_for_project<'a>( + &'a self, + lookup_project: &'a WeakEntity, + ) -> impl Iterator + 'a { + self.language_servers + .iter() + .filter_map(move |(id, state)| match &state.kind { + LanguageServerKind::Local { project } | LanguageServerKind::Remote { project } => { + if project == lookup_project { + Some(*id) + } else { + None + } + } + LanguageServerKind::Global | LanguageServerKind::LocalSsh { .. } => Some(*id), + }) + } + + pub fn enable_rpc_trace_for_language_server( + &mut self, + server_id: LanguageServerId, + ) -> Option<&mut LanguageServerRpcState> { + let rpc_state = self + .language_servers + .get_mut(&server_id)? + .rpc_state + .get_or_insert_with(|| LanguageServerRpcState { + rpc_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES), + last_message_kind: None, + }); + Some(rpc_state) + } + + pub fn disable_rpc_trace_for_language_server( + &mut self, + server_id: LanguageServerId, + ) -> Option<()> { + self.language_servers.get_mut(&server_id)?.rpc_state.take(); + Some(()) + } + + pub fn has_server_logs(&self, server: &LanguageServerSelector) -> bool { + match server { + LanguageServerSelector::Id(id) => self.language_servers.contains_key(id), + LanguageServerSelector::Name(name) => self + .language_servers + .iter() + .any(|(_, state)| state.name.as_ref() == Some(name)), + } + } + + fn on_io( + &mut self, + language_server_id: LanguageServerId, + io_kind: IoKind, + message: &str, + cx: &mut Context, + ) -> Option<()> { + let is_received = match io_kind { + IoKind::StdOut => true, + IoKind::StdIn => false, + IoKind::StdErr => { + self.add_language_server_log(language_server_id, MessageType::LOG, message, cx); + return Some(()); + } + }; + + let kind = if is_received { + MessageKind::Receive + } else { + MessageKind::Send + }; + + self.add_language_server_rpc(language_server_id, kind, message, cx); + cx.notify(); + Some(()) + } + + fn emit_event(&mut self, e: Event, cx: &mut Context) { + let on_headless_host = self.on_headless_host; + match &e { + Event::NewServerLogEntry { id, kind, text } => { + if let Some(state) = self.get_language_server_state(*id) { + let downstream_client = match &state.kind { + LanguageServerKind::Remote { project } + | LanguageServerKind::Local { project } => project + .upgrade() + .map(|project| project.read(cx).lsp_store()), + LanguageServerKind::LocalSsh { lsp_store } => lsp_store.upgrade(), + LanguageServerKind::Global => None, + } + .and_then(|lsp_store| lsp_store.read(cx).downstream_client()); + if let Some((client, project_id)) = downstream_client { + if on_headless_host + || Some(LogKind::from_server_log_type(kind)) == state.toggled_log_kind + { + client + .send(proto::LanguageServerLog { + project_id, + language_server_id: id.to_proto(), + message: text.clone(), + log_type: Some(kind.to_proto()), + }) + .ok(); + } + } + } + } + } + + cx.emit(e); + } +} diff --git a/crates/project/src/lsp_store/lsp_ext_command.rs b/crates/project/src/lsp_store/lsp_ext_command.rs index cb13fa5efcfd753e0ffb12fbcc0f3d84e09ff370..0263946b25ed58969a3a7a98a9f537ce81d86ab1 100644 --- a/crates/project/src/lsp_store/lsp_ext_command.rs +++ b/crates/project/src/lsp_store/lsp_ext_command.rs @@ -115,14 +115,14 @@ impl LspCommand for ExpandMacro { message: Self::ProtoRequest, _: Entity, buffer: Entity, - mut cx: AsyncApp, + cx: AsyncApp, ) -> anyhow::Result { let position = message .position .and_then(deserialize_anchor) .context("invalid position")?; Ok(Self { - position: buffer.read_with(&mut cx, |buffer, _| position.to_point_utf16(buffer))?, + position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer))?, }) } @@ -213,7 +213,7 @@ impl LspCommand for OpenDocs { ) -> Result { Ok(OpenDocsParams { text_document: lsp::TextDocumentIdentifier { - uri: lsp::Url::from_file_path(path).unwrap(), + uri: lsp::Uri::from_file_path(path).unwrap(), }, position: point_to_lsp(self.position), }) @@ -249,14 +249,14 @@ impl LspCommand for OpenDocs { message: Self::ProtoRequest, _: Entity, buffer: Entity, - mut cx: AsyncApp, + cx: AsyncApp, ) -> anyhow::Result { let position = message .position .and_then(deserialize_anchor) .context("invalid position")?; Ok(Self { - position: buffer.read_with(&mut cx, |buffer, _| position.to_point_utf16(buffer))?, + position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer))?, }) } @@ -462,14 +462,14 @@ impl LspCommand for GoToParentModule { request: Self::ProtoRequest, _: Entity, buffer: Entity, - mut cx: AsyncApp, + cx: AsyncApp, ) -> anyhow::Result { let position = request .position .and_then(deserialize_anchor) .context("bad request with bad position")?; Ok(Self { - position: buffer.read_with(&mut cx, |buffer, _| position.to_point_utf16(buffer))?, + position: buffer.read_with(&cx, |buffer, _| position.to_point_utf16(buffer))?, }) } diff --git a/crates/project/src/lsp_store/rust_analyzer_ext.rs b/crates/project/src/lsp_store/rust_analyzer_ext.rs index 6c425717a82e94985c60db8d1034d470f1aeec35..54f63220b1ef8bab1db22a0808fd2ccb9277b73c 100644 --- a/crates/project/src/lsp_store/rust_analyzer_ext.rs +++ b/crates/project/src/lsp_store/rust_analyzer_ext.rs @@ -1,8 +1,8 @@ use ::serde::{Deserialize, Serialize}; use anyhow::Context as _; -use gpui::{App, Entity, Task, WeakEntity}; -use language::ServerHealth; -use lsp::{LanguageServer, LanguageServerName}; +use gpui::{App, AsyncApp, Entity, Task, WeakEntity}; +use language::{Buffer, ServerHealth}; +use lsp::{LanguageServer, LanguageServerId, LanguageServerName}; use rpc::proto; use crate::{LspStore, LspStoreEvent, Project, ProjectPath, lsp_store}; @@ -34,7 +34,6 @@ pub fn register_notifications(lsp_store: WeakEntity, language_server: language_server .on_notification::({ - let name = name.clone(); move |params, cx| { let message = params.message; let log_message = message.as_ref().map(|message| { @@ -84,31 +83,32 @@ pub fn register_notifications(lsp_store: WeakEntity, language_server: pub fn cancel_flycheck( project: Entity, - buffer_path: ProjectPath, + buffer_path: Option, cx: &mut App, ) -> Task> { let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client(); let lsp_store = project.read(cx).lsp_store(); - let buffer = project.update(cx, |project, cx| { - project.buffer_store().update(cx, |buffer_store, cx| { - buffer_store.open_buffer(buffer_path, cx) + let buffer = buffer_path.map(|buffer_path| { + project.update(cx, |project, cx| { + project.buffer_store().update(cx, |buffer_store, cx| { + buffer_store.open_buffer(buffer_path, cx) + }) }) }); cx.spawn(async move |cx| { - let buffer = buffer.await?; - let Some(rust_analyzer_server) = project.read_with(cx, |project, cx| { - project.language_server_id_for_name(buffer.read(cx), &RUST_ANALYZER_NAME, cx) - })? + let buffer = match buffer { + Some(buffer) => Some(buffer.await?), + None => None, + }; + let Some(rust_analyzer_server) = find_rust_analyzer_server(&project, buffer.as_ref(), cx) else { return Ok(()); }; - let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id().to_proto())?; if let Some((client, project_id)) = upstream_client { let request = proto::LspExtCancelFlycheck { project_id, - buffer_id, language_server_id: rust_analyzer_server.to_proto(), }; client @@ -131,28 +131,33 @@ pub fn cancel_flycheck( pub fn run_flycheck( project: Entity, - buffer_path: ProjectPath, + buffer_path: Option, cx: &mut App, ) -> Task> { let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client(); let lsp_store = project.read(cx).lsp_store(); - let buffer = project.update(cx, |project, cx| { - project.buffer_store().update(cx, |buffer_store, cx| { - buffer_store.open_buffer(buffer_path, cx) + let buffer = buffer_path.map(|buffer_path| { + project.update(cx, |project, cx| { + project.buffer_store().update(cx, |buffer_store, cx| { + buffer_store.open_buffer(buffer_path, cx) + }) }) }); cx.spawn(async move |cx| { - let buffer = buffer.await?; - let Some(rust_analyzer_server) = project.read_with(cx, |project, cx| { - project.language_server_id_for_name(buffer.read(cx), &RUST_ANALYZER_NAME, cx) - })? + let buffer = match buffer { + Some(buffer) => Some(buffer.await?), + None => None, + }; + let Some(rust_analyzer_server) = find_rust_analyzer_server(&project, buffer.as_ref(), cx) else { return Ok(()); }; - let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id().to_proto())?; if let Some((client, project_id)) = upstream_client { + let buffer_id = buffer + .map(|buffer| buffer.read_with(cx, |buffer, _| buffer.remote_id().to_proto())) + .transpose()?; let request = proto::LspExtRunFlycheck { project_id, buffer_id, @@ -183,31 +188,32 @@ pub fn run_flycheck( pub fn clear_flycheck( project: Entity, - buffer_path: ProjectPath, + buffer_path: Option, cx: &mut App, ) -> Task> { let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client(); let lsp_store = project.read(cx).lsp_store(); - let buffer = project.update(cx, |project, cx| { - project.buffer_store().update(cx, |buffer_store, cx| { - buffer_store.open_buffer(buffer_path, cx) + let buffer = buffer_path.map(|buffer_path| { + project.update(cx, |project, cx| { + project.buffer_store().update(cx, |buffer_store, cx| { + buffer_store.open_buffer(buffer_path, cx) + }) }) }); cx.spawn(async move |cx| { - let buffer = buffer.await?; - let Some(rust_analyzer_server) = project.read_with(cx, |project, cx| { - project.language_server_id_for_name(buffer.read(cx), &RUST_ANALYZER_NAME, cx) - })? + let buffer = match buffer { + Some(buffer) => Some(buffer.await?), + None => None, + }; + let Some(rust_analyzer_server) = find_rust_analyzer_server(&project, buffer.as_ref(), cx) else { return Ok(()); }; - let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id().to_proto())?; if let Some((client, project_id)) = upstream_client { let request = proto::LspExtClearFlycheck { project_id, - buffer_id, language_server_id: rust_analyzer_server.to_proto(), }; client @@ -227,3 +233,40 @@ pub fn clear_flycheck( anyhow::Ok(()) }) } + +fn find_rust_analyzer_server( + project: &Entity, + buffer: Option<&Entity>, + cx: &mut AsyncApp, +) -> Option { + project + .read_with(cx, |project, cx| { + buffer + .and_then(|buffer| { + project.language_server_id_for_name(buffer.read(cx), &RUST_ANALYZER_NAME, cx) + }) + // If no rust-analyzer found for the current buffer (e.g. `settings.json`), fall back to the project lookup + // and use project's rust-analyzer if it's the only one. + .or_else(|| { + let rust_analyzer_servers = project + .lsp_store() + .read(cx) + .language_server_statuses + .iter() + .filter_map(|(server_id, server_status)| { + if server_status.name == RUST_ANALYZER_NAME { + Some(*server_id) + } else { + None + } + }) + .collect::>(); + if rust_analyzer_servers.len() == 1 { + rust_analyzer_servers.first().copied() + } else { + None + } + }) + }) + .ok()? +} diff --git a/crates/project/src/manifest_tree.rs b/crates/project/src/manifest_tree.rs index 7266acb5b4a29b68d8863feb760334de46260424..5a3c7bd40fb11ee5bebe340ddc57ec71a112270b 100644 --- a/crates/project/src/manifest_tree.rs +++ b/crates/project/src/manifest_tree.rs @@ -7,18 +7,12 @@ mod manifest_store; mod path_trie; mod server_tree; -use std::{ - borrow::Borrow, - collections::{BTreeMap, hash_map::Entry}, - ops::ControlFlow, - path::Path, - sync::Arc, -}; +use std::{borrow::Borrow, collections::hash_map::Entry, ops::ControlFlow, path::Path, sync::Arc}; use collections::HashMap; -use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription}; +use gpui::{App, AppContext as _, Context, Entity, Subscription}; use language::{ManifestDelegate, ManifestName, ManifestQuery}; -pub use manifest_store::ManifestProviders; +pub use manifest_store::ManifestProvidersStore; use path_trie::{LabelPresence, RootPathTrie, TriePath}; use settings::{SettingsStore, WorktreeId}; use worktree::{Event as WorktreeEvent, Snapshot, Worktree}; @@ -28,9 +22,7 @@ use crate::{ worktree_store::{WorktreeStore, WorktreeStoreEvent}, }; -pub(crate) use server_tree::{ - AdapterQuery, LanguageServerTree, LanguageServerTreeNode, LaunchDisposition, -}; +pub(crate) use server_tree::{LanguageServerTree, LanguageServerTreeNode, LaunchDisposition}; struct WorktreeRoots { roots: RootPathTrie, @@ -51,12 +43,9 @@ impl WorktreeRoots { match event { WorktreeEvent::UpdatedEntries(changes) => { for (path, _, kind) in changes.iter() { - match kind { - worktree::PathChange::Removed => { - let path = TriePath::from(path.as_ref()); - this.roots.remove(&path); - } - _ => {} + if kind == &worktree::PathChange::Removed { + let path = TriePath::from(path.as_ref()); + this.roots.remove(&path); } } } @@ -81,14 +70,6 @@ pub struct ManifestTree { _subscriptions: [Subscription; 2], } -#[derive(PartialEq)] -pub(crate) enum ManifestTreeEvent { - WorktreeRemoved(WorktreeId), - Cleared, -} - -impl EventEmitter for ManifestTree {} - impl ManifestTree { pub fn new(worktree_store: Entity, cx: &mut App) -> Entity { cx.new(|cx| Self { @@ -96,35 +77,33 @@ impl ManifestTree { _subscriptions: [ cx.subscribe(&worktree_store, Self::on_worktree_store_event), cx.observe_global::(|this, cx| { - for (_, roots) in &mut this.root_points { + for roots in this.root_points.values_mut() { roots.update(cx, |worktree_roots, _| { worktree_roots.roots = RootPathTrie::new(); }) } - cx.emit(ManifestTreeEvent::Cleared); }), ], worktree_store, }) } + pub(crate) fn root_for_path( &mut self, - ProjectPath { worktree_id, path }: ProjectPath, - manifests: &mut dyn Iterator, - delegate: Arc, + ProjectPath { worktree_id, path }: &ProjectPath, + manifest_name: &ManifestName, + delegate: &Arc, cx: &mut App, - ) -> BTreeMap { - debug_assert_eq!(delegate.worktree_id(), worktree_id); - let mut roots = BTreeMap::from_iter( - manifests.map(|manifest| (manifest, (None, LabelPresence::KnownAbsent))), - ); - let worktree_roots = match self.root_points.entry(worktree_id) { + ) -> Option { + debug_assert_eq!(delegate.worktree_id(), *worktree_id); + let (mut marked_path, mut current_presence) = (None, LabelPresence::KnownAbsent); + let worktree_roots = match self.root_points.entry(*worktree_id) { Entry::Occupied(occupied_entry) => occupied_entry.get().clone(), Entry::Vacant(vacant_entry) => { let Some(worktree) = self .worktree_store .read(cx) - .worktree_for_id(worktree_id, cx) + .worktree_for_id(*worktree_id, cx) else { return Default::default(); }; @@ -133,16 +112,16 @@ impl ManifestTree { } }; - let key = TriePath::from(&*path); + let key = TriePath::from(&**path); worktree_roots.read_with(cx, |this, _| { this.roots.walk(&key, &mut |path, labels| { for (label, presence) in labels { - if let Some((marked_path, current_presence)) = roots.get_mut(label) { - if *current_presence > *presence { + if label == manifest_name { + if current_presence > *presence { debug_assert!(false, "RootPathTrie precondition violation; while walking the tree label presence is only allowed to increase"); } - *marked_path = Some(ProjectPath {worktree_id, path: path.clone()}); - *current_presence = *presence; + marked_path = Some(ProjectPath {worktree_id: *worktree_id, path: path.clone()}); + current_presence = *presence; } } @@ -150,12 +129,9 @@ impl ManifestTree { }); }); - for (manifest_name, (root_path, presence)) in &mut roots { - if *presence == LabelPresence::Present { - continue; - } - - let depth = root_path + if current_presence == LabelPresence::KnownAbsent { + // Some part of the path is unexplored. + let depth = marked_path .as_ref() .map(|root_path| { path.strip_prefix(&root_path.path) @@ -165,13 +141,10 @@ impl ManifestTree { }) .unwrap_or_else(|| path.components().count() + 1); - if depth > 0 { - let Some(provider) = ManifestProviders::global(cx).get(manifest_name.borrow()) - else { - log::warn!("Manifest provider `{}` not found", manifest_name.as_ref()); - continue; - }; - + if depth > 0 + && let Some(provider) = + ManifestProvidersStore::global(cx).get(manifest_name.borrow()) + { let root = provider.search(ManifestQuery { path: path.clone(), depth, @@ -182,9 +155,9 @@ impl ManifestTree { let root = TriePath::from(&*known_root); this.roots .insert(&root, manifest_name.clone(), LabelPresence::Present); - *presence = LabelPresence::Present; - *root_path = Some(ProjectPath { - worktree_id, + current_presence = LabelPresence::Present; + marked_path = Some(ProjectPath { + worktree_id: *worktree_id, path: known_root, }); }), @@ -195,27 +168,34 @@ impl ManifestTree { } } } + marked_path.filter(|_| current_presence.eq(&LabelPresence::Present)) + } - roots - .into_iter() - .filter_map(|(k, (path, presence))| { - let path = path?; - presence.eq(&LabelPresence::Present).then(|| (k, path)) + pub(crate) fn root_for_path_or_worktree_root( + &mut self, + project_path: &ProjectPath, + manifest_name: Option<&ManifestName>, + delegate: &Arc, + cx: &mut App, + ) -> ProjectPath { + let worktree_id = project_path.worktree_id; + // Backwards-compat: Fill in any adapters for which we did not detect the root as having the project root at the root of a worktree. + manifest_name + .and_then(|manifest_name| self.root_for_path(project_path, manifest_name, delegate, cx)) + .unwrap_or_else(|| ProjectPath { + worktree_id, + path: Arc::from(Path::new("")), }) - .collect() } + fn on_worktree_store_event( &mut self, _: Entity, evt: &WorktreeStoreEvent, - cx: &mut Context, + _: &mut Context, ) { - match evt { - WorktreeStoreEvent::WorktreeRemoved(_, worktree_id) => { - self.root_points.remove(&worktree_id); - cx.emit(ManifestTreeEvent::WorktreeRemoved(*worktree_id)); - } - _ => {} + if let WorktreeStoreEvent::WorktreeRemoved(_, worktree_id) = evt { + self.root_points.remove(worktree_id); } } } @@ -223,6 +203,7 @@ impl ManifestTree { pub(crate) struct ManifestQueryDelegate { worktree: Snapshot, } + impl ManifestQueryDelegate { pub fn new(worktree: Snapshot) -> Self { Self { worktree } @@ -231,10 +212,8 @@ impl ManifestQueryDelegate { impl ManifestDelegate for ManifestQueryDelegate { fn exists(&self, path: &Path, is_dir: Option) -> bool { - self.worktree.entry_for_path(path).map_or(false, |entry| { - is_dir.map_or(true, |is_required_to_be_dir| { - is_required_to_be_dir == entry.is_dir() - }) + self.worktree.entry_for_path(path).is_some_and(|entry| { + is_dir.is_none_or(|is_required_to_be_dir| is_required_to_be_dir == entry.is_dir()) }) } diff --git a/crates/project/src/manifest_tree/manifest_store.rs b/crates/project/src/manifest_tree/manifest_store.rs index 0462b257985c6ec554519c565f1e935853654e59..cf9f81aee470646d5800ca4a1a4ed7aff4cbd03d 100644 --- a/crates/project/src/manifest_tree/manifest_store.rs +++ b/crates/project/src/manifest_tree/manifest_store.rs @@ -1,4 +1,4 @@ -use collections::HashMap; +use collections::{HashMap, HashSet}; use gpui::{App, Global, SharedString}; use parking_lot::RwLock; use std::{ops::Deref, sync::Arc}; @@ -11,13 +11,13 @@ struct ManifestProvidersState { } #[derive(Clone, Default)] -pub struct ManifestProviders(Arc>); +pub struct ManifestProvidersStore(Arc>); #[derive(Default)] -struct GlobalManifestProvider(ManifestProviders); +struct GlobalManifestProvider(ManifestProvidersStore); impl Deref for GlobalManifestProvider { - type Target = ManifestProviders; + type Target = ManifestProvidersStore; fn deref(&self) -> &Self::Target { &self.0 @@ -26,7 +26,7 @@ impl Deref for GlobalManifestProvider { impl Global for GlobalManifestProvider {} -impl ManifestProviders { +impl ManifestProvidersStore { /// Returns the global [`ManifestStore`]. /// /// Inserts a default [`ManifestStore`] if one does not yet exist. @@ -45,4 +45,7 @@ impl ManifestProviders { pub(super) fn get(&self, name: &SharedString) -> Option> { self.0.read().providers.get(name).cloned() } + pub(crate) fn manifest_file_names(&self) -> HashSet { + self.0.read().providers.keys().cloned().collect() + } } diff --git a/crates/project/src/manifest_tree/path_trie.rs b/crates/project/src/manifest_tree/path_trie.rs index 1a0736765a43b9e1365334de95eacbe9dbf64382..9cebfda25c69fa35b06cefe9ec744b5e6152a820 100644 --- a/crates/project/src/manifest_tree/path_trie.rs +++ b/crates/project/src/manifest_tree/path_trie.rs @@ -22,9 +22,9 @@ pub(super) struct RootPathTrie
-

Crate serde

- - source · - -