eval.rs

  1use ::fs::{Fs, RealFs};
  2use anyhow::Result;
  3use clap::Parser;
  4use client::{Client, UserStore};
  5use clock::RealSystemClock;
  6use collections::BTreeMap;
  7use feature_flags::FeatureFlagAppExt as _;
  8use git::GitHostingProviderRegistry;
  9use gpui::{AsyncAppContext, BackgroundExecutor, Context, Model};
 10use http_client::{HttpClient, Method};
 11use language::LanguageRegistry;
 12use node_runtime::NodeRuntime;
 13use open_ai::OpenAiEmbeddingModel;
 14use project::Project;
 15use reqwest_client::ReqwestClient;
 16use semantic_index::{
 17    EmbeddingProvider, OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status,
 18};
 19use serde::{Deserialize, Serialize};
 20use settings::SettingsStore;
 21use smol::channel::bounded;
 22use smol::io::AsyncReadExt;
 23use smol::Timer;
 24use std::ops::RangeInclusive;
 25use std::path::PathBuf;
 26use std::time::Duration;
 27use std::{
 28    fs,
 29    path::Path,
 30    process::{exit, Stdio},
 31    sync::{
 32        atomic::{AtomicUsize, Ordering::SeqCst},
 33        Arc,
 34    },
 35};
 36
 37const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
 38const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
 39const EVAL_DB_PATH: &'static str = "target/eval_db";
 40const SEARCH_RESULT_LIMIT: usize = 8;
 41const SKIP_EVAL_PATH: &'static str = ".skip_eval";
 42
 43#[derive(clap::Parser)]
 44#[command(author, version, about, long_about = None)]
 45struct Cli {
 46    #[command(subcommand)]
 47    command: Commands,
 48}
 49
 50#[derive(clap::Subcommand)]
 51enum Commands {
 52    Fetch {},
 53    Run {
 54        #[arg(long)]
 55        repo: Option<String>,
 56    },
 57}
 58
 59#[derive(Clone, Deserialize, Serialize)]
 60struct EvaluationProject {
 61    repo: String,
 62    sha: String,
 63    queries: Vec<EvaluationQuery>,
 64}
 65
 66#[derive(Clone, Debug, Deserialize, Serialize)]
 67struct EvaluationQuery {
 68    query: String,
 69    expected_results: Vec<EvaluationSearchResult>,
 70}
 71
 72#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
 73struct EvaluationSearchResult {
 74    file: String,
 75    lines: RangeInclusive<u32>,
 76}
 77
 78#[derive(Clone, Deserialize, Serialize)]
 79struct EvaluationProjectOutcome {
 80    repo: String,
 81    sha: String,
 82    queries: Vec<EvaluationQueryOutcome>,
 83}
 84
 85#[derive(Clone, Debug, Deserialize, Serialize)]
 86struct EvaluationQueryOutcome {
 87    repo: String,
 88    query: String,
 89    expected_results: Vec<EvaluationSearchResult>,
 90    actual_results: Vec<EvaluationSearchResult>,
 91    covered_file_count: usize,
 92    overlapped_result_count: usize,
 93    covered_result_count: usize,
 94    total_result_count: usize,
 95    covered_result_indices: Vec<usize>,
 96}
 97
 98fn main() -> Result<()> {
 99    let cli = Cli::parse();
100    env_logger::init();
101
102    gpui::App::headless().run(move |cx| {
103        let executor = cx.background_executor().clone();
104        let client = Arc::new(ReqwestClient::user_agent("Zed LLM evals").unwrap());
105        cx.set_http_client(client.clone());
106        match cli.command {
107            Commands::Fetch {} => {
108                executor
109                    .clone()
110                    .spawn(async move {
111                        if let Err(err) = fetch_evaluation_resources(client, &executor).await {
112                            eprintln!("Error: {}", err);
113                            exit(1);
114                        }
115                        exit(0);
116                    })
117                    .detach();
118            }
119            Commands::Run { repo } => {
120                cx.spawn(|mut cx| async move {
121                    if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
122                        eprintln!("Error: {}", err);
123                        exit(1);
124                    }
125                    exit(0);
126                })
127                .detach();
128            }
129        }
130    });
131
132    Ok(())
133}
134
135async fn fetch_evaluation_resources(
136    http_client: Arc<dyn HttpClient>,
137    executor: &BackgroundExecutor,
138) -> Result<()> {
139    fetch_code_search_net_resources(&*http_client).await?;
140    fetch_eval_repos(executor, &*http_client).await?;
141    Ok(())
142}
143
144async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
145    eprintln!("Fetching CodeSearchNet evaluations...");
146
147    let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
148
149    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
150    fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
151
152    // Fetch the annotations CSV, which contains the human-annotated search relevances
153    let annotations_path = dataset_dir.join("annotations.csv");
154    let annotations_csv_content = if annotations_path.exists() {
155        fs::read_to_string(&annotations_path).expect("failed to read annotations")
156    } else {
157        let response = http_client
158            .get(annotations_url, Default::default(), true)
159            .await
160            .expect("failed to fetch annotations csv");
161        let mut body = String::new();
162        response
163            .into_body()
164            .read_to_string(&mut body)
165            .await
166            .expect("failed to read annotations.csv response");
167        fs::write(annotations_path, &body).expect("failed to write annotations.csv");
168        body
169    };
170
171    // Parse the annotations CSV. Skip over queries with zero relevance.
172    let rows = annotations_csv_content.lines().filter_map(|line| {
173        let mut values = line.split(',');
174        let _language = values.next()?;
175        let query = values.next()?;
176        let github_url = values.next()?;
177        let score = values.next()?;
178
179        if score == "0" {
180            return None;
181        }
182
183        let url_path = github_url.strip_prefix("https://github.com/")?;
184        let (url_path, hash) = url_path.split_once('#')?;
185        let (repo_name, url_path) = url_path.split_once("/blob/")?;
186        let (sha, file_path) = url_path.split_once('/')?;
187        let line_range = if let Some((start, end)) = hash.split_once('-') {
188            start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
189        } else {
190            let row = hash.strip_prefix("L")?.parse().ok()?;
191            row..=row
192        };
193        Some((repo_name, sha, query, file_path, line_range))
194    });
195
196    // Group the annotations by repo and sha.
197    let mut evaluations_by_repo = BTreeMap::new();
198    for (repo_name, sha, query, file_path, lines) in rows {
199        let evaluation_project = evaluations_by_repo
200            .entry((repo_name, sha))
201            .or_insert_with(|| EvaluationProject {
202                repo: repo_name.to_string(),
203                sha: sha.to_string(),
204                queries: Vec::new(),
205            });
206
207        let ix = evaluation_project
208            .queries
209            .iter()
210            .position(|entry| entry.query == query)
211            .unwrap_or_else(|| {
212                evaluation_project.queries.push(EvaluationQuery {
213                    query: query.to_string(),
214                    expected_results: Vec::new(),
215                });
216                evaluation_project.queries.len() - 1
217            });
218        let results = &mut evaluation_project.queries[ix].expected_results;
219        let result = EvaluationSearchResult {
220            file: file_path.to_string(),
221            lines,
222        };
223        if !results.contains(&result) {
224            results.push(result);
225        }
226    }
227
228    let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
229    let evaluations_path = dataset_dir.join("evaluations.json");
230    fs::write(
231        &evaluations_path,
232        serde_json::to_vec_pretty(&evaluations).unwrap(),
233    )
234    .unwrap();
235
236    eprintln!(
237        "Fetched CodeSearchNet evaluations into {}",
238        evaluations_path.display()
239    );
240
241    Ok(())
242}
243
244#[derive(Default, Debug)]
245struct Counts {
246    covered_results: usize,
247    overlapped_results: usize,
248    covered_files: usize,
249    total_results: usize,
250}
251
252async fn run_evaluation(
253    only_repo: Option<String>,
254    executor: &BackgroundExecutor,
255    cx: &mut AsyncAppContext,
256) -> Result<()> {
257    let mut http_client = None;
258    cx.update(|cx| {
259        let mut store = SettingsStore::new(cx);
260        store
261            .set_default_settings(settings::default_settings().as_ref(), cx)
262            .unwrap();
263        cx.set_global(store);
264        client::init_settings(cx);
265        language::init(cx);
266        Project::init_settings(cx);
267        http_client = Some(cx.http_client());
268        cx.update_flags(false, vec![]);
269    })
270    .unwrap();
271    let http_client = http_client.unwrap();
272    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
273    let evaluations_path = dataset_dir.join("evaluations.json");
274    let repos_dir = Path::new(EVAL_REPOS_DIR);
275    let db_path = Path::new(EVAL_DB_PATH);
276    let api_key = std::env::var("OPENAI_API_KEY").unwrap();
277    let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new());
278    let fs = Arc::new(RealFs::new(git_hosting_provider_registry, None)) as Arc<dyn Fs>;
279    let clock = Arc::new(RealSystemClock);
280    let client = cx
281        .update(|cx| {
282            Client::new(
283                clock,
284                Arc::new(http_client::HttpClientWithUrl::new(
285                    http_client.clone(),
286                    "https://zed.dev",
287                    None,
288                )),
289                cx,
290            )
291        })
292        .unwrap();
293    let user_store = cx
294        .new_model(|cx| UserStore::new(client.clone(), cx))
295        .unwrap();
296    let node_runtime = NodeRuntime::unavailable();
297
298    let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
299    let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
300
301    let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
302        http_client.clone(),
303        OpenAiEmbeddingModel::TextEmbedding3Small,
304        open_ai::OPEN_AI_API_URL.to_string(),
305        api_key,
306    ));
307
308    let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
309    cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
310        .unwrap();
311
312    let mut counts = Counts::default();
313    eprint!("Running evals.");
314
315    let mut failures = Vec::new();
316
317    for evaluation_project in evaluations {
318        if only_repo
319            .as_ref()
320            .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
321        {
322            continue;
323        }
324
325        eprint!("\r\x1B[2K");
326        eprint!(
327            "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
328            counts.covered_results,
329            counts.total_results,
330            counts.overlapped_results,
331            counts.total_results,
332            counts.covered_files,
333            counts.total_results,
334            evaluation_project.repo
335        );
336
337        let repo_dir = repos_dir.join(&evaluation_project.repo);
338        if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
339            eprintln!("Skipping {}: directory not found", evaluation_project.repo);
340            continue;
341        }
342
343        let repo_db_path =
344            db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
345
346        let project = cx
347            .update(|cx| {
348                Project::local(
349                    client.clone(),
350                    node_runtime.clone(),
351                    user_store.clone(),
352                    language_registry.clone(),
353                    fs.clone(),
354                    None,
355                    cx,
356                )
357            })
358            .unwrap();
359
360        let repo = evaluation_project.repo.clone();
361        if let Err(err) = run_eval_project(
362            evaluation_project,
363            &user_store,
364            repo_db_path,
365            &repo_dir,
366            &mut counts,
367            project,
368            embedding_provider.clone(),
369            fs.clone(),
370            cx,
371        )
372        .await
373        {
374            eprintln!("{repo} eval failed with error: {:?}", err);
375
376            failures.push((repo, err));
377        }
378    }
379
380    eprintln!(
381        "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. {} failed.",
382        counts.covered_results,
383        counts.total_results,
384        counts.overlapped_results,
385        counts.total_results,
386        counts.covered_files,
387        counts.total_results,
388        failures.len(),
389    );
390
391    if failures.is_empty() {
392        Ok(())
393    } else {
394        eprintln!("Failures:\n");
395
396        for (index, (repo, failure)) in failures.iter().enumerate() {
397            eprintln!("Failure #{} - {repo}\n{:?}", index + 1, failure);
398        }
399
400        Err(anyhow::anyhow!("Some evals failed."))
401    }
402}
403
404#[allow(clippy::too_many_arguments)]
405async fn run_eval_project(
406    evaluation_project: EvaluationProject,
407    user_store: &Model<UserStore>,
408    repo_db_path: PathBuf,
409    repo_dir: &Path,
410    counts: &mut Counts,
411    project: Model<Project>,
412    embedding_provider: Arc<dyn EmbeddingProvider>,
413    fs: Arc<dyn Fs>,
414    cx: &mut AsyncAppContext,
415) -> Result<(), anyhow::Error> {
416    let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider, cx).await?;
417
418    let (worktree, _) = project
419        .update(cx, |project, cx| {
420            project.find_or_create_worktree(repo_dir, true, cx)
421        })?
422        .await?;
423
424    worktree
425        .update(cx, |worktree, _| {
426            worktree.as_local().unwrap().scan_complete()
427        })?
428        .await;
429
430    let project_index = cx.update(|cx| semantic_index.create_project_index(project.clone(), cx))?;
431    wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
432
433    for query in evaluation_project.queries {
434        let results = {
435            // Retry search up to 3 times in case of timeout, network failure, etc.
436            let mut retries_remaining = 3;
437            let mut result;
438
439            loop {
440                match cx.update(|cx| {
441                    let project_index = project_index.read(cx);
442                    project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx)
443                }) {
444                    Ok(task) => match task.await {
445                        Ok(answer) => {
446                            result = Ok(answer);
447                            break;
448                        }
449                        Err(err) => {
450                            result = Err(err);
451                        }
452                    },
453                    Err(err) => {
454                        result = Err(err);
455                    }
456                }
457
458                if retries_remaining > 0 {
459                    eprintln!(
460                        "Retrying search after it failed on query {:?} with {:?}",
461                        query, result
462                    );
463                    retries_remaining -= 1;
464                } else {
465                    eprintln!(
466                        "Ran out of retries; giving up on search which failed on query {:?} with {:?}",
467                        query, result
468                    );
469                    break;
470                }
471            }
472
473            SemanticDb::load_results(result?, &fs.clone(), &cx).await?
474        };
475
476        let mut project_covered_result_count = 0;
477        let mut project_overlapped_result_count = 0;
478        let mut project_covered_file_count = 0;
479        let mut covered_result_indices = Vec::new();
480        for expected_result in &query.expected_results {
481            let mut file_matched = false;
482            let mut range_overlapped = false;
483            let mut range_covered = false;
484
485            for (ix, result) in results.iter().enumerate() {
486                if result.path.as_ref() == Path::new(&expected_result.file) {
487                    file_matched = true;
488                    let start_matched = result.row_range.contains(&expected_result.lines.start());
489                    let end_matched = result.row_range.contains(&expected_result.lines.end());
490
491                    if start_matched || end_matched {
492                        range_overlapped = true;
493                    }
494
495                    if start_matched && end_matched {
496                        range_covered = true;
497                        covered_result_indices.push(ix);
498                        break;
499                    }
500                }
501            }
502
503            if range_covered {
504                project_covered_result_count += 1
505            };
506            if range_overlapped {
507                project_overlapped_result_count += 1
508            };
509            if file_matched {
510                project_covered_file_count += 1
511            };
512        }
513        let outcome_repo = evaluation_project.repo.clone();
514
515        let query_results = EvaluationQueryOutcome {
516            repo: outcome_repo,
517            query: query.query,
518            total_result_count: query.expected_results.len(),
519            covered_result_count: project_covered_result_count,
520            overlapped_result_count: project_overlapped_result_count,
521            covered_file_count: project_covered_file_count,
522            expected_results: query.expected_results,
523            actual_results: results
524                .iter()
525                .map(|result| EvaluationSearchResult {
526                    file: result.path.to_string_lossy().to_string(),
527                    lines: result.row_range.clone(),
528                })
529                .collect(),
530            covered_result_indices,
531        };
532
533        counts.overlapped_results += query_results.overlapped_result_count;
534        counts.covered_results += query_results.covered_result_count;
535        counts.covered_files += query_results.covered_file_count;
536        counts.total_results += query_results.total_result_count;
537
538        println!("{}", serde_json::to_string(&query_results)?);
539    }
540
541    user_store.update(cx, |_, _| {
542        drop(semantic_index);
543        drop(project);
544        drop(worktree);
545        drop(project_index);
546    })
547}
548
549async fn wait_for_indexing_complete(
550    project_index: &Model<ProjectIndex>,
551    cx: &mut AsyncAppContext,
552    timeout: Option<Duration>,
553) {
554    let (tx, rx) = bounded(1);
555    let subscription = cx.update(|cx| {
556        cx.subscribe(project_index, move |_, event, _| {
557            if let Status::Idle = event {
558                let _ = tx.try_send(*event);
559            }
560        })
561    });
562
563    let result = match timeout {
564        Some(timeout_duration) => {
565            smol::future::or(
566                async {
567                    rx.recv().await.map_err(|_| ())?;
568                    Ok(())
569                },
570                async {
571                    Timer::after(timeout_duration).await;
572                    Err(())
573                },
574            )
575            .await
576        }
577        None => rx.recv().await.map(|_| ()).map_err(|_| ()),
578    };
579
580    match result {
581        Ok(_) => (),
582        Err(_) => {
583            if let Some(timeout) = timeout {
584                eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
585            }
586        }
587    }
588
589    drop(subscription);
590}
591
592async fn fetch_eval_repos(
593    executor: &BackgroundExecutor,
594    http_client: &dyn HttpClient,
595) -> Result<()> {
596    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
597    let evaluations_path = dataset_dir.join("evaluations.json");
598    let repos_dir = Path::new(EVAL_REPOS_DIR);
599
600    let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
601    let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
602
603    eprintln!("Fetching evaluation repositories...");
604
605    executor
606        .scoped(move |scope| {
607            let done_count = Arc::new(AtomicUsize::new(0));
608            let len = evaluations.len();
609            for chunk in evaluations.chunks(evaluations.len() / 8) {
610                let chunk = chunk.to_vec();
611                let done_count = done_count.clone();
612                scope.spawn(async move {
613                    for EvaluationProject { repo, sha, .. } in chunk {
614                        eprint!(
615                            "\rFetching evaluation repositories ({}/{})...",
616                            done_count.load(SeqCst),
617                            len,
618                        );
619
620                        fetch_eval_repo(repo, sha, repos_dir, http_client).await;
621                        done_count.fetch_add(1, SeqCst);
622                    }
623                });
624            }
625        })
626        .await;
627
628    Ok(())
629}
630
631async fn fetch_eval_repo(
632    repo: String,
633    sha: String,
634    repos_dir: &Path,
635    http_client: &dyn HttpClient,
636) {
637    let Some((owner, repo_name)) = repo.split_once('/') else {
638        return;
639    };
640    let repo_dir = repos_dir.join(owner).join(repo_name);
641    fs::create_dir_all(&repo_dir).unwrap();
642    let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
643    if skip_eval_path.exists() {
644        return;
645    }
646    if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
647        if head_content.trim() == sha {
648            return;
649        }
650    }
651    let repo_response = http_client
652        .send(
653            http_client::Request::builder()
654                .method(Method::HEAD)
655                .uri(format!("https://github.com/{}", repo))
656                .body(Default::default())
657                .expect(""),
658        )
659        .await
660        .expect("failed to check github repo");
661    if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
662        fs::write(&skip_eval_path, "").unwrap();
663        eprintln!(
664            "Repo {repo} is no longer public ({:?}). Skipping",
665            repo_response.status()
666        );
667        return;
668    }
669    if !repo_dir.join(".git").exists() {
670        let init_output = util::command::new_std_command("git")
671            .current_dir(&repo_dir)
672            .args(&["init"])
673            .output()
674            .unwrap();
675        if !init_output.status.success() {
676            eprintln!(
677                "Failed to initialize git repository for {}: {}",
678                repo,
679                String::from_utf8_lossy(&init_output.stderr)
680            );
681            return;
682        }
683    }
684    let url = format!("https://github.com/{}.git", repo);
685    util::command::new_std_command("git")
686        .current_dir(&repo_dir)
687        .args(&["remote", "add", "-f", "origin", &url])
688        .stdin(Stdio::null())
689        .output()
690        .unwrap();
691    let fetch_output = util::command::new_std_command("git")
692        .current_dir(&repo_dir)
693        .args(&["fetch", "--depth", "1", "origin", &sha])
694        .stdin(Stdio::null())
695        .output()
696        .unwrap();
697    if !fetch_output.status.success() {
698        eprintln!(
699            "Failed to fetch {} for {}: {}",
700            sha,
701            repo,
702            String::from_utf8_lossy(&fetch_output.stderr)
703        );
704        return;
705    }
706    let checkout_output = util::command::new_std_command("git")
707        .current_dir(&repo_dir)
708        .args(&["checkout", &sha])
709        .output()
710        .unwrap();
711
712    if !checkout_output.status.success() {
713        eprintln!(
714            "Failed to checkout {} for {}: {}",
715            sha,
716            repo,
717            String::from_utf8_lossy(&checkout_output.stderr)
718        );
719    }
720}