@@ -81,7 +81,11 @@ impl CodeContextRetriever {
if let Some((item, byte_range)) = item.zip(byte_range) {
if !name.is_empty() {
- let item = format!("{}\n{}", context_spans.join("\n"), item);
+ let item = if context_spans.is_empty() {
+ item.to_string()
+ } else {
+ format!("{}\n{}", context_spans.join("\n"), item)
+ };
let document_text = CODE_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_str().unwrap())
@@ -1,5 +1,9 @@
use crate::{
- db::dot, embedding::EmbeddingProvider, vector_store_settings::VectorStoreSettings, VectorStore,
+ db::dot,
+ embedding::EmbeddingProvider,
+ parsing::{CodeContextRetriever, Document},
+ vector_store_settings::VectorStoreSettings,
+ VectorStore,
};
use anyhow::Result;
use async_trait::async_trait;
@@ -9,7 +13,7 @@ use project::{project_settings::ProjectSettings, FakeFs, Project};
use rand::{rngs::StdRng, Rng};
use serde_json::json;
use settings::SettingsStore;
-use std::sync::Arc;
+use std::{path::Path, sync::Arc};
use unindent::Unindent;
#[ctor::ctor]
@@ -52,24 +56,7 @@ async fn test_vector_store(cx: &mut TestAppContext) {
.await;
let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
- let rust_language = Arc::new(
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- path_suffixes: vec!["rs".into()],
- ..Default::default()
- },
- Some(tree_sitter_rust::language()),
- )
- .with_embedding_query(
- r#"
- (function_item
- name: (identifier) @name
- body: (block)) @item
- "#,
- )
- .unwrap(),
- );
+ let rust_language = rust_lang();
languages.add(rust_language);
let db_dir = tempdir::TempDir::new("vector-store").unwrap();
@@ -109,14 +96,59 @@ async fn test_vector_store(cx: &mut TestAppContext) {
#[gpui::test]
async fn test_code_context_retrieval(cx: &mut TestAppContext) {
- // let mut retriever = CodeContextRetriever::new(fs);
-
- // retriever::parse_file(
- // "
- // //
- // ",
- // );
- //
+ let language = rust_lang();
+ let mut retriever = CodeContextRetriever::new();
+
+ let text = "
+ /// A doc comment
+ /// that spans multiple lines
+ fn a() {
+ b
+ }
+
+ impl C for D {
+ }
+ "
+ .unindent();
+
+ let parsed_files = retriever
+ .parse_file(Path::new("foo.rs"), &text, language)
+ .unwrap();
+
+ assert_eq!(
+ parsed_files,
+ &[
+ Document {
+ name: "a".into(),
+ range: text.find("fn a").unwrap()..(text.find("}").unwrap() + 1),
+ content: "
+ The below code snippet is from file 'foo.rs'
+
+ ```rust
+ /// A doc comment
+ /// that spans multiple lines
+ fn a() {
+ b
+ }
+ ```"
+ .unindent(),
+ embedding: vec![],
+ },
+ Document {
+ name: "C for D".into(),
+ range: text.find("impl C").unwrap()..(text.rfind("}").unwrap() + 1),
+ content: "
+ The below code snippet is from file 'foo.rs'
+
+ ```rust
+ impl C for D {
+ }
+ ```"
+ .unindent(),
+ embedding: vec![],
+ }
+ ]
+ );
}
#[gpui::test]
@@ -178,3 +210,71 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
.collect())
}
}
+
+fn rust_lang() -> Arc<Language> {
+ Arc::new(
+ Language::new(
+ LanguageConfig {
+ name: "Rust".into(),
+ path_suffixes: vec!["rs".into()],
+ ..Default::default()
+ },
+ Some(tree_sitter_rust::language()),
+ )
+ .with_embedding_query(
+ r#"
+ (
+ (line_comment)* @context
+ .
+ (enum_item
+ name: (_) @name) @item
+ )
+
+ (
+ (line_comment)* @context
+ .
+ (struct_item
+ name: (_) @name) @item
+ )
+
+ (
+ (line_comment)* @context
+ .
+ (impl_item
+ trait: (_)? @name
+ "for"? @name
+ type: (_) @name) @item
+ )
+
+ (
+ (line_comment)* @context
+ .
+ (trait_item
+ name: (_) @name) @item
+ )
+
+ (
+ (line_comment)* @context
+ .
+ (function_item
+ name: (_) @name) @item
+ )
+
+ (
+ (line_comment)* @context
+ .
+ (macro_definition
+ name: (_) @name) @item
+ )
+
+ (
+ (line_comment)* @context
+ .
+ (function_signature_item
+ name: (_) @name) @item
+ )
+ "#,
+ )
+ .unwrap(),
+ )
+}