main.rs

  1mod headless;
  2
  3use anyhow::{Result, anyhow};
  4use clap::{Args, Parser, Subcommand};
  5use cloud_llm_client::predict_edits_v3;
  6use edit_prediction_context::EditPredictionExcerptOptions;
  7use futures::channel::mpsc;
  8use futures::{FutureExt as _, StreamExt as _};
  9use gpui::{AppContext, Application, AsyncApp};
 10use gpui::{Entity, Task};
 11use language::Bias;
 12use language::Buffer;
 13use language::Point;
 14use language_model::LlmApiToken;
 15use project::{Project, ProjectPath, Worktree};
 16use release_channel::AppVersion;
 17use reqwest_client::ReqwestClient;
 18use serde_json::json;
 19use std::path::{Path, PathBuf};
 20use std::process::exit;
 21use std::str::FromStr;
 22use std::sync::Arc;
 23use std::time::Duration;
 24use util::paths::PathStyle;
 25use util::rel_path::RelPath;
 26use zeta::{PerformPredictEditsParams, Zeta};
 27
 28use crate::headless::ZetaCliAppState;
 29
 30#[derive(Parser, Debug)]
 31#[command(name = "zeta")]
 32struct ZetaCliArgs {
 33    #[command(subcommand)]
 34    command: Commands,
 35}
 36
 37#[derive(Subcommand, Debug)]
 38enum Commands {
 39    Context(ContextArgs),
 40    Zeta2Context {
 41        #[clap(flatten)]
 42        zeta2_args: Zeta2Args,
 43        #[clap(flatten)]
 44        context_args: ContextArgs,
 45    },
 46    Predict {
 47        #[arg(long)]
 48        predict_edits_body: Option<FileOrStdin>,
 49        #[clap(flatten)]
 50        context_args: Option<ContextArgs>,
 51    },
 52}
 53
 54#[derive(Debug, Args)]
 55#[group(requires = "worktree")]
 56struct ContextArgs {
 57    #[arg(long)]
 58    worktree: PathBuf,
 59    #[arg(long)]
 60    cursor: CursorPosition,
 61    #[arg(long)]
 62    use_language_server: bool,
 63    #[arg(long)]
 64    events: Option<FileOrStdin>,
 65}
 66
 67#[derive(Debug, Args)]
 68struct Zeta2Args {
 69    #[arg(long, default_value_t = 8192)]
 70    max_prompt_bytes: usize,
 71    #[arg(long, default_value_t = 2048)]
 72    max_excerpt_bytes: usize,
 73    #[arg(long, default_value_t = 1024)]
 74    min_excerpt_bytes: usize,
 75    #[arg(long, default_value_t = 0.66)]
 76    target_before_cursor_over_total_bytes: f32,
 77    #[arg(long, default_value_t = 1024)]
 78    max_diagnostic_bytes: usize,
 79    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 80    prompt_format: PromptFormat,
 81    #[arg(long, value_enum, default_value_t = Default::default())]
 82    output_format: OutputFormat,
 83}
 84
 85#[derive(clap::ValueEnum, Default, Debug, Clone)]
 86enum PromptFormat {
 87    #[default]
 88    MarkedExcerpt,
 89    LabeledSections,
 90    OnlySnippets,
 91}
 92
 93impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
 94    fn into(self) -> predict_edits_v3::PromptFormat {
 95        match self {
 96            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
 97            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
 98            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
 99        }
100    }
101}
102
103#[derive(clap::ValueEnum, Default, Debug, Clone)]
104enum OutputFormat {
105    #[default]
106    Prompt,
107    Request,
108    Both,
109}
110
111#[derive(Debug, Clone)]
112enum FileOrStdin {
113    File(PathBuf),
114    Stdin,
115}
116
117impl FileOrStdin {
118    async fn read_to_string(&self) -> Result<String, std::io::Error> {
119        match self {
120            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
121            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
122        }
123    }
124}
125
126impl FromStr for FileOrStdin {
127    type Err = <PathBuf as FromStr>::Err;
128
129    fn from_str(s: &str) -> Result<Self, Self::Err> {
130        match s {
131            "-" => Ok(Self::Stdin),
132            _ => Ok(Self::File(PathBuf::from_str(s)?)),
133        }
134    }
135}
136
137#[derive(Debug, Clone)]
138struct CursorPosition {
139    path: Arc<RelPath>,
140    point: Point,
141}
142
143impl FromStr for CursorPosition {
144    type Err = anyhow::Error;
145
146    fn from_str(s: &str) -> Result<Self> {
147        let parts: Vec<&str> = s.split(':').collect();
148        if parts.len() != 3 {
149            return Err(anyhow!(
150                "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
151                s
152            ));
153        }
154
155        let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
156        let line: u32 = parts[1]
157            .parse()
158            .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
159        let column: u32 = parts[2]
160            .parse()
161            .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
162
163        // Convert from 1-based to 0-based indexing
164        let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
165
166        Ok(CursorPosition { path, point })
167    }
168}
169
170enum GetContextOutput {
171    Zeta1(zeta::GatherContextOutput),
172    Zeta2(String),
173}
174
175async fn get_context(
176    zeta2_args: Option<Zeta2Args>,
177    args: ContextArgs,
178    app_state: &Arc<ZetaCliAppState>,
179    cx: &mut AsyncApp,
180) -> Result<GetContextOutput> {
181    let ContextArgs {
182        worktree: worktree_path,
183        cursor,
184        use_language_server,
185        events,
186    } = args;
187
188    let worktree_path = worktree_path.canonicalize()?;
189
190    let project = cx.update(|cx| {
191        Project::local(
192            app_state.client.clone(),
193            app_state.node_runtime.clone(),
194            app_state.user_store.clone(),
195            app_state.languages.clone(),
196            app_state.fs.clone(),
197            None,
198            cx,
199        )
200    })?;
201
202    let worktree = project
203        .update(cx, |project, cx| {
204            project.create_worktree(&worktree_path, true, cx)
205        })?
206        .await?;
207
208    let (_lsp_open_handle, buffer) = if use_language_server {
209        let (lsp_open_handle, buffer) =
210            open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
211        (Some(lsp_open_handle), buffer)
212    } else {
213        let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
214        (None, buffer)
215    };
216
217    let full_path_str = worktree
218        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
219        .display(PathStyle::local())
220        .to_string();
221
222    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
223    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
224    if clipped_cursor != cursor.point {
225        let max_row = snapshot.max_point().row;
226        if cursor.point.row < max_row {
227            return Err(anyhow!(
228                "Cursor position {:?} is out of bounds (line length is {})",
229                cursor.point,
230                snapshot.line_len(cursor.point.row)
231            ));
232        } else {
233            return Err(anyhow!(
234                "Cursor position {:?} is out of bounds (max row is {})",
235                cursor.point,
236                max_row
237            ));
238        }
239    }
240
241    let events = match events {
242        Some(events) => events.read_to_string().await?,
243        None => String::new(),
244    };
245
246    if let Some(zeta2_args) = zeta2_args {
247        Ok(GetContextOutput::Zeta2(
248            cx.update(|cx| {
249                let zeta = cx.new(|cx| {
250                    zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
251                });
252                zeta.update(cx, |zeta, cx| {
253                    zeta.register_buffer(&buffer, &project, cx);
254                    zeta.set_options(zeta2::ZetaOptions {
255                        excerpt: EditPredictionExcerptOptions {
256                            max_bytes: zeta2_args.max_excerpt_bytes,
257                            min_bytes: zeta2_args.min_excerpt_bytes,
258                            target_before_cursor_over_total_bytes: zeta2_args
259                                .target_before_cursor_over_total_bytes,
260                        },
261                        max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
262                        max_prompt_bytes: zeta2_args.max_prompt_bytes,
263                        prompt_format: zeta2_args.prompt_format.into(),
264                    })
265                });
266                // TODO: Actually wait for indexing.
267                let timer = cx.background_executor().timer(Duration::from_secs(5));
268                cx.spawn(async move |cx| {
269                    timer.await;
270                    let request = zeta
271                        .update(cx, |zeta, cx| {
272                            let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
273                            zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
274                        })?
275                        .await?;
276
277                    let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
278                    let prompt_string = planned_prompt.to_prompt_string()?.0;
279                    match zeta2_args.output_format {
280                        OutputFormat::Prompt => anyhow::Ok(prompt_string),
281                        OutputFormat::Request => {
282                            anyhow::Ok(serde_json::to_string_pretty(&request)?)
283                        }
284                        OutputFormat::Both => anyhow::Ok(serde_json::to_string_pretty(&json!({
285                            "request": request,
286                            "prompt": prompt_string,
287                        }))?),
288                    }
289                })
290            })?
291            .await?,
292        ))
293    } else {
294        let prompt_for_events = move || (events, 0);
295        Ok(GetContextOutput::Zeta1(
296            cx.update(|cx| {
297                zeta::gather_context(
298                    full_path_str,
299                    &snapshot,
300                    clipped_cursor,
301                    prompt_for_events,
302                    cx,
303                )
304            })?
305            .await?,
306        ))
307    }
308}
309
310pub async fn open_buffer(
311    project: &Entity<Project>,
312    worktree: &Entity<Worktree>,
313    path: &RelPath,
314    cx: &mut AsyncApp,
315) -> Result<Entity<Buffer>> {
316    let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
317        worktree_id: worktree.id(),
318        path: path.into(),
319    })?;
320
321    project
322        .update(cx, |project, cx| project.open_buffer(project_path, cx))?
323        .await
324}
325
326pub async fn open_buffer_with_language_server(
327    project: &Entity<Project>,
328    worktree: &Entity<Worktree>,
329    path: &RelPath,
330    cx: &mut AsyncApp,
331) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
332    let buffer = open_buffer(project, worktree, path, cx).await?;
333
334    let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
335        (
336            project.register_buffer_with_language_servers(&buffer, cx),
337            project.path_style(cx),
338        )
339    })?;
340
341    let log_prefix = path.display(path_style);
342    wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
343
344    Ok((lsp_open_handle, buffer))
345}
346
347// TODO: Dedupe with similar function in crates/eval/src/instance.rs
348pub fn wait_for_lang_server(
349    project: &Entity<Project>,
350    buffer: &Entity<Buffer>,
351    log_prefix: String,
352    cx: &mut AsyncApp,
353) -> Task<Result<()>> {
354    println!("{}⏵ Waiting for language server", log_prefix);
355
356    let (mut tx, mut rx) = mpsc::channel(1);
357
358    let lsp_store = project
359        .read_with(cx, |project, _| project.lsp_store())
360        .unwrap();
361
362    let has_lang_server = buffer
363        .update(cx, |buffer, cx| {
364            lsp_store.update(cx, |lsp_store, cx| {
365                lsp_store
366                    .language_servers_for_local_buffer(buffer, cx)
367                    .next()
368                    .is_some()
369            })
370        })
371        .unwrap_or(false);
372
373    if has_lang_server {
374        project
375            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
376            .unwrap()
377            .detach();
378    }
379
380    let subscriptions = [
381        cx.subscribe(&lsp_store, {
382            let log_prefix = log_prefix.clone();
383            move |_, event, _| {
384                if let project::LspStoreEvent::LanguageServerUpdate {
385                    message:
386                        client::proto::update_language_server::Variant::WorkProgress(
387                            client::proto::LspWorkProgress {
388                                message: Some(message),
389                                ..
390                            },
391                        ),
392                    ..
393                } = event
394                {
395                    println!("{}{message}", log_prefix)
396                }
397            }
398        }),
399        cx.subscribe(project, {
400            let buffer = buffer.clone();
401            move |project, event, cx| match event {
402                project::Event::LanguageServerAdded(_, _, _) => {
403                    let buffer = buffer.clone();
404                    project
405                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
406                        .detach();
407                }
408                project::Event::DiskBasedDiagnosticsFinished { .. } => {
409                    tx.try_send(()).ok();
410                }
411                _ => {}
412            }
413        }),
414    ];
415
416    cx.spawn(async move |cx| {
417        let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
418        let result = futures::select! {
419            _ = rx.next() => {
420                println!("{}⚑ Language server idle", log_prefix);
421                anyhow::Ok(())
422            },
423            _ = timeout.fuse() => {
424                anyhow::bail!("LSP wait timed out after 5 minutes");
425            }
426        };
427        drop(subscriptions);
428        result
429    })
430}
431
432fn main() {
433    let args = ZetaCliArgs::parse();
434    let http_client = Arc::new(ReqwestClient::new());
435    let app = Application::headless().with_http_client(http_client);
436
437    app.run(move |cx| {
438        let app_state = Arc::new(headless::init(cx));
439        let is_zeta2_context_command = matches!(args.command, Commands::Zeta2Context { .. });
440        cx.spawn(async move |cx| {
441            let result = match args.command {
442                Commands::Zeta2Context {
443                    zeta2_args,
444                    context_args,
445                } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
446                    Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
447                    Ok(GetContextOutput::Zeta2(output)) => Ok(output),
448                    Err(err) => Err(err),
449                },
450                Commands::Context(context_args) => {
451                    match get_context(None, context_args, &app_state, cx).await {
452                        Ok(GetContextOutput::Zeta1(output)) => {
453                            Ok(serde_json::to_string_pretty(&output.body).unwrap())
454                        }
455                        Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
456                        Err(err) => Err(err),
457                    }
458                }
459                Commands::Predict {
460                    predict_edits_body,
461                    context_args,
462                } => {
463                    cx.spawn(async move |cx| {
464                        let app_version = cx.update(|cx| AppVersion::global(cx))?;
465                        app_state.client.sign_in(true, cx).await?;
466                        let llm_token = LlmApiToken::default();
467                        llm_token.refresh(&app_state.client).await?;
468
469                        let predict_edits_body =
470                            if let Some(predict_edits_body) = predict_edits_body {
471                                serde_json::from_str(&predict_edits_body.read_to_string().await?)?
472                            } else if let Some(context_args) = context_args {
473                                match get_context(None, context_args, &app_state, cx).await? {
474                                    GetContextOutput::Zeta1(output) => output.body,
475                                    GetContextOutput::Zeta2 { .. } => unreachable!(),
476                                }
477                            } else {
478                                return Err(anyhow!(
479                                    "Expected either --predict-edits-body-file \
480                                    or the required args of the `context` command."
481                                ));
482                            };
483
484                        let (response, _usage) =
485                            Zeta::perform_predict_edits(PerformPredictEditsParams {
486                                client: app_state.client.clone(),
487                                llm_token,
488                                app_version,
489                                body: predict_edits_body,
490                            })
491                            .await?;
492
493                        Ok(response.output_excerpt)
494                    })
495                    .await
496                }
497            };
498            match result {
499                Ok(output) => {
500                    println!("{}", output);
501                    // TODO: Remove this once the 5 second delay is properly replaced.
502                    if is_zeta2_context_command {
503                        eprintln!("Note that zeta_cli doesn't yet wait for indexing, instead waits 5 seconds.");
504                    }
505                    let _ = cx.update(|cx| cx.quit());
506                }
507                Err(e) => {
508                    eprintln!("Failed: {:?}", e);
509                    exit(1);
510                }
511            }
512        })
513        .detach();
514    });
515}