Use anthropic provider key for CC

Agus Zubiaga created

Change summary

Cargo.lock                                       |  2 
crates/agent_servers/Cargo.toml                  |  2 
crates/agent_servers/src/claude.rs               | 51 +++++++++++------
crates/language_models/src/provider/anthropic.rs | 53 ++++++++++++-----
4 files changed, 73 insertions(+), 35 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -267,6 +267,8 @@ dependencies = [
  "indoc",
  "itertools 0.14.0",
  "language",
+ "language_model",
+ "language_models",
  "libc",
  "log",
  "nix 0.29.0",

crates/agent_servers/Cargo.toml 🔗

@@ -27,6 +27,8 @@ futures.workspace = true
 gpui.workspace = true
 indoc.workspace = true
 itertools.workspace = true
+language_model.workspace = true
+language_models.workspace = true
 log.workspace = true
 paths.workspace = true
 project.workspace = true

crates/agent_servers/src/claude.rs 🔗

@@ -30,7 +30,7 @@ use util::{ResultExt, debug_panic};
 use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
 use crate::claude::tools::ClaudeTool;
 use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
-use acp_thread::{AcpThread, AgentConnection};
+use acp_thread::{AcpThread, AgentConnection, AuthRequired};
 
 #[derive(Clone)]
 pub struct ClaudeCode;
@@ -79,6 +79,34 @@ impl AgentConnection for ClaudeAgentConnection {
     ) -> Task<Result<Entity<AcpThread>>> {
         let cwd = cwd.to_owned();
         cx.spawn(async move |cx| {
+            let settings = cx.read_global(|settings: &SettingsStore, _| {
+                settings.get::<AllAgentServersSettings>(None).claude.clone()
+            })?;
+
+            let Some(command) = AgentServerCommand::resolve(
+                "claude",
+                &[],
+                Some(&util::paths::home_dir().join(".claude/local/claude")),
+                settings,
+                &project,
+                cx,
+            )
+            .await
+            else {
+                anyhow::bail!("Failed to find claude binary");
+            };
+
+            let api_key = cx
+                .update(|cx| language_models::provider::anthropic::ApiKey::get(cx))?
+                .await
+                .map_err(|err| {
+                    if err.is::<language_model::AuthenticateError>() {
+                        anyhow!(AuthRequired)
+                    } else {
+                        anyhow!(err)
+                    }
+                })?;
+
             let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
             let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?;
 
@@ -98,23 +126,6 @@ impl AgentConnection for ClaudeAgentConnection {
                 .await?;
             mcp_config_file.flush().await?;
 
-            let settings = cx.read_global(|settings: &SettingsStore, _| {
-                settings.get::<AllAgentServersSettings>(None).claude.clone()
-            })?;
-
-            let Some(command) = AgentServerCommand::resolve(
-                "claude",
-                &[],
-                Some(&util::paths::home_dir().join(".claude/local/claude")),
-                settings,
-                &project,
-                cx,
-            )
-            .await
-            else {
-                anyhow::bail!("Failed to find claude binary");
-            };
-
             let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
             let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
 
@@ -126,6 +137,7 @@ impl AgentConnection for ClaudeAgentConnection {
                 &command,
                 ClaudeSessionMode::Start,
                 session_id.clone(),
+                api_key,
                 &mcp_config_path,
                 &cwd,
             )?;
@@ -320,6 +332,7 @@ fn spawn_claude(
     command: &AgentServerCommand,
     mode: ClaudeSessionMode,
     session_id: acp::SessionId,
+    api_key: language_models::provider::anthropic::ApiKey,
     mcp_config_path: &Path,
     root_dir: &Path,
 ) -> Result<Child> {
@@ -355,6 +368,8 @@ fn spawn_claude(
             ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
         })
         .args(command.args.iter().map(|arg| arg.as_str()))
+        .envs(command.env.iter().flatten())
+        .env("ANTHROPIC_API_KEY", api_key.key)
         .current_dir(root_dir)
         .stdin(std::process::Stdio::piped())
         .stdout(std::process::Stdio::piped())

crates/language_models/src/provider/anthropic.rs 🔗

@@ -153,34 +153,53 @@ impl State {
             return Task::ready(Ok(()));
         }
 
+        let key = ApiKey::get(cx);
+
+        cx.spawn(async move |this, cx| {
+            let key = key.await?;
+
+            this.update(cx, |this, cx| {
+                this.api_key = Some(key.key);
+                this.api_key_from_env = key.from_env;
+                cx.notify();
+            })?;
+
+            Ok(())
+        })
+    }
+}
+
+pub struct ApiKey {
+    pub key: String,
+    pub from_env: bool,
+}
+
+impl ApiKey {
+    pub fn get(cx: &mut App) -> Task<Result<Self>> {
         let credentials_provider = <dyn CredentialsProvider>::global(cx);
         let api_url = AllLanguageModelSettings::get_global(cx)
             .anthropic
             .api_url
             .clone();
 
-        cx.spawn(async move |this, cx| {
-            let (api_key, from_env) = if let Ok(api_key) = std::env::var(ANTHROPIC_API_KEY_VAR) {
-                (api_key, true)
-            } else {
+        if let Ok(key) = std::env::var(ANTHROPIC_API_KEY_VAR) {
+            Task::ready(Ok(ApiKey {
+                key,
+                from_env: true,
+            }))
+        } else {
+            cx.spawn(async move |cx| {
                 let (_, api_key) = credentials_provider
                     .read_credentials(&api_url, &cx)
                     .await?
                     .ok_or(AuthenticateError::CredentialsNotFound)?;
-                (
-                    String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
-                    false,
-                )
-            };
-
-            this.update(cx, |this, cx| {
-                this.api_key = Some(api_key);
-                this.api_key_from_env = from_env;
-                cx.notify();
-            })?;
 
-            Ok(())
-        })
+                Ok(ApiKey {
+                    key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
+                    from_env: false,
+                })
+            })
+        }
     }
 }