Detailed changes
@@ -3202,7 +3202,6 @@ dependencies = [
"rustc-hash 2.1.1",
"schemars 1.0.4",
"serde",
- "serde_json",
"strum 0.27.2",
]
@@ -8867,6 +8866,7 @@ dependencies = [
"open_router",
"parking_lot",
"proto",
+ "schemars 1.0.4",
"serde",
"serde_json",
"settings",
@@ -21685,6 +21685,7 @@ dependencies = [
"serde",
"serde_json",
"settings",
+ "smol",
"thiserror 2.0.17",
"util",
"uuid",
@@ -21702,6 +21703,7 @@ dependencies = [
"clap",
"client",
"cloud_llm_client",
+ "cloud_zeta2_prompt",
"collections",
"edit_prediction_context",
"editor",
@@ -21715,7 +21717,6 @@ dependencies = [
"ordered-float 2.10.1",
"pretty_assertions",
"project",
- "regex-syntax",
"serde",
"serde_json",
"settings",
@@ -6,7 +6,6 @@ mod native_agent_server;
pub mod outline;
mod templates;
mod thread;
-mod tool_schema;
mod tools;
#[cfg(test)]
@@ -1,6 +1,6 @@
use anyhow::Result;
use gpui::{AsyncApp, Entity};
-use language::{Buffer, OutlineItem, ParseStatus};
+use language::{Buffer, OutlineItem};
use regex::Regex;
use std::fmt::Write;
use text::Point;
@@ -30,10 +30,9 @@ pub async fn get_buffer_content_or_outline(
if file_size > AUTO_OUTLINE_SIZE {
// For large files, use outline instead of full content
// Wait until the buffer has been fully parsed, so we can read its outline
- let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
- while *parse_status.borrow() != ParseStatus::Idle {
- parse_status.changed().await?;
- }
+ buffer
+ .read_with(cx, |buffer, _| buffer.parsing_idle())?
+ .await;
let outline_items = buffer.read_with(cx, |buffer, _| {
let snapshot = buffer.snapshot();
@@ -2139,7 +2139,7 @@ where
/// Returns the JSON schema that describes the tool's input.
fn input_schema(format: LanguageModelToolSchemaFormat) -> Schema {
- crate::tool_schema::root_schema_for::<Self::Input>(format)
+ language_model::tool_schema::root_schema_for::<Self::Input>(format)
}
/// Some tools rely on a provider for the underlying billing or other reasons.
@@ -2226,7 +2226,7 @@ where
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
let mut json = serde_json::to_value(T::input_schema(format))?;
- crate::tool_schema::adapt_schema_to_format(&mut json, format)?;
+ language_model::tool_schema::adapt_schema_to_format(&mut json, format)?;
Ok(json)
}
@@ -165,7 +165,7 @@ impl AnyAgentTool for ContextServerTool {
format: language_model::LanguageModelToolSchemaFormat,
) -> Result<serde_json::Value> {
let mut schema = self.tool.input_schema.clone();
- crate::tool_schema::adapt_schema_to_format(&mut schema, format)?;
+ language_model::tool_schema::adapt_schema_to_format(&mut schema, format)?;
Ok(match schema {
serde_json::Value::Null => {
serde_json::json!({ "type": "object", "properties": [] })
@@ -19,5 +19,4 @@ ordered-float.workspace = true
rustc-hash.workspace = true
schemars.workspace = true
serde.workspace = true
-serde_json.workspace = true
strum.workspace = true
@@ -3,7 +3,7 @@ use cloud_llm_client::predict_edits_v3::{self, Excerpt};
use indoc::indoc;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
-use std::{fmt::Write, sync::LazyLock};
+use std::fmt::Write;
use crate::{push_events, write_codeblock};
@@ -15,7 +15,7 @@ pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> R
push_events(&mut prompt, &request.events);
}
- writeln!(&mut prompt, "## Excerpt around the cursor\n")?;
+ writeln!(&mut prompt, "## Cursor context")?;
write_codeblock(
&request.excerpt_path,
&[Excerpt {
@@ -39,54 +39,56 @@ pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> R
#[derive(Clone, Deserialize, Serialize, JsonSchema)]
pub struct SearchToolInput {
/// An array of queries to run for gathering context relevant to the next prediction
- #[schemars(length(max = 5))]
+ #[schemars(length(max = 3))]
pub queries: Box<[SearchToolQuery]>,
}
+/// Search for relevant code by path, syntax hierarchy, and content.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct SearchToolQuery {
- /// A glob pattern to match file paths in the codebase
+ /// 1. A glob pattern to match file paths in the codebase to search in.
pub glob: String,
- /// A regular expression to match content within the files matched by the glob pattern
- pub regex: String,
+ /// 2. Regular expressions to match syntax nodes **by their first line** and hierarchy.
+ ///
+ /// Subsequent regexes match nodes within the full content of the nodes matched by the previous regexes.
+ ///
+ /// Example: Searching for a `User` class
+ /// ["class\s+User"]
+ ///
+ /// Example: Searching for a `get_full_name` method under a `User` class
+ /// ["class\s+User", "def\sget_full_name"]
+ ///
+ /// Skip this field to match on content alone.
+ #[schemars(length(max = 3))]
+ #[serde(default)]
+ pub syntax_node: Vec<String>,
+ /// 3. An optional regular expression to match the final content that should appear in the results.
+ ///
+ /// - Content will be matched within all lines of the matched syntax nodes.
+ /// - If syntax node regexes are provided, this field can be skipped to include as much of the node itself as possible.
+ /// - If no syntax node regexes are provided, the content will be matched within the entire file.
+ pub content: Option<String>,
}
-pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
- let schema = schemars::schema_for!(SearchToolInput);
-
- let description = schema
- .get("description")
- .and_then(|description| description.as_str())
- .unwrap()
- .to_string();
-
- (schema.into(), description)
-});
-
pub const TOOL_NAME: &str = "search";
const SEARCH_INSTRUCTIONS: &str = indoc! {r#"
- ## Task
+ You are part of an edit prediction system in a code editor.
+ Your role is to search for code that will serve as context for predicting the next edit.
- You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
- that will serve as context for predicting the next required edit.
-
- **Your task:**
- Analyze the user's recent edits and current cursor context
- - Use the `search` tool to find code that may be relevant for predicting the next edit
+ - Use the `search` tool to find code that is relevant for predicting the next edit
- Focus on finding:
- Code patterns that might need similar changes based on the recent edits
- Functions, variables, types, and constants referenced in the current cursor context
- Related implementations, usages, or dependencies that may require consistent updates
-
- **Important constraints:**
- - This conversation has exactly 2 turns
- - You must make ALL search queries in your first response via the `search` tool
- - All queries will be executed in parallel and results returned together
- - In the second turn, you will select the most relevant results via the `select` tool.
+ - How items defined in the cursor excerpt are used or altered
+ - You will not be able to filter results or perform subsequent queries, so keep searches as targeted as possible
+ - Use `syntax_node` parameter whenever you're looking for a particular type, class, or function
+ - Avoid using wildcard globs if you already know the file path of the content you're looking for
"#};
const TOOL_USE_REMINDER: &str = indoc! {"
--
- Use the `search` tool now
+ Analyze the user's intent in one to two sentences, then call the `search` tool.
"};
@@ -1618,6 +1618,18 @@ impl Buffer {
self.parse_status.1.clone()
}
+ /// Wait until the buffer is no longer parsing
+ pub fn parsing_idle(&self) -> impl Future<Output = ()> + use<> {
+ let mut parse_status = self.parse_status();
+ async move {
+ while *parse_status.borrow() != ParseStatus::Idle {
+ if parse_status.changed().await.is_err() {
+ break;
+ }
+ }
+ }
+ }
+
/// Assign to the buffer a set of diagnostics created by a given language server.
pub fn update_diagnostics(
&mut self,
@@ -17,7 +17,6 @@ test-support = []
[dependencies]
anthropic = { workspace = true, features = ["schemars"] }
-open_router.workspace = true
anyhow.workspace = true
base64.workspace = true
client.workspace = true
@@ -30,8 +29,10 @@ http_client.workspace = true
icons.workspace = true
image.workspace = true
log.workspace = true
+open_router.workspace = true
parking_lot.workspace = true
proto.workspace = true
+schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
@@ -4,6 +4,7 @@ mod registry;
mod request;
mod role;
mod telemetry;
+pub mod tool_schema;
#[cfg(any(test, feature = "test-support"))]
pub mod fake_provider;
@@ -35,6 +36,7 @@ pub use crate::registry::*;
pub use crate::request::*;
pub use crate::role::*;
pub use crate::telemetry::*;
+pub use crate::tool_schema::LanguageModelToolSchemaFormat;
pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
LanguageModelProviderId::new("anthropic");
@@ -409,15 +411,6 @@ impl From<open_router::ApiError> for LanguageModelCompletionError {
}
}
-/// Indicates the format used to define the input schema for a language model tool.
-#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
-pub enum LanguageModelToolSchemaFormat {
- /// A JSON schema, see https://json-schema.org
- JsonSchema,
- /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
- JsonSchemaSubset,
-}
-
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
@@ -1,5 +1,4 @@
use anyhow::Result;
-use language_model::LanguageModelToolSchemaFormat;
use schemars::{
JsonSchema, Schema,
generate::SchemaSettings,
@@ -7,7 +6,16 @@ use schemars::{
};
use serde_json::Value;
-pub(crate) fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
+/// Indicates the format used to define the input schema for a language model tool.
+#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
+pub enum LanguageModelToolSchemaFormat {
+ /// A JSON schema, see https://json-schema.org
+ JsonSchema,
+ /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
+ JsonSchemaSubset,
+}
+
+pub fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
let mut generator = match format {
LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(),
LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
@@ -4060,7 +4060,7 @@ impl Project {
result_rx
}
- fn find_search_candidate_buffers(
+ pub fn find_search_candidate_buffers(
&mut self,
query: &SearchQuery,
limit: usize,
@@ -33,6 +33,7 @@ project.workspace = true
release_channel.workspace = true
serde.workspace = true
serde_json.workspace = true
+smol.workspace = true
thiserror.workspace = true
util.workspace = true
uuid.workspace = true
@@ -1,18 +1,19 @@
use std::ops::Range;
use anyhow::Result;
+use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
use collections::HashMap;
-use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
use futures::{
StreamExt,
channel::mpsc::{self, UnboundedSender},
};
use gpui::{AppContext, AsyncApp, Entity};
-use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, ToPoint as _};
+use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint};
use project::{
Project, WorktreeSettings,
search::{SearchQuery, SearchResult},
};
+use smol::channel;
use util::{
ResultExt as _,
paths::{PathMatcher, PathStyle},
@@ -21,7 +22,7 @@ use workspace::item::Settings as _;
pub async fn run_retrieval_searches(
project: Entity<Project>,
- regex_by_glob: HashMap<String, String>,
+ queries: Vec<SearchToolQuery>,
cx: &mut AsyncApp,
) -> Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>> {
let (exclude_matcher, path_style) = project.update(cx, |project, cx| {
@@ -37,14 +38,13 @@ pub async fn run_retrieval_searches(
let (results_tx, mut results_rx) = mpsc::unbounded();
- for (glob, regex) in regex_by_glob {
+ for query in queries {
let exclude_matcher = exclude_matcher.clone();
let results_tx = results_tx.clone();
let project = project.clone();
cx.spawn(async move |cx| {
run_query(
- &glob,
- ®ex,
+ query,
results_tx.clone(),
path_style,
exclude_matcher,
@@ -108,87 +108,442 @@ pub async fn run_retrieval_searches(
.await
}
-const MIN_EXCERPT_LEN: usize = 16;
const MAX_EXCERPT_LEN: usize = 768;
const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5;
+struct SearchJob {
+ buffer: Entity<Buffer>,
+ snapshot: BufferSnapshot,
+ ranges: Vec<Range<usize>>,
+ query_ix: usize,
+ jobs_tx: channel::Sender<SearchJob>,
+}
+
async fn run_query(
- glob: &str,
- regex: &str,
+ input_query: SearchToolQuery,
results_tx: UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
path_style: PathStyle,
exclude_matcher: PathMatcher,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<()> {
- let include_matcher = PathMatcher::new(vec![glob], path_style)?;
-
- let query = SearchQuery::regex(
- regex,
- false,
- true,
- false,
- true,
- include_matcher,
- exclude_matcher,
- true,
- None,
- )?;
-
- let results = project.update(cx, |project, cx| project.search(query, cx))?;
- futures::pin_mut!(results);
-
- while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
- if results_tx.is_closed() {
- break;
- }
+ let include_matcher = PathMatcher::new(vec![input_query.glob], path_style)?;
- if ranges.is_empty() {
- continue;
- }
+ let make_search = |regex: &str| -> Result<SearchQuery> {
+ SearchQuery::regex(
+ regex,
+ false,
+ true,
+ false,
+ true,
+ include_matcher.clone(),
+ exclude_matcher.clone(),
+ true,
+ None,
+ )
+ };
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
- let results_tx = results_tx.clone();
+ if let Some(outer_syntax_regex) = input_query.syntax_node.first() {
+ let outer_syntax_query = make_search(outer_syntax_regex)?;
+ let nested_syntax_queries = input_query
+ .syntax_node
+ .into_iter()
+ .skip(1)
+ .map(|query| make_search(&query))
+ .collect::<Result<Vec<_>>>()?;
+ let content_query = input_query
+ .content
+ .map(|regex| make_search(®ex))
+ .transpose()?;
- cx.background_spawn(async move {
- let mut excerpts = Vec::with_capacity(ranges.len());
-
- for range in ranges {
- let offset_range = range.to_offset(&snapshot);
- let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
-
- let excerpt = EditPredictionExcerpt::select_from_buffer(
- query_point,
- &snapshot,
- &EditPredictionExcerptOptions {
- max_bytes: MAX_EXCERPT_LEN,
- min_bytes: MIN_EXCERPT_LEN,
- target_before_cursor_over_total_bytes: 0.5,
- },
- None,
- );
+ let (jobs_tx, jobs_rx) = channel::unbounded();
- if let Some(excerpt) = excerpt
- && !excerpt.line_range.is_empty()
- {
- excerpts.push((
- snapshot.anchor_after(excerpt.range.start)
- ..snapshot.anchor_before(excerpt.range.end),
- excerpt.range.len(),
- ));
- }
+ let outer_search_results_rx =
+ project.update(cx, |project, cx| project.search(outer_syntax_query, cx))?;
+
+ let outer_search_task = cx.spawn(async move |cx| {
+ futures::pin_mut!(outer_search_results_rx);
+ while let Some(SearchResult::Buffer { buffer, ranges }) =
+ outer_search_results_rx.next().await
+ {
+ buffer
+ .read_with(cx, |buffer, _| buffer.parsing_idle())?
+ .await;
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+ let expanded_ranges: Vec<_> = ranges
+ .into_iter()
+ .filter_map(|range| expand_to_parent_range(&range, &snapshot))
+ .collect();
+ jobs_tx
+ .send(SearchJob {
+ buffer,
+ snapshot,
+ ranges: expanded_ranges,
+ query_ix: 0,
+ jobs_tx: jobs_tx.clone(),
+ })
+ .await?;
}
+ anyhow::Ok(())
+ });
- let send_result = results_tx.unbounded_send((buffer, snapshot, excerpts));
+ let n_workers = cx.background_executor().num_cpus();
+ let search_job_task = cx.background_executor().scoped(|scope| {
+ for _ in 0..n_workers {
+ scope.spawn(async {
+ while let Ok(job) = jobs_rx.recv().await {
+ process_nested_search_job(
+ &results_tx,
+ &nested_syntax_queries,
+ &content_query,
+ job,
+ )
+ .await;
+ }
+ });
+ }
+ });
+
+ search_job_task.await;
+ outer_search_task.await?;
+ } else if let Some(content_regex) = &input_query.content {
+ let search_query = make_search(&content_regex)?;
+
+ let results_rx = project.update(cx, |project, cx| project.search(search_query, cx))?;
+ futures::pin_mut!(results_rx);
+
+ while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await {
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+
+ let ranges = ranges
+ .into_iter()
+ .map(|range| {
+ let range = range.to_offset(&snapshot);
+ let range = expand_to_entire_lines(range, &snapshot);
+ let size = range.len();
+ let range =
+ snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
+ (range, size)
+ })
+ .collect();
+
+ let send_result = results_tx.unbounded_send((buffer.clone(), snapshot.clone(), ranges));
if let Err(err) = send_result
&& !err.is_disconnected()
{
log::error!("{err}");
}
- })
- .detach();
+ }
+ } else {
+ log::warn!("Context gathering model produced a glob-only search");
}
anyhow::Ok(())
}
+
+async fn process_nested_search_job(
+ results_tx: &UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
+ queries: &Vec<SearchQuery>,
+ content_query: &Option<SearchQuery>,
+ job: SearchJob,
+) {
+ if let Some(search_query) = queries.get(job.query_ix) {
+ let mut subranges = Vec::new();
+ for range in job.ranges {
+ let start = range.start;
+ let search_results = search_query.search(&job.snapshot, Some(range)).await;
+ for subrange in search_results {
+ let subrange = start + subrange.start..start + subrange.end;
+ subranges.extend(expand_to_parent_range(&subrange, &job.snapshot));
+ }
+ }
+ job.jobs_tx
+ .send(SearchJob {
+ buffer: job.buffer,
+ snapshot: job.snapshot,
+ ranges: subranges,
+ query_ix: job.query_ix + 1,
+ jobs_tx: job.jobs_tx.clone(),
+ })
+ .await
+ .ok();
+ } else {
+ let ranges = if let Some(content_query) = content_query {
+ let mut subranges = Vec::new();
+ for range in job.ranges {
+ let start = range.start;
+ let search_results = content_query.search(&job.snapshot, Some(range)).await;
+ for subrange in search_results {
+ let subrange = start + subrange.start..start + subrange.end;
+ subranges.push(subrange);
+ }
+ }
+ subranges
+ } else {
+ job.ranges
+ };
+
+ let matches = ranges
+ .into_iter()
+ .map(|range| {
+ let snapshot = &job.snapshot;
+ let range = expand_to_entire_lines(range, snapshot);
+ let size = range.len();
+ let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
+ (range, size)
+ })
+ .collect();
+
+ let send_result = results_tx.unbounded_send((job.buffer, job.snapshot, matches));
+
+ if let Err(err) = send_result
+ && !err.is_disconnected()
+ {
+ log::error!("{err}");
+ }
+ }
+}
+
+fn expand_to_entire_lines(range: Range<usize>, snapshot: &BufferSnapshot) -> Range<usize> {
+ let mut point_range = range.to_point(snapshot);
+ point_range.start.column = 0;
+ if point_range.end.column > 0 {
+ point_range.end = snapshot.max_point().min(point_range.end + Point::new(1, 0));
+ }
+ point_range.to_offset(snapshot)
+}
+
+fn expand_to_parent_range<T: ToPoint + ToOffset>(
+ range: &Range<T>,
+ snapshot: &BufferSnapshot,
+) -> Option<Range<usize>> {
+ let mut line_range = range.to_point(&snapshot);
+ line_range.start.column = snapshot.indent_size_for_line(line_range.start.row).len;
+ line_range.end.column = snapshot.line_len(line_range.end.row);
+ // TODO skip result if matched line isn't the first node line?
+
+ let node = snapshot.syntax_ancestor(line_range)?;
+ Some(node.byte_range())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::merge_excerpts::merge_excerpts;
+ use cloud_zeta2_prompt::write_codeblock;
+ use edit_prediction_context::Line;
+ use gpui::TestAppContext;
+ use indoc::indoc;
+ use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
+ use pretty_assertions::assert_eq;
+ use project::FakeFs;
+ use serde_json::json;
+ use settings::SettingsStore;
+ use std::path::Path;
+ use util::path;
+
+ #[gpui::test]
+ async fn test_retrieval(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/root"),
+ json!({
+ "user.rs": indoc!{"
+ pub struct Organization {
+ owner: Arc<User>,
+ }
+
+ pub struct User {
+ first_name: String,
+ last_name: String,
+ }
+
+ impl Organization {
+ pub fn owner(&self) -> Arc<User> {
+ self.owner.clone()
+ }
+ }
+
+ impl User {
+ pub fn new(first_name: String, last_name: String) -> Self {
+ Self {
+ first_name,
+ last_name
+ }
+ }
+
+ pub fn first_name(&self) -> String {
+ self.first_name.clone()
+ }
+
+ pub fn last_name(&self) -> String {
+ self.last_name.clone()
+ }
+ }
+ "},
+ "main.rs": indoc!{r#"
+ fn main() {
+ let user = User::new(FIRST_NAME.clone(), "doe".into());
+ println!("user {:?}", user);
+ }
+ "#},
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, vec![Path::new(path!("/root"))], cx).await;
+ project.update(cx, |project, _cx| {
+ project.languages().add(rust_lang().into())
+ });
+
+ assert_results(
+ &project,
+ SearchToolQuery {
+ glob: "user.rs".into(),
+ syntax_node: vec!["impl\\s+User".into(), "pub\\s+fn\\s+first_name".into()],
+ content: None,
+ },
+ indoc! {r#"
+ `````root/user.rs
+ …
+ impl User {
+ …
+ pub fn first_name(&self) -> String {
+ self.first_name.clone()
+ }
+ …
+ `````
+ "#},
+ cx,
+ )
+ .await;
+
+ assert_results(
+ &project,
+ SearchToolQuery {
+ glob: "user.rs".into(),
+ syntax_node: vec!["impl\\s+User".into()],
+ content: Some("\\.clone".into()),
+ },
+ indoc! {r#"
+ `````root/user.rs
+ …
+ impl User {
+ …
+ pub fn first_name(&self) -> String {
+ self.first_name.clone()
+ …
+ pub fn last_name(&self) -> String {
+ self.last_name.clone()
+ …
+ `````
+ "#},
+ cx,
+ )
+ .await;
+
+ assert_results(
+ &project,
+ SearchToolQuery {
+ glob: "*.rs".into(),
+ syntax_node: vec![],
+ content: Some("\\.clone".into()),
+ },
+ indoc! {r#"
+ `````root/main.rs
+ fn main() {
+ let user = User::new(FIRST_NAME.clone(), "doe".into());
+ …
+ `````
+
+ `````root/user.rs
+ …
+ impl Organization {
+ pub fn owner(&self) -> Arc<User> {
+ self.owner.clone()
+ …
+ impl User {
+ …
+ pub fn first_name(&self) -> String {
+ self.first_name.clone()
+ …
+ pub fn last_name(&self) -> String {
+ self.last_name.clone()
+ …
+ `````
+ "#},
+ cx,
+ )
+ .await;
+ }
+
+ async fn assert_results(
+ project: &Entity<Project>,
+ query: SearchToolQuery,
+ expected_output: &str,
+ cx: &mut TestAppContext,
+ ) {
+ let results = run_retrieval_searches(project.clone(), vec![query], &mut cx.to_async())
+ .await
+ .unwrap();
+
+ let mut results = results.into_iter().collect::<Vec<_>>();
+ results.sort_by_key(|results| {
+ results
+ .0
+ .read_with(cx, |buffer, _| buffer.file().unwrap().path().clone())
+ });
+
+ let mut output = String::new();
+ for (buffer, ranges) in results {
+ buffer.read_with(cx, |buffer, cx| {
+ let excerpts = ranges.into_iter().map(|range| {
+ let point_range = range.to_point(buffer);
+ if point_range.end.column > 0 {
+ Line(point_range.start.row)..Line(point_range.end.row + 1)
+ } else {
+ Line(point_range.start.row)..Line(point_range.end.row)
+ }
+ });
+
+ write_codeblock(
+ &buffer.file().unwrap().full_path(cx),
+ merge_excerpts(&buffer.snapshot(), excerpts).iter(),
+ &[],
+ Line(buffer.max_point().row),
+ false,
+ &mut output,
+ );
+ });
+ }
+ output.pop();
+
+ assert_eq!(output, expected_output);
+ }
+
+ fn rust_lang() -> Language {
+ Language::new(
+ LanguageConfig {
+ name: "Rust".into(),
+ matcher: LanguageMatcher {
+ path_suffixes: vec!["rs".to_string()],
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ Some(tree_sitter_rust::LANGUAGE.into()),
+ )
+ .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
+ .unwrap()
+ }
+
+ fn init_test(cx: &mut TestAppContext) {
+ cx.update(move |cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ zlog::init_test();
+ });
+ }
+}
@@ -18,10 +18,10 @@ use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, TextBufferSn
use project::Project;
pub async fn parse_diff<'a>(
- diff: &'a str,
+ diff_str: &'a str,
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
- let mut diff = DiffParser::new(diff);
+ let mut diff = DiffParser::new(diff_str);
let mut edited_buffer = None;
let mut edits = Vec::new();
@@ -41,7 +41,10 @@ pub async fn parse_diff<'a>(
Some(ref current) => current,
};
- edits.extend(resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)?);
+ edits.extend(
+ resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)
+ .with_context(|| format!("Diff:\n{diff_str}"))?,
+ );
}
DiffEvent::FileEnd { renamed_to } => {
let (buffer, _) = edited_buffer
@@ -69,13 +72,13 @@ pub struct OpenedBuffers<'a>(#[allow(unused)] HashMap<Cow<'a, str>, Entity<Buffe
#[must_use]
pub async fn apply_diff<'a>(
- diff: &'a str,
+ diff_str: &'a str,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<OpenedBuffers<'a>> {
let mut included_files = HashMap::default();
- for line in diff.lines() {
+ for line in diff_str.lines() {
let diff_line = DiffLine::parse(line);
if let DiffLine::OldPath { path } = diff_line {
@@ -97,7 +100,7 @@ pub async fn apply_diff<'a>(
let ranges = [Anchor::MIN..Anchor::MAX];
- let mut diff = DiffParser::new(diff);
+ let mut diff = DiffParser::new(diff_str);
let mut current_file = None;
let mut edits = vec![];
@@ -120,7 +123,10 @@ pub async fn apply_diff<'a>(
};
buffer.read_with(cx, |buffer, _| {
- edits.extend(resolve_hunk_edits_in_buffer(hunk, buffer, ranges)?);
+ edits.extend(
+ resolve_hunk_edits_in_buffer(hunk, buffer, ranges)
+ .with_context(|| format!("Diff:\n{diff_str}"))?,
+ );
anyhow::Ok(())
})??;
}
@@ -328,13 +334,7 @@ fn resolve_hunk_edits_in_buffer(
offset = Some(range.start + ix);
}
}
- offset.ok_or_else(|| {
- anyhow!(
- "Failed to match context:\n{}\n\nBuffer:\n{}",
- hunk.context,
- buffer.text(),
- )
- })
+ offset.ok_or_else(|| anyhow!("Failed to match context:\n{}", hunk.context))
}?;
let iter = hunk.edits.into_iter().flat_map(move |edit| {
let old_text = buffer
@@ -7,7 +7,7 @@ use cloud_llm_client::{
ZED_VERSION_HEADER_NAME,
};
use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
-use cloud_zeta2_prompt::retrieval_prompt::SearchToolInput;
+use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
use collections::HashMap;
use edit_prediction_context::{
DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
@@ -35,7 +35,7 @@ use uuid::Uuid;
use std::ops::Range;
use std::path::Path;
use std::str::FromStr as _;
-use std::sync::Arc;
+use std::sync::{Arc, LazyLock};
use std::time::{Duration, Instant};
use thiserror::Error;
use util::rel_path::RelPathBuf;
@@ -88,6 +88,9 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
buffer_change_grouping_interval: Duration::from_secs(1),
};
+static MODEL_ID: LazyLock<String> =
+ LazyLock::new(|| std::env::var("ZED_ZETA2_MODEL").unwrap_or("yqvev8r3".to_string()));
+
pub struct Zeta2FeatureFlag;
impl FeatureFlag for Zeta2FeatureFlag {
@@ -180,7 +183,7 @@ pub struct ZetaEditPredictionDebugInfo {
pub struct ZetaSearchQueryDebugInfo {
pub project: Entity<Project>,
pub timestamp: Instant,
- pub regex_by_glob: HashMap<String, String>,
+ pub search_queries: Vec<SearchToolQuery>,
}
pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
@@ -883,7 +886,7 @@ impl Zeta {
let (prompt, _) = prompt_result?;
let request = open_ai::Request {
- model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("yqvev8r3".to_string()),
+ model: MODEL_ID.clone(),
messages: vec![open_ai::RequestMessage::User {
content: open_ai::MessageContent::Plain(prompt),
}],
@@ -1226,10 +1229,24 @@ impl Zeta {
.ok();
}
- let (tool_schema, tool_description) = &*cloud_zeta2_prompt::retrieval_prompt::TOOL_SCHEMA;
+ pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
+ let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
+ language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
+ );
+
+ let description = schema
+ .get("description")
+ .and_then(|description| description.as_str())
+ .unwrap()
+ .to_string();
+
+ (schema.into(), description)
+ });
+
+ let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
let request = open_ai::Request {
- model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("2327jz9q".to_string()),
+ model: MODEL_ID.clone(),
messages: vec![open_ai::RequestMessage::User {
content: open_ai::MessageContent::Plain(prompt),
}],
@@ -1242,8 +1259,8 @@ impl Zeta {
tools: vec![open_ai::ToolDefinition::Function {
function: FunctionDefinition {
name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
- description: Some(tool_description.clone()),
- parameters: Some(tool_schema.clone()),
+ description: Some(tool_description),
+ parameters: Some(tool_schema),
},
}],
prompt_cache_key: None,
@@ -1255,7 +1272,6 @@ impl Zeta {
let response =
Self::send_raw_llm_request(client, llm_token, app_version, request).await;
let mut response = Self::handle_api_response(&this, response, cx)?;
-
log::trace!("Got search planning response");
let choice = response
@@ -1270,7 +1286,7 @@ impl Zeta {
anyhow::bail!("Retrieval response didn't include an assistant message");
};
- let mut regex_by_glob: HashMap<String, String> = HashMap::default();
+ let mut queries: Vec<SearchToolQuery> = Vec::new();
for tool_call in tool_calls {
let open_ai::ToolCallContent::Function { function } = tool_call.content;
if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
@@ -1283,13 +1299,7 @@ impl Zeta {
}
let input: SearchToolInput = serde_json::from_str(&function.arguments)?;
- for query in input.queries {
- let regex = regex_by_glob.entry(query.glob).or_default();
- if !regex.is_empty() {
- regex.push('|');
- }
- regex.push_str(&query.regex);
- }
+ queries.extend(input.queries);
}
if let Some(debug_tx) = &debug_tx {
@@ -1298,16 +1308,16 @@ impl Zeta {
ZetaSearchQueryDebugInfo {
project: project.clone(),
timestamp: Instant::now(),
- regex_by_glob: regex_by_glob.clone(),
+ search_queries: queries.clone(),
},
))
.ok();
}
- log::trace!("Running retrieval search: {regex_by_glob:#?}");
+ log::trace!("Running retrieval search: {queries:#?}");
let related_excerpts_result =
- retrieval_search::run_retrieval_searches(project.clone(), regex_by_glob, cx).await;
+ retrieval_search::run_retrieval_searches(project.clone(), queries, cx).await;
log::trace!("Search queries executed");
@@ -1754,7 +1764,8 @@ mod tests {
arguments: serde_json::to_string(&SearchToolInput {
queries: Box::new([SearchToolQuery {
glob: "root/2.txt".to_string(),
- regex: ".".to_string(),
+ syntax_node: vec![],
+ content: Some(".".into()),
}]),
})
.unwrap(),
@@ -16,6 +16,7 @@ anyhow.workspace = true
chrono.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
+cloud_zeta2_prompt.workspace = true
collections.workspace = true
edit_prediction_context.workspace = true
editor.workspace = true
@@ -27,7 +28,6 @@ log.workspace = true
multi_buffer.workspace = true
ordered-float.workspace = true
project.workspace = true
-regex-syntax = "0.8.8"
serde.workspace = true
serde_json.workspace = true
telemetry.workspace = true
@@ -8,6 +8,7 @@ use std::{
use anyhow::Result;
use client::{Client, UserStore};
+use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
use editor::{Editor, PathKey};
use futures::StreamExt as _;
use gpui::{
@@ -41,19 +42,13 @@ pub struct Zeta2ContextView {
#[derive(Debug)]
struct RetrievalRun {
editor: Entity<Editor>,
- search_queries: Vec<GlobQueries>,
+ search_queries: Vec<SearchToolQuery>,
started_at: Instant,
search_results_generated_at: Option<Instant>,
search_results_executed_at: Option<Instant>,
finished_at: Option<Instant>,
}
-#[derive(Debug)]
-struct GlobQueries {
- glob: String,
- alternations: Vec<String>,
-}
-
actions!(
dev,
[
@@ -210,23 +205,7 @@ impl Zeta2ContextView {
};
run.search_results_generated_at = Some(info.timestamp);
- run.search_queries = info
- .regex_by_glob
- .into_iter()
- .map(|(glob, regex)| {
- let mut regex_parser = regex_syntax::ast::parse::Parser::new();
-
- GlobQueries {
- glob,
- alternations: match regex_parser.parse(®ex) {
- Ok(regex_syntax::ast::Ast::Alternation(ref alt)) => {
- alt.asts.iter().map(|ast| ast.to_string()).collect()
- }
- _ => vec![regex],
- },
- }
- })
- .collect();
+ run.search_queries = info.search_queries;
cx.notify();
}
@@ -292,18 +271,28 @@ impl Zeta2ContextView {
.enumerate()
.flat_map(|(ix, query)| {
std::iter::once(ListHeader::new(query.glob.clone()).into_any_element())
- .chain(query.alternations.iter().enumerate().map(
- move |(alt_ix, alt)| {
- ListItem::new(ix * 100 + alt_ix)
+ .chain(query.syntax_node.iter().enumerate().map(
+ move |(regex_ix, regex)| {
+ ListItem::new(ix * 100 + regex_ix)
.start_slot(
Icon::new(IconName::MagnifyingGlass)
.color(Color::Muted)
.size(IconSize::Small),
)
- .child(alt.clone())
+ .child(regex.clone())
.into_any_element()
},
))
+ .chain(query.content.as_ref().map(move |regex| {
+ ListItem::new(ix * 100 + query.syntax_node.len())
+ .start_slot(
+ Icon::new(IconName::MagnifyingGlass)
+ .color(Color::Muted)
+ .size(IconSize::Small),
+ )
+ .child(regex.clone())
+ .into_any_element()
+ }))
}),
),
)
@@ -2,11 +2,11 @@ use crate::example::{ActualExcerpt, NamedExample};
use crate::headless::ZetaCliAppState;
use crate::paths::LOGS_DIR;
use ::serde::Serialize;
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::{Result, anyhow};
use clap::Args;
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
use futures::StreamExt as _;
-use gpui::AsyncApp;
+use gpui::{AppContext, AsyncApp};
use project::Project;
use serde::Deserialize;
use std::cell::Cell;
@@ -14,6 +14,7 @@ use std::fs;
use std::io::Write;
use std::path::PathBuf;
use std::sync::Arc;
+use std::sync::Mutex;
use std::time::{Duration, Instant};
#[derive(Debug, Args)]
@@ -103,112 +104,126 @@ pub async fn zeta2_predict(
let _edited_buffers = example.apply_edit_history(&project, cx).await?;
let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
+ let result = Arc::new(Mutex::new(PredictionDetails::default()));
let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
- let refresh_task = zeta.update(cx, |zeta, cx| {
- zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
- })?;
+ let debug_task = cx.background_spawn({
+ let result = result.clone();
+ async move {
+ let mut context_retrieval_started_at = None;
+ let mut context_retrieval_finished_at = None;
+ let mut search_queries_generated_at = None;
+ let mut search_queries_executed_at = None;
+ while let Some(event) = debug_rx.next().await {
+ match event {
+ zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
+ context_retrieval_started_at = Some(info.timestamp);
+ fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?;
+ }
+ zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
+ search_queries_generated_at = Some(info.timestamp);
+ fs::write(
+ LOGS_DIR.join("search_queries.json"),
+ serde_json::to_string_pretty(&info.search_queries).unwrap(),
+ )?;
+ }
+ zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
+ search_queries_executed_at = Some(info.timestamp);
+ }
+ zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
+ context_retrieval_finished_at = Some(info.timestamp);
+ }
+ zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
+ let prediction_started_at = Instant::now();
+ fs::write(
+ LOGS_DIR.join("prediction_prompt.md"),
+ &request.local_prompt.unwrap_or_default(),
+ )?;
- let mut context_retrieval_started_at = None;
- let mut context_retrieval_finished_at = None;
- let mut search_queries_generated_at = None;
- let mut search_queries_executed_at = None;
- let mut prediction_started_at = None;
- let mut prediction_finished_at = None;
- let mut excerpts_text = String::new();
- let mut prediction_task = None;
- let mut result = PredictionDetails::default();
- while let Some(event) = debug_rx.next().await {
- match event {
- zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
- context_retrieval_started_at = Some(info.timestamp);
- fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?;
- }
- zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
- search_queries_generated_at = Some(info.timestamp);
- fs::write(
- LOGS_DIR.join("search_queries.json"),
- serde_json::to_string_pretty(&info.regex_by_glob).unwrap(),
- )?;
- }
- zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
- search_queries_executed_at = Some(info.timestamp);
- }
- zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
- context_retrieval_finished_at = Some(info.timestamp);
+ {
+ let mut result = result.lock().unwrap();
- prediction_task = Some(zeta.update(cx, |zeta, cx| {
- zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
- })?);
- }
- zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
- prediction_started_at = Some(Instant::now());
- fs::write(
- LOGS_DIR.join("prediction_prompt.md"),
- &request.local_prompt.unwrap_or_default(),
- )?;
+ for included_file in request.request.included_files {
+ let insertions =
+ vec![(request.request.cursor_point, CURSOR_MARKER)];
+ result.excerpts.extend(included_file.excerpts.iter().map(
+ |excerpt| ActualExcerpt {
+ path: included_file.path.components().skip(1).collect(),
+ text: String::from(excerpt.text.as_ref()),
+ },
+ ));
+ write_codeblock(
+ &included_file.path,
+ included_file.excerpts.iter(),
+ if included_file.path == request.request.excerpt_path {
+ &insertions
+ } else {
+ &[]
+ },
+ included_file.max_row,
+ false,
+ &mut result.excerpts_text,
+ );
+ }
+ }
- for included_file in request.request.included_files {
- let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
- result
- .excerpts
- .extend(included_file.excerpts.iter().map(|excerpt| ActualExcerpt {
- path: included_file.path.components().skip(1).collect(),
- text: String::from(excerpt.text.as_ref()),
- }));
- write_codeblock(
- &included_file.path,
- included_file.excerpts.iter(),
- if included_file.path == request.request.excerpt_path {
- &insertions
- } else {
- &[]
- },
- included_file.max_row,
- false,
- &mut excerpts_text,
- );
- }
+ let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
+ let response = zeta2::text_from_response(response).unwrap_or_default();
+ let prediction_finished_at = Instant::now();
+ fs::write(LOGS_DIR.join("prediction_response.md"), &response)?;
- let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
- let response = zeta2::text_from_response(response).unwrap_or_default();
- prediction_finished_at = Some(Instant::now());
- fs::write(LOGS_DIR.join("prediction_response.md"), &response)?;
+ let mut result = result.lock().unwrap();
- break;
+ result.planning_search_time = search_queries_generated_at.unwrap()
+ - context_retrieval_started_at.unwrap();
+ result.running_search_time = search_queries_executed_at.unwrap()
+ - search_queries_generated_at.unwrap();
+ result.filtering_search_time = context_retrieval_finished_at.unwrap()
+ - search_queries_executed_at.unwrap();
+ result.prediction_time = prediction_finished_at - prediction_started_at;
+ result.total_time =
+ prediction_finished_at - context_retrieval_started_at.unwrap();
+
+ break;
+ }
+ }
}
+ anyhow::Ok(())
}
- }
+ });
+
+ zeta.update(cx, |zeta, cx| {
+ zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
+ })?
+ .await?;
- refresh_task.await.context("context retrieval failed")?;
- let prediction = prediction_task.unwrap().await?;
+ let prediction = zeta
+ .update(cx, |zeta, cx| {
+ zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
+ })?
+ .await?;
+ debug_task.await?;
+
+ let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
result.diff = prediction
.map(|prediction| {
let old_text = prediction.snapshot.text();
- let new_text = prediction.buffer.update(cx, |buffer, cx| {
- buffer.edit(prediction.edits.iter().cloned(), None, cx);
- buffer.text()
- })?;
- anyhow::Ok(language::unified_diff(&old_text, &new_text))
+ let new_text = prediction
+ .buffer
+ .update(cx, |buffer, cx| {
+ buffer.edit(prediction.edits.iter().cloned(), None, cx);
+ buffer.text()
+ })
+ .unwrap();
+ language::unified_diff(&old_text, &new_text)
})
- .transpose()?
.unwrap_or_default();
- result.excerpts_text = excerpts_text;
-
- result.planning_search_time =
- search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap();
- result.running_search_time =
- search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap();
- result.filtering_search_time =
- context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap();
- result.prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap();
- result.total_time = prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap();
anyhow::Ok(result)
}
-#[derive(Debug, Default, Serialize, Deserialize)]
+#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct PredictionDetails {
pub diff: String,
pub excerpts: Vec<ActualExcerpt>,