claude.rs

   1mod mcp_server;
   2pub mod tools;
   3
   4use collections::HashMap;
   5use context_server::listener::McpServerTool;
   6use project::Project;
   7use settings::SettingsStore;
   8use smol::process::Child;
   9use std::cell::RefCell;
  10use std::fmt::Display;
  11use std::path::Path;
  12use std::rc::Rc;
  13use uuid::Uuid;
  14
  15use agent_client_protocol as acp;
  16use anyhow::{Result, anyhow};
  17use futures::channel::oneshot;
  18use futures::{AsyncBufReadExt, AsyncWriteExt};
  19use futures::{
  20    AsyncRead, AsyncWrite, FutureExt, StreamExt,
  21    channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
  22    io::BufReader,
  23    select_biased,
  24};
  25use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  26use serde::{Deserialize, Serialize};
  27use util::{ResultExt, debug_panic};
  28
  29use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
  30use crate::claude::tools::ClaudeTool;
  31use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
  32use acp_thread::{AcpThread, AgentConnection};
  33
  34#[derive(Clone)]
  35pub struct ClaudeCode;
  36
  37impl AgentServer for ClaudeCode {
  38    fn name(&self) -> &'static str {
  39        "Claude Code"
  40    }
  41
  42    fn empty_state_headline(&self) -> &'static str {
  43        self.name()
  44    }
  45
  46    fn empty_state_message(&self) -> &'static str {
  47        "How can I help you today?"
  48    }
  49
  50    fn logo(&self) -> ui::IconName {
  51        ui::IconName::AiClaude
  52    }
  53
  54    fn connect(
  55        &self,
  56        _root_dir: &Path,
  57        _project: &Entity<Project>,
  58        _cx: &mut App,
  59    ) -> Task<Result<Rc<dyn AgentConnection>>> {
  60        let connection = ClaudeAgentConnection {
  61            sessions: Default::default(),
  62        };
  63
  64        Task::ready(Ok(Rc::new(connection) as _))
  65    }
  66}
  67
  68struct ClaudeAgentConnection {
  69    sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
  70}
  71
  72impl AgentConnection for ClaudeAgentConnection {
  73    fn new_thread(
  74        self: Rc<Self>,
  75        project: Entity<Project>,
  76        cwd: &Path,
  77        cx: &mut AsyncApp,
  78    ) -> Task<Result<Entity<AcpThread>>> {
  79        let cwd = cwd.to_owned();
  80        cx.spawn(async move |cx| {
  81            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
  82            let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?;
  83
  84            let mut mcp_servers = HashMap::default();
  85            mcp_servers.insert(
  86                mcp_server::SERVER_NAME.to_string(),
  87                permission_mcp_server.server_config()?,
  88            );
  89            let mcp_config = McpConfig { mcp_servers };
  90
  91            let mcp_config_file = tempfile::NamedTempFile::new()?;
  92            let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
  93
  94            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
  95            mcp_config_file
  96                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
  97                .await?;
  98            mcp_config_file.flush().await?;
  99
 100            let settings = cx.read_global(|settings: &SettingsStore, _| {
 101                settings.get::<AllAgentServersSettings>(None).claude.clone()
 102            })?;
 103
 104            let Some(command) = AgentServerCommand::resolve(
 105                "claude",
 106                &[],
 107                Some(&util::paths::home_dir().join(".claude/local/claude")),
 108                settings,
 109                &project,
 110                cx,
 111            )
 112            .await
 113            else {
 114                anyhow::bail!("Failed to find claude binary");
 115            };
 116
 117            let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
 118            let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
 119
 120            let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
 121
 122            log::trace!("Starting session with id: {}", session_id);
 123
 124            let mut child = spawn_claude(
 125                &command,
 126                ClaudeSessionMode::Start,
 127                session_id.clone(),
 128                &mcp_config_path,
 129                &cwd,
 130            )?;
 131
 132            let stdin = child.stdin.take().unwrap();
 133            let stdout = child.stdout.take().unwrap();
 134
 135            let pid = child.id();
 136            log::trace!("Spawned (pid: {})", pid);
 137
 138            cx.background_spawn(async move {
 139                let mut outgoing_rx = Some(outgoing_rx);
 140
 141                ClaudeAgentSession::handle_io(
 142                    outgoing_rx.take().unwrap(),
 143                    incoming_message_tx.clone(),
 144                    stdin,
 145                    stdout,
 146                )
 147                .await?;
 148
 149                log::trace!("Stopped (pid: {})", pid);
 150
 151                drop(mcp_config_path);
 152                anyhow::Ok(())
 153            })
 154            .detach();
 155
 156            let turn_state = Rc::new(RefCell::new(TurnState::None));
 157
 158            let handler_task = cx.spawn({
 159                let turn_state = turn_state.clone();
 160                let mut thread_rx = thread_rx.clone();
 161                async move |cx| {
 162                    while let Some(message) = incoming_message_rx.next().await {
 163                        ClaudeAgentSession::handle_message(
 164                            thread_rx.clone(),
 165                            message,
 166                            turn_state.clone(),
 167                            cx,
 168                        )
 169                        .await
 170                    }
 171
 172                    if let Some(status) = child.status().await.log_err() {
 173                        if let Some(thread) = thread_rx.recv().await.ok() {
 174                            thread
 175                                .update(cx, |thread, cx| {
 176                                    thread.emit_server_exited(status, cx);
 177                                })
 178                                .ok();
 179                        }
 180                    }
 181                }
 182            });
 183
 184            let thread = cx.new(|cx| {
 185                AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
 186            })?;
 187
 188            thread_tx.send(thread.downgrade())?;
 189
 190            let session = ClaudeAgentSession {
 191                outgoing_tx,
 192                turn_state,
 193                _handler_task: handler_task,
 194                _mcp_server: Some(permission_mcp_server),
 195            };
 196
 197            self.sessions.borrow_mut().insert(session_id, session);
 198
 199            Ok(thread)
 200        })
 201    }
 202
 203    fn auth_methods(&self) -> &[acp::AuthMethod] {
 204        &[]
 205    }
 206
 207    fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 208        Task::ready(Err(anyhow!("Authentication not supported")))
 209    }
 210
 211    fn prompt(
 212        &self,
 213        _id: Option<acp_thread::UserMessageId>,
 214        params: acp::PromptRequest,
 215        cx: &mut App,
 216    ) -> Task<Result<acp::PromptResponse>> {
 217        let sessions = self.sessions.borrow();
 218        let Some(session) = sessions.get(&params.session_id) else {
 219            return Task::ready(Err(anyhow!(
 220                "Attempted to send message to nonexistent session {}",
 221                params.session_id
 222            )));
 223        };
 224
 225        let (end_tx, end_rx) = oneshot::channel();
 226        session.turn_state.replace(TurnState::InProgress { end_tx });
 227
 228        let mut content = String::new();
 229        for chunk in params.prompt {
 230            match chunk {
 231                acp::ContentBlock::Text(text_content) => {
 232                    content.push_str(&text_content.text);
 233                }
 234                acp::ContentBlock::ResourceLink(resource_link) => {
 235                    content.push_str(&format!("@{}", resource_link.uri));
 236                }
 237                acp::ContentBlock::Audio(_)
 238                | acp::ContentBlock::Image(_)
 239                | acp::ContentBlock::Resource(_) => {
 240                    // TODO
 241                }
 242            }
 243        }
 244
 245        if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
 246            message: Message {
 247                role: Role::User,
 248                content: Content::UntaggedText(content),
 249                id: None,
 250                model: None,
 251                stop_reason: None,
 252                stop_sequence: None,
 253                usage: None,
 254            },
 255            session_id: Some(params.session_id.to_string()),
 256        }) {
 257            return Task::ready(Err(anyhow!(err)));
 258        }
 259
 260        cx.foreground_executor().spawn(async move { end_rx.await? })
 261    }
 262
 263    fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
 264        let sessions = self.sessions.borrow();
 265        let Some(session) = sessions.get(&session_id) else {
 266            log::warn!("Attempted to cancel nonexistent session {}", session_id);
 267            return;
 268        };
 269
 270        let request_id = new_request_id();
 271
 272        let turn_state = session.turn_state.take();
 273        let TurnState::InProgress { end_tx } = turn_state else {
 274            // Already cancelled or idle, put it back
 275            session.turn_state.replace(turn_state);
 276            return;
 277        };
 278
 279        session.turn_state.replace(TurnState::CancelRequested {
 280            end_tx,
 281            request_id: request_id.clone(),
 282        });
 283
 284        session
 285            .outgoing_tx
 286            .unbounded_send(SdkMessage::ControlRequest {
 287                request_id,
 288                request: ControlRequest::Interrupt,
 289            })
 290            .log_err();
 291    }
 292}
 293
 294#[derive(Clone, Copy)]
 295enum ClaudeSessionMode {
 296    Start,
 297    #[expect(dead_code)]
 298    Resume,
 299}
 300
 301fn spawn_claude(
 302    command: &AgentServerCommand,
 303    mode: ClaudeSessionMode,
 304    session_id: acp::SessionId,
 305    mcp_config_path: &Path,
 306    root_dir: &Path,
 307) -> Result<Child> {
 308    let child = util::command::new_smol_command(&command.path)
 309        .args([
 310            "--input-format",
 311            "stream-json",
 312            "--output-format",
 313            "stream-json",
 314            "--print",
 315            "--verbose",
 316            "--mcp-config",
 317            mcp_config_path.to_string_lossy().as_ref(),
 318            "--permission-prompt-tool",
 319            &format!(
 320                "mcp__{}__{}",
 321                mcp_server::SERVER_NAME,
 322                mcp_server::PermissionTool::NAME,
 323            ),
 324            "--allowedTools",
 325            &format!(
 326                "mcp__{}__{},mcp__{}__{}",
 327                mcp_server::SERVER_NAME,
 328                mcp_server::EditTool::NAME,
 329                mcp_server::SERVER_NAME,
 330                mcp_server::ReadTool::NAME
 331            ),
 332            "--disallowedTools",
 333            "Read,Edit",
 334        ])
 335        .args(match mode {
 336            ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
 337            ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
 338        })
 339        .args(command.args.iter().map(|arg| arg.as_str()))
 340        .current_dir(root_dir)
 341        .stdin(std::process::Stdio::piped())
 342        .stdout(std::process::Stdio::piped())
 343        .stderr(std::process::Stdio::inherit())
 344        .kill_on_drop(true)
 345        .spawn()?;
 346
 347    Ok(child)
 348}
 349
 350struct ClaudeAgentSession {
 351    outgoing_tx: UnboundedSender<SdkMessage>,
 352    turn_state: Rc<RefCell<TurnState>>,
 353    _mcp_server: Option<ClaudeZedMcpServer>,
 354    _handler_task: Task<()>,
 355}
 356
 357#[derive(Debug, Default)]
 358enum TurnState {
 359    #[default]
 360    None,
 361    InProgress {
 362        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 363    },
 364    CancelRequested {
 365        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 366        request_id: String,
 367    },
 368    CancelConfirmed {
 369        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 370    },
 371}
 372
 373impl TurnState {
 374    fn is_cancelled(&self) -> bool {
 375        matches!(self, TurnState::CancelConfirmed { .. })
 376    }
 377
 378    fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
 379        match self {
 380            TurnState::None => None,
 381            TurnState::InProgress { end_tx, .. } => Some(end_tx),
 382            TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
 383            TurnState::CancelConfirmed { end_tx } => Some(end_tx),
 384        }
 385    }
 386
 387    fn confirm_cancellation(self, id: &str) -> Self {
 388        match self {
 389            TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
 390                TurnState::CancelConfirmed { end_tx }
 391            }
 392            _ => self,
 393        }
 394    }
 395}
 396
 397impl ClaudeAgentSession {
 398    async fn handle_message(
 399        mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
 400        message: SdkMessage,
 401        turn_state: Rc<RefCell<TurnState>>,
 402        cx: &mut AsyncApp,
 403    ) {
 404        match message {
 405            // we should only be sending these out, they don't need to be in the thread
 406            SdkMessage::ControlRequest { .. } => {}
 407            SdkMessage::User {
 408                message,
 409                session_id: _,
 410            } => {
 411                let Some(thread) = thread_rx
 412                    .recv()
 413                    .await
 414                    .log_err()
 415                    .and_then(|entity| entity.upgrade())
 416                else {
 417                    log::error!("Received an SDK message but thread is gone");
 418                    return;
 419                };
 420
 421                for chunk in message.content.chunks() {
 422                    match chunk {
 423                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 424                            if !turn_state.borrow().is_cancelled() {
 425                                thread
 426                                    .update(cx, |thread, cx| {
 427                                        thread.push_user_content_block(None, text.into(), cx)
 428                                    })
 429                                    .log_err();
 430                            }
 431                        }
 432                        ContentChunk::ToolResult {
 433                            content,
 434                            tool_use_id,
 435                        } => {
 436                            let content = content.to_string();
 437                            thread
 438                                .update(cx, |thread, cx| {
 439                                    thread.update_tool_call(
 440                                        acp::ToolCallUpdate {
 441                                            id: acp::ToolCallId(tool_use_id.into()),
 442                                            fields: acp::ToolCallUpdateFields {
 443                                                status: if turn_state.borrow().is_cancelled() {
 444                                                    // Do not set to completed if turn was cancelled
 445                                                    None
 446                                                } else {
 447                                                    Some(acp::ToolCallStatus::Completed)
 448                                                },
 449                                                content: (!content.is_empty())
 450                                                    .then(|| vec![content.into()]),
 451                                                ..Default::default()
 452                                            },
 453                                        },
 454                                        cx,
 455                                    )
 456                                })
 457                                .log_err();
 458                        }
 459                        ContentChunk::Thinking { .. }
 460                        | ContentChunk::RedactedThinking
 461                        | ContentChunk::ToolUse { .. } => {
 462                            debug_panic!(
 463                                "Should not get {:?} with role: assistant. should we handle this?",
 464                                chunk
 465                            );
 466                        }
 467
 468                        ContentChunk::Image
 469                        | ContentChunk::Document
 470                        | ContentChunk::WebSearchToolResult => {
 471                            thread
 472                                .update(cx, |thread, cx| {
 473                                    thread.push_assistant_content_block(
 474                                        format!("Unsupported content: {:?}", chunk).into(),
 475                                        false,
 476                                        cx,
 477                                    )
 478                                })
 479                                .log_err();
 480                        }
 481                    }
 482                }
 483            }
 484            SdkMessage::Assistant {
 485                message,
 486                session_id: _,
 487            } => {
 488                let Some(thread) = thread_rx
 489                    .recv()
 490                    .await
 491                    .log_err()
 492                    .and_then(|entity| entity.upgrade())
 493                else {
 494                    log::error!("Received an SDK message but thread is gone");
 495                    return;
 496                };
 497
 498                for chunk in message.content.chunks() {
 499                    match chunk {
 500                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 501                            thread
 502                                .update(cx, |thread, cx| {
 503                                    thread.push_assistant_content_block(text.into(), false, cx)
 504                                })
 505                                .log_err();
 506                        }
 507                        ContentChunk::Thinking { thinking } => {
 508                            thread
 509                                .update(cx, |thread, cx| {
 510                                    thread.push_assistant_content_block(thinking.into(), true, cx)
 511                                })
 512                                .log_err();
 513                        }
 514                        ContentChunk::RedactedThinking => {
 515                            thread
 516                                .update(cx, |thread, cx| {
 517                                    thread.push_assistant_content_block(
 518                                        "[REDACTED]".into(),
 519                                        true,
 520                                        cx,
 521                                    )
 522                                })
 523                                .log_err();
 524                        }
 525                        ContentChunk::ToolUse { id, name, input } => {
 526                            let claude_tool = ClaudeTool::infer(&name, input);
 527
 528                            thread
 529                                .update(cx, |thread, cx| {
 530                                    if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
 531                                        thread.update_plan(
 532                                            acp::Plan {
 533                                                entries: params
 534                                                    .todos
 535                                                    .into_iter()
 536                                                    .map(Into::into)
 537                                                    .collect(),
 538                                            },
 539                                            cx,
 540                                        )
 541                                    } else {
 542                                        thread.upsert_tool_call(
 543                                            claude_tool.as_acp(acp::ToolCallId(id.into())),
 544                                            cx,
 545                                        );
 546                                    }
 547                                })
 548                                .log_err();
 549                        }
 550                        ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
 551                            debug_panic!(
 552                                "Should not get tool results with role: assistant. should we handle this?"
 553                            );
 554                        }
 555                        ContentChunk::Image | ContentChunk::Document => {
 556                            thread
 557                                .update(cx, |thread, cx| {
 558                                    thread.push_assistant_content_block(
 559                                        format!("Unsupported content: {:?}", chunk).into(),
 560                                        false,
 561                                        cx,
 562                                    )
 563                                })
 564                                .log_err();
 565                        }
 566                    }
 567                }
 568            }
 569            SdkMessage::Result {
 570                is_error,
 571                subtype,
 572                result,
 573                ..
 574            } => {
 575                let turn_state = turn_state.take();
 576                let was_cancelled = turn_state.is_cancelled();
 577                let Some(end_turn_tx) = turn_state.end_tx() else {
 578                    debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
 579                    return;
 580                };
 581
 582                if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution)
 583                {
 584                    end_turn_tx
 585                        .send(Err(anyhow!(
 586                            "Error: {}",
 587                            result.unwrap_or_else(|| subtype.to_string())
 588                        )))
 589                        .ok();
 590                } else {
 591                    let stop_reason = match subtype {
 592                        ResultErrorType::Success => acp::StopReason::EndTurn,
 593                        ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
 594                        ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
 595                    };
 596                    end_turn_tx
 597                        .send(Ok(acp::PromptResponse { stop_reason }))
 598                        .ok();
 599                }
 600            }
 601            SdkMessage::ControlResponse { response } => {
 602                if matches!(response.subtype, ResultErrorType::Success) {
 603                    let new_state = turn_state.take().confirm_cancellation(&response.request_id);
 604                    turn_state.replace(new_state);
 605                } else {
 606                    log::error!("Control response error: {:?}", response);
 607                }
 608            }
 609            SdkMessage::System { .. } => {}
 610        }
 611    }
 612
 613    async fn handle_io(
 614        mut outgoing_rx: UnboundedReceiver<SdkMessage>,
 615        incoming_tx: UnboundedSender<SdkMessage>,
 616        mut outgoing_bytes: impl Unpin + AsyncWrite,
 617        incoming_bytes: impl Unpin + AsyncRead,
 618    ) -> Result<UnboundedReceiver<SdkMessage>> {
 619        let mut output_reader = BufReader::new(incoming_bytes);
 620        let mut outgoing_line = Vec::new();
 621        let mut incoming_line = String::new();
 622        loop {
 623            select_biased! {
 624                message = outgoing_rx.next() => {
 625                    if let Some(message) = message {
 626                        outgoing_line.clear();
 627                        serde_json::to_writer(&mut outgoing_line, &message)?;
 628                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
 629                        outgoing_line.push(b'\n');
 630                        outgoing_bytes.write_all(&outgoing_line).await.ok();
 631                    } else {
 632                        break;
 633                    }
 634                }
 635                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
 636                    if bytes_read? == 0 {
 637                        break
 638                    }
 639                    log::trace!("recv: {}", &incoming_line);
 640                    match serde_json::from_str::<SdkMessage>(&incoming_line) {
 641                        Ok(message) => {
 642                            incoming_tx.unbounded_send(message).log_err();
 643                        }
 644                        Err(error) => {
 645                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
 646                        }
 647                    }
 648                    incoming_line.clear();
 649                }
 650            }
 651        }
 652
 653        Ok(outgoing_rx)
 654    }
 655}
 656
 657#[derive(Debug, Clone, Serialize, Deserialize)]
 658struct Message {
 659    role: Role,
 660    content: Content,
 661    #[serde(skip_serializing_if = "Option::is_none")]
 662    id: Option<String>,
 663    #[serde(skip_serializing_if = "Option::is_none")]
 664    model: Option<String>,
 665    #[serde(skip_serializing_if = "Option::is_none")]
 666    stop_reason: Option<String>,
 667    #[serde(skip_serializing_if = "Option::is_none")]
 668    stop_sequence: Option<String>,
 669    #[serde(skip_serializing_if = "Option::is_none")]
 670    usage: Option<Usage>,
 671}
 672
 673#[derive(Debug, Clone, Serialize, Deserialize)]
 674#[serde(untagged)]
 675enum Content {
 676    UntaggedText(String),
 677    Chunks(Vec<ContentChunk>),
 678}
 679
 680impl Content {
 681    pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
 682        match self {
 683            Self::Chunks(chunks) => chunks.into_iter(),
 684            Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
 685        }
 686    }
 687}
 688
 689impl Display for Content {
 690    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 691        match self {
 692            Content::UntaggedText(txt) => write!(f, "{}", txt),
 693            Content::Chunks(chunks) => {
 694                for chunk in chunks {
 695                    write!(f, "{}", chunk)?;
 696                }
 697                Ok(())
 698            }
 699        }
 700    }
 701}
 702
 703#[derive(Debug, Clone, Serialize, Deserialize)]
 704#[serde(tag = "type", rename_all = "snake_case")]
 705enum ContentChunk {
 706    Text {
 707        text: String,
 708    },
 709    ToolUse {
 710        id: String,
 711        name: String,
 712        input: serde_json::Value,
 713    },
 714    ToolResult {
 715        content: Content,
 716        tool_use_id: String,
 717    },
 718    Thinking {
 719        thinking: String,
 720    },
 721    RedactedThinking,
 722    // TODO
 723    Image,
 724    Document,
 725    WebSearchToolResult,
 726    #[serde(untagged)]
 727    UntaggedText(String),
 728}
 729
 730impl Display for ContentChunk {
 731    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 732        match self {
 733            ContentChunk::Text { text } => write!(f, "{}", text),
 734            ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
 735            ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
 736            ContentChunk::UntaggedText(text) => write!(f, "{}", text),
 737            ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
 738            ContentChunk::Image
 739            | ContentChunk::Document
 740            | ContentChunk::ToolUse { .. }
 741            | ContentChunk::WebSearchToolResult => {
 742                write!(f, "\n{:?}\n", &self)
 743            }
 744        }
 745    }
 746}
 747
 748#[derive(Debug, Clone, Serialize, Deserialize)]
 749struct Usage {
 750    input_tokens: u32,
 751    cache_creation_input_tokens: u32,
 752    cache_read_input_tokens: u32,
 753    output_tokens: u32,
 754    service_tier: String,
 755}
 756
 757#[derive(Debug, Clone, Serialize, Deserialize)]
 758#[serde(rename_all = "snake_case")]
 759enum Role {
 760    System,
 761    Assistant,
 762    User,
 763}
 764
 765#[derive(Debug, Clone, Serialize, Deserialize)]
 766struct MessageParam {
 767    role: Role,
 768    content: String,
 769}
 770
 771#[derive(Debug, Clone, Serialize, Deserialize)]
 772#[serde(tag = "type", rename_all = "snake_case")]
 773enum SdkMessage {
 774    // An assistant message
 775    Assistant {
 776        message: Message, // from Anthropic SDK
 777        #[serde(skip_serializing_if = "Option::is_none")]
 778        session_id: Option<String>,
 779    },
 780    // A user message
 781    User {
 782        message: Message, // from Anthropic SDK
 783        #[serde(skip_serializing_if = "Option::is_none")]
 784        session_id: Option<String>,
 785    },
 786    // Emitted as the last message in a conversation
 787    Result {
 788        subtype: ResultErrorType,
 789        duration_ms: f64,
 790        duration_api_ms: f64,
 791        is_error: bool,
 792        num_turns: i32,
 793        #[serde(skip_serializing_if = "Option::is_none")]
 794        result: Option<String>,
 795        session_id: String,
 796        total_cost_usd: f64,
 797    },
 798    // Emitted as the first message at the start of a conversation
 799    System {
 800        cwd: String,
 801        session_id: String,
 802        tools: Vec<String>,
 803        model: String,
 804        mcp_servers: Vec<McpServer>,
 805        #[serde(rename = "apiKeySource")]
 806        api_key_source: String,
 807        #[serde(rename = "permissionMode")]
 808        permission_mode: PermissionMode,
 809    },
 810    /// Messages used to control the conversation, outside of chat messages to the model
 811    ControlRequest {
 812        request_id: String,
 813        request: ControlRequest,
 814    },
 815    /// Response to a control request
 816    ControlResponse { response: ControlResponse },
 817}
 818
 819#[derive(Debug, Clone, Serialize, Deserialize)]
 820#[serde(tag = "subtype", rename_all = "snake_case")]
 821enum ControlRequest {
 822    /// Cancel the current conversation
 823    Interrupt,
 824}
 825
 826#[derive(Debug, Clone, Serialize, Deserialize)]
 827struct ControlResponse {
 828    request_id: String,
 829    subtype: ResultErrorType,
 830}
 831
 832#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
 833#[serde(rename_all = "snake_case")]
 834enum ResultErrorType {
 835    Success,
 836    ErrorMaxTurns,
 837    ErrorDuringExecution,
 838}
 839
 840impl Display for ResultErrorType {
 841    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 842        match self {
 843            ResultErrorType::Success => write!(f, "success"),
 844            ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
 845            ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
 846        }
 847    }
 848}
 849
 850fn new_request_id() -> String {
 851    use rand::Rng;
 852    // In the Claude Code TS SDK they just generate a random 12 character string,
 853    // `Math.random().toString(36).substring(2, 15)`
 854    rand::thread_rng()
 855        .sample_iter(&rand::distributions::Alphanumeric)
 856        .take(12)
 857        .map(char::from)
 858        .collect()
 859}
 860
 861#[derive(Debug, Clone, Serialize, Deserialize)]
 862struct McpServer {
 863    name: String,
 864    status: String,
 865}
 866
 867#[derive(Debug, Clone, Serialize, Deserialize)]
 868#[serde(rename_all = "camelCase")]
 869enum PermissionMode {
 870    Default,
 871    AcceptEdits,
 872    BypassPermissions,
 873    Plan,
 874}
 875
 876#[cfg(test)]
 877pub(crate) mod tests {
 878    use super::*;
 879    use crate::e2e_tests;
 880    use gpui::TestAppContext;
 881    use serde_json::json;
 882
 883    crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
 884
 885    pub fn local_command() -> AgentServerCommand {
 886        AgentServerCommand {
 887            path: "claude".into(),
 888            args: vec![],
 889            env: None,
 890        }
 891    }
 892
 893    #[gpui::test]
 894    #[cfg_attr(not(feature = "e2e"), ignore)]
 895    async fn test_todo_plan(cx: &mut TestAppContext) {
 896        let fs = e2e_tests::init_test(cx).await;
 897        let project = Project::test(fs, [], cx).await;
 898        let thread =
 899            e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
 900
 901        thread
 902            .update(cx, |thread, cx| {
 903                thread.send_raw(
 904                    "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
 905                    cx,
 906                )
 907            })
 908            .await
 909            .unwrap();
 910
 911        let mut entries_len = 0;
 912
 913        thread.read_with(cx, |thread, _| {
 914            entries_len = thread.plan().entries.len();
 915            assert!(thread.plan().entries.len() > 0, "Empty plan");
 916        });
 917
 918        thread
 919            .update(cx, |thread, cx| {
 920                thread.send_raw(
 921                    "Mark the first entry status as in progress without acting on it.",
 922                    cx,
 923                )
 924            })
 925            .await
 926            .unwrap();
 927
 928        thread.read_with(cx, |thread, _| {
 929            assert!(matches!(
 930                thread.plan().entries[0].status,
 931                acp::PlanEntryStatus::InProgress
 932            ));
 933            assert_eq!(thread.plan().entries.len(), entries_len);
 934        });
 935
 936        thread
 937            .update(cx, |thread, cx| {
 938                thread.send_raw(
 939                    "Now mark the first entry as completed without acting on it.",
 940                    cx,
 941                )
 942            })
 943            .await
 944            .unwrap();
 945
 946        thread.read_with(cx, |thread, _| {
 947            assert!(matches!(
 948                thread.plan().entries[0].status,
 949                acp::PlanEntryStatus::Completed
 950            ));
 951            assert_eq!(thread.plan().entries.len(), entries_len);
 952        });
 953    }
 954
 955    #[test]
 956    fn test_deserialize_content_untagged_text() {
 957        let json = json!("Hello, world!");
 958        let content: Content = serde_json::from_value(json).unwrap();
 959        match content {
 960            Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
 961            _ => panic!("Expected UntaggedText variant"),
 962        }
 963    }
 964
 965    #[test]
 966    fn test_deserialize_content_chunks() {
 967        let json = json!([
 968            {
 969                "type": "text",
 970                "text": "Hello"
 971            },
 972            {
 973                "type": "tool_use",
 974                "id": "tool_123",
 975                "name": "calculator",
 976                "input": {"operation": "add", "a": 1, "b": 2}
 977            }
 978        ]);
 979        let content: Content = serde_json::from_value(json).unwrap();
 980        match content {
 981            Content::Chunks(chunks) => {
 982                assert_eq!(chunks.len(), 2);
 983                match &chunks[0] {
 984                    ContentChunk::Text { text } => assert_eq!(text, "Hello"),
 985                    _ => panic!("Expected Text chunk"),
 986                }
 987                match &chunks[1] {
 988                    ContentChunk::ToolUse { id, name, input } => {
 989                        assert_eq!(id, "tool_123");
 990                        assert_eq!(name, "calculator");
 991                        assert_eq!(input["operation"], "add");
 992                        assert_eq!(input["a"], 1);
 993                        assert_eq!(input["b"], 2);
 994                    }
 995                    _ => panic!("Expected ToolUse chunk"),
 996                }
 997            }
 998            _ => panic!("Expected Chunks variant"),
 999        }
1000    }
1001
1002    #[test]
1003    fn test_deserialize_tool_result_untagged_text() {
1004        let json = json!({
1005            "type": "tool_result",
1006            "content": "Result content",
1007            "tool_use_id": "tool_456"
1008        });
1009        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1010        match chunk {
1011            ContentChunk::ToolResult {
1012                content,
1013                tool_use_id,
1014            } => {
1015                match content {
1016                    Content::UntaggedText(text) => assert_eq!(text, "Result content"),
1017                    _ => panic!("Expected UntaggedText content"),
1018                }
1019                assert_eq!(tool_use_id, "tool_456");
1020            }
1021            _ => panic!("Expected ToolResult variant"),
1022        }
1023    }
1024
1025    #[test]
1026    fn test_deserialize_tool_result_chunks() {
1027        let json = json!({
1028            "type": "tool_result",
1029            "content": [
1030                {
1031                    "type": "text",
1032                    "text": "Processing complete"
1033                },
1034                {
1035                    "type": "text",
1036                    "text": "Result: 42"
1037                }
1038            ],
1039            "tool_use_id": "tool_789"
1040        });
1041        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1042        match chunk {
1043            ContentChunk::ToolResult {
1044                content,
1045                tool_use_id,
1046            } => {
1047                match content {
1048                    Content::Chunks(chunks) => {
1049                        assert_eq!(chunks.len(), 2);
1050                        match &chunks[0] {
1051                            ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1052                            _ => panic!("Expected Text chunk"),
1053                        }
1054                        match &chunks[1] {
1055                            ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1056                            _ => panic!("Expected Text chunk"),
1057                        }
1058                    }
1059                    _ => panic!("Expected Chunks content"),
1060                }
1061                assert_eq!(tool_use_id, "tool_789");
1062            }
1063            _ => panic!("Expected ToolResult variant"),
1064        }
1065    }
1066}