predict.rs

  1use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
  2use crate::headless::ZetaCliAppState;
  3use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
  4use crate::{
  5    CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
  6};
  7use ::serde::Serialize;
  8use anyhow::{Context, Result, anyhow};
  9use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
 10use collections::HashMap;
 11use futures::StreamExt as _;
 12use gpui::{AppContext, AsyncApp, Entity};
 13use language::{Anchor, Buffer, Point};
 14use project::Project;
 15use project::buffer_store::BufferStoreEvent;
 16use serde::Deserialize;
 17use std::fs;
 18use std::io::{IsTerminal, Write};
 19use std::ops::Range;
 20use std::path::PathBuf;
 21use std::sync::Arc;
 22use std::sync::Mutex;
 23use std::time::{Duration, Instant};
 24use sweep_ai::SweepAi;
 25use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
 26
 27pub async fn run_predict(
 28    args: PredictArguments,
 29    app_state: &Arc<ZetaCliAppState>,
 30    cx: &mut AsyncApp,
 31) {
 32    let example = NamedExample::load(args.example_path).unwrap();
 33    let project = example.setup_project(app_state, cx).await.unwrap();
 34    let zeta = setup_zeta(&project, app_state, cx).unwrap();
 35    let sweep = if matches!(args.options.provider, PredictionProvider::Sweep) {
 36        Some(setup_sweep(&project, cx).unwrap())
 37    } else {
 38        None
 39    };
 40    let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
 41    let result = perform_predict(example, project, zeta, sweep, None, args.options, cx)
 42        .await
 43        .unwrap();
 44    result.write(args.format, std::io::stdout()).unwrap();
 45
 46    print_run_data_dir(true, std::io::stdout().is_terminal());
 47}
 48
 49pub fn setup_zeta(
 50    project: &Entity<Project>,
 51    app_state: &Arc<ZetaCliAppState>,
 52    cx: &mut AsyncApp,
 53) -> Result<Entity<Zeta>> {
 54    let zeta =
 55        cx.new(|cx| zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
 56
 57    let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
 58
 59    cx.subscribe(&buffer_store, {
 60        let project = project.clone();
 61        let zeta = zeta.clone();
 62        move |_, event, cx| match event {
 63            BufferStoreEvent::BufferAdded(buffer) => {
 64                zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
 65            }
 66            _ => {}
 67        }
 68    })?
 69    .detach();
 70
 71    anyhow::Ok(zeta)
 72}
 73
 74pub fn setup_sweep(project: &Entity<Project>, cx: &mut AsyncApp) -> Result<Entity<SweepAi>> {
 75    let sweep = cx.new(|cx| SweepAi::new(cx))?;
 76
 77    let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
 78
 79    cx.subscribe(&buffer_store, {
 80        let project = project.clone();
 81        let sweep = sweep.clone();
 82        move |_, event, cx| match event {
 83            BufferStoreEvent::BufferAdded(buffer) => {
 84                sweep.update(cx, |sweep, cx| sweep.register_buffer(&buffer, &project, cx));
 85            }
 86            _ => {}
 87        }
 88    })?
 89    .detach();
 90
 91    anyhow::Ok(sweep)
 92}
 93
 94pub async fn perform_predict(
 95    example: NamedExample,
 96    project: Entity<Project>,
 97    zeta: Entity<Zeta>,
 98    sweep: Option<Entity<SweepAi>>,
 99    repetition_ix: Option<u16>,
100    options: PredictionOptions,
101    cx: &mut AsyncApp,
102) -> Result<PredictionDetails> {
103    let mut cache_mode = options.cache;
104    if repetition_ix.is_some() {
105        if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
106            panic!("Repetitions are not supported in Auto cache mode");
107        } else {
108            cache_mode = CacheMode::Skip;
109        }
110    } else if cache_mode == CacheMode::Auto {
111        cache_mode = CacheMode::Requests;
112    }
113
114    let mut example_run_dir = RUN_DIR.join(&example.file_name());
115    if let Some(repetition_ix) = repetition_ix {
116        example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
117    }
118    fs::create_dir_all(&example_run_dir)?;
119    if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
120        fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
121    }
122
123    #[cfg(unix)]
124    std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
125        .context("creating latest link")?;
126
127    #[cfg(windows)]
128    std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
129        .context("creating latest link")?;
130
131    zeta.update(cx, |zeta, _cx| {
132        zeta.with_eval_cache(Arc::new(RunCache {
133            example_run_dir: example_run_dir.clone(),
134            cache_mode,
135        }));
136    })?;
137
138    let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
139
140    let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
141
142    let prompt_format = options.zeta2.prompt_format;
143
144    zeta.update(cx, |zeta, _cx| {
145        let mut options = zeta.options().clone();
146        options.prompt_format = prompt_format.into();
147        zeta.set_options(options);
148    })?;
149
150    let prediction = match options.provider {
151        crate::PredictionProvider::Zeta2 => {
152            let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
153
154            let debug_task = cx.background_spawn({
155                let result = result.clone();
156                async move {
157                    let mut start_time = None;
158                    let mut search_queries_generated_at = None;
159                    let mut search_queries_executed_at = None;
160                    while let Some(event) = debug_rx.next().await {
161                        match event {
162                            zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
163                                start_time = Some(info.timestamp);
164                                fs::write(
165                                    example_run_dir.join("search_prompt.md"),
166                                    &info.search_prompt,
167                                )?;
168                            }
169                            zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
170                                search_queries_generated_at = Some(info.timestamp);
171                                fs::write(
172                                    example_run_dir.join("search_queries.json"),
173                                    serde_json::to_string_pretty(&info.search_queries).unwrap(),
174                                )?;
175                            }
176                            zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
177                                search_queries_executed_at = Some(info.timestamp);
178                            }
179                            zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
180                            zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
181                                let prediction_started_at = Instant::now();
182                                start_time.get_or_insert(prediction_started_at);
183                                let prompt = request.local_prompt.unwrap_or_default();
184                                fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
185
186                                {
187                                    let mut result = result.lock().unwrap();
188                                    result.prompt_len = prompt.chars().count();
189
190                                    for included_file in request.request.included_files {
191                                        let insertions =
192                                            vec![(request.request.cursor_point, CURSOR_MARKER)];
193                                        result.excerpts.extend(included_file.excerpts.iter().map(
194                                            |excerpt| {
195                                                ActualExcerpt {
196                                                    path: included_file
197                                                        .path
198                                                        .components()
199                                                        .skip(1)
200                                                        .collect(),
201                                                    text: String::from(excerpt.text.as_ref()),
202                                                }
203                                            },
204                                        ));
205                                        write_codeblock(
206                                            &included_file.path,
207                                            included_file.excerpts.iter(),
208                                            if included_file.path == request.request.excerpt_path {
209                                                &insertions
210                                            } else {
211                                                &[]
212                                            },
213                                            included_file.max_row,
214                                            false,
215                                            &mut result.excerpts_text,
216                                        );
217                                    }
218                                }
219
220                                let response =
221                                    request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
222                                let response =
223                                    zeta2::text_from_response(response).unwrap_or_default();
224                                let prediction_finished_at = Instant::now();
225                                fs::write(
226                                    example_run_dir.join("prediction_response.md"),
227                                    &response,
228                                )?;
229
230                                let mut result = result.lock().unwrap();
231                                result.generated_len = response.chars().count();
232
233                                if !options.use_expected_context {
234                                    result.planning_search_time = Some(
235                                        search_queries_generated_at.unwrap() - start_time.unwrap(),
236                                    );
237                                    result.running_search_time = Some(
238                                        search_queries_executed_at.unwrap()
239                                            - search_queries_generated_at.unwrap(),
240                                    );
241                                }
242                                result.prediction_time =
243                                    prediction_finished_at - prediction_started_at;
244                                result.total_time = prediction_finished_at - start_time.unwrap();
245
246                                break;
247                            }
248                        }
249                    }
250                    anyhow::Ok(())
251                }
252            });
253
254            if options.use_expected_context {
255                let context_excerpts_tasks = example
256                    .example
257                    .expected_context
258                    .iter()
259                    .flat_map(|section| {
260                        section.alternatives[0].excerpts.iter().map(|excerpt| {
261                            resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
262                        })
263                    })
264                    .collect::<Vec<_>>();
265                let context_excerpts_vec =
266                    futures::future::try_join_all(context_excerpts_tasks).await?;
267
268                let mut context_excerpts = HashMap::default();
269                for (buffer, mut excerpts) in context_excerpts_vec {
270                    context_excerpts
271                        .entry(buffer)
272                        .or_insert(Vec::new())
273                        .append(&mut excerpts);
274                }
275
276                zeta.update(cx, |zeta, _cx| {
277                    zeta.set_context(project.clone(), context_excerpts)
278                })?;
279            } else {
280                zeta.update(cx, |zeta, cx| {
281                    zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
282                })?
283                .await?;
284            }
285
286            let prediction = zeta
287                .update(cx, |zeta, cx| {
288                    zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
289                })?
290                .await?
291                .map(|prediction| (prediction.buffer, prediction.snapshot, prediction.edits));
292
293            debug_task.await?;
294
295            prediction
296        }
297        crate::PredictionProvider::Sweep => sweep
298            .unwrap()
299            .update(cx, |sweep, cx| {
300                let mut recent_paths = Vec::new();
301                for path in zeta
302                    .read(cx)
303                    .history_for_project(&project)
304                    .rev()
305                    .filter_map(|event| event.project_path(cx))
306                {
307                    if !recent_paths.contains(&path) {
308                        recent_paths.push(path);
309                    }
310                }
311
312                sweep.request_completion(
313                    &project,
314                    recent_paths.into_iter(),
315                    &cursor_buffer,
316                    cursor_anchor,
317                    cx,
318                )
319            })?
320            .await?
321            .map(
322                |sweep_ai::EditPrediction {
323                     edits, snapshot, ..
324                 }| { (cursor_buffer.clone(), snapshot, edits) },
325            ),
326    };
327
328    let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
329
330    result.diff = prediction
331        .map(|(buffer, snapshot, edits)| {
332            let old_text = snapshot.text();
333            let new_text = buffer
334                .update(cx, |buffer, cx| {
335                    let branch = buffer.branch(cx);
336                    branch.update(cx, |branch, cx| {
337                        branch.edit(edits.iter().cloned(), None, cx);
338                        branch.text()
339                    })
340                })
341                .unwrap();
342            language::unified_diff(&old_text, &new_text)
343        })
344        .unwrap_or_default();
345
346    anyhow::Ok(result)
347}
348
349async fn resolve_context_entry(
350    project: Entity<Project>,
351    excerpt: ExpectedExcerpt,
352    mut cx: AsyncApp,
353) -> Result<(Entity<Buffer>, Vec<Range<Anchor>>)> {
354    let buffer = project
355        .update(&mut cx, |project, cx| {
356            let project_path = project.find_project_path(&excerpt.path, cx).unwrap();
357            project.open_buffer(project_path, cx)
358        })?
359        .await?;
360
361    let ranges = buffer.read_with(&mut cx, |buffer, _| {
362        let full_text = buffer.text();
363        let offset = full_text
364            .find(&excerpt.text)
365            .expect("Expected context not found");
366        let point = buffer.offset_to_point(offset);
367        excerpt
368            .required_lines
369            .iter()
370            .map(|line| {
371                let row = point.row + line.0;
372                let range = Point::new(row, 0)..Point::new(row + 1, 0);
373                buffer.anchor_after(range.start)..buffer.anchor_before(range.end)
374            })
375            .collect()
376    })?;
377
378    Ok((buffer, ranges))
379}
380
381struct RunCache {
382    cache_mode: CacheMode,
383    example_run_dir: PathBuf,
384}
385
386impl RunCache {
387    fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
388        CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
389    }
390
391    fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
392        CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
393    }
394
395    fn link_to_run(&self, key: &EvalCacheKey) {
396        let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
397        fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
398
399        let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
400        fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
401    }
402}
403
404impl EvalCache for RunCache {
405    fn read(&self, key: EvalCacheKey) -> Option<String> {
406        let path = RunCache::output_cache_path(&key);
407
408        if path.exists() {
409            let use_cache = match key.0 {
410                EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
411                EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
412                    self.cache_mode.use_cached_llm_responses()
413                }
414            };
415            if use_cache {
416                log::info!("Using cache entry: {}", path.display());
417                self.link_to_run(&key);
418                Some(fs::read_to_string(path).unwrap())
419            } else {
420                log::trace!("Skipping cached entry: {}", path.display());
421                None
422            }
423        } else if matches!(self.cache_mode, CacheMode::Force) {
424            panic!(
425                "No cached entry found for {:?}. Run without `--cache force` at least once.",
426                key.0
427            );
428        } else {
429            None
430        }
431    }
432
433    fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
434        fs::create_dir_all(&*CACHE_DIR).unwrap();
435
436        let input_path = RunCache::input_cache_path(&key);
437        fs::write(&input_path, input).unwrap();
438
439        let output_path = RunCache::output_cache_path(&key);
440        log::trace!("Writing cache entry: {}", output_path.display());
441        fs::write(&output_path, output).unwrap();
442
443        self.link_to_run(&key);
444    }
445}
446
447#[derive(Clone, Debug, Serialize, Deserialize)]
448pub struct PredictionDetails {
449    pub diff: String,
450    pub excerpts: Vec<ActualExcerpt>,
451    pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
452    pub planning_search_time: Option<Duration>,
453    pub running_search_time: Option<Duration>,
454    pub prediction_time: Duration,
455    pub total_time: Duration,
456    pub run_example_dir: PathBuf,
457    pub prompt_len: usize,
458    pub generated_len: usize,
459}
460
461impl PredictionDetails {
462    pub fn new(run_example_dir: PathBuf) -> Self {
463        Self {
464            diff: Default::default(),
465            excerpts: Default::default(),
466            excerpts_text: Default::default(),
467            planning_search_time: Default::default(),
468            running_search_time: Default::default(),
469            prediction_time: Default::default(),
470            total_time: Default::default(),
471            run_example_dir,
472            prompt_len: 0,
473            generated_len: 0,
474        }
475    }
476
477    pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
478        let formatted = match format {
479            PredictionsOutputFormat::Md => self.to_markdown(),
480            PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
481            PredictionsOutputFormat::Diff => self.diff.clone(),
482        };
483
484        Ok(out.write_all(formatted.as_bytes())?)
485    }
486
487    pub fn to_markdown(&self) -> String {
488        let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
489
490        format!(
491            "## Excerpts\n\n\
492            {}\n\n\
493            ## Prediction\n\n\
494            {}\n\n\
495            ## Time\n\n\
496            Planning searches: {}ms\n\
497            Running searches: {}ms\n\
498            Making Prediction: {}ms\n\n\
499            -------------------\n\n\
500            Total: {}ms\n\
501            Inference: {}ms ({:.2}%)\n",
502            self.excerpts_text,
503            self.diff,
504            self.planning_search_time.unwrap_or_default().as_millis(),
505            self.running_search_time.unwrap_or_default().as_millis(),
506            self.prediction_time.as_millis(),
507            self.total_time.as_millis(),
508            inference_time.as_millis(),
509            (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
510        )
511    }
512}