main.rs

  1//! Headless CLI binary for running Zed's agent in evaluation/benchmark environments.
  2//!
  3//! Designed to work inside containerized environments (like Harbor/termbench) where:
  4//! - The repository is already checked out at the working directory
  5//! - The model API key is provided via environment variables
  6//! - Results are written to an output directory (default: `/logs/agent/`)
  7//!
  8//! ## Usage
  9//!
 10//! ```text
 11//! eval-cli --workdir /testbed --model anthropic/claude-sonnet-4-6-latest \
 12//!          --instruction "Fix the bug described in..." --timeout 600
 13//! ```
 14//!
 15//! ## Output
 16//!
 17//! Writes to `--output-dir` (default `/logs/agent/`):
 18//!   - `result.json`  — structured result with status, timing, and token usage
 19//!   - `thread.md`    — full conversation as markdown
 20//!   - `thread.json`  — raw thread state as JSON
 21//!
 22//! ## Exit codes
 23//!
 24//! | Code | Meaning |
 25//! |------|---------|
 26//! | 0    | Agent finished |
 27//! | 1    | Error (model/auth/runtime failure) |
 28//! | 2    | Timeout |
 29//! | 3    | Interrupted (SIGTERM/SIGINT) |
 30
 31mod headless;
 32
 33use std::path::PathBuf;
 34use std::process;
 35use std::rc::Rc;
 36use std::str::FromStr;
 37use std::sync::Arc;
 38use std::sync::atomic::{AtomicBool, Ordering};
 39use std::time::{Duration, Instant};
 40
 41use acp_thread::AgentConnection as _;
 42use agent::{NativeAgent, NativeAgentConnection, Templates, ThreadStore};
 43use agent_client_protocol as acp;
 44use anyhow::{Context, Result};
 45use clap::Parser;
 46use feature_flags::FeatureFlagAppExt as _;
 47
 48use futures::{FutureExt, select_biased};
 49use gpui::{AppContext as _, AsyncApp, Entity, UpdateGlobal};
 50use language_model::{LanguageModelRegistry, SelectedModel};
 51use project::Project;
 52use settings::SettingsStore;
 53use util::path_list::PathList;
 54
 55use crate::headless::AgentCliAppState;
 56
 57#[derive(Parser, Debug)]
 58#[command(
 59    name = "eval-cli",
 60    about = "Run Zed's agent headlessly in evaluation/benchmark environments"
 61)]
 62struct Args {
 63    /// Output current environment variables as JSON to stdout.
 64    /// Used internally by Zed's shell environment capture.
 65    #[arg(long, hide = true)]
 66    printenv: bool,
 67
 68    /// Path to the repository working directory. Defaults to the current directory.
 69    #[arg(long, default_value = ".")]
 70    workdir: PathBuf,
 71
 72    /// Instruction/prompt text. If omitted, read from --instruction-file or stdin.
 73    #[arg(long)]
 74    instruction: Option<String>,
 75
 76    /// Language model to use, in `provider/model` format.
 77    #[arg(long, default_value = "anthropic/claude-sonnet-4-6-latest")]
 78    model: String,
 79
 80    /// Maximum wall-clock time in seconds for the agent run.
 81    #[arg(long)]
 82    timeout: Option<u64>,
 83
 84    /// Directory for output artifacts (result.json, thread.md, thread.json).
 85    #[arg(long, default_value = "/logs/agent")]
 86    output_dir: PathBuf,
 87}
 88
 89enum AgentOutcome {
 90    Completed,
 91    Timeout { seconds: u64 },
 92    Interrupted,
 93}
 94
 95#[derive(serde::Serialize)]
 96struct EvalResult {
 97    status: String,
 98    #[serde(skip_serializing_if = "Option::is_none")]
 99    error: Option<String>,
100    duration_secs: f64,
101    #[serde(skip_serializing_if = "Option::is_none")]
102    timeout_secs: Option<u64>,
103    model: String,
104    #[serde(skip_serializing_if = "Option::is_none")]
105    input_tokens: Option<u64>,
106    #[serde(skip_serializing_if = "Option::is_none")]
107    output_tokens: Option<u64>,
108    #[serde(skip_serializing_if = "Option::is_none")]
109    cache_creation_input_tokens: Option<u64>,
110    #[serde(skip_serializing_if = "Option::is_none")]
111    cache_read_input_tokens: Option<u64>,
112}
113
114const EXIT_OK: i32 = 0;
115const EXIT_ERROR: i32 = 1;
116const EXIT_TIMEOUT: i32 = 2;
117const EXIT_INTERRUPTED: i32 = 3;
118
119static TERMINATED: AtomicBool = AtomicBool::new(false);
120
121fn main() {
122    let args = Args::parse();
123
124    if args.printenv {
125        util::shell_env::print_env();
126        return;
127    }
128
129    env_logger::init();
130
131    ctrlc::set_handler(|| {
132        TERMINATED.store(true, Ordering::SeqCst);
133    })
134    .expect("failed to set signal handler");
135
136    let instruction = read_instruction(&args).unwrap_or_else(|e| {
137        eprintln!("Error reading instruction: {e}");
138        process::exit(EXIT_ERROR);
139    });
140
141    let workdir = args.workdir.canonicalize().unwrap_or_else(|e| {
142        eprintln!("Invalid --workdir {:?}: {e}", args.workdir);
143        process::exit(EXIT_ERROR);
144    });
145
146    let output_dir = args.output_dir.clone();
147    if let Err(e) = std::fs::create_dir_all(&output_dir) {
148        eprintln!("Error creating output dir {}: {e}", output_dir.display());
149        process::exit(EXIT_ERROR);
150    }
151
152    let http_client = Arc::new(reqwest_client::ReqwestClient::new());
153    let app = gpui_platform::headless().with_http_client(http_client);
154
155    app.run(move |cx| {
156        let app_state = headless::init(cx);
157        cx.set_staff(true);
158
159        let auth_tasks = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
160            registry
161                .providers()
162                .iter()
163                .map(|p| p.authenticate(cx))
164                .collect::<Vec<_>>()
165        });
166
167        let model_name = args.model.clone();
168        let timeout = args.timeout;
169
170        cx.spawn(async move |cx| {
171            futures::future::join_all(auth_tasks).await;
172
173            let start = Instant::now();
174
175            let (outcome, token_usage) = run_agent(
176                &app_state,
177                &workdir,
178                &instruction,
179                &model_name,
180                timeout,
181                Some(&output_dir),
182                cx,
183            )
184            .await;
185
186            let duration = start.elapsed();
187
188            let (status, error, exit_code) = match &outcome {
189                Ok(AgentOutcome::Completed) => ("completed".to_string(), None, EXIT_OK),
190                Ok(AgentOutcome::Timeout { seconds }) => {
191                    eprintln!("Timeout: agent exceeded {seconds}s time limit");
192                    ("timeout".to_string(), None, EXIT_TIMEOUT)
193                }
194                Ok(AgentOutcome::Interrupted) => {
195                    eprintln!("Interrupted: received SIGTERM, saved partial output");
196                    ("interrupted".to_string(), None, EXIT_INTERRUPTED)
197                }
198                Err(e) => {
199                    eprintln!("Error: {e:#}");
200                    ("error".to_string(), Some(format!("{e:#}")), EXIT_ERROR)
201                }
202            };
203
204            let result = EvalResult {
205                status,
206                error,
207                duration_secs: duration.as_secs_f64(),
208                timeout_secs: timeout,
209                model: model_name.clone(),
210                input_tokens: token_usage.as_ref().map(|u| u.input_tokens),
211                output_tokens: token_usage.as_ref().map(|u| u.output_tokens),
212                cache_creation_input_tokens: token_usage
213                    .as_ref()
214                    .filter(|u| u.cache_creation_input_tokens > 0)
215                    .map(|u| u.cache_creation_input_tokens),
216                cache_read_input_tokens: token_usage
217                    .as_ref()
218                    .filter(|u| u.cache_read_input_tokens > 0)
219                    .map(|u| u.cache_read_input_tokens),
220            };
221
222            match serde_json::to_string_pretty(&result) {
223                Ok(json) => {
224                    if let Err(e) = std::fs::write(output_dir.join("result.json"), &json) {
225                        eprintln!("Error writing result.json: {e:#}");
226                    }
227                    eprintln!("[eval-cli] result: {json}");
228                }
229                Err(e) => eprintln!("Error serializing result: {e:#}"),
230            }
231
232            cx.update(|cx| cx.quit());
233            process::exit(exit_code);
234        })
235        .detach();
236    });
237}
238
239fn read_instruction(args: &Args) -> Result<String> {
240    let text = if let Some(text) = &args.instruction {
241        text.clone()
242    } else {
243        use std::io::Read;
244        let mut buf = String::new();
245        std::io::stdin()
246            .read_to_string(&mut buf)
247            .context("reading instruction from stdin")?;
248        buf
249    };
250    anyhow::ensure!(!text.trim().is_empty(), "instruction is empty");
251    Ok(text)
252}
253
254async fn run_agent(
255    app_state: &Arc<AgentCliAppState>,
256    workdir: &std::path::Path,
257    instruction: &str,
258    model_name: &str,
259    timeout: Option<u64>,
260    output_dir: Option<&std::path::Path>,
261    cx: &mut AsyncApp,
262) -> (Result<AgentOutcome>, Option<language_model::TokenUsage>) {
263    let setup_result: Result<()> = cx.update(|cx| {
264        let selected = SelectedModel::from_str(model_name).map_err(|e| anyhow::anyhow!("{e}"))?;
265        let registry = LanguageModelRegistry::global(cx);
266        let model = registry
267            .read(cx)
268            .available_models(cx)
269            .find(|m| m.id() == selected.model && m.provider_id() == selected.provider)
270            .ok_or_else(|| {
271                let available = registry
272                    .read(cx)
273                    .available_models(cx)
274                    .map(|m| format!("{}/{}", m.provider_id().0, m.id().0))
275                    .collect::<Vec<_>>()
276                    .join(", ");
277                anyhow::anyhow!("Model {model_name} not found. Available: {available}")
278            })?;
279
280        let supports_thinking = model.supports_thinking();
281
282        registry.update(cx, |registry, cx| {
283            registry.set_default_model(
284                Some(language_model::ConfiguredModel {
285                    provider: registry
286                        .provider(&model.provider_id())
287                        .context("Provider not found")?,
288                    model,
289                }),
290                cx,
291            );
292            anyhow::Ok(())
293        })?;
294
295        let (enable_thinking, effort) = if supports_thinking {
296            (true, "\"high\"")
297        } else {
298            (false, "null")
299        };
300        let provider_id = selected.provider.0.to_string();
301        let model_id = selected.model.0.to_string();
302        SettingsStore::update_global(cx, |store, cx| {
303            let settings = format!(
304                r#"{{
305                    "agent": {{
306                        "tool_permissions": {{"default": "allow"}},
307                        "default_model": {{
308                            "provider": "{provider_id}",
309                            "model": "{model_id}",
310                            "enable_thinking": {enable_thinking},
311                            "effort": {effort}
312                        }}
313                    }},
314                    "autosave": "off",
315                    "format_on_save": "off"
316                }}"
317                "#
318            );
319            store.set_user_settings(&settings, cx).ok();
320        });
321
322        anyhow::Ok(())
323    });
324
325    if let Err(e) = setup_result {
326        return (Err(e), None);
327    }
328
329    let project = cx.update(|cx| {
330        Project::local(
331            app_state.client.clone(),
332            app_state.node_runtime.clone(),
333            app_state.user_store.clone(),
334            app_state.languages.clone(),
335            app_state.fs.clone(),
336            None,
337            project::LocalProjectFlags {
338                init_worktree_trust: false,
339                ..Default::default()
340            },
341            cx,
342        )
343    });
344
345    let worktree = project.update(cx, |project, cx| project.create_worktree(workdir, true, cx));
346    let worktree = match worktree.await {
347        Ok(w) => w,
348        Err(e) => return (Err(e).context("creating worktree"), None),
349    };
350
351    let scan_result = worktree.update(cx, |tree, _cx| {
352        tree.as_local()
353            .context("expected local worktree")
354            .map(|local| local.scan_complete())
355    });
356    match scan_result {
357        Ok(future) => future.await,
358        Err(e) => return (Err(e), None),
359    };
360
361    let agent = cx.update(|cx| {
362        let thread_store = cx.new(|cx| ThreadStore::new(cx));
363        NativeAgent::new(
364            thread_store,
365            Templates::new(),
366            None,
367            app_state.fs.clone(),
368            cx,
369        )
370    });
371
372    let connection = Rc::new(NativeAgentConnection(agent.clone()));
373    let acp_thread = match cx
374        .update(|cx| {
375            connection
376                .clone()
377                .new_session(project, PathList::new(&[workdir]), cx)
378        })
379        .await
380    {
381        Ok(t) => t,
382        Err(e) => return (Err(e).context("creating ACP session"), None),
383    };
384
385    let _subscription = cx.subscribe(&acp_thread, |acp_thread, event, cx| {
386        log_acp_thread_event(&acp_thread, event, cx);
387    });
388
389    let message = vec![acp::ContentBlock::Text(acp::TextContent::new(
390        instruction.to_string(),
391    ))];
392
393    let send_future = acp_thread.update(cx, |acp_thread: &mut acp_thread::AcpThread, cx| {
394        acp_thread.send(message, cx)
395    });
396
397    let timeout_future = if let Some(timeout_secs) = timeout {
398        futures::future::Either::Left(
399            cx.background_executor()
400                .timer(Duration::from_secs(timeout_secs)),
401        )
402    } else {
403        futures::future::Either::Right(futures::future::pending::<()>())
404    };
405
406    let sigterm_future = {
407        let executor = cx.background_executor().clone();
408        async move {
409            while !TERMINATED.load(Ordering::Relaxed) {
410                executor.timer(Duration::from_millis(100)).await;
411            }
412        }
413    };
414
415    let outcome = select_biased! {
416        result = send_future.fuse() => match result {
417            Ok(Some(response)) => {
418                eprintln!("[eval-cli] stopped: {:?}", response.stop_reason);
419                if response.stop_reason == acp::StopReason::MaxTokens {
420                    Err(anyhow::anyhow!("Model hit maximum token limit"))
421                } else {
422                    Ok(AgentOutcome::Completed)
423                }
424            }
425            Ok(None) => {
426                eprintln!("[eval-cli] completed (no response)");
427                Ok(AgentOutcome::Completed)
428            }
429            Err(e) => Err(e).context("agent run failed"),
430        },
431        _ = sigterm_future.fuse() => {
432            eprintln!("[eval-cli] received SIGTERM, cancelling...");
433            acp_thread.update(cx, |t: &mut acp_thread::AcpThread, cx| t.cancel(cx)).await;
434            Ok(AgentOutcome::Interrupted)
435        },
436        _ = timeout_future.fuse() => {
437            acp_thread.update(cx, |t: &mut acp_thread::AcpThread, cx| t.cancel(cx)).await;
438            Ok(AgentOutcome::Timeout { seconds: timeout.unwrap_or(0) })
439        }
440    };
441
442    let thread = cx.update(|cx| {
443        let session_id = acp_thread.read(cx).session_id().clone();
444        connection.thread(&session_id, cx)
445    });
446
447    let cumulative_usage = if let Some(thread) = &thread {
448        let db_thread = thread.read_with(cx, |thread, cx| thread.to_db(cx));
449        let db_thread = db_thread.await;
450        let usage = db_thread.cumulative_token_usage;
451        if usage.input_tokens > 0 || usage.output_tokens > 0 {
452            Some(usage)
453        } else {
454            None
455        }
456    } else {
457        None
458    };
459
460    let acp_usage = cx.update(|cx| {
461        acp_thread
462            .read(cx)
463            .token_usage()
464            .map(|usage| language_model::TokenUsage {
465                input_tokens: usage.input_tokens,
466                output_tokens: usage.output_tokens,
467                ..Default::default()
468            })
469    });
470
471    let final_usage = cumulative_usage.or(acp_usage);
472
473    if let (Some(thread), Some(dir)) = (&thread, output_dir) {
474        let markdown = thread.read_with(cx, |thread, _cx| thread.to_markdown());
475        if let Err(e) = std::fs::write(dir.join("thread.md"), markdown) {
476            eprintln!("Error writing thread.md: {e:#}");
477        }
478
479        let db_thread = thread.read_with(cx, |thread, cx| thread.to_db(cx));
480        let db_thread = db_thread.await;
481        match serde_json::to_string_pretty(&db_thread) {
482            Ok(json) => {
483                if let Err(e) = std::fs::write(dir.join("thread.json"), json) {
484                    eprintln!("Error writing thread.json: {e:#}");
485                }
486            }
487            Err(e) => eprintln!("Error serializing thread.json: {e:#}"),
488        }
489    }
490
491    (outcome, final_usage)
492}
493
494fn log_acp_thread_event(
495    acp_thread: &Entity<acp_thread::AcpThread>,
496    event: &acp_thread::AcpThreadEvent,
497    cx: &mut gpui::App,
498) {
499    match event {
500        acp_thread::AcpThreadEvent::NewEntry => {
501            let entries = acp_thread.read(cx).entries();
502            if let Some(acp_thread::AgentThreadEntry::AssistantMessage(message)) = entries.last() {
503                for chunk in &message.chunks {
504                    if let acp_thread::AssistantMessageChunk::Message { block } = chunk {
505                        if let acp_thread::ContentBlock::Markdown { markdown } = block {
506                            let text = markdown.read(cx).source().to_string();
507                            if !text.is_empty() {
508                                eprint!("{text}");
509                            }
510                        }
511                    }
512                }
513            }
514        }
515        acp_thread::AcpThreadEvent::EntryUpdated(index) => {
516            let entries = acp_thread.read(cx).entries();
517            if let Some(acp_thread::AgentThreadEntry::ToolCall(tool_call)) = entries.get(*index) {
518                if let Some(name) = &tool_call.tool_name {
519                    match &tool_call.status {
520                        acp_thread::ToolCallStatus::Completed => {
521                            eprintln!("[tool] {name}");
522                        }
523                        acp_thread::ToolCallStatus::Failed => {
524                            eprintln!("[tool] {name}");
525                        }
526                        acp_thread::ToolCallStatus::Rejected => {
527                            eprintln!("[tool] {name} rejected");
528                        }
529                        acp_thread::ToolCallStatus::Canceled => {
530                            eprintln!("[tool] {name} canceled");
531                        }
532                        _ => {}
533                    }
534                }
535            }
536        }
537        acp_thread::AcpThreadEvent::Stopped(reason) => {
538            eprintln!("\n[eval-cli] stopped: {reason:?}");
539        }
540        acp_thread::AcpThreadEvent::Error => {
541            eprintln!("[eval-cli] error event");
542        }
543        acp_thread::AcpThreadEvent::Retry(status) => {
544            eprintln!("[eval-cli] retry: {status:?}");
545        }
546        acp_thread::AcpThreadEvent::SubagentSpawned(session_id) => {
547            eprintln!("[eval-cli] subagent spawned: {session_id}");
548        }
549        _ => {}
550    }
551}