claude.rs

   1mod mcp_server;
   2pub mod tools;
   3
   4use collections::HashMap;
   5use context_server::listener::McpServerTool;
   6use project::Project;
   7use settings::SettingsStore;
   8use smol::process::Child;
   9use std::cell::RefCell;
  10use std::fmt::Display;
  11use std::path::Path;
  12use std::rc::Rc;
  13use uuid::Uuid;
  14
  15use agent_client_protocol as acp;
  16use anyhow::{Result, anyhow};
  17use futures::channel::oneshot;
  18use futures::{AsyncBufReadExt, AsyncWriteExt};
  19use futures::{
  20    AsyncRead, AsyncWrite, FutureExt, StreamExt,
  21    channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
  22    io::BufReader,
  23    select_biased,
  24};
  25use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  26use serde::{Deserialize, Serialize};
  27use util::{ResultExt, debug_panic};
  28
  29use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
  30use crate::claude::tools::ClaudeTool;
  31use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
  32use acp_thread::{AcpThread, AgentConnection};
  33
  34#[derive(Clone)]
  35pub struct ClaudeCode;
  36
  37impl AgentServer for ClaudeCode {
  38    fn name(&self) -> &'static str {
  39        "Claude Code"
  40    }
  41
  42    fn empty_state_headline(&self) -> &'static str {
  43        self.name()
  44    }
  45
  46    fn empty_state_message(&self) -> &'static str {
  47        "How can I help you today?"
  48    }
  49
  50    fn logo(&self) -> ui::IconName {
  51        ui::IconName::AiClaude
  52    }
  53
  54    fn connect(
  55        &self,
  56        _root_dir: &Path,
  57        _project: &Entity<Project>,
  58        _cx: &mut App,
  59    ) -> Task<Result<Rc<dyn AgentConnection>>> {
  60        let connection = ClaudeAgentConnection {
  61            sessions: Default::default(),
  62        };
  63
  64        Task::ready(Ok(Rc::new(connection) as _))
  65    }
  66}
  67
  68struct ClaudeAgentConnection {
  69    sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
  70}
  71
  72impl AgentConnection for ClaudeAgentConnection {
  73    fn new_thread(
  74        self: Rc<Self>,
  75        project: Entity<Project>,
  76        cwd: &Path,
  77        cx: &mut AsyncApp,
  78    ) -> Task<Result<Entity<AcpThread>>> {
  79        let cwd = cwd.to_owned();
  80        cx.spawn(async move |cx| {
  81            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
  82            let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?;
  83
  84            let mut mcp_servers = HashMap::default();
  85            mcp_servers.insert(
  86                mcp_server::SERVER_NAME.to_string(),
  87                permission_mcp_server.server_config()?,
  88            );
  89            let mcp_config = McpConfig { mcp_servers };
  90
  91            let mcp_config_file = tempfile::NamedTempFile::new()?;
  92            let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
  93
  94            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
  95            mcp_config_file
  96                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
  97                .await?;
  98            mcp_config_file.flush().await?;
  99
 100            let settings = cx.read_global(|settings: &SettingsStore, _| {
 101                settings.get::<AllAgentServersSettings>(None).claude.clone()
 102            })?;
 103
 104            let Some(command) = AgentServerCommand::resolve(
 105                "claude",
 106                &[],
 107                Some(&util::paths::home_dir().join(".claude/local/claude")),
 108                settings,
 109                &project,
 110                cx,
 111            )
 112            .await
 113            else {
 114                anyhow::bail!("Failed to find claude binary");
 115            };
 116
 117            let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
 118            let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
 119
 120            let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
 121
 122            log::trace!("Starting session with id: {}", session_id);
 123
 124            let mut child = spawn_claude(
 125                &command,
 126                ClaudeSessionMode::Start,
 127                session_id.clone(),
 128                &mcp_config_path,
 129                &cwd,
 130            )?;
 131
 132            let stdin = child.stdin.take().unwrap();
 133            let stdout = child.stdout.take().unwrap();
 134
 135            let pid = child.id();
 136            log::trace!("Spawned (pid: {})", pid);
 137
 138            cx.background_spawn(async move {
 139                let mut outgoing_rx = Some(outgoing_rx);
 140
 141                ClaudeAgentSession::handle_io(
 142                    outgoing_rx.take().unwrap(),
 143                    incoming_message_tx.clone(),
 144                    stdin,
 145                    stdout,
 146                )
 147                .await?;
 148
 149                log::trace!("Stopped (pid: {})", pid);
 150
 151                drop(mcp_config_path);
 152                anyhow::Ok(())
 153            })
 154            .detach();
 155
 156            let cancellation_state = Rc::new(RefCell::new(CancellationState::None));
 157
 158            let end_turn_tx = Rc::new(RefCell::new(None));
 159            let handler_task = cx.spawn({
 160                let end_turn_tx = end_turn_tx.clone();
 161                let mut thread_rx = thread_rx.clone();
 162                let cancellation_state = cancellation_state.clone();
 163                async move |cx| {
 164                    while let Some(message) = incoming_message_rx.next().await {
 165                        ClaudeAgentSession::handle_message(
 166                            thread_rx.clone(),
 167                            message,
 168                            end_turn_tx.clone(),
 169                            cancellation_state.clone(),
 170                            cx,
 171                        )
 172                        .await
 173                    }
 174
 175                    if let Some(status) = child.status().await.log_err() {
 176                        if let Some(thread) = thread_rx.recv().await.ok() {
 177                            thread
 178                                .update(cx, |thread, cx| {
 179                                    thread.emit_server_exited(status, cx);
 180                                })
 181                                .ok();
 182                        }
 183                    }
 184                }
 185            });
 186
 187            let thread = cx.new(|cx| {
 188                AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
 189            })?;
 190
 191            thread_tx.send(thread.downgrade())?;
 192
 193            let session = ClaudeAgentSession {
 194                outgoing_tx,
 195                end_turn_tx,
 196                cancellation_state,
 197                _handler_task: handler_task,
 198                _mcp_server: Some(permission_mcp_server),
 199            };
 200
 201            self.sessions.borrow_mut().insert(session_id, session);
 202
 203            Ok(thread)
 204        })
 205    }
 206
 207    fn auth_methods(&self) -> &[acp::AuthMethod] {
 208        &[]
 209    }
 210
 211    fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 212        Task::ready(Err(anyhow!("Authentication not supported")))
 213    }
 214
 215    fn prompt(
 216        &self,
 217        params: acp::PromptRequest,
 218        cx: &mut App,
 219    ) -> Task<Result<acp::PromptResponse>> {
 220        let sessions = self.sessions.borrow();
 221        let Some(session) = sessions.get(&params.session_id) else {
 222            return Task::ready(Err(anyhow!(
 223                "Attempted to send message to nonexistent session {}",
 224                params.session_id
 225            )));
 226        };
 227
 228        let (tx, rx) = oneshot::channel();
 229        session.end_turn_tx.borrow_mut().replace(tx);
 230
 231        let mut content = String::new();
 232        for chunk in params.prompt {
 233            match chunk {
 234                acp::ContentBlock::Text(text_content) => {
 235                    content.push_str(&text_content.text);
 236                }
 237                acp::ContentBlock::ResourceLink(resource_link) => {
 238                    content.push_str(&format!("@{}", resource_link.uri));
 239                }
 240                acp::ContentBlock::Audio(_)
 241                | acp::ContentBlock::Image(_)
 242                | acp::ContentBlock::Resource(_) => {
 243                    // TODO
 244                }
 245            }
 246        }
 247
 248        if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
 249            message: Message {
 250                role: Role::User,
 251                content: Content::UntaggedText(content),
 252                id: None,
 253                model: None,
 254                stop_reason: None,
 255                stop_sequence: None,
 256                usage: None,
 257            },
 258            session_id: Some(params.session_id.to_string()),
 259        }) {
 260            return Task::ready(Err(anyhow!(err)));
 261        }
 262
 263        let cancellation_state = session.cancellation_state.clone();
 264        cx.foreground_executor().spawn(async move {
 265            let result = rx.await??;
 266            *cancellation_state.borrow_mut() = CancellationState::None;
 267            Ok(result)
 268        })
 269    }
 270
 271    fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
 272        let sessions = self.sessions.borrow();
 273        let Some(session) = sessions.get(&session_id) else {
 274            log::warn!("Attempted to cancel nonexistent session {}", session_id);
 275            return;
 276        };
 277
 278        let request_id = new_request_id();
 279
 280        *session.cancellation_state.borrow_mut() = CancellationState::Requested {
 281            request_id: request_id.clone(),
 282        };
 283
 284        session
 285            .outgoing_tx
 286            .unbounded_send(SdkMessage::ControlRequest {
 287                request_id,
 288                request: ControlRequest::Interrupt,
 289            })
 290            .log_err();
 291    }
 292}
 293
 294#[derive(Clone, Copy)]
 295enum ClaudeSessionMode {
 296    Start,
 297    #[expect(dead_code)]
 298    Resume,
 299}
 300
 301fn spawn_claude(
 302    command: &AgentServerCommand,
 303    mode: ClaudeSessionMode,
 304    session_id: acp::SessionId,
 305    mcp_config_path: &Path,
 306    root_dir: &Path,
 307) -> Result<Child> {
 308    let child = util::command::new_smol_command(&command.path)
 309        .args([
 310            "--input-format",
 311            "stream-json",
 312            "--output-format",
 313            "stream-json",
 314            "--print",
 315            "--verbose",
 316            "--mcp-config",
 317            mcp_config_path.to_string_lossy().as_ref(),
 318            "--permission-prompt-tool",
 319            &format!(
 320                "mcp__{}__{}",
 321                mcp_server::SERVER_NAME,
 322                mcp_server::PermissionTool::NAME,
 323            ),
 324            "--allowedTools",
 325            &format!(
 326                "mcp__{}__{},mcp__{}__{}",
 327                mcp_server::SERVER_NAME,
 328                mcp_server::EditTool::NAME,
 329                mcp_server::SERVER_NAME,
 330                mcp_server::ReadTool::NAME
 331            ),
 332            "--disallowedTools",
 333            "Read,Edit",
 334        ])
 335        .args(match mode {
 336            ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
 337            ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
 338        })
 339        .args(command.args.iter().map(|arg| arg.as_str()))
 340        .current_dir(root_dir)
 341        .stdin(std::process::Stdio::piped())
 342        .stdout(std::process::Stdio::piped())
 343        .stderr(std::process::Stdio::inherit())
 344        .kill_on_drop(true)
 345        .spawn()?;
 346
 347    Ok(child)
 348}
 349
 350struct ClaudeAgentSession {
 351    outgoing_tx: UnboundedSender<SdkMessage>,
 352    end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
 353    cancellation_state: Rc<RefCell<CancellationState>>,
 354    _mcp_server: Option<ClaudeZedMcpServer>,
 355    _handler_task: Task<()>,
 356}
 357
 358#[derive(Debug, Default)]
 359enum CancellationState {
 360    #[default]
 361    None,
 362    Requested {
 363        request_id: String,
 364    },
 365    Confirmed,
 366}
 367
 368impl CancellationState {
 369    fn is_confirmed(&self) -> bool {
 370        matches!(self, CancellationState::Confirmed)
 371    }
 372}
 373
 374impl ClaudeAgentSession {
 375    async fn handle_message(
 376        mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
 377        message: SdkMessage,
 378        end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
 379        cancellation_state: Rc<RefCell<CancellationState>>,
 380        cx: &mut AsyncApp,
 381    ) {
 382        match message {
 383            // we should only be sending these out, they don't need to be in the thread
 384            SdkMessage::ControlRequest { .. } => {}
 385            SdkMessage::User {
 386                message,
 387                session_id: _,
 388            } => {
 389                let Some(thread) = thread_rx
 390                    .recv()
 391                    .await
 392                    .log_err()
 393                    .and_then(|entity| entity.upgrade())
 394                else {
 395                    log::error!("Received an SDK message but thread is gone");
 396                    return;
 397                };
 398
 399                for chunk in message.content.chunks() {
 400                    match chunk {
 401                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 402                            if !cancellation_state.borrow().is_confirmed() {
 403                                thread
 404                                    .update(cx, |thread, cx| {
 405                                        thread.push_user_content_block(text.into(), cx)
 406                                    })
 407                                    .log_err();
 408                            }
 409                        }
 410                        ContentChunk::ToolResult {
 411                            content,
 412                            tool_use_id,
 413                        } => {
 414                            let content = content.to_string();
 415                            thread
 416                                .update(cx, |thread, cx| {
 417                                    thread.update_tool_call(
 418                                        acp::ToolCallUpdate {
 419                                            id: acp::ToolCallId(tool_use_id.into()),
 420                                            fields: acp::ToolCallUpdateFields {
 421                                                status: if cancellation_state
 422                                                    .borrow()
 423                                                    .is_confirmed()
 424                                                {
 425                                                    // Do not set to completed if turn was cancelled
 426                                                    None
 427                                                } else {
 428                                                    Some(acp::ToolCallStatus::Completed)
 429                                                },
 430                                                content: (!content.is_empty())
 431                                                    .then(|| vec![content.into()]),
 432                                                ..Default::default()
 433                                            },
 434                                        },
 435                                        cx,
 436                                    )
 437                                })
 438                                .log_err();
 439                        }
 440                        ContentChunk::Thinking { .. }
 441                        | ContentChunk::RedactedThinking
 442                        | ContentChunk::ToolUse { .. } => {
 443                            debug_panic!(
 444                                "Should not get {:?} with role: assistant. should we handle this?",
 445                                chunk
 446                            );
 447                        }
 448
 449                        ContentChunk::Image
 450                        | ContentChunk::Document
 451                        | ContentChunk::WebSearchToolResult => {
 452                            thread
 453                                .update(cx, |thread, cx| {
 454                                    thread.push_assistant_content_block(
 455                                        format!("Unsupported content: {:?}", chunk).into(),
 456                                        false,
 457                                        cx,
 458                                    )
 459                                })
 460                                .log_err();
 461                        }
 462                    }
 463                }
 464            }
 465            SdkMessage::Assistant {
 466                message,
 467                session_id: _,
 468            } => {
 469                let Some(thread) = thread_rx
 470                    .recv()
 471                    .await
 472                    .log_err()
 473                    .and_then(|entity| entity.upgrade())
 474                else {
 475                    log::error!("Received an SDK message but thread is gone");
 476                    return;
 477                };
 478
 479                for chunk in message.content.chunks() {
 480                    match chunk {
 481                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
 482                            thread
 483                                .update(cx, |thread, cx| {
 484                                    thread.push_assistant_content_block(text.into(), false, cx)
 485                                })
 486                                .log_err();
 487                        }
 488                        ContentChunk::Thinking { thinking } => {
 489                            thread
 490                                .update(cx, |thread, cx| {
 491                                    thread.push_assistant_content_block(thinking.into(), true, cx)
 492                                })
 493                                .log_err();
 494                        }
 495                        ContentChunk::RedactedThinking => {
 496                            thread
 497                                .update(cx, |thread, cx| {
 498                                    thread.push_assistant_content_block(
 499                                        "[REDACTED]".into(),
 500                                        true,
 501                                        cx,
 502                                    )
 503                                })
 504                                .log_err();
 505                        }
 506                        ContentChunk::ToolUse { id, name, input } => {
 507                            let claude_tool = ClaudeTool::infer(&name, input);
 508
 509                            thread
 510                                .update(cx, |thread, cx| {
 511                                    if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
 512                                        thread.update_plan(
 513                                            acp::Plan {
 514                                                entries: params
 515                                                    .todos
 516                                                    .into_iter()
 517                                                    .map(Into::into)
 518                                                    .collect(),
 519                                            },
 520                                            cx,
 521                                        )
 522                                    } else {
 523                                        thread.upsert_tool_call(
 524                                            claude_tool.as_acp(acp::ToolCallId(id.into())),
 525                                            cx,
 526                                        );
 527                                    }
 528                                })
 529                                .log_err();
 530                        }
 531                        ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
 532                            debug_panic!(
 533                                "Should not get tool results with role: assistant. should we handle this?"
 534                            );
 535                        }
 536                        ContentChunk::Image | ContentChunk::Document => {
 537                            thread
 538                                .update(cx, |thread, cx| {
 539                                    thread.push_assistant_content_block(
 540                                        format!("Unsupported content: {:?}", chunk).into(),
 541                                        false,
 542                                        cx,
 543                                    )
 544                                })
 545                                .log_err();
 546                        }
 547                    }
 548                }
 549            }
 550            SdkMessage::Result {
 551                is_error,
 552                subtype,
 553                result,
 554                ..
 555            } => {
 556                if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
 557                    if is_error
 558                        || (subtype == ResultErrorType::ErrorDuringExecution
 559                            && !cancellation_state.borrow().is_confirmed())
 560                    {
 561                        end_turn_tx
 562                            .send(Err(anyhow!(
 563                                "Error: {}",
 564                                result.unwrap_or_else(|| subtype.to_string())
 565                            )))
 566                            .ok();
 567                    } else {
 568                        let stop_reason = match subtype {
 569                            ResultErrorType::Success => acp::StopReason::EndTurn,
 570                            ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
 571                            ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
 572                        };
 573                        end_turn_tx
 574                            .send(Ok(acp::PromptResponse { stop_reason }))
 575                            .ok();
 576                    }
 577                }
 578            }
 579            SdkMessage::ControlResponse { response } => {
 580                if matches!(response.subtype, ResultErrorType::Success) {
 581                    let mut cancellation_state = cancellation_state.borrow_mut();
 582                    if let CancellationState::Requested { request_id } = &*cancellation_state
 583                        && request_id == &response.request_id
 584                    {
 585                        *cancellation_state = CancellationState::Confirmed;
 586                    }
 587                }
 588            }
 589            SdkMessage::System { .. } => {}
 590        }
 591    }
 592
 593    async fn handle_io(
 594        mut outgoing_rx: UnboundedReceiver<SdkMessage>,
 595        incoming_tx: UnboundedSender<SdkMessage>,
 596        mut outgoing_bytes: impl Unpin + AsyncWrite,
 597        incoming_bytes: impl Unpin + AsyncRead,
 598    ) -> Result<UnboundedReceiver<SdkMessage>> {
 599        let mut output_reader = BufReader::new(incoming_bytes);
 600        let mut outgoing_line = Vec::new();
 601        let mut incoming_line = String::new();
 602        loop {
 603            select_biased! {
 604                message = outgoing_rx.next() => {
 605                    if let Some(message) = message {
 606                        outgoing_line.clear();
 607                        serde_json::to_writer(&mut outgoing_line, &message)?;
 608                        log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
 609                        outgoing_line.push(b'\n');
 610                        outgoing_bytes.write_all(&outgoing_line).await.ok();
 611                    } else {
 612                        break;
 613                    }
 614                }
 615                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
 616                    if bytes_read? == 0 {
 617                        break
 618                    }
 619                    log::trace!("recv: {}", &incoming_line);
 620                    match serde_json::from_str::<SdkMessage>(&incoming_line) {
 621                        Ok(message) => {
 622                            incoming_tx.unbounded_send(message).log_err();
 623                        }
 624                        Err(error) => {
 625                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
 626                        }
 627                    }
 628                    incoming_line.clear();
 629                }
 630            }
 631        }
 632
 633        Ok(outgoing_rx)
 634    }
 635}
 636
 637#[derive(Debug, Clone, Serialize, Deserialize)]
 638struct Message {
 639    role: Role,
 640    content: Content,
 641    #[serde(skip_serializing_if = "Option::is_none")]
 642    id: Option<String>,
 643    #[serde(skip_serializing_if = "Option::is_none")]
 644    model: Option<String>,
 645    #[serde(skip_serializing_if = "Option::is_none")]
 646    stop_reason: Option<String>,
 647    #[serde(skip_serializing_if = "Option::is_none")]
 648    stop_sequence: Option<String>,
 649    #[serde(skip_serializing_if = "Option::is_none")]
 650    usage: Option<Usage>,
 651}
 652
 653#[derive(Debug, Clone, Serialize, Deserialize)]
 654#[serde(untagged)]
 655enum Content {
 656    UntaggedText(String),
 657    Chunks(Vec<ContentChunk>),
 658}
 659
 660impl Content {
 661    pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
 662        match self {
 663            Self::Chunks(chunks) => chunks.into_iter(),
 664            Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
 665        }
 666    }
 667}
 668
 669impl Display for Content {
 670    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 671        match self {
 672            Content::UntaggedText(txt) => write!(f, "{}", txt),
 673            Content::Chunks(chunks) => {
 674                for chunk in chunks {
 675                    write!(f, "{}", chunk)?;
 676                }
 677                Ok(())
 678            }
 679        }
 680    }
 681}
 682
 683#[derive(Debug, Clone, Serialize, Deserialize)]
 684#[serde(tag = "type", rename_all = "snake_case")]
 685enum ContentChunk {
 686    Text {
 687        text: String,
 688    },
 689    ToolUse {
 690        id: String,
 691        name: String,
 692        input: serde_json::Value,
 693    },
 694    ToolResult {
 695        content: Content,
 696        tool_use_id: String,
 697    },
 698    Thinking {
 699        thinking: String,
 700    },
 701    RedactedThinking,
 702    // TODO
 703    Image,
 704    Document,
 705    WebSearchToolResult,
 706    #[serde(untagged)]
 707    UntaggedText(String),
 708}
 709
 710impl Display for ContentChunk {
 711    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 712        match self {
 713            ContentChunk::Text { text } => write!(f, "{}", text),
 714            ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
 715            ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
 716            ContentChunk::UntaggedText(text) => write!(f, "{}", text),
 717            ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
 718            ContentChunk::Image
 719            | ContentChunk::Document
 720            | ContentChunk::ToolUse { .. }
 721            | ContentChunk::WebSearchToolResult => {
 722                write!(f, "\n{:?}\n", &self)
 723            }
 724        }
 725    }
 726}
 727
 728#[derive(Debug, Clone, Serialize, Deserialize)]
 729struct Usage {
 730    input_tokens: u32,
 731    cache_creation_input_tokens: u32,
 732    cache_read_input_tokens: u32,
 733    output_tokens: u32,
 734    service_tier: String,
 735}
 736
 737#[derive(Debug, Clone, Serialize, Deserialize)]
 738#[serde(rename_all = "snake_case")]
 739enum Role {
 740    System,
 741    Assistant,
 742    User,
 743}
 744
 745#[derive(Debug, Clone, Serialize, Deserialize)]
 746struct MessageParam {
 747    role: Role,
 748    content: String,
 749}
 750
 751#[derive(Debug, Clone, Serialize, Deserialize)]
 752#[serde(tag = "type", rename_all = "snake_case")]
 753enum SdkMessage {
 754    // An assistant message
 755    Assistant {
 756        message: Message, // from Anthropic SDK
 757        #[serde(skip_serializing_if = "Option::is_none")]
 758        session_id: Option<String>,
 759    },
 760    // A user message
 761    User {
 762        message: Message, // from Anthropic SDK
 763        #[serde(skip_serializing_if = "Option::is_none")]
 764        session_id: Option<String>,
 765    },
 766    // Emitted as the last message in a conversation
 767    Result {
 768        subtype: ResultErrorType,
 769        duration_ms: f64,
 770        duration_api_ms: f64,
 771        is_error: bool,
 772        num_turns: i32,
 773        #[serde(skip_serializing_if = "Option::is_none")]
 774        result: Option<String>,
 775        session_id: String,
 776        total_cost_usd: f64,
 777    },
 778    // Emitted as the first message at the start of a conversation
 779    System {
 780        cwd: String,
 781        session_id: String,
 782        tools: Vec<String>,
 783        model: String,
 784        mcp_servers: Vec<McpServer>,
 785        #[serde(rename = "apiKeySource")]
 786        api_key_source: String,
 787        #[serde(rename = "permissionMode")]
 788        permission_mode: PermissionMode,
 789    },
 790    /// Messages used to control the conversation, outside of chat messages to the model
 791    ControlRequest {
 792        request_id: String,
 793        request: ControlRequest,
 794    },
 795    /// Response to a control request
 796    ControlResponse { response: ControlResponse },
 797}
 798
 799#[derive(Debug, Clone, Serialize, Deserialize)]
 800#[serde(tag = "subtype", rename_all = "snake_case")]
 801enum ControlRequest {
 802    /// Cancel the current conversation
 803    Interrupt,
 804}
 805
 806#[derive(Debug, Clone, Serialize, Deserialize)]
 807struct ControlResponse {
 808    request_id: String,
 809    subtype: ResultErrorType,
 810}
 811
 812#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
 813#[serde(rename_all = "snake_case")]
 814enum ResultErrorType {
 815    Success,
 816    ErrorMaxTurns,
 817    ErrorDuringExecution,
 818}
 819
 820impl Display for ResultErrorType {
 821    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 822        match self {
 823            ResultErrorType::Success => write!(f, "success"),
 824            ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
 825            ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
 826        }
 827    }
 828}
 829
 830fn new_request_id() -> String {
 831    use rand::Rng;
 832    // In the Claude Code TS SDK they just generate a random 12 character string,
 833    // `Math.random().toString(36).substring(2, 15)`
 834    rand::thread_rng()
 835        .sample_iter(&rand::distributions::Alphanumeric)
 836        .take(12)
 837        .map(char::from)
 838        .collect()
 839}
 840
 841#[derive(Debug, Clone, Serialize, Deserialize)]
 842struct McpServer {
 843    name: String,
 844    status: String,
 845}
 846
 847#[derive(Debug, Clone, Serialize, Deserialize)]
 848#[serde(rename_all = "camelCase")]
 849enum PermissionMode {
 850    Default,
 851    AcceptEdits,
 852    BypassPermissions,
 853    Plan,
 854}
 855
 856#[cfg(test)]
 857pub(crate) mod tests {
 858    use super::*;
 859    use crate::e2e_tests;
 860    use gpui::TestAppContext;
 861    use serde_json::json;
 862
 863    crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
 864
 865    pub fn local_command() -> AgentServerCommand {
 866        AgentServerCommand {
 867            path: "claude".into(),
 868            args: vec![],
 869            env: None,
 870        }
 871    }
 872
 873    #[gpui::test]
 874    #[cfg_attr(not(feature = "e2e"), ignore)]
 875    async fn test_todo_plan(cx: &mut TestAppContext) {
 876        let fs = e2e_tests::init_test(cx).await;
 877        let project = Project::test(fs, [], cx).await;
 878        let thread =
 879            e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
 880
 881        thread
 882            .update(cx, |thread, cx| {
 883                thread.send_raw(
 884                    "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
 885                    cx,
 886                )
 887            })
 888            .await
 889            .unwrap();
 890
 891        let mut entries_len = 0;
 892
 893        thread.read_with(cx, |thread, _| {
 894            entries_len = thread.plan().entries.len();
 895            assert!(thread.plan().entries.len() > 0, "Empty plan");
 896        });
 897
 898        thread
 899            .update(cx, |thread, cx| {
 900                thread.send_raw(
 901                    "Mark the first entry status as in progress without acting on it.",
 902                    cx,
 903                )
 904            })
 905            .await
 906            .unwrap();
 907
 908        thread.read_with(cx, |thread, _| {
 909            assert!(matches!(
 910                thread.plan().entries[0].status,
 911                acp::PlanEntryStatus::InProgress
 912            ));
 913            assert_eq!(thread.plan().entries.len(), entries_len);
 914        });
 915
 916        thread
 917            .update(cx, |thread, cx| {
 918                thread.send_raw(
 919                    "Now mark the first entry as completed without acting on it.",
 920                    cx,
 921                )
 922            })
 923            .await
 924            .unwrap();
 925
 926        thread.read_with(cx, |thread, _| {
 927            assert!(matches!(
 928                thread.plan().entries[0].status,
 929                acp::PlanEntryStatus::Completed
 930            ));
 931            assert_eq!(thread.plan().entries.len(), entries_len);
 932        });
 933    }
 934
 935    #[test]
 936    fn test_deserialize_content_untagged_text() {
 937        let json = json!("Hello, world!");
 938        let content: Content = serde_json::from_value(json).unwrap();
 939        match content {
 940            Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
 941            _ => panic!("Expected UntaggedText variant"),
 942        }
 943    }
 944
 945    #[test]
 946    fn test_deserialize_content_chunks() {
 947        let json = json!([
 948            {
 949                "type": "text",
 950                "text": "Hello"
 951            },
 952            {
 953                "type": "tool_use",
 954                "id": "tool_123",
 955                "name": "calculator",
 956                "input": {"operation": "add", "a": 1, "b": 2}
 957            }
 958        ]);
 959        let content: Content = serde_json::from_value(json).unwrap();
 960        match content {
 961            Content::Chunks(chunks) => {
 962                assert_eq!(chunks.len(), 2);
 963                match &chunks[0] {
 964                    ContentChunk::Text { text } => assert_eq!(text, "Hello"),
 965                    _ => panic!("Expected Text chunk"),
 966                }
 967                match &chunks[1] {
 968                    ContentChunk::ToolUse { id, name, input } => {
 969                        assert_eq!(id, "tool_123");
 970                        assert_eq!(name, "calculator");
 971                        assert_eq!(input["operation"], "add");
 972                        assert_eq!(input["a"], 1);
 973                        assert_eq!(input["b"], 2);
 974                    }
 975                    _ => panic!("Expected ToolUse chunk"),
 976                }
 977            }
 978            _ => panic!("Expected Chunks variant"),
 979        }
 980    }
 981
 982    #[test]
 983    fn test_deserialize_tool_result_untagged_text() {
 984        let json = json!({
 985            "type": "tool_result",
 986            "content": "Result content",
 987            "tool_use_id": "tool_456"
 988        });
 989        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
 990        match chunk {
 991            ContentChunk::ToolResult {
 992                content,
 993                tool_use_id,
 994            } => {
 995                match content {
 996                    Content::UntaggedText(text) => assert_eq!(text, "Result content"),
 997                    _ => panic!("Expected UntaggedText content"),
 998                }
 999                assert_eq!(tool_use_id, "tool_456");
1000            }
1001            _ => panic!("Expected ToolResult variant"),
1002        }
1003    }
1004
1005    #[test]
1006    fn test_deserialize_tool_result_chunks() {
1007        let json = json!({
1008            "type": "tool_result",
1009            "content": [
1010                {
1011                    "type": "text",
1012                    "text": "Processing complete"
1013                },
1014                {
1015                    "type": "text",
1016                    "text": "Result: 42"
1017                }
1018            ],
1019            "tool_use_id": "tool_789"
1020        });
1021        let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1022        match chunk {
1023            ContentChunk::ToolResult {
1024                content,
1025                tool_use_id,
1026            } => {
1027                match content {
1028                    Content::Chunks(chunks) => {
1029                        assert_eq!(chunks.len(), 2);
1030                        match &chunks[0] {
1031                            ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1032                            _ => panic!("Expected Text chunk"),
1033                        }
1034                        match &chunks[1] {
1035                            ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1036                            _ => panic!("Expected Text chunk"),
1037                        }
1038                    }
1039                    _ => panic!("Expected Chunks content"),
1040                }
1041                assert_eq!(tool_use_id, "tool_789");
1042            }
1043            _ => panic!("Expected ToolResult variant"),
1044        }
1045    }
1046}