predict.rs

  1use crate::example::{ActualExcerpt, NamedExample};
  2use crate::headless::ZetaCliAppState;
  3use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
  4use crate::{
  5    CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
  6};
  7use ::serde::Serialize;
  8use anyhow::{Context, Result, anyhow};
  9use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
 10use futures::StreamExt as _;
 11use gpui::{AppContext, AsyncApp, Entity};
 12use project::Project;
 13use project::buffer_store::BufferStoreEvent;
 14use serde::Deserialize;
 15use std::fs;
 16use std::io::{IsTerminal, Write};
 17use std::path::PathBuf;
 18use std::sync::Arc;
 19use std::sync::Mutex;
 20use std::time::{Duration, Instant};
 21use zeta::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
 22
 23pub async fn run_predict(
 24    args: PredictArguments,
 25    app_state: &Arc<ZetaCliAppState>,
 26    cx: &mut AsyncApp,
 27) {
 28    let example = NamedExample::load(args.example_path).unwrap();
 29    let project = example.setup_project(app_state, cx).await.unwrap();
 30    let zeta = setup_zeta(args.options.provider, &project, app_state, cx).unwrap();
 31    let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
 32    let result = perform_predict(example, project, zeta, None, args.options, cx)
 33        .await
 34        .unwrap();
 35    result.write(args.format, std::io::stdout()).unwrap();
 36
 37    print_run_data_dir(true, std::io::stdout().is_terminal());
 38}
 39
 40pub fn setup_zeta(
 41    provider: PredictionProvider,
 42    project: &Entity<Project>,
 43    app_state: &Arc<ZetaCliAppState>,
 44    cx: &mut AsyncApp,
 45) -> Result<Entity<Zeta>> {
 46    let zeta =
 47        cx.new(|cx| zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
 48
 49    zeta.update(cx, |zeta, _cx| {
 50        let model = match provider {
 51            PredictionProvider::Zeta1 => zeta::ZetaEditPredictionModel::Zeta1,
 52            PredictionProvider::Zeta2 => zeta::ZetaEditPredictionModel::Zeta2,
 53            PredictionProvider::Sweep => zeta::ZetaEditPredictionModel::Sweep,
 54        };
 55        zeta.set_edit_prediction_model(model);
 56    })?;
 57
 58    let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
 59
 60    cx.subscribe(&buffer_store, {
 61        let project = project.clone();
 62        let zeta = zeta.clone();
 63        move |_, event, cx| match event {
 64            BufferStoreEvent::BufferAdded(buffer) => {
 65                zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
 66            }
 67            _ => {}
 68        }
 69    })?
 70    .detach();
 71
 72    anyhow::Ok(zeta)
 73}
 74
 75pub async fn perform_predict(
 76    example: NamedExample,
 77    project: Entity<Project>,
 78    zeta: Entity<Zeta>,
 79    repetition_ix: Option<u16>,
 80    options: PredictionOptions,
 81    cx: &mut AsyncApp,
 82) -> Result<PredictionDetails> {
 83    let mut cache_mode = options.cache;
 84    if repetition_ix.is_some() {
 85        if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
 86            panic!("Repetitions are not supported in Auto cache mode");
 87        } else {
 88            cache_mode = CacheMode::Skip;
 89        }
 90    } else if cache_mode == CacheMode::Auto {
 91        cache_mode = CacheMode::Requests;
 92    }
 93
 94    let mut example_run_dir = RUN_DIR.join(&example.file_name());
 95    if let Some(repetition_ix) = repetition_ix {
 96        example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
 97    }
 98    fs::create_dir_all(&example_run_dir)?;
 99    if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
100        fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
101    }
102
103    #[cfg(unix)]
104    std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
105        .context("creating latest link")?;
106
107    #[cfg(windows)]
108    std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
109        .context("creating latest link")?;
110
111    zeta.update(cx, |zeta, _cx| {
112        zeta.with_eval_cache(Arc::new(RunCache {
113            example_run_dir: example_run_dir.clone(),
114            cache_mode,
115        }));
116    })?;
117
118    let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
119
120    let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
121
122    let prompt_format = options.zeta2.prompt_format;
123
124    zeta.update(cx, |zeta, _cx| {
125        let mut options = zeta.options().clone();
126        options.prompt_format = prompt_format.into();
127        zeta.set_options(options);
128    })?;
129
130    let mut debug_task = gpui::Task::ready(Ok(()));
131
132    if options.provider == crate::PredictionProvider::Zeta2 {
133        let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
134
135        debug_task = cx.background_spawn({
136            let result = result.clone();
137            async move {
138                let mut start_time = None;
139                let mut search_queries_generated_at = None;
140                let mut search_queries_executed_at = None;
141                while let Some(event) = debug_rx.next().await {
142                    match event {
143                        zeta::ZetaDebugInfo::ContextRetrievalStarted(info) => {
144                            start_time = Some(info.timestamp);
145                            fs::write(
146                                example_run_dir.join("search_prompt.md"),
147                                &info.search_prompt,
148                            )?;
149                        }
150                        zeta::ZetaDebugInfo::SearchQueriesGenerated(info) => {
151                            search_queries_generated_at = Some(info.timestamp);
152                            fs::write(
153                                example_run_dir.join("search_queries.json"),
154                                serde_json::to_string_pretty(&info.search_queries).unwrap(),
155                            )?;
156                        }
157                        zeta::ZetaDebugInfo::SearchQueriesExecuted(info) => {
158                            search_queries_executed_at = Some(info.timestamp);
159                        }
160                        zeta::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
161                        zeta::ZetaDebugInfo::EditPredictionRequested(request) => {
162                            let prediction_started_at = Instant::now();
163                            start_time.get_or_insert(prediction_started_at);
164                            let prompt = request.local_prompt.unwrap_or_default();
165                            fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
166
167                            {
168                                let mut result = result.lock().unwrap();
169                                result.prompt_len = prompt.chars().count();
170
171                                for included_file in request.inputs.included_files {
172                                    let insertions =
173                                        vec![(request.inputs.cursor_point, CURSOR_MARKER)];
174                                    result.excerpts.extend(included_file.excerpts.iter().map(
175                                        |excerpt| ActualExcerpt {
176                                            path: included_file.path.components().skip(1).collect(),
177                                            text: String::from(excerpt.text.as_ref()),
178                                        },
179                                    ));
180                                    write_codeblock(
181                                        &included_file.path,
182                                        included_file.excerpts.iter(),
183                                        if included_file.path == request.inputs.cursor_path {
184                                            &insertions
185                                        } else {
186                                            &[]
187                                        },
188                                        included_file.max_row,
189                                        false,
190                                        &mut result.excerpts_text,
191                                    );
192                                }
193                            }
194
195                            let response =
196                                request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
197                            let response = zeta::text_from_response(response).unwrap_or_default();
198                            let prediction_finished_at = Instant::now();
199                            fs::write(example_run_dir.join("prediction_response.md"), &response)?;
200
201                            let mut result = result.lock().unwrap();
202                            result.generated_len = response.chars().count();
203
204                            result.planning_search_time =
205                                Some(search_queries_generated_at.unwrap() - start_time.unwrap());
206                            result.running_search_time = Some(
207                                search_queries_executed_at.unwrap()
208                                    - search_queries_generated_at.unwrap(),
209                            );
210                            result.prediction_time = prediction_finished_at - prediction_started_at;
211                            result.total_time = prediction_finished_at - start_time.unwrap();
212
213                            break;
214                        }
215                    }
216                }
217                anyhow::Ok(())
218            }
219        });
220
221        zeta.update(cx, |zeta, cx| {
222            zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
223        })?
224        .await?;
225    }
226
227    let prediction = zeta
228        .update(cx, |zeta, cx| {
229            zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
230        })?
231        .await?;
232
233    debug_task.await?;
234
235    let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
236
237    result.diff = prediction
238        .and_then(|prediction| {
239            let prediction = prediction.prediction.ok()?;
240            prediction.edit_preview.as_unified_diff(&prediction.edits)
241        })
242        .unwrap_or_default();
243
244    anyhow::Ok(result)
245}
246
247struct RunCache {
248    cache_mode: CacheMode,
249    example_run_dir: PathBuf,
250}
251
252impl RunCache {
253    fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
254        CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
255    }
256
257    fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
258        CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
259    }
260
261    fn link_to_run(&self, key: &EvalCacheKey) {
262        let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
263        fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
264
265        let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
266        fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
267    }
268}
269
270impl EvalCache for RunCache {
271    fn read(&self, key: EvalCacheKey) -> Option<String> {
272        let path = RunCache::output_cache_path(&key);
273
274        if path.exists() {
275            let use_cache = match key.0 {
276                EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
277                EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
278                    self.cache_mode.use_cached_llm_responses()
279                }
280            };
281            if use_cache {
282                log::info!("Using cache entry: {}", path.display());
283                self.link_to_run(&key);
284                Some(fs::read_to_string(path).unwrap())
285            } else {
286                log::trace!("Skipping cached entry: {}", path.display());
287                None
288            }
289        } else if matches!(self.cache_mode, CacheMode::Force) {
290            panic!(
291                "No cached entry found for {:?}. Run without `--cache force` at least once.",
292                key.0
293            );
294        } else {
295            None
296        }
297    }
298
299    fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
300        fs::create_dir_all(&*CACHE_DIR).unwrap();
301
302        let input_path = RunCache::input_cache_path(&key);
303        fs::write(&input_path, input).unwrap();
304
305        let output_path = RunCache::output_cache_path(&key);
306        log::trace!("Writing cache entry: {}", output_path.display());
307        fs::write(&output_path, output).unwrap();
308
309        self.link_to_run(&key);
310    }
311}
312
313#[derive(Clone, Debug, Serialize, Deserialize)]
314pub struct PredictionDetails {
315    pub diff: String,
316    pub excerpts: Vec<ActualExcerpt>,
317    pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
318    pub planning_search_time: Option<Duration>,
319    pub running_search_time: Option<Duration>,
320    pub prediction_time: Duration,
321    pub total_time: Duration,
322    pub run_example_dir: PathBuf,
323    pub prompt_len: usize,
324    pub generated_len: usize,
325}
326
327impl PredictionDetails {
328    pub fn new(run_example_dir: PathBuf) -> Self {
329        Self {
330            diff: Default::default(),
331            excerpts: Default::default(),
332            excerpts_text: Default::default(),
333            planning_search_time: Default::default(),
334            running_search_time: Default::default(),
335            prediction_time: Default::default(),
336            total_time: Default::default(),
337            run_example_dir,
338            prompt_len: 0,
339            generated_len: 0,
340        }
341    }
342
343    pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
344        let formatted = match format {
345            PredictionsOutputFormat::Md => self.to_markdown(),
346            PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
347            PredictionsOutputFormat::Diff => self.diff.clone(),
348        };
349
350        Ok(out.write_all(formatted.as_bytes())?)
351    }
352
353    pub fn to_markdown(&self) -> String {
354        let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
355
356        format!(
357            "## Excerpts\n\n\
358            {}\n\n\
359            ## Prediction\n\n\
360            {}\n\n\
361            ## Time\n\n\
362            Planning searches: {}ms\n\
363            Running searches: {}ms\n\
364            Making Prediction: {}ms\n\n\
365            -------------------\n\n\
366            Total: {}ms\n\
367            Inference: {}ms ({:.2}%)\n",
368            self.excerpts_text,
369            self.diff,
370            self.planning_search_time.unwrap_or_default().as_millis(),
371            self.running_search_time.unwrap_or_default().as_millis(),
372            self.prediction_time.as_millis(),
373            self.total_time.as_millis(),
374            inference_time.as_millis(),
375            (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
376        )
377    }
378}