prompts.rs

  1use crate::codegen::CodegenKind;
  2use gpui::AsyncAppContext;
  3use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
  4use semantic_index::SearchResult;
  5use std::cmp::{self, Reverse};
  6use std::fmt::Write;
  7use std::ops::Range;
  8use std::path::PathBuf;
  9use tiktoken_rs::ChatCompletionRequestMessage;
 10
 11pub struct PromptCodeSnippet {
 12    path: Option<PathBuf>,
 13    language_name: Option<String>,
 14    content: String,
 15}
 16
 17impl PromptCodeSnippet {
 18    pub fn new(search_result: SearchResult, cx: &AsyncAppContext) -> Self {
 19        let (content, language_name, file_path) =
 20            search_result.buffer.read_with(cx, |buffer, _| {
 21                let snapshot = buffer.snapshot();
 22                let content = snapshot
 23                    .text_for_range(search_result.range.clone())
 24                    .collect::<String>();
 25
 26                let language_name = buffer
 27                    .language()
 28                    .and_then(|language| Some(language.name().to_string()));
 29
 30                let file_path = buffer
 31                    .file()
 32                    .and_then(|file| Some(file.path().to_path_buf()));
 33
 34                (content, language_name, file_path)
 35            });
 36
 37        PromptCodeSnippet {
 38            path: file_path,
 39            language_name,
 40            content,
 41        }
 42    }
 43}
 44
 45impl ToString for PromptCodeSnippet {
 46    fn to_string(&self) -> String {
 47        let path = self
 48            .path
 49            .as_ref()
 50            .and_then(|path| Some(path.to_string_lossy().to_string()))
 51            .unwrap_or("".to_string());
 52        let language_name = self.language_name.clone().unwrap_or("".to_string());
 53        let content = self.content.clone();
 54
 55        format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
 56    }
 57}
 58
 59#[allow(dead_code)]
 60fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
 61    #[derive(Debug)]
 62    struct Match {
 63        collapse: Range<usize>,
 64        keep: Vec<Range<usize>>,
 65    }
 66
 67    let selected_range = selected_range.to_offset(buffer);
 68    let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
 69        Some(&grammar.embedding_config.as_ref()?.query)
 70    });
 71    let configs = ts_matches
 72        .grammars()
 73        .iter()
 74        .map(|g| g.embedding_config.as_ref().unwrap())
 75        .collect::<Vec<_>>();
 76    let mut matches = Vec::new();
 77    while let Some(mat) = ts_matches.peek() {
 78        let config = &configs[mat.grammar_index];
 79        if let Some(collapse) = mat.captures.iter().find_map(|cap| {
 80            if Some(cap.index) == config.collapse_capture_ix {
 81                Some(cap.node.byte_range())
 82            } else {
 83                None
 84            }
 85        }) {
 86            let mut keep = Vec::new();
 87            for capture in mat.captures.iter() {
 88                if Some(capture.index) == config.keep_capture_ix {
 89                    keep.push(capture.node.byte_range());
 90                } else {
 91                    continue;
 92                }
 93            }
 94            ts_matches.advance();
 95            matches.push(Match { collapse, keep });
 96        } else {
 97            ts_matches.advance();
 98        }
 99    }
100    matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
101    let mut matches = matches.into_iter().peekable();
102
103    let mut summary = String::new();
104    let mut offset = 0;
105    let mut flushed_selection = false;
106    while let Some(mat) = matches.next() {
107        // Keep extending the collapsed range if the next match surrounds
108        // the current one.
109        while let Some(next_mat) = matches.peek() {
110            if mat.collapse.start <= next_mat.collapse.start
111                && mat.collapse.end >= next_mat.collapse.end
112            {
113                matches.next().unwrap();
114            } else {
115                break;
116            }
117        }
118
119        if offset > mat.collapse.start {
120            // Skip collapsed nodes that have already been summarized.
121            offset = cmp::max(offset, mat.collapse.end);
122            continue;
123        }
124
125        if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
126            if !flushed_selection {
127                // The collapsed node ends after the selection starts, so we'll flush the selection first.
128                summary.extend(buffer.text_for_range(offset..selected_range.start));
129                summary.push_str("<|START|");
130                if selected_range.end == selected_range.start {
131                    summary.push_str(">");
132                } else {
133                    summary.extend(buffer.text_for_range(selected_range.clone()));
134                    summary.push_str("|END|>");
135                }
136                offset = selected_range.end;
137                flushed_selection = true;
138            }
139
140            // If the selection intersects the collapsed node, we won't collapse it.
141            if selected_range.end >= mat.collapse.start {
142                continue;
143            }
144        }
145
146        summary.extend(buffer.text_for_range(offset..mat.collapse.start));
147        for keep in mat.keep {
148            summary.extend(buffer.text_for_range(keep));
149        }
150        offset = mat.collapse.end;
151    }
152
153    // Flush selection if we haven't already done so.
154    if !flushed_selection && offset <= selected_range.start {
155        summary.extend(buffer.text_for_range(offset..selected_range.start));
156        summary.push_str("<|START|");
157        if selected_range.end == selected_range.start {
158            summary.push_str(">");
159        } else {
160            summary.extend(buffer.text_for_range(selected_range.clone()));
161            summary.push_str("|END|>");
162        }
163        offset = selected_range.end;
164    }
165
166    summary.extend(buffer.text_for_range(offset..buffer.len()));
167    summary
168}
169
170pub fn generate_content_prompt(
171    user_prompt: String,
172    language_name: Option<&str>,
173    buffer: &BufferSnapshot,
174    range: Range<impl ToOffset>,
175    kind: CodegenKind,
176    search_results: Vec<PromptCodeSnippet>,
177    model: &str,
178) -> String {
179    const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
180    const RESERVED_TOKENS_FOR_GENERATION: usize = 1000;
181
182    let mut prompts = Vec::new();
183    let range = range.to_offset(buffer);
184
185    // General Preamble
186    if let Some(language_name) = language_name {
187        prompts.push(format!("You're an expert {language_name} engineer.\n"));
188    } else {
189        prompts.push("You're an expert engineer.\n".to_string());
190    }
191
192    // Snippets
193    let mut snippet_position = prompts.len() - 1;
194
195    let mut content = String::new();
196    content.extend(buffer.text_for_range(0..range.start));
197    if range.start == range.end {
198        content.push_str("<|START|>");
199    } else {
200        content.push_str("<|START|");
201    }
202    content.extend(buffer.text_for_range(range.clone()));
203    if range.start != range.end {
204        content.push_str("|END|>");
205    }
206    content.extend(buffer.text_for_range(range.end..buffer.len()));
207
208    prompts.push("The file you are currently working on has the following content:\n".to_string());
209
210    if let Some(language_name) = language_name {
211        let language_name = language_name.to_lowercase();
212        prompts.push(format!("```{language_name}\n{content}\n```"));
213    } else {
214        prompts.push(format!("```\n{content}\n```"));
215    }
216
217    match kind {
218        CodegenKind::Generate { position: _ } => {
219            prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
220            prompts
221                .push("Assume the cursor is located where the `<|START|` marker is.".to_string());
222            prompts.push(
223                "Text can't be replaced, so assume your answer will be inserted at the cursor."
224                    .to_string(),
225            );
226            prompts.push(format!(
227                "Generate text based on the users prompt: {user_prompt}"
228            ));
229        }
230        CodegenKind::Transform { range: _ } => {
231            prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
232            prompts.push(format!(
233                "Modify the users code selected text based upon the users prompt: '{user_prompt}'"
234            ));
235            prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
236        }
237    }
238
239    if let Some(language_name) = language_name {
240        prompts.push(format!(
241            "Your answer MUST always and only be valid {language_name}"
242        ));
243    }
244    prompts.push("Never make remarks about the output.".to_string());
245    prompts.push("Do not return any text, except the generated code.".to_string());
246    prompts.push("Do not wrap your text in a Markdown block".to_string());
247
248    let current_messages = [ChatCompletionRequestMessage {
249        role: "user".to_string(),
250        content: Some(prompts.join("\n")),
251        function_call: None,
252        name: None,
253    }];
254
255    let mut remaining_token_count = if let Ok(current_token_count) =
256        tiktoken_rs::num_tokens_from_messages(model, &current_messages)
257    {
258        let max_token_count = tiktoken_rs::model::get_context_size(model);
259        let intermediate_token_count = max_token_count - current_token_count;
260
261        if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION {
262            0
263        } else {
264            intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION
265        }
266    } else {
267        // If tiktoken fails to count token count, assume we have no space remaining.
268        0
269    };
270
271    // TODO:
272    //   - add repository name to snippet
273    //   - add file path
274    //   - add language
275    if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
276        let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
277
278        for search_result in search_results {
279            let mut snippet_prompt = template.to_string();
280            let snippet = search_result.to_string();
281            writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap();
282
283            let token_count = encoding
284                .encode_with_special_tokens(snippet_prompt.as_str())
285                .len();
286            if token_count <= remaining_token_count {
287                if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
288                    prompts.insert(snippet_position, snippet_prompt);
289                    snippet_position += 1;
290                    remaining_token_count -= token_count;
291                    // If you have already added the template to the prompt, remove the template.
292                    template = "";
293                }
294            } else {
295                break;
296            }
297        }
298    }
299
300    prompts.join("\n")
301}
302
303#[cfg(test)]
304pub(crate) mod tests {
305
306    use super::*;
307    use std::sync::Arc;
308
309    use gpui::AppContext;
310    use indoc::indoc;
311    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
312    use settings::SettingsStore;
313
314    pub(crate) fn rust_lang() -> Language {
315        Language::new(
316            LanguageConfig {
317                name: "Rust".into(),
318                path_suffixes: vec!["rs".to_string()],
319                ..Default::default()
320            },
321            Some(tree_sitter_rust::language()),
322        )
323        .with_embedding_query(
324            r#"
325            (
326                [(line_comment) (attribute_item)]* @context
327                .
328                [
329                    (struct_item
330                        name: (_) @name)
331
332                    (enum_item
333                        name: (_) @name)
334
335                    (impl_item
336                        trait: (_)? @name
337                        "for"? @name
338                        type: (_) @name)
339
340                    (trait_item
341                        name: (_) @name)
342
343                    (function_item
344                        name: (_) @name
345                        body: (block
346                            "{" @keep
347                            "}" @keep) @collapse)
348
349                    (macro_definition
350                        name: (_) @name)
351                    ] @item
352                )
353            "#,
354        )
355        .unwrap()
356    }
357
358    #[gpui::test]
359    fn test_outline_for_prompt(cx: &mut AppContext) {
360        cx.set_global(SettingsStore::test(cx));
361        language_settings::init(cx);
362        let text = indoc! {"
363            struct X {
364                a: usize,
365                b: usize,
366            }
367
368            impl X {
369
370                fn new() -> Self {
371                    let a = 1;
372                    let b = 2;
373                    Self { a, b }
374                }
375
376                pub fn a(&self, param: bool) -> usize {
377                    self.a
378                }
379
380                pub fn b(&self) -> usize {
381                    self.b
382                }
383            }
384        "};
385        let buffer =
386            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
387        let snapshot = buffer.read(cx).snapshot();
388
389        assert_eq!(
390            summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
391            indoc! {"
392                struct X {
393                    <|START|>a: usize,
394                    b: usize,
395                }
396
397                impl X {
398
399                    fn new() -> Self {}
400
401                    pub fn a(&self, param: bool) -> usize {}
402
403                    pub fn b(&self) -> usize {}
404                }
405            "}
406        );
407
408        assert_eq!(
409            summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
410            indoc! {"
411                struct X {
412                    a: usize,
413                    b: usize,
414                }
415
416                impl X {
417
418                    fn new() -> Self {
419                        let <|START|a |END|>= 1;
420                        let b = 2;
421                        Self { a, b }
422                    }
423
424                    pub fn a(&self, param: bool) -> usize {}
425
426                    pub fn b(&self) -> usize {}
427                }
428            "}
429        );
430
431        assert_eq!(
432            summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
433            indoc! {"
434                struct X {
435                    a: usize,
436                    b: usize,
437                }
438
439                impl X {
440                <|START|>
441                    fn new() -> Self {}
442
443                    pub fn a(&self, param: bool) -> usize {}
444
445                    pub fn b(&self) -> usize {}
446                }
447            "}
448        );
449
450        assert_eq!(
451            summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
452            indoc! {"
453                struct X {
454                    a: usize,
455                    b: usize,
456                }
457
458                impl X {
459
460                    fn new() -> Self {}
461
462                    pub fn a(&self, param: bool) -> usize {}
463
464                    pub fn b(&self) -> usize {}
465                }
466                <|START|>"}
467        );
468
469        // Ensure nested functions get collapsed properly.
470        let text = indoc! {"
471            struct X {
472                a: usize,
473                b: usize,
474            }
475
476            impl X {
477
478                fn new() -> Self {
479                    let a = 1;
480                    let b = 2;
481                    Self { a, b }
482                }
483
484                pub fn a(&self, param: bool) -> usize {
485                    let a = 30;
486                    fn nested() -> usize {
487                        3
488                    }
489                    self.a + nested()
490                }
491
492                pub fn b(&self) -> usize {
493                    self.b
494                }
495            }
496        "};
497        buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
498        let snapshot = buffer.read(cx).snapshot();
499        assert_eq!(
500            summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
501            indoc! {"
502                <|START|>struct X {
503                    a: usize,
504                    b: usize,
505                }
506
507                impl X {
508
509                    fn new() -> Self {}
510
511                    pub fn a(&self, param: bool) -> usize {}
512
513                    pub fn b(&self) -> usize {}
514                }
515            "}
516        );
517    }
518}