main.rs

  1//! Headless CLI binary for running Zed's agent in evaluation/benchmark environments.
  2//!
  3//! Designed to work inside containerized environments (like Harbor/termbench) where:
  4//! - The repository is already checked out at the working directory
  5//! - The model API key is provided via environment variables
  6//! - Results are written to an output directory (default: `/logs/agent/`)
  7//!
  8//! ## Usage
  9//!
 10//! ```text
 11//! eval-cli --workdir /testbed --model anthropic/claude-sonnet-4-6-latest \
 12//!          --instruction "Fix the bug described in..." --timeout 600
 13//! ```
 14//!
 15//! ## Output
 16//!
 17//! Writes to `--output-dir` (default `/logs/agent/`):
 18//!   - `result.json`  — structured result with status, timing, and token usage
 19//!   - `thread.md`    — full conversation as markdown
 20//!   - `thread.json`  — raw thread state as JSON
 21//!
 22//! ## Exit codes
 23//!
 24//! | Code | Meaning |
 25//! |------|---------|
 26//! | 0    | Agent finished |
 27//! | 1    | Error (model/auth/runtime failure) |
 28//! | 2    | Timeout |
 29//! | 3    | Interrupted (SIGTERM/SIGINT) |
 30
 31mod headless;
 32
 33use std::path::PathBuf;
 34use std::process;
 35use std::rc::Rc;
 36use std::str::FromStr;
 37use std::sync::Arc;
 38use std::sync::atomic::{AtomicBool, Ordering};
 39use std::time::{Duration, Instant};
 40
 41use acp_thread::AgentConnection as _;
 42use agent::{NativeAgent, NativeAgentConnection, Templates, ThreadStore};
 43use agent_client_protocol as acp;
 44use anyhow::{Context, Result};
 45use clap::Parser;
 46use feature_flags::FeatureFlagAppExt as _;
 47
 48use futures::{FutureExt, select_biased};
 49use gpui::{AppContext as _, AsyncApp, Entity, UpdateGlobal};
 50use language_model::{LanguageModelRegistry, SelectedModel};
 51use project::Project;
 52use settings::SettingsStore;
 53
 54use crate::headless::AgentCliAppState;
 55
 56#[derive(Parser, Debug)]
 57#[command(
 58    name = "eval-cli",
 59    about = "Run Zed's agent headlessly in evaluation/benchmark environments"
 60)]
 61struct Args {
 62    /// Output current environment variables as JSON to stdout.
 63    /// Used internally by Zed's shell environment capture.
 64    #[arg(long, hide = true)]
 65    printenv: bool,
 66
 67    /// Path to the repository working directory. Defaults to the current directory.
 68    #[arg(long, default_value = ".")]
 69    workdir: PathBuf,
 70
 71    /// Instruction/prompt text. If omitted, read from --instruction-file or stdin.
 72    #[arg(long)]
 73    instruction: Option<String>,
 74
 75    /// Language model to use, in `provider/model` format.
 76    #[arg(long, default_value = "anthropic/claude-sonnet-4-6-latest")]
 77    model: String,
 78
 79    /// Maximum wall-clock time in seconds for the agent run.
 80    #[arg(long)]
 81    timeout: Option<u64>,
 82
 83    /// Directory for output artifacts (result.json, thread.md, thread.json).
 84    #[arg(long, default_value = "/logs/agent")]
 85    output_dir: PathBuf,
 86}
 87
 88enum AgentOutcome {
 89    Completed,
 90    Timeout { seconds: u64 },
 91    Interrupted,
 92}
 93
 94#[derive(serde::Serialize)]
 95struct EvalResult {
 96    status: String,
 97    #[serde(skip_serializing_if = "Option::is_none")]
 98    error: Option<String>,
 99    duration_secs: f64,
100    #[serde(skip_serializing_if = "Option::is_none")]
101    timeout_secs: Option<u64>,
102    model: String,
103    #[serde(skip_serializing_if = "Option::is_none")]
104    input_tokens: Option<u64>,
105    #[serde(skip_serializing_if = "Option::is_none")]
106    output_tokens: Option<u64>,
107    #[serde(skip_serializing_if = "Option::is_none")]
108    cache_creation_input_tokens: Option<u64>,
109    #[serde(skip_serializing_if = "Option::is_none")]
110    cache_read_input_tokens: Option<u64>,
111}
112
113const EXIT_OK: i32 = 0;
114const EXIT_ERROR: i32 = 1;
115const EXIT_TIMEOUT: i32 = 2;
116const EXIT_INTERRUPTED: i32 = 3;
117
118static TERMINATED: AtomicBool = AtomicBool::new(false);
119
120fn main() {
121    let args = Args::parse();
122
123    if args.printenv {
124        util::shell_env::print_env();
125        return;
126    }
127
128    env_logger::init();
129
130    ctrlc::set_handler(|| {
131        TERMINATED.store(true, Ordering::SeqCst);
132    })
133    .expect("failed to set signal handler");
134
135    let instruction = read_instruction(&args).unwrap_or_else(|e| {
136        eprintln!("Error reading instruction: {e}");
137        process::exit(EXIT_ERROR);
138    });
139
140    let workdir = args.workdir.canonicalize().unwrap_or_else(|e| {
141        eprintln!("Invalid --workdir {:?}: {e}", args.workdir);
142        process::exit(EXIT_ERROR);
143    });
144
145    let output_dir = args.output_dir.clone();
146    if let Err(e) = std::fs::create_dir_all(&output_dir) {
147        eprintln!("Error creating output dir {}: {e}", output_dir.display());
148        process::exit(EXIT_ERROR);
149    }
150
151    let http_client = Arc::new(reqwest_client::ReqwestClient::new());
152    let app = gpui_platform::headless().with_http_client(http_client);
153
154    app.run(move |cx| {
155        let app_state = headless::init(cx);
156        cx.set_staff(true);
157
158        let auth_tasks = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
159            registry
160                .providers()
161                .iter()
162                .map(|p| p.authenticate(cx))
163                .collect::<Vec<_>>()
164        });
165
166        let model_name = args.model.clone();
167        let timeout = args.timeout;
168
169        cx.spawn(async move |cx| {
170            futures::future::join_all(auth_tasks).await;
171
172            let start = Instant::now();
173
174            let (outcome, token_usage) = run_agent(
175                &app_state,
176                &workdir,
177                &instruction,
178                &model_name,
179                timeout,
180                Some(&output_dir),
181                cx,
182            )
183            .await;
184
185            let duration = start.elapsed();
186
187            let (status, error, exit_code) = match &outcome {
188                Ok(AgentOutcome::Completed) => ("completed".to_string(), None, EXIT_OK),
189                Ok(AgentOutcome::Timeout { seconds }) => {
190                    eprintln!("Timeout: agent exceeded {seconds}s time limit");
191                    ("timeout".to_string(), None, EXIT_TIMEOUT)
192                }
193                Ok(AgentOutcome::Interrupted) => {
194                    eprintln!("Interrupted: received SIGTERM, saved partial output");
195                    ("interrupted".to_string(), None, EXIT_INTERRUPTED)
196                }
197                Err(e) => {
198                    eprintln!("Error: {e:#}");
199                    ("error".to_string(), Some(format!("{e:#}")), EXIT_ERROR)
200                }
201            };
202
203            let result = EvalResult {
204                status,
205                error,
206                duration_secs: duration.as_secs_f64(),
207                timeout_secs: timeout,
208                model: model_name.clone(),
209                input_tokens: token_usage.as_ref().map(|u| u.input_tokens),
210                output_tokens: token_usage.as_ref().map(|u| u.output_tokens),
211                cache_creation_input_tokens: token_usage
212                    .as_ref()
213                    .filter(|u| u.cache_creation_input_tokens > 0)
214                    .map(|u| u.cache_creation_input_tokens),
215                cache_read_input_tokens: token_usage
216                    .as_ref()
217                    .filter(|u| u.cache_read_input_tokens > 0)
218                    .map(|u| u.cache_read_input_tokens),
219            };
220
221            match serde_json::to_string_pretty(&result) {
222                Ok(json) => {
223                    if let Err(e) = std::fs::write(output_dir.join("result.json"), &json) {
224                        eprintln!("Error writing result.json: {e:#}");
225                    }
226                    eprintln!("[eval-cli] result: {json}");
227                }
228                Err(e) => eprintln!("Error serializing result: {e:#}"),
229            }
230
231            cx.update(|cx| cx.quit());
232            process::exit(exit_code);
233        })
234        .detach();
235    });
236}
237
238fn read_instruction(args: &Args) -> Result<String> {
239    let text = if let Some(text) = &args.instruction {
240        text.clone()
241    } else {
242        use std::io::Read;
243        let mut buf = String::new();
244        std::io::stdin()
245            .read_to_string(&mut buf)
246            .context("reading instruction from stdin")?;
247        buf
248    };
249    anyhow::ensure!(!text.trim().is_empty(), "instruction is empty");
250    Ok(text)
251}
252
253async fn run_agent(
254    app_state: &Arc<AgentCliAppState>,
255    workdir: &std::path::Path,
256    instruction: &str,
257    model_name: &str,
258    timeout: Option<u64>,
259    output_dir: Option<&std::path::Path>,
260    cx: &mut AsyncApp,
261) -> (Result<AgentOutcome>, Option<language_model::TokenUsage>) {
262    let setup_result: Result<()> = cx.update(|cx| {
263        let selected = SelectedModel::from_str(model_name).map_err(|e| anyhow::anyhow!("{e}"))?;
264        let registry = LanguageModelRegistry::global(cx);
265        let model = registry
266            .read(cx)
267            .available_models(cx)
268            .find(|m| m.id() == selected.model && m.provider_id() == selected.provider)
269            .ok_or_else(|| {
270                let available = registry
271                    .read(cx)
272                    .available_models(cx)
273                    .map(|m| format!("{}/{}", m.provider_id().0, m.id().0))
274                    .collect::<Vec<_>>()
275                    .join(", ");
276                anyhow::anyhow!("Model {model_name} not found. Available: {available}")
277            })?;
278
279        let supports_thinking = model.supports_thinking();
280
281        registry.update(cx, |registry, cx| {
282            registry.set_default_model(
283                Some(language_model::ConfiguredModel {
284                    provider: registry
285                        .provider(&model.provider_id())
286                        .context("Provider not found")?,
287                    model,
288                }),
289                cx,
290            );
291            anyhow::Ok(())
292        })?;
293
294        let (enable_thinking, effort) = if supports_thinking {
295            (true, "\"high\"")
296        } else {
297            (false, "null")
298        };
299        let provider_id = selected.provider.0.to_string();
300        let model_id = selected.model.0.to_string();
301        SettingsStore::update_global(cx, |store, cx| {
302            let settings = format!(
303                r#"{{
304                    "agent": {{
305                        "tool_permissions": {{"default": "allow"}},
306                        "default_model": {{
307                            "provider": "{provider_id}",
308                            "model": "{model_id}",
309                            "enable_thinking": {enable_thinking},
310                            "effort": {effort}
311                        }}
312                    }},
313                    "autosave": "off",
314                    "format_on_save": "off"
315                }}"
316                "#
317            );
318            store.set_user_settings(&settings, cx).ok();
319        });
320
321        anyhow::Ok(())
322    });
323
324    if let Err(e) = setup_result {
325        return (Err(e), None);
326    }
327
328    let project = cx.update(|cx| {
329        Project::local(
330            app_state.client.clone(),
331            app_state.node_runtime.clone(),
332            app_state.user_store.clone(),
333            app_state.languages.clone(),
334            app_state.fs.clone(),
335            None,
336            project::LocalProjectFlags {
337                init_worktree_trust: false,
338                ..Default::default()
339            },
340            cx,
341        )
342    });
343
344    let worktree = project.update(cx, |project, cx| project.create_worktree(workdir, true, cx));
345    let worktree = match worktree.await {
346        Ok(w) => w,
347        Err(e) => return (Err(e).context("creating worktree"), None),
348    };
349
350    let scan_result = worktree.update(cx, |tree, _cx| {
351        tree.as_local()
352            .context("expected local worktree")
353            .map(|local| local.scan_complete())
354    });
355    match scan_result {
356        Ok(future) => future.await,
357        Err(e) => return (Err(e), None),
358    };
359
360    let thread_store = cx.new(|cx| ThreadStore::new(cx));
361    let agent = match NativeAgent::new(
362        project.clone(),
363        thread_store,
364        Templates::new(),
365        None,
366        app_state.fs.clone(),
367        cx,
368    )
369    .await
370    {
371        Ok(a) => a,
372        Err(e) => return (Err(e).context("creating agent"), None),
373    };
374
375    let connection = Rc::new(NativeAgentConnection(agent.clone()));
376    let acp_thread = match cx
377        .update(|cx| connection.clone().new_session(project, workdir, cx))
378        .await
379    {
380        Ok(t) => t,
381        Err(e) => return (Err(e).context("creating ACP session"), None),
382    };
383
384    let _subscription = cx.subscribe(&acp_thread, |acp_thread, event, cx| {
385        log_acp_thread_event(&acp_thread, event, cx);
386    });
387
388    let message = vec![acp::ContentBlock::Text(acp::TextContent::new(
389        instruction.to_string(),
390    ))];
391
392    let send_future = acp_thread.update(cx, |acp_thread: &mut acp_thread::AcpThread, cx| {
393        acp_thread.send(message, cx)
394    });
395
396    let timeout_future = if let Some(timeout_secs) = timeout {
397        futures::future::Either::Left(
398            cx.background_executor()
399                .timer(Duration::from_secs(timeout_secs)),
400        )
401    } else {
402        futures::future::Either::Right(futures::future::pending::<()>())
403    };
404
405    let sigterm_future = {
406        let executor = cx.background_executor().clone();
407        async move {
408            while !TERMINATED.load(Ordering::Relaxed) {
409                executor.timer(Duration::from_millis(100)).await;
410            }
411        }
412    };
413
414    let outcome = select_biased! {
415        result = send_future.fuse() => match result {
416            Ok(Some(response)) => {
417                eprintln!("[eval-cli] stopped: {:?}", response.stop_reason);
418                if response.stop_reason == acp::StopReason::MaxTokens {
419                    Err(anyhow::anyhow!("Model hit maximum token limit"))
420                } else {
421                    Ok(AgentOutcome::Completed)
422                }
423            }
424            Ok(None) => {
425                eprintln!("[eval-cli] completed (no response)");
426                Ok(AgentOutcome::Completed)
427            }
428            Err(e) => Err(e).context("agent run failed"),
429        },
430        _ = sigterm_future.fuse() => {
431            eprintln!("[eval-cli] received SIGTERM, cancelling...");
432            acp_thread.update(cx, |t: &mut acp_thread::AcpThread, cx| t.cancel(cx)).await;
433            Ok(AgentOutcome::Interrupted)
434        },
435        _ = timeout_future.fuse() => {
436            acp_thread.update(cx, |t: &mut acp_thread::AcpThread, cx| t.cancel(cx)).await;
437            Ok(AgentOutcome::Timeout { seconds: timeout.unwrap_or(0) })
438        }
439    };
440
441    let thread = cx.update(|cx| {
442        let session_id = acp_thread.read(cx).session_id().clone();
443        connection.thread(&session_id, cx)
444    });
445
446    let cumulative_usage = if let Some(thread) = &thread {
447        let db_thread = thread.read_with(cx, |thread, cx| thread.to_db(cx));
448        let db_thread = db_thread.await;
449        let usage = db_thread.cumulative_token_usage;
450        if usage.input_tokens > 0 || usage.output_tokens > 0 {
451            Some(usage)
452        } else {
453            None
454        }
455    } else {
456        None
457    };
458
459    let acp_usage = cx.update(|cx| {
460        acp_thread
461            .read(cx)
462            .token_usage()
463            .map(|usage| language_model::TokenUsage {
464                input_tokens: usage.input_tokens,
465                output_tokens: usage.output_tokens,
466                ..Default::default()
467            })
468    });
469
470    let final_usage = cumulative_usage.or(acp_usage);
471
472    if let (Some(thread), Some(dir)) = (&thread, output_dir) {
473        let markdown = thread.read_with(cx, |thread, _cx| thread.to_markdown());
474        if let Err(e) = std::fs::write(dir.join("thread.md"), markdown) {
475            eprintln!("Error writing thread.md: {e:#}");
476        }
477
478        let db_thread = thread.read_with(cx, |thread, cx| thread.to_db(cx));
479        let db_thread = db_thread.await;
480        match serde_json::to_string_pretty(&db_thread) {
481            Ok(json) => {
482                if let Err(e) = std::fs::write(dir.join("thread.json"), json) {
483                    eprintln!("Error writing thread.json: {e:#}");
484                }
485            }
486            Err(e) => eprintln!("Error serializing thread.json: {e:#}"),
487        }
488    }
489
490    (outcome, final_usage)
491}
492
493fn log_acp_thread_event(
494    acp_thread: &Entity<acp_thread::AcpThread>,
495    event: &acp_thread::AcpThreadEvent,
496    cx: &mut gpui::App,
497) {
498    match event {
499        acp_thread::AcpThreadEvent::NewEntry => {
500            let entries = acp_thread.read(cx).entries();
501            if let Some(acp_thread::AgentThreadEntry::AssistantMessage(message)) = entries.last() {
502                for chunk in &message.chunks {
503                    if let acp_thread::AssistantMessageChunk::Message { block } = chunk {
504                        if let acp_thread::ContentBlock::Markdown { markdown } = block {
505                            let text = markdown.read(cx).source().to_string();
506                            if !text.is_empty() {
507                                eprint!("{text}");
508                            }
509                        }
510                    }
511                }
512            }
513        }
514        acp_thread::AcpThreadEvent::EntryUpdated(index) => {
515            let entries = acp_thread.read(cx).entries();
516            if let Some(acp_thread::AgentThreadEntry::ToolCall(tool_call)) = entries.get(*index) {
517                if let Some(name) = &tool_call.tool_name {
518                    match &tool_call.status {
519                        acp_thread::ToolCallStatus::Completed => {
520                            eprintln!("[tool] {name}");
521                        }
522                        acp_thread::ToolCallStatus::Failed => {
523                            eprintln!("[tool] {name}");
524                        }
525                        acp_thread::ToolCallStatus::Rejected => {
526                            eprintln!("[tool] {name} rejected");
527                        }
528                        acp_thread::ToolCallStatus::Canceled => {
529                            eprintln!("[tool] {name} canceled");
530                        }
531                        _ => {}
532                    }
533                }
534            }
535        }
536        acp_thread::AcpThreadEvent::Stopped(reason) => {
537            eprintln!("\n[eval-cli] stopped: {reason:?}");
538        }
539        acp_thread::AcpThreadEvent::Error => {
540            eprintln!("[eval-cli] error event");
541        }
542        acp_thread::AcpThreadEvent::Retry(status) => {
543            eprintln!("[eval-cli] retry: {status:?}");
544        }
545        acp_thread::AcpThreadEvent::SubagentSpawned(session_id) => {
546            eprintln!("[eval-cli] subagent spawned: {session_id}");
547        }
548        _ => {}
549    }
550}