eval.rs

  1use ::fs::{Fs, RealFs};
  2use anyhow::Result;
  3use clap::Parser;
  4use client::{Client, UserStore};
  5use clock::RealSystemClock;
  6use collections::BTreeMap;
  7use feature_flags::FeatureFlagAppExt as _;
  8use git::GitHostingProviderRegistry;
  9use gpui::{AsyncAppContext, BackgroundExecutor, Context, Model};
 10use http_client::{HttpClient, Method};
 11use language::LanguageRegistry;
 12use node_runtime::FakeNodeRuntime;
 13use open_ai::OpenAiEmbeddingModel;
 14use project::Project;
 15use semantic_index::{OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status};
 16use serde::{Deserialize, Serialize};
 17use settings::SettingsStore;
 18use smol::channel::bounded;
 19use smol::io::AsyncReadExt;
 20use smol::Timer;
 21use std::ops::RangeInclusive;
 22use std::time::Duration;
 23use std::{
 24    fs,
 25    path::Path,
 26    process::{exit, Command, Stdio},
 27    sync::{
 28        atomic::{AtomicUsize, Ordering::SeqCst},
 29        Arc,
 30    },
 31};
 32
 33const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
 34const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
 35const EVAL_DB_PATH: &'static str = "target/eval_db";
 36const SEARCH_RESULT_LIMIT: usize = 8;
 37const SKIP_EVAL_PATH: &'static str = ".skip_eval";
 38
 39#[derive(clap::Parser)]
 40#[command(author, version, about, long_about = None)]
 41struct Cli {
 42    #[command(subcommand)]
 43    command: Commands,
 44}
 45
 46#[derive(clap::Subcommand)]
 47enum Commands {
 48    Fetch {},
 49    Run {
 50        #[arg(long)]
 51        repo: Option<String>,
 52    },
 53}
 54
 55#[derive(Clone, Deserialize, Serialize)]
 56struct EvaluationProject {
 57    repo: String,
 58    sha: String,
 59    queries: Vec<EvaluationQuery>,
 60}
 61
 62#[derive(Clone, Debug, Deserialize, Serialize)]
 63struct EvaluationQuery {
 64    query: String,
 65    expected_results: Vec<EvaluationSearchResult>,
 66}
 67
 68#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
 69struct EvaluationSearchResult {
 70    file: String,
 71    lines: RangeInclusive<u32>,
 72}
 73
 74#[derive(Clone, Deserialize, Serialize)]
 75struct EvaluationProjectOutcome {
 76    repo: String,
 77    sha: String,
 78    queries: Vec<EvaluationQueryOutcome>,
 79}
 80
 81#[derive(Clone, Debug, Deserialize, Serialize)]
 82struct EvaluationQueryOutcome {
 83    repo: String,
 84    query: String,
 85    expected_results: Vec<EvaluationSearchResult>,
 86    actual_results: Vec<EvaluationSearchResult>,
 87    covered_file_count: usize,
 88    overlapped_result_count: usize,
 89    covered_result_count: usize,
 90    total_result_count: usize,
 91    covered_result_indices: Vec<usize>,
 92}
 93
 94fn main() -> Result<()> {
 95    let cli = Cli::parse();
 96    env_logger::init();
 97
 98    gpui::App::headless().run(move |cx| {
 99        let executor = cx.background_executor().clone();
100
101        match cli.command {
102            Commands::Fetch {} => {
103                executor
104                    .clone()
105                    .spawn(async move {
106                        if let Err(err) = fetch_evaluation_resources(&executor).await {
107                            eprintln!("Error: {}", err);
108                            exit(1);
109                        }
110                        exit(0);
111                    })
112                    .detach();
113            }
114            Commands::Run { repo } => {
115                cx.spawn(|mut cx| async move {
116                    if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
117                        eprintln!("Error: {}", err);
118                        exit(1);
119                    }
120                    exit(0);
121                })
122                .detach();
123            }
124        }
125    });
126
127    Ok(())
128}
129
130async fn fetch_evaluation_resources(executor: &BackgroundExecutor) -> Result<()> {
131    let http_client = http_client::HttpClientWithProxy::new(None, None);
132    fetch_code_search_net_resources(&http_client).await?;
133    fetch_eval_repos(executor, &http_client).await?;
134    Ok(())
135}
136
137async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
138    eprintln!("Fetching CodeSearchNet evaluations...");
139
140    let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
141
142    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
143    fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
144
145    // Fetch the annotations CSV, which contains the human-annotated search relevances
146    let annotations_path = dataset_dir.join("annotations.csv");
147    let annotations_csv_content = if annotations_path.exists() {
148        fs::read_to_string(&annotations_path).expect("failed to read annotations")
149    } else {
150        let response = http_client
151            .get(annotations_url, Default::default(), true)
152            .await
153            .expect("failed to fetch annotations csv");
154        let mut body = String::new();
155        response
156            .into_body()
157            .read_to_string(&mut body)
158            .await
159            .expect("failed to read annotations.csv response");
160        fs::write(annotations_path, &body).expect("failed to write annotations.csv");
161        body
162    };
163
164    // Parse the annotations CSV. Skip over queries with zero relevance.
165    let rows = annotations_csv_content.lines().filter_map(|line| {
166        let mut values = line.split(',');
167        let _language = values.next()?;
168        let query = values.next()?;
169        let github_url = values.next()?;
170        let score = values.next()?;
171
172        if score == "0" {
173            return None;
174        }
175
176        let url_path = github_url.strip_prefix("https://github.com/")?;
177        let (url_path, hash) = url_path.split_once('#')?;
178        let (repo_name, url_path) = url_path.split_once("/blob/")?;
179        let (sha, file_path) = url_path.split_once('/')?;
180        let line_range = if let Some((start, end)) = hash.split_once('-') {
181            start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
182        } else {
183            let row = hash.strip_prefix("L")?.parse().ok()?;
184            row..=row
185        };
186        Some((repo_name, sha, query, file_path, line_range))
187    });
188
189    // Group the annotations by repo and sha.
190    let mut evaluations_by_repo = BTreeMap::new();
191    for (repo_name, sha, query, file_path, lines) in rows {
192        let evaluation_project = evaluations_by_repo
193            .entry((repo_name, sha))
194            .or_insert_with(|| EvaluationProject {
195                repo: repo_name.to_string(),
196                sha: sha.to_string(),
197                queries: Vec::new(),
198            });
199
200        let ix = evaluation_project
201            .queries
202            .iter()
203            .position(|entry| entry.query == query)
204            .unwrap_or_else(|| {
205                evaluation_project.queries.push(EvaluationQuery {
206                    query: query.to_string(),
207                    expected_results: Vec::new(),
208                });
209                evaluation_project.queries.len() - 1
210            });
211        let results = &mut evaluation_project.queries[ix].expected_results;
212        let result = EvaluationSearchResult {
213            file: file_path.to_string(),
214            lines,
215        };
216        if !results.contains(&result) {
217            results.push(result);
218        }
219    }
220
221    let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
222    let evaluations_path = dataset_dir.join("evaluations.json");
223    fs::write(
224        &evaluations_path,
225        serde_json::to_vec_pretty(&evaluations).unwrap(),
226    )
227    .unwrap();
228
229    eprintln!(
230        "Fetched CodeSearchNet evaluations into {}",
231        evaluations_path.display()
232    );
233
234    Ok(())
235}
236
237async fn run_evaluation(
238    only_repo: Option<String>,
239    executor: &BackgroundExecutor,
240    cx: &mut AsyncAppContext,
241) -> Result<()> {
242    cx.update(|cx| {
243        let mut store = SettingsStore::new(cx);
244        store
245            .set_default_settings(settings::default_settings().as_ref(), cx)
246            .unwrap();
247        cx.set_global(store);
248        client::init_settings(cx);
249        language::init(cx);
250        Project::init_settings(cx);
251        cx.update_flags(false, vec![]);
252    })
253    .unwrap();
254
255    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
256    let evaluations_path = dataset_dir.join("evaluations.json");
257    let repos_dir = Path::new(EVAL_REPOS_DIR);
258    let db_path = Path::new(EVAL_DB_PATH);
259    let http_client = http_client::HttpClientWithProxy::new(None, None);
260    let api_key = std::env::var("OPENAI_API_KEY").unwrap();
261    let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new());
262    let fs = Arc::new(RealFs::new(git_hosting_provider_registry, None)) as Arc<dyn Fs>;
263    let clock = Arc::new(RealSystemClock);
264    let client = cx
265        .update(|cx| {
266            Client::new(
267                clock,
268                Arc::new(http_client::HttpClientWithUrl::new(
269                    "https://zed.dev",
270                    None,
271                    None,
272                )),
273                cx,
274            )
275        })
276        .unwrap();
277    let user_store = cx
278        .new_model(|cx| UserStore::new(client.clone(), cx))
279        .unwrap();
280    let node_runtime = Arc::new(FakeNodeRuntime {});
281
282    let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
283    let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
284
285    let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
286        http_client.clone(),
287        OpenAiEmbeddingModel::TextEmbedding3Small,
288        open_ai::OPEN_AI_API_URL.to_string(),
289        api_key,
290    ));
291
292    let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
293    cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
294        .unwrap();
295
296    let mut covered_result_count = 0;
297    let mut overlapped_result_count = 0;
298    let mut covered_file_count = 0;
299    let mut total_result_count = 0;
300    eprint!("Running evals.");
301
302    for evaluation_project in evaluations {
303        if only_repo
304            .as_ref()
305            .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
306        {
307            continue;
308        }
309
310        eprint!("\r\x1B[2K");
311        eprint!(
312            "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
313            covered_result_count,
314            total_result_count,
315            overlapped_result_count,
316            total_result_count,
317            covered_file_count,
318            total_result_count,
319            evaluation_project.repo
320        );
321
322        let repo_db_path =
323            db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
324        let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider.clone(), cx)
325            .await
326            .unwrap();
327
328        let repo_dir = repos_dir.join(&evaluation_project.repo);
329        if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
330            eprintln!("Skipping {}: directory not found", evaluation_project.repo);
331            continue;
332        }
333
334        let project = cx
335            .update(|cx| {
336                Project::local(
337                    client.clone(),
338                    node_runtime.clone(),
339                    user_store.clone(),
340                    language_registry.clone(),
341                    fs.clone(),
342                    None,
343                    cx,
344                )
345            })
346            .unwrap();
347
348        let (worktree, _) = project
349            .update(cx, |project, cx| {
350                project.find_or_create_worktree(repo_dir, true, cx)
351            })?
352            .await?;
353
354        worktree
355            .update(cx, |worktree, _| {
356                worktree.as_local().unwrap().scan_complete()
357            })
358            .unwrap()
359            .await;
360
361        let project_index = cx
362            .update(|cx| semantic_index.create_project_index(project.clone(), cx))
363            .unwrap();
364        wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
365
366        for query in evaluation_project.queries {
367            let results = cx
368                .update(|cx| {
369                    let project_index = project_index.read(cx);
370                    project_index.search(query.query.clone(), SEARCH_RESULT_LIMIT, cx)
371                })
372                .unwrap()
373                .await
374                .unwrap();
375
376            let results = SemanticDb::load_results(results, &fs.clone(), &cx)
377                .await
378                .unwrap();
379
380            let mut project_covered_result_count = 0;
381            let mut project_overlapped_result_count = 0;
382            let mut project_covered_file_count = 0;
383            let mut covered_result_indices = Vec::new();
384            for expected_result in &query.expected_results {
385                let mut file_matched = false;
386                let mut range_overlapped = false;
387                let mut range_covered = false;
388
389                for (ix, result) in results.iter().enumerate() {
390                    if result.path.as_ref() == Path::new(&expected_result.file) {
391                        file_matched = true;
392                        let start_matched =
393                            result.row_range.contains(&expected_result.lines.start());
394                        let end_matched = result.row_range.contains(&expected_result.lines.end());
395
396                        if start_matched || end_matched {
397                            range_overlapped = true;
398                        }
399
400                        if start_matched && end_matched {
401                            range_covered = true;
402                            covered_result_indices.push(ix);
403                            break;
404                        }
405                    }
406                }
407
408                if range_covered {
409                    project_covered_result_count += 1
410                };
411                if range_overlapped {
412                    project_overlapped_result_count += 1
413                };
414                if file_matched {
415                    project_covered_file_count += 1
416                };
417            }
418            let outcome_repo = evaluation_project.repo.clone();
419
420            let query_results = EvaluationQueryOutcome {
421                repo: outcome_repo,
422                query: query.query,
423                total_result_count: query.expected_results.len(),
424                covered_result_count: project_covered_result_count,
425                overlapped_result_count: project_overlapped_result_count,
426                covered_file_count: project_covered_file_count,
427                expected_results: query.expected_results,
428                actual_results: results
429                    .iter()
430                    .map(|result| EvaluationSearchResult {
431                        file: result.path.to_string_lossy().to_string(),
432                        lines: result.row_range.clone(),
433                    })
434                    .collect(),
435                covered_result_indices,
436            };
437
438            overlapped_result_count += query_results.overlapped_result_count;
439            covered_result_count += query_results.covered_result_count;
440            covered_file_count += query_results.covered_file_count;
441            total_result_count += query_results.total_result_count;
442
443            println!("{}", serde_json::to_string(&query_results).unwrap());
444        }
445    }
446
447    eprint!(
448        "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured.",
449        covered_result_count,
450        total_result_count,
451        overlapped_result_count,
452        total_result_count,
453        covered_file_count,
454        total_result_count,
455    );
456
457    Ok(())
458}
459
460async fn wait_for_indexing_complete(
461    project_index: &Model<ProjectIndex>,
462    cx: &mut AsyncAppContext,
463    timeout: Option<Duration>,
464) {
465    let (tx, rx) = bounded(1);
466    let subscription = cx.update(|cx| {
467        cx.subscribe(project_index, move |_, event, _| {
468            if let Status::Idle = event {
469                let _ = tx.try_send(*event);
470            }
471        })
472    });
473
474    let result = match timeout {
475        Some(timeout_duration) => {
476            smol::future::or(
477                async {
478                    rx.recv().await.map_err(|_| ())?;
479                    Ok(())
480                },
481                async {
482                    Timer::after(timeout_duration).await;
483                    Err(())
484                },
485            )
486            .await
487        }
488        None => rx.recv().await.map(|_| ()).map_err(|_| ()),
489    };
490
491    match result {
492        Ok(_) => (),
493        Err(_) => {
494            if let Some(timeout) = timeout {
495                eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
496            }
497        }
498    }
499
500    drop(subscription);
501}
502
503async fn fetch_eval_repos(
504    executor: &BackgroundExecutor,
505    http_client: &dyn HttpClient,
506) -> Result<()> {
507    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
508    let evaluations_path = dataset_dir.join("evaluations.json");
509    let repos_dir = Path::new(EVAL_REPOS_DIR);
510
511    let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
512    let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
513
514    eprint!("Fetching evaluation repositories...");
515
516    executor
517        .scoped(move |scope| {
518            let done_count = Arc::new(AtomicUsize::new(0));
519            let len = evaluations.len();
520            for chunk in evaluations.chunks(evaluations.len() / 8) {
521                let chunk = chunk.to_vec();
522                let done_count = done_count.clone();
523                scope.spawn(async move {
524                    for EvaluationProject { repo, sha, .. } in chunk {
525                        eprint!(
526                            "\rFetching evaluation repositories ({}/{})...",
527                            done_count.load(SeqCst),
528                            len,
529                        );
530
531                        fetch_eval_repo(repo, sha, repos_dir, http_client).await;
532                        done_count.fetch_add(1, SeqCst);
533                    }
534                });
535            }
536        })
537        .await;
538
539    Ok(())
540}
541
542async fn fetch_eval_repo(
543    repo: String,
544    sha: String,
545    repos_dir: &Path,
546    http_client: &dyn HttpClient,
547) {
548    let Some((owner, repo_name)) = repo.split_once('/') else {
549        return;
550    };
551    let repo_dir = repos_dir.join(owner).join(repo_name);
552    fs::create_dir_all(&repo_dir).unwrap();
553    let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
554    if skip_eval_path.exists() {
555        return;
556    }
557    if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
558        if head_content.trim() == sha {
559            return;
560        }
561    }
562    let repo_response = http_client
563        .send(
564            http_client::Request::builder()
565                .method(Method::HEAD)
566                .uri(format!("https://github.com/{}", repo))
567                .body(Default::default())
568                .expect(""),
569        )
570        .await
571        .expect("failed to check github repo");
572    if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
573        fs::write(&skip_eval_path, "").unwrap();
574        eprintln!(
575            "Repo {repo} is no longer public ({:?}). Skipping",
576            repo_response.status()
577        );
578        return;
579    }
580    if !repo_dir.join(".git").exists() {
581        let init_output = Command::new("git")
582            .current_dir(&repo_dir)
583            .args(&["init"])
584            .output()
585            .unwrap();
586        if !init_output.status.success() {
587            eprintln!(
588                "Failed to initialize git repository for {}: {}",
589                repo,
590                String::from_utf8_lossy(&init_output.stderr)
591            );
592            return;
593        }
594    }
595    let url = format!("https://github.com/{}.git", repo);
596    Command::new("git")
597        .current_dir(&repo_dir)
598        .args(&["remote", "add", "-f", "origin", &url])
599        .stdin(Stdio::null())
600        .output()
601        .unwrap();
602    let fetch_output = Command::new("git")
603        .current_dir(&repo_dir)
604        .args(&["fetch", "--depth", "1", "origin", &sha])
605        .stdin(Stdio::null())
606        .output()
607        .unwrap();
608    if !fetch_output.status.success() {
609        eprintln!(
610            "Failed to fetch {} for {}: {}",
611            sha,
612            repo,
613            String::from_utf8_lossy(&fetch_output.stderr)
614        );
615        return;
616    }
617    let checkout_output = Command::new("git")
618        .current_dir(&repo_dir)
619        .args(&["checkout", &sha])
620        .output()
621        .unwrap();
622
623    if !checkout_output.status.success() {
624        eprintln!(
625            "Failed to checkout {} for {}: {}",
626            sha,
627            repo,
628            String::from_utf8_lossy(&checkout_output.stderr)
629        );
630    }
631}