predict.rs

  1use crate::PromptFormat;
  2use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
  3use crate::headless::ZetaCliAppState;
  4use crate::paths::{
  5    CACHE_DIR, LOGS_DIR, LOGS_PREDICTION_PROMPT, LOGS_PREDICTION_RESPONSE, LOGS_SEARCH_PROMPT,
  6    LOGS_SEARCH_QUERIES,
  7};
  8use ::serde::Serialize;
  9use anyhow::{Result, anyhow};
 10use clap::Args;
 11use collections::HashMap;
 12use gpui::http_client::Url;
 13use language::{Anchor, Buffer, Point};
 14// use cloud_llm_client::predict_edits_v3::PromptFormat;
 15use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
 16use futures::StreamExt as _;
 17use gpui::{AppContext, AsyncApp, Entity};
 18use project::Project;
 19use serde::Deserialize;
 20use std::cell::Cell;
 21use std::fs;
 22use std::io::Write;
 23use std::ops::Range;
 24use std::path::PathBuf;
 25use std::sync::Arc;
 26use std::sync::Mutex;
 27use std::time::{Duration, Instant};
 28use zeta2::LlmResponseCache;
 29
 30#[derive(Debug, Args)]
 31pub struct PredictArguments {
 32    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 33    prompt_format: PromptFormat,
 34    #[arg(long)]
 35    use_expected_context: bool,
 36    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
 37    format: PredictionsOutputFormat,
 38    example_path: PathBuf,
 39    #[clap(long)]
 40    skip_cache: bool,
 41}
 42
 43#[derive(clap::ValueEnum, Debug, Clone)]
 44pub enum PredictionsOutputFormat {
 45    Json,
 46    Md,
 47    Diff,
 48}
 49
 50pub async fn run_zeta2_predict(
 51    args: PredictArguments,
 52    app_state: &Arc<ZetaCliAppState>,
 53    cx: &mut AsyncApp,
 54) {
 55    let example = NamedExample::load(args.example_path).unwrap();
 56    let result = zeta2_predict(
 57        example,
 58        args.skip_cache,
 59        args.prompt_format,
 60        args.use_expected_context,
 61        &app_state,
 62        cx,
 63    )
 64    .await
 65    .unwrap();
 66    result.write(args.format, std::io::stdout()).unwrap();
 67
 68    println!("## Logs\n");
 69    println!("Search prompt: {}", LOGS_SEARCH_PROMPT.display());
 70    println!("Search queries: {}", LOGS_SEARCH_QUERIES.display());
 71    println!("Prediction prompt: {}", LOGS_PREDICTION_PROMPT.display());
 72    println!(
 73        "Prediction response: {}",
 74        LOGS_PREDICTION_RESPONSE.display()
 75    );
 76}
 77
 78thread_local! {
 79    static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
 80}
 81
 82pub async fn zeta2_predict(
 83    example: NamedExample,
 84    skip_cache: bool,
 85    prompt_format: PromptFormat,
 86    use_expected_context: bool,
 87    app_state: &Arc<ZetaCliAppState>,
 88    cx: &mut AsyncApp,
 89) -> Result<PredictionDetails> {
 90    fs::create_dir_all(&*LOGS_DIR)?;
 91    let worktree_path = example.setup_worktree().await?;
 92
 93    if !AUTHENTICATED.get() {
 94        AUTHENTICATED.set(true);
 95
 96        app_state
 97            .client
 98            .sign_in_with_optional_connect(true, cx)
 99            .await?;
100    }
101
102    let project = cx.update(|cx| {
103        Project::local(
104            app_state.client.clone(),
105            app_state.node_runtime.clone(),
106            app_state.user_store.clone(),
107            app_state.languages.clone(),
108            app_state.fs.clone(),
109            None,
110            cx,
111        )
112    })?;
113
114    let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
115
116    let worktree = project
117        .update(cx, |project, cx| {
118            project.create_worktree(&worktree_path, true, cx)
119        })?
120        .await?;
121    worktree
122        .read_with(cx, |worktree, _cx| {
123            worktree.as_local().unwrap().scan_complete()
124        })?
125        .await;
126
127    let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
128
129    zeta.update(cx, |zeta, _cx| {
130        zeta.with_llm_response_cache(Arc::new(Cache { skip_cache }));
131    })?;
132
133    cx.subscribe(&buffer_store, {
134        let project = project.clone();
135        move |_, event, cx| match event {
136            project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
137                zeta2::Zeta::try_global(cx)
138                    .unwrap()
139                    .update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
140            }
141            _ => {}
142        }
143    })?
144    .detach();
145
146    let _edited_buffers = example.apply_edit_history(&project, cx).await?;
147    let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
148
149    let result = Arc::new(Mutex::new(PredictionDetails::default()));
150    let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
151
152    let debug_task = cx.background_spawn({
153        let result = result.clone();
154        async move {
155            let mut start_time = None;
156            let mut search_queries_generated_at = None;
157            let mut search_queries_executed_at = None;
158            while let Some(event) = debug_rx.next().await {
159                match event {
160                    zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
161                        start_time = Some(info.timestamp);
162                        fs::write(&*LOGS_SEARCH_PROMPT, &info.search_prompt)?;
163                    }
164                    zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
165                        search_queries_generated_at = Some(info.timestamp);
166                        fs::write(
167                            &*LOGS_SEARCH_QUERIES,
168                            serde_json::to_string_pretty(&info.search_queries).unwrap(),
169                        )?;
170                    }
171                    zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
172                        search_queries_executed_at = Some(info.timestamp);
173                    }
174                    zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
175                    zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
176                        let prediction_started_at = Instant::now();
177                        start_time.get_or_insert(prediction_started_at);
178                        fs::write(
179                            &*LOGS_PREDICTION_PROMPT,
180                            &request.local_prompt.unwrap_or_default(),
181                        )?;
182
183                        {
184                            let mut result = result.lock().unwrap();
185
186                            for included_file in request.request.included_files {
187                                let insertions =
188                                    vec![(request.request.cursor_point, CURSOR_MARKER)];
189                                result.excerpts.extend(included_file.excerpts.iter().map(
190                                    |excerpt| ActualExcerpt {
191                                        path: included_file.path.components().skip(1).collect(),
192                                        text: String::from(excerpt.text.as_ref()),
193                                    },
194                                ));
195                                write_codeblock(
196                                    &included_file.path,
197                                    included_file.excerpts.iter(),
198                                    if included_file.path == request.request.excerpt_path {
199                                        &insertions
200                                    } else {
201                                        &[]
202                                    },
203                                    included_file.max_row,
204                                    false,
205                                    &mut result.excerpts_text,
206                                );
207                            }
208                        }
209
210                        let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
211                        let response = zeta2::text_from_response(response).unwrap_or_default();
212                        let prediction_finished_at = Instant::now();
213                        fs::write(&*LOGS_PREDICTION_RESPONSE, &response)?;
214
215                        let mut result = result.lock().unwrap();
216
217                        if !use_expected_context {
218                            result.planning_search_time =
219                                Some(search_queries_generated_at.unwrap() - start_time.unwrap());
220                            result.running_search_time = Some(
221                                search_queries_executed_at.unwrap()
222                                    - search_queries_generated_at.unwrap(),
223                            );
224                        }
225                        result.prediction_time = prediction_finished_at - prediction_started_at;
226                        result.total_time = prediction_finished_at - start_time.unwrap();
227
228                        break;
229                    }
230                }
231            }
232            anyhow::Ok(())
233        }
234    });
235
236    zeta.update(cx, |zeta, _cx| {
237        let mut options = zeta.options().clone();
238        options.prompt_format = prompt_format.into();
239        zeta.set_options(options);
240    })?;
241
242    if use_expected_context {
243        let context_excerpts_tasks = example
244            .example
245            .expected_context
246            .iter()
247            .flat_map(|section| {
248                section.alternatives[0].excerpts.iter().map(|excerpt| {
249                    resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
250                })
251            })
252            .collect::<Vec<_>>();
253        let context_excerpts_vec = futures::future::try_join_all(context_excerpts_tasks).await?;
254
255        let mut context_excerpts = HashMap::default();
256        for (buffer, mut excerpts) in context_excerpts_vec {
257            context_excerpts
258                .entry(buffer)
259                .or_insert(Vec::new())
260                .append(&mut excerpts);
261        }
262
263        zeta.update(cx, |zeta, _cx| {
264            zeta.set_context(project.clone(), context_excerpts)
265        })?;
266    } else {
267        zeta.update(cx, |zeta, cx| {
268            zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
269        })?
270        .await?;
271    }
272
273    let prediction = zeta
274        .update(cx, |zeta, cx| {
275            zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
276        })?
277        .await?;
278
279    debug_task.await?;
280
281    let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
282    result.diff = prediction
283        .map(|prediction| {
284            let old_text = prediction.snapshot.text();
285            let new_text = prediction
286                .buffer
287                .update(cx, |buffer, cx| {
288                    buffer.edit(prediction.edits.iter().cloned(), None, cx);
289                    buffer.text()
290                })
291                .unwrap();
292            language::unified_diff(&old_text, &new_text)
293        })
294        .unwrap_or_default();
295
296    anyhow::Ok(result)
297}
298
299async fn resolve_context_entry(
300    project: Entity<Project>,
301    excerpt: ExpectedExcerpt,
302    mut cx: AsyncApp,
303) -> Result<(Entity<Buffer>, Vec<Range<Anchor>>)> {
304    let buffer = project
305        .update(&mut cx, |project, cx| {
306            let project_path = project.find_project_path(&excerpt.path, cx).unwrap();
307            project.open_buffer(project_path, cx)
308        })?
309        .await?;
310
311    let ranges = buffer.read_with(&mut cx, |buffer, _| {
312        let full_text = buffer.text();
313        let offset = full_text
314            .find(&excerpt.text)
315            .expect("Expected context not found");
316        let point = buffer.offset_to_point(offset);
317        excerpt
318            .required_lines
319            .iter()
320            .map(|line| {
321                let row = point.row + line.0;
322                let range = Point::new(row, 0)..Point::new(row + 1, 0);
323                buffer.anchor_after(range.start)..buffer.anchor_before(range.end)
324            })
325            .collect()
326    })?;
327
328    Ok((buffer, ranges))
329}
330
331struct Cache {
332    skip_cache: bool,
333}
334
335impl Cache {
336    fn path(key: u64) -> PathBuf {
337        CACHE_DIR.join(format!("{key:x}.json"))
338    }
339}
340
341impl LlmResponseCache for Cache {
342    fn get_key(&self, url: &Url, body: &str) -> u64 {
343        use collections::FxHasher;
344        use std::hash::{Hash, Hasher};
345
346        let mut hasher = FxHasher::default();
347        url.hash(&mut hasher);
348        body.hash(&mut hasher);
349        hasher.finish()
350    }
351
352    fn read_response(&self, key: u64) -> Option<String> {
353        let path = Cache::path(key);
354        if path.exists() {
355            if self.skip_cache {
356                log::info!("Skipping existing cached LLM response: {}", path.display());
357                None
358            } else {
359                log::info!("Using LLM response from cache: {}", path.display());
360                Some(fs::read_to_string(path).unwrap())
361            }
362        } else {
363            None
364        }
365    }
366
367    fn write_response(&self, key: u64, value: &str) {
368        fs::create_dir_all(&*CACHE_DIR).unwrap();
369
370        let path = Cache::path(key);
371        log::info!("Writing LLM response to cache: {}", path.display());
372        fs::write(path, value).unwrap();
373    }
374}
375
376#[derive(Clone, Debug, Default, Serialize, Deserialize)]
377pub struct PredictionDetails {
378    pub diff: String,
379    pub excerpts: Vec<ActualExcerpt>,
380    pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
381    pub planning_search_time: Option<Duration>,
382    pub running_search_time: Option<Duration>,
383    pub prediction_time: Duration,
384    pub total_time: Duration,
385}
386
387impl PredictionDetails {
388    pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
389        let formatted = match format {
390            PredictionsOutputFormat::Md => self.to_markdown(),
391            PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
392            PredictionsOutputFormat::Diff => self.diff.clone(),
393        };
394
395        Ok(out.write_all(formatted.as_bytes())?)
396    }
397
398    pub fn to_markdown(&self) -> String {
399        let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
400
401        format!(
402            "## Excerpts\n\n\
403            {}\n\n\
404            ## Prediction\n\n\
405            {}\n\n\
406            ## Time\n\n\
407            Planning searches: {}ms\n\
408            Running searches: {}ms\n\
409            Making Prediction: {}ms\n\n\
410            -------------------\n\n\
411            Total: {}ms\n\
412            Inference: {}ms ({:.2}%)\n",
413            self.excerpts_text,
414            self.diff,
415            self.planning_search_time.unwrap_or_default().as_millis(),
416            self.running_search_time.unwrap_or_default().as_millis(),
417            self.prediction_time.as_millis(),
418            self.total_time.as_millis(),
419            inference_time.as_millis(),
420            (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
421        )
422    }
423}