1mod headless;
2
3use anyhow::{Result, anyhow};
4use clap::{Args, Parser, Subcommand};
5use edit_prediction_context::EditPredictionExcerptOptions;
6use futures::channel::mpsc;
7use futures::{FutureExt as _, StreamExt as _};
8use gpui::{AppContext, Application, AsyncApp};
9use gpui::{Entity, Task};
10use language::Bias;
11use language::Buffer;
12use language::Point;
13use language_model::LlmApiToken;
14use project::{Project, ProjectPath, Worktree};
15use release_channel::AppVersion;
16use reqwest_client::ReqwestClient;
17use std::path::{Path, PathBuf};
18use std::process::exit;
19use std::str::FromStr;
20use std::sync::Arc;
21use std::time::Duration;
22use zeta::{PerformPredictEditsParams, Zeta};
23
24use crate::headless::ZetaCliAppState;
25
26#[derive(Parser, Debug)]
27#[command(name = "zeta")]
28struct ZetaCliArgs {
29 #[command(subcommand)]
30 command: Commands,
31}
32
33#[derive(Subcommand, Debug)]
34enum Commands {
35 Context(ContextArgs),
36 Zeta2Context {
37 #[clap(flatten)]
38 zeta2_args: Zeta2Args,
39 #[clap(flatten)]
40 context_args: ContextArgs,
41 },
42 Predict {
43 #[arg(long)]
44 predict_edits_body: Option<FileOrStdin>,
45 #[clap(flatten)]
46 context_args: Option<ContextArgs>,
47 },
48}
49
50#[derive(Debug, Args)]
51#[group(requires = "worktree")]
52struct ContextArgs {
53 #[arg(long)]
54 worktree: PathBuf,
55 #[arg(long)]
56 cursor: CursorPosition,
57 #[arg(long)]
58 use_language_server: bool,
59 #[arg(long)]
60 events: Option<FileOrStdin>,
61}
62
63#[derive(Debug, Args)]
64struct Zeta2Args {
65 #[arg(long, default_value_t = 8192)]
66 prompt_max_bytes: usize,
67 #[arg(long, default_value_t = 2048)]
68 excerpt_max_bytes: usize,
69 #[arg(long, default_value_t = 1024)]
70 excerpt_min_bytes: usize,
71 #[arg(long, default_value_t = 0.66)]
72 target_before_cursor_over_total_bytes: f32,
73}
74
75#[derive(Debug, Clone)]
76enum FileOrStdin {
77 File(PathBuf),
78 Stdin,
79}
80
81impl FileOrStdin {
82 async fn read_to_string(&self) -> Result<String, std::io::Error> {
83 match self {
84 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
85 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
86 }
87 }
88}
89
90impl FromStr for FileOrStdin {
91 type Err = <PathBuf as FromStr>::Err;
92
93 fn from_str(s: &str) -> Result<Self, Self::Err> {
94 match s {
95 "-" => Ok(Self::Stdin),
96 _ => Ok(Self::File(PathBuf::from_str(s)?)),
97 }
98 }
99}
100
101#[derive(Debug, Clone)]
102struct CursorPosition {
103 path: PathBuf,
104 point: Point,
105}
106
107impl FromStr for CursorPosition {
108 type Err = anyhow::Error;
109
110 fn from_str(s: &str) -> Result<Self> {
111 let parts: Vec<&str> = s.split(':').collect();
112 if parts.len() != 3 {
113 return Err(anyhow!(
114 "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
115 s
116 ));
117 }
118
119 let path = PathBuf::from(parts[0]);
120 let line: u32 = parts[1]
121 .parse()
122 .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
123 let column: u32 = parts[2]
124 .parse()
125 .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
126
127 // Convert from 1-based to 0-based indexing
128 let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
129
130 Ok(CursorPosition { path, point })
131 }
132}
133
134enum GetContextOutput {
135 Zeta1(zeta::GatherContextOutput),
136 Zeta2(String),
137}
138
139async fn get_context(
140 zeta2_args: Option<Zeta2Args>,
141 args: ContextArgs,
142 app_state: &Arc<ZetaCliAppState>,
143 cx: &mut AsyncApp,
144) -> Result<GetContextOutput> {
145 let ContextArgs {
146 worktree: worktree_path,
147 cursor,
148 use_language_server,
149 events,
150 } = args;
151
152 let worktree_path = worktree_path.canonicalize()?;
153 if cursor.path.is_absolute() {
154 return Err(anyhow!("Absolute paths are not supported in --cursor"));
155 }
156
157 let project = cx.update(|cx| {
158 Project::local(
159 app_state.client.clone(),
160 app_state.node_runtime.clone(),
161 app_state.user_store.clone(),
162 app_state.languages.clone(),
163 app_state.fs.clone(),
164 None,
165 cx,
166 )
167 })?;
168
169 let worktree = project
170 .update(cx, |project, cx| {
171 project.create_worktree(&worktree_path, true, cx)
172 })?
173 .await?;
174
175 let (_lsp_open_handle, buffer) = if use_language_server {
176 let (lsp_open_handle, buffer) =
177 open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
178 (Some(lsp_open_handle), buffer)
179 } else {
180 let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
181 (None, buffer)
182 };
183
184 let worktree_name = worktree_path
185 .file_name()
186 .ok_or_else(|| anyhow!("--worktree path must end with a folder name"))?;
187 let full_path_str = PathBuf::from(worktree_name)
188 .join(&cursor.path)
189 .to_string_lossy()
190 .to_string();
191
192 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
193 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
194 if clipped_cursor != cursor.point {
195 let max_row = snapshot.max_point().row;
196 if cursor.point.row < max_row {
197 return Err(anyhow!(
198 "Cursor position {:?} is out of bounds (line length is {})",
199 cursor.point,
200 snapshot.line_len(cursor.point.row)
201 ));
202 } else {
203 return Err(anyhow!(
204 "Cursor position {:?} is out of bounds (max row is {})",
205 cursor.point,
206 max_row
207 ));
208 }
209 }
210
211 let events = match events {
212 Some(events) => events.read_to_string().await?,
213 None => String::new(),
214 };
215
216 if let Some(zeta2_args) = zeta2_args {
217 Ok(GetContextOutput::Zeta2(
218 cx.update(|cx| {
219 let zeta = cx.new(|cx| {
220 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
221 });
222 zeta.update(cx, |zeta, cx| {
223 zeta.register_buffer(&buffer, &project, cx);
224 zeta.excerpt_options = EditPredictionExcerptOptions {
225 max_bytes: zeta2_args.excerpt_max_bytes,
226 min_bytes: zeta2_args.excerpt_min_bytes,
227 target_before_cursor_over_total_bytes: zeta2_args
228 .target_before_cursor_over_total_bytes,
229 }
230 });
231 // TODO: Actually wait for indexing.
232 let timer = cx.background_executor().timer(Duration::from_secs(5));
233 cx.spawn(async move |cx| {
234 timer.await;
235 let request = zeta
236 .update(cx, |zeta, cx| {
237 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
238 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
239 })?
240 .await?;
241 let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(
242 &request,
243 &cloud_zeta2_prompt::PlanOptions {
244 max_bytes: zeta2_args.prompt_max_bytes,
245 },
246 )?;
247 anyhow::Ok(planned_prompt.to_prompt_string())
248 })
249 })?
250 .await?,
251 ))
252 } else {
253 let prompt_for_events = move || (events, 0);
254 Ok(GetContextOutput::Zeta1(
255 cx.update(|cx| {
256 zeta::gather_context(
257 full_path_str,
258 &snapshot,
259 clipped_cursor,
260 prompt_for_events,
261 cx,
262 )
263 })?
264 .await?,
265 ))
266 }
267}
268
269pub async fn open_buffer(
270 project: &Entity<Project>,
271 worktree: &Entity<Worktree>,
272 path: &Path,
273 cx: &mut AsyncApp,
274) -> Result<Entity<Buffer>> {
275 let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
276 worktree_id: worktree.id(),
277 path: path.to_path_buf().into(),
278 })?;
279
280 project
281 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
282 .await
283}
284
285pub async fn open_buffer_with_language_server(
286 project: &Entity<Project>,
287 worktree: &Entity<Worktree>,
288 path: &Path,
289 cx: &mut AsyncApp,
290) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
291 let buffer = open_buffer(project, worktree, path, cx).await?;
292
293 let lsp_open_handle = project.update(cx, |project, cx| {
294 project.register_buffer_with_language_servers(&buffer, cx)
295 })?;
296
297 let log_prefix = path.to_string_lossy().to_string();
298 wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
299
300 Ok((lsp_open_handle, buffer))
301}
302
303// TODO: Dedupe with similar function in crates/eval/src/instance.rs
304pub fn wait_for_lang_server(
305 project: &Entity<Project>,
306 buffer: &Entity<Buffer>,
307 log_prefix: String,
308 cx: &mut AsyncApp,
309) -> Task<Result<()>> {
310 println!("{}⏵ Waiting for language server", log_prefix);
311
312 let (mut tx, mut rx) = mpsc::channel(1);
313
314 let lsp_store = project
315 .read_with(cx, |project, _| project.lsp_store())
316 .unwrap();
317
318 let has_lang_server = buffer
319 .update(cx, |buffer, cx| {
320 lsp_store.update(cx, |lsp_store, cx| {
321 lsp_store
322 .language_servers_for_local_buffer(buffer, cx)
323 .next()
324 .is_some()
325 })
326 })
327 .unwrap_or(false);
328
329 if has_lang_server {
330 project
331 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
332 .unwrap()
333 .detach();
334 }
335
336 let subscriptions = [
337 cx.subscribe(&lsp_store, {
338 let log_prefix = log_prefix.clone();
339 move |_, event, _| {
340 if let project::LspStoreEvent::LanguageServerUpdate {
341 message:
342 client::proto::update_language_server::Variant::WorkProgress(
343 client::proto::LspWorkProgress {
344 message: Some(message),
345 ..
346 },
347 ),
348 ..
349 } = event
350 {
351 println!("{}⟲ {message}", log_prefix)
352 }
353 }
354 }),
355 cx.subscribe(project, {
356 let buffer = buffer.clone();
357 move |project, event, cx| match event {
358 project::Event::LanguageServerAdded(_, _, _) => {
359 let buffer = buffer.clone();
360 project
361 .update(cx, |project, cx| project.save_buffer(buffer, cx))
362 .detach();
363 }
364 project::Event::DiskBasedDiagnosticsFinished { .. } => {
365 tx.try_send(()).ok();
366 }
367 _ => {}
368 }
369 }),
370 ];
371
372 cx.spawn(async move |cx| {
373 let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
374 let result = futures::select! {
375 _ = rx.next() => {
376 println!("{}⚑ Language server idle", log_prefix);
377 anyhow::Ok(())
378 },
379 _ = timeout.fuse() => {
380 anyhow::bail!("LSP wait timed out after 5 minutes");
381 }
382 };
383 drop(subscriptions);
384 result
385 })
386}
387
388fn main() {
389 let args = ZetaCliArgs::parse();
390 let http_client = Arc::new(ReqwestClient::new());
391 let app = Application::headless().with_http_client(http_client);
392
393 app.run(move |cx| {
394 let app_state = Arc::new(headless::init(cx));
395 let is_zeta2_context_command = matches!(args.command, Commands::Zeta2Context { .. });
396 cx.spawn(async move |cx| {
397 let result = match args.command {
398 Commands::Zeta2Context {
399 zeta2_args,
400 context_args,
401 } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
402 Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
403 Ok(GetContextOutput::Zeta2(output)) => Ok(output),
404 Err(err) => Err(err),
405 },
406 Commands::Context(context_args) => {
407 match get_context(None, context_args, &app_state, cx).await {
408 Ok(GetContextOutput::Zeta1(output)) => {
409 Ok(serde_json::to_string_pretty(&output.body).unwrap())
410 }
411 Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
412 Err(err) => Err(err),
413 }
414 }
415 Commands::Predict {
416 predict_edits_body,
417 context_args,
418 } => {
419 cx.spawn(async move |cx| {
420 let app_version = cx.update(|cx| AppVersion::global(cx))?;
421 app_state.client.sign_in(true, cx).await?;
422 let llm_token = LlmApiToken::default();
423 llm_token.refresh(&app_state.client).await?;
424
425 let predict_edits_body =
426 if let Some(predict_edits_body) = predict_edits_body {
427 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
428 } else if let Some(context_args) = context_args {
429 match get_context(None, context_args, &app_state, cx).await? {
430 GetContextOutput::Zeta1(output) => output.body,
431 GetContextOutput::Zeta2 { .. } => unreachable!(),
432 }
433 } else {
434 return Err(anyhow!(
435 "Expected either --predict-edits-body-file \
436 or the required args of the `context` command."
437 ));
438 };
439
440 let (response, _usage) =
441 Zeta::perform_predict_edits(PerformPredictEditsParams {
442 client: app_state.client.clone(),
443 llm_token,
444 app_version,
445 body: predict_edits_body,
446 })
447 .await?;
448
449 Ok(response.output_excerpt)
450 })
451 .await
452 }
453 };
454 match result {
455 Ok(output) => {
456 println!("{}", output);
457 // TODO: Remove this once the 5 second delay is properly replaced.
458 if is_zeta2_context_command {
459 eprintln!("Note that zeta2-context doesn't yet wait for indexing, instead waits 5 seconds.");
460 }
461 let _ = cx.update(|cx| cx.quit());
462 }
463 Err(e) => {
464 eprintln!("Failed: {:?}", e);
465 exit(1);
466 }
467 }
468 })
469 .detach();
470 });
471}