run_agent_evals.rs

 1use gh_workflow::{Event, Expression, Job, Run, Schedule, Step, Use, Workflow, WorkflowDispatch};
 2
 3use crate::tasks::workflows::{
 4    runners::{self, Platform},
 5    steps::{self, FluentBuilder as _, NamedJob, named, setup_cargo_config},
 6    vars::{self, Input},
 7};
 8
 9pub(crate) fn run_agent_evals() -> Workflow {
10    let agent_evals = agent_evals();
11    let model_name = Input::string("model_name", None);
12
13    named::workflow()
14        .on(Event::default().workflow_dispatch(
15            WorkflowDispatch::default().add_input(model_name.name, model_name.input()),
16        ))
17        .concurrency(vars::one_workflow_per_non_main_branch())
18        .add_env(("CARGO_TERM_COLOR", "always"))
19        .add_env(("CARGO_INCREMENTAL", 0))
20        .add_env(("RUST_BACKTRACE", 1))
21        .add_env(("ANTHROPIC_API_KEY", vars::ANTHROPIC_API_KEY))
22        .add_env(("ZED_CLIENT_CHECKSUM_SEED", vars::ZED_CLIENT_CHECKSUM_SEED))
23        .add_env(("ZED_EVAL_TELEMETRY", 1))
24        .add_env(("MODEL_NAME", model_name.to_string()))
25        .add_job(agent_evals.name, agent_evals.job)
26}
27
28fn agent_evals() -> NamedJob {
29    fn run_eval() -> Step<Run> {
30        named::bash(
31            "cargo run --package=eval -- --repetitions=8 --concurrency=1 --model \"${MODEL_NAME}\"",
32        )
33    }
34
35    named::job(
36        Job::default()
37            .runs_on(runners::LINUX_DEFAULT)
38            .timeout_minutes(60_u32 * 10)
39            .add_step(steps::checkout_repo())
40            .add_step(steps::cache_rust_dependencies_namespace())
41            .map(steps::install_linux_dependencies)
42            .add_step(setup_cargo_config(Platform::Linux))
43            .add_step(steps::script("cargo build --package=eval"))
44            .add_step(run_eval())
45            .add_step(steps::cleanup_cargo_config(Platform::Linux)),
46    )
47}
48
49pub(crate) fn run_unit_evals() -> Workflow {
50    let unit_evals = unit_evals();
51
52    named::workflow()
53        .on(Event::default()
54            .schedule([
55                // GitHub might drop jobs at busy times, so we choose a random time in the middle of the night.
56                Schedule::default().cron("47 1 * * 2"),
57            ])
58            .workflow_dispatch(WorkflowDispatch::default()))
59        .concurrency(vars::one_workflow_per_non_main_branch())
60        .add_env(("CARGO_TERM_COLOR", "always"))
61        .add_env(("CARGO_INCREMENTAL", 0))
62        .add_env(("RUST_BACKTRACE", 1))
63        .add_env(("ZED_CLIENT_CHECKSUM_SEED", vars::ZED_CLIENT_CHECKSUM_SEED))
64        .add_job(unit_evals.name, unit_evals.job)
65}
66
67fn unit_evals() -> NamedJob {
68    fn send_failure_to_slack() -> Step<Use> {
69        named::uses(
70            "slackapi",
71            "slack-github-action",
72            "b0fa283ad8fea605de13dc3f449259339835fc52",
73        )
74        .if_condition(Expression::new("${{ failure() }}"))
75        .add_with(("method", "chat.postMessage"))
76        .add_with(("token", vars::SLACK_APP_ZED_UNIT_EVALS_BOT_TOKEN))
77        .add_with(("payload", indoc::indoc!{r#"
78            channel: C04UDRNNJFQ
79            text: "Unit Evals Failed: https://github.com/zed-industries/zed/actions/runs/${{ github.run_id }}"
80        "#}))
81    }
82
83    named::job(
84        Job::default()
85            .runs_on(runners::LINUX_DEFAULT)
86            .add_step(steps::checkout_repo())
87            .add_step(steps::setup_cargo_config(Platform::Linux))
88            .add_step(steps::cache_rust_dependencies_namespace())
89            .map(steps::install_linux_dependencies)
90            .add_step(steps::cargo_install_nextest(Platform::Linux))
91            .add_step(steps::clear_target_dir_if_large(Platform::Linux))
92            .add_step(
93                steps::script("./script/run-unit-evals")
94                    .add_env(("ANTHROPIC_API_KEY", vars::ANTHROPIC_API_KEY)),
95            )
96            .add_step(send_failure_to_slack())
97            .add_step(steps::cleanup_cargo_config(Platform::Linux)),
98    )
99}