claude.rs

   1mod edit_tool;
   2mod mcp_server;
   3mod permission_tool;
   4mod read_tool;
   5pub mod tools;
   6mod write_tool;
   7
   8use action_log::ActionLog;
   9use collections::HashMap;
  10use context_server::listener::McpServerTool;
  11use language_models::provider::anthropic::AnthropicLanguageModelProvider;
  12use project::Project;
  13use settings::SettingsStore;
  14use smol::process::Child;
  15use std::any::Any;
  16use std::cell::RefCell;
  17use std::fmt::Display;
  18use std::path::{Path, PathBuf};
  19use std::rc::Rc;
  20use util::command::new_smol_command;
  21use uuid::Uuid;
  22
  23use agent_client_protocol as acp;
  24use anyhow::{Context as _, Result, anyhow};
  25use futures::channel::oneshot;
  26use futures::{AsyncBufReadExt, AsyncWriteExt};
  27use futures::{
  28    AsyncRead, AsyncWrite, FutureExt, StreamExt,
  29    channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
  30    io::BufReader,
  31    select_biased,
  32};
  33use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  34use serde::{Deserialize, Serialize};
  35use util::{ResultExt, debug_panic};
  36
  37use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
  38use crate::claude::tools::ClaudeTool;
  39use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
  40use acp_thread::{AcpThread, AgentConnection, AuthRequired, LoadError, MentionUri};
  41
  42#[derive(Clone)]
  43pub struct ClaudeCode;
  44
  45impl AgentServer for ClaudeCode {
  46    fn name(&self) -> &'static str {
  47        "Claude Code"
  48    }
  49
  50    fn empty_state_headline(&self) -> &'static str {
  51        self.name()
  52    }
  53
  54    fn empty_state_message(&self) -> &'static str {
  55        "How can I help you today?"
  56    }
  57
  58    fn logo(&self) -> ui::IconName {
  59        ui::IconName::AiClaude
  60    }
  61
  62    fn connect(
  63        &self,
  64        _root_dir: &Path,
  65        _project: &Entity<Project>,
  66        _cx: &mut App,
  67    ) -> Task<Result<Rc<dyn AgentConnection>>> {
  68        let connection = ClaudeAgentConnection {
  69            sessions: Default::default(),
  70        };
  71
  72        Task::ready(Ok(Rc::new(connection) as _))
  73    }
  74
  75    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
  76        self
  77    }
  78}
  79
  80struct ClaudeAgentConnection {
  81    sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
  82}
  83
  84impl AgentConnection for ClaudeAgentConnection {
  85    fn new_thread(
  86        self: Rc<Self>,
  87        project: Entity<Project>,
  88        cwd: &Path,
  89        cx: &mut App,
  90    ) -> Task<Result<Entity<AcpThread>>> {
  91        let cwd = cwd.to_owned();
  92        cx.spawn(async move |cx| {
  93            let settings = cx.read_global(|settings: &SettingsStore, _| {
  94                settings.get::<AllAgentServersSettings>(None).claude.clone()
  95            })?;
  96
  97            let Some(command) = AgentServerCommand::resolve(
  98                "claude",
  99                &[],
 100                Some(&util::paths::home_dir().join(".claude/local/claude")),
 101                settings,
 102                &project,
 103                cx,
 104            )
 105            .await
 106            else {
 107                return Err(LoadError::NotInstalled {
 108                    error_message: "Failed to find Claude Code binary".into(),
 109                    install_message: "Install Claude Code".into(),
 110                    install_command: "npm install -g @anthropic-ai/claude-code@latest".into(),
 111                }.into());
 112            };
 113
 114            let api_key =
 115                cx.update(AnthropicLanguageModelProvider::api_key)?
 116                    .await
 117                    .map_err(|err| {
 118                        if err.is::<language_model::AuthenticateError>() {
 119                            anyhow!(AuthRequired::new().with_language_model_provider(
 120                                language_model::ANTHROPIC_PROVIDER_ID
 121                            ))
 122                        } else {
 123                            anyhow!(err)
 124                        }
 125                    })?;
 126
 127            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
 128            let fs = project.read_with(cx, |project, _cx| project.fs().clone())?;
 129            let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), fs, cx).await?;
 130
 131            let mut mcp_servers = HashMap::default();
 132            mcp_servers.insert(
 133                mcp_server::SERVER_NAME.to_string(),
 134                permission_mcp_server.server_config()?,
 135            );
 136            let mcp_config = McpConfig { mcp_servers };
 137
 138            let mcp_config_file = tempfile::NamedTempFile::new()?;
 139            let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
 140
 141            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
 142            mcp_config_file
 143                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
 144                .await?;
 145            mcp_config_file.flush().await?;
 146
 147            let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
 148            let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
 149
 150            let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
 151
 152            log::trace!("Starting session with id: {}", session_id);
 153
 154            let mut child = spawn_claude(
 155                &command,
 156                ClaudeSessionMode::Start,
 157                session_id.clone(),
 158                api_key,
 159                &mcp_config_path,
 160                &cwd,
 161            )?;
 162
 163            let stdout = child.stdout.take().context("Failed to take stdout")?;
 164            let stdin = child.stdin.take().context("Failed to take stdin")?;
 165            let stderr = child.stderr.take().context("Failed to take stderr")?;
 166
 167            let pid = child.id();
 168            log::trace!("Spawned (pid: {})", pid);
 169
 170            cx.background_spawn(async move {
 171                let mut stderr = BufReader::new(stderr);
 172                let mut line = String::new();
 173                while let Ok(n) = stderr.read_line(&mut line).await
 174                    && n > 0
 175                {
 176                    log::warn!("agent stderr: {}", &line);
 177                    line.clear();
 178                }
 179            })
 180            .detach();
 181
 182            cx.background_spawn(async move {
 183                let mut outgoing_rx = Some(outgoing_rx);
 184
 185                ClaudeAgentSession::handle_io(
 186                    outgoing_rx.take().unwrap(),
 187                    incoming_message_tx.clone(),
 188                    stdin,
 189                    stdout,
 190                )
 191                .await?;
 192
 193                log::trace!("Stopped (pid: {})", pid);
 194
 195                drop(mcp_config_path);
 196                anyhow::Ok(())
 197            })
 198            .detach();
 199
 200            let turn_state = Rc::new(RefCell::new(TurnState::None));
 201
 202            let handler_task = cx.spawn({
 203                let turn_state = turn_state.clone();
 204                let mut thread_rx = thread_rx.clone();
 205                async move |cx| {
 206                    while let Some(message) = incoming_message_rx.next().await {
 207                        ClaudeAgentSession::handle_message(
 208                            thread_rx.clone(),
 209                            message,
 210                            turn_state.clone(),
 211                            cx,
 212                        )
 213                        .await
 214                    }
 215
 216                    if let Some(status) = child.status().await.log_err()
 217                        && let Some(thread) = thread_rx.recv().await.ok()
 218                    {
 219                        let version = claude_version(command.path.clone(), cx).await.log_err();
 220                        let help = claude_help(command.path.clone(), cx).await.log_err();
 221                        thread
 222                            .update(cx, |thread, cx| {
 223                                let error = if let Some(version) = version
 224                                    && let Some(help) = help
 225                                    && (!help.contains("--input-format")
 226                                        || !help.contains("--session-id"))
 227                                {
 228                                    LoadError::Unsupported {
 229                                    error_message: format!(
 230                                            "Your installed version of Claude Code ({}, version {}) does not have required features for use with Zed.",
 231                                            command.path.to_string_lossy(),
 232                                            version,
 233                                        )
 234                                        .into(),
 235                                        upgrade_message: "Upgrade Claude Code to latest".into(),
 236                                        upgrade_command: format!(
 237                                            "{} update",
 238                                            command.path.to_string_lossy()
 239                                        ),
 240                                    }
 241                                } else {
 242                                    LoadError::Exited { status }
 243                                };
 244                                thread.emit_load_error(error, cx);
 245                            })
 246                            .ok();
 247                    }
 248                }
 249            });
 250
 251            let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
 252            let thread = cx.new(|_cx| {
 253                AcpThread::new(
 254                    "Claude Code",
 255                    self.clone(),
 256                    project,
 257                    action_log,
 258                    session_id.clone(),
 259                )
 260            })?;
 261
 262            thread_tx.send(thread.downgrade())?;
 263
 264            let session = ClaudeAgentSession {
 265                outgoing_tx,
 266                turn_state,
 267                _handler_task: handler_task,
 268                _mcp_server: Some(permission_mcp_server),
 269            };
 270
 271            self.sessions.borrow_mut().insert(session_id, session);
 272
 273            Ok(thread)
 274        })
 275    }
 276
 277    fn auth_methods(&self) -> &[acp::AuthMethod] {
 278        &[]
 279    }
 280
 281    fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 282        Task::ready(Err(anyhow!("Authentication not supported")))
 283    }
 284
 285    fn prompt(
 286        &self,
 287        _id: Option<acp_thread::UserMessageId>,
 288        params: acp::PromptRequest,
 289        cx: &mut App,
 290    ) -> Task<Result<acp::PromptResponse>> {
 291        let sessions = self.sessions.borrow();
 292        let Some(session) = sessions.get(&params.session_id) else {
 293            return Task::ready(Err(anyhow!(
 294                "Attempted to send message to nonexistent session {}",
 295                params.session_id
 296            )));
 297        };
 298
 299        let (end_tx, end_rx) = oneshot::channel();
 300        session.turn_state.replace(TurnState::InProgress { end_tx });
 301
 302        let content = acp_content_to_claude(params.prompt);
 303
 304        if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
 305            message: Message {
 306                role: Role::User,
 307                content: Content::Chunks(content),
 308                id: None,
 309                model: None,
 310                stop_reason: None,
 311                stop_sequence: None,
 312                usage: None,
 313            },
 314            session_id: Some(params.session_id.to_string()),
 315        }) {
 316            return Task::ready(Err(anyhow!(err)));
 317        }
 318
 319        cx.foreground_executor().spawn(async move { end_rx.await? })
 320    }
 321
 322    fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
 323        let sessions = self.sessions.borrow();
 324        let Some(session) = sessions.get(session_id) else {
 325            log::warn!("Attempted to cancel nonexistent session {}", session_id);
 326            return;
 327        };
 328
 329        let request_id = new_request_id();
 330
 331        let turn_state = session.turn_state.take();
 332        let TurnState::InProgress { end_tx } = turn_state else {
 333            // Already canceled or idle, put it back
 334            session.turn_state.replace(turn_state);
 335            return;
 336        };
 337
 338        session.turn_state.replace(TurnState::CancelRequested {
 339            end_tx,
 340            request_id: request_id.clone(),
 341        });
 342
 343        session
 344            .outgoing_tx
 345            .unbounded_send(SdkMessage::ControlRequest {
 346                request_id,
 347                request: ControlRequest::Interrupt,
 348            })
 349            .log_err();
 350    }
 351
 352    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 353        self
 354    }
 355}
 356
 357#[derive(Clone, Copy)]
 358enum ClaudeSessionMode {
 359    Start,
 360    #[expect(dead_code)]
 361    Resume,
 362}
 363
 364fn spawn_claude(
 365    command: &AgentServerCommand,
 366    mode: ClaudeSessionMode,
 367    session_id: acp::SessionId,
 368    api_key: language_models::provider::anthropic::ApiKey,
 369    mcp_config_path: &Path,
 370    root_dir: &Path,
 371) -> Result<Child> {
 372    let child = util::command::new_smol_command(&command.path)
 373        .args([
 374            "--input-format",
 375            "stream-json",
 376            "--output-format",
 377            "stream-json",
 378            "--print",
 379            "--verbose",
 380            "--mcp-config",
 381            mcp_config_path.to_string_lossy().as_ref(),
 382            "--permission-prompt-tool",
 383            &format!(
 384                "mcp__{}__{}",
 385                mcp_server::SERVER_NAME,
 386                permission_tool::PermissionTool::NAME,
 387            ),
 388            "--allowedTools",
 389            &format!(
 390                "mcp__{}__{}",
 391                mcp_server::SERVER_NAME,
 392                read_tool::ReadTool::NAME
 393            ),
 394            "--disallowedTools",
 395            "Read,Write,Edit,MultiEdit",
 396        ])
 397        .args(match mode {
 398            ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
 399            ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
 400        })
 401        .args(command.args.iter().map(|arg| arg.as_str()))
 402        .envs(command.env.iter().flatten())
 403        .env("ANTHROPIC_API_KEY", api_key.key)
 404        .current_dir(root_dir)
 405        .stdin(std::process::Stdio::piped())
 406        .stdout(std::process::Stdio::piped())
 407        .stderr(std::process::Stdio::piped())
 408        .kill_on_drop(true)
 409        .spawn()?;
 410
 411    Ok(child)
 412}
 413
 414fn claude_version(path: PathBuf, cx: &mut AsyncApp) -> Task<Result<semver::Version>> {
 415    cx.background_spawn(async move {
 416        let output = new_smol_command(path).arg("--version").output().await?;
 417        let output = String::from_utf8(output.stdout)?;
 418        let version = output
 419            .trim()
 420            .strip_suffix(" (Claude Code)")
 421            .context("parsing Claude version")?;
 422        let version = semver::Version::parse(version)?;
 423        anyhow::Ok(version)
 424    })
 425}
 426
 427fn claude_help(path: PathBuf, cx: &mut AsyncApp) -> Task<Result<String>> {
 428    cx.background_spawn(async move {
 429        let output = new_smol_command(path).arg("--help").output().await?;
 430        let output = String::from_utf8(output.stdout)?;
 431        anyhow::Ok(output)
 432    })
 433}
 434
 435struct ClaudeAgentSession {
 436    outgoing_tx: UnboundedSender<SdkMessage>,
 437    turn_state: Rc<RefCell<TurnState>>,
 438    _mcp_server: Option<ClaudeZedMcpServer>,
 439    _handler_task: Task<()>,
 440}
 441
 442#[derive(Debug, Default)]
 443enum TurnState {
 444    #[default]
 445    None,
 446    InProgress {
 447        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 448    },
 449    CancelRequested {
 450        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 451        request_id: String,
 452    },
 453    CancelConfirmed {
 454        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 455    },
 456}
 457
 458impl TurnState {
 459    fn is_canceled(&self) -> bool {
 460        matches!(self, TurnState::CancelConfirmed { .. })
 461    }
 462
 463    fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
 464        match self {
 465            TurnState::None => None,
 466            TurnState::InProgress { end_tx, .. } => Some(end_tx),
 467            TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
 468            TurnState::CancelConfirmed { end_tx } => Some(end_tx),
 469        }
 470    }
 471
 472    fn confirm_cancellation(self, id: &str) -> Self {
 473        match self {
 474            TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
 475                TurnState::CancelConfirmed { end_tx }
 476            }
 477            _ => self,
 478        }
 479    }
 480}
 481
 482impl ClaudeAgentSession {
 483    async fn handle_message(
 484        mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
 485        message: SdkMessage,
 486        turn_state: Rc<RefCell<TurnState>>,
 487        cx: &mut AsyncApp,
 488    ) {
 489        match message {
 490            // we should only be sending these out, they don't need to be in the thread
 491            SdkMessage::ControlRequest { .. } => {}
 492            SdkMessage::User {
 493                message,
 494                session_id: _,
 495            } => {
 496                let Some(thread) = thread_rx
 497                    .recv()
 498                    .await
 499                    .log_err()
 500                    .and_then(|entity| entity.upgrade())
 501                else {
 502                    log::error!("Received an SDK message but thread is gone");
 503                    return;
 504                };
 505
 506                for chunk in message.content.chunks() {
 507                    match chunk {
 508                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 509                            if !turn_state.borrow().is_canceled() {
 510                                thread
 511                                    .update(cx, |thread, cx| {
 512                                        thread.push_user_content_block(None, text.into(), cx)
 513                                    })
 514                                    .log_err();
 515                            }
 516                        }
 517                        ContentChunk::ToolResult {
 518                            content,
 519                            tool_use_id,
 520                        } => {
 521                            let content = content.to_string();
 522                            thread
 523                                .update(cx, |thread, cx| {
 524                                    let id = acp::ToolCallId(tool_use_id.into());
 525                                    let set_new_content = !content.is_empty()
 526                                        && thread.tool_call(&id).is_none_or(|(_, tool_call)| {
 527                                            // preserve rich diff if we have one
 528                                            tool_call.diffs().next().is_none()
 529                                        });
 530
 531                                    thread.update_tool_call(
 532                                        acp::ToolCallUpdate {
 533                                            id,
 534                                            fields: acp::ToolCallUpdateFields {
 535                                                status: if turn_state.borrow().is_canceled() {
 536                                                    // Do not set to completed if turn was canceled
 537                                                    None
 538                                                } else {
 539                                                    Some(acp::ToolCallStatus::Completed)
 540                                                },
 541                                                content: set_new_content
 542                                                    .then(|| vec![content.into()]),
 543                                                ..Default::default()
 544                                            },
 545                                        },
 546                                        cx,
 547                                    )
 548                                })
 549                                .log_err();
 550                        }
 551                        ContentChunk::Thinking { .. }
 552                        | ContentChunk::RedactedThinking
 553                        | ContentChunk::ToolUse { .. } => {
 554                            debug_panic!(
 555                                "Should not get {:?} with role: assistant. should we handle this?",
 556                                chunk
 557                            );
 558                        }
 559                        ContentChunk::Image { source } => {
 560                            if !turn_state.borrow().is_canceled() {
 561                                thread
 562                                    .update(cx, |thread, cx| {
 563                                        thread.push_user_content_block(None, source.into(), cx)
 564                                    })
 565                                    .log_err();
 566                            }
 567                        }
 568
 569                        ContentChunk::Document | ContentChunk::WebSearchToolResult => {
 570                            thread
 571                                .update(cx, |thread, cx| {
 572                                    thread.push_assistant_content_block(
 573                                        format!("Unsupported content: {:?}", chunk).into(),
 574                                        false,
 575                                        cx,
 576                                    )
 577                                })
 578                                .log_err();
 579                        }
 580                    }
 581                }
 582            }
 583            SdkMessage::Assistant {
 584                message,
 585                session_id: _,
 586            } => {
 587                let Some(thread) = thread_rx
 588                    .recv()
 589                    .await
 590                    .log_err()
 591                    .and_then(|entity| entity.upgrade())
 592                else {
 593                    log::error!("Received an SDK message but thread is gone");
 594                    return;
 595                };
 596
 597                for chunk in message.content.chunks() {
 598                    match chunk {
 599                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 600                            thread
 601                                .update(cx, |thread, cx| {
 602                                    thread.push_assistant_content_block(text.into(), false, cx)
 603                                })
 604                                .log_err();
 605                        }
 606                        ContentChunk::Thinking { thinking } => {
 607                            thread
 608                                .update(cx, |thread, cx| {
 609                                    thread.push_assistant_content_block(thinking.into(), true, cx)
 610                                })
 611                                .log_err();
 612                        }
 613                        ContentChunk::RedactedThinking => {
 614                            thread
 615                                .update(cx, |thread, cx| {
 616                                    thread.push_assistant_content_block(
 617                                        "[REDACTED]".into(),
 618                                        true,
 619                                        cx,
 620                                    )
 621                                })
 622                                .log_err();
 623                        }
 624                        ContentChunk::ToolUse { id, name, input } => {
 625                            let claude_tool = ClaudeTool::infer(&name, input);
 626
 627                            thread
 628                                .update(cx, |thread, cx| {
 629                                    if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
 630                                        thread.update_plan(
 631                                            acp::Plan {
 632                                                entries: params
 633                                                    .todos
 634                                                    .into_iter()
 635                                                    .map(Into::into)
 636                                                    .collect(),
 637                                            },
 638                                            cx,
 639                                        )
 640                                    } else {
 641                                        thread.upsert_tool_call(
 642                                            claude_tool.as_acp(acp::ToolCallId(id.into())),
 643                                            cx,
 644                                        )?;
 645                                    }
 646                                    anyhow::Ok(())
 647                                })
 648                                .log_err();
 649                        }
 650                        ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
 651                            debug_panic!(
 652                                "Should not get tool results with role: assistant. should we handle this?"
 653                            );
 654                        }
 655                        ContentChunk::Image { source } => {
 656                            thread
 657                                .update(cx, |thread, cx| {
 658                                    thread.push_assistant_content_block(source.into(), false, cx)
 659                                })
 660                                .log_err();
 661                        }
 662                        ContentChunk::Document => {
 663                            thread
 664                                .update(cx, |thread, cx| {
 665                                    thread.push_assistant_content_block(
 666                                        format!("Unsupported content: {:?}", chunk).into(),
 667                                        false,
 668                                        cx,
 669                                    )
 670                                })
 671                                .log_err();
 672                        }
 673                    }
 674                }
 675            }
 676            SdkMessage::Result {
 677                is_error,
 678                subtype,
 679                result,
 680                ..
 681            } => {
 682                let turn_state = turn_state.take();
 683                let was_canceled = turn_state.is_canceled();
 684                let Some(end_turn_tx) = turn_state.end_tx() else {
 685                    debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
 686                    return;
 687                };
 688
 689                if is_error || (!was_canceled && subtype == ResultErrorType::ErrorDuringExecution) {
 690                    end_turn_tx
 691                        .send(Err(anyhow!(
 692                            "Error: {}",
 693                            result.unwrap_or_else(|| subtype.to_string())
 694                        )))
 695                        .ok();
 696                } else {
 697                    let stop_reason = match subtype {
 698                        ResultErrorType::Success => acp::StopReason::EndTurn,
 699                        ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
 700                        ResultErrorType::ErrorDuringExecution => acp::StopReason::Canceled,
 701                    };
 702                    end_turn_tx
 703                        .send(Ok(acp::PromptResponse { stop_reason }))
 704                        .ok();
 705                }
 706            }
 707            SdkMessage::ControlResponse { response } => {
 708                if matches!(response.subtype, ResultErrorType::Success) {
 709                    let new_state = turn_state.take().confirm_cancellation(&response.request_id);
 710                    turn_state.replace(new_state);
 711                } else {
 712                    log::error!("Control response error: {:?}", response);
 713                }
 714            }
 715            SdkMessage::System { .. } => {}
 716        }
 717    }
 718
 719    async fn handle_io(
 720        mut outgoing_rx: UnboundedReceiver<SdkMessage>,
 721        incoming_tx: UnboundedSender<SdkMessage>,
 722        mut outgoing_bytes: impl Unpin + AsyncWrite,
 723        incoming_bytes: impl Unpin + AsyncRead,
 724    ) -> Result<UnboundedReceiver<SdkMessage>> {
 725        let mut output_reader = BufReader::new(incoming_bytes);
 726        let mut outgoing_line = Vec::new();
 727        let mut incoming_line = String::new();
 728        loop {
 729            select_biased! {
 730                message = outgoing_rx.next() => {
 731                    if let Some(message) = message {
 732                        outgoing_line.clear();
 733                        serde_json::to_writer(&mut outgoing_line, &message)?;
 734                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
 735                        outgoing_line.push(b'\n');
 736                        outgoing_bytes.write_all(&outgoing_line).await.ok();
 737                    } else {
 738                        break;
 739                    }
 740                }
 741                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
 742                    if bytes_read? == 0 {
 743                        break
 744                    }
 745                    log::trace!("recv: {}", &incoming_line);
 746                    match serde_json::from_str::<SdkMessage>(&incoming_line) {
 747                        Ok(message) => {
 748                            incoming_tx.unbounded_send(message).log_err();
 749                        }
 750                        Err(error) => {
 751                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
 752                        }
 753                    }
 754                    incoming_line.clear();
 755                }
 756            }
 757        }
 758
 759        Ok(outgoing_rx)
 760    }
 761}
 762
 763#[derive(Debug, Clone, Serialize, Deserialize)]
 764struct Message {
 765    role: Role,
 766    content: Content,
 767    #[serde(skip_serializing_if = "Option::is_none")]
 768    id: Option<String>,
 769    #[serde(skip_serializing_if = "Option::is_none")]
 770    model: Option<String>,
 771    #[serde(skip_serializing_if = "Option::is_none")]
 772    stop_reason: Option<String>,
 773    #[serde(skip_serializing_if = "Option::is_none")]
 774    stop_sequence: Option<String>,
 775    #[serde(skip_serializing_if = "Option::is_none")]
 776    usage: Option<Usage>,
 777}
 778
 779#[derive(Debug, Clone, Serialize, Deserialize)]
 780#[serde(untagged)]
 781enum Content {
 782    UntaggedText(String),
 783    Chunks(Vec<ContentChunk>),
 784}
 785
 786impl Content {
 787    pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
 788        match self {
 789            Self::Chunks(chunks) => chunks.into_iter(),
 790            Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
 791        }
 792    }
 793}
 794
 795impl Display for Content {
 796    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 797        match self {
 798            Content::UntaggedText(txt) => write!(f, "{}", txt),
 799            Content::Chunks(chunks) => {
 800                for chunk in chunks {
 801                    write!(f, "{}", chunk)?;
 802                }
 803                Ok(())
 804            }
 805        }
 806    }
 807}
 808
 809#[derive(Debug, Clone, Serialize, Deserialize)]
 810#[serde(tag = "type", rename_all = "snake_case")]
 811enum ContentChunk {
 812    Text {
 813        text: String,
 814    },
 815    ToolUse {
 816        id: String,
 817        name: String,
 818        input: serde_json::Value,
 819    },
 820    ToolResult {
 821        content: Content,
 822        tool_use_id: String,
 823    },
 824    Thinking {
 825        thinking: String,
 826    },
 827    RedactedThinking,
 828    Image {
 829        source: ImageSource,
 830    },
 831    // TODO
 832    Document,
 833    WebSearchToolResult,
 834    #[serde(untagged)]
 835    UntaggedText(String),
 836}
 837
 838#[derive(Debug, Clone, Serialize, Deserialize)]
 839#[serde(tag = "type", rename_all = "snake_case")]
 840enum ImageSource {
 841    Base64 { data: String, media_type: String },
 842    Url { url: String },
 843}
 844
 845impl Into<acp::ContentBlock> for ImageSource {
 846    fn into(self) -> acp::ContentBlock {
 847        match self {
 848            ImageSource::Base64 { data, media_type } => {
 849                acp::ContentBlock::Image(acp::ImageContent {
 850                    annotations: None,
 851                    data,
 852                    mime_type: media_type,
 853                    uri: None,
 854                })
 855            }
 856            ImageSource::Url { url } => acp::ContentBlock::Image(acp::ImageContent {
 857                annotations: None,
 858                data: "".to_string(),
 859                mime_type: "".to_string(),
 860                uri: Some(url),
 861            }),
 862        }
 863    }
 864}
 865
 866impl Display for ContentChunk {
 867    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 868        match self {
 869            ContentChunk::Text { text } => write!(f, "{}", text),
 870            ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
 871            ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
 872            ContentChunk::UntaggedText(text) => write!(f, "{}", text),
 873            ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
 874            ContentChunk::Image { .. }
 875            | ContentChunk::Document
 876            | ContentChunk::ToolUse { .. }
 877            | ContentChunk::WebSearchToolResult => {
 878                write!(f, "\n{:?}\n", &self)
 879            }
 880        }
 881    }
 882}
 883
 884#[derive(Debug, Clone, Serialize, Deserialize)]
 885struct Usage {
 886    input_tokens: u32,
 887    cache_creation_input_tokens: u32,
 888    cache_read_input_tokens: u32,
 889    output_tokens: u32,
 890    service_tier: String,
 891}
 892
 893#[derive(Debug, Clone, Serialize, Deserialize)]
 894#[serde(rename_all = "snake_case")]
 895enum Role {
 896    System,
 897    Assistant,
 898    User,
 899}
 900
 901#[derive(Debug, Clone, Serialize, Deserialize)]
 902struct MessageParam {
 903    role: Role,
 904    content: String,
 905}
 906
 907#[derive(Debug, Clone, Serialize, Deserialize)]
 908#[serde(tag = "type", rename_all = "snake_case")]
 909enum SdkMessage {
 910    // An assistant message
 911    Assistant {
 912        message: Message, // from Anthropic SDK
 913        #[serde(skip_serializing_if = "Option::is_none")]
 914        session_id: Option<String>,
 915    },
 916    // A user message
 917    User {
 918        message: Message, // from Anthropic SDK
 919        #[serde(skip_serializing_if = "Option::is_none")]
 920        session_id: Option<String>,
 921    },
 922    // Emitted as the last message in a conversation
 923    Result {
 924        subtype: ResultErrorType,
 925        duration_ms: f64,
 926        duration_api_ms: f64,
 927        is_error: bool,
 928        num_turns: i32,
 929        #[serde(skip_serializing_if = "Option::is_none")]
 930        result: Option<String>,
 931        session_id: String,
 932        total_cost_usd: f64,
 933    },
 934    // Emitted as the first message at the start of a conversation
 935    System {
 936        cwd: String,
 937        session_id: String,
 938        tools: Vec<String>,
 939        model: String,
 940        mcp_servers: Vec<McpServer>,
 941        #[serde(rename = "apiKeySource")]
 942        api_key_source: String,
 943        #[serde(rename = "permissionMode")]
 944        permission_mode: PermissionMode,
 945    },
 946    /// Messages used to control the conversation, outside of chat messages to the model
 947    ControlRequest {
 948        request_id: String,
 949        request: ControlRequest,
 950    },
 951    /// Response to a control request
 952    ControlResponse { response: ControlResponse },
 953}
 954
 955#[derive(Debug, Clone, Serialize, Deserialize)]
 956#[serde(tag = "subtype", rename_all = "snake_case")]
 957enum ControlRequest {
 958    /// Cancel the current conversation
 959    Interrupt,
 960}
 961
 962#[derive(Debug, Clone, Serialize, Deserialize)]
 963struct ControlResponse {
 964    request_id: String,
 965    subtype: ResultErrorType,
 966}
 967
 968#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
 969#[serde(rename_all = "snake_case")]
 970enum ResultErrorType {
 971    Success,
 972    ErrorMaxTurns,
 973    ErrorDuringExecution,
 974}
 975
 976impl Display for ResultErrorType {
 977    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 978        match self {
 979            ResultErrorType::Success => write!(f, "success"),
 980            ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
 981            ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
 982        }
 983    }
 984}
 985
 986fn acp_content_to_claude(prompt: Vec<acp::ContentBlock>) -> Vec<ContentChunk> {
 987    let mut content = Vec::with_capacity(prompt.len());
 988    let mut context = Vec::with_capacity(prompt.len());
 989
 990    for chunk in prompt {
 991        match chunk {
 992            acp::ContentBlock::Text(text_content) => {
 993                content.push(ContentChunk::Text {
 994                    text: text_content.text,
 995                });
 996            }
 997            acp::ContentBlock::ResourceLink(resource_link) => {
 998                match MentionUri::parse(&resource_link.uri) {
 999                    Ok(uri) => {
1000                        content.push(ContentChunk::Text {
1001                            text: format!("{}", uri.as_link()),
1002                        });
1003                    }
1004                    Err(_) => {
1005                        content.push(ContentChunk::Text {
1006                            text: resource_link.uri,
1007                        });
1008                    }
1009                }
1010            }
1011            acp::ContentBlock::Resource(resource) => match resource.resource {
1012                acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1013                    match MentionUri::parse(&resource.uri) {
1014                        Ok(uri) => {
1015                            content.push(ContentChunk::Text {
1016                                text: format!("{}", uri.as_link()),
1017                            });
1018                        }
1019                        Err(_) => {
1020                            content.push(ContentChunk::Text {
1021                                text: resource.uri.clone(),
1022                            });
1023                        }
1024                    }
1025
1026                    context.push(ContentChunk::Text {
1027                        text: format!(
1028                            "\n<context ref=\"{}\">\n{}\n</context>",
1029                            resource.uri, resource.text
1030                        ),
1031                    });
1032                }
1033                acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1034                    // Unsupported by SDK
1035                }
1036            },
1037            acp::ContentBlock::Image(acp::ImageContent {
1038                data, mime_type, ..
1039            }) => content.push(ContentChunk::Image {
1040                source: ImageSource::Base64 {
1041                    data,
1042                    media_type: mime_type,
1043                },
1044            }),
1045            acp::ContentBlock::Audio(_) => {
1046                // Unsupported by SDK
1047            }
1048        }
1049    }
1050
1051    content.extend(context);
1052    content
1053}
1054
1055fn new_request_id() -> String {
1056    use rand::Rng;
1057    // In the Claude Code TS SDK they just generate a random 12 character string,
1058    // `Math.random().toString(36).substring(2, 15)`
1059    rand::thread_rng()
1060        .sample_iter(&rand::distributions::Alphanumeric)
1061        .take(12)
1062        .map(char::from)
1063        .collect()
1064}
1065
1066#[derive(Debug, Clone, Serialize, Deserialize)]
1067struct McpServer {
1068    name: String,
1069    status: String,
1070}
1071
1072#[derive(Debug, Clone, Serialize, Deserialize)]
1073#[serde(rename_all = "camelCase")]
1074enum PermissionMode {
1075    Default,
1076    AcceptEdits,
1077    BypassPermissions,
1078    Plan,
1079}
1080
1081#[cfg(test)]
1082pub(crate) mod tests {
1083    use super::*;
1084    use crate::e2e_tests;
1085    use gpui::TestAppContext;
1086    use serde_json::json;
1087
1088    crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
1089
1090    pub fn local_command() -> AgentServerCommand {
1091        AgentServerCommand {
1092            path: "claude".into(),
1093            args: vec![],
1094            env: None,
1095        }
1096    }
1097
1098    #[gpui::test]
1099    #[cfg_attr(not(feature = "e2e"), ignore)]
1100    async fn test_todo_plan(cx: &mut TestAppContext) {
1101        let fs = e2e_tests::init_test(cx).await;
1102        let project = Project::test(fs, [], cx).await;
1103        let thread =
1104            e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
1105
1106        thread
1107            .update(cx, |thread, cx| {
1108                thread.send_raw(
1109                    "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
1110                    cx,
1111                )
1112            })
1113            .await
1114            .unwrap();
1115
1116        let mut entries_len = 0;
1117
1118        thread.read_with(cx, |thread, _| {
1119            entries_len = thread.plan().entries.len();
1120            assert!(thread.plan().entries.len() > 0, "Empty plan");
1121        });
1122
1123        thread
1124            .update(cx, |thread, cx| {
1125                thread.send_raw(
1126                    "Mark the first entry status as in progress without acting on it.",
1127                    cx,
1128                )
1129            })
1130            .await
1131            .unwrap();
1132
1133        thread.read_with(cx, |thread, _| {
1134            assert!(matches!(
1135                thread.plan().entries[0].status,
1136                acp::PlanEntryStatus::InProgress
1137            ));
1138            assert_eq!(thread.plan().entries.len(), entries_len);
1139        });
1140
1141        thread
1142            .update(cx, |thread, cx| {
1143                thread.send_raw(
1144                    "Now mark the first entry as completed without acting on it.",
1145                    cx,
1146                )
1147            })
1148            .await
1149            .unwrap();
1150
1151        thread.read_with(cx, |thread, _| {
1152            assert!(matches!(
1153                thread.plan().entries[0].status,
1154                acp::PlanEntryStatus::Completed
1155            ));
1156            assert_eq!(thread.plan().entries.len(), entries_len);
1157        });
1158    }
1159
1160    #[test]
1161    fn test_deserialize_content_untagged_text() {
1162        let json = json!("Hello, world!");
1163        let content: Content = serde_json::from_value(json).unwrap();
1164        match content {
1165            Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
1166            _ => panic!("Expected UntaggedText variant"),
1167        }
1168    }
1169
1170    #[test]
1171    fn test_deserialize_content_chunks() {
1172        let json = json!([
1173            {
1174                "type": "text",
1175                "text": "Hello"
1176            },
1177            {
1178                "type": "tool_use",
1179                "id": "tool_123",
1180                "name": "calculator",
1181                "input": {"operation": "add", "a": 1, "b": 2}
1182            }
1183        ]);
1184        let content: Content = serde_json::from_value(json).unwrap();
1185        match content {
1186            Content::Chunks(chunks) => {
1187                assert_eq!(chunks.len(), 2);
1188                match &chunks[0] {
1189                    ContentChunk::Text { text } => assert_eq!(text, "Hello"),
1190                    _ => panic!("Expected Text chunk"),
1191                }
1192                match &chunks[1] {
1193                    ContentChunk::ToolUse { id, name, input } => {
1194                        assert_eq!(id, "tool_123");
1195                        assert_eq!(name, "calculator");
1196                        assert_eq!(input["operation"], "add");
1197                        assert_eq!(input["a"], 1);
1198                        assert_eq!(input["b"], 2);
1199                    }
1200                    _ => panic!("Expected ToolUse chunk"),
1201                }
1202            }
1203            _ => panic!("Expected Chunks variant"),
1204        }
1205    }
1206
1207    #[test]
1208    fn test_deserialize_tool_result_untagged_text() {
1209        let json = json!({
1210            "type": "tool_result",
1211            "content": "Result content",
1212            "tool_use_id": "tool_456"
1213        });
1214        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1215        match chunk {
1216            ContentChunk::ToolResult {
1217                content,
1218                tool_use_id,
1219            } => {
1220                match content {
1221                    Content::UntaggedText(text) => assert_eq!(text, "Result content"),
1222                    _ => panic!("Expected UntaggedText content"),
1223                }
1224                assert_eq!(tool_use_id, "tool_456");
1225            }
1226            _ => panic!("Expected ToolResult variant"),
1227        }
1228    }
1229
1230    #[test]
1231    fn test_deserialize_tool_result_chunks() {
1232        let json = json!({
1233            "type": "tool_result",
1234            "content": [
1235                {
1236                    "type": "text",
1237                    "text": "Processing complete"
1238                },
1239                {
1240                    "type": "text",
1241                    "text": "Result: 42"
1242                }
1243            ],
1244            "tool_use_id": "tool_789"
1245        });
1246        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1247        match chunk {
1248            ContentChunk::ToolResult {
1249                content,
1250                tool_use_id,
1251            } => {
1252                match content {
1253                    Content::Chunks(chunks) => {
1254                        assert_eq!(chunks.len(), 2);
1255                        match &chunks[0] {
1256                            ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1257                            _ => panic!("Expected Text chunk"),
1258                        }
1259                        match &chunks[1] {
1260                            ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1261                            _ => panic!("Expected Text chunk"),
1262                        }
1263                    }
1264                    _ => panic!("Expected Chunks content"),
1265                }
1266                assert_eq!(tool_use_id, "tool_789");
1267            }
1268            _ => panic!("Expected ToolResult variant"),
1269        }
1270    }
1271
1272    #[test]
1273    fn test_acp_content_to_claude() {
1274        let acp_content = vec![
1275            acp::ContentBlock::Text(acp::TextContent {
1276                text: "Hello world".to_string(),
1277                annotations: None,
1278            }),
1279            acp::ContentBlock::Image(acp::ImageContent {
1280                data: "base64data".to_string(),
1281                mime_type: "image/png".to_string(),
1282                annotations: None,
1283                uri: None,
1284            }),
1285            acp::ContentBlock::ResourceLink(acp::ResourceLink {
1286                uri: "file:///path/to/example.rs".to_string(),
1287                name: "example.rs".to_string(),
1288                annotations: None,
1289                description: None,
1290                mime_type: None,
1291                size: None,
1292                title: None,
1293            }),
1294            acp::ContentBlock::Resource(acp::EmbeddedResource {
1295                annotations: None,
1296                resource: acp::EmbeddedResourceResource::TextResourceContents(
1297                    acp::TextResourceContents {
1298                        mime_type: None,
1299                        text: "fn main() { println!(\"Hello!\"); }".to_string(),
1300                        uri: "file:///path/to/code.rs".to_string(),
1301                    },
1302                ),
1303            }),
1304            acp::ContentBlock::ResourceLink(acp::ResourceLink {
1305                uri: "invalid_uri_format".to_string(),
1306                name: "invalid.txt".to_string(),
1307                annotations: None,
1308                description: None,
1309                mime_type: None,
1310                size: None,
1311                title: None,
1312            }),
1313        ];
1314
1315        let claude_content = acp_content_to_claude(acp_content);
1316
1317        assert_eq!(claude_content.len(), 6);
1318
1319        match &claude_content[0] {
1320            ContentChunk::Text { text } => assert_eq!(text, "Hello world"),
1321            _ => panic!("Expected Text chunk"),
1322        }
1323
1324        match &claude_content[1] {
1325            ContentChunk::Image { source } => match source {
1326                ImageSource::Base64 { data, media_type } => {
1327                    assert_eq!(data, "base64data");
1328                    assert_eq!(media_type, "image/png");
1329                }
1330                _ => panic!("Expected Base64 image source"),
1331            },
1332            _ => panic!("Expected Image chunk"),
1333        }
1334
1335        match &claude_content[2] {
1336            ContentChunk::Text { text } => {
1337                assert!(text.contains("example.rs"));
1338                assert!(text.contains("file:///path/to/example.rs"));
1339            }
1340            _ => panic!("Expected Text chunk for ResourceLink"),
1341        }
1342
1343        match &claude_content[3] {
1344            ContentChunk::Text { text } => {
1345                assert!(text.contains("code.rs"));
1346                assert!(text.contains("file:///path/to/code.rs"));
1347            }
1348            _ => panic!("Expected Text chunk for Resource"),
1349        }
1350
1351        match &claude_content[4] {
1352            ContentChunk::Text { text } => {
1353                assert_eq!(text, "invalid_uri_format");
1354            }
1355            _ => panic!("Expected Text chunk for invalid URI"),
1356        }
1357
1358        match &claude_content[5] {
1359            ContentChunk::Text { text } => {
1360                assert!(text.contains("<context ref=\"file:///path/to/code.rs\">"));
1361                assert!(text.contains("fn main() { println!(\"Hello!\"); }"));
1362                assert!(text.contains("</context>"));
1363            }
1364            _ => panic!("Expected Text chunk for context"),
1365        }
1366    }
1367}