Allow AI interactions to be proxied through Zed's server so you don't need an API key (#7367)

Nathan Sobo , Antonio , Antonio Scandurra , Thorsten , and Max created

Co-authored-by: Antonio <antonio@zed.dev>

Resurrected this from some assistant work I did in Spring of 2023.
- [x] Resurrect streaming responses
- [x] Use streaming responses to enable AI via Zed's servers by default
(but preserve API key option for now)
- [x] Simplify protobuf
- [x] Proxy to OpenAI on zed.dev
- [x] Proxy to Gemini on zed.dev
- [x] Improve UX for switching between openAI and google models
- We current disallow cycling when setting a custom model, but we need a
better solution to keep OpenAI models available while testing the google
ones
- [x] Show remaining tokens correctly for Google models
- [x] Remove semantic index
- [x] Delete `ai` crate
- [x] Cloud front so we can ban abuse
- [x] Rate-limiting
- [x] Fix panic when using inline assistant
- [x] Double check the upgraded `AssistantSettings` are
backwards-compatible
- [x] Add hosted LLM interaction behind a `language-models` feature
flag.

Release Notes:

- We are temporarily removing the semantic index in order to redesign it
from scratch.

---------

Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Thorsten <thorsten@zed.dev>
Co-authored-by: Max <max@zed.dev>

Change summary

Cargo.lock                                                     |  234 
Cargo.toml                                                     |    7 
assets/keymaps/default-macos.json                              |    3 
assets/settings/default.json                                   |   32 
crates/ai/Cargo.toml                                           |   41 
crates/ai/LICENSE-GPL                                          |    1 
crates/ai/src/ai.rs                                            |    8 
crates/ai/src/auth.rs                                          |   23 
crates/ai/src/completion.rs                                    |   23 
crates/ai/src/embedding.rs                                     |  121 
crates/ai/src/models.rs                                        |   16 
crates/ai/src/prompts/base.rs                                  |  337 
crates/ai/src/prompts/file_context.rs                          |  164 
crates/ai/src/prompts/generate.rs                              |   99 
crates/ai/src/prompts/mod.rs                                   |    5 
crates/ai/src/prompts/preamble.rs                              |   52 
crates/ai/src/prompts/repository_context.rs                    |   96 
crates/ai/src/providers.rs                                     |    1 
crates/ai/src/providers/open_ai.rs                             |    9 
crates/ai/src/providers/open_ai/completion.rs                  |  421 
crates/ai/src/providers/open_ai/embedding.rs                   |  345 
crates/ai/src/providers/open_ai/model.rs                       |   59 
crates/ai/src/test.rs                                          |  206 
crates/assistant/Cargo.toml                                    |    9 
crates/assistant/src/assistant.rs                              |  222 
crates/assistant/src/assistant_panel.rs                        |  809 -
crates/assistant/src/assistant_settings.rs                     |  567 
crates/assistant/src/codegen.rs                                |   92 
crates/assistant/src/completion_provider.rs                    |  188 
crates/assistant/src/completion_provider/fake.rs               |   29 
crates/assistant/src/completion_provider/open_ai.rs            |  301 
crates/assistant/src/completion_provider/zed.rs                |  167 
crates/assistant/src/prompts.rs                                |  443 
crates/assistant/src/saved_conversation.rs                     |  121 
crates/assistant/src/streaming_diff.rs                         |    8 
crates/client/src/client.rs                                    |   35 
crates/client/src/telemetry.rs                                 |    2 
crates/collab/Cargo.toml                                       |    3 
crates/collab/migrations.sqlite/20221109000000_test_schema.sql |   10 
crates/collab/migrations/20240220234826_add_rate_buckets.sql   |   11 
crates/collab/src/ai.rs                                        |   75 
crates/collab/src/api/extensions.rs                            |    4 
crates/collab/src/db/queries.rs                                |    1 
crates/collab/src/db/queries/rate_buckets.rs                   |   58 
crates/collab/src/db/tables.rs                                 |    1 
crates/collab/src/db/tables/rate_buckets.rs                    |   31 
crates/collab/src/lib.rs                                       |   14 
crates/collab/src/main.rs                                      |   19 
crates/collab/src/rate_limiter.rs                              |  274 
crates/collab/src/rpc.rs                                       |  313 
crates/collab/src/tests/test_server.rs                         |   20 
crates/google_ai/Cargo.toml                                    |   14 
crates/google_ai/src/google_ai.rs                              |  266 
crates/open_ai/Cargo.toml                                      |   19 
crates/open_ai/src/open_ai.rs                                  |  182 
crates/rpc/proto/zed.proto                                     |   52 
crates/rpc/src/error.rs                                        |    2 
crates/rpc/src/peer.rs                                         |  142 
crates/rpc/src/proto.rs                                        |    7 
crates/search/Cargo.toml                                       |    1 
crates/search/src/buffer_search.rs                             |    7 
crates/search/src/mode.rs                                      |   16 
crates/search/src/project_search.rs                            |  460 
crates/search/src/search.rs                                    |    1 
crates/semantic_index/Cargo.toml                               |   66 
crates/semantic_index/LICENSE-GPL                              |    1 
crates/semantic_index/README.md                                |   20 
crates/semantic_index/eval/gpt-engineer.json                   |  114 
crates/semantic_index/eval/tree-sitter.json                    |  104 
crates/semantic_index/src/db.rs                                |  594 -
crates/semantic_index/src/embedding_queue.rs                   |  169 
crates/semantic_index/src/parsing.rs                           |  414 
crates/semantic_index/src/semantic_index.rs                    | 1308 ---
crates/semantic_index/src/semantic_index_settings.rs           |   33 
crates/semantic_index/src/semantic_index_tests.rs              | 1725 ----
crates/settings/src/settings_store.rs                          |   23 
crates/util/src/http.rs                                        |    3 
crates/zed/Cargo.toml                                          |    1 
crates/zed/src/main.rs                                         |    3 
crates/zed/src/zed.rs                                          |    2 
docs/src/configuring_zed.md                                    |   22 
script/bootstrap                                               |    5 
script/evaluate_semantic_index                                 |    3 
script/gemini.py                                               |   91 
script/linux                                                   |    4 
script/script.py                                               |    1 
script/sqlx                                                    |    7 
87 files changed, 3,530 insertions(+), 8,482 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -85,32 +85,6 @@ dependencies = [
  "memchr",
 ]
 
-[[package]]
-name = "ai"
-version = "0.1.0"
-dependencies = [
- "anyhow",
- "async-trait",
- "bincode",
- "futures 0.3.28",
- "gpui",
- "isahc",
- "language",
- "log",
- "matrixmultiply",
- "ordered-float 2.10.0",
- "parking_lot",
- "parse_duration",
- "postage",
- "rand 0.8.5",
- "rusqlite",
- "schemars",
- "serde",
- "serde_json",
- "tiktoken-rs",
- "util",
-]
-
 [[package]]
 name = "alacritty_terminal"
 version = "0.22.1-dev"
@@ -339,9 +313,9 @@ dependencies = [
 name = "assistant"
 version = "0.1.0"
 dependencies = [
- "ai",
  "anyhow",
  "chrono",
+ "client",
  "collections",
  "ctor",
  "editor",
@@ -354,13 +328,14 @@ dependencies = [
  "log",
  "menu",
  "multi_buffer",
+ "open_ai",
  "ordered-float 2.10.0",
+ "parking_lot",
  "project",
  "rand 0.8.5",
  "regex",
  "schemars",
  "search",
- "semantic_index",
  "serde",
  "serde_json",
  "settings",
@@ -1339,7 +1314,7 @@ version = "0.3.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "a6773ddc0eafc0e509fb60e48dff7f450f8e674a0686ae8605e8d9901bd5eefa"
 dependencies = [
- "num-bigint 0.4.4",
+ "num-bigint",
  "num-integer",
  "num-traits",
 ]
@@ -2209,11 +2184,11 @@ dependencies = [
  "fs",
  "futures 0.3.28",
  "git",
+ "google_ai",
  "gpui",
  "hex",
  "indoc",
  "language",
- "lazy_static",
  "live_kit_client",
  "live_kit_server",
  "log",
@@ -2222,6 +2197,7 @@ dependencies = [
  "nanoid",
  "node_runtime",
  "notifications",
+ "open_ai",
  "parking_lot",
  "pretty_assertions",
  "project",
@@ -3554,24 +3530,12 @@ dependencies = [
  "workspace",
 ]
 
-[[package]]
-name = "fallible-iterator"
-version = "0.2.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
-
 [[package]]
 name = "fallible-iterator"
 version = "0.3.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649"
 
-[[package]]
-name = "fallible-streaming-iterator"
-version = "0.1.9"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
-
 [[package]]
 name = "fancy-regex"
 version = "0.11.0"
@@ -4183,7 +4147,7 @@ version = "0.28.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0"
 dependencies = [
- "fallible-iterator 0.3.0",
+ "fallible-iterator",
  "indexmap 2.0.0",
  "stable_deref_trait",
 ]
@@ -4279,6 +4243,17 @@ dependencies = [
  "workspace",
 ]
 
+[[package]]
+name = "google_ai"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "futures 0.3.28",
+ "serde",
+ "serde_json",
+ "util",
+]
+
 [[package]]
 name = "gpu-alloc"
 version = "0.6.0"
@@ -5667,16 +5642,6 @@ version = "0.7.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
 
-[[package]]
-name = "matrixmultiply"
-version = "0.3.8"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2"
-dependencies = [
- "autocfg",
- "rawpointer",
-]
-
 [[package]]
 name = "maybe-owned"
 version = "0.3.4"
@@ -5946,19 +5911,6 @@ dependencies = [
  "tempfile",
 ]
 
-[[package]]
-name = "ndarray"
-version = "0.15.6"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
-dependencies = [
- "matrixmultiply",
- "num-complex 0.4.4",
- "num-integer",
- "num-traits",
- "rawpointer",
-]
-
 [[package]]
 name = "ndk"
 version = "0.7.0"
@@ -6111,45 +6063,20 @@ dependencies = [
  "winapi",
 ]
 
-[[package]]
-name = "num"
-version = "0.2.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36"
-dependencies = [
- "num-bigint 0.2.6",
- "num-complex 0.2.4",
- "num-integer",
- "num-iter",
- "num-rational 0.2.4",
- "num-traits",
-]
-
 [[package]]
 name = "num"
 version = "0.4.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af"
 dependencies = [
- "num-bigint 0.4.4",
- "num-complex 0.4.4",
+ "num-bigint",
+ "num-complex",
  "num-integer",
  "num-iter",
  "num-rational 0.4.1",
  "num-traits",
 ]
 
-[[package]]
-name = "num-bigint"
-version = "0.2.6"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304"
-dependencies = [
- "autocfg",
- "num-integer",
- "num-traits",
-]
-
 [[package]]
 name = "num-bigint"
 version = "0.4.4"
@@ -6196,16 +6123,6 @@ dependencies = [
  "zeroize",
 ]
 
-[[package]]
-name = "num-complex"
-version = "0.2.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95"
-dependencies = [
- "autocfg",
- "num-traits",
-]
-
 [[package]]
 name = "num-complex"
 version = "0.4.4"
@@ -6247,18 +6164,6 @@ dependencies = [
  "num-traits",
 ]
 
-[[package]]
-name = "num-rational"
-version = "0.2.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5c000134b5dbf44adc5cb772486d335293351644b801551abe8f75c84cfa4aef"
-dependencies = [
- "autocfg",
- "num-bigint 0.2.6",
- "num-integer",
- "num-traits",
-]
-
 [[package]]
 name = "num-rational"
 version = "0.3.2"
@@ -6277,7 +6182,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
 dependencies = [
  "autocfg",
- "num-bigint 0.4.4",
+ "num-bigint",
  "num-integer",
  "num-traits",
 ]
@@ -6436,7 +6341,7 @@ dependencies = [
  "futures-util",
  "hkdf",
  "hmac 0.12.1",
- "num 0.4.1",
+ "num",
  "num-bigint-dig 0.8.4",
  "pbkdf2 0.12.2",
  "rand 0.8.5",
@@ -6464,6 +6369,18 @@ dependencies = [
  "pathdiff",
 ]
 
+[[package]]
+name = "open_ai"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "futures 0.3.28",
+ "schemars",
+ "serde",
+ "serde_json",
+ "util",
+]
+
 [[package]]
 name = "openssl"
 version = "0.10.57"
@@ -6679,17 +6596,6 @@ dependencies = [
  "windows-targets 0.48.5",
 ]
 
-[[package]]
-name = "parse_duration"
-version = "2.1.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7037e5e93e0172a5a96874380bf73bc6ecef022e26fa25f2be26864d6b3ba95d"
-dependencies = [
- "lazy_static",
- "num 0.2.1",
- "regex",
-]
-
 [[package]]
 name = "password-hash"
 version = "0.2.1"
@@ -7471,12 +7377,6 @@ dependencies = [
  "raw-window-handle 0.5.2",
 ]
 
-[[package]]
-name = "rawpointer"
-version = "0.2.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
-
 [[package]]
 name = "rayon"
 version = "1.8.0"
@@ -7935,20 +7835,6 @@ dependencies = [
  "zeroize",
 ]
 
-[[package]]
-name = "rusqlite"
-version = "0.29.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2"
-dependencies = [
- "bitflags 2.4.2",
- "fallible-iterator 0.2.0",
- "fallible-streaming-iterator",
- "hashlink",
- "libsqlite3-sys",
- "smallvec",
-]
-
 [[package]]
 name = "rust-embed"
 version = "8.2.0"
@@ -8378,7 +8264,6 @@ dependencies = [
  "language",
  "menu",
  "project",
- "semantic_index",
  "serde",
  "serde_json",
  "settings",
@@ -8434,52 +8319,6 @@ version = "1.0.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "58bf37232d3bb9a2c4e641ca2a11d83b5062066f88df7fed36c28772046d65ba"
 
-[[package]]
-name = "semantic_index"
-version = "0.1.0"
-dependencies = [
- "ai",
- "anyhow",
- "collections",
- "ctor",
- "env_logger",
- "futures 0.3.28",
- "gpui",
- "language",
- "lazy_static",
- "log",
- "ndarray",
- "ordered-float 2.10.0",
- "parking_lot",
- "postage",
- "pretty_assertions",
- "project",
- "rand 0.8.5",
- "release_channel",
- "rpc",
- "rusqlite",
- "schemars",
- "serde",
- "serde_json",
- "settings",
- "sha1",
- "smol",
- "tempfile",
- "tree-sitter",
- "tree-sitter-cpp",
- "tree-sitter-elixir",
- "tree-sitter-json 0.20.0",
- "tree-sitter-lua",
- "tree-sitter-php",
- "tree-sitter-ruby",
- "tree-sitter-rust",
- "tree-sitter-toml",
- "tree-sitter-typescript",
- "unindent",
- "util",
- "workspace",
-]
-
 [[package]]
 name = "semver"
 version = "1.0.18"
@@ -8766,7 +8605,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "8eb4ea60fb301dc81dfc113df680571045d375ab7345d171c5dc7d7e13107a80"
 dependencies = [
  "chrono",
- "num-bigint 0.4.4",
+ "num-bigint",
  "num-traits",
  "thiserror",
 ]
@@ -9197,7 +9036,7 @@ dependencies = [
  "log",
  "md-5",
  "memchr",
- "num-bigint 0.4.4",
+ "num-bigint",
  "once_cell",
  "rand 0.8.5",
  "rust_decimal",
@@ -12729,7 +12568,6 @@ dependencies = [
  "release_channel",
  "rope",
  "search",
- "semantic_index",
  "serde",
  "serde_json",
  "settings",

Cargo.toml 🔗

@@ -1,7 +1,6 @@
 [workspace]
 members = [
     "crates/activity_indicator",
-    "crates/ai",
     "crates/assets",
     "crates/assistant",
     "crates/audio",
@@ -34,6 +33,7 @@ members = [
     "crates/fuzzy",
     "crates/git",
     "crates/go_to_line",
+    "crates/google_ai",
     "crates/gpui",
     "crates/gpui_macros",
     "crates/image_viewer",
@@ -52,6 +52,7 @@ members = [
     "crates/multi_buffer",
     "crates/node_runtime",
     "crates/notifications",
+    "crates/open_ai",
     "crates/outline",
     "crates/picker",
     "crates/prettier",
@@ -69,7 +70,6 @@ members = [
     "crates/task",
     "crates/tasks_ui",
     "crates/search",
-    "crates/semantic_index",
     "crates/settings",
     "crates/snippet",
     "crates/sqlez",
@@ -138,6 +138,7 @@ fsevent = { path = "crates/fsevent" }
 fuzzy = { path = "crates/fuzzy" }
 git = { path = "crates/git" }
 go_to_line = { path = "crates/go_to_line" }
+google_ai = { path = "crates/google_ai" }
 gpui = { path = "crates/gpui" }
 gpui_macros = { path = "crates/gpui_macros" }
 install_cli = { path = "crates/install_cli" }
@@ -156,6 +157,7 @@ menu = { path = "crates/menu" }
 multi_buffer = { path = "crates/multi_buffer" }
 node_runtime = { path = "crates/node_runtime" }
 notifications = { path = "crates/notifications" }
+open_ai = { path = "crates/open_ai" }
 outline = { path = "crates/outline" }
 picker = { path = "crates/picker" }
 plugin = { path = "crates/plugin" }
@@ -174,7 +176,6 @@ rpc = { path = "crates/rpc" }
 task = { path = "crates/task" }
 tasks_ui = { path = "crates/tasks_ui" }
 search = { path = "crates/search" }
-semantic_index = { path = "crates/semantic_index" }
 settings = { path = "crates/settings" }
 snippet = { path = "crates/snippet" }
 sqlez = { path = "crates/sqlez" }

assets/keymaps/default-macos.json 🔗

@@ -251,7 +251,6 @@
       "alt-tab": "search::CycleMode",
       "cmd-shift-h": "search::ToggleReplace",
       "alt-cmd-g": "search::ActivateRegexMode",
-      "alt-cmd-s": "search::ActivateSemanticMode",
       "alt-cmd-x": "search::ActivateTextMode"
     }
   },
@@ -276,7 +275,6 @@
       "alt-tab": "search::CycleMode",
       "cmd-shift-h": "search::ToggleReplace",
       "alt-cmd-g": "search::ActivateRegexMode",
-      "alt-cmd-s": "search::ActivateSemanticMode",
       "alt-cmd-x": "search::ActivateTextMode"
     }
   },
@@ -302,7 +300,6 @@
       "alt-tab": "search::CycleMode",
       "alt-cmd-f": "project_search::ToggleFilters",
       "alt-cmd-g": "search::ActivateRegexMode",
-      "alt-cmd-s": "search::ActivateSemanticMode",
       "alt-cmd-x": "search::ActivateTextMode"
     }
   },

assets/settings/default.json 🔗

@@ -237,6 +237,8 @@
     "default_width": 380
   },
   "assistant": {
+    // Version of this setting.
+    "version": "1",
     // Whether to show the assistant panel button in the status bar.
     "button": true,
     // Where to dock the assistant panel. Can be 'left', 'right' or 'bottom'.
@@ -245,28 +247,16 @@
     "default_width": 640,
     // Default height when the assistant is docked to the bottom.
     "default_height": 320,
-    // Deprecated: Please use `provider.api_url` instead.
-    // The default OpenAI API endpoint to use when starting new conversations.
-    "openai_api_url": "https://api.openai.com/v1",
-    // Deprecated: Please use `provider.default_model` instead.
-    // The default OpenAI model to use when starting new conversations. This
-    // setting can take three values:
-    //
-    // 1. "gpt-3.5-turbo-0613""
-    // 2. "gpt-4-0613""
-    // 3. "gpt-4-1106-preview"
-    "default_open_ai_model": "gpt-4-1106-preview",
+    // AI provider.
     "provider": {
-      "type": "openai",
-      // The default OpenAI API endpoint to use when starting new conversations.
-      "api_url": "https://api.openai.com/v1",
-      // The default OpenAI model to use when starting new conversations. This
+      "name": "openai",
+      // The default model to use when starting new conversations. This
       // setting can take three values:
       //
-      // 1. "gpt-3.5-turbo-0613""
-      // 2. "gpt-4-0613""
-      // 3. "gpt-4-1106-preview"
-      "default_model": "gpt-4-1106-preview"
+      // 1. "gpt-3.5-turbo"
+      // 2. "gpt-4"
+      // 3. "gpt-4-turbo-preview"
+      "default_model": "gpt-4-turbo-preview"
     }
   },
   // Whether the screen sharing icon is shown in the os status bar.
@@ -505,10 +495,6 @@
     // Existing terminals will not pick up this change until they are recreated.
     // "max_scroll_history_lines": 10000,
   },
-  // Difference settings for semantic_index
-  "semantic_index": {
-    "enabled": true
-  },
   // Settings specific to our elixir integration
   "elixir": {
     // Change the LSP zed uses for elixir.

crates/ai/Cargo.toml 🔗

@@ -1,41 +0,0 @@
-[package]
-name = "ai"
-version = "0.1.0"
-edition = "2021"
-publish = false
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/ai.rs"
-doctest = false
-
-[features]
-test-support = []
-
-[dependencies]
-anyhow.workspace = true
-async-trait.workspace = true
-bincode = "1.3.3"
-futures.workspace = true
-gpui.workspace = true
-isahc.workspace = true
-language.workspace = true
-log.workspace = true
-matrixmultiply = "0.3.7"
-ordered-float.workspace = true
-parking_lot.workspace = true
-parse_duration = "2.1.1"
-postage.workspace = true
-rand.workspace = true
-rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
-schemars.workspace = true
-serde.workspace = true
-serde_json.workspace = true
-tiktoken-rs.workspace = true
-util.workspace = true
-
-[dev-dependencies]
-gpui = { workspace = true, features = ["test-support"] }

crates/ai/src/ai.rs 🔗

@@ -1,8 +0,0 @@
-pub mod auth;
-pub mod completion;
-pub mod embedding;
-pub mod models;
-pub mod prompts;
-pub mod providers;
-#[cfg(any(test, feature = "test-support"))]
-pub mod test;

crates/ai/src/auth.rs 🔗

@@ -1,23 +0,0 @@
-use futures::future::BoxFuture;
-use gpui::AppContext;
-
-#[derive(Clone, Debug)]
-pub enum ProviderCredential {
-    Credentials { api_key: String },
-    NoCredentials,
-    NotNeeded,
-}
-
-pub trait CredentialProvider: Send + Sync {
-    fn has_credentials(&self) -> bool;
-    #[must_use]
-    fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential>;
-    #[must_use]
-    fn save_credentials(
-        &self,
-        cx: &mut AppContext,
-        credential: ProviderCredential,
-    ) -> BoxFuture<()>;
-    #[must_use]
-    fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()>;
-}

crates/ai/src/completion.rs 🔗

@@ -1,23 +0,0 @@
-use anyhow::Result;
-use futures::{future::BoxFuture, stream::BoxStream};
-
-use crate::{auth::CredentialProvider, models::LanguageModel};
-
-pub trait CompletionRequest: Send + Sync {
-    fn data(&self) -> serde_json::Result<String>;
-}
-
-pub trait CompletionProvider: CredentialProvider {
-    fn base_model(&self) -> Box<dyn LanguageModel>;
-    fn complete(
-        &self,
-        prompt: Box<dyn CompletionRequest>,
-    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
-    fn box_clone(&self) -> Box<dyn CompletionProvider>;
-}
-
-impl Clone for Box<dyn CompletionProvider> {
-    fn clone(&self) -> Box<dyn CompletionProvider> {
-        self.box_clone()
-    }
-}

crates/ai/src/embedding.rs 🔗

@@ -1,121 +0,0 @@
-use std::time::Instant;
-
-use anyhow::Result;
-use async_trait::async_trait;
-use ordered_float::OrderedFloat;
-use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
-use rusqlite::ToSql;
-
-use crate::auth::CredentialProvider;
-use crate::models::LanguageModel;
-
-#[derive(Debug, PartialEq, Clone)]
-pub struct Embedding(pub Vec<f32>);
-
-// This is needed for semantic index functionality
-// Unfortunately it has to live wherever the "Embedding" struct is created.
-// Keeping this in here though, introduces a 'rusqlite' dependency into AI
-// which is less than ideal
-impl FromSql for Embedding {
-    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
-        let bytes = value.as_blob()?;
-        let embedding =
-            bincode::deserialize(bytes).map_err(|err| rusqlite::types::FromSqlError::Other(err))?;
-        Ok(Embedding(embedding))
-    }
-}
-
-impl ToSql for Embedding {
-    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
-        let bytes = bincode::serialize(&self.0)
-            .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
-        Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
-    }
-}
-impl From<Vec<f32>> for Embedding {
-    fn from(value: Vec<f32>) -> Self {
-        Embedding(value)
-    }
-}
-
-impl Embedding {
-    pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
-        let len = self.0.len();
-        assert_eq!(len, other.0.len());
-
-        let mut result = 0.0;
-        unsafe {
-            matrixmultiply::sgemm(
-                1,
-                len,
-                1,
-                1.0,
-                self.0.as_ptr(),
-                len as isize,
-                1,
-                other.0.as_ptr(),
-                1,
-                len as isize,
-                0.0,
-                &mut result as *mut f32,
-                1,
-                1,
-            );
-        }
-        OrderedFloat(result)
-    }
-}
-
-#[async_trait]
-pub trait EmbeddingProvider: CredentialProvider {
-    fn base_model(&self) -> Box<dyn LanguageModel>;
-    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
-    fn max_tokens_per_batch(&self) -> usize;
-    fn rate_limit_expiration(&self) -> Option<Instant>;
-}
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-    use rand::prelude::*;
-
-    #[gpui::test]
-    fn test_similarity(mut rng: StdRng) {
-        assert_eq!(
-            Embedding::from(vec![1., 0., 0., 0., 0.])
-                .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
-            0.
-        );
-        assert_eq!(
-            Embedding::from(vec![2., 0., 0., 0., 0.])
-                .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
-            6.
-        );
-
-        for _ in 0..100 {
-            let size = 1536;
-            let mut a = vec![0.; size];
-            let mut b = vec![0.; size];
-            for (a, b) in a.iter_mut().zip(b.iter_mut()) {
-                *a = rng.gen();
-                *b = rng.gen();
-            }
-            let a = Embedding::from(a);
-            let b = Embedding::from(b);
-
-            assert_eq!(
-                round_to_decimals(a.similarity(&b), 1),
-                round_to_decimals(reference_dot(&a.0, &b.0), 1)
-            );
-        }
-
-        fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
-            let factor = 10.0_f32.powi(decimal_places);
-            (n * factor).round() / factor
-        }
-
-        fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
-            OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
-        }
-    }
-}

crates/ai/src/models.rs 🔗

@@ -1,16 +0,0 @@
-pub enum TruncationDirection {
-    Start,
-    End,
-}
-
-pub trait LanguageModel {
-    fn name(&self) -> String;
-    fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
-    fn truncate(
-        &self,
-        content: &str,
-        length: usize,
-        direction: TruncationDirection,
-    ) -> anyhow::Result<String>;
-    fn capacity(&self) -> anyhow::Result<usize>;
-}

crates/ai/src/prompts/base.rs 🔗

@@ -1,337 +0,0 @@
-use std::cmp::Reverse;
-use std::ops::Range;
-use std::sync::Arc;
-
-use language::BufferSnapshot;
-use util::ResultExt;
-
-use crate::models::LanguageModel;
-use crate::prompts::repository_context::PromptCodeSnippet;
-
-pub(crate) enum PromptFileType {
-    Text,
-    Code,
-}
-
-// TODO: Set this up to manage for defaults well
-pub struct PromptArguments {
-    pub model: Arc<dyn LanguageModel>,
-    pub user_prompt: Option<String>,
-    pub language_name: Option<String>,
-    pub project_name: Option<String>,
-    pub snippets: Vec<PromptCodeSnippet>,
-    pub reserved_tokens: usize,
-    pub buffer: Option<BufferSnapshot>,
-    pub selected_range: Option<Range<usize>>,
-}
-
-impl PromptArguments {
-    pub(crate) fn get_file_type(&self) -> PromptFileType {
-        if self
-            .language_name
-            .as_ref()
-            .map(|name| !["Markdown", "Plain Text"].contains(&name.as_str()))
-            .unwrap_or(true)
-        {
-            PromptFileType::Code
-        } else {
-            PromptFileType::Text
-        }
-    }
-}
-
-pub trait PromptTemplate {
-    fn generate(
-        &self,
-        args: &PromptArguments,
-        max_token_length: Option<usize>,
-    ) -> anyhow::Result<(String, usize)>;
-}
-
-#[repr(i8)]
-#[derive(PartialEq, Eq)]
-pub enum PromptPriority {
-    /// Ignores truncation.
-    Mandatory,
-    /// Truncates based on priority.
-    Ordered { order: usize },
-}
-
-impl PartialOrd for PromptPriority {
-    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
-        Some(self.cmp(other))
-    }
-}
-
-impl Ord for PromptPriority {
-    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
-        match (self, other) {
-            (Self::Mandatory, Self::Mandatory) => std::cmp::Ordering::Equal,
-            (Self::Mandatory, Self::Ordered { .. }) => std::cmp::Ordering::Greater,
-            (Self::Ordered { .. }, Self::Mandatory) => std::cmp::Ordering::Less,
-            (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.cmp(a),
-        }
-    }
-}
-
-pub struct PromptChain {
-    args: PromptArguments,
-    templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
-}
-
-impl PromptChain {
-    pub fn new(
-        args: PromptArguments,
-        templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
-    ) -> Self {
-        PromptChain { args, templates }
-    }
-
-    pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
-        // Argsort based on Prompt Priority
-        let separator = "\n";
-        let separator_tokens = self.args.model.count_tokens(separator)?;
-        let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
-        sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
-
-        let mut tokens_outstanding = if truncate {
-            Some(self.args.model.capacity()? - self.args.reserved_tokens)
-        } else {
-            None
-        };
-
-        let mut prompts = vec!["".to_string(); sorted_indices.len()];
-        for idx in sorted_indices {
-            let (_, template) = &self.templates[idx];
-
-            if let Some((template_prompt, prompt_token_count)) =
-                template.generate(&self.args, tokens_outstanding).log_err()
-            {
-                if template_prompt != "" {
-                    prompts[idx] = template_prompt;
-
-                    if let Some(remaining_tokens) = tokens_outstanding {
-                        let new_tokens = prompt_token_count + separator_tokens;
-                        tokens_outstanding = if remaining_tokens > new_tokens {
-                            Some(remaining_tokens - new_tokens)
-                        } else {
-                            Some(0)
-                        };
-                    }
-                }
-            }
-        }
-
-        prompts.retain(|x| x != "");
-
-        let full_prompt = prompts.join(separator);
-        let total_token_count = self.args.model.count_tokens(&full_prompt)?;
-        anyhow::Ok((prompts.join(separator), total_token_count))
-    }
-}
-
-#[cfg(test)]
-pub(crate) mod tests {
-    use crate::models::TruncationDirection;
-    use crate::test::FakeLanguageModel;
-
-    use super::*;
-
-    #[test]
-    pub fn test_prompt_chain() {
-        struct TestPromptTemplate {}
-        impl PromptTemplate for TestPromptTemplate {
-            fn generate(
-                &self,
-                args: &PromptArguments,
-                max_token_length: Option<usize>,
-            ) -> anyhow::Result<(String, usize)> {
-                let mut content = "This is a test prompt template".to_string();
-
-                let mut token_count = args.model.count_tokens(&content)?;
-                if let Some(max_token_length) = max_token_length {
-                    if token_count > max_token_length {
-                        content = args.model.truncate(
-                            &content,
-                            max_token_length,
-                            TruncationDirection::End,
-                        )?;
-                        token_count = max_token_length;
-                    }
-                }
-
-                anyhow::Ok((content, token_count))
-            }
-        }
-
-        struct TestLowPriorityTemplate {}
-        impl PromptTemplate for TestLowPriorityTemplate {
-            fn generate(
-                &self,
-                args: &PromptArguments,
-                max_token_length: Option<usize>,
-            ) -> anyhow::Result<(String, usize)> {
-                let mut content = "This is a low priority test prompt template".to_string();
-
-                let mut token_count = args.model.count_tokens(&content)?;
-                if let Some(max_token_length) = max_token_length {
-                    if token_count > max_token_length {
-                        content = args.model.truncate(
-                            &content,
-                            max_token_length,
-                            TruncationDirection::End,
-                        )?;
-                        token_count = max_token_length;
-                    }
-                }
-
-                anyhow::Ok((content, token_count))
-            }
-        }
-
-        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
-        let args = PromptArguments {
-            model: model.clone(),
-            language_name: None,
-            project_name: None,
-            snippets: Vec::new(),
-            reserved_tokens: 0,
-            buffer: None,
-            selected_range: None,
-            user_prompt: None,
-        };
-
-        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
-            (
-                PromptPriority::Ordered { order: 0 },
-                Box::new(TestPromptTemplate {}),
-            ),
-            (
-                PromptPriority::Ordered { order: 1 },
-                Box::new(TestLowPriorityTemplate {}),
-            ),
-        ];
-        let chain = PromptChain::new(args, templates);
-
-        let (prompt, token_count) = chain.generate(false).unwrap();
-
-        assert_eq!(
-            prompt,
-            "This is a test prompt template\nThis is a low priority test prompt template"
-                .to_string()
-        );
-
-        assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
-
-        // Testing with Truncation Off
-        // Should ignore capacity and return all prompts
-        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
-        let args = PromptArguments {
-            model: model.clone(),
-            language_name: None,
-            project_name: None,
-            snippets: Vec::new(),
-            reserved_tokens: 0,
-            buffer: None,
-            selected_range: None,
-            user_prompt: None,
-        };
-
-        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
-            (
-                PromptPriority::Ordered { order: 0 },
-                Box::new(TestPromptTemplate {}),
-            ),
-            (
-                PromptPriority::Ordered { order: 1 },
-                Box::new(TestLowPriorityTemplate {}),
-            ),
-        ];
-        let chain = PromptChain::new(args, templates);
-
-        let (prompt, token_count) = chain.generate(false).unwrap();
-
-        assert_eq!(
-            prompt,
-            "This is a test prompt template\nThis is a low priority test prompt template"
-                .to_string()
-        );
-
-        assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
-
-        // Testing with Truncation Off
-        // Should ignore capacity and return all prompts
-        let capacity = 20;
-        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
-        let args = PromptArguments {
-            model: model.clone(),
-            language_name: None,
-            project_name: None,
-            snippets: Vec::new(),
-            reserved_tokens: 0,
-            buffer: None,
-            selected_range: None,
-            user_prompt: None,
-        };
-
-        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
-            (
-                PromptPriority::Ordered { order: 0 },
-                Box::new(TestPromptTemplate {}),
-            ),
-            (
-                PromptPriority::Ordered { order: 1 },
-                Box::new(TestLowPriorityTemplate {}),
-            ),
-            (
-                PromptPriority::Ordered { order: 2 },
-                Box::new(TestLowPriorityTemplate {}),
-            ),
-        ];
-        let chain = PromptChain::new(args, templates);
-
-        let (prompt, token_count) = chain.generate(true).unwrap();
-
-        assert_eq!(prompt, "This is a test promp".to_string());
-        assert_eq!(token_count, capacity);
-
-        // Change Ordering of Prompts Based on Priority
-        let capacity = 120;
-        let reserved_tokens = 10;
-        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
-        let args = PromptArguments {
-            model: model.clone(),
-            language_name: None,
-            project_name: None,
-            snippets: Vec::new(),
-            reserved_tokens,
-            buffer: None,
-            selected_range: None,
-            user_prompt: None,
-        };
-        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
-            (
-                PromptPriority::Mandatory,
-                Box::new(TestLowPriorityTemplate {}),
-            ),
-            (
-                PromptPriority::Ordered { order: 0 },
-                Box::new(TestPromptTemplate {}),
-            ),
-            (
-                PromptPriority::Ordered { order: 1 },
-                Box::new(TestLowPriorityTemplate {}),
-            ),
-        ];
-        let chain = PromptChain::new(args, templates);
-
-        let (prompt, token_count) = chain.generate(true).unwrap();
-
-        assert_eq!(
-            prompt,
-            "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
-                .to_string()
-        );
-        assert_eq!(token_count, capacity - reserved_tokens);
-    }
-}

crates/ai/src/prompts/file_context.rs 🔗

@@ -1,164 +0,0 @@
-use anyhow::anyhow;
-use language::BufferSnapshot;
-use language::ToOffset;
-
-use crate::models::LanguageModel;
-use crate::models::TruncationDirection;
-use crate::prompts::base::PromptArguments;
-use crate::prompts::base::PromptTemplate;
-use std::fmt::Write;
-use std::ops::Range;
-use std::sync::Arc;
-
-fn retrieve_context(
-    buffer: &BufferSnapshot,
-    selected_range: &Option<Range<usize>>,
-    model: Arc<dyn LanguageModel>,
-    max_token_count: Option<usize>,
-) -> anyhow::Result<(String, usize, bool)> {
-    let mut prompt = String::new();
-    let mut truncated = false;
-    if let Some(selected_range) = selected_range {
-        let start = selected_range.start.to_offset(buffer);
-        let end = selected_range.end.to_offset(buffer);
-
-        let start_window = buffer.text_for_range(0..start).collect::<String>();
-
-        let mut selected_window = String::new();
-        if start == end {
-            write!(selected_window, "<|START|>").unwrap();
-        } else {
-            write!(selected_window, "<|START|").unwrap();
-        }
-
-        write!(
-            selected_window,
-            "{}",
-            buffer.text_for_range(start..end).collect::<String>()
-        )
-        .unwrap();
-
-        if start != end {
-            write!(selected_window, "|END|>").unwrap();
-        }
-
-        let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
-
-        if let Some(max_token_count) = max_token_count {
-            let selected_tokens = model.count_tokens(&selected_window)?;
-            if selected_tokens > max_token_count {
-                return Err(anyhow!(
-                    "selected range is greater than model context window, truncation not possible"
-                ));
-            };
-
-            let mut remaining_tokens = max_token_count - selected_tokens;
-            let start_window_tokens = model.count_tokens(&start_window)?;
-            let end_window_tokens = model.count_tokens(&end_window)?;
-            let outside_tokens = start_window_tokens + end_window_tokens;
-            if outside_tokens > remaining_tokens {
-                let (start_goal_tokens, end_goal_tokens) =
-                    if start_window_tokens < end_window_tokens {
-                        let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
-                        remaining_tokens -= start_goal_tokens;
-                        let end_goal_tokens = remaining_tokens.min(end_window_tokens);
-                        (start_goal_tokens, end_goal_tokens)
-                    } else {
-                        let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
-                        remaining_tokens -= end_goal_tokens;
-                        let start_goal_tokens = remaining_tokens.min(start_window_tokens);
-                        (start_goal_tokens, end_goal_tokens)
-                    };
-
-                let truncated_start_window =
-                    model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
-                let truncated_end_window =
-                    model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
-                writeln!(
-                    prompt,
-                    "{truncated_start_window}{selected_window}{truncated_end_window}"
-                )
-                .unwrap();
-                truncated = true;
-            } else {
-                writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
-            }
-        } else {
-            // If we dont have a selected range, include entire file.
-            writeln!(prompt, "{}", &buffer.text()).unwrap();
-
-            // Dumb truncation strategy
-            if let Some(max_token_count) = max_token_count {
-                if model.count_tokens(&prompt)? > max_token_count {
-                    truncated = true;
-                    prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
-                }
-            }
-        }
-    }
-
-    let token_count = model.count_tokens(&prompt)?;
-    anyhow::Ok((prompt, token_count, truncated))
-}
-
-pub struct FileContext {}
-
-impl PromptTemplate for FileContext {
-    fn generate(
-        &self,
-        args: &PromptArguments,
-        max_token_length: Option<usize>,
-    ) -> anyhow::Result<(String, usize)> {
-        if let Some(buffer) = &args.buffer {
-            let mut prompt = String::new();
-            // Add Initial Preamble
-            // TODO: Do we want to add the path in here?
-            writeln!(
-                prompt,
-                "The file you are currently working on has the following content:"
-            )
-            .unwrap();
-
-            let language_name = args
-                .language_name
-                .clone()
-                .unwrap_or("".to_string())
-                .to_lowercase();
-
-            let (context, _, truncated) = retrieve_context(
-                buffer,
-                &args.selected_range,
-                args.model.clone(),
-                max_token_length,
-            )?;
-            writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
-
-            if truncated {
-                writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
-            }
-
-            if let Some(selected_range) = &args.selected_range {
-                let start = selected_range.start.to_offset(buffer);
-                let end = selected_range.end.to_offset(buffer);
-
-                if start == end {
-                    writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
-                } else {
-                    writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
-                }
-            }
-
-            // Really dumb truncation strategy
-            if let Some(max_tokens) = max_token_length {
-                prompt = args
-                    .model
-                    .truncate(&prompt, max_tokens, TruncationDirection::End)?;
-            }
-
-            let token_count = args.model.count_tokens(&prompt)?;
-            anyhow::Ok((prompt, token_count))
-        } else {
-            Err(anyhow!("no buffer provided to retrieve file context from"))
-        }
-    }
-}

crates/ai/src/prompts/generate.rs 🔗

@@ -1,99 +0,0 @@
-use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
-use anyhow::anyhow;
-use std::fmt::Write;
-
-pub fn capitalize(s: &str) -> String {
-    let mut c = s.chars();
-    match c.next() {
-        None => String::new(),
-        Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
-    }
-}
-
-pub struct GenerateInlineContent {}
-
-impl PromptTemplate for GenerateInlineContent {
-    fn generate(
-        &self,
-        args: &PromptArguments,
-        max_token_length: Option<usize>,
-    ) -> anyhow::Result<(String, usize)> {
-        let Some(user_prompt) = &args.user_prompt else {
-            return Err(anyhow!("user prompt not provided"));
-        };
-
-        let file_type = args.get_file_type();
-        let content_type = match &file_type {
-            PromptFileType::Code => "code",
-            PromptFileType::Text => "text",
-        };
-
-        let mut prompt = String::new();
-
-        if let Some(selected_range) = &args.selected_range {
-            if selected_range.start == selected_range.end {
-                writeln!(
-                    prompt,
-                    "Assume the cursor is located where the `<|START|>` span is."
-                )
-                .unwrap();
-                writeln!(
-                    prompt,
-                    "{} can't be replaced, so assume your answer will be inserted at the cursor.",
-                    capitalize(content_type)
-                )
-                .unwrap();
-                writeln!(
-                    prompt,
-                    "Generate {content_type} based on the users prompt: {user_prompt}",
-                )
-                .unwrap();
-            } else {
-                writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
-                writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
-                writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
-            }
-        } else {
-            writeln!(
-                prompt,
-                "Generate {content_type} based on the users prompt: {user_prompt}"
-            )
-            .unwrap();
-        }
-
-        if let Some(language_name) = &args.language_name {
-            writeln!(
-                prompt,
-                "Your answer MUST always and only be valid {}.",
-                language_name
-            )
-            .unwrap();
-        }
-        writeln!(prompt, "Never make remarks about the output.").unwrap();
-        writeln!(
-            prompt,
-            "Do not return anything else, except the generated {content_type}."
-        )
-        .unwrap();
-
-        match file_type {
-            PromptFileType::Code => {
-                // writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
-            }
-            _ => {}
-        }
-
-        // Really dumb truncation strategy
-        if let Some(max_tokens) = max_token_length {
-            prompt = args.model.truncate(
-                &prompt,
-                max_tokens,
-                crate::models::TruncationDirection::End,
-            )?;
-        }
-
-        let token_count = args.model.count_tokens(&prompt)?;
-
-        anyhow::Ok((prompt, token_count))
-    }
-}

crates/ai/src/prompts/preamble.rs 🔗

@@ -1,52 +0,0 @@
-use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
-use std::fmt::Write;
-
-pub struct EngineerPreamble {}
-
-impl PromptTemplate for EngineerPreamble {
-    fn generate(
-        &self,
-        args: &PromptArguments,
-        max_token_length: Option<usize>,
-    ) -> anyhow::Result<(String, usize)> {
-        let mut prompts = Vec::new();
-
-        match args.get_file_type() {
-            PromptFileType::Code => {
-                prompts.push(format!(
-                    "You are an expert {}engineer.",
-                    args.language_name.clone().unwrap_or("".to_string()) + " "
-                ));
-            }
-            PromptFileType::Text => {
-                prompts.push("You are an expert engineer.".to_string());
-            }
-        }
-
-        if let Some(project_name) = args.project_name.clone() {
-            prompts.push(format!(
-                "You are currently working inside the '{project_name}' project in code editor Zed."
-            ));
-        }
-
-        if let Some(mut remaining_tokens) = max_token_length {
-            let mut prompt = String::new();
-            let mut total_count = 0;
-            for prompt_piece in prompts {
-                let prompt_token_count =
-                    args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
-                if remaining_tokens > prompt_token_count {
-                    writeln!(prompt, "{prompt_piece}").unwrap();
-                    remaining_tokens -= prompt_token_count;
-                    total_count += prompt_token_count;
-                }
-            }
-
-            anyhow::Ok((prompt, total_count))
-        } else {
-            let prompt = prompts.join("\n");
-            let token_count = args.model.count_tokens(&prompt)?;
-            anyhow::Ok((prompt, token_count))
-        }
-    }
-}

crates/ai/src/prompts/repository_context.rs 🔗

@@ -1,96 +0,0 @@
-use crate::prompts::base::{PromptArguments, PromptTemplate};
-use std::fmt::Write;
-use std::{ops::Range, path::PathBuf};
-
-use gpui::{AsyncAppContext, Model};
-use language::{Anchor, Buffer};
-
-#[derive(Clone)]
-pub struct PromptCodeSnippet {
-    path: Option<PathBuf>,
-    language_name: Option<String>,
-    content: String,
-}
-
-impl PromptCodeSnippet {
-    pub fn new(
-        buffer: Model<Buffer>,
-        range: Range<Anchor>,
-        cx: &mut AsyncAppContext,
-    ) -> anyhow::Result<Self> {
-        let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
-            let snapshot = buffer.snapshot();
-            let content = snapshot.text_for_range(range.clone()).collect::<String>();
-
-            let language_name = buffer
-                .language()
-                .map(|language| language.name().to_string().to_lowercase());
-
-            let file_path = buffer.file().map(|file| file.path().to_path_buf());
-
-            (content, language_name, file_path)
-        })?;
-
-        anyhow::Ok(PromptCodeSnippet {
-            path: file_path,
-            language_name,
-            content,
-        })
-    }
-}
-
-impl ToString for PromptCodeSnippet {
-    fn to_string(&self) -> String {
-        let path = self
-            .path
-            .as_ref()
-            .map(|path| path.to_string_lossy().to_string())
-            .unwrap_or("".to_string());
-        let language_name = self.language_name.clone().unwrap_or("".to_string());
-        let content = self.content.clone();
-
-        format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
-    }
-}
-
-pub struct RepositoryContext {}
-
-impl PromptTemplate for RepositoryContext {
-    fn generate(
-        &self,
-        args: &PromptArguments,
-        max_token_length: Option<usize>,
-    ) -> anyhow::Result<(String, usize)> {
-        const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
-        let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
-        let mut prompt = String::new();
-
-        let mut remaining_tokens = max_token_length;
-        let separator_token_length = args.model.count_tokens("\n")?;
-        for snippet in &args.snippets {
-            let mut snippet_prompt = template.to_string();
-            let content = snippet.to_string();
-            writeln!(snippet_prompt, "{content}").unwrap();
-
-            let token_count = args.model.count_tokens(&snippet_prompt)?;
-            if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
-                if let Some(tokens_left) = remaining_tokens {
-                    if tokens_left >= token_count {
-                        writeln!(prompt, "{snippet_prompt}").unwrap();
-                        remaining_tokens = if tokens_left >= (token_count + separator_token_length)
-                        {
-                            Some(tokens_left - token_count - separator_token_length)
-                        } else {
-                            Some(0)
-                        };
-                    }
-                } else {
-                    writeln!(prompt, "{snippet_prompt}").unwrap();
-                }
-            }
-        }
-
-        let total_token_count = args.model.count_tokens(&prompt)?;
-        anyhow::Ok((prompt, total_token_count))
-    }
-}

crates/ai/src/providers/open_ai.rs 🔗

@@ -1,9 +0,0 @@
-pub mod completion;
-pub mod embedding;
-pub mod model;
-
-pub use completion::*;
-pub use embedding::*;
-pub use model::OpenAiLanguageModel;
-
-pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";

crates/ai/src/providers/open_ai/completion.rs 🔗

@@ -1,421 +0,0 @@
-use std::{
-    env,
-    fmt::{self, Display},
-    io,
-    sync::Arc,
-};
-
-use anyhow::{anyhow, Result};
-use futures::{
-    future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
-    Stream, StreamExt,
-};
-use gpui::{AppContext, BackgroundExecutor};
-use isahc::{http::StatusCode, Request, RequestExt};
-use parking_lot::RwLock;
-use schemars::JsonSchema;
-use serde::{Deserialize, Serialize};
-use util::ResultExt;
-
-use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
-use crate::{
-    auth::{CredentialProvider, ProviderCredential},
-    completion::{CompletionProvider, CompletionRequest},
-    models::LanguageModel,
-};
-
-#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
-#[serde(rename_all = "lowercase")]
-pub enum Role {
-    User,
-    Assistant,
-    System,
-}
-
-impl Role {
-    pub fn cycle(&mut self) {
-        *self = match self {
-            Role::User => Role::Assistant,
-            Role::Assistant => Role::System,
-            Role::System => Role::User,
-        }
-    }
-}
-
-impl Display for Role {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
-        match self {
-            Role::User => write!(f, "User"),
-            Role::Assistant => write!(f, "Assistant"),
-            Role::System => write!(f, "System"),
-        }
-    }
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct RequestMessage {
-    pub role: Role,
-    pub content: String,
-}
-
-#[derive(Debug, Default, Serialize)]
-pub struct OpenAiRequest {
-    pub model: String,
-    pub messages: Vec<RequestMessage>,
-    pub stream: bool,
-    pub stop: Vec<String>,
-    pub temperature: f32,
-}
-
-impl CompletionRequest for OpenAiRequest {
-    fn data(&self) -> serde_json::Result<String> {
-        serde_json::to_string(self)
-    }
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct ResponseMessage {
-    pub role: Option<Role>,
-    pub content: Option<String>,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct OpenAiUsage {
-    pub prompt_tokens: u32,
-    pub completion_tokens: u32,
-    pub total_tokens: u32,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct ChatChoiceDelta {
-    pub index: u32,
-    pub delta: ResponseMessage,
-    pub finish_reason: Option<String>,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct OpenAiResponseStreamEvent {
-    pub id: Option<String>,
-    pub object: String,
-    pub created: u32,
-    pub model: String,
-    pub choices: Vec<ChatChoiceDelta>,
-    pub usage: Option<OpenAiUsage>,
-}
-
-async fn stream_completion(
-    api_url: String,
-    kind: OpenAiCompletionProviderKind,
-    credential: ProviderCredential,
-    executor: BackgroundExecutor,
-    request: Box<dyn CompletionRequest>,
-) -> Result<impl Stream<Item = Result<OpenAiResponseStreamEvent>>> {
-    let api_key = match credential {
-        ProviderCredential::Credentials { api_key } => api_key,
-        _ => {
-            return Err(anyhow!("no credentials provider for completion"));
-        }
-    };
-
-    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
-
-    let (auth_header_name, auth_header_value) = kind.auth_header(api_key);
-    let json_data = request.data()?;
-    let mut response = Request::post(kind.completions_endpoint_url(&api_url))
-        .header("Content-Type", "application/json")
-        .header(auth_header_name, auth_header_value)
-        .body(json_data)?
-        .send_async()
-        .await?;
-
-    let status = response.status();
-    if status == StatusCode::OK {
-        executor
-            .spawn(async move {
-                let mut lines = BufReader::new(response.body_mut()).lines();
-
-                fn parse_line(
-                    line: Result<String, io::Error>,
-                ) -> Result<Option<OpenAiResponseStreamEvent>> {
-                    if let Some(data) = line?.strip_prefix("data: ") {
-                        let event = serde_json::from_str(data)?;
-                        Ok(Some(event))
-                    } else {
-                        Ok(None)
-                    }
-                }
-
-                while let Some(line) = lines.next().await {
-                    if let Some(event) = parse_line(line).transpose() {
-                        let done = event.as_ref().map_or(false, |event| {
-                            event
-                                .choices
-                                .last()
-                                .map_or(false, |choice| choice.finish_reason.is_some())
-                        });
-                        if tx.unbounded_send(event).is_err() {
-                            break;
-                        }
-
-                        if done {
-                            break;
-                        }
-                    }
-                }
-
-                anyhow::Ok(())
-            })
-            .detach();
-
-        Ok(rx)
-    } else {
-        let mut body = String::new();
-        response.body_mut().read_to_string(&mut body).await?;
-
-        #[derive(Deserialize)]
-        struct OpenAiResponse {
-            error: OpenAiError,
-        }
-
-        #[derive(Deserialize)]
-        struct OpenAiError {
-            message: String,
-        }
-
-        match serde_json::from_str::<OpenAiResponse>(&body) {
-            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
-                "Failed to connect to OpenAI API: {}",
-                response.error.message,
-            )),
-
-            _ => Err(anyhow!(
-                "Failed to connect to OpenAI API: {} {}",
-                response.status(),
-                body,
-            )),
-        }
-    }
-}
-
-#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
-pub enum AzureOpenAiApiVersion {
-    /// Retiring April 2, 2024.
-    #[serde(rename = "2023-03-15-preview")]
-    V2023_03_15Preview,
-    #[serde(rename = "2023-05-15")]
-    V2023_05_15,
-    /// Retiring April 2, 2024.
-    #[serde(rename = "2023-06-01-preview")]
-    V2023_06_01Preview,
-    /// Retiring April 2, 2024.
-    #[serde(rename = "2023-07-01-preview")]
-    V2023_07_01Preview,
-    /// Retiring April 2, 2024.
-    #[serde(rename = "2023-08-01-preview")]
-    V2023_08_01Preview,
-    /// Retiring April 2, 2024.
-    #[serde(rename = "2023-09-01-preview")]
-    V2023_09_01Preview,
-    #[serde(rename = "2023-12-01-preview")]
-    V2023_12_01Preview,
-    #[serde(rename = "2024-02-15-preview")]
-    V2024_02_15Preview,
-}
-
-impl fmt::Display for AzureOpenAiApiVersion {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        write!(
-            f,
-            "{}",
-            match self {
-                Self::V2023_03_15Preview => "2023-03-15-preview",
-                Self::V2023_05_15 => "2023-05-15",
-                Self::V2023_06_01Preview => "2023-06-01-preview",
-                Self::V2023_07_01Preview => "2023-07-01-preview",
-                Self::V2023_08_01Preview => "2023-08-01-preview",
-                Self::V2023_09_01Preview => "2023-09-01-preview",
-                Self::V2023_12_01Preview => "2023-12-01-preview",
-                Self::V2024_02_15Preview => "2024-02-15-preview",
-            }
-        )
-    }
-}
-
-#[derive(Clone)]
-pub enum OpenAiCompletionProviderKind {
-    OpenAi,
-    AzureOpenAi {
-        deployment_id: String,
-        api_version: AzureOpenAiApiVersion,
-    },
-}
-
-impl OpenAiCompletionProviderKind {
-    /// Returns the chat completion endpoint URL for this [`OpenAiCompletionProviderKind`].
-    fn completions_endpoint_url(&self, api_url: &str) -> String {
-        match self {
-            Self::OpenAi => {
-                // https://platform.openai.com/docs/api-reference/chat/create
-                format!("{api_url}/chat/completions")
-            }
-            Self::AzureOpenAi {
-                deployment_id,
-                api_version,
-            } => {
-                // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
-                format!("{api_url}/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}")
-            }
-        }
-    }
-
-    /// Returns the authentication header for this [`OpenAiCompletionProviderKind`].
-    fn auth_header(&self, api_key: String) -> (&'static str, String) {
-        match self {
-            Self::OpenAi => ("Authorization", format!("Bearer {api_key}")),
-            Self::AzureOpenAi { .. } => ("Api-Key", api_key),
-        }
-    }
-}
-
-#[derive(Clone)]
-pub struct OpenAiCompletionProvider {
-    api_url: String,
-    kind: OpenAiCompletionProviderKind,
-    model: OpenAiLanguageModel,
-    credential: Arc<RwLock<ProviderCredential>>,
-    executor: BackgroundExecutor,
-}
-
-impl OpenAiCompletionProvider {
-    pub async fn new(
-        api_url: String,
-        kind: OpenAiCompletionProviderKind,
-        model_name: String,
-        executor: BackgroundExecutor,
-    ) -> Self {
-        let model = executor
-            .spawn(async move { OpenAiLanguageModel::load(&model_name) })
-            .await;
-        let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
-        Self {
-            api_url,
-            kind,
-            model,
-            credential,
-            executor,
-        }
-    }
-}
-
-impl CredentialProvider for OpenAiCompletionProvider {
-    fn has_credentials(&self) -> bool {
-        match *self.credential.read() {
-            ProviderCredential::Credentials { .. } => true,
-            _ => false,
-        }
-    }
-
-    fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
-        let existing_credential = self.credential.read().clone();
-        let retrieved_credential = match existing_credential {
-            ProviderCredential::Credentials { .. } => {
-                return async move { existing_credential }.boxed()
-            }
-            _ => {
-                if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
-                    async move { ProviderCredential::Credentials { api_key } }.boxed()
-                } else {
-                    let credentials = cx.read_credentials(OPEN_AI_API_URL);
-                    async move {
-                        if let Some(Some((_, api_key))) = credentials.await.log_err() {
-                            if let Some(api_key) = String::from_utf8(api_key).log_err() {
-                                ProviderCredential::Credentials { api_key }
-                            } else {
-                                ProviderCredential::NoCredentials
-                            }
-                        } else {
-                            ProviderCredential::NoCredentials
-                        }
-                    }
-                    .boxed()
-                }
-            }
-        };
-
-        async move {
-            let retrieved_credential = retrieved_credential.await;
-            *self.credential.write() = retrieved_credential.clone();
-            retrieved_credential
-        }
-        .boxed()
-    }
-
-    fn save_credentials(
-        &self,
-        cx: &mut AppContext,
-        credential: ProviderCredential,
-    ) -> BoxFuture<()> {
-        *self.credential.write() = credential.clone();
-        let credential = credential.clone();
-        let write_credentials = match credential {
-            ProviderCredential::Credentials { api_key } => {
-                Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
-            }
-            _ => None,
-        };
-
-        async move {
-            if let Some(write_credentials) = write_credentials {
-                write_credentials.await.log_err();
-            }
-        }
-        .boxed()
-    }
-
-    fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
-        *self.credential.write() = ProviderCredential::NoCredentials;
-        let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
-        async move {
-            delete_credentials.await.log_err();
-        }
-        .boxed()
-    }
-}
-
-impl CompletionProvider for OpenAiCompletionProvider {
-    fn base_model(&self) -> Box<dyn LanguageModel> {
-        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
-        model
-    }
-
-    fn complete(
-        &self,
-        prompt: Box<dyn CompletionRequest>,
-    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
-        // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
-        // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
-        // which is currently model based, due to the language model.
-        // At some point in the future we should rectify this.
-        let credential = self.credential.read().clone();
-        let api_url = self.api_url.clone();
-        let kind = self.kind.clone();
-        let request = stream_completion(api_url, kind, credential, self.executor.clone(), prompt);
-        async move {
-            let response = request.await?;
-            let stream = response
-                .filter_map(|response| async move {
-                    match response {
-                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
-                        Err(error) => Some(Err(error)),
-                    }
-                })
-                .boxed();
-            Ok(stream)
-        }
-        .boxed()
-    }
-
-    fn box_clone(&self) -> Box<dyn CompletionProvider> {
-        Box::new((*self).clone())
-    }
-}

crates/ai/src/providers/open_ai/embedding.rs 🔗

@@ -1,345 +0,0 @@
-use anyhow::{anyhow, Result};
-use async_trait::async_trait;
-use futures::future::BoxFuture;
-use futures::AsyncReadExt;
-use futures::FutureExt;
-use gpui::AppContext;
-use gpui::BackgroundExecutor;
-use isahc::http::StatusCode;
-use isahc::prelude::Configurable;
-use isahc::{AsyncBody, Response};
-use parking_lot::{Mutex, RwLock};
-use parse_duration::parse;
-use postage::watch;
-use serde::{Deserialize, Serialize};
-use serde_json;
-use std::env;
-use std::ops::Add;
-use std::sync::{Arc, OnceLock};
-use std::time::{Duration, Instant};
-use tiktoken_rs::{cl100k_base, CoreBPE};
-use util::http::{HttpClient, Request};
-use util::ResultExt;
-
-use crate::auth::{CredentialProvider, ProviderCredential};
-use crate::embedding::{Embedding, EmbeddingProvider};
-use crate::models::LanguageModel;
-use crate::providers::open_ai::OpenAiLanguageModel;
-
-use crate::providers::open_ai::OPEN_AI_API_URL;
-
-pub(crate) fn open_ai_bpe_tokenizer() -> &'static CoreBPE {
-    static OPEN_AI_BPE_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
-    OPEN_AI_BPE_TOKENIZER.get_or_init(|| cl100k_base().unwrap())
-}
-
-#[derive(Clone)]
-pub struct OpenAiEmbeddingProvider {
-    api_url: String,
-    model: OpenAiLanguageModel,
-    credential: Arc<RwLock<ProviderCredential>>,
-    pub client: Arc<dyn HttpClient>,
-    pub executor: BackgroundExecutor,
-    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
-    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
-}
-
-#[derive(Serialize)]
-struct OpenAiEmbeddingRequest<'a> {
-    model: &'static str,
-    input: Vec<&'a str>,
-}
-
-#[derive(Deserialize)]
-struct OpenAiEmbeddingResponse {
-    data: Vec<OpenAiEmbedding>,
-    usage: OpenAiEmbeddingUsage,
-}
-
-#[derive(Debug, Deserialize)]
-struct OpenAiEmbedding {
-    embedding: Vec<f32>,
-    index: usize,
-    object: String,
-}
-
-#[derive(Deserialize)]
-struct OpenAiEmbeddingUsage {
-    prompt_tokens: usize,
-    total_tokens: usize,
-}
-
-impl OpenAiEmbeddingProvider {
-    pub async fn new(
-        api_url: String,
-        client: Arc<dyn HttpClient>,
-        executor: BackgroundExecutor,
-    ) -> Self {
-        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
-        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
-
-        // Loading the model is expensive, so ensure this runs off the main thread.
-        let model = executor
-            .spawn(async move { OpenAiLanguageModel::load("text-embedding-ada-002") })
-            .await;
-        let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
-
-        OpenAiEmbeddingProvider {
-            api_url,
-            model,
-            credential,
-            client,
-            executor,
-            rate_limit_count_rx,
-            rate_limit_count_tx,
-        }
-    }
-
-    fn get_api_key(&self) -> Result<String> {
-        match self.credential.read().clone() {
-            ProviderCredential::Credentials { api_key } => Ok(api_key),
-            _ => Err(anyhow!("api credentials not provided")),
-        }
-    }
-
-    fn resolve_rate_limit(&self) {
-        let reset_time = *self.rate_limit_count_tx.lock().borrow();
-
-        if let Some(reset_time) = reset_time {
-            if Instant::now() >= reset_time {
-                *self.rate_limit_count_tx.lock().borrow_mut() = None
-            }
-        }
-
-        log::trace!(
-            "resolving reset time: {:?}",
-            *self.rate_limit_count_tx.lock().borrow()
-        );
-    }
-
-    fn update_reset_time(&self, reset_time: Instant) {
-        let original_time = *self.rate_limit_count_tx.lock().borrow();
-
-        let updated_time = if let Some(original_time) = original_time {
-            if reset_time < original_time {
-                Some(reset_time)
-            } else {
-                Some(original_time)
-            }
-        } else {
-            Some(reset_time)
-        };
-
-        log::trace!("updating rate limit time: {:?}", updated_time);
-
-        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
-    }
-    async fn send_request(
-        &self,
-        api_url: &str,
-        api_key: &str,
-        spans: Vec<&str>,
-        request_timeout: u64,
-    ) -> Result<Response<AsyncBody>> {
-        let request = Request::post(format!("{api_url}/embeddings"))
-            .redirect_policy(isahc::config::RedirectPolicy::Follow)
-            .timeout(Duration::from_secs(request_timeout))
-            .header("Content-Type", "application/json")
-            .header("Authorization", format!("Bearer {}", api_key))
-            .body(
-                serde_json::to_string(&OpenAiEmbeddingRequest {
-                    input: spans.clone(),
-                    model: "text-embedding-ada-002",
-                })
-                .unwrap()
-                .into(),
-            )?;
-
-        Ok(self.client.send(request).await?)
-    }
-}
-
-impl CredentialProvider for OpenAiEmbeddingProvider {
-    fn has_credentials(&self) -> bool {
-        match *self.credential.read() {
-            ProviderCredential::Credentials { .. } => true,
-            _ => false,
-        }
-    }
-
-    fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
-        let existing_credential = self.credential.read().clone();
-        let retrieved_credential = match existing_credential {
-            ProviderCredential::Credentials { .. } => {
-                return async move { existing_credential }.boxed()
-            }
-            _ => {
-                if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
-                    async move { ProviderCredential::Credentials { api_key } }.boxed()
-                } else {
-                    let credentials = cx.read_credentials(OPEN_AI_API_URL);
-                    async move {
-                        if let Some(Some((_, api_key))) = credentials.await.log_err() {
-                            if let Some(api_key) = String::from_utf8(api_key).log_err() {
-                                ProviderCredential::Credentials { api_key }
-                            } else {
-                                ProviderCredential::NoCredentials
-                            }
-                        } else {
-                            ProviderCredential::NoCredentials
-                        }
-                    }
-                    .boxed()
-                }
-            }
-        };
-
-        async move {
-            let retrieved_credential = retrieved_credential.await;
-            *self.credential.write() = retrieved_credential.clone();
-            retrieved_credential
-        }
-        .boxed()
-    }
-
-    fn save_credentials(
-        &self,
-        cx: &mut AppContext,
-        credential: ProviderCredential,
-    ) -> BoxFuture<()> {
-        *self.credential.write() = credential.clone();
-        let credential = credential.clone();
-        let write_credentials = match credential {
-            ProviderCredential::Credentials { api_key } => {
-                Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
-            }
-            _ => None,
-        };
-
-        async move {
-            if let Some(write_credentials) = write_credentials {
-                write_credentials.await.log_err();
-            }
-        }
-        .boxed()
-    }
-
-    fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
-        *self.credential.write() = ProviderCredential::NoCredentials;
-        let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
-        async move {
-            delete_credentials.await.log_err();
-        }
-        .boxed()
-    }
-}
-
-#[async_trait]
-impl EmbeddingProvider for OpenAiEmbeddingProvider {
-    fn base_model(&self) -> Box<dyn LanguageModel> {
-        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
-        model
-    }
-
-    fn max_tokens_per_batch(&self) -> usize {
-        50000
-    }
-
-    fn rate_limit_expiration(&self) -> Option<Instant> {
-        *self.rate_limit_count_rx.borrow()
-    }
-
-    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
-        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
-        const MAX_RETRIES: usize = 4;
-
-        let api_url = self.api_url.as_str();
-        let api_key = self.get_api_key()?;
-
-        let mut request_number = 0;
-        let mut rate_limiting = false;
-        let mut request_timeout: u64 = 15;
-        let mut response: Response<AsyncBody>;
-        while request_number < MAX_RETRIES {
-            response = self
-                .send_request(
-                    &api_url,
-                    &api_key,
-                    spans.iter().map(|x| &**x).collect(),
-                    request_timeout,
-                )
-                .await?;
-
-            request_number += 1;
-
-            match response.status() {
-                StatusCode::REQUEST_TIMEOUT => {
-                    request_timeout += 5;
-                }
-                StatusCode::OK => {
-                    let mut body = String::new();
-                    response.body_mut().read_to_string(&mut body).await?;
-                    let response: OpenAiEmbeddingResponse = serde_json::from_str(&body)?;
-
-                    log::trace!(
-                        "openai embedding completed. tokens: {:?}",
-                        response.usage.total_tokens
-                    );
-
-                    // If we complete a request successfully that was previously rate_limited
-                    // resolve the rate limit
-                    if rate_limiting {
-                        self.resolve_rate_limit()
-                    }
-
-                    return Ok(response
-                        .data
-                        .into_iter()
-                        .map(|embedding| Embedding::from(embedding.embedding))
-                        .collect());
-                }
-                StatusCode::TOO_MANY_REQUESTS => {
-                    rate_limiting = true;
-                    let mut body = String::new();
-                    response.body_mut().read_to_string(&mut body).await?;
-
-                    let delay_duration = {
-                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
-                        if let Some(time_to_reset) =
-                            response.headers().get("x-ratelimit-reset-tokens")
-                        {
-                            if let Ok(time_str) = time_to_reset.to_str() {
-                                parse(time_str).unwrap_or(delay)
-                            } else {
-                                delay
-                            }
-                        } else {
-                            delay
-                        }
-                    };
-
-                    // If we've previously rate limited, increment the duration but not the count
-                    let reset_time = Instant::now().add(delay_duration);
-                    self.update_reset_time(reset_time);
-
-                    log::trace!(
-                        "openai rate limiting: waiting {:?} until lifted",
-                        &delay_duration
-                    );
-
-                    self.executor.timer(delay_duration).await;
-                }
-                _ => {
-                    let mut body = String::new();
-                    response.body_mut().read_to_string(&mut body).await?;
-                    return Err(anyhow!(
-                        "open ai bad request: {:?} {:?}",
-                        &response.status(),
-                        body
-                    ));
-                }
-            }
-        }
-        Err(anyhow!("openai max retries"))
-    }
-}

crates/ai/src/providers/open_ai/model.rs 🔗

@@ -1,59 +0,0 @@
-use anyhow::anyhow;
-use tiktoken_rs::CoreBPE;
-
-use crate::models::{LanguageModel, TruncationDirection};
-
-use super::open_ai_bpe_tokenizer;
-
-#[derive(Clone)]
-pub struct OpenAiLanguageModel {
-    name: String,
-    bpe: Option<CoreBPE>,
-}
-
-impl OpenAiLanguageModel {
-    pub fn load(model_name: &str) -> Self {
-        let bpe = tiktoken_rs::get_bpe_from_model(model_name)
-            .unwrap_or(open_ai_bpe_tokenizer().to_owned());
-        OpenAiLanguageModel {
-            name: model_name.to_string(),
-            bpe: Some(bpe),
-        }
-    }
-}
-
-impl LanguageModel for OpenAiLanguageModel {
-    fn name(&self) -> String {
-        self.name.clone()
-    }
-    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
-        if let Some(bpe) = &self.bpe {
-            anyhow::Ok(bpe.encode_with_special_tokens(content).len())
-        } else {
-            Err(anyhow!("bpe for open ai model was not retrieved"))
-        }
-    }
-    fn truncate(
-        &self,
-        content: &str,
-        length: usize,
-        direction: TruncationDirection,
-    ) -> anyhow::Result<String> {
-        if let Some(bpe) = &self.bpe {
-            let tokens = bpe.encode_with_special_tokens(content);
-            if tokens.len() > length {
-                match direction {
-                    TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
-                    TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
-                }
-            } else {
-                bpe.decode(tokens)
-            }
-        } else {
-            Err(anyhow!("bpe for open ai model was not retrieved"))
-        }
-    }
-    fn capacity(&self) -> anyhow::Result<usize> {
-        anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
-    }
-}

crates/ai/src/test.rs 🔗

@@ -1,206 +0,0 @@
-use std::{
-    sync::atomic::{self, AtomicUsize, Ordering},
-    time::Instant,
-};
-
-use async_trait::async_trait;
-use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
-use gpui::AppContext;
-use parking_lot::Mutex;
-
-use crate::{
-    auth::{CredentialProvider, ProviderCredential},
-    completion::{CompletionProvider, CompletionRequest},
-    embedding::{Embedding, EmbeddingProvider},
-    models::{LanguageModel, TruncationDirection},
-};
-
-#[derive(Clone)]
-pub struct FakeLanguageModel {
-    pub capacity: usize,
-}
-
-impl LanguageModel for FakeLanguageModel {
-    fn name(&self) -> String {
-        "dummy".to_string()
-    }
-    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
-        anyhow::Ok(content.chars().collect::<Vec<char>>().len())
-    }
-    fn truncate(
-        &self,
-        content: &str,
-        length: usize,
-        direction: TruncationDirection,
-    ) -> anyhow::Result<String> {
-        println!("TRYING TO TRUNCATE: {:?}", length.clone());
-
-        if length > self.count_tokens(content)? {
-            println!("NOT TRUNCATING");
-            return anyhow::Ok(content.to_string());
-        }
-
-        anyhow::Ok(match direction {
-            TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
-                .into_iter()
-                .collect::<String>(),
-            TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
-                .into_iter()
-                .collect::<String>(),
-        })
-    }
-    fn capacity(&self) -> anyhow::Result<usize> {
-        anyhow::Ok(self.capacity)
-    }
-}
-
-#[derive(Default)]
-pub struct FakeEmbeddingProvider {
-    pub embedding_count: AtomicUsize,
-}
-
-impl Clone for FakeEmbeddingProvider {
-    fn clone(&self) -> Self {
-        FakeEmbeddingProvider {
-            embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
-        }
-    }
-}
-
-impl FakeEmbeddingProvider {
-    pub fn embedding_count(&self) -> usize {
-        self.embedding_count.load(atomic::Ordering::SeqCst)
-    }
-
-    pub fn embed_sync(&self, span: &str) -> Embedding {
-        let mut result = vec![1.0; 26];
-        for letter in span.chars() {
-            let letter = letter.to_ascii_lowercase();
-            if letter as u32 >= 'a' as u32 {
-                let ix = (letter as u32) - ('a' as u32);
-                if ix < 26 {
-                    result[ix as usize] += 1.0;
-                }
-            }
-        }
-
-        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
-        for x in &mut result {
-            *x /= norm;
-        }
-
-        result.into()
-    }
-}
-
-impl CredentialProvider for FakeEmbeddingProvider {
-    fn has_credentials(&self) -> bool {
-        true
-    }
-
-    fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
-        async { ProviderCredential::NotNeeded }.boxed()
-    }
-
-    fn save_credentials(
-        &self,
-        _cx: &mut AppContext,
-        _credential: ProviderCredential,
-    ) -> BoxFuture<()> {
-        async {}.boxed()
-    }
-
-    fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
-        async {}.boxed()
-    }
-}
-
-#[async_trait]
-impl EmbeddingProvider for FakeEmbeddingProvider {
-    fn base_model(&self) -> Box<dyn LanguageModel> {
-        Box::new(FakeLanguageModel { capacity: 1000 })
-    }
-    fn max_tokens_per_batch(&self) -> usize {
-        1000
-    }
-
-    fn rate_limit_expiration(&self) -> Option<Instant> {
-        None
-    }
-
-    async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
-        self.embedding_count
-            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
-
-        anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
-    }
-}
-
-pub struct FakeCompletionProvider {
-    last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
-}
-
-impl Clone for FakeCompletionProvider {
-    fn clone(&self) -> Self {
-        Self {
-            last_completion_tx: Mutex::new(None),
-        }
-    }
-}
-
-impl FakeCompletionProvider {
-    pub fn new() -> Self {
-        Self {
-            last_completion_tx: Mutex::new(None),
-        }
-    }
-
-    pub fn send_completion(&self, completion: impl Into<String>) {
-        let mut tx = self.last_completion_tx.lock();
-        tx.as_mut().unwrap().try_send(completion.into()).unwrap();
-    }
-
-    pub fn finish_completion(&self) {
-        self.last_completion_tx.lock().take().unwrap();
-    }
-}
-
-impl CredentialProvider for FakeCompletionProvider {
-    fn has_credentials(&self) -> bool {
-        true
-    }
-
-    fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
-        async { ProviderCredential::NotNeeded }.boxed()
-    }
-
-    fn save_credentials(
-        &self,
-        _cx: &mut AppContext,
-        _credential: ProviderCredential,
-    ) -> BoxFuture<()> {
-        async {}.boxed()
-    }
-
-    fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
-        async {}.boxed()
-    }
-}
-
-impl CompletionProvider for FakeCompletionProvider {
-    fn base_model(&self) -> Box<dyn LanguageModel> {
-        let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
-        model
-    }
-    fn complete(
-        &self,
-        _prompt: Box<dyn CompletionRequest>,
-    ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
-        let (tx, rx) = mpsc::channel(1);
-        *self.last_completion_tx.lock() = Some(tx);
-        async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
-    }
-    fn box_clone(&self) -> Box<dyn CompletionProvider> {
-        Box::new((*self).clone())
-    }
-}

crates/assistant/Cargo.toml 🔗

@@ -5,17 +5,14 @@ edition = "2021"
 publish = false
 license = "GPL-3.0-or-later"
 
-[lints]
-workspace = true
-
 [lib]
 path = "src/assistant.rs"
 doctest = false
 
 [dependencies]
-ai.workspace = true
 anyhow.workspace = true
 chrono.workspace = true
+client.workspace = true
 collections.workspace = true
 editor.workspace = true
 fs.workspace = true
@@ -26,12 +23,13 @@ language.workspace = true
 log.workspace = true
 menu.workspace = true
 multi_buffer.workspace = true
+open_ai = { workspace = true, features = ["schemars"] }
 ordered-float.workspace = true
+parking_lot.workspace = true
 project.workspace = true
 regex.workspace = true
 schemars.workspace = true
 search.workspace = true
-semantic_index.workspace = true
 serde.workspace = true
 serde_json.workspace = true
 settings.workspace = true
@@ -45,7 +43,6 @@ uuid.workspace = true
 workspace.workspace = true
 
 [dev-dependencies]
-ai = { workspace = true, features = ["test-support"] }
 ctor.workspace = true
 editor = { workspace = true, features = ["test-support"] }
 env_logger.workspace = true

crates/assistant/src/assistant.rs 🔗

@@ -1,22 +1,24 @@
 pub mod assistant_panel;
 pub mod assistant_settings;
 mod codegen;
+mod completion_provider;
 mod prompts;
+mod saved_conversation;
 mod streaming_diff;
 
-use ai::providers::open_ai::Role;
-use anyhow::Result;
 pub use assistant_panel::AssistantPanel;
-use assistant_settings::OpenAiModel;
+use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
 use chrono::{DateTime, Local};
-use collections::HashMap;
-use fs::Fs;
-use futures::StreamExt;
+use client::{proto, Client};
+pub(crate) use completion_provider::*;
 use gpui::{actions, AppContext, SharedString};
-use regex::Regex;
+pub(crate) use saved_conversation::*;
 use serde::{Deserialize, Serialize};
-use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc};
-use util::paths::CONVERSATIONS_DIR;
+use settings::Settings;
+use std::{
+    fmt::{self, Display},
+    sync::Arc,
+};
 
 actions!(
     assistant,
@@ -30,7 +32,6 @@ actions!(
         ResetKey,
         InlineAssist,
         ToggleIncludeConversation,
-        ToggleRetrieveContext,
     ]
 );
 
@@ -39,6 +40,139 @@ actions!(
 )]
 struct MessageId(usize);
 
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+    User,
+    Assistant,
+    System,
+}
+
+impl Role {
+    pub fn cycle(&mut self) {
+        *self = match self {
+            Role::User => Role::Assistant,
+            Role::Assistant => Role::System,
+            Role::System => Role::User,
+        }
+    }
+}
+
+impl Display for Role {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
+        match self {
+            Role::User => write!(f, "user"),
+            Role::Assistant => write!(f, "assistant"),
+            Role::System => write!(f, "system"),
+        }
+    }
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
+pub enum LanguageModel {
+    ZedDotDev(ZedDotDevModel),
+    OpenAi(OpenAiModel),
+}
+
+impl Default for LanguageModel {
+    fn default() -> Self {
+        LanguageModel::ZedDotDev(ZedDotDevModel::default())
+    }
+}
+
+impl LanguageModel {
+    pub fn telemetry_id(&self) -> String {
+        match self {
+            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
+            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
+        }
+    }
+
+    pub fn display_name(&self) -> String {
+        match self {
+            LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()),
+            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()),
+        }
+    }
+
+    pub fn max_token_count(&self) -> usize {
+        match self {
+            LanguageModel::OpenAi(model) => tiktoken_rs::model::get_context_size(model.id()),
+            LanguageModel::ZedDotDev(model) => match model {
+                ZedDotDevModel::GptThreePointFiveTurbo
+                | ZedDotDevModel::GptFour
+                | ZedDotDevModel::GptFourTurbo => tiktoken_rs::model::get_context_size(model.id()),
+                ZedDotDevModel::Custom(_) => 30720, // TODO: Base this on the selected model.
+            },
+        }
+    }
+
+    pub fn id(&self) -> &str {
+        match self {
+            LanguageModel::OpenAi(model) => model.id(),
+            LanguageModel::ZedDotDev(model) => model.id(),
+        }
+    }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct LanguageModelRequestMessage {
+    pub role: Role,
+    pub content: String,
+}
+
+impl LanguageModelRequestMessage {
+    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
+        proto::LanguageModelRequestMessage {
+            role: match self.role {
+                Role::User => proto::LanguageModelRole::LanguageModelUser,
+                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
+                Role::System => proto::LanguageModelRole::LanguageModelSystem,
+            } as i32,
+            content: self.content.clone(),
+        }
+    }
+}
+
+#[derive(Debug, Default, Serialize)]
+pub struct LanguageModelRequest {
+    pub model: LanguageModel,
+    pub messages: Vec<LanguageModelRequestMessage>,
+    pub stop: Vec<String>,
+    pub temperature: f32,
+}
+
+impl LanguageModelRequest {
+    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
+        proto::CompleteWithLanguageModel {
+            model: self.model.id().to_string(),
+            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
+            stop: self.stop.clone(),
+            temperature: self.temperature,
+        }
+    }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct LanguageModelResponseMessage {
+    pub role: Option<Role>,
+    pub content: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct LanguageModelUsage {
+    pub prompt_tokens: u32,
+    pub completion_tokens: u32,
+    pub total_tokens: u32,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct LanguageModelChoiceDelta {
+    pub index: u32,
+    pub delta: LanguageModelResponseMessage,
+    pub finish_reason: Option<String>,
+}
+
 #[derive(Clone, Debug, Serialize, Deserialize)]
 struct MessageMetadata {
     role: Role,
@@ -53,71 +187,9 @@ enum MessageStatus {
     Error(SharedString),
 }
 
-#[derive(Serialize, Deserialize)]
-struct SavedMessage {
-    id: MessageId,
-    start: usize,
-}
-
-#[derive(Serialize, Deserialize)]
-struct SavedConversation {
-    id: Option<String>,
-    zed: String,
-    version: String,
-    text: String,
-    messages: Vec<SavedMessage>,
-    message_metadata: HashMap<MessageId, MessageMetadata>,
-    summary: String,
-    api_url: Option<String>,
-    model: OpenAiModel,
-}
-
-impl SavedConversation {
-    const VERSION: &'static str = "0.1.0";
-}
-
-struct SavedConversationMetadata {
-    title: String,
-    path: PathBuf,
-    mtime: chrono::DateTime<chrono::Local>,
-}
-
-impl SavedConversationMetadata {
-    pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
-        fs.create_dir(&CONVERSATIONS_DIR).await?;
-
-        let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
-        let mut conversations = Vec::<SavedConversationMetadata>::new();
-        while let Some(path) = paths.next().await {
-            let path = path?;
-            if path.extension() != Some(OsStr::new("json")) {
-                continue;
-            }
-
-            let pattern = r" - \d+.zed.json$";
-            let re = Regex::new(pattern).unwrap();
-
-            let metadata = fs.metadata(&path).await?;
-            if let Some((file_name, metadata)) = path
-                .file_name()
-                .and_then(|name| name.to_str())
-                .zip(metadata)
-            {
-                let title = re.replace(file_name, "");
-                conversations.push(Self {
-                    title: title.into_owned(),
-                    path,
-                    mtime: metadata.mtime.into(),
-                });
-            }
-        }
-        conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
-
-        Ok(conversations)
-    }
-}
-
-pub fn init(cx: &mut AppContext) {
+pub fn init(client: Arc<Client>, cx: &mut AppContext) {
+    AssistantSettings::register(cx);
+    completion_provider::init(client, cx);
     assistant_panel::init(cx);
 }
 

crates/assistant/src/assistant_panel.rs 🔗

@@ -1,21 +1,13 @@
 use crate::{
-    assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAiModel},
+    assistant_settings::{AssistantDockPosition, AssistantSettings, ZedDotDevModel},
     codegen::{self, Codegen, CodegenKind},
     prompts::generate_content_prompt,
-    Assist, CycleMessageRole, InlineAssist, MessageId, MessageMetadata, MessageStatus,
+    Assist, CompletionProvider, CycleMessageRole, InlineAssist, LanguageModel,
+    LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, MessageStatus,
     NewConversation, QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata,
-    SavedMessage, Split, ToggleFocus, ToggleIncludeConversation, ToggleRetrieveContext,
+    SavedMessage, Split, ToggleFocus, ToggleIncludeConversation,
 };
-use ai::prompts::repository_context::PromptCodeSnippet;
-use ai::{
-    auth::ProviderCredential,
-    completion::{CompletionProvider, CompletionRequest},
-    providers::open_ai::{
-        OpenAiCompletionProvider, OpenAiCompletionProviderKind, OpenAiRequest, RequestMessage,
-        OPEN_AI_API_URL,
-    },
-};
-use anyhow::{anyhow, Result};
+use anyhow::Result;
 use chrono::{DateTime, Local};
 use collections::{hash_map, HashMap, HashSet, VecDeque};
 use editor::{
@@ -24,35 +16,25 @@ use editor::{
         BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint,
     },
     scroll::{Autoscroll, AutoscrollStrategy},
-    Anchor, Editor, EditorElement, EditorEvent, EditorStyle, MultiBufferSnapshot, ToOffset,
+    Anchor, Editor, EditorElement, EditorEvent, EditorStyle, MultiBufferSnapshot, ToOffset as _,
     ToPoint,
 };
 use fs::Fs;
 use futures::StreamExt;
 use gpui::{
-    canvas, div, point, relative, rems, uniform_list, Action, AnyElement, AppContext,
-    AsyncAppContext, AsyncWindowContext, ClipboardItem, Context, EventEmitter, FocusHandle,
-    FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement, IntoElement, Model,
-    ModelContext, ParentElement, Pixels, PromptLevel, Render, SharedString,
+    canvas, div, point, relative, rems, uniform_list, Action, AnyElement, AnyView, AppContext,
+    AsyncAppContext, AsyncWindowContext, AvailableSpace, ClipboardItem, Context, EventEmitter,
+    FocusHandle, FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement,
+    IntoElement, Model, ModelContext, ParentElement, Pixels, Render, SharedString,
     StatefulInteractiveElement, Styled, Subscription, Task, TextStyle, UniformListScrollHandle,
     View, ViewContext, VisualContext, WeakModel, WeakView, WhiteSpace, WindowContext,
 };
 use language::{language_settings::SoftWrap, Buffer, BufferId, LanguageRegistry, ToOffset as _};
+use parking_lot::Mutex;
 use project::Project;
 use search::{buffer_search::DivRegistrar, BufferSearchBar};
-use semantic_index::{SemanticIndex, SemanticIndexStatus};
 use settings::Settings;
-use std::{
-    cell::Cell,
-    cmp,
-    fmt::Write,
-    iter,
-    ops::Range,
-    path::{Path, PathBuf},
-    rc::Rc,
-    sync::Arc,
-    time::{Duration, Instant},
-};
+use std::{cmp, fmt::Write, iter, ops::Range, path::PathBuf, sync::Arc, time::Duration};
 use telemetry_events::AssistantKind;
 use theme::ThemeSettings;
 use ui::{
@@ -69,7 +51,6 @@ use workspace::{
 };
 
 pub fn init(cx: &mut AppContext) {
-    AssistantSettings::register(cx);
     cx.observe_new_views(
         |workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
             workspace
@@ -88,27 +69,29 @@ pub struct AssistantPanel {
     workspace: WeakView<Workspace>,
     width: Option<Pixels>,
     height: Option<Pixels>,
-    active_editor_index: Option<usize>,
-    prev_active_editor_index: Option<usize>,
-    editors: Vec<View<ConversationEditor>>,
+    active_conversation_editor: Option<ActiveConversationEditor>,
+    show_saved_conversations: bool,
     saved_conversations: Vec<SavedConversationMetadata>,
     saved_conversations_scroll_handle: UniformListScrollHandle,
     zoomed: bool,
     focus_handle: FocusHandle,
     toolbar: View<Toolbar>,
-    completion_provider: Arc<dyn CompletionProvider>,
-    api_key_editor: Option<View<Editor>>,
     languages: Arc<LanguageRegistry>,
     fs: Arc<dyn Fs>,
-    subscriptions: Vec<Subscription>,
+    _subscriptions: Vec<Subscription>,
     next_inline_assist_id: usize,
     pending_inline_assists: HashMap<usize, PendingInlineAssist>,
     pending_inline_assist_ids_by_editor: HashMap<WeakView<Editor>, Vec<usize>>,
     include_conversation_in_next_inline_assist: bool,
     inline_prompt_history: VecDeque<String>,
     _watch_saved_conversations: Task<Result<()>>,
-    semantic_index: Option<Model<SemanticIndex>>,
-    retrieve_context_in_next_inline_assist: bool,
+    model: LanguageModel,
+    authentication_prompt: Option<AnyView>,
+}
+
+struct ActiveConversationEditor {
+    editor: View<ConversationEditor>,
+    _subscriptions: Vec<Subscription>,
 }
 
 impl AssistantPanel {
@@ -124,22 +107,6 @@ impl AssistantPanel {
                 .await
                 .log_err()
                 .unwrap_or_default();
-            let (provider_kind, api_url, model_name) = cx.update(|cx| {
-                let settings = AssistantSettings::get_global(cx);
-                anyhow::Ok((
-                    settings.provider_kind()?,
-                    settings.provider_api_url()?,
-                    settings.provider_model_name()?,
-                ))
-            })??;
-
-            let completion_provider = OpenAiCompletionProvider::new(
-                api_url,
-                provider_kind,
-                model_name,
-                cx.background_executor().clone(),
-            )
-            .await;
 
             // TODO: deserialize state.
             let workspace_handle = workspace.clone();
@@ -168,41 +135,48 @@ impl AssistantPanel {
                     let toolbar = cx.new_view(|cx| {
                         let mut toolbar = Toolbar::new();
                         toolbar.set_can_navigate(false, cx);
-                        toolbar.add_item(cx.new_view(|cx| BufferSearchBar::new(cx)), cx);
+                        toolbar.add_item(cx.new_view(BufferSearchBar::new), cx);
                         toolbar
                     });
 
-                    let semantic_index = SemanticIndex::global(cx);
-
                     let focus_handle = cx.focus_handle();
-                    cx.on_focus_in(&focus_handle, Self::focus_in).detach();
-                    cx.on_focus_out(&focus_handle, Self::focus_out).detach();
+                    let subscriptions = vec![
+                        cx.on_focus_in(&focus_handle, Self::focus_in),
+                        cx.on_focus_out(&focus_handle, Self::focus_out),
+                        cx.observe_global::<CompletionProvider>({
+                            let mut prev_settings_version =
+                                CompletionProvider::global(cx).settings_version();
+                            move |this, cx| {
+                                this.completion_provider_changed(prev_settings_version, cx);
+                                prev_settings_version =
+                                    CompletionProvider::global(cx).settings_version();
+                            }
+                        }),
+                    ];
+                    let model = CompletionProvider::global(cx).default_model();
 
                     Self {
                         workspace: workspace_handle,
-                        active_editor_index: Default::default(),
-                        prev_active_editor_index: Default::default(),
-                        editors: Default::default(),
+                        active_conversation_editor: None,
+                        show_saved_conversations: false,
                         saved_conversations,
                         saved_conversations_scroll_handle: Default::default(),
                         zoomed: false,
                         focus_handle,
                         toolbar,
-                        completion_provider: Arc::new(completion_provider),
-                        api_key_editor: None,
                         languages: workspace.app_state().languages.clone(),
                         fs: workspace.app_state().fs.clone(),
                         width: None,
                         height: None,
-                        subscriptions: Default::default(),
+                        _subscriptions: subscriptions,
                         next_inline_assist_id: 0,
                         pending_inline_assists: Default::default(),
                         pending_inline_assist_ids_by_editor: Default::default(),
                         include_conversation_in_next_inline_assist: false,
                         inline_prompt_history: Default::default(),
                         _watch_saved_conversations,
-                        semantic_index,
-                        retrieve_context_in_next_inline_assist: false,
+                        model,
+                        authentication_prompt: None,
                     }
                 })
             })
@@ -214,14 +188,8 @@ impl AssistantPanel {
             .update(cx, |toolbar, cx| toolbar.focus_changed(true, cx));
         cx.notify();
         if self.focus_handle.is_focused(cx) {
-            if self.has_credentials() {
-                if let Some(editor) = self.active_editor() {
-                    cx.focus_view(editor);
-                }
-            }
-
-            if let Some(api_key_editor) = self.api_key_editor.as_ref() {
-                cx.focus_view(api_key_editor);
+            if let Some(editor) = self.active_conversation_editor() {
+                cx.focus_view(editor);
             }
         }
     }
@@ -232,6 +200,30 @@ impl AssistantPanel {
         cx.notify();
     }
 
+    fn completion_provider_changed(
+        &mut self,
+        prev_settings_version: usize,
+        cx: &mut ViewContext<Self>,
+    ) {
+        if self.is_authenticated(cx) {
+            self.authentication_prompt = None;
+
+            let model = CompletionProvider::global(cx).default_model();
+            self.set_model(model, cx);
+
+            if self.active_conversation_editor().is_none() {
+                self.new_conversation(cx);
+            }
+        } else if self.authentication_prompt.is_none()
+            || prev_settings_version != CompletionProvider::global(cx).settings_version()
+        {
+            self.authentication_prompt =
+                Some(cx.update_global::<CompletionProvider, _>(|provider, cx| {
+                    provider.authentication_prompt(cx)
+                }));
+        }
+    }
+
     pub fn inline_assist(
         workspace: &mut Workspace,
         _: &InlineAssist,
@@ -250,7 +242,7 @@ impl AssistantPanel {
         };
         let project = workspace.project().clone();
 
-        if assistant.update(cx, |assistant, _| assistant.has_credentials()) {
+        if assistant.update(cx, |assistant, cx| assistant.is_authenticated(cx)) {
             assistant.update(cx, |assistant, cx| {
                 assistant.new_inline_assist(&active_editor, cx, &project)
             });
@@ -258,9 +250,9 @@ impl AssistantPanel {
             let assistant = assistant.downgrade();
             cx.spawn(|workspace, mut cx| async move {
                 assistant
-                    .update(&mut cx, |assistant, cx| assistant.load_credentials(cx))?
-                    .await;
-                if assistant.update(&mut cx, |assistant, _| assistant.has_credentials())? {
+                    .update(&mut cx, |assistant, cx| assistant.authenticate(cx))?
+                    .await?;
+                if assistant.update(&mut cx, |assistant, cx| assistant.is_authenticated(cx))? {
                     assistant.update(&mut cx, |assistant, cx| {
                         assistant.new_inline_assist(&active_editor, cx, &project)
                     })?;
@@ -311,34 +303,11 @@ impl AssistantPanel {
         };
 
         let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
-        let provider = self.completion_provider.clone();
-
-        let codegen = cx.new_model(|cx| {
-            Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
-        });
 
-        if let Some(semantic_index) = self.semantic_index.clone() {
-            let project = project.clone();
-            cx.spawn(|_, mut cx| async move {
-                let previously_indexed = semantic_index
-                    .update(&mut cx, |index, cx| {
-                        index.project_previously_indexed(&project, cx)
-                    })?
-                    .await
-                    .unwrap_or(false);
-                if previously_indexed {
-                    let _ = semantic_index
-                        .update(&mut cx, |index, cx| {
-                            index.index_project(project.clone(), cx)
-                        })?
-                        .await;
-                }
-                anyhow::Ok(())
-            })
-            .detach_and_log_err(cx);
-        }
+        let codegen =
+            cx.new_model(|cx| Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, cx));
 
-        let measurements = Rc::new(Cell::new(BlockMeasurements::default()));
+        let measurements = Arc::new(Mutex::new(BlockMeasurements::default()));
         let inline_assistant = cx.new_view(|cx| {
             InlineAssistant::new(
                 inline_assist_id,
@@ -348,9 +317,6 @@ impl AssistantPanel {
                 codegen.clone(),
                 self.workspace.clone(),
                 cx,
-                self.retrieve_context_in_next_inline_assist,
-                self.semantic_index.clone(),
-                project.clone(),
             )
         });
         let block_id = editor.update(cx, |editor, cx| {
@@ -365,10 +331,10 @@ impl AssistantPanel {
                     render: Arc::new({
                         let inline_assistant = inline_assistant.clone();
                         move |cx: &mut BlockContext| {
-                            measurements.set(BlockMeasurements {
+                            *measurements.lock() = BlockMeasurements {
                                 anchor_x: cx.anchor_x,
                                 gutter_width: cx.gutter_dimensions.width,
-                            });
+                            };
                             inline_assistant.clone().into_any_element()
                         }
                     }),
@@ -456,7 +422,7 @@ impl AssistantPanel {
             .entry(editor.downgrade())
             .or_default()
             .push(inline_assist_id);
-        self.update_highlights_for_editor(&editor, cx);
+        self.update_highlights_for_editor(editor, cx);
     }
 
     fn handle_inline_assistant_event(
@@ -470,15 +436,8 @@ impl AssistantPanel {
             InlineAssistantEvent::Confirmed {
                 prompt,
                 include_conversation,
-                retrieve_context,
             } => {
-                self.confirm_inline_assist(
-                    assist_id,
-                    prompt,
-                    *include_conversation,
-                    cx,
-                    *retrieve_context,
-                );
+                self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx);
             }
             InlineAssistantEvent::Canceled => {
                 self.finish_inline_assist(assist_id, true, cx);
@@ -491,9 +450,6 @@ impl AssistantPanel {
             } => {
                 self.include_conversation_in_next_inline_assist = *include_conversation;
             }
-            InlineAssistantEvent::RetrieveContextToggled { retrieve_context } => {
-                self.retrieve_context_in_next_inline_assist = *retrieve_context
-            }
         }
     }
 
@@ -575,10 +531,9 @@ impl AssistantPanel {
         user_prompt: &str,
         include_conversation: bool,
         cx: &mut ViewContext<Self>,
-        retrieve_context: bool,
     ) {
         let conversation = if include_conversation {
-            self.active_editor()
+            self.active_conversation_editor()
                 .map(|editor| editor.read(cx).conversation.clone())
         } else {
             None
@@ -599,17 +554,13 @@ impl AssistantPanel {
 
         let project = pending_assist.project.clone();
 
-        let project_name = if let Some(project) = project.upgrade() {
-            Some(
-                project
-                    .read(cx)
-                    .worktree_root_names(cx)
-                    .collect::<Vec<&str>>()
-                    .join("/"),
-            )
-        } else {
-            None
-        };
+        let project_name = project.upgrade().map(|project| {
+            project
+                .read(cx)
+                .worktree_root_names(cx)
+                .collect::<Vec<&str>>()
+                .join("/")
+        });
 
         self.inline_prompt_history
             .retain(|prompt| prompt != user_prompt);
@@ -652,7 +603,7 @@ impl AssistantPanel {
         // If Markdown or No Language is Known, increase the randomness for more creative output
         // If Code, decrease temperature to get more deterministic outputs
         let temperature = if let Some(language) = language_name.clone() {
-            if *language != *"Markdown" {
+            if language.as_ref() != "Markdown" {
                 0.5
             } else {
                 1.0
@@ -663,61 +614,9 @@ impl AssistantPanel {
 
         let user_prompt = user_prompt.to_string();
 
-        let snippets = if retrieve_context {
-            let Some(project) = project.upgrade() else {
-                return;
-            };
-
-            let search_results = if let Some(semantic_index) = self.semantic_index.clone() {
-                let search_results = semantic_index.update(cx, |this, cx| {
-                    this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx)
-                });
-
-                cx.background_executor()
-                    .spawn(async move { search_results.await.unwrap_or_default() })
-            } else {
-                Task::ready(Vec::new())
-            };
-
-            let snippets = cx.spawn(|_, mut cx| async move {
-                let mut snippets = Vec::new();
-                for result in search_results.await {
-                    snippets.push(PromptCodeSnippet::new(
-                        result.buffer,
-                        result.range,
-                        &mut cx,
-                    )?);
-                }
-                anyhow::Ok(snippets)
-            });
-            snippets
-        } else {
-            Task::ready(Ok(Vec::new()))
-        };
-
-        let Some(mut model_name) = AssistantSettings::get_global(cx)
-            .provider_model_name()
-            .log_err()
-        else {
-            return;
-        };
-
-        let prompt = cx.background_executor().spawn({
-            let model_name = model_name.clone();
-            async move {
-                let snippets = snippets.await?;
-
-                let language_name = language_name.as_deref();
-                generate_content_prompt(
-                    user_prompt,
-                    language_name,
-                    buffer,
-                    range,
-                    snippets,
-                    &model_name,
-                    project_name,
-                )
-            }
+        let prompt = cx.background_executor().spawn(async move {
+            let language_name = language_name.as_deref();
+            generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
         });
 
         let mut messages = Vec::new();
@@ -729,25 +628,24 @@ impl AssistantPanel {
                     .messages(cx)
                     .map(|message| message.to_open_ai_message(buffer)),
             );
-            model_name = conversation.model.full_name().to_string();
         }
+        let model = self.model.clone();
 
         cx.spawn(|_, mut cx| async move {
             // I Don't know if we want to return a ? here.
             let prompt = prompt.await?;
 
-            messages.push(RequestMessage {
+            messages.push(LanguageModelRequestMessage {
                 role: Role::User,
                 content: prompt,
             });
 
-            let request = Box::new(OpenAiRequest {
-                model: model_name,
+            let request = LanguageModelRequest {
+                model,
                 messages,
-                stream: true,
                 stop: vec!["|END|>".to_string()],
                 temperature,
-            });
+            };
 
             codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
             anyhow::Ok(())
@@ -781,7 +679,7 @@ impl AssistantPanel {
             } else {
                 editor.highlight_background::<PendingInlineAssist>(
                     background_ranges,
-                    |theme| theme.editor_active_line_background, // todo("use the appropriate color")
+                    |theme| theme.editor_active_line_background, // todo!("use the appropriate color")
                     cx,
                 );
             }
@@ -801,54 +699,82 @@ impl AssistantPanel {
         });
     }
 
-    fn build_api_key_editor(&mut self, cx: &mut WindowContext<'_>) {
-        self.api_key_editor = Some(build_api_key_editor(cx));
-    }
-
     fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> View<ConversationEditor> {
         let editor = cx.new_view(|cx| {
             ConversationEditor::new(
-                self.completion_provider.clone(),
+                self.model.clone(),
                 self.languages.clone(),
                 self.fs.clone(),
                 self.workspace.clone(),
                 cx,
             )
         });
-        self.add_conversation(editor.clone(), cx);
+        self.show_conversation(editor.clone(), cx);
         editor
     }
 
-    fn add_conversation(&mut self, editor: View<ConversationEditor>, cx: &mut ViewContext<Self>) {
-        self.subscriptions
-            .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
+    fn show_conversation(
+        &mut self,
+        conversation_editor: View<ConversationEditor>,
+        cx: &mut ViewContext<Self>,
+    ) {
+        let mut subscriptions = Vec::new();
+        subscriptions
+            .push(cx.subscribe(&conversation_editor, Self::handle_conversation_editor_event));
 
-        let conversation = editor.read(cx).conversation.clone();
-        self.subscriptions
-            .push(cx.observe(&conversation, |_, _, cx| cx.notify()));
+        let conversation = conversation_editor.read(cx).conversation.clone();
+        subscriptions.push(cx.observe(&conversation, |_, _, cx| cx.notify()));
+
+        let editor = conversation_editor.read(cx).editor.clone();
+        self.toolbar.update(cx, |toolbar, cx| {
+            toolbar.set_active_item(Some(&editor), cx);
+        });
+        if self.focus_handle.contains_focused(cx) {
+            cx.focus_view(&editor);
+        }
+        self.active_conversation_editor = Some(ActiveConversationEditor {
+            editor: conversation_editor,
+            _subscriptions: subscriptions,
+        });
+        self.show_saved_conversations = false;
 
-        let index = self.editors.len();
-        self.editors.push(editor);
-        self.set_active_editor_index(Some(index), cx);
+        cx.notify();
     }
 
-    fn set_active_editor_index(&mut self, index: Option<usize>, cx: &mut ViewContext<Self>) {
-        self.prev_active_editor_index = self.active_editor_index;
-        self.active_editor_index = index;
-        if let Some(editor) = self.active_editor() {
-            let editor = editor.read(cx).editor.clone();
-            self.toolbar.update(cx, |toolbar, cx| {
-                toolbar.set_active_item(Some(&editor), cx);
-            });
-            if self.focus_handle.contains_focused(cx) {
-                cx.focus_view(&editor);
-            }
-        } else {
-            self.toolbar.update(cx, |toolbar, cx| {
-                toolbar.set_active_item(None, cx);
-            });
-        }
+    fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
+        let next_model = match &self.model {
+            LanguageModel::OpenAi(model) => LanguageModel::OpenAi(match &model {
+                open_ai::Model::ThreePointFiveTurbo => open_ai::Model::Four,
+                open_ai::Model::Four => open_ai::Model::FourTurbo,
+                open_ai::Model::FourTurbo => open_ai::Model::ThreePointFiveTurbo,
+            }),
+            LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
+                ZedDotDevModel::GptThreePointFiveTurbo => ZedDotDevModel::GptFour,
+                ZedDotDevModel::GptFour => ZedDotDevModel::GptFourTurbo,
+                ZedDotDevModel::GptFourTurbo => {
+                    match CompletionProvider::global(cx).default_model() {
+                        LanguageModel::ZedDotDev(custom) => custom,
+                        _ => ZedDotDevModel::GptThreePointFiveTurbo,
+                    }
+                }
+                ZedDotDevModel::Custom(_) => ZedDotDevModel::GptThreePointFiveTurbo,
+            }),
+        };
+
+        self.set_model(next_model, cx);
+    }
 
+    fn set_model(&mut self, model: LanguageModel, cx: &mut ViewContext<Self>) {
+        self.model = model.clone();
+        if let Some(editor) = self.active_conversation_editor() {
+            editor.update(cx, |active_conversation, cx| {
+                active_conversation
+                    .conversation
+                    .update(cx, |conversation, cx| {
+                        conversation.set_model(model, cx);
+                    })
+            })
+        }
         cx.notify();
     }
 
@@ -863,49 +789,6 @@ impl AssistantPanel {
         }
     }
 
-    fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
-        if let Some(api_key) = self
-            .api_key_editor
-            .as_ref()
-            .map(|editor| editor.read(cx).text(cx))
-        {
-            if !api_key.is_empty() {
-                let credential = ProviderCredential::Credentials {
-                    api_key: api_key.clone(),
-                };
-
-                let completion_provider = self.completion_provider.clone();
-                cx.spawn(|this, mut cx| async move {
-                    cx.update(|cx| completion_provider.save_credentials(cx, credential))?
-                        .await;
-
-                    this.update(&mut cx, |this, cx| {
-                        this.api_key_editor.take();
-                        this.focus_handle.focus(cx);
-                        cx.notify();
-                    })
-                })
-                .detach_and_log_err(cx);
-            }
-        } else {
-            cx.propagate();
-        }
-    }
-
-    fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
-        let completion_provider = self.completion_provider.clone();
-        cx.spawn(|this, mut cx| async move {
-            cx.update(|cx| completion_provider.delete_credentials(cx))?
-                .await;
-            this.update(&mut cx, |this, cx| {
-                this.build_api_key_editor(cx);
-                this.focus_handle.focus(cx);
-                cx.notify();
-            })
-        })
-        .detach_and_log_err(cx);
-    }
-
     fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext<Self>) {
         if self.zoomed {
             cx.emit(PanelEvent::ZoomOut)
@@ -958,58 +841,27 @@ impl AssistantPanel {
         }
     }
 
-    fn active_editor(&self) -> Option<&View<ConversationEditor>> {
-        self.editors.get(self.active_editor_index?)
+    fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
+        CompletionProvider::global(cx)
+            .reset_credentials(cx)
+            .detach_and_log_err(cx);
     }
 
-    fn render_api_key_editor(
-        &self,
-        editor: &View<Editor>,
-        cx: &mut ViewContext<Self>,
-    ) -> impl IntoElement {
-        let settings = ThemeSettings::get_global(cx);
-        let text_style = TextStyle {
-            color: if editor.read(cx).read_only(cx) {
-                cx.theme().colors().text_disabled
-            } else {
-                cx.theme().colors().text
-            },
-            font_family: settings.ui_font.family.clone(),
-            font_features: settings.ui_font.features,
-            font_size: rems(0.875).into(),
-            font_weight: FontWeight::NORMAL,
-            font_style: FontStyle::Normal,
-            line_height: relative(1.3),
-            background_color: None,
-            underline: None,
-            strikethrough: None,
-            white_space: WhiteSpace::Normal,
-        };
-        EditorElement::new(
-            &editor,
-            EditorStyle {
-                background: cx.theme().colors().editor_background,
-                local_player: cx.theme().players().local(),
-                text: text_style,
-                ..Default::default()
-            },
-        )
+    fn active_conversation_editor(&self) -> Option<&View<ConversationEditor>> {
+        Some(&self.active_conversation_editor.as_ref()?.editor)
     }
 
     fn render_hamburger_button(cx: &mut ViewContext<Self>) -> impl IntoElement {
         IconButton::new("hamburger_button", IconName::Menu)
             .on_click(cx.listener(|this, _event, cx| {
-                if this.active_editor().is_some() {
-                    this.set_active_editor_index(None, cx);
-                } else {
-                    this.set_active_editor_index(this.prev_active_editor_index, cx);
-                }
+                this.show_saved_conversations = !this.show_saved_conversations;
+                cx.notify();
             }))
             .tooltip(|cx| Tooltip::text("Conversation History", cx))
     }
 
     fn render_editor_tools(&self, cx: &mut ViewContext<Self>) -> Vec<AnyElement> {
-        if self.active_editor().is_some() {
+        if self.active_conversation_editor().is_some() {
             vec![
                 Self::render_split_button(cx).into_any_element(),
                 Self::render_quote_button(cx).into_any_element(),
@@ -1023,7 +875,7 @@ impl AssistantPanel {
     fn render_split_button(cx: &mut ViewContext<Self>) -> impl IntoElement {
         IconButton::new("split_button", IconName::Snip)
             .on_click(cx.listener(|this, _event, cx| {
-                if let Some(active_editor) = this.active_editor() {
+                if let Some(active_editor) = this.active_conversation_editor() {
                     active_editor.update(cx, |editor, cx| editor.split(&Default::default(), cx));
                 }
             }))
@@ -1034,7 +886,7 @@ impl AssistantPanel {
     fn render_assist_button(cx: &mut ViewContext<Self>) -> impl IntoElement {
         IconButton::new("assist_button", IconName::MagicWand)
             .on_click(cx.listener(|this, _event, cx| {
-                if let Some(active_editor) = this.active_editor() {
+                if let Some(active_editor) = this.active_conversation_editor() {
                     active_editor.update(cx, |editor, cx| editor.assist(&Default::default(), cx));
                 }
             }))
@@ -1111,202 +963,185 @@ impl AssistantPanel {
     fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
         cx.focus(&self.focus_handle);
 
-        if let Some(ix) = self.editor_index_for_path(&path, cx) {
-            self.set_active_editor_index(Some(ix), cx);
-            return Task::ready(Ok(()));
-        }
-
         let fs = self.fs.clone();
         let workspace = self.workspace.clone();
         let languages = self.languages.clone();
         cx.spawn(|this, mut cx| async move {
-            let saved_conversation = fs.load(&path).await?;
-            let saved_conversation = serde_json::from_str(&saved_conversation)?;
-            let conversation =
-                Conversation::deserialize(saved_conversation, path.clone(), languages, &mut cx)
-                    .await?;
+            let saved_conversation = SavedConversation::load(&path, fs.as_ref()).await?;
+            let model = this.update(&mut cx, |this, _| this.model.clone())?;
+            let conversation = Conversation::deserialize(
+                saved_conversation,
+                model,
+                path.clone(),
+                languages,
+                &mut cx,
+            )
+            .await?;
 
             this.update(&mut cx, |this, cx| {
-                // If, by the time we've loaded the conversation, the user has already opened
-                // the same conversation, we don't want to open it again.
-                if let Some(ix) = this.editor_index_for_path(&path, cx) {
-                    this.set_active_editor_index(Some(ix), cx);
-                } else {
-                    let editor = cx.new_view(|cx| {
-                        ConversationEditor::for_conversation(conversation, fs, workspace, cx)
-                    });
-                    this.add_conversation(editor, cx);
-                }
+                let editor = cx.new_view(|cx| {
+                    ConversationEditor::for_conversation(conversation, fs, workspace, cx)
+                });
+                this.show_conversation(editor, cx);
             })?;
             Ok(())
         })
     }
 
-    fn editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
-        self.editors
-            .iter()
-            .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
-    }
-
-    fn has_credentials(&mut self) -> bool {
-        self.completion_provider.has_credentials()
+    fn is_authenticated(&mut self, cx: &mut ViewContext<Self>) -> bool {
+        CompletionProvider::global(cx).is_authenticated()
     }
 
-    fn load_credentials(&mut self, cx: &mut ViewContext<Self>) -> Task<()> {
-        let completion_provider = self.completion_provider.clone();
-        cx.spawn(|_, mut cx| async move {
-            if let Some(retrieve_credentials) = cx
-                .update(|cx| completion_provider.retrieve_credentials(cx))
-                .log_err()
-            {
-                retrieve_credentials.await;
-            }
-        })
+    fn authenticate(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
+        cx.update_global::<CompletionProvider, _>(|provider, cx| provider.authenticate(cx))
     }
-}
 
-fn build_api_key_editor(cx: &mut WindowContext) -> View<Editor> {
-    cx.new_view(|cx| {
-        let mut editor = Editor::single_line(cx);
-        editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
-        editor
-    })
-}
-
-impl Render for AssistantPanel {
-    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
-        if let Some(api_key_editor) = self.api_key_editor.clone() {
-            const INSTRUCTIONS: [&'static str; 6] = [
-                "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
-                " - You can create an API key at: platform.openai.com/api-keys",
-                " - Make sure your OpenAI account has credits",
-                " - Having a subscription for another service like GitHub Copilot won't work.",
-                " ",
-                "Paste your OpenAI API key and press Enter to use the assistant:"
-            ];
-
-            v_flex()
-                .p_4()
-                .size_full()
-                .on_action(cx.listener(AssistantPanel::save_credentials))
-                .track_focus(&self.focus_handle)
-                .children(
-                    INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
-                )
-                .child(
-                    h_flex()
-                        .w_full()
-                        .my_2()
-                        .px_2()
-                        .py_1()
-                        .bg(cx.theme().colors().editor_background)
-                        .rounded_md()
-                        .child(self.render_api_key_editor(&api_key_editor, cx)),
-                )
-                .child(
+    fn render_signed_in(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+        let header = TabBar::new("assistant_header")
+            .start_child(
+                h_flex().gap_1().child(Self::render_hamburger_button(cx)), // .children(title),
+            )
+            .children(self.active_conversation_editor().map(|editor| {
+                h_flex()
+                    .h(rems(Tab::CONTAINER_HEIGHT_IN_REMS))
+                    .flex_1()
+                    .px_2()
+                    .child(Label::new(editor.read(cx).title(cx)).into_element())
+            }))
+            .when(self.focus_handle.contains_focused(cx), |this| {
+                this.end_child(
                     h_flex()
                         .gap_2()
-                        .child(Label::new("Click on").size(LabelSize::Small))
-                        .child(Icon::new(IconName::Ai).size(IconSize::XSmall))
+                        .when(self.active_conversation_editor().is_some(), |this| {
+                            this.child(h_flex().gap_1().children(self.render_editor_tools(cx)))
+                                .child(
+                                    ui::Divider::vertical()
+                                        .inset()
+                                        .color(ui::DividerColor::Border),
+                                )
+                        })
                         .child(
-                            Label::new("in the status bar to close this panel.")
-                                .size(LabelSize::Small),
+                            h_flex()
+                                .gap_1()
+                                .child(Self::render_plus_button(cx))
+                                .child(self.render_zoom_button(cx)),
                         ),
                 )
-        } else {
-            let header = TabBar::new("assistant_header")
-                .start_child(
-                    h_flex().gap_1().child(Self::render_hamburger_button(cx)), // .children(title),
-                )
-                .children(self.active_editor().map(|editor| {
-                    h_flex()
-                        .h(rems(Tab::CONTAINER_HEIGHT_IN_REMS))
-                        .flex_1()
-                        .px_2()
-                        .child(Label::new(editor.read(cx).title(cx)).into_element())
-                }))
-                .when(self.focus_handle.contains_focused(cx), |this| {
-                    this.end_child(
-                        h_flex()
-                            .gap_2()
-                            .when(self.active_editor().is_some(), |this| {
-                                this.child(h_flex().gap_1().children(self.render_editor_tools(cx)))
-                                    .child(
-                                        ui::Divider::vertical()
-                                            .inset()
-                                            .color(ui::DividerColor::Border),
-                                    )
-                            })
-                            .child(
-                                h_flex()
-                                    .gap_1()
-                                    .child(Self::render_plus_button(cx))
-                                    .child(self.render_zoom_button(cx)),
-                            ),
-                    )
-                });
+            });
 
-            let contents = if self.active_editor().is_some() {
-                let mut registrar = DivRegistrar::new(
-                    |panel, cx| panel.toolbar.read(cx).item_of_type::<BufferSearchBar>(),
-                    cx,
-                );
-                BufferSearchBar::register(&mut registrar);
-                registrar.into_div()
+        let contents = if self.active_conversation_editor().is_some() {
+            let mut registrar = DivRegistrar::new(
+                |panel, cx| panel.toolbar.read(cx).item_of_type::<BufferSearchBar>(),
+                cx,
+            );
+            BufferSearchBar::register(&mut registrar);
+            registrar.into_div()
+        } else {
+            div()
+        };
+        v_flex()
+            .key_context("AssistantPanel")
+            .size_full()
+            .on_action(cx.listener(|this, _: &workspace::NewFile, cx| {
+                this.new_conversation(cx);
+            }))
+            .on_action(cx.listener(AssistantPanel::toggle_zoom))
+            .on_action(cx.listener(AssistantPanel::deploy))
+            .on_action(cx.listener(AssistantPanel::select_next_match))
+            .on_action(cx.listener(AssistantPanel::select_prev_match))
+            .on_action(cx.listener(AssistantPanel::handle_editor_cancel))
+            .on_action(cx.listener(AssistantPanel::reset_credentials))
+            .track_focus(&self.focus_handle)
+            .child(header)
+            .children(if self.toolbar.read(cx).hidden() {
+                None
             } else {
-                div()
-            };
-            v_flex()
-                .key_context("AssistantPanel")
-                .size_full()
-                .on_action(cx.listener(|this, _: &workspace::NewFile, cx| {
-                    this.new_conversation(cx);
-                }))
-                .on_action(cx.listener(AssistantPanel::reset_credentials))
-                .on_action(cx.listener(AssistantPanel::toggle_zoom))
-                .on_action(cx.listener(AssistantPanel::deploy))
-                .on_action(cx.listener(AssistantPanel::select_next_match))
-                .on_action(cx.listener(AssistantPanel::select_prev_match))
-                .on_action(cx.listener(AssistantPanel::handle_editor_cancel))
-                .track_focus(&self.focus_handle)
-                .child(header)
-                .children(if self.toolbar.read(cx).hidden() {
-                    None
-                } else {
-                    Some(self.toolbar.clone())
-                })
-                .child(
-                    contents
-                        .flex_1()
-                        .child(if let Some(editor) = self.active_editor() {
-                            editor.clone().into_any_element()
-                        } else {
-                            let view = cx.view().clone();
-                            let scroll_handle = self.saved_conversations_scroll_handle.clone();
-                            let conversation_count = self.saved_conversations.len();
-                            canvas(
-                                move |bounds, cx| {
-                                    let mut list = uniform_list(
-                                        view,
-                                        "saved_conversations",
-                                        conversation_count,
-                                        |this, range, cx| {
-                                            range
-                                                .map(|ix| this.render_saved_conversation(ix, cx))
-                                                .collect()
-                                        },
-                                    )
-                                    .track_scroll(scroll_handle)
-                                    .into_any_element();
-                                    list.layout(bounds.origin, bounds.size.into(), cx);
-                                    list
+                Some(self.toolbar.clone())
+            })
+            .child(contents.flex_1().child(
+                if self.show_saved_conversations || self.active_conversation_editor().is_none() {
+                    let view = cx.view().clone();
+                    let scroll_handle = self.saved_conversations_scroll_handle.clone();
+                    let conversation_count = self.saved_conversations.len();
+                    canvas(
+                        move |bounds, cx| {
+                            let mut saved_conversations = uniform_list(
+                                view,
+                                "saved_conversations",
+                                conversation_count,
+                                |this, range, cx| {
+                                    range
+                                        .map(|ix| this.render_saved_conversation(ix, cx))
+                                        .collect()
                                 },
-                                |_bounds, mut list, cx| list.paint(cx),
                             )
-                            .size_full()
-                            .into_any_element()
-                        }),
-                )
+                            .track_scroll(scroll_handle)
+                            .into_any_element();
+                            saved_conversations.layout(
+                                bounds.origin,
+                                bounds.size.map(AvailableSpace::Definite),
+                                cx,
+                            );
+                            saved_conversations
+                        },
+                        |_bounds, mut saved_conversations, cx| saved_conversations.paint(cx),
+                    )
+                    .size_full()
+                    .into_any_element()
+                } else {
+                    let editor = self.active_conversation_editor().unwrap();
+                    let conversation = editor.read(cx).conversation.clone();
+                    div()
+                        .size_full()
+                        .child(editor.clone())
+                        .child(
+                            h_flex()
+                                .absolute()
+                                .gap_1()
+                                .top_3()
+                                .right_5()
+                                .child(self.render_model(&conversation, cx))
+                                .children(self.render_remaining_tokens(&conversation, cx)),
+                        )
+                        .into_any_element()
+                },
+            ))
+    }
+
+    fn render_model(
+        &self,
+        conversation: &Model<Conversation>,
+        cx: &mut ViewContext<Self>,
+    ) -> impl IntoElement {
+        Button::new("current_model", conversation.read(cx).model.display_name())
+            .style(ButtonStyle::Filled)
+            .tooltip(move |cx| Tooltip::text("Change Model", cx))
+            .on_click(cx.listener(|this, _, cx| this.cycle_model(cx)))
+    }
+
+    fn render_remaining_tokens(
+        &self,
+        conversation: &Model<Conversation>,
+        cx: &mut ViewContext<Self>,
+    ) -> Option<impl IntoElement> {
+        let remaining_tokens = conversation.read(cx).remaining_tokens()?;
+        let remaining_tokens_color = if remaining_tokens <= 0 {
+            Color::Error
+        } else if remaining_tokens <= 500 {
+            Color::Warning
+        } else {
+            Color::Default
+        };
+        Some(Label::new(remaining_tokens.to_string()).color(remaining_tokens_color))
+    }
+}
+
+impl Render for AssistantPanel {
+    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+        if let Some(authentication_prompt) = self.authentication_prompt.as_ref() {
+            authentication_prompt.clone().into_any()
+        } else {
+            self.render_signed_in(cx).into_any_element()
         }
     }
 }

crates/assistant/src/assistant_settings.rs 🔗

@@ -1,169 +1,296 @@
-use ai::providers::open_ai::{
-    AzureOpenAiApiVersion, OpenAiCompletionProviderKind, OPEN_AI_API_URL,
-};
-use anyhow::anyhow;
+use std::fmt;
+
 use gpui::Pixels;
-use schemars::JsonSchema;
-use serde::{Deserialize, Serialize};
+pub use open_ai::Model as OpenAiModel;
+use schemars::{
+    schema::{InstanceType, Metadata, Schema, SchemaObject},
+    JsonSchema,
+};
+use serde::{
+    de::{self, Visitor},
+    Deserialize, Deserializer, Serialize, Serializer,
+};
 use settings::Settings;
 
-#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
-#[serde(rename_all = "snake_case")]
-pub enum OpenAiModel {
-    #[serde(rename = "gpt-3.5-turbo-0613")]
-    ThreePointFiveTurbo,
-    #[serde(rename = "gpt-4-0613")]
-    Four,
-    #[serde(rename = "gpt-4-1106-preview")]
-    FourTurbo,
+#[derive(Clone, Debug, Default, PartialEq)]
+pub enum ZedDotDevModel {
+    GptThreePointFiveTurbo,
+    GptFour,
+    #[default]
+    GptFourTurbo,
+    Custom(String),
 }
 
-impl OpenAiModel {
-    pub fn full_name(&self) -> &'static str {
-        match self {
-            Self::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
-            Self::Four => "gpt-4-0613",
-            Self::FourTurbo => "gpt-4-1106-preview",
+impl Serialize for ZedDotDevModel {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        serializer.serialize_str(self.id())
+    }
+}
+
+impl<'de> Deserialize<'de> for ZedDotDevModel {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: Deserializer<'de>,
+    {
+        struct ZedDotDevModelVisitor;
+
+        impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
+            type Value = ZedDotDevModel;
+
+            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+                formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
+            }
+
+            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
+            where
+                E: de::Error,
+            {
+                match value {
+                    "gpt-3.5-turbo" => Ok(ZedDotDevModel::GptThreePointFiveTurbo),
+                    "gpt-4" => Ok(ZedDotDevModel::GptFour),
+                    "gpt-4-turbo-preview" => Ok(ZedDotDevModel::GptFourTurbo),
+                    _ => Ok(ZedDotDevModel::Custom(value.to_owned())),
+                }
+            }
         }
+
+        deserializer.deserialize_str(ZedDotDevModelVisitor)
     }
+}
+
+impl JsonSchema for ZedDotDevModel {
+    fn schema_name() -> String {
+        "ZedDotDevModel".to_owned()
+    }
+
+    fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
+        let variants = vec![
+            "gpt-3.5-turbo".to_owned(),
+            "gpt-4".to_owned(),
+            "gpt-4-turbo-preview".to_owned(),
+        ];
+        Schema::Object(SchemaObject {
+            instance_type: Some(InstanceType::String.into()),
+            enum_values: Some(variants.into_iter().map(|s| s.into()).collect()),
+            metadata: Some(Box::new(Metadata {
+                title: Some("ZedDotDevModel".to_owned()),
+                default: Some(serde_json::json!("gpt-4-turbo-preview")),
+                examples: vec![
+                    serde_json::json!("gpt-3.5-turbo"),
+                    serde_json::json!("gpt-4"),
+                    serde_json::json!("gpt-4-turbo-preview"),
+                    serde_json::json!("custom-model-name"),
+                ],
+                ..Default::default()
+            })),
+            ..Default::default()
+        })
+    }
+}
 
-    pub fn short_name(&self) -> &'static str {
+impl ZedDotDevModel {
+    pub fn id(&self) -> &str {
         match self {
-            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
-            Self::Four => "gpt-4",
-            Self::FourTurbo => "gpt-4-turbo",
+            Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
+            Self::GptFour => "gpt-4",
+            Self::GptFourTurbo => "gpt-4-turbo-preview",
+            Self::Custom(id) => id,
         }
     }
 
-    pub fn cycle(&self) -> Self {
+    pub fn display_name(&self) -> &str {
         match self {
-            Self::ThreePointFiveTurbo => Self::Four,
-            Self::Four => Self::FourTurbo,
-            Self::FourTurbo => Self::ThreePointFiveTurbo,
+            Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
+            Self::GptFour => "gpt-4",
+            Self::GptFourTurbo => "gpt-4-turbo",
+            Self::Custom(id) => id.as_str(),
         }
     }
 }
 
-#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
+#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
 #[serde(rename_all = "snake_case")]
 pub enum AssistantDockPosition {
     Left,
+    #[default]
     Right,
     Bottom,
 }
 
-#[derive(Debug, Deserialize)]
+#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
+#[serde(tag = "name", rename_all = "snake_case")]
+pub enum AssistantProvider {
+    #[serde(rename = "zed.dev")]
+    ZedDotDev {
+        #[serde(default)]
+        default_model: ZedDotDevModel,
+    },
+    #[serde(rename = "openai")]
+    OpenAi {
+        #[serde(default)]
+        default_model: OpenAiModel,
+        #[serde(default = "open_ai_url")]
+        api_url: String,
+    },
+}
+
+impl Default for AssistantProvider {
+    fn default() -> Self {
+        Self::ZedDotDev {
+            default_model: ZedDotDevModel::default(),
+        }
+    }
+}
+
+fn open_ai_url() -> String {
+    "https://api.openai.com/v1".into()
+}
+
+#[derive(Default, Debug, Deserialize, Serialize)]
 pub struct AssistantSettings {
-    /// Whether to show the assistant panel button in the status bar.
     pub button: bool,
-    /// Where to dock the assistant.
     pub dock: AssistantDockPosition,
-    /// Default width in pixels when the assistant is docked to the left or right.
     pub default_width: Pixels,
-    /// Default height in pixels when the assistant is docked to the bottom.
     pub default_height: Pixels,
-    /// The default OpenAI model to use when starting new conversations.
-    #[deprecated = "Please use `provider.default_model` instead."]
-    pub default_open_ai_model: OpenAiModel,
-    /// OpenAI API base URL to use when starting new conversations.
-    #[deprecated = "Please use `provider.api_url` instead."]
-    pub openai_api_url: String,
-    /// The settings for the AI provider.
-    pub provider: AiProviderSettings,
+    pub provider: AssistantProvider,
 }
 
-impl AssistantSettings {
-    pub fn provider_kind(&self) -> anyhow::Result<OpenAiCompletionProviderKind> {
-        match &self.provider {
-            AiProviderSettings::OpenAi(_) => Ok(OpenAiCompletionProviderKind::OpenAi),
-            AiProviderSettings::AzureOpenAi(settings) => {
-                let deployment_id = settings
-                    .deployment_id
-                    .clone()
-                    .ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?;
-                let api_version = settings
-                    .api_version
-                    .ok_or_else(|| anyhow!("no Azure OpenAI API version"))?;
-
-                Ok(OpenAiCompletionProviderKind::AzureOpenAi {
-                    deployment_id,
-                    api_version,
-                })
-            }
-        }
+/// Assistant panel settings
+#[derive(Clone, Serialize, Deserialize, Debug)]
+#[serde(untagged)]
+pub enum AssistantSettingsContent {
+    Versioned(VersionedAssistantSettingsContent),
+    Legacy(LegacyAssistantSettingsContent),
+}
+
+impl JsonSchema for AssistantSettingsContent {
+    fn schema_name() -> String {
+        VersionedAssistantSettingsContent::schema_name()
     }
 
-    pub fn provider_api_url(&self) -> anyhow::Result<String> {
-        match &self.provider {
-            AiProviderSettings::OpenAi(settings) => Ok(settings
-                .api_url
-                .clone()
-                .unwrap_or_else(|| OPEN_AI_API_URL.to_string())),
-            AiProviderSettings::AzureOpenAi(settings) => settings
-                .api_url
-                .clone()
-                .ok_or_else(|| anyhow!("no Azure OpenAI API URL")),
-        }
+    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
+        VersionedAssistantSettingsContent::json_schema(gen)
     }
 
-    pub fn provider_model(&self) -> anyhow::Result<OpenAiModel> {
-        match &self.provider {
-            AiProviderSettings::OpenAi(settings) => {
-                Ok(settings.default_model.unwrap_or(OpenAiModel::FourTurbo))
-            }
-            AiProviderSettings::AzureOpenAi(settings) => {
-                let deployment_id = settings
-                    .deployment_id
-                    .as_deref()
-                    .ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?;
-
-                match deployment_id {
-                    // https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-preview
-                    "gpt-4" | "gpt-4-32k" => Ok(OpenAiModel::Four),
-                    // https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-35
-                    "gpt-35-turbo" | "gpt-35-turbo-16k" | "gpt-35-turbo-instruct" => {
-                        Ok(OpenAiModel::ThreePointFiveTurbo)
+    fn is_referenceable() -> bool {
+        VersionedAssistantSettingsContent::is_referenceable()
+    }
+}
+
+impl Default for AssistantSettingsContent {
+    fn default() -> Self {
+        Self::Versioned(VersionedAssistantSettingsContent::default())
+    }
+}
+
+impl AssistantSettingsContent {
+    fn upgrade(&self) -> AssistantSettingsContentV1 {
+        match self {
+            AssistantSettingsContent::Versioned(settings) => match settings {
+                VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
+            },
+            AssistantSettingsContent::Legacy(settings) => {
+                if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
+                    AssistantSettingsContentV1 {
+                        button: settings.button,
+                        dock: settings.dock,
+                        default_width: settings.default_width,
+                        default_height: settings.default_height,
+                        provider: Some(AssistantProvider::OpenAi {
+                            default_model: settings
+                                .default_open_ai_model
+                                .clone()
+                                .unwrap_or_default(),
+                            api_url: open_ai_api_url.clone(),
+                        }),
+                    }
+                } else if let Some(open_ai_model) = settings.default_open_ai_model.clone() {
+                    AssistantSettingsContentV1 {
+                        button: settings.button,
+                        dock: settings.dock,
+                        default_width: settings.default_width,
+                        default_height: settings.default_height,
+                        provider: Some(AssistantProvider::OpenAi {
+                            default_model: open_ai_model,
+                            api_url: open_ai_url(),
+                        }),
+                    }
+                } else {
+                    AssistantSettingsContentV1 {
+                        button: settings.button,
+                        dock: settings.dock,
+                        default_width: settings.default_width,
+                        default_height: settings.default_height,
+                        provider: None,
                     }
-                    _ => Err(anyhow!(
-                        "no matching OpenAI model found for deployment ID: '{deployment_id}'"
-                    )),
                 }
             }
         }
     }
 
-    pub fn provider_model_name(&self) -> anyhow::Result<String> {
-        match &self.provider {
-            AiProviderSettings::OpenAi(settings) => Ok(settings
-                .default_model
-                .unwrap_or(OpenAiModel::FourTurbo)
-                .full_name()
-                .to_string()),
-            AiProviderSettings::AzureOpenAi(settings) => settings
-                .deployment_id
-                .clone()
-                .ok_or_else(|| anyhow!("no Azure OpenAI deployment ID")),
+    pub fn set_dock(&mut self, dock: AssistantDockPosition) {
+        match self {
+            AssistantSettingsContent::Versioned(settings) => match settings {
+                VersionedAssistantSettingsContent::V1(settings) => {
+                    settings.dock = Some(dock);
+                }
+            },
+            AssistantSettingsContent::Legacy(settings) => {
+                settings.dock = Some(dock);
+            }
         }
     }
 }
 
-impl Settings for AssistantSettings {
-    const KEY: Option<&'static str> = Some("assistant");
-
-    type FileContent = AssistantSettingsContent;
+#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
+#[serde(tag = "version")]
+pub enum VersionedAssistantSettingsContent {
+    #[serde(rename = "1")]
+    V1(AssistantSettingsContentV1),
+}
 
-    fn load(
-        default_value: &Self::FileContent,
-        user_values: &[&Self::FileContent],
-        _: &mut gpui::AppContext,
-    ) -> anyhow::Result<Self> {
-        Self::load_via_json_merge(default_value, user_values)
+impl Default for VersionedAssistantSettingsContent {
+    fn default() -> Self {
+        Self::V1(AssistantSettingsContentV1 {
+            button: None,
+            dock: None,
+            default_width: None,
+            default_height: None,
+            provider: None,
+        })
     }
 }
 
-/// Assistant panel settings
-#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
-pub struct AssistantSettingsContent {
+#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
+pub struct AssistantSettingsContentV1 {
+    /// Whether to show the assistant panel button in the status bar.
+    ///
+    /// Default: true
+    button: Option<bool>,
+    /// Where to dock the assistant.
+    ///
+    /// Default: right
+    dock: Option<AssistantDockPosition>,
+    /// Default width in pixels when the assistant is docked to the left or right.
+    ///
+    /// Default: 640
+    default_width: Option<f32>,
+    /// Default height in pixels when the assistant is docked to the bottom.
+    ///
+    /// Default: 320
+    default_height: Option<f32>,
+    /// The provider of the assistant service.
+    ///
+    /// This can either be the internal `zed.dev` service or an external `openai` service,
+    /// each with their respective default models and configurations.
+    provider: Option<AssistantProvider>,
+}
+
+#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
+pub struct LegacyAssistantSettingsContent {
     /// Whether to show the assistant panel button in the status bar.
     ///
     /// Default: true
@@ -180,88 +307,164 @@ pub struct AssistantSettingsContent {
     ///
     /// Default: 320
     pub default_height: Option<f32>,
-    /// Deprecated: Please use `provider.default_model` instead.
     /// The default OpenAI model to use when starting new conversations.
     ///
     /// Default: gpt-4-1106-preview
-    #[deprecated = "Please use `provider.default_model` instead."]
     pub default_open_ai_model: Option<OpenAiModel>,
-    /// Deprecated: Please use `provider.api_url` instead.
     /// OpenAI API base URL to use when starting new conversations.
     ///
     /// Default: https://api.openai.com/v1
-    #[deprecated = "Please use `provider.api_url` instead."]
     pub openai_api_url: Option<String>,
-    /// The settings for the AI provider.
-    #[serde(default)]
-    pub provider: AiProviderSettingsContent,
 }
 
-#[derive(Debug, Clone, Deserialize)]
-#[serde(tag = "type", rename_all = "snake_case")]
-pub enum AiProviderSettings {
-    /// The settings for the OpenAI provider.
-    #[serde(rename = "openai")]
-    OpenAi(OpenAiProviderSettings),
-    /// The settings for the Azure OpenAI provider.
-    #[serde(rename = "azure_openai")]
-    AzureOpenAi(AzureOpenAiProviderSettings),
-}
+impl Settings for AssistantSettings {
+    const KEY: Option<&'static str> = Some("assistant");
 
-/// The settings for the AI provider used by the Zed Assistant.
-#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
-#[serde(tag = "type", rename_all = "snake_case")]
-pub enum AiProviderSettingsContent {
-    /// The settings for the OpenAI provider.
-    #[serde(rename = "openai")]
-    OpenAi(OpenAiProviderSettingsContent),
-    /// The settings for the Azure OpenAI provider.
-    #[serde(rename = "azure_openai")]
-    AzureOpenAi(AzureOpenAiProviderSettingsContent),
-}
+    type FileContent = AssistantSettingsContent;
 
-impl Default for AiProviderSettingsContent {
-    fn default() -> Self {
-        Self::OpenAi(OpenAiProviderSettingsContent::default())
+    fn load(
+        default_value: &Self::FileContent,
+        user_values: &[&Self::FileContent],
+        _: &mut gpui::AppContext,
+    ) -> anyhow::Result<Self> {
+        let mut settings = AssistantSettings::default();
+
+        for value in [default_value].iter().chain(user_values) {
+            let value = value.upgrade();
+            merge(&mut settings.button, value.button);
+            merge(&mut settings.dock, value.dock);
+            merge(
+                &mut settings.default_width,
+                value.default_width.map(Into::into),
+            );
+            merge(
+                &mut settings.default_height,
+                value.default_height.map(Into::into),
+            );
+            if let Some(provider) = value.provider.clone() {
+                match (&mut settings.provider, provider) {
+                    (
+                        AssistantProvider::ZedDotDev { default_model },
+                        AssistantProvider::ZedDotDev {
+                            default_model: default_model_override,
+                        },
+                    ) => {
+                        *default_model = default_model_override;
+                    }
+                    (
+                        AssistantProvider::OpenAi {
+                            default_model,
+                            api_url,
+                        },
+                        AssistantProvider::OpenAi {
+                            default_model: default_model_override,
+                            api_url: api_url_override,
+                        },
+                    ) => {
+                        *default_model = default_model_override;
+                        *api_url = api_url_override;
+                    }
+                    (merged, provider_override) => {
+                        *merged = provider_override;
+                    }
+                }
+            }
+        }
+
+        Ok(settings)
     }
 }
 
-#[derive(Debug, Clone, Deserialize)]
-pub struct OpenAiProviderSettings {
-    /// The OpenAI API base URL to use when starting new conversations.
-    pub api_url: Option<String>,
-    /// The default OpenAI model to use when starting new conversations.
-    pub default_model: Option<OpenAiModel>,
+fn merge<T: Copy>(target: &mut T, value: Option<T>) {
+    if let Some(value) = value {
+        *target = value;
+    }
 }
 
-#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
-pub struct OpenAiProviderSettingsContent {
-    /// The OpenAI API base URL to use when starting new conversations.
-    ///
-    /// Default: https://api.openai.com/v1
-    pub api_url: Option<String>,
-    /// The default OpenAI model to use when starting new conversations.
-    ///
-    /// Default: gpt-4-1106-preview
-    pub default_model: Option<OpenAiModel>,
-}
+#[cfg(test)]
+mod tests {
+    use gpui::AppContext;
+    use settings::SettingsStore;
 
-#[derive(Debug, Clone, Deserialize)]
-pub struct AzureOpenAiProviderSettings {
-    /// The Azure OpenAI API base URL to use when starting new conversations.
-    pub api_url: Option<String>,
-    /// The Azure OpenAI API version.
-    pub api_version: Option<AzureOpenAiApiVersion>,
-    /// The Azure OpenAI API deployment ID.
-    pub deployment_id: Option<String>,
-}
+    use super::*;
+
+    #[gpui::test]
+    fn test_deserialize_assistant_settings(cx: &mut AppContext) {
+        let store = settings::SettingsStore::test(cx);
+        cx.set_global(store);
 
-#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
-pub struct AzureOpenAiProviderSettingsContent {
-    /// The Azure OpenAI API base URL to use when starting new conversations.
-    pub api_url: Option<String>,
-    /// The Azure OpenAI API version.
-    pub api_version: Option<AzureOpenAiApiVersion>,
-    /// The Azure OpenAI deployment ID.
-    pub deployment_id: Option<String>,
+        // Settings default to gpt-4-turbo.
+        AssistantSettings::register(cx);
+        assert_eq!(
+            AssistantSettings::get_global(cx).provider,
+            AssistantProvider::OpenAi {
+                default_model: OpenAiModel::FourTurbo,
+                api_url: open_ai_url()
+            }
+        );
+
+        // Ensure backward-compatibility.
+        cx.update_global::<SettingsStore, _>(|store, cx| {
+            store
+                .set_user_settings(
+                    r#"{
+                        "assistant": {
+                            "openai_api_url": "test-url",
+                        }
+                    }"#,
+                    cx,
+                )
+                .unwrap();
+        });
+        assert_eq!(
+            AssistantSettings::get_global(cx).provider,
+            AssistantProvider::OpenAi {
+                default_model: OpenAiModel::FourTurbo,
+                api_url: "test-url".into()
+            }
+        );
+        cx.update_global::<SettingsStore, _>(|store, cx| {
+            store
+                .set_user_settings(
+                    r#"{
+                        "assistant": {
+                            "default_open_ai_model": "gpt-4-0613"
+                        }
+                    }"#,
+                    cx,
+                )
+                .unwrap();
+        });
+        assert_eq!(
+            AssistantSettings::get_global(cx).provider,
+            AssistantProvider::OpenAi {
+                default_model: OpenAiModel::Four,
+                api_url: open_ai_url()
+            }
+        );
+
+        // The new version supports setting a custom model when using zed.dev.
+        cx.update_global::<SettingsStore, _>(|store, cx| {
+            store
+                .set_user_settings(
+                    r#"{
+                        "assistant": {
+                            "version": "1",
+                            "provider": {
+                                "name": "zed.dev",
+                                "default_model": "custom"
+                            }
+                        }
+                    }"#,
+                    cx,
+                )
+                .unwrap();
+        });
+        assert_eq!(
+            AssistantSettings::get_global(cx).provider,
+            AssistantProvider::ZedDotDev {
+                default_model: ZedDotDevModel::Custom("custom".into())
+            }
+        );
+    }
 }

crates/assistant/src/codegen.rs 🔗

@@ -1,12 +1,13 @@
-use crate::streaming_diff::{Hunk, StreamingDiff};
-use ai::completion::{CompletionProvider, CompletionRequest};
+use crate::{
+    streaming_diff::{Hunk, StreamingDiff},
+    CompletionProvider, LanguageModelRequest,
+};
 use anyhow::Result;
 use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
 use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
 use gpui::{EventEmitter, Model, ModelContext, Task};
 use language::{Rope, TransactionId};
-use multi_buffer;
-use std::{cmp, future, ops::Range, sync::Arc};
+use std::{cmp, future, ops::Range};
 
 pub enum Event {
     Finished,
@@ -20,7 +21,6 @@ pub enum CodegenKind {
 }
 
 pub struct Codegen {
-    provider: Arc<dyn CompletionProvider>,
     buffer: Model<MultiBuffer>,
     snapshot: MultiBufferSnapshot,
     kind: CodegenKind,
@@ -35,15 +35,9 @@ pub struct Codegen {
 impl EventEmitter<Event> for Codegen {}
 
 impl Codegen {
-    pub fn new(
-        buffer: Model<MultiBuffer>,
-        kind: CodegenKind,
-        provider: Arc<dyn CompletionProvider>,
-        cx: &mut ModelContext<Self>,
-    ) -> Self {
+    pub fn new(buffer: Model<MultiBuffer>, kind: CodegenKind, cx: &mut ModelContext<Self>) -> Self {
         let snapshot = buffer.read(cx).snapshot(cx);
         Self {
-            provider,
             buffer: buffer.clone(),
             snapshot,
             kind,
@@ -94,7 +88,7 @@ impl Codegen {
         self.error.as_ref()
     }
 
-    pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
+    pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
         let range = self.range();
         let snapshot = self.snapshot.clone();
         let selected_text = snapshot
@@ -108,7 +102,7 @@ impl Codegen {
             .next()
             .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
 
-        let response = self.provider.complete(prompt);
+        let response = CompletionProvider::global(cx).complete(prompt);
         self.generation = cx.spawn(|this, mut cx| {
             async move {
                 let generate = async {
@@ -305,7 +299,7 @@ fn strip_invalid_spans_from_codeblock(
         }
 
         if first_line {
-            if buffer == "" || buffer == "`" || buffer == "``" {
+            if buffer.is_empty() || buffer == "`" || buffer == "``" {
                 return future::ready(None);
             } else if buffer.starts_with("```") {
                 starts_with_markdown_codeblock = true;
@@ -360,8 +354,9 @@ fn strip_invalid_spans_from_codeblock(
 mod tests {
     use std::sync::Arc;
 
+    use crate::FakeCompletionProvider;
+
     use super::*;
-    use ai::test::FakeCompletionProvider;
     use futures::stream::{self};
     use gpui::{Context, TestAppContext};
     use indoc::indoc;
@@ -378,15 +373,11 @@ mod tests {
         pub name: String,
     }
 
-    impl CompletionRequest for DummyCompletionRequest {
-        fn data(&self) -> serde_json::Result<String> {
-            serde_json::to_string(self)
-        }
-    }
-
     #[gpui::test(iterations = 10)]
     async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
+        let provider = FakeCompletionProvider::default();
         cx.set_global(cx.update(SettingsStore::test));
+        cx.set_global(CompletionProvider::Fake(provider.clone()));
         cx.update(language_settings::init);
 
         let text = indoc! {"
@@ -405,19 +396,10 @@ mod tests {
             let snapshot = buffer.snapshot(cx);
             snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
         });
-        let provider = Arc::new(FakeCompletionProvider::new());
-        let codegen = cx.new_model(|cx| {
-            Codegen::new(
-                buffer.clone(),
-                CodegenKind::Transform { range },
-                provider.clone(),
-                cx,
-            )
-        });
+        let codegen =
+            cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Transform { range }, cx));
 
-        let request = Box::new(DummyCompletionRequest {
-            name: "test".to_string(),
-        });
+        let request = LanguageModelRequest::default();
         codegen.update(cx, |codegen, cx| codegen.start(request, cx));
 
         let mut new_text = concat!(
@@ -430,8 +412,7 @@ mod tests {
             let max_len = cmp::min(new_text.len(), 10);
             let len = rng.gen_range(1..=max_len);
             let (chunk, suffix) = new_text.split_at(len);
-            println!("CHUNK: {:?}", &chunk);
-            provider.send_completion(chunk);
+            provider.send_completion(chunk.into());
             new_text = suffix;
             cx.background_executor.run_until_parked();
         }
@@ -456,6 +437,8 @@ mod tests {
         cx: &mut TestAppContext,
         mut rng: StdRng,
     ) {
+        let provider = FakeCompletionProvider::default();
+        cx.set_global(CompletionProvider::Fake(provider.clone()));
         cx.set_global(cx.update(SettingsStore::test));
         cx.update(language_settings::init);
 
@@ -472,19 +455,10 @@ mod tests {
             let snapshot = buffer.snapshot(cx);
             snapshot.anchor_before(Point::new(1, 6))
         });
-        let provider = Arc::new(FakeCompletionProvider::new());
-        let codegen = cx.new_model(|cx| {
-            Codegen::new(
-                buffer.clone(),
-                CodegenKind::Generate { position },
-                provider.clone(),
-                cx,
-            )
-        });
+        let codegen =
+            cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
 
-        let request = Box::new(DummyCompletionRequest {
-            name: "test".to_string(),
-        });
+        let request = LanguageModelRequest::default();
         codegen.update(cx, |codegen, cx| codegen.start(request, cx));
 
         let mut new_text = concat!(
@@ -497,7 +471,7 @@ mod tests {
             let max_len = cmp::min(new_text.len(), 10);
             let len = rng.gen_range(1..=max_len);
             let (chunk, suffix) = new_text.split_at(len);
-            provider.send_completion(chunk);
+            provider.send_completion(chunk.into());
             new_text = suffix;
             cx.background_executor.run_until_parked();
         }
@@ -522,6 +496,8 @@ mod tests {
         cx: &mut TestAppContext,
         mut rng: StdRng,
     ) {
+        let provider = FakeCompletionProvider::default();
+        cx.set_global(CompletionProvider::Fake(provider.clone()));
         cx.set_global(cx.update(SettingsStore::test));
         cx.update(language_settings::init);
 
@@ -538,19 +514,10 @@ mod tests {
             let snapshot = buffer.snapshot(cx);
             snapshot.anchor_before(Point::new(1, 2))
         });
-        let provider = Arc::new(FakeCompletionProvider::new());
-        let codegen = cx.new_model(|cx| {
-            Codegen::new(
-                buffer.clone(),
-                CodegenKind::Generate { position },
-                provider.clone(),
-                cx,
-            )
-        });
+        let codegen =
+            cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
 
-        let request = Box::new(DummyCompletionRequest {
-            name: "test".to_string(),
-        });
+        let request = LanguageModelRequest::default();
         codegen.update(cx, |codegen, cx| codegen.start(request, cx));
 
         let mut new_text = concat!(
@@ -563,8 +530,7 @@ mod tests {
             let max_len = cmp::min(new_text.len(), 10);
             let len = rng.gen_range(1..=max_len);
             let (chunk, suffix) = new_text.split_at(len);
-            println!("{:?}", &chunk);
-            provider.send_completion(chunk);
+            provider.send_completion(chunk.into());
             new_text = suffix;
             cx.background_executor.run_until_parked();
         }

crates/assistant/src/completion_provider.rs 🔗

@@ -0,0 +1,188 @@
+#[cfg(test)]
+mod fake;
+mod open_ai;
+mod zed;
+
+#[cfg(test)]
+pub use fake::*;
+pub use open_ai::*;
+pub use zed::*;
+
+use crate::{
+    assistant_settings::{AssistantProvider, AssistantSettings},
+    LanguageModel, LanguageModelRequest,
+};
+use anyhow::Result;
+use client::Client;
+use futures::{future::BoxFuture, stream::BoxStream};
+use gpui::{AnyView, AppContext, Task, WindowContext};
+use settings::{Settings, SettingsStore};
+use std::sync::Arc;
+
+pub fn init(client: Arc<Client>, cx: &mut AppContext) {
+    let mut settings_version = 0;
+    let provider = match &AssistantSettings::get_global(cx).provider {
+        AssistantProvider::ZedDotDev { default_model } => {
+            CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
+                default_model.clone(),
+                client.clone(),
+                settings_version,
+                cx,
+            ))
+        }
+        AssistantProvider::OpenAi {
+            default_model,
+            api_url,
+        } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
+            default_model.clone(),
+            api_url.clone(),
+            client.http_client(),
+            settings_version,
+        )),
+    };
+    cx.set_global(provider);
+
+    cx.observe_global::<SettingsStore>(move |cx| {
+        settings_version += 1;
+        cx.update_global::<CompletionProvider, _>(|provider, cx| {
+            match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
+                (
+                    CompletionProvider::OpenAi(provider),
+                    AssistantProvider::OpenAi {
+                        default_model,
+                        api_url,
+                    },
+                ) => {
+                    provider.update(default_model.clone(), api_url.clone(), settings_version);
+                }
+                (
+                    CompletionProvider::ZedDotDev(provider),
+                    AssistantProvider::ZedDotDev { default_model },
+                ) => {
+                    provider.update(default_model.clone(), settings_version);
+                }
+                (CompletionProvider::OpenAi(_), AssistantProvider::ZedDotDev { default_model }) => {
+                    *provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
+                        default_model.clone(),
+                        client.clone(),
+                        settings_version,
+                        cx,
+                    ));
+                }
+                (
+                    CompletionProvider::ZedDotDev(_),
+                    AssistantProvider::OpenAi {
+                        default_model,
+                        api_url,
+                    },
+                ) => {
+                    *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
+                        default_model.clone(),
+                        api_url.clone(),
+                        client.http_client(),
+                        settings_version,
+                    ));
+                }
+                #[cfg(test)]
+                (CompletionProvider::Fake(_), _) => unimplemented!(),
+            }
+        })
+    })
+    .detach();
+}
+
+pub enum CompletionProvider {
+    OpenAi(OpenAiCompletionProvider),
+    ZedDotDev(ZedDotDevCompletionProvider),
+    #[cfg(test)]
+    Fake(FakeCompletionProvider),
+}
+
+impl gpui::Global for CompletionProvider {}
+
+impl CompletionProvider {
+    pub fn global(cx: &AppContext) -> &Self {
+        cx.global::<Self>()
+    }
+
+    pub fn settings_version(&self) -> usize {
+        match self {
+            CompletionProvider::OpenAi(provider) => provider.settings_version(),
+            CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
+            #[cfg(test)]
+            CompletionProvider::Fake(_) => unimplemented!(),
+        }
+    }
+
+    pub fn is_authenticated(&self) -> bool {
+        match self {
+            CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
+            CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
+            #[cfg(test)]
+            CompletionProvider::Fake(_) => true,
+        }
+    }
+
+    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+        match self {
+            CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
+            CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
+            #[cfg(test)]
+            CompletionProvider::Fake(_) => Task::ready(Ok(())),
+        }
+    }
+
+    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
+        match self {
+            CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
+            CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
+            #[cfg(test)]
+            CompletionProvider::Fake(_) => unimplemented!(),
+        }
+    }
+
+    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
+        match self {
+            CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
+            CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
+            #[cfg(test)]
+            CompletionProvider::Fake(_) => Task::ready(Ok(())),
+        }
+    }
+
+    pub fn default_model(&self) -> LanguageModel {
+        match self {
+            CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
+            CompletionProvider::ZedDotDev(provider) => {
+                LanguageModel::ZedDotDev(provider.default_model())
+            }
+            #[cfg(test)]
+            CompletionProvider::Fake(_) => unimplemented!(),
+        }
+    }
+
+    pub fn count_tokens(
+        &self,
+        request: LanguageModelRequest,
+        cx: &AppContext,
+    ) -> BoxFuture<'static, Result<usize>> {
+        match self {
+            CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
+            CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
+            #[cfg(test)]
+            CompletionProvider::Fake(_) => unimplemented!(),
+        }
+    }
+
+    pub fn complete(
+        &self,
+        request: LanguageModelRequest,
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+        match self {
+            CompletionProvider::OpenAi(provider) => provider.complete(request),
+            CompletionProvider::ZedDotDev(provider) => provider.complete(request),
+            #[cfg(test)]
+            CompletionProvider::Fake(provider) => provider.complete(),
+        }
+    }
+}

crates/assistant/src/completion_provider/fake.rs 🔗

@@ -0,0 +1,29 @@
+use anyhow::Result;
+use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use std::sync::Arc;
+
+#[derive(Clone, Default)]
+pub struct FakeCompletionProvider {
+    current_completion_tx: Arc<parking_lot::Mutex<Option<mpsc::UnboundedSender<String>>>>,
+}
+
+impl FakeCompletionProvider {
+    pub fn complete(&self) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+        let (tx, rx) = mpsc::unbounded();
+        *self.current_completion_tx.lock() = Some(tx);
+        async move { Ok(rx.map(Ok).boxed()) }.boxed()
+    }
+
+    pub fn send_completion(&self, chunk: String) {
+        self.current_completion_tx
+            .lock()
+            .as_ref()
+            .unwrap()
+            .unbounded_send(chunk)
+            .unwrap();
+    }
+
+    pub fn finish_completion(&self) {
+        self.current_completion_tx.lock().take();
+    }
+}

crates/assistant/src/completion_provider/open_ai.rs 🔗

@@ -0,0 +1,301 @@
+use crate::{
+    assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
+};
+use anyhow::{anyhow, Result};
+use editor::{Editor, EditorElement, EditorStyle};
+use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
+use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
+use settings::Settings;
+use std::{env, sync::Arc};
+use theme::ThemeSettings;
+use ui::prelude::*;
+use util::{http::HttpClient, ResultExt};
+
+pub struct OpenAiCompletionProvider {
+    api_key: Option<String>,
+    api_url: String,
+    default_model: OpenAiModel,
+    http_client: Arc<dyn HttpClient>,
+    settings_version: usize,
+}
+
+impl OpenAiCompletionProvider {
+    pub fn new(
+        default_model: OpenAiModel,
+        api_url: String,
+        http_client: Arc<dyn HttpClient>,
+        settings_version: usize,
+    ) -> Self {
+        Self {
+            api_key: None,
+            api_url,
+            default_model,
+            http_client,
+            settings_version,
+        }
+    }
+
+    pub fn update(&mut self, default_model: OpenAiModel, api_url: String, settings_version: usize) {
+        self.default_model = default_model;
+        self.api_url = api_url;
+        self.settings_version = settings_version;
+    }
+
+    pub fn settings_version(&self) -> usize {
+        self.settings_version
+    }
+
+    pub fn is_authenticated(&self) -> bool {
+        self.api_key.is_some()
+    }
+
+    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+        if self.is_authenticated() {
+            Task::ready(Ok(()))
+        } else {
+            let api_url = self.api_url.clone();
+            cx.spawn(|mut cx| async move {
+                let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+                    api_key
+                } else {
+                    let (_, api_key) = cx
+                        .update(|cx| cx.read_credentials(&api_url))?
+                        .await?
+                        .ok_or_else(|| anyhow!("credentials not found"))?;
+                    String::from_utf8(api_key)?
+                };
+                cx.update_global::<CompletionProvider, _>(|provider, _cx| {
+                    if let CompletionProvider::OpenAi(provider) = provider {
+                        provider.api_key = Some(api_key);
+                    }
+                })
+            })
+        }
+    }
+
+    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
+        let delete_credentials = cx.delete_credentials(&self.api_url);
+        cx.spawn(|mut cx| async move {
+            delete_credentials.await.log_err();
+            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
+                if let CompletionProvider::OpenAi(provider) = provider {
+                    provider.api_key = None;
+                }
+            })
+        })
+    }
+
+    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
+        cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
+            .into()
+    }
+
+    pub fn default_model(&self) -> OpenAiModel {
+        self.default_model.clone()
+    }
+
+    pub fn count_tokens(
+        &self,
+        request: LanguageModelRequest,
+        cx: &AppContext,
+    ) -> BoxFuture<'static, Result<usize>> {
+        count_open_ai_tokens(request, cx.background_executor())
+    }
+
+    pub fn complete(
+        &self,
+        request: LanguageModelRequest,
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+        let request = self.to_open_ai_request(request);
+
+        let http_client = self.http_client.clone();
+        let api_key = self.api_key.clone();
+        let api_url = self.api_url.clone();
+        async move {
+            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
+            let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
+            let response = request.await?;
+            let stream = response
+                .filter_map(|response| async move {
+                    match response {
+                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
+                        Err(error) => Some(Err(error)),
+                    }
+                })
+                .boxed();
+            Ok(stream)
+        }
+        .boxed()
+    }
+
+    fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
+        let model = match request.model {
+            LanguageModel::ZedDotDev(_) => self.default_model(),
+            LanguageModel::OpenAi(model) => model,
+        };
+
+        Request {
+            model,
+            messages: request
+                .messages
+                .into_iter()
+                .map(|msg| RequestMessage {
+                    role: msg.role.into(),
+                    content: msg.content,
+                })
+                .collect(),
+            stream: true,
+            stop: request.stop,
+            temperature: request.temperature,
+        }
+    }
+}
+
+pub fn count_open_ai_tokens(
+    request: LanguageModelRequest,
+    background_executor: &gpui::BackgroundExecutor,
+) -> BoxFuture<'static, Result<usize>> {
+    background_executor
+        .spawn(async move {
+            let messages = request
+                .messages
+                .into_iter()
+                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
+                    role: match message.role {
+                        Role::User => "user".into(),
+                        Role::Assistant => "assistant".into(),
+                        Role::System => "system".into(),
+                    },
+                    content: Some(message.content),
+                    name: None,
+                    function_call: None,
+                })
+                .collect::<Vec<_>>();
+
+            tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages)
+        })
+        .boxed()
+}
+
+impl From<Role> for open_ai::Role {
+    fn from(val: Role) -> Self {
+        match val {
+            Role::User => OpenAiRole::User,
+            Role::Assistant => OpenAiRole::Assistant,
+            Role::System => OpenAiRole::System,
+        }
+    }
+}
+
+struct AuthenticationPrompt {
+    api_key: View<Editor>,
+    api_url: String,
+}
+
+impl AuthenticationPrompt {
+    fn new(api_url: String, cx: &mut WindowContext) -> Self {
+        Self {
+            api_key: cx.new_view(|cx| {
+                let mut editor = Editor::single_line(cx);
+                editor.set_placeholder_text(
+                    "sk-000000000000000000000000000000000000000000000000",
+                    cx,
+                );
+                editor
+            }),
+            api_url,
+        }
+    }
+
+    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
+        let api_key = self.api_key.read(cx).text(cx);
+        if api_key.is_empty() {
+            return;
+        }
+
+        let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
+        cx.spawn(|_, mut cx| async move {
+            write_credentials.await?;
+            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
+                if let CompletionProvider::OpenAi(provider) = provider {
+                    provider.api_key = Some(api_key);
+                }
+            })
+        })
+        .detach_and_log_err(cx);
+    }
+
+    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+        let settings = ThemeSettings::get_global(cx);
+        let text_style = TextStyle {
+            color: cx.theme().colors().text,
+            font_family: settings.ui_font.family.clone(),
+            font_features: settings.ui_font.features,
+            font_size: rems(0.875).into(),
+            font_weight: FontWeight::NORMAL,
+            font_style: FontStyle::Normal,
+            line_height: relative(1.3),
+            background_color: None,
+            underline: None,
+            strikethrough: None,
+            white_space: WhiteSpace::Normal,
+        };
+        EditorElement::new(
+            &self.api_key,
+            EditorStyle {
+                background: cx.theme().colors().editor_background,
+                local_player: cx.theme().players().local(),
+                text: text_style,
+                ..Default::default()
+            },
+        )
+    }
+}
+
+impl Render for AuthenticationPrompt {
+    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+        const INSTRUCTIONS: [&str; 6] = [
+            "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
+            " - You can create an API key at: platform.openai.com/api-keys",
+            " - Make sure your OpenAI account has credits",
+            " - Having a subscription for another service like GitHub Copilot won't work.",
+            "",
+            "Paste your OpenAI API key below and hit enter to use the assistant:",
+        ];
+
+        v_flex()
+            .p_4()
+            .size_full()
+            .on_action(cx.listener(Self::save_api_key))
+            .children(
+                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
+            )
+            .child(
+                h_flex()
+                    .w_full()
+                    .my_2()
+                    .px_2()
+                    .py_1()
+                    .bg(cx.theme().colors().editor_background)
+                    .rounded_md()
+                    .child(self.render_api_key_editor(cx)),
+            )
+            .child(
+                Label::new(
+                    "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
+                )
+                .size(LabelSize::Small),
+            )
+            .child(
+                h_flex()
+                    .gap_2()
+                    .child(Label::new("Click on").size(LabelSize::Small))
+                    .child(Icon::new(IconName::Ai).size(IconSize::XSmall))
+                    .child(
+                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
+                    ),
+            )
+            .into_any()
+    }
+}

crates/assistant/src/completion_provider/zed.rs 🔗

@@ -0,0 +1,167 @@
+use crate::{
+    assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider,
+    LanguageModelRequest,
+};
+use anyhow::{anyhow, Result};
+use client::{proto, Client};
+use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
+use gpui::{AnyView, AppContext, Task};
+use std::{future, sync::Arc};
+use ui::prelude::*;
+
+pub struct ZedDotDevCompletionProvider {
+    client: Arc<Client>,
+    default_model: ZedDotDevModel,
+    settings_version: usize,
+    status: client::Status,
+    _maintain_client_status: Task<()>,
+}
+
+impl ZedDotDevCompletionProvider {
+    pub fn new(
+        default_model: ZedDotDevModel,
+        client: Arc<Client>,
+        settings_version: usize,
+        cx: &mut AppContext,
+    ) -> Self {
+        let mut status_rx = client.status();
+        let status = *status_rx.borrow();
+        let maintain_client_status = cx.spawn(|mut cx| async move {
+            while let Some(status) = status_rx.next().await {
+                let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
+                    if let CompletionProvider::ZedDotDev(provider) = provider {
+                        provider.status = status;
+                    } else {
+                        unreachable!()
+                    }
+                });
+            }
+        });
+        Self {
+            client,
+            default_model,
+            settings_version,
+            status,
+            _maintain_client_status: maintain_client_status,
+        }
+    }
+
+    pub fn update(&mut self, default_model: ZedDotDevModel, settings_version: usize) {
+        self.default_model = default_model;
+        self.settings_version = settings_version;
+    }
+
+    pub fn settings_version(&self) -> usize {
+        self.settings_version
+    }
+
+    pub fn default_model(&self) -> ZedDotDevModel {
+        self.default_model.clone()
+    }
+
+    pub fn is_authenticated(&self) -> bool {
+        self.status.is_connected()
+    }
+
+    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+        let client = self.client.clone();
+        cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
+    }
+
+    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
+        cx.new_view(|_cx| AuthenticationPrompt).into()
+    }
+
+    pub fn count_tokens(
+        &self,
+        request: LanguageModelRequest,
+        cx: &AppContext,
+    ) -> BoxFuture<'static, Result<usize>> {
+        match request.model {
+            crate::LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
+            crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFour)
+            | crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFourTurbo)
+            | crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptThreePointFiveTurbo) => {
+                count_open_ai_tokens(request, cx.background_executor())
+            }
+            crate::LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
+                let request = self.client.request(proto::CountTokensWithLanguageModel {
+                    model,
+                    messages: request
+                        .messages
+                        .iter()
+                        .map(|message| message.to_proto())
+                        .collect(),
+                });
+                async move {
+                    let response = request.await?;
+                    Ok(response.token_count as usize)
+                }
+                .boxed()
+            }
+        }
+    }
+
+    pub fn complete(
+        &self,
+        request: LanguageModelRequest,
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+        let request = proto::CompleteWithLanguageModel {
+            model: request.model.id().to_string(),
+            messages: request
+                .messages
+                .iter()
+                .map(|message| message.to_proto())
+                .collect(),
+            stop: request.stop,
+            temperature: request.temperature,
+        };
+
+        self.client
+            .request_stream(request)
+            .map_ok(|stream| {
+                stream
+                    .filter_map(|response| async move {
+                        match response {
+                            Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
+                            Err(error) => Some(Err(error)),
+                        }
+                    })
+                    .boxed()
+            })
+            .boxed()
+    }
+}
+
+struct AuthenticationPrompt;
+
+impl Render for AuthenticationPrompt {
+    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
+        const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
+
+        v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
+            v_flex()
+                .gap_2()
+                .child(
+                    Button::new("sign_in", "Sign in")
+                        .icon_color(Color::Muted)
+                        .icon(IconName::Github)
+                        .icon_position(IconPosition::Start)
+                        .style(ButtonStyle::Filled)
+                        .full_width()
+                        .on_click(|_, cx| {
+                            CompletionProvider::global(cx)
+                                .authenticate(cx)
+                                .detach_and_log_err(cx);
+                        }),
+                )
+                .child(
+                    div().flex().w_full().items_center().child(
+                        Label::new("Sign in to enable collaboration.")
+                            .color(Color::Muted)
+                            .size(LabelSize::Small),
+                    ),
+                ),
+        )
+    }
+}

crates/assistant/src/prompts.rs 🔗

@@ -1,394 +1,95 @@
-use ai::models::LanguageModel;
-use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
-use ai::prompts::file_context::FileContext;
-use ai::prompts::generate::GenerateInlineContent;
-use ai::prompts::preamble::EngineerPreamble;
-use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
-use ai::providers::open_ai::OpenAiLanguageModel;
-use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
-use std::cmp::{self, Reverse};
-use std::ops::Range;
-use std::sync::Arc;
-
-#[allow(dead_code)]
-fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
-    #[derive(Debug)]
-    struct Match {
-        collapse: Range<usize>,
-        keep: Vec<Range<usize>>,
-    }
-
-    let selected_range = selected_range.to_offset(buffer);
-    let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
-        Some(&grammar.embedding_config.as_ref()?.query)
-    });
-    let configs = ts_matches
-        .grammars()
-        .iter()
-        .map(|g| g.embedding_config.as_ref().unwrap())
-        .collect::<Vec<_>>();
-    let mut matches = Vec::new();
-    while let Some(mat) = ts_matches.peek() {
-        let config = &configs[mat.grammar_index];
-        if let Some(collapse) = mat.captures.iter().find_map(|cap| {
-            if Some(cap.index) == config.collapse_capture_ix {
-                Some(cap.node.byte_range())
-            } else {
-                None
-            }
-        }) {
-            let mut keep = Vec::new();
-            for capture in mat.captures.iter() {
-                if Some(capture.index) == config.keep_capture_ix {
-                    keep.push(capture.node.byte_range());
-                } else {
-                    continue;
-                }
-            }
-            ts_matches.advance();
-            matches.push(Match { collapse, keep });
-        } else {
-            ts_matches.advance();
-        }
-    }
-    matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
-    let mut matches = matches.into_iter().peekable();
-
-    let mut summary = String::new();
-    let mut offset = 0;
-    let mut flushed_selection = false;
-    while let Some(mat) = matches.next() {
-        // Keep extending the collapsed range if the next match surrounds
-        // the current one.
-        while let Some(next_mat) = matches.peek() {
-            if mat.collapse.start <= next_mat.collapse.start
-                && mat.collapse.end >= next_mat.collapse.end
-            {
-                matches.next().unwrap();
-            } else {
-                break;
-            }
-        }
-
-        if offset > mat.collapse.start {
-            // Skip collapsed nodes that have already been summarized.
-            offset = cmp::max(offset, mat.collapse.end);
-            continue;
-        }
-
-        if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
-            if !flushed_selection {
-                // The collapsed node ends after the selection starts, so we'll flush the selection first.
-                summary.extend(buffer.text_for_range(offset..selected_range.start));
-                summary.push_str("<|S|");
-                if selected_range.end == selected_range.start {
-                    summary.push_str(">");
-                } else {
-                    summary.extend(buffer.text_for_range(selected_range.clone()));
-                    summary.push_str("|E|>");
-                }
-                offset = selected_range.end;
-                flushed_selection = true;
-            }
-
-            // If the selection intersects the collapsed node, we won't collapse it.
-            if selected_range.end >= mat.collapse.start {
-                continue;
-            }
-        }
-
-        summary.extend(buffer.text_for_range(offset..mat.collapse.start));
-        for keep in mat.keep {
-            summary.extend(buffer.text_for_range(keep));
-        }
-        offset = mat.collapse.end;
-    }
-
-    // Flush selection if we haven't already done so.
-    if !flushed_selection && offset <= selected_range.start {
-        summary.extend(buffer.text_for_range(offset..selected_range.start));
-        summary.push_str("<|S|");
-        if selected_range.end == selected_range.start {
-            summary.push_str(">");
-        } else {
-            summary.extend(buffer.text_for_range(selected_range.clone()));
-            summary.push_str("|E|>");
-        }
-        offset = selected_range.end;
-    }
-
-    summary.extend(buffer.text_for_range(offset..buffer.len()));
-    summary
-}
+use language::BufferSnapshot;
+use std::{fmt::Write, ops::Range};
 
 pub fn generate_content_prompt(
     user_prompt: String,
     language_name: Option<&str>,
     buffer: BufferSnapshot,
     range: Range<usize>,
-    search_results: Vec<PromptCodeSnippet>,
-    model: &str,
     project_name: Option<String>,
 ) -> anyhow::Result<String> {
-    // Using new Prompt Templates
-    let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAiLanguageModel::load(model));
-    let lang_name = if let Some(language_name) = language_name {
-        Some(language_name.to_string())
-    } else {
-        None
-    };
+    let mut prompt = String::new();
 
-    let args = PromptArguments {
-        model: openai_model,
-        language_name: lang_name.clone(),
-        project_name,
-        snippets: search_results.clone(),
-        reserved_tokens: 1000,
-        buffer: Some(buffer),
-        selected_range: Some(range),
-        user_prompt: Some(user_prompt.clone()),
+    let content_type = match language_name {
+        None | Some("Markdown" | "Plain Text") => {
+            writeln!(prompt, "You are an expert engineer.")?;
+            "Text"
+        }
+        Some(language_name) => {
+            writeln!(prompt, "You are an expert {language_name} engineer.")?;
+            writeln!(
+                prompt,
+                "Your answer MUST always and only be valid {}.",
+                language_name
+            )?;
+            "Code"
+        }
     };
 
-    let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
-        (PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
-        (
-            PromptPriority::Ordered { order: 1 },
-            Box::new(RepositoryContext {}),
-        ),
-        (
-            PromptPriority::Ordered { order: 0 },
-            Box::new(FileContext {}),
-        ),
-        (
-            PromptPriority::Mandatory,
-            Box::new(GenerateInlineContent {}),
-        ),
-    ];
-    let chain = PromptChain::new(args, templates);
-    let (prompt, _) = chain.generate(true)?;
-
-    anyhow::Ok(prompt)
-}
+    if let Some(project_name) = project_name {
+        writeln!(
+            prompt,
+            "You are currently working inside the '{project_name}' project in code editor Zed."
+        )?;
+    }
 
-#[cfg(test)]
-pub(crate) mod tests {
-    use super::*;
-    use gpui::{AppContext, Context};
-    use indoc::indoc;
-    use language::{
-        language_settings, tree_sitter_rust, Buffer, BufferId, Language, LanguageConfig,
-        LanguageMatcher, Point,
-    };
-    use settings::SettingsStore;
-    use std::sync::Arc;
+    // Include file content.
+    for chunk in buffer.text_for_range(0..range.start) {
+        prompt.push_str(chunk);
+    }
 
-    pub(crate) fn rust_lang() -> Language {
-        Language::new(
-            LanguageConfig {
-                name: "Rust".into(),
-                matcher: LanguageMatcher {
-                    path_suffixes: vec!["rs".to_string()],
-                    ..Default::default()
-                },
-                ..Default::default()
-            },
-            Some(tree_sitter_rust::language()),
-        )
-        .with_embedding_query(
-            r#"
-            (
-                [(line_comment) (attribute_item)]* @context
-                .
-                [
-                    (struct_item
-                        name: (_) @name)
+    if range.is_empty() {
+        prompt.push_str("<|START|>");
+    } else {
+        prompt.push_str("<|START|");
+    }
 
-                    (enum_item
-                        name: (_) @name)
+    for chunk in buffer.text_for_range(range.clone()) {
+        prompt.push_str(chunk);
+    }
 
-                    (impl_item
-                        trait: (_)? @name
-                        "for"? @name
-                        type: (_) @name)
+    if !range.is_empty() {
+        prompt.push_str("|END|>");
+    }
 
-                    (trait_item
-                        name: (_) @name)
+    for chunk in buffer.text_for_range(range.end..buffer.len()) {
+        prompt.push_str(chunk);
+    }
 
-                    (function_item
-                        name: (_) @name
-                        body: (block
-                            "{" @keep
-                            "}" @keep) @collapse)
+    prompt.push('\n');
 
-                    (macro_definition
-                        name: (_) @name)
-                    ] @item
-                )
-            "#,
+    if range.is_empty() {
+        writeln!(
+            prompt,
+            "Assume the cursor is located where the `<|START|>` span is."
+        )
+        .unwrap();
+        writeln!(
+            prompt,
+            "{content_type} can't be replaced, so assume your answer will be inserted at the cursor.",
         )
-        .unwrap()
+        .unwrap();
+        writeln!(
+            prompt,
+            "Generate {content_type} based on the users prompt: {user_prompt}",
+        )
+        .unwrap();
+    } else {
+        writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
+        writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
+        writeln!(
+            prompt,
+            "Double check that you only return code and not the '<|START|' and '|END|'> spans"
+        )
+        .unwrap();
     }
 
-    #[gpui::test]
-    fn test_outline_for_prompt(cx: &mut AppContext) {
-        let settings_store = SettingsStore::test(cx);
-        cx.set_global(settings_store);
-        language_settings::init(cx);
-        let text = indoc! {"
-            struct X {
-                a: usize,
-                b: usize,
-            }
-
-            impl X {
-
-                fn new() -> Self {
-                    let a = 1;
-                    let b = 2;
-                    Self { a, b }
-                }
-
-                pub fn a(&self, param: bool) -> usize {
-                    self.a
-                }
-
-                pub fn b(&self) -> usize {
-                    self.b
-                }
-            }
-        "};
-        let buffer = cx.new_model(|cx| {
-            Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
-        });
-        let snapshot = buffer.read(cx).snapshot();
-
-        assert_eq!(
-            summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
-            indoc! {"
-                struct X {
-                    <|S|>a: usize,
-                    b: usize,
-                }
-
-                impl X {
-
-                    fn new() -> Self {}
-
-                    pub fn a(&self, param: bool) -> usize {}
-
-                    pub fn b(&self) -> usize {}
-                }
-            "}
-        );
-
-        assert_eq!(
-            summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
-            indoc! {"
-                struct X {
-                    a: usize,
-                    b: usize,
-                }
-
-                impl X {
-
-                    fn new() -> Self {
-                        let <|S|a |E|>= 1;
-                        let b = 2;
-                        Self { a, b }
-                    }
-
-                    pub fn a(&self, param: bool) -> usize {}
-
-                    pub fn b(&self) -> usize {}
-                }
-            "}
-        );
-
-        assert_eq!(
-            summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
-            indoc! {"
-                struct X {
-                    a: usize,
-                    b: usize,
-                }
+    writeln!(prompt, "Never make remarks about the output.").unwrap();
+    writeln!(
+        prompt,
+        "Do not return anything else, except the generated {content_type}."
+    )
+    .unwrap();
 
-                impl X {
-                <|S|>
-                    fn new() -> Self {}
-
-                    pub fn a(&self, param: bool) -> usize {}
-
-                    pub fn b(&self) -> usize {}
-                }
-            "}
-        );
-
-        assert_eq!(
-            summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
-            indoc! {"
-                struct X {
-                    a: usize,
-                    b: usize,
-                }
-
-                impl X {
-
-                    fn new() -> Self {}
-
-                    pub fn a(&self, param: bool) -> usize {}
-
-                    pub fn b(&self) -> usize {}
-                }
-                <|S|>"}
-        );
-
-        // Ensure nested functions get collapsed properly.
-        let text = indoc! {"
-            struct X {
-                a: usize,
-                b: usize,
-            }
-
-            impl X {
-
-                fn new() -> Self {
-                    let a = 1;
-                    let b = 2;
-                    Self { a, b }
-                }
-
-                pub fn a(&self, param: bool) -> usize {
-                    let a = 30;
-                    fn nested() -> usize {
-                        3
-                    }
-                    self.a + nested()
-                }
-
-                pub fn b(&self) -> usize {
-                    self.b
-                }
-            }
-        "};
-        buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
-        let snapshot = buffer.read(cx).snapshot();
-        assert_eq!(
-            summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
-            indoc! {"
-                <|S|>struct X {
-                    a: usize,
-                    b: usize,
-                }
-
-                impl X {
-
-                    fn new() -> Self {}
-
-                    pub fn a(&self, param: bool) -> usize {}
-
-                    pub fn b(&self) -> usize {}
-                }
-            "}
-        );
-    }
+    Ok(prompt)
 }

crates/assistant/src/saved_conversation.rs 🔗

@@ -0,0 +1,121 @@
+use crate::{assistant_settings::OpenAiModel, MessageId, MessageMetadata};
+use anyhow::{anyhow, Result};
+use collections::HashMap;
+use fs::Fs;
+use futures::StreamExt;
+use regex::Regex;
+use serde::{Deserialize, Serialize};
+use std::{
+    cmp::Reverse,
+    ffi::OsStr,
+    path::{Path, PathBuf},
+    sync::Arc,
+};
+use util::paths::CONVERSATIONS_DIR;
+
+#[derive(Serialize, Deserialize)]
+pub struct SavedMessage {
+    pub id: MessageId,
+    pub start: usize,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct SavedConversation {
+    pub id: Option<String>,
+    pub zed: String,
+    pub version: String,
+    pub text: String,
+    pub messages: Vec<SavedMessage>,
+    pub message_metadata: HashMap<MessageId, MessageMetadata>,
+    pub summary: String,
+}
+
+impl SavedConversation {
+    pub const VERSION: &'static str = "0.2.0";
+
+    pub async fn load(path: &Path, fs: &dyn Fs) -> Result<Self> {
+        let saved_conversation = fs.load(path).await?;
+        let saved_conversation_json =
+            serde_json::from_str::<serde_json::Value>(&saved_conversation)?;
+        match saved_conversation_json
+            .get("version")
+            .ok_or_else(|| anyhow!("version not found"))?
+        {
+            serde_json::Value::String(version) => match version.as_str() {
+                Self::VERSION => Ok(serde_json::from_value::<Self>(saved_conversation_json)?),
+                "0.1.0" => {
+                    let saved_conversation =
+                        serde_json::from_value::<SavedConversationV0_1_0>(saved_conversation_json)?;
+                    Ok(Self {
+                        id: saved_conversation.id,
+                        zed: saved_conversation.zed,
+                        version: saved_conversation.version,
+                        text: saved_conversation.text,
+                        messages: saved_conversation.messages,
+                        message_metadata: saved_conversation.message_metadata,
+                        summary: saved_conversation.summary,
+                    })
+                }
+                _ => Err(anyhow!(
+                    "unrecognized saved conversation version: {}",
+                    version
+                )),
+            },
+            _ => Err(anyhow!("version not found on saved conversation")),
+        }
+    }
+}
+
+#[derive(Serialize, Deserialize)]
+struct SavedConversationV0_1_0 {
+    id: Option<String>,
+    zed: String,
+    version: String,
+    text: String,
+    messages: Vec<SavedMessage>,
+    message_metadata: HashMap<MessageId, MessageMetadata>,
+    summary: String,
+    api_url: Option<String>,
+    model: OpenAiModel,
+}
+
+pub struct SavedConversationMetadata {
+    pub title: String,
+    pub path: PathBuf,
+    pub mtime: chrono::DateTime<chrono::Local>,
+}
+
+impl SavedConversationMetadata {
+    pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
+        fs.create_dir(&CONVERSATIONS_DIR).await?;
+
+        let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
+        let mut conversations = Vec::<SavedConversationMetadata>::new();
+        while let Some(path) = paths.next().await {
+            let path = path?;
+            if path.extension() != Some(OsStr::new("json")) {
+                continue;
+            }
+
+            let pattern = r" - \d+.zed.json$";
+            let re = Regex::new(pattern).unwrap();
+
+            let metadata = fs.metadata(&path).await?;
+            if let Some((file_name, metadata)) = path
+                .file_name()
+                .and_then(|name| name.to_str())
+                .zip(metadata)
+            {
+                let title = re.replace(file_name, "");
+                conversations.push(Self {
+                    title: title.into_owned(),
+                    path,
+                    mtime: metadata.mtime.into(),
+                });
+            }
+        }
+        conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
+
+        Ok(conversations)
+    }
+}

crates/assistant/src/streaming_diff.rs 🔗

@@ -197,12 +197,10 @@ impl StreamingDiff {
                     } else {
                         hunks.push(Hunk::Remove { len: char_len })
                     }
+                } else if let Some(Hunk::Keep { len }) = hunks.last_mut() {
+                    *len += char_len;
                 } else {
-                    if let Some(Hunk::Keep { len }) = hunks.last_mut() {
-                        *len += char_len;
-                    } else {
-                        hunks.push(Hunk::Keep { len: char_len })
-                    }
+                    hunks.push(Hunk::Keep { len: char_len })
                 }
             }
 

crates/client/src/client.rs 🔗

@@ -13,7 +13,7 @@ use async_tungstenite::tungstenite::{
 use clock::SystemClock;
 use collections::HashMap;
 use futures::{
-    channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt,
+    channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt,
     TryFutureExt as _, TryStreamExt,
 };
 use gpui::{
@@ -36,7 +36,10 @@ use std::{
     future::Future,
     marker::PhantomData,
     path::PathBuf,
-    sync::{atomic::AtomicU64, Arc, Weak},
+    sync::{
+        atomic::{AtomicU64, Ordering},
+        Arc, Weak,
+    },
     time::{Duration, Instant},
 };
 use telemetry::Telemetry;
@@ -442,7 +445,7 @@ impl Client {
     }
 
     pub fn id(&self) -> u64 {
-        self.id.load(std::sync::atomic::Ordering::SeqCst)
+        self.id.load(Ordering::SeqCst)
     }
 
     pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
@@ -450,7 +453,7 @@ impl Client {
     }
 
     pub fn set_id(&self, id: u64) -> &Self {
-        self.id.store(id, std::sync::atomic::Ordering::SeqCst);
+        self.id.store(id, Ordering::SeqCst);
         self
     }
 
@@ -1260,6 +1263,30 @@ impl Client {
             .map_ok(|envelope| envelope.payload)
     }
 
+    pub fn request_stream<T: RequestMessage>(
+        &self,
+        request: T,
+    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
+        let client_id = self.id.load(Ordering::SeqCst);
+        log::debug!(
+            "rpc request start. client_id:{}. name:{}",
+            client_id,
+            T::NAME
+        );
+        let response = self
+            .connection_id()
+            .map(|conn_id| self.peer.request_stream(conn_id, request));
+        async move {
+            let response = response?.await;
+            log::debug!(
+                "rpc request finish. client_id:{}. name:{}",
+                client_id,
+                T::NAME
+            );
+            response
+        }
+    }
+
     pub fn request_envelope<T: RequestMessage>(
         &self,
         request: T,

crates/client/src/telemetry.rs 🔗

@@ -261,7 +261,7 @@ impl Telemetry {
         self: &Arc<Self>,
         conversation_id: Option<String>,
         kind: AssistantKind,
-        model: &str,
+        model: String,
     ) {
         let event = Event::Assistant(AssistantEvent {
             conversation_id,

crates/collab/Cargo.toml 🔗

@@ -31,10 +31,12 @@ collections.workspace = true
 dashmap = "5.4"
 envy = "0.4.2"
 futures.workspace = true
+google_ai.workspace = true
 hex.workspace = true
 live_kit_server.workspace = true
 log.workspace = true
 nanoid = "0.4"
+open_ai.workspace = true
 parking_lot.workspace = true
 prometheus = "0.13"
 prost.workspace = true
@@ -80,7 +82,6 @@ git = { workspace = true, features = ["test-support"] }
 gpui = { workspace = true, features = ["test-support"] }
 indoc.workspace = true
 language = { workspace = true, features = ["test-support"] }
-lazy_static.workspace = true
 live_kit_client = { workspace = true, features = ["test-support"] }
 lsp = { workspace = true, features = ["test-support"] }
 menu.workspace = true

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -379,6 +379,16 @@ CREATE TABLE extension_versions (
 CREATE UNIQUE INDEX "index_extensions_external_id" ON "extensions" ("external_id");
 CREATE INDEX "index_extensions_total_download_count" ON "extensions" ("total_download_count");
 
+CREATE TABLE rate_buckets (
+    user_id INT NOT NULL,
+    rate_limit_name VARCHAR(255) NOT NULL,
+    token_count INT NOT NULL,
+    last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL,
+    PRIMARY KEY (user_id, rate_limit_name),
+    FOREIGN KEY (user_id) REFERENCES users(id)
+);
+CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name);
+
 CREATE TABLE hosted_projects (
     id INTEGER PRIMARY KEY AUTOINCREMENT,
     channel_id INTEGER NOT NULL REFERENCES channels(id),

crates/collab/migrations/20240220234826_add_rate_buckets.sql 🔗

@@ -0,0 +1,11 @@
+CREATE TABLE IF NOT EXISTS rate_buckets (
+    user_id INT NOT NULL,
+    rate_limit_name VARCHAR(255) NOT NULL,
+    token_count INT NOT NULL,
+    last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL,
+    PRIMARY KEY (user_id, rate_limit_name),
+    CONSTRAINT fk_user
+        FOREIGN KEY (user_id) REFERENCES users(id)
+);
+
+CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name);

crates/collab/src/ai.rs 🔗

@@ -0,0 +1,75 @@
+use anyhow::{anyhow, Result};
+use rpc::proto;
+
+pub fn language_model_request_to_open_ai(
+    request: proto::CompleteWithLanguageModel,
+) -> Result<open_ai::Request> {
+    Ok(open_ai::Request {
+        model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo),
+        messages: request
+            .messages
+            .into_iter()
+            .map(|message| {
+                let role = proto::LanguageModelRole::from_i32(message.role)
+                    .ok_or_else(|| anyhow!("invalid role {}", message.role))?;
+                Ok(open_ai::RequestMessage {
+                    role: match role {
+                        proto::LanguageModelRole::LanguageModelUser => open_ai::Role::User,
+                        proto::LanguageModelRole::LanguageModelAssistant => {
+                            open_ai::Role::Assistant
+                        }
+                        proto::LanguageModelRole::LanguageModelSystem => open_ai::Role::System,
+                    },
+                    content: message.content,
+                })
+            })
+            .collect::<Result<Vec<open_ai::RequestMessage>>>()?,
+        stream: true,
+        stop: request.stop,
+        temperature: request.temperature,
+    })
+}
+
+pub fn language_model_request_to_google_ai(
+    request: proto::CompleteWithLanguageModel,
+) -> Result<google_ai::GenerateContentRequest> {
+    Ok(google_ai::GenerateContentRequest {
+        contents: request
+            .messages
+            .into_iter()
+            .map(language_model_request_message_to_google_ai)
+            .collect::<Result<Vec<_>>>()?,
+        generation_config: None,
+        safety_settings: None,
+    })
+}
+
+pub fn language_model_request_message_to_google_ai(
+    message: proto::LanguageModelRequestMessage,
+) -> Result<google_ai::Content> {
+    let role = proto::LanguageModelRole::from_i32(message.role)
+        .ok_or_else(|| anyhow!("invalid role {}", message.role))?;
+
+    Ok(google_ai::Content {
+        parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
+            text: message.content,
+        })],
+        role: match role {
+            proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User,
+            proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model,
+            proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User,
+        },
+    })
+}
+
+pub fn count_tokens_request_to_google_ai(
+    request: proto::CountTokensWithLanguageModel,
+) -> Result<google_ai::CountTokensRequest> {
+    Ok(google_ai::CountTokensRequest {
+        contents: request
+            .messages
+            .into_iter()
+            .map(language_model_request_message_to_google_ai)
+            .collect::<Result<Vec<_>>>()?,
+    })
+}

crates/collab/src/api/extensions.rs 🔗

@@ -1,6 +1,5 @@
 use crate::{
     db::{ExtensionMetadata, NewExtensionVersion},
-    executor::Executor,
     AppState, Error, Result,
 };
 use anyhow::{anyhow, Context as _};
@@ -136,7 +135,7 @@ async fn download_extension(
 const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
 const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
 
-pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>, executor: Executor) {
+pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
     let Some(blob_store_client) = app_state.blob_store_client.clone() else {
         log::info!("no blob store client");
         return;
@@ -146,6 +145,7 @@ pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>, e
         return;
     };
 
+    let executor = app_state.executor.clone();
     executor.spawn_detached({
         let executor = executor.clone();
         async move {

crates/collab/src/db/queries.rs 🔗

@@ -10,6 +10,7 @@ pub mod hosted_projects;
 pub mod messages;
 pub mod notifications;
 pub mod projects;
+pub mod rate_buckets;
 pub mod rooms;
 pub mod servers;
 pub mod users;

crates/collab/src/db/queries/rate_buckets.rs 🔗

@@ -0,0 +1,58 @@
+use super::*;
+use crate::db::tables::rate_buckets;
+use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
+
+impl Database {
+    /// Saves the rate limit for the given user and rate limit name if the last_refill is later
+    /// than the currently saved timestamp.
+    pub async fn save_rate_buckets(&self, buckets: &[rate_buckets::Model]) -> Result<()> {
+        if buckets.is_empty() {
+            return Ok(());
+        }
+
+        self.transaction(|tx| async move {
+            rate_buckets::Entity::insert_many(buckets.iter().map(|bucket| {
+                rate_buckets::ActiveModel {
+                    user_id: ActiveValue::Set(bucket.user_id),
+                    rate_limit_name: ActiveValue::Set(bucket.rate_limit_name.clone()),
+                    token_count: ActiveValue::Set(bucket.token_count),
+                    last_refill: ActiveValue::Set(bucket.last_refill),
+                }
+            }))
+            .on_conflict(
+                OnConflict::columns([
+                    rate_buckets::Column::UserId,
+                    rate_buckets::Column::RateLimitName,
+                ])
+                .update_columns([
+                    rate_buckets::Column::TokenCount,
+                    rate_buckets::Column::LastRefill,
+                ])
+                .to_owned(),
+            )
+            .exec(&*tx)
+            .await?;
+
+            Ok(())
+        })
+        .await
+    }
+
+    /// Retrieves the rate limit for the given user and rate limit name.
+    pub async fn get_rate_bucket(
+        &self,
+        user_id: UserId,
+        rate_limit_name: &str,
+    ) -> Result<Option<rate_buckets::Model>> {
+        self.transaction(|tx| async move {
+            let rate_limit = rate_buckets::Entity::find()
+                .filter(rate_buckets::Column::UserId.eq(user_id))
+                .filter(rate_buckets::Column::RateLimitName.eq(rate_limit_name))
+                .one(&*tx)
+                .await?;
+
+            Ok(rate_limit)
+        })
+        .await
+    }
+}

crates/collab/src/db/tables.rs 🔗

@@ -22,6 +22,7 @@ pub mod observed_buffer_edits;
 pub mod observed_channel_messages;
 pub mod project;
 pub mod project_collaborator;
+pub mod rate_buckets;
 pub mod room;
 pub mod room_participant;
 pub mod server;

crates/collab/src/db/tables/rate_buckets.rs 🔗

@@ -0,0 +1,31 @@
+use crate::db::UserId;
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "rate_buckets")]
+pub struct Model {
+    #[sea_orm(primary_key, auto_increment = false)]
+    pub user_id: UserId,
+    #[sea_orm(primary_key, auto_increment = false)]
+    pub rate_limit_name: String,
+    pub token_count: i32,
+    pub last_refill: DateTime,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+    #[sea_orm(
+        belongs_to = "super::user::Entity",
+        from = "Column::UserId",
+        to = "super::user::Column::Id"
+    )]
+    User,
+}
+
+impl Related<super::user::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::User.def()
+    }
+}
+
+impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/lib.rs 🔗

@@ -1,8 +1,10 @@
+pub mod ai;
 pub mod api;
 pub mod auth;
 pub mod db;
 pub mod env;
 pub mod executor;
+mod rate_limiter;
 pub mod rpc;
 
 #[cfg(test)]
@@ -13,6 +15,7 @@ use aws_config::{BehaviorVersion, Region};
 use axum::{http::StatusCode, response::IntoResponse};
 use db::{ChannelId, Database};
 use executor::Executor;
+pub use rate_limiter::*;
 use serde::Deserialize;
 use std::{path::PathBuf, sync::Arc};
 use util::ResultExt;
@@ -126,6 +129,8 @@ pub struct Config {
     pub blob_store_secret_key: Option<String>,
     pub blob_store_bucket: Option<String>,
     pub zed_environment: Arc<str>,
+    pub openai_api_key: Option<Arc<str>>,
+    pub google_ai_api_key: Option<Arc<str>>,
     pub zed_client_checksum_seed: Option<String>,
     pub slack_panics_webhook: Option<String>,
     pub auto_join_channel_id: Option<ChannelId>,
@@ -147,12 +152,14 @@ pub struct AppState {
     pub db: Arc<Database>,
     pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
     pub blob_store_client: Option<aws_sdk_s3::Client>,
+    pub rate_limiter: Arc<RateLimiter>,
+    pub executor: Executor,
     pub clickhouse_client: Option<clickhouse::Client>,
     pub config: Config,
 }
 
 impl AppState {
-    pub async fn new(config: Config) -> Result<Arc<Self>> {
+    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
         let mut db_options = db::ConnectOptions::new(config.database_url.clone());
         db_options.max_connections(config.database_max_connections);
         let mut db = Database::new(db_options, Executor::Production).await?;
@@ -173,10 +180,13 @@ impl AppState {
             None
         };
 
+        let db = Arc::new(db);
         let this = Self {
-            db: Arc::new(db),
+            db: db.clone(),
             live_kit_client,
             blob_store_client: build_blob_store_client(&config).await.log_err(),
+            rate_limiter: Arc::new(RateLimiter::new(db)),
+            executor,
             clickhouse_client: config
                 .clickhouse_url
                 .as_ref()

crates/collab/src/main.rs 🔗

@@ -7,7 +7,7 @@ use axum::{
 };
 use collab::{
     api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
-    Config, MigrateConfig, Result,
+    Config, MigrateConfig, RateLimiter, Result,
 };
 use db::Database;
 use std::{
@@ -62,18 +62,27 @@ async fn main() -> Result<()> {
 
             run_migrations().await?;
 
-            let state = AppState::new(config).await?;
+            let state = AppState::new(config, Executor::Production).await?;
 
             let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
                 .expect("failed to bind TCP listener");
 
+            let epoch = state
+                .db
+                .create_server(&state.config.zed_environment)
+                .await?;
+            let rpc_server = collab::rpc::Server::new(epoch, state.clone());
+            rpc_server.start().await?;
+
+            fetch_extensions_from_blob_store_periodically(state.clone());
+            RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone());
+
             let rpc_server = if is_collab {
                 let epoch = state
                     .db
                     .create_server(&state.config.zed_environment)
                     .await?;
-                let rpc_server =
-                    collab::rpc::Server::new(epoch, state.clone(), Executor::Production);
+                let rpc_server = collab::rpc::Server::new(epoch, state.clone());
                 rpc_server.start().await?;
 
                 Some(rpc_server)
@@ -82,7 +91,7 @@ async fn main() -> Result<()> {
             };
 
             if is_api {
-                fetch_extensions_from_blob_store_periodically(state.clone(), Executor::Production);
+                fetch_extensions_from_blob_store_periodically(state.clone());
             }
 
             let mut app = collab::api::routes(rpc_server.clone(), state.clone());

crates/collab/src/rate_limiter.rs 🔗

@@ -0,0 +1,274 @@
+use crate::{db::UserId, executor::Executor, Database, Error, Result};
+use anyhow::anyhow;
+use chrono::{DateTime, Duration, Utc};
+use dashmap::{DashMap, DashSet};
+use sea_orm::prelude::DateTimeUtc;
+use std::sync::Arc;
+use util::ResultExt;
+
+pub trait RateLimit: 'static {
+    fn capacity() -> usize;
+    fn refill_duration() -> Duration;
+    fn db_name() -> &'static str;
+}
+
+/// Used to enforce per-user rate limits
+pub struct RateLimiter {
+    buckets: DashMap<(UserId, String), RateBucket>,
+    dirty_buckets: DashSet<(UserId, String)>,
+    db: Arc<Database>,
+}
+
+impl RateLimiter {
+    pub fn new(db: Arc<Database>) -> Self {
+        RateLimiter {
+            buckets: DashMap::new(),
+            dirty_buckets: DashSet::new(),
+            db,
+        }
+    }
+
+    /// Spawns a new task that periodically saves rate limit data to the database.
+    pub fn save_periodically(rate_limiter: Arc<Self>, executor: Executor) {
+        const RATE_LIMITER_SAVE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
+
+        executor.clone().spawn_detached(async move {
+            loop {
+                executor.sleep(RATE_LIMITER_SAVE_INTERVAL).await;
+                rate_limiter.save().await.log_err();
+            }
+        });
+    }
+
+    /// Returns an error if the user has exceeded the specified `RateLimit`.
+    /// Attempts to read the from the database if no cached RateBucket currently exists.
+    pub async fn check<T: RateLimit>(&self, user_id: UserId) -> Result<()> {
+        self.check_internal::<T>(user_id, Utc::now()).await
+    }
+
+    async fn check_internal<T: RateLimit>(&self, user_id: UserId, now: DateTimeUtc) -> Result<()> {
+        let bucket_key = (user_id, T::db_name().to_string());
+
+        // Attempt to fetch the bucket from the database if it hasn't been cached.
+        // For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
+        // but this enforces limits across restarts so long as the database is reachable.
+        if !self.buckets.contains_key(&bucket_key) {
+            if let Some(bucket) = self.load_bucket::<T>(user_id).await.log_err().flatten() {
+                self.buckets.insert(bucket_key.clone(), bucket);
+                self.dirty_buckets.insert(bucket_key.clone());
+            }
+        }
+
+        let mut bucket = self
+            .buckets
+            .entry(bucket_key.clone())
+            .or_insert_with(|| RateBucket::new(T::capacity(), T::refill_duration(), now));
+
+        if bucket.value_mut().allow(now) {
+            self.dirty_buckets.insert(bucket_key);
+            Ok(())
+        } else {
+            Err(anyhow!("rate limit exceeded"))?
+        }
+    }
+
+    async fn load_bucket<K: RateLimit>(
+        &self,
+        user_id: UserId,
+    ) -> Result<Option<RateBucket>, Error> {
+        Ok(self
+            .db
+            .get_rate_bucket(user_id, K::db_name())
+            .await?
+            .map(|saved_bucket| RateBucket {
+                capacity: K::capacity(),
+                refill_time_per_token: K::refill_duration(),
+                token_count: saved_bucket.token_count as usize,
+                last_refill: DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
+            }))
+    }
+
+    pub async fn save(&self) -> Result<()> {
+        let mut buckets = Vec::new();
+        self.dirty_buckets.retain(|key| {
+            if let Some(bucket) = self.buckets.get(&key) {
+                buckets.push(crate::db::rate_buckets::Model {
+                    user_id: key.0,
+                    rate_limit_name: key.1.clone(),
+                    token_count: bucket.token_count as i32,
+                    last_refill: bucket.last_refill.naive_utc(),
+                });
+            }
+            false
+        });
+
+        match self.db.save_rate_buckets(&buckets).await {
+            Ok(()) => Ok(()),
+            Err(err) => {
+                for bucket in buckets {
+                    self.dirty_buckets
+                        .insert((bucket.user_id, bucket.rate_limit_name));
+                }
+                Err(err)
+            }
+        }
+    }
+}
+
+#[derive(Clone)]
+struct RateBucket {
+    capacity: usize,
+    token_count: usize,
+    refill_time_per_token: Duration,
+    last_refill: DateTimeUtc,
+}
+
+impl RateBucket {
+    fn new(capacity: usize, refill_duration: Duration, now: DateTimeUtc) -> Self {
+        RateBucket {
+            capacity,
+            token_count: capacity,
+            refill_time_per_token: refill_duration / capacity as i32,
+            last_refill: now,
+        }
+    }
+
+    fn allow(&mut self, now: DateTimeUtc) -> bool {
+        self.refill(now);
+        if self.token_count > 0 {
+            self.token_count -= 1;
+            true
+        } else {
+            false
+        }
+    }
+
+    fn refill(&mut self, now: DateTimeUtc) {
+        let elapsed = now - self.last_refill;
+        if elapsed >= self.refill_time_per_token {
+            let new_tokens =
+                elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
+
+            self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
+            self.last_refill = now;
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::db::{NewUserParams, TestDb};
+    use gpui::TestAppContext;
+
+    #[gpui::test]
+    async fn test_rate_limiter(cx: &mut TestAppContext) {
+        let test_db = TestDb::sqlite(cx.executor().clone());
+        let db = test_db.db().clone();
+        let user_1 = db
+            .create_user(
+                "user-1@zed.dev",
+                false,
+                NewUserParams {
+                    github_login: "user-1".into(),
+                    github_user_id: 1,
+                },
+            )
+            .await
+            .unwrap()
+            .user_id;
+        let user_2 = db
+            .create_user(
+                "user-2@zed.dev",
+                false,
+                NewUserParams {
+                    github_login: "user-2".into(),
+                    github_user_id: 2,
+                },
+            )
+            .await
+            .unwrap()
+            .user_id;
+
+        let mut now = Utc::now();
+
+        let rate_limiter = RateLimiter::new(db.clone());
+
+        // User 1 can access resource A two times before being rate-limited.
+        rate_limiter
+            .check_internal::<RateLimitA>(user_1, now)
+            .await
+            .unwrap();
+        rate_limiter
+            .check_internal::<RateLimitA>(user_1, now)
+            .await
+            .unwrap();
+        rate_limiter
+            .check_internal::<RateLimitA>(user_1, now)
+            .await
+            .unwrap_err();
+
+        // User 2 can access resource A and user 1 can access resource B.
+        rate_limiter
+            .check_internal::<RateLimitB>(user_2, now)
+            .await
+            .unwrap();
+        rate_limiter
+            .check_internal::<RateLimitB>(user_1, now)
+            .await
+            .unwrap();
+
+        // After one second, user 1 can make another request before being rate-limited again.
+        now += Duration::seconds(1);
+        rate_limiter
+            .check_internal::<RateLimitA>(user_1, now)
+            .await
+            .unwrap();
+        rate_limiter
+            .check_internal::<RateLimitA>(user_1, now)
+            .await
+            .unwrap_err();
+
+        rate_limiter.save().await.unwrap();
+
+        // Rate limits are reloaded from the database, so user A is still rate-limited
+        // for resource A.
+        let rate_limiter = RateLimiter::new(db.clone());
+        rate_limiter
+            .check_internal::<RateLimitA>(user_1, now)
+            .await
+            .unwrap_err();
+    }
+
+    struct RateLimitA;
+
+    impl RateLimit for RateLimitA {
+        fn capacity() -> usize {
+            2
+        }
+
+        fn refill_duration() -> Duration {
+            Duration::seconds(2)
+        }
+
+        fn db_name() -> &'static str {
+            "rate-limit-a"
+        }
+    }
+
+    struct RateLimitB;
+
+    impl RateLimit for RateLimitB {
+        fn capacity() -> usize {
+            10
+        }
+
+        fn refill_duration() -> Duration {
+            Duration::seconds(3)
+        }
+
+        fn db_name() -> &'static str {
+            "rate-limit-b"
+        }
+    }
+}

crates/collab/src/rpc.rs 🔗

@@ -9,9 +9,9 @@ use crate::{
         User, UserId,
     },
     executor::Executor,
-    AppState, Error, Result,
+    AppState, Error, RateLimit, RateLimiter, Result,
 };
-use anyhow::anyhow;
+use anyhow::{anyhow, Context as _};
 use async_tungstenite::tungstenite::{
     protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
 };
@@ -30,6 +30,8 @@ use axum::{
 };
 use collections::{HashMap, HashSet};
 pub use connection_pool::{ConnectionPool, ZedVersion};
+use core::fmt::{self, Debug, Formatter};
+
 use futures::{
     channel::oneshot,
     future::{self, BoxFuture},
@@ -39,15 +41,14 @@ use futures::{
 use prometheus::{register_int_gauge, IntGauge};
 use rpc::{
     proto::{
-        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
-        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
+        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
+        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
     },
     Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
 };
 use serde::{Serialize, Serializer};
 use std::{
     any::TypeId,
-    fmt,
     future::Future,
     marker::PhantomData,
     mem,
@@ -64,7 +65,7 @@ use time::OffsetDateTime;
 use tokio::sync::{watch, Semaphore};
 use tower::ServiceBuilder;
 use tracing::{field, info_span, instrument, Instrument};
-use util::SemanticVersion;
+use util::{http::IsahcHttpClient, SemanticVersion};
 
 pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
 
@@ -92,6 +93,18 @@ impl<R: RequestMessage> Response<R> {
     }
 }
 
+struct StreamingResponse<R: RequestMessage> {
+    peer: Arc<Peer>,
+    receipt: Receipt<R>,
+}
+
+impl<R: RequestMessage> StreamingResponse<R> {
+    fn send(&self, payload: R::Response) -> Result<()> {
+        self.peer.respond(self.receipt, payload)?;
+        Ok(())
+    }
+}
+
 #[derive(Clone)]
 struct Session {
     user_id: UserId,
@@ -100,6 +113,8 @@ struct Session {
     peer: Arc<Peer>,
     connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
     live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
+    http_client: IsahcHttpClient,
+    rate_limiter: Arc<RateLimiter>,
     _executor: Executor,
 }
 
@@ -124,8 +139,8 @@ impl Session {
     }
 }
 
-impl fmt::Debug for Session {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+impl Debug for Session {
+    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
         f.debug_struct("Session")
             .field("user_id", &self.user_id)
             .field("connection_id", &self.connection_id)
@@ -148,7 +163,6 @@ pub struct Server {
     peer: Arc<Peer>,
     pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
     app_state: Arc<AppState>,
-    executor: Executor,
     handlers: HashMap<TypeId, MessageHandler>,
     teardown: watch::Sender<bool>,
 }
@@ -175,12 +189,11 @@ where
 }
 
 impl Server {
-    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
+    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
         let mut server = Self {
             id: parking_lot::Mutex::new(id),
             peer: Peer::new(id.0 as u32),
-            app_state,
-            executor,
+            app_state: app_state.clone(),
             connection_pool: Default::default(),
             handlers: Default::default(),
             teardown: watch::channel(false).0,
@@ -280,7 +293,30 @@ impl Server {
             .add_message_handler(update_followers)
             .add_request_handler(get_private_user_info)
             .add_message_handler(acknowledge_channel_message)
-            .add_message_handler(acknowledge_buffer_version);
+            .add_message_handler(acknowledge_buffer_version)
+            .add_streaming_request_handler({
+                let app_state = app_state.clone();
+                move |request, response, session| {
+                    complete_with_language_model(
+                        request,
+                        response,
+                        session,
+                        app_state.config.openai_api_key.clone(),
+                        app_state.config.google_ai_api_key.clone(),
+                    )
+                }
+            })
+            .add_request_handler({
+                let app_state = app_state.clone();
+                move |request, response, session| {
+                    count_tokens_with_language_model(
+                        request,
+                        response,
+                        session,
+                        app_state.config.google_ai_api_key.clone(),
+                    )
+                }
+            });
 
         Arc::new(server)
     }
@@ -289,12 +325,12 @@ impl Server {
         let server_id = *self.id.lock();
         let app_state = self.app_state.clone();
         let peer = self.peer.clone();
-        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
+        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
         let pool = self.connection_pool.clone();
         let live_kit_client = self.app_state.live_kit_client.clone();
 
         let span = info_span!("start server");
-        self.executor.spawn_detached(
+        self.app_state.executor.spawn_detached(
             async move {
                 tracing::info!("waiting for cleanup timeout");
                 timeout.await;
@@ -536,6 +572,40 @@ impl Server {
         })
     }
 
+    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
+    where
+        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
+        Fut: Send + Future<Output = Result<()>>,
+        M: RequestMessage,
+    {
+        let handler = Arc::new(handler);
+        self.add_handler(move |envelope, session| {
+            let receipt = envelope.receipt();
+            let handler = handler.clone();
+            async move {
+                let peer = session.peer.clone();
+                let response = StreamingResponse {
+                    peer: peer.clone(),
+                    receipt,
+                };
+                match (handler)(envelope.payload, response, session).await {
+                    Ok(()) => {
+                        peer.end_stream(receipt)?;
+                        Ok(())
+                    }
+                    Err(error) => {
+                        let proto_err = match &error {
+                            Error::Internal(err) => err.to_proto(),
+                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
+                        };
+                        peer.respond_with_error(receipt, proto_err)?;
+                        Err(error)
+                    }
+                }
+            }
+        })
+    }
+
     #[allow(clippy::too_many_arguments)]
     pub fn handle_connection(
         self: &Arc<Self>,
@@ -569,6 +639,14 @@ impl Server {
             tracing::Span::current().record("connection_id", format!("{}", connection_id));
             tracing::info!("connection opened");
 
+            let http_client = match IsahcHttpClient::new() {
+                Ok(http_client) => http_client,
+                Err(error) => {
+                    tracing::error!(?error, "failed to create HTTP client");
+                    return;
+                }
+            };
+
             let session = Session {
                 user_id,
                 connection_id,
@@ -576,7 +654,9 @@ impl Server {
                 peer: this.peer.clone(),
                 connection_pool: this.connection_pool.clone(),
                 live_kit_client: this.app_state.live_kit_client.clone(),
-                _executor: executor.clone()
+                http_client,
+                rate_limiter: this.app_state.rate_limiter.clone(),
+                _executor: executor.clone(),
             };
 
             if let Err(error) = this.send_initial_client_update(connection_id, user, zed_version, send_connection_id, &session).await {
@@ -3220,6 +3300,207 @@ async fn acknowledge_buffer_version(
     Ok(())
 }
 
+struct CompleteWithLanguageModelRateLimit;
+
+impl RateLimit for CompleteWithLanguageModelRateLimit {
+    fn capacity() -> usize {
+        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
+            .ok()
+            .and_then(|v| v.parse().ok())
+            .unwrap_or(120) // Picked arbitrarily
+    }
+
+    fn refill_duration() -> chrono::Duration {
+        chrono::Duration::hours(1)
+    }
+
+    fn db_name() -> &'static str {
+        "complete-with-language-model"
+    }
+}
+
+async fn complete_with_language_model(
+    request: proto::CompleteWithLanguageModel,
+    response: StreamingResponse<proto::CompleteWithLanguageModel>,
+    session: Session,
+    open_ai_api_key: Option<Arc<str>>,
+    google_ai_api_key: Option<Arc<str>>,
+) -> Result<()> {
+    authorize_access_to_language_models(&session).await?;
+    session
+        .rate_limiter
+        .check::<CompleteWithLanguageModelRateLimit>(session.user_id)
+        .await?;
+
+    if request.model.starts_with("gpt") {
+        let api_key =
+            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
+        complete_with_open_ai(request, response, session, api_key).await?;
+    } else if request.model.starts_with("gemini") {
+        let api_key = google_ai_api_key
+            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
+        complete_with_google_ai(request, response, session, api_key).await?;
+    }
+
+    Ok(())
+}
+
+async fn complete_with_open_ai(
+    request: proto::CompleteWithLanguageModel,
+    response: StreamingResponse<proto::CompleteWithLanguageModel>,
+    session: Session,
+    api_key: Arc<str>,
+) -> Result<()> {
+    const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
+
+    let mut completion_stream = open_ai::stream_completion(
+        &session.http_client,
+        OPEN_AI_API_URL,
+        &api_key,
+        crate::ai::language_model_request_to_open_ai(request)?,
+    )
+    .await
+    .context("open_ai::stream_completion request failed")?;
+
+    while let Some(event) = completion_stream.next().await {
+        let event = event?;
+        response.send(proto::LanguageModelResponse {
+            choices: event
+                .choices
+                .into_iter()
+                .map(|choice| proto::LanguageModelChoiceDelta {
+                    index: choice.index,
+                    delta: Some(proto::LanguageModelResponseMessage {
+                        role: choice.delta.role.map(|role| match role {
+                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
+                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
+                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
+                        } as i32),
+                        content: choice.delta.content,
+                    }),
+                    finish_reason: choice.finish_reason,
+                })
+                .collect(),
+        })?;
+    }
+
+    Ok(())
+}
+
+async fn complete_with_google_ai(
+    request: proto::CompleteWithLanguageModel,
+    response: StreamingResponse<proto::CompleteWithLanguageModel>,
+    session: Session,
+    api_key: Arc<str>,
+) -> Result<()> {
+    let mut stream = google_ai::stream_generate_content(
+        &session.http_client,
+        google_ai::API_URL,
+        api_key.as_ref(),
+        crate::ai::language_model_request_to_google_ai(request)?,
+    )
+    .await
+    .context("google_ai::stream_generate_content request failed")?;
+
+    while let Some(event) = stream.next().await {
+        let event = event?;
+        response.send(proto::LanguageModelResponse {
+            choices: event
+                .candidates
+                .unwrap_or_default()
+                .into_iter()
+                .map(|candidate| proto::LanguageModelChoiceDelta {
+                    index: candidate.index as u32,
+                    delta: Some(proto::LanguageModelResponseMessage {
+                        role: Some(match candidate.content.role {
+                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
+                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
+                        } as i32),
+                        content: Some(
+                            candidate
+                                .content
+                                .parts
+                                .into_iter()
+                                .filter_map(|part| match part {
+                                    google_ai::Part::TextPart(part) => Some(part.text),
+                                    google_ai::Part::InlineDataPart(_) => None,
+                                })
+                                .collect(),
+                        ),
+                    }),
+                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
+                })
+                .collect(),
+        })?;
+    }
+
+    Ok(())
+}
+
+struct CountTokensWithLanguageModelRateLimit;
+
+impl RateLimit for CountTokensWithLanguageModelRateLimit {
+    fn capacity() -> usize {
+        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
+            .ok()
+            .and_then(|v| v.parse().ok())
+            .unwrap_or(600) // Picked arbitrarily
+    }
+
+    fn refill_duration() -> chrono::Duration {
+        chrono::Duration::hours(1)
+    }
+
+    fn db_name() -> &'static str {
+        "count-tokens-with-language-model"
+    }
+}
+
+async fn count_tokens_with_language_model(
+    request: proto::CountTokensWithLanguageModel,
+    response: Response<proto::CountTokensWithLanguageModel>,
+    session: Session,
+    google_ai_api_key: Option<Arc<str>>,
+) -> Result<()> {
+    authorize_access_to_language_models(&session).await?;
+
+    if !request.model.starts_with("gemini") {
+        return Err(anyhow!(
+            "counting tokens for model: {:?} is not supported",
+            request.model
+        ))?;
+    }
+
+    session
+        .rate_limiter
+        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id)
+        .await?;
+
+    let api_key = google_ai_api_key
+        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
+    let tokens_response = google_ai::count_tokens(
+        &session.http_client,
+        google_ai::API_URL,
+        &api_key,
+        crate::ai::count_tokens_request_to_google_ai(request)?,
+    )
+    .await?;
+    response.send(proto::CountTokensResponse {
+        token_count: tokens_response.total_tokens as u32,
+    })?;
+    Ok(())
+}
+
+async fn authorize_access_to_language_models(session: &Session) -> Result<(), Error> {
+    let db = session.db().await;
+    let flags = db.get_user_flags(session.user_id).await?;
+    if flags.iter().any(|flag| flag == "language-models") {
+        Ok(())
+    } else {
+        Err(anyhow!("permission denied"))?
+    }
+}
+
 /// Start receiving chat updates for a channel
 async fn join_channel_chat(
     request: proto::JoinChannelChat,

crates/collab/src/tests/test_server.rs 🔗

@@ -2,7 +2,7 @@ use crate::{
     db::{tests::TestDb, NewUserParams, UserId},
     executor::Executor,
     rpc::{Server, ZedVersion, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
-    AppState, Config,
+    AppState, Config, RateLimiter,
 };
 use anyhow::anyhow;
 use call::ActiveCall;
@@ -93,17 +93,14 @@ impl TestServer {
             deterministic.clone(),
         )
         .unwrap();
-        let app_state = Self::build_app_state(&test_db, &live_kit_server).await;
+        let executor = Executor::Deterministic(deterministic.clone());
+        let app_state = Self::build_app_state(&test_db, &live_kit_server, executor.clone()).await;
         let epoch = app_state
             .db
             .create_server(&app_state.config.zed_environment)
             .await
             .unwrap();
-        let server = Server::new(
-            epoch,
-            app_state.clone(),
-            Executor::Deterministic(deterministic.clone()),
-        );
+        let server = Server::new(epoch, app_state.clone());
         server.start().await.unwrap();
         // Advance clock to ensure the server's cleanup task is finished.
         deterministic.advance_clock(CLEANUP_TIMEOUT);
@@ -482,12 +479,15 @@ impl TestServer {
 
     pub async fn build_app_state(
         test_db: &TestDb,
-        fake_server: &live_kit_client::TestServer,
+        live_kit_test_server: &live_kit_client::TestServer,
+        executor: Executor,
     ) -> Arc<AppState> {
         Arc::new(AppState {
             db: test_db.db().clone(),
-            live_kit_client: Some(Arc::new(fake_server.create_api_client())),
+            live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
             blob_store_client: None,
+            rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())),
+            executor,
             clickhouse_client: None,
             config: Config {
                 http_port: 0,
@@ -506,6 +506,8 @@ impl TestServer {
                 blob_store_access_key: None,
                 blob_store_secret_key: None,
                 blob_store_bucket: None,
+                openai_api_key: None,
+                google_ai_api_key: None,
                 clickhouse_url: None,
                 clickhouse_user: None,
                 clickhouse_password: None,

crates/google_ai/Cargo.toml 🔗

@@ -0,0 +1,14 @@
+[package]
+name = "google_ai"
+version = "0.1.0"
+edition = "2021"
+
+[lib]
+path = "src/google_ai.rs"
+
+[dependencies]
+anyhow.workspace = true
+futures.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+util.workspace = true

crates/google_ai/src/google_ai.rs 🔗

@@ -0,0 +1,266 @@
+use anyhow::{anyhow, Result};
+use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
+use serde::{Deserialize, Serialize};
+use util::http::HttpClient;
+
+pub const API_URL: &str = "https://generativelanguage.googleapis.com";
+
+pub async fn stream_generate_content<T: HttpClient>(
+    client: &T,
+    api_url: &str,
+    api_key: &str,
+    request: GenerateContentRequest,
+) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
+    let uri = format!(
+        "{}/v1beta/models/gemini-pro:streamGenerateContent?alt=sse&key={}",
+        api_url, api_key
+    );
+
+    let request = serde_json::to_string(&request)?;
+    let mut response = client.post_json(&uri, request.into()).await?;
+    if response.status().is_success() {
+        let reader = BufReader::new(response.into_body());
+        Ok(reader
+            .lines()
+            .filter_map(|line| async move {
+                match line {
+                    Ok(line) => {
+                        if let Some(line) = line.strip_prefix("data: ") {
+                            match serde_json::from_str(line) {
+                                Ok(response) => Some(Ok(response)),
+                                Err(error) => Some(Err(anyhow!(error))),
+                            }
+                        } else {
+                            None
+                        }
+                    }
+                    Err(error) => Some(Err(anyhow!(error))),
+                }
+            })
+            .boxed())
+    } else {
+        let mut text = String::new();
+        response.body_mut().read_to_string(&mut text).await?;
+        Err(anyhow!(
+            "error during streamGenerateContent, status code: {:?}, body: {}",
+            response.status(),
+            text
+        ))
+    }
+}
+
+pub async fn count_tokens<T: HttpClient>(
+    client: &T,
+    api_url: &str,
+    api_key: &str,
+    request: CountTokensRequest,
+) -> Result<CountTokensResponse> {
+    let uri = format!(
+        "{}/v1beta/models/gemini-pro:countTokens?key={}",
+        api_url, api_key
+    );
+    let request = serde_json::to_string(&request)?;
+    let mut response = client.post_json(&uri, request.into()).await?;
+    let mut text = String::new();
+    response.body_mut().read_to_string(&mut text).await?;
+    if response.status().is_success() {
+        Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
+    } else {
+        Err(anyhow!(
+            "error during countTokens, status code: {:?}, body: {}",
+            response.status(),
+            text
+        ))
+    }
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub enum Task {
+    #[serde(rename = "generateContent")]
+    GenerateContent,
+    #[serde(rename = "streamGenerateContent")]
+    StreamGenerateContent,
+    #[serde(rename = "countTokens")]
+    CountTokens,
+    #[serde(rename = "embedContent")]
+    EmbedContent,
+    #[serde(rename = "batchEmbedContents")]
+    BatchEmbedContents,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct GenerateContentRequest {
+    pub contents: Vec<Content>,
+    pub generation_config: Option<GenerationConfig>,
+    pub safety_settings: Option<Vec<SafetySetting>>,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct GenerateContentResponse {
+    pub candidates: Option<Vec<GenerateContentCandidate>>,
+    pub prompt_feedback: Option<PromptFeedback>,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct GenerateContentCandidate {
+    pub index: usize,
+    pub content: Content,
+    pub finish_reason: Option<String>,
+    pub finish_message: Option<String>,
+    pub safety_ratings: Option<Vec<SafetyRating>>,
+    pub citation_metadata: Option<CitationMetadata>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct Content {
+    pub parts: Vec<Part>,
+    pub role: Role,
+}
+
+#[derive(Debug, Deserialize, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub enum Role {
+    User,
+    Model,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum Part {
+    TextPart(TextPart),
+    InlineDataPart(InlineDataPart),
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct TextPart {
+    pub text: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct InlineDataPart {
+    pub inline_data: GenerativeContentBlob,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct GenerativeContentBlob {
+    pub mime_type: String,
+    pub data: String,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CitationSource {
+    pub start_index: Option<usize>,
+    pub end_index: Option<usize>,
+    pub uri: Option<String>,
+    pub license: Option<String>,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CitationMetadata {
+    pub citation_sources: Vec<CitationSource>,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct PromptFeedback {
+    pub block_reason: Option<String>,
+    pub safety_ratings: Vec<SafetyRating>,
+    pub block_reason_message: Option<String>,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct GenerationConfig {
+    pub candidate_count: Option<usize>,
+    pub stop_sequences: Option<Vec<String>>,
+    pub max_output_tokens: Option<usize>,
+    pub temperature: Option<f64>,
+    pub top_p: Option<f64>,
+    pub top_k: Option<usize>,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct SafetySetting {
+    pub category: HarmCategory,
+    pub threshold: HarmBlockThreshold,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub enum HarmCategory {
+    #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
+    Unspecified,
+    #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
+    Derogatory,
+    #[serde(rename = "HARM_CATEGORY_TOXICITY")]
+    Toxicity,
+    #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
+    Violence,
+    #[serde(rename = "HARM_CATEGORY_SEXUAL")]
+    Sexual,
+    #[serde(rename = "HARM_CATEGORY_MEDICAL")]
+    Medical,
+    #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
+    Dangerous,
+    #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
+    Harassment,
+    #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
+    HateSpeech,
+    #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
+    SexuallyExplicit,
+    #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
+    DangerousContent,
+}
+
+#[derive(Debug, Serialize)]
+pub enum HarmBlockThreshold {
+    #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
+    Unspecified,
+    #[serde(rename = "BLOCK_LOW_AND_ABOVE")]
+    BlockLowAndAbove,
+    #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
+    BlockMediumAndAbove,
+    #[serde(rename = "BLOCK_ONLY_HIGH")]
+    BlockOnlyHigh,
+    #[serde(rename = "BLOCK_NONE")]
+    BlockNone,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
+pub enum HarmProbability {
+    #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
+    Unspecified,
+    Negligible,
+    Low,
+    Medium,
+    High,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct SafetyRating {
+    pub category: HarmCategory,
+    pub probability: HarmProbability,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CountTokensRequest {
+    pub contents: Vec<Content>,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CountTokensResponse {
+    pub total_tokens: usize,
+}

crates/open_ai/Cargo.toml 🔗

@@ -0,0 +1,19 @@
+[package]
+name = "open_ai"
+version = "0.1.0"
+edition = "2021"
+
+[lib]
+path = "src/open_ai.rs"
+
+[features]
+default = []
+schemars = ["dep:schemars"]
+
+[dependencies]
+anyhow.workspace = true
+futures.workspace = true
+schemars = { workspace = true, optional = true }
+serde.workspace = true
+serde_json.workspace = true
+util.workspace = true

crates/open_ai/src/open_ai.rs 🔗

@@ -0,0 +1,182 @@
+use anyhow::{anyhow, Result};
+use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
+use serde::{Deserialize, Serialize};
+use std::convert::TryFrom;
+use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+    User,
+    Assistant,
+    System,
+}
+
+impl TryFrom<String> for Role {
+    type Error = anyhow::Error;
+
+    fn try_from(value: String) -> Result<Self> {
+        match value.as_str() {
+            "user" => Ok(Self::User),
+            "assistant" => Ok(Self::Assistant),
+            "system" => Ok(Self::System),
+            _ => Err(anyhow!("invalid role '{value}'")),
+        }
+    }
+}
+
+impl From<Role> for String {
+    fn from(val: Role) -> Self {
+        match val {
+            Role::User => "user".to_owned(),
+            Role::Assistant => "assistant".to_owned(),
+            Role::System => "system".to_owned(),
+        }
+    }
+}
+
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+pub enum Model {
+    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
+    ThreePointFiveTurbo,
+    #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
+    Four,
+    #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
+    #[default]
+    FourTurbo,
+}
+
+impl Model {
+    pub fn from_id(id: &str) -> Result<Self> {
+        match id {
+            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
+            "gpt-4" => Ok(Self::Four),
+            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
+            _ => Err(anyhow!("invalid model id")),
+        }
+    }
+
+    pub fn id(&self) -> &'static str {
+        match self {
+            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
+            Self::Four => "gpt-4",
+            Self::FourTurbo => "gpt-4-turbo-preview",
+        }
+    }
+
+    pub fn display_name(&self) -> &'static str {
+        match self {
+            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
+            Self::Four => "gpt-4",
+            Self::FourTurbo => "gpt-4-turbo",
+        }
+    }
+}
+
+#[derive(Debug, Serialize)]
+pub struct Request {
+    pub model: Model,
+    pub messages: Vec<RequestMessage>,
+    pub stream: bool,
+    pub stop: Vec<String>,
+    pub temperature: f32,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct RequestMessage {
+    pub role: Role,
+    pub content: String,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ResponseMessage {
+    pub role: Option<Role>,
+    pub content: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct Usage {
+    pub prompt_tokens: u32,
+    pub completion_tokens: u32,
+    pub total_tokens: u32,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct ChoiceDelta {
+    pub index: u32,
+    pub delta: ResponseMessage,
+    pub finish_reason: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct ResponseStreamEvent {
+    pub created: u32,
+    pub model: String,
+    pub choices: Vec<ChoiceDelta>,
+    pub usage: Option<Usage>,
+}
+
+pub async fn stream_completion(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    request: Request,
+) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
+    let uri = format!("{api_url}/chat/completions");
+    let request = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(uri)
+        .header("Content-Type", "application/json")
+        .header("Authorization", format!("Bearer {}", api_key))
+        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
+    let mut response = client.send(request).await?;
+    if response.status().is_success() {
+        let reader = BufReader::new(response.into_body());
+        Ok(reader
+            .lines()
+            .filter_map(|line| async move {
+                match line {
+                    Ok(line) => {
+                        let line = line.strip_prefix("data: ")?;
+                        if line == "[DONE]" {
+                            None
+                        } else {
+                            match serde_json::from_str(line) {
+                                Ok(response) => Some(Ok(response)),
+                                Err(error) => Some(Err(anyhow!(error))),
+                            }
+                        }
+                    }
+                    Err(error) => Some(Err(anyhow!(error))),
+                }
+            })
+            .boxed())
+    } else {
+        let mut body = String::new();
+        response.body_mut().read_to_string(&mut body).await?;
+
+        #[derive(Deserialize)]
+        struct OpenAiResponse {
+            error: OpenAiError,
+        }
+
+        #[derive(Deserialize)]
+        struct OpenAiError {
+            message: String,
+        }
+
+        match serde_json::from_str::<OpenAiResponse>(&body) {
+            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
+                "Failed to connect to OpenAI API: {}",
+                response.error.message,
+            )),
+
+            _ => Err(anyhow!(
+                "Failed to connect to OpenAI API: {} {}",
+                response.status(),
+                body,
+            )),
+        }
+    }
+}

crates/rpc/proto/zed.proto 🔗

@@ -1,7 +1,7 @@
 syntax = "proto3";
 package zed.messages;
 
-// Looking for a number? Search "// Current max"
+// Looking for a number? Search "// current max"
 
 message PeerId {
     uint32 owner_id = 1;
@@ -26,6 +26,7 @@ message Envelope {
         Error error = 6;
         Ping ping = 7;
         Test test = 8;
+        EndStream end_stream = 165;
 
         CreateRoom create_room = 9;
         CreateRoomResponse create_room_response = 10;
@@ -198,6 +199,11 @@ message Envelope {
         GetImplementationResponse get_implementation_response = 163;
 
         JoinHostedProject join_hosted_project = 164;
+
+        CompleteWithLanguageModel complete_with_language_model = 166;
+        LanguageModelResponse language_model_response = 167;
+        CountTokensWithLanguageModel count_tokens_with_language_model = 168;
+        CountTokensResponse count_tokens_response = 169; // current max
     }
 
     reserved 158 to 161;
@@ -236,6 +242,8 @@ enum ErrorCode {
     reserved 6;
 }
 
+message EndStream {}
+
 message Test {
     uint64 id = 1;
 }
@@ -1718,3 +1726,45 @@ message SetRoomParticipantRole {
     uint64 user_id = 2;
     ChannelRole role = 3;
 }
+
+message CompleteWithLanguageModel {
+    string model = 1;
+    repeated LanguageModelRequestMessage messages = 2;
+    repeated string stop = 3;
+    float temperature = 4;
+}
+
+message LanguageModelRequestMessage {
+    LanguageModelRole role = 1;
+    string content = 2;
+}
+
+enum LanguageModelRole {
+    LanguageModelUser = 0;
+    LanguageModelAssistant = 1;
+    LanguageModelSystem = 2;
+}
+
+message LanguageModelResponseMessage {
+    optional LanguageModelRole role = 1;
+    optional string content = 2;
+}
+
+message LanguageModelResponse {
+    repeated LanguageModelChoiceDelta choices = 1;
+}
+
+message LanguageModelChoiceDelta {
+    uint32 index = 1;
+    LanguageModelResponseMessage delta = 2;
+    optional string finish_reason = 3;
+}
+
+message CountTokensWithLanguageModel {
+    string model = 1;
+    repeated LanguageModelRequestMessage messages = 2;
+}
+
+message CountTokensResponse {
+    uint32 token_count = 1;
+}

crates/rpc/src/error.rs 🔗

@@ -80,7 +80,7 @@ pub trait ErrorExt {
     fn error_tag(&self, k: &str) -> Option<&str>;
     /// to_proto() converts the error into a proto::Error
     fn to_proto(&self) -> proto::Error;
-    ///
+    /// Clones the error and turns into an [anyhow::Error].
     fn cloned(&self) -> anyhow::Error;
 }
 

crates/rpc/src/peer.rs 🔗

@@ -9,19 +9,21 @@ use collections::HashMap;
 use futures::{
     channel::{mpsc, oneshot},
     stream::BoxStream,
-    FutureExt, SinkExt, StreamExt, TryFutureExt,
+    FutureExt, SinkExt, Stream, StreamExt, TryFutureExt,
 };
 use parking_lot::{Mutex, RwLock};
 use serde::{ser::SerializeStruct, Serialize};
-use std::{fmt, sync::atomic::Ordering::SeqCst, time::Instant};
 use std::{
+    fmt, future,
     future::Future,
     marker::PhantomData,
+    sync::atomic::Ordering::SeqCst,
     sync::{
         atomic::{self, AtomicU32},
         Arc,
     },
     time::Duration,
+    time::Instant,
 };
 use tracing::instrument;
 
@@ -118,6 +120,15 @@ pub struct ConnectionState {
             >,
         >,
     >,
+    #[allow(clippy::type_complexity)]
+    #[serde(skip)]
+    stream_response_channels: Arc<
+        Mutex<
+            Option<
+                HashMap<u32, mpsc::UnboundedSender<(Result<proto::Envelope>, oneshot::Sender<()>)>>,
+            >,
+        >,
+    >,
 }
 
 const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
@@ -171,17 +182,28 @@ impl Peer {
             outgoing_tx,
             next_message_id: Default::default(),
             response_channels: Arc::new(Mutex::new(Some(Default::default()))),
+            stream_response_channels: Arc::new(Mutex::new(Some(Default::default()))),
         };
         let mut writer = MessageStream::new(connection.tx);
         let mut reader = MessageStream::new(connection.rx);
 
         let this = self.clone();
         let response_channels = connection_state.response_channels.clone();
+        let stream_response_channels = connection_state.stream_response_channels.clone();
+
         let handle_io = async move {
             tracing::trace!(%connection_id, "handle io future: start");
 
             let _end_connection = util::defer(|| {
                 response_channels.lock().take();
+                if let Some(channels) = stream_response_channels.lock().take() {
+                    for channel in channels.values() {
+                        let _ = channel.unbounded_send((
+                            Err(anyhow!("connection closed")),
+                            oneshot::channel().0,
+                        ));
+                    }
+                }
                 this.connections.write().remove(&connection_id);
                 tracing::trace!(%connection_id, "handle io future: end");
             });
@@ -273,12 +295,14 @@ impl Peer {
         };
 
         let response_channels = connection_state.response_channels.clone();
+        let stream_response_channels = connection_state.stream_response_channels.clone();
         self.connections
             .write()
             .insert(connection_id, connection_state);
 
         let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| {
             let response_channels = response_channels.clone();
+            let stream_response_channels = stream_response_channels.clone();
             async move {
                 let message_id = incoming.id;
                 tracing::trace!(?incoming, "incoming message future: start");
@@ -293,8 +317,15 @@ impl Peer {
                         responding_to,
                         "incoming response: received"
                     );
-                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
-                    if let Some(tx) = channel {
+                    let response_channel =
+                        response_channels.lock().as_mut()?.remove(&responding_to);
+                    let stream_response_channel = stream_response_channels
+                        .lock()
+                        .as_ref()?
+                        .get(&responding_to)
+                        .cloned();
+
+                    if let Some(tx) = response_channel {
                         let requester_resumed = oneshot::channel();
                         if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) {
                             tracing::trace!(
@@ -319,6 +350,31 @@ impl Peer {
                             responding_to,
                             "incoming response: requester resumed"
                         );
+                    } else if let Some(tx) = stream_response_channel {
+                        let requester_resumed = oneshot::channel();
+                        if let Err(error) = tx.unbounded_send((Ok(incoming), requester_resumed.0)) {
+                            tracing::debug!(
+                                %connection_id,
+                                message_id,
+                                responding_to = responding_to,
+                                ?error,
+                                "incoming stream response: request future dropped",
+                            );
+                        }
+
+                        tracing::debug!(
+                            %connection_id,
+                            message_id,
+                            responding_to,
+                            "incoming stream response: waiting to resume requester"
+                        );
+                        let _ = requester_resumed.1.await;
+                        tracing::debug!(
+                            %connection_id,
+                            message_id,
+                            responding_to,
+                            "incoming stream response: requester resumed"
+                        );
                     } else {
                         let message_type =
                             proto::build_typed_envelope(connection_id, received_at, incoming)
@@ -451,6 +507,66 @@ impl Peer {
         }
     }
 
+    pub fn request_stream<T: RequestMessage>(
+        &self,
+        receiver_id: ConnectionId,
+        request: T,
+    ) -> impl Future<Output = Result<impl Unpin + Stream<Item = Result<T::Response>>>> {
+        let (tx, rx) = mpsc::unbounded();
+        let send = self.connection_state(receiver_id).and_then(|connection| {
+            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
+            let stream_response_channels = connection.stream_response_channels.clone();
+            stream_response_channels
+                .lock()
+                .as_mut()
+                .ok_or_else(|| anyhow!("connection was closed"))?
+                .insert(message_id, tx);
+            connection
+                .outgoing_tx
+                .unbounded_send(proto::Message::Envelope(
+                    request.into_envelope(message_id, None, None),
+                ))
+                .map_err(|_| anyhow!("connection was closed"))?;
+            Ok((message_id, stream_response_channels))
+        });
+
+        async move {
+            let (message_id, stream_response_channels) = send?;
+            let stream_response_channels = Arc::downgrade(&stream_response_channels);
+
+            Ok(rx.filter_map(move |(response, _barrier)| {
+                let stream_response_channels = stream_response_channels.clone();
+                future::ready(match response {
+                    Ok(response) => {
+                        if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
+                            Some(Err(anyhow!(
+                                "RPC request {} failed - {}",
+                                T::NAME,
+                                error.message
+                            )))
+                        } else if let Some(proto::envelope::Payload::EndStream(_)) =
+                            &response.payload
+                        {
+                            // Remove the transmitting end of the response channel to end the stream.
+                            if let Some(channels) = stream_response_channels.upgrade() {
+                                if let Some(channels) = channels.lock().as_mut() {
+                                    channels.remove(&message_id);
+                                }
+                            }
+                            None
+                        } else {
+                            Some(
+                                T::Response::from_envelope(response)
+                                    .ok_or_else(|| anyhow!("received response of the wrong type")),
+                            )
+                        }
+                    }
+                    Err(error) => Some(Err(error)),
+                })
+            }))
+        }
+    }
+
     pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
         let connection = self.connection_state(receiver_id)?;
         let message_id = connection
@@ -503,6 +619,24 @@ impl Peer {
         Ok(())
     }
 
+    pub fn end_stream<T: RequestMessage>(&self, receipt: Receipt<T>) -> Result<()> {
+        let connection = self.connection_state(receipt.sender_id)?;
+        let message_id = connection
+            .next_message_id
+            .fetch_add(1, atomic::Ordering::SeqCst);
+
+        let message = proto::EndStream {};
+
+        connection
+            .outgoing_tx
+            .unbounded_send(proto::Message::Envelope(message.into_envelope(
+                message_id,
+                Some(receipt.message_id),
+                None,
+            )))?;
+        Ok(())
+    }
+
     pub fn respond_with_error<T: RequestMessage>(
         &self,
         receipt: Receipt<T>,

crates/rpc/src/proto.rs 🔗

@@ -149,7 +149,10 @@ messages!(
     (CallCanceled, Foreground),
     (CancelCall, Foreground),
     (ChannelMessageSent, Foreground),
+    (CompleteWithLanguageModel, Background),
     (CopyProjectEntry, Foreground),
+    (CountTokensWithLanguageModel, Background),
+    (CountTokensResponse, Background),
     (CreateBufferForPeer, Foreground),
     (CreateChannel, Foreground),
     (CreateChannelResponse, Foreground),
@@ -160,6 +163,7 @@ messages!(
     (DeleteChannel, Foreground),
     (DeleteNotification, Foreground),
     (DeleteProjectEntry, Foreground),
+    (EndStream, Foreground),
     (Error, Foreground),
     (ExpandProjectEntry, Foreground),
     (ExpandProjectEntryResponse, Foreground),
@@ -211,6 +215,7 @@ messages!(
     (JoinProjectResponse, Foreground),
     (JoinRoom, Foreground),
     (JoinRoomResponse, Foreground),
+    (LanguageModelResponse, Background),
     (LeaveChannelBuffer, Background),
     (LeaveChannelChat, Foreground),
     (LeaveProject, Foreground),
@@ -300,6 +305,8 @@ request_messages!(
     (Call, Ack),
     (CancelCall, Ack),
     (CopyProjectEntry, ProjectEntryResponse),
+    (CompleteWithLanguageModel, LanguageModelResponse),
+    (CountTokensWithLanguageModel, CountTokensResponse),
     (CreateChannel, CreateChannelResponse),
     (CreateProjectEntry, ProjectEntryResponse),
     (CreateRoom, CreateRoomResponse),

crates/search/Cargo.toml 🔗

@@ -22,7 +22,6 @@ gpui.workspace = true
 language.workspace = true
 menu.workspace = true
 project.workspace = true
-semantic_index.workspace = true
 serde.workspace = true
 serde_json.workspace = true
 settings.workspace = true

crates/search/src/buffer_search.rs 🔗

@@ -705,11 +705,6 @@ impl BufferSearchBar {
         option.as_button(is_active, action)
     }
     pub fn activate_search_mode(&mut self, mode: SearchMode, cx: &mut ViewContext<Self>) {
-        assert_ne!(
-            mode,
-            SearchMode::Semantic,
-            "Semantic search is not supported in buffer search"
-        );
         if mode == self.current_mode {
             return;
         }
@@ -1022,7 +1017,7 @@ impl BufferSearchBar {
         }
     }
     fn cycle_mode(&mut self, _: &CycleMode, cx: &mut ViewContext<Self>) {
-        self.activate_search_mode(next_mode(&self.current_mode, false), cx);
+        self.activate_search_mode(next_mode(&self.current_mode), cx);
     }
     fn toggle_replace(&mut self, _: &ToggleReplace, cx: &mut ViewContext<Self>) {
         if let Some(_) = &self.active_searchable_item {

crates/search/src/mode.rs 🔗

@@ -1,13 +1,12 @@
 use gpui::{Action, SharedString};
 
-use crate::{ActivateRegexMode, ActivateSemanticMode, ActivateTextMode};
+use crate::{ActivateRegexMode, ActivateTextMode};
 
 // TODO: Update the default search mode to get from config
 #[derive(Copy, Clone, Debug, Default, PartialEq)]
 pub enum SearchMode {
     #[default]
     Text,
-    Semantic,
     Regex,
 }
 
@@ -15,7 +14,6 @@ impl SearchMode {
     pub(crate) fn label(&self) -> &'static str {
         match self {
             SearchMode::Text => "Text",
-            SearchMode::Semantic => "Semantic",
             SearchMode::Regex => "Regex",
         }
     }
@@ -25,22 +23,14 @@ impl SearchMode {
     pub(crate) fn action(&self) -> Box<dyn Action> {
         match self {
             SearchMode::Text => ActivateTextMode.boxed_clone(),
-            SearchMode::Semantic => ActivateSemanticMode.boxed_clone(),
             SearchMode::Regex => ActivateRegexMode.boxed_clone(),
         }
     }
 }
 
-pub(crate) fn next_mode(mode: &SearchMode, semantic_enabled: bool) -> SearchMode {
+pub(crate) fn next_mode(mode: &SearchMode) -> SearchMode {
     match mode {
         SearchMode::Text => SearchMode::Regex,
-        SearchMode::Regex => {
-            if semantic_enabled {
-                SearchMode::Semantic
-            } else {
-                SearchMode::Text
-            }
-        }
-        SearchMode::Semantic => SearchMode::Text,
+        SearchMode::Regex => SearchMode::Text,
     }
 }

crates/search/src/project_search.rs 🔗

@@ -1,33 +1,26 @@
 use crate::{
-    history::SearchHistory, mode::SearchMode, ActivateRegexMode, ActivateSemanticMode,
-    ActivateTextMode, CycleMode, NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext,
-    SearchOptions, SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleIncludeIgnored,
-    ToggleReplace, ToggleWholeWord,
+    history::SearchHistory, mode::SearchMode, ActivateRegexMode, ActivateTextMode, CycleMode,
+    NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext, SearchOptions,
+    SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleIncludeIgnored, ToggleReplace,
+    ToggleWholeWord,
 };
-use anyhow::{Context as _, Result};
-use collections::HashMap;
+use anyhow::Context as _;
+use collections::{HashMap, HashSet};
 use editor::{
     actions::SelectAll,
     items::active_match_index,
     scroll::{Autoscroll, Axis},
-    Anchor, Editor, EditorEvent, MultiBuffer, MAX_TAB_TITLE_LEN,
+    Anchor, Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer, MAX_TAB_TITLE_LEN,
 };
-use editor::{EditorElement, EditorStyle};
 use gpui::{
     actions, div, Action, AnyElement, AnyView, AppContext, Context as _, Element, EntityId,
     EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight, Global, Hsla,
-    InteractiveElement, IntoElement, KeyContext, Model, ModelContext, ParentElement, Point,
-    PromptLevel, Render, SharedString, Styled, Subscription, Task, TextStyle, View, ViewContext,
-    VisualContext, WeakModel, WeakView, WhiteSpace, WindowContext,
+    InteractiveElement, IntoElement, KeyContext, Model, ModelContext, ParentElement, Point, Render,
+    SharedString, Styled, Subscription, Task, TextStyle, View, ViewContext, VisualContext,
+    WeakModel, WeakView, WhiteSpace, WindowContext,
 };
 use menu::Confirm;
-use project::{
-    search::{SearchInputs, SearchQuery},
-    Project,
-};
-use semantic_index::{SemanticIndex, SemanticIndexStatus};
-
-use collections::HashSet;
+use project::{search::SearchQuery, Project};
 use settings::Settings;
 use smol::stream::StreamExt;
 use std::{
@@ -35,22 +28,20 @@ use std::{
     mem,
     ops::{Not, Range},
     path::{Path, PathBuf},
-    time::{Duration, Instant},
 };
 use theme::ThemeSettings;
-use workspace::{DeploySearch, NewSearch};
-
 use ui::{
     h_flex, prelude::*, v_flex, Icon, IconButton, IconName, Label, LabelCommon, LabelSize,
     Selectable, ToggleButton, Tooltip,
 };
-use util::{paths::PathMatcher, ResultExt as _};
+use util::paths::PathMatcher;
 use workspace::{
     item::{BreadcrumbText, Item, ItemEvent, ItemHandle},
     searchable::{Direction, SearchableItem, SearchableItemHandle},
     ItemNavHistory, Pane, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
     WorkspaceId,
 };
+use workspace::{DeploySearch, NewSearch};
 
 const MIN_INPUT_WIDTH_REMS: f32 = 15.;
 const MAX_INPUT_WIDTH_REMS: f32 = 30.;
@@ -86,12 +77,6 @@ pub fn init(cx: &mut AppContext) {
         register_workspace_action(workspace, move |search_bar, _: &ActivateTextMode, cx| {
             search_bar.activate_search_mode(SearchMode::Text, cx)
         });
-        register_workspace_action(
-            workspace,
-            move |search_bar, _: &ActivateSemanticMode, cx| {
-                search_bar.activate_search_mode(SearchMode::Semantic, cx)
-            },
-        );
         register_workspace_action(workspace, move |search_bar, action: &CycleMode, cx| {
             search_bar.cycle_mode(action, cx)
         });
@@ -159,8 +144,6 @@ pub struct ProjectSearchView {
     query_editor: View<Editor>,
     replacement_editor: View<Editor>,
     results_editor: View<Editor>,
-    semantic_state: Option<SemanticState>,
-    semantic_permissioned: Option<bool>,
     search_options: SearchOptions,
     panels_with_errors: HashSet<InputPanel>,
     active_match_index: Option<usize>,
@@ -174,12 +157,6 @@ pub struct ProjectSearchView {
     _subscriptions: Vec<Subscription>,
 }
 
-struct SemanticState {
-    index_status: SemanticIndexStatus,
-    maintain_rate_limit: Option<Task<()>>,
-    _subscription: Subscription,
-}
-
 #[derive(Debug, Clone)]
 struct ProjectSearchSettings {
     search_options: SearchOptions,
@@ -282,68 +259,6 @@ impl ProjectSearch {
         }));
         cx.notify();
     }
-
-    fn semantic_search(&mut self, inputs: &SearchInputs, cx: &mut ModelContext<Self>) {
-        let search = SemanticIndex::global(cx).map(|index| {
-            index.update(cx, |semantic_index, cx| {
-                semantic_index.search_project(
-                    self.project.clone(),
-                    inputs.as_str().to_owned(),
-                    10,
-                    inputs.files_to_include().to_vec(),
-                    inputs.files_to_exclude().to_vec(),
-                    cx,
-                )
-            })
-        });
-        self.search_id += 1;
-        self.match_ranges.clear();
-        self.search_history.add(inputs.as_str().to_string());
-        self.no_results = None;
-        self.pending_search = Some(cx.spawn(|this, mut cx| async move {
-            let results = search?.await.log_err()?;
-            let matches = results
-                .into_iter()
-                .map(|result| (result.buffer, vec![result.range.start..result.range.start]));
-
-            this.update(&mut cx, |this, cx| {
-                this.no_results = Some(true);
-                this.excerpts.update(cx, |excerpts, cx| {
-                    excerpts.clear(cx);
-                });
-            })
-            .ok()?;
-            for (buffer, ranges) in matches {
-                let mut match_ranges = this
-                    .update(&mut cx, |this, cx| {
-                        this.no_results = Some(false);
-                        this.excerpts.update(cx, |excerpts, cx| {
-                            excerpts.stream_excerpts_with_context_lines(buffer, ranges, 3, cx)
-                        })
-                    })
-                    .ok()?;
-                while let Some(match_range) = match_ranges.next().await {
-                    this.update(&mut cx, |this, cx| {
-                        this.match_ranges.push(match_range);
-                        while let Ok(Some(match_range)) = match_ranges.try_next() {
-                            this.match_ranges.push(match_range);
-                        }
-                        cx.notify();
-                    })
-                    .ok()?;
-                }
-            }
-
-            this.update(&mut cx, |this, cx| {
-                this.pending_search.take();
-                cx.notify();
-            })
-            .ok()?;
-
-            None
-        }));
-        cx.notify();
-    }
 }
 
 #[derive(Clone, Debug, PartialEq, Eq)]
@@ -358,8 +273,6 @@ impl EventEmitter<ViewEvent> for ProjectSearchView {}
 
 impl Render for ProjectSearchView {
     fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
-        const PLEASE_AUTHENTICATE: &str = "API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables. If you authenticated using the Assistant Panel, please restart Zed to Authenticate.";
-
         if self.has_matches() {
             div()
                 .flex_1()
@@ -370,7 +283,7 @@ impl Render for ProjectSearchView {
             let model = self.model.read(cx);
             let has_no_results = model.no_results.unwrap_or(false);
             let is_search_underway = model.pending_search.is_some();
-            let mut major_text = if is_search_underway {
+            let major_text = if is_search_underway {
                 Label::new("Searching...")
             } else if has_no_results {
                 Label::new("No results")
@@ -378,43 +291,6 @@ impl Render for ProjectSearchView {
                 Label::new(format!("{} search all files", self.current_mode.label()))
             };
 
-            let mut show_minor_text = true;
-            let semantic_status = self.semantic_state.as_ref().and_then(|semantic| {
-                let status = semantic.index_status;
-                match status {
-                    SemanticIndexStatus::NotAuthenticated => {
-                        major_text = Label::new("Not Authenticated");
-                        show_minor_text = false;
-                        Some(PLEASE_AUTHENTICATE.to_string())
-                    }
-                    SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()),
-                    SemanticIndexStatus::Indexing {
-                        remaining_files,
-                        rate_limit_expiry,
-                    } => {
-                        if remaining_files == 0 {
-                            Some("Indexing...".to_string())
-                        } else {
-                            if let Some(rate_limit_expiry) = rate_limit_expiry {
-                                let remaining_seconds =
-                                    rate_limit_expiry.duration_since(Instant::now());
-                                if remaining_seconds > Duration::from_secs(0) {
-                                    Some(format!(
-                                        "Remaining files to index (rate limit resets in {}s): {}",
-                                        remaining_seconds.as_secs(),
-                                        remaining_files
-                                    ))
-                                } else {
-                                    Some(format!("Remaining files to index: {}", remaining_files))
-                                }
-                            } else {
-                                Some(format!("Remaining files to index: {}", remaining_files))
-                            }
-                        }
-                    }
-                    SemanticIndexStatus::NotIndexed => None,
-                }
-            });
             let major_text = div().justify_center().max_w_96().child(major_text);
 
             let minor_text: Option<SharedString> = if let Some(no_results) = model.no_results {
@@ -424,12 +300,7 @@ impl Render for ProjectSearchView {
                     None
                 }
             } else {
-                if let Some(mut semantic_status) = semantic_status {
-                    semantic_status.extend(self.landing_text_minor().chars());
-                    Some(semantic_status.into())
-                } else {
-                    Some(self.landing_text_minor())
-                }
+                Some(self.landing_text_minor())
             };
             let minor_text = minor_text.map(|text| {
                 div()
@@ -676,58 +547,6 @@ impl ProjectSearchView {
         });
     }
 
-    fn index_project(&mut self, cx: &mut ViewContext<Self>) {
-        if let Some(semantic_index) = SemanticIndex::global(cx) {
-            // Semantic search uses no options
-            self.search_options = SearchOptions::none();
-
-            let project = self.model.read(cx).project.clone();
-
-            semantic_index.update(cx, |semantic_index, cx| {
-                semantic_index
-                    .index_project(project.clone(), cx)
-                    .detach_and_log_err(cx);
-            });
-
-            self.semantic_state = Some(SemanticState {
-                index_status: semantic_index.read(cx).status(&project),
-                maintain_rate_limit: None,
-                _subscription: cx.observe(&semantic_index, Self::semantic_index_changed),
-            });
-            self.semantic_index_changed(semantic_index, cx);
-        }
-    }
-
-    fn semantic_index_changed(
-        &mut self,
-        semantic_index: Model<SemanticIndex>,
-        cx: &mut ViewContext<Self>,
-    ) {
-        let project = self.model.read(cx).project.clone();
-        if let Some(semantic_state) = self.semantic_state.as_mut() {
-            cx.notify();
-            semantic_state.index_status = semantic_index.read(cx).status(&project);
-            if let SemanticIndexStatus::Indexing {
-                rate_limit_expiry: Some(_),
-                ..
-            } = &semantic_state.index_status
-            {
-                if semantic_state.maintain_rate_limit.is_none() {
-                    semantic_state.maintain_rate_limit =
-                        Some(cx.spawn(|this, mut cx| async move {
-                            loop {
-                                cx.background_executor().timer(Duration::from_secs(1)).await;
-                                this.update(&mut cx, |_, cx| cx.notify()).log_err();
-                            }
-                        }));
-                    return;
-                }
-            } else {
-                semantic_state.maintain_rate_limit = None;
-            }
-        }
-    }
-
     fn clear_search(&mut self, cx: &mut ViewContext<Self>) {
         self.model.update(cx, |model, cx| {
             model.pending_search = None;
@@ -750,63 +569,7 @@ impl ProjectSearchView {
         self.clear_search(cx);
         self.current_mode = mode;
         self.active_match_index = None;
-
-        match mode {
-            SearchMode::Semantic => {
-                let has_permission = self.semantic_permissioned(cx);
-                self.active_match_index = None;
-                cx.spawn(|this, mut cx| async move {
-                    let has_permission = has_permission.await?;
-
-                    if !has_permission {
-                        let answer = this.update(&mut cx, |this, cx| {
-                            let project = this.model.read(cx).project.clone();
-                            let project_name = project
-                                .read(cx)
-                                .worktree_root_names(cx)
-                                .collect::<Vec<&str>>()
-                                .join("/");
-                            let is_plural =
-                                project_name.chars().filter(|letter| *letter == '/').count() > 0;
-                            let prompt_text = format!("Would you like to index the '{}' project{} for semantic search? This requires sending code to the OpenAI API", project_name,
-                                if is_plural {
-                                    "s"
-                                } else {""});
-                            cx.prompt(
-                                PromptLevel::Info,
-                                prompt_text.as_str(),
-                                None,
-                                &["Continue", "Cancel"],
-                            )
-                        })?;
-
-                        if answer.await? == 0 {
-                            this.update(&mut cx, |this, _| {
-                                this.semantic_permissioned = Some(true);
-                            })?;
-                        } else {
-                            this.update(&mut cx, |this, cx| {
-                                this.semantic_permissioned = Some(false);
-                                debug_assert_ne!(previous_mode, SearchMode::Semantic, "Tried to re-enable semantic search mode after user modal was rejected");
-                                this.activate_search_mode(previous_mode, cx);
-                            })?;
-                            return anyhow::Ok(());
-                        }
-                    }
-
-                    this.update(&mut cx, |this, cx| {
-                        this.index_project(cx);
-                    })?;
-
-                    anyhow::Ok(())
-                }).detach_and_log_err(cx);
-            }
-            SearchMode::Regex | SearchMode::Text => {
-                self.semantic_state = None;
-                self.active_match_index = None;
-                self.search(cx);
-            }
-        }
+        self.search(cx);
 
         cx.update_global(|state: &mut ActiveSettings, cx| {
             state.0.insert(
@@ -973,8 +736,6 @@ impl ProjectSearchView {
             model,
             query_editor,
             results_editor,
-            semantic_state: None,
-            semantic_permissioned: None,
             search_options: options,
             panels_with_errors: HashSet::default(),
             active_match_index: None,
@@ -990,19 +751,6 @@ impl ProjectSearchView {
         this
     }
 
-    fn semantic_permissioned(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<bool>> {
-        if let Some(value) = self.semantic_permissioned {
-            return Task::ready(Ok(value));
-        }
-
-        SemanticIndex::global(cx)
-            .map(|semantic| {
-                let project = self.model.read(cx).project.clone();
-                semantic.update(cx, |this, cx| this.project_previously_indexed(&project, cx))
-            })
-            .unwrap_or(Task::ready(Ok(false)))
-    }
-
     pub fn new_search_in_directory(
         workspace: &mut Workspace,
         dir_path: &Path,
@@ -1126,22 +874,8 @@ impl ProjectSearchView {
     }
 
     fn search(&mut self, cx: &mut ViewContext<Self>) {
-        let mode = self.current_mode;
-        match mode {
-            SearchMode::Semantic => {
-                if self.semantic_state.is_some() {
-                    if let Some(query) = self.build_search_query(cx) {
-                        self.model
-                            .update(cx, |model, cx| model.semantic_search(query.as_inner(), cx));
-                    }
-                }
-            }
-
-            _ => {
-                if let Some(query) = self.build_search_query(cx) {
-                    self.model.update(cx, |model, cx| model.search(query, cx));
-                }
-            }
+        if let Some(query) = self.build_search_query(cx) {
+            self.model.update(cx, |model, cx| model.search(query, cx));
         }
     }
 
@@ -1356,7 +1090,6 @@ impl ProjectSearchView {
     fn landing_text_minor(&self) -> SharedString {
         match self.current_mode {
             SearchMode::Text | SearchMode::Regex => "Include/exclude specific paths with the filter option. Matching exact word and/or casing is available too.".into(),
-            SearchMode::Semantic => "\nSimply explain the code you are looking to find. ex. 'prompt user for permissions to index their project'".into()
         }
     }
     fn border_color_for(&self, panel: InputPanel, cx: &WindowContext) -> Hsla {
@@ -1387,8 +1120,7 @@ impl ProjectSearchBar {
     fn cycle_mode(&self, _: &CycleMode, cx: &mut ViewContext<Self>) {
         if let Some(view) = self.active_project_search.as_ref() {
             view.update(cx, |this, cx| {
-                let new_mode =
-                    crate::mode::next_mode(&this.current_mode, SemanticIndex::enabled(cx));
+                let new_mode = crate::mode::next_mode(&this.current_mode);
                 this.activate_search_mode(new_mode, cx);
                 let editor_handle = this.query_editor.focus_handle(cx);
                 cx.focus(&editor_handle);
@@ -1681,7 +1413,6 @@ impl Render for ProjectSearchBar {
             });
         }
         let search = search.read(cx);
-        let semantic_is_available = SemanticIndex::enabled(cx);
 
         let query_column = h_flex()
             .flex_1()
@@ -1711,12 +1442,8 @@ impl Render for ProjectSearchBar {
                                     .unwrap_or_default(),
                             ),
                     )
-                    .when(search.current_mode != SearchMode::Semantic, |this| {
-                        this.child(
-                            IconButton::new(
-                                "project-search-case-sensitive",
-                                IconName::CaseSensitive,
-                            )
+                    .child(
+                        IconButton::new("project-search-case-sensitive", IconName::CaseSensitive)
                             .tooltip(|cx| {
                                 Tooltip::for_action(
                                     "Toggle case sensitive",
@@ -1728,18 +1455,17 @@ impl Render for ProjectSearchBar {
                             .on_click(cx.listener(|this, _, cx| {
                                 this.toggle_search_option(SearchOptions::CASE_SENSITIVE, cx);
                             })),
-                        )
-                        .child(
-                            IconButton::new("project-search-whole-word", IconName::WholeWord)
-                                .tooltip(|cx| {
-                                    Tooltip::for_action("Toggle whole word", &ToggleWholeWord, cx)
-                                })
-                                .selected(self.is_option_enabled(SearchOptions::WHOLE_WORD, cx))
-                                .on_click(cx.listener(|this, _, cx| {
-                                    this.toggle_search_option(SearchOptions::WHOLE_WORD, cx);
-                                })),
-                        )
-                    }),
+                    )
+                    .child(
+                        IconButton::new("project-search-whole-word", IconName::WholeWord)
+                            .tooltip(|cx| {
+                                Tooltip::for_action("Toggle whole word", &ToggleWholeWord, cx)
+                            })
+                            .selected(self.is_option_enabled(SearchOptions::WHOLE_WORD, cx))
+                            .on_click(cx.listener(|this, _, cx| {
+                                this.toggle_search_option(SearchOptions::WHOLE_WORD, cx);
+                            })),
+                    ),
             );
 
         let mode_column = v_flex().items_start().justify_start().child(
@@ -1775,33 +1501,8 @@ impl Render for ProjectSearchBar {
                                         cx,
                                     )
                                 })
-                                .map(|this| {
-                                    if semantic_is_available {
-                                        this.middle()
-                                    } else {
-                                        this.last()
-                                    }
-                                }),
-                        )
-                        .when(semantic_is_available, |this| {
-                            this.child(
-                                ToggleButton::new("project-search-semantic-button", "Semantic")
-                                    .style(ButtonStyle::Filled)
-                                    .size(ButtonSize::Large)
-                                    .selected(search.current_mode == SearchMode::Semantic)
-                                    .on_click(cx.listener(|this, _, cx| {
-                                        this.activate_search_mode(SearchMode::Semantic, cx)
-                                    }))
-                                    .tooltip(|cx| {
-                                        Tooltip::for_action(
-                                            "Toggle semantic search",
-                                            &ActivateSemanticMode,
-                                            cx,
-                                        )
-                                    })
-                                    .last(),
-                            )
-                        }),
+                                .last(),
+                        ),
                 )
                 .child(
                     IconButton::new("project-search-toggle-replace", IconName::Replace)
@@ -1929,21 +1630,16 @@ impl Render for ProjectSearchBar {
                         .border_color(search.border_color_for(InputPanel::Include, cx))
                         .rounded_lg()
                         .child(self.render_text_input(&search.included_files_editor, cx))
-                        .when(search.current_mode != SearchMode::Semantic, |this| {
-                            this.child(
-                                SearchOptions::INCLUDE_IGNORED.as_button(
-                                    search
-                                        .search_options
-                                        .contains(SearchOptions::INCLUDE_IGNORED),
-                                    cx.listener(|this, _, cx| {
-                                        this.toggle_search_option(
-                                            SearchOptions::INCLUDE_IGNORED,
-                                            cx,
-                                        );
-                                    }),
-                                ),
-                            )
-                        }),
+                        .child(
+                            SearchOptions::INCLUDE_IGNORED.as_button(
+                                search
+                                    .search_options
+                                    .contains(SearchOptions::INCLUDE_IGNORED),
+                                cx.listener(|this, _, cx| {
+                                    this.toggle_search_option(SearchOptions::INCLUDE_IGNORED, cx);
+                                }),
+                            ),
+                        ),
                 )
                 .child(
                     h_flex()
@@ -1972,9 +1668,6 @@ impl Render for ProjectSearchBar {
             .on_action(cx.listener(|this, _: &ActivateRegexMode, cx| {
                 this.activate_search_mode(SearchMode::Regex, cx)
             }))
-            .on_action(cx.listener(|this, _: &ActivateSemanticMode, cx| {
-                this.activate_search_mode(SearchMode::Semantic, cx)
-            }))
             .capture_action(cx.listener(|this, action, cx| {
                 this.tab(action, cx);
                 cx.stop_propagation();
@@ -1987,35 +1680,33 @@ impl Render for ProjectSearchBar {
             .on_action(cx.listener(|this, action, cx| {
                 this.cycle_mode(action, cx);
             }))
-            .when(search.current_mode != SearchMode::Semantic, |this| {
-                this.on_action(cx.listener(|this, action, cx| {
-                    this.toggle_replace(action, cx);
-                }))
-                .on_action(cx.listener(|this, _: &ToggleWholeWord, cx| {
-                    this.toggle_search_option(SearchOptions::WHOLE_WORD, cx);
-                }))
-                .on_action(cx.listener(|this, _: &ToggleCaseSensitive, cx| {
-                    this.toggle_search_option(SearchOptions::CASE_SENSITIVE, cx);
-                }))
-                .on_action(cx.listener(|this, action, cx| {
-                    if let Some(search) = this.active_project_search.as_ref() {
-                        search.update(cx, |this, cx| {
-                            this.replace_next(action, cx);
-                        })
-                    }
-                }))
-                .on_action(cx.listener(|this, action, cx| {
-                    if let Some(search) = this.active_project_search.as_ref() {
-                        search.update(cx, |this, cx| {
-                            this.replace_all(action, cx);
-                        })
-                    }
+            .on_action(cx.listener(|this, action, cx| {
+                this.toggle_replace(action, cx);
+            }))
+            .on_action(cx.listener(|this, _: &ToggleWholeWord, cx| {
+                this.toggle_search_option(SearchOptions::WHOLE_WORD, cx);
+            }))
+            .on_action(cx.listener(|this, _: &ToggleCaseSensitive, cx| {
+                this.toggle_search_option(SearchOptions::CASE_SENSITIVE, cx);
+            }))
+            .on_action(cx.listener(|this, action, cx| {
+                if let Some(search) = this.active_project_search.as_ref() {
+                    search.update(cx, |this, cx| {
+                        this.replace_next(action, cx);
+                    })
+                }
+            }))
+            .on_action(cx.listener(|this, action, cx| {
+                if let Some(search) = this.active_project_search.as_ref() {
+                    search.update(cx, |this, cx| {
+                        this.replace_all(action, cx);
+                    })
+                }
+            }))
+            .when(search.filters_enabled, |this| {
+                this.on_action(cx.listener(|this, _: &ToggleIncludeIgnored, cx| {
+                    this.toggle_search_option(SearchOptions::INCLUDE_IGNORED, cx);
                 }))
-                .when(search.filters_enabled, |this| {
-                    this.on_action(cx.listener(|this, _: &ToggleIncludeIgnored, cx| {
-                        this.toggle_search_option(SearchOptions::INCLUDE_IGNORED, cx);
-                    }))
-                })
             })
             .on_action(cx.listener(Self::select_next_match))
             .on_action(cx.listener(Self::select_prev_match))
@@ -2039,12 +1730,6 @@ impl ToolbarItemView for ProjectSearchBar {
         self.subscription = None;
         self.active_project_search = None;
         if let Some(search) = active_pane_item.and_then(|i| i.downcast::<ProjectSearchView>()) {
-            search.update(cx, |search, cx| {
-                if search.current_mode == SearchMode::Semantic {
-                    search.index_project(cx);
-                }
-            });
-
             self.subscription = Some(cx.observe(&search, |_, _, cx| cx.notify()));
             self.active_project_search = Some(search);
             ToolbarItemLocation::PrimaryLeft {}
@@ -2123,9 +1808,8 @@ pub mod tests {
     use editor::DisplayPoint;
     use gpui::{Action, TestAppContext, WindowHandle};
     use project::FakeFs;
-    use semantic_index::semantic_index_settings::SemanticIndexSettings;
     use serde_json::json;
-    use settings::{Settings, SettingsStore};
+    use settings::SettingsStore;
     use std::sync::Arc;
     use workspace::DeploySearch;
 
@@ -3446,8 +3130,6 @@ pub mod tests {
             let settings = SettingsStore::test(cx);
             cx.set_global(settings);
 
-            SemanticIndexSettings::register(cx);
-
             theme::init(theme::LoadThemes::JustBase, cx);
 
             language::init(cx);

crates/search/src/search.rs 🔗

@@ -33,7 +33,6 @@ actions!(
         NextHistoryQuery,
         PreviousHistoryQuery,
         ActivateTextMode,
-        ActivateSemanticMode,
         ActivateRegexMode,
         ReplaceAll,
         ReplaceNext,

crates/semantic_index/Cargo.toml 🔗

@@ -1,66 +0,0 @@
-[package]
-name = "semantic_index"
-version = "0.1.0"
-edition = "2021"
-publish = false
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/semantic_index.rs"
-doctest = false
-
-[dependencies]
-ai.workspace = true
-anyhow.workspace = true
-collections.workspace = true
-futures.workspace = true
-gpui.workspace = true
-language.workspace = true
-lazy_static.workspace = true
-log.workspace = true
-ndarray = { version = "0.15.0" }
-ordered-float.workspace = true
-parking_lot.workspace = true
-postage.workspace = true
-project.workspace = true
-rand.workspace = true
-release_channel.workspace = true
-rpc.workspace = true
-rusqlite.workspace = true
-schemars.workspace = true
-serde.workspace = true
-serde_json.workspace = true
-settings.workspace = true
-sha1 = "0.10.5"
-smol.workspace = true
-tree-sitter.workspace = true
-util.workspace = true
-workspace.workspace = true
-
-[dev-dependencies]
-ai = { workspace = true, features = ["test-support"] }
-collections = { workspace = true, features = ["test-support"] }
-ctor.workspace = true
-env_logger.workspace = true
-gpui = { workspace = true, features = ["test-support"] }
-language = { workspace = true, features = ["test-support"] }
-pretty_assertions.workspace = true
-project = { workspace = true, features = ["test-support"] }
-rand.workspace = true
-rpc = { workspace = true, features = ["test-support"] }
-settings = { workspace = true, features = ["test-support"]}
-tempfile.workspace = true
-tree-sitter-cpp.workspace = true
-tree-sitter-elixir.workspace = true
-tree-sitter-json.workspace = true
-tree-sitter-lua.workspace = true
-tree-sitter-php.workspace = true
-tree-sitter-ruby.workspace = true
-tree-sitter-rust.workspace = true
-tree-sitter-toml.workspace = true
-tree-sitter-typescript.workspace = true
-unindent.workspace = true
-workspace = { workspace = true, features = ["test-support"] }

crates/semantic_index/README.md 🔗

@@ -1,20 +0,0 @@
-
-# Semantic Index
-
-## Evaluation
-
-### Metrics
-
-nDCG@k:
-- "The value of NDCG is determined by comparing the relevance of the items returned by the search engine to the relevance of the item that a hypothetical "ideal" search engine would return.
-- "The relevance of result is represented by a score (also known as a 'grade') that is assigned to the search query. The scores of these results are then discounted based on their position in the search results -- did they get recommended first or last?"
-
-MRR@k:
-- "Mean reciprocal rank quantifies the rank of the first relevant item found in the recommendation list."
-
-MAP@k:
-- "Mean average precision averages the precision@k metric at each relevant item position in the recommendation list.
-
-Resources:
-- [Evaluating recommendation metrics](https://www.shaped.ai/blog/evaluating-recommendation-systems-map-mmr-ndcg)
-- [Math Walkthrough](https://towardsdatascience.com/demystifying-ndcg-bee3be58cfe0)

crates/semantic_index/eval/gpt-engineer.json 🔗

@@ -1,114 +0,0 @@
-{
-  "repo": "https://github.com/AntonOsika/gpt-engineer.git",
-  "commit": "7735a6445bae3611c62f521e6464c67c957f87c2",
-  "assertions": [
-    {
-      "query": "How do I contribute to this project?",
-      "matches": [
-        ".github/CONTRIBUTING.md:1",
-        "ROADMAP.md:48"
-      ]
-    },
-    {
-      "query": "What version of the openai package is active?",
-      "matches": [
-        "pyproject.toml:14"
-      ]
-    },
-    {
-      "query": "Ask user for clarification",
-      "matches": [
-        "gpt_engineer/steps.py:69"
-      ]
-    },
-    {
-      "query": "generate tests for python code",
-      "matches": [
-        "gpt_engineer/steps.py:153"
-      ]
-    },
-    {
-      "query": "get item from database based on key",
-      "matches": [
-        "gpt_engineer/db.py:42",
-        "gpt_engineer/db.py:68"
-      ]
-    },
-    {
-      "query": "prompt user to select files",
-      "matches": [
-        "gpt_engineer/file_selector.py:171",
-        "gpt_engineer/file_selector.py:306",
-        "gpt_engineer/file_selector.py:289",
-        "gpt_engineer/file_selector.py:234"
-      ]
-    },
-    {
-      "query": "send to rudderstack",
-      "matches": [
-        "gpt_engineer/collect.py:11",
-        "gpt_engineer/collect.py:38"
-      ]
-    },
-    {
-      "query": "parse code blocks from chat messages",
-      "matches": [
-        "gpt_engineer/chat_to_files.py:10",
-        "docs/intro/chat_parsing.md:1"
-      ]
-    },
-    {
-      "query": "how do I use the docker cli?",
-      "matches": [
-        "docker/README.md:1"
-      ]
-    },
-    {
-      "query": "ask the user if the code ran successfully?",
-      "matches": [
-        "gpt_engineer/learning.py:54"
-      ]
-    },
-    {
-      "query": "how is consent granted by the user?",
-      "matches": [
-        "gpt_engineer/learning.py:107",
-        "gpt_engineer/learning.py:130",
-        "gpt_engineer/learning.py:152"
-      ]
-    },
-    {
-      "query": "what are all the different steps the agent can take?",
-      "matches": [
-        "docs/intro/steps_module.md:1",
-        "gpt_engineer/steps.py:391"
-      ]
-    },
-    {
-      "query": "ask the user for clarification?",
-      "matches": [
-        "gpt_engineer/steps.py:69"
-      ]
-    },
-    {
-      "query": "what models are available?",
-      "matches": [
-        "gpt_engineer/ai.py:315",
-        "gpt_engineer/ai.py:341",
-        "docs/open-models.md:1"
-      ]
-    },
-    {
-      "query": "what is the current focus of the project?",
-      "matches": [
-        "ROADMAP.md:11"
-      ]
-    },
-    {
-      "query": "does the agent know how to fix code?",
-      "matches": [
-        "gpt_engineer/steps.py:367"
-      ]
-    }
-  ]
-}

crates/semantic_index/eval/tree-sitter.json 🔗

@@ -1,104 +0,0 @@
-{
-  "repo": "https://github.com/tree-sitter/tree-sitter.git",
-  "commit": "46af27796a76c72d8466627d499f2bca4af958ee",
-  "assertions": [
-    {
-      "query": "What attributes are available for the tags configuration struct?",
-      "matches": [
-        "tags/src/lib.rs:24"
-      ]
-    },
-    {
-      "query": "create a new tag configuration",
-      "matches": [
-        "tags/src/lib.rs:119"
-      ]
-    },
-    {
-      "query": "generate tags based on config",
-      "matches": [
-        "tags/src/lib.rs:261"
-      ]
-    },
-    {
-      "query": "match on ts quantifier in rust",
-      "matches": [
-        "lib/binding_rust/lib.rs:139"
-      ]
-    },
-    {
-      "query": "cli command to generate tags",
-      "matches": [
-        "cli/src/tags.rs:10"
-      ]
-    },
-    {
-      "query": "what version of the tree-sitter-tags package is active?",
-      "matches": [
-        "tags/Cargo.toml:4"
-      ]
-    },
-    {
-      "query": "Insert a new parse state",
-      "matches": [
-        "cli/src/generate/build_tables/build_parse_table.rs:153"
-      ]
-    },
-    {
-      "query": "Handle conflict when numerous actions occur on the same symbol",
-      "matches": [
-        "cli/src/generate/build_tables/build_parse_table.rs:363",
-        "cli/src/generate/build_tables/build_parse_table.rs:442"
-      ]
-    },
-    {
-      "query": "Match based on associativity of actions",
-      "matches": [
-        "cri/src/generate/build_tables/build_parse_table.rs:542"
-      ]
-    },
-    {
-      "query": "Format token set display",
-      "matches": [
-        "cli/src/generate/build_tables/item.rs:246"
-      ]
-    },
-    {
-      "query": "extract choices from rule",
-      "matches": [
-        "cli/src/generate/prepare_grammar/flatten_grammar.rs:124"
-      ]
-    },
-    {
-      "query": "How do we identify if a symbol is being used?",
-      "matches": [
-        "cli/src/generate/prepare_grammar/flatten_grammar.rs:175"
-      ]
-    },
-    {
-      "query": "How do we launch the playground?",
-      "matches": [
-        "cli/src/playground.rs:46"
-      ]
-    },
-    {
-      "query": "How do we test treesitter query matches in rust?",
-      "matches": [
-        "cli/src/query_testing.rs:152",
-        "cli/src/tests/query_test.rs:781",
-        "cli/src/tests/query_test.rs:2163",
-        "cli/src/tests/query_test.rs:3781",
-        "cli/src/tests/query_test.rs:887"
-      ]
-    },
-    {
-      "query": "What does the CLI do?",
-      "matches": [
-        "cli/README.md:10",
-        "cli/loader/README.md:3",
-        "docs/section-5-implementation.md:14",
-        "docs/section-5-implementation.md:18"
-      ]
-    }
-  ]
-}

crates/semantic_index/src/db.rs 🔗

@@ -1,594 +0,0 @@
-use crate::{
-    parsing::{Span, SpanDigest},
-    SEMANTIC_INDEX_VERSION,
-};
-use ai::embedding::Embedding;
-use anyhow::{anyhow, Context, Result};
-use collections::HashMap;
-use futures::channel::oneshot;
-use gpui::BackgroundExecutor;
-use ndarray::{Array1, Array2};
-use ordered_float::OrderedFloat;
-use project::Fs;
-use rpc::proto::Timestamp;
-use rusqlite::params;
-use rusqlite::types::Value;
-use std::{
-    future::Future,
-    ops::Range,
-    path::{Path, PathBuf},
-    rc::Rc,
-    sync::Arc,
-    time::SystemTime,
-};
-use util::{paths::PathMatcher, TryFutureExt};
-
-pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
-    let mut indices = (0..data.len()).collect::<Vec<_>>();
-    indices.sort_by_key(|&i| &data[i]);
-    indices.reverse();
-    indices
-}
-
-#[derive(Debug)]
-pub struct FileRecord {
-    pub id: usize,
-    pub relative_path: String,
-    pub mtime: Timestamp,
-}
-
-#[derive(Clone)]
-pub struct VectorDatabase {
-    path: Arc<Path>,
-    transactions:
-        smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
-}
-
-impl VectorDatabase {
-    pub async fn new(
-        fs: Arc<dyn Fs>,
-        path: Arc<Path>,
-        executor: BackgroundExecutor,
-    ) -> Result<Self> {
-        if let Some(db_directory) = path.parent() {
-            fs.create_dir(db_directory).await?;
-        }
-
-        let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
-            Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
-        >();
-        executor
-            .spawn({
-                let path = path.clone();
-                async move {
-                    let mut connection = rusqlite::Connection::open(&path)?;
-
-                    connection.pragma_update(None, "journal_mode", "wal")?;
-                    connection.pragma_update(None, "synchronous", "normal")?;
-                    connection.pragma_update(None, "cache_size", 1000000)?;
-                    connection.pragma_update(None, "temp_store", "MEMORY")?;
-
-                    while let Ok(transaction) = transactions_rx.recv().await {
-                        transaction(&mut connection);
-                    }
-
-                    anyhow::Ok(())
-                }
-                .log_err()
-            })
-            .detach();
-        let this = Self {
-            transactions: transactions_tx,
-            path,
-        };
-        this.initialize_database().await?;
-        Ok(this)
-    }
-
-    pub fn path(&self) -> &Arc<Path> {
-        &self.path
-    }
-
-    fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
-    where
-        F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
-        T: 'static + Send,
-    {
-        let (tx, rx) = oneshot::channel();
-        let transactions = self.transactions.clone();
-        async move {
-            if transactions
-                .send(Box::new(|connection| {
-                    let result = connection
-                        .transaction()
-                        .map_err(|err| anyhow!(err))
-                        .and_then(|transaction| {
-                            let result = f(&transaction)?;
-                            transaction.commit()?;
-                            Ok(result)
-                        });
-                    let _ = tx.send(result);
-                }))
-                .await
-                .is_err()
-            {
-                return Err(anyhow!("connection was dropped"))?;
-            }
-            rx.await?
-        }
-    }
-
-    fn initialize_database(&self) -> impl Future<Output = Result<()>> {
-        self.transact(|db| {
-            rusqlite::vtab::array::load_module(&db)?;
-
-            // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
-            let version_query = db.prepare("SELECT version from semantic_index_config");
-            let version = version_query
-                .and_then(|mut query| query.query_row([], |row| row.get::<_, i64>(0)));
-            if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
-                log::trace!("vector database schema up to date");
-                return Ok(());
-            }
-
-            log::trace!("vector database schema out of date. updating...");
-            // We renamed the `documents` table to `spans`, so we want to drop
-            // `documents` without recreating it if it exists.
-            db.execute("DROP TABLE IF EXISTS documents", [])
-                .context("failed to drop 'documents' table")?;
-            db.execute("DROP TABLE IF EXISTS spans", [])
-                .context("failed to drop 'spans' table")?;
-            db.execute("DROP TABLE IF EXISTS files", [])
-                .context("failed to drop 'files' table")?;
-            db.execute("DROP TABLE IF EXISTS worktrees", [])
-                .context("failed to drop 'worktrees' table")?;
-            db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
-                .context("failed to drop 'semantic_index_config' table")?;
-
-            // Initialize Vector Databasing Tables
-            db.execute(
-                "CREATE TABLE semantic_index_config (
-                    version INTEGER NOT NULL
-                )",
-                [],
-            )?;
-
-            db.execute(
-                "INSERT INTO semantic_index_config (version) VALUES (?1)",
-                params![SEMANTIC_INDEX_VERSION],
-            )?;
-
-            db.execute(
-                "CREATE TABLE worktrees (
-                    id INTEGER PRIMARY KEY AUTOINCREMENT,
-                    absolute_path VARCHAR NOT NULL
-                );
-                CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
-                ",
-                [],
-            )?;
-
-            db.execute(
-                "CREATE TABLE files (
-                    id INTEGER PRIMARY KEY AUTOINCREMENT,
-                    worktree_id INTEGER NOT NULL,
-                    relative_path VARCHAR NOT NULL,
-                    mtime_seconds INTEGER NOT NULL,
-                    mtime_nanos INTEGER NOT NULL,
-                    FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
-                )",
-                [],
-            )?;
-
-            db.execute(
-                "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
-                [],
-            )?;
-
-            db.execute(
-                "CREATE TABLE spans (
-                    id INTEGER PRIMARY KEY AUTOINCREMENT,
-                    file_id INTEGER NOT NULL,
-                    start_byte INTEGER NOT NULL,
-                    end_byte INTEGER NOT NULL,
-                    name VARCHAR NOT NULL,
-                    embedding BLOB NOT NULL,
-                    digest BLOB NOT NULL,
-                    FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
-                )",
-                [],
-            )?;
-            db.execute(
-                "CREATE INDEX spans_digest ON spans (digest)",
-                [],
-            )?;
-
-            log::trace!("vector database initialized with updated schema.");
-            Ok(())
-        })
-    }
-
-    pub fn delete_file(
-        &self,
-        worktree_id: i64,
-        delete_path: Arc<Path>,
-    ) -> impl Future<Output = Result<()>> {
-        self.transact(move |db| {
-            db.execute(
-                "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
-                params![worktree_id, delete_path.to_str()],
-            )?;
-            Ok(())
-        })
-    }
-
-    pub fn insert_file(
-        &self,
-        worktree_id: i64,
-        path: Arc<Path>,
-        mtime: SystemTime,
-        spans: Vec<Span>,
-    ) -> impl Future<Output = Result<()>> {
-        self.transact(move |db| {
-            // Return the existing ID, if both the file and mtime match
-            let mtime = Timestamp::from(mtime);
-
-            db.execute(
-                "
-                REPLACE INTO files
-                (worktree_id, relative_path, mtime_seconds, mtime_nanos)
-                VALUES (?1, ?2, ?3, ?4)
-                ",
-                params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
-            )?;
-
-            let file_id = db.last_insert_rowid();
-
-            let mut query = db.prepare(
-                "
-                INSERT INTO spans
-                (file_id, start_byte, end_byte, name, embedding, digest)
-                VALUES (?1, ?2, ?3, ?4, ?5, ?6)
-                ",
-            )?;
-
-            for span in spans {
-                query.execute(params![
-                    file_id,
-                    span.range.start.to_string(),
-                    span.range.end.to_string(),
-                    span.name,
-                    span.embedding,
-                    span.digest
-                ])?;
-            }
-
-            Ok(())
-        })
-    }
-
-    pub fn worktree_previously_indexed(
-        &self,
-        worktree_root_path: &Path,
-    ) -> impl Future<Output = Result<bool>> {
-        let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
-        self.transact(move |db| {
-            let mut worktree_query =
-                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
-            let worktree_id =
-                worktree_query.query_row(params![worktree_root_path], |row| row.get::<_, i64>(0));
-
-            Ok(worktree_id.is_ok())
-        })
-    }
-
-    pub fn embeddings_for_digests(
-        &self,
-        digests: Vec<SpanDigest>,
-    ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
-        self.transact(move |db| {
-            let mut query = db.prepare(
-                "
-                SELECT digest, embedding
-                FROM spans
-                WHERE digest IN rarray(?)
-                ",
-            )?;
-            let mut embeddings_by_digest = HashMap::default();
-            let digests = Rc::new(
-                digests
-                    .into_iter()
-                    .map(|digest| Value::Blob(digest.0.to_vec()))
-                    .collect::<Vec<_>>(),
-            );
-            let rows = query.query_map(params![digests], |row| {
-                Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
-            })?;
-
-            for (digest, embedding) in rows.flatten() {
-                embeddings_by_digest.insert(digest, embedding);
-            }
-
-            Ok(embeddings_by_digest)
-        })
-    }
-
-    pub fn embeddings_for_files(
-        &self,
-        worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
-    ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
-        self.transact(move |db| {
-            let mut query = db.prepare(
-                "
-                SELECT digest, embedding
-                FROM spans
-                LEFT JOIN files ON files.id = spans.file_id
-                WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
-            ",
-            )?;
-            let mut embeddings_by_digest = HashMap::default();
-            for (worktree_id, file_paths) in worktree_id_file_paths {
-                let file_paths = Rc::new(
-                    file_paths
-                        .into_iter()
-                        .map(|p| Value::Text(p.to_string_lossy().into_owned()))
-                        .collect::<Vec<_>>(),
-                );
-                let rows = query.query_map(params![worktree_id, file_paths], |row| {
-                    Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
-                })?;
-
-                for (digest, embedding) in rows.flatten() {
-                    embeddings_by_digest.insert(digest, embedding);
-                }
-            }
-
-            Ok(embeddings_by_digest)
-        })
-    }
-
-    pub fn find_or_create_worktree(
-        &self,
-        worktree_root_path: Arc<Path>,
-    ) -> impl Future<Output = Result<i64>> {
-        self.transact(move |db| {
-            let mut worktree_query =
-                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
-            let worktree_id = worktree_query
-                .query_row(params![worktree_root_path.to_string_lossy()], |row| {
-                    row.get::<_, i64>(0)
-                });
-
-            if worktree_id.is_ok() {
-                return Ok(worktree_id?);
-            }
-
-            // If worktree_id is Err, insert new worktree
-            db.execute(
-                "INSERT into worktrees (absolute_path) VALUES (?1)",
-                params![worktree_root_path.to_string_lossy()],
-            )?;
-            Ok(db.last_insert_rowid())
-        })
-    }
-
-    pub fn get_file_mtimes(
-        &self,
-        worktree_id: i64,
-    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
-        self.transact(move |db| {
-            let mut statement = db.prepare(
-                "
-                SELECT relative_path, mtime_seconds, mtime_nanos
-                FROM files
-                WHERE worktree_id = ?1
-                ORDER BY relative_path",
-            )?;
-            let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
-            for row in statement.query_map(params![worktree_id], |row| {
-                Ok((
-                    row.get::<_, String>(0)?.into(),
-                    Timestamp {
-                        seconds: row.get(1)?,
-                        nanos: row.get(2)?,
-                    }
-                    .into(),
-                ))
-            })? {
-                let row = row?;
-                result.insert(row.0, row.1);
-            }
-            Ok(result)
-        })
-    }
-
-    pub fn top_k_search(
-        &self,
-        query_embedding: &Embedding,
-        limit: usize,
-        file_ids: &[i64],
-    ) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
-        let file_ids = file_ids.to_vec();
-        let query = query_embedding.clone().0;
-        let query = Array1::from_vec(query);
-        self.transact(move |db| {
-            let mut query_statement = db.prepare(
-                "
-                    SELECT
-                        id, embedding
-                    FROM
-                        spans
-                    WHERE
-                        file_id IN rarray(?)
-                    ",
-            )?;
-
-            let deserialized_rows = query_statement
-                .query_map(params![ids_to_sql(&file_ids)], |row| {
-                    Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
-                })?
-                .filter_map(|row| row.ok())
-                .collect::<Vec<(usize, Embedding)>>();
-
-            if deserialized_rows.len() == 0 {
-                return Ok(Vec::new());
-            }
-
-            // Get Length of Embeddings Returned
-            let embedding_len = deserialized_rows[0].1 .0.len();
-
-            let batch_n = 1000;
-            let mut batches = Vec::new();
-            let mut batch_ids = Vec::new();
-            let mut batch_embeddings: Vec<f32> = Vec::new();
-            deserialized_rows.iter().for_each(|(id, embedding)| {
-                batch_ids.push(id);
-                batch_embeddings.extend(&embedding.0);
-
-                if batch_ids.len() == batch_n {
-                    let embeddings = std::mem::take(&mut batch_embeddings);
-                    let ids = std::mem::take(&mut batch_ids);
-                    let array = Array2::from_shape_vec((ids.len(), embedding_len), embeddings);
-                    match array {
-                        Ok(array) => {
-                            batches.push((ids, array));
-                        }
-                        Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
-                    }
-                }
-            });
-
-            if batch_ids.len() > 0 {
-                let array = Array2::from_shape_vec(
-                    (batch_ids.len(), embedding_len),
-                    batch_embeddings.clone(),
-                );
-                match array {
-                    Ok(array) => {
-                        batches.push((batch_ids.clone(), array));
-                    }
-                    Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
-                }
-            }
-
-            let mut ids: Vec<usize> = Vec::new();
-            let mut results = Vec::new();
-            for (batch_ids, array) in batches {
-                let scores = array
-                    .dot(&query.t())
-                    .to_vec()
-                    .iter()
-                    .map(|score| OrderedFloat(*score))
-                    .collect::<Vec<OrderedFloat<f32>>>();
-                results.extend(scores);
-                ids.extend(batch_ids);
-            }
-
-            let sorted_idx = argsort(&results);
-            let mut sorted_results = Vec::new();
-            let last_idx = limit.min(sorted_idx.len());
-            for idx in &sorted_idx[0..last_idx] {
-                sorted_results.push((ids[*idx] as i64, results[*idx]))
-            }
-
-            Ok(sorted_results)
-        })
-    }
-
-    pub fn retrieve_included_file_ids(
-        &self,
-        worktree_ids: &[i64],
-        includes: &[PathMatcher],
-        excludes: &[PathMatcher],
-    ) -> impl Future<Output = Result<Vec<i64>>> {
-        let worktree_ids = worktree_ids.to_vec();
-        let includes = includes.to_vec();
-        let excludes = excludes.to_vec();
-        self.transact(move |db| {
-            let mut file_query = db.prepare(
-                "
-                SELECT
-                    id, relative_path
-                FROM
-                    files
-                WHERE
-                    worktree_id IN rarray(?)
-                ",
-            )?;
-
-            let mut file_ids = Vec::<i64>::new();
-            let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
-
-            while let Some(row) = rows.next()? {
-                let file_id = row.get(0)?;
-                let relative_path = row.get_ref(1)?.as_str()?;
-                let included =
-                    includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
-                let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
-                if included && !excluded {
-                    file_ids.push(file_id);
-                }
-            }
-
-            anyhow::Ok(file_ids)
-        })
-    }
-
-    pub fn spans_for_ids(
-        &self,
-        ids: &[i64],
-    ) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
-        let ids = ids.to_vec();
-        self.transact(move |db| {
-            let mut statement = db.prepare(
-                "
-                    SELECT
-                        spans.id,
-                        files.worktree_id,
-                        files.relative_path,
-                        spans.start_byte,
-                        spans.end_byte
-                    FROM
-                        spans, files
-                    WHERE
-                        spans.file_id = files.id AND
-                        spans.id in rarray(?)
-                ",
-            )?;
-
-            let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
-                Ok((
-                    row.get::<_, i64>(0)?,
-                    row.get::<_, i64>(1)?,
-                    row.get::<_, String>(2)?.into(),
-                    row.get(3)?..row.get(4)?,
-                ))
-            })?;
-
-            let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
-            for row in result_iter {
-                let (id, worktree_id, path, range) = row?;
-                values_by_id.insert(id, (worktree_id, path, range));
-            }
-
-            let mut results = Vec::with_capacity(ids.len());
-            for id in &ids {
-                let value = values_by_id
-                    .remove(id)
-                    .ok_or(anyhow!("missing span id {}", id))?;
-                results.push(value);
-            }
-
-            Ok(results)
-        })
-    }
-}
-
-fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
-    Rc::new(
-        ids.iter()
-            .copied()
-            .map(|v| rusqlite::types::Value::from(v))
-            .collect::<Vec<_>>(),
-    )
-}

crates/semantic_index/src/embedding_queue.rs 🔗

@@ -1,169 +0,0 @@
-use crate::{parsing::Span, JobHandle};
-use ai::embedding::EmbeddingProvider;
-use gpui::BackgroundExecutor;
-use parking_lot::Mutex;
-use smol::channel;
-use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime};
-
-#[derive(Clone)]
-pub struct FileToEmbed {
-    pub worktree_id: i64,
-    pub path: Arc<Path>,
-    pub mtime: SystemTime,
-    pub spans: Vec<Span>,
-    pub job_handle: JobHandle,
-}
-
-impl std::fmt::Debug for FileToEmbed {
-    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        f.debug_struct("FileToEmbed")
-            .field("worktree_id", &self.worktree_id)
-            .field("path", &self.path)
-            .field("mtime", &self.mtime)
-            .field("spans", &self.spans)
-            .finish_non_exhaustive()
-    }
-}
-
-impl PartialEq for FileToEmbed {
-    fn eq(&self, other: &Self) -> bool {
-        self.worktree_id == other.worktree_id
-            && self.path == other.path
-            && self.mtime == other.mtime
-            && self.spans == other.spans
-    }
-}
-
-pub struct EmbeddingQueue {
-    embedding_provider: Arc<dyn EmbeddingProvider>,
-    pending_batch: Vec<FileFragmentToEmbed>,
-    executor: BackgroundExecutor,
-    pending_batch_token_count: usize,
-    finished_files_tx: channel::Sender<FileToEmbed>,
-    finished_files_rx: channel::Receiver<FileToEmbed>,
-}
-
-#[derive(Clone)]
-pub struct FileFragmentToEmbed {
-    file: Arc<Mutex<FileToEmbed>>,
-    span_range: Range<usize>,
-}
-
-impl EmbeddingQueue {
-    pub fn new(
-        embedding_provider: Arc<dyn EmbeddingProvider>,
-        executor: BackgroundExecutor,
-    ) -> Self {
-        let (finished_files_tx, finished_files_rx) = channel::unbounded();
-        Self {
-            embedding_provider,
-            executor,
-            pending_batch: Vec::new(),
-            pending_batch_token_count: 0,
-            finished_files_tx,
-            finished_files_rx,
-        }
-    }
-
-    pub fn push(&mut self, file: FileToEmbed) {
-        if file.spans.is_empty() {
-            self.finished_files_tx.try_send(file).unwrap();
-            return;
-        }
-
-        let file = Arc::new(Mutex::new(file));
-
-        self.pending_batch.push(FileFragmentToEmbed {
-            file: file.clone(),
-            span_range: 0..0,
-        });
-
-        let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
-        for (ix, span) in file.lock().spans.iter().enumerate() {
-            let span_token_count = if span.embedding.is_none() {
-                span.token_count
-            } else {
-                0
-            };
-
-            let next_token_count = self.pending_batch_token_count + span_token_count;
-            if next_token_count > self.embedding_provider.max_tokens_per_batch() {
-                let range_end = fragment_range.end;
-                self.flush();
-                self.pending_batch.push(FileFragmentToEmbed {
-                    file: file.clone(),
-                    span_range: range_end..range_end,
-                });
-                fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
-            }
-
-            fragment_range.end = ix + 1;
-            self.pending_batch_token_count += span_token_count;
-        }
-    }
-
-    pub fn flush(&mut self) {
-        let batch = mem::take(&mut self.pending_batch);
-        self.pending_batch_token_count = 0;
-        if batch.is_empty() {
-            return;
-        }
-
-        let finished_files_tx = self.finished_files_tx.clone();
-        let embedding_provider = self.embedding_provider.clone();
-
-        self.executor
-            .spawn(async move {
-                let mut spans = Vec::new();
-                for fragment in &batch {
-                    let file = fragment.file.lock();
-                    spans.extend(
-                        file.spans[fragment.span_range.clone()]
-                            .iter()
-                            .filter(|d| d.embedding.is_none())
-                            .map(|d| d.content.clone()),
-                    );
-                }
-
-                // If spans is 0, just send the fragment to the finished files if its the last one.
-                if spans.is_empty() {
-                    for fragment in batch.clone() {
-                        if let Some(file) = Arc::into_inner(fragment.file) {
-                            finished_files_tx.try_send(file.into_inner()).unwrap();
-                        }
-                    }
-                    return;
-                };
-
-                match embedding_provider.embed_batch(spans).await {
-                    Ok(embeddings) => {
-                        let mut embeddings = embeddings.into_iter();
-                        for fragment in batch {
-                            for span in &mut fragment.file.lock().spans[fragment.span_range.clone()]
-                                .iter_mut()
-                                .filter(|d| d.embedding.is_none())
-                            {
-                                if let Some(embedding) = embeddings.next() {
-                                    span.embedding = Some(embedding);
-                                } else {
-                                    log::error!("number of embeddings != number of documents");
-                                }
-                            }
-
-                            if let Some(file) = Arc::into_inner(fragment.file) {
-                                finished_files_tx.try_send(file.into_inner()).unwrap();
-                            }
-                        }
-                    }
-                    Err(error) => {
-                        log::error!("{:?}", error);
-                    }
-                }
-            })
-            .detach();
-    }
-
-    pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
-        self.finished_files_rx.clone()
-    }
-}

crates/semantic_index/src/parsing.rs 🔗

@@ -1,414 +0,0 @@
-use ai::{
-    embedding::{Embedding, EmbeddingProvider},
-    models::TruncationDirection,
-};
-use anyhow::{anyhow, Result};
-use collections::HashSet;
-use language::{Grammar, Language};
-use rusqlite::{
-    types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
-    ToSql,
-};
-use sha1::{Digest, Sha1};
-use std::{
-    borrow::Cow,
-    cmp::{self, Reverse},
-    ops::Range,
-    path::Path,
-    sync::Arc,
-};
-use tree_sitter::{Parser, QueryCursor};
-
-#[derive(Debug, PartialEq, Eq, Clone, Hash)]
-pub struct SpanDigest(pub [u8; 20]);
-
-impl FromSql for SpanDigest {
-    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
-        let blob = value.as_blob()?;
-        let bytes =
-            blob.try_into()
-                .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
-                    expected_size: 20,
-                    blob_size: blob.len(),
-                })?;
-        return Ok(SpanDigest(bytes));
-    }
-}
-
-impl ToSql for SpanDigest {
-    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
-        self.0.to_sql()
-    }
-}
-
-impl From<&'_ str> for SpanDigest {
-    fn from(value: &'_ str) -> Self {
-        let mut sha1 = Sha1::new();
-        sha1.update(value);
-        Self(sha1.finalize().into())
-    }
-}
-
-#[derive(Debug, PartialEq, Clone)]
-pub struct Span {
-    pub name: String,
-    pub range: Range<usize>,
-    pub content: String,
-    pub embedding: Option<Embedding>,
-    pub digest: SpanDigest,
-    pub token_count: usize,
-}
-
-const CODE_CONTEXT_TEMPLATE: &str =
-    "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
-const ENTIRE_FILE_TEMPLATE: &str =
-    "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
-const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
-pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &[
-    "TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML", "Scheme",
-];
-
-pub struct CodeContextRetriever {
-    pub parser: Parser,
-    pub cursor: QueryCursor,
-    pub embedding_provider: Arc<dyn EmbeddingProvider>,
-}
-
-// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
-// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
-// If there are preceding comments, we track this with a context capture
-// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
-// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
-#[derive(Debug, Clone)]
-pub struct CodeContextMatch {
-    pub start_col: usize,
-    pub item_range: Option<Range<usize>>,
-    pub name_range: Option<Range<usize>>,
-    pub context_ranges: Vec<Range<usize>>,
-    pub collapse_ranges: Vec<Range<usize>>,
-}
-
-impl CodeContextRetriever {
-    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
-        Self {
-            parser: Parser::new(),
-            cursor: QueryCursor::new(),
-            embedding_provider,
-        }
-    }
-
-    fn parse_entire_file(
-        &self,
-        relative_path: Option<&Path>,
-        language_name: Arc<str>,
-        content: &str,
-    ) -> Result<Vec<Span>> {
-        let document_span = ENTIRE_FILE_TEMPLATE
-            .replace(
-                "<path>",
-                &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
-            )
-            .replace("<language>", language_name.as_ref())
-            .replace("<item>", &content);
-        let digest = SpanDigest::from(document_span.as_str());
-        let model = self.embedding_provider.base_model();
-        let document_span = model.truncate(
-            &document_span,
-            model.capacity()?,
-            ai::models::TruncationDirection::End,
-        )?;
-        let token_count = model.count_tokens(&document_span)?;
-
-        Ok(vec![Span {
-            range: 0..content.len(),
-            content: document_span,
-            embedding: Default::default(),
-            name: language_name.to_string(),
-            digest,
-            token_count,
-        }])
-    }
-
-    fn parse_markdown_file(
-        &self,
-        relative_path: Option<&Path>,
-        content: &str,
-    ) -> Result<Vec<Span>> {
-        let document_span = MARKDOWN_CONTEXT_TEMPLATE
-            .replace(
-                "<path>",
-                &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
-            )
-            .replace("<item>", &content);
-        let digest = SpanDigest::from(document_span.as_str());
-
-        let model = self.embedding_provider.base_model();
-        let document_span = model.truncate(
-            &document_span,
-            model.capacity()?,
-            ai::models::TruncationDirection::End,
-        )?;
-        let token_count = model.count_tokens(&document_span)?;
-
-        Ok(vec![Span {
-            range: 0..content.len(),
-            content: document_span,
-            embedding: None,
-            name: "Markdown".to_string(),
-            digest,
-            token_count,
-        }])
-    }
-
-    fn get_matches_in_file(
-        &mut self,
-        content: &str,
-        grammar: &Arc<Grammar>,
-    ) -> Result<Vec<CodeContextMatch>> {
-        let embedding_config = grammar
-            .embedding_config
-            .as_ref()
-            .ok_or_else(|| anyhow!("no embedding queries"))?;
-        self.parser.set_language(&grammar.ts_language).unwrap();
-
-        let tree = self
-            .parser
-            .parse(&content, None)
-            .ok_or_else(|| anyhow!("parsing failed"))?;
-
-        let mut captures: Vec<CodeContextMatch> = Vec::new();
-        let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
-        let mut keep_ranges: Vec<Range<usize>> = Vec::new();
-        for mat in self.cursor.matches(
-            &embedding_config.query,
-            tree.root_node(),
-            content.as_bytes(),
-        ) {
-            let mut start_col = 0;
-            let mut item_range: Option<Range<usize>> = None;
-            let mut name_range: Option<Range<usize>> = None;
-            let mut context_ranges: Vec<Range<usize>> = Vec::new();
-            collapse_ranges.clear();
-            keep_ranges.clear();
-            for capture in mat.captures {
-                if capture.index == embedding_config.item_capture_ix {
-                    item_range = Some(capture.node.byte_range());
-                    start_col = capture.node.start_position().column;
-                } else if Some(capture.index) == embedding_config.name_capture_ix {
-                    name_range = Some(capture.node.byte_range());
-                } else if Some(capture.index) == embedding_config.context_capture_ix {
-                    context_ranges.push(capture.node.byte_range());
-                } else if Some(capture.index) == embedding_config.collapse_capture_ix {
-                    collapse_ranges.push(capture.node.byte_range());
-                } else if Some(capture.index) == embedding_config.keep_capture_ix {
-                    keep_ranges.push(capture.node.byte_range());
-                }
-            }
-
-            captures.push(CodeContextMatch {
-                start_col,
-                item_range,
-                name_range,
-                context_ranges,
-                collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
-            });
-        }
-        Ok(captures)
-    }
-
-    pub fn parse_file_with_template(
-        &mut self,
-        relative_path: Option<&Path>,
-        content: &str,
-        language: Arc<Language>,
-    ) -> Result<Vec<Span>> {
-        let language_name = language.name();
-
-        if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
-            return self.parse_entire_file(relative_path, language_name, &content);
-        } else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
-            return self.parse_markdown_file(relative_path, &content);
-        }
-
-        let mut spans = self.parse_file(content, language)?;
-        for span in &mut spans {
-            let document_content = CODE_CONTEXT_TEMPLATE
-                .replace(
-                    "<path>",
-                    &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
-                )
-                .replace("<language>", language_name.as_ref())
-                .replace("item", &span.content);
-
-            let model = self.embedding_provider.base_model();
-            let document_content = model.truncate(
-                &document_content,
-                model.capacity()?,
-                TruncationDirection::End,
-            )?;
-            let token_count = model.count_tokens(&document_content)?;
-
-            span.content = document_content;
-            span.token_count = token_count;
-        }
-        Ok(spans)
-    }
-
-    pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
-        let grammar = language
-            .grammar()
-            .ok_or_else(|| anyhow!("no grammar for language"))?;
-
-        // Iterate through query matches
-        let matches = self.get_matches_in_file(content, grammar)?;
-
-        let language_scope = language.default_scope();
-        let placeholder = language_scope.collapsed_placeholder();
-
-        let mut spans = Vec::new();
-        let mut collapsed_ranges_within = Vec::new();
-        let mut parsed_name_ranges = HashSet::default();
-        for (i, context_match) in matches.iter().enumerate() {
-            // Items which are collapsible but not embeddable have no item range
-            let item_range = if let Some(item_range) = context_match.item_range.clone() {
-                item_range
-            } else {
-                continue;
-            };
-
-            // Checks for deduplication
-            let name;
-            if let Some(name_range) = context_match.name_range.clone() {
-                name = content
-                    .get(name_range.clone())
-                    .map_or(String::new(), |s| s.to_string());
-                if parsed_name_ranges.contains(&name_range) {
-                    continue;
-                }
-                parsed_name_ranges.insert(name_range);
-            } else {
-                name = String::new();
-            }
-
-            collapsed_ranges_within.clear();
-            'outer: for remaining_match in &matches[(i + 1)..] {
-                for collapsed_range in &remaining_match.collapse_ranges {
-                    if item_range.start <= collapsed_range.start
-                        && item_range.end >= collapsed_range.end
-                    {
-                        collapsed_ranges_within.push(collapsed_range.clone());
-                    } else {
-                        break 'outer;
-                    }
-                }
-            }
-
-            collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
-
-            let mut span_content = String::new();
-            for context_range in &context_match.context_ranges {
-                add_content_from_range(
-                    &mut span_content,
-                    content,
-                    context_range.clone(),
-                    context_match.start_col,
-                );
-                span_content.push_str("\n");
-            }
-
-            let mut offset = item_range.start;
-            for collapsed_range in &collapsed_ranges_within {
-                if collapsed_range.start > offset {
-                    add_content_from_range(
-                        &mut span_content,
-                        content,
-                        offset..collapsed_range.start,
-                        context_match.start_col,
-                    );
-                    offset = collapsed_range.start;
-                }
-
-                if collapsed_range.end > offset {
-                    span_content.push_str(placeholder);
-                    offset = collapsed_range.end;
-                }
-            }
-
-            if offset < item_range.end {
-                add_content_from_range(
-                    &mut span_content,
-                    content,
-                    offset..item_range.end,
-                    context_match.start_col,
-                );
-            }
-
-            let sha1 = SpanDigest::from(span_content.as_str());
-            spans.push(Span {
-                name,
-                content: span_content,
-                range: item_range.clone(),
-                embedding: None,
-                digest: sha1,
-                token_count: 0,
-            })
-        }
-
-        return Ok(spans);
-    }
-}
-
-pub(crate) fn subtract_ranges(
-    ranges: &[Range<usize>],
-    ranges_to_subtract: &[Range<usize>],
-) -> Vec<Range<usize>> {
-    let mut result = Vec::new();
-
-    let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
-
-    for range in ranges {
-        let mut offset = range.start;
-
-        while offset < range.end {
-            if let Some(range_to_subtract) = ranges_to_subtract.peek() {
-                if offset < range_to_subtract.start {
-                    let next_offset = cmp::min(range_to_subtract.start, range.end);
-                    result.push(offset..next_offset);
-                    offset = next_offset;
-                } else {
-                    let next_offset = cmp::min(range_to_subtract.end, range.end);
-                    offset = next_offset;
-                }
-
-                if offset >= range_to_subtract.end {
-                    ranges_to_subtract.next();
-                }
-            } else {
-                result.push(offset..range.end);
-                offset = range.end;
-            }
-        }
-    }
-
-    result
-}
-
-fn add_content_from_range(
-    output: &mut String,
-    content: &str,
-    range: Range<usize>,
-    start_col: usize,
-) {
-    for mut line in content.get(range.clone()).unwrap_or("").lines() {
-        for _ in 0..start_col {
-            if line.starts_with(' ') {
-                line = &line[1..];
-            } else {
-                break;
-            }
-        }
-        output.push_str(line);
-        output.push('\n');
-    }
-    output.pop();
-}

crates/semantic_index/src/semantic_index.rs 🔗

@@ -1,1308 +0,0 @@
-mod db;
-mod embedding_queue;
-mod parsing;
-pub mod semantic_index_settings;
-
-#[cfg(test)]
-mod semantic_index_tests;
-
-use crate::semantic_index_settings::SemanticIndexSettings;
-use ai::embedding::{Embedding, EmbeddingProvider};
-use ai::providers::open_ai::{OpenAiEmbeddingProvider, OPEN_AI_API_URL};
-use anyhow::{anyhow, Context as _, Result};
-use collections::{BTreeMap, HashMap, HashSet};
-use db::VectorDatabase;
-use embedding_queue::{EmbeddingQueue, FileToEmbed};
-use futures::{future, FutureExt, StreamExt};
-use gpui::{
-    AppContext, AsyncAppContext, BorrowWindow, Context, Global, Model, ModelContext, Task,
-    ViewContext, WeakModel,
-};
-use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
-use lazy_static::lazy_static;
-use ordered_float::OrderedFloat;
-use parking_lot::Mutex;
-use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
-use postage::watch;
-use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
-use release_channel::ReleaseChannel;
-use settings::Settings;
-use smol::channel;
-use std::{
-    cmp::Reverse,
-    env,
-    future::Future,
-    mem,
-    ops::Range,
-    path::{Path, PathBuf},
-    sync::{Arc, Weak},
-    time::{Duration, Instant, SystemTime},
-};
-use util::paths::PathMatcher;
-use util::{http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
-use workspace::Workspace;
-
-const SEMANTIC_INDEX_VERSION: usize = 11;
-const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
-const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
-
-lazy_static! {
-    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
-}
-
-pub fn init(
-    fs: Arc<dyn Fs>,
-    http_client: Arc<dyn HttpClient>,
-    language_registry: Arc<LanguageRegistry>,
-    cx: &mut AppContext,
-) {
-    SemanticIndexSettings::register(cx);
-
-    let db_file_path = EMBEDDINGS_DIR
-        .join(Path::new(ReleaseChannel::global(cx).dev_name()))
-        .join("embeddings_db");
-
-    cx.observe_new_views(
-        |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
-            let Some(semantic_index) = SemanticIndex::global(cx) else {
-                return;
-            };
-            let project = workspace.project().clone();
-
-            if project.read(cx).is_local() {
-                cx.app_mut()
-                    .spawn(|mut cx| async move {
-                        let previously_indexed = semantic_index
-                            .update(&mut cx, |index, cx| {
-                                index.project_previously_indexed(&project, cx)
-                            })?
-                            .await?;
-                        if previously_indexed {
-                            semantic_index
-                                .update(&mut cx, |index, cx| index.index_project(project, cx))?
-                                .await?;
-                        }
-                        anyhow::Ok(())
-                    })
-                    .detach_and_log_err(cx);
-            }
-        },
-    )
-    .detach();
-
-    cx.spawn(move |cx| async move {
-        let embedding_provider = OpenAiEmbeddingProvider::new(
-            // TODO: We should read it from config, but I'm not sure whether to reuse `openai_api_url` in assistant settings or not
-            OPEN_AI_API_URL.to_string(),
-            http_client,
-            cx.background_executor().clone(),
-        )
-        .await;
-        let semantic_index = SemanticIndex::new(
-            fs,
-            db_file_path,
-            Arc::new(embedding_provider),
-            language_registry,
-            cx.clone(),
-        )
-        .await?;
-
-        cx.update(|cx| cx.set_global(GlobalSemanticIndex(semantic_index.clone())))?;
-
-        anyhow::Ok(())
-    })
-    .detach();
-}
-
-#[derive(Copy, Clone, Debug)]
-pub enum SemanticIndexStatus {
-    NotAuthenticated,
-    NotIndexed,
-    Indexed,
-    Indexing {
-        remaining_files: usize,
-        rate_limit_expiry: Option<Instant>,
-    },
-}
-
-pub struct SemanticIndex {
-    fs: Arc<dyn Fs>,
-    db: VectorDatabase,
-    embedding_provider: Arc<dyn EmbeddingProvider>,
-    language_registry: Arc<LanguageRegistry>,
-    parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
-    _embedding_task: Task<()>,
-    _parsing_files_tasks: Vec<Task<()>>,
-    projects: HashMap<WeakModel<Project>, ProjectState>,
-}
-
-struct GlobalSemanticIndex(Model<SemanticIndex>);
-
-impl Global for GlobalSemanticIndex {}
-
-struct ProjectState {
-    worktrees: HashMap<WorktreeId, WorktreeState>,
-    pending_file_count_rx: watch::Receiver<usize>,
-    pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
-    pending_index: usize,
-    _subscription: gpui::Subscription,
-    _observe_pending_file_count: Task<()>,
-}
-
-enum WorktreeState {
-    Registering(RegisteringWorktreeState),
-    Registered(RegisteredWorktreeState),
-}
-
-impl WorktreeState {
-    fn is_registered(&self) -> bool {
-        matches!(self, Self::Registered(_))
-    }
-
-    fn paths_changed(
-        &mut self,
-        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
-        worktree: &Worktree,
-    ) {
-        let changed_paths = match self {
-            Self::Registering(state) => &mut state.changed_paths,
-            Self::Registered(state) => &mut state.changed_paths,
-        };
-
-        for (path, entry_id, change) in changes.iter() {
-            let Some(entry) = worktree.entry_for_id(*entry_id) else {
-                continue;
-            };
-            let Some(mtime) = entry.mtime else {
-                continue;
-            };
-            if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
-                continue;
-            }
-            changed_paths.insert(
-                path.clone(),
-                ChangedPathInfo {
-                    mtime,
-                    is_deleted: *change == PathChange::Removed,
-                },
-            );
-        }
-    }
-}
-
-struct RegisteringWorktreeState {
-    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
-    done_rx: watch::Receiver<Option<()>>,
-    _registration: Task<()>,
-}
-
-impl RegisteringWorktreeState {
-    fn done(&self) -> impl Future<Output = ()> {
-        let mut done_rx = self.done_rx.clone();
-        async move {
-            while let Some(result) = done_rx.next().await {
-                if result.is_some() {
-                    break;
-                }
-            }
-        }
-    }
-}
-
-struct RegisteredWorktreeState {
-    db_id: i64,
-    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
-}
-
-struct ChangedPathInfo {
-    mtime: SystemTime,
-    is_deleted: bool,
-}
-
-#[derive(Clone)]
-pub struct JobHandle {
-    /// The outer Arc is here to count the clones of a JobHandle instance;
-    /// when the last handle to a given job is dropped, we decrement a counter (just once).
-    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
-}
-
-impl JobHandle {
-    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
-        *tx.lock().borrow_mut() += 1;
-        Self {
-            tx: Arc::new(Arc::downgrade(&tx)),
-        }
-    }
-}
-
-impl ProjectState {
-    fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
-        let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
-        let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
-        Self {
-            worktrees: Default::default(),
-            pending_file_count_rx: pending_file_count_rx.clone(),
-            pending_file_count_tx,
-            pending_index: 0,
-            _subscription: subscription,
-            _observe_pending_file_count: cx.spawn({
-                let mut pending_file_count_rx = pending_file_count_rx.clone();
-                |this, mut cx| async move {
-                    while let Some(_) = pending_file_count_rx.next().await {
-                        if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
-                            break;
-                        }
-                    }
-                }
-            }),
-        }
-    }
-
-    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
-        self.worktrees
-            .iter()
-            .find_map(|(worktree_id, worktree_state)| match worktree_state {
-                WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
-                _ => None,
-            })
-    }
-}
-
-#[derive(Clone)]
-pub struct PendingFile {
-    worktree_db_id: i64,
-    relative_path: Arc<Path>,
-    absolute_path: PathBuf,
-    language: Option<Arc<Language>>,
-    modified_time: SystemTime,
-    job_handle: JobHandle,
-}
-
-#[derive(Clone)]
-pub struct SearchResult {
-    pub buffer: Model<Buffer>,
-    pub range: Range<Anchor>,
-    pub similarity: OrderedFloat<f32>,
-}
-
-impl SemanticIndex {
-    pub fn global(cx: &mut AppContext) -> Option<Model<SemanticIndex>> {
-        cx.try_global::<GlobalSemanticIndex>()
-            .map(|semantic_index| semantic_index.0.clone())
-    }
-
-    pub fn authenticate(&mut self, cx: &mut AppContext) -> Task<bool> {
-        if !self.embedding_provider.has_credentials() {
-            let embedding_provider = self.embedding_provider.clone();
-            cx.spawn(|cx| async move {
-                if let Some(retrieve_credentials) = cx
-                    .update(|cx| embedding_provider.retrieve_credentials(cx))
-                    .log_err()
-                {
-                    retrieve_credentials.await;
-                }
-
-                embedding_provider.has_credentials()
-            })
-        } else {
-            Task::ready(true)
-        }
-    }
-
-    pub fn is_authenticated(&self) -> bool {
-        self.embedding_provider.has_credentials()
-    }
-
-    pub fn enabled(cx: &AppContext) -> bool {
-        SemanticIndexSettings::get_global(cx).enabled
-    }
-
-    pub fn status(&self, project: &Model<Project>) -> SemanticIndexStatus {
-        if !self.is_authenticated() {
-            return SemanticIndexStatus::NotAuthenticated;
-        }
-
-        if let Some(project_state) = self.projects.get(&project.downgrade()) {
-            if project_state
-                .worktrees
-                .values()
-                .all(|worktree| worktree.is_registered())
-                && project_state.pending_index == 0
-            {
-                SemanticIndexStatus::Indexed
-            } else {
-                SemanticIndexStatus::Indexing {
-                    remaining_files: *project_state.pending_file_count_rx.borrow(),
-                    rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
-                }
-            }
-        } else {
-            SemanticIndexStatus::NotIndexed
-        }
-    }
-
-    pub async fn new(
-        fs: Arc<dyn Fs>,
-        database_path: PathBuf,
-        embedding_provider: Arc<dyn EmbeddingProvider>,
-        language_registry: Arc<LanguageRegistry>,
-        mut cx: AsyncAppContext,
-    ) -> Result<Model<Self>> {
-        let t0 = Instant::now();
-        let database_path = Arc::from(database_path);
-        let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone())
-            .await?;
-
-        log::trace!(
-            "db initialization took {:?} milliseconds",
-            t0.elapsed().as_millis()
-        );
-
-        cx.new_model(|cx| {
-            let t0 = Instant::now();
-            let embedding_queue =
-                EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone());
-            let _embedding_task = cx.background_executor().spawn({
-                let embedded_files = embedding_queue.finished_files();
-                let db = db.clone();
-                async move {
-                    while let Ok(file) = embedded_files.recv().await {
-                        db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
-                            .await
-                            .log_err();
-                    }
-                }
-            });
-
-            // Parse files into embeddable spans.
-            let (parsing_files_tx, parsing_files_rx) =
-                channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
-            let embedding_queue = Arc::new(Mutex::new(embedding_queue));
-            let mut _parsing_files_tasks = Vec::new();
-            for _ in 0..cx.background_executor().num_cpus() {
-                let fs = fs.clone();
-                let mut parsing_files_rx = parsing_files_rx.clone();
-                let embedding_provider = embedding_provider.clone();
-                let embedding_queue = embedding_queue.clone();
-                let background = cx.background_executor().clone();
-                _parsing_files_tasks.push(cx.background_executor().spawn(async move {
-                    let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
-                    loop {
-                        let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
-                        let mut next_file_to_parse = parsing_files_rx.next().fuse();
-                        futures::select_biased! {
-                            next_file_to_parse = next_file_to_parse => {
-                                if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
-                                    Self::parse_file(
-                                        &fs,
-                                        pending_file,
-                                        &mut retriever,
-                                        &embedding_queue,
-                                        &embeddings_for_digest,
-                                    )
-                                    .await
-                                } else {
-                                    break;
-                                }
-                            },
-                            _ = timer => {
-                                embedding_queue.lock().flush();
-                            }
-                        }
-                    }
-                }));
-            }
-
-            log::trace!(
-                "semantic index task initialization took {:?} milliseconds",
-                t0.elapsed().as_millis()
-            );
-            Self {
-                fs,
-                db,
-                embedding_provider,
-                language_registry,
-                parsing_files_tx,
-                _embedding_task,
-                _parsing_files_tasks,
-                projects: Default::default(),
-            }
-        })
-    }
-
-    async fn parse_file(
-        fs: &Arc<dyn Fs>,
-        pending_file: PendingFile,
-        retriever: &mut CodeContextRetriever,
-        embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
-        embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
-    ) {
-        let Some(language) = pending_file.language else {
-            return;
-        };
-
-        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
-            if let Some(mut spans) = retriever
-                .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
-                .log_err()
-            {
-                log::trace!(
-                    "parsed path {:?}: {} spans",
-                    pending_file.relative_path,
-                    spans.len()
-                );
-
-                for span in &mut spans {
-                    if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
-                        span.embedding = Some(embedding.to_owned());
-                    }
-                }
-
-                embedding_queue.lock().push(FileToEmbed {
-                    worktree_id: pending_file.worktree_db_id,
-                    path: pending_file.relative_path,
-                    mtime: pending_file.modified_time,
-                    job_handle: pending_file.job_handle,
-                    spans,
-                });
-            }
-        }
-    }
-
-    pub fn project_previously_indexed(
-        &mut self,
-        project: &Model<Project>,
-        cx: &mut ModelContext<Self>,
-    ) -> Task<Result<bool>> {
-        let worktrees_indexed_previously = project
-            .read(cx)
-            .worktrees()
-            .map(|worktree| {
-                self.db
-                    .worktree_previously_indexed(&worktree.read(cx).abs_path())
-            })
-            .collect::<Vec<_>>();
-        cx.spawn(|_, _cx| async move {
-            let worktree_indexed_previously =
-                futures::future::join_all(worktrees_indexed_previously).await;
-
-            Ok(worktree_indexed_previously
-                .iter()
-                .filter(|worktree| worktree.is_ok())
-                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
-        })
-    }
-
-    fn project_entries_changed(
-        &mut self,
-        project: Model<Project>,
-        worktree_id: WorktreeId,
-        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
-        cx: &mut ModelContext<Self>,
-    ) {
-        let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) else {
-            return;
-        };
-        let project = project.downgrade();
-        let Some(project_state) = self.projects.get_mut(&project) else {
-            return;
-        };
-
-        let worktree = worktree.read(cx);
-        let worktree_state =
-            if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
-                worktree_state
-            } else {
-                return;
-            };
-        worktree_state.paths_changed(changes, worktree);
-        if let WorktreeState::Registered(_) = worktree_state {
-            cx.spawn(|this, mut cx| async move {
-                cx.background_executor()
-                    .timer(BACKGROUND_INDEXING_DELAY)
-                    .await;
-                if let Some((this, project)) = this.upgrade().zip(project.upgrade()) {
-                    this.update(&mut cx, |this, cx| {
-                        this.index_project(project, cx).detach_and_log_err(cx)
-                    })?;
-                }
-                anyhow::Ok(())
-            })
-            .detach_and_log_err(cx);
-        }
-    }
-
-    fn register_worktree(
-        &mut self,
-        project: Model<Project>,
-        worktree: Model<Worktree>,
-        cx: &mut ModelContext<Self>,
-    ) {
-        let project = project.downgrade();
-        let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
-            project_state
-        } else {
-            return;
-        };
-        let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
-            worktree
-        } else {
-            return;
-        };
-        let worktree_abs_path = worktree.abs_path().clone();
-        let scan_complete = worktree.scan_complete();
-        let worktree_id = worktree.id();
-        let db = self.db.clone();
-        let language_registry = self.language_registry.clone();
-        let (mut done_tx, done_rx) = watch::channel();
-        let registration = cx.spawn(|this, mut cx| {
-            async move {
-                let register = async {
-                    scan_complete.await;
-                    let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
-                    let mut file_mtimes = db.get_file_mtimes(db_id).await?;
-                    let worktree = if let Some(project) = project.upgrade() {
-                        project
-                            .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
-                            .ok()
-                            .flatten()
-                            .context("worktree not found")?
-                    } else {
-                        return anyhow::Ok(());
-                    };
-                    let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?;
-                    let mut changed_paths = cx
-                        .background_executor()
-                        .spawn(async move {
-                            let mut changed_paths = BTreeMap::new();
-                            for file in worktree.files(false, 0) {
-                                let absolute_path = worktree.absolutize(&file.path)?;
-
-                                if file.is_external || file.is_ignored || file.is_symlink {
-                                    continue;
-                                }
-
-                                if let Ok(language) = language_registry
-                                    .language_for_file_path(&absolute_path)
-                                    .await
-                                {
-                                    // Test if file is valid parseable file
-                                    if !PARSEABLE_ENTIRE_FILE_TYPES
-                                        .contains(&language.name().as_ref())
-                                        && &language.name().as_ref() != &"Markdown"
-                                        && language
-                                            .grammar()
-                                            .and_then(|grammar| grammar.embedding_config.as_ref())
-                                            .is_none()
-                                    {
-                                        continue;
-                                    }
-                                    let Some(new_mtime) = file.mtime else {
-                                        continue;
-                                    };
-
-                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
-                                    let already_stored = stored_mtime == Some(new_mtime);
-
-                                    if !already_stored {
-                                        changed_paths.insert(
-                                            file.path.clone(),
-                                            ChangedPathInfo {
-                                                mtime: new_mtime,
-                                                is_deleted: false,
-                                            },
-                                        );
-                                    }
-                                }
-                            }
-
-                            // Clean up entries from database that are no longer in the worktree.
-                            for (path, mtime) in file_mtimes {
-                                changed_paths.insert(
-                                    path.into(),
-                                    ChangedPathInfo {
-                                        mtime,
-                                        is_deleted: true,
-                                    },
-                                );
-                            }
-
-                            anyhow::Ok(changed_paths)
-                        })
-                        .await?;
-                    this.update(&mut cx, |this, cx| {
-                        let project_state = this
-                            .projects
-                            .get_mut(&project)
-                            .context("project not registered")?;
-                        let project = project.upgrade().context("project was dropped")?;
-
-                        if let Some(WorktreeState::Registering(state)) =
-                            project_state.worktrees.remove(&worktree_id)
-                        {
-                            changed_paths.extend(state.changed_paths);
-                        }
-                        project_state.worktrees.insert(
-                            worktree_id,
-                            WorktreeState::Registered(RegisteredWorktreeState {
-                                db_id,
-                                changed_paths,
-                            }),
-                        );
-                        this.index_project(project, cx).detach_and_log_err(cx);
-
-                        anyhow::Ok(())
-                    })??;
-
-                    anyhow::Ok(())
-                };
-
-                if register.await.log_err().is_none() {
-                    // Stop tracking this worktree if the registration failed.
-                    this.update(&mut cx, |this, _| {
-                        if let Some(project_state) = this.projects.get_mut(&project) {
-                            project_state.worktrees.remove(&worktree_id);
-                        }
-                    })
-                    .ok();
-                }
-
-                *done_tx.borrow_mut() = Some(());
-            }
-        });
-        project_state.worktrees.insert(
-            worktree_id,
-            WorktreeState::Registering(RegisteringWorktreeState {
-                changed_paths: Default::default(),
-                done_rx,
-                _registration: registration,
-            }),
-        );
-    }
-
-    fn project_worktrees_changed(&mut self, project: Model<Project>, cx: &mut ModelContext<Self>) {
-        let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
-        {
-            project_state
-        } else {
-            return;
-        };
-
-        let mut worktrees = project
-            .read(cx)
-            .worktrees()
-            .filter(|worktree| worktree.read(cx).is_local())
-            .collect::<Vec<_>>();
-        let worktree_ids = worktrees
-            .iter()
-            .map(|worktree| worktree.read(cx).id())
-            .collect::<HashSet<_>>();
-
-        // Remove worktrees that are no longer present
-        project_state
-            .worktrees
-            .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
-
-        // Register new worktrees
-        worktrees.retain(|worktree| {
-            let worktree_id = worktree.read(cx).id();
-            !project_state.worktrees.contains_key(&worktree_id)
-        });
-        for worktree in worktrees {
-            self.register_worktree(project.clone(), worktree, cx);
-        }
-    }
-
-    pub fn pending_file_count(&self, project: &Model<Project>) -> Option<watch::Receiver<usize>> {
-        Some(
-            self.projects
-                .get(&project.downgrade())?
-                .pending_file_count_rx
-                .clone(),
-        )
-    }
-
-    pub fn search_project(
-        &mut self,
-        project: Model<Project>,
-        query: String,
-        limit: usize,
-        includes: Vec<PathMatcher>,
-        excludes: Vec<PathMatcher>,
-        cx: &mut ModelContext<Self>,
-    ) -> Task<Result<Vec<SearchResult>>> {
-        if query.is_empty() {
-            return Task::ready(Ok(Vec::new()));
-        }
-
-        let index = self.index_project(project.clone(), cx);
-        let embedding_provider = self.embedding_provider.clone();
-
-        cx.spawn(|this, mut cx| async move {
-            index.await?;
-            let t0 = Instant::now();
-
-            let query = embedding_provider
-                .embed_batch(vec![query])
-                .await?
-                .pop()
-                .context("could not embed query")?;
-            log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
-
-            let search_start = Instant::now();
-            let modified_buffer_results = this.update(&mut cx, |this, cx| {
-                this.search_modified_buffers(
-                    &project,
-                    query.clone(),
-                    limit,
-                    &includes,
-                    &excludes,
-                    cx,
-                )
-            })?;
-            let file_results = this.update(&mut cx, |this, cx| {
-                this.search_files(project, query, limit, includes, excludes, cx)
-            })?;
-            let (modified_buffer_results, file_results) =
-                futures::join!(modified_buffer_results, file_results);
-
-            // Weave together the results from modified buffers and files.
-            let mut results = Vec::new();
-            let mut modified_buffers = HashSet::default();
-            for result in modified_buffer_results.log_err().unwrap_or_default() {
-                modified_buffers.insert(result.buffer.clone());
-                results.push(result);
-            }
-            for result in file_results.log_err().unwrap_or_default() {
-                if !modified_buffers.contains(&result.buffer) {
-                    results.push(result);
-                }
-            }
-            results.sort_by_key(|result| Reverse(result.similarity));
-            results.truncate(limit);
-            log::trace!("Semantic search took {:?}", search_start.elapsed());
-            Ok(results)
-        })
-    }
-
-    pub fn search_files(
-        &mut self,
-        project: Model<Project>,
-        query: Embedding,
-        limit: usize,
-        includes: Vec<PathMatcher>,
-        excludes: Vec<PathMatcher>,
-        cx: &mut ModelContext<Self>,
-    ) -> Task<Result<Vec<SearchResult>>> {
-        let db_path = self.db.path().clone();
-        let fs = self.fs.clone();
-        cx.spawn(|this, mut cx| async move {
-            let database = VectorDatabase::new(
-                fs.clone(),
-                db_path.clone(),
-                cx.background_executor().clone(),
-            )
-            .await?;
-
-            let worktree_db_ids = this.read_with(&cx, |this, _| {
-                let project_state = this
-                    .projects
-                    .get(&project.downgrade())
-                    .context("project was not indexed")?;
-                let worktree_db_ids = project_state
-                    .worktrees
-                    .values()
-                    .filter_map(|worktree| {
-                        if let WorktreeState::Registered(worktree) = worktree {
-                            Some(worktree.db_id)
-                        } else {
-                            None
-                        }
-                    })
-                    .collect::<Vec<i64>>();
-                anyhow::Ok(worktree_db_ids)
-            })??;
-
-            let file_ids = database
-                .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
-                .await?;
-
-            let batch_n = cx.background_executor().num_cpus();
-            let ids_len = file_ids.clone().len();
-            let minimum_batch_size = 50;
-
-            let batch_size = {
-                let size = ids_len / batch_n;
-                if size < minimum_batch_size {
-                    minimum_batch_size
-                } else {
-                    size
-                }
-            };
-
-            let mut batch_results = Vec::new();
-            for batch in file_ids.chunks(batch_size) {
-                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
-                let fs = fs.clone();
-                let db_path = db_path.clone();
-                let query = query.clone();
-                if let Some(db) =
-                    VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone())
-                        .await
-                        .log_err()
-                {
-                    batch_results.push(async move {
-                        db.top_k_search(&query, limit, batch.as_slice()).await
-                    });
-                }
-            }
-
-            let batch_results = futures::future::join_all(batch_results).await;
-
-            let mut results = Vec::new();
-            for batch_result in batch_results {
-                if batch_result.is_ok() {
-                    for (id, similarity) in batch_result.unwrap() {
-                        let ix = match results
-                            .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
-                        {
-                            Ok(ix) => ix,
-                            Err(ix) => ix,
-                        };
-
-                        results.insert(ix, (id, similarity));
-                        results.truncate(limit);
-                    }
-                }
-            }
-
-            let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
-            let scores = results
-                .into_iter()
-                .map(|(_, score)| score)
-                .collect::<Vec<_>>();
-            let spans = database.spans_for_ids(ids.as_slice()).await?;
-
-            let mut tasks = Vec::new();
-            let mut ranges = Vec::new();
-            let weak_project = project.downgrade();
-            project.update(&mut cx, |project, cx| {
-                let this = this.upgrade().context("index was dropped")?;
-                for (worktree_db_id, file_path, byte_range) in spans {
-                    let project_state =
-                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
-                            state
-                        } else {
-                            return Err(anyhow!("project not added"));
-                        };
-                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
-                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
-                        ranges.push(byte_range);
-                    }
-                }
-
-                Ok(())
-            })??;
-
-            let buffers = futures::future::join_all(tasks).await;
-            Ok(buffers
-                .into_iter()
-                .zip(ranges)
-                .zip(scores)
-                .filter_map(|((buffer, range), similarity)| {
-                    let buffer = buffer.log_err()?;
-                    let range = buffer
-                        .read_with(&cx, |buffer, _| {
-                            let start = buffer.clip_offset(range.start, Bias::Left);
-                            let end = buffer.clip_offset(range.end, Bias::Right);
-                            buffer.anchor_before(start)..buffer.anchor_after(end)
-                        })
-                        .log_err()?;
-                    Some(SearchResult {
-                        buffer,
-                        range,
-                        similarity,
-                    })
-                })
-                .collect())
-        })
-    }
-
-    fn search_modified_buffers(
-        &self,
-        project: &Model<Project>,
-        query: Embedding,
-        limit: usize,
-        includes: &[PathMatcher],
-        excludes: &[PathMatcher],
-        cx: &mut ModelContext<Self>,
-    ) -> Task<Result<Vec<SearchResult>>> {
-        let modified_buffers = project
-            .read(cx)
-            .opened_buffers()
-            .into_iter()
-            .filter_map(|buffer_handle| {
-                let buffer = buffer_handle.read(cx);
-                let snapshot = buffer.snapshot();
-                let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
-                    excludes.iter().any(|matcher| matcher.is_match(&path))
-                });
-
-                let included = if includes.len() == 0 {
-                    true
-                } else {
-                    snapshot.resolve_file_path(cx, false).map_or(false, |path| {
-                        includes.iter().any(|matcher| matcher.is_match(&path))
-                    })
-                };
-
-                if buffer.is_dirty() && !excluded && included {
-                    Some((buffer_handle, snapshot))
-                } else {
-                    None
-                }
-            })
-            .collect::<HashMap<_, _>>();
-
-        let embedding_provider = self.embedding_provider.clone();
-        let fs = self.fs.clone();
-        let db_path = self.db.path().clone();
-        let background = cx.background_executor().clone();
-        cx.background_executor().spawn(async move {
-            let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
-            let mut results = Vec::<SearchResult>::new();
-
-            let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
-            for (buffer, snapshot) in modified_buffers {
-                let language = snapshot
-                    .language_at(0)
-                    .cloned()
-                    .unwrap_or_else(|| language::PLAIN_TEXT.clone());
-                let mut spans = retriever
-                    .parse_file_with_template(None, &snapshot.text(), language)
-                    .log_err()
-                    .unwrap_or_default();
-                if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
-                    .await
-                    .log_err()
-                    .is_some()
-                {
-                    for span in spans {
-                        let similarity = span.embedding.unwrap().similarity(&query);
-                        let ix = match results
-                            .binary_search_by_key(&Reverse(similarity), |result| {
-                                Reverse(result.similarity)
-                            }) {
-                            Ok(ix) => ix,
-                            Err(ix) => ix,
-                        };
-
-                        let range = {
-                            let start = snapshot.clip_offset(span.range.start, Bias::Left);
-                            let end = snapshot.clip_offset(span.range.end, Bias::Right);
-                            snapshot.anchor_before(start)..snapshot.anchor_after(end)
-                        };
-
-                        results.insert(
-                            ix,
-                            SearchResult {
-                                buffer: buffer.clone(),
-                                range,
-                                similarity,
-                            },
-                        );
-                        results.truncate(limit);
-                    }
-                }
-            }
-
-            Ok(results)
-        })
-    }
-
-    pub fn index_project(
-        &mut self,
-        project: Model<Project>,
-        cx: &mut ModelContext<Self>,
-    ) -> Task<Result<()>> {
-        if self.is_authenticated() {
-            self.index_project_internal(project, cx)
-        } else {
-            let authenticate = self.authenticate(cx);
-            cx.spawn(|this, mut cx| async move {
-                if authenticate.await {
-                    this.update(&mut cx, |this, cx| this.index_project_internal(project, cx))?
-                        .await
-                } else {
-                    Err(anyhow!("user is not authenticated"))
-                }
-            })
-        }
-    }
-
-    fn index_project_internal(
-        &mut self,
-        project: Model<Project>,
-        cx: &mut ModelContext<Self>,
-    ) -> Task<Result<()>> {
-        if !self.projects.contains_key(&project.downgrade()) {
-            let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
-                project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
-                    this.project_worktrees_changed(project.clone(), cx);
-                }
-                project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
-                    this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
-                }
-                _ => {}
-            });
-            let project_state = ProjectState::new(subscription, cx);
-            self.projects.insert(project.downgrade(), project_state);
-            self.project_worktrees_changed(project.clone(), cx);
-        }
-        let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
-        project_state.pending_index += 1;
-        cx.notify();
-
-        let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
-        let db = self.db.clone();
-        let language_registry = self.language_registry.clone();
-        let parsing_files_tx = self.parsing_files_tx.clone();
-        let worktree_registration = self.wait_for_worktree_registration(&project, cx);
-
-        cx.spawn(|this, mut cx| async move {
-            worktree_registration.await?;
-
-            let mut pending_files = Vec::new();
-            let mut files_to_delete = Vec::new();
-            this.update(&mut cx, |this, cx| {
-                let project_state = this
-                    .projects
-                    .get_mut(&project.downgrade())
-                    .context("project was dropped")?;
-                let pending_file_count_tx = &project_state.pending_file_count_tx;
-
-                project_state
-                    .worktrees
-                    .retain(|worktree_id, worktree_state| {
-                        let worktree = if let Some(worktree) =
-                            project.read(cx).worktree_for_id(*worktree_id, cx)
-                        {
-                            worktree
-                        } else {
-                            return false;
-                        };
-                        let worktree_state =
-                            if let WorktreeState::Registered(worktree_state) = worktree_state {
-                                worktree_state
-                            } else {
-                                return true;
-                            };
-
-                        for (path, info) in &worktree_state.changed_paths {
-                            if info.is_deleted {
-                                files_to_delete.push((worktree_state.db_id, path.clone()));
-                            } else if let Ok(absolute_path) = worktree.read(cx).absolutize(path) {
-                                let job_handle = JobHandle::new(pending_file_count_tx);
-                                pending_files.push(PendingFile {
-                                    absolute_path,
-                                    relative_path: path.clone(),
-                                    language: None,
-                                    job_handle,
-                                    modified_time: info.mtime,
-                                    worktree_db_id: worktree_state.db_id,
-                                });
-                            }
-                        }
-                        worktree_state.changed_paths.clear();
-                        true
-                    });
-
-                anyhow::Ok(())
-            })??;
-
-            cx.background_executor()
-                .spawn(async move {
-                    for (worktree_db_id, path) in files_to_delete {
-                        db.delete_file(worktree_db_id, path).await.log_err();
-                    }
-
-                    let embeddings_for_digest = {
-                        let mut files = HashMap::default();
-                        for pending_file in &pending_files {
-                            files
-                                .entry(pending_file.worktree_db_id)
-                                .or_insert(Vec::new())
-                                .push(pending_file.relative_path.clone());
-                        }
-                        Arc::new(
-                            db.embeddings_for_files(files)
-                                .await
-                                .log_err()
-                                .unwrap_or_default(),
-                        )
-                    };
-
-                    for mut pending_file in pending_files {
-                        if let Ok(language) = language_registry
-                            .language_for_file_path(&pending_file.relative_path)
-                            .await
-                        {
-                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
-                                && &language.name().as_ref() != &"Markdown"
-                                && language
-                                    .grammar()
-                                    .and_then(|grammar| grammar.embedding_config.as_ref())
-                                    .is_none()
-                            {
-                                continue;
-                            }
-                            pending_file.language = Some(language);
-                        }
-                        parsing_files_tx
-                            .try_send((embeddings_for_digest.clone(), pending_file))
-                            .ok();
-                    }
-
-                    // Wait until we're done indexing.
-                    while let Some(count) = pending_file_count_rx.next().await {
-                        if count == 0 {
-                            break;
-                        }
-                    }
-                })
-                .await;
-
-            this.update(&mut cx, |this, cx| {
-                let project_state = this
-                    .projects
-                    .get_mut(&project.downgrade())
-                    .context("project was dropped")?;
-                project_state.pending_index -= 1;
-                cx.notify();
-                anyhow::Ok(())
-            })??;
-
-            Ok(())
-        })
-    }
-
-    fn wait_for_worktree_registration(
-        &self,
-        project: &Model<Project>,
-        cx: &mut ModelContext<Self>,
-    ) -> Task<Result<()>> {
-        let project = project.downgrade();
-        cx.spawn(|this, cx| async move {
-            loop {
-                let mut pending_worktrees = Vec::new();
-                this.upgrade()
-                    .context("semantic index dropped")?
-                    .read_with(&cx, |this, _| {
-                        if let Some(project) = this.projects.get(&project) {
-                            for worktree in project.worktrees.values() {
-                                if let WorktreeState::Registering(worktree) = worktree {
-                                    pending_worktrees.push(worktree.done());
-                                }
-                            }
-                        }
-                    })?;
-
-                if pending_worktrees.is_empty() {
-                    break;
-                } else {
-                    future::join_all(pending_worktrees).await;
-                }
-            }
-            Ok(())
-        })
-    }
-
-    async fn embed_spans(
-        spans: &mut [Span],
-        embedding_provider: &dyn EmbeddingProvider,
-        db: &VectorDatabase,
-    ) -> Result<()> {
-        let mut batch = Vec::new();
-        let mut batch_tokens = 0;
-        let mut embeddings = Vec::new();
-
-        let digests = spans
-            .iter()
-            .map(|span| span.digest.clone())
-            .collect::<Vec<_>>();
-        let embeddings_for_digests = db
-            .embeddings_for_digests(digests)
-            .await
-            .log_err()
-            .unwrap_or_default();
-
-        for span in &*spans {
-            if embeddings_for_digests.contains_key(&span.digest) {
-                continue;
-            };
-
-            if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
-                let batch_embeddings = embedding_provider
-                    .embed_batch(mem::take(&mut batch))
-                    .await?;
-                embeddings.extend(batch_embeddings);
-                batch_tokens = 0;
-            }
-
-            batch_tokens += span.token_count;
-            batch.push(span.content.clone());
-        }
-
-        if !batch.is_empty() {
-            let batch_embeddings = embedding_provider
-                .embed_batch(mem::take(&mut batch))
-                .await?;
-
-            embeddings.extend(batch_embeddings);
-        }
-
-        let mut embeddings = embeddings.into_iter();
-        for span in spans {
-            let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
-                Some(embedding.clone())
-            } else {
-                embeddings.next()
-            };
-            let embedding = embedding.context("failed to embed spans")?;
-            span.embedding = Some(embedding);
-        }
-        Ok(())
-    }
-}
-
-impl Drop for JobHandle {
-    fn drop(&mut self) {
-        if let Some(inner) = Arc::get_mut(&mut self.tx) {
-            // This is the last instance of the JobHandle (regardless of its origin - whether it was cloned or not)
-            if let Some(tx) = inner.upgrade() {
-                let mut tx = tx.lock();
-                *tx.borrow_mut() -= 1;
-            }
-        }
-    }
-}
-
-#[cfg(test)]
-mod tests {
-
-    use super::*;
-    #[test]
-    fn test_job_handle() {
-        let (job_count_tx, job_count_rx) = watch::channel_with(0);
-        let tx = Arc::new(Mutex::new(job_count_tx));
-        let job_handle = JobHandle::new(&tx);
-
-        assert_eq!(1, *job_count_rx.borrow());
-        let new_job_handle = job_handle.clone();
-        assert_eq!(1, *job_count_rx.borrow());
-        drop(job_handle);
-        assert_eq!(1, *job_count_rx.borrow());
-        drop(new_job_handle);
-        assert_eq!(0, *job_count_rx.borrow());
-    }
-}

crates/semantic_index/src/semantic_index_settings.rs 🔗

@@ -1,33 +0,0 @@
-use anyhow;
-use schemars::JsonSchema;
-use serde::{Deserialize, Serialize};
-use settings::Settings;
-
-#[derive(Deserialize, Debug)]
-pub struct SemanticIndexSettings {
-    pub enabled: bool,
-}
-
-/// Configuration of semantic index, an alternate search engine available in
-/// project search.
-#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
-pub struct SemanticIndexSettingsContent {
-    /// Whether or not to display the Semantic mode in project search.
-    ///
-    /// Default: true
-    pub enabled: Option<bool>,
-}
-
-impl Settings for SemanticIndexSettings {
-    const KEY: Option<&'static str> = Some("semantic_index");
-
-    type FileContent = SemanticIndexSettingsContent;
-
-    fn load(
-        default_value: &Self::FileContent,
-        user_values: &[&Self::FileContent],
-        _: &mut gpui::AppContext,
-    ) -> anyhow::Result<Self> {
-        Self::load_via_json_merge(default_value, user_values)
-    }
-}

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -1,1725 +0,0 @@
-use crate::{
-    embedding_queue::EmbeddingQueue,
-    parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
-    semantic_index_settings::SemanticIndexSettings,
-    FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
-};
-use ai::test::FakeEmbeddingProvider;
-use gpui::TestAppContext;
-use language::{Language, LanguageConfig, LanguageMatcher, LanguageRegistry, ToOffset};
-use parking_lot::Mutex;
-use pretty_assertions::assert_eq;
-use project::{FakeFs, Fs, Project};
-use rand::{rngs::StdRng, Rng};
-use serde_json::json;
-use settings::{Settings, SettingsStore};
-use std::{path::Path, sync::Arc, time::SystemTime};
-use unindent::Unindent;
-use util::{paths::PathMatcher, RandomCharIter};
-
-#[ctor::ctor]
-fn init_logger() {
-    if std::env::var("RUST_LOG").is_ok() {
-        env_logger::init();
-    }
-}
-
-#[gpui::test]
-async fn test_semantic_index(cx: &mut TestAppContext) {
-    init_test(cx);
-
-    let fs = FakeFs::new(cx.background_executor.clone());
-    fs.insert_tree(
-        "/the-root",
-        json!({
-            "src": {
-                "file1.rs": "
-                    fn aaa() {
-                        println!(\"aaaaaaaaaaaa!\");
-                    }
-
-                    fn zzzzz() {
-                        println!(\"SLEEPING\");
-                    }
-                ".unindent(),
-                "file2.rs": "
-                    fn bbb() {
-                        println!(\"bbbbbbbbbbbbb!\");
-                    }
-                    struct pqpqpqp {}
-                ".unindent(),
-                "file3.toml": "
-                    ZZZZZZZZZZZZZZZZZZ = 5
-                ".unindent(),
-            }
-        }),
-    )
-    .await;
-
-    let languages = Arc::new(LanguageRegistry::test(cx.executor().clone()));
-    let rust_language = rust_lang();
-    let toml_language = toml_lang();
-    languages.add(rust_language);
-    languages.add(toml_language);
-
-    let db_dir = tempfile::Builder::new()
-        .prefix("vector-store")
-        .tempdir()
-        .unwrap();
-    let db_path = db_dir.path().join("db.sqlite");
-
-    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
-    let semantic_index = SemanticIndex::new(
-        fs.clone(),
-        db_path,
-        embedding_provider.clone(),
-        languages,
-        cx.to_async(),
-    )
-    .await
-    .unwrap();
-
-    let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
-
-    let search_results = semantic_index.update(cx, |store, cx| {
-        store.search_project(
-            project.clone(),
-            "aaaaaabbbbzz".to_string(),
-            5,
-            vec![],
-            vec![],
-            cx,
-        )
-    });
-    let pending_file_count =
-        semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
-    cx.background_executor.run_until_parked();
-    assert_eq!(*pending_file_count.borrow(), 3);
-    cx.background_executor
-        .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
-    assert_eq!(*pending_file_count.borrow(), 0);
-
-    let search_results = search_results.await.unwrap();
-    assert_search_results(
-        &search_results,
-        &[
-            (Path::new("src/file1.rs").into(), 0),
-            (Path::new("src/file2.rs").into(), 0),
-            (Path::new("src/file3.toml").into(), 0),
-            (Path::new("src/file1.rs").into(), 45),
-            (Path::new("src/file2.rs").into(), 45),
-        ],
-        cx,
-    );
-
-    // Test Include Files Functionality
-    let include_files = vec![PathMatcher::new("*.rs").unwrap()];
-    let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
-    let rust_only_search_results = semantic_index
-        .update(cx, |store, cx| {
-            store.search_project(
-                project.clone(),
-                "aaaaaabbbbzz".to_string(),
-                5,
-                include_files,
-                vec![],
-                cx,
-            )
-        })
-        .await
-        .unwrap();
-
-    assert_search_results(
-        &rust_only_search_results,
-        &[
-            (Path::new("src/file1.rs").into(), 0),
-            (Path::new("src/file2.rs").into(), 0),
-            (Path::new("src/file1.rs").into(), 45),
-            (Path::new("src/file2.rs").into(), 45),
-        ],
-        cx,
-    );
-
-    let no_rust_search_results = semantic_index
-        .update(cx, |store, cx| {
-            store.search_project(
-                project.clone(),
-                "aaaaaabbbbzz".to_string(),
-                5,
-                vec![],
-                exclude_files,
-                cx,
-            )
-        })
-        .await
-        .unwrap();
-
-    assert_search_results(
-        &no_rust_search_results,
-        &[(Path::new("src/file3.toml").into(), 0)],
-        cx,
-    );
-
-    fs.save(
-        "/the-root/src/file2.rs".as_ref(),
-        &"
-            fn dddd() { println!(\"ddddd!\"); }
-            struct pqpqpqp {}
-        "
-        .unindent()
-        .into(),
-        Default::default(),
-    )
-    .await
-    .unwrap();
-
-    cx.background_executor
-        .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
-
-    let prev_embedding_count = embedding_provider.embedding_count();
-    let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
-    cx.background_executor.run_until_parked();
-    assert_eq!(*pending_file_count.borrow(), 1);
-    cx.background_executor
-        .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
-    assert_eq!(*pending_file_count.borrow(), 0);
-    index.await.unwrap();
-
-    assert_eq!(
-        embedding_provider.embedding_count() - prev_embedding_count,
-        1
-    );
-}
-
-#[gpui::test(iterations = 10)]
-async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
-    let (outstanding_job_count, _) = postage::watch::channel_with(0);
-    let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
-
-    let files = (1..=3)
-        .map(|file_ix| FileToEmbed {
-            worktree_id: 5,
-            path: Path::new(&format!("path-{file_ix}")).into(),
-            mtime: SystemTime::now(),
-            spans: (0..rng.gen_range(4..22))
-                .map(|document_ix| {
-                    let content_len = rng.gen_range(10..100);
-                    let content = RandomCharIter::new(&mut rng)
-                        .with_simple_text()
-                        .take(content_len)
-                        .collect::<String>();
-                    let digest = SpanDigest::from(content.as_str());
-                    Span {
-                        range: 0..10,
-                        embedding: None,
-                        name: format!("document {document_ix}"),
-                        content,
-                        digest,
-                        token_count: rng.gen_range(10..30),
-                    }
-                })
-                .collect(),
-            job_handle: JobHandle::new(&outstanding_job_count),
-        })
-        .collect::<Vec<_>>();
-
-    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
-
-    let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor.clone());
-    for file in &files {
-        queue.push(file.clone());
-    }
-    queue.flush();
-
-    cx.background_executor.run_until_parked();
-    let finished_files = queue.finished_files();
-    let mut embedded_files: Vec<_> = files
-        .iter()
-        .map(|_| finished_files.try_recv().expect("no finished file"))
-        .collect();
-
-    let expected_files: Vec<_> = files
-        .iter()
-        .map(|file| {
-            let mut file = file.clone();
-            for doc in &mut file.spans {
-                doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
-            }
-            file
-        })
-        .collect();
-
-    embedded_files.sort_by_key(|f| f.path.clone());
-
-    assert_eq!(embedded_files, expected_files);
-}
-
-#[track_caller]
-fn assert_search_results(
-    actual: &[SearchResult],
-    expected: &[(Arc<Path>, usize)],
-    cx: &TestAppContext,
-) {
-    let actual = actual
-        .iter()
-        .map(|search_result| {
-            search_result.buffer.read_with(cx, |buffer, _cx| {
-                (
-                    buffer.file().unwrap().path().clone(),
-                    search_result.range.start.to_offset(buffer),
-                )
-            })
-        })
-        .collect::<Vec<_>>();
-    assert_eq!(actual, expected);
-}
-
-#[gpui::test]
-async fn test_code_context_retrieval_rust() {
-    let language = rust_lang();
-    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
-    let mut retriever = CodeContextRetriever::new(embedding_provider);
-
-    let text = "
-        /// A doc comment
-        /// that spans multiple lines
-        #[gpui::test]
-        fn a() {
-            b
-        }
-
-        impl C for D {
-        }
-
-        impl E {
-            // This is also a preceding comment
-            pub fn function_1() -> Option<()> {
-                unimplemented!();
-            }
-
-            // This is a preceding comment
-            fn function_2() -> Result<()> {
-                unimplemented!();
-            }
-        }
-
-        #[derive(Clone)]
-        struct D {
-            name: String
-        }
-    "
-    .unindent();
-
-    let documents = retriever.parse_file(&text, language).unwrap();
-
-    assert_documents_eq(
-        &documents,
-        &[
-            (
-                "
-                /// A doc comment
-                /// that spans multiple lines
-                #[gpui::test]
-                fn a() {
-                    b
-                }"
-                .unindent(),
-                text.find("fn a").unwrap(),
-            ),
-            (
-                "
-                impl C for D {
-                }"
-                .unindent(),
-                text.find("impl C").unwrap(),
-            ),
-            (
-                "
-                impl E {
-                    // This is also a preceding comment
-                    pub fn function_1() -> Option<()> { /* ... */ }
-
-                    // This is a preceding comment
-                    fn function_2() -> Result<()> { /* ... */ }
-                }"
-                .unindent(),
-                text.find("impl E").unwrap(),
-            ),
-            (
-                "
-                // This is also a preceding comment
-                pub fn function_1() -> Option<()> {
-                    unimplemented!();
-                }"
-                .unindent(),
-                text.find("pub fn function_1").unwrap(),
-            ),
-            (
-                "
-                // This is a preceding comment
-                fn function_2() -> Result<()> {
-                    unimplemented!();
-                }"
-                .unindent(),
-                text.find("fn function_2").unwrap(),
-            ),
-            (
-                "
-                #[derive(Clone)]
-                struct D {
-                    name: String
-                }"
-                .unindent(),
-                text.find("struct D").unwrap(),
-            ),
-        ],
-    );
-}
-
-#[gpui::test]
-async fn test_code_context_retrieval_json() {
-    let language = json_lang();
-    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
-    let mut retriever = CodeContextRetriever::new(embedding_provider);
-
-    let text = r#"
-        {
-            "array": [1, 2, 3, 4],
-            "string": "abcdefg",
-            "nested_object": {
-                "array_2": [5, 6, 7, 8],
-                "string_2": "hijklmnop",
-                "boolean": true,
-                "none": null
-            }
-        }
-    "#
-    .unindent();
-
-    let documents = retriever.parse_file(&text, language.clone()).unwrap();
-
-    assert_documents_eq(
-        &documents,
-        &[(
-            r#"
-                {
-                    "array": [],
-                    "string": "",
-                    "nested_object": {
-                        "array_2": [],
-                        "string_2": "",
-                        "boolean": true,
-                        "none": null
-                    }
-                }"#
-            .unindent(),
-            text.find('{').unwrap(),
-        )],
-    );
-
-    let text = r#"
-        [
-            {
-                "name": "somebody",
-                "age": 42
-            },
-            {
-                "name": "somebody else",
-                "age": 43
-            }
-        ]
-    "#
-    .unindent();
-
-    let documents = retriever.parse_file(&text, language.clone()).unwrap();
-
-    assert_documents_eq(
-        &documents,
-        &[(
-            r#"
-            [{
-                    "name": "",
-                    "age": 42
-                }]"#
-            .unindent(),
-            text.find('[').unwrap(),
-        )],
-    );
-}
-
-fn assert_documents_eq(
-    documents: &[Span],
-    expected_contents_and_start_offsets: &[(String, usize)],
-) {
-    assert_eq!(
-        documents
-            .iter()
-            .map(|document| (document.content.clone(), document.range.start))
-            .collect::<Vec<_>>(),
-        expected_contents_and_start_offsets
-    );
-}
-
-#[gpui::test]
-async fn test_code_context_retrieval_javascript() {
-    let language = js_lang();
-    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
-    let mut retriever = CodeContextRetriever::new(embedding_provider);
-
-    let text = "
-        /* globals importScripts, backend */
-        function _authorize() {}
-
-        /**
-         * Sometimes the frontend build is way faster than backend.
-         */
-        export async function authorizeBank() {
-            _authorize(pushModal, upgradingAccountId, {});
-        }
-
-        export class SettingsPage {
-            /* This is a test setting */
-            constructor(page) {
-                this.page = page;
-            }
-        }
-
-        /* This is a test comment */
-        class TestClass {}
-
-        /* Schema for editor_events in Clickhouse. */
-        export interface ClickhouseEditorEvent {
-            installation_id: string
-            operation: string
-        }
-        "
-    .unindent();
-
-    let documents = retriever.parse_file(&text, language.clone()).unwrap();
-
-    assert_documents_eq(
-        &documents,
-        &[
-            (
-                "
-            /* globals importScripts, backend */
-            function _authorize() {}"
-                    .unindent(),
-                37,
-            ),
-            (
-                "
-            /**
-             * Sometimes the frontend build is way faster than backend.
-             */
-            export async function authorizeBank() {
-                _authorize(pushModal, upgradingAccountId, {});
-            }"
-                .unindent(),
-                131,
-            ),
-            (
-                "
-                export class SettingsPage {
-                    /* This is a test setting */
-                    constructor(page) {
-                        this.page = page;
-                    }
-                }"
-                .unindent(),
-                225,
-            ),
-            (
-                "
-                /* This is a test setting */
-                constructor(page) {
-                    this.page = page;
-                }"
-                .unindent(),
-                290,
-            ),
-            (
-                "
-                /* This is a test comment */
-                class TestClass {}"
-                    .unindent(),
-                374,
-            ),
-            (
-                "
-                /* Schema for editor_events in Clickhouse. */
-                export interface ClickhouseEditorEvent {
-                    installation_id: string
-                    operation: string
-                }"
-                .unindent(),
-                440,
-            ),
-        ],
-    )
-}
-
-#[gpui::test]
-async fn test_code_context_retrieval_lua() {
-    let language = lua_lang();
-    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
-    let mut retriever = CodeContextRetriever::new(embedding_provider);
-
-    let text = r#"
-        -- Creates a new class
-        -- @param baseclass The Baseclass of this class, or nil.
-        -- @return A new class reference.
-        function classes.class(baseclass)
-            -- Create the class definition and metatable.
-            local classdef = {}
-            -- Find the super class, either Object or user-defined.
-            baseclass = baseclass or classes.Object
-            -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
-            setmetatable(classdef, { __index = baseclass })
-            -- All class instances have a reference to the class object.
-            classdef.class = classdef
-            --- Recursively allocates the inheritance tree of the instance.
-            -- @param mastertable The 'root' of the inheritance tree.
-            -- @return Returns the instance with the allocated inheritance tree.
-            function classdef.alloc(mastertable)
-                -- All class instances have a reference to a superclass object.
-                local instance = { super = baseclass.alloc(mastertable) }
-                -- Any functions this instance does not know of will 'look up' to the superclass definition.
-                setmetatable(instance, { __index = classdef, __newindex = mastertable })
-                return instance
-            end
-        end
-        "#.unindent();
-
-    let documents = retriever.parse_file(&text, language.clone()).unwrap();
-
-    assert_documents_eq(
-        &documents,
-        &[
-            (r#"
-                -- Creates a new class
-                -- @param baseclass The Baseclass of this class, or nil.
-                -- @return A new class reference.
-                function classes.class(baseclass)
-                    -- Create the class definition and metatable.
-                    local classdef = {}
-                    -- Find the super class, either Object or user-defined.
-                    baseclass = baseclass or classes.Object
-                    -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
-                    setmetatable(classdef, { __index = baseclass })
-                    -- All class instances have a reference to the class object.
-                    classdef.class = classdef
-                    --- Recursively allocates the inheritance tree of the instance.
-                    -- @param mastertable The 'root' of the inheritance tree.
-                    -- @return Returns the instance with the allocated inheritance tree.
-                    function classdef.alloc(mastertable)
-                        --[ ... ]--
-                        --[ ... ]--
-                    end
-                end"#.unindent(),
-            114),
-            (r#"
-            --- Recursively allocates the inheritance tree of the instance.
-            -- @param mastertable The 'root' of the inheritance tree.
-            -- @return Returns the instance with the allocated inheritance tree.
-            function classdef.alloc(mastertable)
-                -- All class instances have a reference to a superclass object.
-                local instance = { super = baseclass.alloc(mastertable) }
-                -- Any functions this instance does not know of will 'look up' to the superclass definition.
-                setmetatable(instance, { __index = classdef, __newindex = mastertable })
-                return instance
-            end"#.unindent(), 810),
-        ]
-    );
-}
-
-#[gpui::test]
-async fn test_code_context_retrieval_elixir() {
-    let language = elixir_lang();
-    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
-    let mut retriever = CodeContextRetriever::new(embedding_provider);
-
-    let text = r#"
-        defmodule File.Stream do
-            @moduledoc """
-            Defines a `File.Stream` struct returned by `File.stream!/3`.
-
-            The following fields are public:
-
-            * `path`          - the file path
-            * `modes`         - the file modes
-            * `raw`           - a boolean indicating if bin functions should be used
-            * `line_or_bytes` - if reading should read lines or a given number of bytes
-            * `node`          - the node the file belongs to
-
-            """
-
-            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
-
-            @type t :: %__MODULE__{}
-
-            @doc false
-            def __build__(path, modes, line_or_bytes) do
-            raw = :lists.keyfind(:encoding, 1, modes) == false
-
-            modes =
-                case raw do
-                true ->
-                    case :lists.keyfind(:read_ahead, 1, modes) do
-                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
-                    {:read_ahead, _} -> [:raw | modes]
-                    false -> [:raw, :read_ahead | modes]
-                    end
-
-                false ->
-                    modes
-                end
-
-            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
-
-            end"#
-    .unindent();
-
-    let documents = retriever.parse_file(&text, language.clone()).unwrap();
-
-    assert_documents_eq(
-        &documents,
-        &[(
-            r#"
-        defmodule File.Stream do
-            @moduledoc """
-            Defines a `File.Stream` struct returned by `File.stream!/3`.
-
-            The following fields are public:
-
-            * `path`          - the file path
-            * `modes`         - the file modes
-            * `raw`           - a boolean indicating if bin functions should be used
-            * `line_or_bytes` - if reading should read lines or a given number of bytes
-            * `node`          - the node the file belongs to
-
-            """
-
-            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
-
-            @type t :: %__MODULE__{}
-
-            @doc false
-            def __build__(path, modes, line_or_bytes) do
-            raw = :lists.keyfind(:encoding, 1, modes) == false
-
-            modes =
-                case raw do
-                true ->
-                    case :lists.keyfind(:read_ahead, 1, modes) do
-                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
-                    {:read_ahead, _} -> [:raw | modes]
-                    false -> [:raw, :read_ahead | modes]
-                    end
-
-                false ->
-                    modes
-                end
-
-            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
-
-            end"#
-                .unindent(),
-            0,
-        ),(r#"
-            @doc false
-            def __build__(path, modes, line_or_bytes) do
-            raw = :lists.keyfind(:encoding, 1, modes) == false
-
-            modes =
-                case raw do
-                true ->
-                    case :lists.keyfind(:read_ahead, 1, modes) do
-                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
-                    {:read_ahead, _} -> [:raw | modes]
-                    false -> [:raw, :read_ahead | modes]
-                    end
-
-                false ->
-                    modes
-                end
-
-            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
-
-            end"#.unindent(), 574)],
-    );
-}
-
-#[gpui::test]
-async fn test_code_context_retrieval_cpp() {
-    let language = cpp_lang();
-    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
-    let mut retriever = CodeContextRetriever::new(embedding_provider);
-
-    let text = "
-    /**
-     * @brief Main function
-     * @returns 0 on exit
-     */
-    int main() { return 0; }
-
-    /**
-    * This is a test comment
-    */
-    class MyClass {       // The class
-        public:           // Access specifier
-        int myNum;        // Attribute (int variable)
-        string myString;  // Attribute (string variable)
-    };
-
-    // This is a test comment
-    enum Color { red, green, blue };
-
-    /** This is a preceding block comment
-     * This is the second line
-     */
-    struct {           // Structure declaration
-        int myNum;       // Member (int variable)
-        string myString; // Member (string variable)
-    } myStructure;
-
-    /**
-     * @brief Matrix class.
-     */
-    template <typename T,
-              typename = typename std::enable_if<
-                std::is_integral<T>::value || std::is_floating_point<T>::value,
-                bool>::type>
-    class Matrix2 {
-        std::vector<std::vector<T>> _mat;
-
-        public:
-            /**
-            * @brief Constructor
-            * @tparam Integer ensuring integers are being evaluated and not other
-            * data types.
-            * @param size denoting the size of Matrix as size x size
-            */
-            template <typename Integer,
-                    typename = typename std::enable_if<std::is_integral<Integer>::value,
-                    Integer>::type>
-            explicit Matrix(const Integer size) {
-                for (size_t i = 0; i < size; ++i) {
-                    _mat.emplace_back(std::vector<T>(size, 0));
-                }
-            }
-    }"
-    .unindent();
-
-    let documents = retriever.parse_file(&text, language.clone()).unwrap();
-
-    assert_documents_eq(
-        &documents,
-        &[
-            (
-                "
-        /**
-         * @brief Main function
-         * @returns 0 on exit
-         */
-        int main() { return 0; }"
-                    .unindent(),
-                54,
-            ),
-            (
-                "
-                /**
-                * This is a test comment
-                */
-                class MyClass {       // The class
-                    public:           // Access specifier
-                    int myNum;        // Attribute (int variable)
-                    string myString;  // Attribute (string variable)
-                }"
-                .unindent(),
-                112,
-            ),
-            (
-                "
-                // This is a test comment
-                enum Color { red, green, blue }"
-                    .unindent(),
-                322,
-            ),
-            (
-                "
-                /** This is a preceding block comment
-                 * This is the second line
-                 */
-                struct {           // Structure declaration
-                    int myNum;       // Member (int variable)
-                    string myString; // Member (string variable)
-                } myStructure;"
-                    .unindent(),
-                425,
-            ),
-            (
-                "
-                /**
-                 * @brief Matrix class.
-                 */
-                template <typename T,
-                          typename = typename std::enable_if<
-                            std::is_integral<T>::value || std::is_floating_point<T>::value,
-                            bool>::type>
-                class Matrix2 {
-                    std::vector<std::vector<T>> _mat;
-
-                    public:
-                        /**
-                        * @brief Constructor
-                        * @tparam Integer ensuring integers are being evaluated and not other
-                        * data types.
-                        * @param size denoting the size of Matrix as size x size
-                        */
-                        template <typename Integer,
-                                typename = typename std::enable_if<std::is_integral<Integer>::value,
-                                Integer>::type>
-                        explicit Matrix(const Integer size) {
-                            for (size_t i = 0; i < size; ++i) {
-                                _mat.emplace_back(std::vector<T>(size, 0));
-                            }
-                        }
-                }"
-                .unindent(),
-                612,
-            ),
-            (
-                "
-                explicit Matrix(const Integer size) {
-                    for (size_t i = 0; i < size; ++i) {
-                        _mat.emplace_back(std::vector<T>(size, 0));
-                    }
-                }"
-                .unindent(),
-                1226,
-            ),
-        ],
-    );
-}
-
-#[gpui::test]
-async fn test_code_context_retrieval_ruby() {
-    let language = ruby_lang();
-    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
-    let mut retriever = CodeContextRetriever::new(embedding_provider);
-
-    let text = r#"
-        # This concern is inspired by "sudo mode" on GitHub. It
-        # is a way to re-authenticate a user before allowing them
-        # to see or perform an action.
-        #
-        # Add `before_action :require_challenge!` to actions you
-        # want to protect.
-        #
-        # The user will be shown a page to enter the challenge (which
-        # is either the password, or just the username when no
-        # password exists). Upon passing, there is a grace period
-        # during which no challenge will be asked from the user.
-        #
-        # Accessing challenge-protected resources during the grace
-        # period will refresh the grace period.
-        module ChallengableConcern
-            extend ActiveSupport::Concern
-
-            CHALLENGE_TIMEOUT = 1.hour.freeze
-
-            def require_challenge!
-                return if skip_challenge?
-
-                if challenge_passed_recently?
-                    session[:challenge_passed_at] = Time.now.utc
-                    return
-                end
-
-                @challenge = Form::Challenge.new(return_to: request.url)
-
-                if params.key?(:form_challenge)
-                    if challenge_passed?
-                        session[:challenge_passed_at] = Time.now.utc
-                    else
-                        flash.now[:alert] = I18n.t('challenge.invalid_password')
-                        render_challenge
-                    end
-                else
-                    render_challenge
-                end
-            end
-
-            def challenge_passed?
-                current_user.valid_password?(challenge_params[:current_password])
-            end
-        end
-
-        class Animal
-            include Comparable
-
-            attr_reader :legs
-
-            def initialize(name, legs)
-                @name, @legs = name, legs
-            end
-
-            def <=>(other)
-                legs <=> other.legs
-            end
-        end
-
-        # Singleton method for car object
-        def car.wheels
-            puts "There are four wheels"
-        end"#
-        .unindent();
-
-    let documents = retriever.parse_file(&text, language.clone()).unwrap();
-
-    assert_documents_eq(
-        &documents,
-        &[
-            (
-                r#"
-        # This concern is inspired by "sudo mode" on GitHub. It
-        # is a way to re-authenticate a user before allowing them
-        # to see or perform an action.
-        #
-        # Add `before_action :require_challenge!` to actions you
-        # want to protect.
-        #
-        # The user will be shown a page to enter the challenge (which
-        # is either the password, or just the username when no
-        # password exists). Upon passing, there is a grace period
-        # during which no challenge will be asked from the user.
-        #
-        # Accessing challenge-protected resources during the grace
-        # period will refresh the grace period.
-        module ChallengableConcern
-            extend ActiveSupport::Concern
-
-            CHALLENGE_TIMEOUT = 1.hour.freeze
-
-            def require_challenge!
-                # ...
-            end
-
-            def challenge_passed?
-                # ...
-            end
-        end"#
-                    .unindent(),
-                558,
-            ),
-            (
-                r#"
-            def require_challenge!
-                return if skip_challenge?
-
-                if challenge_passed_recently?
-                    session[:challenge_passed_at] = Time.now.utc
-                    return
-                end
-
-                @challenge = Form::Challenge.new(return_to: request.url)
-
-                if params.key?(:form_challenge)
-                    if challenge_passed?
-                        session[:challenge_passed_at] = Time.now.utc
-                    else
-                        flash.now[:alert] = I18n.t('challenge.invalid_password')
-                        render_challenge
-                    end
-                else
-                    render_challenge
-                end
-            end"#
-                    .unindent(),
-                663,
-            ),
-            (
-                r#"
-                def challenge_passed?
-                    current_user.valid_password?(challenge_params[:current_password])
-                end"#
-                    .unindent(),
-                1254,
-            ),
-            (
-                r#"
-                class Animal
-                    include Comparable
-
-                    attr_reader :legs
-
-                    def initialize(name, legs)
-                        # ...
-                    end
-
-                    def <=>(other)
-                        # ...
-                    end
-                end"#
-                    .unindent(),
-                1363,
-            ),
-            (
-                r#"
-                def initialize(name, legs)
-                    @name, @legs = name, legs
-                end"#
-                    .unindent(),
-                1427,
-            ),
-            (
-                r#"
-                def <=>(other)
-                    legs <=> other.legs
-                end"#
-                    .unindent(),
-                1501,
-            ),
-            (
-                r#"
-                # Singleton method for car object
-                def car.wheels
-                    puts "There are four wheels"
-                end"#
-                    .unindent(),
-                1591,
-            ),
-        ],
-    );
-}
-
-#[gpui::test]
-async fn test_code_context_retrieval_php() {
-    let language = php_lang();
-    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
-    let mut retriever = CodeContextRetriever::new(embedding_provider);
-
-    let text = r#"
-        <?php
-
-        namespace LevelUp\Experience\Concerns;
-
-        /*
-        This is a multiple-lines comment block
-        that spans over multiple
-        lines
-        */
-        function functionName() {
-            echo "Hello world!";
-        }
-
-        trait HasAchievements
-        {
-            /**
-            * @throws \Exception
-            */
-            public function grantAchievement(Achievement $achievement, $progress = null): void
-            {
-                if ($progress > 100) {
-                    throw new Exception(message: 'Progress cannot be greater than 100');
-                }
-
-                if ($this->achievements()->find($achievement->id)) {
-                    throw new Exception(message: 'User already has this Achievement');
-                }
-
-                $this->achievements()->attach($achievement, [
-                    'progress' => $progress ?? null,
-                ]);
-
-                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
-            }
-
-            public function achievements(): BelongsToMany
-            {
-                return $this->belongsToMany(related: Achievement::class)
-                ->withPivot(columns: 'progress')
-                ->where('is_secret', false)
-                ->using(AchievementUser::class);
-            }
-        }
-
-        interface Multiplier
-        {
-            public function qualifies(array $data): bool;
-
-            public function setMultiplier(): int;
-        }
-
-        enum AuditType: string
-        {
-            case Add = 'add';
-            case Remove = 'remove';
-            case Reset = 'reset';
-            case LevelUp = 'level_up';
-        }
-
-        ?>"#
-    .unindent();
-
-    let documents = retriever.parse_file(&text, language.clone()).unwrap();
-
-    assert_documents_eq(
-        &documents,
-        &[
-            (
-                r#"
-        /*
-        This is a multiple-lines comment block
-        that spans over multiple
-        lines
-        */
-        function functionName() {
-            echo "Hello world!";
-        }"#
-                .unindent(),
-                123,
-            ),
-            (
-                r#"
-        trait HasAchievements
-        {
-            /**
-            * @throws \Exception
-            */
-            public function grantAchievement(Achievement $achievement, $progress = null): void
-            {/* ... */}
-
-            public function achievements(): BelongsToMany
-            {/* ... */}
-        }"#
-                .unindent(),
-                177,
-            ),
-            (r#"
-            /**
-            * @throws \Exception
-            */
-            public function grantAchievement(Achievement $achievement, $progress = null): void
-            {
-                if ($progress > 100) {
-                    throw new Exception(message: 'Progress cannot be greater than 100');
-                }
-
-                if ($this->achievements()->find($achievement->id)) {
-                    throw new Exception(message: 'User already has this Achievement');
-                }
-
-                $this->achievements()->attach($achievement, [
-                    'progress' => $progress ?? null,
-                ]);
-
-                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
-            }"#.unindent(), 245),
-            (r#"
-                public function achievements(): BelongsToMany
-                {
-                    return $this->belongsToMany(related: Achievement::class)
-                    ->withPivot(columns: 'progress')
-                    ->where('is_secret', false)
-                    ->using(AchievementUser::class);
-                }"#.unindent(), 902),
-            (r#"
-                interface Multiplier
-                {
-                    public function qualifies(array $data): bool;
-
-                    public function setMultiplier(): int;
-                }"#.unindent(),
-                1146),
-            (r#"
-                enum AuditType: string
-                {
-                    case Add = 'add';
-                    case Remove = 'remove';
-                    case Reset = 'reset';
-                    case LevelUp = 'level_up';
-                }"#.unindent(), 1265)
-        ],
-    );
-}
-
-fn js_lang() -> Arc<Language> {
-    Arc::new(
-        Language::new(
-            LanguageConfig {
-                name: "Javascript".into(),
-                matcher: LanguageMatcher {
-                    path_suffixes: vec!["js".into()],
-                    ..Default::default()
-                },
-                ..Default::default()
-            },
-            Some(tree_sitter_typescript::language_tsx()),
-        )
-        .with_embedding_query(
-            &r#"
-
-            (
-                (comment)* @context
-                .
-                [
-                (export_statement
-                    (function_declaration
-                        "async"? @name
-                        "function" @name
-                        name: (_) @name))
-                (function_declaration
-                    "async"? @name
-                    "function" @name
-                    name: (_) @name)
-                ] @item
-            )
-
-            (
-                (comment)* @context
-                .
-                [
-                (export_statement
-                    (class_declaration
-                        "class" @name
-                        name: (_) @name))
-                (class_declaration
-                    "class" @name
-                    name: (_) @name)
-                ] @item
-            )
-
-            (
-                (comment)* @context
-                .
-                [
-                (export_statement
-                    (interface_declaration
-                        "interface" @name
-                        name: (_) @name))
-                (interface_declaration
-                    "interface" @name
-                    name: (_) @name)
-                ] @item
-            )
-
-            (
-                (comment)* @context
-                .
-                [
-                (export_statement
-                    (enum_declaration
-                        "enum" @name
-                        name: (_) @name))
-                (enum_declaration
-                    "enum" @name
-                    name: (_) @name)
-                ] @item
-            )
-
-            (
-                (comment)* @context
-                .
-                (method_definition
-                    [
-                        "get"
-                        "set"
-                        "async"
-                        "*"
-                        "static"
-                    ]* @name
-                    name: (_) @name) @item
-            )
-
-                    "#
-            .unindent(),
-        )
-        .unwrap(),
-    )
-}
-
-fn rust_lang() -> Arc<Language> {
-    Arc::new(
-        Language::new(
-            LanguageConfig {
-                name: "Rust".into(),
-                matcher: LanguageMatcher {
-                    path_suffixes: vec!["rs".into()],
-                    ..Default::default()
-                },
-                collapsed_placeholder: " /* ... */ ".to_string(),
-                ..Default::default()
-            },
-            Some(tree_sitter_rust::language()),
-        )
-        .with_embedding_query(
-            r#"
-            (
-                [(line_comment) (attribute_item)]* @context
-                .
-                [
-                    (struct_item
-                        name: (_) @name)
-
-                    (enum_item
-                        name: (_) @name)
-
-                    (impl_item
-                        trait: (_)? @name
-                        "for"? @name
-                        type: (_) @name)
-
-                    (trait_item
-                        name: (_) @name)
-
-                    (function_item
-                        name: (_) @name
-                        body: (block
-                            "{" @keep
-                            "}" @keep) @collapse)
-
-                    (macro_definition
-                        name: (_) @name)
-                ] @item
-            )
-
-            (attribute_item) @collapse
-            (use_declaration) @collapse
-            "#,
-        )
-        .unwrap(),
-    )
-}
-
-fn json_lang() -> Arc<Language> {
-    Arc::new(
-        Language::new(
-            LanguageConfig {
-                name: "JSON".into(),
-                matcher: LanguageMatcher {
-                    path_suffixes: vec!["json".into()],
-                    ..Default::default()
-                },
-                ..Default::default()
-            },
-            Some(tree_sitter_json::language()),
-        )
-        .with_embedding_query(
-            r#"
-            (document) @item
-
-            (array
-                "[" @keep
-                .
-                (object)? @keep
-                "]" @keep) @collapse
-
-            (pair value: (string
-                "\"" @keep
-                "\"" @keep) @collapse)
-            "#,
-        )
-        .unwrap(),
-    )
-}
-
-fn toml_lang() -> Arc<Language> {
-    Arc::new(Language::new(
-        LanguageConfig {
-            name: "TOML".into(),
-            matcher: LanguageMatcher {
-                path_suffixes: vec!["toml".into()],
-                ..Default::default()
-            },
-            ..Default::default()
-        },
-        Some(tree_sitter_toml::language()),
-    ))
-}
-
-fn cpp_lang() -> Arc<Language> {
-    Arc::new(
-        Language::new(
-            LanguageConfig {
-                name: "CPP".into(),
-                matcher: LanguageMatcher {
-                    path_suffixes: vec!["cpp".into()],
-                    ..Default::default()
-                },
-                ..Default::default()
-            },
-            Some(tree_sitter_cpp::language()),
-        )
-        .with_embedding_query(
-            r#"
-            (
-                (comment)* @context
-                .
-                (function_definition
-                    (type_qualifier)? @name
-                    type: (_)? @name
-                    declarator: [
-                        (function_declarator
-                            declarator: (_) @name)
-                        (pointer_declarator
-                            "*" @name
-                            declarator: (function_declarator
-                            declarator: (_) @name))
-                        (pointer_declarator
-                            "*" @name
-                            declarator: (pointer_declarator
-                                "*" @name
-                            declarator: (function_declarator
-                                declarator: (_) @name)))
-                        (reference_declarator
-                            ["&" "&&"] @name
-                            (function_declarator
-                            declarator: (_) @name))
-                    ]
-                    (type_qualifier)? @name) @item
-                )
-
-            (
-                (comment)* @context
-                .
-                (template_declaration
-                    (class_specifier
-                        "class" @name
-                        name: (_) @name)
-                        ) @item
-            )
-
-            (
-                (comment)* @context
-                .
-                (class_specifier
-                    "class" @name
-                    name: (_) @name) @item
-                )
-
-            (
-                (comment)* @context
-                .
-                (enum_specifier
-                    "enum" @name
-                    name: (_) @name) @item
-                )
-
-            (
-                (comment)* @context
-                .
-                (declaration
-                    type: (struct_specifier
-                    "struct" @name)
-                    declarator: (_) @name) @item
-            )
-
-            "#,
-        )
-        .unwrap(),
-    )
-}
-
-fn lua_lang() -> Arc<Language> {
-    Arc::new(
-        Language::new(
-            LanguageConfig {
-                name: "Lua".into(),
-                matcher: LanguageMatcher {
-                    path_suffixes: vec!["lua".into()],
-                    ..Default::default()
-                },
-                collapsed_placeholder: "--[ ... ]--".to_string(),
-                ..Default::default()
-            },
-            Some(tree_sitter_lua::language()),
-        )
-        .with_embedding_query(
-            r#"
-            (
-                (comment)* @context
-                .
-                (function_declaration
-                    "function" @name
-                    name: (_) @name
-                    (comment)* @collapse
-                    body: (block) @collapse
-                ) @item
-            )
-        "#,
-        )
-        .unwrap(),
-    )
-}
-
-fn php_lang() -> Arc<Language> {
-    Arc::new(
-        Language::new(
-            LanguageConfig {
-                name: "PHP".into(),
-                matcher: LanguageMatcher {
-                    path_suffixes: vec!["php".into()],
-                    ..Default::default()
-                },
-                collapsed_placeholder: "/* ... */".into(),
-                ..Default::default()
-            },
-            Some(tree_sitter_php::language_php()),
-        )
-        .with_embedding_query(
-            r#"
-            (
-                (comment)* @context
-                .
-                [
-                    (function_definition
-                        "function" @name
-                        name: (_) @name
-                        body: (_
-                            "{" @keep
-                            "}" @keep) @collapse
-                        )
-
-                    (trait_declaration
-                        "trait" @name
-                        name: (_) @name)
-
-                    (method_declaration
-                        "function" @name
-                        name: (_) @name
-                        body: (_
-                            "{" @keep
-                            "}" @keep) @collapse
-                        )
-
-                    (interface_declaration
-                        "interface" @name
-                        name: (_) @name
-                        )
-
-                    (enum_declaration
-                        "enum" @name
-                        name: (_) @name
-                        )
-
-                ] @item
-            )
-            "#,
-        )
-        .unwrap(),
-    )
-}
-
-fn ruby_lang() -> Arc<Language> {
-    Arc::new(
-        Language::new(
-            LanguageConfig {
-                name: "Ruby".into(),
-                matcher: LanguageMatcher {
-                    path_suffixes: vec!["rb".into()],
-                    ..Default::default()
-                },
-                collapsed_placeholder: "# ...".to_string(),
-                ..Default::default()
-            },
-            Some(tree_sitter_ruby::language()),
-        )
-        .with_embedding_query(
-            r#"
-            (
-                (comment)* @context
-                .
-                [
-                (module
-                    "module" @name
-                    name: (_) @name)
-                (method
-                    "def" @name
-                    name: (_) @name
-                    body: (body_statement) @collapse)
-                (class
-                    "class" @name
-                    name: (_) @name)
-                (singleton_method
-                    "def" @name
-                    object: (_) @name
-                    "." @name
-                    name: (_) @name
-                    body: (body_statement) @collapse)
-                ] @item
-            )
-            "#,
-        )
-        .unwrap(),
-    )
-}
-
-fn elixir_lang() -> Arc<Language> {
-    Arc::new(
-        Language::new(
-            LanguageConfig {
-                name: "Elixir".into(),
-                matcher: LanguageMatcher {
-                    path_suffixes: vec!["rs".into()],
-                    ..Default::default()
-                },
-                ..Default::default()
-            },
-            Some(tree_sitter_elixir::language()),
-        )
-        .with_embedding_query(
-            r#"
-            (
-                (unary_operator
-                    operator: "@"
-                    operand: (call
-                        target: (identifier) @unary
-                        (#match? @unary "^(doc)$"))
-                    ) @context
-                .
-                (call
-                target: (identifier) @name
-                (arguments
-                [
-                (identifier) @name
-                (call
-                target: (identifier) @name)
-                (binary_operator
-                left: (call
-                target: (identifier) @name)
-                operator: "when")
-                ])
-                (#any-match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
-                )
-
-            (call
-                target: (identifier) @name
-                (arguments (alias) @name)
-                (#any-match? @name "^(defmodule|defprotocol)$")) @item
-            "#,
-        )
-        .unwrap(),
-    )
-}
-
-#[gpui::test]
-fn test_subtract_ranges() {
-    assert_eq!(
-        subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
-        vec![1..4, 10..21]
-    );
-
-    assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
-}
-
-fn init_test(cx: &mut TestAppContext) {
-    cx.update(|cx| {
-        let settings_store = SettingsStore::test(cx);
-        cx.set_global(settings_store);
-        SemanticIndexSettings::register(cx);
-        language::init(cx);
-        Project::init_settings(cx);
-    });
-}

crates/settings/src/settings_store.rs 🔗

@@ -479,7 +479,28 @@ impl SettingsStore {
             merge_schema(target_schema, setting_schema.schema);
         }
 
-        fn merge_schema(target: &mut SchemaObject, source: SchemaObject) {
+        fn merge_schema(target: &mut SchemaObject, mut source: SchemaObject) {
+            let source_subschemas = source.subschemas();
+            let target_subschemas = target.subschemas();
+            if let Some(all_of) = source_subschemas.all_of.take() {
+                target_subschemas
+                    .all_of
+                    .get_or_insert(Vec::new())
+                    .extend(all_of);
+            }
+            if let Some(any_of) = source_subschemas.any_of.take() {
+                target_subschemas
+                    .any_of
+                    .get_or_insert(Vec::new())
+                    .extend(any_of);
+            }
+            if let Some(one_of) = source_subschemas.one_of.take() {
+                target_subschemas
+                    .one_of
+                    .get_or_insert(Vec::new())
+                    .extend(one_of);
+            }
+
             if let Some(source) = source.object {
                 let target_properties = &mut target.object().properties;
                 for (key, value) in source.properties {

crates/util/src/http.rs 🔗

@@ -5,9 +5,8 @@ use futures_lite::FutureExt;
 use isahc::config::{Configurable, RedirectPolicy};
 pub use isahc::{
     http::{Method, StatusCode, Uri},
-    Error,
+    AsyncBody, Error, HttpClient as IsahcHttpClient, Request, Response,
 };
-pub use isahc::{AsyncBody, Request, Response};
 #[cfg(feature = "test-support")]
 use std::fmt;
 use std::{

crates/zed/Cargo.toml 🔗

@@ -71,7 +71,6 @@ recent_projects.workspace = true
 release_channel.workspace = true
 rope.workspace = true
 search.workspace = true
-semantic_index.workspace = true
 serde.workspace = true
 serde_json.workspace = true
 settings.workspace = true

crates/zed/src/main.rs 🔗

@@ -174,7 +174,7 @@ fn main() {
             node_runtime.clone(),
             cx,
         );
-        assistant::init(cx);
+        assistant::init(client.clone(), cx);
 
         extension::init(
             fs.clone(),
@@ -247,7 +247,6 @@ fn main() {
         tasks_ui::init(cx);
         channel::init(&client, user_store.clone(), cx);
         search::init(cx);
-        semantic_index::init(fs.clone(), http.clone(), languages.clone(), cx);
         vim::init(cx);
         terminal_view::init(cx);
 

crates/zed/src/zed.rs 🔗

@@ -3060,7 +3060,7 @@ mod tests {
             collab_ui::init(&app_state, cx);
             project_panel::init((), cx);
             terminal_view::init(cx);
-            assistant::init(cx);
+            assistant::init(app_state.client.clone(), cx);
             initialize_workspace(app_state.clone(), cx);
             app_state
         })

docs/src/configuring_zed.md 🔗

@@ -606,28 +606,6 @@ These values take in the same options as the root-level settings with the same n
 
 `boolean` values
 
-## Semantic Index
-
-- Description: Settings related to semantic index.
-- Setting: `semantic_index`
-- Default:
-
-```json
-"semantic_index": {
-  "enabled": false
-},
-```
-
-### Enabled
-
-- Description: Whether or not to display the `Semantic` mode in project search.
-- Setting: `enabled`
-- Default: `true`
-
-**Options**
-
-`boolean` values
-
 ## Show Call Status Icon
 
 - Description: Whether or not to show the call status icon in the status bar.

script/bootstrap 🔗

@@ -11,3 +11,8 @@ cargo run -p collab -- migrate
 
 echo "seeding database..."
 script/seed-db
+
+if [[ "$OSTYPE" == "linux-gnu"* ]]; then
+  echo "Linux dependencies..."
+  script/linux
+fi

script/gemini.py 🔗

@@ -0,0 +1,91 @@
+import subprocess
+import json
+import http.client
+import mimetypes
+import os
+
+def get_text_files():
+    text_files = []
+    # List all files tracked by Git
+    git_files_proc = subprocess.run(['git', 'ls-files'], stdout=subprocess.PIPE, text=True)
+    for file in git_files_proc.stdout.strip().split('\n'):
+        # Check MIME type for each file
+        mime_check_proc = subprocess.run(['file', '--mime', file], stdout=subprocess.PIPE, text=True)
+        if 'text' in mime_check_proc.stdout:
+            text_files.append(file)
+
+    print(f"File count: {len(text_files)}")
+
+    return text_files
+
+def get_file_contents(file):
+    # Read file content
+    with open(file, 'r') as f:
+        return f.read()
+
+
+def main():
+    GEMINI_API_KEY = os.environ.get('GEMINI_API_KEY')
+
+    # Your prompt
+    prompt = "Document the data types and dataflow in this codebase in preparation to port a streaming implementation to rust:\n\n"
+    # Fetch all text files
+    text_files = get_text_files()
+    code_blocks = []
+    for file in text_files:
+        file_contents = get_file_contents(file)
+        # Create a code block for each text file
+        code_blocks.append(f"\n`{file}`\n\n```{file_contents}```\n")
+
+    # Construct the JSON payload
+    payload = json.dumps({
+        "contents": [{
+            "parts": [{
+                "text": prompt + "".join(code_blocks)
+            }]
+        }]
+    })
+
+    # Prepare the HTTP connection
+    conn = http.client.HTTPSConnection("generativelanguage.googleapis.com")
+
+    # Define headers
+    headers = {
+        'Content-Type': 'application/json',
+        'Content-Length': str(len(payload))
+    }
+
+    # Output the content length in bytes
+    print(f"Content Length in kilobytes: {len(payload.encode('utf-8')) / 1024:.2f} KB")
+
+
+    # Send a request to count the tokens
+    conn.request("POST", f"/v1beta/models/gemini-1.5-pro-latest:countTokens?key={GEMINI_API_KEY}", body=payload, headers=headers)
+    # Get the response
+    response = conn.getresponse()
+    if response.status == 200:
+        token_count = json.loads(response.read().decode('utf-8')).get('totalTokens')
+        print(f"Token count: {token_count}")
+    else:
+        print(f"Failed to get token count. Status code: {response.status}, Response body: {response.read().decode('utf-8')}")
+
+
+    # Prepare the HTTP connection
+    conn = http.client.HTTPSConnection("generativelanguage.googleapis.com")
+    conn.request("GET", f"/v1beta/models/gemini-1.5-pro-latest:streamGenerateContent?key={GEMINI_API_KEY}", body=payload, headers=headers)
+
+    # Get the response in a streaming manner
+    response = conn.getresponse()
+    if response.status == 200:
+        print("Successfully sent the data to the API.")
+        # Read the response in chunks
+        while chunk := response.read(4096):
+            print(chunk.decode('utf-8'))
+    else:
+        print(f"Failed to send the data to the API. Status code: {response.status}, Response body: {response.read().decode('utf-8')}")
+
+    # Close the connection
+    conn.close()
+
+if __name__ == "__main__":
+    main()

script/linux 🔗

@@ -1,4 +1,6 @@
-#!/usr/bin/bash -e
+#!/usr/bin/bash
+
+set -e
 
 # if sudo is not installed, define an empty alias
 maysudo=$(command -v sudo || command -v doas || true)

script/sqlx 🔗

@@ -3,12 +3,15 @@
 set -e
 
 # Install sqlx-cli if needed
-[[ "$(sqlx --version)" == "sqlx-cli 0.5.7" ]] || cargo install sqlx-cli --version 0.5.7
+if [[ "$(sqlx --version)" != "sqlx-cli 0.5.7" ]]; then
+    echo "sqlx-cli not found or not the required version, installing version 0.5.7..."
+    cargo install sqlx-cli --version 0.5.7
+fi
 
 cd crates/collab
 
 # Export contents of .env.toml
-eval "$(cargo run --quiet --bin dotenv)"
+eval "$(cargo run --bin dotenv)"
 
 # Run sqlx command
 sqlx $@