cloud_zeta2_prompt.rs

  1//! Zeta2 prompt planning and generation code shared with cloud.
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use cloud_llm_client::predict_edits_v3::{self, Line, Point, PromptFormat, ReferencedDeclaration};
  5use indoc::indoc;
  6use ordered_float::OrderedFloat;
  7use rustc_hash::{FxHashMap, FxHashSet};
  8use serde::Serialize;
  9use std::fmt::Write;
 10use std::sync::Arc;
 11use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
 12use strum::{EnumIter, IntoEnumIterator};
 13
 14pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
 15
 16pub const CURSOR_MARKER: &str = "<|user_cursor|>";
 17/// NOTE: Differs from zed version of constant - includes a newline
 18pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
 19/// NOTE: Differs from zed version of constant - includes a newline
 20pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
 21
 22// TODO: use constants for markers?
 23const MARKED_EXCERPT_INSTRUCTIONS: &str = indoc! {"
 24    You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
 25
 26    The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|user_cursor|>.  Please respond with edited code for that region.
 27
 28    Other code is provided for context, and `…` indicates when code has been skipped.
 29
 30    # Edit History:
 31
 32"};
 33
 34const LABELED_SECTIONS_INSTRUCTIONS: &str = indoc! {r#"
 35    You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code.
 36
 37    Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`).
 38
 39    The cursor position is marked with `<|user_cursor|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it.
 40
 41    Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example:
 42
 43    <|current_section|>
 44    for i in 0..16 {
 45        println!("{i}");
 46    }
 47
 48    # Edit History:
 49
 50"#};
 51
 52const NUMBERED_LINES_INSTRUCTIONS: &str = indoc! {r#"
 53    # Instructions
 54
 55    You are a code completion assistant helping a programmer finish their work. Your task is to:
 56
 57    1. Analyze the edit history to understand what the programmer is trying to achieve
 58    2. Identify any incomplete refactoring or changes that need to be finished
 59    3. Make the remaining edits that a human programmer would logically make next
 60    4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere.
 61
 62    Focus on:
 63    - Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs)
 64    - Completing any partially-applied changes across the codebase
 65    - Ensuring consistency with the programming style and patterns already established
 66    - Making edits that maintain or improve code quality
 67    - If the programmer started refactoring one instance of a pattern, find and update ALL similar instances
 68    - Don't write a lot of code if you're not sure what to do
 69
 70    Rules:
 71    - Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
 72    - Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
 73    - Write the edits in the unified diff format as shown in the example.
 74
 75    # Example output:
 76
 77    ```
 78    --- a/src/myapp/cli.py
 79    +++ b/src/myapp/cli.py
 80    @@ -1,3 +1,3 @@
 81    -
 82    -
 83    -import sys
 84    +import json
 85    ```
 86
 87    # Edit History:
 88
 89"#};
 90
 91const UNIFIED_DIFF_REMINDER: &str = indoc! {"
 92    ---
 93
 94    Please analyze the edit history and the files, then provide the unified diff for your predicted edits.
 95    Do not include the cursor marker in your output.
 96    If you're editing multiple files, be sure to reflect filename in the hunk's header.
 97"};
 98
 99pub struct PlannedPrompt<'a> {
100    request: &'a predict_edits_v3::PredictEditsRequest,
101    /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
102    /// `to_prompt_string`.
103    snippets: Vec<PlannedSnippet<'a>>,
104    budget_used: usize,
105}
106
107#[derive(Clone, Debug)]
108pub struct PlannedSnippet<'a> {
109    path: Arc<Path>,
110    range: Range<Line>,
111    text: &'a str,
112    // TODO: Indicate this in the output
113    #[allow(dead_code)]
114    text_is_truncated: bool,
115}
116
117#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
118pub enum DeclarationStyle {
119    Signature,
120    Declaration,
121}
122
123#[derive(Clone, Debug, Serialize)]
124pub struct SectionLabels {
125    pub excerpt_index: usize,
126    pub section_ranges: Vec<(Arc<Path>, Range<Line>)>,
127}
128
129impl<'a> PlannedPrompt<'a> {
130    /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
131    ///
132    /// Initializes a priority queue by populating it with each snippet, finding the
133    /// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a
134    /// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects
135    /// the cost of upgrade.
136    ///
137    /// TODO: Implement an early halting condition. One option might be to have another priority
138    /// queue where the score is the size, and update it accordingly. Another option might be to
139    /// have some simpler heuristic like bailing after N failed insertions, or based on how much
140    /// budget is left.
141    ///
142    /// TODO: Has the current known sources of imprecision:
143    ///
144    /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
145    /// plan even though the containing struct is already included.
146    ///
147    /// * Does not consider cost of signatures when ranking snippets - this is tricky since
148    /// signatures may be shared by multiple snippets.
149    ///
150    /// * Does not include file paths / other text when considering max_bytes.
151    pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
152        let mut this = PlannedPrompt {
153            request,
154            snippets: Vec::new(),
155            budget_used: request.excerpt.len(),
156        };
157        let mut included_parents = FxHashSet::default();
158        let additional_parents = this.additional_parent_signatures(
159            &request.excerpt_path,
160            request.excerpt_parent,
161            &included_parents,
162        )?;
163        this.add_parents(&mut included_parents, additional_parents);
164
165        let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES);
166
167        if this.budget_used > max_bytes {
168            return Err(anyhow!(
169                "Excerpt + signatures size of {} already exceeds budget of {}",
170                this.budget_used,
171                max_bytes
172            ));
173        }
174
175        #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
176        struct QueueEntry {
177            score_density: OrderedFloat<f32>,
178            declaration_index: usize,
179            style: DeclarationStyle,
180        }
181
182        // Initialize priority queue with the best score for each snippet.
183        let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
184        for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
185            let (style, score_density) = DeclarationStyle::iter()
186                .map(|style| {
187                    (
188                        style,
189                        OrderedFloat(declaration_score_density(&declaration, style)),
190                    )
191                })
192                .max_by_key(|(_, score_density)| *score_density)
193                .unwrap();
194            queue.push(QueueEntry {
195                score_density,
196                declaration_index,
197                style,
198            });
199        }
200
201        // Knapsack selection loop
202        while let Some(queue_entry) = queue.pop() {
203            let Some(declaration) = request
204                .referenced_declarations
205                .get(queue_entry.declaration_index)
206            else {
207                return Err(anyhow!(
208                    "Invalid declaration index {}",
209                    queue_entry.declaration_index
210                ));
211            };
212
213            let mut additional_bytes = declaration_size(declaration, queue_entry.style);
214            if this.budget_used + additional_bytes > max_bytes {
215                continue;
216            }
217
218            let additional_parents = this.additional_parent_signatures(
219                &declaration.path,
220                declaration.parent_index,
221                &mut included_parents,
222            )?;
223            additional_bytes += additional_parents
224                .iter()
225                .map(|(_, snippet)| snippet.text.len())
226                .sum::<usize>();
227            if this.budget_used + additional_bytes > max_bytes {
228                continue;
229            }
230
231            this.budget_used += additional_bytes;
232            this.add_parents(&mut included_parents, additional_parents);
233            let planned_snippet = match queue_entry.style {
234                DeclarationStyle::Signature => {
235                    let Some(text) = declaration.text.get(declaration.signature_range.clone())
236                    else {
237                        return Err(anyhow!(
238                            "Invalid declaration signature_range {:?} with text.len() = {}",
239                            declaration.signature_range,
240                            declaration.text.len()
241                        ));
242                    };
243                    let signature_start_line = declaration.range.start
244                        + Line(
245                            declaration.text[..declaration.signature_range.start]
246                                .lines()
247                                .count() as u32,
248                        );
249                    let signature_end_line = signature_start_line
250                        + Line(
251                            declaration.text
252                                [declaration.signature_range.start..declaration.signature_range.end]
253                                .lines()
254                                .count() as u32,
255                        );
256                    let range = signature_start_line..signature_end_line;
257
258                    PlannedSnippet {
259                        path: declaration.path.clone(),
260                        range,
261                        text,
262                        text_is_truncated: declaration.text_is_truncated,
263                    }
264                }
265                DeclarationStyle::Declaration => PlannedSnippet {
266                    path: declaration.path.clone(),
267                    range: declaration.range.clone(),
268                    text: &declaration.text,
269                    text_is_truncated: declaration.text_is_truncated,
270                },
271            };
272            this.snippets.push(planned_snippet);
273
274            // When a Signature is consumed, insert an entry for Definition style.
275            if queue_entry.style == DeclarationStyle::Signature {
276                let signature_size = declaration_size(&declaration, DeclarationStyle::Signature);
277                let declaration_size =
278                    declaration_size(&declaration, DeclarationStyle::Declaration);
279                let signature_score = declaration_score(&declaration, DeclarationStyle::Signature);
280                let declaration_score =
281                    declaration_score(&declaration, DeclarationStyle::Declaration);
282
283                let score_diff = declaration_score - signature_score;
284                let size_diff = declaration_size.saturating_sub(signature_size);
285                if score_diff > 0.0001 && size_diff > 0 {
286                    queue.push(QueueEntry {
287                        declaration_index: queue_entry.declaration_index,
288                        score_density: OrderedFloat(score_diff / (size_diff as f32)),
289                        style: DeclarationStyle::Declaration,
290                    });
291                }
292            }
293        }
294
295        anyhow::Ok(this)
296    }
297
298    fn add_parents(
299        &mut self,
300        included_parents: &mut FxHashSet<usize>,
301        snippets: Vec<(usize, PlannedSnippet<'a>)>,
302    ) {
303        for (parent_index, snippet) in snippets {
304            included_parents.insert(parent_index);
305            self.budget_used += snippet.text.len();
306            self.snippets.push(snippet);
307        }
308    }
309
310    fn additional_parent_signatures(
311        &self,
312        path: &Arc<Path>,
313        parent_index: Option<usize>,
314        included_parents: &FxHashSet<usize>,
315    ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
316        let mut results = Vec::new();
317        self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
318        Ok(results)
319    }
320
321    fn additional_parent_signatures_impl(
322        &self,
323        path: &Arc<Path>,
324        parent_index: Option<usize>,
325        included_parents: &FxHashSet<usize>,
326        results: &mut Vec<(usize, PlannedSnippet<'a>)>,
327    ) -> Result<()> {
328        let Some(parent_index) = parent_index else {
329            return Ok(());
330        };
331        if included_parents.contains(&parent_index) {
332            return Ok(());
333        }
334        let Some(parent_signature) = self.request.signatures.get(parent_index) else {
335            return Err(anyhow!("Invalid parent index {}", parent_index));
336        };
337        results.push((
338            parent_index,
339            PlannedSnippet {
340                path: path.clone(),
341                range: parent_signature.range.clone(),
342                text: &parent_signature.text,
343                text_is_truncated: parent_signature.text_is_truncated,
344            },
345        ));
346        self.additional_parent_signatures_impl(
347            path,
348            parent_signature.parent_index,
349            included_parents,
350            results,
351        )
352    }
353
354    /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
355    /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
356    /// chunks.
357    pub fn to_prompt_string(&'a self) -> Result<(String, SectionLabels)> {
358        let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
359            FxHashMap::default();
360        for snippet in &self.snippets {
361            file_to_snippets
362                .entry(&snippet.path)
363                .or_default()
364                .push(snippet);
365        }
366
367        // Reorder so that file with cursor comes last
368        let mut file_snippets = Vec::new();
369        let mut excerpt_file_snippets = Vec::new();
370        for (file_path, snippets) in file_to_snippets {
371            if file_path == self.request.excerpt_path.as_ref() {
372                excerpt_file_snippets = snippets;
373            } else {
374                file_snippets.push((file_path, snippets, false));
375            }
376        }
377        let excerpt_snippet = PlannedSnippet {
378            path: self.request.excerpt_path.clone(),
379            range: self.request.excerpt_line_range.clone(),
380            text: &self.request.excerpt,
381            text_is_truncated: false,
382        };
383        excerpt_file_snippets.push(&excerpt_snippet);
384        file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
385
386        let mut excerpt_file_insertions = match self.request.prompt_format {
387            PromptFormat::MarkedExcerpt => vec![
388                (
389                    Point {
390                        line: self.request.excerpt_line_range.start,
391                        column: 0,
392                    },
393                    EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
394                ),
395                (self.request.cursor_point, CURSOR_MARKER),
396                (
397                    Point {
398                        line: self.request.excerpt_line_range.end,
399                        column: 0,
400                    },
401                    EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
402                ),
403            ],
404            PromptFormat::LabeledSections => vec![(self.request.cursor_point, CURSOR_MARKER)],
405            PromptFormat::NumLinesUniDiff => {
406                vec![(self.request.cursor_point, CURSOR_MARKER)]
407            }
408            PromptFormat::OnlySnippets => vec![],
409        };
410
411        let mut prompt = match self.request.prompt_format {
412            PromptFormat::MarkedExcerpt => MARKED_EXCERPT_INSTRUCTIONS.to_string(),
413            PromptFormat::LabeledSections => LABELED_SECTIONS_INSTRUCTIONS.to_string(),
414            PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(),
415            // only intended for use via zeta_cli
416            PromptFormat::OnlySnippets => String::new(),
417        };
418
419        if self.request.events.is_empty() {
420            prompt.push_str("(No edit history)\n\n");
421        } else {
422            prompt.push_str(
423                "The following are the latest edits made by the user, from earlier to later.\n\n",
424            );
425            Self::push_events(&mut prompt, &self.request.events);
426        }
427
428        if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
429            if self.request.referenced_declarations.is_empty() {
430                prompt.push_str(indoc! {"
431                    # File under the cursor:
432
433                    The cursor marker <|user_cursor|> indicates the current user cursor position.
434                    The file is in current state, edits from edit history have been applied.
435                    We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.
436
437                "});
438            } else {
439                // Note: This hasn't been trained on yet
440                prompt.push_str(indoc! {"
441                    # Code Excerpts:
442
443                    The cursor marker <|user_cursor|> indicates the current user cursor position.
444                    Other excerpts of code from the project have been included as context based on their similarity to the code under the cursor.
445                    Context excerpts are not guaranteed to be relevant, so use your own judgement.
446                    Files are in their current state, edits from edit history have been applied.
447                    We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.
448
449                "});
450            }
451        } else {
452            prompt.push_str("\n## Code\n\n");
453        }
454
455        let section_labels =
456            self.push_file_snippets(&mut prompt, &mut excerpt_file_insertions, file_snippets)?;
457
458        if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
459            prompt.push_str(UNIFIED_DIFF_REMINDER);
460        }
461
462        Ok((prompt, section_labels))
463    }
464
465    fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
466        if events.is_empty() {
467            return;
468        };
469
470        writeln!(output, "`````diff").unwrap();
471        for event in events {
472            writeln!(output, "{}", event).unwrap();
473        }
474        writeln!(output, "`````\n").unwrap();
475    }
476
477    fn push_file_snippets(
478        &self,
479        output: &mut String,
480        excerpt_file_insertions: &mut Vec<(Point, &'static str)>,
481        file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
482    ) -> Result<SectionLabels> {
483        let mut section_ranges = Vec::new();
484        let mut excerpt_index = None;
485
486        for (file_path, mut snippets, is_excerpt_file) in file_snippets {
487            snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
488
489            // TODO: What if the snippets get expanded too large to be editable?
490            let mut current_snippet: Option<(&PlannedSnippet, Range<Line>)> = None;
491            let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<Line>)> = Vec::new();
492            for snippet in snippets {
493                if let Some((_, current_snippet_range)) = current_snippet.as_mut()
494                    && snippet.range.start <= current_snippet_range.end
495                {
496                    current_snippet_range.end = current_snippet_range.end.max(snippet.range.end);
497                    continue;
498                }
499                if let Some(current_snippet) = current_snippet.take() {
500                    disjoint_snippets.push(current_snippet);
501                }
502                current_snippet = Some((snippet, snippet.range.clone()));
503            }
504            if let Some(current_snippet) = current_snippet.take() {
505                disjoint_snippets.push(current_snippet);
506            }
507
508            // TODO: remove filename=?
509            writeln!(output, "`````filename={}", file_path.display()).ok();
510            let mut skipped_last_snippet = false;
511            for (snippet, range) in disjoint_snippets {
512                let section_index = section_ranges.len();
513
514                match self.request.prompt_format {
515                    PromptFormat::MarkedExcerpt
516                    | PromptFormat::OnlySnippets
517                    | PromptFormat::NumLinesUniDiff => {
518                        if range.start.0 > 0 && !skipped_last_snippet {
519                            output.push_str("\n");
520                        }
521                    }
522                    PromptFormat::LabeledSections => {
523                        if is_excerpt_file
524                            && range.start <= self.request.excerpt_line_range.start
525                            && range.end >= self.request.excerpt_line_range.end
526                        {
527                            writeln!(output, "<|current_section|>").ok();
528                        } else {
529                            writeln!(output, "<|section_{}|>", section_index).ok();
530                        }
531                    }
532                }
533
534                let push_full_snippet = |output: &mut String| {
535                    if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
536                        for (i, line) in snippet.text.lines().enumerate() {
537                            writeln!(output, "{}|{}", i as u32 + range.start.0 + 1, line)?;
538                        }
539                    } else {
540                        output.push_str(&snippet.text);
541                    }
542                    anyhow::Ok(())
543                };
544
545                if is_excerpt_file {
546                    if self.request.prompt_format == PromptFormat::OnlySnippets {
547                        if range.start >= self.request.excerpt_line_range.start
548                            && range.end <= self.request.excerpt_line_range.end
549                        {
550                            skipped_last_snippet = true;
551                        } else {
552                            skipped_last_snippet = false;
553                            output.push_str(snippet.text);
554                        }
555                    } else if !excerpt_file_insertions.is_empty() {
556                        let lines = snippet.text.lines().collect::<Vec<_>>();
557                        let push_line = |output: &mut String, line_ix: usize| {
558                            if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
559                                write!(output, "{}|", line_ix as u32 + range.start.0 + 1)?;
560                            }
561                            anyhow::Ok(writeln!(output, "{}", lines[line_ix])?)
562                        };
563                        let mut last_line_ix = 0;
564                        let mut insertion_ix = 0;
565                        while insertion_ix < excerpt_file_insertions.len() {
566                            let (point, insertion) = &excerpt_file_insertions[insertion_ix];
567                            let found = point.line >= range.start && point.line <= range.end;
568                            if found {
569                                excerpt_index = Some(section_index);
570                                let insertion_line_ix = (point.line.0 - range.start.0) as usize;
571                                for line_ix in last_line_ix..insertion_line_ix {
572                                    push_line(output, line_ix)?;
573                                }
574                                if let Some(next_line) = lines.get(insertion_line_ix) {
575                                    if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
576                                        write!(
577                                            output,
578                                            "{}|",
579                                            insertion_line_ix as u32 + range.start.0 + 1
580                                        )?
581                                    }
582                                    output.push_str(&next_line[..point.column as usize]);
583                                    output.push_str(insertion);
584                                    writeln!(output, "{}", &next_line[point.column as usize..])?;
585                                } else {
586                                    writeln!(output, "{}", insertion)?;
587                                }
588                                last_line_ix = insertion_line_ix + 1;
589                                excerpt_file_insertions.remove(insertion_ix);
590                                continue;
591                            }
592                            insertion_ix += 1;
593                        }
594                        skipped_last_snippet = false;
595                        for line_ix in last_line_ix..lines.len() {
596                            push_line(output, line_ix)?;
597                        }
598                    } else {
599                        skipped_last_snippet = false;
600                        push_full_snippet(output)?;
601                    }
602                } else {
603                    skipped_last_snippet = false;
604                    push_full_snippet(output)?;
605                }
606
607                section_ranges.push((snippet.path.clone(), range));
608            }
609
610            output.push_str("`````\n\n");
611        }
612
613        Ok(SectionLabels {
614            // TODO: Clean this up
615            excerpt_index: match self.request.prompt_format {
616                PromptFormat::OnlySnippets => 0,
617                _ => excerpt_index.context("bug: no snippet found for excerpt")?,
618            },
619            section_ranges,
620        })
621    }
622}
623
624fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
625    declaration_score(declaration, style) / declaration_size(declaration, style) as f32
626}
627
628fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
629    match style {
630        DeclarationStyle::Signature => declaration.signature_score,
631        DeclarationStyle::Declaration => declaration.declaration_score,
632    }
633}
634
635fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize {
636    match style {
637        DeclarationStyle::Signature => declaration.signature_range.len(),
638        DeclarationStyle::Declaration => declaration.text.len(),
639    }
640}