1mod headless;
2
3use anyhow::{Result, anyhow};
4use clap::{Args, Parser, Subcommand};
5use edit_prediction_context::EditPredictionExcerptOptions;
6use futures::channel::mpsc;
7use futures::{FutureExt as _, StreamExt as _};
8use gpui::{AppContext, Application, AsyncApp};
9use gpui::{Entity, Task};
10use language::Bias;
11use language::Buffer;
12use language::Point;
13use language_model::LlmApiToken;
14use project::{Project, ProjectPath, Worktree};
15use release_channel::AppVersion;
16use reqwest_client::ReqwestClient;
17use std::path::{Path, PathBuf};
18use std::process::exit;
19use std::str::FromStr;
20use std::sync::Arc;
21use std::time::Duration;
22use zeta::{PerformPredictEditsParams, Zeta};
23
24use crate::headless::ZetaCliAppState;
25
26#[derive(Parser, Debug)]
27#[command(name = "zeta")]
28struct ZetaCliArgs {
29 #[command(subcommand)]
30 command: Commands,
31}
32
33#[derive(Subcommand, Debug)]
34enum Commands {
35 Context(ContextArgs),
36 Zeta2Context {
37 #[clap(flatten)]
38 zeta2_args: Zeta2Args,
39 #[clap(flatten)]
40 context_args: ContextArgs,
41 },
42 Predict {
43 #[arg(long)]
44 predict_edits_body: Option<FileOrStdin>,
45 #[clap(flatten)]
46 context_args: Option<ContextArgs>,
47 },
48}
49
50#[derive(Debug, Args)]
51#[group(requires = "worktree")]
52struct ContextArgs {
53 #[arg(long)]
54 worktree: PathBuf,
55 #[arg(long)]
56 cursor: CursorPosition,
57 #[arg(long)]
58 use_language_server: bool,
59 #[arg(long)]
60 events: Option<FileOrStdin>,
61}
62
63#[derive(Debug, Args)]
64struct Zeta2Args {
65 #[arg(long, default_value_t = 8192)]
66 max_prompt_bytes: usize,
67 #[arg(long, default_value_t = 2048)]
68 max_excerpt_bytes: usize,
69 #[arg(long, default_value_t = 1024)]
70 min_excerpt_bytes: usize,
71 #[arg(long, default_value_t = 0.66)]
72 target_before_cursor_over_total_bytes: f32,
73 #[arg(long, default_value_t = 1024)]
74 max_diagnostic_bytes: usize,
75}
76
77#[derive(Debug, Clone)]
78enum FileOrStdin {
79 File(PathBuf),
80 Stdin,
81}
82
83impl FileOrStdin {
84 async fn read_to_string(&self) -> Result<String, std::io::Error> {
85 match self {
86 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
87 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
88 }
89 }
90}
91
92impl FromStr for FileOrStdin {
93 type Err = <PathBuf as FromStr>::Err;
94
95 fn from_str(s: &str) -> Result<Self, Self::Err> {
96 match s {
97 "-" => Ok(Self::Stdin),
98 _ => Ok(Self::File(PathBuf::from_str(s)?)),
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
104struct CursorPosition {
105 path: PathBuf,
106 point: Point,
107}
108
109impl FromStr for CursorPosition {
110 type Err = anyhow::Error;
111
112 fn from_str(s: &str) -> Result<Self> {
113 let parts: Vec<&str> = s.split(':').collect();
114 if parts.len() != 3 {
115 return Err(anyhow!(
116 "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
117 s
118 ));
119 }
120
121 let path = PathBuf::from(parts[0]);
122 let line: u32 = parts[1]
123 .parse()
124 .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
125 let column: u32 = parts[2]
126 .parse()
127 .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
128
129 // Convert from 1-based to 0-based indexing
130 let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
131
132 Ok(CursorPosition { path, point })
133 }
134}
135
136enum GetContextOutput {
137 Zeta1(zeta::GatherContextOutput),
138 Zeta2(String),
139}
140
141async fn get_context(
142 zeta2_args: Option<Zeta2Args>,
143 args: ContextArgs,
144 app_state: &Arc<ZetaCliAppState>,
145 cx: &mut AsyncApp,
146) -> Result<GetContextOutput> {
147 let ContextArgs {
148 worktree: worktree_path,
149 cursor,
150 use_language_server,
151 events,
152 } = args;
153
154 let worktree_path = worktree_path.canonicalize()?;
155 if cursor.path.is_absolute() {
156 return Err(anyhow!("Absolute paths are not supported in --cursor"));
157 }
158
159 let project = cx.update(|cx| {
160 Project::local(
161 app_state.client.clone(),
162 app_state.node_runtime.clone(),
163 app_state.user_store.clone(),
164 app_state.languages.clone(),
165 app_state.fs.clone(),
166 None,
167 cx,
168 )
169 })?;
170
171 let worktree = project
172 .update(cx, |project, cx| {
173 project.create_worktree(&worktree_path, true, cx)
174 })?
175 .await?;
176
177 let (_lsp_open_handle, buffer) = if use_language_server {
178 let (lsp_open_handle, buffer) =
179 open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
180 (Some(lsp_open_handle), buffer)
181 } else {
182 let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
183 (None, buffer)
184 };
185
186 let worktree_name = worktree_path
187 .file_name()
188 .ok_or_else(|| anyhow!("--worktree path must end with a folder name"))?;
189 let full_path_str = PathBuf::from(worktree_name)
190 .join(&cursor.path)
191 .to_string_lossy()
192 .to_string();
193
194 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
195 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
196 if clipped_cursor != cursor.point {
197 let max_row = snapshot.max_point().row;
198 if cursor.point.row < max_row {
199 return Err(anyhow!(
200 "Cursor position {:?} is out of bounds (line length is {})",
201 cursor.point,
202 snapshot.line_len(cursor.point.row)
203 ));
204 } else {
205 return Err(anyhow!(
206 "Cursor position {:?} is out of bounds (max row is {})",
207 cursor.point,
208 max_row
209 ));
210 }
211 }
212
213 let events = match events {
214 Some(events) => events.read_to_string().await?,
215 None => String::new(),
216 };
217
218 if let Some(zeta2_args) = zeta2_args {
219 Ok(GetContextOutput::Zeta2(
220 cx.update(|cx| {
221 let zeta = cx.new(|cx| {
222 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
223 });
224 zeta.update(cx, |zeta, cx| {
225 zeta.register_buffer(&buffer, &project, cx);
226 zeta.set_options(zeta2::ZetaOptions {
227 excerpt: EditPredictionExcerptOptions {
228 max_bytes: zeta2_args.max_excerpt_bytes,
229 min_bytes: zeta2_args.min_excerpt_bytes,
230 target_before_cursor_over_total_bytes: zeta2_args
231 .target_before_cursor_over_total_bytes,
232 },
233 max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
234 max_prompt_bytes: zeta2_args.max_prompt_bytes,
235 })
236 });
237 // TODO: Actually wait for indexing.
238 let timer = cx.background_executor().timer(Duration::from_secs(5));
239 cx.spawn(async move |cx| {
240 timer.await;
241 let request = zeta
242 .update(cx, |zeta, cx| {
243 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
244 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
245 })?
246 .await?;
247 let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(
248 &request,
249 &cloud_zeta2_prompt::PlanOptions {
250 max_bytes: zeta2_args.max_prompt_bytes,
251 },
252 )?;
253 anyhow::Ok(planned_prompt.to_prompt_string())
254 })
255 })?
256 .await?,
257 ))
258 } else {
259 let prompt_for_events = move || (events, 0);
260 Ok(GetContextOutput::Zeta1(
261 cx.update(|cx| {
262 zeta::gather_context(
263 full_path_str,
264 &snapshot,
265 clipped_cursor,
266 prompt_for_events,
267 cx,
268 )
269 })?
270 .await?,
271 ))
272 }
273}
274
275pub async fn open_buffer(
276 project: &Entity<Project>,
277 worktree: &Entity<Worktree>,
278 path: &Path,
279 cx: &mut AsyncApp,
280) -> Result<Entity<Buffer>> {
281 let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
282 worktree_id: worktree.id(),
283 path: path.to_path_buf().into(),
284 })?;
285
286 project
287 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
288 .await
289}
290
291pub async fn open_buffer_with_language_server(
292 project: &Entity<Project>,
293 worktree: &Entity<Worktree>,
294 path: &Path,
295 cx: &mut AsyncApp,
296) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
297 let buffer = open_buffer(project, worktree, path, cx).await?;
298
299 let lsp_open_handle = project.update(cx, |project, cx| {
300 project.register_buffer_with_language_servers(&buffer, cx)
301 })?;
302
303 let log_prefix = path.to_string_lossy().to_string();
304 wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
305
306 Ok((lsp_open_handle, buffer))
307}
308
309// TODO: Dedupe with similar function in crates/eval/src/instance.rs
310pub fn wait_for_lang_server(
311 project: &Entity<Project>,
312 buffer: &Entity<Buffer>,
313 log_prefix: String,
314 cx: &mut AsyncApp,
315) -> Task<Result<()>> {
316 println!("{}⏵ Waiting for language server", log_prefix);
317
318 let (mut tx, mut rx) = mpsc::channel(1);
319
320 let lsp_store = project
321 .read_with(cx, |project, _| project.lsp_store())
322 .unwrap();
323
324 let has_lang_server = buffer
325 .update(cx, |buffer, cx| {
326 lsp_store.update(cx, |lsp_store, cx| {
327 lsp_store
328 .language_servers_for_local_buffer(buffer, cx)
329 .next()
330 .is_some()
331 })
332 })
333 .unwrap_or(false);
334
335 if has_lang_server {
336 project
337 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
338 .unwrap()
339 .detach();
340 }
341
342 let subscriptions = [
343 cx.subscribe(&lsp_store, {
344 let log_prefix = log_prefix.clone();
345 move |_, event, _| {
346 if let project::LspStoreEvent::LanguageServerUpdate {
347 message:
348 client::proto::update_language_server::Variant::WorkProgress(
349 client::proto::LspWorkProgress {
350 message: Some(message),
351 ..
352 },
353 ),
354 ..
355 } = event
356 {
357 println!("{}⟲ {message}", log_prefix)
358 }
359 }
360 }),
361 cx.subscribe(project, {
362 let buffer = buffer.clone();
363 move |project, event, cx| match event {
364 project::Event::LanguageServerAdded(_, _, _) => {
365 let buffer = buffer.clone();
366 project
367 .update(cx, |project, cx| project.save_buffer(buffer, cx))
368 .detach();
369 }
370 project::Event::DiskBasedDiagnosticsFinished { .. } => {
371 tx.try_send(()).ok();
372 }
373 _ => {}
374 }
375 }),
376 ];
377
378 cx.spawn(async move |cx| {
379 let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
380 let result = futures::select! {
381 _ = rx.next() => {
382 println!("{}⚑ Language server idle", log_prefix);
383 anyhow::Ok(())
384 },
385 _ = timeout.fuse() => {
386 anyhow::bail!("LSP wait timed out after 5 minutes");
387 }
388 };
389 drop(subscriptions);
390 result
391 })
392}
393
394fn main() {
395 let args = ZetaCliArgs::parse();
396 let http_client = Arc::new(ReqwestClient::new());
397 let app = Application::headless().with_http_client(http_client);
398
399 app.run(move |cx| {
400 let app_state = Arc::new(headless::init(cx));
401 let is_zeta2_context_command = matches!(args.command, Commands::Zeta2Context { .. });
402 cx.spawn(async move |cx| {
403 let result = match args.command {
404 Commands::Zeta2Context {
405 zeta2_args,
406 context_args,
407 } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
408 Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
409 Ok(GetContextOutput::Zeta2(output)) => Ok(output),
410 Err(err) => Err(err),
411 },
412 Commands::Context(context_args) => {
413 match get_context(None, context_args, &app_state, cx).await {
414 Ok(GetContextOutput::Zeta1(output)) => {
415 Ok(serde_json::to_string_pretty(&output.body).unwrap())
416 }
417 Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
418 Err(err) => Err(err),
419 }
420 }
421 Commands::Predict {
422 predict_edits_body,
423 context_args,
424 } => {
425 cx.spawn(async move |cx| {
426 let app_version = cx.update(|cx| AppVersion::global(cx))?;
427 app_state.client.sign_in(true, cx).await?;
428 let llm_token = LlmApiToken::default();
429 llm_token.refresh(&app_state.client).await?;
430
431 let predict_edits_body =
432 if let Some(predict_edits_body) = predict_edits_body {
433 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
434 } else if let Some(context_args) = context_args {
435 match get_context(None, context_args, &app_state, cx).await? {
436 GetContextOutput::Zeta1(output) => output.body,
437 GetContextOutput::Zeta2 { .. } => unreachable!(),
438 }
439 } else {
440 return Err(anyhow!(
441 "Expected either --predict-edits-body-file \
442 or the required args of the `context` command."
443 ));
444 };
445
446 let (response, _usage) =
447 Zeta::perform_predict_edits(PerformPredictEditsParams {
448 client: app_state.client.clone(),
449 llm_token,
450 app_version,
451 body: predict_edits_body,
452 })
453 .await?;
454
455 Ok(response.output_excerpt)
456 })
457 .await
458 }
459 };
460 match result {
461 Ok(output) => {
462 println!("{}", output);
463 // TODO: Remove this once the 5 second delay is properly replaced.
464 if is_zeta2_context_command {
465 eprintln!("Note that zeta2-context doesn't yet wait for indexing, instead waits 5 seconds.");
466 }
467 let _ = cx.update(|cx| cx.quit());
468 }
469 Err(e) => {
470 eprintln!("Failed: {:?}", e);
471 exit(1);
472 }
473 }
474 })
475 .detach();
476 });
477}