main.rs

  1//! Headless CLI binary for running Zed's agent in evaluation/benchmark environments.
  2//!
  3//! Designed to work inside containerized environments (like Harbor/termbench) where:
  4//! - The repository is already checked out at the working directory
  5//! - The model API key is provided via environment variables
  6//! - Results are written to an output directory (default: `/logs/agent/`)
  7//!
  8//! ## Usage
  9//!
 10//! ```text
 11//! eval-cli --workdir /testbed --model anthropic/claude-sonnet-4-6-latest \
 12//!          --instruction "Fix the bug described in..." --timeout 600
 13//! ```
 14//!
 15//! ## Output
 16//!
 17//! Writes to `--output-dir` (default `/logs/agent/`):
 18//!   - `result.json`  — structured result with status, timing, and token usage
 19//!   - `thread.md`    — full conversation as markdown
 20//!   - `thread.json`  — raw thread state as JSON
 21//!
 22//! ## Exit codes
 23//!
 24//! | Code | Meaning |
 25//! |------|---------|
 26//! | 0    | Agent finished |
 27//! | 1    | Error (model/auth/runtime failure) |
 28//! | 2    | Timeout |
 29//! | 3    | Interrupted (SIGTERM/SIGINT) |
 30
 31mod headless;
 32
 33use std::path::PathBuf;
 34use std::process;
 35use std::rc::Rc;
 36use std::str::FromStr;
 37use std::sync::Arc;
 38use std::sync::atomic::{AtomicBool, Ordering};
 39use std::time::{Duration, Instant};
 40
 41use acp_thread::AgentConnection as _;
 42use agent::{NativeAgent, NativeAgentConnection, Templates, ThreadStore};
 43use agent_client_protocol as acp;
 44use anyhow::{Context, Result};
 45use clap::Parser;
 46use feature_flags::FeatureFlagAppExt as _;
 47
 48use futures::{FutureExt, select_biased};
 49use gpui::{AppContext as _, AsyncApp, Entity, UpdateGlobal};
 50use language_model::{LanguageModelRegistry, SelectedModel};
 51use project::Project;
 52use settings::SettingsStore;
 53use util::path_list::PathList;
 54
 55use crate::headless::AgentCliAppState;
 56
 57#[derive(Parser, Debug)]
 58#[command(
 59    name = "eval-cli",
 60    about = "Run Zed's agent headlessly in evaluation/benchmark environments"
 61)]
 62struct Args {
 63    /// Output current environment variables as JSON to stdout.
 64    /// Used internally by Zed's shell environment capture.
 65    #[arg(long, hide = true)]
 66    printenv: bool,
 67
 68    /// Path to the repository working directory. Defaults to the current directory.
 69    #[arg(long, default_value = ".")]
 70    workdir: PathBuf,
 71
 72    /// Instruction/prompt text. If omitted, read from --instruction-file or stdin.
 73    #[arg(long)]
 74    instruction: Option<String>,
 75
 76    /// Language model to use, in `provider/model` format.
 77    #[arg(long, default_value = "anthropic/claude-sonnet-4-6-latest")]
 78    model: String,
 79
 80    /// Maximum wall-clock time in seconds for the agent run.
 81    #[arg(long)]
 82    timeout: Option<u64>,
 83
 84    /// Directory for output artifacts (result.json, thread.md, thread.json).
 85    #[arg(long, default_value = ".")]
 86    output_dir: PathBuf,
 87
 88    /// Disable staff mode (staff mode is enabled by default).
 89    #[arg(long)]
 90    no_staff: bool,
 91
 92    /// Reasoning effort level for models that support thinking (low, medium, high).
 93    /// Defaults to "high" for thinking-capable models.
 94    #[arg(long)]
 95    reasoning_effort: Option<String>,
 96
 97    /// Enable or disable extended thinking. Defaults to model auto-detection if omitted.
 98    #[arg(long)]
 99    thinking: Option<bool>,
100}
101
102enum AgentOutcome {
103    Completed,
104    Timeout { seconds: u64 },
105    Interrupted,
106}
107
108#[derive(serde::Serialize)]
109struct EvalResult {
110    status: String,
111    #[serde(skip_serializing_if = "Option::is_none")]
112    error: Option<String>,
113    duration_secs: f64,
114    #[serde(skip_serializing_if = "Option::is_none")]
115    timeout_secs: Option<u64>,
116    model: String,
117    #[serde(skip_serializing_if = "Option::is_none")]
118    input_tokens: Option<u64>,
119    #[serde(skip_serializing_if = "Option::is_none")]
120    output_tokens: Option<u64>,
121    #[serde(skip_serializing_if = "Option::is_none")]
122    cache_creation_input_tokens: Option<u64>,
123    #[serde(skip_serializing_if = "Option::is_none")]
124    cache_read_input_tokens: Option<u64>,
125}
126
127const EXIT_OK: i32 = 0;
128const EXIT_ERROR: i32 = 1;
129const EXIT_TIMEOUT: i32 = 2;
130const EXIT_INTERRUPTED: i32 = 3;
131
132static TERMINATED: AtomicBool = AtomicBool::new(false);
133
134fn main() {
135    let args = Args::parse();
136
137    if args.printenv {
138        util::shell_env::print_env();
139        return;
140    }
141
142    env_logger::init();
143
144    ctrlc::set_handler(|| {
145        TERMINATED.store(true, Ordering::SeqCst);
146    })
147    .expect("failed to set signal handler");
148
149    let instruction = read_instruction(&args).unwrap_or_else(|e| {
150        eprintln!("Error reading instruction: {e}");
151        process::exit(EXIT_ERROR);
152    });
153
154    let workdir = args.workdir.canonicalize().unwrap_or_else(|e| {
155        eprintln!("Invalid --workdir {:?}: {e}", args.workdir);
156        process::exit(EXIT_ERROR);
157    });
158
159    let output_dir = args.output_dir.clone();
160    if let Err(e) = std::fs::create_dir_all(&output_dir) {
161        eprintln!("Error creating output dir {}: {e}", output_dir.display());
162        process::exit(EXIT_ERROR);
163    }
164
165    let http_client = Arc::new(reqwest_client::ReqwestClient::new());
166    let app = gpui_platform::headless().with_http_client(http_client);
167
168    app.run(move |cx| {
169        let app_state = headless::init(cx);
170        cx.set_staff(!args.no_staff);
171
172        let auth_tasks = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
173            registry
174                .providers()
175                .iter()
176                .map(|p| p.authenticate(cx))
177                .collect::<Vec<_>>()
178        });
179
180        let model_name = args.model.clone();
181        let timeout = args.timeout;
182        let thinking_override = args.thinking;
183        let reasoning_effort = args.reasoning_effort.clone();
184
185        cx.spawn(async move |cx| {
186            futures::future::join_all(auth_tasks).await;
187
188            let start = Instant::now();
189
190            let (outcome, token_usage) = run_agent(
191                &app_state,
192                &workdir,
193                &instruction,
194                &model_name,
195                timeout,
196                thinking_override,
197                reasoning_effort.as_deref(),
198                Some(&output_dir),
199                cx,
200            )
201            .await;
202
203            let duration = start.elapsed();
204
205            let (status, error, exit_code) = match &outcome {
206                Ok(AgentOutcome::Completed) => ("completed".to_string(), None, EXIT_OK),
207                Ok(AgentOutcome::Timeout { seconds }) => {
208                    eprintln!("Timeout: agent exceeded {seconds}s time limit");
209                    ("timeout".to_string(), None, EXIT_TIMEOUT)
210                }
211                Ok(AgentOutcome::Interrupted) => {
212                    eprintln!("Interrupted: received SIGTERM, saved partial output");
213                    ("interrupted".to_string(), None, EXIT_INTERRUPTED)
214                }
215                Err(e) => {
216                    eprintln!("Error: {e:#}");
217                    ("error".to_string(), Some(format!("{e:#}")), EXIT_ERROR)
218                }
219            };
220
221            let result = EvalResult {
222                status,
223                error,
224                duration_secs: duration.as_secs_f64(),
225                timeout_secs: timeout,
226                model: model_name.clone(),
227                input_tokens: token_usage.as_ref().map(|u| u.input_tokens),
228                output_tokens: token_usage.as_ref().map(|u| u.output_tokens),
229                cache_creation_input_tokens: token_usage
230                    .as_ref()
231                    .filter(|u| u.cache_creation_input_tokens > 0)
232                    .map(|u| u.cache_creation_input_tokens),
233                cache_read_input_tokens: token_usage
234                    .as_ref()
235                    .filter(|u| u.cache_read_input_tokens > 0)
236                    .map(|u| u.cache_read_input_tokens),
237            };
238
239            match serde_json::to_string_pretty(&result) {
240                Ok(json) => {
241                    if let Err(e) = std::fs::write(output_dir.join("result.json"), &json) {
242                        eprintln!("Error writing result.json: {e:#}");
243                    }
244                    eprintln!("[eval-cli] result: {json}");
245                }
246                Err(e) => eprintln!("Error serializing result: {e:#}"),
247            }
248
249            cx.update(|cx| cx.quit());
250            process::exit(exit_code);
251        })
252        .detach();
253    });
254}
255
256fn read_instruction(args: &Args) -> Result<String> {
257    let text = if let Some(text) = &args.instruction {
258        text.clone()
259    } else {
260        use std::io::Read;
261        let mut buf = String::new();
262        std::io::stdin()
263            .read_to_string(&mut buf)
264            .context("reading instruction from stdin")?;
265        buf
266    };
267    anyhow::ensure!(!text.trim().is_empty(), "instruction is empty");
268    Ok(text)
269}
270
271async fn run_agent(
272    app_state: &Arc<AgentCliAppState>,
273    workdir: &std::path::Path,
274    instruction: &str,
275    model_name: &str,
276    timeout: Option<u64>,
277    thinking_override: Option<bool>,
278    reasoning_effort: Option<&str>,
279    output_dir: Option<&std::path::Path>,
280    cx: &mut AsyncApp,
281) -> (Result<AgentOutcome>, Option<language_model::TokenUsage>) {
282    let setup_result: Result<()> = cx.update(|cx| {
283        let selected = SelectedModel::from_str(model_name).map_err(|e| anyhow::anyhow!("{e}"))?;
284        let registry = LanguageModelRegistry::global(cx);
285        let model = registry
286            .read(cx)
287            .available_models(cx)
288            .find(|m| m.id() == selected.model && m.provider_id() == selected.provider)
289            .ok_or_else(|| {
290                let available = registry
291                    .read(cx)
292                    .available_models(cx)
293                    .map(|m| format!("{}/{}", m.provider_id().0, m.id().0))
294                    .collect::<Vec<_>>()
295                    .join(", ");
296                anyhow::anyhow!("Model {model_name} not found. Available: {available}")
297            })?;
298
299        let supports_thinking = model.supports_thinking();
300
301        registry.update(cx, |registry, cx| {
302            registry.set_default_model(
303                Some(language_model::ConfiguredModel {
304                    provider: registry
305                        .provider(&model.provider_id())
306                        .context("Provider not found")?,
307                    model,
308                }),
309                cx,
310            );
311            anyhow::Ok(())
312        })?;
313
314        let enable_thinking = thinking_override.unwrap_or(supports_thinking);
315        let effort = if enable_thinking {
316            match reasoning_effort {
317                Some(level) => format!("\"{level}\""),
318                None => "\"high\"".to_string(),
319            }
320        } else {
321            "null".to_string()
322        };
323        let provider_id = selected.provider.0.to_string();
324        let model_id = selected.model.0.to_string();
325        SettingsStore::update_global(cx, |store, cx| {
326            let settings = format!(
327                r#"{{
328                    "agent": {{
329                        "tool_permissions": {{"default": "allow"}},
330                        "default_model": {{
331                            "provider": "{provider_id}",
332                            "model": "{model_id}",
333                            "enable_thinking": {enable_thinking},
334                            "effort": {effort}
335                        }}
336                    }},
337                    "autosave": "off",
338                    "format_on_save": "off"
339                }}"
340                "#
341            );
342            store.set_user_settings(&settings, cx).ok();
343        });
344
345        anyhow::Ok(())
346    });
347
348    if let Err(e) = setup_result {
349        return (Err(e), None);
350    }
351
352    let project = cx.update(|cx| {
353        Project::local(
354            app_state.client.clone(),
355            app_state.node_runtime.clone(),
356            app_state.user_store.clone(),
357            app_state.languages.clone(),
358            app_state.fs.clone(),
359            None,
360            project::LocalProjectFlags {
361                init_worktree_trust: false,
362                ..Default::default()
363            },
364            cx,
365        )
366    });
367
368    let worktree = project.update(cx, |project, cx| project.create_worktree(workdir, true, cx));
369    let worktree = match worktree.await {
370        Ok(w) => w,
371        Err(e) => return (Err(e).context("creating worktree"), None),
372    };
373
374    let scan_result = worktree.update(cx, |tree, _cx| {
375        tree.as_local()
376            .context("expected local worktree")
377            .map(|local| local.scan_complete())
378    });
379    match scan_result {
380        Ok(future) => future.await,
381        Err(e) => return (Err(e), None),
382    };
383
384    let agent = cx.update(|cx| {
385        let thread_store = cx.new(|cx| ThreadStore::new(cx));
386        NativeAgent::new(
387            thread_store,
388            Templates::new(),
389            None,
390            app_state.fs.clone(),
391            cx,
392        )
393    });
394
395    let connection = Rc::new(NativeAgentConnection(agent.clone()));
396    let acp_thread = match cx
397        .update(|cx| {
398            connection
399                .clone()
400                .new_session(project, PathList::new(&[workdir]), cx)
401        })
402        .await
403    {
404        Ok(t) => t,
405        Err(e) => return (Err(e).context("creating ACP session"), None),
406    };
407
408    let _subscription = cx.subscribe(&acp_thread, |acp_thread, event, cx| {
409        log_acp_thread_event(&acp_thread, event, cx);
410    });
411
412    let message = vec![acp::ContentBlock::Text(acp::TextContent::new(
413        instruction.to_string(),
414    ))];
415
416    let send_future = acp_thread.update(cx, |acp_thread: &mut acp_thread::AcpThread, cx| {
417        acp_thread.send(message, cx)
418    });
419
420    let timeout_future = if let Some(timeout_secs) = timeout {
421        futures::future::Either::Left(
422            cx.background_executor()
423                .timer(Duration::from_secs(timeout_secs)),
424        )
425    } else {
426        futures::future::Either::Right(futures::future::pending::<()>())
427    };
428
429    let sigterm_future = {
430        let executor = cx.background_executor().clone();
431        async move {
432            while !TERMINATED.load(Ordering::Relaxed) {
433                executor.timer(Duration::from_millis(100)).await;
434            }
435        }
436    };
437
438    let outcome = select_biased! {
439        result = send_future.fuse() => match result {
440            Ok(Some(response)) => {
441                eprintln!("[eval-cli] stopped: {:?}", response.stop_reason);
442                if response.stop_reason == acp::StopReason::MaxTokens {
443                    Err(anyhow::anyhow!("Model hit maximum token limit"))
444                } else {
445                    Ok(AgentOutcome::Completed)
446                }
447            }
448            Ok(None) => {
449                eprintln!("[eval-cli] completed (no response)");
450                Ok(AgentOutcome::Completed)
451            }
452            Err(e) => Err(e).context("agent run failed"),
453        },
454        _ = sigterm_future.fuse() => {
455            eprintln!("[eval-cli] received SIGTERM, cancelling...");
456            acp_thread.update(cx, |t: &mut acp_thread::AcpThread, cx| t.cancel(cx)).await;
457            Ok(AgentOutcome::Interrupted)
458        },
459        _ = timeout_future.fuse() => {
460            acp_thread.update(cx, |t: &mut acp_thread::AcpThread, cx| t.cancel(cx)).await;
461            Ok(AgentOutcome::Timeout { seconds: timeout.unwrap_or(0) })
462        }
463    };
464
465    let thread = cx.update(|cx| {
466        let session_id = acp_thread.read(cx).session_id().clone();
467        connection.thread(&session_id, cx)
468    });
469
470    let cumulative_usage = if let Some(thread) = &thread {
471        let db_thread = thread.read_with(cx, |thread, cx| thread.to_db(cx));
472        let db_thread = db_thread.await;
473        let usage = db_thread.cumulative_token_usage;
474        if usage.input_tokens > 0 || usage.output_tokens > 0 {
475            Some(usage)
476        } else {
477            None
478        }
479    } else {
480        None
481    };
482
483    let acp_usage = cx.update(|cx| {
484        acp_thread
485            .read(cx)
486            .token_usage()
487            .map(|usage| language_model::TokenUsage {
488                input_tokens: usage.input_tokens,
489                output_tokens: usage.output_tokens,
490                ..Default::default()
491            })
492    });
493
494    let final_usage = cumulative_usage.or(acp_usage);
495
496    if let (Some(thread), Some(dir)) = (&thread, output_dir) {
497        let markdown = thread.read_with(cx, |thread, _cx| thread.to_markdown());
498        if let Err(e) = std::fs::write(dir.join("thread.md"), markdown) {
499            eprintln!("Error writing thread.md: {e:#}");
500        }
501
502        let db_thread = thread.read_with(cx, |thread, cx| thread.to_db(cx));
503        let db_thread = db_thread.await;
504        match serde_json::to_string_pretty(&db_thread) {
505            Ok(json) => {
506                if let Err(e) = std::fs::write(dir.join("thread.json"), json) {
507                    eprintln!("Error writing thread.json: {e:#}");
508                }
509            }
510            Err(e) => eprintln!("Error serializing thread.json: {e:#}"),
511        }
512    }
513
514    (outcome, final_usage)
515}
516
517fn log_acp_thread_event(
518    acp_thread: &Entity<acp_thread::AcpThread>,
519    event: &acp_thread::AcpThreadEvent,
520    cx: &mut gpui::App,
521) {
522    match event {
523        acp_thread::AcpThreadEvent::NewEntry => {
524            let entries = acp_thread.read(cx).entries();
525            if let Some(acp_thread::AgentThreadEntry::AssistantMessage(message)) = entries.last() {
526                for chunk in &message.chunks {
527                    if let acp_thread::AssistantMessageChunk::Message { block } = chunk {
528                        if let acp_thread::ContentBlock::Markdown { markdown } = block {
529                            let text = markdown.read(cx).source().to_string();
530                            if !text.is_empty() {
531                                eprint!("{text}");
532                            }
533                        }
534                    }
535                }
536            }
537        }
538        acp_thread::AcpThreadEvent::EntryUpdated(index) => {
539            let entries = acp_thread.read(cx).entries();
540            if let Some(acp_thread::AgentThreadEntry::ToolCall(tool_call)) = entries.get(*index) {
541                if let Some(name) = &tool_call.tool_name {
542                    match &tool_call.status {
543                        acp_thread::ToolCallStatus::Completed => {
544                            eprintln!("[tool] {name}");
545                        }
546                        acp_thread::ToolCallStatus::Failed => {
547                            eprintln!("[tool] {name}");
548                        }
549                        acp_thread::ToolCallStatus::Rejected => {
550                            eprintln!("[tool] {name} rejected");
551                        }
552                        acp_thread::ToolCallStatus::Canceled => {
553                            eprintln!("[tool] {name} canceled");
554                        }
555                        _ => {}
556                    }
557                }
558            }
559        }
560        acp_thread::AcpThreadEvent::Stopped(reason) => {
561            eprintln!("\n[eval-cli] stopped: {reason:?}");
562        }
563        acp_thread::AcpThreadEvent::Error => {
564            eprintln!("[eval-cli] error event");
565        }
566        acp_thread::AcpThreadEvent::Retry(status) => {
567            eprintln!("[eval-cli] retry: {status:?}");
568        }
569        acp_thread::AcpThreadEvent::SubagentSpawned(session_id) => {
570            eprintln!("[eval-cli] subagent spawned: {session_id}");
571        }
572        _ => {}
573    }
574}