Detailed changes
@@ -3,5 +3,10 @@
"label": "clippy",
"command": "cargo",
"args": ["xtask", "clippy"]
+ },
+ {
+ "label": "assistant2",
+ "command": "cargo",
+ "args": ["run", "-p", "assistant2", "--example", "assistant_example"]
}
]
@@ -371,6 +371,50 @@ dependencies = [
"workspace",
]
+[[package]]
+name = "assistant2"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "assets",
+ "assistant_tooling",
+ "client",
+ "editor",
+ "env_logger",
+ "feature_flags",
+ "futures 0.3.28",
+ "gpui",
+ "language",
+ "languages",
+ "log",
+ "nanoid",
+ "node_runtime",
+ "open_ai",
+ "project",
+ "release_channel",
+ "rich_text",
+ "schemars",
+ "semantic_index",
+ "serde",
+ "serde_json",
+ "settings",
+ "theme",
+ "ui",
+ "util",
+ "workspace",
+]
+
+[[package]]
+name = "assistant_tooling"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "gpui",
+ "schemars",
+ "serde",
+ "serde_json",
+]
+
[[package]]
name = "async-broadcast"
version = "0.7.0"
@@ -643,7 +687,7 @@ checksum = "5fd55a5ba1179988837d24ab4c7cc8ed6efdeff578ede0416b4225a5fca35bd0"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -710,7 +754,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -741,7 +785,7 @@ checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -1385,7 +1429,7 @@ dependencies = [
"regex",
"rustc-hash",
"shlex",
- "syn 2.0.48",
+ "syn 2.0.59",
"which 4.4.2",
]
@@ -1468,7 +1512,7 @@ source = "git+https://github.com/kvark/blade?rev=810ec594358aafea29a4a3d8ab601d2
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -1634,7 +1678,7 @@ checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -2019,7 +2063,7 @@ dependencies = [
"heck 0.4.1",
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -2959,7 +3003,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30d2b3721e861707777e3195b0158f950ae6dc4a27e4d02ff9f67e3eb3de199e"
dependencies = [
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -3442,7 +3486,7 @@ checksum = "5c785274071b1b420972453b306eeca06acf4633829db4223b58a2a8c5953bc4"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -3954,7 +3998,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -4194,7 +4238,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -5061,7 +5105,7 @@ checksum = "ce243b1bfa62ffc028f1cc3b6034ec63d649f3031bc8a4fbbb004e1ac17d1f68"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -5682,7 +5726,7 @@ checksum = "ba125974b109d512fccbc6c0244e7580143e460895dfd6ea7f8bbb692fd94396"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -6643,7 +6687,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -6719,7 +6763,7 @@ dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -6799,7 +6843,7 @@ checksum = "e8890702dbec0bad9116041ae586f84805b13eecd1d8b1df27c29998a9969d6d"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -6977,7 +7021,7 @@ dependencies = [
"phf_shared",
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -7028,7 +7072,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -7252,7 +7296,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae005bd773ab59b4725093fd7df83fd7892f7d8eafb48dbd7de6e024e4215f9d"
dependencies = [
"proc-macro2",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -7309,9 +7353,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
-version = "1.0.78"
+version = "1.0.81"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae"
+checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba"
dependencies = [
"unicode-ident",
]
@@ -7332,7 +7376,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd"
dependencies = [
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -8175,7 +8219,7 @@ dependencies = [
"proc-macro2",
"quote",
"rust-embed-utils",
- "syn 2.0.48",
+ "syn 2.0.59",
"walkdir",
]
@@ -8449,7 +8493,7 @@ dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -8490,7 +8534,7 @@ dependencies = [
"proc-macro2",
"quote",
"sea-bae",
- "syn 2.0.48",
+ "syn 2.0.59",
"unicode-ident",
]
@@ -8674,7 +8718,7 @@ checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -8739,7 +8783,7 @@ checksum = "8725e1dfadb3a50f7e5ce0b1a540466f6ed3fe7a0fca2ac2b8b831d31316bd00"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -9505,7 +9549,7 @@ dependencies = [
"proc-macro2",
"quote",
"rustversion",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -9634,9 +9678,9 @@ dependencies = [
[[package]]
name = "syn"
-version = "2.0.48"
+version = "2.0.59"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f"
+checksum = "4a6531ffc7b071655e4ce2e04bd464c4830bb585a61cabb96cf808f05172615a"
dependencies = [
"proc-macro2",
"quote",
@@ -10001,7 +10045,7 @@ checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -10180,7 +10224,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -10405,7 +10449,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -11172,7 +11216,7 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
"wasm-bindgen-shared",
]
@@ -11206,7 +11250,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
@@ -11343,7 +11387,7 @@ dependencies = [
"anyhow",
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
"wasmtime-component-util",
"wasmtime-wit-bindgen",
"wit-parser",
@@ -11504,7 +11548,7 @@ checksum = "6d6d967f01032da7d4c6303da32f6a00d5efe1bac124b156e7342d8ace6ffdfc"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -11784,7 +11828,7 @@ dependencies = [
"proc-macro2",
"quote",
"shellexpand",
- "syn 2.0.48",
+ "syn 2.0.59",
"witx",
]
@@ -11796,7 +11840,7 @@ checksum = "512d816dbcd0113103b2eb2402ec9018e7f0755202a5b3e67db726f229d8dcae"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
"wiggle-generate",
]
@@ -11914,7 +11958,7 @@ checksum = "942ac266be9249c84ca862f0a164a39533dc2f6f33dc98ec89c8da99b82ea0bd"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -11925,7 +11969,7 @@ checksum = "da33557140a288fae4e1d5f8873aaf9eb6613a9cf82c3e070223ff177f598b60"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -12242,7 +12286,7 @@ dependencies = [
"anyhow",
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
"wit-bindgen-core",
"wit-bindgen-rust",
]
@@ -12567,6 +12611,7 @@ dependencies = [
"anyhow",
"assets",
"assistant",
+ "assistant2",
"audio",
"auto_update",
"backtrace",
@@ -12860,7 +12905,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -12880,7 +12925,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.48",
+ "syn 2.0.59",
]
[[package]]
@@ -4,6 +4,8 @@ members = [
"crates/anthropic",
"crates/assets",
"crates/assistant",
+ "crates/assistant_tooling",
+ "crates/assistant2",
"crates/audio",
"crates/auto_update",
"crates/breadcrumbs",
@@ -137,6 +139,8 @@ ai = { path = "crates/ai" }
anthropic = { path = "crates/anthropic" }
assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
+assistant2 = { path = "crates/assistant2" }
+assistant_tooling = { path = "crates/assistant_tooling" }
audio = { path = "crates/audio" }
auto_update = { path = "crates/auto_update" }
base64 = "0.13"
@@ -208,6 +212,7 @@ rpc = { path = "crates/rpc" }
task = { path = "crates/task" }
tasks_ui = { path = "crates/tasks_ui" }
search = { path = "crates/search" }
+semantic_index = { path = "crates/semantic_index" }
semantic_version = { path = "crates/semantic_version" }
settings = { path = "crates/settings" }
snippet = { path = "crates/snippet" }
@@ -209,7 +209,14 @@
}
},
{
- "context": "AssistantPanel",
+ "context": "AssistantChat > Editor", // Used in the assistant2 crate
+ "bindings": {
+ "enter": ["assistant2::Submit", "Simple"],
+ "cmd-enter": ["assistant2::Submit", "Codebase"]
+ }
+ },
+ {
+ "context": "AssistantPanel", // Used in the assistant crate, which we're replacing
"bindings": {
"cmd-g": "search::SelectNextMatch",
"cmd-shift-g": "search::SelectPrevMatch"
@@ -5,6 +5,9 @@ edition = "2021"
publish = false
license = "GPL-3.0-or-later"
+[lib]
+path = "src/assets.rs"
+
[lints]
workspace = true
@@ -1,7 +1,7 @@
// This crate was essentially pulled out verbatim from main `zed` crate to avoid having to run RustEmbed macro whenever zed has to be rebuilt. It saves a second or two on an incremental build.
use anyhow::anyhow;
-use gpui::{AssetSource, Result, SharedString};
+use gpui::{AppContext, AssetSource, Result, SharedString};
use rust_embed::RustEmbed;
#[derive(RustEmbed)]
@@ -34,3 +34,19 @@ impl AssetSource for Assets {
.collect())
}
}
+
+impl Assets {
+ /// Populate the [`TextSystem`] of the given [`AppContext`] with all `.ttf` fonts in the `fonts` directory.
+ pub fn load_fonts(&self, cx: &AppContext) -> gpui::Result<()> {
+ let font_paths = self.list("fonts")?;
+ let mut embedded_fonts = Vec::new();
+ for font_path in font_paths {
+ if font_path.ends_with(".ttf") {
+ let font_bytes = cx.asset_source().load(&font_path)?;
+ embedded_fonts.push(font_bytes);
+ }
+ }
+
+ cx.text_system().add_fonts(embedded_fonts)
+ }
+}
@@ -128,6 +128,8 @@ impl LanguageModelRequestMessage {
Role::System => proto::LanguageModelRole::LanguageModelSystem,
} as i32,
content: self.content.clone(),
+ tool_calls: Vec::new(),
+ tool_call_id: None,
}
}
}
@@ -147,6 +149,8 @@ impl LanguageModelRequest {
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
stop: self.stop.clone(),
temperature: self.temperature,
+ tool_choice: None,
+ tools: Vec::new(),
}
}
}
@@ -140,14 +140,24 @@ impl OpenAiCompletionProvider {
messages: request
.messages
.into_iter()
- .map(|msg| RequestMessage {
- role: msg.role.into(),
- content: msg.content,
+ .map(|msg| match msg.role {
+ Role::User => RequestMessage::User {
+ content: msg.content,
+ },
+ Role::Assistant => RequestMessage::Assistant {
+ content: Some(msg.content),
+ tool_calls: Vec::new(),
+ },
+ Role::System => RequestMessage::System {
+ content: msg.content,
+ },
})
.collect(),
stream: true,
stop: request.stop,
temperature: request.temperature,
+ tools: Vec::new(),
+ tool_choice: None,
}
}
}
@@ -123,6 +123,8 @@ impl ZedDotDevCompletionProvider {
.collect(),
stop: request.stop,
temperature: request.temperature,
+ tools: Vec::new(),
+ tool_choice: None,
};
self.client
@@ -0,0 +1,56 @@
+[package]
+name = "assistant2"
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "GPL-3.0-or-later"
+
+[lib]
+path = "src/assistant2.rs"
+
+[[example]]
+name = "assistant_example"
+path = "examples/assistant_example.rs"
+crate-type = ["bin"]
+
+[dependencies]
+anyhow.workspace = true
+assistant_tooling.workspace = true
+client.workspace = true
+editor.workspace = true
+feature_flags.workspace = true
+futures.workspace = true
+gpui.workspace = true
+language.workspace = true
+log.workspace = true
+open_ai.workspace = true
+project.workspace = true
+rich_text.workspace = true
+semantic_index.workspace = true
+schemars.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+settings.workspace = true
+theme.workspace = true
+ui.workspace = true
+util.workspace = true
+workspace.workspace = true
+nanoid = "0.4"
+
+[dev-dependencies]
+assets.workspace = true
+editor = { workspace = true, features = ["test-support"] }
+env_logger.workspace = true
+gpui = { workspace = true, features = ["test-support"] }
+language = { workspace = true, features = ["test-support"] }
+languages.workspace = true
+node_runtime.workspace = true
+project = { workspace = true, features = ["test-support"] }
+release_channel.workspace = true
+settings = { workspace = true, features = ["test-support"] }
+theme = { workspace = true, features = ["test-support"] }
+util = { workspace = true, features = ["test-support"] }
+workspace = { workspace = true, features = ["test-support"] }
+
+[lints]
+workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,129 @@
+use anyhow::Context as _;
+use assets::Assets;
+use assistant2::{tools::ProjectIndexTool, AssistantPanel};
+use assistant_tooling::ToolRegistry;
+use client::Client;
+use gpui::{actions, App, AppContext, KeyBinding, Task, View, WindowOptions};
+use language::LanguageRegistry;
+use project::Project;
+use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticIndex};
+use settings::{KeymapFile, DEFAULT_KEYMAP_PATH};
+use std::{
+ path::{Path, PathBuf},
+ sync::Arc,
+};
+use theme::LoadThemes;
+use ui::{div, prelude::*, Render};
+use util::{http::HttpClientWithUrl, ResultExt as _};
+
+actions!(example, [Quit]);
+
+fn main() {
+ let args: Vec<String> = std::env::args().collect();
+
+ env_logger::init();
+ App::new().with_assets(Assets).run(|cx| {
+ cx.bind_keys(Some(KeyBinding::new("cmd-q", Quit, None)));
+ cx.on_action(|_: &Quit, cx: &mut AppContext| {
+ cx.quit();
+ });
+
+ if args.len() < 2 {
+ eprintln!(
+ "Usage: cargo run --example assistant_example -p assistant2 -- <project_path>"
+ );
+ cx.quit();
+ return;
+ }
+
+ settings::init(cx);
+ language::init(cx);
+ Project::init_settings(cx);
+ editor::init(cx);
+ theme::init(LoadThemes::JustBase, cx);
+ Assets.load_fonts(cx).unwrap();
+ KeymapFile::load_asset(DEFAULT_KEYMAP_PATH, cx).unwrap();
+ client::init_settings(cx);
+ release_channel::init("0.130.0", cx);
+
+ let client = Client::production(cx);
+ {
+ let client = client.clone();
+ cx.spawn(|cx| async move { client.authenticate_and_connect(false, &cx).await })
+ .detach_and_log_err(cx);
+ }
+ assistant2::init(client.clone(), cx);
+
+ let language_registry = Arc::new(LanguageRegistry::new(
+ Task::ready(()),
+ cx.background_executor().clone(),
+ ));
+ let node_runtime = node_runtime::RealNodeRuntime::new(client.http_client());
+ languages::init(language_registry.clone(), node_runtime, cx);
+
+ let http = Arc::new(HttpClientWithUrl::new("http://localhost:11434"));
+
+ let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
+ let embedding_provider = OpenAiEmbeddingProvider::new(
+ http.clone(),
+ OpenAiEmbeddingModel::TextEmbedding3Small,
+ open_ai::OPEN_AI_API_URL.to_string(),
+ api_key,
+ );
+
+ cx.spawn(|mut cx| async move {
+ let mut semantic_index = SemanticIndex::new(
+ PathBuf::from("/tmp/semantic-index-db.mdb"),
+ Arc::new(embedding_provider),
+ &mut cx,
+ )
+ .await?;
+
+ let project_path = Path::new(&args[1]);
+ let project = Project::example([project_path], &mut cx).await;
+
+ cx.update(|cx| {
+ let fs = project.read(cx).fs().clone();
+
+ let project_index = semantic_index.project_index(project.clone(), cx);
+
+ let mut tool_registry = ToolRegistry::new();
+ tool_registry
+ .register(ProjectIndexTool::new(project_index.clone(), fs.clone()))
+ .context("failed to register ProjectIndexTool")
+ .log_err();
+
+ let tool_registry = Arc::new(tool_registry);
+
+ cx.open_window(WindowOptions::default(), |cx| {
+ cx.new_view(|cx| Example::new(language_registry, tool_registry, cx))
+ });
+ cx.activate(true);
+ })
+ })
+ .detach_and_log_err(cx);
+ })
+}
+
+struct Example {
+ assistant_panel: View<AssistantPanel>,
+}
+
+impl Example {
+ fn new(
+ language_registry: Arc<LanguageRegistry>,
+ tool_registry: Arc<ToolRegistry>,
+ cx: &mut ViewContext<Self>,
+ ) -> Self {
+ Self {
+ assistant_panel: cx
+ .new_view(|cx| AssistantPanel::new(language_registry, tool_registry, cx)),
+ }
+ }
+}
+
+impl Render for Example {
+ fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl ui::prelude::IntoElement {
+ div().size_full().child(self.assistant_panel.clone())
+ }
+}
@@ -0,0 +1,952 @@
+mod assistant_settings;
+mod completion_provider;
+pub mod tools;
+
+use anyhow::{Context, Result};
+use assistant_tooling::{ToolFunctionCall, ToolRegistry};
+use client::{proto, Client};
+use completion_provider::*;
+use editor::{Editor, EditorEvent};
+use feature_flags::FeatureFlagAppExt as _;
+use futures::{channel::oneshot, future::join_all, Future, FutureExt, StreamExt};
+use gpui::{
+ list, prelude::*, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
+ FocusableView, Global, ListAlignment, ListState, Model, Render, Task, View, WeakView,
+};
+use language::{language_settings::SoftWrap, LanguageRegistry};
+use open_ai::{FunctionContent, ToolCall, ToolCallContent};
+use project::Fs;
+use rich_text::RichText;
+use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex};
+use serde::Deserialize;
+use settings::Settings;
+use std::{cmp, sync::Arc};
+use theme::ThemeSettings;
+use tools::ProjectIndexTool;
+use ui::{popover_menu, prelude::*, ButtonLike, CollapsibleContainer, Color, ContextMenu, Tooltip};
+use util::{paths::EMBEDDINGS_DIR, ResultExt};
+use workspace::{
+ dock::{DockPosition, Panel, PanelEvent},
+ Workspace,
+};
+
+pub use assistant_settings::AssistantSettings;
+
+const MAX_COMPLETION_CALLS_PER_SUBMISSION: usize = 5;
+
+// gpui::actions!(assistant, [Submit]);
+
+#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
+pub struct Submit(SubmitMode);
+
+/// There are multiple different ways to submit a model request, represented by this enum.
+#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
+pub enum SubmitMode {
+ /// Only include the conversation.
+ Simple,
+ /// Send the current file as context.
+ CurrentFile,
+ /// Search the codebase and send relevant excerpts.
+ Codebase,
+}
+
+gpui::actions!(assistant2, [ToggleFocus]);
+gpui::impl_actions!(assistant2, [Submit]);
+
+pub fn init(client: Arc<Client>, cx: &mut AppContext) {
+ AssistantSettings::register(cx);
+
+ cx.spawn(|mut cx| {
+ let client = client.clone();
+ async move {
+ let embedding_provider = CloudEmbeddingProvider::new(client.clone());
+ let semantic_index = SemanticIndex::new(
+ EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
+ Arc::new(embedding_provider),
+ &mut cx,
+ )
+ .await?;
+ cx.update(|cx| cx.set_global(semantic_index))
+ }
+ })
+ .detach();
+
+ cx.set_global(CompletionProvider::new(CloudCompletionProvider::new(
+ client,
+ )));
+
+ cx.observe_new_views(
+ |workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
+ workspace.register_action(|workspace, _: &ToggleFocus, cx| {
+ workspace.toggle_panel_focus::<AssistantPanel>(cx);
+ });
+ },
+ )
+ .detach();
+}
+
+pub fn enabled(cx: &AppContext) -> bool {
+ cx.is_staff()
+}
+
+pub struct AssistantPanel {
+ chat: View<AssistantChat>,
+ width: Option<Pixels>,
+}
+
+impl AssistantPanel {
+ pub fn load(
+ workspace: WeakView<Workspace>,
+ cx: AsyncWindowContext,
+ ) -> Task<Result<View<Self>>> {
+ cx.spawn(|mut cx| async move {
+ let (app_state, project) = workspace.update(&mut cx, |workspace, _| {
+ (workspace.app_state().clone(), workspace.project().clone())
+ })?;
+
+ cx.new_view(|cx| {
+ // todo!("this will panic if the semantic index failed to load or has not loaded yet")
+ let project_index = cx.update_global(|semantic_index: &mut SemanticIndex, cx| {
+ semantic_index.project_index(project.clone(), cx)
+ });
+
+ let mut tool_registry = ToolRegistry::new();
+ tool_registry
+ .register(ProjectIndexTool::new(
+ project_index.clone(),
+ app_state.fs.clone(),
+ ))
+ .context("failed to register ProjectIndexTool")
+ .log_err();
+
+ let tool_registry = Arc::new(tool_registry);
+
+ Self::new(app_state.languages.clone(), tool_registry, cx)
+ })
+ })
+ }
+
+ pub fn new(
+ language_registry: Arc<LanguageRegistry>,
+ tool_registry: Arc<ToolRegistry>,
+ cx: &mut ViewContext<Self>,
+ ) -> Self {
+ let chat = cx.new_view(|cx| {
+ AssistantChat::new(language_registry.clone(), tool_registry.clone(), cx)
+ });
+
+ Self { width: None, chat }
+ }
+}
+
+impl Render for AssistantPanel {
+ fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ div()
+ .size_full()
+ .v_flex()
+ .p_2()
+ .bg(cx.theme().colors().background)
+ .child(self.chat.clone())
+ }
+}
+
+impl Panel for AssistantPanel {
+ fn persistent_name() -> &'static str {
+ "AssistantPanelv2"
+ }
+
+ fn position(&self, _cx: &WindowContext) -> workspace::dock::DockPosition {
+ // todo!("Add a setting / use assistant settings")
+ DockPosition::Right
+ }
+
+ fn position_is_valid(&self, position: workspace::dock::DockPosition) -> bool {
+ matches!(position, DockPosition::Right)
+ }
+
+ fn set_position(&mut self, _: workspace::dock::DockPosition, _: &mut ViewContext<Self>) {
+ // Do nothing until we have a setting for this
+ }
+
+ fn size(&self, _cx: &WindowContext) -> Pixels {
+ self.width.unwrap_or(px(400.))
+ }
+
+ fn set_size(&mut self, size: Option<Pixels>, cx: &mut ViewContext<Self>) {
+ self.width = size;
+ cx.notify();
+ }
+
+ fn icon(&self, _cx: &WindowContext) -> Option<ui::IconName> {
+ Some(IconName::Ai)
+ }
+
+ fn icon_tooltip(&self, _: &WindowContext) -> Option<&'static str> {
+ Some("Assistant Panel ✨")
+ }
+
+ fn toggle_action(&self) -> Box<dyn gpui::Action> {
+ Box::new(ToggleFocus)
+ }
+}
+
+impl EventEmitter<PanelEvent> for AssistantPanel {}
+
+impl FocusableView for AssistantPanel {
+ fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
+ self.chat
+ .read(cx)
+ .messages
+ .iter()
+ .rev()
+ .find_map(|msg| msg.focus_handle(cx))
+ .expect("no user message in chat")
+ }
+}
+
+struct AssistantChat {
+ model: String,
+ messages: Vec<ChatMessage>,
+ list_state: ListState,
+ language_registry: Arc<LanguageRegistry>,
+ next_message_id: MessageId,
+ pending_completion: Option<Task<()>>,
+ tool_registry: Arc<ToolRegistry>,
+}
+
+impl AssistantChat {
+ fn new(
+ language_registry: Arc<LanguageRegistry>,
+ tool_registry: Arc<ToolRegistry>,
+ cx: &mut ViewContext<Self>,
+ ) -> Self {
+ let model = CompletionProvider::get(cx).default_model();
+ let view = cx.view().downgrade();
+ let list_state = ListState::new(
+ 0,
+ ListAlignment::Bottom,
+ px(1024.),
+ move |ix, cx: &mut WindowContext| {
+ view.update(cx, |this, cx| this.render_message(ix, cx))
+ .unwrap()
+ },
+ );
+
+ let mut this = Self {
+ model,
+ messages: Vec::new(),
+ list_state,
+ language_registry,
+ next_message_id: MessageId(0),
+ pending_completion: None,
+ tool_registry,
+ };
+ this.push_new_user_message(true, cx);
+ this
+ }
+
+ fn focused_message_id(&self, cx: &WindowContext) -> Option<MessageId> {
+ self.messages.iter().find_map(|message| match message {
+ ChatMessage::User(message) => message
+ .body
+ .focus_handle(cx)
+ .contains_focused(cx)
+ .then_some(message.id),
+ ChatMessage::Assistant(_) => None,
+ })
+ }
+
+ fn submit(&mut self, Submit(mode): &Submit, cx: &mut ViewContext<Self>) {
+ let Some(focused_message_id) = self.focused_message_id(cx) else {
+ log::error!("unexpected state: no user message editor is focused.");
+ return;
+ };
+
+ self.truncate_messages(focused_message_id, cx);
+
+ let mode = *mode;
+ self.pending_completion = Some(cx.spawn(move |this, mut cx| async move {
+ Self::request_completion(
+ this.clone(),
+ mode,
+ MAX_COMPLETION_CALLS_PER_SUBMISSION,
+ &mut cx,
+ )
+ .await
+ .log_err();
+
+ this.update(&mut cx, |this, cx| {
+ let focus = this
+ .user_message(focused_message_id)
+ .body
+ .focus_handle(cx)
+ .contains_focused(cx);
+ this.push_new_user_message(focus, cx);
+ })
+ .context("Failed to push new user message")
+ .log_err();
+ }));
+ }
+
+ async fn request_completion(
+ this: WeakView<Self>,
+ mode: SubmitMode,
+ limit: usize,
+ cx: &mut AsyncWindowContext,
+ ) -> Result<()> {
+ let mut call_count = 0;
+ loop {
+ let complete = async {
+ let completion = this.update(cx, |this, cx| {
+ this.push_new_assistant_message(cx);
+
+ let definitions = if call_count < limit && matches!(mode, SubmitMode::Codebase)
+ {
+ this.tool_registry.definitions()
+ } else {
+ &[]
+ };
+ call_count += 1;
+
+ CompletionProvider::get(cx).complete(
+ this.model.clone(),
+ this.completion_messages(cx),
+ Vec::new(),
+ 1.0,
+ definitions,
+ )
+ });
+
+ let mut stream = completion?.await?;
+ let mut body = String::new();
+ while let Some(delta) = stream.next().await {
+ let delta = delta?;
+ this.update(cx, |this, cx| {
+ if let Some(ChatMessage::Assistant(AssistantMessage {
+ body: message_body,
+ tool_calls: message_tool_calls,
+ ..
+ })) = this.messages.last_mut()
+ {
+ if let Some(content) = &delta.content {
+ body.push_str(content);
+ }
+
+ for tool_call in delta.tool_calls {
+ let index = tool_call.index as usize;
+ if index >= message_tool_calls.len() {
+ message_tool_calls.resize_with(index + 1, Default::default);
+ }
+ let call = &mut message_tool_calls[index];
+
+ if let Some(id) = &tool_call.id {
+ call.id.push_str(id);
+ }
+
+ match tool_call.variant {
+ Some(proto::tool_call_delta::Variant::Function(tool_call)) => {
+ if let Some(name) = &tool_call.name {
+ call.name.push_str(name);
+ }
+ if let Some(arguments) = &tool_call.arguments {
+ call.arguments.push_str(arguments);
+ }
+ }
+ None => {}
+ }
+ }
+
+ *message_body =
+ RichText::new(body.clone(), &[], &this.language_registry);
+ cx.notify();
+ } else {
+ unreachable!()
+ }
+ })?;
+ }
+
+ anyhow::Ok(())
+ }
+ .await;
+
+ let mut tool_tasks = Vec::new();
+ this.update(cx, |this, cx| {
+ if let Some(ChatMessage::Assistant(AssistantMessage {
+ error: message_error,
+ tool_calls,
+ ..
+ })) = this.messages.last_mut()
+ {
+ if let Err(error) = complete {
+ message_error.replace(SharedString::from(error.to_string()));
+ cx.notify();
+ } else {
+ for tool_call in tool_calls.iter() {
+ tool_tasks.push(this.tool_registry.call(tool_call, cx));
+ }
+ }
+ }
+ })?;
+
+ if tool_tasks.is_empty() {
+ return Ok(());
+ }
+
+ let tools = join_all(tool_tasks.into_iter()).await;
+ this.update(cx, |this, cx| {
+ if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) =
+ this.messages.last_mut()
+ {
+ *tool_calls = tools;
+ cx.notify();
+ }
+ })?;
+ }
+ }
+
+ fn user_message(&mut self, message_id: MessageId) -> &mut UserMessage {
+ self.messages
+ .iter_mut()
+ .find_map(|message| match message {
+ ChatMessage::User(user_message) if user_message.id == message_id => {
+ Some(user_message)
+ }
+ _ => None,
+ })
+ .expect("User message not found")
+ }
+
+ fn push_new_user_message(&mut self, focus: bool, cx: &mut ViewContext<Self>) {
+ let id = self.next_message_id.post_inc();
+ let body = cx.new_view(|cx| {
+ let mut editor = Editor::auto_height(80, cx);
+ editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
+ if focus {
+ cx.focus_self();
+ }
+ editor
+ });
+ let _subscription = cx.subscribe(&body, move |this, editor, event, cx| match event {
+ EditorEvent::SelectionsChanged { .. } => {
+ if editor.read(cx).is_focused(cx) {
+ let (message_ix, _message) = this
+ .messages
+ .iter()
+ .enumerate()
+ .find_map(|(ix, message)| match message {
+ ChatMessage::User(user_message) if user_message.id == id => {
+ Some((ix, user_message))
+ }
+ _ => None,
+ })
+ .expect("user message not found");
+
+ this.list_state.scroll_to_reveal_item(message_ix);
+ }
+ }
+ _ => {}
+ });
+ let message = ChatMessage::User(UserMessage {
+ id,
+ body,
+ contexts: Vec::new(),
+ _subscription,
+ });
+ self.push_message(message, cx);
+ }
+
+ fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
+ let message = ChatMessage::Assistant(AssistantMessage {
+ id: self.next_message_id.post_inc(),
+ body: RichText::default(),
+ tool_calls: Vec::new(),
+ error: None,
+ });
+ self.push_message(message, cx);
+ }
+
+ fn push_message(&mut self, message: ChatMessage, cx: &mut ViewContext<Self>) {
+ let old_len = self.messages.len();
+ let focus_handle = Some(message.focus_handle(cx));
+ self.messages.push(message);
+ self.list_state
+ .splice_focusable(old_len..old_len, focus_handle);
+ cx.notify();
+ }
+
+ fn truncate_messages(&mut self, last_message_id: MessageId, cx: &mut ViewContext<Self>) {
+ if let Some(index) = self.messages.iter().position(|message| match message {
+ ChatMessage::User(message) => message.id == last_message_id,
+ ChatMessage::Assistant(message) => message.id == last_message_id,
+ }) {
+ self.list_state.splice(index + 1..self.messages.len(), 0);
+ self.messages.truncate(index + 1);
+ cx.notify();
+ }
+ }
+
+ fn render_error(
+ &self,
+ error: Option<SharedString>,
+ _ix: usize,
+ cx: &mut ViewContext<Self>,
+ ) -> AnyElement {
+ let theme = cx.theme();
+
+ if let Some(error) = error {
+ div()
+ .py_1()
+ .px_2()
+ .neg_mx_1()
+ .rounded_md()
+ .border()
+ .border_color(theme.status().error_border)
+ // .bg(theme.status().error_background)
+ .text_color(theme.status().error)
+ .child(error.clone())
+ .into_any_element()
+ } else {
+ div().into_any_element()
+ }
+ }
+
+ fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
+ let is_last = ix == self.messages.len() - 1;
+
+ match &self.messages[ix] {
+ ChatMessage::User(UserMessage {
+ body,
+ contexts: _contexts,
+ ..
+ }) => div()
+ .when(!is_last, |element| element.mb_2())
+ .child(div().p_2().child(Label::new("You").color(Color::Default)))
+ .child(
+ div()
+ .on_action(cx.listener(Self::submit))
+ .p_2()
+ .text_color(cx.theme().colors().editor_foreground)
+ .font(ThemeSettings::get_global(cx).buffer_font.clone())
+ .bg(cx.theme().colors().editor_background)
+ .child(body.clone()), // .children(contexts.iter().map(|context| context.render(cx))),
+ )
+ .into_any(),
+ ChatMessage::Assistant(AssistantMessage {
+ id,
+ body,
+ error,
+ tool_calls,
+ ..
+ }) => {
+ let assistant_body = if body.text.is_empty() && !tool_calls.is_empty() {
+ div()
+ } else {
+ div().p_2().child(body.element(ElementId::from(id.0), cx))
+ };
+
+ div()
+ .when(!is_last, |element| element.mb_2())
+ .child(
+ div()
+ .p_2()
+ .child(Label::new("Assistant").color(Color::Modified)),
+ )
+ .child(assistant_body)
+ .child(self.render_error(error.clone(), ix, cx))
+ .children(tool_calls.iter().map(|tool_call| {
+ let result = &tool_call.result;
+ let name = tool_call.name.clone();
+ match result {
+ Some(result) => div()
+ .p_2()
+ .child(result.render(&name, &tool_call.id, cx))
+ .into_any(),
+ None => div()
+ .p_2()
+ .child(Label::new(name).color(Color::Modified))
+ .child("Running...")
+ .into_any(),
+ }
+ }))
+ .into_any()
+ }
+ }
+ }
+
+ fn completion_messages(&self, cx: &WindowContext) -> Vec<CompletionMessage> {
+ let mut completion_messages = Vec::new();
+
+ for message in &self.messages {
+ match message {
+ ChatMessage::User(UserMessage { body, contexts, .. }) => {
+ // setup context for model
+ contexts.iter().for_each(|context| {
+ completion_messages.extend(context.completion_messages(cx))
+ });
+
+ // Show user's message last so that the assistant is grounded in the user's request
+ completion_messages.push(CompletionMessage::User {
+ content: body.read(cx).text(cx),
+ });
+ }
+ ChatMessage::Assistant(AssistantMessage {
+ body, tool_calls, ..
+ }) => {
+ // In no case do we want to send an empty message. This shouldn't happen, but we might as well
+ // not break the Chat API if it does.
+ if body.text.is_empty() && tool_calls.is_empty() {
+ continue;
+ }
+
+ let tool_calls_from_assistant = tool_calls
+ .iter()
+ .map(|tool_call| ToolCall {
+ content: ToolCallContent::Function {
+ function: FunctionContent {
+ name: tool_call.name.clone(),
+ arguments: tool_call.arguments.clone(),
+ },
+ },
+ id: tool_call.id.clone(),
+ })
+ .collect();
+
+ completion_messages.push(CompletionMessage::Assistant {
+ content: Some(body.text.to_string()),
+ tool_calls: tool_calls_from_assistant,
+ });
+
+ for tool_call in tool_calls {
+ // todo!(): we should not be sending when the tool is still running / has no result
+ // For now I'm going to have to assume we send an empty string because otherwise
+ // the Chat API will break -- there is a required message for every tool call by ID
+ let content = match &tool_call.result {
+ Some(result) => result.format(&tool_call.name),
+ None => "".to_string(),
+ };
+
+ completion_messages.push(CompletionMessage::Tool {
+ content,
+ tool_call_id: tool_call.id.clone(),
+ });
+ }
+ }
+ }
+ }
+
+ completion_messages
+ }
+
+ fn render_model_dropdown(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ let this = cx.view().downgrade();
+ div().h_flex().justify_end().child(
+ div().w_32().child(
+ popover_menu("user-menu")
+ .menu(move |cx| {
+ ContextMenu::build(cx, |mut menu, cx| {
+ for model in CompletionProvider::get(cx).available_models() {
+ menu = menu.custom_entry(
+ {
+ let model = model.clone();
+ move |_| Label::new(model.clone()).into_any_element()
+ },
+ {
+ let this = this.clone();
+ move |cx| {
+ _ = this.update(cx, |this, cx| {
+ this.model = model.clone();
+ cx.notify();
+ });
+ }
+ },
+ );
+ }
+ menu
+ })
+ .into()
+ })
+ .trigger(
+ ButtonLike::new("active-model")
+ .child(
+ h_flex()
+ .w_full()
+ .gap_0p5()
+ .child(
+ div()
+ .overflow_x_hidden()
+ .flex_grow()
+ .whitespace_nowrap()
+ .child(Label::new(self.model.clone())),
+ )
+ .child(div().child(
+ Icon::new(IconName::ChevronDown).color(Color::Muted),
+ )),
+ )
+ .style(ButtonStyle::Subtle)
+ .tooltip(move |cx| Tooltip::text("Change Model", cx)),
+ )
+ .anchor(gpui::AnchorCorner::TopRight),
+ ),
+ )
+ }
+}
+
+impl Render for AssistantChat {
+ fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ div()
+ .relative()
+ .flex_1()
+ .v_flex()
+ .key_context("AssistantChat")
+ .text_color(Color::Default.color(cx))
+ .child(self.render_model_dropdown(cx))
+ .child(list(self.list_state.clone()).flex_1())
+ }
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+struct MessageId(usize);
+
+impl MessageId {
+ fn post_inc(&mut self) -> Self {
+ let id = *self;
+ self.0 += 1;
+ id
+ }
+}
+
+enum ChatMessage {
+ User(UserMessage),
+ Assistant(AssistantMessage),
+}
+
+impl ChatMessage {
+ fn focus_handle(&self, cx: &AppContext) -> Option<FocusHandle> {
+ match self {
+ ChatMessage::User(UserMessage { body, .. }) => Some(body.focus_handle(cx)),
+ ChatMessage::Assistant(_) => None,
+ }
+ }
+}
+
+struct UserMessage {
+ id: MessageId,
+ body: View<Editor>,
+ contexts: Vec<AssistantContext>,
+ _subscription: gpui::Subscription,
+}
+
+struct AssistantMessage {
+ id: MessageId,
+ body: RichText,
+ tool_calls: Vec<ToolFunctionCall>,
+ error: Option<SharedString>,
+}
+
+// Since we're swapping out for direct query usage, we might not need to use this injected context
+// It will be useful though for when the user _definitely_ wants the model to see a specific file,
+// query, error, etc.
+#[allow(dead_code)]
+enum AssistantContext {
+ Codebase(View<CodebaseContext>),
+}
+
+#[allow(dead_code)]
+struct CodebaseExcerpt {
+ element_id: ElementId,
+ path: SharedString,
+ text: SharedString,
+ score: f32,
+ expanded: bool,
+}
+
+impl AssistantContext {
+ #[allow(dead_code)]
+ fn render(&self, _cx: &mut ViewContext<AssistantChat>) -> AnyElement {
+ match self {
+ AssistantContext::Codebase(context) => context.clone().into_any_element(),
+ }
+ }
+
+ fn completion_messages(&self, cx: &WindowContext) -> Vec<CompletionMessage> {
+ match self {
+ AssistantContext::Codebase(context) => context.read(cx).completion_messages(),
+ }
+ }
+}
+
+enum CodebaseContext {
+ Pending { _task: Task<()> },
+ Done(Result<Vec<CodebaseExcerpt>>),
+}
+
+impl CodebaseContext {
+ fn toggle_expanded(&mut self, element_id: ElementId, cx: &mut ViewContext<Self>) {
+ if let CodebaseContext::Done(Ok(excerpts)) = self {
+ if let Some(excerpt) = excerpts
+ .iter_mut()
+ .find(|excerpt| excerpt.element_id == element_id)
+ {
+ excerpt.expanded = !excerpt.expanded;
+ cx.notify();
+ }
+ }
+ }
+}
+
+impl Render for CodebaseContext {
+ fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ match self {
+ CodebaseContext::Pending { .. } => div()
+ .h_flex()
+ .items_center()
+ .gap_1()
+ .child(Icon::new(IconName::Ai).color(Color::Muted).into_element())
+ .child("Searching codebase..."),
+ CodebaseContext::Done(Ok(excerpts)) => {
+ div()
+ .v_flex()
+ .gap_2()
+ .children(excerpts.iter().map(|excerpt| {
+ let expanded = excerpt.expanded;
+ let element_id = excerpt.element_id.clone();
+
+ CollapsibleContainer::new(element_id.clone(), expanded)
+ .start_slot(
+ h_flex()
+ .gap_1()
+ .child(Icon::new(IconName::File).color(Color::Muted))
+ .child(Label::new(excerpt.path.clone()).color(Color::Muted)),
+ )
+ .on_click(cx.listener(move |this, _, cx| {
+ this.toggle_expanded(element_id.clone(), cx);
+ }))
+ .child(
+ div()
+ .p_2()
+ .rounded_md()
+ .bg(cx.theme().colors().editor_background)
+ .child(
+ excerpt.text.clone(), // todo!(): Show as an editor block
+ ),
+ )
+ }))
+ }
+ CodebaseContext::Done(Err(error)) => div().child(error.to_string()),
+ }
+ }
+}
+
+impl CodebaseContext {
+ #[allow(dead_code)]
+ fn new(
+ query: impl 'static + Future<Output = Result<String>>,
+ populated: oneshot::Sender<bool>,
+ project_index: Model<ProjectIndex>,
+ fs: Arc<dyn Fs>,
+ cx: &mut ViewContext<Self>,
+ ) -> Self {
+ let query = query.boxed_local();
+ let _task = cx.spawn(|this, mut cx| async move {
+ let result = async {
+ let query = query.await?;
+ let results = this
+ .update(&mut cx, |_this, cx| {
+ project_index.read(cx).search(&query, 16, cx)
+ })?
+ .await;
+
+ let excerpts = results.into_iter().map(|result| {
+ let abs_path = result
+ .worktree
+ .read_with(&cx, |worktree, _| worktree.abs_path().join(&result.path));
+ let fs = fs.clone();
+
+ async move {
+ let path = result.path.clone();
+ let text = fs.load(&abs_path?).await?;
+ // todo!("what should we do with stale ranges?");
+ let range = cmp::min(result.range.start, text.len())
+ ..cmp::min(result.range.end, text.len());
+
+ let text = SharedString::from(text[range].to_string());
+
+ anyhow::Ok(CodebaseExcerpt {
+ element_id: ElementId::Name(nanoid::nanoid!().into()),
+ path: path.to_string_lossy().to_string().into(),
+ text,
+ score: result.score,
+ expanded: false,
+ })
+ }
+ });
+
+ anyhow::Ok(
+ futures::future::join_all(excerpts)
+ .await
+ .into_iter()
+ .filter_map(|result| result.log_err())
+ .collect(),
+ )
+ }
+ .await;
+
+ this.update(&mut cx, |this, cx| {
+ this.populate(result, populated, cx);
+ })
+ .ok();
+ });
+
+ Self::Pending { _task }
+ }
+
+ #[allow(dead_code)]
+ fn populate(
+ &mut self,
+ result: Result<Vec<CodebaseExcerpt>>,
+ populated: oneshot::Sender<bool>,
+ cx: &mut ViewContext<Self>,
+ ) {
+ let success = result.is_ok();
+ *self = Self::Done(result);
+ populated.send(success).ok();
+ cx.notify();
+ }
+
+ fn completion_messages(&self) -> Vec<CompletionMessage> {
+ // One system message for the whole batch of excerpts:
+
+ // Semantic search results for user query:
+ //
+ // Excerpt from $path:
+ // ~~~
+ // `text`
+ // ~~~
+ //
+ // Excerpt from $path:
+
+ match self {
+ CodebaseContext::Done(Ok(excerpts)) => {
+ if excerpts.is_empty() {
+ return Vec::new();
+ }
+
+ let mut body = "Semantic search results for user query:\n".to_string();
+
+ for excerpt in excerpts {
+ body.push_str("Excerpt from ");
+ body.push_str(excerpt.path.as_ref());
+ body.push_str(", score ");
+ body.push_str(&excerpt.score.to_string());
+ body.push_str(":\n");
+ body.push_str("~~~\n");
+ body.push_str(excerpt.text.as_ref());
+ body.push_str("~~~\n");
+ }
+
+ vec![CompletionMessage::System { content: body }]
+ }
+ _ => vec![],
+ }
+ }
+}
@@ -0,0 +1,26 @@
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use settings::{Settings, SettingsSources};
+
+#[derive(Default, Debug, Deserialize, Serialize, Clone)]
+pub struct AssistantSettings {
+ pub enabled: bool,
+}
+
+#[derive(Default, Debug, Deserialize, Serialize, Clone, JsonSchema)]
+pub struct AssistantSettingsContent {
+ pub enabled: Option<bool>,
+}
+
+impl Settings for AssistantSettings {
+ const KEY: Option<&'static str> = Some("assistant_v2");
+
+ type FileContent = AssistantSettingsContent;
+
+ fn load(
+ sources: SettingsSources<Self::FileContent>,
+ _: &mut gpui::AppContext,
+ ) -> anyhow::Result<Self> {
+ Ok(sources.json_merge().unwrap_or_else(|_| Default::default()))
+ }
+}
@@ -0,0 +1,179 @@
+use anyhow::Result;
+use assistant_tooling::ToolFunctionDefinition;
+use client::{proto, Client};
+use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use gpui::Global;
+use std::sync::Arc;
+
+pub use open_ai::RequestMessage as CompletionMessage;
+
+#[derive(Clone)]
+pub struct CompletionProvider(Arc<dyn CompletionProviderBackend>);
+
+impl CompletionProvider {
+ pub fn new(backend: impl CompletionProviderBackend) -> Self {
+ Self(Arc::new(backend))
+ }
+
+ pub fn default_model(&self) -> String {
+ self.0.default_model()
+ }
+
+ pub fn available_models(&self) -> Vec<String> {
+ self.0.available_models()
+ }
+
+ pub fn complete(
+ &self,
+ model: String,
+ messages: Vec<CompletionMessage>,
+ stop: Vec<String>,
+ temperature: f32,
+ tools: &[ToolFunctionDefinition],
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
+ {
+ self.0.complete(model, messages, stop, temperature, tools)
+ }
+}
+
+impl Global for CompletionProvider {}
+
+pub trait CompletionProviderBackend: 'static {
+ fn default_model(&self) -> String;
+ fn available_models(&self) -> Vec<String>;
+ fn complete(
+ &self,
+ model: String,
+ messages: Vec<CompletionMessage>,
+ stop: Vec<String>,
+ temperature: f32,
+ tools: &[ToolFunctionDefinition],
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>;
+}
+
+pub struct CloudCompletionProvider {
+ client: Arc<Client>,
+}
+
+impl CloudCompletionProvider {
+ pub fn new(client: Arc<Client>) -> Self {
+ Self { client }
+ }
+}
+
+impl CompletionProviderBackend for CloudCompletionProvider {
+ fn default_model(&self) -> String {
+ "gpt-4-turbo".into()
+ }
+
+ fn available_models(&self) -> Vec<String> {
+ vec!["gpt-4-turbo".into(), "gpt-4".into(), "gpt-3.5-turbo".into()]
+ }
+
+ fn complete(
+ &self,
+ model: String,
+ messages: Vec<CompletionMessage>,
+ stop: Vec<String>,
+ temperature: f32,
+ tools: &[ToolFunctionDefinition],
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
+ {
+ let client = self.client.clone();
+ let tools: Vec<proto::ChatCompletionTool> = tools
+ .iter()
+ .filter_map(|tool| {
+ Some(proto::ChatCompletionTool {
+ variant: Some(proto::chat_completion_tool::Variant::Function(
+ proto::chat_completion_tool::FunctionObject {
+ name: tool.name.clone(),
+ description: Some(tool.description.clone()),
+ parameters: Some(serde_json::to_string(&tool.parameters).ok()?),
+ },
+ )),
+ })
+ })
+ .collect();
+
+ let tool_choice = match tools.is_empty() {
+ true => None,
+ false => Some("auto".into()),
+ };
+
+ async move {
+ let stream = client
+ .request_stream(proto::CompleteWithLanguageModel {
+ model,
+ messages: messages
+ .into_iter()
+ .map(|message| match message {
+ CompletionMessage::Assistant {
+ content,
+ tool_calls,
+ } => proto::LanguageModelRequestMessage {
+ role: proto::LanguageModelRole::LanguageModelAssistant as i32,
+ content: content.unwrap_or_default(),
+ tool_call_id: None,
+ tool_calls: tool_calls
+ .into_iter()
+ .map(|tool_call| match tool_call.content {
+ open_ai::ToolCallContent::Function { function } => {
+ proto::ToolCall {
+ id: tool_call.id,
+ variant: Some(proto::tool_call::Variant::Function(
+ proto::tool_call::FunctionCall {
+ name: function.name,
+ arguments: function.arguments,
+ },
+ )),
+ }
+ }
+ })
+ .collect(),
+ },
+ CompletionMessage::User { content } => {
+ proto::LanguageModelRequestMessage {
+ role: proto::LanguageModelRole::LanguageModelUser as i32,
+ content,
+ tool_call_id: None,
+ tool_calls: Vec::new(),
+ }
+ }
+ CompletionMessage::System { content } => {
+ proto::LanguageModelRequestMessage {
+ role: proto::LanguageModelRole::LanguageModelSystem as i32,
+ content,
+ tool_calls: Vec::new(),
+ tool_call_id: None,
+ }
+ }
+ CompletionMessage::Tool {
+ content,
+ tool_call_id,
+ } => proto::LanguageModelRequestMessage {
+ role: proto::LanguageModelRole::LanguageModelTool as i32,
+ content,
+ tool_call_id: Some(tool_call_id),
+ tool_calls: Vec::new(),
+ },
+ })
+ .collect(),
+ stop,
+ temperature,
+ tool_choice,
+ tools,
+ })
+ .await?;
+
+ Ok(stream
+ .filter_map(|response| async move {
+ match response {
+ Ok(mut response) => Some(Ok(response.choices.pop()?.delta?)),
+ Err(error) => Some(Err(error)),
+ }
+ })
+ .boxed())
+ }
+ .boxed()
+ }
+}
@@ -0,0 +1,176 @@
+use anyhow::Result;
+use assistant_tooling::LanguageModelTool;
+use gpui::{prelude::*, AnyElement, AppContext, Model, Task};
+use project::Fs;
+use schemars::JsonSchema;
+use semantic_index::ProjectIndex;
+use serde::{Deserialize, Serialize};
+use std::sync::Arc;
+use ui::{
+ div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString,
+ WindowContext,
+};
+use util::ResultExt as _;
+
+const DEFAULT_SEARCH_LIMIT: usize = 20;
+
+#[derive(Serialize, Clone)]
+pub struct CodebaseExcerpt {
+ path: SharedString,
+ text: SharedString,
+ score: f32,
+}
+
+// Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model.
+// Any changes or deletions to the `CodebaseQuery` comments will change model behavior.
+
+#[derive(Deserialize, JsonSchema)]
+pub struct CodebaseQuery {
+ /// Semantic search query
+ query: String,
+ /// Maximum number of results to return, defaults to 20
+ limit: Option<usize>,
+}
+
+pub struct ProjectIndexTool {
+ project_index: Model<ProjectIndex>,
+ fs: Arc<dyn Fs>,
+}
+
+impl ProjectIndexTool {
+ pub fn new(project_index: Model<ProjectIndex>, fs: Arc<dyn Fs>) -> Self {
+ // TODO: setup a better description based on the user's current codebase.
+ Self { project_index, fs }
+ }
+}
+
+impl LanguageModelTool for ProjectIndexTool {
+ type Input = CodebaseQuery;
+ type Output = Vec<CodebaseExcerpt>;
+
+ fn name(&self) -> String {
+ "query_codebase".to_string()
+ }
+
+ fn description(&self) -> String {
+ "Semantic search against the user's current codebase, returning excerpts related to the query by computing a dot product against embeddings of chunks and an embedding of the query".to_string()
+ }
+
+ fn execute(&self, query: &Self::Input, cx: &AppContext) -> Task<Result<Self::Output>> {
+ let project_index = self.project_index.read(cx);
+
+ let results = project_index.search(
+ query.query.as_str(),
+ query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
+ cx,
+ );
+
+ let fs = self.fs.clone();
+
+ cx.spawn(|cx| async move {
+ let results = results.await;
+
+ let excerpts = results.into_iter().map(|result| {
+ let abs_path = result
+ .worktree
+ .read_with(&cx, |worktree, _| worktree.abs_path().join(&result.path));
+ let fs = fs.clone();
+
+ async move {
+ let path = result.path.clone();
+ let text = fs.load(&abs_path?).await?;
+
+ let mut start = result.range.start;
+ let mut end = result.range.end.min(text.len());
+ while !text.is_char_boundary(start) {
+ start += 1;
+ }
+ while !text.is_char_boundary(end) {
+ end -= 1;
+ }
+
+ anyhow::Ok(CodebaseExcerpt {
+ path: path.to_string_lossy().to_string().into(),
+ text: SharedString::from(text[start..end].to_string()),
+ score: result.score,
+ })
+ }
+ });
+
+ let excerpts = futures::future::join_all(excerpts)
+ .await
+ .into_iter()
+ .filter_map(|result| result.log_err())
+ .collect();
+ anyhow::Ok(excerpts)
+ })
+ }
+
+ fn render(
+ _tool_call_id: &str,
+ input: &Self::Input,
+ excerpts: &Self::Output,
+ cx: &mut WindowContext,
+ ) -> AnyElement {
+ let query = input.query.clone();
+
+ div()
+ .v_flex()
+ .gap_2()
+ .child(
+ div()
+ .p_2()
+ .rounded_md()
+ .bg(cx.theme().colors().editor_background)
+ .child(
+ h_flex()
+ .child(Label::new("Query: ").color(Color::Modified))
+ .child(Label::new(query).color(Color::Muted)),
+ ),
+ )
+ .children(excerpts.iter().map(|excerpt| {
+ // This render doesn't have state/model, so we can't use the listener
+ // let expanded = excerpt.expanded;
+ // let element_id = excerpt.element_id.clone();
+ let element_id = ElementId::Name(nanoid::nanoid!().into());
+ let expanded = false;
+
+ CollapsibleContainer::new(element_id.clone(), expanded)
+ .start_slot(
+ h_flex()
+ .gap_1()
+ .child(Icon::new(IconName::File).color(Color::Muted))
+ .child(Label::new(excerpt.path.clone()).color(Color::Muted)),
+ )
+ // .on_click(cx.listener(move |this, _, cx| {
+ // this.toggle_expanded(element_id.clone(), cx);
+ // }))
+ .child(
+ div()
+ .p_2()
+ .rounded_md()
+ .bg(cx.theme().colors().editor_background)
+ .child(
+ excerpt.text.clone(), // todo!(): Show as an editor block
+ ),
+ )
+ }))
+ .into_any_element()
+ }
+
+ fn format(_input: &Self::Input, excerpts: &Self::Output) -> String {
+ let mut body = "Semantic search results:\n".to_string();
+
+ for excerpt in excerpts {
+ body.push_str("Excerpt from ");
+ body.push_str(excerpt.path.as_ref());
+ body.push_str(", score ");
+ body.push_str(&excerpt.score.to_string());
+ body.push_str(":\n");
+ body.push_str("~~~\n");
+ body.push_str(excerpt.text.as_ref());
+ body.push_str("~~~\n");
+ }
+ body
+ }
+}
@@ -0,0 +1,22 @@
+[package]
+name = "assistant_tooling"
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/assistant_tooling.rs"
+
+[dependencies]
+anyhow.workspace = true
+gpui.workspace = true
+schemars.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+
+[dev-dependencies]
+gpui = { workspace = true, features = ["test-support"] }
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,208 @@
+# Assistant Tooling
+
+Bringing OpenAI compatible tool calling to GPUI.
+
+This unlocks:
+
+- **Structured Extraction** of model responses
+- **Validation** of model inputs
+- **Execution** of chosen toolsn
+
+## Overview
+
+Language Models can produce structured outputs that are perfect for calling functions. The most famous of these is OpenAI's tool calling. When make a chat completion you can pass a list of tools available to the model. The model will choose `0..n` tools to help them complete a user's task. It's up to _you_ to create the tools that the model can call.
+
+> **User**: "Hey I need help with implementing a collapsible panel in GPUI"
+>
+> **Assistant**: "Sure, I can help with that. Let me see what I can find."
+>
+> `tool_calls: ["name": "query_codebase", arguments: "{ 'query': 'GPUI collapsible panel' }"]`
+>
+> `result: "['crates/gpui/src/panel.rs:12: impl Panel { ... }', 'crates/gpui/src/panel.rs:20: impl Panel { ... }']"`
+>
+> **Assistant**: "Here are some excerpts from the GPUI codebase that might help you."
+
+This library is designed to facilitate this interaction mode by allowing you to go from `struct` to `tool` with a simple trait, `LanguageModelTool`.
+
+## Example
+
+Let's expose querying a semantic index directly by the model. First, we'll set up some _necessary_ imports
+
+```rust
+use anyhow::Result;
+use assistant_tooling::{LanguageModelTool, ToolRegistry};
+use gpui::{App, AppContext, Task};
+use schemars::JsonSchema;
+use serde::Deserialize;
+use serde_json::json;
+```
+
+Then we'll define the query structure the model must fill in. This _must_ derive `Deserialize` from `serde` and `JsonSchema` from the `schemars` crate.
+
+```rust
+#[derive(Deserialize, JsonSchema)]
+struct CodebaseQuery {
+ query: String,
+}
+```
+
+After that we can define our tool, with the expectation that it will need a `ProjectIndex` to search against. For this example, the index uses the same interface as `semantic_index::ProjectIndex`.
+
+```rust
+struct ProjectIndex {}
+
+impl ProjectIndex {
+ fn new() -> Self {
+ ProjectIndex {}
+ }
+
+ fn search(&self, _query: &str, _limit: usize, _cx: &AppContext) -> Task<Result<Vec<String>>> {
+ // Instead of hooking up a real index, we're going to fake it
+ if _query.contains("gpui") {
+ return Task::ready(Ok(vec![r#"// crates/gpui/src/gpui.rs
+ //! # Welcome to GPUI!
+ //!
+ //! GPUI is a hybrid immediate and retained mode, GPU accelerated, UI framework
+ //! for Rust, designed to support a wide variety of applications
+ "#
+ .to_string()]));
+ }
+ return Task::ready(Ok(vec![]));
+ }
+}
+
+struct ProjectIndexTool {
+ project_index: ProjectIndex,
+}
+```
+
+Now we can implement the `LanguageModelTool` trait for our tool by:
+
+- Defining the `Input` from the model, which is `CodebaseQuery`
+- Defining the `Output`
+- Implementing the `name` and `description` functions to provide the model information when it's choosing a tool
+- Implementing the `execute` function to run the tool
+
+```rust
+impl LanguageModelTool for ProjectIndexTool {
+ type Input = CodebaseQuery;
+ type Output = String;
+
+ fn name(&self) -> String {
+ "query_codebase".to_string()
+ }
+
+ fn description(&self) -> String {
+ "Executes a query against the codebase, returning excerpts related to the query".to_string()
+ }
+
+ fn execute(&self, query: Self::Input, cx: &AppContext) -> Task<Result<Self::Output>> {
+ let results = self.project_index.search(query.query.as_str(), 10, cx);
+
+ cx.spawn(|_cx| async move {
+ let results = results.await?;
+
+ if !results.is_empty() {
+ Ok(results.join("\n"))
+ } else {
+ Ok("No results".to_string())
+ }
+ })
+ }
+}
+```
+
+For the sake of this example, let's look at the types that OpenAI will be passing to us
+
+```rust
+// OpenAI definitions, shown here for demonstration
+#[derive(Deserialize)]
+struct FunctionCall {
+ name: String,
+ args: String,
+}
+
+#[derive(Deserialize, Eq, PartialEq)]
+enum ToolCallType {
+ #[serde(rename = "function")]
+ Function,
+ Other,
+}
+
+#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
+struct ToolCallId(String);
+
+#[derive(Deserialize)]
+#[serde(tag = "type", rename_all = "snake_case")]
+enum ToolCall {
+ Function {
+ #[allow(dead_code)]
+ id: ToolCallId,
+ function: FunctionCall,
+ },
+ Other {
+ #[allow(dead_code)]
+ id: ToolCallId,
+ },
+}
+
+#[derive(Deserialize)]
+struct AssistantMessage {
+ role: String,
+ content: Option<String>,
+ tool_calls: Option<Vec<ToolCall>>,
+}
+```
+
+When the model wants to call tools, it will pass a list of `ToolCall`s. When those are `function`s that we can handle, we'll pass them to our `ToolRegistry` to get a future that we can await.
+
+```rust
+// Inside `fn main()`
+App::new().run(|cx: &mut AppContext| {
+ let tool = ProjectIndexTool {
+ project_index: ProjectIndex::new(),
+ };
+
+ let mut registry = ToolRegistry::new();
+ let registered = registry.register(tool);
+ assert!(registered.is_ok());
+```
+
+Let's pretend the model sent us back a message requesting
+
+```rust
+let model_response = json!({
+ "role": "assistant",
+ "tool_calls": [
+ {
+ "id": "call_1",
+ "function": {
+ "name": "query_codebase",
+ "args": r#"{"query":"GPUI Task background_executor"}"#
+ },
+ "type": "function"
+ }
+ ]
+});
+
+let message: AssistantMessage = serde_json::from_value(model_response).unwrap();
+
+// We know there's a tool call, so let's skip straight to it for this example
+let tool_calls = message.tool_calls.as_ref().unwrap();
+let tool_call = tool_calls.get(0).unwrap();
+```
+
+We can now use our registry to call the tool.
+
+```rust
+let task = registry.call(
+ tool_call.name,
+ tool_call.args,
+);
+
+cx.spawn(|_cx| async move {
+ let result = task.await?;
+ println!("{}", result.unwrap());
+ Ok(())
+})
+```
@@ -0,0 +1,5 @@
+pub mod registry;
+pub mod tool;
+
+pub use crate::registry::ToolRegistry;
+pub use crate::tool::{LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition};
@@ -0,0 +1,298 @@
+use anyhow::{anyhow, Result};
+use gpui::{AnyElement, AppContext, Task, WindowContext};
+use std::{any::Any, collections::HashMap};
+
+use crate::tool::{
+ LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
+};
+
+pub struct ToolRegistry {
+ tools: HashMap<String, Box<dyn Fn(&ToolFunctionCall, &AppContext) -> Task<ToolFunctionCall>>>,
+ definitions: Vec<ToolFunctionDefinition>,
+}
+
+impl ToolRegistry {
+ pub fn new() -> Self {
+ Self {
+ tools: HashMap::new(),
+ definitions: Vec::new(),
+ }
+ }
+
+ pub fn definitions(&self) -> &[ToolFunctionDefinition] {
+ &self.definitions
+ }
+
+ pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
+ fn render<T: 'static + LanguageModelTool>(
+ tool_call_id: &str,
+ input: &Box<dyn Any>,
+ output: &Box<dyn Any>,
+ cx: &mut WindowContext,
+ ) -> AnyElement {
+ T::render(
+ tool_call_id,
+ input.as_ref().downcast_ref::<T::Input>().unwrap(),
+ output.as_ref().downcast_ref::<T::Output>().unwrap(),
+ cx,
+ )
+ }
+
+ fn format<T: 'static + LanguageModelTool>(
+ input: &Box<dyn Any>,
+ output: &Box<dyn Any>,
+ ) -> String {
+ T::format(
+ input.as_ref().downcast_ref::<T::Input>().unwrap(),
+ output.as_ref().downcast_ref::<T::Output>().unwrap(),
+ )
+ }
+
+ self.definitions.push(tool.definition());
+ let name = tool.name();
+ let previous = self.tools.insert(
+ name.clone(),
+ Box::new(move |tool_call: &ToolFunctionCall, cx: &AppContext| {
+ let name = tool_call.name.clone();
+ let arguments = tool_call.arguments.clone();
+ let id = tool_call.id.clone();
+
+ let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
+ return Task::ready(ToolFunctionCall {
+ id,
+ name: name.clone(),
+ arguments,
+ result: Some(ToolFunctionCallResult::ParsingFailed),
+ });
+ };
+
+ let result = tool.execute(&input, cx);
+
+ cx.spawn(move |_cx| async move {
+ match result.await {
+ Ok(result) => {
+ let result: T::Output = result;
+ ToolFunctionCall {
+ id,
+ name: name.clone(),
+ arguments,
+ result: Some(ToolFunctionCallResult::Finished {
+ input: Box::new(input),
+ output: Box::new(result),
+ render_fn: render::<T>,
+ format_fn: format::<T>,
+ }),
+ }
+ }
+ Err(_error) => ToolFunctionCall {
+ id,
+ name: name.clone(),
+ arguments,
+ result: Some(ToolFunctionCallResult::ExecutionFailed {
+ input: Box::new(input),
+ }),
+ },
+ }
+ })
+ }),
+ );
+
+ if previous.is_some() {
+ return Err(anyhow!("already registered a tool with name {}", name));
+ }
+
+ Ok(())
+ }
+
+ pub fn call(&self, tool_call: &ToolFunctionCall, cx: &AppContext) -> Task<ToolFunctionCall> {
+ let name = tool_call.name.clone();
+ let arguments = tool_call.arguments.clone();
+ let id = tool_call.id.clone();
+
+ let tool = match self.tools.get(&name) {
+ Some(tool) => tool,
+ None => {
+ let name = name.clone();
+ return Task::ready(ToolFunctionCall {
+ id,
+ name: name.clone(),
+ arguments,
+ result: Some(ToolFunctionCallResult::NoSuchTool),
+ });
+ }
+ };
+
+ tool(tool_call, cx)
+ }
+}
+
+#[cfg(test)]
+mod test {
+
+ use super::*;
+
+ use schemars::schema_for;
+
+ use gpui::{div, AnyElement, Element, ParentElement, TestAppContext, WindowContext};
+ use schemars::JsonSchema;
+ use serde::{Deserialize, Serialize};
+ use serde_json::json;
+
+ #[derive(Deserialize, Serialize, JsonSchema)]
+ struct WeatherQuery {
+ location: String,
+ unit: String,
+ }
+
+ struct WeatherTool {
+ current_weather: WeatherResult,
+ }
+
+ #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
+ struct WeatherResult {
+ location: String,
+ temperature: f64,
+ unit: String,
+ }
+
+ impl LanguageModelTool for WeatherTool {
+ type Input = WeatherQuery;
+ type Output = WeatherResult;
+
+ fn name(&self) -> String {
+ "get_current_weather".to_string()
+ }
+
+ fn description(&self) -> String {
+ "Fetches the current weather for a given location.".to_string()
+ }
+
+ fn execute(&self, input: &WeatherQuery, _cx: &AppContext) -> Task<Result<Self::Output>> {
+ let _location = input.location.clone();
+ let _unit = input.unit.clone();
+
+ let weather = self.current_weather.clone();
+
+ Task::ready(Ok(weather))
+ }
+
+ fn render(
+ _tool_call_id: &str,
+ _input: &Self::Input,
+ output: &Self::Output,
+ _cx: &mut WindowContext,
+ ) -> AnyElement {
+ div()
+ .child(format!(
+ "The current temperature in {} is {} {}",
+ output.location, output.temperature, output.unit
+ ))
+ .into_any()
+ }
+
+ fn format(_input: &Self::Input, output: &Self::Output) -> String {
+ format!(
+ "The current temperature in {} is {} {}",
+ output.location, output.temperature, output.unit
+ )
+ }
+ }
+
+ #[gpui::test]
+ async fn test_function_registry(cx: &mut TestAppContext) {
+ cx.background_executor.run_until_parked();
+
+ let mut registry = ToolRegistry::new();
+
+ let tool = WeatherTool {
+ current_weather: WeatherResult {
+ location: "San Francisco".to_string(),
+ temperature: 21.0,
+ unit: "Celsius".to_string(),
+ },
+ };
+
+ registry.register(tool).unwrap();
+
+ let _result = cx
+ .update(|cx| {
+ registry.call(
+ &ToolFunctionCall {
+ name: "get_current_weather".to_string(),
+ arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"#
+ .to_string(),
+ id: "test-123".to_string(),
+ result: None,
+ },
+ cx,
+ )
+ })
+ .await;
+
+ // assert!(result.is_ok());
+ // let result = result.unwrap();
+
+ // let expected = r#"{"location":"San Francisco","temperature":21.0,"unit":"Celsius"}"#;
+
+ // todo!(): Put this back in after the interface is stabilized
+ // assert_eq!(result, expected);
+ }
+
+ #[gpui::test]
+ async fn test_openai_weather_example(cx: &mut TestAppContext) {
+ cx.background_executor.run_until_parked();
+
+ let tool = WeatherTool {
+ current_weather: WeatherResult {
+ location: "San Francisco".to_string(),
+ temperature: 21.0,
+ unit: "Celsius".to_string(),
+ },
+ };
+
+ let tools = vec![tool.definition()];
+ assert_eq!(tools.len(), 1);
+
+ let expected = ToolFunctionDefinition {
+ name: "get_current_weather".to_string(),
+ description: "Fetches the current weather for a given location.".to_string(),
+ parameters: schema_for!(WeatherQuery).schema,
+ };
+
+ assert_eq!(tools[0].name, expected.name);
+ assert_eq!(tools[0].description, expected.description);
+
+ let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
+
+ assert_eq!(
+ expected_schema,
+ json!({
+ "title": "WeatherQuery",
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string"
+ },
+ "unit": {
+ "type": "string"
+ }
+ },
+ "required": ["location", "unit"]
+ })
+ );
+
+ let args = json!({
+ "location": "San Francisco",
+ "unit": "Celsius"
+ });
+
+ let query: WeatherQuery = serde_json::from_value(args).unwrap();
+
+ let result = cx.update(|cx| tool.execute(&query, cx)).await;
+
+ assert!(result.is_ok());
+ let result = result.unwrap();
+
+ assert_eq!(result, tool.current_weather);
+ }
+}
@@ -0,0 +1,145 @@
+use anyhow::Result;
+use gpui::{div, AnyElement, AppContext, Element, ParentElement as _, Task, WindowContext};
+use schemars::{schema::SchemaObject, schema_for, JsonSchema};
+use serde::Deserialize;
+use std::{any::Any, fmt::Debug};
+
+#[derive(Default, Deserialize)]
+pub struct ToolFunctionCall {
+ pub id: String,
+ pub name: String,
+ pub arguments: String,
+ #[serde(skip)]
+ pub result: Option<ToolFunctionCallResult>,
+}
+
+pub enum ToolFunctionCallResult {
+ NoSuchTool,
+ ParsingFailed,
+ ExecutionFailed {
+ input: Box<dyn Any>,
+ },
+ Finished {
+ input: Box<dyn Any>,
+ output: Box<dyn Any>,
+ render_fn: fn(
+ // tool_call_id
+ &str,
+ // LanguageModelTool::Input
+ &Box<dyn Any>,
+ // LanguageModelTool::Output
+ &Box<dyn Any>,
+ &mut WindowContext,
+ ) -> AnyElement,
+ format_fn: fn(
+ // LanguageModelTool::Input
+ &Box<dyn Any>,
+ // LanguageModelTool::Output
+ &Box<dyn Any>,
+ ) -> String,
+ },
+}
+
+impl ToolFunctionCallResult {
+ pub fn render(
+ &self,
+ tool_name: &str,
+ tool_call_id: &str,
+ cx: &mut WindowContext,
+ ) -> AnyElement {
+ match self {
+ ToolFunctionCallResult::NoSuchTool => {
+ div().child(format!("no such tool {tool_name}")).into_any()
+ }
+ ToolFunctionCallResult::ParsingFailed => div()
+ .child(format!("failed to parse input for tool {tool_name}"))
+ .into_any(),
+ ToolFunctionCallResult::ExecutionFailed { .. } => div()
+ .child(format!("failed to execute tool {tool_name}"))
+ .into_any(),
+ ToolFunctionCallResult::Finished {
+ input,
+ output,
+ render_fn,
+ ..
+ } => render_fn(tool_call_id, input, output, cx),
+ }
+ }
+
+ pub fn format(&self, tool: &str) -> String {
+ match self {
+ ToolFunctionCallResult::NoSuchTool => format!("no such tool {tool}"),
+ ToolFunctionCallResult::ParsingFailed => {
+ format!("failed to parse input for tool {tool}")
+ }
+ ToolFunctionCallResult::ExecutionFailed { input: _input } => {
+ format!("failed to execute tool {tool}")
+ }
+ ToolFunctionCallResult::Finished {
+ input,
+ output,
+ format_fn,
+ ..
+ } => format_fn(input, output),
+ }
+ }
+}
+
+#[derive(Clone)]
+pub struct ToolFunctionDefinition {
+ pub name: String,
+ pub description: String,
+ pub parameters: SchemaObject,
+}
+
+impl Debug for ToolFunctionDefinition {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ let schema = serde_json::to_string(&self.parameters).ok();
+ let schema = schema.unwrap_or("None".to_string());
+
+ f.debug_struct("ToolFunctionDefinition")
+ .field("name", &self.name)
+ .field("description", &self.description)
+ .field("parameters", &schema)
+ .finish()
+ }
+}
+
+pub trait LanguageModelTool {
+ /// The input type that will be passed in to `execute` when the tool is called
+ /// by the language model.
+ type Input: for<'de> Deserialize<'de> + JsonSchema;
+
+ /// The output returned by executing the tool.
+ type Output: 'static;
+
+ /// The name of the tool is exposed to the language model to allow
+ /// the model to pick which tools to use. As this name is used to
+ /// identify the tool within a tool registry, it should be unique.
+ fn name(&self) -> String;
+
+ /// A description of the tool that can be used to _prompt_ the model
+ /// as to what the tool does.
+ fn description(&self) -> String;
+
+ /// The OpenAI Function definition for the tool, for direct use with OpenAI's API.
+ fn definition(&self) -> ToolFunctionDefinition {
+ ToolFunctionDefinition {
+ name: self.name(),
+ description: self.description(),
+ parameters: schema_for!(Self::Input).schema,
+ }
+ }
+
+ /// Execute the tool
+ fn execute(&self, input: &Self::Input, cx: &AppContext) -> Task<Result<Self::Output>>;
+
+ fn render(
+ tool_call_id: &str,
+ input: &Self::Input,
+ output: &Self::Output,
+ cx: &mut WindowContext,
+ ) -> AnyElement;
+
+ fn format(input: &Self::Input, output: &Self::Output) -> String;
+}
@@ -457,6 +457,14 @@ impl Client {
})
}
+ pub fn production(cx: &mut AppContext) -> Arc<Self> {
+ let clock = Arc::new(clock::RealSystemClock);
+ let http = Arc::new(HttpClientWithUrl::new(
+ &ClientSettings::get_global(cx).server_url,
+ ));
+ Self::new(clock, http.clone(), cx)
+ }
+
pub fn id(&self) -> u64 {
self.id.load(Ordering::SeqCst)
}
@@ -1119,6 +1127,8 @@ impl Client {
if let Some((login, token)) =
IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
{
+ eprintln!("authenticate as admin {login}, {token}");
+
return Self::authenticate_as_admin(http, login.clone(), token.clone())
.await;
}
@@ -5,7 +5,8 @@
"maxbrunsfeld",
"iamnbutler",
"mikayla-maki",
- "JosephTLyons"
+ "JosephTLyons",
+ "rgbkrk"
],
"channels": ["zed"],
"number_of_users": 100
@@ -1,5 +1,6 @@
-use anyhow::{anyhow, Result};
+use anyhow::{anyhow, Context as _, Result};
use rpc::proto;
+use util::ResultExt as _;
pub fn language_model_request_to_open_ai(
request: proto::CompleteWithLanguageModel,
@@ -9,24 +10,83 @@ pub fn language_model_request_to_open_ai(
messages: request
.messages
.into_iter()
- .map(|message| {
+ .map(|message: proto::LanguageModelRequestMessage| {
let role = proto::LanguageModelRole::from_i32(message.role)
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
- Ok(open_ai::RequestMessage {
- role: match role {
- proto::LanguageModelRole::LanguageModelUser => open_ai::Role::User,
- proto::LanguageModelRole::LanguageModelAssistant => {
- open_ai::Role::Assistant
+
+ let openai_message = match role {
+ proto::LanguageModelRole::LanguageModelUser => open_ai::RequestMessage::User {
+ content: message.content,
+ },
+ proto::LanguageModelRole::LanguageModelAssistant => {
+ open_ai::RequestMessage::Assistant {
+ content: Some(message.content),
+ tool_calls: message
+ .tool_calls
+ .into_iter()
+ .filter_map(|call| {
+ Some(open_ai::ToolCall {
+ id: call.id,
+ content: match call.variant? {
+ proto::tool_call::Variant::Function(f) => {
+ open_ai::ToolCallContent::Function {
+ function: open_ai::FunctionContent {
+ name: f.name,
+ arguments: f.arguments,
+ },
+ }
+ }
+ },
+ })
+ })
+ .collect(),
+ }
+ }
+ proto::LanguageModelRole::LanguageModelSystem => {
+ open_ai::RequestMessage::System {
+ content: message.content,
}
- proto::LanguageModelRole::LanguageModelSystem => open_ai::Role::System,
+ }
+ proto::LanguageModelRole::LanguageModelTool => open_ai::RequestMessage::Tool {
+ tool_call_id: message
+ .tool_call_id
+ .ok_or_else(|| anyhow!("tool message is missing tool call id"))?,
+ content: message.content,
},
- content: message.content,
- })
+ };
+
+ Ok(openai_message)
})
.collect::<Result<Vec<open_ai::RequestMessage>>>()?,
stream: true,
stop: request.stop,
temperature: request.temperature,
+ tools: request
+ .tools
+ .into_iter()
+ .filter_map(|tool| {
+ Some(match tool.variant? {
+ proto::chat_completion_tool::Variant::Function(f) => {
+ open_ai::ToolDefinition::Function {
+ function: open_ai::FunctionDefinition {
+ name: f.name,
+ description: f.description,
+ parameters: if let Some(params) = &f.parameters {
+ Some(
+ serde_json::from_str(params)
+ .context("failed to deserialize tool parameters")
+ .log_err()?,
+ )
+ } else {
+ None
+ },
+ },
+ }
+ }
+ })
+ })
+ .collect(),
+ tool_choice: request.tool_choice,
})
}
@@ -58,6 +118,9 @@ pub fn language_model_request_message_to_google_ai(
proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User,
proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model,
proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User,
+ proto::LanguageModelRole::LanguageModelTool => {
+ Err(anyhow!("we don't handle tool calls with google ai yet"))?
+ }
},
})
}
@@ -775,9 +775,7 @@ impl Server {
Box::new(move |envelope, session| {
let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
let received_at = envelope.received_at;
- tracing::info!(
- "message received"
- );
+ tracing::info!("message received");
let start_time = Instant::now();
let future = (handler)(*envelope, session);
async move {
@@ -786,12 +784,24 @@ impl Server {
let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
let queue_duration_ms = total_duration_ms - processing_duration_ms;
let payload_type = M::NAME;
+
match result {
Err(error) => {
- // todo!(), why isn't this logged inside the span?
- tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, payload_type, "error handling message")
+ tracing::error!(
+ ?error,
+ total_duration_ms,
+ processing_duration_ms,
+ queue_duration_ms,
+ payload_type,
+ "error handling message"
+ )
}
- Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
+ Ok(()) => tracing::info!(
+ total_duration_ms,
+ processing_duration_ms,
+ queue_duration_ms,
+ "finished handling message"
+ ),
}
}
.boxed()
@@ -4098,7 +4108,7 @@ async fn complete_with_open_ai(
crate::ai::language_model_request_to_open_ai(request)?,
)
.await
- .context("open_ai::stream_completion request failed")?;
+ .context("open_ai::stream_completion request failed within collab")?;
while let Some(event) = completion_stream.next().await {
let event = event?;
@@ -4113,8 +4123,32 @@ async fn complete_with_open_ai(
open_ai::Role::User => LanguageModelRole::LanguageModelUser,
open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
+ open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
} as i32),
content: choice.delta.content,
+ tool_calls: choice
+ .delta
+ .tool_calls
+ .into_iter()
+ .map(|delta| proto::ToolCallDelta {
+ index: delta.index as u32,
+ id: delta.id,
+ variant: match delta.function {
+ Some(function) => {
+ let name = function.name;
+ let arguments = function.arguments;
+
+ Some(proto::tool_call_delta::Variant::Function(
+ proto::tool_call_delta::FunctionCallDelta {
+ name,
+ arguments,
+ },
+ ))
+ }
+ None => None,
+ },
+ })
+ .collect(),
}),
finish_reason: choice.finish_reason,
})
@@ -4165,6 +4199,8 @@ async fn complete_with_google_ai(
})
.collect(),
),
+ // Tool calls are not supported for Google
+ tool_calls: Vec::new(),
}),
finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
})
@@ -4187,24 +4223,28 @@ async fn complete_with_anthropic(
let messages = request
.messages
.into_iter()
- .filter_map(|message| match message.role() {
- LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
- role: anthropic::Role::User,
- content: message.content,
- }),
- LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
- role: anthropic::Role::Assistant,
- content: message.content,
- }),
- // Anthropic's API breaks system instructions out as a separate field rather
- // than having a system message role.
- LanguageModelRole::LanguageModelSystem => {
- if !system_message.is_empty() {
- system_message.push_str("\n\n");
- }
- system_message.push_str(&message.content);
+ .filter_map(|message| {
+ match message.role() {
+ LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
+ role: anthropic::Role::User,
+ content: message.content,
+ }),
+ LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
+ role: anthropic::Role::Assistant,
+ content: message.content,
+ }),
+ // Anthropic's API breaks system instructions out as a separate field rather
+ // than having a system message role.
+ LanguageModelRole::LanguageModelSystem => {
+ if !system_message.is_empty() {
+ system_message.push_str("\n\n");
+ }
+ system_message.push_str(&message.content);
- None
+ None
+ }
+ // We don't yet support tool calls for Anthropic
+ LanguageModelRole::LanguageModelTool => None,
}
})
.collect();
@@ -4248,6 +4288,7 @@ async fn complete_with_anthropic(
delta: Some(proto::LanguageModelResponseMessage {
role: Some(current_role as i32),
content: Some(text),
+ tool_calls: Vec::new(),
}),
finish_reason: None,
}],
@@ -4264,6 +4305,7 @@ async fn complete_with_anthropic(
delta: Some(proto::LanguageModelResponseMessage {
role: Some(current_role as i32),
content: Some(text),
+ tool_calls: Vec::new(),
}),
finish_reason: None,
}],
@@ -234,10 +234,11 @@ impl ChatPanel {
let channel_id = chat.read(cx).channel_id;
{
self.markdown_data.clear();
- let chat = chat.read(cx);
- self.message_list.reset(chat.message_count());
+ let chat = chat.read(cx);
let channel_name = chat.channel(cx).map(|channel| channel.name.clone());
+ let message_count = chat.message_count();
+ self.message_list.reset(message_count);
self.message_editor.update(cx, |editor, cx| {
editor.set_channel(channel_id, channel_name, cx);
editor.clear_reply_to_message_id();
@@ -766,7 +767,7 @@ impl ChatPanel {
body.push_str(MESSAGE_EDITED);
}
- let mut rich_text = rich_text::render_rich_text(body, &mentions, language_registry, None);
+ let mut rich_text = RichText::new(body, &mentions, language_registry);
if message.edited_at.is_some() {
let range = (rich_text.text.len() - MESSAGE_EDITED.len())..rich_text.text.len();
@@ -2947,7 +2947,7 @@ impl Render for DraggedChannelView {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl Element {
let ui_font = ThemeSettings::get_global(cx).ui_font.family.clone();
h_flex()
- .font(ui_font)
+ .font_family(ui_font)
.bg(cx.theme().colors().background)
.w(self.width)
.p_1()
@@ -125,7 +125,7 @@ impl Render for IncomingCallNotification {
cx.set_rem_size(ui_font_size);
- div().size_full().font(ui_font).child(
+ div().size_full().font_family(ui_font).child(
CollabNotification::new(
self.state.call.calling_user.avatar_uri.clone(),
Button::new("accept", "Accept").on_click({
@@ -129,7 +129,7 @@ impl Render for ProjectSharedNotification {
cx.set_rem_size(ui_font_size);
- div().size_full().font(ui_font).child(
+ div().size_full().font_family(ui_font).child(
CollabNotification::new(
self.owner.avatar_uri.clone(),
Button::new("open", "Open").on_click(cx.listener(move |this, _event, cx| {
@@ -61,13 +61,13 @@ use fuzzy::{StringMatch, StringMatchCandidate};
use git::blame::GitBlame;
use git::diff_hunk_to_display;
use gpui::{
- div, impl_actions, point, prelude::*, px, relative, rems, size, uniform_list, Action,
- AnyElement, AppContext, AsyncWindowContext, AvailableSpace, BackgroundExecutor, Bounds,
- ClipboardItem, Context, DispatchPhase, ElementId, EventEmitter, FocusHandle, FocusableView,
- FontId, FontStyle, FontWeight, HighlightStyle, Hsla, InteractiveText, KeyContext, Model,
- MouseButton, PaintQuad, ParentElement, Pixels, Render, SharedString, Size, StrikethroughStyle,
- Styled, StyledText, Subscription, Task, TextStyle, UnderlineStyle, UniformListScrollHandle,
- View, ViewContext, ViewInputHandler, VisualContext, WeakView, WhiteSpace, WindowContext,
+ div, impl_actions, point, prelude::*, px, relative, size, uniform_list, Action, AnyElement,
+ AppContext, AsyncWindowContext, AvailableSpace, BackgroundExecutor, Bounds, ClipboardItem,
+ Context, DispatchPhase, ElementId, EventEmitter, FocusHandle, FocusableView, FontId, FontStyle,
+ FontWeight, HighlightStyle, Hsla, InteractiveText, KeyContext, Model, MouseButton, PaintQuad,
+ ParentElement, Pixels, Render, SharedString, Size, StrikethroughStyle, Styled, StyledText,
+ Subscription, Task, TextStyle, UnderlineStyle, UniformListScrollHandle, View, ViewContext,
+ ViewInputHandler, VisualContext, WeakView, WhiteSpace, WindowContext,
};
use highlight_matching_bracket::refresh_matching_bracket_highlights;
use hover_popover::{hide_hover, HoverState};
@@ -8885,7 +8885,6 @@ impl Editor {
self.style = Some(style);
}
- #[cfg(any(test, feature = "test-support"))]
pub fn style(&self) -> Option<&EditorStyle> {
self.style.as_ref()
}
@@ -10322,21 +10321,9 @@ impl FocusableView for Editor {
impl Render for Editor {
fn render<'a>(&mut self, cx: &mut ViewContext<'a, Self>) -> impl IntoElement {
let settings = ThemeSettings::get_global(cx);
- let text_style = match self.mode {
- EditorMode::SingleLine | EditorMode::AutoHeight { .. } => TextStyle {
- color: cx.theme().colors().editor_foreground,
- font_family: settings.ui_font.family.clone(),
- font_features: settings.ui_font.features,
- font_size: rems(0.875).into(),
- font_weight: FontWeight::NORMAL,
- font_style: FontStyle::Normal,
- line_height: relative(settings.buffer_line_height.value()),
- background_color: None,
- underline: None,
- strikethrough: None,
- white_space: WhiteSpace::Normal,
- },
+ let text_style = match self.mode {
+ EditorMode::SingleLine | EditorMode::AutoHeight { .. } => cx.text_style(),
EditorMode::Full => TextStyle {
color: cx.theme().colors().editor_foreground,
font_family: settings.buffer_font.family.clone(),
@@ -3056,7 +3056,7 @@ fn render_inline_blame_entry(
h_flex()
.id("inline-blame")
.w_full()
- .font(style.text.font().family)
+ .font_family(style.text.font().family)
.text_color(cx.theme().status().hint)
.line_height(style.text.line_height)
.child(Icon::new(IconName::FileGit).color(Color::Hint))
@@ -3108,7 +3108,7 @@ fn render_blame_entry(
h_flex()
.w_full()
- .font(style.text.font().family)
+ .font_family(style.text.font().family)
.line_height(style.text.line_height)
.id(("blame", ix))
.children([
@@ -1,7 +1,7 @@
use crate::{
self as gpui, hsla, point, px, relative, rems, AbsoluteLength, AlignItems, CursorStyle,
- DefiniteLength, Fill, FlexDirection, FlexWrap, FontStyle, FontWeight, Hsla, JustifyContent,
- Length, Position, SharedString, StyleRefinement, Visibility, WhiteSpace,
+ DefiniteLength, Fill, FlexDirection, FlexWrap, Font, FontStyle, FontWeight, Hsla,
+ JustifyContent, Length, Position, SharedString, StyleRefinement, Visibility, WhiteSpace,
};
use crate::{BoxShadow, TextStyleRefinement};
use smallvec::{smallvec, SmallVec};
@@ -771,14 +771,32 @@ pub trait Styled: Sized {
self
}
- /// Change the font on this element and its children.
- fn font(mut self, family_name: impl Into<SharedString>) -> Self {
+ /// Change the font family on this element and its children.
+ fn font_family(mut self, family_name: impl Into<SharedString>) -> Self {
self.text_style()
.get_or_insert_with(Default::default)
.font_family = Some(family_name.into());
self
}
+ /// Change the font of this element and its children.
+ fn font(mut self, font: Font) -> Self {
+ let Font {
+ family,
+ features,
+ weight,
+ style,
+ } = font;
+
+ let text_style = self.text_style().get_or_insert_with(Default::default);
+ text_style.font_family = Some(family);
+ text_style.font_features = Some(features);
+ text_style.font_weight = Some(weight);
+ text_style.font_style = Some(style);
+
+ self
+ }
+
/// Set the line height on this element and its children.
fn line_height(mut self, line_height: impl Into<DefiniteLength>) -> Self {
self.text_style()
@@ -1,6 +1,7 @@
use anyhow::{anyhow, Context, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use serde::{Deserialize, Serialize};
+use serde_json::{Map, Value};
use std::{convert::TryFrom, future::Future};
use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
@@ -12,6 +13,7 @@ pub enum Role {
User,
Assistant,
System,
+ Tool,
}
impl TryFrom<String> for Role {
@@ -22,6 +24,7 @@ impl TryFrom<String> for Role {
"user" => Ok(Self::User),
"assistant" => Ok(Self::Assistant),
"system" => Ok(Self::System),
+ "tool" => Ok(Self::Tool),
_ => Err(anyhow!("invalid role '{value}'")),
}
}
@@ -33,6 +36,7 @@ impl From<Role> for String {
Role::User => "user".to_owned(),
Role::Assistant => "assistant".to_owned(),
Role::System => "system".to_owned(),
+ Role::Tool => "tool".to_owned(),
}
}
}
@@ -91,18 +95,88 @@ pub struct Request {
pub stream: bool,
pub stop: Vec<String>,
pub temperature: f32,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub tool_choice: Option<String>,
+ #[serde(skip_serializing_if = "Vec::is_empty")]
+ pub tools: Vec<ToolDefinition>,
+}
+
+#[derive(Debug, Serialize)]
+pub struct FunctionDefinition {
+ pub name: String,
+ pub description: Option<String>,
+ pub parameters: Option<Map<String, Value>>,
+}
+
+#[derive(Serialize, Debug)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum ToolDefinition {
+ #[allow(dead_code)]
+ Function { function: FunctionDefinition },
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(tag = "role", rename_all = "lowercase")]
+pub enum RequestMessage {
+ Assistant {
+ content: Option<String>,
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ tool_calls: Vec<ToolCall>,
+ },
+ User {
+ content: String,
+ },
+ System {
+ content: String,
+ },
+ Tool {
+ content: String,
+ tool_call_id: String,
+ },
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ToolCall {
+ pub id: String,
+ #[serde(flatten)]
+ pub content: ToolCallContent,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct RequestMessage {
- pub role: Role,
- pub content: String,
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ToolCallContent {
+ Function { function: FunctionContent },
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct ResponseMessage {
+pub struct FunctionContent {
+ pub name: String,
+ pub arguments: String,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ResponseMessageDelta {
pub role: Option<Role>,
pub content: Option<String>,
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ pub tool_calls: Vec<ToolCallChunk>,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ToolCallChunk {
+ pub index: usize,
+ pub id: Option<String>,
+
+ // There is also an optional `type` field that would determine if a
+ // function is there. Sometimes this streams in with the `function` before
+ // it streams in the `type`
+ pub function: Option<FunctionChunk>,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct FunctionChunk {
+ pub name: Option<String>,
+ pub arguments: Option<String>,
}
#[derive(Deserialize, Debug)]
@@ -115,7 +189,7 @@ pub struct Usage {
#[derive(Deserialize, Debug)]
pub struct ChoiceDelta {
pub index: u32,
- pub delta: ResponseMessage,
+ pub delta: ResponseMessageDelta,
pub finish_reason: Option<String>,
}
@@ -1843,7 +1843,7 @@ impl Render for DraggedProjectEntryView {
let settings = ProjectPanelSettings::get_global(cx);
let ui_font = ThemeSettings::get_global(cx).ui_font.family.clone();
h_flex()
- .font(ui_font)
+ .font_family(ui_font)
.bg(cx.theme().colors().background)
.w(self.width)
.child(
@@ -507,7 +507,7 @@ impl RemoteProjects {
.my_1()
.py_0p5()
.px_3()
- .font(ThemeSettings::get_global(cx).buffer_font.family.clone())
+ .font_family(ThemeSettings::get_global(cx).buffer_font.family.clone())
.child(Label::new(instructions))
)
.when(status == DevServerStatus::Offline, |this| {
@@ -43,6 +43,19 @@ pub struct RichText {
Option<Arc<dyn Fn(usize, Range<usize>, &mut WindowContext) -> Option<AnyView>>>,
}
+impl Default for RichText {
+ fn default() -> Self {
+ Self {
+ text: SharedString::default(),
+ highlights: Vec::new(),
+ link_ranges: Vec::new(),
+ link_urls: Arc::from([]),
+ custom_ranges: Vec::new(),
+ custom_ranges_tooltip_fn: None,
+ }
+ }
+}
+
/// Allows one to specify extra links to the rendered markdown, which can be used
/// for e.g. mentions.
#[derive(Debug)]
@@ -52,6 +65,37 @@ pub struct Mention {
}
impl RichText {
+ pub fn new(
+ block: String,
+ mentions: &[Mention],
+ language_registry: &Arc<LanguageRegistry>,
+ ) -> Self {
+ let mut text = String::new();
+ let mut highlights = Vec::new();
+ let mut link_ranges = Vec::new();
+ let mut link_urls = Vec::new();
+ render_markdown_mut(
+ &block,
+ mentions,
+ language_registry,
+ None,
+ &mut text,
+ &mut highlights,
+ &mut link_ranges,
+ &mut link_urls,
+ );
+ text.truncate(text.trim_end().len());
+
+ RichText {
+ text: SharedString::from(text),
+ link_urls: link_urls.into(),
+ link_ranges,
+ highlights,
+ custom_ranges: Vec::new(),
+ custom_ranges_tooltip_fn: None,
+ }
+ }
+
pub fn set_tooltip_builder_for_custom_ranges(
&mut self,
f: impl Fn(usize, Range<usize>, &mut WindowContext) -> Option<AnyView> + 'static,
@@ -347,38 +391,6 @@ pub fn render_markdown_mut(
}
}
-pub fn render_rich_text(
- block: String,
- mentions: &[Mention],
- language_registry: &Arc<LanguageRegistry>,
- language: Option<&Arc<Language>>,
-) -> RichText {
- let mut text = String::new();
- let mut highlights = Vec::new();
- let mut link_ranges = Vec::new();
- let mut link_urls = Vec::new();
- render_markdown_mut(
- &block,
- mentions,
- language_registry,
- language,
- &mut text,
- &mut highlights,
- &mut link_ranges,
- &mut link_urls,
- );
- text.truncate(text.trim_end().len());
-
- RichText {
- text: SharedString::from(text),
- link_urls: link_urls.into(),
- link_ranges,
- highlights,
- custom_ranges: Vec::new(),
- custom_ranges_tooltip_fn: None,
- }
-}
-
pub fn render_code(
text: &mut String,
highlights: &mut Vec<(Range<usize>, Highlight)>,
@@ -1880,22 +1880,70 @@ message CompleteWithLanguageModel {
repeated LanguageModelRequestMessage messages = 2;
repeated string stop = 3;
float temperature = 4;
+ repeated ChatCompletionTool tools = 5;
+ optional string tool_choice = 6;
}
+// A tool presented to the language model for its use
+message ChatCompletionTool {
+ oneof variant {
+ FunctionObject function = 1;
+ }
+
+ message FunctionObject {
+ string name = 1;
+ optional string description = 2;
+ optional string parameters = 3;
+ }
+}
+
+// A message to the language model
message LanguageModelRequestMessage {
LanguageModelRole role = 1;
string content = 2;
+ optional string tool_call_id = 3;
+ repeated ToolCall tool_calls = 4;
}
enum LanguageModelRole {
LanguageModelUser = 0;
LanguageModelAssistant = 1;
LanguageModelSystem = 2;
+ LanguageModelTool = 3;
}
message LanguageModelResponseMessage {
optional LanguageModelRole role = 1;
optional string content = 2;
+ repeated ToolCallDelta tool_calls = 3;
+}
+
+// A request to call a tool, by the language model
+message ToolCall {
+ string id = 1;
+
+ oneof variant {
+ FunctionCall function = 2;
+ }
+
+ message FunctionCall {
+ string name = 1;
+ string arguments = 2;
+ }
+}
+
+message ToolCallDelta {
+ uint32 index = 1;
+ optional string id = 2;
+
+ oneof variant {
+ FunctionCallDelta function = 3;
+ }
+
+ message FunctionCallDelta {
+ optional string name = 1;
+ optional string arguments = 2;
+ }
}
message LanguageModelResponse {
@@ -12,6 +12,11 @@ workspace = true
[lib]
path = "src/semantic_index.rs"
+[[example]]
+name = "index"
+path = "examples/index.rs"
+crate-type = ["bin"]
+
[dependencies]
anyhow.workspace = true
client.workspace = true
@@ -1,25 +1,16 @@
use client::Client;
use futures::channel::oneshot;
-use gpui::{App, Global, TestAppContext};
+use gpui::{App, Global};
use language::language_settings::AllLanguageSettings;
use project::Project;
use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticIndex};
use settings::SettingsStore;
-use std::{path::Path, sync::Arc};
+use std::{
+ path::{Path, PathBuf},
+ sync::Arc,
+};
use util::http::HttpClientWithUrl;
-pub fn init_test(cx: &mut TestAppContext) {
- _ = cx.update(|cx| {
- let store = SettingsStore::test(cx);
- cx.set_global(store);
- language::init(cx);
- Project::init_settings(cx);
- SettingsStore::update(cx, |store, cx| {
- store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
- });
- });
-}
-
fn main() {
env_logger::init();
@@ -50,20 +41,21 @@ fn main() {
// let embedding_provider = semantic_index::FakeEmbeddingProvider;
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
- let embedding_provider = OpenAiEmbeddingProvider::new(
+
+ let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
http.clone(),
OpenAiEmbeddingModel::TextEmbedding3Small,
open_ai::OPEN_AI_API_URL.to_string(),
api_key,
- );
-
- let semantic_index = SemanticIndex::new(
- Path::new("/tmp/semantic-index-db.mdb"),
- Arc::new(embedding_provider),
- cx,
- );
+ ));
cx.spawn(|mut cx| async move {
+ let semantic_index = SemanticIndex::new(
+ PathBuf::from("/tmp/semantic-index-db.mdb"),
+ embedding_provider,
+ &mut cx,
+ );
+
let mut semantic_index = semantic_index.await.unwrap();
let project_path = Path::new(&args[1]);
@@ -21,7 +21,7 @@ use std::{
cmp::Ordering,
future::Future,
ops::Range,
- path::Path,
+ path::{Path, PathBuf},
sync::Arc,
time::{Duration, SystemTime},
};
@@ -37,30 +37,29 @@ pub struct SemanticIndex {
impl Global for SemanticIndex {}
impl SemanticIndex {
- pub fn new(
- db_path: &Path,
+ pub async fn new(
+ db_path: PathBuf,
embedding_provider: Arc<dyn EmbeddingProvider>,
- cx: &mut AppContext,
- ) -> Task<Result<Self>> {
- let db_path = db_path.to_path_buf();
- cx.spawn(|cx| async move {
- let db_connection = cx
- .background_executor()
- .spawn(async move {
- unsafe {
- heed::EnvOpenOptions::new()
- .map_size(1024 * 1024 * 1024)
- .max_dbs(3000)
- .open(db_path)
- }
- })
- .await?;
-
- Ok(SemanticIndex {
- db_connection,
- embedding_provider,
- project_indices: HashMap::default(),
+ cx: &mut AsyncAppContext,
+ ) -> Result<Self> {
+ let db_connection = cx
+ .background_executor()
+ .spawn(async move {
+ std::fs::create_dir_all(&db_path)?;
+ unsafe {
+ heed::EnvOpenOptions::new()
+ .map_size(1024 * 1024 * 1024)
+ .max_dbs(3000)
+ .open(db_path)
+ }
})
+ .await
+ .context("opening database connection")?;
+
+ Ok(SemanticIndex {
+ db_connection,
+ embedding_provider,
+ project_indices: HashMap::default(),
})
}
@@ -91,7 +90,7 @@ pub struct ProjectIndex {
worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
- last_status: Status,
+ pub last_status: Status,
embedding_provider: Arc<dyn EmbeddingProvider>,
_subscription: Subscription,
}
@@ -397,7 +396,7 @@ impl WorktreeIndex {
) -> impl Future<Output = Result<()>> {
let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
let worktree_abs_path = worktree.abs_path().clone();
- let scan = self.scan_updated_entries(worktree, updated_entries, cx);
+ let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
let embed = self.embed_files(chunk.files, cx);
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
@@ -498,7 +497,9 @@ impl WorktreeIndex {
| project::PathChange::Updated
| project::PathChange::AddedOrUpdated => {
if let Some(entry) = worktree.entry_for_id(*entry_id) {
- updated_entries_tx.send(entry.clone()).await?;
+ if entry.is_file() {
+ updated_entries_tx.send(entry.clone()).await?;
+ }
}
}
project::PathChange::Removed => {
@@ -539,7 +540,14 @@ impl WorktreeIndex {
cx.spawn(async {
while let Ok(entry) = entries.recv().await {
let entry_abs_path = worktree_abs_path.join(&entry.path);
- let Some(text) = fs.load(&entry_abs_path).await.log_err() else {
+ let Some(text) = fs
+ .load(&entry_abs_path)
+ .await
+ .with_context(|| {
+ format!("failed to read path {entry_abs_path:?}")
+ })
+ .log_err()
+ else {
continue;
};
let language = language_registry
@@ -683,7 +691,7 @@ impl WorktreeIndex {
.context("failed to create read transaction")?;
let db_entries = db.iter(&txn).context("failed to iterate database")?;
for db_entry in db_entries {
- let (_, db_embedded_file) = db_entry?;
+ let (_key, db_embedded_file) = db_entry?;
for chunk in db_embedded_file.chunks {
chunks_tx
.send((db_embedded_file.path.clone(), chunk))
@@ -700,6 +708,7 @@ impl WorktreeIndex {
cx.spawn(|cx| async move {
#[cfg(debug_assertions)]
let embedding_query_start = std::time::Instant::now();
+ log::info!("Searching for {query}");
let mut query_embeddings = embedding_provider
.embed(&[TextToEmbed::new(&query)])
@@ -876,17 +885,13 @@ mod tests {
let temp_dir = tempfile::tempdir().unwrap();
- let mut semantic_index = cx
- .update(|cx| {
- let semantic_index = SemanticIndex::new(
- Path::new(temp_dir.path()),
- Arc::new(TestEmbeddingProvider),
- cx,
- );
- semantic_index
- })
- .await
- .unwrap();
+ let mut semantic_index = SemanticIndex::new(
+ temp_dir.path().into(),
+ Arc::new(TestEmbeddingProvider),
+ &mut cx.to_async(),
+ )
+ .await
+ .unwrap();
let project_path = Path::new("./fixture");
@@ -2,6 +2,7 @@ mod keymap_file;
mod settings_file;
mod settings_store;
+use gpui::AppContext;
use rust_embed::RustEmbed;
use std::{borrow::Cow, str};
use util::asset_str;
@@ -19,6 +20,14 @@ pub use settings_store::{
#[exclude = "*.DS_Store"]
pub struct SettingsAssets;
+pub fn init(cx: &mut AppContext) {
+ let mut settings = SettingsStore::default();
+ settings
+ .set_default_settings(&default_settings(), cx)
+ .unwrap();
+ cx.set_global(settings);
+}
+
pub fn default_settings() -> Cow<'static, str> {
asset_str::<SettingsAssets>("settings/default.json")
}
@@ -29,14 +29,14 @@ pub enum ComponentStory {
ListHeader,
ListItem,
OverflowScroll,
+ Picker,
Scroll,
Tab,
TabBar,
+ Text,
TitleBar,
ToggleButton,
- Text,
ViewportUnits,
- Picker,
}
impl ComponentStory {
@@ -11,7 +11,7 @@ use gpui::{
};
use log::LevelFilter;
use project::Project;
-use settings::{default_settings, KeymapFile, Settings, SettingsStore};
+use settings::{KeymapFile, Settings};
use simplelog::SimpleLogger;
use strum::IntoEnumIterator;
use theme::{ThemeRegistry, ThemeSettings};
@@ -64,12 +64,7 @@ fn main() {
gpui::App::new().with_assets(Assets).run(move |cx| {
load_embedded_fonts(cx).unwrap();
- let mut store = SettingsStore::default();
- store
- .set_default_settings(default_settings().as_ref(), cx)
- .unwrap();
- cx.set_global(store);
-
+ settings::init(cx);
theme::init(theme::LoadThemes::All(Box::new(Assets)), cx);
let selector = story_selector;
@@ -122,7 +117,7 @@ impl Render for StoryWrapper {
.flex()
.flex_col()
.size_full()
- .font("Zed Mono")
+ .font_family("Zed Mono")
.child(self.story.clone())
}
}
@@ -1,6 +1,7 @@
mod avatar;
mod button;
mod checkbox;
+mod collapsible_container;
mod context_menu;
mod disclosure;
mod divider;
@@ -25,6 +26,7 @@ mod stories;
pub use avatar::*;
pub use button::*;
pub use checkbox::*;
+pub use collapsible_container::*;
pub use context_menu::*;
pub use disclosure::*;
pub use divider::*;
@@ -0,0 +1,152 @@
+use crate::{prelude::*, ButtonLike};
+use smallvec::SmallVec;
+
+use gpui::*;
+
+#[derive(Default, Clone, Copy, Debug, PartialEq)]
+pub enum ContainerStyle {
+ #[default]
+ None,
+ Card,
+}
+
+struct ContainerStyles {
+ pub background_color: Hsla,
+ pub border_color: Hsla,
+ pub text_color: Hsla,
+}
+
+#[derive(IntoElement)]
+pub struct CollapsibleContainer {
+ id: ElementId,
+ base: ButtonLike,
+ toggle: bool,
+ /// A slot for content that appears before the label, like an icon or avatar.
+ start_slot: Option<AnyElement>,
+ /// A slot for content that appears after the label, usually on the other side of the header.
+ /// This might be a button, a disclosure arrow, a face pile, etc.
+ end_slot: Option<AnyElement>,
+ style: ContainerStyle,
+ children: SmallVec<[AnyElement; 1]>,
+}
+
+impl CollapsibleContainer {
+ pub fn new(id: impl Into<ElementId>, toggle: bool) -> Self {
+ Self {
+ id: id.into(),
+ base: ButtonLike::new("button_base"),
+ toggle,
+ start_slot: None,
+ end_slot: None,
+ style: ContainerStyle::Card,
+ children: SmallVec::new(),
+ }
+ }
+
+ pub fn start_slot<E: IntoElement>(mut self, start_slot: impl Into<Option<E>>) -> Self {
+ self.start_slot = start_slot.into().map(IntoElement::into_any_element);
+ self
+ }
+
+ pub fn end_slot<E: IntoElement>(mut self, end_slot: impl Into<Option<E>>) -> Self {
+ self.end_slot = end_slot.into().map(IntoElement::into_any_element);
+ self
+ }
+
+ pub fn child<E: IntoElement>(mut self, child: E) -> Self {
+ self.children.push(child.into_any_element());
+ self
+ }
+}
+
+impl Clickable for CollapsibleContainer {
+ fn on_click(mut self, handler: impl Fn(&ClickEvent, &mut WindowContext) + 'static) -> Self {
+ self.base = self.base.on_click(handler);
+ self
+ }
+}
+
+impl RenderOnce for CollapsibleContainer {
+ fn render(self, cx: &mut WindowContext) -> impl IntoElement {
+ let color = cx.theme().colors();
+
+ let styles = match self.style {
+ ContainerStyle::None => ContainerStyles {
+ background_color: color.ghost_element_background,
+ border_color: color.border_transparent,
+ text_color: color.text,
+ },
+ ContainerStyle::Card => ContainerStyles {
+ background_color: color.elevated_surface_background,
+ border_color: color.border,
+ text_color: color.text,
+ },
+ };
+
+ v_flex()
+ .id(self.id)
+ .relative()
+ .rounded_md()
+ .bg(styles.background_color)
+ .border()
+ .border_color(styles.border_color)
+ .text_color(styles.text_color)
+ .overflow_hidden()
+ .child(
+ h_flex()
+ .overflow_hidden()
+ .w_full()
+ .group("toggleable_container_header")
+ .border_b()
+ .border_color(if self.toggle {
+ styles.border_color
+ } else {
+ color.border_transparent
+ })
+ .child(
+ self.base.full_width().style(ButtonStyle::Subtle).child(
+ div()
+ .h_7()
+ .p_1()
+ .flex()
+ .flex_1()
+ .items_center()
+ .justify_between()
+ .w_full()
+ .gap_1()
+ .cursor_pointer()
+ .group_hover("toggleable_container_header", |this| {
+ this.bg(color.element_hover)
+ })
+ .child(
+ h_flex()
+ .gap_1()
+ .child(
+ IconButton::new(
+ "toggle_icon",
+ match self.toggle {
+ true => IconName::ChevronDown,
+ false => IconName::ChevronRight,
+ },
+ )
+ .icon_color(Color::Muted)
+ .icon_size(IconSize::XSmall),
+ )
+ .child(
+ div()
+ .id("label_container")
+ .flex()
+ .gap_1()
+ .items_center()
+ .children(self.start_slot),
+ ),
+ )
+ .child(h_flex().children(self.end_slot)),
+ ),
+ ),
+ )
+ .when(self.toggle, |this| {
+ this.child(h_flex().flex_1().w_full().p_1().children(self.children))
+ })
+ }
+}
@@ -110,7 +110,7 @@ impl RenderOnce for WindowsCaptionButton {
.content_center()
.w(width)
.h_full()
- .font("Segoe Fluent Icons")
+ .font_family("Segoe Fluent Icons")
.text_size(px(10.0))
.hover(|style| style.bg(self.hover_background_color))
.active(|style| {
@@ -95,7 +95,7 @@ pub fn tooltip_container<V>(
div().pl_2().pt_2p5().child(
v_flex()
.elevation_2(cx)
- .font(ui_font)
+ .font_family(ui_font)
.text_ui()
.text_color(cx.theme().colors().text)
.py_1()
@@ -93,7 +93,7 @@ impl RenderOnce for Headline {
let ui_font = ThemeSettings::get_global(cx).ui_font.family.clone();
div()
- .font(ui_font)
+ .font_family(ui_font)
.line_height(self.size.line_height())
.text_size(self.size.size())
.text_color(cx.theme().colors().text)
@@ -2928,6 +2928,6 @@ impl Render for DraggedTab {
.selected(self.is_active)
.child(label)
.render(cx)
- .font(ui_font)
+ .font_family(ui_font)
}
}
@@ -4004,7 +4004,7 @@ impl Render for Workspace {
.size_full()
.flex()
.flex_col()
- .font(ui_font)
+ .font_family(ui_font)
.gap_0()
.justify_start()
.items_start()
@@ -19,6 +19,7 @@ activity_indicator.workspace = true
anyhow.workspace = true
assets.workspace = true
assistant.workspace = true
+assistant2.workspace = true
audio.workspace = true
auto_update.workspace = true
backtrace = "0.3"
@@ -231,27 +231,18 @@ fn init_ui(args: Args) {
load_embedded_fonts(cx);
- let mut store = SettingsStore::default();
- store
- .set_default_settings(default_settings().as_ref(), cx)
- .unwrap();
- cx.set_global(store);
+ settings::init(cx);
handle_settings_file_changes(user_settings_file_rx, cx);
handle_keymap_file_changes(user_keymap_file_rx, cx);
- client::init_settings(cx);
-
- let clock = Arc::new(clock::RealSystemClock);
- let http = Arc::new(HttpClientWithUrl::new(
- &client::ClientSettings::get_global(cx).server_url,
- ));
- let client = client::Client::new(clock, http.clone(), cx);
+ client::init_settings(cx);
+ let client = Client::production(cx);
let mut languages =
LanguageRegistry::new(login_shell_env_loaded, cx.background_executor().clone());
let copilot_language_server_id = languages.next_language_server_id();
languages.set_language_server_download_dir(paths::LANGUAGES_DIR.clone());
let languages = Arc::new(languages);
- let node_runtime = RealNodeRuntime::new(http.clone());
+ let node_runtime = RealNodeRuntime::new(client.http_client());
language::init(cx);
languages::init(languages.clone(), node_runtime.clone(), cx);
@@ -271,11 +262,14 @@ fn init_ui(args: Args) {
diagnostics::init(cx);
copilot::init(
copilot_language_server_id,
- http.clone(),
+ client.http_client(),
node_runtime.clone(),
cx,
);
+
assistant::init(client.clone(), cx);
+ assistant2::init(client.clone(), cx);
+
init_inline_completion_provider(client.telemetry().clone(), cx);
extension::init(
@@ -297,7 +291,7 @@ fn init_ui(args: Args) {
cx.observe_global::<SettingsStore>({
let languages = languages.clone();
- let http = http.clone();
+ let http = client.http_client();
let client = client.clone();
move |cx| {
@@ -345,7 +339,7 @@ fn init_ui(args: Args) {
AppState::set_global(Arc::downgrade(&app_state), cx);
audio::init(Assets, cx);
- auto_update::init(http.clone(), cx);
+ auto_update::init(client.http_client(), cx);
workspace::init(app_state.clone(), cx);
recent_projects::init(cx);
@@ -378,7 +372,7 @@ fn init_ui(args: Args) {
initialize_workspace(app_state.clone(), cx);
// todo(linux): unblock this
- upload_panics_and_crashes(http.clone(), cx);
+ upload_panics_and_crashes(client.http_client(), cx);
cx.activate(true);
@@ -3,7 +3,6 @@ mod only_instance;
mod open_listener;
pub use app_menus::*;
-use assistant::AssistantPanel;
use breadcrumbs::Breadcrumbs;
use client::ZED_URL_SCHEME;
use collections::VecDeque;
@@ -181,10 +180,12 @@ pub fn initialize_workspace(app_state: Arc<AppState>, cx: &mut AppContext) {
})
});
}
+
cx.spawn(|workspace_handle, mut cx| async move {
+ let assistant_panel =
+ assistant::AssistantPanel::load(workspace_handle.clone(), cx.clone());
let project_panel = ProjectPanel::load(workspace_handle.clone(), cx.clone());
let terminal_panel = TerminalPanel::load(workspace_handle.clone(), cx.clone());
- let assistant_panel = AssistantPanel::load(workspace_handle.clone(), cx.clone());
let channels_panel =
collab_ui::collab_panel::CollabPanel::load(workspace_handle.clone(), cx.clone());
let chat_panel =
@@ -193,6 +194,7 @@ pub fn initialize_workspace(app_state: Arc<AppState>, cx: &mut AppContext) {
workspace_handle.clone(),
cx.clone(),
);
+
let (
project_panel,
terminal_panel,
@@ -210,9 +212,9 @@ pub fn initialize_workspace(app_state: Arc<AppState>, cx: &mut AppContext) {
)?;
workspace_handle.update(&mut cx, |workspace, cx| {
+ workspace.add_panel(assistant_panel, cx);
workspace.add_panel(project_panel, cx);
workspace.add_panel(terminal_panel, cx);
- workspace.add_panel(assistant_panel, cx);
workspace.add_panel(channels_panel, cx);
workspace.add_panel(chat_panel, cx);
workspace.add_panel(notification_panel, cx);
@@ -221,6 +223,30 @@ pub fn initialize_workspace(app_state: Arc<AppState>, cx: &mut AppContext) {
})
.detach();
+ let mut current_user = app_state.user_store.read(cx).watch_current_user();
+
+ cx.spawn(|workspace_handle, mut cx| async move {
+ while let Some(user) = current_user.next().await {
+ if user.is_some() {
+ // User known now, can check feature flags / staff
+ // At this point, should have the user with staff status available
+ let use_assistant2 = cx.update(|cx| assistant2::enabled(cx))?;
+ if use_assistant2 {
+ let panel =
+ assistant2::AssistantPanel::load(workspace_handle.clone(), cx.clone())
+ .await?;
+ workspace_handle.update(&mut cx, |workspace, cx| {
+ workspace.add_panel(panel, cx);
+ })?;
+ }
+
+ break;
+ }
+ }
+ anyhow::Ok(())
+ })
+ .detach();
+
workspace
.register_action(about)
.register_action(|_, _: &Minimize, cx| {
@@ -3028,11 +3054,7 @@ mod tests {
])
.unwrap();
let themes = ThemeRegistry::default();
- let mut settings = SettingsStore::default();
- settings
- .set_default_settings(&settings::default_settings(), cx)
- .unwrap();
- cx.set_global(settings);
+ settings::init(cx);
theme::init(theme::LoadThemes::JustBase, cx);
let mut has_default_theme = false;
@@ -147,7 +147,7 @@ setTimeout(() => {
}
spawn(binaryPath, i == 0 ? args : [], {
stdio: "inherit",
- env: {
+ env: Object.assign({}, process.env, {
ZED_IMPERSONATE: users[i],
ZED_WINDOW_POSITION: position,
ZED_STATELESS: isStateful && i == 0 ? "1" : "",
@@ -157,9 +157,8 @@ setTimeout(() => {
ZED_ADMIN_API_TOKEN: "secret",
ZED_WINDOW_SIZE: size,
ZED_CLIENT_CHECKSUM_SEED: "development-checksum-seed",
- PATH: process.env.PATH,
RUST_LOG: process.env.RUST_LOG || "info",
- },
+ }),
});
}
}, 0.1);