prompts.rs

  1use crate::codegen::CodegenKind;
  2use gpui::AsyncAppContext;
  3use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
  4use semantic_index::SearchResult;
  5use std::cmp::{self, Reverse};
  6use std::fmt::Write;
  7use std::ops::Range;
  8use std::path::PathBuf;
  9use tiktoken_rs::ChatCompletionRequestMessage;
 10
 11pub struct PromptCodeSnippet {
 12    path: Option<PathBuf>,
 13    language_name: Option<String>,
 14    content: String,
 15}
 16
 17impl PromptCodeSnippet {
 18    pub fn new(search_result: SearchResult, cx: &AsyncAppContext) -> Self {
 19        let (content, language_name, file_path) =
 20            search_result.buffer.read_with(cx, |buffer, _| {
 21                let snapshot = buffer.snapshot();
 22                let content = snapshot
 23                    .text_for_range(search_result.range.clone())
 24                    .collect::<String>();
 25
 26                let language_name = buffer
 27                    .language()
 28                    .and_then(|language| Some(language.name().to_string()));
 29
 30                let file_path = buffer
 31                    .file()
 32                    .and_then(|file| Some(file.path().to_path_buf()));
 33
 34                (content, language_name, file_path)
 35            });
 36
 37        PromptCodeSnippet {
 38            path: file_path,
 39            language_name,
 40            content,
 41        }
 42    }
 43}
 44
 45impl ToString for PromptCodeSnippet {
 46    fn to_string(&self) -> String {
 47        let path = self
 48            .path
 49            .as_ref()
 50            .and_then(|path| Some(path.to_string_lossy().to_string()))
 51            .unwrap_or("".to_string());
 52        let language_name = self.language_name.clone().unwrap_or("".to_string());
 53        let content = self.content.clone();
 54
 55        format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
 56    }
 57}
 58
 59#[allow(dead_code)]
 60fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
 61    #[derive(Debug)]
 62    struct Match {
 63        collapse: Range<usize>,
 64        keep: Vec<Range<usize>>,
 65    }
 66
 67    let selected_range = selected_range.to_offset(buffer);
 68    let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
 69        Some(&grammar.embedding_config.as_ref()?.query)
 70    });
 71    let configs = ts_matches
 72        .grammars()
 73        .iter()
 74        .map(|g| g.embedding_config.as_ref().unwrap())
 75        .collect::<Vec<_>>();
 76    let mut matches = Vec::new();
 77    while let Some(mat) = ts_matches.peek() {
 78        let config = &configs[mat.grammar_index];
 79        if let Some(collapse) = mat.captures.iter().find_map(|cap| {
 80            if Some(cap.index) == config.collapse_capture_ix {
 81                Some(cap.node.byte_range())
 82            } else {
 83                None
 84            }
 85        }) {
 86            let mut keep = Vec::new();
 87            for capture in mat.captures.iter() {
 88                if Some(capture.index) == config.keep_capture_ix {
 89                    keep.push(capture.node.byte_range());
 90                } else {
 91                    continue;
 92                }
 93            }
 94            ts_matches.advance();
 95            matches.push(Match { collapse, keep });
 96        } else {
 97            ts_matches.advance();
 98        }
 99    }
100    matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
101    let mut matches = matches.into_iter().peekable();
102
103    let mut summary = String::new();
104    let mut offset = 0;
105    let mut flushed_selection = false;
106    while let Some(mat) = matches.next() {
107        // Keep extending the collapsed range if the next match surrounds
108        // the current one.
109        while let Some(next_mat) = matches.peek() {
110            if mat.collapse.start <= next_mat.collapse.start
111                && mat.collapse.end >= next_mat.collapse.end
112            {
113                matches.next().unwrap();
114            } else {
115                break;
116            }
117        }
118
119        if offset > mat.collapse.start {
120            // Skip collapsed nodes that have already been summarized.
121            offset = cmp::max(offset, mat.collapse.end);
122            continue;
123        }
124
125        if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
126            if !flushed_selection {
127                // The collapsed node ends after the selection starts, so we'll flush the selection first.
128                summary.extend(buffer.text_for_range(offset..selected_range.start));
129                summary.push_str("<|START|");
130                if selected_range.end == selected_range.start {
131                    summary.push_str(">");
132                } else {
133                    summary.extend(buffer.text_for_range(selected_range.clone()));
134                    summary.push_str("|END|>");
135                }
136                offset = selected_range.end;
137                flushed_selection = true;
138            }
139
140            // If the selection intersects the collapsed node, we won't collapse it.
141            if selected_range.end >= mat.collapse.start {
142                continue;
143            }
144        }
145
146        summary.extend(buffer.text_for_range(offset..mat.collapse.start));
147        for keep in mat.keep {
148            summary.extend(buffer.text_for_range(keep));
149        }
150        offset = mat.collapse.end;
151    }
152
153    // Flush selection if we haven't already done so.
154    if !flushed_selection && offset <= selected_range.start {
155        summary.extend(buffer.text_for_range(offset..selected_range.start));
156        summary.push_str("<|START|");
157        if selected_range.end == selected_range.start {
158            summary.push_str(">");
159        } else {
160            summary.extend(buffer.text_for_range(selected_range.clone()));
161            summary.push_str("|END|>");
162        }
163        offset = selected_range.end;
164    }
165
166    summary.extend(buffer.text_for_range(offset..buffer.len()));
167    summary
168}
169
170pub fn generate_content_prompt(
171    user_prompt: String,
172    language_name: Option<&str>,
173    buffer: &BufferSnapshot,
174    range: Range<impl ToOffset>,
175    kind: CodegenKind,
176    search_results: Vec<PromptCodeSnippet>,
177    model: &str,
178) -> String {
179    const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
180    const RESERVED_TOKENS_FOR_GENERATION: usize = 1000;
181
182    let mut prompts = Vec::new();
183    let range = range.to_offset(buffer);
184
185    // General Preamble
186    if let Some(language_name) = language_name {
187        prompts.push(format!("You're an expert {language_name} engineer.\n"));
188    } else {
189        prompts.push("You're an expert engineer.\n".to_string());
190    }
191
192    // Snippets
193    let mut snippet_position = prompts.len() - 1;
194
195    let mut content = String::new();
196    content.extend(buffer.text_for_range(0..range.start));
197    if range.start == range.end {
198        content.push_str("<|START|>");
199    } else {
200        content.push_str("<|START|");
201    }
202    content.extend(buffer.text_for_range(range.clone()));
203    if range.start != range.end {
204        content.push_str("|END|>");
205    }
206    content.extend(buffer.text_for_range(range.end..buffer.len()));
207
208    prompts.push("The file you are currently working on has the following content:\n".to_string());
209
210    if let Some(language_name) = language_name {
211        let language_name = language_name.to_lowercase();
212        prompts.push(format!("```{language_name}\n{content}\n```"));
213    } else {
214        prompts.push(format!("```\n{content}\n```"));
215    }
216
217    match kind {
218        CodegenKind::Generate { position: _ } => {
219            prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
220            prompts
221                .push("Assume the cursor is located where the `<|START|` marker is.".to_string());
222            prompts.push(
223                "Text can't be replaced, so assume your answer will be inserted at the cursor."
224                    .to_string(),
225            );
226            prompts.push(format!(
227                "Generate text based on the users prompt: {user_prompt}"
228            ));
229        }
230        CodegenKind::Transform { range: _ } => {
231            prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
232            prompts.push(format!(
233                "Modify the users code selected text based upon the users prompt: '{user_prompt}'"
234            ));
235            prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
236        }
237    }
238
239    if let Some(language_name) = language_name {
240        prompts.push(format!(
241            "Your answer MUST always and only be valid {language_name}"
242        ));
243    }
244    prompts.push("Never make remarks about the output.".to_string());
245    prompts.push("Do not return any text, except the generated code.".to_string());
246    prompts.push("Always wrap your code in a Markdown block".to_string());
247
248    let current_messages = [ChatCompletionRequestMessage {
249        role: "user".to_string(),
250        content: Some(prompts.join("\n")),
251        function_call: None,
252        name: None,
253    }];
254
255    let mut remaining_token_count = if let Ok(current_token_count) =
256        tiktoken_rs::num_tokens_from_messages(model, &current_messages)
257    {
258        let max_token_count = tiktoken_rs::model::get_context_size(model);
259        let intermediate_token_count = if max_token_count > current_token_count {
260            max_token_count - current_token_count
261        } else {
262            0
263        };
264
265        if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION {
266            0
267        } else {
268            intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION
269        }
270    } else {
271        // If tiktoken fails to count token count, assume we have no space remaining.
272        0
273    };
274
275    // TODO:
276    //   - add repository name to snippet
277    //   - add file path
278    //   - add language
279    if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
280        let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
281
282        for search_result in search_results {
283            let mut snippet_prompt = template.to_string();
284            let snippet = search_result.to_string();
285            writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap();
286
287            let token_count = encoding
288                .encode_with_special_tokens(snippet_prompt.as_str())
289                .len();
290            if token_count <= remaining_token_count {
291                if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
292                    prompts.insert(snippet_position, snippet_prompt);
293                    snippet_position += 1;
294                    remaining_token_count -= token_count;
295                    // If you have already added the template to the prompt, remove the template.
296                    template = "";
297                }
298            } else {
299                break;
300            }
301        }
302    }
303
304    prompts.join("\n")
305}
306
307#[cfg(test)]
308pub(crate) mod tests {
309
310    use super::*;
311    use std::sync::Arc;
312
313    use gpui::AppContext;
314    use indoc::indoc;
315    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
316    use settings::SettingsStore;
317
318    pub(crate) fn rust_lang() -> Language {
319        Language::new(
320            LanguageConfig {
321                name: "Rust".into(),
322                path_suffixes: vec!["rs".to_string()],
323                ..Default::default()
324            },
325            Some(tree_sitter_rust::language()),
326        )
327        .with_embedding_query(
328            r#"
329            (
330                [(line_comment) (attribute_item)]* @context
331                .
332                [
333                    (struct_item
334                        name: (_) @name)
335
336                    (enum_item
337                        name: (_) @name)
338
339                    (impl_item
340                        trait: (_)? @name
341                        "for"? @name
342                        type: (_) @name)
343
344                    (trait_item
345                        name: (_) @name)
346
347                    (function_item
348                        name: (_) @name
349                        body: (block
350                            "{" @keep
351                            "}" @keep) @collapse)
352
353                    (macro_definition
354                        name: (_) @name)
355                    ] @item
356                )
357            "#,
358        )
359        .unwrap()
360    }
361
362    #[gpui::test]
363    fn test_outline_for_prompt(cx: &mut AppContext) {
364        cx.set_global(SettingsStore::test(cx));
365        language_settings::init(cx);
366        let text = indoc! {"
367            struct X {
368                a: usize,
369                b: usize,
370            }
371
372            impl X {
373
374                fn new() -> Self {
375                    let a = 1;
376                    let b = 2;
377                    Self { a, b }
378                }
379
380                pub fn a(&self, param: bool) -> usize {
381                    self.a
382                }
383
384                pub fn b(&self) -> usize {
385                    self.b
386                }
387            }
388        "};
389        let buffer =
390            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
391        let snapshot = buffer.read(cx).snapshot();
392
393        assert_eq!(
394            summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
395            indoc! {"
396                struct X {
397                    <|START|>a: usize,
398                    b: usize,
399                }
400
401                impl X {
402
403                    fn new() -> Self {}
404
405                    pub fn a(&self, param: bool) -> usize {}
406
407                    pub fn b(&self) -> usize {}
408                }
409            "}
410        );
411
412        assert_eq!(
413            summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
414            indoc! {"
415                struct X {
416                    a: usize,
417                    b: usize,
418                }
419
420                impl X {
421
422                    fn new() -> Self {
423                        let <|START|a |END|>= 1;
424                        let b = 2;
425                        Self { a, b }
426                    }
427
428                    pub fn a(&self, param: bool) -> usize {}
429
430                    pub fn b(&self) -> usize {}
431                }
432            "}
433        );
434
435        assert_eq!(
436            summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
437            indoc! {"
438                struct X {
439                    a: usize,
440                    b: usize,
441                }
442
443                impl X {
444                <|START|>
445                    fn new() -> Self {}
446
447                    pub fn a(&self, param: bool) -> usize {}
448
449                    pub fn b(&self) -> usize {}
450                }
451            "}
452        );
453
454        assert_eq!(
455            summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
456            indoc! {"
457                struct X {
458                    a: usize,
459                    b: usize,
460                }
461
462                impl X {
463
464                    fn new() -> Self {}
465
466                    pub fn a(&self, param: bool) -> usize {}
467
468                    pub fn b(&self) -> usize {}
469                }
470                <|START|>"}
471        );
472
473        // Ensure nested functions get collapsed properly.
474        let text = indoc! {"
475            struct X {
476                a: usize,
477                b: usize,
478            }
479
480            impl X {
481
482                fn new() -> Self {
483                    let a = 1;
484                    let b = 2;
485                    Self { a, b }
486                }
487
488                pub fn a(&self, param: bool) -> usize {
489                    let a = 30;
490                    fn nested() -> usize {
491                        3
492                    }
493                    self.a + nested()
494                }
495
496                pub fn b(&self) -> usize {
497                    self.b
498                }
499            }
500        "};
501        buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
502        let snapshot = buffer.read(cx).snapshot();
503        assert_eq!(
504            summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
505            indoc! {"
506                <|START|>struct X {
507                    a: usize,
508                    b: usize,
509                }
510
511                impl X {
512
513                    fn new() -> Self {}
514
515                    pub fn a(&self, param: bool) -> usize {}
516
517                    pub fn b(&self) -> usize {}
518                }
519            "}
520        );
521    }
522}