claude.rs

   1mod mcp_server;
   2pub mod tools;
   3
   4use action_log::ActionLog;
   5use collections::HashMap;
   6use context_server::listener::McpServerTool;
   7use language_models::provider::anthropic::AnthropicLanguageModelProvider;
   8use project::Project;
   9use settings::SettingsStore;
  10use smol::process::Child;
  11use std::any::Any;
  12use std::cell::RefCell;
  13use std::fmt::Display;
  14use std::path::Path;
  15use std::rc::Rc;
  16use uuid::Uuid;
  17
  18use agent_client_protocol as acp;
  19use anyhow::{Context as _, Result, anyhow};
  20use futures::channel::oneshot;
  21use futures::{AsyncBufReadExt, AsyncWriteExt};
  22use futures::{
  23    AsyncRead, AsyncWrite, FutureExt, StreamExt,
  24    channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
  25    io::BufReader,
  26    select_biased,
  27};
  28use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  29use serde::{Deserialize, Serialize};
  30use util::{ResultExt, debug_panic};
  31
  32use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
  33use crate::claude::tools::ClaudeTool;
  34use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
  35use acp_thread::{AcpThread, AgentConnection, AuthRequired};
  36
  37#[derive(Clone)]
  38pub struct ClaudeCode;
  39
  40impl AgentServer for ClaudeCode {
  41    fn name(&self) -> &'static str {
  42        "Claude Code"
  43    }
  44
  45    fn empty_state_headline(&self) -> &'static str {
  46        self.name()
  47    }
  48
  49    fn empty_state_message(&self) -> &'static str {
  50        "How can I help you today?"
  51    }
  52
  53    fn logo(&self) -> ui::IconName {
  54        ui::IconName::AiClaude
  55    }
  56
  57    fn connect(
  58        &self,
  59        _root_dir: &Path,
  60        _project: &Entity<Project>,
  61        _cx: &mut App,
  62    ) -> Task<Result<Rc<dyn AgentConnection>>> {
  63        let connection = ClaudeAgentConnection {
  64            sessions: Default::default(),
  65        };
  66
  67        Task::ready(Ok(Rc::new(connection) as _))
  68    }
  69
  70    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
  71        self
  72    }
  73}
  74
  75struct ClaudeAgentConnection {
  76    sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
  77}
  78
  79impl AgentConnection for ClaudeAgentConnection {
  80    fn new_thread(
  81        self: Rc<Self>,
  82        project: Entity<Project>,
  83        cwd: &Path,
  84        cx: &mut App,
  85    ) -> Task<Result<Entity<AcpThread>>> {
  86        let cwd = cwd.to_owned();
  87        cx.spawn(async move |cx| {
  88            let settings = cx.read_global(|settings: &SettingsStore, _| {
  89                settings.get::<AllAgentServersSettings>(None).claude.clone()
  90            })?;
  91
  92            let Some(command) = AgentServerCommand::resolve(
  93                "claude",
  94                &[],
  95                Some(&util::paths::home_dir().join(".claude/local/claude")),
  96                settings,
  97                &project,
  98                cx,
  99            )
 100            .await
 101            else {
 102                anyhow::bail!("Failed to find claude binary");
 103            };
 104
 105            let api_key =
 106                cx.update(AnthropicLanguageModelProvider::api_key)?
 107                    .await
 108                    .map_err(|err| {
 109                        if err.is::<language_model::AuthenticateError>() {
 110                            anyhow!(AuthRequired::new().with_language_model_provider(
 111                                language_model::ANTHROPIC_PROVIDER_ID
 112                            ))
 113                        } else {
 114                            anyhow!(err)
 115                        }
 116                    })?;
 117
 118            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
 119            let fs = project.read_with(cx, |project, _cx| project.fs().clone())?;
 120            let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), fs, cx).await?;
 121
 122            let mut mcp_servers = HashMap::default();
 123            mcp_servers.insert(
 124                mcp_server::SERVER_NAME.to_string(),
 125                permission_mcp_server.server_config()?,
 126            );
 127            let mcp_config = McpConfig { mcp_servers };
 128
 129            let mcp_config_file = tempfile::NamedTempFile::new()?;
 130            let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
 131
 132            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
 133            mcp_config_file
 134                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
 135                .await?;
 136            mcp_config_file.flush().await?;
 137
 138            let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
 139            let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
 140
 141            let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
 142
 143            log::trace!("Starting session with id: {}", session_id);
 144
 145            let mut child = spawn_claude(
 146                &command,
 147                ClaudeSessionMode::Start,
 148                session_id.clone(),
 149                api_key,
 150                &mcp_config_path,
 151                &cwd,
 152            )?;
 153
 154            let stdout = child.stdout.take().context("Failed to take stdout")?;
 155            let stdin = child.stdin.take().context("Failed to take stdin")?;
 156            let stderr = child.stderr.take().context("Failed to take stderr")?;
 157
 158            let pid = child.id();
 159            log::trace!("Spawned (pid: {})", pid);
 160
 161            cx.background_spawn(async move {
 162                let mut stderr = BufReader::new(stderr);
 163                let mut line = String::new();
 164                while let Ok(n) = stderr.read_line(&mut line).await
 165                    && n > 0
 166                {
 167                    log::warn!("agent stderr: {}", &line);
 168                    line.clear();
 169                }
 170            })
 171            .detach();
 172
 173            cx.background_spawn(async move {
 174                let mut outgoing_rx = Some(outgoing_rx);
 175
 176                ClaudeAgentSession::handle_io(
 177                    outgoing_rx.take().unwrap(),
 178                    incoming_message_tx.clone(),
 179                    stdin,
 180                    stdout,
 181                )
 182                .await?;
 183
 184                log::trace!("Stopped (pid: {})", pid);
 185
 186                drop(mcp_config_path);
 187                anyhow::Ok(())
 188            })
 189            .detach();
 190
 191            let turn_state = Rc::new(RefCell::new(TurnState::None));
 192
 193            let handler_task = cx.spawn({
 194                let turn_state = turn_state.clone();
 195                let mut thread_rx = thread_rx.clone();
 196                async move |cx| {
 197                    while let Some(message) = incoming_message_rx.next().await {
 198                        ClaudeAgentSession::handle_message(
 199                            thread_rx.clone(),
 200                            message,
 201                            turn_state.clone(),
 202                            cx,
 203                        )
 204                        .await
 205                    }
 206
 207                    if let Some(status) = child.status().await.log_err()
 208                        && let Some(thread) = thread_rx.recv().await.ok()
 209                    {
 210                        thread
 211                            .update(cx, |thread, cx| {
 212                                thread.emit_server_exited(status, cx);
 213                            })
 214                            .ok();
 215                    }
 216                }
 217            });
 218
 219            let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
 220            let thread = cx.new(|_cx| {
 221                AcpThread::new(
 222                    "Claude Code",
 223                    self.clone(),
 224                    project,
 225                    action_log,
 226                    session_id.clone(),
 227                )
 228            })?;
 229
 230            thread_tx.send(thread.downgrade())?;
 231
 232            let session = ClaudeAgentSession {
 233                outgoing_tx,
 234                turn_state,
 235                _handler_task: handler_task,
 236                _mcp_server: Some(permission_mcp_server),
 237            };
 238
 239            self.sessions.borrow_mut().insert(session_id, session);
 240
 241            Ok(thread)
 242        })
 243    }
 244
 245    fn auth_methods(&self) -> &[acp::AuthMethod] {
 246        &[]
 247    }
 248
 249    fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 250        Task::ready(Err(anyhow!("Authentication not supported")))
 251    }
 252
 253    fn prompt(
 254        &self,
 255        _id: Option<acp_thread::UserMessageId>,
 256        params: acp::PromptRequest,
 257        cx: &mut App,
 258    ) -> Task<Result<acp::PromptResponse>> {
 259        let sessions = self.sessions.borrow();
 260        let Some(session) = sessions.get(&params.session_id) else {
 261            return Task::ready(Err(anyhow!(
 262                "Attempted to send message to nonexistent session {}",
 263                params.session_id
 264            )));
 265        };
 266
 267        let (end_tx, end_rx) = oneshot::channel();
 268        session.turn_state.replace(TurnState::InProgress { end_tx });
 269
 270        let mut content = String::new();
 271        for chunk in params.prompt {
 272            match chunk {
 273                acp::ContentBlock::Text(text_content) => {
 274                    content.push_str(&text_content.text);
 275                }
 276                acp::ContentBlock::ResourceLink(resource_link) => {
 277                    content.push_str(&format!("@{}", resource_link.uri));
 278                }
 279                acp::ContentBlock::Audio(_)
 280                | acp::ContentBlock::Image(_)
 281                | acp::ContentBlock::Resource(_) => {
 282                    // TODO
 283                }
 284            }
 285        }
 286
 287        if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
 288            message: Message {
 289                role: Role::User,
 290                content: Content::UntaggedText(content),
 291                id: None,
 292                model: None,
 293                stop_reason: None,
 294                stop_sequence: None,
 295                usage: None,
 296            },
 297            session_id: Some(params.session_id.to_string()),
 298        }) {
 299            return Task::ready(Err(anyhow!(err)));
 300        }
 301
 302        cx.foreground_executor().spawn(async move { end_rx.await? })
 303    }
 304
 305    fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
 306        let sessions = self.sessions.borrow();
 307        let Some(session) = sessions.get(session_id) else {
 308            log::warn!("Attempted to cancel nonexistent session {}", session_id);
 309            return;
 310        };
 311
 312        let request_id = new_request_id();
 313
 314        let turn_state = session.turn_state.take();
 315        let TurnState::InProgress { end_tx } = turn_state else {
 316            // Already canceled or idle, put it back
 317            session.turn_state.replace(turn_state);
 318            return;
 319        };
 320
 321        session.turn_state.replace(TurnState::CancelRequested {
 322            end_tx,
 323            request_id: request_id.clone(),
 324        });
 325
 326        session
 327            .outgoing_tx
 328            .unbounded_send(SdkMessage::ControlRequest {
 329                request_id,
 330                request: ControlRequest::Interrupt,
 331            })
 332            .log_err();
 333    }
 334
 335    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 336        self
 337    }
 338}
 339
 340#[derive(Clone, Copy)]
 341enum ClaudeSessionMode {
 342    Start,
 343    #[expect(dead_code)]
 344    Resume,
 345}
 346
 347fn spawn_claude(
 348    command: &AgentServerCommand,
 349    mode: ClaudeSessionMode,
 350    session_id: acp::SessionId,
 351    api_key: language_models::provider::anthropic::ApiKey,
 352    mcp_config_path: &Path,
 353    root_dir: &Path,
 354) -> Result<Child> {
 355    let child = util::command::new_smol_command(&command.path)
 356        .args([
 357            "--input-format",
 358            "stream-json",
 359            "--output-format",
 360            "stream-json",
 361            "--print",
 362            "--verbose",
 363            "--mcp-config",
 364            mcp_config_path.to_string_lossy().as_ref(),
 365            "--permission-prompt-tool",
 366            &format!(
 367                "mcp__{}__{}",
 368                mcp_server::SERVER_NAME,
 369                mcp_server::PermissionTool::NAME,
 370            ),
 371            "--allowedTools",
 372            &format!(
 373                "mcp__{}__{},mcp__{}__{}",
 374                mcp_server::SERVER_NAME,
 375                mcp_server::EditTool::NAME,
 376                mcp_server::SERVER_NAME,
 377                mcp_server::ReadTool::NAME
 378            ),
 379            "--disallowedTools",
 380            "Read,Edit",
 381        ])
 382        .args(match mode {
 383            ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
 384            ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
 385        })
 386        .args(command.args.iter().map(|arg| arg.as_str()))
 387        .envs(command.env.iter().flatten())
 388        .env("ANTHROPIC_API_KEY", api_key.key)
 389        .current_dir(root_dir)
 390        .stdin(std::process::Stdio::piped())
 391        .stdout(std::process::Stdio::piped())
 392        .stderr(std::process::Stdio::piped())
 393        .kill_on_drop(true)
 394        .spawn()?;
 395
 396    Ok(child)
 397}
 398
 399struct ClaudeAgentSession {
 400    outgoing_tx: UnboundedSender<SdkMessage>,
 401    turn_state: Rc<RefCell<TurnState>>,
 402    _mcp_server: Option<ClaudeZedMcpServer>,
 403    _handler_task: Task<()>,
 404}
 405
 406#[derive(Debug, Default)]
 407enum TurnState {
 408    #[default]
 409    None,
 410    InProgress {
 411        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 412    },
 413    CancelRequested {
 414        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 415        request_id: String,
 416    },
 417    CancelConfirmed {
 418        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 419    },
 420}
 421
 422impl TurnState {
 423    fn is_canceled(&self) -> bool {
 424        matches!(self, TurnState::CancelConfirmed { .. })
 425    }
 426
 427    fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
 428        match self {
 429            TurnState::None => None,
 430            TurnState::InProgress { end_tx, .. } => Some(end_tx),
 431            TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
 432            TurnState::CancelConfirmed { end_tx } => Some(end_tx),
 433        }
 434    }
 435
 436    fn confirm_cancellation(self, id: &str) -> Self {
 437        match self {
 438            TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
 439                TurnState::CancelConfirmed { end_tx }
 440            }
 441            _ => self,
 442        }
 443    }
 444}
 445
 446impl ClaudeAgentSession {
 447    async fn handle_message(
 448        mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
 449        message: SdkMessage,
 450        turn_state: Rc<RefCell<TurnState>>,
 451        cx: &mut AsyncApp,
 452    ) {
 453        match message {
 454            // we should only be sending these out, they don't need to be in the thread
 455            SdkMessage::ControlRequest { .. } => {}
 456            SdkMessage::User {
 457                message,
 458                session_id: _,
 459            } => {
 460                let Some(thread) = thread_rx
 461                    .recv()
 462                    .await
 463                    .log_err()
 464                    .and_then(|entity| entity.upgrade())
 465                else {
 466                    log::error!("Received an SDK message but thread is gone");
 467                    return;
 468                };
 469
 470                for chunk in message.content.chunks() {
 471                    match chunk {
 472                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 473                            if !turn_state.borrow().is_canceled() {
 474                                thread
 475                                    .update(cx, |thread, cx| {
 476                                        thread.push_user_content_block(None, text.into(), cx)
 477                                    })
 478                                    .log_err();
 479                            }
 480                        }
 481                        ContentChunk::ToolResult {
 482                            content,
 483                            tool_use_id,
 484                        } => {
 485                            let content = content.to_string();
 486                            thread
 487                                .update(cx, |thread, cx| {
 488                                    thread.update_tool_call(
 489                                        acp::ToolCallUpdate {
 490                                            id: acp::ToolCallId(tool_use_id.into()),
 491                                            fields: acp::ToolCallUpdateFields {
 492                                                status: if turn_state.borrow().is_canceled() {
 493                                                    // Do not set to completed if turn was canceled
 494                                                    None
 495                                                } else {
 496                                                    Some(acp::ToolCallStatus::Completed)
 497                                                },
 498                                                content: (!content.is_empty())
 499                                                    .then(|| vec![content.into()]),
 500                                                ..Default::default()
 501                                            },
 502                                        },
 503                                        cx,
 504                                    )
 505                                })
 506                                .log_err();
 507                        }
 508                        ContentChunk::Thinking { .. }
 509                        | ContentChunk::RedactedThinking
 510                        | ContentChunk::ToolUse { .. } => {
 511                            debug_panic!(
 512                                "Should not get {:?} with role: assistant. should we handle this?",
 513                                chunk
 514                            );
 515                        }
 516
 517                        ContentChunk::Image
 518                        | ContentChunk::Document
 519                        | ContentChunk::WebSearchToolResult => {
 520                            thread
 521                                .update(cx, |thread, cx| {
 522                                    thread.push_assistant_content_block(
 523                                        format!("Unsupported content: {:?}", chunk).into(),
 524                                        false,
 525                                        cx,
 526                                    )
 527                                })
 528                                .log_err();
 529                        }
 530                    }
 531                }
 532            }
 533            SdkMessage::Assistant {
 534                message,
 535                session_id: _,
 536            } => {
 537                let Some(thread) = thread_rx
 538                    .recv()
 539                    .await
 540                    .log_err()
 541                    .and_then(|entity| entity.upgrade())
 542                else {
 543                    log::error!("Received an SDK message but thread is gone");
 544                    return;
 545                };
 546
 547                for chunk in message.content.chunks() {
 548                    match chunk {
 549                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 550                            thread
 551                                .update(cx, |thread, cx| {
 552                                    thread.push_assistant_content_block(text.into(), false, cx)
 553                                })
 554                                .log_err();
 555                        }
 556                        ContentChunk::Thinking { thinking } => {
 557                            thread
 558                                .update(cx, |thread, cx| {
 559                                    thread.push_assistant_content_block(thinking.into(), true, cx)
 560                                })
 561                                .log_err();
 562                        }
 563                        ContentChunk::RedactedThinking => {
 564                            thread
 565                                .update(cx, |thread, cx| {
 566                                    thread.push_assistant_content_block(
 567                                        "[REDACTED]".into(),
 568                                        true,
 569                                        cx,
 570                                    )
 571                                })
 572                                .log_err();
 573                        }
 574                        ContentChunk::ToolUse { id, name, input } => {
 575                            let claude_tool = ClaudeTool::infer(&name, input);
 576
 577                            thread
 578                                .update(cx, |thread, cx| {
 579                                    if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
 580                                        thread.update_plan(
 581                                            acp::Plan {
 582                                                entries: params
 583                                                    .todos
 584                                                    .into_iter()
 585                                                    .map(Into::into)
 586                                                    .collect(),
 587                                            },
 588                                            cx,
 589                                        )
 590                                    } else {
 591                                        thread.upsert_tool_call(
 592                                            claude_tool.as_acp(acp::ToolCallId(id.into())),
 593                                            cx,
 594                                        )?;
 595                                    }
 596                                    anyhow::Ok(())
 597                                })
 598                                .log_err();
 599                        }
 600                        ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
 601                            debug_panic!(
 602                                "Should not get tool results with role: assistant. should we handle this?"
 603                            );
 604                        }
 605                        ContentChunk::Image | ContentChunk::Document => {
 606                            thread
 607                                .update(cx, |thread, cx| {
 608                                    thread.push_assistant_content_block(
 609                                        format!("Unsupported content: {:?}", chunk).into(),
 610                                        false,
 611                                        cx,
 612                                    )
 613                                })
 614                                .log_err();
 615                        }
 616                    }
 617                }
 618            }
 619            SdkMessage::Result {
 620                is_error,
 621                subtype,
 622                result,
 623                ..
 624            } => {
 625                let turn_state = turn_state.take();
 626                let was_canceled = turn_state.is_canceled();
 627                let Some(end_turn_tx) = turn_state.end_tx() else {
 628                    debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
 629                    return;
 630                };
 631
 632                if is_error || (!was_canceled && subtype == ResultErrorType::ErrorDuringExecution) {
 633                    end_turn_tx
 634                        .send(Err(anyhow!(
 635                            "Error: {}",
 636                            result.unwrap_or_else(|| subtype.to_string())
 637                        )))
 638                        .ok();
 639                } else {
 640                    let stop_reason = match subtype {
 641                        ResultErrorType::Success => acp::StopReason::EndTurn,
 642                        ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
 643                        ResultErrorType::ErrorDuringExecution => acp::StopReason::Canceled,
 644                    };
 645                    end_turn_tx
 646                        .send(Ok(acp::PromptResponse { stop_reason }))
 647                        .ok();
 648                }
 649            }
 650            SdkMessage::ControlResponse { response } => {
 651                if matches!(response.subtype, ResultErrorType::Success) {
 652                    let new_state = turn_state.take().confirm_cancellation(&response.request_id);
 653                    turn_state.replace(new_state);
 654                } else {
 655                    log::error!("Control response error: {:?}", response);
 656                }
 657            }
 658            SdkMessage::System { .. } => {}
 659        }
 660    }
 661
 662    async fn handle_io(
 663        mut outgoing_rx: UnboundedReceiver<SdkMessage>,
 664        incoming_tx: UnboundedSender<SdkMessage>,
 665        mut outgoing_bytes: impl Unpin + AsyncWrite,
 666        incoming_bytes: impl Unpin + AsyncRead,
 667    ) -> Result<UnboundedReceiver<SdkMessage>> {
 668        let mut output_reader = BufReader::new(incoming_bytes);
 669        let mut outgoing_line = Vec::new();
 670        let mut incoming_line = String::new();
 671        loop {
 672            select_biased! {
 673                message = outgoing_rx.next() => {
 674                    if let Some(message) = message {
 675                        outgoing_line.clear();
 676                        serde_json::to_writer(&mut outgoing_line, &message)?;
 677                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
 678                        outgoing_line.push(b'\n');
 679                        outgoing_bytes.write_all(&outgoing_line).await.ok();
 680                    } else {
 681                        break;
 682                    }
 683                }
 684                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
 685                    if bytes_read? == 0 {
 686                        break
 687                    }
 688                    log::trace!("recv: {}", &incoming_line);
 689                    match serde_json::from_str::<SdkMessage>(&incoming_line) {
 690                        Ok(message) => {
 691                            incoming_tx.unbounded_send(message).log_err();
 692                        }
 693                        Err(error) => {
 694                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
 695                        }
 696                    }
 697                    incoming_line.clear();
 698                }
 699            }
 700        }
 701
 702        Ok(outgoing_rx)
 703    }
 704}
 705
 706#[derive(Debug, Clone, Serialize, Deserialize)]
 707struct Message {
 708    role: Role,
 709    content: Content,
 710    #[serde(skip_serializing_if = "Option::is_none")]
 711    id: Option<String>,
 712    #[serde(skip_serializing_if = "Option::is_none")]
 713    model: Option<String>,
 714    #[serde(skip_serializing_if = "Option::is_none")]
 715    stop_reason: Option<String>,
 716    #[serde(skip_serializing_if = "Option::is_none")]
 717    stop_sequence: Option<String>,
 718    #[serde(skip_serializing_if = "Option::is_none")]
 719    usage: Option<Usage>,
 720}
 721
 722#[derive(Debug, Clone, Serialize, Deserialize)]
 723#[serde(untagged)]
 724enum Content {
 725    UntaggedText(String),
 726    Chunks(Vec<ContentChunk>),
 727}
 728
 729impl Content {
 730    pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
 731        match self {
 732            Self::Chunks(chunks) => chunks.into_iter(),
 733            Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
 734        }
 735    }
 736}
 737
 738impl Display for Content {
 739    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 740        match self {
 741            Content::UntaggedText(txt) => write!(f, "{}", txt),
 742            Content::Chunks(chunks) => {
 743                for chunk in chunks {
 744                    write!(f, "{}", chunk)?;
 745                }
 746                Ok(())
 747            }
 748        }
 749    }
 750}
 751
 752#[derive(Debug, Clone, Serialize, Deserialize)]
 753#[serde(tag = "type", rename_all = "snake_case")]
 754enum ContentChunk {
 755    Text {
 756        text: String,
 757    },
 758    ToolUse {
 759        id: String,
 760        name: String,
 761        input: serde_json::Value,
 762    },
 763    ToolResult {
 764        content: Content,
 765        tool_use_id: String,
 766    },
 767    Thinking {
 768        thinking: String,
 769    },
 770    RedactedThinking,
 771    // TODO
 772    Image,
 773    Document,
 774    WebSearchToolResult,
 775    #[serde(untagged)]
 776    UntaggedText(String),
 777}
 778
 779impl Display for ContentChunk {
 780    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 781        match self {
 782            ContentChunk::Text { text } => write!(f, "{}", text),
 783            ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
 784            ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
 785            ContentChunk::UntaggedText(text) => write!(f, "{}", text),
 786            ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
 787            ContentChunk::Image
 788            | ContentChunk::Document
 789            | ContentChunk::ToolUse { .. }
 790            | ContentChunk::WebSearchToolResult => {
 791                write!(f, "\n{:?}\n", &self)
 792            }
 793        }
 794    }
 795}
 796
 797#[derive(Debug, Clone, Serialize, Deserialize)]
 798struct Usage {
 799    input_tokens: u32,
 800    cache_creation_input_tokens: u32,
 801    cache_read_input_tokens: u32,
 802    output_tokens: u32,
 803    service_tier: String,
 804}
 805
 806#[derive(Debug, Clone, Serialize, Deserialize)]
 807#[serde(rename_all = "snake_case")]
 808enum Role {
 809    System,
 810    Assistant,
 811    User,
 812}
 813
 814#[derive(Debug, Clone, Serialize, Deserialize)]
 815struct MessageParam {
 816    role: Role,
 817    content: String,
 818}
 819
 820#[derive(Debug, Clone, Serialize, Deserialize)]
 821#[serde(tag = "type", rename_all = "snake_case")]
 822enum SdkMessage {
 823    // An assistant message
 824    Assistant {
 825        message: Message, // from Anthropic SDK
 826        #[serde(skip_serializing_if = "Option::is_none")]
 827        session_id: Option<String>,
 828    },
 829    // A user message
 830    User {
 831        message: Message, // from Anthropic SDK
 832        #[serde(skip_serializing_if = "Option::is_none")]
 833        session_id: Option<String>,
 834    },
 835    // Emitted as the last message in a conversation
 836    Result {
 837        subtype: ResultErrorType,
 838        duration_ms: f64,
 839        duration_api_ms: f64,
 840        is_error: bool,
 841        num_turns: i32,
 842        #[serde(skip_serializing_if = "Option::is_none")]
 843        result: Option<String>,
 844        session_id: String,
 845        total_cost_usd: f64,
 846    },
 847    // Emitted as the first message at the start of a conversation
 848    System {
 849        cwd: String,
 850        session_id: String,
 851        tools: Vec<String>,
 852        model: String,
 853        mcp_servers: Vec<McpServer>,
 854        #[serde(rename = "apiKeySource")]
 855        api_key_source: String,
 856        #[serde(rename = "permissionMode")]
 857        permission_mode: PermissionMode,
 858    },
 859    /// Messages used to control the conversation, outside of chat messages to the model
 860    ControlRequest {
 861        request_id: String,
 862        request: ControlRequest,
 863    },
 864    /// Response to a control request
 865    ControlResponse { response: ControlResponse },
 866}
 867
 868#[derive(Debug, Clone, Serialize, Deserialize)]
 869#[serde(tag = "subtype", rename_all = "snake_case")]
 870enum ControlRequest {
 871    /// Cancel the current conversation
 872    Interrupt,
 873}
 874
 875#[derive(Debug, Clone, Serialize, Deserialize)]
 876struct ControlResponse {
 877    request_id: String,
 878    subtype: ResultErrorType,
 879}
 880
 881#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
 882#[serde(rename_all = "snake_case")]
 883enum ResultErrorType {
 884    Success,
 885    ErrorMaxTurns,
 886    ErrorDuringExecution,
 887}
 888
 889impl Display for ResultErrorType {
 890    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 891        match self {
 892            ResultErrorType::Success => write!(f, "success"),
 893            ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
 894            ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
 895        }
 896    }
 897}
 898
 899fn new_request_id() -> String {
 900    use rand::Rng;
 901    // In the Claude Code TS SDK they just generate a random 12 character string,
 902    // `Math.random().toString(36).substring(2, 15)`
 903    rand::thread_rng()
 904        .sample_iter(&rand::distributions::Alphanumeric)
 905        .take(12)
 906        .map(char::from)
 907        .collect()
 908}
 909
 910#[derive(Debug, Clone, Serialize, Deserialize)]
 911struct McpServer {
 912    name: String,
 913    status: String,
 914}
 915
 916#[derive(Debug, Clone, Serialize, Deserialize)]
 917#[serde(rename_all = "camelCase")]
 918enum PermissionMode {
 919    Default,
 920    AcceptEdits,
 921    BypassPermissions,
 922    Plan,
 923}
 924
 925#[cfg(test)]
 926pub(crate) mod tests {
 927    use super::*;
 928    use crate::e2e_tests;
 929    use gpui::TestAppContext;
 930    use serde_json::json;
 931
 932    crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
 933
 934    pub fn local_command() -> AgentServerCommand {
 935        AgentServerCommand {
 936            path: "claude".into(),
 937            args: vec![],
 938            env: None,
 939        }
 940    }
 941
 942    #[gpui::test]
 943    #[cfg_attr(not(feature = "e2e"), ignore)]
 944    async fn test_todo_plan(cx: &mut TestAppContext) {
 945        let fs = e2e_tests::init_test(cx).await;
 946        let project = Project::test(fs, [], cx).await;
 947        let thread =
 948            e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
 949
 950        thread
 951            .update(cx, |thread, cx| {
 952                thread.send_raw(
 953                    "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
 954                    cx,
 955                )
 956            })
 957            .await
 958            .unwrap();
 959
 960        let mut entries_len = 0;
 961
 962        thread.read_with(cx, |thread, _| {
 963            entries_len = thread.plan().entries.len();
 964            assert!(thread.plan().entries.len() > 0, "Empty plan");
 965        });
 966
 967        thread
 968            .update(cx, |thread, cx| {
 969                thread.send_raw(
 970                    "Mark the first entry status as in progress without acting on it.",
 971                    cx,
 972                )
 973            })
 974            .await
 975            .unwrap();
 976
 977        thread.read_with(cx, |thread, _| {
 978            assert!(matches!(
 979                thread.plan().entries[0].status,
 980                acp::PlanEntryStatus::InProgress
 981            ));
 982            assert_eq!(thread.plan().entries.len(), entries_len);
 983        });
 984
 985        thread
 986            .update(cx, |thread, cx| {
 987                thread.send_raw(
 988                    "Now mark the first entry as completed without acting on it.",
 989                    cx,
 990                )
 991            })
 992            .await
 993            .unwrap();
 994
 995        thread.read_with(cx, |thread, _| {
 996            assert!(matches!(
 997                thread.plan().entries[0].status,
 998                acp::PlanEntryStatus::Completed
 999            ));
1000            assert_eq!(thread.plan().entries.len(), entries_len);
1001        });
1002    }
1003
1004    #[test]
1005    fn test_deserialize_content_untagged_text() {
1006        let json = json!("Hello, world!");
1007        let content: Content = serde_json::from_value(json).unwrap();
1008        match content {
1009            Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
1010            _ => panic!("Expected UntaggedText variant"),
1011        }
1012    }
1013
1014    #[test]
1015    fn test_deserialize_content_chunks() {
1016        let json = json!([
1017            {
1018                "type": "text",
1019                "text": "Hello"
1020            },
1021            {
1022                "type": "tool_use",
1023                "id": "tool_123",
1024                "name": "calculator",
1025                "input": {"operation": "add", "a": 1, "b": 2}
1026            }
1027        ]);
1028        let content: Content = serde_json::from_value(json).unwrap();
1029        match content {
1030            Content::Chunks(chunks) => {
1031                assert_eq!(chunks.len(), 2);
1032                match &chunks[0] {
1033                    ContentChunk::Text { text } => assert_eq!(text, "Hello"),
1034                    _ => panic!("Expected Text chunk"),
1035                }
1036                match &chunks[1] {
1037                    ContentChunk::ToolUse { id, name, input } => {
1038                        assert_eq!(id, "tool_123");
1039                        assert_eq!(name, "calculator");
1040                        assert_eq!(input["operation"], "add");
1041                        assert_eq!(input["a"], 1);
1042                        assert_eq!(input["b"], 2);
1043                    }
1044                    _ => panic!("Expected ToolUse chunk"),
1045                }
1046            }
1047            _ => panic!("Expected Chunks variant"),
1048        }
1049    }
1050
1051    #[test]
1052    fn test_deserialize_tool_result_untagged_text() {
1053        let json = json!({
1054            "type": "tool_result",
1055            "content": "Result content",
1056            "tool_use_id": "tool_456"
1057        });
1058        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1059        match chunk {
1060            ContentChunk::ToolResult {
1061                content,
1062                tool_use_id,
1063            } => {
1064                match content {
1065                    Content::UntaggedText(text) => assert_eq!(text, "Result content"),
1066                    _ => panic!("Expected UntaggedText content"),
1067                }
1068                assert_eq!(tool_use_id, "tool_456");
1069            }
1070            _ => panic!("Expected ToolResult variant"),
1071        }
1072    }
1073
1074    #[test]
1075    fn test_deserialize_tool_result_chunks() {
1076        let json = json!({
1077            "type": "tool_result",
1078            "content": [
1079                {
1080                    "type": "text",
1081                    "text": "Processing complete"
1082                },
1083                {
1084                    "type": "text",
1085                    "text": "Result: 42"
1086                }
1087            ],
1088            "tool_use_id": "tool_789"
1089        });
1090        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1091        match chunk {
1092            ContentChunk::ToolResult {
1093                content,
1094                tool_use_id,
1095            } => {
1096                match content {
1097                    Content::Chunks(chunks) => {
1098                        assert_eq!(chunks.len(), 2);
1099                        match &chunks[0] {
1100                            ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1101                            _ => panic!("Expected Text chunk"),
1102                        }
1103                        match &chunks[1] {
1104                            ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1105                            _ => panic!("Expected Text chunk"),
1106                        }
1107                    }
1108                    _ => panic!("Expected Chunks content"),
1109                }
1110                assert_eq!(tool_use_id, "tool_789");
1111            }
1112            _ => panic!("Expected ToolResult variant"),
1113        }
1114    }
1115}