predict.rs

  1use crate::PromptFormat;
  2use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
  3use crate::headless::ZetaCliAppState;
  4use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
  5use ::serde::Serialize;
  6use anyhow::{Context, Result, anyhow};
  7use clap::{Args, ValueEnum};
  8use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
  9use collections::HashMap;
 10use futures::StreamExt as _;
 11use gpui::{AppContext, AsyncApp, Entity};
 12use language::{Anchor, Buffer, Point};
 13use project::Project;
 14use serde::Deserialize;
 15use std::cell::Cell;
 16use std::fs;
 17use std::io::Write;
 18use std::ops::Range;
 19use std::path::PathBuf;
 20use std::sync::Arc;
 21use std::sync::Mutex;
 22use std::time::{Duration, Instant};
 23use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey};
 24
 25#[derive(Debug, Args)]
 26pub struct PredictArguments {
 27    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 28    prompt_format: PromptFormat,
 29    #[arg(long)]
 30    use_expected_context: bool,
 31    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
 32    format: PredictionsOutputFormat,
 33    example_path: PathBuf,
 34    #[clap(long, value_enum, default_value_t = CacheMode::default())]
 35    cache: CacheMode,
 36}
 37
 38#[derive(Debug, ValueEnum, Default, Clone, Copy)]
 39pub enum CacheMode {
 40    /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
 41    #[default]
 42    #[value(alias = "request")]
 43    Requests,
 44    /// Ignore existing cache entries for both LLM and search.
 45    Skip,
 46    /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
 47    /// Useful for reproducing results and fixing bugs outside of search queries
 48    Force,
 49}
 50
 51impl CacheMode {
 52    fn use_cached_llm_responses(&self) -> bool {
 53        matches!(self, CacheMode::Requests | CacheMode::Force)
 54    }
 55
 56    fn use_cached_search_results(&self) -> bool {
 57        matches!(self, CacheMode::Force)
 58    }
 59}
 60
 61#[derive(clap::ValueEnum, Debug, Clone)]
 62pub enum PredictionsOutputFormat {
 63    Json,
 64    Md,
 65    Diff,
 66}
 67
 68pub async fn run_zeta2_predict(
 69    args: PredictArguments,
 70    app_state: &Arc<ZetaCliAppState>,
 71    cx: &mut AsyncApp,
 72) {
 73    let example = NamedExample::load(args.example_path).unwrap();
 74    let result = zeta2_predict(
 75        example,
 76        args.prompt_format,
 77        args.use_expected_context,
 78        args.cache,
 79        &app_state,
 80        cx,
 81    )
 82    .await
 83    .unwrap();
 84    result.write(args.format, std::io::stdout()).unwrap();
 85
 86    print_run_data_dir();
 87}
 88
 89thread_local! {
 90    static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
 91}
 92
 93pub async fn zeta2_predict(
 94    example: NamedExample,
 95    prompt_format: PromptFormat,
 96    use_expected_context: bool,
 97    cache_mode: CacheMode,
 98    app_state: &Arc<ZetaCliAppState>,
 99    cx: &mut AsyncApp,
100) -> Result<PredictionDetails> {
101    let worktree_path = example.setup_worktree().await?;
102
103    if !AUTHENTICATED.get() {
104        AUTHENTICATED.set(true);
105
106        app_state
107            .client
108            .sign_in_with_optional_connect(true, cx)
109            .await?;
110    }
111
112    let project = cx.update(|cx| {
113        Project::local(
114            app_state.client.clone(),
115            app_state.node_runtime.clone(),
116            app_state.user_store.clone(),
117            app_state.languages.clone(),
118            app_state.fs.clone(),
119            None,
120            cx,
121        )
122    })?;
123
124    let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
125
126    let worktree = project
127        .update(cx, |project, cx| {
128            project.create_worktree(&worktree_path, true, cx)
129        })?
130        .await?;
131    worktree
132        .read_with(cx, |worktree, _cx| {
133            worktree.as_local().unwrap().scan_complete()
134        })?
135        .await;
136
137    let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
138
139    let example_run_dir = RUN_DIR.join(&example.file_name());
140    fs::create_dir_all(&example_run_dir)?;
141    if LATEST_EXAMPLE_RUN_DIR.exists() {
142        fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
143    }
144
145    #[cfg(unix)]
146    std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
147        .context("creating latest link")?;
148
149    #[cfg(windows)]
150    std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
151        .context("creating latest link")?;
152
153    zeta.update(cx, |zeta, _cx| {
154        zeta.with_eval_cache(Arc::new(RunCache {
155            example_run_dir: example_run_dir.clone(),
156            cache_mode,
157        }));
158    })?;
159
160    cx.subscribe(&buffer_store, {
161        let project = project.clone();
162        move |_, event, cx| match event {
163            project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
164                zeta2::Zeta::try_global(cx)
165                    .unwrap()
166                    .update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
167            }
168            _ => {}
169        }
170    })?
171    .detach();
172
173    let _edited_buffers = example.apply_edit_history(&project, cx).await?;
174    let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
175
176    let result = Arc::new(Mutex::new(PredictionDetails::default()));
177    let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
178
179    let debug_task = cx.background_spawn({
180        let result = result.clone();
181        async move {
182            let mut start_time = None;
183            let mut search_queries_generated_at = None;
184            let mut search_queries_executed_at = None;
185            while let Some(event) = debug_rx.next().await {
186                match event {
187                    zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
188                        start_time = Some(info.timestamp);
189                        fs::write(
190                            example_run_dir.join("search_prompt.md"),
191                            &info.search_prompt,
192                        )?;
193                    }
194                    zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
195                        search_queries_generated_at = Some(info.timestamp);
196                        fs::write(
197                            example_run_dir.join("search_queries.json"),
198                            serde_json::to_string_pretty(&info.search_queries).unwrap(),
199                        )?;
200                    }
201                    zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
202                        search_queries_executed_at = Some(info.timestamp);
203                    }
204                    zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
205                    zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
206                        let prediction_started_at = Instant::now();
207                        start_time.get_or_insert(prediction_started_at);
208                        fs::write(
209                            example_run_dir.join("prediction_prompt.md"),
210                            &request.local_prompt.unwrap_or_default(),
211                        )?;
212
213                        {
214                            let mut result = result.lock().unwrap();
215
216                            for included_file in request.request.included_files {
217                                let insertions =
218                                    vec![(request.request.cursor_point, CURSOR_MARKER)];
219                                result.excerpts.extend(included_file.excerpts.iter().map(
220                                    |excerpt| ActualExcerpt {
221                                        path: included_file.path.components().skip(1).collect(),
222                                        text: String::from(excerpt.text.as_ref()),
223                                    },
224                                ));
225                                write_codeblock(
226                                    &included_file.path,
227                                    included_file.excerpts.iter(),
228                                    if included_file.path == request.request.excerpt_path {
229                                        &insertions
230                                    } else {
231                                        &[]
232                                    },
233                                    included_file.max_row,
234                                    false,
235                                    &mut result.excerpts_text,
236                                );
237                            }
238                        }
239
240                        let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
241                        let response = zeta2::text_from_response(response).unwrap_or_default();
242                        let prediction_finished_at = Instant::now();
243                        fs::write(example_run_dir.join("prediction_response.md"), &response)?;
244
245                        let mut result = result.lock().unwrap();
246
247                        if !use_expected_context {
248                            result.planning_search_time =
249                                Some(search_queries_generated_at.unwrap() - start_time.unwrap());
250                            result.running_search_time = Some(
251                                search_queries_executed_at.unwrap()
252                                    - search_queries_generated_at.unwrap(),
253                            );
254                        }
255                        result.prediction_time = prediction_finished_at - prediction_started_at;
256                        result.total_time = prediction_finished_at - start_time.unwrap();
257
258                        break;
259                    }
260                }
261            }
262            anyhow::Ok(())
263        }
264    });
265
266    zeta.update(cx, |zeta, _cx| {
267        let mut options = zeta.options().clone();
268        options.prompt_format = prompt_format.into();
269        zeta.set_options(options);
270    })?;
271
272    if use_expected_context {
273        let context_excerpts_tasks = example
274            .example
275            .expected_context
276            .iter()
277            .flat_map(|section| {
278                section.alternatives[0].excerpts.iter().map(|excerpt| {
279                    resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
280                })
281            })
282            .collect::<Vec<_>>();
283        let context_excerpts_vec = futures::future::try_join_all(context_excerpts_tasks).await?;
284
285        let mut context_excerpts = HashMap::default();
286        for (buffer, mut excerpts) in context_excerpts_vec {
287            context_excerpts
288                .entry(buffer)
289                .or_insert(Vec::new())
290                .append(&mut excerpts);
291        }
292
293        zeta.update(cx, |zeta, _cx| {
294            zeta.set_context(project.clone(), context_excerpts)
295        })?;
296    } else {
297        zeta.update(cx, |zeta, cx| {
298            zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
299        })?
300        .await?;
301    }
302
303    let prediction = zeta
304        .update(cx, |zeta, cx| {
305            zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
306        })?
307        .await?;
308
309    debug_task.await?;
310
311    let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
312    result.diff = prediction
313        .map(|prediction| {
314            let old_text = prediction.snapshot.text();
315            let new_text = prediction
316                .buffer
317                .update(cx, |buffer, cx| {
318                    buffer.edit(prediction.edits.iter().cloned(), None, cx);
319                    buffer.text()
320                })
321                .unwrap();
322            language::unified_diff(&old_text, &new_text)
323        })
324        .unwrap_or_default();
325
326    anyhow::Ok(result)
327}
328
329async fn resolve_context_entry(
330    project: Entity<Project>,
331    excerpt: ExpectedExcerpt,
332    mut cx: AsyncApp,
333) -> Result<(Entity<Buffer>, Vec<Range<Anchor>>)> {
334    let buffer = project
335        .update(&mut cx, |project, cx| {
336            let project_path = project.find_project_path(&excerpt.path, cx).unwrap();
337            project.open_buffer(project_path, cx)
338        })?
339        .await?;
340
341    let ranges = buffer.read_with(&mut cx, |buffer, _| {
342        let full_text = buffer.text();
343        let offset = full_text
344            .find(&excerpt.text)
345            .expect("Expected context not found");
346        let point = buffer.offset_to_point(offset);
347        excerpt
348            .required_lines
349            .iter()
350            .map(|line| {
351                let row = point.row + line.0;
352                let range = Point::new(row, 0)..Point::new(row + 1, 0);
353                buffer.anchor_after(range.start)..buffer.anchor_before(range.end)
354            })
355            .collect()
356    })?;
357
358    Ok((buffer, ranges))
359}
360
361struct RunCache {
362    cache_mode: CacheMode,
363    example_run_dir: PathBuf,
364}
365
366impl RunCache {
367    fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
368        CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
369    }
370
371    fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
372        CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
373    }
374
375    fn link_to_run(&self, key: &EvalCacheKey) {
376        let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
377        fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
378
379        let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
380        fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
381    }
382}
383
384impl EvalCache for RunCache {
385    fn read(&self, key: EvalCacheKey) -> Option<String> {
386        let path = RunCache::output_cache_path(&key);
387
388        if path.exists() {
389            let use_cache = match key.0 {
390                EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
391                EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
392                    self.cache_mode.use_cached_llm_responses()
393                }
394            };
395            if use_cache {
396                log::info!("Using cache entry: {}", path.display());
397                self.link_to_run(&key);
398                Some(fs::read_to_string(path).unwrap())
399            } else {
400                log::info!("Skipping cached entry: {}", path.display());
401                None
402            }
403        } else if matches!(self.cache_mode, CacheMode::Force) {
404            panic!(
405                "No cached entry found for {:?}. Run without `--cache force` at least once.",
406                key.0
407            );
408        } else {
409            None
410        }
411    }
412
413    fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
414        fs::create_dir_all(&*CACHE_DIR).unwrap();
415
416        let input_path = RunCache::input_cache_path(&key);
417        fs::write(&input_path, input).unwrap();
418
419        let output_path = RunCache::output_cache_path(&key);
420        log::info!("Writing cache entry: {}", output_path.display());
421        fs::write(&output_path, output).unwrap();
422
423        self.link_to_run(&key);
424    }
425}
426
427#[derive(Clone, Debug, Default, Serialize, Deserialize)]
428pub struct PredictionDetails {
429    pub diff: String,
430    pub excerpts: Vec<ActualExcerpt>,
431    pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
432    pub planning_search_time: Option<Duration>,
433    pub running_search_time: Option<Duration>,
434    pub prediction_time: Duration,
435    pub total_time: Duration,
436}
437
438impl PredictionDetails {
439    pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
440        let formatted = match format {
441            PredictionsOutputFormat::Md => self.to_markdown(),
442            PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
443            PredictionsOutputFormat::Diff => self.diff.clone(),
444        };
445
446        Ok(out.write_all(formatted.as_bytes())?)
447    }
448
449    pub fn to_markdown(&self) -> String {
450        let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
451
452        format!(
453            "## Excerpts\n\n\
454            {}\n\n\
455            ## Prediction\n\n\
456            {}\n\n\
457            ## Time\n\n\
458            Planning searches: {}ms\n\
459            Running searches: {}ms\n\
460            Making Prediction: {}ms\n\n\
461            -------------------\n\n\
462            Total: {}ms\n\
463            Inference: {}ms ({:.2}%)\n",
464            self.excerpts_text,
465            self.diff,
466            self.planning_search_time.unwrap_or_default().as_millis(),
467            self.running_search_time.unwrap_or_default().as_millis(),
468            self.prediction_time.as_millis(),
469            self.total_time.as_millis(),
470            inference_time.as_millis(),
471            (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
472        )
473    }
474}