claude.rs

   1mod mcp_server;
   2pub mod tools;
   3
   4use collections::HashMap;
   5use context_server::listener::McpServerTool;
   6use project::Project;
   7use settings::SettingsStore;
   8use smol::process::Child;
   9use std::cell::RefCell;
  10use std::fmt::Display;
  11use std::path::Path;
  12use std::rc::Rc;
  13use uuid::Uuid;
  14
  15use agent_client_protocol as acp;
  16use anyhow::{Result, anyhow};
  17use futures::channel::oneshot;
  18use futures::{AsyncBufReadExt, AsyncWriteExt};
  19use futures::{
  20    AsyncRead, AsyncWrite, FutureExt, StreamExt,
  21    channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
  22    io::BufReader,
  23    select_biased,
  24};
  25use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  26use serde::{Deserialize, Serialize};
  27use util::{ResultExt, debug_panic};
  28
  29use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
  30use crate::claude::tools::ClaudeTool;
  31use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
  32use acp_thread::{AcpThread, AgentConnection};
  33
  34#[derive(Clone)]
  35pub struct ClaudeCode;
  36
  37impl AgentServer for ClaudeCode {
  38    fn name(&self) -> &'static str {
  39        "Claude Code"
  40    }
  41
  42    fn empty_state_headline(&self) -> &'static str {
  43        self.name()
  44    }
  45
  46    fn empty_state_message(&self) -> &'static str {
  47        "How can I help you today?"
  48    }
  49
  50    fn logo(&self) -> ui::IconName {
  51        ui::IconName::AiClaude
  52    }
  53
  54    fn connect(
  55        &self,
  56        _root_dir: &Path,
  57        _project: &Entity<Project>,
  58        _cx: &mut App,
  59    ) -> Task<Result<Rc<dyn AgentConnection>>> {
  60        let connection = ClaudeAgentConnection {
  61            sessions: Default::default(),
  62        };
  63
  64        Task::ready(Ok(Rc::new(connection) as _))
  65    }
  66}
  67
  68struct ClaudeAgentConnection {
  69    sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
  70}
  71
  72impl AgentConnection for ClaudeAgentConnection {
  73    fn new_thread(
  74        self: Rc<Self>,
  75        project: Entity<Project>,
  76        cwd: &Path,
  77        cx: &mut AsyncApp,
  78    ) -> Task<Result<Entity<AcpThread>>> {
  79        let cwd = cwd.to_owned();
  80        cx.spawn(async move |cx| {
  81            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
  82            let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?;
  83
  84            let mut mcp_servers = HashMap::default();
  85            mcp_servers.insert(
  86                mcp_server::SERVER_NAME.to_string(),
  87                permission_mcp_server.server_config()?,
  88            );
  89            let mcp_config = McpConfig { mcp_servers };
  90
  91            let mcp_config_file = tempfile::NamedTempFile::new()?;
  92            let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
  93
  94            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
  95            mcp_config_file
  96                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
  97                .await?;
  98            mcp_config_file.flush().await?;
  99
 100            let settings = cx.read_global(|settings: &SettingsStore, _| {
 101                settings.get::<AllAgentServersSettings>(None).claude.clone()
 102            })?;
 103
 104            let Some(command) = AgentServerCommand::resolve(
 105                "claude",
 106                &[],
 107                Some(&util::paths::home_dir().join(".claude/local/claude")),
 108                settings,
 109                &project,
 110                cx,
 111            )
 112            .await
 113            else {
 114                anyhow::bail!("Failed to find claude binary");
 115            };
 116
 117            let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
 118            let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
 119
 120            let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
 121
 122            log::trace!("Starting session with id: {}", session_id);
 123
 124            let mut child = spawn_claude(
 125                &command,
 126                ClaudeSessionMode::Start,
 127                session_id.clone(),
 128                &mcp_config_path,
 129                &cwd,
 130            )?;
 131
 132            let stdin = child.stdin.take().unwrap();
 133            let stdout = child.stdout.take().unwrap();
 134
 135            let pid = child.id();
 136            log::trace!("Spawned (pid: {})", pid);
 137
 138            cx.background_spawn(async move {
 139                let mut outgoing_rx = Some(outgoing_rx);
 140
 141                ClaudeAgentSession::handle_io(
 142                    outgoing_rx.take().unwrap(),
 143                    incoming_message_tx.clone(),
 144                    stdin,
 145                    stdout,
 146                )
 147                .await?;
 148
 149                log::trace!("Stopped (pid: {})", pid);
 150
 151                drop(mcp_config_path);
 152                anyhow::Ok(())
 153            })
 154            .detach();
 155
 156            let cancellation_state = Rc::new(RefCell::new(CancellationState::None));
 157
 158            let end_turn_tx = Rc::new(RefCell::new(None));
 159            let handler_task = cx.spawn({
 160                let end_turn_tx = end_turn_tx.clone();
 161                let mut thread_rx = thread_rx.clone();
 162                let cancellation_state = cancellation_state.clone();
 163                async move |cx| {
 164                    while let Some(message) = incoming_message_rx.next().await {
 165                        ClaudeAgentSession::handle_message(
 166                            thread_rx.clone(),
 167                            message,
 168                            end_turn_tx.clone(),
 169                            cancellation_state.clone(),
 170                            cx,
 171                        )
 172                        .await
 173                    }
 174
 175                    if let Some(status) = child.status().await.log_err() {
 176                        if let Some(thread) = thread_rx.recv().await.ok() {
 177                            thread
 178                                .update(cx, |thread, cx| {
 179                                    thread.emit_server_exited(status, cx);
 180                                })
 181                                .ok();
 182                        }
 183                    }
 184                }
 185            });
 186
 187            let thread = cx.new(|cx| {
 188                AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
 189            })?;
 190
 191            thread_tx.send(thread.downgrade())?;
 192
 193            let session = ClaudeAgentSession {
 194                outgoing_tx,
 195                end_turn_tx,
 196                cancellation_state,
 197                _handler_task: handler_task,
 198                _mcp_server: Some(permission_mcp_server),
 199            };
 200
 201            self.sessions.borrow_mut().insert(session_id, session);
 202
 203            Ok(thread)
 204        })
 205    }
 206
 207    fn auth_methods(&self) -> &[acp::AuthMethod] {
 208        &[]
 209    }
 210
 211    fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 212        Task::ready(Err(anyhow!("Authentication not supported")))
 213    }
 214
 215    fn prompt(
 216        &self,
 217        params: acp::PromptRequest,
 218        cx: &mut App,
 219    ) -> Task<Result<acp::PromptResponse>> {
 220        let sessions = self.sessions.borrow();
 221        let Some(session) = sessions.get(&params.session_id) else {
 222            return Task::ready(Err(anyhow!(
 223                "Attempted to send message to nonexistent session {}",
 224                params.session_id
 225            )));
 226        };
 227
 228        let (tx, rx) = oneshot::channel();
 229        session.end_turn_tx.borrow_mut().replace(tx);
 230
 231        let mut content = String::new();
 232        for chunk in params.prompt {
 233            match chunk {
 234                acp::ContentBlock::Text(text_content) => {
 235                    content.push_str(&text_content.text);
 236                }
 237                acp::ContentBlock::ResourceLink(resource_link) => {
 238                    content.push_str(&format!("@{}", resource_link.uri));
 239                }
 240                acp::ContentBlock::Audio(_)
 241                | acp::ContentBlock::Image(_)
 242                | acp::ContentBlock::Resource(_) => {
 243                    // TODO
 244                }
 245            }
 246        }
 247
 248        if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
 249            message: Message {
 250                role: Role::User,
 251                content: Content::UntaggedText(content),
 252                id: None,
 253                model: None,
 254                stop_reason: None,
 255                stop_sequence: None,
 256                usage: None,
 257            },
 258            session_id: Some(params.session_id.to_string()),
 259        }) {
 260            return Task::ready(Err(anyhow!(err)));
 261        }
 262
 263        let cancellation_state = session.cancellation_state.clone();
 264        cx.foreground_executor().spawn(async move {
 265            let result = rx.await??;
 266            cancellation_state.replace(CancellationState::None);
 267            Ok(result)
 268        })
 269    }
 270
 271    fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
 272        let sessions = self.sessions.borrow();
 273        let Some(session) = sessions.get(&session_id) else {
 274            log::warn!("Attempted to cancel nonexistent session {}", session_id);
 275            return;
 276        };
 277
 278        let request_id = new_request_id();
 279
 280        session
 281            .cancellation_state
 282            .replace(CancellationState::Requested {
 283                request_id: request_id.clone(),
 284            });
 285
 286        session
 287            .outgoing_tx
 288            .unbounded_send(SdkMessage::ControlRequest {
 289                request_id,
 290                request: ControlRequest::Interrupt,
 291            })
 292            .log_err();
 293    }
 294}
 295
 296#[derive(Clone, Copy)]
 297enum ClaudeSessionMode {
 298    Start,
 299    #[expect(dead_code)]
 300    Resume,
 301}
 302
 303fn spawn_claude(
 304    command: &AgentServerCommand,
 305    mode: ClaudeSessionMode,
 306    session_id: acp::SessionId,
 307    mcp_config_path: &Path,
 308    root_dir: &Path,
 309) -> Result<Child> {
 310    let child = util::command::new_smol_command(&command.path)
 311        .args([
 312            "--input-format",
 313            "stream-json",
 314            "--output-format",
 315            "stream-json",
 316            "--print",
 317            "--verbose",
 318            "--mcp-config",
 319            mcp_config_path.to_string_lossy().as_ref(),
 320            "--permission-prompt-tool",
 321            &format!(
 322                "mcp__{}__{}",
 323                mcp_server::SERVER_NAME,
 324                mcp_server::PermissionTool::NAME,
 325            ),
 326            "--allowedTools",
 327            &format!(
 328                "mcp__{}__{},mcp__{}__{}",
 329                mcp_server::SERVER_NAME,
 330                mcp_server::EditTool::NAME,
 331                mcp_server::SERVER_NAME,
 332                mcp_server::ReadTool::NAME
 333            ),
 334            "--disallowedTools",
 335            "Read,Edit",
 336        ])
 337        .args(match mode {
 338            ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
 339            ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
 340        })
 341        .args(command.args.iter().map(|arg| arg.as_str()))
 342        .current_dir(root_dir)
 343        .stdin(std::process::Stdio::piped())
 344        .stdout(std::process::Stdio::piped())
 345        .stderr(std::process::Stdio::inherit())
 346        .kill_on_drop(true)
 347        .spawn()?;
 348
 349    Ok(child)
 350}
 351
 352struct ClaudeAgentSession {
 353    outgoing_tx: UnboundedSender<SdkMessage>,
 354    end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
 355    cancellation_state: Rc<RefCell<CancellationState>>,
 356    _mcp_server: Option<ClaudeZedMcpServer>,
 357    _handler_task: Task<()>,
 358}
 359
 360#[derive(Debug, Default)]
 361enum CancellationState {
 362    #[default]
 363    None,
 364    Requested {
 365        request_id: String,
 366    },
 367    Confirmed,
 368}
 369
 370impl CancellationState {
 371    fn is_confirmed(&self) -> bool {
 372        matches!(self, CancellationState::Confirmed)
 373    }
 374}
 375
 376impl ClaudeAgentSession {
 377    async fn handle_message(
 378        mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
 379        message: SdkMessage,
 380        end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
 381        cancellation_state: Rc<RefCell<CancellationState>>,
 382        cx: &mut AsyncApp,
 383    ) {
 384        match message {
 385            // we should only be sending these out, they don't need to be in the thread
 386            SdkMessage::ControlRequest { .. } => {}
 387            SdkMessage::User {
 388                message,
 389                session_id: _,
 390            } => {
 391                let Some(thread) = thread_rx
 392                    .recv()
 393                    .await
 394                    .log_err()
 395                    .and_then(|entity| entity.upgrade())
 396                else {
 397                    log::error!("Received an SDK message but thread is gone");
 398                    return;
 399                };
 400
 401                for chunk in message.content.chunks() {
 402                    match chunk {
 403                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 404                            if !cancellation_state.borrow().is_confirmed() {
 405                                thread
 406                                    .update(cx, |thread, cx| {
 407                                        thread.push_user_content_block(text.into(), cx)
 408                                    })
 409                                    .log_err();
 410                            }
 411                        }
 412                        ContentChunk::ToolResult {
 413                            content,
 414                            tool_use_id,
 415                        } => {
 416                            let content = content.to_string();
 417                            thread
 418                                .update(cx, |thread, cx| {
 419                                    thread.update_tool_call(
 420                                        acp::ToolCallUpdate {
 421                                            id: acp::ToolCallId(tool_use_id.into()),
 422                                            fields: acp::ToolCallUpdateFields {
 423                                                status: if cancellation_state
 424                                                    .borrow()
 425                                                    .is_confirmed()
 426                                                {
 427                                                    // Do not set to completed if turn was cancelled
 428                                                    None
 429                                                } else {
 430                                                    Some(acp::ToolCallStatus::Completed)
 431                                                },
 432                                                content: (!content.is_empty())
 433                                                    .then(|| vec![content.into()]),
 434                                                ..Default::default()
 435                                            },
 436                                        },
 437                                        cx,
 438                                    )
 439                                })
 440                                .log_err();
 441                        }
 442                        ContentChunk::Thinking { .. }
 443                        | ContentChunk::RedactedThinking
 444                        | ContentChunk::ToolUse { .. } => {
 445                            debug_panic!(
 446                                "Should not get {:?} with role: assistant. should we handle this?",
 447                                chunk
 448                            );
 449                        }
 450
 451                        ContentChunk::Image
 452                        | ContentChunk::Document
 453                        | ContentChunk::WebSearchToolResult => {
 454                            thread
 455                                .update(cx, |thread, cx| {
 456                                    thread.push_assistant_content_block(
 457                                        format!("Unsupported content: {:?}", chunk).into(),
 458                                        false,
 459                                        cx,
 460                                    )
 461                                })
 462                                .log_err();
 463                        }
 464                    }
 465                }
 466            }
 467            SdkMessage::Assistant {
 468                message,
 469                session_id: _,
 470            } => {
 471                let Some(thread) = thread_rx
 472                    .recv()
 473                    .await
 474                    .log_err()
 475                    .and_then(|entity| entity.upgrade())
 476                else {
 477                    log::error!("Received an SDK message but thread is gone");
 478                    return;
 479                };
 480
 481                for chunk in message.content.chunks() {
 482                    match chunk {
 483                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 484                            thread
 485                                .update(cx, |thread, cx| {
 486                                    thread.push_assistant_content_block(text.into(), false, cx)
 487                                })
 488                                .log_err();
 489                        }
 490                        ContentChunk::Thinking { thinking } => {
 491                            thread
 492                                .update(cx, |thread, cx| {
 493                                    thread.push_assistant_content_block(thinking.into(), true, cx)
 494                                })
 495                                .log_err();
 496                        }
 497                        ContentChunk::RedactedThinking => {
 498                            thread
 499                                .update(cx, |thread, cx| {
 500                                    thread.push_assistant_content_block(
 501                                        "[REDACTED]".into(),
 502                                        true,
 503                                        cx,
 504                                    )
 505                                })
 506                                .log_err();
 507                        }
 508                        ContentChunk::ToolUse { id, name, input } => {
 509                            let claude_tool = ClaudeTool::infer(&name, input);
 510
 511                            thread
 512                                .update(cx, |thread, cx| {
 513                                    if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
 514                                        thread.update_plan(
 515                                            acp::Plan {
 516                                                entries: params
 517                                                    .todos
 518                                                    .into_iter()
 519                                                    .map(Into::into)
 520                                                    .collect(),
 521                                            },
 522                                            cx,
 523                                        )
 524                                    } else {
 525                                        thread.upsert_tool_call(
 526                                            claude_tool.as_acp(acp::ToolCallId(id.into())),
 527                                            cx,
 528                                        );
 529                                    }
 530                                })
 531                                .log_err();
 532                        }
 533                        ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
 534                            debug_panic!(
 535                                "Should not get tool results with role: assistant. should we handle this?"
 536                            );
 537                        }
 538                        ContentChunk::Image | ContentChunk::Document => {
 539                            thread
 540                                .update(cx, |thread, cx| {
 541                                    thread.push_assistant_content_block(
 542                                        format!("Unsupported content: {:?}", chunk).into(),
 543                                        false,
 544                                        cx,
 545                                    )
 546                                })
 547                                .log_err();
 548                        }
 549                    }
 550                }
 551            }
 552            SdkMessage::Result {
 553                is_error,
 554                subtype,
 555                result,
 556                ..
 557            } => {
 558                if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
 559                    if is_error
 560                        || (subtype == ResultErrorType::ErrorDuringExecution
 561                            && !cancellation_state.borrow().is_confirmed())
 562                    {
 563                        end_turn_tx
 564                            .send(Err(anyhow!(
 565                                "Error: {}",
 566                                result.unwrap_or_else(|| subtype.to_string())
 567                            )))
 568                            .ok();
 569                    } else {
 570                        let stop_reason = match subtype {
 571                            ResultErrorType::Success => acp::StopReason::EndTurn,
 572                            ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
 573                            ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
 574                        };
 575                        end_turn_tx
 576                            .send(Ok(acp::PromptResponse { stop_reason }))
 577                            .ok();
 578                    }
 579                }
 580            }
 581            SdkMessage::ControlResponse { response } => {
 582                if matches!(response.subtype, ResultErrorType::Success) {
 583                    let mut cancellation_state = cancellation_state.borrow_mut();
 584                    if let CancellationState::Requested { request_id } = &*cancellation_state
 585                        && request_id == &response.request_id
 586                    {
 587                        *cancellation_state = CancellationState::Confirmed;
 588                    }
 589                }
 590            }
 591            SdkMessage::System { .. } => {}
 592        }
 593    }
 594
 595    async fn handle_io(
 596        mut outgoing_rx: UnboundedReceiver<SdkMessage>,
 597        incoming_tx: UnboundedSender<SdkMessage>,
 598        mut outgoing_bytes: impl Unpin + AsyncWrite,
 599        incoming_bytes: impl Unpin + AsyncRead,
 600    ) -> Result<UnboundedReceiver<SdkMessage>> {
 601        let mut output_reader = BufReader::new(incoming_bytes);
 602        let mut outgoing_line = Vec::new();
 603        let mut incoming_line = String::new();
 604        loop {
 605            select_biased! {
 606                message = outgoing_rx.next() => {
 607                    if let Some(message) = message {
 608                        outgoing_line.clear();
 609                        serde_json::to_writer(&mut outgoing_line, &message)?;
 610                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
 611                        outgoing_line.push(b'\n');
 612                        outgoing_bytes.write_all(&outgoing_line).await.ok();
 613                    } else {
 614                        break;
 615                    }
 616                }
 617                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
 618                    if bytes_read? == 0 {
 619                        break
 620                    }
 621                    log::trace!("recv: {}", &incoming_line);
 622                    match serde_json::from_str::<SdkMessage>(&incoming_line) {
 623                        Ok(message) => {
 624                            incoming_tx.unbounded_send(message).log_err();
 625                        }
 626                        Err(error) => {
 627                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
 628                        }
 629                    }
 630                    incoming_line.clear();
 631                }
 632            }
 633        }
 634
 635        Ok(outgoing_rx)
 636    }
 637}
 638
 639#[derive(Debug, Clone, Serialize, Deserialize)]
 640struct Message {
 641    role: Role,
 642    content: Content,
 643    #[serde(skip_serializing_if = "Option::is_none")]
 644    id: Option<String>,
 645    #[serde(skip_serializing_if = "Option::is_none")]
 646    model: Option<String>,
 647    #[serde(skip_serializing_if = "Option::is_none")]
 648    stop_reason: Option<String>,
 649    #[serde(skip_serializing_if = "Option::is_none")]
 650    stop_sequence: Option<String>,
 651    #[serde(skip_serializing_if = "Option::is_none")]
 652    usage: Option<Usage>,
 653}
 654
 655#[derive(Debug, Clone, Serialize, Deserialize)]
 656#[serde(untagged)]
 657enum Content {
 658    UntaggedText(String),
 659    Chunks(Vec<ContentChunk>),
 660}
 661
 662impl Content {
 663    pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
 664        match self {
 665            Self::Chunks(chunks) => chunks.into_iter(),
 666            Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
 667        }
 668    }
 669}
 670
 671impl Display for Content {
 672    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 673        match self {
 674            Content::UntaggedText(txt) => write!(f, "{}", txt),
 675            Content::Chunks(chunks) => {
 676                for chunk in chunks {
 677                    write!(f, "{}", chunk)?;
 678                }
 679                Ok(())
 680            }
 681        }
 682    }
 683}
 684
 685#[derive(Debug, Clone, Serialize, Deserialize)]
 686#[serde(tag = "type", rename_all = "snake_case")]
 687enum ContentChunk {
 688    Text {
 689        text: String,
 690    },
 691    ToolUse {
 692        id: String,
 693        name: String,
 694        input: serde_json::Value,
 695    },
 696    ToolResult {
 697        content: Content,
 698        tool_use_id: String,
 699    },
 700    Thinking {
 701        thinking: String,
 702    },
 703    RedactedThinking,
 704    // TODO
 705    Image,
 706    Document,
 707    WebSearchToolResult,
 708    #[serde(untagged)]
 709    UntaggedText(String),
 710}
 711
 712impl Display for ContentChunk {
 713    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 714        match self {
 715            ContentChunk::Text { text } => write!(f, "{}", text),
 716            ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
 717            ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
 718            ContentChunk::UntaggedText(text) => write!(f, "{}", text),
 719            ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
 720            ContentChunk::Image
 721            | ContentChunk::Document
 722            | ContentChunk::ToolUse { .. }
 723            | ContentChunk::WebSearchToolResult => {
 724                write!(f, "\n{:?}\n", &self)
 725            }
 726        }
 727    }
 728}
 729
 730#[derive(Debug, Clone, Serialize, Deserialize)]
 731struct Usage {
 732    input_tokens: u32,
 733    cache_creation_input_tokens: u32,
 734    cache_read_input_tokens: u32,
 735    output_tokens: u32,
 736    service_tier: String,
 737}
 738
 739#[derive(Debug, Clone, Serialize, Deserialize)]
 740#[serde(rename_all = "snake_case")]
 741enum Role {
 742    System,
 743    Assistant,
 744    User,
 745}
 746
 747#[derive(Debug, Clone, Serialize, Deserialize)]
 748struct MessageParam {
 749    role: Role,
 750    content: String,
 751}
 752
 753#[derive(Debug, Clone, Serialize, Deserialize)]
 754#[serde(tag = "type", rename_all = "snake_case")]
 755enum SdkMessage {
 756    // An assistant message
 757    Assistant {
 758        message: Message, // from Anthropic SDK
 759        #[serde(skip_serializing_if = "Option::is_none")]
 760        session_id: Option<String>,
 761    },
 762    // A user message
 763    User {
 764        message: Message, // from Anthropic SDK
 765        #[serde(skip_serializing_if = "Option::is_none")]
 766        session_id: Option<String>,
 767    },
 768    // Emitted as the last message in a conversation
 769    Result {
 770        subtype: ResultErrorType,
 771        duration_ms: f64,
 772        duration_api_ms: f64,
 773        is_error: bool,
 774        num_turns: i32,
 775        #[serde(skip_serializing_if = "Option::is_none")]
 776        result: Option<String>,
 777        session_id: String,
 778        total_cost_usd: f64,
 779    },
 780    // Emitted as the first message at the start of a conversation
 781    System {
 782        cwd: String,
 783        session_id: String,
 784        tools: Vec<String>,
 785        model: String,
 786        mcp_servers: Vec<McpServer>,
 787        #[serde(rename = "apiKeySource")]
 788        api_key_source: String,
 789        #[serde(rename = "permissionMode")]
 790        permission_mode: PermissionMode,
 791    },
 792    /// Messages used to control the conversation, outside of chat messages to the model
 793    ControlRequest {
 794        request_id: String,
 795        request: ControlRequest,
 796    },
 797    /// Response to a control request
 798    ControlResponse { response: ControlResponse },
 799}
 800
 801#[derive(Debug, Clone, Serialize, Deserialize)]
 802#[serde(tag = "subtype", rename_all = "snake_case")]
 803enum ControlRequest {
 804    /// Cancel the current conversation
 805    Interrupt,
 806}
 807
 808#[derive(Debug, Clone, Serialize, Deserialize)]
 809struct ControlResponse {
 810    request_id: String,
 811    subtype: ResultErrorType,
 812}
 813
 814#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
 815#[serde(rename_all = "snake_case")]
 816enum ResultErrorType {
 817    Success,
 818    ErrorMaxTurns,
 819    ErrorDuringExecution,
 820}
 821
 822impl Display for ResultErrorType {
 823    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 824        match self {
 825            ResultErrorType::Success => write!(f, "success"),
 826            ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
 827            ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
 828        }
 829    }
 830}
 831
 832fn new_request_id() -> String {
 833    use rand::Rng;
 834    // In the Claude Code TS SDK they just generate a random 12 character string,
 835    // `Math.random().toString(36).substring(2, 15)`
 836    rand::thread_rng()
 837        .sample_iter(&rand::distributions::Alphanumeric)
 838        .take(12)
 839        .map(char::from)
 840        .collect()
 841}
 842
 843#[derive(Debug, Clone, Serialize, Deserialize)]
 844struct McpServer {
 845    name: String,
 846    status: String,
 847}
 848
 849#[derive(Debug, Clone, Serialize, Deserialize)]
 850#[serde(rename_all = "camelCase")]
 851enum PermissionMode {
 852    Default,
 853    AcceptEdits,
 854    BypassPermissions,
 855    Plan,
 856}
 857
 858#[cfg(test)]
 859pub(crate) mod tests {
 860    use super::*;
 861    use crate::e2e_tests;
 862    use gpui::TestAppContext;
 863    use serde_json::json;
 864
 865    crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
 866
 867    pub fn local_command() -> AgentServerCommand {
 868        AgentServerCommand {
 869            path: "claude".into(),
 870            args: vec![],
 871            env: None,
 872        }
 873    }
 874
 875    #[gpui::test]
 876    #[cfg_attr(not(feature = "e2e"), ignore)]
 877    async fn test_todo_plan(cx: &mut TestAppContext) {
 878        let fs = e2e_tests::init_test(cx).await;
 879        let project = Project::test(fs, [], cx).await;
 880        let thread =
 881            e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
 882
 883        thread
 884            .update(cx, |thread, cx| {
 885                thread.send_raw(
 886                    "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
 887                    cx,
 888                )
 889            })
 890            .await
 891            .unwrap();
 892
 893        let mut entries_len = 0;
 894
 895        thread.read_with(cx, |thread, _| {
 896            entries_len = thread.plan().entries.len();
 897            assert!(thread.plan().entries.len() > 0, "Empty plan");
 898        });
 899
 900        thread
 901            .update(cx, |thread, cx| {
 902                thread.send_raw(
 903                    "Mark the first entry status as in progress without acting on it.",
 904                    cx,
 905                )
 906            })
 907            .await
 908            .unwrap();
 909
 910        thread.read_with(cx, |thread, _| {
 911            assert!(matches!(
 912                thread.plan().entries[0].status,
 913                acp::PlanEntryStatus::InProgress
 914            ));
 915            assert_eq!(thread.plan().entries.len(), entries_len);
 916        });
 917
 918        thread
 919            .update(cx, |thread, cx| {
 920                thread.send_raw(
 921                    "Now mark the first entry as completed without acting on it.",
 922                    cx,
 923                )
 924            })
 925            .await
 926            .unwrap();
 927
 928        thread.read_with(cx, |thread, _| {
 929            assert!(matches!(
 930                thread.plan().entries[0].status,
 931                acp::PlanEntryStatus::Completed
 932            ));
 933            assert_eq!(thread.plan().entries.len(), entries_len);
 934        });
 935    }
 936
 937    #[test]
 938    fn test_deserialize_content_untagged_text() {
 939        let json = json!("Hello, world!");
 940        let content: Content = serde_json::from_value(json).unwrap();
 941        match content {
 942            Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
 943            _ => panic!("Expected UntaggedText variant"),
 944        }
 945    }
 946
 947    #[test]
 948    fn test_deserialize_content_chunks() {
 949        let json = json!([
 950            {
 951                "type": "text",
 952                "text": "Hello"
 953            },
 954            {
 955                "type": "tool_use",
 956                "id": "tool_123",
 957                "name": "calculator",
 958                "input": {"operation": "add", "a": 1, "b": 2}
 959            }
 960        ]);
 961        let content: Content = serde_json::from_value(json).unwrap();
 962        match content {
 963            Content::Chunks(chunks) => {
 964                assert_eq!(chunks.len(), 2);
 965                match &chunks[0] {
 966                    ContentChunk::Text { text } => assert_eq!(text, "Hello"),
 967                    _ => panic!("Expected Text chunk"),
 968                }
 969                match &chunks[1] {
 970                    ContentChunk::ToolUse { id, name, input } => {
 971                        assert_eq!(id, "tool_123");
 972                        assert_eq!(name, "calculator");
 973                        assert_eq!(input["operation"], "add");
 974                        assert_eq!(input["a"], 1);
 975                        assert_eq!(input["b"], 2);
 976                    }
 977                    _ => panic!("Expected ToolUse chunk"),
 978                }
 979            }
 980            _ => panic!("Expected Chunks variant"),
 981        }
 982    }
 983
 984    #[test]
 985    fn test_deserialize_tool_result_untagged_text() {
 986        let json = json!({
 987            "type": "tool_result",
 988            "content": "Result content",
 989            "tool_use_id": "tool_456"
 990        });
 991        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
 992        match chunk {
 993            ContentChunk::ToolResult {
 994                content,
 995                tool_use_id,
 996            } => {
 997                match content {
 998                    Content::UntaggedText(text) => assert_eq!(text, "Result content"),
 999                    _ => panic!("Expected UntaggedText content"),
1000                }
1001                assert_eq!(tool_use_id, "tool_456");
1002            }
1003            _ => panic!("Expected ToolResult variant"),
1004        }
1005    }
1006
1007    #[test]
1008    fn test_deserialize_tool_result_chunks() {
1009        let json = json!({
1010            "type": "tool_result",
1011            "content": [
1012                {
1013                    "type": "text",
1014                    "text": "Processing complete"
1015                },
1016                {
1017                    "type": "text",
1018                    "text": "Result: 42"
1019                }
1020            ],
1021            "tool_use_id": "tool_789"
1022        });
1023        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1024        match chunk {
1025            ContentChunk::ToolResult {
1026                content,
1027                tool_use_id,
1028            } => {
1029                match content {
1030                    Content::Chunks(chunks) => {
1031                        assert_eq!(chunks.len(), 2);
1032                        match &chunks[0] {
1033                            ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1034                            _ => panic!("Expected Text chunk"),
1035                        }
1036                        match &chunks[1] {
1037                            ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1038                            _ => panic!("Expected Text chunk"),
1039                        }
1040                    }
1041                    _ => panic!("Expected Chunks content"),
1042                }
1043                assert_eq!(tool_use_id, "tool_789");
1044            }
1045            _ => panic!("Expected ToolResult variant"),
1046        }
1047    }
1048}