1use crate::codegen::CodegenKind;
2use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
3use std::cmp::{self, Reverse};
4use std::fmt::Write;
5use std::ops::Range;
6use tiktoken_rs::ChatCompletionRequestMessage;
7
8fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
9 #[derive(Debug)]
10 struct Match {
11 collapse: Range<usize>,
12 keep: Vec<Range<usize>>,
13 }
14
15 let selected_range = selected_range.to_offset(buffer);
16 let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
17 Some(&grammar.embedding_config.as_ref()?.query)
18 });
19 let configs = ts_matches
20 .grammars()
21 .iter()
22 .map(|g| g.embedding_config.as_ref().unwrap())
23 .collect::<Vec<_>>();
24 let mut matches = Vec::new();
25 while let Some(mat) = ts_matches.peek() {
26 let config = &configs[mat.grammar_index];
27 if let Some(collapse) = mat.captures.iter().find_map(|cap| {
28 if Some(cap.index) == config.collapse_capture_ix {
29 Some(cap.node.byte_range())
30 } else {
31 None
32 }
33 }) {
34 let mut keep = Vec::new();
35 for capture in mat.captures.iter() {
36 if Some(capture.index) == config.keep_capture_ix {
37 keep.push(capture.node.byte_range());
38 } else {
39 continue;
40 }
41 }
42 ts_matches.advance();
43 matches.push(Match { collapse, keep });
44 } else {
45 ts_matches.advance();
46 }
47 }
48 matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
49 let mut matches = matches.into_iter().peekable();
50
51 let mut summary = String::new();
52 let mut offset = 0;
53 let mut flushed_selection = false;
54 while let Some(mat) = matches.next() {
55 // Keep extending the collapsed range if the next match surrounds
56 // the current one.
57 while let Some(next_mat) = matches.peek() {
58 if mat.collapse.start <= next_mat.collapse.start
59 && mat.collapse.end >= next_mat.collapse.end
60 {
61 matches.next().unwrap();
62 } else {
63 break;
64 }
65 }
66
67 if offset > mat.collapse.start {
68 // Skip collapsed nodes that have already been summarized.
69 offset = cmp::max(offset, mat.collapse.end);
70 continue;
71 }
72
73 if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
74 if !flushed_selection {
75 // The collapsed node ends after the selection starts, so we'll flush the selection first.
76 summary.extend(buffer.text_for_range(offset..selected_range.start));
77 summary.push_str("<|START|");
78 if selected_range.end == selected_range.start {
79 summary.push_str(">");
80 } else {
81 summary.extend(buffer.text_for_range(selected_range.clone()));
82 summary.push_str("|END|>");
83 }
84 offset = selected_range.end;
85 flushed_selection = true;
86 }
87
88 // If the selection intersects the collapsed node, we won't collapse it.
89 if selected_range.end >= mat.collapse.start {
90 continue;
91 }
92 }
93
94 summary.extend(buffer.text_for_range(offset..mat.collapse.start));
95 for keep in mat.keep {
96 summary.extend(buffer.text_for_range(keep));
97 }
98 offset = mat.collapse.end;
99 }
100
101 // Flush selection if we haven't already done so.
102 if !flushed_selection && offset <= selected_range.start {
103 summary.extend(buffer.text_for_range(offset..selected_range.start));
104 summary.push_str("<|START|");
105 if selected_range.end == selected_range.start {
106 summary.push_str(">");
107 } else {
108 summary.extend(buffer.text_for_range(selected_range.clone()));
109 summary.push_str("|END|>");
110 }
111 offset = selected_range.end;
112 }
113
114 summary.extend(buffer.text_for_range(offset..buffer.len()));
115 summary
116}
117
118pub fn generate_content_prompt(
119 user_prompt: String,
120 language_name: Option<&str>,
121 buffer: &BufferSnapshot,
122 range: Range<impl ToOffset>,
123 kind: CodegenKind,
124 search_results: Vec<String>,
125 model: &str,
126) -> String {
127 const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
128
129 let mut prompts = Vec::new();
130
131 // General Preamble
132 if let Some(language_name) = language_name {
133 prompts.push(format!("You're an expert {language_name} engineer.\n"));
134 } else {
135 prompts.push("You're an expert engineer.\n".to_string());
136 }
137
138 // Snippets
139 let mut snippet_position = prompts.len() - 1;
140
141 let outline = summarize(buffer, range);
142 prompts.push("The file you are currently working on has the following outline:".to_string());
143 if let Some(language_name) = language_name {
144 let language_name = language_name.to_lowercase();
145 prompts.push(format!("```{language_name}\n{outline}\n```"));
146 } else {
147 prompts.push(format!("```\n{outline}\n```"));
148 }
149
150 match kind {
151 CodegenKind::Generate { position: _ } => {
152 prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
153 prompts
154 .push("Assume the cursor is located where the `<|START|` marker is.".to_string());
155 prompts.push(
156 "Text can't be replaced, so assume your answer will be inserted at the cursor."
157 .to_string(),
158 );
159 prompts.push(format!(
160 "Generate text based on the users prompt: {user_prompt}"
161 ));
162 }
163 CodegenKind::Transform { range: _ } => {
164 prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
165 prompts.push(format!(
166 "Modify the users code selected text based upon the users prompt: {user_prompt}"
167 ));
168 prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
169 }
170 }
171
172 if let Some(language_name) = language_name {
173 prompts.push(format!("Your answer MUST always be valid {language_name}"));
174 }
175 prompts.push("Always wrap your response in a Markdown codeblock".to_string());
176 prompts.push("Never make remarks about the output.".to_string());
177
178 let current_messages = [ChatCompletionRequestMessage {
179 role: "user".to_string(),
180 content: Some(prompts.join("\n")),
181 function_call: None,
182 name: None,
183 }];
184
185 let remaining_token_count = if let Ok(current_token_count) =
186 tiktoken_rs::num_tokens_from_messages(model, ¤t_messages)
187 {
188 let max_token_count = tiktoken_rs::model::get_context_size(model);
189 max_token_count - current_token_count
190 } else {
191 // If tiktoken fails to count token count, assume we have no space remaining.
192 0
193 };
194
195 // TODO:
196 // - add repository name to snippet
197 // - add file path
198 // - add language
199 if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
200 let template = "You are working inside a large repository, here are a few code snippets that may be useful";
201
202 for search_result in search_results {
203 let mut snippet_prompt = template.to_string();
204 writeln!(snippet_prompt, "```\n{search_result}\n```").unwrap();
205
206 let token_count = encoding
207 .encode_with_special_tokens(snippet_prompt.as_str())
208 .len();
209 if token_count <= remaining_token_count {
210 if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
211 prompts.insert(snippet_position, snippet_prompt);
212 snippet_position += 1;
213 }
214 } else {
215 break;
216 }
217 }
218 }
219
220 let prompt = prompts.join("\n");
221 println!("PROMPT: {:?}", prompt);
222 prompt
223}
224
225#[cfg(test)]
226pub(crate) mod tests {
227
228 use super::*;
229 use std::sync::Arc;
230
231 use gpui::AppContext;
232 use indoc::indoc;
233 use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
234 use settings::SettingsStore;
235
236 pub(crate) fn rust_lang() -> Language {
237 Language::new(
238 LanguageConfig {
239 name: "Rust".into(),
240 path_suffixes: vec!["rs".to_string()],
241 ..Default::default()
242 },
243 Some(tree_sitter_rust::language()),
244 )
245 .with_embedding_query(
246 r#"
247 (
248 [(line_comment) (attribute_item)]* @context
249 .
250 [
251 (struct_item
252 name: (_) @name)
253
254 (enum_item
255 name: (_) @name)
256
257 (impl_item
258 trait: (_)? @name
259 "for"? @name
260 type: (_) @name)
261
262 (trait_item
263 name: (_) @name)
264
265 (function_item
266 name: (_) @name
267 body: (block
268 "{" @keep
269 "}" @keep) @collapse)
270
271 (macro_definition
272 name: (_) @name)
273 ] @item
274 )
275 "#,
276 )
277 .unwrap()
278 }
279
280 #[gpui::test]
281 fn test_outline_for_prompt(cx: &mut AppContext) {
282 cx.set_global(SettingsStore::test(cx));
283 language_settings::init(cx);
284 let text = indoc! {"
285 struct X {
286 a: usize,
287 b: usize,
288 }
289
290 impl X {
291
292 fn new() -> Self {
293 let a = 1;
294 let b = 2;
295 Self { a, b }
296 }
297
298 pub fn a(&self, param: bool) -> usize {
299 self.a
300 }
301
302 pub fn b(&self) -> usize {
303 self.b
304 }
305 }
306 "};
307 let buffer =
308 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
309 let snapshot = buffer.read(cx).snapshot();
310
311 assert_eq!(
312 summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
313 indoc! {"
314 struct X {
315 <|START|>a: usize,
316 b: usize,
317 }
318
319 impl X {
320
321 fn new() -> Self {}
322
323 pub fn a(&self, param: bool) -> usize {}
324
325 pub fn b(&self) -> usize {}
326 }
327 "}
328 );
329
330 assert_eq!(
331 summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
332 indoc! {"
333 struct X {
334 a: usize,
335 b: usize,
336 }
337
338 impl X {
339
340 fn new() -> Self {
341 let <|START|a |END|>= 1;
342 let b = 2;
343 Self { a, b }
344 }
345
346 pub fn a(&self, param: bool) -> usize {}
347
348 pub fn b(&self) -> usize {}
349 }
350 "}
351 );
352
353 assert_eq!(
354 summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
355 indoc! {"
356 struct X {
357 a: usize,
358 b: usize,
359 }
360
361 impl X {
362 <|START|>
363 fn new() -> Self {}
364
365 pub fn a(&self, param: bool) -> usize {}
366
367 pub fn b(&self) -> usize {}
368 }
369 "}
370 );
371
372 assert_eq!(
373 summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
374 indoc! {"
375 struct X {
376 a: usize,
377 b: usize,
378 }
379
380 impl X {
381
382 fn new() -> Self {}
383
384 pub fn a(&self, param: bool) -> usize {}
385
386 pub fn b(&self) -> usize {}
387 }
388 <|START|>"}
389 );
390
391 // Ensure nested functions get collapsed properly.
392 let text = indoc! {"
393 struct X {
394 a: usize,
395 b: usize,
396 }
397
398 impl X {
399
400 fn new() -> Self {
401 let a = 1;
402 let b = 2;
403 Self { a, b }
404 }
405
406 pub fn a(&self, param: bool) -> usize {
407 let a = 30;
408 fn nested() -> usize {
409 3
410 }
411 self.a + nested()
412 }
413
414 pub fn b(&self) -> usize {
415 self.b
416 }
417 }
418 "};
419 buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
420 let snapshot = buffer.read(cx).snapshot();
421 assert_eq!(
422 summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
423 indoc! {"
424 <|START|>struct X {
425 a: usize,
426 b: usize,
427 }
428
429 impl X {
430
431 fn new() -> Self {}
432
433 pub fn a(&self, param: bool) -> usize {}
434
435 pub fn b(&self) -> usize {}
436 }
437 "}
438 );
439 }
440}