cloud_zeta2_prompt.rs

  1//! Zeta2 prompt planning and generation code shared with cloud.
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, ReferencedDeclaration};
  5use indoc::indoc;
  6use ordered_float::OrderedFloat;
  7use rustc_hash::{FxHashMap, FxHashSet};
  8use serde::Serialize;
  9use std::fmt::Write;
 10use std::sync::Arc;
 11use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
 12use strum::{EnumIter, IntoEnumIterator};
 13
 14pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
 15
 16pub const CURSOR_MARKER: &str = "<|cursor_position|>";
 17/// NOTE: Differs from zed version of constant - includes a newline
 18pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
 19/// NOTE: Differs from zed version of constant - includes a newline
 20pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
 21
 22// TODO: use constants for markers?
 23const MARKED_EXCERPT_SYSTEM_PROMPT: &str = indoc! {"
 24    You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
 25
 26    The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|cursor_position|>.  Please respond with edited code for that region.
 27
 28    Other code is provided for context, and `…` indicates when code has been skipped.
 29"};
 30
 31const LABELED_SECTIONS_SYSTEM_PROMPT: &str = indoc! {r#"
 32    You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code.
 33
 34    Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`).
 35
 36    The cursor position is marked with `<|cursor_position|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it.
 37
 38    Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example:
 39
 40    <|current_section|>
 41    for i in 0..16 {
 42        println!("{i}");
 43    }
 44"#};
 45
 46pub struct PlannedPrompt<'a> {
 47    request: &'a predict_edits_v3::PredictEditsRequest,
 48    /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
 49    /// `to_prompt_string`.
 50    snippets: Vec<PlannedSnippet<'a>>,
 51    budget_used: usize,
 52}
 53
 54pub fn system_prompt(format: PromptFormat) -> &'static str {
 55    match format {
 56        PromptFormat::MarkedExcerpt => MARKED_EXCERPT_SYSTEM_PROMPT,
 57        PromptFormat::LabeledSections => LABELED_SECTIONS_SYSTEM_PROMPT,
 58        // only intended for use via zeta_cli
 59        PromptFormat::OnlySnippets => "",
 60    }
 61}
 62
 63#[derive(Clone, Debug)]
 64pub struct PlannedSnippet<'a> {
 65    path: Arc<Path>,
 66    range: Range<usize>,
 67    text: &'a str,
 68    // TODO: Indicate this in the output
 69    #[allow(dead_code)]
 70    text_is_truncated: bool,
 71}
 72
 73#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
 74pub enum DeclarationStyle {
 75    Signature,
 76    Declaration,
 77}
 78
 79#[derive(Clone, Debug, Serialize)]
 80pub struct SectionLabels {
 81    pub excerpt_index: usize,
 82    pub section_ranges: Vec<(Arc<Path>, Range<usize>)>,
 83}
 84
 85impl<'a> PlannedPrompt<'a> {
 86    /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
 87    ///
 88    /// Initializes a priority queue by populating it with each snippet, finding the
 89    /// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a
 90    /// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects
 91    /// the cost of upgrade.
 92    ///
 93    /// TODO: Implement an early halting condition. One option might be to have another priority
 94    /// queue where the score is the size, and update it accordingly. Another option might be to
 95    /// have some simpler heuristic like bailing after N failed insertions, or based on how much
 96    /// budget is left.
 97    ///
 98    /// TODO: Has the current known sources of imprecision:
 99    ///
100    /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
101    /// plan even though the containing struct is already included.
102    ///
103    /// * Does not consider cost of signatures when ranking snippets - this is tricky since
104    /// signatures may be shared by multiple snippets.
105    ///
106    /// * Does not include file paths / other text when considering max_bytes.
107    pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
108        let mut this = PlannedPrompt {
109            request,
110            snippets: Vec::new(),
111            budget_used: request.excerpt.len(),
112        };
113        let mut included_parents = FxHashSet::default();
114        let additional_parents = this.additional_parent_signatures(
115            &request.excerpt_path,
116            request.excerpt_parent,
117            &included_parents,
118        )?;
119        this.add_parents(&mut included_parents, additional_parents);
120
121        let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES);
122
123        if this.budget_used > max_bytes {
124            return Err(anyhow!(
125                "Excerpt + signatures size of {} already exceeds budget of {}",
126                this.budget_used,
127                max_bytes
128            ));
129        }
130
131        #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
132        struct QueueEntry {
133            score_density: OrderedFloat<f32>,
134            declaration_index: usize,
135            style: DeclarationStyle,
136        }
137
138        // Initialize priority queue with the best score for each snippet.
139        let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
140        for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
141            let (style, score_density) = DeclarationStyle::iter()
142                .map(|style| {
143                    (
144                        style,
145                        OrderedFloat(declaration_score_density(&declaration, style)),
146                    )
147                })
148                .max_by_key(|(_, score_density)| *score_density)
149                .unwrap();
150            queue.push(QueueEntry {
151                score_density,
152                declaration_index,
153                style,
154            });
155        }
156
157        // Knapsack selection loop
158        while let Some(queue_entry) = queue.pop() {
159            let Some(declaration) = request
160                .referenced_declarations
161                .get(queue_entry.declaration_index)
162            else {
163                return Err(anyhow!(
164                    "Invalid declaration index {}",
165                    queue_entry.declaration_index
166                ));
167            };
168
169            let mut additional_bytes = declaration_size(declaration, queue_entry.style);
170            if this.budget_used + additional_bytes > max_bytes {
171                continue;
172            }
173
174            let additional_parents = this.additional_parent_signatures(
175                &declaration.path,
176                declaration.parent_index,
177                &mut included_parents,
178            )?;
179            additional_bytes += additional_parents
180                .iter()
181                .map(|(_, snippet)| snippet.text.len())
182                .sum::<usize>();
183            if this.budget_used + additional_bytes > max_bytes {
184                continue;
185            }
186
187            this.budget_used += additional_bytes;
188            this.add_parents(&mut included_parents, additional_parents);
189            let planned_snippet = match queue_entry.style {
190                DeclarationStyle::Signature => {
191                    let Some(text) = declaration.text.get(declaration.signature_range.clone())
192                    else {
193                        return Err(anyhow!(
194                            "Invalid declaration signature_range {:?} with text.len() = {}",
195                            declaration.signature_range,
196                            declaration.text.len()
197                        ));
198                    };
199                    PlannedSnippet {
200                        path: declaration.path.clone(),
201                        range: (declaration.signature_range.start + declaration.range.start)
202                            ..(declaration.signature_range.end + declaration.range.start),
203                        text,
204                        text_is_truncated: declaration.text_is_truncated,
205                    }
206                }
207                DeclarationStyle::Declaration => PlannedSnippet {
208                    path: declaration.path.clone(),
209                    range: declaration.range.clone(),
210                    text: &declaration.text,
211                    text_is_truncated: declaration.text_is_truncated,
212                },
213            };
214            this.snippets.push(planned_snippet);
215
216            // When a Signature is consumed, insert an entry for Definition style.
217            if queue_entry.style == DeclarationStyle::Signature {
218                let signature_size = declaration_size(&declaration, DeclarationStyle::Signature);
219                let declaration_size =
220                    declaration_size(&declaration, DeclarationStyle::Declaration);
221                let signature_score = declaration_score(&declaration, DeclarationStyle::Signature);
222                let declaration_score =
223                    declaration_score(&declaration, DeclarationStyle::Declaration);
224
225                let score_diff = declaration_score - signature_score;
226                let size_diff = declaration_size.saturating_sub(signature_size);
227                if score_diff > 0.0001 && size_diff > 0 {
228                    queue.push(QueueEntry {
229                        declaration_index: queue_entry.declaration_index,
230                        score_density: OrderedFloat(score_diff / (size_diff as f32)),
231                        style: DeclarationStyle::Declaration,
232                    });
233                }
234            }
235        }
236
237        anyhow::Ok(this)
238    }
239
240    fn add_parents(
241        &mut self,
242        included_parents: &mut FxHashSet<usize>,
243        snippets: Vec<(usize, PlannedSnippet<'a>)>,
244    ) {
245        for (parent_index, snippet) in snippets {
246            included_parents.insert(parent_index);
247            self.budget_used += snippet.text.len();
248            self.snippets.push(snippet);
249        }
250    }
251
252    fn additional_parent_signatures(
253        &self,
254        path: &Arc<Path>,
255        parent_index: Option<usize>,
256        included_parents: &FxHashSet<usize>,
257    ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
258        let mut results = Vec::new();
259        self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
260        Ok(results)
261    }
262
263    fn additional_parent_signatures_impl(
264        &self,
265        path: &Arc<Path>,
266        parent_index: Option<usize>,
267        included_parents: &FxHashSet<usize>,
268        results: &mut Vec<(usize, PlannedSnippet<'a>)>,
269    ) -> Result<()> {
270        let Some(parent_index) = parent_index else {
271            return Ok(());
272        };
273        if included_parents.contains(&parent_index) {
274            return Ok(());
275        }
276        let Some(parent_signature) = self.request.signatures.get(parent_index) else {
277            return Err(anyhow!("Invalid parent index {}", parent_index));
278        };
279        results.push((
280            parent_index,
281            PlannedSnippet {
282                path: path.clone(),
283                range: parent_signature.range.clone(),
284                text: &parent_signature.text,
285                text_is_truncated: parent_signature.text_is_truncated,
286            },
287        ));
288        self.additional_parent_signatures_impl(
289            path,
290            parent_signature.parent_index,
291            included_parents,
292            results,
293        )
294    }
295
296    /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
297    /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
298    /// chunks.
299    pub fn to_prompt_string(&'a self) -> Result<(String, SectionLabels)> {
300        let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
301            FxHashMap::default();
302        for snippet in &self.snippets {
303            file_to_snippets
304                .entry(&snippet.path)
305                .or_default()
306                .push(snippet);
307        }
308
309        // Reorder so that file with cursor comes last
310        let mut file_snippets = Vec::new();
311        let mut excerpt_file_snippets = Vec::new();
312        for (file_path, snippets) in file_to_snippets {
313            if file_path == self.request.excerpt_path.as_ref() {
314                excerpt_file_snippets = snippets;
315            } else {
316                file_snippets.push((file_path, snippets, false));
317            }
318        }
319        let excerpt_snippet = PlannedSnippet {
320            path: self.request.excerpt_path.clone(),
321            range: self.request.excerpt_range.clone(),
322            text: &self.request.excerpt,
323            text_is_truncated: false,
324        };
325        excerpt_file_snippets.push(&excerpt_snippet);
326        file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
327
328        let mut excerpt_file_insertions = match self.request.prompt_format {
329            PromptFormat::MarkedExcerpt => vec![
330                (
331                    self.request.excerpt_range.start,
332                    EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
333                ),
334                (
335                    self.request.excerpt_range.start + self.request.cursor_offset,
336                    CURSOR_MARKER,
337                ),
338                (
339                    self.request
340                        .excerpt_range
341                        .end
342                        .saturating_sub(0)
343                        .max(self.request.excerpt_range.start),
344                    EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
345                ),
346            ],
347            PromptFormat::LabeledSections => vec![(
348                self.request.excerpt_range.start + self.request.cursor_offset,
349                CURSOR_MARKER,
350            )],
351            PromptFormat::OnlySnippets => vec![],
352        };
353
354        let mut prompt = String::new();
355        prompt.push_str("## User Edits\n\n");
356        Self::push_events(&mut prompt, &self.request.events);
357
358        prompt.push_str("\n## Code\n\n");
359        let section_labels =
360            self.push_file_snippets(&mut prompt, &mut excerpt_file_insertions, file_snippets)?;
361        Ok((prompt, section_labels))
362    }
363
364    fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
365        for event in events {
366            match event {
367                Event::BufferChange {
368                    path,
369                    old_path,
370                    diff,
371                    predicted,
372                } => {
373                    if let Some(old_path) = &old_path
374                        && let Some(new_path) = &path
375                    {
376                        if old_path != new_path {
377                            writeln!(
378                                output,
379                                "User renamed {} to {}\n\n",
380                                old_path.display(),
381                                new_path.display()
382                            )
383                            .unwrap();
384                        }
385                    }
386
387                    let path = path
388                        .as_ref()
389                        .map_or_else(|| "untitled".to_string(), |path| path.display().to_string());
390
391                    if *predicted {
392                        writeln!(
393                            output,
394                            "User accepted prediction {:?}:\n```diff\n{}\n```\n",
395                            path, diff
396                        )
397                        .unwrap();
398                    } else {
399                        writeln!(output, "User edited {:?}:\n```diff\n{}\n```\n", path, diff)
400                            .unwrap();
401                    }
402                }
403            }
404        }
405    }
406
407    fn push_file_snippets(
408        &self,
409        output: &mut String,
410        excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
411        file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
412    ) -> Result<SectionLabels> {
413        let mut section_ranges = Vec::new();
414        let mut excerpt_index = None;
415
416        for (file_path, mut snippets, is_excerpt_file) in file_snippets {
417            snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
418
419            // TODO: What if the snippets get expanded too large to be editable?
420            let mut current_snippet: Option<(&PlannedSnippet, Range<usize>)> = None;
421            let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<usize>)> = Vec::new();
422            for snippet in snippets {
423                if let Some((_, current_snippet_range)) = current_snippet.as_mut()
424                    && snippet.range.start < current_snippet_range.end
425                {
426                    if snippet.range.end > current_snippet_range.end {
427                        current_snippet_range.end = snippet.range.end;
428                    }
429                    continue;
430                }
431                if let Some(current_snippet) = current_snippet.take() {
432                    disjoint_snippets.push(current_snippet);
433                }
434                current_snippet = Some((snippet, snippet.range.clone()));
435            }
436            if let Some(current_snippet) = current_snippet.take() {
437                disjoint_snippets.push(current_snippet);
438            }
439
440            writeln!(output, "```{}", file_path.display()).ok();
441            let mut skipped_last_snippet = false;
442            for (snippet, range) in disjoint_snippets {
443                let section_index = section_ranges.len();
444
445                match self.request.prompt_format {
446                    PromptFormat::MarkedExcerpt | PromptFormat::OnlySnippets => {
447                        if range.start > 0 && !skipped_last_snippet {
448                            output.push_str("\n");
449                        }
450                    }
451                    PromptFormat::LabeledSections => {
452                        if is_excerpt_file
453                            && range.start <= self.request.excerpt_range.start
454                            && range.end >= self.request.excerpt_range.end
455                        {
456                            writeln!(output, "<|current_section|>").ok();
457                        } else {
458                            writeln!(output, "<|section_{}|>", section_index).ok();
459                        }
460                    }
461                }
462
463                if is_excerpt_file {
464                    if self.request.prompt_format == PromptFormat::OnlySnippets {
465                        if range.start >= self.request.excerpt_range.start
466                            && range.end <= self.request.excerpt_range.end
467                        {
468                            skipped_last_snippet = true;
469                        } else {
470                            skipped_last_snippet = false;
471                            output.push_str(snippet.text);
472                        }
473                    } else {
474                        let mut last_offset = range.start;
475                        let mut i = 0;
476                        while i < excerpt_file_insertions.len() {
477                            let (offset, insertion) = &excerpt_file_insertions[i];
478                            let found = *offset >= range.start && *offset <= range.end;
479                            if found {
480                                excerpt_index = Some(section_index);
481                                output.push_str(
482                                    &snippet.text[last_offset - range.start..offset - range.start],
483                                );
484                                output.push_str(insertion);
485                                last_offset = *offset;
486                                excerpt_file_insertions.remove(i);
487                                continue;
488                            }
489                            i += 1;
490                        }
491                        skipped_last_snippet = false;
492                        output.push_str(&snippet.text[last_offset - range.start..]);
493                    }
494                } else {
495                    skipped_last_snippet = false;
496                    output.push_str(snippet.text);
497                }
498
499                section_ranges.push((snippet.path.clone(), range));
500            }
501
502            output.push_str("```\n\n");
503        }
504
505        Ok(SectionLabels {
506            // TODO: Clean this up
507            excerpt_index: match self.request.prompt_format {
508                PromptFormat::OnlySnippets => 0,
509                _ => excerpt_index.context("bug: no snippet found for excerpt")?,
510            },
511            section_ranges,
512        })
513    }
514}
515
516fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
517    declaration_score(declaration, style) / declaration_size(declaration, style) as f32
518}
519
520fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
521    match style {
522        DeclarationStyle::Signature => declaration.signature_score,
523        DeclarationStyle::Declaration => declaration.declaration_score,
524    }
525}
526
527fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize {
528    match style {
529        DeclarationStyle::Signature => declaration.signature_range.len(),
530        DeclarationStyle::Declaration => declaration.text.len(),
531    }
532}