1use crate::example::{ActualExcerpt, NamedExample};
2use crate::headless::ZetaCliAppState;
3use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
4use crate::{
5 CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
6};
7use ::serde::Serialize;
8use anyhow::{Context, Result, anyhow};
9use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
10use futures::StreamExt as _;
11use gpui::{AppContext, AsyncApp, Entity};
12use project::Project;
13use project::buffer_store::BufferStoreEvent;
14use serde::Deserialize;
15use std::fs;
16use std::io::{IsTerminal, Write};
17use std::path::PathBuf;
18use std::sync::Arc;
19use std::sync::Mutex;
20use std::time::{Duration, Instant};
21use zeta::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
22
23pub async fn run_predict(
24 args: PredictArguments,
25 app_state: &Arc<ZetaCliAppState>,
26 cx: &mut AsyncApp,
27) {
28 let example = NamedExample::load(args.example_path).unwrap();
29 let project = example.setup_project(app_state, cx).await.unwrap();
30 let zeta = setup_zeta(args.options.provider, &project, app_state, cx).unwrap();
31 let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
32 let result = perform_predict(example, project, zeta, None, args.options, cx)
33 .await
34 .unwrap();
35 result.write(args.format, std::io::stdout()).unwrap();
36
37 print_run_data_dir(true, std::io::stdout().is_terminal());
38}
39
40pub fn setup_zeta(
41 provider: PredictionProvider,
42 project: &Entity<Project>,
43 app_state: &Arc<ZetaCliAppState>,
44 cx: &mut AsyncApp,
45) -> Result<Entity<Zeta>> {
46 let zeta =
47 cx.new(|cx| zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
48
49 zeta.update(cx, |zeta, _cx| {
50 let model = match provider {
51 PredictionProvider::Zeta1 => zeta::ZetaEditPredictionModel::Zeta1,
52 PredictionProvider::Zeta2 => zeta::ZetaEditPredictionModel::Zeta2,
53 PredictionProvider::Sweep => zeta::ZetaEditPredictionModel::Sweep,
54 };
55 zeta.set_edit_prediction_model(model);
56 })?;
57
58 let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
59
60 cx.subscribe(&buffer_store, {
61 let project = project.clone();
62 let zeta = zeta.clone();
63 move |_, event, cx| match event {
64 BufferStoreEvent::BufferAdded(buffer) => {
65 zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
66 }
67 _ => {}
68 }
69 })?
70 .detach();
71
72 anyhow::Ok(zeta)
73}
74
75pub async fn perform_predict(
76 example: NamedExample,
77 project: Entity<Project>,
78 zeta: Entity<Zeta>,
79 repetition_ix: Option<u16>,
80 options: PredictionOptions,
81 cx: &mut AsyncApp,
82) -> Result<PredictionDetails> {
83 let mut cache_mode = options.cache;
84 if repetition_ix.is_some() {
85 if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
86 panic!("Repetitions are not supported in Auto cache mode");
87 } else {
88 cache_mode = CacheMode::Skip;
89 }
90 } else if cache_mode == CacheMode::Auto {
91 cache_mode = CacheMode::Requests;
92 }
93
94 let mut example_run_dir = RUN_DIR.join(&example.file_name());
95 if let Some(repetition_ix) = repetition_ix {
96 example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
97 }
98 fs::create_dir_all(&example_run_dir)?;
99 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
100 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
101 }
102
103 #[cfg(unix)]
104 std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
105 .context("creating latest link")?;
106
107 #[cfg(windows)]
108 std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
109 .context("creating latest link")?;
110
111 zeta.update(cx, |zeta, _cx| {
112 zeta.with_eval_cache(Arc::new(RunCache {
113 example_run_dir: example_run_dir.clone(),
114 cache_mode,
115 }));
116 })?;
117
118 let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
119
120 let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
121
122 let prompt_format = options.zeta2.prompt_format;
123
124 zeta.update(cx, |zeta, _cx| {
125 let mut options = zeta.options().clone();
126 options.prompt_format = prompt_format.into();
127 zeta.set_options(options);
128 })?;
129
130 let mut debug_task = gpui::Task::ready(Ok(()));
131
132 if options.provider == crate::PredictionProvider::Zeta2 {
133 let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
134
135 debug_task = cx.background_spawn({
136 let result = result.clone();
137 async move {
138 let mut start_time = None;
139 let mut search_queries_generated_at = None;
140 let mut search_queries_executed_at = None;
141 while let Some(event) = debug_rx.next().await {
142 match event {
143 zeta::ZetaDebugInfo::ContextRetrievalStarted(info) => {
144 start_time = Some(info.timestamp);
145 fs::write(
146 example_run_dir.join("search_prompt.md"),
147 &info.search_prompt,
148 )?;
149 }
150 zeta::ZetaDebugInfo::SearchQueriesGenerated(info) => {
151 search_queries_generated_at = Some(info.timestamp);
152 fs::write(
153 example_run_dir.join("search_queries.json"),
154 serde_json::to_string_pretty(&info.search_queries).unwrap(),
155 )?;
156 }
157 zeta::ZetaDebugInfo::SearchQueriesExecuted(info) => {
158 search_queries_executed_at = Some(info.timestamp);
159 }
160 zeta::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
161 zeta::ZetaDebugInfo::EditPredictionRequested(request) => {
162 let prediction_started_at = Instant::now();
163 start_time.get_or_insert(prediction_started_at);
164 let prompt = request.local_prompt.unwrap_or_default();
165 fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
166
167 {
168 let mut result = result.lock().unwrap();
169 result.prompt_len = prompt.chars().count();
170
171 for included_file in request.inputs.included_files {
172 let insertions =
173 vec![(request.inputs.cursor_point, CURSOR_MARKER)];
174 result.excerpts.extend(included_file.excerpts.iter().map(
175 |excerpt| ActualExcerpt {
176 path: included_file.path.components().skip(1).collect(),
177 text: String::from(excerpt.text.as_ref()),
178 },
179 ));
180 write_codeblock(
181 &included_file.path,
182 included_file.excerpts.iter(),
183 if included_file.path == request.inputs.cursor_path {
184 &insertions
185 } else {
186 &[]
187 },
188 included_file.max_row,
189 false,
190 &mut result.excerpts_text,
191 );
192 }
193 }
194
195 let response =
196 request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
197 let response = zeta::text_from_response(response).unwrap_or_default();
198 let prediction_finished_at = Instant::now();
199 fs::write(example_run_dir.join("prediction_response.md"), &response)?;
200
201 let mut result = result.lock().unwrap();
202 result.generated_len = response.chars().count();
203
204 result.planning_search_time =
205 Some(search_queries_generated_at.unwrap() - start_time.unwrap());
206 result.running_search_time = Some(
207 search_queries_executed_at.unwrap()
208 - search_queries_generated_at.unwrap(),
209 );
210 result.prediction_time = prediction_finished_at - prediction_started_at;
211 result.total_time = prediction_finished_at - start_time.unwrap();
212
213 break;
214 }
215 }
216 }
217 anyhow::Ok(())
218 }
219 });
220
221 zeta.update(cx, |zeta, cx| {
222 zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
223 })?
224 .await?;
225 }
226
227 let prediction = zeta
228 .update(cx, |zeta, cx| {
229 zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
230 })?
231 .await?;
232
233 debug_task.await?;
234
235 let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
236
237 result.diff = prediction
238 .and_then(|prediction| prediction.edit_preview.as_unified_diff(&prediction.edits))
239 .unwrap_or_default();
240
241 anyhow::Ok(result)
242}
243
244struct RunCache {
245 cache_mode: CacheMode,
246 example_run_dir: PathBuf,
247}
248
249impl RunCache {
250 fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
251 CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
252 }
253
254 fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
255 CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
256 }
257
258 fn link_to_run(&self, key: &EvalCacheKey) {
259 let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
260 fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
261
262 let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
263 fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
264 }
265}
266
267impl EvalCache for RunCache {
268 fn read(&self, key: EvalCacheKey) -> Option<String> {
269 let path = RunCache::output_cache_path(&key);
270
271 if path.exists() {
272 let use_cache = match key.0 {
273 EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
274 EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
275 self.cache_mode.use_cached_llm_responses()
276 }
277 };
278 if use_cache {
279 log::info!("Using cache entry: {}", path.display());
280 self.link_to_run(&key);
281 Some(fs::read_to_string(path).unwrap())
282 } else {
283 log::trace!("Skipping cached entry: {}", path.display());
284 None
285 }
286 } else if matches!(self.cache_mode, CacheMode::Force) {
287 panic!(
288 "No cached entry found for {:?}. Run without `--cache force` at least once.",
289 key.0
290 );
291 } else {
292 None
293 }
294 }
295
296 fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
297 fs::create_dir_all(&*CACHE_DIR).unwrap();
298
299 let input_path = RunCache::input_cache_path(&key);
300 fs::write(&input_path, input).unwrap();
301
302 let output_path = RunCache::output_cache_path(&key);
303 log::trace!("Writing cache entry: {}", output_path.display());
304 fs::write(&output_path, output).unwrap();
305
306 self.link_to_run(&key);
307 }
308}
309
310#[derive(Clone, Debug, Serialize, Deserialize)]
311pub struct PredictionDetails {
312 pub diff: String,
313 pub excerpts: Vec<ActualExcerpt>,
314 pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
315 pub planning_search_time: Option<Duration>,
316 pub running_search_time: Option<Duration>,
317 pub prediction_time: Duration,
318 pub total_time: Duration,
319 pub run_example_dir: PathBuf,
320 pub prompt_len: usize,
321 pub generated_len: usize,
322}
323
324impl PredictionDetails {
325 pub fn new(run_example_dir: PathBuf) -> Self {
326 Self {
327 diff: Default::default(),
328 excerpts: Default::default(),
329 excerpts_text: Default::default(),
330 planning_search_time: Default::default(),
331 running_search_time: Default::default(),
332 prediction_time: Default::default(),
333 total_time: Default::default(),
334 run_example_dir,
335 prompt_len: 0,
336 generated_len: 0,
337 }
338 }
339
340 pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
341 let formatted = match format {
342 PredictionsOutputFormat::Md => self.to_markdown(),
343 PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
344 PredictionsOutputFormat::Diff => self.diff.clone(),
345 };
346
347 Ok(out.write_all(formatted.as_bytes())?)
348 }
349
350 pub fn to_markdown(&self) -> String {
351 let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
352
353 format!(
354 "## Excerpts\n\n\
355 {}\n\n\
356 ## Prediction\n\n\
357 {}\n\n\
358 ## Time\n\n\
359 Planning searches: {}ms\n\
360 Running searches: {}ms\n\
361 Making Prediction: {}ms\n\n\
362 -------------------\n\n\
363 Total: {}ms\n\
364 Inference: {}ms ({:.2}%)\n",
365 self.excerpts_text,
366 self.diff,
367 self.planning_search_time.unwrap_or_default().as_millis(),
368 self.running_search_time.unwrap_or_default().as_millis(),
369 self.prediction_time.as_millis(),
370 self.total_time.as_millis(),
371 inference_time.as_millis(),
372 (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
373 )
374 }
375}