cloud_zeta2_prompt.rs

  1//! Zeta2 prompt planning and generation code shared with cloud.
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use cloud_llm_client::predict_edits_v3::{
  5    self, Event, Line, Point, PromptFormat, ReferencedDeclaration,
  6};
  7use indoc::indoc;
  8use ordered_float::OrderedFloat;
  9use rustc_hash::{FxHashMap, FxHashSet};
 10use serde::Serialize;
 11use std::fmt::Write;
 12use std::sync::Arc;
 13use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
 14use strum::{EnumIter, IntoEnumIterator};
 15
 16pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
 17
 18pub const CURSOR_MARKER: &str = "<|cursor_position|>";
 19/// NOTE: Differs from zed version of constant - includes a newline
 20pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
 21/// NOTE: Differs from zed version of constant - includes a newline
 22pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
 23
 24// TODO: use constants for markers?
 25const MARKED_EXCERPT_SYSTEM_PROMPT: &str = indoc! {"
 26    You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
 27
 28    The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|cursor_position|>.  Please respond with edited code for that region.
 29
 30    Other code is provided for context, and `…` indicates when code has been skipped.
 31"};
 32
 33const LABELED_SECTIONS_SYSTEM_PROMPT: &str = indoc! {r#"
 34    You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code.
 35
 36    Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`).
 37
 38    The cursor position is marked with `<|cursor_position|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it.
 39
 40    Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example:
 41
 42    <|current_section|>
 43    for i in 0..16 {
 44        println!("{i}");
 45    }
 46"#};
 47
 48const NUMBERED_LINES_SYSTEM_PROMPT: &str = indoc! {r#"
 49    # Instructions
 50
 51    You are a code completion assistant helping a programmer finish their work. Your task is to:
 52
 53    1. Analyze the edit history to understand what the programmer is trying to achieve
 54    2. Identify any incomplete refactoring or changes that need to be finished
 55    3. Make the remaining edits that a human programmer would logically make next
 56    4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere.
 57
 58    Focus on:
 59    - Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs)
 60    - Completing any partially-applied changes across the codebase
 61    - Ensuring consistency with the programming style and patterns already established
 62    - Making edits that maintain or improve code quality
 63    - If the programmer started refactoring one instance of a pattern, find and update ALL similar instances
 64    - Don't write a lot of code if you're not sure what to do
 65
 66    Rules:
 67    - Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
 68    - Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
 69    - Write the edits in the unified diff format as shown in the example.
 70
 71    # Example output:
 72
 73    ```
 74    --- a/distill-claude/tmp-outs/edits_history.txt
 75    +++ b/distill-claude/tmp-outs/edits_history.txt
 76    @@ -1,3 +1,3 @@
 77    -
 78    -
 79    -import sys
 80    +import json
 81    ```
 82"#};
 83
 84pub struct PlannedPrompt<'a> {
 85    request: &'a predict_edits_v3::PredictEditsRequest,
 86    /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
 87    /// `to_prompt_string`.
 88    snippets: Vec<PlannedSnippet<'a>>,
 89    budget_used: usize,
 90}
 91
 92pub fn system_prompt(format: PromptFormat) -> &'static str {
 93    match format {
 94        PromptFormat::MarkedExcerpt => MARKED_EXCERPT_SYSTEM_PROMPT,
 95        PromptFormat::LabeledSections => LABELED_SECTIONS_SYSTEM_PROMPT,
 96        PromptFormat::NumberedLines => NUMBERED_LINES_SYSTEM_PROMPT,
 97        // only intended for use via zeta_cli
 98        PromptFormat::OnlySnippets => "",
 99    }
100}
101
102#[derive(Clone, Debug)]
103pub struct PlannedSnippet<'a> {
104    path: Arc<Path>,
105    range: Range<Line>,
106    text: &'a str,
107    // TODO: Indicate this in the output
108    #[allow(dead_code)]
109    text_is_truncated: bool,
110}
111
112#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
113pub enum DeclarationStyle {
114    Signature,
115    Declaration,
116}
117
118#[derive(Clone, Debug, Serialize)]
119pub struct SectionLabels {
120    pub excerpt_index: usize,
121    pub section_ranges: Vec<(Arc<Path>, Range<Line>)>,
122}
123
124impl<'a> PlannedPrompt<'a> {
125    /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
126    ///
127    /// Initializes a priority queue by populating it with each snippet, finding the
128    /// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a
129    /// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects
130    /// the cost of upgrade.
131    ///
132    /// TODO: Implement an early halting condition. One option might be to have another priority
133    /// queue where the score is the size, and update it accordingly. Another option might be to
134    /// have some simpler heuristic like bailing after N failed insertions, or based on how much
135    /// budget is left.
136    ///
137    /// TODO: Has the current known sources of imprecision:
138    ///
139    /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
140    /// plan even though the containing struct is already included.
141    ///
142    /// * Does not consider cost of signatures when ranking snippets - this is tricky since
143    /// signatures may be shared by multiple snippets.
144    ///
145    /// * Does not include file paths / other text when considering max_bytes.
146    pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
147        let mut this = PlannedPrompt {
148            request,
149            snippets: Vec::new(),
150            budget_used: request.excerpt.len(),
151        };
152        let mut included_parents = FxHashSet::default();
153        let additional_parents = this.additional_parent_signatures(
154            &request.excerpt_path,
155            request.excerpt_parent,
156            &included_parents,
157        )?;
158        this.add_parents(&mut included_parents, additional_parents);
159
160        let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES);
161
162        if this.budget_used > max_bytes {
163            return Err(anyhow!(
164                "Excerpt + signatures size of {} already exceeds budget of {}",
165                this.budget_used,
166                max_bytes
167            ));
168        }
169
170        #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
171        struct QueueEntry {
172            score_density: OrderedFloat<f32>,
173            declaration_index: usize,
174            style: DeclarationStyle,
175        }
176
177        // Initialize priority queue with the best score for each snippet.
178        let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
179        for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
180            let (style, score_density) = DeclarationStyle::iter()
181                .map(|style| {
182                    (
183                        style,
184                        OrderedFloat(declaration_score_density(&declaration, style)),
185                    )
186                })
187                .max_by_key(|(_, score_density)| *score_density)
188                .unwrap();
189            queue.push(QueueEntry {
190                score_density,
191                declaration_index,
192                style,
193            });
194        }
195
196        // Knapsack selection loop
197        while let Some(queue_entry) = queue.pop() {
198            let Some(declaration) = request
199                .referenced_declarations
200                .get(queue_entry.declaration_index)
201            else {
202                return Err(anyhow!(
203                    "Invalid declaration index {}",
204                    queue_entry.declaration_index
205                ));
206            };
207
208            let mut additional_bytes = declaration_size(declaration, queue_entry.style);
209            if this.budget_used + additional_bytes > max_bytes {
210                continue;
211            }
212
213            let additional_parents = this.additional_parent_signatures(
214                &declaration.path,
215                declaration.parent_index,
216                &mut included_parents,
217            )?;
218            additional_bytes += additional_parents
219                .iter()
220                .map(|(_, snippet)| snippet.text.len())
221                .sum::<usize>();
222            if this.budget_used + additional_bytes > max_bytes {
223                continue;
224            }
225
226            this.budget_used += additional_bytes;
227            this.add_parents(&mut included_parents, additional_parents);
228            let planned_snippet = match queue_entry.style {
229                DeclarationStyle::Signature => {
230                    let Some(text) = declaration.text.get(declaration.signature_range.clone())
231                    else {
232                        return Err(anyhow!(
233                            "Invalid declaration signature_range {:?} with text.len() = {}",
234                            declaration.signature_range,
235                            declaration.text.len()
236                        ));
237                    };
238                    let signature_start_line = declaration.range.start
239                        + Line(
240                            declaration.text[..declaration.signature_range.start]
241                                .lines()
242                                .count() as u32,
243                        );
244                    let signature_end_line = signature_start_line
245                        + Line(
246                            declaration.text
247                                [declaration.signature_range.start..declaration.signature_range.end]
248                                .lines()
249                                .count() as u32,
250                        );
251                    let range = signature_start_line..signature_end_line;
252
253                    PlannedSnippet {
254                        path: declaration.path.clone(),
255                        range,
256                        text,
257                        text_is_truncated: declaration.text_is_truncated,
258                    }
259                }
260                DeclarationStyle::Declaration => PlannedSnippet {
261                    path: declaration.path.clone(),
262                    range: declaration.range.clone(),
263                    text: &declaration.text,
264                    text_is_truncated: declaration.text_is_truncated,
265                },
266            };
267            this.snippets.push(planned_snippet);
268
269            // When a Signature is consumed, insert an entry for Definition style.
270            if queue_entry.style == DeclarationStyle::Signature {
271                let signature_size = declaration_size(&declaration, DeclarationStyle::Signature);
272                let declaration_size =
273                    declaration_size(&declaration, DeclarationStyle::Declaration);
274                let signature_score = declaration_score(&declaration, DeclarationStyle::Signature);
275                let declaration_score =
276                    declaration_score(&declaration, DeclarationStyle::Declaration);
277
278                let score_diff = declaration_score - signature_score;
279                let size_diff = declaration_size.saturating_sub(signature_size);
280                if score_diff > 0.0001 && size_diff > 0 {
281                    queue.push(QueueEntry {
282                        declaration_index: queue_entry.declaration_index,
283                        score_density: OrderedFloat(score_diff / (size_diff as f32)),
284                        style: DeclarationStyle::Declaration,
285                    });
286                }
287            }
288        }
289
290        anyhow::Ok(this)
291    }
292
293    fn add_parents(
294        &mut self,
295        included_parents: &mut FxHashSet<usize>,
296        snippets: Vec<(usize, PlannedSnippet<'a>)>,
297    ) {
298        for (parent_index, snippet) in snippets {
299            included_parents.insert(parent_index);
300            self.budget_used += snippet.text.len();
301            self.snippets.push(snippet);
302        }
303    }
304
305    fn additional_parent_signatures(
306        &self,
307        path: &Arc<Path>,
308        parent_index: Option<usize>,
309        included_parents: &FxHashSet<usize>,
310    ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
311        let mut results = Vec::new();
312        self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
313        Ok(results)
314    }
315
316    fn additional_parent_signatures_impl(
317        &self,
318        path: &Arc<Path>,
319        parent_index: Option<usize>,
320        included_parents: &FxHashSet<usize>,
321        results: &mut Vec<(usize, PlannedSnippet<'a>)>,
322    ) -> Result<()> {
323        let Some(parent_index) = parent_index else {
324            return Ok(());
325        };
326        if included_parents.contains(&parent_index) {
327            return Ok(());
328        }
329        let Some(parent_signature) = self.request.signatures.get(parent_index) else {
330            return Err(anyhow!("Invalid parent index {}", parent_index));
331        };
332        results.push((
333            parent_index,
334            PlannedSnippet {
335                path: path.clone(),
336                range: parent_signature.range.clone(),
337                text: &parent_signature.text,
338                text_is_truncated: parent_signature.text_is_truncated,
339            },
340        ));
341        self.additional_parent_signatures_impl(
342            path,
343            parent_signature.parent_index,
344            included_parents,
345            results,
346        )
347    }
348
349    /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
350    /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
351    /// chunks.
352    pub fn to_prompt_string(&'a self) -> Result<(String, SectionLabels)> {
353        let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
354            FxHashMap::default();
355        for snippet in &self.snippets {
356            file_to_snippets
357                .entry(&snippet.path)
358                .or_default()
359                .push(snippet);
360        }
361
362        // Reorder so that file with cursor comes last
363        let mut file_snippets = Vec::new();
364        let mut excerpt_file_snippets = Vec::new();
365        for (file_path, snippets) in file_to_snippets {
366            if file_path == self.request.excerpt_path.as_ref() {
367                excerpt_file_snippets = snippets;
368            } else {
369                file_snippets.push((file_path, snippets, false));
370            }
371        }
372        let excerpt_snippet = PlannedSnippet {
373            path: self.request.excerpt_path.clone(),
374            range: self.request.excerpt_line_range.clone(),
375            text: &self.request.excerpt,
376            text_is_truncated: false,
377        };
378        excerpt_file_snippets.push(&excerpt_snippet);
379        file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
380
381        let mut excerpt_file_insertions = match self.request.prompt_format {
382            PromptFormat::MarkedExcerpt => vec![
383                (
384                    Point {
385                        line: self.request.excerpt_line_range.start,
386                        column: 0,
387                    },
388                    EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
389                ),
390                (self.request.cursor_point, CURSOR_MARKER),
391                (
392                    Point {
393                        line: self.request.excerpt_line_range.end,
394                        column: 0,
395                    },
396                    EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
397                ),
398            ],
399            PromptFormat::LabeledSections => vec![(self.request.cursor_point, CURSOR_MARKER)],
400            PromptFormat::NumberedLines => vec![(self.request.cursor_point, CURSOR_MARKER)],
401            PromptFormat::OnlySnippets => vec![],
402        };
403
404        let mut prompt = String::new();
405        prompt.push_str("## User Edits\n\n");
406        if self.request.events.is_empty() {
407            prompt.push_str("No edits yet.\n");
408        } else {
409            Self::push_events(&mut prompt, &self.request.events);
410        }
411
412        prompt.push_str("\n## Code\n\n");
413        let section_labels =
414            self.push_file_snippets(&mut prompt, &mut excerpt_file_insertions, file_snippets)?;
415        Ok((prompt, section_labels))
416    }
417
418    fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
419        for event in events {
420            match event {
421                Event::BufferChange {
422                    path,
423                    old_path,
424                    diff,
425                    predicted,
426                } => {
427                    if let Some(old_path) = &old_path
428                        && let Some(new_path) = &path
429                    {
430                        if old_path != new_path {
431                            writeln!(
432                                output,
433                                "User renamed {} to {}\n\n",
434                                old_path.display(),
435                                new_path.display()
436                            )
437                            .unwrap();
438                        }
439                    }
440
441                    let path = path
442                        .as_ref()
443                        .map_or_else(|| "untitled".to_string(), |path| path.display().to_string());
444
445                    if *predicted {
446                        writeln!(
447                            output,
448                            "User accepted prediction {:?}:\n`````diff\n{}\n`````\n",
449                            path, diff
450                        )
451                        .unwrap();
452                    } else {
453                        writeln!(
454                            output,
455                            "User edited {:?}:\n`````diff\n{}\n`````\n",
456                            path, diff
457                        )
458                        .unwrap();
459                    }
460                }
461            }
462        }
463    }
464
465    fn push_file_snippets(
466        &self,
467        output: &mut String,
468        excerpt_file_insertions: &mut Vec<(Point, &'static str)>,
469        file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
470    ) -> Result<SectionLabels> {
471        let mut section_ranges = Vec::new();
472        let mut excerpt_index = None;
473
474        for (file_path, mut snippets, is_excerpt_file) in file_snippets {
475            snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
476
477            // TODO: What if the snippets get expanded too large to be editable?
478            let mut current_snippet: Option<(&PlannedSnippet, Range<Line>)> = None;
479            let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<Line>)> = Vec::new();
480            for snippet in snippets {
481                if let Some((_, current_snippet_range)) = current_snippet.as_mut()
482                    && snippet.range.start <= current_snippet_range.end
483                {
484                    current_snippet_range.end = current_snippet_range.end.max(snippet.range.end);
485                    continue;
486                }
487                if let Some(current_snippet) = current_snippet.take() {
488                    disjoint_snippets.push(current_snippet);
489                }
490                current_snippet = Some((snippet, snippet.range.clone()));
491            }
492            if let Some(current_snippet) = current_snippet.take() {
493                disjoint_snippets.push(current_snippet);
494            }
495
496            // TODO: remove filename=?
497            writeln!(output, "`````filename={}", file_path.display()).ok();
498            let mut skipped_last_snippet = false;
499            for (snippet, range) in disjoint_snippets {
500                let section_index = section_ranges.len();
501
502                match self.request.prompt_format {
503                    PromptFormat::MarkedExcerpt
504                    | PromptFormat::OnlySnippets
505                    | PromptFormat::NumberedLines => {
506                        if range.start.0 > 0 && !skipped_last_snippet {
507                            output.push_str("\n");
508                        }
509                    }
510                    PromptFormat::LabeledSections => {
511                        if is_excerpt_file
512                            && range.start <= self.request.excerpt_line_range.start
513                            && range.end >= self.request.excerpt_line_range.end
514                        {
515                            writeln!(output, "<|current_section|>").ok();
516                        } else {
517                            writeln!(output, "<|section_{}|>", section_index).ok();
518                        }
519                    }
520                }
521
522                let push_full_snippet = |output: &mut String| {
523                    if self.request.prompt_format == PromptFormat::NumberedLines {
524                        for (i, line) in snippet.text.lines().enumerate() {
525                            writeln!(output, "{}|{}", i as u32 + range.start.0 + 1, line)?;
526                        }
527                    } else {
528                        output.push_str(&snippet.text);
529                    }
530                    anyhow::Ok(())
531                };
532
533                if is_excerpt_file {
534                    if self.request.prompt_format == PromptFormat::OnlySnippets {
535                        if range.start >= self.request.excerpt_line_range.start
536                            && range.end <= self.request.excerpt_line_range.end
537                        {
538                            skipped_last_snippet = true;
539                        } else {
540                            skipped_last_snippet = false;
541                            output.push_str(snippet.text);
542                        }
543                    } else if !excerpt_file_insertions.is_empty() {
544                        let lines = snippet.text.lines().collect::<Vec<_>>();
545                        let push_line = |output: &mut String, line_ix: usize| {
546                            if self.request.prompt_format == PromptFormat::NumberedLines {
547                                write!(output, "{}|", line_ix as u32 + range.start.0 + 1)?;
548                            }
549                            anyhow::Ok(writeln!(output, "{}", lines[line_ix])?)
550                        };
551                        let mut last_line_ix = 0;
552                        let mut insertion_ix = 0;
553                        while insertion_ix < excerpt_file_insertions.len() {
554                            let (point, insertion) = &excerpt_file_insertions[insertion_ix];
555                            let found = point.line >= range.start && point.line <= range.end;
556                            if found {
557                                excerpt_index = Some(section_index);
558                                let insertion_line_ix = (point.line.0 - range.start.0) as usize;
559                                for line_ix in last_line_ix..insertion_line_ix {
560                                    push_line(output, line_ix)?;
561                                }
562                                if let Some(next_line) = lines.get(insertion_line_ix) {
563                                    if self.request.prompt_format == PromptFormat::NumberedLines {
564                                        write!(
565                                            output,
566                                            "{}|",
567                                            insertion_line_ix as u32 + range.start.0 + 1
568                                        )?
569                                    }
570                                    output.push_str(&next_line[..point.column as usize]);
571                                    output.push_str(insertion);
572                                    writeln!(output, "{}", &next_line[point.column as usize..])?;
573                                } else {
574                                    writeln!(output, "{}", insertion)?;
575                                }
576                                last_line_ix = insertion_line_ix + 1;
577                                excerpt_file_insertions.remove(insertion_ix);
578                                continue;
579                            }
580                            insertion_ix += 1;
581                        }
582                        skipped_last_snippet = false;
583                        for line_ix in last_line_ix..lines.len() {
584                            push_line(output, line_ix)?;
585                        }
586                    } else {
587                        skipped_last_snippet = false;
588                        push_full_snippet(output)?;
589                    }
590                } else {
591                    skipped_last_snippet = false;
592                    push_full_snippet(output)?;
593                }
594
595                section_ranges.push((snippet.path.clone(), range));
596            }
597
598            output.push_str("`````\n\n");
599        }
600
601        Ok(SectionLabels {
602            // TODO: Clean this up
603            excerpt_index: match self.request.prompt_format {
604                PromptFormat::OnlySnippets => 0,
605                _ => excerpt_index.context("bug: no snippet found for excerpt")?,
606            },
607            section_ranges,
608        })
609    }
610}
611
612fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
613    declaration_score(declaration, style) / declaration_size(declaration, style) as f32
614}
615
616fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
617    match style {
618        DeclarationStyle::Signature => declaration.signature_score,
619        DeclarationStyle::Declaration => declaration.declaration_score,
620    }
621}
622
623fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize {
624    match style {
625        DeclarationStyle::Signature => declaration.signature_range.len(),
626        DeclarationStyle::Declaration => declaration.text.len(),
627    }
628}