eval.rs

  1use anyhow::{anyhow, Result};
  2use client::{self, UserStore};
  3use collections::HashMap;
  4use git2::{Object, Oid, Repository};
  5use gpui::{AppContext, AssetSource, AsyncAppContext, ModelHandle, Task};
  6use language::LanguageRegistry;
  7use node_runtime::RealNodeRuntime;
  8use project::{Project, RealFs};
  9use rust_embed::RustEmbed;
 10use semantic_index::embedding::OpenAIEmbeddings;
 11use semantic_index::semantic_index_settings::SemanticIndexSettings;
 12use semantic_index::{SearchResult, SemanticIndex};
 13use serde::Deserialize;
 14use settings::{default_settings, handle_settings_file_changes, watch_config_file, SettingsStore};
 15use std::path::{Path, PathBuf};
 16use std::sync::Arc;
 17use std::time::{Duration, Instant};
 18use std::{cmp, env, fs};
 19use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
 20use util::http::{self};
 21use util::paths::{self, EMBEDDINGS_DIR};
 22use zed::languages;
 23
 24#[derive(RustEmbed)]
 25#[folder = "../../assets"]
 26#[include = "fonts/**/*"]
 27#[include = "icons/**/*"]
 28#[include = "themes/**/*"]
 29#[include = "sounds/**/*"]
 30#[include = "*.md"]
 31#[exclude = "*.DS_Store"]
 32pub struct Assets;
 33
 34impl AssetSource for Assets {
 35    fn load(&self, path: &str) -> Result<std::borrow::Cow<[u8]>> {
 36        Self::get(path)
 37            .map(|f| f.data)
 38            .ok_or_else(|| anyhow!("could not find asset at path \"{}\"", path))
 39    }
 40
 41    fn list(&self, path: &str) -> Vec<std::borrow::Cow<'static, str>> {
 42        Self::iter().filter(|p| p.starts_with(path)).collect()
 43    }
 44}
 45
 46#[derive(Deserialize, Clone)]
 47struct EvaluationQuery {
 48    query: String,
 49    matches: Vec<String>,
 50}
 51
 52impl EvaluationQuery {
 53    fn match_pairs(&self) -> Vec<(PathBuf, u32)> {
 54        let mut pairs = Vec::new();
 55        for match_identifier in self.matches.iter() {
 56            let mut match_parts = match_identifier.split(":");
 57
 58            if let Some(file_path) = match_parts.next() {
 59                if let Some(row_number) = match_parts.next() {
 60                    pairs.push((PathBuf::from(file_path), row_number.parse::<u32>().unwrap()));
 61                }
 62            }
 63        }
 64        pairs
 65    }
 66}
 67
 68#[derive(Deserialize, Clone)]
 69struct RepoEval {
 70    repo: String,
 71    commit: String,
 72    assertions: Vec<EvaluationQuery>,
 73}
 74
 75struct EvaluationResults {
 76    token_count: usize,
 77    span_count: usize,
 78    time_to_index: Duration,
 79    time_to_search: Vec<Duration>,
 80    ndcg: HashMap<usize, f32>,
 81    map: HashMap<usize, f32>,
 82}
 83
 84const TMP_REPO_PATH: &str = "eval_repos";
 85
 86fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
 87    let eval_folder = env::current_dir()?
 88        .as_path()
 89        .parent()
 90        .unwrap()
 91        .join("crates/semantic_index/eval");
 92
 93    let mut repo_evals: Vec<RepoEval> = Vec::new();
 94    for entry in fs::read_dir(eval_folder)? {
 95        let file_path = entry.unwrap().path();
 96        if let Some(extension) = file_path.extension() {
 97            if extension == "json" {
 98                if let Ok(file) = fs::read_to_string(file_path) {
 99                    let repo_eval = serde_json::from_str(file.as_str());
100
101                    match repo_eval {
102                        Ok(repo_eval) => {
103                            repo_evals.push(repo_eval);
104                        }
105                        Err(err) => {
106                            println!("Err: {:?}", err);
107                        }
108                    }
109                }
110            }
111        }
112    }
113
114    Ok(repo_evals)
115}
116
117fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<PathBuf> {
118    let repo_name = Path::new(repo_eval.repo.as_str())
119        .file_name()
120        .unwrap()
121        .to_str()
122        .unwrap()
123        .to_owned()
124        .replace(".git", "");
125
126    let clone_path = fs::canonicalize(env::current_dir()?)?
127        .parent()
128        .ok_or(anyhow!("path canonicalization failed"))?
129        .parent()
130        .unwrap()
131        .join(TMP_REPO_PATH)
132        .join(&repo_name);
133
134    // Delete Clone Path if already exists
135    let _ = fs::remove_dir_all(&clone_path);
136
137    // Clone in Repo
138    git2::build::RepoBuilder::new()
139        // .branch(repo_eval.sha.as_str())
140        .clone(repo_eval.repo.as_str(), clone_path.as_path())?;
141
142    let repo: Repository = Repository::open(clone_path.clone())?;
143    let obj: Object = repo
144        .find_commit(Oid::from_str(repo_eval.commit.as_str())?)?
145        .into_object();
146    repo.checkout_tree(&obj, None)?;
147    repo.set_head_detached(obj.id())?;
148
149    Ok(clone_path)
150}
151
152fn dcg(hits: Vec<usize>) -> f32 {
153    let mut result = 0.0;
154    for (idx, hit) in hits.iter().enumerate() {
155        result += *hit as f32 / (2.0 + idx as f32).log2();
156    }
157
158    result
159}
160
161fn get_hits(
162    eval_query: EvaluationQuery,
163    search_results: Vec<SearchResult>,
164    k: usize,
165    cx: &AsyncAppContext,
166) -> (Vec<usize>, Vec<usize>) {
167    let ideal = vec![1; cmp::min(eval_query.matches.len(), k)];
168
169    let mut hits = Vec::new();
170    for result in search_results {
171        let (path, start_row, end_row) = result.buffer.read_with(cx, |buffer, cx| {
172            let path = buffer.file().unwrap().path().to_path_buf();
173            let start_row = buffer.offset_to_point(result.range.start.offset).row;
174            let end_row = buffer.offset_to_point(result.range.end.offset).row;
175            (path, start_row, end_row)
176        });
177
178        let match_pairs = eval_query.match_pairs();
179        let mut found = 0;
180        for (match_path, match_row) in match_pairs {
181            if match_path == path {
182                if match_row >= start_row && match_row <= end_row {
183                    found = 1;
184                    break;
185                }
186            }
187        }
188
189        hits.push(found);
190    }
191
192    // For now, we are calculating ideal_hits a bit different, as technically
193    // with overlapping ranges, one match can result in more than result.
194    let mut ideal_hits = hits.clone();
195    ideal_hits.retain(|x| x == &1);
196
197    let ideal = if ideal.len() > ideal_hits.len() {
198        ideal
199    } else {
200        ideal_hits
201    };
202
203    // Fill ideal to 10 length
204    let mut filled_ideal = [0; 10];
205    for (idx, i) in ideal.to_vec().into_iter().enumerate() {
206        filled_ideal[idx] = i;
207    }
208
209    (filled_ideal.to_vec(), hits)
210}
211
212fn evaluate_ndcg(hits: Vec<usize>, ideal: Vec<usize>) -> Vec<f32> {
213    // NDCG or Normalized Discounted Cumulative Gain, is determined by comparing the relevance of
214    // items returned by the search engine relative to the hypothetical ideal.
215    // Relevance is represented as a series of booleans, in which each search result returned
216    // is identified as being inside the test set of matches (1) or not (0).
217
218    // For example, if result 1, 3 and 5 match the 3 relevant results provided
219    // actual dcg is calculated against a vector of [1, 0, 1, 0, 1]
220    // whereas ideal dcg is calculated against a vector of [1, 1, 1, 0, 0]
221    // as this ideal vector assumes the 3 relevant results provided were returned first
222    // normalized dcg is then calculated as actual dcg / ideal dcg.
223
224    // NDCG ranges from 0 to 1, which higher values indicating better performance
225    // Commonly NDCG is expressed as NDCG@k, in which k represents the metric calculated
226    // including only the top k values returned.
227    // The @k metrics can help you identify, at what point does the relevant results start to fall off.
228    // Ie. a NDCG@1 of 0.9 and a NDCG@3 of 0.5 may indicate that the first result returned in usually
229    // very high quality, whereas rank results quickly drop off after the first result.
230
231    let mut ndcg = Vec::new();
232    for idx in 1..(hits.len() + 1) {
233        let hits_at_k = hits[0..idx].to_vec();
234        let ideal_at_k = ideal[0..idx].to_vec();
235
236        let at_k = dcg(hits_at_k.clone()) / dcg(ideal_at_k.clone());
237
238        ndcg.push(at_k);
239    }
240
241    ndcg
242}
243
244fn evaluate_map(hits: Vec<usize>) -> Vec<f32> {
245    let mut map_at_k = Vec::new();
246
247    let non_zero = hits.iter().sum::<usize>() as f32;
248    if non_zero == 0.0 {
249        return vec![0.0; hits.len()];
250    }
251
252    let mut rolling_non_zero = 0.0;
253    let mut rolling_map = 0.0;
254    for (idx, h) in hits.into_iter().enumerate() {
255        rolling_non_zero += h as f32;
256        rolling_map += rolling_non_zero / (idx + 1) as f32;
257        map_at_k.push(rolling_map / non_zero);
258    }
259
260    map_at_k
261}
262
263fn init_logger() {
264    env_logger::init();
265}
266
267async fn evaluate_repo(
268    index: ModelHandle<SemanticIndex>,
269    project: ModelHandle<Project>,
270    query_matches: Vec<EvaluationQuery>,
271    cx: &mut AsyncAppContext,
272) -> Result<()> {
273    // Index Project
274    let index_t0 = Instant::now();
275    index
276        .update(cx, |index, cx| index.index_project(project.clone(), cx))
277        .await?;
278    let index_time = index_t0.elapsed();
279    println!("Time to Index: {:?}", index_time.as_millis());
280
281    for query in query_matches {
282        // Query each match in order
283        let search_t0 = Instant::now();
284        let search_results = index
285            .update(cx, |index, cx| {
286                index.search_project(project.clone(), query.clone().query, 10, vec![], vec![], cx)
287            })
288            .await?;
289        let search_time = search_t0.elapsed();
290        println!("Time to Search: {:?}", search_time.as_millis());
291
292        // Get Hits/Ideal
293        let k = 10;
294        let (ideal, hits) = self::get_hits(query, search_results, k, cx);
295
296        // Evaluate ndcg@k, for k = 1, 3, 5, 10
297        let ndcg = evaluate_ndcg(hits.clone(), ideal);
298        println!("NDCG: {:?}", ndcg);
299
300        // Evaluate map@k, for k = 1, 3, 5, 10
301        let map = evaluate_map(hits);
302        println!("MAP: {:?}", map);
303
304        // Evaluate span count
305        // Evaluate token count
306    }
307
308    anyhow::Ok(())
309}
310
311fn main() {
312    // Launch new repo as a new Zed workspace/project
313    let app = gpui::App::new(Assets).unwrap();
314    let fs = Arc::new(RealFs);
315    let http = http::client();
316    let user_settings_file_rx =
317        watch_config_file(app.background(), fs.clone(), paths::SETTINGS.clone());
318    let http_client = http::client();
319    init_logger();
320
321    app.run(move |cx| {
322        cx.set_global(*RELEASE_CHANNEL);
323
324        let client = client::Client::new(http.clone(), cx);
325        let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client.clone(), cx));
326
327        // Initialize Settings
328        let mut store = SettingsStore::default();
329        store
330            .set_default_settings(default_settings().as_ref(), cx)
331            .unwrap();
332        cx.set_global(store);
333        handle_settings_file_changes(user_settings_file_rx, cx);
334
335        // Initialize Languages
336        let login_shell_env_loaded = Task::ready(());
337        let mut languages = LanguageRegistry::new(login_shell_env_loaded);
338        languages.set_executor(cx.background().clone());
339        let languages = Arc::new(languages);
340
341        let node_runtime = RealNodeRuntime::new(http.clone());
342        languages::init(languages.clone(), node_runtime.clone());
343        language::init(cx);
344
345        project::Project::init(&client, cx);
346        semantic_index::init(fs.clone(), http.clone(), languages.clone(), cx);
347
348        settings::register::<SemanticIndexSettings>(cx);
349
350        let db_file_path = EMBEDDINGS_DIR
351            .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
352            .join("embeddings_db");
353
354        let languages = languages.clone();
355        let fs = fs.clone();
356        cx.spawn(|mut cx| async move {
357            let semantic_index = SemanticIndex::new(
358                fs.clone(),
359                db_file_path,
360                Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
361                languages.clone(),
362                cx.clone(),
363            )
364            .await?;
365
366            if let Ok(repo_evals) = parse_eval() {
367                for repo in repo_evals {
368                    let cloned = clone_repo(repo.clone());
369                    match cloned {
370                        Ok(clone_path) => {
371                            log::trace!(
372                                "Cloned {:?} @ {:?} into {:?}",
373                                repo.repo,
374                                repo.commit,
375                                &clone_path
376                            );
377
378                            // Create Project
379                            let project = cx.update(|cx| {
380                                Project::local(
381                                    client.clone(),
382                                    user_store.clone(),
383                                    languages.clone(),
384                                    fs.clone(),
385                                    cx,
386                                )
387                            });
388
389                            // Register Worktree
390                            let _ = project
391                                .update(&mut cx, |project, cx| {
392                                    project.find_or_create_local_worktree(clone_path, true, cx)
393                                })
394                                .await;
395
396                            evaluate_repo(
397                                semantic_index.clone(),
398                                project,
399                                repo.assertions,
400                                &mut cx,
401                            )
402                            .await?;
403                        }
404                        Err(err) => {
405                            log::trace!("Error cloning: {:?}", err);
406                        }
407                    }
408                }
409            }
410
411            anyhow::Ok(())
412        })
413        .detach();
414    });
415}