claude.rs

   1mod mcp_server;
   2pub mod tools;
   3
   4use collections::HashMap;
   5use context_server::listener::McpServerTool;
   6use project::Project;
   7use settings::SettingsStore;
   8use smol::process::Child;
   9use std::any::Any;
  10use std::cell::RefCell;
  11use std::fmt::Display;
  12use std::path::Path;
  13use std::rc::Rc;
  14use uuid::Uuid;
  15
  16use agent_client_protocol as acp;
  17use anyhow::{Context as _, Result, anyhow};
  18use futures::channel::oneshot;
  19use futures::{AsyncBufReadExt, AsyncWriteExt};
  20use futures::{
  21    AsyncRead, AsyncWrite, FutureExt, StreamExt,
  22    channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
  23    io::BufReader,
  24    select_biased,
  25};
  26use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  27use serde::{Deserialize, Serialize};
  28use util::{ResultExt, debug_panic};
  29
  30use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
  31use crate::claude::tools::ClaudeTool;
  32use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
  33use acp_thread::{AcpThread, AgentConnection};
  34
  35#[derive(Clone)]
  36pub struct ClaudeCode;
  37
  38impl AgentServer for ClaudeCode {
  39    fn name(&self) -> &'static str {
  40        "Claude Code"
  41    }
  42
  43    fn empty_state_headline(&self) -> &'static str {
  44        self.name()
  45    }
  46
  47    fn empty_state_message(&self) -> &'static str {
  48        "How can I help you today?"
  49    }
  50
  51    fn logo(&self) -> ui::IconName {
  52        ui::IconName::AiClaude
  53    }
  54
  55    fn connect(
  56        &self,
  57        _root_dir: &Path,
  58        _project: &Entity<Project>,
  59        _cx: &mut App,
  60    ) -> Task<Result<Rc<dyn AgentConnection>>> {
  61        let connection = ClaudeAgentConnection {
  62            sessions: Default::default(),
  63        };
  64
  65        Task::ready(Ok(Rc::new(connection) as _))
  66    }
  67}
  68
  69struct ClaudeAgentConnection {
  70    sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
  71}
  72
  73impl AgentConnection for ClaudeAgentConnection {
  74    fn new_thread(
  75        self: Rc<Self>,
  76        project: Entity<Project>,
  77        cwd: &Path,
  78        cx: &mut App,
  79    ) -> Task<Result<Entity<AcpThread>>> {
  80        let cwd = cwd.to_owned();
  81        cx.spawn(async move |cx| {
  82            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
  83            let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?;
  84
  85            let mut mcp_servers = HashMap::default();
  86            mcp_servers.insert(
  87                mcp_server::SERVER_NAME.to_string(),
  88                permission_mcp_server.server_config()?,
  89            );
  90            let mcp_config = McpConfig { mcp_servers };
  91
  92            let mcp_config_file = tempfile::NamedTempFile::new()?;
  93            let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
  94
  95            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
  96            mcp_config_file
  97                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
  98                .await?;
  99            mcp_config_file.flush().await?;
 100
 101            let settings = cx.read_global(|settings: &SettingsStore, _| {
 102                settings.get::<AllAgentServersSettings>(None).claude.clone()
 103            })?;
 104
 105            let Some(command) = AgentServerCommand::resolve(
 106                "claude",
 107                &[],
 108                Some(&util::paths::home_dir().join(".claude/local/claude")),
 109                settings,
 110                &project,
 111                cx,
 112            )
 113            .await
 114            else {
 115                anyhow::bail!("Failed to find claude binary");
 116            };
 117
 118            let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
 119            let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
 120
 121            let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
 122
 123            log::trace!("Starting session with id: {}", session_id);
 124
 125            let mut child = spawn_claude(
 126                &command,
 127                ClaudeSessionMode::Start,
 128                session_id.clone(),
 129                &mcp_config_path,
 130                &cwd,
 131            )?;
 132
 133            let stdout = child.stdout.take().context("Failed to take stdout")?;
 134            let stdin = child.stdin.take().context("Failed to take stdin")?;
 135            let stderr = child.stderr.take().context("Failed to take stderr")?;
 136
 137            let pid = child.id();
 138            log::trace!("Spawned (pid: {})", pid);
 139
 140            cx.background_spawn(async move {
 141                let mut stderr = BufReader::new(stderr);
 142                let mut line = String::new();
 143                while let Ok(n) = stderr.read_line(&mut line).await
 144                    && n > 0
 145                {
 146                    log::warn!("agent stderr: {}", &line);
 147                    line.clear();
 148                }
 149            })
 150            .detach();
 151
 152            cx.background_spawn(async move {
 153                let mut outgoing_rx = Some(outgoing_rx);
 154
 155                ClaudeAgentSession::handle_io(
 156                    outgoing_rx.take().unwrap(),
 157                    incoming_message_tx.clone(),
 158                    stdin,
 159                    stdout,
 160                )
 161                .await?;
 162
 163                log::trace!("Stopped (pid: {})", pid);
 164
 165                drop(mcp_config_path);
 166                anyhow::Ok(())
 167            })
 168            .detach();
 169
 170            let turn_state = Rc::new(RefCell::new(TurnState::None));
 171
 172            let handler_task = cx.spawn({
 173                let turn_state = turn_state.clone();
 174                let mut thread_rx = thread_rx.clone();
 175                async move |cx| {
 176                    while let Some(message) = incoming_message_rx.next().await {
 177                        ClaudeAgentSession::handle_message(
 178                            thread_rx.clone(),
 179                            message,
 180                            turn_state.clone(),
 181                            cx,
 182                        )
 183                        .await
 184                    }
 185
 186                    if let Some(status) = child.status().await.log_err() {
 187                        if let Some(thread) = thread_rx.recv().await.ok() {
 188                            thread
 189                                .update(cx, |thread, cx| {
 190                                    thread.emit_server_exited(status, cx);
 191                                })
 192                                .ok();
 193                        }
 194                    }
 195                }
 196            });
 197
 198            let thread = cx.new(|cx| {
 199                AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
 200            })?;
 201
 202            thread_tx.send(thread.downgrade())?;
 203
 204            let session = ClaudeAgentSession {
 205                outgoing_tx,
 206                turn_state,
 207                _handler_task: handler_task,
 208                _mcp_server: Some(permission_mcp_server),
 209            };
 210
 211            self.sessions.borrow_mut().insert(session_id, session);
 212
 213            Ok(thread)
 214        })
 215    }
 216
 217    fn auth_methods(&self) -> &[acp::AuthMethod] {
 218        &[]
 219    }
 220
 221    fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 222        Task::ready(Err(anyhow!("Authentication not supported")))
 223    }
 224
 225    fn prompt(
 226        &self,
 227        _id: Option<acp_thread::UserMessageId>,
 228        params: acp::PromptRequest,
 229        cx: &mut App,
 230    ) -> Task<Result<acp::PromptResponse>> {
 231        let sessions = self.sessions.borrow();
 232        let Some(session) = sessions.get(&params.session_id) else {
 233            return Task::ready(Err(anyhow!(
 234                "Attempted to send message to nonexistent session {}",
 235                params.session_id
 236            )));
 237        };
 238
 239        let (end_tx, end_rx) = oneshot::channel();
 240        session.turn_state.replace(TurnState::InProgress { end_tx });
 241
 242        let mut content = String::new();
 243        for chunk in params.prompt {
 244            match chunk {
 245                acp::ContentBlock::Text(text_content) => {
 246                    content.push_str(&text_content.text);
 247                }
 248                acp::ContentBlock::ResourceLink(resource_link) => {
 249                    content.push_str(&format!("@{}", resource_link.uri));
 250                }
 251                acp::ContentBlock::Audio(_)
 252                | acp::ContentBlock::Image(_)
 253                | acp::ContentBlock::Resource(_) => {
 254                    // TODO
 255                }
 256            }
 257        }
 258
 259        if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
 260            message: Message {
 261                role: Role::User,
 262                content: Content::UntaggedText(content),
 263                id: None,
 264                model: None,
 265                stop_reason: None,
 266                stop_sequence: None,
 267                usage: None,
 268            },
 269            session_id: Some(params.session_id.to_string()),
 270        }) {
 271            return Task::ready(Err(anyhow!(err)));
 272        }
 273
 274        cx.foreground_executor().spawn(async move { end_rx.await? })
 275    }
 276
 277    fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
 278        let sessions = self.sessions.borrow();
 279        let Some(session) = sessions.get(&session_id) else {
 280            log::warn!("Attempted to cancel nonexistent session {}", session_id);
 281            return;
 282        };
 283
 284        let request_id = new_request_id();
 285
 286        let turn_state = session.turn_state.take();
 287        let TurnState::InProgress { end_tx } = turn_state else {
 288            // Already cancelled or idle, put it back
 289            session.turn_state.replace(turn_state);
 290            return;
 291        };
 292
 293        session.turn_state.replace(TurnState::CancelRequested {
 294            end_tx,
 295            request_id: request_id.clone(),
 296        });
 297
 298        session
 299            .outgoing_tx
 300            .unbounded_send(SdkMessage::ControlRequest {
 301                request_id,
 302                request: ControlRequest::Interrupt,
 303            })
 304            .log_err();
 305    }
 306
 307    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 308        self
 309    }
 310}
 311
 312#[derive(Clone, Copy)]
 313enum ClaudeSessionMode {
 314    Start,
 315    #[expect(dead_code)]
 316    Resume,
 317}
 318
 319fn spawn_claude(
 320    command: &AgentServerCommand,
 321    mode: ClaudeSessionMode,
 322    session_id: acp::SessionId,
 323    mcp_config_path: &Path,
 324    root_dir: &Path,
 325) -> Result<Child> {
 326    let child = util::command::new_smol_command(&command.path)
 327        .args([
 328            "--input-format",
 329            "stream-json",
 330            "--output-format",
 331            "stream-json",
 332            "--print",
 333            "--verbose",
 334            "--mcp-config",
 335            mcp_config_path.to_string_lossy().as_ref(),
 336            "--permission-prompt-tool",
 337            &format!(
 338                "mcp__{}__{}",
 339                mcp_server::SERVER_NAME,
 340                mcp_server::PermissionTool::NAME,
 341            ),
 342            "--allowedTools",
 343            &format!(
 344                "mcp__{}__{},mcp__{}__{}",
 345                mcp_server::SERVER_NAME,
 346                mcp_server::EditTool::NAME,
 347                mcp_server::SERVER_NAME,
 348                mcp_server::ReadTool::NAME
 349            ),
 350            "--disallowedTools",
 351            "Read,Edit",
 352        ])
 353        .args(match mode {
 354            ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
 355            ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
 356        })
 357        .args(command.args.iter().map(|arg| arg.as_str()))
 358        .current_dir(root_dir)
 359        .stdin(std::process::Stdio::piped())
 360        .stdout(std::process::Stdio::piped())
 361        .stderr(std::process::Stdio::piped())
 362        .kill_on_drop(true)
 363        .spawn()?;
 364
 365    Ok(child)
 366}
 367
 368struct ClaudeAgentSession {
 369    outgoing_tx: UnboundedSender<SdkMessage>,
 370    turn_state: Rc<RefCell<TurnState>>,
 371    _mcp_server: Option<ClaudeZedMcpServer>,
 372    _handler_task: Task<()>,
 373}
 374
 375#[derive(Debug, Default)]
 376enum TurnState {
 377    #[default]
 378    None,
 379    InProgress {
 380        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 381    },
 382    CancelRequested {
 383        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 384        request_id: String,
 385    },
 386    CancelConfirmed {
 387        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 388    },
 389}
 390
 391impl TurnState {
 392    fn is_cancelled(&self) -> bool {
 393        matches!(self, TurnState::CancelConfirmed { .. })
 394    }
 395
 396    fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
 397        match self {
 398            TurnState::None => None,
 399            TurnState::InProgress { end_tx, .. } => Some(end_tx),
 400            TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
 401            TurnState::CancelConfirmed { end_tx } => Some(end_tx),
 402        }
 403    }
 404
 405    fn confirm_cancellation(self, id: &str) -> Self {
 406        match self {
 407            TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
 408                TurnState::CancelConfirmed { end_tx }
 409            }
 410            _ => self,
 411        }
 412    }
 413}
 414
 415impl ClaudeAgentSession {
 416    async fn handle_message(
 417        mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
 418        message: SdkMessage,
 419        turn_state: Rc<RefCell<TurnState>>,
 420        cx: &mut AsyncApp,
 421    ) {
 422        match message {
 423            // we should only be sending these out, they don't need to be in the thread
 424            SdkMessage::ControlRequest { .. } => {}
 425            SdkMessage::User {
 426                message,
 427                session_id: _,
 428            } => {
 429                let Some(thread) = thread_rx
 430                    .recv()
 431                    .await
 432                    .log_err()
 433                    .and_then(|entity| entity.upgrade())
 434                else {
 435                    log::error!("Received an SDK message but thread is gone");
 436                    return;
 437                };
 438
 439                for chunk in message.content.chunks() {
 440                    match chunk {
 441                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 442                            if !turn_state.borrow().is_cancelled() {
 443                                thread
 444                                    .update(cx, |thread, cx| {
 445                                        thread.push_user_content_block(None, text.into(), cx)
 446                                    })
 447                                    .log_err();
 448                            }
 449                        }
 450                        ContentChunk::ToolResult {
 451                            content,
 452                            tool_use_id,
 453                        } => {
 454                            let content = content.to_string();
 455                            thread
 456                                .update(cx, |thread, cx| {
 457                                    thread.update_tool_call(
 458                                        acp::ToolCallUpdate {
 459                                            id: acp::ToolCallId(tool_use_id.into()),
 460                                            fields: acp::ToolCallUpdateFields {
 461                                                status: if turn_state.borrow().is_cancelled() {
 462                                                    // Do not set to completed if turn was cancelled
 463                                                    None
 464                                                } else {
 465                                                    Some(acp::ToolCallStatus::Completed)
 466                                                },
 467                                                content: (!content.is_empty())
 468                                                    .then(|| vec![content.into()]),
 469                                                ..Default::default()
 470                                            },
 471                                        },
 472                                        cx,
 473                                    )
 474                                })
 475                                .log_err();
 476                        }
 477                        ContentChunk::Thinking { .. }
 478                        | ContentChunk::RedactedThinking
 479                        | ContentChunk::ToolUse { .. } => {
 480                            debug_panic!(
 481                                "Should not get {:?} with role: assistant. should we handle this?",
 482                                chunk
 483                            );
 484                        }
 485
 486                        ContentChunk::Image
 487                        | ContentChunk::Document
 488                        | ContentChunk::WebSearchToolResult => {
 489                            thread
 490                                .update(cx, |thread, cx| {
 491                                    thread.push_assistant_content_block(
 492                                        format!("Unsupported content: {:?}", chunk).into(),
 493                                        false,
 494                                        cx,
 495                                    )
 496                                })
 497                                .log_err();
 498                        }
 499                    }
 500                }
 501            }
 502            SdkMessage::Assistant {
 503                message,
 504                session_id: _,
 505            } => {
 506                let Some(thread) = thread_rx
 507                    .recv()
 508                    .await
 509                    .log_err()
 510                    .and_then(|entity| entity.upgrade())
 511                else {
 512                    log::error!("Received an SDK message but thread is gone");
 513                    return;
 514                };
 515
 516                for chunk in message.content.chunks() {
 517                    match chunk {
 518                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 519                            thread
 520                                .update(cx, |thread, cx| {
 521                                    thread.push_assistant_content_block(text.into(), false, cx)
 522                                })
 523                                .log_err();
 524                        }
 525                        ContentChunk::Thinking { thinking } => {
 526                            thread
 527                                .update(cx, |thread, cx| {
 528                                    thread.push_assistant_content_block(thinking.into(), true, cx)
 529                                })
 530                                .log_err();
 531                        }
 532                        ContentChunk::RedactedThinking => {
 533                            thread
 534                                .update(cx, |thread, cx| {
 535                                    thread.push_assistant_content_block(
 536                                        "[REDACTED]".into(),
 537                                        true,
 538                                        cx,
 539                                    )
 540                                })
 541                                .log_err();
 542                        }
 543                        ContentChunk::ToolUse { id, name, input } => {
 544                            let claude_tool = ClaudeTool::infer(&name, input);
 545
 546                            thread
 547                                .update(cx, |thread, cx| {
 548                                    if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
 549                                        thread.update_plan(
 550                                            acp::Plan {
 551                                                entries: params
 552                                                    .todos
 553                                                    .into_iter()
 554                                                    .map(Into::into)
 555                                                    .collect(),
 556                                            },
 557                                            cx,
 558                                        )
 559                                    } else {
 560                                        thread.upsert_tool_call(
 561                                            claude_tool.as_acp(acp::ToolCallId(id.into())),
 562                                            cx,
 563                                        )?;
 564                                    }
 565                                    anyhow::Ok(())
 566                                })
 567                                .log_err();
 568                        }
 569                        ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
 570                            debug_panic!(
 571                                "Should not get tool results with role: assistant. should we handle this?"
 572                            );
 573                        }
 574                        ContentChunk::Image | ContentChunk::Document => {
 575                            thread
 576                                .update(cx, |thread, cx| {
 577                                    thread.push_assistant_content_block(
 578                                        format!("Unsupported content: {:?}", chunk).into(),
 579                                        false,
 580                                        cx,
 581                                    )
 582                                })
 583                                .log_err();
 584                        }
 585                    }
 586                }
 587            }
 588            SdkMessage::Result {
 589                is_error,
 590                subtype,
 591                result,
 592                ..
 593            } => {
 594                let turn_state = turn_state.take();
 595                let was_cancelled = turn_state.is_cancelled();
 596                let Some(end_turn_tx) = turn_state.end_tx() else {
 597                    debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
 598                    return;
 599                };
 600
 601                if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution)
 602                {
 603                    end_turn_tx
 604                        .send(Err(anyhow!(
 605                            "Error: {}",
 606                            result.unwrap_or_else(|| subtype.to_string())
 607                        )))
 608                        .ok();
 609                } else {
 610                    let stop_reason = match subtype {
 611                        ResultErrorType::Success => acp::StopReason::EndTurn,
 612                        ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
 613                        ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
 614                    };
 615                    end_turn_tx
 616                        .send(Ok(acp::PromptResponse { stop_reason }))
 617                        .ok();
 618                }
 619            }
 620            SdkMessage::ControlResponse { response } => {
 621                if matches!(response.subtype, ResultErrorType::Success) {
 622                    let new_state = turn_state.take().confirm_cancellation(&response.request_id);
 623                    turn_state.replace(new_state);
 624                } else {
 625                    log::error!("Control response error: {:?}", response);
 626                }
 627            }
 628            SdkMessage::System { .. } => {}
 629        }
 630    }
 631
 632    async fn handle_io(
 633        mut outgoing_rx: UnboundedReceiver<SdkMessage>,
 634        incoming_tx: UnboundedSender<SdkMessage>,
 635        mut outgoing_bytes: impl Unpin + AsyncWrite,
 636        incoming_bytes: impl Unpin + AsyncRead,
 637    ) -> Result<UnboundedReceiver<SdkMessage>> {
 638        let mut output_reader = BufReader::new(incoming_bytes);
 639        let mut outgoing_line = Vec::new();
 640        let mut incoming_line = String::new();
 641        loop {
 642            select_biased! {
 643                message = outgoing_rx.next() => {
 644                    if let Some(message) = message {
 645                        outgoing_line.clear();
 646                        serde_json::to_writer(&mut outgoing_line, &message)?;
 647                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
 648                        outgoing_line.push(b'\n');
 649                        outgoing_bytes.write_all(&outgoing_line).await.ok();
 650                    } else {
 651                        break;
 652                    }
 653                }
 654                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
 655                    if bytes_read? == 0 {
 656                        break
 657                    }
 658                    log::trace!("recv: {}", &incoming_line);
 659                    match serde_json::from_str::<SdkMessage>(&incoming_line) {
 660                        Ok(message) => {
 661                            incoming_tx.unbounded_send(message).log_err();
 662                        }
 663                        Err(error) => {
 664                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
 665                        }
 666                    }
 667                    incoming_line.clear();
 668                }
 669            }
 670        }
 671
 672        Ok(outgoing_rx)
 673    }
 674}
 675
 676#[derive(Debug, Clone, Serialize, Deserialize)]
 677struct Message {
 678    role: Role,
 679    content: Content,
 680    #[serde(skip_serializing_if = "Option::is_none")]
 681    id: Option<String>,
 682    #[serde(skip_serializing_if = "Option::is_none")]
 683    model: Option<String>,
 684    #[serde(skip_serializing_if = "Option::is_none")]
 685    stop_reason: Option<String>,
 686    #[serde(skip_serializing_if = "Option::is_none")]
 687    stop_sequence: Option<String>,
 688    #[serde(skip_serializing_if = "Option::is_none")]
 689    usage: Option<Usage>,
 690}
 691
 692#[derive(Debug, Clone, Serialize, Deserialize)]
 693#[serde(untagged)]
 694enum Content {
 695    UntaggedText(String),
 696    Chunks(Vec<ContentChunk>),
 697}
 698
 699impl Content {
 700    pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
 701        match self {
 702            Self::Chunks(chunks) => chunks.into_iter(),
 703            Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
 704        }
 705    }
 706}
 707
 708impl Display for Content {
 709    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 710        match self {
 711            Content::UntaggedText(txt) => write!(f, "{}", txt),
 712            Content::Chunks(chunks) => {
 713                for chunk in chunks {
 714                    write!(f, "{}", chunk)?;
 715                }
 716                Ok(())
 717            }
 718        }
 719    }
 720}
 721
 722#[derive(Debug, Clone, Serialize, Deserialize)]
 723#[serde(tag = "type", rename_all = "snake_case")]
 724enum ContentChunk {
 725    Text {
 726        text: String,
 727    },
 728    ToolUse {
 729        id: String,
 730        name: String,
 731        input: serde_json::Value,
 732    },
 733    ToolResult {
 734        content: Content,
 735        tool_use_id: String,
 736    },
 737    Thinking {
 738        thinking: String,
 739    },
 740    RedactedThinking,
 741    // TODO
 742    Image,
 743    Document,
 744    WebSearchToolResult,
 745    #[serde(untagged)]
 746    UntaggedText(String),
 747}
 748
 749impl Display for ContentChunk {
 750    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 751        match self {
 752            ContentChunk::Text { text } => write!(f, "{}", text),
 753            ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
 754            ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
 755            ContentChunk::UntaggedText(text) => write!(f, "{}", text),
 756            ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
 757            ContentChunk::Image
 758            | ContentChunk::Document
 759            | ContentChunk::ToolUse { .. }
 760            | ContentChunk::WebSearchToolResult => {
 761                write!(f, "\n{:?}\n", &self)
 762            }
 763        }
 764    }
 765}
 766
 767#[derive(Debug, Clone, Serialize, Deserialize)]
 768struct Usage {
 769    input_tokens: u32,
 770    cache_creation_input_tokens: u32,
 771    cache_read_input_tokens: u32,
 772    output_tokens: u32,
 773    service_tier: String,
 774}
 775
 776#[derive(Debug, Clone, Serialize, Deserialize)]
 777#[serde(rename_all = "snake_case")]
 778enum Role {
 779    System,
 780    Assistant,
 781    User,
 782}
 783
 784#[derive(Debug, Clone, Serialize, Deserialize)]
 785struct MessageParam {
 786    role: Role,
 787    content: String,
 788}
 789
 790#[derive(Debug, Clone, Serialize, Deserialize)]
 791#[serde(tag = "type", rename_all = "snake_case")]
 792enum SdkMessage {
 793    // An assistant message
 794    Assistant {
 795        message: Message, // from Anthropic SDK
 796        #[serde(skip_serializing_if = "Option::is_none")]
 797        session_id: Option<String>,
 798    },
 799    // A user message
 800    User {
 801        message: Message, // from Anthropic SDK
 802        #[serde(skip_serializing_if = "Option::is_none")]
 803        session_id: Option<String>,
 804    },
 805    // Emitted as the last message in a conversation
 806    Result {
 807        subtype: ResultErrorType,
 808        duration_ms: f64,
 809        duration_api_ms: f64,
 810        is_error: bool,
 811        num_turns: i32,
 812        #[serde(skip_serializing_if = "Option::is_none")]
 813        result: Option<String>,
 814        session_id: String,
 815        total_cost_usd: f64,
 816    },
 817    // Emitted as the first message at the start of a conversation
 818    System {
 819        cwd: String,
 820        session_id: String,
 821        tools: Vec<String>,
 822        model: String,
 823        mcp_servers: Vec<McpServer>,
 824        #[serde(rename = "apiKeySource")]
 825        api_key_source: String,
 826        #[serde(rename = "permissionMode")]
 827        permission_mode: PermissionMode,
 828    },
 829    /// Messages used to control the conversation, outside of chat messages to the model
 830    ControlRequest {
 831        request_id: String,
 832        request: ControlRequest,
 833    },
 834    /// Response to a control request
 835    ControlResponse { response: ControlResponse },
 836}
 837
 838#[derive(Debug, Clone, Serialize, Deserialize)]
 839#[serde(tag = "subtype", rename_all = "snake_case")]
 840enum ControlRequest {
 841    /// Cancel the current conversation
 842    Interrupt,
 843}
 844
 845#[derive(Debug, Clone, Serialize, Deserialize)]
 846struct ControlResponse {
 847    request_id: String,
 848    subtype: ResultErrorType,
 849}
 850
 851#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
 852#[serde(rename_all = "snake_case")]
 853enum ResultErrorType {
 854    Success,
 855    ErrorMaxTurns,
 856    ErrorDuringExecution,
 857}
 858
 859impl Display for ResultErrorType {
 860    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 861        match self {
 862            ResultErrorType::Success => write!(f, "success"),
 863            ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
 864            ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
 865        }
 866    }
 867}
 868
 869fn new_request_id() -> String {
 870    use rand::Rng;
 871    // In the Claude Code TS SDK they just generate a random 12 character string,
 872    // `Math.random().toString(36).substring(2, 15)`
 873    rand::thread_rng()
 874        .sample_iter(&rand::distributions::Alphanumeric)
 875        .take(12)
 876        .map(char::from)
 877        .collect()
 878}
 879
 880#[derive(Debug, Clone, Serialize, Deserialize)]
 881struct McpServer {
 882    name: String,
 883    status: String,
 884}
 885
 886#[derive(Debug, Clone, Serialize, Deserialize)]
 887#[serde(rename_all = "camelCase")]
 888enum PermissionMode {
 889    Default,
 890    AcceptEdits,
 891    BypassPermissions,
 892    Plan,
 893}
 894
 895#[cfg(test)]
 896pub(crate) mod tests {
 897    use super::*;
 898    use crate::e2e_tests;
 899    use gpui::TestAppContext;
 900    use serde_json::json;
 901
 902    crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
 903
 904    pub fn local_command() -> AgentServerCommand {
 905        AgentServerCommand {
 906            path: "claude".into(),
 907            args: vec![],
 908            env: None,
 909        }
 910    }
 911
 912    #[gpui::test]
 913    #[cfg_attr(not(feature = "e2e"), ignore)]
 914    async fn test_todo_plan(cx: &mut TestAppContext) {
 915        let fs = e2e_tests::init_test(cx).await;
 916        let project = Project::test(fs, [], cx).await;
 917        let thread =
 918            e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
 919
 920        thread
 921            .update(cx, |thread, cx| {
 922                thread.send_raw(
 923                    "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
 924                    cx,
 925                )
 926            })
 927            .await
 928            .unwrap();
 929
 930        let mut entries_len = 0;
 931
 932        thread.read_with(cx, |thread, _| {
 933            entries_len = thread.plan().entries.len();
 934            assert!(thread.plan().entries.len() > 0, "Empty plan");
 935        });
 936
 937        thread
 938            .update(cx, |thread, cx| {
 939                thread.send_raw(
 940                    "Mark the first entry status as in progress without acting on it.",
 941                    cx,
 942                )
 943            })
 944            .await
 945            .unwrap();
 946
 947        thread.read_with(cx, |thread, _| {
 948            assert!(matches!(
 949                thread.plan().entries[0].status,
 950                acp::PlanEntryStatus::InProgress
 951            ));
 952            assert_eq!(thread.plan().entries.len(), entries_len);
 953        });
 954
 955        thread
 956            .update(cx, |thread, cx| {
 957                thread.send_raw(
 958                    "Now mark the first entry as completed without acting on it.",
 959                    cx,
 960                )
 961            })
 962            .await
 963            .unwrap();
 964
 965        thread.read_with(cx, |thread, _| {
 966            assert!(matches!(
 967                thread.plan().entries[0].status,
 968                acp::PlanEntryStatus::Completed
 969            ));
 970            assert_eq!(thread.plan().entries.len(), entries_len);
 971        });
 972    }
 973
 974    #[test]
 975    fn test_deserialize_content_untagged_text() {
 976        let json = json!("Hello, world!");
 977        let content: Content = serde_json::from_value(json).unwrap();
 978        match content {
 979            Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
 980            _ => panic!("Expected UntaggedText variant"),
 981        }
 982    }
 983
 984    #[test]
 985    fn test_deserialize_content_chunks() {
 986        let json = json!([
 987            {
 988                "type": "text",
 989                "text": "Hello"
 990            },
 991            {
 992                "type": "tool_use",
 993                "id": "tool_123",
 994                "name": "calculator",
 995                "input": {"operation": "add", "a": 1, "b": 2}
 996            }
 997        ]);
 998        let content: Content = serde_json::from_value(json).unwrap();
 999        match content {
1000            Content::Chunks(chunks) => {
1001                assert_eq!(chunks.len(), 2);
1002                match &chunks[0] {
1003                    ContentChunk::Text { text } => assert_eq!(text, "Hello"),
1004                    _ => panic!("Expected Text chunk"),
1005                }
1006                match &chunks[1] {
1007                    ContentChunk::ToolUse { id, name, input } => {
1008                        assert_eq!(id, "tool_123");
1009                        assert_eq!(name, "calculator");
1010                        assert_eq!(input["operation"], "add");
1011                        assert_eq!(input["a"], 1);
1012                        assert_eq!(input["b"], 2);
1013                    }
1014                    _ => panic!("Expected ToolUse chunk"),
1015                }
1016            }
1017            _ => panic!("Expected Chunks variant"),
1018        }
1019    }
1020
1021    #[test]
1022    fn test_deserialize_tool_result_untagged_text() {
1023        let json = json!({
1024            "type": "tool_result",
1025            "content": "Result content",
1026            "tool_use_id": "tool_456"
1027        });
1028        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1029        match chunk {
1030            ContentChunk::ToolResult {
1031                content,
1032                tool_use_id,
1033            } => {
1034                match content {
1035                    Content::UntaggedText(text) => assert_eq!(text, "Result content"),
1036                    _ => panic!("Expected UntaggedText content"),
1037                }
1038                assert_eq!(tool_use_id, "tool_456");
1039            }
1040            _ => panic!("Expected ToolResult variant"),
1041        }
1042    }
1043
1044    #[test]
1045    fn test_deserialize_tool_result_chunks() {
1046        let json = json!({
1047            "type": "tool_result",
1048            "content": [
1049                {
1050                    "type": "text",
1051                    "text": "Processing complete"
1052                },
1053                {
1054                    "type": "text",
1055                    "text": "Result: 42"
1056                }
1057            ],
1058            "tool_use_id": "tool_789"
1059        });
1060        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1061        match chunk {
1062            ContentChunk::ToolResult {
1063                content,
1064                tool_use_id,
1065            } => {
1066                match content {
1067                    Content::Chunks(chunks) => {
1068                        assert_eq!(chunks.len(), 2);
1069                        match &chunks[0] {
1070                            ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1071                            _ => panic!("Expected Text chunk"),
1072                        }
1073                        match &chunks[1] {
1074                            ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1075                            _ => panic!("Expected Text chunk"),
1076                        }
1077                    }
1078                    _ => panic!("Expected Chunks content"),
1079                }
1080                assert_eq!(tool_use_id, "tool_789");
1081            }
1082            _ => panic!("Expected ToolResult variant"),
1083        }
1084    }
1085}