predict.rs

  1use crate::example::{ActualExcerpt, NamedExample};
  2use crate::headless::ZetaCliAppState;
  3use crate::paths::LOGS_DIR;
  4use ::serde::Serialize;
  5use anyhow::{Result, anyhow};
  6use clap::Args;
  7use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
  8use futures::StreamExt as _;
  9use gpui::{AppContext, AsyncApp};
 10use project::Project;
 11use serde::Deserialize;
 12use std::cell::Cell;
 13use std::fs;
 14use std::io::Write;
 15use std::path::PathBuf;
 16use std::sync::Arc;
 17use std::sync::Mutex;
 18use std::time::{Duration, Instant};
 19
 20#[derive(Debug, Args)]
 21pub struct PredictArguments {
 22    example_path: PathBuf,
 23    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
 24    format: PredictionsOutputFormat,
 25}
 26
 27#[derive(clap::ValueEnum, Debug, Clone)]
 28pub enum PredictionsOutputFormat {
 29    Json,
 30    Md,
 31    Diff,
 32}
 33pub async fn run_zeta2_predict(
 34    args: PredictArguments,
 35    app_state: &Arc<ZetaCliAppState>,
 36    cx: &mut AsyncApp,
 37) {
 38    let example = NamedExample::load(args.example_path).unwrap();
 39    let result = zeta2_predict(example, &app_state, cx).await.unwrap();
 40    result.write(args.format, std::io::stdout()).unwrap();
 41}
 42
 43thread_local! {
 44    static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
 45}
 46
 47pub async fn zeta2_predict(
 48    example: NamedExample,
 49    app_state: &Arc<ZetaCliAppState>,
 50    cx: &mut AsyncApp,
 51) -> Result<PredictionDetails> {
 52    fs::create_dir_all(&*LOGS_DIR)?;
 53    let worktree_path = example.setup_worktree().await?;
 54
 55    if !AUTHENTICATED.get() {
 56        AUTHENTICATED.set(true);
 57
 58        app_state
 59            .client
 60            .sign_in_with_optional_connect(true, cx)
 61            .await?;
 62    }
 63
 64    let project = cx.update(|cx| {
 65        Project::local(
 66            app_state.client.clone(),
 67            app_state.node_runtime.clone(),
 68            app_state.user_store.clone(),
 69            app_state.languages.clone(),
 70            app_state.fs.clone(),
 71            None,
 72            cx,
 73        )
 74    })?;
 75
 76    let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
 77
 78    let worktree = project
 79        .update(cx, |project, cx| {
 80            project.create_worktree(&worktree_path, true, cx)
 81        })?
 82        .await?;
 83    worktree
 84        .read_with(cx, |worktree, _cx| {
 85            worktree.as_local().unwrap().scan_complete()
 86        })?
 87        .await;
 88
 89    let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
 90
 91    cx.subscribe(&buffer_store, {
 92        let project = project.clone();
 93        move |_, event, cx| match event {
 94            project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
 95                zeta2::Zeta::try_global(cx)
 96                    .unwrap()
 97                    .update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
 98            }
 99            _ => {}
100        }
101    })?
102    .detach();
103
104    let _edited_buffers = example.apply_edit_history(&project, cx).await?;
105    let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
106
107    let result = Arc::new(Mutex::new(PredictionDetails::default()));
108    let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
109
110    let debug_task = cx.background_spawn({
111        let result = result.clone();
112        async move {
113            let mut context_retrieval_started_at = None;
114            let mut context_retrieval_finished_at = None;
115            let mut search_queries_generated_at = None;
116            let mut search_queries_executed_at = None;
117            while let Some(event) = debug_rx.next().await {
118                match event {
119                    zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
120                        context_retrieval_started_at = Some(info.timestamp);
121                        fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?;
122                    }
123                    zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
124                        search_queries_generated_at = Some(info.timestamp);
125                        fs::write(
126                            LOGS_DIR.join("search_queries.json"),
127                            serde_json::to_string_pretty(&info.search_queries).unwrap(),
128                        )?;
129                    }
130                    zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
131                        search_queries_executed_at = Some(info.timestamp);
132                    }
133                    zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
134                        context_retrieval_finished_at = Some(info.timestamp);
135                    }
136                    zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
137                        let prediction_started_at = Instant::now();
138                        fs::write(
139                            LOGS_DIR.join("prediction_prompt.md"),
140                            &request.local_prompt.unwrap_or_default(),
141                        )?;
142
143                        {
144                            let mut result = result.lock().unwrap();
145
146                            for included_file in request.request.included_files {
147                                let insertions =
148                                    vec![(request.request.cursor_point, CURSOR_MARKER)];
149                                result.excerpts.extend(included_file.excerpts.iter().map(
150                                    |excerpt| ActualExcerpt {
151                                        path: included_file.path.components().skip(1).collect(),
152                                        text: String::from(excerpt.text.as_ref()),
153                                    },
154                                ));
155                                write_codeblock(
156                                    &included_file.path,
157                                    included_file.excerpts.iter(),
158                                    if included_file.path == request.request.excerpt_path {
159                                        &insertions
160                                    } else {
161                                        &[]
162                                    },
163                                    included_file.max_row,
164                                    false,
165                                    &mut result.excerpts_text,
166                                );
167                            }
168                        }
169
170                        let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
171                        let response = zeta2::text_from_response(response).unwrap_or_default();
172                        let prediction_finished_at = Instant::now();
173                        fs::write(LOGS_DIR.join("prediction_response.md"), &response)?;
174
175                        let mut result = result.lock().unwrap();
176
177                        result.planning_search_time = search_queries_generated_at.unwrap()
178                            - context_retrieval_started_at.unwrap();
179                        result.running_search_time = search_queries_executed_at.unwrap()
180                            - search_queries_generated_at.unwrap();
181                        result.filtering_search_time = context_retrieval_finished_at.unwrap()
182                            - search_queries_executed_at.unwrap();
183                        result.prediction_time = prediction_finished_at - prediction_started_at;
184                        result.total_time =
185                            prediction_finished_at - context_retrieval_started_at.unwrap();
186
187                        break;
188                    }
189                }
190            }
191            anyhow::Ok(())
192        }
193    });
194
195    zeta.update(cx, |zeta, cx| {
196        zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
197    })?
198    .await?;
199
200    let prediction = zeta
201        .update(cx, |zeta, cx| {
202            zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
203        })?
204        .await?;
205
206    debug_task.await?;
207
208    let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
209    result.diff = prediction
210        .map(|prediction| {
211            let old_text = prediction.snapshot.text();
212            let new_text = prediction
213                .buffer
214                .update(cx, |buffer, cx| {
215                    buffer.edit(prediction.edits.iter().cloned(), None, cx);
216                    buffer.text()
217                })
218                .unwrap();
219            language::unified_diff(&old_text, &new_text)
220        })
221        .unwrap_or_default();
222
223    anyhow::Ok(result)
224}
225
226#[derive(Clone, Debug, Default, Serialize, Deserialize)]
227pub struct PredictionDetails {
228    pub diff: String,
229    pub excerpts: Vec<ActualExcerpt>,
230    pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
231    pub planning_search_time: Duration,
232    pub filtering_search_time: Duration,
233    pub running_search_time: Duration,
234    pub prediction_time: Duration,
235    pub total_time: Duration,
236}
237
238impl PredictionDetails {
239    pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
240        let formatted = match format {
241            PredictionsOutputFormat::Md => self.to_markdown(),
242            PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
243            PredictionsOutputFormat::Diff => self.diff.clone(),
244        };
245
246        Ok(out.write_all(formatted.as_bytes())?)
247    }
248
249    pub fn to_markdown(&self) -> String {
250        let inference_time =
251            self.planning_search_time + self.filtering_search_time + self.prediction_time;
252
253        format!(
254            "## Excerpts\n\n\
255            {}\n\n\
256            ## Prediction\n\n\
257            {}\n\n\
258            ## Time\n\n\
259            Planning searches: {}ms\n\
260            Running searches: {}ms\n\
261            Filtering context results: {}ms\n\
262            Making Prediction: {}ms\n\n\
263            -------------------\n\n\
264            Total: {}ms\n\
265            Inference: {}ms ({:.2}%)\n",
266            self.excerpts_text,
267            self.diff,
268            self.planning_search_time.as_millis(),
269            self.running_search_time.as_millis(),
270            self.filtering_search_time.as_millis(),
271            self.prediction_time.as_millis(),
272            self.total_time.as_millis(),
273            inference_time.as_millis(),
274            (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
275        )
276    }
277}