1use std::{
2 error::Error,
3 fmt::{self, Debug},
4 path::Path,
5 sync::{Arc, Mutex},
6 time::Duration,
7};
8
9use crate::{
10 ToolMetrics,
11 assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
12};
13use agent::{ThreadEvent, ZedAgentThread};
14use agent_settings::AgentProfileId;
15use anyhow::{Result, anyhow};
16use async_trait::async_trait;
17use buffer_diff::DiffHunkStatus;
18use collections::HashMap;
19use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
20use gpui::{App, AppContext, AsyncApp, Entity};
21use language_model::{LanguageModel, Role, StopReason};
22use zed_llm_client::CompletionIntent;
23
24pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
25
26#[async_trait(?Send)]
27pub trait Example {
28 fn meta(&self) -> ExampleMetadata;
29 async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
30 fn diff_assertions(&self) -> Vec<JudgeAssertion> {
31 Vec::new()
32 }
33 fn thread_assertions(&self) -> Vec<JudgeAssertion> {
34 Vec::new()
35 }
36}
37
38#[derive(Clone, Debug)]
39pub struct JudgeAssertion {
40 pub id: String,
41 pub description: String,
42}
43
44#[derive(Clone, Debug)]
45pub struct ExampleMetadata {
46 pub name: String,
47 pub url: String,
48 pub revision: String,
49 pub language_server: Option<LanguageServer>,
50 pub max_assertions: Option<usize>,
51 pub profile_id: AgentProfileId,
52 pub existing_thread_json: Option<String>,
53 pub max_turns: Option<u32>,
54}
55
56#[derive(Clone, Debug)]
57pub struct LanguageServer {
58 pub file_extension: String,
59 pub allow_preexisting_diagnostics: bool,
60}
61
62impl ExampleMetadata {
63 pub fn repo_name(&self) -> String {
64 self.url
65 .split('/')
66 .next_back()
67 .unwrap_or(&"")
68 .trim_end_matches(".git")
69 .into()
70 }
71}
72
73pub struct FailedAssertion(pub String);
74
75impl fmt::Debug for FailedAssertion {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 write!(f, "Assertion failure: {}", self.0)
78 }
79}
80
81impl fmt::Display for FailedAssertion {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 write!(f, "{}", self.0)
84 }
85}
86
87impl Error for FailedAssertion {}
88
89pub struct ExampleContext {
90 meta: ExampleMetadata,
91 log_prefix: String,
92 agent_thread: Entity<agent::ZedAgentThread>,
93 app: AsyncApp,
94 model: Arc<dyn LanguageModel>,
95 pub assertions: AssertionsReport,
96 pub tool_metrics: Arc<Mutex<ToolMetrics>>,
97}
98
99impl ExampleContext {
100 pub fn new(
101 meta: ExampleMetadata,
102 log_prefix: String,
103 agent_thread: Entity<ZedAgentThread>,
104 model: Arc<dyn LanguageModel>,
105 app: AsyncApp,
106 ) -> Self {
107 let assertions = AssertionsReport::new(meta.max_assertions);
108
109 Self {
110 meta,
111 log_prefix,
112 agent_thread,
113 assertions,
114 model,
115 app,
116 tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
117 }
118 }
119
120 pub fn push_user_message(&mut self, text: impl ToString) {
121 self.app
122 .update_entity(&self.agent_thread, |thread, cx| {
123 thread.insert_user_message(text.to_string(), cx);
124 })
125 .unwrap();
126 }
127
128 pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
129 let message = message.to_string();
130 self.log_assertion(
131 if expected {
132 Ok(())
133 } else {
134 Err(anyhow::Error::from(FailedAssertion(message.clone())))
135 },
136 message,
137 )
138 }
139
140 pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
141 let message = message.to_string();
142 self.log_assertion(
143 match option {
144 Some(value) => Ok(value),
145 None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
146 },
147 message,
148 )
149 }
150
151 #[allow(dead_code)]
152 pub fn assert_eq<T: PartialEq + Debug>(
153 &mut self,
154 left: T,
155 right: T,
156 message: impl ToString,
157 ) -> Result<()> {
158 let message = message.to_string();
159 self.log_assertion(
160 if left == right {
161 Ok(())
162 } else {
163 println!(
164 "{}{}",
165 self.log_prefix,
166 pretty_assertions::Comparison::new(&left, &right)
167 );
168 Err(anyhow::Error::from(FailedAssertion(message.clone())))
169 },
170 message,
171 )
172 }
173
174 fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
175 if let Some(max) = self.meta.max_assertions {
176 anyhow::ensure!(
177 self.assertions.run_count() <= max,
178 "More assertions were run than the stated max_assertions of {max}"
179 );
180 }
181
182 self.assertions.ran.push(RanAssertion {
183 id: message.clone(),
184 result: Ok(RanAssertionResult {
185 analysis: None,
186 passed: result.is_ok(),
187 }),
188 });
189
190 if result.is_ok() {
191 println!("{}✅ {}", self.log_prefix, message);
192 } else {
193 println!("{}❌ {}", self.log_prefix, message);
194 }
195
196 result
197 }
198
199 pub async fn run_to_end(&mut self) -> Result<Response> {
200 self.run_turns(u32::MAX).await
201 }
202
203 pub async fn run_turn(&mut self) -> Result<Response> {
204 self.run_turns(1).await
205 }
206
207 pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
208 let (mut tx, mut rx) = mpsc::channel(1);
209
210 let tool_metrics = self.tool_metrics.clone();
211 let log_prefix = self.log_prefix.clone();
212 let _subscription = self.app.subscribe(
213 &self.agent_thread,
214 move |thread, event: &ThreadEvent, cx| match event {
215 ThreadEvent::ShowError(thread_error) => {
216 tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
217 }
218 ThreadEvent::RetriesFailed { .. } => {
219 // Ignore retries failed events
220 }
221 ThreadEvent::Stopped(reason) => match reason {
222 Ok(StopReason::EndTurn) => {
223 tx.close_channel();
224 }
225 Ok(StopReason::ToolUse) => {
226 if thread.read(cx).remaining_turns() == 0 {
227 tx.close_channel();
228 }
229 }
230 Ok(StopReason::MaxTokens) => {
231 tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
232 }
233 Ok(StopReason::Refusal) => {
234 tx.try_send(Err(anyhow!("Model refused to generate content")))
235 .ok();
236 }
237 Err(err) => {
238 tx.try_send(Err(anyhow!(err.clone()))).ok();
239 }
240 },
241 ThreadEvent::NewRequest
242 | ThreadEvent::StreamedAssistantText(_, _)
243 | ThreadEvent::StreamedAssistantThinking(_, _)
244 | ThreadEvent::UsePendingTools { .. }
245 | ThreadEvent::CompletionCanceled => {}
246 ThreadEvent::ToolUseLimitReached => {}
247 ThreadEvent::StreamedToolUse2 { .. } => {}
248 ThreadEvent::ToolFinished {
249 tool_use_id,
250 pending_tool_use,
251 ..
252 } => {
253 thread.update(cx, |thread, _cx| {
254 if let Some(tool_use) = pending_tool_use {
255 let mut tool_metrics = tool_metrics.lock().unwrap();
256 if let Some(tool_result) = thread.tool_result(&tool_use_id) {
257 let message = if tool_result.is_error {
258 format!("✖︎ {}", tool_use.name)
259 } else {
260 format!("✔︎ {}", tool_use.name)
261 };
262 println!("{log_prefix}{message}");
263 tool_metrics
264 .insert(tool_result.tool_name.clone(), !tool_result.is_error);
265 } else {
266 let message =
267 format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
268 println!("{log_prefix}{message}");
269 tool_metrics.insert(tool_use.name.clone(), true);
270 }
271 }
272 });
273 }
274 ThreadEvent::InvalidToolInput { .. } => {
275 println!("{log_prefix} invalid tool input");
276 }
277 ThreadEvent::MissingToolUse {
278 tool_use_id: _,
279 ui_text,
280 } => {
281 println!("{log_prefix} {ui_text}");
282 }
283 ThreadEvent::ToolConfirmationNeeded => {
284 panic!(
285 "{}Bug: Tool confirmation should not be required in eval",
286 log_prefix
287 );
288 }
289 ThreadEvent::StreamedCompletion
290 | ThreadEvent::MessageAdded(_)
291 | ThreadEvent::MessageEdited(_)
292 | ThreadEvent::MessageDeleted(_)
293 | ThreadEvent::SummaryChanged
294 | ThreadEvent::SummaryGenerated
295 | ThreadEvent::ProfileChanged
296 | ThreadEvent::ReceivedTextChunk
297 | ThreadEvent::StreamedToolUse { .. }
298 | ThreadEvent::CheckpointChanged
299 | ThreadEvent::CancelEditing => {
300 tx.try_send(Ok(())).ok();
301 if std::env::var("ZED_EVAL_DEBUG").is_ok() {
302 println!("{}Event: {:#?}", log_prefix, event);
303 }
304 }
305 },
306 );
307
308 let model = self.model.clone();
309
310 let message_count_before = self.app.update_entity(&self.agent_thread, |agent, cx| {
311 agent.set_remaining_turns(iterations);
312 agent.send_to_model(model, CompletionIntent::UserPrompt, None, cx);
313 agent.messages().len()
314 })?;
315
316 loop {
317 select_biased! {
318 result = rx.next() => {
319 if let Some(result) = result {
320 result?;
321 } else {
322 break;
323 }
324 }
325 _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
326 anyhow::bail!("Agentic loop stalled - waited {THREAD_EVENT_TIMEOUT:?} without any events");
327 }
328 }
329 }
330
331 let messages = self.app.read_entity(&self.agent_thread, |agent, cx| {
332 let mut messages = Vec::new();
333 for message in agent.messages().skip(message_count_before) {
334 messages.push(Message {
335 _role: message.role,
336 text: message.to_string(),
337 tool_use: agent
338 .tool_uses_for_message(message.id, cx)
339 .into_iter()
340 .map(|tool_use| ToolUse {
341 name: tool_use.name.to_string(),
342 value: tool_use.input,
343 })
344 .collect(),
345 });
346 }
347 messages
348 })?;
349
350 let response = Response::new(messages);
351
352 Ok(response)
353 }
354
355 pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
356 self.agent_thread
357 .read_with(&self.app, |thread, cx| {
358 let action_log = thread.action_log().read(cx);
359 HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
360 |(buffer, diff)| {
361 let snapshot = buffer.read(cx).snapshot();
362
363 let file = snapshot.file().unwrap();
364 let diff = diff.read(cx);
365 let base_text = diff.base_text().text();
366
367 let hunks = diff
368 .hunks(&snapshot, cx)
369 .map(|hunk| FileEditHunk {
370 base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
371 text: snapshot
372 .text_for_range(hunk.range.clone())
373 .collect::<String>(),
374 status: hunk.status(),
375 })
376 .collect();
377
378 (file.path().clone(), FileEdits { hunks })
379 },
380 ))
381 })
382 .unwrap()
383 }
384
385 pub fn agent_thread(&self) -> Entity<ZedAgentThread> {
386 self.agent_thread.clone()
387 }
388}
389
390impl AppContext for ExampleContext {
391 type Result<T> = anyhow::Result<T>;
392
393 fn new<T: 'static>(
394 &mut self,
395 build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
396 ) -> Self::Result<Entity<T>> {
397 self.app.new(build_entity)
398 }
399
400 fn reserve_entity<T: 'static>(&mut self) -> Self::Result<gpui::Reservation<T>> {
401 self.app.reserve_entity()
402 }
403
404 fn insert_entity<T: 'static>(
405 &mut self,
406 reservation: gpui::Reservation<T>,
407 build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
408 ) -> Self::Result<Entity<T>> {
409 self.app.insert_entity(reservation, build_entity)
410 }
411
412 fn update_entity<T, R>(
413 &mut self,
414 handle: &Entity<T>,
415 update: impl FnOnce(&mut T, &mut gpui::Context<T>) -> R,
416 ) -> Self::Result<R>
417 where
418 T: 'static,
419 {
420 self.app.update_entity(handle, update)
421 }
422
423 fn read_entity<T, R>(
424 &self,
425 handle: &Entity<T>,
426 read: impl FnOnce(&T, &App) -> R,
427 ) -> Self::Result<R>
428 where
429 T: 'static,
430 {
431 self.app.read_entity(handle, read)
432 }
433
434 fn update_window<T, F>(&mut self, window: gpui::AnyWindowHandle, f: F) -> Result<T>
435 where
436 F: FnOnce(gpui::AnyView, &mut gpui::Window, &mut App) -> T,
437 {
438 self.app.update_window(window, f)
439 }
440
441 fn read_window<T, R>(
442 &self,
443 window: &gpui::WindowHandle<T>,
444 read: impl FnOnce(Entity<T>, &App) -> R,
445 ) -> Result<R>
446 where
447 T: 'static,
448 {
449 self.app.read_window(window, read)
450 }
451
452 fn background_spawn<R>(
453 &self,
454 future: impl std::future::Future<Output = R> + Send + 'static,
455 ) -> gpui::Task<R>
456 where
457 R: Send + 'static,
458 {
459 self.app.background_spawn(future)
460 }
461
462 fn read_global<G, R>(&self, callback: impl FnOnce(&G, &App) -> R) -> Self::Result<R>
463 where
464 G: gpui::Global,
465 {
466 self.app.read_global(callback)
467 }
468}
469
470#[derive(Debug)]
471pub struct Response {
472 messages: Vec<Message>,
473}
474
475impl Response {
476 pub fn new(messages: Vec<Message>) -> Self {
477 Self { messages }
478 }
479
480 pub fn expect_tool(
481 &self,
482 tool_name: &'static str,
483 cx: &mut ExampleContext,
484 ) -> Result<&ToolUse> {
485 let result = self.find_tool_call(tool_name);
486 cx.assert_some(result, format!("called `{}`", tool_name))
487 }
488
489 pub fn find_tool_call(&self, tool_name: &str) -> Option<&ToolUse> {
490 self.messages.iter().rev().find_map(|msg| {
491 msg.tool_use
492 .iter()
493 .find(|tool_use| tool_use.name == tool_name)
494 })
495 }
496
497 #[allow(dead_code)]
498 pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
499 self.messages.iter().flat_map(|msg| &msg.tool_use)
500 }
501
502 pub fn texts(&self) -> impl Iterator<Item = String> {
503 self.messages.iter().map(|message| message.text.clone())
504 }
505}
506
507#[derive(Debug)]
508pub struct Message {
509 _role: Role,
510 text: String,
511 tool_use: Vec<ToolUse>,
512}
513
514#[derive(Debug)]
515pub struct ToolUse {
516 pub name: String,
517 value: serde_json::Value,
518}
519
520impl ToolUse {
521 pub fn parse_input<Input>(&self) -> Result<Input>
522 where
523 Input: for<'de> serde::Deserialize<'de>,
524 {
525 serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
526 }
527}
528
529#[derive(Debug, Eq, PartialEq)]
530pub struct FileEdits {
531 pub hunks: Vec<FileEditHunk>,
532}
533
534#[derive(Debug, Eq, PartialEq)]
535pub struct FileEditHunk {
536 pub base_text: String,
537 pub text: String,
538 pub status: DiffHunkStatus,
539}
540
541impl FileEdits {
542 pub fn has_added_line(&self, line: &str) -> bool {
543 self.hunks.iter().any(|hunk| {
544 hunk.status == DiffHunkStatus::added_none()
545 && hunk.base_text.is_empty()
546 && hunk.text.contains(line)
547 })
548 }
549}