1mod headless;
2
3use anyhow::{Result, anyhow};
4use clap::{Args, Parser, Subcommand};
5use cloud_llm_client::predict_edits_v3;
6use edit_prediction_context::EditPredictionExcerptOptions;
7use futures::channel::mpsc;
8use futures::{FutureExt as _, StreamExt as _};
9use gpui::{AppContext, Application, AsyncApp};
10use gpui::{Entity, Task};
11use language::Bias;
12use language::Buffer;
13use language::Point;
14use language_model::LlmApiToken;
15use project::{Project, ProjectPath, Worktree};
16use release_channel::AppVersion;
17use reqwest_client::ReqwestClient;
18use std::path::{Path, PathBuf};
19use std::process::exit;
20use std::str::FromStr;
21use std::sync::Arc;
22use std::time::Duration;
23use util::paths::PathStyle;
24use util::rel_path::RelPath;
25use zeta::{PerformPredictEditsParams, Zeta};
26
27use crate::headless::ZetaCliAppState;
28
29#[derive(Parser, Debug)]
30#[command(name = "zeta")]
31struct ZetaCliArgs {
32 #[command(subcommand)]
33 command: Commands,
34}
35
36#[derive(Subcommand, Debug)]
37enum Commands {
38 Context(ContextArgs),
39 Zeta2Context {
40 #[clap(flatten)]
41 zeta2_args: Zeta2Args,
42 #[clap(flatten)]
43 context_args: ContextArgs,
44 },
45 Predict {
46 #[arg(long)]
47 predict_edits_body: Option<FileOrStdin>,
48 #[clap(flatten)]
49 context_args: Option<ContextArgs>,
50 },
51}
52
53#[derive(Debug, Args)]
54#[group(requires = "worktree")]
55struct ContextArgs {
56 #[arg(long)]
57 worktree: PathBuf,
58 #[arg(long)]
59 cursor: CursorPosition,
60 #[arg(long)]
61 use_language_server: bool,
62 #[arg(long)]
63 events: Option<FileOrStdin>,
64}
65
66#[derive(Debug, Args)]
67struct Zeta2Args {
68 #[arg(long, default_value_t = 8192)]
69 max_prompt_bytes: usize,
70 #[arg(long, default_value_t = 2048)]
71 max_excerpt_bytes: usize,
72 #[arg(long, default_value_t = 1024)]
73 min_excerpt_bytes: usize,
74 #[arg(long, default_value_t = 0.66)]
75 target_before_cursor_over_total_bytes: f32,
76 #[arg(long, default_value_t = 1024)]
77 max_diagnostic_bytes: usize,
78 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
79 prompt_format: PromptFormat,
80 #[arg(long, value_enum, default_value_t = Default::default())]
81 output_format: OutputFormat,
82}
83
84#[derive(clap::ValueEnum, Default, Debug, Clone)]
85enum PromptFormat {
86 #[default]
87 MarkedExcerpt,
88 LabeledSections,
89}
90
91impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
92 fn into(self) -> predict_edits_v3::PromptFormat {
93 match self {
94 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
95 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
96 }
97 }
98}
99
100#[derive(clap::ValueEnum, Default, Debug, Clone)]
101enum OutputFormat {
102 #[default]
103 Prompt,
104 Request,
105}
106
107#[derive(Debug, Clone)]
108enum FileOrStdin {
109 File(PathBuf),
110 Stdin,
111}
112
113impl FileOrStdin {
114 async fn read_to_string(&self) -> Result<String, std::io::Error> {
115 match self {
116 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
117 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
118 }
119 }
120}
121
122impl FromStr for FileOrStdin {
123 type Err = <PathBuf as FromStr>::Err;
124
125 fn from_str(s: &str) -> Result<Self, Self::Err> {
126 match s {
127 "-" => Ok(Self::Stdin),
128 _ => Ok(Self::File(PathBuf::from_str(s)?)),
129 }
130 }
131}
132
133#[derive(Debug, Clone)]
134struct CursorPosition {
135 path: Arc<RelPath>,
136 point: Point,
137}
138
139impl FromStr for CursorPosition {
140 type Err = anyhow::Error;
141
142 fn from_str(s: &str) -> Result<Self> {
143 let parts: Vec<&str> = s.split(':').collect();
144 if parts.len() != 3 {
145 return Err(anyhow!(
146 "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
147 s
148 ));
149 }
150
151 let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
152 let line: u32 = parts[1]
153 .parse()
154 .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
155 let column: u32 = parts[2]
156 .parse()
157 .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
158
159 // Convert from 1-based to 0-based indexing
160 let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
161
162 Ok(CursorPosition { path, point })
163 }
164}
165
166enum GetContextOutput {
167 Zeta1(zeta::GatherContextOutput),
168 Zeta2(String),
169}
170
171async fn get_context(
172 zeta2_args: Option<Zeta2Args>,
173 args: ContextArgs,
174 app_state: &Arc<ZetaCliAppState>,
175 cx: &mut AsyncApp,
176) -> Result<GetContextOutput> {
177 let ContextArgs {
178 worktree: worktree_path,
179 cursor,
180 use_language_server,
181 events,
182 } = args;
183
184 let worktree_path = worktree_path.canonicalize()?;
185
186 let project = cx.update(|cx| {
187 Project::local(
188 app_state.client.clone(),
189 app_state.node_runtime.clone(),
190 app_state.user_store.clone(),
191 app_state.languages.clone(),
192 app_state.fs.clone(),
193 None,
194 cx,
195 )
196 })?;
197
198 let worktree = project
199 .update(cx, |project, cx| {
200 project.create_worktree(&worktree_path, true, cx)
201 })?
202 .await?;
203
204 let (_lsp_open_handle, buffer) = if use_language_server {
205 let (lsp_open_handle, buffer) =
206 open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
207 (Some(lsp_open_handle), buffer)
208 } else {
209 let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
210 (None, buffer)
211 };
212
213 let full_path_str = worktree
214 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
215 .display(PathStyle::local())
216 .to_string();
217
218 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
219 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
220 if clipped_cursor != cursor.point {
221 let max_row = snapshot.max_point().row;
222 if cursor.point.row < max_row {
223 return Err(anyhow!(
224 "Cursor position {:?} is out of bounds (line length is {})",
225 cursor.point,
226 snapshot.line_len(cursor.point.row)
227 ));
228 } else {
229 return Err(anyhow!(
230 "Cursor position {:?} is out of bounds (max row is {})",
231 cursor.point,
232 max_row
233 ));
234 }
235 }
236
237 let events = match events {
238 Some(events) => events.read_to_string().await?,
239 None => String::new(),
240 };
241
242 if let Some(zeta2_args) = zeta2_args {
243 Ok(GetContextOutput::Zeta2(
244 cx.update(|cx| {
245 let zeta = cx.new(|cx| {
246 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
247 });
248 zeta.update(cx, |zeta, cx| {
249 zeta.register_buffer(&buffer, &project, cx);
250 zeta.set_options(zeta2::ZetaOptions {
251 excerpt: EditPredictionExcerptOptions {
252 max_bytes: zeta2_args.max_excerpt_bytes,
253 min_bytes: zeta2_args.min_excerpt_bytes,
254 target_before_cursor_over_total_bytes: zeta2_args
255 .target_before_cursor_over_total_bytes,
256 },
257 max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
258 max_prompt_bytes: zeta2_args.max_prompt_bytes,
259 prompt_format: zeta2_args.prompt_format.into(),
260 })
261 });
262 // TODO: Actually wait for indexing.
263 let timer = cx.background_executor().timer(Duration::from_secs(5));
264 cx.spawn(async move |cx| {
265 timer.await;
266 let request = zeta
267 .update(cx, |zeta, cx| {
268 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
269 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
270 })?
271 .await?;
272 match zeta2_args.output_format {
273 OutputFormat::Prompt => {
274 let planned_prompt =
275 cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
276 // TODO: Output the section label ranges
277 anyhow::Ok(planned_prompt.to_prompt_string()?.0)
278 }
279 OutputFormat::Request => {
280 anyhow::Ok(serde_json::to_string_pretty(&request)?)
281 }
282 }
283 })
284 })?
285 .await?,
286 ))
287 } else {
288 let prompt_for_events = move || (events, 0);
289 Ok(GetContextOutput::Zeta1(
290 cx.update(|cx| {
291 zeta::gather_context(
292 full_path_str,
293 &snapshot,
294 clipped_cursor,
295 prompt_for_events,
296 cx,
297 )
298 })?
299 .await?,
300 ))
301 }
302}
303
304pub async fn open_buffer(
305 project: &Entity<Project>,
306 worktree: &Entity<Worktree>,
307 path: &RelPath,
308 cx: &mut AsyncApp,
309) -> Result<Entity<Buffer>> {
310 let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
311 worktree_id: worktree.id(),
312 path: path.into(),
313 })?;
314
315 project
316 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
317 .await
318}
319
320pub async fn open_buffer_with_language_server(
321 project: &Entity<Project>,
322 worktree: &Entity<Worktree>,
323 path: &RelPath,
324 cx: &mut AsyncApp,
325) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
326 let buffer = open_buffer(project, worktree, path, cx).await?;
327
328 let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
329 (
330 project.register_buffer_with_language_servers(&buffer, cx),
331 project.path_style(cx),
332 )
333 })?;
334
335 let log_prefix = path.display(path_style);
336 wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
337
338 Ok((lsp_open_handle, buffer))
339}
340
341// TODO: Dedupe with similar function in crates/eval/src/instance.rs
342pub fn wait_for_lang_server(
343 project: &Entity<Project>,
344 buffer: &Entity<Buffer>,
345 log_prefix: String,
346 cx: &mut AsyncApp,
347) -> Task<Result<()>> {
348 println!("{}⏵ Waiting for language server", log_prefix);
349
350 let (mut tx, mut rx) = mpsc::channel(1);
351
352 let lsp_store = project
353 .read_with(cx, |project, _| project.lsp_store())
354 .unwrap();
355
356 let has_lang_server = buffer
357 .update(cx, |buffer, cx| {
358 lsp_store.update(cx, |lsp_store, cx| {
359 lsp_store
360 .language_servers_for_local_buffer(buffer, cx)
361 .next()
362 .is_some()
363 })
364 })
365 .unwrap_or(false);
366
367 if has_lang_server {
368 project
369 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
370 .unwrap()
371 .detach();
372 }
373
374 let subscriptions = [
375 cx.subscribe(&lsp_store, {
376 let log_prefix = log_prefix.clone();
377 move |_, event, _| {
378 if let project::LspStoreEvent::LanguageServerUpdate {
379 message:
380 client::proto::update_language_server::Variant::WorkProgress(
381 client::proto::LspWorkProgress {
382 message: Some(message),
383 ..
384 },
385 ),
386 ..
387 } = event
388 {
389 println!("{}⟲ {message}", log_prefix)
390 }
391 }
392 }),
393 cx.subscribe(project, {
394 let buffer = buffer.clone();
395 move |project, event, cx| match event {
396 project::Event::LanguageServerAdded(_, _, _) => {
397 let buffer = buffer.clone();
398 project
399 .update(cx, |project, cx| project.save_buffer(buffer, cx))
400 .detach();
401 }
402 project::Event::DiskBasedDiagnosticsFinished { .. } => {
403 tx.try_send(()).ok();
404 }
405 _ => {}
406 }
407 }),
408 ];
409
410 cx.spawn(async move |cx| {
411 let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
412 let result = futures::select! {
413 _ = rx.next() => {
414 println!("{}⚑ Language server idle", log_prefix);
415 anyhow::Ok(())
416 },
417 _ = timeout.fuse() => {
418 anyhow::bail!("LSP wait timed out after 5 minutes");
419 }
420 };
421 drop(subscriptions);
422 result
423 })
424}
425
426fn main() {
427 let args = ZetaCliArgs::parse();
428 let http_client = Arc::new(ReqwestClient::new());
429 let app = Application::headless().with_http_client(http_client);
430
431 app.run(move |cx| {
432 let app_state = Arc::new(headless::init(cx));
433 let is_zeta2_context_command = matches!(args.command, Commands::Zeta2Context { .. });
434 cx.spawn(async move |cx| {
435 let result = match args.command {
436 Commands::Zeta2Context {
437 zeta2_args,
438 context_args,
439 } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
440 Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
441 Ok(GetContextOutput::Zeta2(output)) => Ok(output),
442 Err(err) => Err(err),
443 },
444 Commands::Context(context_args) => {
445 match get_context(None, context_args, &app_state, cx).await {
446 Ok(GetContextOutput::Zeta1(output)) => {
447 Ok(serde_json::to_string_pretty(&output.body).unwrap())
448 }
449 Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
450 Err(err) => Err(err),
451 }
452 }
453 Commands::Predict {
454 predict_edits_body,
455 context_args,
456 } => {
457 cx.spawn(async move |cx| {
458 let app_version = cx.update(|cx| AppVersion::global(cx))?;
459 app_state.client.sign_in(true, cx).await?;
460 let llm_token = LlmApiToken::default();
461 llm_token.refresh(&app_state.client).await?;
462
463 let predict_edits_body =
464 if let Some(predict_edits_body) = predict_edits_body {
465 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
466 } else if let Some(context_args) = context_args {
467 match get_context(None, context_args, &app_state, cx).await? {
468 GetContextOutput::Zeta1(output) => output.body,
469 GetContextOutput::Zeta2 { .. } => unreachable!(),
470 }
471 } else {
472 return Err(anyhow!(
473 "Expected either --predict-edits-body-file \
474 or the required args of the `context` command."
475 ));
476 };
477
478 let (response, _usage) =
479 Zeta::perform_predict_edits(PerformPredictEditsParams {
480 client: app_state.client.clone(),
481 llm_token,
482 app_version,
483 body: predict_edits_body,
484 })
485 .await?;
486
487 Ok(response.output_excerpt)
488 })
489 .await
490 }
491 };
492 match result {
493 Ok(output) => {
494 println!("{}", output);
495 // TODO: Remove this once the 5 second delay is properly replaced.
496 if is_zeta2_context_command {
497 eprintln!("Note that zeta_cli doesn't yet wait for indexing, instead waits 5 seconds.");
498 }
499 let _ = cx.update(|cx| cx.quit());
500 }
501 Err(e) => {
502 eprintln!("Failed: {:?}", e);
503 exit(1);
504 }
505 }
506 })
507 .detach();
508 });
509}