1mod headless;
2
3use anyhow::{Result, anyhow};
4use clap::{Args, Parser, Subcommand};
5use futures::channel::mpsc;
6use futures::{FutureExt as _, StreamExt as _};
7use gpui::{AppContext, Application, AsyncApp};
8use gpui::{Entity, Task};
9use language::Bias;
10use language::Buffer;
11use language::Point;
12use language_model::LlmApiToken;
13use project::{Project, ProjectPath};
14use release_channel::AppVersion;
15use reqwest_client::ReqwestClient;
16use std::path::{Path, PathBuf};
17use std::process::exit;
18use std::str::FromStr;
19use std::sync::Arc;
20use std::time::Duration;
21use zeta::{GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context};
22
23use crate::headless::ZetaCliAppState;
24
25#[derive(Parser, Debug)]
26#[command(name = "zeta")]
27struct ZetaCliArgs {
28 #[command(subcommand)]
29 command: Commands,
30}
31
32#[derive(Subcommand, Debug)]
33enum Commands {
34 Context(ContextArgs),
35 Predict {
36 #[arg(long)]
37 predict_edits_body: Option<FileOrStdin>,
38 #[clap(flatten)]
39 context_args: Option<ContextArgs>,
40 },
41}
42
43#[derive(Debug, Args)]
44#[group(requires = "worktree")]
45struct ContextArgs {
46 #[arg(long)]
47 worktree: PathBuf,
48 #[arg(long)]
49 cursor: CursorPosition,
50 #[arg(long)]
51 use_language_server: bool,
52 #[arg(long)]
53 events: Option<FileOrStdin>,
54}
55
56#[derive(Debug, Clone)]
57enum FileOrStdin {
58 File(PathBuf),
59 Stdin,
60}
61
62impl FileOrStdin {
63 async fn read_to_string(&self) -> Result<String, std::io::Error> {
64 match self {
65 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
66 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
67 }
68 }
69}
70
71impl FromStr for FileOrStdin {
72 type Err = <PathBuf as FromStr>::Err;
73
74 fn from_str(s: &str) -> Result<Self, Self::Err> {
75 match s {
76 "-" => Ok(Self::Stdin),
77 _ => Ok(Self::File(PathBuf::from_str(s)?)),
78 }
79 }
80}
81
82#[derive(Debug, Clone)]
83struct CursorPosition {
84 path: PathBuf,
85 point: Point,
86}
87
88impl FromStr for CursorPosition {
89 type Err = anyhow::Error;
90
91 fn from_str(s: &str) -> Result<Self> {
92 let parts: Vec<&str> = s.split(':').collect();
93 if parts.len() != 3 {
94 return Err(anyhow!(
95 "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
96 s
97 ));
98 }
99
100 let path = PathBuf::from(parts[0]);
101 let line: u32 = parts[1]
102 .parse()
103 .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
104 let column: u32 = parts[2]
105 .parse()
106 .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
107
108 // Convert from 1-based to 0-based indexing
109 let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
110
111 Ok(CursorPosition { path, point })
112 }
113}
114
115async fn get_context(
116 args: ContextArgs,
117 app_state: &Arc<ZetaCliAppState>,
118 cx: &mut AsyncApp,
119) -> Result<GatherContextOutput> {
120 let ContextArgs {
121 worktree: worktree_path,
122 cursor,
123 use_language_server,
124 events,
125 } = args;
126
127 let worktree_path = worktree_path.canonicalize()?;
128 if cursor.path.is_absolute() {
129 return Err(anyhow!("Absolute paths are not supported in --cursor"));
130 }
131
132 let (project, _lsp_open_handle, buffer) = if use_language_server {
133 let (project, lsp_open_handle, buffer) =
134 open_buffer_with_language_server(&worktree_path, &cursor.path, app_state, cx).await?;
135 (Some(project), Some(lsp_open_handle), buffer)
136 } else {
137 let abs_path = worktree_path.join(&cursor.path);
138 let content = smol::fs::read_to_string(&abs_path).await?;
139 let buffer = cx.new(|cx| Buffer::local(content, cx))?;
140 (None, None, buffer)
141 };
142
143 let worktree_name = worktree_path
144 .file_name()
145 .ok_or_else(|| anyhow!("--worktree path must end with a folder name"))?;
146 let full_path_str = PathBuf::from(worktree_name)
147 .join(&cursor.path)
148 .to_string_lossy()
149 .to_string();
150
151 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
152 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
153 if clipped_cursor != cursor.point {
154 let max_row = snapshot.max_point().row;
155 if cursor.point.row < max_row {
156 return Err(anyhow!(
157 "Cursor position {:?} is out of bounds (line length is {})",
158 cursor.point,
159 snapshot.line_len(cursor.point.row)
160 ));
161 } else {
162 return Err(anyhow!(
163 "Cursor position {:?} is out of bounds (max row is {})",
164 cursor.point,
165 max_row
166 ));
167 }
168 }
169
170 let events = match events {
171 Some(events) => events.read_to_string().await?,
172 None => String::new(),
173 };
174 // Enable gathering extra data not currently needed for edit predictions
175 let can_collect_data = true;
176 let git_info = None;
177 let recent_files = None;
178 let mut gather_context_output = cx
179 .update(|cx| {
180 gather_context(
181 project.as_ref(),
182 full_path_str,
183 &snapshot,
184 clipped_cursor,
185 move || events,
186 can_collect_data,
187 git_info,
188 recent_files,
189 cx,
190 )
191 })?
192 .await;
193
194 // Disable data collection for these requests, as this is currently just used for evals
195 if let Ok(gather_context_output) = gather_context_output.as_mut() {
196 gather_context_output.body.can_collect_data = false
197 }
198
199 gather_context_output
200}
201
202pub async fn open_buffer_with_language_server(
203 worktree_path: &Path,
204 path: &Path,
205 app_state: &Arc<ZetaCliAppState>,
206 cx: &mut AsyncApp,
207) -> Result<(Entity<Project>, Entity<Entity<Buffer>>, Entity<Buffer>)> {
208 let project = cx.update(|cx| {
209 Project::local(
210 app_state.client.clone(),
211 app_state.node_runtime.clone(),
212 app_state.user_store.clone(),
213 app_state.languages.clone(),
214 app_state.fs.clone(),
215 None,
216 cx,
217 )
218 })?;
219
220 let worktree = project
221 .update(cx, |project, cx| {
222 project.create_worktree(worktree_path, true, cx)
223 })?
224 .await?;
225
226 let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
227 worktree_id: worktree.id(),
228 path: path.to_path_buf().into(),
229 })?;
230
231 let buffer = project
232 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
233 .await?;
234
235 let lsp_open_handle = project.update(cx, |project, cx| {
236 project.register_buffer_with_language_servers(&buffer, cx)
237 })?;
238
239 let log_prefix = path.to_string_lossy().to_string();
240 wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
241
242 Ok((project, lsp_open_handle, buffer))
243}
244
245// TODO: Dedupe with similar function in crates/eval/src/instance.rs
246pub fn wait_for_lang_server(
247 project: &Entity<Project>,
248 buffer: &Entity<Buffer>,
249 log_prefix: String,
250 cx: &mut AsyncApp,
251) -> Task<Result<()>> {
252 println!("{}⏵ Waiting for language server", log_prefix);
253
254 let (mut tx, mut rx) = mpsc::channel(1);
255
256 let lsp_store = project
257 .read_with(cx, |project, _| project.lsp_store())
258 .unwrap();
259
260 let has_lang_server = buffer
261 .update(cx, |buffer, cx| {
262 lsp_store.update(cx, |lsp_store, cx| {
263 lsp_store
264 .language_servers_for_local_buffer(buffer, cx)
265 .next()
266 .is_some()
267 })
268 })
269 .unwrap_or(false);
270
271 if has_lang_server {
272 project
273 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
274 .unwrap()
275 .detach();
276 }
277
278 let subscriptions = [
279 cx.subscribe(&lsp_store, {
280 let log_prefix = log_prefix.clone();
281 move |_, event, _| {
282 if let project::LspStoreEvent::LanguageServerUpdate {
283 message:
284 client::proto::update_language_server::Variant::WorkProgress(
285 client::proto::LspWorkProgress {
286 message: Some(message),
287 ..
288 },
289 ),
290 ..
291 } = event
292 {
293 println!("{}⟲ {message}", log_prefix)
294 }
295 }
296 }),
297 cx.subscribe(project, {
298 let buffer = buffer.clone();
299 move |project, event, cx| match event {
300 project::Event::LanguageServerAdded(_, _, _) => {
301 let buffer = buffer.clone();
302 project
303 .update(cx, |project, cx| project.save_buffer(buffer, cx))
304 .detach();
305 }
306 project::Event::DiskBasedDiagnosticsFinished { .. } => {
307 tx.try_send(()).ok();
308 }
309 _ => {}
310 }
311 }),
312 ];
313
314 cx.spawn(async move |cx| {
315 let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
316 let result = futures::select! {
317 _ = rx.next() => {
318 println!("{}⚑ Language server idle", log_prefix);
319 anyhow::Ok(())
320 },
321 _ = timeout.fuse() => {
322 anyhow::bail!("LSP wait timed out after 5 minutes");
323 }
324 };
325 drop(subscriptions);
326 result
327 })
328}
329
330fn main() {
331 let args = ZetaCliArgs::parse();
332 let http_client = Arc::new(ReqwestClient::new());
333 let app = Application::headless().with_http_client(http_client);
334
335 app.run(move |cx| {
336 let app_state = Arc::new(headless::init(cx));
337 cx.spawn(async move |cx| {
338 let result = match args.command {
339 Commands::Context(context_args) => get_context(context_args, &app_state, cx)
340 .await
341 .map(|output| serde_json::to_string_pretty(&output.body).unwrap()),
342 Commands::Predict {
343 predict_edits_body,
344 context_args,
345 } => {
346 cx.spawn(async move |cx| {
347 let app_version = cx.update(|cx| AppVersion::global(cx))?;
348 app_state.client.sign_in(true, cx).await?;
349 let llm_token = LlmApiToken::default();
350 llm_token.refresh(&app_state.client).await?;
351
352 let predict_edits_body =
353 if let Some(predict_edits_body) = predict_edits_body {
354 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
355 } else if let Some(context_args) = context_args {
356 get_context(context_args, &app_state, cx).await?.body
357 } else {
358 return Err(anyhow!(
359 "Expected either --predict-edits-body-file \
360 or the required args of the `context` command."
361 ));
362 };
363
364 let (response, _usage) =
365 Zeta::perform_predict_edits(PerformPredictEditsParams {
366 client: app_state.client.clone(),
367 llm_token,
368 app_version,
369 body: predict_edits_body,
370 })
371 .await?;
372
373 Ok(response.output_excerpt)
374 })
375 .await
376 }
377 };
378 match result {
379 Ok(output) => {
380 println!("{}", output);
381 let _ = cx.update(|cx| cx.quit());
382 }
383 Err(e) => {
384 eprintln!("Failed: {:?}", e);
385 exit(1);
386 }
387 }
388 })
389 .detach();
390 });
391}