prompts.rs

  1use ai::models::LanguageModel;
  2use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
  3use ai::prompts::file_context::FileContext;
  4use ai::prompts::generate::GenerateInlineContent;
  5use ai::prompts::preamble::EngineerPreamble;
  6use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
  7use ai::providers::open_ai::OpenAILanguageModel;
  8use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
  9use std::cmp::{self, Reverse};
 10use std::ops::Range;
 11use std::sync::Arc;
 12
 13#[allow(dead_code)]
 14fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
 15    #[derive(Debug)]
 16    struct Match {
 17        collapse: Range<usize>,
 18        keep: Vec<Range<usize>>,
 19    }
 20
 21    let selected_range = selected_range.to_offset(buffer);
 22    let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
 23        Some(&grammar.embedding_config.as_ref()?.query)
 24    });
 25    let configs = ts_matches
 26        .grammars()
 27        .iter()
 28        .map(|g| g.embedding_config.as_ref().unwrap())
 29        .collect::<Vec<_>>();
 30    let mut matches = Vec::new();
 31    while let Some(mat) = ts_matches.peek() {
 32        let config = &configs[mat.grammar_index];
 33        if let Some(collapse) = mat.captures.iter().find_map(|cap| {
 34            if Some(cap.index) == config.collapse_capture_ix {
 35                Some(cap.node.byte_range())
 36            } else {
 37                None
 38            }
 39        }) {
 40            let mut keep = Vec::new();
 41            for capture in mat.captures.iter() {
 42                if Some(capture.index) == config.keep_capture_ix {
 43                    keep.push(capture.node.byte_range());
 44                } else {
 45                    continue;
 46                }
 47            }
 48            ts_matches.advance();
 49            matches.push(Match { collapse, keep });
 50        } else {
 51            ts_matches.advance();
 52        }
 53    }
 54    matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
 55    let mut matches = matches.into_iter().peekable();
 56
 57    let mut summary = String::new();
 58    let mut offset = 0;
 59    let mut flushed_selection = false;
 60    while let Some(mat) = matches.next() {
 61        // Keep extending the collapsed range if the next match surrounds
 62        // the current one.
 63        while let Some(next_mat) = matches.peek() {
 64            if mat.collapse.start <= next_mat.collapse.start
 65                && mat.collapse.end >= next_mat.collapse.end
 66            {
 67                matches.next().unwrap();
 68            } else {
 69                break;
 70            }
 71        }
 72
 73        if offset > mat.collapse.start {
 74            // Skip collapsed nodes that have already been summarized.
 75            offset = cmp::max(offset, mat.collapse.end);
 76            continue;
 77        }
 78
 79        if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
 80            if !flushed_selection {
 81                // The collapsed node ends after the selection starts, so we'll flush the selection first.
 82                summary.extend(buffer.text_for_range(offset..selected_range.start));
 83                summary.push_str("<|S|");
 84                if selected_range.end == selected_range.start {
 85                    summary.push_str(">");
 86                } else {
 87                    summary.extend(buffer.text_for_range(selected_range.clone()));
 88                    summary.push_str("|E|>");
 89                }
 90                offset = selected_range.end;
 91                flushed_selection = true;
 92            }
 93
 94            // If the selection intersects the collapsed node, we won't collapse it.
 95            if selected_range.end >= mat.collapse.start {
 96                continue;
 97            }
 98        }
 99
100        summary.extend(buffer.text_for_range(offset..mat.collapse.start));
101        for keep in mat.keep {
102            summary.extend(buffer.text_for_range(keep));
103        }
104        offset = mat.collapse.end;
105    }
106
107    // Flush selection if we haven't already done so.
108    if !flushed_selection && offset <= selected_range.start {
109        summary.extend(buffer.text_for_range(offset..selected_range.start));
110        summary.push_str("<|S|");
111        if selected_range.end == selected_range.start {
112            summary.push_str(">");
113        } else {
114            summary.extend(buffer.text_for_range(selected_range.clone()));
115            summary.push_str("|E|>");
116        }
117        offset = selected_range.end;
118    }
119
120    summary.extend(buffer.text_for_range(offset..buffer.len()));
121    summary
122}
123
124pub fn generate_content_prompt(
125    user_prompt: String,
126    language_name: Option<&str>,
127    buffer: BufferSnapshot,
128    range: Range<usize>,
129    search_results: Vec<PromptCodeSnippet>,
130    model: &str,
131    project_name: Option<String>,
132) -> anyhow::Result<String> {
133    // Using new Prompt Templates
134    let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
135    let lang_name = if let Some(language_name) = language_name {
136        Some(language_name.to_string())
137    } else {
138        None
139    };
140
141    let args = PromptArguments {
142        model: openai_model,
143        language_name: lang_name.clone(),
144        project_name,
145        snippets: search_results.clone(),
146        reserved_tokens: 1000,
147        buffer: Some(buffer),
148        selected_range: Some(range),
149        user_prompt: Some(user_prompt.clone()),
150    };
151
152    let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
153        (PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
154        (
155            PromptPriority::Ordered { order: 1 },
156            Box::new(RepositoryContext {}),
157        ),
158        (
159            PromptPriority::Ordered { order: 0 },
160            Box::new(FileContext {}),
161        ),
162        (
163            PromptPriority::Mandatory,
164            Box::new(GenerateInlineContent {}),
165        ),
166    ];
167    let chain = PromptChain::new(args, templates);
168    let (prompt, _) = chain.generate(true)?;
169
170    anyhow::Ok(prompt)
171}
172
173#[cfg(test)]
174pub(crate) mod tests {
175
176    use super::*;
177    use std::sync::Arc;
178
179    use gpui::AppContext;
180    use indoc::indoc;
181    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
182    use settings::SettingsStore;
183
184    pub(crate) fn rust_lang() -> Language {
185        Language::new(
186            LanguageConfig {
187                name: "Rust".into(),
188                path_suffixes: vec!["rs".to_string()],
189                ..Default::default()
190            },
191            Some(tree_sitter_rust::language()),
192        )
193        .with_embedding_query(
194            r#"
195            (
196                [(line_comment) (attribute_item)]* @context
197                .
198                [
199                    (struct_item
200                        name: (_) @name)
201
202                    (enum_item
203                        name: (_) @name)
204
205                    (impl_item
206                        trait: (_)? @name
207                        "for"? @name
208                        type: (_) @name)
209
210                    (trait_item
211                        name: (_) @name)
212
213                    (function_item
214                        name: (_) @name
215                        body: (block
216                            "{" @keep
217                            "}" @keep) @collapse)
218
219                    (macro_definition
220                        name: (_) @name)
221                    ] @item
222                )
223            "#,
224        )
225        .unwrap()
226    }
227
228    #[gpui::test]
229    fn test_outline_for_prompt(cx: &mut AppContext) {
230        cx.set_global(SettingsStore::test(cx));
231        language_settings::init(cx);
232        let text = indoc! {"
233            struct X {
234                a: usize,
235                b: usize,
236            }
237
238            impl X {
239
240                fn new() -> Self {
241                    let a = 1;
242                    let b = 2;
243                    Self { a, b }
244                }
245
246                pub fn a(&self, param: bool) -> usize {
247                    self.a
248                }
249
250                pub fn b(&self) -> usize {
251                    self.b
252                }
253            }
254        "};
255        let buffer =
256            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
257        let snapshot = buffer.read(cx).snapshot();
258
259        assert_eq!(
260            summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
261            indoc! {"
262                struct X {
263                    <|S|>a: usize,
264                    b: usize,
265                }
266
267                impl X {
268
269                    fn new() -> Self {}
270
271                    pub fn a(&self, param: bool) -> usize {}
272
273                    pub fn b(&self) -> usize {}
274                }
275            "}
276        );
277
278        assert_eq!(
279            summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
280            indoc! {"
281                struct X {
282                    a: usize,
283                    b: usize,
284                }
285
286                impl X {
287
288                    fn new() -> Self {
289                        let <|S|a |E|>= 1;
290                        let b = 2;
291                        Self { a, b }
292                    }
293
294                    pub fn a(&self, param: bool) -> usize {}
295
296                    pub fn b(&self) -> usize {}
297                }
298            "}
299        );
300
301        assert_eq!(
302            summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
303            indoc! {"
304                struct X {
305                    a: usize,
306                    b: usize,
307                }
308
309                impl X {
310                <|S|>
311                    fn new() -> Self {}
312
313                    pub fn a(&self, param: bool) -> usize {}
314
315                    pub fn b(&self) -> usize {}
316                }
317            "}
318        );
319
320        assert_eq!(
321            summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
322            indoc! {"
323                struct X {
324                    a: usize,
325                    b: usize,
326                }
327
328                impl X {
329
330                    fn new() -> Self {}
331
332                    pub fn a(&self, param: bool) -> usize {}
333
334                    pub fn b(&self) -> usize {}
335                }
336                <|S|>"}
337        );
338
339        // Ensure nested functions get collapsed properly.
340        let text = indoc! {"
341            struct X {
342                a: usize,
343                b: usize,
344            }
345
346            impl X {
347
348                fn new() -> Self {
349                    let a = 1;
350                    let b = 2;
351                    Self { a, b }
352                }
353
354                pub fn a(&self, param: bool) -> usize {
355                    let a = 30;
356                    fn nested() -> usize {
357                        3
358                    }
359                    self.a + nested()
360                }
361
362                pub fn b(&self) -> usize {
363                    self.b
364                }
365            }
366        "};
367        buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
368        let snapshot = buffer.read(cx).snapshot();
369        assert_eq!(
370            summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
371            indoc! {"
372                <|S|>struct X {
373                    a: usize,
374                    b: usize,
375                }
376
377                impl X {
378
379                    fn new() -> Self {}
380
381                    pub fn a(&self, param: bool) -> usize {}
382
383                    pub fn b(&self) -> usize {}
384                }
385            "}
386        );
387    }
388}