predict.rs

  1use crate::PromptFormat;
  2use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
  3use crate::headless::ZetaCliAppState;
  4use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
  5use ::serde::Serialize;
  6use anyhow::{Context, Result, anyhow};
  7use clap::{Args, ValueEnum};
  8use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
  9use collections::HashMap;
 10use futures::StreamExt as _;
 11use gpui::{AppContext, AsyncApp, Entity};
 12use language::{Anchor, Buffer, Point};
 13use project::Project;
 14use serde::Deserialize;
 15use std::fs;
 16use std::io::{IsTerminal, Write};
 17use std::ops::Range;
 18use std::path::PathBuf;
 19use std::sync::Arc;
 20use std::sync::Mutex;
 21use std::time::{Duration, Instant};
 22use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
 23
 24#[derive(Debug, Args)]
 25pub struct PredictArguments {
 26    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 27    prompt_format: PromptFormat,
 28    #[arg(long)]
 29    use_expected_context: bool,
 30    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
 31    format: PredictionsOutputFormat,
 32    example_path: PathBuf,
 33    #[clap(long, value_enum, default_value_t = CacheMode::default())]
 34    cache: CacheMode,
 35}
 36
 37#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
 38pub enum CacheMode {
 39    /// Use cached LLM requests and responses, except when multiple repetitions are requested
 40    #[default]
 41    Auto,
 42    /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
 43    #[value(alias = "request")]
 44    Requests,
 45    /// Ignore existing cache entries for both LLM and search.
 46    Skip,
 47    /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
 48    /// Useful for reproducing results and fixing bugs outside of search queries
 49    Force,
 50}
 51
 52impl CacheMode {
 53    fn use_cached_llm_responses(&self) -> bool {
 54        self.assert_not_auto();
 55        matches!(self, CacheMode::Requests | CacheMode::Force)
 56    }
 57
 58    fn use_cached_search_results(&self) -> bool {
 59        self.assert_not_auto();
 60        matches!(self, CacheMode::Force)
 61    }
 62
 63    fn assert_not_auto(&self) {
 64        assert_ne!(
 65            *self,
 66            CacheMode::Auto,
 67            "Cache mode should not be auto at this point!"
 68        );
 69    }
 70}
 71
 72#[derive(clap::ValueEnum, Debug, Clone)]
 73pub enum PredictionsOutputFormat {
 74    Json,
 75    Md,
 76    Diff,
 77}
 78
 79pub async fn run_zeta2_predict(
 80    args: PredictArguments,
 81    app_state: &Arc<ZetaCliAppState>,
 82    cx: &mut AsyncApp,
 83) {
 84    let example = NamedExample::load(args.example_path).unwrap();
 85    let (project, mut zetas, _edited_buffers) =
 86        example.setup_project(app_state, 1, cx).await.unwrap();
 87    let result = zeta2_predict(
 88        example,
 89        project,
 90        zetas.remove(0),
 91        None,
 92        args.prompt_format,
 93        args.use_expected_context,
 94        args.cache,
 95        cx,
 96    )
 97    .await
 98    .unwrap();
 99    result.write(args.format, std::io::stdout()).unwrap();
100
101    print_run_data_dir(true, std::io::stdout().is_terminal());
102}
103
104pub async fn zeta2_predict(
105    example: NamedExample,
106    project: Entity<Project>,
107    zeta: Entity<Zeta>,
108    repetition_ix: Option<u16>,
109    prompt_format: PromptFormat,
110    use_expected_context: bool,
111    mut cache_mode: CacheMode,
112    cx: &mut AsyncApp,
113) -> Result<PredictionDetails> {
114    if repetition_ix.is_some() {
115        if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
116            panic!("Repetitions are not supported in Auto cache mode");
117        } else {
118            cache_mode = CacheMode::Skip;
119        }
120    } else if cache_mode == CacheMode::Auto {
121        cache_mode = CacheMode::Requests;
122    }
123
124    let mut example_run_dir = RUN_DIR.join(&example.file_name());
125    if let Some(repetition_ix) = repetition_ix {
126        example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
127    }
128    fs::create_dir_all(&example_run_dir)?;
129    if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
130        fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
131    }
132
133    #[cfg(unix)]
134    std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
135        .context("creating latest link")?;
136
137    #[cfg(windows)]
138    std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
139        .context("creating latest link")?;
140
141    zeta.update(cx, |zeta, _cx| {
142        zeta.with_eval_cache(Arc::new(RunCache {
143            example_run_dir: example_run_dir.clone(),
144            cache_mode,
145        }));
146    })?;
147
148    let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
149
150    let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
151    let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
152
153    let debug_task = cx.background_spawn({
154        let result = result.clone();
155        async move {
156            let mut start_time = None;
157            let mut search_queries_generated_at = None;
158            let mut search_queries_executed_at = None;
159            while let Some(event) = debug_rx.next().await {
160                match event {
161                    zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
162                        start_time = Some(info.timestamp);
163                        fs::write(
164                            example_run_dir.join("search_prompt.md"),
165                            &info.search_prompt,
166                        )?;
167                    }
168                    zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
169                        search_queries_generated_at = Some(info.timestamp);
170                        fs::write(
171                            example_run_dir.join("search_queries.json"),
172                            serde_json::to_string_pretty(&info.search_queries).unwrap(),
173                        )?;
174                    }
175                    zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
176                        search_queries_executed_at = Some(info.timestamp);
177                    }
178                    zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
179                    zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
180                        let prediction_started_at = Instant::now();
181                        start_time.get_or_insert(prediction_started_at);
182                        fs::write(
183                            example_run_dir.join("prediction_prompt.md"),
184                            &request.local_prompt.unwrap_or_default(),
185                        )?;
186
187                        {
188                            let mut result = result.lock().unwrap();
189
190                            for included_file in request.request.included_files {
191                                let insertions =
192                                    vec![(request.request.cursor_point, CURSOR_MARKER)];
193                                result.excerpts.extend(included_file.excerpts.iter().map(
194                                    |excerpt| ActualExcerpt {
195                                        path: included_file.path.components().skip(1).collect(),
196                                        text: String::from(excerpt.text.as_ref()),
197                                    },
198                                ));
199                                write_codeblock(
200                                    &included_file.path,
201                                    included_file.excerpts.iter(),
202                                    if included_file.path == request.request.excerpt_path {
203                                        &insertions
204                                    } else {
205                                        &[]
206                                    },
207                                    included_file.max_row,
208                                    false,
209                                    &mut result.excerpts_text,
210                                );
211                            }
212                        }
213
214                        let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
215                        let response = zeta2::text_from_response(response).unwrap_or_default();
216                        let prediction_finished_at = Instant::now();
217                        fs::write(example_run_dir.join("prediction_response.md"), &response)?;
218
219                        let mut result = result.lock().unwrap();
220
221                        if !use_expected_context {
222                            result.planning_search_time =
223                                Some(search_queries_generated_at.unwrap() - start_time.unwrap());
224                            result.running_search_time = Some(
225                                search_queries_executed_at.unwrap()
226                                    - search_queries_generated_at.unwrap(),
227                            );
228                        }
229                        result.prediction_time = prediction_finished_at - prediction_started_at;
230                        result.total_time = prediction_finished_at - start_time.unwrap();
231
232                        break;
233                    }
234                }
235            }
236            anyhow::Ok(())
237        }
238    });
239
240    zeta.update(cx, |zeta, _cx| {
241        let mut options = zeta.options().clone();
242        options.prompt_format = prompt_format.into();
243        zeta.set_options(options);
244    })?;
245
246    if use_expected_context {
247        let context_excerpts_tasks = example
248            .example
249            .expected_context
250            .iter()
251            .flat_map(|section| {
252                section.alternatives[0].excerpts.iter().map(|excerpt| {
253                    resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
254                })
255            })
256            .collect::<Vec<_>>();
257        let context_excerpts_vec = futures::future::try_join_all(context_excerpts_tasks).await?;
258
259        let mut context_excerpts = HashMap::default();
260        for (buffer, mut excerpts) in context_excerpts_vec {
261            context_excerpts
262                .entry(buffer)
263                .or_insert(Vec::new())
264                .append(&mut excerpts);
265        }
266
267        zeta.update(cx, |zeta, _cx| {
268            zeta.set_context(project.clone(), context_excerpts)
269        })?;
270    } else {
271        zeta.update(cx, |zeta, cx| {
272            zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
273        })?
274        .await?;
275    }
276
277    let prediction = zeta
278        .update(cx, |zeta, cx| {
279            zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
280        })?
281        .await?;
282
283    debug_task.await?;
284
285    let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
286    result.diff = prediction
287        .map(|prediction| {
288            let old_text = prediction.snapshot.text();
289            let new_text = prediction
290                .buffer
291                .update(cx, |buffer, cx| {
292                    let branch = buffer.branch(cx);
293                    branch.update(cx, |branch, cx| {
294                        branch.edit(prediction.edits.iter().cloned(), None, cx);
295                        branch.text()
296                    })
297                })
298                .unwrap();
299            language::unified_diff(&old_text, &new_text)
300        })
301        .unwrap_or_default();
302
303    anyhow::Ok(result)
304}
305
306async fn resolve_context_entry(
307    project: Entity<Project>,
308    excerpt: ExpectedExcerpt,
309    mut cx: AsyncApp,
310) -> Result<(Entity<Buffer>, Vec<Range<Anchor>>)> {
311    let buffer = project
312        .update(&mut cx, |project, cx| {
313            let project_path = project.find_project_path(&excerpt.path, cx).unwrap();
314            project.open_buffer(project_path, cx)
315        })?
316        .await?;
317
318    let ranges = buffer.read_with(&mut cx, |buffer, _| {
319        let full_text = buffer.text();
320        let offset = full_text
321            .find(&excerpt.text)
322            .expect("Expected context not found");
323        let point = buffer.offset_to_point(offset);
324        excerpt
325            .required_lines
326            .iter()
327            .map(|line| {
328                let row = point.row + line.0;
329                let range = Point::new(row, 0)..Point::new(row + 1, 0);
330                buffer.anchor_after(range.start)..buffer.anchor_before(range.end)
331            })
332            .collect()
333    })?;
334
335    Ok((buffer, ranges))
336}
337
338struct RunCache {
339    cache_mode: CacheMode,
340    example_run_dir: PathBuf,
341}
342
343impl RunCache {
344    fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
345        CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
346    }
347
348    fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
349        CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
350    }
351
352    fn link_to_run(&self, key: &EvalCacheKey) {
353        let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
354        fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
355
356        let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
357        fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
358    }
359}
360
361impl EvalCache for RunCache {
362    fn read(&self, key: EvalCacheKey) -> Option<String> {
363        let path = RunCache::output_cache_path(&key);
364
365        if path.exists() {
366            let use_cache = match key.0 {
367                EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
368                EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
369                    self.cache_mode.use_cached_llm_responses()
370                }
371            };
372            if use_cache {
373                log::info!("Using cache entry: {}", path.display());
374                self.link_to_run(&key);
375                Some(fs::read_to_string(path).unwrap())
376            } else {
377                log::trace!("Skipping cached entry: {}", path.display());
378                None
379            }
380        } else if matches!(self.cache_mode, CacheMode::Force) {
381            panic!(
382                "No cached entry found for {:?}. Run without `--cache force` at least once.",
383                key.0
384            );
385        } else {
386            None
387        }
388    }
389
390    fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
391        fs::create_dir_all(&*CACHE_DIR).unwrap();
392
393        let input_path = RunCache::input_cache_path(&key);
394        fs::write(&input_path, input).unwrap();
395
396        let output_path = RunCache::output_cache_path(&key);
397        log::trace!("Writing cache entry: {}", output_path.display());
398        fs::write(&output_path, output).unwrap();
399
400        self.link_to_run(&key);
401    }
402}
403
404#[derive(Clone, Debug, Serialize, Deserialize)]
405pub struct PredictionDetails {
406    pub diff: String,
407    pub excerpts: Vec<ActualExcerpt>,
408    pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
409    pub planning_search_time: Option<Duration>,
410    pub running_search_time: Option<Duration>,
411    pub prediction_time: Duration,
412    pub total_time: Duration,
413    pub run_example_dir: PathBuf,
414}
415
416impl PredictionDetails {
417    pub fn new(run_example_dir: PathBuf) -> Self {
418        Self {
419            diff: Default::default(),
420            excerpts: Default::default(),
421            excerpts_text: Default::default(),
422            planning_search_time: Default::default(),
423            running_search_time: Default::default(),
424            prediction_time: Default::default(),
425            total_time: Default::default(),
426            run_example_dir,
427        }
428    }
429
430    pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
431        let formatted = match format {
432            PredictionsOutputFormat::Md => self.to_markdown(),
433            PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
434            PredictionsOutputFormat::Diff => self.diff.clone(),
435        };
436
437        Ok(out.write_all(formatted.as_bytes())?)
438    }
439
440    pub fn to_markdown(&self) -> String {
441        let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
442
443        format!(
444            "## Excerpts\n\n\
445            {}\n\n\
446            ## Prediction\n\n\
447            {}\n\n\
448            ## Time\n\n\
449            Planning searches: {}ms\n\
450            Running searches: {}ms\n\
451            Making Prediction: {}ms\n\n\
452            -------------------\n\n\
453            Total: {}ms\n\
454            Inference: {}ms ({:.2}%)\n",
455            self.excerpts_text,
456            self.diff,
457            self.planning_search_time.unwrap_or_default().as_millis(),
458            self.running_search_time.unwrap_or_default().as_millis(),
459            self.prediction_time.as_millis(),
460            self.total_time.as_millis(),
461            inference_time.as_millis(),
462            (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
463        )
464    }
465}