eval.rs

  1use ::fs::{Fs, RealFs};
  2use anyhow::Result;
  3use clap::Parser;
  4use client::{Client, UserStore};
  5use clock::RealSystemClock;
  6use collections::BTreeMap;
  7use feature_flags::FeatureFlagAppExt as _;
  8use git::GitHostingProviderRegistry;
  9use gpui::{AsyncAppContext, BackgroundExecutor, Context, Model};
 10use http_client::{HttpClient, Method};
 11use language::LanguageRegistry;
 12use node_runtime::NodeRuntime;
 13use open_ai::OpenAiEmbeddingModel;
 14use project::Project;
 15use semantic_index::{
 16    EmbeddingProvider, OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status,
 17};
 18use serde::{Deserialize, Serialize};
 19use settings::SettingsStore;
 20use smol::channel::bounded;
 21use smol::io::AsyncReadExt;
 22use smol::Timer;
 23use std::ops::RangeInclusive;
 24use std::path::PathBuf;
 25use std::time::Duration;
 26use std::{
 27    fs,
 28    path::Path,
 29    process::{exit, Command, Stdio},
 30    sync::{
 31        atomic::{AtomicUsize, Ordering::SeqCst},
 32        Arc,
 33    },
 34};
 35use ureq_client::UreqClient;
 36
 37const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
 38const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
 39const EVAL_DB_PATH: &'static str = "target/eval_db";
 40const SEARCH_RESULT_LIMIT: usize = 8;
 41const SKIP_EVAL_PATH: &'static str = ".skip_eval";
 42
 43#[derive(clap::Parser)]
 44#[command(author, version, about, long_about = None)]
 45struct Cli {
 46    #[command(subcommand)]
 47    command: Commands,
 48}
 49
 50#[derive(clap::Subcommand)]
 51enum Commands {
 52    Fetch {},
 53    Run {
 54        #[arg(long)]
 55        repo: Option<String>,
 56    },
 57}
 58
 59#[derive(Clone, Deserialize, Serialize)]
 60struct EvaluationProject {
 61    repo: String,
 62    sha: String,
 63    queries: Vec<EvaluationQuery>,
 64}
 65
 66#[derive(Clone, Debug, Deserialize, Serialize)]
 67struct EvaluationQuery {
 68    query: String,
 69    expected_results: Vec<EvaluationSearchResult>,
 70}
 71
 72#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
 73struct EvaluationSearchResult {
 74    file: String,
 75    lines: RangeInclusive<u32>,
 76}
 77
 78#[derive(Clone, Deserialize, Serialize)]
 79struct EvaluationProjectOutcome {
 80    repo: String,
 81    sha: String,
 82    queries: Vec<EvaluationQueryOutcome>,
 83}
 84
 85#[derive(Clone, Debug, Deserialize, Serialize)]
 86struct EvaluationQueryOutcome {
 87    repo: String,
 88    query: String,
 89    expected_results: Vec<EvaluationSearchResult>,
 90    actual_results: Vec<EvaluationSearchResult>,
 91    covered_file_count: usize,
 92    overlapped_result_count: usize,
 93    covered_result_count: usize,
 94    total_result_count: usize,
 95    covered_result_indices: Vec<usize>,
 96}
 97
 98fn main() -> Result<()> {
 99    let cli = Cli::parse();
100    env_logger::init();
101
102    gpui::App::headless().run(move |cx| {
103        let executor = cx.background_executor().clone();
104        let client = Arc::new(UreqClient::new(
105            None,
106            "Zed LLM evals".to_string(),
107            executor.clone(),
108        ));
109        cx.set_http_client(client.clone());
110        match cli.command {
111            Commands::Fetch {} => {
112                executor
113                    .clone()
114                    .spawn(async move {
115                        if let Err(err) = fetch_evaluation_resources(client, &executor).await {
116                            eprintln!("Error: {}", err);
117                            exit(1);
118                        }
119                        exit(0);
120                    })
121                    .detach();
122            }
123            Commands::Run { repo } => {
124                cx.spawn(|mut cx| async move {
125                    if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
126                        eprintln!("Error: {}", err);
127                        exit(1);
128                    }
129                    exit(0);
130                })
131                .detach();
132            }
133        }
134    });
135
136    Ok(())
137}
138
139async fn fetch_evaluation_resources(
140    http_client: Arc<dyn HttpClient>,
141    executor: &BackgroundExecutor,
142) -> Result<()> {
143    fetch_code_search_net_resources(&*http_client).await?;
144    fetch_eval_repos(executor, &*http_client).await?;
145    Ok(())
146}
147
148async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
149    eprintln!("Fetching CodeSearchNet evaluations...");
150
151    let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
152
153    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
154    fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
155
156    // Fetch the annotations CSV, which contains the human-annotated search relevances
157    let annotations_path = dataset_dir.join("annotations.csv");
158    let annotations_csv_content = if annotations_path.exists() {
159        fs::read_to_string(&annotations_path).expect("failed to read annotations")
160    } else {
161        let response = http_client
162            .get(annotations_url, Default::default(), true)
163            .await
164            .expect("failed to fetch annotations csv");
165        let mut body = String::new();
166        response
167            .into_body()
168            .read_to_string(&mut body)
169            .await
170            .expect("failed to read annotations.csv response");
171        fs::write(annotations_path, &body).expect("failed to write annotations.csv");
172        body
173    };
174
175    // Parse the annotations CSV. Skip over queries with zero relevance.
176    let rows = annotations_csv_content.lines().filter_map(|line| {
177        let mut values = line.split(',');
178        let _language = values.next()?;
179        let query = values.next()?;
180        let github_url = values.next()?;
181        let score = values.next()?;
182
183        if score == "0" {
184            return None;
185        }
186
187        let url_path = github_url.strip_prefix("https://github.com/")?;
188        let (url_path, hash) = url_path.split_once('#')?;
189        let (repo_name, url_path) = url_path.split_once("/blob/")?;
190        let (sha, file_path) = url_path.split_once('/')?;
191        let line_range = if let Some((start, end)) = hash.split_once('-') {
192            start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
193        } else {
194            let row = hash.strip_prefix("L")?.parse().ok()?;
195            row..=row
196        };
197        Some((repo_name, sha, query, file_path, line_range))
198    });
199
200    // Group the annotations by repo and sha.
201    let mut evaluations_by_repo = BTreeMap::new();
202    for (repo_name, sha, query, file_path, lines) in rows {
203        let evaluation_project = evaluations_by_repo
204            .entry((repo_name, sha))
205            .or_insert_with(|| EvaluationProject {
206                repo: repo_name.to_string(),
207                sha: sha.to_string(),
208                queries: Vec::new(),
209            });
210
211        let ix = evaluation_project
212            .queries
213            .iter()
214            .position(|entry| entry.query == query)
215            .unwrap_or_else(|| {
216                evaluation_project.queries.push(EvaluationQuery {
217                    query: query.to_string(),
218                    expected_results: Vec::new(),
219                });
220                evaluation_project.queries.len() - 1
221            });
222        let results = &mut evaluation_project.queries[ix].expected_results;
223        let result = EvaluationSearchResult {
224            file: file_path.to_string(),
225            lines,
226        };
227        if !results.contains(&result) {
228            results.push(result);
229        }
230    }
231
232    let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
233    let evaluations_path = dataset_dir.join("evaluations.json");
234    fs::write(
235        &evaluations_path,
236        serde_json::to_vec_pretty(&evaluations).unwrap(),
237    )
238    .unwrap();
239
240    eprintln!(
241        "Fetched CodeSearchNet evaluations into {}",
242        evaluations_path.display()
243    );
244
245    Ok(())
246}
247
248#[derive(Default, Debug)]
249struct Counts {
250    covered_results: usize,
251    overlapped_results: usize,
252    covered_files: usize,
253    total_results: usize,
254}
255
256async fn run_evaluation(
257    only_repo: Option<String>,
258    executor: &BackgroundExecutor,
259    cx: &mut AsyncAppContext,
260) -> Result<()> {
261    let mut http_client = None;
262    cx.update(|cx| {
263        let mut store = SettingsStore::new(cx);
264        store
265            .set_default_settings(settings::default_settings().as_ref(), cx)
266            .unwrap();
267        cx.set_global(store);
268        client::init_settings(cx);
269        language::init(cx);
270        Project::init_settings(cx);
271        http_client = Some(cx.http_client());
272        cx.update_flags(false, vec![]);
273    })
274    .unwrap();
275    let http_client = http_client.unwrap();
276    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
277    let evaluations_path = dataset_dir.join("evaluations.json");
278    let repos_dir = Path::new(EVAL_REPOS_DIR);
279    let db_path = Path::new(EVAL_DB_PATH);
280    let api_key = std::env::var("OPENAI_API_KEY").unwrap();
281    let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new());
282    let fs = Arc::new(RealFs::new(git_hosting_provider_registry, None)) as Arc<dyn Fs>;
283    let clock = Arc::new(RealSystemClock);
284    let client = cx
285        .update(|cx| {
286            Client::new(
287                clock,
288                Arc::new(http_client::HttpClientWithUrl::new(
289                    http_client.clone(),
290                    "https://zed.dev",
291                    None,
292                )),
293                cx,
294            )
295        })
296        .unwrap();
297    let user_store = cx
298        .new_model(|cx| UserStore::new(client.clone(), cx))
299        .unwrap();
300    let node_runtime = NodeRuntime::unavailable();
301
302    let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
303    let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
304
305    let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
306        http_client.clone(),
307        OpenAiEmbeddingModel::TextEmbedding3Small,
308        open_ai::OPEN_AI_API_URL.to_string(),
309        api_key,
310    ));
311
312    let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
313    cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
314        .unwrap();
315
316    let mut counts = Counts::default();
317    eprint!("Running evals.");
318
319    let mut failures = Vec::new();
320
321    for evaluation_project in evaluations {
322        if only_repo
323            .as_ref()
324            .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
325        {
326            continue;
327        }
328
329        eprint!("\r\x1B[2K");
330        eprint!(
331            "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
332            counts.covered_results,
333            counts.total_results,
334            counts.overlapped_results,
335            counts.total_results,
336            counts.covered_files,
337            counts.total_results,
338            evaluation_project.repo
339        );
340
341        let repo_dir = repos_dir.join(&evaluation_project.repo);
342        if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
343            eprintln!("Skipping {}: directory not found", evaluation_project.repo);
344            continue;
345        }
346
347        let repo_db_path =
348            db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
349
350        let project = cx
351            .update(|cx| {
352                Project::local(
353                    client.clone(),
354                    node_runtime.clone(),
355                    user_store.clone(),
356                    language_registry.clone(),
357                    fs.clone(),
358                    None,
359                    cx,
360                )
361            })
362            .unwrap();
363
364        let repo = evaluation_project.repo.clone();
365        if let Err(err) = run_eval_project(
366            evaluation_project,
367            &user_store,
368            repo_db_path,
369            &repo_dir,
370            &mut counts,
371            project,
372            embedding_provider.clone(),
373            fs.clone(),
374            cx,
375        )
376        .await
377        {
378            eprintln!("{repo} eval failed with error: {:?}", err);
379
380            failures.push((repo, err));
381        }
382    }
383
384    eprintln!(
385        "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. {} failed.",
386        counts.covered_results,
387        counts.total_results,
388        counts.overlapped_results,
389        counts.total_results,
390        counts.covered_files,
391        counts.total_results,
392        failures.len(),
393    );
394
395    if failures.is_empty() {
396        Ok(())
397    } else {
398        eprintln!("Failures:\n");
399
400        for (index, (repo, failure)) in failures.iter().enumerate() {
401            eprintln!("Failure #{} - {repo}\n{:?}", index + 1, failure);
402        }
403
404        Err(anyhow::anyhow!("Some evals failed."))
405    }
406}
407
408#[allow(clippy::too_many_arguments)]
409async fn run_eval_project(
410    evaluation_project: EvaluationProject,
411    user_store: &Model<UserStore>,
412    repo_db_path: PathBuf,
413    repo_dir: &Path,
414    counts: &mut Counts,
415    project: Model<Project>,
416    embedding_provider: Arc<dyn EmbeddingProvider>,
417    fs: Arc<dyn Fs>,
418    cx: &mut AsyncAppContext,
419) -> Result<(), anyhow::Error> {
420    let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider, cx).await?;
421
422    let (worktree, _) = project
423        .update(cx, |project, cx| {
424            project.find_or_create_worktree(repo_dir, true, cx)
425        })?
426        .await?;
427
428    worktree
429        .update(cx, |worktree, _| {
430            worktree.as_local().unwrap().scan_complete()
431        })?
432        .await;
433
434    let project_index = cx.update(|cx| semantic_index.create_project_index(project.clone(), cx))?;
435    wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
436
437    for query in evaluation_project.queries {
438        let results = {
439            // Retry search up to 3 times in case of timeout, network failure, etc.
440            let mut retries_remaining = 3;
441            let mut result;
442
443            loop {
444                match cx.update(|cx| {
445                    let project_index = project_index.read(cx);
446                    project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx)
447                }) {
448                    Ok(task) => match task.await {
449                        Ok(answer) => {
450                            result = Ok(answer);
451                            break;
452                        }
453                        Err(err) => {
454                            result = Err(err);
455                        }
456                    },
457                    Err(err) => {
458                        result = Err(err);
459                    }
460                }
461
462                if retries_remaining > 0 {
463                    eprintln!(
464                        "Retrying search after it failed on query {:?} with {:?}",
465                        query, result
466                    );
467                    retries_remaining -= 1;
468                } else {
469                    eprintln!(
470                        "Ran out of retries; giving up on search which failed on query {:?} with {:?}",
471                        query, result
472                    );
473                    break;
474                }
475            }
476
477            SemanticDb::load_results(result?, &fs.clone(), &cx).await?
478        };
479
480        let mut project_covered_result_count = 0;
481        let mut project_overlapped_result_count = 0;
482        let mut project_covered_file_count = 0;
483        let mut covered_result_indices = Vec::new();
484        for expected_result in &query.expected_results {
485            let mut file_matched = false;
486            let mut range_overlapped = false;
487            let mut range_covered = false;
488
489            for (ix, result) in results.iter().enumerate() {
490                if result.path.as_ref() == Path::new(&expected_result.file) {
491                    file_matched = true;
492                    let start_matched = result.row_range.contains(&expected_result.lines.start());
493                    let end_matched = result.row_range.contains(&expected_result.lines.end());
494
495                    if start_matched || end_matched {
496                        range_overlapped = true;
497                    }
498
499                    if start_matched && end_matched {
500                        range_covered = true;
501                        covered_result_indices.push(ix);
502                        break;
503                    }
504                }
505            }
506
507            if range_covered {
508                project_covered_result_count += 1
509            };
510            if range_overlapped {
511                project_overlapped_result_count += 1
512            };
513            if file_matched {
514                project_covered_file_count += 1
515            };
516        }
517        let outcome_repo = evaluation_project.repo.clone();
518
519        let query_results = EvaluationQueryOutcome {
520            repo: outcome_repo,
521            query: query.query,
522            total_result_count: query.expected_results.len(),
523            covered_result_count: project_covered_result_count,
524            overlapped_result_count: project_overlapped_result_count,
525            covered_file_count: project_covered_file_count,
526            expected_results: query.expected_results,
527            actual_results: results
528                .iter()
529                .map(|result| EvaluationSearchResult {
530                    file: result.path.to_string_lossy().to_string(),
531                    lines: result.row_range.clone(),
532                })
533                .collect(),
534            covered_result_indices,
535        };
536
537        counts.overlapped_results += query_results.overlapped_result_count;
538        counts.covered_results += query_results.covered_result_count;
539        counts.covered_files += query_results.covered_file_count;
540        counts.total_results += query_results.total_result_count;
541
542        println!("{}", serde_json::to_string(&query_results)?);
543    }
544
545    user_store.update(cx, |_, _| {
546        drop(semantic_index);
547        drop(project);
548        drop(worktree);
549        drop(project_index);
550    })
551}
552
553async fn wait_for_indexing_complete(
554    project_index: &Model<ProjectIndex>,
555    cx: &mut AsyncAppContext,
556    timeout: Option<Duration>,
557) {
558    let (tx, rx) = bounded(1);
559    let subscription = cx.update(|cx| {
560        cx.subscribe(project_index, move |_, event, _| {
561            if let Status::Idle = event {
562                let _ = tx.try_send(*event);
563            }
564        })
565    });
566
567    let result = match timeout {
568        Some(timeout_duration) => {
569            smol::future::or(
570                async {
571                    rx.recv().await.map_err(|_| ())?;
572                    Ok(())
573                },
574                async {
575                    Timer::after(timeout_duration).await;
576                    Err(())
577                },
578            )
579            .await
580        }
581        None => rx.recv().await.map(|_| ()).map_err(|_| ()),
582    };
583
584    match result {
585        Ok(_) => (),
586        Err(_) => {
587            if let Some(timeout) = timeout {
588                eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
589            }
590        }
591    }
592
593    drop(subscription);
594}
595
596async fn fetch_eval_repos(
597    executor: &BackgroundExecutor,
598    http_client: &dyn HttpClient,
599) -> Result<()> {
600    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
601    let evaluations_path = dataset_dir.join("evaluations.json");
602    let repos_dir = Path::new(EVAL_REPOS_DIR);
603
604    let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
605    let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
606
607    eprintln!("Fetching evaluation repositories...");
608
609    executor
610        .scoped(move |scope| {
611            let done_count = Arc::new(AtomicUsize::new(0));
612            let len = evaluations.len();
613            for chunk in evaluations.chunks(evaluations.len() / 8) {
614                let chunk = chunk.to_vec();
615                let done_count = done_count.clone();
616                scope.spawn(async move {
617                    for EvaluationProject { repo, sha, .. } in chunk {
618                        eprint!(
619                            "\rFetching evaluation repositories ({}/{})...",
620                            done_count.load(SeqCst),
621                            len,
622                        );
623
624                        fetch_eval_repo(repo, sha, repos_dir, http_client).await;
625                        done_count.fetch_add(1, SeqCst);
626                    }
627                });
628            }
629        })
630        .await;
631
632    Ok(())
633}
634
635async fn fetch_eval_repo(
636    repo: String,
637    sha: String,
638    repos_dir: &Path,
639    http_client: &dyn HttpClient,
640) {
641    let Some((owner, repo_name)) = repo.split_once('/') else {
642        return;
643    };
644    let repo_dir = repos_dir.join(owner).join(repo_name);
645    fs::create_dir_all(&repo_dir).unwrap();
646    let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
647    if skip_eval_path.exists() {
648        return;
649    }
650    if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
651        if head_content.trim() == sha {
652            return;
653        }
654    }
655    let repo_response = http_client
656        .send(
657            http_client::Request::builder()
658                .method(Method::HEAD)
659                .uri(format!("https://github.com/{}", repo))
660                .body(Default::default())
661                .expect(""),
662        )
663        .await
664        .expect("failed to check github repo");
665    if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
666        fs::write(&skip_eval_path, "").unwrap();
667        eprintln!(
668            "Repo {repo} is no longer public ({:?}). Skipping",
669            repo_response.status()
670        );
671        return;
672    }
673    if !repo_dir.join(".git").exists() {
674        let init_output = Command::new("git")
675            .current_dir(&repo_dir)
676            .args(&["init"])
677            .output()
678            .unwrap();
679        if !init_output.status.success() {
680            eprintln!(
681                "Failed to initialize git repository for {}: {}",
682                repo,
683                String::from_utf8_lossy(&init_output.stderr)
684            );
685            return;
686        }
687    }
688    let url = format!("https://github.com/{}.git", repo);
689    Command::new("git")
690        .current_dir(&repo_dir)
691        .args(&["remote", "add", "-f", "origin", &url])
692        .stdin(Stdio::null())
693        .output()
694        .unwrap();
695    let fetch_output = Command::new("git")
696        .current_dir(&repo_dir)
697        .args(&["fetch", "--depth", "1", "origin", &sha])
698        .stdin(Stdio::null())
699        .output()
700        .unwrap();
701    if !fetch_output.status.success() {
702        eprintln!(
703            "Failed to fetch {} for {}: {}",
704            sha,
705            repo,
706            String::from_utf8_lossy(&fetch_output.stderr)
707        );
708        return;
709    }
710    let checkout_output = Command::new("git")
711        .current_dir(&repo_dir)
712        .args(&["checkout", &sha])
713        .output()
714        .unwrap();
715
716    if !checkout_output.status.success() {
717        eprintln!(
718            "Failed to checkout {} for {}: {}",
719            sha,
720            repo,
721            String::from_utf8_lossy(&checkout_output.stderr)
722        );
723    }
724}