example.rs

  1use std::{
  2    error::Error,
  3    fmt::{self, Debug},
  4    sync::{Arc, Mutex},
  5    time::Duration,
  6};
  7
  8use crate::{
  9    ToolMetrics,
 10    assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
 11};
 12use acp_thread::UserMessageId;
 13use agent::{Thread, ThreadEvent, UserMessageContent};
 14use agent_client_protocol as acp;
 15use agent_settings::AgentProfileId;
 16use anyhow::{Result, anyhow};
 17use async_trait::async_trait;
 18use buffer_diff::DiffHunkStatus;
 19use cloud_llm_client::CompletionIntent;
 20use collections::HashMap;
 21use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
 22use gpui::{App, AppContext, AsyncApp, Entity};
 23use language_model::{LanguageModel, Role, StopReason};
 24use util::rel_path::RelPath;
 25
 26pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
 27
 28#[async_trait(?Send)]
 29pub trait Example {
 30    fn meta(&self) -> ExampleMetadata;
 31    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
 32    fn diff_assertions(&self) -> Vec<JudgeAssertion> {
 33        Vec::new()
 34    }
 35    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
 36        Vec::new()
 37    }
 38}
 39
 40#[derive(Clone, Debug)]
 41pub struct JudgeAssertion {
 42    pub id: String,
 43    pub description: String,
 44}
 45
 46#[derive(Clone, Debug)]
 47pub struct ExampleMetadata {
 48    pub name: String,
 49    pub url: String,
 50    pub revision: String,
 51    pub language_server: Option<LanguageServer>,
 52    pub max_assertions: Option<usize>,
 53    pub profile_id: AgentProfileId,
 54    pub existing_thread_json: Option<String>,
 55    pub max_turns: Option<u32>,
 56}
 57
 58#[derive(Clone, Debug)]
 59pub struct LanguageServer {
 60    pub file_extension: String,
 61    pub allow_preexisting_diagnostics: bool,
 62}
 63
 64impl ExampleMetadata {
 65    pub fn repo_name(&self) -> String {
 66        self.url
 67            .split('/')
 68            .next_back()
 69            .unwrap_or("")
 70            .trim_end_matches(".git")
 71            .into()
 72    }
 73}
 74
 75pub struct FailedAssertion(pub String);
 76
 77impl fmt::Debug for FailedAssertion {
 78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 79        write!(f, "Assertion failure: {}", self.0)
 80    }
 81}
 82
 83impl fmt::Display for FailedAssertion {
 84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 85        write!(f, "{}", self.0)
 86    }
 87}
 88
 89impl Error for FailedAssertion {}
 90
 91pub struct ExampleContext {
 92    meta: ExampleMetadata,
 93    log_prefix: String,
 94    agent_thread: Entity<agent::Thread>,
 95    app: AsyncApp,
 96    model: Arc<dyn LanguageModel>,
 97    pub assertions: AssertionsReport,
 98    pub tool_metrics: Arc<Mutex<ToolMetrics>>,
 99}
100
101impl ExampleContext {
102    pub fn new(
103        meta: ExampleMetadata,
104        log_prefix: String,
105        agent_thread: Entity<Thread>,
106        model: Arc<dyn LanguageModel>,
107        app: AsyncApp,
108    ) -> Self {
109        let assertions = AssertionsReport::new(meta.max_assertions);
110
111        Self {
112            meta,
113            log_prefix,
114            agent_thread,
115            assertions,
116            model,
117            app,
118            tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
119        }
120    }
121
122    pub fn push_user_message(&mut self, text: impl ToString) {
123        self.app
124            .update_entity(&self.agent_thread, |thread, cx| {
125                thread.insert_user_message(
126                    text.to_string(),
127                    ContextLoadResult::default(),
128                    None,
129                    Vec::new(),
130                    cx,
131                );
132            })
133            .unwrap();
134    }
135
136    pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
137        let message = message.to_string();
138        self.log_assertion(
139            if expected {
140                Ok(())
141            } else {
142                Err(anyhow::Error::from(FailedAssertion(message.clone())))
143            },
144            message,
145        )
146    }
147
148    pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
149        let message = message.to_string();
150        self.log_assertion(
151            match option {
152                Some(value) => Ok(value),
153                None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
154            },
155            message,
156        )
157    }
158
159    #[allow(dead_code)]
160    pub fn assert_eq<T: PartialEq + Debug>(
161        &mut self,
162        left: T,
163        right: T,
164        message: impl ToString,
165    ) -> Result<()> {
166        let message = message.to_string();
167        self.log_assertion(
168            if left == right {
169                Ok(())
170            } else {
171                println!(
172                    "{}{}",
173                    self.log_prefix,
174                    pretty_assertions::Comparison::new(&left, &right)
175                );
176                Err(anyhow::Error::from(FailedAssertion(message.clone())))
177            },
178            message,
179        )
180    }
181
182    fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
183        if let Some(max) = self.meta.max_assertions {
184            anyhow::ensure!(
185                self.assertions.run_count() <= max,
186                "More assertions were run than the stated max_assertions of {max}"
187            );
188        }
189
190        self.assertions.ran.push(RanAssertion {
191            id: message.clone(),
192            result: Ok(RanAssertionResult {
193                analysis: None,
194                passed: result.is_ok(),
195            }),
196        });
197
198        if result.is_ok() {
199            println!("{}{}", self.log_prefix, message);
200        } else {
201            println!("{}{}", self.log_prefix, message);
202        }
203
204        result
205    }
206
207    pub async fn run_to_end(&mut self) -> Result<Response> {
208        self.run_turns(u32::MAX).await
209    }
210
211    pub async fn run_turn(&mut self) -> Result<Response> {
212        self.run_turns(1).await
213    }
214
215    pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
216        let (mut tx, mut rx) = mpsc::channel(1);
217
218        let tool_metrics = self.tool_metrics.clone();
219        let log_prefix = self.log_prefix.clone();
220        let _subscription = self.app.subscribe(
221            &self.agent_thread,
222            move |thread, event: &ThreadEvent, cx| match event {
223                ThreadEvent::ShowError(thread_error) => {
224                    tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
225                }
226                ThreadEvent::Stopped(reason) => match reason {
227                    Ok(StopReason::EndTurn) => {
228                        tx.close_channel();
229                    }
230                    Ok(StopReason::ToolUse) => {
231                        if thread.read(cx).remaining_turns() == 0 {
232                            tx.close_channel();
233                        }
234                    }
235                    Ok(StopReason::MaxTokens) => {
236                        tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
237                    }
238                    Ok(StopReason::Refusal) => {
239                        tx.try_send(Err(anyhow!("Model refused to generate content")))
240                            .ok();
241                    }
242                    Err(err) => {
243                        tx.try_send(Err(anyhow!(err.clone()))).ok();
244                    }
245                },
246                ThreadEvent::NewRequest
247                | ThreadEvent::StreamedAssistantText(_, _)
248                | ThreadEvent::StreamedAssistantThinking(_, _)
249                | ThreadEvent::UsePendingTools { .. }
250                | ThreadEvent::CompletionCanceled => {}
251                ThreadEvent::ToolUseLimitReached => {}
252                ThreadEvent::ToolFinished {
253                    tool_use_id,
254                    pending_tool_use,
255                    ..
256                } => {
257                    thread.update(cx, |thread, _cx| {
258                        if let Some(tool_use) = pending_tool_use {
259                            let mut tool_metrics = tool_metrics.lock().unwrap();
260                            if let Some(tool_result) = thread.tool_result(tool_use_id) {
261                                let message = if tool_result.is_error {
262                                    format!("✖︎ {}", tool_use.name)
263                                } else {
264                                    format!("✔︎ {}", tool_use.name)
265                                };
266                                println!("{log_prefix}{message}");
267                                tool_metrics
268                                    .insert(tool_result.tool_name.clone(), !tool_result.is_error);
269                            } else {
270                                let message =
271                                    format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
272                                println!("{log_prefix}{message}");
273                                tool_metrics.insert(tool_use.name.clone(), true);
274                            }
275                        }
276                    });
277                }
278                ThreadEvent::InvalidToolInput { .. } => {
279                    println!("{log_prefix} invalid tool input");
280                }
281                ThreadEvent::MissingToolUse {
282                    tool_use_id: _,
283                    ui_text,
284                } => {
285                    println!("{log_prefix} {ui_text}");
286                }
287                ThreadEvent::ToolConfirmationNeeded => {
288                    panic!(
289                        "{}Bug: Tool confirmation should not be required in eval",
290                        log_prefix
291                    );
292                }
293                ThreadEvent::StreamedCompletion
294                | ThreadEvent::MessageAdded(_)
295                | ThreadEvent::MessageEdited(_)
296                | ThreadEvent::MessageDeleted(_)
297                | ThreadEvent::SummaryChanged
298                | ThreadEvent::SummaryGenerated
299                | ThreadEvent::ProfileChanged
300                | ThreadEvent::ReceivedTextChunk
301                | ThreadEvent::StreamedToolUse { .. }
302                | ThreadEvent::CheckpointChanged
303                | ThreadEvent::CancelEditing => {
304                    tx.try_send(Ok(())).ok();
305                    if std::env::var("ZED_EVAL_DEBUG").is_ok() {
306                        println!("{}Event: {:#?}", log_prefix, event);
307                    }
308                }
309            },
310        );
311
312        let model = self.model.clone();
313
314        let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
315            thread.set_remaining_turns(iterations);
316            thread.send_to_model(model, CompletionIntent::UserPrompt, None, cx);
317            thread.messages().len()
318        })?;
319
320        loop {
321            select_biased! {
322                result = rx.next() => {
323                    if let Some(result) = result {
324                        result?;
325                    } else {
326                        break;
327                    }
328                }
329                _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
330                    anyhow::bail!("Agentic loop stalled - waited {THREAD_EVENT_TIMEOUT:?} without any events");
331                }
332            }
333        }
334
335        let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
336            let mut messages = Vec::new();
337            for message in thread.messages().skip(message_count_before) {
338                messages.push(Message {
339                    _role: message.role,
340                    text: message.to_message_content(),
341                    tool_use: thread
342                        .tool_uses_for_message(message.id, cx)
343                        .into_iter()
344                        .map(|tool_use| ToolUse {
345                            name: tool_use.name.to_string(),
346                            value: tool_use.input,
347                        })
348                        .collect(),
349                });
350            }
351            messages
352        })?;
353
354        let response = Response::new(messages);
355
356        Ok(response)
357    }
358
359    pub fn edits(&self) -> HashMap<Arc<RelPath>, FileEdits> {
360        self.agent_thread
361            .read_with(&self.app, |thread, cx| {
362                let action_log = thread.action_log().read(cx);
363                HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
364                    |(buffer, diff)| {
365                        let snapshot = buffer.read(cx).snapshot();
366
367                        let file = snapshot.file().unwrap();
368                        let diff = diff.read(cx);
369                        let base_text = diff.base_text().text();
370
371                        let hunks = diff
372                            .hunks(&snapshot, cx)
373                            .map(|hunk| FileEditHunk {
374                                base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
375                                text: snapshot
376                                    .text_for_range(hunk.range.clone())
377                                    .collect::<String>(),
378                                status: hunk.status(),
379                            })
380                            .collect();
381
382                        (file.path().clone(), FileEdits { hunks })
383                    },
384                ))
385            })
386            .unwrap()
387    }
388
389    pub fn agent_thread(&self) -> Entity<Thread> {
390        self.agent_thread.clone()
391    }
392}
393
394impl AppContext for ExampleContext {
395    type Result<T> = anyhow::Result<T>;
396
397    fn new<T: 'static>(
398        &mut self,
399        build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
400    ) -> Self::Result<Entity<T>> {
401        self.app.new(build_entity)
402    }
403
404    fn reserve_entity<T: 'static>(&mut self) -> Self::Result<gpui::Reservation<T>> {
405        self.app.reserve_entity()
406    }
407
408    fn insert_entity<T: 'static>(
409        &mut self,
410        reservation: gpui::Reservation<T>,
411        build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
412    ) -> Self::Result<Entity<T>> {
413        self.app.insert_entity(reservation, build_entity)
414    }
415
416    fn update_entity<T, R>(
417        &mut self,
418        handle: &Entity<T>,
419        update: impl FnOnce(&mut T, &mut gpui::Context<T>) -> R,
420    ) -> Self::Result<R>
421    where
422        T: 'static,
423    {
424        self.app.update_entity(handle, update)
425    }
426
427    fn as_mut<'a, T>(&'a mut self, handle: &Entity<T>) -> Self::Result<gpui::GpuiBorrow<'a, T>>
428    where
429        T: 'static,
430    {
431        self.app.as_mut(handle)
432    }
433
434    fn read_entity<T, R>(
435        &self,
436        handle: &Entity<T>,
437        read: impl FnOnce(&T, &App) -> R,
438    ) -> Self::Result<R>
439    where
440        T: 'static,
441    {
442        self.app.read_entity(handle, read)
443    }
444
445    fn update_window<T, F>(&mut self, window: gpui::AnyWindowHandle, f: F) -> Result<T>
446    where
447        F: FnOnce(gpui::AnyView, &mut gpui::Window, &mut App) -> T,
448    {
449        self.app.update_window(window, f)
450    }
451
452    fn read_window<T, R>(
453        &self,
454        window: &gpui::WindowHandle<T>,
455        read: impl FnOnce(Entity<T>, &App) -> R,
456    ) -> Result<R>
457    where
458        T: 'static,
459    {
460        self.app.read_window(window, read)
461    }
462
463    fn background_spawn<R>(
464        &self,
465        future: impl std::future::Future<Output = R> + Send + 'static,
466    ) -> gpui::Task<R>
467    where
468        R: Send + 'static,
469    {
470        self.app.background_spawn(future)
471    }
472
473    fn read_global<G, R>(&self, callback: impl FnOnce(&G, &App) -> R) -> Self::Result<R>
474    where
475        G: gpui::Global,
476    {
477        self.app.read_global(callback)
478    }
479}
480
481#[derive(Debug)]
482pub struct Response {
483    messages: Vec<Message>,
484}
485
486impl Response {
487    pub fn new(messages: Vec<Message>) -> Self {
488        Self { messages }
489    }
490
491    pub fn expect_tool(
492        &self,
493        tool_name: &'static str,
494        cx: &mut ExampleContext,
495    ) -> Result<&ToolUse> {
496        let result = self.find_tool_call(tool_name);
497        cx.assert_some(result, format!("called `{}`", tool_name))
498    }
499
500    pub fn find_tool_call(&self, tool_name: &str) -> Option<&ToolUse> {
501        self.messages.iter().rev().find_map(|msg| {
502            msg.tool_use
503                .iter()
504                .find(|tool_use| tool_use.name == tool_name)
505        })
506    }
507
508    #[allow(dead_code)]
509    pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
510        self.messages.iter().flat_map(|msg| &msg.tool_use)
511    }
512
513    pub fn texts(&self) -> impl Iterator<Item = String> {
514        self.messages.iter().map(|message| message.text.clone())
515    }
516}
517
518#[derive(Debug)]
519pub struct Message {
520    _role: Role,
521    text: String,
522    tool_use: Vec<ToolUse>,
523}
524
525#[derive(Debug)]
526pub struct ToolUse {
527    pub name: String,
528    value: serde_json::Value,
529}
530
531impl ToolUse {
532    pub fn parse_input<Input>(&self) -> Result<Input>
533    where
534        Input: for<'de> serde::Deserialize<'de>,
535    {
536        serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
537    }
538}
539
540#[derive(Debug, Eq, PartialEq)]
541pub struct FileEdits {
542    pub hunks: Vec<FileEditHunk>,
543}
544
545#[derive(Debug, Eq, PartialEq)]
546pub struct FileEditHunk {
547    pub base_text: String,
548    pub text: String,
549    pub status: DiffHunkStatus,
550}
551
552impl FileEdits {
553    pub fn has_added_line(&self, line: &str) -> bool {
554        self.hunks.iter().any(|hunk| {
555            hunk.status == DiffHunkStatus::added_none()
556                && hunk.base_text.is_empty()
557                && hunk.text.contains(line)
558        })
559    }
560}