Detailed changes
@@ -577,6 +577,12 @@ impl EditPredictionStore {
}
}
+ pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
+ if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
+ project_state.events.clear();
+ }
+ }
+
pub fn edit_history_for_project(
&self,
project: &Entity<Project>,
@@ -1,9 +1,6 @@
-use crate::{
- PredictionProvider, PromptFormat,
- metrics::ClassificationMetrics,
- paths::{REPOS_DIR, WORKTREES_DIR},
-};
+use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics};
use anyhow::{Context as _, Result};
+use collections::HashMap;
use edit_prediction::udiff::OpenedBuffers;
use gpui::Entity;
use http_client::Url;
@@ -102,7 +99,7 @@ pub struct ExampleScore {
}
impl Example {
- fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
+ pub fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
// git@github.com:owner/repo.git
if self.repository_url.contains('@') {
let (owner, repo) = self
@@ -134,17 +131,6 @@ impl Example {
Ok((owner.into(), repo.into()))
}
}
-
- pub fn worktree_path(&self) -> PathBuf {
- WORKTREES_DIR
- .join(&self.name)
- .join(self.repo_name().unwrap().1.as_ref())
- }
-
- pub fn repo_path(&self) -> PathBuf {
- let (repo_owner, repo_name) = self.repo_name().expect("failed to get repo name");
- REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref())
- }
}
pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
@@ -218,6 +204,8 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
}
}
}
+
+ sort_examples_by_repo_and_rev(&mut examples);
examples
}
@@ -235,6 +223,25 @@ pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
}
}
+pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
+ examples.sort_by(|a, b| {
+ a.repository_url
+ .cmp(&b.repository_url)
+ .then(b.revision.cmp(&a.revision))
+ });
+}
+
+pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
+ let mut examples_by_repo = HashMap::default();
+ for example in examples.iter_mut() {
+ examples_by_repo
+ .entry(example.repository_url.clone())
+ .or_insert_with(Vec::new)
+ .push(example);
+ }
+ examples_by_repo.into_values().collect()
+}
+
fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
@@ -1,4 +1,5 @@
use client::{Client, ProxySettings, UserStore};
+use collections::HashMap;
use extension::ExtensionHostProxy;
use fs::RealFs;
use gpui::http_client::read_proxy_from_env;
@@ -7,12 +8,13 @@ use gpui_tokio::Tokio;
use language::LanguageRegistry;
use language_extension::LspAccess;
use node_runtime::{NodeBinaryOptions, NodeRuntime};
+use project::Project;
use project::project_settings::ProjectSettings;
use release_channel::{AppCommitSha, AppVersion};
use reqwest_client::ReqwestClient;
use settings::{Settings, SettingsStore};
use std::path::PathBuf;
-use std::sync::Arc;
+use std::sync::{Arc, Mutex};
use util::ResultExt as _;
/// Headless subset of `workspace::AppState`.
@@ -22,9 +24,22 @@ pub struct EpAppState {
pub user_store: Entity<UserStore>,
pub fs: Arc<dyn fs::Fs>,
pub node_runtime: NodeRuntime,
+ pub project_cache: ProjectCache,
+}
+
+#[derive(Default)]
+pub struct ProjectCache(Mutex<HashMap<String, Entity<Project>>>);
+
+impl ProjectCache {
+ pub fn insert(&self, repository_url: String, project: Entity<Project>) {
+ self.0.lock().unwrap().insert(repository_url, project);
+ }
+
+ pub fn get(&self, repository_url: &String) -> Option<Entity<Project>> {
+ self.0.lock().unwrap().get(repository_url).cloned()
+ }
}
-// TODO: dedupe with crates/eval/src/eval.rs
pub fn init(cx: &mut App) -> EpAppState {
let app_commit_sha = option_env!("ZED_COMMIT_SHA").map(|s| AppCommitSha::new(s.to_owned()));
@@ -112,11 +127,14 @@ pub fn init(cx: &mut App) -> EpAppState {
prompt_store::init(cx);
terminal_view::init(cx);
+ let project_cache = ProjectCache::default();
+
EpAppState {
languages,
client,
user_store,
fs,
node_runtime,
+ project_cache,
}
}
@@ -1,6 +1,7 @@
use crate::{
example::{Example, ExampleBuffer, ExampleState},
headless::EpAppState,
+ paths::{REPOS_DIR, WORKTREES_DIR},
};
use anyhow::{Result, anyhow};
use collections::HashMap;
@@ -29,29 +30,11 @@ pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>,
}
let project = setup_project(example, &app_state, &mut cx).await;
- let buffer_store = project
- .read_with(&cx, |project, _| project.buffer_store().clone())
- .unwrap();
-
- let ep_store = cx
- .update(|cx| EditPredictionStore::try_global(cx).unwrap())
- .unwrap();
-
- cx.subscribe(&buffer_store, {
- let project = project.clone();
- move |_, event, cx| match event {
- BufferStoreEvent::BufferAdded(buffer) => {
- ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
- }
- _ => {}
- }
- })
- .unwrap()
- .detach();
let _open_buffers = apply_edit_history(example, &project, &mut cx)
.await
.unwrap();
+
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
example.buffer = buffer
.read_with(&cx, |buffer, _cx| {
@@ -64,6 +47,7 @@ pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>,
})
})
.unwrap();
+
example.state = Some(ExampleState {
buffer,
project,
@@ -149,7 +133,35 @@ async fn setup_project(
app_state: &Arc<EpAppState>,
cx: &mut AsyncApp,
) -> Entity<Project> {
- setup_worktree(example).await;
+ let ep_store = cx
+ .update(|cx| EditPredictionStore::try_global(cx).unwrap())
+ .unwrap();
+
+ let worktree_path = setup_worktree(example).await;
+
+ if let Some(project) = app_state.project_cache.get(&example.repository_url) {
+ ep_store
+ .update(cx, |ep_store, _| {
+ ep_store.clear_history_for_project(&project);
+ })
+ .unwrap();
+ let buffer_store = project
+ .read_with(cx, |project, _| project.buffer_store().clone())
+ .unwrap();
+ let buffers = buffer_store
+ .read_with(cx, |buffer_store, _| {
+ buffer_store.buffers().collect::<Vec<_>>()
+ })
+ .unwrap();
+ for buffer in buffers {
+ buffer
+ .update(cx, |buffer, cx| buffer.reload(cx))
+ .unwrap()
+ .await
+ .unwrap();
+ }
+ return project;
+ }
let project = cx
.update(|cx| {
@@ -168,30 +180,44 @@ async fn setup_project(
project
.update(cx, |project, cx| {
project.disable_worktree_scanner(cx);
- })
- .unwrap();
-
- let worktree = project
- .update(cx, |project, cx| {
- project.create_worktree(&example.worktree_path(), true, cx)
+ project.create_worktree(&worktree_path, true, cx)
})
.unwrap()
.await
.unwrap();
- worktree
- .read_with(cx, |worktree, _cx| {
- worktree.as_local().unwrap().scan_complete()
- })
- .unwrap()
- .await;
+
+ app_state
+ .project_cache
+ .insert(example.repository_url.clone(), project.clone());
+
+ let buffer_store = project
+ .read_with(cx, |project, _| project.buffer_store().clone())
+ .unwrap();
+ cx.subscribe(&buffer_store, {
+ let project = project.clone();
+ move |_, event, cx| match event {
+ BufferStoreEvent::BufferAdded(buffer) => {
+ ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
+ }
+ _ => {}
+ }
+ })
+ .unwrap()
+ .detach();
+
project
}
-pub async fn setup_worktree(example: &Example) {
- let repo_dir = example.repo_path();
+pub async fn setup_worktree(example: &Example) -> PathBuf {
+ let (repo_owner, repo_name) = example.repo_name().expect("failed to get repo name");
+ let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
+ let worktree_path = WORKTREES_DIR
+ .join(repo_owner.as_ref())
+ .join(repo_name.as_ref());
let repo_lock = lock_repo(&repo_dir).await;
if !repo_dir.is_dir() {
+ eprintln!("Cloning repository {}", example.repository_url);
fs::create_dir_all(&repo_dir).unwrap();
run_git(&repo_dir, &["init"]).await.unwrap();
run_git(
@@ -227,7 +253,6 @@ pub async fn setup_worktree(example: &Example) {
};
// Create the worktree for this example if needed.
- let worktree_path = example.worktree_path();
if worktree_path.is_dir() {
run_git(&worktree_path, &["clean", "--force", "-d"])
.await
@@ -288,6 +313,8 @@ pub async fn setup_worktree(example: &Example) {
);
}
}
+
+ worktree_path
}
async fn apply_edit_history(
@@ -15,10 +15,12 @@ use edit_prediction::EditPredictionStore;
use gpui::Application;
use reqwest_client::ReqwestClient;
use serde::{Deserialize, Serialize};
+use std::sync::atomic::AtomicUsize;
+use std::sync::atomic::Ordering::SeqCst;
use std::{path::PathBuf, sync::Arc};
use crate::distill::run_distill;
-use crate::example::{read_examples, write_examples};
+use crate::example::{group_examples_by_repo, read_examples, write_examples};
use crate::format_prompt::run_format_prompt;
use crate::load_project::run_load_project;
use crate::predict::run_prediction;
@@ -145,31 +147,40 @@ fn main() {
EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
cx.spawn(async move |cx| {
- match &command {
- Command::Predict(args) => predict::sync_batches(&args.provider).await,
- _ => (),
+ if let Command::Predict(args) = &command {
+ predict::sync_batches(&args.provider).await
};
- let chunks = examples.chunks_mut(args.max_parallelism);
- let total_chunks = chunks.len();
- for (batch_ix, data) in chunks.enumerate() {
- let mut futures = Vec::new();
- eprintln!("Processing batch: {}/{}", batch_ix + 1, total_chunks);
-
- for example in data.iter_mut() {
- let cx = cx.clone();
- let app_state = app_state.clone();
- futures.push(async {
+ let example_count = examples.len();
+ let example_ix = AtomicUsize::new(0);
+ let mut grouped_examples = group_examples_by_repo(&mut examples);
+
+ let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
+ for example_batch in example_batches {
+ let futures = example_batch.into_iter().map(|repo_examples| async {
+ for example in repo_examples.iter_mut() {
+ eprintln!(
+ "Processing example: {}/{}",
+ example_ix.load(SeqCst) + 1,
+ example_count
+ );
+ example_ix.fetch_add(1, SeqCst);
match &command {
Command::ParseExample => {}
Command::LoadProject => {
- run_load_project(example, app_state.clone(), cx).await;
+ run_load_project(example, app_state.clone(), cx.clone()).await;
}
Command::Context => {
- run_context_retrieval(example, app_state, cx).await;
+ run_context_retrieval(example, app_state.clone(), cx.clone()).await;
}
Command::FormatPrompt(args) => {
- run_format_prompt(example, args.prompt_format, app_state, cx).await;
+ run_format_prompt(
+ example,
+ args.prompt_format,
+ app_state.clone(),
+ cx.clone(),
+ )
+ .await;
}
Command::Predict(args) => {
run_prediction(
@@ -177,7 +188,7 @@ fn main() {
Some(args.provider),
args.repetitions,
app_state.clone(),
- cx,
+ cx.clone(),
)
.await;
}
@@ -185,14 +196,14 @@ fn main() {
run_distill(example).await;
}
Command::Score(args) | Command::Eval(args) => {
- run_scoring(example, &args, app_state, cx).await;
+ run_scoring(example, &args, app_state.clone(), cx.clone()).await;
}
Command::Clean => {
unreachable!()
}
}
- });
- }
+ }
+ });
futures::future::join_all(futures).await;
}
@@ -75,25 +75,24 @@ async fn wait_for_language_servers_to_start(
.read_with(cx, |project, _| project.lsp_store())
.unwrap();
- let lang_server_ids = buffer
+ let (language_server_ids, mut starting_language_server_ids) = buffer
.update(cx, |buffer, cx| {
lsp_store.update(cx, |lsp_store, cx| {
- lsp_store.language_servers_for_local_buffer(buffer, cx)
+ let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
+ let starting_ids = ids
+ .iter()
+ .copied()
+ .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
+ .collect::<HashSet<_>>();
+ (ids, starting_ids)
})
})
.unwrap_or_default();
- if !lang_server_ids.is_empty() {
- project
- .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
- .unwrap()
- .detach();
- }
-
eprintln!(
"{}⏵ Waiting for {} language servers",
log_prefix,
- lang_server_ids.len()
+ language_server_ids.len()
);
let timeout = cx
@@ -101,7 +100,7 @@ async fn wait_for_language_servers_to_start(
.timer(Duration::from_secs(60 * 5))
.shared();
- let (mut tx, mut rx) = mpsc::channel(lang_server_ids.len());
+ let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
let added_subscription = cx.subscribe(project, {
let log_prefix = log_prefix.clone();
move |_, event, _| match event {
@@ -113,12 +112,11 @@ async fn wait_for_language_servers_to_start(
}
});
- let mut pending_language_server_ids = HashSet::from_iter(lang_server_ids.iter());
- while !pending_language_server_ids.is_empty() {
+ while !starting_language_server_ids.is_empty() {
futures::select! {
language_server_id = rx.next() => {
if let Some(id) = language_server_id {
- pending_language_server_ids.remove(&id);
+ starting_language_server_ids.remove(&id);
}
},
_ = timeout.clone().fuse() => {
@@ -129,7 +127,14 @@ async fn wait_for_language_servers_to_start(
drop(added_subscription);
- let (mut tx, mut rx) = mpsc::channel(lang_server_ids.len());
+ if !language_server_ids.is_empty() {
+ project
+ .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+ .unwrap()
+ .detach();
+ }
+
+ let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
let subscriptions = [
cx.subscribe(&lsp_store, {
let log_prefix = log_prefix.clone();
@@ -172,7 +177,7 @@ async fn wait_for_language_servers_to_start(
.await
.unwrap();
- let mut pending_language_server_ids = HashSet::from_iter(lang_server_ids.into_iter());
+ let mut pending_language_server_ids = HashSet::from_iter(language_server_ids.into_iter());
while !pending_language_server_ids.is_empty() {
futures::select! {
language_server_id = rx.next() => {
@@ -201,7 +201,10 @@ pub enum LspFormatTarget {
Ranges(BTreeMap<BufferId, Vec<Range<Anchor>>>),
}
-pub type OpenLspBufferHandle = Entity<Entity<Buffer>>;
+#[derive(Clone, PartialEq, Eq, Hash)]
+pub struct OpenLspBufferHandle(Entity<OpenLspBuffer>);
+
+struct OpenLspBuffer(Entity<Buffer>);
impl FormatTrigger {
fn from_proto(value: i32) -> FormatTrigger {
@@ -4208,7 +4211,7 @@ impl LspStore {
cx: &mut Context<Self>,
) -> OpenLspBufferHandle {
let buffer_id = buffer.read(cx).remote_id();
- let handle = cx.new(|_| buffer.clone());
+ let handle = OpenLspBufferHandle(cx.new(|_| OpenLspBuffer(buffer.clone())));
if let Some(local) = self.as_local_mut() {
let refcount = local.registered_buffers.entry(buffer_id).or_insert(0);
if !ignore_refcounts {
@@ -4230,7 +4233,7 @@ impl LspStore {
local.register_buffer_with_language_servers(buffer, only_register_servers, cx);
}
if !ignore_refcounts {
- cx.observe_release(&handle, move |lsp_store, buffer, cx| {
+ cx.observe_release(&handle.0, move |lsp_store, buffer, cx| {
let refcount = {
let local = lsp_store.as_local_mut().unwrap();
let Some(refcount) = local.registered_buffers.get_mut(&buffer_id) else {
@@ -4247,8 +4250,8 @@ impl LspStore {
local.registered_buffers.remove(&buffer_id);
local.buffers_opened_in_servers.remove(&buffer_id);
- if let Some(file) = File::from_dyn(buffer.read(cx).file()).cloned() {
- local.unregister_old_buffer_from_language_servers(buffer, &file, cx);
+ if let Some(file) = File::from_dyn(buffer.0.read(cx).file()).cloned() {
+ local.unregister_old_buffer_from_language_servers(&buffer.0, &file, cx);
let buffer_abs_path = file.abs_path(cx);
for (_, buffer_pull_diagnostics_result_ids) in