1mod headless;
2
3use anyhow::{Result, anyhow};
4use clap::{Args, Parser, Subcommand};
5use cloud_llm_client::predict_edits_v3;
6use edit_prediction_context::EditPredictionExcerptOptions;
7use futures::channel::mpsc;
8use futures::{FutureExt as _, StreamExt as _};
9use gpui::{AppContext, Application, AsyncApp};
10use gpui::{Entity, Task};
11use language::Bias;
12use language::Buffer;
13use language::Point;
14use language_model::LlmApiToken;
15use project::{Project, ProjectPath, Worktree};
16use release_channel::AppVersion;
17use reqwest_client::ReqwestClient;
18use serde_json::json;
19use std::path::{Path, PathBuf};
20use std::process::exit;
21use std::str::FromStr;
22use std::sync::Arc;
23use std::time::Duration;
24use util::paths::PathStyle;
25use util::rel_path::RelPath;
26use zeta::{PerformPredictEditsParams, Zeta};
27
28use crate::headless::ZetaCliAppState;
29
30#[derive(Parser, Debug)]
31#[command(name = "zeta")]
32struct ZetaCliArgs {
33 #[command(subcommand)]
34 command: Commands,
35}
36
37#[derive(Subcommand, Debug)]
38enum Commands {
39 Context(ContextArgs),
40 Zeta2Context {
41 #[clap(flatten)]
42 zeta2_args: Zeta2Args,
43 #[clap(flatten)]
44 context_args: ContextArgs,
45 },
46 Predict {
47 #[arg(long)]
48 predict_edits_body: Option<FileOrStdin>,
49 #[clap(flatten)]
50 context_args: Option<ContextArgs>,
51 },
52}
53
54#[derive(Debug, Args)]
55#[group(requires = "worktree")]
56struct ContextArgs {
57 #[arg(long)]
58 worktree: PathBuf,
59 #[arg(long)]
60 cursor: CursorPosition,
61 #[arg(long)]
62 use_language_server: bool,
63 #[arg(long)]
64 events: Option<FileOrStdin>,
65}
66
67#[derive(Debug, Args)]
68struct Zeta2Args {
69 #[arg(long, default_value_t = 8192)]
70 max_prompt_bytes: usize,
71 #[arg(long, default_value_t = 2048)]
72 max_excerpt_bytes: usize,
73 #[arg(long, default_value_t = 1024)]
74 min_excerpt_bytes: usize,
75 #[arg(long, default_value_t = 0.66)]
76 target_before_cursor_over_total_bytes: f32,
77 #[arg(long, default_value_t = 1024)]
78 max_diagnostic_bytes: usize,
79 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
80 prompt_format: PromptFormat,
81 #[arg(long, value_enum, default_value_t = Default::default())]
82 output_format: OutputFormat,
83 #[arg(long, default_value_t = 42)]
84 file_indexing_parallelism: usize,
85}
86
87#[derive(clap::ValueEnum, Default, Debug, Clone)]
88enum PromptFormat {
89 #[default]
90 MarkedExcerpt,
91 LabeledSections,
92 OnlySnippets,
93}
94
95impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
96 fn into(self) -> predict_edits_v3::PromptFormat {
97 match self {
98 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
99 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
100 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
101 }
102 }
103}
104
105#[derive(clap::ValueEnum, Default, Debug, Clone)]
106enum OutputFormat {
107 #[default]
108 Prompt,
109 Request,
110 Both,
111}
112
113#[derive(Debug, Clone)]
114enum FileOrStdin {
115 File(PathBuf),
116 Stdin,
117}
118
119impl FileOrStdin {
120 async fn read_to_string(&self) -> Result<String, std::io::Error> {
121 match self {
122 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
123 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
124 }
125 }
126}
127
128impl FromStr for FileOrStdin {
129 type Err = <PathBuf as FromStr>::Err;
130
131 fn from_str(s: &str) -> Result<Self, Self::Err> {
132 match s {
133 "-" => Ok(Self::Stdin),
134 _ => Ok(Self::File(PathBuf::from_str(s)?)),
135 }
136 }
137}
138
139#[derive(Debug, Clone)]
140struct CursorPosition {
141 path: Arc<RelPath>,
142 point: Point,
143}
144
145impl FromStr for CursorPosition {
146 type Err = anyhow::Error;
147
148 fn from_str(s: &str) -> Result<Self> {
149 let parts: Vec<&str> = s.split(':').collect();
150 if parts.len() != 3 {
151 return Err(anyhow!(
152 "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
153 s
154 ));
155 }
156
157 let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
158 let line: u32 = parts[1]
159 .parse()
160 .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
161 let column: u32 = parts[2]
162 .parse()
163 .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
164
165 // Convert from 1-based to 0-based indexing
166 let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
167
168 Ok(CursorPosition { path, point })
169 }
170}
171
172enum GetContextOutput {
173 Zeta1(zeta::GatherContextOutput),
174 Zeta2(String),
175}
176
177async fn get_context(
178 zeta2_args: Option<Zeta2Args>,
179 args: ContextArgs,
180 app_state: &Arc<ZetaCliAppState>,
181 cx: &mut AsyncApp,
182) -> Result<GetContextOutput> {
183 let ContextArgs {
184 worktree: worktree_path,
185 cursor,
186 use_language_server,
187 events,
188 } = args;
189
190 let worktree_path = worktree_path.canonicalize()?;
191
192 let project = cx.update(|cx| {
193 Project::local(
194 app_state.client.clone(),
195 app_state.node_runtime.clone(),
196 app_state.user_store.clone(),
197 app_state.languages.clone(),
198 app_state.fs.clone(),
199 None,
200 cx,
201 )
202 })?;
203
204 let worktree = project
205 .update(cx, |project, cx| {
206 project.create_worktree(&worktree_path, true, cx)
207 })?
208 .await?;
209
210 let (_lsp_open_handle, buffer) = if use_language_server {
211 let (lsp_open_handle, buffer) =
212 open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
213 (Some(lsp_open_handle), buffer)
214 } else {
215 let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
216 (None, buffer)
217 };
218
219 let full_path_str = worktree
220 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
221 .display(PathStyle::local())
222 .to_string();
223
224 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
225 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
226 if clipped_cursor != cursor.point {
227 let max_row = snapshot.max_point().row;
228 if cursor.point.row < max_row {
229 return Err(anyhow!(
230 "Cursor position {:?} is out of bounds (line length is {})",
231 cursor.point,
232 snapshot.line_len(cursor.point.row)
233 ));
234 } else {
235 return Err(anyhow!(
236 "Cursor position {:?} is out of bounds (max row is {})",
237 cursor.point,
238 max_row
239 ));
240 }
241 }
242
243 let events = match events {
244 Some(events) => events.read_to_string().await?,
245 None => String::new(),
246 };
247
248 if let Some(zeta2_args) = zeta2_args {
249 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
250 // the whole worktree.
251 worktree
252 .read_with(cx, |worktree, _cx| {
253 worktree.as_local().unwrap().scan_complete()
254 })?
255 .await;
256 let output = cx
257 .update(|cx| {
258 let zeta = cx.new(|cx| {
259 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
260 });
261 let indexing_done_task = zeta.update(cx, |zeta, cx| {
262 zeta.set_options(zeta2::ZetaOptions {
263 excerpt: EditPredictionExcerptOptions {
264 max_bytes: zeta2_args.max_excerpt_bytes,
265 min_bytes: zeta2_args.min_excerpt_bytes,
266 target_before_cursor_over_total_bytes: zeta2_args
267 .target_before_cursor_over_total_bytes,
268 },
269 max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
270 max_prompt_bytes: zeta2_args.max_prompt_bytes,
271 prompt_format: zeta2_args.prompt_format.into(),
272 file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
273 });
274 zeta.register_buffer(&buffer, &project, cx);
275 zeta.wait_for_initial_indexing(&project, cx)
276 });
277 cx.spawn(async move |cx| {
278 indexing_done_task.await?;
279 let request = zeta
280 .update(cx, |zeta, cx| {
281 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
282 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
283 })?
284 .await?;
285
286 let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
287 let prompt_string = planned_prompt.to_prompt_string()?.0;
288 match zeta2_args.output_format {
289 OutputFormat::Prompt => anyhow::Ok(prompt_string),
290 OutputFormat::Request => {
291 anyhow::Ok(serde_json::to_string_pretty(&request)?)
292 }
293 OutputFormat::Both => anyhow::Ok(serde_json::to_string_pretty(&json!({
294 "request": request,
295 "prompt": prompt_string,
296 }))?),
297 }
298 })
299 })?
300 .await?;
301 Ok(GetContextOutput::Zeta2(output))
302 } else {
303 let prompt_for_events = move || (events, 0);
304 Ok(GetContextOutput::Zeta1(
305 cx.update(|cx| {
306 zeta::gather_context(
307 full_path_str,
308 &snapshot,
309 clipped_cursor,
310 prompt_for_events,
311 cx,
312 )
313 })?
314 .await?,
315 ))
316 }
317}
318
319pub async fn open_buffer(
320 project: &Entity<Project>,
321 worktree: &Entity<Worktree>,
322 path: &RelPath,
323 cx: &mut AsyncApp,
324) -> Result<Entity<Buffer>> {
325 let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
326 worktree_id: worktree.id(),
327 path: path.into(),
328 })?;
329
330 project
331 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
332 .await
333}
334
335pub async fn open_buffer_with_language_server(
336 project: &Entity<Project>,
337 worktree: &Entity<Worktree>,
338 path: &RelPath,
339 cx: &mut AsyncApp,
340) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
341 let buffer = open_buffer(project, worktree, path, cx).await?;
342
343 let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
344 (
345 project.register_buffer_with_language_servers(&buffer, cx),
346 project.path_style(cx),
347 )
348 })?;
349
350 let log_prefix = path.display(path_style);
351 wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
352
353 Ok((lsp_open_handle, buffer))
354}
355
356// TODO: Dedupe with similar function in crates/eval/src/instance.rs
357pub fn wait_for_lang_server(
358 project: &Entity<Project>,
359 buffer: &Entity<Buffer>,
360 log_prefix: String,
361 cx: &mut AsyncApp,
362) -> Task<Result<()>> {
363 println!("{}⏵ Waiting for language server", log_prefix);
364
365 let (mut tx, mut rx) = mpsc::channel(1);
366
367 let lsp_store = project
368 .read_with(cx, |project, _| project.lsp_store())
369 .unwrap();
370
371 let has_lang_server = buffer
372 .update(cx, |buffer, cx| {
373 lsp_store.update(cx, |lsp_store, cx| {
374 lsp_store
375 .language_servers_for_local_buffer(buffer, cx)
376 .next()
377 .is_some()
378 })
379 })
380 .unwrap_or(false);
381
382 if has_lang_server {
383 project
384 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
385 .unwrap()
386 .detach();
387 }
388
389 let subscriptions = [
390 cx.subscribe(&lsp_store, {
391 let log_prefix = log_prefix.clone();
392 move |_, event, _| {
393 if let project::LspStoreEvent::LanguageServerUpdate {
394 message:
395 client::proto::update_language_server::Variant::WorkProgress(
396 client::proto::LspWorkProgress {
397 message: Some(message),
398 ..
399 },
400 ),
401 ..
402 } = event
403 {
404 println!("{}⟲ {message}", log_prefix)
405 }
406 }
407 }),
408 cx.subscribe(project, {
409 let buffer = buffer.clone();
410 move |project, event, cx| match event {
411 project::Event::LanguageServerAdded(_, _, _) => {
412 let buffer = buffer.clone();
413 project
414 .update(cx, |project, cx| project.save_buffer(buffer, cx))
415 .detach();
416 }
417 project::Event::DiskBasedDiagnosticsFinished { .. } => {
418 tx.try_send(()).ok();
419 }
420 _ => {}
421 }
422 }),
423 ];
424
425 cx.spawn(async move |cx| {
426 let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
427 let result = futures::select! {
428 _ = rx.next() => {
429 println!("{}⚑ Language server idle", log_prefix);
430 anyhow::Ok(())
431 },
432 _ = timeout.fuse() => {
433 anyhow::bail!("LSP wait timed out after 5 minutes");
434 }
435 };
436 drop(subscriptions);
437 result
438 })
439}
440
441fn main() {
442 zlog::init();
443 zlog::init_output_stderr();
444 let args = ZetaCliArgs::parse();
445 let http_client = Arc::new(ReqwestClient::new());
446 let app = Application::headless().with_http_client(http_client);
447
448 app.run(move |cx| {
449 let app_state = Arc::new(headless::init(cx));
450 cx.spawn(async move |cx| {
451 let result = match args.command {
452 Commands::Zeta2Context {
453 zeta2_args,
454 context_args,
455 } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
456 Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
457 Ok(GetContextOutput::Zeta2(output)) => Ok(output),
458 Err(err) => Err(err),
459 },
460 Commands::Context(context_args) => {
461 match get_context(None, context_args, &app_state, cx).await {
462 Ok(GetContextOutput::Zeta1(output)) => {
463 Ok(serde_json::to_string_pretty(&output.body).unwrap())
464 }
465 Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
466 Err(err) => Err(err),
467 }
468 }
469 Commands::Predict {
470 predict_edits_body,
471 context_args,
472 } => {
473 cx.spawn(async move |cx| {
474 let app_version = cx.update(|cx| AppVersion::global(cx))?;
475 app_state.client.sign_in(true, cx).await?;
476 let llm_token = LlmApiToken::default();
477 llm_token.refresh(&app_state.client).await?;
478
479 let predict_edits_body =
480 if let Some(predict_edits_body) = predict_edits_body {
481 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
482 } else if let Some(context_args) = context_args {
483 match get_context(None, context_args, &app_state, cx).await? {
484 GetContextOutput::Zeta1(output) => output.body,
485 GetContextOutput::Zeta2 { .. } => unreachable!(),
486 }
487 } else {
488 return Err(anyhow!(
489 "Expected either --predict-edits-body-file \
490 or the required args of the `context` command."
491 ));
492 };
493
494 let (response, _usage) =
495 Zeta::perform_predict_edits(PerformPredictEditsParams {
496 client: app_state.client.clone(),
497 llm_token,
498 app_version,
499 body: predict_edits_body,
500 })
501 .await?;
502
503 Ok(response.output_excerpt)
504 })
505 .await
506 }
507 };
508 match result {
509 Ok(output) => {
510 println!("{}", output);
511 let _ = cx.update(|cx| cx.quit());
512 }
513 Err(e) => {
514 eprintln!("Failed: {:?}", e);
515 exit(1);
516 }
517 }
518 })
519 .detach();
520 });
521}