1use gh_workflow::{Event, Expression, Job, Run, Schedule, Step, Use, Workflow, WorkflowDispatch};
2
3use crate::tasks::workflows::{
4 runners::{self, Platform},
5 steps::{self, FluentBuilder as _, NamedJob, named, setup_cargo_config},
6 vars::{self, Input},
7};
8
9pub(crate) fn run_agent_evals() -> Workflow {
10 let agent_evals = agent_evals();
11 let model_name = Input::string("model_name", None);
12
13 named::workflow()
14 .on(Event::default().workflow_dispatch(
15 WorkflowDispatch::default().add_input(model_name.name, model_name.input()),
16 ))
17 .concurrency(vars::one_workflow_per_non_main_branch())
18 .add_env(("CARGO_TERM_COLOR", "always"))
19 .add_env(("CARGO_INCREMENTAL", 0))
20 .add_env(("RUST_BACKTRACE", 1))
21 .add_env(("ANTHROPIC_API_KEY", vars::ANTHROPIC_API_KEY))
22 .add_env(("OPENAI_API_KEY", vars::OPENAI_API_KEY))
23 .add_env(("GOOGLE_AI_API_KEY", vars::GOOGLE_AI_API_KEY))
24 .add_env(("GOOGLE_CLOUD_PROJECT", vars::GOOGLE_CLOUD_PROJECT))
25 .add_env(("ZED_CLIENT_CHECKSUM_SEED", vars::ZED_CLIENT_CHECKSUM_SEED))
26 .add_env(("ZED_EVAL_TELEMETRY", 1))
27 .add_env(("MODEL_NAME", model_name.to_string()))
28 .add_job(agent_evals.name, agent_evals.job)
29}
30
31pub(crate) fn run_unit_evals() -> Workflow {
32 let model_name = Input::string("model_name", None);
33 let commit_sha = Input::string("commit_sha", None);
34
35 let unit_evals = named::job(unit_evals(Some(&commit_sha)));
36
37 named::workflow()
38 .name("run_unit_evals")
39 .on(Event::default().workflow_dispatch(
40 WorkflowDispatch::default()
41 .add_input(model_name.name, model_name.input())
42 .add_input(commit_sha.name, commit_sha.input()),
43 ))
44 .concurrency(vars::one_workflow_per_non_main_branch())
45 .add_env(("CARGO_TERM_COLOR", "always"))
46 .add_env(("CARGO_INCREMENTAL", 0))
47 .add_env(("RUST_BACKTRACE", 1))
48 .add_env(("ZED_CLIENT_CHECKSUM_SEED", vars::ZED_CLIENT_CHECKSUM_SEED))
49 .add_env(("ZED_EVAL_TELEMETRY", 1))
50 .add_env(("MODEL_NAME", model_name.to_string()))
51 .add_job(unit_evals.name, unit_evals.job)
52}
53
54fn add_api_keys(step: Step<Run>) -> Step<Run> {
55 step.add_env(("ANTHROPIC_API_KEY", vars::ANTHROPIC_API_KEY))
56 .add_env(("OPENAI_API_KEY", vars::OPENAI_API_KEY))
57 .add_env(("GOOGLE_AI_API_KEY", vars::GOOGLE_AI_API_KEY))
58 .add_env(("GOOGLE_CLOUD_PROJECT", vars::GOOGLE_CLOUD_PROJECT))
59}
60
61fn agent_evals() -> NamedJob {
62 fn run_eval() -> Step<Run> {
63 named::bash(
64 "cargo run --package=eval -- --repetitions=8 --concurrency=1 --model \"${MODEL_NAME}\"",
65 )
66 }
67
68 named::job(
69 Job::default()
70 .runs_on(runners::LINUX_DEFAULT)
71 .timeout_minutes(60_u32 * 10)
72 .add_step(steps::checkout_repo())
73 .add_step(steps::cache_rust_dependencies_namespace())
74 .map(steps::install_linux_dependencies)
75 .add_step(setup_cargo_config(Platform::Linux))
76 .add_step(steps::script("cargo build --package=eval"))
77 .add_step(add_api_keys(run_eval()))
78 .add_step(steps::cleanup_cargo_config(Platform::Linux)),
79 )
80}
81
82pub(crate) fn run_cron_unit_evals() -> Workflow {
83 let unit_evals = cron_unit_evals();
84
85 named::workflow()
86 .name("run_cron_unit_evals")
87 .on(Event::default()
88 .schedule([
89 // GitHub might drop jobs at busy times, so we choose a random time in the middle of the night.
90 Schedule::default().cron("47 1 * * 2"),
91 ])
92 .workflow_dispatch(WorkflowDispatch::default()))
93 .concurrency(vars::one_workflow_per_non_main_branch())
94 .add_env(("CARGO_TERM_COLOR", "always"))
95 .add_env(("CARGO_INCREMENTAL", 0))
96 .add_env(("RUST_BACKTRACE", 1))
97 .add_env(("ZED_CLIENT_CHECKSUM_SEED", vars::ZED_CLIENT_CHECKSUM_SEED))
98 .add_job(unit_evals.name, unit_evals.job)
99}
100
101fn cron_unit_evals() -> NamedJob {
102 fn send_failure_to_slack() -> Step<Use> {
103 named::uses(
104 "slackapi",
105 "slack-github-action",
106 "b0fa283ad8fea605de13dc3f449259339835fc52",
107 )
108 .if_condition(Expression::new("${{ failure() }}"))
109 .add_with(("method", "chat.postMessage"))
110 .add_with(("token", vars::SLACK_APP_ZED_UNIT_EVALS_BOT_TOKEN))
111 .add_with(("payload", indoc::indoc!{r#"
112 channel: C04UDRNNJFQ
113 text: "Unit Evals Failed: https://github.com/zed-industries/zed/actions/runs/${{ github.run_id }}"
114 "#}))
115 }
116
117 named::job(unit_evals(None).add_step(send_failure_to_slack()))
118}
119
120fn unit_evals(commit: Option<&Input>) -> Job {
121 fn send_failure_to_slack() -> Step<Use> {
122 named::uses(
123 "slackapi",
124 "slack-github-action",
125 "b0fa283ad8fea605de13dc3f449259339835fc52",
126 )
127 .if_condition(Expression::new("${{ failure() }}"))
128 .add_with(("method", "chat.postMessage"))
129 .add_with(("token", vars::SLACK_APP_ZED_UNIT_EVALS_BOT_TOKEN))
130 .add_with(("payload", indoc::indoc!{r#"
131 channel: C04UDRNNJFQ
132 text: "Unit Evals Failed: https://github.com/zed-industries/zed/actions/runs/${{ github.run_id }}"
133 "#}))
134 }
135
136 let script_step = add_api_keys(steps::script("./script/run-unit-evals"));
137
138 Job::default()
139 .runs_on(runners::LINUX_DEFAULT)
140 .add_step(steps::checkout_repo())
141 .add_step(steps::setup_cargo_config(Platform::Linux))
142 .add_step(steps::cache_rust_dependencies_namespace())
143 .map(steps::install_linux_dependencies)
144 .add_step(steps::cargo_install_nextest(Platform::Linux))
145 .add_step(steps::clear_target_dir_if_large(Platform::Linux))
146 .add_step(match commit {
147 Some(commit) => script_step.add_env(("UNIT_EVAL_COMMIT", commit)),
148 None => script_step,
149 })
150 .add_step(send_failure_to_slack())
151 .add_step(steps::cleanup_cargo_config(Platform::Linux))
152}