anthropic.rs

   1use crate::AllLanguageModelSettings;
   2use crate::ui::InstructionListItem;
   3use anthropic::{
   4    AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, ToolResultContent,
   5    ToolResultPart, Usage,
   6};
   7use anyhow::{Context as _, Result, anyhow};
   8use collections::{BTreeMap, HashMap};
   9use credentials_provider::CredentialsProvider;
  10use editor::{Editor, EditorElement, EditorStyle};
  11use futures::Stream;
  12use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  13use gpui::{
  14    AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
  15};
  16use http_client::HttpClient;
  17use language_model::{
  18    AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
  19    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId,
  20    LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  21    LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
  22    LanguageModelToolResultContent, MessageContent, RateLimiter, Role,
  23};
  24use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
  25use schemars::JsonSchema;
  26use serde::{Deserialize, Serialize};
  27use settings::{Settings, SettingsStore};
  28use std::pin::Pin;
  29use std::str::FromStr;
  30use std::sync::Arc;
  31use strum::IntoEnumIterator;
  32use theme::ThemeSettings;
  33use ui::{Icon, IconName, List, Tooltip, prelude::*};
  34use util::ResultExt;
  35
  36const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
  37const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME;
  38
  39#[derive(Default, Clone, Debug, PartialEq)]
  40pub struct AnthropicSettings {
  41    pub api_url: String,
  42    /// Extend Zed's list of Anthropic models.
  43    pub available_models: Vec<AvailableModel>,
  44}
  45
  46#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  47pub struct AvailableModel {
  48    /// The model's name in the Anthropic API. e.g. claude-3-5-sonnet-latest, claude-3-opus-20240229, etc
  49    pub name: String,
  50    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
  51    pub display_name: Option<String>,
  52    /// The model's context window size.
  53    pub max_tokens: u64,
  54    /// A model `name` to substitute when calling tools, in case the primary model doesn't support tool calling.
  55    pub tool_override: Option<String>,
  56    /// Configuration of Anthropic's caching API.
  57    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
  58    pub max_output_tokens: Option<u64>,
  59    pub default_temperature: Option<f32>,
  60    #[serde(default)]
  61    pub extra_beta_headers: Vec<String>,
  62    /// The model's mode (e.g. thinking)
  63    pub mode: Option<ModelMode>,
  64}
  65
  66#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
  67#[serde(tag = "type", rename_all = "lowercase")]
  68pub enum ModelMode {
  69    #[default]
  70    Default,
  71    Thinking {
  72        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  73        budget_tokens: Option<u32>,
  74    },
  75}
  76
  77impl From<ModelMode> for AnthropicModelMode {
  78    fn from(value: ModelMode) -> Self {
  79        match value {
  80            ModelMode::Default => AnthropicModelMode::Default,
  81            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  82        }
  83    }
  84}
  85
  86impl From<AnthropicModelMode> for ModelMode {
  87    fn from(value: AnthropicModelMode) -> Self {
  88        match value {
  89            AnthropicModelMode::Default => ModelMode::Default,
  90            AnthropicModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
  91        }
  92    }
  93}
  94
  95pub struct AnthropicLanguageModelProvider {
  96    http_client: Arc<dyn HttpClient>,
  97    state: gpui::Entity<State>,
  98}
  99
 100const ANTHROPIC_API_KEY_VAR: &str = "ANTHROPIC_API_KEY";
 101
 102pub struct State {
 103    api_key: Option<String>,
 104    api_key_from_env: bool,
 105    _subscription: Subscription,
 106}
 107
 108impl State {
 109    fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 110        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 111        let api_url = AllLanguageModelSettings::get_global(cx)
 112            .anthropic
 113            .api_url
 114            .clone();
 115        cx.spawn(async move |this, cx| {
 116            credentials_provider
 117                .delete_credentials(&api_url, cx)
 118                .await
 119                .ok();
 120            this.update(cx, |this, cx| {
 121                this.api_key = None;
 122                this.api_key_from_env = false;
 123                cx.notify();
 124            })
 125        })
 126    }
 127
 128    fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
 129        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 130        let api_url = AllLanguageModelSettings::get_global(cx)
 131            .anthropic
 132            .api_url
 133            .clone();
 134        cx.spawn(async move |this, cx| {
 135            credentials_provider
 136                .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
 137                .await
 138                .ok();
 139
 140            this.update(cx, |this, cx| {
 141                this.api_key = Some(api_key);
 142                cx.notify();
 143            })
 144        })
 145    }
 146
 147    fn is_authenticated(&self) -> bool {
 148        self.api_key.is_some()
 149    }
 150
 151    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 152        if self.is_authenticated() {
 153            return Task::ready(Ok(()));
 154        }
 155
 156        let key = AnthropicLanguageModelProvider::api_key(cx);
 157
 158        cx.spawn(async move |this, cx| {
 159            let key = key.await?;
 160
 161            this.update(cx, |this, cx| {
 162                this.api_key = Some(key.key);
 163                this.api_key_from_env = key.from_env;
 164                cx.notify();
 165            })?;
 166
 167            Ok(())
 168        })
 169    }
 170}
 171
 172pub struct ApiKey {
 173    pub key: String,
 174    pub from_env: bool,
 175}
 176
 177impl AnthropicLanguageModelProvider {
 178    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
 179        let state = cx.new(|cx| State {
 180            api_key: None,
 181            api_key_from_env: false,
 182            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 183                cx.notify();
 184            }),
 185        });
 186
 187        Self { http_client, state }
 188    }
 189
 190    fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
 191        Arc::new(AnthropicModel {
 192            id: LanguageModelId::from(model.id().to_string()),
 193            model,
 194            state: self.state.clone(),
 195            http_client: self.http_client.clone(),
 196            request_limiter: RateLimiter::new(4),
 197        })
 198    }
 199
 200    pub fn api_key(cx: &mut App) -> Task<Result<ApiKey, AuthenticateError>> {
 201        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 202        let api_url = AllLanguageModelSettings::get_global(cx)
 203            .anthropic
 204            .api_url
 205            .clone();
 206
 207        if let Ok(key) = std::env::var(ANTHROPIC_API_KEY_VAR) {
 208            Task::ready(Ok(ApiKey {
 209                key,
 210                from_env: true,
 211            }))
 212        } else {
 213            cx.spawn(async move |cx| {
 214                let (_, api_key) = credentials_provider
 215                    .read_credentials(&api_url, cx)
 216                    .await?
 217                    .ok_or(AuthenticateError::CredentialsNotFound)?;
 218
 219                Ok(ApiKey {
 220                    key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
 221                    from_env: false,
 222                })
 223            })
 224        }
 225    }
 226}
 227
 228impl LanguageModelProviderState for AnthropicLanguageModelProvider {
 229    type ObservableEntity = State;
 230
 231    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 232        Some(self.state.clone())
 233    }
 234}
 235
 236impl LanguageModelProvider for AnthropicLanguageModelProvider {
 237    fn id(&self) -> LanguageModelProviderId {
 238        PROVIDER_ID
 239    }
 240
 241    fn name(&self) -> LanguageModelProviderName {
 242        PROVIDER_NAME
 243    }
 244
 245    fn icon(&self) -> IconName {
 246        IconName::AiAnthropic
 247    }
 248
 249    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 250        Some(self.create_language_model(anthropic::Model::default()))
 251    }
 252
 253    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 254        Some(self.create_language_model(anthropic::Model::default_fast()))
 255    }
 256
 257    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 258        [
 259            anthropic::Model::ClaudeSonnet4,
 260            anthropic::Model::ClaudeSonnet4Thinking,
 261        ]
 262        .into_iter()
 263        .map(|model| self.create_language_model(model))
 264        .collect()
 265    }
 266
 267    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 268        let mut models = BTreeMap::default();
 269
 270        // Add base models from anthropic::Model::iter()
 271        for model in anthropic::Model::iter() {
 272            if !matches!(model, anthropic::Model::Custom { .. }) {
 273                models.insert(model.id().to_string(), model);
 274            }
 275        }
 276
 277        // Override with available models from settings
 278        for model in AllLanguageModelSettings::get_global(cx)
 279            .anthropic
 280            .available_models
 281            .iter()
 282        {
 283            models.insert(
 284                model.name.clone(),
 285                anthropic::Model::Custom {
 286                    name: model.name.clone(),
 287                    display_name: model.display_name.clone(),
 288                    max_tokens: model.max_tokens,
 289                    tool_override: model.tool_override.clone(),
 290                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
 291                        anthropic::AnthropicModelCacheConfiguration {
 292                            max_cache_anchors: config.max_cache_anchors,
 293                            should_speculate: config.should_speculate,
 294                            min_total_token: config.min_total_token,
 295                        }
 296                    }),
 297                    max_output_tokens: model.max_output_tokens,
 298                    default_temperature: model.default_temperature,
 299                    extra_beta_headers: model.extra_beta_headers.clone(),
 300                    mode: model.mode.clone().unwrap_or_default().into(),
 301                },
 302            );
 303        }
 304
 305        models
 306            .into_values()
 307            .map(|model| self.create_language_model(model))
 308            .collect()
 309    }
 310
 311    fn is_authenticated(&self, cx: &App) -> bool {
 312        self.state.read(cx).is_authenticated()
 313    }
 314
 315    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 316        self.state.update(cx, |state, cx| state.authenticate(cx))
 317    }
 318
 319    fn configuration_view(
 320        &self,
 321        target_agent: ConfigurationViewTargetAgent,
 322        window: &mut Window,
 323        cx: &mut App,
 324    ) -> AnyView {
 325        cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
 326            .into()
 327    }
 328
 329    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 330        self.state.update(cx, |state, cx| state.reset_api_key(cx))
 331    }
 332}
 333
 334pub struct AnthropicModel {
 335    id: LanguageModelId,
 336    model: anthropic::Model,
 337    state: gpui::Entity<State>,
 338    http_client: Arc<dyn HttpClient>,
 339    request_limiter: RateLimiter,
 340}
 341
 342pub fn count_anthropic_tokens(
 343    request: LanguageModelRequest,
 344    cx: &App,
 345) -> BoxFuture<'static, Result<u64>> {
 346    cx.background_spawn(async move {
 347        let messages = request.messages;
 348        let mut tokens_from_images = 0;
 349        let mut string_messages = Vec::with_capacity(messages.len());
 350
 351        for message in messages {
 352            use language_model::MessageContent;
 353
 354            let mut string_contents = String::new();
 355
 356            for content in message.content {
 357                match content {
 358                    MessageContent::Text(text) => {
 359                        string_contents.push_str(&text);
 360                    }
 361                    MessageContent::Thinking { .. } => {
 362                        // Thinking blocks are not included in the input token count.
 363                    }
 364                    MessageContent::RedactedThinking(_) => {
 365                        // Thinking blocks are not included in the input token count.
 366                    }
 367                    MessageContent::Image(image) => {
 368                        tokens_from_images += image.estimate_tokens();
 369                    }
 370                    MessageContent::ToolUse(_tool_use) => {
 371                        // TODO: Estimate token usage from tool uses.
 372                    }
 373                    MessageContent::ToolResult(tool_result) => match &tool_result.content {
 374                        LanguageModelToolResultContent::Text(text) => {
 375                            string_contents.push_str(text);
 376                        }
 377                        LanguageModelToolResultContent::Image(image) => {
 378                            tokens_from_images += image.estimate_tokens();
 379                        }
 380                    },
 381                }
 382            }
 383
 384            if !string_contents.is_empty() {
 385                string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
 386                    role: match message.role {
 387                        Role::User => "user".into(),
 388                        Role::Assistant => "assistant".into(),
 389                        Role::System => "system".into(),
 390                    },
 391                    content: Some(string_contents),
 392                    name: None,
 393                    function_call: None,
 394                });
 395            }
 396        }
 397
 398        // Tiktoken doesn't yet support these models, so we manually use the
 399        // same tokenizer as GPT-4.
 400        tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
 401            .map(|tokens| (tokens + tokens_from_images) as u64)
 402    })
 403    .boxed()
 404}
 405
 406impl AnthropicModel {
 407    fn stream_completion(
 408        &self,
 409        request: anthropic::Request,
 410        cx: &AsyncApp,
 411    ) -> BoxFuture<
 412        'static,
 413        Result<
 414            BoxStream<'static, Result<anthropic::Event, AnthropicError>>,
 415            LanguageModelCompletionError,
 416        >,
 417    > {
 418        let http_client = self.http_client.clone();
 419
 420        let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
 421            let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
 422            (state.api_key.clone(), settings.api_url.clone())
 423        }) else {
 424            return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed();
 425        };
 426
 427        async move {
 428            let Some(api_key) = api_key else {
 429                return Err(LanguageModelCompletionError::NoApiKey {
 430                    provider: PROVIDER_NAME,
 431                });
 432            };
 433            let request =
 434                anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
 435            request.await.map_err(Into::into)
 436        }
 437        .boxed()
 438    }
 439}
 440
 441impl LanguageModel for AnthropicModel {
 442    fn id(&self) -> LanguageModelId {
 443        self.id.clone()
 444    }
 445
 446    fn name(&self) -> LanguageModelName {
 447        LanguageModelName::from(self.model.display_name().to_string())
 448    }
 449
 450    fn provider_id(&self) -> LanguageModelProviderId {
 451        PROVIDER_ID
 452    }
 453
 454    fn provider_name(&self) -> LanguageModelProviderName {
 455        PROVIDER_NAME
 456    }
 457
 458    fn supports_tools(&self) -> bool {
 459        true
 460    }
 461
 462    fn supports_images(&self) -> bool {
 463        true
 464    }
 465
 466    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 467        match choice {
 468            LanguageModelToolChoice::Auto
 469            | LanguageModelToolChoice::Any
 470            | LanguageModelToolChoice::None => true,
 471        }
 472    }
 473
 474    fn telemetry_id(&self) -> String {
 475        format!("anthropic/{}", self.model.id())
 476    }
 477
 478    fn api_key(&self, cx: &App) -> Option<String> {
 479        self.state.read(cx).api_key.clone()
 480    }
 481
 482    fn max_token_count(&self) -> u64 {
 483        self.model.max_token_count()
 484    }
 485
 486    fn max_output_tokens(&self) -> Option<u64> {
 487        Some(self.model.max_output_tokens())
 488    }
 489
 490    fn count_tokens(
 491        &self,
 492        request: LanguageModelRequest,
 493        cx: &App,
 494    ) -> BoxFuture<'static, Result<u64>> {
 495        count_anthropic_tokens(request, cx)
 496    }
 497
 498    fn stream_completion(
 499        &self,
 500        request: LanguageModelRequest,
 501        cx: &AsyncApp,
 502    ) -> BoxFuture<
 503        'static,
 504        Result<
 505            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 506            LanguageModelCompletionError,
 507        >,
 508    > {
 509        let request = into_anthropic(
 510            request,
 511            self.model.request_id().into(),
 512            self.model.default_temperature(),
 513            self.model.max_output_tokens(),
 514            self.model.mode(),
 515        );
 516        let request = self.stream_completion(request, cx);
 517        let future = self.request_limiter.stream(async move {
 518            let response = request.await?;
 519            Ok(AnthropicEventMapper::new().map_stream(response))
 520        });
 521        async move { Ok(future.await?.boxed()) }.boxed()
 522    }
 523
 524    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 525        self.model
 526            .cache_configuration()
 527            .map(|config| LanguageModelCacheConfiguration {
 528                max_cache_anchors: config.max_cache_anchors,
 529                should_speculate: config.should_speculate,
 530                min_total_token: config.min_total_token,
 531            })
 532    }
 533}
 534
 535pub fn into_anthropic(
 536    request: LanguageModelRequest,
 537    model: String,
 538    default_temperature: f32,
 539    max_output_tokens: u64,
 540    mode: AnthropicModelMode,
 541) -> anthropic::Request {
 542    let mut new_messages: Vec<anthropic::Message> = Vec::new();
 543    let mut system_message = String::new();
 544
 545    for message in request.messages {
 546        if message.contents_empty() {
 547            continue;
 548        }
 549
 550        match message.role {
 551            Role::User | Role::Assistant => {
 552                let mut anthropic_message_content: Vec<anthropic::RequestContent> = message
 553                    .content
 554                    .into_iter()
 555                    .filter_map(|content| match content {
 556                        MessageContent::Text(text) => {
 557                            let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) {
 558                                text.trim_end().to_string()
 559                            } else {
 560                                text
 561                            };
 562                            if !text.is_empty() {
 563                                Some(anthropic::RequestContent::Text {
 564                                    text,
 565                                    cache_control: None,
 566                                })
 567                            } else {
 568                                None
 569                            }
 570                        }
 571                        MessageContent::Thinking {
 572                            text: thinking,
 573                            signature,
 574                        } => {
 575                            if !thinking.is_empty() {
 576                                Some(anthropic::RequestContent::Thinking {
 577                                    thinking,
 578                                    signature: signature.unwrap_or_default(),
 579                                    cache_control: None,
 580                                })
 581                            } else {
 582                                None
 583                            }
 584                        }
 585                        MessageContent::RedactedThinking(data) => {
 586                            if !data.is_empty() {
 587                                Some(anthropic::RequestContent::RedactedThinking { data })
 588                            } else {
 589                                None
 590                            }
 591                        }
 592                        MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
 593                            source: anthropic::ImageSource {
 594                                source_type: "base64".to_string(),
 595                                media_type: "image/png".to_string(),
 596                                data: image.source.to_string(),
 597                            },
 598                            cache_control: None,
 599                        }),
 600                        MessageContent::ToolUse(tool_use) => {
 601                            Some(anthropic::RequestContent::ToolUse {
 602                                id: tool_use.id.to_string(),
 603                                name: tool_use.name.to_string(),
 604                                input: tool_use.input,
 605                                cache_control: None,
 606                            })
 607                        }
 608                        MessageContent::ToolResult(tool_result) => {
 609                            Some(anthropic::RequestContent::ToolResult {
 610                                tool_use_id: tool_result.tool_use_id.to_string(),
 611                                is_error: tool_result.is_error,
 612                                content: match tool_result.content {
 613                                    LanguageModelToolResultContent::Text(text) => {
 614                                        ToolResultContent::Plain(text.to_string())
 615                                    }
 616                                    LanguageModelToolResultContent::Image(image) => {
 617                                        ToolResultContent::Multipart(vec![ToolResultPart::Image {
 618                                            source: anthropic::ImageSource {
 619                                                source_type: "base64".to_string(),
 620                                                media_type: "image/png".to_string(),
 621                                                data: image.source.to_string(),
 622                                            },
 623                                        }])
 624                                    }
 625                                },
 626                                cache_control: None,
 627                            })
 628                        }
 629                    })
 630                    .collect();
 631                let anthropic_role = match message.role {
 632                    Role::User => anthropic::Role::User,
 633                    Role::Assistant => anthropic::Role::Assistant,
 634                    Role::System => unreachable!("System role should never occur here"),
 635                };
 636                if let Some(last_message) = new_messages.last_mut()
 637                    && last_message.role == anthropic_role
 638                {
 639                    last_message.content.extend(anthropic_message_content);
 640                    continue;
 641                }
 642
 643                // Mark the last segment of the message as cached
 644                if message.cache {
 645                    let cache_control_value = Some(anthropic::CacheControl {
 646                        cache_type: anthropic::CacheControlType::Ephemeral,
 647                    });
 648                    for message_content in anthropic_message_content.iter_mut().rev() {
 649                        match message_content {
 650                            anthropic::RequestContent::RedactedThinking { .. } => {
 651                                // Caching is not possible, fallback to next message
 652                            }
 653                            anthropic::RequestContent::Text { cache_control, .. }
 654                            | anthropic::RequestContent::Thinking { cache_control, .. }
 655                            | anthropic::RequestContent::Image { cache_control, .. }
 656                            | anthropic::RequestContent::ToolUse { cache_control, .. }
 657                            | anthropic::RequestContent::ToolResult { cache_control, .. } => {
 658                                *cache_control = cache_control_value;
 659                                break;
 660                            }
 661                        }
 662                    }
 663                }
 664
 665                new_messages.push(anthropic::Message {
 666                    role: anthropic_role,
 667                    content: anthropic_message_content,
 668                });
 669            }
 670            Role::System => {
 671                if !system_message.is_empty() {
 672                    system_message.push_str("\n\n");
 673                }
 674                system_message.push_str(&message.string_contents());
 675            }
 676        }
 677    }
 678
 679    anthropic::Request {
 680        model,
 681        messages: new_messages,
 682        max_tokens: max_output_tokens,
 683        system: if system_message.is_empty() {
 684            None
 685        } else {
 686            Some(anthropic::StringOrContents::String(system_message))
 687        },
 688        thinking: if request.thinking_allowed
 689            && let AnthropicModelMode::Thinking { budget_tokens } = mode
 690        {
 691            Some(anthropic::Thinking::Enabled { budget_tokens })
 692        } else {
 693            None
 694        },
 695        tools: request
 696            .tools
 697            .into_iter()
 698            .map(|tool| anthropic::Tool {
 699                name: tool.name,
 700                description: tool.description,
 701                input_schema: tool.input_schema,
 702            })
 703            .collect(),
 704        tool_choice: request.tool_choice.map(|choice| match choice {
 705            LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
 706            LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
 707            LanguageModelToolChoice::None => anthropic::ToolChoice::None,
 708        }),
 709        metadata: None,
 710        stop_sequences: Vec::new(),
 711        temperature: request.temperature.or(Some(default_temperature)),
 712        top_k: None,
 713        top_p: None,
 714    }
 715}
 716
 717pub struct AnthropicEventMapper {
 718    tool_uses_by_index: HashMap<usize, RawToolUse>,
 719    usage: Usage,
 720    stop_reason: StopReason,
 721}
 722
 723impl AnthropicEventMapper {
 724    pub fn new() -> Self {
 725        Self {
 726            tool_uses_by_index: HashMap::default(),
 727            usage: Usage::default(),
 728            stop_reason: StopReason::EndTurn,
 729        }
 730    }
 731
 732    pub fn map_stream(
 733        mut self,
 734        events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
 735    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 736    {
 737        events.flat_map(move |event| {
 738            futures::stream::iter(match event {
 739                Ok(event) => self.map_event(event),
 740                Err(error) => vec![Err(error.into())],
 741            })
 742        })
 743    }
 744
 745    pub fn map_event(
 746        &mut self,
 747        event: Event,
 748    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 749        match event {
 750            Event::ContentBlockStart {
 751                index,
 752                content_block,
 753            } => match content_block {
 754                ResponseContent::Text { text } => {
 755                    vec![Ok(LanguageModelCompletionEvent::Text(text))]
 756                }
 757                ResponseContent::Thinking { thinking } => {
 758                    vec![Ok(LanguageModelCompletionEvent::Thinking {
 759                        text: thinking,
 760                        signature: None,
 761                    })]
 762                }
 763                ResponseContent::RedactedThinking { data } => {
 764                    vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
 765                }
 766                ResponseContent::ToolUse { id, name, .. } => {
 767                    self.tool_uses_by_index.insert(
 768                        index,
 769                        RawToolUse {
 770                            id,
 771                            name,
 772                            input_json: String::new(),
 773                        },
 774                    );
 775                    Vec::new()
 776                }
 777            },
 778            Event::ContentBlockDelta { index, delta } => match delta {
 779                ContentDelta::TextDelta { text } => {
 780                    vec![Ok(LanguageModelCompletionEvent::Text(text))]
 781                }
 782                ContentDelta::ThinkingDelta { thinking } => {
 783                    vec![Ok(LanguageModelCompletionEvent::Thinking {
 784                        text: thinking,
 785                        signature: None,
 786                    })]
 787                }
 788                ContentDelta::SignatureDelta { signature } => {
 789                    vec![Ok(LanguageModelCompletionEvent::Thinking {
 790                        text: "".to_string(),
 791                        signature: Some(signature),
 792                    })]
 793                }
 794                ContentDelta::InputJsonDelta { partial_json } => {
 795                    if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
 796                        tool_use.input_json.push_str(&partial_json);
 797
 798                        // Try to convert invalid (incomplete) JSON into
 799                        // valid JSON that serde can accept, e.g. by closing
 800                        // unclosed delimiters. This way, we can update the
 801                        // UI with whatever has been streamed back so far.
 802                        if let Ok(input) = serde_json::Value::from_str(
 803                            &partial_json_fixer::fix_json(&tool_use.input_json),
 804                        ) {
 805                            return vec![Ok(LanguageModelCompletionEvent::ToolUse(
 806                                LanguageModelToolUse {
 807                                    id: tool_use.id.clone().into(),
 808                                    name: tool_use.name.clone().into(),
 809                                    is_input_complete: false,
 810                                    raw_input: tool_use.input_json.clone(),
 811                                    input,
 812                                },
 813                            ))];
 814                        }
 815                    }
 816                    vec![]
 817                }
 818            },
 819            Event::ContentBlockStop { index } => {
 820                if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
 821                    let input_json = tool_use.input_json.trim();
 822                    let input_value = if input_json.is_empty() {
 823                        Ok(serde_json::Value::Object(serde_json::Map::default()))
 824                    } else {
 825                        serde_json::Value::from_str(input_json)
 826                    };
 827                    let event_result = match input_value {
 828                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
 829                            LanguageModelToolUse {
 830                                id: tool_use.id.into(),
 831                                name: tool_use.name.into(),
 832                                is_input_complete: true,
 833                                input,
 834                                raw_input: tool_use.input_json.clone(),
 835                            },
 836                        )),
 837                        Err(json_parse_err) => {
 838                            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 839                                id: tool_use.id.into(),
 840                                tool_name: tool_use.name.into(),
 841                                raw_input: input_json.into(),
 842                                json_parse_error: json_parse_err.to_string(),
 843                            })
 844                        }
 845                    };
 846
 847                    vec![event_result]
 848                } else {
 849                    Vec::new()
 850                }
 851            }
 852            Event::MessageStart { message } => {
 853                update_usage(&mut self.usage, &message.usage);
 854                vec![
 855                    Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
 856                        &self.usage,
 857                    ))),
 858                    Ok(LanguageModelCompletionEvent::StartMessage {
 859                        message_id: message.id,
 860                    }),
 861                ]
 862            }
 863            Event::MessageDelta { delta, usage } => {
 864                update_usage(&mut self.usage, &usage);
 865                if let Some(stop_reason) = delta.stop_reason.as_deref() {
 866                    self.stop_reason = match stop_reason {
 867                        "end_turn" => StopReason::EndTurn,
 868                        "max_tokens" => StopReason::MaxTokens,
 869                        "tool_use" => StopReason::ToolUse,
 870                        "refusal" => StopReason::Refusal,
 871                        _ => {
 872                            log::error!("Unexpected anthropic stop_reason: {stop_reason}");
 873                            StopReason::EndTurn
 874                        }
 875                    };
 876                }
 877                vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
 878                    convert_usage(&self.usage),
 879                ))]
 880            }
 881            Event::MessageStop => {
 882                vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
 883            }
 884            Event::Error { error } => {
 885                vec![Err(error.into())]
 886            }
 887            _ => Vec::new(),
 888        }
 889    }
 890}
 891
 892struct RawToolUse {
 893    id: String,
 894    name: String,
 895    input_json: String,
 896}
 897
 898/// Updates usage data by preferring counts from `new`.
 899fn update_usage(usage: &mut Usage, new: &Usage) {
 900    if let Some(input_tokens) = new.input_tokens {
 901        usage.input_tokens = Some(input_tokens);
 902    }
 903    if let Some(output_tokens) = new.output_tokens {
 904        usage.output_tokens = Some(output_tokens);
 905    }
 906    if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
 907        usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
 908    }
 909    if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
 910        usage.cache_read_input_tokens = Some(cache_read_input_tokens);
 911    }
 912}
 913
 914fn convert_usage(usage: &Usage) -> language_model::TokenUsage {
 915    language_model::TokenUsage {
 916        input_tokens: usage.input_tokens.unwrap_or(0),
 917        output_tokens: usage.output_tokens.unwrap_or(0),
 918        cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
 919        cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
 920    }
 921}
 922
 923struct ConfigurationView {
 924    api_key_editor: Entity<Editor>,
 925    state: gpui::Entity<State>,
 926    load_credentials_task: Option<Task<()>>,
 927    target_agent: ConfigurationViewTargetAgent,
 928}
 929
 930impl ConfigurationView {
 931    const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
 932
 933    fn new(
 934        state: gpui::Entity<State>,
 935        target_agent: ConfigurationViewTargetAgent,
 936        window: &mut Window,
 937        cx: &mut Context<Self>,
 938    ) -> Self {
 939        cx.observe(&state, |_, _, cx| {
 940            cx.notify();
 941        })
 942        .detach();
 943
 944        let load_credentials_task = Some(cx.spawn({
 945            let state = state.clone();
 946            async move |this, cx| {
 947                if let Some(task) = state
 948                    .update(cx, |state, cx| state.authenticate(cx))
 949                    .log_err()
 950                {
 951                    // We don't log an error, because "not signed in" is also an error.
 952                    let _ = task.await;
 953                }
 954                this.update(cx, |this, cx| {
 955                    this.load_credentials_task = None;
 956                    cx.notify();
 957                })
 958                .log_err();
 959            }
 960        }));
 961
 962        Self {
 963            api_key_editor: cx.new(|cx| {
 964                let mut editor = Editor::single_line(window, cx);
 965                editor.set_placeholder_text(Self::PLACEHOLDER_TEXT, cx);
 966                editor
 967            }),
 968            state,
 969            load_credentials_task,
 970            target_agent,
 971        }
 972    }
 973
 974    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 975        let api_key = self.api_key_editor.read(cx).text(cx);
 976        if api_key.is_empty() {
 977            return;
 978        }
 979
 980        let state = self.state.clone();
 981        cx.spawn_in(window, async move |_, cx| {
 982            state
 983                .update(cx, |state, cx| state.set_api_key(api_key, cx))?
 984                .await
 985        })
 986        .detach_and_log_err(cx);
 987
 988        cx.notify();
 989    }
 990
 991    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 992        self.api_key_editor
 993            .update(cx, |editor, cx| editor.set_text("", window, cx));
 994
 995        let state = self.state.clone();
 996        cx.spawn_in(window, async move |_, cx| {
 997            state.update(cx, |state, cx| state.reset_api_key(cx))?.await
 998        })
 999        .detach_and_log_err(cx);
1000
1001        cx.notify();
1002    }
1003
1004    fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
1005        let settings = ThemeSettings::get_global(cx);
1006        let text_style = TextStyle {
1007            color: cx.theme().colors().text,
1008            font_family: settings.ui_font.family.clone(),
1009            font_features: settings.ui_font.features.clone(),
1010            font_fallbacks: settings.ui_font.fallbacks.clone(),
1011            font_size: rems(0.875).into(),
1012            font_weight: settings.ui_font.weight,
1013            font_style: FontStyle::Normal,
1014            line_height: relative(1.3),
1015            white_space: WhiteSpace::Normal,
1016            ..Default::default()
1017        };
1018        EditorElement::new(
1019            &self.api_key_editor,
1020            EditorStyle {
1021                background: cx.theme().colors().editor_background,
1022                local_player: cx.theme().players().local(),
1023                text: text_style,
1024                ..Default::default()
1025            },
1026        )
1027    }
1028
1029    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
1030        !self.state.read(cx).is_authenticated()
1031    }
1032}
1033
1034impl Render for ConfigurationView {
1035    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1036        let env_var_set = self.state.read(cx).api_key_from_env;
1037
1038        if self.load_credentials_task.is_some() {
1039            div().child(Label::new("Loading credentials...")).into_any()
1040        } else if self.should_render_editor(cx) {
1041            v_flex()
1042                .size_full()
1043                .on_action(cx.listener(Self::save_api_key))
1044                .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
1045                    ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Anthropic".into(),
1046                    ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
1047                })))
1048                .child(
1049                    List::new()
1050                        .child(
1051                            InstructionListItem::new(
1052                                "Create one by visiting",
1053                                Some("Anthropic's settings"),
1054                                Some("https://console.anthropic.com/settings/keys")
1055                            )
1056                        )
1057                        .child(
1058                            InstructionListItem::text_only("Paste your API key below and hit enter to start using the agent")
1059                        )
1060                )
1061                .child(
1062                    h_flex()
1063                        .w_full()
1064                        .my_2()
1065                        .px_2()
1066                        .py_1()
1067                        .bg(cx.theme().colors().editor_background)
1068                        .border_1()
1069                        .border_color(cx.theme().colors().border)
1070                        .rounded_sm()
1071                        .child(self.render_api_key_editor(cx)),
1072                )
1073                .child(
1074                    Label::new(
1075                        format!("You can also assign the {ANTHROPIC_API_KEY_VAR} environment variable and restart Zed."),
1076                    )
1077                    .size(LabelSize::Small)
1078                    .color(Color::Muted),
1079                )
1080                .into_any()
1081        } else {
1082            h_flex()
1083                .mt_1()
1084                .p_1()
1085                .justify_between()
1086                .rounded_md()
1087                .border_1()
1088                .border_color(cx.theme().colors().border)
1089                .bg(cx.theme().colors().background)
1090                .child(
1091                    h_flex()
1092                        .gap_1()
1093                        .child(Icon::new(IconName::Check).color(Color::Success))
1094                        .child(Label::new(if env_var_set {
1095                            format!("API key set in {ANTHROPIC_API_KEY_VAR} environment variable.")
1096                        } else {
1097                            "API key configured.".to_string()
1098                        })),
1099                )
1100                .child(
1101                    Button::new("reset-key", "Reset Key")
1102                        .label_size(LabelSize::Small)
1103                        .icon(Some(IconName::Trash))
1104                        .icon_size(IconSize::Small)
1105                        .icon_position(IconPosition::Start)
1106                        .disabled(env_var_set)
1107                        .when(env_var_set, |this| {
1108                            this.tooltip(Tooltip::text(format!("To reset your API key, unset the {ANTHROPIC_API_KEY_VAR} environment variable.")))
1109                        })
1110                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
1111                )
1112                .into_any()
1113        }
1114    }
1115}
1116
1117#[cfg(test)]
1118mod tests {
1119    use super::*;
1120    use anthropic::AnthropicModelMode;
1121    use language_model::{LanguageModelRequestMessage, MessageContent};
1122
1123    #[test]
1124    fn test_cache_control_only_on_last_segment() {
1125        let request = LanguageModelRequest {
1126            messages: vec![LanguageModelRequestMessage {
1127                role: Role::User,
1128                content: vec![
1129                    MessageContent::Text("Some prompt".to_string()),
1130                    MessageContent::Image(language_model::LanguageModelImage::empty()),
1131                    MessageContent::Image(language_model::LanguageModelImage::empty()),
1132                    MessageContent::Image(language_model::LanguageModelImage::empty()),
1133                    MessageContent::Image(language_model::LanguageModelImage::empty()),
1134                ],
1135                cache: true,
1136            }],
1137            thread_id: None,
1138            prompt_id: None,
1139            intent: None,
1140            mode: None,
1141            stop: vec![],
1142            temperature: None,
1143            tools: vec![],
1144            tool_choice: None,
1145            thinking_allowed: true,
1146        };
1147
1148        let anthropic_request = into_anthropic(
1149            request,
1150            "claude-3-5-sonnet".to_string(),
1151            0.7,
1152            4096,
1153            AnthropicModelMode::Default,
1154        );
1155
1156        assert_eq!(anthropic_request.messages.len(), 1);
1157
1158        let message = &anthropic_request.messages[0];
1159        assert_eq!(message.content.len(), 5);
1160
1161        assert!(matches!(
1162            message.content[0],
1163            anthropic::RequestContent::Text {
1164                cache_control: None,
1165                ..
1166            }
1167        ));
1168        for i in 1..3 {
1169            assert!(matches!(
1170                message.content[i],
1171                anthropic::RequestContent::Image {
1172                    cache_control: None,
1173                    ..
1174                }
1175            ));
1176        }
1177
1178        assert!(matches!(
1179            message.content[4],
1180            anthropic::RequestContent::Image {
1181                cache_control: Some(anthropic::CacheControl {
1182                    cache_type: anthropic::CacheControlType::Ephemeral,
1183                }),
1184                ..
1185            }
1186        ));
1187    }
1188}