1use anyhow::anyhow;
2use language::BufferSnapshot;
3use language::ToOffset;
4
5use crate::models::LanguageModel;
6use crate::models::TruncationDirection;
7use crate::prompts::base::PromptArguments;
8use crate::prompts::base::PromptTemplate;
9use std::fmt::Write;
10use std::ops::Range;
11use std::sync::Arc;
12
13fn retrieve_context(
14 buffer: &BufferSnapshot,
15 selected_range: &Option<Range<usize>>,
16 model: Arc<dyn LanguageModel>,
17 max_token_count: Option<usize>,
18) -> anyhow::Result<(String, usize, bool)> {
19 let mut prompt = String::new();
20 let mut truncated = false;
21 if let Some(selected_range) = selected_range {
22 let start = selected_range.start.to_offset(buffer);
23 let end = selected_range.end.to_offset(buffer);
24
25 let start_window = buffer.text_for_range(0..start).collect::<String>();
26
27 let mut selected_window = String::new();
28 if start == end {
29 write!(selected_window, "<|START|>").unwrap();
30 } else {
31 write!(selected_window, "<|START|").unwrap();
32 }
33
34 write!(
35 selected_window,
36 "{}",
37 buffer.text_for_range(start..end).collect::<String>()
38 )
39 .unwrap();
40
41 if start != end {
42 write!(selected_window, "|END|>").unwrap();
43 }
44
45 let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
46
47 if let Some(max_token_count) = max_token_count {
48 let selected_tokens = model.count_tokens(&selected_window)?;
49 if selected_tokens > max_token_count {
50 return Err(anyhow!(
51 "selected range is greater than model context window, truncation not possible"
52 ));
53 };
54
55 let mut remaining_tokens = max_token_count - selected_tokens;
56 let start_window_tokens = model.count_tokens(&start_window)?;
57 let end_window_tokens = model.count_tokens(&end_window)?;
58 let outside_tokens = start_window_tokens + end_window_tokens;
59 if outside_tokens > remaining_tokens {
60 let (start_goal_tokens, end_goal_tokens) =
61 if start_window_tokens < end_window_tokens {
62 let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
63 remaining_tokens -= start_goal_tokens;
64 let end_goal_tokens = remaining_tokens.min(end_window_tokens);
65 (start_goal_tokens, end_goal_tokens)
66 } else {
67 let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
68 remaining_tokens -= end_goal_tokens;
69 let start_goal_tokens = remaining_tokens.min(start_window_tokens);
70 (start_goal_tokens, end_goal_tokens)
71 };
72
73 let truncated_start_window =
74 model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
75 let truncated_end_window =
76 model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
77 writeln!(
78 prompt,
79 "{truncated_start_window}{selected_window}{truncated_end_window}"
80 )
81 .unwrap();
82 truncated = true;
83 } else {
84 writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
85 }
86 } else {
87 // If we dont have a selected range, include entire file.
88 writeln!(prompt, "{}", &buffer.text()).unwrap();
89
90 // Dumb truncation strategy
91 if let Some(max_token_count) = max_token_count {
92 if model.count_tokens(&prompt)? > max_token_count {
93 truncated = true;
94 prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
95 }
96 }
97 }
98 }
99
100 let token_count = model.count_tokens(&prompt)?;
101 anyhow::Ok((prompt, token_count, truncated))
102}
103
104pub struct FileContext {}
105
106impl PromptTemplate for FileContext {
107 fn generate(
108 &self,
109 args: &PromptArguments,
110 max_token_length: Option<usize>,
111 ) -> anyhow::Result<(String, usize)> {
112 if let Some(buffer) = &args.buffer {
113 let mut prompt = String::new();
114 // Add Initial Preamble
115 // TODO: Do we want to add the path in here?
116 writeln!(
117 prompt,
118 "The file you are currently working on has the following content:"
119 )
120 .unwrap();
121
122 let language_name = args
123 .language_name
124 .clone()
125 .unwrap_or("".to_string())
126 .to_lowercase();
127
128 let (context, _, truncated) = retrieve_context(
129 buffer,
130 &args.selected_range,
131 args.model.clone(),
132 max_token_length,
133 )?;
134 writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
135
136 if truncated {
137 writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
138 }
139
140 if let Some(selected_range) = &args.selected_range {
141 let start = selected_range.start.to_offset(buffer);
142 let end = selected_range.end.to_offset(buffer);
143
144 if start == end {
145 writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
146 } else {
147 writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
148 }
149 }
150
151 // Really dumb truncation strategy
152 if let Some(max_tokens) = max_token_length {
153 prompt = args
154 .model
155 .truncate(&prompt, max_tokens, TruncationDirection::End)?;
156 }
157
158 let token_count = args.model.count_tokens(&prompt)?;
159 anyhow::Ok((prompt, token_count))
160 } else {
161 Err(anyhow!("no buffer provided to retrieve file context from"))
162 }
163 }
164}