1use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
2use crate::headless::ZetaCliAppState;
3use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
4use crate::{
5 CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
6};
7use ::serde::Serialize;
8use anyhow::{Context, Result, anyhow};
9use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
10use collections::HashMap;
11use futures::StreamExt as _;
12use gpui::{AppContext, AsyncApp, Entity};
13use language::{Anchor, Buffer, Point};
14use project::Project;
15use project::buffer_store::BufferStoreEvent;
16use serde::Deserialize;
17use std::fs;
18use std::io::{IsTerminal, Write};
19use std::ops::Range;
20use std::path::PathBuf;
21use std::sync::Arc;
22use std::sync::Mutex;
23use std::time::{Duration, Instant};
24use sweep_ai::SweepAi;
25use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
26
27pub async fn run_predict(
28 args: PredictArguments,
29 app_state: &Arc<ZetaCliAppState>,
30 cx: &mut AsyncApp,
31) {
32 let example = NamedExample::load(args.example_path).unwrap();
33 let project = example.setup_project(app_state, cx).await.unwrap();
34 let zeta = setup_zeta(&project, app_state, cx).unwrap();
35 let sweep = if matches!(args.options.provider, PredictionProvider::Sweep) {
36 Some(setup_sweep(&project, cx).unwrap())
37 } else {
38 None
39 };
40 let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
41 let result = perform_predict(example, project, zeta, sweep, None, args.options, cx)
42 .await
43 .unwrap();
44 result.write(args.format, std::io::stdout()).unwrap();
45
46 print_run_data_dir(true, std::io::stdout().is_terminal());
47}
48
49pub fn setup_zeta(
50 project: &Entity<Project>,
51 app_state: &Arc<ZetaCliAppState>,
52 cx: &mut AsyncApp,
53) -> Result<Entity<Zeta>> {
54 let zeta =
55 cx.new(|cx| zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
56
57 let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
58
59 cx.subscribe(&buffer_store, {
60 let project = project.clone();
61 let zeta = zeta.clone();
62 move |_, event, cx| match event {
63 BufferStoreEvent::BufferAdded(buffer) => {
64 zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
65 }
66 _ => {}
67 }
68 })?
69 .detach();
70
71 anyhow::Ok(zeta)
72}
73
74pub fn setup_sweep(project: &Entity<Project>, cx: &mut AsyncApp) -> Result<Entity<SweepAi>> {
75 let sweep = cx.new(|cx| SweepAi::new(cx))?;
76
77 let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
78
79 cx.subscribe(&buffer_store, {
80 let project = project.clone();
81 let sweep = sweep.clone();
82 move |_, event, cx| match event {
83 BufferStoreEvent::BufferAdded(buffer) => {
84 sweep.update(cx, |sweep, cx| sweep.register_buffer(&buffer, &project, cx));
85 }
86 _ => {}
87 }
88 })?
89 .detach();
90
91 anyhow::Ok(sweep)
92}
93
94pub async fn perform_predict(
95 example: NamedExample,
96 project: Entity<Project>,
97 zeta: Entity<Zeta>,
98 sweep: Option<Entity<SweepAi>>,
99 repetition_ix: Option<u16>,
100 options: PredictionOptions,
101 cx: &mut AsyncApp,
102) -> Result<PredictionDetails> {
103 let mut cache_mode = options.cache;
104 if repetition_ix.is_some() {
105 if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
106 panic!("Repetitions are not supported in Auto cache mode");
107 } else {
108 cache_mode = CacheMode::Skip;
109 }
110 } else if cache_mode == CacheMode::Auto {
111 cache_mode = CacheMode::Requests;
112 }
113
114 let mut example_run_dir = RUN_DIR.join(&example.file_name());
115 if let Some(repetition_ix) = repetition_ix {
116 example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
117 }
118 fs::create_dir_all(&example_run_dir)?;
119 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
120 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
121 }
122
123 #[cfg(unix)]
124 std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
125 .context("creating latest link")?;
126
127 #[cfg(windows)]
128 std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
129 .context("creating latest link")?;
130
131 zeta.update(cx, |zeta, _cx| {
132 zeta.with_eval_cache(Arc::new(RunCache {
133 example_run_dir: example_run_dir.clone(),
134 cache_mode,
135 }));
136 })?;
137
138 let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
139
140 let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
141
142 let prompt_format = options.zeta2.prompt_format;
143
144 zeta.update(cx, |zeta, _cx| {
145 let mut options = zeta.options().clone();
146 options.prompt_format = prompt_format.into();
147 zeta.set_options(options);
148 })?;
149
150 let prediction = match options.provider {
151 crate::PredictionProvider::Zeta2 => {
152 let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
153
154 let debug_task = cx.background_spawn({
155 let result = result.clone();
156 async move {
157 let mut start_time = None;
158 let mut search_queries_generated_at = None;
159 let mut search_queries_executed_at = None;
160 while let Some(event) = debug_rx.next().await {
161 match event {
162 zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
163 start_time = Some(info.timestamp);
164 fs::write(
165 example_run_dir.join("search_prompt.md"),
166 &info.search_prompt,
167 )?;
168 }
169 zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
170 search_queries_generated_at = Some(info.timestamp);
171 fs::write(
172 example_run_dir.join("search_queries.json"),
173 serde_json::to_string_pretty(&info.search_queries).unwrap(),
174 )?;
175 }
176 zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
177 search_queries_executed_at = Some(info.timestamp);
178 }
179 zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
180 zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
181 let prediction_started_at = Instant::now();
182 start_time.get_or_insert(prediction_started_at);
183 let prompt = request.local_prompt.unwrap_or_default();
184 fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
185
186 {
187 let mut result = result.lock().unwrap();
188 result.prompt_len = prompt.chars().count();
189
190 for included_file in request.request.included_files {
191 let insertions =
192 vec![(request.request.cursor_point, CURSOR_MARKER)];
193 result.excerpts.extend(included_file.excerpts.iter().map(
194 |excerpt| {
195 ActualExcerpt {
196 path: included_file
197 .path
198 .components()
199 .skip(1)
200 .collect(),
201 text: String::from(excerpt.text.as_ref()),
202 }
203 },
204 ));
205 write_codeblock(
206 &included_file.path,
207 included_file.excerpts.iter(),
208 if included_file.path == request.request.excerpt_path {
209 &insertions
210 } else {
211 &[]
212 },
213 included_file.max_row,
214 false,
215 &mut result.excerpts_text,
216 );
217 }
218 }
219
220 let response =
221 request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
222 let response =
223 zeta2::text_from_response(response).unwrap_or_default();
224 let prediction_finished_at = Instant::now();
225 fs::write(
226 example_run_dir.join("prediction_response.md"),
227 &response,
228 )?;
229
230 let mut result = result.lock().unwrap();
231 result.generated_len = response.chars().count();
232
233 if !options.use_expected_context {
234 result.planning_search_time = Some(
235 search_queries_generated_at.unwrap() - start_time.unwrap(),
236 );
237 result.running_search_time = Some(
238 search_queries_executed_at.unwrap()
239 - search_queries_generated_at.unwrap(),
240 );
241 }
242 result.prediction_time =
243 prediction_finished_at - prediction_started_at;
244 result.total_time = prediction_finished_at - start_time.unwrap();
245
246 break;
247 }
248 }
249 }
250 anyhow::Ok(())
251 }
252 });
253
254 if options.use_expected_context {
255 let context_excerpts_tasks = example
256 .example
257 .expected_context
258 .iter()
259 .flat_map(|section| {
260 section.alternatives[0].excerpts.iter().map(|excerpt| {
261 resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
262 })
263 })
264 .collect::<Vec<_>>();
265 let context_excerpts_vec =
266 futures::future::try_join_all(context_excerpts_tasks).await?;
267
268 let mut context_excerpts = HashMap::default();
269 for (buffer, mut excerpts) in context_excerpts_vec {
270 context_excerpts
271 .entry(buffer)
272 .or_insert(Vec::new())
273 .append(&mut excerpts);
274 }
275
276 zeta.update(cx, |zeta, _cx| {
277 zeta.set_context(project.clone(), context_excerpts)
278 })?;
279 } else {
280 zeta.update(cx, |zeta, cx| {
281 zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
282 })?
283 .await?;
284 }
285
286 let prediction = zeta
287 .update(cx, |zeta, cx| {
288 zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
289 })?
290 .await?
291 .map(|prediction| (prediction.buffer, prediction.snapshot, prediction.edits));
292
293 debug_task.await?;
294
295 prediction
296 }
297 crate::PredictionProvider::Sweep => sweep
298 .unwrap()
299 .update(cx, |sweep, cx| {
300 let mut recent_paths = Vec::new();
301 for path in zeta
302 .read(cx)
303 .history_for_project(&project)
304 .rev()
305 .filter_map(|event| event.project_path(cx))
306 {
307 if !recent_paths.contains(&path) {
308 recent_paths.push(path);
309 }
310 }
311
312 sweep.request_completion(
313 &project,
314 recent_paths.into_iter(),
315 &cursor_buffer,
316 cursor_anchor,
317 cx,
318 )
319 })?
320 .await?
321 .map(
322 |sweep_ai::EditPrediction {
323 edits, snapshot, ..
324 }| { (cursor_buffer.clone(), snapshot, edits) },
325 ),
326 };
327
328 let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
329
330 result.diff = prediction
331 .map(|(buffer, snapshot, edits)| {
332 let old_text = snapshot.text();
333 let new_text = buffer
334 .update(cx, |buffer, cx| {
335 let branch = buffer.branch(cx);
336 branch.update(cx, |branch, cx| {
337 branch.edit(edits.iter().cloned(), None, cx);
338 branch.text()
339 })
340 })
341 .unwrap();
342 language::unified_diff(&old_text, &new_text)
343 })
344 .unwrap_or_default();
345
346 anyhow::Ok(result)
347}
348
349async fn resolve_context_entry(
350 project: Entity<Project>,
351 excerpt: ExpectedExcerpt,
352 mut cx: AsyncApp,
353) -> Result<(Entity<Buffer>, Vec<Range<Anchor>>)> {
354 let buffer = project
355 .update(&mut cx, |project, cx| {
356 let project_path = project.find_project_path(&excerpt.path, cx).unwrap();
357 project.open_buffer(project_path, cx)
358 })?
359 .await?;
360
361 let ranges = buffer.read_with(&mut cx, |buffer, _| {
362 let full_text = buffer.text();
363 let offset = full_text
364 .find(&excerpt.text)
365 .expect("Expected context not found");
366 let point = buffer.offset_to_point(offset);
367 excerpt
368 .required_lines
369 .iter()
370 .map(|line| {
371 let row = point.row + line.0;
372 let range = Point::new(row, 0)..Point::new(row + 1, 0);
373 buffer.anchor_after(range.start)..buffer.anchor_before(range.end)
374 })
375 .collect()
376 })?;
377
378 Ok((buffer, ranges))
379}
380
381struct RunCache {
382 cache_mode: CacheMode,
383 example_run_dir: PathBuf,
384}
385
386impl RunCache {
387 fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
388 CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
389 }
390
391 fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
392 CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
393 }
394
395 fn link_to_run(&self, key: &EvalCacheKey) {
396 let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
397 fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
398
399 let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
400 fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
401 }
402}
403
404impl EvalCache for RunCache {
405 fn read(&self, key: EvalCacheKey) -> Option<String> {
406 let path = RunCache::output_cache_path(&key);
407
408 if path.exists() {
409 let use_cache = match key.0 {
410 EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
411 EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
412 self.cache_mode.use_cached_llm_responses()
413 }
414 };
415 if use_cache {
416 log::info!("Using cache entry: {}", path.display());
417 self.link_to_run(&key);
418 Some(fs::read_to_string(path).unwrap())
419 } else {
420 log::trace!("Skipping cached entry: {}", path.display());
421 None
422 }
423 } else if matches!(self.cache_mode, CacheMode::Force) {
424 panic!(
425 "No cached entry found for {:?}. Run without `--cache force` at least once.",
426 key.0
427 );
428 } else {
429 None
430 }
431 }
432
433 fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
434 fs::create_dir_all(&*CACHE_DIR).unwrap();
435
436 let input_path = RunCache::input_cache_path(&key);
437 fs::write(&input_path, input).unwrap();
438
439 let output_path = RunCache::output_cache_path(&key);
440 log::trace!("Writing cache entry: {}", output_path.display());
441 fs::write(&output_path, output).unwrap();
442
443 self.link_to_run(&key);
444 }
445}
446
447#[derive(Clone, Debug, Serialize, Deserialize)]
448pub struct PredictionDetails {
449 pub diff: String,
450 pub excerpts: Vec<ActualExcerpt>,
451 pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
452 pub planning_search_time: Option<Duration>,
453 pub running_search_time: Option<Duration>,
454 pub prediction_time: Duration,
455 pub total_time: Duration,
456 pub run_example_dir: PathBuf,
457 pub prompt_len: usize,
458 pub generated_len: usize,
459}
460
461impl PredictionDetails {
462 pub fn new(run_example_dir: PathBuf) -> Self {
463 Self {
464 diff: Default::default(),
465 excerpts: Default::default(),
466 excerpts_text: Default::default(),
467 planning_search_time: Default::default(),
468 running_search_time: Default::default(),
469 prediction_time: Default::default(),
470 total_time: Default::default(),
471 run_example_dir,
472 prompt_len: 0,
473 generated_len: 0,
474 }
475 }
476
477 pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
478 let formatted = match format {
479 PredictionsOutputFormat::Md => self.to_markdown(),
480 PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
481 PredictionsOutputFormat::Diff => self.diff.clone(),
482 };
483
484 Ok(out.write_all(formatted.as_bytes())?)
485 }
486
487 pub fn to_markdown(&self) -> String {
488 let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
489
490 format!(
491 "## Excerpts\n\n\
492 {}\n\n\
493 ## Prediction\n\n\
494 {}\n\n\
495 ## Time\n\n\
496 Planning searches: {}ms\n\
497 Running searches: {}ms\n\
498 Making Prediction: {}ms\n\n\
499 -------------------\n\n\
500 Total: {}ms\n\
501 Inference: {}ms ({:.2}%)\n",
502 self.excerpts_text,
503 self.diff,
504 self.planning_search_time.unwrap_or_default().as_millis(),
505 self.running_search_time.unwrap_or_default().as_millis(),
506 self.prediction_time.as_millis(),
507 self.total_time.as_millis(),
508 inference_time.as_millis(),
509 (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
510 )
511 }
512}