1use std::{
2 error::Error,
3 fmt::{self, Debug},
4 sync::{Arc, Mutex},
5 time::Duration,
6};
7
8use crate::{
9 ToolMetrics,
10 assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
11};
12use agent::ThreadEvent;
13use anyhow::{Result, anyhow};
14use async_trait::async_trait;
15use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
16use gpui::{AppContext, AsyncApp, Entity};
17use language_model::{LanguageModel, Role, StopReason};
18
19pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
20
21#[async_trait(?Send)]
22pub trait Example {
23 fn meta(&self) -> ExampleMetadata;
24 async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
25 fn diff_assertions(&self) -> Vec<JudgeAssertion> {
26 Vec::new()
27 }
28 fn thread_assertions(&self) -> Vec<JudgeAssertion> {
29 Vec::new()
30 }
31}
32
33#[derive(Clone, Debug)]
34pub struct JudgeAssertion {
35 pub id: String,
36 pub description: String,
37}
38
39#[derive(Clone, Debug)]
40pub struct ExampleMetadata {
41 pub name: String,
42 pub url: String,
43 pub revision: String,
44 pub language_server: Option<LanguageServer>,
45 pub max_assertions: Option<usize>,
46}
47
48#[derive(Clone, Debug)]
49pub struct LanguageServer {
50 pub file_extension: String,
51 pub allow_preexisting_diagnostics: bool,
52}
53
54impl ExampleMetadata {
55 pub fn repo_name(&self) -> String {
56 self.url
57 .split('/')
58 .next_back()
59 .unwrap_or(&"")
60 .trim_end_matches(".git")
61 .into()
62 }
63}
64
65pub struct FailedAssertion(pub String);
66
67impl fmt::Debug for FailedAssertion {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 write!(f, "Assertion failure: {}", self.0)
70 }
71}
72
73impl fmt::Display for FailedAssertion {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 write!(f, "{}", self.0)
76 }
77}
78
79impl Error for FailedAssertion {}
80
81pub struct ExampleContext {
82 meta: ExampleMetadata,
83 log_prefix: String,
84 agent_thread: Entity<agent::Thread>,
85 app: AsyncApp,
86 model: Arc<dyn LanguageModel>,
87 pub assertions: AssertionsReport,
88 pub tool_metrics: Arc<Mutex<ToolMetrics>>,
89}
90
91impl ExampleContext {
92 pub fn new(
93 meta: ExampleMetadata,
94 log_prefix: String,
95 agent_thread: Entity<agent::Thread>,
96 model: Arc<dyn LanguageModel>,
97 app: AsyncApp,
98 ) -> Self {
99 let assertions = AssertionsReport::new(meta.max_assertions);
100
101 Self {
102 meta,
103 log_prefix,
104 agent_thread,
105 assertions,
106 model,
107 app,
108 tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
109 }
110 }
111
112 pub fn push_user_message(&mut self, text: impl ToString) {
113 self.app
114 .update_entity(&self.agent_thread, |thread, cx| {
115 thread.insert_user_message(text.to_string(), vec![], None, cx);
116 })
117 .unwrap();
118 }
119
120 pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
121 let message = message.to_string();
122 self.log_assertion(
123 if expected {
124 Ok(())
125 } else {
126 Err(anyhow::Error::from(FailedAssertion(message.clone())))
127 },
128 message,
129 )
130 }
131
132 pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
133 let message = message.to_string();
134 self.log_assertion(
135 match option {
136 Some(value) => Ok(value),
137 None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
138 },
139 message,
140 )
141 }
142
143 #[allow(dead_code)]
144 pub fn assert_eq<T: PartialEq + Debug>(
145 &mut self,
146 left: T,
147 right: T,
148 message: impl ToString,
149 ) -> Result<()> {
150 let message = message.to_string();
151 self.log_assertion(
152 if left == right {
153 Ok(())
154 } else {
155 println!("{}{:#?} != {:#?}", self.log_prefix, left, right);
156 Err(anyhow::Error::from(FailedAssertion(message.clone())))
157 },
158 message,
159 )
160 }
161
162 fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
163 if let Some(max) = self.meta.max_assertions {
164 if self.assertions.run_count() > max {
165 return Err(anyhow!(
166 "More assertions were run than the stated max_assertions of {}",
167 max
168 ));
169 }
170 }
171
172 self.assertions.ran.push(RanAssertion {
173 id: message.clone(),
174 result: Ok(RanAssertionResult {
175 analysis: None,
176 passed: result.is_ok(),
177 }),
178 });
179
180 if result.is_ok() {
181 println!("{}✅ {}", self.log_prefix, message);
182 } else {
183 println!("{}❌ {}", self.log_prefix, message);
184 }
185
186 result
187 }
188
189 pub async fn run_to_end(&mut self) -> Result<Response> {
190 self.run_turns(u32::MAX).await
191 }
192
193 pub async fn run_turn(&mut self) -> Result<Response> {
194 self.run_turns(1).await
195 }
196
197 pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
198 let (mut tx, mut rx) = mpsc::channel(1);
199
200 let tool_metrics = self.tool_metrics.clone();
201 let log_prefix = self.log_prefix.clone();
202 let _subscription = self.app.subscribe(
203 &self.agent_thread,
204 move |thread, event: &ThreadEvent, cx| match event {
205 ThreadEvent::ShowError(thread_error) => {
206 tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
207 }
208 ThreadEvent::Stopped(reason) => match reason {
209 Ok(StopReason::EndTurn) => {
210 tx.close_channel();
211 }
212 Ok(StopReason::ToolUse) => {
213 if thread.read(cx).remaining_turns() == 0 {
214 tx.close_channel();
215 }
216 }
217 Ok(StopReason::MaxTokens) => {
218 tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
219 }
220 Err(err) => {
221 tx.try_send(Err(anyhow!(err.clone()))).ok();
222 }
223 },
224 ThreadEvent::StreamedAssistantText(_, _)
225 | ThreadEvent::StreamedAssistantThinking(_, _)
226 | ThreadEvent::UsePendingTools { .. } => {}
227 ThreadEvent::ToolFinished {
228 tool_use_id,
229 pending_tool_use,
230 ..
231 } => {
232 thread.update(cx, |thread, _cx| {
233 if let Some(tool_use) = pending_tool_use {
234 let mut tool_metrics = tool_metrics.lock().unwrap();
235 if let Some(tool_result) = thread.tool_result(&tool_use_id) {
236 let message = if tool_result.is_error {
237 format!("TOOL FAILED: {}", tool_use.name)
238 } else {
239 format!("TOOL FINISHED: {}", tool_use.name)
240 };
241 println!("{log_prefix}{message}");
242 tool_metrics
243 .insert(tool_result.tool_name.clone(), !tool_result.is_error);
244 } else {
245 let message =
246 format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
247 println!("{log_prefix}{message}");
248 tool_metrics.insert(tool_use.name.clone(), true);
249 }
250 }
251 });
252 }
253 ThreadEvent::ToolConfirmationNeeded => {
254 panic!(
255 "{}Bug: Tool confirmation should not be required in eval",
256 log_prefix
257 );
258 }
259 ThreadEvent::StreamedCompletion
260 | ThreadEvent::MessageAdded(_)
261 | ThreadEvent::MessageEdited(_)
262 | ThreadEvent::MessageDeleted(_)
263 | ThreadEvent::SummaryChanged
264 | ThreadEvent::SummaryGenerated
265 | ThreadEvent::ReceivedTextChunk
266 | ThreadEvent::StreamedToolUse { .. }
267 | ThreadEvent::CheckpointChanged
268 | ThreadEvent::UsageUpdated(_) => {
269 tx.try_send(Ok(())).ok();
270 if std::env::var("ZED_EVAL_DEBUG").is_ok() {
271 println!("{}Event: {:#?}", log_prefix, event);
272 }
273 }
274 },
275 );
276
277 let model = self.model.clone();
278
279 let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
280 thread.set_remaining_turns(iterations);
281 thread.send_to_model(model, None, cx);
282 thread.messages().len()
283 })?;
284
285 loop {
286 select_biased! {
287 result = rx.next() => {
288 if let Some(result) = result {
289 result?;
290 } else {
291 break;
292 }
293 }
294 _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
295 return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
296 }
297 }
298 }
299
300 let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
301 let mut messages = Vec::new();
302 for message in thread.messages().skip(message_count_before) {
303 messages.push(Message {
304 _role: message.role,
305 _text: message.to_string(),
306 tool_use: thread
307 .tool_uses_for_message(message.id, cx)
308 .into_iter()
309 .map(|tool_use| ToolUse {
310 name: tool_use.name.to_string(),
311 value: tool_use.input,
312 })
313 .collect(),
314 });
315 }
316 messages
317 })?;
318
319 let response = Response::new(messages);
320
321 Ok(response)
322 }
323}
324
325#[derive(Debug)]
326pub struct Response {
327 messages: Vec<Message>,
328}
329
330impl Response {
331 pub fn new(messages: Vec<Message>) -> Self {
332 Self { messages }
333 }
334
335 pub fn expect_tool(
336 &self,
337 tool_name: &'static str,
338 cx: &mut ExampleContext,
339 ) -> Result<&ToolUse> {
340 let result = self.messages.iter().find_map(|msg| {
341 msg.tool_use
342 .iter()
343 .find(|tool_use| tool_use.name == tool_name)
344 });
345 cx.assert_some(result, format!("called `{}`", tool_name))
346 }
347}
348
349#[derive(Debug)]
350pub struct Message {
351 _role: Role,
352 _text: String,
353 tool_use: Vec<ToolUse>,
354}
355
356#[derive(Debug)]
357pub struct ToolUse {
358 name: String,
359 value: serde_json::Value,
360}
361
362impl ToolUse {
363 pub fn expect_input<Input>(&self, cx: &mut ExampleContext) -> Result<Input>
364 where
365 Input: for<'de> serde::Deserialize<'de>,
366 {
367 let result =
368 serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err));
369 cx.log_assertion(result, format!("valid `{}` input", &self.name))
370 }
371}