1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use agent_client_protocol::{self as acp};
4use agent_settings::AgentProfileId;
5use anyhow::Result;
6use client::{Client, UserStore};
7use cloud_llm_client::CompletionIntent;
8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
9use fs::{FakeFs, Fs};
10use futures::{
11 StreamExt,
12 channel::{
13 mpsc::{self, UnboundedReceiver},
14 oneshot,
15 },
16};
17use gpui::{
18 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
19};
20use indoc::indoc;
21use language_model::{
22 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
23 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
25 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
26};
27use pretty_assertions::assert_eq;
28use project::{
29 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
30};
31use prompt_store::ProjectContext;
32use reqwest_client::ReqwestClient;
33use schemars::JsonSchema;
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36use settings::{Settings, SettingsStore};
37use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
38use util::path;
39
40mod test_tools;
41use test_tools::*;
42
43#[gpui::test]
44async fn test_echo(cx: &mut TestAppContext) {
45 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
46 let fake_model = model.as_fake();
47
48 let events = thread
49 .update(cx, |thread, cx| {
50 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
51 })
52 .unwrap();
53 cx.run_until_parked();
54 fake_model.send_last_completion_stream_text_chunk("Hello");
55 fake_model
56 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
57 fake_model.end_last_completion_stream();
58
59 let events = events.collect().await;
60 thread.update(cx, |thread, _cx| {
61 assert_eq!(
62 thread.last_message().unwrap().to_markdown(),
63 indoc! {"
64 ## Assistant
65
66 Hello
67 "}
68 )
69 });
70 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
71}
72
73#[gpui::test]
74async fn test_thinking(cx: &mut TestAppContext) {
75 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
76 let fake_model = model.as_fake();
77
78 let events = thread
79 .update(cx, |thread, cx| {
80 thread.send(
81 UserMessageId::new(),
82 [indoc! {"
83 Testing:
84
85 Generate a thinking step where you just think the word 'Think',
86 and have your final answer be 'Hello'
87 "}],
88 cx,
89 )
90 })
91 .unwrap();
92 cx.run_until_parked();
93 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
94 text: "Think".to_string(),
95 signature: None,
96 });
97 fake_model.send_last_completion_stream_text_chunk("Hello");
98 fake_model
99 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
100 fake_model.end_last_completion_stream();
101
102 let events = events.collect().await;
103 thread.update(cx, |thread, _cx| {
104 assert_eq!(
105 thread.last_message().unwrap().to_markdown(),
106 indoc! {"
107 ## Assistant
108
109 <think>Think</think>
110 Hello
111 "}
112 )
113 });
114 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
115}
116
117#[gpui::test]
118async fn test_system_prompt(cx: &mut TestAppContext) {
119 let ThreadTest {
120 model,
121 thread,
122 project_context,
123 ..
124 } = setup(cx, TestModel::Fake).await;
125 let fake_model = model.as_fake();
126
127 project_context.update(cx, |project_context, _cx| {
128 project_context.shell = "test-shell".into()
129 });
130 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
131 thread
132 .update(cx, |thread, cx| {
133 thread.send(UserMessageId::new(), ["abc"], cx)
134 })
135 .unwrap();
136 cx.run_until_parked();
137 let mut pending_completions = fake_model.pending_completions();
138 assert_eq!(
139 pending_completions.len(),
140 1,
141 "unexpected pending completions: {:?}",
142 pending_completions
143 );
144
145 let pending_completion = pending_completions.pop().unwrap();
146 assert_eq!(pending_completion.messages[0].role, Role::System);
147
148 let system_message = &pending_completion.messages[0];
149 let system_prompt = system_message.content[0].to_str().unwrap();
150 assert!(
151 system_prompt.contains("test-shell"),
152 "unexpected system message: {:?}",
153 system_message
154 );
155 assert!(
156 system_prompt.contains("## Fixing Diagnostics"),
157 "unexpected system message: {:?}",
158 system_message
159 );
160}
161
162#[gpui::test]
163async fn test_prompt_caching(cx: &mut TestAppContext) {
164 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
165 let fake_model = model.as_fake();
166
167 // Send initial user message and verify it's cached
168 thread
169 .update(cx, |thread, cx| {
170 thread.send(UserMessageId::new(), ["Message 1"], cx)
171 })
172 .unwrap();
173 cx.run_until_parked();
174
175 let completion = fake_model.pending_completions().pop().unwrap();
176 assert_eq!(
177 completion.messages[1..],
178 vec![LanguageModelRequestMessage {
179 role: Role::User,
180 content: vec!["Message 1".into()],
181 cache: true
182 }]
183 );
184 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
185 "Response to Message 1".into(),
186 ));
187 fake_model.end_last_completion_stream();
188 cx.run_until_parked();
189
190 // Send another user message and verify only the latest is cached
191 thread
192 .update(cx, |thread, cx| {
193 thread.send(UserMessageId::new(), ["Message 2"], cx)
194 })
195 .unwrap();
196 cx.run_until_parked();
197
198 let completion = fake_model.pending_completions().pop().unwrap();
199 assert_eq!(
200 completion.messages[1..],
201 vec![
202 LanguageModelRequestMessage {
203 role: Role::User,
204 content: vec!["Message 1".into()],
205 cache: false
206 },
207 LanguageModelRequestMessage {
208 role: Role::Assistant,
209 content: vec!["Response to Message 1".into()],
210 cache: false
211 },
212 LanguageModelRequestMessage {
213 role: Role::User,
214 content: vec!["Message 2".into()],
215 cache: true
216 }
217 ]
218 );
219 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
220 "Response to Message 2".into(),
221 ));
222 fake_model.end_last_completion_stream();
223 cx.run_until_parked();
224
225 // Simulate a tool call and verify that the latest tool result is cached
226 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
227 thread
228 .update(cx, |thread, cx| {
229 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
230 })
231 .unwrap();
232 cx.run_until_parked();
233
234 let tool_use = LanguageModelToolUse {
235 id: "tool_1".into(),
236 name: EchoTool::name().into(),
237 raw_input: json!({"text": "test"}).to_string(),
238 input: json!({"text": "test"}),
239 is_input_complete: true,
240 };
241 fake_model
242 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
243 fake_model.end_last_completion_stream();
244 cx.run_until_parked();
245
246 let completion = fake_model.pending_completions().pop().unwrap();
247 let tool_result = LanguageModelToolResult {
248 tool_use_id: "tool_1".into(),
249 tool_name: EchoTool::name().into(),
250 is_error: false,
251 content: "test".into(),
252 output: Some("test".into()),
253 };
254 assert_eq!(
255 completion.messages[1..],
256 vec![
257 LanguageModelRequestMessage {
258 role: Role::User,
259 content: vec!["Message 1".into()],
260 cache: false
261 },
262 LanguageModelRequestMessage {
263 role: Role::Assistant,
264 content: vec!["Response to Message 1".into()],
265 cache: false
266 },
267 LanguageModelRequestMessage {
268 role: Role::User,
269 content: vec!["Message 2".into()],
270 cache: false
271 },
272 LanguageModelRequestMessage {
273 role: Role::Assistant,
274 content: vec!["Response to Message 2".into()],
275 cache: false
276 },
277 LanguageModelRequestMessage {
278 role: Role::User,
279 content: vec!["Use the echo tool".into()],
280 cache: false
281 },
282 LanguageModelRequestMessage {
283 role: Role::Assistant,
284 content: vec![MessageContent::ToolUse(tool_use)],
285 cache: false
286 },
287 LanguageModelRequestMessage {
288 role: Role::User,
289 content: vec![MessageContent::ToolResult(tool_result)],
290 cache: true
291 }
292 ]
293 );
294}
295
296#[gpui::test]
297#[cfg_attr(not(feature = "e2e"), ignore)]
298async fn test_basic_tool_calls(cx: &mut TestAppContext) {
299 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
300
301 // Test a tool call that's likely to complete *before* streaming stops.
302 let events = thread
303 .update(cx, |thread, cx| {
304 thread.add_tool(EchoTool);
305 thread.send(
306 UserMessageId::new(),
307 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
308 cx,
309 )
310 })
311 .unwrap()
312 .collect()
313 .await;
314 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
315
316 // Test a tool calls that's likely to complete *after* streaming stops.
317 let events = thread
318 .update(cx, |thread, cx| {
319 thread.remove_tool(&EchoTool::name());
320 thread.add_tool(DelayTool);
321 thread.send(
322 UserMessageId::new(),
323 [
324 "Now call the delay tool with 200ms.",
325 "When the timer goes off, then you echo the output of the tool.",
326 ],
327 cx,
328 )
329 })
330 .unwrap()
331 .collect()
332 .await;
333 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
334 thread.update(cx, |thread, _cx| {
335 assert!(
336 thread
337 .last_message()
338 .unwrap()
339 .as_agent_message()
340 .unwrap()
341 .content
342 .iter()
343 .any(|content| {
344 if let AgentMessageContent::Text(text) = content {
345 text.contains("Ding")
346 } else {
347 false
348 }
349 }),
350 "{}",
351 thread.to_markdown()
352 );
353 });
354}
355
356#[gpui::test]
357#[cfg_attr(not(feature = "e2e"), ignore)]
358async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
359 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
360
361 // Test a tool call that's likely to complete *before* streaming stops.
362 let mut events = thread
363 .update(cx, |thread, cx| {
364 thread.add_tool(WordListTool);
365 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
366 })
367 .unwrap();
368
369 let mut saw_partial_tool_use = false;
370 while let Some(event) = events.next().await {
371 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
372 thread.update(cx, |thread, _cx| {
373 // Look for a tool use in the thread's last message
374 let message = thread.last_message().unwrap();
375 let agent_message = message.as_agent_message().unwrap();
376 let last_content = agent_message.content.last().unwrap();
377 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
378 assert_eq!(last_tool_use.name.as_ref(), "word_list");
379 if tool_call.status == acp::ToolCallStatus::Pending {
380 if !last_tool_use.is_input_complete
381 && last_tool_use.input.get("g").is_none()
382 {
383 saw_partial_tool_use = true;
384 }
385 } else {
386 last_tool_use
387 .input
388 .get("a")
389 .expect("'a' has streamed because input is now complete");
390 last_tool_use
391 .input
392 .get("g")
393 .expect("'g' has streamed because input is now complete");
394 }
395 } else {
396 panic!("last content should be a tool use");
397 }
398 });
399 }
400 }
401
402 assert!(
403 saw_partial_tool_use,
404 "should see at least one partially streamed tool use in the history"
405 );
406}
407
408#[gpui::test]
409async fn test_tool_authorization(cx: &mut TestAppContext) {
410 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
411 let fake_model = model.as_fake();
412
413 let mut events = thread
414 .update(cx, |thread, cx| {
415 thread.add_tool(ToolRequiringPermission);
416 thread.send(UserMessageId::new(), ["abc"], cx)
417 })
418 .unwrap();
419 cx.run_until_parked();
420 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
421 LanguageModelToolUse {
422 id: "tool_id_1".into(),
423 name: ToolRequiringPermission::name().into(),
424 raw_input: "{}".into(),
425 input: json!({}),
426 is_input_complete: true,
427 },
428 ));
429 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
430 LanguageModelToolUse {
431 id: "tool_id_2".into(),
432 name: ToolRequiringPermission::name().into(),
433 raw_input: "{}".into(),
434 input: json!({}),
435 is_input_complete: true,
436 },
437 ));
438 fake_model.end_last_completion_stream();
439 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
440 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
441
442 // Approve the first
443 tool_call_auth_1
444 .response
445 .send(tool_call_auth_1.options[1].id.clone())
446 .unwrap();
447 cx.run_until_parked();
448
449 // Reject the second
450 tool_call_auth_2
451 .response
452 .send(tool_call_auth_1.options[2].id.clone())
453 .unwrap();
454 cx.run_until_parked();
455
456 let completion = fake_model.pending_completions().pop().unwrap();
457 let message = completion.messages.last().unwrap();
458 assert_eq!(
459 message.content,
460 vec![
461 language_model::MessageContent::ToolResult(LanguageModelToolResult {
462 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
463 tool_name: ToolRequiringPermission::name().into(),
464 is_error: false,
465 content: "Allowed".into(),
466 output: Some("Allowed".into())
467 }),
468 language_model::MessageContent::ToolResult(LanguageModelToolResult {
469 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
470 tool_name: ToolRequiringPermission::name().into(),
471 is_error: true,
472 content: "Permission to run tool denied by user".into(),
473 output: None
474 })
475 ]
476 );
477
478 // Simulate yet another tool call.
479 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
480 LanguageModelToolUse {
481 id: "tool_id_3".into(),
482 name: ToolRequiringPermission::name().into(),
483 raw_input: "{}".into(),
484 input: json!({}),
485 is_input_complete: true,
486 },
487 ));
488 fake_model.end_last_completion_stream();
489
490 // Respond by always allowing tools.
491 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
492 tool_call_auth_3
493 .response
494 .send(tool_call_auth_3.options[0].id.clone())
495 .unwrap();
496 cx.run_until_parked();
497 let completion = fake_model.pending_completions().pop().unwrap();
498 let message = completion.messages.last().unwrap();
499 assert_eq!(
500 message.content,
501 vec![language_model::MessageContent::ToolResult(
502 LanguageModelToolResult {
503 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
504 tool_name: ToolRequiringPermission::name().into(),
505 is_error: false,
506 content: "Allowed".into(),
507 output: Some("Allowed".into())
508 }
509 )]
510 );
511
512 // Simulate a final tool call, ensuring we don't trigger authorization.
513 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
514 LanguageModelToolUse {
515 id: "tool_id_4".into(),
516 name: ToolRequiringPermission::name().into(),
517 raw_input: "{}".into(),
518 input: json!({}),
519 is_input_complete: true,
520 },
521 ));
522 fake_model.end_last_completion_stream();
523 cx.run_until_parked();
524 let completion = fake_model.pending_completions().pop().unwrap();
525 let message = completion.messages.last().unwrap();
526 assert_eq!(
527 message.content,
528 vec![language_model::MessageContent::ToolResult(
529 LanguageModelToolResult {
530 tool_use_id: "tool_id_4".into(),
531 tool_name: ToolRequiringPermission::name().into(),
532 is_error: false,
533 content: "Allowed".into(),
534 output: Some("Allowed".into())
535 }
536 )]
537 );
538}
539
540#[gpui::test]
541async fn test_tool_hallucination(cx: &mut TestAppContext) {
542 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
543 let fake_model = model.as_fake();
544
545 let mut events = thread
546 .update(cx, |thread, cx| {
547 thread.send(UserMessageId::new(), ["abc"], cx)
548 })
549 .unwrap();
550 cx.run_until_parked();
551 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
552 LanguageModelToolUse {
553 id: "tool_id_1".into(),
554 name: "nonexistent_tool".into(),
555 raw_input: "{}".into(),
556 input: json!({}),
557 is_input_complete: true,
558 },
559 ));
560 fake_model.end_last_completion_stream();
561
562 let tool_call = expect_tool_call(&mut events).await;
563 assert_eq!(tool_call.title, "nonexistent_tool");
564 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
565 let update = expect_tool_call_update_fields(&mut events).await;
566 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
567}
568
569#[gpui::test]
570async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
571 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
572 let fake_model = model.as_fake();
573
574 let events = thread
575 .update(cx, |thread, cx| {
576 thread.add_tool(EchoTool);
577 thread.send(UserMessageId::new(), ["abc"], cx)
578 })
579 .unwrap();
580 cx.run_until_parked();
581 let tool_use = LanguageModelToolUse {
582 id: "tool_id_1".into(),
583 name: EchoTool::name().into(),
584 raw_input: "{}".into(),
585 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
586 is_input_complete: true,
587 };
588 fake_model
589 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
590 fake_model.end_last_completion_stream();
591
592 cx.run_until_parked();
593 let completion = fake_model.pending_completions().pop().unwrap();
594 let tool_result = LanguageModelToolResult {
595 tool_use_id: "tool_id_1".into(),
596 tool_name: EchoTool::name().into(),
597 is_error: false,
598 content: "def".into(),
599 output: Some("def".into()),
600 };
601 assert_eq!(
602 completion.messages[1..],
603 vec![
604 LanguageModelRequestMessage {
605 role: Role::User,
606 content: vec!["abc".into()],
607 cache: false
608 },
609 LanguageModelRequestMessage {
610 role: Role::Assistant,
611 content: vec![MessageContent::ToolUse(tool_use.clone())],
612 cache: false
613 },
614 LanguageModelRequestMessage {
615 role: Role::User,
616 content: vec![MessageContent::ToolResult(tool_result.clone())],
617 cache: true
618 },
619 ]
620 );
621
622 // Simulate reaching tool use limit.
623 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
624 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
625 ));
626 fake_model.end_last_completion_stream();
627 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
628 assert!(
629 last_event
630 .unwrap_err()
631 .is::<language_model::ToolUseLimitReachedError>()
632 );
633
634 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
635 cx.run_until_parked();
636 let completion = fake_model.pending_completions().pop().unwrap();
637 assert_eq!(
638 completion.messages[1..],
639 vec![
640 LanguageModelRequestMessage {
641 role: Role::User,
642 content: vec!["abc".into()],
643 cache: false
644 },
645 LanguageModelRequestMessage {
646 role: Role::Assistant,
647 content: vec![MessageContent::ToolUse(tool_use)],
648 cache: false
649 },
650 LanguageModelRequestMessage {
651 role: Role::User,
652 content: vec![MessageContent::ToolResult(tool_result)],
653 cache: false
654 },
655 LanguageModelRequestMessage {
656 role: Role::User,
657 content: vec!["Continue where you left off".into()],
658 cache: true
659 }
660 ]
661 );
662
663 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
664 fake_model.end_last_completion_stream();
665 events.collect::<Vec<_>>().await;
666 thread.read_with(cx, |thread, _cx| {
667 assert_eq!(
668 thread.last_message().unwrap().to_markdown(),
669 indoc! {"
670 ## Assistant
671
672 Done
673 "}
674 )
675 });
676
677 // Ensure we error if calling resume when tool use limit was *not* reached.
678 let error = thread
679 .update(cx, |thread, cx| thread.resume(cx))
680 .unwrap_err();
681 assert_eq!(
682 error.to_string(),
683 "can only resume after tool use limit is reached"
684 )
685}
686
687#[gpui::test]
688async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
689 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
690 let fake_model = model.as_fake();
691
692 let events = thread
693 .update(cx, |thread, cx| {
694 thread.add_tool(EchoTool);
695 thread.send(UserMessageId::new(), ["abc"], cx)
696 })
697 .unwrap();
698 cx.run_until_parked();
699
700 let tool_use = LanguageModelToolUse {
701 id: "tool_id_1".into(),
702 name: EchoTool::name().into(),
703 raw_input: "{}".into(),
704 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
705 is_input_complete: true,
706 };
707 let tool_result = LanguageModelToolResult {
708 tool_use_id: "tool_id_1".into(),
709 tool_name: EchoTool::name().into(),
710 is_error: false,
711 content: "def".into(),
712 output: Some("def".into()),
713 };
714 fake_model
715 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
716 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
717 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
718 ));
719 fake_model.end_last_completion_stream();
720 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
721 assert!(
722 last_event
723 .unwrap_err()
724 .is::<language_model::ToolUseLimitReachedError>()
725 );
726
727 thread
728 .update(cx, |thread, cx| {
729 thread.send(UserMessageId::new(), vec!["ghi"], cx)
730 })
731 .unwrap();
732 cx.run_until_parked();
733 let completion = fake_model.pending_completions().pop().unwrap();
734 assert_eq!(
735 completion.messages[1..],
736 vec![
737 LanguageModelRequestMessage {
738 role: Role::User,
739 content: vec!["abc".into()],
740 cache: false
741 },
742 LanguageModelRequestMessage {
743 role: Role::Assistant,
744 content: vec![MessageContent::ToolUse(tool_use)],
745 cache: false
746 },
747 LanguageModelRequestMessage {
748 role: Role::User,
749 content: vec![MessageContent::ToolResult(tool_result)],
750 cache: false
751 },
752 LanguageModelRequestMessage {
753 role: Role::User,
754 content: vec!["ghi".into()],
755 cache: true
756 }
757 ]
758 );
759}
760
761async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
762 let event = events
763 .next()
764 .await
765 .expect("no tool call authorization event received")
766 .unwrap();
767 match event {
768 ThreadEvent::ToolCall(tool_call) => tool_call,
769 event => {
770 panic!("Unexpected event {event:?}");
771 }
772 }
773}
774
775async fn expect_tool_call_update_fields(
776 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
777) -> acp::ToolCallUpdate {
778 let event = events
779 .next()
780 .await
781 .expect("no tool call authorization event received")
782 .unwrap();
783 match event {
784 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
785 event => {
786 panic!("Unexpected event {event:?}");
787 }
788 }
789}
790
791async fn next_tool_call_authorization(
792 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
793) -> ToolCallAuthorization {
794 loop {
795 let event = events
796 .next()
797 .await
798 .expect("no tool call authorization event received")
799 .unwrap();
800 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
801 let permission_kinds = tool_call_authorization
802 .options
803 .iter()
804 .map(|o| o.kind)
805 .collect::<Vec<_>>();
806 assert_eq!(
807 permission_kinds,
808 vec![
809 acp::PermissionOptionKind::AllowAlways,
810 acp::PermissionOptionKind::AllowOnce,
811 acp::PermissionOptionKind::RejectOnce,
812 ]
813 );
814 return tool_call_authorization;
815 }
816 }
817}
818
819#[gpui::test]
820#[cfg_attr(not(feature = "e2e"), ignore)]
821async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
822 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
823
824 // Test concurrent tool calls with different delay times
825 let events = thread
826 .update(cx, |thread, cx| {
827 thread.add_tool(DelayTool);
828 thread.send(
829 UserMessageId::new(),
830 [
831 "Call the delay tool twice in the same message.",
832 "Once with 100ms. Once with 300ms.",
833 "When both timers are complete, describe the outputs.",
834 ],
835 cx,
836 )
837 })
838 .unwrap()
839 .collect()
840 .await;
841
842 let stop_reasons = stop_events(events);
843 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
844
845 thread.update(cx, |thread, _cx| {
846 let last_message = thread.last_message().unwrap();
847 let agent_message = last_message.as_agent_message().unwrap();
848 let text = agent_message
849 .content
850 .iter()
851 .filter_map(|content| {
852 if let AgentMessageContent::Text(text) = content {
853 Some(text.as_str())
854 } else {
855 None
856 }
857 })
858 .collect::<String>();
859
860 assert!(text.contains("Ding"));
861 });
862}
863
864#[gpui::test]
865async fn test_profiles(cx: &mut TestAppContext) {
866 let ThreadTest {
867 model, thread, fs, ..
868 } = setup(cx, TestModel::Fake).await;
869 let fake_model = model.as_fake();
870
871 thread.update(cx, |thread, _cx| {
872 thread.add_tool(DelayTool);
873 thread.add_tool(EchoTool);
874 thread.add_tool(InfiniteTool);
875 });
876
877 // Override profiles and wait for settings to be loaded.
878 fs.insert_file(
879 paths::settings_file(),
880 json!({
881 "agent": {
882 "profiles": {
883 "test-1": {
884 "name": "Test Profile 1",
885 "tools": {
886 EchoTool::name(): true,
887 DelayTool::name(): true,
888 }
889 },
890 "test-2": {
891 "name": "Test Profile 2",
892 "tools": {
893 InfiniteTool::name(): true,
894 }
895 }
896 }
897 }
898 })
899 .to_string()
900 .into_bytes(),
901 )
902 .await;
903 cx.run_until_parked();
904
905 // Test that test-1 profile (default) has echo and delay tools
906 thread
907 .update(cx, |thread, cx| {
908 thread.set_profile(AgentProfileId("test-1".into()));
909 thread.send(UserMessageId::new(), ["test"], cx)
910 })
911 .unwrap();
912 cx.run_until_parked();
913
914 let mut pending_completions = fake_model.pending_completions();
915 assert_eq!(pending_completions.len(), 1);
916 let completion = pending_completions.pop().unwrap();
917 let tool_names: Vec<String> = completion
918 .tools
919 .iter()
920 .map(|tool| tool.name.clone())
921 .collect();
922 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
923 fake_model.end_last_completion_stream();
924
925 // Switch to test-2 profile, and verify that it has only the infinite tool.
926 thread
927 .update(cx, |thread, cx| {
928 thread.set_profile(AgentProfileId("test-2".into()));
929 thread.send(UserMessageId::new(), ["test2"], cx)
930 })
931 .unwrap();
932 cx.run_until_parked();
933 let mut pending_completions = fake_model.pending_completions();
934 assert_eq!(pending_completions.len(), 1);
935 let completion = pending_completions.pop().unwrap();
936 let tool_names: Vec<String> = completion
937 .tools
938 .iter()
939 .map(|tool| tool.name.clone())
940 .collect();
941 assert_eq!(tool_names, vec![InfiniteTool::name()]);
942}
943
944#[gpui::test]
945async fn test_mcp_tools(cx: &mut TestAppContext) {
946 let ThreadTest {
947 model,
948 thread,
949 context_server_store,
950 fs,
951 ..
952 } = setup(cx, TestModel::Fake).await;
953 let fake_model = model.as_fake();
954
955 // Override profiles and wait for settings to be loaded.
956 fs.insert_file(
957 paths::settings_file(),
958 json!({
959 "agent": {
960 "profiles": {
961 "test": {
962 "name": "Test Profile",
963 "enable_all_context_servers": true,
964 "tools": {
965 EchoTool::name(): true,
966 }
967 },
968 }
969 }
970 })
971 .to_string()
972 .into_bytes(),
973 )
974 .await;
975 cx.run_until_parked();
976 thread.update(cx, |thread, _| {
977 thread.set_profile(AgentProfileId("test".into()))
978 });
979
980 let mut mcp_tool_calls = setup_context_server(
981 "test_server",
982 vec![context_server::types::Tool {
983 name: "echo".into(),
984 description: None,
985 input_schema: serde_json::to_value(
986 EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
987 )
988 .unwrap(),
989 output_schema: None,
990 annotations: None,
991 }],
992 &context_server_store,
993 cx,
994 );
995
996 let events = thread.update(cx, |thread, cx| {
997 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
998 });
999 cx.run_until_parked();
1000
1001 // Simulate the model calling the MCP tool.
1002 let completion = fake_model.pending_completions().pop().unwrap();
1003 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1004 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1005 LanguageModelToolUse {
1006 id: "tool_1".into(),
1007 name: "echo".into(),
1008 raw_input: json!({"text": "test"}).to_string(),
1009 input: json!({"text": "test"}),
1010 is_input_complete: true,
1011 },
1012 ));
1013 fake_model.end_last_completion_stream();
1014 cx.run_until_parked();
1015
1016 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1017 assert_eq!(tool_call_params.name, "echo");
1018 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1019 tool_call_response
1020 .send(context_server::types::CallToolResponse {
1021 content: vec![context_server::types::ToolResponseContent::Text {
1022 text: "test".into(),
1023 }],
1024 is_error: None,
1025 meta: None,
1026 structured_content: None,
1027 })
1028 .unwrap();
1029 cx.run_until_parked();
1030
1031 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1032 fake_model.send_last_completion_stream_text_chunk("Done!");
1033 fake_model.end_last_completion_stream();
1034 events.collect::<Vec<_>>().await;
1035
1036 // Send again after adding the echo tool, ensuring the name collision is resolved.
1037 let events = thread.update(cx, |thread, cx| {
1038 thread.add_tool(EchoTool);
1039 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1040 });
1041 cx.run_until_parked();
1042 let completion = fake_model.pending_completions().pop().unwrap();
1043 assert_eq!(
1044 tool_names_for_completion(&completion),
1045 vec!["echo", "test_server_echo"]
1046 );
1047 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1048 LanguageModelToolUse {
1049 id: "tool_2".into(),
1050 name: "test_server_echo".into(),
1051 raw_input: json!({"text": "mcp"}).to_string(),
1052 input: json!({"text": "mcp"}),
1053 is_input_complete: true,
1054 },
1055 ));
1056 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1057 LanguageModelToolUse {
1058 id: "tool_3".into(),
1059 name: "echo".into(),
1060 raw_input: json!({"text": "native"}).to_string(),
1061 input: json!({"text": "native"}),
1062 is_input_complete: true,
1063 },
1064 ));
1065 fake_model.end_last_completion_stream();
1066 cx.run_until_parked();
1067
1068 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1069 assert_eq!(tool_call_params.name, "echo");
1070 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1071 tool_call_response
1072 .send(context_server::types::CallToolResponse {
1073 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1074 is_error: None,
1075 meta: None,
1076 structured_content: None,
1077 })
1078 .unwrap();
1079 cx.run_until_parked();
1080
1081 // Ensure the tool results were inserted with the correct names.
1082 let completion = fake_model.pending_completions().pop().unwrap();
1083 assert_eq!(
1084 completion.messages.last().unwrap().content,
1085 vec![
1086 MessageContent::ToolResult(LanguageModelToolResult {
1087 tool_use_id: "tool_3".into(),
1088 tool_name: "echo".into(),
1089 is_error: false,
1090 content: "native".into(),
1091 output: Some("native".into()),
1092 },),
1093 MessageContent::ToolResult(LanguageModelToolResult {
1094 tool_use_id: "tool_2".into(),
1095 tool_name: "test_server_echo".into(),
1096 is_error: false,
1097 content: "mcp".into(),
1098 output: Some("mcp".into()),
1099 },),
1100 ]
1101 );
1102 fake_model.end_last_completion_stream();
1103 events.collect::<Vec<_>>().await;
1104}
1105
1106#[gpui::test]
1107async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1108 let ThreadTest {
1109 model,
1110 thread,
1111 context_server_store,
1112 fs,
1113 ..
1114 } = setup(cx, TestModel::Fake).await;
1115 let fake_model = model.as_fake();
1116
1117 // Set up a profile with all tools enabled
1118 fs.insert_file(
1119 paths::settings_file(),
1120 json!({
1121 "agent": {
1122 "profiles": {
1123 "test": {
1124 "name": "Test Profile",
1125 "enable_all_context_servers": true,
1126 "tools": {
1127 EchoTool::name(): true,
1128 DelayTool::name(): true,
1129 WordListTool::name(): true,
1130 ToolRequiringPermission::name(): true,
1131 InfiniteTool::name(): true,
1132 }
1133 },
1134 }
1135 }
1136 })
1137 .to_string()
1138 .into_bytes(),
1139 )
1140 .await;
1141 cx.run_until_parked();
1142
1143 thread.update(cx, |thread, _| {
1144 thread.set_profile(AgentProfileId("test".into()));
1145 thread.add_tool(EchoTool);
1146 thread.add_tool(DelayTool);
1147 thread.add_tool(WordListTool);
1148 thread.add_tool(ToolRequiringPermission);
1149 thread.add_tool(InfiniteTool);
1150 });
1151
1152 // Set up multiple context servers with some overlapping tool names
1153 let _server1_calls = setup_context_server(
1154 "xxx",
1155 vec![
1156 context_server::types::Tool {
1157 name: "echo".into(), // Conflicts with native EchoTool
1158 description: None,
1159 input_schema: serde_json::to_value(
1160 EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1161 )
1162 .unwrap(),
1163 output_schema: None,
1164 annotations: None,
1165 },
1166 context_server::types::Tool {
1167 name: "unique_tool_1".into(),
1168 description: None,
1169 input_schema: json!({"type": "object", "properties": {}}),
1170 output_schema: None,
1171 annotations: None,
1172 },
1173 ],
1174 &context_server_store,
1175 cx,
1176 );
1177
1178 let _server2_calls = setup_context_server(
1179 "yyy",
1180 vec![
1181 context_server::types::Tool {
1182 name: "echo".into(), // Also conflicts with native EchoTool
1183 description: None,
1184 input_schema: serde_json::to_value(
1185 EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1186 )
1187 .unwrap(),
1188 output_schema: None,
1189 annotations: None,
1190 },
1191 context_server::types::Tool {
1192 name: "unique_tool_2".into(),
1193 description: None,
1194 input_schema: json!({"type": "object", "properties": {}}),
1195 output_schema: None,
1196 annotations: None,
1197 },
1198 context_server::types::Tool {
1199 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1200 description: None,
1201 input_schema: json!({"type": "object", "properties": {}}),
1202 output_schema: None,
1203 annotations: None,
1204 },
1205 context_server::types::Tool {
1206 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1207 description: None,
1208 input_schema: json!({"type": "object", "properties": {}}),
1209 output_schema: None,
1210 annotations: None,
1211 },
1212 ],
1213 &context_server_store,
1214 cx,
1215 );
1216 let _server3_calls = setup_context_server(
1217 "zzz",
1218 vec![
1219 context_server::types::Tool {
1220 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1221 description: None,
1222 input_schema: json!({"type": "object", "properties": {}}),
1223 output_schema: None,
1224 annotations: None,
1225 },
1226 context_server::types::Tool {
1227 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1228 description: None,
1229 input_schema: json!({"type": "object", "properties": {}}),
1230 output_schema: None,
1231 annotations: None,
1232 },
1233 context_server::types::Tool {
1234 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1235 description: None,
1236 input_schema: json!({"type": "object", "properties": {}}),
1237 output_schema: None,
1238 annotations: None,
1239 },
1240 ],
1241 &context_server_store,
1242 cx,
1243 );
1244
1245 thread
1246 .update(cx, |thread, cx| {
1247 thread.send(UserMessageId::new(), ["Go"], cx)
1248 })
1249 .unwrap();
1250 cx.run_until_parked();
1251 let completion = fake_model.pending_completions().pop().unwrap();
1252 assert_eq!(
1253 tool_names_for_completion(&completion),
1254 vec![
1255 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1256 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1257 "delay",
1258 "echo",
1259 "infinite",
1260 "tool_requiring_permission",
1261 "unique_tool_1",
1262 "unique_tool_2",
1263 "word_list",
1264 "xxx_echo",
1265 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1266 "yyy_echo",
1267 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1268 ]
1269 );
1270}
1271
1272#[gpui::test]
1273#[cfg_attr(not(feature = "e2e"), ignore)]
1274async fn test_cancellation(cx: &mut TestAppContext) {
1275 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1276
1277 let mut events = thread
1278 .update(cx, |thread, cx| {
1279 thread.add_tool(InfiniteTool);
1280 thread.add_tool(EchoTool);
1281 thread.send(
1282 UserMessageId::new(),
1283 ["Call the echo tool, then call the infinite tool, then explain their output"],
1284 cx,
1285 )
1286 })
1287 .unwrap();
1288
1289 // Wait until both tools are called.
1290 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1291 let mut echo_id = None;
1292 let mut echo_completed = false;
1293 while let Some(event) = events.next().await {
1294 match event.unwrap() {
1295 ThreadEvent::ToolCall(tool_call) => {
1296 assert_eq!(tool_call.title, expected_tools.remove(0));
1297 if tool_call.title == "Echo" {
1298 echo_id = Some(tool_call.id);
1299 }
1300 }
1301 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1302 acp::ToolCallUpdate {
1303 id,
1304 fields:
1305 acp::ToolCallUpdateFields {
1306 status: Some(acp::ToolCallStatus::Completed),
1307 ..
1308 },
1309 },
1310 )) if Some(&id) == echo_id.as_ref() => {
1311 echo_completed = true;
1312 }
1313 _ => {}
1314 }
1315
1316 if expected_tools.is_empty() && echo_completed {
1317 break;
1318 }
1319 }
1320
1321 // Cancel the current send and ensure that the event stream is closed, even
1322 // if one of the tools is still running.
1323 thread.update(cx, |thread, cx| thread.cancel(cx));
1324 let events = events.collect::<Vec<_>>().await;
1325 let last_event = events.last();
1326 assert!(
1327 matches!(
1328 last_event,
1329 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1330 ),
1331 "unexpected event {last_event:?}"
1332 );
1333
1334 // Ensure we can still send a new message after cancellation.
1335 let events = thread
1336 .update(cx, |thread, cx| {
1337 thread.send(
1338 UserMessageId::new(),
1339 ["Testing: reply with 'Hello' then stop."],
1340 cx,
1341 )
1342 })
1343 .unwrap()
1344 .collect::<Vec<_>>()
1345 .await;
1346 thread.update(cx, |thread, _cx| {
1347 let message = thread.last_message().unwrap();
1348 let agent_message = message.as_agent_message().unwrap();
1349 assert_eq!(
1350 agent_message.content,
1351 vec![AgentMessageContent::Text("Hello".to_string())]
1352 );
1353 });
1354 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1355}
1356
1357#[gpui::test]
1358async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1359 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1360 let fake_model = model.as_fake();
1361
1362 let events_1 = thread
1363 .update(cx, |thread, cx| {
1364 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1365 })
1366 .unwrap();
1367 cx.run_until_parked();
1368 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1369 cx.run_until_parked();
1370
1371 let events_2 = thread
1372 .update(cx, |thread, cx| {
1373 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1374 })
1375 .unwrap();
1376 cx.run_until_parked();
1377 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1378 fake_model
1379 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1380 fake_model.end_last_completion_stream();
1381
1382 let events_1 = events_1.collect::<Vec<_>>().await;
1383 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1384 let events_2 = events_2.collect::<Vec<_>>().await;
1385 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1386}
1387
1388#[gpui::test]
1389async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1390 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1391 let fake_model = model.as_fake();
1392
1393 let events_1 = thread
1394 .update(cx, |thread, cx| {
1395 thread.send(UserMessageId::new(), ["Hello 1"], cx)
1396 })
1397 .unwrap();
1398 cx.run_until_parked();
1399 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1400 fake_model
1401 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1402 fake_model.end_last_completion_stream();
1403 let events_1 = events_1.collect::<Vec<_>>().await;
1404
1405 let events_2 = thread
1406 .update(cx, |thread, cx| {
1407 thread.send(UserMessageId::new(), ["Hello 2"], cx)
1408 })
1409 .unwrap();
1410 cx.run_until_parked();
1411 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1412 fake_model
1413 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1414 fake_model.end_last_completion_stream();
1415 let events_2 = events_2.collect::<Vec<_>>().await;
1416
1417 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1418 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1419}
1420
1421#[gpui::test]
1422async fn test_refusal(cx: &mut TestAppContext) {
1423 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1424 let fake_model = model.as_fake();
1425
1426 let events = thread
1427 .update(cx, |thread, cx| {
1428 thread.send(UserMessageId::new(), ["Hello"], cx)
1429 })
1430 .unwrap();
1431 cx.run_until_parked();
1432 thread.read_with(cx, |thread, _| {
1433 assert_eq!(
1434 thread.to_markdown(),
1435 indoc! {"
1436 ## User
1437
1438 Hello
1439 "}
1440 );
1441 });
1442
1443 fake_model.send_last_completion_stream_text_chunk("Hey!");
1444 cx.run_until_parked();
1445 thread.read_with(cx, |thread, _| {
1446 assert_eq!(
1447 thread.to_markdown(),
1448 indoc! {"
1449 ## User
1450
1451 Hello
1452
1453 ## Assistant
1454
1455 Hey!
1456 "}
1457 );
1458 });
1459
1460 // If the model refuses to continue, the thread should remove all the messages after the last user message.
1461 fake_model
1462 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1463 let events = events.collect::<Vec<_>>().await;
1464 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1465 thread.read_with(cx, |thread, _| {
1466 assert_eq!(thread.to_markdown(), "");
1467 });
1468}
1469
1470#[gpui::test]
1471async fn test_truncate_first_message(cx: &mut TestAppContext) {
1472 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1473 let fake_model = model.as_fake();
1474
1475 let message_id = UserMessageId::new();
1476 thread
1477 .update(cx, |thread, cx| {
1478 thread.send(message_id.clone(), ["Hello"], cx)
1479 })
1480 .unwrap();
1481 cx.run_until_parked();
1482 thread.read_with(cx, |thread, _| {
1483 assert_eq!(
1484 thread.to_markdown(),
1485 indoc! {"
1486 ## User
1487
1488 Hello
1489 "}
1490 );
1491 assert_eq!(thread.latest_token_usage(), None);
1492 });
1493
1494 fake_model.send_last_completion_stream_text_chunk("Hey!");
1495 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1496 language_model::TokenUsage {
1497 input_tokens: 32_000,
1498 output_tokens: 16_000,
1499 cache_creation_input_tokens: 0,
1500 cache_read_input_tokens: 0,
1501 },
1502 ));
1503 cx.run_until_parked();
1504 thread.read_with(cx, |thread, _| {
1505 assert_eq!(
1506 thread.to_markdown(),
1507 indoc! {"
1508 ## User
1509
1510 Hello
1511
1512 ## Assistant
1513
1514 Hey!
1515 "}
1516 );
1517 assert_eq!(
1518 thread.latest_token_usage(),
1519 Some(acp_thread::TokenUsage {
1520 used_tokens: 32_000 + 16_000,
1521 max_tokens: 1_000_000,
1522 })
1523 );
1524 });
1525
1526 thread
1527 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1528 .unwrap();
1529 cx.run_until_parked();
1530 thread.read_with(cx, |thread, _| {
1531 assert_eq!(thread.to_markdown(), "");
1532 assert_eq!(thread.latest_token_usage(), None);
1533 });
1534
1535 // Ensure we can still send a new message after truncation.
1536 thread
1537 .update(cx, |thread, cx| {
1538 thread.send(UserMessageId::new(), ["Hi"], cx)
1539 })
1540 .unwrap();
1541 thread.update(cx, |thread, _cx| {
1542 assert_eq!(
1543 thread.to_markdown(),
1544 indoc! {"
1545 ## User
1546
1547 Hi
1548 "}
1549 );
1550 });
1551 cx.run_until_parked();
1552 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1553 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1554 language_model::TokenUsage {
1555 input_tokens: 40_000,
1556 output_tokens: 20_000,
1557 cache_creation_input_tokens: 0,
1558 cache_read_input_tokens: 0,
1559 },
1560 ));
1561 cx.run_until_parked();
1562 thread.read_with(cx, |thread, _| {
1563 assert_eq!(
1564 thread.to_markdown(),
1565 indoc! {"
1566 ## User
1567
1568 Hi
1569
1570 ## Assistant
1571
1572 Ahoy!
1573 "}
1574 );
1575
1576 assert_eq!(
1577 thread.latest_token_usage(),
1578 Some(acp_thread::TokenUsage {
1579 used_tokens: 40_000 + 20_000,
1580 max_tokens: 1_000_000,
1581 })
1582 );
1583 });
1584}
1585
1586#[gpui::test]
1587async fn test_truncate_second_message(cx: &mut TestAppContext) {
1588 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1589 let fake_model = model.as_fake();
1590
1591 thread
1592 .update(cx, |thread, cx| {
1593 thread.send(UserMessageId::new(), ["Message 1"], cx)
1594 })
1595 .unwrap();
1596 cx.run_until_parked();
1597 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1598 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1599 language_model::TokenUsage {
1600 input_tokens: 32_000,
1601 output_tokens: 16_000,
1602 cache_creation_input_tokens: 0,
1603 cache_read_input_tokens: 0,
1604 },
1605 ));
1606 fake_model.end_last_completion_stream();
1607 cx.run_until_parked();
1608
1609 let assert_first_message_state = |cx: &mut TestAppContext| {
1610 thread.clone().read_with(cx, |thread, _| {
1611 assert_eq!(
1612 thread.to_markdown(),
1613 indoc! {"
1614 ## User
1615
1616 Message 1
1617
1618 ## Assistant
1619
1620 Message 1 response
1621 "}
1622 );
1623
1624 assert_eq!(
1625 thread.latest_token_usage(),
1626 Some(acp_thread::TokenUsage {
1627 used_tokens: 32_000 + 16_000,
1628 max_tokens: 1_000_000,
1629 })
1630 );
1631 });
1632 };
1633
1634 assert_first_message_state(cx);
1635
1636 let second_message_id = UserMessageId::new();
1637 thread
1638 .update(cx, |thread, cx| {
1639 thread.send(second_message_id.clone(), ["Message 2"], cx)
1640 })
1641 .unwrap();
1642 cx.run_until_parked();
1643
1644 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1645 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1646 language_model::TokenUsage {
1647 input_tokens: 40_000,
1648 output_tokens: 20_000,
1649 cache_creation_input_tokens: 0,
1650 cache_read_input_tokens: 0,
1651 },
1652 ));
1653 fake_model.end_last_completion_stream();
1654 cx.run_until_parked();
1655
1656 thread.read_with(cx, |thread, _| {
1657 assert_eq!(
1658 thread.to_markdown(),
1659 indoc! {"
1660 ## User
1661
1662 Message 1
1663
1664 ## Assistant
1665
1666 Message 1 response
1667
1668 ## User
1669
1670 Message 2
1671
1672 ## Assistant
1673
1674 Message 2 response
1675 "}
1676 );
1677
1678 assert_eq!(
1679 thread.latest_token_usage(),
1680 Some(acp_thread::TokenUsage {
1681 used_tokens: 40_000 + 20_000,
1682 max_tokens: 1_000_000,
1683 })
1684 );
1685 });
1686
1687 thread
1688 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1689 .unwrap();
1690 cx.run_until_parked();
1691
1692 assert_first_message_state(cx);
1693}
1694
1695#[gpui::test]
1696async fn test_title_generation(cx: &mut TestAppContext) {
1697 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1698 let fake_model = model.as_fake();
1699
1700 let summary_model = Arc::new(FakeLanguageModel::default());
1701 thread.update(cx, |thread, cx| {
1702 thread.set_summarization_model(Some(summary_model.clone()), cx)
1703 });
1704
1705 let send = thread
1706 .update(cx, |thread, cx| {
1707 thread.send(UserMessageId::new(), ["Hello"], cx)
1708 })
1709 .unwrap();
1710 cx.run_until_parked();
1711
1712 fake_model.send_last_completion_stream_text_chunk("Hey!");
1713 fake_model.end_last_completion_stream();
1714 cx.run_until_parked();
1715 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1716
1717 // Ensure the summary model has been invoked to generate a title.
1718 summary_model.send_last_completion_stream_text_chunk("Hello ");
1719 summary_model.send_last_completion_stream_text_chunk("world\nG");
1720 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1721 summary_model.end_last_completion_stream();
1722 send.collect::<Vec<_>>().await;
1723 cx.run_until_parked();
1724 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1725
1726 // Send another message, ensuring no title is generated this time.
1727 let send = thread
1728 .update(cx, |thread, cx| {
1729 thread.send(UserMessageId::new(), ["Hello again"], cx)
1730 })
1731 .unwrap();
1732 cx.run_until_parked();
1733 fake_model.send_last_completion_stream_text_chunk("Hey again!");
1734 fake_model.end_last_completion_stream();
1735 cx.run_until_parked();
1736 assert_eq!(summary_model.pending_completions(), Vec::new());
1737 send.collect::<Vec<_>>().await;
1738 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1739}
1740
1741#[gpui::test]
1742async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
1743 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1744 let fake_model = model.as_fake();
1745
1746 let _events = thread
1747 .update(cx, |thread, cx| {
1748 thread.add_tool(ToolRequiringPermission);
1749 thread.add_tool(EchoTool);
1750 thread.send(UserMessageId::new(), ["Hey!"], cx)
1751 })
1752 .unwrap();
1753 cx.run_until_parked();
1754
1755 let permission_tool_use = LanguageModelToolUse {
1756 id: "tool_id_1".into(),
1757 name: ToolRequiringPermission::name().into(),
1758 raw_input: "{}".into(),
1759 input: json!({}),
1760 is_input_complete: true,
1761 };
1762 let echo_tool_use = LanguageModelToolUse {
1763 id: "tool_id_2".into(),
1764 name: EchoTool::name().into(),
1765 raw_input: json!({"text": "test"}).to_string(),
1766 input: json!({"text": "test"}),
1767 is_input_complete: true,
1768 };
1769 fake_model.send_last_completion_stream_text_chunk("Hi!");
1770 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1771 permission_tool_use,
1772 ));
1773 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1774 echo_tool_use.clone(),
1775 ));
1776 fake_model.end_last_completion_stream();
1777 cx.run_until_parked();
1778
1779 // Ensure pending tools are skipped when building a request.
1780 let request = thread
1781 .read_with(cx, |thread, cx| {
1782 thread.build_completion_request(CompletionIntent::EditFile, cx)
1783 })
1784 .unwrap();
1785 assert_eq!(
1786 request.messages[1..],
1787 vec![
1788 LanguageModelRequestMessage {
1789 role: Role::User,
1790 content: vec!["Hey!".into()],
1791 cache: true
1792 },
1793 LanguageModelRequestMessage {
1794 role: Role::Assistant,
1795 content: vec![
1796 MessageContent::Text("Hi!".into()),
1797 MessageContent::ToolUse(echo_tool_use.clone())
1798 ],
1799 cache: false
1800 },
1801 LanguageModelRequestMessage {
1802 role: Role::User,
1803 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
1804 tool_use_id: echo_tool_use.id.clone(),
1805 tool_name: echo_tool_use.name,
1806 is_error: false,
1807 content: "test".into(),
1808 output: Some("test".into())
1809 })],
1810 cache: false
1811 },
1812 ],
1813 );
1814}
1815
1816#[gpui::test]
1817async fn test_agent_connection(cx: &mut TestAppContext) {
1818 cx.update(settings::init);
1819 let templates = Templates::new();
1820
1821 // Initialize language model system with test provider
1822 cx.update(|cx| {
1823 gpui_tokio::init(cx);
1824 client::init_settings(cx);
1825
1826 let http_client = FakeHttpClient::with_404_response();
1827 let clock = Arc::new(clock::FakeSystemClock::new());
1828 let client = Client::new(clock, http_client, cx);
1829 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1830 Project::init_settings(cx);
1831 agent_settings::init(cx);
1832 language_model::init(client.clone(), cx);
1833 language_models::init(user_store, client.clone(), cx);
1834 LanguageModelRegistry::test(cx);
1835 });
1836 cx.executor().forbid_parking();
1837
1838 // Create a project for new_thread
1839 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1840 fake_fs.insert_tree(path!("/test"), json!({})).await;
1841 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1842 let cwd = Path::new("/test");
1843 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1844 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1845
1846 // Create agent and connection
1847 let agent = NativeAgent::new(
1848 project.clone(),
1849 history_store,
1850 templates.clone(),
1851 None,
1852 fake_fs.clone(),
1853 &mut cx.to_async(),
1854 )
1855 .await
1856 .unwrap();
1857 let connection = NativeAgentConnection(agent.clone());
1858
1859 // Test model_selector returns Some
1860 let selector_opt = connection.model_selector();
1861 assert!(
1862 selector_opt.is_some(),
1863 "agent2 should always support ModelSelector"
1864 );
1865 let selector = selector_opt.unwrap();
1866
1867 // Test list_models
1868 let listed_models = cx
1869 .update(|cx| selector.list_models(cx))
1870 .await
1871 .expect("list_models should succeed");
1872 let AgentModelList::Grouped(listed_models) = listed_models else {
1873 panic!("Unexpected model list type");
1874 };
1875 assert!(!listed_models.is_empty(), "should have at least one model");
1876 assert_eq!(
1877 listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1878 "fake/fake"
1879 );
1880
1881 // Create a thread using new_thread
1882 let connection_rc = Rc::new(connection.clone());
1883 let acp_thread = cx
1884 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1885 .await
1886 .expect("new_thread should succeed");
1887
1888 // Get the session_id from the AcpThread
1889 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1890
1891 // Test selected_model returns the default
1892 let model = cx
1893 .update(|cx| selector.selected_model(&session_id, cx))
1894 .await
1895 .expect("selected_model should succeed");
1896 let model = cx
1897 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1898 .unwrap();
1899 let model = model.as_fake();
1900 assert_eq!(model.id().0, "fake", "should return default model");
1901
1902 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1903 cx.run_until_parked();
1904 model.send_last_completion_stream_text_chunk("def");
1905 cx.run_until_parked();
1906 acp_thread.read_with(cx, |thread, cx| {
1907 assert_eq!(
1908 thread.to_markdown(cx),
1909 indoc! {"
1910 ## User
1911
1912 abc
1913
1914 ## Assistant
1915
1916 def
1917
1918 "}
1919 )
1920 });
1921
1922 // Test cancel
1923 cx.update(|cx| connection.cancel(&session_id, cx));
1924 request.await.expect("prompt should fail gracefully");
1925
1926 // Ensure that dropping the ACP thread causes the native thread to be
1927 // dropped as well.
1928 cx.update(|_| drop(acp_thread));
1929 let result = cx
1930 .update(|cx| {
1931 connection.prompt(
1932 Some(acp_thread::UserMessageId::new()),
1933 acp::PromptRequest {
1934 session_id: session_id.clone(),
1935 prompt: vec!["ghi".into()],
1936 },
1937 cx,
1938 )
1939 })
1940 .await;
1941 assert_eq!(
1942 result.as_ref().unwrap_err().to_string(),
1943 "Session not found",
1944 "unexpected result: {:?}",
1945 result
1946 );
1947}
1948
1949#[gpui::test]
1950async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1951 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1952 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1953 let fake_model = model.as_fake();
1954
1955 let mut events = thread
1956 .update(cx, |thread, cx| {
1957 thread.send(UserMessageId::new(), ["Think"], cx)
1958 })
1959 .unwrap();
1960 cx.run_until_parked();
1961
1962 // Simulate streaming partial input.
1963 let input = json!({});
1964 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1965 LanguageModelToolUse {
1966 id: "1".into(),
1967 name: ThinkingTool::name().into(),
1968 raw_input: input.to_string(),
1969 input,
1970 is_input_complete: false,
1971 },
1972 ));
1973
1974 // Input streaming completed
1975 let input = json!({ "content": "Thinking hard!" });
1976 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1977 LanguageModelToolUse {
1978 id: "1".into(),
1979 name: "thinking".into(),
1980 raw_input: input.to_string(),
1981 input,
1982 is_input_complete: true,
1983 },
1984 ));
1985 fake_model.end_last_completion_stream();
1986 cx.run_until_parked();
1987
1988 let tool_call = expect_tool_call(&mut events).await;
1989 assert_eq!(
1990 tool_call,
1991 acp::ToolCall {
1992 id: acp::ToolCallId("1".into()),
1993 title: "Thinking".into(),
1994 kind: acp::ToolKind::Think,
1995 status: acp::ToolCallStatus::Pending,
1996 content: vec![],
1997 locations: vec![],
1998 raw_input: Some(json!({})),
1999 raw_output: None,
2000 }
2001 );
2002 let update = expect_tool_call_update_fields(&mut events).await;
2003 assert_eq!(
2004 update,
2005 acp::ToolCallUpdate {
2006 id: acp::ToolCallId("1".into()),
2007 fields: acp::ToolCallUpdateFields {
2008 title: Some("Thinking".into()),
2009 kind: Some(acp::ToolKind::Think),
2010 raw_input: Some(json!({ "content": "Thinking hard!" })),
2011 ..Default::default()
2012 },
2013 }
2014 );
2015 let update = expect_tool_call_update_fields(&mut events).await;
2016 assert_eq!(
2017 update,
2018 acp::ToolCallUpdate {
2019 id: acp::ToolCallId("1".into()),
2020 fields: acp::ToolCallUpdateFields {
2021 status: Some(acp::ToolCallStatus::InProgress),
2022 ..Default::default()
2023 },
2024 }
2025 );
2026 let update = expect_tool_call_update_fields(&mut events).await;
2027 assert_eq!(
2028 update,
2029 acp::ToolCallUpdate {
2030 id: acp::ToolCallId("1".into()),
2031 fields: acp::ToolCallUpdateFields {
2032 content: Some(vec!["Thinking hard!".into()]),
2033 ..Default::default()
2034 },
2035 }
2036 );
2037 let update = expect_tool_call_update_fields(&mut events).await;
2038 assert_eq!(
2039 update,
2040 acp::ToolCallUpdate {
2041 id: acp::ToolCallId("1".into()),
2042 fields: acp::ToolCallUpdateFields {
2043 status: Some(acp::ToolCallStatus::Completed),
2044 raw_output: Some("Finished thinking.".into()),
2045 ..Default::default()
2046 },
2047 }
2048 );
2049}
2050
2051#[gpui::test]
2052async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2053 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2054 let fake_model = model.as_fake();
2055
2056 let mut events = thread
2057 .update(cx, |thread, cx| {
2058 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2059 thread.send(UserMessageId::new(), ["Hello!"], cx)
2060 })
2061 .unwrap();
2062 cx.run_until_parked();
2063
2064 fake_model.send_last_completion_stream_text_chunk("Hey!");
2065 fake_model.end_last_completion_stream();
2066
2067 let mut retry_events = Vec::new();
2068 while let Some(Ok(event)) = events.next().await {
2069 match event {
2070 ThreadEvent::Retry(retry_status) => {
2071 retry_events.push(retry_status);
2072 }
2073 ThreadEvent::Stop(..) => break,
2074 _ => {}
2075 }
2076 }
2077
2078 assert_eq!(retry_events.len(), 0);
2079 thread.read_with(cx, |thread, _cx| {
2080 assert_eq!(
2081 thread.to_markdown(),
2082 indoc! {"
2083 ## User
2084
2085 Hello!
2086
2087 ## Assistant
2088
2089 Hey!
2090 "}
2091 )
2092 });
2093}
2094
2095#[gpui::test]
2096async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2097 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2098 let fake_model = model.as_fake();
2099
2100 let mut events = thread
2101 .update(cx, |thread, cx| {
2102 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2103 thread.send(UserMessageId::new(), ["Hello!"], cx)
2104 })
2105 .unwrap();
2106 cx.run_until_parked();
2107
2108 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2109 provider: LanguageModelProviderName::new("Anthropic"),
2110 retry_after: Some(Duration::from_secs(3)),
2111 });
2112 fake_model.end_last_completion_stream();
2113
2114 cx.executor().advance_clock(Duration::from_secs(3));
2115 cx.run_until_parked();
2116
2117 fake_model.send_last_completion_stream_text_chunk("Hey!");
2118 fake_model.end_last_completion_stream();
2119
2120 let mut retry_events = Vec::new();
2121 while let Some(Ok(event)) = events.next().await {
2122 match event {
2123 ThreadEvent::Retry(retry_status) => {
2124 retry_events.push(retry_status);
2125 }
2126 ThreadEvent::Stop(..) => break,
2127 _ => {}
2128 }
2129 }
2130
2131 assert_eq!(retry_events.len(), 1);
2132 assert!(matches!(
2133 retry_events[0],
2134 acp_thread::RetryStatus { attempt: 1, .. }
2135 ));
2136 thread.read_with(cx, |thread, _cx| {
2137 assert_eq!(
2138 thread.to_markdown(),
2139 indoc! {"
2140 ## User
2141
2142 Hello!
2143
2144 ## Assistant
2145
2146 Hey!
2147 "}
2148 )
2149 });
2150}
2151
2152#[gpui::test]
2153async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2154 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2155 let fake_model = model.as_fake();
2156
2157 let mut events = thread
2158 .update(cx, |thread, cx| {
2159 thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2160 thread.send(UserMessageId::new(), ["Hello!"], cx)
2161 })
2162 .unwrap();
2163 cx.run_until_parked();
2164
2165 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2166 fake_model.send_last_completion_stream_error(
2167 LanguageModelCompletionError::ServerOverloaded {
2168 provider: LanguageModelProviderName::new("Anthropic"),
2169 retry_after: Some(Duration::from_secs(3)),
2170 },
2171 );
2172 fake_model.end_last_completion_stream();
2173 cx.executor().advance_clock(Duration::from_secs(3));
2174 cx.run_until_parked();
2175 }
2176
2177 let mut errors = Vec::new();
2178 let mut retry_events = Vec::new();
2179 while let Some(event) = events.next().await {
2180 match event {
2181 Ok(ThreadEvent::Retry(retry_status)) => {
2182 retry_events.push(retry_status);
2183 }
2184 Ok(ThreadEvent::Stop(..)) => break,
2185 Err(error) => errors.push(error),
2186 _ => {}
2187 }
2188 }
2189
2190 assert_eq!(
2191 retry_events.len(),
2192 crate::thread::MAX_RETRY_ATTEMPTS as usize
2193 );
2194 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2195 assert_eq!(retry_events[i].attempt, i + 1);
2196 }
2197 assert_eq!(errors.len(), 1);
2198 let error = errors[0]
2199 .downcast_ref::<LanguageModelCompletionError>()
2200 .unwrap();
2201 assert!(matches!(
2202 error,
2203 LanguageModelCompletionError::ServerOverloaded { .. }
2204 ));
2205}
2206
2207/// Filters out the stop events for asserting against in tests
2208fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2209 result_events
2210 .into_iter()
2211 .filter_map(|event| match event.unwrap() {
2212 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2213 _ => None,
2214 })
2215 .collect()
2216}
2217
2218struct ThreadTest {
2219 model: Arc<dyn LanguageModel>,
2220 thread: Entity<Thread>,
2221 project_context: Entity<ProjectContext>,
2222 context_server_store: Entity<ContextServerStore>,
2223 fs: Arc<FakeFs>,
2224}
2225
2226enum TestModel {
2227 Sonnet4,
2228 Fake,
2229}
2230
2231impl TestModel {
2232 fn id(&self) -> LanguageModelId {
2233 match self {
2234 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2235 TestModel::Fake => unreachable!(),
2236 }
2237 }
2238}
2239
2240async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2241 cx.executor().allow_parking();
2242
2243 let fs = FakeFs::new(cx.background_executor.clone());
2244 fs.create_dir(paths::settings_file().parent().unwrap())
2245 .await
2246 .unwrap();
2247 fs.insert_file(
2248 paths::settings_file(),
2249 json!({
2250 "agent": {
2251 "default_profile": "test-profile",
2252 "profiles": {
2253 "test-profile": {
2254 "name": "Test Profile",
2255 "tools": {
2256 EchoTool::name(): true,
2257 DelayTool::name(): true,
2258 WordListTool::name(): true,
2259 ToolRequiringPermission::name(): true,
2260 InfiniteTool::name(): true,
2261 ThinkingTool::name(): true,
2262 }
2263 }
2264 }
2265 }
2266 })
2267 .to_string()
2268 .into_bytes(),
2269 )
2270 .await;
2271
2272 cx.update(|cx| {
2273 settings::init(cx);
2274 Project::init_settings(cx);
2275 agent_settings::init(cx);
2276 gpui_tokio::init(cx);
2277 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2278 cx.set_http_client(Arc::new(http_client));
2279
2280 client::init_settings(cx);
2281 let client = Client::production(cx);
2282 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2283 language_model::init(client.clone(), cx);
2284 language_models::init(user_store, client.clone(), cx);
2285
2286 watch_settings(fs.clone(), cx);
2287 });
2288
2289 let templates = Templates::new();
2290
2291 fs.insert_tree(path!("/test"), json!({})).await;
2292 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2293
2294 let model = cx
2295 .update(|cx| {
2296 if let TestModel::Fake = model {
2297 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2298 } else {
2299 let model_id = model.id();
2300 let models = LanguageModelRegistry::read_global(cx);
2301 let model = models
2302 .available_models(cx)
2303 .find(|model| model.id() == model_id)
2304 .unwrap();
2305
2306 let provider = models.provider(&model.provider_id()).unwrap();
2307 let authenticated = provider.authenticate(cx);
2308
2309 cx.spawn(async move |_cx| {
2310 authenticated.await.unwrap();
2311 model
2312 })
2313 }
2314 })
2315 .await;
2316
2317 let project_context = cx.new(|_cx| ProjectContext::default());
2318 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2319 let context_server_registry =
2320 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2321 let thread = cx.new(|cx| {
2322 Thread::new(
2323 project,
2324 project_context.clone(),
2325 context_server_registry,
2326 templates,
2327 Some(model.clone()),
2328 cx,
2329 )
2330 });
2331 ThreadTest {
2332 model,
2333 thread,
2334 project_context,
2335 context_server_store,
2336 fs,
2337 }
2338}
2339
2340#[cfg(test)]
2341#[ctor::ctor]
2342fn init_logger() {
2343 if std::env::var("RUST_LOG").is_ok() {
2344 env_logger::init();
2345 }
2346}
2347
2348fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2349 let fs = fs.clone();
2350 cx.spawn({
2351 async move |cx| {
2352 let mut new_settings_content_rx = settings::watch_config_file(
2353 cx.background_executor(),
2354 fs,
2355 paths::settings_file().clone(),
2356 );
2357
2358 while let Some(new_settings_content) = new_settings_content_rx.next().await {
2359 cx.update(|cx| {
2360 SettingsStore::update_global(cx, |settings, cx| {
2361 settings.set_user_settings(&new_settings_content, cx)
2362 })
2363 })
2364 .ok();
2365 }
2366 }
2367 })
2368 .detach();
2369}
2370
2371fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2372 completion
2373 .tools
2374 .iter()
2375 .map(|tool| tool.name.clone())
2376 .collect()
2377}
2378
2379fn setup_context_server(
2380 name: &'static str,
2381 tools: Vec<context_server::types::Tool>,
2382 context_server_store: &Entity<ContextServerStore>,
2383 cx: &mut TestAppContext,
2384) -> mpsc::UnboundedReceiver<(
2385 context_server::types::CallToolParams,
2386 oneshot::Sender<context_server::types::CallToolResponse>,
2387)> {
2388 cx.update(|cx| {
2389 let mut settings = ProjectSettings::get_global(cx).clone();
2390 settings.context_servers.insert(
2391 name.into(),
2392 project::project_settings::ContextServerSettings::Custom {
2393 enabled: true,
2394 command: ContextServerCommand {
2395 path: "somebinary".into(),
2396 args: Vec::new(),
2397 env: None,
2398 },
2399 },
2400 );
2401 ProjectSettings::override_global(settings, cx);
2402 });
2403
2404 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2405 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2406 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2407 context_server::types::InitializeResponse {
2408 protocol_version: context_server::types::ProtocolVersion(
2409 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2410 ),
2411 server_info: context_server::types::Implementation {
2412 name: name.into(),
2413 version: "1.0.0".to_string(),
2414 },
2415 capabilities: context_server::types::ServerCapabilities {
2416 tools: Some(context_server::types::ToolsCapabilities {
2417 list_changed: Some(true),
2418 }),
2419 ..Default::default()
2420 },
2421 meta: None,
2422 }
2423 })
2424 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2425 let tools = tools.clone();
2426 async move {
2427 context_server::types::ListToolsResponse {
2428 tools,
2429 next_cursor: None,
2430 meta: None,
2431 }
2432 }
2433 })
2434 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2435 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2436 async move {
2437 let (response_tx, response_rx) = oneshot::channel();
2438 mcp_tool_calls_tx
2439 .unbounded_send((params, response_tx))
2440 .unwrap();
2441 response_rx.await.unwrap()
2442 }
2443 });
2444 context_server_store.update(cx, |store, cx| {
2445 store.start_server(
2446 Arc::new(ContextServer::new(
2447 ContextServerId(name.into()),
2448 Arc::new(fake_transport),
2449 )),
2450 cx,
2451 );
2452 });
2453 cx.run_until_parked();
2454 mcp_tool_calls_rx
2455}