prompts.rs

  1use ai::models::LanguageModel;
  2use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
  3use ai::prompts::file_context::FileContext;
  4use ai::prompts::generate::GenerateInlineContent;
  5use ai::prompts::preamble::EngineerPreamble;
  6use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
  7use ai::providers::open_ai::OpenAiLanguageModel;
  8use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
  9use std::cmp::{self, Reverse};
 10use std::ops::Range;
 11use std::sync::Arc;
 12
 13#[allow(dead_code)]
 14fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
 15    #[derive(Debug)]
 16    struct Match {
 17        collapse: Range<usize>,
 18        keep: Vec<Range<usize>>,
 19    }
 20
 21    let selected_range = selected_range.to_offset(buffer);
 22    let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
 23        Some(&grammar.embedding_config.as_ref()?.query)
 24    });
 25    let configs = ts_matches
 26        .grammars()
 27        .iter()
 28        .map(|g| g.embedding_config.as_ref().unwrap())
 29        .collect::<Vec<_>>();
 30    let mut matches = Vec::new();
 31    while let Some(mat) = ts_matches.peek() {
 32        let config = &configs[mat.grammar_index];
 33        if let Some(collapse) = mat.captures.iter().find_map(|cap| {
 34            if Some(cap.index) == config.collapse_capture_ix {
 35                Some(cap.node.byte_range())
 36            } else {
 37                None
 38            }
 39        }) {
 40            let mut keep = Vec::new();
 41            for capture in mat.captures.iter() {
 42                if Some(capture.index) == config.keep_capture_ix {
 43                    keep.push(capture.node.byte_range());
 44                } else {
 45                    continue;
 46                }
 47            }
 48            ts_matches.advance();
 49            matches.push(Match { collapse, keep });
 50        } else {
 51            ts_matches.advance();
 52        }
 53    }
 54    matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
 55    let mut matches = matches.into_iter().peekable();
 56
 57    let mut summary = String::new();
 58    let mut offset = 0;
 59    let mut flushed_selection = false;
 60    while let Some(mat) = matches.next() {
 61        // Keep extending the collapsed range if the next match surrounds
 62        // the current one.
 63        while let Some(next_mat) = matches.peek() {
 64            if mat.collapse.start <= next_mat.collapse.start
 65                && mat.collapse.end >= next_mat.collapse.end
 66            {
 67                matches.next().unwrap();
 68            } else {
 69                break;
 70            }
 71        }
 72
 73        if offset > mat.collapse.start {
 74            // Skip collapsed nodes that have already been summarized.
 75            offset = cmp::max(offset, mat.collapse.end);
 76            continue;
 77        }
 78
 79        if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
 80            if !flushed_selection {
 81                // The collapsed node ends after the selection starts, so we'll flush the selection first.
 82                summary.extend(buffer.text_for_range(offset..selected_range.start));
 83                summary.push_str("<|S|");
 84                if selected_range.end == selected_range.start {
 85                    summary.push_str(">");
 86                } else {
 87                    summary.extend(buffer.text_for_range(selected_range.clone()));
 88                    summary.push_str("|E|>");
 89                }
 90                offset = selected_range.end;
 91                flushed_selection = true;
 92            }
 93
 94            // If the selection intersects the collapsed node, we won't collapse it.
 95            if selected_range.end >= mat.collapse.start {
 96                continue;
 97            }
 98        }
 99
100        summary.extend(buffer.text_for_range(offset..mat.collapse.start));
101        for keep in mat.keep {
102            summary.extend(buffer.text_for_range(keep));
103        }
104        offset = mat.collapse.end;
105    }
106
107    // Flush selection if we haven't already done so.
108    if !flushed_selection && offset <= selected_range.start {
109        summary.extend(buffer.text_for_range(offset..selected_range.start));
110        summary.push_str("<|S|");
111        if selected_range.end == selected_range.start {
112            summary.push_str(">");
113        } else {
114            summary.extend(buffer.text_for_range(selected_range.clone()));
115            summary.push_str("|E|>");
116        }
117        offset = selected_range.end;
118    }
119
120    summary.extend(buffer.text_for_range(offset..buffer.len()));
121    summary
122}
123
124pub fn generate_content_prompt(
125    user_prompt: String,
126    language_name: Option<&str>,
127    buffer: BufferSnapshot,
128    range: Range<usize>,
129    search_results: Vec<PromptCodeSnippet>,
130    model: &str,
131    project_name: Option<String>,
132) -> anyhow::Result<String> {
133    // Using new Prompt Templates
134    let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAiLanguageModel::load(model));
135    let lang_name = if let Some(language_name) = language_name {
136        Some(language_name.to_string())
137    } else {
138        None
139    };
140
141    let args = PromptArguments {
142        model: openai_model,
143        language_name: lang_name.clone(),
144        project_name,
145        snippets: search_results.clone(),
146        reserved_tokens: 1000,
147        buffer: Some(buffer),
148        selected_range: Some(range),
149        user_prompt: Some(user_prompt.clone()),
150    };
151
152    let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
153        (PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
154        (
155            PromptPriority::Ordered { order: 1 },
156            Box::new(RepositoryContext {}),
157        ),
158        (
159            PromptPriority::Ordered { order: 0 },
160            Box::new(FileContext {}),
161        ),
162        (
163            PromptPriority::Mandatory,
164            Box::new(GenerateInlineContent {}),
165        ),
166    ];
167    let chain = PromptChain::new(args, templates);
168    let (prompt, _) = chain.generate(true)?;
169
170    anyhow::Ok(prompt)
171}
172
173#[cfg(test)]
174pub(crate) mod tests {
175
176    use super::*;
177    use std::sync::Arc;
178
179    use gpui::{AppContext, Context};
180    use indoc::indoc;
181    use language::{
182        language_settings, tree_sitter_rust, Buffer, BufferId, Language, LanguageConfig, Point,
183    };
184    use settings::SettingsStore;
185
186    pub(crate) fn rust_lang() -> Language {
187        Language::new(
188            LanguageConfig {
189                name: "Rust".into(),
190                path_suffixes: vec!["rs".to_string()],
191                ..Default::default()
192            },
193            Some(tree_sitter_rust::language()),
194        )
195        .with_embedding_query(
196            r#"
197            (
198                [(line_comment) (attribute_item)]* @context
199                .
200                [
201                    (struct_item
202                        name: (_) @name)
203
204                    (enum_item
205                        name: (_) @name)
206
207                    (impl_item
208                        trait: (_)? @name
209                        "for"? @name
210                        type: (_) @name)
211
212                    (trait_item
213                        name: (_) @name)
214
215                    (function_item
216                        name: (_) @name
217                        body: (block
218                            "{" @keep
219                            "}" @keep) @collapse)
220
221                    (macro_definition
222                        name: (_) @name)
223                    ] @item
224                )
225            "#,
226        )
227        .unwrap()
228    }
229
230    #[gpui::test]
231    fn test_outline_for_prompt(cx: &mut AppContext) {
232        let settings_store = SettingsStore::test(cx);
233        cx.set_global(settings_store);
234        language_settings::init(cx);
235        let text = indoc! {"
236            struct X {
237                a: usize,
238                b: usize,
239            }
240
241            impl X {
242
243                fn new() -> Self {
244                    let a = 1;
245                    let b = 2;
246                    Self { a, b }
247                }
248
249                pub fn a(&self, param: bool) -> usize {
250                    self.a
251                }
252
253                pub fn b(&self) -> usize {
254                    self.b
255                }
256            }
257        "};
258        let buffer = cx.new_model(|cx| {
259            Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
260        });
261        let snapshot = buffer.read(cx).snapshot();
262
263        assert_eq!(
264            summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
265            indoc! {"
266                struct X {
267                    <|S|>a: usize,
268                    b: usize,
269                }
270
271                impl X {
272
273                    fn new() -> Self {}
274
275                    pub fn a(&self, param: bool) -> usize {}
276
277                    pub fn b(&self) -> usize {}
278                }
279            "}
280        );
281
282        assert_eq!(
283            summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
284            indoc! {"
285                struct X {
286                    a: usize,
287                    b: usize,
288                }
289
290                impl X {
291
292                    fn new() -> Self {
293                        let <|S|a |E|>= 1;
294                        let b = 2;
295                        Self { a, b }
296                    }
297
298                    pub fn a(&self, param: bool) -> usize {}
299
300                    pub fn b(&self) -> usize {}
301                }
302            "}
303        );
304
305        assert_eq!(
306            summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
307            indoc! {"
308                struct X {
309                    a: usize,
310                    b: usize,
311                }
312
313                impl X {
314                <|S|>
315                    fn new() -> Self {}
316
317                    pub fn a(&self, param: bool) -> usize {}
318
319                    pub fn b(&self) -> usize {}
320                }
321            "}
322        );
323
324        assert_eq!(
325            summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
326            indoc! {"
327                struct X {
328                    a: usize,
329                    b: usize,
330                }
331
332                impl X {
333
334                    fn new() -> Self {}
335
336                    pub fn a(&self, param: bool) -> usize {}
337
338                    pub fn b(&self) -> usize {}
339                }
340                <|S|>"}
341        );
342
343        // Ensure nested functions get collapsed properly.
344        let text = indoc! {"
345            struct X {
346                a: usize,
347                b: usize,
348            }
349
350            impl X {
351
352                fn new() -> Self {
353                    let a = 1;
354                    let b = 2;
355                    Self { a, b }
356                }
357
358                pub fn a(&self, param: bool) -> usize {
359                    let a = 30;
360                    fn nested() -> usize {
361                        3
362                    }
363                    self.a + nested()
364                }
365
366                pub fn b(&self) -> usize {
367                    self.b
368                }
369            }
370        "};
371        buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
372        let snapshot = buffer.read(cx).snapshot();
373        assert_eq!(
374            summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
375            indoc! {"
376                <|S|>struct X {
377                    a: usize,
378                    b: usize,
379                }
380
381                impl X {
382
383                    fn new() -> Self {}
384
385                    pub fn a(&self, param: bool) -> usize {}
386
387                    pub fn b(&self) -> usize {}
388                }
389            "}
390        );
391    }
392}