claude.rs

   1mod mcp_server;
   2pub mod tools;
   3
   4use collections::HashMap;
   5use context_server::listener::McpServerTool;
   6use language_models::provider::anthropic::AnthropicLanguageModelProvider;
   7use project::Project;
   8use settings::SettingsStore;
   9use smol::process::Child;
  10use std::any::Any;
  11use std::cell::RefCell;
  12use std::fmt::Display;
  13use std::path::Path;
  14use std::rc::Rc;
  15use uuid::Uuid;
  16
  17use agent_client_protocol as acp;
  18use anyhow::{Context as _, Result, anyhow};
  19use futures::channel::oneshot;
  20use futures::{AsyncBufReadExt, AsyncWriteExt};
  21use futures::{
  22    AsyncRead, AsyncWrite, FutureExt, StreamExt,
  23    channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
  24    io::BufReader,
  25    select_biased,
  26};
  27use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  28use serde::{Deserialize, Serialize};
  29use util::{ResultExt, debug_panic};
  30
  31use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
  32use crate::claude::tools::ClaudeTool;
  33use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
  34use acp_thread::{AcpThread, AgentConnection, AuthRequired};
  35
  36#[derive(Clone)]
  37pub struct ClaudeCode;
  38
  39impl AgentServer for ClaudeCode {
  40    fn name(&self) -> &'static str {
  41        "Claude Code"
  42    }
  43
  44    fn empty_state_headline(&self) -> &'static str {
  45        self.name()
  46    }
  47
  48    fn empty_state_message(&self) -> &'static str {
  49        "How can I help you today?"
  50    }
  51
  52    fn logo(&self) -> ui::IconName {
  53        ui::IconName::AiClaude
  54    }
  55
  56    fn connect(
  57        &self,
  58        _root_dir: &Path,
  59        _project: &Entity<Project>,
  60        _cx: &mut App,
  61    ) -> Task<Result<Rc<dyn AgentConnection>>> {
  62        let connection = ClaudeAgentConnection {
  63            sessions: Default::default(),
  64        };
  65
  66        Task::ready(Ok(Rc::new(connection) as _))
  67    }
  68
  69    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
  70        self
  71    }
  72}
  73
  74struct ClaudeAgentConnection {
  75    sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
  76}
  77
  78impl AgentConnection for ClaudeAgentConnection {
  79    fn new_thread(
  80        self: Rc<Self>,
  81        project: Entity<Project>,
  82        cwd: &Path,
  83        cx: &mut App,
  84    ) -> Task<Result<Entity<AcpThread>>> {
  85        let cwd = cwd.to_owned();
  86        cx.spawn(async move |cx| {
  87            let settings = cx.read_global(|settings: &SettingsStore, _| {
  88                settings.get::<AllAgentServersSettings>(None).claude.clone()
  89            })?;
  90
  91            let Some(command) = AgentServerCommand::resolve(
  92                "claude",
  93                &[],
  94                Some(&util::paths::home_dir().join(".claude/local/claude")),
  95                settings,
  96                &project,
  97                cx,
  98            )
  99            .await
 100            else {
 101                anyhow::bail!("Failed to find claude binary");
 102            };
 103
 104            let api_key =
 105                cx.update(AnthropicLanguageModelProvider::api_key)?
 106                    .await
 107                    .map_err(|err| {
 108                        if err.is::<language_model::AuthenticateError>() {
 109                            anyhow!(AuthRequired::new().with_language_model_provider(
 110                                language_model::ANTHROPIC_PROVIDER_ID
 111                            ))
 112                        } else {
 113                            anyhow!(err)
 114                        }
 115                    })?;
 116
 117            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
 118            let fs = project.read_with(cx, |project, _cx| project.fs().clone())?;
 119            let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), fs, cx).await?;
 120
 121            let mut mcp_servers = HashMap::default();
 122            mcp_servers.insert(
 123                mcp_server::SERVER_NAME.to_string(),
 124                permission_mcp_server.server_config()?,
 125            );
 126            let mcp_config = McpConfig { mcp_servers };
 127
 128            let mcp_config_file = tempfile::NamedTempFile::new()?;
 129            let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
 130
 131            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
 132            mcp_config_file
 133                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
 134                .await?;
 135            mcp_config_file.flush().await?;
 136
 137            let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
 138            let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
 139
 140            let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
 141
 142            log::trace!("Starting session with id: {}", session_id);
 143
 144            let mut child = spawn_claude(
 145                &command,
 146                ClaudeSessionMode::Start,
 147                session_id.clone(),
 148                api_key,
 149                &mcp_config_path,
 150                &cwd,
 151            )?;
 152
 153            let stdout = child.stdout.take().context("Failed to take stdout")?;
 154            let stdin = child.stdin.take().context("Failed to take stdin")?;
 155            let stderr = child.stderr.take().context("Failed to take stderr")?;
 156
 157            let pid = child.id();
 158            log::trace!("Spawned (pid: {})", pid);
 159
 160            cx.background_spawn(async move {
 161                let mut stderr = BufReader::new(stderr);
 162                let mut line = String::new();
 163                while let Ok(n) = stderr.read_line(&mut line).await
 164                    && n > 0
 165                {
 166                    log::warn!("agent stderr: {}", &line);
 167                    line.clear();
 168                }
 169            })
 170            .detach();
 171
 172            cx.background_spawn(async move {
 173                let mut outgoing_rx = Some(outgoing_rx);
 174
 175                ClaudeAgentSession::handle_io(
 176                    outgoing_rx.take().unwrap(),
 177                    incoming_message_tx.clone(),
 178                    stdin,
 179                    stdout,
 180                )
 181                .await?;
 182
 183                log::trace!("Stopped (pid: {})", pid);
 184
 185                drop(mcp_config_path);
 186                anyhow::Ok(())
 187            })
 188            .detach();
 189
 190            let turn_state = Rc::new(RefCell::new(TurnState::None));
 191
 192            let handler_task = cx.spawn({
 193                let turn_state = turn_state.clone();
 194                let mut thread_rx = thread_rx.clone();
 195                async move |cx| {
 196                    while let Some(message) = incoming_message_rx.next().await {
 197                        ClaudeAgentSession::handle_message(
 198                            thread_rx.clone(),
 199                            message,
 200                            turn_state.clone(),
 201                            cx,
 202                        )
 203                        .await
 204                    }
 205
 206                    if let Some(status) = child.status().await.log_err()
 207                        && let Some(thread) = thread_rx.recv().await.ok() {
 208                            thread
 209                                .update(cx, |thread, cx| {
 210                                    thread.emit_server_exited(status, cx);
 211                                })
 212                                .ok();
 213                        }
 214                }
 215            });
 216
 217            let thread = cx.new(|cx| {
 218                AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
 219            })?;
 220
 221            thread_tx.send(thread.downgrade())?;
 222
 223            let session = ClaudeAgentSession {
 224                outgoing_tx,
 225                turn_state,
 226                _handler_task: handler_task,
 227                _mcp_server: Some(permission_mcp_server),
 228            };
 229
 230            self.sessions.borrow_mut().insert(session_id, session);
 231
 232            Ok(thread)
 233        })
 234    }
 235
 236    fn auth_methods(&self) -> &[acp::AuthMethod] {
 237        &[]
 238    }
 239
 240    fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 241        Task::ready(Err(anyhow!("Authentication not supported")))
 242    }
 243
 244    fn prompt(
 245        &self,
 246        _id: Option<acp_thread::UserMessageId>,
 247        params: acp::PromptRequest,
 248        cx: &mut App,
 249    ) -> Task<Result<acp::PromptResponse>> {
 250        let sessions = self.sessions.borrow();
 251        let Some(session) = sessions.get(&params.session_id) else {
 252            return Task::ready(Err(anyhow!(
 253                "Attempted to send message to nonexistent session {}",
 254                params.session_id
 255            )));
 256        };
 257
 258        let (end_tx, end_rx) = oneshot::channel();
 259        session.turn_state.replace(TurnState::InProgress { end_tx });
 260
 261        let mut content = String::new();
 262        for chunk in params.prompt {
 263            match chunk {
 264                acp::ContentBlock::Text(text_content) => {
 265                    content.push_str(&text_content.text);
 266                }
 267                acp::ContentBlock::ResourceLink(resource_link) => {
 268                    content.push_str(&format!("@{}", resource_link.uri));
 269                }
 270                acp::ContentBlock::Audio(_)
 271                | acp::ContentBlock::Image(_)
 272                | acp::ContentBlock::Resource(_) => {
 273                    // TODO
 274                }
 275            }
 276        }
 277
 278        if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
 279            message: Message {
 280                role: Role::User,
 281                content: Content::UntaggedText(content),
 282                id: None,
 283                model: None,
 284                stop_reason: None,
 285                stop_sequence: None,
 286                usage: None,
 287            },
 288            session_id: Some(params.session_id.to_string()),
 289        }) {
 290            return Task::ready(Err(anyhow!(err)));
 291        }
 292
 293        cx.foreground_executor().spawn(async move { end_rx.await? })
 294    }
 295
 296    fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
 297        let sessions = self.sessions.borrow();
 298        let Some(session) = sessions.get(session_id) else {
 299            log::warn!("Attempted to cancel nonexistent session {}", session_id);
 300            return;
 301        };
 302
 303        let request_id = new_request_id();
 304
 305        let turn_state = session.turn_state.take();
 306        let TurnState::InProgress { end_tx } = turn_state else {
 307            // Already canceled or idle, put it back
 308            session.turn_state.replace(turn_state);
 309            return;
 310        };
 311
 312        session.turn_state.replace(TurnState::CancelRequested {
 313            end_tx,
 314            request_id: request_id.clone(),
 315        });
 316
 317        session
 318            .outgoing_tx
 319            .unbounded_send(SdkMessage::ControlRequest {
 320                request_id,
 321                request: ControlRequest::Interrupt,
 322            })
 323            .log_err();
 324    }
 325
 326    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 327        self
 328    }
 329}
 330
 331#[derive(Clone, Copy)]
 332enum ClaudeSessionMode {
 333    Start,
 334    #[expect(dead_code)]
 335    Resume,
 336}
 337
 338fn spawn_claude(
 339    command: &AgentServerCommand,
 340    mode: ClaudeSessionMode,
 341    session_id: acp::SessionId,
 342    api_key: language_models::provider::anthropic::ApiKey,
 343    mcp_config_path: &Path,
 344    root_dir: &Path,
 345) -> Result<Child> {
 346    let child = util::command::new_smol_command(&command.path)
 347        .args([
 348            "--input-format",
 349            "stream-json",
 350            "--output-format",
 351            "stream-json",
 352            "--print",
 353            "--verbose",
 354            "--mcp-config",
 355            mcp_config_path.to_string_lossy().as_ref(),
 356            "--permission-prompt-tool",
 357            &format!(
 358                "mcp__{}__{}",
 359                mcp_server::SERVER_NAME,
 360                mcp_server::PermissionTool::NAME,
 361            ),
 362            "--allowedTools",
 363            &format!(
 364                "mcp__{}__{},mcp__{}__{}",
 365                mcp_server::SERVER_NAME,
 366                mcp_server::EditTool::NAME,
 367                mcp_server::SERVER_NAME,
 368                mcp_server::ReadTool::NAME
 369            ),
 370            "--disallowedTools",
 371            "Read,Edit",
 372        ])
 373        .args(match mode {
 374            ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
 375            ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
 376        })
 377        .args(command.args.iter().map(|arg| arg.as_str()))
 378        .envs(command.env.iter().flatten())
 379        .env("ANTHROPIC_API_KEY", api_key.key)
 380        .current_dir(root_dir)
 381        .stdin(std::process::Stdio::piped())
 382        .stdout(std::process::Stdio::piped())
 383        .stderr(std::process::Stdio::piped())
 384        .kill_on_drop(true)
 385        .spawn()?;
 386
 387    Ok(child)
 388}
 389
 390struct ClaudeAgentSession {
 391    outgoing_tx: UnboundedSender<SdkMessage>,
 392    turn_state: Rc<RefCell<TurnState>>,
 393    _mcp_server: Option<ClaudeZedMcpServer>,
 394    _handler_task: Task<()>,
 395}
 396
 397#[derive(Debug, Default)]
 398enum TurnState {
 399    #[default]
 400    None,
 401    InProgress {
 402        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 403    },
 404    CancelRequested {
 405        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 406        request_id: String,
 407    },
 408    CancelConfirmed {
 409        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 410    },
 411}
 412
 413impl TurnState {
 414    fn is_canceled(&self) -> bool {
 415        matches!(self, TurnState::CancelConfirmed { .. })
 416    }
 417
 418    fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
 419        match self {
 420            TurnState::None => None,
 421            TurnState::InProgress { end_tx, .. } => Some(end_tx),
 422            TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
 423            TurnState::CancelConfirmed { end_tx } => Some(end_tx),
 424        }
 425    }
 426
 427    fn confirm_cancellation(self, id: &str) -> Self {
 428        match self {
 429            TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
 430                TurnState::CancelConfirmed { end_tx }
 431            }
 432            _ => self,
 433        }
 434    }
 435}
 436
 437impl ClaudeAgentSession {
 438    async fn handle_message(
 439        mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
 440        message: SdkMessage,
 441        turn_state: Rc<RefCell<TurnState>>,
 442        cx: &mut AsyncApp,
 443    ) {
 444        match message {
 445            // we should only be sending these out, they don't need to be in the thread
 446            SdkMessage::ControlRequest { .. } => {}
 447            SdkMessage::User {
 448                message,
 449                session_id: _,
 450            } => {
 451                let Some(thread) = thread_rx
 452                    .recv()
 453                    .await
 454                    .log_err()
 455                    .and_then(|entity| entity.upgrade())
 456                else {
 457                    log::error!("Received an SDK message but thread is gone");
 458                    return;
 459                };
 460
 461                for chunk in message.content.chunks() {
 462                    match chunk {
 463                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 464                            if !turn_state.borrow().is_canceled() {
 465                                thread
 466                                    .update(cx, |thread, cx| {
 467                                        thread.push_user_content_block(None, text.into(), cx)
 468                                    })
 469                                    .log_err();
 470                            }
 471                        }
 472                        ContentChunk::ToolResult {
 473                            content,
 474                            tool_use_id,
 475                        } => {
 476                            let content = content.to_string();
 477                            thread
 478                                .update(cx, |thread, cx| {
 479                                    thread.update_tool_call(
 480                                        acp::ToolCallUpdate {
 481                                            id: acp::ToolCallId(tool_use_id.into()),
 482                                            fields: acp::ToolCallUpdateFields {
 483                                                status: if turn_state.borrow().is_canceled() {
 484                                                    // Do not set to completed if turn was canceled
 485                                                    None
 486                                                } else {
 487                                                    Some(acp::ToolCallStatus::Completed)
 488                                                },
 489                                                content: (!content.is_empty())
 490                                                    .then(|| vec![content.into()]),
 491                                                ..Default::default()
 492                                            },
 493                                        },
 494                                        cx,
 495                                    )
 496                                })
 497                                .log_err();
 498                        }
 499                        ContentChunk::Thinking { .. }
 500                        | ContentChunk::RedactedThinking
 501                        | ContentChunk::ToolUse { .. } => {
 502                            debug_panic!(
 503                                "Should not get {:?} with role: assistant. should we handle this?",
 504                                chunk
 505                            );
 506                        }
 507
 508                        ContentChunk::Image
 509                        | ContentChunk::Document
 510                        | ContentChunk::WebSearchToolResult => {
 511                            thread
 512                                .update(cx, |thread, cx| {
 513                                    thread.push_assistant_content_block(
 514                                        format!("Unsupported content: {:?}", chunk).into(),
 515                                        false,
 516                                        cx,
 517                                    )
 518                                })
 519                                .log_err();
 520                        }
 521                    }
 522                }
 523            }
 524            SdkMessage::Assistant {
 525                message,
 526                session_id: _,
 527            } => {
 528                let Some(thread) = thread_rx
 529                    .recv()
 530                    .await
 531                    .log_err()
 532                    .and_then(|entity| entity.upgrade())
 533                else {
 534                    log::error!("Received an SDK message but thread is gone");
 535                    return;
 536                };
 537
 538                for chunk in message.content.chunks() {
 539                    match chunk {
 540                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 541                            thread
 542                                .update(cx, |thread, cx| {
 543                                    thread.push_assistant_content_block(text.into(), false, cx)
 544                                })
 545                                .log_err();
 546                        }
 547                        ContentChunk::Thinking { thinking } => {
 548                            thread
 549                                .update(cx, |thread, cx| {
 550                                    thread.push_assistant_content_block(thinking.into(), true, cx)
 551                                })
 552                                .log_err();
 553                        }
 554                        ContentChunk::RedactedThinking => {
 555                            thread
 556                                .update(cx, |thread, cx| {
 557                                    thread.push_assistant_content_block(
 558                                        "[REDACTED]".into(),
 559                                        true,
 560                                        cx,
 561                                    )
 562                                })
 563                                .log_err();
 564                        }
 565                        ContentChunk::ToolUse { id, name, input } => {
 566                            let claude_tool = ClaudeTool::infer(&name, input);
 567
 568                            thread
 569                                .update(cx, |thread, cx| {
 570                                    if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
 571                                        thread.update_plan(
 572                                            acp::Plan {
 573                                                entries: params
 574                                                    .todos
 575                                                    .into_iter()
 576                                                    .map(Into::into)
 577                                                    .collect(),
 578                                            },
 579                                            cx,
 580                                        )
 581                                    } else {
 582                                        thread.upsert_tool_call(
 583                                            claude_tool.as_acp(acp::ToolCallId(id.into())),
 584                                            cx,
 585                                        )?;
 586                                    }
 587                                    anyhow::Ok(())
 588                                })
 589                                .log_err();
 590                        }
 591                        ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
 592                            debug_panic!(
 593                                "Should not get tool results with role: assistant. should we handle this?"
 594                            );
 595                        }
 596                        ContentChunk::Image | ContentChunk::Document => {
 597                            thread
 598                                .update(cx, |thread, cx| {
 599                                    thread.push_assistant_content_block(
 600                                        format!("Unsupported content: {:?}", chunk).into(),
 601                                        false,
 602                                        cx,
 603                                    )
 604                                })
 605                                .log_err();
 606                        }
 607                    }
 608                }
 609            }
 610            SdkMessage::Result {
 611                is_error,
 612                subtype,
 613                result,
 614                ..
 615            } => {
 616                let turn_state = turn_state.take();
 617                let was_canceled = turn_state.is_canceled();
 618                let Some(end_turn_tx) = turn_state.end_tx() else {
 619                    debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
 620                    return;
 621                };
 622
 623                if is_error || (!was_canceled && subtype == ResultErrorType::ErrorDuringExecution) {
 624                    end_turn_tx
 625                        .send(Err(anyhow!(
 626                            "Error: {}",
 627                            result.unwrap_or_else(|| subtype.to_string())
 628                        )))
 629                        .ok();
 630                } else {
 631                    let stop_reason = match subtype {
 632                        ResultErrorType::Success => acp::StopReason::EndTurn,
 633                        ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
 634                        ResultErrorType::ErrorDuringExecution => acp::StopReason::Canceled,
 635                    };
 636                    end_turn_tx
 637                        .send(Ok(acp::PromptResponse { stop_reason }))
 638                        .ok();
 639                }
 640            }
 641            SdkMessage::ControlResponse { response } => {
 642                if matches!(response.subtype, ResultErrorType::Success) {
 643                    let new_state = turn_state.take().confirm_cancellation(&response.request_id);
 644                    turn_state.replace(new_state);
 645                } else {
 646                    log::error!("Control response error: {:?}", response);
 647                }
 648            }
 649            SdkMessage::System { .. } => {}
 650        }
 651    }
 652
 653    async fn handle_io(
 654        mut outgoing_rx: UnboundedReceiver<SdkMessage>,
 655        incoming_tx: UnboundedSender<SdkMessage>,
 656        mut outgoing_bytes: impl Unpin + AsyncWrite,
 657        incoming_bytes: impl Unpin + AsyncRead,
 658    ) -> Result<UnboundedReceiver<SdkMessage>> {
 659        let mut output_reader = BufReader::new(incoming_bytes);
 660        let mut outgoing_line = Vec::new();
 661        let mut incoming_line = String::new();
 662        loop {
 663            select_biased! {
 664                message = outgoing_rx.next() => {
 665                    if let Some(message) = message {
 666                        outgoing_line.clear();
 667                        serde_json::to_writer(&mut outgoing_line, &message)?;
 668                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
 669                        outgoing_line.push(b'\n');
 670                        outgoing_bytes.write_all(&outgoing_line).await.ok();
 671                    } else {
 672                        break;
 673                    }
 674                }
 675                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
 676                    if bytes_read? == 0 {
 677                        break
 678                    }
 679                    log::trace!("recv: {}", &incoming_line);
 680                    match serde_json::from_str::<SdkMessage>(&incoming_line) {
 681                        Ok(message) => {
 682                            incoming_tx.unbounded_send(message).log_err();
 683                        }
 684                        Err(error) => {
 685                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
 686                        }
 687                    }
 688                    incoming_line.clear();
 689                }
 690            }
 691        }
 692
 693        Ok(outgoing_rx)
 694    }
 695}
 696
 697#[derive(Debug, Clone, Serialize, Deserialize)]
 698struct Message {
 699    role: Role,
 700    content: Content,
 701    #[serde(skip_serializing_if = "Option::is_none")]
 702    id: Option<String>,
 703    #[serde(skip_serializing_if = "Option::is_none")]
 704    model: Option<String>,
 705    #[serde(skip_serializing_if = "Option::is_none")]
 706    stop_reason: Option<String>,
 707    #[serde(skip_serializing_if = "Option::is_none")]
 708    stop_sequence: Option<String>,
 709    #[serde(skip_serializing_if = "Option::is_none")]
 710    usage: Option<Usage>,
 711}
 712
 713#[derive(Debug, Clone, Serialize, Deserialize)]
 714#[serde(untagged)]
 715enum Content {
 716    UntaggedText(String),
 717    Chunks(Vec<ContentChunk>),
 718}
 719
 720impl Content {
 721    pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
 722        match self {
 723            Self::Chunks(chunks) => chunks.into_iter(),
 724            Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
 725        }
 726    }
 727}
 728
 729impl Display for Content {
 730    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 731        match self {
 732            Content::UntaggedText(txt) => write!(f, "{}", txt),
 733            Content::Chunks(chunks) => {
 734                for chunk in chunks {
 735                    write!(f, "{}", chunk)?;
 736                }
 737                Ok(())
 738            }
 739        }
 740    }
 741}
 742
 743#[derive(Debug, Clone, Serialize, Deserialize)]
 744#[serde(tag = "type", rename_all = "snake_case")]
 745enum ContentChunk {
 746    Text {
 747        text: String,
 748    },
 749    ToolUse {
 750        id: String,
 751        name: String,
 752        input: serde_json::Value,
 753    },
 754    ToolResult {
 755        content: Content,
 756        tool_use_id: String,
 757    },
 758    Thinking {
 759        thinking: String,
 760    },
 761    RedactedThinking,
 762    // TODO
 763    Image,
 764    Document,
 765    WebSearchToolResult,
 766    #[serde(untagged)]
 767    UntaggedText(String),
 768}
 769
 770impl Display for ContentChunk {
 771    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 772        match self {
 773            ContentChunk::Text { text } => write!(f, "{}", text),
 774            ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
 775            ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
 776            ContentChunk::UntaggedText(text) => write!(f, "{}", text),
 777            ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
 778            ContentChunk::Image
 779            | ContentChunk::Document
 780            | ContentChunk::ToolUse { .. }
 781            | ContentChunk::WebSearchToolResult => {
 782                write!(f, "\n{:?}\n", &self)
 783            }
 784        }
 785    }
 786}
 787
 788#[derive(Debug, Clone, Serialize, Deserialize)]
 789struct Usage {
 790    input_tokens: u32,
 791    cache_creation_input_tokens: u32,
 792    cache_read_input_tokens: u32,
 793    output_tokens: u32,
 794    service_tier: String,
 795}
 796
 797#[derive(Debug, Clone, Serialize, Deserialize)]
 798#[serde(rename_all = "snake_case")]
 799enum Role {
 800    System,
 801    Assistant,
 802    User,
 803}
 804
 805#[derive(Debug, Clone, Serialize, Deserialize)]
 806struct MessageParam {
 807    role: Role,
 808    content: String,
 809}
 810
 811#[derive(Debug, Clone, Serialize, Deserialize)]
 812#[serde(tag = "type", rename_all = "snake_case")]
 813enum SdkMessage {
 814    // An assistant message
 815    Assistant {
 816        message: Message, // from Anthropic SDK
 817        #[serde(skip_serializing_if = "Option::is_none")]
 818        session_id: Option<String>,
 819    },
 820    // A user message
 821    User {
 822        message: Message, // from Anthropic SDK
 823        #[serde(skip_serializing_if = "Option::is_none")]
 824        session_id: Option<String>,
 825    },
 826    // Emitted as the last message in a conversation
 827    Result {
 828        subtype: ResultErrorType,
 829        duration_ms: f64,
 830        duration_api_ms: f64,
 831        is_error: bool,
 832        num_turns: i32,
 833        #[serde(skip_serializing_if = "Option::is_none")]
 834        result: Option<String>,
 835        session_id: String,
 836        total_cost_usd: f64,
 837    },
 838    // Emitted as the first message at the start of a conversation
 839    System {
 840        cwd: String,
 841        session_id: String,
 842        tools: Vec<String>,
 843        model: String,
 844        mcp_servers: Vec<McpServer>,
 845        #[serde(rename = "apiKeySource")]
 846        api_key_source: String,
 847        #[serde(rename = "permissionMode")]
 848        permission_mode: PermissionMode,
 849    },
 850    /// Messages used to control the conversation, outside of chat messages to the model
 851    ControlRequest {
 852        request_id: String,
 853        request: ControlRequest,
 854    },
 855    /// Response to a control request
 856    ControlResponse { response: ControlResponse },
 857}
 858
 859#[derive(Debug, Clone, Serialize, Deserialize)]
 860#[serde(tag = "subtype", rename_all = "snake_case")]
 861enum ControlRequest {
 862    /// Cancel the current conversation
 863    Interrupt,
 864}
 865
 866#[derive(Debug, Clone, Serialize, Deserialize)]
 867struct ControlResponse {
 868    request_id: String,
 869    subtype: ResultErrorType,
 870}
 871
 872#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
 873#[serde(rename_all = "snake_case")]
 874enum ResultErrorType {
 875    Success,
 876    ErrorMaxTurns,
 877    ErrorDuringExecution,
 878}
 879
 880impl Display for ResultErrorType {
 881    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 882        match self {
 883            ResultErrorType::Success => write!(f, "success"),
 884            ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
 885            ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
 886        }
 887    }
 888}
 889
 890fn new_request_id() -> String {
 891    use rand::Rng;
 892    // In the Claude Code TS SDK they just generate a random 12 character string,
 893    // `Math.random().toString(36).substring(2, 15)`
 894    rand::thread_rng()
 895        .sample_iter(&rand::distributions::Alphanumeric)
 896        .take(12)
 897        .map(char::from)
 898        .collect()
 899}
 900
 901#[derive(Debug, Clone, Serialize, Deserialize)]
 902struct McpServer {
 903    name: String,
 904    status: String,
 905}
 906
 907#[derive(Debug, Clone, Serialize, Deserialize)]
 908#[serde(rename_all = "camelCase")]
 909enum PermissionMode {
 910    Default,
 911    AcceptEdits,
 912    BypassPermissions,
 913    Plan,
 914}
 915
 916#[cfg(test)]
 917pub(crate) mod tests {
 918    use super::*;
 919    use crate::e2e_tests;
 920    use gpui::TestAppContext;
 921    use serde_json::json;
 922
 923    crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
 924
 925    pub fn local_command() -> AgentServerCommand {
 926        AgentServerCommand {
 927            path: "claude".into(),
 928            args: vec![],
 929            env: None,
 930        }
 931    }
 932
 933    #[gpui::test]
 934    #[cfg_attr(not(feature = "e2e"), ignore)]
 935    async fn test_todo_plan(cx: &mut TestAppContext) {
 936        let fs = e2e_tests::init_test(cx).await;
 937        let project = Project::test(fs, [], cx).await;
 938        let thread =
 939            e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
 940
 941        thread
 942            .update(cx, |thread, cx| {
 943                thread.send_raw(
 944                    "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
 945                    cx,
 946                )
 947            })
 948            .await
 949            .unwrap();
 950
 951        let mut entries_len = 0;
 952
 953        thread.read_with(cx, |thread, _| {
 954            entries_len = thread.plan().entries.len();
 955            assert!(thread.plan().entries.len() > 0, "Empty plan");
 956        });
 957
 958        thread
 959            .update(cx, |thread, cx| {
 960                thread.send_raw(
 961                    "Mark the first entry status as in progress without acting on it.",
 962                    cx,
 963                )
 964            })
 965            .await
 966            .unwrap();
 967
 968        thread.read_with(cx, |thread, _| {
 969            assert!(matches!(
 970                thread.plan().entries[0].status,
 971                acp::PlanEntryStatus::InProgress
 972            ));
 973            assert_eq!(thread.plan().entries.len(), entries_len);
 974        });
 975
 976        thread
 977            .update(cx, |thread, cx| {
 978                thread.send_raw(
 979                    "Now mark the first entry as completed without acting on it.",
 980                    cx,
 981                )
 982            })
 983            .await
 984            .unwrap();
 985
 986        thread.read_with(cx, |thread, _| {
 987            assert!(matches!(
 988                thread.plan().entries[0].status,
 989                acp::PlanEntryStatus::Completed
 990            ));
 991            assert_eq!(thread.plan().entries.len(), entries_len);
 992        });
 993    }
 994
 995    #[test]
 996    fn test_deserialize_content_untagged_text() {
 997        let json = json!("Hello, world!");
 998        let content: Content = serde_json::from_value(json).unwrap();
 999        match content {
1000            Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
1001            _ => panic!("Expected UntaggedText variant"),
1002        }
1003    }
1004
1005    #[test]
1006    fn test_deserialize_content_chunks() {
1007        let json = json!([
1008            {
1009                "type": "text",
1010                "text": "Hello"
1011            },
1012            {
1013                "type": "tool_use",
1014                "id": "tool_123",
1015                "name": "calculator",
1016                "input": {"operation": "add", "a": 1, "b": 2}
1017            }
1018        ]);
1019        let content: Content = serde_json::from_value(json).unwrap();
1020        match content {
1021            Content::Chunks(chunks) => {
1022                assert_eq!(chunks.len(), 2);
1023                match &chunks[0] {
1024                    ContentChunk::Text { text } => assert_eq!(text, "Hello"),
1025                    _ => panic!("Expected Text chunk"),
1026                }
1027                match &chunks[1] {
1028                    ContentChunk::ToolUse { id, name, input } => {
1029                        assert_eq!(id, "tool_123");
1030                        assert_eq!(name, "calculator");
1031                        assert_eq!(input["operation"], "add");
1032                        assert_eq!(input["a"], 1);
1033                        assert_eq!(input["b"], 2);
1034                    }
1035                    _ => panic!("Expected ToolUse chunk"),
1036                }
1037            }
1038            _ => panic!("Expected Chunks variant"),
1039        }
1040    }
1041
1042    #[test]
1043    fn test_deserialize_tool_result_untagged_text() {
1044        let json = json!({
1045            "type": "tool_result",
1046            "content": "Result content",
1047            "tool_use_id": "tool_456"
1048        });
1049        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1050        match chunk {
1051            ContentChunk::ToolResult {
1052                content,
1053                tool_use_id,
1054            } => {
1055                match content {
1056                    Content::UntaggedText(text) => assert_eq!(text, "Result content"),
1057                    _ => panic!("Expected UntaggedText content"),
1058                }
1059                assert_eq!(tool_use_id, "tool_456");
1060            }
1061            _ => panic!("Expected ToolResult variant"),
1062        }
1063    }
1064
1065    #[test]
1066    fn test_deserialize_tool_result_chunks() {
1067        let json = json!({
1068            "type": "tool_result",
1069            "content": [
1070                {
1071                    "type": "text",
1072                    "text": "Processing complete"
1073                },
1074                {
1075                    "type": "text",
1076                    "text": "Result: 42"
1077                }
1078            ],
1079            "tool_use_id": "tool_789"
1080        });
1081        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1082        match chunk {
1083            ContentChunk::ToolResult {
1084                content,
1085                tool_use_id,
1086            } => {
1087                match content {
1088                    Content::Chunks(chunks) => {
1089                        assert_eq!(chunks.len(), 2);
1090                        match &chunks[0] {
1091                            ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1092                            _ => panic!("Expected Text chunk"),
1093                        }
1094                        match &chunks[1] {
1095                            ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1096                            _ => panic!("Expected Text chunk"),
1097                        }
1098                    }
1099                    _ => panic!("Expected Chunks content"),
1100                }
1101                assert_eq!(tool_use_id, "tool_789");
1102            }
1103            _ => panic!("Expected ToolResult variant"),
1104        }
1105    }
1106}