claude.rs

   1mod mcp_server;
   2pub mod tools;
   3
   4use collections::HashMap;
   5use context_server::listener::McpServerTool;
   6use language_models::provider::anthropic::AnthropicLanguageModelProvider;
   7use project::Project;
   8use settings::SettingsStore;
   9use smol::process::Child;
  10use std::any::Any;
  11use std::cell::RefCell;
  12use std::fmt::Display;
  13use std::path::Path;
  14use std::rc::Rc;
  15use uuid::Uuid;
  16
  17use agent_client_protocol as acp;
  18use anyhow::{Context as _, Result, anyhow};
  19use futures::channel::oneshot;
  20use futures::{AsyncBufReadExt, AsyncWriteExt};
  21use futures::{
  22    AsyncRead, AsyncWrite, FutureExt, StreamExt,
  23    channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
  24    io::BufReader,
  25    select_biased,
  26};
  27use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  28use serde::{Deserialize, Serialize};
  29use util::{ResultExt, debug_panic};
  30
  31use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
  32use crate::claude::tools::ClaudeTool;
  33use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
  34use acp_thread::{AcpThread, AgentConnection, AuthRequired};
  35
  36#[derive(Clone)]
  37pub struct ClaudeCode;
  38
  39impl AgentServer for ClaudeCode {
  40    fn name(&self) -> &'static str {
  41        "Claude Code"
  42    }
  43
  44    fn empty_state_headline(&self) -> &'static str {
  45        self.name()
  46    }
  47
  48    fn empty_state_message(&self) -> &'static str {
  49        "How can I help you today?"
  50    }
  51
  52    fn logo(&self) -> ui::IconName {
  53        ui::IconName::AiClaude
  54    }
  55
  56    fn connect(
  57        &self,
  58        _root_dir: &Path,
  59        _project: &Entity<Project>,
  60        _cx: &mut App,
  61    ) -> Task<Result<Rc<dyn AgentConnection>>> {
  62        let connection = ClaudeAgentConnection {
  63            sessions: Default::default(),
  64        };
  65
  66        Task::ready(Ok(Rc::new(connection) as _))
  67    }
  68
  69    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
  70        self
  71    }
  72}
  73
  74struct ClaudeAgentConnection {
  75    sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
  76}
  77
  78impl AgentConnection for ClaudeAgentConnection {
  79    fn new_thread(
  80        self: Rc<Self>,
  81        project: Entity<Project>,
  82        cwd: &Path,
  83        cx: &mut App,
  84    ) -> Task<Result<Entity<AcpThread>>> {
  85        let cwd = cwd.to_owned();
  86        cx.spawn(async move |cx| {
  87            let settings = cx.read_global(|settings: &SettingsStore, _| {
  88                settings.get::<AllAgentServersSettings>(None).claude.clone()
  89            })?;
  90
  91            let Some(command) = AgentServerCommand::resolve(
  92                "claude",
  93                &[],
  94                Some(&util::paths::home_dir().join(".claude/local/claude")),
  95                settings,
  96                &project,
  97                cx,
  98            )
  99            .await
 100            else {
 101                anyhow::bail!("Failed to find claude binary");
 102            };
 103
 104            let api_key =
 105                cx.update(AnthropicLanguageModelProvider::api_key)?
 106                    .await
 107                    .map_err(|err| {
 108                        if err.is::<language_model::AuthenticateError>() {
 109                            anyhow!(AuthRequired::new().with_language_model_provider(
 110                                language_model::ANTHROPIC_PROVIDER_ID
 111                            ))
 112                        } else {
 113                            anyhow!(err)
 114                        }
 115                    })?;
 116
 117            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
 118            let fs = project.read_with(cx, |project, _cx| project.fs().clone())?;
 119            let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), fs, cx).await?;
 120
 121            let mut mcp_servers = HashMap::default();
 122            mcp_servers.insert(
 123                mcp_server::SERVER_NAME.to_string(),
 124                permission_mcp_server.server_config()?,
 125            );
 126            let mcp_config = McpConfig { mcp_servers };
 127
 128            let mcp_config_file = tempfile::NamedTempFile::new()?;
 129            let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
 130
 131            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
 132            mcp_config_file
 133                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
 134                .await?;
 135            mcp_config_file.flush().await?;
 136
 137            let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
 138            let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
 139
 140            let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
 141
 142            log::trace!("Starting session with id: {}", session_id);
 143
 144            let mut child = spawn_claude(
 145                &command,
 146                ClaudeSessionMode::Start,
 147                session_id.clone(),
 148                api_key,
 149                &mcp_config_path,
 150                &cwd,
 151            )?;
 152
 153            let stdout = child.stdout.take().context("Failed to take stdout")?;
 154            let stdin = child.stdin.take().context("Failed to take stdin")?;
 155            let stderr = child.stderr.take().context("Failed to take stderr")?;
 156
 157            let pid = child.id();
 158            log::trace!("Spawned (pid: {})", pid);
 159
 160            cx.background_spawn(async move {
 161                let mut stderr = BufReader::new(stderr);
 162                let mut line = String::new();
 163                while let Ok(n) = stderr.read_line(&mut line).await
 164                    && n > 0
 165                {
 166                    log::warn!("agent stderr: {}", &line);
 167                    line.clear();
 168                }
 169            })
 170            .detach();
 171
 172            cx.background_spawn(async move {
 173                let mut outgoing_rx = Some(outgoing_rx);
 174
 175                ClaudeAgentSession::handle_io(
 176                    outgoing_rx.take().unwrap(),
 177                    incoming_message_tx.clone(),
 178                    stdin,
 179                    stdout,
 180                )
 181                .await?;
 182
 183                log::trace!("Stopped (pid: {})", pid);
 184
 185                drop(mcp_config_path);
 186                anyhow::Ok(())
 187            })
 188            .detach();
 189
 190            let turn_state = Rc::new(RefCell::new(TurnState::None));
 191
 192            let handler_task = cx.spawn({
 193                let turn_state = turn_state.clone();
 194                let mut thread_rx = thread_rx.clone();
 195                async move |cx| {
 196                    while let Some(message) = incoming_message_rx.next().await {
 197                        ClaudeAgentSession::handle_message(
 198                            thread_rx.clone(),
 199                            message,
 200                            turn_state.clone(),
 201                            cx,
 202                        )
 203                        .await
 204                    }
 205
 206                    if let Some(status) = child.status().await.log_err() {
 207                        if let Some(thread) = thread_rx.recv().await.ok() {
 208                            thread
 209                                .update(cx, |thread, cx| {
 210                                    thread.emit_server_exited(status, cx);
 211                                })
 212                                .ok();
 213                        }
 214                    }
 215                }
 216            });
 217
 218            let thread = cx.new(|cx| {
 219                AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
 220            })?;
 221
 222            thread_tx.send(thread.downgrade())?;
 223
 224            let session = ClaudeAgentSession {
 225                outgoing_tx,
 226                turn_state,
 227                _handler_task: handler_task,
 228                _mcp_server: Some(permission_mcp_server),
 229            };
 230
 231            self.sessions.borrow_mut().insert(session_id, session);
 232
 233            Ok(thread)
 234        })
 235    }
 236
 237    fn auth_methods(&self) -> &[acp::AuthMethod] {
 238        &[]
 239    }
 240
 241    fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 242        Task::ready(Err(anyhow!("Authentication not supported")))
 243    }
 244
 245    fn prompt(
 246        &self,
 247        _id: Option<acp_thread::UserMessageId>,
 248        params: acp::PromptRequest,
 249        cx: &mut App,
 250    ) -> Task<Result<acp::PromptResponse>> {
 251        let sessions = self.sessions.borrow();
 252        let Some(session) = sessions.get(&params.session_id) else {
 253            return Task::ready(Err(anyhow!(
 254                "Attempted to send message to nonexistent session {}",
 255                params.session_id
 256            )));
 257        };
 258
 259        let (end_tx, end_rx) = oneshot::channel();
 260        session.turn_state.replace(TurnState::InProgress { end_tx });
 261
 262        let mut content = String::new();
 263        for chunk in params.prompt {
 264            match chunk {
 265                acp::ContentBlock::Text(text_content) => {
 266                    content.push_str(&text_content.text);
 267                }
 268                acp::ContentBlock::ResourceLink(resource_link) => {
 269                    content.push_str(&format!("@{}", resource_link.uri));
 270                }
 271                acp::ContentBlock::Audio(_)
 272                | acp::ContentBlock::Image(_)
 273                | acp::ContentBlock::Resource(_) => {
 274                    // TODO
 275                }
 276            }
 277        }
 278
 279        if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
 280            message: Message {
 281                role: Role::User,
 282                content: Content::UntaggedText(content),
 283                id: None,
 284                model: None,
 285                stop_reason: None,
 286                stop_sequence: None,
 287                usage: None,
 288            },
 289            session_id: Some(params.session_id.to_string()),
 290        }) {
 291            return Task::ready(Err(anyhow!(err)));
 292        }
 293
 294        cx.foreground_executor().spawn(async move { end_rx.await? })
 295    }
 296
 297    fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
 298        let sessions = self.sessions.borrow();
 299        let Some(session) = sessions.get(session_id) else {
 300            log::warn!("Attempted to cancel nonexistent session {}", session_id);
 301            return;
 302        };
 303
 304        let request_id = new_request_id();
 305
 306        let turn_state = session.turn_state.take();
 307        let TurnState::InProgress { end_tx } = turn_state else {
 308            // Already canceled or idle, put it back
 309            session.turn_state.replace(turn_state);
 310            return;
 311        };
 312
 313        session.turn_state.replace(TurnState::CancelRequested {
 314            end_tx,
 315            request_id: request_id.clone(),
 316        });
 317
 318        session
 319            .outgoing_tx
 320            .unbounded_send(SdkMessage::ControlRequest {
 321                request_id,
 322                request: ControlRequest::Interrupt,
 323            })
 324            .log_err();
 325    }
 326
 327    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 328        self
 329    }
 330}
 331
 332#[derive(Clone, Copy)]
 333enum ClaudeSessionMode {
 334    Start,
 335    #[expect(dead_code)]
 336    Resume,
 337}
 338
 339fn spawn_claude(
 340    command: &AgentServerCommand,
 341    mode: ClaudeSessionMode,
 342    session_id: acp::SessionId,
 343    api_key: language_models::provider::anthropic::ApiKey,
 344    mcp_config_path: &Path,
 345    root_dir: &Path,
 346) -> Result<Child> {
 347    let child = util::command::new_smol_command(&command.path)
 348        .args([
 349            "--input-format",
 350            "stream-json",
 351            "--output-format",
 352            "stream-json",
 353            "--print",
 354            "--verbose",
 355            "--mcp-config",
 356            mcp_config_path.to_string_lossy().as_ref(),
 357            "--permission-prompt-tool",
 358            &format!(
 359                "mcp__{}__{}",
 360                mcp_server::SERVER_NAME,
 361                mcp_server::PermissionTool::NAME,
 362            ),
 363            "--allowedTools",
 364            &format!(
 365                "mcp__{}__{},mcp__{}__{}",
 366                mcp_server::SERVER_NAME,
 367                mcp_server::EditTool::NAME,
 368                mcp_server::SERVER_NAME,
 369                mcp_server::ReadTool::NAME
 370            ),
 371            "--disallowedTools",
 372            "Read,Edit",
 373        ])
 374        .args(match mode {
 375            ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
 376            ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
 377        })
 378        .args(command.args.iter().map(|arg| arg.as_str()))
 379        .envs(command.env.iter().flatten())
 380        .env("ANTHROPIC_API_KEY", api_key.key)
 381        .current_dir(root_dir)
 382        .stdin(std::process::Stdio::piped())
 383        .stdout(std::process::Stdio::piped())
 384        .stderr(std::process::Stdio::piped())
 385        .kill_on_drop(true)
 386        .spawn()?;
 387
 388    Ok(child)
 389}
 390
 391struct ClaudeAgentSession {
 392    outgoing_tx: UnboundedSender<SdkMessage>,
 393    turn_state: Rc<RefCell<TurnState>>,
 394    _mcp_server: Option<ClaudeZedMcpServer>,
 395    _handler_task: Task<()>,
 396}
 397
 398#[derive(Debug, Default)]
 399enum TurnState {
 400    #[default]
 401    None,
 402    InProgress {
 403        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 404    },
 405    CancelRequested {
 406        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 407        request_id: String,
 408    },
 409    CancelConfirmed {
 410        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 411    },
 412}
 413
 414impl TurnState {
 415    fn is_canceled(&self) -> bool {
 416        matches!(self, TurnState::CancelConfirmed { .. })
 417    }
 418
 419    fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
 420        match self {
 421            TurnState::None => None,
 422            TurnState::InProgress { end_tx, .. } => Some(end_tx),
 423            TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
 424            TurnState::CancelConfirmed { end_tx } => Some(end_tx),
 425        }
 426    }
 427
 428    fn confirm_cancellation(self, id: &str) -> Self {
 429        match self {
 430            TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
 431                TurnState::CancelConfirmed { end_tx }
 432            }
 433            _ => self,
 434        }
 435    }
 436}
 437
 438impl ClaudeAgentSession {
 439    async fn handle_message(
 440        mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
 441        message: SdkMessage,
 442        turn_state: Rc<RefCell<TurnState>>,
 443        cx: &mut AsyncApp,
 444    ) {
 445        match message {
 446            // we should only be sending these out, they don't need to be in the thread
 447            SdkMessage::ControlRequest { .. } => {}
 448            SdkMessage::User {
 449                message,
 450                session_id: _,
 451            } => {
 452                let Some(thread) = thread_rx
 453                    .recv()
 454                    .await
 455                    .log_err()
 456                    .and_then(|entity| entity.upgrade())
 457                else {
 458                    log::error!("Received an SDK message but thread is gone");
 459                    return;
 460                };
 461
 462                for chunk in message.content.chunks() {
 463                    match chunk {
 464                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 465                            if !turn_state.borrow().is_canceled() {
 466                                thread
 467                                    .update(cx, |thread, cx| {
 468                                        thread.push_user_content_block(None, text.into(), cx)
 469                                    })
 470                                    .log_err();
 471                            }
 472                        }
 473                        ContentChunk::ToolResult {
 474                            content,
 475                            tool_use_id,
 476                        } => {
 477                            let content = content.to_string();
 478                            thread
 479                                .update(cx, |thread, cx| {
 480                                    thread.update_tool_call(
 481                                        acp::ToolCallUpdate {
 482                                            id: acp::ToolCallId(tool_use_id.into()),
 483                                            fields: acp::ToolCallUpdateFields {
 484                                                status: if turn_state.borrow().is_canceled() {
 485                                                    // Do not set to completed if turn was canceled
 486                                                    None
 487                                                } else {
 488                                                    Some(acp::ToolCallStatus::Completed)
 489                                                },
 490                                                content: (!content.is_empty())
 491                                                    .then(|| vec![content.into()]),
 492                                                ..Default::default()
 493                                            },
 494                                        },
 495                                        cx,
 496                                    )
 497                                })
 498                                .log_err();
 499                        }
 500                        ContentChunk::Thinking { .. }
 501                        | ContentChunk::RedactedThinking
 502                        | ContentChunk::ToolUse { .. } => {
 503                            debug_panic!(
 504                                "Should not get {:?} with role: assistant. should we handle this?",
 505                                chunk
 506                            );
 507                        }
 508
 509                        ContentChunk::Image
 510                        | ContentChunk::Document
 511                        | ContentChunk::WebSearchToolResult => {
 512                            thread
 513                                .update(cx, |thread, cx| {
 514                                    thread.push_assistant_content_block(
 515                                        format!("Unsupported content: {:?}", chunk).into(),
 516                                        false,
 517                                        cx,
 518                                    )
 519                                })
 520                                .log_err();
 521                        }
 522                    }
 523                }
 524            }
 525            SdkMessage::Assistant {
 526                message,
 527                session_id: _,
 528            } => {
 529                let Some(thread) = thread_rx
 530                    .recv()
 531                    .await
 532                    .log_err()
 533                    .and_then(|entity| entity.upgrade())
 534                else {
 535                    log::error!("Received an SDK message but thread is gone");
 536                    return;
 537                };
 538
 539                for chunk in message.content.chunks() {
 540                    match chunk {
 541                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 542                            thread
 543                                .update(cx, |thread, cx| {
 544                                    thread.push_assistant_content_block(text.into(), false, cx)
 545                                })
 546                                .log_err();
 547                        }
 548                        ContentChunk::Thinking { thinking } => {
 549                            thread
 550                                .update(cx, |thread, cx| {
 551                                    thread.push_assistant_content_block(thinking.into(), true, cx)
 552                                })
 553                                .log_err();
 554                        }
 555                        ContentChunk::RedactedThinking => {
 556                            thread
 557                                .update(cx, |thread, cx| {
 558                                    thread.push_assistant_content_block(
 559                                        "[REDACTED]".into(),
 560                                        true,
 561                                        cx,
 562                                    )
 563                                })
 564                                .log_err();
 565                        }
 566                        ContentChunk::ToolUse { id, name, input } => {
 567                            let claude_tool = ClaudeTool::infer(&name, input);
 568
 569                            thread
 570                                .update(cx, |thread, cx| {
 571                                    if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
 572                                        thread.update_plan(
 573                                            acp::Plan {
 574                                                entries: params
 575                                                    .todos
 576                                                    .into_iter()
 577                                                    .map(Into::into)
 578                                                    .collect(),
 579                                            },
 580                                            cx,
 581                                        )
 582                                    } else {
 583                                        thread.upsert_tool_call(
 584                                            claude_tool.as_acp(acp::ToolCallId(id.into())),
 585                                            cx,
 586                                        )?;
 587                                    }
 588                                    anyhow::Ok(())
 589                                })
 590                                .log_err();
 591                        }
 592                        ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
 593                            debug_panic!(
 594                                "Should not get tool results with role: assistant. should we handle this?"
 595                            );
 596                        }
 597                        ContentChunk::Image | ContentChunk::Document => {
 598                            thread
 599                                .update(cx, |thread, cx| {
 600                                    thread.push_assistant_content_block(
 601                                        format!("Unsupported content: {:?}", chunk).into(),
 602                                        false,
 603                                        cx,
 604                                    )
 605                                })
 606                                .log_err();
 607                        }
 608                    }
 609                }
 610            }
 611            SdkMessage::Result {
 612                is_error,
 613                subtype,
 614                result,
 615                ..
 616            } => {
 617                let turn_state = turn_state.take();
 618                let was_canceled = turn_state.is_canceled();
 619                let Some(end_turn_tx) = turn_state.end_tx() else {
 620                    debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
 621                    return;
 622                };
 623
 624                if is_error || (!was_canceled && subtype == ResultErrorType::ErrorDuringExecution) {
 625                    end_turn_tx
 626                        .send(Err(anyhow!(
 627                            "Error: {}",
 628                            result.unwrap_or_else(|| subtype.to_string())
 629                        )))
 630                        .ok();
 631                } else {
 632                    let stop_reason = match subtype {
 633                        ResultErrorType::Success => acp::StopReason::EndTurn,
 634                        ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
 635                        ResultErrorType::ErrorDuringExecution => acp::StopReason::Canceled,
 636                    };
 637                    end_turn_tx
 638                        .send(Ok(acp::PromptResponse { stop_reason }))
 639                        .ok();
 640                }
 641            }
 642            SdkMessage::ControlResponse { response } => {
 643                if matches!(response.subtype, ResultErrorType::Success) {
 644                    let new_state = turn_state.take().confirm_cancellation(&response.request_id);
 645                    turn_state.replace(new_state);
 646                } else {
 647                    log::error!("Control response error: {:?}", response);
 648                }
 649            }
 650            SdkMessage::System { .. } => {}
 651        }
 652    }
 653
 654    async fn handle_io(
 655        mut outgoing_rx: UnboundedReceiver<SdkMessage>,
 656        incoming_tx: UnboundedSender<SdkMessage>,
 657        mut outgoing_bytes: impl Unpin + AsyncWrite,
 658        incoming_bytes: impl Unpin + AsyncRead,
 659    ) -> Result<UnboundedReceiver<SdkMessage>> {
 660        let mut output_reader = BufReader::new(incoming_bytes);
 661        let mut outgoing_line = Vec::new();
 662        let mut incoming_line = String::new();
 663        loop {
 664            select_biased! {
 665                message = outgoing_rx.next() => {
 666                    if let Some(message) = message {
 667                        outgoing_line.clear();
 668                        serde_json::to_writer(&mut outgoing_line, &message)?;
 669                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
 670                        outgoing_line.push(b'\n');
 671                        outgoing_bytes.write_all(&outgoing_line).await.ok();
 672                    } else {
 673                        break;
 674                    }
 675                }
 676                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
 677                    if bytes_read? == 0 {
 678                        break
 679                    }
 680                    log::trace!("recv: {}", &incoming_line);
 681                    match serde_json::from_str::<SdkMessage>(&incoming_line) {
 682                        Ok(message) => {
 683                            incoming_tx.unbounded_send(message).log_err();
 684                        }
 685                        Err(error) => {
 686                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
 687                        }
 688                    }
 689                    incoming_line.clear();
 690                }
 691            }
 692        }
 693
 694        Ok(outgoing_rx)
 695    }
 696}
 697
 698#[derive(Debug, Clone, Serialize, Deserialize)]
 699struct Message {
 700    role: Role,
 701    content: Content,
 702    #[serde(skip_serializing_if = "Option::is_none")]
 703    id: Option<String>,
 704    #[serde(skip_serializing_if = "Option::is_none")]
 705    model: Option<String>,
 706    #[serde(skip_serializing_if = "Option::is_none")]
 707    stop_reason: Option<String>,
 708    #[serde(skip_serializing_if = "Option::is_none")]
 709    stop_sequence: Option<String>,
 710    #[serde(skip_serializing_if = "Option::is_none")]
 711    usage: Option<Usage>,
 712}
 713
 714#[derive(Debug, Clone, Serialize, Deserialize)]
 715#[serde(untagged)]
 716enum Content {
 717    UntaggedText(String),
 718    Chunks(Vec<ContentChunk>),
 719}
 720
 721impl Content {
 722    pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
 723        match self {
 724            Self::Chunks(chunks) => chunks.into_iter(),
 725            Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
 726        }
 727    }
 728}
 729
 730impl Display for Content {
 731    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 732        match self {
 733            Content::UntaggedText(txt) => write!(f, "{}", txt),
 734            Content::Chunks(chunks) => {
 735                for chunk in chunks {
 736                    write!(f, "{}", chunk)?;
 737                }
 738                Ok(())
 739            }
 740        }
 741    }
 742}
 743
 744#[derive(Debug, Clone, Serialize, Deserialize)]
 745#[serde(tag = "type", rename_all = "snake_case")]
 746enum ContentChunk {
 747    Text {
 748        text: String,
 749    },
 750    ToolUse {
 751        id: String,
 752        name: String,
 753        input: serde_json::Value,
 754    },
 755    ToolResult {
 756        content: Content,
 757        tool_use_id: String,
 758    },
 759    Thinking {
 760        thinking: String,
 761    },
 762    RedactedThinking,
 763    // TODO
 764    Image,
 765    Document,
 766    WebSearchToolResult,
 767    #[serde(untagged)]
 768    UntaggedText(String),
 769}
 770
 771impl Display for ContentChunk {
 772    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 773        match self {
 774            ContentChunk::Text { text } => write!(f, "{}", text),
 775            ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
 776            ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
 777            ContentChunk::UntaggedText(text) => write!(f, "{}", text),
 778            ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
 779            ContentChunk::Image
 780            | ContentChunk::Document
 781            | ContentChunk::ToolUse { .. }
 782            | ContentChunk::WebSearchToolResult => {
 783                write!(f, "\n{:?}\n", &self)
 784            }
 785        }
 786    }
 787}
 788
 789#[derive(Debug, Clone, Serialize, Deserialize)]
 790struct Usage {
 791    input_tokens: u32,
 792    cache_creation_input_tokens: u32,
 793    cache_read_input_tokens: u32,
 794    output_tokens: u32,
 795    service_tier: String,
 796}
 797
 798#[derive(Debug, Clone, Serialize, Deserialize)]
 799#[serde(rename_all = "snake_case")]
 800enum Role {
 801    System,
 802    Assistant,
 803    User,
 804}
 805
 806#[derive(Debug, Clone, Serialize, Deserialize)]
 807struct MessageParam {
 808    role: Role,
 809    content: String,
 810}
 811
 812#[derive(Debug, Clone, Serialize, Deserialize)]
 813#[serde(tag = "type", rename_all = "snake_case")]
 814enum SdkMessage {
 815    // An assistant message
 816    Assistant {
 817        message: Message, // from Anthropic SDK
 818        #[serde(skip_serializing_if = "Option::is_none")]
 819        session_id: Option<String>,
 820    },
 821    // A user message
 822    User {
 823        message: Message, // from Anthropic SDK
 824        #[serde(skip_serializing_if = "Option::is_none")]
 825        session_id: Option<String>,
 826    },
 827    // Emitted as the last message in a conversation
 828    Result {
 829        subtype: ResultErrorType,
 830        duration_ms: f64,
 831        duration_api_ms: f64,
 832        is_error: bool,
 833        num_turns: i32,
 834        #[serde(skip_serializing_if = "Option::is_none")]
 835        result: Option<String>,
 836        session_id: String,
 837        total_cost_usd: f64,
 838    },
 839    // Emitted as the first message at the start of a conversation
 840    System {
 841        cwd: String,
 842        session_id: String,
 843        tools: Vec<String>,
 844        model: String,
 845        mcp_servers: Vec<McpServer>,
 846        #[serde(rename = "apiKeySource")]
 847        api_key_source: String,
 848        #[serde(rename = "permissionMode")]
 849        permission_mode: PermissionMode,
 850    },
 851    /// Messages used to control the conversation, outside of chat messages to the model
 852    ControlRequest {
 853        request_id: String,
 854        request: ControlRequest,
 855    },
 856    /// Response to a control request
 857    ControlResponse { response: ControlResponse },
 858}
 859
 860#[derive(Debug, Clone, Serialize, Deserialize)]
 861#[serde(tag = "subtype", rename_all = "snake_case")]
 862enum ControlRequest {
 863    /// Cancel the current conversation
 864    Interrupt,
 865}
 866
 867#[derive(Debug, Clone, Serialize, Deserialize)]
 868struct ControlResponse {
 869    request_id: String,
 870    subtype: ResultErrorType,
 871}
 872
 873#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
 874#[serde(rename_all = "snake_case")]
 875enum ResultErrorType {
 876    Success,
 877    ErrorMaxTurns,
 878    ErrorDuringExecution,
 879}
 880
 881impl Display for ResultErrorType {
 882    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 883        match self {
 884            ResultErrorType::Success => write!(f, "success"),
 885            ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
 886            ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
 887        }
 888    }
 889}
 890
 891fn new_request_id() -> String {
 892    use rand::Rng;
 893    // In the Claude Code TS SDK they just generate a random 12 character string,
 894    // `Math.random().toString(36).substring(2, 15)`
 895    rand::thread_rng()
 896        .sample_iter(&rand::distributions::Alphanumeric)
 897        .take(12)
 898        .map(char::from)
 899        .collect()
 900}
 901
 902#[derive(Debug, Clone, Serialize, Deserialize)]
 903struct McpServer {
 904    name: String,
 905    status: String,
 906}
 907
 908#[derive(Debug, Clone, Serialize, Deserialize)]
 909#[serde(rename_all = "camelCase")]
 910enum PermissionMode {
 911    Default,
 912    AcceptEdits,
 913    BypassPermissions,
 914    Plan,
 915}
 916
 917#[cfg(test)]
 918pub(crate) mod tests {
 919    use super::*;
 920    use crate::e2e_tests;
 921    use gpui::TestAppContext;
 922    use serde_json::json;
 923
 924    crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
 925
 926    pub fn local_command() -> AgentServerCommand {
 927        AgentServerCommand {
 928            path: "claude".into(),
 929            args: vec![],
 930            env: None,
 931        }
 932    }
 933
 934    #[gpui::test]
 935    #[cfg_attr(not(feature = "e2e"), ignore)]
 936    async fn test_todo_plan(cx: &mut TestAppContext) {
 937        let fs = e2e_tests::init_test(cx).await;
 938        let project = Project::test(fs, [], cx).await;
 939        let thread =
 940            e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
 941
 942        thread
 943            .update(cx, |thread, cx| {
 944                thread.send_raw(
 945                    "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
 946                    cx,
 947                )
 948            })
 949            .await
 950            .unwrap();
 951
 952        let mut entries_len = 0;
 953
 954        thread.read_with(cx, |thread, _| {
 955            entries_len = thread.plan().entries.len();
 956            assert!(thread.plan().entries.len() > 0, "Empty plan");
 957        });
 958
 959        thread
 960            .update(cx, |thread, cx| {
 961                thread.send_raw(
 962                    "Mark the first entry status as in progress without acting on it.",
 963                    cx,
 964                )
 965            })
 966            .await
 967            .unwrap();
 968
 969        thread.read_with(cx, |thread, _| {
 970            assert!(matches!(
 971                thread.plan().entries[0].status,
 972                acp::PlanEntryStatus::InProgress
 973            ));
 974            assert_eq!(thread.plan().entries.len(), entries_len);
 975        });
 976
 977        thread
 978            .update(cx, |thread, cx| {
 979                thread.send_raw(
 980                    "Now mark the first entry as completed without acting on it.",
 981                    cx,
 982                )
 983            })
 984            .await
 985            .unwrap();
 986
 987        thread.read_with(cx, |thread, _| {
 988            assert!(matches!(
 989                thread.plan().entries[0].status,
 990                acp::PlanEntryStatus::Completed
 991            ));
 992            assert_eq!(thread.plan().entries.len(), entries_len);
 993        });
 994    }
 995
 996    #[test]
 997    fn test_deserialize_content_untagged_text() {
 998        let json = json!("Hello, world!");
 999        let content: Content = serde_json::from_value(json).unwrap();
1000        match content {
1001            Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
1002            _ => panic!("Expected UntaggedText variant"),
1003        }
1004    }
1005
1006    #[test]
1007    fn test_deserialize_content_chunks() {
1008        let json = json!([
1009            {
1010                "type": "text",
1011                "text": "Hello"
1012            },
1013            {
1014                "type": "tool_use",
1015                "id": "tool_123",
1016                "name": "calculator",
1017                "input": {"operation": "add", "a": 1, "b": 2}
1018            }
1019        ]);
1020        let content: Content = serde_json::from_value(json).unwrap();
1021        match content {
1022            Content::Chunks(chunks) => {
1023                assert_eq!(chunks.len(), 2);
1024                match &chunks[0] {
1025                    ContentChunk::Text { text } => assert_eq!(text, "Hello"),
1026                    _ => panic!("Expected Text chunk"),
1027                }
1028                match &chunks[1] {
1029                    ContentChunk::ToolUse { id, name, input } => {
1030                        assert_eq!(id, "tool_123");
1031                        assert_eq!(name, "calculator");
1032                        assert_eq!(input["operation"], "add");
1033                        assert_eq!(input["a"], 1);
1034                        assert_eq!(input["b"], 2);
1035                    }
1036                    _ => panic!("Expected ToolUse chunk"),
1037                }
1038            }
1039            _ => panic!("Expected Chunks variant"),
1040        }
1041    }
1042
1043    #[test]
1044    fn test_deserialize_tool_result_untagged_text() {
1045        let json = json!({
1046            "type": "tool_result",
1047            "content": "Result content",
1048            "tool_use_id": "tool_456"
1049        });
1050        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1051        match chunk {
1052            ContentChunk::ToolResult {
1053                content,
1054                tool_use_id,
1055            } => {
1056                match content {
1057                    Content::UntaggedText(text) => assert_eq!(text, "Result content"),
1058                    _ => panic!("Expected UntaggedText content"),
1059                }
1060                assert_eq!(tool_use_id, "tool_456");
1061            }
1062            _ => panic!("Expected ToolResult variant"),
1063        }
1064    }
1065
1066    #[test]
1067    fn test_deserialize_tool_result_chunks() {
1068        let json = json!({
1069            "type": "tool_result",
1070            "content": [
1071                {
1072                    "type": "text",
1073                    "text": "Processing complete"
1074                },
1075                {
1076                    "type": "text",
1077                    "text": "Result: 42"
1078                }
1079            ],
1080            "tool_use_id": "tool_789"
1081        });
1082        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1083        match chunk {
1084            ContentChunk::ToolResult {
1085                content,
1086                tool_use_id,
1087            } => {
1088                match content {
1089                    Content::Chunks(chunks) => {
1090                        assert_eq!(chunks.len(), 2);
1091                        match &chunks[0] {
1092                            ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1093                            _ => panic!("Expected Text chunk"),
1094                        }
1095                        match &chunks[1] {
1096                            ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1097                            _ => panic!("Expected Text chunk"),
1098                        }
1099                    }
1100                    _ => panic!("Expected Chunks content"),
1101                }
1102                assert_eq!(tool_use_id, "tool_789");
1103            }
1104            _ => panic!("Expected ToolResult variant"),
1105        }
1106    }
1107}