1use anyhow::anyhow;
2use language::BufferSnapshot;
3use language::ToOffset;
4
5use crate::models::LanguageModel;
6use crate::templates::base::PromptArguments;
7use crate::templates::base::PromptTemplate;
8use std::fmt::Write;
9use std::ops::Range;
10use std::sync::Arc;
11
12fn retrieve_context(
13 buffer: &BufferSnapshot,
14 selected_range: &Option<Range<usize>>,
15 model: Arc<dyn LanguageModel>,
16 max_token_count: Option<usize>,
17) -> anyhow::Result<(String, usize, bool)> {
18 let mut prompt = String::new();
19 let mut truncated = false;
20 if let Some(selected_range) = selected_range {
21 let start = selected_range.start.to_offset(buffer);
22 let end = selected_range.end.to_offset(buffer);
23
24 let start_window = buffer.text_for_range(0..start).collect::<String>();
25
26 let mut selected_window = String::new();
27 if start == end {
28 write!(selected_window, "<|START|>").unwrap();
29 } else {
30 write!(selected_window, "<|START|").unwrap();
31 }
32
33 write!(
34 selected_window,
35 "{}",
36 buffer.text_for_range(start..end).collect::<String>()
37 )
38 .unwrap();
39
40 if start != end {
41 write!(selected_window, "|END|>").unwrap();
42 }
43
44 let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
45
46 if let Some(max_token_count) = max_token_count {
47 let selected_tokens = model.count_tokens(&selected_window)?;
48 if selected_tokens > max_token_count {
49 return Err(anyhow!(
50 "selected range is greater than model context window, truncation not possible"
51 ));
52 };
53
54 let mut remaining_tokens = max_token_count - selected_tokens;
55 let start_window_tokens = model.count_tokens(&start_window)?;
56 let end_window_tokens = model.count_tokens(&end_window)?;
57 let outside_tokens = start_window_tokens + end_window_tokens;
58 if outside_tokens > remaining_tokens {
59 let (start_goal_tokens, end_goal_tokens) =
60 if start_window_tokens < end_window_tokens {
61 let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
62 remaining_tokens -= start_goal_tokens;
63 let end_goal_tokens = remaining_tokens.min(end_window_tokens);
64 (start_goal_tokens, end_goal_tokens)
65 } else {
66 let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
67 remaining_tokens -= end_goal_tokens;
68 let start_goal_tokens = remaining_tokens.min(start_window_tokens);
69 (start_goal_tokens, end_goal_tokens)
70 };
71
72 let truncated_start_window =
73 model.truncate_start(&start_window, start_goal_tokens)?;
74 let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
75 writeln!(
76 prompt,
77 "{truncated_start_window}{selected_window}{truncated_end_window}"
78 )
79 .unwrap();
80 truncated = true;
81 } else {
82 writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
83 }
84 } else {
85 // If we dont have a selected range, include entire file.
86 writeln!(prompt, "{}", &buffer.text()).unwrap();
87
88 // Dumb truncation strategy
89 if let Some(max_token_count) = max_token_count {
90 if model.count_tokens(&prompt)? > max_token_count {
91 truncated = true;
92 prompt = model.truncate(&prompt, max_token_count)?;
93 }
94 }
95 }
96 }
97
98 let token_count = model.count_tokens(&prompt)?;
99 anyhow::Ok((prompt, token_count, truncated))
100}
101
102pub struct FileContext {}
103
104impl PromptTemplate for FileContext {
105 fn generate(
106 &self,
107 args: &PromptArguments,
108 max_token_length: Option<usize>,
109 ) -> anyhow::Result<(String, usize)> {
110 if let Some(buffer) = &args.buffer {
111 let mut prompt = String::new();
112 // Add Initial Preamble
113 // TODO: Do we want to add the path in here?
114 writeln!(
115 prompt,
116 "The file you are currently working on has the following content:"
117 )
118 .unwrap();
119
120 let language_name = args
121 .language_name
122 .clone()
123 .unwrap_or("".to_string())
124 .to_lowercase();
125
126 let (context, _, truncated) = retrieve_context(
127 buffer,
128 &args.selected_range,
129 args.model.clone(),
130 max_token_length,
131 )?;
132 writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
133
134 if truncated {
135 writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
136 }
137
138 if let Some(selected_range) = &args.selected_range {
139 let start = selected_range.start.to_offset(buffer);
140 let end = selected_range.end.to_offset(buffer);
141
142 if start == end {
143 writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
144 } else {
145 writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
146 }
147 }
148
149 // Really dumb truncation strategy
150 if let Some(max_tokens) = max_token_length {
151 prompt = args.model.truncate(&prompt, max_tokens)?;
152 }
153
154 let token_count = args.model.count_tokens(&prompt)?;
155 anyhow::Ok((prompt, token_count))
156 } else {
157 Err(anyhow!("no buffer provided to retrieve file context from"))
158 }
159 }
160}