claude.rs

   1mod edit_tool;
   2mod mcp_server;
   3mod permission_tool;
   4mod read_tool;
   5pub mod tools;
   6mod write_tool;
   7
   8use action_log::ActionLog;
   9use collections::HashMap;
  10use context_server::listener::McpServerTool;
  11use language_models::provider::anthropic::AnthropicLanguageModelProvider;
  12use project::Project;
  13use settings::SettingsStore;
  14use smol::process::Child;
  15use std::any::Any;
  16use std::cell::RefCell;
  17use std::fmt::Display;
  18use std::path::{Path, PathBuf};
  19use std::rc::Rc;
  20use util::command::new_smol_command;
  21use uuid::Uuid;
  22
  23use agent_client_protocol as acp;
  24use anyhow::{Context as _, Result, anyhow};
  25use futures::channel::oneshot;
  26use futures::{AsyncBufReadExt, AsyncWriteExt};
  27use futures::{
  28    AsyncRead, AsyncWrite, FutureExt, StreamExt,
  29    channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
  30    io::BufReader,
  31    select_biased,
  32};
  33use gpui::{App, AppContext, AsyncApp, Entity, SharedString, Task, WeakEntity};
  34use serde::{Deserialize, Serialize};
  35use util::{ResultExt, debug_panic};
  36
  37use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
  38use crate::claude::tools::ClaudeTool;
  39use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
  40use acp_thread::{AcpThread, AgentConnection, AuthRequired, LoadError, MentionUri};
  41
  42#[derive(Clone)]
  43pub struct ClaudeCode;
  44
  45impl AgentServer for ClaudeCode {
  46    fn name(&self) -> SharedString {
  47        "Claude Code".into()
  48    }
  49
  50    fn empty_state_headline(&self) -> SharedString {
  51        self.name()
  52    }
  53
  54    fn empty_state_message(&self) -> SharedString {
  55        "How can I help you today?".into()
  56    }
  57
  58    fn logo(&self) -> ui::IconName {
  59        ui::IconName::AiClaude
  60    }
  61
  62    fn connect(
  63        &self,
  64        _root_dir: &Path,
  65        _project: &Entity<Project>,
  66        _cx: &mut App,
  67    ) -> Task<Result<Rc<dyn AgentConnection>>> {
  68        let connection = ClaudeAgentConnection {
  69            sessions: Default::default(),
  70        };
  71
  72        Task::ready(Ok(Rc::new(connection) as _))
  73    }
  74
  75    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
  76        self
  77    }
  78}
  79
  80struct ClaudeAgentConnection {
  81    sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
  82}
  83
  84impl AgentConnection for ClaudeAgentConnection {
  85    fn new_thread(
  86        self: Rc<Self>,
  87        project: Entity<Project>,
  88        cwd: &Path,
  89        cx: &mut App,
  90    ) -> Task<Result<Entity<AcpThread>>> {
  91        let cwd = cwd.to_owned();
  92        cx.spawn(async move |cx| {
  93            let settings = cx.read_global(|settings: &SettingsStore, _| {
  94                settings.get::<AllAgentServersSettings>(None).claude.clone()
  95            })?;
  96
  97            let Some(command) = AgentServerCommand::resolve(
  98                "claude",
  99                &[],
 100                Some(&util::paths::home_dir().join(".claude/local/claude")),
 101                settings,
 102                &project,
 103                cx,
 104            )
 105            .await
 106            else {
 107                return Err(LoadError::NotInstalled {
 108                    error_message: "Failed to find Claude Code binary".into(),
 109                    install_message: "Install Claude Code".into(),
 110                    install_command: "npm install -g @anthropic-ai/claude-code@latest".into(),
 111                }.into());
 112            };
 113
 114            let api_key =
 115                cx.update(AnthropicLanguageModelProvider::api_key)?
 116                    .await
 117                    .map_err(|err| {
 118                        if err.is::<language_model::AuthenticateError>() {
 119                            anyhow!(AuthRequired::new().with_language_model_provider(
 120                                language_model::ANTHROPIC_PROVIDER_ID
 121                            ))
 122                        } else {
 123                            anyhow!(err)
 124                        }
 125                    })?;
 126
 127            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
 128            let fs = project.read_with(cx, |project, _cx| project.fs().clone())?;
 129            let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), fs, cx).await?;
 130
 131            let mut mcp_servers = HashMap::default();
 132            mcp_servers.insert(
 133                mcp_server::SERVER_NAME.to_string(),
 134                permission_mcp_server.server_config()?,
 135            );
 136            let mcp_config = McpConfig { mcp_servers };
 137
 138            let mcp_config_file = tempfile::NamedTempFile::new()?;
 139            let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
 140
 141            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
 142            mcp_config_file
 143                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
 144                .await?;
 145            mcp_config_file.flush().await?;
 146
 147            let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
 148            let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
 149
 150            let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
 151
 152            log::trace!("Starting session with id: {}", session_id);
 153
 154            let mut child = spawn_claude(
 155                &command,
 156                ClaudeSessionMode::Start,
 157                session_id.clone(),
 158                api_key,
 159                &mcp_config_path,
 160                &cwd,
 161            )?;
 162
 163            let stdout = child.stdout.take().context("Failed to take stdout")?;
 164            let stdin = child.stdin.take().context("Failed to take stdin")?;
 165            let stderr = child.stderr.take().context("Failed to take stderr")?;
 166
 167            let pid = child.id();
 168            log::trace!("Spawned (pid: {})", pid);
 169
 170            cx.background_spawn(async move {
 171                let mut stderr = BufReader::new(stderr);
 172                let mut line = String::new();
 173                while let Ok(n) = stderr.read_line(&mut line).await
 174                    && n > 0
 175                {
 176                    log::warn!("agent stderr: {}", &line);
 177                    line.clear();
 178                }
 179            })
 180            .detach();
 181
 182            cx.background_spawn(async move {
 183                let mut outgoing_rx = Some(outgoing_rx);
 184
 185                ClaudeAgentSession::handle_io(
 186                    outgoing_rx.take().unwrap(),
 187                    incoming_message_tx.clone(),
 188                    stdin,
 189                    stdout,
 190                )
 191                .await?;
 192
 193                log::trace!("Stopped (pid: {})", pid);
 194
 195                drop(mcp_config_path);
 196                anyhow::Ok(())
 197            })
 198            .detach();
 199
 200            let turn_state = Rc::new(RefCell::new(TurnState::None));
 201
 202            let handler_task = cx.spawn({
 203                let turn_state = turn_state.clone();
 204                let mut thread_rx = thread_rx.clone();
 205                async move |cx| {
 206                    while let Some(message) = incoming_message_rx.next().await {
 207                        ClaudeAgentSession::handle_message(
 208                            thread_rx.clone(),
 209                            message,
 210                            turn_state.clone(),
 211                            cx,
 212                        )
 213                        .await
 214                    }
 215
 216                    if let Some(status) = child.status().await.log_err()
 217                        && let Some(thread) = thread_rx.recv().await.ok()
 218                    {
 219                        let version = claude_version(command.path.clone(), cx).await.log_err();
 220                        let help = claude_help(command.path.clone(), cx).await.log_err();
 221                        thread
 222                            .update(cx, |thread, cx| {
 223                                let error = if let Some(version) = version
 224                                    && let Some(help) = help
 225                                    && (!help.contains("--input-format")
 226                                        || !help.contains("--session-id"))
 227                                {
 228                                    LoadError::Unsupported {
 229                                    error_message: format!(
 230                                            "Your installed version of Claude Code ({}, version {}) does not have required features for use with Zed.",
 231                                            command.path.to_string_lossy(),
 232                                            version,
 233                                        )
 234                                        .into(),
 235                                        upgrade_message: "Upgrade Claude Code to latest".into(),
 236                                        upgrade_command: format!(
 237                                            "{} update",
 238                                            command.path.to_string_lossy()
 239                                        ),
 240                                    }
 241                                } else {
 242                                    LoadError::Exited { status }
 243                                };
 244                                thread.emit_load_error(error, cx);
 245                            })
 246                            .ok();
 247                    }
 248                }
 249            });
 250
 251            let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
 252            let thread = cx.new(|cx| {
 253                AcpThread::new(
 254                    "Claude Code",
 255                    self.clone(),
 256                    project,
 257                    action_log,
 258                    session_id.clone(),
 259                    watch::Receiver::constant(acp::PromptCapabilities {
 260                        image: true,
 261                        audio: false,
 262                        embedded_context: true,
 263                    }),
 264                    cx,
 265                )
 266            })?;
 267
 268            thread_tx.send(thread.downgrade())?;
 269
 270            let session = ClaudeAgentSession {
 271                outgoing_tx,
 272                turn_state,
 273                _handler_task: handler_task,
 274                _mcp_server: Some(permission_mcp_server),
 275            };
 276
 277            self.sessions.borrow_mut().insert(session_id, session);
 278
 279            Ok(thread)
 280        })
 281    }
 282
 283    fn auth_methods(&self) -> &[acp::AuthMethod] {
 284        &[]
 285    }
 286
 287    fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 288        Task::ready(Err(anyhow!("Authentication not supported")))
 289    }
 290
 291    fn prompt(
 292        &self,
 293        _id: Option<acp_thread::UserMessageId>,
 294        params: acp::PromptRequest,
 295        cx: &mut App,
 296    ) -> Task<Result<acp::PromptResponse>> {
 297        let sessions = self.sessions.borrow();
 298        let Some(session) = sessions.get(&params.session_id) else {
 299            return Task::ready(Err(anyhow!(
 300                "Attempted to send message to nonexistent session {}",
 301                params.session_id
 302            )));
 303        };
 304
 305        let (end_tx, end_rx) = oneshot::channel();
 306        session.turn_state.replace(TurnState::InProgress { end_tx });
 307
 308        let content = acp_content_to_claude(params.prompt);
 309
 310        if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
 311            message: Message {
 312                role: Role::User,
 313                content: Content::Chunks(content),
 314                id: None,
 315                model: None,
 316                stop_reason: None,
 317                stop_sequence: None,
 318                usage: None,
 319            },
 320            session_id: Some(params.session_id.to_string()),
 321        }) {
 322            return Task::ready(Err(anyhow!(err)));
 323        }
 324
 325        cx.foreground_executor().spawn(async move { end_rx.await? })
 326    }
 327
 328    fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
 329        let sessions = self.sessions.borrow();
 330        let Some(session) = sessions.get(session_id) else {
 331            log::warn!("Attempted to cancel nonexistent session {}", session_id);
 332            return;
 333        };
 334
 335        let request_id = new_request_id();
 336
 337        let turn_state = session.turn_state.take();
 338        let TurnState::InProgress { end_tx } = turn_state else {
 339            // Already canceled or idle, put it back
 340            session.turn_state.replace(turn_state);
 341            return;
 342        };
 343
 344        session.turn_state.replace(TurnState::CancelRequested {
 345            end_tx,
 346            request_id: request_id.clone(),
 347        });
 348
 349        session
 350            .outgoing_tx
 351            .unbounded_send(SdkMessage::ControlRequest {
 352                request_id,
 353                request: ControlRequest::Interrupt,
 354            })
 355            .log_err();
 356    }
 357
 358    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 359        self
 360    }
 361}
 362
 363#[derive(Clone, Copy)]
 364enum ClaudeSessionMode {
 365    Start,
 366    #[expect(dead_code)]
 367    Resume,
 368}
 369
 370fn spawn_claude(
 371    command: &AgentServerCommand,
 372    mode: ClaudeSessionMode,
 373    session_id: acp::SessionId,
 374    api_key: language_models::provider::anthropic::ApiKey,
 375    mcp_config_path: &Path,
 376    root_dir: &Path,
 377) -> Result<Child> {
 378    let child = util::command::new_smol_command(&command.path)
 379        .args([
 380            "--input-format",
 381            "stream-json",
 382            "--output-format",
 383            "stream-json",
 384            "--print",
 385            "--verbose",
 386            "--mcp-config",
 387            mcp_config_path.to_string_lossy().as_ref(),
 388            "--permission-prompt-tool",
 389            &format!(
 390                "mcp__{}__{}",
 391                mcp_server::SERVER_NAME,
 392                permission_tool::PermissionTool::NAME,
 393            ),
 394            "--allowedTools",
 395            &format!(
 396                "mcp__{}__{}",
 397                mcp_server::SERVER_NAME,
 398                read_tool::ReadTool::NAME
 399            ),
 400            "--disallowedTools",
 401            "Read,Write,Edit,MultiEdit",
 402        ])
 403        .args(match mode {
 404            ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
 405            ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
 406        })
 407        .args(command.args.iter().map(|arg| arg.as_str()))
 408        .envs(command.env.iter().flatten())
 409        .env("ANTHROPIC_API_KEY", api_key.key)
 410        .current_dir(root_dir)
 411        .stdin(std::process::Stdio::piped())
 412        .stdout(std::process::Stdio::piped())
 413        .stderr(std::process::Stdio::piped())
 414        .kill_on_drop(true)
 415        .spawn()?;
 416
 417    Ok(child)
 418}
 419
 420fn claude_version(path: PathBuf, cx: &mut AsyncApp) -> Task<Result<semver::Version>> {
 421    cx.background_spawn(async move {
 422        let output = new_smol_command(path).arg("--version").output().await?;
 423        let output = String::from_utf8(output.stdout)?;
 424        let version = output
 425            .trim()
 426            .strip_suffix(" (Claude Code)")
 427            .context("parsing Claude version")?;
 428        let version = semver::Version::parse(version)?;
 429        anyhow::Ok(version)
 430    })
 431}
 432
 433fn claude_help(path: PathBuf, cx: &mut AsyncApp) -> Task<Result<String>> {
 434    cx.background_spawn(async move {
 435        let output = new_smol_command(path).arg("--help").output().await?;
 436        let output = String::from_utf8(output.stdout)?;
 437        anyhow::Ok(output)
 438    })
 439}
 440
 441struct ClaudeAgentSession {
 442    outgoing_tx: UnboundedSender<SdkMessage>,
 443    turn_state: Rc<RefCell<TurnState>>,
 444    _mcp_server: Option<ClaudeZedMcpServer>,
 445    _handler_task: Task<()>,
 446}
 447
 448#[derive(Debug, Default)]
 449enum TurnState {
 450    #[default]
 451    None,
 452    InProgress {
 453        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 454    },
 455    CancelRequested {
 456        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 457        request_id: String,
 458    },
 459    CancelConfirmed {
 460        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 461    },
 462}
 463
 464impl TurnState {
 465    fn is_canceled(&self) -> bool {
 466        matches!(self, TurnState::CancelConfirmed { .. })
 467    }
 468
 469    fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
 470        match self {
 471            TurnState::None => None,
 472            TurnState::InProgress { end_tx, .. } => Some(end_tx),
 473            TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
 474            TurnState::CancelConfirmed { end_tx } => Some(end_tx),
 475        }
 476    }
 477
 478    fn confirm_cancellation(self, id: &str) -> Self {
 479        match self {
 480            TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
 481                TurnState::CancelConfirmed { end_tx }
 482            }
 483            _ => self,
 484        }
 485    }
 486}
 487
 488impl ClaudeAgentSession {
 489    async fn handle_message(
 490        mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
 491        message: SdkMessage,
 492        turn_state: Rc<RefCell<TurnState>>,
 493        cx: &mut AsyncApp,
 494    ) {
 495        match message {
 496            // we should only be sending these out, they don't need to be in the thread
 497            SdkMessage::ControlRequest { .. } => {}
 498            SdkMessage::User {
 499                message,
 500                session_id: _,
 501            } => {
 502                let Some(thread) = thread_rx
 503                    .recv()
 504                    .await
 505                    .log_err()
 506                    .and_then(|entity| entity.upgrade())
 507                else {
 508                    log::error!("Received an SDK message but thread is gone");
 509                    return;
 510                };
 511
 512                for chunk in message.content.chunks() {
 513                    match chunk {
 514                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 515                            if !turn_state.borrow().is_canceled() {
 516                                thread
 517                                    .update(cx, |thread, cx| {
 518                                        thread.push_user_content_block(None, text.into(), cx)
 519                                    })
 520                                    .log_err();
 521                            }
 522                        }
 523                        ContentChunk::ToolResult {
 524                            content,
 525                            tool_use_id,
 526                        } => {
 527                            let content = content.to_string();
 528                            thread
 529                                .update(cx, |thread, cx| {
 530                                    let id = acp::ToolCallId(tool_use_id.into());
 531                                    let set_new_content = !content.is_empty()
 532                                        && thread.tool_call(&id).is_none_or(|(_, tool_call)| {
 533                                            // preserve rich diff if we have one
 534                                            tool_call.diffs().next().is_none()
 535                                        });
 536
 537                                    thread.update_tool_call(
 538                                        acp::ToolCallUpdate {
 539                                            id,
 540                                            fields: acp::ToolCallUpdateFields {
 541                                                status: if turn_state.borrow().is_canceled() {
 542                                                    // Do not set to completed if turn was canceled
 543                                                    None
 544                                                } else {
 545                                                    Some(acp::ToolCallStatus::Completed)
 546                                                },
 547                                                content: set_new_content
 548                                                    .then(|| vec![content.into()]),
 549                                                ..Default::default()
 550                                            },
 551                                        },
 552                                        cx,
 553                                    )
 554                                })
 555                                .log_err();
 556                        }
 557                        ContentChunk::Thinking { .. }
 558                        | ContentChunk::RedactedThinking
 559                        | ContentChunk::ToolUse { .. } => {
 560                            debug_panic!(
 561                                "Should not get {:?} with role: assistant. should we handle this?",
 562                                chunk
 563                            );
 564                        }
 565                        ContentChunk::Image { source } => {
 566                            if !turn_state.borrow().is_canceled() {
 567                                thread
 568                                    .update(cx, |thread, cx| {
 569                                        thread.push_user_content_block(None, source.into(), cx)
 570                                    })
 571                                    .log_err();
 572                            }
 573                        }
 574
 575                        ContentChunk::Document | ContentChunk::WebSearchToolResult => {
 576                            thread
 577                                .update(cx, |thread, cx| {
 578                                    thread.push_assistant_content_block(
 579                                        format!("Unsupported content: {:?}", chunk).into(),
 580                                        false,
 581                                        cx,
 582                                    )
 583                                })
 584                                .log_err();
 585                        }
 586                    }
 587                }
 588            }
 589            SdkMessage::Assistant {
 590                message,
 591                session_id: _,
 592            } => {
 593                let Some(thread) = thread_rx
 594                    .recv()
 595                    .await
 596                    .log_err()
 597                    .and_then(|entity| entity.upgrade())
 598                else {
 599                    log::error!("Received an SDK message but thread is gone");
 600                    return;
 601                };
 602
 603                for chunk in message.content.chunks() {
 604                    match chunk {
 605                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 606                            thread
 607                                .update(cx, |thread, cx| {
 608                                    thread.push_assistant_content_block(text.into(), false, cx)
 609                                })
 610                                .log_err();
 611                        }
 612                        ContentChunk::Thinking { thinking } => {
 613                            thread
 614                                .update(cx, |thread, cx| {
 615                                    thread.push_assistant_content_block(thinking.into(), true, cx)
 616                                })
 617                                .log_err();
 618                        }
 619                        ContentChunk::RedactedThinking => {
 620                            thread
 621                                .update(cx, |thread, cx| {
 622                                    thread.push_assistant_content_block(
 623                                        "[REDACTED]".into(),
 624                                        true,
 625                                        cx,
 626                                    )
 627                                })
 628                                .log_err();
 629                        }
 630                        ContentChunk::ToolUse { id, name, input } => {
 631                            let claude_tool = ClaudeTool::infer(&name, input);
 632
 633                            thread
 634                                .update(cx, |thread, cx| {
 635                                    if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
 636                                        thread.update_plan(
 637                                            acp::Plan {
 638                                                entries: params
 639                                                    .todos
 640                                                    .into_iter()
 641                                                    .map(Into::into)
 642                                                    .collect(),
 643                                            },
 644                                            cx,
 645                                        )
 646                                    } else {
 647                                        thread.upsert_tool_call(
 648                                            claude_tool.as_acp(acp::ToolCallId(id.into())),
 649                                            cx,
 650                                        )?;
 651                                    }
 652                                    anyhow::Ok(())
 653                                })
 654                                .log_err();
 655                        }
 656                        ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
 657                            debug_panic!(
 658                                "Should not get tool results with role: assistant. should we handle this?"
 659                            );
 660                        }
 661                        ContentChunk::Image { source } => {
 662                            thread
 663                                .update(cx, |thread, cx| {
 664                                    thread.push_assistant_content_block(source.into(), false, cx)
 665                                })
 666                                .log_err();
 667                        }
 668                        ContentChunk::Document => {
 669                            thread
 670                                .update(cx, |thread, cx| {
 671                                    thread.push_assistant_content_block(
 672                                        format!("Unsupported content: {:?}", chunk).into(),
 673                                        false,
 674                                        cx,
 675                                    )
 676                                })
 677                                .log_err();
 678                        }
 679                    }
 680                }
 681            }
 682            SdkMessage::Result {
 683                is_error,
 684                subtype,
 685                result,
 686                ..
 687            } => {
 688                let turn_state = turn_state.take();
 689                let was_canceled = turn_state.is_canceled();
 690                let Some(end_turn_tx) = turn_state.end_tx() else {
 691                    debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
 692                    return;
 693                };
 694
 695                if is_error || (!was_canceled && subtype == ResultErrorType::ErrorDuringExecution) {
 696                    end_turn_tx
 697                        .send(Err(anyhow!(
 698                            "Error: {}",
 699                            result.unwrap_or_else(|| subtype.to_string())
 700                        )))
 701                        .ok();
 702                } else {
 703                    let stop_reason = match subtype {
 704                        ResultErrorType::Success => acp::StopReason::EndTurn,
 705                        ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
 706                        ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
 707                    };
 708                    end_turn_tx
 709                        .send(Ok(acp::PromptResponse { stop_reason }))
 710                        .ok();
 711                }
 712            }
 713            SdkMessage::ControlResponse { response } => {
 714                if matches!(response.subtype, ResultErrorType::Success) {
 715                    let new_state = turn_state.take().confirm_cancellation(&response.request_id);
 716                    turn_state.replace(new_state);
 717                } else {
 718                    log::error!("Control response error: {:?}", response);
 719                }
 720            }
 721            SdkMessage::System { .. } => {}
 722        }
 723    }
 724
 725    async fn handle_io(
 726        mut outgoing_rx: UnboundedReceiver<SdkMessage>,
 727        incoming_tx: UnboundedSender<SdkMessage>,
 728        mut outgoing_bytes: impl Unpin + AsyncWrite,
 729        incoming_bytes: impl Unpin + AsyncRead,
 730    ) -> Result<UnboundedReceiver<SdkMessage>> {
 731        let mut output_reader = BufReader::new(incoming_bytes);
 732        let mut outgoing_line = Vec::new();
 733        let mut incoming_line = String::new();
 734        loop {
 735            select_biased! {
 736                message = outgoing_rx.next() => {
 737                    if let Some(message) = message {
 738                        outgoing_line.clear();
 739                        serde_json::to_writer(&mut outgoing_line, &message)?;
 740                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
 741                        outgoing_line.push(b'\n');
 742                        outgoing_bytes.write_all(&outgoing_line).await.ok();
 743                    } else {
 744                        break;
 745                    }
 746                }
 747                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
 748                    if bytes_read? == 0 {
 749                        break
 750                    }
 751                    log::trace!("recv: {}", &incoming_line);
 752                    match serde_json::from_str::<SdkMessage>(&incoming_line) {
 753                        Ok(message) => {
 754                            incoming_tx.unbounded_send(message).log_err();
 755                        }
 756                        Err(error) => {
 757                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
 758                        }
 759                    }
 760                    incoming_line.clear();
 761                }
 762            }
 763        }
 764
 765        Ok(outgoing_rx)
 766    }
 767}
 768
 769#[derive(Debug, Clone, Serialize, Deserialize)]
 770struct Message {
 771    role: Role,
 772    content: Content,
 773    #[serde(skip_serializing_if = "Option::is_none")]
 774    id: Option<String>,
 775    #[serde(skip_serializing_if = "Option::is_none")]
 776    model: Option<String>,
 777    #[serde(skip_serializing_if = "Option::is_none")]
 778    stop_reason: Option<String>,
 779    #[serde(skip_serializing_if = "Option::is_none")]
 780    stop_sequence: Option<String>,
 781    #[serde(skip_serializing_if = "Option::is_none")]
 782    usage: Option<Usage>,
 783}
 784
 785#[derive(Debug, Clone, Serialize, Deserialize)]
 786#[serde(untagged)]
 787enum Content {
 788    UntaggedText(String),
 789    Chunks(Vec<ContentChunk>),
 790}
 791
 792impl Content {
 793    pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
 794        match self {
 795            Self::Chunks(chunks) => chunks.into_iter(),
 796            Self::UntaggedText(text) => vec![ContentChunk::Text { text }].into_iter(),
 797        }
 798    }
 799}
 800
 801impl Display for Content {
 802    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 803        match self {
 804            Content::UntaggedText(txt) => write!(f, "{}", txt),
 805            Content::Chunks(chunks) => {
 806                for chunk in chunks {
 807                    write!(f, "{}", chunk)?;
 808                }
 809                Ok(())
 810            }
 811        }
 812    }
 813}
 814
 815#[derive(Debug, Clone, Serialize, Deserialize)]
 816#[serde(tag = "type", rename_all = "snake_case")]
 817enum ContentChunk {
 818    Text {
 819        text: String,
 820    },
 821    ToolUse {
 822        id: String,
 823        name: String,
 824        input: serde_json::Value,
 825    },
 826    ToolResult {
 827        content: Content,
 828        tool_use_id: String,
 829    },
 830    Thinking {
 831        thinking: String,
 832    },
 833    RedactedThinking,
 834    Image {
 835        source: ImageSource,
 836    },
 837    // TODO
 838    Document,
 839    WebSearchToolResult,
 840    #[serde(untagged)]
 841    UntaggedText(String),
 842}
 843
 844#[derive(Debug, Clone, Serialize, Deserialize)]
 845#[serde(tag = "type", rename_all = "snake_case")]
 846enum ImageSource {
 847    Base64 { data: String, media_type: String },
 848    Url { url: String },
 849}
 850
 851impl Into<acp::ContentBlock> for ImageSource {
 852    fn into(self) -> acp::ContentBlock {
 853        match self {
 854            ImageSource::Base64 { data, media_type } => {
 855                acp::ContentBlock::Image(acp::ImageContent {
 856                    annotations: None,
 857                    data,
 858                    mime_type: media_type,
 859                    uri: None,
 860                })
 861            }
 862            ImageSource::Url { url } => acp::ContentBlock::Image(acp::ImageContent {
 863                annotations: None,
 864                data: "".to_string(),
 865                mime_type: "".to_string(),
 866                uri: Some(url),
 867            }),
 868        }
 869    }
 870}
 871
 872impl Display for ContentChunk {
 873    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 874        match self {
 875            ContentChunk::Text { text } => write!(f, "{}", text),
 876            ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
 877            ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
 878            ContentChunk::UntaggedText(text) => write!(f, "{}", text),
 879            ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
 880            ContentChunk::Image { .. }
 881            | ContentChunk::Document
 882            | ContentChunk::ToolUse { .. }
 883            | ContentChunk::WebSearchToolResult => {
 884                write!(f, "\n{:?}\n", &self)
 885            }
 886        }
 887    }
 888}
 889
 890#[derive(Debug, Clone, Serialize, Deserialize)]
 891struct Usage {
 892    input_tokens: u32,
 893    cache_creation_input_tokens: u32,
 894    cache_read_input_tokens: u32,
 895    output_tokens: u32,
 896    service_tier: String,
 897}
 898
 899#[derive(Debug, Clone, Serialize, Deserialize)]
 900#[serde(rename_all = "snake_case")]
 901enum Role {
 902    System,
 903    Assistant,
 904    User,
 905}
 906
 907#[derive(Debug, Clone, Serialize, Deserialize)]
 908struct MessageParam {
 909    role: Role,
 910    content: String,
 911}
 912
 913#[derive(Debug, Clone, Serialize, Deserialize)]
 914#[serde(tag = "type", rename_all = "snake_case")]
 915enum SdkMessage {
 916    // An assistant message
 917    Assistant {
 918        message: Message, // from Anthropic SDK
 919        #[serde(skip_serializing_if = "Option::is_none")]
 920        session_id: Option<String>,
 921    },
 922    // A user message
 923    User {
 924        message: Message, // from Anthropic SDK
 925        #[serde(skip_serializing_if = "Option::is_none")]
 926        session_id: Option<String>,
 927    },
 928    // Emitted as the last message in a conversation
 929    Result {
 930        subtype: ResultErrorType,
 931        duration_ms: f64,
 932        duration_api_ms: f64,
 933        is_error: bool,
 934        num_turns: i32,
 935        #[serde(skip_serializing_if = "Option::is_none")]
 936        result: Option<String>,
 937        session_id: String,
 938        total_cost_usd: f64,
 939    },
 940    // Emitted as the first message at the start of a conversation
 941    System {
 942        cwd: String,
 943        session_id: String,
 944        tools: Vec<String>,
 945        model: String,
 946        mcp_servers: Vec<McpServer>,
 947        #[serde(rename = "apiKeySource")]
 948        api_key_source: String,
 949        #[serde(rename = "permissionMode")]
 950        permission_mode: PermissionMode,
 951    },
 952    /// Messages used to control the conversation, outside of chat messages to the model
 953    ControlRequest {
 954        request_id: String,
 955        request: ControlRequest,
 956    },
 957    /// Response to a control request
 958    ControlResponse { response: ControlResponse },
 959}
 960
 961#[derive(Debug, Clone, Serialize, Deserialize)]
 962#[serde(tag = "subtype", rename_all = "snake_case")]
 963enum ControlRequest {
 964    /// Cancel the current conversation
 965    Interrupt,
 966}
 967
 968#[derive(Debug, Clone, Serialize, Deserialize)]
 969struct ControlResponse {
 970    request_id: String,
 971    subtype: ResultErrorType,
 972}
 973
 974#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
 975#[serde(rename_all = "snake_case")]
 976enum ResultErrorType {
 977    Success,
 978    ErrorMaxTurns,
 979    ErrorDuringExecution,
 980}
 981
 982impl Display for ResultErrorType {
 983    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 984        match self {
 985            ResultErrorType::Success => write!(f, "success"),
 986            ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
 987            ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
 988        }
 989    }
 990}
 991
 992fn acp_content_to_claude(prompt: Vec<acp::ContentBlock>) -> Vec<ContentChunk> {
 993    let mut content = Vec::with_capacity(prompt.len());
 994    let mut context = Vec::with_capacity(prompt.len());
 995
 996    for chunk in prompt {
 997        match chunk {
 998            acp::ContentBlock::Text(text_content) => {
 999                content.push(ContentChunk::Text {
1000                    text: text_content.text,
1001                });
1002            }
1003            acp::ContentBlock::ResourceLink(resource_link) => {
1004                match MentionUri::parse(&resource_link.uri) {
1005                    Ok(uri) => {
1006                        content.push(ContentChunk::Text {
1007                            text: format!("{}", uri.as_link()),
1008                        });
1009                    }
1010                    Err(_) => {
1011                        content.push(ContentChunk::Text {
1012                            text: resource_link.uri,
1013                        });
1014                    }
1015                }
1016            }
1017            acp::ContentBlock::Resource(resource) => match resource.resource {
1018                acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1019                    match MentionUri::parse(&resource.uri) {
1020                        Ok(uri) => {
1021                            content.push(ContentChunk::Text {
1022                                text: format!("{}", uri.as_link()),
1023                            });
1024                        }
1025                        Err(_) => {
1026                            content.push(ContentChunk::Text {
1027                                text: resource.uri.clone(),
1028                            });
1029                        }
1030                    }
1031
1032                    context.push(ContentChunk::Text {
1033                        text: format!(
1034                            "\n<context ref=\"{}\">\n{}\n</context>",
1035                            resource.uri, resource.text
1036                        ),
1037                    });
1038                }
1039                acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1040                    // Unsupported by SDK
1041                }
1042            },
1043            acp::ContentBlock::Image(acp::ImageContent {
1044                data, mime_type, ..
1045            }) => content.push(ContentChunk::Image {
1046                source: ImageSource::Base64 {
1047                    data,
1048                    media_type: mime_type,
1049                },
1050            }),
1051            acp::ContentBlock::Audio(_) => {
1052                // Unsupported by SDK
1053            }
1054        }
1055    }
1056
1057    content.extend(context);
1058    content
1059}
1060
1061fn new_request_id() -> String {
1062    use rand::Rng;
1063    // In the Claude Code TS SDK they just generate a random 12 character string,
1064    // `Math.random().toString(36).substring(2, 15)`
1065    rand::thread_rng()
1066        .sample_iter(&rand::distributions::Alphanumeric)
1067        .take(12)
1068        .map(char::from)
1069        .collect()
1070}
1071
1072#[derive(Debug, Clone, Serialize, Deserialize)]
1073struct McpServer {
1074    name: String,
1075    status: String,
1076}
1077
1078#[derive(Debug, Clone, Serialize, Deserialize)]
1079#[serde(rename_all = "camelCase")]
1080enum PermissionMode {
1081    Default,
1082    AcceptEdits,
1083    BypassPermissions,
1084    Plan,
1085}
1086
1087#[cfg(test)]
1088pub(crate) mod tests {
1089    use super::*;
1090    use crate::e2e_tests;
1091    use gpui::TestAppContext;
1092    use serde_json::json;
1093
1094    crate::common_e2e_tests!(async |_, _, _| ClaudeCode, allow_option_id = "allow");
1095
1096    pub fn local_command() -> AgentServerCommand {
1097        AgentServerCommand {
1098            path: "claude".into(),
1099            args: vec![],
1100            env: None,
1101        }
1102    }
1103
1104    #[gpui::test]
1105    #[cfg_attr(not(feature = "e2e"), ignore)]
1106    async fn test_todo_plan(cx: &mut TestAppContext) {
1107        let fs = e2e_tests::init_test(cx).await;
1108        let project = Project::test(fs, [], cx).await;
1109        let thread =
1110            e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
1111
1112        thread
1113            .update(cx, |thread, cx| {
1114                thread.send_raw(
1115                    "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
1116                    cx,
1117                )
1118            })
1119            .await
1120            .unwrap();
1121
1122        let mut entries_len = 0;
1123
1124        thread.read_with(cx, |thread, _| {
1125            entries_len = thread.plan().entries.len();
1126            assert!(!thread.plan().entries.is_empty(), "Empty plan");
1127        });
1128
1129        thread
1130            .update(cx, |thread, cx| {
1131                thread.send_raw(
1132                    "Mark the first entry status as in progress without acting on it.",
1133                    cx,
1134                )
1135            })
1136            .await
1137            .unwrap();
1138
1139        thread.read_with(cx, |thread, _| {
1140            assert!(matches!(
1141                thread.plan().entries[0].status,
1142                acp::PlanEntryStatus::InProgress
1143            ));
1144            assert_eq!(thread.plan().entries.len(), entries_len);
1145        });
1146
1147        thread
1148            .update(cx, |thread, cx| {
1149                thread.send_raw(
1150                    "Now mark the first entry as completed without acting on it.",
1151                    cx,
1152                )
1153            })
1154            .await
1155            .unwrap();
1156
1157        thread.read_with(cx, |thread, _| {
1158            assert!(matches!(
1159                thread.plan().entries[0].status,
1160                acp::PlanEntryStatus::Completed
1161            ));
1162            assert_eq!(thread.plan().entries.len(), entries_len);
1163        });
1164    }
1165
1166    #[test]
1167    fn test_deserialize_content_untagged_text() {
1168        let json = json!("Hello, world!");
1169        let content: Content = serde_json::from_value(json).unwrap();
1170        match content {
1171            Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
1172            _ => panic!("Expected UntaggedText variant"),
1173        }
1174    }
1175
1176    #[test]
1177    fn test_deserialize_content_chunks() {
1178        let json = json!([
1179            {
1180                "type": "text",
1181                "text": "Hello"
1182            },
1183            {
1184                "type": "tool_use",
1185                "id": "tool_123",
1186                "name": "calculator",
1187                "input": {"operation": "add", "a": 1, "b": 2}
1188            }
1189        ]);
1190        let content: Content = serde_json::from_value(json).unwrap();
1191        match content {
1192            Content::Chunks(chunks) => {
1193                assert_eq!(chunks.len(), 2);
1194                match &chunks[0] {
1195                    ContentChunk::Text { text } => assert_eq!(text, "Hello"),
1196                    _ => panic!("Expected Text chunk"),
1197                }
1198                match &chunks[1] {
1199                    ContentChunk::ToolUse { id, name, input } => {
1200                        assert_eq!(id, "tool_123");
1201                        assert_eq!(name, "calculator");
1202                        assert_eq!(input["operation"], "add");
1203                        assert_eq!(input["a"], 1);
1204                        assert_eq!(input["b"], 2);
1205                    }
1206                    _ => panic!("Expected ToolUse chunk"),
1207                }
1208            }
1209            _ => panic!("Expected Chunks variant"),
1210        }
1211    }
1212
1213    #[test]
1214    fn test_deserialize_tool_result_untagged_text() {
1215        let json = json!({
1216            "type": "tool_result",
1217            "content": "Result content",
1218            "tool_use_id": "tool_456"
1219        });
1220        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1221        match chunk {
1222            ContentChunk::ToolResult {
1223                content,
1224                tool_use_id,
1225            } => {
1226                match content {
1227                    Content::UntaggedText(text) => assert_eq!(text, "Result content"),
1228                    _ => panic!("Expected UntaggedText content"),
1229                }
1230                assert_eq!(tool_use_id, "tool_456");
1231            }
1232            _ => panic!("Expected ToolResult variant"),
1233        }
1234    }
1235
1236    #[test]
1237    fn test_deserialize_tool_result_chunks() {
1238        let json = json!({
1239            "type": "tool_result",
1240            "content": [
1241                {
1242                    "type": "text",
1243                    "text": "Processing complete"
1244                },
1245                {
1246                    "type": "text",
1247                    "text": "Result: 42"
1248                }
1249            ],
1250            "tool_use_id": "tool_789"
1251        });
1252        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1253        match chunk {
1254            ContentChunk::ToolResult {
1255                content,
1256                tool_use_id,
1257            } => {
1258                match content {
1259                    Content::Chunks(chunks) => {
1260                        assert_eq!(chunks.len(), 2);
1261                        match &chunks[0] {
1262                            ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1263                            _ => panic!("Expected Text chunk"),
1264                        }
1265                        match &chunks[1] {
1266                            ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1267                            _ => panic!("Expected Text chunk"),
1268                        }
1269                    }
1270                    _ => panic!("Expected Chunks content"),
1271                }
1272                assert_eq!(tool_use_id, "tool_789");
1273            }
1274            _ => panic!("Expected ToolResult variant"),
1275        }
1276    }
1277
1278    #[test]
1279    fn test_acp_content_to_claude() {
1280        let acp_content = vec![
1281            acp::ContentBlock::Text(acp::TextContent {
1282                text: "Hello world".to_string(),
1283                annotations: None,
1284            }),
1285            acp::ContentBlock::Image(acp::ImageContent {
1286                data: "base64data".to_string(),
1287                mime_type: "image/png".to_string(),
1288                annotations: None,
1289                uri: None,
1290            }),
1291            acp::ContentBlock::ResourceLink(acp::ResourceLink {
1292                uri: "file:///path/to/example.rs".to_string(),
1293                name: "example.rs".to_string(),
1294                annotations: None,
1295                description: None,
1296                mime_type: None,
1297                size: None,
1298                title: None,
1299            }),
1300            acp::ContentBlock::Resource(acp::EmbeddedResource {
1301                annotations: None,
1302                resource: acp::EmbeddedResourceResource::TextResourceContents(
1303                    acp::TextResourceContents {
1304                        mime_type: None,
1305                        text: "fn main() { println!(\"Hello!\"); }".to_string(),
1306                        uri: "file:///path/to/code.rs".to_string(),
1307                    },
1308                ),
1309            }),
1310            acp::ContentBlock::ResourceLink(acp::ResourceLink {
1311                uri: "invalid_uri_format".to_string(),
1312                name: "invalid.txt".to_string(),
1313                annotations: None,
1314                description: None,
1315                mime_type: None,
1316                size: None,
1317                title: None,
1318            }),
1319        ];
1320
1321        let claude_content = acp_content_to_claude(acp_content);
1322
1323        assert_eq!(claude_content.len(), 6);
1324
1325        match &claude_content[0] {
1326            ContentChunk::Text { text } => assert_eq!(text, "Hello world"),
1327            _ => panic!("Expected Text chunk"),
1328        }
1329
1330        match &claude_content[1] {
1331            ContentChunk::Image { source } => match source {
1332                ImageSource::Base64 { data, media_type } => {
1333                    assert_eq!(data, "base64data");
1334                    assert_eq!(media_type, "image/png");
1335                }
1336                _ => panic!("Expected Base64 image source"),
1337            },
1338            _ => panic!("Expected Image chunk"),
1339        }
1340
1341        match &claude_content[2] {
1342            ContentChunk::Text { text } => {
1343                assert!(text.contains("example.rs"));
1344                assert!(text.contains("file:///path/to/example.rs"));
1345            }
1346            _ => panic!("Expected Text chunk for ResourceLink"),
1347        }
1348
1349        match &claude_content[3] {
1350            ContentChunk::Text { text } => {
1351                assert!(text.contains("code.rs"));
1352                assert!(text.contains("file:///path/to/code.rs"));
1353            }
1354            _ => panic!("Expected Text chunk for Resource"),
1355        }
1356
1357        match &claude_content[4] {
1358            ContentChunk::Text { text } => {
1359                assert_eq!(text, "invalid_uri_format");
1360            }
1361            _ => panic!("Expected Text chunk for invalid URI"),
1362        }
1363
1364        match &claude_content[5] {
1365            ContentChunk::Text { text } => {
1366                assert!(text.contains("<context ref=\"file:///path/to/code.rs\">"));
1367                assert!(text.contains("fn main() { println!(\"Hello!\"); }"));
1368                assert!(text.contains("</context>"));
1369            }
1370            _ => panic!("Expected Text chunk for context"),
1371        }
1372    }
1373}