@@ -6,19 +6,21 @@ mod onboarding_modal;
mod onboarding_telemetry;
mod rate_completion_modal;
+use arrayvec::ArrayVec;
pub(crate) use completion_diff_element::*;
use db::kvp::{Dismissable, KEY_VALUE_STORE};
use edit_prediction::DataCollectionState;
pub use init::*;
use license_detection::LicenseDetectionWatcher;
+use project::git_store::Repository;
pub use rate_completion_modal::*;
use anyhow::{Context as _, Result, anyhow};
-use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
- PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
+ PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, PredictEditsResponse,
+ ZED_VERSION_HEADER_NAME,
};
use collections::{HashMap, HashSet, VecDeque};
use futures::AsyncReadExt;
@@ -32,7 +34,7 @@ use language::{
Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
-use project::{Project, ProjectPath};
+use project::{Project, ProjectEntryId, ProjectPath};
use release_channel::AppVersion;
use settings::WorktreeId;
use std::str::FromStr;
@@ -70,6 +72,12 @@ const MAX_DIAGNOSTIC_GROUPS: usize = 10;
/// Maximum number of events to track.
const MAX_EVENT_COUNT: usize = 16;
+/// Maximum number of recent files to track.
+const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
+
+/// Maximum number of edit predictions to store for feedback.
+const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
+
actions!(
edit_prediction,
[
@@ -212,7 +220,7 @@ impl std::fmt::Debug for EditPrediction {
}
pub struct Zeta {
- workspace: Option<WeakEntity<Workspace>>,
+ workspace: WeakEntity<Workspace>,
client: Arc<Client>,
events: VecDeque<Event>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
@@ -225,6 +233,7 @@ pub struct Zeta {
update_required: bool,
user_store: Entity<UserStore>,
license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
+ recent_project_entries: VecDeque<(ProjectEntryId, Instant)>,
}
impl Zeta {
@@ -233,7 +242,7 @@ impl Zeta {
}
pub fn register(
- workspace: Option<WeakEntity<Workspace>>,
+ workspace: Option<Entity<Workspace>>,
worktree: Option<Entity<Worktree>>,
client: Arc<Client>,
user_store: Entity<UserStore>,
@@ -266,7 +275,7 @@ impl Zeta {
}
fn new(
- workspace: Option<WeakEntity<Workspace>>,
+ workspace: Option<Entity<Workspace>>,
client: Arc<Client>,
user_store: Entity<UserStore>,
cx: &mut Context<Self>,
@@ -276,11 +285,27 @@ impl Zeta {
let data_collection_choice = Self::load_data_collection_choices();
let data_collection_choice = cx.new(|_| data_collection_choice);
+ if let Some(workspace) = &workspace {
+ cx.subscribe(
+ &workspace.read(cx).project().clone(),
+ |this, _workspace, event, _cx| match event {
+ project::Event::ActiveEntryChanged(Some(project_entry_id)) => {
+ this.push_recent_project_entry(*project_entry_id)
+ }
+ _ => {}
+ },
+ )
+ .detach();
+ }
+
Self {
- workspace,
+ workspace: workspace.map_or_else(
+ || WeakEntity::new_invalid(),
+ |workspace| workspace.downgrade(),
+ ),
client,
- events: VecDeque::new(),
- shown_completions: VecDeque::new(),
+ events: VecDeque::with_capacity(MAX_EVENT_COUNT),
+ shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
rated_completions: HashSet::default(),
registered_buffers: HashMap::default(),
data_collection_choice,
@@ -300,6 +325,7 @@ impl Zeta {
update_required: false,
license_detection_watchers: HashMap::default(),
user_store,
+ recent_project_entries: VecDeque::with_capacity(MAX_RECENT_PROJECT_ENTRIES_COUNT),
}
}
@@ -327,11 +353,12 @@ impl Zeta {
}
}
- self.events.push_back(event);
if self.events.len() >= MAX_EVENT_COUNT {
// These are halved instead of popping to improve prompt caching.
self.events.drain(..MAX_EVENT_COUNT / 2);
}
+
+ self.events.push_back(event);
}
pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
@@ -393,12 +420,17 @@ impl Zeta {
let llm_token = self.llm_token.clone();
let app_version = AppVersion::global(cx);
- let git_info = if let (true, Some(project), Some(file)) =
+ let (git_info, recent_files) = if let (true, Some(project), Some(file)) =
(can_collect_data, project, snapshot.file())
+ && let Some(repository) =
+ git_repository_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
{
- git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
+ let repository = repository.read(cx);
+ let git_info = make_predict_edits_git_info(repository);
+ let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
+ (git_info, Some(recent_files))
} else {
- None
+ (None, None)
};
let full_path: Arc<Path> = snapshot
@@ -417,6 +449,7 @@ impl Zeta {
make_events_prompt,
can_collect_data,
git_info,
+ recent_files,
cx,
);
@@ -702,12 +735,8 @@ and then another
can_collect_data: bool,
cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> {
- let workspace = self
- .workspace
- .as_ref()
- .and_then(|workspace| workspace.upgrade());
self.request_completion_impl(
- workspace,
+ self.workspace.upgrade(),
project,
buffer,
position,
@@ -1021,11 +1050,11 @@ and then another
}
pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
- self.shown_completions.push_front(completion.clone());
- if self.shown_completions.len() > 50 {
+ if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
let completion = self.shown_completions.pop_back().unwrap();
self.rated_completions.remove(&completion.id);
}
+ self.shown_completions.push_front(completion.clone());
cx.notify();
}
@@ -1099,6 +1128,63 @@ and then another
None => DataCollectionChoice::NotAnswered,
}
}
+
+ fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) {
+ let now = Instant::now();
+ if let Some(existing_ix) = self
+ .recent_project_entries
+ .iter()
+ .rposition(|(id, _)| *id == project_entry_id)
+ {
+ self.recent_project_entries.remove(existing_ix);
+ }
+ if self.recent_project_entries.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
+ self.recent_project_entries.pop_front();
+ }
+ self.recent_project_entries
+ .push_back((project_entry_id, now));
+ }
+
+ fn recent_files(
+ &mut self,
+ now: &Instant,
+ repository: &Repository,
+ cx: &Context<Self>,
+ ) -> Vec<PredictEditsRecentFile> {
+ let Ok(project) = self
+ .workspace
+ .read_with(cx, |workspace, _cx| workspace.project().clone())
+ else {
+ return Vec::new();
+ };
+ let mut results = Vec::new();
+ for ix in (0..self.recent_project_entries.len()).rev() {
+ let (id, last_active_at) = &self.recent_project_entries[ix];
+ let Some(project_path) = project.read(cx).path_for_entry(*id, cx) else {
+ self.recent_project_entries.remove(ix);
+ continue;
+ };
+ let Some(repo_path) = repository.project_path_to_repo_path(&project_path, cx) else {
+ // entry not removed since queries involving other repositories might occur later
+ continue;
+ };
+ let Some(repo_path) = repo_path.to_str() else {
+ // paths may not be valid UTF-8
+ self.recent_project_entries.remove(ix);
+ continue;
+ };
+ let Ok(active_to_now_ms) = now.duration_since(*last_active_at).as_millis().try_into()
+ else {
+ self.recent_project_entries.remove(ix);
+ continue;
+ };
+ results.push(PredictEditsRecentFile {
+ repo_path: repo_path.to_string(),
+ active_to_now_ms,
+ });
+ }
+ results
+ }
}
pub struct PerformPredictEditsParams {
@@ -1123,33 +1209,32 @@ fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b:
.sum()
}
-fn git_info_for_file(
+fn git_repository_for_file(
project: &Entity<Project>,
project_path: &ProjectPath,
cx: &App,
-) -> Option<PredictEditsGitInfo> {
+) -> Option<Entity<Repository>> {
let git_store = project.read(cx).git_store().read(cx);
- if let Some((repository, _repo_path)) =
- git_store.repository_and_path_for_project_path(project_path, cx)
- {
- let repository = repository.read(cx);
- let head_sha = repository
- .head_commit
- .as_ref()
- .map(|head_commit| head_commit.sha.to_string());
- let remote_origin_url = repository.remote_origin_url.clone();
- let remote_upstream_url = repository.remote_upstream_url.clone();
- if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
- return None;
- }
- Some(PredictEditsGitInfo {
- head_sha,
- remote_origin_url,
- remote_upstream_url,
- })
- } else {
- None
+ git_store
+ .repository_and_path_for_project_path(project_path, cx)
+ .map(|(repo, _repo_path)| repo)
+}
+
+fn make_predict_edits_git_info(repository: &Repository) -> Option<PredictEditsGitInfo> {
+ let head_sha = repository
+ .head_commit
+ .as_ref()
+ .map(|head_commit| head_commit.sha.to_string());
+ let remote_origin_url = repository.remote_origin_url.clone();
+ let remote_upstream_url = repository.remote_upstream_url.clone();
+ if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
+ return None;
}
+ Some(PredictEditsGitInfo {
+ head_sha,
+ remote_origin_url,
+ remote_upstream_url,
+ })
}
pub struct GatherContextOutput {
@@ -1165,6 +1250,7 @@ pub fn gather_context(
make_events_prompt: impl FnOnce() -> String + Send + 'static,
can_collect_data: bool,
git_info: Option<PredictEditsGitInfo>,
+ recent_files: Option<Vec<PredictEditsRecentFile>>,
cx: &App,
) -> Task<Result<GatherContextOutput>> {
let local_lsp_store =
@@ -1216,6 +1302,7 @@ pub fn gather_context(
git_info,
outline: None,
speculated_output: None,
+ recent_files,
};
Ok(GatherContextOutput {