main.rs

  1mod headless;
  2
  3use anyhow::{Result, anyhow};
  4use clap::{Args, Parser, Subcommand};
  5use cloud_llm_client::predict_edits_v3;
  6use edit_prediction_context::EditPredictionExcerptOptions;
  7use futures::channel::mpsc;
  8use futures::{FutureExt as _, StreamExt as _};
  9use gpui::{AppContext, Application, AsyncApp};
 10use gpui::{Entity, Task};
 11use language::Bias;
 12use language::Buffer;
 13use language::Point;
 14use language_model::LlmApiToken;
 15use project::{Project, ProjectPath, Worktree};
 16use release_channel::AppVersion;
 17use reqwest_client::ReqwestClient;
 18use std::path::{Path, PathBuf};
 19use std::process::exit;
 20use std::str::FromStr;
 21use std::sync::Arc;
 22use std::time::Duration;
 23use util::paths::PathStyle;
 24use util::rel_path::RelPath;
 25use zeta::{PerformPredictEditsParams, Zeta};
 26
 27use crate::headless::ZetaCliAppState;
 28
 29#[derive(Parser, Debug)]
 30#[command(name = "zeta")]
 31struct ZetaCliArgs {
 32    #[command(subcommand)]
 33    command: Commands,
 34}
 35
 36#[derive(Subcommand, Debug)]
 37enum Commands {
 38    Context(ContextArgs),
 39    Zeta2Context {
 40        #[clap(flatten)]
 41        zeta2_args: Zeta2Args,
 42        #[clap(flatten)]
 43        context_args: ContextArgs,
 44    },
 45    Predict {
 46        #[arg(long)]
 47        predict_edits_body: Option<FileOrStdin>,
 48        #[clap(flatten)]
 49        context_args: Option<ContextArgs>,
 50    },
 51}
 52
 53#[derive(Debug, Args)]
 54#[group(requires = "worktree")]
 55struct ContextArgs {
 56    #[arg(long)]
 57    worktree: PathBuf,
 58    #[arg(long)]
 59    cursor: CursorPosition,
 60    #[arg(long)]
 61    use_language_server: bool,
 62    #[arg(long)]
 63    events: Option<FileOrStdin>,
 64}
 65
 66#[derive(Debug, Args)]
 67struct Zeta2Args {
 68    #[arg(long, default_value_t = 8192)]
 69    max_prompt_bytes: usize,
 70    #[arg(long, default_value_t = 2048)]
 71    max_excerpt_bytes: usize,
 72    #[arg(long, default_value_t = 1024)]
 73    min_excerpt_bytes: usize,
 74    #[arg(long, default_value_t = 0.66)]
 75    target_before_cursor_over_total_bytes: f32,
 76    #[arg(long, default_value_t = 1024)]
 77    max_diagnostic_bytes: usize,
 78    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 79    prompt_format: PromptFormat,
 80    #[arg(long, value_enum, default_value_t = Default::default())]
 81    output_format: OutputFormat,
 82}
 83
 84#[derive(clap::ValueEnum, Default, Debug, Clone)]
 85enum PromptFormat {
 86    #[default]
 87    MarkedExcerpt,
 88    LabeledSections,
 89}
 90
 91impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
 92    fn into(self) -> predict_edits_v3::PromptFormat {
 93        match self {
 94            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
 95            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
 96        }
 97    }
 98}
 99
100#[derive(clap::ValueEnum, Default, Debug, Clone)]
101enum OutputFormat {
102    #[default]
103    Prompt,
104    Request,
105}
106
107#[derive(Debug, Clone)]
108enum FileOrStdin {
109    File(PathBuf),
110    Stdin,
111}
112
113impl FileOrStdin {
114    async fn read_to_string(&self) -> Result<String, std::io::Error> {
115        match self {
116            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
117            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
118        }
119    }
120}
121
122impl FromStr for FileOrStdin {
123    type Err = <PathBuf as FromStr>::Err;
124
125    fn from_str(s: &str) -> Result<Self, Self::Err> {
126        match s {
127            "-" => Ok(Self::Stdin),
128            _ => Ok(Self::File(PathBuf::from_str(s)?)),
129        }
130    }
131}
132
133#[derive(Debug, Clone)]
134struct CursorPosition {
135    path: Arc<RelPath>,
136    point: Point,
137}
138
139impl FromStr for CursorPosition {
140    type Err = anyhow::Error;
141
142    fn from_str(s: &str) -> Result<Self> {
143        let parts: Vec<&str> = s.split(':').collect();
144        if parts.len() != 3 {
145            return Err(anyhow!(
146                "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
147                s
148            ));
149        }
150
151        let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
152        let line: u32 = parts[1]
153            .parse()
154            .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
155        let column: u32 = parts[2]
156            .parse()
157            .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
158
159        // Convert from 1-based to 0-based indexing
160        let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
161
162        Ok(CursorPosition { path, point })
163    }
164}
165
166enum GetContextOutput {
167    Zeta1(zeta::GatherContextOutput),
168    Zeta2(String),
169}
170
171async fn get_context(
172    zeta2_args: Option<Zeta2Args>,
173    args: ContextArgs,
174    app_state: &Arc<ZetaCliAppState>,
175    cx: &mut AsyncApp,
176) -> Result<GetContextOutput> {
177    let ContextArgs {
178        worktree: worktree_path,
179        cursor,
180        use_language_server,
181        events,
182    } = args;
183
184    let worktree_path = worktree_path.canonicalize()?;
185
186    let project = cx.update(|cx| {
187        Project::local(
188            app_state.client.clone(),
189            app_state.node_runtime.clone(),
190            app_state.user_store.clone(),
191            app_state.languages.clone(),
192            app_state.fs.clone(),
193            None,
194            cx,
195        )
196    })?;
197
198    let worktree = project
199        .update(cx, |project, cx| {
200            project.create_worktree(&worktree_path, true, cx)
201        })?
202        .await?;
203
204    let (_lsp_open_handle, buffer) = if use_language_server {
205        let (lsp_open_handle, buffer) =
206            open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
207        (Some(lsp_open_handle), buffer)
208    } else {
209        let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
210        (None, buffer)
211    };
212
213    let full_path_str = worktree
214        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
215        .display(PathStyle::local())
216        .to_string();
217
218    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
219    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
220    if clipped_cursor != cursor.point {
221        let max_row = snapshot.max_point().row;
222        if cursor.point.row < max_row {
223            return Err(anyhow!(
224                "Cursor position {:?} is out of bounds (line length is {})",
225                cursor.point,
226                snapshot.line_len(cursor.point.row)
227            ));
228        } else {
229            return Err(anyhow!(
230                "Cursor position {:?} is out of bounds (max row is {})",
231                cursor.point,
232                max_row
233            ));
234        }
235    }
236
237    let events = match events {
238        Some(events) => events.read_to_string().await?,
239        None => String::new(),
240    };
241
242    if let Some(zeta2_args) = zeta2_args {
243        Ok(GetContextOutput::Zeta2(
244            cx.update(|cx| {
245                let zeta = cx.new(|cx| {
246                    zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
247                });
248                zeta.update(cx, |zeta, cx| {
249                    zeta.register_buffer(&buffer, &project, cx);
250                    zeta.set_options(zeta2::ZetaOptions {
251                        excerpt: EditPredictionExcerptOptions {
252                            max_bytes: zeta2_args.max_excerpt_bytes,
253                            min_bytes: zeta2_args.min_excerpt_bytes,
254                            target_before_cursor_over_total_bytes: zeta2_args
255                                .target_before_cursor_over_total_bytes,
256                        },
257                        max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
258                        max_prompt_bytes: zeta2_args.max_prompt_bytes,
259                        prompt_format: zeta2_args.prompt_format.into(),
260                    })
261                });
262                // TODO: Actually wait for indexing.
263                let timer = cx.background_executor().timer(Duration::from_secs(5));
264                cx.spawn(async move |cx| {
265                    timer.await;
266                    let request = zeta
267                        .update(cx, |zeta, cx| {
268                            let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
269                            zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
270                        })?
271                        .await?;
272                    match zeta2_args.output_format {
273                        OutputFormat::Prompt => {
274                            let planned_prompt =
275                                cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
276                            // TODO: Output the section label ranges
277                            anyhow::Ok(planned_prompt.to_prompt_string()?.0)
278                        }
279                        OutputFormat::Request => {
280                            anyhow::Ok(serde_json::to_string_pretty(&request)?)
281                        }
282                    }
283                })
284            })?
285            .await?,
286        ))
287    } else {
288        let prompt_for_events = move || (events, 0);
289        Ok(GetContextOutput::Zeta1(
290            cx.update(|cx| {
291                zeta::gather_context(
292                    full_path_str,
293                    &snapshot,
294                    clipped_cursor,
295                    prompt_for_events,
296                    cx,
297                )
298            })?
299            .await?,
300        ))
301    }
302}
303
304pub async fn open_buffer(
305    project: &Entity<Project>,
306    worktree: &Entity<Worktree>,
307    path: &RelPath,
308    cx: &mut AsyncApp,
309) -> Result<Entity<Buffer>> {
310    let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
311        worktree_id: worktree.id(),
312        path: path.into(),
313    })?;
314
315    project
316        .update(cx, |project, cx| project.open_buffer(project_path, cx))?
317        .await
318}
319
320pub async fn open_buffer_with_language_server(
321    project: &Entity<Project>,
322    worktree: &Entity<Worktree>,
323    path: &RelPath,
324    cx: &mut AsyncApp,
325) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
326    let buffer = open_buffer(project, worktree, path, cx).await?;
327
328    let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
329        (
330            project.register_buffer_with_language_servers(&buffer, cx),
331            project.path_style(cx),
332        )
333    })?;
334
335    let log_prefix = path.display(path_style);
336    wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
337
338    Ok((lsp_open_handle, buffer))
339}
340
341// TODO: Dedupe with similar function in crates/eval/src/instance.rs
342pub fn wait_for_lang_server(
343    project: &Entity<Project>,
344    buffer: &Entity<Buffer>,
345    log_prefix: String,
346    cx: &mut AsyncApp,
347) -> Task<Result<()>> {
348    println!("{}⏵ Waiting for language server", log_prefix);
349
350    let (mut tx, mut rx) = mpsc::channel(1);
351
352    let lsp_store = project
353        .read_with(cx, |project, _| project.lsp_store())
354        .unwrap();
355
356    let has_lang_server = buffer
357        .update(cx, |buffer, cx| {
358            lsp_store.update(cx, |lsp_store, cx| {
359                lsp_store
360                    .language_servers_for_local_buffer(buffer, cx)
361                    .next()
362                    .is_some()
363            })
364        })
365        .unwrap_or(false);
366
367    if has_lang_server {
368        project
369            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
370            .unwrap()
371            .detach();
372    }
373
374    let subscriptions = [
375        cx.subscribe(&lsp_store, {
376            let log_prefix = log_prefix.clone();
377            move |_, event, _| {
378                if let project::LspStoreEvent::LanguageServerUpdate {
379                    message:
380                        client::proto::update_language_server::Variant::WorkProgress(
381                            client::proto::LspWorkProgress {
382                                message: Some(message),
383                                ..
384                            },
385                        ),
386                    ..
387                } = event
388                {
389                    println!("{}{message}", log_prefix)
390                }
391            }
392        }),
393        cx.subscribe(project, {
394            let buffer = buffer.clone();
395            move |project, event, cx| match event {
396                project::Event::LanguageServerAdded(_, _, _) => {
397                    let buffer = buffer.clone();
398                    project
399                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
400                        .detach();
401                }
402                project::Event::DiskBasedDiagnosticsFinished { .. } => {
403                    tx.try_send(()).ok();
404                }
405                _ => {}
406            }
407        }),
408    ];
409
410    cx.spawn(async move |cx| {
411        let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
412        let result = futures::select! {
413            _ = rx.next() => {
414                println!("{}⚑ Language server idle", log_prefix);
415                anyhow::Ok(())
416            },
417            _ = timeout.fuse() => {
418                anyhow::bail!("LSP wait timed out after 5 minutes");
419            }
420        };
421        drop(subscriptions);
422        result
423    })
424}
425
426fn main() {
427    let args = ZetaCliArgs::parse();
428    let http_client = Arc::new(ReqwestClient::new());
429    let app = Application::headless().with_http_client(http_client);
430
431    app.run(move |cx| {
432        let app_state = Arc::new(headless::init(cx));
433        let is_zeta2_context_command = matches!(args.command, Commands::Zeta2Context { .. });
434        cx.spawn(async move |cx| {
435            let result = match args.command {
436                Commands::Zeta2Context {
437                    zeta2_args,
438                    context_args,
439                } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
440                    Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
441                    Ok(GetContextOutput::Zeta2(output)) => Ok(output),
442                    Err(err) => Err(err),
443                },
444                Commands::Context(context_args) => {
445                    match get_context(None, context_args, &app_state, cx).await {
446                        Ok(GetContextOutput::Zeta1(output)) => {
447                            Ok(serde_json::to_string_pretty(&output.body).unwrap())
448                        }
449                        Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
450                        Err(err) => Err(err),
451                    }
452                }
453                Commands::Predict {
454                    predict_edits_body,
455                    context_args,
456                } => {
457                    cx.spawn(async move |cx| {
458                        let app_version = cx.update(|cx| AppVersion::global(cx))?;
459                        app_state.client.sign_in(true, cx).await?;
460                        let llm_token = LlmApiToken::default();
461                        llm_token.refresh(&app_state.client).await?;
462
463                        let predict_edits_body =
464                            if let Some(predict_edits_body) = predict_edits_body {
465                                serde_json::from_str(&predict_edits_body.read_to_string().await?)?
466                            } else if let Some(context_args) = context_args {
467                                match get_context(None, context_args, &app_state, cx).await? {
468                                    GetContextOutput::Zeta1(output) => output.body,
469                                    GetContextOutput::Zeta2 { .. } => unreachable!(),
470                                }
471                            } else {
472                                return Err(anyhow!(
473                                    "Expected either --predict-edits-body-file \
474                                    or the required args of the `context` command."
475                                ));
476                            };
477
478                        let (response, _usage) =
479                            Zeta::perform_predict_edits(PerformPredictEditsParams {
480                                client: app_state.client.clone(),
481                                llm_token,
482                                app_version,
483                                body: predict_edits_body,
484                            })
485                            .await?;
486
487                        Ok(response.output_excerpt)
488                    })
489                    .await
490                }
491            };
492            match result {
493                Ok(output) => {
494                    println!("{}", output);
495                    // TODO: Remove this once the 5 second delay is properly replaced.
496                    if is_zeta2_context_command {
497                        eprintln!("Note that zeta_cli doesn't yet wait for indexing, instead waits 5 seconds.");
498                    }
499                    let _ = cx.update(|cx| cx.quit());
500                }
501                Err(e) => {
502                    eprintln!("Failed: {:?}", e);
503                    exit(1);
504                }
505            }
506        })
507        .detach();
508    });
509}