prompts.rs

  1use ai::models::LanguageModel;
  2use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
  3use ai::prompts::file_context::FileContext;
  4use ai::prompts::generate::GenerateInlineContent;
  5use ai::prompts::preamble::EngineerPreamble;
  6use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
  7use ai::providers::open_ai::OpenAILanguageModel;
  8use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
  9use std::cmp::{self, Reverse};
 10use std::ops::Range;
 11use std::sync::Arc;
 12
 13#[allow(dead_code)]
 14fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
 15    #[derive(Debug)]
 16    struct Match {
 17        collapse: Range<usize>,
 18        keep: Vec<Range<usize>>,
 19    }
 20
 21    let selected_range = selected_range.to_offset(buffer);
 22    let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
 23        Some(&grammar.embedding_config.as_ref()?.query)
 24    });
 25    let configs = ts_matches
 26        .grammars()
 27        .iter()
 28        .map(|g| g.embedding_config.as_ref().unwrap())
 29        .collect::<Vec<_>>();
 30    let mut matches = Vec::new();
 31    while let Some(mat) = ts_matches.peek() {
 32        let config = &configs[mat.grammar_index];
 33        if let Some(collapse) = mat.captures.iter().find_map(|cap| {
 34            if Some(cap.index) == config.collapse_capture_ix {
 35                Some(cap.node.byte_range())
 36            } else {
 37                None
 38            }
 39        }) {
 40            let mut keep = Vec::new();
 41            for capture in mat.captures.iter() {
 42                if Some(capture.index) == config.keep_capture_ix {
 43                    keep.push(capture.node.byte_range());
 44                } else {
 45                    continue;
 46                }
 47            }
 48            ts_matches.advance();
 49            matches.push(Match { collapse, keep });
 50        } else {
 51            ts_matches.advance();
 52        }
 53    }
 54    matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
 55    let mut matches = matches.into_iter().peekable();
 56
 57    let mut summary = String::new();
 58    let mut offset = 0;
 59    let mut flushed_selection = false;
 60    while let Some(mat) = matches.next() {
 61        // Keep extending the collapsed range if the next match surrounds
 62        // the current one.
 63        while let Some(next_mat) = matches.peek() {
 64            if mat.collapse.start <= next_mat.collapse.start
 65                && mat.collapse.end >= next_mat.collapse.end
 66            {
 67                matches.next().unwrap();
 68            } else {
 69                break;
 70            }
 71        }
 72
 73        if offset > mat.collapse.start {
 74            // Skip collapsed nodes that have already been summarized.
 75            offset = cmp::max(offset, mat.collapse.end);
 76            continue;
 77        }
 78
 79        if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
 80            if !flushed_selection {
 81                // The collapsed node ends after the selection starts, so we'll flush the selection first.
 82                summary.extend(buffer.text_for_range(offset..selected_range.start));
 83                summary.push_str("<|S|");
 84                if selected_range.end == selected_range.start {
 85                    summary.push_str(">");
 86                } else {
 87                    summary.extend(buffer.text_for_range(selected_range.clone()));
 88                    summary.push_str("|E|>");
 89                }
 90                offset = selected_range.end;
 91                flushed_selection = true;
 92            }
 93
 94            // If the selection intersects the collapsed node, we won't collapse it.
 95            if selected_range.end >= mat.collapse.start {
 96                continue;
 97            }
 98        }
 99
100        summary.extend(buffer.text_for_range(offset..mat.collapse.start));
101        for keep in mat.keep {
102            summary.extend(buffer.text_for_range(keep));
103        }
104        offset = mat.collapse.end;
105    }
106
107    // Flush selection if we haven't already done so.
108    if !flushed_selection && offset <= selected_range.start {
109        summary.extend(buffer.text_for_range(offset..selected_range.start));
110        summary.push_str("<|S|");
111        if selected_range.end == selected_range.start {
112            summary.push_str(">");
113        } else {
114            summary.extend(buffer.text_for_range(selected_range.clone()));
115            summary.push_str("|E|>");
116        }
117        offset = selected_range.end;
118    }
119
120    summary.extend(buffer.text_for_range(offset..buffer.len()));
121    summary
122}
123
124pub fn generate_content_prompt(
125    user_prompt: String,
126    language_name: Option<&str>,
127    buffer: BufferSnapshot,
128    range: Range<usize>,
129    search_results: Vec<PromptCodeSnippet>,
130    model: &str,
131    project_name: Option<String>,
132) -> anyhow::Result<String> {
133    // Using new Prompt Templates
134    let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
135    let lang_name = if let Some(language_name) = language_name {
136        Some(language_name.to_string())
137    } else {
138        None
139    };
140
141    let args = PromptArguments {
142        model: openai_model,
143        language_name: lang_name.clone(),
144        project_name,
145        snippets: search_results.clone(),
146        reserved_tokens: 1000,
147        buffer: Some(buffer),
148        selected_range: Some(range),
149        user_prompt: Some(user_prompt.clone()),
150    };
151
152    let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
153        (PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
154        (
155            PromptPriority::Ordered { order: 1 },
156            Box::new(RepositoryContext {}),
157        ),
158        (
159            PromptPriority::Ordered { order: 0 },
160            Box::new(FileContext {}),
161        ),
162        (
163            PromptPriority::Mandatory,
164            Box::new(GenerateInlineContent {}),
165        ),
166    ];
167    let chain = PromptChain::new(args, templates);
168    let (prompt, _) = chain.generate(true)?;
169
170    anyhow::Ok(prompt)
171}
172
173#[cfg(test)]
174pub(crate) mod tests {
175
176    use super::*;
177    use std::sync::Arc;
178
179    use gpui::{AppContext, Context};
180    use indoc::indoc;
181    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
182    use settings::SettingsStore;
183
184    pub(crate) fn rust_lang() -> Language {
185        Language::new(
186            LanguageConfig {
187                name: "Rust".into(),
188                path_suffixes: vec!["rs".to_string()],
189                ..Default::default()
190            },
191            Some(tree_sitter_rust::language()),
192        )
193        .with_embedding_query(
194            r#"
195            (
196                [(line_comment) (attribute_item)]* @context
197                .
198                [
199                    (struct_item
200                        name: (_) @name)
201
202                    (enum_item
203                        name: (_) @name)
204
205                    (impl_item
206                        trait: (_)? @name
207                        "for"? @name
208                        type: (_) @name)
209
210                    (trait_item
211                        name: (_) @name)
212
213                    (function_item
214                        name: (_) @name
215                        body: (block
216                            "{" @keep
217                            "}" @keep) @collapse)
218
219                    (macro_definition
220                        name: (_) @name)
221                    ] @item
222                )
223            "#,
224        )
225        .unwrap()
226    }
227
228    #[gpui::test]
229    fn test_outline_for_prompt(cx: &mut AppContext) {
230        let settings_store = SettingsStore::test(cx);
231        cx.set_global(settings_store);
232        language_settings::init(cx);
233        let text = indoc! {"
234            struct X {
235                a: usize,
236                b: usize,
237            }
238
239            impl X {
240
241                fn new() -> Self {
242                    let a = 1;
243                    let b = 2;
244                    Self { a, b }
245                }
246
247                pub fn a(&self, param: bool) -> usize {
248                    self.a
249                }
250
251                pub fn b(&self) -> usize {
252                    self.b
253                }
254            }
255        "};
256        let buffer =
257            cx.new_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
258        let snapshot = buffer.read(cx).snapshot();
259
260        assert_eq!(
261            summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
262            indoc! {"
263                struct X {
264                    <|S|>a: usize,
265                    b: usize,
266                }
267
268                impl X {
269
270                    fn new() -> Self {}
271
272                    pub fn a(&self, param: bool) -> usize {}
273
274                    pub fn b(&self) -> usize {}
275                }
276            "}
277        );
278
279        assert_eq!(
280            summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
281            indoc! {"
282                struct X {
283                    a: usize,
284                    b: usize,
285                }
286
287                impl X {
288
289                    fn new() -> Self {
290                        let <|S|a |E|>= 1;
291                        let b = 2;
292                        Self { a, b }
293                    }
294
295                    pub fn a(&self, param: bool) -> usize {}
296
297                    pub fn b(&self) -> usize {}
298                }
299            "}
300        );
301
302        assert_eq!(
303            summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
304            indoc! {"
305                struct X {
306                    a: usize,
307                    b: usize,
308                }
309
310                impl X {
311                <|S|>
312                    fn new() -> Self {}
313
314                    pub fn a(&self, param: bool) -> usize {}
315
316                    pub fn b(&self) -> usize {}
317                }
318            "}
319        );
320
321        assert_eq!(
322            summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
323            indoc! {"
324                struct X {
325                    a: usize,
326                    b: usize,
327                }
328
329                impl X {
330
331                    fn new() -> Self {}
332
333                    pub fn a(&self, param: bool) -> usize {}
334
335                    pub fn b(&self) -> usize {}
336                }
337                <|S|>"}
338        );
339
340        // Ensure nested functions get collapsed properly.
341        let text = indoc! {"
342            struct X {
343                a: usize,
344                b: usize,
345            }
346
347            impl X {
348
349                fn new() -> Self {
350                    let a = 1;
351                    let b = 2;
352                    Self { a, b }
353                }
354
355                pub fn a(&self, param: bool) -> usize {
356                    let a = 30;
357                    fn nested() -> usize {
358                        3
359                    }
360                    self.a + nested()
361                }
362
363                pub fn b(&self) -> usize {
364                    self.b
365                }
366            }
367        "};
368        buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
369        let snapshot = buffer.read(cx).snapshot();
370        assert_eq!(
371            summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
372            indoc! {"
373                <|S|>struct X {
374                    a: usize,
375                    b: usize,
376                }
377
378                impl X {
379
380                    fn new() -> Self {}
381
382                    pub fn a(&self, param: bool) -> usize {}
383
384                    pub fn b(&self) -> usize {}
385                }
386            "}
387        );
388    }
389}