repository_context.rs

 1use crate::prompts::base::{PromptArguments, PromptTemplate};
 2use std::fmt::Write;
 3use std::{ops::Range, path::PathBuf};
 4
 5use gpui::{AsyncAppContext, Model};
 6use language::{Anchor, Buffer};
 7
 8#[derive(Clone)]
 9pub struct PromptCodeSnippet {
10    path: Option<PathBuf>,
11    language_name: Option<String>,
12    content: String,
13}
14
15impl PromptCodeSnippet {
16    pub fn new(
17        buffer: Model<Buffer>,
18        range: Range<Anchor>,
19        cx: &mut AsyncAppContext,
20    ) -> anyhow::Result<Self> {
21        let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
22            let snapshot = buffer.snapshot();
23            let content = snapshot.text_for_range(range.clone()).collect::<String>();
24
25            let language_name = buffer
26                .language()
27                .and_then(|language| Some(language.name().to_string().to_lowercase()));
28
29            let file_path = buffer
30                .file()
31                .and_then(|file| Some(file.path().to_path_buf()));
32
33            (content, language_name, file_path)
34        })?;
35
36        anyhow::Ok(PromptCodeSnippet {
37            path: file_path,
38            language_name,
39            content,
40        })
41    }
42}
43
44impl ToString for PromptCodeSnippet {
45    fn to_string(&self) -> String {
46        let path = self
47            .path
48            .as_ref()
49            .and_then(|path| Some(path.to_string_lossy().to_string()))
50            .unwrap_or("".to_string());
51        let language_name = self.language_name.clone().unwrap_or("".to_string());
52        let content = self.content.clone();
53
54        format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
55    }
56}
57
58pub struct RepositoryContext {}
59
60impl PromptTemplate for RepositoryContext {
61    fn generate(
62        &self,
63        args: &PromptArguments,
64        max_token_length: Option<usize>,
65    ) -> anyhow::Result<(String, usize)> {
66        const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
67        let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
68        let mut prompt = String::new();
69
70        let mut remaining_tokens = max_token_length.clone();
71        let separator_token_length = args.model.count_tokens("\n")?;
72        for snippet in &args.snippets {
73            let mut snippet_prompt = template.to_string();
74            let content = snippet.to_string();
75            writeln!(snippet_prompt, "{content}").unwrap();
76
77            let token_count = args.model.count_tokens(&snippet_prompt)?;
78            if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
79                if let Some(tokens_left) = remaining_tokens {
80                    if tokens_left >= token_count {
81                        writeln!(prompt, "{snippet_prompt}").unwrap();
82                        remaining_tokens = if tokens_left >= (token_count + separator_token_length)
83                        {
84                            Some(tokens_left - token_count - separator_token_length)
85                        } else {
86                            Some(0)
87                        };
88                    }
89                } else {
90                    writeln!(prompt, "{snippet_prompt}").unwrap();
91                }
92            }
93        }
94
95        let total_token_count = args.model.count_tokens(&prompt)?;
96        anyhow::Ok((prompt, total_token_count))
97    }
98}