prompts.rs

  1use crate::codegen::CodegenKind;
  2use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
  3use std::cmp::{self, Reverse};
  4use std::fmt::Write;
  5use std::ops::Range;
  6use tiktoken_rs::ChatCompletionRequestMessage;
  7
  8fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
  9    #[derive(Debug)]
 10    struct Match {
 11        collapse: Range<usize>,
 12        keep: Vec<Range<usize>>,
 13    }
 14
 15    let selected_range = selected_range.to_offset(buffer);
 16    let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
 17        Some(&grammar.embedding_config.as_ref()?.query)
 18    });
 19    let configs = ts_matches
 20        .grammars()
 21        .iter()
 22        .map(|g| g.embedding_config.as_ref().unwrap())
 23        .collect::<Vec<_>>();
 24    let mut matches = Vec::new();
 25    while let Some(mat) = ts_matches.peek() {
 26        let config = &configs[mat.grammar_index];
 27        if let Some(collapse) = mat.captures.iter().find_map(|cap| {
 28            if Some(cap.index) == config.collapse_capture_ix {
 29                Some(cap.node.byte_range())
 30            } else {
 31                None
 32            }
 33        }) {
 34            let mut keep = Vec::new();
 35            for capture in mat.captures.iter() {
 36                if Some(capture.index) == config.keep_capture_ix {
 37                    keep.push(capture.node.byte_range());
 38                } else {
 39                    continue;
 40                }
 41            }
 42            ts_matches.advance();
 43            matches.push(Match { collapse, keep });
 44        } else {
 45            ts_matches.advance();
 46        }
 47    }
 48    matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
 49    let mut matches = matches.into_iter().peekable();
 50
 51    let mut summary = String::new();
 52    let mut offset = 0;
 53    let mut flushed_selection = false;
 54    while let Some(mat) = matches.next() {
 55        // Keep extending the collapsed range if the next match surrounds
 56        // the current one.
 57        while let Some(next_mat) = matches.peek() {
 58            if mat.collapse.start <= next_mat.collapse.start
 59                && mat.collapse.end >= next_mat.collapse.end
 60            {
 61                matches.next().unwrap();
 62            } else {
 63                break;
 64            }
 65        }
 66
 67        if offset > mat.collapse.start {
 68            // Skip collapsed nodes that have already been summarized.
 69            offset = cmp::max(offset, mat.collapse.end);
 70            continue;
 71        }
 72
 73        if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
 74            if !flushed_selection {
 75                // The collapsed node ends after the selection starts, so we'll flush the selection first.
 76                summary.extend(buffer.text_for_range(offset..selected_range.start));
 77                summary.push_str("<|START|");
 78                if selected_range.end == selected_range.start {
 79                    summary.push_str(">");
 80                } else {
 81                    summary.extend(buffer.text_for_range(selected_range.clone()));
 82                    summary.push_str("|END|>");
 83                }
 84                offset = selected_range.end;
 85                flushed_selection = true;
 86            }
 87
 88            // If the selection intersects the collapsed node, we won't collapse it.
 89            if selected_range.end >= mat.collapse.start {
 90                continue;
 91            }
 92        }
 93
 94        summary.extend(buffer.text_for_range(offset..mat.collapse.start));
 95        for keep in mat.keep {
 96            summary.extend(buffer.text_for_range(keep));
 97        }
 98        offset = mat.collapse.end;
 99    }
100
101    // Flush selection if we haven't already done so.
102    if !flushed_selection && offset <= selected_range.start {
103        summary.extend(buffer.text_for_range(offset..selected_range.start));
104        summary.push_str("<|START|");
105        if selected_range.end == selected_range.start {
106            summary.push_str(">");
107        } else {
108            summary.extend(buffer.text_for_range(selected_range.clone()));
109            summary.push_str("|END|>");
110        }
111        offset = selected_range.end;
112    }
113
114    summary.extend(buffer.text_for_range(offset..buffer.len()));
115    summary
116}
117
118pub fn generate_content_prompt(
119    user_prompt: String,
120    language_name: Option<&str>,
121    buffer: &BufferSnapshot,
122    range: Range<impl ToOffset>,
123    kind: CodegenKind,
124    search_results: Vec<String>,
125    model: &str,
126) -> String {
127    const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
128
129    let mut prompts = Vec::new();
130
131    // General Preamble
132    if let Some(language_name) = language_name {
133        prompts.push(format!("You're an expert {language_name} engineer.\n"));
134    } else {
135        prompts.push("You're an expert engineer.\n".to_string());
136    }
137
138    // Snippets
139    let mut snippet_position = prompts.len() - 1;
140
141    let outline = summarize(buffer, range);
142    prompts.push("The file you are currently working on has the following outline:".to_string());
143    if let Some(language_name) = language_name {
144        let language_name = language_name.to_lowercase();
145        prompts.push(format!("```{language_name}\n{outline}\n```"));
146    } else {
147        prompts.push(format!("```\n{outline}\n```"));
148    }
149
150    match kind {
151        CodegenKind::Generate { position: _ } => {
152            prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
153            prompts
154                .push("Assume the cursor is located where the `<|START|` marker is.".to_string());
155            prompts.push(
156                "Text can't be replaced, so assume your answer will be inserted at the cursor."
157                    .to_string(),
158            );
159            prompts.push(format!(
160                "Generate text based on the users prompt: {user_prompt}"
161            ));
162        }
163        CodegenKind::Transform { range: _ } => {
164            prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
165            prompts.push(format!(
166                "Modify the users code selected text based upon the users prompt: {user_prompt}"
167            ));
168            prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
169        }
170    }
171
172    if let Some(language_name) = language_name {
173        prompts.push(format!("Your answer MUST always be valid {language_name}"));
174    }
175    prompts.push("Always wrap your response in a Markdown codeblock".to_string());
176    prompts.push("Never make remarks about the output.".to_string());
177
178    let current_messages = [ChatCompletionRequestMessage {
179        role: "user".to_string(),
180        content: Some(prompts.join("\n")),
181        function_call: None,
182        name: None,
183    }];
184
185    let remaining_token_count = if let Ok(current_token_count) =
186        tiktoken_rs::num_tokens_from_messages(model, &current_messages)
187    {
188        let max_token_count = tiktoken_rs::model::get_context_size(model);
189        max_token_count - current_token_count
190    } else {
191        // If tiktoken fails to count token count, assume we have no space remaining.
192        0
193    };
194
195    // TODO:
196    //   - add repository name to snippet
197    //   - add file path
198    //   - add language
199    if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
200        let template = "You are working inside a large repository, here are a few code snippets that may be useful";
201
202        for search_result in search_results {
203            let mut snippet_prompt = template.to_string();
204            writeln!(snippet_prompt, "```\n{search_result}\n```").unwrap();
205
206            let token_count = encoding
207                .encode_with_special_tokens(snippet_prompt.as_str())
208                .len();
209            if token_count <= remaining_token_count {
210                if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
211                    prompts.insert(snippet_position, snippet_prompt);
212                    snippet_position += 1;
213                }
214            } else {
215                break;
216            }
217        }
218    }
219
220    let prompt = prompts.join("\n");
221    println!("PROMPT: {:?}", prompt);
222    prompt
223}
224
225#[cfg(test)]
226pub(crate) mod tests {
227
228    use super::*;
229    use std::sync::Arc;
230
231    use gpui::AppContext;
232    use indoc::indoc;
233    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
234    use settings::SettingsStore;
235
236    pub(crate) fn rust_lang() -> Language {
237        Language::new(
238            LanguageConfig {
239                name: "Rust".into(),
240                path_suffixes: vec!["rs".to_string()],
241                ..Default::default()
242            },
243            Some(tree_sitter_rust::language()),
244        )
245        .with_embedding_query(
246            r#"
247            (
248                [(line_comment) (attribute_item)]* @context
249                .
250                [
251                    (struct_item
252                        name: (_) @name)
253
254                    (enum_item
255                        name: (_) @name)
256
257                    (impl_item
258                        trait: (_)? @name
259                        "for"? @name
260                        type: (_) @name)
261
262                    (trait_item
263                        name: (_) @name)
264
265                    (function_item
266                        name: (_) @name
267                        body: (block
268                            "{" @keep
269                            "}" @keep) @collapse)
270
271                    (macro_definition
272                        name: (_) @name)
273                    ] @item
274                )
275            "#,
276        )
277        .unwrap()
278    }
279
280    #[gpui::test]
281    fn test_outline_for_prompt(cx: &mut AppContext) {
282        cx.set_global(SettingsStore::test(cx));
283        language_settings::init(cx);
284        let text = indoc! {"
285            struct X {
286                a: usize,
287                b: usize,
288            }
289
290            impl X {
291
292                fn new() -> Self {
293                    let a = 1;
294                    let b = 2;
295                    Self { a, b }
296                }
297
298                pub fn a(&self, param: bool) -> usize {
299                    self.a
300                }
301
302                pub fn b(&self) -> usize {
303                    self.b
304                }
305            }
306        "};
307        let buffer =
308            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
309        let snapshot = buffer.read(cx).snapshot();
310
311        assert_eq!(
312            summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
313            indoc! {"
314                struct X {
315                    <|START|>a: usize,
316                    b: usize,
317                }
318
319                impl X {
320
321                    fn new() -> Self {}
322
323                    pub fn a(&self, param: bool) -> usize {}
324
325                    pub fn b(&self) -> usize {}
326                }
327            "}
328        );
329
330        assert_eq!(
331            summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
332            indoc! {"
333                struct X {
334                    a: usize,
335                    b: usize,
336                }
337
338                impl X {
339
340                    fn new() -> Self {
341                        let <|START|a |END|>= 1;
342                        let b = 2;
343                        Self { a, b }
344                    }
345
346                    pub fn a(&self, param: bool) -> usize {}
347
348                    pub fn b(&self) -> usize {}
349                }
350            "}
351        );
352
353        assert_eq!(
354            summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
355            indoc! {"
356                struct X {
357                    a: usize,
358                    b: usize,
359                }
360
361                impl X {
362                <|START|>
363                    fn new() -> Self {}
364
365                    pub fn a(&self, param: bool) -> usize {}
366
367                    pub fn b(&self) -> usize {}
368                }
369            "}
370        );
371
372        assert_eq!(
373            summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
374            indoc! {"
375                struct X {
376                    a: usize,
377                    b: usize,
378                }
379
380                impl X {
381
382                    fn new() -> Self {}
383
384                    pub fn a(&self, param: bool) -> usize {}
385
386                    pub fn b(&self) -> usize {}
387                }
388                <|START|>"}
389        );
390
391        // Ensure nested functions get collapsed properly.
392        let text = indoc! {"
393            struct X {
394                a: usize,
395                b: usize,
396            }
397
398            impl X {
399
400                fn new() -> Self {
401                    let a = 1;
402                    let b = 2;
403                    Self { a, b }
404                }
405
406                pub fn a(&self, param: bool) -> usize {
407                    let a = 30;
408                    fn nested() -> usize {
409                        3
410                    }
411                    self.a + nested()
412                }
413
414                pub fn b(&self) -> usize {
415                    self.b
416                }
417            }
418        "};
419        buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
420        let snapshot = buffer.read(cx).snapshot();
421        assert_eq!(
422            summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
423            indoc! {"
424                <|START|>struct X {
425                    a: usize,
426                    b: usize,
427                }
428
429                impl X {
430
431                    fn new() -> Self {}
432
433                    pub fn a(&self, param: bool) -> usize {}
434
435                    pub fn b(&self) -> usize {}
436                }
437            "}
438        );
439    }
440}