1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use context_server::{ContextServer, ContextServerCommand, ContextServerId};
8use fs::{FakeFs, Fs};
9use futures::{
10 StreamExt,
11 channel::{
12 mpsc::{self, UnboundedReceiver},
13 oneshot,
14 },
15};
16use gpui::{
17 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
18};
19use indoc::indoc;
20use language_model::{
21 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
22 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
23 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
24 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
25};
26use pretty_assertions::assert_eq;
27use project::{
28 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
29};
30use prompt_store::ProjectContext;
31use reqwest_client::ReqwestClient;
32use schemars::JsonSchema;
33use serde::{Deserialize, Serialize};
34use serde_json::json;
35use settings::{Settings, SettingsStore};
36use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
37use util::path;
38
39mod test_tools;
40use test_tools::*;
41
42#[gpui::test]
43async fn test_echo(cx: &mut TestAppContext) {
44 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
45 let fake_model = model.as_fake();
46
47 let events = thread
48 .update(cx, |thread, cx| {
49 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
50 })
51 .unwrap();
52 cx.run_until_parked();
53 fake_model.send_last_completion_stream_text_chunk("Hello");
54 fake_model
55 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
56 fake_model.end_last_completion_stream();
57
58 let events = events.collect().await;
59 thread.update(cx, |thread, _cx| {
60 assert_eq!(
61 thread.last_message().unwrap().to_markdown(),
62 indoc! {"
63 ## Assistant
64
65 Hello
66 "}
67 )
68 });
69 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
70}
71
72#[gpui::test]
73async fn test_thinking(cx: &mut TestAppContext) {
74 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
75 let fake_model = model.as_fake();
76
77 let events = thread
78 .update(cx, |thread, cx| {
79 thread.send(
80 UserMessageId::new(),
81 [indoc! {"
82 Testing:
83
84 Generate a thinking step where you just think the word 'Think',
85 and have your final answer be 'Hello'
86 "}],
87 cx,
88 )
89 })
90 .unwrap();
91 cx.run_until_parked();
92 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
93 text: "Think".to_string(),
94 signature: None,
95 });
96 fake_model.send_last_completion_stream_text_chunk("Hello");
97 fake_model
98 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
99 fake_model.end_last_completion_stream();
100
101 let events = events.collect().await;
102 thread.update(cx, |thread, _cx| {
103 assert_eq!(
104 thread.last_message().unwrap().to_markdown(),
105 indoc! {"
106 ## Assistant
107
108 <think>Think</think>
109 Hello
110 "}
111 )
112 });
113 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
114}
115
116#[gpui::test]
117async fn test_system_prompt(cx: &mut TestAppContext) {
118 let ThreadTest {
119 model,
120 thread,
121 project_context,
122 ..
123 } = setup(cx, TestModel::Fake).await;
124 let fake_model = model.as_fake();
125
126 project_context.update(cx, |project_context, _cx| {
127 project_context.shell = "test-shell".into()
128 });
129 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
130 thread
131 .update(cx, |thread, cx| {
132 thread.send(UserMessageId::new(), ["abc"], cx)
133 })
134 .unwrap();
135 cx.run_until_parked();
136 let mut pending_completions = fake_model.pending_completions();
137 assert_eq!(
138 pending_completions.len(),
139 1,
140 "unexpected pending completions: {:?}",
141 pending_completions
142 );
143
144 let pending_completion = pending_completions.pop().unwrap();
145 assert_eq!(pending_completion.messages[0].role, Role::System);
146
147 let system_message = &pending_completion.messages[0];
148 let system_prompt = system_message.content[0].to_str().unwrap();
149 assert!(
150 system_prompt.contains("test-shell"),
151 "unexpected system message: {:?}",
152 system_message
153 );
154 assert!(
155 system_prompt.contains("## Fixing Diagnostics"),
156 "unexpected system message: {:?}",
157 system_message
158 );
159}
160
161#[gpui::test]
162async fn test_prompt_caching(cx: &mut TestAppContext) {
163 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
164 let fake_model = model.as_fake();
165
166 // Send initial user message and verify it's cached
167 thread
168 .update(cx, |thread, cx| {
169 thread.send(UserMessageId::new(), ["Message 1"], cx)
170 })
171 .unwrap();
172 cx.run_until_parked();
173
174 let completion = fake_model.pending_completions().pop().unwrap();
175 assert_eq!(
176 completion.messages[1..],
177 vec![LanguageModelRequestMessage {
178 role: Role::User,
179 content: vec!["Message 1".into()],
180 cache: true
181 }]
182 );
183 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
184 "Response to Message 1".into(),
185 ));
186 fake_model.end_last_completion_stream();
187 cx.run_until_parked();
188
189 // Send another user message and verify only the latest is cached
190 thread
191 .update(cx, |thread, cx| {
192 thread.send(UserMessageId::new(), ["Message 2"], cx)
193 })
194 .unwrap();
195 cx.run_until_parked();
196
197 let completion = fake_model.pending_completions().pop().unwrap();
198 assert_eq!(
199 completion.messages[1..],
200 vec![
201 LanguageModelRequestMessage {
202 role: Role::User,
203 content: vec!["Message 1".into()],
204 cache: false
205 },
206 LanguageModelRequestMessage {
207 role: Role::Assistant,
208 content: vec!["Response to Message 1".into()],
209 cache: false
210 },
211 LanguageModelRequestMessage {
212 role: Role::User,
213 content: vec!["Message 2".into()],
214 cache: true
215 }
216 ]
217 );
218 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
219 "Response to Message 2".into(),
220 ));
221 fake_model.end_last_completion_stream();
222 cx.run_until_parked();
223
224 // Simulate a tool call and verify that the latest tool result is cached
225 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
226 thread
227 .update(cx, |thread, cx| {
228 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
229 })
230 .unwrap();
231 cx.run_until_parked();
232
233 let tool_use = LanguageModelToolUse {
234 id: "tool_1".into(),
235 name: EchoTool::name().into(),
236 raw_input: json!({"text": "test"}).to_string(),
237 input: json!({"text": "test"}),
238 is_input_complete: true,
239 };
240 fake_model
241 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
242 fake_model.end_last_completion_stream();
243 cx.run_until_parked();
244
245 let completion = fake_model.pending_completions().pop().unwrap();
246 let tool_result = LanguageModelToolResult {
247 tool_use_id: "tool_1".into(),
248 tool_name: EchoTool::name().into(),
249 is_error: false,
250 content: "test".into(),
251 output: Some("test".into()),
252 };
253 assert_eq!(
254 completion.messages[1..],
255 vec![
256 LanguageModelRequestMessage {
257 role: Role::User,
258 content: vec!["Message 1".into()],
259 cache: false
260 },
261 LanguageModelRequestMessage {
262 role: Role::Assistant,
263 content: vec!["Response to Message 1".into()],
264 cache: false
265 },
266 LanguageModelRequestMessage {
267 role: Role::User,
268 content: vec!["Message 2".into()],
269 cache: false
270 },
271 LanguageModelRequestMessage {
272 role: Role::Assistant,
273 content: vec!["Response to Message 2".into()],
274 cache: false
275 },
276 LanguageModelRequestMessage {
277 role: Role::User,
278 content: vec!["Use the echo tool".into()],
279 cache: false
280 },
281 LanguageModelRequestMessage {
282 role: Role::Assistant,
283 content: vec![MessageContent::ToolUse(tool_use)],
284 cache: false
285 },
286 LanguageModelRequestMessage {
287 role: Role::User,
288 content: vec![MessageContent::ToolResult(tool_result)],
289 cache: true
290 }
291 ]
292 );
293}
294
295#[gpui::test]
296#[cfg_attr(not(feature = "e2e"), ignore)]
297async fn test_basic_tool_calls(cx: &mut TestAppContext) {
298 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
299
300 // Test a tool call that's likely to complete *before* streaming stops.
301 let events = thread
302 .update(cx, |thread, cx| {
303 thread.add_tool(EchoTool);
304 thread.send(
305 UserMessageId::new(),
306 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
307 cx,
308 )
309 })
310 .unwrap()
311 .collect()
312 .await;
313 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
314
315 // Test a tool calls that's likely to complete *after* streaming stops.
316 let events = thread
317 .update(cx, |thread, cx| {
318 thread.remove_tool(&EchoTool::name());
319 thread.add_tool(DelayTool);
320 thread.send(
321 UserMessageId::new(),
322 [
323 "Now call the delay tool with 200ms.",
324 "When the timer goes off, then you echo the output of the tool.",
325 ],
326 cx,
327 )
328 })
329 .unwrap()
330 .collect()
331 .await;
332 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
333 thread.update(cx, |thread, _cx| {
334 assert!(
335 thread
336 .last_message()
337 .unwrap()
338 .as_agent_message()
339 .unwrap()
340 .content
341 .iter()
342 .any(|content| {
343 if let AgentMessageContent::Text(text) = content {
344 text.contains("Ding")
345 } else {
346 false
347 }
348 }),
349 "{}",
350 thread.to_markdown()
351 );
352 });
353}
354
355#[gpui::test]
356#[cfg_attr(not(feature = "e2e"), ignore)]
357async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
358 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
359
360 // Test a tool call that's likely to complete *before* streaming stops.
361 let mut events = thread
362 .update(cx, |thread, cx| {
363 thread.add_tool(WordListTool);
364 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
365 })
366 .unwrap();
367
368 let mut saw_partial_tool_use = false;
369 while let Some(event) = events.next().await {
370 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
371 thread.update(cx, |thread, _cx| {
372 // Look for a tool use in the thread's last message
373 let message = thread.last_message().unwrap();
374 let agent_message = message.as_agent_message().unwrap();
375 let last_content = agent_message.content.last().unwrap();
376 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
377 assert_eq!(last_tool_use.name.as_ref(), "word_list");
378 if tool_call.status == acp::ToolCallStatus::Pending {
379 if !last_tool_use.is_input_complete
380 && last_tool_use.input.get("g").is_none()
381 {
382 saw_partial_tool_use = true;
383 }
384 } else {
385 last_tool_use
386 .input
387 .get("a")
388 .expect("'a' has streamed because input is now complete");
389 last_tool_use
390 .input
391 .get("g")
392 .expect("'g' has streamed because input is now complete");
393 }
394 } else {
395 panic!("last content should be a tool use");
396 }
397 });
398 }
399 }
400
401 assert!(
402 saw_partial_tool_use,
403 "should see at least one partially streamed tool use in the history"
404 );
405}
406
407#[gpui::test]
408async fn test_tool_authorization(cx: &mut TestAppContext) {
409 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
410 let fake_model = model.as_fake();
411
412 let mut events = thread
413 .update(cx, |thread, cx| {
414 thread.add_tool(ToolRequiringPermission);
415 thread.send(UserMessageId::new(), ["abc"], cx)
416 })
417 .unwrap();
418 cx.run_until_parked();
419 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
420 LanguageModelToolUse {
421 id: "tool_id_1".into(),
422 name: ToolRequiringPermission::name().into(),
423 raw_input: "{}".into(),
424 input: json!({}),
425 is_input_complete: true,
426 },
427 ));
428 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
429 LanguageModelToolUse {
430 id: "tool_id_2".into(),
431 name: ToolRequiringPermission::name().into(),
432 raw_input: "{}".into(),
433 input: json!({}),
434 is_input_complete: true,
435 },
436 ));
437 fake_model.end_last_completion_stream();
438 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
439 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
440
441 // Approve the first
442 tool_call_auth_1
443 .response
444 .send(tool_call_auth_1.options[1].id.clone())
445 .unwrap();
446 cx.run_until_parked();
447
448 // Reject the second
449 tool_call_auth_2
450 .response
451 .send(tool_call_auth_1.options[2].id.clone())
452 .unwrap();
453 cx.run_until_parked();
454
455 let completion = fake_model.pending_completions().pop().unwrap();
456 let message = completion.messages.last().unwrap();
457 assert_eq!(
458 message.content,
459 vec![
460 language_model::MessageContent::ToolResult(LanguageModelToolResult {
461 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
462 tool_name: ToolRequiringPermission::name().into(),
463 is_error: false,
464 content: "Allowed".into(),
465 output: Some("Allowed".into())
466 }),
467 language_model::MessageContent::ToolResult(LanguageModelToolResult {
468 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
469 tool_name: ToolRequiringPermission::name().into(),
470 is_error: true,
471 content: "Permission to run tool denied by user".into(),
472 output: None
473 })
474 ]
475 );
476
477 // Simulate yet another tool call.
478 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
479 LanguageModelToolUse {
480 id: "tool_id_3".into(),
481 name: ToolRequiringPermission::name().into(),
482 raw_input: "{}".into(),
483 input: json!({}),
484 is_input_complete: true,
485 },
486 ));
487 fake_model.end_last_completion_stream();
488
489 // Respond by always allowing tools.
490 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
491 tool_call_auth_3
492 .response
493 .send(tool_call_auth_3.options[0].id.clone())
494 .unwrap();
495 cx.run_until_parked();
496 let completion = fake_model.pending_completions().pop().unwrap();
497 let message = completion.messages.last().unwrap();
498 assert_eq!(
499 message.content,
500 vec![language_model::MessageContent::ToolResult(
501 LanguageModelToolResult {
502 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
503 tool_name: ToolRequiringPermission::name().into(),
504 is_error: false,
505 content: "Allowed".into(),
506 output: Some("Allowed".into())
507 }
508 )]
509 );
510
511 // Simulate a final tool call, ensuring we don't trigger authorization.
512 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
513 LanguageModelToolUse {
514 id: "tool_id_4".into(),
515 name: ToolRequiringPermission::name().into(),
516 raw_input: "{}".into(),
517 input: json!({}),
518 is_input_complete: true,
519 },
520 ));
521 fake_model.end_last_completion_stream();
522 cx.run_until_parked();
523 let completion = fake_model.pending_completions().pop().unwrap();
524 let message = completion.messages.last().unwrap();
525 assert_eq!(
526 message.content,
527 vec![language_model::MessageContent::ToolResult(
528 LanguageModelToolResult {
529 tool_use_id: "tool_id_4".into(),
530 tool_name: ToolRequiringPermission::name().into(),
531 is_error: false,
532 content: "Allowed".into(),
533 output: Some("Allowed".into())
534 }
535 )]
536 );
537}
538
539#[gpui::test]
540async fn test_tool_hallucination(cx: &mut TestAppContext) {
541 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
542 let fake_model = model.as_fake();
543
544 let mut events = thread
545 .update(cx, |thread, cx| {
546 thread.send(UserMessageId::new(), ["abc"], cx)
547 })
548 .unwrap();
549 cx.run_until_parked();
550 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
551 LanguageModelToolUse {
552 id: "tool_id_1".into(),
553 name: "nonexistent_tool".into(),
554 raw_input: "{}".into(),
555 input: json!({}),
556 is_input_complete: true,
557 },
558 ));
559 fake_model.end_last_completion_stream();
560
561 let tool_call = expect_tool_call(&mut events).await;
562 assert_eq!(tool_call.title, "nonexistent_tool");
563 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
564 let update = expect_tool_call_update_fields(&mut events).await;
565 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
566}
567
568#[gpui::test]
569async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
570 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
571 let fake_model = model.as_fake();
572
573 let events = thread
574 .update(cx, |thread, cx| {
575 thread.add_tool(EchoTool);
576 thread.send(UserMessageId::new(), ["abc"], cx)
577 })
578 .unwrap();
579 cx.run_until_parked();
580 let tool_use = LanguageModelToolUse {
581 id: "tool_id_1".into(),
582 name: EchoTool::name().into(),
583 raw_input: "{}".into(),
584 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
585 is_input_complete: true,
586 };
587 fake_model
588 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
589 fake_model.end_last_completion_stream();
590
591 cx.run_until_parked();
592 let completion = fake_model.pending_completions().pop().unwrap();
593 let tool_result = LanguageModelToolResult {
594 tool_use_id: "tool_id_1".into(),
595 tool_name: EchoTool::name().into(),
596 is_error: false,
597 content: "def".into(),
598 output: Some("def".into()),
599 };
600 assert_eq!(
601 completion.messages[1..],
602 vec![
603 LanguageModelRequestMessage {
604 role: Role::User,
605 content: vec!["abc".into()],
606 cache: false
607 },
608 LanguageModelRequestMessage {
609 role: Role::Assistant,
610 content: vec![MessageContent::ToolUse(tool_use.clone())],
611 cache: false
612 },
613 LanguageModelRequestMessage {
614 role: Role::User,
615 content: vec![MessageContent::ToolResult(tool_result.clone())],
616 cache: true
617 },
618 ]
619 );
620
621 // Simulate reaching tool use limit.
622 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
623 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
624 ));
625 fake_model.end_last_completion_stream();
626 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
627 assert!(
628 last_event
629 .unwrap_err()
630 .is::<language_model::ToolUseLimitReachedError>()
631 );
632
633 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
634 cx.run_until_parked();
635 let completion = fake_model.pending_completions().pop().unwrap();
636 assert_eq!(
637 completion.messages[1..],
638 vec![
639 LanguageModelRequestMessage {
640 role: Role::User,
641 content: vec!["abc".into()],
642 cache: false
643 },
644 LanguageModelRequestMessage {
645 role: Role::Assistant,
646 content: vec![MessageContent::ToolUse(tool_use)],
647 cache: false
648 },
649 LanguageModelRequestMessage {
650 role: Role::User,
651 content: vec![MessageContent::ToolResult(tool_result)],
652 cache: false
653 },
654 LanguageModelRequestMessage {
655 role: Role::User,
656 content: vec!["Continue where you left off".into()],
657 cache: true
658 }
659 ]
660 );
661
662 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
663 fake_model.end_last_completion_stream();
664 events.collect::<Vec<_>>().await;
665 thread.read_with(cx, |thread, _cx| {
666 assert_eq!(
667 thread.last_message().unwrap().to_markdown(),
668 indoc! {"
669 ## Assistant
670
671 Done
672 "}
673 )
674 });
675
676 // Ensure we error if calling resume when tool use limit was *not* reached.
677 let error = thread
678 .update(cx, |thread, cx| thread.resume(cx))
679 .unwrap_err();
680 assert_eq!(
681 error.to_string(),
682 "can only resume after tool use limit is reached"
683 )
684}
685
686#[gpui::test]
687async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
688 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
689 let fake_model = model.as_fake();
690
691 let events = thread
692 .update(cx, |thread, cx| {
693 thread.add_tool(EchoTool);
694 thread.send(UserMessageId::new(), ["abc"], cx)
695 })
696 .unwrap();
697 cx.run_until_parked();
698
699 let tool_use = LanguageModelToolUse {
700 id: "tool_id_1".into(),
701 name: EchoTool::name().into(),
702 raw_input: "{}".into(),
703 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
704 is_input_complete: true,
705 };
706 let tool_result = LanguageModelToolResult {
707 tool_use_id: "tool_id_1".into(),
708 tool_name: EchoTool::name().into(),
709 is_error: false,
710 content: "def".into(),
711 output: Some("def".into()),
712 };
713 fake_model
714 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
715 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
716 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
717 ));
718 fake_model.end_last_completion_stream();
719 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
720 assert!(
721 last_event
722 .unwrap_err()
723 .is::<language_model::ToolUseLimitReachedError>()
724 );
725
726 thread
727 .update(cx, |thread, cx| {
728 thread.send(UserMessageId::new(), vec!["ghi"], cx)
729 })
730 .unwrap();
731 cx.run_until_parked();
732 let completion = fake_model.pending_completions().pop().unwrap();
733 assert_eq!(
734 completion.messages[1..],
735 vec![
736 LanguageModelRequestMessage {
737 role: Role::User,
738 content: vec!["abc".into()],
739 cache: false
740 },
741 LanguageModelRequestMessage {
742 role: Role::Assistant,
743 content: vec![MessageContent::ToolUse(tool_use)],
744 cache: false
745 },
746 LanguageModelRequestMessage {
747 role: Role::User,
748 content: vec![MessageContent::ToolResult(tool_result)],
749 cache: false
750 },
751 LanguageModelRequestMessage {
752 role: Role::User,
753 content: vec!["ghi".into()],
754 cache: true
755 }
756 ]
757 );
758}
759
760async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
761 let event = events
762 .next()
763 .await
764 .expect("no tool call authorization event received")
765 .unwrap();
766 match event {
767 ThreadEvent::ToolCall(tool_call) => tool_call,
768 event => {
769 panic!("Unexpected event {event:?}");
770 }
771 }
772}
773
774async fn expect_tool_call_update_fields(
775 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
776) -> acp::ToolCallUpdate {
777 let event = events
778 .next()
779 .await
780 .expect("no tool call authorization event received")
781 .unwrap();
782 match event {
783 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
784 event => {
785 panic!("Unexpected event {event:?}");
786 }
787 }
788}
789
790async fn next_tool_call_authorization(
791 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
792) -> ToolCallAuthorization {
793 loop {
794 let event = events
795 .next()
796 .await
797 .expect("no tool call authorization event received")
798 .unwrap();
799 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
800 let permission_kinds = tool_call_authorization
801 .options
802 .iter()
803 .map(|o| o.kind)
804 .collect::<Vec<_>>();
805 assert_eq!(
806 permission_kinds,
807 vec![
808 acp::PermissionOptionKind::AllowAlways,
809 acp::PermissionOptionKind::AllowOnce,
810 acp::PermissionOptionKind::RejectOnce,
811 ]
812 );
813 return tool_call_authorization;
814 }
815 }
816}
817
818#[gpui::test]
819#[cfg_attr(not(feature = "e2e"), ignore)]
820async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
821 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
822
823 // Test concurrent tool calls with different delay times
824 let events = thread
825 .update(cx, |thread, cx| {
826 thread.add_tool(DelayTool);
827 thread.send(
828 UserMessageId::new(),
829 [
830 "Call the delay tool twice in the same message.",
831 "Once with 100ms. Once with 300ms.",
832 "When both timers are complete, describe the outputs.",
833 ],
834 cx,
835 )
836 })
837 .unwrap()
838 .collect()
839 .await;
840
841 let stop_reasons = stop_events(events);
842 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
843
844 thread.update(cx, |thread, _cx| {
845 let last_message = thread.last_message().unwrap();
846 let agent_message = last_message.as_agent_message().unwrap();
847 let text = agent_message
848 .content
849 .iter()
850 .filter_map(|content| {
851 if let AgentMessageContent::Text(text) = content {
852 Some(text.as_str())
853 } else {
854 None
855 }
856 })
857 .collect::<String>();
858
859 assert!(text.contains("Ding"));
860 });
861}
862
863#[gpui::test]
864async fn test_profiles(cx: &mut TestAppContext) {
865 let ThreadTest {
866 model, thread, fs, ..
867 } = setup(cx, TestModel::Fake).await;
868 let fake_model = model.as_fake();
869
870 thread.update(cx, |thread, _cx| {
871 thread.add_tool(DelayTool);
872 thread.add_tool(EchoTool);
873 thread.add_tool(InfiniteTool);
874 });
875
876 // Override profiles and wait for settings to be loaded.
877 fs.insert_file(
878 paths::settings_file(),
879 json!({
880 "agent": {
881 "profiles": {
882 "test-1": {
883 "name": "Test Profile 1",
884 "tools": {
885 EchoTool::name(): true,
886 DelayTool::name(): true,
887 }
888 },
889 "test-2": {
890 "name": "Test Profile 2",
891 "tools": {
892 InfiniteTool::name(): true,
893 }
894 }
895 }
896 }
897 })
898 .to_string()
899 .into_bytes(),
900 )
901 .await;
902 cx.run_until_parked();
903
904 // Test that test-1 profile (default) has echo and delay tools
905 thread
906 .update(cx, |thread, cx| {
907 thread.set_profile(AgentProfileId("test-1".into()));
908 thread.send(UserMessageId::new(), ["test"], cx)
909 })
910 .unwrap();
911 cx.run_until_parked();
912
913 let mut pending_completions = fake_model.pending_completions();
914 assert_eq!(pending_completions.len(), 1);
915 let completion = pending_completions.pop().unwrap();
916 let tool_names: Vec<String> = completion
917 .tools
918 .iter()
919 .map(|tool| tool.name.clone())
920 .collect();
921 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
922 fake_model.end_last_completion_stream();
923
924 // Switch to test-2 profile, and verify that it has only the infinite tool.
925 thread
926 .update(cx, |thread, cx| {
927 thread.set_profile(AgentProfileId("test-2".into()));
928 thread.send(UserMessageId::new(), ["test2"], cx)
929 })
930 .unwrap();
931 cx.run_until_parked();
932 let mut pending_completions = fake_model.pending_completions();
933 assert_eq!(pending_completions.len(), 1);
934 let completion = pending_completions.pop().unwrap();
935 let tool_names: Vec<String> = completion
936 .tools
937 .iter()
938 .map(|tool| tool.name.clone())
939 .collect();
940 assert_eq!(tool_names, vec![InfiniteTool::name()]);
941}
942
943#[gpui::test]
944async fn test_mcp_tools(cx: &mut TestAppContext) {
945 let ThreadTest {
946 model,
947 thread,
948 context_server_store,
949 fs,
950 ..
951 } = setup(cx, TestModel::Fake).await;
952 let fake_model = model.as_fake();
953
954 // Override profiles and wait for settings to be loaded.
955 fs.insert_file(
956 paths::settings_file(),
957 json!({
958 "agent": {
959 "profiles": {
960 "test": {
961 "name": "Test Profile",
962 "enable_all_context_servers": true,
963 "tools": {
964 EchoTool::name(): true,
965 }
966 },
967 }
968 }
969 })
970 .to_string()
971 .into_bytes(),
972 )
973 .await;
974 cx.run_until_parked();
975 thread.update(cx, |thread, _| {
976 thread.set_profile(AgentProfileId("test".into()))
977 });
978
979 let mut mcp_tool_calls = setup_context_server(
980 "test_server",
981 vec![context_server::types::Tool {
982 name: "echo".into(),
983 description: None,
984 input_schema: serde_json::to_value(
985 EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
986 )
987 .unwrap(),
988 output_schema: None,
989 annotations: None,
990 }],
991 &context_server_store,
992 cx,
993 );
994
995 let events = thread.update(cx, |thread, cx| {
996 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
997 });
998 cx.run_until_parked();
999
1000 // Simulate the model calling the MCP tool.
1001 let completion = fake_model.pending_completions().pop().unwrap();
1002 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1003 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1004 LanguageModelToolUse {
1005 id: "tool_1".into(),
1006 name: "echo".into(),
1007 raw_input: json!({"text": "test"}).to_string(),
1008 input: json!({"text": "test"}),
1009 is_input_complete: true,
1010 },
1011 ));
1012 fake_model.end_last_completion_stream();
1013 cx.run_until_parked();
1014
1015 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1016 assert_eq!(tool_call_params.name, "echo");
1017 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1018 tool_call_response
1019 .send(context_server::types::CallToolResponse {
1020 content: vec![context_server::types::ToolResponseContent::Text {
1021 text: "test".into(),
1022 }],
1023 is_error: None,
1024 meta: None,
1025 structured_content: None,
1026 })
1027 .unwrap();
1028 cx.run_until_parked();
1029
1030 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1031 fake_model.send_last_completion_stream_text_chunk("Done!");
1032 fake_model.end_last_completion_stream();
1033 events.collect::<Vec<_>>().await;
1034
1035 // Send again after adding the echo tool, ensuring the name collision is resolved.
1036 let events = thread.update(cx, |thread, cx| {
1037 thread.add_tool(EchoTool);
1038 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1039 });
1040 cx.run_until_parked();
1041 let completion = fake_model.pending_completions().pop().unwrap();
1042 assert_eq!(
1043 tool_names_for_completion(&completion),
1044 vec!["echo", "test_server_echo"]
1045 );
1046 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1047 LanguageModelToolUse {
1048 id: "tool_2".into(),
1049 name: "test_server_echo".into(),
1050 raw_input: json!({"text": "mcp"}).to_string(),
1051 input: json!({"text": "mcp"}),
1052 is_input_complete: true,
1053 },
1054 ));
1055 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1056 LanguageModelToolUse {
1057 id: "tool_3".into(),
1058 name: "echo".into(),
1059 raw_input: json!({"text": "native"}).to_string(),
1060 input: json!({"text": "native"}),
1061 is_input_complete: true,
1062 },
1063 ));
1064 fake_model.end_last_completion_stream();
1065 cx.run_until_parked();
1066
1067 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1068 assert_eq!(tool_call_params.name, "echo");
1069 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1070 tool_call_response
1071 .send(context_server::types::CallToolResponse {
1072 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1073 is_error: None,
1074 meta: None,
1075 structured_content: None,
1076 })
1077 .unwrap();
1078 cx.run_until_parked();
1079
1080 // Ensure the tool results were inserted with the correct names.
1081 let completion = fake_model.pending_completions().pop().unwrap();
1082 assert_eq!(
1083 completion.messages.last().unwrap().content,
1084 vec![
1085 MessageContent::ToolResult(LanguageModelToolResult {
1086 tool_use_id: "tool_3".into(),
1087 tool_name: "echo".into(),
1088 is_error: false,
1089 content: "native".into(),
1090 output: Some("native".into()),
1091 },),
1092 MessageContent::ToolResult(LanguageModelToolResult {
1093 tool_use_id: "tool_2".into(),
1094 tool_name: "test_server_echo".into(),
1095 is_error: false,
1096 content: "mcp".into(),
1097 output: Some("mcp".into()),
1098 },),
1099 ]
1100 );
1101 fake_model.end_last_completion_stream();
1102 events.collect::<Vec<_>>().await;
1103}
1104
1105#[gpui::test]
1106async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1107 let ThreadTest {
1108 model,
1109 thread,
1110 context_server_store,
1111 fs,
1112 ..
1113 } = setup(cx, TestModel::Fake).await;
1114 let fake_model = model.as_fake();
1115
1116 // Set up a profile with all tools enabled
1117 fs.insert_file(
1118 paths::settings_file(),
1119 json!({
1120 "agent": {
1121 "profiles": {
1122 "test": {
1123 "name": "Test Profile",
1124 "enable_all_context_servers": true,
1125 "tools": {
1126 EchoTool::name(): true,
1127 DelayTool::name(): true,
1128 WordListTool::name(): true,
1129 ToolRequiringPermission::name(): true,
1130 InfiniteTool::name(): true,
1131 }
1132 },
1133 }
1134 }
1135 })
1136 .to_string()
1137 .into_bytes(),
1138 )
1139 .await;
1140 cx.run_until_parked();
1141
1142 thread.update(cx, |thread, _| {
1143 thread.set_profile(AgentProfileId("test".into()));
1144 thread.add_tool(EchoTool);
1145 thread.add_tool(DelayTool);
1146 thread.add_tool(WordListTool);
1147 thread.add_tool(ToolRequiringPermission);
1148 thread.add_tool(InfiniteTool);
1149 });
1150
1151 // Set up multiple context servers with some overlapping tool names
1152 let _server1_calls = setup_context_server(
1153 "xxx",
1154 vec![
1155 context_server::types::Tool {
1156 name: "echo".into(), // Conflicts with native EchoTool
1157 description: None,
1158 input_schema: serde_json::to_value(
1159 EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1160 )
1161 .unwrap(),
1162 output_schema: None,
1163 annotations: None,
1164 },
1165 context_server::types::Tool {
1166 name: "unique_tool_1".into(),
1167 description: None,
1168 input_schema: json!({"type": "object", "properties": {}}),
1169 output_schema: None,
1170 annotations: None,
1171 },
1172 ],
1173 &context_server_store,
1174 cx,
1175 );
1176
1177 let _server2_calls = setup_context_server(
1178 "yyy",
1179 vec![
1180 context_server::types::Tool {
1181 name: "echo".into(), // Also conflicts with native EchoTool
1182 description: None,
1183 input_schema: serde_json::to_value(
1184 EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1185 )
1186 .unwrap(),
1187 output_schema: None,
1188 annotations: None,
1189 },
1190 context_server::types::Tool {
1191 name: "unique_tool_2".into(),
1192 description: None,
1193 input_schema: json!({"type": "object", "properties": {}}),
1194 output_schema: None,
1195 annotations: None,
1196 },
1197 context_server::types::Tool {
1198 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1199 description: None,
1200 input_schema: json!({"type": "object", "properties": {}}),
1201 output_schema: None,
1202 annotations: None,
1203 },
1204 context_server::types::Tool {
1205 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1206 description: None,
1207 input_schema: json!({"type": "object", "properties": {}}),
1208 output_schema: None,
1209 annotations: None,
1210 },
1211 ],
1212 &context_server_store,
1213 cx,
1214 );
1215 let _server3_calls = setup_context_server(
1216 "zzz",
1217 vec![
1218 context_server::types::Tool {
1219 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1220 description: None,
1221 input_schema: json!({"type": "object", "properties": {}}),
1222 output_schema: None,
1223 annotations: None,
1224 },
1225 context_server::types::Tool {
1226 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1227 description: None,
1228 input_schema: json!({"type": "object", "properties": {}}),
1229 output_schema: None,
1230 annotations: None,
1231 },
1232 context_server::types::Tool {
1233 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1234 description: None,
1235 input_schema: json!({"type": "object", "properties": {}}),
1236 output_schema: None,
1237 annotations: None,
1238 },
1239 ],
1240 &context_server_store,
1241 cx,
1242 );
1243
1244 thread
1245 .update(cx, |thread, cx| {
1246 thread.send(UserMessageId::new(), ["Go"], cx)
1247 })
1248 .unwrap();
1249 cx.run_until_parked();
1250 let completion = fake_model.pending_completions().pop().unwrap();
1251 assert_eq!(
1252 tool_names_for_completion(&completion),
1253 vec![
1254 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1255 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1256 "delay",
1257 "echo",
1258 "infinite",
1259 "tool_requiring_permission",
1260 "unique_tool_1",
1261 "unique_tool_2",
1262 "word_list",
1263 "xxx_echo",
1264 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1265 "yyy_echo",
1266 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1267 ]
1268 );
1269}
1270
1271#[gpui::test]
1272#[cfg_attr(not(feature = "e2e"), ignore)]
1273async fn test_cancellation(cx: &mut TestAppContext) {
1274 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1275
1276 let mut events = thread
1277 .update(cx, |thread, cx| {
1278 thread.add_tool(InfiniteTool);
1279 thread.add_tool(EchoTool);
1280 thread.send(
1281 UserMessageId::new(),
1282 ["Call the echo tool, then call the infinite tool, then explain their output"],
1283 cx,
1284 )
1285 })
1286 .unwrap();
1287
1288 // Wait until both tools are called.
1289 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1290 let mut echo_id = None;
1291 let mut echo_completed = false;
1292 while let Some(event) = events.next().await {
1293 match event.unwrap() {
1294 ThreadEvent::ToolCall(tool_call) => {
1295 assert_eq!(tool_call.title, expected_tools.remove(0));
1296 if tool_call.title == "Echo" {
1297 echo_id = Some(tool_call.id);
1298 }
1299 }
1300 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1301 acp::ToolCallUpdate {
1302 id,
1303 fields:
1304 acp::ToolCallUpdateFields {
1305 status: Some(acp::ToolCallStatus::Completed),
1306 ..
1307 },
1308 },
1309 )) if Some(&id) == echo_id.as_ref() => {
1310 echo_completed = true;
1311 }
1312 _ => {}
1313 }
1314
1315 if expected_tools.is_empty() && echo_completed {
1316 break;
1317 }
1318 }
1319
1320 // Cancel the current send and ensure that the event stream is closed, even
1321 // if one of the tools is still running.
1322 thread.update(cx, |thread, cx| thread.cancel(cx));
1323 let events = events.collect::<Vec<_>>().await;
1324 let last_event = events.last();
1325 assert!(
1326 matches!(
1327 last_event,
1328 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1329 ),
1330 "unexpected event {last_event:?}"
1331 );
1332
1333 // Ensure we can still send a new message after cancellation.
1334 let events = thread
1335 .update(cx, |thread, cx| {
1336 thread.send(
1337 UserMessageId::new(),
1338 ["Testing: reply with 'Hello' then stop."],
1339 cx,
1340 )
1341 })
1342 .unwrap()
1343 .collect::<Vec<_>>()
1344 .await;
1345 thread.update(cx, |thread, _cx| {
1346 let message = thread.last_message().unwrap();
1347 let agent_message = message.as_agent_message().unwrap();
1348 assert_eq!(
1349 agent_message.content,
1350 vec![AgentMessageContent::Text("Hello".to_string())]
1351 );
1352 });
1353 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1354}
1355
1356#[gpui::test]
1357async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1358 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1359 let fake_model = model.as_fake();
1360
1361 let events_1 = thread
1362 .update(cx, |thread, cx| {
1363 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1364 })
1365 .unwrap();
1366 cx.run_until_parked();
1367 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1368 cx.run_until_parked();
1369
1370 let events_2 = thread
1371 .update(cx, |thread, cx| {
1372 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1373 })
1374 .unwrap();
1375 cx.run_until_parked();
1376 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1377 fake_model
1378 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1379 fake_model.end_last_completion_stream();
1380
1381 let events_1 = events_1.collect::<Vec<_>>().await;
1382 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1383 let events_2 = events_2.collect::<Vec<_>>().await;
1384 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1385}
1386
1387#[gpui::test]
1388async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1389 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1390 let fake_model = model.as_fake();
1391
1392 let events_1 = thread
1393 .update(cx, |thread, cx| {
1394 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1395 })
1396 .unwrap();
1397 cx.run_until_parked();
1398 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1399 fake_model
1400 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1401 fake_model.end_last_completion_stream();
1402 let events_1 = events_1.collect::<Vec<_>>().await;
1403
1404 let events_2 = thread
1405 .update(cx, |thread, cx| {
1406 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1407 })
1408 .unwrap();
1409 cx.run_until_parked();
1410 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1411 fake_model
1412 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1413 fake_model.end_last_completion_stream();
1414 let events_2 = events_2.collect::<Vec<_>>().await;
1415
1416 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1417 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1418}
1419
1420#[gpui::test]
1421async fn test_refusal(cx: &mut TestAppContext) {
1422 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1423 let fake_model = model.as_fake();
1424
1425 let events = thread
1426 .update(cx, |thread, cx| {
1427 thread.send(UserMessageId::new(), ["Hello"], cx)
1428 })
1429 .unwrap();
1430 cx.run_until_parked();
1431 thread.read_with(cx, |thread, _| {
1432 assert_eq!(
1433 thread.to_markdown(),
1434 indoc! {"
1435 ## User
1436
1437 Hello
1438 "}
1439 );
1440 });
1441
1442 fake_model.send_last_completion_stream_text_chunk("Hey!");
1443 cx.run_until_parked();
1444 thread.read_with(cx, |thread, _| {
1445 assert_eq!(
1446 thread.to_markdown(),
1447 indoc! {"
1448 ## User
1449
1450 Hello
1451
1452 ## Assistant
1453
1454 Hey!
1455 "}
1456 );
1457 });
1458
1459 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1460 fake_model
1461 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1462 let events = events.collect::<Vec<_>>().await;
1463 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1464 thread.read_with(cx, |thread, _| {
1465 assert_eq!(thread.to_markdown(), "");
1466 });
1467}
1468
1469#[gpui::test]
1470async fn test_truncate_first_message(cx: &mut TestAppContext) {
1471 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1472 let fake_model = model.as_fake();
1473
1474 let message_id = UserMessageId::new();
1475 thread
1476 .update(cx, |thread, cx| {
1477 thread.send(message_id.clone(), ["Hello"], cx)
1478 })
1479 .unwrap();
1480 cx.run_until_parked();
1481 thread.read_with(cx, |thread, _| {
1482 assert_eq!(
1483 thread.to_markdown(),
1484 indoc! {"
1485 ## User
1486
1487 Hello
1488 "}
1489 );
1490 assert_eq!(thread.latest_token_usage(), None);
1491 });
1492
1493 fake_model.send_last_completion_stream_text_chunk("Hey!");
1494 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1495 language_model::TokenUsage {
1496 input_tokens: 32_000,
1497 output_tokens: 16_000,
1498 cache_creation_input_tokens: 0,
1499 cache_read_input_tokens: 0,
1500 },
1501 ));
1502 cx.run_until_parked();
1503 thread.read_with(cx, |thread, _| {
1504 assert_eq!(
1505 thread.to_markdown(),
1506 indoc! {"
1507 ## User
1508
1509 Hello
1510
1511 ## Assistant
1512
1513 Hey!
1514 "}
1515 );
1516 assert_eq!(
1517 thread.latest_token_usage(),
1518 Some(acp_thread::TokenUsage {
1519 used_tokens: 32_000 + 16_000,
1520 max_tokens: 1_000_000,
1521 })
1522 );
1523 });
1524
1525 thread
1526 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1527 .unwrap();
1528 cx.run_until_parked();
1529 thread.read_with(cx, |thread, _| {
1530 assert_eq!(thread.to_markdown(), "");
1531 assert_eq!(thread.latest_token_usage(), None);
1532 });
1533
1534 // Ensure we can still send a new message after truncation.
1535 thread
1536 .update(cx, |thread, cx| {
1537 thread.send(UserMessageId::new(), ["Hi"], cx)
1538 })
1539 .unwrap();
1540 thread.update(cx, |thread, _cx| {
1541 assert_eq!(
1542 thread.to_markdown(),
1543 indoc! {"
1544 ## User
1545
1546 Hi
1547 "}
1548 );
1549 });
1550 cx.run_until_parked();
1551 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1552 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1553 language_model::TokenUsage {
1554 input_tokens: 40_000,
1555 output_tokens: 20_000,
1556 cache_creation_input_tokens: 0,
1557 cache_read_input_tokens: 0,
1558 },
1559 ));
1560 cx.run_until_parked();
1561 thread.read_with(cx, |thread, _| {
1562 assert_eq!(
1563 thread.to_markdown(),
1564 indoc! {"
1565 ## User
1566
1567 Hi
1568
1569 ## Assistant
1570
1571 Ahoy!
1572 "}
1573 );
1574
1575 assert_eq!(
1576 thread.latest_token_usage(),
1577 Some(acp_thread::TokenUsage {
1578 used_tokens: 40_000 + 20_000,
1579 max_tokens: 1_000_000,
1580 })
1581 );
1582 });
1583}
1584
1585#[gpui::test]
1586async fn test_truncate_second_message(cx: &mut TestAppContext) {
1587 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1588 let fake_model = model.as_fake();
1589
1590 thread
1591 .update(cx, |thread, cx| {
1592 thread.send(UserMessageId::new(), ["Message 1"], cx)
1593 })
1594 .unwrap();
1595 cx.run_until_parked();
1596 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1597 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1598 language_model::TokenUsage {
1599 input_tokens: 32_000,
1600 output_tokens: 16_000,
1601 cache_creation_input_tokens: 0,
1602 cache_read_input_tokens: 0,
1603 },
1604 ));
1605 fake_model.end_last_completion_stream();
1606 cx.run_until_parked();
1607
1608 let assert_first_message_state = |cx: &mut TestAppContext| {
1609 thread.clone().read_with(cx, |thread, _| {
1610 assert_eq!(
1611 thread.to_markdown(),
1612 indoc! {"
1613 ## User
1614
1615 Message 1
1616
1617 ## Assistant
1618
1619 Message 1 response
1620 "}
1621 );
1622
1623 assert_eq!(
1624 thread.latest_token_usage(),
1625 Some(acp_thread::TokenUsage {
1626 used_tokens: 32_000 + 16_000,
1627 max_tokens: 1_000_000,
1628 })
1629 );
1630 });
1631 };
1632
1633 assert_first_message_state(cx);
1634
1635 let second_message_id = UserMessageId::new();
1636 thread
1637 .update(cx, |thread, cx| {
1638 thread.send(second_message_id.clone(), ["Message 2"], cx)
1639 })
1640 .unwrap();
1641 cx.run_until_parked();
1642
1643 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1644 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1645 language_model::TokenUsage {
1646 input_tokens: 40_000,
1647 output_tokens: 20_000,
1648 cache_creation_input_tokens: 0,
1649 cache_read_input_tokens: 0,
1650 },
1651 ));
1652 fake_model.end_last_completion_stream();
1653 cx.run_until_parked();
1654
1655 thread.read_with(cx, |thread, _| {
1656 assert_eq!(
1657 thread.to_markdown(),
1658 indoc! {"
1659 ## User
1660
1661 Message 1
1662
1663 ## Assistant
1664
1665 Message 1 response
1666
1667 ## User
1668
1669 Message 2
1670
1671 ## Assistant
1672
1673 Message 2 response
1674 "}
1675 );
1676
1677 assert_eq!(
1678 thread.latest_token_usage(),
1679 Some(acp_thread::TokenUsage {
1680 used_tokens: 40_000 + 20_000,
1681 max_tokens: 1_000_000,
1682 })
1683 );
1684 });
1685
1686 thread
1687 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1688 .unwrap();
1689 cx.run_until_parked();
1690
1691 assert_first_message_state(cx);
1692}
1693
1694#[gpui::test]
1695async fn test_title_generation(cx: &mut TestAppContext) {
1696 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1697 let fake_model = model.as_fake();
1698
1699 let summary_model = Arc::new(FakeLanguageModel::default());
1700 thread.update(cx, |thread, cx| {
1701 thread.set_summarization_model(Some(summary_model.clone()), cx)
1702 });
1703
1704 let send = thread
1705 .update(cx, |thread, cx| {
1706 thread.send(UserMessageId::new(), ["Hello"], cx)
1707 })
1708 .unwrap();
1709 cx.run_until_parked();
1710
1711 fake_model.send_last_completion_stream_text_chunk("Hey!");
1712 fake_model.end_last_completion_stream();
1713 cx.run_until_parked();
1714 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1715
1716 // Ensure the summary model has been invoked to generate a title.
1717 summary_model.send_last_completion_stream_text_chunk("Hello ");
1718 summary_model.send_last_completion_stream_text_chunk("world\nG");
1719 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1720 summary_model.end_last_completion_stream();
1721 send.collect::<Vec<_>>().await;
1722 cx.run_until_parked();
1723 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1724
1725 // Send another message, ensuring no title is generated this time.
1726 let send = thread
1727 .update(cx, |thread, cx| {
1728 thread.send(UserMessageId::new(), ["Hello again"], cx)
1729 })
1730 .unwrap();
1731 cx.run_until_parked();
1732 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1733 fake_model.end_last_completion_stream();
1734 cx.run_until_parked();
1735 assert_eq!(summary_model.pending_completions(), Vec::new());
1736 send.collect::<Vec<_>>().await;
1737 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1738}
1739
1740#[gpui::test]
1741async fn test_agent_connection(cx: &mut TestAppContext) {
1742 cx.update(settings::init);
1743 let templates = Templates::new();
1744
1745 // Initialize language model system with test provider
1746 cx.update(|cx| {
1747 gpui_tokio::init(cx);
1748 client::init_settings(cx);
1749
1750 let http_client = FakeHttpClient::with_404_response();
1751 let clock = Arc::new(clock::FakeSystemClock::new());
1752 let client = Client::new(clock, http_client, cx);
1753 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1754 Project::init_settings(cx);
1755 agent_settings::init(cx);
1756 language_model::init(client.clone(), cx);
1757 language_models::init(user_store, client.clone(), cx);
1758 LanguageModelRegistry::test(cx);
1759 });
1760 cx.executor().forbid_parking();
1761
1762 // Create a project for new_thread
1763 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1764 fake_fs.insert_tree(path!("/test"), json!({})).await;
1765 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1766 let cwd = Path::new("/test");
1767 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1768 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1769
1770 // Create agent and connection
1771 let agent = NativeAgent::new(
1772 project.clone(),
1773 history_store,
1774 templates.clone(),
1775 None,
1776 fake_fs.clone(),
1777 &mut cx.to_async(),
1778 )
1779 .await
1780 .unwrap();
1781 let connection = NativeAgentConnection(agent.clone());
1782
1783 // Test model_selector returns Some
1784 let selector_opt = connection.model_selector();
1785 assert!(
1786 selector_opt.is_some(),
1787 "agent2 should always support ModelSelector"
1788 );
1789 let selector = selector_opt.unwrap();
1790
1791 // Test list_models
1792 let listed_models = cx
1793 .update(|cx| selector.list_models(cx))
1794 .await
1795 .expect("list_models should succeed");
1796 let AgentModelList::Grouped(listed_models) = listed_models else {
1797 panic!("Unexpected model list type");
1798 };
1799 assert!(!listed_models.is_empty(), "should have at least one model");
1800 assert_eq!(
1801 listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1802 "fake/fake"
1803 );
1804
1805 // Create a thread using new_thread
1806 let connection_rc = Rc::new(connection.clone());
1807 let acp_thread = cx
1808 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1809 .await
1810 .expect("new_thread should succeed");
1811
1812 // Get the session_id from the AcpThread
1813 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1814
1815 // Test selected_model returns the default
1816 let model = cx
1817 .update(|cx| selector.selected_model(&session_id, cx))
1818 .await
1819 .expect("selected_model should succeed");
1820 let model = cx
1821 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1822 .unwrap();
1823 let model = model.as_fake();
1824 assert_eq!(model.id().0, "fake", "should return default model");
1825
1826 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1827 cx.run_until_parked();
1828 model.send_last_completion_stream_text_chunk("def");
1829 cx.run_until_parked();
1830 acp_thread.read_with(cx, |thread, cx| {
1831 assert_eq!(
1832 thread.to_markdown(cx),
1833 indoc! {"
1834 ## User
1835
1836 abc
1837
1838 ## Assistant
1839
1840 def
1841
1842 "}
1843 )
1844 });
1845
1846 // Test cancel
1847 cx.update(|cx| connection.cancel(&session_id, cx));
1848 request.await.expect("prompt should fail gracefully");
1849
1850 // Ensure that dropping the ACP thread causes the native thread to be
1851 // dropped as well.
1852 cx.update(|_| drop(acp_thread));
1853 let result = cx
1854 .update(|cx| {
1855 connection.prompt(
1856 Some(acp_thread::UserMessageId::new()),
1857 acp::PromptRequest {
1858 session_id: session_id.clone(),
1859 prompt: vec!["ghi".into()],
1860 },
1861 cx,
1862 )
1863 })
1864 .await;
1865 assert_eq!(
1866 result.as_ref().unwrap_err().to_string(),
1867 "Session not found",
1868 "unexpected result: {:?}",
1869 result
1870 );
1871}
1872
1873#[gpui::test]
1874async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1875 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1876 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1877 let fake_model = model.as_fake();
1878
1879 let mut events = thread
1880 .update(cx, |thread, cx| {
1881 thread.send(UserMessageId::new(), ["Think"], cx)
1882 })
1883 .unwrap();
1884 cx.run_until_parked();
1885
1886 // Simulate streaming partial input.
1887 let input = json!({});
1888 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1889 LanguageModelToolUse {
1890 id: "1".into(),
1891 name: ThinkingTool::name().into(),
1892 raw_input: input.to_string(),
1893 input,
1894 is_input_complete: false,
1895 },
1896 ));
1897
1898 // Input streaming completed
1899 let input = json!({ "content": "Thinking hard!" });
1900 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1901 LanguageModelToolUse {
1902 id: "1".into(),
1903 name: "thinking".into(),
1904 raw_input: input.to_string(),
1905 input,
1906 is_input_complete: true,
1907 },
1908 ));
1909 fake_model.end_last_completion_stream();
1910 cx.run_until_parked();
1911
1912 let tool_call = expect_tool_call(&mut events).await;
1913 assert_eq!(
1914 tool_call,
1915 acp::ToolCall {
1916 id: acp::ToolCallId("1".into()),
1917 title: "Thinking".into(),
1918 kind: acp::ToolKind::Think,
1919 status: acp::ToolCallStatus::Pending,
1920 content: vec![],
1921 locations: vec![],
1922 raw_input: Some(json!({})),
1923 raw_output: None,
1924 }
1925 );
1926 let update = expect_tool_call_update_fields(&mut events).await;
1927 assert_eq!(
1928 update,
1929 acp::ToolCallUpdate {
1930 id: acp::ToolCallId("1".into()),
1931 fields: acp::ToolCallUpdateFields {
1932 title: Some("Thinking".into()),
1933 kind: Some(acp::ToolKind::Think),
1934 raw_input: Some(json!({ "content": "Thinking hard!" })),
1935 ..Default::default()
1936 },
1937 }
1938 );
1939 let update = expect_tool_call_update_fields(&mut events).await;
1940 assert_eq!(
1941 update,
1942 acp::ToolCallUpdate {
1943 id: acp::ToolCallId("1".into()),
1944 fields: acp::ToolCallUpdateFields {
1945 status: Some(acp::ToolCallStatus::InProgress),
1946 ..Default::default()
1947 },
1948 }
1949 );
1950 let update = expect_tool_call_update_fields(&mut events).await;
1951 assert_eq!(
1952 update,
1953 acp::ToolCallUpdate {
1954 id: acp::ToolCallId("1".into()),
1955 fields: acp::ToolCallUpdateFields {
1956 content: Some(vec!["Thinking hard!".into()]),
1957 ..Default::default()
1958 },
1959 }
1960 );
1961 let update = expect_tool_call_update_fields(&mut events).await;
1962 assert_eq!(
1963 update,
1964 acp::ToolCallUpdate {
1965 id: acp::ToolCallId("1".into()),
1966 fields: acp::ToolCallUpdateFields {
1967 status: Some(acp::ToolCallStatus::Completed),
1968 raw_output: Some("Finished thinking.".into()),
1969 ..Default::default()
1970 },
1971 }
1972 );
1973}
1974
1975#[gpui::test]
1976async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
1977 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1978 let fake_model = model.as_fake();
1979
1980 let mut events = thread
1981 .update(cx, |thread, cx| {
1982 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1983 thread.send(UserMessageId::new(), ["Hello!"], cx)
1984 })
1985 .unwrap();
1986 cx.run_until_parked();
1987
1988 fake_model.send_last_completion_stream_text_chunk("Hey!");
1989 fake_model.end_last_completion_stream();
1990
1991 let mut retry_events = Vec::new();
1992 while let Some(Ok(event)) = events.next().await {
1993 match event {
1994 ThreadEvent::Retry(retry_status) => {
1995 retry_events.push(retry_status);
1996 }
1997 ThreadEvent::Stop(..) => break,
1998 _ => {}
1999 }
2000 }
2001
2002 assert_eq!(retry_events.len(), 0);
2003 thread.read_with(cx, |thread, _cx| {
2004 assert_eq!(
2005 thread.to_markdown(),
2006 indoc! {"
2007 ## User
2008
2009 Hello!
2010
2011 ## Assistant
2012
2013 Hey!
2014 "}
2015 )
2016 });
2017}
2018
2019#[gpui::test]
2020async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2021 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2022 let fake_model = model.as_fake();
2023
2024 let mut events = thread
2025 .update(cx, |thread, cx| {
2026 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2027 thread.send(UserMessageId::new(), ["Hello!"], cx)
2028 })
2029 .unwrap();
2030 cx.run_until_parked();
2031
2032 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2033 provider: LanguageModelProviderName::new("Anthropic"),
2034 retry_after: Some(Duration::from_secs(3)),
2035 });
2036 fake_model.end_last_completion_stream();
2037
2038 cx.executor().advance_clock(Duration::from_secs(3));
2039 cx.run_until_parked();
2040
2041 fake_model.send_last_completion_stream_text_chunk("Hey!");
2042 fake_model.end_last_completion_stream();
2043
2044 let mut retry_events = Vec::new();
2045 while let Some(Ok(event)) = events.next().await {
2046 match event {
2047 ThreadEvent::Retry(retry_status) => {
2048 retry_events.push(retry_status);
2049 }
2050 ThreadEvent::Stop(..) => break,
2051 _ => {}
2052 }
2053 }
2054
2055 assert_eq!(retry_events.len(), 1);
2056 assert!(matches!(
2057 retry_events[0],
2058 acp_thread::RetryStatus { attempt: 1, .. }
2059 ));
2060 thread.read_with(cx, |thread, _cx| {
2061 assert_eq!(
2062 thread.to_markdown(),
2063 indoc! {"
2064 ## User
2065
2066 Hello!
2067
2068 ## Assistant
2069
2070 Hey!
2071 "}
2072 )
2073 });
2074}
2075
2076#[gpui::test]
2077async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2078 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2079 let fake_model = model.as_fake();
2080
2081 let mut events = thread
2082 .update(cx, |thread, cx| {
2083 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2084 thread.send(UserMessageId::new(), ["Hello!"], cx)
2085 })
2086 .unwrap();
2087 cx.run_until_parked();
2088
2089 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2090 fake_model.send_last_completion_stream_error(
2091 LanguageModelCompletionError::ServerOverloaded {
2092 provider: LanguageModelProviderName::new("Anthropic"),
2093 retry_after: Some(Duration::from_secs(3)),
2094 },
2095 );
2096 fake_model.end_last_completion_stream();
2097 cx.executor().advance_clock(Duration::from_secs(3));
2098 cx.run_until_parked();
2099 }
2100
2101 let mut errors = Vec::new();
2102 let mut retry_events = Vec::new();
2103 while let Some(event) = events.next().await {
2104 match event {
2105 Ok(ThreadEvent::Retry(retry_status)) => {
2106 retry_events.push(retry_status);
2107 }
2108 Ok(ThreadEvent::Stop(..)) => break,
2109 Err(error) => errors.push(error),
2110 _ => {}
2111 }
2112 }
2113
2114 assert_eq!(
2115 retry_events.len(),
2116 crate::thread::MAX_RETRY_ATTEMPTS as usize
2117 );
2118 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2119 assert_eq!(retry_events[i].attempt, i + 1);
2120 }
2121 assert_eq!(errors.len(), 1);
2122 let error = errors[0]
2123 .downcast_ref::<LanguageModelCompletionError>()
2124 .unwrap();
2125 assert!(matches!(
2126 error,
2127 LanguageModelCompletionError::ServerOverloaded { .. }
2128 ));
2129}
2130
2131/// Filters out the stop events for asserting against in tests
2132fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2133 result_events
2134 .into_iter()
2135 .filter_map(|event| match event.unwrap() {
2136 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2137 _ => None,
2138 })
2139 .collect()
2140}
2141
2142struct ThreadTest {
2143 model: Arc<dyn LanguageModel>,
2144 thread: Entity<Thread>,
2145 project_context: Entity<ProjectContext>,
2146 context_server_store: Entity<ContextServerStore>,
2147 fs: Arc<FakeFs>,
2148}
2149
2150enum TestModel {
2151 Sonnet4,
2152 Fake,
2153}
2154
2155impl TestModel {
2156 fn id(&self) -> LanguageModelId {
2157 match self {
2158 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2159 TestModel::Fake => unreachable!(),
2160 }
2161 }
2162}
2163
2164async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2165 cx.executor().allow_parking();
2166
2167 let fs = FakeFs::new(cx.background_executor.clone());
2168 fs.create_dir(paths::settings_file().parent().unwrap())
2169 .await
2170 .unwrap();
2171 fs.insert_file(
2172 paths::settings_file(),
2173 json!({
2174 "agent": {
2175 "default_profile": "test-profile",
2176 "profiles": {
2177 "test-profile": {
2178 "name": "Test Profile",
2179 "tools": {
2180 EchoTool::name(): true,
2181 DelayTool::name(): true,
2182 WordListTool::name(): true,
2183 ToolRequiringPermission::name(): true,
2184 InfiniteTool::name(): true,
2185 ThinkingTool::name(): true,
2186 }
2187 }
2188 }
2189 }
2190 })
2191 .to_string()
2192 .into_bytes(),
2193 )
2194 .await;
2195
2196 cx.update(|cx| {
2197 settings::init(cx);
2198 Project::init_settings(cx);
2199 agent_settings::init(cx);
2200 gpui_tokio::init(cx);
2201 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2202 cx.set_http_client(Arc::new(http_client));
2203
2204 client::init_settings(cx);
2205 let client = Client::production(cx);
2206 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2207 language_model::init(client.clone(), cx);
2208 language_models::init(user_store, client.clone(), cx);
2209
2210 watch_settings(fs.clone(), cx);
2211 });
2212
2213 let templates = Templates::new();
2214
2215 fs.insert_tree(path!("/test"), json!({})).await;
2216 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2217
2218 let model = cx
2219 .update(|cx| {
2220 if let TestModel::Fake = model {
2221 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2222 } else {
2223 let model_id = model.id();
2224 let models = LanguageModelRegistry::read_global(cx);
2225 let model = models
2226 .available_models(cx)
2227 .find(|model| model.id() == model_id)
2228 .unwrap();
2229
2230 let provider = models.provider(&model.provider_id()).unwrap();
2231 let authenticated = provider.authenticate(cx);
2232
2233 cx.spawn(async move |_cx| {
2234 authenticated.await.unwrap();
2235 model
2236 })
2237 }
2238 })
2239 .await;
2240
2241 let project_context = cx.new(|_cx| ProjectContext::default());
2242 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2243 let context_server_registry =
2244 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2245 let thread = cx.new(|cx| {
2246 Thread::new(
2247 project,
2248 project_context.clone(),
2249 context_server_registry,
2250 templates,
2251 Some(model.clone()),
2252 cx,
2253 )
2254 });
2255 ThreadTest {
2256 model,
2257 thread,
2258 project_context,
2259 context_server_store,
2260 fs,
2261 }
2262}
2263
2264#[cfg(test)]
2265#[ctor::ctor]
2266fn init_logger() {
2267 if std::env::var("RUST_LOG").is_ok() {
2268 env_logger::init();
2269 }
2270}
2271
2272fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2273 let fs = fs.clone();
2274 cx.spawn({
2275 async move |cx| {
2276 let mut new_settings_content_rx = settings::watch_config_file(
2277 cx.background_executor(),
2278 fs,
2279 paths::settings_file().clone(),
2280 );
2281
2282 while let Some(new_settings_content) = new_settings_content_rx.next().await {
2283 cx.update(|cx| {
2284 SettingsStore::update_global(cx, |settings, cx| {
2285 settings.set_user_settings(&new_settings_content, cx)
2286 })
2287 })
2288 .ok();
2289 }
2290 }
2291 })
2292 .detach();
2293}
2294
2295fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2296 completion
2297 .tools
2298 .iter()
2299 .map(|tool| tool.name.clone())
2300 .collect()
2301}
2302
2303fn setup_context_server(
2304 name: &'static str,
2305 tools: Vec<context_server::types::Tool>,
2306 context_server_store: &Entity<ContextServerStore>,
2307 cx: &mut TestAppContext,
2308) -> mpsc::UnboundedReceiver<(
2309 context_server::types::CallToolParams,
2310 oneshot::Sender<context_server::types::CallToolResponse>,
2311)> {
2312 cx.update(|cx| {
2313 let mut settings = ProjectSettings::get_global(cx).clone();
2314 settings.context_servers.insert(
2315 name.into(),
2316 project::project_settings::ContextServerSettings::Custom {
2317 enabled: true,
2318 command: ContextServerCommand {
2319 path: "somebinary".into(),
2320 args: Vec::new(),
2321 env: None,
2322 },
2323 },
2324 );
2325 ProjectSettings::override_global(settings, cx);
2326 });
2327
2328 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2329 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2330 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2331 context_server::types::InitializeResponse {
2332 protocol_version: context_server::types::ProtocolVersion(
2333 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2334 ),
2335 server_info: context_server::types::Implementation {
2336 name: name.into(),
2337 version: "1.0.0".to_string(),
2338 },
2339 capabilities: context_server::types::ServerCapabilities {
2340 tools: Some(context_server::types::ToolsCapabilities {
2341 list_changed: Some(true),
2342 }),
2343 ..Default::default()
2344 },
2345 meta: None,
2346 }
2347 })
2348 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2349 let tools = tools.clone();
2350 async move {
2351 context_server::types::ListToolsResponse {
2352 tools,
2353 next_cursor: None,
2354 meta: None,
2355 }
2356 }
2357 })
2358 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2359 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2360 async move {
2361 let (response_tx, response_rx) = oneshot::channel();
2362 mcp_tool_calls_tx
2363 .unbounded_send((params, response_tx))
2364 .unwrap();
2365 response_rx.await.unwrap()
2366 }
2367 });
2368 context_server_store.update(cx, |store, cx| {
2369 store.start_server(
2370 Arc::new(ContextServer::new(
2371 ContextServerId(name.into()),
2372 Arc::new(fake_transport),
2373 )),
2374 cx,
2375 );
2376 });
2377 cx.run_until_parked();
2378 mcp_tool_calls_rx
2379}