1//! Zeta2 prompt planning and generation code shared with cloud.
2
3use anyhow::{Result, anyhow};
4use cloud_llm_client::predict_edits_v3::{self, ReferencedDeclaration};
5use ordered_float::OrderedFloat;
6use rustc_hash::{FxHashMap, FxHashSet};
7use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
8use strum::{EnumIter, IntoEnumIterator};
9
10pub const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
11/// NOTE: Differs from zed version of constant - includes a newline
12pub const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>\n";
13/// NOTE: Differs from zed version of constant - includes a newline
14pub const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>\n";
15
16pub struct PlannedPrompt<'a> {
17 request: &'a predict_edits_v3::PredictEditsRequest,
18 /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
19 /// `to_prompt_string`.
20 snippets: Vec<PlannedSnippet<'a>>,
21 budget_used: usize,
22}
23
24pub struct PlanOptions {
25 pub max_bytes: usize,
26}
27
28#[derive(Clone, Debug)]
29pub struct PlannedSnippet<'a> {
30 path: &'a Path,
31 range: Range<usize>,
32 text: &'a str,
33 // TODO: Indicate this in the output
34 #[allow(dead_code)]
35 text_is_truncated: bool,
36}
37
38#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
39pub enum SnippetStyle {
40 Signature,
41 Declaration,
42}
43
44impl<'a> PlannedPrompt<'a> {
45 /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
46 ///
47 /// Initializes a priority queue by populating it with each snippet, finding the SnippetStyle
48 /// that minimizes `score_density = score / snippet.range(style).len()`. When a "signature"
49 /// snippet is popped, insert an entry for the "declaration" variant that reflects the cost of
50 /// upgrade.
51 ///
52 /// TODO: Implement an early halting condition. One option might be to have another priority
53 /// queue where the score is the size, and update it accordingly. Another option might be to
54 /// have some simpler heuristic like bailing after N failed insertions, or based on how much
55 /// budget is left.
56 ///
57 /// TODO: Has the current known sources of imprecision:
58 ///
59 /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
60 /// plan even though the containing struct is already included.
61 ///
62 /// * Does not consider cost of signatures when ranking snippets - this is tricky since
63 /// signatures may be shared by multiple snippets.
64 ///
65 /// * Does not include file paths / other text when considering max_bytes.
66 pub fn populate(
67 request: &'a predict_edits_v3::PredictEditsRequest,
68 options: &PlanOptions,
69 ) -> Result<Self> {
70 let mut this = PlannedPrompt {
71 request,
72 snippets: Vec::new(),
73 budget_used: request.excerpt.len(),
74 };
75 let mut included_parents = FxHashSet::default();
76 let additional_parents = this.additional_parent_signatures(
77 &request.excerpt_path,
78 request.excerpt_parent,
79 &included_parents,
80 )?;
81 this.add_parents(&mut included_parents, additional_parents);
82
83 if this.budget_used > options.max_bytes {
84 return Err(anyhow!(
85 "Excerpt + signatures size of {} already exceeds budget of {}",
86 this.budget_used,
87 options.max_bytes
88 ));
89 }
90
91 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
92 struct QueueEntry {
93 score_density: OrderedFloat<f32>,
94 declaration_index: usize,
95 style: SnippetStyle,
96 }
97
98 // Initialize priority queue with the best score for each snippet.
99 let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
100 for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
101 let (style, score_density) = SnippetStyle::iter()
102 .map(|style| {
103 (
104 style,
105 OrderedFloat(declaration_score_density(&declaration, style)),
106 )
107 })
108 .max_by_key(|(_, score_density)| *score_density)
109 .unwrap();
110 queue.push(QueueEntry {
111 score_density,
112 declaration_index,
113 style,
114 });
115 }
116
117 // Knapsack selection loop
118 while let Some(queue_entry) = queue.pop() {
119 let Some(declaration) = request
120 .referenced_declarations
121 .get(queue_entry.declaration_index)
122 else {
123 return Err(anyhow!(
124 "Invalid declaration index {}",
125 queue_entry.declaration_index
126 ));
127 };
128
129 let mut additional_bytes = declaration_size(declaration, queue_entry.style);
130 if this.budget_used + additional_bytes > options.max_bytes {
131 continue;
132 }
133
134 let additional_parents = this.additional_parent_signatures(
135 &declaration.path,
136 declaration.parent_index,
137 &mut included_parents,
138 )?;
139 additional_bytes += additional_parents
140 .iter()
141 .map(|(_, snippet)| snippet.text.len())
142 .sum::<usize>();
143 if this.budget_used + additional_bytes > options.max_bytes {
144 continue;
145 }
146
147 this.budget_used += additional_bytes;
148 this.add_parents(&mut included_parents, additional_parents);
149 let planned_snippet = match queue_entry.style {
150 SnippetStyle::Signature => {
151 let Some(text) = declaration.text.get(declaration.signature_range.clone())
152 else {
153 return Err(anyhow!(
154 "Invalid declaration signature_range {:?} with text.len() = {}",
155 declaration.signature_range,
156 declaration.text.len()
157 ));
158 };
159 PlannedSnippet {
160 path: &declaration.path,
161 range: (declaration.signature_range.start + declaration.range.start)
162 ..(declaration.signature_range.end + declaration.range.start),
163 text,
164 text_is_truncated: declaration.text_is_truncated,
165 }
166 }
167 SnippetStyle::Declaration => PlannedSnippet {
168 path: &declaration.path,
169 range: declaration.range.clone(),
170 text: &declaration.text,
171 text_is_truncated: declaration.text_is_truncated,
172 },
173 };
174 this.snippets.push(planned_snippet);
175
176 // When a Signature is consumed, insert an entry for Definition style.
177 if queue_entry.style == SnippetStyle::Signature {
178 let signature_size = declaration_size(&declaration, SnippetStyle::Signature);
179 let declaration_size = declaration_size(&declaration, SnippetStyle::Declaration);
180 let signature_score = declaration_score(&declaration, SnippetStyle::Signature);
181 let declaration_score = declaration_score(&declaration, SnippetStyle::Declaration);
182
183 let score_diff = declaration_score - signature_score;
184 let size_diff = declaration_size.saturating_sub(signature_size);
185 if score_diff > 0.0001 && size_diff > 0 {
186 queue.push(QueueEntry {
187 declaration_index: queue_entry.declaration_index,
188 score_density: OrderedFloat(score_diff / (size_diff as f32)),
189 style: SnippetStyle::Declaration,
190 });
191 }
192 }
193 }
194
195 anyhow::Ok(this)
196 }
197
198 fn add_parents(
199 &mut self,
200 included_parents: &mut FxHashSet<usize>,
201 snippets: Vec<(usize, PlannedSnippet<'a>)>,
202 ) {
203 for (parent_index, snippet) in snippets {
204 included_parents.insert(parent_index);
205 self.budget_used += snippet.text.len();
206 self.snippets.push(snippet);
207 }
208 }
209
210 fn additional_parent_signatures(
211 &self,
212 path: &'a Path,
213 parent_index: Option<usize>,
214 included_parents: &FxHashSet<usize>,
215 ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
216 let mut results = Vec::new();
217 self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
218 Ok(results)
219 }
220
221 fn additional_parent_signatures_impl(
222 &self,
223 path: &'a Path,
224 parent_index: Option<usize>,
225 included_parents: &FxHashSet<usize>,
226 results: &mut Vec<(usize, PlannedSnippet<'a>)>,
227 ) -> Result<()> {
228 let Some(parent_index) = parent_index else {
229 return Ok(());
230 };
231 if included_parents.contains(&parent_index) {
232 return Ok(());
233 }
234 let Some(parent_signature) = self.request.signatures.get(parent_index) else {
235 return Err(anyhow!("Invalid parent index {}", parent_index));
236 };
237 results.push((
238 parent_index,
239 PlannedSnippet {
240 path,
241 range: parent_signature.range.clone(),
242 text: &parent_signature.text,
243 text_is_truncated: parent_signature.text_is_truncated,
244 },
245 ));
246 self.additional_parent_signatures_impl(
247 path,
248 parent_signature.parent_index,
249 included_parents,
250 results,
251 )
252 }
253
254 /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
255 /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
256 /// chunks.
257 pub fn to_prompt_string(&self) -> String {
258 let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
259 FxHashMap::default();
260 for snippet in &self.snippets {
261 file_to_snippets
262 .entry(&snippet.path)
263 .or_default()
264 .push(snippet);
265 }
266
267 // Reorder so that file with cursor comes last
268 let mut file_snippets = Vec::new();
269 let mut excerpt_file_snippets = Vec::new();
270 for (file_path, snippets) in file_to_snippets {
271 if file_path == &self.request.excerpt_path {
272 excerpt_file_snippets = snippets;
273 } else {
274 file_snippets.push((file_path, snippets, false));
275 }
276 }
277 let excerpt_snippet = PlannedSnippet {
278 path: &self.request.excerpt_path,
279 range: self.request.excerpt_range.clone(),
280 text: &self.request.excerpt,
281 text_is_truncated: false,
282 };
283 excerpt_file_snippets.push(&excerpt_snippet);
284 file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
285
286 let mut excerpt_file_insertions = vec![
287 (
288 self.request.excerpt_range.start,
289 EDITABLE_REGION_START_MARKER,
290 ),
291 (
292 self.request.excerpt_range.start + self.request.cursor_offset,
293 CURSOR_MARKER,
294 ),
295 (
296 self.request
297 .excerpt_range
298 .end
299 .saturating_sub(0)
300 .max(self.request.excerpt_range.start),
301 EDITABLE_REGION_END_MARKER,
302 ),
303 ];
304
305 fn push_excerpt_file_range(
306 range: Range<usize>,
307 text: &str,
308 excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
309 output: &mut String,
310 ) {
311 let mut last_offset = range.start;
312 let mut i = 0;
313 while i < excerpt_file_insertions.len() {
314 let (offset, insertion) = &excerpt_file_insertions[i];
315 let found = *offset >= range.start && *offset <= range.end;
316 if found {
317 output.push_str(&text[last_offset - range.start..offset - range.start]);
318 output.push_str(insertion);
319 last_offset = *offset;
320 excerpt_file_insertions.remove(i);
321 continue;
322 }
323 i += 1;
324 }
325 output.push_str(&text[last_offset - range.start..]);
326 }
327
328 let mut output = String::new();
329 for (file_path, mut snippets, is_excerpt_file) in file_snippets {
330 output.push_str(&format!("```{}\n", file_path.display()));
331
332 let mut last_included_range: Option<Range<usize>> = None;
333 snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
334 for snippet in snippets {
335 if let Some(last_range) = &last_included_range
336 && snippet.range.start < last_range.end
337 {
338 if snippet.range.end <= last_range.end {
339 continue;
340 }
341 // TODO: Should probably also handle case where there is just one char (newline)
342 // between snippets - assume it's a newline.
343 let text = &snippet.text[last_range.end - snippet.range.start..];
344 if is_excerpt_file {
345 push_excerpt_file_range(
346 last_range.end..snippet.range.end,
347 text,
348 &mut excerpt_file_insertions,
349 &mut output,
350 );
351 } else {
352 output.push_str(text);
353 }
354 last_included_range = Some(last_range.start..snippet.range.end);
355 continue;
356 }
357 if last_included_range.is_some() {
358 output.push_str("…\n");
359 }
360 if is_excerpt_file {
361 push_excerpt_file_range(
362 snippet.range.clone(),
363 snippet.text,
364 &mut excerpt_file_insertions,
365 &mut output,
366 );
367 } else {
368 output.push_str(snippet.text);
369 }
370 last_included_range = Some(snippet.range.clone());
371 }
372
373 output.push_str("```\n\n");
374 }
375
376 output
377 }
378}
379
380fn declaration_score_density(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
381 declaration_score(declaration, style) / declaration_size(declaration, style) as f32
382}
383
384fn declaration_score(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
385 match style {
386 SnippetStyle::Signature => declaration.signature_score,
387 SnippetStyle::Declaration => declaration.declaration_score,
388 }
389}
390
391fn declaration_size(declaration: &ReferencedDeclaration, style: SnippetStyle) -> usize {
392 match style {
393 SnippetStyle::Signature => declaration.signature_range.len(),
394 SnippetStyle::Declaration => declaration.text.len(),
395 }
396}