1use crate::PromptFormat;
2use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
3use crate::headless::ZetaCliAppState;
4use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
5use ::serde::Serialize;
6use anyhow::{Context, Result, anyhow};
7use clap::{Args, ValueEnum};
8use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
9use collections::HashMap;
10use futures::StreamExt as _;
11use gpui::{AppContext, AsyncApp, Entity};
12use language::{Anchor, Buffer, Point};
13use project::Project;
14use serde::Deserialize;
15use std::cell::Cell;
16use std::fs;
17use std::io::Write;
18use std::ops::Range;
19use std::path::PathBuf;
20use std::sync::Arc;
21use std::sync::Mutex;
22use std::time::{Duration, Instant};
23use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey};
24
25#[derive(Debug, Args)]
26pub struct PredictArguments {
27 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
28 prompt_format: PromptFormat,
29 #[arg(long)]
30 use_expected_context: bool,
31 #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
32 format: PredictionsOutputFormat,
33 example_path: PathBuf,
34 #[clap(long, value_enum, default_value_t = CacheMode::default())]
35 cache: CacheMode,
36}
37
38#[derive(Debug, ValueEnum, Default, Clone, Copy)]
39pub enum CacheMode {
40 /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
41 #[default]
42 #[value(alias = "request")]
43 Requests,
44 /// Ignore existing cache entries for both LLM and search.
45 Skip,
46 /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
47 /// Useful for reproducing results and fixing bugs outside of search queries
48 Force,
49}
50
51impl CacheMode {
52 fn use_cached_llm_responses(&self) -> bool {
53 matches!(self, CacheMode::Requests | CacheMode::Force)
54 }
55
56 fn use_cached_search_results(&self) -> bool {
57 matches!(self, CacheMode::Force)
58 }
59}
60
61#[derive(clap::ValueEnum, Debug, Clone)]
62pub enum PredictionsOutputFormat {
63 Json,
64 Md,
65 Diff,
66}
67
68pub async fn run_zeta2_predict(
69 args: PredictArguments,
70 app_state: &Arc<ZetaCliAppState>,
71 cx: &mut AsyncApp,
72) {
73 let example = NamedExample::load(args.example_path).unwrap();
74 let result = zeta2_predict(
75 example,
76 args.prompt_format,
77 args.use_expected_context,
78 args.cache,
79 &app_state,
80 cx,
81 )
82 .await
83 .unwrap();
84 result.write(args.format, std::io::stdout()).unwrap();
85
86 print_run_data_dir();
87}
88
89thread_local! {
90 static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
91}
92
93pub async fn zeta2_predict(
94 example: NamedExample,
95 prompt_format: PromptFormat,
96 use_expected_context: bool,
97 cache_mode: CacheMode,
98 app_state: &Arc<ZetaCliAppState>,
99 cx: &mut AsyncApp,
100) -> Result<PredictionDetails> {
101 let worktree_path = example.setup_worktree().await?;
102
103 if !AUTHENTICATED.get() {
104 AUTHENTICATED.set(true);
105
106 app_state
107 .client
108 .sign_in_with_optional_connect(true, cx)
109 .await?;
110 }
111
112 let project = cx.update(|cx| {
113 Project::local(
114 app_state.client.clone(),
115 app_state.node_runtime.clone(),
116 app_state.user_store.clone(),
117 app_state.languages.clone(),
118 app_state.fs.clone(),
119 None,
120 cx,
121 )
122 })?;
123
124 let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
125
126 let worktree = project
127 .update(cx, |project, cx| {
128 project.create_worktree(&worktree_path, true, cx)
129 })?
130 .await?;
131 worktree
132 .read_with(cx, |worktree, _cx| {
133 worktree.as_local().unwrap().scan_complete()
134 })?
135 .await;
136
137 let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
138
139 let example_run_dir = RUN_DIR.join(&example.file_name());
140 fs::create_dir_all(&example_run_dir)?;
141 if LATEST_EXAMPLE_RUN_DIR.exists() {
142 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
143 }
144
145 #[cfg(unix)]
146 std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
147 .context("creating latest link")?;
148
149 #[cfg(windows)]
150 std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
151 .context("creating latest link")?;
152
153 zeta.update(cx, |zeta, _cx| {
154 zeta.with_eval_cache(Arc::new(RunCache {
155 example_run_dir: example_run_dir.clone(),
156 cache_mode,
157 }));
158 })?;
159
160 cx.subscribe(&buffer_store, {
161 let project = project.clone();
162 move |_, event, cx| match event {
163 project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
164 zeta2::Zeta::try_global(cx)
165 .unwrap()
166 .update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
167 }
168 _ => {}
169 }
170 })?
171 .detach();
172
173 let _edited_buffers = example.apply_edit_history(&project, cx).await?;
174 let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
175
176 let result = Arc::new(Mutex::new(PredictionDetails::default()));
177 let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
178
179 let debug_task = cx.background_spawn({
180 let result = result.clone();
181 async move {
182 let mut start_time = None;
183 let mut search_queries_generated_at = None;
184 let mut search_queries_executed_at = None;
185 while let Some(event) = debug_rx.next().await {
186 match event {
187 zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
188 start_time = Some(info.timestamp);
189 fs::write(
190 example_run_dir.join("search_prompt.md"),
191 &info.search_prompt,
192 )?;
193 }
194 zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
195 search_queries_generated_at = Some(info.timestamp);
196 fs::write(
197 example_run_dir.join("search_queries.json"),
198 serde_json::to_string_pretty(&info.search_queries).unwrap(),
199 )?;
200 }
201 zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
202 search_queries_executed_at = Some(info.timestamp);
203 }
204 zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
205 zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
206 let prediction_started_at = Instant::now();
207 start_time.get_or_insert(prediction_started_at);
208 fs::write(
209 example_run_dir.join("prediction_prompt.md"),
210 &request.local_prompt.unwrap_or_default(),
211 )?;
212
213 {
214 let mut result = result.lock().unwrap();
215
216 for included_file in request.request.included_files {
217 let insertions =
218 vec![(request.request.cursor_point, CURSOR_MARKER)];
219 result.excerpts.extend(included_file.excerpts.iter().map(
220 |excerpt| ActualExcerpt {
221 path: included_file.path.components().skip(1).collect(),
222 text: String::from(excerpt.text.as_ref()),
223 },
224 ));
225 write_codeblock(
226 &included_file.path,
227 included_file.excerpts.iter(),
228 if included_file.path == request.request.excerpt_path {
229 &insertions
230 } else {
231 &[]
232 },
233 included_file.max_row,
234 false,
235 &mut result.excerpts_text,
236 );
237 }
238 }
239
240 let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
241 let response = zeta2::text_from_response(response).unwrap_or_default();
242 let prediction_finished_at = Instant::now();
243 fs::write(example_run_dir.join("prediction_response.md"), &response)?;
244
245 let mut result = result.lock().unwrap();
246
247 if !use_expected_context {
248 result.planning_search_time =
249 Some(search_queries_generated_at.unwrap() - start_time.unwrap());
250 result.running_search_time = Some(
251 search_queries_executed_at.unwrap()
252 - search_queries_generated_at.unwrap(),
253 );
254 }
255 result.prediction_time = prediction_finished_at - prediction_started_at;
256 result.total_time = prediction_finished_at - start_time.unwrap();
257
258 break;
259 }
260 }
261 }
262 anyhow::Ok(())
263 }
264 });
265
266 zeta.update(cx, |zeta, _cx| {
267 let mut options = zeta.options().clone();
268 options.prompt_format = prompt_format.into();
269 zeta.set_options(options);
270 })?;
271
272 if use_expected_context {
273 let context_excerpts_tasks = example
274 .example
275 .expected_context
276 .iter()
277 .flat_map(|section| {
278 section.alternatives[0].excerpts.iter().map(|excerpt| {
279 resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
280 })
281 })
282 .collect::<Vec<_>>();
283 let context_excerpts_vec = futures::future::try_join_all(context_excerpts_tasks).await?;
284
285 let mut context_excerpts = HashMap::default();
286 for (buffer, mut excerpts) in context_excerpts_vec {
287 context_excerpts
288 .entry(buffer)
289 .or_insert(Vec::new())
290 .append(&mut excerpts);
291 }
292
293 zeta.update(cx, |zeta, _cx| {
294 zeta.set_context(project.clone(), context_excerpts)
295 })?;
296 } else {
297 zeta.update(cx, |zeta, cx| {
298 zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
299 })?
300 .await?;
301 }
302
303 let prediction = zeta
304 .update(cx, |zeta, cx| {
305 zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
306 })?
307 .await?;
308
309 debug_task.await?;
310
311 let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
312 result.diff = prediction
313 .map(|prediction| {
314 let old_text = prediction.snapshot.text();
315 let new_text = prediction
316 .buffer
317 .update(cx, |buffer, cx| {
318 buffer.edit(prediction.edits.iter().cloned(), None, cx);
319 buffer.text()
320 })
321 .unwrap();
322 language::unified_diff(&old_text, &new_text)
323 })
324 .unwrap_or_default();
325
326 anyhow::Ok(result)
327}
328
329async fn resolve_context_entry(
330 project: Entity<Project>,
331 excerpt: ExpectedExcerpt,
332 mut cx: AsyncApp,
333) -> Result<(Entity<Buffer>, Vec<Range<Anchor>>)> {
334 let buffer = project
335 .update(&mut cx, |project, cx| {
336 let project_path = project.find_project_path(&excerpt.path, cx).unwrap();
337 project.open_buffer(project_path, cx)
338 })?
339 .await?;
340
341 let ranges = buffer.read_with(&mut cx, |buffer, _| {
342 let full_text = buffer.text();
343 let offset = full_text
344 .find(&excerpt.text)
345 .expect("Expected context not found");
346 let point = buffer.offset_to_point(offset);
347 excerpt
348 .required_lines
349 .iter()
350 .map(|line| {
351 let row = point.row + line.0;
352 let range = Point::new(row, 0)..Point::new(row + 1, 0);
353 buffer.anchor_after(range.start)..buffer.anchor_before(range.end)
354 })
355 .collect()
356 })?;
357
358 Ok((buffer, ranges))
359}
360
361struct RunCache {
362 cache_mode: CacheMode,
363 example_run_dir: PathBuf,
364}
365
366impl RunCache {
367 fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
368 CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
369 }
370
371 fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
372 CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
373 }
374
375 fn link_to_run(&self, key: &EvalCacheKey) {
376 let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
377 fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
378
379 let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
380 fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
381 }
382}
383
384impl EvalCache for RunCache {
385 fn read(&self, key: EvalCacheKey) -> Option<String> {
386 let path = RunCache::output_cache_path(&key);
387
388 if path.exists() {
389 let use_cache = match key.0 {
390 EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
391 EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
392 self.cache_mode.use_cached_llm_responses()
393 }
394 };
395 if use_cache {
396 log::info!("Using cache entry: {}", path.display());
397 self.link_to_run(&key);
398 Some(fs::read_to_string(path).unwrap())
399 } else {
400 log::info!("Skipping cached entry: {}", path.display());
401 None
402 }
403 } else if matches!(self.cache_mode, CacheMode::Force) {
404 panic!(
405 "No cached entry found for {:?}. Run without `--cache force` at least once.",
406 key.0
407 );
408 } else {
409 None
410 }
411 }
412
413 fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
414 fs::create_dir_all(&*CACHE_DIR).unwrap();
415
416 let input_path = RunCache::input_cache_path(&key);
417 fs::write(&input_path, input).unwrap();
418
419 let output_path = RunCache::output_cache_path(&key);
420 log::info!("Writing cache entry: {}", output_path.display());
421 fs::write(&output_path, output).unwrap();
422
423 self.link_to_run(&key);
424 }
425}
426
427#[derive(Clone, Debug, Default, Serialize, Deserialize)]
428pub struct PredictionDetails {
429 pub diff: String,
430 pub excerpts: Vec<ActualExcerpt>,
431 pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
432 pub planning_search_time: Option<Duration>,
433 pub running_search_time: Option<Duration>,
434 pub prediction_time: Duration,
435 pub total_time: Duration,
436}
437
438impl PredictionDetails {
439 pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
440 let formatted = match format {
441 PredictionsOutputFormat::Md => self.to_markdown(),
442 PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
443 PredictionsOutputFormat::Diff => self.diff.clone(),
444 };
445
446 Ok(out.write_all(formatted.as_bytes())?)
447 }
448
449 pub fn to_markdown(&self) -> String {
450 let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
451
452 format!(
453 "## Excerpts\n\n\
454 {}\n\n\
455 ## Prediction\n\n\
456 {}\n\n\
457 ## Time\n\n\
458 Planning searches: {}ms\n\
459 Running searches: {}ms\n\
460 Making Prediction: {}ms\n\n\
461 -------------------\n\n\
462 Total: {}ms\n\
463 Inference: {}ms ({:.2}%)\n",
464 self.excerpts_text,
465 self.diff,
466 self.planning_search_time.unwrap_or_default().as_millis(),
467 self.running_search_time.unwrap_or_default().as_millis(),
468 self.prediction_time.as_millis(),
469 self.total_time.as_millis(),
470 inference_time.as_millis(),
471 (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
472 )
473 }
474}