prompts.rs

  1use ai::models::{LanguageModel, OpenAILanguageModel};
  2use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
  3use ai::templates::file_context::FileContext;
  4use ai::templates::generate::GenerateInlineContent;
  5use ai::templates::preamble::EngineerPreamble;
  6use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
  7use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
  8use std::cmp::{self, Reverse};
  9use std::ops::Range;
 10use std::sync::Arc;
 11
 12#[allow(dead_code)]
 13fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
 14    #[derive(Debug)]
 15    struct Match {
 16        collapse: Range<usize>,
 17        keep: Vec<Range<usize>>,
 18    }
 19
 20    let selected_range = selected_range.to_offset(buffer);
 21    let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
 22        Some(&grammar.embedding_config.as_ref()?.query)
 23    });
 24    let configs = ts_matches
 25        .grammars()
 26        .iter()
 27        .map(|g| g.embedding_config.as_ref().unwrap())
 28        .collect::<Vec<_>>();
 29    let mut matches = Vec::new();
 30    while let Some(mat) = ts_matches.peek() {
 31        let config = &configs[mat.grammar_index];
 32        if let Some(collapse) = mat.captures.iter().find_map(|cap| {
 33            if Some(cap.index) == config.collapse_capture_ix {
 34                Some(cap.node.byte_range())
 35            } else {
 36                None
 37            }
 38        }) {
 39            let mut keep = Vec::new();
 40            for capture in mat.captures.iter() {
 41                if Some(capture.index) == config.keep_capture_ix {
 42                    keep.push(capture.node.byte_range());
 43                } else {
 44                    continue;
 45                }
 46            }
 47            ts_matches.advance();
 48            matches.push(Match { collapse, keep });
 49        } else {
 50            ts_matches.advance();
 51        }
 52    }
 53    matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
 54    let mut matches = matches.into_iter().peekable();
 55
 56    let mut summary = String::new();
 57    let mut offset = 0;
 58    let mut flushed_selection = false;
 59    while let Some(mat) = matches.next() {
 60        // Keep extending the collapsed range if the next match surrounds
 61        // the current one.
 62        while let Some(next_mat) = matches.peek() {
 63            if mat.collapse.start <= next_mat.collapse.start
 64                && mat.collapse.end >= next_mat.collapse.end
 65            {
 66                matches.next().unwrap();
 67            } else {
 68                break;
 69            }
 70        }
 71
 72        if offset > mat.collapse.start {
 73            // Skip collapsed nodes that have already been summarized.
 74            offset = cmp::max(offset, mat.collapse.end);
 75            continue;
 76        }
 77
 78        if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
 79            if !flushed_selection {
 80                // The collapsed node ends after the selection starts, so we'll flush the selection first.
 81                summary.extend(buffer.text_for_range(offset..selected_range.start));
 82                summary.push_str("<|START|");
 83                if selected_range.end == selected_range.start {
 84                    summary.push_str(">");
 85                } else {
 86                    summary.extend(buffer.text_for_range(selected_range.clone()));
 87                    summary.push_str("|END|>");
 88                }
 89                offset = selected_range.end;
 90                flushed_selection = true;
 91            }
 92
 93            // If the selection intersects the collapsed node, we won't collapse it.
 94            if selected_range.end >= mat.collapse.start {
 95                continue;
 96            }
 97        }
 98
 99        summary.extend(buffer.text_for_range(offset..mat.collapse.start));
100        for keep in mat.keep {
101            summary.extend(buffer.text_for_range(keep));
102        }
103        offset = mat.collapse.end;
104    }
105
106    // Flush selection if we haven't already done so.
107    if !flushed_selection && offset <= selected_range.start {
108        summary.extend(buffer.text_for_range(offset..selected_range.start));
109        summary.push_str("<|START|");
110        if selected_range.end == selected_range.start {
111            summary.push_str(">");
112        } else {
113            summary.extend(buffer.text_for_range(selected_range.clone()));
114            summary.push_str("|END|>");
115        }
116        offset = selected_range.end;
117    }
118
119    summary.extend(buffer.text_for_range(offset..buffer.len()));
120    summary
121}
122
123pub fn generate_content_prompt(
124    user_prompt: String,
125    language_name: Option<&str>,
126    buffer: BufferSnapshot,
127    range: Range<usize>,
128    search_results: Vec<PromptCodeSnippet>,
129    model: &str,
130    project_name: Option<String>,
131) -> anyhow::Result<String> {
132    // Using new Prompt Templates
133    let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
134    let lang_name = if let Some(language_name) = language_name {
135        Some(language_name.to_string())
136    } else {
137        None
138    };
139
140    let args = PromptArguments {
141        model: openai_model,
142        language_name: lang_name.clone(),
143        project_name,
144        snippets: search_results.clone(),
145        reserved_tokens: 1000,
146        buffer: Some(buffer),
147        selected_range: Some(range),
148        user_prompt: Some(user_prompt.clone()),
149    };
150
151    let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
152        (PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
153        (
154            PromptPriority::Ordered { order: 1 },
155            Box::new(RepositoryContext {}),
156        ),
157        (
158            PromptPriority::Ordered { order: 0 },
159            Box::new(FileContext {}),
160        ),
161        (
162            PromptPriority::Mandatory,
163            Box::new(GenerateInlineContent {}),
164        ),
165    ];
166    let chain = PromptChain::new(args, templates);
167    let (prompt, _) = chain.generate(true)?;
168
169    anyhow::Ok(prompt)
170}
171
172#[cfg(test)]
173pub(crate) mod tests {
174
175    use super::*;
176    use std::sync::Arc;
177
178    use gpui::AppContext;
179    use indoc::indoc;
180    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
181    use settings::SettingsStore;
182
183    pub(crate) fn rust_lang() -> Language {
184        Language::new(
185            LanguageConfig {
186                name: "Rust".into(),
187                path_suffixes: vec!["rs".to_string()],
188                ..Default::default()
189            },
190            Some(tree_sitter_rust::language()),
191        )
192        .with_embedding_query(
193            r#"
194            (
195                [(line_comment) (attribute_item)]* @context
196                .
197                [
198                    (struct_item
199                        name: (_) @name)
200
201                    (enum_item
202                        name: (_) @name)
203
204                    (impl_item
205                        trait: (_)? @name
206                        "for"? @name
207                        type: (_) @name)
208
209                    (trait_item
210                        name: (_) @name)
211
212                    (function_item
213                        name: (_) @name
214                        body: (block
215                            "{" @keep
216                            "}" @keep) @collapse)
217
218                    (macro_definition
219                        name: (_) @name)
220                    ] @item
221                )
222            "#,
223        )
224        .unwrap()
225    }
226
227    #[gpui::test]
228    fn test_outline_for_prompt(cx: &mut AppContext) {
229        cx.set_global(SettingsStore::test(cx));
230        language_settings::init(cx);
231        let text = indoc! {"
232            struct X {
233                a: usize,
234                b: usize,
235            }
236
237            impl X {
238
239                fn new() -> Self {
240                    let a = 1;
241                    let b = 2;
242                    Self { a, b }
243                }
244
245                pub fn a(&self, param: bool) -> usize {
246                    self.a
247                }
248
249                pub fn b(&self) -> usize {
250                    self.b
251                }
252            }
253        "};
254        let buffer =
255            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
256        let snapshot = buffer.read(cx).snapshot();
257
258        assert_eq!(
259            summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
260            indoc! {"
261                struct X {
262                    <|START|>a: usize,
263                    b: usize,
264                }
265
266                impl X {
267
268                    fn new() -> Self {}
269
270                    pub fn a(&self, param: bool) -> usize {}
271
272                    pub fn b(&self) -> usize {}
273                }
274            "}
275        );
276
277        assert_eq!(
278            summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
279            indoc! {"
280                struct X {
281                    a: usize,
282                    b: usize,
283                }
284
285                impl X {
286
287                    fn new() -> Self {
288                        let <|START|a |END|>= 1;
289                        let b = 2;
290                        Self { a, b }
291                    }
292
293                    pub fn a(&self, param: bool) -> usize {}
294
295                    pub fn b(&self) -> usize {}
296                }
297            "}
298        );
299
300        assert_eq!(
301            summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
302            indoc! {"
303                struct X {
304                    a: usize,
305                    b: usize,
306                }
307
308                impl X {
309                <|START|>
310                    fn new() -> Self {}
311
312                    pub fn a(&self, param: bool) -> usize {}
313
314                    pub fn b(&self) -> usize {}
315                }
316            "}
317        );
318
319        assert_eq!(
320            summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
321            indoc! {"
322                struct X {
323                    a: usize,
324                    b: usize,
325                }
326
327                impl X {
328
329                    fn new() -> Self {}
330
331                    pub fn a(&self, param: bool) -> usize {}
332
333                    pub fn b(&self) -> usize {}
334                }
335                <|START|>"}
336        );
337
338        // Ensure nested functions get collapsed properly.
339        let text = indoc! {"
340            struct X {
341                a: usize,
342                b: usize,
343            }
344
345            impl X {
346
347                fn new() -> Self {
348                    let a = 1;
349                    let b = 2;
350                    Self { a, b }
351                }
352
353                pub fn a(&self, param: bool) -> usize {
354                    let a = 30;
355                    fn nested() -> usize {
356                        3
357                    }
358                    self.a + nested()
359                }
360
361                pub fn b(&self) -> usize {
362                    self.b
363                }
364            }
365        "};
366        buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
367        let snapshot = buffer.read(cx).snapshot();
368        assert_eq!(
369            summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
370            indoc! {"
371                <|START|>struct X {
372                    a: usize,
373                    b: usize,
374                }
375
376                impl X {
377
378                    fn new() -> Self {}
379
380                    pub fn a(&self, param: bool) -> usize {}
381
382                    pub fn b(&self) -> usize {}
383                }
384            "}
385        );
386    }
387}