eval.rs

  1use ::fs::{Fs, RealFs};
  2use anyhow::Result;
  3use clap::Parser;
  4use client::{Client, UserStore};
  5use clock::RealSystemClock;
  6use collections::BTreeMap;
  7use feature_flags::FeatureFlagAppExt as _;
  8use git::GitHostingProviderRegistry;
  9use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Entity};
 10use http_client::{HttpClient, Method};
 11use language::LanguageRegistry;
 12use node_runtime::NodeRuntime;
 13use open_ai::OpenAiEmbeddingModel;
 14use project::Project;
 15use reqwest_client::ReqwestClient;
 16use semantic_index::{
 17    EmbeddingProvider, OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status,
 18};
 19use serde::{Deserialize, Serialize};
 20use settings::SettingsStore;
 21use smol::channel::bounded;
 22use smol::io::AsyncReadExt;
 23use smol::Timer;
 24use std::ops::RangeInclusive;
 25use std::path::PathBuf;
 26use std::time::Duration;
 27use std::{
 28    fs,
 29    path::Path,
 30    process::{exit, Stdio},
 31    sync::{
 32        atomic::{AtomicUsize, Ordering::SeqCst},
 33        Arc,
 34    },
 35};
 36
 37const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
 38const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
 39const EVAL_DB_PATH: &'static str = "target/eval_db";
 40const SEARCH_RESULT_LIMIT: usize = 8;
 41const SKIP_EVAL_PATH: &'static str = ".skip_eval";
 42
 43#[derive(clap::Parser)]
 44#[command(author, version, about, long_about = None)]
 45struct Cli {
 46    #[command(subcommand)]
 47    command: Commands,
 48}
 49
 50#[derive(clap::Subcommand)]
 51enum Commands {
 52    Fetch {},
 53    Run {
 54        #[arg(long)]
 55        repo: Option<String>,
 56    },
 57}
 58
 59#[derive(Clone, Deserialize, Serialize)]
 60struct EvaluationProject {
 61    repo: String,
 62    sha: String,
 63    queries: Vec<EvaluationQuery>,
 64}
 65
 66#[derive(Clone, Debug, Deserialize, Serialize)]
 67struct EvaluationQuery {
 68    query: String,
 69    expected_results: Vec<EvaluationSearchResult>,
 70}
 71
 72#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
 73struct EvaluationSearchResult {
 74    file: String,
 75    lines: RangeInclusive<u32>,
 76}
 77
 78#[derive(Clone, Deserialize, Serialize)]
 79struct EvaluationProjectOutcome {
 80    repo: String,
 81    sha: String,
 82    queries: Vec<EvaluationQueryOutcome>,
 83}
 84
 85#[derive(Clone, Debug, Deserialize, Serialize)]
 86struct EvaluationQueryOutcome {
 87    repo: String,
 88    query: String,
 89    expected_results: Vec<EvaluationSearchResult>,
 90    actual_results: Vec<EvaluationSearchResult>,
 91    covered_file_count: usize,
 92    overlapped_result_count: usize,
 93    covered_result_count: usize,
 94    total_result_count: usize,
 95    covered_result_indices: Vec<usize>,
 96}
 97
 98fn main() -> Result<()> {
 99    let cli = Cli::parse();
100    env_logger::init();
101
102    gpui::Application::headless().run(move |cx| {
103        let executor = cx.background_executor().clone();
104        let client = Arc::new(ReqwestClient::user_agent("Zed LLM evals").unwrap());
105        cx.set_http_client(client.clone());
106        match cli.command {
107            Commands::Fetch {} => {
108                executor
109                    .clone()
110                    .spawn(async move {
111                        if let Err(err) = fetch_evaluation_resources(client, &executor).await {
112                            eprintln!("Error: {}", err);
113                            exit(1);
114                        }
115                        exit(0);
116                    })
117                    .detach();
118            }
119            Commands::Run { repo } => {
120                cx.spawn(|mut cx| async move {
121                    if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
122                        eprintln!("Error: {}", err);
123                        exit(1);
124                    }
125                    exit(0);
126                })
127                .detach();
128            }
129        }
130    });
131
132    Ok(())
133}
134
135async fn fetch_evaluation_resources(
136    http_client: Arc<dyn HttpClient>,
137    executor: &BackgroundExecutor,
138) -> Result<()> {
139    fetch_code_search_net_resources(&*http_client).await?;
140    fetch_eval_repos(executor, &*http_client).await?;
141    Ok(())
142}
143
144async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
145    eprintln!("Fetching CodeSearchNet evaluations...");
146
147    let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
148
149    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
150    fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
151
152    // Fetch the annotations CSV, which contains the human-annotated search relevances
153    let annotations_path = dataset_dir.join("annotations.csv");
154    let annotations_csv_content = if annotations_path.exists() {
155        fs::read_to_string(&annotations_path).expect("failed to read annotations")
156    } else {
157        let response = http_client
158            .get(annotations_url, Default::default(), true)
159            .await
160            .expect("failed to fetch annotations csv");
161        let mut body = String::new();
162        response
163            .into_body()
164            .read_to_string(&mut body)
165            .await
166            .expect("failed to read annotations.csv response");
167        fs::write(annotations_path, &body).expect("failed to write annotations.csv");
168        body
169    };
170
171    // Parse the annotations CSV. Skip over queries with zero relevance.
172    let rows = annotations_csv_content.lines().filter_map(|line| {
173        let mut values = line.split(',');
174        let _language = values.next()?;
175        let query = values.next()?;
176        let github_url = values.next()?;
177        let score = values.next()?;
178
179        if score == "0" {
180            return None;
181        }
182
183        let url_path = github_url.strip_prefix("https://github.com/")?;
184        let (url_path, hash) = url_path.split_once('#')?;
185        let (repo_name, url_path) = url_path.split_once("/blob/")?;
186        let (sha, file_path) = url_path.split_once('/')?;
187        let line_range = if let Some((start, end)) = hash.split_once('-') {
188            start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
189        } else {
190            let row = hash.strip_prefix("L")?.parse().ok()?;
191            row..=row
192        };
193        Some((repo_name, sha, query, file_path, line_range))
194    });
195
196    // Group the annotations by repo and sha.
197    let mut evaluations_by_repo = BTreeMap::new();
198    for (repo_name, sha, query, file_path, lines) in rows {
199        let evaluation_project = evaluations_by_repo
200            .entry((repo_name, sha))
201            .or_insert_with(|| EvaluationProject {
202                repo: repo_name.to_string(),
203                sha: sha.to_string(),
204                queries: Vec::new(),
205            });
206
207        let ix = evaluation_project
208            .queries
209            .iter()
210            .position(|entry| entry.query == query)
211            .unwrap_or_else(|| {
212                evaluation_project.queries.push(EvaluationQuery {
213                    query: query.to_string(),
214                    expected_results: Vec::new(),
215                });
216                evaluation_project.queries.len() - 1
217            });
218        let results = &mut evaluation_project.queries[ix].expected_results;
219        let result = EvaluationSearchResult {
220            file: file_path.to_string(),
221            lines,
222        };
223        if !results.contains(&result) {
224            results.push(result);
225        }
226    }
227
228    let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
229    let evaluations_path = dataset_dir.join("evaluations.json");
230    fs::write(
231        &evaluations_path,
232        serde_json::to_vec_pretty(&evaluations).unwrap(),
233    )
234    .unwrap();
235
236    eprintln!(
237        "Fetched CodeSearchNet evaluations into {}",
238        evaluations_path.display()
239    );
240
241    Ok(())
242}
243
244#[derive(Default, Debug)]
245struct Counts {
246    covered_results: usize,
247    overlapped_results: usize,
248    covered_files: usize,
249    total_results: usize,
250}
251
252async fn run_evaluation(
253    only_repo: Option<String>,
254    executor: &BackgroundExecutor,
255    cx: &mut AsyncApp,
256) -> Result<()> {
257    let mut http_client = None;
258    cx.update(|cx| {
259        let mut store = SettingsStore::new(cx);
260        store
261            .set_default_settings(settings::default_settings().as_ref(), cx)
262            .unwrap();
263        cx.set_global(store);
264        client::init_settings(cx);
265        language::init(cx);
266        Project::init_settings(cx);
267        http_client = Some(cx.http_client());
268        cx.update_flags(false, vec![]);
269    })
270    .unwrap();
271    let http_client = http_client.unwrap();
272    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
273    let evaluations_path = dataset_dir.join("evaluations.json");
274    let repos_dir = Path::new(EVAL_REPOS_DIR);
275    let db_path = Path::new(EVAL_DB_PATH);
276    let api_key = std::env::var("OPENAI_API_KEY").unwrap();
277    let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new());
278    let fs = Arc::new(RealFs::new(git_hosting_provider_registry, None)) as Arc<dyn Fs>;
279    let clock = Arc::new(RealSystemClock);
280    let client = cx
281        .update(|cx| {
282            Client::new(
283                clock,
284                Arc::new(http_client::HttpClientWithUrl::new(
285                    http_client.clone(),
286                    "https://zed.dev",
287                    None,
288                )),
289                cx,
290            )
291        })
292        .unwrap();
293    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)).unwrap();
294    let node_runtime = NodeRuntime::unavailable();
295
296    let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
297    let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
298
299    let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
300        http_client.clone(),
301        OpenAiEmbeddingModel::TextEmbedding3Small,
302        open_ai::OPEN_AI_API_URL.to_string(),
303        api_key,
304    ));
305
306    let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
307    cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
308        .unwrap();
309
310    let mut counts = Counts::default();
311    eprint!("Running evals.");
312
313    let mut failures = Vec::new();
314
315    for evaluation_project in evaluations {
316        if only_repo
317            .as_ref()
318            .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
319        {
320            continue;
321        }
322
323        eprint!("\r\x1B[2K");
324        eprint!(
325            "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
326            counts.covered_results,
327            counts.total_results,
328            counts.overlapped_results,
329            counts.total_results,
330            counts.covered_files,
331            counts.total_results,
332            evaluation_project.repo
333        );
334
335        let repo_dir = repos_dir.join(&evaluation_project.repo);
336        if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
337            eprintln!("Skipping {}: directory not found", evaluation_project.repo);
338            continue;
339        }
340
341        let repo_db_path =
342            db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
343
344        let project = cx
345            .update(|cx| {
346                Project::local(
347                    client.clone(),
348                    node_runtime.clone(),
349                    user_store.clone(),
350                    language_registry.clone(),
351                    fs.clone(),
352                    None,
353                    cx,
354                )
355            })
356            .unwrap();
357
358        let repo = evaluation_project.repo.clone();
359        if let Err(err) = run_eval_project(
360            evaluation_project,
361            &user_store,
362            repo_db_path,
363            &repo_dir,
364            &mut counts,
365            project,
366            embedding_provider.clone(),
367            fs.clone(),
368            cx,
369        )
370        .await
371        {
372            eprintln!("{repo} eval failed with error: {:?}", err);
373
374            failures.push((repo, err));
375        }
376    }
377
378    eprintln!(
379        "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. {} failed.",
380        counts.covered_results,
381        counts.total_results,
382        counts.overlapped_results,
383        counts.total_results,
384        counts.covered_files,
385        counts.total_results,
386        failures.len(),
387    );
388
389    if failures.is_empty() {
390        Ok(())
391    } else {
392        eprintln!("Failures:\n");
393
394        for (index, (repo, failure)) in failures.iter().enumerate() {
395            eprintln!("Failure #{} - {repo}\n{:?}", index + 1, failure);
396        }
397
398        Err(anyhow::anyhow!("Some evals failed."))
399    }
400}
401
402async fn run_eval_project(
403    evaluation_project: EvaluationProject,
404    user_store: &Entity<UserStore>,
405    repo_db_path: PathBuf,
406    repo_dir: &Path,
407    counts: &mut Counts,
408    project: Entity<Project>,
409    embedding_provider: Arc<dyn EmbeddingProvider>,
410    fs: Arc<dyn Fs>,
411    cx: &mut AsyncApp,
412) -> Result<(), anyhow::Error> {
413    let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider, cx).await?;
414
415    let (worktree, _) = project
416        .update(cx, |project, cx| {
417            project.find_or_create_worktree(repo_dir, true, cx)
418        })?
419        .await?;
420
421    worktree
422        .update(cx, |worktree, _| {
423            worktree.as_local().unwrap().scan_complete()
424        })?
425        .await;
426
427    let project_index = cx.update(|cx| semantic_index.create_project_index(project.clone(), cx))?;
428    wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
429
430    for query in evaluation_project.queries {
431        let results = {
432            // Retry search up to 3 times in case of timeout, network failure, etc.
433            let mut retries_remaining = 3;
434            let mut result;
435
436            loop {
437                match cx.update(|cx| {
438                    let project_index = project_index.read(cx);
439                    project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx)
440                }) {
441                    Ok(task) => match task.await {
442                        Ok(answer) => {
443                            result = Ok(answer);
444                            break;
445                        }
446                        Err(err) => {
447                            result = Err(err);
448                        }
449                    },
450                    Err(err) => {
451                        result = Err(err);
452                    }
453                }
454
455                if retries_remaining > 0 {
456                    eprintln!(
457                        "Retrying search after it failed on query {:?} with {:?}",
458                        query, result
459                    );
460                    retries_remaining -= 1;
461                } else {
462                    eprintln!(
463                        "Ran out of retries; giving up on search which failed on query {:?} with {:?}",
464                        query, result
465                    );
466                    break;
467                }
468            }
469
470            SemanticDb::load_results(result?, &fs.clone(), &cx).await?
471        };
472
473        let mut project_covered_result_count = 0;
474        let mut project_overlapped_result_count = 0;
475        let mut project_covered_file_count = 0;
476        let mut covered_result_indices = Vec::new();
477        for expected_result in &query.expected_results {
478            let mut file_matched = false;
479            let mut range_overlapped = false;
480            let mut range_covered = false;
481
482            for (ix, result) in results.iter().enumerate() {
483                if result.path.as_ref() == Path::new(&expected_result.file) {
484                    file_matched = true;
485                    let start_matched = result.row_range.contains(&expected_result.lines.start());
486                    let end_matched = result.row_range.contains(&expected_result.lines.end());
487
488                    if start_matched || end_matched {
489                        range_overlapped = true;
490                    }
491
492                    if start_matched && end_matched {
493                        range_covered = true;
494                        covered_result_indices.push(ix);
495                        break;
496                    }
497                }
498            }
499
500            if range_covered {
501                project_covered_result_count += 1
502            };
503            if range_overlapped {
504                project_overlapped_result_count += 1
505            };
506            if file_matched {
507                project_covered_file_count += 1
508            };
509        }
510        let outcome_repo = evaluation_project.repo.clone();
511
512        let query_results = EvaluationQueryOutcome {
513            repo: outcome_repo,
514            query: query.query,
515            total_result_count: query.expected_results.len(),
516            covered_result_count: project_covered_result_count,
517            overlapped_result_count: project_overlapped_result_count,
518            covered_file_count: project_covered_file_count,
519            expected_results: query.expected_results,
520            actual_results: results
521                .iter()
522                .map(|result| EvaluationSearchResult {
523                    file: result.path.to_string_lossy().to_string(),
524                    lines: result.row_range.clone(),
525                })
526                .collect(),
527            covered_result_indices,
528        };
529
530        counts.overlapped_results += query_results.overlapped_result_count;
531        counts.covered_results += query_results.covered_result_count;
532        counts.covered_files += query_results.covered_file_count;
533        counts.total_results += query_results.total_result_count;
534
535        println!("{}", serde_json::to_string(&query_results)?);
536    }
537
538    user_store.update(cx, |_, _| {
539        drop(semantic_index);
540        drop(project);
541        drop(worktree);
542        drop(project_index);
543    })
544}
545
546async fn wait_for_indexing_complete(
547    project_index: &Entity<ProjectIndex>,
548    cx: &mut AsyncApp,
549    timeout: Option<Duration>,
550) {
551    let (tx, rx) = bounded(1);
552    let subscription = cx.update(|cx| {
553        cx.subscribe(project_index, move |_, event, _| {
554            if let Status::Idle = event {
555                let _ = tx.try_send(*event);
556            }
557        })
558    });
559
560    let result = match timeout {
561        Some(timeout_duration) => {
562            smol::future::or(
563                async {
564                    rx.recv().await.map_err(|_| ())?;
565                    Ok(())
566                },
567                async {
568                    Timer::after(timeout_duration).await;
569                    Err(())
570                },
571            )
572            .await
573        }
574        None => rx.recv().await.map(|_| ()).map_err(|_| ()),
575    };
576
577    match result {
578        Ok(_) => (),
579        Err(_) => {
580            if let Some(timeout) = timeout {
581                eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
582            }
583        }
584    }
585
586    drop(subscription);
587}
588
589async fn fetch_eval_repos(
590    executor: &BackgroundExecutor,
591    http_client: &dyn HttpClient,
592) -> Result<()> {
593    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
594    let evaluations_path = dataset_dir.join("evaluations.json");
595    let repos_dir = Path::new(EVAL_REPOS_DIR);
596
597    let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
598    let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
599
600    eprintln!("Fetching evaluation repositories...");
601
602    executor
603        .scoped(move |scope| {
604            let done_count = Arc::new(AtomicUsize::new(0));
605            let len = evaluations.len();
606            for chunk in evaluations.chunks(evaluations.len() / 8) {
607                let chunk = chunk.to_vec();
608                let done_count = done_count.clone();
609                scope.spawn(async move {
610                    for EvaluationProject { repo, sha, .. } in chunk {
611                        eprint!(
612                            "\rFetching evaluation repositories ({}/{})...",
613                            done_count.load(SeqCst),
614                            len,
615                        );
616
617                        fetch_eval_repo(repo, sha, repos_dir, http_client).await;
618                        done_count.fetch_add(1, SeqCst);
619                    }
620                });
621            }
622        })
623        .await;
624
625    Ok(())
626}
627
628async fn fetch_eval_repo(
629    repo: String,
630    sha: String,
631    repos_dir: &Path,
632    http_client: &dyn HttpClient,
633) {
634    let Some((owner, repo_name)) = repo.split_once('/') else {
635        return;
636    };
637    let repo_dir = repos_dir.join(owner).join(repo_name);
638    fs::create_dir_all(&repo_dir).unwrap();
639    let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
640    if skip_eval_path.exists() {
641        return;
642    }
643    if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
644        if head_content.trim() == sha {
645            return;
646        }
647    }
648    let repo_response = http_client
649        .send(
650            http_client::Request::builder()
651                .method(Method::HEAD)
652                .uri(format!("https://github.com/{}", repo))
653                .body(Default::default())
654                .expect(""),
655        )
656        .await
657        .expect("failed to check github repo");
658    if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
659        fs::write(&skip_eval_path, "").unwrap();
660        eprintln!(
661            "Repo {repo} is no longer public ({:?}). Skipping",
662            repo_response.status()
663        );
664        return;
665    }
666    if !repo_dir.join(".git").exists() {
667        let init_output = util::command::new_std_command("git")
668            .current_dir(&repo_dir)
669            .args(&["init"])
670            .output()
671            .unwrap();
672        if !init_output.status.success() {
673            eprintln!(
674                "Failed to initialize git repository for {}: {}",
675                repo,
676                String::from_utf8_lossy(&init_output.stderr)
677            );
678            return;
679        }
680    }
681    let url = format!("https://github.com/{}.git", repo);
682    util::command::new_std_command("git")
683        .current_dir(&repo_dir)
684        .args(&["remote", "add", "-f", "origin", &url])
685        .stdin(Stdio::null())
686        .output()
687        .unwrap();
688    let fetch_output = util::command::new_std_command("git")
689        .current_dir(&repo_dir)
690        .args(&["fetch", "--depth", "1", "origin", &sha])
691        .stdin(Stdio::null())
692        .output()
693        .unwrap();
694    if !fetch_output.status.success() {
695        eprintln!(
696            "Failed to fetch {} for {}: {}",
697            sha,
698            repo,
699            String::from_utf8_lossy(&fetch_output.stderr)
700        );
701        return;
702    }
703    let checkout_output = util::command::new_std_command("git")
704        .current_dir(&repo_dir)
705        .args(&["checkout", &sha])
706        .output()
707        .unwrap();
708
709    if !checkout_output.status.success() {
710        eprintln!(
711            "Failed to checkout {} for {}: {}",
712            sha,
713            repo,
714            String::from_utf8_lossy(&checkout_output.stderr)
715        );
716    }
717}