cloud_zeta2_prompt.rs

  1//! Zeta2 prompt planning and generation code shared with cloud.
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, ReferencedDeclaration};
  5use indoc::indoc;
  6use ordered_float::OrderedFloat;
  7use rustc_hash::{FxHashMap, FxHashSet};
  8use std::fmt::Write;
  9use std::sync::Arc;
 10use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
 11use strum::{EnumIter, IntoEnumIterator};
 12
 13pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
 14
 15pub const CURSOR_MARKER: &str = "<|cursor_position|>";
 16/// NOTE: Differs from zed version of constant - includes a newline
 17pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
 18/// NOTE: Differs from zed version of constant - includes a newline
 19pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
 20
 21// TODO: use constants for markers?
 22const MARKED_EXCERPT_SYSTEM_PROMPT: &str = indoc! {"
 23    You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
 24
 25    The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|cursor_position|>.  Please respond with edited code for that region.
 26
 27    Other code is provided for context, and `…` indicates when code has been skipped.
 28"};
 29
 30const LABELED_SECTIONS_SYSTEM_PROMPT: &str = indoc! {r#"
 31    You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code.
 32
 33    Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`).
 34
 35    The cursor position is marked with `<|cursor_position|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it.
 36
 37    Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example:
 38
 39    <|current_section|>
 40    for i in 0..16 {
 41        println!("{i}");
 42    }
 43"#};
 44
 45pub struct PlannedPrompt<'a> {
 46    request: &'a predict_edits_v3::PredictEditsRequest,
 47    /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
 48    /// `to_prompt_string`.
 49    snippets: Vec<PlannedSnippet<'a>>,
 50    budget_used: usize,
 51}
 52
 53pub fn system_prompt(format: PromptFormat) -> &'static str {
 54    match format {
 55        PromptFormat::MarkedExcerpt => MARKED_EXCERPT_SYSTEM_PROMPT,
 56        PromptFormat::LabeledSections => LABELED_SECTIONS_SYSTEM_PROMPT,
 57        // only intended for use via zeta_cli
 58        PromptFormat::OnlySnippets => "",
 59    }
 60}
 61
 62#[derive(Clone, Debug)]
 63pub struct PlannedSnippet<'a> {
 64    path: Arc<Path>,
 65    range: Range<usize>,
 66    text: &'a str,
 67    // TODO: Indicate this in the output
 68    #[allow(dead_code)]
 69    text_is_truncated: bool,
 70}
 71
 72#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
 73pub enum DeclarationStyle {
 74    Signature,
 75    Declaration,
 76}
 77
 78#[derive(Clone, Debug)]
 79pub struct SectionLabels {
 80    pub excerpt_index: usize,
 81    pub section_ranges: Vec<(Arc<Path>, Range<usize>)>,
 82}
 83
 84impl<'a> PlannedPrompt<'a> {
 85    /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
 86    ///
 87    /// Initializes a priority queue by populating it with each snippet, finding the
 88    /// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a
 89    /// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects
 90    /// the cost of upgrade.
 91    ///
 92    /// TODO: Implement an early halting condition. One option might be to have another priority
 93    /// queue where the score is the size, and update it accordingly. Another option might be to
 94    /// have some simpler heuristic like bailing after N failed insertions, or based on how much
 95    /// budget is left.
 96    ///
 97    /// TODO: Has the current known sources of imprecision:
 98    ///
 99    /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
100    /// plan even though the containing struct is already included.
101    ///
102    /// * Does not consider cost of signatures when ranking snippets - this is tricky since
103    /// signatures may be shared by multiple snippets.
104    ///
105    /// * Does not include file paths / other text when considering max_bytes.
106    pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
107        let mut this = PlannedPrompt {
108            request,
109            snippets: Vec::new(),
110            budget_used: request.excerpt.len(),
111        };
112        let mut included_parents = FxHashSet::default();
113        let additional_parents = this.additional_parent_signatures(
114            &request.excerpt_path,
115            request.excerpt_parent,
116            &included_parents,
117        )?;
118        this.add_parents(&mut included_parents, additional_parents);
119
120        let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES);
121
122        if this.budget_used > max_bytes {
123            return Err(anyhow!(
124                "Excerpt + signatures size of {} already exceeds budget of {}",
125                this.budget_used,
126                max_bytes
127            ));
128        }
129
130        #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
131        struct QueueEntry {
132            score_density: OrderedFloat<f32>,
133            declaration_index: usize,
134            style: DeclarationStyle,
135        }
136
137        // Initialize priority queue with the best score for each snippet.
138        let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
139        for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
140            let (style, score_density) = DeclarationStyle::iter()
141                .map(|style| {
142                    (
143                        style,
144                        OrderedFloat(declaration_score_density(&declaration, style)),
145                    )
146                })
147                .max_by_key(|(_, score_density)| *score_density)
148                .unwrap();
149            queue.push(QueueEntry {
150                score_density,
151                declaration_index,
152                style,
153            });
154        }
155
156        // Knapsack selection loop
157        while let Some(queue_entry) = queue.pop() {
158            let Some(declaration) = request
159                .referenced_declarations
160                .get(queue_entry.declaration_index)
161            else {
162                return Err(anyhow!(
163                    "Invalid declaration index {}",
164                    queue_entry.declaration_index
165                ));
166            };
167
168            let mut additional_bytes = declaration_size(declaration, queue_entry.style);
169            if this.budget_used + additional_bytes > max_bytes {
170                continue;
171            }
172
173            let additional_parents = this.additional_parent_signatures(
174                &declaration.path,
175                declaration.parent_index,
176                &mut included_parents,
177            )?;
178            additional_bytes += additional_parents
179                .iter()
180                .map(|(_, snippet)| snippet.text.len())
181                .sum::<usize>();
182            if this.budget_used + additional_bytes > max_bytes {
183                continue;
184            }
185
186            this.budget_used += additional_bytes;
187            this.add_parents(&mut included_parents, additional_parents);
188            let planned_snippet = match queue_entry.style {
189                DeclarationStyle::Signature => {
190                    let Some(text) = declaration.text.get(declaration.signature_range.clone())
191                    else {
192                        return Err(anyhow!(
193                            "Invalid declaration signature_range {:?} with text.len() = {}",
194                            declaration.signature_range,
195                            declaration.text.len()
196                        ));
197                    };
198                    PlannedSnippet {
199                        path: declaration.path.clone(),
200                        range: (declaration.signature_range.start + declaration.range.start)
201                            ..(declaration.signature_range.end + declaration.range.start),
202                        text,
203                        text_is_truncated: declaration.text_is_truncated,
204                    }
205                }
206                DeclarationStyle::Declaration => PlannedSnippet {
207                    path: declaration.path.clone(),
208                    range: declaration.range.clone(),
209                    text: &declaration.text,
210                    text_is_truncated: declaration.text_is_truncated,
211                },
212            };
213            this.snippets.push(planned_snippet);
214
215            // When a Signature is consumed, insert an entry for Definition style.
216            if queue_entry.style == DeclarationStyle::Signature {
217                let signature_size = declaration_size(&declaration, DeclarationStyle::Signature);
218                let declaration_size =
219                    declaration_size(&declaration, DeclarationStyle::Declaration);
220                let signature_score = declaration_score(&declaration, DeclarationStyle::Signature);
221                let declaration_score =
222                    declaration_score(&declaration, DeclarationStyle::Declaration);
223
224                let score_diff = declaration_score - signature_score;
225                let size_diff = declaration_size.saturating_sub(signature_size);
226                if score_diff > 0.0001 && size_diff > 0 {
227                    queue.push(QueueEntry {
228                        declaration_index: queue_entry.declaration_index,
229                        score_density: OrderedFloat(score_diff / (size_diff as f32)),
230                        style: DeclarationStyle::Declaration,
231                    });
232                }
233            }
234        }
235
236        anyhow::Ok(this)
237    }
238
239    fn add_parents(
240        &mut self,
241        included_parents: &mut FxHashSet<usize>,
242        snippets: Vec<(usize, PlannedSnippet<'a>)>,
243    ) {
244        for (parent_index, snippet) in snippets {
245            included_parents.insert(parent_index);
246            self.budget_used += snippet.text.len();
247            self.snippets.push(snippet);
248        }
249    }
250
251    fn additional_parent_signatures(
252        &self,
253        path: &Arc<Path>,
254        parent_index: Option<usize>,
255        included_parents: &FxHashSet<usize>,
256    ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
257        let mut results = Vec::new();
258        self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
259        Ok(results)
260    }
261
262    fn additional_parent_signatures_impl(
263        &self,
264        path: &Arc<Path>,
265        parent_index: Option<usize>,
266        included_parents: &FxHashSet<usize>,
267        results: &mut Vec<(usize, PlannedSnippet<'a>)>,
268    ) -> Result<()> {
269        let Some(parent_index) = parent_index else {
270            return Ok(());
271        };
272        if included_parents.contains(&parent_index) {
273            return Ok(());
274        }
275        let Some(parent_signature) = self.request.signatures.get(parent_index) else {
276            return Err(anyhow!("Invalid parent index {}", parent_index));
277        };
278        results.push((
279            parent_index,
280            PlannedSnippet {
281                path: path.clone(),
282                range: parent_signature.range.clone(),
283                text: &parent_signature.text,
284                text_is_truncated: parent_signature.text_is_truncated,
285            },
286        ));
287        self.additional_parent_signatures_impl(
288            path,
289            parent_signature.parent_index,
290            included_parents,
291            results,
292        )
293    }
294
295    /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
296    /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
297    /// chunks.
298    pub fn to_prompt_string(&'a self) -> Result<(String, SectionLabels)> {
299        let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
300            FxHashMap::default();
301        for snippet in &self.snippets {
302            file_to_snippets
303                .entry(&snippet.path)
304                .or_default()
305                .push(snippet);
306        }
307
308        // Reorder so that file with cursor comes last
309        let mut file_snippets = Vec::new();
310        let mut excerpt_file_snippets = Vec::new();
311        for (file_path, snippets) in file_to_snippets {
312            if file_path == self.request.excerpt_path.as_ref() {
313                excerpt_file_snippets = snippets;
314            } else {
315                file_snippets.push((file_path, snippets, false));
316            }
317        }
318        let excerpt_snippet = PlannedSnippet {
319            path: self.request.excerpt_path.clone(),
320            range: self.request.excerpt_range.clone(),
321            text: &self.request.excerpt,
322            text_is_truncated: false,
323        };
324        excerpt_file_snippets.push(&excerpt_snippet);
325        file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
326
327        let mut excerpt_file_insertions = match self.request.prompt_format {
328            PromptFormat::MarkedExcerpt => vec![
329                (
330                    self.request.excerpt_range.start,
331                    EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
332                ),
333                (
334                    self.request.excerpt_range.start + self.request.cursor_offset,
335                    CURSOR_MARKER,
336                ),
337                (
338                    self.request
339                        .excerpt_range
340                        .end
341                        .saturating_sub(0)
342                        .max(self.request.excerpt_range.start),
343                    EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
344                ),
345            ],
346            PromptFormat::LabeledSections => vec![(
347                self.request.excerpt_range.start + self.request.cursor_offset,
348                CURSOR_MARKER,
349            )],
350            PromptFormat::OnlySnippets => vec![],
351        };
352
353        let mut prompt = String::new();
354        prompt.push_str("## User Edits\n\n");
355        Self::push_events(&mut prompt, &self.request.events);
356
357        prompt.push_str("\n## Code\n\n");
358        let section_labels =
359            self.push_file_snippets(&mut prompt, &mut excerpt_file_insertions, file_snippets)?;
360        Ok((prompt, section_labels))
361    }
362
363    fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
364        for event in events {
365            match event {
366                Event::BufferChange {
367                    path,
368                    old_path,
369                    diff,
370                    predicted,
371                } => {
372                    if let Some(old_path) = &old_path
373                        && let Some(new_path) = &path
374                    {
375                        if old_path != new_path {
376                            writeln!(
377                                output,
378                                "User renamed {} to {}\n\n",
379                                old_path.display(),
380                                new_path.display()
381                            )
382                            .unwrap();
383                        }
384                    }
385
386                    let path = path
387                        .as_ref()
388                        .map_or_else(|| "untitled".to_string(), |path| path.display().to_string());
389
390                    if *predicted {
391                        writeln!(
392                            output,
393                            "User accepted prediction {:?}:\n```diff\n{}\n```\n",
394                            path, diff
395                        )
396                        .unwrap();
397                    } else {
398                        writeln!(output, "User edited {:?}:\n```diff\n{}\n```\n", path, diff)
399                            .unwrap();
400                    }
401                }
402            }
403        }
404    }
405
406    fn push_file_snippets(
407        &self,
408        output: &mut String,
409        excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
410        file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
411    ) -> Result<SectionLabels> {
412        let mut section_ranges = Vec::new();
413        let mut excerpt_index = None;
414
415        for (file_path, mut snippets, is_excerpt_file) in file_snippets {
416            snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
417
418            // TODO: What if the snippets get expanded too large to be editable?
419            let mut current_snippet: Option<(&PlannedSnippet, Range<usize>)> = None;
420            let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<usize>)> = Vec::new();
421            for snippet in snippets {
422                if let Some((_, current_snippet_range)) = current_snippet.as_mut()
423                    && snippet.range.start < current_snippet_range.end
424                {
425                    if snippet.range.end > current_snippet_range.end {
426                        current_snippet_range.end = snippet.range.end;
427                    }
428                    continue;
429                }
430                if let Some(current_snippet) = current_snippet.take() {
431                    disjoint_snippets.push(current_snippet);
432                }
433                current_snippet = Some((snippet, snippet.range.clone()));
434            }
435            if let Some(current_snippet) = current_snippet.take() {
436                disjoint_snippets.push(current_snippet);
437            }
438
439            writeln!(output, "```{}", file_path.display()).ok();
440            let mut skipped_last_snippet = false;
441            for (snippet, range) in disjoint_snippets {
442                let section_index = section_ranges.len();
443
444                match self.request.prompt_format {
445                    PromptFormat::MarkedExcerpt | PromptFormat::OnlySnippets => {
446                        if range.start > 0 && !skipped_last_snippet {
447                            output.push_str("\n");
448                        }
449                    }
450                    PromptFormat::LabeledSections => {
451                        if is_excerpt_file
452                            && range.start <= self.request.excerpt_range.start
453                            && range.end >= self.request.excerpt_range.end
454                        {
455                            writeln!(output, "<|current_section|>").ok();
456                        } else {
457                            writeln!(output, "<|section_{}|>", section_index).ok();
458                        }
459                    }
460                }
461
462                if is_excerpt_file {
463                    if self.request.prompt_format == PromptFormat::OnlySnippets {
464                        if range.start >= self.request.excerpt_range.start
465                            && range.end <= self.request.excerpt_range.end
466                        {
467                            skipped_last_snippet = true;
468                        } else {
469                            skipped_last_snippet = false;
470                            output.push_str(snippet.text);
471                        }
472                    } else {
473                        let mut last_offset = range.start;
474                        let mut i = 0;
475                        while i < excerpt_file_insertions.len() {
476                            let (offset, insertion) = &excerpt_file_insertions[i];
477                            let found = *offset >= range.start && *offset <= range.end;
478                            if found {
479                                excerpt_index = Some(section_index);
480                                output.push_str(
481                                    &snippet.text[last_offset - range.start..offset - range.start],
482                                );
483                                output.push_str(insertion);
484                                last_offset = *offset;
485                                excerpt_file_insertions.remove(i);
486                                continue;
487                            }
488                            i += 1;
489                        }
490                        skipped_last_snippet = false;
491                        output.push_str(&snippet.text[last_offset - range.start..]);
492                    }
493                } else {
494                    skipped_last_snippet = false;
495                    output.push_str(snippet.text);
496                }
497
498                section_ranges.push((snippet.path.clone(), range));
499            }
500
501            output.push_str("```\n\n");
502        }
503
504        Ok(SectionLabels {
505            // TODO: Clean this up
506            excerpt_index: match self.request.prompt_format {
507                PromptFormat::OnlySnippets => 0,
508                _ => excerpt_index.context("bug: no snippet found for excerpt")?,
509            },
510            section_ranges,
511        })
512    }
513}
514
515fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
516    declaration_score(declaration, style) / declaration_size(declaration, style) as f32
517}
518
519fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
520    match style {
521        DeclarationStyle::Signature => declaration.signature_score,
522        DeclarationStyle::Declaration => declaration.declaration_score,
523    }
524}
525
526fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize {
527    match style {
528        DeclarationStyle::Signature => declaration.signature_range.len(),
529        DeclarationStyle::Declaration => declaration.text.len(),
530    }
531}