claude.rs

   1mod mcp_server;
   2pub mod tools;
   3
   4use collections::HashMap;
   5use context_server::listener::McpServerTool;
   6use language_models::provider::anthropic::AnthropicLanguageModelProvider;
   7use project::Project;
   8use settings::SettingsStore;
   9use smol::process::Child;
  10use std::any::Any;
  11use std::cell::RefCell;
  12use std::fmt::Display;
  13use std::path::Path;
  14use std::rc::Rc;
  15use uuid::Uuid;
  16
  17use agent_client_protocol as acp;
  18use anyhow::{Context as _, Result, anyhow};
  19use futures::channel::oneshot;
  20use futures::{AsyncBufReadExt, AsyncWriteExt};
  21use futures::{
  22    AsyncRead, AsyncWrite, FutureExt, StreamExt,
  23    channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
  24    io::BufReader,
  25    select_biased,
  26};
  27use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  28use serde::{Deserialize, Serialize};
  29use util::{ResultExt, debug_panic};
  30
  31use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
  32use crate::claude::tools::ClaudeTool;
  33use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
  34use acp_thread::{AcpThread, AgentConnection, AuthRequired};
  35
  36#[derive(Clone)]
  37pub struct ClaudeCode;
  38
  39impl AgentServer for ClaudeCode {
  40    fn name(&self) -> &'static str {
  41        "Claude Code"
  42    }
  43
  44    fn empty_state_headline(&self) -> &'static str {
  45        self.name()
  46    }
  47
  48    fn empty_state_message(&self) -> &'static str {
  49        "How can I help you today?"
  50    }
  51
  52    fn logo(&self) -> ui::IconName {
  53        ui::IconName::AiClaude
  54    }
  55
  56    fn connect(
  57        &self,
  58        _root_dir: &Path,
  59        _project: &Entity<Project>,
  60        _cx: &mut App,
  61    ) -> Task<Result<Rc<dyn AgentConnection>>> {
  62        let connection = ClaudeAgentConnection {
  63            sessions: Default::default(),
  64        };
  65
  66        Task::ready(Ok(Rc::new(connection) as _))
  67    }
  68}
  69
  70struct ClaudeAgentConnection {
  71    sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
  72}
  73
  74impl AgentConnection for ClaudeAgentConnection {
  75    fn new_thread(
  76        self: Rc<Self>,
  77        project: Entity<Project>,
  78        cwd: &Path,
  79        cx: &mut App,
  80    ) -> Task<Result<Entity<AcpThread>>> {
  81        let cwd = cwd.to_owned();
  82        cx.spawn(async move |cx| {
  83            let settings = cx.read_global(|settings: &SettingsStore, _| {
  84                settings.get::<AllAgentServersSettings>(None).claude.clone()
  85            })?;
  86
  87            let Some(command) = AgentServerCommand::resolve(
  88                "claude",
  89                &[],
  90                Some(&util::paths::home_dir().join(".claude/local/claude")),
  91                settings,
  92                &project,
  93                cx,
  94            )
  95            .await
  96            else {
  97                anyhow::bail!("Failed to find claude binary");
  98            };
  99
 100            let api_key =
 101                cx.update(AnthropicLanguageModelProvider::api_key)?
 102                    .await
 103                    .map_err(|err| {
 104                        if err.is::<language_model::AuthenticateError>() {
 105                            anyhow!(AuthRequired::new().with_language_model_provider(
 106                                language_model::ANTHROPIC_PROVIDER_ID
 107                            ))
 108                        } else {
 109                            anyhow!(err)
 110                        }
 111                    })?;
 112
 113            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
 114            let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?;
 115
 116            let mut mcp_servers = HashMap::default();
 117            mcp_servers.insert(
 118                mcp_server::SERVER_NAME.to_string(),
 119                permission_mcp_server.server_config()?,
 120            );
 121            let mcp_config = McpConfig { mcp_servers };
 122
 123            let mcp_config_file = tempfile::NamedTempFile::new()?;
 124            let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
 125
 126            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
 127            mcp_config_file
 128                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
 129                .await?;
 130            mcp_config_file.flush().await?;
 131
 132            let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
 133            let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
 134
 135            let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
 136
 137            log::trace!("Starting session with id: {}", session_id);
 138
 139            let mut child = spawn_claude(
 140                &command,
 141                ClaudeSessionMode::Start,
 142                session_id.clone(),
 143                api_key,
 144                &mcp_config_path,
 145                &cwd,
 146            )?;
 147
 148            let stdout = child.stdout.take().context("Failed to take stdout")?;
 149            let stdin = child.stdin.take().context("Failed to take stdin")?;
 150            let stderr = child.stderr.take().context("Failed to take stderr")?;
 151
 152            let pid = child.id();
 153            log::trace!("Spawned (pid: {})", pid);
 154
 155            cx.background_spawn(async move {
 156                let mut stderr = BufReader::new(stderr);
 157                let mut line = String::new();
 158                while let Ok(n) = stderr.read_line(&mut line).await
 159                    && n > 0
 160                {
 161                    log::warn!("agent stderr: {}", &line);
 162                    line.clear();
 163                }
 164            })
 165            .detach();
 166
 167            cx.background_spawn(async move {
 168                let mut outgoing_rx = Some(outgoing_rx);
 169
 170                ClaudeAgentSession::handle_io(
 171                    outgoing_rx.take().unwrap(),
 172                    incoming_message_tx.clone(),
 173                    stdin,
 174                    stdout,
 175                )
 176                .await?;
 177
 178                log::trace!("Stopped (pid: {})", pid);
 179
 180                drop(mcp_config_path);
 181                anyhow::Ok(())
 182            })
 183            .detach();
 184
 185            let turn_state = Rc::new(RefCell::new(TurnState::None));
 186
 187            let handler_task = cx.spawn({
 188                let turn_state = turn_state.clone();
 189                let mut thread_rx = thread_rx.clone();
 190                async move |cx| {
 191                    while let Some(message) = incoming_message_rx.next().await {
 192                        ClaudeAgentSession::handle_message(
 193                            thread_rx.clone(),
 194                            message,
 195                            turn_state.clone(),
 196                            cx,
 197                        )
 198                        .await
 199                    }
 200
 201                    if let Some(status) = child.status().await.log_err() {
 202                        if let Some(thread) = thread_rx.recv().await.ok() {
 203                            thread
 204                                .update(cx, |thread, cx| {
 205                                    thread.emit_server_exited(status, cx);
 206                                })
 207                                .ok();
 208                        }
 209                    }
 210                }
 211            });
 212
 213            let thread = cx.new(|cx| {
 214                AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
 215            })?;
 216
 217            thread_tx.send(thread.downgrade())?;
 218
 219            let session = ClaudeAgentSession {
 220                outgoing_tx,
 221                turn_state,
 222                _handler_task: handler_task,
 223                _mcp_server: Some(permission_mcp_server),
 224            };
 225
 226            self.sessions.borrow_mut().insert(session_id, session);
 227
 228            Ok(thread)
 229        })
 230    }
 231
 232    fn auth_methods(&self) -> &[acp::AuthMethod] {
 233        &[]
 234    }
 235
 236    fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 237        Task::ready(Err(anyhow!("Authentication not supported")))
 238    }
 239
 240    fn prompt(
 241        &self,
 242        _id: Option<acp_thread::UserMessageId>,
 243        params: acp::PromptRequest,
 244        cx: &mut App,
 245    ) -> Task<Result<acp::PromptResponse>> {
 246        let sessions = self.sessions.borrow();
 247        let Some(session) = sessions.get(&params.session_id) else {
 248            return Task::ready(Err(anyhow!(
 249                "Attempted to send message to nonexistent session {}",
 250                params.session_id
 251            )));
 252        };
 253
 254        let (end_tx, end_rx) = oneshot::channel();
 255        session.turn_state.replace(TurnState::InProgress { end_tx });
 256
 257        let mut content = String::new();
 258        for chunk in params.prompt {
 259            match chunk {
 260                acp::ContentBlock::Text(text_content) => {
 261                    content.push_str(&text_content.text);
 262                }
 263                acp::ContentBlock::ResourceLink(resource_link) => {
 264                    content.push_str(&format!("@{}", resource_link.uri));
 265                }
 266                acp::ContentBlock::Audio(_)
 267                | acp::ContentBlock::Image(_)
 268                | acp::ContentBlock::Resource(_) => {
 269                    // TODO
 270                }
 271            }
 272        }
 273
 274        if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
 275            message: Message {
 276                role: Role::User,
 277                content: Content::UntaggedText(content),
 278                id: None,
 279                model: None,
 280                stop_reason: None,
 281                stop_sequence: None,
 282                usage: None,
 283            },
 284            session_id: Some(params.session_id.to_string()),
 285        }) {
 286            return Task::ready(Err(anyhow!(err)));
 287        }
 288
 289        cx.foreground_executor().spawn(async move { end_rx.await? })
 290    }
 291
 292    fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
 293        let sessions = self.sessions.borrow();
 294        let Some(session) = sessions.get(session_id) else {
 295            log::warn!("Attempted to cancel nonexistent session {}", session_id);
 296            return;
 297        };
 298
 299        let request_id = new_request_id();
 300
 301        let turn_state = session.turn_state.take();
 302        let TurnState::InProgress { end_tx } = turn_state else {
 303            // Already canceled or idle, put it back
 304            session.turn_state.replace(turn_state);
 305            return;
 306        };
 307
 308        session.turn_state.replace(TurnState::CancelRequested {
 309            end_tx,
 310            request_id: request_id.clone(),
 311        });
 312
 313        session
 314            .outgoing_tx
 315            .unbounded_send(SdkMessage::ControlRequest {
 316                request_id,
 317                request: ControlRequest::Interrupt,
 318            })
 319            .log_err();
 320    }
 321
 322    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 323        self
 324    }
 325}
 326
 327#[derive(Clone, Copy)]
 328enum ClaudeSessionMode {
 329    Start,
 330    #[expect(dead_code)]
 331    Resume,
 332}
 333
 334fn spawn_claude(
 335    command: &AgentServerCommand,
 336    mode: ClaudeSessionMode,
 337    session_id: acp::SessionId,
 338    api_key: language_models::provider::anthropic::ApiKey,
 339    mcp_config_path: &Path,
 340    root_dir: &Path,
 341) -> Result<Child> {
 342    let child = util::command::new_smol_command(&command.path)
 343        .args([
 344            "--input-format",
 345            "stream-json",
 346            "--output-format",
 347            "stream-json",
 348            "--print",
 349            "--verbose",
 350            "--mcp-config",
 351            mcp_config_path.to_string_lossy().as_ref(),
 352            "--permission-prompt-tool",
 353            &format!(
 354                "mcp__{}__{}",
 355                mcp_server::SERVER_NAME,
 356                mcp_server::PermissionTool::NAME,
 357            ),
 358            "--allowedTools",
 359            &format!(
 360                "mcp__{}__{},mcp__{}__{}",
 361                mcp_server::SERVER_NAME,
 362                mcp_server::EditTool::NAME,
 363                mcp_server::SERVER_NAME,
 364                mcp_server::ReadTool::NAME
 365            ),
 366            "--disallowedTools",
 367            "Read,Edit",
 368        ])
 369        .args(match mode {
 370            ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
 371            ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
 372        })
 373        .args(command.args.iter().map(|arg| arg.as_str()))
 374        .envs(command.env.iter().flatten())
 375        .env("ANTHROPIC_API_KEY", api_key.key)
 376        .current_dir(root_dir)
 377        .stdin(std::process::Stdio::piped())
 378        .stdout(std::process::Stdio::piped())
 379        .stderr(std::process::Stdio::piped())
 380        .kill_on_drop(true)
 381        .spawn()?;
 382
 383    Ok(child)
 384}
 385
 386struct ClaudeAgentSession {
 387    outgoing_tx: UnboundedSender<SdkMessage>,
 388    turn_state: Rc<RefCell<TurnState>>,
 389    _mcp_server: Option<ClaudeZedMcpServer>,
 390    _handler_task: Task<()>,
 391}
 392
 393#[derive(Debug, Default)]
 394enum TurnState {
 395    #[default]
 396    None,
 397    InProgress {
 398        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 399    },
 400    CancelRequested {
 401        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 402        request_id: String,
 403    },
 404    CancelConfirmed {
 405        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 406    },
 407}
 408
 409impl TurnState {
 410    fn is_canceled(&self) -> bool {
 411        matches!(self, TurnState::CancelConfirmed { .. })
 412    }
 413
 414    fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
 415        match self {
 416            TurnState::None => None,
 417            TurnState::InProgress { end_tx, .. } => Some(end_tx),
 418            TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
 419            TurnState::CancelConfirmed { end_tx } => Some(end_tx),
 420        }
 421    }
 422
 423    fn confirm_cancellation(self, id: &str) -> Self {
 424        match self {
 425            TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
 426                TurnState::CancelConfirmed { end_tx }
 427            }
 428            _ => self,
 429        }
 430    }
 431}
 432
 433impl ClaudeAgentSession {
 434    async fn handle_message(
 435        mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
 436        message: SdkMessage,
 437        turn_state: Rc<RefCell<TurnState>>,
 438        cx: &mut AsyncApp,
 439    ) {
 440        match message {
 441            // we should only be sending these out, they don't need to be in the thread
 442            SdkMessage::ControlRequest { .. } => {}
 443            SdkMessage::User {
 444                message,
 445                session_id: _,
 446            } => {
 447                let Some(thread) = thread_rx
 448                    .recv()
 449                    .await
 450                    .log_err()
 451                    .and_then(|entity| entity.upgrade())
 452                else {
 453                    log::error!("Received an SDK message but thread is gone");
 454                    return;
 455                };
 456
 457                for chunk in message.content.chunks() {
 458                    match chunk {
 459                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 460                            if !turn_state.borrow().is_canceled() {
 461                                thread
 462                                    .update(cx, |thread, cx| {
 463                                        thread.push_user_content_block(None, text.into(), cx)
 464                                    })
 465                                    .log_err();
 466                            }
 467                        }
 468                        ContentChunk::ToolResult {
 469                            content,
 470                            tool_use_id,
 471                        } => {
 472                            let content = content.to_string();
 473                            thread
 474                                .update(cx, |thread, cx| {
 475                                    thread.update_tool_call(
 476                                        acp::ToolCallUpdate {
 477                                            id: acp::ToolCallId(tool_use_id.into()),
 478                                            fields: acp::ToolCallUpdateFields {
 479                                                status: if turn_state.borrow().is_canceled() {
 480                                                    // Do not set to completed if turn was canceled
 481                                                    None
 482                                                } else {
 483                                                    Some(acp::ToolCallStatus::Completed)
 484                                                },
 485                                                content: (!content.is_empty())
 486                                                    .then(|| vec![content.into()]),
 487                                                ..Default::default()
 488                                            },
 489                                        },
 490                                        cx,
 491                                    )
 492                                })
 493                                .log_err();
 494                        }
 495                        ContentChunk::Thinking { .. }
 496                        | ContentChunk::RedactedThinking
 497                        | ContentChunk::ToolUse { .. } => {
 498                            debug_panic!(
 499                                "Should not get {:?} with role: assistant. should we handle this?",
 500                                chunk
 501                            );
 502                        }
 503
 504                        ContentChunk::Image
 505                        | ContentChunk::Document
 506                        | ContentChunk::WebSearchToolResult => {
 507                            thread
 508                                .update(cx, |thread, cx| {
 509                                    thread.push_assistant_content_block(
 510                                        format!("Unsupported content: {:?}", chunk).into(),
 511                                        false,
 512                                        cx,
 513                                    )
 514                                })
 515                                .log_err();
 516                        }
 517                    }
 518                }
 519            }
 520            SdkMessage::Assistant {
 521                message,
 522                session_id: _,
 523            } => {
 524                let Some(thread) = thread_rx
 525                    .recv()
 526                    .await
 527                    .log_err()
 528                    .and_then(|entity| entity.upgrade())
 529                else {
 530                    log::error!("Received an SDK message but thread is gone");
 531                    return;
 532                };
 533
 534                for chunk in message.content.chunks() {
 535                    match chunk {
 536                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 537                            thread
 538                                .update(cx, |thread, cx| {
 539                                    thread.push_assistant_content_block(text.into(), false, cx)
 540                                })
 541                                .log_err();
 542                        }
 543                        ContentChunk::Thinking { thinking } => {
 544                            thread
 545                                .update(cx, |thread, cx| {
 546                                    thread.push_assistant_content_block(thinking.into(), true, cx)
 547                                })
 548                                .log_err();
 549                        }
 550                        ContentChunk::RedactedThinking => {
 551                            thread
 552                                .update(cx, |thread, cx| {
 553                                    thread.push_assistant_content_block(
 554                                        "[REDACTED]".into(),
 555                                        true,
 556                                        cx,
 557                                    )
 558                                })
 559                                .log_err();
 560                        }
 561                        ContentChunk::ToolUse { id, name, input } => {
 562                            let claude_tool = ClaudeTool::infer(&name, input);
 563
 564                            thread
 565                                .update(cx, |thread, cx| {
 566                                    if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
 567                                        thread.update_plan(
 568                                            acp::Plan {
 569                                                entries: params
 570                                                    .todos
 571                                                    .into_iter()
 572                                                    .map(Into::into)
 573                                                    .collect(),
 574                                            },
 575                                            cx,
 576                                        )
 577                                    } else {
 578                                        thread.upsert_tool_call(
 579                                            claude_tool.as_acp(acp::ToolCallId(id.into())),
 580                                            cx,
 581                                        )?;
 582                                    }
 583                                    anyhow::Ok(())
 584                                })
 585                                .log_err();
 586                        }
 587                        ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
 588                            debug_panic!(
 589                                "Should not get tool results with role: assistant. should we handle this?"
 590                            );
 591                        }
 592                        ContentChunk::Image | ContentChunk::Document => {
 593                            thread
 594                                .update(cx, |thread, cx| {
 595                                    thread.push_assistant_content_block(
 596                                        format!("Unsupported content: {:?}", chunk).into(),
 597                                        false,
 598                                        cx,
 599                                    )
 600                                })
 601                                .log_err();
 602                        }
 603                    }
 604                }
 605            }
 606            SdkMessage::Result {
 607                is_error,
 608                subtype,
 609                result,
 610                ..
 611            } => {
 612                let turn_state = turn_state.take();
 613                let was_canceled = turn_state.is_canceled();
 614                let Some(end_turn_tx) = turn_state.end_tx() else {
 615                    debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
 616                    return;
 617                };
 618
 619                if is_error || (!was_canceled && subtype == ResultErrorType::ErrorDuringExecution) {
 620                    end_turn_tx
 621                        .send(Err(anyhow!(
 622                            "Error: {}",
 623                            result.unwrap_or_else(|| subtype.to_string())
 624                        )))
 625                        .ok();
 626                } else {
 627                    let stop_reason = match subtype {
 628                        ResultErrorType::Success => acp::StopReason::EndTurn,
 629                        ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
 630                        ResultErrorType::ErrorDuringExecution => acp::StopReason::Canceled,
 631                    };
 632                    end_turn_tx
 633                        .send(Ok(acp::PromptResponse { stop_reason }))
 634                        .ok();
 635                }
 636            }
 637            SdkMessage::ControlResponse { response } => {
 638                if matches!(response.subtype, ResultErrorType::Success) {
 639                    let new_state = turn_state.take().confirm_cancellation(&response.request_id);
 640                    turn_state.replace(new_state);
 641                } else {
 642                    log::error!("Control response error: {:?}", response);
 643                }
 644            }
 645            SdkMessage::System { .. } => {}
 646        }
 647    }
 648
 649    async fn handle_io(
 650        mut outgoing_rx: UnboundedReceiver<SdkMessage>,
 651        incoming_tx: UnboundedSender<SdkMessage>,
 652        mut outgoing_bytes: impl Unpin + AsyncWrite,
 653        incoming_bytes: impl Unpin + AsyncRead,
 654    ) -> Result<UnboundedReceiver<SdkMessage>> {
 655        let mut output_reader = BufReader::new(incoming_bytes);
 656        let mut outgoing_line = Vec::new();
 657        let mut incoming_line = String::new();
 658        loop {
 659            select_biased! {
 660                message = outgoing_rx.next() => {
 661                    if let Some(message) = message {
 662                        outgoing_line.clear();
 663                        serde_json::to_writer(&mut outgoing_line, &message)?;
 664                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
 665                        outgoing_line.push(b'\n');
 666                        outgoing_bytes.write_all(&outgoing_line).await.ok();
 667                    } else {
 668                        break;
 669                    }
 670                }
 671                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
 672                    if bytes_read? == 0 {
 673                        break
 674                    }
 675                    log::trace!("recv: {}", &incoming_line);
 676                    match serde_json::from_str::<SdkMessage>(&incoming_line) {
 677                        Ok(message) => {
 678                            incoming_tx.unbounded_send(message).log_err();
 679                        }
 680                        Err(error) => {
 681                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
 682                        }
 683                    }
 684                    incoming_line.clear();
 685                }
 686            }
 687        }
 688
 689        Ok(outgoing_rx)
 690    }
 691}
 692
 693#[derive(Debug, Clone, Serialize, Deserialize)]
 694struct Message {
 695    role: Role,
 696    content: Content,
 697    #[serde(skip_serializing_if = "Option::is_none")]
 698    id: Option<String>,
 699    #[serde(skip_serializing_if = "Option::is_none")]
 700    model: Option<String>,
 701    #[serde(skip_serializing_if = "Option::is_none")]
 702    stop_reason: Option<String>,
 703    #[serde(skip_serializing_if = "Option::is_none")]
 704    stop_sequence: Option<String>,
 705    #[serde(skip_serializing_if = "Option::is_none")]
 706    usage: Option<Usage>,
 707}
 708
 709#[derive(Debug, Clone, Serialize, Deserialize)]
 710#[serde(untagged)]
 711enum Content {
 712    UntaggedText(String),
 713    Chunks(Vec<ContentChunk>),
 714}
 715
 716impl Content {
 717    pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
 718        match self {
 719            Self::Chunks(chunks) => chunks.into_iter(),
 720            Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
 721        }
 722    }
 723}
 724
 725impl Display for Content {
 726    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 727        match self {
 728            Content::UntaggedText(txt) => write!(f, "{}", txt),
 729            Content::Chunks(chunks) => {
 730                for chunk in chunks {
 731                    write!(f, "{}", chunk)?;
 732                }
 733                Ok(())
 734            }
 735        }
 736    }
 737}
 738
 739#[derive(Debug, Clone, Serialize, Deserialize)]
 740#[serde(tag = "type", rename_all = "snake_case")]
 741enum ContentChunk {
 742    Text {
 743        text: String,
 744    },
 745    ToolUse {
 746        id: String,
 747        name: String,
 748        input: serde_json::Value,
 749    },
 750    ToolResult {
 751        content: Content,
 752        tool_use_id: String,
 753    },
 754    Thinking {
 755        thinking: String,
 756    },
 757    RedactedThinking,
 758    // TODO
 759    Image,
 760    Document,
 761    WebSearchToolResult,
 762    #[serde(untagged)]
 763    UntaggedText(String),
 764}
 765
 766impl Display for ContentChunk {
 767    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 768        match self {
 769            ContentChunk::Text { text } => write!(f, "{}", text),
 770            ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
 771            ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
 772            ContentChunk::UntaggedText(text) => write!(f, "{}", text),
 773            ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
 774            ContentChunk::Image
 775            | ContentChunk::Document
 776            | ContentChunk::ToolUse { .. }
 777            | ContentChunk::WebSearchToolResult => {
 778                write!(f, "\n{:?}\n", &self)
 779            }
 780        }
 781    }
 782}
 783
 784#[derive(Debug, Clone, Serialize, Deserialize)]
 785struct Usage {
 786    input_tokens: u32,
 787    cache_creation_input_tokens: u32,
 788    cache_read_input_tokens: u32,
 789    output_tokens: u32,
 790    service_tier: String,
 791}
 792
 793#[derive(Debug, Clone, Serialize, Deserialize)]
 794#[serde(rename_all = "snake_case")]
 795enum Role {
 796    System,
 797    Assistant,
 798    User,
 799}
 800
 801#[derive(Debug, Clone, Serialize, Deserialize)]
 802struct MessageParam {
 803    role: Role,
 804    content: String,
 805}
 806
 807#[derive(Debug, Clone, Serialize, Deserialize)]
 808#[serde(tag = "type", rename_all = "snake_case")]
 809enum SdkMessage {
 810    // An assistant message
 811    Assistant {
 812        message: Message, // from Anthropic SDK
 813        #[serde(skip_serializing_if = "Option::is_none")]
 814        session_id: Option<String>,
 815    },
 816    // A user message
 817    User {
 818        message: Message, // from Anthropic SDK
 819        #[serde(skip_serializing_if = "Option::is_none")]
 820        session_id: Option<String>,
 821    },
 822    // Emitted as the last message in a conversation
 823    Result {
 824        subtype: ResultErrorType,
 825        duration_ms: f64,
 826        duration_api_ms: f64,
 827        is_error: bool,
 828        num_turns: i32,
 829        #[serde(skip_serializing_if = "Option::is_none")]
 830        result: Option<String>,
 831        session_id: String,
 832        total_cost_usd: f64,
 833    },
 834    // Emitted as the first message at the start of a conversation
 835    System {
 836        cwd: String,
 837        session_id: String,
 838        tools: Vec<String>,
 839        model: String,
 840        mcp_servers: Vec<McpServer>,
 841        #[serde(rename = "apiKeySource")]
 842        api_key_source: String,
 843        #[serde(rename = "permissionMode")]
 844        permission_mode: PermissionMode,
 845    },
 846    /// Messages used to control the conversation, outside of chat messages to the model
 847    ControlRequest {
 848        request_id: String,
 849        request: ControlRequest,
 850    },
 851    /// Response to a control request
 852    ControlResponse { response: ControlResponse },
 853}
 854
 855#[derive(Debug, Clone, Serialize, Deserialize)]
 856#[serde(tag = "subtype", rename_all = "snake_case")]
 857enum ControlRequest {
 858    /// Cancel the current conversation
 859    Interrupt,
 860}
 861
 862#[derive(Debug, Clone, Serialize, Deserialize)]
 863struct ControlResponse {
 864    request_id: String,
 865    subtype: ResultErrorType,
 866}
 867
 868#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
 869#[serde(rename_all = "snake_case")]
 870enum ResultErrorType {
 871    Success,
 872    ErrorMaxTurns,
 873    ErrorDuringExecution,
 874}
 875
 876impl Display for ResultErrorType {
 877    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 878        match self {
 879            ResultErrorType::Success => write!(f, "success"),
 880            ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
 881            ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
 882        }
 883    }
 884}
 885
 886fn new_request_id() -> String {
 887    use rand::Rng;
 888    // In the Claude Code TS SDK they just generate a random 12 character string,
 889    // `Math.random().toString(36).substring(2, 15)`
 890    rand::thread_rng()
 891        .sample_iter(&rand::distributions::Alphanumeric)
 892        .take(12)
 893        .map(char::from)
 894        .collect()
 895}
 896
 897#[derive(Debug, Clone, Serialize, Deserialize)]
 898struct McpServer {
 899    name: String,
 900    status: String,
 901}
 902
 903#[derive(Debug, Clone, Serialize, Deserialize)]
 904#[serde(rename_all = "camelCase")]
 905enum PermissionMode {
 906    Default,
 907    AcceptEdits,
 908    BypassPermissions,
 909    Plan,
 910}
 911
 912#[cfg(test)]
 913pub(crate) mod tests {
 914    use super::*;
 915    use crate::e2e_tests;
 916    use gpui::TestAppContext;
 917    use serde_json::json;
 918
 919    crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
 920
 921    pub fn local_command() -> AgentServerCommand {
 922        AgentServerCommand {
 923            path: "claude".into(),
 924            args: vec![],
 925            env: None,
 926        }
 927    }
 928
 929    #[gpui::test]
 930    #[cfg_attr(not(feature = "e2e"), ignore)]
 931    async fn test_todo_plan(cx: &mut TestAppContext) {
 932        let fs = e2e_tests::init_test(cx).await;
 933        let project = Project::test(fs, [], cx).await;
 934        let thread =
 935            e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
 936
 937        thread
 938            .update(cx, |thread, cx| {
 939                thread.send_raw(
 940                    "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
 941                    cx,
 942                )
 943            })
 944            .await
 945            .unwrap();
 946
 947        let mut entries_len = 0;
 948
 949        thread.read_with(cx, |thread, _| {
 950            entries_len = thread.plan().entries.len();
 951            assert!(thread.plan().entries.len() > 0, "Empty plan");
 952        });
 953
 954        thread
 955            .update(cx, |thread, cx| {
 956                thread.send_raw(
 957                    "Mark the first entry status as in progress without acting on it.",
 958                    cx,
 959                )
 960            })
 961            .await
 962            .unwrap();
 963
 964        thread.read_with(cx, |thread, _| {
 965            assert!(matches!(
 966                thread.plan().entries[0].status,
 967                acp::PlanEntryStatus::InProgress
 968            ));
 969            assert_eq!(thread.plan().entries.len(), entries_len);
 970        });
 971
 972        thread
 973            .update(cx, |thread, cx| {
 974                thread.send_raw(
 975                    "Now mark the first entry as completed without acting on it.",
 976                    cx,
 977                )
 978            })
 979            .await
 980            .unwrap();
 981
 982        thread.read_with(cx, |thread, _| {
 983            assert!(matches!(
 984                thread.plan().entries[0].status,
 985                acp::PlanEntryStatus::Completed
 986            ));
 987            assert_eq!(thread.plan().entries.len(), entries_len);
 988        });
 989    }
 990
 991    #[test]
 992    fn test_deserialize_content_untagged_text() {
 993        let json = json!("Hello, world!");
 994        let content: Content = serde_json::from_value(json).unwrap();
 995        match content {
 996            Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
 997            _ => panic!("Expected UntaggedText variant"),
 998        }
 999    }
1000
1001    #[test]
1002    fn test_deserialize_content_chunks() {
1003        let json = json!([
1004            {
1005                "type": "text",
1006                "text": "Hello"
1007            },
1008            {
1009                "type": "tool_use",
1010                "id": "tool_123",
1011                "name": "calculator",
1012                "input": {"operation": "add", "a": 1, "b": 2}
1013            }
1014        ]);
1015        let content: Content = serde_json::from_value(json).unwrap();
1016        match content {
1017            Content::Chunks(chunks) => {
1018                assert_eq!(chunks.len(), 2);
1019                match &chunks[0] {
1020                    ContentChunk::Text { text } => assert_eq!(text, "Hello"),
1021                    _ => panic!("Expected Text chunk"),
1022                }
1023                match &chunks[1] {
1024                    ContentChunk::ToolUse { id, name, input } => {
1025                        assert_eq!(id, "tool_123");
1026                        assert_eq!(name, "calculator");
1027                        assert_eq!(input["operation"], "add");
1028                        assert_eq!(input["a"], 1);
1029                        assert_eq!(input["b"], 2);
1030                    }
1031                    _ => panic!("Expected ToolUse chunk"),
1032                }
1033            }
1034            _ => panic!("Expected Chunks variant"),
1035        }
1036    }
1037
1038    #[test]
1039    fn test_deserialize_tool_result_untagged_text() {
1040        let json = json!({
1041            "type": "tool_result",
1042            "content": "Result content",
1043            "tool_use_id": "tool_456"
1044        });
1045        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1046        match chunk {
1047            ContentChunk::ToolResult {
1048                content,
1049                tool_use_id,
1050            } => {
1051                match content {
1052                    Content::UntaggedText(text) => assert_eq!(text, "Result content"),
1053                    _ => panic!("Expected UntaggedText content"),
1054                }
1055                assert_eq!(tool_use_id, "tool_456");
1056            }
1057            _ => panic!("Expected ToolResult variant"),
1058        }
1059    }
1060
1061    #[test]
1062    fn test_deserialize_tool_result_chunks() {
1063        let json = json!({
1064            "type": "tool_result",
1065            "content": [
1066                {
1067                    "type": "text",
1068                    "text": "Processing complete"
1069                },
1070                {
1071                    "type": "text",
1072                    "text": "Result: 42"
1073                }
1074            ],
1075            "tool_use_id": "tool_789"
1076        });
1077        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1078        match chunk {
1079            ContentChunk::ToolResult {
1080                content,
1081                tool_use_id,
1082            } => {
1083                match content {
1084                    Content::Chunks(chunks) => {
1085                        assert_eq!(chunks.len(), 2);
1086                        match &chunks[0] {
1087                            ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1088                            _ => panic!("Expected Text chunk"),
1089                        }
1090                        match &chunks[1] {
1091                            ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1092                            _ => panic!("Expected Text chunk"),
1093                        }
1094                    }
1095                    _ => panic!("Expected Chunks content"),
1096                }
1097                assert_eq!(tool_use_id, "tool_789");
1098            }
1099            _ => panic!("Expected ToolResult variant"),
1100        }
1101    }
1102}