1use crate::codegen::CodegenKind;
2use gpui::AsyncAppContext;
3use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
4use semantic_index::SearchResult;
5use std::cmp::{self, Reverse};
6use std::fmt::Write;
7use std::ops::Range;
8use std::path::PathBuf;
9use tiktoken_rs::ChatCompletionRequestMessage;
10
11pub struct PromptCodeSnippet {
12 path: Option<PathBuf>,
13 language_name: Option<String>,
14 content: String,
15}
16
17impl PromptCodeSnippet {
18 pub fn new(search_result: SearchResult, cx: &AsyncAppContext) -> Self {
19 let (content, language_name, file_path) =
20 search_result.buffer.read_with(cx, |buffer, _| {
21 let snapshot = buffer.snapshot();
22 let content = snapshot
23 .text_for_range(search_result.range.clone())
24 .collect::<String>();
25
26 let language_name = buffer
27 .language()
28 .and_then(|language| Some(language.name().to_string()));
29
30 let file_path = buffer
31 .file()
32 .and_then(|file| Some(file.path().to_path_buf()));
33
34 (content, language_name, file_path)
35 });
36
37 PromptCodeSnippet {
38 path: file_path,
39 language_name,
40 content,
41 }
42 }
43}
44
45impl ToString for PromptCodeSnippet {
46 fn to_string(&self) -> String {
47 let path = self
48 .path
49 .as_ref()
50 .and_then(|path| Some(path.to_string_lossy().to_string()))
51 .unwrap_or("".to_string());
52 let language_name = self.language_name.clone().unwrap_or("".to_string());
53 let content = self.content.clone();
54
55 format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
56 }
57}
58
59#[allow(dead_code)]
60fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
61 #[derive(Debug)]
62 struct Match {
63 collapse: Range<usize>,
64 keep: Vec<Range<usize>>,
65 }
66
67 let selected_range = selected_range.to_offset(buffer);
68 let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
69 Some(&grammar.embedding_config.as_ref()?.query)
70 });
71 let configs = ts_matches
72 .grammars()
73 .iter()
74 .map(|g| g.embedding_config.as_ref().unwrap())
75 .collect::<Vec<_>>();
76 let mut matches = Vec::new();
77 while let Some(mat) = ts_matches.peek() {
78 let config = &configs[mat.grammar_index];
79 if let Some(collapse) = mat.captures.iter().find_map(|cap| {
80 if Some(cap.index) == config.collapse_capture_ix {
81 Some(cap.node.byte_range())
82 } else {
83 None
84 }
85 }) {
86 let mut keep = Vec::new();
87 for capture in mat.captures.iter() {
88 if Some(capture.index) == config.keep_capture_ix {
89 keep.push(capture.node.byte_range());
90 } else {
91 continue;
92 }
93 }
94 ts_matches.advance();
95 matches.push(Match { collapse, keep });
96 } else {
97 ts_matches.advance();
98 }
99 }
100 matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
101 let mut matches = matches.into_iter().peekable();
102
103 let mut summary = String::new();
104 let mut offset = 0;
105 let mut flushed_selection = false;
106 while let Some(mat) = matches.next() {
107 // Keep extending the collapsed range if the next match surrounds
108 // the current one.
109 while let Some(next_mat) = matches.peek() {
110 if mat.collapse.start <= next_mat.collapse.start
111 && mat.collapse.end >= next_mat.collapse.end
112 {
113 matches.next().unwrap();
114 } else {
115 break;
116 }
117 }
118
119 if offset > mat.collapse.start {
120 // Skip collapsed nodes that have already been summarized.
121 offset = cmp::max(offset, mat.collapse.end);
122 continue;
123 }
124
125 if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
126 if !flushed_selection {
127 // The collapsed node ends after the selection starts, so we'll flush the selection first.
128 summary.extend(buffer.text_for_range(offset..selected_range.start));
129 summary.push_str("<|START|");
130 if selected_range.end == selected_range.start {
131 summary.push_str(">");
132 } else {
133 summary.extend(buffer.text_for_range(selected_range.clone()));
134 summary.push_str("|END|>");
135 }
136 offset = selected_range.end;
137 flushed_selection = true;
138 }
139
140 // If the selection intersects the collapsed node, we won't collapse it.
141 if selected_range.end >= mat.collapse.start {
142 continue;
143 }
144 }
145
146 summary.extend(buffer.text_for_range(offset..mat.collapse.start));
147 for keep in mat.keep {
148 summary.extend(buffer.text_for_range(keep));
149 }
150 offset = mat.collapse.end;
151 }
152
153 // Flush selection if we haven't already done so.
154 if !flushed_selection && offset <= selected_range.start {
155 summary.extend(buffer.text_for_range(offset..selected_range.start));
156 summary.push_str("<|START|");
157 if selected_range.end == selected_range.start {
158 summary.push_str(">");
159 } else {
160 summary.extend(buffer.text_for_range(selected_range.clone()));
161 summary.push_str("|END|>");
162 }
163 offset = selected_range.end;
164 }
165
166 summary.extend(buffer.text_for_range(offset..buffer.len()));
167 summary
168}
169
170pub fn generate_content_prompt(
171 user_prompt: String,
172 language_name: Option<&str>,
173 buffer: &BufferSnapshot,
174 range: Range<impl ToOffset>,
175 kind: CodegenKind,
176 search_results: Vec<PromptCodeSnippet>,
177 model: &str,
178) -> String {
179 const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
180 const RESERVED_TOKENS_FOR_GENERATION: usize = 1000;
181
182 let mut prompts = Vec::new();
183 let range = range.to_offset(buffer);
184
185 // General Preamble
186 if let Some(language_name) = language_name {
187 prompts.push(format!("You're an expert {language_name} engineer.\n"));
188 } else {
189 prompts.push("You're an expert engineer.\n".to_string());
190 }
191
192 // Snippets
193 let mut snippet_position = prompts.len() - 1;
194
195 let mut content = String::new();
196 content.extend(buffer.text_for_range(0..range.start));
197 if range.start == range.end {
198 content.push_str("<|START|>");
199 } else {
200 content.push_str("<|START|");
201 }
202 content.extend(buffer.text_for_range(range.clone()));
203 if range.start != range.end {
204 content.push_str("|END|>");
205 }
206 content.extend(buffer.text_for_range(range.end..buffer.len()));
207
208 prompts.push("The file you are currently working on has the following content:\n".to_string());
209
210 if let Some(language_name) = language_name {
211 let language_name = language_name.to_lowercase();
212 prompts.push(format!("```{language_name}\n{content}\n```"));
213 } else {
214 prompts.push(format!("```\n{content}\n```"));
215 }
216
217 match kind {
218 CodegenKind::Generate { position: _ } => {
219 prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
220 prompts
221 .push("Assume the cursor is located where the `<|START|` marker is.".to_string());
222 prompts.push(
223 "Text can't be replaced, so assume your answer will be inserted at the cursor."
224 .to_string(),
225 );
226 prompts.push(format!(
227 "Generate text based on the users prompt: {user_prompt}"
228 ));
229 }
230 CodegenKind::Transform { range: _ } => {
231 prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
232 prompts.push(format!(
233 "Modify the users code selected text based upon the users prompt: '{user_prompt}'"
234 ));
235 prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
236 }
237 }
238
239 if let Some(language_name) = language_name {
240 prompts.push(format!(
241 "Your answer MUST always and only be valid {language_name}"
242 ));
243 }
244 prompts.push("Never make remarks about the output.".to_string());
245 prompts.push("Do not return any text, except the generated code.".to_string());
246 prompts.push("Always wrap your code in a Markdown block".to_string());
247
248 let current_messages = [ChatCompletionRequestMessage {
249 role: "user".to_string(),
250 content: Some(prompts.join("\n")),
251 function_call: None,
252 name: None,
253 }];
254
255 let mut remaining_token_count = if let Ok(current_token_count) =
256 tiktoken_rs::num_tokens_from_messages(model, ¤t_messages)
257 {
258 let max_token_count = tiktoken_rs::model::get_context_size(model);
259 let intermediate_token_count = if max_token_count > current_token_count {
260 max_token_count - current_token_count
261 } else {
262 0
263 };
264
265 if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION {
266 0
267 } else {
268 intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION
269 }
270 } else {
271 // If tiktoken fails to count token count, assume we have no space remaining.
272 0
273 };
274
275 // TODO:
276 // - add repository name to snippet
277 // - add file path
278 // - add language
279 if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
280 let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
281
282 for search_result in search_results {
283 let mut snippet_prompt = template.to_string();
284 let snippet = search_result.to_string();
285 writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap();
286
287 let token_count = encoding
288 .encode_with_special_tokens(snippet_prompt.as_str())
289 .len();
290 if token_count <= remaining_token_count {
291 if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
292 prompts.insert(snippet_position, snippet_prompt);
293 snippet_position += 1;
294 remaining_token_count -= token_count;
295 // If you have already added the template to the prompt, remove the template.
296 template = "";
297 }
298 } else {
299 break;
300 }
301 }
302 }
303
304 prompts.join("\n")
305}
306
307#[cfg(test)]
308pub(crate) mod tests {
309
310 use super::*;
311 use std::sync::Arc;
312
313 use gpui::AppContext;
314 use indoc::indoc;
315 use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
316 use settings::SettingsStore;
317
318 pub(crate) fn rust_lang() -> Language {
319 Language::new(
320 LanguageConfig {
321 name: "Rust".into(),
322 path_suffixes: vec!["rs".to_string()],
323 ..Default::default()
324 },
325 Some(tree_sitter_rust::language()),
326 )
327 .with_embedding_query(
328 r#"
329 (
330 [(line_comment) (attribute_item)]* @context
331 .
332 [
333 (struct_item
334 name: (_) @name)
335
336 (enum_item
337 name: (_) @name)
338
339 (impl_item
340 trait: (_)? @name
341 "for"? @name
342 type: (_) @name)
343
344 (trait_item
345 name: (_) @name)
346
347 (function_item
348 name: (_) @name
349 body: (block
350 "{" @keep
351 "}" @keep) @collapse)
352
353 (macro_definition
354 name: (_) @name)
355 ] @item
356 )
357 "#,
358 )
359 .unwrap()
360 }
361
362 #[gpui::test]
363 fn test_outline_for_prompt(cx: &mut AppContext) {
364 cx.set_global(SettingsStore::test(cx));
365 language_settings::init(cx);
366 let text = indoc! {"
367 struct X {
368 a: usize,
369 b: usize,
370 }
371
372 impl X {
373
374 fn new() -> Self {
375 let a = 1;
376 let b = 2;
377 Self { a, b }
378 }
379
380 pub fn a(&self, param: bool) -> usize {
381 self.a
382 }
383
384 pub fn b(&self) -> usize {
385 self.b
386 }
387 }
388 "};
389 let buffer =
390 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
391 let snapshot = buffer.read(cx).snapshot();
392
393 assert_eq!(
394 summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
395 indoc! {"
396 struct X {
397 <|START|>a: usize,
398 b: usize,
399 }
400
401 impl X {
402
403 fn new() -> Self {}
404
405 pub fn a(&self, param: bool) -> usize {}
406
407 pub fn b(&self) -> usize {}
408 }
409 "}
410 );
411
412 assert_eq!(
413 summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
414 indoc! {"
415 struct X {
416 a: usize,
417 b: usize,
418 }
419
420 impl X {
421
422 fn new() -> Self {
423 let <|START|a |END|>= 1;
424 let b = 2;
425 Self { a, b }
426 }
427
428 pub fn a(&self, param: bool) -> usize {}
429
430 pub fn b(&self) -> usize {}
431 }
432 "}
433 );
434
435 assert_eq!(
436 summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
437 indoc! {"
438 struct X {
439 a: usize,
440 b: usize,
441 }
442
443 impl X {
444 <|START|>
445 fn new() -> Self {}
446
447 pub fn a(&self, param: bool) -> usize {}
448
449 pub fn b(&self) -> usize {}
450 }
451 "}
452 );
453
454 assert_eq!(
455 summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
456 indoc! {"
457 struct X {
458 a: usize,
459 b: usize,
460 }
461
462 impl X {
463
464 fn new() -> Self {}
465
466 pub fn a(&self, param: bool) -> usize {}
467
468 pub fn b(&self) -> usize {}
469 }
470 <|START|>"}
471 );
472
473 // Ensure nested functions get collapsed properly.
474 let text = indoc! {"
475 struct X {
476 a: usize,
477 b: usize,
478 }
479
480 impl X {
481
482 fn new() -> Self {
483 let a = 1;
484 let b = 2;
485 Self { a, b }
486 }
487
488 pub fn a(&self, param: bool) -> usize {
489 let a = 30;
490 fn nested() -> usize {
491 3
492 }
493 self.a + nested()
494 }
495
496 pub fn b(&self) -> usize {
497 self.b
498 }
499 }
500 "};
501 buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
502 let snapshot = buffer.read(cx).snapshot();
503 assert_eq!(
504 summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
505 indoc! {"
506 <|START|>struct X {
507 a: usize,
508 b: usize,
509 }
510
511 impl X {
512
513 fn new() -> Self {}
514
515 pub fn a(&self, param: bool) -> usize {}
516
517 pub fn b(&self) -> usize {}
518 }
519 "}
520 );
521 }
522}