1use crate::PromptFormat;
2use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
3use crate::headless::ZetaCliAppState;
4use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
5use ::serde::Serialize;
6use anyhow::{Context, Result, anyhow};
7use clap::{Args, ValueEnum};
8use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
9use collections::HashMap;
10use futures::StreamExt as _;
11use gpui::{AppContext, AsyncApp, Entity};
12use language::{Anchor, Buffer, Point};
13use project::Project;
14use serde::Deserialize;
15use std::fs;
16use std::io::{IsTerminal, Write};
17use std::ops::Range;
18use std::path::PathBuf;
19use std::sync::Arc;
20use std::sync::Mutex;
21use std::time::{Duration, Instant};
22use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
23
24#[derive(Debug, Args)]
25pub struct PredictArguments {
26 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
27 prompt_format: PromptFormat,
28 #[arg(long)]
29 use_expected_context: bool,
30 #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
31 format: PredictionsOutputFormat,
32 example_path: PathBuf,
33 #[clap(long, value_enum, default_value_t = CacheMode::default())]
34 cache: CacheMode,
35}
36
37#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
38pub enum CacheMode {
39 /// Use cached LLM requests and responses, except when multiple repetitions are requested
40 #[default]
41 Auto,
42 /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
43 #[value(alias = "request")]
44 Requests,
45 /// Ignore existing cache entries for both LLM and search.
46 Skip,
47 /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
48 /// Useful for reproducing results and fixing bugs outside of search queries
49 Force,
50}
51
52impl CacheMode {
53 fn use_cached_llm_responses(&self) -> bool {
54 self.assert_not_auto();
55 matches!(self, CacheMode::Requests | CacheMode::Force)
56 }
57
58 fn use_cached_search_results(&self) -> bool {
59 self.assert_not_auto();
60 matches!(self, CacheMode::Force)
61 }
62
63 fn assert_not_auto(&self) {
64 assert_ne!(
65 *self,
66 CacheMode::Auto,
67 "Cache mode should not be auto at this point!"
68 );
69 }
70}
71
72#[derive(clap::ValueEnum, Debug, Clone)]
73pub enum PredictionsOutputFormat {
74 Json,
75 Md,
76 Diff,
77}
78
79pub async fn run_zeta2_predict(
80 args: PredictArguments,
81 app_state: &Arc<ZetaCliAppState>,
82 cx: &mut AsyncApp,
83) {
84 let example = NamedExample::load(args.example_path).unwrap();
85 let (project, mut zetas, _edited_buffers) =
86 example.setup_project(app_state, 1, cx).await.unwrap();
87 let result = zeta2_predict(
88 example,
89 project,
90 zetas.remove(0),
91 None,
92 args.prompt_format,
93 args.use_expected_context,
94 args.cache,
95 cx,
96 )
97 .await
98 .unwrap();
99 result.write(args.format, std::io::stdout()).unwrap();
100
101 print_run_data_dir(true, std::io::stdout().is_terminal());
102}
103
104pub async fn zeta2_predict(
105 example: NamedExample,
106 project: Entity<Project>,
107 zeta: Entity<Zeta>,
108 repetition_ix: Option<u16>,
109 prompt_format: PromptFormat,
110 use_expected_context: bool,
111 mut cache_mode: CacheMode,
112 cx: &mut AsyncApp,
113) -> Result<PredictionDetails> {
114 if repetition_ix.is_some() {
115 if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
116 panic!("Repetitions are not supported in Auto cache mode");
117 } else {
118 cache_mode = CacheMode::Skip;
119 }
120 } else if cache_mode == CacheMode::Auto {
121 cache_mode = CacheMode::Requests;
122 }
123
124 let mut example_run_dir = RUN_DIR.join(&example.file_name());
125 if let Some(repetition_ix) = repetition_ix {
126 example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
127 }
128 fs::create_dir_all(&example_run_dir)?;
129 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
130 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
131 }
132
133 #[cfg(unix)]
134 std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
135 .context("creating latest link")?;
136
137 #[cfg(windows)]
138 std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
139 .context("creating latest link")?;
140
141 zeta.update(cx, |zeta, _cx| {
142 zeta.with_eval_cache(Arc::new(RunCache {
143 example_run_dir: example_run_dir.clone(),
144 cache_mode,
145 }));
146 })?;
147
148 let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
149
150 let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
151 let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
152
153 let debug_task = cx.background_spawn({
154 let result = result.clone();
155 async move {
156 let mut start_time = None;
157 let mut search_queries_generated_at = None;
158 let mut search_queries_executed_at = None;
159 while let Some(event) = debug_rx.next().await {
160 match event {
161 zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
162 start_time = Some(info.timestamp);
163 fs::write(
164 example_run_dir.join("search_prompt.md"),
165 &info.search_prompt,
166 )?;
167 }
168 zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
169 search_queries_generated_at = Some(info.timestamp);
170 fs::write(
171 example_run_dir.join("search_queries.json"),
172 serde_json::to_string_pretty(&info.search_queries).unwrap(),
173 )?;
174 }
175 zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
176 search_queries_executed_at = Some(info.timestamp);
177 }
178 zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
179 zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
180 let prediction_started_at = Instant::now();
181 start_time.get_or_insert(prediction_started_at);
182 let prompt = request.local_prompt.unwrap_or_default();
183 fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
184
185 {
186 let mut result = result.lock().unwrap();
187 result.prompt_len = prompt.chars().count();
188
189 for included_file in request.request.included_files {
190 let insertions =
191 vec![(request.request.cursor_point, CURSOR_MARKER)];
192 result.excerpts.extend(included_file.excerpts.iter().map(
193 |excerpt| ActualExcerpt {
194 path: included_file.path.components().skip(1).collect(),
195 text: String::from(excerpt.text.as_ref()),
196 },
197 ));
198 write_codeblock(
199 &included_file.path,
200 included_file.excerpts.iter(),
201 if included_file.path == request.request.excerpt_path {
202 &insertions
203 } else {
204 &[]
205 },
206 included_file.max_row,
207 false,
208 &mut result.excerpts_text,
209 );
210 }
211 }
212
213 let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
214 let response = zeta2::text_from_response(response).unwrap_or_default();
215 let prediction_finished_at = Instant::now();
216 fs::write(example_run_dir.join("prediction_response.md"), &response)?;
217
218 let mut result = result.lock().unwrap();
219 result.generated_len = response.chars().count();
220
221 if !use_expected_context {
222 result.planning_search_time =
223 Some(search_queries_generated_at.unwrap() - start_time.unwrap());
224 result.running_search_time = Some(
225 search_queries_executed_at.unwrap()
226 - search_queries_generated_at.unwrap(),
227 );
228 }
229 result.prediction_time = prediction_finished_at - prediction_started_at;
230 result.total_time = prediction_finished_at - start_time.unwrap();
231
232 break;
233 }
234 }
235 }
236 anyhow::Ok(())
237 }
238 });
239
240 zeta.update(cx, |zeta, _cx| {
241 let mut options = zeta.options().clone();
242 options.prompt_format = prompt_format.into();
243 zeta.set_options(options);
244 })?;
245
246 if use_expected_context {
247 let context_excerpts_tasks = example
248 .example
249 .expected_context
250 .iter()
251 .flat_map(|section| {
252 section.alternatives[0].excerpts.iter().map(|excerpt| {
253 resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
254 })
255 })
256 .collect::<Vec<_>>();
257 let context_excerpts_vec = futures::future::try_join_all(context_excerpts_tasks).await?;
258
259 let mut context_excerpts = HashMap::default();
260 for (buffer, mut excerpts) in context_excerpts_vec {
261 context_excerpts
262 .entry(buffer)
263 .or_insert(Vec::new())
264 .append(&mut excerpts);
265 }
266
267 zeta.update(cx, |zeta, _cx| {
268 zeta.set_context(project.clone(), context_excerpts)
269 })?;
270 } else {
271 zeta.update(cx, |zeta, cx| {
272 zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
273 })?
274 .await?;
275 }
276
277 let prediction = zeta
278 .update(cx, |zeta, cx| {
279 zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
280 })?
281 .await?;
282
283 debug_task.await?;
284
285 let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
286 result.diff = prediction
287 .map(|prediction| {
288 let old_text = prediction.snapshot.text();
289 let new_text = prediction
290 .buffer
291 .update(cx, |buffer, cx| {
292 let branch = buffer.branch(cx);
293 branch.update(cx, |branch, cx| {
294 branch.edit(prediction.edits.iter().cloned(), None, cx);
295 branch.text()
296 })
297 })
298 .unwrap();
299 language::unified_diff(&old_text, &new_text)
300 })
301 .unwrap_or_default();
302
303 anyhow::Ok(result)
304}
305
306async fn resolve_context_entry(
307 project: Entity<Project>,
308 excerpt: ExpectedExcerpt,
309 mut cx: AsyncApp,
310) -> Result<(Entity<Buffer>, Vec<Range<Anchor>>)> {
311 let buffer = project
312 .update(&mut cx, |project, cx| {
313 let project_path = project.find_project_path(&excerpt.path, cx).unwrap();
314 project.open_buffer(project_path, cx)
315 })?
316 .await?;
317
318 let ranges = buffer.read_with(&mut cx, |buffer, _| {
319 let full_text = buffer.text();
320 let offset = full_text
321 .find(&excerpt.text)
322 .expect("Expected context not found");
323 let point = buffer.offset_to_point(offset);
324 excerpt
325 .required_lines
326 .iter()
327 .map(|line| {
328 let row = point.row + line.0;
329 let range = Point::new(row, 0)..Point::new(row + 1, 0);
330 buffer.anchor_after(range.start)..buffer.anchor_before(range.end)
331 })
332 .collect()
333 })?;
334
335 Ok((buffer, ranges))
336}
337
338struct RunCache {
339 cache_mode: CacheMode,
340 example_run_dir: PathBuf,
341}
342
343impl RunCache {
344 fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
345 CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
346 }
347
348 fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
349 CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
350 }
351
352 fn link_to_run(&self, key: &EvalCacheKey) {
353 let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
354 fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
355
356 let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
357 fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
358 }
359}
360
361impl EvalCache for RunCache {
362 fn read(&self, key: EvalCacheKey) -> Option<String> {
363 let path = RunCache::output_cache_path(&key);
364
365 if path.exists() {
366 let use_cache = match key.0 {
367 EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
368 EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
369 self.cache_mode.use_cached_llm_responses()
370 }
371 };
372 if use_cache {
373 log::info!("Using cache entry: {}", path.display());
374 self.link_to_run(&key);
375 Some(fs::read_to_string(path).unwrap())
376 } else {
377 log::trace!("Skipping cached entry: {}", path.display());
378 None
379 }
380 } else if matches!(self.cache_mode, CacheMode::Force) {
381 panic!(
382 "No cached entry found for {:?}. Run without `--cache force` at least once.",
383 key.0
384 );
385 } else {
386 None
387 }
388 }
389
390 fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
391 fs::create_dir_all(&*CACHE_DIR).unwrap();
392
393 let input_path = RunCache::input_cache_path(&key);
394 fs::write(&input_path, input).unwrap();
395
396 let output_path = RunCache::output_cache_path(&key);
397 log::trace!("Writing cache entry: {}", output_path.display());
398 fs::write(&output_path, output).unwrap();
399
400 self.link_to_run(&key);
401 }
402}
403
404#[derive(Clone, Debug, Serialize, Deserialize)]
405pub struct PredictionDetails {
406 pub diff: String,
407 pub excerpts: Vec<ActualExcerpt>,
408 pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
409 pub planning_search_time: Option<Duration>,
410 pub running_search_time: Option<Duration>,
411 pub prediction_time: Duration,
412 pub total_time: Duration,
413 pub run_example_dir: PathBuf,
414 pub prompt_len: usize,
415 pub generated_len: usize,
416}
417
418impl PredictionDetails {
419 pub fn new(run_example_dir: PathBuf) -> Self {
420 Self {
421 diff: Default::default(),
422 excerpts: Default::default(),
423 excerpts_text: Default::default(),
424 planning_search_time: Default::default(),
425 running_search_time: Default::default(),
426 prediction_time: Default::default(),
427 total_time: Default::default(),
428 run_example_dir,
429 prompt_len: 0,
430 generated_len: 0,
431 }
432 }
433
434 pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
435 let formatted = match format {
436 PredictionsOutputFormat::Md => self.to_markdown(),
437 PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
438 PredictionsOutputFormat::Diff => self.diff.clone(),
439 };
440
441 Ok(out.write_all(formatted.as_bytes())?)
442 }
443
444 pub fn to_markdown(&self) -> String {
445 let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
446
447 format!(
448 "## Excerpts\n\n\
449 {}\n\n\
450 ## Prediction\n\n\
451 {}\n\n\
452 ## Time\n\n\
453 Planning searches: {}ms\n\
454 Running searches: {}ms\n\
455 Making Prediction: {}ms\n\n\
456 -------------------\n\n\
457 Total: {}ms\n\
458 Inference: {}ms ({:.2}%)\n",
459 self.excerpts_text,
460 self.diff,
461 self.planning_search_time.unwrap_or_default().as_millis(),
462 self.running_search_time.unwrap_or_default().as_millis(),
463 self.prediction_time.as_millis(),
464 self.total_time.as_millis(),
465 inference_time.as_millis(),
466 (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
467 )
468 }
469}