mod.rs

  1use super::*;
  2use crate::templates::Templates;
  3use client::{Client, UserStore};
  4use gpui::{AppContext, Entity, TestAppContext};
  5use language_model::{
  6    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  7    LanguageModelRegistry, MessageContent, StopReason,
  8};
  9use reqwest_client::ReqwestClient;
 10use schemars::JsonSchema;
 11use serde::{Deserialize, Serialize};
 12use smol::stream::StreamExt;
 13use std::{sync::Arc, time::Duration};
 14
 15mod test_tools;
 16use test_tools::*;
 17
 18#[gpui::test]
 19async fn test_echo(cx: &mut TestAppContext) {
 20    let ThreadTest { model, thread, .. } = setup(cx).await;
 21
 22    let events = thread
 23        .update(cx, |thread, cx| {
 24            thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
 25        })
 26        .collect()
 27        .await;
 28    thread.update(cx, |thread, _cx| {
 29        assert_eq!(
 30            thread.messages().last().unwrap().content,
 31            vec![MessageContent::Text("Hello".to_string())]
 32        );
 33    });
 34    assert_eq!(stop_events(events), vec![StopReason::EndTurn]);
 35}
 36
 37#[gpui::test]
 38async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 39    let ThreadTest { model, thread, .. } = setup(cx).await;
 40
 41    // Test a tool call that's likely to complete *before* streaming stops.
 42    let events = thread
 43        .update(cx, |thread, cx| {
 44            thread.add_tool(EchoTool);
 45            thread.send(
 46                model.clone(),
 47                "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
 48                cx,
 49            )
 50        })
 51        .collect()
 52        .await;
 53    assert_eq!(
 54        stop_events(events),
 55        vec![StopReason::ToolUse, StopReason::EndTurn]
 56    );
 57
 58    // Test a tool calls that's likely to complete *after* streaming stops.
 59    let events = thread
 60        .update(cx, |thread, cx| {
 61            thread.remove_tool(&AgentTool::name(&EchoTool));
 62            thread.add_tool(DelayTool);
 63            thread.send(
 64                model.clone(),
 65                "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
 66                cx,
 67            )
 68        })
 69        .collect()
 70        .await;
 71    assert_eq!(
 72        stop_events(events),
 73        vec![StopReason::ToolUse, StopReason::EndTurn]
 74    );
 75    thread.update(cx, |thread, _cx| {
 76        assert!(thread
 77            .messages()
 78            .last()
 79            .unwrap()
 80            .content
 81            .iter()
 82            .any(|content| {
 83                if let MessageContent::Text(text) = content {
 84                    text.contains("Ding")
 85                } else {
 86                    false
 87                }
 88            }));
 89    });
 90}
 91
 92#[gpui::test]
 93async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 94    let ThreadTest { model, thread, .. } = setup(cx).await;
 95
 96    // Test a tool call that's likely to complete *before* streaming stops.
 97    let mut events = thread.update(cx, |thread, cx| {
 98        thread.add_tool(WordListTool);
 99        thread.send(model.clone(), "Test the word_list tool.", cx)
100    });
101
102    let mut saw_partial_tool_use = false;
103    while let Some(event) = events.next().await {
104        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event {
105            thread.update(cx, |thread, _cx| {
106                // Look for a tool use in the thread's last message
107                let last_content = thread.messages().last().unwrap().content.last().unwrap();
108                if let MessageContent::ToolUse(last_tool_use) = last_content {
109                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
110                    if tool_use_event.is_input_complete {
111                        last_tool_use
112                            .input
113                            .get("a")
114                            .expect("'a' has streamed because input is now complete");
115                        last_tool_use
116                            .input
117                            .get("g")
118                            .expect("'g' has streamed because input is now complete");
119                    } else {
120                        if !last_tool_use.is_input_complete
121                            && last_tool_use.input.get("g").is_none()
122                        {
123                            saw_partial_tool_use = true;
124                        }
125                    }
126                } else {
127                    panic!("last content should be a tool use");
128                }
129            });
130        }
131    }
132
133    assert!(
134        saw_partial_tool_use,
135        "should see at least one partially streamed tool use in the history"
136    );
137}
138
139#[gpui::test]
140async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
141    let ThreadTest { model, thread, .. } = setup(cx).await;
142
143    // Test concurrent tool calls with different delay times
144    let events = thread
145        .update(cx, |thread, cx| {
146            thread.add_tool(DelayTool);
147            thread.send(
148                model.clone(),
149                "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
150                cx,
151            )
152        })
153        .collect()
154        .await;
155
156    let stop_reasons = stop_events(events);
157    if stop_reasons.len() == 2 {
158        assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]);
159    } else if stop_reasons.len() == 3 {
160        assert_eq!(
161            stop_reasons,
162            vec![
163                StopReason::ToolUse,
164                StopReason::ToolUse,
165                StopReason::EndTurn
166            ]
167        );
168    } else {
169        panic!("Expected either 1 or 2 tool uses followed by end turn");
170    }
171
172    thread.update(cx, |thread, _cx| {
173        let last_message = thread.messages().last().unwrap();
174        let text = last_message
175            .content
176            .iter()
177            .filter_map(|content| {
178                if let MessageContent::Text(text) = content {
179                    Some(text.as_str())
180                } else {
181                    None
182                }
183            })
184            .collect::<String>();
185
186        assert!(text.contains("Ding"));
187    });
188}
189
190/// Filters out the stop events for asserting against in tests
191fn stop_events(
192    result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
193) -> Vec<StopReason> {
194    result_events
195        .into_iter()
196        .filter_map(|event| match event.unwrap() {
197            LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason),
198            _ => None,
199        })
200        .collect()
201}
202
203struct ThreadTest {
204    model: Arc<dyn LanguageModel>,
205    thread: Entity<Thread>,
206}
207
208async fn setup(cx: &mut TestAppContext) -> ThreadTest {
209    cx.executor().allow_parking();
210    cx.update(settings::init);
211    let templates = Templates::new();
212    let thread = cx.new(|_| Thread::new(templates));
213
214    let model = cx
215        .update(|cx| {
216            gpui_tokio::init(cx);
217            let http_client = ReqwestClient::user_agent("agent tests").unwrap();
218            cx.set_http_client(Arc::new(http_client));
219
220            client::init_settings(cx);
221            let client = Client::production(cx);
222            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
223            language_model::init(client.clone(), cx);
224            language_models::init(user_store.clone(), client.clone(), cx);
225
226            let models = LanguageModelRegistry::read_global(cx);
227            let model = models
228                .available_models(cx)
229                .find(|model| model.id().0 == "claude-3-7-sonnet-latest")
230                .unwrap();
231
232            let provider = models.provider(&model.provider_id()).unwrap();
233            let authenticated = provider.authenticate(cx);
234
235            cx.spawn(async move |_cx| {
236                authenticated.await.unwrap();
237                model
238            })
239        })
240        .await;
241
242    ThreadTest { model, thread }
243}
244
245#[cfg(test)]
246#[ctor::ctor]
247fn init_logger() {
248    if std::env::var("RUST_LOG").is_ok() {
249        env_logger::init();
250    }
251}