claude.rs

   1mod mcp_server;
   2pub mod tools;
   3
   4use collections::HashMap;
   5use context_server::listener::McpServerTool;
   6use language_models::provider::anthropic::AnthropicLanguageModelProvider;
   7use project::Project;
   8use settings::SettingsStore;
   9use smol::process::Child;
  10use std::any::Any;
  11use std::cell::RefCell;
  12use std::fmt::Display;
  13use std::path::Path;
  14use std::rc::Rc;
  15use uuid::Uuid;
  16
  17use agent_client_protocol as acp;
  18use anyhow::{Context as _, Result, anyhow};
  19use futures::channel::oneshot;
  20use futures::{AsyncBufReadExt, AsyncWriteExt};
  21use futures::{
  22    AsyncRead, AsyncWrite, FutureExt, StreamExt,
  23    channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
  24    io::BufReader,
  25    select_biased,
  26};
  27use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  28use serde::{Deserialize, Serialize};
  29use util::{ResultExt, debug_panic};
  30
  31use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
  32use crate::claude::tools::ClaudeTool;
  33use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
  34use acp_thread::{AcpThread, AgentConnection, AuthRequired};
  35
  36#[derive(Clone)]
  37pub struct ClaudeCode;
  38
  39impl AgentServer for ClaudeCode {
  40    fn name(&self) -> &'static str {
  41        "Claude Code"
  42    }
  43
  44    fn empty_state_headline(&self) -> &'static str {
  45        self.name()
  46    }
  47
  48    fn empty_state_message(&self) -> &'static str {
  49        "How can I help you today?"
  50    }
  51
  52    fn logo(&self) -> ui::IconName {
  53        ui::IconName::AiClaude
  54    }
  55
  56    fn connect(
  57        &self,
  58        _root_dir: &Path,
  59        _project: &Entity<Project>,
  60        _cx: &mut App,
  61    ) -> Task<Result<Rc<dyn AgentConnection>>> {
  62        let connection = ClaudeAgentConnection {
  63            sessions: Default::default(),
  64        };
  65
  66        Task::ready(Ok(Rc::new(connection) as _))
  67    }
  68}
  69
  70struct ClaudeAgentConnection {
  71    sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
  72}
  73
  74impl AgentConnection for ClaudeAgentConnection {
  75    fn new_thread(
  76        self: Rc<Self>,
  77        project: Entity<Project>,
  78        cwd: &Path,
  79        cx: &mut App,
  80    ) -> Task<Result<Entity<AcpThread>>> {
  81        let cwd = cwd.to_owned();
  82        cx.spawn(async move |cx| {
  83            let settings = cx.read_global(|settings: &SettingsStore, _| {
  84                settings.get::<AllAgentServersSettings>(None).claude.clone()
  85            })?;
  86
  87            let Some(command) = AgentServerCommand::resolve(
  88                "claude",
  89                &[],
  90                Some(&util::paths::home_dir().join(".claude/local/claude")),
  91                settings,
  92                &project,
  93                cx,
  94            )
  95            .await
  96            else {
  97                anyhow::bail!("Failed to find claude binary");
  98            };
  99
 100            let api_key =
 101                cx.update(AnthropicLanguageModelProvider::api_key)?
 102                    .await
 103                    .map_err(|err| {
 104                        if err.is::<language_model::AuthenticateError>() {
 105                            anyhow!(AuthRequired::new().with_language_model_provider(
 106                                language_model::ANTHROPIC_PROVIDER_ID
 107                            ))
 108                        } else {
 109                            anyhow!(err)
 110                        }
 111                    })?;
 112
 113            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
 114            let fs = project.read_with(cx, |project, _cx| project.fs().clone())?;
 115            let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), fs, cx).await?;
 116
 117            let mut mcp_servers = HashMap::default();
 118            mcp_servers.insert(
 119                mcp_server::SERVER_NAME.to_string(),
 120                permission_mcp_server.server_config()?,
 121            );
 122            let mcp_config = McpConfig { mcp_servers };
 123
 124            let mcp_config_file = tempfile::NamedTempFile::new()?;
 125            let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
 126
 127            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
 128            mcp_config_file
 129                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
 130                .await?;
 131            mcp_config_file.flush().await?;
 132
 133            let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
 134            let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
 135
 136            let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
 137
 138            log::trace!("Starting session with id: {}", session_id);
 139
 140            let mut child = spawn_claude(
 141                &command,
 142                ClaudeSessionMode::Start,
 143                session_id.clone(),
 144                api_key,
 145                &mcp_config_path,
 146                &cwd,
 147            )?;
 148
 149            let stdout = child.stdout.take().context("Failed to take stdout")?;
 150            let stdin = child.stdin.take().context("Failed to take stdin")?;
 151            let stderr = child.stderr.take().context("Failed to take stderr")?;
 152
 153            let pid = child.id();
 154            log::trace!("Spawned (pid: {})", pid);
 155
 156            cx.background_spawn(async move {
 157                let mut stderr = BufReader::new(stderr);
 158                let mut line = String::new();
 159                while let Ok(n) = stderr.read_line(&mut line).await
 160                    && n > 0
 161                {
 162                    log::warn!("agent stderr: {}", &line);
 163                    line.clear();
 164                }
 165            })
 166            .detach();
 167
 168            cx.background_spawn(async move {
 169                let mut outgoing_rx = Some(outgoing_rx);
 170
 171                ClaudeAgentSession::handle_io(
 172                    outgoing_rx.take().unwrap(),
 173                    incoming_message_tx.clone(),
 174                    stdin,
 175                    stdout,
 176                )
 177                .await?;
 178
 179                log::trace!("Stopped (pid: {})", pid);
 180
 181                drop(mcp_config_path);
 182                anyhow::Ok(())
 183            })
 184            .detach();
 185
 186            let turn_state = Rc::new(RefCell::new(TurnState::None));
 187
 188            let handler_task = cx.spawn({
 189                let turn_state = turn_state.clone();
 190                let mut thread_rx = thread_rx.clone();
 191                async move |cx| {
 192                    while let Some(message) = incoming_message_rx.next().await {
 193                        ClaudeAgentSession::handle_message(
 194                            thread_rx.clone(),
 195                            message,
 196                            turn_state.clone(),
 197                            cx,
 198                        )
 199                        .await
 200                    }
 201
 202                    if let Some(status) = child.status().await.log_err() {
 203                        if let Some(thread) = thread_rx.recv().await.ok() {
 204                            thread
 205                                .update(cx, |thread, cx| {
 206                                    thread.emit_server_exited(status, cx);
 207                                })
 208                                .ok();
 209                        }
 210                    }
 211                }
 212            });
 213
 214            let thread = cx.new(|cx| {
 215                AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
 216            })?;
 217
 218            thread_tx.send(thread.downgrade())?;
 219
 220            let session = ClaudeAgentSession {
 221                outgoing_tx,
 222                turn_state,
 223                _handler_task: handler_task,
 224                _mcp_server: Some(permission_mcp_server),
 225            };
 226
 227            self.sessions.borrow_mut().insert(session_id, session);
 228
 229            Ok(thread)
 230        })
 231    }
 232
 233    fn auth_methods(&self) -> &[acp::AuthMethod] {
 234        &[]
 235    }
 236
 237    fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 238        Task::ready(Err(anyhow!("Authentication not supported")))
 239    }
 240
 241    fn prompt(
 242        &self,
 243        _id: Option<acp_thread::UserMessageId>,
 244        params: acp::PromptRequest,
 245        cx: &mut App,
 246    ) -> Task<Result<acp::PromptResponse>> {
 247        let sessions = self.sessions.borrow();
 248        let Some(session) = sessions.get(&params.session_id) else {
 249            return Task::ready(Err(anyhow!(
 250                "Attempted to send message to nonexistent session {}",
 251                params.session_id
 252            )));
 253        };
 254
 255        let (end_tx, end_rx) = oneshot::channel();
 256        session.turn_state.replace(TurnState::InProgress { end_tx });
 257
 258        let mut content = String::new();
 259        for chunk in params.prompt {
 260            match chunk {
 261                acp::ContentBlock::Text(text_content) => {
 262                    content.push_str(&text_content.text);
 263                }
 264                acp::ContentBlock::ResourceLink(resource_link) => {
 265                    content.push_str(&format!("@{}", resource_link.uri));
 266                }
 267                acp::ContentBlock::Audio(_)
 268                | acp::ContentBlock::Image(_)
 269                | acp::ContentBlock::Resource(_) => {
 270                    // TODO
 271                }
 272            }
 273        }
 274
 275        if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
 276            message: Message {
 277                role: Role::User,
 278                content: Content::UntaggedText(content),
 279                id: None,
 280                model: None,
 281                stop_reason: None,
 282                stop_sequence: None,
 283                usage: None,
 284            },
 285            session_id: Some(params.session_id.to_string()),
 286        }) {
 287            return Task::ready(Err(anyhow!(err)));
 288        }
 289
 290        cx.foreground_executor().spawn(async move { end_rx.await? })
 291    }
 292
 293    fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
 294        let sessions = self.sessions.borrow();
 295        let Some(session) = sessions.get(session_id) else {
 296            log::warn!("Attempted to cancel nonexistent session {}", session_id);
 297            return;
 298        };
 299
 300        let request_id = new_request_id();
 301
 302        let turn_state = session.turn_state.take();
 303        let TurnState::InProgress { end_tx } = turn_state else {
 304            // Already canceled or idle, put it back
 305            session.turn_state.replace(turn_state);
 306            return;
 307        };
 308
 309        session.turn_state.replace(TurnState::CancelRequested {
 310            end_tx,
 311            request_id: request_id.clone(),
 312        });
 313
 314        session
 315            .outgoing_tx
 316            .unbounded_send(SdkMessage::ControlRequest {
 317                request_id,
 318                request: ControlRequest::Interrupt,
 319            })
 320            .log_err();
 321    }
 322
 323    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 324        self
 325    }
 326}
 327
 328#[derive(Clone, Copy)]
 329enum ClaudeSessionMode {
 330    Start,
 331    #[expect(dead_code)]
 332    Resume,
 333}
 334
 335fn spawn_claude(
 336    command: &AgentServerCommand,
 337    mode: ClaudeSessionMode,
 338    session_id: acp::SessionId,
 339    api_key: language_models::provider::anthropic::ApiKey,
 340    mcp_config_path: &Path,
 341    root_dir: &Path,
 342) -> Result<Child> {
 343    let child = util::command::new_smol_command(&command.path)
 344        .args([
 345            "--input-format",
 346            "stream-json",
 347            "--output-format",
 348            "stream-json",
 349            "--print",
 350            "--verbose",
 351            "--mcp-config",
 352            mcp_config_path.to_string_lossy().as_ref(),
 353            "--permission-prompt-tool",
 354            &format!(
 355                "mcp__{}__{}",
 356                mcp_server::SERVER_NAME,
 357                mcp_server::PermissionTool::NAME,
 358            ),
 359            "--allowedTools",
 360            &format!(
 361                "mcp__{}__{},mcp__{}__{}",
 362                mcp_server::SERVER_NAME,
 363                mcp_server::EditTool::NAME,
 364                mcp_server::SERVER_NAME,
 365                mcp_server::ReadTool::NAME
 366            ),
 367            "--disallowedTools",
 368            "Read,Edit",
 369        ])
 370        .args(match mode {
 371            ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
 372            ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
 373        })
 374        .args(command.args.iter().map(|arg| arg.as_str()))
 375        .envs(command.env.iter().flatten())
 376        .env("ANTHROPIC_API_KEY", api_key.key)
 377        .current_dir(root_dir)
 378        .stdin(std::process::Stdio::piped())
 379        .stdout(std::process::Stdio::piped())
 380        .stderr(std::process::Stdio::piped())
 381        .kill_on_drop(true)
 382        .spawn()?;
 383
 384    Ok(child)
 385}
 386
 387struct ClaudeAgentSession {
 388    outgoing_tx: UnboundedSender<SdkMessage>,
 389    turn_state: Rc<RefCell<TurnState>>,
 390    _mcp_server: Option<ClaudeZedMcpServer>,
 391    _handler_task: Task<()>,
 392}
 393
 394#[derive(Debug, Default)]
 395enum TurnState {
 396    #[default]
 397    None,
 398    InProgress {
 399        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 400    },
 401    CancelRequested {
 402        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 403        request_id: String,
 404    },
 405    CancelConfirmed {
 406        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 407    },
 408}
 409
 410impl TurnState {
 411    fn is_canceled(&self) -> bool {
 412        matches!(self, TurnState::CancelConfirmed { .. })
 413    }
 414
 415    fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
 416        match self {
 417            TurnState::None => None,
 418            TurnState::InProgress { end_tx, .. } => Some(end_tx),
 419            TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
 420            TurnState::CancelConfirmed { end_tx } => Some(end_tx),
 421        }
 422    }
 423
 424    fn confirm_cancellation(self, id: &str) -> Self {
 425        match self {
 426            TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
 427                TurnState::CancelConfirmed { end_tx }
 428            }
 429            _ => self,
 430        }
 431    }
 432}
 433
 434impl ClaudeAgentSession {
 435    async fn handle_message(
 436        mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
 437        message: SdkMessage,
 438        turn_state: Rc<RefCell<TurnState>>,
 439        cx: &mut AsyncApp,
 440    ) {
 441        match message {
 442            // we should only be sending these out, they don't need to be in the thread
 443            SdkMessage::ControlRequest { .. } => {}
 444            SdkMessage::User {
 445                message,
 446                session_id: _,
 447            } => {
 448                let Some(thread) = thread_rx
 449                    .recv()
 450                    .await
 451                    .log_err()
 452                    .and_then(|entity| entity.upgrade())
 453                else {
 454                    log::error!("Received an SDK message but thread is gone");
 455                    return;
 456                };
 457
 458                for chunk in message.content.chunks() {
 459                    match chunk {
 460                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 461                            if !turn_state.borrow().is_canceled() {
 462                                thread
 463                                    .update(cx, |thread, cx| {
 464                                        thread.push_user_content_block(None, text.into(), cx)
 465                                    })
 466                                    .log_err();
 467                            }
 468                        }
 469                        ContentChunk::ToolResult {
 470                            content,
 471                            tool_use_id,
 472                        } => {
 473                            let content = content.to_string();
 474                            thread
 475                                .update(cx, |thread, cx| {
 476                                    thread.update_tool_call(
 477                                        acp::ToolCallUpdate {
 478                                            id: acp::ToolCallId(tool_use_id.into()),
 479                                            fields: acp::ToolCallUpdateFields {
 480                                                status: if turn_state.borrow().is_canceled() {
 481                                                    // Do not set to completed if turn was canceled
 482                                                    None
 483                                                } else {
 484                                                    Some(acp::ToolCallStatus::Completed)
 485                                                },
 486                                                content: (!content.is_empty())
 487                                                    .then(|| vec![content.into()]),
 488                                                ..Default::default()
 489                                            },
 490                                        },
 491                                        cx,
 492                                    )
 493                                })
 494                                .log_err();
 495                        }
 496                        ContentChunk::Thinking { .. }
 497                        | ContentChunk::RedactedThinking
 498                        | ContentChunk::ToolUse { .. } => {
 499                            debug_panic!(
 500                                "Should not get {:?} with role: assistant. should we handle this?",
 501                                chunk
 502                            );
 503                        }
 504
 505                        ContentChunk::Image
 506                        | ContentChunk::Document
 507                        | ContentChunk::WebSearchToolResult => {
 508                            thread
 509                                .update(cx, |thread, cx| {
 510                                    thread.push_assistant_content_block(
 511                                        format!("Unsupported content: {:?}", chunk).into(),
 512                                        false,
 513                                        cx,
 514                                    )
 515                                })
 516                                .log_err();
 517                        }
 518                    }
 519                }
 520            }
 521            SdkMessage::Assistant {
 522                message,
 523                session_id: _,
 524            } => {
 525                let Some(thread) = thread_rx
 526                    .recv()
 527                    .await
 528                    .log_err()
 529                    .and_then(|entity| entity.upgrade())
 530                else {
 531                    log::error!("Received an SDK message but thread is gone");
 532                    return;
 533                };
 534
 535                for chunk in message.content.chunks() {
 536                    match chunk {
 537                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 538                            thread
 539                                .update(cx, |thread, cx| {
 540                                    thread.push_assistant_content_block(text.into(), false, cx)
 541                                })
 542                                .log_err();
 543                        }
 544                        ContentChunk::Thinking { thinking } => {
 545                            thread
 546                                .update(cx, |thread, cx| {
 547                                    thread.push_assistant_content_block(thinking.into(), true, cx)
 548                                })
 549                                .log_err();
 550                        }
 551                        ContentChunk::RedactedThinking => {
 552                            thread
 553                                .update(cx, |thread, cx| {
 554                                    thread.push_assistant_content_block(
 555                                        "[REDACTED]".into(),
 556                                        true,
 557                                        cx,
 558                                    )
 559                                })
 560                                .log_err();
 561                        }
 562                        ContentChunk::ToolUse { id, name, input } => {
 563                            let claude_tool = ClaudeTool::infer(&name, input);
 564
 565                            thread
 566                                .update(cx, |thread, cx| {
 567                                    if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
 568                                        thread.update_plan(
 569                                            acp::Plan {
 570                                                entries: params
 571                                                    .todos
 572                                                    .into_iter()
 573                                                    .map(Into::into)
 574                                                    .collect(),
 575                                            },
 576                                            cx,
 577                                        )
 578                                    } else {
 579                                        thread.upsert_tool_call(
 580                                            claude_tool.as_acp(acp::ToolCallId(id.into())),
 581                                            cx,
 582                                        )?;
 583                                    }
 584                                    anyhow::Ok(())
 585                                })
 586                                .log_err();
 587                        }
 588                        ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
 589                            debug_panic!(
 590                                "Should not get tool results with role: assistant. should we handle this?"
 591                            );
 592                        }
 593                        ContentChunk::Image | ContentChunk::Document => {
 594                            thread
 595                                .update(cx, |thread, cx| {
 596                                    thread.push_assistant_content_block(
 597                                        format!("Unsupported content: {:?}", chunk).into(),
 598                                        false,
 599                                        cx,
 600                                    )
 601                                })
 602                                .log_err();
 603                        }
 604                    }
 605                }
 606            }
 607            SdkMessage::Result {
 608                is_error,
 609                subtype,
 610                result,
 611                ..
 612            } => {
 613                let turn_state = turn_state.take();
 614                let was_canceled = turn_state.is_canceled();
 615                let Some(end_turn_tx) = turn_state.end_tx() else {
 616                    debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
 617                    return;
 618                };
 619
 620                if is_error || (!was_canceled && subtype == ResultErrorType::ErrorDuringExecution) {
 621                    end_turn_tx
 622                        .send(Err(anyhow!(
 623                            "Error: {}",
 624                            result.unwrap_or_else(|| subtype.to_string())
 625                        )))
 626                        .ok();
 627                } else {
 628                    let stop_reason = match subtype {
 629                        ResultErrorType::Success => acp::StopReason::EndTurn,
 630                        ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
 631                        ResultErrorType::ErrorDuringExecution => acp::StopReason::Canceled,
 632                    };
 633                    end_turn_tx
 634                        .send(Ok(acp::PromptResponse { stop_reason }))
 635                        .ok();
 636                }
 637            }
 638            SdkMessage::ControlResponse { response } => {
 639                if matches!(response.subtype, ResultErrorType::Success) {
 640                    let new_state = turn_state.take().confirm_cancellation(&response.request_id);
 641                    turn_state.replace(new_state);
 642                } else {
 643                    log::error!("Control response error: {:?}", response);
 644                }
 645            }
 646            SdkMessage::System { .. } => {}
 647        }
 648    }
 649
 650    async fn handle_io(
 651        mut outgoing_rx: UnboundedReceiver<SdkMessage>,
 652        incoming_tx: UnboundedSender<SdkMessage>,
 653        mut outgoing_bytes: impl Unpin + AsyncWrite,
 654        incoming_bytes: impl Unpin + AsyncRead,
 655    ) -> Result<UnboundedReceiver<SdkMessage>> {
 656        let mut output_reader = BufReader::new(incoming_bytes);
 657        let mut outgoing_line = Vec::new();
 658        let mut incoming_line = String::new();
 659        loop {
 660            select_biased! {
 661                message = outgoing_rx.next() => {
 662                    if let Some(message) = message {
 663                        outgoing_line.clear();
 664                        serde_json::to_writer(&mut outgoing_line, &message)?;
 665                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
 666                        outgoing_line.push(b'\n');
 667                        outgoing_bytes.write_all(&outgoing_line).await.ok();
 668                    } else {
 669                        break;
 670                    }
 671                }
 672                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
 673                    if bytes_read? == 0 {
 674                        break
 675                    }
 676                    log::trace!("recv: {}", &incoming_line);
 677                    match serde_json::from_str::<SdkMessage>(&incoming_line) {
 678                        Ok(message) => {
 679                            incoming_tx.unbounded_send(message).log_err();
 680                        }
 681                        Err(error) => {
 682                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
 683                        }
 684                    }
 685                    incoming_line.clear();
 686                }
 687            }
 688        }
 689
 690        Ok(outgoing_rx)
 691    }
 692}
 693
 694#[derive(Debug, Clone, Serialize, Deserialize)]
 695struct Message {
 696    role: Role,
 697    content: Content,
 698    #[serde(skip_serializing_if = "Option::is_none")]
 699    id: Option<String>,
 700    #[serde(skip_serializing_if = "Option::is_none")]
 701    model: Option<String>,
 702    #[serde(skip_serializing_if = "Option::is_none")]
 703    stop_reason: Option<String>,
 704    #[serde(skip_serializing_if = "Option::is_none")]
 705    stop_sequence: Option<String>,
 706    #[serde(skip_serializing_if = "Option::is_none")]
 707    usage: Option<Usage>,
 708}
 709
 710#[derive(Debug, Clone, Serialize, Deserialize)]
 711#[serde(untagged)]
 712enum Content {
 713    UntaggedText(String),
 714    Chunks(Vec<ContentChunk>),
 715}
 716
 717impl Content {
 718    pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
 719        match self {
 720            Self::Chunks(chunks) => chunks.into_iter(),
 721            Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
 722        }
 723    }
 724}
 725
 726impl Display for Content {
 727    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 728        match self {
 729            Content::UntaggedText(txt) => write!(f, "{}", txt),
 730            Content::Chunks(chunks) => {
 731                for chunk in chunks {
 732                    write!(f, "{}", chunk)?;
 733                }
 734                Ok(())
 735            }
 736        }
 737    }
 738}
 739
 740#[derive(Debug, Clone, Serialize, Deserialize)]
 741#[serde(tag = "type", rename_all = "snake_case")]
 742enum ContentChunk {
 743    Text {
 744        text: String,
 745    },
 746    ToolUse {
 747        id: String,
 748        name: String,
 749        input: serde_json::Value,
 750    },
 751    ToolResult {
 752        content: Content,
 753        tool_use_id: String,
 754    },
 755    Thinking {
 756        thinking: String,
 757    },
 758    RedactedThinking,
 759    // TODO
 760    Image,
 761    Document,
 762    WebSearchToolResult,
 763    #[serde(untagged)]
 764    UntaggedText(String),
 765}
 766
 767impl Display for ContentChunk {
 768    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 769        match self {
 770            ContentChunk::Text { text } => write!(f, "{}", text),
 771            ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
 772            ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
 773            ContentChunk::UntaggedText(text) => write!(f, "{}", text),
 774            ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
 775            ContentChunk::Image
 776            | ContentChunk::Document
 777            | ContentChunk::ToolUse { .. }
 778            | ContentChunk::WebSearchToolResult => {
 779                write!(f, "\n{:?}\n", &self)
 780            }
 781        }
 782    }
 783}
 784
 785#[derive(Debug, Clone, Serialize, Deserialize)]
 786struct Usage {
 787    input_tokens: u32,
 788    cache_creation_input_tokens: u32,
 789    cache_read_input_tokens: u32,
 790    output_tokens: u32,
 791    service_tier: String,
 792}
 793
 794#[derive(Debug, Clone, Serialize, Deserialize)]
 795#[serde(rename_all = "snake_case")]
 796enum Role {
 797    System,
 798    Assistant,
 799    User,
 800}
 801
 802#[derive(Debug, Clone, Serialize, Deserialize)]
 803struct MessageParam {
 804    role: Role,
 805    content: String,
 806}
 807
 808#[derive(Debug, Clone, Serialize, Deserialize)]
 809#[serde(tag = "type", rename_all = "snake_case")]
 810enum SdkMessage {
 811    // An assistant message
 812    Assistant {
 813        message: Message, // from Anthropic SDK
 814        #[serde(skip_serializing_if = "Option::is_none")]
 815        session_id: Option<String>,
 816    },
 817    // A user message
 818    User {
 819        message: Message, // from Anthropic SDK
 820        #[serde(skip_serializing_if = "Option::is_none")]
 821        session_id: Option<String>,
 822    },
 823    // Emitted as the last message in a conversation
 824    Result {
 825        subtype: ResultErrorType,
 826        duration_ms: f64,
 827        duration_api_ms: f64,
 828        is_error: bool,
 829        num_turns: i32,
 830        #[serde(skip_serializing_if = "Option::is_none")]
 831        result: Option<String>,
 832        session_id: String,
 833        total_cost_usd: f64,
 834    },
 835    // Emitted as the first message at the start of a conversation
 836    System {
 837        cwd: String,
 838        session_id: String,
 839        tools: Vec<String>,
 840        model: String,
 841        mcp_servers: Vec<McpServer>,
 842        #[serde(rename = "apiKeySource")]
 843        api_key_source: String,
 844        #[serde(rename = "permissionMode")]
 845        permission_mode: PermissionMode,
 846    },
 847    /// Messages used to control the conversation, outside of chat messages to the model
 848    ControlRequest {
 849        request_id: String,
 850        request: ControlRequest,
 851    },
 852    /// Response to a control request
 853    ControlResponse { response: ControlResponse },
 854}
 855
 856#[derive(Debug, Clone, Serialize, Deserialize)]
 857#[serde(tag = "subtype", rename_all = "snake_case")]
 858enum ControlRequest {
 859    /// Cancel the current conversation
 860    Interrupt,
 861}
 862
 863#[derive(Debug, Clone, Serialize, Deserialize)]
 864struct ControlResponse {
 865    request_id: String,
 866    subtype: ResultErrorType,
 867}
 868
 869#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
 870#[serde(rename_all = "snake_case")]
 871enum ResultErrorType {
 872    Success,
 873    ErrorMaxTurns,
 874    ErrorDuringExecution,
 875}
 876
 877impl Display for ResultErrorType {
 878    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 879        match self {
 880            ResultErrorType::Success => write!(f, "success"),
 881            ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
 882            ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
 883        }
 884    }
 885}
 886
 887fn new_request_id() -> String {
 888    use rand::Rng;
 889    // In the Claude Code TS SDK they just generate a random 12 character string,
 890    // `Math.random().toString(36).substring(2, 15)`
 891    rand::thread_rng()
 892        .sample_iter(&rand::distributions::Alphanumeric)
 893        .take(12)
 894        .map(char::from)
 895        .collect()
 896}
 897
 898#[derive(Debug, Clone, Serialize, Deserialize)]
 899struct McpServer {
 900    name: String,
 901    status: String,
 902}
 903
 904#[derive(Debug, Clone, Serialize, Deserialize)]
 905#[serde(rename_all = "camelCase")]
 906enum PermissionMode {
 907    Default,
 908    AcceptEdits,
 909    BypassPermissions,
 910    Plan,
 911}
 912
 913#[cfg(test)]
 914pub(crate) mod tests {
 915    use super::*;
 916    use crate::e2e_tests;
 917    use gpui::TestAppContext;
 918    use serde_json::json;
 919
 920    crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
 921
 922    pub fn local_command() -> AgentServerCommand {
 923        AgentServerCommand {
 924            path: "claude".into(),
 925            args: vec![],
 926            env: None,
 927        }
 928    }
 929
 930    #[gpui::test]
 931    #[cfg_attr(not(feature = "e2e"), ignore)]
 932    async fn test_todo_plan(cx: &mut TestAppContext) {
 933        let fs = e2e_tests::init_test(cx).await;
 934        let project = Project::test(fs, [], cx).await;
 935        let thread =
 936            e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
 937
 938        thread
 939            .update(cx, |thread, cx| {
 940                thread.send_raw(
 941                    "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
 942                    cx,
 943                )
 944            })
 945            .await
 946            .unwrap();
 947
 948        let mut entries_len = 0;
 949
 950        thread.read_with(cx, |thread, _| {
 951            entries_len = thread.plan().entries.len();
 952            assert!(thread.plan().entries.len() > 0, "Empty plan");
 953        });
 954
 955        thread
 956            .update(cx, |thread, cx| {
 957                thread.send_raw(
 958                    "Mark the first entry status as in progress without acting on it.",
 959                    cx,
 960                )
 961            })
 962            .await
 963            .unwrap();
 964
 965        thread.read_with(cx, |thread, _| {
 966            assert!(matches!(
 967                thread.plan().entries[0].status,
 968                acp::PlanEntryStatus::InProgress
 969            ));
 970            assert_eq!(thread.plan().entries.len(), entries_len);
 971        });
 972
 973        thread
 974            .update(cx, |thread, cx| {
 975                thread.send_raw(
 976                    "Now mark the first entry as completed without acting on it.",
 977                    cx,
 978                )
 979            })
 980            .await
 981            .unwrap();
 982
 983        thread.read_with(cx, |thread, _| {
 984            assert!(matches!(
 985                thread.plan().entries[0].status,
 986                acp::PlanEntryStatus::Completed
 987            ));
 988            assert_eq!(thread.plan().entries.len(), entries_len);
 989        });
 990    }
 991
 992    #[test]
 993    fn test_deserialize_content_untagged_text() {
 994        let json = json!("Hello, world!");
 995        let content: Content = serde_json::from_value(json).unwrap();
 996        match content {
 997            Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
 998            _ => panic!("Expected UntaggedText variant"),
 999        }
1000    }
1001
1002    #[test]
1003    fn test_deserialize_content_chunks() {
1004        let json = json!([
1005            {
1006                "type": "text",
1007                "text": "Hello"
1008            },
1009            {
1010                "type": "tool_use",
1011                "id": "tool_123",
1012                "name": "calculator",
1013                "input": {"operation": "add", "a": 1, "b": 2}
1014            }
1015        ]);
1016        let content: Content = serde_json::from_value(json).unwrap();
1017        match content {
1018            Content::Chunks(chunks) => {
1019                assert_eq!(chunks.len(), 2);
1020                match &chunks[0] {
1021                    ContentChunk::Text { text } => assert_eq!(text, "Hello"),
1022                    _ => panic!("Expected Text chunk"),
1023                }
1024                match &chunks[1] {
1025                    ContentChunk::ToolUse { id, name, input } => {
1026                        assert_eq!(id, "tool_123");
1027                        assert_eq!(name, "calculator");
1028                        assert_eq!(input["operation"], "add");
1029                        assert_eq!(input["a"], 1);
1030                        assert_eq!(input["b"], 2);
1031                    }
1032                    _ => panic!("Expected ToolUse chunk"),
1033                }
1034            }
1035            _ => panic!("Expected Chunks variant"),
1036        }
1037    }
1038
1039    #[test]
1040    fn test_deserialize_tool_result_untagged_text() {
1041        let json = json!({
1042            "type": "tool_result",
1043            "content": "Result content",
1044            "tool_use_id": "tool_456"
1045        });
1046        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1047        match chunk {
1048            ContentChunk::ToolResult {
1049                content,
1050                tool_use_id,
1051            } => {
1052                match content {
1053                    Content::UntaggedText(text) => assert_eq!(text, "Result content"),
1054                    _ => panic!("Expected UntaggedText content"),
1055                }
1056                assert_eq!(tool_use_id, "tool_456");
1057            }
1058            _ => panic!("Expected ToolResult variant"),
1059        }
1060    }
1061
1062    #[test]
1063    fn test_deserialize_tool_result_chunks() {
1064        let json = json!({
1065            "type": "tool_result",
1066            "content": [
1067                {
1068                    "type": "text",
1069                    "text": "Processing complete"
1070                },
1071                {
1072                    "type": "text",
1073                    "text": "Result: 42"
1074                }
1075            ],
1076            "tool_use_id": "tool_789"
1077        });
1078        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1079        match chunk {
1080            ContentChunk::ToolResult {
1081                content,
1082                tool_use_id,
1083            } => {
1084                match content {
1085                    Content::Chunks(chunks) => {
1086                        assert_eq!(chunks.len(), 2);
1087                        match &chunks[0] {
1088                            ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1089                            _ => panic!("Expected Text chunk"),
1090                        }
1091                        match &chunks[1] {
1092                            ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1093                            _ => panic!("Expected Text chunk"),
1094                        }
1095                    }
1096                    _ => panic!("Expected Chunks content"),
1097                }
1098                assert_eq!(tool_use_id, "tool_789");
1099            }
1100            _ => panic!("Expected ToolResult variant"),
1101        }
1102    }
1103}