Detailed changes
@@ -704,6 +704,7 @@ dependencies = [
"assistant_tool",
"chrono",
"collections",
+ "feature_flags",
"futures 0.3.31",
"gpui",
"html_to_markdown",
@@ -721,9 +722,11 @@ dependencies = [
"ui",
"unindent",
"util",
+ "web_search",
"workspace",
"workspace-hack",
"worktree",
+ "zed_llm_client",
]
[[package]]
@@ -16609,6 +16612,36 @@ dependencies = [
"wasm-bindgen",
]
+[[package]]
+name = "web_search"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "collections",
+ "gpui",
+ "serde",
+ "workspace-hack",
+ "zed_llm_client",
+]
+
+[[package]]
+name = "web_search_providers"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "client",
+ "feature_flags",
+ "futures 0.3.31",
+ "gpui",
+ "http_client",
+ "language_model",
+ "serde",
+ "serde_json",
+ "web_search",
+ "workspace-hack",
+ "zed_llm_client",
+]
+
[[package]]
name = "webpki-root-certs"
version = "0.26.8"
@@ -18287,6 +18320,8 @@ dependencies = [
"uuid",
"vim",
"vim_mode_setting",
+ "web_search",
+ "web_search_providers",
"welcome",
"windows 0.61.1",
"winresource",
@@ -18351,9 +18386,9 @@ dependencies = [
[[package]]
name = "zed_llm_client"
-version = "0.4.2"
+version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1d28a5d6bdb0f40acf5261c39cabbf65a13b55ba4b86d9beb5b8b1c484373f1a"
+checksum = "57a5e1b5b3ace3fb55292a4c14036723bb8a01fac4aeaa3c2b63b51228412f94"
dependencies = [
"serde",
"serde_json",
@@ -165,6 +165,8 @@ members = [
"crates/util_macros",
"crates/vim",
"crates/vim_mode_setting",
+ "crates/web_search",
+ "crates/web_search_providers",
"crates/welcome",
"crates/workspace",
"crates/worktree",
@@ -370,6 +372,8 @@ util = { path = "crates/util" }
util_macros = { path = "crates/util_macros" }
vim = { path = "crates/vim" }
vim_mode_setting = { path = "crates/vim_mode_setting" }
+web_search = { path = "crates/web_search" }
+web_search_providers = { path = "crates/web_search_providers" }
welcome = { path = "crates/welcome" }
workspace = { path = "crates/workspace" }
worktree = { path = "crates/worktree" }
@@ -601,7 +605,7 @@ wasmtime-wasi = "29"
which = "6.0.0"
wit-component = "0.221"
workspace-hack = "0.1.0"
-zed_llm_client = "0.4.2"
+zed_llm_client = "0.5.0"
zstd = "0.11"
metal = "0.29"
@@ -652,7 +652,8 @@
"path_search": true,
"read_file": true,
"regex_search": true,
- "thinking": true
+ "thinking": true,
+ "web_search": true
}
},
"write": {
@@ -678,7 +679,8 @@
"regex_search": true,
"rename": true,
"symbol_info": true,
- "thinking": true
+ "thinking": true,
+ "web_search": true
}
}
},
@@ -5,11 +5,12 @@ use crate::thread::{
ThreadEvent, ThreadFeedback,
};
use crate::thread_store::{RulesLoadingError, ThreadStore};
-use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus};
+use crate::tool_use::{PendingToolUseStatus, ToolUse};
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
use anyhow::Context as _;
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
+use assistant_tool::ToolUseStatus;
use collections::{HashMap, HashSet};
use editor::scroll::Autoscroll;
use editor::{Editor, EditorElement, EditorStyle, MultiBuffer};
@@ -943,8 +944,8 @@ impl ActiveThread {
&tool_use.input,
self.thread
.read(cx)
- .tool_result(&tool_use.id)
- .map(|result| result.content.clone().into())
+ .output_for_tool(&tool_use.id)
+ .map(|output| output.clone().into())
.unwrap_or("".into()),
cx,
);
@@ -2279,12 +2280,15 @@ impl ActiveThread {
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement + use<> {
+ if let Some(card) = self.thread.read(cx).card_for_tool(&tool_use.id) {
+ return card.render(&tool_use.status, window, cx);
+ }
+
let is_open = self
.expanded_tool_uses
.get(&tool_use.id)
.copied()
.unwrap_or_default();
-
let is_status_finished = matches!(&tool_use.status, ToolUseStatus::Finished(_));
let fs = self
@@ -2381,6 +2385,7 @@ impl ActiveThread {
open_markdown_link(text, workspace.clone(), window, cx);
}
})
+ .into_any_element()
}),
)),
),
@@ -2437,6 +2442,7 @@ impl ActiveThread {
open_markdown_link(text, workspace.clone(), window, cx);
}
})
+ .into_any_element()
})),
),
),
@@ -2767,7 +2773,7 @@ impl ActiveThread {
)
})
}
- })
+ }).into_any_element()
}
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
@@ -6,7 +6,7 @@ use std::time::Instant;
use anyhow::{Context as _, Result, anyhow};
use assistant_settings::AssistantSettings;
-use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
+use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap};
use feature_flags::{self, FeatureFlagAppExt};
@@ -631,6 +631,14 @@ impl Thread {
self.tool_use.tool_result(id)
}
+ pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
+ Some(&self.tool_use.tool_result(id)?.content)
+ }
+
+ pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
+ self.tool_use.tool_result_card(id).cloned()
+ }
+
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_use.message_has_tool_results(message_id)
}
@@ -1426,6 +1434,12 @@ impl Thread {
)
};
+ // Store the card separately if it exists
+ if let Some(card) = tool_result.card.clone() {
+ self.tool_use
+ .insert_tool_result_card(tool_use_id.clone(), card);
+ }
+
cx.spawn({
async move |thread: WeakEntity<Thread>, cx| {
let output = tool_result.output.await;
@@ -1,7 +1,7 @@
use std::sync::Arc;
use anyhow::Result;
-use assistant_tool::{Tool, ToolWorkingSet};
+use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
use collections::HashMap;
use futures::FutureExt as _;
use futures::future::Shared;
@@ -27,26 +27,7 @@ pub struct ToolUse {
pub needs_confirmation: bool,
}
-#[derive(Debug, Clone)]
-pub enum ToolUseStatus {
- NeedsConfirmation,
- Pending,
- Running,
- Finished(SharedString),
- Error(SharedString),
-}
-
-impl ToolUseStatus {
- pub fn text(&self) -> SharedString {
- match self {
- ToolUseStatus::NeedsConfirmation => "".into(),
- ToolUseStatus::Pending => "".into(),
- ToolUseStatus::Running => "".into(),
- ToolUseStatus::Finished(out) => out.clone(),
- ToolUseStatus::Error(out) => out.clone(),
- }
- }
-}
+pub const USING_TOOL_MARKER: &str = "<using_tool>";
pub struct ToolUseState {
tools: Entity<ToolWorkingSet>,
@@ -54,10 +35,9 @@ pub struct ToolUseState {
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
+ tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
}
-pub const USING_TOOL_MARKER: &str = "<using_tool>";
-
impl ToolUseState {
pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
Self {
@@ -66,6 +46,7 @@ impl ToolUseState {
tool_uses_by_user_message: HashMap::default(),
tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
+ tool_result_cards: HashMap::default(),
}
}
@@ -257,6 +238,18 @@ impl ToolUseState {
self.tool_results.get(tool_use_id)
}
+ pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
+ self.tool_result_cards.get(tool_use_id)
+ }
+
+ pub fn insert_tool_result_card(
+ &mut self,
+ tool_use_id: LanguageModelToolUseId,
+ card: AnyToolCard,
+ ) {
+ self.tool_result_cards.insert(tool_use_id, card);
+ }
+
pub fn request_tool_use(
&mut self,
assistant_message_id: MessageId,
@@ -9,6 +9,10 @@ use std::fmt::Formatter;
use std::sync::Arc;
use anyhow::Result;
+use gpui::AnyElement;
+use gpui::Context;
+use gpui::IntoElement;
+use gpui::Window;
use gpui::{App, Entity, SharedString, Task};
use icons::IconName;
use language_model::LanguageModelRequestMessage;
@@ -24,16 +28,87 @@ pub fn init(cx: &mut App) {
ToolRegistry::default_global(cx);
}
-/// The result of running a tool
+#[derive(Debug, Clone)]
+pub enum ToolUseStatus {
+ NeedsConfirmation,
+ Pending,
+ Running,
+ Finished(SharedString),
+ Error(SharedString),
+}
+
+impl ToolUseStatus {
+ pub fn text(&self) -> SharedString {
+ match self {
+ ToolUseStatus::NeedsConfirmation => "".into(),
+ ToolUseStatus::Pending => "".into(),
+ ToolUseStatus::Running => "".into(),
+ ToolUseStatus::Finished(out) => out.clone(),
+ ToolUseStatus::Error(out) => out.clone(),
+ }
+ }
+}
+
+/// The result of running a tool, containing both the asynchronous output
+/// and an optional card view that can be rendered immediately.
pub struct ToolResult {
/// The asynchronous task that will eventually resolve to the tool's output
pub output: Task<Result<String>>,
+ /// An optional view to present the output of the tool.
+ pub card: Option<AnyToolCard>,
+}
+
+pub trait ToolCard: 'static + Sized {
+ fn render(
+ &mut self,
+ status: &ToolUseStatus,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> impl IntoElement;
+}
+
+#[derive(Clone)]
+pub struct AnyToolCard {
+ entity: gpui::AnyEntity,
+ render: fn(
+ entity: gpui::AnyEntity,
+ status: &ToolUseStatus,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyElement,
+}
+
+impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
+ fn from(entity: Entity<T>) -> Self {
+ fn downcast_render<T: ToolCard>(
+ entity: gpui::AnyEntity,
+ status: &ToolUseStatus,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyElement {
+ let entity = entity.downcast::<T>().unwrap();
+ entity.update(cx, |entity, cx| {
+ entity.render(status, window, cx).into_any_element()
+ })
+ }
+
+ Self {
+ entity: entity.into(),
+ render: downcast_render::<T>,
+ }
+ }
+}
+
+impl AnyToolCard {
+ pub fn render(&self, status: &ToolUseStatus, window: &mut Window, cx: &mut App) -> AnyElement {
+ (self.render)(self.entity.clone(), status, window, cx)
+ }
}
impl From<Task<Result<String>>> for ToolResult {
- /// Convert from a task to a ToolResult
+ /// Convert from a task to a ToolResult with no card
fn from(output: Task<Result<String>>) -> Self {
- Self { output }
+ Self { output, card: None }
}
}
@@ -16,6 +16,7 @@ anyhow.workspace = true
assistant_tool.workspace = true
chrono.workspace = true
collections.workspace = true
+feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
html_to_markdown.workspace = true
@@ -32,7 +33,9 @@ ui.workspace = true
util.workspace = true
worktree.workspace = true
open = { workspace = true }
+web_search.workspace = true
workspace-hack.workspace = true
+zed_llm_client.workspace = true
[dev-dependencies]
collections = { workspace = true, features = ["test-support"] }
@@ -22,14 +22,17 @@ mod schema;
mod symbol_info_tool;
mod terminal_tool;
mod thinking_tool;
+mod web_search_tool;
use std::sync::Arc;
use assistant_tool::ToolRegistry;
use copy_path_tool::CopyPathTool;
+use feature_flags::FeatureFlagAppExt;
use gpui::App;
use http_client::HttpClientWithUrl;
use move_path_tool::MovePathTool;
+use web_search_tool::WebSearchTool;
use crate::batch_tool::BatchTool;
use crate::code_action_tool::CodeActionTool;
@@ -56,28 +59,39 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
assistant_tool::init(cx);
let registry = ToolRegistry::global(cx);
- registry.register_tool(TerminalTool);
registry.register_tool(BatchTool);
+ registry.register_tool(CodeActionTool);
+ registry.register_tool(CodeSymbolsTool);
+ registry.register_tool(ContentsTool);
+ registry.register_tool(CopyPathTool);
registry.register_tool(CreateDirectoryTool);
registry.register_tool(CreateFileTool);
- registry.register_tool(CopyPathTool);
registry.register_tool(DeletePathTool);
- registry.register_tool(FindReplaceFileTool);
- registry.register_tool(SymbolInfoTool);
- registry.register_tool(CodeActionTool);
- registry.register_tool(MovePathTool);
registry.register_tool(DiagnosticsTool);
+ registry.register_tool(FetchTool::new(http_client));
+ registry.register_tool(FindReplaceFileTool);
registry.register_tool(ListDirectoryTool);
+ registry.register_tool(MovePathTool);
registry.register_tool(NowTool);
registry.register_tool(OpenTool);
- registry.register_tool(CodeSymbolsTool);
- registry.register_tool(ContentsTool);
registry.register_tool(PathSearchTool);
registry.register_tool(ReadFileTool);
registry.register_tool(RegexSearchTool);
registry.register_tool(RenameTool);
+ registry.register_tool(SymbolInfoTool);
+ registry.register_tool(TerminalTool);
registry.register_tool(ThinkingTool);
- registry.register_tool(FetchTool::new(http_client));
+
+ cx.observe_flag::<feature_flags::ZedProWebSearchTool, _>({
+ move |is_enabled, cx| {
+ if is_enabled {
+ ToolRegistry::global(cx).register_tool(WebSearchTool);
+ } else {
+ ToolRegistry::global(cx).unregister_tool(WebSearchTool);
+ }
+ }
+ })
+ .detach();
}
#[cfg(test)]
@@ -0,0 +1,213 @@
+use std::{sync::Arc, time::Duration};
+
+use crate::schema::json_schema_for;
+use anyhow::{Context as _, Result, anyhow};
+use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
+use futures::{FutureExt, TryFutureExt};
+use gpui::{
+ Animation, AnimationExt, App, AppContext, Context, Entity, IntoElement, Task, Window,
+ pulsating_between,
+};
+use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
+use project::Project;
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use ui::{IconName, Tooltip, prelude::*};
+use web_search::WebSearchRegistry;
+use zed_llm_client::WebSearchResponse;
+
+#[derive(Debug, Serialize, Deserialize, JsonSchema)]
+pub struct WebSearchToolInput {
+ /// The search term or question to query on the web.
+ query: String,
+}
+
+pub struct WebSearchTool;
+
+impl Tool for WebSearchTool {
+ fn name(&self) -> String {
+ "web_search".into()
+ }
+
+ fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
+ false
+ }
+
+ fn description(&self) -> String {
+ "Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
+ }
+
+ fn icon(&self) -> IconName {
+ IconName::Globe
+ }
+
+ fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
+ json_schema_for::<WebSearchToolInput>(format)
+ }
+
+ fn ui_text(&self, _input: &serde_json::Value) -> String {
+ "Web Search".to_string()
+ }
+
+ fn run(
+ self: Arc<Self>,
+ input: serde_json::Value,
+ _messages: &[LanguageModelRequestMessage],
+ _project: Entity<Project>,
+ _action_log: Entity<ActionLog>,
+ cx: &mut App,
+ ) -> ToolResult {
+ let input = match serde_json::from_value::<WebSearchToolInput>(input) {
+ Ok(input) => input,
+ Err(err) => return Task::ready(Err(anyhow!(err))).into(),
+ };
+ let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
+ return Task::ready(Err(anyhow!("Web search is not available."))).into();
+ };
+
+ let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
+ let output = cx.background_spawn({
+ let search_task = search_task.clone();
+ async move {
+ let response = search_task.await.map_err(|err| anyhow!(err))?;
+ serde_json::to_string(&response).context("Failed to serialize search results")
+ }
+ });
+
+ ToolResult {
+ output,
+ card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
+ }
+ }
+}
+
+struct WebSearchToolCard {
+ response: Option<Result<WebSearchResponse>>,
+ _task: Task<()>,
+}
+
+impl WebSearchToolCard {
+ fn new(
+ search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ let _task = cx.spawn(async move |this, cx| {
+ let response = search_task.await.map_err(|err| anyhow!(err));
+ this.update(cx, |this, cx| {
+ this.response = Some(response);
+ cx.notify();
+ })
+ .ok();
+ });
+
+ Self {
+ response: None,
+ _task,
+ }
+ }
+}
+
+impl ToolCard for WebSearchToolCard {
+ fn render(
+ &mut self,
+ _status: &ToolUseStatus,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> impl IntoElement {
+ let header = h_flex()
+ .id("tool-label-container")
+ .gap_1p5()
+ .max_w_full()
+ .overflow_x_scroll()
+ .child(
+ Icon::new(IconName::Globe)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ )
+ .child(match self.response.as_ref() {
+ Some(Ok(response)) => {
+ let text: SharedString = if response.citations.len() == 1 {
+ "1 result".into()
+ } else {
+ format!("{} results", response.citations.len()).into()
+ };
+ h_flex()
+ .gap_1p5()
+ .child(Label::new("Searched the Web").size(LabelSize::Small))
+ .child(
+ div()
+ .size(px(3.))
+ .rounded_full()
+ .bg(cx.theme().colors().text),
+ )
+ .child(Label::new(text).size(LabelSize::Small))
+ .into_any_element()
+ }
+ Some(Err(error)) => div()
+ .id("web-search-error")
+ .child(Label::new("Web Search failed").size(LabelSize::Small))
+ .tooltip(Tooltip::text(error.to_string()))
+ .into_any_element(),
+
+ None => Label::new("Searching the Web…")
+ .size(LabelSize::Small)
+ .with_animation(
+ "web-search-label",
+ Animation::new(Duration::from_secs(2))
+ .repeat()
+ .with_easing(pulsating_between(0.6, 1.)),
+ |label, delta| label.alpha(delta),
+ )
+ .into_any_element(),
+ })
+ .into_any();
+
+ let content =
+ self.response.as_ref().and_then(|response| match response {
+ Ok(response) => {
+ Some(
+ v_flex()
+ .ml_1p5()
+ .pl_1p5()
+ .border_l_1()
+ .border_color(cx.theme().colors().border_variant)
+ .gap_1()
+ .children(response.citations.iter().enumerate().map(
+ |(index, citation)| {
+ let title = citation.title.clone();
+ let url = citation.url.clone();
+
+ Button::new(("citation", index), title)
+ .label_size(LabelSize::Small)
+ .color(Color::Muted)
+ .icon(IconName::ArrowUpRight)
+ .icon_size(IconSize::XSmall)
+ .icon_position(IconPosition::End)
+ .truncate(true)
+ .tooltip({
+ let url = url.clone();
+ move |window, cx| {
+ Tooltip::with_meta(
+ "Citation Link",
+ None,
+ url.clone(),
+ window,
+ cx,
+ )
+ }
+ })
+ .on_click({
+ let url = url.clone();
+ move |_, _, cx| cx.open_url(&url)
+ })
+ },
+ ))
+ .into_any(),
+ )
+ }
+ Err(_) => None,
+ });
+
+ v_flex().my_2().gap_1().child(header).children(content)
+ }
+}
@@ -84,6 +84,11 @@ impl FeatureFlag for ZedPro {
const NAME: &'static str = "zed-pro";
}
+pub struct ZedProWebSearchTool {}
+impl FeatureFlag for ZedProWebSearchTool {
+ const NAME: &'static str = "zed-pro-web-search-tool";
+}
+
pub struct NotebookFeatureFlag;
impl FeatureFlag for NotebookFeatureFlag {
@@ -160,7 +160,11 @@ impl Render for Tooltip {
}),
)
.when_some(self.meta.clone(), |this, meta| {
- this.child(Label::new(meta).size(LabelSize::Small).color(Color::Muted))
+ this.child(
+ div()
+ .max_w_72()
+ .child(Label::new(meta).size(LabelSize::Small).color(Color::Muted)),
+ )
})
})
}
@@ -0,0 +1,20 @@
+[package]
+name = "web_search"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/web_search.rs"
+
+[dependencies]
+anyhow.workspace = true
+collections.workspace = true
+gpui.workspace = true
+serde.workspace = true
+workspace-hack.workspace = true
+zed_llm_client.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,64 @@
+use anyhow::Result;
+use collections::HashMap;
+use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
+use std::sync::Arc;
+use zed_llm_client::WebSearchResponse;
+
+pub fn init(cx: &mut App) {
+ let registry = cx.new(|_cx| WebSearchRegistry::default());
+ cx.set_global(GlobalWebSearchRegistry(registry));
+}
+
+#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
+pub struct WebSearchProviderId(pub SharedString);
+
+pub trait WebSearchProvider {
+ fn id(&self) -> WebSearchProviderId;
+ fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>>;
+}
+
+struct GlobalWebSearchRegistry(Entity<WebSearchRegistry>);
+
+impl Global for GlobalWebSearchRegistry {}
+
+#[derive(Default)]
+pub struct WebSearchRegistry {
+ providers: HashMap<WebSearchProviderId, Arc<dyn WebSearchProvider>>,
+ active_provider: Option<Arc<dyn WebSearchProvider>>,
+}
+
+impl WebSearchRegistry {
+ pub fn global(cx: &App) -> Entity<Self> {
+ cx.global::<GlobalWebSearchRegistry>().0.clone()
+ }
+
+ pub fn read_global(cx: &App) -> &Self {
+ cx.global::<GlobalWebSearchRegistry>().0.read(cx)
+ }
+
+ pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn WebSearchProvider>> {
+ self.providers.values()
+ }
+
+ pub fn active_provider(&self) -> Option<Arc<dyn WebSearchProvider>> {
+ self.active_provider.clone()
+ }
+
+ pub fn set_active_provider(&mut self, provider: Arc<dyn WebSearchProvider>) {
+ self.active_provider = Some(provider.clone());
+ self.providers.insert(provider.id(), provider);
+ }
+
+ pub fn register_provider<T: WebSearchProvider + 'static>(
+ &mut self,
+ provider: T,
+ _cx: &mut Context<Self>,
+ ) {
+ let id = provider.id();
+ let provider = Arc::new(provider);
+ self.providers.insert(id.clone(), provider.clone());
+ if self.active_provider.is_none() {
+ self.active_provider = Some(provider);
+ }
+ }
+}
@@ -0,0 +1,26 @@
+[package]
+name = "web_search_providers"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/web_search_providers.rs"
+
+[dependencies]
+anyhow.workspace = true
+client.workspace = true
+feature_flags.workspace = true
+futures.workspace = true
+gpui.workspace = true
+http_client.workspace = true
+language_model.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+web_search.workspace = true
+workspace-hack.workspace = true
+zed_llm_client.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,103 @@
+use std::sync::Arc;
+
+use anyhow::{Context as _, Result, anyhow};
+use client::Client;
+use futures::AsyncReadExt as _;
+use gpui::{App, AppContext, Context, Entity, Subscription, Task};
+use http_client::{HttpClient, Method};
+use language_model::{LlmApiToken, RefreshLlmTokenListener};
+use web_search::{WebSearchProvider, WebSearchProviderId};
+use zed_llm_client::{WebSearchBody, WebSearchResponse};
+
+pub struct CloudWebSearchProvider {
+ state: Entity<State>,
+}
+
+impl CloudWebSearchProvider {
+ pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
+ let state = cx.new(|cx| State::new(client, cx));
+
+ Self { state }
+ }
+}
+
+pub struct State {
+ client: Arc<Client>,
+ llm_api_token: LlmApiToken,
+ _llm_token_subscription: Subscription,
+}
+
+impl State {
+ pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
+ let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
+
+ Self {
+ client,
+ llm_api_token: LlmApiToken::default(),
+ _llm_token_subscription: cx.subscribe(
+ &refresh_llm_token_listener,
+ |this, _, _event, cx| {
+ let client = this.client.clone();
+ let llm_api_token = this.llm_api_token.clone();
+ cx.spawn(async move |_this, _cx| {
+ llm_api_token.refresh(&client).await?;
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+ },
+ ),
+ }
+ }
+}
+
+impl WebSearchProvider for CloudWebSearchProvider {
+ fn id(&self) -> WebSearchProviderId {
+ WebSearchProviderId("zed.dev".into())
+ }
+
+ fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
+ let state = self.state.read(cx);
+ let client = state.client.clone();
+ let llm_api_token = state.llm_api_token.clone();
+ let body = WebSearchBody { query };
+ cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
+ }
+}
+
+async fn perform_web_search(
+ client: Arc<Client>,
+ llm_api_token: LlmApiToken,
+ body: WebSearchBody,
+) -> Result<WebSearchResponse> {
+ let http_client = &client.http_client();
+
+ let token = llm_api_token.acquire(&client).await?;
+
+ let request_builder = http_client::Request::builder().method(Method::POST);
+ let request_builder = if let Ok(web_search_url) = std::env::var("ZED_WEB_SEARCH_URL") {
+ request_builder.uri(web_search_url)
+ } else {
+ request_builder.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
+ };
+ let request = request_builder
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {token}"))
+ .body(serde_json::to_string(&body)?.into())?;
+ let mut response = http_client
+ .send(request)
+ .await
+ .context("failed to send web search request")?;
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ return Ok(serde_json::from_str(&body)?);
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ return Err(anyhow!(
+ "error performing web search.\nStatus: {:?}\nBody: {body}",
+ response.status(),
+ ));
+ }
+}
@@ -0,0 +1,35 @@
+mod cloud;
+
+use client::Client;
+use feature_flags::{FeatureFlagAppExt, ZedProWebSearchTool};
+use gpui::{App, Context};
+use std::sync::Arc;
+use web_search::WebSearchRegistry;
+
+pub fn init(client: Arc<Client>, cx: &mut App) {
+ let registry = WebSearchRegistry::global(cx);
+ registry.update(cx, |registry, cx| {
+ register_web_search_providers(registry, client, cx);
+ });
+}
+
+fn register_web_search_providers(
+ _registry: &mut WebSearchRegistry,
+ client: Arc<Client>,
+ cx: &mut Context<WebSearchRegistry>,
+) {
+ cx.observe_flag::<ZedProWebSearchTool, _>({
+ let client = client.clone();
+ move |is_enabled, cx| {
+ if is_enabled {
+ WebSearchRegistry::global(cx).update(cx, |registry, cx| {
+ registry.register_provider(
+ cloud::CloudWebSearchProvider::new(client.clone(), cx),
+ cx,
+ );
+ });
+ }
+ }
+ })
+ .detach();
+}
@@ -133,6 +133,8 @@ util.workspace = true
uuid.workspace = true
vim.workspace = true
vim_mode_setting.workspace = true
+web_search.workspace = true
+web_search_providers.workspace = true
welcome.workspace = true
workspace.workspace = true
zed_actions.workspace = true
@@ -490,6 +490,8 @@ fn main() {
app_state.fs.clone(),
cx,
);
+ web_search::init(cx);
+ web_search_providers::init(app_state.client.clone(), cx);
snippet_provider::init(cx);
inline_completion_registry::init(
app_state.client.clone(),
@@ -4258,6 +4258,8 @@ mod tests {
app_state.fs.clone(),
cx,
);
+ web_search::init(cx);
+ web_search_providers::init(app_state.client.clone(), cx);
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx);
assistant::init(
app_state.fs.clone(),