predict.rs

  1use crate::example::{ActualExcerpt, NamedExample};
  2use crate::headless::ZetaCliAppState;
  3use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
  4use crate::{
  5    CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
  6};
  7use ::serde::Serialize;
  8use anyhow::{Context, Result, anyhow};
  9use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
 10use edit_prediction::{EditPredictionStore, EvalCache, EvalCacheEntryKind, EvalCacheKey};
 11use futures::StreamExt as _;
 12use gpui::{AppContext, AsyncApp, Entity};
 13use project::Project;
 14use project::buffer_store::BufferStoreEvent;
 15use serde::Deserialize;
 16use std::fs;
 17use std::io::{IsTerminal, Write};
 18use std::path::PathBuf;
 19use std::sync::Arc;
 20use std::sync::Mutex;
 21use std::time::{Duration, Instant};
 22
 23pub async fn run_predict(
 24    args: PredictArguments,
 25    app_state: &Arc<ZetaCliAppState>,
 26    cx: &mut AsyncApp,
 27) {
 28    let example = NamedExample::load(args.example_path).unwrap();
 29    let project = example.setup_project(app_state, cx).await.unwrap();
 30    let store = setup_store(args.options.provider, &project, app_state, cx).unwrap();
 31    let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
 32    let result = perform_predict(example, project, store, None, args.options, cx)
 33        .await
 34        .unwrap();
 35    result.write(args.format, std::io::stdout()).unwrap();
 36
 37    print_run_data_dir(true, std::io::stdout().is_terminal());
 38}
 39
 40pub fn setup_store(
 41    provider: PredictionProvider,
 42    project: &Entity<Project>,
 43    app_state: &Arc<ZetaCliAppState>,
 44    cx: &mut AsyncApp,
 45) -> Result<Entity<EditPredictionStore>> {
 46    let store = cx.new(|cx| {
 47        edit_prediction::EditPredictionStore::new(
 48            app_state.client.clone(),
 49            app_state.user_store.clone(),
 50            cx,
 51        )
 52    })?;
 53
 54    store.update(cx, |store, _cx| {
 55        let model = match provider {
 56            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
 57            PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
 58            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
 59        };
 60        store.set_edit_prediction_model(model);
 61    })?;
 62
 63    let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
 64
 65    cx.subscribe(&buffer_store, {
 66        let project = project.clone();
 67        let store = store.clone();
 68        move |_, event, cx| match event {
 69            BufferStoreEvent::BufferAdded(buffer) => {
 70                store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
 71            }
 72            _ => {}
 73        }
 74    })?
 75    .detach();
 76
 77    anyhow::Ok(store)
 78}
 79
 80pub async fn perform_predict(
 81    example: NamedExample,
 82    project: Entity<Project>,
 83    store: Entity<EditPredictionStore>,
 84    repetition_ix: Option<u16>,
 85    options: PredictionOptions,
 86    cx: &mut AsyncApp,
 87) -> Result<PredictionDetails> {
 88    let mut cache_mode = options.cache;
 89    if repetition_ix.is_some() {
 90        if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
 91            panic!("Repetitions are not supported in Auto cache mode");
 92        } else {
 93            cache_mode = CacheMode::Skip;
 94        }
 95    } else if cache_mode == CacheMode::Auto {
 96        cache_mode = CacheMode::Requests;
 97    }
 98
 99    let mut example_run_dir = RUN_DIR.join(&example.file_name());
100    if let Some(repetition_ix) = repetition_ix {
101        example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
102    }
103    fs::create_dir_all(&example_run_dir)?;
104    if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
105        fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
106    }
107
108    #[cfg(unix)]
109    std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
110        .context("creating latest link")?;
111
112    #[cfg(windows)]
113    std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
114        .context("creating latest link")?;
115
116    store.update(cx, |store, _cx| {
117        store.with_eval_cache(Arc::new(RunCache {
118            example_run_dir: example_run_dir.clone(),
119            cache_mode,
120        }));
121    })?;
122
123    let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
124
125    let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
126
127    let prompt_format = options.zeta2.prompt_format;
128
129    store.update(cx, |store, _cx| {
130        let mut options = store.options().clone();
131        options.prompt_format = prompt_format.into();
132        store.set_options(options);
133    })?;
134
135    let mut debug_task = gpui::Task::ready(Ok(()));
136
137    if options.provider == crate::PredictionProvider::Zeta2 {
138        let mut debug_rx = store.update(cx, |store, _| store.debug_info())?;
139
140        debug_task = cx.background_spawn({
141            let result = result.clone();
142            async move {
143                let mut start_time = None;
144                let mut retrieval_finished_at = None;
145                while let Some(event) = debug_rx.next().await {
146                    match event {
147                        edit_prediction::DebugEvent::ContextRetrievalStarted(info) => {
148                            start_time = Some(info.timestamp);
149                            fs::write(
150                                example_run_dir.join("search_prompt.md"),
151                                &info.search_prompt,
152                            )?;
153                        }
154                        edit_prediction::DebugEvent::ContextRetrievalFinished(info) => {
155                            retrieval_finished_at = Some(info.timestamp);
156                            for (key, value) in &info.metadata {
157                                if *key == "search_queries" {
158                                    fs::write(
159                                        example_run_dir.join("search_queries.json"),
160                                        value.as_bytes(),
161                                    )?;
162                                }
163                            }
164                        }
165                        edit_prediction::DebugEvent::EditPredictionRequested(request) => {
166                            let prediction_started_at = Instant::now();
167                            start_time.get_or_insert(prediction_started_at);
168                            let prompt = request.local_prompt.unwrap_or_default();
169                            fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
170
171                            {
172                                let mut result = result.lock().unwrap();
173                                result.prompt_len = prompt.chars().count();
174
175                                for included_file in request.inputs.included_files {
176                                    let insertions =
177                                        vec![(request.inputs.cursor_point, CURSOR_MARKER)];
178                                    result.excerpts.extend(included_file.excerpts.iter().map(
179                                        |excerpt| ActualExcerpt {
180                                            path: included_file.path.components().skip(1).collect(),
181                                            text: String::from(excerpt.text.as_ref()),
182                                        },
183                                    ));
184                                    write_codeblock(
185                                        &included_file.path,
186                                        included_file.excerpts.iter(),
187                                        if included_file.path == request.inputs.cursor_path {
188                                            &insertions
189                                        } else {
190                                            &[]
191                                        },
192                                        included_file.max_row,
193                                        false,
194                                        &mut result.excerpts_text,
195                                    );
196                                }
197                            }
198
199                            let response =
200                                request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
201                            let response =
202                                edit_prediction::open_ai_response::text_from_response(response)
203                                    .unwrap_or_default();
204                            let prediction_finished_at = Instant::now();
205                            fs::write(example_run_dir.join("prediction_response.md"), &response)?;
206
207                            let mut result = result.lock().unwrap();
208                            result.generated_len = response.chars().count();
209                            result.retrieval_time =
210                                retrieval_finished_at.unwrap() - start_time.unwrap();
211                            result.prediction_time = prediction_finished_at - prediction_started_at;
212                            result.total_time = prediction_finished_at - start_time.unwrap();
213
214                            break;
215                        }
216                    }
217                }
218                anyhow::Ok(())
219            }
220        });
221
222        store.update(cx, |store, cx| {
223            store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx)
224        })?;
225    }
226
227    let prediction = store
228        .update(cx, |store, cx| {
229            store.request_prediction(
230                &project,
231                &cursor_buffer,
232                cursor_anchor,
233                cloud_llm_client::PredictEditsRequestTrigger::Cli,
234                cx,
235            )
236        })?
237        .await?;
238
239    debug_task.await?;
240
241    let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
242
243    result.diff = prediction
244        .and_then(|prediction| {
245            let prediction = prediction.prediction.ok()?;
246            prediction.edit_preview.as_unified_diff(&prediction.edits)
247        })
248        .unwrap_or_default();
249
250    anyhow::Ok(result)
251}
252
253struct RunCache {
254    cache_mode: CacheMode,
255    example_run_dir: PathBuf,
256}
257
258impl RunCache {
259    fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
260        CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
261    }
262
263    fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
264        CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
265    }
266
267    fn link_to_run(&self, key: &EvalCacheKey) {
268        let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
269        fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
270
271        let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
272        fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
273    }
274}
275
276impl EvalCache for RunCache {
277    fn read(&self, key: EvalCacheKey) -> Option<String> {
278        let path = RunCache::output_cache_path(&key);
279
280        if path.exists() {
281            let use_cache = match key.0 {
282                EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
283                EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
284                    self.cache_mode.use_cached_llm_responses()
285                }
286            };
287            if use_cache {
288                log::info!("Using cache entry: {}", path.display());
289                self.link_to_run(&key);
290                Some(fs::read_to_string(path).unwrap())
291            } else {
292                log::trace!("Skipping cached entry: {}", path.display());
293                None
294            }
295        } else if matches!(self.cache_mode, CacheMode::Force) {
296            panic!(
297                "No cached entry found for {:?}. Run without `--cache force` at least once.",
298                key.0
299            );
300        } else {
301            None
302        }
303    }
304
305    fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
306        fs::create_dir_all(&*CACHE_DIR).unwrap();
307
308        let input_path = RunCache::input_cache_path(&key);
309        fs::write(&input_path, input).unwrap();
310
311        let output_path = RunCache::output_cache_path(&key);
312        log::trace!("Writing cache entry: {}", output_path.display());
313        fs::write(&output_path, output).unwrap();
314
315        self.link_to_run(&key);
316    }
317}
318
319#[derive(Clone, Debug, Serialize, Deserialize)]
320pub struct PredictionDetails {
321    pub diff: String,
322    pub excerpts: Vec<ActualExcerpt>,
323    pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
324    pub retrieval_time: Duration,
325    pub prediction_time: Duration,
326    pub total_time: Duration,
327    pub run_example_dir: PathBuf,
328    pub prompt_len: usize,
329    pub generated_len: usize,
330}
331
332impl PredictionDetails {
333    pub fn new(run_example_dir: PathBuf) -> Self {
334        Self {
335            diff: Default::default(),
336            excerpts: Default::default(),
337            excerpts_text: Default::default(),
338            retrieval_time: Default::default(),
339            prediction_time: Default::default(),
340            total_time: Default::default(),
341            run_example_dir,
342            prompt_len: 0,
343            generated_len: 0,
344        }
345    }
346
347    pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
348        let formatted = match format {
349            PredictionsOutputFormat::Md => self.to_markdown(),
350            PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
351            PredictionsOutputFormat::Diff => self.diff.clone(),
352        };
353
354        Ok(out.write_all(formatted.as_bytes())?)
355    }
356
357    pub fn to_markdown(&self) -> String {
358        format!(
359            "## Excerpts\n\n\
360            {}\n\n\
361            ## Prediction\n\n\
362            {}\n\n\
363            ## Time\n\n\
364            Retrieval: {}ms\n\
365            Prediction: {}ms\n\n\
366            Total: {}ms\n",
367            self.excerpts_text,
368            self.diff,
369            self.retrieval_time.as_millis(),
370            self.prediction_time.as_millis(),
371            self.total_time.as_millis(),
372        )
373    }
374}