file_context.rs

  1use anyhow::anyhow;
  2use language::BufferSnapshot;
  3use language::ToOffset;
  4
  5use crate::models::LanguageModel;
  6use crate::models::TruncationDirection;
  7use crate::prompts::base::PromptArguments;
  8use crate::prompts::base::PromptTemplate;
  9use std::fmt::Write;
 10use std::ops::Range;
 11use std::sync::Arc;
 12
 13fn retrieve_context(
 14    buffer: &BufferSnapshot,
 15    selected_range: &Option<Range<usize>>,
 16    model: Arc<dyn LanguageModel>,
 17    max_token_count: Option<usize>,
 18) -> anyhow::Result<(String, usize, bool)> {
 19    let mut prompt = String::new();
 20    let mut truncated = false;
 21    if let Some(selected_range) = selected_range {
 22        let start = selected_range.start.to_offset(buffer);
 23        let end = selected_range.end.to_offset(buffer);
 24
 25        let start_window = buffer.text_for_range(0..start).collect::<String>();
 26
 27        let mut selected_window = String::new();
 28        if start == end {
 29            write!(selected_window, "<|START|>").unwrap();
 30        } else {
 31            write!(selected_window, "<|START|").unwrap();
 32        }
 33
 34        write!(
 35            selected_window,
 36            "{}",
 37            buffer.text_for_range(start..end).collect::<String>()
 38        )
 39        .unwrap();
 40
 41        if start != end {
 42            write!(selected_window, "|END|>").unwrap();
 43        }
 44
 45        let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
 46
 47        if let Some(max_token_count) = max_token_count {
 48            let selected_tokens = model.count_tokens(&selected_window)?;
 49            if selected_tokens > max_token_count {
 50                return Err(anyhow!(
 51                    "selected range is greater than model context window, truncation not possible"
 52                ));
 53            };
 54
 55            let mut remaining_tokens = max_token_count - selected_tokens;
 56            let start_window_tokens = model.count_tokens(&start_window)?;
 57            let end_window_tokens = model.count_tokens(&end_window)?;
 58            let outside_tokens = start_window_tokens + end_window_tokens;
 59            if outside_tokens > remaining_tokens {
 60                let (start_goal_tokens, end_goal_tokens) =
 61                    if start_window_tokens < end_window_tokens {
 62                        let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
 63                        remaining_tokens -= start_goal_tokens;
 64                        let end_goal_tokens = remaining_tokens.min(end_window_tokens);
 65                        (start_goal_tokens, end_goal_tokens)
 66                    } else {
 67                        let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
 68                        remaining_tokens -= end_goal_tokens;
 69                        let start_goal_tokens = remaining_tokens.min(start_window_tokens);
 70                        (start_goal_tokens, end_goal_tokens)
 71                    };
 72
 73                let truncated_start_window =
 74                    model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
 75                let truncated_end_window =
 76                    model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
 77                writeln!(
 78                    prompt,
 79                    "{truncated_start_window}{selected_window}{truncated_end_window}"
 80                )
 81                .unwrap();
 82                truncated = true;
 83            } else {
 84                writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
 85            }
 86        } else {
 87            // If we dont have a selected range, include entire file.
 88            writeln!(prompt, "{}", &buffer.text()).unwrap();
 89
 90            // Dumb truncation strategy
 91            if let Some(max_token_count) = max_token_count {
 92                if model.count_tokens(&prompt)? > max_token_count {
 93                    truncated = true;
 94                    prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
 95                }
 96            }
 97        }
 98    }
 99
100    let token_count = model.count_tokens(&prompt)?;
101    anyhow::Ok((prompt, token_count, truncated))
102}
103
104pub struct FileContext {}
105
106impl PromptTemplate for FileContext {
107    fn generate(
108        &self,
109        args: &PromptArguments,
110        max_token_length: Option<usize>,
111    ) -> anyhow::Result<(String, usize)> {
112        if let Some(buffer) = &args.buffer {
113            let mut prompt = String::new();
114            // Add Initial Preamble
115            // TODO: Do we want to add the path in here?
116            writeln!(
117                prompt,
118                "The file you are currently working on has the following content:"
119            )
120            .unwrap();
121
122            let language_name = args
123                .language_name
124                .clone()
125                .unwrap_or("".to_string())
126                .to_lowercase();
127
128            let (context, _, truncated) = retrieve_context(
129                buffer,
130                &args.selected_range,
131                args.model.clone(),
132                max_token_length,
133            )?;
134            writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
135
136            if truncated {
137                writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
138            }
139
140            if let Some(selected_range) = &args.selected_range {
141                let start = selected_range.start.to_offset(buffer);
142                let end = selected_range.end.to_offset(buffer);
143
144                if start == end {
145                    writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
146                } else {
147                    writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
148                }
149            }
150
151            // Really dumb truncation strategy
152            if let Some(max_tokens) = max_token_length {
153                prompt = args
154                    .model
155                    .truncate(&prompt, max_tokens, TruncationDirection::End)?;
156            }
157
158            let token_count = args.model.count_tokens(&prompt)?;
159            anyhow::Ok((prompt, token_count))
160        } else {
161            Err(anyhow!("no buffer provided to retrieve file context from"))
162        }
163    }
164}