1mod mcp_server;
2mod tools;
3
4use collections::HashMap;
5use project::Project;
6use settings::SettingsStore;
7use std::cell::RefCell;
8use std::fmt::Display;
9use std::path::Path;
10use std::rc::Rc;
11
12use agentic_coding_protocol::{
13 self as acp, AnyAgentRequest, AnyAgentResult, Client, ProtocolVersion,
14 StreamAssistantMessageChunkParams, ToolCallContent, UpdateToolCallParams,
15};
16use anyhow::{Result, anyhow};
17use futures::channel::oneshot;
18use futures::future::LocalBoxFuture;
19use futures::{AsyncBufReadExt, AsyncWriteExt};
20use futures::{
21 AsyncRead, AsyncWrite, FutureExt, StreamExt,
22 channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
23 io::BufReader,
24 select_biased,
25};
26use gpui::{App, AppContext, Entity, Task};
27use serde::{Deserialize, Serialize};
28use util::ResultExt;
29
30use crate::claude::mcp_server::ClaudeMcpServer;
31use crate::claude::tools::ClaudeTool;
32use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
33use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
34
35#[derive(Clone)]
36pub struct ClaudeCode;
37
38impl AgentServer for ClaudeCode {
39 fn name(&self) -> &'static str {
40 "Claude Code"
41 }
42
43 fn empty_state_headline(&self) -> &'static str {
44 self.name()
45 }
46
47 fn empty_state_message(&self) -> &'static str {
48 ""
49 }
50
51 fn logo(&self) -> ui::IconName {
52 ui::IconName::AiClaude
53 }
54
55 fn supports_always_allow(&self) -> bool {
56 false
57 }
58
59 fn new_thread(
60 &self,
61 root_dir: &Path,
62 project: &Entity<Project>,
63 cx: &mut App,
64 ) -> Task<Result<Entity<AcpThread>>> {
65 let project = project.clone();
66 let root_dir = root_dir.to_path_buf();
67 let title = self.name().into();
68 let context_server_store = project.read(cx).context_server_store().read(cx);
69 let mut mcp_servers = HashMap::default();
70 for id in context_server_store.all_server_ids() {
71 let Some(configuration) = context_server_store.configuration_for_server(&id) else {
72 continue;
73 };
74 let command = configuration.command();
75 mcp_servers.insert(
76 id.0.to_string(),
77 McpServerConfig {
78 command: command.path.clone(),
79 args: command.args.clone(),
80 env: command.env.clone(),
81 },
82 );
83 }
84 cx.spawn(async move |cx| {
85 let (mut delegate_tx, delegate_rx) = watch::channel(None);
86 let tool_id_map = Rc::new(RefCell::new(HashMap::default()));
87
88 let permission_mcp_server =
89 ClaudeMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?;
90
91 mcp_servers.insert(
92 mcp_server::SERVER_NAME.to_string(),
93 permission_mcp_server.server_config()?,
94 );
95 dbg!(&mcp_servers);
96 let mcp_config = McpConfig { mcp_servers };
97
98 let mcp_config_file = tempfile::NamedTempFile::new()?;
99 let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
100
101 let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
102 mcp_config_file
103 .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
104 .await?;
105 mcp_config_file.flush().await?;
106
107 let settings = cx.read_global(|settings: &SettingsStore, _| {
108 settings.get::<AllAgentServersSettings>(None).claude.clone()
109 })?;
110
111 let Some(command) =
112 AgentServerCommand::resolve("claude", &[], settings, &project, cx).await
113 else {
114 anyhow::bail!("Failed to find claude binary");
115 };
116
117 let mut child = util::command::new_smol_command(&command.path)
118 .args(
119 [
120 "--input-format",
121 "stream-json",
122 "--output-format",
123 "stream-json",
124 "--print",
125 "--verbose",
126 "--mcp-config",
127 mcp_config_path.to_string_lossy().as_ref(),
128 "--permission-prompt-tool",
129 &format!(
130 "mcp__{}__{}",
131 mcp_server::SERVER_NAME,
132 mcp_server::PERMISSION_TOOL
133 ),
134 "--allowedTools",
135 "mcp__zed__Read,mcp__zed__Edit",
136 "--disallowedTools",
137 "Read,Edit",
138 ]
139 .into_iter()
140 .chain(command.args.iter().map(|arg| arg.as_str())),
141 )
142 .current_dir(root_dir)
143 .stdin(std::process::Stdio::piped())
144 .stdout(std::process::Stdio::piped())
145 .stderr(std::process::Stdio::inherit())
146 .kill_on_drop(true)
147 .spawn()?;
148
149 let stdin = child.stdin.take().unwrap();
150 let stdout = child.stdout.take().unwrap();
151
152 let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
153 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
154
155 let io_task =
156 ClaudeAgentConnection::handle_io(outgoing_rx, incoming_message_tx, stdin, stdout);
157 cx.background_spawn(async move {
158 io_task.await.log_err();
159 drop(mcp_config_path);
160 drop(child);
161 })
162 .detach();
163
164 cx.new(|cx| {
165 let end_turn_tx = Rc::new(RefCell::new(None));
166 let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async());
167 delegate_tx.send(Some(delegate.clone())).log_err();
168
169 let handler_task = cx.foreground_executor().spawn({
170 let end_turn_tx = end_turn_tx.clone();
171 let tool_id_map = tool_id_map.clone();
172 async move {
173 while let Some(message) = incoming_message_rx.next().await {
174 ClaudeAgentConnection::handle_message(
175 delegate.clone(),
176 message,
177 end_turn_tx.clone(),
178 tool_id_map.clone(),
179 )
180 .await
181 }
182 }
183 });
184
185 let mut connection = ClaudeAgentConnection {
186 outgoing_tx,
187 end_turn_tx,
188 _handler_task: handler_task,
189 _mcp_server: None,
190 };
191
192 connection._mcp_server = Some(permission_mcp_server);
193 acp_thread::AcpThread::new(connection, title, None, project.clone(), cx)
194 })
195 })
196 }
197}
198
199impl AgentConnection for ClaudeAgentConnection {
200 /// Send a request to the agent and wait for a response.
201 fn request_any(
202 &self,
203 params: AnyAgentRequest,
204 ) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
205 let end_turn_tx = self.end_turn_tx.clone();
206 let outgoing_tx = self.outgoing_tx.clone();
207 async move {
208 match params {
209 // todo: consider sending an empty request so we get the init response?
210 AnyAgentRequest::InitializeParams(_) => Ok(AnyAgentResult::InitializeResponse(
211 acp::InitializeResponse {
212 is_authenticated: true,
213 protocol_version: ProtocolVersion::latest(),
214 },
215 )),
216 AnyAgentRequest::AuthenticateParams(_) => {
217 Err(anyhow!("Authentication not supported"))
218 }
219 AnyAgentRequest::SendUserMessageParams(message) => {
220 let (tx, rx) = oneshot::channel();
221 end_turn_tx.borrow_mut().replace(tx);
222 let mut content = String::new();
223 for chunk in message.chunks {
224 match chunk {
225 agentic_coding_protocol::UserMessageChunk::Text { text } => {
226 content.push_str(&text)
227 }
228 agentic_coding_protocol::UserMessageChunk::Path { path } => {
229 content.push_str(&format!("@{path:?}"))
230 }
231 }
232 }
233 outgoing_tx.unbounded_send(SdkMessage::User {
234 message: Message {
235 role: Role::User,
236 content: Content::UntaggedText(content),
237 id: None,
238 model: None,
239 stop_reason: None,
240 stop_sequence: None,
241 usage: None,
242 },
243 session_id: None,
244 })?;
245 rx.await??;
246 Ok(AnyAgentResult::SendUserMessageResponse(
247 acp::SendUserMessageResponse,
248 ))
249 }
250 AnyAgentRequest::CancelSendMessageParams(_) => Ok(
251 AnyAgentResult::CancelSendMessageResponse(acp::CancelSendMessageResponse),
252 ),
253 }
254 }
255 .boxed_local()
256 }
257}
258
259struct ClaudeAgentConnection {
260 outgoing_tx: UnboundedSender<SdkMessage>,
261 end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
262 _mcp_server: Option<ClaudeMcpServer>,
263 _handler_task: Task<()>,
264}
265
266impl ClaudeAgentConnection {
267 async fn handle_message(
268 delegate: AcpClientDelegate,
269 message: SdkMessage,
270 end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
271 tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
272 ) {
273 match message {
274 SdkMessage::Assistant { message, .. } | SdkMessage::User { message, .. } => {
275 for chunk in message.content.chunks() {
276 match chunk {
277 ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
278 delegate
279 .stream_assistant_message_chunk(StreamAssistantMessageChunkParams {
280 chunk: acp::AssistantMessageChunk::Text { text },
281 })
282 .await
283 .log_err();
284 }
285 ContentChunk::ToolUse { id, name, input } => {
286 if let Some(resp) = delegate
287 .push_tool_call(ClaudeTool::infer(&name, input).as_acp())
288 .await
289 .log_err()
290 {
291 tool_id_map.borrow_mut().insert(id, resp.id);
292 }
293 }
294 ContentChunk::ToolResult {
295 content,
296 tool_use_id,
297 } => {
298 let id = tool_id_map.borrow_mut().remove(&tool_use_id);
299 if let Some(id) = id {
300 let content = content.to_string();
301 delegate
302 .update_tool_call(UpdateToolCallParams {
303 tool_call_id: id,
304 status: acp::ToolCallStatus::Finished,
305 // Don't unset existing content
306 content: (!content.is_empty()).then_some(
307 ToolCallContent::Markdown {
308 // For now we only include text content
309 markdown: content,
310 },
311 ),
312 })
313 .await
314 .log_err();
315 }
316 }
317 ContentChunk::Image
318 | ContentChunk::Document
319 | ContentChunk::Thinking
320 | ContentChunk::RedactedThinking
321 | ContentChunk::WebSearchToolResult => {
322 delegate
323 .stream_assistant_message_chunk(StreamAssistantMessageChunkParams {
324 chunk: acp::AssistantMessageChunk::Text {
325 text: format!("Unsupported content: {:?}", chunk),
326 },
327 })
328 .await
329 .log_err();
330 }
331 }
332 }
333 }
334 SdkMessage::Result {
335 is_error, subtype, ..
336 } => {
337 if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
338 if is_error {
339 end_turn_tx.send(Err(anyhow!("Error: {subtype}"))).ok();
340 } else {
341 end_turn_tx.send(Ok(())).ok();
342 }
343 }
344 }
345 SdkMessage::System { .. } => {}
346 }
347 }
348
349 async fn handle_io(
350 mut outgoing_rx: UnboundedReceiver<SdkMessage>,
351 incoming_tx: UnboundedSender<SdkMessage>,
352 mut outgoing_bytes: impl Unpin + AsyncWrite,
353 incoming_bytes: impl Unpin + AsyncRead,
354 ) -> Result<()> {
355 let mut output_reader = BufReader::new(incoming_bytes);
356 let mut outgoing_line = Vec::new();
357 let mut incoming_line = String::new();
358 loop {
359 select_biased! {
360 message = outgoing_rx.next() => {
361 if let Some(message) = message {
362 outgoing_line.clear();
363 serde_json::to_writer(&mut outgoing_line, &message)?;
364 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
365 outgoing_line.push(b'\n');
366 outgoing_bytes.write_all(&outgoing_line).await.ok();
367 } else {
368 break;
369 }
370 }
371 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
372 if bytes_read? == 0 {
373 break
374 }
375 log::trace!("recv: {}", &incoming_line);
376 match serde_json::from_str::<SdkMessage>(&incoming_line) {
377 Ok(message) => {
378 incoming_tx.unbounded_send(message).log_err();
379 }
380 Err(error) => {
381 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
382 }
383 }
384 incoming_line.clear();
385 }
386 }
387 }
388 Ok(())
389 }
390}
391
392#[derive(Debug, Clone, Serialize, Deserialize)]
393struct Message {
394 role: Role,
395 content: Content,
396 #[serde(skip_serializing_if = "Option::is_none")]
397 id: Option<String>,
398 #[serde(skip_serializing_if = "Option::is_none")]
399 model: Option<String>,
400 #[serde(skip_serializing_if = "Option::is_none")]
401 stop_reason: Option<String>,
402 #[serde(skip_serializing_if = "Option::is_none")]
403 stop_sequence: Option<String>,
404 #[serde(skip_serializing_if = "Option::is_none")]
405 usage: Option<Usage>,
406}
407
408#[derive(Debug, Clone, Serialize, Deserialize)]
409#[serde(untagged)]
410enum Content {
411 UntaggedText(String),
412 Chunks(Vec<ContentChunk>),
413}
414
415impl Content {
416 pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
417 match self {
418 Self::Chunks(chunks) => chunks.into_iter(),
419 Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
420 }
421 }
422}
423
424impl Display for Content {
425 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
426 match self {
427 Content::UntaggedText(txt) => write!(f, "{}", txt),
428 Content::Chunks(chunks) => {
429 for chunk in chunks {
430 write!(f, "{}", chunk)?;
431 }
432 Ok(())
433 }
434 }
435 }
436}
437
438#[derive(Debug, Clone, Serialize, Deserialize)]
439#[serde(tag = "type", rename_all = "snake_case")]
440enum ContentChunk {
441 Text {
442 text: String,
443 },
444 ToolUse {
445 id: String,
446 name: String,
447 input: serde_json::Value,
448 },
449 ToolResult {
450 content: Content,
451 tool_use_id: String,
452 },
453 // TODO
454 Image,
455 Document,
456 Thinking,
457 RedactedThinking,
458 WebSearchToolResult,
459 #[serde(untagged)]
460 UntaggedText(String),
461}
462
463impl Display for ContentChunk {
464 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
465 match self {
466 ContentChunk::Text { text } => write!(f, "{}", text),
467 ContentChunk::UntaggedText(text) => write!(f, "{}", text),
468 ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
469 ContentChunk::Image
470 | ContentChunk::Document
471 | ContentChunk::Thinking
472 | ContentChunk::RedactedThinking
473 | ContentChunk::ToolUse { .. }
474 | ContentChunk::WebSearchToolResult => {
475 write!(f, "\n{:?}\n", &self)
476 }
477 }
478 }
479}
480
481#[derive(Debug, Clone, Serialize, Deserialize)]
482struct Usage {
483 input_tokens: u32,
484 cache_creation_input_tokens: u32,
485 cache_read_input_tokens: u32,
486 output_tokens: u32,
487 service_tier: String,
488}
489
490#[derive(Debug, Clone, Serialize, Deserialize)]
491#[serde(rename_all = "snake_case")]
492enum Role {
493 System,
494 Assistant,
495 User,
496}
497
498#[derive(Debug, Clone, Serialize, Deserialize)]
499struct MessageParam {
500 role: Role,
501 content: String,
502}
503
504#[derive(Debug, Clone, Serialize, Deserialize)]
505#[serde(tag = "type", rename_all = "snake_case")]
506enum SdkMessage {
507 // An assistant message
508 Assistant {
509 message: Message, // from Anthropic SDK
510 #[serde(skip_serializing_if = "Option::is_none")]
511 session_id: Option<String>,
512 },
513
514 // A user message
515 User {
516 message: Message, // from Anthropic SDK
517 #[serde(skip_serializing_if = "Option::is_none")]
518 session_id: Option<String>,
519 },
520
521 // Emitted as the last message in a conversation
522 Result {
523 subtype: ResultErrorType,
524 duration_ms: f64,
525 duration_api_ms: f64,
526 is_error: bool,
527 num_turns: i32,
528 #[serde(skip_serializing_if = "Option::is_none")]
529 result: Option<String>,
530 session_id: String,
531 total_cost_usd: f64,
532 },
533 // Emitted as the first message at the start of a conversation
534 System {
535 cwd: String,
536 session_id: String,
537 tools: Vec<String>,
538 model: String,
539 mcp_servers: Vec<McpServer>,
540 #[serde(rename = "apiKeySource")]
541 api_key_source: String,
542 #[serde(rename = "permissionMode")]
543 permission_mode: PermissionMode,
544 },
545}
546
547#[derive(Debug, Clone, Serialize, Deserialize)]
548#[serde(rename_all = "snake_case")]
549enum ResultErrorType {
550 Success,
551 ErrorMaxTurns,
552 ErrorDuringExecution,
553}
554
555impl Display for ResultErrorType {
556 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
557 match self {
558 ResultErrorType::Success => write!(f, "success"),
559 ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
560 ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
561 }
562 }
563}
564
565#[derive(Debug, Clone, Serialize, Deserialize)]
566struct McpServer {
567 name: String,
568 status: String,
569}
570
571#[derive(Debug, Clone, Serialize, Deserialize)]
572#[serde(rename_all = "camelCase")]
573enum PermissionMode {
574 Default,
575 AcceptEdits,
576 BypassPermissions,
577 Plan,
578}
579
580#[derive(Serialize)]
581#[serde(rename_all = "camelCase")]
582struct McpConfig {
583 mcp_servers: HashMap<String, McpServerConfig>,
584}
585
586#[derive(Serialize, Debug)]
587#[serde(rename_all = "camelCase")]
588struct McpServerConfig {
589 command: String,
590 args: Vec<String>,
591 #[serde(skip_serializing_if = "Option::is_none")]
592 env: Option<HashMap<String, String>>,
593}
594
595#[cfg(test)]
596pub(crate) mod tests {
597 use super::*;
598 use serde_json::json;
599
600 crate::common_e2e_tests!(ClaudeCode);
601
602 pub fn local_command() -> AgentServerCommand {
603 AgentServerCommand {
604 path: "claude".into(),
605 args: vec![],
606 env: None,
607 }
608 }
609
610 #[test]
611 fn test_deserialize_content_untagged_text() {
612 let json = json!("Hello, world!");
613 let content: Content = serde_json::from_value(json).unwrap();
614 match content {
615 Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
616 _ => panic!("Expected UntaggedText variant"),
617 }
618 }
619
620 #[test]
621 fn test_deserialize_content_chunks() {
622 let json = json!([
623 {
624 "type": "text",
625 "text": "Hello"
626 },
627 {
628 "type": "tool_use",
629 "id": "tool_123",
630 "name": "calculator",
631 "input": {"operation": "add", "a": 1, "b": 2}
632 }
633 ]);
634 let content: Content = serde_json::from_value(json).unwrap();
635 match content {
636 Content::Chunks(chunks) => {
637 assert_eq!(chunks.len(), 2);
638 match &chunks[0] {
639 ContentChunk::Text { text } => assert_eq!(text, "Hello"),
640 _ => panic!("Expected Text chunk"),
641 }
642 match &chunks[1] {
643 ContentChunk::ToolUse { id, name, input } => {
644 assert_eq!(id, "tool_123");
645 assert_eq!(name, "calculator");
646 assert_eq!(input["operation"], "add");
647 assert_eq!(input["a"], 1);
648 assert_eq!(input["b"], 2);
649 }
650 _ => panic!("Expected ToolUse chunk"),
651 }
652 }
653 _ => panic!("Expected Chunks variant"),
654 }
655 }
656
657 #[test]
658 fn test_deserialize_tool_result_untagged_text() {
659 let json = json!({
660 "type": "tool_result",
661 "content": "Result content",
662 "tool_use_id": "tool_456"
663 });
664 let chunk: ContentChunk = serde_json::from_value(json).unwrap();
665 match chunk {
666 ContentChunk::ToolResult {
667 content,
668 tool_use_id,
669 } => {
670 match content {
671 Content::UntaggedText(text) => assert_eq!(text, "Result content"),
672 _ => panic!("Expected UntaggedText content"),
673 }
674 assert_eq!(tool_use_id, "tool_456");
675 }
676 _ => panic!("Expected ToolResult variant"),
677 }
678 }
679
680 #[test]
681 fn test_deserialize_tool_result_chunks() {
682 let json = json!({
683 "type": "tool_result",
684 "content": [
685 {
686 "type": "text",
687 "text": "Processing complete"
688 },
689 {
690 "type": "text",
691 "text": "Result: 42"
692 }
693 ],
694 "tool_use_id": "tool_789"
695 });
696 let chunk: ContentChunk = serde_json::from_value(json).unwrap();
697 match chunk {
698 ContentChunk::ToolResult {
699 content,
700 tool_use_id,
701 } => {
702 match content {
703 Content::Chunks(chunks) => {
704 assert_eq!(chunks.len(), 2);
705 match &chunks[0] {
706 ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
707 _ => panic!("Expected Text chunk"),
708 }
709 match &chunks[1] {
710 ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
711 _ => panic!("Expected Text chunk"),
712 }
713 }
714 _ => panic!("Expected Chunks content"),
715 }
716 assert_eq!(tool_use_id, "tool_789");
717 }
718 _ => panic!("Expected ToolResult variant"),
719 }
720 }
721}