prompts.rs

  1use ai::models::LanguageModel;
  2use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
  3use ai::prompts::file_context::FileContext;
  4use ai::prompts::generate::GenerateInlineContent;
  5use ai::prompts::preamble::EngineerPreamble;
  6use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
  7use ai::providers::open_ai::OpenAiLanguageModel;
  8use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
  9use std::cmp::{self, Reverse};
 10use std::ops::Range;
 11use std::sync::Arc;
 12
 13#[allow(dead_code)]
 14fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
 15    #[derive(Debug)]
 16    struct Match {
 17        collapse: Range<usize>,
 18        keep: Vec<Range<usize>>,
 19    }
 20
 21    let selected_range = selected_range.to_offset(buffer);
 22    let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
 23        Some(&grammar.embedding_config.as_ref()?.query)
 24    });
 25    let configs = ts_matches
 26        .grammars()
 27        .iter()
 28        .map(|g| g.embedding_config.as_ref().unwrap())
 29        .collect::<Vec<_>>();
 30    let mut matches = Vec::new();
 31    while let Some(mat) = ts_matches.peek() {
 32        let config = &configs[mat.grammar_index];
 33        if let Some(collapse) = mat.captures.iter().find_map(|cap| {
 34            if Some(cap.index) == config.collapse_capture_ix {
 35                Some(cap.node.byte_range())
 36            } else {
 37                None
 38            }
 39        }) {
 40            let mut keep = Vec::new();
 41            for capture in mat.captures.iter() {
 42                if Some(capture.index) == config.keep_capture_ix {
 43                    keep.push(capture.node.byte_range());
 44                } else {
 45                    continue;
 46                }
 47            }
 48            ts_matches.advance();
 49            matches.push(Match { collapse, keep });
 50        } else {
 51            ts_matches.advance();
 52        }
 53    }
 54    matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
 55    let mut matches = matches.into_iter().peekable();
 56
 57    let mut summary = String::new();
 58    let mut offset = 0;
 59    let mut flushed_selection = false;
 60    while let Some(mat) = matches.next() {
 61        // Keep extending the collapsed range if the next match surrounds
 62        // the current one.
 63        while let Some(next_mat) = matches.peek() {
 64            if mat.collapse.start <= next_mat.collapse.start
 65                && mat.collapse.end >= next_mat.collapse.end
 66            {
 67                matches.next().unwrap();
 68            } else {
 69                break;
 70            }
 71        }
 72
 73        if offset > mat.collapse.start {
 74            // Skip collapsed nodes that have already been summarized.
 75            offset = cmp::max(offset, mat.collapse.end);
 76            continue;
 77        }
 78
 79        if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
 80            if !flushed_selection {
 81                // The collapsed node ends after the selection starts, so we'll flush the selection first.
 82                summary.extend(buffer.text_for_range(offset..selected_range.start));
 83                summary.push_str("<|S|");
 84                if selected_range.end == selected_range.start {
 85                    summary.push_str(">");
 86                } else {
 87                    summary.extend(buffer.text_for_range(selected_range.clone()));
 88                    summary.push_str("|E|>");
 89                }
 90                offset = selected_range.end;
 91                flushed_selection = true;
 92            }
 93
 94            // If the selection intersects the collapsed node, we won't collapse it.
 95            if selected_range.end >= mat.collapse.start {
 96                continue;
 97            }
 98        }
 99
100        summary.extend(buffer.text_for_range(offset..mat.collapse.start));
101        for keep in mat.keep {
102            summary.extend(buffer.text_for_range(keep));
103        }
104        offset = mat.collapse.end;
105    }
106
107    // Flush selection if we haven't already done so.
108    if !flushed_selection && offset <= selected_range.start {
109        summary.extend(buffer.text_for_range(offset..selected_range.start));
110        summary.push_str("<|S|");
111        if selected_range.end == selected_range.start {
112            summary.push_str(">");
113        } else {
114            summary.extend(buffer.text_for_range(selected_range.clone()));
115            summary.push_str("|E|>");
116        }
117        offset = selected_range.end;
118    }
119
120    summary.extend(buffer.text_for_range(offset..buffer.len()));
121    summary
122}
123
124pub fn generate_content_prompt(
125    user_prompt: String,
126    language_name: Option<&str>,
127    buffer: BufferSnapshot,
128    range: Range<usize>,
129    search_results: Vec<PromptCodeSnippet>,
130    model: &str,
131    project_name: Option<String>,
132) -> anyhow::Result<String> {
133    // Using new Prompt Templates
134    let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAiLanguageModel::load(model));
135    let lang_name = if let Some(language_name) = language_name {
136        Some(language_name.to_string())
137    } else {
138        None
139    };
140
141    let args = PromptArguments {
142        model: openai_model,
143        language_name: lang_name.clone(),
144        project_name,
145        snippets: search_results.clone(),
146        reserved_tokens: 1000,
147        buffer: Some(buffer),
148        selected_range: Some(range),
149        user_prompt: Some(user_prompt.clone()),
150    };
151
152    let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
153        (PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
154        (
155            PromptPriority::Ordered { order: 1 },
156            Box::new(RepositoryContext {}),
157        ),
158        (
159            PromptPriority::Ordered { order: 0 },
160            Box::new(FileContext {}),
161        ),
162        (
163            PromptPriority::Mandatory,
164            Box::new(GenerateInlineContent {}),
165        ),
166    ];
167    let chain = PromptChain::new(args, templates);
168    let (prompt, _) = chain.generate(true)?;
169
170    anyhow::Ok(prompt)
171}
172
173#[cfg(test)]
174pub(crate) mod tests {
175    use super::*;
176    use gpui::{AppContext, Context};
177    use indoc::indoc;
178    use language::{
179        language_settings, tree_sitter_rust, Buffer, BufferId, Language, LanguageConfig,
180        LanguageMatcher, Point,
181    };
182    use settings::SettingsStore;
183    use std::sync::Arc;
184
185    pub(crate) fn rust_lang() -> Language {
186        Language::new(
187            LanguageConfig {
188                name: "Rust".into(),
189                matcher: LanguageMatcher {
190                    path_suffixes: vec!["rs".to_string()],
191                    ..Default::default()
192                },
193                ..Default::default()
194            },
195            Some(tree_sitter_rust::language()),
196        )
197        .with_embedding_query(
198            r#"
199            (
200                [(line_comment) (attribute_item)]* @context
201                .
202                [
203                    (struct_item
204                        name: (_) @name)
205
206                    (enum_item
207                        name: (_) @name)
208
209                    (impl_item
210                        trait: (_)? @name
211                        "for"? @name
212                        type: (_) @name)
213
214                    (trait_item
215                        name: (_) @name)
216
217                    (function_item
218                        name: (_) @name
219                        body: (block
220                            "{" @keep
221                            "}" @keep) @collapse)
222
223                    (macro_definition
224                        name: (_) @name)
225                    ] @item
226                )
227            "#,
228        )
229        .unwrap()
230    }
231
232    #[gpui::test]
233    fn test_outline_for_prompt(cx: &mut AppContext) {
234        let settings_store = SettingsStore::test(cx);
235        cx.set_global(settings_store);
236        language_settings::init(cx);
237        let text = indoc! {"
238            struct X {
239                a: usize,
240                b: usize,
241            }
242
243            impl X {
244
245                fn new() -> Self {
246                    let a = 1;
247                    let b = 2;
248                    Self { a, b }
249                }
250
251                pub fn a(&self, param: bool) -> usize {
252                    self.a
253                }
254
255                pub fn b(&self) -> usize {
256                    self.b
257                }
258            }
259        "};
260        let buffer = cx.new_model(|cx| {
261            Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
262        });
263        let snapshot = buffer.read(cx).snapshot();
264
265        assert_eq!(
266            summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
267            indoc! {"
268                struct X {
269                    <|S|>a: usize,
270                    b: usize,
271                }
272
273                impl X {
274
275                    fn new() -> Self {}
276
277                    pub fn a(&self, param: bool) -> usize {}
278
279                    pub fn b(&self) -> usize {}
280                }
281            "}
282        );
283
284        assert_eq!(
285            summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
286            indoc! {"
287                struct X {
288                    a: usize,
289                    b: usize,
290                }
291
292                impl X {
293
294                    fn new() -> Self {
295                        let <|S|a |E|>= 1;
296                        let b = 2;
297                        Self { a, b }
298                    }
299
300                    pub fn a(&self, param: bool) -> usize {}
301
302                    pub fn b(&self) -> usize {}
303                }
304            "}
305        );
306
307        assert_eq!(
308            summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
309            indoc! {"
310                struct X {
311                    a: usize,
312                    b: usize,
313                }
314
315                impl X {
316                <|S|>
317                    fn new() -> Self {}
318
319                    pub fn a(&self, param: bool) -> usize {}
320
321                    pub fn b(&self) -> usize {}
322                }
323            "}
324        );
325
326        assert_eq!(
327            summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
328            indoc! {"
329                struct X {
330                    a: usize,
331                    b: usize,
332                }
333
334                impl X {
335
336                    fn new() -> Self {}
337
338                    pub fn a(&self, param: bool) -> usize {}
339
340                    pub fn b(&self) -> usize {}
341                }
342                <|S|>"}
343        );
344
345        // Ensure nested functions get collapsed properly.
346        let text = indoc! {"
347            struct X {
348                a: usize,
349                b: usize,
350            }
351
352            impl X {
353
354                fn new() -> Self {
355                    let a = 1;
356                    let b = 2;
357                    Self { a, b }
358                }
359
360                pub fn a(&self, param: bool) -> usize {
361                    let a = 30;
362                    fn nested() -> usize {
363                        3
364                    }
365                    self.a + nested()
366                }
367
368                pub fn b(&self) -> usize {
369                    self.b
370                }
371            }
372        "};
373        buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
374        let snapshot = buffer.read(cx).snapshot();
375        assert_eq!(
376            summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
377            indoc! {"
378                <|S|>struct X {
379                    a: usize,
380                    b: usize,
381                }
382
383                impl X {
384
385                    fn new() -> Self {}
386
387                    pub fn a(&self, param: bool) -> usize {}
388
389                    pub fn b(&self) -> usize {}
390                }
391            "}
392        );
393    }
394}