@@ -20,8 +20,8 @@ use anyhow::{Context as _, Result, anyhow};
use client::{Client, EditPredictionUsage, UserStore};
use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
- PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, PredictEditsResponse,
- ZED_VERSION_HEADER_NAME,
+ PredictEditsAdditionalContext, PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile,
+ PredictEditsResponse, ZED_VERSION_HEADER_NAME,
};
use collections::{HashMap, HashSet, VecDeque};
use futures::AsyncReadExt;
@@ -68,7 +68,6 @@ const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_ch
const MAX_CONTEXT_TOKENS: usize = 150;
const MAX_REWRITE_TOKENS: usize = 350;
const MAX_EVENT_TOKENS: usize = 500;
-const MAX_DIAGNOSTIC_GROUPS: usize = 10;
/// Maximum number of events to track.
const MAX_EVENT_COUNT: usize = 16;
@@ -76,12 +75,15 @@ const MAX_EVENT_COUNT: usize = 16;
/// Maximum number of recent files to track.
const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
-/// Minimum number of milliseconds between recent project entries to keep them
+/// Minimum number of milliseconds between recent file entries.
const MIN_TIME_BETWEEN_RECENT_PROJECT_ENTRIES: Duration = Duration::from_millis(100);
/// Maximum file path length to include in recent files list.
const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
+/// Maximum number of JSON bytes for diagnostics in additional context.
+const MAX_DIAGNOSTICS_BYTES: usize = 4096;
+
/// Maximum number of edit predictions to store for feedback.
const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
@@ -151,7 +153,6 @@ pub struct EditPrediction {
edits: Arc<[(Range<Anchor>, String)]>,
snapshot: BufferSnapshot,
edit_preview: EditPreview,
- input_outline: Arc<str>,
input_events: Arc<str>,
input_excerpt: Arc<str>,
output_excerpt: Arc<str>,
@@ -407,7 +408,7 @@ impl Zeta {
fn request_completion_impl<F, R>(
&mut self,
workspace: Option<Entity<Workspace>>,
- project: Option<&Entity<Project>>,
+ project: Option<Entity<Project>>,
buffer: &Entity<Buffer>,
cursor: language::Anchor,
can_collect_data: CanCollectData,
@@ -429,30 +430,21 @@ impl Zeta {
let llm_token = self.llm_token.clone();
let app_version = AppVersion::global(cx);
- let cursor_point = cursor.to_point(&snapshot);
- let cursor_offset = cursor_point.to_offset(&snapshot);
- let git_info = if matches!(can_collect_data, CanCollectData(true)) {
- self.gather_git_info(cursor_point, &buffer_snapshotted_at, &snapshot, project, cx)
- } else {
- None
- };
-
let full_path: Arc<Path> = snapshot
.file()
.map(|f| Arc::from(f.full_path(cx).as_path()))
.unwrap_or_else(|| Arc::from(Path::new("untitled")));
let full_path_str = full_path.to_string_lossy().to_string();
+ let cursor_point = cursor.to_point(&snapshot);
+ let cursor_offset = cursor_point.to_offset(&snapshot);
let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
- let gather_task = gather_context(
- project,
+ let gather_task = cx.background_spawn(gather_context(
full_path_str,
- &snapshot,
+ snapshot.clone(),
cursor_point,
make_events_prompt,
can_collect_data,
- git_info,
- cx,
- );
+ ));
cx.spawn(async move |this, cx| {
let GatherContextOutput {
@@ -461,13 +453,41 @@ impl Zeta {
} = gather_task.await?;
let done_gathering_context_at = Instant::now();
+ let additional_context_task = if matches!(can_collect_data, CanCollectData(true))
+ && let Some(file) = snapshot.file()
+ && let Ok(project_path) = cx.update(|cx| ProjectPath::from_file(file.as_ref(), cx))
+ {
+ // This is async to reduce latency of the edit predictions request. The downside is
+ // that it will see a slightly later state than was used when gathering context.
+ let snapshot = snapshot.clone();
+ let this = this.clone();
+ Some(cx.spawn(async move |cx| {
+ if let Ok(Some(task)) = this.update(cx, |this, cx| {
+ this.gather_additional_context(
+ cursor_point,
+ cursor_offset,
+ snapshot,
+ &buffer_snapshotted_at,
+ project_path,
+ project.as_ref(),
+ cx,
+ )
+ }) {
+ Some(task.await)
+ } else {
+ None
+ }
+ }))
+ } else {
+ None
+ };
+
log::debug!(
"Events:\n{}\nExcerpt:\n{:?}",
body.input_events,
body.input_excerpt
);
- let input_outline = body.outline.clone().unwrap_or_default();
let input_events = body.input_events.clone();
let input_excerpt = body.input_excerpt.clone();
@@ -524,6 +544,7 @@ impl Zeta {
.ok();
}
+ let request_id = response.request_id.clone();
let edit_prediction = Self::process_completion_response(
response,
buffer,
@@ -531,7 +552,6 @@ impl Zeta {
editable_range,
cursor_offset,
full_path,
- input_outline,
input_events,
input_excerpt,
buffer_snapshotted_at,
@@ -555,6 +575,19 @@ impl Zeta {
);
}
+ if let Some(additional_context_task) = additional_context_task {
+ cx.background_spawn(async move {
+ if let Some(additional_context) = additional_context_task.await {
+ telemetry::event!(
+ "Edit Prediction Additional Context",
+ request_id = request_id,
+ additional_context = additional_context
+ );
+ }
+ })
+ .detach();
+ }
+
edit_prediction
})
}
@@ -572,13 +605,12 @@ impl Zeta {
and then another
"#};
- let project = None;
let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
let position = buffer.read(cx).anchor_before(Point::new(1, 0));
let completion_tasks = vec![
self.fake_completion(
- project,
+ None,
&buffer,
position,
PredictEditsResponse {
@@ -590,12 +622,12 @@ And maybe a short line
Then a few lines
and then another
{EDITABLE_REGION_END_MARKER}
- ", ),
+ "),
},
cx,
),
self.fake_completion(
- project,
+ None,
&buffer,
position,
PredictEditsResponse {
@@ -612,7 +644,7 @@ and then another
cx,
),
self.fake_completion(
- project,
+ None,
&buffer,
position,
PredictEditsResponse {
@@ -630,7 +662,7 @@ and then another
cx,
),
self.fake_completion(
- project,
+ None,
&buffer,
position,
PredictEditsResponse {
@@ -648,7 +680,7 @@ and then another
cx,
),
self.fake_completion(
- project,
+ None,
&buffer,
position,
PredictEditsResponse {
@@ -665,7 +697,7 @@ and then another
cx,
),
self.fake_completion(
- project,
+ None,
&buffer,
position,
PredictEditsResponse {
@@ -681,7 +713,7 @@ and then another
cx,
),
self.fake_completion(
- project,
+ None,
&buffer,
position,
PredictEditsResponse {
@@ -715,7 +747,7 @@ and then another
#[cfg(any(test, feature = "test-support"))]
pub fn fake_completion(
&mut self,
- project: Option<&Entity<Project>>,
+ project: Option<Entity<Project>>,
buffer: &Entity<Buffer>,
position: language::Anchor,
response: PredictEditsResponse,
@@ -736,7 +768,7 @@ and then another
pub fn request_completion(
&mut self,
- project: Option<&Entity<Project>>,
+ project: Option<Entity<Project>>,
buffer: &Entity<Buffer>,
position: language::Anchor,
can_collect_data: CanCollectData,
@@ -904,7 +936,6 @@ and then another
editable_range: Range<usize>,
cursor_offset: usize,
path: Arc<Path>,
- input_outline: String,
input_events: String,
input_excerpt: String,
buffer_snapshotted_at: Instant,
@@ -949,7 +980,6 @@ and then another
edits,
edit_preview,
snapshot,
- input_outline: input_outline.into(),
input_events: input_events.into(),
input_excerpt: input_excerpt.into(),
output_excerpt,
@@ -1078,7 +1108,6 @@ and then another
rating,
input_events = completion.input_events,
input_excerpt = completion.input_excerpt,
- input_outline = completion.input_outline,
output_excerpt = completion.output_excerpt,
feedback
);
@@ -1136,17 +1165,17 @@ and then another
}
}
- fn gather_git_info(
+ fn gather_additional_context(
&mut self,
cursor_point: language::Point,
+ cursor_offset: usize,
+ snapshot: BufferSnapshot,
buffer_snapshotted_at: &Instant,
- snapshot: &BufferSnapshot,
+ project_path: ProjectPath,
project: Option<&Entity<Project>>,
cx: &mut Context<Self>,
- ) -> Option<PredictEditsGitInfo> {
+ ) -> Option<Task<PredictEditsAdditionalContext>> {
let project = project?.read(cx);
- let file = snapshot.file()?;
- let project_path = ProjectPath::from_file(file.as_ref(), cx);
let entry = project.entry_for_path(&project_path, cx)?;
if !worktree_entry_eligible_for_collection(&entry) {
return None;
@@ -1155,7 +1184,25 @@ and then another
let git_store = project.git_store().read(cx);
let (repository, repo_path) =
git_store.repository_and_path_for_project_path(&project_path, cx)?;
- let repo_path_str = repo_path.to_str()?;
+ let repo_path_string = repo_path.to_str()?.to_string();
+
+ let diagnostics = if let Some(local_lsp_store) = project.lsp_store().read(cx).as_local() {
+ snapshot
+ .diagnostics
+ .iter()
+ .filter_map(|(language_server_id, diagnostics)| {
+ let language_server =
+ local_lsp_store.running_language_server_for_id(*language_server_id)?;
+ Some((
+ *language_server_id,
+ language_server.name(),
+ diagnostics.clone(),
+ ))
+ })
+ .collect()
+ } else {
+ Vec::new()
+ };
repository.update(cx, |repository, cx| {
let head_sha = repository.head_commit.as_ref()?.sha.to_string();
@@ -1163,14 +1210,61 @@ and then another
let remote_upstream_url = repository.remote_upstream_url.clone();
let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
- Some(PredictEditsGitInfo {
- input_path: Some(repo_path_str.to_string()),
- cursor_point: Some(to_cloud_llm_client_point(cursor_point)),
- head_sha: Some(head_sha),
- remote_origin_url,
- remote_upstream_url,
- recent_files: Some(recent_files),
- })
+ // group, resolve, and select diagnostics on a background thread
+ Some(cx.background_spawn(async move {
+ let mut diagnostic_groups_with_name = Vec::new();
+ for (language_server_id, language_server_name, diagnostics) in
+ diagnostics.into_iter()
+ {
+ let mut groups = Vec::new();
+ diagnostics.groups(language_server_id, &mut groups, &snapshot);
+ diagnostic_groups_with_name.extend(groups.into_iter().map(|(_, group)| {
+ (
+ language_server_name.clone(),
+ group.resolve::<usize>(&snapshot),
+ )
+ }));
+ }
+
+ // sort by proximity to cursor
+ diagnostic_groups_with_name.sort_by_key(|(_, group)| {
+ let range = &group.entries[group.primary_ix].range;
+ if range.start >= cursor_offset {
+ range.start - cursor_offset
+ } else if cursor_offset >= range.end {
+ cursor_offset - range.end
+ } else {
+ (cursor_offset - range.start).min(range.end - cursor_offset)
+ }
+ });
+
+ let mut diagnostic_groups = Vec::new();
+ let mut diagnostic_groups_truncated = false;
+ let mut diagnostics_byte_count = 0;
+ for (name, group) in diagnostic_groups_with_name {
+ let raw_value = serde_json::value::to_raw_value(&group).unwrap();
+ diagnostics_byte_count += name.0.len() + raw_value.get().len();
+ if diagnostics_byte_count > MAX_DIAGNOSTICS_BYTES {
+ diagnostic_groups_truncated = true;
+ break;
+ }
+ diagnostic_groups.push((name.to_string(), raw_value));
+ }
+
+ PredictEditsAdditionalContext {
+ input_path: repo_path_string,
+ cursor_point: to_cloud_llm_client_point(cursor_point),
+ cursor_offset: cursor_offset,
+ git_info: PredictEditsGitInfo {
+ head_sha: Some(head_sha),
+ remote_origin_url,
+ remote_upstream_url,
+ },
+ diagnostic_groups,
+ diagnostic_groups_truncated,
+ recent_files,
+ }
+ }))
})
}
@@ -1319,74 +1413,34 @@ pub struct GatherContextOutput {
pub editable_range: Range<usize>,
}
-pub fn gather_context(
- project: Option<&Entity<Project>>,
+pub async fn gather_context(
full_path_str: String,
- snapshot: &BufferSnapshot,
+ snapshot: BufferSnapshot,
cursor_point: language::Point,
make_events_prompt: impl FnOnce() -> String + Send + 'static,
can_collect_data: CanCollectData,
- git_info: Option<PredictEditsGitInfo>,
- cx: &App,
-) -> Task<Result<GatherContextOutput>> {
- let local_lsp_store =
- project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
- let diagnostic_groups: Vec<(String, serde_json::Value)> =
- if matches!(can_collect_data, CanCollectData(true))
- && let Some(local_lsp_store) = local_lsp_store
- {
- snapshot
- .diagnostic_groups(None)
- .into_iter()
- .filter_map(|(language_server_id, diagnostic_group)| {
- let language_server =
- local_lsp_store.running_language_server_for_id(language_server_id)?;
- let diagnostic_group = diagnostic_group.resolve::<usize>(snapshot);
- let language_server_name = language_server.name().to_string();
- let serialized = serde_json::to_value(diagnostic_group).unwrap();
- Some((language_server_name, serialized))
- })
- .collect::<Vec<_>>()
- } else {
- Vec::new()
- };
-
- cx.background_spawn({
- let snapshot = snapshot.clone();
- async move {
- let diagnostic_groups = if diagnostic_groups.is_empty()
- || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS
- {
- None
- } else {
- Some(diagnostic_groups)
- };
-
- let input_excerpt = excerpt_for_cursor_position(
- cursor_point,
- &full_path_str,
- &snapshot,
- MAX_REWRITE_TOKENS,
- MAX_CONTEXT_TOKENS,
- );
- let input_events = make_events_prompt();
- let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
-
- let body = PredictEditsBody {
- input_events,
- input_excerpt: input_excerpt.prompt,
- can_collect_data: can_collect_data.0,
- diagnostic_groups,
- git_info,
- outline: None,
- speculated_output: None,
- };
-
- Ok(GatherContextOutput {
- body,
- editable_range,
- })
- }
+) -> Result<GatherContextOutput> {
+ let input_excerpt = excerpt_for_cursor_position(
+ cursor_point,
+ &full_path_str,
+ &snapshot,
+ MAX_REWRITE_TOKENS,
+ MAX_CONTEXT_TOKENS,
+ );
+ let input_events = make_events_prompt();
+ let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
+
+ let body = PredictEditsBody {
+ input_events,
+ input_excerpt: input_excerpt.prompt,
+ can_collect_data: can_collect_data.0,
+ diagnostic_groups: None,
+ git_info: None,
+ };
+
+ Ok(GatherContextOutput {
+ body,
+ editable_range,
})
}
@@ -1763,13 +1817,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
let completion_request = this.update(cx, |this, cx| {
this.last_request_timestamp = Instant::now();
this.zeta.update(cx, |zeta, cx| {
- zeta.request_completion(
- project.as_ref(),
- &buffer,
- position,
- can_collect_data,
- cx,
- )
+ zeta.request_completion(project, &buffer, position, can_collect_data, cx)
})
});
@@ -1983,7 +2031,6 @@ mod tests {
id: EditPredictionId(Uuid::new_v4()),
excerpt_range: 0..0,
cursor_offset: 0,
- input_outline: "".into(),
input_events: "".into(),
input_excerpt: "".into(),
output_excerpt: "".into(),