cloud_zeta2_prompt.rs

  1//! Zeta2 prompt planning and generation code shared with cloud.
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, ReferencedDeclaration};
  5use indoc::indoc;
  6use ordered_float::OrderedFloat;
  7use rustc_hash::{FxHashMap, FxHashSet};
  8use std::fmt::Write;
  9use std::sync::Arc;
 10use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
 11use strum::{EnumIter, IntoEnumIterator};
 12
 13pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
 14
 15pub const CURSOR_MARKER: &str = "<|cursor_position|>";
 16/// NOTE: Differs from zed version of constant - includes a newline
 17pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
 18/// NOTE: Differs from zed version of constant - includes a newline
 19pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
 20
 21// TODO: use constants for markers?
 22const MARKED_EXCERPT_SYSTEM_PROMPT: &str = indoc! {"
 23    You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
 24
 25    The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|cursor_position|>.  Please respond with edited code for that region.
 26
 27    Other code is provided for context, and `…` indicates when code has been skipped.
 28"};
 29
 30const LABELED_SECTIONS_SYSTEM_PROMPT: &str = indoc! {r#"
 31    You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code.
 32
 33    Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`).
 34
 35    The cursor position is marked with `<|cursor_position|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it.
 36
 37    Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example:
 38
 39    <|current_section|>
 40    for i in 0..16 {
 41        println!("{i}");
 42    }
 43"#};
 44
 45pub struct PlannedPrompt<'a> {
 46    request: &'a predict_edits_v3::PredictEditsRequest,
 47    /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
 48    /// `to_prompt_string`.
 49    snippets: Vec<PlannedSnippet<'a>>,
 50    budget_used: usize,
 51}
 52
 53pub fn system_prompt(format: PromptFormat) -> &'static str {
 54    match format {
 55        PromptFormat::MarkedExcerpt => MARKED_EXCERPT_SYSTEM_PROMPT,
 56        PromptFormat::LabeledSections => LABELED_SECTIONS_SYSTEM_PROMPT,
 57    }
 58}
 59
 60#[derive(Clone, Debug)]
 61pub struct PlannedSnippet<'a> {
 62    path: Arc<Path>,
 63    range: Range<usize>,
 64    text: &'a str,
 65    // TODO: Indicate this in the output
 66    #[allow(dead_code)]
 67    text_is_truncated: bool,
 68}
 69
 70#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
 71pub enum SnippetStyle {
 72    Signature,
 73    Declaration,
 74}
 75
 76#[derive(Clone, Debug)]
 77pub struct SectionLabels {
 78    pub excerpt_index: usize,
 79    pub section_ranges: Vec<(Arc<Path>, Range<usize>)>,
 80}
 81
 82impl<'a> PlannedPrompt<'a> {
 83    /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
 84    ///
 85    /// Initializes a priority queue by populating it with each snippet, finding the SnippetStyle
 86    /// that minimizes `score_density = score / snippet.range(style).len()`. When a "signature"
 87    /// snippet is popped, insert an entry for the "declaration" variant that reflects the cost of
 88    /// upgrade.
 89    ///
 90    /// TODO: Implement an early halting condition. One option might be to have another priority
 91    /// queue where the score is the size, and update it accordingly. Another option might be to
 92    /// have some simpler heuristic like bailing after N failed insertions, or based on how much
 93    /// budget is left.
 94    ///
 95    /// TODO: Has the current known sources of imprecision:
 96    ///
 97    /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
 98    /// plan even though the containing struct is already included.
 99    ///
100    /// * Does not consider cost of signatures when ranking snippets - this is tricky since
101    /// signatures may be shared by multiple snippets.
102    ///
103    /// * Does not include file paths / other text when considering max_bytes.
104    pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
105        let mut this = PlannedPrompt {
106            request,
107            snippets: Vec::new(),
108            budget_used: request.excerpt.len(),
109        };
110        let mut included_parents = FxHashSet::default();
111        let additional_parents = this.additional_parent_signatures(
112            &request.excerpt_path,
113            request.excerpt_parent,
114            &included_parents,
115        )?;
116        this.add_parents(&mut included_parents, additional_parents);
117
118        let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES);
119
120        if this.budget_used > max_bytes {
121            return Err(anyhow!(
122                "Excerpt + signatures size of {} already exceeds budget of {}",
123                this.budget_used,
124                max_bytes
125            ));
126        }
127
128        #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
129        struct QueueEntry {
130            score_density: OrderedFloat<f32>,
131            declaration_index: usize,
132            style: SnippetStyle,
133        }
134
135        // Initialize priority queue with the best score for each snippet.
136        let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
137        for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
138            let (style, score_density) = SnippetStyle::iter()
139                .map(|style| {
140                    (
141                        style,
142                        OrderedFloat(declaration_score_density(&declaration, style)),
143                    )
144                })
145                .max_by_key(|(_, score_density)| *score_density)
146                .unwrap();
147            queue.push(QueueEntry {
148                score_density,
149                declaration_index,
150                style,
151            });
152        }
153
154        // Knapsack selection loop
155        while let Some(queue_entry) = queue.pop() {
156            let Some(declaration) = request
157                .referenced_declarations
158                .get(queue_entry.declaration_index)
159            else {
160                return Err(anyhow!(
161                    "Invalid declaration index {}",
162                    queue_entry.declaration_index
163                ));
164            };
165
166            let mut additional_bytes = declaration_size(declaration, queue_entry.style);
167            if this.budget_used + additional_bytes > max_bytes {
168                continue;
169            }
170
171            let additional_parents = this.additional_parent_signatures(
172                &declaration.path,
173                declaration.parent_index,
174                &mut included_parents,
175            )?;
176            additional_bytes += additional_parents
177                .iter()
178                .map(|(_, snippet)| snippet.text.len())
179                .sum::<usize>();
180            if this.budget_used + additional_bytes > max_bytes {
181                continue;
182            }
183
184            this.budget_used += additional_bytes;
185            this.add_parents(&mut included_parents, additional_parents);
186            let planned_snippet = match queue_entry.style {
187                SnippetStyle::Signature => {
188                    let Some(text) = declaration.text.get(declaration.signature_range.clone())
189                    else {
190                        return Err(anyhow!(
191                            "Invalid declaration signature_range {:?} with text.len() = {}",
192                            declaration.signature_range,
193                            declaration.text.len()
194                        ));
195                    };
196                    PlannedSnippet {
197                        path: declaration.path.clone(),
198                        range: (declaration.signature_range.start + declaration.range.start)
199                            ..(declaration.signature_range.end + declaration.range.start),
200                        text,
201                        text_is_truncated: declaration.text_is_truncated,
202                    }
203                }
204                SnippetStyle::Declaration => PlannedSnippet {
205                    path: declaration.path.clone(),
206                    range: declaration.range.clone(),
207                    text: &declaration.text,
208                    text_is_truncated: declaration.text_is_truncated,
209                },
210            };
211            this.snippets.push(planned_snippet);
212
213            // When a Signature is consumed, insert an entry for Definition style.
214            if queue_entry.style == SnippetStyle::Signature {
215                let signature_size = declaration_size(&declaration, SnippetStyle::Signature);
216                let declaration_size = declaration_size(&declaration, SnippetStyle::Declaration);
217                let signature_score = declaration_score(&declaration, SnippetStyle::Signature);
218                let declaration_score = declaration_score(&declaration, SnippetStyle::Declaration);
219
220                let score_diff = declaration_score - signature_score;
221                let size_diff = declaration_size.saturating_sub(signature_size);
222                if score_diff > 0.0001 && size_diff > 0 {
223                    queue.push(QueueEntry {
224                        declaration_index: queue_entry.declaration_index,
225                        score_density: OrderedFloat(score_diff / (size_diff as f32)),
226                        style: SnippetStyle::Declaration,
227                    });
228                }
229            }
230        }
231
232        anyhow::Ok(this)
233    }
234
235    fn add_parents(
236        &mut self,
237        included_parents: &mut FxHashSet<usize>,
238        snippets: Vec<(usize, PlannedSnippet<'a>)>,
239    ) {
240        for (parent_index, snippet) in snippets {
241            included_parents.insert(parent_index);
242            self.budget_used += snippet.text.len();
243            self.snippets.push(snippet);
244        }
245    }
246
247    fn additional_parent_signatures(
248        &self,
249        path: &Arc<Path>,
250        parent_index: Option<usize>,
251        included_parents: &FxHashSet<usize>,
252    ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
253        let mut results = Vec::new();
254        self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
255        Ok(results)
256    }
257
258    fn additional_parent_signatures_impl(
259        &self,
260        path: &Arc<Path>,
261        parent_index: Option<usize>,
262        included_parents: &FxHashSet<usize>,
263        results: &mut Vec<(usize, PlannedSnippet<'a>)>,
264    ) -> Result<()> {
265        let Some(parent_index) = parent_index else {
266            return Ok(());
267        };
268        if included_parents.contains(&parent_index) {
269            return Ok(());
270        }
271        let Some(parent_signature) = self.request.signatures.get(parent_index) else {
272            return Err(anyhow!("Invalid parent index {}", parent_index));
273        };
274        results.push((
275            parent_index,
276            PlannedSnippet {
277                path: path.clone(),
278                range: parent_signature.range.clone(),
279                text: &parent_signature.text,
280                text_is_truncated: parent_signature.text_is_truncated,
281            },
282        ));
283        self.additional_parent_signatures_impl(
284            path,
285            parent_signature.parent_index,
286            included_parents,
287            results,
288        )
289    }
290
291    /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
292    /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
293    /// chunks.
294    pub fn to_prompt_string(&'a self) -> Result<(String, SectionLabels)> {
295        let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
296            FxHashMap::default();
297        for snippet in &self.snippets {
298            file_to_snippets
299                .entry(&snippet.path)
300                .or_default()
301                .push(snippet);
302        }
303
304        // Reorder so that file with cursor comes last
305        let mut file_snippets = Vec::new();
306        let mut excerpt_file_snippets = Vec::new();
307        for (file_path, snippets) in file_to_snippets {
308            if file_path == self.request.excerpt_path.as_ref() {
309                excerpt_file_snippets = snippets;
310            } else {
311                file_snippets.push((file_path, snippets, false));
312            }
313        }
314        let excerpt_snippet = PlannedSnippet {
315            path: self.request.excerpt_path.clone(),
316            range: self.request.excerpt_range.clone(),
317            text: &self.request.excerpt,
318            text_is_truncated: false,
319        };
320        excerpt_file_snippets.push(&excerpt_snippet);
321        file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
322
323        let mut excerpt_file_insertions = match self.request.prompt_format {
324            PromptFormat::MarkedExcerpt => vec![
325                (
326                    self.request.excerpt_range.start,
327                    EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
328                ),
329                (
330                    self.request.excerpt_range.start + self.request.cursor_offset,
331                    CURSOR_MARKER,
332                ),
333                (
334                    self.request
335                        .excerpt_range
336                        .end
337                        .saturating_sub(0)
338                        .max(self.request.excerpt_range.start),
339                    EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
340                ),
341            ],
342            PromptFormat::LabeledSections => vec![(
343                self.request.excerpt_range.start + self.request.cursor_offset,
344                CURSOR_MARKER,
345            )],
346        };
347
348        let mut prompt = String::new();
349        prompt.push_str("## User Edits\n\n");
350        Self::push_events(&mut prompt, &self.request.events);
351
352        prompt.push_str("\n## Code\n\n");
353        let section_labels =
354            self.push_file_snippets(&mut prompt, &mut excerpt_file_insertions, file_snippets)?;
355        Ok((prompt, section_labels))
356    }
357
358    fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
359        for event in events {
360            match event {
361                Event::BufferChange {
362                    path,
363                    old_path,
364                    diff,
365                    predicted,
366                } => {
367                    if let Some(old_path) = &old_path
368                        && let Some(new_path) = &path
369                    {
370                        if old_path != new_path {
371                            writeln!(
372                                output,
373                                "User renamed {} to {}\n\n",
374                                old_path.display(),
375                                new_path.display()
376                            )
377                            .unwrap();
378                        }
379                    }
380
381                    let path = path
382                        .as_ref()
383                        .map_or_else(|| "untitled".to_string(), |path| path.display().to_string());
384
385                    if *predicted {
386                        writeln!(
387                            output,
388                            "User accepted prediction {:?}:\n```diff\n{}\n```\n",
389                            path, diff
390                        )
391                        .unwrap();
392                    } else {
393                        writeln!(output, "User edited {:?}:\n```diff\n{}\n```\n", path, diff)
394                            .unwrap();
395                    }
396                }
397            }
398        }
399    }
400
401    fn push_file_snippets(
402        &self,
403        output: &mut String,
404        excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
405        file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
406    ) -> Result<SectionLabels> {
407        let mut section_ranges = Vec::new();
408        let mut excerpt_index = None;
409
410        for (file_path, mut snippets, is_excerpt_file) in file_snippets {
411            snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
412
413            // TODO: What if the snippets get expanded too large to be editable?
414            let mut current_snippet: Option<(&PlannedSnippet, Range<usize>)> = None;
415            let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<usize>)> = Vec::new();
416            for snippet in snippets {
417                if let Some((_, current_snippet_range)) = current_snippet.as_mut()
418                    && snippet.range.start < current_snippet_range.end
419                {
420                    if snippet.range.end > current_snippet_range.end {
421                        current_snippet_range.end = snippet.range.end;
422                    }
423                    continue;
424                }
425                if let Some(current_snippet) = current_snippet.take() {
426                    disjoint_snippets.push(current_snippet);
427                }
428                current_snippet = Some((snippet, snippet.range.clone()));
429            }
430            if let Some(current_snippet) = current_snippet.take() {
431                disjoint_snippets.push(current_snippet);
432            }
433
434            writeln!(output, "```{}", file_path.display()).ok();
435            for (snippet, range) in disjoint_snippets {
436                let section_index = section_ranges.len();
437
438                match self.request.prompt_format {
439                    PromptFormat::MarkedExcerpt => {
440                        if range.start > 0 {
441                            output.push_str("\n");
442                        }
443                    }
444                    PromptFormat::LabeledSections => {
445                        if is_excerpt_file
446                            && range.start <= self.request.excerpt_range.start
447                            && range.end >= self.request.excerpt_range.end
448                        {
449                            writeln!(output, "<|current_section|>").ok();
450                        } else {
451                            writeln!(output, "<|section_{}|>", section_index).ok();
452                        }
453                    }
454                }
455
456                if is_excerpt_file {
457                    excerpt_index = Some(section_index);
458                    let mut last_offset = range.start;
459                    let mut i = 0;
460                    while i < excerpt_file_insertions.len() {
461                        let (offset, insertion) = &excerpt_file_insertions[i];
462                        let found = *offset >= range.start && *offset <= range.end;
463                        if found {
464                            output.push_str(
465                                &snippet.text[last_offset - range.start..offset - range.start],
466                            );
467                            output.push_str(insertion);
468                            last_offset = *offset;
469                            excerpt_file_insertions.remove(i);
470                            continue;
471                        }
472                        i += 1;
473                    }
474                    output.push_str(&snippet.text[last_offset - range.start..]);
475                } else {
476                    output.push_str(snippet.text);
477                }
478
479                section_ranges.push((snippet.path.clone(), range));
480            }
481
482            output.push_str("```\n\n");
483        }
484
485        Ok(SectionLabels {
486            excerpt_index: excerpt_index.context("bug: no snippet found for excerpt")?,
487            section_ranges,
488        })
489    }
490}
491
492fn declaration_score_density(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
493    declaration_score(declaration, style) / declaration_size(declaration, style) as f32
494}
495
496fn declaration_score(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
497    match style {
498        SnippetStyle::Signature => declaration.signature_score,
499        SnippetStyle::Declaration => declaration.declaration_score,
500    }
501}
502
503fn declaration_size(declaration: &ReferencedDeclaration, style: SnippetStyle) -> usize {
504    match style {
505        SnippetStyle::Signature => declaration.signature_range.len(),
506        SnippetStyle::Declaration => declaration.text.len(),
507    }
508}