1use std::sync::Arc;
2
3use crate::{dot, embedding::EmbeddingProvider, VectorStore};
4use anyhow::Result;
5use async_trait::async_trait;
6use gpui::{Task, TestAppContext};
7use language::{Language, LanguageConfig, LanguageRegistry};
8use project::{FakeFs, Project};
9use rand::Rng;
10use serde_json::json;
11use unindent::Unindent;
12
13#[gpui::test]
14async fn test_vector_store(cx: &mut TestAppContext) {
15 let fs = FakeFs::new(cx.background());
16 fs.insert_tree(
17 "/the-root",
18 json!({
19 "src": {
20 "file1.rs": "
21 fn aaa() {
22 println!(\"aaaa!\");
23 }
24
25 fn zzzzzzzzz() {
26 println!(\"SLEEPING\");
27 }
28 ".unindent(),
29 "file2.rs": "
30 fn bbb() {
31 println!(\"bbbb!\");
32 }
33 ".unindent(),
34 }
35 }),
36 )
37 .await;
38
39 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
40 let rust_language = Arc::new(
41 Language::new(
42 LanguageConfig {
43 name: "Rust".into(),
44 path_suffixes: vec!["rs".into()],
45 ..Default::default()
46 },
47 Some(tree_sitter_rust::language()),
48 )
49 .with_outline_query(
50 r#"
51 (function_item
52 name: (identifier) @name
53 body: (block)) @item
54 "#,
55 )
56 .unwrap(),
57 );
58 languages.add(rust_language);
59
60 let store = cx.add_model(|_| {
61 VectorStore::new(
62 fs.clone(),
63 "foo".to_string(),
64 Arc::new(FakeEmbeddingProvider),
65 languages,
66 )
67 });
68
69 let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
70 store
71 .update(cx, |store, cx| store.add_project(project, cx))
72 .await
73 .unwrap();
74
75 let search_results = store
76 .update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx))
77 .await
78 .unwrap();
79
80 assert_eq!(search_results[0].offset, 0);
81 assert_eq!(search_results[1].name, "aaa");
82}
83
84#[test]
85fn test_dot_product() {
86 assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
87 assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
88
89 for _ in 0..100 {
90 let mut rng = rand::thread_rng();
91 let a: [f32; 32] = rng.gen();
92 let b: [f32; 32] = rng.gen();
93 assert_eq!(
94 round_to_decimals(dot(&a, &b), 3),
95 round_to_decimals(reference_dot(&a, &b), 3)
96 );
97 }
98
99 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
100 let factor = (10.0 as f32).powi(decimal_places);
101 (n * factor).round() / factor
102 }
103
104 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
105 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
106 }
107}
108
109struct FakeEmbeddingProvider;
110
111#[async_trait]
112impl EmbeddingProvider for FakeEmbeddingProvider {
113 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
114 Ok(spans
115 .iter()
116 .map(|span| {
117 let mut result = vec![0.0; 26];
118 for letter in span.chars() {
119 if letter as u32 > 'a' as u32 {
120 let ix = (letter as u32) - ('a' as u32);
121 if ix < 26 {
122 result[ix as usize] += 1.0;
123 }
124 }
125 }
126
127 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
128 for x in &mut result {
129 *x /= norm;
130 }
131
132 result
133 })
134 .collect())
135 }
136}