predict.rs

  1use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
  2use crate::headless::ZetaCliAppState;
  3use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
  4use crate::{
  5    CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
  6};
  7use ::serde::Serialize;
  8use anyhow::{Context, Result, anyhow};
  9use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
 10use collections::HashMap;
 11use futures::StreamExt as _;
 12use gpui::{AppContext, AsyncApp, Entity};
 13use language::{Anchor, Buffer, Point};
 14use project::Project;
 15use project::buffer_store::BufferStoreEvent;
 16use serde::Deserialize;
 17use std::fs;
 18use std::io::{IsTerminal, Write};
 19use std::ops::Range;
 20use std::path::PathBuf;
 21use std::sync::Arc;
 22use std::sync::Mutex;
 23use std::time::{Duration, Instant};
 24use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
 25
 26pub async fn run_predict(
 27    args: PredictArguments,
 28    app_state: &Arc<ZetaCliAppState>,
 29    cx: &mut AsyncApp,
 30) {
 31    let example = NamedExample::load(args.example_path).unwrap();
 32    let project = example.setup_project(app_state, cx).await.unwrap();
 33    let zeta = setup_zeta(args.options.provider, &project, app_state, cx).unwrap();
 34    let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
 35    let result = perform_predict(example, project, zeta, None, args.options, cx)
 36        .await
 37        .unwrap();
 38    result.write(args.format, std::io::stdout()).unwrap();
 39
 40    print_run_data_dir(true, std::io::stdout().is_terminal());
 41}
 42
 43pub fn setup_zeta(
 44    provider: PredictionProvider,
 45    project: &Entity<Project>,
 46    app_state: &Arc<ZetaCliAppState>,
 47    cx: &mut AsyncApp,
 48) -> Result<Entity<Zeta>> {
 49    let zeta =
 50        cx.new(|cx| zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
 51
 52    zeta.update(cx, |zeta, _cx| {
 53        let model = match provider {
 54            PredictionProvider::Zeta2 => zeta2::ZetaEditPredictionModel::ZedCloud,
 55            PredictionProvider::Sweep => zeta2::ZetaEditPredictionModel::Sweep,
 56        };
 57        zeta.set_edit_prediction_model(model);
 58    })?;
 59
 60    let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
 61
 62    cx.subscribe(&buffer_store, {
 63        let project = project.clone();
 64        let zeta = zeta.clone();
 65        move |_, event, cx| match event {
 66            BufferStoreEvent::BufferAdded(buffer) => {
 67                zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
 68            }
 69            _ => {}
 70        }
 71    })?
 72    .detach();
 73
 74    anyhow::Ok(zeta)
 75}
 76
 77pub async fn perform_predict(
 78    example: NamedExample,
 79    project: Entity<Project>,
 80    zeta: Entity<Zeta>,
 81    repetition_ix: Option<u16>,
 82    options: PredictionOptions,
 83    cx: &mut AsyncApp,
 84) -> Result<PredictionDetails> {
 85    let mut cache_mode = options.cache;
 86    if repetition_ix.is_some() {
 87        if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
 88            panic!("Repetitions are not supported in Auto cache mode");
 89        } else {
 90            cache_mode = CacheMode::Skip;
 91        }
 92    } else if cache_mode == CacheMode::Auto {
 93        cache_mode = CacheMode::Requests;
 94    }
 95
 96    let mut example_run_dir = RUN_DIR.join(&example.file_name());
 97    if let Some(repetition_ix) = repetition_ix {
 98        example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
 99    }
100    fs::create_dir_all(&example_run_dir)?;
101    if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
102        fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
103    }
104
105    #[cfg(unix)]
106    std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
107        .context("creating latest link")?;
108
109    #[cfg(windows)]
110    std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
111        .context("creating latest link")?;
112
113    zeta.update(cx, |zeta, _cx| {
114        zeta.with_eval_cache(Arc::new(RunCache {
115            example_run_dir: example_run_dir.clone(),
116            cache_mode,
117        }));
118    })?;
119
120    let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
121
122    let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
123
124    let prompt_format = options.zeta2.prompt_format;
125
126    zeta.update(cx, |zeta, _cx| {
127        let mut options = zeta.options().clone();
128        options.prompt_format = prompt_format.into();
129        zeta.set_options(options);
130    })?;
131
132    let mut debug_task = gpui::Task::ready(Ok(()));
133
134    if options.provider == crate::PredictionProvider::Zeta2 {
135        let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
136
137        debug_task = cx.background_spawn({
138            let result = result.clone();
139            async move {
140                let mut start_time = None;
141                let mut search_queries_generated_at = None;
142                let mut search_queries_executed_at = None;
143                while let Some(event) = debug_rx.next().await {
144                    match event {
145                        zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
146                            start_time = Some(info.timestamp);
147                            fs::write(
148                                example_run_dir.join("search_prompt.md"),
149                                &info.search_prompt,
150                            )?;
151                        }
152                        zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
153                            search_queries_generated_at = Some(info.timestamp);
154                            fs::write(
155                                example_run_dir.join("search_queries.json"),
156                                serde_json::to_string_pretty(&info.search_queries).unwrap(),
157                            )?;
158                        }
159                        zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
160                            search_queries_executed_at = Some(info.timestamp);
161                        }
162                        zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
163                        zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
164                            let prediction_started_at = Instant::now();
165                            start_time.get_or_insert(prediction_started_at);
166                            let prompt = request.local_prompt.unwrap_or_default();
167                            fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
168
169                            {
170                                let mut result = result.lock().unwrap();
171                                result.prompt_len = prompt.chars().count();
172
173                                for included_file in request.request.included_files {
174                                    let insertions =
175                                        vec![(request.request.cursor_point, CURSOR_MARKER)];
176                                    result.excerpts.extend(included_file.excerpts.iter().map(
177                                        |excerpt| ActualExcerpt {
178                                            path: included_file.path.components().skip(1).collect(),
179                                            text: String::from(excerpt.text.as_ref()),
180                                        },
181                                    ));
182                                    write_codeblock(
183                                        &included_file.path,
184                                        included_file.excerpts.iter(),
185                                        if included_file.path == request.request.excerpt_path {
186                                            &insertions
187                                        } else {
188                                            &[]
189                                        },
190                                        included_file.max_row,
191                                        false,
192                                        &mut result.excerpts_text,
193                                    );
194                                }
195                            }
196
197                            let response =
198                                request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
199                            let response = zeta2::text_from_response(response).unwrap_or_default();
200                            let prediction_finished_at = Instant::now();
201                            fs::write(example_run_dir.join("prediction_response.md"), &response)?;
202
203                            let mut result = result.lock().unwrap();
204                            result.generated_len = response.chars().count();
205
206                            if !options.use_expected_context {
207                                result.planning_search_time = Some(
208                                    search_queries_generated_at.unwrap() - start_time.unwrap(),
209                                );
210                                result.running_search_time = Some(
211                                    search_queries_executed_at.unwrap()
212                                        - search_queries_generated_at.unwrap(),
213                                );
214                            }
215                            result.prediction_time = prediction_finished_at - prediction_started_at;
216                            result.total_time = prediction_finished_at - start_time.unwrap();
217
218                            break;
219                        }
220                    }
221                }
222                anyhow::Ok(())
223            }
224        });
225
226        if options.use_expected_context {
227            let context_excerpts_tasks = example
228                .example
229                .expected_context
230                .iter()
231                .flat_map(|section| {
232                    section.alternatives[0].excerpts.iter().map(|excerpt| {
233                        resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
234                    })
235                })
236                .collect::<Vec<_>>();
237            let context_excerpts_vec =
238                futures::future::try_join_all(context_excerpts_tasks).await?;
239
240            let mut context_excerpts = HashMap::default();
241            for (buffer, mut excerpts) in context_excerpts_vec {
242                context_excerpts
243                    .entry(buffer)
244                    .or_insert(Vec::new())
245                    .append(&mut excerpts);
246            }
247
248            zeta.update(cx, |zeta, _cx| {
249                zeta.set_context(project.clone(), context_excerpts)
250            })?;
251        } else {
252            zeta.update(cx, |zeta, cx| {
253                zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
254            })?
255            .await?;
256        }
257    }
258
259    let prediction = zeta
260        .update(cx, |zeta, cx| {
261            zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
262        })?
263        .await?;
264
265    debug_task.await?;
266
267    let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
268
269    result.diff = prediction
270        .map(|prediction| {
271            let old_text = prediction.snapshot.text();
272            let new_text = prediction
273                .buffer
274                .update(cx, |buffer, cx| {
275                    let branch = buffer.branch(cx);
276                    branch.update(cx, |branch, cx| {
277                        branch.edit(prediction.edits.iter().cloned(), None, cx);
278                        branch.text()
279                    })
280                })
281                .unwrap();
282            language::unified_diff(&old_text, &new_text)
283        })
284        .unwrap_or_default();
285
286    anyhow::Ok(result)
287}
288
289async fn resolve_context_entry(
290    project: Entity<Project>,
291    excerpt: ExpectedExcerpt,
292    mut cx: AsyncApp,
293) -> Result<(Entity<Buffer>, Vec<Range<Anchor>>)> {
294    let buffer = project
295        .update(&mut cx, |project, cx| {
296            let project_path = project.find_project_path(&excerpt.path, cx).unwrap();
297            project.open_buffer(project_path, cx)
298        })?
299        .await?;
300
301    let ranges = buffer.read_with(&mut cx, |buffer, _| {
302        let full_text = buffer.text();
303        let offset = full_text
304            .find(&excerpt.text)
305            .expect("Expected context not found");
306        let point = buffer.offset_to_point(offset);
307        excerpt
308            .required_lines
309            .iter()
310            .map(|line| {
311                let row = point.row + line.0;
312                let range = Point::new(row, 0)..Point::new(row + 1, 0);
313                buffer.anchor_after(range.start)..buffer.anchor_before(range.end)
314            })
315            .collect()
316    })?;
317
318    Ok((buffer, ranges))
319}
320
321struct RunCache {
322    cache_mode: CacheMode,
323    example_run_dir: PathBuf,
324}
325
326impl RunCache {
327    fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
328        CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
329    }
330
331    fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
332        CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
333    }
334
335    fn link_to_run(&self, key: &EvalCacheKey) {
336        let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
337        fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
338
339        let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
340        fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
341    }
342}
343
344impl EvalCache for RunCache {
345    fn read(&self, key: EvalCacheKey) -> Option<String> {
346        let path = RunCache::output_cache_path(&key);
347
348        if path.exists() {
349            let use_cache = match key.0 {
350                EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
351                EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
352                    self.cache_mode.use_cached_llm_responses()
353                }
354            };
355            if use_cache {
356                log::info!("Using cache entry: {}", path.display());
357                self.link_to_run(&key);
358                Some(fs::read_to_string(path).unwrap())
359            } else {
360                log::trace!("Skipping cached entry: {}", path.display());
361                None
362            }
363        } else if matches!(self.cache_mode, CacheMode::Force) {
364            panic!(
365                "No cached entry found for {:?}. Run without `--cache force` at least once.",
366                key.0
367            );
368        } else {
369            None
370        }
371    }
372
373    fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
374        fs::create_dir_all(&*CACHE_DIR).unwrap();
375
376        let input_path = RunCache::input_cache_path(&key);
377        fs::write(&input_path, input).unwrap();
378
379        let output_path = RunCache::output_cache_path(&key);
380        log::trace!("Writing cache entry: {}", output_path.display());
381        fs::write(&output_path, output).unwrap();
382
383        self.link_to_run(&key);
384    }
385}
386
387#[derive(Clone, Debug, Serialize, Deserialize)]
388pub struct PredictionDetails {
389    pub diff: String,
390    pub excerpts: Vec<ActualExcerpt>,
391    pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
392    pub planning_search_time: Option<Duration>,
393    pub running_search_time: Option<Duration>,
394    pub prediction_time: Duration,
395    pub total_time: Duration,
396    pub run_example_dir: PathBuf,
397    pub prompt_len: usize,
398    pub generated_len: usize,
399}
400
401impl PredictionDetails {
402    pub fn new(run_example_dir: PathBuf) -> Self {
403        Self {
404            diff: Default::default(),
405            excerpts: Default::default(),
406            excerpts_text: Default::default(),
407            planning_search_time: Default::default(),
408            running_search_time: Default::default(),
409            prediction_time: Default::default(),
410            total_time: Default::default(),
411            run_example_dir,
412            prompt_len: 0,
413            generated_len: 0,
414        }
415    }
416
417    pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
418        let formatted = match format {
419            PredictionsOutputFormat::Md => self.to_markdown(),
420            PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
421            PredictionsOutputFormat::Diff => self.diff.clone(),
422        };
423
424        Ok(out.write_all(formatted.as_bytes())?)
425    }
426
427    pub fn to_markdown(&self) -> String {
428        let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
429
430        format!(
431            "## Excerpts\n\n\
432            {}\n\n\
433            ## Prediction\n\n\
434            {}\n\n\
435            ## Time\n\n\
436            Planning searches: {}ms\n\
437            Running searches: {}ms\n\
438            Making Prediction: {}ms\n\n\
439            -------------------\n\n\
440            Total: {}ms\n\
441            Inference: {}ms ({:.2}%)\n",
442            self.excerpts_text,
443            self.diff,
444            self.planning_search_time.unwrap_or_default().as_millis(),
445            self.running_search_time.unwrap_or_default().as_millis(),
446            self.prediction_time.as_millis(),
447            self.total_time.as_millis(),
448            inference_time.as_millis(),
449            (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
450        )
451    }
452}