cloud_zeta2_prompt.rs

  1//! Zeta2 prompt planning and generation code shared with cloud.
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, ReferencedDeclaration};
  5use indoc::indoc;
  6use ordered_float::OrderedFloat;
  7use rustc_hash::{FxHashMap, FxHashSet};
  8use std::fmt::Write;
  9use std::sync::Arc;
 10use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
 11use strum::{EnumIter, IntoEnumIterator};
 12
 13pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
 14
 15pub const CURSOR_MARKER: &str = "<|cursor_position|>";
 16/// NOTE: Differs from zed version of constant - includes a newline
 17pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
 18/// NOTE: Differs from zed version of constant - includes a newline
 19pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
 20
 21// TODO: use constants for markers?
 22const MARKED_EXCERPT_SYSTEM_PROMPT: &str = indoc! {"
 23    You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
 24
 25    The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|cursor_position|>.  Please respond with edited code for that region.
 26
 27    Other code is provided for context, and `…` indicates when code has been skipped.
 28"};
 29
 30const LABELED_SECTIONS_SYSTEM_PROMPT: &str = indoc! {r#"
 31    You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code.
 32
 33    Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`).
 34
 35    The cursor position is marked with `<|cursor_position|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it.
 36
 37    Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example:
 38
 39    <|current_section|>
 40    for i in 0..16 {
 41        println!("{i}");
 42    }
 43"#};
 44
 45pub struct PlannedPrompt<'a> {
 46    request: &'a predict_edits_v3::PredictEditsRequest,
 47    /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
 48    /// `to_prompt_string`.
 49    snippets: Vec<PlannedSnippet<'a>>,
 50    budget_used: usize,
 51}
 52
 53pub fn system_prompt(format: PromptFormat) -> &'static str {
 54    match format {
 55        PromptFormat::MarkedExcerpt => MARKED_EXCERPT_SYSTEM_PROMPT,
 56        PromptFormat::LabeledSections => LABELED_SECTIONS_SYSTEM_PROMPT,
 57        // only intended for use via zeta_cli
 58        PromptFormat::OnlySnippets => "",
 59    }
 60}
 61
 62#[derive(Clone, Debug)]
 63pub struct PlannedSnippet<'a> {
 64    path: Arc<Path>,
 65    range: Range<usize>,
 66    text: &'a str,
 67    // TODO: Indicate this in the output
 68    #[allow(dead_code)]
 69    text_is_truncated: bool,
 70}
 71
 72#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
 73pub enum SnippetStyle {
 74    Signature,
 75    Declaration,
 76}
 77
 78#[derive(Clone, Debug)]
 79pub struct SectionLabels {
 80    pub excerpt_index: usize,
 81    pub section_ranges: Vec<(Arc<Path>, Range<usize>)>,
 82}
 83
 84impl<'a> PlannedPrompt<'a> {
 85    /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
 86    ///
 87    /// Initializes a priority queue by populating it with each snippet, finding the SnippetStyle
 88    /// that minimizes `score_density = score / snippet.range(style).len()`. When a "signature"
 89    /// snippet is popped, insert an entry for the "declaration" variant that reflects the cost of
 90    /// upgrade.
 91    ///
 92    /// TODO: Implement an early halting condition. One option might be to have another priority
 93    /// queue where the score is the size, and update it accordingly. Another option might be to
 94    /// have some simpler heuristic like bailing after N failed insertions, or based on how much
 95    /// budget is left.
 96    ///
 97    /// TODO: Has the current known sources of imprecision:
 98    ///
 99    /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
100    /// plan even though the containing struct is already included.
101    ///
102    /// * Does not consider cost of signatures when ranking snippets - this is tricky since
103    /// signatures may be shared by multiple snippets.
104    ///
105    /// * Does not include file paths / other text when considering max_bytes.
106    pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
107        let mut this = PlannedPrompt {
108            request,
109            snippets: Vec::new(),
110            budget_used: request.excerpt.len(),
111        };
112        let mut included_parents = FxHashSet::default();
113        let additional_parents = this.additional_parent_signatures(
114            &request.excerpt_path,
115            request.excerpt_parent,
116            &included_parents,
117        )?;
118        this.add_parents(&mut included_parents, additional_parents);
119
120        let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES);
121
122        if this.budget_used > max_bytes {
123            return Err(anyhow!(
124                "Excerpt + signatures size of {} already exceeds budget of {}",
125                this.budget_used,
126                max_bytes
127            ));
128        }
129
130        #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
131        struct QueueEntry {
132            score_density: OrderedFloat<f32>,
133            declaration_index: usize,
134            style: SnippetStyle,
135        }
136
137        // Initialize priority queue with the best score for each snippet.
138        let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
139        for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
140            let (style, score_density) = SnippetStyle::iter()
141                .map(|style| {
142                    (
143                        style,
144                        OrderedFloat(declaration_score_density(&declaration, style)),
145                    )
146                })
147                .max_by_key(|(_, score_density)| *score_density)
148                .unwrap();
149            queue.push(QueueEntry {
150                score_density,
151                declaration_index,
152                style,
153            });
154        }
155
156        // Knapsack selection loop
157        while let Some(queue_entry) = queue.pop() {
158            let Some(declaration) = request
159                .referenced_declarations
160                .get(queue_entry.declaration_index)
161            else {
162                return Err(anyhow!(
163                    "Invalid declaration index {}",
164                    queue_entry.declaration_index
165                ));
166            };
167
168            let mut additional_bytes = declaration_size(declaration, queue_entry.style);
169            if this.budget_used + additional_bytes > max_bytes {
170                continue;
171            }
172
173            let additional_parents = this.additional_parent_signatures(
174                &declaration.path,
175                declaration.parent_index,
176                &mut included_parents,
177            )?;
178            additional_bytes += additional_parents
179                .iter()
180                .map(|(_, snippet)| snippet.text.len())
181                .sum::<usize>();
182            if this.budget_used + additional_bytes > max_bytes {
183                continue;
184            }
185
186            this.budget_used += additional_bytes;
187            this.add_parents(&mut included_parents, additional_parents);
188            let planned_snippet = match queue_entry.style {
189                SnippetStyle::Signature => {
190                    let Some(text) = declaration.text.get(declaration.signature_range.clone())
191                    else {
192                        return Err(anyhow!(
193                            "Invalid declaration signature_range {:?} with text.len() = {}",
194                            declaration.signature_range,
195                            declaration.text.len()
196                        ));
197                    };
198                    PlannedSnippet {
199                        path: declaration.path.clone(),
200                        range: (declaration.signature_range.start + declaration.range.start)
201                            ..(declaration.signature_range.end + declaration.range.start),
202                        text,
203                        text_is_truncated: declaration.text_is_truncated,
204                    }
205                }
206                SnippetStyle::Declaration => PlannedSnippet {
207                    path: declaration.path.clone(),
208                    range: declaration.range.clone(),
209                    text: &declaration.text,
210                    text_is_truncated: declaration.text_is_truncated,
211                },
212            };
213            this.snippets.push(planned_snippet);
214
215            // When a Signature is consumed, insert an entry for Definition style.
216            if queue_entry.style == SnippetStyle::Signature {
217                let signature_size = declaration_size(&declaration, SnippetStyle::Signature);
218                let declaration_size = declaration_size(&declaration, SnippetStyle::Declaration);
219                let signature_score = declaration_score(&declaration, SnippetStyle::Signature);
220                let declaration_score = declaration_score(&declaration, SnippetStyle::Declaration);
221
222                let score_diff = declaration_score - signature_score;
223                let size_diff = declaration_size.saturating_sub(signature_size);
224                if score_diff > 0.0001 && size_diff > 0 {
225                    queue.push(QueueEntry {
226                        declaration_index: queue_entry.declaration_index,
227                        score_density: OrderedFloat(score_diff / (size_diff as f32)),
228                        style: SnippetStyle::Declaration,
229                    });
230                }
231            }
232        }
233
234        anyhow::Ok(this)
235    }
236
237    fn add_parents(
238        &mut self,
239        included_parents: &mut FxHashSet<usize>,
240        snippets: Vec<(usize, PlannedSnippet<'a>)>,
241    ) {
242        for (parent_index, snippet) in snippets {
243            included_parents.insert(parent_index);
244            self.budget_used += snippet.text.len();
245            self.snippets.push(snippet);
246        }
247    }
248
249    fn additional_parent_signatures(
250        &self,
251        path: &Arc<Path>,
252        parent_index: Option<usize>,
253        included_parents: &FxHashSet<usize>,
254    ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
255        let mut results = Vec::new();
256        self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
257        Ok(results)
258    }
259
260    fn additional_parent_signatures_impl(
261        &self,
262        path: &Arc<Path>,
263        parent_index: Option<usize>,
264        included_parents: &FxHashSet<usize>,
265        results: &mut Vec<(usize, PlannedSnippet<'a>)>,
266    ) -> Result<()> {
267        let Some(parent_index) = parent_index else {
268            return Ok(());
269        };
270        if included_parents.contains(&parent_index) {
271            return Ok(());
272        }
273        let Some(parent_signature) = self.request.signatures.get(parent_index) else {
274            return Err(anyhow!("Invalid parent index {}", parent_index));
275        };
276        results.push((
277            parent_index,
278            PlannedSnippet {
279                path: path.clone(),
280                range: parent_signature.range.clone(),
281                text: &parent_signature.text,
282                text_is_truncated: parent_signature.text_is_truncated,
283            },
284        ));
285        self.additional_parent_signatures_impl(
286            path,
287            parent_signature.parent_index,
288            included_parents,
289            results,
290        )
291    }
292
293    /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
294    /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
295    /// chunks.
296    pub fn to_prompt_string(&'a self) -> Result<(String, SectionLabels)> {
297        let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
298            FxHashMap::default();
299        for snippet in &self.snippets {
300            file_to_snippets
301                .entry(&snippet.path)
302                .or_default()
303                .push(snippet);
304        }
305
306        // Reorder so that file with cursor comes last
307        let mut file_snippets = Vec::new();
308        let mut excerpt_file_snippets = Vec::new();
309        for (file_path, snippets) in file_to_snippets {
310            if file_path == self.request.excerpt_path.as_ref() {
311                excerpt_file_snippets = snippets;
312            } else {
313                file_snippets.push((file_path, snippets, false));
314            }
315        }
316        let excerpt_snippet = PlannedSnippet {
317            path: self.request.excerpt_path.clone(),
318            range: self.request.excerpt_range.clone(),
319            text: &self.request.excerpt,
320            text_is_truncated: false,
321        };
322        excerpt_file_snippets.push(&excerpt_snippet);
323        file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
324
325        let mut excerpt_file_insertions = match self.request.prompt_format {
326            PromptFormat::MarkedExcerpt => vec![
327                (
328                    self.request.excerpt_range.start,
329                    EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
330                ),
331                (
332                    self.request.excerpt_range.start + self.request.cursor_offset,
333                    CURSOR_MARKER,
334                ),
335                (
336                    self.request
337                        .excerpt_range
338                        .end
339                        .saturating_sub(0)
340                        .max(self.request.excerpt_range.start),
341                    EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
342                ),
343            ],
344            PromptFormat::LabeledSections => vec![(
345                self.request.excerpt_range.start + self.request.cursor_offset,
346                CURSOR_MARKER,
347            )],
348            PromptFormat::OnlySnippets => vec![],
349        };
350
351        let mut prompt = String::new();
352        prompt.push_str("## User Edits\n\n");
353        Self::push_events(&mut prompt, &self.request.events);
354
355        prompt.push_str("\n## Code\n\n");
356        let section_labels =
357            self.push_file_snippets(&mut prompt, &mut excerpt_file_insertions, file_snippets)?;
358        Ok((prompt, section_labels))
359    }
360
361    fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
362        for event in events {
363            match event {
364                Event::BufferChange {
365                    path,
366                    old_path,
367                    diff,
368                    predicted,
369                } => {
370                    if let Some(old_path) = &old_path
371                        && let Some(new_path) = &path
372                    {
373                        if old_path != new_path {
374                            writeln!(
375                                output,
376                                "User renamed {} to {}\n\n",
377                                old_path.display(),
378                                new_path.display()
379                            )
380                            .unwrap();
381                        }
382                    }
383
384                    let path = path
385                        .as_ref()
386                        .map_or_else(|| "untitled".to_string(), |path| path.display().to_string());
387
388                    if *predicted {
389                        writeln!(
390                            output,
391                            "User accepted prediction {:?}:\n```diff\n{}\n```\n",
392                            path, diff
393                        )
394                        .unwrap();
395                    } else {
396                        writeln!(output, "User edited {:?}:\n```diff\n{}\n```\n", path, diff)
397                            .unwrap();
398                    }
399                }
400            }
401        }
402    }
403
404    fn push_file_snippets(
405        &self,
406        output: &mut String,
407        excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
408        file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
409    ) -> Result<SectionLabels> {
410        let mut section_ranges = Vec::new();
411        let mut excerpt_index = None;
412
413        for (file_path, mut snippets, is_excerpt_file) in file_snippets {
414            snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
415
416            // TODO: What if the snippets get expanded too large to be editable?
417            let mut current_snippet: Option<(&PlannedSnippet, Range<usize>)> = None;
418            let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<usize>)> = Vec::new();
419            for snippet in snippets {
420                if let Some((_, current_snippet_range)) = current_snippet.as_mut()
421                    && snippet.range.start < current_snippet_range.end
422                {
423                    if snippet.range.end > current_snippet_range.end {
424                        current_snippet_range.end = snippet.range.end;
425                    }
426                    continue;
427                }
428                if let Some(current_snippet) = current_snippet.take() {
429                    disjoint_snippets.push(current_snippet);
430                }
431                current_snippet = Some((snippet, snippet.range.clone()));
432            }
433            if let Some(current_snippet) = current_snippet.take() {
434                disjoint_snippets.push(current_snippet);
435            }
436
437            writeln!(output, "```{}", file_path.display()).ok();
438            let mut skipped_last_snippet = false;
439            for (snippet, range) in disjoint_snippets {
440                let section_index = section_ranges.len();
441
442                match self.request.prompt_format {
443                    PromptFormat::MarkedExcerpt | PromptFormat::OnlySnippets => {
444                        if range.start > 0 && !skipped_last_snippet {
445                            output.push_str("\n");
446                        }
447                    }
448                    PromptFormat::LabeledSections => {
449                        if is_excerpt_file
450                            && range.start <= self.request.excerpt_range.start
451                            && range.end >= self.request.excerpt_range.end
452                        {
453                            writeln!(output, "<|current_section|>").ok();
454                        } else {
455                            writeln!(output, "<|section_{}|>", section_index).ok();
456                        }
457                    }
458                }
459
460                if is_excerpt_file {
461                    if self.request.prompt_format == PromptFormat::OnlySnippets {
462                        if range.start >= self.request.excerpt_range.start
463                            && range.end <= self.request.excerpt_range.end
464                        {
465                            skipped_last_snippet = true;
466                        } else {
467                            skipped_last_snippet = false;
468                            output.push_str(snippet.text);
469                        }
470                    } else {
471                        let mut last_offset = range.start;
472                        let mut i = 0;
473                        while i < excerpt_file_insertions.len() {
474                            let (offset, insertion) = &excerpt_file_insertions[i];
475                            let found = *offset >= range.start && *offset <= range.end;
476                            if found {
477                                excerpt_index = Some(section_index);
478                                output.push_str(
479                                    &snippet.text[last_offset - range.start..offset - range.start],
480                                );
481                                output.push_str(insertion);
482                                last_offset = *offset;
483                                excerpt_file_insertions.remove(i);
484                                continue;
485                            }
486                            i += 1;
487                        }
488                        skipped_last_snippet = false;
489                        output.push_str(&snippet.text[last_offset - range.start..]);
490                    }
491                } else {
492                    skipped_last_snippet = false;
493                    output.push_str(snippet.text);
494                }
495
496                section_ranges.push((snippet.path.clone(), range));
497            }
498
499            output.push_str("```\n\n");
500        }
501
502        Ok(SectionLabels {
503            // TODO: Clean this up
504            excerpt_index: match self.request.prompt_format {
505                PromptFormat::OnlySnippets => 0,
506                _ => excerpt_index.context("bug: no snippet found for excerpt")?,
507            },
508            section_ranges,
509        })
510    }
511}
512
513fn declaration_score_density(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
514    declaration_score(declaration, style) / declaration_size(declaration, style) as f32
515}
516
517fn declaration_score(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
518    match style {
519        SnippetStyle::Signature => declaration.signature_score,
520        SnippetStyle::Declaration => declaration.declaration_score,
521    }
522}
523
524fn declaration_size(declaration: &ReferencedDeclaration, style: SnippetStyle) -> usize {
525    match style {
526        SnippetStyle::Signature => declaration.signature_range.len(),
527        SnippetStyle::Declaration => declaration.text.len(),
528    }
529}