claude.rs

   1mod edit_tool;
   2mod mcp_server;
   3mod permission_tool;
   4mod read_tool;
   5pub mod tools;
   6mod write_tool;
   7
   8use action_log::ActionLog;
   9use collections::HashMap;
  10use context_server::listener::McpServerTool;
  11use language_models::provider::anthropic::AnthropicLanguageModelProvider;
  12use project::Project;
  13use settings::SettingsStore;
  14use smol::process::Child;
  15use std::any::Any;
  16use std::cell::RefCell;
  17use std::fmt::Display;
  18use std::path::Path;
  19use std::rc::Rc;
  20use uuid::Uuid;
  21
  22use agent_client_protocol as acp;
  23use anyhow::{Context as _, Result, anyhow};
  24use futures::channel::oneshot;
  25use futures::{AsyncBufReadExt, AsyncWriteExt};
  26use futures::{
  27    AsyncRead, AsyncWrite, FutureExt, StreamExt,
  28    channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
  29    io::BufReader,
  30    select_biased,
  31};
  32use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  33use serde::{Deserialize, Serialize};
  34use util::{ResultExt, debug_panic};
  35
  36use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
  37use crate::claude::tools::ClaudeTool;
  38use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
  39use acp_thread::{AcpThread, AgentConnection, AuthRequired, MentionUri};
  40
  41#[derive(Clone)]
  42pub struct ClaudeCode;
  43
  44impl AgentServer for ClaudeCode {
  45    fn name(&self) -> &'static str {
  46        "Claude Code"
  47    }
  48
  49    fn empty_state_headline(&self) -> &'static str {
  50        self.name()
  51    }
  52
  53    fn empty_state_message(&self) -> &'static str {
  54        "How can I help you today?"
  55    }
  56
  57    fn logo(&self) -> ui::IconName {
  58        ui::IconName::AiClaude
  59    }
  60
  61    fn connect(
  62        &self,
  63        _root_dir: &Path,
  64        _project: &Entity<Project>,
  65        _cx: &mut App,
  66    ) -> Task<Result<Rc<dyn AgentConnection>>> {
  67        let connection = ClaudeAgentConnection {
  68            sessions: Default::default(),
  69        };
  70
  71        Task::ready(Ok(Rc::new(connection) as _))
  72    }
  73
  74    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
  75        self
  76    }
  77}
  78
  79struct ClaudeAgentConnection {
  80    sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
  81}
  82
  83impl AgentConnection for ClaudeAgentConnection {
  84    fn new_thread(
  85        self: Rc<Self>,
  86        project: Entity<Project>,
  87        cwd: &Path,
  88        cx: &mut App,
  89    ) -> Task<Result<Entity<AcpThread>>> {
  90        let cwd = cwd.to_owned();
  91        cx.spawn(async move |cx| {
  92            let settings = cx.read_global(|settings: &SettingsStore, _| {
  93                settings.get::<AllAgentServersSettings>(None).claude.clone()
  94            })?;
  95
  96            let Some(command) = AgentServerCommand::resolve(
  97                "claude",
  98                &[],
  99                Some(&util::paths::home_dir().join(".claude/local/claude")),
 100                settings,
 101                &project,
 102                cx,
 103            )
 104            .await
 105            else {
 106                anyhow::bail!("Failed to find claude binary");
 107            };
 108
 109            let api_key =
 110                cx.update(AnthropicLanguageModelProvider::api_key)?
 111                    .await
 112                    .map_err(|err| {
 113                        if err.is::<language_model::AuthenticateError>() {
 114                            anyhow!(AuthRequired::new().with_language_model_provider(
 115                                language_model::ANTHROPIC_PROVIDER_ID
 116                            ))
 117                        } else {
 118                            anyhow!(err)
 119                        }
 120                    })?;
 121
 122            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
 123            let fs = project.read_with(cx, |project, _cx| project.fs().clone())?;
 124            let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), fs, cx).await?;
 125
 126            let mut mcp_servers = HashMap::default();
 127            mcp_servers.insert(
 128                mcp_server::SERVER_NAME.to_string(),
 129                permission_mcp_server.server_config()?,
 130            );
 131            let mcp_config = McpConfig { mcp_servers };
 132
 133            let mcp_config_file = tempfile::NamedTempFile::new()?;
 134            let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
 135
 136            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
 137            mcp_config_file
 138                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
 139                .await?;
 140            mcp_config_file.flush().await?;
 141
 142            let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
 143            let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
 144
 145            let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
 146
 147            log::trace!("Starting session with id: {}", session_id);
 148
 149            let mut child = spawn_claude(
 150                &command,
 151                ClaudeSessionMode::Start,
 152                session_id.clone(),
 153                api_key,
 154                &mcp_config_path,
 155                &cwd,
 156            )?;
 157
 158            let stdout = child.stdout.take().context("Failed to take stdout")?;
 159            let stdin = child.stdin.take().context("Failed to take stdin")?;
 160            let stderr = child.stderr.take().context("Failed to take stderr")?;
 161
 162            let pid = child.id();
 163            log::trace!("Spawned (pid: {})", pid);
 164
 165            cx.background_spawn(async move {
 166                let mut stderr = BufReader::new(stderr);
 167                let mut line = String::new();
 168                while let Ok(n) = stderr.read_line(&mut line).await
 169                    && n > 0
 170                {
 171                    log::warn!("agent stderr: {}", &line);
 172                    line.clear();
 173                }
 174            })
 175            .detach();
 176
 177            cx.background_spawn(async move {
 178                let mut outgoing_rx = Some(outgoing_rx);
 179
 180                ClaudeAgentSession::handle_io(
 181                    outgoing_rx.take().unwrap(),
 182                    incoming_message_tx.clone(),
 183                    stdin,
 184                    stdout,
 185                )
 186                .await?;
 187
 188                log::trace!("Stopped (pid: {})", pid);
 189
 190                drop(mcp_config_path);
 191                anyhow::Ok(())
 192            })
 193            .detach();
 194
 195            let turn_state = Rc::new(RefCell::new(TurnState::None));
 196
 197            let handler_task = cx.spawn({
 198                let turn_state = turn_state.clone();
 199                let mut thread_rx = thread_rx.clone();
 200                async move |cx| {
 201                    while let Some(message) = incoming_message_rx.next().await {
 202                        ClaudeAgentSession::handle_message(
 203                            thread_rx.clone(),
 204                            message,
 205                            turn_state.clone(),
 206                            cx,
 207                        )
 208                        .await
 209                    }
 210
 211                    if let Some(status) = child.status().await.log_err()
 212                        && let Some(thread) = thread_rx.recv().await.ok()
 213                    {
 214                        thread
 215                            .update(cx, |thread, cx| {
 216                                thread.emit_server_exited(status, cx);
 217                            })
 218                            .ok();
 219                    }
 220                }
 221            });
 222
 223            let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
 224            let thread = cx.new(|_cx| {
 225                AcpThread::new(
 226                    "Claude Code",
 227                    self.clone(),
 228                    project,
 229                    action_log,
 230                    session_id.clone(),
 231                )
 232            })?;
 233
 234            thread_tx.send(thread.downgrade())?;
 235
 236            let session = ClaudeAgentSession {
 237                outgoing_tx,
 238                turn_state,
 239                _handler_task: handler_task,
 240                _mcp_server: Some(permission_mcp_server),
 241            };
 242
 243            self.sessions.borrow_mut().insert(session_id, session);
 244
 245            Ok(thread)
 246        })
 247    }
 248
 249    fn auth_methods(&self) -> &[acp::AuthMethod] {
 250        &[]
 251    }
 252
 253    fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 254        Task::ready(Err(anyhow!("Authentication not supported")))
 255    }
 256
 257    fn prompt(
 258        &self,
 259        _id: Option<acp_thread::UserMessageId>,
 260        params: acp::PromptRequest,
 261        cx: &mut App,
 262    ) -> Task<Result<acp::PromptResponse>> {
 263        let sessions = self.sessions.borrow();
 264        let Some(session) = sessions.get(&params.session_id) else {
 265            return Task::ready(Err(anyhow!(
 266                "Attempted to send message to nonexistent session {}",
 267                params.session_id
 268            )));
 269        };
 270
 271        let (end_tx, end_rx) = oneshot::channel();
 272        session.turn_state.replace(TurnState::InProgress { end_tx });
 273
 274        let content = acp_content_to_claude(params.prompt);
 275
 276        if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
 277            message: Message {
 278                role: Role::User,
 279                content: Content::Chunks(content),
 280                id: None,
 281                model: None,
 282                stop_reason: None,
 283                stop_sequence: None,
 284                usage: None,
 285            },
 286            session_id: Some(params.session_id.to_string()),
 287        }) {
 288            return Task::ready(Err(anyhow!(err)));
 289        }
 290
 291        cx.foreground_executor().spawn(async move { end_rx.await? })
 292    }
 293
 294    fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
 295        let sessions = self.sessions.borrow();
 296        let Some(session) = sessions.get(session_id) else {
 297            log::warn!("Attempted to cancel nonexistent session {}", session_id);
 298            return;
 299        };
 300
 301        let request_id = new_request_id();
 302
 303        let turn_state = session.turn_state.take();
 304        let TurnState::InProgress { end_tx } = turn_state else {
 305            // Already canceled or idle, put it back
 306            session.turn_state.replace(turn_state);
 307            return;
 308        };
 309
 310        session.turn_state.replace(TurnState::CancelRequested {
 311            end_tx,
 312            request_id: request_id.clone(),
 313        });
 314
 315        session
 316            .outgoing_tx
 317            .unbounded_send(SdkMessage::ControlRequest {
 318                request_id,
 319                request: ControlRequest::Interrupt,
 320            })
 321            .log_err();
 322    }
 323
 324    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 325        self
 326    }
 327}
 328
 329#[derive(Clone, Copy)]
 330enum ClaudeSessionMode {
 331    Start,
 332    #[expect(dead_code)]
 333    Resume,
 334}
 335
 336fn spawn_claude(
 337    command: &AgentServerCommand,
 338    mode: ClaudeSessionMode,
 339    session_id: acp::SessionId,
 340    api_key: language_models::provider::anthropic::ApiKey,
 341    mcp_config_path: &Path,
 342    root_dir: &Path,
 343) -> Result<Child> {
 344    let child = util::command::new_smol_command(&command.path)
 345        .args([
 346            "--input-format",
 347            "stream-json",
 348            "--output-format",
 349            "stream-json",
 350            "--print",
 351            "--verbose",
 352            "--mcp-config",
 353            mcp_config_path.to_string_lossy().as_ref(),
 354            "--permission-prompt-tool",
 355            &format!(
 356                "mcp__{}__{}",
 357                mcp_server::SERVER_NAME,
 358                permission_tool::PermissionTool::NAME,
 359            ),
 360            "--allowedTools",
 361            &format!(
 362                "mcp__{}__{}",
 363                mcp_server::SERVER_NAME,
 364                read_tool::ReadTool::NAME
 365            ),
 366            "--disallowedTools",
 367            "Read,Write,Edit,MultiEdit",
 368        ])
 369        .args(match mode {
 370            ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
 371            ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
 372        })
 373        .args(command.args.iter().map(|arg| arg.as_str()))
 374        .envs(command.env.iter().flatten())
 375        .env("ANTHROPIC_API_KEY", api_key.key)
 376        .current_dir(root_dir)
 377        .stdin(std::process::Stdio::piped())
 378        .stdout(std::process::Stdio::piped())
 379        .stderr(std::process::Stdio::piped())
 380        .kill_on_drop(true)
 381        .spawn()?;
 382
 383    Ok(child)
 384}
 385
 386struct ClaudeAgentSession {
 387    outgoing_tx: UnboundedSender<SdkMessage>,
 388    turn_state: Rc<RefCell<TurnState>>,
 389    _mcp_server: Option<ClaudeZedMcpServer>,
 390    _handler_task: Task<()>,
 391}
 392
 393#[derive(Debug, Default)]
 394enum TurnState {
 395    #[default]
 396    None,
 397    InProgress {
 398        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 399    },
 400    CancelRequested {
 401        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 402        request_id: String,
 403    },
 404    CancelConfirmed {
 405        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 406    },
 407}
 408
 409impl TurnState {
 410    fn is_canceled(&self) -> bool {
 411        matches!(self, TurnState::CancelConfirmed { .. })
 412    }
 413
 414    fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
 415        match self {
 416            TurnState::None => None,
 417            TurnState::InProgress { end_tx, .. } => Some(end_tx),
 418            TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
 419            TurnState::CancelConfirmed { end_tx } => Some(end_tx),
 420        }
 421    }
 422
 423    fn confirm_cancellation(self, id: &str) -> Self {
 424        match self {
 425            TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
 426                TurnState::CancelConfirmed { end_tx }
 427            }
 428            _ => self,
 429        }
 430    }
 431}
 432
 433impl ClaudeAgentSession {
 434    async fn handle_message(
 435        mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
 436        message: SdkMessage,
 437        turn_state: Rc<RefCell<TurnState>>,
 438        cx: &mut AsyncApp,
 439    ) {
 440        match message {
 441            // we should only be sending these out, they don't need to be in the thread
 442            SdkMessage::ControlRequest { .. } => {}
 443            SdkMessage::User {
 444                message,
 445                session_id: _,
 446            } => {
 447                let Some(thread) = thread_rx
 448                    .recv()
 449                    .await
 450                    .log_err()
 451                    .and_then(|entity| entity.upgrade())
 452                else {
 453                    log::error!("Received an SDK message but thread is gone");
 454                    return;
 455                };
 456
 457                for chunk in message.content.chunks() {
 458                    match chunk {
 459                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 460                            if !turn_state.borrow().is_canceled() {
 461                                thread
 462                                    .update(cx, |thread, cx| {
 463                                        thread.push_user_content_block(None, text.into(), cx)
 464                                    })
 465                                    .log_err();
 466                            }
 467                        }
 468                        ContentChunk::ToolResult {
 469                            content,
 470                            tool_use_id,
 471                        } => {
 472                            let content = content.to_string();
 473                            thread
 474                                .update(cx, |thread, cx| {
 475                                    let id = acp::ToolCallId(tool_use_id.into());
 476                                    let set_new_content = !content.is_empty()
 477                                        && thread.tool_call(&id).is_none_or(|(_, tool_call)| {
 478                                            // preserve rich diff if we have one
 479                                            tool_call.diffs().next().is_none()
 480                                        });
 481
 482                                    thread.update_tool_call(
 483                                        acp::ToolCallUpdate {
 484                                            id,
 485                                            fields: acp::ToolCallUpdateFields {
 486                                                status: if turn_state.borrow().is_canceled() {
 487                                                    // Do not set to completed if turn was canceled
 488                                                    None
 489                                                } else {
 490                                                    Some(acp::ToolCallStatus::Completed)
 491                                                },
 492                                                content: set_new_content
 493                                                    .then(|| vec![content.into()]),
 494                                                ..Default::default()
 495                                            },
 496                                        },
 497                                        cx,
 498                                    )
 499                                })
 500                                .log_err();
 501                        }
 502                        ContentChunk::Thinking { .. }
 503                        | ContentChunk::RedactedThinking
 504                        | ContentChunk::ToolUse { .. } => {
 505                            debug_panic!(
 506                                "Should not get {:?} with role: assistant. should we handle this?",
 507                                chunk
 508                            );
 509                        }
 510                        ContentChunk::Image { source } => {
 511                            if !turn_state.borrow().is_canceled() {
 512                                thread
 513                                    .update(cx, |thread, cx| {
 514                                        thread.push_user_content_block(None, source.into(), cx)
 515                                    })
 516                                    .log_err();
 517                            }
 518                        }
 519
 520                        ContentChunk::Document | ContentChunk::WebSearchToolResult => {
 521                            thread
 522                                .update(cx, |thread, cx| {
 523                                    thread.push_assistant_content_block(
 524                                        format!("Unsupported content: {:?}", chunk).into(),
 525                                        false,
 526                                        cx,
 527                                    )
 528                                })
 529                                .log_err();
 530                        }
 531                    }
 532                }
 533            }
 534            SdkMessage::Assistant {
 535                message,
 536                session_id: _,
 537            } => {
 538                let Some(thread) = thread_rx
 539                    .recv()
 540                    .await
 541                    .log_err()
 542                    .and_then(|entity| entity.upgrade())
 543                else {
 544                    log::error!("Received an SDK message but thread is gone");
 545                    return;
 546                };
 547
 548                for chunk in message.content.chunks() {
 549                    match chunk {
 550                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 551                            thread
 552                                .update(cx, |thread, cx| {
 553                                    thread.push_assistant_content_block(text.into(), false, cx)
 554                                })
 555                                .log_err();
 556                        }
 557                        ContentChunk::Thinking { thinking } => {
 558                            thread
 559                                .update(cx, |thread, cx| {
 560                                    thread.push_assistant_content_block(thinking.into(), true, cx)
 561                                })
 562                                .log_err();
 563                        }
 564                        ContentChunk::RedactedThinking => {
 565                            thread
 566                                .update(cx, |thread, cx| {
 567                                    thread.push_assistant_content_block(
 568                                        "[REDACTED]".into(),
 569                                        true,
 570                                        cx,
 571                                    )
 572                                })
 573                                .log_err();
 574                        }
 575                        ContentChunk::ToolUse { id, name, input } => {
 576                            let claude_tool = ClaudeTool::infer(&name, input);
 577
 578                            thread
 579                                .update(cx, |thread, cx| {
 580                                    if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
 581                                        thread.update_plan(
 582                                            acp::Plan {
 583                                                entries: params
 584                                                    .todos
 585                                                    .into_iter()
 586                                                    .map(Into::into)
 587                                                    .collect(),
 588                                            },
 589                                            cx,
 590                                        )
 591                                    } else {
 592                                        thread.upsert_tool_call(
 593                                            claude_tool.as_acp(acp::ToolCallId(id.into())),
 594                                            cx,
 595                                        )?;
 596                                    }
 597                                    anyhow::Ok(())
 598                                })
 599                                .log_err();
 600                        }
 601                        ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
 602                            debug_panic!(
 603                                "Should not get tool results with role: assistant. should we handle this?"
 604                            );
 605                        }
 606                        ContentChunk::Image { source } => {
 607                            thread
 608                                .update(cx, |thread, cx| {
 609                                    thread.push_assistant_content_block(source.into(), false, cx)
 610                                })
 611                                .log_err();
 612                        }
 613                        ContentChunk::Document => {
 614                            thread
 615                                .update(cx, |thread, cx| {
 616                                    thread.push_assistant_content_block(
 617                                        format!("Unsupported content: {:?}", chunk).into(),
 618                                        false,
 619                                        cx,
 620                                    )
 621                                })
 622                                .log_err();
 623                        }
 624                    }
 625                }
 626            }
 627            SdkMessage::Result {
 628                is_error,
 629                subtype,
 630                result,
 631                ..
 632            } => {
 633                let turn_state = turn_state.take();
 634                let was_canceled = turn_state.is_canceled();
 635                let Some(end_turn_tx) = turn_state.end_tx() else {
 636                    debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
 637                    return;
 638                };
 639
 640                if is_error || (!was_canceled && subtype == ResultErrorType::ErrorDuringExecution) {
 641                    end_turn_tx
 642                        .send(Err(anyhow!(
 643                            "Error: {}",
 644                            result.unwrap_or_else(|| subtype.to_string())
 645                        )))
 646                        .ok();
 647                } else {
 648                    let stop_reason = match subtype {
 649                        ResultErrorType::Success => acp::StopReason::EndTurn,
 650                        ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
 651                        ResultErrorType::ErrorDuringExecution => acp::StopReason::Canceled,
 652                    };
 653                    end_turn_tx
 654                        .send(Ok(acp::PromptResponse { stop_reason }))
 655                        .ok();
 656                }
 657            }
 658            SdkMessage::ControlResponse { response } => {
 659                if matches!(response.subtype, ResultErrorType::Success) {
 660                    let new_state = turn_state.take().confirm_cancellation(&response.request_id);
 661                    turn_state.replace(new_state);
 662                } else {
 663                    log::error!("Control response error: {:?}", response);
 664                }
 665            }
 666            SdkMessage::System { .. } => {}
 667        }
 668    }
 669
 670    async fn handle_io(
 671        mut outgoing_rx: UnboundedReceiver<SdkMessage>,
 672        incoming_tx: UnboundedSender<SdkMessage>,
 673        mut outgoing_bytes: impl Unpin + AsyncWrite,
 674        incoming_bytes: impl Unpin + AsyncRead,
 675    ) -> Result<UnboundedReceiver<SdkMessage>> {
 676        let mut output_reader = BufReader::new(incoming_bytes);
 677        let mut outgoing_line = Vec::new();
 678        let mut incoming_line = String::new();
 679        loop {
 680            select_biased! {
 681                message = outgoing_rx.next() => {
 682                    if let Some(message) = message {
 683                        outgoing_line.clear();
 684                        serde_json::to_writer(&mut outgoing_line, &message)?;
 685                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
 686                        outgoing_line.push(b'\n');
 687                        outgoing_bytes.write_all(&outgoing_line).await.ok();
 688                    } else {
 689                        break;
 690                    }
 691                }
 692                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
 693                    if bytes_read? == 0 {
 694                        break
 695                    }
 696                    log::trace!("recv: {}", &incoming_line);
 697                    match serde_json::from_str::<SdkMessage>(&incoming_line) {
 698                        Ok(message) => {
 699                            incoming_tx.unbounded_send(message).log_err();
 700                        }
 701                        Err(error) => {
 702                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
 703                        }
 704                    }
 705                    incoming_line.clear();
 706                }
 707            }
 708        }
 709
 710        Ok(outgoing_rx)
 711    }
 712}
 713
 714#[derive(Debug, Clone, Serialize, Deserialize)]
 715struct Message {
 716    role: Role,
 717    content: Content,
 718    #[serde(skip_serializing_if = "Option::is_none")]
 719    id: Option<String>,
 720    #[serde(skip_serializing_if = "Option::is_none")]
 721    model: Option<String>,
 722    #[serde(skip_serializing_if = "Option::is_none")]
 723    stop_reason: Option<String>,
 724    #[serde(skip_serializing_if = "Option::is_none")]
 725    stop_sequence: Option<String>,
 726    #[serde(skip_serializing_if = "Option::is_none")]
 727    usage: Option<Usage>,
 728}
 729
 730#[derive(Debug, Clone, Serialize, Deserialize)]
 731#[serde(untagged)]
 732enum Content {
 733    UntaggedText(String),
 734    Chunks(Vec<ContentChunk>),
 735}
 736
 737impl Content {
 738    pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
 739        match self {
 740            Self::Chunks(chunks) => chunks.into_iter(),
 741            Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
 742        }
 743    }
 744}
 745
 746impl Display for Content {
 747    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 748        match self {
 749            Content::UntaggedText(txt) => write!(f, "{}", txt),
 750            Content::Chunks(chunks) => {
 751                for chunk in chunks {
 752                    write!(f, "{}", chunk)?;
 753                }
 754                Ok(())
 755            }
 756        }
 757    }
 758}
 759
 760#[derive(Debug, Clone, Serialize, Deserialize)]
 761#[serde(tag = "type", rename_all = "snake_case")]
 762enum ContentChunk {
 763    Text {
 764        text: String,
 765    },
 766    ToolUse {
 767        id: String,
 768        name: String,
 769        input: serde_json::Value,
 770    },
 771    ToolResult {
 772        content: Content,
 773        tool_use_id: String,
 774    },
 775    Thinking {
 776        thinking: String,
 777    },
 778    RedactedThinking,
 779    Image {
 780        source: ImageSource,
 781    },
 782    // TODO
 783    Document,
 784    WebSearchToolResult,
 785    #[serde(untagged)]
 786    UntaggedText(String),
 787}
 788
 789#[derive(Debug, Clone, Serialize, Deserialize)]
 790#[serde(tag = "type", rename_all = "snake_case")]
 791enum ImageSource {
 792    Base64 { data: String, media_type: String },
 793    Url { url: String },
 794}
 795
 796impl Into<acp::ContentBlock> for ImageSource {
 797    fn into(self) -> acp::ContentBlock {
 798        match self {
 799            ImageSource::Base64 { data, media_type } => {
 800                acp::ContentBlock::Image(acp::ImageContent {
 801                    annotations: None,
 802                    data,
 803                    mime_type: media_type,
 804                    uri: None,
 805                })
 806            }
 807            ImageSource::Url { url } => acp::ContentBlock::Image(acp::ImageContent {
 808                annotations: None,
 809                data: "".to_string(),
 810                mime_type: "".to_string(),
 811                uri: Some(url),
 812            }),
 813        }
 814    }
 815}
 816
 817impl Display for ContentChunk {
 818    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 819        match self {
 820            ContentChunk::Text { text } => write!(f, "{}", text),
 821            ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
 822            ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
 823            ContentChunk::UntaggedText(text) => write!(f, "{}", text),
 824            ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
 825            ContentChunk::Image { .. }
 826            | ContentChunk::Document
 827            | ContentChunk::ToolUse { .. }
 828            | ContentChunk::WebSearchToolResult => {
 829                write!(f, "\n{:?}\n", &self)
 830            }
 831        }
 832    }
 833}
 834
 835#[derive(Debug, Clone, Serialize, Deserialize)]
 836struct Usage {
 837    input_tokens: u32,
 838    cache_creation_input_tokens: u32,
 839    cache_read_input_tokens: u32,
 840    output_tokens: u32,
 841    service_tier: String,
 842}
 843
 844#[derive(Debug, Clone, Serialize, Deserialize)]
 845#[serde(rename_all = "snake_case")]
 846enum Role {
 847    System,
 848    Assistant,
 849    User,
 850}
 851
 852#[derive(Debug, Clone, Serialize, Deserialize)]
 853struct MessageParam {
 854    role: Role,
 855    content: String,
 856}
 857
 858#[derive(Debug, Clone, Serialize, Deserialize)]
 859#[serde(tag = "type", rename_all = "snake_case")]
 860enum SdkMessage {
 861    // An assistant message
 862    Assistant {
 863        message: Message, // from Anthropic SDK
 864        #[serde(skip_serializing_if = "Option::is_none")]
 865        session_id: Option<String>,
 866    },
 867    // A user message
 868    User {
 869        message: Message, // from Anthropic SDK
 870        #[serde(skip_serializing_if = "Option::is_none")]
 871        session_id: Option<String>,
 872    },
 873    // Emitted as the last message in a conversation
 874    Result {
 875        subtype: ResultErrorType,
 876        duration_ms: f64,
 877        duration_api_ms: f64,
 878        is_error: bool,
 879        num_turns: i32,
 880        #[serde(skip_serializing_if = "Option::is_none")]
 881        result: Option<String>,
 882        session_id: String,
 883        total_cost_usd: f64,
 884    },
 885    // Emitted as the first message at the start of a conversation
 886    System {
 887        cwd: String,
 888        session_id: String,
 889        tools: Vec<String>,
 890        model: String,
 891        mcp_servers: Vec<McpServer>,
 892        #[serde(rename = "apiKeySource")]
 893        api_key_source: String,
 894        #[serde(rename = "permissionMode")]
 895        permission_mode: PermissionMode,
 896    },
 897    /// Messages used to control the conversation, outside of chat messages to the model
 898    ControlRequest {
 899        request_id: String,
 900        request: ControlRequest,
 901    },
 902    /// Response to a control request
 903    ControlResponse { response: ControlResponse },
 904}
 905
 906#[derive(Debug, Clone, Serialize, Deserialize)]
 907#[serde(tag = "subtype", rename_all = "snake_case")]
 908enum ControlRequest {
 909    /// Cancel the current conversation
 910    Interrupt,
 911}
 912
 913#[derive(Debug, Clone, Serialize, Deserialize)]
 914struct ControlResponse {
 915    request_id: String,
 916    subtype: ResultErrorType,
 917}
 918
 919#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
 920#[serde(rename_all = "snake_case")]
 921enum ResultErrorType {
 922    Success,
 923    ErrorMaxTurns,
 924    ErrorDuringExecution,
 925}
 926
 927impl Display for ResultErrorType {
 928    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 929        match self {
 930            ResultErrorType::Success => write!(f, "success"),
 931            ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
 932            ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
 933        }
 934    }
 935}
 936
 937fn acp_content_to_claude(prompt: Vec<acp::ContentBlock>) -> Vec<ContentChunk> {
 938    let mut content = Vec::with_capacity(prompt.len());
 939    let mut context = Vec::with_capacity(prompt.len());
 940
 941    for chunk in prompt {
 942        match chunk {
 943            acp::ContentBlock::Text(text_content) => {
 944                content.push(ContentChunk::Text {
 945                    text: text_content.text,
 946                });
 947            }
 948            acp::ContentBlock::ResourceLink(resource_link) => {
 949                match MentionUri::parse(&resource_link.uri) {
 950                    Ok(uri) => {
 951                        content.push(ContentChunk::Text {
 952                            text: format!("{}", uri.as_link()),
 953                        });
 954                    }
 955                    Err(_) => {
 956                        content.push(ContentChunk::Text {
 957                            text: resource_link.uri,
 958                        });
 959                    }
 960                }
 961            }
 962            acp::ContentBlock::Resource(resource) => match resource.resource {
 963                acp::EmbeddedResourceResource::TextResourceContents(resource) => {
 964                    match MentionUri::parse(&resource.uri) {
 965                        Ok(uri) => {
 966                            content.push(ContentChunk::Text {
 967                                text: format!("{}", uri.as_link()),
 968                            });
 969                        }
 970                        Err(_) => {
 971                            content.push(ContentChunk::Text {
 972                                text: resource.uri.clone(),
 973                            });
 974                        }
 975                    }
 976
 977                    context.push(ContentChunk::Text {
 978                        text: format!(
 979                            "\n<context ref=\"{}\">\n{}\n</context>",
 980                            resource.uri, resource.text
 981                        ),
 982                    });
 983                }
 984                acp::EmbeddedResourceResource::BlobResourceContents(_) => {
 985                    // Unsupported by SDK
 986                }
 987            },
 988            acp::ContentBlock::Image(acp::ImageContent {
 989                data, mime_type, ..
 990            }) => content.push(ContentChunk::Image {
 991                source: ImageSource::Base64 {
 992                    data,
 993                    media_type: mime_type,
 994                },
 995            }),
 996            acp::ContentBlock::Audio(_) => {
 997                // Unsupported by SDK
 998            }
 999        }
1000    }
1001
1002    content.extend(context);
1003    content
1004}
1005
1006fn new_request_id() -> String {
1007    use rand::Rng;
1008    // In the Claude Code TS SDK they just generate a random 12 character string,
1009    // `Math.random().toString(36).substring(2, 15)`
1010    rand::thread_rng()
1011        .sample_iter(&rand::distributions::Alphanumeric)
1012        .take(12)
1013        .map(char::from)
1014        .collect()
1015}
1016
1017#[derive(Debug, Clone, Serialize, Deserialize)]
1018struct McpServer {
1019    name: String,
1020    status: String,
1021}
1022
1023#[derive(Debug, Clone, Serialize, Deserialize)]
1024#[serde(rename_all = "camelCase")]
1025enum PermissionMode {
1026    Default,
1027    AcceptEdits,
1028    BypassPermissions,
1029    Plan,
1030}
1031
1032#[cfg(test)]
1033pub(crate) mod tests {
1034    use super::*;
1035    use crate::e2e_tests;
1036    use gpui::TestAppContext;
1037    use serde_json::json;
1038
1039    crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
1040
1041    pub fn local_command() -> AgentServerCommand {
1042        AgentServerCommand {
1043            path: "claude".into(),
1044            args: vec![],
1045            env: None,
1046        }
1047    }
1048
1049    #[gpui::test]
1050    #[cfg_attr(not(feature = "e2e"), ignore)]
1051    async fn test_todo_plan(cx: &mut TestAppContext) {
1052        let fs = e2e_tests::init_test(cx).await;
1053        let project = Project::test(fs, [], cx).await;
1054        let thread =
1055            e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
1056
1057        thread
1058            .update(cx, |thread, cx| {
1059                thread.send_raw(
1060                    "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
1061                    cx,
1062                )
1063            })
1064            .await
1065            .unwrap();
1066
1067        let mut entries_len = 0;
1068
1069        thread.read_with(cx, |thread, _| {
1070            entries_len = thread.plan().entries.len();
1071            assert!(thread.plan().entries.len() > 0, "Empty plan");
1072        });
1073
1074        thread
1075            .update(cx, |thread, cx| {
1076                thread.send_raw(
1077                    "Mark the first entry status as in progress without acting on it.",
1078                    cx,
1079                )
1080            })
1081            .await
1082            .unwrap();
1083
1084        thread.read_with(cx, |thread, _| {
1085            assert!(matches!(
1086                thread.plan().entries[0].status,
1087                acp::PlanEntryStatus::InProgress
1088            ));
1089            assert_eq!(thread.plan().entries.len(), entries_len);
1090        });
1091
1092        thread
1093            .update(cx, |thread, cx| {
1094                thread.send_raw(
1095                    "Now mark the first entry as completed without acting on it.",
1096                    cx,
1097                )
1098            })
1099            .await
1100            .unwrap();
1101
1102        thread.read_with(cx, |thread, _| {
1103            assert!(matches!(
1104                thread.plan().entries[0].status,
1105                acp::PlanEntryStatus::Completed
1106            ));
1107            assert_eq!(thread.plan().entries.len(), entries_len);
1108        });
1109    }
1110
1111    #[test]
1112    fn test_deserialize_content_untagged_text() {
1113        let json = json!("Hello, world!");
1114        let content: Content = serde_json::from_value(json).unwrap();
1115        match content {
1116            Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
1117            _ => panic!("Expected UntaggedText variant"),
1118        }
1119    }
1120
1121    #[test]
1122    fn test_deserialize_content_chunks() {
1123        let json = json!([
1124            {
1125                "type": "text",
1126                "text": "Hello"
1127            },
1128            {
1129                "type": "tool_use",
1130                "id": "tool_123",
1131                "name": "calculator",
1132                "input": {"operation": "add", "a": 1, "b": 2}
1133            }
1134        ]);
1135        let content: Content = serde_json::from_value(json).unwrap();
1136        match content {
1137            Content::Chunks(chunks) => {
1138                assert_eq!(chunks.len(), 2);
1139                match &chunks[0] {
1140                    ContentChunk::Text { text } => assert_eq!(text, "Hello"),
1141                    _ => panic!("Expected Text chunk"),
1142                }
1143                match &chunks[1] {
1144                    ContentChunk::ToolUse { id, name, input } => {
1145                        assert_eq!(id, "tool_123");
1146                        assert_eq!(name, "calculator");
1147                        assert_eq!(input["operation"], "add");
1148                        assert_eq!(input["a"], 1);
1149                        assert_eq!(input["b"], 2);
1150                    }
1151                    _ => panic!("Expected ToolUse chunk"),
1152                }
1153            }
1154            _ => panic!("Expected Chunks variant"),
1155        }
1156    }
1157
1158    #[test]
1159    fn test_deserialize_tool_result_untagged_text() {
1160        let json = json!({
1161            "type": "tool_result",
1162            "content": "Result content",
1163            "tool_use_id": "tool_456"
1164        });
1165        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1166        match chunk {
1167            ContentChunk::ToolResult {
1168                content,
1169                tool_use_id,
1170            } => {
1171                match content {
1172                    Content::UntaggedText(text) => assert_eq!(text, "Result content"),
1173                    _ => panic!("Expected UntaggedText content"),
1174                }
1175                assert_eq!(tool_use_id, "tool_456");
1176            }
1177            _ => panic!("Expected ToolResult variant"),
1178        }
1179    }
1180
1181    #[test]
1182    fn test_deserialize_tool_result_chunks() {
1183        let json = json!({
1184            "type": "tool_result",
1185            "content": [
1186                {
1187                    "type": "text",
1188                    "text": "Processing complete"
1189                },
1190                {
1191                    "type": "text",
1192                    "text": "Result: 42"
1193                }
1194            ],
1195            "tool_use_id": "tool_789"
1196        });
1197        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1198        match chunk {
1199            ContentChunk::ToolResult {
1200                content,
1201                tool_use_id,
1202            } => {
1203                match content {
1204                    Content::Chunks(chunks) => {
1205                        assert_eq!(chunks.len(), 2);
1206                        match &chunks[0] {
1207                            ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1208                            _ => panic!("Expected Text chunk"),
1209                        }
1210                        match &chunks[1] {
1211                            ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1212                            _ => panic!("Expected Text chunk"),
1213                        }
1214                    }
1215                    _ => panic!("Expected Chunks content"),
1216                }
1217                assert_eq!(tool_use_id, "tool_789");
1218            }
1219            _ => panic!("Expected ToolResult variant"),
1220        }
1221    }
1222
1223    #[test]
1224    fn test_acp_content_to_claude() {
1225        let acp_content = vec![
1226            acp::ContentBlock::Text(acp::TextContent {
1227                text: "Hello world".to_string(),
1228                annotations: None,
1229            }),
1230            acp::ContentBlock::Image(acp::ImageContent {
1231                data: "base64data".to_string(),
1232                mime_type: "image/png".to_string(),
1233                annotations: None,
1234                uri: None,
1235            }),
1236            acp::ContentBlock::ResourceLink(acp::ResourceLink {
1237                uri: "file:///path/to/example.rs".to_string(),
1238                name: "example.rs".to_string(),
1239                annotations: None,
1240                description: None,
1241                mime_type: None,
1242                size: None,
1243                title: None,
1244            }),
1245            acp::ContentBlock::Resource(acp::EmbeddedResource {
1246                annotations: None,
1247                resource: acp::EmbeddedResourceResource::TextResourceContents(
1248                    acp::TextResourceContents {
1249                        mime_type: None,
1250                        text: "fn main() { println!(\"Hello!\"); }".to_string(),
1251                        uri: "file:///path/to/code.rs".to_string(),
1252                    },
1253                ),
1254            }),
1255            acp::ContentBlock::ResourceLink(acp::ResourceLink {
1256                uri: "invalid_uri_format".to_string(),
1257                name: "invalid.txt".to_string(),
1258                annotations: None,
1259                description: None,
1260                mime_type: None,
1261                size: None,
1262                title: None,
1263            }),
1264        ];
1265
1266        let claude_content = acp_content_to_claude(acp_content);
1267
1268        assert_eq!(claude_content.len(), 6);
1269
1270        match &claude_content[0] {
1271            ContentChunk::Text { text } => assert_eq!(text, "Hello world"),
1272            _ => panic!("Expected Text chunk"),
1273        }
1274
1275        match &claude_content[1] {
1276            ContentChunk::Image { source } => match source {
1277                ImageSource::Base64 { data, media_type } => {
1278                    assert_eq!(data, "base64data");
1279                    assert_eq!(media_type, "image/png");
1280                }
1281                _ => panic!("Expected Base64 image source"),
1282            },
1283            _ => panic!("Expected Image chunk"),
1284        }
1285
1286        match &claude_content[2] {
1287            ContentChunk::Text { text } => {
1288                assert!(text.contains("example.rs"));
1289                assert!(text.contains("file:///path/to/example.rs"));
1290            }
1291            _ => panic!("Expected Text chunk for ResourceLink"),
1292        }
1293
1294        match &claude_content[3] {
1295            ContentChunk::Text { text } => {
1296                assert!(text.contains("code.rs"));
1297                assert!(text.contains("file:///path/to/code.rs"));
1298            }
1299            _ => panic!("Expected Text chunk for Resource"),
1300        }
1301
1302        match &claude_content[4] {
1303            ContentChunk::Text { text } => {
1304                assert_eq!(text, "invalid_uri_format");
1305            }
1306            _ => panic!("Expected Text chunk for invalid URI"),
1307        }
1308
1309        match &claude_content[5] {
1310            ContentChunk::Text { text } => {
1311                assert!(text.contains("<context ref=\"file:///path/to/code.rs\">"));
1312                assert!(text.contains("fn main() { println!(\"Hello!\"); }"));
1313                assert!(text.contains("</context>"));
1314            }
1315            _ => panic!("Expected Text chunk for context"),
1316        }
1317    }
1318}