anthropic.rs

  1pub mod telemetry;
  2
  3use anthropic::{ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode};
  4use anyhow::Result;
  5use collections::BTreeMap;
  6use credentials_provider::CredentialsProvider;
  7use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  8use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
  9use http_client::HttpClient;
 10use language_model::{
 11    ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, ApiKeyState, AuthenticateError,
 12    ConfigurationViewTargetAgent, EnvVar, IconOrSvg, LanguageModel,
 13    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
 14    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 15    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 16    LanguageModelToolChoice, RateLimiter, env_var,
 17};
 18use settings::{Settings, SettingsStore};
 19use std::sync::{Arc, LazyLock};
 20use strum::IntoEnumIterator;
 21use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
 22use ui_input::InputField;
 23use util::ResultExt;
 24
 25pub use anthropic::completion::{AnthropicEventMapper, into_anthropic};
 26pub use settings::AnthropicAvailableModel as AvailableModel;
 27
 28const PROVIDER_ID: LanguageModelProviderId = ANTHROPIC_PROVIDER_ID;
 29const PROVIDER_NAME: LanguageModelProviderName = ANTHROPIC_PROVIDER_NAME;
 30
 31#[derive(Default, Clone, Debug, PartialEq)]
 32pub struct AnthropicSettings {
 33    pub api_url: String,
 34    /// Extend Zed's list of Anthropic models.
 35    pub available_models: Vec<AvailableModel>,
 36}
 37
 38pub struct AnthropicLanguageModelProvider {
 39    http_client: Arc<dyn HttpClient>,
 40    state: Entity<State>,
 41}
 42
 43const API_KEY_ENV_VAR_NAME: &str = "ANTHROPIC_API_KEY";
 44static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
 45
 46pub struct State {
 47    api_key_state: ApiKeyState,
 48    credentials_provider: Arc<dyn CredentialsProvider>,
 49}
 50
 51impl State {
 52    fn is_authenticated(&self) -> bool {
 53        self.api_key_state.has_key()
 54    }
 55
 56    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
 57        let credentials_provider = self.credentials_provider.clone();
 58        let api_url = AnthropicLanguageModelProvider::api_url(cx);
 59        self.api_key_state.store(
 60            api_url,
 61            api_key,
 62            |this| &mut this.api_key_state,
 63            credentials_provider,
 64            cx,
 65        )
 66    }
 67
 68    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 69        let credentials_provider = self.credentials_provider.clone();
 70        let api_url = AnthropicLanguageModelProvider::api_url(cx);
 71        self.api_key_state.load_if_needed(
 72            api_url,
 73            |this| &mut this.api_key_state,
 74            credentials_provider,
 75            cx,
 76        )
 77    }
 78}
 79
 80impl AnthropicLanguageModelProvider {
 81    pub fn new(
 82        http_client: Arc<dyn HttpClient>,
 83        credentials_provider: Arc<dyn CredentialsProvider>,
 84        cx: &mut App,
 85    ) -> Self {
 86        let state = cx.new(|cx| {
 87            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 88                let credentials_provider = this.credentials_provider.clone();
 89                let api_url = Self::api_url(cx);
 90                this.api_key_state.handle_url_change(
 91                    api_url,
 92                    |this| &mut this.api_key_state,
 93                    credentials_provider,
 94                    cx,
 95                );
 96                cx.notify();
 97            })
 98            .detach();
 99            State {
100                api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
101                credentials_provider,
102            }
103        });
104
105        Self { http_client, state }
106    }
107
108    fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
109        Arc::new(AnthropicModel {
110            id: LanguageModelId::from(model.id().to_string()),
111            model,
112            state: self.state.clone(),
113            http_client: self.http_client.clone(),
114            request_limiter: RateLimiter::new(4),
115        })
116    }
117
118    fn settings(cx: &App) -> &AnthropicSettings {
119        &crate::AllLanguageModelSettings::get_global(cx).anthropic
120    }
121
122    fn api_url(cx: &App) -> SharedString {
123        let api_url = &Self::settings(cx).api_url;
124        if api_url.is_empty() {
125            ANTHROPIC_API_URL.into()
126        } else {
127            SharedString::new(api_url.as_str())
128        }
129    }
130}
131
132impl LanguageModelProviderState for AnthropicLanguageModelProvider {
133    type ObservableEntity = State;
134
135    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
136        Some(self.state.clone())
137    }
138}
139
140impl LanguageModelProvider for AnthropicLanguageModelProvider {
141    fn id(&self) -> LanguageModelProviderId {
142        PROVIDER_ID
143    }
144
145    fn name(&self) -> LanguageModelProviderName {
146        PROVIDER_NAME
147    }
148
149    fn icon(&self) -> IconOrSvg {
150        IconOrSvg::Icon(IconName::AiAnthropic)
151    }
152
153    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
154        Some(self.create_language_model(anthropic::Model::default()))
155    }
156
157    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
158        Some(self.create_language_model(anthropic::Model::default_fast()))
159    }
160
161    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
162        [anthropic::Model::ClaudeSonnet4_6]
163            .into_iter()
164            .map(|model| self.create_language_model(model))
165            .collect()
166    }
167
168    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
169        let mut models = BTreeMap::default();
170
171        // Add base models from anthropic::Model::iter()
172        for model in anthropic::Model::iter() {
173            if !matches!(model, anthropic::Model::Custom { .. }) {
174                models.insert(model.id().to_string(), model);
175            }
176        }
177
178        // Override with available models from settings
179        for model in &AnthropicLanguageModelProvider::settings(cx).available_models {
180            models.insert(
181                model.name.clone(),
182                anthropic::Model::Custom {
183                    name: model.name.clone(),
184                    display_name: model.display_name.clone(),
185                    max_tokens: model.max_tokens,
186                    tool_override: model.tool_override.clone(),
187                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
188                        anthropic::AnthropicModelCacheConfiguration {
189                            max_cache_anchors: config.max_cache_anchors,
190                            should_speculate: config.should_speculate,
191                            min_total_token: config.min_total_token,
192                        }
193                    }),
194                    max_output_tokens: model.max_output_tokens,
195                    default_temperature: model.default_temperature,
196                    extra_beta_headers: model.extra_beta_headers.clone(),
197                    mode: match model.mode.unwrap_or_default() {
198                        settings::ModelMode::Default => AnthropicModelMode::Default,
199                        settings::ModelMode::Thinking { budget_tokens } => {
200                            AnthropicModelMode::Thinking { budget_tokens }
201                        }
202                    },
203                },
204            );
205        }
206
207        models
208            .into_values()
209            .map(|model| self.create_language_model(model))
210            .collect()
211    }
212
213    fn is_authenticated(&self, cx: &App) -> bool {
214        self.state.read(cx).is_authenticated()
215    }
216
217    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
218        self.state.update(cx, |state, cx| state.authenticate(cx))
219    }
220
221    fn configuration_view(
222        &self,
223        target_agent: ConfigurationViewTargetAgent,
224        window: &mut Window,
225        cx: &mut App,
226    ) -> AnyView {
227        cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
228            .into()
229    }
230
231    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
232        self.state
233            .update(cx, |state, cx| state.set_api_key(None, cx))
234    }
235}
236
237pub struct AnthropicModel {
238    id: LanguageModelId,
239    model: anthropic::Model,
240    state: Entity<State>,
241    http_client: Arc<dyn HttpClient>,
242    request_limiter: RateLimiter,
243}
244
245impl AnthropicModel {
246    fn stream_completion(
247        &self,
248        request: anthropic::Request,
249        cx: &AsyncApp,
250    ) -> BoxFuture<
251        'static,
252        Result<
253            BoxStream<'static, Result<anthropic::Event, AnthropicError>>,
254            LanguageModelCompletionError,
255        >,
256    > {
257        let http_client = self.http_client.clone();
258
259        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
260            let api_url = AnthropicLanguageModelProvider::api_url(cx);
261            (state.api_key_state.key(&api_url), api_url)
262        });
263
264        let beta_headers = self.model.beta_headers();
265
266        async move {
267            let Some(api_key) = api_key else {
268                return Err(LanguageModelCompletionError::NoApiKey {
269                    provider: PROVIDER_NAME,
270                });
271            };
272            let request = anthropic::stream_completion(
273                http_client.as_ref(),
274                &api_url,
275                &api_key,
276                request,
277                beta_headers,
278            );
279            request.await.map_err(Into::into)
280        }
281        .boxed()
282    }
283}
284
285impl LanguageModel for AnthropicModel {
286    fn id(&self) -> LanguageModelId {
287        self.id.clone()
288    }
289
290    fn name(&self) -> LanguageModelName {
291        LanguageModelName::from(self.model.display_name().to_string())
292    }
293
294    fn provider_id(&self) -> LanguageModelProviderId {
295        PROVIDER_ID
296    }
297
298    fn provider_name(&self) -> LanguageModelProviderName {
299        PROVIDER_NAME
300    }
301
302    fn supports_tools(&self) -> bool {
303        true
304    }
305
306    fn supports_images(&self) -> bool {
307        true
308    }
309
310    fn supports_streaming_tools(&self) -> bool {
311        true
312    }
313
314    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
315        match choice {
316            LanguageModelToolChoice::Auto
317            | LanguageModelToolChoice::Any
318            | LanguageModelToolChoice::None => true,
319        }
320    }
321
322    fn supports_thinking(&self) -> bool {
323        self.model.supports_thinking()
324    }
325
326    fn supports_fast_mode(&self) -> bool {
327        self.model.supports_speed()
328    }
329
330    fn supported_effort_levels(&self) -> Vec<language_model::LanguageModelEffortLevel> {
331        if self.model.supports_adaptive_thinking() {
332            vec![
333                language_model::LanguageModelEffortLevel {
334                    name: "Low".into(),
335                    value: "low".into(),
336                    is_default: false,
337                },
338                language_model::LanguageModelEffortLevel {
339                    name: "Medium".into(),
340                    value: "medium".into(),
341                    is_default: false,
342                },
343                language_model::LanguageModelEffortLevel {
344                    name: "High".into(),
345                    value: "high".into(),
346                    is_default: true,
347                },
348                language_model::LanguageModelEffortLevel {
349                    name: "Max".into(),
350                    value: "max".into(),
351                    is_default: false,
352                },
353            ]
354        } else {
355            Vec::new()
356        }
357    }
358
359    fn telemetry_id(&self) -> String {
360        format!("anthropic/{}", self.model.id())
361    }
362
363    fn api_key(&self, cx: &App) -> Option<String> {
364        self.state.read_with(cx, |state, cx| {
365            let api_url = AnthropicLanguageModelProvider::api_url(cx);
366            state.api_key_state.key(&api_url).map(|key| key.to_string())
367        })
368    }
369
370    fn max_token_count(&self) -> u64 {
371        self.model.max_token_count()
372    }
373
374    fn max_output_tokens(&self) -> Option<u64> {
375        Some(self.model.max_output_tokens())
376    }
377
378    fn stream_completion(
379        &self,
380        request: LanguageModelRequest,
381        cx: &AsyncApp,
382    ) -> BoxFuture<
383        'static,
384        Result<
385            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
386            LanguageModelCompletionError,
387        >,
388    > {
389        let mut request = into_anthropic(
390            request,
391            self.model.request_id().into(),
392            self.model.default_temperature(),
393            self.model.max_output_tokens(),
394            self.model.mode(),
395        );
396        if !self.model.supports_speed() {
397            request.speed = None;
398        }
399        let request = self.stream_completion(request, cx);
400        let future = self.request_limiter.stream(async move {
401            let response = request.await?;
402            Ok(AnthropicEventMapper::new().map_stream(response))
403        });
404        async move { Ok(future.await?.boxed()) }.boxed()
405    }
406
407    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
408        self.model
409            .cache_configuration()
410            .map(|config| LanguageModelCacheConfiguration {
411                max_cache_anchors: config.max_cache_anchors,
412                should_speculate: config.should_speculate,
413                min_total_token: config.min_total_token,
414            })
415    }
416}
417
418struct ConfigurationView {
419    api_key_editor: Entity<InputField>,
420    state: Entity<State>,
421    load_credentials_task: Option<Task<()>>,
422    target_agent: ConfigurationViewTargetAgent,
423}
424
425impl ConfigurationView {
426    const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
427
428    fn new(
429        state: Entity<State>,
430        target_agent: ConfigurationViewTargetAgent,
431        window: &mut Window,
432        cx: &mut Context<Self>,
433    ) -> Self {
434        cx.observe(&state, |_, _, cx| {
435            cx.notify();
436        })
437        .detach();
438
439        let load_credentials_task = Some(cx.spawn({
440            let state = state.clone();
441            async move |this, cx| {
442                let task = state.update(cx, |state, cx| state.authenticate(cx));
443                // We don't log an error, because "not signed in" is also an error.
444                let _ = task.await;
445                this.update(cx, |this, cx| {
446                    this.load_credentials_task = None;
447                    cx.notify();
448                })
449                .log_err();
450            }
451        }));
452
453        Self {
454            api_key_editor: cx.new(|cx| InputField::new(window, cx, Self::PLACEHOLDER_TEXT)),
455            state,
456            load_credentials_task,
457            target_agent,
458        }
459    }
460
461    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
462        let api_key = self.api_key_editor.read(cx).text(cx);
463        if api_key.is_empty() {
464            return;
465        }
466
467        // url changes can cause the editor to be displayed again
468        self.api_key_editor
469            .update(cx, |editor, cx| editor.set_text("", window, cx));
470
471        let state = self.state.clone();
472        cx.spawn_in(window, async move |_, cx| {
473            state
474                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
475                .await
476        })
477        .detach_and_log_err(cx);
478    }
479
480    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
481        self.api_key_editor
482            .update(cx, |editor, cx| editor.set_text("", window, cx));
483
484        let state = self.state.clone();
485        cx.spawn_in(window, async move |_, cx| {
486            state
487                .update(cx, |state, cx| state.set_api_key(None, cx))
488                .await
489        })
490        .detach_and_log_err(cx);
491    }
492
493    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
494        !self.state.read(cx).is_authenticated()
495    }
496}
497
498impl Render for ConfigurationView {
499    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
500        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
501        let configured_card_label = if env_var_set {
502            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
503        } else {
504            let api_url = AnthropicLanguageModelProvider::api_url(cx);
505            if api_url == ANTHROPIC_API_URL {
506                "API key configured".to_string()
507            } else {
508                format!("API key configured for {}", api_url)
509            }
510        };
511
512        if self.load_credentials_task.is_some() {
513            div()
514                .child(Label::new("Loading credentials..."))
515                .into_any_element()
516        } else if self.should_render_editor(cx) {
517            v_flex()
518                .size_full()
519                .on_action(cx.listener(Self::save_api_key))
520                .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
521                    ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Anthropic".into(),
522                    ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
523                })))
524                .child(
525                    List::new()
526                        .child(
527                            ListBulletItem::new("")
528                                .child(Label::new("Create one by visiting"))
529                                .child(ButtonLink::new("Anthropic's settings", "https://console.anthropic.com/settings/keys"))
530                        )
531                        .child(
532                            ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
533                        )
534                )
535                .child(self.api_key_editor.clone())
536                .child(
537                    Label::new(
538                        format!("You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
539                    )
540                    .size(LabelSize::Small)
541                    .color(Color::Muted)
542                    .mt_0p5(),
543                )
544                .into_any_element()
545        } else {
546            ConfiguredApiCard::new(configured_card_label)
547                .disabled(env_var_set)
548                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
549                .when(env_var_set, |this| {
550                    this.tooltip_label(format!(
551                    "To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."
552                ))
553                })
554                .into_any_element()
555        }
556    }
557}