1use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
2use crate::headless::ZetaCliAppState;
3use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
4use crate::{
5 CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
6};
7use ::serde::Serialize;
8use anyhow::{Context, Result, anyhow};
9use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
10use collections::HashMap;
11use futures::StreamExt as _;
12use gpui::{AppContext, AsyncApp, Entity};
13use language::{Anchor, Buffer, Point};
14use project::Project;
15use project::buffer_store::BufferStoreEvent;
16use serde::Deserialize;
17use std::fs;
18use std::io::{IsTerminal, Write};
19use std::ops::Range;
20use std::path::PathBuf;
21use std::sync::Arc;
22use std::sync::Mutex;
23use std::time::{Duration, Instant};
24use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
25
26pub async fn run_predict(
27 args: PredictArguments,
28 app_state: &Arc<ZetaCliAppState>,
29 cx: &mut AsyncApp,
30) {
31 let example = NamedExample::load(args.example_path).unwrap();
32 let project = example.setup_project(app_state, cx).await.unwrap();
33 let zeta = setup_zeta(args.options.provider, &project, app_state, cx).unwrap();
34 let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
35 let result = perform_predict(example, project, zeta, None, args.options, cx)
36 .await
37 .unwrap();
38 result.write(args.format, std::io::stdout()).unwrap();
39
40 print_run_data_dir(true, std::io::stdout().is_terminal());
41}
42
43pub fn setup_zeta(
44 provider: PredictionProvider,
45 project: &Entity<Project>,
46 app_state: &Arc<ZetaCliAppState>,
47 cx: &mut AsyncApp,
48) -> Result<Entity<Zeta>> {
49 let zeta =
50 cx.new(|cx| zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
51
52 zeta.update(cx, |zeta, _cx| {
53 let model = match provider {
54 PredictionProvider::Zeta2 => zeta2::ZetaEditPredictionModel::ZedCloud,
55 PredictionProvider::Sweep => zeta2::ZetaEditPredictionModel::Sweep,
56 };
57 zeta.set_edit_prediction_model(model);
58 })?;
59
60 let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
61
62 cx.subscribe(&buffer_store, {
63 let project = project.clone();
64 let zeta = zeta.clone();
65 move |_, event, cx| match event {
66 BufferStoreEvent::BufferAdded(buffer) => {
67 zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
68 }
69 _ => {}
70 }
71 })?
72 .detach();
73
74 anyhow::Ok(zeta)
75}
76
77pub async fn perform_predict(
78 example: NamedExample,
79 project: Entity<Project>,
80 zeta: Entity<Zeta>,
81 repetition_ix: Option<u16>,
82 options: PredictionOptions,
83 cx: &mut AsyncApp,
84) -> Result<PredictionDetails> {
85 let mut cache_mode = options.cache;
86 if repetition_ix.is_some() {
87 if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
88 panic!("Repetitions are not supported in Auto cache mode");
89 } else {
90 cache_mode = CacheMode::Skip;
91 }
92 } else if cache_mode == CacheMode::Auto {
93 cache_mode = CacheMode::Requests;
94 }
95
96 let mut example_run_dir = RUN_DIR.join(&example.file_name());
97 if let Some(repetition_ix) = repetition_ix {
98 example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
99 }
100 fs::create_dir_all(&example_run_dir)?;
101 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
102 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
103 }
104
105 #[cfg(unix)]
106 std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
107 .context("creating latest link")?;
108
109 #[cfg(windows)]
110 std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
111 .context("creating latest link")?;
112
113 zeta.update(cx, |zeta, _cx| {
114 zeta.with_eval_cache(Arc::new(RunCache {
115 example_run_dir: example_run_dir.clone(),
116 cache_mode,
117 }));
118 })?;
119
120 let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
121
122 let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
123
124 let prompt_format = options.zeta2.prompt_format;
125
126 zeta.update(cx, |zeta, _cx| {
127 let mut options = zeta.options().clone();
128 options.prompt_format = prompt_format.into();
129 zeta.set_options(options);
130 })?;
131
132 let mut debug_task = gpui::Task::ready(Ok(()));
133
134 if options.provider == crate::PredictionProvider::Zeta2 {
135 let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
136
137 debug_task = cx.background_spawn({
138 let result = result.clone();
139 async move {
140 let mut start_time = None;
141 let mut search_queries_generated_at = None;
142 let mut search_queries_executed_at = None;
143 while let Some(event) = debug_rx.next().await {
144 match event {
145 zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
146 start_time = Some(info.timestamp);
147 fs::write(
148 example_run_dir.join("search_prompt.md"),
149 &info.search_prompt,
150 )?;
151 }
152 zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
153 search_queries_generated_at = Some(info.timestamp);
154 fs::write(
155 example_run_dir.join("search_queries.json"),
156 serde_json::to_string_pretty(&info.search_queries).unwrap(),
157 )?;
158 }
159 zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
160 search_queries_executed_at = Some(info.timestamp);
161 }
162 zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
163 zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
164 let prediction_started_at = Instant::now();
165 start_time.get_or_insert(prediction_started_at);
166 let prompt = request.local_prompt.unwrap_or_default();
167 fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
168
169 {
170 let mut result = result.lock().unwrap();
171 result.prompt_len = prompt.chars().count();
172
173 for included_file in request.request.included_files {
174 let insertions =
175 vec![(request.request.cursor_point, CURSOR_MARKER)];
176 result.excerpts.extend(included_file.excerpts.iter().map(
177 |excerpt| ActualExcerpt {
178 path: included_file.path.components().skip(1).collect(),
179 text: String::from(excerpt.text.as_ref()),
180 },
181 ));
182 write_codeblock(
183 &included_file.path,
184 included_file.excerpts.iter(),
185 if included_file.path == request.request.excerpt_path {
186 &insertions
187 } else {
188 &[]
189 },
190 included_file.max_row,
191 false,
192 &mut result.excerpts_text,
193 );
194 }
195 }
196
197 let response =
198 request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
199 let response = zeta2::text_from_response(response).unwrap_or_default();
200 let prediction_finished_at = Instant::now();
201 fs::write(example_run_dir.join("prediction_response.md"), &response)?;
202
203 let mut result = result.lock().unwrap();
204 result.generated_len = response.chars().count();
205
206 if !options.use_expected_context {
207 result.planning_search_time = Some(
208 search_queries_generated_at.unwrap() - start_time.unwrap(),
209 );
210 result.running_search_time = Some(
211 search_queries_executed_at.unwrap()
212 - search_queries_generated_at.unwrap(),
213 );
214 }
215 result.prediction_time = prediction_finished_at - prediction_started_at;
216 result.total_time = prediction_finished_at - start_time.unwrap();
217
218 break;
219 }
220 }
221 }
222 anyhow::Ok(())
223 }
224 });
225
226 if options.use_expected_context {
227 let context_excerpts_tasks = example
228 .example
229 .expected_context
230 .iter()
231 .flat_map(|section| {
232 section.alternatives[0].excerpts.iter().map(|excerpt| {
233 resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
234 })
235 })
236 .collect::<Vec<_>>();
237 let context_excerpts_vec =
238 futures::future::try_join_all(context_excerpts_tasks).await?;
239
240 let mut context_excerpts = HashMap::default();
241 for (buffer, mut excerpts) in context_excerpts_vec {
242 context_excerpts
243 .entry(buffer)
244 .or_insert(Vec::new())
245 .append(&mut excerpts);
246 }
247
248 zeta.update(cx, |zeta, _cx| {
249 zeta.set_context(project.clone(), context_excerpts)
250 })?;
251 } else {
252 zeta.update(cx, |zeta, cx| {
253 zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
254 })?
255 .await?;
256 }
257 }
258
259 let prediction = zeta
260 .update(cx, |zeta, cx| {
261 zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
262 })?
263 .await?;
264
265 debug_task.await?;
266
267 let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
268
269 result.diff = prediction
270 .map(|prediction| {
271 let old_text = prediction.snapshot.text();
272 let new_text = prediction
273 .buffer
274 .update(cx, |buffer, cx| {
275 let branch = buffer.branch(cx);
276 branch.update(cx, |branch, cx| {
277 branch.edit(prediction.edits.iter().cloned(), None, cx);
278 branch.text()
279 })
280 })
281 .unwrap();
282 language::unified_diff(&old_text, &new_text)
283 })
284 .unwrap_or_default();
285
286 anyhow::Ok(result)
287}
288
289async fn resolve_context_entry(
290 project: Entity<Project>,
291 excerpt: ExpectedExcerpt,
292 mut cx: AsyncApp,
293) -> Result<(Entity<Buffer>, Vec<Range<Anchor>>)> {
294 let buffer = project
295 .update(&mut cx, |project, cx| {
296 let project_path = project.find_project_path(&excerpt.path, cx).unwrap();
297 project.open_buffer(project_path, cx)
298 })?
299 .await?;
300
301 let ranges = buffer.read_with(&mut cx, |buffer, _| {
302 let full_text = buffer.text();
303 let offset = full_text
304 .find(&excerpt.text)
305 .expect("Expected context not found");
306 let point = buffer.offset_to_point(offset);
307 excerpt
308 .required_lines
309 .iter()
310 .map(|line| {
311 let row = point.row + line.0;
312 let range = Point::new(row, 0)..Point::new(row + 1, 0);
313 buffer.anchor_after(range.start)..buffer.anchor_before(range.end)
314 })
315 .collect()
316 })?;
317
318 Ok((buffer, ranges))
319}
320
321struct RunCache {
322 cache_mode: CacheMode,
323 example_run_dir: PathBuf,
324}
325
326impl RunCache {
327 fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
328 CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
329 }
330
331 fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
332 CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
333 }
334
335 fn link_to_run(&self, key: &EvalCacheKey) {
336 let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
337 fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
338
339 let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
340 fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
341 }
342}
343
344impl EvalCache for RunCache {
345 fn read(&self, key: EvalCacheKey) -> Option<String> {
346 let path = RunCache::output_cache_path(&key);
347
348 if path.exists() {
349 let use_cache = match key.0 {
350 EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
351 EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
352 self.cache_mode.use_cached_llm_responses()
353 }
354 };
355 if use_cache {
356 log::info!("Using cache entry: {}", path.display());
357 self.link_to_run(&key);
358 Some(fs::read_to_string(path).unwrap())
359 } else {
360 log::trace!("Skipping cached entry: {}", path.display());
361 None
362 }
363 } else if matches!(self.cache_mode, CacheMode::Force) {
364 panic!(
365 "No cached entry found for {:?}. Run without `--cache force` at least once.",
366 key.0
367 );
368 } else {
369 None
370 }
371 }
372
373 fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
374 fs::create_dir_all(&*CACHE_DIR).unwrap();
375
376 let input_path = RunCache::input_cache_path(&key);
377 fs::write(&input_path, input).unwrap();
378
379 let output_path = RunCache::output_cache_path(&key);
380 log::trace!("Writing cache entry: {}", output_path.display());
381 fs::write(&output_path, output).unwrap();
382
383 self.link_to_run(&key);
384 }
385}
386
387#[derive(Clone, Debug, Serialize, Deserialize)]
388pub struct PredictionDetails {
389 pub diff: String,
390 pub excerpts: Vec<ActualExcerpt>,
391 pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
392 pub planning_search_time: Option<Duration>,
393 pub running_search_time: Option<Duration>,
394 pub prediction_time: Duration,
395 pub total_time: Duration,
396 pub run_example_dir: PathBuf,
397 pub prompt_len: usize,
398 pub generated_len: usize,
399}
400
401impl PredictionDetails {
402 pub fn new(run_example_dir: PathBuf) -> Self {
403 Self {
404 diff: Default::default(),
405 excerpts: Default::default(),
406 excerpts_text: Default::default(),
407 planning_search_time: Default::default(),
408 running_search_time: Default::default(),
409 prediction_time: Default::default(),
410 total_time: Default::default(),
411 run_example_dir,
412 prompt_len: 0,
413 generated_len: 0,
414 }
415 }
416
417 pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
418 let formatted = match format {
419 PredictionsOutputFormat::Md => self.to_markdown(),
420 PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
421 PredictionsOutputFormat::Diff => self.diff.clone(),
422 };
423
424 Ok(out.write_all(formatted.as_bytes())?)
425 }
426
427 pub fn to_markdown(&self) -> String {
428 let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
429
430 format!(
431 "## Excerpts\n\n\
432 {}\n\n\
433 ## Prediction\n\n\
434 {}\n\n\
435 ## Time\n\n\
436 Planning searches: {}ms\n\
437 Running searches: {}ms\n\
438 Making Prediction: {}ms\n\n\
439 -------------------\n\n\
440 Total: {}ms\n\
441 Inference: {}ms ({:.2}%)\n",
442 self.excerpts_text,
443 self.diff,
444 self.planning_search_time.unwrap_or_default().as_millis(),
445 self.running_search_time.unwrap_or_default().as_millis(),
446 self.prediction_time.as_millis(),
447 self.total_time.as_millis(),
448 inference_time.as_millis(),
449 (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
450 )
451 }
452}