file_context.rs

  1use anyhow::anyhow;
  2use language::BufferSnapshot;
  3use language::ToOffset;
  4
  5use crate::models::LanguageModel;
  6use crate::templates::base::PromptArguments;
  7use crate::templates::base::PromptTemplate;
  8use std::fmt::Write;
  9use std::ops::Range;
 10use std::sync::Arc;
 11
 12fn retrieve_context(
 13    buffer: &BufferSnapshot,
 14    selected_range: &Option<Range<usize>>,
 15    model: Arc<dyn LanguageModel>,
 16    max_token_count: Option<usize>,
 17) -> anyhow::Result<(String, usize, bool)> {
 18    let mut prompt = String::new();
 19    let mut truncated = false;
 20    if let Some(selected_range) = selected_range {
 21        let start = selected_range.start.to_offset(buffer);
 22        let end = selected_range.end.to_offset(buffer);
 23
 24        let start_window = buffer.text_for_range(0..start).collect::<String>();
 25
 26        let mut selected_window = String::new();
 27        if start == end {
 28            write!(selected_window, "<|START|>").unwrap();
 29        } else {
 30            write!(selected_window, "<|START|").unwrap();
 31        }
 32
 33        write!(
 34            selected_window,
 35            "{}",
 36            buffer.text_for_range(start..end).collect::<String>()
 37        )
 38        .unwrap();
 39
 40        if start != end {
 41            write!(selected_window, "|END|>").unwrap();
 42        }
 43
 44        let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
 45
 46        if let Some(max_token_count) = max_token_count {
 47            let selected_tokens = model.count_tokens(&selected_window)?;
 48            if selected_tokens > max_token_count {
 49                return Err(anyhow!(
 50                    "selected range is greater than model context window, truncation not possible"
 51                ));
 52            };
 53
 54            let mut remaining_tokens = max_token_count - selected_tokens;
 55            let start_window_tokens = model.count_tokens(&start_window)?;
 56            let end_window_tokens = model.count_tokens(&end_window)?;
 57            let outside_tokens = start_window_tokens + end_window_tokens;
 58            if outside_tokens > remaining_tokens {
 59                let (start_goal_tokens, end_goal_tokens) =
 60                    if start_window_tokens < end_window_tokens {
 61                        let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
 62                        remaining_tokens -= start_goal_tokens;
 63                        let end_goal_tokens = remaining_tokens.min(end_window_tokens);
 64                        (start_goal_tokens, end_goal_tokens)
 65                    } else {
 66                        let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
 67                        remaining_tokens -= end_goal_tokens;
 68                        let start_goal_tokens = remaining_tokens.min(start_window_tokens);
 69                        (start_goal_tokens, end_goal_tokens)
 70                    };
 71
 72                let truncated_start_window =
 73                    model.truncate_start(&start_window, start_goal_tokens)?;
 74                let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
 75                writeln!(
 76                    prompt,
 77                    "{truncated_start_window}{selected_window}{truncated_end_window}"
 78                )
 79                .unwrap();
 80                truncated = true;
 81            } else {
 82                writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
 83            }
 84        } else {
 85            // If we dont have a selected range, include entire file.
 86            writeln!(prompt, "{}", &buffer.text()).unwrap();
 87
 88            // Dumb truncation strategy
 89            if let Some(max_token_count) = max_token_count {
 90                if model.count_tokens(&prompt)? > max_token_count {
 91                    truncated = true;
 92                    prompt = model.truncate(&prompt, max_token_count)?;
 93                }
 94            }
 95        }
 96    }
 97
 98    let token_count = model.count_tokens(&prompt)?;
 99    anyhow::Ok((prompt, token_count, truncated))
100}
101
102pub struct FileContext {}
103
104impl PromptTemplate for FileContext {
105    fn generate(
106        &self,
107        args: &PromptArguments,
108        max_token_length: Option<usize>,
109    ) -> anyhow::Result<(String, usize)> {
110        if let Some(buffer) = &args.buffer {
111            let mut prompt = String::new();
112            // Add Initial Preamble
113            // TODO: Do we want to add the path in here?
114            writeln!(
115                prompt,
116                "The file you are currently working on has the following content:"
117            )
118            .unwrap();
119
120            let language_name = args
121                .language_name
122                .clone()
123                .unwrap_or("".to_string())
124                .to_lowercase();
125
126            let (context, _, truncated) = retrieve_context(
127                buffer,
128                &args.selected_range,
129                args.model.clone(),
130                max_token_length,
131            )?;
132            writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
133
134            if truncated {
135                writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
136            }
137
138            if let Some(selected_range) = &args.selected_range {
139                let start = selected_range.start.to_offset(buffer);
140                let end = selected_range.end.to_offset(buffer);
141
142                if start == end {
143                    writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
144                } else {
145                    writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
146                }
147            }
148
149            // Really dumb truncation strategy
150            if let Some(max_tokens) = max_token_length {
151                prompt = args.model.truncate(&prompt, max_tokens)?;
152            }
153
154            let token_count = args.model.count_tokens(&prompt)?;
155            anyhow::Ok((prompt, token_count))
156        } else {
157            Err(anyhow!("no buffer provided to retrieve file context from"))
158        }
159    }
160}