eval.rs

  1use ::fs::{Fs, RealFs};
  2use anyhow::Result;
  3use clap::Parser;
  4use client::{Client, UserStore};
  5use clock::RealSystemClock;
  6use collections::BTreeMap;
  7use feature_flags::FeatureFlagAppExt as _;
  8use git::GitHostingProviderRegistry;
  9use gpui::{AsyncAppContext, BackgroundExecutor, Context, Model};
 10use http_client::{HttpClient, Method};
 11use language::LanguageRegistry;
 12use node_runtime::NodeRuntime;
 13use open_ai::OpenAiEmbeddingModel;
 14use project::Project;
 15use semantic_index::{
 16    EmbeddingProvider, OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status,
 17};
 18use serde::{Deserialize, Serialize};
 19use settings::SettingsStore;
 20use smol::channel::bounded;
 21use smol::io::AsyncReadExt;
 22use smol::Timer;
 23use std::ops::RangeInclusive;
 24use std::path::PathBuf;
 25use std::time::Duration;
 26use std::{
 27    fs,
 28    path::Path,
 29    process::{exit, Command, Stdio},
 30    sync::{
 31        atomic::{AtomicUsize, Ordering::SeqCst},
 32        Arc,
 33    },
 34};
 35
 36const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
 37const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
 38const EVAL_DB_PATH: &'static str = "target/eval_db";
 39const SEARCH_RESULT_LIMIT: usize = 8;
 40const SKIP_EVAL_PATH: &'static str = ".skip_eval";
 41
 42#[derive(clap::Parser)]
 43#[command(author, version, about, long_about = None)]
 44struct Cli {
 45    #[command(subcommand)]
 46    command: Commands,
 47}
 48
 49#[derive(clap::Subcommand)]
 50enum Commands {
 51    Fetch {},
 52    Run {
 53        #[arg(long)]
 54        repo: Option<String>,
 55    },
 56}
 57
 58#[derive(Clone, Deserialize, Serialize)]
 59struct EvaluationProject {
 60    repo: String,
 61    sha: String,
 62    queries: Vec<EvaluationQuery>,
 63}
 64
 65#[derive(Clone, Debug, Deserialize, Serialize)]
 66struct EvaluationQuery {
 67    query: String,
 68    expected_results: Vec<EvaluationSearchResult>,
 69}
 70
 71#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
 72struct EvaluationSearchResult {
 73    file: String,
 74    lines: RangeInclusive<u32>,
 75}
 76
 77#[derive(Clone, Deserialize, Serialize)]
 78struct EvaluationProjectOutcome {
 79    repo: String,
 80    sha: String,
 81    queries: Vec<EvaluationQueryOutcome>,
 82}
 83
 84#[derive(Clone, Debug, Deserialize, Serialize)]
 85struct EvaluationQueryOutcome {
 86    repo: String,
 87    query: String,
 88    expected_results: Vec<EvaluationSearchResult>,
 89    actual_results: Vec<EvaluationSearchResult>,
 90    covered_file_count: usize,
 91    overlapped_result_count: usize,
 92    covered_result_count: usize,
 93    total_result_count: usize,
 94    covered_result_indices: Vec<usize>,
 95}
 96
 97fn main() -> Result<()> {
 98    let cli = Cli::parse();
 99    env_logger::init();
100
101    gpui::App::headless().run(move |cx| {
102        let executor = cx.background_executor().clone();
103        let client = isahc_http_client::IsahcHttpClient::new(None, None);
104        cx.set_http_client(client.clone());
105        match cli.command {
106            Commands::Fetch {} => {
107                executor
108                    .clone()
109                    .spawn(async move {
110                        if let Err(err) = fetch_evaluation_resources(client, &executor).await {
111                            eprintln!("Error: {}", err);
112                            exit(1);
113                        }
114                        exit(0);
115                    })
116                    .detach();
117            }
118            Commands::Run { repo } => {
119                cx.spawn(|mut cx| async move {
120                    if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
121                        eprintln!("Error: {}", err);
122                        exit(1);
123                    }
124                    exit(0);
125                })
126                .detach();
127            }
128        }
129    });
130
131    Ok(())
132}
133
134async fn fetch_evaluation_resources(
135    http_client: Arc<dyn HttpClient>,
136    executor: &BackgroundExecutor,
137) -> Result<()> {
138    fetch_code_search_net_resources(&*http_client).await?;
139    fetch_eval_repos(executor, &*http_client).await?;
140    Ok(())
141}
142
143async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
144    eprintln!("Fetching CodeSearchNet evaluations...");
145
146    let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
147
148    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
149    fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
150
151    // Fetch the annotations CSV, which contains the human-annotated search relevances
152    let annotations_path = dataset_dir.join("annotations.csv");
153    let annotations_csv_content = if annotations_path.exists() {
154        fs::read_to_string(&annotations_path).expect("failed to read annotations")
155    } else {
156        let response = http_client
157            .get(annotations_url, Default::default(), true)
158            .await
159            .expect("failed to fetch annotations csv");
160        let mut body = String::new();
161        response
162            .into_body()
163            .read_to_string(&mut body)
164            .await
165            .expect("failed to read annotations.csv response");
166        fs::write(annotations_path, &body).expect("failed to write annotations.csv");
167        body
168    };
169
170    // Parse the annotations CSV. Skip over queries with zero relevance.
171    let rows = annotations_csv_content.lines().filter_map(|line| {
172        let mut values = line.split(',');
173        let _language = values.next()?;
174        let query = values.next()?;
175        let github_url = values.next()?;
176        let score = values.next()?;
177
178        if score == "0" {
179            return None;
180        }
181
182        let url_path = github_url.strip_prefix("https://github.com/")?;
183        let (url_path, hash) = url_path.split_once('#')?;
184        let (repo_name, url_path) = url_path.split_once("/blob/")?;
185        let (sha, file_path) = url_path.split_once('/')?;
186        let line_range = if let Some((start, end)) = hash.split_once('-') {
187            start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
188        } else {
189            let row = hash.strip_prefix("L")?.parse().ok()?;
190            row..=row
191        };
192        Some((repo_name, sha, query, file_path, line_range))
193    });
194
195    // Group the annotations by repo and sha.
196    let mut evaluations_by_repo = BTreeMap::new();
197    for (repo_name, sha, query, file_path, lines) in rows {
198        let evaluation_project = evaluations_by_repo
199            .entry((repo_name, sha))
200            .or_insert_with(|| EvaluationProject {
201                repo: repo_name.to_string(),
202                sha: sha.to_string(),
203                queries: Vec::new(),
204            });
205
206        let ix = evaluation_project
207            .queries
208            .iter()
209            .position(|entry| entry.query == query)
210            .unwrap_or_else(|| {
211                evaluation_project.queries.push(EvaluationQuery {
212                    query: query.to_string(),
213                    expected_results: Vec::new(),
214                });
215                evaluation_project.queries.len() - 1
216            });
217        let results = &mut evaluation_project.queries[ix].expected_results;
218        let result = EvaluationSearchResult {
219            file: file_path.to_string(),
220            lines,
221        };
222        if !results.contains(&result) {
223            results.push(result);
224        }
225    }
226
227    let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
228    let evaluations_path = dataset_dir.join("evaluations.json");
229    fs::write(
230        &evaluations_path,
231        serde_json::to_vec_pretty(&evaluations).unwrap(),
232    )
233    .unwrap();
234
235    eprintln!(
236        "Fetched CodeSearchNet evaluations into {}",
237        evaluations_path.display()
238    );
239
240    Ok(())
241}
242
243#[derive(Default, Debug)]
244struct Counts {
245    covered_results: usize,
246    overlapped_results: usize,
247    covered_files: usize,
248    total_results: usize,
249}
250
251async fn run_evaluation(
252    only_repo: Option<String>,
253    executor: &BackgroundExecutor,
254    cx: &mut AsyncAppContext,
255) -> Result<()> {
256    let mut http_client = None;
257    cx.update(|cx| {
258        let mut store = SettingsStore::new(cx);
259        store
260            .set_default_settings(settings::default_settings().as_ref(), cx)
261            .unwrap();
262        cx.set_global(store);
263        client::init_settings(cx);
264        language::init(cx);
265        Project::init_settings(cx);
266        http_client = Some(cx.http_client());
267        cx.update_flags(false, vec![]);
268    })
269    .unwrap();
270    let http_client = http_client.unwrap();
271    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
272    let evaluations_path = dataset_dir.join("evaluations.json");
273    let repos_dir = Path::new(EVAL_REPOS_DIR);
274    let db_path = Path::new(EVAL_DB_PATH);
275    let api_key = std::env::var("OPENAI_API_KEY").unwrap();
276    let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new());
277    let fs = Arc::new(RealFs::new(git_hosting_provider_registry, None)) as Arc<dyn Fs>;
278    let clock = Arc::new(RealSystemClock);
279    let client = cx
280        .update(|cx| {
281            Client::new(
282                clock,
283                Arc::new(http_client::HttpClientWithUrl::new(
284                    http_client.clone(),
285                    "https://zed.dev",
286                    None,
287                )),
288                cx,
289            )
290        })
291        .unwrap();
292    let user_store = cx
293        .new_model(|cx| UserStore::new(client.clone(), cx))
294        .unwrap();
295    let node_runtime = NodeRuntime::unavailable();
296
297    let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
298    let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
299
300    let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
301        http_client.clone(),
302        OpenAiEmbeddingModel::TextEmbedding3Small,
303        open_ai::OPEN_AI_API_URL.to_string(),
304        api_key,
305    ));
306
307    let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
308    cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
309        .unwrap();
310
311    let mut counts = Counts::default();
312    eprint!("Running evals.");
313
314    let mut failures = Vec::new();
315
316    for evaluation_project in evaluations {
317        if only_repo
318            .as_ref()
319            .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
320        {
321            continue;
322        }
323
324        eprint!("\r\x1B[2K");
325        eprint!(
326            "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
327            counts.covered_results,
328            counts.total_results,
329            counts.overlapped_results,
330            counts.total_results,
331            counts.covered_files,
332            counts.total_results,
333            evaluation_project.repo
334        );
335
336        let repo_dir = repos_dir.join(&evaluation_project.repo);
337        if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
338            eprintln!("Skipping {}: directory not found", evaluation_project.repo);
339            continue;
340        }
341
342        let repo_db_path =
343            db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
344
345        let project = cx
346            .update(|cx| {
347                Project::local(
348                    client.clone(),
349                    node_runtime.clone(),
350                    user_store.clone(),
351                    language_registry.clone(),
352                    fs.clone(),
353                    None,
354                    cx,
355                )
356            })
357            .unwrap();
358
359        let repo = evaluation_project.repo.clone();
360        if let Err(err) = run_eval_project(
361            evaluation_project,
362            &user_store,
363            repo_db_path,
364            &repo_dir,
365            &mut counts,
366            project,
367            embedding_provider.clone(),
368            fs.clone(),
369            cx,
370        )
371        .await
372        {
373            eprintln!("{repo} eval failed with error: {:?}", err);
374
375            failures.push((repo, err));
376        }
377    }
378
379    eprintln!(
380        "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. {} failed.",
381        counts.covered_results,
382        counts.total_results,
383        counts.overlapped_results,
384        counts.total_results,
385        counts.covered_files,
386        counts.total_results,
387        failures.len(),
388    );
389
390    if failures.is_empty() {
391        Ok(())
392    } else {
393        eprintln!("Failures:\n");
394
395        for (index, (repo, failure)) in failures.iter().enumerate() {
396            eprintln!("Failure #{} - {repo}\n{:?}", index + 1, failure);
397        }
398
399        Err(anyhow::anyhow!("Some evals failed."))
400    }
401}
402
403#[allow(clippy::too_many_arguments)]
404async fn run_eval_project(
405    evaluation_project: EvaluationProject,
406    user_store: &Model<UserStore>,
407    repo_db_path: PathBuf,
408    repo_dir: &Path,
409    counts: &mut Counts,
410    project: Model<Project>,
411    embedding_provider: Arc<dyn EmbeddingProvider>,
412    fs: Arc<dyn Fs>,
413    cx: &mut AsyncAppContext,
414) -> Result<(), anyhow::Error> {
415    let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider, cx).await?;
416
417    let (worktree, _) = project
418        .update(cx, |project, cx| {
419            project.find_or_create_worktree(repo_dir, true, cx)
420        })?
421        .await?;
422
423    worktree
424        .update(cx, |worktree, _| {
425            worktree.as_local().unwrap().scan_complete()
426        })?
427        .await;
428
429    let project_index = cx.update(|cx| semantic_index.create_project_index(project.clone(), cx))?;
430    wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
431
432    for query in evaluation_project.queries {
433        let results = {
434            // Retry search up to 3 times in case of timeout, network failure, etc.
435            let mut retries_remaining = 3;
436            let mut result;
437
438            loop {
439                match cx.update(|cx| {
440                    let project_index = project_index.read(cx);
441                    project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx)
442                }) {
443                    Ok(task) => match task.await {
444                        Ok(answer) => {
445                            result = Ok(answer);
446                            break;
447                        }
448                        Err(err) => {
449                            result = Err(err);
450                        }
451                    },
452                    Err(err) => {
453                        result = Err(err);
454                    }
455                }
456
457                if retries_remaining > 0 {
458                    eprintln!(
459                        "Retrying search after it failed on query {:?} with {:?}",
460                        query, result
461                    );
462                    retries_remaining -= 1;
463                } else {
464                    eprintln!(
465                        "Ran out of retries; giving up on search which failed on query {:?} with {:?}",
466                        query, result
467                    );
468                    break;
469                }
470            }
471
472            SemanticDb::load_results(result?, &fs.clone(), &cx).await?
473        };
474
475        let mut project_covered_result_count = 0;
476        let mut project_overlapped_result_count = 0;
477        let mut project_covered_file_count = 0;
478        let mut covered_result_indices = Vec::new();
479        for expected_result in &query.expected_results {
480            let mut file_matched = false;
481            let mut range_overlapped = false;
482            let mut range_covered = false;
483
484            for (ix, result) in results.iter().enumerate() {
485                if result.path.as_ref() == Path::new(&expected_result.file) {
486                    file_matched = true;
487                    let start_matched = result.row_range.contains(&expected_result.lines.start());
488                    let end_matched = result.row_range.contains(&expected_result.lines.end());
489
490                    if start_matched || end_matched {
491                        range_overlapped = true;
492                    }
493
494                    if start_matched && end_matched {
495                        range_covered = true;
496                        covered_result_indices.push(ix);
497                        break;
498                    }
499                }
500            }
501
502            if range_covered {
503                project_covered_result_count += 1
504            };
505            if range_overlapped {
506                project_overlapped_result_count += 1
507            };
508            if file_matched {
509                project_covered_file_count += 1
510            };
511        }
512        let outcome_repo = evaluation_project.repo.clone();
513
514        let query_results = EvaluationQueryOutcome {
515            repo: outcome_repo,
516            query: query.query,
517            total_result_count: query.expected_results.len(),
518            covered_result_count: project_covered_result_count,
519            overlapped_result_count: project_overlapped_result_count,
520            covered_file_count: project_covered_file_count,
521            expected_results: query.expected_results,
522            actual_results: results
523                .iter()
524                .map(|result| EvaluationSearchResult {
525                    file: result.path.to_string_lossy().to_string(),
526                    lines: result.row_range.clone(),
527                })
528                .collect(),
529            covered_result_indices,
530        };
531
532        counts.overlapped_results += query_results.overlapped_result_count;
533        counts.covered_results += query_results.covered_result_count;
534        counts.covered_files += query_results.covered_file_count;
535        counts.total_results += query_results.total_result_count;
536
537        println!("{}", serde_json::to_string(&query_results)?);
538    }
539
540    user_store.update(cx, |_, _| {
541        drop(semantic_index);
542        drop(project);
543        drop(worktree);
544        drop(project_index);
545    })
546}
547
548async fn wait_for_indexing_complete(
549    project_index: &Model<ProjectIndex>,
550    cx: &mut AsyncAppContext,
551    timeout: Option<Duration>,
552) {
553    let (tx, rx) = bounded(1);
554    let subscription = cx.update(|cx| {
555        cx.subscribe(project_index, move |_, event, _| {
556            if let Status::Idle = event {
557                let _ = tx.try_send(*event);
558            }
559        })
560    });
561
562    let result = match timeout {
563        Some(timeout_duration) => {
564            smol::future::or(
565                async {
566                    rx.recv().await.map_err(|_| ())?;
567                    Ok(())
568                },
569                async {
570                    Timer::after(timeout_duration).await;
571                    Err(())
572                },
573            )
574            .await
575        }
576        None => rx.recv().await.map(|_| ()).map_err(|_| ()),
577    };
578
579    match result {
580        Ok(_) => (),
581        Err(_) => {
582            if let Some(timeout) = timeout {
583                eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
584            }
585        }
586    }
587
588    drop(subscription);
589}
590
591async fn fetch_eval_repos(
592    executor: &BackgroundExecutor,
593    http_client: &dyn HttpClient,
594) -> Result<()> {
595    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
596    let evaluations_path = dataset_dir.join("evaluations.json");
597    let repos_dir = Path::new(EVAL_REPOS_DIR);
598
599    let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
600    let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
601
602    eprintln!("Fetching evaluation repositories...");
603
604    executor
605        .scoped(move |scope| {
606            let done_count = Arc::new(AtomicUsize::new(0));
607            let len = evaluations.len();
608            for chunk in evaluations.chunks(evaluations.len() / 8) {
609                let chunk = chunk.to_vec();
610                let done_count = done_count.clone();
611                scope.spawn(async move {
612                    for EvaluationProject { repo, sha, .. } in chunk {
613                        eprint!(
614                            "\rFetching evaluation repositories ({}/{})...",
615                            done_count.load(SeqCst),
616                            len,
617                        );
618
619                        fetch_eval_repo(repo, sha, repos_dir, http_client).await;
620                        done_count.fetch_add(1, SeqCst);
621                    }
622                });
623            }
624        })
625        .await;
626
627    Ok(())
628}
629
630async fn fetch_eval_repo(
631    repo: String,
632    sha: String,
633    repos_dir: &Path,
634    http_client: &dyn HttpClient,
635) {
636    let Some((owner, repo_name)) = repo.split_once('/') else {
637        return;
638    };
639    let repo_dir = repos_dir.join(owner).join(repo_name);
640    fs::create_dir_all(&repo_dir).unwrap();
641    let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
642    if skip_eval_path.exists() {
643        return;
644    }
645    if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
646        if head_content.trim() == sha {
647            return;
648        }
649    }
650    let repo_response = http_client
651        .send(
652            http_client::Request::builder()
653                .method(Method::HEAD)
654                .uri(format!("https://github.com/{}", repo))
655                .body(Default::default())
656                .expect(""),
657        )
658        .await
659        .expect("failed to check github repo");
660    if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
661        fs::write(&skip_eval_path, "").unwrap();
662        eprintln!(
663            "Repo {repo} is no longer public ({:?}). Skipping",
664            repo_response.status()
665        );
666        return;
667    }
668    if !repo_dir.join(".git").exists() {
669        let init_output = Command::new("git")
670            .current_dir(&repo_dir)
671            .args(&["init"])
672            .output()
673            .unwrap();
674        if !init_output.status.success() {
675            eprintln!(
676                "Failed to initialize git repository for {}: {}",
677                repo,
678                String::from_utf8_lossy(&init_output.stderr)
679            );
680            return;
681        }
682    }
683    let url = format!("https://github.com/{}.git", repo);
684    Command::new("git")
685        .current_dir(&repo_dir)
686        .args(&["remote", "add", "-f", "origin", &url])
687        .stdin(Stdio::null())
688        .output()
689        .unwrap();
690    let fetch_output = Command::new("git")
691        .current_dir(&repo_dir)
692        .args(&["fetch", "--depth", "1", "origin", &sha])
693        .stdin(Stdio::null())
694        .output()
695        .unwrap();
696    if !fetch_output.status.success() {
697        eprintln!(
698            "Failed to fetch {} for {}: {}",
699            sha,
700            repo,
701            String::from_utf8_lossy(&fetch_output.stderr)
702        );
703        return;
704    }
705    let checkout_output = Command::new("git")
706        .current_dir(&repo_dir)
707        .args(&["checkout", &sha])
708        .output()
709        .unwrap();
710
711    if !checkout_output.status.success() {
712        eprintln!(
713            "Failed to checkout {} for {}: {}",
714            sha,
715            repo,
716            String::from_utf8_lossy(&checkout_output.stderr)
717        );
718    }
719}