claude.rs

   1mod mcp_server;
   2pub mod tools;
   3
   4use action_log::ActionLog;
   5use collections::HashMap;
   6use context_server::listener::McpServerTool;
   7use language_models::provider::anthropic::AnthropicLanguageModelProvider;
   8use project::Project;
   9use settings::SettingsStore;
  10use smol::process::Child;
  11use std::any::Any;
  12use std::cell::RefCell;
  13use std::fmt::Display;
  14use std::path::Path;
  15use std::rc::Rc;
  16use uuid::Uuid;
  17
  18use agent_client_protocol as acp;
  19use anyhow::{Context as _, Result, anyhow};
  20use futures::channel::oneshot;
  21use futures::{AsyncBufReadExt, AsyncWriteExt};
  22use futures::{
  23    AsyncRead, AsyncWrite, FutureExt, StreamExt,
  24    channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
  25    io::BufReader,
  26    select_biased,
  27};
  28use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  29use serde::{Deserialize, Serialize};
  30use util::{ResultExt, debug_panic};
  31
  32use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
  33use crate::claude::tools::ClaudeTool;
  34use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
  35use acp_thread::{AcpThread, AgentConnection, AuthRequired, MentionUri};
  36
  37#[derive(Clone)]
  38pub struct ClaudeCode;
  39
  40impl AgentServer for ClaudeCode {
  41    fn name(&self) -> &'static str {
  42        "Claude Code"
  43    }
  44
  45    fn empty_state_headline(&self) -> &'static str {
  46        self.name()
  47    }
  48
  49    fn empty_state_message(&self) -> &'static str {
  50        "How can I help you today?"
  51    }
  52
  53    fn logo(&self) -> ui::IconName {
  54        ui::IconName::AiClaude
  55    }
  56
  57    fn connect(
  58        &self,
  59        _root_dir: &Path,
  60        _project: &Entity<Project>,
  61        _cx: &mut App,
  62    ) -> Task<Result<Rc<dyn AgentConnection>>> {
  63        let connection = ClaudeAgentConnection {
  64            sessions: Default::default(),
  65        };
  66
  67        Task::ready(Ok(Rc::new(connection) as _))
  68    }
  69
  70    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
  71        self
  72    }
  73}
  74
  75struct ClaudeAgentConnection {
  76    sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
  77}
  78
  79impl AgentConnection for ClaudeAgentConnection {
  80    fn new_thread(
  81        self: Rc<Self>,
  82        project: Entity<Project>,
  83        cwd: &Path,
  84        cx: &mut App,
  85    ) -> Task<Result<Entity<AcpThread>>> {
  86        let cwd = cwd.to_owned();
  87        cx.spawn(async move |cx| {
  88            let settings = cx.read_global(|settings: &SettingsStore, _| {
  89                settings.get::<AllAgentServersSettings>(None).claude.clone()
  90            })?;
  91
  92            let Some(command) = AgentServerCommand::resolve(
  93                "claude",
  94                &[],
  95                Some(&util::paths::home_dir().join(".claude/local/claude")),
  96                settings,
  97                &project,
  98                cx,
  99            )
 100            .await
 101            else {
 102                anyhow::bail!("Failed to find claude binary");
 103            };
 104
 105            let api_key =
 106                cx.update(AnthropicLanguageModelProvider::api_key)?
 107                    .await
 108                    .map_err(|err| {
 109                        if err.is::<language_model::AuthenticateError>() {
 110                            anyhow!(AuthRequired::new().with_language_model_provider(
 111                                language_model::ANTHROPIC_PROVIDER_ID
 112                            ))
 113                        } else {
 114                            anyhow!(err)
 115                        }
 116                    })?;
 117
 118            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
 119            let fs = project.read_with(cx, |project, _cx| project.fs().clone())?;
 120            let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), fs, cx).await?;
 121
 122            let mut mcp_servers = HashMap::default();
 123            mcp_servers.insert(
 124                mcp_server::SERVER_NAME.to_string(),
 125                permission_mcp_server.server_config()?,
 126            );
 127            let mcp_config = McpConfig { mcp_servers };
 128
 129            let mcp_config_file = tempfile::NamedTempFile::new()?;
 130            let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
 131
 132            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
 133            mcp_config_file
 134                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
 135                .await?;
 136            mcp_config_file.flush().await?;
 137
 138            let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
 139            let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
 140
 141            let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
 142
 143            log::trace!("Starting session with id: {}", session_id);
 144
 145            let mut child = spawn_claude(
 146                &command,
 147                ClaudeSessionMode::Start,
 148                session_id.clone(),
 149                api_key,
 150                &mcp_config_path,
 151                &cwd,
 152            )?;
 153
 154            let stdout = child.stdout.take().context("Failed to take stdout")?;
 155            let stdin = child.stdin.take().context("Failed to take stdin")?;
 156            let stderr = child.stderr.take().context("Failed to take stderr")?;
 157
 158            let pid = child.id();
 159            log::trace!("Spawned (pid: {})", pid);
 160
 161            cx.background_spawn(async move {
 162                let mut stderr = BufReader::new(stderr);
 163                let mut line = String::new();
 164                while let Ok(n) = stderr.read_line(&mut line).await
 165                    && n > 0
 166                {
 167                    log::warn!("agent stderr: {}", &line);
 168                    line.clear();
 169                }
 170            })
 171            .detach();
 172
 173            cx.background_spawn(async move {
 174                let mut outgoing_rx = Some(outgoing_rx);
 175
 176                ClaudeAgentSession::handle_io(
 177                    outgoing_rx.take().unwrap(),
 178                    incoming_message_tx.clone(),
 179                    stdin,
 180                    stdout,
 181                )
 182                .await?;
 183
 184                log::trace!("Stopped (pid: {})", pid);
 185
 186                drop(mcp_config_path);
 187                anyhow::Ok(())
 188            })
 189            .detach();
 190
 191            let turn_state = Rc::new(RefCell::new(TurnState::None));
 192
 193            let handler_task = cx.spawn({
 194                let turn_state = turn_state.clone();
 195                let mut thread_rx = thread_rx.clone();
 196                async move |cx| {
 197                    while let Some(message) = incoming_message_rx.next().await {
 198                        ClaudeAgentSession::handle_message(
 199                            thread_rx.clone(),
 200                            message,
 201                            turn_state.clone(),
 202                            cx,
 203                        )
 204                        .await
 205                    }
 206
 207                    if let Some(status) = child.status().await.log_err()
 208                        && let Some(thread) = thread_rx.recv().await.ok()
 209                    {
 210                        thread
 211                            .update(cx, |thread, cx| {
 212                                thread.emit_server_exited(status, cx);
 213                            })
 214                            .ok();
 215                    }
 216                }
 217            });
 218
 219            let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
 220            let thread = cx.new(|_cx| {
 221                AcpThread::new(
 222                    "Claude Code",
 223                    self.clone(),
 224                    project,
 225                    action_log,
 226                    session_id.clone(),
 227                )
 228            })?;
 229
 230            thread_tx.send(thread.downgrade())?;
 231
 232            let session = ClaudeAgentSession {
 233                outgoing_tx,
 234                turn_state,
 235                _handler_task: handler_task,
 236                _mcp_server: Some(permission_mcp_server),
 237            };
 238
 239            self.sessions.borrow_mut().insert(session_id, session);
 240
 241            Ok(thread)
 242        })
 243    }
 244
 245    fn auth_methods(&self) -> &[acp::AuthMethod] {
 246        &[]
 247    }
 248
 249    fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 250        Task::ready(Err(anyhow!("Authentication not supported")))
 251    }
 252
 253    fn prompt(
 254        &self,
 255        _id: Option<acp_thread::UserMessageId>,
 256        params: acp::PromptRequest,
 257        cx: &mut App,
 258    ) -> Task<Result<acp::PromptResponse>> {
 259        let sessions = self.sessions.borrow();
 260        let Some(session) = sessions.get(&params.session_id) else {
 261            return Task::ready(Err(anyhow!(
 262                "Attempted to send message to nonexistent session {}",
 263                params.session_id
 264            )));
 265        };
 266
 267        let (end_tx, end_rx) = oneshot::channel();
 268        session.turn_state.replace(TurnState::InProgress { end_tx });
 269
 270        let content = acp_content_to_claude(params.prompt);
 271
 272        if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
 273            message: Message {
 274                role: Role::User,
 275                content: Content::Chunks(content),
 276                id: None,
 277                model: None,
 278                stop_reason: None,
 279                stop_sequence: None,
 280                usage: None,
 281            },
 282            session_id: Some(params.session_id.to_string()),
 283        }) {
 284            return Task::ready(Err(anyhow!(err)));
 285        }
 286
 287        cx.foreground_executor().spawn(async move { end_rx.await? })
 288    }
 289
 290    fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
 291        let sessions = self.sessions.borrow();
 292        let Some(session) = sessions.get(session_id) else {
 293            log::warn!("Attempted to cancel nonexistent session {}", session_id);
 294            return;
 295        };
 296
 297        let request_id = new_request_id();
 298
 299        let turn_state = session.turn_state.take();
 300        let TurnState::InProgress { end_tx } = turn_state else {
 301            // Already canceled or idle, put it back
 302            session.turn_state.replace(turn_state);
 303            return;
 304        };
 305
 306        session.turn_state.replace(TurnState::CancelRequested {
 307            end_tx,
 308            request_id: request_id.clone(),
 309        });
 310
 311        session
 312            .outgoing_tx
 313            .unbounded_send(SdkMessage::ControlRequest {
 314                request_id,
 315                request: ControlRequest::Interrupt,
 316            })
 317            .log_err();
 318    }
 319
 320    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 321        self
 322    }
 323}
 324
 325#[derive(Clone, Copy)]
 326enum ClaudeSessionMode {
 327    Start,
 328    #[expect(dead_code)]
 329    Resume,
 330}
 331
 332fn spawn_claude(
 333    command: &AgentServerCommand,
 334    mode: ClaudeSessionMode,
 335    session_id: acp::SessionId,
 336    api_key: language_models::provider::anthropic::ApiKey,
 337    mcp_config_path: &Path,
 338    root_dir: &Path,
 339) -> Result<Child> {
 340    let child = util::command::new_smol_command(&command.path)
 341        .args([
 342            "--input-format",
 343            "stream-json",
 344            "--output-format",
 345            "stream-json",
 346            "--print",
 347            "--verbose",
 348            "--mcp-config",
 349            mcp_config_path.to_string_lossy().as_ref(),
 350            "--permission-prompt-tool",
 351            &format!(
 352                "mcp__{}__{}",
 353                mcp_server::SERVER_NAME,
 354                mcp_server::PermissionTool::NAME,
 355            ),
 356            "--allowedTools",
 357            &format!(
 358                "mcp__{}__{},mcp__{}__{}",
 359                mcp_server::SERVER_NAME,
 360                mcp_server::EditTool::NAME,
 361                mcp_server::SERVER_NAME,
 362                mcp_server::ReadTool::NAME
 363            ),
 364            "--disallowedTools",
 365            "Read,Edit",
 366        ])
 367        .args(match mode {
 368            ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
 369            ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
 370        })
 371        .args(command.args.iter().map(|arg| arg.as_str()))
 372        .envs(command.env.iter().flatten())
 373        .env("ANTHROPIC_API_KEY", api_key.key)
 374        .current_dir(root_dir)
 375        .stdin(std::process::Stdio::piped())
 376        .stdout(std::process::Stdio::piped())
 377        .stderr(std::process::Stdio::piped())
 378        .kill_on_drop(true)
 379        .spawn()?;
 380
 381    Ok(child)
 382}
 383
 384struct ClaudeAgentSession {
 385    outgoing_tx: UnboundedSender<SdkMessage>,
 386    turn_state: Rc<RefCell<TurnState>>,
 387    _mcp_server: Option<ClaudeZedMcpServer>,
 388    _handler_task: Task<()>,
 389}
 390
 391#[derive(Debug, Default)]
 392enum TurnState {
 393    #[default]
 394    None,
 395    InProgress {
 396        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 397    },
 398    CancelRequested {
 399        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 400        request_id: String,
 401    },
 402    CancelConfirmed {
 403        end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
 404    },
 405}
 406
 407impl TurnState {
 408    fn is_canceled(&self) -> bool {
 409        matches!(self, TurnState::CancelConfirmed { .. })
 410    }
 411
 412    fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
 413        match self {
 414            TurnState::None => None,
 415            TurnState::InProgress { end_tx, .. } => Some(end_tx),
 416            TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
 417            TurnState::CancelConfirmed { end_tx } => Some(end_tx),
 418        }
 419    }
 420
 421    fn confirm_cancellation(self, id: &str) -> Self {
 422        match self {
 423            TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
 424                TurnState::CancelConfirmed { end_tx }
 425            }
 426            _ => self,
 427        }
 428    }
 429}
 430
 431impl ClaudeAgentSession {
 432    async fn handle_message(
 433        mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
 434        message: SdkMessage,
 435        turn_state: Rc<RefCell<TurnState>>,
 436        cx: &mut AsyncApp,
 437    ) {
 438        match message {
 439            // we should only be sending these out, they don't need to be in the thread
 440            SdkMessage::ControlRequest { .. } => {}
 441            SdkMessage::User {
 442                message,
 443                session_id: _,
 444            } => {
 445                let Some(thread) = thread_rx
 446                    .recv()
 447                    .await
 448                    .log_err()
 449                    .and_then(|entity| entity.upgrade())
 450                else {
 451                    log::error!("Received an SDK message but thread is gone");
 452                    return;
 453                };
 454
 455                for chunk in message.content.chunks() {
 456                    match chunk {
 457                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 458                            if !turn_state.borrow().is_canceled() {
 459                                thread
 460                                    .update(cx, |thread, cx| {
 461                                        thread.push_user_content_block(None, text.into(), cx)
 462                                    })
 463                                    .log_err();
 464                            }
 465                        }
 466                        ContentChunk::ToolResult {
 467                            content,
 468                            tool_use_id,
 469                        } => {
 470                            let content = content.to_string();
 471                            thread
 472                                .update(cx, |thread, cx| {
 473                                    thread.update_tool_call(
 474                                        acp::ToolCallUpdate {
 475                                            id: acp::ToolCallId(tool_use_id.into()),
 476                                            fields: acp::ToolCallUpdateFields {
 477                                                status: if turn_state.borrow().is_canceled() {
 478                                                    // Do not set to completed if turn was canceled
 479                                                    None
 480                                                } else {
 481                                                    Some(acp::ToolCallStatus::Completed)
 482                                                },
 483                                                content: (!content.is_empty())
 484                                                    .then(|| vec![content.into()]),
 485                                                ..Default::default()
 486                                            },
 487                                        },
 488                                        cx,
 489                                    )
 490                                })
 491                                .log_err();
 492                        }
 493                        ContentChunk::Thinking { .. }
 494                        | ContentChunk::RedactedThinking
 495                        | ContentChunk::ToolUse { .. } => {
 496                            debug_panic!(
 497                                "Should not get {:?} with role: assistant. should we handle this?",
 498                                chunk
 499                            );
 500                        }
 501                        ContentChunk::Image { source } => {
 502                            if !turn_state.borrow().is_canceled() {
 503                                thread
 504                                    .update(cx, |thread, cx| {
 505                                        thread.push_user_content_block(None, source.into(), cx)
 506                                    })
 507                                    .log_err();
 508                            }
 509                        }
 510
 511                        ContentChunk::Document | ContentChunk::WebSearchToolResult => {
 512                            thread
 513                                .update(cx, |thread, cx| {
 514                                    thread.push_assistant_content_block(
 515                                        format!("Unsupported content: {:?}", chunk).into(),
 516                                        false,
 517                                        cx,
 518                                    )
 519                                })
 520                                .log_err();
 521                        }
 522                    }
 523                }
 524            }
 525            SdkMessage::Assistant {
 526                message,
 527                session_id: _,
 528            } => {
 529                let Some(thread) = thread_rx
 530                    .recv()
 531                    .await
 532                    .log_err()
 533                    .and_then(|entity| entity.upgrade())
 534                else {
 535                    log::error!("Received an SDK message but thread is gone");
 536                    return;
 537                };
 538
 539                for chunk in message.content.chunks() {
 540                    match chunk {
 541                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 542                            thread
 543                                .update(cx, |thread, cx| {
 544                                    thread.push_assistant_content_block(text.into(), false, cx)
 545                                })
 546                                .log_err();
 547                        }
 548                        ContentChunk::Thinking { thinking } => {
 549                            thread
 550                                .update(cx, |thread, cx| {
 551                                    thread.push_assistant_content_block(thinking.into(), true, cx)
 552                                })
 553                                .log_err();
 554                        }
 555                        ContentChunk::RedactedThinking => {
 556                            thread
 557                                .update(cx, |thread, cx| {
 558                                    thread.push_assistant_content_block(
 559                                        "[REDACTED]".into(),
 560                                        true,
 561                                        cx,
 562                                    )
 563                                })
 564                                .log_err();
 565                        }
 566                        ContentChunk::ToolUse { id, name, input } => {
 567                            let claude_tool = ClaudeTool::infer(&name, input);
 568
 569                            thread
 570                                .update(cx, |thread, cx| {
 571                                    if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
 572                                        thread.update_plan(
 573                                            acp::Plan {
 574                                                entries: params
 575                                                    .todos
 576                                                    .into_iter()
 577                                                    .map(Into::into)
 578                                                    .collect(),
 579                                            },
 580                                            cx,
 581                                        )
 582                                    } else {
 583                                        thread.upsert_tool_call(
 584                                            claude_tool.as_acp(acp::ToolCallId(id.into())),
 585                                            cx,
 586                                        )?;
 587                                    }
 588                                    anyhow::Ok(())
 589                                })
 590                                .log_err();
 591                        }
 592                        ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
 593                            debug_panic!(
 594                                "Should not get tool results with role: assistant. should we handle this?"
 595                            );
 596                        }
 597                        ContentChunk::Image { source } => {
 598                            thread
 599                                .update(cx, |thread, cx| {
 600                                    thread.push_assistant_content_block(source.into(), false, cx)
 601                                })
 602                                .log_err();
 603                        }
 604                        ContentChunk::Document => {
 605                            thread
 606                                .update(cx, |thread, cx| {
 607                                    thread.push_assistant_content_block(
 608                                        format!("Unsupported content: {:?}", chunk).into(),
 609                                        false,
 610                                        cx,
 611                                    )
 612                                })
 613                                .log_err();
 614                        }
 615                    }
 616                }
 617            }
 618            SdkMessage::Result {
 619                is_error,
 620                subtype,
 621                result,
 622                ..
 623            } => {
 624                let turn_state = turn_state.take();
 625                let was_canceled = turn_state.is_canceled();
 626                let Some(end_turn_tx) = turn_state.end_tx() else {
 627                    debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
 628                    return;
 629                };
 630
 631                if is_error || (!was_canceled && subtype == ResultErrorType::ErrorDuringExecution) {
 632                    end_turn_tx
 633                        .send(Err(anyhow!(
 634                            "Error: {}",
 635                            result.unwrap_or_else(|| subtype.to_string())
 636                        )))
 637                        .ok();
 638                } else {
 639                    let stop_reason = match subtype {
 640                        ResultErrorType::Success => acp::StopReason::EndTurn,
 641                        ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
 642                        ResultErrorType::ErrorDuringExecution => acp::StopReason::Canceled,
 643                    };
 644                    end_turn_tx
 645                        .send(Ok(acp::PromptResponse { stop_reason }))
 646                        .ok();
 647                }
 648            }
 649            SdkMessage::ControlResponse { response } => {
 650                if matches!(response.subtype, ResultErrorType::Success) {
 651                    let new_state = turn_state.take().confirm_cancellation(&response.request_id);
 652                    turn_state.replace(new_state);
 653                } else {
 654                    log::error!("Control response error: {:?}", response);
 655                }
 656            }
 657            SdkMessage::System { .. } => {}
 658        }
 659    }
 660
 661    async fn handle_io(
 662        mut outgoing_rx: UnboundedReceiver<SdkMessage>,
 663        incoming_tx: UnboundedSender<SdkMessage>,
 664        mut outgoing_bytes: impl Unpin + AsyncWrite,
 665        incoming_bytes: impl Unpin + AsyncRead,
 666    ) -> Result<UnboundedReceiver<SdkMessage>> {
 667        let mut output_reader = BufReader::new(incoming_bytes);
 668        let mut outgoing_line = Vec::new();
 669        let mut incoming_line = String::new();
 670        loop {
 671            select_biased! {
 672                message = outgoing_rx.next() => {
 673                    if let Some(message) = message {
 674                        outgoing_line.clear();
 675                        serde_json::to_writer(&mut outgoing_line, &message)?;
 676                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
 677                        outgoing_line.push(b'\n');
 678                        outgoing_bytes.write_all(&outgoing_line).await.ok();
 679                    } else {
 680                        break;
 681                    }
 682                }
 683                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
 684                    if bytes_read? == 0 {
 685                        break
 686                    }
 687                    log::trace!("recv: {}", &incoming_line);
 688                    match serde_json::from_str::<SdkMessage>(&incoming_line) {
 689                        Ok(message) => {
 690                            incoming_tx.unbounded_send(message).log_err();
 691                        }
 692                        Err(error) => {
 693                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
 694                        }
 695                    }
 696                    incoming_line.clear();
 697                }
 698            }
 699        }
 700
 701        Ok(outgoing_rx)
 702    }
 703}
 704
 705#[derive(Debug, Clone, Serialize, Deserialize)]
 706struct Message {
 707    role: Role,
 708    content: Content,
 709    #[serde(skip_serializing_if = "Option::is_none")]
 710    id: Option<String>,
 711    #[serde(skip_serializing_if = "Option::is_none")]
 712    model: Option<String>,
 713    #[serde(skip_serializing_if = "Option::is_none")]
 714    stop_reason: Option<String>,
 715    #[serde(skip_serializing_if = "Option::is_none")]
 716    stop_sequence: Option<String>,
 717    #[serde(skip_serializing_if = "Option::is_none")]
 718    usage: Option<Usage>,
 719}
 720
 721#[derive(Debug, Clone, Serialize, Deserialize)]
 722#[serde(untagged)]
 723enum Content {
 724    UntaggedText(String),
 725    Chunks(Vec<ContentChunk>),
 726}
 727
 728impl Content {
 729    pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
 730        match self {
 731            Self::Chunks(chunks) => chunks.into_iter(),
 732            Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
 733        }
 734    }
 735}
 736
 737impl Display for Content {
 738    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 739        match self {
 740            Content::UntaggedText(txt) => write!(f, "{}", txt),
 741            Content::Chunks(chunks) => {
 742                for chunk in chunks {
 743                    write!(f, "{}", chunk)?;
 744                }
 745                Ok(())
 746            }
 747        }
 748    }
 749}
 750
 751#[derive(Debug, Clone, Serialize, Deserialize)]
 752#[serde(tag = "type", rename_all = "snake_case")]
 753enum ContentChunk {
 754    Text {
 755        text: String,
 756    },
 757    ToolUse {
 758        id: String,
 759        name: String,
 760        input: serde_json::Value,
 761    },
 762    ToolResult {
 763        content: Content,
 764        tool_use_id: String,
 765    },
 766    Thinking {
 767        thinking: String,
 768    },
 769    RedactedThinking,
 770    Image {
 771        source: ImageSource,
 772    },
 773    // TODO
 774    Document,
 775    WebSearchToolResult,
 776    #[serde(untagged)]
 777    UntaggedText(String),
 778}
 779
 780#[derive(Debug, Clone, Serialize, Deserialize)]
 781#[serde(tag = "type", rename_all = "snake_case")]
 782enum ImageSource {
 783    Base64 { data: String, media_type: String },
 784    Url { url: String },
 785}
 786
 787impl Into<acp::ContentBlock> for ImageSource {
 788    fn into(self) -> acp::ContentBlock {
 789        match self {
 790            ImageSource::Base64 { data, media_type } => {
 791                acp::ContentBlock::Image(acp::ImageContent {
 792                    annotations: None,
 793                    data,
 794                    mime_type: media_type,
 795                    uri: None,
 796                })
 797            }
 798            ImageSource::Url { url } => acp::ContentBlock::Image(acp::ImageContent {
 799                annotations: None,
 800                data: "".to_string(),
 801                mime_type: "".to_string(),
 802                uri: Some(url),
 803            }),
 804        }
 805    }
 806}
 807
 808impl Display for ContentChunk {
 809    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 810        match self {
 811            ContentChunk::Text { text } => write!(f, "{}", text),
 812            ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
 813            ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
 814            ContentChunk::UntaggedText(text) => write!(f, "{}", text),
 815            ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
 816            ContentChunk::Image { .. }
 817            | ContentChunk::Document
 818            | ContentChunk::ToolUse { .. }
 819            | ContentChunk::WebSearchToolResult => {
 820                write!(f, "\n{:?}\n", &self)
 821            }
 822        }
 823    }
 824}
 825
 826#[derive(Debug, Clone, Serialize, Deserialize)]
 827struct Usage {
 828    input_tokens: u32,
 829    cache_creation_input_tokens: u32,
 830    cache_read_input_tokens: u32,
 831    output_tokens: u32,
 832    service_tier: String,
 833}
 834
 835#[derive(Debug, Clone, Serialize, Deserialize)]
 836#[serde(rename_all = "snake_case")]
 837enum Role {
 838    System,
 839    Assistant,
 840    User,
 841}
 842
 843#[derive(Debug, Clone, Serialize, Deserialize)]
 844struct MessageParam {
 845    role: Role,
 846    content: String,
 847}
 848
 849#[derive(Debug, Clone, Serialize, Deserialize)]
 850#[serde(tag = "type", rename_all = "snake_case")]
 851enum SdkMessage {
 852    // An assistant message
 853    Assistant {
 854        message: Message, // from Anthropic SDK
 855        #[serde(skip_serializing_if = "Option::is_none")]
 856        session_id: Option<String>,
 857    },
 858    // A user message
 859    User {
 860        message: Message, // from Anthropic SDK
 861        #[serde(skip_serializing_if = "Option::is_none")]
 862        session_id: Option<String>,
 863    },
 864    // Emitted as the last message in a conversation
 865    Result {
 866        subtype: ResultErrorType,
 867        duration_ms: f64,
 868        duration_api_ms: f64,
 869        is_error: bool,
 870        num_turns: i32,
 871        #[serde(skip_serializing_if = "Option::is_none")]
 872        result: Option<String>,
 873        session_id: String,
 874        total_cost_usd: f64,
 875    },
 876    // Emitted as the first message at the start of a conversation
 877    System {
 878        cwd: String,
 879        session_id: String,
 880        tools: Vec<String>,
 881        model: String,
 882        mcp_servers: Vec<McpServer>,
 883        #[serde(rename = "apiKeySource")]
 884        api_key_source: String,
 885        #[serde(rename = "permissionMode")]
 886        permission_mode: PermissionMode,
 887    },
 888    /// Messages used to control the conversation, outside of chat messages to the model
 889    ControlRequest {
 890        request_id: String,
 891        request: ControlRequest,
 892    },
 893    /// Response to a control request
 894    ControlResponse { response: ControlResponse },
 895}
 896
 897#[derive(Debug, Clone, Serialize, Deserialize)]
 898#[serde(tag = "subtype", rename_all = "snake_case")]
 899enum ControlRequest {
 900    /// Cancel the current conversation
 901    Interrupt,
 902}
 903
 904#[derive(Debug, Clone, Serialize, Deserialize)]
 905struct ControlResponse {
 906    request_id: String,
 907    subtype: ResultErrorType,
 908}
 909
 910#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
 911#[serde(rename_all = "snake_case")]
 912enum ResultErrorType {
 913    Success,
 914    ErrorMaxTurns,
 915    ErrorDuringExecution,
 916}
 917
 918impl Display for ResultErrorType {
 919    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 920        match self {
 921            ResultErrorType::Success => write!(f, "success"),
 922            ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
 923            ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
 924        }
 925    }
 926}
 927
 928fn acp_content_to_claude(prompt: Vec<acp::ContentBlock>) -> Vec<ContentChunk> {
 929    let mut content = Vec::with_capacity(prompt.len());
 930    let mut context = Vec::with_capacity(prompt.len());
 931
 932    for chunk in prompt {
 933        match chunk {
 934            acp::ContentBlock::Text(text_content) => {
 935                content.push(ContentChunk::Text {
 936                    text: text_content.text,
 937                });
 938            }
 939            acp::ContentBlock::ResourceLink(resource_link) => {
 940                match MentionUri::parse(&resource_link.uri) {
 941                    Ok(uri) => {
 942                        content.push(ContentChunk::Text {
 943                            text: format!("{}", uri.as_link()),
 944                        });
 945                    }
 946                    Err(_) => {
 947                        content.push(ContentChunk::Text {
 948                            text: resource_link.uri,
 949                        });
 950                    }
 951                }
 952            }
 953            acp::ContentBlock::Resource(resource) => match resource.resource {
 954                acp::EmbeddedResourceResource::TextResourceContents(resource) => {
 955                    match MentionUri::parse(&resource.uri) {
 956                        Ok(uri) => {
 957                            content.push(ContentChunk::Text {
 958                                text: format!("{}", uri.as_link()),
 959                            });
 960                        }
 961                        Err(_) => {
 962                            content.push(ContentChunk::Text {
 963                                text: resource.uri.clone(),
 964                            });
 965                        }
 966                    }
 967
 968                    context.push(ContentChunk::Text {
 969                        text: format!(
 970                            "\n<context ref=\"{}\">\n{}\n</context>",
 971                            resource.uri, resource.text
 972                        ),
 973                    });
 974                }
 975                acp::EmbeddedResourceResource::BlobResourceContents(_) => {
 976                    // Unsupported by SDK
 977                }
 978            },
 979            acp::ContentBlock::Image(acp::ImageContent {
 980                data, mime_type, ..
 981            }) => content.push(ContentChunk::Image {
 982                source: ImageSource::Base64 {
 983                    data,
 984                    media_type: mime_type,
 985                },
 986            }),
 987            acp::ContentBlock::Audio(_) => {
 988                // Unsupported by SDK
 989            }
 990        }
 991    }
 992
 993    content.extend(context);
 994    content
 995}
 996
 997fn new_request_id() -> String {
 998    use rand::Rng;
 999    // In the Claude Code TS SDK they just generate a random 12 character string,
1000    // `Math.random().toString(36).substring(2, 15)`
1001    rand::thread_rng()
1002        .sample_iter(&rand::distributions::Alphanumeric)
1003        .take(12)
1004        .map(char::from)
1005        .collect()
1006}
1007
1008#[derive(Debug, Clone, Serialize, Deserialize)]
1009struct McpServer {
1010    name: String,
1011    status: String,
1012}
1013
1014#[derive(Debug, Clone, Serialize, Deserialize)]
1015#[serde(rename_all = "camelCase")]
1016enum PermissionMode {
1017    Default,
1018    AcceptEdits,
1019    BypassPermissions,
1020    Plan,
1021}
1022
1023#[cfg(test)]
1024pub(crate) mod tests {
1025    use super::*;
1026    use crate::e2e_tests;
1027    use gpui::TestAppContext;
1028    use serde_json::json;
1029
1030    crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
1031
1032    pub fn local_command() -> AgentServerCommand {
1033        AgentServerCommand {
1034            path: "claude".into(),
1035            args: vec![],
1036            env: None,
1037        }
1038    }
1039
1040    #[gpui::test]
1041    #[cfg_attr(not(feature = "e2e"), ignore)]
1042    async fn test_todo_plan(cx: &mut TestAppContext) {
1043        let fs = e2e_tests::init_test(cx).await;
1044        let project = Project::test(fs, [], cx).await;
1045        let thread =
1046            e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
1047
1048        thread
1049            .update(cx, |thread, cx| {
1050                thread.send_raw(
1051                    "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
1052                    cx,
1053                )
1054            })
1055            .await
1056            .unwrap();
1057
1058        let mut entries_len = 0;
1059
1060        thread.read_with(cx, |thread, _| {
1061            entries_len = thread.plan().entries.len();
1062            assert!(thread.plan().entries.len() > 0, "Empty plan");
1063        });
1064
1065        thread
1066            .update(cx, |thread, cx| {
1067                thread.send_raw(
1068                    "Mark the first entry status as in progress without acting on it.",
1069                    cx,
1070                )
1071            })
1072            .await
1073            .unwrap();
1074
1075        thread.read_with(cx, |thread, _| {
1076            assert!(matches!(
1077                thread.plan().entries[0].status,
1078                acp::PlanEntryStatus::InProgress
1079            ));
1080            assert_eq!(thread.plan().entries.len(), entries_len);
1081        });
1082
1083        thread
1084            .update(cx, |thread, cx| {
1085                thread.send_raw(
1086                    "Now mark the first entry as completed without acting on it.",
1087                    cx,
1088                )
1089            })
1090            .await
1091            .unwrap();
1092
1093        thread.read_with(cx, |thread, _| {
1094            assert!(matches!(
1095                thread.plan().entries[0].status,
1096                acp::PlanEntryStatus::Completed
1097            ));
1098            assert_eq!(thread.plan().entries.len(), entries_len);
1099        });
1100    }
1101
1102    #[test]
1103    fn test_deserialize_content_untagged_text() {
1104        let json = json!("Hello, world!");
1105        let content: Content = serde_json::from_value(json).unwrap();
1106        match content {
1107            Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
1108            _ => panic!("Expected UntaggedText variant"),
1109        }
1110    }
1111
1112    #[test]
1113    fn test_deserialize_content_chunks() {
1114        let json = json!([
1115            {
1116                "type": "text",
1117                "text": "Hello"
1118            },
1119            {
1120                "type": "tool_use",
1121                "id": "tool_123",
1122                "name": "calculator",
1123                "input": {"operation": "add", "a": 1, "b": 2}
1124            }
1125        ]);
1126        let content: Content = serde_json::from_value(json).unwrap();
1127        match content {
1128            Content::Chunks(chunks) => {
1129                assert_eq!(chunks.len(), 2);
1130                match &chunks[0] {
1131                    ContentChunk::Text { text } => assert_eq!(text, "Hello"),
1132                    _ => panic!("Expected Text chunk"),
1133                }
1134                match &chunks[1] {
1135                    ContentChunk::ToolUse { id, name, input } => {
1136                        assert_eq!(id, "tool_123");
1137                        assert_eq!(name, "calculator");
1138                        assert_eq!(input["operation"], "add");
1139                        assert_eq!(input["a"], 1);
1140                        assert_eq!(input["b"], 2);
1141                    }
1142                    _ => panic!("Expected ToolUse chunk"),
1143                }
1144            }
1145            _ => panic!("Expected Chunks variant"),
1146        }
1147    }
1148
1149    #[test]
1150    fn test_deserialize_tool_result_untagged_text() {
1151        let json = json!({
1152            "type": "tool_result",
1153            "content": "Result content",
1154            "tool_use_id": "tool_456"
1155        });
1156        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1157        match chunk {
1158            ContentChunk::ToolResult {
1159                content,
1160                tool_use_id,
1161            } => {
1162                match content {
1163                    Content::UntaggedText(text) => assert_eq!(text, "Result content"),
1164                    _ => panic!("Expected UntaggedText content"),
1165                }
1166                assert_eq!(tool_use_id, "tool_456");
1167            }
1168            _ => panic!("Expected ToolResult variant"),
1169        }
1170    }
1171
1172    #[test]
1173    fn test_deserialize_tool_result_chunks() {
1174        let json = json!({
1175            "type": "tool_result",
1176            "content": [
1177                {
1178                    "type": "text",
1179                    "text": "Processing complete"
1180                },
1181                {
1182                    "type": "text",
1183                    "text": "Result: 42"
1184                }
1185            ],
1186            "tool_use_id": "tool_789"
1187        });
1188        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1189        match chunk {
1190            ContentChunk::ToolResult {
1191                content,
1192                tool_use_id,
1193            } => {
1194                match content {
1195                    Content::Chunks(chunks) => {
1196                        assert_eq!(chunks.len(), 2);
1197                        match &chunks[0] {
1198                            ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1199                            _ => panic!("Expected Text chunk"),
1200                        }
1201                        match &chunks[1] {
1202                            ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1203                            _ => panic!("Expected Text chunk"),
1204                        }
1205                    }
1206                    _ => panic!("Expected Chunks content"),
1207                }
1208                assert_eq!(tool_use_id, "tool_789");
1209            }
1210            _ => panic!("Expected ToolResult variant"),
1211        }
1212    }
1213
1214    #[test]
1215    fn test_acp_content_to_claude() {
1216        let acp_content = vec![
1217            acp::ContentBlock::Text(acp::TextContent {
1218                text: "Hello world".to_string(),
1219                annotations: None,
1220            }),
1221            acp::ContentBlock::Image(acp::ImageContent {
1222                data: "base64data".to_string(),
1223                mime_type: "image/png".to_string(),
1224                annotations: None,
1225                uri: None,
1226            }),
1227            acp::ContentBlock::ResourceLink(acp::ResourceLink {
1228                uri: "file:///path/to/example.rs".to_string(),
1229                name: "example.rs".to_string(),
1230                annotations: None,
1231                description: None,
1232                mime_type: None,
1233                size: None,
1234                title: None,
1235            }),
1236            acp::ContentBlock::Resource(acp::EmbeddedResource {
1237                annotations: None,
1238                resource: acp::EmbeddedResourceResource::TextResourceContents(
1239                    acp::TextResourceContents {
1240                        mime_type: None,
1241                        text: "fn main() { println!(\"Hello!\"); }".to_string(),
1242                        uri: "file:///path/to/code.rs".to_string(),
1243                    },
1244                ),
1245            }),
1246            acp::ContentBlock::ResourceLink(acp::ResourceLink {
1247                uri: "invalid_uri_format".to_string(),
1248                name: "invalid.txt".to_string(),
1249                annotations: None,
1250                description: None,
1251                mime_type: None,
1252                size: None,
1253                title: None,
1254            }),
1255        ];
1256
1257        let claude_content = acp_content_to_claude(acp_content);
1258
1259        assert_eq!(claude_content.len(), 6);
1260
1261        match &claude_content[0] {
1262            ContentChunk::Text { text } => assert_eq!(text, "Hello world"),
1263            _ => panic!("Expected Text chunk"),
1264        }
1265
1266        match &claude_content[1] {
1267            ContentChunk::Image { source } => match source {
1268                ImageSource::Base64 { data, media_type } => {
1269                    assert_eq!(data, "base64data");
1270                    assert_eq!(media_type, "image/png");
1271                }
1272                _ => panic!("Expected Base64 image source"),
1273            },
1274            _ => panic!("Expected Image chunk"),
1275        }
1276
1277        match &claude_content[2] {
1278            ContentChunk::Text { text } => {
1279                assert!(text.contains("example.rs"));
1280                assert!(text.contains("file:///path/to/example.rs"));
1281            }
1282            _ => panic!("Expected Text chunk for ResourceLink"),
1283        }
1284
1285        match &claude_content[3] {
1286            ContentChunk::Text { text } => {
1287                assert!(text.contains("code.rs"));
1288                assert!(text.contains("file:///path/to/code.rs"));
1289            }
1290            _ => panic!("Expected Text chunk for Resource"),
1291        }
1292
1293        match &claude_content[4] {
1294            ContentChunk::Text { text } => {
1295                assert_eq!(text, "invalid_uri_format");
1296            }
1297            _ => panic!("Expected Text chunk for invalid URI"),
1298        }
1299
1300        match &claude_content[5] {
1301            ContentChunk::Text { text } => {
1302                assert!(text.contains("<context ref=\"file:///path/to/code.rs\">"));
1303                assert!(text.contains("fn main() { println!(\"Hello!\"); }"));
1304                assert!(text.contains("</context>"));
1305            }
1306            _ => panic!("Expected Text chunk for context"),
1307        }
1308    }
1309}