predict.rs

  1use crate::PromptFormat;
  2use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
  3use crate::headless::ZetaCliAppState;
  4use crate::paths::{CACHE_DIR, LOGS_DIR};
  5use ::serde::Serialize;
  6use anyhow::{Result, anyhow};
  7use clap::Args;
  8use collections::HashMap;
  9use gpui::http_client::Url;
 10use language::{Anchor, Buffer, Point};
 11// use cloud_llm_client::predict_edits_v3::PromptFormat;
 12use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
 13use futures::StreamExt as _;
 14use gpui::{AppContext, AsyncApp, Entity};
 15use project::Project;
 16use serde::Deserialize;
 17use std::cell::Cell;
 18use std::fs;
 19use std::io::Write;
 20use std::ops::Range;
 21use std::path::PathBuf;
 22use std::sync::Arc;
 23use std::sync::Mutex;
 24use std::time::{Duration, Instant};
 25use zeta2::LlmResponseCache;
 26
 27#[derive(Debug, Args)]
 28pub struct PredictArguments {
 29    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 30    prompt_format: PromptFormat,
 31    #[arg(long)]
 32    use_expected_context: bool,
 33    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
 34    format: PredictionsOutputFormat,
 35    example_path: PathBuf,
 36    #[clap(long)]
 37    skip_cache: bool,
 38}
 39
 40#[derive(clap::ValueEnum, Debug, Clone)]
 41pub enum PredictionsOutputFormat {
 42    Json,
 43    Md,
 44    Diff,
 45}
 46
 47pub async fn run_zeta2_predict(
 48    args: PredictArguments,
 49    app_state: &Arc<ZetaCliAppState>,
 50    cx: &mut AsyncApp,
 51) {
 52    let example = NamedExample::load(args.example_path).unwrap();
 53    let result = zeta2_predict(
 54        example,
 55        args.skip_cache,
 56        args.prompt_format,
 57        args.use_expected_context,
 58        &app_state,
 59        cx,
 60    )
 61    .await
 62    .unwrap();
 63    result.write(args.format, std::io::stdout()).unwrap();
 64}
 65
 66thread_local! {
 67    static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
 68}
 69
 70pub async fn zeta2_predict(
 71    example: NamedExample,
 72    skip_cache: bool,
 73    prompt_format: PromptFormat,
 74    use_expected_context: bool,
 75    app_state: &Arc<ZetaCliAppState>,
 76    cx: &mut AsyncApp,
 77) -> Result<PredictionDetails> {
 78    fs::create_dir_all(&*LOGS_DIR)?;
 79    let worktree_path = example.setup_worktree().await?;
 80
 81    if !AUTHENTICATED.get() {
 82        AUTHENTICATED.set(true);
 83
 84        app_state
 85            .client
 86            .sign_in_with_optional_connect(true, cx)
 87            .await?;
 88    }
 89
 90    let project = cx.update(|cx| {
 91        Project::local(
 92            app_state.client.clone(),
 93            app_state.node_runtime.clone(),
 94            app_state.user_store.clone(),
 95            app_state.languages.clone(),
 96            app_state.fs.clone(),
 97            None,
 98            cx,
 99        )
100    })?;
101
102    let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
103
104    let worktree = project
105        .update(cx, |project, cx| {
106            project.create_worktree(&worktree_path, true, cx)
107        })?
108        .await?;
109    worktree
110        .read_with(cx, |worktree, _cx| {
111            worktree.as_local().unwrap().scan_complete()
112        })?
113        .await;
114
115    let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
116
117    zeta.update(cx, |zeta, _cx| {
118        zeta.with_llm_response_cache(Arc::new(Cache { skip_cache }));
119    })?;
120
121    cx.subscribe(&buffer_store, {
122        let project = project.clone();
123        move |_, event, cx| match event {
124            project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
125                zeta2::Zeta::try_global(cx)
126                    .unwrap()
127                    .update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
128            }
129            _ => {}
130        }
131    })?
132    .detach();
133
134    let _edited_buffers = example.apply_edit_history(&project, cx).await?;
135    let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
136
137    let result = Arc::new(Mutex::new(PredictionDetails::default()));
138    let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
139
140    let debug_task = cx.background_spawn({
141        let result = result.clone();
142        async move {
143            let mut start_time = None;
144            let mut search_queries_generated_at = None;
145            let mut search_queries_executed_at = None;
146            while let Some(event) = debug_rx.next().await {
147                match event {
148                    zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
149                        start_time = Some(info.timestamp);
150                        fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?;
151                    }
152                    zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
153                        search_queries_generated_at = Some(info.timestamp);
154                        fs::write(
155                            LOGS_DIR.join("search_queries.json"),
156                            serde_json::to_string_pretty(&info.search_queries).unwrap(),
157                        )?;
158                    }
159                    zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
160                        search_queries_executed_at = Some(info.timestamp);
161                    }
162                    zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
163                    zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
164                        let prediction_started_at = Instant::now();
165                        start_time.get_or_insert(prediction_started_at);
166                        fs::write(
167                            LOGS_DIR.join("prediction_prompt.md"),
168                            &request.local_prompt.unwrap_or_default(),
169                        )?;
170
171                        {
172                            let mut result = result.lock().unwrap();
173
174                            for included_file in request.request.included_files {
175                                let insertions =
176                                    vec![(request.request.cursor_point, CURSOR_MARKER)];
177                                result.excerpts.extend(included_file.excerpts.iter().map(
178                                    |excerpt| ActualExcerpt {
179                                        path: included_file.path.components().skip(1).collect(),
180                                        text: String::from(excerpt.text.as_ref()),
181                                    },
182                                ));
183                                write_codeblock(
184                                    &included_file.path,
185                                    included_file.excerpts.iter(),
186                                    if included_file.path == request.request.excerpt_path {
187                                        &insertions
188                                    } else {
189                                        &[]
190                                    },
191                                    included_file.max_row,
192                                    false,
193                                    &mut result.excerpts_text,
194                                );
195                            }
196                        }
197
198                        let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
199                        let response = zeta2::text_from_response(response).unwrap_or_default();
200                        let prediction_finished_at = Instant::now();
201                        fs::write(LOGS_DIR.join("prediction_response.md"), &response)?;
202
203                        let mut result = result.lock().unwrap();
204
205                        if !use_expected_context {
206                            result.planning_search_time =
207                                Some(search_queries_generated_at.unwrap() - start_time.unwrap());
208                            result.running_search_time = Some(
209                                search_queries_executed_at.unwrap()
210                                    - search_queries_generated_at.unwrap(),
211                            );
212                        }
213                        result.prediction_time = prediction_finished_at - prediction_started_at;
214                        result.total_time = prediction_finished_at - start_time.unwrap();
215
216                        break;
217                    }
218                }
219            }
220            anyhow::Ok(())
221        }
222    });
223
224    zeta.update(cx, |zeta, _cx| {
225        let mut options = zeta.options().clone();
226        options.prompt_format = prompt_format.into();
227        zeta.set_options(options);
228    })?;
229
230    if use_expected_context {
231        let context_excerpts_tasks = example
232            .example
233            .expected_context
234            .iter()
235            .flat_map(|section| {
236                section.alternatives[0].excerpts.iter().map(|excerpt| {
237                    resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
238                })
239            })
240            .collect::<Vec<_>>();
241        let context_excerpts_vec = futures::future::try_join_all(context_excerpts_tasks).await?;
242
243        let mut context_excerpts = HashMap::default();
244        for (buffer, mut excerpts) in context_excerpts_vec {
245            context_excerpts
246                .entry(buffer)
247                .or_insert(Vec::new())
248                .append(&mut excerpts);
249        }
250
251        zeta.update(cx, |zeta, _cx| {
252            zeta.set_context(project.clone(), context_excerpts)
253        })?;
254    } else {
255        zeta.update(cx, |zeta, cx| {
256            zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
257        })?
258        .await?;
259    }
260
261    let prediction = zeta
262        .update(cx, |zeta, cx| {
263            zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
264        })?
265        .await?;
266
267    debug_task.await?;
268
269    let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
270    result.diff = prediction
271        .map(|prediction| {
272            let old_text = prediction.snapshot.text();
273            let new_text = prediction
274                .buffer
275                .update(cx, |buffer, cx| {
276                    buffer.edit(prediction.edits.iter().cloned(), None, cx);
277                    buffer.text()
278                })
279                .unwrap();
280            language::unified_diff(&old_text, &new_text)
281        })
282        .unwrap_or_default();
283
284    anyhow::Ok(result)
285}
286
287async fn resolve_context_entry(
288    project: Entity<Project>,
289    excerpt: ExpectedExcerpt,
290    mut cx: AsyncApp,
291) -> Result<(Entity<Buffer>, Vec<Range<Anchor>>)> {
292    let buffer = project
293        .update(&mut cx, |project, cx| {
294            let project_path = project.find_project_path(&excerpt.path, cx).unwrap();
295            project.open_buffer(project_path, cx)
296        })?
297        .await?;
298
299    let ranges = buffer.read_with(&mut cx, |buffer, _| {
300        let full_text = buffer.text();
301        let offset = full_text
302            .find(&excerpt.text)
303            .expect("Expected context not found");
304        let point = buffer.offset_to_point(offset);
305        excerpt
306            .required_lines
307            .iter()
308            .map(|line| {
309                let row = point.row + line.0;
310                let range = Point::new(row, 0)..Point::new(row + 1, 0);
311                buffer.anchor_after(range.start)..buffer.anchor_before(range.end)
312            })
313            .collect()
314    })?;
315
316    Ok((buffer, ranges))
317}
318
319struct Cache {
320    skip_cache: bool,
321}
322
323impl Cache {
324    fn path(key: u64) -> PathBuf {
325        CACHE_DIR.join(format!("{key:x}.json"))
326    }
327}
328
329impl LlmResponseCache for Cache {
330    fn get_key(&self, url: &Url, body: &str) -> u64 {
331        use collections::FxHasher;
332        use std::hash::{Hash, Hasher};
333
334        let mut hasher = FxHasher::default();
335        url.hash(&mut hasher);
336        body.hash(&mut hasher);
337        hasher.finish()
338    }
339
340    fn read_response(&self, key: u64) -> Option<String> {
341        let path = Cache::path(key);
342        if path.exists() {
343            if self.skip_cache {
344                log::info!("Skipping existing cached LLM response: {}", path.display());
345                None
346            } else {
347                log::info!("Using LLM response from cache: {}", path.display());
348                Some(fs::read_to_string(path).unwrap())
349            }
350        } else {
351            None
352        }
353    }
354
355    fn write_response(&self, key: u64, value: &str) {
356        fs::create_dir_all(&*CACHE_DIR).unwrap();
357
358        let path = Cache::path(key);
359        log::info!("Writing LLM response to cache: {}", path.display());
360        fs::write(path, value).unwrap();
361    }
362}
363
364#[derive(Clone, Debug, Default, Serialize, Deserialize)]
365pub struct PredictionDetails {
366    pub diff: String,
367    pub excerpts: Vec<ActualExcerpt>,
368    pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
369    pub planning_search_time: Option<Duration>,
370    pub running_search_time: Option<Duration>,
371    pub prediction_time: Duration,
372    pub total_time: Duration,
373}
374
375impl PredictionDetails {
376    pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
377        let formatted = match format {
378            PredictionsOutputFormat::Md => self.to_markdown(),
379            PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
380            PredictionsOutputFormat::Diff => self.diff.clone(),
381        };
382
383        Ok(out.write_all(formatted.as_bytes())?)
384    }
385
386    pub fn to_markdown(&self) -> String {
387        let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
388
389        format!(
390            "## Excerpts\n\n\
391            {}\n\n\
392            ## Prediction\n\n\
393            {}\n\n\
394            ## Time\n\n\
395            Planning searches: {}ms\n\
396            Running searches: {}ms\n\
397            Making Prediction: {}ms\n\n\
398            -------------------\n\n\
399            Total: {}ms\n\
400            Inference: {}ms ({:.2}%)\n",
401            self.excerpts_text,
402            self.diff,
403            self.planning_search_time.unwrap_or_default().as_millis(),
404            self.running_search_time.unwrap_or_default().as_millis(),
405            self.prediction_time.as_millis(),
406            self.total_time.as_millis(),
407            inference_time.as_millis(),
408            (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
409        )
410    }
411}