repository_context.rs

 1use crate::templates::base::{PromptArguments, PromptTemplate};
 2use std::fmt::Write;
 3use std::{ops::Range, path::PathBuf};
 4
 5use gpui::{AsyncAppContext, ModelHandle};
 6use language::{Anchor, Buffer};
 7
 8#[derive(Clone)]
 9pub struct PromptCodeSnippet {
10    path: Option<PathBuf>,
11    language_name: Option<String>,
12    content: String,
13}
14
15impl PromptCodeSnippet {
16    pub fn new(buffer: ModelHandle<Buffer>, range: Range<Anchor>, cx: &AsyncAppContext) -> Self {
17        let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| {
18            let snapshot = buffer.snapshot();
19            let content = snapshot.text_for_range(range.clone()).collect::<String>();
20
21            let language_name = buffer
22                .language()
23                .and_then(|language| Some(language.name().to_string().to_lowercase()));
24
25            let file_path = buffer
26                .file()
27                .and_then(|file| Some(file.path().to_path_buf()));
28
29            (content, language_name, file_path)
30        });
31
32        PromptCodeSnippet {
33            path: file_path,
34            language_name,
35            content,
36        }
37    }
38}
39
40impl ToString for PromptCodeSnippet {
41    fn to_string(&self) -> String {
42        let path = self
43            .path
44            .as_ref()
45            .and_then(|path| Some(path.to_string_lossy().to_string()))
46            .unwrap_or("".to_string());
47        let language_name = self.language_name.clone().unwrap_or("".to_string());
48        let content = self.content.clone();
49
50        format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
51    }
52}
53
54pub struct RepositoryContext {}
55
56impl PromptTemplate for RepositoryContext {
57    fn generate(
58        &self,
59        args: &PromptArguments,
60        max_token_length: Option<usize>,
61    ) -> anyhow::Result<(String, usize)> {
62        const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
63        let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
64        let mut prompt = String::new();
65
66        let mut remaining_tokens = max_token_length.clone();
67        let seperator_token_length = args.model.count_tokens("\n")?;
68        for snippet in &args.snippets {
69            let mut snippet_prompt = template.to_string();
70            let content = snippet.to_string();
71            writeln!(snippet_prompt, "{content}").unwrap();
72
73            let token_count = args.model.count_tokens(&snippet_prompt)?;
74            if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
75                if let Some(tokens_left) = remaining_tokens {
76                    if tokens_left >= token_count {
77                        writeln!(prompt, "{snippet_prompt}").unwrap();
78                        remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
79                        {
80                            Some(tokens_left - token_count - seperator_token_length)
81                        } else {
82                            Some(0)
83                        };
84                    }
85                } else {
86                    writeln!(prompt, "{snippet_prompt}").unwrap();
87                }
88            }
89        }
90
91        let total_token_count = args.model.count_tokens(&prompt)?;
92        anyhow::Ok((prompt, total_token_count))
93    }
94}