1use std::{
2 error::Error,
3 fmt::{self, Debug},
4 path::Path,
5 sync::{Arc, Mutex},
6 time::Duration,
7};
8
9use crate::{
10 ToolMetrics,
11 assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
12};
13use agent::{ContextLoadResult, Thread, ThreadEvent};
14use anyhow::{Result, anyhow};
15use assistant_settings::AgentProfileId;
16use async_trait::async_trait;
17use buffer_diff::DiffHunkStatus;
18use collections::HashMap;
19use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
20use gpui::{App, AppContext, AsyncApp, Entity};
21use language_model::{LanguageModel, Role, StopReason};
22
23pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
24
25#[async_trait(?Send)]
26pub trait Example {
27 fn meta(&self) -> ExampleMetadata;
28 async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
29 fn diff_assertions(&self) -> Vec<JudgeAssertion> {
30 Vec::new()
31 }
32 fn thread_assertions(&self) -> Vec<JudgeAssertion> {
33 Vec::new()
34 }
35}
36
37#[derive(Clone, Debug)]
38pub struct JudgeAssertion {
39 pub id: String,
40 pub description: String,
41}
42
43#[derive(Clone, Debug)]
44pub struct ExampleMetadata {
45 pub name: String,
46 pub url: String,
47 pub revision: String,
48 pub language_server: Option<LanguageServer>,
49 pub max_assertions: Option<usize>,
50 pub profile_id: AgentProfileId,
51}
52
53#[derive(Clone, Debug)]
54pub struct LanguageServer {
55 pub file_extension: String,
56 pub allow_preexisting_diagnostics: bool,
57}
58
59impl ExampleMetadata {
60 pub fn repo_name(&self) -> String {
61 self.url
62 .split('/')
63 .next_back()
64 .unwrap_or(&"")
65 .trim_end_matches(".git")
66 .into()
67 }
68}
69
70pub struct FailedAssertion(pub String);
71
72impl fmt::Debug for FailedAssertion {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 write!(f, "Assertion failure: {}", self.0)
75 }
76}
77
78impl fmt::Display for FailedAssertion {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 write!(f, "{}", self.0)
81 }
82}
83
84impl Error for FailedAssertion {}
85
86pub struct ExampleContext {
87 meta: ExampleMetadata,
88 log_prefix: String,
89 agent_thread: Entity<agent::Thread>,
90 app: AsyncApp,
91 model: Arc<dyn LanguageModel>,
92 pub assertions: AssertionsReport,
93 pub tool_metrics: Arc<Mutex<ToolMetrics>>,
94}
95
96impl ExampleContext {
97 pub fn new(
98 meta: ExampleMetadata,
99 log_prefix: String,
100 agent_thread: Entity<agent::Thread>,
101 model: Arc<dyn LanguageModel>,
102 app: AsyncApp,
103 ) -> Self {
104 let assertions = AssertionsReport::new(meta.max_assertions);
105
106 Self {
107 meta,
108 log_prefix,
109 agent_thread,
110 assertions,
111 model,
112 app,
113 tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
114 }
115 }
116
117 pub fn push_user_message(&mut self, text: impl ToString) {
118 self.app
119 .update_entity(&self.agent_thread, |thread, cx| {
120 thread.insert_user_message(
121 text.to_string(),
122 ContextLoadResult::default(),
123 None,
124 Vec::new(),
125 cx,
126 );
127 })
128 .unwrap();
129 }
130
131 pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
132 let message = message.to_string();
133 self.log_assertion(
134 if expected {
135 Ok(())
136 } else {
137 Err(anyhow::Error::from(FailedAssertion(message.clone())))
138 },
139 message,
140 )
141 }
142
143 pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
144 let message = message.to_string();
145 self.log_assertion(
146 match option {
147 Some(value) => Ok(value),
148 None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
149 },
150 message,
151 )
152 }
153
154 #[allow(dead_code)]
155 pub fn assert_eq<T: PartialEq + Debug>(
156 &mut self,
157 left: T,
158 right: T,
159 message: impl ToString,
160 ) -> Result<()> {
161 let message = message.to_string();
162 self.log_assertion(
163 if left == right {
164 Ok(())
165 } else {
166 println!(
167 "{}{}",
168 self.log_prefix,
169 pretty_assertions::Comparison::new(&left, &right)
170 );
171 Err(anyhow::Error::from(FailedAssertion(message.clone())))
172 },
173 message,
174 )
175 }
176
177 fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
178 if let Some(max) = self.meta.max_assertions {
179 if self.assertions.run_count() > max {
180 return Err(anyhow!(
181 "More assertions were run than the stated max_assertions of {}",
182 max
183 ));
184 }
185 }
186
187 self.assertions.ran.push(RanAssertion {
188 id: message.clone(),
189 result: Ok(RanAssertionResult {
190 analysis: None,
191 passed: result.is_ok(),
192 }),
193 });
194
195 if result.is_ok() {
196 println!("{}✅ {}", self.log_prefix, message);
197 } else {
198 println!("{}❌ {}", self.log_prefix, message);
199 }
200
201 result
202 }
203
204 pub async fn run_to_end(&mut self) -> Result<Response> {
205 self.run_turns(u32::MAX).await
206 }
207
208 pub async fn run_turn(&mut self) -> Result<Response> {
209 self.run_turns(1).await
210 }
211
212 pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
213 let (mut tx, mut rx) = mpsc::channel(1);
214
215 let tool_metrics = self.tool_metrics.clone();
216 let log_prefix = self.log_prefix.clone();
217 let _subscription = self.app.subscribe(
218 &self.agent_thread,
219 move |thread, event: &ThreadEvent, cx| match event {
220 ThreadEvent::ShowError(thread_error) => {
221 tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
222 }
223 ThreadEvent::Stopped(reason) => match reason {
224 Ok(StopReason::EndTurn) => {
225 tx.close_channel();
226 }
227 Ok(StopReason::ToolUse) => {
228 if thread.read(cx).remaining_turns() == 0 {
229 tx.close_channel();
230 }
231 }
232 Ok(StopReason::MaxTokens) => {
233 tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
234 }
235 Err(err) => {
236 tx.try_send(Err(anyhow!(err.clone()))).ok();
237 }
238 },
239 ThreadEvent::NewRequest
240 | ThreadEvent::StreamedAssistantText(_, _)
241 | ThreadEvent::StreamedAssistantThinking(_, _)
242 | ThreadEvent::UsePendingTools { .. }
243 | ThreadEvent::CompletionCanceled => {}
244 ThreadEvent::ToolFinished {
245 tool_use_id,
246 pending_tool_use,
247 ..
248 } => {
249 thread.update(cx, |thread, _cx| {
250 if let Some(tool_use) = pending_tool_use {
251 let mut tool_metrics = tool_metrics.lock().unwrap();
252 if let Some(tool_result) = thread.tool_result(&tool_use_id) {
253 let message = if tool_result.is_error {
254 format!("✖︎ {}", tool_use.name)
255 } else {
256 format!("✔︎ {}", tool_use.name)
257 };
258 println!("{log_prefix}{message}");
259 tool_metrics
260 .insert(tool_result.tool_name.clone(), !tool_result.is_error);
261 } else {
262 let message =
263 format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
264 println!("{log_prefix}{message}");
265 tool_metrics.insert(tool_use.name.clone(), true);
266 }
267 }
268 });
269 }
270 ThreadEvent::InvalidToolInput { .. } => {
271 println!("{log_prefix} invalid tool input");
272 }
273 ThreadEvent::MissingToolUse {
274 tool_use_id: _,
275 ui_text,
276 } => {
277 println!("{log_prefix} {ui_text}");
278 }
279 ThreadEvent::ToolConfirmationNeeded => {
280 panic!(
281 "{}Bug: Tool confirmation should not be required in eval",
282 log_prefix
283 );
284 }
285 ThreadEvent::StreamedCompletion
286 | ThreadEvent::MessageAdded(_)
287 | ThreadEvent::MessageEdited(_)
288 | ThreadEvent::MessageDeleted(_)
289 | ThreadEvent::SummaryChanged
290 | ThreadEvent::SummaryGenerated
291 | ThreadEvent::ReceivedTextChunk
292 | ThreadEvent::StreamedToolUse { .. }
293 | ThreadEvent::CheckpointChanged
294 | ThreadEvent::CancelEditing => {
295 tx.try_send(Ok(())).ok();
296 if std::env::var("ZED_EVAL_DEBUG").is_ok() {
297 println!("{}Event: {:#?}", log_prefix, event);
298 }
299 }
300 },
301 );
302
303 let model = self.model.clone();
304
305 let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
306 thread.set_remaining_turns(iterations);
307 thread.send_to_model(model, None, cx);
308 thread.messages().len()
309 })?;
310
311 loop {
312 select_biased! {
313 result = rx.next() => {
314 if let Some(result) = result {
315 result?;
316 } else {
317 break;
318 }
319 }
320 _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
321 return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
322 }
323 }
324 }
325
326 let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
327 let mut messages = Vec::new();
328 for message in thread.messages().skip(message_count_before) {
329 messages.push(Message {
330 _role: message.role,
331 text: message.to_string(),
332 tool_use: thread
333 .tool_uses_for_message(message.id, cx)
334 .into_iter()
335 .map(|tool_use| ToolUse {
336 name: tool_use.name.to_string(),
337 value: tool_use.input,
338 })
339 .collect(),
340 });
341 }
342 messages
343 })?;
344
345 let response = Response::new(messages);
346
347 Ok(response)
348 }
349
350 pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
351 self.agent_thread
352 .read_with(&self.app, |thread, cx| {
353 let action_log = thread.action_log().read(cx);
354 HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
355 |(buffer, diff)| {
356 let snapshot = buffer.read(cx).snapshot();
357
358 let file = snapshot.file().unwrap();
359 let diff = diff.read(cx);
360 let base_text = diff.base_text().text();
361
362 let hunks = diff
363 .hunks(&snapshot, cx)
364 .map(|hunk| FileEditHunk {
365 base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
366 text: snapshot
367 .text_for_range(hunk.range.clone())
368 .collect::<String>(),
369 status: hunk.status(),
370 })
371 .collect();
372
373 (file.path().clone(), FileEdits { hunks })
374 },
375 ))
376 })
377 .unwrap()
378 }
379
380 pub fn agent_thread(&self) -> Entity<Thread> {
381 self.agent_thread.clone()
382 }
383}
384
385impl AppContext for ExampleContext {
386 type Result<T> = anyhow::Result<T>;
387
388 fn new<T: 'static>(
389 &mut self,
390 build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
391 ) -> Self::Result<Entity<T>> {
392 self.app.new(build_entity)
393 }
394
395 fn reserve_entity<T: 'static>(&mut self) -> Self::Result<gpui::Reservation<T>> {
396 self.app.reserve_entity()
397 }
398
399 fn insert_entity<T: 'static>(
400 &mut self,
401 reservation: gpui::Reservation<T>,
402 build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
403 ) -> Self::Result<Entity<T>> {
404 self.app.insert_entity(reservation, build_entity)
405 }
406
407 fn update_entity<T, R>(
408 &mut self,
409 handle: &Entity<T>,
410 update: impl FnOnce(&mut T, &mut gpui::Context<T>) -> R,
411 ) -> Self::Result<R>
412 where
413 T: 'static,
414 {
415 self.app.update_entity(handle, update)
416 }
417
418 fn read_entity<T, R>(
419 &self,
420 handle: &Entity<T>,
421 read: impl FnOnce(&T, &App) -> R,
422 ) -> Self::Result<R>
423 where
424 T: 'static,
425 {
426 self.app.read_entity(handle, read)
427 }
428
429 fn update_window<T, F>(&mut self, window: gpui::AnyWindowHandle, f: F) -> Result<T>
430 where
431 F: FnOnce(gpui::AnyView, &mut gpui::Window, &mut App) -> T,
432 {
433 self.app.update_window(window, f)
434 }
435
436 fn read_window<T, R>(
437 &self,
438 window: &gpui::WindowHandle<T>,
439 read: impl FnOnce(Entity<T>, &App) -> R,
440 ) -> Result<R>
441 where
442 T: 'static,
443 {
444 self.app.read_window(window, read)
445 }
446
447 fn background_spawn<R>(
448 &self,
449 future: impl std::future::Future<Output = R> + Send + 'static,
450 ) -> gpui::Task<R>
451 where
452 R: Send + 'static,
453 {
454 self.app.background_spawn(future)
455 }
456
457 fn read_global<G, R>(&self, callback: impl FnOnce(&G, &App) -> R) -> Self::Result<R>
458 where
459 G: gpui::Global,
460 {
461 self.app.read_global(callback)
462 }
463}
464
465#[derive(Debug)]
466pub struct Response {
467 messages: Vec<Message>,
468}
469
470impl Response {
471 pub fn new(messages: Vec<Message>) -> Self {
472 Self { messages }
473 }
474
475 pub fn expect_tool(
476 &self,
477 tool_name: &'static str,
478 cx: &mut ExampleContext,
479 ) -> Result<&ToolUse> {
480 let result = self.messages.iter().find_map(|msg| {
481 msg.tool_use
482 .iter()
483 .find(|tool_use| tool_use.name == tool_name)
484 });
485 cx.assert_some(result, format!("called `{}`", tool_name))
486 }
487
488 #[allow(dead_code)]
489 pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
490 self.messages.iter().flat_map(|msg| &msg.tool_use)
491 }
492
493 pub fn texts(&self) -> impl Iterator<Item = String> {
494 self.messages.iter().map(|message| message.text.clone())
495 }
496}
497
498#[derive(Debug)]
499pub struct Message {
500 _role: Role,
501 text: String,
502 tool_use: Vec<ToolUse>,
503}
504
505#[derive(Debug)]
506pub struct ToolUse {
507 pub name: String,
508 value: serde_json::Value,
509}
510
511impl ToolUse {
512 pub fn parse_input<Input>(&self) -> Result<Input>
513 where
514 Input: for<'de> serde::Deserialize<'de>,
515 {
516 serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
517 }
518}
519
520#[derive(Debug, Eq, PartialEq)]
521pub struct FileEdits {
522 pub hunks: Vec<FileEditHunk>,
523}
524
525#[derive(Debug, Eq, PartialEq)]
526pub struct FileEditHunk {
527 pub base_text: String,
528 pub text: String,
529 pub status: DiffHunkStatus,
530}
531
532impl FileEdits {
533 pub fn has_added_line(&self, line: &str) -> bool {
534 self.hunks.iter().any(|hunk| {
535 hunk.status == DiffHunkStatus::added_none()
536 && hunk.base_text.is_empty()
537 && hunk.text.contains(line)
538 })
539 }
540}