claude.rs

   1mod mcp_server;
   2pub mod tools;
   3
   4use collections::HashMap;
   5use context_server::listener::McpServerTool;
   6use project::Project;
   7use settings::SettingsStore;
   8use smol::process::Child;
   9use std::cell::RefCell;
  10use std::fmt::Display;
  11use std::path::Path;
  12use std::rc::Rc;
  13use uuid::Uuid;
  14
  15use agent_client_protocol as acp;
  16use anyhow::{Result, anyhow};
  17use futures::channel::oneshot;
  18use futures::{AsyncBufReadExt, AsyncWriteExt};
  19use futures::{
  20    AsyncRead, AsyncWrite, FutureExt, StreamExt,
  21    channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
  22    io::BufReader,
  23    select_biased,
  24};
  25use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  26use serde::{Deserialize, Serialize};
  27use util::{ResultExt, debug_panic};
  28
  29use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
  30use crate::claude::tools::ClaudeTool;
  31use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
  32use acp_thread::{AcpThread, AgentConnection};
  33
  34#[derive(Clone)]
  35pub struct ClaudeCode;
  36
  37impl AgentServer for ClaudeCode {
  38    fn name(&self) -> &'static str {
  39        "Claude Code"
  40    }
  41
  42    fn empty_state_headline(&self) -> &'static str {
  43        self.name()
  44    }
  45
  46    fn empty_state_message(&self) -> &'static str {
  47        "How can I help you today?"
  48    }
  49
  50    fn logo(&self) -> ui::IconName {
  51        ui::IconName::AiClaude
  52    }
  53
  54    fn connect(
  55        &self,
  56        _root_dir: &Path,
  57        _project: &Entity<Project>,
  58        _cx: &mut App,
  59    ) -> Task<Result<Rc<dyn AgentConnection>>> {
  60        let connection = ClaudeAgentConnection {
  61            sessions: Default::default(),
  62        };
  63
  64        Task::ready(Ok(Rc::new(connection) as _))
  65    }
  66}
  67
  68struct ClaudeAgentConnection {
  69    sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
  70}
  71
  72impl AgentConnection for ClaudeAgentConnection {
  73    fn new_thread(
  74        self: Rc<Self>,
  75        project: Entity<Project>,
  76        cwd: &Path,
  77        cx: &mut AsyncApp,
  78    ) -> Task<Result<Entity<AcpThread>>> {
  79        let cwd = cwd.to_owned();
  80        cx.spawn(async move |cx| {
  81            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
  82            let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?;
  83
  84            let mut mcp_servers = HashMap::default();
  85            mcp_servers.insert(
  86                mcp_server::SERVER_NAME.to_string(),
  87                permission_mcp_server.server_config()?,
  88            );
  89            let mcp_config = McpConfig { mcp_servers };
  90
  91            let mcp_config_file = tempfile::NamedTempFile::new()?;
  92            let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
  93
  94            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
  95            mcp_config_file
  96                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
  97                .await?;
  98            mcp_config_file.flush().await?;
  99
 100            let settings = cx.read_global(|settings: &SettingsStore, _| {
 101                settings.get::<AllAgentServersSettings>(None).claude.clone()
 102            })?;
 103
 104            let Some(command) = AgentServerCommand::resolve(
 105                "claude",
 106                &[],
 107                Some(&util::paths::home_dir().join(".claude/local/claude")),
 108                settings,
 109                &project,
 110                cx,
 111            )
 112            .await
 113            else {
 114                anyhow::bail!("Failed to find claude binary");
 115            };
 116
 117            let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
 118            let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
 119
 120            let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
 121
 122            log::trace!("Starting session with id: {}", session_id);
 123
 124            let mut child = spawn_claude(
 125                &command,
 126                ClaudeSessionMode::Start,
 127                session_id.clone(),
 128                &mcp_config_path,
 129                &cwd,
 130            )?;
 131
 132            let stdin = child.stdin.take().unwrap();
 133            let stdout = child.stdout.take().unwrap();
 134
 135            let pid = child.id();
 136            log::trace!("Spawned (pid: {})", pid);
 137
 138            cx.background_spawn(async move {
 139                let mut outgoing_rx = Some(outgoing_rx);
 140
 141                ClaudeAgentSession::handle_io(
 142                    outgoing_rx.take().unwrap(),
 143                    incoming_message_tx.clone(),
 144                    stdin,
 145                    stdout,
 146                )
 147                .await?;
 148
 149                log::trace!("Stopped (pid: {})", pid);
 150
 151                drop(mcp_config_path);
 152                anyhow::Ok(())
 153            })
 154            .detach();
 155
 156            let turn_state = Rc::new(RefCell::new(TurnState::None));
 157
 158            let handler_task = cx.spawn({
 159                let turn_state = turn_state.clone();
 160                let mut thread_rx = thread_rx.clone();
 161                async move |cx| {
 162                    while let Some(message) = incoming_message_rx.next().await {
 163                        ClaudeAgentSession::handle_message(
 164                            thread_rx.clone(),
 165                            message,
 166                            turn_state.clone(),
 167                            cx,
 168                        )
 169                        .await
 170                    }
 171
 172                    if let Some(status) = child.status().await.log_err() {
 173                        if let Some(thread) = thread_rx.recv().await.ok() {
 174                            thread
 175                                .update(cx, |thread, cx| {
 176                                    thread.emit_server_exited(status, cx);
 177                                })
 178                                .ok();
 179                        }
 180                    }
 181                }
 182            });
 183
 184            let thread = cx.new(|cx| {
 185                AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
 186            })?;
 187
 188            thread_tx.send(thread.downgrade())?;
 189
 190            let session = ClaudeAgentSession {
 191                outgoing_tx,
 192                turn_state,
 193                _handler_task: handler_task,
 194                _mcp_server: Some(permission_mcp_server),
 195            };
 196
 197            self.sessions.borrow_mut().insert(session_id, session);
 198
 199            Ok(thread)
 200        })
 201    }
 202
 203    fn auth_methods(&self) -> &[acp::AuthMethod] {
 204        &[]
 205    }
 206
 207    fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 208        Task::ready(Err(anyhow!("Authentication not supported")))
 209    }
 210
 211    fn prompt(
 212        &self,
 213        params: acp::PromptRequest,
 214        cx: &mut App,
 215    ) -> Task<Result<acp::PromptResponse>> {
 216        let sessions = self.sessions.borrow();
 217        let Some(session) = sessions.get(&params.session_id) else {
 218            return Task::ready(Err(anyhow!(
 219                "Attempted to send message to nonexistent session {}",
 220                params.session_id
 221            )));
 222        };
 223
 224        let (end_tx, end_rx) = oneshot::channel();
 225        session.turn_state.replace(TurnState::InProgress { end_tx });
 226
 227        let mut content = String::new();
 228        for chunk in params.prompt {
 229            match chunk {
 230                acp::ContentBlock::Text(text_content) => {
 231                    content.push_str(&text_content.text);
 232                }
 233                acp::ContentBlock::ResourceLink(resource_link) => {
 234                    content.push_str(&format!("@{}", resource_link.uri));
 235                }
 236                acp::ContentBlock::Audio(_)
 237                | acp::ContentBlock::Image(_)
 238                | acp::ContentBlock::Resource(_) => {
 239                    // TODO
 240                }
 241            }
 242        }
 243
 244        if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
 245            message: Message {
 246                role: Role::User,
 247                content: Content::UntaggedText(content),
 248                id: None,
 249                model: None,
 250                stop_reason: None,
 251                stop_sequence: None,
 252                usage: None,
 253            },
 254            session_id: Some(params.session_id.to_string()),
 255        }) {
 256            return Task::ready(Err(anyhow!(err)));
 257        }
 258
 259        cx.foreground_executor().spawn(async move { end_rx.await? })
 260    }
 261
 262    fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
 263        let sessions = self.sessions.borrow();
 264        let Some(session) = sessions.get(&session_id) else {
 265            log::warn!("Attempted to cancel nonexistent session {}", session_id);
 266            return;
 267        };
 268
 269        let request_id = new_request_id();
 270
 271        let turn_state = session.turn_state.take();
 272        let TurnState::InProgress { end_tx } = turn_state else {
 273            // Already cancelled or idle, put it back
 274            session.turn_state.replace(turn_state);
 275            return;
 276        };
 277
 278        session.turn_state.replace(TurnState::CancelRequested {
 279            end_tx,
 280            request_id: request_id.clone(),
 281        });
 282
 283        session
 284            .outgoing_tx
 285            .unbounded_send(SdkMessage::ControlRequest {
 286                request_id,
 287                request: ControlRequest::Interrupt,
 288            })
 289            .log_err();
 290    }
 291}
 292
 293#[derive(Clone, Copy)]
 294enum ClaudeSessionMode {
 295    Start,
 296    #[expect(dead_code)]
 297    Resume,
 298}
 299
 300fn spawn_claude(
 301    command: &AgentServerCommand,
 302    mode: ClaudeSessionMode,
 303    session_id: acp::SessionId,
 304    mcp_config_path: &Path,
 305    root_dir: &Path,
 306) -> Result<Child> {
 307    let child = util::command::new_smol_command(&command.path)
 308        .args([
 309            "--input-format",
 310            "stream-json",
 311            "--output-format",
 312            "stream-json",
 313            "--print",
 314            "--verbose",
 315            "--mcp-config",
 316            mcp_config_path.to_string_lossy().as_ref(),
 317            "--permission-prompt-tool",
 318            &format!(
 319                "mcp__{}__{}",
 320                mcp_server::SERVER_NAME,
 321                mcp_server::PermissionTool::NAME,
 322            ),
 323            "--allowedTools",
 324            &format!(
 325                "mcp__{}__{},mcp__{}__{}",
 326                mcp_server::SERVER_NAME,
 327                mcp_server::EditTool::NAME,
 328                mcp_server::SERVER_NAME,
 329                mcp_server::ReadTool::NAME
 330            ),
 331            "--disallowedTools",
 332            "Read,Edit",
 333        ])
 334        .args(match mode {
 335            ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
 336            ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
 337        })
 338        .args(command.args.iter().map(|arg| arg.as_str()))
 339        .current_dir(root_dir)
 340        .stdin(std::process::Stdio::piped())
 341        .stdout(std::process::Stdio::piped())
 342        .stderr(std::process::Stdio::inherit())
 343        .kill_on_drop(true)
 344        .spawn()?;
 345
 346    Ok(child)
 347}
 348
 349struct ClaudeAgentSession {
 350    outgoing_tx: UnboundedSender<SdkMessage>,
 351    turn_state: Rc<RefCell<TurnState>>,
 352    _mcp_server: Option<ClaudeZedMcpServer>,
 353    _handler_task: Task<()>,
 354}
 355
 356#[derive(Debug, Default)]
 357enum TurnState {
 358    #[default]
 359    None,
 360    InProgress {
 361        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 362    },
 363    CancelRequested {
 364        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 365        request_id: String,
 366    },
 367    CancelConfirmed {
 368        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 369    },
 370}
 371
 372impl TurnState {
 373    fn is_cancelled(&self) -> bool {
 374        matches!(self, TurnState::CancelConfirmed { .. })
 375    }
 376
 377    fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
 378        match self {
 379            TurnState::None => None,
 380            TurnState::InProgress { end_tx, .. } => Some(end_tx),
 381            TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
 382            TurnState::CancelConfirmed { end_tx } => Some(end_tx),
 383        }
 384    }
 385
 386    fn confirm_cancellation(self, id: &str) -> Self {
 387        match self {
 388            TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
 389                TurnState::CancelConfirmed { end_tx }
 390            }
 391            _ => self,
 392        }
 393    }
 394}
 395
 396impl ClaudeAgentSession {
 397    async fn handle_message(
 398        mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
 399        message: SdkMessage,
 400        turn_state: Rc<RefCell<TurnState>>,
 401        cx: &mut AsyncApp,
 402    ) {
 403        match message {
 404            // we should only be sending these out, they don't need to be in the thread
 405            SdkMessage::ControlRequest { .. } => {}
 406            SdkMessage::User {
 407                message,
 408                session_id: _,
 409            } => {
 410                let Some(thread) = thread_rx
 411                    .recv()
 412                    .await
 413                    .log_err()
 414                    .and_then(|entity| entity.upgrade())
 415                else {
 416                    log::error!("Received an SDK message but thread is gone");
 417                    return;
 418                };
 419
 420                for chunk in message.content.chunks() {
 421                    match chunk {
 422                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 423                            if !turn_state.borrow().is_cancelled() {
 424                                thread
 425                                    .update(cx, |thread, cx| {
 426                                        thread.push_user_content_block(text.into(), cx)
 427                                    })
 428                                    .log_err();
 429                            }
 430                        }
 431                        ContentChunk::ToolResult {
 432                            content,
 433                            tool_use_id,
 434                        } => {
 435                            let content = content.to_string();
 436                            thread
 437                                .update(cx, |thread, cx| {
 438                                    thread.update_tool_call(
 439                                        acp::ToolCallUpdate {
 440                                            id: acp::ToolCallId(tool_use_id.into()),
 441                                            fields: acp::ToolCallUpdateFields {
 442                                                status: if turn_state.borrow().is_cancelled() {
 443                                                    // Do not set to completed if turn was cancelled
 444                                                    None
 445                                                } else {
 446                                                    Some(acp::ToolCallStatus::Completed)
 447                                                },
 448                                                content: (!content.is_empty())
 449                                                    .then(|| vec![content.into()]),
 450                                                ..Default::default()
 451                                            },
 452                                        },
 453                                        cx,
 454                                    )
 455                                })
 456                                .log_err();
 457                        }
 458                        ContentChunk::Thinking { .. }
 459                        | ContentChunk::RedactedThinking
 460                        | ContentChunk::ToolUse { .. } => {
 461                            debug_panic!(
 462                                "Should not get {:?} with role: assistant. should we handle this?",
 463                                chunk
 464                            );
 465                        }
 466
 467                        ContentChunk::Image
 468                        | ContentChunk::Document
 469                        | ContentChunk::WebSearchToolResult => {
 470                            thread
 471                                .update(cx, |thread, cx| {
 472                                    thread.push_assistant_content_block(
 473                                        format!("Unsupported content: {:?}", chunk).into(),
 474                                        false,
 475                                        cx,
 476                                    )
 477                                })
 478                                .log_err();
 479                        }
 480                    }
 481                }
 482            }
 483            SdkMessage::Assistant {
 484                message,
 485                session_id: _,
 486            } => {
 487                let Some(thread) = thread_rx
 488                    .recv()
 489                    .await
 490                    .log_err()
 491                    .and_then(|entity| entity.upgrade())
 492                else {
 493                    log::error!("Received an SDK message but thread is gone");
 494                    return;
 495                };
 496
 497                for chunk in message.content.chunks() {
 498                    match chunk {
 499                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 500                            thread
 501                                .update(cx, |thread, cx| {
 502                                    thread.push_assistant_content_block(text.into(), false, cx)
 503                                })
 504                                .log_err();
 505                        }
 506                        ContentChunk::Thinking { thinking } => {
 507                            thread
 508                                .update(cx, |thread, cx| {
 509                                    thread.push_assistant_content_block(thinking.into(), true, cx)
 510                                })
 511                                .log_err();
 512                        }
 513                        ContentChunk::RedactedThinking => {
 514                            thread
 515                                .update(cx, |thread, cx| {
 516                                    thread.push_assistant_content_block(
 517                                        "[REDACTED]".into(),
 518                                        true,
 519                                        cx,
 520                                    )
 521                                })
 522                                .log_err();
 523                        }
 524                        ContentChunk::ToolUse { id, name, input } => {
 525                            let claude_tool = ClaudeTool::infer(&name, input);
 526
 527                            thread
 528                                .update(cx, |thread, cx| {
 529                                    if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
 530                                        thread.update_plan(
 531                                            acp::Plan {
 532                                                entries: params
 533                                                    .todos
 534                                                    .into_iter()
 535                                                    .map(Into::into)
 536                                                    .collect(),
 537                                            },
 538                                            cx,
 539                                        )
 540                                    } else {
 541                                        thread.upsert_tool_call(
 542                                            claude_tool.as_acp(acp::ToolCallId(id.into())),
 543                                            cx,
 544                                        );
 545                                    }
 546                                })
 547                                .log_err();
 548                        }
 549                        ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
 550                            debug_panic!(
 551                                "Should not get tool results with role: assistant. should we handle this?"
 552                            );
 553                        }
 554                        ContentChunk::Image | ContentChunk::Document => {
 555                            thread
 556                                .update(cx, |thread, cx| {
 557                                    thread.push_assistant_content_block(
 558                                        format!("Unsupported content: {:?}", chunk).into(),
 559                                        false,
 560                                        cx,
 561                                    )
 562                                })
 563                                .log_err();
 564                        }
 565                    }
 566                }
 567            }
 568            SdkMessage::Result {
 569                is_error,
 570                subtype,
 571                result,
 572                ..
 573            } => {
 574                let turn_state = turn_state.take();
 575                let was_cancelled = turn_state.is_cancelled();
 576                let Some(end_turn_tx) = turn_state.end_tx() else {
 577                    debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
 578                    return;
 579                };
 580
 581                if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution)
 582                {
 583                    end_turn_tx
 584                        .send(Err(anyhow!(
 585                            "Error: {}",
 586                            result.unwrap_or_else(|| subtype.to_string())
 587                        )))
 588                        .ok();
 589                } else {
 590                    let stop_reason = match subtype {
 591                        ResultErrorType::Success => acp::StopReason::EndTurn,
 592                        ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
 593                        ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
 594                    };
 595                    end_turn_tx
 596                        .send(Ok(acp::PromptResponse { stop_reason }))
 597                        .ok();
 598                }
 599            }
 600            SdkMessage::ControlResponse { response } => {
 601                if matches!(response.subtype, ResultErrorType::Success) {
 602                    let new_state = turn_state.take().confirm_cancellation(&response.request_id);
 603                    turn_state.replace(new_state);
 604                } else {
 605                    log::error!("Control response error: {:?}", response);
 606                }
 607            }
 608            SdkMessage::System { .. } => {}
 609        }
 610    }
 611
 612    async fn handle_io(
 613        mut outgoing_rx: UnboundedReceiver<SdkMessage>,
 614        incoming_tx: UnboundedSender<SdkMessage>,
 615        mut outgoing_bytes: impl Unpin + AsyncWrite,
 616        incoming_bytes: impl Unpin + AsyncRead,
 617    ) -> Result<UnboundedReceiver<SdkMessage>> {
 618        let mut output_reader = BufReader::new(incoming_bytes);
 619        let mut outgoing_line = Vec::new();
 620        let mut incoming_line = String::new();
 621        loop {
 622            select_biased! {
 623                message = outgoing_rx.next() => {
 624                    if let Some(message) = message {
 625                        outgoing_line.clear();
 626                        serde_json::to_writer(&mut outgoing_line, &message)?;
 627                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
 628                        outgoing_line.push(b'\n');
 629                        outgoing_bytes.write_all(&outgoing_line).await.ok();
 630                    } else {
 631                        break;
 632                    }
 633                }
 634                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
 635                    if bytes_read? == 0 {
 636                        break
 637                    }
 638                    log::trace!("recv: {}", &incoming_line);
 639                    match serde_json::from_str::<SdkMessage>(&incoming_line) {
 640                        Ok(message) => {
 641                            incoming_tx.unbounded_send(message).log_err();
 642                        }
 643                        Err(error) => {
 644                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
 645                        }
 646                    }
 647                    incoming_line.clear();
 648                }
 649            }
 650        }
 651
 652        Ok(outgoing_rx)
 653    }
 654}
 655
 656#[derive(Debug, Clone, Serialize, Deserialize)]
 657struct Message {
 658    role: Role,
 659    content: Content,
 660    #[serde(skip_serializing_if = "Option::is_none")]
 661    id: Option<String>,
 662    #[serde(skip_serializing_if = "Option::is_none")]
 663    model: Option<String>,
 664    #[serde(skip_serializing_if = "Option::is_none")]
 665    stop_reason: Option<String>,
 666    #[serde(skip_serializing_if = "Option::is_none")]
 667    stop_sequence: Option<String>,
 668    #[serde(skip_serializing_if = "Option::is_none")]
 669    usage: Option<Usage>,
 670}
 671
 672#[derive(Debug, Clone, Serialize, Deserialize)]
 673#[serde(untagged)]
 674enum Content {
 675    UntaggedText(String),
 676    Chunks(Vec<ContentChunk>),
 677}
 678
 679impl Content {
 680    pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
 681        match self {
 682            Self::Chunks(chunks) => chunks.into_iter(),
 683            Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
 684        }
 685    }
 686}
 687
 688impl Display for Content {
 689    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 690        match self {
 691            Content::UntaggedText(txt) => write!(f, "{}", txt),
 692            Content::Chunks(chunks) => {
 693                for chunk in chunks {
 694                    write!(f, "{}", chunk)?;
 695                }
 696                Ok(())
 697            }
 698        }
 699    }
 700}
 701
 702#[derive(Debug, Clone, Serialize, Deserialize)]
 703#[serde(tag = "type", rename_all = "snake_case")]
 704enum ContentChunk {
 705    Text {
 706        text: String,
 707    },
 708    ToolUse {
 709        id: String,
 710        name: String,
 711        input: serde_json::Value,
 712    },
 713    ToolResult {
 714        content: Content,
 715        tool_use_id: String,
 716    },
 717    Thinking {
 718        thinking: String,
 719    },
 720    RedactedThinking,
 721    // TODO
 722    Image,
 723    Document,
 724    WebSearchToolResult,
 725    #[serde(untagged)]
 726    UntaggedText(String),
 727}
 728
 729impl Display for ContentChunk {
 730    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 731        match self {
 732            ContentChunk::Text { text } => write!(f, "{}", text),
 733            ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
 734            ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
 735            ContentChunk::UntaggedText(text) => write!(f, "{}", text),
 736            ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
 737            ContentChunk::Image
 738            | ContentChunk::Document
 739            | ContentChunk::ToolUse { .. }
 740            | ContentChunk::WebSearchToolResult => {
 741                write!(f, "\n{:?}\n", &self)
 742            }
 743        }
 744    }
 745}
 746
 747#[derive(Debug, Clone, Serialize, Deserialize)]
 748struct Usage {
 749    input_tokens: u32,
 750    cache_creation_input_tokens: u32,
 751    cache_read_input_tokens: u32,
 752    output_tokens: u32,
 753    service_tier: String,
 754}
 755
 756#[derive(Debug, Clone, Serialize, Deserialize)]
 757#[serde(rename_all = "snake_case")]
 758enum Role {
 759    System,
 760    Assistant,
 761    User,
 762}
 763
 764#[derive(Debug, Clone, Serialize, Deserialize)]
 765struct MessageParam {
 766    role: Role,
 767    content: String,
 768}
 769
 770#[derive(Debug, Clone, Serialize, Deserialize)]
 771#[serde(tag = "type", rename_all = "snake_case")]
 772enum SdkMessage {
 773    // An assistant message
 774    Assistant {
 775        message: Message, // from Anthropic SDK
 776        #[serde(skip_serializing_if = "Option::is_none")]
 777        session_id: Option<String>,
 778    },
 779    // A user message
 780    User {
 781        message: Message, // from Anthropic SDK
 782        #[serde(skip_serializing_if = "Option::is_none")]
 783        session_id: Option<String>,
 784    },
 785    // Emitted as the last message in a conversation
 786    Result {
 787        subtype: ResultErrorType,
 788        duration_ms: f64,
 789        duration_api_ms: f64,
 790        is_error: bool,
 791        num_turns: i32,
 792        #[serde(skip_serializing_if = "Option::is_none")]
 793        result: Option<String>,
 794        session_id: String,
 795        total_cost_usd: f64,
 796    },
 797    // Emitted as the first message at the start of a conversation
 798    System {
 799        cwd: String,
 800        session_id: String,
 801        tools: Vec<String>,
 802        model: String,
 803        mcp_servers: Vec<McpServer>,
 804        #[serde(rename = "apiKeySource")]
 805        api_key_source: String,
 806        #[serde(rename = "permissionMode")]
 807        permission_mode: PermissionMode,
 808    },
 809    /// Messages used to control the conversation, outside of chat messages to the model
 810    ControlRequest {
 811        request_id: String,
 812        request: ControlRequest,
 813    },
 814    /// Response to a control request
 815    ControlResponse { response: ControlResponse },
 816}
 817
 818#[derive(Debug, Clone, Serialize, Deserialize)]
 819#[serde(tag = "subtype", rename_all = "snake_case")]
 820enum ControlRequest {
 821    /// Cancel the current conversation
 822    Interrupt,
 823}
 824
 825#[derive(Debug, Clone, Serialize, Deserialize)]
 826struct ControlResponse {
 827    request_id: String,
 828    subtype: ResultErrorType,
 829}
 830
 831#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
 832#[serde(rename_all = "snake_case")]
 833enum ResultErrorType {
 834    Success,
 835    ErrorMaxTurns,
 836    ErrorDuringExecution,
 837}
 838
 839impl Display for ResultErrorType {
 840    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 841        match self {
 842            ResultErrorType::Success => write!(f, "success"),
 843            ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
 844            ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
 845        }
 846    }
 847}
 848
 849fn new_request_id() -> String {
 850    use rand::Rng;
 851    // In the Claude Code TS SDK they just generate a random 12 character string,
 852    // `Math.random().toString(36).substring(2, 15)`
 853    rand::thread_rng()
 854        .sample_iter(&rand::distributions::Alphanumeric)
 855        .take(12)
 856        .map(char::from)
 857        .collect()
 858}
 859
 860#[derive(Debug, Clone, Serialize, Deserialize)]
 861struct McpServer {
 862    name: String,
 863    status: String,
 864}
 865
 866#[derive(Debug, Clone, Serialize, Deserialize)]
 867#[serde(rename_all = "camelCase")]
 868enum PermissionMode {
 869    Default,
 870    AcceptEdits,
 871    BypassPermissions,
 872    Plan,
 873}
 874
 875#[cfg(test)]
 876pub(crate) mod tests {
 877    use super::*;
 878    use crate::e2e_tests;
 879    use gpui::TestAppContext;
 880    use serde_json::json;
 881
 882    crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
 883
 884    pub fn local_command() -> AgentServerCommand {
 885        AgentServerCommand {
 886            path: "claude".into(),
 887            args: vec![],
 888            env: None,
 889        }
 890    }
 891
 892    #[gpui::test]
 893    #[cfg_attr(not(feature = "e2e"), ignore)]
 894    async fn test_todo_plan(cx: &mut TestAppContext) {
 895        let fs = e2e_tests::init_test(cx).await;
 896        let project = Project::test(fs, [], cx).await;
 897        let thread =
 898            e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
 899
 900        thread
 901            .update(cx, |thread, cx| {
 902                thread.send_raw(
 903                    "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
 904                    cx,
 905                )
 906            })
 907            .await
 908            .unwrap();
 909
 910        let mut entries_len = 0;
 911
 912        thread.read_with(cx, |thread, _| {
 913            entries_len = thread.plan().entries.len();
 914            assert!(thread.plan().entries.len() > 0, "Empty plan");
 915        });
 916
 917        thread
 918            .update(cx, |thread, cx| {
 919                thread.send_raw(
 920                    "Mark the first entry status as in progress without acting on it.",
 921                    cx,
 922                )
 923            })
 924            .await
 925            .unwrap();
 926
 927        thread.read_with(cx, |thread, _| {
 928            assert!(matches!(
 929                thread.plan().entries[0].status,
 930                acp::PlanEntryStatus::InProgress
 931            ));
 932            assert_eq!(thread.plan().entries.len(), entries_len);
 933        });
 934
 935        thread
 936            .update(cx, |thread, cx| {
 937                thread.send_raw(
 938                    "Now mark the first entry as completed without acting on it.",
 939                    cx,
 940                )
 941            })
 942            .await
 943            .unwrap();
 944
 945        thread.read_with(cx, |thread, _| {
 946            assert!(matches!(
 947                thread.plan().entries[0].status,
 948                acp::PlanEntryStatus::Completed
 949            ));
 950            assert_eq!(thread.plan().entries.len(), entries_len);
 951        });
 952    }
 953
 954    #[test]
 955    fn test_deserialize_content_untagged_text() {
 956        let json = json!("Hello, world!");
 957        let content: Content = serde_json::from_value(json).unwrap();
 958        match content {
 959            Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
 960            _ => panic!("Expected UntaggedText variant"),
 961        }
 962    }
 963
 964    #[test]
 965    fn test_deserialize_content_chunks() {
 966        let json = json!([
 967            {
 968                "type": "text",
 969                "text": "Hello"
 970            },
 971            {
 972                "type": "tool_use",
 973                "id": "tool_123",
 974                "name": "calculator",
 975                "input": {"operation": "add", "a": 1, "b": 2}
 976            }
 977        ]);
 978        let content: Content = serde_json::from_value(json).unwrap();
 979        match content {
 980            Content::Chunks(chunks) => {
 981                assert_eq!(chunks.len(), 2);
 982                match &chunks[0] {
 983                    ContentChunk::Text { text } => assert_eq!(text, "Hello"),
 984                    _ => panic!("Expected Text chunk"),
 985                }
 986                match &chunks[1] {
 987                    ContentChunk::ToolUse { id, name, input } => {
 988                        assert_eq!(id, "tool_123");
 989                        assert_eq!(name, "calculator");
 990                        assert_eq!(input["operation"], "add");
 991                        assert_eq!(input["a"], 1);
 992                        assert_eq!(input["b"], 2);
 993                    }
 994                    _ => panic!("Expected ToolUse chunk"),
 995                }
 996            }
 997            _ => panic!("Expected Chunks variant"),
 998        }
 999    }
1000
1001    #[test]
1002    fn test_deserialize_tool_result_untagged_text() {
1003        let json = json!({
1004            "type": "tool_result",
1005            "content": "Result content",
1006            "tool_use_id": "tool_456"
1007        });
1008        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1009        match chunk {
1010            ContentChunk::ToolResult {
1011                content,
1012                tool_use_id,
1013            } => {
1014                match content {
1015                    Content::UntaggedText(text) => assert_eq!(text, "Result content"),
1016                    _ => panic!("Expected UntaggedText content"),
1017                }
1018                assert_eq!(tool_use_id, "tool_456");
1019            }
1020            _ => panic!("Expected ToolResult variant"),
1021        }
1022    }
1023
1024    #[test]
1025    fn test_deserialize_tool_result_chunks() {
1026        let json = json!({
1027            "type": "tool_result",
1028            "content": [
1029                {
1030                    "type": "text",
1031                    "text": "Processing complete"
1032                },
1033                {
1034                    "type": "text",
1035                    "text": "Result: 42"
1036                }
1037            ],
1038            "tool_use_id": "tool_789"
1039        });
1040        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1041        match chunk {
1042            ContentChunk::ToolResult {
1043                content,
1044                tool_use_id,
1045            } => {
1046                match content {
1047                    Content::Chunks(chunks) => {
1048                        assert_eq!(chunks.len(), 2);
1049                        match &chunks[0] {
1050                            ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1051                            _ => panic!("Expected Text chunk"),
1052                        }
1053                        match &chunks[1] {
1054                            ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1055                            _ => panic!("Expected Text chunk"),
1056                        }
1057                    }
1058                    _ => panic!("Expected Chunks content"),
1059                }
1060                assert_eq!(tool_use_id, "tool_789");
1061            }
1062            _ => panic!("Expected ToolResult variant"),
1063        }
1064    }
1065}