example.rs

  1use std::{
  2    error::Error,
  3    fmt::{self, Debug},
  4    path::Path,
  5    sync::{Arc, Mutex},
  6    time::Duration,
  7};
  8
  9use crate::{
 10    ToolMetrics,
 11    assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
 12};
 13use agent::{ContextLoadResult, Thread, ThreadEvent};
 14use agent_settings::AgentProfileId;
 15use anyhow::{Result, anyhow};
 16use async_trait::async_trait;
 17use buffer_diff::DiffHunkStatus;
 18use collections::HashMap;
 19use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
 20use gpui::{App, AppContext, AsyncApp, Entity};
 21use language_model::{LanguageModel, Role, StopReason};
 22
 23pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
 24
 25#[async_trait(?Send)]
 26pub trait Example {
 27    fn meta(&self) -> ExampleMetadata;
 28    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
 29    fn diff_assertions(&self) -> Vec<JudgeAssertion> {
 30        Vec::new()
 31    }
 32    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
 33        Vec::new()
 34    }
 35}
 36
 37#[derive(Clone, Debug)]
 38pub struct JudgeAssertion {
 39    pub id: String,
 40    pub description: String,
 41}
 42
 43#[derive(Clone, Debug)]
 44pub struct ExampleMetadata {
 45    pub name: String,
 46    pub url: String,
 47    pub revision: String,
 48    pub language_server: Option<LanguageServer>,
 49    pub max_assertions: Option<usize>,
 50    pub profile_id: AgentProfileId,
 51    pub existing_thread_json: Option<String>,
 52    pub max_turns: Option<u32>,
 53}
 54
 55#[derive(Clone, Debug)]
 56pub struct LanguageServer {
 57    pub file_extension: String,
 58    pub allow_preexisting_diagnostics: bool,
 59}
 60
 61impl ExampleMetadata {
 62    pub fn repo_name(&self) -> String {
 63        self.url
 64            .split('/')
 65            .next_back()
 66            .unwrap_or(&"")
 67            .trim_end_matches(".git")
 68            .into()
 69    }
 70}
 71
 72pub struct FailedAssertion(pub String);
 73
 74impl fmt::Debug for FailedAssertion {
 75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 76        write!(f, "Assertion failure: {}", self.0)
 77    }
 78}
 79
 80impl fmt::Display for FailedAssertion {
 81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 82        write!(f, "{}", self.0)
 83    }
 84}
 85
 86impl Error for FailedAssertion {}
 87
 88pub struct ExampleContext {
 89    meta: ExampleMetadata,
 90    log_prefix: String,
 91    agent_thread: Entity<agent::Thread>,
 92    app: AsyncApp,
 93    model: Arc<dyn LanguageModel>,
 94    pub assertions: AssertionsReport,
 95    pub tool_metrics: Arc<Mutex<ToolMetrics>>,
 96}
 97
 98impl ExampleContext {
 99    pub fn new(
100        meta: ExampleMetadata,
101        log_prefix: String,
102        agent_thread: Entity<agent::Thread>,
103        model: Arc<dyn LanguageModel>,
104        app: AsyncApp,
105    ) -> Self {
106        let assertions = AssertionsReport::new(meta.max_assertions);
107
108        Self {
109            meta,
110            log_prefix,
111            agent_thread,
112            assertions,
113            model,
114            app,
115            tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
116        }
117    }
118
119    pub fn push_user_message(&mut self, text: impl ToString) {
120        self.app
121            .update_entity(&self.agent_thread, |thread, cx| {
122                thread.insert_user_message(
123                    text.to_string(),
124                    ContextLoadResult::default(),
125                    None,
126                    Vec::new(),
127                    cx,
128                );
129            })
130            .unwrap();
131    }
132
133    pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
134        let message = message.to_string();
135        self.log_assertion(
136            if expected {
137                Ok(())
138            } else {
139                Err(anyhow::Error::from(FailedAssertion(message.clone())))
140            },
141            message,
142        )
143    }
144
145    pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
146        let message = message.to_string();
147        self.log_assertion(
148            match option {
149                Some(value) => Ok(value),
150                None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
151            },
152            message,
153        )
154    }
155
156    #[allow(dead_code)]
157    pub fn assert_eq<T: PartialEq + Debug>(
158        &mut self,
159        left: T,
160        right: T,
161        message: impl ToString,
162    ) -> Result<()> {
163        let message = message.to_string();
164        self.log_assertion(
165            if left == right {
166                Ok(())
167            } else {
168                println!(
169                    "{}{}",
170                    self.log_prefix,
171                    pretty_assertions::Comparison::new(&left, &right)
172                );
173                Err(anyhow::Error::from(FailedAssertion(message.clone())))
174            },
175            message,
176        )
177    }
178
179    fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
180        if let Some(max) = self.meta.max_assertions {
181            anyhow::ensure!(
182                self.assertions.run_count() <= max,
183                "More assertions were run than the stated max_assertions of {max}"
184            );
185        }
186
187        self.assertions.ran.push(RanAssertion {
188            id: message.clone(),
189            result: Ok(RanAssertionResult {
190                analysis: None,
191                passed: result.is_ok(),
192            }),
193        });
194
195        if result.is_ok() {
196            println!("{}{}", self.log_prefix, message);
197        } else {
198            println!("{}{}", self.log_prefix, message);
199        }
200
201        result
202    }
203
204    pub async fn run_to_end(&mut self) -> Result<Response> {
205        self.run_turns(u32::MAX).await
206    }
207
208    pub async fn run_turn(&mut self) -> Result<Response> {
209        self.run_turns(1).await
210    }
211
212    pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
213        let (mut tx, mut rx) = mpsc::channel(1);
214
215        let tool_metrics = self.tool_metrics.clone();
216        let log_prefix = self.log_prefix.clone();
217        let _subscription = self.app.subscribe(
218            &self.agent_thread,
219            move |thread, event: &ThreadEvent, cx| match event {
220                ThreadEvent::ShowError(thread_error) => {
221                    tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
222                }
223                ThreadEvent::Stopped(reason) => match reason {
224                    Ok(StopReason::EndTurn) => {
225                        tx.close_channel();
226                    }
227                    Ok(StopReason::ToolUse) => {
228                        if thread.read(cx).remaining_turns() == 0 {
229                            tx.close_channel();
230                        }
231                    }
232                    Ok(StopReason::MaxTokens) => {
233                        tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
234                    }
235                    Ok(StopReason::Refusal) => {
236                        tx.try_send(Err(anyhow!("Model refused to generate content")))
237                            .ok();
238                    }
239                    Err(err) => {
240                        tx.try_send(Err(anyhow!(err.clone()))).ok();
241                    }
242                },
243                ThreadEvent::NewRequest
244                | ThreadEvent::StreamedAssistantText(_, _)
245                | ThreadEvent::StreamedAssistantThinking(_, _)
246                | ThreadEvent::UsePendingTools { .. }
247                | ThreadEvent::CompletionCanceled => {}
248                ThreadEvent::ToolFinished {
249                    tool_use_id,
250                    pending_tool_use,
251                    ..
252                } => {
253                    thread.update(cx, |thread, _cx| {
254                        if let Some(tool_use) = pending_tool_use {
255                            let mut tool_metrics = tool_metrics.lock().unwrap();
256                            if let Some(tool_result) = thread.tool_result(&tool_use_id) {
257                                let message = if tool_result.is_error {
258                                    format!("✖︎ {}", tool_use.name)
259                                } else {
260                                    format!("✔︎ {}", tool_use.name)
261                                };
262                                println!("{log_prefix}{message}");
263                                tool_metrics
264                                    .insert(tool_result.tool_name.clone(), !tool_result.is_error);
265                            } else {
266                                let message =
267                                    format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
268                                println!("{log_prefix}{message}");
269                                tool_metrics.insert(tool_use.name.clone(), true);
270                            }
271                        }
272                    });
273                }
274                ThreadEvent::InvalidToolInput { .. } => {
275                    println!("{log_prefix} invalid tool input");
276                }
277                ThreadEvent::MissingToolUse {
278                    tool_use_id: _,
279                    ui_text,
280                } => {
281                    println!("{log_prefix} {ui_text}");
282                }
283                ThreadEvent::ToolConfirmationNeeded => {
284                    panic!(
285                        "{}Bug: Tool confirmation should not be required in eval",
286                        log_prefix
287                    );
288                }
289                ThreadEvent::StreamedCompletion
290                | ThreadEvent::MessageAdded(_)
291                | ThreadEvent::MessageEdited(_)
292                | ThreadEvent::MessageDeleted(_)
293                | ThreadEvent::SummaryChanged
294                | ThreadEvent::SummaryGenerated
295                | ThreadEvent::ReceivedTextChunk
296                | ThreadEvent::StreamedToolUse { .. }
297                | ThreadEvent::CheckpointChanged
298                | ThreadEvent::CancelEditing => {
299                    tx.try_send(Ok(())).ok();
300                    if std::env::var("ZED_EVAL_DEBUG").is_ok() {
301                        println!("{}Event: {:#?}", log_prefix, event);
302                    }
303                }
304            },
305        );
306
307        let model = self.model.clone();
308
309        let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
310            thread.set_remaining_turns(iterations);
311            thread.send_to_model(model, None, cx);
312            thread.messages().len()
313        })?;
314
315        loop {
316            select_biased! {
317                result = rx.next() => {
318                    if let Some(result) = result {
319                        result?;
320                    } else {
321                        break;
322                    }
323                }
324                _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
325                    anyhow::bail!("Agentic loop stalled - waited {THREAD_EVENT_TIMEOUT:?} without any events");
326                }
327            }
328        }
329
330        let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
331            let mut messages = Vec::new();
332            for message in thread.messages().skip(message_count_before) {
333                messages.push(Message {
334                    _role: message.role,
335                    text: message.to_string(),
336                    tool_use: thread
337                        .tool_uses_for_message(message.id, cx)
338                        .into_iter()
339                        .map(|tool_use| ToolUse {
340                            name: tool_use.name.to_string(),
341                            value: tool_use.input,
342                        })
343                        .collect(),
344                });
345            }
346            messages
347        })?;
348
349        let response = Response::new(messages);
350
351        Ok(response)
352    }
353
354    pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
355        self.agent_thread
356            .read_with(&self.app, |thread, cx| {
357                let action_log = thread.action_log().read(cx);
358                HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
359                    |(buffer, diff)| {
360                        let snapshot = buffer.read(cx).snapshot();
361
362                        let file = snapshot.file().unwrap();
363                        let diff = diff.read(cx);
364                        let base_text = diff.base_text().text();
365
366                        let hunks = diff
367                            .hunks(&snapshot, cx)
368                            .map(|hunk| FileEditHunk {
369                                base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
370                                text: snapshot
371                                    .text_for_range(hunk.range.clone())
372                                    .collect::<String>(),
373                                status: hunk.status(),
374                            })
375                            .collect();
376
377                        (file.path().clone(), FileEdits { hunks })
378                    },
379                ))
380            })
381            .unwrap()
382    }
383
384    pub fn agent_thread(&self) -> Entity<Thread> {
385        self.agent_thread.clone()
386    }
387}
388
389impl AppContext for ExampleContext {
390    type Result<T> = anyhow::Result<T>;
391
392    fn new<T: 'static>(
393        &mut self,
394        build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
395    ) -> Self::Result<Entity<T>> {
396        self.app.new(build_entity)
397    }
398
399    fn reserve_entity<T: 'static>(&mut self) -> Self::Result<gpui::Reservation<T>> {
400        self.app.reserve_entity()
401    }
402
403    fn insert_entity<T: 'static>(
404        &mut self,
405        reservation: gpui::Reservation<T>,
406        build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
407    ) -> Self::Result<Entity<T>> {
408        self.app.insert_entity(reservation, build_entity)
409    }
410
411    fn update_entity<T, R>(
412        &mut self,
413        handle: &Entity<T>,
414        update: impl FnOnce(&mut T, &mut gpui::Context<T>) -> R,
415    ) -> Self::Result<R>
416    where
417        T: 'static,
418    {
419        self.app.update_entity(handle, update)
420    }
421
422    fn read_entity<T, R>(
423        &self,
424        handle: &Entity<T>,
425        read: impl FnOnce(&T, &App) -> R,
426    ) -> Self::Result<R>
427    where
428        T: 'static,
429    {
430        self.app.read_entity(handle, read)
431    }
432
433    fn update_window<T, F>(&mut self, window: gpui::AnyWindowHandle, f: F) -> Result<T>
434    where
435        F: FnOnce(gpui::AnyView, &mut gpui::Window, &mut App) -> T,
436    {
437        self.app.update_window(window, f)
438    }
439
440    fn read_window<T, R>(
441        &self,
442        window: &gpui::WindowHandle<T>,
443        read: impl FnOnce(Entity<T>, &App) -> R,
444    ) -> Result<R>
445    where
446        T: 'static,
447    {
448        self.app.read_window(window, read)
449    }
450
451    fn background_spawn<R>(
452        &self,
453        future: impl std::future::Future<Output = R> + Send + 'static,
454    ) -> gpui::Task<R>
455    where
456        R: Send + 'static,
457    {
458        self.app.background_spawn(future)
459    }
460
461    fn read_global<G, R>(&self, callback: impl FnOnce(&G, &App) -> R) -> Self::Result<R>
462    where
463        G: gpui::Global,
464    {
465        self.app.read_global(callback)
466    }
467}
468
469#[derive(Debug)]
470pub struct Response {
471    messages: Vec<Message>,
472}
473
474impl Response {
475    pub fn new(messages: Vec<Message>) -> Self {
476        Self { messages }
477    }
478
479    pub fn expect_tool(
480        &self,
481        tool_name: &'static str,
482        cx: &mut ExampleContext,
483    ) -> Result<&ToolUse> {
484        let result = self.find_tool_call(tool_name);
485        cx.assert_some(result, format!("called `{}`", tool_name))
486    }
487
488    pub fn find_tool_call(&self, tool_name: &str) -> Option<&ToolUse> {
489        self.messages.iter().rev().find_map(|msg| {
490            msg.tool_use
491                .iter()
492                .find(|tool_use| tool_use.name == tool_name)
493        })
494    }
495
496    #[allow(dead_code)]
497    pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
498        self.messages.iter().flat_map(|msg| &msg.tool_use)
499    }
500
501    pub fn texts(&self) -> impl Iterator<Item = String> {
502        self.messages.iter().map(|message| message.text.clone())
503    }
504}
505
506#[derive(Debug)]
507pub struct Message {
508    _role: Role,
509    text: String,
510    tool_use: Vec<ToolUse>,
511}
512
513#[derive(Debug)]
514pub struct ToolUse {
515    pub name: String,
516    value: serde_json::Value,
517}
518
519impl ToolUse {
520    pub fn parse_input<Input>(&self) -> Result<Input>
521    where
522        Input: for<'de> serde::Deserialize<'de>,
523    {
524        serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
525    }
526}
527
528#[derive(Debug, Eq, PartialEq)]
529pub struct FileEdits {
530    pub hunks: Vec<FileEditHunk>,
531}
532
533#[derive(Debug, Eq, PartialEq)]
534pub struct FileEditHunk {
535    pub base_text: String,
536    pub text: String,
537    pub status: DiffHunkStatus,
538}
539
540impl FileEdits {
541    pub fn has_added_line(&self, line: &str) -> bool {
542        self.hunks.iter().any(|hunk| {
543            hunk.status == DiffHunkStatus::added_none()
544                && hunk.base_text.is_empty()
545                && hunk.text.contains(line)
546        })
547    }
548}