cloud_zeta2_prompt.rs

  1//! Zeta2 prompt planning and generation code shared with cloud.
  2
  3use anyhow::{Result, anyhow};
  4use cloud_llm_client::predict_edits_v3::{self, ReferencedDeclaration};
  5use ordered_float::OrderedFloat;
  6use rustc_hash::{FxHashMap, FxHashSet};
  7use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
  8use strum::{EnumIter, IntoEnumIterator};
  9
 10pub const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
 11/// NOTE: Differs from zed version of constant - includes a newline
 12pub const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>\n";
 13/// NOTE: Differs from zed version of constant - includes a newline
 14pub const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>\n";
 15
 16pub struct PlannedPrompt<'a> {
 17    request: &'a predict_edits_v3::PredictEditsRequest,
 18    /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
 19    /// `to_prompt_string`.
 20    snippets: Vec<PlannedSnippet<'a>>,
 21    budget_used: usize,
 22}
 23
 24pub struct PlanOptions {
 25    pub max_bytes: usize,
 26}
 27
 28#[derive(Clone, Debug)]
 29pub struct PlannedSnippet<'a> {
 30    path: &'a Path,
 31    range: Range<usize>,
 32    text: &'a str,
 33    // TODO: Indicate this in the output
 34    #[allow(dead_code)]
 35    text_is_truncated: bool,
 36}
 37
 38#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
 39pub enum SnippetStyle {
 40    Signature,
 41    Declaration,
 42}
 43
 44impl<'a> PlannedPrompt<'a> {
 45    /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
 46    ///
 47    /// Initializes a priority queue by populating it with each snippet, finding the SnippetStyle
 48    /// that minimizes `score_density = score / snippet.range(style).len()`. When a "signature"
 49    /// snippet is popped, insert an entry for the "declaration" variant that reflects the cost of
 50    /// upgrade.
 51    ///
 52    /// TODO: Implement an early halting condition. One option might be to have another priority
 53    /// queue where the score is the size, and update it accordingly. Another option might be to
 54    /// have some simpler heuristic like bailing after N failed insertions, or based on how much
 55    /// budget is left.
 56    ///
 57    /// TODO: Has the current known sources of imprecision:
 58    ///
 59    /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
 60    /// plan even though the containing struct is already included.
 61    ///
 62    /// * Does not consider cost of signatures when ranking snippets - this is tricky since
 63    /// signatures may be shared by multiple snippets.
 64    ///
 65    /// * Does not include file paths / other text when considering max_bytes.
 66    pub fn populate(
 67        request: &'a predict_edits_v3::PredictEditsRequest,
 68        options: &PlanOptions,
 69    ) -> Result<Self> {
 70        let mut this = PlannedPrompt {
 71            request,
 72            snippets: Vec::new(),
 73            budget_used: request.excerpt.len(),
 74        };
 75        let mut included_parents = FxHashSet::default();
 76        let additional_parents = this.additional_parent_signatures(
 77            &request.excerpt_path,
 78            request.excerpt_parent,
 79            &included_parents,
 80        )?;
 81        this.add_parents(&mut included_parents, additional_parents);
 82
 83        if this.budget_used > options.max_bytes {
 84            return Err(anyhow!(
 85                "Excerpt + signatures size of {} already exceeds budget of {}",
 86                this.budget_used,
 87                options.max_bytes
 88            ));
 89        }
 90
 91        #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
 92        struct QueueEntry {
 93            score_density: OrderedFloat<f32>,
 94            declaration_index: usize,
 95            style: SnippetStyle,
 96        }
 97
 98        // Initialize priority queue with the best score for each snippet.
 99        let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
100        for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
101            let (style, score_density) = SnippetStyle::iter()
102                .map(|style| {
103                    (
104                        style,
105                        OrderedFloat(declaration_score_density(&declaration, style)),
106                    )
107                })
108                .max_by_key(|(_, score_density)| *score_density)
109                .unwrap();
110            queue.push(QueueEntry {
111                score_density,
112                declaration_index,
113                style,
114            });
115        }
116
117        // Knapsack selection loop
118        while let Some(queue_entry) = queue.pop() {
119            let Some(declaration) = request
120                .referenced_declarations
121                .get(queue_entry.declaration_index)
122            else {
123                return Err(anyhow!(
124                    "Invalid declaration index {}",
125                    queue_entry.declaration_index
126                ));
127            };
128
129            let mut additional_bytes = declaration_size(declaration, queue_entry.style);
130            if this.budget_used + additional_bytes > options.max_bytes {
131                continue;
132            }
133
134            let additional_parents = this.additional_parent_signatures(
135                &declaration.path,
136                declaration.parent_index,
137                &mut included_parents,
138            )?;
139            additional_bytes += additional_parents
140                .iter()
141                .map(|(_, snippet)| snippet.text.len())
142                .sum::<usize>();
143            if this.budget_used + additional_bytes > options.max_bytes {
144                continue;
145            }
146
147            this.budget_used += additional_bytes;
148            this.add_parents(&mut included_parents, additional_parents);
149            let planned_snippet = match queue_entry.style {
150                SnippetStyle::Signature => {
151                    let Some(text) = declaration.text.get(declaration.signature_range.clone())
152                    else {
153                        return Err(anyhow!(
154                            "Invalid declaration signature_range {:?} with text.len() = {}",
155                            declaration.signature_range,
156                            declaration.text.len()
157                        ));
158                    };
159                    PlannedSnippet {
160                        path: &declaration.path,
161                        range: (declaration.signature_range.start + declaration.range.start)
162                            ..(declaration.signature_range.end + declaration.range.start),
163                        text,
164                        text_is_truncated: declaration.text_is_truncated,
165                    }
166                }
167                SnippetStyle::Declaration => PlannedSnippet {
168                    path: &declaration.path,
169                    range: declaration.range.clone(),
170                    text: &declaration.text,
171                    text_is_truncated: declaration.text_is_truncated,
172                },
173            };
174            this.snippets.push(planned_snippet);
175
176            // When a Signature is consumed, insert an entry for Definition style.
177            if queue_entry.style == SnippetStyle::Signature {
178                let signature_size = declaration_size(&declaration, SnippetStyle::Signature);
179                let declaration_size = declaration_size(&declaration, SnippetStyle::Declaration);
180                let signature_score = declaration_score(&declaration, SnippetStyle::Signature);
181                let declaration_score = declaration_score(&declaration, SnippetStyle::Declaration);
182
183                let score_diff = declaration_score - signature_score;
184                let size_diff = declaration_size.saturating_sub(signature_size);
185                if score_diff > 0.0001 && size_diff > 0 {
186                    queue.push(QueueEntry {
187                        declaration_index: queue_entry.declaration_index,
188                        score_density: OrderedFloat(score_diff / (size_diff as f32)),
189                        style: SnippetStyle::Declaration,
190                    });
191                }
192            }
193        }
194
195        anyhow::Ok(this)
196    }
197
198    fn add_parents(
199        &mut self,
200        included_parents: &mut FxHashSet<usize>,
201        snippets: Vec<(usize, PlannedSnippet<'a>)>,
202    ) {
203        for (parent_index, snippet) in snippets {
204            included_parents.insert(parent_index);
205            self.budget_used += snippet.text.len();
206            self.snippets.push(snippet);
207        }
208    }
209
210    fn additional_parent_signatures(
211        &self,
212        path: &'a Path,
213        parent_index: Option<usize>,
214        included_parents: &FxHashSet<usize>,
215    ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
216        let mut results = Vec::new();
217        self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
218        Ok(results)
219    }
220
221    fn additional_parent_signatures_impl(
222        &self,
223        path: &'a Path,
224        parent_index: Option<usize>,
225        included_parents: &FxHashSet<usize>,
226        results: &mut Vec<(usize, PlannedSnippet<'a>)>,
227    ) -> Result<()> {
228        let Some(parent_index) = parent_index else {
229            return Ok(());
230        };
231        if included_parents.contains(&parent_index) {
232            return Ok(());
233        }
234        let Some(parent_signature) = self.request.signatures.get(parent_index) else {
235            return Err(anyhow!("Invalid parent index {}", parent_index));
236        };
237        results.push((
238            parent_index,
239            PlannedSnippet {
240                path,
241                range: parent_signature.range.clone(),
242                text: &parent_signature.text,
243                text_is_truncated: parent_signature.text_is_truncated,
244            },
245        ));
246        self.additional_parent_signatures_impl(
247            path,
248            parent_signature.parent_index,
249            included_parents,
250            results,
251        )
252    }
253
254    /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
255    /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
256    /// chunks.
257    pub fn to_prompt_string(&self) -> String {
258        let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
259            FxHashMap::default();
260        for snippet in &self.snippets {
261            file_to_snippets
262                .entry(&snippet.path)
263                .or_default()
264                .push(snippet);
265        }
266
267        // Reorder so that file with cursor comes last
268        let mut file_snippets = Vec::new();
269        let mut excerpt_file_snippets = Vec::new();
270        for (file_path, snippets) in file_to_snippets {
271            if file_path == &self.request.excerpt_path {
272                excerpt_file_snippets = snippets;
273            } else {
274                file_snippets.push((file_path, snippets, false));
275            }
276        }
277        let excerpt_snippet = PlannedSnippet {
278            path: &self.request.excerpt_path,
279            range: self.request.excerpt_range.clone(),
280            text: &self.request.excerpt,
281            text_is_truncated: false,
282        };
283        excerpt_file_snippets.push(&excerpt_snippet);
284        file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
285
286        let mut excerpt_file_insertions = vec![
287            (
288                self.request.excerpt_range.start,
289                EDITABLE_REGION_START_MARKER,
290            ),
291            (
292                self.request.excerpt_range.start + self.request.cursor_offset,
293                CURSOR_MARKER,
294            ),
295            (
296                self.request
297                    .excerpt_range
298                    .end
299                    .saturating_sub(0)
300                    .max(self.request.excerpt_range.start),
301                EDITABLE_REGION_END_MARKER,
302            ),
303        ];
304
305        fn push_excerpt_file_range(
306            range: Range<usize>,
307            text: &str,
308            excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
309            output: &mut String,
310        ) {
311            let mut last_offset = range.start;
312            let mut i = 0;
313            while i < excerpt_file_insertions.len() {
314                let (offset, insertion) = &excerpt_file_insertions[i];
315                let found = *offset >= range.start && *offset <= range.end;
316                if found {
317                    output.push_str(&text[last_offset - range.start..offset - range.start]);
318                    output.push_str(insertion);
319                    last_offset = *offset;
320                    excerpt_file_insertions.remove(i);
321                    continue;
322                }
323                i += 1;
324            }
325            output.push_str(&text[last_offset - range.start..]);
326        }
327
328        let mut output = String::new();
329        for (file_path, mut snippets, is_excerpt_file) in file_snippets {
330            output.push_str(&format!("```{}\n", file_path.display()));
331
332            let mut last_included_range: Option<Range<usize>> = None;
333            snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
334            for snippet in snippets {
335                if let Some(last_range) = &last_included_range
336                    && snippet.range.start < last_range.end
337                {
338                    if snippet.range.end <= last_range.end {
339                        continue;
340                    }
341                    // TODO: Should probably also handle case where there is just one char (newline)
342                    // between snippets - assume it's a newline.
343                    let text = &snippet.text[last_range.end - snippet.range.start..];
344                    if is_excerpt_file {
345                        push_excerpt_file_range(
346                            last_range.end..snippet.range.end,
347                            text,
348                            &mut excerpt_file_insertions,
349                            &mut output,
350                        );
351                    } else {
352                        output.push_str(text);
353                    }
354                    last_included_range = Some(last_range.start..snippet.range.end);
355                    continue;
356                }
357                if last_included_range.is_some() {
358                    output.push_str("\n");
359                }
360                if is_excerpt_file {
361                    push_excerpt_file_range(
362                        snippet.range.clone(),
363                        snippet.text,
364                        &mut excerpt_file_insertions,
365                        &mut output,
366                    );
367                } else {
368                    output.push_str(snippet.text);
369                }
370                last_included_range = Some(snippet.range.clone());
371            }
372
373            output.push_str("```\n\n");
374        }
375
376        output
377    }
378}
379
380fn declaration_score_density(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
381    declaration_score(declaration, style) / declaration_size(declaration, style) as f32
382}
383
384fn declaration_score(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
385    match style {
386        SnippetStyle::Signature => declaration.signature_score,
387        SnippetStyle::Declaration => declaration.declaration_score,
388    }
389}
390
391fn declaration_size(declaration: &ReferencedDeclaration, style: SnippetStyle) -> usize {
392    match style {
393        SnippetStyle::Signature => declaration.signature_range.len(),
394        SnippetStyle::Declaration => declaration.text.len(),
395    }
396}