claude.rs

   1mod mcp_server;
   2pub mod tools;
   3
   4use collections::HashMap;
   5use context_server::listener::McpServerTool;
   6use project::Project;
   7use settings::SettingsStore;
   8use smol::process::Child;
   9use std::cell::{Cell, RefCell};
  10use std::fmt::Display;
  11use std::path::Path;
  12use std::rc::Rc;
  13use uuid::Uuid;
  14
  15use agent_client_protocol as acp;
  16use anyhow::{Result, anyhow};
  17use futures::channel::oneshot;
  18use futures::{AsyncBufReadExt, AsyncWriteExt};
  19use futures::{
  20    AsyncRead, AsyncWrite, FutureExt, StreamExt,
  21    channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
  22    io::BufReader,
  23    select_biased,
  24};
  25use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  26use serde::{Deserialize, Serialize};
  27use util::{ResultExt, debug_panic};
  28
  29use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
  30use crate::claude::tools::ClaudeTool;
  31use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
  32use acp_thread::{AcpThread, AgentConnection};
  33
  34#[derive(Clone)]
  35pub struct ClaudeCode;
  36
  37impl AgentServer for ClaudeCode {
  38    fn name(&self) -> &'static str {
  39        "Claude Code"
  40    }
  41
  42    fn empty_state_headline(&self) -> &'static str {
  43        self.name()
  44    }
  45
  46    fn empty_state_message(&self) -> &'static str {
  47        "How can I help you today?"
  48    }
  49
  50    fn logo(&self) -> ui::IconName {
  51        ui::IconName::AiClaude
  52    }
  53
  54    fn connect(
  55        &self,
  56        _root_dir: &Path,
  57        _project: &Entity<Project>,
  58        _cx: &mut App,
  59    ) -> Task<Result<Rc<dyn AgentConnection>>> {
  60        let connection = ClaudeAgentConnection {
  61            sessions: Default::default(),
  62        };
  63
  64        Task::ready(Ok(Rc::new(connection) as _))
  65    }
  66}
  67
  68struct ClaudeAgentConnection {
  69    sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
  70}
  71
  72impl AgentConnection for ClaudeAgentConnection {
  73    fn new_thread(
  74        self: Rc<Self>,
  75        project: Entity<Project>,
  76        cwd: &Path,
  77        cx: &mut AsyncApp,
  78    ) -> Task<Result<Entity<AcpThread>>> {
  79        let cwd = cwd.to_owned();
  80        cx.spawn(async move |cx| {
  81            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
  82            let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?;
  83
  84            let mut mcp_servers = HashMap::default();
  85            mcp_servers.insert(
  86                mcp_server::SERVER_NAME.to_string(),
  87                permission_mcp_server.server_config()?,
  88            );
  89            let mcp_config = McpConfig { mcp_servers };
  90
  91            let mcp_config_file = tempfile::NamedTempFile::new()?;
  92            let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
  93
  94            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
  95            mcp_config_file
  96                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
  97                .await?;
  98            mcp_config_file.flush().await?;
  99
 100            let settings = cx.read_global(|settings: &SettingsStore, _| {
 101                settings.get::<AllAgentServersSettings>(None).claude.clone()
 102            })?;
 103
 104            let Some(command) = AgentServerCommand::resolve(
 105                "claude",
 106                &[],
 107                Some(&util::paths::home_dir().join(".claude/local/claude")),
 108                settings,
 109                &project,
 110                cx,
 111            )
 112            .await
 113            else {
 114                anyhow::bail!("Failed to find claude binary");
 115            };
 116
 117            let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
 118            let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
 119
 120            let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
 121
 122            log::trace!("Starting session with id: {}", session_id);
 123
 124            let mut child = spawn_claude(
 125                &command,
 126                ClaudeSessionMode::Start,
 127                session_id.clone(),
 128                &mcp_config_path,
 129                &cwd,
 130            )?;
 131
 132            let stdin = child.stdin.take().unwrap();
 133            let stdout = child.stdout.take().unwrap();
 134
 135            let pid = child.id();
 136            log::trace!("Spawned (pid: {})", pid);
 137
 138            cx.background_spawn(async move {
 139                let mut outgoing_rx = Some(outgoing_rx);
 140
 141                ClaudeAgentSession::handle_io(
 142                    outgoing_rx.take().unwrap(),
 143                    incoming_message_tx.clone(),
 144                    stdin,
 145                    stdout,
 146                )
 147                .await?;
 148
 149                log::trace!("Stopped (pid: {})", pid);
 150
 151                drop(mcp_config_path);
 152                anyhow::Ok(())
 153            })
 154            .detach();
 155
 156            let pending_cancellation = Rc::new(Cell::new(PendingCancellation::None));
 157
 158            let end_turn_tx = Rc::new(RefCell::new(None));
 159            let handler_task = cx.spawn({
 160                let end_turn_tx = end_turn_tx.clone();
 161                let mut thread_rx = thread_rx.clone();
 162                let cancellation_state = pending_cancellation.clone();
 163                async move |cx| {
 164                    while let Some(message) = incoming_message_rx.next().await {
 165                        ClaudeAgentSession::handle_message(
 166                            thread_rx.clone(),
 167                            message,
 168                            end_turn_tx.clone(),
 169                            cancellation_state.clone(),
 170                            cx,
 171                        )
 172                        .await
 173                    }
 174
 175                    if let Some(status) = child.status().await.log_err() {
 176                        if let Some(thread) = thread_rx.recv().await.ok() {
 177                            thread
 178                                .update(cx, |thread, cx| {
 179                                    thread.emit_server_exited(status, cx);
 180                                })
 181                                .ok();
 182                        }
 183                    }
 184                }
 185            });
 186
 187            let thread = cx.new(|cx| {
 188                AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
 189            })?;
 190
 191            thread_tx.send(thread.downgrade())?;
 192
 193            let session = ClaudeAgentSession {
 194                outgoing_tx,
 195                end_turn_tx,
 196                pending_cancellation,
 197                _handler_task: handler_task,
 198                _mcp_server: Some(permission_mcp_server),
 199            };
 200
 201            self.sessions.borrow_mut().insert(session_id, session);
 202
 203            Ok(thread)
 204        })
 205    }
 206
 207    fn auth_methods(&self) -> &[acp::AuthMethod] {
 208        &[]
 209    }
 210
 211    fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 212        Task::ready(Err(anyhow!("Authentication not supported")))
 213    }
 214
 215    fn prompt(
 216        &self,
 217        params: acp::PromptRequest,
 218        cx: &mut App,
 219    ) -> Task<Result<acp::PromptResponse>> {
 220        let sessions = self.sessions.borrow();
 221        let Some(session) = sessions.get(&params.session_id) else {
 222            return Task::ready(Err(anyhow!(
 223                "Attempted to send message to nonexistent session {}",
 224                params.session_id
 225            )));
 226        };
 227
 228        let (tx, rx) = oneshot::channel();
 229        session.end_turn_tx.borrow_mut().replace(tx);
 230
 231        let mut content = String::new();
 232        for chunk in params.prompt {
 233            match chunk {
 234                acp::ContentBlock::Text(text_content) => {
 235                    content.push_str(&text_content.text);
 236                }
 237                acp::ContentBlock::ResourceLink(resource_link) => {
 238                    content.push_str(&format!("@{}", resource_link.uri));
 239                }
 240                acp::ContentBlock::Audio(_)
 241                | acp::ContentBlock::Image(_)
 242                | acp::ContentBlock::Resource(_) => {
 243                    // TODO
 244                }
 245            }
 246        }
 247
 248        if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
 249            message: Message {
 250                role: Role::User,
 251                content: Content::UntaggedText(content),
 252                id: None,
 253                model: None,
 254                stop_reason: None,
 255                stop_sequence: None,
 256                usage: None,
 257            },
 258            session_id: Some(params.session_id.to_string()),
 259        }) {
 260            return Task::ready(Err(anyhow!(err)));
 261        }
 262
 263        let cancellation_state = session.pending_cancellation.clone();
 264        cx.foreground_executor().spawn(async move {
 265            let result = rx.await??;
 266            cancellation_state.set(PendingCancellation::None);
 267            Ok(result)
 268        })
 269    }
 270
 271    fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
 272        let sessions = self.sessions.borrow();
 273        let Some(session) = sessions.get(&session_id) else {
 274            log::warn!("Attempted to cancel nonexistent session {}", session_id);
 275            return;
 276        };
 277
 278        let request_id = new_request_id();
 279
 280        session.pending_cancellation.set(PendingCancellation::Sent {
 281            request_id: request_id.clone(),
 282        });
 283
 284        session
 285            .outgoing_tx
 286            .unbounded_send(SdkMessage::ControlRequest {
 287                request_id,
 288                request: ControlRequest::Interrupt,
 289            })
 290            .log_err();
 291    }
 292}
 293
 294#[derive(Clone, Copy)]
 295enum ClaudeSessionMode {
 296    Start,
 297    #[expect(dead_code)]
 298    Resume,
 299}
 300
 301fn spawn_claude(
 302    command: &AgentServerCommand,
 303    mode: ClaudeSessionMode,
 304    session_id: acp::SessionId,
 305    mcp_config_path: &Path,
 306    root_dir: &Path,
 307) -> Result<Child> {
 308    let child = util::command::new_smol_command(&command.path)
 309        .args([
 310            "--input-format",
 311            "stream-json",
 312            "--output-format",
 313            "stream-json",
 314            "--print",
 315            "--verbose",
 316            "--mcp-config",
 317            mcp_config_path.to_string_lossy().as_ref(),
 318            "--permission-prompt-tool",
 319            &format!(
 320                "mcp__{}__{}",
 321                mcp_server::SERVER_NAME,
 322                mcp_server::PermissionTool::NAME,
 323            ),
 324            "--allowedTools",
 325            &format!(
 326                "mcp__{}__{},mcp__{}__{}",
 327                mcp_server::SERVER_NAME,
 328                mcp_server::EditTool::NAME,
 329                mcp_server::SERVER_NAME,
 330                mcp_server::ReadTool::NAME
 331            ),
 332            "--disallowedTools",
 333            "Read,Edit",
 334        ])
 335        .args(match mode {
 336            ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
 337            ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
 338        })
 339        .args(command.args.iter().map(|arg| arg.as_str()))
 340        .current_dir(root_dir)
 341        .stdin(std::process::Stdio::piped())
 342        .stdout(std::process::Stdio::piped())
 343        .stderr(std::process::Stdio::inherit())
 344        .kill_on_drop(true)
 345        .spawn()?;
 346
 347    Ok(child)
 348}
 349
 350struct ClaudeAgentSession {
 351    outgoing_tx: UnboundedSender<SdkMessage>,
 352    end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
 353    pending_cancellation: Rc<Cell<PendingCancellation>>,
 354    _mcp_server: Option<ClaudeZedMcpServer>,
 355    _handler_task: Task<()>,
 356}
 357
 358#[derive(Debug, Default, PartialEq)]
 359enum PendingCancellation {
 360    #[default]
 361    None,
 362    Sent {
 363        request_id: String,
 364    },
 365    Confirmed,
 366}
 367
 368impl ClaudeAgentSession {
 369    async fn handle_message(
 370        mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
 371        message: SdkMessage,
 372        end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
 373        pending_cancellation: Rc<Cell<PendingCancellation>>,
 374        cx: &mut AsyncApp,
 375    ) {
 376        match message {
 377            // we should only be sending these out, they don't need to be in the thread
 378            SdkMessage::ControlRequest { .. } => {}
 379            SdkMessage::User {
 380                message,
 381                session_id: _,
 382            } => {
 383                let Some(thread) = thread_rx
 384                    .recv()
 385                    .await
 386                    .log_err()
 387                    .and_then(|entity| entity.upgrade())
 388                else {
 389                    log::error!("Received an SDK message but thread is gone");
 390                    return;
 391                };
 392
 393                for chunk in message.content.chunks() {
 394                    match chunk {
 395                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 396                            let state = pending_cancellation.take();
 397                            if state != PendingCancellation::Confirmed {
 398                                thread
 399                                    .update(cx, |thread, cx| {
 400                                        thread.push_user_content_block(text.into(), cx)
 401                                    })
 402                                    .log_err();
 403                            }
 404                            pending_cancellation.set(state);
 405                        }
 406                        ContentChunk::ToolResult {
 407                            content,
 408                            tool_use_id,
 409                        } => {
 410                            let content = content.to_string();
 411                            thread
 412                                .update(cx, |thread, cx| {
 413                                    thread.update_tool_call(
 414                                        acp::ToolCallUpdate {
 415                                            id: acp::ToolCallId(tool_use_id.into()),
 416                                            fields: acp::ToolCallUpdateFields {
 417                                                status: Some(acp::ToolCallStatus::Completed),
 418                                                content: (!content.is_empty())
 419                                                    .then(|| vec![content.into()]),
 420                                                ..Default::default()
 421                                            },
 422                                        },
 423                                        cx,
 424                                    )
 425                                })
 426                                .log_err();
 427                        }
 428                        ContentChunk::Thinking { .. }
 429                        | ContentChunk::RedactedThinking
 430                        | ContentChunk::ToolUse { .. } => {
 431                            debug_panic!(
 432                                "Should not get {:?} with role: assistant. should we handle this?",
 433                                chunk
 434                            );
 435                        }
 436
 437                        ContentChunk::Image
 438                        | ContentChunk::Document
 439                        | ContentChunk::WebSearchToolResult => {
 440                            thread
 441                                .update(cx, |thread, cx| {
 442                                    thread.push_assistant_content_block(
 443                                        format!("Unsupported content: {:?}", chunk).into(),
 444                                        false,
 445                                        cx,
 446                                    )
 447                                })
 448                                .log_err();
 449                        }
 450                    }
 451                }
 452            }
 453            SdkMessage::Assistant {
 454                message,
 455                session_id: _,
 456            } => {
 457                let Some(thread) = thread_rx
 458                    .recv()
 459                    .await
 460                    .log_err()
 461                    .and_then(|entity| entity.upgrade())
 462                else {
 463                    log::error!("Received an SDK message but thread is gone");
 464                    return;
 465                };
 466
 467                for chunk in message.content.chunks() {
 468                    match chunk {
 469                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 470                            thread
 471                                .update(cx, |thread, cx| {
 472                                    thread.push_assistant_content_block(text.into(), false, cx)
 473                                })
 474                                .log_err();
 475                        }
 476                        ContentChunk::Thinking { thinking } => {
 477                            thread
 478                                .update(cx, |thread, cx| {
 479                                    thread.push_assistant_content_block(thinking.into(), true, cx)
 480                                })
 481                                .log_err();
 482                        }
 483                        ContentChunk::RedactedThinking => {
 484                            thread
 485                                .update(cx, |thread, cx| {
 486                                    thread.push_assistant_content_block(
 487                                        "[REDACTED]".into(),
 488                                        true,
 489                                        cx,
 490                                    )
 491                                })
 492                                .log_err();
 493                        }
 494                        ContentChunk::ToolUse { id, name, input } => {
 495                            let claude_tool = ClaudeTool::infer(&name, input);
 496
 497                            thread
 498                                .update(cx, |thread, cx| {
 499                                    if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
 500                                        thread.update_plan(
 501                                            acp::Plan {
 502                                                entries: params
 503                                                    .todos
 504                                                    .into_iter()
 505                                                    .map(Into::into)
 506                                                    .collect(),
 507                                            },
 508                                            cx,
 509                                        )
 510                                    } else {
 511                                        thread.upsert_tool_call(
 512                                            claude_tool.as_acp(acp::ToolCallId(id.into())),
 513                                            cx,
 514                                        );
 515                                    }
 516                                })
 517                                .log_err();
 518                        }
 519                        ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
 520                            debug_panic!(
 521                                "Should not get tool results with role: assistant. should we handle this?"
 522                            );
 523                        }
 524                        ContentChunk::Image | ContentChunk::Document => {
 525                            thread
 526                                .update(cx, |thread, cx| {
 527                                    thread.push_assistant_content_block(
 528                                        format!("Unsupported content: {:?}", chunk).into(),
 529                                        false,
 530                                        cx,
 531                                    )
 532                                })
 533                                .log_err();
 534                        }
 535                    }
 536                }
 537            }
 538            SdkMessage::Result {
 539                is_error,
 540                subtype,
 541                result,
 542                ..
 543            } => {
 544                if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
 545                    if is_error
 546                        || (subtype == ResultErrorType::ErrorDuringExecution
 547                            && pending_cancellation.take() != PendingCancellation::Confirmed)
 548                    {
 549                        end_turn_tx
 550                            .send(Err(anyhow!(
 551                                "Error: {}",
 552                                result.unwrap_or_else(|| subtype.to_string())
 553                            )))
 554                            .ok();
 555                    } else {
 556                        let stop_reason = match subtype {
 557                            ResultErrorType::Success => acp::StopReason::EndTurn,
 558                            ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
 559                            ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
 560                        };
 561                        end_turn_tx
 562                            .send(Ok(acp::PromptResponse { stop_reason }))
 563                            .ok();
 564                    }
 565                }
 566            }
 567            SdkMessage::ControlResponse { response } => {
 568                if matches!(response.subtype, ResultErrorType::Success) {
 569                    let pending_cancellation_value = pending_cancellation.take();
 570
 571                    if let PendingCancellation::Sent { request_id } = &pending_cancellation_value
 572                        && request_id == &response.request_id
 573                    {
 574                        pending_cancellation.set(PendingCancellation::Confirmed);
 575                    } else {
 576                        pending_cancellation.set(pending_cancellation_value);
 577                    }
 578                }
 579            }
 580            SdkMessage::System { .. } => {}
 581        }
 582    }
 583
 584    async fn handle_io(
 585        mut outgoing_rx: UnboundedReceiver<SdkMessage>,
 586        incoming_tx: UnboundedSender<SdkMessage>,
 587        mut outgoing_bytes: impl Unpin + AsyncWrite,
 588        incoming_bytes: impl Unpin + AsyncRead,
 589    ) -> Result<UnboundedReceiver<SdkMessage>> {
 590        let mut output_reader = BufReader::new(incoming_bytes);
 591        let mut outgoing_line = Vec::new();
 592        let mut incoming_line = String::new();
 593        loop {
 594            select_biased! {
 595                message = outgoing_rx.next() => {
 596                    if let Some(message) = message {
 597                        outgoing_line.clear();
 598                        serde_json::to_writer(&mut outgoing_line, &message)?;
 599                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
 600                        outgoing_line.push(b'\n');
 601                        outgoing_bytes.write_all(&outgoing_line).await.ok();
 602                    } else {
 603                        break;
 604                    }
 605                }
 606                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
 607                    if bytes_read? == 0 {
 608                        break
 609                    }
 610                    log::trace!("recv: {}", &incoming_line);
 611                    match serde_json::from_str::<SdkMessage>(&incoming_line) {
 612                        Ok(message) => {
 613                            incoming_tx.unbounded_send(message).log_err();
 614                        }
 615                        Err(error) => {
 616                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
 617                        }
 618                    }
 619                    incoming_line.clear();
 620                }
 621            }
 622        }
 623
 624        Ok(outgoing_rx)
 625    }
 626}
 627
 628#[derive(Debug, Clone, Serialize, Deserialize)]
 629struct Message {
 630    role: Role,
 631    content: Content,
 632    #[serde(skip_serializing_if = "Option::is_none")]
 633    id: Option<String>,
 634    #[serde(skip_serializing_if = "Option::is_none")]
 635    model: Option<String>,
 636    #[serde(skip_serializing_if = "Option::is_none")]
 637    stop_reason: Option<String>,
 638    #[serde(skip_serializing_if = "Option::is_none")]
 639    stop_sequence: Option<String>,
 640    #[serde(skip_serializing_if = "Option::is_none")]
 641    usage: Option<Usage>,
 642}
 643
 644#[derive(Debug, Clone, Serialize, Deserialize)]
 645#[serde(untagged)]
 646enum Content {
 647    UntaggedText(String),
 648    Chunks(Vec<ContentChunk>),
 649}
 650
 651impl Content {
 652    pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
 653        match self {
 654            Self::Chunks(chunks) => chunks.into_iter(),
 655            Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
 656        }
 657    }
 658}
 659
 660impl Display for Content {
 661    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 662        match self {
 663            Content::UntaggedText(txt) => write!(f, "{}", txt),
 664            Content::Chunks(chunks) => {
 665                for chunk in chunks {
 666                    write!(f, "{}", chunk)?;
 667                }
 668                Ok(())
 669            }
 670        }
 671    }
 672}
 673
 674#[derive(Debug, Clone, Serialize, Deserialize)]
 675#[serde(tag = "type", rename_all = "snake_case")]
 676enum ContentChunk {
 677    Text {
 678        text: String,
 679    },
 680    ToolUse {
 681        id: String,
 682        name: String,
 683        input: serde_json::Value,
 684    },
 685    ToolResult {
 686        content: Content,
 687        tool_use_id: String,
 688    },
 689    Thinking {
 690        thinking: String,
 691    },
 692    RedactedThinking,
 693    // TODO
 694    Image,
 695    Document,
 696    WebSearchToolResult,
 697    #[serde(untagged)]
 698    UntaggedText(String),
 699}
 700
 701impl Display for ContentChunk {
 702    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 703        match self {
 704            ContentChunk::Text { text } => write!(f, "{}", text),
 705            ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
 706            ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
 707            ContentChunk::UntaggedText(text) => write!(f, "{}", text),
 708            ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
 709            ContentChunk::Image
 710            | ContentChunk::Document
 711            | ContentChunk::ToolUse { .. }
 712            | ContentChunk::WebSearchToolResult => {
 713                write!(f, "\n{:?}\n", &self)
 714            }
 715        }
 716    }
 717}
 718
 719#[derive(Debug, Clone, Serialize, Deserialize)]
 720struct Usage {
 721    input_tokens: u32,
 722    cache_creation_input_tokens: u32,
 723    cache_read_input_tokens: u32,
 724    output_tokens: u32,
 725    service_tier: String,
 726}
 727
 728#[derive(Debug, Clone, Serialize, Deserialize)]
 729#[serde(rename_all = "snake_case")]
 730enum Role {
 731    System,
 732    Assistant,
 733    User,
 734}
 735
 736#[derive(Debug, Clone, Serialize, Deserialize)]
 737struct MessageParam {
 738    role: Role,
 739    content: String,
 740}
 741
 742#[derive(Debug, Clone, Serialize, Deserialize)]
 743#[serde(tag = "type", rename_all = "snake_case")]
 744enum SdkMessage {
 745    // An assistant message
 746    Assistant {
 747        message: Message, // from Anthropic SDK
 748        #[serde(skip_serializing_if = "Option::is_none")]
 749        session_id: Option<String>,
 750    },
 751    // A user message
 752    User {
 753        message: Message, // from Anthropic SDK
 754        #[serde(skip_serializing_if = "Option::is_none")]
 755        session_id: Option<String>,
 756    },
 757    // Emitted as the last message in a conversation
 758    Result {
 759        subtype: ResultErrorType,
 760        duration_ms: f64,
 761        duration_api_ms: f64,
 762        is_error: bool,
 763        num_turns: i32,
 764        #[serde(skip_serializing_if = "Option::is_none")]
 765        result: Option<String>,
 766        session_id: String,
 767        total_cost_usd: f64,
 768    },
 769    // Emitted as the first message at the start of a conversation
 770    System {
 771        cwd: String,
 772        session_id: String,
 773        tools: Vec<String>,
 774        model: String,
 775        mcp_servers: Vec<McpServer>,
 776        #[serde(rename = "apiKeySource")]
 777        api_key_source: String,
 778        #[serde(rename = "permissionMode")]
 779        permission_mode: PermissionMode,
 780    },
 781    /// Messages used to control the conversation, outside of chat messages to the model
 782    ControlRequest {
 783        request_id: String,
 784        request: ControlRequest,
 785    },
 786    /// Response to a control request
 787    ControlResponse { response: ControlResponse },
 788}
 789
 790#[derive(Debug, Clone, Serialize, Deserialize)]
 791#[serde(tag = "subtype", rename_all = "snake_case")]
 792enum ControlRequest {
 793    /// Cancel the current conversation
 794    Interrupt,
 795}
 796
 797#[derive(Debug, Clone, Serialize, Deserialize)]
 798struct ControlResponse {
 799    request_id: String,
 800    subtype: ResultErrorType,
 801}
 802
 803#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
 804#[serde(rename_all = "snake_case")]
 805enum ResultErrorType {
 806    Success,
 807    ErrorMaxTurns,
 808    ErrorDuringExecution,
 809}
 810
 811impl Display for ResultErrorType {
 812    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 813        match self {
 814            ResultErrorType::Success => write!(f, "success"),
 815            ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
 816            ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
 817        }
 818    }
 819}
 820
 821fn new_request_id() -> String {
 822    use rand::Rng;
 823    // In the Claude Code TS SDK they just generate a random 12 character string,
 824    // `Math.random().toString(36).substring(2, 15)`
 825    rand::thread_rng()
 826        .sample_iter(&rand::distributions::Alphanumeric)
 827        .take(12)
 828        .map(char::from)
 829        .collect()
 830}
 831
 832#[derive(Debug, Clone, Serialize, Deserialize)]
 833struct McpServer {
 834    name: String,
 835    status: String,
 836}
 837
 838#[derive(Debug, Clone, Serialize, Deserialize)]
 839#[serde(rename_all = "camelCase")]
 840enum PermissionMode {
 841    Default,
 842    AcceptEdits,
 843    BypassPermissions,
 844    Plan,
 845}
 846
 847#[cfg(test)]
 848pub(crate) mod tests {
 849    use super::*;
 850    use crate::e2e_tests;
 851    use gpui::TestAppContext;
 852    use serde_json::json;
 853
 854    crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
 855
 856    pub fn local_command() -> AgentServerCommand {
 857        AgentServerCommand {
 858            path: "claude".into(),
 859            args: vec![],
 860            env: None,
 861        }
 862    }
 863
 864    #[gpui::test]
 865    #[cfg_attr(not(feature = "e2e"), ignore)]
 866    async fn test_todo_plan(cx: &mut TestAppContext) {
 867        let fs = e2e_tests::init_test(cx).await;
 868        let project = Project::test(fs, [], cx).await;
 869        let thread =
 870            e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
 871
 872        thread
 873            .update(cx, |thread, cx| {
 874                thread.send_raw(
 875                    "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
 876                    cx,
 877                )
 878            })
 879            .await
 880            .unwrap();
 881
 882        let mut entries_len = 0;
 883
 884        thread.read_with(cx, |thread, _| {
 885            entries_len = thread.plan().entries.len();
 886            assert!(thread.plan().entries.len() > 0, "Empty plan");
 887        });
 888
 889        thread
 890            .update(cx, |thread, cx| {
 891                thread.send_raw(
 892                    "Mark the first entry status as in progress without acting on it.",
 893                    cx,
 894                )
 895            })
 896            .await
 897            .unwrap();
 898
 899        thread.read_with(cx, |thread, _| {
 900            assert!(matches!(
 901                thread.plan().entries[0].status,
 902                acp::PlanEntryStatus::InProgress
 903            ));
 904            assert_eq!(thread.plan().entries.len(), entries_len);
 905        });
 906
 907        thread
 908            .update(cx, |thread, cx| {
 909                thread.send_raw(
 910                    "Now mark the first entry as completed without acting on it.",
 911                    cx,
 912                )
 913            })
 914            .await
 915            .unwrap();
 916
 917        thread.read_with(cx, |thread, _| {
 918            assert!(matches!(
 919                thread.plan().entries[0].status,
 920                acp::PlanEntryStatus::Completed
 921            ));
 922            assert_eq!(thread.plan().entries.len(), entries_len);
 923        });
 924    }
 925
 926    #[test]
 927    fn test_deserialize_content_untagged_text() {
 928        let json = json!("Hello, world!");
 929        let content: Content = serde_json::from_value(json).unwrap();
 930        match content {
 931            Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
 932            _ => panic!("Expected UntaggedText variant"),
 933        }
 934    }
 935
 936    #[test]
 937    fn test_deserialize_content_chunks() {
 938        let json = json!([
 939            {
 940                "type": "text",
 941                "text": "Hello"
 942            },
 943            {
 944                "type": "tool_use",
 945                "id": "tool_123",
 946                "name": "calculator",
 947                "input": {"operation": "add", "a": 1, "b": 2}
 948            }
 949        ]);
 950        let content: Content = serde_json::from_value(json).unwrap();
 951        match content {
 952            Content::Chunks(chunks) => {
 953                assert_eq!(chunks.len(), 2);
 954                match &chunks[0] {
 955                    ContentChunk::Text { text } => assert_eq!(text, "Hello"),
 956                    _ => panic!("Expected Text chunk"),
 957                }
 958                match &chunks[1] {
 959                    ContentChunk::ToolUse { id, name, input } => {
 960                        assert_eq!(id, "tool_123");
 961                        assert_eq!(name, "calculator");
 962                        assert_eq!(input["operation"], "add");
 963                        assert_eq!(input["a"], 1);
 964                        assert_eq!(input["b"], 2);
 965                    }
 966                    _ => panic!("Expected ToolUse chunk"),
 967                }
 968            }
 969            _ => panic!("Expected Chunks variant"),
 970        }
 971    }
 972
 973    #[test]
 974    fn test_deserialize_tool_result_untagged_text() {
 975        let json = json!({
 976            "type": "tool_result",
 977            "content": "Result content",
 978            "tool_use_id": "tool_456"
 979        });
 980        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
 981        match chunk {
 982            ContentChunk::ToolResult {
 983                content,
 984                tool_use_id,
 985            } => {
 986                match content {
 987                    Content::UntaggedText(text) => assert_eq!(text, "Result content"),
 988                    _ => panic!("Expected UntaggedText content"),
 989                }
 990                assert_eq!(tool_use_id, "tool_456");
 991            }
 992            _ => panic!("Expected ToolResult variant"),
 993        }
 994    }
 995
 996    #[test]
 997    fn test_deserialize_tool_result_chunks() {
 998        let json = json!({
 999            "type": "tool_result",
1000            "content": [
1001                {
1002                    "type": "text",
1003                    "text": "Processing complete"
1004                },
1005                {
1006                    "type": "text",
1007                    "text": "Result: 42"
1008                }
1009            ],
1010            "tool_use_id": "tool_789"
1011        });
1012        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1013        match chunk {
1014            ContentChunk::ToolResult {
1015                content,
1016                tool_use_id,
1017            } => {
1018                match content {
1019                    Content::Chunks(chunks) => {
1020                        assert_eq!(chunks.len(), 2);
1021                        match &chunks[0] {
1022                            ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1023                            _ => panic!("Expected Text chunk"),
1024                        }
1025                        match &chunks[1] {
1026                            ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1027                            _ => panic!("Expected Text chunk"),
1028                        }
1029                    }
1030                    _ => panic!("Expected Chunks content"),
1031                }
1032                assert_eq!(tool_use_id, "tool_789");
1033            }
1034            _ => panic!("Expected ToolResult variant"),
1035        }
1036    }
1037}