1use super::*;
2use acp_thread::AgentConnection;
3use action_log::ActionLog;
4use agent_client_protocol::{self as acp};
5use anyhow::Result;
6use client::{Client, UserStore};
7use fs::{FakeFs, Fs};
8use futures::channel::mpsc::UnboundedReceiver;
9use gpui::{
10 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
11};
12use indoc::indoc;
13use language_model::{
14 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
15 LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
16 StopReason, fake_provider::FakeLanguageModel,
17};
18use project::Project;
19use prompt_store::ProjectContext;
20use reqwest_client::ReqwestClient;
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23use serde_json::json;
24use settings::SettingsStore;
25use smol::stream::StreamExt;
26use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
27use util::path;
28
29mod test_tools;
30use test_tools::*;
31
32#[gpui::test]
33#[ignore = "can't run on CI yet"]
34async fn test_echo(cx: &mut TestAppContext) {
35 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
36
37 let events = thread
38 .update(cx, |thread, cx| {
39 thread.send("Testing: Reply with 'Hello'", cx)
40 })
41 .collect()
42 .await;
43 thread.update(cx, |thread, _cx| {
44 assert_eq!(
45 thread.messages().last().unwrap().content,
46 vec![MessageContent::Text("Hello".to_string())]
47 );
48 });
49 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
50}
51
52#[gpui::test]
53#[ignore = "can't run on CI yet"]
54async fn test_thinking(cx: &mut TestAppContext) {
55 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
56
57 let events = thread
58 .update(cx, |thread, cx| {
59 thread.send(
60 indoc! {"
61 Testing:
62
63 Generate a thinking step where you just think the word 'Think',
64 and have your final answer be 'Hello'
65 "},
66 cx,
67 )
68 })
69 .collect()
70 .await;
71 thread.update(cx, |thread, _cx| {
72 assert_eq!(
73 thread.messages().last().unwrap().to_markdown(),
74 indoc! {"
75 ## assistant
76 <think>Think</think>
77 Hello
78 "}
79 )
80 });
81 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
82}
83
84#[gpui::test]
85async fn test_system_prompt(cx: &mut TestAppContext) {
86 let ThreadTest {
87 model,
88 thread,
89 project_context,
90 ..
91 } = setup(cx, TestModel::Fake).await;
92 let fake_model = model.as_fake();
93
94 project_context.borrow_mut().shell = "test-shell".into();
95 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
96 thread.update(cx, |thread, cx| thread.send("abc", cx));
97 cx.run_until_parked();
98 let mut pending_completions = fake_model.pending_completions();
99 assert_eq!(
100 pending_completions.len(),
101 1,
102 "unexpected pending completions: {:?}",
103 pending_completions
104 );
105
106 let pending_completion = pending_completions.pop().unwrap();
107 assert_eq!(pending_completion.messages[0].role, Role::System);
108
109 let system_message = &pending_completion.messages[0];
110 let system_prompt = system_message.content[0].to_str().unwrap();
111 assert!(
112 system_prompt.contains("test-shell"),
113 "unexpected system message: {:?}",
114 system_message
115 );
116 assert!(
117 system_prompt.contains("## Fixing Diagnostics"),
118 "unexpected system message: {:?}",
119 system_message
120 );
121}
122
123#[gpui::test]
124#[ignore = "can't run on CI yet"]
125async fn test_basic_tool_calls(cx: &mut TestAppContext) {
126 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
127
128 // Test a tool call that's likely to complete *before* streaming stops.
129 let events = thread
130 .update(cx, |thread, cx| {
131 thread.add_tool(EchoTool);
132 thread.send(
133 "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
134 cx,
135 )
136 })
137 .collect()
138 .await;
139 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
140
141 // Test a tool calls that's likely to complete *after* streaming stops.
142 let events = thread
143 .update(cx, |thread, cx| {
144 thread.remove_tool(&AgentTool::name(&EchoTool));
145 thread.add_tool(DelayTool);
146 thread.send(
147 "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
148 cx,
149 )
150 })
151 .collect()
152 .await;
153 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
154 thread.update(cx, |thread, _cx| {
155 assert!(
156 thread
157 .messages()
158 .last()
159 .unwrap()
160 .content
161 .iter()
162 .any(|content| {
163 if let MessageContent::Text(text) = content {
164 text.contains("Ding")
165 } else {
166 false
167 }
168 })
169 );
170 });
171}
172
173#[gpui::test]
174#[ignore = "can't run on CI yet"]
175async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
176 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
177
178 // Test a tool call that's likely to complete *before* streaming stops.
179 let mut events = thread.update(cx, |thread, cx| {
180 thread.add_tool(WordListTool);
181 thread.send("Test the word_list tool.", cx)
182 });
183
184 let mut saw_partial_tool_use = false;
185 while let Some(event) = events.next().await {
186 if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
187 thread.update(cx, |thread, _cx| {
188 // Look for a tool use in the thread's last message
189 let last_content = thread.messages().last().unwrap().content.last().unwrap();
190 if let MessageContent::ToolUse(last_tool_use) = last_content {
191 assert_eq!(last_tool_use.name.as_ref(), "word_list");
192 if tool_call.status == acp::ToolCallStatus::Pending {
193 if !last_tool_use.is_input_complete
194 && last_tool_use.input.get("g").is_none()
195 {
196 saw_partial_tool_use = true;
197 }
198 } else {
199 last_tool_use
200 .input
201 .get("a")
202 .expect("'a' has streamed because input is now complete");
203 last_tool_use
204 .input
205 .get("g")
206 .expect("'g' has streamed because input is now complete");
207 }
208 } else {
209 panic!("last content should be a tool use");
210 }
211 });
212 }
213 }
214
215 assert!(
216 saw_partial_tool_use,
217 "should see at least one partially streamed tool use in the history"
218 );
219}
220
221#[gpui::test]
222async fn test_tool_authorization(cx: &mut TestAppContext) {
223 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
224 let fake_model = model.as_fake();
225
226 let mut events = thread.update(cx, |thread, cx| {
227 thread.add_tool(ToolRequiringPermission);
228 thread.send("abc", cx)
229 });
230 cx.run_until_parked();
231 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
232 LanguageModelToolUse {
233 id: "tool_id_1".into(),
234 name: ToolRequiringPermission.name().into(),
235 raw_input: "{}".into(),
236 input: json!({}),
237 is_input_complete: true,
238 },
239 ));
240 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
241 LanguageModelToolUse {
242 id: "tool_id_2".into(),
243 name: ToolRequiringPermission.name().into(),
244 raw_input: "{}".into(),
245 input: json!({}),
246 is_input_complete: true,
247 },
248 ));
249 fake_model.end_last_completion_stream();
250 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
251 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
252
253 // Approve the first
254 tool_call_auth_1
255 .response
256 .send(tool_call_auth_1.options[1].id.clone())
257 .unwrap();
258 cx.run_until_parked();
259
260 // Reject the second
261 tool_call_auth_2
262 .response
263 .send(tool_call_auth_1.options[2].id.clone())
264 .unwrap();
265 cx.run_until_parked();
266
267 let completion = fake_model.pending_completions().pop().unwrap();
268 let message = completion.messages.last().unwrap();
269 assert_eq!(
270 message.content,
271 vec![
272 MessageContent::ToolResult(LanguageModelToolResult {
273 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
274 tool_name: ToolRequiringPermission.name().into(),
275 is_error: false,
276 content: "Allowed".into(),
277 output: Some("Allowed".into())
278 }),
279 MessageContent::ToolResult(LanguageModelToolResult {
280 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
281 tool_name: ToolRequiringPermission.name().into(),
282 is_error: true,
283 content: "Permission to run tool denied by user".into(),
284 output: None
285 })
286 ]
287 );
288
289 // Simulate yet another tool call.
290 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
291 LanguageModelToolUse {
292 id: "tool_id_3".into(),
293 name: ToolRequiringPermission.name().into(),
294 raw_input: "{}".into(),
295 input: json!({}),
296 is_input_complete: true,
297 },
298 ));
299 fake_model.end_last_completion_stream();
300
301 // Respond by always allowing tools.
302 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
303 tool_call_auth_3
304 .response
305 .send(tool_call_auth_3.options[0].id.clone())
306 .unwrap();
307 cx.run_until_parked();
308 let completion = fake_model.pending_completions().pop().unwrap();
309 let message = completion.messages.last().unwrap();
310 assert_eq!(
311 message.content,
312 vec![MessageContent::ToolResult(LanguageModelToolResult {
313 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
314 tool_name: ToolRequiringPermission.name().into(),
315 is_error: false,
316 content: "Allowed".into(),
317 output: Some("Allowed".into())
318 })]
319 );
320
321 // Simulate a final tool call, ensuring we don't trigger authorization.
322 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
323 LanguageModelToolUse {
324 id: "tool_id_4".into(),
325 name: ToolRequiringPermission.name().into(),
326 raw_input: "{}".into(),
327 input: json!({}),
328 is_input_complete: true,
329 },
330 ));
331 fake_model.end_last_completion_stream();
332 cx.run_until_parked();
333 let completion = fake_model.pending_completions().pop().unwrap();
334 let message = completion.messages.last().unwrap();
335 assert_eq!(
336 message.content,
337 vec![MessageContent::ToolResult(LanguageModelToolResult {
338 tool_use_id: "tool_id_4".into(),
339 tool_name: ToolRequiringPermission.name().into(),
340 is_error: false,
341 content: "Allowed".into(),
342 output: Some("Allowed".into())
343 })]
344 );
345}
346
347#[gpui::test]
348async fn test_tool_hallucination(cx: &mut TestAppContext) {
349 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
350 let fake_model = model.as_fake();
351
352 let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx));
353 cx.run_until_parked();
354 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
355 LanguageModelToolUse {
356 id: "tool_id_1".into(),
357 name: "nonexistent_tool".into(),
358 raw_input: "{}".into(),
359 input: json!({}),
360 is_input_complete: true,
361 },
362 ));
363 fake_model.end_last_completion_stream();
364
365 let tool_call = expect_tool_call(&mut events).await;
366 assert_eq!(tool_call.title, "nonexistent_tool");
367 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
368 let update = expect_tool_call_update_fields(&mut events).await;
369 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
370}
371
372async fn expect_tool_call(
373 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
374) -> acp::ToolCall {
375 let event = events
376 .next()
377 .await
378 .expect("no tool call authorization event received")
379 .unwrap();
380 match event {
381 AgentResponseEvent::ToolCall(tool_call) => return tool_call,
382 event => {
383 panic!("Unexpected event {event:?}");
384 }
385 }
386}
387
388async fn expect_tool_call_update_fields(
389 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
390) -> acp::ToolCallUpdate {
391 let event = events
392 .next()
393 .await
394 .expect("no tool call authorization event received")
395 .unwrap();
396 match event {
397 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
398 return update;
399 }
400 event => {
401 panic!("Unexpected event {event:?}");
402 }
403 }
404}
405
406async fn next_tool_call_authorization(
407 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
408) -> ToolCallAuthorization {
409 loop {
410 let event = events
411 .next()
412 .await
413 .expect("no tool call authorization event received")
414 .unwrap();
415 if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
416 let permission_kinds = tool_call_authorization
417 .options
418 .iter()
419 .map(|o| o.kind)
420 .collect::<Vec<_>>();
421 assert_eq!(
422 permission_kinds,
423 vec![
424 acp::PermissionOptionKind::AllowAlways,
425 acp::PermissionOptionKind::AllowOnce,
426 acp::PermissionOptionKind::RejectOnce,
427 ]
428 );
429 return tool_call_authorization;
430 }
431 }
432}
433
434#[gpui::test]
435#[ignore = "can't run on CI yet"]
436async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
437 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
438
439 // Test concurrent tool calls with different delay times
440 let events = thread
441 .update(cx, |thread, cx| {
442 thread.add_tool(DelayTool);
443 thread.send(
444 "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
445 cx,
446 )
447 })
448 .collect()
449 .await;
450
451 let stop_reasons = stop_events(events);
452 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
453
454 thread.update(cx, |thread, _cx| {
455 let last_message = thread.messages().last().unwrap();
456 let text = last_message
457 .content
458 .iter()
459 .filter_map(|content| {
460 if let MessageContent::Text(text) = content {
461 Some(text.as_str())
462 } else {
463 None
464 }
465 })
466 .collect::<String>();
467
468 assert!(text.contains("Ding"));
469 });
470}
471
472#[gpui::test]
473#[ignore = "can't run on CI yet"]
474async fn test_cancellation(cx: &mut TestAppContext) {
475 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
476
477 let mut events = thread.update(cx, |thread, cx| {
478 thread.add_tool(InfiniteTool);
479 thread.add_tool(EchoTool);
480 thread.send(
481 "Call the echo tool and then call the infinite tool, then explain their output",
482 cx,
483 )
484 });
485
486 // Wait until both tools are called.
487 let mut expected_tools = vec!["Echo", "Infinite Tool"];
488 let mut echo_id = None;
489 let mut echo_completed = false;
490 while let Some(event) = events.next().await {
491 match event.unwrap() {
492 AgentResponseEvent::ToolCall(tool_call) => {
493 assert_eq!(tool_call.title, expected_tools.remove(0));
494 if tool_call.title == "Echo" {
495 echo_id = Some(tool_call.id);
496 }
497 }
498 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
499 acp::ToolCallUpdate {
500 id,
501 fields:
502 acp::ToolCallUpdateFields {
503 status: Some(acp::ToolCallStatus::Completed),
504 ..
505 },
506 },
507 )) if Some(&id) == echo_id.as_ref() => {
508 echo_completed = true;
509 }
510 _ => {}
511 }
512
513 if expected_tools.is_empty() && echo_completed {
514 break;
515 }
516 }
517
518 // Cancel the current send and ensure that the event stream is closed, even
519 // if one of the tools is still running.
520 thread.update(cx, |thread, _cx| thread.cancel());
521 events.collect::<Vec<_>>().await;
522
523 // Ensure we can still send a new message after cancellation.
524 let events = thread
525 .update(cx, |thread, cx| {
526 thread.send("Testing: reply with 'Hello' then stop.", cx)
527 })
528 .collect::<Vec<_>>()
529 .await;
530 thread.update(cx, |thread, _cx| {
531 assert_eq!(
532 thread.messages().last().unwrap().content,
533 vec![MessageContent::Text("Hello".to_string())]
534 );
535 });
536 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
537}
538
539#[gpui::test]
540async fn test_refusal(cx: &mut TestAppContext) {
541 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
542 let fake_model = model.as_fake();
543
544 let events = thread.update(cx, |thread, cx| thread.send("Hello", cx));
545 cx.run_until_parked();
546 thread.read_with(cx, |thread, _| {
547 assert_eq!(
548 thread.to_markdown(),
549 indoc! {"
550 ## user
551 Hello
552 "}
553 );
554 });
555
556 fake_model.send_last_completion_stream_text_chunk("Hey!");
557 cx.run_until_parked();
558 thread.read_with(cx, |thread, _| {
559 assert_eq!(
560 thread.to_markdown(),
561 indoc! {"
562 ## user
563 Hello
564 ## assistant
565 Hey!
566 "}
567 );
568 });
569
570 // If the model refuses to continue, the thread should remove all the messages after the last user message.
571 fake_model
572 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
573 let events = events.collect::<Vec<_>>().await;
574 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
575 thread.read_with(cx, |thread, _| {
576 assert_eq!(thread.to_markdown(), "");
577 });
578}
579
580#[gpui::test]
581async fn test_agent_connection(cx: &mut TestAppContext) {
582 cx.update(settings::init);
583 let templates = Templates::new();
584
585 // Initialize language model system with test provider
586 cx.update(|cx| {
587 gpui_tokio::init(cx);
588 client::init_settings(cx);
589
590 let http_client = FakeHttpClient::with_404_response();
591 let clock = Arc::new(clock::FakeSystemClock::new());
592 let client = Client::new(clock, http_client, cx);
593 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
594 language_model::init(client.clone(), cx);
595 language_models::init(user_store.clone(), client.clone(), cx);
596 Project::init_settings(cx);
597 LanguageModelRegistry::test(cx);
598 });
599 cx.executor().forbid_parking();
600
601 // Create a project for new_thread
602 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
603 fake_fs.insert_tree(path!("/test"), json!({})).await;
604 let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
605 let cwd = Path::new("/test");
606
607 // Create agent and connection
608 let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
609 .await
610 .unwrap();
611 let connection = NativeAgentConnection(agent.clone());
612
613 // Test model_selector returns Some
614 let selector_opt = connection.model_selector();
615 assert!(
616 selector_opt.is_some(),
617 "agent2 should always support ModelSelector"
618 );
619 let selector = selector_opt.unwrap();
620
621 // Test list_models
622 let listed_models = cx
623 .update(|cx| {
624 let mut async_cx = cx.to_async();
625 selector.list_models(&mut async_cx)
626 })
627 .await
628 .expect("list_models should succeed");
629 assert!(!listed_models.is_empty(), "should have at least one model");
630 assert_eq!(listed_models[0].id().0, "fake");
631
632 // Create a thread using new_thread
633 let connection_rc = Rc::new(connection.clone());
634 let acp_thread = cx
635 .update(|cx| {
636 let mut async_cx = cx.to_async();
637 connection_rc.new_thread(project, cwd, &mut async_cx)
638 })
639 .await
640 .expect("new_thread should succeed");
641
642 // Get the session_id from the AcpThread
643 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
644
645 // Test selected_model returns the default
646 let model = cx
647 .update(|cx| {
648 let mut async_cx = cx.to_async();
649 selector.selected_model(&session_id, &mut async_cx)
650 })
651 .await
652 .expect("selected_model should succeed");
653 let model = model.as_fake();
654 assert_eq!(model.id().0, "fake", "should return default model");
655
656 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
657 cx.run_until_parked();
658 model.send_last_completion_stream_text_chunk("def");
659 cx.run_until_parked();
660 acp_thread.read_with(cx, |thread, cx| {
661 assert_eq!(
662 thread.to_markdown(cx),
663 indoc! {"
664 ## User
665
666 abc
667
668 ## Assistant
669
670 def
671
672 "}
673 )
674 });
675
676 // Test cancel
677 cx.update(|cx| connection.cancel(&session_id, cx));
678 request.await.expect("prompt should fail gracefully");
679
680 // Ensure that dropping the ACP thread causes the native thread to be
681 // dropped as well.
682 cx.update(|_| drop(acp_thread));
683 let result = cx
684 .update(|cx| {
685 connection.prompt(
686 acp::PromptRequest {
687 session_id: session_id.clone(),
688 prompt: vec!["ghi".into()],
689 },
690 cx,
691 )
692 })
693 .await;
694 assert_eq!(
695 result.as_ref().unwrap_err().to_string(),
696 "Session not found",
697 "unexpected result: {:?}",
698 result
699 );
700}
701
702#[gpui::test]
703async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
704 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
705 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
706 let fake_model = model.as_fake();
707
708 let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx));
709 cx.run_until_parked();
710
711 // Simulate streaming partial input.
712 let input = json!({});
713 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
714 LanguageModelToolUse {
715 id: "1".into(),
716 name: ThinkingTool.name().into(),
717 raw_input: input.to_string(),
718 input,
719 is_input_complete: false,
720 },
721 ));
722
723 // Input streaming completed
724 let input = json!({ "content": "Thinking hard!" });
725 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
726 LanguageModelToolUse {
727 id: "1".into(),
728 name: "thinking".into(),
729 raw_input: input.to_string(),
730 input,
731 is_input_complete: true,
732 },
733 ));
734 fake_model.end_last_completion_stream();
735 cx.run_until_parked();
736
737 let tool_call = expect_tool_call(&mut events).await;
738 assert_eq!(
739 tool_call,
740 acp::ToolCall {
741 id: acp::ToolCallId("1".into()),
742 title: "Thinking".into(),
743 kind: acp::ToolKind::Think,
744 status: acp::ToolCallStatus::Pending,
745 content: vec![],
746 locations: vec![],
747 raw_input: Some(json!({})),
748 raw_output: None,
749 }
750 );
751 let update = expect_tool_call_update_fields(&mut events).await;
752 assert_eq!(
753 update,
754 acp::ToolCallUpdate {
755 id: acp::ToolCallId("1".into()),
756 fields: acp::ToolCallUpdateFields {
757 title: Some("Thinking".into()),
758 kind: Some(acp::ToolKind::Think),
759 raw_input: Some(json!({ "content": "Thinking hard!" })),
760 ..Default::default()
761 },
762 }
763 );
764 let update = expect_tool_call_update_fields(&mut events).await;
765 assert_eq!(
766 update,
767 acp::ToolCallUpdate {
768 id: acp::ToolCallId("1".into()),
769 fields: acp::ToolCallUpdateFields {
770 status: Some(acp::ToolCallStatus::InProgress),
771 ..Default::default()
772 },
773 }
774 );
775 let update = expect_tool_call_update_fields(&mut events).await;
776 assert_eq!(
777 update,
778 acp::ToolCallUpdate {
779 id: acp::ToolCallId("1".into()),
780 fields: acp::ToolCallUpdateFields {
781 content: Some(vec!["Thinking hard!".into()]),
782 ..Default::default()
783 },
784 }
785 );
786 let update = expect_tool_call_update_fields(&mut events).await;
787 assert_eq!(
788 update,
789 acp::ToolCallUpdate {
790 id: acp::ToolCallId("1".into()),
791 fields: acp::ToolCallUpdateFields {
792 status: Some(acp::ToolCallStatus::Completed),
793 ..Default::default()
794 },
795 }
796 );
797}
798
799/// Filters out the stop events for asserting against in tests
800fn stop_events(
801 result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
802) -> Vec<acp::StopReason> {
803 result_events
804 .into_iter()
805 .filter_map(|event| match event.unwrap() {
806 AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
807 _ => None,
808 })
809 .collect()
810}
811
812struct ThreadTest {
813 model: Arc<dyn LanguageModel>,
814 thread: Entity<Thread>,
815 project_context: Rc<RefCell<ProjectContext>>,
816}
817
818enum TestModel {
819 Sonnet4,
820 Sonnet4Thinking,
821 Fake,
822}
823
824impl TestModel {
825 fn id(&self) -> LanguageModelId {
826 match self {
827 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
828 TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
829 TestModel::Fake => unreachable!(),
830 }
831 }
832}
833
834async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
835 cx.executor().allow_parking();
836
837 let fs = FakeFs::new(cx.background_executor.clone());
838
839 cx.update(|cx| {
840 settings::init(cx);
841 watch_settings(fs.clone(), cx);
842 Project::init_settings(cx);
843 agent_settings::init(cx);
844 });
845 let templates = Templates::new();
846
847 fs.insert_tree(path!("/test"), json!({})).await;
848 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
849
850 let model = cx
851 .update(|cx| {
852 gpui_tokio::init(cx);
853 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
854 cx.set_http_client(Arc::new(http_client));
855
856 client::init_settings(cx);
857 let client = Client::production(cx);
858 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
859 language_model::init(client.clone(), cx);
860 language_models::init(user_store.clone(), client.clone(), cx);
861
862 if let TestModel::Fake = model {
863 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
864 } else {
865 let model_id = model.id();
866 let models = LanguageModelRegistry::read_global(cx);
867 let model = models
868 .available_models(cx)
869 .find(|model| model.id() == model_id)
870 .unwrap();
871
872 let provider = models.provider(&model.provider_id()).unwrap();
873 let authenticated = provider.authenticate(cx);
874
875 cx.spawn(async move |_cx| {
876 authenticated.await.unwrap();
877 model
878 })
879 }
880 })
881 .await;
882
883 let project_context = Rc::new(RefCell::new(ProjectContext::default()));
884 let action_log = cx.new(|_| ActionLog::new(project.clone()));
885 let thread = cx.new(|_| {
886 Thread::new(
887 project,
888 project_context.clone(),
889 action_log,
890 templates,
891 model.clone(),
892 )
893 });
894 ThreadTest {
895 model,
896 thread,
897 project_context,
898 }
899}
900
901#[cfg(test)]
902#[ctor::ctor]
903fn init_logger() {
904 if std::env::var("RUST_LOG").is_ok() {
905 env_logger::init();
906 }
907}
908
909fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
910 let fs = fs.clone();
911 cx.spawn({
912 async move |cx| {
913 let mut new_settings_content_rx = settings::watch_config_file(
914 cx.background_executor(),
915 fs,
916 paths::settings_file().clone(),
917 );
918
919 while let Some(new_settings_content) = new_settings_content_rx.next().await {
920 cx.update(|cx| {
921 SettingsStore::update_global(cx, |settings, cx| {
922 settings.set_user_settings(&new_settings_content, cx)
923 })
924 })
925 .ok();
926 }
927 }
928 })
929 .detach();
930}