predict.rs

  1use crate::PromptFormat;
  2use crate::example::{ActualExcerpt, NamedExample};
  3use crate::headless::ZetaCliAppState;
  4use crate::paths::{CACHE_DIR, LOGS_DIR};
  5use ::serde::Serialize;
  6use anyhow::{Result, anyhow};
  7use clap::Args;
  8use gpui::http_client::Url;
  9// use cloud_llm_client::predict_edits_v3::PromptFormat;
 10use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
 11use futures::StreamExt as _;
 12use gpui::{AppContext, AsyncApp};
 13use project::Project;
 14use serde::Deserialize;
 15use std::cell::Cell;
 16use std::fs;
 17use std::io::Write;
 18use std::path::PathBuf;
 19use std::sync::Arc;
 20use std::sync::Mutex;
 21use std::time::{Duration, Instant};
 22use zeta2::LlmResponseCache;
 23
 24#[derive(Debug, Args)]
 25pub struct PredictArguments {
 26    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 27    prompt_format: PromptFormat,
 28    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
 29    format: PredictionsOutputFormat,
 30    example_path: PathBuf,
 31    #[clap(long)]
 32    skip_cache: bool,
 33}
 34
 35#[derive(clap::ValueEnum, Debug, Clone)]
 36pub enum PredictionsOutputFormat {
 37    Json,
 38    Md,
 39    Diff,
 40}
 41pub async fn run_zeta2_predict(
 42    args: PredictArguments,
 43    app_state: &Arc<ZetaCliAppState>,
 44    cx: &mut AsyncApp,
 45) {
 46    let example = NamedExample::load(args.example_path).unwrap();
 47    let result = zeta2_predict(example, args.skip_cache, args.prompt_format, &app_state, cx)
 48        .await
 49        .unwrap();
 50    result.write(args.format, std::io::stdout()).unwrap();
 51}
 52
 53thread_local! {
 54    static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
 55}
 56
 57pub async fn zeta2_predict(
 58    example: NamedExample,
 59    skip_cache: bool,
 60    prompt_format: PromptFormat,
 61    app_state: &Arc<ZetaCliAppState>,
 62    cx: &mut AsyncApp,
 63) -> Result<PredictionDetails> {
 64    fs::create_dir_all(&*LOGS_DIR)?;
 65    let worktree_path = example.setup_worktree().await?;
 66
 67    if !AUTHENTICATED.get() {
 68        AUTHENTICATED.set(true);
 69
 70        app_state
 71            .client
 72            .sign_in_with_optional_connect(true, cx)
 73            .await?;
 74    }
 75
 76    let project = cx.update(|cx| {
 77        Project::local(
 78            app_state.client.clone(),
 79            app_state.node_runtime.clone(),
 80            app_state.user_store.clone(),
 81            app_state.languages.clone(),
 82            app_state.fs.clone(),
 83            None,
 84            cx,
 85        )
 86    })?;
 87
 88    let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
 89
 90    let worktree = project
 91        .update(cx, |project, cx| {
 92            project.create_worktree(&worktree_path, true, cx)
 93        })?
 94        .await?;
 95    worktree
 96        .read_with(cx, |worktree, _cx| {
 97            worktree.as_local().unwrap().scan_complete()
 98        })?
 99        .await;
100
101    let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
102
103    zeta.update(cx, |zeta, _cx| {
104        zeta.with_llm_response_cache(Arc::new(Cache { skip_cache }));
105    })?;
106
107    cx.subscribe(&buffer_store, {
108        let project = project.clone();
109        move |_, event, cx| match event {
110            project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
111                zeta2::Zeta::try_global(cx)
112                    .unwrap()
113                    .update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
114            }
115            _ => {}
116        }
117    })?
118    .detach();
119
120    let _edited_buffers = example.apply_edit_history(&project, cx).await?;
121    let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
122
123    let result = Arc::new(Mutex::new(PredictionDetails::default()));
124    let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
125
126    let debug_task = cx.background_spawn({
127        let result = result.clone();
128        async move {
129            let mut context_retrieval_started_at = None;
130            let mut context_retrieval_finished_at = None;
131            let mut search_queries_generated_at = None;
132            let mut search_queries_executed_at = None;
133            while let Some(event) = debug_rx.next().await {
134                match event {
135                    zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
136                        context_retrieval_started_at = Some(info.timestamp);
137                        fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?;
138                    }
139                    zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
140                        search_queries_generated_at = Some(info.timestamp);
141                        fs::write(
142                            LOGS_DIR.join("search_queries.json"),
143                            serde_json::to_string_pretty(&info.search_queries).unwrap(),
144                        )?;
145                    }
146                    zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
147                        search_queries_executed_at = Some(info.timestamp);
148                    }
149                    zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
150                        context_retrieval_finished_at = Some(info.timestamp);
151                    }
152                    zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
153                        let prediction_started_at = Instant::now();
154                        fs::write(
155                            LOGS_DIR.join("prediction_prompt.md"),
156                            &request.local_prompt.unwrap_or_default(),
157                        )?;
158
159                        {
160                            let mut result = result.lock().unwrap();
161
162                            for included_file in request.request.included_files {
163                                let insertions =
164                                    vec![(request.request.cursor_point, CURSOR_MARKER)];
165                                result.excerpts.extend(included_file.excerpts.iter().map(
166                                    |excerpt| ActualExcerpt {
167                                        path: included_file.path.components().skip(1).collect(),
168                                        text: String::from(excerpt.text.as_ref()),
169                                    },
170                                ));
171                                write_codeblock(
172                                    &included_file.path,
173                                    included_file.excerpts.iter(),
174                                    if included_file.path == request.request.excerpt_path {
175                                        &insertions
176                                    } else {
177                                        &[]
178                                    },
179                                    included_file.max_row,
180                                    false,
181                                    &mut result.excerpts_text,
182                                );
183                            }
184                        }
185
186                        let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
187                        let response = zeta2::text_from_response(response).unwrap_or_default();
188                        let prediction_finished_at = Instant::now();
189                        fs::write(LOGS_DIR.join("prediction_response.md"), &response)?;
190
191                        let mut result = result.lock().unwrap();
192
193                        result.planning_search_time = search_queries_generated_at.unwrap()
194                            - context_retrieval_started_at.unwrap();
195                        result.running_search_time = search_queries_executed_at.unwrap()
196                            - search_queries_generated_at.unwrap();
197                        result.filtering_search_time = context_retrieval_finished_at.unwrap()
198                            - search_queries_executed_at.unwrap();
199                        result.prediction_time = prediction_finished_at - prediction_started_at;
200                        result.total_time =
201                            prediction_finished_at - context_retrieval_started_at.unwrap();
202
203                        break;
204                    }
205                }
206            }
207            anyhow::Ok(())
208        }
209    });
210
211    zeta.update(cx, |zeta, cx| {
212        let mut options = zeta.options().clone();
213        options.prompt_format = prompt_format.into();
214        zeta.set_options(options);
215        zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
216    })?
217    .await?;
218
219    let prediction = zeta
220        .update(cx, |zeta, cx| {
221            zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
222        })?
223        .await?;
224
225    debug_task.await?;
226
227    let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
228    result.diff = prediction
229        .map(|prediction| {
230            let old_text = prediction.snapshot.text();
231            let new_text = prediction
232                .buffer
233                .update(cx, |buffer, cx| {
234                    buffer.edit(prediction.edits.iter().cloned(), None, cx);
235                    buffer.text()
236                })
237                .unwrap();
238            language::unified_diff(&old_text, &new_text)
239        })
240        .unwrap_or_default();
241
242    anyhow::Ok(result)
243}
244
245struct Cache {
246    skip_cache: bool,
247}
248
249impl Cache {
250    fn path(key: u64) -> PathBuf {
251        CACHE_DIR.join(format!("{key:x}.json"))
252    }
253}
254
255impl LlmResponseCache for Cache {
256    fn get_key(&self, url: &Url, body: &str) -> u64 {
257        use collections::FxHasher;
258        use std::hash::{Hash, Hasher};
259
260        let mut hasher = FxHasher::default();
261        url.hash(&mut hasher);
262        body.hash(&mut hasher);
263        hasher.finish()
264    }
265
266    fn read_response(&self, key: u64) -> Option<String> {
267        let path = Cache::path(key);
268        if path.exists() {
269            if self.skip_cache {
270                log::info!("Skipping existing cached LLM response: {}", path.display());
271                None
272            } else {
273                log::info!("Using LLM response from cache: {}", path.display());
274                Some(fs::read_to_string(path).unwrap())
275            }
276        } else {
277            None
278        }
279    }
280
281    fn write_response(&self, key: u64, value: &str) {
282        fs::create_dir_all(&*CACHE_DIR).unwrap();
283
284        let path = Cache::path(key);
285        log::info!("Writing LLM response to cache: {}", path.display());
286        fs::write(path, value).unwrap();
287    }
288}
289
290#[derive(Clone, Debug, Default, Serialize, Deserialize)]
291pub struct PredictionDetails {
292    pub diff: String,
293    pub excerpts: Vec<ActualExcerpt>,
294    pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
295    pub planning_search_time: Duration,
296    pub filtering_search_time: Duration,
297    pub running_search_time: Duration,
298    pub prediction_time: Duration,
299    pub total_time: Duration,
300}
301
302impl PredictionDetails {
303    pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
304        let formatted = match format {
305            PredictionsOutputFormat::Md => self.to_markdown(),
306            PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
307            PredictionsOutputFormat::Diff => self.diff.clone(),
308        };
309
310        Ok(out.write_all(formatted.as_bytes())?)
311    }
312
313    pub fn to_markdown(&self) -> String {
314        let inference_time =
315            self.planning_search_time + self.filtering_search_time + self.prediction_time;
316
317        format!(
318            "## Excerpts\n\n\
319            {}\n\n\
320            ## Prediction\n\n\
321            {}\n\n\
322            ## Time\n\n\
323            Planning searches: {}ms\n\
324            Running searches: {}ms\n\
325            Filtering context results: {}ms\n\
326            Making Prediction: {}ms\n\n\
327            -------------------\n\n\
328            Total: {}ms\n\
329            Inference: {}ms ({:.2}%)\n",
330            self.excerpts_text,
331            self.diff,
332            self.planning_search_time.as_millis(),
333            self.running_search_time.as_millis(),
334            self.filtering_search_time.as_millis(),
335            self.prediction_time.as_millis(),
336            self.total_time.as_millis(),
337            inference_time.as_millis(),
338            (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
339        )
340    }
341}