prompts.rs

  1use gpui::AppContext;
  2use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
  3use std::cmp;
  4use std::ops::Range;
  5use std::{fmt::Write, iter};
  6
  7use crate::codegen::CodegenKind;
  8
  9fn outline_for_prompt(
 10    buffer: &BufferSnapshot,
 11    range: Range<language::Anchor>,
 12    cx: &AppContext,
 13) -> Option<String> {
 14    let indent = buffer
 15        .language_indent_size_at(0, cx)
 16        .chars()
 17        .collect::<String>();
 18    let outline = buffer.outline(None)?;
 19    let range = range.to_offset(buffer);
 20
 21    let mut text = String::new();
 22    let mut items = outline.items.into_iter().peekable();
 23
 24    let mut intersected = false;
 25    let mut intersection_indent = 0;
 26    let mut extended_range = range.clone();
 27
 28    while let Some(item) = items.next() {
 29        let item_range = item.range.to_offset(buffer);
 30        if item_range.end < range.start || item_range.start > range.end {
 31            text.extend(iter::repeat(indent.as_str()).take(item.depth));
 32            text.push_str(&item.text);
 33            text.push('\n');
 34        } else {
 35            intersected = true;
 36            let is_terminal = items
 37                .peek()
 38                .map_or(true, |next_item| next_item.depth <= item.depth);
 39            if is_terminal {
 40                if item_range.start <= extended_range.start {
 41                    extended_range.start = item_range.start;
 42                    intersection_indent = item.depth;
 43                }
 44                extended_range.end = cmp::max(extended_range.end, item_range.end);
 45            } else {
 46                let name_start = item_range.start + item.name_ranges.first().unwrap().start;
 47                let name_end = item_range.start + item.name_ranges.last().unwrap().end;
 48
 49                if range.start > name_end {
 50                    text.extend(iter::repeat(indent.as_str()).take(item.depth));
 51                    text.push_str(&item.text);
 52                    text.push('\n');
 53                } else {
 54                    if name_start <= extended_range.start {
 55                        extended_range.start = item_range.start;
 56                        intersection_indent = item.depth;
 57                    }
 58                    extended_range.end = cmp::max(extended_range.end, name_end);
 59                }
 60            }
 61        }
 62
 63        if intersected
 64            && items.peek().map_or(true, |next_item| {
 65                next_item.range.start.to_offset(buffer) > range.end
 66            })
 67        {
 68            intersected = false;
 69            text.extend(iter::repeat(indent.as_str()).take(intersection_indent));
 70            text.extend(buffer.text_for_range(extended_range.start..range.start));
 71            text.push_str("<|START|");
 72            text.extend(buffer.text_for_range(range.clone()));
 73            if range.start != range.end {
 74                text.push_str("|END|>");
 75            } else {
 76                text.push_str(">");
 77            }
 78            text.extend(buffer.text_for_range(range.end..extended_range.end));
 79            text.push('\n');
 80        }
 81    }
 82
 83    Some(text)
 84}
 85
 86pub fn generate_content_prompt(
 87    user_prompt: String,
 88    language_name: Option<&str>,
 89    buffer: &BufferSnapshot,
 90    range: Range<language::Anchor>,
 91    cx: &AppContext,
 92    kind: CodegenKind,
 93) -> String {
 94    let mut prompt = String::new();
 95
 96    // General Preamble
 97    if let Some(language_name) = language_name {
 98        writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
 99    } else {
100        writeln!(prompt, "You're an expert engineer.\n").unwrap();
101    }
102
103    let outline = outline_for_prompt(buffer, range.clone(), cx);
104    if let Some(outline) = outline {
105        writeln!(
106            prompt,
107            "The file you are currently working on has the following outline:"
108        )
109        .unwrap();
110        if let Some(language_name) = language_name {
111            let language_name = language_name.to_lowercase();
112            writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap();
113        } else {
114            writeln!(prompt, "```\n{outline}\n```").unwrap();
115        }
116    }
117
118    // Assume for now that we are just generating
119    if range.clone().start == range.end {
120        writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap();
121    } else {
122        writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
123    }
124
125    match kind {
126        CodegenKind::Generate { position: _ } => {
127            writeln!(
128                prompt,
129                "Assume the cursor is located where the `<|START|` marker is."
130            )
131            .unwrap();
132            writeln!(
133                prompt,
134                "Text can't be replaced, so assume your answer will be inserted at the cursor."
135            )
136            .unwrap();
137            writeln!(
138                prompt,
139                "Generate text based on the users prompt: {user_prompt}"
140            )
141            .unwrap();
142        }
143        CodegenKind::Transform { range: _ } => {
144            writeln!(
145                prompt,
146                "Modify the users code selected text based upon the users prompt: {user_prompt}"
147            )
148            .unwrap();
149            writeln!(
150                prompt,
151                "You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file."
152            )
153            .unwrap();
154        }
155    }
156
157    if let Some(language_name) = language_name {
158        writeln!(prompt, "Your answer MUST always be valid {language_name}").unwrap();
159    }
160    writeln!(prompt, "Always wrap your response in a Markdown codeblock").unwrap();
161    writeln!(prompt, "Never make remarks about the output.").unwrap();
162
163    prompt
164}
165
166#[cfg(test)]
167pub(crate) mod tests {
168
169    use super::*;
170    use std::sync::Arc;
171
172    use gpui::AppContext;
173    use indoc::indoc;
174    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
175    use settings::SettingsStore;
176
177    pub(crate) fn rust_lang() -> Language {
178        Language::new(
179            LanguageConfig {
180                name: "Rust".into(),
181                path_suffixes: vec!["rs".to_string()],
182                ..Default::default()
183            },
184            Some(tree_sitter_rust::language()),
185        )
186        .with_indents_query(
187            r#"
188                (call_expression) @indent
189                (field_expression) @indent
190                (_ "(" ")" @end) @indent
191                (_ "{" "}" @end) @indent
192                "#,
193        )
194        .unwrap()
195        .with_outline_query(
196            r#"
197                (struct_item
198                    "struct" @context
199                    name: (_) @name) @item
200                (enum_item
201                    "enum" @context
202                    name: (_) @name) @item
203                (enum_variant
204                    name: (_) @name) @item
205                (field_declaration
206                    name: (_) @name) @item
207                (impl_item
208                    "impl" @context
209                    trait: (_)? @name
210                    "for"? @context
211                    type: (_) @name) @item
212                (function_item
213                    "fn" @context
214                    name: (_) @name) @item
215                (mod_item
216                    "mod" @context
217                    name: (_) @name) @item
218                "#,
219        )
220        .unwrap()
221    }
222
223    #[gpui::test]
224    fn test_outline_for_prompt(cx: &mut AppContext) {
225        cx.set_global(SettingsStore::test(cx));
226        language_settings::init(cx);
227        let text = indoc! {"
228            struct X {
229                a: usize,
230                b: usize,
231            }
232
233            impl X {
234
235                fn new() -> Self {
236                    let a = 1;
237                    let b = 2;
238                    Self { a, b }
239                }
240
241                pub fn a(&self, param: bool) -> usize {
242                    self.a
243                }
244
245                pub fn b(&self) -> usize {
246                    self.b
247                }
248            }
249        "};
250        let buffer =
251            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
252        let snapshot = buffer.read(cx).snapshot();
253
254        let outline = outline_for_prompt(
255            &snapshot,
256            snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_before(Point::new(1, 4)),
257            cx,
258        );
259        assert_eq!(
260            outline.as_deref(),
261            Some(indoc! {"
262                struct X
263                    <|START|>a: usize
264                    b
265                impl X
266                    fn new
267                    fn a
268                    fn b
269            "})
270        );
271
272        let outline = outline_for_prompt(
273            &snapshot,
274            snapshot.anchor_before(Point::new(8, 12))..snapshot.anchor_before(Point::new(8, 14)),
275            cx,
276        );
277        assert_eq!(
278            outline.as_deref(),
279            Some(indoc! {"
280                struct X
281                    a
282                    b
283                impl X
284                    fn new() -> Self {
285                        let <|START|a |END|>= 1;
286                        let b = 2;
287                        Self { a, b }
288                    }
289                    fn a
290                    fn b
291            "})
292        );
293
294        let outline = outline_for_prompt(
295            &snapshot,
296            snapshot.anchor_before(Point::new(6, 0))..snapshot.anchor_before(Point::new(6, 0)),
297            cx,
298        );
299        assert_eq!(
300            outline.as_deref(),
301            Some(indoc! {"
302                struct X
303                    a
304                    b
305                impl X
306                <|START|>
307                    fn new
308                    fn a
309                    fn b
310            "})
311        );
312
313        let outline = outline_for_prompt(
314            &snapshot,
315            snapshot.anchor_before(Point::new(8, 12))..snapshot.anchor_before(Point::new(13, 9)),
316            cx,
317        );
318        assert_eq!(
319            outline.as_deref(),
320            Some(indoc! {"
321                struct X
322                    a
323                    b
324                impl X
325                    fn new() -> Self {
326                        let <|START|a = 1;
327                        let b = 2;
328                        Self { a, b }
329                    }
330
331                    pub f|END|>n a(&self, param: bool) -> usize {
332                        self.a
333                    }
334                    fn b
335            "})
336        );
337
338        let outline = outline_for_prompt(
339            &snapshot,
340            snapshot.anchor_before(Point::new(5, 6))..snapshot.anchor_before(Point::new(12, 0)),
341            cx,
342        );
343        assert_eq!(
344            outline.as_deref(),
345            Some(indoc! {"
346                struct X
347                    a
348                    b
349                impl X<|START| {
350
351                    fn new() -> Self {
352                        let a = 1;
353                        let b = 2;
354                        Self { a, b }
355                    }
356                |END|>
357                    fn a
358                    fn b
359            "})
360        );
361
362        let outline = outline_for_prompt(
363            &snapshot,
364            snapshot.anchor_before(Point::new(18, 8))..snapshot.anchor_before(Point::new(18, 8)),
365            cx,
366        );
367        assert_eq!(
368            outline.as_deref(),
369            Some(indoc! {"
370                struct X
371                    a
372                    b
373                impl X
374                    fn new
375                    fn a
376                    pub fn b(&self) -> usize {
377                        <|START|>self.b
378                    }
379            "})
380        );
381    }
382}