1use std::{
2 error::Error,
3 fmt::{self, Debug},
4 path::Path,
5 sync::{Arc, Mutex},
6 time::Duration,
7};
8
9use crate::{
10 ToolMetrics,
11 assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
12};
13use agent::{ContextLoadResult, Thread, ThreadEvent};
14use anyhow::{Result, anyhow};
15use assistant_settings::AgentProfileId;
16use async_trait::async_trait;
17use buffer_diff::DiffHunkStatus;
18use collections::HashMap;
19use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
20use gpui::{App, AppContext, AsyncApp, Entity};
21use language_model::{LanguageModel, Role, StopReason};
22
23pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
24
25#[async_trait(?Send)]
26pub trait Example {
27 fn meta(&self) -> ExampleMetadata;
28 async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
29 fn diff_assertions(&self) -> Vec<JudgeAssertion> {
30 Vec::new()
31 }
32 fn thread_assertions(&self) -> Vec<JudgeAssertion> {
33 Vec::new()
34 }
35}
36
37#[derive(Clone, Debug)]
38pub struct JudgeAssertion {
39 pub id: String,
40 pub description: String,
41}
42
43#[derive(Clone, Debug)]
44pub struct ExampleMetadata {
45 pub name: String,
46 pub url: String,
47 pub revision: String,
48 pub language_server: Option<LanguageServer>,
49 pub max_assertions: Option<usize>,
50 pub profile_id: AgentProfileId,
51 pub existing_thread_json: Option<String>,
52}
53
54#[derive(Clone, Debug)]
55pub struct LanguageServer {
56 pub file_extension: String,
57 pub allow_preexisting_diagnostics: bool,
58}
59
60impl ExampleMetadata {
61 pub fn repo_name(&self) -> String {
62 self.url
63 .split('/')
64 .next_back()
65 .unwrap_or(&"")
66 .trim_end_matches(".git")
67 .into()
68 }
69}
70
71pub struct FailedAssertion(pub String);
72
73impl fmt::Debug for FailedAssertion {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 write!(f, "Assertion failure: {}", self.0)
76 }
77}
78
79impl fmt::Display for FailedAssertion {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 write!(f, "{}", self.0)
82 }
83}
84
85impl Error for FailedAssertion {}
86
87pub struct ExampleContext {
88 meta: ExampleMetadata,
89 log_prefix: String,
90 agent_thread: Entity<agent::Thread>,
91 app: AsyncApp,
92 model: Arc<dyn LanguageModel>,
93 pub assertions: AssertionsReport,
94 pub tool_metrics: Arc<Mutex<ToolMetrics>>,
95}
96
97impl ExampleContext {
98 pub fn new(
99 meta: ExampleMetadata,
100 log_prefix: String,
101 agent_thread: Entity<agent::Thread>,
102 model: Arc<dyn LanguageModel>,
103 app: AsyncApp,
104 ) -> Self {
105 let assertions = AssertionsReport::new(meta.max_assertions);
106
107 Self {
108 meta,
109 log_prefix,
110 agent_thread,
111 assertions,
112 model,
113 app,
114 tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
115 }
116 }
117
118 pub fn push_user_message(&mut self, text: impl ToString) {
119 self.app
120 .update_entity(&self.agent_thread, |thread, cx| {
121 thread.insert_user_message(
122 text.to_string(),
123 ContextLoadResult::default(),
124 None,
125 Vec::new(),
126 cx,
127 );
128 })
129 .unwrap();
130 }
131
132 pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
133 let message = message.to_string();
134 self.log_assertion(
135 if expected {
136 Ok(())
137 } else {
138 Err(anyhow::Error::from(FailedAssertion(message.clone())))
139 },
140 message,
141 )
142 }
143
144 pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
145 let message = message.to_string();
146 self.log_assertion(
147 match option {
148 Some(value) => Ok(value),
149 None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
150 },
151 message,
152 )
153 }
154
155 #[allow(dead_code)]
156 pub fn assert_eq<T: PartialEq + Debug>(
157 &mut self,
158 left: T,
159 right: T,
160 message: impl ToString,
161 ) -> Result<()> {
162 let message = message.to_string();
163 self.log_assertion(
164 if left == right {
165 Ok(())
166 } else {
167 println!(
168 "{}{}",
169 self.log_prefix,
170 pretty_assertions::Comparison::new(&left, &right)
171 );
172 Err(anyhow::Error::from(FailedAssertion(message.clone())))
173 },
174 message,
175 )
176 }
177
178 fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
179 if let Some(max) = self.meta.max_assertions {
180 anyhow::ensure!(
181 self.assertions.run_count() <= max,
182 "More assertions were run than the stated max_assertions of {max}"
183 );
184 }
185
186 self.assertions.ran.push(RanAssertion {
187 id: message.clone(),
188 result: Ok(RanAssertionResult {
189 analysis: None,
190 passed: result.is_ok(),
191 }),
192 });
193
194 if result.is_ok() {
195 println!("{}✅ {}", self.log_prefix, message);
196 } else {
197 println!("{}❌ {}", self.log_prefix, message);
198 }
199
200 result
201 }
202
203 pub async fn run_to_end(&mut self) -> Result<Response> {
204 self.run_turns(u32::MAX).await
205 }
206
207 pub async fn run_turn(&mut self) -> Result<Response> {
208 self.run_turns(1).await
209 }
210
211 pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
212 let (mut tx, mut rx) = mpsc::channel(1);
213
214 let tool_metrics = self.tool_metrics.clone();
215 let log_prefix = self.log_prefix.clone();
216 let _subscription = self.app.subscribe(
217 &self.agent_thread,
218 move |thread, event: &ThreadEvent, cx| match event {
219 ThreadEvent::ShowError(thread_error) => {
220 tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
221 }
222 ThreadEvent::Stopped(reason) => match reason {
223 Ok(StopReason::EndTurn) => {
224 tx.close_channel();
225 }
226 Ok(StopReason::ToolUse) => {
227 if thread.read(cx).remaining_turns() == 0 {
228 tx.close_channel();
229 }
230 }
231 Ok(StopReason::MaxTokens) => {
232 tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
233 }
234 Ok(StopReason::Refusal) => {
235 tx.try_send(Err(anyhow!("Model refused to generate content")))
236 .ok();
237 }
238 Err(err) => {
239 tx.try_send(Err(anyhow!(err.clone()))).ok();
240 }
241 },
242 ThreadEvent::NewRequest
243 | ThreadEvent::StreamedAssistantText(_, _)
244 | ThreadEvent::StreamedAssistantThinking(_, _)
245 | ThreadEvent::UsePendingTools { .. }
246 | ThreadEvent::CompletionCanceled => {}
247 ThreadEvent::ToolFinished {
248 tool_use_id,
249 pending_tool_use,
250 ..
251 } => {
252 thread.update(cx, |thread, _cx| {
253 if let Some(tool_use) = pending_tool_use {
254 let mut tool_metrics = tool_metrics.lock().unwrap();
255 if let Some(tool_result) = thread.tool_result(&tool_use_id) {
256 let message = if tool_result.is_error {
257 format!("✖︎ {}", tool_use.name)
258 } else {
259 format!("✔︎ {}", tool_use.name)
260 };
261 println!("{log_prefix}{message}");
262 tool_metrics
263 .insert(tool_result.tool_name.clone(), !tool_result.is_error);
264 } else {
265 let message =
266 format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
267 println!("{log_prefix}{message}");
268 tool_metrics.insert(tool_use.name.clone(), true);
269 }
270 }
271 });
272 }
273 ThreadEvent::InvalidToolInput { .. } => {
274 println!("{log_prefix} invalid tool input");
275 }
276 ThreadEvent::MissingToolUse {
277 tool_use_id: _,
278 ui_text,
279 } => {
280 println!("{log_prefix} {ui_text}");
281 }
282 ThreadEvent::ToolConfirmationNeeded => {
283 panic!(
284 "{}Bug: Tool confirmation should not be required in eval",
285 log_prefix
286 );
287 }
288 ThreadEvent::StreamedCompletion
289 | ThreadEvent::MessageAdded(_)
290 | ThreadEvent::MessageEdited(_)
291 | ThreadEvent::MessageDeleted(_)
292 | ThreadEvent::SummaryChanged
293 | ThreadEvent::SummaryGenerated
294 | ThreadEvent::ReceivedTextChunk
295 | ThreadEvent::StreamedToolUse { .. }
296 | ThreadEvent::CheckpointChanged
297 | ThreadEvent::CancelEditing => {
298 tx.try_send(Ok(())).ok();
299 if std::env::var("ZED_EVAL_DEBUG").is_ok() {
300 println!("{}Event: {:#?}", log_prefix, event);
301 }
302 }
303 },
304 );
305
306 let model = self.model.clone();
307
308 let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
309 thread.set_remaining_turns(iterations);
310 thread.send_to_model(model, None, cx);
311 thread.messages().len()
312 })?;
313
314 loop {
315 select_biased! {
316 result = rx.next() => {
317 if let Some(result) = result {
318 result?;
319 } else {
320 break;
321 }
322 }
323 _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
324 anyhow::bail!("Agentic loop stalled - waited {THREAD_EVENT_TIMEOUT:?} without any events");
325 }
326 }
327 }
328
329 let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
330 let mut messages = Vec::new();
331 for message in thread.messages().skip(message_count_before) {
332 messages.push(Message {
333 _role: message.role,
334 text: message.to_string(),
335 tool_use: thread
336 .tool_uses_for_message(message.id, cx)
337 .into_iter()
338 .map(|tool_use| ToolUse {
339 name: tool_use.name.to_string(),
340 value: tool_use.input,
341 })
342 .collect(),
343 });
344 }
345 messages
346 })?;
347
348 let response = Response::new(messages);
349
350 Ok(response)
351 }
352
353 pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
354 self.agent_thread
355 .read_with(&self.app, |thread, cx| {
356 let action_log = thread.action_log().read(cx);
357 HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
358 |(buffer, diff)| {
359 let snapshot = buffer.read(cx).snapshot();
360
361 let file = snapshot.file().unwrap();
362 let diff = diff.read(cx);
363 let base_text = diff.base_text().text();
364
365 let hunks = diff
366 .hunks(&snapshot, cx)
367 .map(|hunk| FileEditHunk {
368 base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
369 text: snapshot
370 .text_for_range(hunk.range.clone())
371 .collect::<String>(),
372 status: hunk.status(),
373 })
374 .collect();
375
376 (file.path().clone(), FileEdits { hunks })
377 },
378 ))
379 })
380 .unwrap()
381 }
382
383 pub fn agent_thread(&self) -> Entity<Thread> {
384 self.agent_thread.clone()
385 }
386}
387
388impl AppContext for ExampleContext {
389 type Result<T> = anyhow::Result<T>;
390
391 fn new<T: 'static>(
392 &mut self,
393 build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
394 ) -> Self::Result<Entity<T>> {
395 self.app.new(build_entity)
396 }
397
398 fn reserve_entity<T: 'static>(&mut self) -> Self::Result<gpui::Reservation<T>> {
399 self.app.reserve_entity()
400 }
401
402 fn insert_entity<T: 'static>(
403 &mut self,
404 reservation: gpui::Reservation<T>,
405 build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
406 ) -> Self::Result<Entity<T>> {
407 self.app.insert_entity(reservation, build_entity)
408 }
409
410 fn update_entity<T, R>(
411 &mut self,
412 handle: &Entity<T>,
413 update: impl FnOnce(&mut T, &mut gpui::Context<T>) -> R,
414 ) -> Self::Result<R>
415 where
416 T: 'static,
417 {
418 self.app.update_entity(handle, update)
419 }
420
421 fn read_entity<T, R>(
422 &self,
423 handle: &Entity<T>,
424 read: impl FnOnce(&T, &App) -> R,
425 ) -> Self::Result<R>
426 where
427 T: 'static,
428 {
429 self.app.read_entity(handle, read)
430 }
431
432 fn update_window<T, F>(&mut self, window: gpui::AnyWindowHandle, f: F) -> Result<T>
433 where
434 F: FnOnce(gpui::AnyView, &mut gpui::Window, &mut App) -> T,
435 {
436 self.app.update_window(window, f)
437 }
438
439 fn read_window<T, R>(
440 &self,
441 window: &gpui::WindowHandle<T>,
442 read: impl FnOnce(Entity<T>, &App) -> R,
443 ) -> Result<R>
444 where
445 T: 'static,
446 {
447 self.app.read_window(window, read)
448 }
449
450 fn background_spawn<R>(
451 &self,
452 future: impl std::future::Future<Output = R> + Send + 'static,
453 ) -> gpui::Task<R>
454 where
455 R: Send + 'static,
456 {
457 self.app.background_spawn(future)
458 }
459
460 fn read_global<G, R>(&self, callback: impl FnOnce(&G, &App) -> R) -> Self::Result<R>
461 where
462 G: gpui::Global,
463 {
464 self.app.read_global(callback)
465 }
466}
467
468#[derive(Debug)]
469pub struct Response {
470 messages: Vec<Message>,
471}
472
473impl Response {
474 pub fn new(messages: Vec<Message>) -> Self {
475 Self { messages }
476 }
477
478 pub fn expect_tool(
479 &self,
480 tool_name: &'static str,
481 cx: &mut ExampleContext,
482 ) -> Result<&ToolUse> {
483 let result = self.find_tool_call(tool_name);
484 cx.assert_some(result, format!("called `{}`", tool_name))
485 }
486
487 pub fn find_tool_call(&self, tool_name: &str) -> Option<&ToolUse> {
488 self.messages.iter().rev().find_map(|msg| {
489 msg.tool_use
490 .iter()
491 .find(|tool_use| tool_use.name == tool_name)
492 })
493 }
494
495 #[allow(dead_code)]
496 pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
497 self.messages.iter().flat_map(|msg| &msg.tool_use)
498 }
499
500 pub fn texts(&self) -> impl Iterator<Item = String> {
501 self.messages.iter().map(|message| message.text.clone())
502 }
503}
504
505#[derive(Debug)]
506pub struct Message {
507 _role: Role,
508 text: String,
509 tool_use: Vec<ToolUse>,
510}
511
512#[derive(Debug)]
513pub struct ToolUse {
514 pub name: String,
515 value: serde_json::Value,
516}
517
518impl ToolUse {
519 pub fn parse_input<Input>(&self) -> Result<Input>
520 where
521 Input: for<'de> serde::Deserialize<'de>,
522 {
523 serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
524 }
525}
526
527#[derive(Debug, Eq, PartialEq)]
528pub struct FileEdits {
529 pub hunks: Vec<FileEditHunk>,
530}
531
532#[derive(Debug, Eq, PartialEq)]
533pub struct FileEditHunk {
534 pub base_text: String,
535 pub text: String,
536 pub status: DiffHunkStatus,
537}
538
539impl FileEdits {
540 pub fn has_added_line(&self, line: &str) -> bool {
541 self.hunks.iter().any(|hunk| {
542 hunk.status == DiffHunkStatus::added_none()
543 && hunk.base_text.is_empty()
544 && hunk.text.contains(line)
545 })
546 }
547}