cloud_zeta2_prompt.rs

  1//! Zeta2 prompt planning and generation code shared with cloud.
  2
  3use anyhow::{Result, anyhow};
  4use cloud_llm_client::predict_edits_v3::{self, Event, ReferencedDeclaration};
  5use indoc::indoc;
  6use ordered_float::OrderedFloat;
  7use rustc_hash::{FxHashMap, FxHashSet};
  8use std::fmt::Write;
  9use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
 10use strum::{EnumIter, IntoEnumIterator};
 11
 12pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
 13
 14pub const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
 15/// NOTE: Differs from zed version of constant - includes a newline
 16pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
 17/// NOTE: Differs from zed version of constant - includes a newline
 18pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
 19
 20// TODO: use constants for markers?
 21pub const SYSTEM_PROMPT: &str = indoc! {"
 22    You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
 23
 24    The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|user_cursor_is_here|>. Please respond with edited code for that region.
 25"};
 26
 27pub struct PlannedPrompt<'a> {
 28    request: &'a predict_edits_v3::PredictEditsRequest,
 29    /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
 30    /// `to_prompt_string`.
 31    snippets: Vec<PlannedSnippet<'a>>,
 32    budget_used: usize,
 33}
 34
 35pub struct PlanOptions {
 36    pub max_bytes: usize,
 37}
 38
 39#[derive(Clone, Debug)]
 40pub struct PlannedSnippet<'a> {
 41    path: &'a Path,
 42    range: Range<usize>,
 43    text: &'a str,
 44    // TODO: Indicate this in the output
 45    #[allow(dead_code)]
 46    text_is_truncated: bool,
 47}
 48
 49#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
 50pub enum SnippetStyle {
 51    Signature,
 52    Declaration,
 53}
 54
 55impl<'a> PlannedPrompt<'a> {
 56    /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
 57    ///
 58    /// Initializes a priority queue by populating it with each snippet, finding the SnippetStyle
 59    /// that minimizes `score_density = score / snippet.range(style).len()`. When a "signature"
 60    /// snippet is popped, insert an entry for the "declaration" variant that reflects the cost of
 61    /// upgrade.
 62    ///
 63    /// TODO: Implement an early halting condition. One option might be to have another priority
 64    /// queue where the score is the size, and update it accordingly. Another option might be to
 65    /// have some simpler heuristic like bailing after N failed insertions, or based on how much
 66    /// budget is left.
 67    ///
 68    /// TODO: Has the current known sources of imprecision:
 69    ///
 70    /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
 71    /// plan even though the containing struct is already included.
 72    ///
 73    /// * Does not consider cost of signatures when ranking snippets - this is tricky since
 74    /// signatures may be shared by multiple snippets.
 75    ///
 76    /// * Does not include file paths / other text when considering max_bytes.
 77    pub fn populate(
 78        request: &'a predict_edits_v3::PredictEditsRequest,
 79        options: &PlanOptions,
 80    ) -> Result<Self> {
 81        let mut this = PlannedPrompt {
 82            request,
 83            snippets: Vec::new(),
 84            budget_used: request.excerpt.len(),
 85        };
 86        let mut included_parents = FxHashSet::default();
 87        let additional_parents = this.additional_parent_signatures(
 88            &request.excerpt_path,
 89            request.excerpt_parent,
 90            &included_parents,
 91        )?;
 92        this.add_parents(&mut included_parents, additional_parents);
 93
 94        if this.budget_used > options.max_bytes {
 95            return Err(anyhow!(
 96                "Excerpt + signatures size of {} already exceeds budget of {}",
 97                this.budget_used,
 98                options.max_bytes
 99            ));
100        }
101
102        #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
103        struct QueueEntry {
104            score_density: OrderedFloat<f32>,
105            declaration_index: usize,
106            style: SnippetStyle,
107        }
108
109        // Initialize priority queue with the best score for each snippet.
110        let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
111        for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
112            let (style, score_density) = SnippetStyle::iter()
113                .map(|style| {
114                    (
115                        style,
116                        OrderedFloat(declaration_score_density(&declaration, style)),
117                    )
118                })
119                .max_by_key(|(_, score_density)| *score_density)
120                .unwrap();
121            queue.push(QueueEntry {
122                score_density,
123                declaration_index,
124                style,
125            });
126        }
127
128        // Knapsack selection loop
129        while let Some(queue_entry) = queue.pop() {
130            let Some(declaration) = request
131                .referenced_declarations
132                .get(queue_entry.declaration_index)
133            else {
134                return Err(anyhow!(
135                    "Invalid declaration index {}",
136                    queue_entry.declaration_index
137                ));
138            };
139
140            let mut additional_bytes = declaration_size(declaration, queue_entry.style);
141            if this.budget_used + additional_bytes > options.max_bytes {
142                continue;
143            }
144
145            let additional_parents = this.additional_parent_signatures(
146                &declaration.path,
147                declaration.parent_index,
148                &mut included_parents,
149            )?;
150            additional_bytes += additional_parents
151                .iter()
152                .map(|(_, snippet)| snippet.text.len())
153                .sum::<usize>();
154            if this.budget_used + additional_bytes > options.max_bytes {
155                continue;
156            }
157
158            this.budget_used += additional_bytes;
159            this.add_parents(&mut included_parents, additional_parents);
160            let planned_snippet = match queue_entry.style {
161                SnippetStyle::Signature => {
162                    let Some(text) = declaration.text.get(declaration.signature_range.clone())
163                    else {
164                        return Err(anyhow!(
165                            "Invalid declaration signature_range {:?} with text.len() = {}",
166                            declaration.signature_range,
167                            declaration.text.len()
168                        ));
169                    };
170                    PlannedSnippet {
171                        path: &declaration.path,
172                        range: (declaration.signature_range.start + declaration.range.start)
173                            ..(declaration.signature_range.end + declaration.range.start),
174                        text,
175                        text_is_truncated: declaration.text_is_truncated,
176                    }
177                }
178                SnippetStyle::Declaration => PlannedSnippet {
179                    path: &declaration.path,
180                    range: declaration.range.clone(),
181                    text: &declaration.text,
182                    text_is_truncated: declaration.text_is_truncated,
183                },
184            };
185            this.snippets.push(planned_snippet);
186
187            // When a Signature is consumed, insert an entry for Definition style.
188            if queue_entry.style == SnippetStyle::Signature {
189                let signature_size = declaration_size(&declaration, SnippetStyle::Signature);
190                let declaration_size = declaration_size(&declaration, SnippetStyle::Declaration);
191                let signature_score = declaration_score(&declaration, SnippetStyle::Signature);
192                let declaration_score = declaration_score(&declaration, SnippetStyle::Declaration);
193
194                let score_diff = declaration_score - signature_score;
195                let size_diff = declaration_size.saturating_sub(signature_size);
196                if score_diff > 0.0001 && size_diff > 0 {
197                    queue.push(QueueEntry {
198                        declaration_index: queue_entry.declaration_index,
199                        score_density: OrderedFloat(score_diff / (size_diff as f32)),
200                        style: SnippetStyle::Declaration,
201                    });
202                }
203            }
204        }
205
206        anyhow::Ok(this)
207    }
208
209    fn add_parents(
210        &mut self,
211        included_parents: &mut FxHashSet<usize>,
212        snippets: Vec<(usize, PlannedSnippet<'a>)>,
213    ) {
214        for (parent_index, snippet) in snippets {
215            included_parents.insert(parent_index);
216            self.budget_used += snippet.text.len();
217            self.snippets.push(snippet);
218        }
219    }
220
221    fn additional_parent_signatures(
222        &self,
223        path: &'a Path,
224        parent_index: Option<usize>,
225        included_parents: &FxHashSet<usize>,
226    ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
227        let mut results = Vec::new();
228        self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
229        Ok(results)
230    }
231
232    fn additional_parent_signatures_impl(
233        &self,
234        path: &'a Path,
235        parent_index: Option<usize>,
236        included_parents: &FxHashSet<usize>,
237        results: &mut Vec<(usize, PlannedSnippet<'a>)>,
238    ) -> Result<()> {
239        let Some(parent_index) = parent_index else {
240            return Ok(());
241        };
242        if included_parents.contains(&parent_index) {
243            return Ok(());
244        }
245        let Some(parent_signature) = self.request.signatures.get(parent_index) else {
246            return Err(anyhow!("Invalid parent index {}", parent_index));
247        };
248        results.push((
249            parent_index,
250            PlannedSnippet {
251                path,
252                range: parent_signature.range.clone(),
253                text: &parent_signature.text,
254                text_is_truncated: parent_signature.text_is_truncated,
255            },
256        ));
257        self.additional_parent_signatures_impl(
258            path,
259            parent_signature.parent_index,
260            included_parents,
261            results,
262        )
263    }
264
265    /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
266    /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
267    /// chunks.
268    pub fn to_prompt_string(&self) -> String {
269        let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
270            FxHashMap::default();
271        for snippet in &self.snippets {
272            file_to_snippets
273                .entry(&snippet.path)
274                .or_default()
275                .push(snippet);
276        }
277
278        // Reorder so that file with cursor comes last
279        let mut file_snippets = Vec::new();
280        let mut excerpt_file_snippets = Vec::new();
281        for (file_path, snippets) in file_to_snippets {
282            if file_path == &self.request.excerpt_path {
283                excerpt_file_snippets = snippets;
284            } else {
285                file_snippets.push((file_path, snippets, false));
286            }
287        }
288        let excerpt_snippet = PlannedSnippet {
289            path: &self.request.excerpt_path,
290            range: self.request.excerpt_range.clone(),
291            text: &self.request.excerpt,
292            text_is_truncated: false,
293        };
294        excerpt_file_snippets.push(&excerpt_snippet);
295        file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
296
297        let mut excerpt_file_insertions = vec![
298            (
299                self.request.excerpt_range.start,
300                EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
301            ),
302            (
303                self.request.excerpt_range.start + self.request.cursor_offset,
304                CURSOR_MARKER,
305            ),
306            (
307                self.request
308                    .excerpt_range
309                    .end
310                    .saturating_sub(0)
311                    .max(self.request.excerpt_range.start),
312                EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
313            ),
314        ];
315
316        let mut output = String::new();
317        output.push_str("## User Edits\n\n");
318        Self::push_events(&mut output, &self.request.events);
319
320        output.push_str("\n## Code\n\n");
321        Self::push_file_snippets(&mut output, &mut excerpt_file_insertions, file_snippets);
322        output
323    }
324
325    fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
326        for event in events {
327            match event {
328                Event::BufferChange {
329                    path,
330                    old_path,
331                    diff,
332                    predicted,
333                } => {
334                    if let Some(old_path) = &old_path
335                        && let Some(new_path) = &path
336                    {
337                        if old_path != new_path {
338                            writeln!(
339                                output,
340                                "User renamed {} to {}\n\n",
341                                old_path.display(),
342                                new_path.display()
343                            )
344                            .unwrap();
345                        }
346                    }
347
348                    let path = path
349                        .as_ref()
350                        .map_or_else(|| "untitled".to_string(), |path| path.display().to_string());
351
352                    if *predicted {
353                        writeln!(
354                            output,
355                            "User accepted prediction {:?}:\n```diff\n{}\n```\n",
356                            path, diff
357                        )
358                        .unwrap();
359                    } else {
360                        writeln!(output, "User edited {:?}:\n```diff\n{}\n```\n", path, diff)
361                            .unwrap();
362                    }
363                }
364            }
365        }
366    }
367
368    fn push_file_snippets(
369        output: &mut String,
370        excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
371        file_snippets: Vec<(&Path, Vec<&PlannedSnippet>, bool)>,
372    ) {
373        fn push_excerpt_file_range(
374            range: Range<usize>,
375            text: &str,
376            excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
377            output: &mut String,
378        ) {
379            let mut last_offset = range.start;
380            let mut i = 0;
381            while i < excerpt_file_insertions.len() {
382                let (offset, insertion) = &excerpt_file_insertions[i];
383                let found = *offset >= range.start && *offset <= range.end;
384                if found {
385                    output.push_str(&text[last_offset - range.start..offset - range.start]);
386                    output.push_str(insertion);
387                    last_offset = *offset;
388                    excerpt_file_insertions.remove(i);
389                    continue;
390                }
391                i += 1;
392            }
393            output.push_str(&text[last_offset - range.start..]);
394        }
395
396        for (file_path, mut snippets, is_excerpt_file) in file_snippets {
397            output.push_str(&format!("```{}\n", file_path.display()));
398
399            let mut last_included_range: Option<Range<usize>> = None;
400            snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
401            for snippet in snippets {
402                if let Some(last_range) = &last_included_range
403                    && snippet.range.start < last_range.end
404                {
405                    if snippet.range.end <= last_range.end {
406                        continue;
407                    }
408                    // TODO: Should probably also handle case where there is just one char (newline)
409                    // between snippets - assume it's a newline.
410                    let text = &snippet.text[last_range.end - snippet.range.start..];
411                    if is_excerpt_file {
412                        push_excerpt_file_range(
413                            last_range.end..snippet.range.end,
414                            text,
415                            excerpt_file_insertions,
416                            output,
417                        );
418                    } else {
419                        output.push_str(text);
420                    }
421                    last_included_range = Some(last_range.start..snippet.range.end);
422                    continue;
423                }
424                if last_included_range.is_some() {
425                    output.push_str("\n");
426                }
427                if is_excerpt_file {
428                    push_excerpt_file_range(
429                        snippet.range.clone(),
430                        snippet.text,
431                        excerpt_file_insertions,
432                        output,
433                    );
434                } else {
435                    output.push_str(snippet.text);
436                }
437                last_included_range = Some(snippet.range.clone());
438            }
439
440            output.push_str("```\n\n");
441        }
442    }
443}
444
445fn declaration_score_density(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
446    declaration_score(declaration, style) / declaration_size(declaration, style) as f32
447}
448
449fn declaration_score(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
450    match style {
451        SnippetStyle::Signature => declaration.signature_score,
452        SnippetStyle::Declaration => declaration.declaration_score,
453    }
454}
455
456fn declaration_size(declaration: &ReferencedDeclaration, style: SnippetStyle) -> usize {
457    match style {
458        SnippetStyle::Signature => declaration.signature_range.len(),
459        SnippetStyle::Declaration => declaration.text.len(),
460    }
461}