1use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
2use crate::headless::ZetaCliAppState;
3use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
4use crate::{
5 CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
6};
7use ::serde::Serialize;
8use anyhow::{Context, Result, anyhow};
9use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
10use collections::HashMap;
11use futures::StreamExt as _;
12use gpui::{AppContext, AsyncApp, Entity};
13use language::{Anchor, Buffer, Point};
14use project::Project;
15use project::buffer_store::BufferStoreEvent;
16use serde::Deserialize;
17use std::fs;
18use std::io::{IsTerminal, Write};
19use std::ops::Range;
20use std::path::PathBuf;
21use std::sync::Arc;
22use std::sync::Mutex;
23use std::time::{Duration, Instant};
24use zeta::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
25
26pub async fn run_predict(
27 args: PredictArguments,
28 app_state: &Arc<ZetaCliAppState>,
29 cx: &mut AsyncApp,
30) {
31 let example = NamedExample::load(args.example_path).unwrap();
32 let project = example.setup_project(app_state, cx).await.unwrap();
33 let zeta = setup_zeta(args.options.provider, &project, app_state, cx).unwrap();
34 let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
35 let result = perform_predict(example, project, zeta, None, args.options, cx)
36 .await
37 .unwrap();
38 result.write(args.format, std::io::stdout()).unwrap();
39
40 print_run_data_dir(true, std::io::stdout().is_terminal());
41}
42
43pub fn setup_zeta(
44 provider: PredictionProvider,
45 project: &Entity<Project>,
46 app_state: &Arc<ZetaCliAppState>,
47 cx: &mut AsyncApp,
48) -> Result<Entity<Zeta>> {
49 let zeta =
50 cx.new(|cx| zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
51
52 zeta.update(cx, |zeta, _cx| {
53 let model = match provider {
54 PredictionProvider::Zeta1 => zeta::ZetaEditPredictionModel::Zeta1,
55 PredictionProvider::Zeta2 => zeta::ZetaEditPredictionModel::Zeta2,
56 PredictionProvider::Sweep => zeta::ZetaEditPredictionModel::Sweep,
57 };
58 zeta.set_edit_prediction_model(model);
59 })?;
60
61 let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
62
63 cx.subscribe(&buffer_store, {
64 let project = project.clone();
65 let zeta = zeta.clone();
66 move |_, event, cx| match event {
67 BufferStoreEvent::BufferAdded(buffer) => {
68 zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
69 }
70 _ => {}
71 }
72 })?
73 .detach();
74
75 anyhow::Ok(zeta)
76}
77
78pub async fn perform_predict(
79 example: NamedExample,
80 project: Entity<Project>,
81 zeta: Entity<Zeta>,
82 repetition_ix: Option<u16>,
83 options: PredictionOptions,
84 cx: &mut AsyncApp,
85) -> Result<PredictionDetails> {
86 let mut cache_mode = options.cache;
87 if repetition_ix.is_some() {
88 if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
89 panic!("Repetitions are not supported in Auto cache mode");
90 } else {
91 cache_mode = CacheMode::Skip;
92 }
93 } else if cache_mode == CacheMode::Auto {
94 cache_mode = CacheMode::Requests;
95 }
96
97 let mut example_run_dir = RUN_DIR.join(&example.file_name());
98 if let Some(repetition_ix) = repetition_ix {
99 example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
100 }
101 fs::create_dir_all(&example_run_dir)?;
102 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
103 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
104 }
105
106 #[cfg(unix)]
107 std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
108 .context("creating latest link")?;
109
110 #[cfg(windows)]
111 std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
112 .context("creating latest link")?;
113
114 zeta.update(cx, |zeta, _cx| {
115 zeta.with_eval_cache(Arc::new(RunCache {
116 example_run_dir: example_run_dir.clone(),
117 cache_mode,
118 }));
119 })?;
120
121 let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
122
123 let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
124
125 let prompt_format = options.zeta2.prompt_format;
126
127 zeta.update(cx, |zeta, _cx| {
128 let mut options = zeta.options().clone();
129 options.prompt_format = prompt_format.into();
130 zeta.set_options(options);
131 })?;
132
133 let mut debug_task = gpui::Task::ready(Ok(()));
134
135 if options.provider == crate::PredictionProvider::Zeta2 {
136 let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
137
138 debug_task = cx.background_spawn({
139 let result = result.clone();
140 async move {
141 let mut start_time = None;
142 let mut search_queries_generated_at = None;
143 let mut search_queries_executed_at = None;
144 while let Some(event) = debug_rx.next().await {
145 match event {
146 zeta::ZetaDebugInfo::ContextRetrievalStarted(info) => {
147 start_time = Some(info.timestamp);
148 fs::write(
149 example_run_dir.join("search_prompt.md"),
150 &info.search_prompt,
151 )?;
152 }
153 zeta::ZetaDebugInfo::SearchQueriesGenerated(info) => {
154 search_queries_generated_at = Some(info.timestamp);
155 fs::write(
156 example_run_dir.join("search_queries.json"),
157 serde_json::to_string_pretty(&info.search_queries).unwrap(),
158 )?;
159 }
160 zeta::ZetaDebugInfo::SearchQueriesExecuted(info) => {
161 search_queries_executed_at = Some(info.timestamp);
162 }
163 zeta::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
164 zeta::ZetaDebugInfo::EditPredictionRequested(request) => {
165 let prediction_started_at = Instant::now();
166 start_time.get_or_insert(prediction_started_at);
167 let prompt = request.local_prompt.unwrap_or_default();
168 fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
169
170 {
171 let mut result = result.lock().unwrap();
172 result.prompt_len = prompt.chars().count();
173
174 for included_file in request.inputs.included_files {
175 let insertions =
176 vec![(request.inputs.cursor_point, CURSOR_MARKER)];
177 result.excerpts.extend(included_file.excerpts.iter().map(
178 |excerpt| ActualExcerpt {
179 path: included_file.path.components().skip(1).collect(),
180 text: String::from(excerpt.text.as_ref()),
181 },
182 ));
183 write_codeblock(
184 &included_file.path,
185 included_file.excerpts.iter(),
186 if included_file.path == request.inputs.cursor_path {
187 &insertions
188 } else {
189 &[]
190 },
191 included_file.max_row,
192 false,
193 &mut result.excerpts_text,
194 );
195 }
196 }
197
198 let response =
199 request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
200 let response = zeta::text_from_response(response).unwrap_or_default();
201 let prediction_finished_at = Instant::now();
202 fs::write(example_run_dir.join("prediction_response.md"), &response)?;
203
204 let mut result = result.lock().unwrap();
205 result.generated_len = response.chars().count();
206
207 if !options.use_expected_context {
208 result.planning_search_time = Some(
209 search_queries_generated_at.unwrap() - start_time.unwrap(),
210 );
211 result.running_search_time = Some(
212 search_queries_executed_at.unwrap()
213 - search_queries_generated_at.unwrap(),
214 );
215 }
216 result.prediction_time = prediction_finished_at - prediction_started_at;
217 result.total_time = prediction_finished_at - start_time.unwrap();
218
219 break;
220 }
221 }
222 }
223 anyhow::Ok(())
224 }
225 });
226
227 if options.use_expected_context {
228 let context_excerpts_tasks = example
229 .example
230 .expected_context
231 .iter()
232 .flat_map(|section| {
233 section.alternatives[0].excerpts.iter().map(|excerpt| {
234 resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
235 })
236 })
237 .collect::<Vec<_>>();
238 let context_excerpts_vec =
239 futures::future::try_join_all(context_excerpts_tasks).await?;
240
241 let mut context_excerpts = HashMap::default();
242 for (buffer, mut excerpts) in context_excerpts_vec {
243 context_excerpts
244 .entry(buffer)
245 .or_insert(Vec::new())
246 .append(&mut excerpts);
247 }
248
249 zeta.update(cx, |zeta, _cx| {
250 zeta.set_context(project.clone(), context_excerpts)
251 })?;
252 } else {
253 zeta.update(cx, |zeta, cx| {
254 zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
255 })?
256 .await?;
257 }
258 }
259
260 let prediction = zeta
261 .update(cx, |zeta, cx| {
262 zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
263 })?
264 .await?;
265
266 debug_task.await?;
267
268 let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
269
270 result.diff = prediction
271 .and_then(|prediction| prediction.edit_preview.as_unified_diff(&prediction.edits))
272 .unwrap_or_default();
273
274 anyhow::Ok(result)
275}
276
277async fn resolve_context_entry(
278 project: Entity<Project>,
279 excerpt: ExpectedExcerpt,
280 mut cx: AsyncApp,
281) -> Result<(Entity<Buffer>, Vec<Range<Anchor>>)> {
282 let buffer = project
283 .update(&mut cx, |project, cx| {
284 let project_path = project.find_project_path(&excerpt.path, cx).unwrap();
285 project.open_buffer(project_path, cx)
286 })?
287 .await?;
288
289 let ranges = buffer.read_with(&mut cx, |buffer, _| {
290 let full_text = buffer.text();
291 let offset = full_text
292 .find(&excerpt.text)
293 .expect("Expected context not found");
294 let point = buffer.offset_to_point(offset);
295 excerpt
296 .required_lines
297 .iter()
298 .map(|line| {
299 let row = point.row + line.0;
300 let range = Point::new(row, 0)..Point::new(row + 1, 0);
301 buffer.anchor_after(range.start)..buffer.anchor_before(range.end)
302 })
303 .collect()
304 })?;
305
306 Ok((buffer, ranges))
307}
308
309struct RunCache {
310 cache_mode: CacheMode,
311 example_run_dir: PathBuf,
312}
313
314impl RunCache {
315 fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
316 CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
317 }
318
319 fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
320 CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
321 }
322
323 fn link_to_run(&self, key: &EvalCacheKey) {
324 let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
325 fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
326
327 let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
328 fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
329 }
330}
331
332impl EvalCache for RunCache {
333 fn read(&self, key: EvalCacheKey) -> Option<String> {
334 let path = RunCache::output_cache_path(&key);
335
336 if path.exists() {
337 let use_cache = match key.0 {
338 EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
339 EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
340 self.cache_mode.use_cached_llm_responses()
341 }
342 };
343 if use_cache {
344 log::info!("Using cache entry: {}", path.display());
345 self.link_to_run(&key);
346 Some(fs::read_to_string(path).unwrap())
347 } else {
348 log::trace!("Skipping cached entry: {}", path.display());
349 None
350 }
351 } else if matches!(self.cache_mode, CacheMode::Force) {
352 panic!(
353 "No cached entry found for {:?}. Run without `--cache force` at least once.",
354 key.0
355 );
356 } else {
357 None
358 }
359 }
360
361 fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
362 fs::create_dir_all(&*CACHE_DIR).unwrap();
363
364 let input_path = RunCache::input_cache_path(&key);
365 fs::write(&input_path, input).unwrap();
366
367 let output_path = RunCache::output_cache_path(&key);
368 log::trace!("Writing cache entry: {}", output_path.display());
369 fs::write(&output_path, output).unwrap();
370
371 self.link_to_run(&key);
372 }
373}
374
375#[derive(Clone, Debug, Serialize, Deserialize)]
376pub struct PredictionDetails {
377 pub diff: String,
378 pub excerpts: Vec<ActualExcerpt>,
379 pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
380 pub planning_search_time: Option<Duration>,
381 pub running_search_time: Option<Duration>,
382 pub prediction_time: Duration,
383 pub total_time: Duration,
384 pub run_example_dir: PathBuf,
385 pub prompt_len: usize,
386 pub generated_len: usize,
387}
388
389impl PredictionDetails {
390 pub fn new(run_example_dir: PathBuf) -> Self {
391 Self {
392 diff: Default::default(),
393 excerpts: Default::default(),
394 excerpts_text: Default::default(),
395 planning_search_time: Default::default(),
396 running_search_time: Default::default(),
397 prediction_time: Default::default(),
398 total_time: Default::default(),
399 run_example_dir,
400 prompt_len: 0,
401 generated_len: 0,
402 }
403 }
404
405 pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
406 let formatted = match format {
407 PredictionsOutputFormat::Md => self.to_markdown(),
408 PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
409 PredictionsOutputFormat::Diff => self.diff.clone(),
410 };
411
412 Ok(out.write_all(formatted.as_bytes())?)
413 }
414
415 pub fn to_markdown(&self) -> String {
416 let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
417
418 format!(
419 "## Excerpts\n\n\
420 {}\n\n\
421 ## Prediction\n\n\
422 {}\n\n\
423 ## Time\n\n\
424 Planning searches: {}ms\n\
425 Running searches: {}ms\n\
426 Making Prediction: {}ms\n\n\
427 -------------------\n\n\
428 Total: {}ms\n\
429 Inference: {}ms ({:.2}%)\n",
430 self.excerpts_text,
431 self.diff,
432 self.planning_search_time.unwrap_or_default().as_millis(),
433 self.running_search_time.unwrap_or_default().as_millis(),
434 self.prediction_time.as_millis(),
435 self.total_time.as_millis(),
436 inference_time.as_millis(),
437 (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
438 )
439 }
440}