1mod headless;
2
3use anyhow::{Result, anyhow};
4use clap::{Args, Parser, Subcommand};
5use edit_prediction_context::EditPredictionExcerptOptions;
6use futures::channel::mpsc;
7use futures::{FutureExt as _, StreamExt as _};
8use gpui::{AppContext, Application, AsyncApp};
9use gpui::{Entity, Task};
10use language::Bias;
11use language::Buffer;
12use language::Point;
13use language_model::LlmApiToken;
14use project::{Project, ProjectPath, Worktree};
15use release_channel::AppVersion;
16use reqwest_client::ReqwestClient;
17use std::path::{Path, PathBuf};
18use std::process::exit;
19use std::str::FromStr;
20use std::sync::Arc;
21use std::time::Duration;
22use zeta::{PerformPredictEditsParams, Zeta};
23
24use crate::headless::ZetaCliAppState;
25
26#[derive(Parser, Debug)]
27#[command(name = "zeta")]
28struct ZetaCliArgs {
29 #[command(subcommand)]
30 command: Commands,
31}
32
33#[derive(Subcommand, Debug)]
34enum Commands {
35 Context(ContextArgs),
36 Zeta2Context {
37 #[clap(flatten)]
38 zeta2_args: Zeta2Args,
39 #[clap(flatten)]
40 context_args: ContextArgs,
41 },
42 Predict {
43 #[arg(long)]
44 predict_edits_body: Option<FileOrStdin>,
45 #[clap(flatten)]
46 context_args: Option<ContextArgs>,
47 },
48}
49
50#[derive(Debug, Args)]
51#[group(requires = "worktree")]
52struct ContextArgs {
53 #[arg(long)]
54 worktree: PathBuf,
55 #[arg(long)]
56 cursor: CursorPosition,
57 #[arg(long)]
58 use_language_server: bool,
59 #[arg(long)]
60 events: Option<FileOrStdin>,
61}
62
63#[derive(Debug, Args)]
64struct Zeta2Args {
65 #[arg(long, default_value_t = 8192)]
66 prompt_max_bytes: usize,
67 #[arg(long, default_value_t = 2048)]
68 excerpt_max_bytes: usize,
69 #[arg(long, default_value_t = 1024)]
70 excerpt_min_bytes: usize,
71 #[arg(long, default_value_t = 0.66)]
72 target_before_cursor_over_total_bytes: f32,
73 #[arg(long, default_value_t = 1024)]
74 max_diagnostic_bytes: usize,
75}
76
77#[derive(Debug, Clone)]
78enum FileOrStdin {
79 File(PathBuf),
80 Stdin,
81}
82
83impl FileOrStdin {
84 async fn read_to_string(&self) -> Result<String, std::io::Error> {
85 match self {
86 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
87 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
88 }
89 }
90}
91
92impl FromStr for FileOrStdin {
93 type Err = <PathBuf as FromStr>::Err;
94
95 fn from_str(s: &str) -> Result<Self, Self::Err> {
96 match s {
97 "-" => Ok(Self::Stdin),
98 _ => Ok(Self::File(PathBuf::from_str(s)?)),
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
104struct CursorPosition {
105 path: PathBuf,
106 point: Point,
107}
108
109impl FromStr for CursorPosition {
110 type Err = anyhow::Error;
111
112 fn from_str(s: &str) -> Result<Self> {
113 let parts: Vec<&str> = s.split(':').collect();
114 if parts.len() != 3 {
115 return Err(anyhow!(
116 "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
117 s
118 ));
119 }
120
121 let path = PathBuf::from(parts[0]);
122 let line: u32 = parts[1]
123 .parse()
124 .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
125 let column: u32 = parts[2]
126 .parse()
127 .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
128
129 // Convert from 1-based to 0-based indexing
130 let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
131
132 Ok(CursorPosition { path, point })
133 }
134}
135
136enum GetContextOutput {
137 Zeta1(zeta::GatherContextOutput),
138 Zeta2(String),
139}
140
141async fn get_context(
142 zeta2_args: Option<Zeta2Args>,
143 args: ContextArgs,
144 app_state: &Arc<ZetaCliAppState>,
145 cx: &mut AsyncApp,
146) -> Result<GetContextOutput> {
147 let ContextArgs {
148 worktree: worktree_path,
149 cursor,
150 use_language_server,
151 events,
152 } = args;
153
154 let worktree_path = worktree_path.canonicalize()?;
155 if cursor.path.is_absolute() {
156 return Err(anyhow!("Absolute paths are not supported in --cursor"));
157 }
158
159 let project = cx.update(|cx| {
160 Project::local(
161 app_state.client.clone(),
162 app_state.node_runtime.clone(),
163 app_state.user_store.clone(),
164 app_state.languages.clone(),
165 app_state.fs.clone(),
166 None,
167 cx,
168 )
169 })?;
170
171 let worktree = project
172 .update(cx, |project, cx| {
173 project.create_worktree(&worktree_path, true, cx)
174 })?
175 .await?;
176
177 let (_lsp_open_handle, buffer) = if use_language_server {
178 let (lsp_open_handle, buffer) =
179 open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
180 (Some(lsp_open_handle), buffer)
181 } else {
182 let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
183 (None, buffer)
184 };
185
186 let worktree_name = worktree_path
187 .file_name()
188 .ok_or_else(|| anyhow!("--worktree path must end with a folder name"))?;
189 let full_path_str = PathBuf::from(worktree_name)
190 .join(&cursor.path)
191 .to_string_lossy()
192 .to_string();
193
194 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
195 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
196 if clipped_cursor != cursor.point {
197 let max_row = snapshot.max_point().row;
198 if cursor.point.row < max_row {
199 return Err(anyhow!(
200 "Cursor position {:?} is out of bounds (line length is {})",
201 cursor.point,
202 snapshot.line_len(cursor.point.row)
203 ));
204 } else {
205 return Err(anyhow!(
206 "Cursor position {:?} is out of bounds (max row is {})",
207 cursor.point,
208 max_row
209 ));
210 }
211 }
212
213 let events = match events {
214 Some(events) => events.read_to_string().await?,
215 None => String::new(),
216 };
217
218 if let Some(zeta2_args) = zeta2_args {
219 Ok(GetContextOutput::Zeta2(
220 cx.update(|cx| {
221 let zeta = cx.new(|cx| {
222 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
223 });
224 zeta.update(cx, |zeta, cx| {
225 zeta.register_buffer(&buffer, &project, cx);
226 zeta.set_options(zeta2::ZetaOptions {
227 excerpt: EditPredictionExcerptOptions {
228 max_bytes: zeta2_args.excerpt_max_bytes,
229 min_bytes: zeta2_args.excerpt_min_bytes,
230 target_before_cursor_over_total_bytes: zeta2_args
231 .target_before_cursor_over_total_bytes,
232 },
233 max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
234 })
235 });
236 // TODO: Actually wait for indexing.
237 let timer = cx.background_executor().timer(Duration::from_secs(5));
238 cx.spawn(async move |cx| {
239 timer.await;
240 let request = zeta
241 .update(cx, |zeta, cx| {
242 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
243 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
244 })?
245 .await?;
246 let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(
247 &request,
248 &cloud_zeta2_prompt::PlanOptions {
249 max_bytes: zeta2_args.prompt_max_bytes,
250 },
251 )?;
252 anyhow::Ok(planned_prompt.to_prompt_string())
253 })
254 })?
255 .await?,
256 ))
257 } else {
258 let prompt_for_events = move || (events, 0);
259 Ok(GetContextOutput::Zeta1(
260 cx.update(|cx| {
261 zeta::gather_context(
262 full_path_str,
263 &snapshot,
264 clipped_cursor,
265 prompt_for_events,
266 cx,
267 )
268 })?
269 .await?,
270 ))
271 }
272}
273
274pub async fn open_buffer(
275 project: &Entity<Project>,
276 worktree: &Entity<Worktree>,
277 path: &Path,
278 cx: &mut AsyncApp,
279) -> Result<Entity<Buffer>> {
280 let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
281 worktree_id: worktree.id(),
282 path: path.to_path_buf().into(),
283 })?;
284
285 project
286 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
287 .await
288}
289
290pub async fn open_buffer_with_language_server(
291 project: &Entity<Project>,
292 worktree: &Entity<Worktree>,
293 path: &Path,
294 cx: &mut AsyncApp,
295) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
296 let buffer = open_buffer(project, worktree, path, cx).await?;
297
298 let lsp_open_handle = project.update(cx, |project, cx| {
299 project.register_buffer_with_language_servers(&buffer, cx)
300 })?;
301
302 let log_prefix = path.to_string_lossy().to_string();
303 wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
304
305 Ok((lsp_open_handle, buffer))
306}
307
308// TODO: Dedupe with similar function in crates/eval/src/instance.rs
309pub fn wait_for_lang_server(
310 project: &Entity<Project>,
311 buffer: &Entity<Buffer>,
312 log_prefix: String,
313 cx: &mut AsyncApp,
314) -> Task<Result<()>> {
315 println!("{}⏵ Waiting for language server", log_prefix);
316
317 let (mut tx, mut rx) = mpsc::channel(1);
318
319 let lsp_store = project
320 .read_with(cx, |project, _| project.lsp_store())
321 .unwrap();
322
323 let has_lang_server = buffer
324 .update(cx, |buffer, cx| {
325 lsp_store.update(cx, |lsp_store, cx| {
326 lsp_store
327 .language_servers_for_local_buffer(buffer, cx)
328 .next()
329 .is_some()
330 })
331 })
332 .unwrap_or(false);
333
334 if has_lang_server {
335 project
336 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
337 .unwrap()
338 .detach();
339 }
340
341 let subscriptions = [
342 cx.subscribe(&lsp_store, {
343 let log_prefix = log_prefix.clone();
344 move |_, event, _| {
345 if let project::LspStoreEvent::LanguageServerUpdate {
346 message:
347 client::proto::update_language_server::Variant::WorkProgress(
348 client::proto::LspWorkProgress {
349 message: Some(message),
350 ..
351 },
352 ),
353 ..
354 } = event
355 {
356 println!("{}⟲ {message}", log_prefix)
357 }
358 }
359 }),
360 cx.subscribe(project, {
361 let buffer = buffer.clone();
362 move |project, event, cx| match event {
363 project::Event::LanguageServerAdded(_, _, _) => {
364 let buffer = buffer.clone();
365 project
366 .update(cx, |project, cx| project.save_buffer(buffer, cx))
367 .detach();
368 }
369 project::Event::DiskBasedDiagnosticsFinished { .. } => {
370 tx.try_send(()).ok();
371 }
372 _ => {}
373 }
374 }),
375 ];
376
377 cx.spawn(async move |cx| {
378 let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
379 let result = futures::select! {
380 _ = rx.next() => {
381 println!("{}⚑ Language server idle", log_prefix);
382 anyhow::Ok(())
383 },
384 _ = timeout.fuse() => {
385 anyhow::bail!("LSP wait timed out after 5 minutes");
386 }
387 };
388 drop(subscriptions);
389 result
390 })
391}
392
393fn main() {
394 let args = ZetaCliArgs::parse();
395 let http_client = Arc::new(ReqwestClient::new());
396 let app = Application::headless().with_http_client(http_client);
397
398 app.run(move |cx| {
399 let app_state = Arc::new(headless::init(cx));
400 let is_zeta2_context_command = matches!(args.command, Commands::Zeta2Context { .. });
401 cx.spawn(async move |cx| {
402 let result = match args.command {
403 Commands::Zeta2Context {
404 zeta2_args,
405 context_args,
406 } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
407 Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
408 Ok(GetContextOutput::Zeta2(output)) => Ok(output),
409 Err(err) => Err(err),
410 },
411 Commands::Context(context_args) => {
412 match get_context(None, context_args, &app_state, cx).await {
413 Ok(GetContextOutput::Zeta1(output)) => {
414 Ok(serde_json::to_string_pretty(&output.body).unwrap())
415 }
416 Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
417 Err(err) => Err(err),
418 }
419 }
420 Commands::Predict {
421 predict_edits_body,
422 context_args,
423 } => {
424 cx.spawn(async move |cx| {
425 let app_version = cx.update(|cx| AppVersion::global(cx))?;
426 app_state.client.sign_in(true, cx).await?;
427 let llm_token = LlmApiToken::default();
428 llm_token.refresh(&app_state.client).await?;
429
430 let predict_edits_body =
431 if let Some(predict_edits_body) = predict_edits_body {
432 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
433 } else if let Some(context_args) = context_args {
434 match get_context(None, context_args, &app_state, cx).await? {
435 GetContextOutput::Zeta1(output) => output.body,
436 GetContextOutput::Zeta2 { .. } => unreachable!(),
437 }
438 } else {
439 return Err(anyhow!(
440 "Expected either --predict-edits-body-file \
441 or the required args of the `context` command."
442 ));
443 };
444
445 let (response, _usage) =
446 Zeta::perform_predict_edits(PerformPredictEditsParams {
447 client: app_state.client.clone(),
448 llm_token,
449 app_version,
450 body: predict_edits_body,
451 })
452 .await?;
453
454 Ok(response.output_excerpt)
455 })
456 .await
457 }
458 };
459 match result {
460 Ok(output) => {
461 println!("{}", output);
462 // TODO: Remove this once the 5 second delay is properly replaced.
463 if is_zeta2_context_command {
464 eprintln!("Note that zeta2-context doesn't yet wait for indexing, instead waits 5 seconds.");
465 }
466 let _ = cx.update(|cx| cx.quit());
467 }
468 Err(e) => {
469 eprintln!("Failed: {:?}", e);
470 exit(1);
471 }
472 }
473 })
474 .detach();
475 });
476}