predict.rs

  1use crate::example::{ActualExcerpt, NamedExample};
  2
  3use crate::headless::ZetaCliAppState;
  4use ::serde::Serialize;
  5use ::util::paths::PathStyle;
  6use anyhow::{Context as _, Result, anyhow};
  7use clap::Args;
  8use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
  9use futures::StreamExt as _;
 10use gpui::AsyncApp;
 11use language_model::LanguageModelRegistry;
 12use project::{Project, ProjectPath};
 13use serde::Deserialize;
 14use std::cell::Cell;
 15use std::io::Write;
 16use std::path::PathBuf;
 17use std::sync::Arc;
 18use std::time::{Duration, Instant};
 19use util::rel_path::RelPath;
 20
 21#[derive(Debug, Args)]
 22pub struct PredictArguments {
 23    example_path: PathBuf,
 24    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
 25    format: PredictionsOutputFormat,
 26}
 27
 28#[derive(clap::ValueEnum, Debug, Clone)]
 29pub enum PredictionsOutputFormat {
 30    Json,
 31    Md,
 32    Diff,
 33}
 34pub async fn run_zeta2_predict(
 35    args: PredictArguments,
 36    app_state: &Arc<ZetaCliAppState>,
 37    cx: &mut AsyncApp,
 38) {
 39    let example = NamedExample::load(args.example_path).unwrap();
 40    let result = zeta2_predict(example, &app_state, cx).await.unwrap();
 41    result.write(args.format, std::io::stdout()).unwrap();
 42}
 43
 44thread_local! {
 45    static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
 46}
 47
 48pub async fn zeta2_predict(
 49    example: NamedExample,
 50    app_state: &Arc<ZetaCliAppState>,
 51    cx: &mut AsyncApp,
 52) -> Result<PredictionDetails> {
 53    let worktree_path = example.setup_worktree().await?;
 54
 55    if !AUTHENTICATED.get() {
 56        AUTHENTICATED.set(true);
 57
 58        cx.update(|cx| {
 59            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 60                registry
 61                    .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
 62                    .unwrap()
 63                    .authenticate(cx)
 64            })
 65        })?
 66        .await?;
 67
 68        app_state
 69            .client
 70            .sign_in_with_optional_connect(true, cx)
 71            .await?;
 72    }
 73
 74    let project = cx.update(|cx| {
 75        Project::local(
 76            app_state.client.clone(),
 77            app_state.node_runtime.clone(),
 78            app_state.user_store.clone(),
 79            app_state.languages.clone(),
 80            app_state.fs.clone(),
 81            None,
 82            cx,
 83        )
 84    })?;
 85
 86    let worktree = project
 87        .update(cx, |project, cx| {
 88            project.create_worktree(&worktree_path, true, cx)
 89        })?
 90        .await?;
 91    worktree
 92        .read_with(cx, |worktree, _cx| {
 93            worktree.as_local().unwrap().scan_complete()
 94        })?
 95        .await;
 96
 97    let _edited_buffers = example.apply_edit_history(&project, cx).await?;
 98
 99    let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc();
100
101    let cursor_buffer = project
102        .update(cx, |project, cx| {
103            project.open_buffer(
104                ProjectPath {
105                    worktree_id: worktree.read(cx).id(),
106                    path: cursor_path,
107                },
108                cx,
109            )
110        })?
111        .await?;
112
113    let cursor_offset_within_excerpt = example
114        .example
115        .cursor_position
116        .find(CURSOR_MARKER)
117        .ok_or_else(|| anyhow!("missing cursor marker"))?;
118    let mut cursor_excerpt = example.example.cursor_position.clone();
119    cursor_excerpt.replace_range(
120        cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
121        "",
122    );
123    let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
124        let text = buffer.text();
125
126        let mut matches = text.match_indices(&cursor_excerpt);
127        let Some((excerpt_offset, _)) = matches.next() else {
128            anyhow::bail!(
129                "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
130            );
131        };
132        assert!(matches.next().is_none());
133
134        Ok(excerpt_offset)
135    })??;
136
137    let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
138    let cursor_anchor =
139        cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
140
141    let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
142
143    let refresh_task = zeta.update(cx, |zeta, cx| {
144        zeta.register_buffer(&cursor_buffer, &project, cx);
145        zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
146    })?;
147
148    let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
149    let mut context_retrieval_started_at = None;
150    let mut context_retrieval_finished_at = None;
151    let mut search_queries_generated_at = None;
152    let mut search_queries_executed_at = None;
153    let mut prediction_started_at = None;
154    let mut prediction_finished_at = None;
155    let mut excerpts_text = String::new();
156    let mut prediction_task = None;
157    let mut result = PredictionDetails::default();
158    while let Some(event) = debug_rx.next().await {
159        match event {
160            zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
161                context_retrieval_started_at = Some(info.timestamp);
162            }
163            zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
164                search_queries_generated_at = Some(info.timestamp);
165            }
166            zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
167                search_queries_executed_at = Some(info.timestamp);
168            }
169            zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
170                context_retrieval_finished_at = Some(info.timestamp);
171
172                prediction_task = Some(zeta.update(cx, |zeta, cx| {
173                    zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
174                })?);
175            }
176            zeta2::ZetaDebugInfo::EditPredicted(request) => {
177                prediction_started_at = Some(Instant::now());
178                request.response_rx.await?.map_err(|err| anyhow!(err))?;
179                prediction_finished_at = Some(Instant::now());
180
181                for included_file in request.request.included_files {
182                    let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
183                    result
184                        .excerpts
185                        .extend(included_file.excerpts.iter().map(|excerpt| ActualExcerpt {
186                            path: included_file.path.components().skip(1).collect(),
187                            text: String::from(excerpt.text.as_ref()),
188                        }));
189                    write_codeblock(
190                        &included_file.path,
191                        included_file.excerpts.iter(),
192                        if included_file.path == request.request.excerpt_path {
193                            &insertions
194                        } else {
195                            &[]
196                        },
197                        included_file.max_row,
198                        false,
199                        &mut excerpts_text,
200                    );
201                }
202                break;
203            }
204            _ => {}
205        }
206    }
207
208    refresh_task.await.context("context retrieval failed")?;
209    let prediction = prediction_task.unwrap().await?;
210
211    result.diff = prediction
212        .map(|prediction| {
213            let old_text = prediction.snapshot.text();
214            let new_text = prediction.buffer.update(cx, |buffer, cx| {
215                buffer.edit(prediction.edits.iter().cloned(), None, cx);
216                buffer.text()
217            })?;
218            anyhow::Ok(language::unified_diff(&old_text, &new_text))
219        })
220        .transpose()?
221        .unwrap_or_default();
222    result.excerpts_text = excerpts_text;
223
224    result.planning_search_time =
225        search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap();
226    result.running_search_time =
227        search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap();
228    result.filtering_search_time =
229        context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap();
230    result.prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap();
231    result.total_time = prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap();
232
233    anyhow::Ok(result)
234}
235
236#[derive(Debug, Default, Serialize, Deserialize)]
237pub struct PredictionDetails {
238    pub diff: String,
239    pub excerpts: Vec<ActualExcerpt>,
240    pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
241    pub planning_search_time: Duration,
242    pub filtering_search_time: Duration,
243    pub running_search_time: Duration,
244    pub prediction_time: Duration,
245    pub total_time: Duration,
246}
247
248impl PredictionDetails {
249    pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
250        let formatted = match format {
251            PredictionsOutputFormat::Md => self.to_markdown(),
252            PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
253            PredictionsOutputFormat::Diff => self.diff.clone(),
254        };
255
256        Ok(out.write_all(formatted.as_bytes())?)
257    }
258
259    pub fn to_markdown(&self) -> String {
260        let inference_time =
261            self.planning_search_time + self.filtering_search_time + self.prediction_time;
262
263        format!(
264            "## Excerpts\n\n\
265            {}\n\n\
266            ## Prediction\n\n\
267            {}\n\n\
268            ## Time\n\n\
269            Planning searches: {}ms\n\
270            Running searches: {}ms\n\
271            Filtering context results: {}ms\n\
272            Making Prediction: {}ms\n\n\
273            -------------------\n\n\
274            Total: {}ms\n\
275            Inference: {}ms ({:.2}%)\n",
276            self.excerpts_text,
277            self.diff,
278            self.planning_search_time.as_millis(),
279            self.running_search_time.as_millis(),
280            self.filtering_search_time.as_millis(),
281            self.prediction_time.as_millis(),
282            self.total_time.as_millis(),
283            inference_time.as_millis(),
284            (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
285        )
286    }
287}