1mod headless;
2
3use anyhow::{Result, anyhow};
4use clap::{Args, Parser, Subcommand};
5use cloud_llm_client::predict_edits_v3;
6use edit_prediction_context::EditPredictionExcerptOptions;
7use futures::channel::mpsc;
8use futures::{FutureExt as _, StreamExt as _};
9use gpui::{AppContext, Application, AsyncApp};
10use gpui::{Entity, Task};
11use language::Bias;
12use language::Buffer;
13use language::Point;
14use language_model::LlmApiToken;
15use project::{Project, ProjectPath, Worktree};
16use release_channel::AppVersion;
17use reqwest_client::ReqwestClient;
18use serde_json::json;
19use std::path::{Path, PathBuf};
20use std::process::exit;
21use std::str::FromStr;
22use std::sync::Arc;
23use std::time::Duration;
24use util::paths::PathStyle;
25use util::rel_path::RelPath;
26use zeta::{PerformPredictEditsParams, Zeta};
27
28use crate::headless::ZetaCliAppState;
29
30#[derive(Parser, Debug)]
31#[command(name = "zeta")]
32struct ZetaCliArgs {
33 #[command(subcommand)]
34 command: Commands,
35}
36
37#[derive(Subcommand, Debug)]
38enum Commands {
39 Context(ContextArgs),
40 Zeta2Context {
41 #[clap(flatten)]
42 zeta2_args: Zeta2Args,
43 #[clap(flatten)]
44 context_args: ContextArgs,
45 },
46 Predict {
47 #[arg(long)]
48 predict_edits_body: Option<FileOrStdin>,
49 #[clap(flatten)]
50 context_args: Option<ContextArgs>,
51 },
52}
53
54#[derive(Debug, Args)]
55#[group(requires = "worktree")]
56struct ContextArgs {
57 #[arg(long)]
58 worktree: PathBuf,
59 #[arg(long)]
60 cursor: CursorPosition,
61 #[arg(long)]
62 use_language_server: bool,
63 #[arg(long)]
64 events: Option<FileOrStdin>,
65}
66
67#[derive(Debug, Args)]
68struct Zeta2Args {
69 #[arg(long, default_value_t = 8192)]
70 max_prompt_bytes: usize,
71 #[arg(long, default_value_t = 2048)]
72 max_excerpt_bytes: usize,
73 #[arg(long, default_value_t = 1024)]
74 min_excerpt_bytes: usize,
75 #[arg(long, default_value_t = 0.66)]
76 target_before_cursor_over_total_bytes: f32,
77 #[arg(long, default_value_t = 1024)]
78 max_diagnostic_bytes: usize,
79 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
80 prompt_format: PromptFormat,
81 #[arg(long, value_enum, default_value_t = Default::default())]
82 output_format: OutputFormat,
83}
84
85#[derive(clap::ValueEnum, Default, Debug, Clone)]
86enum PromptFormat {
87 #[default]
88 MarkedExcerpt,
89 LabeledSections,
90 OnlySnippets,
91}
92
93impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
94 fn into(self) -> predict_edits_v3::PromptFormat {
95 match self {
96 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
97 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
98 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
99 }
100 }
101}
102
103#[derive(clap::ValueEnum, Default, Debug, Clone)]
104enum OutputFormat {
105 #[default]
106 Prompt,
107 Request,
108 Both,
109}
110
111#[derive(Debug, Clone)]
112enum FileOrStdin {
113 File(PathBuf),
114 Stdin,
115}
116
117impl FileOrStdin {
118 async fn read_to_string(&self) -> Result<String, std::io::Error> {
119 match self {
120 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
121 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
122 }
123 }
124}
125
126impl FromStr for FileOrStdin {
127 type Err = <PathBuf as FromStr>::Err;
128
129 fn from_str(s: &str) -> Result<Self, Self::Err> {
130 match s {
131 "-" => Ok(Self::Stdin),
132 _ => Ok(Self::File(PathBuf::from_str(s)?)),
133 }
134 }
135}
136
137#[derive(Debug, Clone)]
138struct CursorPosition {
139 path: Arc<RelPath>,
140 point: Point,
141}
142
143impl FromStr for CursorPosition {
144 type Err = anyhow::Error;
145
146 fn from_str(s: &str) -> Result<Self> {
147 let parts: Vec<&str> = s.split(':').collect();
148 if parts.len() != 3 {
149 return Err(anyhow!(
150 "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
151 s
152 ));
153 }
154
155 let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
156 let line: u32 = parts[1]
157 .parse()
158 .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
159 let column: u32 = parts[2]
160 .parse()
161 .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
162
163 // Convert from 1-based to 0-based indexing
164 let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
165
166 Ok(CursorPosition { path, point })
167 }
168}
169
170enum GetContextOutput {
171 Zeta1(zeta::GatherContextOutput),
172 Zeta2(String),
173}
174
175async fn get_context(
176 zeta2_args: Option<Zeta2Args>,
177 args: ContextArgs,
178 app_state: &Arc<ZetaCliAppState>,
179 cx: &mut AsyncApp,
180) -> Result<GetContextOutput> {
181 let ContextArgs {
182 worktree: worktree_path,
183 cursor,
184 use_language_server,
185 events,
186 } = args;
187
188 let worktree_path = worktree_path.canonicalize()?;
189
190 let project = cx.update(|cx| {
191 Project::local(
192 app_state.client.clone(),
193 app_state.node_runtime.clone(),
194 app_state.user_store.clone(),
195 app_state.languages.clone(),
196 app_state.fs.clone(),
197 None,
198 cx,
199 )
200 })?;
201
202 let worktree = project
203 .update(cx, |project, cx| {
204 project.create_worktree(&worktree_path, true, cx)
205 })?
206 .await?;
207
208 let (_lsp_open_handle, buffer) = if use_language_server {
209 let (lsp_open_handle, buffer) =
210 open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
211 (Some(lsp_open_handle), buffer)
212 } else {
213 let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
214 (None, buffer)
215 };
216
217 let full_path_str = worktree
218 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
219 .display(PathStyle::local())
220 .to_string();
221
222 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
223 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
224 if clipped_cursor != cursor.point {
225 let max_row = snapshot.max_point().row;
226 if cursor.point.row < max_row {
227 return Err(anyhow!(
228 "Cursor position {:?} is out of bounds (line length is {})",
229 cursor.point,
230 snapshot.line_len(cursor.point.row)
231 ));
232 } else {
233 return Err(anyhow!(
234 "Cursor position {:?} is out of bounds (max row is {})",
235 cursor.point,
236 max_row
237 ));
238 }
239 }
240
241 let events = match events {
242 Some(events) => events.read_to_string().await?,
243 None => String::new(),
244 };
245
246 if let Some(zeta2_args) = zeta2_args {
247 Ok(GetContextOutput::Zeta2(
248 cx.update(|cx| {
249 let zeta = cx.new(|cx| {
250 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
251 });
252 zeta.update(cx, |zeta, cx| {
253 zeta.register_buffer(&buffer, &project, cx);
254 zeta.set_options(zeta2::ZetaOptions {
255 excerpt: EditPredictionExcerptOptions {
256 max_bytes: zeta2_args.max_excerpt_bytes,
257 min_bytes: zeta2_args.min_excerpt_bytes,
258 target_before_cursor_over_total_bytes: zeta2_args
259 .target_before_cursor_over_total_bytes,
260 },
261 max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
262 max_prompt_bytes: zeta2_args.max_prompt_bytes,
263 prompt_format: zeta2_args.prompt_format.into(),
264 })
265 });
266 // TODO: Actually wait for indexing.
267 let timer = cx.background_executor().timer(Duration::from_secs(5));
268 cx.spawn(async move |cx| {
269 timer.await;
270 let request = zeta
271 .update(cx, |zeta, cx| {
272 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
273 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
274 })?
275 .await?;
276
277 let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
278 let prompt_string = planned_prompt.to_prompt_string()?.0;
279 match zeta2_args.output_format {
280 OutputFormat::Prompt => anyhow::Ok(prompt_string),
281 OutputFormat::Request => {
282 anyhow::Ok(serde_json::to_string_pretty(&request)?)
283 }
284 OutputFormat::Both => anyhow::Ok(serde_json::to_string_pretty(&json!({
285 "request": request,
286 "prompt": prompt_string,
287 }))?),
288 }
289 })
290 })?
291 .await?,
292 ))
293 } else {
294 let prompt_for_events = move || (events, 0);
295 Ok(GetContextOutput::Zeta1(
296 cx.update(|cx| {
297 zeta::gather_context(
298 full_path_str,
299 &snapshot,
300 clipped_cursor,
301 prompt_for_events,
302 cx,
303 )
304 })?
305 .await?,
306 ))
307 }
308}
309
310pub async fn open_buffer(
311 project: &Entity<Project>,
312 worktree: &Entity<Worktree>,
313 path: &RelPath,
314 cx: &mut AsyncApp,
315) -> Result<Entity<Buffer>> {
316 let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
317 worktree_id: worktree.id(),
318 path: path.into(),
319 })?;
320
321 project
322 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
323 .await
324}
325
326pub async fn open_buffer_with_language_server(
327 project: &Entity<Project>,
328 worktree: &Entity<Worktree>,
329 path: &RelPath,
330 cx: &mut AsyncApp,
331) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
332 let buffer = open_buffer(project, worktree, path, cx).await?;
333
334 let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
335 (
336 project.register_buffer_with_language_servers(&buffer, cx),
337 project.path_style(cx),
338 )
339 })?;
340
341 let log_prefix = path.display(path_style);
342 wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
343
344 Ok((lsp_open_handle, buffer))
345}
346
347// TODO: Dedupe with similar function in crates/eval/src/instance.rs
348pub fn wait_for_lang_server(
349 project: &Entity<Project>,
350 buffer: &Entity<Buffer>,
351 log_prefix: String,
352 cx: &mut AsyncApp,
353) -> Task<Result<()>> {
354 println!("{}⏵ Waiting for language server", log_prefix);
355
356 let (mut tx, mut rx) = mpsc::channel(1);
357
358 let lsp_store = project
359 .read_with(cx, |project, _| project.lsp_store())
360 .unwrap();
361
362 let has_lang_server = buffer
363 .update(cx, |buffer, cx| {
364 lsp_store.update(cx, |lsp_store, cx| {
365 lsp_store
366 .language_servers_for_local_buffer(buffer, cx)
367 .next()
368 .is_some()
369 })
370 })
371 .unwrap_or(false);
372
373 if has_lang_server {
374 project
375 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
376 .unwrap()
377 .detach();
378 }
379
380 let subscriptions = [
381 cx.subscribe(&lsp_store, {
382 let log_prefix = log_prefix.clone();
383 move |_, event, _| {
384 if let project::LspStoreEvent::LanguageServerUpdate {
385 message:
386 client::proto::update_language_server::Variant::WorkProgress(
387 client::proto::LspWorkProgress {
388 message: Some(message),
389 ..
390 },
391 ),
392 ..
393 } = event
394 {
395 println!("{}⟲ {message}", log_prefix)
396 }
397 }
398 }),
399 cx.subscribe(project, {
400 let buffer = buffer.clone();
401 move |project, event, cx| match event {
402 project::Event::LanguageServerAdded(_, _, _) => {
403 let buffer = buffer.clone();
404 project
405 .update(cx, |project, cx| project.save_buffer(buffer, cx))
406 .detach();
407 }
408 project::Event::DiskBasedDiagnosticsFinished { .. } => {
409 tx.try_send(()).ok();
410 }
411 _ => {}
412 }
413 }),
414 ];
415
416 cx.spawn(async move |cx| {
417 let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
418 let result = futures::select! {
419 _ = rx.next() => {
420 println!("{}⚑ Language server idle", log_prefix);
421 anyhow::Ok(())
422 },
423 _ = timeout.fuse() => {
424 anyhow::bail!("LSP wait timed out after 5 minutes");
425 }
426 };
427 drop(subscriptions);
428 result
429 })
430}
431
432fn main() {
433 let args = ZetaCliArgs::parse();
434 let http_client = Arc::new(ReqwestClient::new());
435 let app = Application::headless().with_http_client(http_client);
436
437 app.run(move |cx| {
438 let app_state = Arc::new(headless::init(cx));
439 let is_zeta2_context_command = matches!(args.command, Commands::Zeta2Context { .. });
440 cx.spawn(async move |cx| {
441 let result = match args.command {
442 Commands::Zeta2Context {
443 zeta2_args,
444 context_args,
445 } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
446 Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
447 Ok(GetContextOutput::Zeta2(output)) => Ok(output),
448 Err(err) => Err(err),
449 },
450 Commands::Context(context_args) => {
451 match get_context(None, context_args, &app_state, cx).await {
452 Ok(GetContextOutput::Zeta1(output)) => {
453 Ok(serde_json::to_string_pretty(&output.body).unwrap())
454 }
455 Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
456 Err(err) => Err(err),
457 }
458 }
459 Commands::Predict {
460 predict_edits_body,
461 context_args,
462 } => {
463 cx.spawn(async move |cx| {
464 let app_version = cx.update(|cx| AppVersion::global(cx))?;
465 app_state.client.sign_in(true, cx).await?;
466 let llm_token = LlmApiToken::default();
467 llm_token.refresh(&app_state.client).await?;
468
469 let predict_edits_body =
470 if let Some(predict_edits_body) = predict_edits_body {
471 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
472 } else if let Some(context_args) = context_args {
473 match get_context(None, context_args, &app_state, cx).await? {
474 GetContextOutput::Zeta1(output) => output.body,
475 GetContextOutput::Zeta2 { .. } => unreachable!(),
476 }
477 } else {
478 return Err(anyhow!(
479 "Expected either --predict-edits-body-file \
480 or the required args of the `context` command."
481 ));
482 };
483
484 let (response, _usage) =
485 Zeta::perform_predict_edits(PerformPredictEditsParams {
486 client: app_state.client.clone(),
487 llm_token,
488 app_version,
489 body: predict_edits_body,
490 })
491 .await?;
492
493 Ok(response.output_excerpt)
494 })
495 .await
496 }
497 };
498 match result {
499 Ok(output) => {
500 println!("{}", output);
501 // TODO: Remove this once the 5 second delay is properly replaced.
502 if is_zeta2_context_command {
503 eprintln!("Note that zeta_cli doesn't yet wait for indexing, instead waits 5 seconds.");
504 }
505 let _ = cx.update(|cx| cx.quit());
506 }
507 Err(e) => {
508 eprintln!("Failed: {:?}", e);
509 exit(1);
510 }
511 }
512 })
513 .detach();
514 });
515}