anthropic.rs

   1use crate::AllLanguageModelSettings;
   2use crate::ui::InstructionListItem;
   3use anthropic::{AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, Usage};
   4use anyhow::{Context as _, Result, anyhow};
   5use collections::{BTreeMap, HashMap};
   6use credentials_provider::CredentialsProvider;
   7use editor::{Editor, EditorElement, EditorStyle};
   8use futures::Stream;
   9use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  10use gpui::{
  11    AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
  12};
  13use http_client::HttpClient;
  14use language_model::{
  15    AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
  16    LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
  17    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  18    LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, MessageContent,
  19    RateLimiter, Role,
  20};
  21use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
  22use schemars::JsonSchema;
  23use serde::{Deserialize, Serialize};
  24use settings::{Settings, SettingsStore};
  25use std::pin::Pin;
  26use std::str::FromStr;
  27use std::sync::Arc;
  28use strum::IntoEnumIterator;
  29use theme::ThemeSettings;
  30use ui::{Icon, IconName, List, Tooltip, prelude::*};
  31use util::ResultExt;
  32
  33const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID;
  34const PROVIDER_NAME: &str = "Anthropic";
  35
  36#[derive(Default, Clone, Debug, PartialEq)]
  37pub struct AnthropicSettings {
  38    pub api_url: String,
  39    /// Extend Zed's list of Anthropic models.
  40    pub available_models: Vec<AvailableModel>,
  41    pub needs_setting_migration: bool,
  42}
  43
  44#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  45pub struct AvailableModel {
  46    /// The model's name in the Anthropic API. e.g. claude-3-5-sonnet-latest, claude-3-opus-20240229, etc
  47    pub name: String,
  48    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
  49    pub display_name: Option<String>,
  50    /// The model's context window size.
  51    pub max_tokens: usize,
  52    /// A model `name` to substitute when calling tools, in case the primary model doesn't support tool calling.
  53    pub tool_override: Option<String>,
  54    /// Configuration of Anthropic's caching API.
  55    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
  56    pub max_output_tokens: Option<u32>,
  57    pub default_temperature: Option<f32>,
  58    #[serde(default)]
  59    pub extra_beta_headers: Vec<String>,
  60    /// The model's mode (e.g. thinking)
  61    pub mode: Option<ModelMode>,
  62}
  63
  64#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
  65#[serde(tag = "type", rename_all = "lowercase")]
  66pub enum ModelMode {
  67    #[default]
  68    Default,
  69    Thinking {
  70        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  71        budget_tokens: Option<u32>,
  72    },
  73}
  74
  75impl From<ModelMode> for AnthropicModelMode {
  76    fn from(value: ModelMode) -> Self {
  77        match value {
  78            ModelMode::Default => AnthropicModelMode::Default,
  79            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  80        }
  81    }
  82}
  83
  84impl From<AnthropicModelMode> for ModelMode {
  85    fn from(value: AnthropicModelMode) -> Self {
  86        match value {
  87            AnthropicModelMode::Default => ModelMode::Default,
  88            AnthropicModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
  89        }
  90    }
  91}
  92
  93pub struct AnthropicLanguageModelProvider {
  94    http_client: Arc<dyn HttpClient>,
  95    state: gpui::Entity<State>,
  96}
  97
  98const ANTHROPIC_API_KEY_VAR: &str = "ANTHROPIC_API_KEY";
  99
 100pub struct State {
 101    api_key: Option<String>,
 102    api_key_from_env: bool,
 103    _subscription: Subscription,
 104}
 105
 106impl State {
 107    fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 108        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 109        let api_url = AllLanguageModelSettings::get_global(cx)
 110            .anthropic
 111            .api_url
 112            .clone();
 113        cx.spawn(async move |this, cx| {
 114            credentials_provider
 115                .delete_credentials(&api_url, &cx)
 116                .await
 117                .ok();
 118            this.update(cx, |this, cx| {
 119                this.api_key = None;
 120                this.api_key_from_env = false;
 121                cx.notify();
 122            })
 123        })
 124    }
 125
 126    fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
 127        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 128        let api_url = AllLanguageModelSettings::get_global(cx)
 129            .anthropic
 130            .api_url
 131            .clone();
 132        cx.spawn(async move |this, cx| {
 133            credentials_provider
 134                .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
 135                .await
 136                .ok();
 137
 138            this.update(cx, |this, cx| {
 139                this.api_key = Some(api_key);
 140                cx.notify();
 141            })
 142        })
 143    }
 144
 145    fn is_authenticated(&self) -> bool {
 146        self.api_key.is_some()
 147    }
 148
 149    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 150        if self.is_authenticated() {
 151            return Task::ready(Ok(()));
 152        }
 153
 154        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 155        let api_url = AllLanguageModelSettings::get_global(cx)
 156            .anthropic
 157            .api_url
 158            .clone();
 159
 160        cx.spawn(async move |this, cx| {
 161            let (api_key, from_env) = if let Ok(api_key) = std::env::var(ANTHROPIC_API_KEY_VAR) {
 162                (api_key, true)
 163            } else {
 164                let (_, api_key) = credentials_provider
 165                    .read_credentials(&api_url, &cx)
 166                    .await?
 167                    .ok_or(AuthenticateError::CredentialsNotFound)?;
 168                (
 169                    String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
 170                    false,
 171                )
 172            };
 173
 174            this.update(cx, |this, cx| {
 175                this.api_key = Some(api_key);
 176                this.api_key_from_env = from_env;
 177                cx.notify();
 178            })?;
 179
 180            Ok(())
 181        })
 182    }
 183}
 184
 185impl AnthropicLanguageModelProvider {
 186    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
 187        let state = cx.new(|cx| State {
 188            api_key: None,
 189            api_key_from_env: false,
 190            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 191                cx.notify();
 192            }),
 193        });
 194
 195        Self { http_client, state }
 196    }
 197
 198    fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
 199        Arc::new(AnthropicModel {
 200            id: LanguageModelId::from(model.id().to_string()),
 201            model,
 202            state: self.state.clone(),
 203            http_client: self.http_client.clone(),
 204            request_limiter: RateLimiter::new(4),
 205        })
 206    }
 207}
 208
 209impl LanguageModelProviderState for AnthropicLanguageModelProvider {
 210    type ObservableEntity = State;
 211
 212    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 213        Some(self.state.clone())
 214    }
 215}
 216
 217impl LanguageModelProvider for AnthropicLanguageModelProvider {
 218    fn id(&self) -> LanguageModelProviderId {
 219        LanguageModelProviderId(PROVIDER_ID.into())
 220    }
 221
 222    fn name(&self) -> LanguageModelProviderName {
 223        LanguageModelProviderName(PROVIDER_NAME.into())
 224    }
 225
 226    fn icon(&self) -> IconName {
 227        IconName::AiAnthropic
 228    }
 229
 230    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 231        Some(self.create_language_model(anthropic::Model::default()))
 232    }
 233
 234    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 235        Some(self.create_language_model(anthropic::Model::default_fast()))
 236    }
 237
 238    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 239        [
 240            anthropic::Model::Claude3_7Sonnet,
 241            anthropic::Model::Claude3_7SonnetThinking,
 242        ]
 243        .into_iter()
 244        .map(|model| self.create_language_model(model))
 245        .collect()
 246    }
 247
 248    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 249        let mut models = BTreeMap::default();
 250
 251        // Add base models from anthropic::Model::iter()
 252        for model in anthropic::Model::iter() {
 253            if !matches!(model, anthropic::Model::Custom { .. }) {
 254                models.insert(model.id().to_string(), model);
 255            }
 256        }
 257
 258        // Override with available models from settings
 259        for model in AllLanguageModelSettings::get_global(cx)
 260            .anthropic
 261            .available_models
 262            .iter()
 263        {
 264            models.insert(
 265                model.name.clone(),
 266                anthropic::Model::Custom {
 267                    name: model.name.clone(),
 268                    display_name: model.display_name.clone(),
 269                    max_tokens: model.max_tokens,
 270                    tool_override: model.tool_override.clone(),
 271                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
 272                        anthropic::AnthropicModelCacheConfiguration {
 273                            max_cache_anchors: config.max_cache_anchors,
 274                            should_speculate: config.should_speculate,
 275                            min_total_token: config.min_total_token,
 276                        }
 277                    }),
 278                    max_output_tokens: model.max_output_tokens,
 279                    default_temperature: model.default_temperature,
 280                    extra_beta_headers: model.extra_beta_headers.clone(),
 281                    mode: model.mode.clone().unwrap_or_default().into(),
 282                },
 283            );
 284        }
 285
 286        models
 287            .into_values()
 288            .map(|model| self.create_language_model(model))
 289            .collect()
 290    }
 291
 292    fn is_authenticated(&self, cx: &App) -> bool {
 293        self.state.read(cx).is_authenticated()
 294    }
 295
 296    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 297        self.state.update(cx, |state, cx| state.authenticate(cx))
 298    }
 299
 300    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
 301        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
 302            .into()
 303    }
 304
 305    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 306        self.state.update(cx, |state, cx| state.reset_api_key(cx))
 307    }
 308}
 309
 310pub struct AnthropicModel {
 311    id: LanguageModelId,
 312    model: anthropic::Model,
 313    state: gpui::Entity<State>,
 314    http_client: Arc<dyn HttpClient>,
 315    request_limiter: RateLimiter,
 316}
 317
 318pub fn count_anthropic_tokens(
 319    request: LanguageModelRequest,
 320    cx: &App,
 321) -> BoxFuture<'static, Result<usize>> {
 322    cx.background_spawn(async move {
 323        let messages = request.messages;
 324        let mut tokens_from_images = 0;
 325        let mut string_messages = Vec::with_capacity(messages.len());
 326
 327        for message in messages {
 328            use language_model::MessageContent;
 329
 330            let mut string_contents = String::new();
 331
 332            for content in message.content {
 333                match content {
 334                    MessageContent::Text(text) => {
 335                        string_contents.push_str(&text);
 336                    }
 337                    MessageContent::Thinking { .. } => {
 338                        // Thinking blocks are not included in the input token count.
 339                    }
 340                    MessageContent::RedactedThinking(_) => {
 341                        // Thinking blocks are not included in the input token count.
 342                    }
 343                    MessageContent::Image(image) => {
 344                        tokens_from_images += image.estimate_tokens();
 345                    }
 346                    MessageContent::ToolUse(_tool_use) => {
 347                        // TODO: Estimate token usage from tool uses.
 348                    }
 349                    MessageContent::ToolResult(tool_result) => {
 350                        string_contents.push_str(&tool_result.content);
 351                    }
 352                }
 353            }
 354
 355            if !string_contents.is_empty() {
 356                string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
 357                    role: match message.role {
 358                        Role::User => "user".into(),
 359                        Role::Assistant => "assistant".into(),
 360                        Role::System => "system".into(),
 361                    },
 362                    content: Some(string_contents),
 363                    name: None,
 364                    function_call: None,
 365                });
 366            }
 367        }
 368
 369        // Tiktoken doesn't yet support these models, so we manually use the
 370        // same tokenizer as GPT-4.
 371        tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
 372            .map(|tokens| tokens + tokens_from_images)
 373    })
 374    .boxed()
 375}
 376
 377impl AnthropicModel {
 378    fn stream_completion(
 379        &self,
 380        request: anthropic::Request,
 381        cx: &AsyncApp,
 382    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<anthropic::Event, AnthropicError>>>>
 383    {
 384        let http_client = self.http_client.clone();
 385
 386        let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
 387            let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
 388            (state.api_key.clone(), settings.api_url.clone())
 389        }) else {
 390            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
 391        };
 392
 393        async move {
 394            let api_key = api_key.ok_or_else(|| anyhow!("Missing Anthropic API Key"))?;
 395            let request =
 396                anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
 397            request.await.context("failed to stream completion")
 398        }
 399        .boxed()
 400    }
 401}
 402
 403impl LanguageModel for AnthropicModel {
 404    fn id(&self) -> LanguageModelId {
 405        self.id.clone()
 406    }
 407
 408    fn name(&self) -> LanguageModelName {
 409        LanguageModelName::from(self.model.display_name().to_string())
 410    }
 411
 412    fn provider_id(&self) -> LanguageModelProviderId {
 413        LanguageModelProviderId(PROVIDER_ID.into())
 414    }
 415
 416    fn provider_name(&self) -> LanguageModelProviderName {
 417        LanguageModelProviderName(PROVIDER_NAME.into())
 418    }
 419
 420    fn supports_tools(&self) -> bool {
 421        true
 422    }
 423
 424    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 425        match choice {
 426            LanguageModelToolChoice::Auto
 427            | LanguageModelToolChoice::Any
 428            | LanguageModelToolChoice::None => true,
 429        }
 430    }
 431
 432    fn telemetry_id(&self) -> String {
 433        format!("anthropic/{}", self.model.id())
 434    }
 435
 436    fn api_key(&self, cx: &App) -> Option<String> {
 437        self.state.read(cx).api_key.clone()
 438    }
 439
 440    fn max_token_count(&self) -> usize {
 441        self.model.max_token_count()
 442    }
 443
 444    fn max_output_tokens(&self) -> Option<u32> {
 445        Some(self.model.max_output_tokens())
 446    }
 447
 448    fn count_tokens(
 449        &self,
 450        request: LanguageModelRequest,
 451        cx: &App,
 452    ) -> BoxFuture<'static, Result<usize>> {
 453        count_anthropic_tokens(request, cx)
 454    }
 455
 456    fn stream_completion(
 457        &self,
 458        request: LanguageModelRequest,
 459        cx: &AsyncApp,
 460    ) -> BoxFuture<
 461        'static,
 462        Result<
 463            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 464        >,
 465    > {
 466        let request = into_anthropic(
 467            request,
 468            self.model.request_id().into(),
 469            self.model.default_temperature(),
 470            self.model.max_output_tokens(),
 471            self.model.mode(),
 472        );
 473        let request = self.stream_completion(request, cx);
 474        let future = self.request_limiter.stream(async move {
 475            let response = request
 476                .await
 477                .map_err(|err| match err.downcast::<AnthropicError>() {
 478                    Ok(anthropic_err) => anthropic_err_to_anyhow(anthropic_err),
 479                    Err(err) => anyhow!(err),
 480                })?;
 481            Ok(AnthropicEventMapper::new().map_stream(response))
 482        });
 483        async move { Ok(future.await?.boxed()) }.boxed()
 484    }
 485
 486    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 487        self.model
 488            .cache_configuration()
 489            .map(|config| LanguageModelCacheConfiguration {
 490                max_cache_anchors: config.max_cache_anchors,
 491                should_speculate: config.should_speculate,
 492                min_total_token: config.min_total_token,
 493            })
 494    }
 495}
 496
 497pub fn into_anthropic(
 498    request: LanguageModelRequest,
 499    model: String,
 500    default_temperature: f32,
 501    max_output_tokens: u32,
 502    mode: AnthropicModelMode,
 503) -> anthropic::Request {
 504    let mut new_messages: Vec<anthropic::Message> = Vec::new();
 505    let mut system_message = String::new();
 506
 507    for message in request.messages {
 508        if message.contents_empty() {
 509            continue;
 510        }
 511
 512        match message.role {
 513            Role::User | Role::Assistant => {
 514                let cache_control = if message.cache {
 515                    Some(anthropic::CacheControl {
 516                        cache_type: anthropic::CacheControlType::Ephemeral,
 517                    })
 518                } else {
 519                    None
 520                };
 521                let anthropic_message_content: Vec<anthropic::RequestContent> = message
 522                    .content
 523                    .into_iter()
 524                    .filter_map(|content| match content {
 525                        MessageContent::Text(text) => {
 526                            if !text.is_empty() {
 527                                Some(anthropic::RequestContent::Text {
 528                                    text,
 529                                    cache_control,
 530                                })
 531                            } else {
 532                                None
 533                            }
 534                        }
 535                        MessageContent::Thinking {
 536                            text: thinking,
 537                            signature,
 538                        } => {
 539                            if !thinking.is_empty() {
 540                                Some(anthropic::RequestContent::Thinking {
 541                                    thinking,
 542                                    signature: signature.unwrap_or_default(),
 543                                    cache_control,
 544                                })
 545                            } else {
 546                                None
 547                            }
 548                        }
 549                        MessageContent::RedactedThinking(data) => {
 550                            if !data.is_empty() {
 551                                Some(anthropic::RequestContent::RedactedThinking {
 552                                    data: String::from_utf8(data).ok()?,
 553                                })
 554                            } else {
 555                                None
 556                            }
 557                        }
 558                        MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
 559                            source: anthropic::ImageSource {
 560                                source_type: "base64".to_string(),
 561                                media_type: "image/png".to_string(),
 562                                data: image.source.to_string(),
 563                            },
 564                            cache_control,
 565                        }),
 566                        MessageContent::ToolUse(tool_use) => {
 567                            Some(anthropic::RequestContent::ToolUse {
 568                                id: tool_use.id.to_string(),
 569                                name: tool_use.name.to_string(),
 570                                input: tool_use.input,
 571                                cache_control,
 572                            })
 573                        }
 574                        MessageContent::ToolResult(tool_result) => {
 575                            Some(anthropic::RequestContent::ToolResult {
 576                                tool_use_id: tool_result.tool_use_id.to_string(),
 577                                is_error: tool_result.is_error,
 578                                content: tool_result.content.to_string(),
 579                                cache_control,
 580                            })
 581                        }
 582                    })
 583                    .collect();
 584                let anthropic_role = match message.role {
 585                    Role::User => anthropic::Role::User,
 586                    Role::Assistant => anthropic::Role::Assistant,
 587                    Role::System => unreachable!("System role should never occur here"),
 588                };
 589                if let Some(last_message) = new_messages.last_mut() {
 590                    if last_message.role == anthropic_role {
 591                        last_message.content.extend(anthropic_message_content);
 592                        continue;
 593                    }
 594                }
 595                new_messages.push(anthropic::Message {
 596                    role: anthropic_role,
 597                    content: anthropic_message_content,
 598                });
 599            }
 600            Role::System => {
 601                if !system_message.is_empty() {
 602                    system_message.push_str("\n\n");
 603                }
 604                system_message.push_str(&message.string_contents());
 605            }
 606        }
 607    }
 608
 609    anthropic::Request {
 610        model,
 611        messages: new_messages,
 612        max_tokens: max_output_tokens,
 613        system: if system_message.is_empty() {
 614            None
 615        } else {
 616            Some(anthropic::StringOrContents::String(system_message))
 617        },
 618        thinking: if let AnthropicModelMode::Thinking { budget_tokens } = mode {
 619            Some(anthropic::Thinking::Enabled { budget_tokens })
 620        } else {
 621            None
 622        },
 623        tools: request
 624            .tools
 625            .into_iter()
 626            .map(|tool| anthropic::Tool {
 627                name: tool.name,
 628                description: tool.description,
 629                input_schema: tool.input_schema,
 630            })
 631            .collect(),
 632        tool_choice: request.tool_choice.map(|choice| match choice {
 633            LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
 634            LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
 635            LanguageModelToolChoice::None => anthropic::ToolChoice::None,
 636        }),
 637        metadata: None,
 638        stop_sequences: Vec::new(),
 639        temperature: request.temperature.or(Some(default_temperature)),
 640        top_k: None,
 641        top_p: None,
 642    }
 643}
 644
 645pub struct AnthropicEventMapper {
 646    tool_uses_by_index: HashMap<usize, RawToolUse>,
 647    usage: Usage,
 648    stop_reason: StopReason,
 649}
 650
 651impl AnthropicEventMapper {
 652    pub fn new() -> Self {
 653        Self {
 654            tool_uses_by_index: HashMap::default(),
 655            usage: Usage::default(),
 656            stop_reason: StopReason::EndTurn,
 657        }
 658    }
 659
 660    pub fn map_stream(
 661        mut self,
 662        events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
 663    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 664    {
 665        events.flat_map(move |event| {
 666            futures::stream::iter(match event {
 667                Ok(event) => self.map_event(event),
 668                Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
 669            })
 670        })
 671    }
 672
 673    pub fn map_event(
 674        &mut self,
 675        event: Event,
 676    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 677        match event {
 678            Event::ContentBlockStart {
 679                index,
 680                content_block,
 681            } => match content_block {
 682                ResponseContent::Text { text } => {
 683                    vec![Ok(LanguageModelCompletionEvent::Text(text))]
 684                }
 685                ResponseContent::Thinking { thinking } => {
 686                    vec![Ok(LanguageModelCompletionEvent::Thinking {
 687                        text: thinking,
 688                        signature: None,
 689                    })]
 690                }
 691                ResponseContent::RedactedThinking { .. } => {
 692                    // Redacted thinking is encrypted and not accessible to the user, see:
 693                    // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#suggestions-for-handling-redacted-thinking-in-production
 694                    Vec::new()
 695                }
 696                ResponseContent::ToolUse { id, name, .. } => {
 697                    self.tool_uses_by_index.insert(
 698                        index,
 699                        RawToolUse {
 700                            id,
 701                            name,
 702                            input_json: String::new(),
 703                        },
 704                    );
 705                    Vec::new()
 706                }
 707            },
 708            Event::ContentBlockDelta { index, delta } => match delta {
 709                ContentDelta::TextDelta { text } => {
 710                    vec![Ok(LanguageModelCompletionEvent::Text(text))]
 711                }
 712                ContentDelta::ThinkingDelta { thinking } => {
 713                    vec![Ok(LanguageModelCompletionEvent::Thinking {
 714                        text: thinking,
 715                        signature: None,
 716                    })]
 717                }
 718                ContentDelta::SignatureDelta { signature } => {
 719                    vec![Ok(LanguageModelCompletionEvent::Thinking {
 720                        text: "".to_string(),
 721                        signature: Some(signature),
 722                    })]
 723                }
 724                ContentDelta::InputJsonDelta { partial_json } => {
 725                    if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
 726                        tool_use.input_json.push_str(&partial_json);
 727
 728                        // Try to convert invalid (incomplete) JSON into
 729                        // valid JSON that serde can accept, e.g. by closing
 730                        // unclosed delimiters. This way, we can update the
 731                        // UI with whatever has been streamed back so far.
 732                        if let Ok(input) = serde_json::Value::from_str(
 733                            &partial_json_fixer::fix_json(&tool_use.input_json),
 734                        ) {
 735                            return vec![Ok(LanguageModelCompletionEvent::ToolUse(
 736                                LanguageModelToolUse {
 737                                    id: tool_use.id.clone().into(),
 738                                    name: tool_use.name.clone().into(),
 739                                    is_input_complete: false,
 740                                    raw_input: tool_use.input_json.clone(),
 741                                    input,
 742                                },
 743                            ))];
 744                        }
 745                    }
 746                    return vec![];
 747                }
 748            },
 749            Event::ContentBlockStop { index } => {
 750                if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
 751                    let input_json = tool_use.input_json.trim();
 752                    let input_value = if input_json.is_empty() {
 753                        Ok(serde_json::Value::Object(serde_json::Map::default()))
 754                    } else {
 755                        serde_json::Value::from_str(input_json)
 756                    };
 757                    let event_result = match input_value {
 758                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
 759                            LanguageModelToolUse {
 760                                id: tool_use.id.into(),
 761                                name: tool_use.name.into(),
 762                                is_input_complete: true,
 763                                input,
 764                                raw_input: tool_use.input_json.clone(),
 765                            },
 766                        )),
 767                        Err(json_parse_err) => Err(LanguageModelCompletionError::BadInputJson {
 768                            id: tool_use.id.into(),
 769                            tool_name: tool_use.name.into(),
 770                            raw_input: input_json.into(),
 771                            json_parse_error: json_parse_err.to_string(),
 772                        }),
 773                    };
 774
 775                    vec![event_result]
 776                } else {
 777                    Vec::new()
 778                }
 779            }
 780            Event::MessageStart { message } => {
 781                update_usage(&mut self.usage, &message.usage);
 782                vec![
 783                    Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
 784                        &self.usage,
 785                    ))),
 786                    Ok(LanguageModelCompletionEvent::StartMessage {
 787                        message_id: message.id,
 788                    }),
 789                ]
 790            }
 791            Event::MessageDelta { delta, usage } => {
 792                update_usage(&mut self.usage, &usage);
 793                if let Some(stop_reason) = delta.stop_reason.as_deref() {
 794                    self.stop_reason = match stop_reason {
 795                        "end_turn" => StopReason::EndTurn,
 796                        "max_tokens" => StopReason::MaxTokens,
 797                        "tool_use" => StopReason::ToolUse,
 798                        _ => {
 799                            log::error!("Unexpected anthropic stop_reason: {stop_reason}");
 800                            StopReason::EndTurn
 801                        }
 802                    };
 803                }
 804                vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
 805                    convert_usage(&self.usage),
 806                ))]
 807            }
 808            Event::MessageStop => {
 809                vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
 810            }
 811            Event::Error { error } => {
 812                vec![Err(LanguageModelCompletionError::Other(anyhow!(
 813                    AnthropicError::ApiError(error)
 814                )))]
 815            }
 816            _ => Vec::new(),
 817        }
 818    }
 819}
 820
 821struct RawToolUse {
 822    id: String,
 823    name: String,
 824    input_json: String,
 825}
 826
 827pub fn anthropic_err_to_anyhow(err: AnthropicError) -> anyhow::Error {
 828    if let AnthropicError::ApiError(api_err) = &err {
 829        if let Some(tokens) = api_err.match_window_exceeded() {
 830            return anyhow!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens });
 831        }
 832    }
 833
 834    anyhow!(err)
 835}
 836
 837/// Updates usage data by preferring counts from `new`.
 838fn update_usage(usage: &mut Usage, new: &Usage) {
 839    if let Some(input_tokens) = new.input_tokens {
 840        usage.input_tokens = Some(input_tokens);
 841    }
 842    if let Some(output_tokens) = new.output_tokens {
 843        usage.output_tokens = Some(output_tokens);
 844    }
 845    if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
 846        usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
 847    }
 848    if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
 849        usage.cache_read_input_tokens = Some(cache_read_input_tokens);
 850    }
 851}
 852
 853fn convert_usage(usage: &Usage) -> language_model::TokenUsage {
 854    language_model::TokenUsage {
 855        input_tokens: usage.input_tokens.unwrap_or(0),
 856        output_tokens: usage.output_tokens.unwrap_or(0),
 857        cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
 858        cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
 859    }
 860}
 861
 862struct ConfigurationView {
 863    api_key_editor: Entity<Editor>,
 864    state: gpui::Entity<State>,
 865    load_credentials_task: Option<Task<()>>,
 866}
 867
 868impl ConfigurationView {
 869    const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
 870
 871    fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
 872        cx.observe(&state, |_, _, cx| {
 873            cx.notify();
 874        })
 875        .detach();
 876
 877        let load_credentials_task = Some(cx.spawn({
 878            let state = state.clone();
 879            async move |this, cx| {
 880                if let Some(task) = state
 881                    .update(cx, |state, cx| state.authenticate(cx))
 882                    .log_err()
 883                {
 884                    // We don't log an error, because "not signed in" is also an error.
 885                    let _ = task.await;
 886                }
 887                this.update(cx, |this, cx| {
 888                    this.load_credentials_task = None;
 889                    cx.notify();
 890                })
 891                .log_err();
 892            }
 893        }));
 894
 895        Self {
 896            api_key_editor: cx.new(|cx| {
 897                let mut editor = Editor::single_line(window, cx);
 898                editor.set_placeholder_text(Self::PLACEHOLDER_TEXT, cx);
 899                editor
 900            }),
 901            state,
 902            load_credentials_task,
 903        }
 904    }
 905
 906    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 907        let api_key = self.api_key_editor.read(cx).text(cx);
 908        if api_key.is_empty() {
 909            return;
 910        }
 911
 912        let state = self.state.clone();
 913        cx.spawn_in(window, async move |_, cx| {
 914            state
 915                .update(cx, |state, cx| state.set_api_key(api_key, cx))?
 916                .await
 917        })
 918        .detach_and_log_err(cx);
 919
 920        cx.notify();
 921    }
 922
 923    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 924        self.api_key_editor
 925            .update(cx, |editor, cx| editor.set_text("", window, cx));
 926
 927        let state = self.state.clone();
 928        cx.spawn_in(window, async move |_, cx| {
 929            state.update(cx, |state, cx| state.reset_api_key(cx))?.await
 930        })
 931        .detach_and_log_err(cx);
 932
 933        cx.notify();
 934    }
 935
 936    fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
 937        let settings = ThemeSettings::get_global(cx);
 938        let text_style = TextStyle {
 939            color: cx.theme().colors().text,
 940            font_family: settings.ui_font.family.clone(),
 941            font_features: settings.ui_font.features.clone(),
 942            font_fallbacks: settings.ui_font.fallbacks.clone(),
 943            font_size: rems(0.875).into(),
 944            font_weight: settings.ui_font.weight,
 945            font_style: FontStyle::Normal,
 946            line_height: relative(1.3),
 947            white_space: WhiteSpace::Normal,
 948            ..Default::default()
 949        };
 950        EditorElement::new(
 951            &self.api_key_editor,
 952            EditorStyle {
 953                background: cx.theme().colors().editor_background,
 954                local_player: cx.theme().players().local(),
 955                text: text_style,
 956                ..Default::default()
 957            },
 958        )
 959    }
 960
 961    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
 962        !self.state.read(cx).is_authenticated()
 963    }
 964}
 965
 966impl Render for ConfigurationView {
 967    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 968        let env_var_set = self.state.read(cx).api_key_from_env;
 969
 970        if self.load_credentials_task.is_some() {
 971            div().child(Label::new("Loading credentials...")).into_any()
 972        } else if self.should_render_editor(cx) {
 973            v_flex()
 974                .size_full()
 975                .on_action(cx.listener(Self::save_api_key))
 976                .child(Label::new("To use Zed's assistant with Anthropic, you need to add an API key. Follow these steps:"))
 977                .child(
 978                    List::new()
 979                        .child(
 980                            InstructionListItem::new(
 981                                "Create one by visiting",
 982                                Some("Anthropic's settings"),
 983                                Some("https://console.anthropic.com/settings/keys")
 984                            )
 985                        )
 986                        .child(
 987                            InstructionListItem::text_only("Paste your API key below and hit enter to start using the assistant")
 988                        )
 989                )
 990                .child(
 991                    h_flex()
 992                        .w_full()
 993                        .my_2()
 994                        .px_2()
 995                        .py_1()
 996                        .bg(cx.theme().colors().editor_background)
 997                        .border_1()
 998                        .border_color(cx.theme().colors().border)
 999                        .rounded_sm()
1000                        .child(self.render_api_key_editor(cx)),
1001                )
1002                .child(
1003                    Label::new(
1004                        format!("You can also assign the {ANTHROPIC_API_KEY_VAR} environment variable and restart Zed."),
1005                    )
1006                    .size(LabelSize::Small)
1007                    .color(Color::Muted),
1008                )
1009                .into_any()
1010        } else {
1011            h_flex()
1012                .mt_1()
1013                .p_1()
1014                .justify_between()
1015                .rounded_md()
1016                .border_1()
1017                .border_color(cx.theme().colors().border)
1018                .bg(cx.theme().colors().background)
1019                .child(
1020                    h_flex()
1021                        .gap_1()
1022                        .child(Icon::new(IconName::Check).color(Color::Success))
1023                        .child(Label::new(if env_var_set {
1024                            format!("API key set in {ANTHROPIC_API_KEY_VAR} environment variable.")
1025                        } else {
1026                            "API key configured.".to_string()
1027                        })),
1028                )
1029                .child(
1030                    Button::new("reset-key", "Reset Key")
1031                        .label_size(LabelSize::Small)
1032                        .icon(Some(IconName::Trash))
1033                        .icon_size(IconSize::Small)
1034                        .icon_position(IconPosition::Start)
1035                        .disabled(env_var_set)
1036                        .when(env_var_set, |this| {
1037                            this.tooltip(Tooltip::text(format!("To reset your API key, unset the {ANTHROPIC_API_KEY_VAR} environment variable.")))
1038                        })
1039                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
1040                )
1041                .into_any()
1042        }
1043    }
1044}